adv-optm 2.4.dev21__tar.gz → 2.4.dev23__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. {adv_optm-2.4.dev21 → adv_optm-2.4.dev23}/PKG-INFO +1 -1
  2. {adv_optm-2.4.dev21 → adv_optm-2.4.dev23}/adv_optm/__init__.py +1 -1
  3. {adv_optm-2.4.dev21 → adv_optm-2.4.dev23}/adv_optm/optim/SignSGD_adv.py +10 -10
  4. {adv_optm-2.4.dev21 → adv_optm-2.4.dev23}/adv_optm/optim/SinkSGD_adv.py +13 -13
  5. {adv_optm-2.4.dev21 → adv_optm-2.4.dev23}/adv_optm.egg-info/PKG-INFO +1 -1
  6. {adv_optm-2.4.dev21 → adv_optm-2.4.dev23}/setup.py +1 -1
  7. {adv_optm-2.4.dev21 → adv_optm-2.4.dev23}/LICENSE +0 -0
  8. {adv_optm-2.4.dev21 → adv_optm-2.4.dev23}/README.md +0 -0
  9. {adv_optm-2.4.dev21 → adv_optm-2.4.dev23}/adv_optm/optim/AdaMuon_adv.py +0 -0
  10. {adv_optm-2.4.dev21 → adv_optm-2.4.dev23}/adv_optm/optim/AdamW_adv.py +0 -0
  11. {adv_optm-2.4.dev21 → adv_optm-2.4.dev23}/adv_optm/optim/Adopt_adv.py +0 -0
  12. {adv_optm-2.4.dev21 → adv_optm-2.4.dev23}/adv_optm/optim/Lion_adv.py +0 -0
  13. {adv_optm-2.4.dev21 → adv_optm-2.4.dev23}/adv_optm/optim/Muon_adv.py +0 -0
  14. {adv_optm-2.4.dev21 → adv_optm-2.4.dev23}/adv_optm/optim/Prodigy_adv.py +0 -0
  15. {adv_optm-2.4.dev21 → adv_optm-2.4.dev23}/adv_optm/optim/__init__.py +0 -0
  16. {adv_optm-2.4.dev21 → adv_optm-2.4.dev23}/adv_optm/util/Kourkoutas.py +0 -0
  17. {adv_optm-2.4.dev21 → adv_optm-2.4.dev23}/adv_optm/util/Muon_AuxAdam.py +0 -0
  18. {adv_optm-2.4.dev21 → adv_optm-2.4.dev23}/adv_optm/util/Muon_util.py +0 -0
  19. {adv_optm-2.4.dev21 → adv_optm-2.4.dev23}/adv_optm/util/OrthoGrad.py +0 -0
  20. {adv_optm-2.4.dev21 → adv_optm-2.4.dev23}/adv_optm/util/__init__.py +0 -0
  21. {adv_optm-2.4.dev21 → adv_optm-2.4.dev23}/adv_optm/util/centered_decay.py +0 -0
  22. {adv_optm-2.4.dev21 → adv_optm-2.4.dev23}/adv_optm/util/factorization_util.py +0 -0
  23. {adv_optm-2.4.dev21 → adv_optm-2.4.dev23}/adv_optm/util/lion_k.py +0 -0
  24. {adv_optm-2.4.dev21 → adv_optm-2.4.dev23}/adv_optm/util/param_update.py +0 -0
  25. {adv_optm-2.4.dev21 → adv_optm-2.4.dev23}/adv_optm/util/scaled_optm.py +0 -0
  26. {adv_optm-2.4.dev21 → adv_optm-2.4.dev23}/adv_optm/util/signed_util.py +0 -0
  27. {adv_optm-2.4.dev21 → adv_optm-2.4.dev23}/adv_optm/util/sinkhorn.py +0 -0
  28. {adv_optm-2.4.dev21 → adv_optm-2.4.dev23}/adv_optm/util/state_util.py +0 -0
  29. {adv_optm-2.4.dev21 → adv_optm-2.4.dev23}/adv_optm/util/update_util.py +0 -0
  30. {adv_optm-2.4.dev21 → adv_optm-2.4.dev23}/adv_optm.egg-info/SOURCES.txt +0 -0
  31. {adv_optm-2.4.dev21 → adv_optm-2.4.dev23}/adv_optm.egg-info/dependency_links.txt +0 -0
  32. {adv_optm-2.4.dev21 → adv_optm-2.4.dev23}/adv_optm.egg-info/requires.txt +0 -0
  33. {adv_optm-2.4.dev21 → adv_optm-2.4.dev23}/adv_optm.egg-info/top_level.txt +0 -0
  34. {adv_optm-2.4.dev21 → adv_optm-2.4.dev23}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: adv_optm
3
- Version: 2.4.dev21
3
+ Version: 2.4.dev23
4
4
  Summary: A family of highly efficient, lightweight yet powerful optimizers.
5
5
  Home-page: https://github.com/Koratahiu/Advanced_Optimizers
6
6
  Author: Koratahiu
@@ -20,4 +20,4 @@ __all__ = [
20
20
  "SinkSGD_adv",
21
21
  ]
22
22
 
23
- __version__ = "2.4.dev21"
23
+ __version__ = "2.4.dev23"
@@ -70,8 +70,8 @@ class SignSGD_adv(torch.optim.Optimizer):
70
70
  nesterov_coef: float | None = None,
71
71
  # Normalization then Momentum
72
72
  normed_momentum: bool = False,
73
- # Centered Variance Precondition
74
- centered_vt: bool = False,
73
+ # SNR Precondition
74
+ snr_cond: bool = False,
75
75
  # Centered WD
76
76
  centered_wd: float = 0.0,
77
77
  centered_wd_mode: str = 'float8',
@@ -91,8 +91,8 @@ class SignSGD_adv(torch.optim.Optimizer):
91
91
  raise ValueError(f"momentum should be in [0.0, 1.0], but got {momentum}")
92
92
  if not weight_decay >= 0.0:
93
93
  raise ValueError(f"Weight decay must be >= 0.0, but got {weight_decay}")
94
- if centered_vt and not normed_momentum and not momentum > 0:
95
- raise NotImplementedError(f"centered_vt is intended to be used with normed_momentum")
94
+ if snr_cond and not normed_momentum and not momentum > 0:
95
+ raise NotImplementedError(f"snr_cond is intended to be used with normed_momentum")
96
96
 
97
97
  state_precision = state_precision.lower()
98
98
  valid_precisions = {"auto", "fp32", "factored", "bf16_sr", "fp16", "fp8_sr", "int8_sr"}
@@ -115,7 +115,7 @@ class SignSGD_adv(torch.optim.Optimizer):
115
115
  nesterov=nesterov,
116
116
  nesterov_coef=nesterov_coef,
117
117
  normed_momentum=normed_momentum,
118
- centered_vt=centered_vt,
118
+ snr_cond=snr_cond,
119
119
  spectral_normalization=spectral_normalization,
120
120
  centered_wd= centered_wd,
121
121
  centered_wd_mode= centered_wd_mode,
@@ -254,7 +254,7 @@ class SignSGD_adv(torch.optim.Optimizer):
254
254
  nesterov = group.get('nesterov', False)
255
255
  nesterov_coef = group.get('nesterov_coef', None)
256
256
  sso = group.get('stochastic_sign', False)
257
- centered_vt = group.get('centered_vt', False) and group.get('normed_momentum', False) and momentum > 0
257
+ snr_cond = group.get('snr_cond', False) and group.get('normed_momentum', False) and momentum > 0
258
258
 
259
259
  denom = None
260
260
  wd_target = None
@@ -278,7 +278,7 @@ class SignSGD_adv(torch.optim.Optimizer):
278
278
  # Reconstruct momentum m_{t-1}
279
279
  exp_avg = _reconstruct_state((state['mu_m_nmf'], state['mv_m_nmf'], state['sign'], d2), signed=True, shifter=state['shifter'])
280
280
 
281
- if centered_vt:
281
+ if snr_cond:
282
282
  denom = (1.0 - exp_avg.square()).clamp_min_(1e-30).sqrt_().view_as(p)
283
283
 
284
284
  exp_avg.lerp_(grad_reshaped, 1 - momentum)
@@ -302,7 +302,7 @@ class SignSGD_adv(torch.optim.Optimizer):
302
302
  actual_precision = group['actual_state_precision']
303
303
  exp_avg = get_state(state, 'exp_avg', actual_precision)
304
304
 
305
- if centered_vt:
305
+ if snr_cond:
306
306
  denom = (1.0 - exp_avg.square()).clamp_min_(1e-30).sqrt_()
307
307
 
308
308
  exp_avg.lerp_(grad, 1 - momentum)
@@ -325,7 +325,7 @@ class SignSGD_adv(torch.optim.Optimizer):
325
325
  else:
326
326
  update = raw_update
327
327
 
328
- if centered_vt:
328
+ if snr_cond:
329
329
  update.atan2_(denom)
330
330
 
331
331
  if group.get('geometric_wd', False) and group["weight_decay"] > 0 :
@@ -339,7 +339,7 @@ class SignSGD_adv(torch.optim.Optimizer):
339
339
  if group.get('spectral_normalization', False):
340
340
  update = scale_update(p, update, lr, state=state)
341
341
  else:
342
- update_scaling = lr * A if centered_vt else lr
342
+ update_scaling = lr * A if snr_cond else lr
343
343
  update.mul_(update_scaling)
344
344
 
345
345
  param_update.apply_parameter_update(self, p, group, update, lr, random_int_tensor=random_int_tensor, wd_target=wd_target, cwd_target=cwd_target)
@@ -58,8 +58,8 @@ class SinkSGD_adv(torch.optim.Optimizer):
58
58
  orthogonal_sinkhorn: bool = False,
59
59
  # Normalization then Momentum
60
60
  normed_momentum: bool = False,
61
- # Centered Variance Precondition
62
- centered_vt: bool = False,
61
+ # SNR Precondition
62
+ snr_cond: bool = False,
63
63
  # Nesterov Momentum
64
64
  nesterov: bool = False,
65
65
  nesterov_coef: float | None = None,
@@ -89,8 +89,8 @@ class SinkSGD_adv(torch.optim.Optimizer):
89
89
  raise ValueError(f"Momentum should be >= 0.0. Got {momentum}")
90
90
  if not (weight_decay >= 0.0):
91
91
  raise ValueError(f"Weight-decay should be >= 0.0. Got {weight_decay}")
92
- if centered_vt and not normed_momentum:
93
- raise NotImplementedError(f"centered_vt is intended to be used with normed_momentum")
92
+ if snr_cond and not normed_momentum:
93
+ raise NotImplementedError(f"snr_cond is intended to be used with normed_momentum")
94
94
 
95
95
  state_precision = state_precision.lower()
96
96
  valid_precisions = {"auto", "fp32", "factored", "bf16_sr", "fp16", "fp8_sr", "int8_sr"}
@@ -102,7 +102,7 @@ class SinkSGD_adv(torch.optim.Optimizer):
102
102
 
103
103
  defaults = {
104
104
  "lr": lr, "momentum": momentum,
105
- "weight_decay": weight_decay, "nesterov": nesterov, "nesterov_coef": nesterov_coef, "normed_momentum": normed_momentum, "centered_vt": centered_vt,
105
+ "weight_decay": weight_decay, "nesterov": nesterov, "nesterov_coef": nesterov_coef, "normed_momentum": normed_momentum, "snr_cond": snr_cond,
106
106
  "geometric_wd": geometric_wd, "cautious_wd": cautious_wd,
107
107
  "orthogonal_gradient": orthogonal_gradient,
108
108
  "compiled_optimizer": compiled_optimizer,
@@ -228,7 +228,7 @@ class SinkSGD_adv(torch.optim.Optimizer):
228
228
  momentum = group['momentum']
229
229
  nesterov = group['nesterov']
230
230
  nesterov_coef = group.get('nesterov_coef', None)
231
- centered_vt = group.get('centered_vt', False)
231
+ snr_cond = group.get('snr_cond', False)
232
232
 
233
233
  vt_row = None
234
234
  vt_col = None
@@ -238,6 +238,9 @@ class SinkSGD_adv(torch.optim.Optimizer):
238
238
  wd_target = None
239
239
  cwd_target = None
240
240
 
241
+ if group["orthogonal_gradient"]:
242
+ grad = _orthogonalize_gradient(p, grad)
243
+
241
244
  if group.get('normed_momentum', False):
242
245
  if not is_vector:
243
246
  # Sinkhorn iterative normalization
@@ -246,9 +249,6 @@ class SinkSGD_adv(torch.optim.Optimizer):
246
249
  # For vectors, apply sign operation
247
250
  grad = grad.sign_()
248
251
 
249
- if group["orthogonal_gradient"]:
250
- grad = _orthogonalize_gradient(p, grad)
251
-
252
252
  if state['factored']:
253
253
  d1, d2 = state['effective_shape']
254
254
  grad_reshaped = grad.view(d1, d2)
@@ -256,7 +256,7 @@ class SinkSGD_adv(torch.optim.Optimizer):
256
256
  if momentum != 0:
257
257
  buf = _reconstruct_state((state['mu_b_nmf'], state['mv_b_nmf'], state['sign'], d2), signed=True, shifter=state['shifter'])
258
258
 
259
- if centered_vt:
259
+ if snr_cond:
260
260
  if not is_vector:
261
261
  buf_2d_sq = buf.view(grad.shape[0], -1).square()
262
262
  vt_row = (1 - buf_2d_sq.mean(dim=-1)).clamp_min_(1e-30)
@@ -286,7 +286,7 @@ class SinkSGD_adv(torch.optim.Optimizer):
286
286
  if momentum != 0:
287
287
  buf = get_state(state, 'momentum_buffer', actual_precision)
288
288
 
289
- if centered_vt:
289
+ if snr_cond:
290
290
  if not is_vector:
291
291
  buf_2d_sq = buf.view(grad.shape[0], -1).square()
292
292
  vt_row = (1 - buf_2d_sq.mean(dim=-1)).clamp_min_(1e-30)
@@ -309,7 +309,7 @@ class SinkSGD_adv(torch.optim.Optimizer):
309
309
 
310
310
  del random_int_state_tensor
311
311
 
312
- if centered_vt:
312
+ if snr_cond:
313
313
  if not is_vector:
314
314
  # Align with Sinkhorn: Alternate row/col preconditioning
315
315
  update_2d = update.view(update.shape[0], -1)
@@ -342,7 +342,7 @@ class SinkSGD_adv(torch.optim.Optimizer):
342
342
  if group.get('spectral_normalization', False):
343
343
  update = scale_update(p, update, update_scaling, state=state)
344
344
  else:
345
- if centered_vt:
345
+ if snr_cond:
346
346
  update_scaling = update_scaling * (4/math.pi)
347
347
  update.mul_(update_scaling)
348
348
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: adv_optm
3
- Version: 2.4.dev21
3
+ Version: 2.4.dev23
4
4
  Summary: A family of highly efficient, lightweight yet powerful optimizers.
5
5
  Home-page: https://github.com/Koratahiu/Advanced_Optimizers
6
6
  Author: Koratahiu
@@ -5,7 +5,7 @@ with open("README.md", "r", encoding="utf-8") as fh:
5
5
 
6
6
  setup(
7
7
  name="adv_optm",
8
- version="2.4.dev21",
8
+ version="2.4.dev23",
9
9
  author="Koratahiu",
10
10
  author_email="hiuhonor@gmail.com",
11
11
  license='Apache 2.0',
File without changes
File without changes
File without changes