adv-optm 2.4.dev2__tar.gz → 2.4.dev5__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {adv_optm-2.4.dev2 → adv_optm-2.4.dev5}/PKG-INFO +1 -1
- {adv_optm-2.4.dev2 → adv_optm-2.4.dev5}/adv_optm/__init__.py +1 -1
- {adv_optm-2.4.dev2 → adv_optm-2.4.dev5}/adv_optm/optim/AdaMuon_adv.py +9 -5
- {adv_optm-2.4.dev2 → adv_optm-2.4.dev5}/adv_optm/optim/AdamW_adv.py +30 -10
- {adv_optm-2.4.dev2 → adv_optm-2.4.dev5}/adv_optm/optim/Adopt_adv.py +45 -24
- {adv_optm-2.4.dev2 → adv_optm-2.4.dev5}/adv_optm/optim/Lion_Prodigy_adv.py +2 -2
- {adv_optm-2.4.dev2 → adv_optm-2.4.dev5}/adv_optm/optim/Lion_adv.py +42 -26
- {adv_optm-2.4.dev2 → adv_optm-2.4.dev5}/adv_optm/optim/Muon_adv.py +8 -4
- {adv_optm-2.4.dev2 → adv_optm-2.4.dev5}/adv_optm/optim/Prodigy_adv.py +27 -12
- {adv_optm-2.4.dev2 → adv_optm-2.4.dev5}/adv_optm/optim/SignSGD_adv.py +39 -24
- {adv_optm-2.4.dev2 → adv_optm-2.4.dev5}/adv_optm/optim/Simplified_AdEMAMix.py +12 -6
- {adv_optm-2.4.dev2 → adv_optm-2.4.dev5}/adv_optm/util/Kourkoutas.py +43 -12
- {adv_optm-2.4.dev2 → adv_optm-2.4.dev5}/adv_optm/util/Muon_AuxAdam.py +8 -2
- {adv_optm-2.4.dev2 → adv_optm-2.4.dev5}/adv_optm/util/Muon_util.py +7 -5
- adv_optm-2.4.dev5/adv_optm/util/OrthoGrad.py +50 -0
- {adv_optm-2.4.dev2 → adv_optm-2.4.dev5}/adv_optm/util/centered_decay.py +1 -1
- {adv_optm-2.4.dev2 → adv_optm-2.4.dev5}/adv_optm/util/param_update.py +54 -12
- {adv_optm-2.4.dev2 → adv_optm-2.4.dev5}/adv_optm/util/scaled_optm.py +28 -20
- adv_optm-2.4.dev5/adv_optm/util/signed_util.py +13 -0
- adv_optm-2.4.dev5/adv_optm/util/update_util.py +111 -0
- {adv_optm-2.4.dev2 → adv_optm-2.4.dev5}/adv_optm.egg-info/PKG-INFO +1 -1
- {adv_optm-2.4.dev2 → adv_optm-2.4.dev5}/adv_optm.egg-info/SOURCES.txt +1 -0
- {adv_optm-2.4.dev2 → adv_optm-2.4.dev5}/setup.py +1 -1
- adv_optm-2.4.dev2/adv_optm/util/OrthoGrad.py +0 -21
- adv_optm-2.4.dev2/adv_optm/util/update_util.py +0 -32
- {adv_optm-2.4.dev2 → adv_optm-2.4.dev5}/LICENSE +0 -0
- {adv_optm-2.4.dev2 → adv_optm-2.4.dev5}/README.md +0 -0
- {adv_optm-2.4.dev2 → adv_optm-2.4.dev5}/adv_optm/optim/__init__.py +0 -0
- {adv_optm-2.4.dev2 → adv_optm-2.4.dev5}/adv_optm/util/__init__.py +0 -0
- {adv_optm-2.4.dev2 → adv_optm-2.4.dev5}/adv_optm/util/factorization_util.py +0 -0
- {adv_optm-2.4.dev2 → adv_optm-2.4.dev5}/adv_optm/util/lion_k.py +0 -0
- {adv_optm-2.4.dev2 → adv_optm-2.4.dev5}/adv_optm.egg-info/dependency_links.txt +0 -0
- {adv_optm-2.4.dev2 → adv_optm-2.4.dev5}/adv_optm.egg-info/requires.txt +0 -0
- {adv_optm-2.4.dev2 → adv_optm-2.4.dev5}/adv_optm.egg-info/top_level.txt +0 -0
- {adv_optm-2.4.dev2 → adv_optm-2.4.dev5}/setup.cfg +0 -0
|
@@ -206,6 +206,8 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
206
206
|
if spectral_normalization and rms_rescaling:
|
|
207
207
|
print("Warning: spectral_normalization is incompatible with rms_rescaling, Disabling rms_rescaling.")
|
|
208
208
|
rms_rescaling = False
|
|
209
|
+
if spectral_normalization and accelerated_ns:
|
|
210
|
+
ValueError("spectral_normalization violates accelerated Newton-Schulz assumptions. Pick one of them.")
|
|
209
211
|
|
|
210
212
|
defaults = {
|
|
211
213
|
"lr": lr, "betas": betas, "weight_decay": weight_decay, "cautious_wd": cautious_wd,
|
|
@@ -260,6 +262,8 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
260
262
|
if group.get('use_muon') is None: # Fallback
|
|
261
263
|
group['use_muon'] = group.get('optim_type') == 'muon'
|
|
262
264
|
|
|
265
|
+
self.init_step()
|
|
266
|
+
|
|
263
267
|
self.kourkoutas_helper = None
|
|
264
268
|
if any(group.get('adam_kourkoutas_beta', False) for group in self.param_groups):
|
|
265
269
|
self.kourkoutas_helper = KourkoutasHelper(self)
|
|
@@ -280,8 +284,8 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
280
284
|
def load_state_dict(self, state_dict: dict) -> None:
|
|
281
285
|
"""
|
|
282
286
|
Overrides default load_state_dict to implement a workaround for PyTorch's
|
|
283
|
-
automatic dtype casting. It ensures factorized states remain float32 for
|
|
284
|
-
stability, preserves integer/float8 quantized anchor states, and forces
|
|
287
|
+
automatic dtype casting. It ensures factorized states remain float32 for
|
|
288
|
+
stability, preserves integer/float8 quantized anchor states, and forces
|
|
285
289
|
standard states onto the parameter's current dtype/device.
|
|
286
290
|
"""
|
|
287
291
|
super().load_state_dict(state_dict)
|
|
@@ -419,7 +423,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
419
423
|
step_size = group['lr'] / bias_correction1
|
|
420
424
|
|
|
421
425
|
if is_compiled:
|
|
422
|
-
step_size = torch.as_tensor(step_size
|
|
426
|
+
step_size = torch.as_tensor(step_size)
|
|
423
427
|
adam_step_param = self._compiled_adam_step_parameter
|
|
424
428
|
else:
|
|
425
429
|
adam_step_param = Muon_AuxAdam._adam_step_parameter
|
|
@@ -430,7 +434,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
430
434
|
|
|
431
435
|
else: # Muon path
|
|
432
436
|
if is_compiled:
|
|
433
|
-
lr = torch.as_tensor(group['lr']
|
|
437
|
+
lr = torch.as_tensor(group['lr'])
|
|
434
438
|
muon_step_param = self._compiled_muon_step_parameter
|
|
435
439
|
else:
|
|
436
440
|
lr = group['lr']
|
|
@@ -467,7 +471,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
467
471
|
else:
|
|
468
472
|
shape_for_scaling = p.shape
|
|
469
473
|
|
|
470
|
-
scaled_eps, adaptive_eps, spectral_target, wd_scale = get_spectral_scaling(shape_for_scaling, group['n_layers'])
|
|
474
|
+
scaled_eps, adaptive_eps, spectral_target, wd_scale = get_spectral_scaling(p, shape_for_scaling, group['n_layers'])
|
|
471
475
|
|
|
472
476
|
weight_decay = group['weight_decay'] * wd_scale
|
|
473
477
|
decoupled_wd = True
|
|
@@ -6,7 +6,7 @@ from typing import Optional, Callable
|
|
|
6
6
|
|
|
7
7
|
from ..util import param_update
|
|
8
8
|
from ..util.factorization_util import _get_effective_shape, _reconstruct_state, _factorize_state
|
|
9
|
-
from ..util.update_util import _grams_update, _cautious_update
|
|
9
|
+
from ..util.update_util import _grams_update, _cautious_update, _init_fisher_wd_scaler, _get_fisher_wd_scaler
|
|
10
10
|
from ..util.OrthoGrad import _orthogonalize_gradient
|
|
11
11
|
from ..util.Kourkoutas import KourkoutasHelper
|
|
12
12
|
from ..util.scaled_optm import scale_update, is_spectral, init_spectral_norm
|
|
@@ -29,6 +29,9 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
29
29
|
eps (float): term added to the denominator to improve
|
|
30
30
|
numerical stability (default: 1e-8)
|
|
31
31
|
weight_decay (float): weight decay (L2 penalty) (default: 0).
|
|
32
|
+
fisher_wd (bool): whether to use Fisher Adam (FAdam) weight decay, mapping
|
|
33
|
+
the decay direction through the empirical Fisher information matrix and
|
|
34
|
+
clipping its RMS. (default: False)
|
|
32
35
|
cautious_wd (bool): Enables Cautious Weight Decay. If True, weight decay is
|
|
33
36
|
applied only to parameter coordinates where the sign of the parameter
|
|
34
37
|
and the sign of the optimizer update align (default: False).
|
|
@@ -91,7 +94,7 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
91
94
|
'int4': Uses 4-bit block-wise quantization (block size 32).
|
|
92
95
|
nnmf_factor (bool): whether to use the factorization or disable it to use
|
|
93
96
|
the uncompressed optimizer. (default: False)
|
|
94
|
-
factored_2nd (bool): whether to keep the first moment uncompressed (dense)
|
|
97
|
+
factored_2nd (bool): whether to keep the first moment uncompressed (dense)
|
|
95
98
|
while only factorizing the second moment. (default: True)
|
|
96
99
|
"""
|
|
97
100
|
|
|
@@ -103,6 +106,7 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
103
106
|
eps: float = 1e-8,
|
|
104
107
|
# Decoupled/cautious weight decay
|
|
105
108
|
weight_decay: float = 0.0,
|
|
109
|
+
fisher_wd: bool = False,
|
|
106
110
|
cautious_wd: bool = False,
|
|
107
111
|
# Adam's Bias Correction
|
|
108
112
|
use_bias_correction: bool = True,
|
|
@@ -155,7 +159,8 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
155
159
|
cautious_mask = False
|
|
156
160
|
|
|
157
161
|
defaults = {
|
|
158
|
-
"lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay,
|
|
162
|
+
"lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay,
|
|
163
|
+
"fisher_wd": fisher_wd, "cautious_wd": cautious_wd,
|
|
159
164
|
"use_atan2": use_atan2,
|
|
160
165
|
"orthogonal_gradient": orthogonal_gradient, "use_bias_correction": use_bias_correction,
|
|
161
166
|
"beta3_ema": beta3_ema, "alpha": alpha, "compiled_optimizer": compiled_optimizer,
|
|
@@ -192,8 +197,8 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
192
197
|
def load_state_dict(self, state_dict: dict) -> None:
|
|
193
198
|
"""
|
|
194
199
|
Overrides default load_state_dict to implement a workaround for PyTorch's
|
|
195
|
-
automatic dtype casting. It ensures factorized states remain float32 for
|
|
196
|
-
stability, preserves integer/float8 quantized anchor states, and forces
|
|
200
|
+
automatic dtype casting. It ensures factorized states remain float32 for
|
|
201
|
+
stability, preserves integer/float8 quantized anchor states, and forces
|
|
197
202
|
standard states onto the parameter's current dtype/device.
|
|
198
203
|
"""
|
|
199
204
|
super().load_state_dict(state_dict)
|
|
@@ -273,6 +278,8 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
273
278
|
|
|
274
279
|
_init_anchor(p, state, group)
|
|
275
280
|
|
|
281
|
+
_init_fisher_wd_scaler(group, state, p)
|
|
282
|
+
|
|
276
283
|
beta1, beta2 = group['betas']
|
|
277
284
|
|
|
278
285
|
current_step = state['step']
|
|
@@ -294,7 +301,7 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
294
301
|
random_int_tensor = None
|
|
295
302
|
|
|
296
303
|
if group.get('compiled_optimizer', False):
|
|
297
|
-
step_size = torch.as_tensor(step_size
|
|
304
|
+
step_size = torch.as_tensor(step_size)
|
|
298
305
|
if p.dtype == torch.bfloat16 and self.stochastic_rounding:
|
|
299
306
|
# Pre-generate random tensor for stochastic rounding if needed.
|
|
300
307
|
random_int_tensor = param_update._get_random_int_for_sr(p)
|
|
@@ -349,7 +356,11 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
349
356
|
update_mt = mt if not factored_2nd else mt.clone()
|
|
350
357
|
|
|
351
358
|
vt = _reconstruct_state((state['mu_v_nmf'], state['mv_v_nmf']), signed=False)
|
|
352
|
-
|
|
359
|
+
|
|
360
|
+
if isinstance(beta2, torch.Tensor) and beta2.dim() > 0:
|
|
361
|
+
vt.mul_(beta2).addcmul_(grad_reshaped, grad_reshaped * (1.0 - beta2))
|
|
362
|
+
else:
|
|
363
|
+
vt.mul_(beta2).addcmul_(grad_reshaped, grad_reshaped, value=1.0 - beta2)
|
|
353
364
|
|
|
354
365
|
if self.use_AdEMAMix:
|
|
355
366
|
if factored_2nd:
|
|
@@ -363,7 +374,7 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
363
374
|
update = update_mt.add_(mt_slow, alpha=alpha)
|
|
364
375
|
else:
|
|
365
376
|
update = grad_reshaped.add(mt_slow, alpha=alpha)
|
|
366
|
-
|
|
377
|
+
|
|
367
378
|
if not factored_2nd:
|
|
368
379
|
# Factorize
|
|
369
380
|
state['mu_m_slow_nmf'], state['mv_m_slow_nmf'], state['sign_slow'] = _factorize_state(mt_slow, signed=True)
|
|
@@ -385,6 +396,9 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
385
396
|
denom = vt.sqrt_()
|
|
386
397
|
denom.div_(sqrt_bias_correction2).add_(group['eps'])
|
|
387
398
|
update.div_(denom)
|
|
399
|
+
|
|
400
|
+
wd_scaler = _get_fisher_wd_scaler(group, state.get("wd_scaler"), p, denom, group['use_atan2'])
|
|
401
|
+
|
|
388
402
|
del vt
|
|
389
403
|
|
|
390
404
|
update = update.view(p.shape)
|
|
@@ -413,7 +427,10 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
413
427
|
update = update_mt if beta1 > 0 else grad.clone()
|
|
414
428
|
|
|
415
429
|
exp_avg_sq = state['exp_avg_sq']
|
|
416
|
-
|
|
430
|
+
if isinstance(beta2, torch.Tensor) and beta2.dim() > 0:
|
|
431
|
+
exp_avg_sq.mul_(beta2).addcmul_(grad, grad * (1.0 - beta2))
|
|
432
|
+
else:
|
|
433
|
+
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)
|
|
417
434
|
|
|
418
435
|
if group['use_atan2']:
|
|
419
436
|
denom = exp_avg_sq.sqrt()
|
|
@@ -423,6 +440,9 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
423
440
|
denom = exp_avg_sq.sqrt()
|
|
424
441
|
denom.div_(sqrt_bias_correction2).add_(group['eps'])
|
|
425
442
|
update.div_(denom)
|
|
443
|
+
|
|
444
|
+
wd_scaler = _get_fisher_wd_scaler(group, state.get("wd_scaler"), p, denom, group['use_atan2'])
|
|
445
|
+
|
|
426
446
|
del denom
|
|
427
447
|
|
|
428
448
|
update_scaling = step_size * A if group['use_atan2'] else step_size
|
|
@@ -431,7 +451,7 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
431
451
|
else:
|
|
432
452
|
update.mul_(update_scaling)
|
|
433
453
|
|
|
434
|
-
param_update.apply_parameter_update(self, p, group, update, step_size, random_int_tensor=random_int_tensor)
|
|
454
|
+
param_update.apply_parameter_update(self, p, group, update, step_size, random_int_tensor=random_int_tensor, wd_scaler=wd_scaler)
|
|
435
455
|
|
|
436
456
|
def compile(self, *args, **kwargs):
|
|
437
457
|
self._compiled_step_parameter = torch.compile(self._step_parameter, *args, **kwargs)
|
|
@@ -7,7 +7,7 @@ from ..util import param_update
|
|
|
7
7
|
from ..util.factorization_util import _get_effective_shape, _reconstruct_state, _factorize_state, _nnmf
|
|
8
8
|
from ..util.OrthoGrad import _orthogonalize_gradient
|
|
9
9
|
from ..util.Kourkoutas import KourkoutasHelper
|
|
10
|
-
from ..util.update_util import _grams_update, _cautious_update, _scale_sim_AdEMAMix_update
|
|
10
|
+
from ..util.update_util import _grams_update, _cautious_update, _scale_sim_AdEMAMix_update, _init_fisher_wd_scaler, _get_fisher_wd_scaler
|
|
11
11
|
from ..util.scaled_optm import scale_update, is_spectral, init_spectral_norm
|
|
12
12
|
from ..util.centered_decay import _init_anchor
|
|
13
13
|
|
|
@@ -33,6 +33,9 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
33
33
|
eps (float): term added to the denominator to improve
|
|
34
34
|
numerical stability (default: 1e-6)
|
|
35
35
|
weight_decay (float): weight decay (L2 penalty) (default: 0)
|
|
36
|
+
fisher_wd (bool): whether to use Fisher Adam (FAdam) weight decay, mapping
|
|
37
|
+
the decay direction through the empirical Fisher information matrix and
|
|
38
|
+
clipping its RMS. (default: False)
|
|
36
39
|
cautious_wd (bool): Enables Cautious Weight Decay. If True, weight decay is
|
|
37
40
|
applied only to parameter coordinates where the sign of the parameter
|
|
38
41
|
and the sign of the optimizer update align (default: False).
|
|
@@ -107,7 +110,7 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
107
110
|
'int4': Uses 4-bit block-wise quantization (block size 32).
|
|
108
111
|
nnmf_factor (bool): whether to use the factorization or disable it to use
|
|
109
112
|
the uncompressed optimizer. (default: False)
|
|
110
|
-
factored_2nd (bool): whether to keep the first moment uncompressed (dense)
|
|
113
|
+
factored_2nd (bool): whether to keep the first moment uncompressed (dense)
|
|
111
114
|
while only factorizing the second moment. (default: True)
|
|
112
115
|
"""
|
|
113
116
|
|
|
@@ -119,6 +122,7 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
119
122
|
eps: float = 1e-6,
|
|
120
123
|
# Decoupled/cautious weight decay
|
|
121
124
|
weight_decay: float = 0.0,
|
|
125
|
+
fisher_wd: bool = False,
|
|
122
126
|
cautious_wd: bool = False,
|
|
123
127
|
# ADOPT clipping
|
|
124
128
|
clip_lambda: Optional[Callable[[int], float]] = lambda step: step**0.25,
|
|
@@ -181,7 +185,8 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
181
185
|
print("Warning: cautious is incompatible with Simplified_AdEMAMix, Disabling cautious.")
|
|
182
186
|
|
|
183
187
|
defaults = {
|
|
184
|
-
"lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay,
|
|
188
|
+
"lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay,
|
|
189
|
+
"fisher_wd": fisher_wd, "cautious_wd": cautious_wd,
|
|
185
190
|
"beta3_ema": beta3_ema, "alpha": alpha,
|
|
186
191
|
"alpha_grad": alpha_grad,
|
|
187
192
|
"kourkoutas_beta": kourkoutas_beta, "beta2_min": beta2_min, "ema_alpha": ema_alpha,
|
|
@@ -189,7 +194,7 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
189
194
|
"scaled_optm": scaled_optm,
|
|
190
195
|
"centered_wd": centered_wd,
|
|
191
196
|
"centered_wd_mode": centered_wd_mode,
|
|
192
|
-
"nnmf_factor": nnmf_factor, "vector_reshape": vector_reshape, "factored_2nd": factored_2nd,
|
|
197
|
+
"nnmf_factor": nnmf_factor, "vector_reshape": vector_reshape, "factored_2nd": factored_2nd,
|
|
193
198
|
"compiled_optimizer": compiled_optimizer,
|
|
194
199
|
}
|
|
195
200
|
self.clip_lambda = clip_lambda
|
|
@@ -222,8 +227,8 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
222
227
|
def load_state_dict(self, state_dict: dict) -> None:
|
|
223
228
|
"""
|
|
224
229
|
Overrides default load_state_dict to implement a workaround for PyTorch's
|
|
225
|
-
automatic dtype casting. It ensures factorized states remain float32 for
|
|
226
|
-
stability, preserves integer/float8 quantized anchor states, and forces
|
|
230
|
+
automatic dtype casting. It ensures factorized states remain float32 for
|
|
231
|
+
stability, preserves integer/float8 quantized anchor states, and forces
|
|
227
232
|
standard states onto the parameter's current dtype/device.
|
|
228
233
|
"""
|
|
229
234
|
super().load_state_dict(state_dict)
|
|
@@ -244,6 +249,19 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
244
249
|
grad = p.grad
|
|
245
250
|
state = self.state[p]
|
|
246
251
|
|
|
252
|
+
|
|
253
|
+
beta1, beta2 = group['betas']
|
|
254
|
+
|
|
255
|
+
if group.get('kourkoutas_beta', False):
|
|
256
|
+
if 'step' not in state:
|
|
257
|
+
current_step = 0
|
|
258
|
+
else:
|
|
259
|
+
current_step = state['step']
|
|
260
|
+
# Call prepare_step() once at the beginning of the step for all params
|
|
261
|
+
self.kourkoutas_helper.maybe_prepare_step(current_step, p.device)
|
|
262
|
+
# Get the dynamic beta2 calculated in prepare_step()
|
|
263
|
+
beta2 = self.kourkoutas_helper.get_beta2(p, group)
|
|
264
|
+
|
|
247
265
|
# State Initialization
|
|
248
266
|
if 'step' not in state:
|
|
249
267
|
state['step'] = 0
|
|
@@ -256,6 +274,12 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
256
274
|
|
|
257
275
|
dtype = torch.float32 if state['factored'] else p.dtype
|
|
258
276
|
|
|
277
|
+
vt_init = grad.pow(2).to(dtype)
|
|
278
|
+
if isinstance(beta2, torch.Tensor) and beta2.dim() > 0:
|
|
279
|
+
vt_init.mul_(beta2).addcmul_(grad.to(dtype), grad.to(dtype) * (1.0 - beta2))
|
|
280
|
+
else:
|
|
281
|
+
vt_init.mul_(beta2).addcmul_(grad.to(dtype), grad.to(dtype), value=1.0 - beta2)
|
|
282
|
+
|
|
259
283
|
if state['factored']:
|
|
260
284
|
state['effective_shape'] = _get_effective_shape(p.numel())
|
|
261
285
|
d1, d2 = state['effective_shape']
|
|
@@ -279,33 +303,23 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
279
303
|
if self.use_AdEMAMix:
|
|
280
304
|
state['exp_avg_slow'] = torch.zeros_like(p, device=p.device, dtype=dtype)
|
|
281
305
|
# Second moment (v)
|
|
282
|
-
|
|
283
|
-
# Allocate NMF factors for vt
|
|
284
|
-
state['mu_v_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
|
|
285
|
-
state['mv_v_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
|
|
286
|
-
# Initialize v_0
|
|
287
|
-
state['mu_v_nmf'], state['mv_v_nmf'] = _nnmf(vt_init)
|
|
306
|
+
state['mu_v_nmf'], state['mv_v_nmf'] = _nnmf(vt_init.view(d1, d2))
|
|
288
307
|
del vt_init
|
|
289
308
|
else: # Fallback for non-factored tensors
|
|
290
309
|
if group['betas'][0] > 0:
|
|
291
310
|
state['exp_avg'] = torch.zeros_like(p, device=p.device, dtype=dtype)
|
|
292
311
|
if self.use_AdEMAMix:
|
|
293
312
|
state['exp_avg_slow'] = torch.zeros_like(p, device=p.device, dtype=dtype)
|
|
294
|
-
state['exp_avg_sq'] =
|
|
313
|
+
state['exp_avg_sq'] = vt_init
|
|
295
314
|
|
|
296
315
|
if group.get('scaled_optm', False) and is_spectral(p):
|
|
297
316
|
init_spectral_norm(group, state, p)
|
|
298
317
|
|
|
299
318
|
_init_anchor(p, state, group)
|
|
300
319
|
|
|
301
|
-
|
|
320
|
+
_init_fisher_wd_scaler(group, state, p)
|
|
302
321
|
|
|
303
322
|
current_step = state['step']
|
|
304
|
-
if group.get('kourkoutas_beta', False):
|
|
305
|
-
# Call prepare_step() once at the beginning of the step for all params
|
|
306
|
-
self.kourkoutas_helper.maybe_prepare_step(current_step, p.device)
|
|
307
|
-
# Get the dynamic beta2 calculated in prepare_step()
|
|
308
|
-
beta2 = self.kourkoutas_helper.get_beta2(p, group)
|
|
309
323
|
|
|
310
324
|
# The first step is for initialization only (skip when use_atan2 as it's scale invariant).
|
|
311
325
|
if state['step'] == 0 and not self.use_atan2:
|
|
@@ -315,7 +329,7 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
315
329
|
random_int_tensor = None
|
|
316
330
|
|
|
317
331
|
if group.get('compiled_optimizer', False):
|
|
318
|
-
lr = torch.as_tensor(group['lr']
|
|
332
|
+
lr = torch.as_tensor(group['lr'])
|
|
319
333
|
if p.dtype == torch.bfloat16 and self.stochastic_rounding:
|
|
320
334
|
# Pre-generate random tensor for stochastic rounding if needed.
|
|
321
335
|
random_int_tensor = param_update._get_random_int_for_sr(p)
|
|
@@ -359,9 +373,13 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
359
373
|
|
|
360
374
|
# ADOPT Step A: Decorrelate g_t using v_{t-1}
|
|
361
375
|
denom = vt.sqrt()
|
|
376
|
+
wd_scaler = _get_fisher_wd_scaler(group, state.get("wd_scaler"), p, denom, self.use_atan2)
|
|
362
377
|
|
|
363
378
|
# Update second moment v_t for the *next* step using raw g_t
|
|
364
|
-
|
|
379
|
+
if isinstance(beta2, torch.Tensor) and beta2.dim() > 0:
|
|
380
|
+
vt.mul_(beta2).addcmul_(grad_reshaped, grad_reshaped * (1.0 - beta2))
|
|
381
|
+
else:
|
|
382
|
+
vt.mul_(beta2).addcmul_(grad_reshaped, grad_reshaped, value=1.0 - beta2)
|
|
365
383
|
# Factorize
|
|
366
384
|
state['mu_v_nmf'], state['mv_v_nmf'] = _factorize_state(vt, signed=False)
|
|
367
385
|
del vt
|
|
@@ -434,6 +452,7 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
434
452
|
|
|
435
453
|
# ADOPT Step A: Decorrelate g_t using v_{t-1}
|
|
436
454
|
denom = vt.sqrt()
|
|
455
|
+
wd_scaler = _get_fisher_wd_scaler(group, state.get("wd_scaler"), p, denom, self.use_atan2)
|
|
437
456
|
|
|
438
457
|
if self.use_atan2:
|
|
439
458
|
normalized_grad = torch.atan2(grad, denom, out=denom)
|
|
@@ -475,9 +494,11 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
475
494
|
else:
|
|
476
495
|
update = normalized_grad
|
|
477
496
|
|
|
478
|
-
|
|
479
497
|
# Update second moment v_t for the next step using raw g_t
|
|
480
|
-
|
|
498
|
+
if isinstance(beta2, torch.Tensor) and beta2.dim() > 0:
|
|
499
|
+
vt.mul_(beta2).addcmul_(grad, grad * (1.0 - beta2))
|
|
500
|
+
else:
|
|
501
|
+
vt.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
|
|
481
502
|
|
|
482
503
|
update_scaling = lr * A if self.use_atan2 else lr
|
|
483
504
|
|
|
@@ -487,7 +508,7 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
487
508
|
update.mul_(update_scaling)
|
|
488
509
|
|
|
489
510
|
# Parameter Update
|
|
490
|
-
param_update.apply_parameter_update(self, p, group, update, lr, random_int_tensor=random_int_tensor)
|
|
511
|
+
param_update.apply_parameter_update(self, p, group, update, lr, random_int_tensor=random_int_tensor, wd_scaler=wd_scaler)
|
|
491
512
|
|
|
492
513
|
def compile(self, *args, **kwargs):
|
|
493
514
|
self._compiled_step_parameter = torch.compile(self._step_parameter, *args, **kwargs)
|
|
@@ -225,8 +225,8 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
|
|
|
225
225
|
# Pre-generate random tensor for stochastic rounding if needed.
|
|
226
226
|
random_int_tensor = param_update._get_random_int_for_sr(p)
|
|
227
227
|
# TODO, workaround until pytorch#169634 is fixed
|
|
228
|
-
d = torch.as_tensor(group['d']
|
|
229
|
-
dlr = torch.as_tensor(dlr
|
|
228
|
+
d = torch.as_tensor(group['d'])
|
|
229
|
+
dlr = torch.as_tensor(dlr)
|
|
230
230
|
step_param_fn = self._compiled_step_parameter
|
|
231
231
|
else:
|
|
232
232
|
d = group['d']
|
|
@@ -8,6 +8,8 @@ from ..util.factorization_util import _get_effective_shape, _reconstruct_state,
|
|
|
8
8
|
from ..util.lion_k import _get_lion_k_update
|
|
9
9
|
from ..util.scaled_optm import scale_update, is_spectral, init_spectral_norm
|
|
10
10
|
from ..util.centered_decay import _init_anchor
|
|
11
|
+
from ..util.update_util import _get_l1_adaptive_lr
|
|
12
|
+
from ..util.signed_util import apply_stochastic_sign
|
|
11
13
|
|
|
12
14
|
|
|
13
15
|
class Lion_adv(torch.optim.Optimizer):
|
|
@@ -44,9 +46,10 @@ class Lion_adv(torch.optim.Optimizer):
|
|
|
44
46
|
parameter dimensionality. Sets p=2.0 for 4D tensors (Conv2D) (Biases/Norms) to
|
|
45
47
|
use Spherical updates, and p=1.0 for others (Linear/Embeddings) to use Sign
|
|
46
48
|
updates. Overrides explicit kappa_p value. (default: False).
|
|
49
|
+
stochastic_sign (bool): whether to use the Stochastic Sign operator. (default: False)
|
|
47
50
|
freeze_on_flip (bool): Projected SignGD One-hit freeze. Masks updates for
|
|
48
51
|
coordinates where the gradient sign flips compared to the previous step. (default: False)
|
|
49
|
-
l1_adaptive (bool): Scales learning rate dynamically
|
|
52
|
+
l1_adaptive (bool): Scales learning rate dynamically
|
|
50
53
|
by the L1 norm of the gradient to handle gradient heterogeneity. (default: False).
|
|
51
54
|
centered_wd (float): Centered Weight Decay coefficient. Instead of decaying weights
|
|
52
55
|
toward zero, they are decayed toward their initial values (anchors). This
|
|
@@ -79,6 +82,8 @@ class Lion_adv(torch.optim.Optimizer):
|
|
|
79
82
|
# Lion-k
|
|
80
83
|
kappa_p: float = 1.0,
|
|
81
84
|
auto_kappa_p: bool = False,
|
|
85
|
+
# Stochastic Sign Operator
|
|
86
|
+
stochastic_sign: bool = False,
|
|
82
87
|
# Projected and adaptive sign
|
|
83
88
|
freeze_on_flip: bool = False,
|
|
84
89
|
l1_adaptive: bool = False,
|
|
@@ -110,6 +115,7 @@ class Lion_adv(torch.optim.Optimizer):
|
|
|
110
115
|
clip_threshold=clip_threshold,
|
|
111
116
|
kappa_p=kappa_p,
|
|
112
117
|
auto_kappa_p=auto_kappa_p,
|
|
118
|
+
stochastic_sign=stochastic_sign,
|
|
113
119
|
freeze_on_flip=freeze_on_flip,
|
|
114
120
|
l1_adaptive=l1_adaptive,
|
|
115
121
|
scaled_optm= scaled_optm,
|
|
@@ -137,8 +143,8 @@ class Lion_adv(torch.optim.Optimizer):
|
|
|
137
143
|
def load_state_dict(self, state_dict: dict) -> None:
|
|
138
144
|
"""
|
|
139
145
|
Overrides default load_state_dict to implement a workaround for PyTorch's
|
|
140
|
-
automatic dtype casting. It ensures factorized states remain float32 for
|
|
141
|
-
stability, preserves integer/float8 quantized anchor states, and forces
|
|
146
|
+
automatic dtype casting. It ensures factorized states remain float32 for
|
|
147
|
+
stability, preserves integer/float8 quantized anchor states, and forces
|
|
142
148
|
standard states onto the parameter's current dtype/device.
|
|
143
149
|
"""
|
|
144
150
|
super().load_state_dict(state_dict)
|
|
@@ -201,19 +207,22 @@ class Lion_adv(torch.optim.Optimizer):
|
|
|
201
207
|
lr = group["lr"]
|
|
202
208
|
|
|
203
209
|
random_int_tensor = None
|
|
210
|
+
random_noise_tensor = None
|
|
204
211
|
|
|
205
212
|
if group.get('compiled_optimizer', False):
|
|
206
213
|
if p.dtype == torch.bfloat16 and self.stochastic_rounding:
|
|
207
214
|
# Pre-generate random tensor for stochastic rounding if needed.
|
|
208
215
|
random_int_tensor = param_update._get_random_int_for_sr(p)
|
|
209
|
-
|
|
216
|
+
if group.get('stochastic_sign', False):
|
|
217
|
+
random_noise_tensor = param_update._get_random_noise_for_sso(p)
|
|
218
|
+
lr = torch.as_tensor(lr)
|
|
210
219
|
step_param_fn = self._compiled_step_parameter
|
|
211
220
|
else:
|
|
212
221
|
step_param_fn = self._step_parameter
|
|
213
222
|
|
|
214
|
-
step_param_fn(p, grad, state, group, lr, random_int_tensor)
|
|
223
|
+
step_param_fn(p, grad, state, group, lr, random_int_tensor, random_noise_tensor)
|
|
215
224
|
|
|
216
|
-
def _step_parameter(self, p, grad, state, group, lr, random_int_tensor):
|
|
225
|
+
def _step_parameter(self, p, grad, state, group, lr, random_int_tensor, random_noise_tensor):
|
|
217
226
|
if grad.dtype != torch.float32 and state['factored']:
|
|
218
227
|
grad = grad.float()
|
|
219
228
|
if group["clip_threshold"] > 0.0:
|
|
@@ -251,9 +260,6 @@ class Lion_adv(torch.optim.Optimizer):
|
|
|
251
260
|
# Compute update term c_t
|
|
252
261
|
update = torch.lerp(grad_reshaped, exp_avg, beta1)
|
|
253
262
|
|
|
254
|
-
if group.get("l1_adaptive", False) and kappa_p == 1:
|
|
255
|
-
lr = lr * (update.norm(p=1))
|
|
256
|
-
|
|
257
263
|
# Standard Lion momentum update
|
|
258
264
|
# m_t = beta2 * m_{t-1} + (1-beta2) * g_t
|
|
259
265
|
exp_avg.lerp_(grad_reshaped, 1 - beta2)
|
|
@@ -262,7 +268,12 @@ class Lion_adv(torch.optim.Optimizer):
|
|
|
262
268
|
state['mu_m_nmf'], state['mv_m_nmf'], state['sign'] = _factorize_state(exp_avg, signed=True)
|
|
263
269
|
del exp_avg
|
|
264
270
|
|
|
265
|
-
|
|
271
|
+
if freeze_on_flip:
|
|
272
|
+
# Fast binary diff (XOR) from momentum sign directly
|
|
273
|
+
flipped_packed = prev_sign_packed ^ state['sign']
|
|
274
|
+
flipped_mask = _unpack_bools(flipped_packed, original_m=d2).view_as(update)
|
|
275
|
+
update = torch.where(flipped_mask, 0.0, update)
|
|
276
|
+
del prev_sign_packed, flipped_packed, flipped_mask
|
|
266
277
|
|
|
267
278
|
if self.cautious_mask:
|
|
268
279
|
mask = (update * grad_reshaped > 0).to(grad_reshaped.dtype)
|
|
@@ -272,12 +283,12 @@ class Lion_adv(torch.optim.Optimizer):
|
|
|
272
283
|
|
|
273
284
|
update = update.view(p.shape)
|
|
274
285
|
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
286
|
+
l1_mean = _get_l1_adaptive_lr(p, update, state, group, kappa_p, rescale=False)
|
|
287
|
+
|
|
288
|
+
if group.get('stochastic_sign', False):
|
|
289
|
+
update = apply_stochastic_sign(update, noise=random_noise_tensor)
|
|
290
|
+
else:
|
|
291
|
+
update = _get_lion_k_update(update, kappa_p)
|
|
281
292
|
|
|
282
293
|
else:
|
|
283
294
|
# Fallback to standard Lion logic
|
|
@@ -286,10 +297,13 @@ class Lion_adv(torch.optim.Optimizer):
|
|
|
286
297
|
# Compute update term
|
|
287
298
|
update = torch.lerp(grad, exp_avg, beta1)
|
|
288
299
|
|
|
289
|
-
|
|
290
|
-
|
|
300
|
+
# Standard Lion momentum update
|
|
301
|
+
exp_avg.lerp_(grad, 1 - beta2)
|
|
291
302
|
|
|
292
|
-
|
|
303
|
+
if freeze_on_flip:
|
|
304
|
+
current_sign = (update > 0).to(torch.uint8)
|
|
305
|
+
update = torch.where(current_sign == state['prev_sign'], update, 0.0)
|
|
306
|
+
state['prev_sign'] = current_sign
|
|
293
307
|
|
|
294
308
|
if self.cautious_mask:
|
|
295
309
|
mask = (update * grad > 0).to(grad.dtype)
|
|
@@ -297,20 +311,22 @@ class Lion_adv(torch.optim.Optimizer):
|
|
|
297
311
|
update.mul_(mask)
|
|
298
312
|
del mask
|
|
299
313
|
|
|
300
|
-
|
|
301
|
-
exp_avg.lerp_(grad, 1 - beta2)
|
|
314
|
+
l1_mean = _get_l1_adaptive_lr(p, update, state, group, kappa_p, rescale=False)
|
|
302
315
|
|
|
303
|
-
if
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
316
|
+
if group.get('stochastic_sign', False):
|
|
317
|
+
update = apply_stochastic_sign(update, noise=random_noise_tensor)
|
|
318
|
+
else:
|
|
319
|
+
update = _get_lion_k_update(update, kappa_p)
|
|
320
|
+
|
|
321
|
+
if l1_mean is not None:
|
|
322
|
+
update.mul_(l1_mean)
|
|
307
323
|
|
|
308
324
|
if group.get('scaled_optm', False):
|
|
309
325
|
update = scale_update(p, update, lr, vector_state=state.get('spectral_v'))
|
|
310
326
|
else:
|
|
311
327
|
update.mul_(lr)
|
|
312
328
|
|
|
313
|
-
param_update.apply_parameter_update(self, p, group, update, lr, random_int_tensor=random_int_tensor)
|
|
329
|
+
param_update.apply_parameter_update(self, p, group, update, lr, random_int_tensor=random_int_tensor, wd_scaler=l1_mean)
|
|
314
330
|
|
|
315
331
|
def compile(self, *args, **kwargs):
|
|
316
332
|
self._compiled_step_parameter = torch.compile(self._step_parameter, *args, **kwargs)
|
|
@@ -183,6 +183,8 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
183
183
|
if spectral_normalization and rms_rescaling:
|
|
184
184
|
print("Warning: spectral_normalization is incompatible with rms_rescaling, Disabling rms_rescaling.")
|
|
185
185
|
rms_rescaling = False
|
|
186
|
+
if spectral_normalization and accelerated_ns:
|
|
187
|
+
ValueError("spectral_normalization violates accelerated Newton-Schulz assumptions. Pick one of them.")
|
|
186
188
|
|
|
187
189
|
defaults = {
|
|
188
190
|
"lr": lr, "beta1": beta1, "weight_decay": weight_decay, "cautious_wd": cautious_wd,
|
|
@@ -239,6 +241,8 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
239
241
|
if group.get('use_muon') is None: # Fallback
|
|
240
242
|
group['use_muon'] = group.get('optim_type') == 'muon'
|
|
241
243
|
|
|
244
|
+
self.init_step()
|
|
245
|
+
|
|
242
246
|
self.kourkoutas_helper = None
|
|
243
247
|
if any(group.get('adam_kourkoutas_beta', False) for group in self.param_groups):
|
|
244
248
|
self.kourkoutas_helper = KourkoutasHelper(self)
|
|
@@ -259,8 +263,8 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
259
263
|
def load_state_dict(self, state_dict: dict) -> None:
|
|
260
264
|
"""
|
|
261
265
|
Overrides default load_state_dict to implement a workaround for PyTorch's
|
|
262
|
-
automatic dtype casting. It ensures factorized states remain float32 for
|
|
263
|
-
stability, preserves integer/float8 quantized anchor states, and forces
|
|
266
|
+
automatic dtype casting. It ensures factorized states remain float32 for
|
|
267
|
+
stability, preserves integer/float8 quantized anchor states, and forces
|
|
264
268
|
standard states onto the parameter's current dtype/device.
|
|
265
269
|
"""
|
|
266
270
|
super().load_state_dict(state_dict)
|
|
@@ -393,7 +397,7 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
393
397
|
step_size = group['lr'] / bias_correction1
|
|
394
398
|
|
|
395
399
|
if is_compiled:
|
|
396
|
-
step_size = torch.as_tensor(step_size
|
|
400
|
+
step_size = torch.as_tensor(step_size)
|
|
397
401
|
adam_step_param = self._compiled_adam_step_parameter
|
|
398
402
|
else:
|
|
399
403
|
adam_step_param = Muon_AuxAdam._adam_step_parameter
|
|
@@ -404,7 +408,7 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
404
408
|
|
|
405
409
|
else: # Muon path
|
|
406
410
|
if is_compiled:
|
|
407
|
-
lr = torch.as_tensor(group['lr']
|
|
411
|
+
lr = torch.as_tensor(group['lr'])
|
|
408
412
|
muon_step_param = self._compiled_muon_step_parameter
|
|
409
413
|
else:
|
|
410
414
|
lr = group['lr']
|