adv-optm 2.4.dev25__tar.gz → 2.5.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- adv_optm-2.5.1/PKG-INFO +113 -0
- adv_optm-2.5.1/README.md +82 -0
- {adv_optm-2.4.dev25 → adv_optm-2.5.1}/adv_optm/__init__.py +1 -1
- {adv_optm-2.4.dev25 → adv_optm-2.5.1}/adv_optm/optim/AdaMuon_adv.py +7 -7
- {adv_optm-2.4.dev25 → adv_optm-2.5.1}/adv_optm/optim/AdamW_adv.py +5 -5
- {adv_optm-2.4.dev25 → adv_optm-2.5.1}/adv_optm/optim/Adopt_adv.py +4 -6
- {adv_optm-2.4.dev25 → adv_optm-2.5.1}/adv_optm/optim/Lion_adv.py +6 -5
- {adv_optm-2.4.dev25 → adv_optm-2.5.1}/adv_optm/optim/Muon_adv.py +7 -7
- {adv_optm-2.4.dev25 → adv_optm-2.5.1}/adv_optm/optim/Prodigy_adv.py +4 -4
- {adv_optm-2.4.dev25 → adv_optm-2.5.1}/adv_optm/optim/SignSGD_adv.py +7 -8
- {adv_optm-2.4.dev25 → adv_optm-2.5.1}/adv_optm/optim/SinkSGD_adv.py +7 -8
- {adv_optm-2.4.dev25 → adv_optm-2.5.1}/adv_optm/util/Muon_AuxAdam.py +2 -3
- adv_optm-2.5.1/adv_optm/util/OrthoGrad.py +99 -0
- {adv_optm-2.4.dev25 → adv_optm-2.5.1}/adv_optm/util/param_update.py +3 -3
- {adv_optm-2.4.dev25 → adv_optm-2.5.1}/adv_optm/util/scaled_optm.py +2 -2
- {adv_optm-2.4.dev25 → adv_optm-2.5.1}/adv_optm/util/state_util.py +1 -1
- adv_optm-2.5.1/adv_optm.egg-info/PKG-INFO +113 -0
- {adv_optm-2.4.dev25 → adv_optm-2.5.1}/setup.py +1 -1
- adv_optm-2.4.dev25/PKG-INFO +0 -109
- adv_optm-2.4.dev25/README.md +0 -78
- adv_optm-2.4.dev25/adv_optm/util/OrthoGrad.py +0 -80
- adv_optm-2.4.dev25/adv_optm.egg-info/PKG-INFO +0 -109
- {adv_optm-2.4.dev25 → adv_optm-2.5.1}/LICENSE +0 -0
- {adv_optm-2.4.dev25 → adv_optm-2.5.1}/adv_optm/optim/__init__.py +0 -0
- {adv_optm-2.4.dev25 → adv_optm-2.5.1}/adv_optm/util/Kourkoutas.py +0 -0
- {adv_optm-2.4.dev25 → adv_optm-2.5.1}/adv_optm/util/Muon_util.py +0 -0
- {adv_optm-2.4.dev25 → adv_optm-2.5.1}/adv_optm/util/__init__.py +0 -0
- {adv_optm-2.4.dev25 → adv_optm-2.5.1}/adv_optm/util/centered_decay.py +0 -0
- {adv_optm-2.4.dev25 → adv_optm-2.5.1}/adv_optm/util/factorization_util.py +0 -0
- {adv_optm-2.4.dev25 → adv_optm-2.5.1}/adv_optm/util/lion_k.py +0 -0
- {adv_optm-2.4.dev25 → adv_optm-2.5.1}/adv_optm/util/signed_util.py +0 -0
- {adv_optm-2.4.dev25 → adv_optm-2.5.1}/adv_optm/util/sinkhorn.py +0 -0
- {adv_optm-2.4.dev25 → adv_optm-2.5.1}/adv_optm/util/update_util.py +0 -0
- {adv_optm-2.4.dev25 → adv_optm-2.5.1}/adv_optm.egg-info/SOURCES.txt +0 -0
- {adv_optm-2.4.dev25 → adv_optm-2.5.1}/adv_optm.egg-info/dependency_links.txt +0 -0
- {adv_optm-2.4.dev25 → adv_optm-2.5.1}/adv_optm.egg-info/requires.txt +0 -0
- {adv_optm-2.4.dev25 → adv_optm-2.5.1}/adv_optm.egg-info/top_level.txt +0 -0
- {adv_optm-2.4.dev25 → adv_optm-2.5.1}/setup.cfg +0 -0
adv_optm-2.5.1/PKG-INFO
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: adv_optm
|
|
3
|
+
Version: 2.5.1
|
|
4
|
+
Summary: A family of highly efficient, lightweight yet powerful optimizers.
|
|
5
|
+
Home-page: https://github.com/Koratahiu/Advanced_Optimizers
|
|
6
|
+
Author: Koratahiu
|
|
7
|
+
Author-email: hiuhonor@gmail.com
|
|
8
|
+
License: Apache 2.0
|
|
9
|
+
Keywords: llm,fine-tuning,memory-efficient,low-rank,compression,pytorch,optimizer,adam
|
|
10
|
+
Classifier: Programming Language :: Python :: 3
|
|
11
|
+
Classifier: License :: OSI Approved :: Apache Software License
|
|
12
|
+
Classifier: Operating System :: OS Independent
|
|
13
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
14
|
+
Classifier: Topic :: Software Development :: Libraries :: Python Modules
|
|
15
|
+
Requires-Python: >=3.8
|
|
16
|
+
Description-Content-Type: text/markdown
|
|
17
|
+
License-File: LICENSE
|
|
18
|
+
Requires-Dist: torch>=2.1
|
|
19
|
+
Dynamic: author
|
|
20
|
+
Dynamic: author-email
|
|
21
|
+
Dynamic: classifier
|
|
22
|
+
Dynamic: description
|
|
23
|
+
Dynamic: description-content-type
|
|
24
|
+
Dynamic: home-page
|
|
25
|
+
Dynamic: keywords
|
|
26
|
+
Dynamic: license
|
|
27
|
+
Dynamic: license-file
|
|
28
|
+
Dynamic: requires-dist
|
|
29
|
+
Dynamic: requires-python
|
|
30
|
+
Dynamic: summary
|
|
31
|
+
|
|
32
|
+
# Advanced Optimizers (AIO)
|
|
33
|
+
|
|
34
|
+
A comprehensive, all-in-one collection of state-of-the-art optimization algorithms for deep learning. Designed for **maximum efficiency**, **minimal memory footprint**, and **superior performance** across diverse model architectures and training scenarios.
|
|
35
|
+
|
|
36
|
+
[](https://pypi.org/project/adv_optm/)
|
|
37
|
+
[](https://pypi.org/project/adv_optm/)
|
|
38
|
+
[](LICENSE)
|
|
39
|
+
|
|
40
|
+
---
|
|
41
|
+
|
|
42
|
+
## 📦 Installation
|
|
43
|
+
|
|
44
|
+
```bash
|
|
45
|
+
pip install adv_optm
|
|
46
|
+
```
|
|
47
|
+
*Requires PyTorch 2.3+ for `torch.compile` support.*
|
|
48
|
+
|
|
49
|
+
---
|
|
50
|
+
|
|
51
|
+
## What's New
|
|
52
|
+
|
|
53
|
+
### 🌟 Version 2.5.x: The Massive Refactor
|
|
54
|
+
This major update introduces a complete architectural refactor of the library:
|
|
55
|
+
|
|
56
|
+
**🆕 New Optimizers & Scaling**
|
|
57
|
+
* **`SinkSGD_adv`:** Added a powerful new optimizer to the lineup.
|
|
58
|
+
* **Spectral Scaling:** Now available across *all* optimizers, achieving width/rank invariant updates for highly stable training.
|
|
59
|
+
|
|
60
|
+
**💾 Memory & State Precision Control**
|
|
61
|
+
* **Granular State Precision (`state_precision`):** Drastically reduce memory overhead with new optimizer state modes:
|
|
62
|
+
* `factored` (Rank-2 factored mode)
|
|
63
|
+
* `fp32` (Full precision)
|
|
64
|
+
* `bf16_sr` & `int8_sr` (BF16/Int8 with Stochastic Rounding)
|
|
65
|
+
* **Factored Second Moment (`factored_2nd`):** Available for all Adam variants. Works seamlessly alongside any `state_precision` setting to further slash memory usage.
|
|
66
|
+
|
|
67
|
+
**⚙️ Advanced Dynamics & Momentum**
|
|
68
|
+
* **Variance Normalized Momentum (`normed_momentum`):** Applies optimizer normalization *before* momentum (Normalization then Momentum/NtM). Available for `AdamW_adv`, `SignSGD_adv`, and `SinkSGD_adv`.
|
|
69
|
+
* **Universal Nesterov Momentum:** Replaced the hard-to-tune Simplified_AdEMAMix with Nesterov momentum (`nesterov`) and a dedicated coefficient (`nesterov_coef`) across all optimizers.
|
|
70
|
+
* **Preconditioning & Signs:**
|
|
71
|
+
* Added **Variance/Confidence Preconditioning (`snr_cond`)** for `SignSGD_adv` and `SinkSGD_adv` (requires `normed_momentum`). Read the technical reports: [AASS](https://koratahiu.github.io/aass/) & [sink-v](https://koratahiu.github.io/sink-v/).
|
|
72
|
+
* Added **Adaptive Stochastic Sign** with $L_\infty$ preconditioning (`stochastic_sign`) for `SignSGD_Adv` and `Lion_adv`.
|
|
73
|
+
* **Improved CANS (`accelerated_ns`):** Enhanced for Muon variants by integrating a dynamic lower bound.
|
|
74
|
+
* **New OrthoGrad modes (`orthogonal_gradient`):** Standard OrthoGrad `flattened` and a new matrix-wise mode `iterative`.
|
|
75
|
+
|
|
76
|
+
**⚓ Weight Decay Innovations**
|
|
77
|
+
* **Centered Weight Decay (`centered_wd`):** Pulls weights toward their pre-train state (anchor). To save memory, anchor precision (`centered_wd_mode`) can be set to full, float8, int8, or int4.
|
|
78
|
+
* **Fisher Weight Decay (`fisher_wd`):** Now available for Adam variants based on the [FAdam paper](https://arxiv.org/abs/2405.12807).
|
|
79
|
+
* **Geometric Weight Decay:** Added specifically for `SinkSGD_adv` and `SignSGD_adv`.
|
|
80
|
+
|
|
81
|
+
*(Note: `Lion_Prodigy_adv`, `Simplified_AdEMAMix`, and heuristic cautious/grams modes have been deprecated in favor of these superior, theoretically-grounded features).*
|
|
82
|
+
|
|
83
|
+
<details>
|
|
84
|
+
<summary><b>Click to see older release notes (v1.2.x - v2.1.x)</b></summary>
|
|
85
|
+
|
|
86
|
+
### Version 2.1.x
|
|
87
|
+
* **New Optimizer:** Added **Signum** (SignSGD with momentum) to the `SignSGD_adv` family.
|
|
88
|
+
|
|
89
|
+
### Version 2.0.x
|
|
90
|
+
* ⚡ **`torch.compile` Support:** Fully implemented for all advanced optimizers. Enable via `compiled_optimizer=True` to heavily fuse and optimize the optimizer step path.
|
|
91
|
+
* 📉 **1-Bit Factored Mode:** Vastly improved implementation via `nnmf_factor=True`.
|
|
92
|
+
* 🛠️ Broad performance and stability improvements across all optimizers.
|
|
93
|
+
|
|
94
|
+
### Version 1.2.x
|
|
95
|
+
* **Advanced Muon Variants:** Brought the groundbreaking [Muon optimizer](https://kellerjordan.github.io/posts/muon/) into the fold, enriched with features from recent literature.
|
|
96
|
+
|
|
97
|
+
| Optimizer | Description |
|
|
98
|
+
|---|---|
|
|
99
|
+
| `Muon_adv` | Advanced Muon implementation featuring CANS, NorMuon, Low-Rank Orthogonalization, and more. |
|
|
100
|
+
| `AdaMuon_adv` | Combines Muon's geometry with Adam-like adaptive scaling and sign-based orthogonalization. |
|
|
101
|
+
|
|
102
|
+
* **Prodigy Speedup:** Prodigy variants are now **50% faster** by eliminating unnecessary CUDA syncs (Shoutout to **@dxqb**!).
|
|
103
|
+
* **Stochastic Rounding for BF16:** Parameter updates and weight decay now accumulate in float32 and round once at the end.
|
|
104
|
+
* **Cautious Weight Decay:** Implemented for all advanced optimizers ([Paper](https://arxiv.org/abs/2510.12402)).
|
|
105
|
+
* **Fused Operations:** Transitioned to fused and in-place operations wherever possible.
|
|
106
|
+
|
|
107
|
+
</details>
|
|
108
|
+
|
|
109
|
+
---
|
|
110
|
+
|
|
111
|
+
## 💡 Core Innovations
|
|
112
|
+
|
|
113
|
+
*(Documentation expanding on the theory and usage of these features is coming soon!)*
|
adv_optm-2.5.1/README.md
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
1
|
+
# Advanced Optimizers (AIO)
|
|
2
|
+
|
|
3
|
+
A comprehensive, all-in-one collection of state-of-the-art optimization algorithms for deep learning. Designed for **maximum efficiency**, **minimal memory footprint**, and **superior performance** across diverse model architectures and training scenarios.
|
|
4
|
+
|
|
5
|
+
[](https://pypi.org/project/adv_optm/)
|
|
6
|
+
[](https://pypi.org/project/adv_optm/)
|
|
7
|
+
[](LICENSE)
|
|
8
|
+
|
|
9
|
+
---
|
|
10
|
+
|
|
11
|
+
## 📦 Installation
|
|
12
|
+
|
|
13
|
+
```bash
|
|
14
|
+
pip install adv_optm
|
|
15
|
+
```
|
|
16
|
+
*Requires PyTorch 2.3+ for `torch.compile` support.*
|
|
17
|
+
|
|
18
|
+
---
|
|
19
|
+
|
|
20
|
+
## What's New
|
|
21
|
+
|
|
22
|
+
### 🌟 Version 2.5.x: The Massive Refactor
|
|
23
|
+
This major update introduces a complete architectural refactor of the library:
|
|
24
|
+
|
|
25
|
+
**🆕 New Optimizers & Scaling**
|
|
26
|
+
* **`SinkSGD_adv`:** Added a powerful new optimizer to the lineup.
|
|
27
|
+
* **Spectral Scaling:** Now available across *all* optimizers, achieving width/rank invariant updates for highly stable training.
|
|
28
|
+
|
|
29
|
+
**💾 Memory & State Precision Control**
|
|
30
|
+
* **Granular State Precision (`state_precision`):** Drastically reduce memory overhead with new optimizer state modes:
|
|
31
|
+
* `factored` (Rank-2 factored mode)
|
|
32
|
+
* `fp32` (Full precision)
|
|
33
|
+
* `bf16_sr` & `int8_sr` (BF16/Int8 with Stochastic Rounding)
|
|
34
|
+
* **Factored Second Moment (`factored_2nd`):** Available for all Adam variants. Works seamlessly alongside any `state_precision` setting to further slash memory usage.
|
|
35
|
+
|
|
36
|
+
**⚙️ Advanced Dynamics & Momentum**
|
|
37
|
+
* **Variance Normalized Momentum (`normed_momentum`):** Applies optimizer normalization *before* momentum (Normalization then Momentum/NtM). Available for `AdamW_adv`, `SignSGD_adv`, and `SinkSGD_adv`.
|
|
38
|
+
* **Universal Nesterov Momentum:** Replaced the hard-to-tune Simplified_AdEMAMix with Nesterov momentum (`nesterov`) and a dedicated coefficient (`nesterov_coef`) across all optimizers.
|
|
39
|
+
* **Preconditioning & Signs:**
|
|
40
|
+
* Added **Variance/Confidence Preconditioning (`snr_cond`)** for `SignSGD_adv` and `SinkSGD_adv` (requires `normed_momentum`). Read the technical reports: [AASS](https://koratahiu.github.io/aass/) & [sink-v](https://koratahiu.github.io/sink-v/).
|
|
41
|
+
* Added **Adaptive Stochastic Sign** with $L_\infty$ preconditioning (`stochastic_sign`) for `SignSGD_Adv` and `Lion_adv`.
|
|
42
|
+
* **Improved CANS (`accelerated_ns`):** Enhanced for Muon variants by integrating a dynamic lower bound.
|
|
43
|
+
* **New OrthoGrad modes (`orthogonal_gradient`):** Standard OrthoGrad `flattened` and a new matrix-wise mode `iterative`.
|
|
44
|
+
|
|
45
|
+
**⚓ Weight Decay Innovations**
|
|
46
|
+
* **Centered Weight Decay (`centered_wd`):** Pulls weights toward their pre-train state (anchor). To save memory, anchor precision (`centered_wd_mode`) can be set to full, float8, int8, or int4.
|
|
47
|
+
* **Fisher Weight Decay (`fisher_wd`):** Now available for Adam variants based on the [FAdam paper](https://arxiv.org/abs/2405.12807).
|
|
48
|
+
* **Geometric Weight Decay:** Added specifically for `SinkSGD_adv` and `SignSGD_adv`.
|
|
49
|
+
|
|
50
|
+
*(Note: `Lion_Prodigy_adv`, `Simplified_AdEMAMix`, and heuristic cautious/grams modes have been deprecated in favor of these superior, theoretically-grounded features).*
|
|
51
|
+
|
|
52
|
+
<details>
|
|
53
|
+
<summary><b>Click to see older release notes (v1.2.x - v2.1.x)</b></summary>
|
|
54
|
+
|
|
55
|
+
### Version 2.1.x
|
|
56
|
+
* **New Optimizer:** Added **Signum** (SignSGD with momentum) to the `SignSGD_adv` family.
|
|
57
|
+
|
|
58
|
+
### Version 2.0.x
|
|
59
|
+
* ⚡ **`torch.compile` Support:** Fully implemented for all advanced optimizers. Enable via `compiled_optimizer=True` to heavily fuse and optimize the optimizer step path.
|
|
60
|
+
* 📉 **1-Bit Factored Mode:** Vastly improved implementation via `nnmf_factor=True`.
|
|
61
|
+
* 🛠️ Broad performance and stability improvements across all optimizers.
|
|
62
|
+
|
|
63
|
+
### Version 1.2.x
|
|
64
|
+
* **Advanced Muon Variants:** Brought the groundbreaking [Muon optimizer](https://kellerjordan.github.io/posts/muon/) into the fold, enriched with features from recent literature.
|
|
65
|
+
|
|
66
|
+
| Optimizer | Description |
|
|
67
|
+
|---|---|
|
|
68
|
+
| `Muon_adv` | Advanced Muon implementation featuring CANS, NorMuon, Low-Rank Orthogonalization, and more. |
|
|
69
|
+
| `AdaMuon_adv` | Combines Muon's geometry with Adam-like adaptive scaling and sign-based orthogonalization. |
|
|
70
|
+
|
|
71
|
+
* **Prodigy Speedup:** Prodigy variants are now **50% faster** by eliminating unnecessary CUDA syncs (Shoutout to **@dxqb**!).
|
|
72
|
+
* **Stochastic Rounding for BF16:** Parameter updates and weight decay now accumulate in float32 and round once at the end.
|
|
73
|
+
* **Cautious Weight Decay:** Implemented for all advanced optimizers ([Paper](https://arxiv.org/abs/2510.12402)).
|
|
74
|
+
* **Fused Operations:** Transitioned to fused and in-place operations wherever possible.
|
|
75
|
+
|
|
76
|
+
</details>
|
|
77
|
+
|
|
78
|
+
---
|
|
79
|
+
|
|
80
|
+
## 💡 Core Innovations
|
|
81
|
+
|
|
82
|
+
*(Documentation expanding on the theory and usage of these features is coming soon!)*
|
|
@@ -57,7 +57,8 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
57
57
|
(default: (3.4445, -4.7750, 2.0315)).
|
|
58
58
|
stochastic_rounding (bool): whether to use stochastic rounding for
|
|
59
59
|
BF16 parameter updates (default: True).
|
|
60
|
-
orthogonal_gradient (
|
|
60
|
+
orthogonal_gradient (str): whether to use OrthoGrad variants. 'disabled': off.
|
|
61
|
+
'flattened': Standard vectorized OrthoGrad. 'iterative': Matrix-wise rank-2 OrthoGrad. (default: disabled)
|
|
61
62
|
nesterov (bool): enables Nesterov momentum (default: False).
|
|
62
63
|
use_atan2 (bool): whether to use the atan2 update rule. (default: False)
|
|
63
64
|
vector_reshape (bool): whether to reshape 1D vectors into 2D
|
|
@@ -114,7 +115,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
114
115
|
adam_fisher_wd (bool): Fisher Adam (FAdam) weight decay for the AdamW part. (default: False)
|
|
115
116
|
adam_use_bias_correction (bool): Bias correction for AdamW.
|
|
116
117
|
adam_use_atan2 (bool): Atan2 update rule for AdamW.
|
|
117
|
-
adam_orthogonal_gradient (
|
|
118
|
+
adam_orthogonal_gradient (str): OrthoGrad for AdamW.
|
|
118
119
|
adam_nesterov (bool): Nesterov momentum for AdamW. (default: False)
|
|
119
120
|
adam_nesterov_coef (float, optional): Nesterov coefficient for AdamW. (default: None)
|
|
120
121
|
adam_kourkoutas_beta (bool): Kourkoutas-β for AdamW.
|
|
@@ -149,7 +150,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
149
150
|
# Stochastic Rounding for BF16
|
|
150
151
|
stochastic_rounding: bool = True,
|
|
151
152
|
# OrthoGrad
|
|
152
|
-
orthogonal_gradient:
|
|
153
|
+
orthogonal_gradient: str = 'disabled', # 'flattened', 'iterative'
|
|
153
154
|
# Adam_atan2 (scale invariant)
|
|
154
155
|
use_atan2: bool = False,
|
|
155
156
|
# NorMuon
|
|
@@ -190,7 +191,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
190
191
|
adam_fisher_wd: bool = False,
|
|
191
192
|
adam_use_bias_correction: bool = True,
|
|
192
193
|
adam_use_atan2: bool = False,
|
|
193
|
-
adam_orthogonal_gradient:
|
|
194
|
+
adam_orthogonal_gradient: str = 'disabled', # 'flattened', 'iterative'
|
|
194
195
|
adam_nesterov: bool = False,
|
|
195
196
|
adam_nesterov_coef: float | None = None,
|
|
196
197
|
adam_kourkoutas_beta: bool = False,
|
|
@@ -213,7 +214,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
213
214
|
print("Warning: spectral_normalization is incompatible with rms_rescaling, Disabling rms_rescaling.")
|
|
214
215
|
rms_rescaling = False
|
|
215
216
|
if spectral_normalization and accelerated_ns:
|
|
216
|
-
ValueError("spectral_normalization violates accelerated Newton-Schulz assumptions. Pick one of them.")
|
|
217
|
+
raise ValueError("spectral_normalization violates accelerated Newton-Schulz assumptions. Pick one of them.")
|
|
217
218
|
|
|
218
219
|
# Legacy backwards compatibility support for `nnmf_factor=True`
|
|
219
220
|
if nnmf_factor:
|
|
@@ -515,8 +516,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
515
516
|
grad = approx_mars(grad, state['last_grad'], group['mars_gamma'], beta1)
|
|
516
517
|
|
|
517
518
|
|
|
518
|
-
|
|
519
|
-
grad = _orthogonalize_gradient(p, grad)
|
|
519
|
+
grad = _orthogonalize_gradient(p, grad, group.get("orthogonal_gradient"))
|
|
520
520
|
|
|
521
521
|
if state['factored']: # Factored Muon
|
|
522
522
|
d1, d2 = state['effective_shape']
|
|
@@ -45,7 +45,8 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
45
45
|
stochastic_rounding (bool): whether to use stochastic
|
|
46
46
|
rounding for BF16 parameter updates (default: True).
|
|
47
47
|
use_atan2 (bool): whether to use the atan2 update rule. (default: False)
|
|
48
|
-
orthogonal_gradient (
|
|
48
|
+
orthogonal_gradient (str): whether to use OrthoGrad variants. 'disabled': off.
|
|
49
|
+
'flattened': Standard vectorized OrthoGrad. 'iterative': Matrix-wise rank-2 OrthoGrad. (default: disabled)
|
|
49
50
|
normed_momentum (bool): whether to compute the first moment on the normalized gradient. (default: False)
|
|
50
51
|
kourkoutas_beta (bool): whether to enable the layer-wise dynamic β₂ logic.
|
|
51
52
|
If `False`, the optimizer behaves as standard AdamW. (default: False)
|
|
@@ -104,7 +105,7 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
104
105
|
# Adam_atan2 (scale invariant)
|
|
105
106
|
use_atan2: bool = False,
|
|
106
107
|
# OrthoGrad
|
|
107
|
-
orthogonal_gradient:
|
|
108
|
+
orthogonal_gradient: str = 'disabled', # 'flattened', 'iterative'
|
|
108
109
|
# Nesterov momentum
|
|
109
110
|
nesterov: bool = False,
|
|
110
111
|
nesterov_coef: float | None = None,
|
|
@@ -326,8 +327,7 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
326
327
|
def _step_parameter(self, p, grad, state, group, step_size, beta1, beta2, sqrt_bias_correction2, random_int_tensor, random_int_state_tensor):
|
|
327
328
|
grad = upcast_grad_for_precision(grad, state, group['state_precision'])
|
|
328
329
|
|
|
329
|
-
|
|
330
|
-
grad = _orthogonalize_gradient(p, grad)
|
|
330
|
+
grad = _orthogonalize_gradient(p, grad, group["orthogonal_gradient"])
|
|
331
331
|
|
|
332
332
|
nesterov = group.get('nesterov', False)
|
|
333
333
|
nesterov_coef = group.get('nesterov_coef', None)
|
|
@@ -462,7 +462,7 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
462
462
|
else:
|
|
463
463
|
update.mul_(update_scaling)
|
|
464
464
|
|
|
465
|
-
param_update.apply_parameter_update(self, p, group, update,
|
|
465
|
+
param_update.apply_parameter_update(self, p, group, update, group['lr'], random_int_tensor=random_int_tensor, wd_scaler=wd_scaler)
|
|
466
466
|
|
|
467
467
|
def compile(self, *args, **kwargs):
|
|
468
468
|
self._compiled_step_parameter = torch.compile(self._step_parameter, *args, **kwargs)
|
|
@@ -108,7 +108,7 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
108
108
|
# Stochastic Rounding for BF16
|
|
109
109
|
stochastic_rounding: bool = True,
|
|
110
110
|
# OrthoGrad
|
|
111
|
-
orthogonal_gradient:
|
|
111
|
+
orthogonal_gradient: str = 'disabled', # 'flattened', 'iterative'
|
|
112
112
|
# Nesterov momentum
|
|
113
113
|
nesterov: bool = False,
|
|
114
114
|
nesterov_coef: float | None = None,
|
|
@@ -158,7 +158,7 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
158
158
|
|
|
159
159
|
defaults = {
|
|
160
160
|
"lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay,
|
|
161
|
-
"fisher_wd": fisher_wd, "cautious_wd": cautious_wd,
|
|
161
|
+
"fisher_wd": fisher_wd, "cautious_wd": cautious_wd, "orthogonal_gradient": orthogonal_gradient,
|
|
162
162
|
"nesterov": nesterov, "nesterov_coef": nesterov_coef,
|
|
163
163
|
"kourkoutas_beta": kourkoutas_beta, "beta2_min": beta2_min, "ema_alpha": ema_alpha,
|
|
164
164
|
"tiny_spike": tiny_spike, "k_warmup_steps": k_warmup_steps, "k_logging": k_logging,
|
|
@@ -172,7 +172,6 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
172
172
|
self.clip_lambda = clip_lambda
|
|
173
173
|
self.stochastic_rounding = stochastic_rounding
|
|
174
174
|
self.use_atan2 = use_atan2
|
|
175
|
-
self.orthogonal_gradient = orthogonal_gradient
|
|
176
175
|
self.kourkoutas_beta = kourkoutas_beta
|
|
177
176
|
self.layer_key_fn = layer_key_fn
|
|
178
177
|
self._init_lr = lr if lr > 0 else 1
|
|
@@ -237,7 +236,7 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
237
236
|
dtype = torch.float32 if (state['factored'] or req_precision == 'factored') else p.dtype
|
|
238
237
|
|
|
239
238
|
vt_dtype = torch.float32 if (state['factored'] or state['factored_2nd'] or req_precision in ['factored', 'bf16_sr', 'int8_sr']) else dtype
|
|
240
|
-
vt_init = grad.pow(2).to(vt_dtype)
|
|
239
|
+
vt_init = grad.pow(2).to(vt_dtype)
|
|
241
240
|
|
|
242
241
|
if state['factored']:
|
|
243
242
|
state['effective_shape'] = _get_effective_shape(p.numel())
|
|
@@ -329,8 +328,7 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
329
328
|
def _step_parameter(self, p, grad, state, group, lr, beta1, beta2, random_int_tensor, random_int_state_tensor):
|
|
330
329
|
grad = upcast_grad_for_precision(grad, state, group['state_precision'])
|
|
331
330
|
|
|
332
|
-
|
|
333
|
-
grad = _orthogonalize_gradient(p, grad)
|
|
331
|
+
grad = _orthogonalize_gradient(p, grad, group["orthogonal_gradient"])
|
|
334
332
|
|
|
335
333
|
nesterov = group.get('nesterov', False)
|
|
336
334
|
nesterov_coef = group.get('nesterov_coef', None)
|
|
@@ -67,7 +67,7 @@ class Lion_adv(torch.optim.Optimizer):
|
|
|
67
67
|
# Stochastic Rounding for BF16
|
|
68
68
|
stochastic_rounding: bool = True,
|
|
69
69
|
# OrthoGrad
|
|
70
|
-
orthogonal_gradient:
|
|
70
|
+
orthogonal_gradient: str = 'disabled', # 'flattened', 'iterative'
|
|
71
71
|
# Lion-k
|
|
72
72
|
kappa_p: float = 1.0,
|
|
73
73
|
auto_kappa_p: bool = False,
|
|
@@ -213,8 +213,9 @@ class Lion_adv(torch.optim.Optimizer):
|
|
|
213
213
|
def _step_parameter(self, p, grad, state, group, lr, random_int_tensor, random_noise_tensor):
|
|
214
214
|
if grad.dtype != torch.float32 and state['factored']:
|
|
215
215
|
grad = grad.float()
|
|
216
|
-
|
|
217
|
-
|
|
216
|
+
is_vector = p.ndim < 2 or getattr(p, '_is_dora_scale', False) or getattr(p, 'is_vector', False)
|
|
217
|
+
|
|
218
|
+
grad = _orthogonalize_gradient(p, grad, group["orthogonal_gradient"])
|
|
218
219
|
|
|
219
220
|
# Lion-K Logic
|
|
220
221
|
kappa_p = group.get("kappa_p", 1.0)
|
|
@@ -250,7 +251,7 @@ class Lion_adv(torch.optim.Optimizer):
|
|
|
250
251
|
update = update.view(p.shape)
|
|
251
252
|
|
|
252
253
|
if group.get('stochastic_sign', False):
|
|
253
|
-
update = apply_stochastic_sign_(update, noise=random_noise_tensor)
|
|
254
|
+
update = apply_stochastic_sign_(update, noise=random_noise_tensor, is_vector=is_vector)
|
|
254
255
|
else:
|
|
255
256
|
update = _get_lion_k_update(update, kappa_p)
|
|
256
257
|
|
|
@@ -265,7 +266,7 @@ class Lion_adv(torch.optim.Optimizer):
|
|
|
265
266
|
exp_avg.lerp_(grad, 1 - beta2)
|
|
266
267
|
|
|
267
268
|
if group.get('stochastic_sign', False):
|
|
268
|
-
update = apply_stochastic_sign_(update, noise=random_noise_tensor)
|
|
269
|
+
update = apply_stochastic_sign_(update, noise=random_noise_tensor, is_vector=is_vector)
|
|
269
270
|
else:
|
|
270
271
|
update = _get_lion_k_update(update, kappa_p)
|
|
271
272
|
|
|
@@ -39,7 +39,8 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
39
39
|
(default: (3.4445, -4.7750, 2.0315)).
|
|
40
40
|
stochastic_rounding (bool): whether to use stochastic rounding for
|
|
41
41
|
BF16 parameter updates (default: True).
|
|
42
|
-
orthogonal_gradient (
|
|
42
|
+
orthogonal_gradient (str): whether to use OrthoGrad variants. 'disabled': off.
|
|
43
|
+
'flattened': Standard vectorized OrthoGrad. 'iterative': Matrix-wise rank-2 OrthoGrad. (default: disabled)
|
|
43
44
|
vector_reshape (bool): whether to reshape 1D vectors into 2D
|
|
44
45
|
matrices to apply low-rank compression (default: True).
|
|
45
46
|
nnmf_factor (bool): whether to use the factorization or disable it to use
|
|
@@ -89,7 +90,7 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
89
90
|
adam_fisher_wd (bool): Fisher Adam (FAdam) weight decay for the AdamW part. (default: False)
|
|
90
91
|
adam_use_bias_correction (bool): Bias correction for AdamW.
|
|
91
92
|
adam_use_atan2 (bool): Atan2 update rule for AdamW.
|
|
92
|
-
adam_orthogonal_gradient (
|
|
93
|
+
adam_orthogonal_gradient (str): OrthoGrad for AdamW.
|
|
93
94
|
adam_nesterov (bool): Nesterov momentum for AdamW. (default: False)
|
|
94
95
|
adam_nesterov_coef (float, optional): Nesterov coefficient for AdamW. (default: None)
|
|
95
96
|
adam_kourkoutas_beta (bool): Kourkoutas-β for AdamW.
|
|
@@ -121,7 +122,7 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
121
122
|
# Stochastic Rounding for BF16
|
|
122
123
|
stochastic_rounding: bool = True,
|
|
123
124
|
# OrthoGrad
|
|
124
|
-
orthogonal_gradient:
|
|
125
|
+
orthogonal_gradient: str = 'disabled', # 'flattened', 'iterative'
|
|
125
126
|
# RMS Rescaling
|
|
126
127
|
rms_rescaling: bool = True,
|
|
127
128
|
# SMMF factorization
|
|
@@ -159,7 +160,7 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
159
160
|
adam_fisher_wd: bool = False,
|
|
160
161
|
adam_use_bias_correction: bool = True,
|
|
161
162
|
adam_use_atan2: bool = False,
|
|
162
|
-
adam_orthogonal_gradient:
|
|
163
|
+
adam_orthogonal_gradient: str = 'disabled', # 'flattened', 'iterative'
|
|
163
164
|
adam_nesterov: bool = False,
|
|
164
165
|
adam_nesterov_coef: float | None = None,
|
|
165
166
|
adam_kourkoutas_beta: bool = False,
|
|
@@ -186,7 +187,7 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
186
187
|
print("Warning: spectral_normalization is incompatible with rms_rescaling, Disabling rms_rescaling.")
|
|
187
188
|
rms_rescaling = False
|
|
188
189
|
if spectral_normalization and accelerated_ns:
|
|
189
|
-
ValueError("spectral_normalization violates accelerated Newton-Schulz assumptions. Pick one of them.")
|
|
190
|
+
raise ValueError("spectral_normalization violates accelerated Newton-Schulz assumptions. Pick one of them.")
|
|
190
191
|
|
|
191
192
|
# Legacy backwards compatibility support for `nnmf_factor=True`
|
|
192
193
|
if nnmf_factor:
|
|
@@ -457,8 +458,7 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
457
458
|
if grad.dtype != torch.float32 and state.get('factored', False):
|
|
458
459
|
grad = grad.float()
|
|
459
460
|
|
|
460
|
-
|
|
461
|
-
grad = _orthogonalize_gradient(p, grad)
|
|
461
|
+
grad = _orthogonalize_gradient(p, grad, group.get("orthogonal_gradient"))
|
|
462
462
|
|
|
463
463
|
if state['factored']: # Factored Muon
|
|
464
464
|
d1, d2 = state['effective_shape']
|
|
@@ -43,7 +43,8 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
43
43
|
stochastic_rounding (bool): whether to use stochastic
|
|
44
44
|
rounding for BF16 parameter updates (default: True).
|
|
45
45
|
use_atan2 (bool): whether to use the atan2 update rule. (default: False)
|
|
46
|
-
orthogonal_gradient (
|
|
46
|
+
orthogonal_gradient (str): whether to use OrthoGrad variants. 'disabled': off.
|
|
47
|
+
'flattened': Standard vectorized OrthoGrad. 'iterative': Matrix-wise rank-2 OrthoGrad. (default: disabled)
|
|
47
48
|
nnmf_factor (bool): whether to use the factorization or disable it to use
|
|
48
49
|
the uncompressed optimizer. (default: False)
|
|
49
50
|
factored_2nd (bool): whether to keep the first moment uncompressed (dense)
|
|
@@ -119,7 +120,7 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
119
120
|
# Adam_atan2 (scale invariant)
|
|
120
121
|
use_atan2: bool = False,
|
|
121
122
|
# OrthoGrad
|
|
122
|
-
orthogonal_gradient:
|
|
123
|
+
orthogonal_gradient: str = 'disabled', # 'flattened', 'iterative'
|
|
123
124
|
# Nesterov momentum
|
|
124
125
|
nesterov: bool = False,
|
|
125
126
|
nesterov_coef: float | None = None,
|
|
@@ -371,8 +372,7 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
371
372
|
def _step_parameter(self, p, grad, state, group, beta2, d, dlr, random_int_tensor, random_int_state_tensor):
|
|
372
373
|
grad = upcast_grad_for_precision(grad, state, group['state_precision'])
|
|
373
374
|
|
|
374
|
-
|
|
375
|
-
grad = _orthogonalize_gradient(p, grad)
|
|
375
|
+
grad = _orthogonalize_gradient(p, grad, group["orthogonal_gradient"])
|
|
376
376
|
|
|
377
377
|
nesterov = group.get('nesterov', False)
|
|
378
378
|
nesterov_coef = group.get('nesterov_coef', None)
|
|
@@ -62,7 +62,7 @@ class SignSGD_adv(torch.optim.Optimizer):
|
|
|
62
62
|
# Stochastic Rounding for BF16
|
|
63
63
|
stochastic_rounding: bool = True,
|
|
64
64
|
# OrthoGrad
|
|
65
|
-
orthogonal_gradient:
|
|
65
|
+
orthogonal_gradient: str = 'disabled', # 'flattened', 'iterative'
|
|
66
66
|
# Stochastic Sign Operator
|
|
67
67
|
stochastic_sign: bool = False,
|
|
68
68
|
# Nesterov momentum
|
|
@@ -171,7 +171,7 @@ class SignSGD_adv(torch.optim.Optimizer):
|
|
|
171
171
|
def __init_state(self, p, group):
|
|
172
172
|
state = self.state[p]
|
|
173
173
|
# State Initialization
|
|
174
|
-
if
|
|
174
|
+
if 'step' not in state:
|
|
175
175
|
req_precision = group['state_precision']
|
|
176
176
|
is_vector = len(p.shape) == 1 and not group['vector_reshape']
|
|
177
177
|
|
|
@@ -259,8 +259,7 @@ class SignSGD_adv(torch.optim.Optimizer):
|
|
|
259
259
|
wd_target = None
|
|
260
260
|
cwd_target = None
|
|
261
261
|
|
|
262
|
-
|
|
263
|
-
grad = _orthogonalize_gradient(p, grad)
|
|
262
|
+
grad = _orthogonalize_gradient(p, grad, group["orthogonal_gradient"])
|
|
264
263
|
|
|
265
264
|
if normed_mt:
|
|
266
265
|
if sso:
|
|
@@ -282,7 +281,7 @@ class SignSGD_adv(torch.optim.Optimizer):
|
|
|
282
281
|
|
|
283
282
|
if nesterov and normed_mt:
|
|
284
283
|
# Scale the normalized gradient using empirical buffer magnitude (SNR recovery)
|
|
285
|
-
normed_grad =
|
|
284
|
+
normed_grad = exp_avg.abs().mul_(grad_reshaped)
|
|
286
285
|
|
|
287
286
|
exp_avg.lerp_(grad_reshaped, 1 - momentum)
|
|
288
287
|
|
|
@@ -313,7 +312,7 @@ class SignSGD_adv(torch.optim.Optimizer):
|
|
|
313
312
|
|
|
314
313
|
if nesterov and normed_mt:
|
|
315
314
|
# Scale the normalized gradient using empirical buffer magnitude (SNR recovery)
|
|
316
|
-
normed_grad =
|
|
315
|
+
normed_grad = exp_avg.abs().mul_(grad)
|
|
317
316
|
|
|
318
317
|
exp_avg.lerp_(grad, 1 - momentum)
|
|
319
318
|
|
|
@@ -344,7 +343,7 @@ class SignSGD_adv(torch.optim.Optimizer):
|
|
|
344
343
|
if group.get('geometric_wd', False) and group["weight_decay"] > 0 :
|
|
345
344
|
wd_target = get_signsgd_wd_target(p, denom=denom, stochastic_sign=sso, noise=random_noise_tensor, is_vector=is_vector)
|
|
346
345
|
|
|
347
|
-
if group.get('centered_wd', 0.0) > 0 and '
|
|
346
|
+
if group.get('centered_wd', 0.0) > 0 and 'anchor_data' in state:
|
|
348
347
|
anchor = dequantize_anchor(p, state, group, p.dtype)
|
|
349
348
|
cwd_target = get_signsgd_wd_target(p.sub(anchor), denom=denom, stochastic_sign=sso, noise=random_noise_tensor, is_vector=is_vector)
|
|
350
349
|
del anchor
|
|
@@ -355,7 +354,7 @@ class SignSGD_adv(torch.optim.Optimizer):
|
|
|
355
354
|
update_scaling = lr * A if snr_cond else lr
|
|
356
355
|
update.mul_(update_scaling)
|
|
357
356
|
|
|
358
|
-
param_update.apply_parameter_update(self, p, group, update, lr, random_int_tensor=random_int_tensor, wd_target=wd_target, cwd_target=cwd_target
|
|
357
|
+
param_update.apply_parameter_update(self, p, group, update, lr, random_int_tensor=random_int_tensor, wd_target=wd_target, cwd_target=cwd_target)
|
|
359
358
|
|
|
360
359
|
def compile(self, *args, **kwargs):
|
|
361
360
|
self._compiled_step_parameter = torch.compile(self._step_parameter, *args, **kwargs)
|
|
@@ -69,7 +69,7 @@ class SinkSGD_adv(torch.optim.Optimizer):
|
|
|
69
69
|
# Stochastic Rounding for BF16
|
|
70
70
|
stochastic_rounding: bool = True,
|
|
71
71
|
# OrthoGrad
|
|
72
|
-
orthogonal_gradient:
|
|
72
|
+
orthogonal_gradient: str = 'disabled', # 'flattened', 'iterative'
|
|
73
73
|
# Spectral Normed Optimizer
|
|
74
74
|
spectral_normalization: bool = False,
|
|
75
75
|
# Centered WD
|
|
@@ -89,8 +89,8 @@ class SinkSGD_adv(torch.optim.Optimizer):
|
|
|
89
89
|
raise ValueError(f"Momentum should be >= 0.0. Got {momentum}")
|
|
90
90
|
if not (weight_decay >= 0.0):
|
|
91
91
|
raise ValueError(f"Weight-decay should be >= 0.0. Got {weight_decay}")
|
|
92
|
-
if snr_cond and not normed_momentum:
|
|
93
|
-
raise NotImplementedError(f"snr_cond is intended to be used with normed_momentum")
|
|
92
|
+
if snr_cond and not normed_momentum and not momentum > 0:
|
|
93
|
+
raise NotImplementedError(f"snr_cond is intended to be used with normed_momentum.")
|
|
94
94
|
|
|
95
95
|
state_precision = state_precision.lower()
|
|
96
96
|
valid_precisions = {"auto", "fp32", "factored", "bf16_sr", "fp16", "int8_sr"}
|
|
@@ -237,8 +237,7 @@ class SinkSGD_adv(torch.optim.Optimizer):
|
|
|
237
237
|
wd_target = None
|
|
238
238
|
cwd_target = None
|
|
239
239
|
|
|
240
|
-
|
|
241
|
-
grad = _orthogonalize_gradient(p, grad)
|
|
240
|
+
grad = _orthogonalize_gradient(p, grad, group["orthogonal_gradient"])
|
|
242
241
|
|
|
243
242
|
if normed_mt:
|
|
244
243
|
if not is_vector:
|
|
@@ -266,7 +265,7 @@ class SinkSGD_adv(torch.optim.Optimizer):
|
|
|
266
265
|
|
|
267
266
|
if nesterov and normed_mt:
|
|
268
267
|
# Scale the normalized gradient using empirical buffer magnitude (SNR recovery)
|
|
269
|
-
normed_grad =
|
|
268
|
+
normed_grad = buf.abs().mul_(grad_reshaped)
|
|
270
269
|
|
|
271
270
|
buf.lerp_(grad_reshaped, 1 - momentum)
|
|
272
271
|
|
|
@@ -303,7 +302,7 @@ class SinkSGD_adv(torch.optim.Optimizer):
|
|
|
303
302
|
|
|
304
303
|
if nesterov and normed_mt:
|
|
305
304
|
# Scale the normalized gradient using empirical buffer magnitude (SNR recovery)
|
|
306
|
-
normed_grad =
|
|
305
|
+
normed_grad = buf.abs().mul_(grad)
|
|
307
306
|
|
|
308
307
|
buf.lerp_(grad, 1 - momentum)
|
|
309
308
|
|
|
@@ -346,7 +345,7 @@ class SinkSGD_adv(torch.optim.Optimizer):
|
|
|
346
345
|
wd_scaler = get_sinkhorn_wd_scaler(p, row_denom=vt_row, col_denom=vt_col)
|
|
347
346
|
else:
|
|
348
347
|
wd_target = get_signsgd_wd_target(p, denom=denom)
|
|
349
|
-
if is_vector and group.get('centered_wd', 0.0) > 0 and '
|
|
348
|
+
if is_vector and group.get('centered_wd', 0.0) > 0 and 'anchor_data' in state:
|
|
350
349
|
anchor = dequantize_anchor(p, state, group, p.dtype)
|
|
351
350
|
cwd_target = get_signsgd_wd_target(p.sub(anchor), denom=denom)
|
|
352
351
|
del anchor
|
|
@@ -71,8 +71,7 @@ def _init_auxadam_state(self, p, group):
|
|
|
71
71
|
def _adam_step_parameter(self, p, grad, state, group, beta1_adam, beta2_adam, sqrt_bias_correction2, step_size, random_int_tensor, random_int_state_tensor=None):
|
|
72
72
|
grad = upcast_grad_for_precision(grad, state, group.get('adam_state_precision', 'auto'))
|
|
73
73
|
|
|
74
|
-
|
|
75
|
-
grad = _orthogonalize_gradient(p, grad)
|
|
74
|
+
grad = _orthogonalize_gradient(p, grad, group.get("adam_orthogonal_gradient"))
|
|
76
75
|
|
|
77
76
|
if hasattr(self, 'kourkoutas_helper') and self.kourkoutas_helper:
|
|
78
77
|
# Accumulate current grad's norm for the *next* step
|
|
@@ -190,4 +189,4 @@ def _adam_step_parameter(self, p, grad, state, group, beta1_adam, beta2_adam, sq
|
|
|
190
189
|
else:
|
|
191
190
|
update.mul_(update_scaling)
|
|
192
191
|
|
|
193
|
-
param_update.apply_parameter_update(self, p, group, update,
|
|
192
|
+
param_update.apply_parameter_update(self, p, group, update, group['lr'], group["adam_weight_decay"], random_int_tensor=random_int_tensor, wd_scaler=wd_scaler)
|