adv-optm 2.4.dev24__tar.gz → 2.4.dev25__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. {adv_optm-2.4.dev24 → adv_optm-2.4.dev25}/PKG-INFO +1 -1
  2. {adv_optm-2.4.dev24 → adv_optm-2.4.dev25}/adv_optm/__init__.py +1 -1
  3. {adv_optm-2.4.dev24 → adv_optm-2.4.dev25}/adv_optm/optim/SignSGD_adv.py +12 -8
  4. {adv_optm-2.4.dev24 → adv_optm-2.4.dev25}/adv_optm/optim/SinkSGD_adv.py +10 -6
  5. {adv_optm-2.4.dev24 → adv_optm-2.4.dev25}/adv_optm.egg-info/PKG-INFO +1 -1
  6. {adv_optm-2.4.dev24 → adv_optm-2.4.dev25}/setup.py +1 -1
  7. {adv_optm-2.4.dev24 → adv_optm-2.4.dev25}/LICENSE +0 -0
  8. {adv_optm-2.4.dev24 → adv_optm-2.4.dev25}/README.md +0 -0
  9. {adv_optm-2.4.dev24 → adv_optm-2.4.dev25}/adv_optm/optim/AdaMuon_adv.py +0 -0
  10. {adv_optm-2.4.dev24 → adv_optm-2.4.dev25}/adv_optm/optim/AdamW_adv.py +0 -0
  11. {adv_optm-2.4.dev24 → adv_optm-2.4.dev25}/adv_optm/optim/Adopt_adv.py +0 -0
  12. {adv_optm-2.4.dev24 → adv_optm-2.4.dev25}/adv_optm/optim/Lion_adv.py +0 -0
  13. {adv_optm-2.4.dev24 → adv_optm-2.4.dev25}/adv_optm/optim/Muon_adv.py +0 -0
  14. {adv_optm-2.4.dev24 → adv_optm-2.4.dev25}/adv_optm/optim/Prodigy_adv.py +0 -0
  15. {adv_optm-2.4.dev24 → adv_optm-2.4.dev25}/adv_optm/optim/__init__.py +0 -0
  16. {adv_optm-2.4.dev24 → adv_optm-2.4.dev25}/adv_optm/util/Kourkoutas.py +0 -0
  17. {adv_optm-2.4.dev24 → adv_optm-2.4.dev25}/adv_optm/util/Muon_AuxAdam.py +0 -0
  18. {adv_optm-2.4.dev24 → adv_optm-2.4.dev25}/adv_optm/util/Muon_util.py +0 -0
  19. {adv_optm-2.4.dev24 → adv_optm-2.4.dev25}/adv_optm/util/OrthoGrad.py +0 -0
  20. {adv_optm-2.4.dev24 → adv_optm-2.4.dev25}/adv_optm/util/__init__.py +0 -0
  21. {adv_optm-2.4.dev24 → adv_optm-2.4.dev25}/adv_optm/util/centered_decay.py +0 -0
  22. {adv_optm-2.4.dev24 → adv_optm-2.4.dev25}/adv_optm/util/factorization_util.py +0 -0
  23. {adv_optm-2.4.dev24 → adv_optm-2.4.dev25}/adv_optm/util/lion_k.py +0 -0
  24. {adv_optm-2.4.dev24 → adv_optm-2.4.dev25}/adv_optm/util/param_update.py +0 -0
  25. {adv_optm-2.4.dev24 → adv_optm-2.4.dev25}/adv_optm/util/scaled_optm.py +0 -0
  26. {adv_optm-2.4.dev24 → adv_optm-2.4.dev25}/adv_optm/util/signed_util.py +0 -0
  27. {adv_optm-2.4.dev24 → adv_optm-2.4.dev25}/adv_optm/util/sinkhorn.py +0 -0
  28. {adv_optm-2.4.dev24 → adv_optm-2.4.dev25}/adv_optm/util/state_util.py +0 -0
  29. {adv_optm-2.4.dev24 → adv_optm-2.4.dev25}/adv_optm/util/update_util.py +0 -0
  30. {adv_optm-2.4.dev24 → adv_optm-2.4.dev25}/adv_optm.egg-info/SOURCES.txt +0 -0
  31. {adv_optm-2.4.dev24 → adv_optm-2.4.dev25}/adv_optm.egg-info/dependency_links.txt +0 -0
  32. {adv_optm-2.4.dev24 → adv_optm-2.4.dev25}/adv_optm.egg-info/requires.txt +0 -0
  33. {adv_optm-2.4.dev24 → adv_optm-2.4.dev25}/adv_optm.egg-info/top_level.txt +0 -0
  34. {adv_optm-2.4.dev24 → adv_optm-2.4.dev25}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: adv_optm
3
- Version: 2.4.dev24
3
+ Version: 2.4.dev25
4
4
  Summary: A family of highly efficient, lightweight yet powerful optimizers.
5
5
  Home-page: https://github.com/Koratahiu/Advanced_Optimizers
6
6
  Author: Koratahiu
@@ -20,4 +20,4 @@ __all__ = [
20
20
  "SinkSGD_adv",
21
21
  ]
22
22
 
23
- __version__ = "2.4.dev24"
23
+ __version__ = "2.4.dev25"
@@ -280,16 +280,18 @@ class SignSGD_adv(torch.optim.Optimizer):
280
280
  if snr_cond:
281
281
  denom = (1.0 - exp_avg.square()).clamp_min_(1e-30).sqrt_().view_as(p)
282
282
 
283
+ if nesterov and normed_mt:
284
+ # Scale the normalized gradient using empirical buffer magnitude (SNR recovery)
285
+ normed_grad = grad_reshaped * exp_avg.abs()
286
+
283
287
  exp_avg.lerp_(grad_reshaped, 1 - momentum)
284
288
 
285
289
  if nesterov:
286
290
  nv_coef = momentum if nesterov_coef is None else nesterov_coef
287
291
  if normed_mt:
288
- # Scale the normalized gradient down to match the buffer's variance
289
- ema_std = math.sqrt((1 - momentum) / (1 + momentum))
290
- raw_update = (grad_reshaped * ema_std).lerp_(exp_avg, nv_coef)
292
+ raw_update = normed_grad.lerp_(exp_avg, nv_coef)
291
293
  else:
292
- raw_update = grad.lerp(exp_avg, nv_coef)
294
+ raw_update = grad_reshaped.lerp(exp_avg, nv_coef)
293
295
  else:
294
296
  raw_update = exp_avg.clone()
295
297
 
@@ -309,14 +311,16 @@ class SignSGD_adv(torch.optim.Optimizer):
309
311
  if snr_cond:
310
312
  denom = (1.0 - exp_avg.square()).clamp_min_(1e-30).sqrt_()
311
313
 
314
+ if nesterov and normed_mt:
315
+ # Scale the normalized gradient using empirical buffer magnitude (SNR recovery)
316
+ normed_grad = grad * exp_avg.abs()
317
+
312
318
  exp_avg.lerp_(grad, 1 - momentum)
313
319
 
314
320
  if nesterov:
315
321
  nv_coef = momentum if nesterov_coef is None else nesterov_coef
316
322
  if normed_mt:
317
- # Scale the normalized gradient down to match the buffer's variance
318
- ema_std = math.sqrt((1 - momentum) / (1 + momentum))
319
- raw_update = (grad * ema_std).lerp_(exp_avg, nv_coef)
323
+ raw_update = normed_grad.lerp_(exp_avg, nv_coef)
320
324
  else:
321
325
  raw_update = grad.lerp(exp_avg, nv_coef)
322
326
  else:
@@ -351,7 +355,7 @@ class SignSGD_adv(torch.optim.Optimizer):
351
355
  update_scaling = lr * A if snr_cond else lr
352
356
  update.mul_(update_scaling)
353
357
 
354
- param_update.apply_parameter_update(self, p, group, update, lr, random_int_tensor=random_int_tensor, wd_target=wd_target, cwd_target=cwd_target)
358
+ param_update.apply_parameter_update(self, p, group, update, lr, random_int_tensor=random_int_tensor, wd_target=wd_target, cwd_target=cwd_target, decoupled=True)
355
359
 
356
360
  def compile(self, *args, **kwargs):
357
361
  self._compiled_step_parameter = torch.compile(self._step_parameter, *args, **kwargs)
@@ -264,6 +264,10 @@ class SinkSGD_adv(torch.optim.Optimizer):
264
264
  else:
265
265
  denom = (1.0 - buf.square()).clamp_min_(1e-30).sqrt_().view_as(p)
266
266
 
267
+ if nesterov and normed_mt:
268
+ # Scale the normalized gradient using empirical buffer magnitude (SNR recovery)
269
+ normed_grad = grad_reshaped * buf.abs()
270
+
267
271
  buf.lerp_(grad_reshaped, 1 - momentum)
268
272
 
269
273
  # Factorize updated buffer
@@ -272,9 +276,7 @@ class SinkSGD_adv(torch.optim.Optimizer):
272
276
  if nesterov:
273
277
  nv_coef = momentum if nesterov_coef is None else nesterov_coef
274
278
  if normed_mt:
275
- # Scale the normalized gradient down to match the buffer's variance
276
- ema_std = math.sqrt((1 - momentum) / (1 + momentum))
277
- update = (grad_reshaped * ema_std).lerp_(buf, nv_coef)
279
+ update = normed_grad.lerp_(buf, nv_coef)
278
280
  else:
279
281
  update = grad_reshaped.lerp(buf, nv_coef)
280
282
  else:
@@ -299,6 +301,10 @@ class SinkSGD_adv(torch.optim.Optimizer):
299
301
  else:
300
302
  denom = (1.0 - buf.square()).clamp_min_(1e-30).sqrt_()
301
303
 
304
+ if nesterov and normed_mt:
305
+ # Scale the normalized gradient using empirical buffer magnitude (SNR recovery)
306
+ normed_grad = grad * buf.abs()
307
+
302
308
  buf.lerp_(grad, 1 - momentum)
303
309
 
304
310
  set_state(state, 'momentum_buffer', buf, actual_precision, random_int_state_tensor)
@@ -306,9 +312,7 @@ class SinkSGD_adv(torch.optim.Optimizer):
306
312
  if nesterov:
307
313
  nv_coef = momentum if nesterov_coef is None else nesterov_coef
308
314
  if normed_mt:
309
- # Scale the normalized gradient down to match the buffer's variance
310
- ema_std = math.sqrt((1 - momentum) / (1 + momentum))
311
- update = (grad * ema_std).lerp_(buf, nv_coef)
315
+ update = normed_grad.lerp_(buf, nv_coef)
312
316
  else:
313
317
  update = grad.lerp(buf, nv_coef)
314
318
  else:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: adv_optm
3
- Version: 2.4.dev24
3
+ Version: 2.4.dev25
4
4
  Summary: A family of highly efficient, lightweight yet powerful optimizers.
5
5
  Home-page: https://github.com/Koratahiu/Advanced_Optimizers
6
6
  Author: Koratahiu
@@ -5,7 +5,7 @@ with open("README.md", "r", encoding="utf-8") as fh:
5
5
 
6
6
  setup(
7
7
  name="adv_optm",
8
- version="2.4.dev24",
8
+ version="2.4.dev25",
9
9
  author="Koratahiu",
10
10
  author_email="hiuhonor@gmail.com",
11
11
  license='Apache 2.0',
File without changes
File without changes
File without changes