adv-optm 2.4.dev4__tar.gz → 2.4.dev5__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. {adv_optm-2.4.dev4 → adv_optm-2.4.dev5}/PKG-INFO +1 -1
  2. {adv_optm-2.4.dev4 → adv_optm-2.4.dev5}/adv_optm/__init__.py +1 -1
  3. {adv_optm-2.4.dev4 → adv_optm-2.4.dev5}/adv_optm/optim/AdaMuon_adv.py +7 -3
  4. {adv_optm-2.4.dev4 → adv_optm-2.4.dev5}/adv_optm/optim/AdamW_adv.py +17 -4
  5. {adv_optm-2.4.dev4 → adv_optm-2.4.dev5}/adv_optm/optim/Adopt_adv.py +13 -4
  6. {adv_optm-2.4.dev4 → adv_optm-2.4.dev5}/adv_optm/optim/Lion_Prodigy_adv.py +2 -2
  7. {adv_optm-2.4.dev4 → adv_optm-2.4.dev5}/adv_optm/optim/Lion_adv.py +35 -21
  8. {adv_optm-2.4.dev4 → adv_optm-2.4.dev5}/adv_optm/optim/Muon_adv.py +6 -2
  9. {adv_optm-2.4.dev4 → adv_optm-2.4.dev5}/adv_optm/optim/Prodigy_adv.py +14 -5
  10. {adv_optm-2.4.dev4 → adv_optm-2.4.dev5}/adv_optm/optim/SignSGD_adv.py +31 -15
  11. {adv_optm-2.4.dev4 → adv_optm-2.4.dev5}/adv_optm/optim/Simplified_AdEMAMix.py +1 -1
  12. {adv_optm-2.4.dev4 → adv_optm-2.4.dev5}/adv_optm/util/Muon_util.py +7 -5
  13. {adv_optm-2.4.dev4 → adv_optm-2.4.dev5}/adv_optm/util/param_update.py +49 -7
  14. {adv_optm-2.4.dev4 → adv_optm-2.4.dev5}/adv_optm/util/scaled_optm.py +20 -16
  15. adv_optm-2.4.dev5/adv_optm/util/signed_util.py +13 -0
  16. {adv_optm-2.4.dev4 → adv_optm-2.4.dev5}/adv_optm/util/update_util.py +46 -8
  17. {adv_optm-2.4.dev4 → adv_optm-2.4.dev5}/adv_optm.egg-info/PKG-INFO +1 -1
  18. {adv_optm-2.4.dev4 → adv_optm-2.4.dev5}/adv_optm.egg-info/SOURCES.txt +1 -0
  19. {adv_optm-2.4.dev4 → adv_optm-2.4.dev5}/setup.py +1 -1
  20. {adv_optm-2.4.dev4 → adv_optm-2.4.dev5}/LICENSE +0 -0
  21. {adv_optm-2.4.dev4 → adv_optm-2.4.dev5}/README.md +0 -0
  22. {adv_optm-2.4.dev4 → adv_optm-2.4.dev5}/adv_optm/optim/__init__.py +0 -0
  23. {adv_optm-2.4.dev4 → adv_optm-2.4.dev5}/adv_optm/util/Kourkoutas.py +0 -0
  24. {adv_optm-2.4.dev4 → adv_optm-2.4.dev5}/adv_optm/util/Muon_AuxAdam.py +0 -0
  25. {adv_optm-2.4.dev4 → adv_optm-2.4.dev5}/adv_optm/util/OrthoGrad.py +0 -0
  26. {adv_optm-2.4.dev4 → adv_optm-2.4.dev5}/adv_optm/util/__init__.py +0 -0
  27. {adv_optm-2.4.dev4 → adv_optm-2.4.dev5}/adv_optm/util/centered_decay.py +0 -0
  28. {adv_optm-2.4.dev4 → adv_optm-2.4.dev5}/adv_optm/util/factorization_util.py +0 -0
  29. {adv_optm-2.4.dev4 → adv_optm-2.4.dev5}/adv_optm/util/lion_k.py +0 -0
  30. {adv_optm-2.4.dev4 → adv_optm-2.4.dev5}/adv_optm.egg-info/dependency_links.txt +0 -0
  31. {adv_optm-2.4.dev4 → adv_optm-2.4.dev5}/adv_optm.egg-info/requires.txt +0 -0
  32. {adv_optm-2.4.dev4 → adv_optm-2.4.dev5}/adv_optm.egg-info/top_level.txt +0 -0
  33. {adv_optm-2.4.dev4 → adv_optm-2.4.dev5}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: adv_optm
3
- Version: 2.4.dev4
3
+ Version: 2.4.dev5
4
4
  Summary: A family of highly efficient, lightweight yet powerful optimizers.
5
5
  Home-page: https://github.com/Koratahiu/Advanced_Optimizers
6
6
  Author: Koratahiu
@@ -22,4 +22,4 @@ __all__ = [
22
22
  "SignSGD_adv",
23
23
  ]
24
24
 
25
- __version__ = "2.4.dev4"
25
+ __version__ = "2.4.dev5"
@@ -206,6 +206,8 @@ class AdaMuon_adv(torch.optim.Optimizer):
206
206
  if spectral_normalization and rms_rescaling:
207
207
  print("Warning: spectral_normalization is incompatible with rms_rescaling, Disabling rms_rescaling.")
208
208
  rms_rescaling = False
209
+ if spectral_normalization and accelerated_ns:
210
+ ValueError("spectral_normalization violates accelerated Newton-Schulz assumptions. Pick one of them.")
209
211
 
210
212
  defaults = {
211
213
  "lr": lr, "betas": betas, "weight_decay": weight_decay, "cautious_wd": cautious_wd,
@@ -260,6 +262,8 @@ class AdaMuon_adv(torch.optim.Optimizer):
260
262
  if group.get('use_muon') is None: # Fallback
261
263
  group['use_muon'] = group.get('optim_type') == 'muon'
262
264
 
265
+ self.init_step()
266
+
263
267
  self.kourkoutas_helper = None
264
268
  if any(group.get('adam_kourkoutas_beta', False) for group in self.param_groups):
265
269
  self.kourkoutas_helper = KourkoutasHelper(self)
@@ -419,7 +423,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
419
423
  step_size = group['lr'] / bias_correction1
420
424
 
421
425
  if is_compiled:
422
- step_size = torch.as_tensor(step_size, dtype=torch.float64)
426
+ step_size = torch.as_tensor(step_size)
423
427
  adam_step_param = self._compiled_adam_step_parameter
424
428
  else:
425
429
  adam_step_param = Muon_AuxAdam._adam_step_parameter
@@ -430,7 +434,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
430
434
 
431
435
  else: # Muon path
432
436
  if is_compiled:
433
- lr = torch.as_tensor(group['lr'], dtype=torch.float64)
437
+ lr = torch.as_tensor(group['lr'])
434
438
  muon_step_param = self._compiled_muon_step_parameter
435
439
  else:
436
440
  lr = group['lr']
@@ -467,7 +471,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
467
471
  else:
468
472
  shape_for_scaling = p.shape
469
473
 
470
- scaled_eps, adaptive_eps, spectral_target, wd_scale = get_spectral_scaling(shape_for_scaling, group['n_layers'])
474
+ scaled_eps, adaptive_eps, spectral_target, wd_scale = get_spectral_scaling(p, shape_for_scaling, group['n_layers'])
471
475
 
472
476
  weight_decay = group['weight_decay'] * wd_scale
473
477
  decoupled_wd = True
@@ -6,7 +6,7 @@ from typing import Optional, Callable
6
6
 
7
7
  from ..util import param_update
8
8
  from ..util.factorization_util import _get_effective_shape, _reconstruct_state, _factorize_state
9
- from ..util.update_util import _grams_update, _cautious_update
9
+ from ..util.update_util import _grams_update, _cautious_update, _init_fisher_wd_scaler, _get_fisher_wd_scaler
10
10
  from ..util.OrthoGrad import _orthogonalize_gradient
11
11
  from ..util.Kourkoutas import KourkoutasHelper
12
12
  from ..util.scaled_optm import scale_update, is_spectral, init_spectral_norm
@@ -29,6 +29,9 @@ class AdamW_adv(torch.optim.Optimizer):
29
29
  eps (float): term added to the denominator to improve
30
30
  numerical stability (default: 1e-8)
31
31
  weight_decay (float): weight decay (L2 penalty) (default: 0).
32
+ fisher_wd (bool): whether to use Fisher Adam (FAdam) weight decay, mapping
33
+ the decay direction through the empirical Fisher information matrix and
34
+ clipping its RMS. (default: False)
32
35
  cautious_wd (bool): Enables Cautious Weight Decay. If True, weight decay is
33
36
  applied only to parameter coordinates where the sign of the parameter
34
37
  and the sign of the optimizer update align (default: False).
@@ -103,6 +106,7 @@ class AdamW_adv(torch.optim.Optimizer):
103
106
  eps: float = 1e-8,
104
107
  # Decoupled/cautious weight decay
105
108
  weight_decay: float = 0.0,
109
+ fisher_wd: bool = False,
106
110
  cautious_wd: bool = False,
107
111
  # Adam's Bias Correction
108
112
  use_bias_correction: bool = True,
@@ -155,7 +159,8 @@ class AdamW_adv(torch.optim.Optimizer):
155
159
  cautious_mask = False
156
160
 
157
161
  defaults = {
158
- "lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay, "cautious_wd": cautious_wd,
162
+ "lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay,
163
+ "fisher_wd": fisher_wd, "cautious_wd": cautious_wd,
159
164
  "use_atan2": use_atan2,
160
165
  "orthogonal_gradient": orthogonal_gradient, "use_bias_correction": use_bias_correction,
161
166
  "beta3_ema": beta3_ema, "alpha": alpha, "compiled_optimizer": compiled_optimizer,
@@ -273,6 +278,8 @@ class AdamW_adv(torch.optim.Optimizer):
273
278
 
274
279
  _init_anchor(p, state, group)
275
280
 
281
+ _init_fisher_wd_scaler(group, state, p)
282
+
276
283
  beta1, beta2 = group['betas']
277
284
 
278
285
  current_step = state['step']
@@ -294,7 +301,7 @@ class AdamW_adv(torch.optim.Optimizer):
294
301
  random_int_tensor = None
295
302
 
296
303
  if group.get('compiled_optimizer', False):
297
- step_size = torch.as_tensor(step_size, dtype=torch.float64)
304
+ step_size = torch.as_tensor(step_size)
298
305
  if p.dtype == torch.bfloat16 and self.stochastic_rounding:
299
306
  # Pre-generate random tensor for stochastic rounding if needed.
300
307
  random_int_tensor = param_update._get_random_int_for_sr(p)
@@ -389,6 +396,9 @@ class AdamW_adv(torch.optim.Optimizer):
389
396
  denom = vt.sqrt_()
390
397
  denom.div_(sqrt_bias_correction2).add_(group['eps'])
391
398
  update.div_(denom)
399
+
400
+ wd_scaler = _get_fisher_wd_scaler(group, state.get("wd_scaler"), p, denom, group['use_atan2'])
401
+
392
402
  del vt
393
403
 
394
404
  update = update.view(p.shape)
@@ -430,6 +440,9 @@ class AdamW_adv(torch.optim.Optimizer):
430
440
  denom = exp_avg_sq.sqrt()
431
441
  denom.div_(sqrt_bias_correction2).add_(group['eps'])
432
442
  update.div_(denom)
443
+
444
+ wd_scaler = _get_fisher_wd_scaler(group, state.get("wd_scaler"), p, denom, group['use_atan2'])
445
+
433
446
  del denom
434
447
 
435
448
  update_scaling = step_size * A if group['use_atan2'] else step_size
@@ -438,7 +451,7 @@ class AdamW_adv(torch.optim.Optimizer):
438
451
  else:
439
452
  update.mul_(update_scaling)
440
453
 
441
- param_update.apply_parameter_update(self, p, group, update, step_size, random_int_tensor=random_int_tensor)
454
+ param_update.apply_parameter_update(self, p, group, update, step_size, random_int_tensor=random_int_tensor, wd_scaler=wd_scaler)
442
455
 
443
456
  def compile(self, *args, **kwargs):
444
457
  self._compiled_step_parameter = torch.compile(self._step_parameter, *args, **kwargs)
@@ -7,7 +7,7 @@ from ..util import param_update
7
7
  from ..util.factorization_util import _get_effective_shape, _reconstruct_state, _factorize_state, _nnmf
8
8
  from ..util.OrthoGrad import _orthogonalize_gradient
9
9
  from ..util.Kourkoutas import KourkoutasHelper
10
- from ..util.update_util import _grams_update, _cautious_update, _scale_sim_AdEMAMix_update
10
+ from ..util.update_util import _grams_update, _cautious_update, _scale_sim_AdEMAMix_update, _init_fisher_wd_scaler, _get_fisher_wd_scaler
11
11
  from ..util.scaled_optm import scale_update, is_spectral, init_spectral_norm
12
12
  from ..util.centered_decay import _init_anchor
13
13
 
@@ -33,6 +33,9 @@ class Adopt_adv(torch.optim.Optimizer):
33
33
  eps (float): term added to the denominator to improve
34
34
  numerical stability (default: 1e-6)
35
35
  weight_decay (float): weight decay (L2 penalty) (default: 0)
36
+ fisher_wd (bool): whether to use Fisher Adam (FAdam) weight decay, mapping
37
+ the decay direction through the empirical Fisher information matrix and
38
+ clipping its RMS. (default: False)
36
39
  cautious_wd (bool): Enables Cautious Weight Decay. If True, weight decay is
37
40
  applied only to parameter coordinates where the sign of the parameter
38
41
  and the sign of the optimizer update align (default: False).
@@ -119,6 +122,7 @@ class Adopt_adv(torch.optim.Optimizer):
119
122
  eps: float = 1e-6,
120
123
  # Decoupled/cautious weight decay
121
124
  weight_decay: float = 0.0,
125
+ fisher_wd: bool = False,
122
126
  cautious_wd: bool = False,
123
127
  # ADOPT clipping
124
128
  clip_lambda: Optional[Callable[[int], float]] = lambda step: step**0.25,
@@ -181,7 +185,8 @@ class Adopt_adv(torch.optim.Optimizer):
181
185
  print("Warning: cautious is incompatible with Simplified_AdEMAMix, Disabling cautious.")
182
186
 
183
187
  defaults = {
184
- "lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay, "cautious_wd": cautious_wd,
188
+ "lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay,
189
+ "fisher_wd": fisher_wd, "cautious_wd": cautious_wd,
185
190
  "beta3_ema": beta3_ema, "alpha": alpha,
186
191
  "alpha_grad": alpha_grad,
187
192
  "kourkoutas_beta": kourkoutas_beta, "beta2_min": beta2_min, "ema_alpha": ema_alpha,
@@ -312,6 +317,8 @@ class Adopt_adv(torch.optim.Optimizer):
312
317
 
313
318
  _init_anchor(p, state, group)
314
319
 
320
+ _init_fisher_wd_scaler(group, state, p)
321
+
315
322
  current_step = state['step']
316
323
 
317
324
  # The first step is for initialization only (skip when use_atan2 as it's scale invariant).
@@ -322,7 +329,7 @@ class Adopt_adv(torch.optim.Optimizer):
322
329
  random_int_tensor = None
323
330
 
324
331
  if group.get('compiled_optimizer', False):
325
- lr = torch.as_tensor(group['lr'], dtype=torch.float64)
332
+ lr = torch.as_tensor(group['lr'])
326
333
  if p.dtype == torch.bfloat16 and self.stochastic_rounding:
327
334
  # Pre-generate random tensor for stochastic rounding if needed.
328
335
  random_int_tensor = param_update._get_random_int_for_sr(p)
@@ -366,6 +373,7 @@ class Adopt_adv(torch.optim.Optimizer):
366
373
 
367
374
  # ADOPT Step A: Decorrelate g_t using v_{t-1}
368
375
  denom = vt.sqrt()
376
+ wd_scaler = _get_fisher_wd_scaler(group, state.get("wd_scaler"), p, denom, self.use_atan2)
369
377
 
370
378
  # Update second moment v_t for the *next* step using raw g_t
371
379
  if isinstance(beta2, torch.Tensor) and beta2.dim() > 0:
@@ -444,6 +452,7 @@ class Adopt_adv(torch.optim.Optimizer):
444
452
 
445
453
  # ADOPT Step A: Decorrelate g_t using v_{t-1}
446
454
  denom = vt.sqrt()
455
+ wd_scaler = _get_fisher_wd_scaler(group, state.get("wd_scaler"), p, denom, self.use_atan2)
447
456
 
448
457
  if self.use_atan2:
449
458
  normalized_grad = torch.atan2(grad, denom, out=denom)
@@ -499,7 +508,7 @@ class Adopt_adv(torch.optim.Optimizer):
499
508
  update.mul_(update_scaling)
500
509
 
501
510
  # Parameter Update
502
- param_update.apply_parameter_update(self, p, group, update, lr, random_int_tensor=random_int_tensor)
511
+ param_update.apply_parameter_update(self, p, group, update, lr, random_int_tensor=random_int_tensor, wd_scaler=wd_scaler)
503
512
 
504
513
  def compile(self, *args, **kwargs):
505
514
  self._compiled_step_parameter = torch.compile(self._step_parameter, *args, **kwargs)
@@ -225,8 +225,8 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
225
225
  # Pre-generate random tensor for stochastic rounding if needed.
226
226
  random_int_tensor = param_update._get_random_int_for_sr(p)
227
227
  # TODO, workaround until pytorch#169634 is fixed
228
- d = torch.as_tensor(group['d'], dtype=torch.float64)
229
- dlr = torch.as_tensor(dlr, dtype=torch.float64)
228
+ d = torch.as_tensor(group['d'])
229
+ dlr = torch.as_tensor(dlr)
230
230
  step_param_fn = self._compiled_step_parameter
231
231
  else:
232
232
  d = group['d']
@@ -9,6 +9,7 @@ from ..util.lion_k import _get_lion_k_update
9
9
  from ..util.scaled_optm import scale_update, is_spectral, init_spectral_norm
10
10
  from ..util.centered_decay import _init_anchor
11
11
  from ..util.update_util import _get_l1_adaptive_lr
12
+ from ..util.signed_util import apply_stochastic_sign
12
13
 
13
14
 
14
15
  class Lion_adv(torch.optim.Optimizer):
@@ -45,6 +46,7 @@ class Lion_adv(torch.optim.Optimizer):
45
46
  parameter dimensionality. Sets p=2.0 for 4D tensors (Conv2D) (Biases/Norms) to
46
47
  use Spherical updates, and p=1.0 for others (Linear/Embeddings) to use Sign
47
48
  updates. Overrides explicit kappa_p value. (default: False).
49
+ stochastic_sign (bool): whether to use the Stochastic Sign operator. (default: False)
48
50
  freeze_on_flip (bool): Projected SignGD One-hit freeze. Masks updates for
49
51
  coordinates where the gradient sign flips compared to the previous step. (default: False)
50
52
  l1_adaptive (bool): Scales learning rate dynamically
@@ -80,6 +82,8 @@ class Lion_adv(torch.optim.Optimizer):
80
82
  # Lion-k
81
83
  kappa_p: float = 1.0,
82
84
  auto_kappa_p: bool = False,
85
+ # Stochastic Sign Operator
86
+ stochastic_sign: bool = False,
83
87
  # Projected and adaptive sign
84
88
  freeze_on_flip: bool = False,
85
89
  l1_adaptive: bool = False,
@@ -111,6 +115,7 @@ class Lion_adv(torch.optim.Optimizer):
111
115
  clip_threshold=clip_threshold,
112
116
  kappa_p=kappa_p,
113
117
  auto_kappa_p=auto_kappa_p,
118
+ stochastic_sign=stochastic_sign,
114
119
  freeze_on_flip=freeze_on_flip,
115
120
  l1_adaptive=l1_adaptive,
116
121
  scaled_optm= scaled_optm,
@@ -202,19 +207,22 @@ class Lion_adv(torch.optim.Optimizer):
202
207
  lr = group["lr"]
203
208
 
204
209
  random_int_tensor = None
210
+ random_noise_tensor = None
205
211
 
206
212
  if group.get('compiled_optimizer', False):
207
213
  if p.dtype == torch.bfloat16 and self.stochastic_rounding:
208
214
  # Pre-generate random tensor for stochastic rounding if needed.
209
215
  random_int_tensor = param_update._get_random_int_for_sr(p)
210
- lr = torch.as_tensor(lr, dtype=torch.float64)
216
+ if group.get('stochastic_sign', False):
217
+ random_noise_tensor = param_update._get_random_noise_for_sso(p)
218
+ lr = torch.as_tensor(lr)
211
219
  step_param_fn = self._compiled_step_parameter
212
220
  else:
213
221
  step_param_fn = self._step_parameter
214
222
 
215
- step_param_fn(p, grad, state, group, lr, random_int_tensor)
223
+ step_param_fn(p, grad, state, group, lr, random_int_tensor, random_noise_tensor)
216
224
 
217
- def _step_parameter(self, p, grad, state, group, lr, random_int_tensor):
225
+ def _step_parameter(self, p, grad, state, group, lr, random_int_tensor, random_noise_tensor):
218
226
  if grad.dtype != torch.float32 and state['factored']:
219
227
  grad = grad.float()
220
228
  if group["clip_threshold"] > 0.0:
@@ -252,8 +260,6 @@ class Lion_adv(torch.optim.Optimizer):
252
260
  # Compute update term c_t
253
261
  update = torch.lerp(grad_reshaped, exp_avg, beta1)
254
262
 
255
- l1_mean = _get_l1_adaptive_lr(p, update, state, group, kappa_p)
256
-
257
263
  # Standard Lion momentum update
258
264
  # m_t = beta2 * m_{t-1} + (1-beta2) * g_t
259
265
  exp_avg.lerp_(grad_reshaped, 1 - beta2)
@@ -262,7 +268,12 @@ class Lion_adv(torch.optim.Optimizer):
262
268
  state['mu_m_nmf'], state['mv_m_nmf'], state['sign'] = _factorize_state(exp_avg, signed=True)
263
269
  del exp_avg
264
270
 
265
- update = _get_lion_k_update(update, kappa_p)
271
+ if freeze_on_flip:
272
+ # Fast binary diff (XOR) from momentum sign directly
273
+ flipped_packed = prev_sign_packed ^ state['sign']
274
+ flipped_mask = _unpack_bools(flipped_packed, original_m=d2).view_as(update)
275
+ update = torch.where(flipped_mask, 0.0, update)
276
+ del prev_sign_packed, flipped_packed, flipped_mask
266
277
 
267
278
  if self.cautious_mask:
268
279
  mask = (update * grad_reshaped > 0).to(grad_reshaped.dtype)
@@ -272,12 +283,12 @@ class Lion_adv(torch.optim.Optimizer):
272
283
 
273
284
  update = update.view(p.shape)
274
285
 
275
- if freeze_on_flip:
276
- # Fast binary diff (XOR) from momentum sign directly
277
- flipped_packed = prev_sign_packed ^ state['sign']
278
- flipped_mask = _unpack_bools(flipped_packed, original_m=d2).view_as(update)
279
- update = torch.where(flipped_mask, 0.0, update)
280
- del prev_sign_packed, flipped_packed, flipped_mask
286
+ l1_mean = _get_l1_adaptive_lr(p, update, state, group, kappa_p, rescale=False)
287
+
288
+ if group.get('stochastic_sign', False):
289
+ update = apply_stochastic_sign(update, noise=random_noise_tensor)
290
+ else:
291
+ update = _get_lion_k_update(update, kappa_p)
281
292
 
282
293
  else:
283
294
  # Fallback to standard Lion logic
@@ -286,9 +297,13 @@ class Lion_adv(torch.optim.Optimizer):
286
297
  # Compute update term
287
298
  update = torch.lerp(grad, exp_avg, beta1)
288
299
 
289
- l1_mean = _get_l1_adaptive_lr(p, update, state, group, kappa_p)
300
+ # Standard Lion momentum update
301
+ exp_avg.lerp_(grad, 1 - beta2)
290
302
 
291
- update = _get_lion_k_update(update, kappa_p)
303
+ if freeze_on_flip:
304
+ current_sign = (update > 0).to(torch.uint8)
305
+ update = torch.where(current_sign == state['prev_sign'], update, 0.0)
306
+ state['prev_sign'] = current_sign
292
307
 
293
308
  if self.cautious_mask:
294
309
  mask = (update * grad > 0).to(grad.dtype)
@@ -296,13 +311,12 @@ class Lion_adv(torch.optim.Optimizer):
296
311
  update.mul_(mask)
297
312
  del mask
298
313
 
299
- # Standard Lion momentum update
300
- exp_avg.lerp_(grad, 1 - beta2)
314
+ l1_mean = _get_l1_adaptive_lr(p, update, state, group, kappa_p, rescale=False)
301
315
 
302
- if freeze_on_flip:
303
- current_sign = (update > 0).to(torch.uint8)
304
- update = torch.where(current_sign == state['prev_sign'], update, 0.0)
305
- state['prev_sign'] = current_sign
316
+ if group.get('stochastic_sign', False):
317
+ update = apply_stochastic_sign(update, noise=random_noise_tensor)
318
+ else:
319
+ update = _get_lion_k_update(update, kappa_p)
306
320
 
307
321
  if l1_mean is not None:
308
322
  update.mul_(l1_mean)
@@ -312,7 +326,7 @@ class Lion_adv(torch.optim.Optimizer):
312
326
  else:
313
327
  update.mul_(lr)
314
328
 
315
- param_update.apply_parameter_update(self, p, group, update, lr, random_int_tensor=random_int_tensor)
329
+ param_update.apply_parameter_update(self, p, group, update, lr, random_int_tensor=random_int_tensor, wd_scaler=l1_mean)
316
330
 
317
331
  def compile(self, *args, **kwargs):
318
332
  self._compiled_step_parameter = torch.compile(self._step_parameter, *args, **kwargs)
@@ -183,6 +183,8 @@ class Muon_adv(torch.optim.Optimizer):
183
183
  if spectral_normalization and rms_rescaling:
184
184
  print("Warning: spectral_normalization is incompatible with rms_rescaling, Disabling rms_rescaling.")
185
185
  rms_rescaling = False
186
+ if spectral_normalization and accelerated_ns:
187
+ ValueError("spectral_normalization violates accelerated Newton-Schulz assumptions. Pick one of them.")
186
188
 
187
189
  defaults = {
188
190
  "lr": lr, "beta1": beta1, "weight_decay": weight_decay, "cautious_wd": cautious_wd,
@@ -239,6 +241,8 @@ class Muon_adv(torch.optim.Optimizer):
239
241
  if group.get('use_muon') is None: # Fallback
240
242
  group['use_muon'] = group.get('optim_type') == 'muon'
241
243
 
244
+ self.init_step()
245
+
242
246
  self.kourkoutas_helper = None
243
247
  if any(group.get('adam_kourkoutas_beta', False) for group in self.param_groups):
244
248
  self.kourkoutas_helper = KourkoutasHelper(self)
@@ -393,7 +397,7 @@ class Muon_adv(torch.optim.Optimizer):
393
397
  step_size = group['lr'] / bias_correction1
394
398
 
395
399
  if is_compiled:
396
- step_size = torch.as_tensor(step_size, dtype=torch.float64)
400
+ step_size = torch.as_tensor(step_size)
397
401
  adam_step_param = self._compiled_adam_step_parameter
398
402
  else:
399
403
  adam_step_param = Muon_AuxAdam._adam_step_parameter
@@ -404,7 +408,7 @@ class Muon_adv(torch.optim.Optimizer):
404
408
 
405
409
  else: # Muon path
406
410
  if is_compiled:
407
- lr = torch.as_tensor(group['lr'], dtype=torch.float64)
411
+ lr = torch.as_tensor(group['lr'])
408
412
  muon_step_param = self._compiled_muon_step_parameter
409
413
  else:
410
414
  lr = group['lr']
@@ -9,7 +9,7 @@ from ..util import param_update
9
9
  from ..util.OrthoGrad import _orthogonalize_gradient
10
10
  from ..util.Kourkoutas import KourkoutasHelper
11
11
  from ..util.factorization_util import _get_effective_shape, _reconstruct_state, _factorize_state
12
- from ..util.update_util import _grams_update, _cautious_update, _scale_sim_AdEMAMix_update
12
+ from ..util.update_util import _grams_update, _cautious_update, _scale_sim_AdEMAMix_update, _init_fisher_wd_scaler, _get_fisher_wd_scaler
13
13
  from ..util.centered_decay import _init_anchor
14
14
 
15
15
  A = 4 / math.pi
@@ -29,6 +29,9 @@ class Prodigy_adv(torch.optim.Optimizer):
29
29
  eps (float): term added to the denominator to improve
30
30
  numerical stability (default: 1e-8)
31
31
  weight_decay (float): weight decay (L2 penalty) (default: 0)
32
+ fisher_wd (bool): whether to use Fisher Adam (FAdam) weight decay, mapping
33
+ the decay direction through the empirical Fisher information matrix and
34
+ clipping its RMS. (default: False)
32
35
  cautious_wd (bool): Enables Cautious Weight Decay. If True, weight decay is
33
36
  applied only to parameter coordinates where the sign of the parameter
34
37
  and the sign of the optimizer update align (default: False).
@@ -133,6 +136,7 @@ class Prodigy_adv(torch.optim.Optimizer):
133
136
  eps: float = 1e-8,
134
137
  # Decoupled/cautious weight decay
135
138
  weight_decay: float = 0.0,
139
+ fisher_wd: bool = False,
136
140
  cautious_wd: bool = False,
137
141
  # Stochastic Rounding for BF16
138
142
  stochastic_rounding: bool = True,
@@ -206,7 +210,8 @@ class Prodigy_adv(torch.optim.Optimizer):
206
210
  raise ValueError(f"For Kourkoutas-β, betas[1] (as beta2_max) must be > beta2_min. Got {betas[1]} and {beta2_min}")
207
211
 
208
212
  defaults = {
209
- "lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay, "cautious_wd": cautious_wd,
213
+ "lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay,
214
+ "fisher_wd": fisher_wd, "cautious_wd": cautious_wd,
210
215
  "use_atan2": use_atan2,
211
216
  "orthogonal_gradient": orthogonal_gradient,
212
217
  "beta3_ema": beta3_ema, "alpha": alpha, "compiled_optimizer": compiled_optimizer,
@@ -354,6 +359,8 @@ class Prodigy_adv(torch.optim.Optimizer):
354
359
 
355
360
  _init_anchor(p, state, group)
356
361
 
362
+ _init_fisher_wd_scaler(group, state, p)
363
+
357
364
  if not hasattr(self, 'd_denom'):
358
365
  self.d_denom = torch.tensor(0.0, device=p.device)
359
366
  self.d_numerator = torch.tensor(group.get('d_numerator', 0.0), device=p.device)
@@ -376,8 +383,8 @@ class Prodigy_adv(torch.optim.Optimizer):
376
383
  # Pre-generate random tensor for stochastic rounding if needed.
377
384
  random_int_tensor = param_update._get_random_int_for_sr(p)
378
385
  # TODO, workaround until pytorch#169634 is fixed
379
- d = torch.as_tensor(group['d'], dtype=torch.float64)
380
- dlr = torch.as_tensor(dlr, dtype=torch.float64)
386
+ d = torch.as_tensor(group['d'])
387
+ dlr = torch.as_tensor(dlr)
381
388
  step_param_fn = self._compiled_step_parameter
382
389
  else:
383
390
  d = group['d']
@@ -478,6 +485,7 @@ class Prodigy_adv(torch.optim.Optimizer):
478
485
  else:
479
486
  denom = vt.sqrt_()
480
487
  update.div_(denom.add_(d * group['eps']))
488
+ wd_scaler = _get_fisher_wd_scaler(group, state.get("wd_scaler"), p, denom, group['use_atan2'])
481
489
  del vt
482
490
 
483
491
  update_scaling = dlr * A if group['use_atan2'] else dlr
@@ -528,6 +536,7 @@ class Prodigy_adv(torch.optim.Optimizer):
528
536
  else:
529
537
  denom = exp_avg_sq.sqrt()
530
538
  update.div_(denom.add_(d * group['eps']))
539
+ wd_scaler = _get_fisher_wd_scaler(group, state.get("wd_scaler"), p, denom, group['use_atan2'])
531
540
  del denom
532
541
 
533
542
  update_scaling = dlr * A if group['use_atan2'] else dlr
@@ -557,7 +566,7 @@ class Prodigy_adv(torch.optim.Optimizer):
557
566
  if 'p0' in state:
558
567
  del state['p0']
559
568
 
560
- param_update.apply_parameter_update(self, p, group, update, dlr, random_int_tensor=random_int_tensor)
569
+ param_update.apply_parameter_update(self, p, group, update, dlr, random_int_tensor=random_int_tensor, wd_scaler=wd_scaler)
561
570
 
562
571
  def compile(self, *args, **kwargs):
563
572
  self._compiled_step_parameter = torch.compile(self._step_parameter, *args, **kwargs)
@@ -9,6 +9,7 @@ from ..util.lion_k import _get_lion_k_update
9
9
  from ..util.update_util import _get_l1_adaptive_lr
10
10
  from ..util.scaled_optm import scale_update, is_spectral, init_spectral_norm
11
11
  from ..util.centered_decay import _init_anchor
12
+ from ..util.signed_util import apply_stochastic_sign
12
13
 
13
14
 
14
15
  class SignSGD_adv(torch.optim.Optimizer):
@@ -39,6 +40,7 @@ class SignSGD_adv(torch.optim.Optimizer):
39
40
  parameter dimensionality. Sets p=2.0 for 4D tensors (Conv2D) (Biases/Norms) to
40
41
  use Spherical updates, and p=1.0 for others (Linear/Embeddings) to use Sign
41
42
  updates. Overrides explicit kappa_p value. (default: False).
43
+ stochastic_sign (bool): whether to use the Stochastic Sign operator. (default: False)
42
44
  Simplified_AdEMAMix (bool): whether to use the Simplified AdEMAMix update rule.
43
45
  This changes the EMA to accumulator and the update numerator to `alpha_grad * grad + mt`, which can be
44
46
  more responsive, especially for small batch sizes. (default: False)
@@ -79,6 +81,8 @@ class SignSGD_adv(torch.optim.Optimizer):
79
81
  # Projection-k
80
82
  kappa_p: float = 1.0,
81
83
  auto_kappa_p: bool = True,
84
+ # Stochastic Sign Operator
85
+ stochastic_sign: bool = False,
82
86
  # Simplified_AdEMAMix
83
87
  alpha_grad: float = 1.0,
84
88
  Simplified_AdEMAMix: bool = False,
@@ -112,6 +116,7 @@ class SignSGD_adv(torch.optim.Optimizer):
112
116
  orthogonal_gradient=orthogonal_gradient,
113
117
  kappa_p=kappa_p,
114
118
  auto_kappa_p=auto_kappa_p,
119
+ stochastic_sign=stochastic_sign,
115
120
  alpha_grad=alpha_grad,
116
121
  Simplified_AdEMAMix=Simplified_AdEMAMix,
117
122
  scaled_optm= scaled_optm,
@@ -203,23 +208,26 @@ class SignSGD_adv(torch.optim.Optimizer):
203
208
  lr = group["lr"]
204
209
 
205
210
  random_int_tensor = None
211
+ random_noise_tensor = None
206
212
 
207
213
  if group.get('compiled_optimizer', False):
208
214
  if p.dtype == torch.bfloat16 and self.stochastic_rounding:
209
215
  # Pre-generate random tensor for stochastic rounding if needed.
210
216
  random_int_tensor = param_update._get_random_int_for_sr(p)
211
- lr = torch.as_tensor(lr, dtype=torch.float64)
217
+ if group.get('stochastic_sign', False):
218
+ random_noise_tensor = param_update._get_random_noise_for_sso(p)
219
+ lr = torch.as_tensor(lr)
212
220
  step_param_fn = self._compiled_step_parameter
213
221
  else:
214
222
  step_param_fn = self._step_parameter
215
223
 
216
- step_param_fn(p, grad, state, group, lr, random_int_tensor)
224
+ step_param_fn(p, grad, state, group, lr, random_int_tensor, random_noise_tensor)
217
225
 
218
226
  if group.get("l1_adaptive", False):
219
227
  state["step"] += 1
220
228
 
221
- def _step_parameter(self, p, grad, state, group, lr, random_int_tensor):
222
- if grad.dtype != torch.float32 and state['factored']:
229
+ def _step_parameter(self, p, grad, state, group, lr, random_int_tensor, random_noise_tensor):
230
+ if grad.dtype != torch.float32 and state.get('factored', False):
223
231
  grad = grad.float()
224
232
 
225
233
  if group["orthogonal_gradient"]:
@@ -269,18 +277,23 @@ class SignSGD_adv(torch.optim.Optimizer):
269
277
  if freeze_on_flip:
270
278
  state['sign'] = _pack_bools(raw_update > 0)
271
279
 
272
- l1_mean = _get_l1_adaptive_lr(p, raw_update, state, group, kappa_p)
273
280
 
274
- update = _get_lion_k_update(raw_update, kappa_p)
275
- update = update.view(p.shape)
281
+ raw_update = raw_update.view(p.shape)
276
282
 
277
283
  if freeze_on_flip:
278
284
  # Fast binary diff (XOR) from momentum sign directly
279
285
  flipped_packed = prev_sign_packed ^ state['sign']
280
- flipped_mask = _unpack_bools(flipped_packed, original_m=d2).view_as(update)
281
- update = torch.where(flipped_mask, 0.0, update)
286
+ flipped_mask = _unpack_bools(flipped_packed, original_m=d2).view_as(raw_update)
287
+ raw_update = torch.where(flipped_mask, 0.0, raw_update)
282
288
  del prev_sign_packed, flipped_packed, flipped_mask
283
289
 
290
+ l1_mean = _get_l1_adaptive_lr(p, raw_update, state, group, kappa_p)
291
+
292
+ if group.get('stochastic_sign', False):
293
+ update = apply_stochastic_sign(raw_update, noise=random_noise_tensor)
294
+ else:
295
+ update = _get_lion_k_update(raw_update, kappa_p)
296
+
284
297
  else:
285
298
  # Fallback to standard SignSGD logic
286
299
  if momentum > 0:
@@ -294,15 +307,18 @@ class SignSGD_adv(torch.optim.Optimizer):
294
307
  else:
295
308
  raw_update = grad.clone()
296
309
 
297
- l1_mean = _get_l1_adaptive_lr(p, raw_update, state, group, kappa_p)
298
-
299
- update = _get_lion_k_update(raw_update, kappa_p)
300
-
301
310
  if freeze_on_flip:
302
311
  current_sign = (raw_update > 0).to(torch.uint8)
303
- update = torch.where(current_sign == state['prev_sign'], update, 0.0)
312
+ raw_update = torch.where(current_sign == state['prev_sign'], raw_update, 0.0)
304
313
  state['prev_sign'] = current_sign
305
314
 
315
+ l1_mean = _get_l1_adaptive_lr(p, raw_update, state, group, kappa_p)
316
+
317
+ if group.get('stochastic_sign', False):
318
+ update = apply_stochastic_sign(raw_update, noise=random_noise_tensor)
319
+ else:
320
+ update = _get_lion_k_update(raw_update, kappa_p)
321
+
306
322
  if l1_mean is not None:
307
323
  update.mul_(l1_mean)
308
324
 
@@ -311,7 +327,7 @@ class SignSGD_adv(torch.optim.Optimizer):
311
327
  else:
312
328
  update.mul_(lr)
313
329
 
314
- param_update.apply_parameter_update(self, p, group, update, lr, random_int_tensor=random_int_tensor)
330
+ param_update.apply_parameter_update(self, p, group, update, lr, random_int_tensor=random_int_tensor, wd_scaler=l1_mean)
315
331
 
316
332
  def compile(self, *args, **kwargs):
317
333
  self._compiled_step_parameter = torch.compile(self._step_parameter, *args, **kwargs)
@@ -288,7 +288,7 @@ class Simplified_AdEMAMix(torch.optim.Optimizer):
288
288
  # Pre-generate random tensor for stochastic rounding if needed.
289
289
  random_int_tensor = param_update._get_random_int_for_sr(p)
290
290
  # TODO, workaround until pytorch#169634 is fixed
291
- lr = torch.as_tensor(lr, dtype=torch.float64)
291
+ lr = torch.as_tensor(lr)
292
292
  step_param_fn = self._compiled_step_parameter
293
293
  else:
294
294
  step_param_fn = self._step_parameter
@@ -1,5 +1,7 @@
1
1
  import torch
2
2
 
3
+ import math
4
+
3
5
  @torch.no_grad()
4
6
  def _newton_schulz_iteration(
5
7
  G: torch.Tensor,
@@ -359,11 +361,11 @@ def rms_adjustment(update: torch.Tensor, rms_rescaling: bool, lr):
359
361
  # This is slower due to norm calculations but it worked the best for t2i models.
360
362
  rms_target = 0.2 # default (Adam) value for RMS
361
363
  update_norm = torch.linalg.vector_norm(update)
362
- return update.mul_(lr * rms_target * (update.numel()**0.5) / update_norm.clamp_min_(1e-8))
364
+ return update.mul_(lr * rms_target * (math.sqrt(update.numel())) / update_norm.clamp_min_(1e-8))
363
365
  else:
364
366
  # Original Muon scaling
365
367
  r, c = update.size(-2), update.size(-1)
366
- scaling_factor = max(1, r / c) ** 0.5
368
+ scaling_factor = math.sqrt(max(1, r / c))
367
369
  return update.mul_(lr * scaling_factor)
368
370
 
369
371
  def _auto_projection_for_adamuon(raw_update: torch.Tensor, kappa_p: float) -> torch.Tensor:
@@ -474,15 +476,15 @@ def get_spectral_scaling(shape: torch.Size, n_layers: int):
474
476
  # A) Newton-Schulz Damping
475
477
  # This ensures the matrix orthogonalization is stable across scales.
476
478
  # Formula: (1/L) * sqrt(d_in / d_out)
477
- ns_eps = (1.0 / L) * (d_in / d_out) ** 0.5
479
+ ns_eps = (1.0 / L) * math.sqrt(d_in / d_out)
478
480
 
479
481
  # B) Adaptive Denominator Epsilon
480
482
  # This ensures the Adam-style division doesn't explode or vanish.
481
483
  # Formula: (1/L) * (1 / sqrt(d_in * d_out))
482
- adaptive_eps = (1.0 / L) * (1.0 / (d_in * d_out)**0.5)
484
+ adaptive_eps = (1.0 / L) * (1.0 / math.sqrt(d_in * d_out))
483
485
 
484
486
  # Spectral Target (Section F) -> sqrt(d_out/d_in)
485
- spectral_target = (d_out / d_in) ** 0.5
487
+ spectral_target = math.sqrt(d_out / d_in)
486
488
 
487
489
  # Weight Decay (Section 3.4) -> 1/width
488
490
  wd_scale = 1.0 / d_in
@@ -4,7 +4,7 @@ from torch.optim import Optimizer
4
4
 
5
5
  from typing import Dict, Any
6
6
 
7
- from .scaled_optm import scale_wds
7
+ from .scaled_optm import adjust_wds, scale_wds
8
8
  from .centered_decay import dequantize_anchor
9
9
 
10
10
  _generators: Dict[torch.device, torch.Generator] = {}
@@ -29,11 +29,17 @@ def _apply_weight_decay(
29
29
  # Cautious Weight Decay: only decay if the update pushes in the same direction as the decay
30
30
  if cautious:
31
31
  mask = (update_calc * p_calc >= 0).to(p_calc.dtype)
32
- p_calc.addcmul_(p_calc, mask, value=-scaled_wd)
32
+ if isinstance(scaled_wd, Tensor):
33
+ p_calc.addcmul_(p_calc, mask * scaled_wd, value=-1.0)
34
+ else:
35
+ p_calc.addcmul_(p_calc, mask, value=-scaled_wd)
33
36
  del mask
34
37
  else:
35
38
  # Standard decoupled weight decay
36
- p_calc.add_(p_calc, alpha=-scaled_wd)
39
+ if isinstance(scaled_wd, Tensor):
40
+ p_calc.addcmul_(p_calc, scaled_wd, value=-1.0)
41
+ else:
42
+ p_calc.add_(p_calc, alpha=-scaled_wd)
37
43
 
38
44
  # Centered Weight Decay (pulls toward anchor)
39
45
  if scaled_cwd is not None and 'anchor_type' in state:
@@ -43,15 +49,20 @@ def _apply_weight_decay(
43
49
  if cautious:
44
50
  # Cautious Weight Decay: only decay if the update pushes in the same direction as the decay
45
51
  mask = (update_calc * decay_target >= 0).to(p_calc.dtype)
46
- p_calc.addcmul_(decay_target, mask, value=-scaled_cwd)
52
+ if isinstance(scaled_cwd, Tensor):
53
+ p_calc.addcmul_(decay_target, mask * scaled_cwd, value=-1.0)
54
+ else:
55
+ p_calc.addcmul_(decay_target, mask, value=-scaled_cwd)
47
56
  del mask
48
57
  else:
49
58
  # Standard decoupled weight decay
50
- p_calc.add_(decay_target, alpha=-scaled_cwd)
59
+ if isinstance(scaled_cwd, Tensor):
60
+ p_calc.addcmul_(decay_target, scaled_cwd, value=-1.0)
61
+ else:
62
+ p_calc.add_(decay_target, alpha=-scaled_cwd)
51
63
 
52
64
  del anchor, decay_target
53
65
 
54
-
55
66
  def apply_parameter_update(
56
67
  self,
57
68
  p: Tensor,
@@ -61,6 +72,7 @@ def apply_parameter_update(
61
72
  wd: float | None = None,
62
73
  random_int_tensor: Tensor | None = None,
63
74
  decoupled: bool = False,
75
+ wd_scaler: float | Tensor | None = None,
64
76
  ) -> None:
65
77
  """
66
78
  Applies decoupled weight decay (standard, cautious, centered) and the final
@@ -75,13 +87,16 @@ def apply_parameter_update(
75
87
  random_int_tensor: Optional pre-generated random tensor for stochastic
76
88
  rounding. Required for the `torch.compile` path.
77
89
  decoupled: Whenever to use the true decoupled weight decay.
90
+ wd_scaler: A multiplier/tensor to scale the calculated wd/cwd magnitude (e.g. for Fisher Adam WD).
78
91
  """
79
92
  wd = group["weight_decay"] if wd is None else wd
80
93
  cwd = group.get("centered_wd", 0.0)
81
94
 
82
95
  if group.get('scaled_optm', False):
83
96
  decoupled = True
84
- wd, cwd = scale_wds(wd, cwd, p)
97
+ wd, cwd = adjust_wds(wd, cwd, p)
98
+ if wd_scaler is None:
99
+ wd, cwd = scale_wds(wd, cwd, p)
85
100
 
86
101
  # Calculate global decay factor for decoupled vs standard
87
102
  decay_factor = (lr / self._init_lr) if decoupled else lr
@@ -89,6 +104,12 @@ def apply_parameter_update(
89
104
  scaled_wd = (wd * decay_factor) if wd != 0 else None
90
105
  scaled_cwd = (cwd * decay_factor) if cwd != 0 else None
91
106
 
107
+ if wd_scaler is not None:
108
+ if scaled_wd is not None:
109
+ scaled_wd = scaled_wd * wd_scaler
110
+ if scaled_cwd is not None:
111
+ scaled_cwd = scaled_cwd * wd_scaler
112
+
92
113
  state = self.state[p]
93
114
 
94
115
  # Compute full update in float32 if using bfloat16 with stochastic rounding
@@ -284,3 +305,24 @@ def post_process_loaded_state(optimizer: Optimizer) -> None:
284
305
  # Ensure device match
285
306
  if state[key].device != p.device:
286
307
  state[key] = state[key].to(p.device)
308
+
309
+
310
+ def _get_random_noise_for_sso(source: torch.Tensor) -> torch.Tensor:
311
+ """
312
+ Generates a random noise tensor for Stochastic Sign operator.
313
+ This function is not torch.compile-path friendly due to its use of torch.Generator.
314
+ """
315
+ global _generators
316
+ device = source.device
317
+ if device not in _generators:
318
+ set_seed(device)
319
+ # TODO, this is a workaround until torch compile error
320
+ # NotImplementedError: UserDefinedObjectVariable(generator) is fixed
321
+ generator = _generators[device]
322
+ # create a random noise tensor
323
+ return torch.randint(
324
+ size=source.shape,
325
+ device=source.device,
326
+ dtype=source.dtype,
327
+ generator=generator,
328
+ )
@@ -46,35 +46,39 @@ def scale_update(
46
46
  return update.mul_(lr)
47
47
 
48
48
 
49
- def scale_wds(wd: float, cwd: float, p: torch.Tensor) -> tuple[float, float]:
49
+ def adjust_wds(wd: float, cwd: float, p: torch.Tensor) -> tuple[float, float]:
50
50
  """
51
- Adjusts standard weight decay and centered weight decay based on the parameter's
52
- shape and type to maintain effective regularization strength.
51
+ Adjusts standard weight decay and centered weight decay.
53
52
  """
54
53
  # DoRA Scale (Magnitude Vector)
55
54
  if getattr(p, '_is_dora_scale', False):
56
55
  return wd, cwd
57
56
 
58
- conflict = cwd != 0
59
-
60
57
  if getattr(p, '_is_oft', False):
61
- # Fallback to standard WD (using cwd value) if both are active.
62
- return (cwd if conflict else wd), 0.0
58
+ return wd, 0.0
63
59
 
64
60
  if p.ndim >= 2:
65
- fan_in = p.numel() // p.shape[0]
66
-
67
- # When both WDs are active on LoRA, fallback to standard WD (using cwd value)
68
- # Reverts the behavior for better DoRA tuning.
69
61
  is_lora = getattr(p, '_is_lora_A', False) or getattr(p, '_is_lora_B', False)
70
- if conflict and is_lora:
71
- return cwd / fan_in, 0.0
62
+ if is_lora:
63
+ return wd, 0.0
64
+
65
+ else:
66
+ # 1D Biases or generic 1D parameters
67
+ # Centered WD safely regularizes the delta without collapsing base feature variance.
68
+ return 0.0, cwd
69
+
72
70
 
71
+ def scale_wds(wd: float, cwd: float, p: torch.Tensor) -> tuple[float, float]:
72
+ """
73
+ Scales standard weight decay and centered weight decay based on the parameter's
74
+ shape and type to maintain effective regularization strength.
75
+ """
76
+ if p.ndim >= 2:
77
+ fan_in = p.numel() // p.shape[0]
73
78
  return wd / fan_in, cwd / fan_in
74
79
 
75
- # 1D Biases or generic 1D parameters
76
- # Centered WD safely regularizes the delta without collapsing base feature variance.
77
- return 0.0, cwd
80
+ # 1D tensors (like DoRA scale and Biases)
81
+ return wd, cwd
78
82
 
79
83
 
80
84
  @torch.no_grad()
@@ -0,0 +1,13 @@
1
+ import torch
2
+
3
+
4
+ def apply_stochastic_sign(update: torch.Tensor, noise: torch.Tensor | None) -> torch.Tensor:
5
+ """
6
+ Applies the Stochastic Sign operator S_R(v).
7
+ Uses uniform noise injection to compute the stochastic sign
8
+ """
9
+ R = update.abs().max().clamp_min(1e-12)
10
+
11
+ if noise is None:
12
+ noise = torch.rand_like(update) * 2.0 - 1.0
13
+ return torch.sign(update / R + noise, out=update)
@@ -1,5 +1,7 @@
1
1
  import torch
2
2
 
3
+ import math
4
+
3
5
  def _grams_update(mt: torch.Tensor, grad: torch.Tensor, inplace: bool=False):
4
6
  """
5
7
  Applies the update rule of "Gradient Descent with Adaptive Momentum Scaling"
@@ -31,27 +33,63 @@ def _scale_sim_AdEMAMix_update(beta: float, current_step: int, alpha_grad: float
31
33
  lr = lr * total_scale
32
34
  return lr
33
35
 
36
+
37
+ def _init_fisher_wd_scaler(group: dict, state: dict, p: torch.Tensor) -> torch.Tensor | None:
38
+ if not group.get('fisher_wd', False):
39
+ return
40
+
41
+ state["wd_scaler"] = torch.tensor(1.0, device=p.device)
42
+
43
+ def _get_fisher_wd_scaler(group: dict, stored_scaler: torch.Tensor, p: torch.Tensor, denom: torch.Tensor, atan2: bool) -> torch.Tensor | None:
44
+ """
45
+ Calculates the Fisher weight decay scaler.
46
+ Maps the decay direction through the empirical Fisher information matrix
47
+ and clips its RMS to ensure stability.
48
+ From the paper:
49
+ "FAdam: Adam is a natural gradient optimizer using diagonal empirical Fisher information"
50
+ """
51
+ if not group.get('fisher_wd', False):
52
+ return None
53
+
54
+ if atan2:
55
+ wd_scaler = torch.atan2(stored_scaler, denom).mul_(4 / math.pi)
56
+ else:
57
+ eps = group.get('eps', 1e-8)
58
+ wd_scaler = 1.0 / (denom + eps)
59
+
60
+ # Reshape scaler if necessary to match parameter shape (for factored states)
61
+ wd_scaler = wd_scaler.view(p.shape)
62
+
63
+ gw_rms = torch.sqrt(torch.mean((p * wd_scaler) ** 2))
64
+ clip_coef = torch.clamp(gw_rms / 1.0, min=1.0)
65
+ return wd_scaler / clip_coef
66
+
34
67
  def _get_l1_adaptive_lr(
35
68
  p: torch.Tensor,
36
69
  update: torch.Tensor,
37
70
  state: dict,
38
71
  group: dict,
39
- kappa_p: float
72
+ kappa_p: float,
73
+ rescale: bool = False,
40
74
  ) -> torch.Tensor:
41
75
  """
42
76
  Calculates the L1 adaptive learning rate based on gradient heterogeneity.
43
77
  """
44
- if not group.get("l1_adaptive", False) and kappa_p != 1:
78
+ if not group.get("l1_adaptive", False) or kappa_p != 1:
45
79
  return None
46
80
 
47
- momentum = group["momentum"]
48
- alpha_grad = group["alpha_grad"]
49
81
  update_view = update.view(p.shape)
50
82
 
51
- # Calculate scale factor based on momentum/update magnitude
52
- scale_factor = _scale_sim_AdEMAMix_update(
53
- momentum, state["step"] + 1, alpha_grad, 1, False
54
- )
83
+ if rescale:
84
+ momentum = group["momentum"]
85
+ alpha_grad = group["alpha_grad"]
86
+
87
+ # Calculate scale factor based on momentum/update magnitude
88
+ scale_factor = _scale_sim_AdEMAMix_update(
89
+ momentum, state["step"] + 1, alpha_grad, 1, False
90
+ )
91
+ else:
92
+ scale_factor = 1
55
93
 
56
94
  # Determine dimension for mean calculation based on parameter type
57
95
  if getattr(p, '_is_oft', False) or getattr(p, '_is_lora_A', False):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: adv_optm
3
- Version: 2.4.dev4
3
+ Version: 2.4.dev5
4
4
  Summary: A family of highly efficient, lightweight yet powerful optimizers.
5
5
  Home-page: https://github.com/Koratahiu/Advanced_Optimizers
6
6
  Author: Koratahiu
@@ -27,4 +27,5 @@ adv_optm/util/factorization_util.py
27
27
  adv_optm/util/lion_k.py
28
28
  adv_optm/util/param_update.py
29
29
  adv_optm/util/scaled_optm.py
30
+ adv_optm/util/signed_util.py
30
31
  adv_optm/util/update_util.py
@@ -5,7 +5,7 @@ with open("README.md", "r", encoding="utf-8") as fh:
5
5
 
6
6
  setup(
7
7
  name="adv_optm",
8
- version="2.4.dev4",
8
+ version="2.4.dev5",
9
9
  author="Koratahiu",
10
10
  author_email="hiuhonor@gmail.com",
11
11
  license='Apache 2.0',
File without changes
File without changes
File without changes