adv-optm 2.5.10__tar.gz → 2.6.1.dev2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. {adv_optm-2.5.10 → adv_optm-2.6.1.dev2}/PKG-INFO +1 -1
  2. {adv_optm-2.5.10 → adv_optm-2.6.1.dev2}/adv_optm/__init__.py +1 -1
  3. {adv_optm-2.5.10 → adv_optm-2.6.1.dev2}/adv_optm/optim/AdaMuon_adv.py +2 -1
  4. {adv_optm-2.5.10 → adv_optm-2.6.1.dev2}/adv_optm/optim/AdamW_adv.py +2 -1
  5. {adv_optm-2.5.10 → adv_optm-2.6.1.dev2}/adv_optm/optim/Adopt_adv.py +2 -1
  6. {adv_optm-2.5.10 → adv_optm-2.6.1.dev2}/adv_optm/optim/Lion_adv.py +2 -0
  7. {adv_optm-2.5.10 → adv_optm-2.6.1.dev2}/adv_optm/optim/Muon_adv.py +2 -1
  8. {adv_optm-2.5.10 → adv_optm-2.6.1.dev2}/adv_optm/optim/Prodigy_adv.py +2 -1
  9. {adv_optm-2.5.10 → adv_optm-2.6.1.dev2}/adv_optm/optim/SignSGD_adv.py +2 -0
  10. {adv_optm-2.5.10 → adv_optm-2.6.1.dev2}/adv_optm/optim/SinkSGD_adv.py +2 -1
  11. {adv_optm-2.5.10 → adv_optm-2.6.1.dev2}/adv_optm/util/param_update.py +33 -27
  12. {adv_optm-2.5.10 → adv_optm-2.6.1.dev2}/adv_optm/util/scaled_optm.py +14 -0
  13. {adv_optm-2.5.10 → adv_optm-2.6.1.dev2}/adv_optm.egg-info/PKG-INFO +1 -1
  14. {adv_optm-2.5.10 → adv_optm-2.6.1.dev2}/setup.py +1 -1
  15. {adv_optm-2.5.10 → adv_optm-2.6.1.dev2}/LICENSE +0 -0
  16. {adv_optm-2.5.10 → adv_optm-2.6.1.dev2}/README.md +0 -0
  17. {adv_optm-2.5.10 → adv_optm-2.6.1.dev2}/adv_optm/optim/__init__.py +0 -0
  18. {adv_optm-2.5.10 → adv_optm-2.6.1.dev2}/adv_optm/util/Kourkoutas.py +0 -0
  19. {adv_optm-2.5.10 → adv_optm-2.6.1.dev2}/adv_optm/util/Muon_AuxAdam.py +0 -0
  20. {adv_optm-2.5.10 → adv_optm-2.6.1.dev2}/adv_optm/util/Muon_util.py +0 -0
  21. {adv_optm-2.5.10 → adv_optm-2.6.1.dev2}/adv_optm/util/OrthoGrad.py +0 -0
  22. {adv_optm-2.5.10 → adv_optm-2.6.1.dev2}/adv_optm/util/__init__.py +0 -0
  23. {adv_optm-2.5.10 → adv_optm-2.6.1.dev2}/adv_optm/util/centered_decay.py +0 -0
  24. {adv_optm-2.5.10 → adv_optm-2.6.1.dev2}/adv_optm/util/factorization_util.py +0 -0
  25. {adv_optm-2.5.10 → adv_optm-2.6.1.dev2}/adv_optm/util/lion_k.py +0 -0
  26. {adv_optm-2.5.10 → adv_optm-2.6.1.dev2}/adv_optm/util/signed_util.py +0 -0
  27. {adv_optm-2.5.10 → adv_optm-2.6.1.dev2}/adv_optm/util/sinkhorn.py +0 -0
  28. {adv_optm-2.5.10 → adv_optm-2.6.1.dev2}/adv_optm/util/state_util.py +0 -0
  29. {adv_optm-2.5.10 → adv_optm-2.6.1.dev2}/adv_optm/util/update_util.py +0 -0
  30. {adv_optm-2.5.10 → adv_optm-2.6.1.dev2}/adv_optm.egg-info/SOURCES.txt +0 -0
  31. {adv_optm-2.5.10 → adv_optm-2.6.1.dev2}/adv_optm.egg-info/dependency_links.txt +0 -0
  32. {adv_optm-2.5.10 → adv_optm-2.6.1.dev2}/adv_optm.egg-info/requires.txt +0 -0
  33. {adv_optm-2.5.10 → adv_optm-2.6.1.dev2}/adv_optm.egg-info/top_level.txt +0 -0
  34. {adv_optm-2.5.10 → adv_optm-2.6.1.dev2}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: adv_optm
3
- Version: 2.5.10
3
+ Version: 2.6.1.dev2
4
4
  Summary: A family of highly efficient, lightweight yet powerful optimizers.
5
5
  Home-page: https://github.com/Koratahiu/Advanced_Optimizers
6
6
  Author: Koratahiu
@@ -20,4 +20,4 @@ __all__ = [
20
20
  "SinkSGD_adv",
21
21
  ]
22
22
 
23
- __version__ = "2.5.10"
23
+ __version__ = "2.6.1.dev2"
@@ -137,6 +137,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
137
137
  # Decoupled/cautious weight decay
138
138
  weight_decay: float = 0,
139
139
  cautious_wd: bool = False,
140
+ scaled_wd: bool = False,
140
141
  # Nesterov momentum
141
142
  nesterov: bool = True,
142
143
  nesterov_coef: float | None = None,
@@ -227,7 +228,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
227
228
 
228
229
  defaults = {
229
230
  "lr": lr, "betas": betas, "weight_decay": weight_decay, "cautious_wd": cautious_wd,
230
- "eps": eps, "rms_rescaling": rms_rescaling, "ns_steps": ns_steps,
231
+ "eps": eps, "rms_rescaling": rms_rescaling, "ns_steps": ns_steps, "scaled_wd": scaled_wd,
231
232
  "ns_eps": ns_eps, "ns_coeffs": ns_coeffs, "nnmf_factor": nnmf_factor,
232
233
  "vector_reshape": vector_reshape,
233
234
  "nesterov":nesterov, "nesterov_coef": nesterov_coef, "use_atan2":use_atan2,
@@ -98,6 +98,7 @@ class AdamW_adv(torch.optim.Optimizer):
98
98
  weight_decay: float = 0.0,
99
99
  fisher_wd: bool = False,
100
100
  cautious_wd: bool = False,
101
+ scaled_wd: bool = False,
101
102
  # Adam's Bias Correction
102
103
  use_bias_correction: bool = True,
103
104
  # Stochastic Rounding for BF16
@@ -156,7 +157,7 @@ class AdamW_adv(torch.optim.Optimizer):
156
157
 
157
158
  defaults = {
158
159
  "lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay,
159
- "fisher_wd": fisher_wd, "cautious_wd": cautious_wd,
160
+ "fisher_wd": fisher_wd, "cautious_wd": cautious_wd, "scaled_wd": scaled_wd,
160
161
  "use_atan2": use_atan2, "nesterov": nesterov, "nesterov_coef": nesterov_coef,
161
162
  "normed_momentum": normed_momentum,
162
163
  "orthogonal_gradient": orthogonal_gradient, "use_bias_correction": use_bias_correction,
@@ -101,6 +101,7 @@ class Adopt_adv(torch.optim.Optimizer):
101
101
  weight_decay: float = 0.0,
102
102
  fisher_wd: bool = False,
103
103
  cautious_wd: bool = False,
104
+ scaled_wd: bool = False,
104
105
  # ADOPT clipping
105
106
  clip_lambda: Optional[Callable[[int], float]] = lambda step: step**0.25,
106
107
  # Adam_atan2 (scale invariant)
@@ -157,7 +158,7 @@ class Adopt_adv(torch.optim.Optimizer):
157
158
  state_precision = "factored"
158
159
 
159
160
  defaults = {
160
- "lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay,
161
+ "lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay, "scaled_wd": scaled_wd,
161
162
  "fisher_wd": fisher_wd, "cautious_wd": cautious_wd, "orthogonal_gradient": orthogonal_gradient,
162
163
  "nesterov": nesterov, "nesterov_coef": nesterov_coef,
163
164
  "kourkoutas_beta": kourkoutas_beta, "beta2_min": beta2_min, "ema_alpha": ema_alpha,
@@ -64,6 +64,7 @@ class Lion_adv(torch.optim.Optimizer):
64
64
  # Decoupled/cautious weight decay
65
65
  weight_decay: float = 0.0,
66
66
  cautious_wd: bool = False,
67
+ scaled_wd: bool = False,
67
68
  # Stochastic Rounding for BF16
68
69
  stochastic_rounding: bool = True,
69
70
  # OrthoGrad
@@ -96,6 +97,7 @@ class Lion_adv(torch.optim.Optimizer):
96
97
  betas=betas,
97
98
  weight_decay=weight_decay,
98
99
  cautious_wd=cautious_wd,
100
+ scaled_wd=scaled_wd,
99
101
  vector_reshape=vector_reshape,
100
102
  orthogonal_gradient=orthogonal_gradient,
101
103
  kappa_p=kappa_p,
@@ -111,6 +111,7 @@ class Muon_adv(torch.optim.Optimizer):
111
111
  # Decoupled/cautious weight decay
112
112
  weight_decay: float = 0.0,
113
113
  cautious_wd: bool = False,
114
+ scaled_wd: bool = False,
114
115
  # Nesterov momentum
115
116
  nesterov: bool = True,
116
117
  nesterov_coef: float | None = None,
@@ -201,7 +202,7 @@ class Muon_adv(torch.optim.Optimizer):
201
202
  defaults = {
202
203
  "lr": lr, "beta1": beta1, "weight_decay": weight_decay, "cautious_wd": cautious_wd,
203
204
  "nesterov": nesterov, "nesterov_coef": nesterov_coef, "ns_steps": ns_steps, "ns_eps": ns_eps,
204
- "ns_coeffs": ns_coeffs, "nnmf_factor": nnmf_factor,
205
+ "ns_coeffs": ns_coeffs, "nnmf_factor": nnmf_factor, "scaled_wd": scaled_wd,
205
206
  "vector_reshape": vector_reshape, "rms_rescaling": rms_rescaling,
206
207
  "orthogonal_gradient": orthogonal_gradient,
207
208
  'compiled_optimizer': compiled_optimizer,
@@ -115,6 +115,7 @@ class Prodigy_adv(torch.optim.Optimizer):
115
115
  weight_decay: float = 0.0,
116
116
  fisher_wd: bool = False,
117
117
  cautious_wd: bool = False,
118
+ scaled_wd: bool = False,
118
119
  # Stochastic Rounding for BF16
119
120
  stochastic_rounding: bool = True,
120
121
  # Adam_atan2 (scale invariant)
@@ -181,7 +182,7 @@ class Prodigy_adv(torch.optim.Optimizer):
181
182
 
182
183
  defaults = {
183
184
  "lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay,
184
- "fisher_wd": fisher_wd, "cautious_wd": cautious_wd,
185
+ "fisher_wd": fisher_wd, "cautious_wd": cautious_wd, "scaled_wd": scaled_wd,
185
186
  "use_atan2": use_atan2,
186
187
  "orthogonal_gradient": orthogonal_gradient,
187
188
  "compiled_optimizer": compiled_optimizer,
@@ -59,6 +59,7 @@ class SignSGD_adv(torch.optim.Optimizer):
59
59
  # weight decay features
60
60
  geometric_wd: bool = False,
61
61
  cautious_wd: bool = False,
62
+ scaled_wd: bool = False,
62
63
  # Stochastic Rounding for BF16
63
64
  stochastic_rounding: bool = True,
64
65
  # OrthoGrad
@@ -108,6 +109,7 @@ class SignSGD_adv(torch.optim.Optimizer):
108
109
  momentum=momentum,
109
110
  weight_decay=weight_decay,
110
111
  cautious_wd=cautious_wd,
112
+ scaled_wd=scaled_wd,
111
113
  geometric_wd=geometric_wd,
112
114
  vector_reshape=vector_reshape,
113
115
  orthogonal_gradient=orthogonal_gradient,
@@ -66,6 +66,7 @@ class SinkSGD_adv(torch.optim.Optimizer):
66
66
  # weight decay features
67
67
  geometric_wd: bool = False,
68
68
  cautious_wd: bool = False,
69
+ scaled_wd: bool = False,
69
70
  # Stochastic Rounding for BF16
70
71
  stochastic_rounding: bool = True,
71
72
  # OrthoGrad
@@ -103,7 +104,7 @@ class SinkSGD_adv(torch.optim.Optimizer):
103
104
  defaults = {
104
105
  "lr": lr, "momentum": momentum,
105
106
  "weight_decay": weight_decay, "nesterov": nesterov, "nesterov_coef": nesterov_coef, "normed_momentum": normed_momentum, "snr_cond": snr_cond,
106
- "geometric_wd": geometric_wd, "cautious_wd": cautious_wd,
107
+ "geometric_wd": geometric_wd, "cautious_wd": cautious_wd, "scaled_wd": scaled_wd,
107
108
  "orthogonal_gradient": orthogonal_gradient,
108
109
  "compiled_optimizer": compiled_optimizer,
109
110
  "sinkhorn_iterations": sinkhorn_iterations,
@@ -6,7 +6,7 @@ import torch.nn.functional as F
6
6
 
7
7
  from typing import Dict, Any
8
8
 
9
- from .scaled_optm import adjust_wds
9
+ from .scaled_optm import adjust_wds, scale_wd
10
10
  from .centered_decay import dequantize_anchor
11
11
 
12
12
  _generators: Dict[torch.device, torch.Generator] = {}
@@ -18,8 +18,8 @@ def _apply_weight_decay(
18
18
  p: Tensor,
19
19
  state: Dict[str, Any],
20
20
  group: Dict[str, Any],
21
- scaled_wd: float | Tensor | None,
22
- scaled_cwd: float | Tensor | None,
21
+ eff_wd: float | Tensor | None,
22
+ eff_cwd: float | Tensor | None,
23
23
  wd_target: Tensor | None = None,
24
24
  cwd_target: Tensor | None = None,
25
25
  ) -> None:
@@ -29,26 +29,26 @@ def _apply_weight_decay(
29
29
  cautious = group.get('cautious_wd', False)
30
30
 
31
31
  # Standard Weight Decay (pulls toward zero)
32
- if scaled_wd is not None:
32
+ if eff_wd is not None:
33
33
  if wd_target is None:
34
34
  wd_target = p_calc
35
35
  # Cautious Weight Decay: only decay if the update pushes in the same direction as the decay
36
36
  if cautious:
37
37
  mask = (update_calc * p_calc >= 0).to(p_calc.dtype)
38
- if isinstance(scaled_wd, Tensor):
39
- p_calc.addcmul_(wd_target, mask * scaled_wd, value=-1.0)
38
+ if isinstance(eff_wd, Tensor):
39
+ p_calc.addcmul_(wd_target, mask * eff_wd, value=-1.0)
40
40
  else:
41
- p_calc.addcmul_(wd_target, mask, value=-scaled_wd)
41
+ p_calc.addcmul_(wd_target, mask, value=-eff_wd)
42
42
  del mask
43
43
  else:
44
44
  # Standard decoupled weight decay
45
- if isinstance(scaled_wd, Tensor):
46
- p_calc.addcmul_(wd_target, scaled_wd, value=-1.0)
45
+ if isinstance(eff_wd, Tensor):
46
+ p_calc.addcmul_(wd_target, eff_wd, value=-1.0)
47
47
  else:
48
- p_calc.add_(wd_target, alpha=-scaled_wd)
48
+ p_calc.add_(wd_target, alpha=-eff_wd)
49
49
 
50
50
  # Centered Weight Decay (pulls toward anchor)
51
- if scaled_cwd is not None and 'anchor_data' in state:
51
+ if eff_cwd is not None and 'anchor_data' in state:
52
52
  if cwd_target is not None:
53
53
  decay_target = cwd_target
54
54
  else:
@@ -59,17 +59,17 @@ def _apply_weight_decay(
59
59
  if cautious:
60
60
  # Cautious Weight Decay: only decay if the update pushes in the same direction as the decay
61
61
  mask = (update_calc * decay_target >= 0).to(p_calc.dtype)
62
- if isinstance(scaled_cwd, Tensor):
63
- p_calc.addcmul_(decay_target, mask * scaled_cwd, value=-1.0)
62
+ if isinstance(eff_cwd, Tensor):
63
+ p_calc.addcmul_(decay_target, mask * eff_cwd, value=-1.0)
64
64
  else:
65
- p_calc.addcmul_(decay_target, mask, value=-scaled_cwd)
65
+ p_calc.addcmul_(decay_target, mask, value=-eff_cwd)
66
66
  del mask
67
67
  else:
68
68
  # Standard decoupled weight decay
69
- if isinstance(scaled_cwd, Tensor):
70
- p_calc.addcmul_(decay_target, scaled_cwd, value=-1.0)
69
+ if isinstance(eff_cwd, Tensor):
70
+ p_calc.addcmul_(decay_target, eff_cwd, value=-1.0)
71
71
  else:
72
- p_calc.add_(decay_target, alpha=-scaled_cwd)
72
+ p_calc.add_(decay_target, alpha=-eff_cwd)
73
73
 
74
74
  if cwd_target is None:
75
75
  del decay_target
@@ -105,18 +105,24 @@ def apply_parameter_update(
105
105
  wd = group["weight_decay"] if wd is None else wd
106
106
  cwd = group.get("centered_wd", 0.0)
107
107
  wd, cwd = adjust_wds(wd, cwd, p)
108
+ scaled_wd = group.get("scaled_wd", False)
109
+ decoupled = scaled_wd
108
110
 
109
111
  # Calculate global decay factor for decoupled vs standard
110
112
  decay_factor = (lr / self._init_lr) if decoupled else lr
111
113
 
112
- scaled_wd = (wd * decay_factor) if wd != 0 else None
113
- scaled_cwd = (cwd * decay_factor) if cwd != 0 else None
114
+ eff_wd = (wd * decay_factor) if wd != 0 else None
115
+ eff_cwd = (cwd * decay_factor) if cwd != 0 else None
114
116
 
115
117
  if wd_scaler is not None:
116
- if scaled_wd is not None:
117
- scaled_wd = scaled_wd * wd_scaler
118
- if scaled_cwd is not None:
119
- scaled_cwd = scaled_cwd * wd_scaler
118
+ if eff_wd is not None:
119
+ if scaled_wd:
120
+ eff_wd = scale_wd(eff_wd, p)
121
+ eff_wd = eff_wd * wd_scaler
122
+ if eff_cwd is not None:
123
+ if scaled_wd:
124
+ eff_cwd = scale_wd(eff_cwd, p)
125
+ eff_cwd = eff_cwd * wd_scaler
120
126
 
121
127
  state = self.state[p]
122
128
 
@@ -129,8 +135,8 @@ def apply_parameter_update(
129
135
  cwd_t = cwd_target.float() if cwd_target is not None else None
130
136
 
131
137
  # Apply weight decay if needed
132
- if scaled_wd is not None or scaled_cwd is not None:
133
- _apply_weight_decay(p_fp32, update_fp32, p, state, group, scaled_wd, scaled_cwd, wd_t, cwd_t)
138
+ if eff_wd is not None or eff_cwd is not None:
139
+ _apply_weight_decay(p_fp32, update_fp32, p, state, group, eff_wd, eff_cwd, wd_t, cwd_t)
134
140
 
135
141
  # Apply main update
136
142
  p_fp32.add_(-update_fp32)
@@ -147,8 +153,8 @@ def apply_parameter_update(
147
153
 
148
154
  else:
149
155
  # Standard path for non-bfloat16 or without stochastic rounding
150
- if scaled_wd is not None or scaled_cwd is not None:
151
- _apply_weight_decay(p, update, p, state, group, scaled_wd, scaled_cwd, wd_target, cwd_target)
156
+ if eff_wd is not None or eff_cwd is not None:
157
+ _apply_weight_decay(p, update, p, state, group, eff_wd, eff_cwd, wd_target, cwd_target)
152
158
 
153
159
  # Apply main update
154
160
  p.add_(-update)
@@ -97,6 +97,20 @@ def adjust_wds(wd: float, cwd: float, p: torch.Tensor) -> tuple[float, float]:
97
97
  # Centered WD safely regularizes the delta without collapsing base feature variance.
98
98
  return wd, cwd
99
99
 
100
+ def scale_wd(wd: float, p: torch.Tensor) -> float:
101
+ """
102
+ Scale-invariant, dimension-scaled weight decay.
103
+ """
104
+ if getattr(p, '_is_oft', False):
105
+ n_el = p.shape[-1]
106
+ b = (1.0 + math.sqrt(1.0 + 8.0 * n_el)) / 2.0
107
+ wd = (2 * wd) / (b - 1)
108
+ return wd
109
+
110
+ if p.ndim >= 2:
111
+ width = p.numel() // p.shape[0]
112
+ return wd / width
113
+
100
114
 
101
115
  def is_spectral(p: torch.Tensor) -> bool:
102
116
  """Determines if a parameter should undergo spectral normalization updates."""
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: adv_optm
3
- Version: 2.5.10
3
+ Version: 2.6.1.dev2
4
4
  Summary: A family of highly efficient, lightweight yet powerful optimizers.
5
5
  Home-page: https://github.com/Koratahiu/Advanced_Optimizers
6
6
  Author: Koratahiu
@@ -5,7 +5,7 @@ with open("README.md", "r", encoding="utf-8") as fh:
5
5
 
6
6
  setup(
7
7
  name="adv_optm",
8
- version="2.5.10",
8
+ version="2.6.1.dev2",
9
9
  author="Koratahiu",
10
10
  author_email="hiuhonor@gmail.com",
11
11
  license='Apache 2.0',
File without changes
File without changes
File without changes