adv-optm 2.4.dev5__tar.gz → 2.4.dev6__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. {adv_optm-2.4.dev5 → adv_optm-2.4.dev6}/PKG-INFO +1 -1
  2. {adv_optm-2.4.dev5 → adv_optm-2.4.dev6}/adv_optm/__init__.py +1 -1
  3. {adv_optm-2.4.dev5 → adv_optm-2.4.dev6}/adv_optm/optim/AdamW_adv.py +8 -3
  4. {adv_optm-2.4.dev5 → adv_optm-2.4.dev6}/adv_optm/optim/Adopt_adv.py +8 -3
  5. {adv_optm-2.4.dev5 → adv_optm-2.4.dev6}/adv_optm/util/scaled_optm.py +18 -6
  6. {adv_optm-2.4.dev5 → adv_optm-2.4.dev6}/adv_optm.egg-info/PKG-INFO +1 -1
  7. {adv_optm-2.4.dev5 → adv_optm-2.4.dev6}/setup.py +1 -1
  8. {adv_optm-2.4.dev5 → adv_optm-2.4.dev6}/LICENSE +0 -0
  9. {adv_optm-2.4.dev5 → adv_optm-2.4.dev6}/README.md +0 -0
  10. {adv_optm-2.4.dev5 → adv_optm-2.4.dev6}/adv_optm/optim/AdaMuon_adv.py +0 -0
  11. {adv_optm-2.4.dev5 → adv_optm-2.4.dev6}/adv_optm/optim/Lion_Prodigy_adv.py +0 -0
  12. {adv_optm-2.4.dev5 → adv_optm-2.4.dev6}/adv_optm/optim/Lion_adv.py +0 -0
  13. {adv_optm-2.4.dev5 → adv_optm-2.4.dev6}/adv_optm/optim/Muon_adv.py +0 -0
  14. {adv_optm-2.4.dev5 → adv_optm-2.4.dev6}/adv_optm/optim/Prodigy_adv.py +0 -0
  15. {adv_optm-2.4.dev5 → adv_optm-2.4.dev6}/adv_optm/optim/SignSGD_adv.py +0 -0
  16. {adv_optm-2.4.dev5 → adv_optm-2.4.dev6}/adv_optm/optim/Simplified_AdEMAMix.py +0 -0
  17. {adv_optm-2.4.dev5 → adv_optm-2.4.dev6}/adv_optm/optim/__init__.py +0 -0
  18. {adv_optm-2.4.dev5 → adv_optm-2.4.dev6}/adv_optm/util/Kourkoutas.py +0 -0
  19. {adv_optm-2.4.dev5 → adv_optm-2.4.dev6}/adv_optm/util/Muon_AuxAdam.py +0 -0
  20. {adv_optm-2.4.dev5 → adv_optm-2.4.dev6}/adv_optm/util/Muon_util.py +0 -0
  21. {adv_optm-2.4.dev5 → adv_optm-2.4.dev6}/adv_optm/util/OrthoGrad.py +0 -0
  22. {adv_optm-2.4.dev5 → adv_optm-2.4.dev6}/adv_optm/util/__init__.py +0 -0
  23. {adv_optm-2.4.dev5 → adv_optm-2.4.dev6}/adv_optm/util/centered_decay.py +0 -0
  24. {adv_optm-2.4.dev5 → adv_optm-2.4.dev6}/adv_optm/util/factorization_util.py +0 -0
  25. {adv_optm-2.4.dev5 → adv_optm-2.4.dev6}/adv_optm/util/lion_k.py +0 -0
  26. {adv_optm-2.4.dev5 → adv_optm-2.4.dev6}/adv_optm/util/param_update.py +0 -0
  27. {adv_optm-2.4.dev5 → adv_optm-2.4.dev6}/adv_optm/util/signed_util.py +0 -0
  28. {adv_optm-2.4.dev5 → adv_optm-2.4.dev6}/adv_optm/util/update_util.py +0 -0
  29. {adv_optm-2.4.dev5 → adv_optm-2.4.dev6}/adv_optm.egg-info/SOURCES.txt +0 -0
  30. {adv_optm-2.4.dev5 → adv_optm-2.4.dev6}/adv_optm.egg-info/dependency_links.txt +0 -0
  31. {adv_optm-2.4.dev5 → adv_optm-2.4.dev6}/adv_optm.egg-info/requires.txt +0 -0
  32. {adv_optm-2.4.dev5 → adv_optm-2.4.dev6}/adv_optm.egg-info/top_level.txt +0 -0
  33. {adv_optm-2.4.dev5 → adv_optm-2.4.dev6}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: adv_optm
3
- Version: 2.4.dev5
3
+ Version: 2.4.dev6
4
4
  Summary: A family of highly efficient, lightweight yet powerful optimizers.
5
5
  Home-page: https://github.com/Koratahiu/Advanced_Optimizers
6
6
  Author: Koratahiu
@@ -22,4 +22,4 @@ __all__ = [
22
22
  "SignSGD_adv",
23
23
  ]
24
24
 
25
- __version__ = "2.4.dev5"
25
+ __version__ = "2.4.dev6"
@@ -9,7 +9,7 @@ from ..util.factorization_util import _get_effective_shape, _reconstruct_state,
9
9
  from ..util.update_util import _grams_update, _cautious_update, _init_fisher_wd_scaler, _get_fisher_wd_scaler
10
10
  from ..util.OrthoGrad import _orthogonalize_gradient
11
11
  from ..util.Kourkoutas import KourkoutasHelper
12
- from ..util.scaled_optm import scale_update, is_spectral, init_spectral_norm
12
+ from ..util.scaled_optm import scale_update, is_spectral, init_spectral_norm, scale_eps
13
13
  from ..util.centered_decay import _init_anchor
14
14
 
15
15
  A = 4 / math.pi
@@ -153,6 +153,9 @@ class AdamW_adv(torch.optim.Optimizer):
153
153
  raise ValueError(f"Weight-decay should be >= 0.0. Got {weight_decay}")
154
154
  if kourkoutas_beta and not (betas[1] > beta2_min):
155
155
  raise ValueError(f"For Kourkoutas-β, betas[1] (as beta2_max) must be > beta2_min. Got {betas[1]} and {beta2_min}")
156
+ if scaled_optm and use_atan2:
157
+ print("Warning: use_atan2 is incompatible with scaled_optm, Disabling atan2.")
158
+ use_atan2 = False
156
159
 
157
160
  if cautious_mask and grams_moment:
158
161
  print("Warning: cautious is incompatible with grams, Disabling cautious.")
@@ -330,6 +333,8 @@ class AdamW_adv(torch.optim.Optimizer):
330
333
  # Determine if we are using dense first-moments alongside a factored second-order second-moment
331
334
  factored_2nd = group.get('factored_2nd', False)
332
335
 
336
+ adaptive_eps = scale_eps(group, p)
337
+
333
338
  if state['factored']:
334
339
  d1, d2 = state['effective_shape']
335
340
  grad_reshaped = grad.view(d1, d2)
@@ -394,7 +399,7 @@ class AdamW_adv(torch.optim.Optimizer):
394
399
  update.atan2_(denom)
395
400
  else:
396
401
  denom = vt.sqrt_()
397
- denom.div_(sqrt_bias_correction2).add_(group['eps'])
402
+ denom.div_(sqrt_bias_correction2).add_(adaptive_eps)
398
403
  update.div_(denom)
399
404
 
400
405
  wd_scaler = _get_fisher_wd_scaler(group, state.get("wd_scaler"), p, denom, group['use_atan2'])
@@ -438,7 +443,7 @@ class AdamW_adv(torch.optim.Optimizer):
438
443
  update.atan2_(denom)
439
444
  else:
440
445
  denom = exp_avg_sq.sqrt()
441
- denom.div_(sqrt_bias_correction2).add_(group['eps'])
446
+ denom.div_(sqrt_bias_correction2).add_(adaptive_eps)
442
447
  update.div_(denom)
443
448
 
444
449
  wd_scaler = _get_fisher_wd_scaler(group, state.get("wd_scaler"), p, denom, group['use_atan2'])
@@ -8,7 +8,7 @@ from ..util.factorization_util import _get_effective_shape, _reconstruct_state,
8
8
  from ..util.OrthoGrad import _orthogonalize_gradient
9
9
  from ..util.Kourkoutas import KourkoutasHelper
10
10
  from ..util.update_util import _grams_update, _cautious_update, _scale_sim_AdEMAMix_update, _init_fisher_wd_scaler, _get_fisher_wd_scaler
11
- from ..util.scaled_optm import scale_update, is_spectral, init_spectral_norm
11
+ from ..util.scaled_optm import scale_update, is_spectral, init_spectral_norm, scale_eps
12
12
  from ..util.centered_decay import _init_anchor
13
13
 
14
14
  A = 4 / math.pi
@@ -183,6 +183,9 @@ class Adopt_adv(torch.optim.Optimizer):
183
183
  print("Warning: grams is incompatible with Simplified_AdEMAMix, Disabling grams.")
184
184
  if cautious_mask and Simplified_AdEMAMix:
185
185
  print("Warning: cautious is incompatible with Simplified_AdEMAMix, Disabling cautious.")
186
+ if scaled_optm and use_atan2:
187
+ print("Warning: use_atan2 is incompatible with scaled_optm, Disabling atan2.")
188
+ use_atan2 = False
186
189
 
187
190
  defaults = {
188
191
  "lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay,
@@ -364,6 +367,8 @@ class Adopt_adv(torch.optim.Optimizer):
364
367
  # Determine if we are using dense first-moments alongside a factored second-order second-moment
365
368
  factored_2nd = group.get('factored_2nd', False)
366
369
 
370
+ adaptive_eps = scale_eps(group, p)
371
+
367
372
  if state['factored']:
368
373
  d1, d2 = state['effective_shape']
369
374
  grad_reshaped = grad.view(d1, d2)
@@ -387,7 +392,7 @@ class Adopt_adv(torch.optim.Optimizer):
387
392
  if self.use_atan2:
388
393
  normalized_grad = torch.atan2(grad_reshaped, denom, out=denom)
389
394
  else:
390
- normalized_grad = torch.div(grad_reshaped, denom.add_(group['eps']), out=denom)
395
+ normalized_grad = torch.div(grad_reshaped, denom.add_(adaptive_eps), out=denom)
391
396
  if self.clip_lambda is not None:
392
397
  clip_val = self.clip_lambda(state['step'])
393
398
  normalized_grad.clamp_(-clip_val, clip_val)
@@ -457,7 +462,7 @@ class Adopt_adv(torch.optim.Optimizer):
457
462
  if self.use_atan2:
458
463
  normalized_grad = torch.atan2(grad, denom, out=denom)
459
464
  else:
460
- normalized_grad = torch.div(grad, denom.add_(group['eps']), out=denom)
465
+ normalized_grad = torch.div(grad, denom.add_(adaptive_eps), out=denom)
461
466
  if self.clip_lambda is not None:
462
467
  clip_val = self.clip_lambda(state['step'])
463
468
  normalized_grad.clamp_(-clip_val, clip_val)
@@ -2,6 +2,8 @@ import torch
2
2
 
3
3
  from . import param_update
4
4
 
5
+ import math
6
+
5
7
  def scale_update(
6
8
  p: torch.Tensor,
7
9
  update: torch.Tensor,
@@ -26,16 +28,16 @@ def scale_update(
26
28
 
27
29
  # DoRA Magnitude Scales (1D) or 1D Bias/Norm layers
28
30
  if is_dora_scale or p.ndim == 1:
29
- return rms_normalization(update, dim=None, lr=lr)
31
+ return l2_normalization(update, dim=None, lr=lr)
30
32
 
31
33
  # Orthogonal Fine-Tuning (OFT)
32
34
  # This guarantees O(1) update complexity scaling, independent of block sizes.
33
35
  if is_oft:
34
36
  n = update.shape[1]
35
37
  # Calculate block size (b)
36
- b = (1 + (1 + 8 * n) ** 0.5) / 2
37
- target_norm = (b / 8) ** 0.5
38
- scale = target_norm / (n ** 0.5)
38
+ b = (1 + math.sqrt(1 + 8 * n)) / 2
39
+ target_norm = math.sqrt(b / 8)
40
+ scale = target_norm / math.sqrt(n)
39
41
  return rms_normalization(update, dim=1, lr=lr * scale)
40
42
 
41
43
  # LoRA Factors or Full Finetuning weights
@@ -46,6 +48,16 @@ def scale_update(
46
48
  return update.mul_(lr)
47
49
 
48
50
 
51
+ def scale_eps(group: dict, p) -> tuple[float, float]:
52
+ """
53
+ Scales Adam eps to be scale-invariant.
54
+ """
55
+ if group.get('scaled_optm', False):
56
+ adaptive_eps = (1.0 / group['n_layers']) * (1.0 / math.sqrt(p.numel()))
57
+ else:
58
+ adaptive_eps = group['eps']
59
+ return adaptive_eps
60
+
49
61
  def adjust_wds(wd: float, cwd: float, p: torch.Tensor) -> tuple[float, float]:
50
62
  """
51
63
  Adjusts standard weight decay and centered weight decay.
@@ -93,7 +105,7 @@ def rms_normalization(update: torch.Tensor, dim: int | None, lr: float) -> torch
93
105
  """Performs Root Mean Square normalization on the update tensor."""
94
106
  n = update.numel() if dim is None else update.shape[dim]
95
107
  norm = torch.linalg.vector_norm(update, ord=2, dim=dim, keepdim=True).clamp_min_(1e-12)
96
- scale_n = n**0.5
108
+ scale_n = math.sqrt(n)
97
109
  return update.mul_(lr * scale_n / norm)
98
110
 
99
111
 
@@ -123,7 +135,7 @@ def spectral_normalization(update: torch.Tensor, vector_state: torch.Tensor, lr:
123
135
  update = update.to(vector_state.dtype)
124
136
  update_flat = update.view(d_out, d_in)
125
137
  # Target scale derived from the "Modular Norm" paper
126
- target_scale = (d_out / d_in) ** 0.5
138
+ target_scale = math.sqrt(d_out / d_in)
127
139
  # Power Iteration step to estimate the largest singular value (sigma)
128
140
  # u = Wv
129
141
  u = torch.mv(update_flat, vector_state)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: adv_optm
3
- Version: 2.4.dev5
3
+ Version: 2.4.dev6
4
4
  Summary: A family of highly efficient, lightweight yet powerful optimizers.
5
5
  Home-page: https://github.com/Koratahiu/Advanced_Optimizers
6
6
  Author: Koratahiu
@@ -5,7 +5,7 @@ with open("README.md", "r", encoding="utf-8") as fh:
5
5
 
6
6
  setup(
7
7
  name="adv_optm",
8
- version="2.4.dev5",
8
+ version="2.4.dev6",
9
9
  author="Koratahiu",
10
10
  author_email="hiuhonor@gmail.com",
11
11
  license='Apache 2.0',
File without changes
File without changes
File without changes