adv-optm 2.2.dev1__tar.gz → 2.2.dev3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. {adv_optm-2.2.dev1 → adv_optm-2.2.dev3}/PKG-INFO +1 -1
  2. {adv_optm-2.2.dev1 → adv_optm-2.2.dev3}/adv_optm/__init__.py +1 -1
  3. {adv_optm-2.2.dev1 → adv_optm-2.2.dev3}/adv_optm/optim/AdaMuon_adv.py +6 -1
  4. {adv_optm-2.2.dev1 → adv_optm-2.2.dev3}/adv_optm/optim/Muon_adv.py +5 -1
  5. {adv_optm-2.2.dev1 → adv_optm-2.2.dev3}/adv_optm/util/Muon_util.py +7 -2
  6. {adv_optm-2.2.dev1 → adv_optm-2.2.dev3}/adv_optm/util/param_update.py +10 -4
  7. {adv_optm-2.2.dev1 → adv_optm-2.2.dev3}/adv_optm.egg-info/PKG-INFO +1 -1
  8. {adv_optm-2.2.dev1 → adv_optm-2.2.dev3}/setup.py +1 -1
  9. {adv_optm-2.2.dev1 → adv_optm-2.2.dev3}/LICENSE +0 -0
  10. {adv_optm-2.2.dev1 → adv_optm-2.2.dev3}/README.md +0 -0
  11. {adv_optm-2.2.dev1 → adv_optm-2.2.dev3}/adv_optm/optim/AdamW_adv.py +0 -0
  12. {adv_optm-2.2.dev1 → adv_optm-2.2.dev3}/adv_optm/optim/Adopt_adv.py +0 -0
  13. {adv_optm-2.2.dev1 → adv_optm-2.2.dev3}/adv_optm/optim/Lion_Prodigy_adv.py +0 -0
  14. {adv_optm-2.2.dev1 → adv_optm-2.2.dev3}/adv_optm/optim/Lion_adv.py +0 -0
  15. {adv_optm-2.2.dev1 → adv_optm-2.2.dev3}/adv_optm/optim/Prodigy_adv.py +0 -0
  16. {adv_optm-2.2.dev1 → adv_optm-2.2.dev3}/adv_optm/optim/SignSGD_adv.py +0 -0
  17. {adv_optm-2.2.dev1 → adv_optm-2.2.dev3}/adv_optm/optim/Simplified_AdEMAMix.py +0 -0
  18. {adv_optm-2.2.dev1 → adv_optm-2.2.dev3}/adv_optm/optim/__init__.py +0 -0
  19. {adv_optm-2.2.dev1 → adv_optm-2.2.dev3}/adv_optm/util/Kourkoutas.py +0 -0
  20. {adv_optm-2.2.dev1 → adv_optm-2.2.dev3}/adv_optm/util/Muon_AuxAdam.py +0 -0
  21. {adv_optm-2.2.dev1 → adv_optm-2.2.dev3}/adv_optm/util/OrthoGrad.py +0 -0
  22. {adv_optm-2.2.dev1 → adv_optm-2.2.dev3}/adv_optm/util/__init__.py +0 -0
  23. {adv_optm-2.2.dev1 → adv_optm-2.2.dev3}/adv_optm/util/factorization_util.py +0 -0
  24. {adv_optm-2.2.dev1 → adv_optm-2.2.dev3}/adv_optm/util/lion_k.py +0 -0
  25. {adv_optm-2.2.dev1 → adv_optm-2.2.dev3}/adv_optm/util/update_util.py +0 -0
  26. {adv_optm-2.2.dev1 → adv_optm-2.2.dev3}/adv_optm.egg-info/SOURCES.txt +0 -0
  27. {adv_optm-2.2.dev1 → adv_optm-2.2.dev3}/adv_optm.egg-info/dependency_links.txt +0 -0
  28. {adv_optm-2.2.dev1 → adv_optm-2.2.dev3}/adv_optm.egg-info/requires.txt +0 -0
  29. {adv_optm-2.2.dev1 → adv_optm-2.2.dev3}/adv_optm.egg-info/top_level.txt +0 -0
  30. {adv_optm-2.2.dev1 → adv_optm-2.2.dev3}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: adv_optm
3
- Version: 2.2.dev1
3
+ Version: 2.2.dev3
4
4
  Summary: A family of highly efficient, lightweight yet powerful optimizers.
5
5
  Home-page: https://github.com/Koratahiu/Advanced_Optimizers
6
6
  Author: Koratahiu
@@ -22,4 +22,4 @@ __all__ = [
22
22
  "SignSGD_adv",
23
23
  ]
24
24
 
25
- __version__ = "2.2.dev1"
25
+ __version__ = "2.2.dev3"
@@ -225,6 +225,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
225
225
  "adam_k_warmup_steps": adam_k_warmup_steps, "adam_nnmf_factor": adam_nnmf_factor,
226
226
  }
227
227
  self.stochastic_rounding = stochastic_rounding
228
+ self._init_lr = lr
228
229
 
229
230
  super().__init__(params, defaults)
230
231
 
@@ -439,9 +440,13 @@ class AdaMuon_adv(torch.optim.Optimizer):
439
440
  scaled_eps, adaptive_eps, spectral_target, wd_scale = get_spectral_scaling(shape_for_scaling, group['n_layers'])
440
441
 
441
442
  weight_decay = group['weight_decay'] * wd_scale
443
+ decoupled_wd = True
444
+
442
445
  ns_eps = scaled_eps
446
+
443
447
  else:
444
448
  weight_decay = group['weight_decay']
449
+ decoupled_wd = False
445
450
  ns_eps = group['ns_eps']
446
451
  adaptive_eps = group['eps']
447
452
 
@@ -587,7 +592,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
587
592
 
588
593
  update = update.reshape(original_shape)
589
594
 
590
- param_update.apply_parameter_update(self, p, group, update, lr, wd=weight_decay, random_int_tensor=random_int_tensor)
595
+ param_update.apply_parameter_update(self, p, group, update, lr, wd=weight_decay, random_int_tensor=random_int_tensor, decoupled=decoupled_wd)
591
596
 
592
597
  @torch.no_grad()
593
598
  def step(self, closure=None):
@@ -204,6 +204,7 @@ class Muon_adv(torch.optim.Optimizer):
204
204
  }
205
205
  self.stochastic_rounding = stochastic_rounding
206
206
  self.compiled_optimizer = compiled_optimizer
207
+ self._init_lr = lr
207
208
 
208
209
  super().__init__(params, defaults)
209
210
 
@@ -404,9 +405,12 @@ class Muon_adv(torch.optim.Optimizer):
404
405
  scaled_eps, _, spectral_target, wd_scale = get_spectral_scaling(shape_for_scaling, group['n_layers'])
405
406
 
406
407
  weight_decay = group['weight_decay'] * wd_scale
408
+ decoupled_wd = True
409
+
407
410
  ns_eps = scaled_eps
408
411
  else:
409
412
  weight_decay = group['weight_decay']
413
+ decoupled_wd = False
410
414
  ns_eps = group['ns_eps']
411
415
 
412
416
  # MARS-M Approximated (Variance Reduction)
@@ -521,7 +525,7 @@ class Muon_adv(torch.optim.Optimizer):
521
525
 
522
526
  update = update.reshape(original_shape)
523
527
 
524
- param_update.apply_parameter_update(self, p, group, update, lr, wd=weight_decay, random_int_tensor=random_int_tensor)
528
+ param_update.apply_parameter_update(self, p, group, update, lr, wd=weight_decay, random_int_tensor=random_int_tensor, decoupled=decoupled_wd)
525
529
 
526
530
  @torch.no_grad()
527
531
  def step(self, closure=None):
@@ -352,7 +352,12 @@ def spectral_norm_update(update: torch.Tensor, vector_state: torch.Tensor, targe
352
352
  # Normalize v_new to get next state
353
353
  v_norm = torch.linalg.vector_norm(v_new)
354
354
 
355
- vector_state.copy_(v_new.div_(v_norm.clamp_min_(1e-12)).to(vector_state.dtype))
355
+ # if v_norm >= 0.5:
356
+ # vector_state.copy_(v_new.div_(v_norm.clamp_min_(1e-12))).to(vector_state.dtype))
357
+ candidate_v = v_new / v_norm
358
+ next_state = torch.where(v_norm >= 0.5, candidate_v, vector_state)
359
+ vector_state.copy_(next_state.to(vector_state.dtype))
360
+ # Else: We keep the old vector_state (which is a random unit vector at init)
356
361
 
357
362
  # Estimate sigma = ||A @ v|| (since v is unit norm)
358
363
  # Re-compute A @ v_new with the updated vector for better estimate
@@ -379,7 +384,7 @@ def get_spectral_scaling(shape: torch.Size, n_layers: int):
379
384
  wd_scale: Weight decay scale
380
385
  """
381
386
  d_out, d_in = shape[0], shape[1]
382
-
387
+
383
388
  # Handle Convolutional/Flattened tensors
384
389
  if len(shape) > 2:
385
390
  d_in = shape[1:].numel()
@@ -14,6 +14,7 @@ def apply_parameter_update(
14
14
  lr: float | Tensor,
15
15
  wd: float | None = None,
16
16
  random_int_tensor: Tensor | None = None,
17
+ decoupled: bool = False,
17
18
  ) -> None:
18
19
  """
19
20
  Applies decoupled weight decay (standard or cautious) and the final
@@ -27,9 +28,14 @@ def apply_parameter_update(
27
28
  wd: Optional float value for weight decay, if another value other than group["weight_decay"] is needed.
28
29
  random_int_tensor: Optional pre-generated random tensor for stochastic
29
30
  rounding. Required for the `torch.compile` path.
31
+ decoupled: Whenever to use the true decoupled weight decay.
30
32
  """
31
33
  wd = group["weight_decay"] if wd is None else wd
32
34
  cautious = group.get('cautious_wd', False)
35
+ if decoupled:
36
+ scaled_wd = 1 / (1.0 + wd * lr / self._init_lr)
37
+ else:
38
+ scaled_wd = wd * lr
33
39
 
34
40
  # Compute full update in float32 if using bfloat16 with stochastic rounding
35
41
  if p.dtype == torch.bfloat16 and self.stochastic_rounding:
@@ -41,11 +47,11 @@ def apply_parameter_update(
41
47
  if cautious:
42
48
  # Cautious Weight Decay
43
49
  mask = (update_fp32 * p_fp32 >= 0).float()
44
- p_fp32.addcmul_(p_fp32, mask, value=-wd * lr)
50
+ p_fp32.addcmul_(p_fp32, mask, value=-scaled_wd)
45
51
  del mask
46
52
  else:
47
53
  # Standard decoupled weight decay
48
- p_fp32.add_(p_fp32, alpha=-wd * lr)
54
+ p_fp32.add_(p_fp32, alpha=-scaled_wd)
49
55
 
50
56
  # Apply main update
51
57
  p_fp32.add_(-update_fp32)
@@ -65,11 +71,11 @@ def apply_parameter_update(
65
71
  if cautious:
66
72
  # Cautious Weight Decay
67
73
  mask = (update * p >= 0).to(p.dtype)
68
- p.addcmul_(p, mask, value=-wd * lr)
74
+ p.addcmul_(p, mask, value=-scaled_wd)
69
75
  del mask
70
76
  else:
71
77
  # Standard decoupled weight decay
72
- p.add_(p, alpha=-wd * lr)
78
+ p.add_(p, alpha=-scaled_wd)
73
79
 
74
80
  # Apply main update
75
81
  p.add_(-update)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: adv_optm
3
- Version: 2.2.dev1
3
+ Version: 2.2.dev3
4
4
  Summary: A family of highly efficient, lightweight yet powerful optimizers.
5
5
  Home-page: https://github.com/Koratahiu/Advanced_Optimizers
6
6
  Author: Koratahiu
@@ -5,7 +5,7 @@ with open("README.md", "r", encoding="utf-8") as fh:
5
5
 
6
6
  setup(
7
7
  name="adv_optm",
8
- version="2.2.dev1",
8
+ version="2.2.dev3",
9
9
  author="Koratahiu",
10
10
  author_email="hiuhonor@gmail.com",
11
11
  license='Apache 2.0',
File without changes
File without changes
File without changes