adv-optm 2.4.dev12__tar.gz → 2.4.dev13__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. {adv_optm-2.4.dev12 → adv_optm-2.4.dev13}/PKG-INFO +1 -1
  2. {adv_optm-2.4.dev12 → adv_optm-2.4.dev13}/adv_optm/__init__.py +1 -1
  3. {adv_optm-2.4.dev12 → adv_optm-2.4.dev13}/adv_optm/optim/AdamW_adv.py +15 -5
  4. {adv_optm-2.4.dev12 → adv_optm-2.4.dev13}/adv_optm/optim/SinkSGD_adv.py +19 -10
  5. {adv_optm-2.4.dev12 → adv_optm-2.4.dev13}/adv_optm/util/sinkhorn.py +8 -20
  6. {adv_optm-2.4.dev12 → adv_optm-2.4.dev13}/adv_optm.egg-info/PKG-INFO +1 -1
  7. {adv_optm-2.4.dev12 → adv_optm-2.4.dev13}/setup.py +1 -1
  8. {adv_optm-2.4.dev12 → adv_optm-2.4.dev13}/LICENSE +0 -0
  9. {adv_optm-2.4.dev12 → adv_optm-2.4.dev13}/README.md +0 -0
  10. {adv_optm-2.4.dev12 → adv_optm-2.4.dev13}/adv_optm/optim/AdaMuon_adv.py +0 -0
  11. {adv_optm-2.4.dev12 → adv_optm-2.4.dev13}/adv_optm/optim/Adopt_adv.py +0 -0
  12. {adv_optm-2.4.dev12 → adv_optm-2.4.dev13}/adv_optm/optim/Lion_Prodigy_adv.py +0 -0
  13. {adv_optm-2.4.dev12 → adv_optm-2.4.dev13}/adv_optm/optim/Lion_adv.py +0 -0
  14. {adv_optm-2.4.dev12 → adv_optm-2.4.dev13}/adv_optm/optim/Muon_adv.py +0 -0
  15. {adv_optm-2.4.dev12 → adv_optm-2.4.dev13}/adv_optm/optim/Prodigy_adv.py +0 -0
  16. {adv_optm-2.4.dev12 → adv_optm-2.4.dev13}/adv_optm/optim/SignSGD_adv.py +0 -0
  17. {adv_optm-2.4.dev12 → adv_optm-2.4.dev13}/adv_optm/optim/Simplified_AdEMAMix.py +0 -0
  18. {adv_optm-2.4.dev12 → adv_optm-2.4.dev13}/adv_optm/optim/__init__.py +0 -0
  19. {adv_optm-2.4.dev12 → adv_optm-2.4.dev13}/adv_optm/util/Kourkoutas.py +0 -0
  20. {adv_optm-2.4.dev12 → adv_optm-2.4.dev13}/adv_optm/util/Muon_AuxAdam.py +0 -0
  21. {adv_optm-2.4.dev12 → adv_optm-2.4.dev13}/adv_optm/util/Muon_util.py +0 -0
  22. {adv_optm-2.4.dev12 → adv_optm-2.4.dev13}/adv_optm/util/OrthoGrad.py +0 -0
  23. {adv_optm-2.4.dev12 → adv_optm-2.4.dev13}/adv_optm/util/__init__.py +0 -0
  24. {adv_optm-2.4.dev12 → adv_optm-2.4.dev13}/adv_optm/util/centered_decay.py +0 -0
  25. {adv_optm-2.4.dev12 → adv_optm-2.4.dev13}/adv_optm/util/factorization_util.py +0 -0
  26. {adv_optm-2.4.dev12 → adv_optm-2.4.dev13}/adv_optm/util/lion_k.py +0 -0
  27. {adv_optm-2.4.dev12 → adv_optm-2.4.dev13}/adv_optm/util/param_update.py +0 -0
  28. {adv_optm-2.4.dev12 → adv_optm-2.4.dev13}/adv_optm/util/scaled_optm.py +0 -0
  29. {adv_optm-2.4.dev12 → adv_optm-2.4.dev13}/adv_optm/util/signed_util.py +0 -0
  30. {adv_optm-2.4.dev12 → adv_optm-2.4.dev13}/adv_optm/util/state_util.py +0 -0
  31. {adv_optm-2.4.dev12 → adv_optm-2.4.dev13}/adv_optm/util/update_util.py +0 -0
  32. {adv_optm-2.4.dev12 → adv_optm-2.4.dev13}/adv_optm.egg-info/SOURCES.txt +0 -0
  33. {adv_optm-2.4.dev12 → adv_optm-2.4.dev13}/adv_optm.egg-info/dependency_links.txt +0 -0
  34. {adv_optm-2.4.dev12 → adv_optm-2.4.dev13}/adv_optm.egg-info/requires.txt +0 -0
  35. {adv_optm-2.4.dev12 → adv_optm-2.4.dev13}/adv_optm.egg-info/top_level.txt +0 -0
  36. {adv_optm-2.4.dev12 → adv_optm-2.4.dev13}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: adv_optm
3
- Version: 2.4.dev12
3
+ Version: 2.4.dev13
4
4
  Summary: A family of highly efficient, lightweight yet powerful optimizers.
5
5
  Home-page: https://github.com/Koratahiu/Advanced_Optimizers
6
6
  Author: Koratahiu
@@ -24,4 +24,4 @@ __all__ = [
24
24
  "SinkSGD_adv",
25
25
  ]
26
26
 
27
- __version__ = "2.4.dev12"
27
+ __version__ = "2.4.dev13"
@@ -232,12 +232,13 @@ class AdamW_adv(torch.optim.Optimizer):
232
232
  def supports_flat_params(self):
233
233
  return False
234
234
 
235
- @torch.no_grad()
236
- def step_parameter(self, p: torch.Tensor, group: dict, i: int | None = None):
237
- if p.grad is None:
238
- return
235
+ def init_step(self):
236
+ for group in self.param_groups:
237
+ for i, p in enumerate(group['params']):
238
+ self.__init_state(p, group)
239
239
 
240
- grad = p.grad
240
+ @torch.no_grad()
241
+ def __init_state(self, p, group):
241
242
  state = self.state[p]
242
243
 
243
244
  # State Initialization
@@ -303,6 +304,15 @@ class AdamW_adv(torch.optim.Optimizer):
303
304
 
304
305
  _init_fisher_wd_scaler(group, state, p)
305
306
 
307
+ @torch.no_grad()
308
+ def step_parameter(self, p: torch.Tensor, group: dict, i: int | None = None):
309
+ if p.grad is None:
310
+ return
311
+
312
+ grad = p.grad
313
+ state = self.state[p]
314
+ self.__init_state(p, group)
315
+
306
316
  beta1, beta2 = group['betas']
307
317
 
308
318
  current_step = state['step']
@@ -116,6 +116,8 @@ class SinkSGD_adv(torch.optim.Optimizer):
116
116
  for device in devices:
117
117
  param_update.set_seed(device)
118
118
 
119
+ self.init_step()
120
+
119
121
  self._compiled_step_parameter = None
120
122
  if compiled_optimizer:
121
123
  self.compile(fullgraph=True)
@@ -136,14 +138,14 @@ class SinkSGD_adv(torch.optim.Optimizer):
136
138
  def supports_flat_params(self):
137
139
  return False
138
140
 
139
- @torch.no_grad()
140
- def step_parameter(self, p: torch.Tensor, group: dict, i: int | None = None):
141
- if p.grad is None:
142
- return
141
+ def init_step(self):
142
+ for group in self.param_groups:
143
+ for i, p in enumerate(group['params']):
144
+ self.__init_state(p, group)
143
145
 
144
- grad = p.grad
146
+ @torch.no_grad()
147
+ def __init_state(self, p, group):
145
148
  state = self.state[p]
146
-
147
149
  # State Initialization
148
150
  if 'step' not in state:
149
151
  state['step'] = 0
@@ -180,6 +182,15 @@ class SinkSGD_adv(torch.optim.Optimizer):
180
182
 
181
183
  _init_anchor(p, state, group)
182
184
 
185
+ @torch.no_grad()
186
+ def step_parameter(self, p: torch.Tensor, group: dict, i: int | None = None):
187
+ if p.grad is None:
188
+ return
189
+
190
+ grad = p.grad
191
+ state = self.state[p]
192
+ self.__init_state(p, group)
193
+
183
194
  step_size = group['lr']
184
195
 
185
196
  random_int_tensor = None
@@ -219,7 +230,7 @@ class SinkSGD_adv(torch.optim.Optimizer):
219
230
 
220
231
  if momentum != 0:
221
232
  buf = _reconstruct_state((state['mu_b_nmf'], state['mv_b_nmf'], state['sign'], d2), signed=True)
222
- buf.mul_(momentum).add_(grad_reshaped, alpha=1 - momentum)
233
+ buf.lerp_(grad_reshaped, 1 - momentum)
223
234
 
224
235
  # Factorize updated buffer
225
236
  state['mu_b_nmf'], state['mv_b_nmf'], state['sign'] = _factorize_state(buf.clone(), signed=True)
@@ -239,9 +250,7 @@ class SinkSGD_adv(torch.optim.Optimizer):
239
250
 
240
251
  if momentum != 0:
241
252
  buf = get_state(state, 'momentum_buffer', actual_precision)
242
-
243
- buf.mul_(momentum).add_(grad, alpha=1 - momentum)
244
-
253
+ buf.lerp_(grad, 1 - momentum)
245
254
 
246
255
  set_state(state, 'momentum_buffer', buf, actual_precision, random_int_state_tensor)
247
256
 
@@ -36,13 +36,10 @@ def apply_sr_sinkhorn(update: torch.Tensor, p: torch.Tensor, ortho_project: bool
36
36
 
37
37
 
38
38
  # Precompute scaling factors.
39
- scale_first = m if scale_cond else n
40
- scale_second = n if scale_cond else m
39
+ scale_first = math.sqrt(m if scale_cond else n)
40
+ scale_second = math.sqrt(n if scale_cond else m)
41
41
 
42
42
  if ortho_project:
43
- # Pre-compute squares for the mathematical trick in ortho_normed
44
- target_norm_sq_first = scale_first ** 2
45
- target_norm_sq_second = scale_second ** 2
46
43
  param_2d = p.float().view(p.shape[0], -1)
47
44
  p_norm_sq_dim = torch.sum(param_2d * param_2d, dim=dim, keepdim=True).add_(1e-30)
48
45
  p_norm_sq_adim = torch.sum(param_2d * param_2d, dim=1-dim, keepdim=True).add_(1e-30)
@@ -53,23 +50,17 @@ def apply_sr_sinkhorn(update: torch.Tensor, p: torch.Tensor, ortho_project: bool
53
50
  norm1 = update_2d.norm(p=2, dim=dim, keepdim=True).clamp_min_(1e-12)
54
51
  update_2d.mul_(scale_first / norm1)
55
52
  if ortho_project:
56
- update_2d = ortho_normed(param_2d, update_2d, p_norm_sq_dim, dim, scale_first, target_norm_sq_first)
53
+ update_2d = ortho_normed(param_2d, update_2d, p_norm_sq_dim, dim, scale_first)
57
54
 
58
55
  # Second normalization step
59
56
  norm2 = update_2d.norm(p=2, dim=1-dim, keepdim=True).clamp_min_(1e-12)
60
57
  update_2d.mul_(scale_second / norm2)
61
58
  if ortho_project:
62
- update_2d = ortho_normed(param_2d, update_2d, p_norm_sq_adim, 1-dim, scale_second, target_norm_sq_second)
63
-
64
- # Final step
65
- norm1 = update_2d.norm(p=2, dim=dim, keepdim=True).clamp_min_(1e-12)
66
- update_2d.mul_(scale_first / norm1)
67
- if ortho_project:
68
- update_2d = ortho_normed(param_2d, update_2d, p_norm_sq_dim, dim, scale_first, target_norm_sq_first)
59
+ update_2d = ortho_normed(param_2d, update_2d, p_norm_sq_adim, 1-dim, scale_second)
69
60
 
70
61
  return update_2d.view(original_shape).to(original_dtype)
71
62
 
72
- def ortho_normed(p_2d, update_2d, p_norm_sq, dim, target_norm, target_norm_sq):
63
+ def ortho_normed(p_2d, update_2d, p_norm_sq, dim, target_norm):
73
64
  """
74
65
  Projects the update to be orthogonal to p along 'dim' and restores the original norm.
75
66
  """
@@ -80,10 +71,7 @@ def ortho_normed(p_2d, update_2d, p_norm_sq, dim, target_norm, target_norm_sq):
80
71
  # In-place subtraction: update_2d = update_2d - (proj * p_2d)
81
72
  update_2d.addcmul_(proj, p_2d, value=-1.0)
82
73
 
83
- # Magnitude Preservation via Pythagorean theorem
84
- # ||g_orth||^2 = ||g||^2 - ||proj * p||^2
85
- proj_norm_sq = (dot_prod ** 2) / p_norm_sq
86
- g_orth_norm_sq = (target_norm_sq - proj_norm_sq).clamp_min_(1e-30)
87
-
88
- scale_factor = target_norm / torch.sqrt(g_orth_norm_sq)
74
+ # Magnitude Preservation
75
+ g_orth_norm = update_2d.norm(p=2, dim=dim, keepdim=True).clamp_min_(1e-12)
76
+ scale_factor = target_norm / g_orth_norm
89
77
  return update_2d.mul_(scale_factor)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: adv_optm
3
- Version: 2.4.dev12
3
+ Version: 2.4.dev13
4
4
  Summary: A family of highly efficient, lightweight yet powerful optimizers.
5
5
  Home-page: https://github.com/Koratahiu/Advanced_Optimizers
6
6
  Author: Koratahiu
@@ -5,7 +5,7 @@ with open("README.md", "r", encoding="utf-8") as fh:
5
5
 
6
6
  setup(
7
7
  name="adv_optm",
8
- version="2.4.dev12",
8
+ version="2.4.dev13",
9
9
  author="Koratahiu",
10
10
  author_email="hiuhonor@gmail.com",
11
11
  license='Apache 2.0',
File without changes
File without changes
File without changes