adv-optm 2.4.dev12__tar.gz → 2.4.dev14__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. {adv_optm-2.4.dev12 → adv_optm-2.4.dev14}/PKG-INFO +1 -1
  2. {adv_optm-2.4.dev12 → adv_optm-2.4.dev14}/adv_optm/__init__.py +1 -1
  3. {adv_optm-2.4.dev12 → adv_optm-2.4.dev14}/adv_optm/optim/AdaMuon_adv.py +50 -37
  4. {adv_optm-2.4.dev12 → adv_optm-2.4.dev14}/adv_optm/optim/AdamW_adv.py +41 -13
  5. {adv_optm-2.4.dev12 → adv_optm-2.4.dev14}/adv_optm/optim/Adopt_adv.py +59 -71
  6. {adv_optm-2.4.dev12 → adv_optm-2.4.dev14}/adv_optm/optim/Lion_adv.py +23 -53
  7. {adv_optm-2.4.dev12 → adv_optm-2.4.dev14}/adv_optm/optim/Muon_adv.py +49 -38
  8. {adv_optm-2.4.dev12 → adv_optm-2.4.dev14}/adv_optm/optim/Prodigy_adv.py +148 -130
  9. {adv_optm-2.4.dev12 → adv_optm-2.4.dev14}/adv_optm/optim/SignSGD_adv.py +55 -126
  10. {adv_optm-2.4.dev12 → adv_optm-2.4.dev14}/adv_optm/optim/SinkSGD_adv.py +60 -24
  11. adv_optm-2.4.dev14/adv_optm/util/Muon_AuxAdam.py +237 -0
  12. {adv_optm-2.4.dev12 → adv_optm-2.4.dev14}/adv_optm/util/Muon_util.py +2 -5
  13. {adv_optm-2.4.dev12 → adv_optm-2.4.dev14}/adv_optm/util/scaled_optm.py +83 -29
  14. adv_optm-2.4.dev14/adv_optm/util/signed_util.py +53 -0
  15. {adv_optm-2.4.dev12 → adv_optm-2.4.dev14}/adv_optm/util/sinkhorn.py +8 -20
  16. {adv_optm-2.4.dev12 → adv_optm-2.4.dev14}/adv_optm.egg-info/PKG-INFO +1 -1
  17. {adv_optm-2.4.dev12 → adv_optm-2.4.dev14}/setup.py +1 -1
  18. adv_optm-2.4.dev12/adv_optm/util/Muon_AuxAdam.py +0 -172
  19. adv_optm-2.4.dev12/adv_optm/util/signed_util.py +0 -13
  20. {adv_optm-2.4.dev12 → adv_optm-2.4.dev14}/LICENSE +0 -0
  21. {adv_optm-2.4.dev12 → adv_optm-2.4.dev14}/README.md +0 -0
  22. {adv_optm-2.4.dev12 → adv_optm-2.4.dev14}/adv_optm/optim/Lion_Prodigy_adv.py +0 -0
  23. {adv_optm-2.4.dev12 → adv_optm-2.4.dev14}/adv_optm/optim/Simplified_AdEMAMix.py +0 -0
  24. {adv_optm-2.4.dev12 → adv_optm-2.4.dev14}/adv_optm/optim/__init__.py +0 -0
  25. {adv_optm-2.4.dev12 → adv_optm-2.4.dev14}/adv_optm/util/Kourkoutas.py +0 -0
  26. {adv_optm-2.4.dev12 → adv_optm-2.4.dev14}/adv_optm/util/OrthoGrad.py +0 -0
  27. {adv_optm-2.4.dev12 → adv_optm-2.4.dev14}/adv_optm/util/__init__.py +0 -0
  28. {adv_optm-2.4.dev12 → adv_optm-2.4.dev14}/adv_optm/util/centered_decay.py +0 -0
  29. {adv_optm-2.4.dev12 → adv_optm-2.4.dev14}/adv_optm/util/factorization_util.py +0 -0
  30. {adv_optm-2.4.dev12 → adv_optm-2.4.dev14}/adv_optm/util/lion_k.py +0 -0
  31. {adv_optm-2.4.dev12 → adv_optm-2.4.dev14}/adv_optm/util/param_update.py +0 -0
  32. {adv_optm-2.4.dev12 → adv_optm-2.4.dev14}/adv_optm/util/state_util.py +0 -0
  33. {adv_optm-2.4.dev12 → adv_optm-2.4.dev14}/adv_optm/util/update_util.py +0 -0
  34. {adv_optm-2.4.dev12 → adv_optm-2.4.dev14}/adv_optm.egg-info/SOURCES.txt +0 -0
  35. {adv_optm-2.4.dev12 → adv_optm-2.4.dev14}/adv_optm.egg-info/dependency_links.txt +0 -0
  36. {adv_optm-2.4.dev12 → adv_optm-2.4.dev14}/adv_optm.egg-info/requires.txt +0 -0
  37. {adv_optm-2.4.dev12 → adv_optm-2.4.dev14}/adv_optm.egg-info/top_level.txt +0 -0
  38. {adv_optm-2.4.dev12 → adv_optm-2.4.dev14}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: adv_optm
3
- Version: 2.4.dev12
3
+ Version: 2.4.dev14
4
4
  Summary: A family of highly efficient, lightweight yet powerful optimizers.
5
5
  Home-page: https://github.com/Koratahiu/Advanced_Optimizers
6
6
  Author: Koratahiu
@@ -24,4 +24,4 @@ __all__ = [
24
24
  "SinkSGD_adv",
25
25
  ]
26
26
 
27
- __version__ = "2.4.dev12"
27
+ __version__ = "2.4.dev14"
@@ -59,14 +59,6 @@ class AdaMuon_adv(torch.optim.Optimizer):
59
59
  orthogonal_gradient (bool): whether to use OrthoGrad. (default: False)
60
60
  nesterov (bool): enables Nesterov momentum (default: False).
61
61
  use_atan2 (bool): whether to use the atan2 update rule. (default: False)
62
- Simplified_AdEMAMix (bool): whether to use the Simplified AdEMAMix update rule.
63
- This changes the update to `alpha_grad * grad + mt`, which can be
64
- more responsive, especially for small batch sizes. (default: False)
65
- alpha_grad (float): Mixing coefficient for the Simplified AdEMAMix update rule
66
- (only used when `Simplified_AdEMAMix` is `True`). Controls the weight of the
67
- current gradient. For small batch sizes, use high values (e.g., 10-100) to be
68
- more responsive. For large batch sizes, use low values (e.g., 0-1) for
69
- stability. (default: 100.0)
70
62
  vector_reshape (bool): whether to reshape 1D vectors into 2D
71
63
  matrices to apply low-rank compression (default: True).
72
64
  kappa_p (float, optional): The p-value for the update geometry (domain [1.0, 2.0]).
@@ -117,6 +109,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
117
109
  adam_betas (tuple[float, float]): Betas for the AdamW optimizer part.
118
110
  adam_eps (float): Epsilon for the AdamW optimizer part.
119
111
  adam_weight_decay (float): Weight decay for the AdamW optimizer part.
112
+ adam_fisher_wd (bool): Fisher Adam (FAdam) weight decay for the AdamW part. (default: False)
120
113
  adam_use_bias_correction (bool): Bias correction for AdamW.
121
114
  adam_use_atan2 (bool): Atan2 update rule for AdamW.
122
115
  adam_cautious_mask (bool): Cautious masking for AdamW.
@@ -125,8 +118,17 @@ class AdaMuon_adv(torch.optim.Optimizer):
125
118
  adam_use_AdEMAMix (bool): AdEMAMix for AdamW.
126
119
  adam_beta3_ema (float): Beta3 for AdEMAMix.
127
120
  adam_alpha (float): Alpha for AdEMAMix.
121
+ adam_nesterov (bool): Nesterov momentum for AdamW. (default: False)
122
+ adam_nesterov_coef (float, optional): Nesterov coefficient for AdamW. (default: None)
128
123
  adam_kourkoutas_beta (bool): Kourkoutas-β for AdamW.
124
+ adam_beta2_min (float): Minimum beta2 for Kourkoutas-β. (default: 0.9)
125
+ adam_ema_alpha (float): EMA alpha for Kourkoutas-β. (default: 0.95)
126
+ adam_tiny_spike (float): Tiny spike for Kourkoutas-β. (default: 1e-9)
127
+ adam_k_warmup_steps (int): Warmup steps for Kourkoutas-β. (default: 0)
128
+ adam_spectral_normalization (bool): Enable explicit spectral normalization for AdamW. (default: False)
129
+ adam_state_precision (str): Precision for AuxAdam states. Options: 'auto', 'fp32', 'bf16_sr', 'fp8_sr', 'int8_sr', 'factored'. (default: 'auto')
129
130
  adam_nnmf_factor (bool): 1-bit factored for AdamW.
131
+ adam_factored_2nd (bool): Factorize only the second moment (v_t) for AuxAdam. (default: False)
130
132
  """
131
133
 
132
134
  def __init__(
@@ -140,6 +142,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
140
142
  cautious_wd: bool = False,
141
143
  # Nesterov momentum
142
144
  nesterov: bool = True,
145
+ nesterov_coef: float | None = None,
143
146
  # RMS Rescaling
144
147
  rms_rescaling: bool = True,
145
148
  # Newton Schulz
@@ -152,9 +155,6 @@ class AdaMuon_adv(torch.optim.Optimizer):
152
155
  orthogonal_gradient: bool = False,
153
156
  # Adam_atan2 (scale invariant)
154
157
  use_atan2: bool = False,
155
- # One-EMA AdEMAMix
156
- Simplified_AdEMAMix: bool = False,
157
- alpha_grad: float = 100.0,
158
158
  # NorMuon
159
159
  normuon_variant: bool = False,
160
160
  # Boolean to spilt param
@@ -190,6 +190,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
190
190
  adam_betas: tuple[float, float] = (0.9, 0.99),
191
191
  adam_eps: float | None = 1e-8,
192
192
  adam_weight_decay: float = 0.0,
193
+ adam_fisher_wd: bool = False,
193
194
  adam_use_bias_correction: bool = True,
194
195
  adam_use_atan2: bool = False,
195
196
  adam_cautious_mask: bool = False,
@@ -198,12 +199,17 @@ class AdaMuon_adv(torch.optim.Optimizer):
198
199
  adam_use_AdEMAMix: bool = False,
199
200
  adam_beta3_ema: float = 0.9999,
200
201
  adam_alpha: float = 5.0,
202
+ adam_nesterov: bool = False,
203
+ adam_nesterov_coef: float | None = None,
201
204
  adam_kourkoutas_beta: bool = False,
202
205
  adam_beta2_min: float = 0.9,
203
206
  adam_ema_alpha: float = 0.95,
204
207
  adam_tiny_spike: float = 1e-9,
205
208
  adam_k_warmup_steps: int = 0,
209
+ adam_spectral_normalization: bool = False,
210
+ adam_state_precision: str = "auto",
206
211
  adam_nnmf_factor: bool = False,
212
+ adam_factored_2nd: bool = False,
207
213
  ):
208
214
  if not (lr >= 0.0):
209
215
  raise ValueError(f"Learning-rate should be >= 0.0. Got {lr}")
@@ -211,9 +217,6 @@ class AdaMuon_adv(torch.optim.Optimizer):
211
217
  raise ValueError(f"Weight-decay should be >= 0.0. Got {weight_decay}")
212
218
  if not (ns_steps > 0):
213
219
  raise ValueError(f"Newton-Schulz steps should be > 0. Got {ns_steps}")
214
- if Simplified_AdEMAMix and nesterov:
215
- print("Warning: nesterov is incompatible with Simplified_AdEMAMix, Disabling nesterov.")
216
- nesterov = False
217
220
  if spectral_normalization and rms_rescaling:
218
221
  print("Warning: spectral_normalization is incompatible with rms_rescaling, Disabling rms_rescaling.")
219
222
  rms_rescaling = False
@@ -221,17 +224,20 @@ class AdaMuon_adv(torch.optim.Optimizer):
221
224
  ValueError("spectral_normalization violates accelerated Newton-Schulz assumptions. Pick one of them.")
222
225
 
223
226
  state_precision = state_precision.lower()
224
- valid_precisions = {"auto", "fp32", "bf16_sr", "fp8_sr", "int8_sr"}
227
+ valid_precisions = {"auto", "fp32", "factored", "bf16_sr", "fp8_sr", "int8_sr"}
225
228
  if state_precision not in valid_precisions:
226
229
  raise ValueError(f"state_precision must be one of {valid_precisions}. Got {state_precision}")
227
230
 
231
+ adam_state_precision = adam_state_precision.lower()
232
+ if adam_state_precision not in valid_precisions:
233
+ raise ValueError(f"adam_state_precision must be one of {valid_precisions}. Got {adam_state_precision}")
234
+
228
235
  defaults = {
229
236
  "lr": lr, "betas": betas, "weight_decay": weight_decay, "cautious_wd": cautious_wd,
230
237
  "eps": eps, "rms_rescaling": rms_rescaling, "ns_steps": ns_steps,
231
238
  "ns_eps": ns_eps, "ns_coeffs": ns_coeffs, "nnmf_factor": nnmf_factor,
232
239
  "vector_reshape": vector_reshape,
233
- "nesterov":nesterov, "use_atan2":use_atan2,
234
- "Simplified_AdEMAMix": Simplified_AdEMAMix, "alpha_grad": alpha_grad,
240
+ "nesterov":nesterov, "nesterov_coef": nesterov_coef, "use_atan2":use_atan2,
235
241
  "normuon_variant": normuon_variant, "orthogonal_gradient": orthogonal_gradient,
236
242
  "compiled_optimizer":compiled_optimizer,
237
243
  "use_muon": use_muon,
@@ -254,13 +260,18 @@ class AdaMuon_adv(torch.optim.Optimizer):
254
260
  "centered_wd_mode": centered_wd_mode,
255
261
  # AdamW_adv defaults
256
262
  "adam_betas": adam_betas, "adam_eps": adam_eps, "adam_weight_decay": adam_weight_decay,
263
+ "adam_fisher_wd": adam_fisher_wd,
257
264
  "adam_use_bias_correction": adam_use_bias_correction, "adam_use_atan2": adam_use_atan2,
258
265
  "adam_cautious_mask": adam_cautious_mask, "adam_grams_moment": adam_grams_moment,
259
266
  "adam_orthogonal_gradient": adam_orthogonal_gradient,
260
267
  "adam_use_AdEMAMix": adam_use_AdEMAMix, "adam_beta3_ema": adam_beta3_ema, "adam_alpha": adam_alpha,
268
+ "adam_nesterov": adam_nesterov, "adam_nesterov_coef": adam_nesterov_coef,
261
269
  "adam_kourkoutas_beta": adam_kourkoutas_beta, "adam_beta2_min": adam_beta2_min,
262
270
  "adam_ema_alpha": adam_ema_alpha, "adam_tiny_spike": adam_tiny_spike,
263
- "adam_k_warmup_steps": adam_k_warmup_steps, "adam_nnmf_factor": adam_nnmf_factor,
271
+ "adam_k_warmup_steps": adam_k_warmup_steps,
272
+ "adam_spectral_normalization": adam_spectral_normalization,
273
+ "adam_state_precision": adam_state_precision,
274
+ "adam_nnmf_factor": adam_nnmf_factor, "adam_factored_2nd": adam_factored_2nd,
264
275
  }
265
276
  self.stochastic_rounding = stochastic_rounding
266
277
  self._init_lr = lr
@@ -447,13 +458,24 @@ class AdaMuon_adv(torch.optim.Optimizer):
447
458
 
448
459
  step_size = group['lr'] / bias_correction1
449
460
 
461
+ random_int_state_tensor = None
450
462
  if is_compiled:
451
463
  step_size = torch.as_tensor(step_size)
452
464
  adam_step_param = self._compiled_adam_step_parameter
465
+
466
+ # Generate state SR random tensor when compiled
467
+ actual_precision = group.get('adam_actual_state_precision', 'auto')
468
+ random_int_state_tensor = random_int_tensor
469
+ if actual_precision == 'bf16_sr' and random_int_state_tensor is None:
470
+ random_int_state_tensor = param_update._get_random_int_for_sr(p)
471
+ elif actual_precision == 'int8_sr':
472
+ random_int_state_tensor = param_update._get_random_int_for_8bit_sr(p)
473
+ elif actual_precision == 'fp8_sr':
474
+ random_int_state_tensor = param_update._get_random_int_for_fp8_sr(p)
453
475
  else:
454
476
  adam_step_param = Muon_AuxAdam._adam_step_parameter
455
477
 
456
- adam_step_param(self, p, grad, state, group, beta1_adam, beta2_adam, sqrt_bias_correction2, step_size, random_int_tensor)
478
+ adam_step_param(self, p, grad, state, group, beta1_adam, beta2_adam, sqrt_bias_correction2, step_size, random_int_tensor, random_int_state_tensor)
457
479
 
458
480
  state['step'] += 1
459
481
 
@@ -465,7 +487,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
465
487
  # Generate state SR random tensor when compiled
466
488
  actual_precision = group['actual_state_precision']
467
489
  random_int_state_tensor = random_int_tensor
468
- if actual_precision == 'bf16_sr' and random_int_state_tensor is not None:
490
+ if actual_precision == 'bf16_sr' and random_int_state_tensor is None:
469
491
  random_int_state_tensor = param_update._get_random_int_for_sr(p)
470
492
  elif actual_precision == 'int8_sr':
471
493
  random_int_state_tensor = param_update._get_random_int_for_8bit_sr(p)
@@ -488,8 +510,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
488
510
  grad = upcast_grad_for_precision(grad, state, group.get('state_precision', 'auto'))
489
511
  beta1, beta2 = group['betas']
490
512
  nesterov = group['nesterov']
491
- Simplified_AdEMAMix = group['Simplified_AdEMAMix']
492
- alpha_grad = group['alpha_grad']
513
+ nesterov_coef = group.get('nesterov_coef', None)
493
514
 
494
515
  # Update geometry
495
516
  kappa_p = group.get("kappa_p", 1.0)
@@ -513,7 +534,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
513
534
 
514
535
  # MARS-M Approximated (Variance Reduction)
515
536
  if group.get('approx_mars', False):
516
- grad = approx_mars(grad, state['last_grad'], group['mars_gamma'], beta1, Simplified_AdEMAMix=Simplified_AdEMAMix)
537
+ grad = approx_mars(grad, state['last_grad'], group['mars_gamma'], beta1)
517
538
 
518
539
 
519
540
  if group.get("orthogonal_gradient"):
@@ -527,15 +548,11 @@ class AdaMuon_adv(torch.optim.Optimizer):
527
548
  mt_buf = _reconstruct_state((state['mu_mbuf_nmf'], state['mv_mbuf_nmf'], state['sign_buf'], d2), signed=True)
528
549
 
529
550
  # Update momentum in full-size
530
- if not Simplified_AdEMAMix:
531
- mt_buf.lerp_(grad_reshaped, 1 - beta1)
532
- else:
533
- mt_buf.mul_(beta1).add_(grad_reshaped)
551
+ mt_buf.lerp_(grad_reshaped, 1 - beta1)
534
552
 
535
553
  if nesterov:
536
- update = grad_reshaped.lerp(mt_buf, beta1)
537
- elif Simplified_AdEMAMix:
538
- update = torch.add(mt_buf, grad_reshaped, alpha=alpha_grad)
554
+ nv_coef = beta1 if nesterov_coef is None else nesterov_coef
555
+ update = grad_reshaped.lerp(mt_buf, nv_coef)
539
556
  else:
540
557
  update = mt_buf.clone()
541
558
 
@@ -586,15 +603,11 @@ class AdaMuon_adv(torch.optim.Optimizer):
586
603
 
587
604
  # Momentum update
588
605
  mt_buf = get_state(state, 'momentum_buffer', actual_precision)
589
- if not Simplified_AdEMAMix:
590
- mt_buf.lerp_(grad, 1 - beta1)
591
- else:
592
- mt_buf.mul_(beta1).add_(grad)
606
+ mt_buf.lerp_(grad, 1 - beta1)
593
607
 
594
608
  if nesterov:
595
- update = grad.lerp(mt_buf, beta1)
596
- elif Simplified_AdEMAMix:
597
- update = mt_buf.add(grad, alpha=alpha_grad)
609
+ nv_coef = beta1 if nesterov_coef is None else nesterov_coef
610
+ update = grad.lerp(mt_buf, nv_coef)
598
611
  else:
599
612
  update = mt_buf.clone()
600
613
 
@@ -28,7 +28,8 @@ class AdamW_adv(torch.optim.Optimizer):
28
28
  betas (tuple[float, float]): coefficients used for computing running
29
29
  averages of gradient and its square (default: (0.9, 0.999))
30
30
  eps (float): term added to the denominator to improve
31
- numerical stability (default: 1e-8)
31
+ numerical stability. Set to None for scale invariant eps (vector
32
+ lower bound) (default: 1e-8)
32
33
  weight_decay (float): weight decay (L2 penalty) (default: 0).
33
34
  fisher_wd (bool): whether to use Fisher Adam (FAdam) weight decay, mapping
34
35
  the decay direction through the empirical Fisher information matrix and
@@ -127,6 +128,9 @@ class AdamW_adv(torch.optim.Optimizer):
127
128
  use_AdEMAMix: bool = False,
128
129
  beta3_ema: float = 0.9999,
129
130
  alpha: float = 5.0,
131
+ # Nesterov momentum
132
+ nesterov: bool = False,
133
+ nesterov_coef: float | None = None,
130
134
  # K-b (adaptive beta2)
131
135
  kourkoutas_beta: bool = False,
132
136
  beta2_min: float = 0.9,
@@ -176,7 +180,7 @@ class AdamW_adv(torch.optim.Optimizer):
176
180
  defaults = {
177
181
  "lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay,
178
182
  "fisher_wd": fisher_wd, "cautious_wd": cautious_wd,
179
- "use_atan2": use_atan2,
183
+ "use_atan2": use_atan2, "nesterov": nesterov, "nesterov_coef": nesterov_coef,
180
184
  "orthogonal_gradient": orthogonal_gradient, "use_bias_correction": use_bias_correction,
181
185
  "beta3_ema": beta3_ema, "alpha": alpha, "compiled_optimizer": compiled_optimizer,
182
186
  "kourkoutas_beta": kourkoutas_beta, "beta2_min": beta2_min, "ema_alpha": ema_alpha,
@@ -195,6 +199,8 @@ class AdamW_adv(torch.optim.Optimizer):
195
199
  self._init_lr = lr
196
200
  super().__init__(params, defaults)
197
201
 
202
+ self.init_step()
203
+
198
204
  if self.kourkoutas_beta:
199
205
  self.kourkoutas_helper = KourkoutasHelper(self)
200
206
 
@@ -232,12 +238,13 @@ class AdamW_adv(torch.optim.Optimizer):
232
238
  def supports_flat_params(self):
233
239
  return False
234
240
 
235
- @torch.no_grad()
236
- def step_parameter(self, p: torch.Tensor, group: dict, i: int | None = None):
237
- if p.grad is None:
238
- return
241
+ def init_step(self):
242
+ for group in self.param_groups:
243
+ for i, p in enumerate(group['params']):
244
+ self.__init_state(p, group)
239
245
 
240
- grad = p.grad
246
+ @torch.no_grad()
247
+ def __init_state(self, p, group):
241
248
  state = self.state[p]
242
249
 
243
250
  # State Initialization
@@ -303,6 +310,15 @@ class AdamW_adv(torch.optim.Optimizer):
303
310
 
304
311
  _init_fisher_wd_scaler(group, state, p)
305
312
 
313
+ @torch.no_grad()
314
+ def step_parameter(self, p: torch.Tensor, group: dict, i: int | None = None):
315
+ if p.grad is None:
316
+ return
317
+
318
+ grad = p.grad
319
+ state = self.state[p]
320
+ self.__init_state(p, group)
321
+
306
322
  beta1, beta2 = group['betas']
307
323
 
308
324
  current_step = state['step']
@@ -353,6 +369,9 @@ class AdamW_adv(torch.optim.Optimizer):
353
369
  if self.use_AdEMAMix:
354
370
  beta3_ema = group['beta3_ema']
355
371
  alpha = group['alpha']
372
+ nesterov = group.get('nesterov', False)
373
+ nesterov_coef = group.get('nesterov_coef', None)
374
+ use_mt = group['betas'][0] > 0
356
375
 
357
376
  if group.get('kourkoutas_beta', False):
358
377
  # Accumulate current grad's norm for the *next* step
@@ -365,7 +384,7 @@ class AdamW_adv(torch.optim.Optimizer):
365
384
  grad_reshaped = grad.view(d1, d2)
366
385
 
367
386
  # Reconstruct momentum from previous step's factors
368
- if beta1 > 0:
387
+ if use_mt:
369
388
  mt = _reconstruct_state((state['mu_m_nmf'], state['mv_m_nmf'], state['sign'], d2), signed=True)
370
389
 
371
390
  # Update momentum in full-size
@@ -381,6 +400,10 @@ class AdamW_adv(torch.optim.Optimizer):
381
400
  else:
382
401
  update_mt = mt
383
402
 
403
+ if nesterov:
404
+ nv_coef = beta1 if nesterov_coef is None else nesterov_coef
405
+ update_mt = update_mt.lerp_(grad_reshaped, 1-nv_coef)
406
+
384
407
  vt = _reconstruct_state((state['mu_v_nmf'], state['mv_v_nmf']), signed=False)
385
408
 
386
409
  if isinstance(beta2, torch.Tensor) and beta2.dim() > 0:
@@ -393,7 +416,7 @@ class AdamW_adv(torch.optim.Optimizer):
393
416
 
394
417
  mt_slow.lerp_(grad_reshaped, 1.0 - beta3_ema)
395
418
 
396
- if beta1 > 0:
419
+ if use_mt:
397
420
  update = update_mt.add_(mt_slow, alpha=alpha)
398
421
  else:
399
422
  update = grad_reshaped.add(mt_slow, alpha=alpha)
@@ -402,7 +425,7 @@ class AdamW_adv(torch.optim.Optimizer):
402
425
  state['mu_m_slow_nmf'], state['mv_m_slow_nmf'], state['sign_slow'] = _factorize_state(mt_slow, signed=True)
403
426
  del mt_slow
404
427
  else:
405
- if beta1 > 0:
428
+ if use_mt:
406
429
  update = update_mt
407
430
  else:
408
431
  update = grad_reshaped.clone()
@@ -429,7 +452,7 @@ class AdamW_adv(torch.optim.Optimizer):
429
452
  actual_precision = group['actual_state_precision']
430
453
  factored_2nd = state.get('factored_2nd', False)
431
454
 
432
- if beta1 > 0:
455
+ if use_mt:
433
456
  exp_avg = get_state(state, 'exp_avg', actual_precision)
434
457
  exp_avg.lerp_(grad, 1.0 - beta1)
435
458
 
@@ -439,19 +462,24 @@ class AdamW_adv(torch.optim.Optimizer):
439
462
  update_mt = _cautious_update(exp_avg, grad)
440
463
  else:
441
464
  update_mt = exp_avg.clone()
465
+
466
+ if nesterov:
467
+ nv_coef = beta1 if nesterov_coef is None else nesterov_coef
468
+ update_mt = update_mt.lerp_(grad, 1-nv_coef)
469
+
442
470
  set_state(state, 'exp_avg', exp_avg, actual_precision, random_int_state_tensor)
443
471
 
444
472
  if self.use_AdEMAMix:
445
473
  exp_avg_slow = get_state(state, 'exp_avg_slow', actual_precision)
446
474
  exp_avg_slow.lerp_(grad, 1.0 - beta3_ema)
447
475
 
448
- if beta1 > 0:
476
+ if use_mt:
449
477
  update = update_mt.add_(exp_avg_slow, alpha=alpha)
450
478
  else:
451
479
  update = torch.add(grad, exp_avg_slow, alpha=alpha)
452
480
  set_state(state, 'exp_avg_slow', exp_avg_slow, actual_precision, random_int_state_tensor)
453
481
  else:
454
- update = update_mt if beta1 > 0 else grad.clone()
482
+ update = update_mt if use_mt else grad.clone()
455
483
 
456
484
  if factored_2nd:
457
485
  d1, d2 = state['effective_shape']
@@ -7,7 +7,7 @@ from ..util import param_update
7
7
  from ..util.factorization_util import _get_effective_shape, _reconstruct_state, _factorize_state, _nnmf
8
8
  from ..util.OrthoGrad import _orthogonalize_gradient
9
9
  from ..util.Kourkoutas import KourkoutasHelper
10
- from ..util.update_util import _grams_update, _cautious_update, _scale_sim_AdEMAMix_update, _init_fisher_wd_scaler, _get_fisher_wd_scaler
10
+ from ..util.update_util import _grams_update, _cautious_update, _init_fisher_wd_scaler, _get_fisher_wd_scaler
11
11
  from ..util.scaled_optm import scale_update, is_spectral, init_spectral_norm, scale_eps
12
12
  from ..util.centered_decay import _init_anchor
13
13
  from ..util.state_util import init_state_tensor, get_state, set_state, upcast_grad_for_precision
@@ -32,7 +32,8 @@ class Adopt_adv(torch.optim.Optimizer):
32
32
  betas (tuple[float, float]): coefficients used for computing running
33
33
  averages of momentum and variance (default: (0.9, 0.9999))
34
34
  eps (float): term added to the denominator to improve
35
- numerical stability (default: 1e-6)
35
+ numerical stability. Set to None for scale invariant eps (vector
36
+ lower bound) (default: 1e-6)
36
37
  weight_decay (float): weight decay (L2 penalty) (default: 0)
37
38
  fisher_wd (bool): whether to use Fisher Adam (FAdam) weight decay, mapping
38
39
  the decay direction through the empirical Fisher information matrix and
@@ -68,16 +69,6 @@ class Adopt_adv(torch.optim.Optimizer):
68
69
  before it is added to the fast momentum term (`update = mt + alpha * mt_slow`).
69
70
  A higher value increases the stabilizing influence of the slow
70
71
  momentum. (default: 5.0)
71
- Simplified_AdEMAMix (bool): whether to use the Simplified AdEMAMix update rule.
72
- This changes the EMA to accumulator and the update numerator to `alpha_grad * grad + mt`, which can be
73
- more responsive, especially for small batch sizes. Enabling this will
74
- automatically disable `use_AdEMAMix`, `cautious_mask`, `grams_moment`,
75
- and `use_atan2`. (default: False)
76
- alpha_grad (float): Mixing coefficient for the Simplified AdEMAMix update rule
77
- (only used when `Simplified_AdEMAMix` is `True`). Controls the weight of the
78
- current gradient. For small batch sizes, use high values (e.g., 10-100) to be
79
- more responsive. For large batch sizes, use low values (e.g., 0-1) for
80
- stability. (default: 100.0)
81
72
  kourkoutas_beta (bool): whether to enable the layer-wise dynamic β₂ logic.
82
73
  If `False`, the optimizer behaves as standard Adopt. (default: False)
83
74
  beta2_min (float): The minimum value for dynamic β₂, used during periods of
@@ -143,9 +134,9 @@ class Adopt_adv(torch.optim.Optimizer):
143
134
  use_AdEMAMix: bool = False,
144
135
  beta3_ema: float = 0.9999,
145
136
  alpha: float = 5.0,
146
- # One-EMA AdEMAMix
147
- Simplified_AdEMAMix: bool = False,
148
- alpha_grad: float = 100.0,
137
+ # Nesterov momentum
138
+ nesterov: bool = False,
139
+ nesterov_coef: float | None = None,
149
140
  # K-b (adaptive beta2)
150
141
  kourkoutas_beta: bool = False,
151
142
  beta2_min: float = 0.9,
@@ -179,16 +170,8 @@ class Adopt_adv(torch.optim.Optimizer):
179
170
  if cautious_mask and grams_moment:
180
171
  print("Warning: cautious is incompatible with grams, Disabling cautious.")
181
172
  cautious_mask = False
182
- if betas[0] == 0.0 and Simplified_AdEMAMix:
183
- raise ValueError(f"Beta1 cannot be 0.0 when using Simplified_AdEMAMix. Got {betas[0]}")
184
173
  if kourkoutas_beta and not (betas[1] > beta2_min):
185
174
  raise ValueError(f"For Kourkoutas-β, betas[1] (as beta2_max) must be > beta2_min. Got {betas[1]} and {beta2_min}")
186
- if use_AdEMAMix and Simplified_AdEMAMix:
187
- print("Warning: use_AdEMAMix is incompatible with Simplified_AdEMAMix, Disabling use_AdEMAMix.")
188
- if grams_moment and Simplified_AdEMAMix:
189
- print("Warning: grams is incompatible with Simplified_AdEMAMix, Disabling grams.")
190
- if cautious_mask and Simplified_AdEMAMix:
191
- print("Warning: cautious is incompatible with Simplified_AdEMAMix, Disabling cautious.")
192
175
 
193
176
 
194
177
  state_precision = state_precision.lower()
@@ -204,7 +187,7 @@ class Adopt_adv(torch.optim.Optimizer):
204
187
  "lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay,
205
188
  "fisher_wd": fisher_wd, "cautious_wd": cautious_wd,
206
189
  "beta3_ema": beta3_ema, "alpha": alpha,
207
- "alpha_grad": alpha_grad,
190
+ "nesterov": nesterov, "nesterov_coef": nesterov_coef,
208
191
  "kourkoutas_beta": kourkoutas_beta, "beta2_min": beta2_min, "ema_alpha": ema_alpha,
209
192
  "tiny_spike": tiny_spike, "k_warmup_steps": k_warmup_steps, "k_logging": k_logging,
210
193
  "spectral_normalization": spectral_normalization,
@@ -216,17 +199,18 @@ class Adopt_adv(torch.optim.Optimizer):
216
199
  }
217
200
  self.clip_lambda = clip_lambda
218
201
  self.stochastic_rounding = stochastic_rounding
219
- self.use_atan2 = use_atan2 and not Simplified_AdEMAMix
220
- self.cautious_mask = cautious_mask and not Simplified_AdEMAMix
221
- self.grams_moment = grams_moment and not Simplified_AdEMAMix
202
+ self.use_atan2 = use_atan2
203
+ self.cautious_mask = cautious_mask
204
+ self.grams_moment = grams_moment
222
205
  self.orthogonal_gradient = orthogonal_gradient
223
- self.use_AdEMAMix = use_AdEMAMix and not Simplified_AdEMAMix
224
- self.Simplified_AdEMAMix = Simplified_AdEMAMix
206
+ self.use_AdEMAMix = use_AdEMAMix
225
207
  self.kourkoutas_beta = kourkoutas_beta
226
208
  self.layer_key_fn = layer_key_fn
227
209
  self._init_lr = lr
228
210
  super().__init__(params, defaults)
229
211
 
212
+ self.init_step()
213
+
230
214
  if self.kourkoutas_beta:
231
215
  self.kourkoutas_helper = KourkoutasHelper(self)
232
216
 
@@ -258,26 +242,15 @@ class Adopt_adv(torch.optim.Optimizer):
258
242
  @property
259
243
  def supports_flat_params(self): return False
260
244
 
261
- @torch.no_grad()
262
- def step_parameter(self, p: torch.Tensor, group: dict, i: int | None = None):
263
- if p.grad is None:
264
- return
245
+ def init_step(self):
246
+ for group in self.param_groups:
247
+ for i, p in enumerate(group['params']):
248
+ self.__init_state(p, group)
265
249
 
266
- grad = p.grad
250
+ @torch.no_grad()
251
+ def __init_state(self, p, group):
267
252
  state = self.state[p]
268
253
 
269
- beta1, beta2 = group['betas']
270
-
271
- if group.get('kourkoutas_beta', False):
272
- if 'step' not in state:
273
- current_step = 0
274
- else:
275
- current_step = state['step']
276
- # Call prepare_step() once at the beginning of the step for all params
277
- self.kourkoutas_helper.maybe_prepare_step(current_step, p.device)
278
- # Get the dynamic beta2 calculated in prepare_step()
279
- beta2 = self.kourkoutas_helper.get_beta2(p, group)
280
-
281
254
  # State Initialization
282
255
  if 'step' not in state:
283
256
  state['step'] = 0
@@ -340,6 +313,27 @@ class Adopt_adv(torch.optim.Optimizer):
340
313
 
341
314
  _init_fisher_wd_scaler(group, state, p)
342
315
 
316
+ @torch.no_grad()
317
+ def step_parameter(self, p: torch.Tensor, group: dict, i: int | None = None):
318
+ if p.grad is None:
319
+ return
320
+
321
+ grad = p.grad
322
+ state = self.state[p]
323
+ self.__init_state(p, group)
324
+
325
+ beta1, beta2 = group['betas']
326
+
327
+ if group.get('kourkoutas_beta', False):
328
+ if 'step' not in state:
329
+ current_step = 0
330
+ else:
331
+ current_step = state['step']
332
+ # Call prepare_step() once at the beginning of the step for all params
333
+ self.kourkoutas_helper.maybe_prepare_step(current_step, p.device)
334
+ # Get the dynamic beta2 calculated in prepare_step()
335
+ beta2 = self.kourkoutas_helper.get_beta2(p, group)
336
+
343
337
  current_step = state['step']
344
338
 
345
339
  # The first step is for initialization only (skip when use_atan2 as it's scale invariant).
@@ -367,9 +361,6 @@ class Adopt_adv(torch.optim.Optimizer):
367
361
  lr = group['lr']
368
362
  step_param_fn = self._step_parameter
369
363
 
370
- if self.Simplified_AdEMAMix:
371
- lr = _scale_sim_AdEMAMix_update(beta1, state['step'] + 1, group["alpha_grad"], lr, group.get('spectral_normalization', False))
372
-
373
364
  step_param_fn(p, grad, state, group, lr, beta1, beta2, random_int_tensor, random_int_state_tensor)
374
365
 
375
366
  state['step'] += 1
@@ -383,8 +374,9 @@ class Adopt_adv(torch.optim.Optimizer):
383
374
  if self.use_AdEMAMix:
384
375
  beta3_ema = group['beta3_ema']
385
376
  alpha = group['alpha']
386
- if self.Simplified_AdEMAMix:
387
- alpha_grad = group["alpha_grad"]
377
+ nesterov = group.get('nesterov', False)
378
+ nesterov_coef = group.get('nesterov_coef', None)
379
+ use_mt = group['betas'][0] > 0
388
380
 
389
381
  if group.get('kourkoutas_beta', False):
390
382
  # Accumulate current grad's norm for the *next* step
@@ -421,13 +413,10 @@ class Adopt_adv(torch.optim.Optimizer):
421
413
  normalized_grad.clamp_(-clip_val, clip_val)
422
414
 
423
415
  # ADOPT Step B: Update momentum m_t using normalized gradient
424
- if beta1 > 0:
416
+ if use_mt:
425
417
  mt = _reconstruct_state((state['mu_m_nmf'], state['mv_m_nmf'], state['sign'], d2), signed=True)
426
418
 
427
- if self.Simplified_AdEMAMix:
428
- mt.mul_(beta1).add_(normalized_grad, alpha=1.0)
429
- else:
430
- mt.lerp_(normalized_grad, 1.0 - beta1)
419
+ mt.lerp_(normalized_grad, 1.0 - beta1)
431
420
 
432
421
  # Factorize
433
422
  state['mu_m_nmf'], state['mv_m_nmf'], state['sign'] = _factorize_state(mt.clone(), signed=True)
@@ -439,13 +428,17 @@ class Adopt_adv(torch.optim.Optimizer):
439
428
  else:
440
429
  update_mt = mt
441
430
 
431
+ if nesterov:
432
+ nv_coef = beta1 if nesterov_coef is None else nesterov_coef
433
+ update_mt = update_mt.lerp_(grad_reshaped, 1-nv_coef)
434
+
442
435
  if self.use_AdEMAMix:
443
436
  # Reconstruct AdEMAMix EMA
444
437
  mt_slow = _reconstruct_state((state['mu_m_slow_nmf'], state['mv_m_slow_nmf'], state['sign_slow'], d2), signed=True)
445
438
 
446
439
  mt_slow.lerp_(normalized_grad, 1.0 - beta3_ema)
447
440
 
448
- if beta1 > 0:
441
+ if use_mt:
449
442
  update = update_mt.add_(mt_slow, alpha=alpha)
450
443
  del normalized_grad
451
444
  else:
@@ -453,12 +446,8 @@ class Adopt_adv(torch.optim.Optimizer):
453
446
  # Factorize
454
447
  state['mu_m_slow_nmf'], state['mv_m_slow_nmf'], state['sign_slow'] = _factorize_state(mt_slow, signed=True)
455
448
  del mt_slow
456
-
457
- elif self.Simplified_AdEMAMix:
458
- update = update_mt.add_(normalized_grad, alpha=alpha_grad)
459
- del normalized_grad
460
449
  else:
461
- if beta1 > 0:
450
+ if use_mt:
462
451
  update = update_mt
463
452
  del normalized_grad
464
453
  else:
@@ -490,12 +479,9 @@ class Adopt_adv(torch.optim.Optimizer):
490
479
  normalized_grad.clamp_(-clip_val, clip_val)
491
480
 
492
481
  # ADOPT Step B: Update momentum m_t
493
- if beta1 > 0:
482
+ if use_mt:
494
483
  mt = get_state(state, 'exp_avg', actual_precision) # m_{t-1}
495
- if self.Simplified_AdEMAMix:
496
- mt.mul_(beta1).add_(normalized_grad, alpha=1.0)
497
- else:
498
- mt.lerp_(normalized_grad, 1.0 - beta1)
484
+ mt.lerp_(normalized_grad, 1.0 - beta1)
499
485
 
500
486
  if self.grams_moment:
501
487
  update_mt = _grams_update(mt, grad)
@@ -504,21 +490,23 @@ class Adopt_adv(torch.optim.Optimizer):
504
490
  else:
505
491
  update_mt = mt.clone()
506
492
 
493
+ if nesterov:
494
+ nv_coef = beta1 if nesterov_coef is None else nesterov_coef
495
+ update_mt = update_mt.lerp_(grad, 1-nv_coef)
496
+
507
497
  set_state(state, 'exp_avg', mt, actual_precision, random_int_state_tensor)
508
498
 
509
499
  if self.use_AdEMAMix:
510
500
  m_slow = get_state(state, 'exp_avg_slow', actual_precision)
511
501
  m_slow.lerp_(normalized_grad, 1.0 - beta3_ema)
512
- if beta1 > 0:
502
+ if use_mt:
513
503
  update = update_mt.add_(m_slow, alpha=alpha)
514
504
  del normalized_grad
515
505
  else:
516
506
  update = normalized_grad.add_(m_slow, alpha=alpha)
517
507
  set_state(state, 'exp_avg_slow', m_slow, actual_precision, random_int_state_tensor)
518
- elif self.Simplified_AdEMAMix:
519
- update = update_mt.add_(normalized_grad, alpha=alpha_grad)
520
508
  else:
521
- if beta1 > 0:
509
+ if use_mt:
522
510
  update = update_mt
523
511
  del normalized_grad
524
512
  else: