adv-optm 2.4.dev25__tar.gz → 2.5__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. {adv_optm-2.4.dev25 → adv_optm-2.5}/PKG-INFO +1 -1
  2. {adv_optm-2.4.dev25 → adv_optm-2.5}/adv_optm/__init__.py +1 -1
  3. {adv_optm-2.4.dev25 → adv_optm-2.5}/adv_optm/optim/AdaMuon_adv.py +7 -7
  4. {adv_optm-2.4.dev25 → adv_optm-2.5}/adv_optm/optim/AdamW_adv.py +5 -5
  5. {adv_optm-2.4.dev25 → adv_optm-2.5}/adv_optm/optim/Adopt_adv.py +4 -6
  6. {adv_optm-2.4.dev25 → adv_optm-2.5}/adv_optm/optim/Lion_adv.py +6 -5
  7. {adv_optm-2.4.dev25 → adv_optm-2.5}/adv_optm/optim/Muon_adv.py +7 -7
  8. {adv_optm-2.4.dev25 → adv_optm-2.5}/adv_optm/optim/Prodigy_adv.py +4 -4
  9. {adv_optm-2.4.dev25 → adv_optm-2.5}/adv_optm/optim/SignSGD_adv.py +7 -8
  10. {adv_optm-2.4.dev25 → adv_optm-2.5}/adv_optm/optim/SinkSGD_adv.py +7 -8
  11. {adv_optm-2.4.dev25 → adv_optm-2.5}/adv_optm/util/Muon_AuxAdam.py +2 -3
  12. adv_optm-2.5/adv_optm/util/OrthoGrad.py +92 -0
  13. {adv_optm-2.4.dev25 → adv_optm-2.5}/adv_optm/util/param_update.py +3 -3
  14. {adv_optm-2.4.dev25 → adv_optm-2.5}/adv_optm/util/scaled_optm.py +2 -2
  15. {adv_optm-2.4.dev25 → adv_optm-2.5}/adv_optm/util/state_util.py +1 -1
  16. {adv_optm-2.4.dev25 → adv_optm-2.5}/adv_optm.egg-info/PKG-INFO +1 -1
  17. {adv_optm-2.4.dev25 → adv_optm-2.5}/setup.py +1 -1
  18. adv_optm-2.4.dev25/adv_optm/util/OrthoGrad.py +0 -80
  19. {adv_optm-2.4.dev25 → adv_optm-2.5}/LICENSE +0 -0
  20. {adv_optm-2.4.dev25 → adv_optm-2.5}/README.md +0 -0
  21. {adv_optm-2.4.dev25 → adv_optm-2.5}/adv_optm/optim/__init__.py +0 -0
  22. {adv_optm-2.4.dev25 → adv_optm-2.5}/adv_optm/util/Kourkoutas.py +0 -0
  23. {adv_optm-2.4.dev25 → adv_optm-2.5}/adv_optm/util/Muon_util.py +0 -0
  24. {adv_optm-2.4.dev25 → adv_optm-2.5}/adv_optm/util/__init__.py +0 -0
  25. {adv_optm-2.4.dev25 → adv_optm-2.5}/adv_optm/util/centered_decay.py +0 -0
  26. {adv_optm-2.4.dev25 → adv_optm-2.5}/adv_optm/util/factorization_util.py +0 -0
  27. {adv_optm-2.4.dev25 → adv_optm-2.5}/adv_optm/util/lion_k.py +0 -0
  28. {adv_optm-2.4.dev25 → adv_optm-2.5}/adv_optm/util/signed_util.py +0 -0
  29. {adv_optm-2.4.dev25 → adv_optm-2.5}/adv_optm/util/sinkhorn.py +0 -0
  30. {adv_optm-2.4.dev25 → adv_optm-2.5}/adv_optm/util/update_util.py +0 -0
  31. {adv_optm-2.4.dev25 → adv_optm-2.5}/adv_optm.egg-info/SOURCES.txt +0 -0
  32. {adv_optm-2.4.dev25 → adv_optm-2.5}/adv_optm.egg-info/dependency_links.txt +0 -0
  33. {adv_optm-2.4.dev25 → adv_optm-2.5}/adv_optm.egg-info/requires.txt +0 -0
  34. {adv_optm-2.4.dev25 → adv_optm-2.5}/adv_optm.egg-info/top_level.txt +0 -0
  35. {adv_optm-2.4.dev25 → adv_optm-2.5}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: adv_optm
3
- Version: 2.4.dev25
3
+ Version: 2.5
4
4
  Summary: A family of highly efficient, lightweight yet powerful optimizers.
5
5
  Home-page: https://github.com/Koratahiu/Advanced_Optimizers
6
6
  Author: Koratahiu
@@ -20,4 +20,4 @@ __all__ = [
20
20
  "SinkSGD_adv",
21
21
  ]
22
22
 
23
- __version__ = "2.4.dev25"
23
+ __version__ = "2.5"
@@ -57,7 +57,8 @@ class AdaMuon_adv(torch.optim.Optimizer):
57
57
  (default: (3.4445, -4.7750, 2.0315)).
58
58
  stochastic_rounding (bool): whether to use stochastic rounding for
59
59
  BF16 parameter updates (default: True).
60
- orthogonal_gradient (bool): whether to use OrthoGrad. (default: False)
60
+ orthogonal_gradient (str): whether to use OrthoGrad variants. 'disabled': off.
61
+ 'flattened': Standard vectorized OrthoGrad. 'iterative': Matrix-wise rank-2 OrthoGrad. (default: disabled)
61
62
  nesterov (bool): enables Nesterov momentum (default: False).
62
63
  use_atan2 (bool): whether to use the atan2 update rule. (default: False)
63
64
  vector_reshape (bool): whether to reshape 1D vectors into 2D
@@ -114,7 +115,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
114
115
  adam_fisher_wd (bool): Fisher Adam (FAdam) weight decay for the AdamW part. (default: False)
115
116
  adam_use_bias_correction (bool): Bias correction for AdamW.
116
117
  adam_use_atan2 (bool): Atan2 update rule for AdamW.
117
- adam_orthogonal_gradient (bool): OrthoGrad for AdamW.
118
+ adam_orthogonal_gradient (str): OrthoGrad for AdamW.
118
119
  adam_nesterov (bool): Nesterov momentum for AdamW. (default: False)
119
120
  adam_nesterov_coef (float, optional): Nesterov coefficient for AdamW. (default: None)
120
121
  adam_kourkoutas_beta (bool): Kourkoutas-β for AdamW.
@@ -149,7 +150,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
149
150
  # Stochastic Rounding for BF16
150
151
  stochastic_rounding: bool = True,
151
152
  # OrthoGrad
152
- orthogonal_gradient: bool = False,
153
+ orthogonal_gradient: str = 'disabled', # 'flattened', 'iterative'
153
154
  # Adam_atan2 (scale invariant)
154
155
  use_atan2: bool = False,
155
156
  # NorMuon
@@ -190,7 +191,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
190
191
  adam_fisher_wd: bool = False,
191
192
  adam_use_bias_correction: bool = True,
192
193
  adam_use_atan2: bool = False,
193
- adam_orthogonal_gradient: bool = False,
194
+ adam_orthogonal_gradient: str = 'disabled', # 'flattened', 'iterative'
194
195
  adam_nesterov: bool = False,
195
196
  adam_nesterov_coef: float | None = None,
196
197
  adam_kourkoutas_beta: bool = False,
@@ -213,7 +214,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
213
214
  print("Warning: spectral_normalization is incompatible with rms_rescaling, Disabling rms_rescaling.")
214
215
  rms_rescaling = False
215
216
  if spectral_normalization and accelerated_ns:
216
- ValueError("spectral_normalization violates accelerated Newton-Schulz assumptions. Pick one of them.")
217
+ raise ValueError("spectral_normalization violates accelerated Newton-Schulz assumptions. Pick one of them.")
217
218
 
218
219
  # Legacy backwards compatibility support for `nnmf_factor=True`
219
220
  if nnmf_factor:
@@ -515,8 +516,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
515
516
  grad = approx_mars(grad, state['last_grad'], group['mars_gamma'], beta1)
516
517
 
517
518
 
518
- if group.get("orthogonal_gradient"):
519
- grad = _orthogonalize_gradient(p, grad)
519
+ grad = _orthogonalize_gradient(p, grad, group.get("orthogonal_gradient"))
520
520
 
521
521
  if state['factored']: # Factored Muon
522
522
  d1, d2 = state['effective_shape']
@@ -45,7 +45,8 @@ class AdamW_adv(torch.optim.Optimizer):
45
45
  stochastic_rounding (bool): whether to use stochastic
46
46
  rounding for BF16 parameter updates (default: True).
47
47
  use_atan2 (bool): whether to use the atan2 update rule. (default: False)
48
- orthogonal_gradient (bool): whether to use OrthoGrad. (default: False)
48
+ orthogonal_gradient (str): whether to use OrthoGrad variants. 'disabled': off.
49
+ 'flattened': Standard vectorized OrthoGrad. 'iterative': Matrix-wise rank-2 OrthoGrad. (default: disabled)
49
50
  normed_momentum (bool): whether to compute the first moment on the normalized gradient. (default: False)
50
51
  kourkoutas_beta (bool): whether to enable the layer-wise dynamic β₂ logic.
51
52
  If `False`, the optimizer behaves as standard AdamW. (default: False)
@@ -104,7 +105,7 @@ class AdamW_adv(torch.optim.Optimizer):
104
105
  # Adam_atan2 (scale invariant)
105
106
  use_atan2: bool = False,
106
107
  # OrthoGrad
107
- orthogonal_gradient: bool = False,
108
+ orthogonal_gradient: str = 'disabled', # 'flattened', 'iterative'
108
109
  # Nesterov momentum
109
110
  nesterov: bool = False,
110
111
  nesterov_coef: float | None = None,
@@ -326,8 +327,7 @@ class AdamW_adv(torch.optim.Optimizer):
326
327
  def _step_parameter(self, p, grad, state, group, step_size, beta1, beta2, sqrt_bias_correction2, random_int_tensor, random_int_state_tensor):
327
328
  grad = upcast_grad_for_precision(grad, state, group['state_precision'])
328
329
 
329
- if group["orthogonal_gradient"]:
330
- grad = _orthogonalize_gradient(p, grad)
330
+ grad = _orthogonalize_gradient(p, grad, group["orthogonal_gradient"])
331
331
 
332
332
  nesterov = group.get('nesterov', False)
333
333
  nesterov_coef = group.get('nesterov_coef', None)
@@ -462,7 +462,7 @@ class AdamW_adv(torch.optim.Optimizer):
462
462
  else:
463
463
  update.mul_(update_scaling)
464
464
 
465
- param_update.apply_parameter_update(self, p, group, update, step_size, random_int_tensor=random_int_tensor, wd_scaler=wd_scaler)
465
+ param_update.apply_parameter_update(self, p, group, update, group['lr'], random_int_tensor=random_int_tensor, wd_scaler=wd_scaler)
466
466
 
467
467
  def compile(self, *args, **kwargs):
468
468
  self._compiled_step_parameter = torch.compile(self._step_parameter, *args, **kwargs)
@@ -108,7 +108,7 @@ class Adopt_adv(torch.optim.Optimizer):
108
108
  # Stochastic Rounding for BF16
109
109
  stochastic_rounding: bool = True,
110
110
  # OrthoGrad
111
- orthogonal_gradient: bool = False,
111
+ orthogonal_gradient: str = 'disabled', # 'flattened', 'iterative'
112
112
  # Nesterov momentum
113
113
  nesterov: bool = False,
114
114
  nesterov_coef: float | None = None,
@@ -158,7 +158,7 @@ class Adopt_adv(torch.optim.Optimizer):
158
158
 
159
159
  defaults = {
160
160
  "lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay,
161
- "fisher_wd": fisher_wd, "cautious_wd": cautious_wd,
161
+ "fisher_wd": fisher_wd, "cautious_wd": cautious_wd, "orthogonal_gradient": orthogonal_gradient,
162
162
  "nesterov": nesterov, "nesterov_coef": nesterov_coef,
163
163
  "kourkoutas_beta": kourkoutas_beta, "beta2_min": beta2_min, "ema_alpha": ema_alpha,
164
164
  "tiny_spike": tiny_spike, "k_warmup_steps": k_warmup_steps, "k_logging": k_logging,
@@ -172,7 +172,6 @@ class Adopt_adv(torch.optim.Optimizer):
172
172
  self.clip_lambda = clip_lambda
173
173
  self.stochastic_rounding = stochastic_rounding
174
174
  self.use_atan2 = use_atan2
175
- self.orthogonal_gradient = orthogonal_gradient
176
175
  self.kourkoutas_beta = kourkoutas_beta
177
176
  self.layer_key_fn = layer_key_fn
178
177
  self._init_lr = lr if lr > 0 else 1
@@ -237,7 +236,7 @@ class Adopt_adv(torch.optim.Optimizer):
237
236
  dtype = torch.float32 if (state['factored'] or req_precision == 'factored') else p.dtype
238
237
 
239
238
  vt_dtype = torch.float32 if (state['factored'] or state['factored_2nd'] or req_precision in ['factored', 'bf16_sr', 'int8_sr']) else dtype
240
- vt_init = grad.pow(2).to(vt_dtype) * (1 - group['betas'][1])
239
+ vt_init = grad.pow(2).to(vt_dtype)
241
240
 
242
241
  if state['factored']:
243
242
  state['effective_shape'] = _get_effective_shape(p.numel())
@@ -329,8 +328,7 @@ class Adopt_adv(torch.optim.Optimizer):
329
328
  def _step_parameter(self, p, grad, state, group, lr, beta1, beta2, random_int_tensor, random_int_state_tensor):
330
329
  grad = upcast_grad_for_precision(grad, state, group['state_precision'])
331
330
 
332
- if self.orthogonal_gradient:
333
- grad = _orthogonalize_gradient(p, grad)
331
+ grad = _orthogonalize_gradient(p, grad, group["orthogonal_gradient"])
334
332
 
335
333
  nesterov = group.get('nesterov', False)
336
334
  nesterov_coef = group.get('nesterov_coef', None)
@@ -67,7 +67,7 @@ class Lion_adv(torch.optim.Optimizer):
67
67
  # Stochastic Rounding for BF16
68
68
  stochastic_rounding: bool = True,
69
69
  # OrthoGrad
70
- orthogonal_gradient: bool = False,
70
+ orthogonal_gradient: str = 'disabled', # 'flattened', 'iterative'
71
71
  # Lion-k
72
72
  kappa_p: float = 1.0,
73
73
  auto_kappa_p: bool = False,
@@ -213,8 +213,9 @@ class Lion_adv(torch.optim.Optimizer):
213
213
  def _step_parameter(self, p, grad, state, group, lr, random_int_tensor, random_noise_tensor):
214
214
  if grad.dtype != torch.float32 and state['factored']:
215
215
  grad = grad.float()
216
- if group["orthogonal_gradient"]:
217
- grad = _orthogonalize_gradient(p, grad)
216
+ is_vector = p.ndim < 2 or getattr(p, '_is_dora_scale', False) or getattr(p, 'is_vector', False)
217
+
218
+ grad = _orthogonalize_gradient(p, grad, group["orthogonal_gradient"])
218
219
 
219
220
  # Lion-K Logic
220
221
  kappa_p = group.get("kappa_p", 1.0)
@@ -250,7 +251,7 @@ class Lion_adv(torch.optim.Optimizer):
250
251
  update = update.view(p.shape)
251
252
 
252
253
  if group.get('stochastic_sign', False):
253
- update = apply_stochastic_sign_(update, noise=random_noise_tensor)
254
+ update = apply_stochastic_sign_(update, noise=random_noise_tensor, is_vector=is_vector)
254
255
  else:
255
256
  update = _get_lion_k_update(update, kappa_p)
256
257
 
@@ -265,7 +266,7 @@ class Lion_adv(torch.optim.Optimizer):
265
266
  exp_avg.lerp_(grad, 1 - beta2)
266
267
 
267
268
  if group.get('stochastic_sign', False):
268
- update = apply_stochastic_sign_(update, noise=random_noise_tensor)
269
+ update = apply_stochastic_sign_(update, noise=random_noise_tensor, is_vector=is_vector)
269
270
  else:
270
271
  update = _get_lion_k_update(update, kappa_p)
271
272
 
@@ -39,7 +39,8 @@ class Muon_adv(torch.optim.Optimizer):
39
39
  (default: (3.4445, -4.7750, 2.0315)).
40
40
  stochastic_rounding (bool): whether to use stochastic rounding for
41
41
  BF16 parameter updates (default: True).
42
- orthogonal_gradient (bool): whether to use OrthoGrad. (default: False)
42
+ orthogonal_gradient (str): whether to use OrthoGrad variants. 'disabled': off.
43
+ 'flattened': Standard vectorized OrthoGrad. 'iterative': Matrix-wise rank-2 OrthoGrad. (default: disabled)
43
44
  vector_reshape (bool): whether to reshape 1D vectors into 2D
44
45
  matrices to apply low-rank compression (default: True).
45
46
  nnmf_factor (bool): whether to use the factorization or disable it to use
@@ -89,7 +90,7 @@ class Muon_adv(torch.optim.Optimizer):
89
90
  adam_fisher_wd (bool): Fisher Adam (FAdam) weight decay for the AdamW part. (default: False)
90
91
  adam_use_bias_correction (bool): Bias correction for AdamW.
91
92
  adam_use_atan2 (bool): Atan2 update rule for AdamW.
92
- adam_orthogonal_gradient (bool): OrthoGrad for AdamW.
93
+ adam_orthogonal_gradient (str): OrthoGrad for AdamW.
93
94
  adam_nesterov (bool): Nesterov momentum for AdamW. (default: False)
94
95
  adam_nesterov_coef (float, optional): Nesterov coefficient for AdamW. (default: None)
95
96
  adam_kourkoutas_beta (bool): Kourkoutas-β for AdamW.
@@ -121,7 +122,7 @@ class Muon_adv(torch.optim.Optimizer):
121
122
  # Stochastic Rounding for BF16
122
123
  stochastic_rounding: bool = True,
123
124
  # OrthoGrad
124
- orthogonal_gradient: bool = False,
125
+ orthogonal_gradient: str = 'disabled', # 'flattened', 'iterative'
125
126
  # RMS Rescaling
126
127
  rms_rescaling: bool = True,
127
128
  # SMMF factorization
@@ -159,7 +160,7 @@ class Muon_adv(torch.optim.Optimizer):
159
160
  adam_fisher_wd: bool = False,
160
161
  adam_use_bias_correction: bool = True,
161
162
  adam_use_atan2: bool = False,
162
- adam_orthogonal_gradient: bool = False,
163
+ adam_orthogonal_gradient: str = 'disabled', # 'flattened', 'iterative'
163
164
  adam_nesterov: bool = False,
164
165
  adam_nesterov_coef: float | None = None,
165
166
  adam_kourkoutas_beta: bool = False,
@@ -186,7 +187,7 @@ class Muon_adv(torch.optim.Optimizer):
186
187
  print("Warning: spectral_normalization is incompatible with rms_rescaling, Disabling rms_rescaling.")
187
188
  rms_rescaling = False
188
189
  if spectral_normalization and accelerated_ns:
189
- ValueError("spectral_normalization violates accelerated Newton-Schulz assumptions. Pick one of them.")
190
+ raise ValueError("spectral_normalization violates accelerated Newton-Schulz assumptions. Pick one of them.")
190
191
 
191
192
  # Legacy backwards compatibility support for `nnmf_factor=True`
192
193
  if nnmf_factor:
@@ -457,8 +458,7 @@ class Muon_adv(torch.optim.Optimizer):
457
458
  if grad.dtype != torch.float32 and state.get('factored', False):
458
459
  grad = grad.float()
459
460
 
460
- if group.get("orthogonal_gradient"):
461
- grad = _orthogonalize_gradient(p, grad)
461
+ grad = _orthogonalize_gradient(p, grad, group.get("orthogonal_gradient"))
462
462
 
463
463
  if state['factored']: # Factored Muon
464
464
  d1, d2 = state['effective_shape']
@@ -43,7 +43,8 @@ class Prodigy_adv(torch.optim.Optimizer):
43
43
  stochastic_rounding (bool): whether to use stochastic
44
44
  rounding for BF16 parameter updates (default: True).
45
45
  use_atan2 (bool): whether to use the atan2 update rule. (default: False)
46
- orthogonal_gradient (bool): whether to use OrthoGrad. (default: False)
46
+ orthogonal_gradient (str): whether to use OrthoGrad variants. 'disabled': off.
47
+ 'flattened': Standard vectorized OrthoGrad. 'iterative': Matrix-wise rank-2 OrthoGrad. (default: disabled)
47
48
  nnmf_factor (bool): whether to use the factorization or disable it to use
48
49
  the uncompressed optimizer. (default: False)
49
50
  factored_2nd (bool): whether to keep the first moment uncompressed (dense)
@@ -119,7 +120,7 @@ class Prodigy_adv(torch.optim.Optimizer):
119
120
  # Adam_atan2 (scale invariant)
120
121
  use_atan2: bool = False,
121
122
  # OrthoGrad
122
- orthogonal_gradient: bool = False,
123
+ orthogonal_gradient: str = 'disabled', # 'flattened', 'iterative'
123
124
  # Nesterov momentum
124
125
  nesterov: bool = False,
125
126
  nesterov_coef: float | None = None,
@@ -371,8 +372,7 @@ class Prodigy_adv(torch.optim.Optimizer):
371
372
  def _step_parameter(self, p, grad, state, group, beta2, d, dlr, random_int_tensor, random_int_state_tensor):
372
373
  grad = upcast_grad_for_precision(grad, state, group['state_precision'])
373
374
 
374
- if group["orthogonal_gradient"]:
375
- grad = _orthogonalize_gradient(p, grad)
375
+ grad = _orthogonalize_gradient(p, grad, group["orthogonal_gradient"])
376
376
 
377
377
  nesterov = group.get('nesterov', False)
378
378
  nesterov_coef = group.get('nesterov_coef', None)
@@ -62,7 +62,7 @@ class SignSGD_adv(torch.optim.Optimizer):
62
62
  # Stochastic Rounding for BF16
63
63
  stochastic_rounding: bool = True,
64
64
  # OrthoGrad
65
- orthogonal_gradient: bool = False,
65
+ orthogonal_gradient: str = 'disabled', # 'flattened', 'iterative'
66
66
  # Stochastic Sign Operator
67
67
  stochastic_sign: bool = False,
68
68
  # Nesterov momentum
@@ -171,7 +171,7 @@ class SignSGD_adv(torch.optim.Optimizer):
171
171
  def __init_state(self, p, group):
172
172
  state = self.state[p]
173
173
  # State Initialization
174
- if group["momentum"] > 0 and len(state) == 0:
174
+ if 'step' not in state:
175
175
  req_precision = group['state_precision']
176
176
  is_vector = len(p.shape) == 1 and not group['vector_reshape']
177
177
 
@@ -259,8 +259,7 @@ class SignSGD_adv(torch.optim.Optimizer):
259
259
  wd_target = None
260
260
  cwd_target = None
261
261
 
262
- if group["orthogonal_gradient"]:
263
- grad = _orthogonalize_gradient(p, grad)
262
+ grad = _orthogonalize_gradient(p, grad, group["orthogonal_gradient"])
264
263
 
265
264
  if normed_mt:
266
265
  if sso:
@@ -282,7 +281,7 @@ class SignSGD_adv(torch.optim.Optimizer):
282
281
 
283
282
  if nesterov and normed_mt:
284
283
  # Scale the normalized gradient using empirical buffer magnitude (SNR recovery)
285
- normed_grad = grad_reshaped * exp_avg.abs()
284
+ normed_grad = exp_avg.abs().mul_(grad_reshaped)
286
285
 
287
286
  exp_avg.lerp_(grad_reshaped, 1 - momentum)
288
287
 
@@ -313,7 +312,7 @@ class SignSGD_adv(torch.optim.Optimizer):
313
312
 
314
313
  if nesterov and normed_mt:
315
314
  # Scale the normalized gradient using empirical buffer magnitude (SNR recovery)
316
- normed_grad = grad * exp_avg.abs()
315
+ normed_grad = exp_avg.abs().mul_(grad)
317
316
 
318
317
  exp_avg.lerp_(grad, 1 - momentum)
319
318
 
@@ -344,7 +343,7 @@ class SignSGD_adv(torch.optim.Optimizer):
344
343
  if group.get('geometric_wd', False) and group["weight_decay"] > 0 :
345
344
  wd_target = get_signsgd_wd_target(p, denom=denom, stochastic_sign=sso, noise=random_noise_tensor, is_vector=is_vector)
346
345
 
347
- if group.get('centered_wd', 0.0) > 0 and 'anchor_type' in state:
346
+ if group.get('centered_wd', 0.0) > 0 and 'anchor_data' in state:
348
347
  anchor = dequantize_anchor(p, state, group, p.dtype)
349
348
  cwd_target = get_signsgd_wd_target(p.sub(anchor), denom=denom, stochastic_sign=sso, noise=random_noise_tensor, is_vector=is_vector)
350
349
  del anchor
@@ -355,7 +354,7 @@ class SignSGD_adv(torch.optim.Optimizer):
355
354
  update_scaling = lr * A if snr_cond else lr
356
355
  update.mul_(update_scaling)
357
356
 
358
- param_update.apply_parameter_update(self, p, group, update, lr, random_int_tensor=random_int_tensor, wd_target=wd_target, cwd_target=cwd_target, decoupled=True)
357
+ param_update.apply_parameter_update(self, p, group, update, lr, random_int_tensor=random_int_tensor, wd_target=wd_target, cwd_target=cwd_target)
359
358
 
360
359
  def compile(self, *args, **kwargs):
361
360
  self._compiled_step_parameter = torch.compile(self._step_parameter, *args, **kwargs)
@@ -69,7 +69,7 @@ class SinkSGD_adv(torch.optim.Optimizer):
69
69
  # Stochastic Rounding for BF16
70
70
  stochastic_rounding: bool = True,
71
71
  # OrthoGrad
72
- orthogonal_gradient: bool = False,
72
+ orthogonal_gradient: str = 'disabled', # 'flattened', 'iterative'
73
73
  # Spectral Normed Optimizer
74
74
  spectral_normalization: bool = False,
75
75
  # Centered WD
@@ -89,8 +89,8 @@ class SinkSGD_adv(torch.optim.Optimizer):
89
89
  raise ValueError(f"Momentum should be >= 0.0. Got {momentum}")
90
90
  if not (weight_decay >= 0.0):
91
91
  raise ValueError(f"Weight-decay should be >= 0.0. Got {weight_decay}")
92
- if snr_cond and not normed_momentum:
93
- raise NotImplementedError(f"snr_cond is intended to be used with normed_momentum")
92
+ if snr_cond and not normed_momentum and not momentum > 0:
93
+ raise NotImplementedError(f"snr_cond is intended to be used with normed_momentum.")
94
94
 
95
95
  state_precision = state_precision.lower()
96
96
  valid_precisions = {"auto", "fp32", "factored", "bf16_sr", "fp16", "int8_sr"}
@@ -237,8 +237,7 @@ class SinkSGD_adv(torch.optim.Optimizer):
237
237
  wd_target = None
238
238
  cwd_target = None
239
239
 
240
- if group["orthogonal_gradient"]:
241
- grad = _orthogonalize_gradient(p, grad)
240
+ grad = _orthogonalize_gradient(p, grad, group["orthogonal_gradient"])
242
241
 
243
242
  if normed_mt:
244
243
  if not is_vector:
@@ -266,7 +265,7 @@ class SinkSGD_adv(torch.optim.Optimizer):
266
265
 
267
266
  if nesterov and normed_mt:
268
267
  # Scale the normalized gradient using empirical buffer magnitude (SNR recovery)
269
- normed_grad = grad_reshaped * buf.abs()
268
+ normed_grad = buf.abs().mul_(grad_reshaped)
270
269
 
271
270
  buf.lerp_(grad_reshaped, 1 - momentum)
272
271
 
@@ -303,7 +302,7 @@ class SinkSGD_adv(torch.optim.Optimizer):
303
302
 
304
303
  if nesterov and normed_mt:
305
304
  # Scale the normalized gradient using empirical buffer magnitude (SNR recovery)
306
- normed_grad = grad * buf.abs()
305
+ normed_grad = buf.abs().mul_(grad)
307
306
 
308
307
  buf.lerp_(grad, 1 - momentum)
309
308
 
@@ -346,7 +345,7 @@ class SinkSGD_adv(torch.optim.Optimizer):
346
345
  wd_scaler = get_sinkhorn_wd_scaler(p, row_denom=vt_row, col_denom=vt_col)
347
346
  else:
348
347
  wd_target = get_signsgd_wd_target(p, denom=denom)
349
- if is_vector and group.get('centered_wd', 0.0) > 0 and 'anchor_type' in state:
348
+ if is_vector and group.get('centered_wd', 0.0) > 0 and 'anchor_data' in state:
350
349
  anchor = dequantize_anchor(p, state, group, p.dtype)
351
350
  cwd_target = get_signsgd_wd_target(p.sub(anchor), denom=denom)
352
351
  del anchor
@@ -71,8 +71,7 @@ def _init_auxadam_state(self, p, group):
71
71
  def _adam_step_parameter(self, p, grad, state, group, beta1_adam, beta2_adam, sqrt_bias_correction2, step_size, random_int_tensor, random_int_state_tensor=None):
72
72
  grad = upcast_grad_for_precision(grad, state, group.get('adam_state_precision', 'auto'))
73
73
 
74
- if group.get("adam_orthogonal_gradient"):
75
- grad = _orthogonalize_gradient(p, grad)
74
+ grad = _orthogonalize_gradient(p, grad, group.get("adam_orthogonal_gradient"))
76
75
 
77
76
  if hasattr(self, 'kourkoutas_helper') and self.kourkoutas_helper:
78
77
  # Accumulate current grad's norm for the *next* step
@@ -190,4 +189,4 @@ def _adam_step_parameter(self, p, grad, state, group, beta1_adam, beta2_adam, sq
190
189
  else:
191
190
  update.mul_(update_scaling)
192
191
 
193
- param_update.apply_parameter_update(self, p, group, update, step_size, group["adam_weight_decay"], random_int_tensor=random_int_tensor, wd_scaler=wd_scaler)
192
+ param_update.apply_parameter_update(self, p, group, update, group['lr'], group["adam_weight_decay"], random_int_tensor=random_int_tensor, wd_scaler=wd_scaler)
@@ -0,0 +1,92 @@
1
+ import torch
2
+ import math
3
+
4
+ def _orthogonalize_gradient(p: torch.Tensor, grad: torch.Tensor, mode: str) -> torch.Tensor:
5
+ """
6
+ Projects the gradient `grad` to be orthogonal to the parameter `p`.
7
+ Supports two modes: 'flattened' (vectorized) and 'iterative' (matrix-wise).
8
+ """
9
+ if mode == 'disabled':
10
+ return grad
11
+ elif mode == 'flattened':
12
+ return flattened_ortho_project(p, grad)
13
+ elif mode == 'iterative':
14
+ return iterative_ortho_project(p, grad, iters=3)
15
+
16
+ def flattened_ortho_project(p: torch.Tensor, grad: torch.Tensor) -> torch.Tensor:
17
+ """
18
+ Projects the flattened gradient `grad` to be orthogonal to the flattened parameter `p`.
19
+ Modified from:
20
+ https://github.com/LucasPrietoAl/grokking-at-the-edge-of-numerical-stability/blob/720d2444df12b851d6cb417ab08cf125c822b2ae/orthograd.py
21
+ """
22
+ original_shape = grad.shape
23
+ original_dtype = grad.dtype
24
+ w = p.view(-1).float()
25
+ g = grad.view(-1).float()
26
+ w_norm_sq = torch.dot(w, w).add_(1e-30)
27
+ proj = torch.dot(w, g) / w_norm_sq
28
+ g_orth = g.sub(w * proj)
29
+ g_norm = g.norm(2)
30
+ g_orth_norm = g_orth.norm(2).add_(1e-30)
31
+ g_orth_scaled = g_orth * (g_norm / g_orth_norm)
32
+ return g_orth_scaled.view(original_shape).to(original_dtype)
33
+
34
+
35
+ def iterative_ortho_project(p: torch.Tensor, grad: torch.Tensor, iters: int = 3) -> torch.Tensor:
36
+ """
37
+ Applies iterative alternating orthogonal projection to a 2D matrix.
38
+ Projects the grad to be orthogonal to the parameter matrix along
39
+ rows and columns sequentially, alternating dimensions.
40
+ Inspired from Sinkhorn algorithm, 2-3 iterations is enough to converge
41
+ to cosine similarity of -1e4 to -1e-6 for every row/col (semi orthogonal).
42
+ """
43
+ # 1D Vector Case fallback to the standard OrthoGrad
44
+ is_vector = p.ndim < 2 or getattr(p, '_is_dora_scale', False) or getattr(p, 'is_vector', False)
45
+ if is_vector:
46
+ return _orthogonalize_gradient(p, grad)
47
+
48
+ original_shape = grad.shape
49
+
50
+ # 2D+ Matrix Case
51
+ grad_2d = grad.view(grad.shape[0], -1)
52
+ param_2d = p.view(p.shape[0], -1)
53
+
54
+ m, n = grad_2d.shape
55
+
56
+ # Dynamically determine the order based on aspect ratio
57
+ row_first = m > n
58
+ dim = 0 if row_first else 1
59
+
60
+ p_norm_sq_dim = torch.sum(param_2d * param_2d, dim=dim, keepdim=True).add_(1e-30)
61
+ p_norm_sq_adim = torch.sum(param_2d * param_2d, dim=1-dim, keepdim=True).add_(1e-30)
62
+
63
+ for _ in range(iters):
64
+ # First dimension
65
+ grad_2d = _ortho_normed_dim(param_2d, grad_2d, p_norm_sq_dim, dim)
66
+ # Second dimension
67
+ grad_2d = _ortho_normed_dim(param_2d, grad_2d, p_norm_sq_adim, 1 - dim)
68
+
69
+ return grad_2d.view(original_shape)
70
+
71
+
72
+ def _ortho_normed_dim(p_2d: torch.Tensor, grad_2d: torch.Tensor, p_norm_sq: torch.Tensor, dim: int) -> torch.Tensor:
73
+ """
74
+ Projects the grad to be orthogonal to p along 'dim' and dynamically restores
75
+ the original magnitude of that dimension pre-projection.
76
+ """
77
+ # Record target magnitude before projection
78
+ norm_lb = 1 / math.sqrt(grad_2d.shape[dim])
79
+ target_norm = grad_2d.norm(p=2, dim=dim, keepdim=True).clamp_min_(norm_lb)
80
+
81
+ # Project: g_orth = g - (p * <p, g> / ||p||^2)
82
+ dot_prod = torch.sum(p_2d * grad_2d, dim=dim, keepdim=True)
83
+ proj = dot_prod / p_norm_sq
84
+
85
+ # In-place subtraction: grad_2d = grad_2d - (proj * p_2d)
86
+ # Standard gamma is -1, but -1.01 proved to converge faster
87
+ grad_2d.addcmul_(proj, p_2d, value=-1.01)
88
+
89
+ # Magnitude Preservation
90
+ g_orth_norm = grad_2d.norm(p=2, dim=dim, keepdim=True).clamp_min_(norm_lb)
91
+ scale_factor = target_norm / g_orth_norm
92
+ return grad_2d.mul_(scale_factor)
@@ -6,7 +6,7 @@ import torch.nn.functional as F
6
6
 
7
7
  from typing import Dict, Any
8
8
 
9
- from .scaled_optm import adjust_wds, scale_wds
9
+ from .scaled_optm import adjust_wds
10
10
  from .centered_decay import dequantize_anchor
11
11
 
12
12
  _generators: Dict[torch.device, torch.Generator] = {}
@@ -48,7 +48,7 @@ def _apply_weight_decay(
48
48
  p_calc.add_(wd_target, alpha=-scaled_wd)
49
49
 
50
50
  # Centered Weight Decay (pulls toward anchor)
51
- if scaled_cwd is not None and 'anchor_type' in state:
51
+ if scaled_cwd is not None and 'anchor_data' in state:
52
52
  if cwd_target is not None:
53
53
  decay_target = cwd_target
54
54
  else:
@@ -330,7 +330,7 @@ def _copy_int8_sym_blockwise_stochastic_core_(
330
330
  target: torch.Tensor,
331
331
  source: torch.Tensor,
332
332
  scales: torch.Tensor,
333
- random_int_tensor: torch.Tensor | None,
333
+ random_int_tensor: torch.Tensor,
334
334
  block_size: int = 2048,
335
335
  val_blocks: torch.Tensor | None = None,
336
336
  ) -> None:
@@ -61,7 +61,7 @@ def adjust_wds(wd: float, cwd: float, p: torch.Tensor) -> tuple[float, float]:
61
61
  """
62
62
  # DoRA Scale (Magnitude Vector)
63
63
  if getattr(p, '_is_dora_scale', False):
64
- return 0.0, cwd
64
+ return wd, cwd
65
65
 
66
66
  if getattr(p, '_is_oft', False):
67
67
  return wd, 0.0
@@ -76,7 +76,7 @@ def adjust_wds(wd: float, cwd: float, p: torch.Tensor) -> tuple[float, float]:
76
76
  else:
77
77
  # 1D Biases or generic 1D parameters
78
78
  # Centered WD safely regularizes the delta without collapsing base feature variance.
79
- return 0.0, cwd
79
+ return wd, cwd
80
80
 
81
81
 
82
82
  def scale_wds(wd: float, cwd: float, p: torch.Tensor) -> tuple[float, float]:
@@ -209,7 +209,7 @@ def fix_loaded_state_dtype(state: dict, p: torch.Tensor, group: dict) -> None:
209
209
 
210
210
  # Pre-define sets for known exact-match keys
211
211
  uint8_keys = {'sign', 'sign_slow', 'sign_buf', 'shifter'}
212
- fp32_keys = {'mu_m_nmf', 'mv_m_nmf', 'mu_v_nmf', 'mv_v_nmf', 'mu_m_slow_nmf', 'mv_m_slow_nmf'}
212
+ fp32_keys = {'mu_m_nmf', 'mv_m_nmf', 'mu_v_nmf', 'mv_v_nmf', 'mu_m_slow_nmf', 'mv_m_slow_nmf', "mu_mbuf_nmf", "mv_mbuf_nmf", "mu_b_nmf", "normuon_v"}
213
213
 
214
214
  for key, val in state.items():
215
215
  if not isinstance(val, torch.Tensor):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: adv_optm
3
- Version: 2.4.dev25
3
+ Version: 2.5
4
4
  Summary: A family of highly efficient, lightweight yet powerful optimizers.
5
5
  Home-page: https://github.com/Koratahiu/Advanced_Optimizers
6
6
  Author: Koratahiu
@@ -5,7 +5,7 @@ with open("README.md", "r", encoding="utf-8") as fh:
5
5
 
6
6
  setup(
7
7
  name="adv_optm",
8
- version="2.4.dev25",
8
+ version="2.5",
9
9
  author="Koratahiu",
10
10
  author_email="hiuhonor@gmail.com",
11
11
  license='Apache 2.0',
@@ -1,80 +0,0 @@
1
- import torch
2
- import math
3
-
4
- def _orthogonalize_gradient(p: torch.Tensor, grad: torch.Tensor) -> torch.Tensor:
5
- """
6
- Projects the gradient `grad` to be orthogonal to the parameter `p`.
7
- Modified from:
8
- https://github.com/LucasPrietoAl/grokking-at-the-edge-of-numerical-stability/blob/720d2444df12b851d6cb417ab08cf125c822b2ae/orthograd.py
9
- """
10
- original_shape = grad.shape
11
- original_dtype = grad.dtype
12
- w = p.view(-1).float()
13
- g = grad.view(-1).float()
14
- w_norm_sq = torch.dot(w, w).add_(1e-30)
15
- proj = torch.dot(w, g) / w_norm_sq
16
- g_orth = g.sub(w * proj)
17
- g_norm = g.norm(2)
18
- g_orth_norm = g_orth.norm(2).add_(1e-30)
19
- g_orth_scaled = g_orth * (g_norm / g_orth_norm)
20
- return g_orth_scaled.view(original_shape).to(original_dtype)
21
-
22
-
23
- def iterative_ortho_project(p: torch.Tensor, update: torch.Tensor, iters: int = 5) -> torch.Tensor:
24
- """
25
- Applies iterative alternating orthogonal projection to a 2D matrix.
26
- Projects the update to be orthogonal to the parameter matrix along
27
- rows and columns sequentially, alternating dimensions.
28
- Inspired from Sinkhorn algorithm, 2 iterations is enough to converge
29
- to cosine similarity of -1e4 to -1e-5 (semi orthogonal).
30
- """
31
- # 1D Vector Case fallback to the standard OrthoGrad
32
- is_vector = p.ndim < 2 or getattr(p, '_is_dora_scale', False) or getattr(p, 'is_vector', False)
33
- if is_vector:
34
- return _orthogonalize_gradient(p, update)
35
-
36
- original_shape = update.shape
37
-
38
- # 2D+ Matrix Case
39
- update_2d = update.view(update.shape[0], -1)
40
- param_2d = p.view(p.shape[0], -1)
41
-
42
- m, n = update_2d.shape
43
-
44
- # Dynamically determine the order based on aspect ratio
45
- row_first = m > n
46
- dim = 0 if row_first else 1
47
-
48
- p_norm_sq_dim = torch.sum(param_2d * param_2d, dim=dim, keepdim=True).add_(1e-30)
49
- p_norm_sq_adim = torch.sum(param_2d * param_2d, dim=1-dim, keepdim=True).add_(1e-30)
50
-
51
- for _ in range(iters):
52
- # First dimension
53
- update_2d = _ortho_normed_dim(param_2d, update_2d, p_norm_sq_dim, dim)
54
- # Second dimension
55
- update_2d = _ortho_normed_dim(param_2d, update_2d, p_norm_sq_adim, 1 - dim)
56
-
57
- return update_2d.view(original_shape)
58
-
59
-
60
- def _ortho_normed_dim(p_2d: torch.Tensor, update_2d: torch.Tensor, p_norm_sq: torch.Tensor, dim: int) -> torch.Tensor:
61
- """
62
- Projects the update to be orthogonal to p along 'dim' and dynamically restores
63
- the original magnitude of that dimension pre-projection.
64
- """
65
- # Record target magnitude before projection
66
- norm_lb = 1 / math.sqrt(update_2d.shape[dim])
67
- target_norm = update_2d.norm(p=2, dim=dim, keepdim=True).clamp_min_(norm_lb)
68
-
69
- # Project: g_orth = g - (p * <p, g> / ||p||^2)
70
- dot_prod = torch.sum(p_2d * update_2d, dim=dim, keepdim=True)
71
- proj = dot_prod / p_norm_sq
72
-
73
- # In-place subtraction: update_2d = update_2d - (proj * p_2d)
74
- # Standard gamma is -1, but -1.01 proved to converge faster
75
- update_2d.addcmul_(proj, p_2d, value=-1.01)
76
-
77
- # Magnitude Preservation
78
- g_orth_norm = update_2d.norm(p=2, dim=dim, keepdim=True).clamp_min_(norm_lb)
79
- scale_factor = target_norm / g_orth_norm
80
- return update_2d.mul_(scale_factor)
File without changes
File without changes
File without changes