adv-optm 2.4.dev6__tar.gz → 2.4.dev7__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {adv_optm-2.4.dev6 → adv_optm-2.4.dev7}/PKG-INFO +1 -1
- {adv_optm-2.4.dev6 → adv_optm-2.4.dev7}/adv_optm/__init__.py +3 -1
- {adv_optm-2.4.dev6 → adv_optm-2.4.dev7}/adv_optm/optim/AdaMuon_adv.py +101 -64
- {adv_optm-2.4.dev6 → adv_optm-2.4.dev7}/adv_optm/optim/AdamW_adv.py +111 -75
- {adv_optm-2.4.dev6 → adv_optm-2.4.dev7}/adv_optm/optim/Adopt_adv.py +118 -85
- {adv_optm-2.4.dev6 → adv_optm-2.4.dev7}/adv_optm/optim/Lion_adv.py +13 -10
- {adv_optm-2.4.dev6 → adv_optm-2.4.dev7}/adv_optm/optim/Muon_adv.py +56 -53
- adv_optm-2.4.dev7/adv_optm/optim/SGD_adv.py +283 -0
- {adv_optm-2.4.dev6 → adv_optm-2.4.dev7}/adv_optm/optim/SignSGD_adv.py +79 -28
- {adv_optm-2.4.dev6 → adv_optm-2.4.dev7}/adv_optm/optim/Simplified_AdEMAMix.py +7 -7
- {adv_optm-2.4.dev6 → adv_optm-2.4.dev7}/adv_optm/optim/__init__.py +2 -0
- {adv_optm-2.4.dev6 → adv_optm-2.4.dev7}/adv_optm/util/Kourkoutas.py +64 -8
- {adv_optm-2.4.dev6 → adv_optm-2.4.dev7}/adv_optm/util/Muon_util.py +3 -43
- adv_optm-2.4.dev7/adv_optm/util/OrthoGrad.py +19 -0
- {adv_optm-2.4.dev6 → adv_optm-2.4.dev7}/adv_optm/util/centered_decay.py +9 -2
- {adv_optm-2.4.dev6 → adv_optm-2.4.dev7}/adv_optm/util/param_update.py +227 -66
- {adv_optm-2.4.dev6 → adv_optm-2.4.dev7}/adv_optm/util/scaled_optm.py +57 -47
- adv_optm-2.4.dev7/adv_optm/util/sinkhorn.py +42 -0
- adv_optm-2.4.dev7/adv_optm/util/state_util.py +289 -0
- {adv_optm-2.4.dev6 → adv_optm-2.4.dev7}/adv_optm/util/update_util.py +10 -52
- {adv_optm-2.4.dev6 → adv_optm-2.4.dev7}/adv_optm.egg-info/PKG-INFO +1 -1
- {adv_optm-2.4.dev6 → adv_optm-2.4.dev7}/adv_optm.egg-info/SOURCES.txt +3 -0
- {adv_optm-2.4.dev6 → adv_optm-2.4.dev7}/setup.py +1 -1
- adv_optm-2.4.dev6/adv_optm/util/OrthoGrad.py +0 -50
- {adv_optm-2.4.dev6 → adv_optm-2.4.dev7}/LICENSE +0 -0
- {adv_optm-2.4.dev6 → adv_optm-2.4.dev7}/README.md +0 -0
- {adv_optm-2.4.dev6 → adv_optm-2.4.dev7}/adv_optm/optim/Lion_Prodigy_adv.py +0 -0
- {adv_optm-2.4.dev6 → adv_optm-2.4.dev7}/adv_optm/optim/Prodigy_adv.py +0 -0
- {adv_optm-2.4.dev6 → adv_optm-2.4.dev7}/adv_optm/util/Muon_AuxAdam.py +0 -0
- {adv_optm-2.4.dev6 → adv_optm-2.4.dev7}/adv_optm/util/__init__.py +0 -0
- {adv_optm-2.4.dev6 → adv_optm-2.4.dev7}/adv_optm/util/factorization_util.py +0 -0
- {adv_optm-2.4.dev6 → adv_optm-2.4.dev7}/adv_optm/util/lion_k.py +0 -0
- {adv_optm-2.4.dev6 → adv_optm-2.4.dev7}/adv_optm/util/signed_util.py +0 -0
- {adv_optm-2.4.dev6 → adv_optm-2.4.dev7}/adv_optm.egg-info/dependency_links.txt +0 -0
- {adv_optm-2.4.dev6 → adv_optm-2.4.dev7}/adv_optm.egg-info/requires.txt +0 -0
- {adv_optm-2.4.dev6 → adv_optm-2.4.dev7}/adv_optm.egg-info/top_level.txt +0 -0
- {adv_optm-2.4.dev6 → adv_optm-2.4.dev7}/setup.cfg +0 -0
|
@@ -8,6 +8,7 @@ from .optim import (
|
|
|
8
8
|
Muon_adv,
|
|
9
9
|
AdaMuon_adv,
|
|
10
10
|
SignSGD_adv,
|
|
11
|
+
SGD_adv,
|
|
11
12
|
)
|
|
12
13
|
|
|
13
14
|
__all__ = [
|
|
@@ -20,6 +21,7 @@ __all__ = [
|
|
|
20
21
|
"Muon_adv",
|
|
21
22
|
"AdaMuon_adv",
|
|
22
23
|
"SignSGD_adv",
|
|
24
|
+
"SGD_adv",
|
|
23
25
|
]
|
|
24
26
|
|
|
25
|
-
__version__ = "2.4.
|
|
27
|
+
__version__ = "2.4.dev7"
|
|
@@ -3,12 +3,15 @@ import torch
|
|
|
3
3
|
import math
|
|
4
4
|
|
|
5
5
|
from ..util import param_update
|
|
6
|
-
from ..util.Muon_util import newton_schulz, _is_suitable_for_muon, rms_adjustment, normuon_update, approx_mars, _auto_projection_for_adamuon,
|
|
6
|
+
from ..util.Muon_util import newton_schulz, _is_suitable_for_muon, rms_adjustment, normuon_update, approx_mars, _auto_projection_for_adamuon, get_spectral_scaling
|
|
7
|
+
from ..util.scaled_optm import spectral_normalization, init_spectral_norm
|
|
7
8
|
from ..util.factorization_util import _get_effective_shape, _factorize_state, _reconstruct_state
|
|
8
9
|
from ..util.OrthoGrad import _orthogonalize_gradient
|
|
9
10
|
from ..util.Kourkoutas import KourkoutasHelper
|
|
10
11
|
from ..util import Muon_AuxAdam
|
|
11
12
|
from ..util.centered_decay import _init_anchor
|
|
13
|
+
from typing import Optional
|
|
14
|
+
from ..util.state_util import init_state_tensor, get_state, set_state, upcast_grad_for_precision
|
|
12
15
|
|
|
13
16
|
A = 4 / math.pi
|
|
14
17
|
|
|
@@ -101,6 +104,13 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
101
104
|
the uncompressed optimizer. (default: False)
|
|
102
105
|
use_muon (bool | None): whether to use Muon or AuxAdamW. MUST be provided
|
|
103
106
|
either here or via `optim_type` in parameter groups. (default: None)
|
|
107
|
+
state_precision (str): Precision for Muon optimizer states. Options: 'auto' (parameter dtype), 'fp32',
|
|
108
|
+
'bf16_sr' (BF16 with stochastic rounding), 'fp8_sr', 'int8_sr'.
|
|
109
|
+
(default: 'auto')
|
|
110
|
+
factored_2nd (bool): Factorize only the second moment (v_t) using SMMF
|
|
111
|
+
low-rank compression while keeping the first moment (momentum_buffer)
|
|
112
|
+
dense. Ignored when `nnmf_factor=True` (full SMMF) or `normuon_variant=True`.
|
|
113
|
+
Combines well with `state_precision` on the first moment. (default: False)
|
|
104
114
|
n_layers (int): The depth of the network (L). Required for optimal epsilon scaling. (default: 1)
|
|
105
115
|
spectral_normalization (bool): Enable explicit spectral normalization using power iteration. (default: False)
|
|
106
116
|
--- Auxiliary AdamW_adv Parameters (used for 'adam' groups) ---
|
|
@@ -129,7 +139,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
129
139
|
weight_decay: float = 0,
|
|
130
140
|
cautious_wd: bool = False,
|
|
131
141
|
# Nesterov momentum
|
|
132
|
-
nesterov: bool =
|
|
142
|
+
nesterov: bool = True,
|
|
133
143
|
# RMS Rescaling
|
|
134
144
|
rms_rescaling: bool = True,
|
|
135
145
|
# Newton Schulz
|
|
@@ -149,6 +159,10 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
149
159
|
normuon_variant: bool = False,
|
|
150
160
|
# Boolean to spilt param
|
|
151
161
|
use_muon: bool | None = None,
|
|
162
|
+
# States precision (Muon path)
|
|
163
|
+
state_precision: str = "auto", # 'fp32', 'bf16_sr', 'fp8_sr', 'int8_sr'
|
|
164
|
+
# Factorized second moment only
|
|
165
|
+
factored_2nd: bool = False,
|
|
152
166
|
# Update geometry parameters
|
|
153
167
|
kappa_p: float = 1.0,
|
|
154
168
|
auto_projection: bool = True,
|
|
@@ -174,7 +188,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
174
188
|
compiled_optimizer: bool = False,
|
|
175
189
|
# --- AdamW_adv specific parameters ---
|
|
176
190
|
adam_betas: tuple[float, float] = (0.9, 0.99),
|
|
177
|
-
adam_eps: float = 1e-8,
|
|
191
|
+
adam_eps: float | None = 1e-8,
|
|
178
192
|
adam_weight_decay: float = 0.0,
|
|
179
193
|
adam_use_bias_correction: bool = True,
|
|
180
194
|
adam_use_atan2: bool = False,
|
|
@@ -200,15 +214,17 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
200
214
|
if Simplified_AdEMAMix and nesterov:
|
|
201
215
|
print("Warning: nesterov is incompatible with Simplified_AdEMAMix, Disabling nesterov.")
|
|
202
216
|
nesterov = False
|
|
203
|
-
if normuon_variant and use_atan2:
|
|
204
|
-
print("Warning: AdaMuon atan2 is incompatible with NorMuon, Disabling AdaMuon atan2.")
|
|
205
|
-
use_atan2 = False
|
|
206
217
|
if spectral_normalization and rms_rescaling:
|
|
207
218
|
print("Warning: spectral_normalization is incompatible with rms_rescaling, Disabling rms_rescaling.")
|
|
208
219
|
rms_rescaling = False
|
|
209
220
|
if spectral_normalization and accelerated_ns:
|
|
210
221
|
ValueError("spectral_normalization violates accelerated Newton-Schulz assumptions. Pick one of them.")
|
|
211
222
|
|
|
223
|
+
state_precision = state_precision.lower()
|
|
224
|
+
valid_precisions = {"auto", "fp32", "bf16_sr", "fp8_sr", "int8_sr"}
|
|
225
|
+
if state_precision not in valid_precisions:
|
|
226
|
+
raise ValueError(f"state_precision must be one of {valid_precisions}. Got {state_precision}")
|
|
227
|
+
|
|
212
228
|
defaults = {
|
|
213
229
|
"lr": lr, "betas": betas, "weight_decay": weight_decay, "cautious_wd": cautious_wd,
|
|
214
230
|
"eps": eps, "rms_rescaling": rms_rescaling, "ns_steps": ns_steps,
|
|
@@ -219,6 +235,10 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
219
235
|
"normuon_variant": normuon_variant, "orthogonal_gradient": orthogonal_gradient,
|
|
220
236
|
"compiled_optimizer":compiled_optimizer,
|
|
221
237
|
"use_muon": use_muon,
|
|
238
|
+
# States precision (Muon path)
|
|
239
|
+
"state_precision": state_precision,
|
|
240
|
+
# Factorized second moment only (Muon path)
|
|
241
|
+
"factored_2nd": factored_2nd,
|
|
222
242
|
# Lion-K
|
|
223
243
|
"kappa_p": kappa_p, "auto_projection": auto_projection,
|
|
224
244
|
# Low-rank Ortho
|
|
@@ -335,9 +355,32 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
335
355
|
state['mu_vbuf_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
|
|
336
356
|
state['mv_vbuf_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
|
|
337
357
|
else:
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
358
|
+
# Determine effective state precision (small tensors always use fp32)
|
|
359
|
+
req_precision = group.get('state_precision', 'auto')
|
|
360
|
+
actual_precision = req_precision
|
|
361
|
+
if actual_precision != 'auto' and (p.numel() < 10000 or p.ndim == 1):
|
|
362
|
+
actual_precision = 'fp32'
|
|
363
|
+
group['actual_state_precision'] = actual_precision
|
|
364
|
+
|
|
365
|
+
# factored_2nd: factorize v_t only; ignored for NorMuon (no v_t) and tiny params
|
|
366
|
+
use_factored_2nd = (
|
|
367
|
+
group.get('factored_2nd', False)
|
|
368
|
+
and not group['normuon_variant']
|
|
369
|
+
and p.numel() >= 10000
|
|
370
|
+
and p.ndim > 1
|
|
371
|
+
)
|
|
372
|
+
state['factored_2nd'] = use_factored_2nd
|
|
373
|
+
|
|
374
|
+
default_dtype = p.dtype
|
|
375
|
+
init_state_tensor(state, 'momentum_buffer', p.shape, actual_precision, p.device, default_dtype)
|
|
376
|
+
|
|
377
|
+
if use_factored_2nd:
|
|
378
|
+
state['effective_shape'] = _get_effective_shape(p.numel())
|
|
379
|
+
d1, d2 = state['effective_shape']
|
|
380
|
+
state['mu_vbuf_nmf'] = torch.zeros(d1, device=p.device, dtype=torch.float32)
|
|
381
|
+
state['mv_vbuf_nmf'] = torch.zeros(d2, device=p.device, dtype=torch.float32)
|
|
382
|
+
elif not group['normuon_variant']:
|
|
383
|
+
init_state_tensor(state, 'second_momentum_buffer', p.shape, actual_precision, p.device, default_dtype, non_neg=True)
|
|
341
384
|
|
|
342
385
|
# NorMuon state initialization
|
|
343
386
|
if group['normuon_variant']:
|
|
@@ -349,25 +392,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
349
392
|
|
|
350
393
|
# Spectral Normalization
|
|
351
394
|
if group.get('spectral_normalization', False):
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
# Case A: Factored Muon
|
|
355
|
-
if state['factored']:
|
|
356
|
-
d1, d2 = state['effective_shape']
|
|
357
|
-
# We need a vector matching the 'inner' dimension d2
|
|
358
|
-
state['spectral_v'] = torch.randn(d2, device=device, dtype=dtype, generator=gen)
|
|
359
|
-
|
|
360
|
-
# Case B: Standard Muon (Linear, Conv2d, etc.)
|
|
361
|
-
elif len(p.shape) >= 2:
|
|
362
|
-
# Since Muon performs `update.flatten(1)`, the matrix becomes
|
|
363
|
-
# (p.shape[0], product_of_rest).
|
|
364
|
-
d_in_flat = p.numel() // p.shape[0]
|
|
365
|
-
|
|
366
|
-
state['spectral_v'] = torch.randn(d_in_flat, device=device, dtype=dtype, generator=gen)
|
|
367
|
-
|
|
368
|
-
# Normalize initial vector for stability
|
|
369
|
-
if 'spectral_v' in state:
|
|
370
|
-
state['spectral_v'].div_(state['spectral_v'].norm())
|
|
395
|
+
init_spectral_norm(group, state, p)
|
|
371
396
|
|
|
372
397
|
# MARS-M state initialization
|
|
373
398
|
if group.get('approx_mars', False):
|
|
@@ -436,18 +461,31 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
436
461
|
if is_compiled:
|
|
437
462
|
lr = torch.as_tensor(group['lr'])
|
|
438
463
|
muon_step_param = self._compiled_muon_step_parameter
|
|
464
|
+
|
|
465
|
+
# Generate state SR random tensor when compiled
|
|
466
|
+
actual_precision = group['actual_state_precision']
|
|
467
|
+
random_int_state_tensor = random_int_tensor
|
|
468
|
+
if actual_precision == 'bf16_sr' and random_int_state_tensor is not None:
|
|
469
|
+
random_int_state_tensor = param_update._get_random_int_for_sr(p)
|
|
470
|
+
elif actual_precision == 'int8_sr':
|
|
471
|
+
random_int_state_tensor = param_update._get_random_int_for_int8_sr(p)
|
|
472
|
+
elif actual_precision == 'fp8_sr':
|
|
473
|
+
random_int_state_tensor = param_update._get_random_int_for_fp8_sr(p)
|
|
439
474
|
else:
|
|
440
475
|
lr = group['lr']
|
|
441
476
|
muon_step_param = self._muon_step_parameter
|
|
477
|
+
random_int_state_tensor = None
|
|
442
478
|
|
|
443
|
-
muon_step_param(p, grad, state, group, lr, random_int_tensor)
|
|
479
|
+
muon_step_param(p, grad, state, group, lr, random_int_tensor, random_int_state_tensor)
|
|
444
480
|
|
|
445
481
|
def compile(self, *args, **kwargs):
|
|
446
482
|
self._compiled_muon_step_parameter = torch.compile(self._muon_step_parameter, *args, **kwargs)
|
|
447
483
|
self._compiled_adam_step_parameter = torch.compile(Muon_AuxAdam._adam_step_parameter, *args, **kwargs)
|
|
448
484
|
|
|
449
485
|
@torch.no_grad()
|
|
450
|
-
def _muon_step_parameter(self, p, grad, state, group, lr, random_int_tensor):
|
|
486
|
+
def _muon_step_parameter(self, p, grad, state, group, lr, random_int_tensor, random_int_state_tensor=None):
|
|
487
|
+
# Upcast grad for low-precision state modes (non-factored path)
|
|
488
|
+
grad = upcast_grad_for_precision(grad, state, group.get('state_precision', 'auto'))
|
|
451
489
|
beta1, beta2 = group['betas']
|
|
452
490
|
nesterov = group['nesterov']
|
|
453
491
|
Simplified_AdEMAMix = group['Simplified_AdEMAMix']
|
|
@@ -465,21 +503,10 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
465
503
|
kappa_p = 1.0
|
|
466
504
|
|
|
467
505
|
if group.get('spectral_normalization', False):
|
|
468
|
-
# Compute Scaling Factors
|
|
469
|
-
if state['factored']:
|
|
470
|
-
shape_for_scaling = torch.Size(state['effective_shape'])
|
|
471
|
-
else:
|
|
472
|
-
shape_for_scaling = p.shape
|
|
473
|
-
|
|
474
|
-
scaled_eps, adaptive_eps, spectral_target, wd_scale = get_spectral_scaling(p, shape_for_scaling, group['n_layers'])
|
|
475
506
|
|
|
476
|
-
|
|
507
|
+
ns_eps, adaptive_eps, _, _ = get_spectral_scaling(p, p.shape, group.get('n_layers', 1))
|
|
477
508
|
decoupled_wd = True
|
|
478
|
-
|
|
479
|
-
ns_eps = scaled_eps
|
|
480
|
-
|
|
481
509
|
else:
|
|
482
|
-
weight_decay = group['weight_decay']
|
|
483
510
|
decoupled_wd = False
|
|
484
511
|
ns_eps = group['ns_eps']
|
|
485
512
|
adaptive_eps = group['eps']
|
|
@@ -488,8 +515,6 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
488
515
|
if group.get('approx_mars', False):
|
|
489
516
|
grad = approx_mars(grad, state['last_grad'], group['mars_gamma'], beta1, Simplified_AdEMAMix=Simplified_AdEMAMix)
|
|
490
517
|
|
|
491
|
-
if grad.dtype != torch.float32 and state.get('factored', False):
|
|
492
|
-
grad = grad.float()
|
|
493
518
|
|
|
494
519
|
if group.get("orthogonal_gradient"):
|
|
495
520
|
grad = _orthogonalize_gradient(p, grad)
|
|
@@ -552,22 +577,15 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
552
577
|
update.div_(denom)
|
|
553
578
|
del denom, vt_buf
|
|
554
579
|
|
|
555
|
-
# RMS-aligned scaling
|
|
556
|
-
step_scale = lr * A if group['use_atan2'] and not group['normuon_variant'] else lr
|
|
557
|
-
# Spectral Normalization
|
|
558
|
-
if group.get('spectral_normalization', False):
|
|
559
|
-
spectral_norm_update(update, state['spectral_v'], spectral_target, step_scale)
|
|
560
|
-
else:
|
|
561
|
-
# Factored RMS-aligned scaling
|
|
562
|
-
rms_adjustment(update, group['rms_rescaling'], step_scale)
|
|
563
|
-
|
|
564
580
|
update = update.reshape(p.shape)
|
|
565
581
|
|
|
566
582
|
else: # Standard AdaMuon logic for non-factored tensors
|
|
567
583
|
original_shape = p.shape
|
|
584
|
+
actual_precision = group['actual_state_precision']
|
|
585
|
+
factored_2nd = state.get('factored_2nd', False)
|
|
568
586
|
|
|
569
587
|
# Momentum update
|
|
570
|
-
mt_buf = state
|
|
588
|
+
mt_buf = get_state(state, 'momentum_buffer', actual_precision)
|
|
571
589
|
if not Simplified_AdEMAMix:
|
|
572
590
|
mt_buf.lerp_(grad, 1 - beta1)
|
|
573
591
|
else:
|
|
@@ -580,6 +598,8 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
580
598
|
else:
|
|
581
599
|
update = mt_buf.clone()
|
|
582
600
|
|
|
601
|
+
set_state(state, 'momentum_buffer', mt_buf, actual_precision, random_int_state_tensor)
|
|
602
|
+
|
|
583
603
|
# Apply update projection
|
|
584
604
|
update = _auto_projection_for_adamuon(update, kappa_p)
|
|
585
605
|
|
|
@@ -603,10 +623,26 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
603
623
|
# NorMuon Logic
|
|
604
624
|
if group['normuon_variant']:
|
|
605
625
|
normuon_update(update, state['normuon_v'], beta2, group['eps'])
|
|
626
|
+
elif factored_2nd:
|
|
627
|
+
# Factorized second moment: reconstruct → update → re-factorize
|
|
628
|
+
d1, d2 = state['effective_shape']
|
|
629
|
+
update = update.view(original_shape)
|
|
630
|
+
update_f32 = update.float()
|
|
631
|
+
vt_buf = _reconstruct_state((state['mu_vbuf_nmf'], state['mv_vbuf_nmf']), signed=False)
|
|
632
|
+
vt_buf.mul_(beta2).addcmul_(update_f32.view(d1, d2), update_f32.view(d1, d2), value=1 - beta2)
|
|
633
|
+
state['mu_vbuf_nmf'], state['mv_vbuf_nmf'] = _factorize_state(vt_buf, signed=False)
|
|
634
|
+
# Apply second moment scaling
|
|
635
|
+
if group['use_atan2']:
|
|
636
|
+
denom = vt_buf.sqrt_().view(original_shape)
|
|
637
|
+
update.atan2_(denom.to(update.dtype))
|
|
638
|
+
else:
|
|
639
|
+
denom = vt_buf.sqrt_().add_(adaptive_eps).view(original_shape)
|
|
640
|
+
update.div_(denom.to(update.dtype))
|
|
641
|
+
del denom, vt_buf, update_f32
|
|
606
642
|
else:
|
|
607
643
|
# Original AdaMuon Logic
|
|
608
644
|
update = update.view(original_shape)
|
|
609
|
-
vt_buf = state
|
|
645
|
+
vt_buf = get_state(state, 'second_momentum_buffer', actual_precision)
|
|
610
646
|
vt_buf.mul_(beta2).addcmul_(update, update, value=1 - beta2)
|
|
611
647
|
# Apply second momentum update (adaptive scaling)
|
|
612
648
|
if group['use_atan2']:
|
|
@@ -615,20 +651,21 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
615
651
|
else:
|
|
616
652
|
denom = vt_buf.sqrt().add_(adaptive_eps)
|
|
617
653
|
update.div_(denom)
|
|
654
|
+
set_state(state, 'second_momentum_buffer', vt_buf, actual_precision, random_int_state_tensor, non_neg=True)
|
|
618
655
|
del denom
|
|
619
656
|
|
|
620
|
-
|
|
657
|
+
step_scale = lr * A if group['use_atan2'] and not group['normuon_variant'] else lr
|
|
621
658
|
|
|
622
|
-
|
|
623
|
-
|
|
624
|
-
|
|
625
|
-
|
|
626
|
-
|
|
627
|
-
|
|
659
|
+
if group.get('spectral_normalization', False):
|
|
660
|
+
# Spectral Normalization
|
|
661
|
+
spectral_normalization(update, state['spectral_u'], state['spectral_v'], step_scale)
|
|
662
|
+
else:
|
|
663
|
+
# RMS-aligned rescaling
|
|
664
|
+
rms_adjustment(update, group['rms_rescaling'], step_scale)
|
|
628
665
|
|
|
629
|
-
|
|
666
|
+
update = update.reshape(original_shape)
|
|
630
667
|
|
|
631
|
-
param_update.apply_parameter_update(self, p, group, update, lr,
|
|
668
|
+
param_update.apply_parameter_update(self, p, group, update, lr, random_int_tensor=random_int_tensor, decoupled=decoupled_wd)
|
|
632
669
|
|
|
633
670
|
@torch.no_grad()
|
|
634
671
|
def step(self, closure=None):
|