adv-optm 2.4.dev11__tar.gz → 2.4.dev12__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. {adv_optm-2.4.dev11 → adv_optm-2.4.dev12}/PKG-INFO +1 -1
  2. {adv_optm-2.4.dev11 → adv_optm-2.4.dev12}/adv_optm/__init__.py +3 -3
  3. {adv_optm-2.4.dev11 → adv_optm-2.4.dev12}/adv_optm/optim/AdaMuon_adv.py +1 -1
  4. {adv_optm-2.4.dev11 → adv_optm-2.4.dev12}/adv_optm/optim/AdamW_adv.py +1 -1
  5. {adv_optm-2.4.dev11 → adv_optm-2.4.dev12}/adv_optm/optim/Adopt_adv.py +1 -1
  6. {adv_optm-2.4.dev11 → adv_optm-2.4.dev12}/adv_optm/optim/Muon_adv.py +1 -1
  7. {adv_optm-2.4.dev11 → adv_optm-2.4.dev12}/adv_optm/optim/SignSGD_adv.py +1 -1
  8. adv_optm-2.4.dev11/adv_optm/optim/SGD_adv.py → adv_optm-2.4.dev12/adv_optm/optim/SinkSGD_adv.py +12 -12
  9. {adv_optm-2.4.dev11 → adv_optm-2.4.dev12}/adv_optm/optim/__init__.py +2 -2
  10. adv_optm-2.4.dev12/adv_optm/util/sinkhorn.py +89 -0
  11. {adv_optm-2.4.dev11 → adv_optm-2.4.dev12}/adv_optm.egg-info/PKG-INFO +1 -1
  12. {adv_optm-2.4.dev11 → adv_optm-2.4.dev12}/adv_optm.egg-info/SOURCES.txt +1 -1
  13. {adv_optm-2.4.dev11 → adv_optm-2.4.dev12}/setup.py +1 -1
  14. adv_optm-2.4.dev11/adv_optm/util/sinkhorn.py +0 -42
  15. {adv_optm-2.4.dev11 → adv_optm-2.4.dev12}/LICENSE +0 -0
  16. {adv_optm-2.4.dev11 → adv_optm-2.4.dev12}/README.md +0 -0
  17. {adv_optm-2.4.dev11 → adv_optm-2.4.dev12}/adv_optm/optim/Lion_Prodigy_adv.py +0 -0
  18. {adv_optm-2.4.dev11 → adv_optm-2.4.dev12}/adv_optm/optim/Lion_adv.py +0 -0
  19. {adv_optm-2.4.dev11 → adv_optm-2.4.dev12}/adv_optm/optim/Prodigy_adv.py +0 -0
  20. {adv_optm-2.4.dev11 → adv_optm-2.4.dev12}/adv_optm/optim/Simplified_AdEMAMix.py +0 -0
  21. {adv_optm-2.4.dev11 → adv_optm-2.4.dev12}/adv_optm/util/Kourkoutas.py +0 -0
  22. {adv_optm-2.4.dev11 → adv_optm-2.4.dev12}/adv_optm/util/Muon_AuxAdam.py +0 -0
  23. {adv_optm-2.4.dev11 → adv_optm-2.4.dev12}/adv_optm/util/Muon_util.py +0 -0
  24. {adv_optm-2.4.dev11 → adv_optm-2.4.dev12}/adv_optm/util/OrthoGrad.py +0 -0
  25. {adv_optm-2.4.dev11 → adv_optm-2.4.dev12}/adv_optm/util/__init__.py +0 -0
  26. {adv_optm-2.4.dev11 → adv_optm-2.4.dev12}/adv_optm/util/centered_decay.py +0 -0
  27. {adv_optm-2.4.dev11 → adv_optm-2.4.dev12}/adv_optm/util/factorization_util.py +0 -0
  28. {adv_optm-2.4.dev11 → adv_optm-2.4.dev12}/adv_optm/util/lion_k.py +0 -0
  29. {adv_optm-2.4.dev11 → adv_optm-2.4.dev12}/adv_optm/util/param_update.py +0 -0
  30. {adv_optm-2.4.dev11 → adv_optm-2.4.dev12}/adv_optm/util/scaled_optm.py +0 -0
  31. {adv_optm-2.4.dev11 → adv_optm-2.4.dev12}/adv_optm/util/signed_util.py +0 -0
  32. {adv_optm-2.4.dev11 → adv_optm-2.4.dev12}/adv_optm/util/state_util.py +0 -0
  33. {adv_optm-2.4.dev11 → adv_optm-2.4.dev12}/adv_optm/util/update_util.py +0 -0
  34. {adv_optm-2.4.dev11 → adv_optm-2.4.dev12}/adv_optm.egg-info/dependency_links.txt +0 -0
  35. {adv_optm-2.4.dev11 → adv_optm-2.4.dev12}/adv_optm.egg-info/requires.txt +0 -0
  36. {adv_optm-2.4.dev11 → adv_optm-2.4.dev12}/adv_optm.egg-info/top_level.txt +0 -0
  37. {adv_optm-2.4.dev11 → adv_optm-2.4.dev12}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: adv_optm
3
- Version: 2.4.dev11
3
+ Version: 2.4.dev12
4
4
  Summary: A family of highly efficient, lightweight yet powerful optimizers.
5
5
  Home-page: https://github.com/Koratahiu/Advanced_Optimizers
6
6
  Author: Koratahiu
@@ -8,7 +8,7 @@ from .optim import (
8
8
  Muon_adv,
9
9
  AdaMuon_adv,
10
10
  SignSGD_adv,
11
- SGD_adv,
11
+ SinkSGD_adv,
12
12
  )
13
13
 
14
14
  __all__ = [
@@ -21,7 +21,7 @@ __all__ = [
21
21
  "Muon_adv",
22
22
  "AdaMuon_adv",
23
23
  "SignSGD_adv",
24
- "SGD_adv",
24
+ "SinkSGD_adv",
25
25
  ]
26
26
 
27
- __version__ = "2.4.dev11"
27
+ __version__ = "2.4.dev12"
@@ -468,7 +468,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
468
468
  if actual_precision == 'bf16_sr' and random_int_state_tensor is not None:
469
469
  random_int_state_tensor = param_update._get_random_int_for_sr(p)
470
470
  elif actual_precision == 'int8_sr':
471
- random_int_state_tensor = param_update._get_random_int_for_int8_sr(p)
471
+ random_int_state_tensor = param_update._get_random_int_for_8bit_sr(p)
472
472
  elif actual_precision == 'fp8_sr':
473
473
  random_int_state_tensor = param_update._get_random_int_for_fp8_sr(p)
474
474
  else:
@@ -333,7 +333,7 @@ class AdamW_adv(torch.optim.Optimizer):
333
333
  if group['actual_state_precision'] == 'bf16_sr' and random_int_state_tensor is None:
334
334
  random_int_state_tensor = param_update._get_random_int_for_sr(p)
335
335
  elif group['actual_state_precision'] == 'int8_sr':
336
- random_int_state_tensor = param_update._get_random_int_for_int8_sr(p)
336
+ random_int_state_tensor = param_update._get_random_int_for_8bit_sr(p)
337
337
  elif group['actual_state_precision'] == 'fp8_sr':
338
338
  random_int_state_tensor = param_update._get_random_int_for_fp8_sr(p)
339
339
  step_param_fn = self._compiled_step_parameter
@@ -359,7 +359,7 @@ class Adopt_adv(torch.optim.Optimizer):
359
359
  if group['actual_state_precision'] == 'bf16_sr' and random_int_state_tensor is None:
360
360
  random_int_state_tensor = param_update._get_random_int_for_sr(p)
361
361
  elif group['actual_state_precision'] == 'int8_sr':
362
- random_int_state_tensor = param_update._get_random_int_for_int8_sr(p)
362
+ random_int_state_tensor = param_update._get_random_int_for_8bit_sr(p)
363
363
  elif group['actual_state_precision'] == 'fp8_sr':
364
364
  random_int_state_tensor = param_update._get_random_int_for_fp8_sr(p)
365
365
  step_param_fn = self._compiled_step_parameter
@@ -422,7 +422,7 @@ class Muon_adv(torch.optim.Optimizer):
422
422
  if actual_precision == 'bf16_sr' and random_int_state_tensor is not None:
423
423
  random_int_state_tensor = param_update._get_random_int_for_sr(p)
424
424
  elif actual_precision == 'int8_sr':
425
- random_int_state_tensor = param_update._get_random_int_for_int8_sr(p)
425
+ random_int_state_tensor = param_update._get_random_int_for_8bit_sr(p)
426
426
  elif actual_precision == 'fp8_sr':
427
427
  random_int_state_tensor = param_update._get_random_int_for_fp8_sr(p)
428
428
  else:
@@ -252,7 +252,7 @@ class SignSGD_adv(torch.optim.Optimizer):
252
252
  if group['actual_state_precision'] == 'bf16_sr' and random_int_state_tensor is None:
253
253
  random_int_state_tensor = param_update._get_random_int_for_sr(p)
254
254
  elif group['actual_state_precision'] == 'int8_sr':
255
- random_int_state_tensor = param_update._get_random_int_for_int8_sr(p)
255
+ random_int_state_tensor = param_update._get_random_int_for_8bit_sr(p)
256
256
  elif group['actual_state_precision'] == 'fp8_sr':
257
257
  random_int_state_tensor = param_update._get_random_int_for_fp8_sr(p)
258
258
 
@@ -11,12 +11,11 @@ from ..util.centered_decay import _init_anchor
11
11
  from ..util.state_util import init_state_tensor, get_state, set_state, upcast_grad_for_precision
12
12
  from ..util.sinkhorn import apply_sr_sinkhorn
13
13
 
14
- class SGD_adv(torch.optim.Optimizer):
14
+ class SinkSGD_adv(torch.optim.Optimizer):
15
15
  """
16
- Implements an advanced Stochastic Gradient Descent (SGD) algorithm.
17
- This is an advanced version of SGD with optional features like
18
- low-rank factorization of optimizer states (SMMF), OrthoGrad,
19
- Cautious updating, and AdEMAMix extensions.
16
+ Implements an advanced Stochastic Gradient Descent (SGD) with Sinkhorn Iterative Normalization (SinkSGD) algorithm.
17
+ This is an advanced version of SinkSGD with optional features like
18
+ low-rank factorization of optimizer states (SMMF), OrthoGrad, etc.
20
19
 
21
20
  Args:
22
21
  params (iterable): iterable of parameters to optimize or dicts defining
@@ -62,11 +61,11 @@ class SGD_adv(torch.optim.Optimizer):
62
61
  cautious_wd: bool = False,
63
62
  # Stochastic Rounding for BF16
64
63
  stochastic_rounding: bool = True,
65
- # OrthoGrad
66
- orthogonal_gradient: bool = False,
67
64
  # Sinkhorn Iterative Normalization
68
- sinkhorn: bool = False,
69
65
  sinkhorn_iterations: int = 5,
66
+ orthogonal_sinkhorn: bool = False,
67
+ # OrthoGrad
68
+ orthogonal_gradient: bool = False,
70
69
  # Spectral Normed Optimizer
71
70
  spectral_normalization: bool = False,
72
71
  # Centered WD
@@ -101,7 +100,8 @@ class SGD_adv(torch.optim.Optimizer):
101
100
  "decoupled_wd": decoupled_wd, "cautious_wd": cautious_wd,
102
101
  "orthogonal_gradient": orthogonal_gradient,
103
102
  "compiled_optimizer": compiled_optimizer,
104
- "sinkhorn": sinkhorn, "sinkhorn_iterations": sinkhorn_iterations,
103
+ "sinkhorn_iterations": sinkhorn_iterations,
104
+ "orthogonal_sinkhorn": orthogonal_sinkhorn,
105
105
  "spectral_normalization": spectral_normalization,
106
106
  "centered_wd": centered_wd, "centered_wd_mode": centered_wd_mode,
107
107
  "state_precision": state_precision,
@@ -193,7 +193,7 @@ class SGD_adv(torch.optim.Optimizer):
193
193
  if group['actual_state_precision'] == 'bf16_sr' and random_int_state_tensor is None:
194
194
  random_int_state_tensor = param_update._get_random_int_for_sr(p)
195
195
  elif group['actual_state_precision'] == 'int8_sr':
196
- random_int_state_tensor = param_update._get_random_int_for_int8_sr(p)
196
+ random_int_state_tensor = param_update._get_random_int_for_8bit_sr(p)
197
197
  elif group['actual_state_precision'] == 'fp8_sr':
198
198
  random_int_state_tensor = param_update._get_random_int_for_fp8_sr(p)
199
199
  step_param_fn = self._compiled_step_parameter
@@ -254,8 +254,8 @@ class SGD_adv(torch.optim.Optimizer):
254
254
 
255
255
  del random_int_state_tensor
256
256
 
257
- if group['sinkhorn']:
258
- update = apply_sr_sinkhorn(update, iters=group['sinkhorn_iterations'])
257
+ # Sinkhorn iterative normalization
258
+ update = apply_sr_sinkhorn(update, p, ortho_project=group['orthogonal_sinkhorn'], iters=group['sinkhorn_iterations'])
259
259
 
260
260
  update_scaling = step_size
261
261
  if group.get('spectral_normalization', False):
@@ -7,7 +7,7 @@ from .Lion_Prodigy_adv import Lion_Prodigy_adv
7
7
  from .Muon_adv import Muon_adv
8
8
  from .AdaMuon_adv import AdaMuon_adv
9
9
  from .SignSGD_adv import SignSGD_adv
10
- from .SGD_adv import SGD_adv
10
+ from .SinkSGD_adv import SinkSGD_adv
11
11
 
12
12
  __all__ = [
13
13
  "AdamW_adv",
@@ -19,5 +19,5 @@ __all__ = [
19
19
  "Muon_adv",
20
20
  "AdaMuon_adv",
21
21
  "SignSGD_adv",
22
- "SGD_adv",
22
+ "SinkSGD_adv",
23
23
  ]
@@ -0,0 +1,89 @@
1
+ import math
2
+ import torch
3
+
4
+ def apply_sr_sinkhorn(update: torch.Tensor, p: torch.Tensor, ortho_project: bool, iters: int = 5) -> torch.Tensor:
5
+ """
6
+ Applies Square-Root Sinkhorn (SR-Sinkhorn) multi-normalization.
7
+ As described in 'Gradient Multi-Normalization for Efficient LLM Training'.
8
+
9
+ This technique normalizes a 2D matrix alternatively by its row-wise L2 norm
10
+ and column-wise L2 norm, driving it toward a fixed point that uniformly
11
+ distributes update magnitudes.
12
+ """
13
+ original_shape = update.shape
14
+ original_dtype = update.dtype
15
+ update = update.float()
16
+
17
+ # 1D Vector Case
18
+ if update.dim() == 1:
19
+ if ortho_project:
20
+ p_float = p.float()
21
+ p_norm_sq = torch.dot(p_float, p_float).add_(1e-30)
22
+ proj = torch.dot(p_float, update) / p_norm_sq
23
+ update.sub_(p_float * proj)
24
+ norm = update.norm(p=2).clamp_min_(1e-12)
25
+ return update.mul_(math.sqrt(update.numel()) / norm).view(original_shape).to(original_dtype)
26
+
27
+ # 2D+ Matrix Case
28
+ update_2d = update.view(update.shape[0], -1)
29
+
30
+ m, n = update_2d.shape
31
+
32
+ # Dynamically determine the order of normalization based on aspect ratio
33
+ # Normalizing the longer dimension first aids stability.
34
+ scale_cond = update_2d.shape[0] > update_2d.shape[1]
35
+ dim = 0 if scale_cond else 1
36
+
37
+
38
+ # Precompute scaling factors.
39
+ scale_first = m if scale_cond else n
40
+ scale_second = n if scale_cond else m
41
+
42
+ if ortho_project:
43
+ # Pre-compute squares for the mathematical trick in ortho_normed
44
+ target_norm_sq_first = scale_first ** 2
45
+ target_norm_sq_second = scale_second ** 2
46
+ param_2d = p.float().view(p.shape[0], -1)
47
+ p_norm_sq_dim = torch.sum(param_2d * param_2d, dim=dim, keepdim=True).add_(1e-30)
48
+ p_norm_sq_adim = torch.sum(param_2d * param_2d, dim=1-dim, keepdim=True).add_(1e-30)
49
+
50
+ # In-place alternating Sinkhorn normalization steps
51
+ for _ in range(iters):
52
+ # First normalization step
53
+ norm1 = update_2d.norm(p=2, dim=dim, keepdim=True).clamp_min_(1e-12)
54
+ update_2d.mul_(scale_first / norm1)
55
+ if ortho_project:
56
+ update_2d = ortho_normed(param_2d, update_2d, p_norm_sq_dim, dim, scale_first, target_norm_sq_first)
57
+
58
+ # Second normalization step
59
+ norm2 = update_2d.norm(p=2, dim=1-dim, keepdim=True).clamp_min_(1e-12)
60
+ update_2d.mul_(scale_second / norm2)
61
+ if ortho_project:
62
+ update_2d = ortho_normed(param_2d, update_2d, p_norm_sq_adim, 1-dim, scale_second, target_norm_sq_second)
63
+
64
+ # Final step
65
+ norm1 = update_2d.norm(p=2, dim=dim, keepdim=True).clamp_min_(1e-12)
66
+ update_2d.mul_(scale_first / norm1)
67
+ if ortho_project:
68
+ update_2d = ortho_normed(param_2d, update_2d, p_norm_sq_dim, dim, scale_first, target_norm_sq_first)
69
+
70
+ return update_2d.view(original_shape).to(original_dtype)
71
+
72
+ def ortho_normed(p_2d, update_2d, p_norm_sq, dim, target_norm, target_norm_sq):
73
+ """
74
+ Projects the update to be orthogonal to p along 'dim' and restores the original norm.
75
+ """
76
+ # Project: g_orth = g - (p * <p, g> / ||p||^2)
77
+ dot_prod = torch.sum(p_2d * update_2d, dim=dim, keepdim=True)
78
+ proj = dot_prod / p_norm_sq
79
+
80
+ # In-place subtraction: update_2d = update_2d - (proj * p_2d)
81
+ update_2d.addcmul_(proj, p_2d, value=-1.0)
82
+
83
+ # Magnitude Preservation via Pythagorean theorem
84
+ # ||g_orth||^2 = ||g||^2 - ||proj * p||^2
85
+ proj_norm_sq = (dot_prod ** 2) / p_norm_sq
86
+ g_orth_norm_sq = (target_norm_sq - proj_norm_sq).clamp_min_(1e-30)
87
+
88
+ scale_factor = target_norm / torch.sqrt(g_orth_norm_sq)
89
+ return update_2d.mul_(scale_factor)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: adv_optm
3
- Version: 2.4.dev11
3
+ Version: 2.4.dev12
4
4
  Summary: A family of highly efficient, lightweight yet powerful optimizers.
5
5
  Home-page: https://github.com/Koratahiu/Advanced_Optimizers
6
6
  Author: Koratahiu
@@ -14,9 +14,9 @@ adv_optm/optim/Lion_Prodigy_adv.py
14
14
  adv_optm/optim/Lion_adv.py
15
15
  adv_optm/optim/Muon_adv.py
16
16
  adv_optm/optim/Prodigy_adv.py
17
- adv_optm/optim/SGD_adv.py
18
17
  adv_optm/optim/SignSGD_adv.py
19
18
  adv_optm/optim/Simplified_AdEMAMix.py
19
+ adv_optm/optim/SinkSGD_adv.py
20
20
  adv_optm/optim/__init__.py
21
21
  adv_optm/util/Kourkoutas.py
22
22
  adv_optm/util/Muon_AuxAdam.py
@@ -5,7 +5,7 @@ with open("README.md", "r", encoding="utf-8") as fh:
5
5
 
6
6
  setup(
7
7
  name="adv_optm",
8
- version="2.4.dev11",
8
+ version="2.4.dev12",
9
9
  author="Koratahiu",
10
10
  author_email="hiuhonor@gmail.com",
11
11
  license='Apache 2.0',
@@ -1,42 +0,0 @@
1
- import math
2
- import torch
3
-
4
- def apply_sr_sinkhorn(update: torch.Tensor, iters: int = 5) -> torch.Tensor:
5
- """
6
- Applies Square-Root Sinkhorn (SR-Sinkhorn) multi-normalization.
7
- As described in 'Gradient Multi-Normalization for Efficient LLM Training'.
8
-
9
- This technique normalizes a 2D matrix alternatively by its row-wise L2 norm
10
- and column-wise L2 norm, driving it toward a fixed point that uniformly
11
- distributes update magnitudes.
12
- """
13
- original_shape = update.shape
14
-
15
- if update.dim() == 1:
16
- norm = update.norm(p=2).clamp_min_(1e-12)
17
- return update.mul_(math.sqrt(update.numel()) / norm)
18
- else:
19
- # Flatten >= 3D tensors into 2D matrices
20
- update_2d = update.view(update.shape[0], -1)
21
-
22
- m, n = update_2d.shape
23
-
24
- # Dynamically determine the order of normalization based on aspect ratio
25
- # Normalizing the longer dimension first aids stability.
26
- dim = 0 if m > n else 1
27
-
28
- # Precompute scaling factors.
29
- scale_first = math.sqrt(m) if dim == 0 else math.sqrt(n)
30
- scale_second = math.sqrt(n) if dim == 0 else math.sqrt(m)
31
-
32
- # In-place alternating Sinkhorn normalization steps
33
- for _ in range(iters):
34
- # First normalization step
35
- norm1 = update_2d.norm(p=2, dim=dim, keepdim=True).clamp_min_(1e-12)
36
- update_2d.mul_(scale_first / norm1)
37
-
38
- # Second normalization step
39
- norm2 = update_2d.norm(p=2, dim=1-dim, keepdim=True).clamp_min_(1e-12)
40
- update_2d.mul_(scale_second / norm2)
41
-
42
- return update_2d.view(original_shape)
File without changes
File without changes
File without changes