adv-optm 2.2.dev5__tar.gz → 2.2.1.dev2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {adv_optm-2.2.dev5 → adv_optm-2.2.1.dev2}/PKG-INFO +1 -1
- {adv_optm-2.2.dev5 → adv_optm-2.2.1.dev2}/adv_optm/__init__.py +1 -1
- {adv_optm-2.2.dev5 → adv_optm-2.2.1.dev2}/adv_optm/optim/AdaMuon_adv.py +4 -2
- {adv_optm-2.2.dev5 → adv_optm-2.2.1.dev2}/adv_optm/optim/Muon_adv.py +4 -2
- {adv_optm-2.2.dev5 → adv_optm-2.2.1.dev2}/adv_optm/util/Muon_util.py +80 -2
- {adv_optm-2.2.dev5 → adv_optm-2.2.1.dev2}/adv_optm/util/param_update.py +2 -3
- {adv_optm-2.2.dev5 → adv_optm-2.2.1.dev2}/adv_optm.egg-info/PKG-INFO +1 -1
- {adv_optm-2.2.dev5 → adv_optm-2.2.1.dev2}/setup.py +1 -1
- {adv_optm-2.2.dev5 → adv_optm-2.2.1.dev2}/LICENSE +0 -0
- {adv_optm-2.2.dev5 → adv_optm-2.2.1.dev2}/README.md +0 -0
- {adv_optm-2.2.dev5 → adv_optm-2.2.1.dev2}/adv_optm/optim/AdamW_adv.py +0 -0
- {adv_optm-2.2.dev5 → adv_optm-2.2.1.dev2}/adv_optm/optim/Adopt_adv.py +0 -0
- {adv_optm-2.2.dev5 → adv_optm-2.2.1.dev2}/adv_optm/optim/Lion_Prodigy_adv.py +0 -0
- {adv_optm-2.2.dev5 → adv_optm-2.2.1.dev2}/adv_optm/optim/Lion_adv.py +0 -0
- {adv_optm-2.2.dev5 → adv_optm-2.2.1.dev2}/adv_optm/optim/Prodigy_adv.py +0 -0
- {adv_optm-2.2.dev5 → adv_optm-2.2.1.dev2}/adv_optm/optim/SignSGD_adv.py +0 -0
- {adv_optm-2.2.dev5 → adv_optm-2.2.1.dev2}/adv_optm/optim/Simplified_AdEMAMix.py +0 -0
- {adv_optm-2.2.dev5 → adv_optm-2.2.1.dev2}/adv_optm/optim/__init__.py +0 -0
- {adv_optm-2.2.dev5 → adv_optm-2.2.1.dev2}/adv_optm/util/Kourkoutas.py +0 -0
- {adv_optm-2.2.dev5 → adv_optm-2.2.1.dev2}/adv_optm/util/Muon_AuxAdam.py +0 -0
- {adv_optm-2.2.dev5 → adv_optm-2.2.1.dev2}/adv_optm/util/OrthoGrad.py +0 -0
- {adv_optm-2.2.dev5 → adv_optm-2.2.1.dev2}/adv_optm/util/__init__.py +0 -0
- {adv_optm-2.2.dev5 → adv_optm-2.2.1.dev2}/adv_optm/util/factorization_util.py +0 -0
- {adv_optm-2.2.dev5 → adv_optm-2.2.1.dev2}/adv_optm/util/lion_k.py +0 -0
- {adv_optm-2.2.dev5 → adv_optm-2.2.1.dev2}/adv_optm/util/update_util.py +0 -0
- {adv_optm-2.2.dev5 → adv_optm-2.2.1.dev2}/adv_optm.egg-info/SOURCES.txt +0 -0
- {adv_optm-2.2.dev5 → adv_optm-2.2.1.dev2}/adv_optm.egg-info/dependency_links.txt +0 -0
- {adv_optm-2.2.dev5 → adv_optm-2.2.1.dev2}/adv_optm.egg-info/requires.txt +0 -0
- {adv_optm-2.2.dev5 → adv_optm-2.2.1.dev2}/adv_optm.egg-info/top_level.txt +0 -0
- {adv_optm-2.2.dev5 → adv_optm-2.2.1.dev2}/setup.cfg +0 -0
|
@@ -499,7 +499,8 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
499
499
|
cns_a_bound=group['cns_a_bound'],
|
|
500
500
|
low_rank_ortho=group['low_rank_ortho'],
|
|
501
501
|
ortho_rank=group['ortho_rank'],
|
|
502
|
-
spectral_normalization=group.get('spectral_normalization', False)
|
|
502
|
+
spectral_normalization=group.get('spectral_normalization', False),
|
|
503
|
+
compiled=group.get('compiled_optimizer', False)
|
|
503
504
|
)
|
|
504
505
|
|
|
505
506
|
if group['normuon_variant']:
|
|
@@ -563,7 +564,8 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
563
564
|
cns_a_bound=group['cns_a_bound'],
|
|
564
565
|
low_rank_ortho=group['low_rank_ortho'],
|
|
565
566
|
ortho_rank=group['ortho_rank'],
|
|
566
|
-
spectral_normalization=group.get('spectral_normalization', False)
|
|
567
|
+
spectral_normalization=group.get('spectral_normalization', False),
|
|
568
|
+
compiled=group.get('compiled_optimizer', False)
|
|
567
569
|
)
|
|
568
570
|
|
|
569
571
|
# NorMuon Logic
|
|
@@ -461,7 +461,8 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
461
461
|
cns_a_bound=group['cns_a_bound'],
|
|
462
462
|
low_rank_ortho=group['low_rank_ortho'],
|
|
463
463
|
ortho_rank=group['ortho_rank'],
|
|
464
|
-
spectral_normalization=group.get('spectral_normalization', False)
|
|
464
|
+
spectral_normalization=group.get('spectral_normalization', False),
|
|
465
|
+
compiled=group.get('compiled_optimizer', False)
|
|
465
466
|
)
|
|
466
467
|
|
|
467
468
|
if group['normuon_variant']:
|
|
@@ -511,7 +512,8 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
511
512
|
cns_a_bound=group['cns_a_bound'],
|
|
512
513
|
low_rank_ortho=group['low_rank_ortho'],
|
|
513
514
|
ortho_rank=group['ortho_rank'],
|
|
514
|
-
spectral_normalization=group.get('spectral_normalization', False)
|
|
515
|
+
spectral_normalization=group.get('spectral_normalization', False),
|
|
516
|
+
compiled=group.get('compiled_optimizer', False)
|
|
515
517
|
)
|
|
516
518
|
|
|
517
519
|
# NorMuon Logic
|
|
@@ -120,6 +120,78 @@ def _newton_schulz_iteration(
|
|
|
120
120
|
|
|
121
121
|
return X.to(G.dtype)
|
|
122
122
|
|
|
123
|
+
@torch.no_grad()
|
|
124
|
+
def _compiled_newton_schulz_iteration(
|
|
125
|
+
G: torch.Tensor,
|
|
126
|
+
steps: int = 5,
|
|
127
|
+
eps: float = 1e-7,
|
|
128
|
+
coeffs: tuple[float, float, float] = (3.4445, -4.7750, 2.0315),
|
|
129
|
+
cns: bool = False,
|
|
130
|
+
cns_a_bound: float = 1e-4,
|
|
131
|
+
spectral_normalization: bool = False,
|
|
132
|
+
) -> torch.Tensor:
|
|
133
|
+
"""
|
|
134
|
+
Newton-Schulz iteration refactored for torch.compile compatibility.
|
|
135
|
+
Removes mutable buffers and in-place operations in favor of functional graph construction.
|
|
136
|
+
"""
|
|
137
|
+
assert G.ndim in (2, 3), f"Input must be 2D or 3D, got {G.ndim}D"
|
|
138
|
+
|
|
139
|
+
a, b, c = coeffs
|
|
140
|
+
|
|
141
|
+
X = G.to(torch.bfloat16)
|
|
142
|
+
|
|
143
|
+
# Transpose if needed
|
|
144
|
+
transposed = X.size(-2) > X.size(-1)
|
|
145
|
+
if transposed:
|
|
146
|
+
X = X.mT
|
|
147
|
+
|
|
148
|
+
# Normalize spectral norm to at most 1
|
|
149
|
+
if spectral_normalization:
|
|
150
|
+
X.div_(X.norm(dim=(-2, -1), keepdim=True).add_(eps))
|
|
151
|
+
else:
|
|
152
|
+
X.div_(X.norm(dim=(-2, -1), keepdim=True).clamp_min_(eps))
|
|
153
|
+
|
|
154
|
+
if cns:
|
|
155
|
+
# Chebyshev-accelerated Newton-Schulz (CANS)
|
|
156
|
+
lower_bound = cns_a_bound
|
|
157
|
+
upper_bound = 1.0
|
|
158
|
+
|
|
159
|
+
for _ in range(steps):
|
|
160
|
+
lb, ub = lower_bound, upper_bound
|
|
161
|
+
lb_ub = lb * ub
|
|
162
|
+
# Calculate Mean Square Error term
|
|
163
|
+
e_sq = (lb**2 + lb_ub + ub**2) / 3.0
|
|
164
|
+
|
|
165
|
+
# Calculate components for alpha and bounds update
|
|
166
|
+
K = 2.0 * e_sq**1.5
|
|
167
|
+
L = lb_ub * (lb + ub)
|
|
168
|
+
denom = K + L
|
|
169
|
+
alpha = 6.0 / denom
|
|
170
|
+
|
|
171
|
+
c1 = alpha * e_sq
|
|
172
|
+
c3 = -alpha / 3.0
|
|
173
|
+
|
|
174
|
+
# Apply the 3rd-order Newton-Schulz update
|
|
175
|
+
A = X @ X.mT
|
|
176
|
+
X = c1 * X + c3 * (A @ X)
|
|
177
|
+
|
|
178
|
+
# Update the singular value bounds for the next iteration based on the error
|
|
179
|
+
eps_val = (K - L) / denom
|
|
180
|
+
lower_bound, upper_bound = 1.0 - eps_val, 1.0 + eps_val
|
|
181
|
+
|
|
182
|
+
else:
|
|
183
|
+
# Standard Quintic Newton-Schulz
|
|
184
|
+
# Update: X = a*X + b*(A@X) + c*(A@A@X)
|
|
185
|
+
for _ in range(steps):
|
|
186
|
+
A = X @ X.mT
|
|
187
|
+
B = b * A + c * (A @ A)
|
|
188
|
+
X = a * X + B @ X
|
|
189
|
+
|
|
190
|
+
# Transpose back if necessary
|
|
191
|
+
if transposed:
|
|
192
|
+
X = X.mT
|
|
193
|
+
|
|
194
|
+
return X.to(G.dtype)
|
|
123
195
|
|
|
124
196
|
@torch.no_grad()
|
|
125
197
|
def newton_schulz(
|
|
@@ -132,6 +204,7 @@ def newton_schulz(
|
|
|
132
204
|
low_rank_ortho: bool = False,
|
|
133
205
|
ortho_rank: int = 128,
|
|
134
206
|
spectral_normalization: bool = False,
|
|
207
|
+
compiled: bool = False,
|
|
135
208
|
) -> torch.Tensor:
|
|
136
209
|
"""
|
|
137
210
|
Public entry point for Muon orthogonalization.
|
|
@@ -149,6 +222,11 @@ def newton_schulz(
|
|
|
149
222
|
low_rank_ortho (bool): Whether to project to low rank before orthogonalizing.
|
|
150
223
|
ortho_rank (int): Rank for low-rank projection.
|
|
151
224
|
"""
|
|
225
|
+
if compiled:
|
|
226
|
+
ns_fn = _compiled_newton_schulz_iteration
|
|
227
|
+
else:
|
|
228
|
+
ns_fn = _newton_schulz_iteration
|
|
229
|
+
|
|
152
230
|
if low_rank_ortho:
|
|
153
231
|
# Low-Rank Orthogonalization via Gaussian Sketching
|
|
154
232
|
M = G
|
|
@@ -172,7 +250,7 @@ def newton_schulz(
|
|
|
172
250
|
projected_M = Q.T @ M
|
|
173
251
|
|
|
174
252
|
# 4. Orthogonalize the smaller projected matrix
|
|
175
|
-
ortho_projected_M =
|
|
253
|
+
ortho_projected_M = ns_fn(
|
|
176
254
|
projected_M,
|
|
177
255
|
steps=steps,
|
|
178
256
|
eps=eps,
|
|
@@ -186,7 +264,7 @@ def newton_schulz(
|
|
|
186
264
|
return Q @ ortho_projected_M
|
|
187
265
|
|
|
188
266
|
# Standard Path
|
|
189
|
-
return
|
|
267
|
+
return ns_fn(
|
|
190
268
|
G,
|
|
191
269
|
steps=steps,
|
|
192
270
|
eps=eps,
|
|
@@ -60,6 +60,7 @@ def apply_parameter_update(
|
|
|
60
60
|
if random_int_tensor is not None:
|
|
61
61
|
# Compiled path: use the pre-computed random tensor
|
|
62
62
|
_copy_stochastic_core_(p, p_fp32, random_int_tensor)
|
|
63
|
+
del random_int_tensor
|
|
63
64
|
else:
|
|
64
65
|
# Uncompiled path: generate randoms inside
|
|
65
66
|
copy_stochastic_(p, p_fp32)
|
|
@@ -132,7 +133,7 @@ def _copy_stochastic_core_(target: Tensor, source: Tensor, random_int_tensor: Te
|
|
|
132
133
|
Core logic for stochastic rounding using a pre-computed random integer tensor.
|
|
133
134
|
This version is designed to be torch.compile-friendly.
|
|
134
135
|
"""
|
|
135
|
-
result = random_int_tensor
|
|
136
|
+
result = random_int_tensor
|
|
136
137
|
# add the random number to the lower 16 bit of the mantissa
|
|
137
138
|
result.add_(source.view(dtype=torch.int32))
|
|
138
139
|
|
|
@@ -142,8 +143,6 @@ def _copy_stochastic_core_(target: Tensor, source: Tensor, random_int_tensor: Te
|
|
|
142
143
|
# copy the higher 16 bit into the target tensor
|
|
143
144
|
target.copy_(result.view(dtype=torch.float32))
|
|
144
145
|
|
|
145
|
-
del result
|
|
146
|
-
|
|
147
146
|
|
|
148
147
|
def copy_stochastic_(target: Tensor, source: Tensor):
|
|
149
148
|
"""
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|