adv-optm 2.4.dev13__tar.gz → 2.4.dev14__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {adv_optm-2.4.dev13 → adv_optm-2.4.dev14}/PKG-INFO +1 -1
- {adv_optm-2.4.dev13 → adv_optm-2.4.dev14}/adv_optm/__init__.py +1 -1
- {adv_optm-2.4.dev13 → adv_optm-2.4.dev14}/adv_optm/optim/AdaMuon_adv.py +50 -37
- {adv_optm-2.4.dev13 → adv_optm-2.4.dev14}/adv_optm/optim/AdamW_adv.py +26 -8
- {adv_optm-2.4.dev13 → adv_optm-2.4.dev14}/adv_optm/optim/Adopt_adv.py +59 -71
- {adv_optm-2.4.dev13 → adv_optm-2.4.dev14}/adv_optm/optim/Lion_adv.py +23 -53
- {adv_optm-2.4.dev13 → adv_optm-2.4.dev14}/adv_optm/optim/Muon_adv.py +49 -38
- {adv_optm-2.4.dev13 → adv_optm-2.4.dev14}/adv_optm/optim/Prodigy_adv.py +148 -130
- {adv_optm-2.4.dev13 → adv_optm-2.4.dev14}/adv_optm/optim/SignSGD_adv.py +55 -126
- {adv_optm-2.4.dev13 → adv_optm-2.4.dev14}/adv_optm/optim/SinkSGD_adv.py +41 -14
- adv_optm-2.4.dev14/adv_optm/util/Muon_AuxAdam.py +237 -0
- {adv_optm-2.4.dev13 → adv_optm-2.4.dev14}/adv_optm/util/Muon_util.py +2 -5
- {adv_optm-2.4.dev13 → adv_optm-2.4.dev14}/adv_optm/util/scaled_optm.py +83 -29
- adv_optm-2.4.dev14/adv_optm/util/signed_util.py +53 -0
- {adv_optm-2.4.dev13 → adv_optm-2.4.dev14}/adv_optm.egg-info/PKG-INFO +1 -1
- {adv_optm-2.4.dev13 → adv_optm-2.4.dev14}/setup.py +1 -1
- adv_optm-2.4.dev13/adv_optm/util/Muon_AuxAdam.py +0 -172
- adv_optm-2.4.dev13/adv_optm/util/signed_util.py +0 -13
- {adv_optm-2.4.dev13 → adv_optm-2.4.dev14}/LICENSE +0 -0
- {adv_optm-2.4.dev13 → adv_optm-2.4.dev14}/README.md +0 -0
- {adv_optm-2.4.dev13 → adv_optm-2.4.dev14}/adv_optm/optim/Lion_Prodigy_adv.py +0 -0
- {adv_optm-2.4.dev13 → adv_optm-2.4.dev14}/adv_optm/optim/Simplified_AdEMAMix.py +0 -0
- {adv_optm-2.4.dev13 → adv_optm-2.4.dev14}/adv_optm/optim/__init__.py +0 -0
- {adv_optm-2.4.dev13 → adv_optm-2.4.dev14}/adv_optm/util/Kourkoutas.py +0 -0
- {adv_optm-2.4.dev13 → adv_optm-2.4.dev14}/adv_optm/util/OrthoGrad.py +0 -0
- {adv_optm-2.4.dev13 → adv_optm-2.4.dev14}/adv_optm/util/__init__.py +0 -0
- {adv_optm-2.4.dev13 → adv_optm-2.4.dev14}/adv_optm/util/centered_decay.py +0 -0
- {adv_optm-2.4.dev13 → adv_optm-2.4.dev14}/adv_optm/util/factorization_util.py +0 -0
- {adv_optm-2.4.dev13 → adv_optm-2.4.dev14}/adv_optm/util/lion_k.py +0 -0
- {adv_optm-2.4.dev13 → adv_optm-2.4.dev14}/adv_optm/util/param_update.py +0 -0
- {adv_optm-2.4.dev13 → adv_optm-2.4.dev14}/adv_optm/util/sinkhorn.py +0 -0
- {adv_optm-2.4.dev13 → adv_optm-2.4.dev14}/adv_optm/util/state_util.py +0 -0
- {adv_optm-2.4.dev13 → adv_optm-2.4.dev14}/adv_optm/util/update_util.py +0 -0
- {adv_optm-2.4.dev13 → adv_optm-2.4.dev14}/adv_optm.egg-info/SOURCES.txt +0 -0
- {adv_optm-2.4.dev13 → adv_optm-2.4.dev14}/adv_optm.egg-info/dependency_links.txt +0 -0
- {adv_optm-2.4.dev13 → adv_optm-2.4.dev14}/adv_optm.egg-info/requires.txt +0 -0
- {adv_optm-2.4.dev13 → adv_optm-2.4.dev14}/adv_optm.egg-info/top_level.txt +0 -0
- {adv_optm-2.4.dev13 → adv_optm-2.4.dev14}/setup.cfg +0 -0
|
@@ -59,14 +59,6 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
59
59
|
orthogonal_gradient (bool): whether to use OrthoGrad. (default: False)
|
|
60
60
|
nesterov (bool): enables Nesterov momentum (default: False).
|
|
61
61
|
use_atan2 (bool): whether to use the atan2 update rule. (default: False)
|
|
62
|
-
Simplified_AdEMAMix (bool): whether to use the Simplified AdEMAMix update rule.
|
|
63
|
-
This changes the update to `alpha_grad * grad + mt`, which can be
|
|
64
|
-
more responsive, especially for small batch sizes. (default: False)
|
|
65
|
-
alpha_grad (float): Mixing coefficient for the Simplified AdEMAMix update rule
|
|
66
|
-
(only used when `Simplified_AdEMAMix` is `True`). Controls the weight of the
|
|
67
|
-
current gradient. For small batch sizes, use high values (e.g., 10-100) to be
|
|
68
|
-
more responsive. For large batch sizes, use low values (e.g., 0-1) for
|
|
69
|
-
stability. (default: 100.0)
|
|
70
62
|
vector_reshape (bool): whether to reshape 1D vectors into 2D
|
|
71
63
|
matrices to apply low-rank compression (default: True).
|
|
72
64
|
kappa_p (float, optional): The p-value for the update geometry (domain [1.0, 2.0]).
|
|
@@ -117,6 +109,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
117
109
|
adam_betas (tuple[float, float]): Betas for the AdamW optimizer part.
|
|
118
110
|
adam_eps (float): Epsilon for the AdamW optimizer part.
|
|
119
111
|
adam_weight_decay (float): Weight decay for the AdamW optimizer part.
|
|
112
|
+
adam_fisher_wd (bool): Fisher Adam (FAdam) weight decay for the AdamW part. (default: False)
|
|
120
113
|
adam_use_bias_correction (bool): Bias correction for AdamW.
|
|
121
114
|
adam_use_atan2 (bool): Atan2 update rule for AdamW.
|
|
122
115
|
adam_cautious_mask (bool): Cautious masking for AdamW.
|
|
@@ -125,8 +118,17 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
125
118
|
adam_use_AdEMAMix (bool): AdEMAMix for AdamW.
|
|
126
119
|
adam_beta3_ema (float): Beta3 for AdEMAMix.
|
|
127
120
|
adam_alpha (float): Alpha for AdEMAMix.
|
|
121
|
+
adam_nesterov (bool): Nesterov momentum for AdamW. (default: False)
|
|
122
|
+
adam_nesterov_coef (float, optional): Nesterov coefficient for AdamW. (default: None)
|
|
128
123
|
adam_kourkoutas_beta (bool): Kourkoutas-β for AdamW.
|
|
124
|
+
adam_beta2_min (float): Minimum beta2 for Kourkoutas-β. (default: 0.9)
|
|
125
|
+
adam_ema_alpha (float): EMA alpha for Kourkoutas-β. (default: 0.95)
|
|
126
|
+
adam_tiny_spike (float): Tiny spike for Kourkoutas-β. (default: 1e-9)
|
|
127
|
+
adam_k_warmup_steps (int): Warmup steps for Kourkoutas-β. (default: 0)
|
|
128
|
+
adam_spectral_normalization (bool): Enable explicit spectral normalization for AdamW. (default: False)
|
|
129
|
+
adam_state_precision (str): Precision for AuxAdam states. Options: 'auto', 'fp32', 'bf16_sr', 'fp8_sr', 'int8_sr', 'factored'. (default: 'auto')
|
|
129
130
|
adam_nnmf_factor (bool): 1-bit factored for AdamW.
|
|
131
|
+
adam_factored_2nd (bool): Factorize only the second moment (v_t) for AuxAdam. (default: False)
|
|
130
132
|
"""
|
|
131
133
|
|
|
132
134
|
def __init__(
|
|
@@ -140,6 +142,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
140
142
|
cautious_wd: bool = False,
|
|
141
143
|
# Nesterov momentum
|
|
142
144
|
nesterov: bool = True,
|
|
145
|
+
nesterov_coef: float | None = None,
|
|
143
146
|
# RMS Rescaling
|
|
144
147
|
rms_rescaling: bool = True,
|
|
145
148
|
# Newton Schulz
|
|
@@ -152,9 +155,6 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
152
155
|
orthogonal_gradient: bool = False,
|
|
153
156
|
# Adam_atan2 (scale invariant)
|
|
154
157
|
use_atan2: bool = False,
|
|
155
|
-
# One-EMA AdEMAMix
|
|
156
|
-
Simplified_AdEMAMix: bool = False,
|
|
157
|
-
alpha_grad: float = 100.0,
|
|
158
158
|
# NorMuon
|
|
159
159
|
normuon_variant: bool = False,
|
|
160
160
|
# Boolean to spilt param
|
|
@@ -190,6 +190,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
190
190
|
adam_betas: tuple[float, float] = (0.9, 0.99),
|
|
191
191
|
adam_eps: float | None = 1e-8,
|
|
192
192
|
adam_weight_decay: float = 0.0,
|
|
193
|
+
adam_fisher_wd: bool = False,
|
|
193
194
|
adam_use_bias_correction: bool = True,
|
|
194
195
|
adam_use_atan2: bool = False,
|
|
195
196
|
adam_cautious_mask: bool = False,
|
|
@@ -198,12 +199,17 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
198
199
|
adam_use_AdEMAMix: bool = False,
|
|
199
200
|
adam_beta3_ema: float = 0.9999,
|
|
200
201
|
adam_alpha: float = 5.0,
|
|
202
|
+
adam_nesterov: bool = False,
|
|
203
|
+
adam_nesterov_coef: float | None = None,
|
|
201
204
|
adam_kourkoutas_beta: bool = False,
|
|
202
205
|
adam_beta2_min: float = 0.9,
|
|
203
206
|
adam_ema_alpha: float = 0.95,
|
|
204
207
|
adam_tiny_spike: float = 1e-9,
|
|
205
208
|
adam_k_warmup_steps: int = 0,
|
|
209
|
+
adam_spectral_normalization: bool = False,
|
|
210
|
+
adam_state_precision: str = "auto",
|
|
206
211
|
adam_nnmf_factor: bool = False,
|
|
212
|
+
adam_factored_2nd: bool = False,
|
|
207
213
|
):
|
|
208
214
|
if not (lr >= 0.0):
|
|
209
215
|
raise ValueError(f"Learning-rate should be >= 0.0. Got {lr}")
|
|
@@ -211,9 +217,6 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
211
217
|
raise ValueError(f"Weight-decay should be >= 0.0. Got {weight_decay}")
|
|
212
218
|
if not (ns_steps > 0):
|
|
213
219
|
raise ValueError(f"Newton-Schulz steps should be > 0. Got {ns_steps}")
|
|
214
|
-
if Simplified_AdEMAMix and nesterov:
|
|
215
|
-
print("Warning: nesterov is incompatible with Simplified_AdEMAMix, Disabling nesterov.")
|
|
216
|
-
nesterov = False
|
|
217
220
|
if spectral_normalization and rms_rescaling:
|
|
218
221
|
print("Warning: spectral_normalization is incompatible with rms_rescaling, Disabling rms_rescaling.")
|
|
219
222
|
rms_rescaling = False
|
|
@@ -221,17 +224,20 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
221
224
|
ValueError("spectral_normalization violates accelerated Newton-Schulz assumptions. Pick one of them.")
|
|
222
225
|
|
|
223
226
|
state_precision = state_precision.lower()
|
|
224
|
-
valid_precisions = {"auto", "fp32", "bf16_sr", "fp8_sr", "int8_sr"}
|
|
227
|
+
valid_precisions = {"auto", "fp32", "factored", "bf16_sr", "fp8_sr", "int8_sr"}
|
|
225
228
|
if state_precision not in valid_precisions:
|
|
226
229
|
raise ValueError(f"state_precision must be one of {valid_precisions}. Got {state_precision}")
|
|
227
230
|
|
|
231
|
+
adam_state_precision = adam_state_precision.lower()
|
|
232
|
+
if adam_state_precision not in valid_precisions:
|
|
233
|
+
raise ValueError(f"adam_state_precision must be one of {valid_precisions}. Got {adam_state_precision}")
|
|
234
|
+
|
|
228
235
|
defaults = {
|
|
229
236
|
"lr": lr, "betas": betas, "weight_decay": weight_decay, "cautious_wd": cautious_wd,
|
|
230
237
|
"eps": eps, "rms_rescaling": rms_rescaling, "ns_steps": ns_steps,
|
|
231
238
|
"ns_eps": ns_eps, "ns_coeffs": ns_coeffs, "nnmf_factor": nnmf_factor,
|
|
232
239
|
"vector_reshape": vector_reshape,
|
|
233
|
-
"nesterov":nesterov, "use_atan2":use_atan2,
|
|
234
|
-
"Simplified_AdEMAMix": Simplified_AdEMAMix, "alpha_grad": alpha_grad,
|
|
240
|
+
"nesterov":nesterov, "nesterov_coef": nesterov_coef, "use_atan2":use_atan2,
|
|
235
241
|
"normuon_variant": normuon_variant, "orthogonal_gradient": orthogonal_gradient,
|
|
236
242
|
"compiled_optimizer":compiled_optimizer,
|
|
237
243
|
"use_muon": use_muon,
|
|
@@ -254,13 +260,18 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
254
260
|
"centered_wd_mode": centered_wd_mode,
|
|
255
261
|
# AdamW_adv defaults
|
|
256
262
|
"adam_betas": adam_betas, "adam_eps": adam_eps, "adam_weight_decay": adam_weight_decay,
|
|
263
|
+
"adam_fisher_wd": adam_fisher_wd,
|
|
257
264
|
"adam_use_bias_correction": adam_use_bias_correction, "adam_use_atan2": adam_use_atan2,
|
|
258
265
|
"adam_cautious_mask": adam_cautious_mask, "adam_grams_moment": adam_grams_moment,
|
|
259
266
|
"adam_orthogonal_gradient": adam_orthogonal_gradient,
|
|
260
267
|
"adam_use_AdEMAMix": adam_use_AdEMAMix, "adam_beta3_ema": adam_beta3_ema, "adam_alpha": adam_alpha,
|
|
268
|
+
"adam_nesterov": adam_nesterov, "adam_nesterov_coef": adam_nesterov_coef,
|
|
261
269
|
"adam_kourkoutas_beta": adam_kourkoutas_beta, "adam_beta2_min": adam_beta2_min,
|
|
262
270
|
"adam_ema_alpha": adam_ema_alpha, "adam_tiny_spike": adam_tiny_spike,
|
|
263
|
-
"adam_k_warmup_steps": adam_k_warmup_steps,
|
|
271
|
+
"adam_k_warmup_steps": adam_k_warmup_steps,
|
|
272
|
+
"adam_spectral_normalization": adam_spectral_normalization,
|
|
273
|
+
"adam_state_precision": adam_state_precision,
|
|
274
|
+
"adam_nnmf_factor": adam_nnmf_factor, "adam_factored_2nd": adam_factored_2nd,
|
|
264
275
|
}
|
|
265
276
|
self.stochastic_rounding = stochastic_rounding
|
|
266
277
|
self._init_lr = lr
|
|
@@ -447,13 +458,24 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
447
458
|
|
|
448
459
|
step_size = group['lr'] / bias_correction1
|
|
449
460
|
|
|
461
|
+
random_int_state_tensor = None
|
|
450
462
|
if is_compiled:
|
|
451
463
|
step_size = torch.as_tensor(step_size)
|
|
452
464
|
adam_step_param = self._compiled_adam_step_parameter
|
|
465
|
+
|
|
466
|
+
# Generate state SR random tensor when compiled
|
|
467
|
+
actual_precision = group.get('adam_actual_state_precision', 'auto')
|
|
468
|
+
random_int_state_tensor = random_int_tensor
|
|
469
|
+
if actual_precision == 'bf16_sr' and random_int_state_tensor is None:
|
|
470
|
+
random_int_state_tensor = param_update._get_random_int_for_sr(p)
|
|
471
|
+
elif actual_precision == 'int8_sr':
|
|
472
|
+
random_int_state_tensor = param_update._get_random_int_for_8bit_sr(p)
|
|
473
|
+
elif actual_precision == 'fp8_sr':
|
|
474
|
+
random_int_state_tensor = param_update._get_random_int_for_fp8_sr(p)
|
|
453
475
|
else:
|
|
454
476
|
adam_step_param = Muon_AuxAdam._adam_step_parameter
|
|
455
477
|
|
|
456
|
-
adam_step_param(self, p, grad, state, group, beta1_adam, beta2_adam, sqrt_bias_correction2, step_size, random_int_tensor)
|
|
478
|
+
adam_step_param(self, p, grad, state, group, beta1_adam, beta2_adam, sqrt_bias_correction2, step_size, random_int_tensor, random_int_state_tensor)
|
|
457
479
|
|
|
458
480
|
state['step'] += 1
|
|
459
481
|
|
|
@@ -465,7 +487,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
465
487
|
# Generate state SR random tensor when compiled
|
|
466
488
|
actual_precision = group['actual_state_precision']
|
|
467
489
|
random_int_state_tensor = random_int_tensor
|
|
468
|
-
if actual_precision == 'bf16_sr' and random_int_state_tensor is
|
|
490
|
+
if actual_precision == 'bf16_sr' and random_int_state_tensor is None:
|
|
469
491
|
random_int_state_tensor = param_update._get_random_int_for_sr(p)
|
|
470
492
|
elif actual_precision == 'int8_sr':
|
|
471
493
|
random_int_state_tensor = param_update._get_random_int_for_8bit_sr(p)
|
|
@@ -488,8 +510,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
488
510
|
grad = upcast_grad_for_precision(grad, state, group.get('state_precision', 'auto'))
|
|
489
511
|
beta1, beta2 = group['betas']
|
|
490
512
|
nesterov = group['nesterov']
|
|
491
|
-
|
|
492
|
-
alpha_grad = group['alpha_grad']
|
|
513
|
+
nesterov_coef = group.get('nesterov_coef', None)
|
|
493
514
|
|
|
494
515
|
# Update geometry
|
|
495
516
|
kappa_p = group.get("kappa_p", 1.0)
|
|
@@ -513,7 +534,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
513
534
|
|
|
514
535
|
# MARS-M Approximated (Variance Reduction)
|
|
515
536
|
if group.get('approx_mars', False):
|
|
516
|
-
grad = approx_mars(grad, state['last_grad'], group['mars_gamma'], beta1
|
|
537
|
+
grad = approx_mars(grad, state['last_grad'], group['mars_gamma'], beta1)
|
|
517
538
|
|
|
518
539
|
|
|
519
540
|
if group.get("orthogonal_gradient"):
|
|
@@ -527,15 +548,11 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
527
548
|
mt_buf = _reconstruct_state((state['mu_mbuf_nmf'], state['mv_mbuf_nmf'], state['sign_buf'], d2), signed=True)
|
|
528
549
|
|
|
529
550
|
# Update momentum in full-size
|
|
530
|
-
|
|
531
|
-
mt_buf.lerp_(grad_reshaped, 1 - beta1)
|
|
532
|
-
else:
|
|
533
|
-
mt_buf.mul_(beta1).add_(grad_reshaped)
|
|
551
|
+
mt_buf.lerp_(grad_reshaped, 1 - beta1)
|
|
534
552
|
|
|
535
553
|
if nesterov:
|
|
536
|
-
|
|
537
|
-
|
|
538
|
-
update = torch.add(mt_buf, grad_reshaped, alpha=alpha_grad)
|
|
554
|
+
nv_coef = beta1 if nesterov_coef is None else nesterov_coef
|
|
555
|
+
update = grad_reshaped.lerp(mt_buf, nv_coef)
|
|
539
556
|
else:
|
|
540
557
|
update = mt_buf.clone()
|
|
541
558
|
|
|
@@ -586,15 +603,11 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
586
603
|
|
|
587
604
|
# Momentum update
|
|
588
605
|
mt_buf = get_state(state, 'momentum_buffer', actual_precision)
|
|
589
|
-
|
|
590
|
-
mt_buf.lerp_(grad, 1 - beta1)
|
|
591
|
-
else:
|
|
592
|
-
mt_buf.mul_(beta1).add_(grad)
|
|
606
|
+
mt_buf.lerp_(grad, 1 - beta1)
|
|
593
607
|
|
|
594
608
|
if nesterov:
|
|
595
|
-
|
|
596
|
-
|
|
597
|
-
update = mt_buf.add(grad, alpha=alpha_grad)
|
|
609
|
+
nv_coef = beta1 if nesterov_coef is None else nesterov_coef
|
|
610
|
+
update = grad.lerp(mt_buf, nv_coef)
|
|
598
611
|
else:
|
|
599
612
|
update = mt_buf.clone()
|
|
600
613
|
|
|
@@ -28,7 +28,8 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
28
28
|
betas (tuple[float, float]): coefficients used for computing running
|
|
29
29
|
averages of gradient and its square (default: (0.9, 0.999))
|
|
30
30
|
eps (float): term added to the denominator to improve
|
|
31
|
-
numerical stability (
|
|
31
|
+
numerical stability. Set to None for scale invariant eps (vector
|
|
32
|
+
lower bound) (default: 1e-8)
|
|
32
33
|
weight_decay (float): weight decay (L2 penalty) (default: 0).
|
|
33
34
|
fisher_wd (bool): whether to use Fisher Adam (FAdam) weight decay, mapping
|
|
34
35
|
the decay direction through the empirical Fisher information matrix and
|
|
@@ -127,6 +128,9 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
127
128
|
use_AdEMAMix: bool = False,
|
|
128
129
|
beta3_ema: float = 0.9999,
|
|
129
130
|
alpha: float = 5.0,
|
|
131
|
+
# Nesterov momentum
|
|
132
|
+
nesterov: bool = False,
|
|
133
|
+
nesterov_coef: float | None = None,
|
|
130
134
|
# K-b (adaptive beta2)
|
|
131
135
|
kourkoutas_beta: bool = False,
|
|
132
136
|
beta2_min: float = 0.9,
|
|
@@ -176,7 +180,7 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
176
180
|
defaults = {
|
|
177
181
|
"lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay,
|
|
178
182
|
"fisher_wd": fisher_wd, "cautious_wd": cautious_wd,
|
|
179
|
-
"use_atan2": use_atan2,
|
|
183
|
+
"use_atan2": use_atan2, "nesterov": nesterov, "nesterov_coef": nesterov_coef,
|
|
180
184
|
"orthogonal_gradient": orthogonal_gradient, "use_bias_correction": use_bias_correction,
|
|
181
185
|
"beta3_ema": beta3_ema, "alpha": alpha, "compiled_optimizer": compiled_optimizer,
|
|
182
186
|
"kourkoutas_beta": kourkoutas_beta, "beta2_min": beta2_min, "ema_alpha": ema_alpha,
|
|
@@ -195,6 +199,8 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
195
199
|
self._init_lr = lr
|
|
196
200
|
super().__init__(params, defaults)
|
|
197
201
|
|
|
202
|
+
self.init_step()
|
|
203
|
+
|
|
198
204
|
if self.kourkoutas_beta:
|
|
199
205
|
self.kourkoutas_helper = KourkoutasHelper(self)
|
|
200
206
|
|
|
@@ -363,6 +369,9 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
363
369
|
if self.use_AdEMAMix:
|
|
364
370
|
beta3_ema = group['beta3_ema']
|
|
365
371
|
alpha = group['alpha']
|
|
372
|
+
nesterov = group.get('nesterov', False)
|
|
373
|
+
nesterov_coef = group.get('nesterov_coef', None)
|
|
374
|
+
use_mt = group['betas'][0] > 0
|
|
366
375
|
|
|
367
376
|
if group.get('kourkoutas_beta', False):
|
|
368
377
|
# Accumulate current grad's norm for the *next* step
|
|
@@ -375,7 +384,7 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
375
384
|
grad_reshaped = grad.view(d1, d2)
|
|
376
385
|
|
|
377
386
|
# Reconstruct momentum from previous step's factors
|
|
378
|
-
if
|
|
387
|
+
if use_mt:
|
|
379
388
|
mt = _reconstruct_state((state['mu_m_nmf'], state['mv_m_nmf'], state['sign'], d2), signed=True)
|
|
380
389
|
|
|
381
390
|
# Update momentum in full-size
|
|
@@ -391,6 +400,10 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
391
400
|
else:
|
|
392
401
|
update_mt = mt
|
|
393
402
|
|
|
403
|
+
if nesterov:
|
|
404
|
+
nv_coef = beta1 if nesterov_coef is None else nesterov_coef
|
|
405
|
+
update_mt = update_mt.lerp_(grad_reshaped, 1-nv_coef)
|
|
406
|
+
|
|
394
407
|
vt = _reconstruct_state((state['mu_v_nmf'], state['mv_v_nmf']), signed=False)
|
|
395
408
|
|
|
396
409
|
if isinstance(beta2, torch.Tensor) and beta2.dim() > 0:
|
|
@@ -403,7 +416,7 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
403
416
|
|
|
404
417
|
mt_slow.lerp_(grad_reshaped, 1.0 - beta3_ema)
|
|
405
418
|
|
|
406
|
-
if
|
|
419
|
+
if use_mt:
|
|
407
420
|
update = update_mt.add_(mt_slow, alpha=alpha)
|
|
408
421
|
else:
|
|
409
422
|
update = grad_reshaped.add(mt_slow, alpha=alpha)
|
|
@@ -412,7 +425,7 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
412
425
|
state['mu_m_slow_nmf'], state['mv_m_slow_nmf'], state['sign_slow'] = _factorize_state(mt_slow, signed=True)
|
|
413
426
|
del mt_slow
|
|
414
427
|
else:
|
|
415
|
-
if
|
|
428
|
+
if use_mt:
|
|
416
429
|
update = update_mt
|
|
417
430
|
else:
|
|
418
431
|
update = grad_reshaped.clone()
|
|
@@ -439,7 +452,7 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
439
452
|
actual_precision = group['actual_state_precision']
|
|
440
453
|
factored_2nd = state.get('factored_2nd', False)
|
|
441
454
|
|
|
442
|
-
if
|
|
455
|
+
if use_mt:
|
|
443
456
|
exp_avg = get_state(state, 'exp_avg', actual_precision)
|
|
444
457
|
exp_avg.lerp_(grad, 1.0 - beta1)
|
|
445
458
|
|
|
@@ -449,19 +462,24 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
449
462
|
update_mt = _cautious_update(exp_avg, grad)
|
|
450
463
|
else:
|
|
451
464
|
update_mt = exp_avg.clone()
|
|
465
|
+
|
|
466
|
+
if nesterov:
|
|
467
|
+
nv_coef = beta1 if nesterov_coef is None else nesterov_coef
|
|
468
|
+
update_mt = update_mt.lerp_(grad, 1-nv_coef)
|
|
469
|
+
|
|
452
470
|
set_state(state, 'exp_avg', exp_avg, actual_precision, random_int_state_tensor)
|
|
453
471
|
|
|
454
472
|
if self.use_AdEMAMix:
|
|
455
473
|
exp_avg_slow = get_state(state, 'exp_avg_slow', actual_precision)
|
|
456
474
|
exp_avg_slow.lerp_(grad, 1.0 - beta3_ema)
|
|
457
475
|
|
|
458
|
-
if
|
|
476
|
+
if use_mt:
|
|
459
477
|
update = update_mt.add_(exp_avg_slow, alpha=alpha)
|
|
460
478
|
else:
|
|
461
479
|
update = torch.add(grad, exp_avg_slow, alpha=alpha)
|
|
462
480
|
set_state(state, 'exp_avg_slow', exp_avg_slow, actual_precision, random_int_state_tensor)
|
|
463
481
|
else:
|
|
464
|
-
update = update_mt if
|
|
482
|
+
update = update_mt if use_mt else grad.clone()
|
|
465
483
|
|
|
466
484
|
if factored_2nd:
|
|
467
485
|
d1, d2 = state['effective_shape']
|
|
@@ -7,7 +7,7 @@ from ..util import param_update
|
|
|
7
7
|
from ..util.factorization_util import _get_effective_shape, _reconstruct_state, _factorize_state, _nnmf
|
|
8
8
|
from ..util.OrthoGrad import _orthogonalize_gradient
|
|
9
9
|
from ..util.Kourkoutas import KourkoutasHelper
|
|
10
|
-
from ..util.update_util import _grams_update, _cautious_update,
|
|
10
|
+
from ..util.update_util import _grams_update, _cautious_update, _init_fisher_wd_scaler, _get_fisher_wd_scaler
|
|
11
11
|
from ..util.scaled_optm import scale_update, is_spectral, init_spectral_norm, scale_eps
|
|
12
12
|
from ..util.centered_decay import _init_anchor
|
|
13
13
|
from ..util.state_util import init_state_tensor, get_state, set_state, upcast_grad_for_precision
|
|
@@ -32,7 +32,8 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
32
32
|
betas (tuple[float, float]): coefficients used for computing running
|
|
33
33
|
averages of momentum and variance (default: (0.9, 0.9999))
|
|
34
34
|
eps (float): term added to the denominator to improve
|
|
35
|
-
numerical stability (
|
|
35
|
+
numerical stability. Set to None for scale invariant eps (vector
|
|
36
|
+
lower bound) (default: 1e-6)
|
|
36
37
|
weight_decay (float): weight decay (L2 penalty) (default: 0)
|
|
37
38
|
fisher_wd (bool): whether to use Fisher Adam (FAdam) weight decay, mapping
|
|
38
39
|
the decay direction through the empirical Fisher information matrix and
|
|
@@ -68,16 +69,6 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
68
69
|
before it is added to the fast momentum term (`update = mt + alpha * mt_slow`).
|
|
69
70
|
A higher value increases the stabilizing influence of the slow
|
|
70
71
|
momentum. (default: 5.0)
|
|
71
|
-
Simplified_AdEMAMix (bool): whether to use the Simplified AdEMAMix update rule.
|
|
72
|
-
This changes the EMA to accumulator and the update numerator to `alpha_grad * grad + mt`, which can be
|
|
73
|
-
more responsive, especially for small batch sizes. Enabling this will
|
|
74
|
-
automatically disable `use_AdEMAMix`, `cautious_mask`, `grams_moment`,
|
|
75
|
-
and `use_atan2`. (default: False)
|
|
76
|
-
alpha_grad (float): Mixing coefficient for the Simplified AdEMAMix update rule
|
|
77
|
-
(only used when `Simplified_AdEMAMix` is `True`). Controls the weight of the
|
|
78
|
-
current gradient. For small batch sizes, use high values (e.g., 10-100) to be
|
|
79
|
-
more responsive. For large batch sizes, use low values (e.g., 0-1) for
|
|
80
|
-
stability. (default: 100.0)
|
|
81
72
|
kourkoutas_beta (bool): whether to enable the layer-wise dynamic β₂ logic.
|
|
82
73
|
If `False`, the optimizer behaves as standard Adopt. (default: False)
|
|
83
74
|
beta2_min (float): The minimum value for dynamic β₂, used during periods of
|
|
@@ -143,9 +134,9 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
143
134
|
use_AdEMAMix: bool = False,
|
|
144
135
|
beta3_ema: float = 0.9999,
|
|
145
136
|
alpha: float = 5.0,
|
|
146
|
-
#
|
|
147
|
-
|
|
148
|
-
|
|
137
|
+
# Nesterov momentum
|
|
138
|
+
nesterov: bool = False,
|
|
139
|
+
nesterov_coef: float | None = None,
|
|
149
140
|
# K-b (adaptive beta2)
|
|
150
141
|
kourkoutas_beta: bool = False,
|
|
151
142
|
beta2_min: float = 0.9,
|
|
@@ -179,16 +170,8 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
179
170
|
if cautious_mask and grams_moment:
|
|
180
171
|
print("Warning: cautious is incompatible with grams, Disabling cautious.")
|
|
181
172
|
cautious_mask = False
|
|
182
|
-
if betas[0] == 0.0 and Simplified_AdEMAMix:
|
|
183
|
-
raise ValueError(f"Beta1 cannot be 0.0 when using Simplified_AdEMAMix. Got {betas[0]}")
|
|
184
173
|
if kourkoutas_beta and not (betas[1] > beta2_min):
|
|
185
174
|
raise ValueError(f"For Kourkoutas-β, betas[1] (as beta2_max) must be > beta2_min. Got {betas[1]} and {beta2_min}")
|
|
186
|
-
if use_AdEMAMix and Simplified_AdEMAMix:
|
|
187
|
-
print("Warning: use_AdEMAMix is incompatible with Simplified_AdEMAMix, Disabling use_AdEMAMix.")
|
|
188
|
-
if grams_moment and Simplified_AdEMAMix:
|
|
189
|
-
print("Warning: grams is incompatible with Simplified_AdEMAMix, Disabling grams.")
|
|
190
|
-
if cautious_mask and Simplified_AdEMAMix:
|
|
191
|
-
print("Warning: cautious is incompatible with Simplified_AdEMAMix, Disabling cautious.")
|
|
192
175
|
|
|
193
176
|
|
|
194
177
|
state_precision = state_precision.lower()
|
|
@@ -204,7 +187,7 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
204
187
|
"lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay,
|
|
205
188
|
"fisher_wd": fisher_wd, "cautious_wd": cautious_wd,
|
|
206
189
|
"beta3_ema": beta3_ema, "alpha": alpha,
|
|
207
|
-
"
|
|
190
|
+
"nesterov": nesterov, "nesterov_coef": nesterov_coef,
|
|
208
191
|
"kourkoutas_beta": kourkoutas_beta, "beta2_min": beta2_min, "ema_alpha": ema_alpha,
|
|
209
192
|
"tiny_spike": tiny_spike, "k_warmup_steps": k_warmup_steps, "k_logging": k_logging,
|
|
210
193
|
"spectral_normalization": spectral_normalization,
|
|
@@ -216,17 +199,18 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
216
199
|
}
|
|
217
200
|
self.clip_lambda = clip_lambda
|
|
218
201
|
self.stochastic_rounding = stochastic_rounding
|
|
219
|
-
self.use_atan2 = use_atan2
|
|
220
|
-
self.cautious_mask = cautious_mask
|
|
221
|
-
self.grams_moment = grams_moment
|
|
202
|
+
self.use_atan2 = use_atan2
|
|
203
|
+
self.cautious_mask = cautious_mask
|
|
204
|
+
self.grams_moment = grams_moment
|
|
222
205
|
self.orthogonal_gradient = orthogonal_gradient
|
|
223
|
-
self.use_AdEMAMix = use_AdEMAMix
|
|
224
|
-
self.Simplified_AdEMAMix = Simplified_AdEMAMix
|
|
206
|
+
self.use_AdEMAMix = use_AdEMAMix
|
|
225
207
|
self.kourkoutas_beta = kourkoutas_beta
|
|
226
208
|
self.layer_key_fn = layer_key_fn
|
|
227
209
|
self._init_lr = lr
|
|
228
210
|
super().__init__(params, defaults)
|
|
229
211
|
|
|
212
|
+
self.init_step()
|
|
213
|
+
|
|
230
214
|
if self.kourkoutas_beta:
|
|
231
215
|
self.kourkoutas_helper = KourkoutasHelper(self)
|
|
232
216
|
|
|
@@ -258,26 +242,15 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
258
242
|
@property
|
|
259
243
|
def supports_flat_params(self): return False
|
|
260
244
|
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
245
|
+
def init_step(self):
|
|
246
|
+
for group in self.param_groups:
|
|
247
|
+
for i, p in enumerate(group['params']):
|
|
248
|
+
self.__init_state(p, group)
|
|
265
249
|
|
|
266
|
-
|
|
250
|
+
@torch.no_grad()
|
|
251
|
+
def __init_state(self, p, group):
|
|
267
252
|
state = self.state[p]
|
|
268
253
|
|
|
269
|
-
beta1, beta2 = group['betas']
|
|
270
|
-
|
|
271
|
-
if group.get('kourkoutas_beta', False):
|
|
272
|
-
if 'step' not in state:
|
|
273
|
-
current_step = 0
|
|
274
|
-
else:
|
|
275
|
-
current_step = state['step']
|
|
276
|
-
# Call prepare_step() once at the beginning of the step for all params
|
|
277
|
-
self.kourkoutas_helper.maybe_prepare_step(current_step, p.device)
|
|
278
|
-
# Get the dynamic beta2 calculated in prepare_step()
|
|
279
|
-
beta2 = self.kourkoutas_helper.get_beta2(p, group)
|
|
280
|
-
|
|
281
254
|
# State Initialization
|
|
282
255
|
if 'step' not in state:
|
|
283
256
|
state['step'] = 0
|
|
@@ -340,6 +313,27 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
340
313
|
|
|
341
314
|
_init_fisher_wd_scaler(group, state, p)
|
|
342
315
|
|
|
316
|
+
@torch.no_grad()
|
|
317
|
+
def step_parameter(self, p: torch.Tensor, group: dict, i: int | None = None):
|
|
318
|
+
if p.grad is None:
|
|
319
|
+
return
|
|
320
|
+
|
|
321
|
+
grad = p.grad
|
|
322
|
+
state = self.state[p]
|
|
323
|
+
self.__init_state(p, group)
|
|
324
|
+
|
|
325
|
+
beta1, beta2 = group['betas']
|
|
326
|
+
|
|
327
|
+
if group.get('kourkoutas_beta', False):
|
|
328
|
+
if 'step' not in state:
|
|
329
|
+
current_step = 0
|
|
330
|
+
else:
|
|
331
|
+
current_step = state['step']
|
|
332
|
+
# Call prepare_step() once at the beginning of the step for all params
|
|
333
|
+
self.kourkoutas_helper.maybe_prepare_step(current_step, p.device)
|
|
334
|
+
# Get the dynamic beta2 calculated in prepare_step()
|
|
335
|
+
beta2 = self.kourkoutas_helper.get_beta2(p, group)
|
|
336
|
+
|
|
343
337
|
current_step = state['step']
|
|
344
338
|
|
|
345
339
|
# The first step is for initialization only (skip when use_atan2 as it's scale invariant).
|
|
@@ -367,9 +361,6 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
367
361
|
lr = group['lr']
|
|
368
362
|
step_param_fn = self._step_parameter
|
|
369
363
|
|
|
370
|
-
if self.Simplified_AdEMAMix:
|
|
371
|
-
lr = _scale_sim_AdEMAMix_update(beta1, state['step'] + 1, group["alpha_grad"], lr, group.get('spectral_normalization', False))
|
|
372
|
-
|
|
373
364
|
step_param_fn(p, grad, state, group, lr, beta1, beta2, random_int_tensor, random_int_state_tensor)
|
|
374
365
|
|
|
375
366
|
state['step'] += 1
|
|
@@ -383,8 +374,9 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
383
374
|
if self.use_AdEMAMix:
|
|
384
375
|
beta3_ema = group['beta3_ema']
|
|
385
376
|
alpha = group['alpha']
|
|
386
|
-
|
|
387
|
-
|
|
377
|
+
nesterov = group.get('nesterov', False)
|
|
378
|
+
nesterov_coef = group.get('nesterov_coef', None)
|
|
379
|
+
use_mt = group['betas'][0] > 0
|
|
388
380
|
|
|
389
381
|
if group.get('kourkoutas_beta', False):
|
|
390
382
|
# Accumulate current grad's norm for the *next* step
|
|
@@ -421,13 +413,10 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
421
413
|
normalized_grad.clamp_(-clip_val, clip_val)
|
|
422
414
|
|
|
423
415
|
# ADOPT Step B: Update momentum m_t using normalized gradient
|
|
424
|
-
if
|
|
416
|
+
if use_mt:
|
|
425
417
|
mt = _reconstruct_state((state['mu_m_nmf'], state['mv_m_nmf'], state['sign'], d2), signed=True)
|
|
426
418
|
|
|
427
|
-
|
|
428
|
-
mt.mul_(beta1).add_(normalized_grad, alpha=1.0)
|
|
429
|
-
else:
|
|
430
|
-
mt.lerp_(normalized_grad, 1.0 - beta1)
|
|
419
|
+
mt.lerp_(normalized_grad, 1.0 - beta1)
|
|
431
420
|
|
|
432
421
|
# Factorize
|
|
433
422
|
state['mu_m_nmf'], state['mv_m_nmf'], state['sign'] = _factorize_state(mt.clone(), signed=True)
|
|
@@ -439,13 +428,17 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
439
428
|
else:
|
|
440
429
|
update_mt = mt
|
|
441
430
|
|
|
431
|
+
if nesterov:
|
|
432
|
+
nv_coef = beta1 if nesterov_coef is None else nesterov_coef
|
|
433
|
+
update_mt = update_mt.lerp_(grad_reshaped, 1-nv_coef)
|
|
434
|
+
|
|
442
435
|
if self.use_AdEMAMix:
|
|
443
436
|
# Reconstruct AdEMAMix EMA
|
|
444
437
|
mt_slow = _reconstruct_state((state['mu_m_slow_nmf'], state['mv_m_slow_nmf'], state['sign_slow'], d2), signed=True)
|
|
445
438
|
|
|
446
439
|
mt_slow.lerp_(normalized_grad, 1.0 - beta3_ema)
|
|
447
440
|
|
|
448
|
-
if
|
|
441
|
+
if use_mt:
|
|
449
442
|
update = update_mt.add_(mt_slow, alpha=alpha)
|
|
450
443
|
del normalized_grad
|
|
451
444
|
else:
|
|
@@ -453,12 +446,8 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
453
446
|
# Factorize
|
|
454
447
|
state['mu_m_slow_nmf'], state['mv_m_slow_nmf'], state['sign_slow'] = _factorize_state(mt_slow, signed=True)
|
|
455
448
|
del mt_slow
|
|
456
|
-
|
|
457
|
-
elif self.Simplified_AdEMAMix:
|
|
458
|
-
update = update_mt.add_(normalized_grad, alpha=alpha_grad)
|
|
459
|
-
del normalized_grad
|
|
460
449
|
else:
|
|
461
|
-
if
|
|
450
|
+
if use_mt:
|
|
462
451
|
update = update_mt
|
|
463
452
|
del normalized_grad
|
|
464
453
|
else:
|
|
@@ -490,12 +479,9 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
490
479
|
normalized_grad.clamp_(-clip_val, clip_val)
|
|
491
480
|
|
|
492
481
|
# ADOPT Step B: Update momentum m_t
|
|
493
|
-
if
|
|
482
|
+
if use_mt:
|
|
494
483
|
mt = get_state(state, 'exp_avg', actual_precision) # m_{t-1}
|
|
495
|
-
|
|
496
|
-
mt.mul_(beta1).add_(normalized_grad, alpha=1.0)
|
|
497
|
-
else:
|
|
498
|
-
mt.lerp_(normalized_grad, 1.0 - beta1)
|
|
484
|
+
mt.lerp_(normalized_grad, 1.0 - beta1)
|
|
499
485
|
|
|
500
486
|
if self.grams_moment:
|
|
501
487
|
update_mt = _grams_update(mt, grad)
|
|
@@ -504,21 +490,23 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
504
490
|
else:
|
|
505
491
|
update_mt = mt.clone()
|
|
506
492
|
|
|
493
|
+
if nesterov:
|
|
494
|
+
nv_coef = beta1 if nesterov_coef is None else nesterov_coef
|
|
495
|
+
update_mt = update_mt.lerp_(grad, 1-nv_coef)
|
|
496
|
+
|
|
507
497
|
set_state(state, 'exp_avg', mt, actual_precision, random_int_state_tensor)
|
|
508
498
|
|
|
509
499
|
if self.use_AdEMAMix:
|
|
510
500
|
m_slow = get_state(state, 'exp_avg_slow', actual_precision)
|
|
511
501
|
m_slow.lerp_(normalized_grad, 1.0 - beta3_ema)
|
|
512
|
-
if
|
|
502
|
+
if use_mt:
|
|
513
503
|
update = update_mt.add_(m_slow, alpha=alpha)
|
|
514
504
|
del normalized_grad
|
|
515
505
|
else:
|
|
516
506
|
update = normalized_grad.add_(m_slow, alpha=alpha)
|
|
517
507
|
set_state(state, 'exp_avg_slow', m_slow, actual_precision, random_int_state_tensor)
|
|
518
|
-
elif self.Simplified_AdEMAMix:
|
|
519
|
-
update = update_mt.add_(normalized_grad, alpha=alpha_grad)
|
|
520
508
|
else:
|
|
521
|
-
if
|
|
509
|
+
if use_mt:
|
|
522
510
|
update = update_mt
|
|
523
511
|
del normalized_grad
|
|
524
512
|
else:
|