adv-optm 2.4.dev18__tar.gz → 2.4.dev20__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. {adv_optm-2.4.dev18 → adv_optm-2.4.dev20}/PKG-INFO +1 -1
  2. {adv_optm-2.4.dev18 → adv_optm-2.4.dev20}/adv_optm/__init__.py +1 -1
  3. {adv_optm-2.4.dev18 → adv_optm-2.4.dev20}/adv_optm/optim/SinkSGD_adv.py +14 -15
  4. {adv_optm-2.4.dev18 → adv_optm-2.4.dev20}/adv_optm/util/sinkhorn.py +2 -19
  5. {adv_optm-2.4.dev18 → adv_optm-2.4.dev20}/adv_optm.egg-info/PKG-INFO +1 -1
  6. {adv_optm-2.4.dev18 → adv_optm-2.4.dev20}/setup.py +1 -1
  7. {adv_optm-2.4.dev18 → adv_optm-2.4.dev20}/LICENSE +0 -0
  8. {adv_optm-2.4.dev18 → adv_optm-2.4.dev20}/README.md +0 -0
  9. {adv_optm-2.4.dev18 → adv_optm-2.4.dev20}/adv_optm/optim/AdaMuon_adv.py +0 -0
  10. {adv_optm-2.4.dev18 → adv_optm-2.4.dev20}/adv_optm/optim/AdamW_adv.py +0 -0
  11. {adv_optm-2.4.dev18 → adv_optm-2.4.dev20}/adv_optm/optim/Adopt_adv.py +0 -0
  12. {adv_optm-2.4.dev18 → adv_optm-2.4.dev20}/adv_optm/optim/Lion_Prodigy_adv.py +0 -0
  13. {adv_optm-2.4.dev18 → adv_optm-2.4.dev20}/adv_optm/optim/Lion_adv.py +0 -0
  14. {adv_optm-2.4.dev18 → adv_optm-2.4.dev20}/adv_optm/optim/Muon_adv.py +0 -0
  15. {adv_optm-2.4.dev18 → adv_optm-2.4.dev20}/adv_optm/optim/Prodigy_adv.py +0 -0
  16. {adv_optm-2.4.dev18 → adv_optm-2.4.dev20}/adv_optm/optim/SignSGD_adv.py +0 -0
  17. {adv_optm-2.4.dev18 → adv_optm-2.4.dev20}/adv_optm/optim/Simplified_AdEMAMix.py +0 -0
  18. {adv_optm-2.4.dev18 → adv_optm-2.4.dev20}/adv_optm/optim/__init__.py +0 -0
  19. {adv_optm-2.4.dev18 → adv_optm-2.4.dev20}/adv_optm/util/Kourkoutas.py +0 -0
  20. {adv_optm-2.4.dev18 → adv_optm-2.4.dev20}/adv_optm/util/Muon_AuxAdam.py +0 -0
  21. {adv_optm-2.4.dev18 → adv_optm-2.4.dev20}/adv_optm/util/Muon_util.py +0 -0
  22. {adv_optm-2.4.dev18 → adv_optm-2.4.dev20}/adv_optm/util/OrthoGrad.py +0 -0
  23. {adv_optm-2.4.dev18 → adv_optm-2.4.dev20}/adv_optm/util/__init__.py +0 -0
  24. {adv_optm-2.4.dev18 → adv_optm-2.4.dev20}/adv_optm/util/centered_decay.py +0 -0
  25. {adv_optm-2.4.dev18 → adv_optm-2.4.dev20}/adv_optm/util/factorization_util.py +0 -0
  26. {adv_optm-2.4.dev18 → adv_optm-2.4.dev20}/adv_optm/util/lion_k.py +0 -0
  27. {adv_optm-2.4.dev18 → adv_optm-2.4.dev20}/adv_optm/util/param_update.py +0 -0
  28. {adv_optm-2.4.dev18 → adv_optm-2.4.dev20}/adv_optm/util/scaled_optm.py +0 -0
  29. {adv_optm-2.4.dev18 → adv_optm-2.4.dev20}/adv_optm/util/signed_util.py +0 -0
  30. {adv_optm-2.4.dev18 → adv_optm-2.4.dev20}/adv_optm/util/state_util.py +0 -0
  31. {adv_optm-2.4.dev18 → adv_optm-2.4.dev20}/adv_optm/util/update_util.py +0 -0
  32. {adv_optm-2.4.dev18 → adv_optm-2.4.dev20}/adv_optm.egg-info/SOURCES.txt +0 -0
  33. {adv_optm-2.4.dev18 → adv_optm-2.4.dev20}/adv_optm.egg-info/dependency_links.txt +0 -0
  34. {adv_optm-2.4.dev18 → adv_optm-2.4.dev20}/adv_optm.egg-info/requires.txt +0 -0
  35. {adv_optm-2.4.dev18 → adv_optm-2.4.dev20}/adv_optm.egg-info/top_level.txt +0 -0
  36. {adv_optm-2.4.dev18 → adv_optm-2.4.dev20}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: adv_optm
3
- Version: 2.4.dev18
3
+ Version: 2.4.dev20
4
4
  Summary: A family of highly efficient, lightweight yet powerful optimizers.
5
5
  Home-page: https://github.com/Koratahiu/Advanced_Optimizers
6
6
  Author: Koratahiu
@@ -24,4 +24,4 @@ __all__ = [
24
24
  "SinkSGD_adv",
25
25
  ]
26
26
 
27
- __version__ = "2.4.dev18"
27
+ __version__ = "2.4.dev20"
@@ -4,12 +4,11 @@ import math
4
4
 
5
5
  from ..util import param_update
6
6
  from ..util.factorization_util import _get_effective_shape, _reconstruct_state, _factorize_state
7
- from ..util.update_util import _grams_update, _cautious_update
8
7
  from ..util.OrthoGrad import _orthogonalize_gradient
9
8
  from ..util.scaled_optm import scale_update, is_spectral, init_spectral_norm
10
9
  from ..util.centered_decay import _init_anchor
11
10
  from ..util.state_util import init_state_tensor, get_state, set_state, upcast_grad_for_precision
12
- from ..util.sinkhorn import apply_sr_sinkhorn, _sinkhorn_sq_grad, get_sinkhorn_wd_scaler
11
+ from ..util.sinkhorn import apply_sr_sinkhorn, get_sinkhorn_wd_scaler
13
12
  from ..util.signed_util import apply_stochastic_sign_
14
13
 
15
14
  class SinkSGD_adv(torch.optim.Optimizer):
@@ -90,6 +89,8 @@ class SinkSGD_adv(torch.optim.Optimizer):
90
89
  raise ValueError(f"Momentum should be >= 0.0. Got {momentum}")
91
90
  if not (weight_decay >= 0.0):
92
91
  raise ValueError(f"Weight-decay should be >= 0.0. Got {weight_decay}")
92
+ if centered_vt and not normed_momentum:
93
+ raise NotImplementedError(f"centered_vt is intended to be used with normed_momentum")
93
94
 
94
95
  state_precision = state_precision.lower()
95
96
  valid_precisions = {"auto", "fp32", "factored", "bf16_sr", "fp8_sr", "int8_sr"}
@@ -183,9 +184,11 @@ class SinkSGD_adv(torch.optim.Optimizer):
183
184
  init_state_tensor(state, 'momentum_buffer', p.shape, actual_precision, p.device, dtype)
184
185
 
185
186
  if group.get('centered_vt', False):
186
- p_shape = p.shape
187
- state['vt_row'] = torch.zeros(p_shape[:-1], device=device, dtype=torch.float32)
188
- state['vt_col'] = torch.zeros(p_shape[:-2] + p_shape[-1:], device=device, dtype=torch.float32)
187
+ # Align shapes with Sinkhorn's 2D flattening
188
+ dim0 = p.shape[0]
189
+ dim1 = p.numel() // dim0
190
+ state['vt_row'] = torch.zeros(dim0, device=device, dtype=torch.float32)
191
+ state['vt_col'] = torch.zeros(dim1, device=device, dtype=torch.float32)
189
192
 
190
193
  if group.get('spectral_normalization', False) and is_spectral(p):
191
194
  init_spectral_norm(state, p)
@@ -280,16 +283,11 @@ class SinkSGD_adv(torch.optim.Optimizer):
280
283
  if group.get('centered_vt', False):
281
284
  vt_row, vt_col = state['vt_row'], state['vt_col']
282
285
  grad_vt = grad - buf
283
- grad_vt_sq = grad_vt * grad_vt
286
+ grad_vt_sq = grad_vt.mul_(grad_vt).view(grad.shape[0], -1)
284
287
  mean_row_grad = grad_vt_sq.mean(dim=-1)
285
288
  mean_col_grad = grad_vt_sq.mean(dim=-2)
286
289
  vt_row.mul_(momentum).add_(mean_row_grad, alpha=1.0 - momentum)
287
290
  vt_col.mul_(momentum).add_(mean_col_grad, alpha=1.0 - momentum)
288
- if nesterov:
289
- nv_coef = momentum if nesterov_coef is None else nesterov_coef
290
- vt_row = vt_row.lerp(mean_row_grad, 1.0 - nv_coef)
291
- vt_col = vt_col.lerp(mean_col_grad, 1.0 - nv_coef)
292
- vt = _sinkhorn_sq_grad(vt_row, vt_col)
293
291
  else:
294
292
  vt_row = None
295
293
  vt_col = None
@@ -309,10 +307,11 @@ class SinkSGD_adv(torch.optim.Optimizer):
309
307
  del random_int_state_tensor
310
308
 
311
309
  if group.get('centered_vt', False):
312
- denom = vt
313
- update.atan2_(denom)
314
- else:
315
- denom = None
310
+ # Align with Sinkhorn: Alternate row/col preconditioning
311
+ update_2d = update.view(update.shape[0], -1)
312
+ update_2d.mul_(vt_row.clamp_min(1e-30).rsqrt().unsqueeze(1))
313
+ update_2d.mul_(vt_col.clamp_min(1e-30).rsqrt().unsqueeze(0))
314
+ update = update_2d.atan_().view_as(p)
316
315
 
317
316
  if not group.get('normed_momentum', False):
318
317
  if not is_vector:
@@ -80,23 +80,6 @@ def ortho_normed(p_2d, update_2d, p_norm_sq, dim, target_norm):
80
80
  scale_factor = target_norm / g_orth_norm
81
81
  return update_2d.mul_(scale_factor)
82
82
 
83
- def _sinkhorn_sq_grad(
84
- vt_row: torch.Tensor,
85
- vt_col: torch.Tensor,
86
- ) -> torch.Tensor:
87
- """
88
- Reconstructs the variance precondition from its rank-1 factors.
89
- Modified from:
90
- https://github.com/jettify/pytorch-optimizer/blob/master/torch_optimizer/adafactor.py
91
- """
92
- r_factor = (
93
- (vt_row / vt_row.mean(dim=-1).clamp_min_(1e-30))
94
- .sqrt_()
95
- .unsqueeze(-1)
96
- )
97
- c_factor = vt_col.unsqueeze(-2).sqrt()
98
- return torch.mul(r_factor, c_factor)
99
-
100
83
  def get_sinkhorn_wd_scaler(
101
84
  p: torch.Tensor,
102
85
  row_denom: torch.Tensor | None = None,
@@ -126,8 +109,8 @@ def get_sinkhorn_wd_scaler(
126
109
 
127
110
  if row_denom is not None and col_denom is not None:
128
111
  # Reshape denominators to ensure safe in-place broadcasting
129
- row_denom = row_denom.view(p_2d.shape[0], 1)
130
- col_denom = col_denom.view(1, p_2d.shape[1])
112
+ row_denom = row_denom.sqrt().view(p_2d.shape[0], 1)
113
+ col_denom = col_denom.sqrt().view(1, p_2d.shape[1])
131
114
 
132
115
  # High denom (noise) -> smaller angle (protects weights)
133
116
  # Low denom (confident) -> larger angle (decays weights)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: adv_optm
3
- Version: 2.4.dev18
3
+ Version: 2.4.dev20
4
4
  Summary: A family of highly efficient, lightweight yet powerful optimizers.
5
5
  Home-page: https://github.com/Koratahiu/Advanced_Optimizers
6
6
  Author: Koratahiu
@@ -5,7 +5,7 @@ with open("README.md", "r", encoding="utf-8") as fh:
5
5
 
6
6
  setup(
7
7
  name="adv_optm",
8
- version="2.4.dev18",
8
+ version="2.4.dev20",
9
9
  author="Koratahiu",
10
10
  author_email="hiuhonor@gmail.com",
11
11
  license='Apache 2.0',
File without changes
File without changes
File without changes