adv-optm 1.0.6__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of adv-optm might be problematic. Click here for more details.
- adv_optm/__init__.py +1 -1
- adv_optm/optim/AdamW_adv.py +61 -9
- adv_optm/optim/Adopt_adv.py +435 -388
- adv_optm/optim/Lion_Prodigy_adv.py +315 -315
- adv_optm/optim/Lion_adv.py +1 -1
- adv_optm/optim/Prodigy_adv.py +78 -19
- adv_optm/optim/Simplified_AdEMAMix.py +54 -2
- adv_optm/util/Kourkoutas.py +171 -0
- {adv_optm-1.0.6.dist-info → adv_optm-1.1.0.dist-info}/METADATA +1 -1
- adv_optm-1.1.0.dist-info/RECORD +20 -0
- adv_optm-1.0.6.dist-info/RECORD +0 -19
- {adv_optm-1.0.6.dist-info → adv_optm-1.1.0.dist-info}/WHEEL +0 -0
- {adv_optm-1.0.6.dist-info → adv_optm-1.1.0.dist-info}/licenses/LICENSE +0 -0
- {adv_optm-1.0.6.dist-info → adv_optm-1.1.0.dist-info}/top_level.txt +0 -0
adv_optm/__init__.py
CHANGED
adv_optm/optim/AdamW_adv.py
CHANGED
|
@@ -1,15 +1,16 @@
|
|
|
1
1
|
import torch
|
|
2
|
-
from typing import Optional
|
|
2
|
+
from typing import Optional, Callable
|
|
3
3
|
|
|
4
4
|
from ..util.BF16_Stochastic_Rounding import add_stochastic_
|
|
5
5
|
from ..util.Effective_Shape import _get_effective_shape
|
|
6
6
|
from ..util.NNMF import _nnmf,_unnmf
|
|
7
7
|
from ..util.OrthoGrad import _orthogonalize_gradient
|
|
8
8
|
from ..util.One_Bit_Boolean import _pack_bools, _unpack_bools
|
|
9
|
+
from ..util.Kourkoutas import KourkoutasHelper
|
|
9
10
|
|
|
10
11
|
class AdamW_adv(torch.optim.Optimizer):
|
|
11
12
|
"""
|
|
12
|
-
Implements
|
|
13
|
+
Implements an advanced AdamW algorithm.
|
|
13
14
|
This is an advanced version of AdamW with optional features like
|
|
14
15
|
low-rank factorization of optimizer states (SMMF), OrthoGrad, etc.
|
|
15
16
|
|
|
@@ -54,6 +55,28 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
54
55
|
as it gradually introduces the stabilizing slow momentum term. During
|
|
55
56
|
the warmup, `alpha` ramps from 0 to its target value. If `None`,
|
|
56
57
|
the scheduler is disabled. (default: None)
|
|
58
|
+
kourkoutas_beta (bool): whether to enable the layer-wise dynamic β₂ logic.
|
|
59
|
+
If `False`, the optimizer behaves as standard AdamW. (default: False)
|
|
60
|
+
beta2_min (float): The minimum value for dynamic β₂, used during periods of
|
|
61
|
+
high gradient variance ("sunspikes"). Must be less than `betas[1]`.
|
|
62
|
+
(default: 0.88)
|
|
63
|
+
ema_alpha (float): The decay rate for the Exponential Moving Average (EMA) of
|
|
64
|
+
the pooled gradient norms. Corresponds to `α` in the paper.
|
|
65
|
+
(default: 0.93)
|
|
66
|
+
tiny_spike (float): A small constant added to the denominator of the
|
|
67
|
+
"sunspike" ratio calculation to prevent division by zero. Corresponds
|
|
68
|
+
to `ε_spike` in the paper. (default: 1e-9)
|
|
69
|
+
k_warmup_steps (int): The number of initial steps during which β₂ is held
|
|
70
|
+
at a fixed beta2 value before the
|
|
71
|
+
dynamic logic activates. (default: 0)
|
|
72
|
+
k_logging (int): if > 0 and kourkoutas_beta=True, enables periodic console
|
|
73
|
+
logging of Kourkoutas-β statistics (min, max, mean of `β₂` across layers)
|
|
74
|
+
every logging steps. Useful for debugging and tuning. Set to 0 to disable
|
|
75
|
+
logging (default: 0).
|
|
76
|
+
layer_key_fn (Optional[Callable]): A function that takes a parameter `p`
|
|
77
|
+
and returns a unique, hashable key representing its "layer" or "bucket".
|
|
78
|
+
If `None`, parameters are bucketed by their memory ID (tensor-wise).
|
|
79
|
+
(default: None)
|
|
57
80
|
nnmf_factor (bool): whether to use the factorization or disable it to use
|
|
58
81
|
the uncompressed optimizer. (default: False)
|
|
59
82
|
"""
|
|
@@ -76,6 +99,13 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
76
99
|
beta3_ema: float = 0.9999,
|
|
77
100
|
alpha: float = 5.0,
|
|
78
101
|
t_alpha: int | None = None,
|
|
102
|
+
kourkoutas_beta: bool = False,
|
|
103
|
+
beta2_min: float = 0.9,
|
|
104
|
+
ema_alpha: float = 0.95,
|
|
105
|
+
tiny_spike: float = 1e-9,
|
|
106
|
+
k_warmup_steps: int = 0,
|
|
107
|
+
k_logging: int = 0,
|
|
108
|
+
layer_key_fn: Optional[Callable] = None,
|
|
79
109
|
nnmf_factor: bool = False,
|
|
80
110
|
):
|
|
81
111
|
if not (lr >= 0.0):
|
|
@@ -86,6 +116,8 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
86
116
|
raise ValueError(f"Epsilon should be >= 0.0. Got {eps}")
|
|
87
117
|
if not (weight_decay >= 0.0):
|
|
88
118
|
raise ValueError(f"Weight-decay should be >= 0.0. Got {weight_decay}")
|
|
119
|
+
if kourkoutas_beta and not (betas[1] > beta2_min): raise ValueError(f"For Kourkoutas-β, betas[1] (as beta2_max) must be > beta2_min. Got {betas[1]} and {beta2_min}")
|
|
120
|
+
|
|
89
121
|
if cautious_mask and grams_moment:
|
|
90
122
|
print("Warning: cautious is incompatible with grams, Disabling cautious.")
|
|
91
123
|
cautious_mask = False
|
|
@@ -95,14 +127,21 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
95
127
|
"vector_reshape": vector_reshape, "use_atan2": use_atan2,
|
|
96
128
|
"orthogonal_gradient": orthogonal_gradient, "use_bias_correction": use_bias_correction,
|
|
97
129
|
"beta3_ema": beta3_ema, "alpha": alpha, "t_alpha": t_alpha,
|
|
130
|
+
"kourkoutas_beta": kourkoutas_beta, "beta2_min": beta2_min, "ema_alpha": ema_alpha,
|
|
131
|
+
"tiny_spike": tiny_spike, "k_warmup_steps": k_warmup_steps, "k_logging": k_logging,
|
|
98
132
|
}
|
|
99
133
|
self.stochastic_rounding = stochastic_rounding
|
|
100
134
|
self.cautious_mask = cautious_mask
|
|
101
135
|
self.grams_moment = grams_moment
|
|
102
136
|
self.use_AdEMAMix = use_AdEMAMix
|
|
103
137
|
self.factored = nnmf_factor
|
|
138
|
+
self.kourkoutas_beta = kourkoutas_beta
|
|
139
|
+
self.layer_key_fn = layer_key_fn
|
|
104
140
|
super().__init__(params, defaults)
|
|
105
141
|
|
|
142
|
+
if self.kourkoutas_beta:
|
|
143
|
+
self.kourkoutas_helper = KourkoutasHelper(self)
|
|
144
|
+
|
|
106
145
|
@property
|
|
107
146
|
def supports_fused_back_pass(self):
|
|
108
147
|
return True
|
|
@@ -127,10 +166,8 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
127
166
|
grad = _orthogonalize_gradient(p, grad)
|
|
128
167
|
state = self.state[p]
|
|
129
168
|
|
|
130
|
-
beta1, beta2 = group['betas']
|
|
131
|
-
|
|
132
169
|
# State Initialization
|
|
133
|
-
if
|
|
170
|
+
if 'step' not in state:
|
|
134
171
|
state['step'] = 0
|
|
135
172
|
|
|
136
173
|
should_factor = (
|
|
@@ -148,7 +185,7 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
148
185
|
d1, d2 = state['effective_shape']
|
|
149
186
|
|
|
150
187
|
# First moment (m)
|
|
151
|
-
if
|
|
188
|
+
if group['betas'][0] > 0:
|
|
152
189
|
state['mu_m_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
|
|
153
190
|
state['mv_m_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
|
|
154
191
|
if not self.grams_moment:
|
|
@@ -163,16 +200,31 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
163
200
|
state['mu_v_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
|
|
164
201
|
state['mv_v_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
|
|
165
202
|
else: # Fallback to standard AdamW for non-factored tensors
|
|
166
|
-
if
|
|
203
|
+
if group['betas'][0] > 0:
|
|
167
204
|
state['exp_avg'] = torch.zeros_like(p, device=device, dtype=dtype)
|
|
168
205
|
if self.use_AdEMAMix:
|
|
169
206
|
state['exp_avg_slow'] = torch.zeros_like(p, device=device, dtype=dtype)
|
|
170
207
|
state['exp_avg_sq'] = torch.zeros_like(p, device=device, dtype=dtype)
|
|
171
208
|
|
|
209
|
+
beta1, beta2 = group['betas']
|
|
210
|
+
|
|
211
|
+
current_step = state['step']
|
|
212
|
+
if group['kourkoutas_beta']:
|
|
213
|
+
# Call prepare_step() once at the beginning of the step for all params
|
|
214
|
+
self.kourkoutas_helper.maybe_prepare_step(current_step)
|
|
215
|
+
# Accumulate current grad's norm for the *next* step
|
|
216
|
+
self.kourkoutas_helper.accumulate_gradient_sq_norm(p, grad)
|
|
217
|
+
# Get the dynamic beta2 calculated in prepare_step()
|
|
218
|
+
beta2 = self.kourkoutas_helper.get_beta2(p, group, current_step)
|
|
219
|
+
|
|
172
220
|
step = state['step'] + 1
|
|
173
221
|
if group['use_bias_correction']:
|
|
174
222
|
bias_correction1 = 1.0 - beta1 ** step
|
|
175
|
-
|
|
223
|
+
if group['kourkoutas_beta']:
|
|
224
|
+
bias_correction2 = 1.0 - group['betas'][1] ** step
|
|
225
|
+
# Use beta2_max for bias correction
|
|
226
|
+
else:
|
|
227
|
+
bias_correction2 = 1.0 - beta2 ** step
|
|
176
228
|
else:
|
|
177
229
|
bias_correction1 = 1
|
|
178
230
|
bias_correction2 = 1
|
|
@@ -315,4 +367,4 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
315
367
|
for i, p in enumerate(group['params']):
|
|
316
368
|
self.step_parameter(p, group, i)
|
|
317
369
|
|
|
318
|
-
return loss
|
|
370
|
+
return loss
|