adv-optm 2.4.dev17__tar.gz → 2.4.dev19__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. {adv_optm-2.4.dev17 → adv_optm-2.4.dev19}/PKG-INFO +1 -1
  2. {adv_optm-2.4.dev17 → adv_optm-2.4.dev19}/adv_optm/__init__.py +1 -1
  3. {adv_optm-2.4.dev17 → adv_optm-2.4.dev19}/adv_optm/optim/SinkSGD_adv.py +51 -11
  4. {adv_optm-2.4.dev17 → adv_optm-2.4.dev19}/adv_optm/util/sinkhorn.py +53 -4
  5. {adv_optm-2.4.dev17 → adv_optm-2.4.dev19}/adv_optm.egg-info/PKG-INFO +1 -1
  6. {adv_optm-2.4.dev17 → adv_optm-2.4.dev19}/setup.py +1 -1
  7. {adv_optm-2.4.dev17 → adv_optm-2.4.dev19}/LICENSE +0 -0
  8. {adv_optm-2.4.dev17 → adv_optm-2.4.dev19}/README.md +0 -0
  9. {adv_optm-2.4.dev17 → adv_optm-2.4.dev19}/adv_optm/optim/AdaMuon_adv.py +0 -0
  10. {adv_optm-2.4.dev17 → adv_optm-2.4.dev19}/adv_optm/optim/AdamW_adv.py +0 -0
  11. {adv_optm-2.4.dev17 → adv_optm-2.4.dev19}/adv_optm/optim/Adopt_adv.py +0 -0
  12. {adv_optm-2.4.dev17 → adv_optm-2.4.dev19}/adv_optm/optim/Lion_Prodigy_adv.py +0 -0
  13. {adv_optm-2.4.dev17 → adv_optm-2.4.dev19}/adv_optm/optim/Lion_adv.py +0 -0
  14. {adv_optm-2.4.dev17 → adv_optm-2.4.dev19}/adv_optm/optim/Muon_adv.py +0 -0
  15. {adv_optm-2.4.dev17 → adv_optm-2.4.dev19}/adv_optm/optim/Prodigy_adv.py +0 -0
  16. {adv_optm-2.4.dev17 → adv_optm-2.4.dev19}/adv_optm/optim/SignSGD_adv.py +0 -0
  17. {adv_optm-2.4.dev17 → adv_optm-2.4.dev19}/adv_optm/optim/Simplified_AdEMAMix.py +0 -0
  18. {adv_optm-2.4.dev17 → adv_optm-2.4.dev19}/adv_optm/optim/__init__.py +0 -0
  19. {adv_optm-2.4.dev17 → adv_optm-2.4.dev19}/adv_optm/util/Kourkoutas.py +0 -0
  20. {adv_optm-2.4.dev17 → adv_optm-2.4.dev19}/adv_optm/util/Muon_AuxAdam.py +0 -0
  21. {adv_optm-2.4.dev17 → adv_optm-2.4.dev19}/adv_optm/util/Muon_util.py +0 -0
  22. {adv_optm-2.4.dev17 → adv_optm-2.4.dev19}/adv_optm/util/OrthoGrad.py +0 -0
  23. {adv_optm-2.4.dev17 → adv_optm-2.4.dev19}/adv_optm/util/__init__.py +0 -0
  24. {adv_optm-2.4.dev17 → adv_optm-2.4.dev19}/adv_optm/util/centered_decay.py +0 -0
  25. {adv_optm-2.4.dev17 → adv_optm-2.4.dev19}/adv_optm/util/factorization_util.py +0 -0
  26. {adv_optm-2.4.dev17 → adv_optm-2.4.dev19}/adv_optm/util/lion_k.py +0 -0
  27. {adv_optm-2.4.dev17 → adv_optm-2.4.dev19}/adv_optm/util/param_update.py +0 -0
  28. {adv_optm-2.4.dev17 → adv_optm-2.4.dev19}/adv_optm/util/scaled_optm.py +0 -0
  29. {adv_optm-2.4.dev17 → adv_optm-2.4.dev19}/adv_optm/util/signed_util.py +0 -0
  30. {adv_optm-2.4.dev17 → adv_optm-2.4.dev19}/adv_optm/util/state_util.py +0 -0
  31. {adv_optm-2.4.dev17 → adv_optm-2.4.dev19}/adv_optm/util/update_util.py +0 -0
  32. {adv_optm-2.4.dev17 → adv_optm-2.4.dev19}/adv_optm.egg-info/SOURCES.txt +0 -0
  33. {adv_optm-2.4.dev17 → adv_optm-2.4.dev19}/adv_optm.egg-info/dependency_links.txt +0 -0
  34. {adv_optm-2.4.dev17 → adv_optm-2.4.dev19}/adv_optm.egg-info/requires.txt +0 -0
  35. {adv_optm-2.4.dev17 → adv_optm-2.4.dev19}/adv_optm.egg-info/top_level.txt +0 -0
  36. {adv_optm-2.4.dev17 → adv_optm-2.4.dev19}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: adv_optm
3
- Version: 2.4.dev17
3
+ Version: 2.4.dev19
4
4
  Summary: A family of highly efficient, lightweight yet powerful optimizers.
5
5
  Home-page: https://github.com/Koratahiu/Advanced_Optimizers
6
6
  Author: Koratahiu
@@ -24,4 +24,4 @@ __all__ = [
24
24
  "SinkSGD_adv",
25
25
  ]
26
26
 
27
- __version__ = "2.4.dev17"
27
+ __version__ = "2.4.dev19"
@@ -1,6 +1,6 @@
1
1
  import torch
2
2
 
3
- from typing import Optional, Callable
3
+ import math
4
4
 
5
5
  from ..util import param_update
6
6
  from ..util.factorization_util import _get_effective_shape, _reconstruct_state, _factorize_state
@@ -9,7 +9,7 @@ from ..util.OrthoGrad import _orthogonalize_gradient
9
9
  from ..util.scaled_optm import scale_update, is_spectral, init_spectral_norm
10
10
  from ..util.centered_decay import _init_anchor
11
11
  from ..util.state_util import init_state_tensor, get_state, set_state, upcast_grad_for_precision
12
- from ..util.sinkhorn import apply_sr_sinkhorn
12
+ from ..util.sinkhorn import apply_sr_sinkhorn, get_sinkhorn_wd_scaler
13
13
  from ..util.signed_util import apply_stochastic_sign_
14
14
 
15
15
  class SinkSGD_adv(torch.optim.Optimizer):
@@ -26,8 +26,6 @@ class SinkSGD_adv(torch.optim.Optimizer):
26
26
  weight_decay (float): weight decay (L2 penalty or decoupled) (default: 0).
27
27
  nesterov (bool): enables Nesterov momentum. Only applicable when momentum
28
28
  is non-zero. (default: False)
29
- decoupled_wd (bool): whether to apply decoupled weight decay (like AdamW)
30
- instead of standard L2 penalty. (default: False)
31
29
  cautious_wd (bool): Enables Cautious Weight Decay. If True, weight decay is
32
30
  applied only to parameter coordinates where the sign of the parameter
33
31
  and the sign of the optimizer update align (default: False).
@@ -61,11 +59,13 @@ class SinkSGD_adv(torch.optim.Optimizer):
61
59
  orthogonal_sinkhorn: bool = False,
62
60
  # Normalization then Momentum
63
61
  normed_momentum: bool = False,
62
+ # Centered Variance Precondition
63
+ centered_vt: bool = False,
64
64
  # Nesterov Momentum
65
65
  nesterov: bool = False,
66
66
  nesterov_coef: float | None = None,
67
- # Decoupled/cautious weight decay
68
- decoupled_wd: bool = False,
67
+ # weight decay features
68
+ geometric_wd: bool = False,
69
69
  cautious_wd: bool = False,
70
70
  # Stochastic Rounding for BF16
71
71
  stochastic_rounding: bool = True,
@@ -90,6 +90,8 @@ class SinkSGD_adv(torch.optim.Optimizer):
90
90
  raise ValueError(f"Momentum should be >= 0.0. Got {momentum}")
91
91
  if not (weight_decay >= 0.0):
92
92
  raise ValueError(f"Weight-decay should be >= 0.0. Got {weight_decay}")
93
+ if centered_vt and not normed_momentum:
94
+ raise NotImplementedError(f"centered_vt is intended to be used with normed_momentum")
93
95
 
94
96
  state_precision = state_precision.lower()
95
97
  valid_precisions = {"auto", "fp32", "factored", "bf16_sr", "fp8_sr", "int8_sr"}
@@ -101,8 +103,8 @@ class SinkSGD_adv(torch.optim.Optimizer):
101
103
 
102
104
  defaults = {
103
105
  "lr": lr, "momentum": momentum,
104
- "weight_decay": weight_decay, "nesterov": nesterov, "nesterov_coef": nesterov_coef, "normed_momentum": normed_momentum,
105
- "decoupled_wd": decoupled_wd, "cautious_wd": cautious_wd,
106
+ "weight_decay": weight_decay, "nesterov": nesterov, "nesterov_coef": nesterov_coef, "normed_momentum": normed_momentum, "centered_vt": centered_vt,
107
+ "geometric_wd": geometric_wd, "cautious_wd": cautious_wd,
106
108
  "orthogonal_gradient": orthogonal_gradient,
107
109
  "compiled_optimizer": compiled_optimizer,
108
110
  "sinkhorn_iterations": sinkhorn_iterations,
@@ -182,6 +184,13 @@ class SinkSGD_adv(torch.optim.Optimizer):
182
184
  if group['momentum'] != 0:
183
185
  init_state_tensor(state, 'momentum_buffer', p.shape, actual_precision, p.device, dtype)
184
186
 
187
+ if group.get('centered_vt', False):
188
+ # Align shapes with Sinkhorn's 2D flattening
189
+ dim0 = p.shape[0]
190
+ dim1 = p.numel() // dim0
191
+ state['vt_row'] = torch.zeros(dim0, device=device, dtype=torch.float32)
192
+ state['vt_col'] = torch.zeros(dim1, device=device, dtype=torch.float32)
193
+
185
194
  if group.get('spectral_normalization', False) and is_spectral(p):
186
195
  init_spectral_norm(state, p)
187
196
 
@@ -237,7 +246,7 @@ class SinkSGD_adv(torch.optim.Optimizer):
237
246
  if group.get('normed_momentum', False):
238
247
  if not is_vector:
239
248
  # Sinkhorn iterative normalization
240
- grad = apply_sr_sinkhorn(grad, p, ortho_project=orthogonal_sinkhorn, iters=sinkhorn_iterations)
249
+ grad = apply_sr_sinkhorn(grad, iters=sinkhorn_iterations, p=p, ortho_project=orthogonal_sinkhorn)
241
250
  else:
242
251
  # For vectors, apply adaptive stochastic sign
243
252
  grad = apply_stochastic_sign_(grad, sign_noise, is_vector=is_vector)
@@ -271,6 +280,23 @@ class SinkSGD_adv(torch.optim.Optimizer):
271
280
 
272
281
  if momentum != 0:
273
282
  buf = get_state(state, 'momentum_buffer', actual_precision)
283
+
284
+ if group.get('centered_vt', False):
285
+ vt_row, vt_col = state['vt_row'], state['vt_col']
286
+ grad_vt = grad - buf
287
+ grad_vt_sq = grad_vt.mul_(grad_vt).view(grad.shape[0], -1)
288
+ mean_row_grad = grad_vt_sq.mean(dim=-1)
289
+ mean_col_grad = grad_vt_sq.mean(dim=-2)
290
+ vt_row.mul_(momentum).add_(mean_row_grad, alpha=1.0 - momentum)
291
+ vt_col.mul_(momentum).add_(mean_col_grad, alpha=1.0 - momentum)
292
+ if nesterov:
293
+ nv_coef = momentum if nesterov_coef is None else nesterov_coef
294
+ vt_row = vt_row.lerp(mean_row_grad, 1.0 - nv_coef)
295
+ vt_col = vt_col.lerp(mean_col_grad, 1.0 - nv_coef)
296
+ else:
297
+ vt_row = None
298
+ vt_col = None
299
+
274
300
  buf.lerp_(grad, 1 - momentum)
275
301
 
276
302
  set_state(state, 'momentum_buffer', buf, actual_precision, random_int_state_tensor)
@@ -285,21 +311,35 @@ class SinkSGD_adv(torch.optim.Optimizer):
285
311
 
286
312
  del random_int_state_tensor
287
313
 
314
+ if group.get('centered_vt', False):
315
+ # Align with Sinkhorn: Alternate row/col preconditioning
316
+ update_2d = update.view(update.shape[0], -1)
317
+ update_2d.div_(vt_row.clamp_min(1e-30).sqrt().unsqueeze(1))
318
+ update_2d.div_(vt_col.clamp_min(1e-30).sqrt().unsqueeze(0))
319
+ update = update_2d.atan_().view_as(p)
320
+
288
321
  if not group.get('normed_momentum', False):
289
322
  if not is_vector:
290
323
  # Sinkhorn iterative normalization
291
- update = apply_sr_sinkhorn(update, p, ortho_project=orthogonal_sinkhorn, iters=sinkhorn_iterations)
324
+ update = apply_sr_sinkhorn(update, iters=sinkhorn_iterations, p=p, ortho_project=orthogonal_sinkhorn)
292
325
  else:
293
326
  # For vectors, apply adaptive stochastic sign
294
327
  update = apply_stochastic_sign_(update, sign_noise, is_vector=is_vector)
295
328
 
329
+ if group.get('geometric_wd', False):
330
+ wd_scaler = get_sinkhorn_wd_scaler(p, row_denom=vt_row, col_denom=vt_col)
331
+ else:
332
+ wd_scaler = None
333
+
296
334
  update_scaling = step_size
297
335
  if group.get('spectral_normalization', False):
298
336
  update = scale_update(p, update, update_scaling, state=state)
299
337
  else:
338
+ if group.get('centered_vt', False):
339
+ update_scaling = update_scaling * (4/math.pi)
300
340
  update.mul_(update_scaling)
301
341
 
302
- param_update.apply_parameter_update(self, p, group, update, step_size, random_int_tensor=random_int_tensor)
342
+ param_update.apply_parameter_update(self, p, group, update, step_size, random_int_tensor=random_int_tensor, wd_scaler=wd_scaler)
303
343
 
304
344
  def compile(self, *args, **kwargs):
305
345
  self._compiled_step_parameter = torch.compile(self._step_parameter, *args, **kwargs)
@@ -1,7 +1,7 @@
1
1
  import math
2
2
  import torch
3
3
 
4
- def apply_sr_sinkhorn(update: torch.Tensor, p: torch.Tensor, ortho_project: bool, iters: int = 5) -> torch.Tensor:
4
+ def apply_sr_sinkhorn(update: torch.Tensor, iters: int = 5, p: torch.Tensor | None = None, ortho_project: bool = False) -> torch.Tensor:
5
5
  """
6
6
  Applies Square-Root Sinkhorn (SR-Sinkhorn) multi-normalization.
7
7
  As described in 'Gradient Multi-Normalization for Efficient LLM Training'.
@@ -47,13 +47,16 @@ def apply_sr_sinkhorn(update: torch.Tensor, p: torch.Tensor, ortho_project: bool
47
47
  # In-place alternating Sinkhorn normalization steps
48
48
  for _ in range(iters):
49
49
  # First normalization step
50
- norm1 = update_2d.norm(p=2, dim=dim, keepdim=True).clamp_min_(1e-12)
50
+ # Stability floor: equivalent to a single-element vector norm lower bound (lb)
51
+ norm1_lb = 1 / math.sqrt(update_2d.shape[dim])
52
+ norm1 = update_2d.norm(p=2, dim=dim, keepdim=True).clamp_min_(norm1_lb)
51
53
  update_2d.mul_(scale_first / norm1)
52
54
  if ortho_project:
53
55
  update_2d = ortho_normed(param_2d, update_2d, p_norm_sq_dim, dim, scale_first)
54
56
 
55
57
  # Second normalization step
56
- norm2 = update_2d.norm(p=2, dim=1-dim, keepdim=True).clamp_min_(1e-12)
58
+ norm2_lb = 1 / math.sqrt(update_2d.shape[1-dim])
59
+ norm2 = update_2d.norm(p=2, dim=1-dim, keepdim=True).clamp_min_(norm2_lb)
57
60
  update_2d.mul_(scale_second / norm2)
58
61
  if ortho_project:
59
62
  update_2d = ortho_normed(param_2d, update_2d, p_norm_sq_adim, 1-dim, scale_second)
@@ -72,6 +75,52 @@ def ortho_normed(p_2d, update_2d, p_norm_sq, dim, target_norm):
72
75
  update_2d.addcmul_(proj, p_2d, value=-1.0)
73
76
 
74
77
  # Magnitude Preservation
75
- g_orth_norm = update_2d.norm(p=2, dim=dim, keepdim=True).clamp_min_(1e-12)
78
+ norm_lb = 1 / math.sqrt(update_2d.shape[dim])
79
+ g_orth_norm = update_2d.norm(p=2, dim=dim, keepdim=True).clamp_min_(norm_lb)
76
80
  scale_factor = target_norm / g_orth_norm
77
81
  return update_2d.mul_(scale_factor)
82
+
83
+ def get_sinkhorn_wd_scaler(
84
+ p: torch.Tensor,
85
+ row_denom: torch.Tensor | None = None,
86
+ col_denom: torch.Tensor | None = None
87
+ ):
88
+ """
89
+ Computes a structural weight decay multiplier.
90
+ Penalizes parameters belonging to dominant rows/columns more heavily,
91
+ while protecting parameters in under-utilized/noisy rows/columns from decay.
92
+ """
93
+ if p.ndim < 2:
94
+ return 1.0
95
+
96
+ p_2d = p.view(p.shape[0], -1)
97
+
98
+ # Lower bounds based on the effective 2D shapes
99
+ row_lb = 1 / math.sqrt(p_2d.shape[1])
100
+ col_lb = 1 / math.sqrt(p_2d.shape[0])
101
+
102
+ # Get the norms
103
+ row_norms = torch.linalg.vector_norm(p_2d, ord=2, dim=1, keepdim=True).clamp_min_(row_lb)
104
+ col_norms = torch.linalg.vector_norm(p_2d, ord=2, dim=0, keepdim=True).clamp_min_(col_lb)
105
+
106
+ # Compute the structural scaler
107
+ row_factor = row_norms.sqrt_()
108
+ col_factor = col_norms.sqrt_()
109
+
110
+ if row_denom is not None and col_denom is not None:
111
+ # Reshape denominators to ensure safe in-place broadcasting
112
+ row_denom = row_denom.sqrt().view(p_2d.shape[0], 1)
113
+ col_denom = col_denom.sqrt().view(1, p_2d.shape[1])
114
+
115
+ # High denom (noise) -> smaller angle (protects weights)
116
+ # Low denom (confident) -> larger angle (decays weights)
117
+ row_factor.atan2_(row_denom)
118
+ col_factor.atan2_(col_denom)
119
+
120
+ # Outer product: merges the row and column confidences into a 2D matrix
121
+ wd_scaler = row_factor * col_factor
122
+
123
+ # Normalize the scaler so its mean is exactly 1.0
124
+ wd_scaler.div_(wd_scaler.mean().clamp_min_(1e-12))
125
+
126
+ return wd_scaler.view_as(p)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: adv_optm
3
- Version: 2.4.dev17
3
+ Version: 2.4.dev19
4
4
  Summary: A family of highly efficient, lightweight yet powerful optimizers.
5
5
  Home-page: https://github.com/Koratahiu/Advanced_Optimizers
6
6
  Author: Koratahiu
@@ -5,7 +5,7 @@ with open("README.md", "r", encoding="utf-8") as fh:
5
5
 
6
6
  setup(
7
7
  name="adv_optm",
8
- version="2.4.dev17",
8
+ version="2.4.dev19",
9
9
  author="Koratahiu",
10
10
  author_email="hiuhonor@gmail.com",
11
11
  license='Apache 2.0',
File without changes
File without changes
File without changes