adv-optm 2.4.dev4__tar.gz → 2.4.dev5__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {adv_optm-2.4.dev4 → adv_optm-2.4.dev5}/PKG-INFO +1 -1
- {adv_optm-2.4.dev4 → adv_optm-2.4.dev5}/adv_optm/__init__.py +1 -1
- {adv_optm-2.4.dev4 → adv_optm-2.4.dev5}/adv_optm/optim/AdaMuon_adv.py +7 -3
- {adv_optm-2.4.dev4 → adv_optm-2.4.dev5}/adv_optm/optim/AdamW_adv.py +17 -4
- {adv_optm-2.4.dev4 → adv_optm-2.4.dev5}/adv_optm/optim/Adopt_adv.py +13 -4
- {adv_optm-2.4.dev4 → adv_optm-2.4.dev5}/adv_optm/optim/Lion_Prodigy_adv.py +2 -2
- {adv_optm-2.4.dev4 → adv_optm-2.4.dev5}/adv_optm/optim/Lion_adv.py +35 -21
- {adv_optm-2.4.dev4 → adv_optm-2.4.dev5}/adv_optm/optim/Muon_adv.py +6 -2
- {adv_optm-2.4.dev4 → adv_optm-2.4.dev5}/adv_optm/optim/Prodigy_adv.py +14 -5
- {adv_optm-2.4.dev4 → adv_optm-2.4.dev5}/adv_optm/optim/SignSGD_adv.py +31 -15
- {adv_optm-2.4.dev4 → adv_optm-2.4.dev5}/adv_optm/optim/Simplified_AdEMAMix.py +1 -1
- {adv_optm-2.4.dev4 → adv_optm-2.4.dev5}/adv_optm/util/Muon_util.py +7 -5
- {adv_optm-2.4.dev4 → adv_optm-2.4.dev5}/adv_optm/util/param_update.py +49 -7
- {adv_optm-2.4.dev4 → adv_optm-2.4.dev5}/adv_optm/util/scaled_optm.py +20 -16
- adv_optm-2.4.dev5/adv_optm/util/signed_util.py +13 -0
- {adv_optm-2.4.dev4 → adv_optm-2.4.dev5}/adv_optm/util/update_util.py +46 -8
- {adv_optm-2.4.dev4 → adv_optm-2.4.dev5}/adv_optm.egg-info/PKG-INFO +1 -1
- {adv_optm-2.4.dev4 → adv_optm-2.4.dev5}/adv_optm.egg-info/SOURCES.txt +1 -0
- {adv_optm-2.4.dev4 → adv_optm-2.4.dev5}/setup.py +1 -1
- {adv_optm-2.4.dev4 → adv_optm-2.4.dev5}/LICENSE +0 -0
- {adv_optm-2.4.dev4 → adv_optm-2.4.dev5}/README.md +0 -0
- {adv_optm-2.4.dev4 → adv_optm-2.4.dev5}/adv_optm/optim/__init__.py +0 -0
- {adv_optm-2.4.dev4 → adv_optm-2.4.dev5}/adv_optm/util/Kourkoutas.py +0 -0
- {adv_optm-2.4.dev4 → adv_optm-2.4.dev5}/adv_optm/util/Muon_AuxAdam.py +0 -0
- {adv_optm-2.4.dev4 → adv_optm-2.4.dev5}/adv_optm/util/OrthoGrad.py +0 -0
- {adv_optm-2.4.dev4 → adv_optm-2.4.dev5}/adv_optm/util/__init__.py +0 -0
- {adv_optm-2.4.dev4 → adv_optm-2.4.dev5}/adv_optm/util/centered_decay.py +0 -0
- {adv_optm-2.4.dev4 → adv_optm-2.4.dev5}/adv_optm/util/factorization_util.py +0 -0
- {adv_optm-2.4.dev4 → adv_optm-2.4.dev5}/adv_optm/util/lion_k.py +0 -0
- {adv_optm-2.4.dev4 → adv_optm-2.4.dev5}/adv_optm.egg-info/dependency_links.txt +0 -0
- {adv_optm-2.4.dev4 → adv_optm-2.4.dev5}/adv_optm.egg-info/requires.txt +0 -0
- {adv_optm-2.4.dev4 → adv_optm-2.4.dev5}/adv_optm.egg-info/top_level.txt +0 -0
- {adv_optm-2.4.dev4 → adv_optm-2.4.dev5}/setup.cfg +0 -0
|
@@ -206,6 +206,8 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
206
206
|
if spectral_normalization and rms_rescaling:
|
|
207
207
|
print("Warning: spectral_normalization is incompatible with rms_rescaling, Disabling rms_rescaling.")
|
|
208
208
|
rms_rescaling = False
|
|
209
|
+
if spectral_normalization and accelerated_ns:
|
|
210
|
+
ValueError("spectral_normalization violates accelerated Newton-Schulz assumptions. Pick one of them.")
|
|
209
211
|
|
|
210
212
|
defaults = {
|
|
211
213
|
"lr": lr, "betas": betas, "weight_decay": weight_decay, "cautious_wd": cautious_wd,
|
|
@@ -260,6 +262,8 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
260
262
|
if group.get('use_muon') is None: # Fallback
|
|
261
263
|
group['use_muon'] = group.get('optim_type') == 'muon'
|
|
262
264
|
|
|
265
|
+
self.init_step()
|
|
266
|
+
|
|
263
267
|
self.kourkoutas_helper = None
|
|
264
268
|
if any(group.get('adam_kourkoutas_beta', False) for group in self.param_groups):
|
|
265
269
|
self.kourkoutas_helper = KourkoutasHelper(self)
|
|
@@ -419,7 +423,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
419
423
|
step_size = group['lr'] / bias_correction1
|
|
420
424
|
|
|
421
425
|
if is_compiled:
|
|
422
|
-
step_size = torch.as_tensor(step_size
|
|
426
|
+
step_size = torch.as_tensor(step_size)
|
|
423
427
|
adam_step_param = self._compiled_adam_step_parameter
|
|
424
428
|
else:
|
|
425
429
|
adam_step_param = Muon_AuxAdam._adam_step_parameter
|
|
@@ -430,7 +434,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
430
434
|
|
|
431
435
|
else: # Muon path
|
|
432
436
|
if is_compiled:
|
|
433
|
-
lr = torch.as_tensor(group['lr']
|
|
437
|
+
lr = torch.as_tensor(group['lr'])
|
|
434
438
|
muon_step_param = self._compiled_muon_step_parameter
|
|
435
439
|
else:
|
|
436
440
|
lr = group['lr']
|
|
@@ -467,7 +471,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
467
471
|
else:
|
|
468
472
|
shape_for_scaling = p.shape
|
|
469
473
|
|
|
470
|
-
scaled_eps, adaptive_eps, spectral_target, wd_scale = get_spectral_scaling(shape_for_scaling, group['n_layers'])
|
|
474
|
+
scaled_eps, adaptive_eps, spectral_target, wd_scale = get_spectral_scaling(p, shape_for_scaling, group['n_layers'])
|
|
471
475
|
|
|
472
476
|
weight_decay = group['weight_decay'] * wd_scale
|
|
473
477
|
decoupled_wd = True
|
|
@@ -6,7 +6,7 @@ from typing import Optional, Callable
|
|
|
6
6
|
|
|
7
7
|
from ..util import param_update
|
|
8
8
|
from ..util.factorization_util import _get_effective_shape, _reconstruct_state, _factorize_state
|
|
9
|
-
from ..util.update_util import _grams_update, _cautious_update
|
|
9
|
+
from ..util.update_util import _grams_update, _cautious_update, _init_fisher_wd_scaler, _get_fisher_wd_scaler
|
|
10
10
|
from ..util.OrthoGrad import _orthogonalize_gradient
|
|
11
11
|
from ..util.Kourkoutas import KourkoutasHelper
|
|
12
12
|
from ..util.scaled_optm import scale_update, is_spectral, init_spectral_norm
|
|
@@ -29,6 +29,9 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
29
29
|
eps (float): term added to the denominator to improve
|
|
30
30
|
numerical stability (default: 1e-8)
|
|
31
31
|
weight_decay (float): weight decay (L2 penalty) (default: 0).
|
|
32
|
+
fisher_wd (bool): whether to use Fisher Adam (FAdam) weight decay, mapping
|
|
33
|
+
the decay direction through the empirical Fisher information matrix and
|
|
34
|
+
clipping its RMS. (default: False)
|
|
32
35
|
cautious_wd (bool): Enables Cautious Weight Decay. If True, weight decay is
|
|
33
36
|
applied only to parameter coordinates where the sign of the parameter
|
|
34
37
|
and the sign of the optimizer update align (default: False).
|
|
@@ -103,6 +106,7 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
103
106
|
eps: float = 1e-8,
|
|
104
107
|
# Decoupled/cautious weight decay
|
|
105
108
|
weight_decay: float = 0.0,
|
|
109
|
+
fisher_wd: bool = False,
|
|
106
110
|
cautious_wd: bool = False,
|
|
107
111
|
# Adam's Bias Correction
|
|
108
112
|
use_bias_correction: bool = True,
|
|
@@ -155,7 +159,8 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
155
159
|
cautious_mask = False
|
|
156
160
|
|
|
157
161
|
defaults = {
|
|
158
|
-
"lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay,
|
|
162
|
+
"lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay,
|
|
163
|
+
"fisher_wd": fisher_wd, "cautious_wd": cautious_wd,
|
|
159
164
|
"use_atan2": use_atan2,
|
|
160
165
|
"orthogonal_gradient": orthogonal_gradient, "use_bias_correction": use_bias_correction,
|
|
161
166
|
"beta3_ema": beta3_ema, "alpha": alpha, "compiled_optimizer": compiled_optimizer,
|
|
@@ -273,6 +278,8 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
273
278
|
|
|
274
279
|
_init_anchor(p, state, group)
|
|
275
280
|
|
|
281
|
+
_init_fisher_wd_scaler(group, state, p)
|
|
282
|
+
|
|
276
283
|
beta1, beta2 = group['betas']
|
|
277
284
|
|
|
278
285
|
current_step = state['step']
|
|
@@ -294,7 +301,7 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
294
301
|
random_int_tensor = None
|
|
295
302
|
|
|
296
303
|
if group.get('compiled_optimizer', False):
|
|
297
|
-
step_size = torch.as_tensor(step_size
|
|
304
|
+
step_size = torch.as_tensor(step_size)
|
|
298
305
|
if p.dtype == torch.bfloat16 and self.stochastic_rounding:
|
|
299
306
|
# Pre-generate random tensor for stochastic rounding if needed.
|
|
300
307
|
random_int_tensor = param_update._get_random_int_for_sr(p)
|
|
@@ -389,6 +396,9 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
389
396
|
denom = vt.sqrt_()
|
|
390
397
|
denom.div_(sqrt_bias_correction2).add_(group['eps'])
|
|
391
398
|
update.div_(denom)
|
|
399
|
+
|
|
400
|
+
wd_scaler = _get_fisher_wd_scaler(group, state.get("wd_scaler"), p, denom, group['use_atan2'])
|
|
401
|
+
|
|
392
402
|
del vt
|
|
393
403
|
|
|
394
404
|
update = update.view(p.shape)
|
|
@@ -430,6 +440,9 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
430
440
|
denom = exp_avg_sq.sqrt()
|
|
431
441
|
denom.div_(sqrt_bias_correction2).add_(group['eps'])
|
|
432
442
|
update.div_(denom)
|
|
443
|
+
|
|
444
|
+
wd_scaler = _get_fisher_wd_scaler(group, state.get("wd_scaler"), p, denom, group['use_atan2'])
|
|
445
|
+
|
|
433
446
|
del denom
|
|
434
447
|
|
|
435
448
|
update_scaling = step_size * A if group['use_atan2'] else step_size
|
|
@@ -438,7 +451,7 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
438
451
|
else:
|
|
439
452
|
update.mul_(update_scaling)
|
|
440
453
|
|
|
441
|
-
param_update.apply_parameter_update(self, p, group, update, step_size, random_int_tensor=random_int_tensor)
|
|
454
|
+
param_update.apply_parameter_update(self, p, group, update, step_size, random_int_tensor=random_int_tensor, wd_scaler=wd_scaler)
|
|
442
455
|
|
|
443
456
|
def compile(self, *args, **kwargs):
|
|
444
457
|
self._compiled_step_parameter = torch.compile(self._step_parameter, *args, **kwargs)
|
|
@@ -7,7 +7,7 @@ from ..util import param_update
|
|
|
7
7
|
from ..util.factorization_util import _get_effective_shape, _reconstruct_state, _factorize_state, _nnmf
|
|
8
8
|
from ..util.OrthoGrad import _orthogonalize_gradient
|
|
9
9
|
from ..util.Kourkoutas import KourkoutasHelper
|
|
10
|
-
from ..util.update_util import _grams_update, _cautious_update, _scale_sim_AdEMAMix_update
|
|
10
|
+
from ..util.update_util import _grams_update, _cautious_update, _scale_sim_AdEMAMix_update, _init_fisher_wd_scaler, _get_fisher_wd_scaler
|
|
11
11
|
from ..util.scaled_optm import scale_update, is_spectral, init_spectral_norm
|
|
12
12
|
from ..util.centered_decay import _init_anchor
|
|
13
13
|
|
|
@@ -33,6 +33,9 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
33
33
|
eps (float): term added to the denominator to improve
|
|
34
34
|
numerical stability (default: 1e-6)
|
|
35
35
|
weight_decay (float): weight decay (L2 penalty) (default: 0)
|
|
36
|
+
fisher_wd (bool): whether to use Fisher Adam (FAdam) weight decay, mapping
|
|
37
|
+
the decay direction through the empirical Fisher information matrix and
|
|
38
|
+
clipping its RMS. (default: False)
|
|
36
39
|
cautious_wd (bool): Enables Cautious Weight Decay. If True, weight decay is
|
|
37
40
|
applied only to parameter coordinates where the sign of the parameter
|
|
38
41
|
and the sign of the optimizer update align (default: False).
|
|
@@ -119,6 +122,7 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
119
122
|
eps: float = 1e-6,
|
|
120
123
|
# Decoupled/cautious weight decay
|
|
121
124
|
weight_decay: float = 0.0,
|
|
125
|
+
fisher_wd: bool = False,
|
|
122
126
|
cautious_wd: bool = False,
|
|
123
127
|
# ADOPT clipping
|
|
124
128
|
clip_lambda: Optional[Callable[[int], float]] = lambda step: step**0.25,
|
|
@@ -181,7 +185,8 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
181
185
|
print("Warning: cautious is incompatible with Simplified_AdEMAMix, Disabling cautious.")
|
|
182
186
|
|
|
183
187
|
defaults = {
|
|
184
|
-
"lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay,
|
|
188
|
+
"lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay,
|
|
189
|
+
"fisher_wd": fisher_wd, "cautious_wd": cautious_wd,
|
|
185
190
|
"beta3_ema": beta3_ema, "alpha": alpha,
|
|
186
191
|
"alpha_grad": alpha_grad,
|
|
187
192
|
"kourkoutas_beta": kourkoutas_beta, "beta2_min": beta2_min, "ema_alpha": ema_alpha,
|
|
@@ -312,6 +317,8 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
312
317
|
|
|
313
318
|
_init_anchor(p, state, group)
|
|
314
319
|
|
|
320
|
+
_init_fisher_wd_scaler(group, state, p)
|
|
321
|
+
|
|
315
322
|
current_step = state['step']
|
|
316
323
|
|
|
317
324
|
# The first step is for initialization only (skip when use_atan2 as it's scale invariant).
|
|
@@ -322,7 +329,7 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
322
329
|
random_int_tensor = None
|
|
323
330
|
|
|
324
331
|
if group.get('compiled_optimizer', False):
|
|
325
|
-
lr = torch.as_tensor(group['lr']
|
|
332
|
+
lr = torch.as_tensor(group['lr'])
|
|
326
333
|
if p.dtype == torch.bfloat16 and self.stochastic_rounding:
|
|
327
334
|
# Pre-generate random tensor for stochastic rounding if needed.
|
|
328
335
|
random_int_tensor = param_update._get_random_int_for_sr(p)
|
|
@@ -366,6 +373,7 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
366
373
|
|
|
367
374
|
# ADOPT Step A: Decorrelate g_t using v_{t-1}
|
|
368
375
|
denom = vt.sqrt()
|
|
376
|
+
wd_scaler = _get_fisher_wd_scaler(group, state.get("wd_scaler"), p, denom, self.use_atan2)
|
|
369
377
|
|
|
370
378
|
# Update second moment v_t for the *next* step using raw g_t
|
|
371
379
|
if isinstance(beta2, torch.Tensor) and beta2.dim() > 0:
|
|
@@ -444,6 +452,7 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
444
452
|
|
|
445
453
|
# ADOPT Step A: Decorrelate g_t using v_{t-1}
|
|
446
454
|
denom = vt.sqrt()
|
|
455
|
+
wd_scaler = _get_fisher_wd_scaler(group, state.get("wd_scaler"), p, denom, self.use_atan2)
|
|
447
456
|
|
|
448
457
|
if self.use_atan2:
|
|
449
458
|
normalized_grad = torch.atan2(grad, denom, out=denom)
|
|
@@ -499,7 +508,7 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
499
508
|
update.mul_(update_scaling)
|
|
500
509
|
|
|
501
510
|
# Parameter Update
|
|
502
|
-
param_update.apply_parameter_update(self, p, group, update, lr, random_int_tensor=random_int_tensor)
|
|
511
|
+
param_update.apply_parameter_update(self, p, group, update, lr, random_int_tensor=random_int_tensor, wd_scaler=wd_scaler)
|
|
503
512
|
|
|
504
513
|
def compile(self, *args, **kwargs):
|
|
505
514
|
self._compiled_step_parameter = torch.compile(self._step_parameter, *args, **kwargs)
|
|
@@ -225,8 +225,8 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
|
|
|
225
225
|
# Pre-generate random tensor for stochastic rounding if needed.
|
|
226
226
|
random_int_tensor = param_update._get_random_int_for_sr(p)
|
|
227
227
|
# TODO, workaround until pytorch#169634 is fixed
|
|
228
|
-
d = torch.as_tensor(group['d']
|
|
229
|
-
dlr = torch.as_tensor(dlr
|
|
228
|
+
d = torch.as_tensor(group['d'])
|
|
229
|
+
dlr = torch.as_tensor(dlr)
|
|
230
230
|
step_param_fn = self._compiled_step_parameter
|
|
231
231
|
else:
|
|
232
232
|
d = group['d']
|
|
@@ -9,6 +9,7 @@ from ..util.lion_k import _get_lion_k_update
|
|
|
9
9
|
from ..util.scaled_optm import scale_update, is_spectral, init_spectral_norm
|
|
10
10
|
from ..util.centered_decay import _init_anchor
|
|
11
11
|
from ..util.update_util import _get_l1_adaptive_lr
|
|
12
|
+
from ..util.signed_util import apply_stochastic_sign
|
|
12
13
|
|
|
13
14
|
|
|
14
15
|
class Lion_adv(torch.optim.Optimizer):
|
|
@@ -45,6 +46,7 @@ class Lion_adv(torch.optim.Optimizer):
|
|
|
45
46
|
parameter dimensionality. Sets p=2.0 for 4D tensors (Conv2D) (Biases/Norms) to
|
|
46
47
|
use Spherical updates, and p=1.0 for others (Linear/Embeddings) to use Sign
|
|
47
48
|
updates. Overrides explicit kappa_p value. (default: False).
|
|
49
|
+
stochastic_sign (bool): whether to use the Stochastic Sign operator. (default: False)
|
|
48
50
|
freeze_on_flip (bool): Projected SignGD One-hit freeze. Masks updates for
|
|
49
51
|
coordinates where the gradient sign flips compared to the previous step. (default: False)
|
|
50
52
|
l1_adaptive (bool): Scales learning rate dynamically
|
|
@@ -80,6 +82,8 @@ class Lion_adv(torch.optim.Optimizer):
|
|
|
80
82
|
# Lion-k
|
|
81
83
|
kappa_p: float = 1.0,
|
|
82
84
|
auto_kappa_p: bool = False,
|
|
85
|
+
# Stochastic Sign Operator
|
|
86
|
+
stochastic_sign: bool = False,
|
|
83
87
|
# Projected and adaptive sign
|
|
84
88
|
freeze_on_flip: bool = False,
|
|
85
89
|
l1_adaptive: bool = False,
|
|
@@ -111,6 +115,7 @@ class Lion_adv(torch.optim.Optimizer):
|
|
|
111
115
|
clip_threshold=clip_threshold,
|
|
112
116
|
kappa_p=kappa_p,
|
|
113
117
|
auto_kappa_p=auto_kappa_p,
|
|
118
|
+
stochastic_sign=stochastic_sign,
|
|
114
119
|
freeze_on_flip=freeze_on_flip,
|
|
115
120
|
l1_adaptive=l1_adaptive,
|
|
116
121
|
scaled_optm= scaled_optm,
|
|
@@ -202,19 +207,22 @@ class Lion_adv(torch.optim.Optimizer):
|
|
|
202
207
|
lr = group["lr"]
|
|
203
208
|
|
|
204
209
|
random_int_tensor = None
|
|
210
|
+
random_noise_tensor = None
|
|
205
211
|
|
|
206
212
|
if group.get('compiled_optimizer', False):
|
|
207
213
|
if p.dtype == torch.bfloat16 and self.stochastic_rounding:
|
|
208
214
|
# Pre-generate random tensor for stochastic rounding if needed.
|
|
209
215
|
random_int_tensor = param_update._get_random_int_for_sr(p)
|
|
210
|
-
|
|
216
|
+
if group.get('stochastic_sign', False):
|
|
217
|
+
random_noise_tensor = param_update._get_random_noise_for_sso(p)
|
|
218
|
+
lr = torch.as_tensor(lr)
|
|
211
219
|
step_param_fn = self._compiled_step_parameter
|
|
212
220
|
else:
|
|
213
221
|
step_param_fn = self._step_parameter
|
|
214
222
|
|
|
215
|
-
step_param_fn(p, grad, state, group, lr, random_int_tensor)
|
|
223
|
+
step_param_fn(p, grad, state, group, lr, random_int_tensor, random_noise_tensor)
|
|
216
224
|
|
|
217
|
-
def _step_parameter(self, p, grad, state, group, lr, random_int_tensor):
|
|
225
|
+
def _step_parameter(self, p, grad, state, group, lr, random_int_tensor, random_noise_tensor):
|
|
218
226
|
if grad.dtype != torch.float32 and state['factored']:
|
|
219
227
|
grad = grad.float()
|
|
220
228
|
if group["clip_threshold"] > 0.0:
|
|
@@ -252,8 +260,6 @@ class Lion_adv(torch.optim.Optimizer):
|
|
|
252
260
|
# Compute update term c_t
|
|
253
261
|
update = torch.lerp(grad_reshaped, exp_avg, beta1)
|
|
254
262
|
|
|
255
|
-
l1_mean = _get_l1_adaptive_lr(p, update, state, group, kappa_p)
|
|
256
|
-
|
|
257
263
|
# Standard Lion momentum update
|
|
258
264
|
# m_t = beta2 * m_{t-1} + (1-beta2) * g_t
|
|
259
265
|
exp_avg.lerp_(grad_reshaped, 1 - beta2)
|
|
@@ -262,7 +268,12 @@ class Lion_adv(torch.optim.Optimizer):
|
|
|
262
268
|
state['mu_m_nmf'], state['mv_m_nmf'], state['sign'] = _factorize_state(exp_avg, signed=True)
|
|
263
269
|
del exp_avg
|
|
264
270
|
|
|
265
|
-
|
|
271
|
+
if freeze_on_flip:
|
|
272
|
+
# Fast binary diff (XOR) from momentum sign directly
|
|
273
|
+
flipped_packed = prev_sign_packed ^ state['sign']
|
|
274
|
+
flipped_mask = _unpack_bools(flipped_packed, original_m=d2).view_as(update)
|
|
275
|
+
update = torch.where(flipped_mask, 0.0, update)
|
|
276
|
+
del prev_sign_packed, flipped_packed, flipped_mask
|
|
266
277
|
|
|
267
278
|
if self.cautious_mask:
|
|
268
279
|
mask = (update * grad_reshaped > 0).to(grad_reshaped.dtype)
|
|
@@ -272,12 +283,12 @@ class Lion_adv(torch.optim.Optimizer):
|
|
|
272
283
|
|
|
273
284
|
update = update.view(p.shape)
|
|
274
285
|
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
286
|
+
l1_mean = _get_l1_adaptive_lr(p, update, state, group, kappa_p, rescale=False)
|
|
287
|
+
|
|
288
|
+
if group.get('stochastic_sign', False):
|
|
289
|
+
update = apply_stochastic_sign(update, noise=random_noise_tensor)
|
|
290
|
+
else:
|
|
291
|
+
update = _get_lion_k_update(update, kappa_p)
|
|
281
292
|
|
|
282
293
|
else:
|
|
283
294
|
# Fallback to standard Lion logic
|
|
@@ -286,9 +297,13 @@ class Lion_adv(torch.optim.Optimizer):
|
|
|
286
297
|
# Compute update term
|
|
287
298
|
update = torch.lerp(grad, exp_avg, beta1)
|
|
288
299
|
|
|
289
|
-
|
|
300
|
+
# Standard Lion momentum update
|
|
301
|
+
exp_avg.lerp_(grad, 1 - beta2)
|
|
290
302
|
|
|
291
|
-
|
|
303
|
+
if freeze_on_flip:
|
|
304
|
+
current_sign = (update > 0).to(torch.uint8)
|
|
305
|
+
update = torch.where(current_sign == state['prev_sign'], update, 0.0)
|
|
306
|
+
state['prev_sign'] = current_sign
|
|
292
307
|
|
|
293
308
|
if self.cautious_mask:
|
|
294
309
|
mask = (update * grad > 0).to(grad.dtype)
|
|
@@ -296,13 +311,12 @@ class Lion_adv(torch.optim.Optimizer):
|
|
|
296
311
|
update.mul_(mask)
|
|
297
312
|
del mask
|
|
298
313
|
|
|
299
|
-
|
|
300
|
-
exp_avg.lerp_(grad, 1 - beta2)
|
|
314
|
+
l1_mean = _get_l1_adaptive_lr(p, update, state, group, kappa_p, rescale=False)
|
|
301
315
|
|
|
302
|
-
if
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
316
|
+
if group.get('stochastic_sign', False):
|
|
317
|
+
update = apply_stochastic_sign(update, noise=random_noise_tensor)
|
|
318
|
+
else:
|
|
319
|
+
update = _get_lion_k_update(update, kappa_p)
|
|
306
320
|
|
|
307
321
|
if l1_mean is not None:
|
|
308
322
|
update.mul_(l1_mean)
|
|
@@ -312,7 +326,7 @@ class Lion_adv(torch.optim.Optimizer):
|
|
|
312
326
|
else:
|
|
313
327
|
update.mul_(lr)
|
|
314
328
|
|
|
315
|
-
param_update.apply_parameter_update(self, p, group, update, lr, random_int_tensor=random_int_tensor)
|
|
329
|
+
param_update.apply_parameter_update(self, p, group, update, lr, random_int_tensor=random_int_tensor, wd_scaler=l1_mean)
|
|
316
330
|
|
|
317
331
|
def compile(self, *args, **kwargs):
|
|
318
332
|
self._compiled_step_parameter = torch.compile(self._step_parameter, *args, **kwargs)
|
|
@@ -183,6 +183,8 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
183
183
|
if spectral_normalization and rms_rescaling:
|
|
184
184
|
print("Warning: spectral_normalization is incompatible with rms_rescaling, Disabling rms_rescaling.")
|
|
185
185
|
rms_rescaling = False
|
|
186
|
+
if spectral_normalization and accelerated_ns:
|
|
187
|
+
ValueError("spectral_normalization violates accelerated Newton-Schulz assumptions. Pick one of them.")
|
|
186
188
|
|
|
187
189
|
defaults = {
|
|
188
190
|
"lr": lr, "beta1": beta1, "weight_decay": weight_decay, "cautious_wd": cautious_wd,
|
|
@@ -239,6 +241,8 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
239
241
|
if group.get('use_muon') is None: # Fallback
|
|
240
242
|
group['use_muon'] = group.get('optim_type') == 'muon'
|
|
241
243
|
|
|
244
|
+
self.init_step()
|
|
245
|
+
|
|
242
246
|
self.kourkoutas_helper = None
|
|
243
247
|
if any(group.get('adam_kourkoutas_beta', False) for group in self.param_groups):
|
|
244
248
|
self.kourkoutas_helper = KourkoutasHelper(self)
|
|
@@ -393,7 +397,7 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
393
397
|
step_size = group['lr'] / bias_correction1
|
|
394
398
|
|
|
395
399
|
if is_compiled:
|
|
396
|
-
step_size = torch.as_tensor(step_size
|
|
400
|
+
step_size = torch.as_tensor(step_size)
|
|
397
401
|
adam_step_param = self._compiled_adam_step_parameter
|
|
398
402
|
else:
|
|
399
403
|
adam_step_param = Muon_AuxAdam._adam_step_parameter
|
|
@@ -404,7 +408,7 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
404
408
|
|
|
405
409
|
else: # Muon path
|
|
406
410
|
if is_compiled:
|
|
407
|
-
lr = torch.as_tensor(group['lr']
|
|
411
|
+
lr = torch.as_tensor(group['lr'])
|
|
408
412
|
muon_step_param = self._compiled_muon_step_parameter
|
|
409
413
|
else:
|
|
410
414
|
lr = group['lr']
|
|
@@ -9,7 +9,7 @@ from ..util import param_update
|
|
|
9
9
|
from ..util.OrthoGrad import _orthogonalize_gradient
|
|
10
10
|
from ..util.Kourkoutas import KourkoutasHelper
|
|
11
11
|
from ..util.factorization_util import _get_effective_shape, _reconstruct_state, _factorize_state
|
|
12
|
-
from ..util.update_util import _grams_update, _cautious_update, _scale_sim_AdEMAMix_update
|
|
12
|
+
from ..util.update_util import _grams_update, _cautious_update, _scale_sim_AdEMAMix_update, _init_fisher_wd_scaler, _get_fisher_wd_scaler
|
|
13
13
|
from ..util.centered_decay import _init_anchor
|
|
14
14
|
|
|
15
15
|
A = 4 / math.pi
|
|
@@ -29,6 +29,9 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
29
29
|
eps (float): term added to the denominator to improve
|
|
30
30
|
numerical stability (default: 1e-8)
|
|
31
31
|
weight_decay (float): weight decay (L2 penalty) (default: 0)
|
|
32
|
+
fisher_wd (bool): whether to use Fisher Adam (FAdam) weight decay, mapping
|
|
33
|
+
the decay direction through the empirical Fisher information matrix and
|
|
34
|
+
clipping its RMS. (default: False)
|
|
32
35
|
cautious_wd (bool): Enables Cautious Weight Decay. If True, weight decay is
|
|
33
36
|
applied only to parameter coordinates where the sign of the parameter
|
|
34
37
|
and the sign of the optimizer update align (default: False).
|
|
@@ -133,6 +136,7 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
133
136
|
eps: float = 1e-8,
|
|
134
137
|
# Decoupled/cautious weight decay
|
|
135
138
|
weight_decay: float = 0.0,
|
|
139
|
+
fisher_wd: bool = False,
|
|
136
140
|
cautious_wd: bool = False,
|
|
137
141
|
# Stochastic Rounding for BF16
|
|
138
142
|
stochastic_rounding: bool = True,
|
|
@@ -206,7 +210,8 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
206
210
|
raise ValueError(f"For Kourkoutas-β, betas[1] (as beta2_max) must be > beta2_min. Got {betas[1]} and {beta2_min}")
|
|
207
211
|
|
|
208
212
|
defaults = {
|
|
209
|
-
"lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay,
|
|
213
|
+
"lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay,
|
|
214
|
+
"fisher_wd": fisher_wd, "cautious_wd": cautious_wd,
|
|
210
215
|
"use_atan2": use_atan2,
|
|
211
216
|
"orthogonal_gradient": orthogonal_gradient,
|
|
212
217
|
"beta3_ema": beta3_ema, "alpha": alpha, "compiled_optimizer": compiled_optimizer,
|
|
@@ -354,6 +359,8 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
354
359
|
|
|
355
360
|
_init_anchor(p, state, group)
|
|
356
361
|
|
|
362
|
+
_init_fisher_wd_scaler(group, state, p)
|
|
363
|
+
|
|
357
364
|
if not hasattr(self, 'd_denom'):
|
|
358
365
|
self.d_denom = torch.tensor(0.0, device=p.device)
|
|
359
366
|
self.d_numerator = torch.tensor(group.get('d_numerator', 0.0), device=p.device)
|
|
@@ -376,8 +383,8 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
376
383
|
# Pre-generate random tensor for stochastic rounding if needed.
|
|
377
384
|
random_int_tensor = param_update._get_random_int_for_sr(p)
|
|
378
385
|
# TODO, workaround until pytorch#169634 is fixed
|
|
379
|
-
d = torch.as_tensor(group['d']
|
|
380
|
-
dlr = torch.as_tensor(dlr
|
|
386
|
+
d = torch.as_tensor(group['d'])
|
|
387
|
+
dlr = torch.as_tensor(dlr)
|
|
381
388
|
step_param_fn = self._compiled_step_parameter
|
|
382
389
|
else:
|
|
383
390
|
d = group['d']
|
|
@@ -478,6 +485,7 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
478
485
|
else:
|
|
479
486
|
denom = vt.sqrt_()
|
|
480
487
|
update.div_(denom.add_(d * group['eps']))
|
|
488
|
+
wd_scaler = _get_fisher_wd_scaler(group, state.get("wd_scaler"), p, denom, group['use_atan2'])
|
|
481
489
|
del vt
|
|
482
490
|
|
|
483
491
|
update_scaling = dlr * A if group['use_atan2'] else dlr
|
|
@@ -528,6 +536,7 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
528
536
|
else:
|
|
529
537
|
denom = exp_avg_sq.sqrt()
|
|
530
538
|
update.div_(denom.add_(d * group['eps']))
|
|
539
|
+
wd_scaler = _get_fisher_wd_scaler(group, state.get("wd_scaler"), p, denom, group['use_atan2'])
|
|
531
540
|
del denom
|
|
532
541
|
|
|
533
542
|
update_scaling = dlr * A if group['use_atan2'] else dlr
|
|
@@ -557,7 +566,7 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
557
566
|
if 'p0' in state:
|
|
558
567
|
del state['p0']
|
|
559
568
|
|
|
560
|
-
param_update.apply_parameter_update(self, p, group, update, dlr, random_int_tensor=random_int_tensor)
|
|
569
|
+
param_update.apply_parameter_update(self, p, group, update, dlr, random_int_tensor=random_int_tensor, wd_scaler=wd_scaler)
|
|
561
570
|
|
|
562
571
|
def compile(self, *args, **kwargs):
|
|
563
572
|
self._compiled_step_parameter = torch.compile(self._step_parameter, *args, **kwargs)
|
|
@@ -9,6 +9,7 @@ from ..util.lion_k import _get_lion_k_update
|
|
|
9
9
|
from ..util.update_util import _get_l1_adaptive_lr
|
|
10
10
|
from ..util.scaled_optm import scale_update, is_spectral, init_spectral_norm
|
|
11
11
|
from ..util.centered_decay import _init_anchor
|
|
12
|
+
from ..util.signed_util import apply_stochastic_sign
|
|
12
13
|
|
|
13
14
|
|
|
14
15
|
class SignSGD_adv(torch.optim.Optimizer):
|
|
@@ -39,6 +40,7 @@ class SignSGD_adv(torch.optim.Optimizer):
|
|
|
39
40
|
parameter dimensionality. Sets p=2.0 for 4D tensors (Conv2D) (Biases/Norms) to
|
|
40
41
|
use Spherical updates, and p=1.0 for others (Linear/Embeddings) to use Sign
|
|
41
42
|
updates. Overrides explicit kappa_p value. (default: False).
|
|
43
|
+
stochastic_sign (bool): whether to use the Stochastic Sign operator. (default: False)
|
|
42
44
|
Simplified_AdEMAMix (bool): whether to use the Simplified AdEMAMix update rule.
|
|
43
45
|
This changes the EMA to accumulator and the update numerator to `alpha_grad * grad + mt`, which can be
|
|
44
46
|
more responsive, especially for small batch sizes. (default: False)
|
|
@@ -79,6 +81,8 @@ class SignSGD_adv(torch.optim.Optimizer):
|
|
|
79
81
|
# Projection-k
|
|
80
82
|
kappa_p: float = 1.0,
|
|
81
83
|
auto_kappa_p: bool = True,
|
|
84
|
+
# Stochastic Sign Operator
|
|
85
|
+
stochastic_sign: bool = False,
|
|
82
86
|
# Simplified_AdEMAMix
|
|
83
87
|
alpha_grad: float = 1.0,
|
|
84
88
|
Simplified_AdEMAMix: bool = False,
|
|
@@ -112,6 +116,7 @@ class SignSGD_adv(torch.optim.Optimizer):
|
|
|
112
116
|
orthogonal_gradient=orthogonal_gradient,
|
|
113
117
|
kappa_p=kappa_p,
|
|
114
118
|
auto_kappa_p=auto_kappa_p,
|
|
119
|
+
stochastic_sign=stochastic_sign,
|
|
115
120
|
alpha_grad=alpha_grad,
|
|
116
121
|
Simplified_AdEMAMix=Simplified_AdEMAMix,
|
|
117
122
|
scaled_optm= scaled_optm,
|
|
@@ -203,23 +208,26 @@ class SignSGD_adv(torch.optim.Optimizer):
|
|
|
203
208
|
lr = group["lr"]
|
|
204
209
|
|
|
205
210
|
random_int_tensor = None
|
|
211
|
+
random_noise_tensor = None
|
|
206
212
|
|
|
207
213
|
if group.get('compiled_optimizer', False):
|
|
208
214
|
if p.dtype == torch.bfloat16 and self.stochastic_rounding:
|
|
209
215
|
# Pre-generate random tensor for stochastic rounding if needed.
|
|
210
216
|
random_int_tensor = param_update._get_random_int_for_sr(p)
|
|
211
|
-
|
|
217
|
+
if group.get('stochastic_sign', False):
|
|
218
|
+
random_noise_tensor = param_update._get_random_noise_for_sso(p)
|
|
219
|
+
lr = torch.as_tensor(lr)
|
|
212
220
|
step_param_fn = self._compiled_step_parameter
|
|
213
221
|
else:
|
|
214
222
|
step_param_fn = self._step_parameter
|
|
215
223
|
|
|
216
|
-
step_param_fn(p, grad, state, group, lr, random_int_tensor)
|
|
224
|
+
step_param_fn(p, grad, state, group, lr, random_int_tensor, random_noise_tensor)
|
|
217
225
|
|
|
218
226
|
if group.get("l1_adaptive", False):
|
|
219
227
|
state["step"] += 1
|
|
220
228
|
|
|
221
|
-
def _step_parameter(self, p, grad, state, group, lr, random_int_tensor):
|
|
222
|
-
if grad.dtype != torch.float32 and state
|
|
229
|
+
def _step_parameter(self, p, grad, state, group, lr, random_int_tensor, random_noise_tensor):
|
|
230
|
+
if grad.dtype != torch.float32 and state.get('factored', False):
|
|
223
231
|
grad = grad.float()
|
|
224
232
|
|
|
225
233
|
if group["orthogonal_gradient"]:
|
|
@@ -269,18 +277,23 @@ class SignSGD_adv(torch.optim.Optimizer):
|
|
|
269
277
|
if freeze_on_flip:
|
|
270
278
|
state['sign'] = _pack_bools(raw_update > 0)
|
|
271
279
|
|
|
272
|
-
l1_mean = _get_l1_adaptive_lr(p, raw_update, state, group, kappa_p)
|
|
273
280
|
|
|
274
|
-
|
|
275
|
-
update = update.view(p.shape)
|
|
281
|
+
raw_update = raw_update.view(p.shape)
|
|
276
282
|
|
|
277
283
|
if freeze_on_flip:
|
|
278
284
|
# Fast binary diff (XOR) from momentum sign directly
|
|
279
285
|
flipped_packed = prev_sign_packed ^ state['sign']
|
|
280
|
-
flipped_mask = _unpack_bools(flipped_packed, original_m=d2).view_as(
|
|
281
|
-
|
|
286
|
+
flipped_mask = _unpack_bools(flipped_packed, original_m=d2).view_as(raw_update)
|
|
287
|
+
raw_update = torch.where(flipped_mask, 0.0, raw_update)
|
|
282
288
|
del prev_sign_packed, flipped_packed, flipped_mask
|
|
283
289
|
|
|
290
|
+
l1_mean = _get_l1_adaptive_lr(p, raw_update, state, group, kappa_p)
|
|
291
|
+
|
|
292
|
+
if group.get('stochastic_sign', False):
|
|
293
|
+
update = apply_stochastic_sign(raw_update, noise=random_noise_tensor)
|
|
294
|
+
else:
|
|
295
|
+
update = _get_lion_k_update(raw_update, kappa_p)
|
|
296
|
+
|
|
284
297
|
else:
|
|
285
298
|
# Fallback to standard SignSGD logic
|
|
286
299
|
if momentum > 0:
|
|
@@ -294,15 +307,18 @@ class SignSGD_adv(torch.optim.Optimizer):
|
|
|
294
307
|
else:
|
|
295
308
|
raw_update = grad.clone()
|
|
296
309
|
|
|
297
|
-
l1_mean = _get_l1_adaptive_lr(p, raw_update, state, group, kappa_p)
|
|
298
|
-
|
|
299
|
-
update = _get_lion_k_update(raw_update, kappa_p)
|
|
300
|
-
|
|
301
310
|
if freeze_on_flip:
|
|
302
311
|
current_sign = (raw_update > 0).to(torch.uint8)
|
|
303
|
-
|
|
312
|
+
raw_update = torch.where(current_sign == state['prev_sign'], raw_update, 0.0)
|
|
304
313
|
state['prev_sign'] = current_sign
|
|
305
314
|
|
|
315
|
+
l1_mean = _get_l1_adaptive_lr(p, raw_update, state, group, kappa_p)
|
|
316
|
+
|
|
317
|
+
if group.get('stochastic_sign', False):
|
|
318
|
+
update = apply_stochastic_sign(raw_update, noise=random_noise_tensor)
|
|
319
|
+
else:
|
|
320
|
+
update = _get_lion_k_update(raw_update, kappa_p)
|
|
321
|
+
|
|
306
322
|
if l1_mean is not None:
|
|
307
323
|
update.mul_(l1_mean)
|
|
308
324
|
|
|
@@ -311,7 +327,7 @@ class SignSGD_adv(torch.optim.Optimizer):
|
|
|
311
327
|
else:
|
|
312
328
|
update.mul_(lr)
|
|
313
329
|
|
|
314
|
-
param_update.apply_parameter_update(self, p, group, update, lr, random_int_tensor=random_int_tensor)
|
|
330
|
+
param_update.apply_parameter_update(self, p, group, update, lr, random_int_tensor=random_int_tensor, wd_scaler=l1_mean)
|
|
315
331
|
|
|
316
332
|
def compile(self, *args, **kwargs):
|
|
317
333
|
self._compiled_step_parameter = torch.compile(self._step_parameter, *args, **kwargs)
|
|
@@ -288,7 +288,7 @@ class Simplified_AdEMAMix(torch.optim.Optimizer):
|
|
|
288
288
|
# Pre-generate random tensor for stochastic rounding if needed.
|
|
289
289
|
random_int_tensor = param_update._get_random_int_for_sr(p)
|
|
290
290
|
# TODO, workaround until pytorch#169634 is fixed
|
|
291
|
-
lr = torch.as_tensor(lr
|
|
291
|
+
lr = torch.as_tensor(lr)
|
|
292
292
|
step_param_fn = self._compiled_step_parameter
|
|
293
293
|
else:
|
|
294
294
|
step_param_fn = self._step_parameter
|
|
@@ -1,5 +1,7 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
|
|
3
|
+
import math
|
|
4
|
+
|
|
3
5
|
@torch.no_grad()
|
|
4
6
|
def _newton_schulz_iteration(
|
|
5
7
|
G: torch.Tensor,
|
|
@@ -359,11 +361,11 @@ def rms_adjustment(update: torch.Tensor, rms_rescaling: bool, lr):
|
|
|
359
361
|
# This is slower due to norm calculations but it worked the best for t2i models.
|
|
360
362
|
rms_target = 0.2 # default (Adam) value for RMS
|
|
361
363
|
update_norm = torch.linalg.vector_norm(update)
|
|
362
|
-
return update.mul_(lr * rms_target * (update.numel()
|
|
364
|
+
return update.mul_(lr * rms_target * (math.sqrt(update.numel())) / update_norm.clamp_min_(1e-8))
|
|
363
365
|
else:
|
|
364
366
|
# Original Muon scaling
|
|
365
367
|
r, c = update.size(-2), update.size(-1)
|
|
366
|
-
scaling_factor = max(1, r / c)
|
|
368
|
+
scaling_factor = math.sqrt(max(1, r / c))
|
|
367
369
|
return update.mul_(lr * scaling_factor)
|
|
368
370
|
|
|
369
371
|
def _auto_projection_for_adamuon(raw_update: torch.Tensor, kappa_p: float) -> torch.Tensor:
|
|
@@ -474,15 +476,15 @@ def get_spectral_scaling(shape: torch.Size, n_layers: int):
|
|
|
474
476
|
# A) Newton-Schulz Damping
|
|
475
477
|
# This ensures the matrix orthogonalization is stable across scales.
|
|
476
478
|
# Formula: (1/L) * sqrt(d_in / d_out)
|
|
477
|
-
ns_eps = (1.0 / L) * (d_in / d_out)
|
|
479
|
+
ns_eps = (1.0 / L) * math.sqrt(d_in / d_out)
|
|
478
480
|
|
|
479
481
|
# B) Adaptive Denominator Epsilon
|
|
480
482
|
# This ensures the Adam-style division doesn't explode or vanish.
|
|
481
483
|
# Formula: (1/L) * (1 / sqrt(d_in * d_out))
|
|
482
|
-
adaptive_eps = (1.0 / L) * (1.0 / (d_in * d_out)
|
|
484
|
+
adaptive_eps = (1.0 / L) * (1.0 / math.sqrt(d_in * d_out))
|
|
483
485
|
|
|
484
486
|
# Spectral Target (Section F) -> sqrt(d_out/d_in)
|
|
485
|
-
spectral_target = (d_out / d_in)
|
|
487
|
+
spectral_target = math.sqrt(d_out / d_in)
|
|
486
488
|
|
|
487
489
|
# Weight Decay (Section 3.4) -> 1/width
|
|
488
490
|
wd_scale = 1.0 / d_in
|
|
@@ -4,7 +4,7 @@ from torch.optim import Optimizer
|
|
|
4
4
|
|
|
5
5
|
from typing import Dict, Any
|
|
6
6
|
|
|
7
|
-
from .scaled_optm import scale_wds
|
|
7
|
+
from .scaled_optm import adjust_wds, scale_wds
|
|
8
8
|
from .centered_decay import dequantize_anchor
|
|
9
9
|
|
|
10
10
|
_generators: Dict[torch.device, torch.Generator] = {}
|
|
@@ -29,11 +29,17 @@ def _apply_weight_decay(
|
|
|
29
29
|
# Cautious Weight Decay: only decay if the update pushes in the same direction as the decay
|
|
30
30
|
if cautious:
|
|
31
31
|
mask = (update_calc * p_calc >= 0).to(p_calc.dtype)
|
|
32
|
-
|
|
32
|
+
if isinstance(scaled_wd, Tensor):
|
|
33
|
+
p_calc.addcmul_(p_calc, mask * scaled_wd, value=-1.0)
|
|
34
|
+
else:
|
|
35
|
+
p_calc.addcmul_(p_calc, mask, value=-scaled_wd)
|
|
33
36
|
del mask
|
|
34
37
|
else:
|
|
35
38
|
# Standard decoupled weight decay
|
|
36
|
-
|
|
39
|
+
if isinstance(scaled_wd, Tensor):
|
|
40
|
+
p_calc.addcmul_(p_calc, scaled_wd, value=-1.0)
|
|
41
|
+
else:
|
|
42
|
+
p_calc.add_(p_calc, alpha=-scaled_wd)
|
|
37
43
|
|
|
38
44
|
# Centered Weight Decay (pulls toward anchor)
|
|
39
45
|
if scaled_cwd is not None and 'anchor_type' in state:
|
|
@@ -43,15 +49,20 @@ def _apply_weight_decay(
|
|
|
43
49
|
if cautious:
|
|
44
50
|
# Cautious Weight Decay: only decay if the update pushes in the same direction as the decay
|
|
45
51
|
mask = (update_calc * decay_target >= 0).to(p_calc.dtype)
|
|
46
|
-
|
|
52
|
+
if isinstance(scaled_cwd, Tensor):
|
|
53
|
+
p_calc.addcmul_(decay_target, mask * scaled_cwd, value=-1.0)
|
|
54
|
+
else:
|
|
55
|
+
p_calc.addcmul_(decay_target, mask, value=-scaled_cwd)
|
|
47
56
|
del mask
|
|
48
57
|
else:
|
|
49
58
|
# Standard decoupled weight decay
|
|
50
|
-
|
|
59
|
+
if isinstance(scaled_cwd, Tensor):
|
|
60
|
+
p_calc.addcmul_(decay_target, scaled_cwd, value=-1.0)
|
|
61
|
+
else:
|
|
62
|
+
p_calc.add_(decay_target, alpha=-scaled_cwd)
|
|
51
63
|
|
|
52
64
|
del anchor, decay_target
|
|
53
65
|
|
|
54
|
-
|
|
55
66
|
def apply_parameter_update(
|
|
56
67
|
self,
|
|
57
68
|
p: Tensor,
|
|
@@ -61,6 +72,7 @@ def apply_parameter_update(
|
|
|
61
72
|
wd: float | None = None,
|
|
62
73
|
random_int_tensor: Tensor | None = None,
|
|
63
74
|
decoupled: bool = False,
|
|
75
|
+
wd_scaler: float | Tensor | None = None,
|
|
64
76
|
) -> None:
|
|
65
77
|
"""
|
|
66
78
|
Applies decoupled weight decay (standard, cautious, centered) and the final
|
|
@@ -75,13 +87,16 @@ def apply_parameter_update(
|
|
|
75
87
|
random_int_tensor: Optional pre-generated random tensor for stochastic
|
|
76
88
|
rounding. Required for the `torch.compile` path.
|
|
77
89
|
decoupled: Whenever to use the true decoupled weight decay.
|
|
90
|
+
wd_scaler: A multiplier/tensor to scale the calculated wd/cwd magnitude (e.g. for Fisher Adam WD).
|
|
78
91
|
"""
|
|
79
92
|
wd = group["weight_decay"] if wd is None else wd
|
|
80
93
|
cwd = group.get("centered_wd", 0.0)
|
|
81
94
|
|
|
82
95
|
if group.get('scaled_optm', False):
|
|
83
96
|
decoupled = True
|
|
84
|
-
wd, cwd =
|
|
97
|
+
wd, cwd = adjust_wds(wd, cwd, p)
|
|
98
|
+
if wd_scaler is None:
|
|
99
|
+
wd, cwd = scale_wds(wd, cwd, p)
|
|
85
100
|
|
|
86
101
|
# Calculate global decay factor for decoupled vs standard
|
|
87
102
|
decay_factor = (lr / self._init_lr) if decoupled else lr
|
|
@@ -89,6 +104,12 @@ def apply_parameter_update(
|
|
|
89
104
|
scaled_wd = (wd * decay_factor) if wd != 0 else None
|
|
90
105
|
scaled_cwd = (cwd * decay_factor) if cwd != 0 else None
|
|
91
106
|
|
|
107
|
+
if wd_scaler is not None:
|
|
108
|
+
if scaled_wd is not None:
|
|
109
|
+
scaled_wd = scaled_wd * wd_scaler
|
|
110
|
+
if scaled_cwd is not None:
|
|
111
|
+
scaled_cwd = scaled_cwd * wd_scaler
|
|
112
|
+
|
|
92
113
|
state = self.state[p]
|
|
93
114
|
|
|
94
115
|
# Compute full update in float32 if using bfloat16 with stochastic rounding
|
|
@@ -284,3 +305,24 @@ def post_process_loaded_state(optimizer: Optimizer) -> None:
|
|
|
284
305
|
# Ensure device match
|
|
285
306
|
if state[key].device != p.device:
|
|
286
307
|
state[key] = state[key].to(p.device)
|
|
308
|
+
|
|
309
|
+
|
|
310
|
+
def _get_random_noise_for_sso(source: torch.Tensor) -> torch.Tensor:
|
|
311
|
+
"""
|
|
312
|
+
Generates a random noise tensor for Stochastic Sign operator.
|
|
313
|
+
This function is not torch.compile-path friendly due to its use of torch.Generator.
|
|
314
|
+
"""
|
|
315
|
+
global _generators
|
|
316
|
+
device = source.device
|
|
317
|
+
if device not in _generators:
|
|
318
|
+
set_seed(device)
|
|
319
|
+
# TODO, this is a workaround until torch compile error
|
|
320
|
+
# NotImplementedError: UserDefinedObjectVariable(generator) is fixed
|
|
321
|
+
generator = _generators[device]
|
|
322
|
+
# create a random noise tensor
|
|
323
|
+
return torch.randint(
|
|
324
|
+
size=source.shape,
|
|
325
|
+
device=source.device,
|
|
326
|
+
dtype=source.dtype,
|
|
327
|
+
generator=generator,
|
|
328
|
+
)
|
|
@@ -46,35 +46,39 @@ def scale_update(
|
|
|
46
46
|
return update.mul_(lr)
|
|
47
47
|
|
|
48
48
|
|
|
49
|
-
def
|
|
49
|
+
def adjust_wds(wd: float, cwd: float, p: torch.Tensor) -> tuple[float, float]:
|
|
50
50
|
"""
|
|
51
|
-
Adjusts standard weight decay and centered weight decay
|
|
52
|
-
shape and type to maintain effective regularization strength.
|
|
51
|
+
Adjusts standard weight decay and centered weight decay.
|
|
53
52
|
"""
|
|
54
53
|
# DoRA Scale (Magnitude Vector)
|
|
55
54
|
if getattr(p, '_is_dora_scale', False):
|
|
56
55
|
return wd, cwd
|
|
57
56
|
|
|
58
|
-
conflict = cwd != 0
|
|
59
|
-
|
|
60
57
|
if getattr(p, '_is_oft', False):
|
|
61
|
-
|
|
62
|
-
return (cwd if conflict else wd), 0.0
|
|
58
|
+
return wd, 0.0
|
|
63
59
|
|
|
64
60
|
if p.ndim >= 2:
|
|
65
|
-
fan_in = p.numel() // p.shape[0]
|
|
66
|
-
|
|
67
|
-
# When both WDs are active on LoRA, fallback to standard WD (using cwd value)
|
|
68
|
-
# Reverts the behavior for better DoRA tuning.
|
|
69
61
|
is_lora = getattr(p, '_is_lora_A', False) or getattr(p, '_is_lora_B', False)
|
|
70
|
-
if
|
|
71
|
-
return
|
|
62
|
+
if is_lora:
|
|
63
|
+
return wd, 0.0
|
|
64
|
+
|
|
65
|
+
else:
|
|
66
|
+
# 1D Biases or generic 1D parameters
|
|
67
|
+
# Centered WD safely regularizes the delta without collapsing base feature variance.
|
|
68
|
+
return 0.0, cwd
|
|
69
|
+
|
|
72
70
|
|
|
71
|
+
def scale_wds(wd: float, cwd: float, p: torch.Tensor) -> tuple[float, float]:
|
|
72
|
+
"""
|
|
73
|
+
Scales standard weight decay and centered weight decay based on the parameter's
|
|
74
|
+
shape and type to maintain effective regularization strength.
|
|
75
|
+
"""
|
|
76
|
+
if p.ndim >= 2:
|
|
77
|
+
fan_in = p.numel() // p.shape[0]
|
|
73
78
|
return wd / fan_in, cwd / fan_in
|
|
74
79
|
|
|
75
|
-
# 1D
|
|
76
|
-
|
|
77
|
-
return 0.0, cwd
|
|
80
|
+
# 1D tensors (like DoRA scale and Biases)
|
|
81
|
+
return wd, cwd
|
|
78
82
|
|
|
79
83
|
|
|
80
84
|
@torch.no_grad()
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def apply_stochastic_sign(update: torch.Tensor, noise: torch.Tensor | None) -> torch.Tensor:
|
|
5
|
+
"""
|
|
6
|
+
Applies the Stochastic Sign operator S_R(v).
|
|
7
|
+
Uses uniform noise injection to compute the stochastic sign
|
|
8
|
+
"""
|
|
9
|
+
R = update.abs().max().clamp_min(1e-12)
|
|
10
|
+
|
|
11
|
+
if noise is None:
|
|
12
|
+
noise = torch.rand_like(update) * 2.0 - 1.0
|
|
13
|
+
return torch.sign(update / R + noise, out=update)
|
|
@@ -1,5 +1,7 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
|
|
3
|
+
import math
|
|
4
|
+
|
|
3
5
|
def _grams_update(mt: torch.Tensor, grad: torch.Tensor, inplace: bool=False):
|
|
4
6
|
"""
|
|
5
7
|
Applies the update rule of "Gradient Descent with Adaptive Momentum Scaling"
|
|
@@ -31,27 +33,63 @@ def _scale_sim_AdEMAMix_update(beta: float, current_step: int, alpha_grad: float
|
|
|
31
33
|
lr = lr * total_scale
|
|
32
34
|
return lr
|
|
33
35
|
|
|
36
|
+
|
|
37
|
+
def _init_fisher_wd_scaler(group: dict, state: dict, p: torch.Tensor) -> torch.Tensor | None:
|
|
38
|
+
if not group.get('fisher_wd', False):
|
|
39
|
+
return
|
|
40
|
+
|
|
41
|
+
state["wd_scaler"] = torch.tensor(1.0, device=p.device)
|
|
42
|
+
|
|
43
|
+
def _get_fisher_wd_scaler(group: dict, stored_scaler: torch.Tensor, p: torch.Tensor, denom: torch.Tensor, atan2: bool) -> torch.Tensor | None:
|
|
44
|
+
"""
|
|
45
|
+
Calculates the Fisher weight decay scaler.
|
|
46
|
+
Maps the decay direction through the empirical Fisher information matrix
|
|
47
|
+
and clips its RMS to ensure stability.
|
|
48
|
+
From the paper:
|
|
49
|
+
"FAdam: Adam is a natural gradient optimizer using diagonal empirical Fisher information"
|
|
50
|
+
"""
|
|
51
|
+
if not group.get('fisher_wd', False):
|
|
52
|
+
return None
|
|
53
|
+
|
|
54
|
+
if atan2:
|
|
55
|
+
wd_scaler = torch.atan2(stored_scaler, denom).mul_(4 / math.pi)
|
|
56
|
+
else:
|
|
57
|
+
eps = group.get('eps', 1e-8)
|
|
58
|
+
wd_scaler = 1.0 / (denom + eps)
|
|
59
|
+
|
|
60
|
+
# Reshape scaler if necessary to match parameter shape (for factored states)
|
|
61
|
+
wd_scaler = wd_scaler.view(p.shape)
|
|
62
|
+
|
|
63
|
+
gw_rms = torch.sqrt(torch.mean((p * wd_scaler) ** 2))
|
|
64
|
+
clip_coef = torch.clamp(gw_rms / 1.0, min=1.0)
|
|
65
|
+
return wd_scaler / clip_coef
|
|
66
|
+
|
|
34
67
|
def _get_l1_adaptive_lr(
|
|
35
68
|
p: torch.Tensor,
|
|
36
69
|
update: torch.Tensor,
|
|
37
70
|
state: dict,
|
|
38
71
|
group: dict,
|
|
39
|
-
kappa_p: float
|
|
72
|
+
kappa_p: float,
|
|
73
|
+
rescale: bool = False,
|
|
40
74
|
) -> torch.Tensor:
|
|
41
75
|
"""
|
|
42
76
|
Calculates the L1 adaptive learning rate based on gradient heterogeneity.
|
|
43
77
|
"""
|
|
44
|
-
if not group.get("l1_adaptive", False)
|
|
78
|
+
if not group.get("l1_adaptive", False) or kappa_p != 1:
|
|
45
79
|
return None
|
|
46
80
|
|
|
47
|
-
momentum = group["momentum"]
|
|
48
|
-
alpha_grad = group["alpha_grad"]
|
|
49
81
|
update_view = update.view(p.shape)
|
|
50
82
|
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
83
|
+
if rescale:
|
|
84
|
+
momentum = group["momentum"]
|
|
85
|
+
alpha_grad = group["alpha_grad"]
|
|
86
|
+
|
|
87
|
+
# Calculate scale factor based on momentum/update magnitude
|
|
88
|
+
scale_factor = _scale_sim_AdEMAMix_update(
|
|
89
|
+
momentum, state["step"] + 1, alpha_grad, 1, False
|
|
90
|
+
)
|
|
91
|
+
else:
|
|
92
|
+
scale_factor = 1
|
|
55
93
|
|
|
56
94
|
# Determine dimension for mean calculation based on parameter type
|
|
57
95
|
if getattr(p, '_is_oft', False) or getattr(p, '_is_lora_A', False):
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|