adv-optm 2.6.dev1__tar.gz → 2.6.1.dev1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {adv_optm-2.6.dev1 → adv_optm-2.6.1.dev1}/PKG-INFO +1 -1
- {adv_optm-2.6.dev1 → adv_optm-2.6.1.dev1}/adv_optm/__init__.py +1 -1
- {adv_optm-2.6.dev1 → adv_optm-2.6.1.dev1}/adv_optm/optim/AdaMuon_adv.py +32 -20
- {adv_optm-2.6.dev1 → adv_optm-2.6.1.dev1}/adv_optm/optim/AdamW_adv.py +20 -15
- {adv_optm-2.6.dev1 → adv_optm-2.6.1.dev1}/adv_optm/optim/Adopt_adv.py +20 -14
- {adv_optm-2.6.dev1 → adv_optm-2.6.1.dev1}/adv_optm/optim/Lion_adv.py +13 -11
- {adv_optm-2.6.dev1 → adv_optm-2.6.1.dev1}/adv_optm/optim/Muon_adv.py +27 -17
- {adv_optm-2.6.dev1 → adv_optm-2.6.1.dev1}/adv_optm/optim/Prodigy_adv.py +20 -16
- {adv_optm-2.6.dev1 → adv_optm-2.6.1.dev1}/adv_optm/optim/SignSGD_adv.py +16 -12
- {adv_optm-2.6.dev1 → adv_optm-2.6.1.dev1}/adv_optm/optim/SinkSGD_adv.py +19 -14
- {adv_optm-2.6.dev1 → adv_optm-2.6.1.dev1}/adv_optm/util/Muon_AuxAdam.py +0 -1
- {adv_optm-2.6.dev1 → adv_optm-2.6.1.dev1}/adv_optm/util/OrthoGrad.py +1 -1
- {adv_optm-2.6.dev1 → adv_optm-2.6.1.dev1}/adv_optm/util/param_update.py +32 -39
- {adv_optm-2.6.dev1 → adv_optm-2.6.1.dev1}/adv_optm/util/scaled_optm.py +24 -37
- {adv_optm-2.6.dev1 → adv_optm-2.6.1.dev1}/adv_optm.egg-info/PKG-INFO +1 -1
- {adv_optm-2.6.dev1 → adv_optm-2.6.1.dev1}/adv_optm.egg-info/SOURCES.txt +0 -1
- {adv_optm-2.6.dev1 → adv_optm-2.6.1.dev1}/setup.py +1 -1
- adv_optm-2.6.dev1/adv_optm/util/msign.py +0 -114
- {adv_optm-2.6.dev1 → adv_optm-2.6.1.dev1}/LICENSE +0 -0
- {adv_optm-2.6.dev1 → adv_optm-2.6.1.dev1}/README.md +0 -0
- {adv_optm-2.6.dev1 → adv_optm-2.6.1.dev1}/adv_optm/optim/__init__.py +0 -0
- {adv_optm-2.6.dev1 → adv_optm-2.6.1.dev1}/adv_optm/util/Kourkoutas.py +0 -0
- {adv_optm-2.6.dev1 → adv_optm-2.6.1.dev1}/adv_optm/util/Muon_util.py +0 -0
- {adv_optm-2.6.dev1 → adv_optm-2.6.1.dev1}/adv_optm/util/__init__.py +0 -0
- {adv_optm-2.6.dev1 → adv_optm-2.6.1.dev1}/adv_optm/util/centered_decay.py +0 -0
- {adv_optm-2.6.dev1 → adv_optm-2.6.1.dev1}/adv_optm/util/factorization_util.py +0 -0
- {adv_optm-2.6.dev1 → adv_optm-2.6.1.dev1}/adv_optm/util/lion_k.py +0 -0
- {adv_optm-2.6.dev1 → adv_optm-2.6.1.dev1}/adv_optm/util/signed_util.py +0 -0
- {adv_optm-2.6.dev1 → adv_optm-2.6.1.dev1}/adv_optm/util/sinkhorn.py +0 -0
- {adv_optm-2.6.dev1 → adv_optm-2.6.1.dev1}/adv_optm/util/state_util.py +0 -0
- {adv_optm-2.6.dev1 → adv_optm-2.6.1.dev1}/adv_optm/util/update_util.py +0 -0
- {adv_optm-2.6.dev1 → adv_optm-2.6.1.dev1}/adv_optm.egg-info/dependency_links.txt +0 -0
- {adv_optm-2.6.dev1 → adv_optm-2.6.1.dev1}/adv_optm.egg-info/requires.txt +0 -0
- {adv_optm-2.6.dev1 → adv_optm-2.6.1.dev1}/adv_optm.egg-info/top_level.txt +0 -0
- {adv_optm-2.6.dev1 → adv_optm-2.6.1.dev1}/setup.cfg +0 -0
|
@@ -137,6 +137,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
137
137
|
# Decoupled/cautious weight decay
|
|
138
138
|
weight_decay: float = 0,
|
|
139
139
|
cautious_wd: bool = False,
|
|
140
|
+
scaled_wd: bool = False,
|
|
140
141
|
# Nesterov momentum
|
|
141
142
|
nesterov: bool = True,
|
|
142
143
|
nesterov_coef: float | None = None,
|
|
@@ -177,8 +178,6 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
177
178
|
mars_gamma: float = 0.025,
|
|
178
179
|
# Spectral Normalization
|
|
179
180
|
spectral_normalization: bool = False,
|
|
180
|
-
# Orthogonalize the weights (Matrix Sign - MSign) every x steps
|
|
181
|
-
MSign_interval: int | None = None,
|
|
182
181
|
# Centered WD
|
|
183
182
|
centered_wd: float = 0.0,
|
|
184
183
|
centered_wd_mode: str = 'float8',
|
|
@@ -229,7 +228,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
229
228
|
|
|
230
229
|
defaults = {
|
|
231
230
|
"lr": lr, "betas": betas, "weight_decay": weight_decay, "cautious_wd": cautious_wd,
|
|
232
|
-
"eps": eps, "rms_rescaling": rms_rescaling, "ns_steps": ns_steps,
|
|
231
|
+
"eps": eps, "rms_rescaling": rms_rescaling, "ns_steps": ns_steps, "scaled_wd": scaled_wd,
|
|
233
232
|
"ns_eps": ns_eps, "ns_coeffs": ns_coeffs, "nnmf_factor": nnmf_factor,
|
|
234
233
|
"vector_reshape": vector_reshape,
|
|
235
234
|
"nesterov":nesterov, "nesterov_coef": nesterov_coef, "use_atan2":use_atan2,
|
|
@@ -249,7 +248,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
249
248
|
# MARS-M
|
|
250
249
|
"approx_mars": approx_mars, "mars_gamma": mars_gamma,
|
|
251
250
|
# Spectral Normalization
|
|
252
|
-
"spectral_normalization": spectral_normalization,
|
|
251
|
+
"spectral_normalization": spectral_normalization,
|
|
253
252
|
# Centered WD
|
|
254
253
|
"centered_wd": centered_wd,
|
|
255
254
|
"centered_wd_mode": centered_wd_mode,
|
|
@@ -284,11 +283,9 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
284
283
|
for device in devices:
|
|
285
284
|
param_update.set_seed(device)
|
|
286
285
|
|
|
287
|
-
# Initialize compiled
|
|
288
|
-
self.
|
|
289
|
-
self.
|
|
290
|
-
if compiled_optimizer:
|
|
291
|
-
self.compile(fullgraph=True)
|
|
286
|
+
# Initialize compiled functions (by parameter shape)
|
|
287
|
+
self._compiled_muon_step_fns = {}
|
|
288
|
+
self._compiled_adam_step_fns = {}
|
|
292
289
|
|
|
293
290
|
def load_state_dict(self, state_dict: dict) -> None:
|
|
294
291
|
"""
|
|
@@ -446,8 +443,16 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
446
443
|
random_int_state_tensor = None
|
|
447
444
|
if is_compiled:
|
|
448
445
|
step_size = torch.as_tensor(step_size)
|
|
449
|
-
|
|
450
|
-
|
|
446
|
+
# Cache compiled function per-shape
|
|
447
|
+
cache_key = (p.shape, state.get('factored', False))
|
|
448
|
+
if cache_key not in self._compiled_adam_step_fns:
|
|
449
|
+
self._compiled_adam_step_fns[cache_key] = torch.compile(
|
|
450
|
+
Muon_AuxAdam._adam_step_parameter,
|
|
451
|
+
fullgraph=True,
|
|
452
|
+
dynamic=False
|
|
453
|
+
)
|
|
454
|
+
adam_step_param = self._compiled_adam_step_fns[cache_key]
|
|
455
|
+
|
|
451
456
|
# Generate state SR random tensor when compiled
|
|
452
457
|
actual_precision = group.get('adam_actual_state_precision', 'auto')
|
|
453
458
|
random_int_state_tensor = random_int_tensor
|
|
@@ -466,7 +471,15 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
466
471
|
random_G_sketch = None
|
|
467
472
|
if is_compiled:
|
|
468
473
|
lr = torch.as_tensor(group['lr'])
|
|
469
|
-
|
|
474
|
+
# Cache compiled function per-shape
|
|
475
|
+
cache_key = (p.shape, state.get('factored', False))
|
|
476
|
+
if cache_key not in self._compiled_muon_step_fns:
|
|
477
|
+
self._compiled_muon_step_fns[cache_key] = torch.compile(
|
|
478
|
+
self._muon_step_parameter,
|
|
479
|
+
fullgraph=True,
|
|
480
|
+
dynamic=False
|
|
481
|
+
)
|
|
482
|
+
muon_step_param = self._compiled_muon_step_fns[cache_key]
|
|
470
483
|
|
|
471
484
|
# Generate state SR random tensor when compiled
|
|
472
485
|
actual_precision = group['actual_state_precision']
|
|
@@ -484,10 +497,6 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
484
497
|
|
|
485
498
|
muon_step_param(p, grad, state, group, lr, random_int_tensor, random_int_state_tensor, random_G_sketch)
|
|
486
499
|
|
|
487
|
-
def compile(self, *args, **kwargs):
|
|
488
|
-
self._compiled_muon_step_parameter = torch.compile(self._muon_step_parameter, *args, **kwargs)
|
|
489
|
-
self._compiled_adam_step_parameter = torch.compile(Muon_AuxAdam._adam_step_parameter, *args, **kwargs)
|
|
490
|
-
|
|
491
500
|
@torch.no_grad()
|
|
492
501
|
def _muon_step_parameter(self, p, grad, state, group, lr, random_int_tensor, random_int_state_tensor, random_G_sketch):
|
|
493
502
|
# Upcast grad for low-precision state modes (non-factored path)
|
|
@@ -533,8 +542,9 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
533
542
|
else:
|
|
534
543
|
update = mt_buf.clone()
|
|
535
544
|
|
|
536
|
-
#
|
|
537
|
-
|
|
545
|
+
# Compress new momentum and store factors
|
|
546
|
+
for key, val in zip(('mu_mbuf_nmf', 'mv_mbuf_nmf', 'sign_buf'), _factorize_state(mt_buf, signed=True, shifter=state['shifter'])):
|
|
547
|
+
state[key].copy_(val)
|
|
538
548
|
del mt_buf
|
|
539
549
|
|
|
540
550
|
# Apply update projection
|
|
@@ -561,7 +571,8 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
561
571
|
vt_buf = _reconstruct_state((state['mu_vbuf_nmf'], state['mv_vbuf_nmf']), signed=False, shifter=state['shifter'])
|
|
562
572
|
# Update second momentum in full-size
|
|
563
573
|
vt_buf.mul_(beta2).addcmul_(update, update, value=1 - beta2)
|
|
564
|
-
|
|
574
|
+
for key, val in zip(('mu_vbuf_nmf', 'mv_vbuf_nmf'), _factorize_state(vt_buf, signed=False, shifter=state['shifter'])):
|
|
575
|
+
state[key].copy_(val)
|
|
565
576
|
# Apply second momentum update (adaptive scaling)
|
|
566
577
|
if group['use_atan2']:
|
|
567
578
|
denom = vt_buf.sqrt_()
|
|
@@ -620,7 +631,8 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
620
631
|
update_f32 = update.float()
|
|
621
632
|
vt_buf = _reconstruct_state((state['mu_vbuf_nmf'], state['mv_vbuf_nmf']), signed=False, shifter=state['shifter'])
|
|
622
633
|
vt_buf.mul_(beta2).addcmul_(update_f32.view(d1, d2), update_f32.view(d1, d2), value=1 - beta2)
|
|
623
|
-
|
|
634
|
+
for key, val in zip(('mu_vbuf_nmf', 'mv_vbuf_nmf'), _factorize_state(vt_buf, signed=False, shifter=state['shifter'])):
|
|
635
|
+
state[key].copy_(val)
|
|
624
636
|
# Apply second moment scaling
|
|
625
637
|
if group['use_atan2']:
|
|
626
638
|
denom = vt_buf.sqrt_().view(original_shape)
|
|
@@ -98,6 +98,7 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
98
98
|
weight_decay: float = 0.0,
|
|
99
99
|
fisher_wd: bool = False,
|
|
100
100
|
cautious_wd: bool = False,
|
|
101
|
+
scaled_wd: bool = False,
|
|
101
102
|
# Adam's Bias Correction
|
|
102
103
|
use_bias_correction: bool = True,
|
|
103
104
|
# Stochastic Rounding for BF16
|
|
@@ -121,8 +122,6 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
121
122
|
layer_key_fn: Optional[Callable] = None,
|
|
122
123
|
# Spectral Normed Optimizer
|
|
123
124
|
spectral_normalization: bool = False,
|
|
124
|
-
# Orthogonalize the weights (Matrix Sign - MSign) every x steps
|
|
125
|
-
MSign_interval: int | None = None,
|
|
126
125
|
# Centered WD
|
|
127
126
|
centered_wd: float = 0.0,
|
|
128
127
|
centered_wd_mode: str = 'float8',
|
|
@@ -158,14 +157,14 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
158
157
|
|
|
159
158
|
defaults = {
|
|
160
159
|
"lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay,
|
|
161
|
-
"fisher_wd": fisher_wd, "cautious_wd": cautious_wd,
|
|
160
|
+
"fisher_wd": fisher_wd, "cautious_wd": cautious_wd, "scaled_wd": scaled_wd,
|
|
162
161
|
"use_atan2": use_atan2, "nesterov": nesterov, "nesterov_coef": nesterov_coef,
|
|
163
162
|
"normed_momentum": normed_momentum,
|
|
164
163
|
"orthogonal_gradient": orthogonal_gradient, "use_bias_correction": use_bias_correction,
|
|
165
164
|
"compiled_optimizer": compiled_optimizer,
|
|
166
165
|
"kourkoutas_beta": kourkoutas_beta, "beta2_min": beta2_min, "ema_alpha": ema_alpha,
|
|
167
166
|
"tiny_spike": tiny_spike, "k_warmup_steps": k_warmup_steps, "k_logging": k_logging,
|
|
168
|
-
"spectral_normalization": spectral_normalization,
|
|
167
|
+
"spectral_normalization": spectral_normalization,
|
|
169
168
|
"centered_wd": centered_wd, "centered_wd_mode": centered_wd_mode,
|
|
170
169
|
"state_precision": state_precision,
|
|
171
170
|
"nnmf_factor": nnmf_factor, "vector_reshape": vector_reshape, "factored_2nd": factored_2nd
|
|
@@ -188,10 +187,8 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
188
187
|
for device in devices:
|
|
189
188
|
param_update.set_seed(device)
|
|
190
189
|
|
|
191
|
-
# Initialize compiled function
|
|
192
|
-
self.
|
|
193
|
-
if compiled_optimizer:
|
|
194
|
-
self.compile(fullgraph=True)
|
|
190
|
+
# Initialize compiled function (by parameter shape)
|
|
191
|
+
self._compiled_step_fns = {}
|
|
195
192
|
|
|
196
193
|
def load_state_dict(self, state_dict: dict) -> None:
|
|
197
194
|
"""
|
|
@@ -323,7 +320,15 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
323
320
|
random_int_state_tensor = param_update._get_random_int_for_sr(p)
|
|
324
321
|
elif group['actual_state_precision'] == 'int8_sr':
|
|
325
322
|
random_int_state_tensor = param_update._get_random_int_for_8bit_sr(p)
|
|
326
|
-
|
|
323
|
+
# Cache compiled function per-shape
|
|
324
|
+
cache_key = (p.shape, state.get('factored', False))
|
|
325
|
+
if cache_key not in self._compiled_step_fns:
|
|
326
|
+
self._compiled_step_fns[cache_key] = torch.compile(
|
|
327
|
+
self._step_parameter,
|
|
328
|
+
fullgraph=True,
|
|
329
|
+
dynamic=False
|
|
330
|
+
)
|
|
331
|
+
step_param_fn = self._compiled_step_fns[cache_key]
|
|
327
332
|
else:
|
|
328
333
|
step_param_fn = self._step_parameter
|
|
329
334
|
|
|
@@ -359,7 +364,8 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
359
364
|
vt.mul_(beta2).addcmul_(grad_reshaped, grad_reshaped, value=1.0 - beta2)
|
|
360
365
|
|
|
361
366
|
# Factorize
|
|
362
|
-
|
|
367
|
+
for key, val in zip(('mu_v_nmf', 'mv_v_nmf'), _factorize_state(vt, signed=False, shifter=state['shifter'])):
|
|
368
|
+
state[key].copy_(val)
|
|
363
369
|
|
|
364
370
|
if group['use_atan2']:
|
|
365
371
|
denom = vt.sqrt_()
|
|
@@ -380,7 +386,8 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
380
386
|
mt.lerp_(grad_reshaped, 1.0 - beta1)
|
|
381
387
|
|
|
382
388
|
# Factorize
|
|
383
|
-
|
|
389
|
+
for key, val in zip(('mu_m_nmf', 'mv_m_nmf', 'sign'), _factorize_state(mt.clone(), signed=True, shifter=state['shifter'])):
|
|
390
|
+
state[key].copy_(val)
|
|
384
391
|
|
|
385
392
|
update_mt = mt
|
|
386
393
|
|
|
@@ -424,7 +431,8 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
424
431
|
exp_avg_sq.mul_(beta2).addcmul_(grad_vt, grad_vt, value=1.0 - beta2)
|
|
425
432
|
|
|
426
433
|
if factored_2nd:
|
|
427
|
-
|
|
434
|
+
for key, val in zip(('mu_v_nmf', 'mv_v_nmf'), _factorize_state(exp_avg_sq.view(d1, d2), signed=False, shifter=state['shifter'])):
|
|
435
|
+
state[key].copy_(val)
|
|
428
436
|
else:
|
|
429
437
|
set_state(state, 'exp_avg_sq', exp_avg_sq, actual_precision, random_int_state_tensor, non_neg=True)
|
|
430
438
|
|
|
@@ -471,9 +479,6 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
471
479
|
|
|
472
480
|
param_update.apply_parameter_update(self, p, group, update, group['lr'], random_int_tensor=random_int_tensor, wd_scaler=wd_scaler)
|
|
473
481
|
|
|
474
|
-
def compile(self, *args, **kwargs):
|
|
475
|
-
self._compiled_step_parameter = torch.compile(self._step_parameter, *args, **kwargs)
|
|
476
|
-
|
|
477
482
|
@torch.no_grad()
|
|
478
483
|
def step(self, closure=None):
|
|
479
484
|
"""Performs a single optimization step."""
|
|
@@ -101,6 +101,7 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
101
101
|
weight_decay: float = 0.0,
|
|
102
102
|
fisher_wd: bool = False,
|
|
103
103
|
cautious_wd: bool = False,
|
|
104
|
+
scaled_wd: bool = False,
|
|
104
105
|
# ADOPT clipping
|
|
105
106
|
clip_lambda: Optional[Callable[[int], float]] = lambda step: step**0.25,
|
|
106
107
|
# Adam_atan2 (scale invariant)
|
|
@@ -122,8 +123,6 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
122
123
|
layer_key_fn: Optional[Callable] = None,
|
|
123
124
|
# Spectral Normed Optimizer
|
|
124
125
|
spectral_normalization: bool = False,
|
|
125
|
-
# Orthogonalize the weights (Matrix Sign - MSign) every x steps
|
|
126
|
-
MSign_interval: int | None = None,
|
|
127
126
|
# Centered WD
|
|
128
127
|
centered_wd: float = 0.0,
|
|
129
128
|
centered_wd_mode: str = 'float8',
|
|
@@ -159,12 +158,12 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
159
158
|
state_precision = "factored"
|
|
160
159
|
|
|
161
160
|
defaults = {
|
|
162
|
-
"lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay,
|
|
161
|
+
"lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay, "scaled_wd": scaled_wd,
|
|
163
162
|
"fisher_wd": fisher_wd, "cautious_wd": cautious_wd, "orthogonal_gradient": orthogonal_gradient,
|
|
164
163
|
"nesterov": nesterov, "nesterov_coef": nesterov_coef,
|
|
165
164
|
"kourkoutas_beta": kourkoutas_beta, "beta2_min": beta2_min, "ema_alpha": ema_alpha,
|
|
166
165
|
"tiny_spike": tiny_spike, "k_warmup_steps": k_warmup_steps, "k_logging": k_logging,
|
|
167
|
-
"spectral_normalization": spectral_normalization,
|
|
166
|
+
"spectral_normalization": spectral_normalization,
|
|
168
167
|
"centered_wd": centered_wd,
|
|
169
168
|
"centered_wd_mode": centered_wd_mode,
|
|
170
169
|
"state_precision": state_precision,
|
|
@@ -191,9 +190,8 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
191
190
|
for device in devices:
|
|
192
191
|
param_update.set_seed(device)
|
|
193
192
|
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
self.compile(fullgraph=True)
|
|
193
|
+
# Initialize compiled function (by parameter shape)
|
|
194
|
+
self._compiled_step_fns = {}
|
|
197
195
|
|
|
198
196
|
def load_state_dict(self, state_dict: dict) -> None:
|
|
199
197
|
"""
|
|
@@ -333,7 +331,15 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
333
331
|
random_int_state_tensor = param_update._get_random_int_for_sr(p)
|
|
334
332
|
elif group['actual_state_precision'] == 'int8_sr':
|
|
335
333
|
random_int_state_tensor = param_update._get_random_int_for_8bit_sr(p)
|
|
336
|
-
|
|
334
|
+
# Cache compiled function per-shape
|
|
335
|
+
cache_key = (p.shape, state.get('factored', False))
|
|
336
|
+
if cache_key not in self._compiled_step_fns:
|
|
337
|
+
self._compiled_step_fns[cache_key] = torch.compile(
|
|
338
|
+
self._step_parameter,
|
|
339
|
+
fullgraph=True,
|
|
340
|
+
dynamic=False
|
|
341
|
+
)
|
|
342
|
+
step_param_fn = self._compiled_step_fns[cache_key]
|
|
337
343
|
else:
|
|
338
344
|
lr = group['lr']
|
|
339
345
|
step_param_fn = self._step_parameter
|
|
@@ -375,7 +381,8 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
375
381
|
else:
|
|
376
382
|
vt.mul_(beta2).addcmul_(grad_reshaped, grad_reshaped, value=1.0 - beta2)
|
|
377
383
|
# Factorize
|
|
378
|
-
|
|
384
|
+
for key, val in zip(('mu_v_nmf', 'mv_v_nmf'), _factorize_state(vt, signed=False, shifter=state['shifter'])):
|
|
385
|
+
state[key].copy_(val)
|
|
379
386
|
del vt
|
|
380
387
|
|
|
381
388
|
if self.use_atan2:
|
|
@@ -393,7 +400,8 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
393
400
|
mt.lerp_(normalized_grad, 1.0 - beta1)
|
|
394
401
|
|
|
395
402
|
# Factorize
|
|
396
|
-
|
|
403
|
+
for key, val in zip(('mu_m_nmf', 'mv_m_nmf', 'sign'), _factorize_state(mt.clone(), signed=True, shifter=state['shifter'])):
|
|
404
|
+
state[key].copy_(val)
|
|
397
405
|
|
|
398
406
|
update_mt = mt
|
|
399
407
|
|
|
@@ -460,7 +468,8 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
460
468
|
vt.mul_(beta2).addcmul_(grad_vt, grad_vt, value=1 - beta2)
|
|
461
469
|
|
|
462
470
|
if factored_2nd:
|
|
463
|
-
|
|
471
|
+
for key, val in zip(('mu_v_nmf', 'mv_v_nmf'), _factorize_state(vt.view(d1, d2), signed=False, shifter=state['shifter'])):
|
|
472
|
+
state[key].copy_(val)
|
|
464
473
|
else:
|
|
465
474
|
set_state(state, 'exp_avg_sq', vt, actual_precision, random_int_state_tensor, non_neg=True)
|
|
466
475
|
del random_int_state_tensor
|
|
@@ -475,9 +484,6 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
475
484
|
# Parameter Update
|
|
476
485
|
param_update.apply_parameter_update(self, p, group, update, lr, random_int_tensor=random_int_tensor, wd_scaler=wd_scaler)
|
|
477
486
|
|
|
478
|
-
def compile(self, *args, **kwargs):
|
|
479
|
-
self._compiled_step_parameter = torch.compile(self._step_parameter, *args, **kwargs)
|
|
480
|
-
|
|
481
487
|
@torch.no_grad()
|
|
482
488
|
def step(self, closure=None):
|
|
483
489
|
"""Performs a single optimization step."""
|
|
@@ -64,6 +64,7 @@ class Lion_adv(torch.optim.Optimizer):
|
|
|
64
64
|
# Decoupled/cautious weight decay
|
|
65
65
|
weight_decay: float = 0.0,
|
|
66
66
|
cautious_wd: bool = False,
|
|
67
|
+
scaled_wd: bool = False,
|
|
67
68
|
# Stochastic Rounding for BF16
|
|
68
69
|
stochastic_rounding: bool = True,
|
|
69
70
|
# OrthoGrad
|
|
@@ -78,8 +79,6 @@ class Lion_adv(torch.optim.Optimizer):
|
|
|
78
79
|
centered_wd_mode: str = 'float8',
|
|
79
80
|
# Spectral Normed Optimizer
|
|
80
81
|
spectral_normalization: bool = False,
|
|
81
|
-
# Orthogonalize the weights (Matrix Sign - MSign) every x steps
|
|
82
|
-
MSign_interval: int | None = None,
|
|
83
82
|
# SMMF factorization
|
|
84
83
|
nnmf_factor: bool = False,
|
|
85
84
|
vector_reshape: bool = False,
|
|
@@ -98,13 +97,13 @@ class Lion_adv(torch.optim.Optimizer):
|
|
|
98
97
|
betas=betas,
|
|
99
98
|
weight_decay=weight_decay,
|
|
100
99
|
cautious_wd=cautious_wd,
|
|
100
|
+
scaled_wd=scaled_wd,
|
|
101
101
|
vector_reshape=vector_reshape,
|
|
102
102
|
orthogonal_gradient=orthogonal_gradient,
|
|
103
103
|
kappa_p=kappa_p,
|
|
104
104
|
auto_kappa_p=auto_kappa_p,
|
|
105
105
|
stochastic_sign=stochastic_sign,
|
|
106
106
|
spectral_normalization=spectral_normalization,
|
|
107
|
-
MSign_interval=MSign_interval,
|
|
108
107
|
nnmf_factor=nnmf_factor,
|
|
109
108
|
centered_wd= centered_wd,
|
|
110
109
|
centered_wd_mode= centered_wd_mode,
|
|
@@ -122,10 +121,8 @@ class Lion_adv(torch.optim.Optimizer):
|
|
|
122
121
|
for device in devices:
|
|
123
122
|
param_update.set_seed(device)
|
|
124
123
|
|
|
125
|
-
# Initialize compiled function
|
|
126
|
-
self.
|
|
127
|
-
if compiled_optimizer:
|
|
128
|
-
self.compile(fullgraph=True)
|
|
124
|
+
# Initialize compiled function (by parameter shape)
|
|
125
|
+
self._compiled_step_fns = {}
|
|
129
126
|
|
|
130
127
|
def load_state_dict(self, state_dict: dict) -> None:
|
|
131
128
|
"""
|
|
@@ -208,7 +205,15 @@ class Lion_adv(torch.optim.Optimizer):
|
|
|
208
205
|
if group.get('stochastic_sign', False):
|
|
209
206
|
random_noise_tensor = param_update._get_random_noise_for_sso(p)
|
|
210
207
|
lr = torch.as_tensor(lr)
|
|
211
|
-
|
|
208
|
+
# Cache compiled function per-shape
|
|
209
|
+
cache_key = (p.shape, state.get('factored', False))
|
|
210
|
+
if cache_key not in self._compiled_step_fns:
|
|
211
|
+
self._compiled_step_fns[cache_key] = torch.compile(
|
|
212
|
+
self._step_parameter,
|
|
213
|
+
fullgraph=True,
|
|
214
|
+
dynamic=False
|
|
215
|
+
)
|
|
216
|
+
step_param_fn = self._compiled_step_fns[cache_key]
|
|
212
217
|
else:
|
|
213
218
|
step_param_fn = self._step_parameter
|
|
214
219
|
|
|
@@ -281,9 +286,6 @@ class Lion_adv(torch.optim.Optimizer):
|
|
|
281
286
|
|
|
282
287
|
param_update.apply_parameter_update(self, p, group, update, lr, random_int_tensor=random_int_tensor)
|
|
283
288
|
|
|
284
|
-
def compile(self, *args, **kwargs):
|
|
285
|
-
self._compiled_step_parameter = torch.compile(self._step_parameter, *args, **kwargs)
|
|
286
|
-
|
|
287
289
|
@torch.no_grad()
|
|
288
290
|
def step(self, closure: Optional[callable] = None):
|
|
289
291
|
"""Performs a single optimization step."""
|
|
@@ -111,6 +111,7 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
111
111
|
# Decoupled/cautious weight decay
|
|
112
112
|
weight_decay: float = 0.0,
|
|
113
113
|
cautious_wd: bool = False,
|
|
114
|
+
scaled_wd: bool = False,
|
|
114
115
|
# Nesterov momentum
|
|
115
116
|
nesterov: bool = True,
|
|
116
117
|
nesterov_coef: float | None = None,
|
|
@@ -146,8 +147,6 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
146
147
|
mars_gamma: float = 0.025,
|
|
147
148
|
# Spectral Normalization
|
|
148
149
|
spectral_normalization: bool = False,
|
|
149
|
-
# Orthogonalize the weights (Matrix Sign - MSign) every x steps
|
|
150
|
-
MSign_interval: int | None = None,
|
|
151
150
|
# Centered WD
|
|
152
151
|
centered_wd: float = 0.0,
|
|
153
152
|
centered_wd_mode: str = 'float8',
|
|
@@ -203,7 +202,7 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
203
202
|
defaults = {
|
|
204
203
|
"lr": lr, "beta1": beta1, "weight_decay": weight_decay, "cautious_wd": cautious_wd,
|
|
205
204
|
"nesterov": nesterov, "nesterov_coef": nesterov_coef, "ns_steps": ns_steps, "ns_eps": ns_eps,
|
|
206
|
-
"ns_coeffs": ns_coeffs, "nnmf_factor": nnmf_factor,
|
|
205
|
+
"ns_coeffs": ns_coeffs, "nnmf_factor": nnmf_factor, "scaled_wd": scaled_wd,
|
|
207
206
|
"vector_reshape": vector_reshape, "rms_rescaling": rms_rescaling,
|
|
208
207
|
"orthogonal_gradient": orthogonal_gradient,
|
|
209
208
|
'compiled_optimizer': compiled_optimizer,
|
|
@@ -220,7 +219,7 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
220
219
|
# MARS-M
|
|
221
220
|
"approx_mars": approx_mars, "mars_gamma": mars_gamma,
|
|
222
221
|
# Spectral Normalization
|
|
223
|
-
"spectral_normalization": spectral_normalization,
|
|
222
|
+
"spectral_normalization": spectral_normalization,
|
|
224
223
|
# Centered WD
|
|
225
224
|
"centered_wd": centered_wd,
|
|
226
225
|
"centered_wd_mode": centered_wd_mode,
|
|
@@ -256,11 +255,9 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
256
255
|
for device in devices:
|
|
257
256
|
param_update.set_seed(device)
|
|
258
257
|
|
|
259
|
-
# Initialize compiled
|
|
260
|
-
self.
|
|
261
|
-
self.
|
|
262
|
-
if compiled_optimizer:
|
|
263
|
-
self.compile(fullgraph=True)
|
|
258
|
+
# Initialize compiled functions (by parameter shape)
|
|
259
|
+
self._compiled_muon_step_fns = {}
|
|
260
|
+
self._compiled_adam_step_fns = {}
|
|
264
261
|
|
|
265
262
|
def load_state_dict(self, state_dict: dict) -> None:
|
|
266
263
|
"""
|
|
@@ -398,7 +395,15 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
398
395
|
random_int_state_tensor = None
|
|
399
396
|
if is_compiled:
|
|
400
397
|
step_size = torch.as_tensor(step_size)
|
|
401
|
-
|
|
398
|
+
# Cache compiled function per-shape
|
|
399
|
+
cache_key = (p.shape, state.get('factored', False))
|
|
400
|
+
if cache_key not in self._compiled_adam_step_fns:
|
|
401
|
+
self._compiled_adam_step_fns[cache_key] = torch.compile(
|
|
402
|
+
Muon_AuxAdam._adam_step_parameter,
|
|
403
|
+
fullgraph=True,
|
|
404
|
+
dynamic=False
|
|
405
|
+
)
|
|
406
|
+
adam_step_param = self._compiled_adam_step_fns[cache_key]
|
|
402
407
|
|
|
403
408
|
actual_precision = group.get('adam_actual_state_precision', 'auto')
|
|
404
409
|
random_int_state_tensor = random_int_tensor
|
|
@@ -417,7 +422,15 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
417
422
|
random_G_sketch = None
|
|
418
423
|
if is_compiled:
|
|
419
424
|
lr = torch.as_tensor(group['lr'])
|
|
420
|
-
|
|
425
|
+
# Cache compiled function per-shape
|
|
426
|
+
cache_key = (p.shape, state.get('factored', False))
|
|
427
|
+
if cache_key not in self._compiled_muon_step_fns:
|
|
428
|
+
self._compiled_muon_step_fns[cache_key] = torch.compile(
|
|
429
|
+
self._muon_step_parameter,
|
|
430
|
+
fullgraph=True,
|
|
431
|
+
dynamic=False
|
|
432
|
+
)
|
|
433
|
+
muon_step_param = self._compiled_muon_step_fns[cache_key]
|
|
421
434
|
|
|
422
435
|
# Generate state SR random tensor when compiled
|
|
423
436
|
actual_precision = group['actual_state_precision']
|
|
@@ -435,10 +448,6 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
435
448
|
|
|
436
449
|
muon_step_param(p, grad, state, group, lr, random_int_tensor, random_int_state_tensor, random_G_sketch)
|
|
437
450
|
|
|
438
|
-
def compile(self, *args, **kwargs):
|
|
439
|
-
self._compiled_muon_step_parameter = torch.compile(self._muon_step_parameter, *args, **kwargs)
|
|
440
|
-
self._compiled_adam_step_parameter = torch.compile(Muon_AuxAdam._adam_step_parameter, *args, **kwargs)
|
|
441
|
-
|
|
442
451
|
@torch.no_grad()
|
|
443
452
|
def _muon_step_parameter(self, p, grad, state, group, lr, random_int_tensor, random_int_state_tensor, random_G_sketch):
|
|
444
453
|
# Upcast grad for low-precision state modes (non-factored path)
|
|
@@ -477,8 +486,9 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
477
486
|
# Standard momentum
|
|
478
487
|
update = mt_buf.clone()
|
|
479
488
|
|
|
480
|
-
#
|
|
481
|
-
|
|
489
|
+
# Compress new momentum and store factors
|
|
490
|
+
for key, val in zip(('mu_mbuf_nmf', 'mv_mbuf_nmf', 'sign_buf'), _factorize_state(mt_buf, signed=True, shifter=state['shifter'])):
|
|
491
|
+
state[key].copy_(val)
|
|
482
492
|
del mt_buf
|
|
483
493
|
|
|
484
494
|
# Orthogonalization step
|
|
@@ -115,6 +115,7 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
115
115
|
weight_decay: float = 0.0,
|
|
116
116
|
fisher_wd: bool = False,
|
|
117
117
|
cautious_wd: bool = False,
|
|
118
|
+
scaled_wd: bool = False,
|
|
118
119
|
# Stochastic Rounding for BF16
|
|
119
120
|
stochastic_rounding: bool = True,
|
|
120
121
|
# Adam_atan2 (scale invariant)
|
|
@@ -156,8 +157,6 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
156
157
|
centered_wd_mode: str = 'float8',
|
|
157
158
|
# Spectral Normalization
|
|
158
159
|
spectral_normalization: bool = False,
|
|
159
|
-
# Orthogonalize the weights (Matrix Sign - MSign) every x steps
|
|
160
|
-
MSign_interval: int | None = None,
|
|
161
160
|
):
|
|
162
161
|
if not (lr >= 0.0):
|
|
163
162
|
raise ValueError(f"Learning-rate should be >= 0.0. Got {lr}")
|
|
@@ -183,7 +182,7 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
183
182
|
|
|
184
183
|
defaults = {
|
|
185
184
|
"lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay,
|
|
186
|
-
"fisher_wd": fisher_wd, "cautious_wd": cautious_wd,
|
|
185
|
+
"fisher_wd": fisher_wd, "cautious_wd": cautious_wd, "scaled_wd": scaled_wd,
|
|
187
186
|
"use_atan2": use_atan2,
|
|
188
187
|
"orthogonal_gradient": orthogonal_gradient,
|
|
189
188
|
"compiled_optimizer": compiled_optimizer,
|
|
@@ -195,7 +194,7 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
195
194
|
"tiny_spike": tiny_spike, "k_warmup_steps": k_warmup_steps, "k_logging": k_logging,
|
|
196
195
|
"centered_wd": centered_wd, "centered_wd_mode": centered_wd_mode,
|
|
197
196
|
"nnmf_factor": nnmf_factor, "vector_reshape": vector_reshape, "factored_2nd": factored_2nd,
|
|
198
|
-
"spectral_normalization": spectral_normalization,
|
|
197
|
+
"spectral_normalization": spectral_normalization,
|
|
199
198
|
}
|
|
200
199
|
self.stochastic_rounding = stochastic_rounding
|
|
201
200
|
self.fsdp_in_use = fsdp_in_use
|
|
@@ -220,11 +219,8 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
220
219
|
for device in devices:
|
|
221
220
|
param_update.set_seed(device)
|
|
222
221
|
|
|
223
|
-
# Initialize compiled function
|
|
224
|
-
self.
|
|
225
|
-
|
|
226
|
-
if compiled_optimizer:
|
|
227
|
-
self.compile(fullgraph=True)
|
|
222
|
+
# Initialize compiled function (by parameter shape)
|
|
223
|
+
self._compiled_step_fns = {}
|
|
228
224
|
|
|
229
225
|
def load_state_dict(self, state_dict: dict) -> None:
|
|
230
226
|
"""
|
|
@@ -366,7 +362,15 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
366
362
|
random_int_state_tensor = param_update._get_random_int_for_sr(p)
|
|
367
363
|
elif group['actual_state_precision'] == 'int8_sr':
|
|
368
364
|
random_int_state_tensor = param_update._get_random_int_for_8bit_sr(p)
|
|
369
|
-
|
|
365
|
+
# Cache compiled function per-shape
|
|
366
|
+
cache_key = (p.shape, state.get('factored', False))
|
|
367
|
+
if cache_key not in self._compiled_step_fns:
|
|
368
|
+
self._compiled_step_fns[cache_key] = torch.compile(
|
|
369
|
+
self._step_parameter,
|
|
370
|
+
fullgraph=True,
|
|
371
|
+
dynamic=False
|
|
372
|
+
)
|
|
373
|
+
step_param_fn = self._compiled_step_fns[cache_key]
|
|
370
374
|
else:
|
|
371
375
|
d = group['d']
|
|
372
376
|
step_param_fn = self._step_parameter
|
|
@@ -402,7 +406,8 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
402
406
|
mt.mul_(self.beta1).add_(grad_reshaped, alpha=d * (1.0 - self.beta1))
|
|
403
407
|
|
|
404
408
|
# Factorize
|
|
405
|
-
|
|
409
|
+
for key, val in zip(('mu_m_nmf', 'mv_m_nmf', 'sign'), _factorize_state(mt.clone(), signed=True, shifter=state['shifter'])):
|
|
410
|
+
state[key].copy_(val)
|
|
406
411
|
|
|
407
412
|
update_mt = mt
|
|
408
413
|
|
|
@@ -423,7 +428,8 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
423
428
|
update = grad_reshaped.mul(d)
|
|
424
429
|
|
|
425
430
|
# Factorize
|
|
426
|
-
|
|
431
|
+
for key, val in zip(('mu_v_nmf', 'mv_v_nmf'), _factorize_state(vt, signed=False, shifter=state['shifter'])):
|
|
432
|
+
state[key].copy_(val)
|
|
427
433
|
|
|
428
434
|
if group['use_atan2']:
|
|
429
435
|
denom = vt.sqrt_()
|
|
@@ -475,7 +481,8 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
475
481
|
exp_avg_sq.mul_(beta2).addcmul_(grad_vt, grad_vt, value=d * d * (1.0 - beta2))
|
|
476
482
|
|
|
477
483
|
if factored_2nd:
|
|
478
|
-
|
|
484
|
+
for key, val in zip(('mu_v_nmf', 'mv_v_nmf'), _factorize_state(exp_avg_sq.view(d1, d2), signed=False, shifter=state['shifter'])):
|
|
485
|
+
state[key].copy_(val)
|
|
479
486
|
else:
|
|
480
487
|
set_state(state, 'exp_avg_sq', exp_avg_sq, actual_precision, random_int_state_tensor, non_neg=True)
|
|
481
488
|
del random_int_state_tensor
|
|
@@ -524,9 +531,6 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
524
531
|
|
|
525
532
|
param_update.apply_parameter_update(self, p, group, update, dlr, random_int_tensor=random_int_tensor, wd_scaler=wd_scaler)
|
|
526
533
|
|
|
527
|
-
def compile(self, *args, **kwargs):
|
|
528
|
-
self._compiled_step_parameter = torch.compile(self._step_parameter, *args, **kwargs)
|
|
529
|
-
|
|
530
534
|
@torch.no_grad()
|
|
531
535
|
def step(self, closure=None):
|
|
532
536
|
"""Performs a single optimization step."""
|