adv-optm 2.2.1.dev1__tar.gz → 2.2.1.dev2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {adv_optm-2.2.1.dev1 → adv_optm-2.2.1.dev2}/PKG-INFO +1 -1
- {adv_optm-2.2.1.dev1 → adv_optm-2.2.1.dev2}/adv_optm/__init__.py +1 -1
- {adv_optm-2.2.1.dev1 → adv_optm-2.2.1.dev2}/adv_optm/optim/AdaMuon_adv.py +4 -2
- {adv_optm-2.2.1.dev1 → adv_optm-2.2.1.dev2}/adv_optm/optim/Muon_adv.py +4 -2
- {adv_optm-2.2.1.dev1 → adv_optm-2.2.1.dev2}/adv_optm/util/Muon_util.py +82 -4
- {adv_optm-2.2.1.dev1 → adv_optm-2.2.1.dev2}/adv_optm.egg-info/PKG-INFO +1 -1
- {adv_optm-2.2.1.dev1 → adv_optm-2.2.1.dev2}/setup.py +1 -1
- {adv_optm-2.2.1.dev1 → adv_optm-2.2.1.dev2}/LICENSE +0 -0
- {adv_optm-2.2.1.dev1 → adv_optm-2.2.1.dev2}/README.md +0 -0
- {adv_optm-2.2.1.dev1 → adv_optm-2.2.1.dev2}/adv_optm/optim/AdamW_adv.py +0 -0
- {adv_optm-2.2.1.dev1 → adv_optm-2.2.1.dev2}/adv_optm/optim/Adopt_adv.py +0 -0
- {adv_optm-2.2.1.dev1 → adv_optm-2.2.1.dev2}/adv_optm/optim/Lion_Prodigy_adv.py +0 -0
- {adv_optm-2.2.1.dev1 → adv_optm-2.2.1.dev2}/adv_optm/optim/Lion_adv.py +0 -0
- {adv_optm-2.2.1.dev1 → adv_optm-2.2.1.dev2}/adv_optm/optim/Prodigy_adv.py +0 -0
- {adv_optm-2.2.1.dev1 → adv_optm-2.2.1.dev2}/adv_optm/optim/SignSGD_adv.py +0 -0
- {adv_optm-2.2.1.dev1 → adv_optm-2.2.1.dev2}/adv_optm/optim/Simplified_AdEMAMix.py +0 -0
- {adv_optm-2.2.1.dev1 → adv_optm-2.2.1.dev2}/adv_optm/optim/__init__.py +0 -0
- {adv_optm-2.2.1.dev1 → adv_optm-2.2.1.dev2}/adv_optm/util/Kourkoutas.py +0 -0
- {adv_optm-2.2.1.dev1 → adv_optm-2.2.1.dev2}/adv_optm/util/Muon_AuxAdam.py +0 -0
- {adv_optm-2.2.1.dev1 → adv_optm-2.2.1.dev2}/adv_optm/util/OrthoGrad.py +0 -0
- {adv_optm-2.2.1.dev1 → adv_optm-2.2.1.dev2}/adv_optm/util/__init__.py +0 -0
- {adv_optm-2.2.1.dev1 → adv_optm-2.2.1.dev2}/adv_optm/util/factorization_util.py +0 -0
- {adv_optm-2.2.1.dev1 → adv_optm-2.2.1.dev2}/adv_optm/util/lion_k.py +0 -0
- {adv_optm-2.2.1.dev1 → adv_optm-2.2.1.dev2}/adv_optm/util/param_update.py +0 -0
- {adv_optm-2.2.1.dev1 → adv_optm-2.2.1.dev2}/adv_optm/util/update_util.py +0 -0
- {adv_optm-2.2.1.dev1 → adv_optm-2.2.1.dev2}/adv_optm.egg-info/SOURCES.txt +0 -0
- {adv_optm-2.2.1.dev1 → adv_optm-2.2.1.dev2}/adv_optm.egg-info/dependency_links.txt +0 -0
- {adv_optm-2.2.1.dev1 → adv_optm-2.2.1.dev2}/adv_optm.egg-info/requires.txt +0 -0
- {adv_optm-2.2.1.dev1 → adv_optm-2.2.1.dev2}/adv_optm.egg-info/top_level.txt +0 -0
- {adv_optm-2.2.1.dev1 → adv_optm-2.2.1.dev2}/setup.cfg +0 -0
|
@@ -499,7 +499,8 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
499
499
|
cns_a_bound=group['cns_a_bound'],
|
|
500
500
|
low_rank_ortho=group['low_rank_ortho'],
|
|
501
501
|
ortho_rank=group['ortho_rank'],
|
|
502
|
-
spectral_normalization=group.get('spectral_normalization', False)
|
|
502
|
+
spectral_normalization=group.get('spectral_normalization', False),
|
|
503
|
+
compiled=group.get('compiled_optimizer', False)
|
|
503
504
|
)
|
|
504
505
|
|
|
505
506
|
if group['normuon_variant']:
|
|
@@ -563,7 +564,8 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
563
564
|
cns_a_bound=group['cns_a_bound'],
|
|
564
565
|
low_rank_ortho=group['low_rank_ortho'],
|
|
565
566
|
ortho_rank=group['ortho_rank'],
|
|
566
|
-
spectral_normalization=group.get('spectral_normalization', False)
|
|
567
|
+
spectral_normalization=group.get('spectral_normalization', False),
|
|
568
|
+
compiled=group.get('compiled_optimizer', False)
|
|
567
569
|
)
|
|
568
570
|
|
|
569
571
|
# NorMuon Logic
|
|
@@ -461,7 +461,8 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
461
461
|
cns_a_bound=group['cns_a_bound'],
|
|
462
462
|
low_rank_ortho=group['low_rank_ortho'],
|
|
463
463
|
ortho_rank=group['ortho_rank'],
|
|
464
|
-
spectral_normalization=group.get('spectral_normalization', False)
|
|
464
|
+
spectral_normalization=group.get('spectral_normalization', False),
|
|
465
|
+
compiled=group.get('compiled_optimizer', False)
|
|
465
466
|
)
|
|
466
467
|
|
|
467
468
|
if group['normuon_variant']:
|
|
@@ -511,7 +512,8 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
511
512
|
cns_a_bound=group['cns_a_bound'],
|
|
512
513
|
low_rank_ortho=group['low_rank_ortho'],
|
|
513
514
|
ortho_rank=group['ortho_rank'],
|
|
514
|
-
spectral_normalization=group.get('spectral_normalization', False)
|
|
515
|
+
spectral_normalization=group.get('spectral_normalization', False),
|
|
516
|
+
compiled=group.get('compiled_optimizer', False)
|
|
515
517
|
)
|
|
516
518
|
|
|
517
519
|
# NorMuon Logic
|
|
@@ -35,7 +35,7 @@ def _newton_schulz_iteration(
|
|
|
35
35
|
|
|
36
36
|
a, b, c = coeffs
|
|
37
37
|
|
|
38
|
-
X = G
|
|
38
|
+
X = G.to(torch.bfloat16)
|
|
39
39
|
|
|
40
40
|
# Transpose if needed
|
|
41
41
|
transposed = X.size(-2) > X.size(-1)
|
|
@@ -118,8 +118,80 @@ def _newton_schulz_iteration(
|
|
|
118
118
|
if transposed:
|
|
119
119
|
X = X.mT
|
|
120
120
|
|
|
121
|
-
return X
|
|
121
|
+
return X.to(G.dtype)
|
|
122
122
|
|
|
123
|
+
@torch.no_grad()
|
|
124
|
+
def _compiled_newton_schulz_iteration(
|
|
125
|
+
G: torch.Tensor,
|
|
126
|
+
steps: int = 5,
|
|
127
|
+
eps: float = 1e-7,
|
|
128
|
+
coeffs: tuple[float, float, float] = (3.4445, -4.7750, 2.0315),
|
|
129
|
+
cns: bool = False,
|
|
130
|
+
cns_a_bound: float = 1e-4,
|
|
131
|
+
spectral_normalization: bool = False,
|
|
132
|
+
) -> torch.Tensor:
|
|
133
|
+
"""
|
|
134
|
+
Newton-Schulz iteration refactored for torch.compile compatibility.
|
|
135
|
+
Removes mutable buffers and in-place operations in favor of functional graph construction.
|
|
136
|
+
"""
|
|
137
|
+
assert G.ndim in (2, 3), f"Input must be 2D or 3D, got {G.ndim}D"
|
|
138
|
+
|
|
139
|
+
a, b, c = coeffs
|
|
140
|
+
|
|
141
|
+
X = G.to(torch.bfloat16)
|
|
142
|
+
|
|
143
|
+
# Transpose if needed
|
|
144
|
+
transposed = X.size(-2) > X.size(-1)
|
|
145
|
+
if transposed:
|
|
146
|
+
X = X.mT
|
|
147
|
+
|
|
148
|
+
# Normalize spectral norm to at most 1
|
|
149
|
+
if spectral_normalization:
|
|
150
|
+
X.div_(X.norm(dim=(-2, -1), keepdim=True).add_(eps))
|
|
151
|
+
else:
|
|
152
|
+
X.div_(X.norm(dim=(-2, -1), keepdim=True).clamp_min_(eps))
|
|
153
|
+
|
|
154
|
+
if cns:
|
|
155
|
+
# Chebyshev-accelerated Newton-Schulz (CANS)
|
|
156
|
+
lower_bound = cns_a_bound
|
|
157
|
+
upper_bound = 1.0
|
|
158
|
+
|
|
159
|
+
for _ in range(steps):
|
|
160
|
+
lb, ub = lower_bound, upper_bound
|
|
161
|
+
lb_ub = lb * ub
|
|
162
|
+
# Calculate Mean Square Error term
|
|
163
|
+
e_sq = (lb**2 + lb_ub + ub**2) / 3.0
|
|
164
|
+
|
|
165
|
+
# Calculate components for alpha and bounds update
|
|
166
|
+
K = 2.0 * e_sq**1.5
|
|
167
|
+
L = lb_ub * (lb + ub)
|
|
168
|
+
denom = K + L
|
|
169
|
+
alpha = 6.0 / denom
|
|
170
|
+
|
|
171
|
+
c1 = alpha * e_sq
|
|
172
|
+
c3 = -alpha / 3.0
|
|
173
|
+
|
|
174
|
+
# Apply the 3rd-order Newton-Schulz update
|
|
175
|
+
A = X @ X.mT
|
|
176
|
+
X = c1 * X + c3 * (A @ X)
|
|
177
|
+
|
|
178
|
+
# Update the singular value bounds for the next iteration based on the error
|
|
179
|
+
eps_val = (K - L) / denom
|
|
180
|
+
lower_bound, upper_bound = 1.0 - eps_val, 1.0 + eps_val
|
|
181
|
+
|
|
182
|
+
else:
|
|
183
|
+
# Standard Quintic Newton-Schulz
|
|
184
|
+
# Update: X = a*X + b*(A@X) + c*(A@A@X)
|
|
185
|
+
for _ in range(steps):
|
|
186
|
+
A = X @ X.mT
|
|
187
|
+
B = b * A + c * (A @ A)
|
|
188
|
+
X = a * X + B @ X
|
|
189
|
+
|
|
190
|
+
# Transpose back if necessary
|
|
191
|
+
if transposed:
|
|
192
|
+
X = X.mT
|
|
193
|
+
|
|
194
|
+
return X.to(G.dtype)
|
|
123
195
|
|
|
124
196
|
@torch.no_grad()
|
|
125
197
|
def newton_schulz(
|
|
@@ -132,6 +204,7 @@ def newton_schulz(
|
|
|
132
204
|
low_rank_ortho: bool = False,
|
|
133
205
|
ortho_rank: int = 128,
|
|
134
206
|
spectral_normalization: bool = False,
|
|
207
|
+
compiled: bool = False,
|
|
135
208
|
) -> torch.Tensor:
|
|
136
209
|
"""
|
|
137
210
|
Public entry point for Muon orthogonalization.
|
|
@@ -149,6 +222,11 @@ def newton_schulz(
|
|
|
149
222
|
low_rank_ortho (bool): Whether to project to low rank before orthogonalizing.
|
|
150
223
|
ortho_rank (int): Rank for low-rank projection.
|
|
151
224
|
"""
|
|
225
|
+
if compiled:
|
|
226
|
+
ns_fn = _compiled_newton_schulz_iteration
|
|
227
|
+
else:
|
|
228
|
+
ns_fn = _newton_schulz_iteration
|
|
229
|
+
|
|
152
230
|
if low_rank_ortho:
|
|
153
231
|
# Low-Rank Orthogonalization via Gaussian Sketching
|
|
154
232
|
M = G
|
|
@@ -172,7 +250,7 @@ def newton_schulz(
|
|
|
172
250
|
projected_M = Q.T @ M
|
|
173
251
|
|
|
174
252
|
# 4. Orthogonalize the smaller projected matrix
|
|
175
|
-
ortho_projected_M =
|
|
253
|
+
ortho_projected_M = ns_fn(
|
|
176
254
|
projected_M,
|
|
177
255
|
steps=steps,
|
|
178
256
|
eps=eps,
|
|
@@ -186,7 +264,7 @@ def newton_schulz(
|
|
|
186
264
|
return Q @ ortho_projected_M
|
|
187
265
|
|
|
188
266
|
# Standard Path
|
|
189
|
-
return
|
|
267
|
+
return ns_fn(
|
|
190
268
|
G,
|
|
191
269
|
steps=steps,
|
|
192
270
|
eps=eps,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|