adv-optm 2.4.dev2__tar.gz → 2.4.dev5__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. {adv_optm-2.4.dev2 → adv_optm-2.4.dev5}/PKG-INFO +1 -1
  2. {adv_optm-2.4.dev2 → adv_optm-2.4.dev5}/adv_optm/__init__.py +1 -1
  3. {adv_optm-2.4.dev2 → adv_optm-2.4.dev5}/adv_optm/optim/AdaMuon_adv.py +9 -5
  4. {adv_optm-2.4.dev2 → adv_optm-2.4.dev5}/adv_optm/optim/AdamW_adv.py +30 -10
  5. {adv_optm-2.4.dev2 → adv_optm-2.4.dev5}/adv_optm/optim/Adopt_adv.py +45 -24
  6. {adv_optm-2.4.dev2 → adv_optm-2.4.dev5}/adv_optm/optim/Lion_Prodigy_adv.py +2 -2
  7. {adv_optm-2.4.dev2 → adv_optm-2.4.dev5}/adv_optm/optim/Lion_adv.py +42 -26
  8. {adv_optm-2.4.dev2 → adv_optm-2.4.dev5}/adv_optm/optim/Muon_adv.py +8 -4
  9. {adv_optm-2.4.dev2 → adv_optm-2.4.dev5}/adv_optm/optim/Prodigy_adv.py +27 -12
  10. {adv_optm-2.4.dev2 → adv_optm-2.4.dev5}/adv_optm/optim/SignSGD_adv.py +39 -24
  11. {adv_optm-2.4.dev2 → adv_optm-2.4.dev5}/adv_optm/optim/Simplified_AdEMAMix.py +12 -6
  12. {adv_optm-2.4.dev2 → adv_optm-2.4.dev5}/adv_optm/util/Kourkoutas.py +43 -12
  13. {adv_optm-2.4.dev2 → adv_optm-2.4.dev5}/adv_optm/util/Muon_AuxAdam.py +8 -2
  14. {adv_optm-2.4.dev2 → adv_optm-2.4.dev5}/adv_optm/util/Muon_util.py +7 -5
  15. adv_optm-2.4.dev5/adv_optm/util/OrthoGrad.py +50 -0
  16. {adv_optm-2.4.dev2 → adv_optm-2.4.dev5}/adv_optm/util/centered_decay.py +1 -1
  17. {adv_optm-2.4.dev2 → adv_optm-2.4.dev5}/adv_optm/util/param_update.py +54 -12
  18. {adv_optm-2.4.dev2 → adv_optm-2.4.dev5}/adv_optm/util/scaled_optm.py +28 -20
  19. adv_optm-2.4.dev5/adv_optm/util/signed_util.py +13 -0
  20. adv_optm-2.4.dev5/adv_optm/util/update_util.py +111 -0
  21. {adv_optm-2.4.dev2 → adv_optm-2.4.dev5}/adv_optm.egg-info/PKG-INFO +1 -1
  22. {adv_optm-2.4.dev2 → adv_optm-2.4.dev5}/adv_optm.egg-info/SOURCES.txt +1 -0
  23. {adv_optm-2.4.dev2 → adv_optm-2.4.dev5}/setup.py +1 -1
  24. adv_optm-2.4.dev2/adv_optm/util/OrthoGrad.py +0 -21
  25. adv_optm-2.4.dev2/adv_optm/util/update_util.py +0 -32
  26. {adv_optm-2.4.dev2 → adv_optm-2.4.dev5}/LICENSE +0 -0
  27. {adv_optm-2.4.dev2 → adv_optm-2.4.dev5}/README.md +0 -0
  28. {adv_optm-2.4.dev2 → adv_optm-2.4.dev5}/adv_optm/optim/__init__.py +0 -0
  29. {adv_optm-2.4.dev2 → adv_optm-2.4.dev5}/adv_optm/util/__init__.py +0 -0
  30. {adv_optm-2.4.dev2 → adv_optm-2.4.dev5}/adv_optm/util/factorization_util.py +0 -0
  31. {adv_optm-2.4.dev2 → adv_optm-2.4.dev5}/adv_optm/util/lion_k.py +0 -0
  32. {adv_optm-2.4.dev2 → adv_optm-2.4.dev5}/adv_optm.egg-info/dependency_links.txt +0 -0
  33. {adv_optm-2.4.dev2 → adv_optm-2.4.dev5}/adv_optm.egg-info/requires.txt +0 -0
  34. {adv_optm-2.4.dev2 → adv_optm-2.4.dev5}/adv_optm.egg-info/top_level.txt +0 -0
  35. {adv_optm-2.4.dev2 → adv_optm-2.4.dev5}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: adv_optm
3
- Version: 2.4.dev2
3
+ Version: 2.4.dev5
4
4
  Summary: A family of highly efficient, lightweight yet powerful optimizers.
5
5
  Home-page: https://github.com/Koratahiu/Advanced_Optimizers
6
6
  Author: Koratahiu
@@ -22,4 +22,4 @@ __all__ = [
22
22
  "SignSGD_adv",
23
23
  ]
24
24
 
25
- __version__ = "2.4.dev2"
25
+ __version__ = "2.4.dev5"
@@ -206,6 +206,8 @@ class AdaMuon_adv(torch.optim.Optimizer):
206
206
  if spectral_normalization and rms_rescaling:
207
207
  print("Warning: spectral_normalization is incompatible with rms_rescaling, Disabling rms_rescaling.")
208
208
  rms_rescaling = False
209
+ if spectral_normalization and accelerated_ns:
210
+ ValueError("spectral_normalization violates accelerated Newton-Schulz assumptions. Pick one of them.")
209
211
 
210
212
  defaults = {
211
213
  "lr": lr, "betas": betas, "weight_decay": weight_decay, "cautious_wd": cautious_wd,
@@ -260,6 +262,8 @@ class AdaMuon_adv(torch.optim.Optimizer):
260
262
  if group.get('use_muon') is None: # Fallback
261
263
  group['use_muon'] = group.get('optim_type') == 'muon'
262
264
 
265
+ self.init_step()
266
+
263
267
  self.kourkoutas_helper = None
264
268
  if any(group.get('adam_kourkoutas_beta', False) for group in self.param_groups):
265
269
  self.kourkoutas_helper = KourkoutasHelper(self)
@@ -280,8 +284,8 @@ class AdaMuon_adv(torch.optim.Optimizer):
280
284
  def load_state_dict(self, state_dict: dict) -> None:
281
285
  """
282
286
  Overrides default load_state_dict to implement a workaround for PyTorch's
283
- automatic dtype casting. It ensures factorized states remain float32 for
284
- stability, preserves integer/float8 quantized anchor states, and forces
287
+ automatic dtype casting. It ensures factorized states remain float32 for
288
+ stability, preserves integer/float8 quantized anchor states, and forces
285
289
  standard states onto the parameter's current dtype/device.
286
290
  """
287
291
  super().load_state_dict(state_dict)
@@ -419,7 +423,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
419
423
  step_size = group['lr'] / bias_correction1
420
424
 
421
425
  if is_compiled:
422
- step_size = torch.as_tensor(step_size, dtype=torch.float64)
426
+ step_size = torch.as_tensor(step_size)
423
427
  adam_step_param = self._compiled_adam_step_parameter
424
428
  else:
425
429
  adam_step_param = Muon_AuxAdam._adam_step_parameter
@@ -430,7 +434,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
430
434
 
431
435
  else: # Muon path
432
436
  if is_compiled:
433
- lr = torch.as_tensor(group['lr'], dtype=torch.float64)
437
+ lr = torch.as_tensor(group['lr'])
434
438
  muon_step_param = self._compiled_muon_step_parameter
435
439
  else:
436
440
  lr = group['lr']
@@ -467,7 +471,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
467
471
  else:
468
472
  shape_for_scaling = p.shape
469
473
 
470
- scaled_eps, adaptive_eps, spectral_target, wd_scale = get_spectral_scaling(shape_for_scaling, group['n_layers'])
474
+ scaled_eps, adaptive_eps, spectral_target, wd_scale = get_spectral_scaling(p, shape_for_scaling, group['n_layers'])
471
475
 
472
476
  weight_decay = group['weight_decay'] * wd_scale
473
477
  decoupled_wd = True
@@ -6,7 +6,7 @@ from typing import Optional, Callable
6
6
 
7
7
  from ..util import param_update
8
8
  from ..util.factorization_util import _get_effective_shape, _reconstruct_state, _factorize_state
9
- from ..util.update_util import _grams_update, _cautious_update
9
+ from ..util.update_util import _grams_update, _cautious_update, _init_fisher_wd_scaler, _get_fisher_wd_scaler
10
10
  from ..util.OrthoGrad import _orthogonalize_gradient
11
11
  from ..util.Kourkoutas import KourkoutasHelper
12
12
  from ..util.scaled_optm import scale_update, is_spectral, init_spectral_norm
@@ -29,6 +29,9 @@ class AdamW_adv(torch.optim.Optimizer):
29
29
  eps (float): term added to the denominator to improve
30
30
  numerical stability (default: 1e-8)
31
31
  weight_decay (float): weight decay (L2 penalty) (default: 0).
32
+ fisher_wd (bool): whether to use Fisher Adam (FAdam) weight decay, mapping
33
+ the decay direction through the empirical Fisher information matrix and
34
+ clipping its RMS. (default: False)
32
35
  cautious_wd (bool): Enables Cautious Weight Decay. If True, weight decay is
33
36
  applied only to parameter coordinates where the sign of the parameter
34
37
  and the sign of the optimizer update align (default: False).
@@ -91,7 +94,7 @@ class AdamW_adv(torch.optim.Optimizer):
91
94
  'int4': Uses 4-bit block-wise quantization (block size 32).
92
95
  nnmf_factor (bool): whether to use the factorization or disable it to use
93
96
  the uncompressed optimizer. (default: False)
94
- factored_2nd (bool): whether to keep the first moment uncompressed (dense)
97
+ factored_2nd (bool): whether to keep the first moment uncompressed (dense)
95
98
  while only factorizing the second moment. (default: True)
96
99
  """
97
100
 
@@ -103,6 +106,7 @@ class AdamW_adv(torch.optim.Optimizer):
103
106
  eps: float = 1e-8,
104
107
  # Decoupled/cautious weight decay
105
108
  weight_decay: float = 0.0,
109
+ fisher_wd: bool = False,
106
110
  cautious_wd: bool = False,
107
111
  # Adam's Bias Correction
108
112
  use_bias_correction: bool = True,
@@ -155,7 +159,8 @@ class AdamW_adv(torch.optim.Optimizer):
155
159
  cautious_mask = False
156
160
 
157
161
  defaults = {
158
- "lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay, "cautious_wd": cautious_wd,
162
+ "lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay,
163
+ "fisher_wd": fisher_wd, "cautious_wd": cautious_wd,
159
164
  "use_atan2": use_atan2,
160
165
  "orthogonal_gradient": orthogonal_gradient, "use_bias_correction": use_bias_correction,
161
166
  "beta3_ema": beta3_ema, "alpha": alpha, "compiled_optimizer": compiled_optimizer,
@@ -192,8 +197,8 @@ class AdamW_adv(torch.optim.Optimizer):
192
197
  def load_state_dict(self, state_dict: dict) -> None:
193
198
  """
194
199
  Overrides default load_state_dict to implement a workaround for PyTorch's
195
- automatic dtype casting. It ensures factorized states remain float32 for
196
- stability, preserves integer/float8 quantized anchor states, and forces
200
+ automatic dtype casting. It ensures factorized states remain float32 for
201
+ stability, preserves integer/float8 quantized anchor states, and forces
197
202
  standard states onto the parameter's current dtype/device.
198
203
  """
199
204
  super().load_state_dict(state_dict)
@@ -273,6 +278,8 @@ class AdamW_adv(torch.optim.Optimizer):
273
278
 
274
279
  _init_anchor(p, state, group)
275
280
 
281
+ _init_fisher_wd_scaler(group, state, p)
282
+
276
283
  beta1, beta2 = group['betas']
277
284
 
278
285
  current_step = state['step']
@@ -294,7 +301,7 @@ class AdamW_adv(torch.optim.Optimizer):
294
301
  random_int_tensor = None
295
302
 
296
303
  if group.get('compiled_optimizer', False):
297
- step_size = torch.as_tensor(step_size, dtype=torch.float64)
304
+ step_size = torch.as_tensor(step_size)
298
305
  if p.dtype == torch.bfloat16 and self.stochastic_rounding:
299
306
  # Pre-generate random tensor for stochastic rounding if needed.
300
307
  random_int_tensor = param_update._get_random_int_for_sr(p)
@@ -349,7 +356,11 @@ class AdamW_adv(torch.optim.Optimizer):
349
356
  update_mt = mt if not factored_2nd else mt.clone()
350
357
 
351
358
  vt = _reconstruct_state((state['mu_v_nmf'], state['mv_v_nmf']), signed=False)
352
- vt.mul_(beta2).addcmul_(grad_reshaped, grad_reshaped, value=1.0 - beta2)
359
+
360
+ if isinstance(beta2, torch.Tensor) and beta2.dim() > 0:
361
+ vt.mul_(beta2).addcmul_(grad_reshaped, grad_reshaped * (1.0 - beta2))
362
+ else:
363
+ vt.mul_(beta2).addcmul_(grad_reshaped, grad_reshaped, value=1.0 - beta2)
353
364
 
354
365
  if self.use_AdEMAMix:
355
366
  if factored_2nd:
@@ -363,7 +374,7 @@ class AdamW_adv(torch.optim.Optimizer):
363
374
  update = update_mt.add_(mt_slow, alpha=alpha)
364
375
  else:
365
376
  update = grad_reshaped.add(mt_slow, alpha=alpha)
366
-
377
+
367
378
  if not factored_2nd:
368
379
  # Factorize
369
380
  state['mu_m_slow_nmf'], state['mv_m_slow_nmf'], state['sign_slow'] = _factorize_state(mt_slow, signed=True)
@@ -385,6 +396,9 @@ class AdamW_adv(torch.optim.Optimizer):
385
396
  denom = vt.sqrt_()
386
397
  denom.div_(sqrt_bias_correction2).add_(group['eps'])
387
398
  update.div_(denom)
399
+
400
+ wd_scaler = _get_fisher_wd_scaler(group, state.get("wd_scaler"), p, denom, group['use_atan2'])
401
+
388
402
  del vt
389
403
 
390
404
  update = update.view(p.shape)
@@ -413,7 +427,10 @@ class AdamW_adv(torch.optim.Optimizer):
413
427
  update = update_mt if beta1 > 0 else grad.clone()
414
428
 
415
429
  exp_avg_sq = state['exp_avg_sq']
416
- exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
430
+ if isinstance(beta2, torch.Tensor) and beta2.dim() > 0:
431
+ exp_avg_sq.mul_(beta2).addcmul_(grad, grad * (1.0 - beta2))
432
+ else:
433
+ exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)
417
434
 
418
435
  if group['use_atan2']:
419
436
  denom = exp_avg_sq.sqrt()
@@ -423,6 +440,9 @@ class AdamW_adv(torch.optim.Optimizer):
423
440
  denom = exp_avg_sq.sqrt()
424
441
  denom.div_(sqrt_bias_correction2).add_(group['eps'])
425
442
  update.div_(denom)
443
+
444
+ wd_scaler = _get_fisher_wd_scaler(group, state.get("wd_scaler"), p, denom, group['use_atan2'])
445
+
426
446
  del denom
427
447
 
428
448
  update_scaling = step_size * A if group['use_atan2'] else step_size
@@ -431,7 +451,7 @@ class AdamW_adv(torch.optim.Optimizer):
431
451
  else:
432
452
  update.mul_(update_scaling)
433
453
 
434
- param_update.apply_parameter_update(self, p, group, update, step_size, random_int_tensor=random_int_tensor)
454
+ param_update.apply_parameter_update(self, p, group, update, step_size, random_int_tensor=random_int_tensor, wd_scaler=wd_scaler)
435
455
 
436
456
  def compile(self, *args, **kwargs):
437
457
  self._compiled_step_parameter = torch.compile(self._step_parameter, *args, **kwargs)
@@ -7,7 +7,7 @@ from ..util import param_update
7
7
  from ..util.factorization_util import _get_effective_shape, _reconstruct_state, _factorize_state, _nnmf
8
8
  from ..util.OrthoGrad import _orthogonalize_gradient
9
9
  from ..util.Kourkoutas import KourkoutasHelper
10
- from ..util.update_util import _grams_update, _cautious_update, _scale_sim_AdEMAMix_update
10
+ from ..util.update_util import _grams_update, _cautious_update, _scale_sim_AdEMAMix_update, _init_fisher_wd_scaler, _get_fisher_wd_scaler
11
11
  from ..util.scaled_optm import scale_update, is_spectral, init_spectral_norm
12
12
  from ..util.centered_decay import _init_anchor
13
13
 
@@ -33,6 +33,9 @@ class Adopt_adv(torch.optim.Optimizer):
33
33
  eps (float): term added to the denominator to improve
34
34
  numerical stability (default: 1e-6)
35
35
  weight_decay (float): weight decay (L2 penalty) (default: 0)
36
+ fisher_wd (bool): whether to use Fisher Adam (FAdam) weight decay, mapping
37
+ the decay direction through the empirical Fisher information matrix and
38
+ clipping its RMS. (default: False)
36
39
  cautious_wd (bool): Enables Cautious Weight Decay. If True, weight decay is
37
40
  applied only to parameter coordinates where the sign of the parameter
38
41
  and the sign of the optimizer update align (default: False).
@@ -107,7 +110,7 @@ class Adopt_adv(torch.optim.Optimizer):
107
110
  'int4': Uses 4-bit block-wise quantization (block size 32).
108
111
  nnmf_factor (bool): whether to use the factorization or disable it to use
109
112
  the uncompressed optimizer. (default: False)
110
- factored_2nd (bool): whether to keep the first moment uncompressed (dense)
113
+ factored_2nd (bool): whether to keep the first moment uncompressed (dense)
111
114
  while only factorizing the second moment. (default: True)
112
115
  """
113
116
 
@@ -119,6 +122,7 @@ class Adopt_adv(torch.optim.Optimizer):
119
122
  eps: float = 1e-6,
120
123
  # Decoupled/cautious weight decay
121
124
  weight_decay: float = 0.0,
125
+ fisher_wd: bool = False,
122
126
  cautious_wd: bool = False,
123
127
  # ADOPT clipping
124
128
  clip_lambda: Optional[Callable[[int], float]] = lambda step: step**0.25,
@@ -181,7 +185,8 @@ class Adopt_adv(torch.optim.Optimizer):
181
185
  print("Warning: cautious is incompatible with Simplified_AdEMAMix, Disabling cautious.")
182
186
 
183
187
  defaults = {
184
- "lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay, "cautious_wd": cautious_wd,
188
+ "lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay,
189
+ "fisher_wd": fisher_wd, "cautious_wd": cautious_wd,
185
190
  "beta3_ema": beta3_ema, "alpha": alpha,
186
191
  "alpha_grad": alpha_grad,
187
192
  "kourkoutas_beta": kourkoutas_beta, "beta2_min": beta2_min, "ema_alpha": ema_alpha,
@@ -189,7 +194,7 @@ class Adopt_adv(torch.optim.Optimizer):
189
194
  "scaled_optm": scaled_optm,
190
195
  "centered_wd": centered_wd,
191
196
  "centered_wd_mode": centered_wd_mode,
192
- "nnmf_factor": nnmf_factor, "vector_reshape": vector_reshape, "factored_2nd": factored_2nd,
197
+ "nnmf_factor": nnmf_factor, "vector_reshape": vector_reshape, "factored_2nd": factored_2nd,
193
198
  "compiled_optimizer": compiled_optimizer,
194
199
  }
195
200
  self.clip_lambda = clip_lambda
@@ -222,8 +227,8 @@ class Adopt_adv(torch.optim.Optimizer):
222
227
  def load_state_dict(self, state_dict: dict) -> None:
223
228
  """
224
229
  Overrides default load_state_dict to implement a workaround for PyTorch's
225
- automatic dtype casting. It ensures factorized states remain float32 for
226
- stability, preserves integer/float8 quantized anchor states, and forces
230
+ automatic dtype casting. It ensures factorized states remain float32 for
231
+ stability, preserves integer/float8 quantized anchor states, and forces
227
232
  standard states onto the parameter's current dtype/device.
228
233
  """
229
234
  super().load_state_dict(state_dict)
@@ -244,6 +249,19 @@ class Adopt_adv(torch.optim.Optimizer):
244
249
  grad = p.grad
245
250
  state = self.state[p]
246
251
 
252
+
253
+ beta1, beta2 = group['betas']
254
+
255
+ if group.get('kourkoutas_beta', False):
256
+ if 'step' not in state:
257
+ current_step = 0
258
+ else:
259
+ current_step = state['step']
260
+ # Call prepare_step() once at the beginning of the step for all params
261
+ self.kourkoutas_helper.maybe_prepare_step(current_step, p.device)
262
+ # Get the dynamic beta2 calculated in prepare_step()
263
+ beta2 = self.kourkoutas_helper.get_beta2(p, group)
264
+
247
265
  # State Initialization
248
266
  if 'step' not in state:
249
267
  state['step'] = 0
@@ -256,6 +274,12 @@ class Adopt_adv(torch.optim.Optimizer):
256
274
 
257
275
  dtype = torch.float32 if state['factored'] else p.dtype
258
276
 
277
+ vt_init = grad.pow(2).to(dtype)
278
+ if isinstance(beta2, torch.Tensor) and beta2.dim() > 0:
279
+ vt_init.mul_(beta2).addcmul_(grad.to(dtype), grad.to(dtype) * (1.0 - beta2))
280
+ else:
281
+ vt_init.mul_(beta2).addcmul_(grad.to(dtype), grad.to(dtype), value=1.0 - beta2)
282
+
259
283
  if state['factored']:
260
284
  state['effective_shape'] = _get_effective_shape(p.numel())
261
285
  d1, d2 = state['effective_shape']
@@ -279,33 +303,23 @@ class Adopt_adv(torch.optim.Optimizer):
279
303
  if self.use_AdEMAMix:
280
304
  state['exp_avg_slow'] = torch.zeros_like(p, device=p.device, dtype=dtype)
281
305
  # Second moment (v)
282
- vt_init = grad.to(dtype).view(d1, d2).square()
283
- # Allocate NMF factors for vt
284
- state['mu_v_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
285
- state['mv_v_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
286
- # Initialize v_0
287
- state['mu_v_nmf'], state['mv_v_nmf'] = _nnmf(vt_init)
306
+ state['mu_v_nmf'], state['mv_v_nmf'] = _nnmf(vt_init.view(d1, d2))
288
307
  del vt_init
289
308
  else: # Fallback for non-factored tensors
290
309
  if group['betas'][0] > 0:
291
310
  state['exp_avg'] = torch.zeros_like(p, device=p.device, dtype=dtype)
292
311
  if self.use_AdEMAMix:
293
312
  state['exp_avg_slow'] = torch.zeros_like(p, device=p.device, dtype=dtype)
294
- state['exp_avg_sq'] = grad.to(dtype).square()
313
+ state['exp_avg_sq'] = vt_init
295
314
 
296
315
  if group.get('scaled_optm', False) and is_spectral(p):
297
316
  init_spectral_norm(group, state, p)
298
317
 
299
318
  _init_anchor(p, state, group)
300
319
 
301
- beta1, beta2 = group['betas']
320
+ _init_fisher_wd_scaler(group, state, p)
302
321
 
303
322
  current_step = state['step']
304
- if group.get('kourkoutas_beta', False):
305
- # Call prepare_step() once at the beginning of the step for all params
306
- self.kourkoutas_helper.maybe_prepare_step(current_step, p.device)
307
- # Get the dynamic beta2 calculated in prepare_step()
308
- beta2 = self.kourkoutas_helper.get_beta2(p, group)
309
323
 
310
324
  # The first step is for initialization only (skip when use_atan2 as it's scale invariant).
311
325
  if state['step'] == 0 and not self.use_atan2:
@@ -315,7 +329,7 @@ class Adopt_adv(torch.optim.Optimizer):
315
329
  random_int_tensor = None
316
330
 
317
331
  if group.get('compiled_optimizer', False):
318
- lr = torch.as_tensor(group['lr'], dtype=torch.float64)
332
+ lr = torch.as_tensor(group['lr'])
319
333
  if p.dtype == torch.bfloat16 and self.stochastic_rounding:
320
334
  # Pre-generate random tensor for stochastic rounding if needed.
321
335
  random_int_tensor = param_update._get_random_int_for_sr(p)
@@ -359,9 +373,13 @@ class Adopt_adv(torch.optim.Optimizer):
359
373
 
360
374
  # ADOPT Step A: Decorrelate g_t using v_{t-1}
361
375
  denom = vt.sqrt()
376
+ wd_scaler = _get_fisher_wd_scaler(group, state.get("wd_scaler"), p, denom, self.use_atan2)
362
377
 
363
378
  # Update second moment v_t for the *next* step using raw g_t
364
- vt.mul_(beta2).addcmul_(grad_reshaped, grad_reshaped, value=1.0 - beta2)
379
+ if isinstance(beta2, torch.Tensor) and beta2.dim() > 0:
380
+ vt.mul_(beta2).addcmul_(grad_reshaped, grad_reshaped * (1.0 - beta2))
381
+ else:
382
+ vt.mul_(beta2).addcmul_(grad_reshaped, grad_reshaped, value=1.0 - beta2)
365
383
  # Factorize
366
384
  state['mu_v_nmf'], state['mv_v_nmf'] = _factorize_state(vt, signed=False)
367
385
  del vt
@@ -434,6 +452,7 @@ class Adopt_adv(torch.optim.Optimizer):
434
452
 
435
453
  # ADOPT Step A: Decorrelate g_t using v_{t-1}
436
454
  denom = vt.sqrt()
455
+ wd_scaler = _get_fisher_wd_scaler(group, state.get("wd_scaler"), p, denom, self.use_atan2)
437
456
 
438
457
  if self.use_atan2:
439
458
  normalized_grad = torch.atan2(grad, denom, out=denom)
@@ -475,9 +494,11 @@ class Adopt_adv(torch.optim.Optimizer):
475
494
  else:
476
495
  update = normalized_grad
477
496
 
478
-
479
497
  # Update second moment v_t for the next step using raw g_t
480
- vt.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
498
+ if isinstance(beta2, torch.Tensor) and beta2.dim() > 0:
499
+ vt.mul_(beta2).addcmul_(grad, grad * (1.0 - beta2))
500
+ else:
501
+ vt.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
481
502
 
482
503
  update_scaling = lr * A if self.use_atan2 else lr
483
504
 
@@ -487,7 +508,7 @@ class Adopt_adv(torch.optim.Optimizer):
487
508
  update.mul_(update_scaling)
488
509
 
489
510
  # Parameter Update
490
- param_update.apply_parameter_update(self, p, group, update, lr, random_int_tensor=random_int_tensor)
511
+ param_update.apply_parameter_update(self, p, group, update, lr, random_int_tensor=random_int_tensor, wd_scaler=wd_scaler)
491
512
 
492
513
  def compile(self, *args, **kwargs):
493
514
  self._compiled_step_parameter = torch.compile(self._step_parameter, *args, **kwargs)
@@ -225,8 +225,8 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
225
225
  # Pre-generate random tensor for stochastic rounding if needed.
226
226
  random_int_tensor = param_update._get_random_int_for_sr(p)
227
227
  # TODO, workaround until pytorch#169634 is fixed
228
- d = torch.as_tensor(group['d'], dtype=torch.float64)
229
- dlr = torch.as_tensor(dlr, dtype=torch.float64)
228
+ d = torch.as_tensor(group['d'])
229
+ dlr = torch.as_tensor(dlr)
230
230
  step_param_fn = self._compiled_step_parameter
231
231
  else:
232
232
  d = group['d']
@@ -8,6 +8,8 @@ from ..util.factorization_util import _get_effective_shape, _reconstruct_state,
8
8
  from ..util.lion_k import _get_lion_k_update
9
9
  from ..util.scaled_optm import scale_update, is_spectral, init_spectral_norm
10
10
  from ..util.centered_decay import _init_anchor
11
+ from ..util.update_util import _get_l1_adaptive_lr
12
+ from ..util.signed_util import apply_stochastic_sign
11
13
 
12
14
 
13
15
  class Lion_adv(torch.optim.Optimizer):
@@ -44,9 +46,10 @@ class Lion_adv(torch.optim.Optimizer):
44
46
  parameter dimensionality. Sets p=2.0 for 4D tensors (Conv2D) (Biases/Norms) to
45
47
  use Spherical updates, and p=1.0 for others (Linear/Embeddings) to use Sign
46
48
  updates. Overrides explicit kappa_p value. (default: False).
49
+ stochastic_sign (bool): whether to use the Stochastic Sign operator. (default: False)
47
50
  freeze_on_flip (bool): Projected SignGD One-hit freeze. Masks updates for
48
51
  coordinates where the gradient sign flips compared to the previous step. (default: False)
49
- l1_adaptive (bool): Scales learning rate dynamically
52
+ l1_adaptive (bool): Scales learning rate dynamically
50
53
  by the L1 norm of the gradient to handle gradient heterogeneity. (default: False).
51
54
  centered_wd (float): Centered Weight Decay coefficient. Instead of decaying weights
52
55
  toward zero, they are decayed toward their initial values (anchors). This
@@ -79,6 +82,8 @@ class Lion_adv(torch.optim.Optimizer):
79
82
  # Lion-k
80
83
  kappa_p: float = 1.0,
81
84
  auto_kappa_p: bool = False,
85
+ # Stochastic Sign Operator
86
+ stochastic_sign: bool = False,
82
87
  # Projected and adaptive sign
83
88
  freeze_on_flip: bool = False,
84
89
  l1_adaptive: bool = False,
@@ -110,6 +115,7 @@ class Lion_adv(torch.optim.Optimizer):
110
115
  clip_threshold=clip_threshold,
111
116
  kappa_p=kappa_p,
112
117
  auto_kappa_p=auto_kappa_p,
118
+ stochastic_sign=stochastic_sign,
113
119
  freeze_on_flip=freeze_on_flip,
114
120
  l1_adaptive=l1_adaptive,
115
121
  scaled_optm= scaled_optm,
@@ -137,8 +143,8 @@ class Lion_adv(torch.optim.Optimizer):
137
143
  def load_state_dict(self, state_dict: dict) -> None:
138
144
  """
139
145
  Overrides default load_state_dict to implement a workaround for PyTorch's
140
- automatic dtype casting. It ensures factorized states remain float32 for
141
- stability, preserves integer/float8 quantized anchor states, and forces
146
+ automatic dtype casting. It ensures factorized states remain float32 for
147
+ stability, preserves integer/float8 quantized anchor states, and forces
142
148
  standard states onto the parameter's current dtype/device.
143
149
  """
144
150
  super().load_state_dict(state_dict)
@@ -201,19 +207,22 @@ class Lion_adv(torch.optim.Optimizer):
201
207
  lr = group["lr"]
202
208
 
203
209
  random_int_tensor = None
210
+ random_noise_tensor = None
204
211
 
205
212
  if group.get('compiled_optimizer', False):
206
213
  if p.dtype == torch.bfloat16 and self.stochastic_rounding:
207
214
  # Pre-generate random tensor for stochastic rounding if needed.
208
215
  random_int_tensor = param_update._get_random_int_for_sr(p)
209
- lr = torch.as_tensor(lr, dtype=torch.float64)
216
+ if group.get('stochastic_sign', False):
217
+ random_noise_tensor = param_update._get_random_noise_for_sso(p)
218
+ lr = torch.as_tensor(lr)
210
219
  step_param_fn = self._compiled_step_parameter
211
220
  else:
212
221
  step_param_fn = self._step_parameter
213
222
 
214
- step_param_fn(p, grad, state, group, lr, random_int_tensor)
223
+ step_param_fn(p, grad, state, group, lr, random_int_tensor, random_noise_tensor)
215
224
 
216
- def _step_parameter(self, p, grad, state, group, lr, random_int_tensor):
225
+ def _step_parameter(self, p, grad, state, group, lr, random_int_tensor, random_noise_tensor):
217
226
  if grad.dtype != torch.float32 and state['factored']:
218
227
  grad = grad.float()
219
228
  if group["clip_threshold"] > 0.0:
@@ -251,9 +260,6 @@ class Lion_adv(torch.optim.Optimizer):
251
260
  # Compute update term c_t
252
261
  update = torch.lerp(grad_reshaped, exp_avg, beta1)
253
262
 
254
- if group.get("l1_adaptive", False) and kappa_p == 1:
255
- lr = lr * (update.norm(p=1))
256
-
257
263
  # Standard Lion momentum update
258
264
  # m_t = beta2 * m_{t-1} + (1-beta2) * g_t
259
265
  exp_avg.lerp_(grad_reshaped, 1 - beta2)
@@ -262,7 +268,12 @@ class Lion_adv(torch.optim.Optimizer):
262
268
  state['mu_m_nmf'], state['mv_m_nmf'], state['sign'] = _factorize_state(exp_avg, signed=True)
263
269
  del exp_avg
264
270
 
265
- update = _get_lion_k_update(update, kappa_p)
271
+ if freeze_on_flip:
272
+ # Fast binary diff (XOR) from momentum sign directly
273
+ flipped_packed = prev_sign_packed ^ state['sign']
274
+ flipped_mask = _unpack_bools(flipped_packed, original_m=d2).view_as(update)
275
+ update = torch.where(flipped_mask, 0.0, update)
276
+ del prev_sign_packed, flipped_packed, flipped_mask
266
277
 
267
278
  if self.cautious_mask:
268
279
  mask = (update * grad_reshaped > 0).to(grad_reshaped.dtype)
@@ -272,12 +283,12 @@ class Lion_adv(torch.optim.Optimizer):
272
283
 
273
284
  update = update.view(p.shape)
274
285
 
275
- if freeze_on_flip:
276
- # Fast binary diff (XOR) from momentum sign directly
277
- flipped_packed = prev_sign_packed ^ state['sign']
278
- flipped_mask = _unpack_bools(flipped_packed, original_m=d2).view_as(update)
279
- update = torch.where(flipped_mask, 0.0, update)
280
- del prev_sign_packed, flipped_packed, flipped_mask
286
+ l1_mean = _get_l1_adaptive_lr(p, update, state, group, kappa_p, rescale=False)
287
+
288
+ if group.get('stochastic_sign', False):
289
+ update = apply_stochastic_sign(update, noise=random_noise_tensor)
290
+ else:
291
+ update = _get_lion_k_update(update, kappa_p)
281
292
 
282
293
  else:
283
294
  # Fallback to standard Lion logic
@@ -286,10 +297,13 @@ class Lion_adv(torch.optim.Optimizer):
286
297
  # Compute update term
287
298
  update = torch.lerp(grad, exp_avg, beta1)
288
299
 
289
- if group.get("l1_adaptive", False) and kappa_p == 1:
290
- lr = lr * (update.norm(p=1))
300
+ # Standard Lion momentum update
301
+ exp_avg.lerp_(grad, 1 - beta2)
291
302
 
292
- update = _get_lion_k_update(update, kappa_p)
303
+ if freeze_on_flip:
304
+ current_sign = (update > 0).to(torch.uint8)
305
+ update = torch.where(current_sign == state['prev_sign'], update, 0.0)
306
+ state['prev_sign'] = current_sign
293
307
 
294
308
  if self.cautious_mask:
295
309
  mask = (update * grad > 0).to(grad.dtype)
@@ -297,20 +311,22 @@ class Lion_adv(torch.optim.Optimizer):
297
311
  update.mul_(mask)
298
312
  del mask
299
313
 
300
- # Standard Lion momentum update
301
- exp_avg.lerp_(grad, 1 - beta2)
314
+ l1_mean = _get_l1_adaptive_lr(p, update, state, group, kappa_p, rescale=False)
302
315
 
303
- if freeze_on_flip:
304
- current_sign = (update > 0).to(torch.uint8)
305
- update = torch.where(current_sign == state['prev_sign'], update, 0.0)
306
- state['prev_sign'] = current_sign
316
+ if group.get('stochastic_sign', False):
317
+ update = apply_stochastic_sign(update, noise=random_noise_tensor)
318
+ else:
319
+ update = _get_lion_k_update(update, kappa_p)
320
+
321
+ if l1_mean is not None:
322
+ update.mul_(l1_mean)
307
323
 
308
324
  if group.get('scaled_optm', False):
309
325
  update = scale_update(p, update, lr, vector_state=state.get('spectral_v'))
310
326
  else:
311
327
  update.mul_(lr)
312
328
 
313
- param_update.apply_parameter_update(self, p, group, update, lr, random_int_tensor=random_int_tensor)
329
+ param_update.apply_parameter_update(self, p, group, update, lr, random_int_tensor=random_int_tensor, wd_scaler=l1_mean)
314
330
 
315
331
  def compile(self, *args, **kwargs):
316
332
  self._compiled_step_parameter = torch.compile(self._step_parameter, *args, **kwargs)
@@ -183,6 +183,8 @@ class Muon_adv(torch.optim.Optimizer):
183
183
  if spectral_normalization and rms_rescaling:
184
184
  print("Warning: spectral_normalization is incompatible with rms_rescaling, Disabling rms_rescaling.")
185
185
  rms_rescaling = False
186
+ if spectral_normalization and accelerated_ns:
187
+ ValueError("spectral_normalization violates accelerated Newton-Schulz assumptions. Pick one of them.")
186
188
 
187
189
  defaults = {
188
190
  "lr": lr, "beta1": beta1, "weight_decay": weight_decay, "cautious_wd": cautious_wd,
@@ -239,6 +241,8 @@ class Muon_adv(torch.optim.Optimizer):
239
241
  if group.get('use_muon') is None: # Fallback
240
242
  group['use_muon'] = group.get('optim_type') == 'muon'
241
243
 
244
+ self.init_step()
245
+
242
246
  self.kourkoutas_helper = None
243
247
  if any(group.get('adam_kourkoutas_beta', False) for group in self.param_groups):
244
248
  self.kourkoutas_helper = KourkoutasHelper(self)
@@ -259,8 +263,8 @@ class Muon_adv(torch.optim.Optimizer):
259
263
  def load_state_dict(self, state_dict: dict) -> None:
260
264
  """
261
265
  Overrides default load_state_dict to implement a workaround for PyTorch's
262
- automatic dtype casting. It ensures factorized states remain float32 for
263
- stability, preserves integer/float8 quantized anchor states, and forces
266
+ automatic dtype casting. It ensures factorized states remain float32 for
267
+ stability, preserves integer/float8 quantized anchor states, and forces
264
268
  standard states onto the parameter's current dtype/device.
265
269
  """
266
270
  super().load_state_dict(state_dict)
@@ -393,7 +397,7 @@ class Muon_adv(torch.optim.Optimizer):
393
397
  step_size = group['lr'] / bias_correction1
394
398
 
395
399
  if is_compiled:
396
- step_size = torch.as_tensor(step_size, dtype=torch.float64)
400
+ step_size = torch.as_tensor(step_size)
397
401
  adam_step_param = self._compiled_adam_step_parameter
398
402
  else:
399
403
  adam_step_param = Muon_AuxAdam._adam_step_parameter
@@ -404,7 +408,7 @@ class Muon_adv(torch.optim.Optimizer):
404
408
 
405
409
  else: # Muon path
406
410
  if is_compiled:
407
- lr = torch.as_tensor(group['lr'], dtype=torch.float64)
411
+ lr = torch.as_tensor(group['lr'])
408
412
  muon_step_param = self._compiled_muon_step_parameter
409
413
  else:
410
414
  lr = group['lr']