adv-optm 2.2.dev5__tar.gz → 2.2.1.dev2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. {adv_optm-2.2.dev5 → adv_optm-2.2.1.dev2}/PKG-INFO +1 -1
  2. {adv_optm-2.2.dev5 → adv_optm-2.2.1.dev2}/adv_optm/__init__.py +1 -1
  3. {adv_optm-2.2.dev5 → adv_optm-2.2.1.dev2}/adv_optm/optim/AdaMuon_adv.py +4 -2
  4. {adv_optm-2.2.dev5 → adv_optm-2.2.1.dev2}/adv_optm/optim/Muon_adv.py +4 -2
  5. {adv_optm-2.2.dev5 → adv_optm-2.2.1.dev2}/adv_optm/util/Muon_util.py +80 -2
  6. {adv_optm-2.2.dev5 → adv_optm-2.2.1.dev2}/adv_optm/util/param_update.py +2 -3
  7. {adv_optm-2.2.dev5 → adv_optm-2.2.1.dev2}/adv_optm.egg-info/PKG-INFO +1 -1
  8. {adv_optm-2.2.dev5 → adv_optm-2.2.1.dev2}/setup.py +1 -1
  9. {adv_optm-2.2.dev5 → adv_optm-2.2.1.dev2}/LICENSE +0 -0
  10. {adv_optm-2.2.dev5 → adv_optm-2.2.1.dev2}/README.md +0 -0
  11. {adv_optm-2.2.dev5 → adv_optm-2.2.1.dev2}/adv_optm/optim/AdamW_adv.py +0 -0
  12. {adv_optm-2.2.dev5 → adv_optm-2.2.1.dev2}/adv_optm/optim/Adopt_adv.py +0 -0
  13. {adv_optm-2.2.dev5 → adv_optm-2.2.1.dev2}/adv_optm/optim/Lion_Prodigy_adv.py +0 -0
  14. {adv_optm-2.2.dev5 → adv_optm-2.2.1.dev2}/adv_optm/optim/Lion_adv.py +0 -0
  15. {adv_optm-2.2.dev5 → adv_optm-2.2.1.dev2}/adv_optm/optim/Prodigy_adv.py +0 -0
  16. {adv_optm-2.2.dev5 → adv_optm-2.2.1.dev2}/adv_optm/optim/SignSGD_adv.py +0 -0
  17. {adv_optm-2.2.dev5 → adv_optm-2.2.1.dev2}/adv_optm/optim/Simplified_AdEMAMix.py +0 -0
  18. {adv_optm-2.2.dev5 → adv_optm-2.2.1.dev2}/adv_optm/optim/__init__.py +0 -0
  19. {adv_optm-2.2.dev5 → adv_optm-2.2.1.dev2}/adv_optm/util/Kourkoutas.py +0 -0
  20. {adv_optm-2.2.dev5 → adv_optm-2.2.1.dev2}/adv_optm/util/Muon_AuxAdam.py +0 -0
  21. {adv_optm-2.2.dev5 → adv_optm-2.2.1.dev2}/adv_optm/util/OrthoGrad.py +0 -0
  22. {adv_optm-2.2.dev5 → adv_optm-2.2.1.dev2}/adv_optm/util/__init__.py +0 -0
  23. {adv_optm-2.2.dev5 → adv_optm-2.2.1.dev2}/adv_optm/util/factorization_util.py +0 -0
  24. {adv_optm-2.2.dev5 → adv_optm-2.2.1.dev2}/adv_optm/util/lion_k.py +0 -0
  25. {adv_optm-2.2.dev5 → adv_optm-2.2.1.dev2}/adv_optm/util/update_util.py +0 -0
  26. {adv_optm-2.2.dev5 → adv_optm-2.2.1.dev2}/adv_optm.egg-info/SOURCES.txt +0 -0
  27. {adv_optm-2.2.dev5 → adv_optm-2.2.1.dev2}/adv_optm.egg-info/dependency_links.txt +0 -0
  28. {adv_optm-2.2.dev5 → adv_optm-2.2.1.dev2}/adv_optm.egg-info/requires.txt +0 -0
  29. {adv_optm-2.2.dev5 → adv_optm-2.2.1.dev2}/adv_optm.egg-info/top_level.txt +0 -0
  30. {adv_optm-2.2.dev5 → adv_optm-2.2.1.dev2}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: adv_optm
3
- Version: 2.2.dev5
3
+ Version: 2.2.1.dev2
4
4
  Summary: A family of highly efficient, lightweight yet powerful optimizers.
5
5
  Home-page: https://github.com/Koratahiu/Advanced_Optimizers
6
6
  Author: Koratahiu
@@ -22,4 +22,4 @@ __all__ = [
22
22
  "SignSGD_adv",
23
23
  ]
24
24
 
25
- __version__ = "2.2.dev5"
25
+ __version__ = "2.2.1.dev2"
@@ -499,7 +499,8 @@ class AdaMuon_adv(torch.optim.Optimizer):
499
499
  cns_a_bound=group['cns_a_bound'],
500
500
  low_rank_ortho=group['low_rank_ortho'],
501
501
  ortho_rank=group['ortho_rank'],
502
- spectral_normalization=group.get('spectral_normalization', False)
502
+ spectral_normalization=group.get('spectral_normalization', False),
503
+ compiled=group.get('compiled_optimizer', False)
503
504
  )
504
505
 
505
506
  if group['normuon_variant']:
@@ -563,7 +564,8 @@ class AdaMuon_adv(torch.optim.Optimizer):
563
564
  cns_a_bound=group['cns_a_bound'],
564
565
  low_rank_ortho=group['low_rank_ortho'],
565
566
  ortho_rank=group['ortho_rank'],
566
- spectral_normalization=group.get('spectral_normalization', False)
567
+ spectral_normalization=group.get('spectral_normalization', False),
568
+ compiled=group.get('compiled_optimizer', False)
567
569
  )
568
570
 
569
571
  # NorMuon Logic
@@ -461,7 +461,8 @@ class Muon_adv(torch.optim.Optimizer):
461
461
  cns_a_bound=group['cns_a_bound'],
462
462
  low_rank_ortho=group['low_rank_ortho'],
463
463
  ortho_rank=group['ortho_rank'],
464
- spectral_normalization=group.get('spectral_normalization', False)
464
+ spectral_normalization=group.get('spectral_normalization', False),
465
+ compiled=group.get('compiled_optimizer', False)
465
466
  )
466
467
 
467
468
  if group['normuon_variant']:
@@ -511,7 +512,8 @@ class Muon_adv(torch.optim.Optimizer):
511
512
  cns_a_bound=group['cns_a_bound'],
512
513
  low_rank_ortho=group['low_rank_ortho'],
513
514
  ortho_rank=group['ortho_rank'],
514
- spectral_normalization=group.get('spectral_normalization', False)
515
+ spectral_normalization=group.get('spectral_normalization', False),
516
+ compiled=group.get('compiled_optimizer', False)
515
517
  )
516
518
 
517
519
  # NorMuon Logic
@@ -120,6 +120,78 @@ def _newton_schulz_iteration(
120
120
 
121
121
  return X.to(G.dtype)
122
122
 
123
+ @torch.no_grad()
124
+ def _compiled_newton_schulz_iteration(
125
+ G: torch.Tensor,
126
+ steps: int = 5,
127
+ eps: float = 1e-7,
128
+ coeffs: tuple[float, float, float] = (3.4445, -4.7750, 2.0315),
129
+ cns: bool = False,
130
+ cns_a_bound: float = 1e-4,
131
+ spectral_normalization: bool = False,
132
+ ) -> torch.Tensor:
133
+ """
134
+ Newton-Schulz iteration refactored for torch.compile compatibility.
135
+ Removes mutable buffers and in-place operations in favor of functional graph construction.
136
+ """
137
+ assert G.ndim in (2, 3), f"Input must be 2D or 3D, got {G.ndim}D"
138
+
139
+ a, b, c = coeffs
140
+
141
+ X = G.to(torch.bfloat16)
142
+
143
+ # Transpose if needed
144
+ transposed = X.size(-2) > X.size(-1)
145
+ if transposed:
146
+ X = X.mT
147
+
148
+ # Normalize spectral norm to at most 1
149
+ if spectral_normalization:
150
+ X.div_(X.norm(dim=(-2, -1), keepdim=True).add_(eps))
151
+ else:
152
+ X.div_(X.norm(dim=(-2, -1), keepdim=True).clamp_min_(eps))
153
+
154
+ if cns:
155
+ # Chebyshev-accelerated Newton-Schulz (CANS)
156
+ lower_bound = cns_a_bound
157
+ upper_bound = 1.0
158
+
159
+ for _ in range(steps):
160
+ lb, ub = lower_bound, upper_bound
161
+ lb_ub = lb * ub
162
+ # Calculate Mean Square Error term
163
+ e_sq = (lb**2 + lb_ub + ub**2) / 3.0
164
+
165
+ # Calculate components for alpha and bounds update
166
+ K = 2.0 * e_sq**1.5
167
+ L = lb_ub * (lb + ub)
168
+ denom = K + L
169
+ alpha = 6.0 / denom
170
+
171
+ c1 = alpha * e_sq
172
+ c3 = -alpha / 3.0
173
+
174
+ # Apply the 3rd-order Newton-Schulz update
175
+ A = X @ X.mT
176
+ X = c1 * X + c3 * (A @ X)
177
+
178
+ # Update the singular value bounds for the next iteration based on the error
179
+ eps_val = (K - L) / denom
180
+ lower_bound, upper_bound = 1.0 - eps_val, 1.0 + eps_val
181
+
182
+ else:
183
+ # Standard Quintic Newton-Schulz
184
+ # Update: X = a*X + b*(A@X) + c*(A@A@X)
185
+ for _ in range(steps):
186
+ A = X @ X.mT
187
+ B = b * A + c * (A @ A)
188
+ X = a * X + B @ X
189
+
190
+ # Transpose back if necessary
191
+ if transposed:
192
+ X = X.mT
193
+
194
+ return X.to(G.dtype)
123
195
 
124
196
  @torch.no_grad()
125
197
  def newton_schulz(
@@ -132,6 +204,7 @@ def newton_schulz(
132
204
  low_rank_ortho: bool = False,
133
205
  ortho_rank: int = 128,
134
206
  spectral_normalization: bool = False,
207
+ compiled: bool = False,
135
208
  ) -> torch.Tensor:
136
209
  """
137
210
  Public entry point for Muon orthogonalization.
@@ -149,6 +222,11 @@ def newton_schulz(
149
222
  low_rank_ortho (bool): Whether to project to low rank before orthogonalizing.
150
223
  ortho_rank (int): Rank for low-rank projection.
151
224
  """
225
+ if compiled:
226
+ ns_fn = _compiled_newton_schulz_iteration
227
+ else:
228
+ ns_fn = _newton_schulz_iteration
229
+
152
230
  if low_rank_ortho:
153
231
  # Low-Rank Orthogonalization via Gaussian Sketching
154
232
  M = G
@@ -172,7 +250,7 @@ def newton_schulz(
172
250
  projected_M = Q.T @ M
173
251
 
174
252
  # 4. Orthogonalize the smaller projected matrix
175
- ortho_projected_M = _newton_schulz_iteration(
253
+ ortho_projected_M = ns_fn(
176
254
  projected_M,
177
255
  steps=steps,
178
256
  eps=eps,
@@ -186,7 +264,7 @@ def newton_schulz(
186
264
  return Q @ ortho_projected_M
187
265
 
188
266
  # Standard Path
189
- return _newton_schulz_iteration(
267
+ return ns_fn(
190
268
  G,
191
269
  steps=steps,
192
270
  eps=eps,
@@ -60,6 +60,7 @@ def apply_parameter_update(
60
60
  if random_int_tensor is not None:
61
61
  # Compiled path: use the pre-computed random tensor
62
62
  _copy_stochastic_core_(p, p_fp32, random_int_tensor)
63
+ del random_int_tensor
63
64
  else:
64
65
  # Uncompiled path: generate randoms inside
65
66
  copy_stochastic_(p, p_fp32)
@@ -132,7 +133,7 @@ def _copy_stochastic_core_(target: Tensor, source: Tensor, random_int_tensor: Te
132
133
  Core logic for stochastic rounding using a pre-computed random integer tensor.
133
134
  This version is designed to be torch.compile-friendly.
134
135
  """
135
- result = random_int_tensor.clone()
136
+ result = random_int_tensor
136
137
  # add the random number to the lower 16 bit of the mantissa
137
138
  result.add_(source.view(dtype=torch.int32))
138
139
 
@@ -142,8 +143,6 @@ def _copy_stochastic_core_(target: Tensor, source: Tensor, random_int_tensor: Te
142
143
  # copy the higher 16 bit into the target tensor
143
144
  target.copy_(result.view(dtype=torch.float32))
144
145
 
145
- del result
146
-
147
146
 
148
147
  def copy_stochastic_(target: Tensor, source: Tensor):
149
148
  """
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: adv_optm
3
- Version: 2.2.dev5
3
+ Version: 2.2.1.dev2
4
4
  Summary: A family of highly efficient, lightweight yet powerful optimizers.
5
5
  Home-page: https://github.com/Koratahiu/Advanced_Optimizers
6
6
  Author: Koratahiu
@@ -5,7 +5,7 @@ with open("README.md", "r", encoding="utf-8") as fh:
5
5
 
6
6
  setup(
7
7
  name="adv_optm",
8
- version="2.2.dev5",
8
+ version="2.2.1.dev2",
9
9
  author="Koratahiu",
10
10
  author_email="hiuhonor@gmail.com",
11
11
  license='Apache 2.0',
File without changes
File without changes
File without changes