adv-optm 2.4.dev16__tar.gz → 2.4.dev18__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. {adv_optm-2.4.dev16 → adv_optm-2.4.dev18}/PKG-INFO +1 -1
  2. {adv_optm-2.4.dev16 → adv_optm-2.4.dev18}/adv_optm/__init__.py +1 -1
  3. {adv_optm-2.4.dev16 → adv_optm-2.4.dev18}/adv_optm/optim/AdamW_adv.py +66 -47
  4. {adv_optm-2.4.dev16 → adv_optm-2.4.dev18}/adv_optm/optim/SinkSGD_adv.py +47 -11
  5. {adv_optm-2.4.dev16 → adv_optm-2.4.dev18}/adv_optm/util/signed_util.py +16 -13
  6. {adv_optm-2.4.dev16 → adv_optm-2.4.dev18}/adv_optm/util/sinkhorn.py +70 -4
  7. {adv_optm-2.4.dev16 → adv_optm-2.4.dev18}/adv_optm.egg-info/PKG-INFO +1 -1
  8. {adv_optm-2.4.dev16 → adv_optm-2.4.dev18}/setup.py +1 -1
  9. {adv_optm-2.4.dev16 → adv_optm-2.4.dev18}/LICENSE +0 -0
  10. {adv_optm-2.4.dev16 → adv_optm-2.4.dev18}/README.md +0 -0
  11. {adv_optm-2.4.dev16 → adv_optm-2.4.dev18}/adv_optm/optim/AdaMuon_adv.py +0 -0
  12. {adv_optm-2.4.dev16 → adv_optm-2.4.dev18}/adv_optm/optim/Adopt_adv.py +0 -0
  13. {adv_optm-2.4.dev16 → adv_optm-2.4.dev18}/adv_optm/optim/Lion_Prodigy_adv.py +0 -0
  14. {adv_optm-2.4.dev16 → adv_optm-2.4.dev18}/adv_optm/optim/Lion_adv.py +0 -0
  15. {adv_optm-2.4.dev16 → adv_optm-2.4.dev18}/adv_optm/optim/Muon_adv.py +0 -0
  16. {adv_optm-2.4.dev16 → adv_optm-2.4.dev18}/adv_optm/optim/Prodigy_adv.py +0 -0
  17. {adv_optm-2.4.dev16 → adv_optm-2.4.dev18}/adv_optm/optim/SignSGD_adv.py +0 -0
  18. {adv_optm-2.4.dev16 → adv_optm-2.4.dev18}/adv_optm/optim/Simplified_AdEMAMix.py +0 -0
  19. {adv_optm-2.4.dev16 → adv_optm-2.4.dev18}/adv_optm/optim/__init__.py +0 -0
  20. {adv_optm-2.4.dev16 → adv_optm-2.4.dev18}/adv_optm/util/Kourkoutas.py +0 -0
  21. {adv_optm-2.4.dev16 → adv_optm-2.4.dev18}/adv_optm/util/Muon_AuxAdam.py +0 -0
  22. {adv_optm-2.4.dev16 → adv_optm-2.4.dev18}/adv_optm/util/Muon_util.py +0 -0
  23. {adv_optm-2.4.dev16 → adv_optm-2.4.dev18}/adv_optm/util/OrthoGrad.py +0 -0
  24. {adv_optm-2.4.dev16 → adv_optm-2.4.dev18}/adv_optm/util/__init__.py +0 -0
  25. {adv_optm-2.4.dev16 → adv_optm-2.4.dev18}/adv_optm/util/centered_decay.py +0 -0
  26. {adv_optm-2.4.dev16 → adv_optm-2.4.dev18}/adv_optm/util/factorization_util.py +0 -0
  27. {adv_optm-2.4.dev16 → adv_optm-2.4.dev18}/adv_optm/util/lion_k.py +0 -0
  28. {adv_optm-2.4.dev16 → adv_optm-2.4.dev18}/adv_optm/util/param_update.py +0 -0
  29. {adv_optm-2.4.dev16 → adv_optm-2.4.dev18}/adv_optm/util/scaled_optm.py +0 -0
  30. {adv_optm-2.4.dev16 → adv_optm-2.4.dev18}/adv_optm/util/state_util.py +0 -0
  31. {adv_optm-2.4.dev16 → adv_optm-2.4.dev18}/adv_optm/util/update_util.py +0 -0
  32. {adv_optm-2.4.dev16 → adv_optm-2.4.dev18}/adv_optm.egg-info/SOURCES.txt +0 -0
  33. {adv_optm-2.4.dev16 → adv_optm-2.4.dev18}/adv_optm.egg-info/dependency_links.txt +0 -0
  34. {adv_optm-2.4.dev16 → adv_optm-2.4.dev18}/adv_optm.egg-info/requires.txt +0 -0
  35. {adv_optm-2.4.dev16 → adv_optm-2.4.dev18}/adv_optm.egg-info/top_level.txt +0 -0
  36. {adv_optm-2.4.dev16 → adv_optm-2.4.dev18}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: adv_optm
3
- Version: 2.4.dev16
3
+ Version: 2.4.dev18
4
4
  Summary: A family of highly efficient, lightweight yet powerful optimizers.
5
5
  Home-page: https://github.com/Koratahiu/Advanced_Optimizers
6
6
  Author: Koratahiu
@@ -24,4 +24,4 @@ __all__ = [
24
24
  "SinkSGD_adv",
25
25
  ]
26
26
 
27
- __version__ = "2.4.dev16"
27
+ __version__ = "2.4.dev18"
@@ -63,6 +63,7 @@ class AdamW_adv(torch.optim.Optimizer):
63
63
  before it is added to the fast momentum term (`update = mt + alpha * mt_slow`).
64
64
  A higher value increases the stabilizing influence of the slow
65
65
  momentum. (default: 5.0)
66
+ normed_momentum (bool): whether to compute the first moment on the normalized gradient. (default: False)
66
67
  kourkoutas_beta (bool): whether to enable the layer-wise dynamic β₂ logic.
67
68
  If `False`, the optimizer behaves as standard AdamW. (default: False)
68
69
  beta2_min (float): The minimum value for dynamic β₂, used during periods of
@@ -131,6 +132,8 @@ class AdamW_adv(torch.optim.Optimizer):
131
132
  # Nesterov momentum
132
133
  nesterov: bool = False,
133
134
  nesterov_coef: float | None = None,
135
+ # Normalization then Momentum
136
+ normed_momentum: bool = False,
134
137
  # K-b (adaptive beta2)
135
138
  kourkoutas_beta: bool = False,
136
139
  beta2_min: float = 0.9,
@@ -181,6 +184,7 @@ class AdamW_adv(torch.optim.Optimizer):
181
184
  "lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay,
182
185
  "fisher_wd": fisher_wd, "cautious_wd": cautious_wd,
183
186
  "use_atan2": use_atan2, "nesterov": nesterov, "nesterov_coef": nesterov_coef,
187
+ "normed_momentum": normed_momentum,
184
188
  "orthogonal_gradient": orthogonal_gradient, "use_bias_correction": use_bias_correction,
185
189
  "beta3_ema": beta3_ema, "alpha": alpha, "compiled_optimizer": compiled_optimizer,
186
190
  "kourkoutas_beta": kourkoutas_beta, "beta2_min": beta2_min, "ema_alpha": ema_alpha,
@@ -383,6 +387,27 @@ class AdamW_adv(torch.optim.Optimizer):
383
387
  d1, d2 = state['effective_shape']
384
388
  grad_reshaped = grad.view(d1, d2)
385
389
 
390
+ vt = _reconstruct_state((state['mu_v_nmf'], state['mv_v_nmf']), signed=False)
391
+
392
+ if isinstance(beta2, torch.Tensor) and beta2.dim() > 0:
393
+ vt.mul_(beta2).addcmul_(grad_reshaped, grad_reshaped * (1.0 - beta2))
394
+ else:
395
+ vt.mul_(beta2).addcmul_(grad_reshaped, grad_reshaped, value=1.0 - beta2)
396
+
397
+ # Factorize
398
+ state['mu_v_nmf'], state['mv_v_nmf'] = _factorize_state(vt, signed=False)
399
+
400
+ if group['use_atan2']:
401
+ denom = vt.sqrt_()
402
+ denom.div_(sqrt_bias_correction2)
403
+ if group.get('normed_momentum', False):
404
+ grad_reshaped.atan2_(denom)
405
+ else:
406
+ denom = vt.sqrt_()
407
+ denom.div_(sqrt_bias_correction2).add_(adaptive_eps)
408
+ if group.get('normed_momentum', False):
409
+ grad_reshaped.div_(denom)
410
+
386
411
  # Reconstruct momentum from previous step's factors
387
412
  if use_mt:
388
413
  mt = _reconstruct_state((state['mu_m_nmf'], state['mv_m_nmf'], state['sign'], d2), signed=True)
@@ -404,13 +429,6 @@ class AdamW_adv(torch.optim.Optimizer):
404
429
  nv_coef = beta1 if nesterov_coef is None else nesterov_coef
405
430
  update_mt = update_mt.lerp_(grad_reshaped, 1-nv_coef)
406
431
 
407
- vt = _reconstruct_state((state['mu_v_nmf'], state['mv_v_nmf']), signed=False)
408
-
409
- if isinstance(beta2, torch.Tensor) and beta2.dim() > 0:
410
- vt.mul_(beta2).addcmul_(grad_reshaped, grad_reshaped * (1.0 - beta2))
411
- else:
412
- vt.mul_(beta2).addcmul_(grad_reshaped, grad_reshaped, value=1.0 - beta2)
413
-
414
432
  if self.use_AdEMAMix:
415
433
  mt_slow = _reconstruct_state((state['mu_m_slow_nmf'], state['mv_m_slow_nmf'], state['sign_slow'], d2), signed=True)
416
434
 
@@ -430,17 +448,11 @@ class AdamW_adv(torch.optim.Optimizer):
430
448
  else:
431
449
  update = grad_reshaped.clone()
432
450
 
433
- # Factorize
434
- state['mu_v_nmf'], state['mv_v_nmf'] = _factorize_state(vt, signed=False)
435
-
436
- if group['use_atan2']:
437
- denom = vt.sqrt_()
438
- denom.div_(sqrt_bias_correction2)
439
- update.atan2_(denom)
440
- else:
441
- denom = vt.sqrt_()
442
- denom.div_(sqrt_bias_correction2).add_(adaptive_eps)
443
- update.div_(denom)
451
+ if not group.get('normed_momentum', False):
452
+ if group['use_atan2']:
453
+ update.atan2_(denom)
454
+ else:
455
+ update.div_(denom)
444
456
 
445
457
  wd_scaler = _get_fisher_wd_scaler(group, state.get("wd_scaler"), p, denom, group['use_atan2'])
446
458
 
@@ -452,6 +464,36 @@ class AdamW_adv(torch.optim.Optimizer):
452
464
  actual_precision = group['actual_state_precision']
453
465
  factored_2nd = state.get('factored_2nd', False)
454
466
 
467
+ if factored_2nd:
468
+ d1, d2 = state['effective_shape']
469
+ exp_avg_sq = _reconstruct_state((state['mu_v_nmf'], state['mv_v_nmf']), signed=False)
470
+ exp_avg_sq = exp_avg_sq.view(p.shape)
471
+ else:
472
+ exp_avg_sq = get_state(state, 'exp_avg_sq', actual_precision)
473
+
474
+ grad_vt = grad.float() if factored_2nd else grad
475
+
476
+ if isinstance(beta2, torch.Tensor) and beta2.dim() > 0:
477
+ exp_avg_sq.mul_(beta2).addcmul_(grad_vt, grad_vt * (1.0 - beta2))
478
+ else:
479
+ exp_avg_sq.mul_(beta2).addcmul_(grad_vt, grad_vt, value=1.0 - beta2)
480
+
481
+ if factored_2nd:
482
+ state['mu_v_nmf'], state['mv_v_nmf'] = _factorize_state(exp_avg_sq.view(d1, d2), signed=False)
483
+ else:
484
+ set_state(state, 'exp_avg_sq', exp_avg_sq, actual_precision, random_int_state_tensor, non_neg=True)
485
+
486
+ if group['use_atan2']:
487
+ denom = exp_avg_sq.sqrt()
488
+ denom.div_(sqrt_bias_correction2)
489
+ if group.get('normed_momentum', False):
490
+ grad.atan2_(denom.to(grad.dtype))
491
+ else:
492
+ denom = exp_avg_sq.sqrt()
493
+ denom.div_(sqrt_bias_correction2).add_(adaptive_eps)
494
+ if group.get('normed_momentum', False):
495
+ grad.div_(denom.to(grad.dtype))
496
+
455
497
  if use_mt:
456
498
  exp_avg = get_state(state, 'exp_avg', actual_precision)
457
499
  exp_avg.lerp_(grad, 1.0 - beta1)
@@ -481,38 +523,15 @@ class AdamW_adv(torch.optim.Optimizer):
481
523
  else:
482
524
  update = update_mt if use_mt else grad.clone()
483
525
 
484
- if factored_2nd:
485
- d1, d2 = state['effective_shape']
486
- exp_avg_sq = _reconstruct_state((state['mu_v_nmf'], state['mv_v_nmf']), signed=False)
487
- exp_avg_sq = exp_avg_sq.view(p.shape)
488
- else:
489
- exp_avg_sq = get_state(state, 'exp_avg_sq', actual_precision)
490
-
491
- grad_vt = grad.float() if factored_2nd else grad
492
-
493
- if isinstance(beta2, torch.Tensor) and beta2.dim() > 0:
494
- exp_avg_sq.mul_(beta2).addcmul_(grad_vt, grad_vt * (1.0 - beta2))
495
- else:
496
- exp_avg_sq.mul_(beta2).addcmul_(grad_vt, grad_vt, value=1.0 - beta2)
497
-
498
- if factored_2nd:
499
- state['mu_v_nmf'], state['mv_v_nmf'] = _factorize_state(exp_avg_sq.view(d1, d2), signed=False)
500
- else:
501
- set_state(state, 'exp_avg_sq', exp_avg_sq, actual_precision, random_int_state_tensor, non_neg=True)
502
- del random_int_state_tensor
503
-
504
- if group['use_atan2']:
505
- denom = exp_avg_sq.sqrt()
506
- denom.div_(sqrt_bias_correction2)
507
- update.atan2_(denom.to(update.dtype))
508
- else:
509
- denom = exp_avg_sq.sqrt()
510
- denom.div_(sqrt_bias_correction2).add_(adaptive_eps)
511
- update.div_(denom.to(update.dtype))
526
+ if not group.get('normed_momentum', False):
527
+ if group['use_atan2']:
528
+ update.atan2_(denom.to(update.dtype))
529
+ else:
530
+ update.div_(denom.to(update.dtype))
512
531
 
513
532
  wd_scaler = _get_fisher_wd_scaler(group, state.get("wd_scaler"), p, denom, group['use_atan2'])
514
533
 
515
- del denom
534
+ del denom, random_int_state_tensor
516
535
 
517
536
  update_scaling = step_size * A if group['use_atan2'] else step_size
518
537
  if group.get('spectral_normalization', False):
@@ -1,6 +1,6 @@
1
1
  import torch
2
2
 
3
- from typing import Optional, Callable
3
+ import math
4
4
 
5
5
  from ..util import param_update
6
6
  from ..util.factorization_util import _get_effective_shape, _reconstruct_state, _factorize_state
@@ -9,7 +9,7 @@ from ..util.OrthoGrad import _orthogonalize_gradient
9
9
  from ..util.scaled_optm import scale_update, is_spectral, init_spectral_norm
10
10
  from ..util.centered_decay import _init_anchor
11
11
  from ..util.state_util import init_state_tensor, get_state, set_state, upcast_grad_for_precision
12
- from ..util.sinkhorn import apply_sr_sinkhorn
12
+ from ..util.sinkhorn import apply_sr_sinkhorn, _sinkhorn_sq_grad, get_sinkhorn_wd_scaler
13
13
  from ..util.signed_util import apply_stochastic_sign_
14
14
 
15
15
  class SinkSGD_adv(torch.optim.Optimizer):
@@ -26,8 +26,6 @@ class SinkSGD_adv(torch.optim.Optimizer):
26
26
  weight_decay (float): weight decay (L2 penalty or decoupled) (default: 0).
27
27
  nesterov (bool): enables Nesterov momentum. Only applicable when momentum
28
28
  is non-zero. (default: False)
29
- decoupled_wd (bool): whether to apply decoupled weight decay (like AdamW)
30
- instead of standard L2 penalty. (default: False)
31
29
  cautious_wd (bool): Enables Cautious Weight Decay. If True, weight decay is
32
30
  applied only to parameter coordinates where the sign of the parameter
33
31
  and the sign of the optimizer update align (default: False).
@@ -61,11 +59,13 @@ class SinkSGD_adv(torch.optim.Optimizer):
61
59
  orthogonal_sinkhorn: bool = False,
62
60
  # Normalization then Momentum
63
61
  normed_momentum: bool = False,
62
+ # Centered Variance Precondition
63
+ centered_vt: bool = False,
64
64
  # Nesterov Momentum
65
65
  nesterov: bool = False,
66
66
  nesterov_coef: float | None = None,
67
- # Decoupled/cautious weight decay
68
- decoupled_wd: bool = False,
67
+ # weight decay features
68
+ geometric_wd: bool = False,
69
69
  cautious_wd: bool = False,
70
70
  # Stochastic Rounding for BF16
71
71
  stochastic_rounding: bool = True,
@@ -101,8 +101,8 @@ class SinkSGD_adv(torch.optim.Optimizer):
101
101
 
102
102
  defaults = {
103
103
  "lr": lr, "momentum": momentum,
104
- "weight_decay": weight_decay, "nesterov": nesterov, "nesterov_coef": nesterov_coef, "normed_momentum": normed_momentum,
105
- "decoupled_wd": decoupled_wd, "cautious_wd": cautious_wd,
104
+ "weight_decay": weight_decay, "nesterov": nesterov, "nesterov_coef": nesterov_coef, "normed_momentum": normed_momentum, "centered_vt": centered_vt,
105
+ "geometric_wd": geometric_wd, "cautious_wd": cautious_wd,
106
106
  "orthogonal_gradient": orthogonal_gradient,
107
107
  "compiled_optimizer": compiled_optimizer,
108
108
  "sinkhorn_iterations": sinkhorn_iterations,
@@ -182,6 +182,11 @@ class SinkSGD_adv(torch.optim.Optimizer):
182
182
  if group['momentum'] != 0:
183
183
  init_state_tensor(state, 'momentum_buffer', p.shape, actual_precision, p.device, dtype)
184
184
 
185
+ if group.get('centered_vt', False):
186
+ p_shape = p.shape
187
+ state['vt_row'] = torch.zeros(p_shape[:-1], device=device, dtype=torch.float32)
188
+ state['vt_col'] = torch.zeros(p_shape[:-2] + p_shape[-1:], device=device, dtype=torch.float32)
189
+
185
190
  if group.get('spectral_normalization', False) and is_spectral(p):
186
191
  init_spectral_norm(state, p)
187
192
 
@@ -237,7 +242,7 @@ class SinkSGD_adv(torch.optim.Optimizer):
237
242
  if group.get('normed_momentum', False):
238
243
  if not is_vector:
239
244
  # Sinkhorn iterative normalization
240
- grad = apply_sr_sinkhorn(grad, p, ortho_project=orthogonal_sinkhorn, iters=sinkhorn_iterations)
245
+ grad = apply_sr_sinkhorn(grad, iters=sinkhorn_iterations, p=p, ortho_project=orthogonal_sinkhorn)
241
246
  else:
242
247
  # For vectors, apply adaptive stochastic sign
243
248
  grad = apply_stochastic_sign_(grad, sign_noise, is_vector=is_vector)
@@ -271,6 +276,24 @@ class SinkSGD_adv(torch.optim.Optimizer):
271
276
 
272
277
  if momentum != 0:
273
278
  buf = get_state(state, 'momentum_buffer', actual_precision)
279
+
280
+ if group.get('centered_vt', False):
281
+ vt_row, vt_col = state['vt_row'], state['vt_col']
282
+ grad_vt = grad - buf
283
+ grad_vt_sq = grad_vt * grad_vt
284
+ mean_row_grad = grad_vt_sq.mean(dim=-1)
285
+ mean_col_grad = grad_vt_sq.mean(dim=-2)
286
+ vt_row.mul_(momentum).add_(mean_row_grad, alpha=1.0 - momentum)
287
+ vt_col.mul_(momentum).add_(mean_col_grad, alpha=1.0 - momentum)
288
+ if nesterov:
289
+ nv_coef = momentum if nesterov_coef is None else nesterov_coef
290
+ vt_row = vt_row.lerp(mean_row_grad, 1.0 - nv_coef)
291
+ vt_col = vt_col.lerp(mean_col_grad, 1.0 - nv_coef)
292
+ vt = _sinkhorn_sq_grad(vt_row, vt_col)
293
+ else:
294
+ vt_row = None
295
+ vt_col = None
296
+
274
297
  buf.lerp_(grad, 1 - momentum)
275
298
 
276
299
  set_state(state, 'momentum_buffer', buf, actual_precision, random_int_state_tensor)
@@ -285,21 +308,34 @@ class SinkSGD_adv(torch.optim.Optimizer):
285
308
 
286
309
  del random_int_state_tensor
287
310
 
311
+ if group.get('centered_vt', False):
312
+ denom = vt
313
+ update.atan2_(denom)
314
+ else:
315
+ denom = None
316
+
288
317
  if not group.get('normed_momentum', False):
289
318
  if not is_vector:
290
319
  # Sinkhorn iterative normalization
291
- update = apply_sr_sinkhorn(update, p, ortho_project=orthogonal_sinkhorn, iters=sinkhorn_iterations)
320
+ update = apply_sr_sinkhorn(update, iters=sinkhorn_iterations, p=p, ortho_project=orthogonal_sinkhorn)
292
321
  else:
293
322
  # For vectors, apply adaptive stochastic sign
294
323
  update = apply_stochastic_sign_(update, sign_noise, is_vector=is_vector)
295
324
 
325
+ if group.get('geometric_wd', False):
326
+ wd_scaler = get_sinkhorn_wd_scaler(p, row_denom=vt_row, col_denom=vt_col)
327
+ else:
328
+ wd_scaler = None
329
+
296
330
  update_scaling = step_size
297
331
  if group.get('spectral_normalization', False):
298
332
  update = scale_update(p, update, update_scaling, state=state)
299
333
  else:
334
+ if group.get('centered_vt', False):
335
+ update_scaling = update_scaling * (4/math.pi)
300
336
  update.mul_(update_scaling)
301
337
 
302
- param_update.apply_parameter_update(self, p, group, update, step_size, random_int_tensor=random_int_tensor)
338
+ param_update.apply_parameter_update(self, p, group, update, step_size, random_int_tensor=random_int_tensor, wd_scaler=wd_scaler)
303
339
 
304
340
  def compile(self, *args, **kwargs):
305
341
  self._compiled_step_parameter = torch.compile(self._step_parameter, *args, **kwargs)
@@ -4,15 +4,19 @@ from . import param_update
4
4
 
5
5
  def apply_stochastic_sign_(update: torch.Tensor, noise: torch.Tensor | None, is_vector: bool = False) -> torch.Tensor:
6
6
  """
7
- Applies the Stochastic Sign operator S_R(v).
7
+ Applies the Iterative L-infinity Stochastic Sign operator.
8
8
  Uses uniform noise injection to compute the stochastic sign
9
9
  """
10
10
  if update.dim() >= 2 and not is_vector:
11
- update_abs = update.abs()
12
- # Calculate row and col maximums
13
- R_col = update_abs.amax(dim=0, keepdim=True) # Shape: (1, cols)
14
- R_row = update_abs.amax(dim=1, keepdim=True) # Shape: (rows, 1)
15
- R = torch.minimum(R_row, R_col)
11
+ # Iterative L-infinity Sinkhorn algorithm
12
+ # This converges in just one iteration
13
+ # Step 1: Row Max (every row max is 1.0, all values <= 1.0)
14
+ R_row = torch.linalg.vector_norm(update, ord=float('inf'), dim=1, keepdim=True).clamp_min_(1e-12)
15
+ update.div_(R_row)
16
+
17
+ # Step 2: Col Max (every col max is 1.0 and every row max stays 1.0)
18
+ R_col = torch.linalg.vector_norm(update, ord=float('inf'), dim=0, keepdim=True).clamp_min_(1e-12)
19
+ update.div_(R_col)
16
20
  else:
17
21
  # Fallback for 1D tensors (e.g., biases, layernorm)
18
22
  # Block-wise scaling to protect against outliers
@@ -21,7 +25,8 @@ def apply_stochastic_sign_(update: torch.Tensor, noise: torch.Tensor | None, is_
21
25
 
22
26
  if numel <= block_size:
23
27
  # Too small to chunk, just use global max
24
- R = update.abs().max()
28
+ R = update.abs().max().clamp_min_(1e-12)
29
+ update.div_(R)
25
30
  else:
26
31
  # Calculate how much padding we need to make it divisible by block_size
27
32
  remainder = numel % block_size
@@ -41,13 +46,11 @@ def apply_stochastic_sign_(update: torch.Tensor, noise: torch.Tensor | None, is_
41
46
  R_blocks = blocks.abs().max(dim=1, keepdim=True).values
42
47
 
43
48
  # Broadcast R_blocks back to the padded shape, slice off padding, and restore original shape
44
- R = R_blocks.expand_as(blocks).reshape(-1)[:numel].view_as(update)
45
-
46
- # Prevent division by zero
47
- R = R.clamp_min(1e-12)
49
+ R = R_blocks.expand_as(blocks).reshape(-1)[:numel].view_as(update).clamp_min(1e-12)
50
+ update.div_(R)
48
51
 
49
52
  if noise is None:
50
53
  noise = param_update._get_random_noise_for_sso(update)
51
54
 
52
- # Chain inplace operations: torch.sign(update / R + noise)
53
- return update.div_(R).add_(noise).sign_()
55
+ # Final stochastic step: sign(v + U[-1, 1])
56
+ return update.add_(noise).sign_()
@@ -1,7 +1,7 @@
1
1
  import math
2
2
  import torch
3
3
 
4
- def apply_sr_sinkhorn(update: torch.Tensor, p: torch.Tensor, ortho_project: bool, iters: int = 5) -> torch.Tensor:
4
+ def apply_sr_sinkhorn(update: torch.Tensor, iters: int = 5, p: torch.Tensor | None = None, ortho_project: bool = False) -> torch.Tensor:
5
5
  """
6
6
  Applies Square-Root Sinkhorn (SR-Sinkhorn) multi-normalization.
7
7
  As described in 'Gradient Multi-Normalization for Efficient LLM Training'.
@@ -47,13 +47,16 @@ def apply_sr_sinkhorn(update: torch.Tensor, p: torch.Tensor, ortho_project: bool
47
47
  # In-place alternating Sinkhorn normalization steps
48
48
  for _ in range(iters):
49
49
  # First normalization step
50
- norm1 = update_2d.norm(p=2, dim=dim, keepdim=True).clamp_min_(1e-12)
50
+ # Stability floor: equivalent to a single-element vector norm lower bound (lb)
51
+ norm1_lb = 1 / math.sqrt(update_2d.shape[dim])
52
+ norm1 = update_2d.norm(p=2, dim=dim, keepdim=True).clamp_min_(norm1_lb)
51
53
  update_2d.mul_(scale_first / norm1)
52
54
  if ortho_project:
53
55
  update_2d = ortho_normed(param_2d, update_2d, p_norm_sq_dim, dim, scale_first)
54
56
 
55
57
  # Second normalization step
56
- norm2 = update_2d.norm(p=2, dim=1-dim, keepdim=True).clamp_min_(1e-12)
58
+ norm2_lb = 1 / math.sqrt(update_2d.shape[1-dim])
59
+ norm2 = update_2d.norm(p=2, dim=1-dim, keepdim=True).clamp_min_(norm2_lb)
57
60
  update_2d.mul_(scale_second / norm2)
58
61
  if ortho_project:
59
62
  update_2d = ortho_normed(param_2d, update_2d, p_norm_sq_adim, 1-dim, scale_second)
@@ -72,6 +75,69 @@ def ortho_normed(p_2d, update_2d, p_norm_sq, dim, target_norm):
72
75
  update_2d.addcmul_(proj, p_2d, value=-1.0)
73
76
 
74
77
  # Magnitude Preservation
75
- g_orth_norm = update_2d.norm(p=2, dim=dim, keepdim=True).clamp_min_(1e-12)
78
+ norm_lb = 1 / math.sqrt(update_2d.shape[dim])
79
+ g_orth_norm = update_2d.norm(p=2, dim=dim, keepdim=True).clamp_min_(norm_lb)
76
80
  scale_factor = target_norm / g_orth_norm
77
81
  return update_2d.mul_(scale_factor)
82
+
83
+ def _sinkhorn_sq_grad(
84
+ vt_row: torch.Tensor,
85
+ vt_col: torch.Tensor,
86
+ ) -> torch.Tensor:
87
+ """
88
+ Reconstructs the variance precondition from its rank-1 factors.
89
+ Modified from:
90
+ https://github.com/jettify/pytorch-optimizer/blob/master/torch_optimizer/adafactor.py
91
+ """
92
+ r_factor = (
93
+ (vt_row / vt_row.mean(dim=-1).clamp_min_(1e-30))
94
+ .sqrt_()
95
+ .unsqueeze(-1)
96
+ )
97
+ c_factor = vt_col.unsqueeze(-2).sqrt()
98
+ return torch.mul(r_factor, c_factor)
99
+
100
+ def get_sinkhorn_wd_scaler(
101
+ p: torch.Tensor,
102
+ row_denom: torch.Tensor | None = None,
103
+ col_denom: torch.Tensor | None = None
104
+ ):
105
+ """
106
+ Computes a structural weight decay multiplier.
107
+ Penalizes parameters belonging to dominant rows/columns more heavily,
108
+ while protecting parameters in under-utilized/noisy rows/columns from decay.
109
+ """
110
+ if p.ndim < 2:
111
+ return 1.0
112
+
113
+ p_2d = p.view(p.shape[0], -1)
114
+
115
+ # Lower bounds based on the effective 2D shapes
116
+ row_lb = 1 / math.sqrt(p_2d.shape[1])
117
+ col_lb = 1 / math.sqrt(p_2d.shape[0])
118
+
119
+ # Get the norms
120
+ row_norms = torch.linalg.vector_norm(p_2d, ord=2, dim=1, keepdim=True).clamp_min_(row_lb)
121
+ col_norms = torch.linalg.vector_norm(p_2d, ord=2, dim=0, keepdim=True).clamp_min_(col_lb)
122
+
123
+ # Compute the structural scaler
124
+ row_factor = row_norms.sqrt_()
125
+ col_factor = col_norms.sqrt_()
126
+
127
+ if row_denom is not None and col_denom is not None:
128
+ # Reshape denominators to ensure safe in-place broadcasting
129
+ row_denom = row_denom.view(p_2d.shape[0], 1)
130
+ col_denom = col_denom.view(1, p_2d.shape[1])
131
+
132
+ # High denom (noise) -> smaller angle (protects weights)
133
+ # Low denom (confident) -> larger angle (decays weights)
134
+ row_factor.atan2_(row_denom)
135
+ col_factor.atan2_(col_denom)
136
+
137
+ # Outer product: merges the row and column confidences into a 2D matrix
138
+ wd_scaler = row_factor * col_factor
139
+
140
+ # Normalize the scaler so its mean is exactly 1.0
141
+ wd_scaler.div_(wd_scaler.mean().clamp_min_(1e-12))
142
+
143
+ return wd_scaler.view_as(p)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: adv_optm
3
- Version: 2.4.dev16
3
+ Version: 2.4.dev18
4
4
  Summary: A family of highly efficient, lightweight yet powerful optimizers.
5
5
  Home-page: https://github.com/Koratahiu/Advanced_Optimizers
6
6
  Author: Koratahiu
@@ -5,7 +5,7 @@ with open("README.md", "r", encoding="utf-8") as fh:
5
5
 
6
6
  setup(
7
7
  name="adv_optm",
8
- version="2.4.dev16",
8
+ version="2.4.dev18",
9
9
  author="Koratahiu",
10
10
  author_email="hiuhonor@gmail.com",
11
11
  license='Apache 2.0',
File without changes
File without changes
File without changes