adv-optm 2.4.dev4__tar.gz → 2.4.dev6__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. {adv_optm-2.4.dev4 → adv_optm-2.4.dev6}/PKG-INFO +1 -1
  2. {adv_optm-2.4.dev4 → adv_optm-2.4.dev6}/adv_optm/__init__.py +1 -1
  3. {adv_optm-2.4.dev4 → adv_optm-2.4.dev6}/adv_optm/optim/AdaMuon_adv.py +7 -3
  4. {adv_optm-2.4.dev4 → adv_optm-2.4.dev6}/adv_optm/optim/AdamW_adv.py +25 -7
  5. {adv_optm-2.4.dev4 → adv_optm-2.4.dev6}/adv_optm/optim/Adopt_adv.py +21 -7
  6. {adv_optm-2.4.dev4 → adv_optm-2.4.dev6}/adv_optm/optim/Lion_Prodigy_adv.py +2 -2
  7. {adv_optm-2.4.dev4 → adv_optm-2.4.dev6}/adv_optm/optim/Lion_adv.py +35 -21
  8. {adv_optm-2.4.dev4 → adv_optm-2.4.dev6}/adv_optm/optim/Muon_adv.py +6 -2
  9. {adv_optm-2.4.dev4 → adv_optm-2.4.dev6}/adv_optm/optim/Prodigy_adv.py +14 -5
  10. {adv_optm-2.4.dev4 → adv_optm-2.4.dev6}/adv_optm/optim/SignSGD_adv.py +31 -15
  11. {adv_optm-2.4.dev4 → adv_optm-2.4.dev6}/adv_optm/optim/Simplified_AdEMAMix.py +1 -1
  12. {adv_optm-2.4.dev4 → adv_optm-2.4.dev6}/adv_optm/util/Muon_util.py +7 -5
  13. {adv_optm-2.4.dev4 → adv_optm-2.4.dev6}/adv_optm/util/param_update.py +49 -7
  14. {adv_optm-2.4.dev4 → adv_optm-2.4.dev6}/adv_optm/util/scaled_optm.py +38 -22
  15. adv_optm-2.4.dev6/adv_optm/util/signed_util.py +13 -0
  16. {adv_optm-2.4.dev4 → adv_optm-2.4.dev6}/adv_optm/util/update_util.py +46 -8
  17. {adv_optm-2.4.dev4 → adv_optm-2.4.dev6}/adv_optm.egg-info/PKG-INFO +1 -1
  18. {adv_optm-2.4.dev4 → adv_optm-2.4.dev6}/adv_optm.egg-info/SOURCES.txt +1 -0
  19. {adv_optm-2.4.dev4 → adv_optm-2.4.dev6}/setup.py +1 -1
  20. {adv_optm-2.4.dev4 → adv_optm-2.4.dev6}/LICENSE +0 -0
  21. {adv_optm-2.4.dev4 → adv_optm-2.4.dev6}/README.md +0 -0
  22. {adv_optm-2.4.dev4 → adv_optm-2.4.dev6}/adv_optm/optim/__init__.py +0 -0
  23. {adv_optm-2.4.dev4 → adv_optm-2.4.dev6}/adv_optm/util/Kourkoutas.py +0 -0
  24. {adv_optm-2.4.dev4 → adv_optm-2.4.dev6}/adv_optm/util/Muon_AuxAdam.py +0 -0
  25. {adv_optm-2.4.dev4 → adv_optm-2.4.dev6}/adv_optm/util/OrthoGrad.py +0 -0
  26. {adv_optm-2.4.dev4 → adv_optm-2.4.dev6}/adv_optm/util/__init__.py +0 -0
  27. {adv_optm-2.4.dev4 → adv_optm-2.4.dev6}/adv_optm/util/centered_decay.py +0 -0
  28. {adv_optm-2.4.dev4 → adv_optm-2.4.dev6}/adv_optm/util/factorization_util.py +0 -0
  29. {adv_optm-2.4.dev4 → adv_optm-2.4.dev6}/adv_optm/util/lion_k.py +0 -0
  30. {adv_optm-2.4.dev4 → adv_optm-2.4.dev6}/adv_optm.egg-info/dependency_links.txt +0 -0
  31. {adv_optm-2.4.dev4 → adv_optm-2.4.dev6}/adv_optm.egg-info/requires.txt +0 -0
  32. {adv_optm-2.4.dev4 → adv_optm-2.4.dev6}/adv_optm.egg-info/top_level.txt +0 -0
  33. {adv_optm-2.4.dev4 → adv_optm-2.4.dev6}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: adv_optm
3
- Version: 2.4.dev4
3
+ Version: 2.4.dev6
4
4
  Summary: A family of highly efficient, lightweight yet powerful optimizers.
5
5
  Home-page: https://github.com/Koratahiu/Advanced_Optimizers
6
6
  Author: Koratahiu
@@ -22,4 +22,4 @@ __all__ = [
22
22
  "SignSGD_adv",
23
23
  ]
24
24
 
25
- __version__ = "2.4.dev4"
25
+ __version__ = "2.4.dev6"
@@ -206,6 +206,8 @@ class AdaMuon_adv(torch.optim.Optimizer):
206
206
  if spectral_normalization and rms_rescaling:
207
207
  print("Warning: spectral_normalization is incompatible with rms_rescaling, Disabling rms_rescaling.")
208
208
  rms_rescaling = False
209
+ if spectral_normalization and accelerated_ns:
210
+ ValueError("spectral_normalization violates accelerated Newton-Schulz assumptions. Pick one of them.")
209
211
 
210
212
  defaults = {
211
213
  "lr": lr, "betas": betas, "weight_decay": weight_decay, "cautious_wd": cautious_wd,
@@ -260,6 +262,8 @@ class AdaMuon_adv(torch.optim.Optimizer):
260
262
  if group.get('use_muon') is None: # Fallback
261
263
  group['use_muon'] = group.get('optim_type') == 'muon'
262
264
 
265
+ self.init_step()
266
+
263
267
  self.kourkoutas_helper = None
264
268
  if any(group.get('adam_kourkoutas_beta', False) for group in self.param_groups):
265
269
  self.kourkoutas_helper = KourkoutasHelper(self)
@@ -419,7 +423,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
419
423
  step_size = group['lr'] / bias_correction1
420
424
 
421
425
  if is_compiled:
422
- step_size = torch.as_tensor(step_size, dtype=torch.float64)
426
+ step_size = torch.as_tensor(step_size)
423
427
  adam_step_param = self._compiled_adam_step_parameter
424
428
  else:
425
429
  adam_step_param = Muon_AuxAdam._adam_step_parameter
@@ -430,7 +434,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
430
434
 
431
435
  else: # Muon path
432
436
  if is_compiled:
433
- lr = torch.as_tensor(group['lr'], dtype=torch.float64)
437
+ lr = torch.as_tensor(group['lr'])
434
438
  muon_step_param = self._compiled_muon_step_parameter
435
439
  else:
436
440
  lr = group['lr']
@@ -467,7 +471,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
467
471
  else:
468
472
  shape_for_scaling = p.shape
469
473
 
470
- scaled_eps, adaptive_eps, spectral_target, wd_scale = get_spectral_scaling(shape_for_scaling, group['n_layers'])
474
+ scaled_eps, adaptive_eps, spectral_target, wd_scale = get_spectral_scaling(p, shape_for_scaling, group['n_layers'])
471
475
 
472
476
  weight_decay = group['weight_decay'] * wd_scale
473
477
  decoupled_wd = True
@@ -6,10 +6,10 @@ from typing import Optional, Callable
6
6
 
7
7
  from ..util import param_update
8
8
  from ..util.factorization_util import _get_effective_shape, _reconstruct_state, _factorize_state
9
- from ..util.update_util import _grams_update, _cautious_update
9
+ from ..util.update_util import _grams_update, _cautious_update, _init_fisher_wd_scaler, _get_fisher_wd_scaler
10
10
  from ..util.OrthoGrad import _orthogonalize_gradient
11
11
  from ..util.Kourkoutas import KourkoutasHelper
12
- from ..util.scaled_optm import scale_update, is_spectral, init_spectral_norm
12
+ from ..util.scaled_optm import scale_update, is_spectral, init_spectral_norm, scale_eps
13
13
  from ..util.centered_decay import _init_anchor
14
14
 
15
15
  A = 4 / math.pi
@@ -29,6 +29,9 @@ class AdamW_adv(torch.optim.Optimizer):
29
29
  eps (float): term added to the denominator to improve
30
30
  numerical stability (default: 1e-8)
31
31
  weight_decay (float): weight decay (L2 penalty) (default: 0).
32
+ fisher_wd (bool): whether to use Fisher Adam (FAdam) weight decay, mapping
33
+ the decay direction through the empirical Fisher information matrix and
34
+ clipping its RMS. (default: False)
32
35
  cautious_wd (bool): Enables Cautious Weight Decay. If True, weight decay is
33
36
  applied only to parameter coordinates where the sign of the parameter
34
37
  and the sign of the optimizer update align (default: False).
@@ -103,6 +106,7 @@ class AdamW_adv(torch.optim.Optimizer):
103
106
  eps: float = 1e-8,
104
107
  # Decoupled/cautious weight decay
105
108
  weight_decay: float = 0.0,
109
+ fisher_wd: bool = False,
106
110
  cautious_wd: bool = False,
107
111
  # Adam's Bias Correction
108
112
  use_bias_correction: bool = True,
@@ -149,13 +153,17 @@ class AdamW_adv(torch.optim.Optimizer):
149
153
  raise ValueError(f"Weight-decay should be >= 0.0. Got {weight_decay}")
150
154
  if kourkoutas_beta and not (betas[1] > beta2_min):
151
155
  raise ValueError(f"For Kourkoutas-β, betas[1] (as beta2_max) must be > beta2_min. Got {betas[1]} and {beta2_min}")
156
+ if scaled_optm and use_atan2:
157
+ print("Warning: use_atan2 is incompatible with scaled_optm, Disabling atan2.")
158
+ use_atan2 = False
152
159
 
153
160
  if cautious_mask and grams_moment:
154
161
  print("Warning: cautious is incompatible with grams, Disabling cautious.")
155
162
  cautious_mask = False
156
163
 
157
164
  defaults = {
158
- "lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay, "cautious_wd": cautious_wd,
165
+ "lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay,
166
+ "fisher_wd": fisher_wd, "cautious_wd": cautious_wd,
159
167
  "use_atan2": use_atan2,
160
168
  "orthogonal_gradient": orthogonal_gradient, "use_bias_correction": use_bias_correction,
161
169
  "beta3_ema": beta3_ema, "alpha": alpha, "compiled_optimizer": compiled_optimizer,
@@ -273,6 +281,8 @@ class AdamW_adv(torch.optim.Optimizer):
273
281
 
274
282
  _init_anchor(p, state, group)
275
283
 
284
+ _init_fisher_wd_scaler(group, state, p)
285
+
276
286
  beta1, beta2 = group['betas']
277
287
 
278
288
  current_step = state['step']
@@ -294,7 +304,7 @@ class AdamW_adv(torch.optim.Optimizer):
294
304
  random_int_tensor = None
295
305
 
296
306
  if group.get('compiled_optimizer', False):
297
- step_size = torch.as_tensor(step_size, dtype=torch.float64)
307
+ step_size = torch.as_tensor(step_size)
298
308
  if p.dtype == torch.bfloat16 and self.stochastic_rounding:
299
309
  # Pre-generate random tensor for stochastic rounding if needed.
300
310
  random_int_tensor = param_update._get_random_int_for_sr(p)
@@ -323,6 +333,8 @@ class AdamW_adv(torch.optim.Optimizer):
323
333
  # Determine if we are using dense first-moments alongside a factored second-order second-moment
324
334
  factored_2nd = group.get('factored_2nd', False)
325
335
 
336
+ adaptive_eps = scale_eps(group, p)
337
+
326
338
  if state['factored']:
327
339
  d1, d2 = state['effective_shape']
328
340
  grad_reshaped = grad.view(d1, d2)
@@ -387,8 +399,11 @@ class AdamW_adv(torch.optim.Optimizer):
387
399
  update.atan2_(denom)
388
400
  else:
389
401
  denom = vt.sqrt_()
390
- denom.div_(sqrt_bias_correction2).add_(group['eps'])
402
+ denom.div_(sqrt_bias_correction2).add_(adaptive_eps)
391
403
  update.div_(denom)
404
+
405
+ wd_scaler = _get_fisher_wd_scaler(group, state.get("wd_scaler"), p, denom, group['use_atan2'])
406
+
392
407
  del vt
393
408
 
394
409
  update = update.view(p.shape)
@@ -428,8 +443,11 @@ class AdamW_adv(torch.optim.Optimizer):
428
443
  update.atan2_(denom)
429
444
  else:
430
445
  denom = exp_avg_sq.sqrt()
431
- denom.div_(sqrt_bias_correction2).add_(group['eps'])
446
+ denom.div_(sqrt_bias_correction2).add_(adaptive_eps)
432
447
  update.div_(denom)
448
+
449
+ wd_scaler = _get_fisher_wd_scaler(group, state.get("wd_scaler"), p, denom, group['use_atan2'])
450
+
433
451
  del denom
434
452
 
435
453
  update_scaling = step_size * A if group['use_atan2'] else step_size
@@ -438,7 +456,7 @@ class AdamW_adv(torch.optim.Optimizer):
438
456
  else:
439
457
  update.mul_(update_scaling)
440
458
 
441
- param_update.apply_parameter_update(self, p, group, update, step_size, random_int_tensor=random_int_tensor)
459
+ param_update.apply_parameter_update(self, p, group, update, step_size, random_int_tensor=random_int_tensor, wd_scaler=wd_scaler)
442
460
 
443
461
  def compile(self, *args, **kwargs):
444
462
  self._compiled_step_parameter = torch.compile(self._step_parameter, *args, **kwargs)
@@ -7,8 +7,8 @@ from ..util import param_update
7
7
  from ..util.factorization_util import _get_effective_shape, _reconstruct_state, _factorize_state, _nnmf
8
8
  from ..util.OrthoGrad import _orthogonalize_gradient
9
9
  from ..util.Kourkoutas import KourkoutasHelper
10
- from ..util.update_util import _grams_update, _cautious_update, _scale_sim_AdEMAMix_update
11
- from ..util.scaled_optm import scale_update, is_spectral, init_spectral_norm
10
+ from ..util.update_util import _grams_update, _cautious_update, _scale_sim_AdEMAMix_update, _init_fisher_wd_scaler, _get_fisher_wd_scaler
11
+ from ..util.scaled_optm import scale_update, is_spectral, init_spectral_norm, scale_eps
12
12
  from ..util.centered_decay import _init_anchor
13
13
 
14
14
  A = 4 / math.pi
@@ -33,6 +33,9 @@ class Adopt_adv(torch.optim.Optimizer):
33
33
  eps (float): term added to the denominator to improve
34
34
  numerical stability (default: 1e-6)
35
35
  weight_decay (float): weight decay (L2 penalty) (default: 0)
36
+ fisher_wd (bool): whether to use Fisher Adam (FAdam) weight decay, mapping
37
+ the decay direction through the empirical Fisher information matrix and
38
+ clipping its RMS. (default: False)
36
39
  cautious_wd (bool): Enables Cautious Weight Decay. If True, weight decay is
37
40
  applied only to parameter coordinates where the sign of the parameter
38
41
  and the sign of the optimizer update align (default: False).
@@ -119,6 +122,7 @@ class Adopt_adv(torch.optim.Optimizer):
119
122
  eps: float = 1e-6,
120
123
  # Decoupled/cautious weight decay
121
124
  weight_decay: float = 0.0,
125
+ fisher_wd: bool = False,
122
126
  cautious_wd: bool = False,
123
127
  # ADOPT clipping
124
128
  clip_lambda: Optional[Callable[[int], float]] = lambda step: step**0.25,
@@ -179,9 +183,13 @@ class Adopt_adv(torch.optim.Optimizer):
179
183
  print("Warning: grams is incompatible with Simplified_AdEMAMix, Disabling grams.")
180
184
  if cautious_mask and Simplified_AdEMAMix:
181
185
  print("Warning: cautious is incompatible with Simplified_AdEMAMix, Disabling cautious.")
186
+ if scaled_optm and use_atan2:
187
+ print("Warning: use_atan2 is incompatible with scaled_optm, Disabling atan2.")
188
+ use_atan2 = False
182
189
 
183
190
  defaults = {
184
- "lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay, "cautious_wd": cautious_wd,
191
+ "lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay,
192
+ "fisher_wd": fisher_wd, "cautious_wd": cautious_wd,
185
193
  "beta3_ema": beta3_ema, "alpha": alpha,
186
194
  "alpha_grad": alpha_grad,
187
195
  "kourkoutas_beta": kourkoutas_beta, "beta2_min": beta2_min, "ema_alpha": ema_alpha,
@@ -312,6 +320,8 @@ class Adopt_adv(torch.optim.Optimizer):
312
320
 
313
321
  _init_anchor(p, state, group)
314
322
 
323
+ _init_fisher_wd_scaler(group, state, p)
324
+
315
325
  current_step = state['step']
316
326
 
317
327
  # The first step is for initialization only (skip when use_atan2 as it's scale invariant).
@@ -322,7 +332,7 @@ class Adopt_adv(torch.optim.Optimizer):
322
332
  random_int_tensor = None
323
333
 
324
334
  if group.get('compiled_optimizer', False):
325
- lr = torch.as_tensor(group['lr'], dtype=torch.float64)
335
+ lr = torch.as_tensor(group['lr'])
326
336
  if p.dtype == torch.bfloat16 and self.stochastic_rounding:
327
337
  # Pre-generate random tensor for stochastic rounding if needed.
328
338
  random_int_tensor = param_update._get_random_int_for_sr(p)
@@ -357,6 +367,8 @@ class Adopt_adv(torch.optim.Optimizer):
357
367
  # Determine if we are using dense first-moments alongside a factored second-order second-moment
358
368
  factored_2nd = group.get('factored_2nd', False)
359
369
 
370
+ adaptive_eps = scale_eps(group, p)
371
+
360
372
  if state['factored']:
361
373
  d1, d2 = state['effective_shape']
362
374
  grad_reshaped = grad.view(d1, d2)
@@ -366,6 +378,7 @@ class Adopt_adv(torch.optim.Optimizer):
366
378
 
367
379
  # ADOPT Step A: Decorrelate g_t using v_{t-1}
368
380
  denom = vt.sqrt()
381
+ wd_scaler = _get_fisher_wd_scaler(group, state.get("wd_scaler"), p, denom, self.use_atan2)
369
382
 
370
383
  # Update second moment v_t for the *next* step using raw g_t
371
384
  if isinstance(beta2, torch.Tensor) and beta2.dim() > 0:
@@ -379,7 +392,7 @@ class Adopt_adv(torch.optim.Optimizer):
379
392
  if self.use_atan2:
380
393
  normalized_grad = torch.atan2(grad_reshaped, denom, out=denom)
381
394
  else:
382
- normalized_grad = torch.div(grad_reshaped, denom.add_(group['eps']), out=denom)
395
+ normalized_grad = torch.div(grad_reshaped, denom.add_(adaptive_eps), out=denom)
383
396
  if self.clip_lambda is not None:
384
397
  clip_val = self.clip_lambda(state['step'])
385
398
  normalized_grad.clamp_(-clip_val, clip_val)
@@ -444,11 +457,12 @@ class Adopt_adv(torch.optim.Optimizer):
444
457
 
445
458
  # ADOPT Step A: Decorrelate g_t using v_{t-1}
446
459
  denom = vt.sqrt()
460
+ wd_scaler = _get_fisher_wd_scaler(group, state.get("wd_scaler"), p, denom, self.use_atan2)
447
461
 
448
462
  if self.use_atan2:
449
463
  normalized_grad = torch.atan2(grad, denom, out=denom)
450
464
  else:
451
- normalized_grad = torch.div(grad, denom.add_(group['eps']), out=denom)
465
+ normalized_grad = torch.div(grad, denom.add_(adaptive_eps), out=denom)
452
466
  if self.clip_lambda is not None:
453
467
  clip_val = self.clip_lambda(state['step'])
454
468
  normalized_grad.clamp_(-clip_val, clip_val)
@@ -499,7 +513,7 @@ class Adopt_adv(torch.optim.Optimizer):
499
513
  update.mul_(update_scaling)
500
514
 
501
515
  # Parameter Update
502
- param_update.apply_parameter_update(self, p, group, update, lr, random_int_tensor=random_int_tensor)
516
+ param_update.apply_parameter_update(self, p, group, update, lr, random_int_tensor=random_int_tensor, wd_scaler=wd_scaler)
503
517
 
504
518
  def compile(self, *args, **kwargs):
505
519
  self._compiled_step_parameter = torch.compile(self._step_parameter, *args, **kwargs)
@@ -225,8 +225,8 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
225
225
  # Pre-generate random tensor for stochastic rounding if needed.
226
226
  random_int_tensor = param_update._get_random_int_for_sr(p)
227
227
  # TODO, workaround until pytorch#169634 is fixed
228
- d = torch.as_tensor(group['d'], dtype=torch.float64)
229
- dlr = torch.as_tensor(dlr, dtype=torch.float64)
228
+ d = torch.as_tensor(group['d'])
229
+ dlr = torch.as_tensor(dlr)
230
230
  step_param_fn = self._compiled_step_parameter
231
231
  else:
232
232
  d = group['d']
@@ -9,6 +9,7 @@ from ..util.lion_k import _get_lion_k_update
9
9
  from ..util.scaled_optm import scale_update, is_spectral, init_spectral_norm
10
10
  from ..util.centered_decay import _init_anchor
11
11
  from ..util.update_util import _get_l1_adaptive_lr
12
+ from ..util.signed_util import apply_stochastic_sign
12
13
 
13
14
 
14
15
  class Lion_adv(torch.optim.Optimizer):
@@ -45,6 +46,7 @@ class Lion_adv(torch.optim.Optimizer):
45
46
  parameter dimensionality. Sets p=2.0 for 4D tensors (Conv2D) (Biases/Norms) to
46
47
  use Spherical updates, and p=1.0 for others (Linear/Embeddings) to use Sign
47
48
  updates. Overrides explicit kappa_p value. (default: False).
49
+ stochastic_sign (bool): whether to use the Stochastic Sign operator. (default: False)
48
50
  freeze_on_flip (bool): Projected SignGD One-hit freeze. Masks updates for
49
51
  coordinates where the gradient sign flips compared to the previous step. (default: False)
50
52
  l1_adaptive (bool): Scales learning rate dynamically
@@ -80,6 +82,8 @@ class Lion_adv(torch.optim.Optimizer):
80
82
  # Lion-k
81
83
  kappa_p: float = 1.0,
82
84
  auto_kappa_p: bool = False,
85
+ # Stochastic Sign Operator
86
+ stochastic_sign: bool = False,
83
87
  # Projected and adaptive sign
84
88
  freeze_on_flip: bool = False,
85
89
  l1_adaptive: bool = False,
@@ -111,6 +115,7 @@ class Lion_adv(torch.optim.Optimizer):
111
115
  clip_threshold=clip_threshold,
112
116
  kappa_p=kappa_p,
113
117
  auto_kappa_p=auto_kappa_p,
118
+ stochastic_sign=stochastic_sign,
114
119
  freeze_on_flip=freeze_on_flip,
115
120
  l1_adaptive=l1_adaptive,
116
121
  scaled_optm= scaled_optm,
@@ -202,19 +207,22 @@ class Lion_adv(torch.optim.Optimizer):
202
207
  lr = group["lr"]
203
208
 
204
209
  random_int_tensor = None
210
+ random_noise_tensor = None
205
211
 
206
212
  if group.get('compiled_optimizer', False):
207
213
  if p.dtype == torch.bfloat16 and self.stochastic_rounding:
208
214
  # Pre-generate random tensor for stochastic rounding if needed.
209
215
  random_int_tensor = param_update._get_random_int_for_sr(p)
210
- lr = torch.as_tensor(lr, dtype=torch.float64)
216
+ if group.get('stochastic_sign', False):
217
+ random_noise_tensor = param_update._get_random_noise_for_sso(p)
218
+ lr = torch.as_tensor(lr)
211
219
  step_param_fn = self._compiled_step_parameter
212
220
  else:
213
221
  step_param_fn = self._step_parameter
214
222
 
215
- step_param_fn(p, grad, state, group, lr, random_int_tensor)
223
+ step_param_fn(p, grad, state, group, lr, random_int_tensor, random_noise_tensor)
216
224
 
217
- def _step_parameter(self, p, grad, state, group, lr, random_int_tensor):
225
+ def _step_parameter(self, p, grad, state, group, lr, random_int_tensor, random_noise_tensor):
218
226
  if grad.dtype != torch.float32 and state['factored']:
219
227
  grad = grad.float()
220
228
  if group["clip_threshold"] > 0.0:
@@ -252,8 +260,6 @@ class Lion_adv(torch.optim.Optimizer):
252
260
  # Compute update term c_t
253
261
  update = torch.lerp(grad_reshaped, exp_avg, beta1)
254
262
 
255
- l1_mean = _get_l1_adaptive_lr(p, update, state, group, kappa_p)
256
-
257
263
  # Standard Lion momentum update
258
264
  # m_t = beta2 * m_{t-1} + (1-beta2) * g_t
259
265
  exp_avg.lerp_(grad_reshaped, 1 - beta2)
@@ -262,7 +268,12 @@ class Lion_adv(torch.optim.Optimizer):
262
268
  state['mu_m_nmf'], state['mv_m_nmf'], state['sign'] = _factorize_state(exp_avg, signed=True)
263
269
  del exp_avg
264
270
 
265
- update = _get_lion_k_update(update, kappa_p)
271
+ if freeze_on_flip:
272
+ # Fast binary diff (XOR) from momentum sign directly
273
+ flipped_packed = prev_sign_packed ^ state['sign']
274
+ flipped_mask = _unpack_bools(flipped_packed, original_m=d2).view_as(update)
275
+ update = torch.where(flipped_mask, 0.0, update)
276
+ del prev_sign_packed, flipped_packed, flipped_mask
266
277
 
267
278
  if self.cautious_mask:
268
279
  mask = (update * grad_reshaped > 0).to(grad_reshaped.dtype)
@@ -272,12 +283,12 @@ class Lion_adv(torch.optim.Optimizer):
272
283
 
273
284
  update = update.view(p.shape)
274
285
 
275
- if freeze_on_flip:
276
- # Fast binary diff (XOR) from momentum sign directly
277
- flipped_packed = prev_sign_packed ^ state['sign']
278
- flipped_mask = _unpack_bools(flipped_packed, original_m=d2).view_as(update)
279
- update = torch.where(flipped_mask, 0.0, update)
280
- del prev_sign_packed, flipped_packed, flipped_mask
286
+ l1_mean = _get_l1_adaptive_lr(p, update, state, group, kappa_p, rescale=False)
287
+
288
+ if group.get('stochastic_sign', False):
289
+ update = apply_stochastic_sign(update, noise=random_noise_tensor)
290
+ else:
291
+ update = _get_lion_k_update(update, kappa_p)
281
292
 
282
293
  else:
283
294
  # Fallback to standard Lion logic
@@ -286,9 +297,13 @@ class Lion_adv(torch.optim.Optimizer):
286
297
  # Compute update term
287
298
  update = torch.lerp(grad, exp_avg, beta1)
288
299
 
289
- l1_mean = _get_l1_adaptive_lr(p, update, state, group, kappa_p)
300
+ # Standard Lion momentum update
301
+ exp_avg.lerp_(grad, 1 - beta2)
290
302
 
291
- update = _get_lion_k_update(update, kappa_p)
303
+ if freeze_on_flip:
304
+ current_sign = (update > 0).to(torch.uint8)
305
+ update = torch.where(current_sign == state['prev_sign'], update, 0.0)
306
+ state['prev_sign'] = current_sign
292
307
 
293
308
  if self.cautious_mask:
294
309
  mask = (update * grad > 0).to(grad.dtype)
@@ -296,13 +311,12 @@ class Lion_adv(torch.optim.Optimizer):
296
311
  update.mul_(mask)
297
312
  del mask
298
313
 
299
- # Standard Lion momentum update
300
- exp_avg.lerp_(grad, 1 - beta2)
314
+ l1_mean = _get_l1_adaptive_lr(p, update, state, group, kappa_p, rescale=False)
301
315
 
302
- if freeze_on_flip:
303
- current_sign = (update > 0).to(torch.uint8)
304
- update = torch.where(current_sign == state['prev_sign'], update, 0.0)
305
- state['prev_sign'] = current_sign
316
+ if group.get('stochastic_sign', False):
317
+ update = apply_stochastic_sign(update, noise=random_noise_tensor)
318
+ else:
319
+ update = _get_lion_k_update(update, kappa_p)
306
320
 
307
321
  if l1_mean is not None:
308
322
  update.mul_(l1_mean)
@@ -312,7 +326,7 @@ class Lion_adv(torch.optim.Optimizer):
312
326
  else:
313
327
  update.mul_(lr)
314
328
 
315
- param_update.apply_parameter_update(self, p, group, update, lr, random_int_tensor=random_int_tensor)
329
+ param_update.apply_parameter_update(self, p, group, update, lr, random_int_tensor=random_int_tensor, wd_scaler=l1_mean)
316
330
 
317
331
  def compile(self, *args, **kwargs):
318
332
  self._compiled_step_parameter = torch.compile(self._step_parameter, *args, **kwargs)
@@ -183,6 +183,8 @@ class Muon_adv(torch.optim.Optimizer):
183
183
  if spectral_normalization and rms_rescaling:
184
184
  print("Warning: spectral_normalization is incompatible with rms_rescaling, Disabling rms_rescaling.")
185
185
  rms_rescaling = False
186
+ if spectral_normalization and accelerated_ns:
187
+ ValueError("spectral_normalization violates accelerated Newton-Schulz assumptions. Pick one of them.")
186
188
 
187
189
  defaults = {
188
190
  "lr": lr, "beta1": beta1, "weight_decay": weight_decay, "cautious_wd": cautious_wd,
@@ -239,6 +241,8 @@ class Muon_adv(torch.optim.Optimizer):
239
241
  if group.get('use_muon') is None: # Fallback
240
242
  group['use_muon'] = group.get('optim_type') == 'muon'
241
243
 
244
+ self.init_step()
245
+
242
246
  self.kourkoutas_helper = None
243
247
  if any(group.get('adam_kourkoutas_beta', False) for group in self.param_groups):
244
248
  self.kourkoutas_helper = KourkoutasHelper(self)
@@ -393,7 +397,7 @@ class Muon_adv(torch.optim.Optimizer):
393
397
  step_size = group['lr'] / bias_correction1
394
398
 
395
399
  if is_compiled:
396
- step_size = torch.as_tensor(step_size, dtype=torch.float64)
400
+ step_size = torch.as_tensor(step_size)
397
401
  adam_step_param = self._compiled_adam_step_parameter
398
402
  else:
399
403
  adam_step_param = Muon_AuxAdam._adam_step_parameter
@@ -404,7 +408,7 @@ class Muon_adv(torch.optim.Optimizer):
404
408
 
405
409
  else: # Muon path
406
410
  if is_compiled:
407
- lr = torch.as_tensor(group['lr'], dtype=torch.float64)
411
+ lr = torch.as_tensor(group['lr'])
408
412
  muon_step_param = self._compiled_muon_step_parameter
409
413
  else:
410
414
  lr = group['lr']
@@ -9,7 +9,7 @@ from ..util import param_update
9
9
  from ..util.OrthoGrad import _orthogonalize_gradient
10
10
  from ..util.Kourkoutas import KourkoutasHelper
11
11
  from ..util.factorization_util import _get_effective_shape, _reconstruct_state, _factorize_state
12
- from ..util.update_util import _grams_update, _cautious_update, _scale_sim_AdEMAMix_update
12
+ from ..util.update_util import _grams_update, _cautious_update, _scale_sim_AdEMAMix_update, _init_fisher_wd_scaler, _get_fisher_wd_scaler
13
13
  from ..util.centered_decay import _init_anchor
14
14
 
15
15
  A = 4 / math.pi
@@ -29,6 +29,9 @@ class Prodigy_adv(torch.optim.Optimizer):
29
29
  eps (float): term added to the denominator to improve
30
30
  numerical stability (default: 1e-8)
31
31
  weight_decay (float): weight decay (L2 penalty) (default: 0)
32
+ fisher_wd (bool): whether to use Fisher Adam (FAdam) weight decay, mapping
33
+ the decay direction through the empirical Fisher information matrix and
34
+ clipping its RMS. (default: False)
32
35
  cautious_wd (bool): Enables Cautious Weight Decay. If True, weight decay is
33
36
  applied only to parameter coordinates where the sign of the parameter
34
37
  and the sign of the optimizer update align (default: False).
@@ -133,6 +136,7 @@ class Prodigy_adv(torch.optim.Optimizer):
133
136
  eps: float = 1e-8,
134
137
  # Decoupled/cautious weight decay
135
138
  weight_decay: float = 0.0,
139
+ fisher_wd: bool = False,
136
140
  cautious_wd: bool = False,
137
141
  # Stochastic Rounding for BF16
138
142
  stochastic_rounding: bool = True,
@@ -206,7 +210,8 @@ class Prodigy_adv(torch.optim.Optimizer):
206
210
  raise ValueError(f"For Kourkoutas-β, betas[1] (as beta2_max) must be > beta2_min. Got {betas[1]} and {beta2_min}")
207
211
 
208
212
  defaults = {
209
- "lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay, "cautious_wd": cautious_wd,
213
+ "lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay,
214
+ "fisher_wd": fisher_wd, "cautious_wd": cautious_wd,
210
215
  "use_atan2": use_atan2,
211
216
  "orthogonal_gradient": orthogonal_gradient,
212
217
  "beta3_ema": beta3_ema, "alpha": alpha, "compiled_optimizer": compiled_optimizer,
@@ -354,6 +359,8 @@ class Prodigy_adv(torch.optim.Optimizer):
354
359
 
355
360
  _init_anchor(p, state, group)
356
361
 
362
+ _init_fisher_wd_scaler(group, state, p)
363
+
357
364
  if not hasattr(self, 'd_denom'):
358
365
  self.d_denom = torch.tensor(0.0, device=p.device)
359
366
  self.d_numerator = torch.tensor(group.get('d_numerator', 0.0), device=p.device)
@@ -376,8 +383,8 @@ class Prodigy_adv(torch.optim.Optimizer):
376
383
  # Pre-generate random tensor for stochastic rounding if needed.
377
384
  random_int_tensor = param_update._get_random_int_for_sr(p)
378
385
  # TODO, workaround until pytorch#169634 is fixed
379
- d = torch.as_tensor(group['d'], dtype=torch.float64)
380
- dlr = torch.as_tensor(dlr, dtype=torch.float64)
386
+ d = torch.as_tensor(group['d'])
387
+ dlr = torch.as_tensor(dlr)
381
388
  step_param_fn = self._compiled_step_parameter
382
389
  else:
383
390
  d = group['d']
@@ -478,6 +485,7 @@ class Prodigy_adv(torch.optim.Optimizer):
478
485
  else:
479
486
  denom = vt.sqrt_()
480
487
  update.div_(denom.add_(d * group['eps']))
488
+ wd_scaler = _get_fisher_wd_scaler(group, state.get("wd_scaler"), p, denom, group['use_atan2'])
481
489
  del vt
482
490
 
483
491
  update_scaling = dlr * A if group['use_atan2'] else dlr
@@ -528,6 +536,7 @@ class Prodigy_adv(torch.optim.Optimizer):
528
536
  else:
529
537
  denom = exp_avg_sq.sqrt()
530
538
  update.div_(denom.add_(d * group['eps']))
539
+ wd_scaler = _get_fisher_wd_scaler(group, state.get("wd_scaler"), p, denom, group['use_atan2'])
531
540
  del denom
532
541
 
533
542
  update_scaling = dlr * A if group['use_atan2'] else dlr
@@ -557,7 +566,7 @@ class Prodigy_adv(torch.optim.Optimizer):
557
566
  if 'p0' in state:
558
567
  del state['p0']
559
568
 
560
- param_update.apply_parameter_update(self, p, group, update, dlr, random_int_tensor=random_int_tensor)
569
+ param_update.apply_parameter_update(self, p, group, update, dlr, random_int_tensor=random_int_tensor, wd_scaler=wd_scaler)
561
570
 
562
571
  def compile(self, *args, **kwargs):
563
572
  self._compiled_step_parameter = torch.compile(self._step_parameter, *args, **kwargs)
@@ -9,6 +9,7 @@ from ..util.lion_k import _get_lion_k_update
9
9
  from ..util.update_util import _get_l1_adaptive_lr
10
10
  from ..util.scaled_optm import scale_update, is_spectral, init_spectral_norm
11
11
  from ..util.centered_decay import _init_anchor
12
+ from ..util.signed_util import apply_stochastic_sign
12
13
 
13
14
 
14
15
  class SignSGD_adv(torch.optim.Optimizer):
@@ -39,6 +40,7 @@ class SignSGD_adv(torch.optim.Optimizer):
39
40
  parameter dimensionality. Sets p=2.0 for 4D tensors (Conv2D) (Biases/Norms) to
40
41
  use Spherical updates, and p=1.0 for others (Linear/Embeddings) to use Sign
41
42
  updates. Overrides explicit kappa_p value. (default: False).
43
+ stochastic_sign (bool): whether to use the Stochastic Sign operator. (default: False)
42
44
  Simplified_AdEMAMix (bool): whether to use the Simplified AdEMAMix update rule.
43
45
  This changes the EMA to accumulator and the update numerator to `alpha_grad * grad + mt`, which can be
44
46
  more responsive, especially for small batch sizes. (default: False)
@@ -79,6 +81,8 @@ class SignSGD_adv(torch.optim.Optimizer):
79
81
  # Projection-k
80
82
  kappa_p: float = 1.0,
81
83
  auto_kappa_p: bool = True,
84
+ # Stochastic Sign Operator
85
+ stochastic_sign: bool = False,
82
86
  # Simplified_AdEMAMix
83
87
  alpha_grad: float = 1.0,
84
88
  Simplified_AdEMAMix: bool = False,
@@ -112,6 +116,7 @@ class SignSGD_adv(torch.optim.Optimizer):
112
116
  orthogonal_gradient=orthogonal_gradient,
113
117
  kappa_p=kappa_p,
114
118
  auto_kappa_p=auto_kappa_p,
119
+ stochastic_sign=stochastic_sign,
115
120
  alpha_grad=alpha_grad,
116
121
  Simplified_AdEMAMix=Simplified_AdEMAMix,
117
122
  scaled_optm= scaled_optm,
@@ -203,23 +208,26 @@ class SignSGD_adv(torch.optim.Optimizer):
203
208
  lr = group["lr"]
204
209
 
205
210
  random_int_tensor = None
211
+ random_noise_tensor = None
206
212
 
207
213
  if group.get('compiled_optimizer', False):
208
214
  if p.dtype == torch.bfloat16 and self.stochastic_rounding:
209
215
  # Pre-generate random tensor for stochastic rounding if needed.
210
216
  random_int_tensor = param_update._get_random_int_for_sr(p)
211
- lr = torch.as_tensor(lr, dtype=torch.float64)
217
+ if group.get('stochastic_sign', False):
218
+ random_noise_tensor = param_update._get_random_noise_for_sso(p)
219
+ lr = torch.as_tensor(lr)
212
220
  step_param_fn = self._compiled_step_parameter
213
221
  else:
214
222
  step_param_fn = self._step_parameter
215
223
 
216
- step_param_fn(p, grad, state, group, lr, random_int_tensor)
224
+ step_param_fn(p, grad, state, group, lr, random_int_tensor, random_noise_tensor)
217
225
 
218
226
  if group.get("l1_adaptive", False):
219
227
  state["step"] += 1
220
228
 
221
- def _step_parameter(self, p, grad, state, group, lr, random_int_tensor):
222
- if grad.dtype != torch.float32 and state['factored']:
229
+ def _step_parameter(self, p, grad, state, group, lr, random_int_tensor, random_noise_tensor):
230
+ if grad.dtype != torch.float32 and state.get('factored', False):
223
231
  grad = grad.float()
224
232
 
225
233
  if group["orthogonal_gradient"]:
@@ -269,18 +277,23 @@ class SignSGD_adv(torch.optim.Optimizer):
269
277
  if freeze_on_flip:
270
278
  state['sign'] = _pack_bools(raw_update > 0)
271
279
 
272
- l1_mean = _get_l1_adaptive_lr(p, raw_update, state, group, kappa_p)
273
280
 
274
- update = _get_lion_k_update(raw_update, kappa_p)
275
- update = update.view(p.shape)
281
+ raw_update = raw_update.view(p.shape)
276
282
 
277
283
  if freeze_on_flip:
278
284
  # Fast binary diff (XOR) from momentum sign directly
279
285
  flipped_packed = prev_sign_packed ^ state['sign']
280
- flipped_mask = _unpack_bools(flipped_packed, original_m=d2).view_as(update)
281
- update = torch.where(flipped_mask, 0.0, update)
286
+ flipped_mask = _unpack_bools(flipped_packed, original_m=d2).view_as(raw_update)
287
+ raw_update = torch.where(flipped_mask, 0.0, raw_update)
282
288
  del prev_sign_packed, flipped_packed, flipped_mask
283
289
 
290
+ l1_mean = _get_l1_adaptive_lr(p, raw_update, state, group, kappa_p)
291
+
292
+ if group.get('stochastic_sign', False):
293
+ update = apply_stochastic_sign(raw_update, noise=random_noise_tensor)
294
+ else:
295
+ update = _get_lion_k_update(raw_update, kappa_p)
296
+
284
297
  else:
285
298
  # Fallback to standard SignSGD logic
286
299
  if momentum > 0:
@@ -294,15 +307,18 @@ class SignSGD_adv(torch.optim.Optimizer):
294
307
  else:
295
308
  raw_update = grad.clone()
296
309
 
297
- l1_mean = _get_l1_adaptive_lr(p, raw_update, state, group, kappa_p)
298
-
299
- update = _get_lion_k_update(raw_update, kappa_p)
300
-
301
310
  if freeze_on_flip:
302
311
  current_sign = (raw_update > 0).to(torch.uint8)
303
- update = torch.where(current_sign == state['prev_sign'], update, 0.0)
312
+ raw_update = torch.where(current_sign == state['prev_sign'], raw_update, 0.0)
304
313
  state['prev_sign'] = current_sign
305
314
 
315
+ l1_mean = _get_l1_adaptive_lr(p, raw_update, state, group, kappa_p)
316
+
317
+ if group.get('stochastic_sign', False):
318
+ update = apply_stochastic_sign(raw_update, noise=random_noise_tensor)
319
+ else:
320
+ update = _get_lion_k_update(raw_update, kappa_p)
321
+
306
322
  if l1_mean is not None:
307
323
  update.mul_(l1_mean)
308
324
 
@@ -311,7 +327,7 @@ class SignSGD_adv(torch.optim.Optimizer):
311
327
  else:
312
328
  update.mul_(lr)
313
329
 
314
- param_update.apply_parameter_update(self, p, group, update, lr, random_int_tensor=random_int_tensor)
330
+ param_update.apply_parameter_update(self, p, group, update, lr, random_int_tensor=random_int_tensor, wd_scaler=l1_mean)
315
331
 
316
332
  def compile(self, *args, **kwargs):
317
333
  self._compiled_step_parameter = torch.compile(self._step_parameter, *args, **kwargs)
@@ -288,7 +288,7 @@ class Simplified_AdEMAMix(torch.optim.Optimizer):
288
288
  # Pre-generate random tensor for stochastic rounding if needed.
289
289
  random_int_tensor = param_update._get_random_int_for_sr(p)
290
290
  # TODO, workaround until pytorch#169634 is fixed
291
- lr = torch.as_tensor(lr, dtype=torch.float64)
291
+ lr = torch.as_tensor(lr)
292
292
  step_param_fn = self._compiled_step_parameter
293
293
  else:
294
294
  step_param_fn = self._step_parameter
@@ -1,5 +1,7 @@
1
1
  import torch
2
2
 
3
+ import math
4
+
3
5
  @torch.no_grad()
4
6
  def _newton_schulz_iteration(
5
7
  G: torch.Tensor,
@@ -359,11 +361,11 @@ def rms_adjustment(update: torch.Tensor, rms_rescaling: bool, lr):
359
361
  # This is slower due to norm calculations but it worked the best for t2i models.
360
362
  rms_target = 0.2 # default (Adam) value for RMS
361
363
  update_norm = torch.linalg.vector_norm(update)
362
- return update.mul_(lr * rms_target * (update.numel()**0.5) / update_norm.clamp_min_(1e-8))
364
+ return update.mul_(lr * rms_target * (math.sqrt(update.numel())) / update_norm.clamp_min_(1e-8))
363
365
  else:
364
366
  # Original Muon scaling
365
367
  r, c = update.size(-2), update.size(-1)
366
- scaling_factor = max(1, r / c) ** 0.5
368
+ scaling_factor = math.sqrt(max(1, r / c))
367
369
  return update.mul_(lr * scaling_factor)
368
370
 
369
371
  def _auto_projection_for_adamuon(raw_update: torch.Tensor, kappa_p: float) -> torch.Tensor:
@@ -474,15 +476,15 @@ def get_spectral_scaling(shape: torch.Size, n_layers: int):
474
476
  # A) Newton-Schulz Damping
475
477
  # This ensures the matrix orthogonalization is stable across scales.
476
478
  # Formula: (1/L) * sqrt(d_in / d_out)
477
- ns_eps = (1.0 / L) * (d_in / d_out) ** 0.5
479
+ ns_eps = (1.0 / L) * math.sqrt(d_in / d_out)
478
480
 
479
481
  # B) Adaptive Denominator Epsilon
480
482
  # This ensures the Adam-style division doesn't explode or vanish.
481
483
  # Formula: (1/L) * (1 / sqrt(d_in * d_out))
482
- adaptive_eps = (1.0 / L) * (1.0 / (d_in * d_out)**0.5)
484
+ adaptive_eps = (1.0 / L) * (1.0 / math.sqrt(d_in * d_out))
483
485
 
484
486
  # Spectral Target (Section F) -> sqrt(d_out/d_in)
485
- spectral_target = (d_out / d_in) ** 0.5
487
+ spectral_target = math.sqrt(d_out / d_in)
486
488
 
487
489
  # Weight Decay (Section 3.4) -> 1/width
488
490
  wd_scale = 1.0 / d_in
@@ -4,7 +4,7 @@ from torch.optim import Optimizer
4
4
 
5
5
  from typing import Dict, Any
6
6
 
7
- from .scaled_optm import scale_wds
7
+ from .scaled_optm import adjust_wds, scale_wds
8
8
  from .centered_decay import dequantize_anchor
9
9
 
10
10
  _generators: Dict[torch.device, torch.Generator] = {}
@@ -29,11 +29,17 @@ def _apply_weight_decay(
29
29
  # Cautious Weight Decay: only decay if the update pushes in the same direction as the decay
30
30
  if cautious:
31
31
  mask = (update_calc * p_calc >= 0).to(p_calc.dtype)
32
- p_calc.addcmul_(p_calc, mask, value=-scaled_wd)
32
+ if isinstance(scaled_wd, Tensor):
33
+ p_calc.addcmul_(p_calc, mask * scaled_wd, value=-1.0)
34
+ else:
35
+ p_calc.addcmul_(p_calc, mask, value=-scaled_wd)
33
36
  del mask
34
37
  else:
35
38
  # Standard decoupled weight decay
36
- p_calc.add_(p_calc, alpha=-scaled_wd)
39
+ if isinstance(scaled_wd, Tensor):
40
+ p_calc.addcmul_(p_calc, scaled_wd, value=-1.0)
41
+ else:
42
+ p_calc.add_(p_calc, alpha=-scaled_wd)
37
43
 
38
44
  # Centered Weight Decay (pulls toward anchor)
39
45
  if scaled_cwd is not None and 'anchor_type' in state:
@@ -43,15 +49,20 @@ def _apply_weight_decay(
43
49
  if cautious:
44
50
  # Cautious Weight Decay: only decay if the update pushes in the same direction as the decay
45
51
  mask = (update_calc * decay_target >= 0).to(p_calc.dtype)
46
- p_calc.addcmul_(decay_target, mask, value=-scaled_cwd)
52
+ if isinstance(scaled_cwd, Tensor):
53
+ p_calc.addcmul_(decay_target, mask * scaled_cwd, value=-1.0)
54
+ else:
55
+ p_calc.addcmul_(decay_target, mask, value=-scaled_cwd)
47
56
  del mask
48
57
  else:
49
58
  # Standard decoupled weight decay
50
- p_calc.add_(decay_target, alpha=-scaled_cwd)
59
+ if isinstance(scaled_cwd, Tensor):
60
+ p_calc.addcmul_(decay_target, scaled_cwd, value=-1.0)
61
+ else:
62
+ p_calc.add_(decay_target, alpha=-scaled_cwd)
51
63
 
52
64
  del anchor, decay_target
53
65
 
54
-
55
66
  def apply_parameter_update(
56
67
  self,
57
68
  p: Tensor,
@@ -61,6 +72,7 @@ def apply_parameter_update(
61
72
  wd: float | None = None,
62
73
  random_int_tensor: Tensor | None = None,
63
74
  decoupled: bool = False,
75
+ wd_scaler: float | Tensor | None = None,
64
76
  ) -> None:
65
77
  """
66
78
  Applies decoupled weight decay (standard, cautious, centered) and the final
@@ -75,13 +87,16 @@ def apply_parameter_update(
75
87
  random_int_tensor: Optional pre-generated random tensor for stochastic
76
88
  rounding. Required for the `torch.compile` path.
77
89
  decoupled: Whenever to use the true decoupled weight decay.
90
+ wd_scaler: A multiplier/tensor to scale the calculated wd/cwd magnitude (e.g. for Fisher Adam WD).
78
91
  """
79
92
  wd = group["weight_decay"] if wd is None else wd
80
93
  cwd = group.get("centered_wd", 0.0)
81
94
 
82
95
  if group.get('scaled_optm', False):
83
96
  decoupled = True
84
- wd, cwd = scale_wds(wd, cwd, p)
97
+ wd, cwd = adjust_wds(wd, cwd, p)
98
+ if wd_scaler is None:
99
+ wd, cwd = scale_wds(wd, cwd, p)
85
100
 
86
101
  # Calculate global decay factor for decoupled vs standard
87
102
  decay_factor = (lr / self._init_lr) if decoupled else lr
@@ -89,6 +104,12 @@ def apply_parameter_update(
89
104
  scaled_wd = (wd * decay_factor) if wd != 0 else None
90
105
  scaled_cwd = (cwd * decay_factor) if cwd != 0 else None
91
106
 
107
+ if wd_scaler is not None:
108
+ if scaled_wd is not None:
109
+ scaled_wd = scaled_wd * wd_scaler
110
+ if scaled_cwd is not None:
111
+ scaled_cwd = scaled_cwd * wd_scaler
112
+
92
113
  state = self.state[p]
93
114
 
94
115
  # Compute full update in float32 if using bfloat16 with stochastic rounding
@@ -284,3 +305,24 @@ def post_process_loaded_state(optimizer: Optimizer) -> None:
284
305
  # Ensure device match
285
306
  if state[key].device != p.device:
286
307
  state[key] = state[key].to(p.device)
308
+
309
+
310
+ def _get_random_noise_for_sso(source: torch.Tensor) -> torch.Tensor:
311
+ """
312
+ Generates a random noise tensor for Stochastic Sign operator.
313
+ This function is not torch.compile-path friendly due to its use of torch.Generator.
314
+ """
315
+ global _generators
316
+ device = source.device
317
+ if device not in _generators:
318
+ set_seed(device)
319
+ # TODO, this is a workaround until torch compile error
320
+ # NotImplementedError: UserDefinedObjectVariable(generator) is fixed
321
+ generator = _generators[device]
322
+ # create a random noise tensor
323
+ return torch.randint(
324
+ size=source.shape,
325
+ device=source.device,
326
+ dtype=source.dtype,
327
+ generator=generator,
328
+ )
@@ -2,6 +2,8 @@ import torch
2
2
 
3
3
  from . import param_update
4
4
 
5
+ import math
6
+
5
7
  def scale_update(
6
8
  p: torch.Tensor,
7
9
  update: torch.Tensor,
@@ -26,16 +28,16 @@ def scale_update(
26
28
 
27
29
  # DoRA Magnitude Scales (1D) or 1D Bias/Norm layers
28
30
  if is_dora_scale or p.ndim == 1:
29
- return rms_normalization(update, dim=None, lr=lr)
31
+ return l2_normalization(update, dim=None, lr=lr)
30
32
 
31
33
  # Orthogonal Fine-Tuning (OFT)
32
34
  # This guarantees O(1) update complexity scaling, independent of block sizes.
33
35
  if is_oft:
34
36
  n = update.shape[1]
35
37
  # Calculate block size (b)
36
- b = (1 + (1 + 8 * n) ** 0.5) / 2
37
- target_norm = (b / 8) ** 0.5
38
- scale = target_norm / (n ** 0.5)
38
+ b = (1 + math.sqrt(1 + 8 * n)) / 2
39
+ target_norm = math.sqrt(b / 8)
40
+ scale = target_norm / math.sqrt(n)
39
41
  return rms_normalization(update, dim=1, lr=lr * scale)
40
42
 
41
43
  # LoRA Factors or Full Finetuning weights
@@ -46,35 +48,49 @@ def scale_update(
46
48
  return update.mul_(lr)
47
49
 
48
50
 
49
- def scale_wds(wd: float, cwd: float, p: torch.Tensor) -> tuple[float, float]:
51
+ def scale_eps(group: dict, p) -> tuple[float, float]:
50
52
  """
51
- Adjusts standard weight decay and centered weight decay based on the parameter's
52
- shape and type to maintain effective regularization strength.
53
+ Scales Adam eps to be scale-invariant.
54
+ """
55
+ if group.get('scaled_optm', False):
56
+ adaptive_eps = (1.0 / group['n_layers']) * (1.0 / math.sqrt(p.numel()))
57
+ else:
58
+ adaptive_eps = group['eps']
59
+ return adaptive_eps
60
+
61
+ def adjust_wds(wd: float, cwd: float, p: torch.Tensor) -> tuple[float, float]:
62
+ """
63
+ Adjusts standard weight decay and centered weight decay.
53
64
  """
54
65
  # DoRA Scale (Magnitude Vector)
55
66
  if getattr(p, '_is_dora_scale', False):
56
67
  return wd, cwd
57
68
 
58
- conflict = cwd != 0
59
-
60
69
  if getattr(p, '_is_oft', False):
61
- # Fallback to standard WD (using cwd value) if both are active.
62
- return (cwd if conflict else wd), 0.0
70
+ return wd, 0.0
63
71
 
64
72
  if p.ndim >= 2:
65
- fan_in = p.numel() // p.shape[0]
66
-
67
- # When both WDs are active on LoRA, fallback to standard WD (using cwd value)
68
- # Reverts the behavior for better DoRA tuning.
69
73
  is_lora = getattr(p, '_is_lora_A', False) or getattr(p, '_is_lora_B', False)
70
- if conflict and is_lora:
71
- return cwd / fan_in, 0.0
74
+ if is_lora:
75
+ return wd, 0.0
72
76
 
77
+ else:
78
+ # 1D Biases or generic 1D parameters
79
+ # Centered WD safely regularizes the delta without collapsing base feature variance.
80
+ return 0.0, cwd
81
+
82
+
83
+ def scale_wds(wd: float, cwd: float, p: torch.Tensor) -> tuple[float, float]:
84
+ """
85
+ Scales standard weight decay and centered weight decay based on the parameter's
86
+ shape and type to maintain effective regularization strength.
87
+ """
88
+ if p.ndim >= 2:
89
+ fan_in = p.numel() // p.shape[0]
73
90
  return wd / fan_in, cwd / fan_in
74
91
 
75
- # 1D Biases or generic 1D parameters
76
- # Centered WD safely regularizes the delta without collapsing base feature variance.
77
- return 0.0, cwd
92
+ # 1D tensors (like DoRA scale and Biases)
93
+ return wd, cwd
78
94
 
79
95
 
80
96
  @torch.no_grad()
@@ -89,7 +105,7 @@ def rms_normalization(update: torch.Tensor, dim: int | None, lr: float) -> torch
89
105
  """Performs Root Mean Square normalization on the update tensor."""
90
106
  n = update.numel() if dim is None else update.shape[dim]
91
107
  norm = torch.linalg.vector_norm(update, ord=2, dim=dim, keepdim=True).clamp_min_(1e-12)
92
- scale_n = n**0.5
108
+ scale_n = math.sqrt(n)
93
109
  return update.mul_(lr * scale_n / norm)
94
110
 
95
111
 
@@ -119,7 +135,7 @@ def spectral_normalization(update: torch.Tensor, vector_state: torch.Tensor, lr:
119
135
  update = update.to(vector_state.dtype)
120
136
  update_flat = update.view(d_out, d_in)
121
137
  # Target scale derived from the "Modular Norm" paper
122
- target_scale = (d_out / d_in) ** 0.5
138
+ target_scale = math.sqrt(d_out / d_in)
123
139
  # Power Iteration step to estimate the largest singular value (sigma)
124
140
  # u = Wv
125
141
  u = torch.mv(update_flat, vector_state)
@@ -0,0 +1,13 @@
1
+ import torch
2
+
3
+
4
+ def apply_stochastic_sign(update: torch.Tensor, noise: torch.Tensor | None) -> torch.Tensor:
5
+ """
6
+ Applies the Stochastic Sign operator S_R(v).
7
+ Uses uniform noise injection to compute the stochastic sign
8
+ """
9
+ R = update.abs().max().clamp_min(1e-12)
10
+
11
+ if noise is None:
12
+ noise = torch.rand_like(update) * 2.0 - 1.0
13
+ return torch.sign(update / R + noise, out=update)
@@ -1,5 +1,7 @@
1
1
  import torch
2
2
 
3
+ import math
4
+
3
5
  def _grams_update(mt: torch.Tensor, grad: torch.Tensor, inplace: bool=False):
4
6
  """
5
7
  Applies the update rule of "Gradient Descent with Adaptive Momentum Scaling"
@@ -31,27 +33,63 @@ def _scale_sim_AdEMAMix_update(beta: float, current_step: int, alpha_grad: float
31
33
  lr = lr * total_scale
32
34
  return lr
33
35
 
36
+
37
+ def _init_fisher_wd_scaler(group: dict, state: dict, p: torch.Tensor) -> torch.Tensor | None:
38
+ if not group.get('fisher_wd', False):
39
+ return
40
+
41
+ state["wd_scaler"] = torch.tensor(1.0, device=p.device)
42
+
43
+ def _get_fisher_wd_scaler(group: dict, stored_scaler: torch.Tensor, p: torch.Tensor, denom: torch.Tensor, atan2: bool) -> torch.Tensor | None:
44
+ """
45
+ Calculates the Fisher weight decay scaler.
46
+ Maps the decay direction through the empirical Fisher information matrix
47
+ and clips its RMS to ensure stability.
48
+ From the paper:
49
+ "FAdam: Adam is a natural gradient optimizer using diagonal empirical Fisher information"
50
+ """
51
+ if not group.get('fisher_wd', False):
52
+ return None
53
+
54
+ if atan2:
55
+ wd_scaler = torch.atan2(stored_scaler, denom).mul_(4 / math.pi)
56
+ else:
57
+ eps = group.get('eps', 1e-8)
58
+ wd_scaler = 1.0 / (denom + eps)
59
+
60
+ # Reshape scaler if necessary to match parameter shape (for factored states)
61
+ wd_scaler = wd_scaler.view(p.shape)
62
+
63
+ gw_rms = torch.sqrt(torch.mean((p * wd_scaler) ** 2))
64
+ clip_coef = torch.clamp(gw_rms / 1.0, min=1.0)
65
+ return wd_scaler / clip_coef
66
+
34
67
  def _get_l1_adaptive_lr(
35
68
  p: torch.Tensor,
36
69
  update: torch.Tensor,
37
70
  state: dict,
38
71
  group: dict,
39
- kappa_p: float
72
+ kappa_p: float,
73
+ rescale: bool = False,
40
74
  ) -> torch.Tensor:
41
75
  """
42
76
  Calculates the L1 adaptive learning rate based on gradient heterogeneity.
43
77
  """
44
- if not group.get("l1_adaptive", False) and kappa_p != 1:
78
+ if not group.get("l1_adaptive", False) or kappa_p != 1:
45
79
  return None
46
80
 
47
- momentum = group["momentum"]
48
- alpha_grad = group["alpha_grad"]
49
81
  update_view = update.view(p.shape)
50
82
 
51
- # Calculate scale factor based on momentum/update magnitude
52
- scale_factor = _scale_sim_AdEMAMix_update(
53
- momentum, state["step"] + 1, alpha_grad, 1, False
54
- )
83
+ if rescale:
84
+ momentum = group["momentum"]
85
+ alpha_grad = group["alpha_grad"]
86
+
87
+ # Calculate scale factor based on momentum/update magnitude
88
+ scale_factor = _scale_sim_AdEMAMix_update(
89
+ momentum, state["step"] + 1, alpha_grad, 1, False
90
+ )
91
+ else:
92
+ scale_factor = 1
55
93
 
56
94
  # Determine dimension for mean calculation based on parameter type
57
95
  if getattr(p, '_is_oft', False) or getattr(p, '_is_lora_A', False):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: adv_optm
3
- Version: 2.4.dev4
3
+ Version: 2.4.dev6
4
4
  Summary: A family of highly efficient, lightweight yet powerful optimizers.
5
5
  Home-page: https://github.com/Koratahiu/Advanced_Optimizers
6
6
  Author: Koratahiu
@@ -27,4 +27,5 @@ adv_optm/util/factorization_util.py
27
27
  adv_optm/util/lion_k.py
28
28
  adv_optm/util/param_update.py
29
29
  adv_optm/util/scaled_optm.py
30
+ adv_optm/util/signed_util.py
30
31
  adv_optm/util/update_util.py
@@ -5,7 +5,7 @@ with open("README.md", "r", encoding="utf-8") as fh:
5
5
 
6
6
  setup(
7
7
  name="adv_optm",
8
- version="2.4.dev4",
8
+ version="2.4.dev6",
9
9
  author="Koratahiu",
10
10
  author_email="hiuhonor@gmail.com",
11
11
  license='Apache 2.0',
File without changes
File without changes
File without changes