adv-optm 2.4.dev16__tar.gz → 2.4.dev18__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {adv_optm-2.4.dev16 → adv_optm-2.4.dev18}/PKG-INFO +1 -1
- {adv_optm-2.4.dev16 → adv_optm-2.4.dev18}/adv_optm/__init__.py +1 -1
- {adv_optm-2.4.dev16 → adv_optm-2.4.dev18}/adv_optm/optim/AdamW_adv.py +66 -47
- {adv_optm-2.4.dev16 → adv_optm-2.4.dev18}/adv_optm/optim/SinkSGD_adv.py +47 -11
- {adv_optm-2.4.dev16 → adv_optm-2.4.dev18}/adv_optm/util/signed_util.py +16 -13
- {adv_optm-2.4.dev16 → adv_optm-2.4.dev18}/adv_optm/util/sinkhorn.py +70 -4
- {adv_optm-2.4.dev16 → adv_optm-2.4.dev18}/adv_optm.egg-info/PKG-INFO +1 -1
- {adv_optm-2.4.dev16 → adv_optm-2.4.dev18}/setup.py +1 -1
- {adv_optm-2.4.dev16 → adv_optm-2.4.dev18}/LICENSE +0 -0
- {adv_optm-2.4.dev16 → adv_optm-2.4.dev18}/README.md +0 -0
- {adv_optm-2.4.dev16 → adv_optm-2.4.dev18}/adv_optm/optim/AdaMuon_adv.py +0 -0
- {adv_optm-2.4.dev16 → adv_optm-2.4.dev18}/adv_optm/optim/Adopt_adv.py +0 -0
- {adv_optm-2.4.dev16 → adv_optm-2.4.dev18}/adv_optm/optim/Lion_Prodigy_adv.py +0 -0
- {adv_optm-2.4.dev16 → adv_optm-2.4.dev18}/adv_optm/optim/Lion_adv.py +0 -0
- {adv_optm-2.4.dev16 → adv_optm-2.4.dev18}/adv_optm/optim/Muon_adv.py +0 -0
- {adv_optm-2.4.dev16 → adv_optm-2.4.dev18}/adv_optm/optim/Prodigy_adv.py +0 -0
- {adv_optm-2.4.dev16 → adv_optm-2.4.dev18}/adv_optm/optim/SignSGD_adv.py +0 -0
- {adv_optm-2.4.dev16 → adv_optm-2.4.dev18}/adv_optm/optim/Simplified_AdEMAMix.py +0 -0
- {adv_optm-2.4.dev16 → adv_optm-2.4.dev18}/adv_optm/optim/__init__.py +0 -0
- {adv_optm-2.4.dev16 → adv_optm-2.4.dev18}/adv_optm/util/Kourkoutas.py +0 -0
- {adv_optm-2.4.dev16 → adv_optm-2.4.dev18}/adv_optm/util/Muon_AuxAdam.py +0 -0
- {adv_optm-2.4.dev16 → adv_optm-2.4.dev18}/adv_optm/util/Muon_util.py +0 -0
- {adv_optm-2.4.dev16 → adv_optm-2.4.dev18}/adv_optm/util/OrthoGrad.py +0 -0
- {adv_optm-2.4.dev16 → adv_optm-2.4.dev18}/adv_optm/util/__init__.py +0 -0
- {adv_optm-2.4.dev16 → adv_optm-2.4.dev18}/adv_optm/util/centered_decay.py +0 -0
- {adv_optm-2.4.dev16 → adv_optm-2.4.dev18}/adv_optm/util/factorization_util.py +0 -0
- {adv_optm-2.4.dev16 → adv_optm-2.4.dev18}/adv_optm/util/lion_k.py +0 -0
- {adv_optm-2.4.dev16 → adv_optm-2.4.dev18}/adv_optm/util/param_update.py +0 -0
- {adv_optm-2.4.dev16 → adv_optm-2.4.dev18}/adv_optm/util/scaled_optm.py +0 -0
- {adv_optm-2.4.dev16 → adv_optm-2.4.dev18}/adv_optm/util/state_util.py +0 -0
- {adv_optm-2.4.dev16 → adv_optm-2.4.dev18}/adv_optm/util/update_util.py +0 -0
- {adv_optm-2.4.dev16 → adv_optm-2.4.dev18}/adv_optm.egg-info/SOURCES.txt +0 -0
- {adv_optm-2.4.dev16 → adv_optm-2.4.dev18}/adv_optm.egg-info/dependency_links.txt +0 -0
- {adv_optm-2.4.dev16 → adv_optm-2.4.dev18}/adv_optm.egg-info/requires.txt +0 -0
- {adv_optm-2.4.dev16 → adv_optm-2.4.dev18}/adv_optm.egg-info/top_level.txt +0 -0
- {adv_optm-2.4.dev16 → adv_optm-2.4.dev18}/setup.cfg +0 -0
|
@@ -63,6 +63,7 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
63
63
|
before it is added to the fast momentum term (`update = mt + alpha * mt_slow`).
|
|
64
64
|
A higher value increases the stabilizing influence of the slow
|
|
65
65
|
momentum. (default: 5.0)
|
|
66
|
+
normed_momentum (bool): whether to compute the first moment on the normalized gradient. (default: False)
|
|
66
67
|
kourkoutas_beta (bool): whether to enable the layer-wise dynamic β₂ logic.
|
|
67
68
|
If `False`, the optimizer behaves as standard AdamW. (default: False)
|
|
68
69
|
beta2_min (float): The minimum value for dynamic β₂, used during periods of
|
|
@@ -131,6 +132,8 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
131
132
|
# Nesterov momentum
|
|
132
133
|
nesterov: bool = False,
|
|
133
134
|
nesterov_coef: float | None = None,
|
|
135
|
+
# Normalization then Momentum
|
|
136
|
+
normed_momentum: bool = False,
|
|
134
137
|
# K-b (adaptive beta2)
|
|
135
138
|
kourkoutas_beta: bool = False,
|
|
136
139
|
beta2_min: float = 0.9,
|
|
@@ -181,6 +184,7 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
181
184
|
"lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay,
|
|
182
185
|
"fisher_wd": fisher_wd, "cautious_wd": cautious_wd,
|
|
183
186
|
"use_atan2": use_atan2, "nesterov": nesterov, "nesterov_coef": nesterov_coef,
|
|
187
|
+
"normed_momentum": normed_momentum,
|
|
184
188
|
"orthogonal_gradient": orthogonal_gradient, "use_bias_correction": use_bias_correction,
|
|
185
189
|
"beta3_ema": beta3_ema, "alpha": alpha, "compiled_optimizer": compiled_optimizer,
|
|
186
190
|
"kourkoutas_beta": kourkoutas_beta, "beta2_min": beta2_min, "ema_alpha": ema_alpha,
|
|
@@ -383,6 +387,27 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
383
387
|
d1, d2 = state['effective_shape']
|
|
384
388
|
grad_reshaped = grad.view(d1, d2)
|
|
385
389
|
|
|
390
|
+
vt = _reconstruct_state((state['mu_v_nmf'], state['mv_v_nmf']), signed=False)
|
|
391
|
+
|
|
392
|
+
if isinstance(beta2, torch.Tensor) and beta2.dim() > 0:
|
|
393
|
+
vt.mul_(beta2).addcmul_(grad_reshaped, grad_reshaped * (1.0 - beta2))
|
|
394
|
+
else:
|
|
395
|
+
vt.mul_(beta2).addcmul_(grad_reshaped, grad_reshaped, value=1.0 - beta2)
|
|
396
|
+
|
|
397
|
+
# Factorize
|
|
398
|
+
state['mu_v_nmf'], state['mv_v_nmf'] = _factorize_state(vt, signed=False)
|
|
399
|
+
|
|
400
|
+
if group['use_atan2']:
|
|
401
|
+
denom = vt.sqrt_()
|
|
402
|
+
denom.div_(sqrt_bias_correction2)
|
|
403
|
+
if group.get('normed_momentum', False):
|
|
404
|
+
grad_reshaped.atan2_(denom)
|
|
405
|
+
else:
|
|
406
|
+
denom = vt.sqrt_()
|
|
407
|
+
denom.div_(sqrt_bias_correction2).add_(adaptive_eps)
|
|
408
|
+
if group.get('normed_momentum', False):
|
|
409
|
+
grad_reshaped.div_(denom)
|
|
410
|
+
|
|
386
411
|
# Reconstruct momentum from previous step's factors
|
|
387
412
|
if use_mt:
|
|
388
413
|
mt = _reconstruct_state((state['mu_m_nmf'], state['mv_m_nmf'], state['sign'], d2), signed=True)
|
|
@@ -404,13 +429,6 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
404
429
|
nv_coef = beta1 if nesterov_coef is None else nesterov_coef
|
|
405
430
|
update_mt = update_mt.lerp_(grad_reshaped, 1-nv_coef)
|
|
406
431
|
|
|
407
|
-
vt = _reconstruct_state((state['mu_v_nmf'], state['mv_v_nmf']), signed=False)
|
|
408
|
-
|
|
409
|
-
if isinstance(beta2, torch.Tensor) and beta2.dim() > 0:
|
|
410
|
-
vt.mul_(beta2).addcmul_(grad_reshaped, grad_reshaped * (1.0 - beta2))
|
|
411
|
-
else:
|
|
412
|
-
vt.mul_(beta2).addcmul_(grad_reshaped, grad_reshaped, value=1.0 - beta2)
|
|
413
|
-
|
|
414
432
|
if self.use_AdEMAMix:
|
|
415
433
|
mt_slow = _reconstruct_state((state['mu_m_slow_nmf'], state['mv_m_slow_nmf'], state['sign_slow'], d2), signed=True)
|
|
416
434
|
|
|
@@ -430,17 +448,11 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
430
448
|
else:
|
|
431
449
|
update = grad_reshaped.clone()
|
|
432
450
|
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
denom.div_(sqrt_bias_correction2)
|
|
439
|
-
update.atan2_(denom)
|
|
440
|
-
else:
|
|
441
|
-
denom = vt.sqrt_()
|
|
442
|
-
denom.div_(sqrt_bias_correction2).add_(adaptive_eps)
|
|
443
|
-
update.div_(denom)
|
|
451
|
+
if not group.get('normed_momentum', False):
|
|
452
|
+
if group['use_atan2']:
|
|
453
|
+
update.atan2_(denom)
|
|
454
|
+
else:
|
|
455
|
+
update.div_(denom)
|
|
444
456
|
|
|
445
457
|
wd_scaler = _get_fisher_wd_scaler(group, state.get("wd_scaler"), p, denom, group['use_atan2'])
|
|
446
458
|
|
|
@@ -452,6 +464,36 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
452
464
|
actual_precision = group['actual_state_precision']
|
|
453
465
|
factored_2nd = state.get('factored_2nd', False)
|
|
454
466
|
|
|
467
|
+
if factored_2nd:
|
|
468
|
+
d1, d2 = state['effective_shape']
|
|
469
|
+
exp_avg_sq = _reconstruct_state((state['mu_v_nmf'], state['mv_v_nmf']), signed=False)
|
|
470
|
+
exp_avg_sq = exp_avg_sq.view(p.shape)
|
|
471
|
+
else:
|
|
472
|
+
exp_avg_sq = get_state(state, 'exp_avg_sq', actual_precision)
|
|
473
|
+
|
|
474
|
+
grad_vt = grad.float() if factored_2nd else grad
|
|
475
|
+
|
|
476
|
+
if isinstance(beta2, torch.Tensor) and beta2.dim() > 0:
|
|
477
|
+
exp_avg_sq.mul_(beta2).addcmul_(grad_vt, grad_vt * (1.0 - beta2))
|
|
478
|
+
else:
|
|
479
|
+
exp_avg_sq.mul_(beta2).addcmul_(grad_vt, grad_vt, value=1.0 - beta2)
|
|
480
|
+
|
|
481
|
+
if factored_2nd:
|
|
482
|
+
state['mu_v_nmf'], state['mv_v_nmf'] = _factorize_state(exp_avg_sq.view(d1, d2), signed=False)
|
|
483
|
+
else:
|
|
484
|
+
set_state(state, 'exp_avg_sq', exp_avg_sq, actual_precision, random_int_state_tensor, non_neg=True)
|
|
485
|
+
|
|
486
|
+
if group['use_atan2']:
|
|
487
|
+
denom = exp_avg_sq.sqrt()
|
|
488
|
+
denom.div_(sqrt_bias_correction2)
|
|
489
|
+
if group.get('normed_momentum', False):
|
|
490
|
+
grad.atan2_(denom.to(grad.dtype))
|
|
491
|
+
else:
|
|
492
|
+
denom = exp_avg_sq.sqrt()
|
|
493
|
+
denom.div_(sqrt_bias_correction2).add_(adaptive_eps)
|
|
494
|
+
if group.get('normed_momentum', False):
|
|
495
|
+
grad.div_(denom.to(grad.dtype))
|
|
496
|
+
|
|
455
497
|
if use_mt:
|
|
456
498
|
exp_avg = get_state(state, 'exp_avg', actual_precision)
|
|
457
499
|
exp_avg.lerp_(grad, 1.0 - beta1)
|
|
@@ -481,38 +523,15 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
481
523
|
else:
|
|
482
524
|
update = update_mt if use_mt else grad.clone()
|
|
483
525
|
|
|
484
|
-
if
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
exp_avg_sq = get_state(state, 'exp_avg_sq', actual_precision)
|
|
490
|
-
|
|
491
|
-
grad_vt = grad.float() if factored_2nd else grad
|
|
492
|
-
|
|
493
|
-
if isinstance(beta2, torch.Tensor) and beta2.dim() > 0:
|
|
494
|
-
exp_avg_sq.mul_(beta2).addcmul_(grad_vt, grad_vt * (1.0 - beta2))
|
|
495
|
-
else:
|
|
496
|
-
exp_avg_sq.mul_(beta2).addcmul_(grad_vt, grad_vt, value=1.0 - beta2)
|
|
497
|
-
|
|
498
|
-
if factored_2nd:
|
|
499
|
-
state['mu_v_nmf'], state['mv_v_nmf'] = _factorize_state(exp_avg_sq.view(d1, d2), signed=False)
|
|
500
|
-
else:
|
|
501
|
-
set_state(state, 'exp_avg_sq', exp_avg_sq, actual_precision, random_int_state_tensor, non_neg=True)
|
|
502
|
-
del random_int_state_tensor
|
|
503
|
-
|
|
504
|
-
if group['use_atan2']:
|
|
505
|
-
denom = exp_avg_sq.sqrt()
|
|
506
|
-
denom.div_(sqrt_bias_correction2)
|
|
507
|
-
update.atan2_(denom.to(update.dtype))
|
|
508
|
-
else:
|
|
509
|
-
denom = exp_avg_sq.sqrt()
|
|
510
|
-
denom.div_(sqrt_bias_correction2).add_(adaptive_eps)
|
|
511
|
-
update.div_(denom.to(update.dtype))
|
|
526
|
+
if not group.get('normed_momentum', False):
|
|
527
|
+
if group['use_atan2']:
|
|
528
|
+
update.atan2_(denom.to(update.dtype))
|
|
529
|
+
else:
|
|
530
|
+
update.div_(denom.to(update.dtype))
|
|
512
531
|
|
|
513
532
|
wd_scaler = _get_fisher_wd_scaler(group, state.get("wd_scaler"), p, denom, group['use_atan2'])
|
|
514
533
|
|
|
515
|
-
del denom
|
|
534
|
+
del denom, random_int_state_tensor
|
|
516
535
|
|
|
517
536
|
update_scaling = step_size * A if group['use_atan2'] else step_size
|
|
518
537
|
if group.get('spectral_normalization', False):
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
|
|
3
|
-
|
|
3
|
+
import math
|
|
4
4
|
|
|
5
5
|
from ..util import param_update
|
|
6
6
|
from ..util.factorization_util import _get_effective_shape, _reconstruct_state, _factorize_state
|
|
@@ -9,7 +9,7 @@ from ..util.OrthoGrad import _orthogonalize_gradient
|
|
|
9
9
|
from ..util.scaled_optm import scale_update, is_spectral, init_spectral_norm
|
|
10
10
|
from ..util.centered_decay import _init_anchor
|
|
11
11
|
from ..util.state_util import init_state_tensor, get_state, set_state, upcast_grad_for_precision
|
|
12
|
-
from ..util.sinkhorn import apply_sr_sinkhorn
|
|
12
|
+
from ..util.sinkhorn import apply_sr_sinkhorn, _sinkhorn_sq_grad, get_sinkhorn_wd_scaler
|
|
13
13
|
from ..util.signed_util import apply_stochastic_sign_
|
|
14
14
|
|
|
15
15
|
class SinkSGD_adv(torch.optim.Optimizer):
|
|
@@ -26,8 +26,6 @@ class SinkSGD_adv(torch.optim.Optimizer):
|
|
|
26
26
|
weight_decay (float): weight decay (L2 penalty or decoupled) (default: 0).
|
|
27
27
|
nesterov (bool): enables Nesterov momentum. Only applicable when momentum
|
|
28
28
|
is non-zero. (default: False)
|
|
29
|
-
decoupled_wd (bool): whether to apply decoupled weight decay (like AdamW)
|
|
30
|
-
instead of standard L2 penalty. (default: False)
|
|
31
29
|
cautious_wd (bool): Enables Cautious Weight Decay. If True, weight decay is
|
|
32
30
|
applied only to parameter coordinates where the sign of the parameter
|
|
33
31
|
and the sign of the optimizer update align (default: False).
|
|
@@ -61,11 +59,13 @@ class SinkSGD_adv(torch.optim.Optimizer):
|
|
|
61
59
|
orthogonal_sinkhorn: bool = False,
|
|
62
60
|
# Normalization then Momentum
|
|
63
61
|
normed_momentum: bool = False,
|
|
62
|
+
# Centered Variance Precondition
|
|
63
|
+
centered_vt: bool = False,
|
|
64
64
|
# Nesterov Momentum
|
|
65
65
|
nesterov: bool = False,
|
|
66
66
|
nesterov_coef: float | None = None,
|
|
67
|
-
#
|
|
68
|
-
|
|
67
|
+
# weight decay features
|
|
68
|
+
geometric_wd: bool = False,
|
|
69
69
|
cautious_wd: bool = False,
|
|
70
70
|
# Stochastic Rounding for BF16
|
|
71
71
|
stochastic_rounding: bool = True,
|
|
@@ -101,8 +101,8 @@ class SinkSGD_adv(torch.optim.Optimizer):
|
|
|
101
101
|
|
|
102
102
|
defaults = {
|
|
103
103
|
"lr": lr, "momentum": momentum,
|
|
104
|
-
"weight_decay": weight_decay, "nesterov": nesterov, "nesterov_coef": nesterov_coef, "normed_momentum": normed_momentum,
|
|
105
|
-
"
|
|
104
|
+
"weight_decay": weight_decay, "nesterov": nesterov, "nesterov_coef": nesterov_coef, "normed_momentum": normed_momentum, "centered_vt": centered_vt,
|
|
105
|
+
"geometric_wd": geometric_wd, "cautious_wd": cautious_wd,
|
|
106
106
|
"orthogonal_gradient": orthogonal_gradient,
|
|
107
107
|
"compiled_optimizer": compiled_optimizer,
|
|
108
108
|
"sinkhorn_iterations": sinkhorn_iterations,
|
|
@@ -182,6 +182,11 @@ class SinkSGD_adv(torch.optim.Optimizer):
|
|
|
182
182
|
if group['momentum'] != 0:
|
|
183
183
|
init_state_tensor(state, 'momentum_buffer', p.shape, actual_precision, p.device, dtype)
|
|
184
184
|
|
|
185
|
+
if group.get('centered_vt', False):
|
|
186
|
+
p_shape = p.shape
|
|
187
|
+
state['vt_row'] = torch.zeros(p_shape[:-1], device=device, dtype=torch.float32)
|
|
188
|
+
state['vt_col'] = torch.zeros(p_shape[:-2] + p_shape[-1:], device=device, dtype=torch.float32)
|
|
189
|
+
|
|
185
190
|
if group.get('spectral_normalization', False) and is_spectral(p):
|
|
186
191
|
init_spectral_norm(state, p)
|
|
187
192
|
|
|
@@ -237,7 +242,7 @@ class SinkSGD_adv(torch.optim.Optimizer):
|
|
|
237
242
|
if group.get('normed_momentum', False):
|
|
238
243
|
if not is_vector:
|
|
239
244
|
# Sinkhorn iterative normalization
|
|
240
|
-
grad = apply_sr_sinkhorn(grad, p, ortho_project=orthogonal_sinkhorn
|
|
245
|
+
grad = apply_sr_sinkhorn(grad, iters=sinkhorn_iterations, p=p, ortho_project=orthogonal_sinkhorn)
|
|
241
246
|
else:
|
|
242
247
|
# For vectors, apply adaptive stochastic sign
|
|
243
248
|
grad = apply_stochastic_sign_(grad, sign_noise, is_vector=is_vector)
|
|
@@ -271,6 +276,24 @@ class SinkSGD_adv(torch.optim.Optimizer):
|
|
|
271
276
|
|
|
272
277
|
if momentum != 0:
|
|
273
278
|
buf = get_state(state, 'momentum_buffer', actual_precision)
|
|
279
|
+
|
|
280
|
+
if group.get('centered_vt', False):
|
|
281
|
+
vt_row, vt_col = state['vt_row'], state['vt_col']
|
|
282
|
+
grad_vt = grad - buf
|
|
283
|
+
grad_vt_sq = grad_vt * grad_vt
|
|
284
|
+
mean_row_grad = grad_vt_sq.mean(dim=-1)
|
|
285
|
+
mean_col_grad = grad_vt_sq.mean(dim=-2)
|
|
286
|
+
vt_row.mul_(momentum).add_(mean_row_grad, alpha=1.0 - momentum)
|
|
287
|
+
vt_col.mul_(momentum).add_(mean_col_grad, alpha=1.0 - momentum)
|
|
288
|
+
if nesterov:
|
|
289
|
+
nv_coef = momentum if nesterov_coef is None else nesterov_coef
|
|
290
|
+
vt_row = vt_row.lerp(mean_row_grad, 1.0 - nv_coef)
|
|
291
|
+
vt_col = vt_col.lerp(mean_col_grad, 1.0 - nv_coef)
|
|
292
|
+
vt = _sinkhorn_sq_grad(vt_row, vt_col)
|
|
293
|
+
else:
|
|
294
|
+
vt_row = None
|
|
295
|
+
vt_col = None
|
|
296
|
+
|
|
274
297
|
buf.lerp_(grad, 1 - momentum)
|
|
275
298
|
|
|
276
299
|
set_state(state, 'momentum_buffer', buf, actual_precision, random_int_state_tensor)
|
|
@@ -285,21 +308,34 @@ class SinkSGD_adv(torch.optim.Optimizer):
|
|
|
285
308
|
|
|
286
309
|
del random_int_state_tensor
|
|
287
310
|
|
|
311
|
+
if group.get('centered_vt', False):
|
|
312
|
+
denom = vt
|
|
313
|
+
update.atan2_(denom)
|
|
314
|
+
else:
|
|
315
|
+
denom = None
|
|
316
|
+
|
|
288
317
|
if not group.get('normed_momentum', False):
|
|
289
318
|
if not is_vector:
|
|
290
319
|
# Sinkhorn iterative normalization
|
|
291
|
-
update = apply_sr_sinkhorn(update, p, ortho_project=orthogonal_sinkhorn
|
|
320
|
+
update = apply_sr_sinkhorn(update, iters=sinkhorn_iterations, p=p, ortho_project=orthogonal_sinkhorn)
|
|
292
321
|
else:
|
|
293
322
|
# For vectors, apply adaptive stochastic sign
|
|
294
323
|
update = apply_stochastic_sign_(update, sign_noise, is_vector=is_vector)
|
|
295
324
|
|
|
325
|
+
if group.get('geometric_wd', False):
|
|
326
|
+
wd_scaler = get_sinkhorn_wd_scaler(p, row_denom=vt_row, col_denom=vt_col)
|
|
327
|
+
else:
|
|
328
|
+
wd_scaler = None
|
|
329
|
+
|
|
296
330
|
update_scaling = step_size
|
|
297
331
|
if group.get('spectral_normalization', False):
|
|
298
332
|
update = scale_update(p, update, update_scaling, state=state)
|
|
299
333
|
else:
|
|
334
|
+
if group.get('centered_vt', False):
|
|
335
|
+
update_scaling = update_scaling * (4/math.pi)
|
|
300
336
|
update.mul_(update_scaling)
|
|
301
337
|
|
|
302
|
-
param_update.apply_parameter_update(self, p, group, update, step_size, random_int_tensor=random_int_tensor)
|
|
338
|
+
param_update.apply_parameter_update(self, p, group, update, step_size, random_int_tensor=random_int_tensor, wd_scaler=wd_scaler)
|
|
303
339
|
|
|
304
340
|
def compile(self, *args, **kwargs):
|
|
305
341
|
self._compiled_step_parameter = torch.compile(self._step_parameter, *args, **kwargs)
|
|
@@ -4,15 +4,19 @@ from . import param_update
|
|
|
4
4
|
|
|
5
5
|
def apply_stochastic_sign_(update: torch.Tensor, noise: torch.Tensor | None, is_vector: bool = False) -> torch.Tensor:
|
|
6
6
|
"""
|
|
7
|
-
Applies the Stochastic Sign operator
|
|
7
|
+
Applies the Iterative L-infinity Stochastic Sign operator.
|
|
8
8
|
Uses uniform noise injection to compute the stochastic sign
|
|
9
9
|
"""
|
|
10
10
|
if update.dim() >= 2 and not is_vector:
|
|
11
|
-
|
|
12
|
-
#
|
|
13
|
-
|
|
14
|
-
R_row =
|
|
15
|
-
|
|
11
|
+
# Iterative L-infinity Sinkhorn algorithm
|
|
12
|
+
# This converges in just one iteration
|
|
13
|
+
# Step 1: Row Max (every row max is 1.0, all values <= 1.0)
|
|
14
|
+
R_row = torch.linalg.vector_norm(update, ord=float('inf'), dim=1, keepdim=True).clamp_min_(1e-12)
|
|
15
|
+
update.div_(R_row)
|
|
16
|
+
|
|
17
|
+
# Step 2: Col Max (every col max is 1.0 and every row max stays 1.0)
|
|
18
|
+
R_col = torch.linalg.vector_norm(update, ord=float('inf'), dim=0, keepdim=True).clamp_min_(1e-12)
|
|
19
|
+
update.div_(R_col)
|
|
16
20
|
else:
|
|
17
21
|
# Fallback for 1D tensors (e.g., biases, layernorm)
|
|
18
22
|
# Block-wise scaling to protect against outliers
|
|
@@ -21,7 +25,8 @@ def apply_stochastic_sign_(update: torch.Tensor, noise: torch.Tensor | None, is_
|
|
|
21
25
|
|
|
22
26
|
if numel <= block_size:
|
|
23
27
|
# Too small to chunk, just use global max
|
|
24
|
-
R = update.abs().max()
|
|
28
|
+
R = update.abs().max().clamp_min_(1e-12)
|
|
29
|
+
update.div_(R)
|
|
25
30
|
else:
|
|
26
31
|
# Calculate how much padding we need to make it divisible by block_size
|
|
27
32
|
remainder = numel % block_size
|
|
@@ -41,13 +46,11 @@ def apply_stochastic_sign_(update: torch.Tensor, noise: torch.Tensor | None, is_
|
|
|
41
46
|
R_blocks = blocks.abs().max(dim=1, keepdim=True).values
|
|
42
47
|
|
|
43
48
|
# Broadcast R_blocks back to the padded shape, slice off padding, and restore original shape
|
|
44
|
-
R = R_blocks.expand_as(blocks).reshape(-1)[:numel].view_as(update)
|
|
45
|
-
|
|
46
|
-
# Prevent division by zero
|
|
47
|
-
R = R.clamp_min(1e-12)
|
|
49
|
+
R = R_blocks.expand_as(blocks).reshape(-1)[:numel].view_as(update).clamp_min(1e-12)
|
|
50
|
+
update.div_(R)
|
|
48
51
|
|
|
49
52
|
if noise is None:
|
|
50
53
|
noise = param_update._get_random_noise_for_sso(update)
|
|
51
54
|
|
|
52
|
-
#
|
|
53
|
-
return update.
|
|
55
|
+
# Final stochastic step: sign(v + U[-1, 1])
|
|
56
|
+
return update.add_(noise).sign_()
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import math
|
|
2
2
|
import torch
|
|
3
3
|
|
|
4
|
-
def apply_sr_sinkhorn(update: torch.Tensor, p: torch.Tensor, ortho_project: bool
|
|
4
|
+
def apply_sr_sinkhorn(update: torch.Tensor, iters: int = 5, p: torch.Tensor | None = None, ortho_project: bool = False) -> torch.Tensor:
|
|
5
5
|
"""
|
|
6
6
|
Applies Square-Root Sinkhorn (SR-Sinkhorn) multi-normalization.
|
|
7
7
|
As described in 'Gradient Multi-Normalization for Efficient LLM Training'.
|
|
@@ -47,13 +47,16 @@ def apply_sr_sinkhorn(update: torch.Tensor, p: torch.Tensor, ortho_project: bool
|
|
|
47
47
|
# In-place alternating Sinkhorn normalization steps
|
|
48
48
|
for _ in range(iters):
|
|
49
49
|
# First normalization step
|
|
50
|
-
|
|
50
|
+
# Stability floor: equivalent to a single-element vector norm lower bound (lb)
|
|
51
|
+
norm1_lb = 1 / math.sqrt(update_2d.shape[dim])
|
|
52
|
+
norm1 = update_2d.norm(p=2, dim=dim, keepdim=True).clamp_min_(norm1_lb)
|
|
51
53
|
update_2d.mul_(scale_first / norm1)
|
|
52
54
|
if ortho_project:
|
|
53
55
|
update_2d = ortho_normed(param_2d, update_2d, p_norm_sq_dim, dim, scale_first)
|
|
54
56
|
|
|
55
57
|
# Second normalization step
|
|
56
|
-
|
|
58
|
+
norm2_lb = 1 / math.sqrt(update_2d.shape[1-dim])
|
|
59
|
+
norm2 = update_2d.norm(p=2, dim=1-dim, keepdim=True).clamp_min_(norm2_lb)
|
|
57
60
|
update_2d.mul_(scale_second / norm2)
|
|
58
61
|
if ortho_project:
|
|
59
62
|
update_2d = ortho_normed(param_2d, update_2d, p_norm_sq_adim, 1-dim, scale_second)
|
|
@@ -72,6 +75,69 @@ def ortho_normed(p_2d, update_2d, p_norm_sq, dim, target_norm):
|
|
|
72
75
|
update_2d.addcmul_(proj, p_2d, value=-1.0)
|
|
73
76
|
|
|
74
77
|
# Magnitude Preservation
|
|
75
|
-
|
|
78
|
+
norm_lb = 1 / math.sqrt(update_2d.shape[dim])
|
|
79
|
+
g_orth_norm = update_2d.norm(p=2, dim=dim, keepdim=True).clamp_min_(norm_lb)
|
|
76
80
|
scale_factor = target_norm / g_orth_norm
|
|
77
81
|
return update_2d.mul_(scale_factor)
|
|
82
|
+
|
|
83
|
+
def _sinkhorn_sq_grad(
|
|
84
|
+
vt_row: torch.Tensor,
|
|
85
|
+
vt_col: torch.Tensor,
|
|
86
|
+
) -> torch.Tensor:
|
|
87
|
+
"""
|
|
88
|
+
Reconstructs the variance precondition from its rank-1 factors.
|
|
89
|
+
Modified from:
|
|
90
|
+
https://github.com/jettify/pytorch-optimizer/blob/master/torch_optimizer/adafactor.py
|
|
91
|
+
"""
|
|
92
|
+
r_factor = (
|
|
93
|
+
(vt_row / vt_row.mean(dim=-1).clamp_min_(1e-30))
|
|
94
|
+
.sqrt_()
|
|
95
|
+
.unsqueeze(-1)
|
|
96
|
+
)
|
|
97
|
+
c_factor = vt_col.unsqueeze(-2).sqrt()
|
|
98
|
+
return torch.mul(r_factor, c_factor)
|
|
99
|
+
|
|
100
|
+
def get_sinkhorn_wd_scaler(
|
|
101
|
+
p: torch.Tensor,
|
|
102
|
+
row_denom: torch.Tensor | None = None,
|
|
103
|
+
col_denom: torch.Tensor | None = None
|
|
104
|
+
):
|
|
105
|
+
"""
|
|
106
|
+
Computes a structural weight decay multiplier.
|
|
107
|
+
Penalizes parameters belonging to dominant rows/columns more heavily,
|
|
108
|
+
while protecting parameters in under-utilized/noisy rows/columns from decay.
|
|
109
|
+
"""
|
|
110
|
+
if p.ndim < 2:
|
|
111
|
+
return 1.0
|
|
112
|
+
|
|
113
|
+
p_2d = p.view(p.shape[0], -1)
|
|
114
|
+
|
|
115
|
+
# Lower bounds based on the effective 2D shapes
|
|
116
|
+
row_lb = 1 / math.sqrt(p_2d.shape[1])
|
|
117
|
+
col_lb = 1 / math.sqrt(p_2d.shape[0])
|
|
118
|
+
|
|
119
|
+
# Get the norms
|
|
120
|
+
row_norms = torch.linalg.vector_norm(p_2d, ord=2, dim=1, keepdim=True).clamp_min_(row_lb)
|
|
121
|
+
col_norms = torch.linalg.vector_norm(p_2d, ord=2, dim=0, keepdim=True).clamp_min_(col_lb)
|
|
122
|
+
|
|
123
|
+
# Compute the structural scaler
|
|
124
|
+
row_factor = row_norms.sqrt_()
|
|
125
|
+
col_factor = col_norms.sqrt_()
|
|
126
|
+
|
|
127
|
+
if row_denom is not None and col_denom is not None:
|
|
128
|
+
# Reshape denominators to ensure safe in-place broadcasting
|
|
129
|
+
row_denom = row_denom.view(p_2d.shape[0], 1)
|
|
130
|
+
col_denom = col_denom.view(1, p_2d.shape[1])
|
|
131
|
+
|
|
132
|
+
# High denom (noise) -> smaller angle (protects weights)
|
|
133
|
+
# Low denom (confident) -> larger angle (decays weights)
|
|
134
|
+
row_factor.atan2_(row_denom)
|
|
135
|
+
col_factor.atan2_(col_denom)
|
|
136
|
+
|
|
137
|
+
# Outer product: merges the row and column confidences into a 2D matrix
|
|
138
|
+
wd_scaler = row_factor * col_factor
|
|
139
|
+
|
|
140
|
+
# Normalize the scaler so its mean is exactly 1.0
|
|
141
|
+
wd_scaler.div_(wd_scaler.mean().clamp_min_(1e-12))
|
|
142
|
+
|
|
143
|
+
return wd_scaler.view_as(p)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|