adv-optm 2.4.dev18__tar.gz → 2.4.dev19__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. {adv_optm-2.4.dev18 → adv_optm-2.4.dev19}/PKG-INFO +1 -1
  2. {adv_optm-2.4.dev18 → adv_optm-2.4.dev19}/adv_optm/__init__.py +1 -1
  3. {adv_optm-2.4.dev18 → adv_optm-2.4.dev19}/adv_optm/optim/SinkSGD_adv.py +14 -10
  4. {adv_optm-2.4.dev18 → adv_optm-2.4.dev19}/adv_optm/util/sinkhorn.py +2 -19
  5. {adv_optm-2.4.dev18 → adv_optm-2.4.dev19}/adv_optm.egg-info/PKG-INFO +1 -1
  6. {adv_optm-2.4.dev18 → adv_optm-2.4.dev19}/setup.py +1 -1
  7. {adv_optm-2.4.dev18 → adv_optm-2.4.dev19}/LICENSE +0 -0
  8. {adv_optm-2.4.dev18 → adv_optm-2.4.dev19}/README.md +0 -0
  9. {adv_optm-2.4.dev18 → adv_optm-2.4.dev19}/adv_optm/optim/AdaMuon_adv.py +0 -0
  10. {adv_optm-2.4.dev18 → adv_optm-2.4.dev19}/adv_optm/optim/AdamW_adv.py +0 -0
  11. {adv_optm-2.4.dev18 → adv_optm-2.4.dev19}/adv_optm/optim/Adopt_adv.py +0 -0
  12. {adv_optm-2.4.dev18 → adv_optm-2.4.dev19}/adv_optm/optim/Lion_Prodigy_adv.py +0 -0
  13. {adv_optm-2.4.dev18 → adv_optm-2.4.dev19}/adv_optm/optim/Lion_adv.py +0 -0
  14. {adv_optm-2.4.dev18 → adv_optm-2.4.dev19}/adv_optm/optim/Muon_adv.py +0 -0
  15. {adv_optm-2.4.dev18 → adv_optm-2.4.dev19}/adv_optm/optim/Prodigy_adv.py +0 -0
  16. {adv_optm-2.4.dev18 → adv_optm-2.4.dev19}/adv_optm/optim/SignSGD_adv.py +0 -0
  17. {adv_optm-2.4.dev18 → adv_optm-2.4.dev19}/adv_optm/optim/Simplified_AdEMAMix.py +0 -0
  18. {adv_optm-2.4.dev18 → adv_optm-2.4.dev19}/adv_optm/optim/__init__.py +0 -0
  19. {adv_optm-2.4.dev18 → adv_optm-2.4.dev19}/adv_optm/util/Kourkoutas.py +0 -0
  20. {adv_optm-2.4.dev18 → adv_optm-2.4.dev19}/adv_optm/util/Muon_AuxAdam.py +0 -0
  21. {adv_optm-2.4.dev18 → adv_optm-2.4.dev19}/adv_optm/util/Muon_util.py +0 -0
  22. {adv_optm-2.4.dev18 → adv_optm-2.4.dev19}/adv_optm/util/OrthoGrad.py +0 -0
  23. {adv_optm-2.4.dev18 → adv_optm-2.4.dev19}/adv_optm/util/__init__.py +0 -0
  24. {adv_optm-2.4.dev18 → adv_optm-2.4.dev19}/adv_optm/util/centered_decay.py +0 -0
  25. {adv_optm-2.4.dev18 → adv_optm-2.4.dev19}/adv_optm/util/factorization_util.py +0 -0
  26. {adv_optm-2.4.dev18 → adv_optm-2.4.dev19}/adv_optm/util/lion_k.py +0 -0
  27. {adv_optm-2.4.dev18 → adv_optm-2.4.dev19}/adv_optm/util/param_update.py +0 -0
  28. {adv_optm-2.4.dev18 → adv_optm-2.4.dev19}/adv_optm/util/scaled_optm.py +0 -0
  29. {adv_optm-2.4.dev18 → adv_optm-2.4.dev19}/adv_optm/util/signed_util.py +0 -0
  30. {adv_optm-2.4.dev18 → adv_optm-2.4.dev19}/adv_optm/util/state_util.py +0 -0
  31. {adv_optm-2.4.dev18 → adv_optm-2.4.dev19}/adv_optm/util/update_util.py +0 -0
  32. {adv_optm-2.4.dev18 → adv_optm-2.4.dev19}/adv_optm.egg-info/SOURCES.txt +0 -0
  33. {adv_optm-2.4.dev18 → adv_optm-2.4.dev19}/adv_optm.egg-info/dependency_links.txt +0 -0
  34. {adv_optm-2.4.dev18 → adv_optm-2.4.dev19}/adv_optm.egg-info/requires.txt +0 -0
  35. {adv_optm-2.4.dev18 → adv_optm-2.4.dev19}/adv_optm.egg-info/top_level.txt +0 -0
  36. {adv_optm-2.4.dev18 → adv_optm-2.4.dev19}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: adv_optm
3
- Version: 2.4.dev18
3
+ Version: 2.4.dev19
4
4
  Summary: A family of highly efficient, lightweight yet powerful optimizers.
5
5
  Home-page: https://github.com/Koratahiu/Advanced_Optimizers
6
6
  Author: Koratahiu
@@ -24,4 +24,4 @@ __all__ = [
24
24
  "SinkSGD_adv",
25
25
  ]
26
26
 
27
- __version__ = "2.4.dev18"
27
+ __version__ = "2.4.dev19"
@@ -9,7 +9,7 @@ from ..util.OrthoGrad import _orthogonalize_gradient
9
9
  from ..util.scaled_optm import scale_update, is_spectral, init_spectral_norm
10
10
  from ..util.centered_decay import _init_anchor
11
11
  from ..util.state_util import init_state_tensor, get_state, set_state, upcast_grad_for_precision
12
- from ..util.sinkhorn import apply_sr_sinkhorn, _sinkhorn_sq_grad, get_sinkhorn_wd_scaler
12
+ from ..util.sinkhorn import apply_sr_sinkhorn, get_sinkhorn_wd_scaler
13
13
  from ..util.signed_util import apply_stochastic_sign_
14
14
 
15
15
  class SinkSGD_adv(torch.optim.Optimizer):
@@ -90,6 +90,8 @@ class SinkSGD_adv(torch.optim.Optimizer):
90
90
  raise ValueError(f"Momentum should be >= 0.0. Got {momentum}")
91
91
  if not (weight_decay >= 0.0):
92
92
  raise ValueError(f"Weight-decay should be >= 0.0. Got {weight_decay}")
93
+ if centered_vt and not normed_momentum:
94
+ raise NotImplementedError(f"centered_vt is intended to be used with normed_momentum")
93
95
 
94
96
  state_precision = state_precision.lower()
95
97
  valid_precisions = {"auto", "fp32", "factored", "bf16_sr", "fp8_sr", "int8_sr"}
@@ -183,9 +185,11 @@ class SinkSGD_adv(torch.optim.Optimizer):
183
185
  init_state_tensor(state, 'momentum_buffer', p.shape, actual_precision, p.device, dtype)
184
186
 
185
187
  if group.get('centered_vt', False):
186
- p_shape = p.shape
187
- state['vt_row'] = torch.zeros(p_shape[:-1], device=device, dtype=torch.float32)
188
- state['vt_col'] = torch.zeros(p_shape[:-2] + p_shape[-1:], device=device, dtype=torch.float32)
188
+ # Align shapes with Sinkhorn's 2D flattening
189
+ dim0 = p.shape[0]
190
+ dim1 = p.numel() // dim0
191
+ state['vt_row'] = torch.zeros(dim0, device=device, dtype=torch.float32)
192
+ state['vt_col'] = torch.zeros(dim1, device=device, dtype=torch.float32)
189
193
 
190
194
  if group.get('spectral_normalization', False) and is_spectral(p):
191
195
  init_spectral_norm(state, p)
@@ -280,7 +284,7 @@ class SinkSGD_adv(torch.optim.Optimizer):
280
284
  if group.get('centered_vt', False):
281
285
  vt_row, vt_col = state['vt_row'], state['vt_col']
282
286
  grad_vt = grad - buf
283
- grad_vt_sq = grad_vt * grad_vt
287
+ grad_vt_sq = grad_vt.mul_(grad_vt).view(grad.shape[0], -1)
284
288
  mean_row_grad = grad_vt_sq.mean(dim=-1)
285
289
  mean_col_grad = grad_vt_sq.mean(dim=-2)
286
290
  vt_row.mul_(momentum).add_(mean_row_grad, alpha=1.0 - momentum)
@@ -289,7 +293,6 @@ class SinkSGD_adv(torch.optim.Optimizer):
289
293
  nv_coef = momentum if nesterov_coef is None else nesterov_coef
290
294
  vt_row = vt_row.lerp(mean_row_grad, 1.0 - nv_coef)
291
295
  vt_col = vt_col.lerp(mean_col_grad, 1.0 - nv_coef)
292
- vt = _sinkhorn_sq_grad(vt_row, vt_col)
293
296
  else:
294
297
  vt_row = None
295
298
  vt_col = None
@@ -309,10 +312,11 @@ class SinkSGD_adv(torch.optim.Optimizer):
309
312
  del random_int_state_tensor
310
313
 
311
314
  if group.get('centered_vt', False):
312
- denom = vt
313
- update.atan2_(denom)
314
- else:
315
- denom = None
315
+ # Align with Sinkhorn: Alternate row/col preconditioning
316
+ update_2d = update.view(update.shape[0], -1)
317
+ update_2d.div_(vt_row.clamp_min(1e-30).sqrt().unsqueeze(1))
318
+ update_2d.div_(vt_col.clamp_min(1e-30).sqrt().unsqueeze(0))
319
+ update = update_2d.atan_().view_as(p)
316
320
 
317
321
  if not group.get('normed_momentum', False):
318
322
  if not is_vector:
@@ -80,23 +80,6 @@ def ortho_normed(p_2d, update_2d, p_norm_sq, dim, target_norm):
80
80
  scale_factor = target_norm / g_orth_norm
81
81
  return update_2d.mul_(scale_factor)
82
82
 
83
- def _sinkhorn_sq_grad(
84
- vt_row: torch.Tensor,
85
- vt_col: torch.Tensor,
86
- ) -> torch.Tensor:
87
- """
88
- Reconstructs the variance precondition from its rank-1 factors.
89
- Modified from:
90
- https://github.com/jettify/pytorch-optimizer/blob/master/torch_optimizer/adafactor.py
91
- """
92
- r_factor = (
93
- (vt_row / vt_row.mean(dim=-1).clamp_min_(1e-30))
94
- .sqrt_()
95
- .unsqueeze(-1)
96
- )
97
- c_factor = vt_col.unsqueeze(-2).sqrt()
98
- return torch.mul(r_factor, c_factor)
99
-
100
83
  def get_sinkhorn_wd_scaler(
101
84
  p: torch.Tensor,
102
85
  row_denom: torch.Tensor | None = None,
@@ -126,8 +109,8 @@ def get_sinkhorn_wd_scaler(
126
109
 
127
110
  if row_denom is not None and col_denom is not None:
128
111
  # Reshape denominators to ensure safe in-place broadcasting
129
- row_denom = row_denom.view(p_2d.shape[0], 1)
130
- col_denom = col_denom.view(1, p_2d.shape[1])
112
+ row_denom = row_denom.sqrt().view(p_2d.shape[0], 1)
113
+ col_denom = col_denom.sqrt().view(1, p_2d.shape[1])
131
114
 
132
115
  # High denom (noise) -> smaller angle (protects weights)
133
116
  # Low denom (confident) -> larger angle (decays weights)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: adv_optm
3
- Version: 2.4.dev18
3
+ Version: 2.4.dev19
4
4
  Summary: A family of highly efficient, lightweight yet powerful optimizers.
5
5
  Home-page: https://github.com/Koratahiu/Advanced_Optimizers
6
6
  Author: Koratahiu
@@ -5,7 +5,7 @@ with open("README.md", "r", encoding="utf-8") as fh:
5
5
 
6
6
  setup(
7
7
  name="adv_optm",
8
- version="2.4.dev18",
8
+ version="2.4.dev19",
9
9
  author="Koratahiu",
10
10
  author_email="hiuhonor@gmail.com",
11
11
  license='Apache 2.0',
File without changes
File without changes
File without changes