adv-optm 2.6.dev1__tar.gz → 2.6.1.dev1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. {adv_optm-2.6.dev1 → adv_optm-2.6.1.dev1}/PKG-INFO +1 -1
  2. {adv_optm-2.6.dev1 → adv_optm-2.6.1.dev1}/adv_optm/__init__.py +1 -1
  3. {adv_optm-2.6.dev1 → adv_optm-2.6.1.dev1}/adv_optm/optim/AdaMuon_adv.py +32 -20
  4. {adv_optm-2.6.dev1 → adv_optm-2.6.1.dev1}/adv_optm/optim/AdamW_adv.py +20 -15
  5. {adv_optm-2.6.dev1 → adv_optm-2.6.1.dev1}/adv_optm/optim/Adopt_adv.py +20 -14
  6. {adv_optm-2.6.dev1 → adv_optm-2.6.1.dev1}/adv_optm/optim/Lion_adv.py +13 -11
  7. {adv_optm-2.6.dev1 → adv_optm-2.6.1.dev1}/adv_optm/optim/Muon_adv.py +27 -17
  8. {adv_optm-2.6.dev1 → adv_optm-2.6.1.dev1}/adv_optm/optim/Prodigy_adv.py +20 -16
  9. {adv_optm-2.6.dev1 → adv_optm-2.6.1.dev1}/adv_optm/optim/SignSGD_adv.py +16 -12
  10. {adv_optm-2.6.dev1 → adv_optm-2.6.1.dev1}/adv_optm/optim/SinkSGD_adv.py +19 -14
  11. {adv_optm-2.6.dev1 → adv_optm-2.6.1.dev1}/adv_optm/util/Muon_AuxAdam.py +0 -1
  12. {adv_optm-2.6.dev1 → adv_optm-2.6.1.dev1}/adv_optm/util/OrthoGrad.py +1 -1
  13. {adv_optm-2.6.dev1 → adv_optm-2.6.1.dev1}/adv_optm/util/param_update.py +32 -39
  14. {adv_optm-2.6.dev1 → adv_optm-2.6.1.dev1}/adv_optm/util/scaled_optm.py +24 -37
  15. {adv_optm-2.6.dev1 → adv_optm-2.6.1.dev1}/adv_optm.egg-info/PKG-INFO +1 -1
  16. {adv_optm-2.6.dev1 → adv_optm-2.6.1.dev1}/adv_optm.egg-info/SOURCES.txt +0 -1
  17. {adv_optm-2.6.dev1 → adv_optm-2.6.1.dev1}/setup.py +1 -1
  18. adv_optm-2.6.dev1/adv_optm/util/msign.py +0 -114
  19. {adv_optm-2.6.dev1 → adv_optm-2.6.1.dev1}/LICENSE +0 -0
  20. {adv_optm-2.6.dev1 → adv_optm-2.6.1.dev1}/README.md +0 -0
  21. {adv_optm-2.6.dev1 → adv_optm-2.6.1.dev1}/adv_optm/optim/__init__.py +0 -0
  22. {adv_optm-2.6.dev1 → adv_optm-2.6.1.dev1}/adv_optm/util/Kourkoutas.py +0 -0
  23. {adv_optm-2.6.dev1 → adv_optm-2.6.1.dev1}/adv_optm/util/Muon_util.py +0 -0
  24. {adv_optm-2.6.dev1 → adv_optm-2.6.1.dev1}/adv_optm/util/__init__.py +0 -0
  25. {adv_optm-2.6.dev1 → adv_optm-2.6.1.dev1}/adv_optm/util/centered_decay.py +0 -0
  26. {adv_optm-2.6.dev1 → adv_optm-2.6.1.dev1}/adv_optm/util/factorization_util.py +0 -0
  27. {adv_optm-2.6.dev1 → adv_optm-2.6.1.dev1}/adv_optm/util/lion_k.py +0 -0
  28. {adv_optm-2.6.dev1 → adv_optm-2.6.1.dev1}/adv_optm/util/signed_util.py +0 -0
  29. {adv_optm-2.6.dev1 → adv_optm-2.6.1.dev1}/adv_optm/util/sinkhorn.py +0 -0
  30. {adv_optm-2.6.dev1 → adv_optm-2.6.1.dev1}/adv_optm/util/state_util.py +0 -0
  31. {adv_optm-2.6.dev1 → adv_optm-2.6.1.dev1}/adv_optm/util/update_util.py +0 -0
  32. {adv_optm-2.6.dev1 → adv_optm-2.6.1.dev1}/adv_optm.egg-info/dependency_links.txt +0 -0
  33. {adv_optm-2.6.dev1 → adv_optm-2.6.1.dev1}/adv_optm.egg-info/requires.txt +0 -0
  34. {adv_optm-2.6.dev1 → adv_optm-2.6.1.dev1}/adv_optm.egg-info/top_level.txt +0 -0
  35. {adv_optm-2.6.dev1 → adv_optm-2.6.1.dev1}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: adv_optm
3
- Version: 2.6.dev1
3
+ Version: 2.6.1.dev1
4
4
  Summary: A family of highly efficient, lightweight yet powerful optimizers.
5
5
  Home-page: https://github.com/Koratahiu/Advanced_Optimizers
6
6
  Author: Koratahiu
@@ -20,4 +20,4 @@ __all__ = [
20
20
  "SinkSGD_adv",
21
21
  ]
22
22
 
23
- __version__ = "2.6.dev1"
23
+ __version__ = "2.6.1.dev1"
@@ -137,6 +137,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
137
137
  # Decoupled/cautious weight decay
138
138
  weight_decay: float = 0,
139
139
  cautious_wd: bool = False,
140
+ scaled_wd: bool = False,
140
141
  # Nesterov momentum
141
142
  nesterov: bool = True,
142
143
  nesterov_coef: float | None = None,
@@ -177,8 +178,6 @@ class AdaMuon_adv(torch.optim.Optimizer):
177
178
  mars_gamma: float = 0.025,
178
179
  # Spectral Normalization
179
180
  spectral_normalization: bool = False,
180
- # Orthogonalize the weights (Matrix Sign - MSign) every x steps
181
- MSign_interval: int | None = None,
182
181
  # Centered WD
183
182
  centered_wd: float = 0.0,
184
183
  centered_wd_mode: str = 'float8',
@@ -229,7 +228,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
229
228
 
230
229
  defaults = {
231
230
  "lr": lr, "betas": betas, "weight_decay": weight_decay, "cautious_wd": cautious_wd,
232
- "eps": eps, "rms_rescaling": rms_rescaling, "ns_steps": ns_steps,
231
+ "eps": eps, "rms_rescaling": rms_rescaling, "ns_steps": ns_steps, "scaled_wd": scaled_wd,
233
232
  "ns_eps": ns_eps, "ns_coeffs": ns_coeffs, "nnmf_factor": nnmf_factor,
234
233
  "vector_reshape": vector_reshape,
235
234
  "nesterov":nesterov, "nesterov_coef": nesterov_coef, "use_atan2":use_atan2,
@@ -249,7 +248,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
249
248
  # MARS-M
250
249
  "approx_mars": approx_mars, "mars_gamma": mars_gamma,
251
250
  # Spectral Normalization
252
- "spectral_normalization": spectral_normalization, "MSign_interval": MSign_interval,
251
+ "spectral_normalization": spectral_normalization,
253
252
  # Centered WD
254
253
  "centered_wd": centered_wd,
255
254
  "centered_wd_mode": centered_wd_mode,
@@ -284,11 +283,9 @@ class AdaMuon_adv(torch.optim.Optimizer):
284
283
  for device in devices:
285
284
  param_update.set_seed(device)
286
285
 
287
- # Initialize compiled function
288
- self._compiled_muon_step_parameter = None
289
- self._compiled_adam_step_parameter = None
290
- if compiled_optimizer:
291
- self.compile(fullgraph=True)
286
+ # Initialize compiled functions (by parameter shape)
287
+ self._compiled_muon_step_fns = {}
288
+ self._compiled_adam_step_fns = {}
292
289
 
293
290
  def load_state_dict(self, state_dict: dict) -> None:
294
291
  """
@@ -446,8 +443,16 @@ class AdaMuon_adv(torch.optim.Optimizer):
446
443
  random_int_state_tensor = None
447
444
  if is_compiled:
448
445
  step_size = torch.as_tensor(step_size)
449
- adam_step_param = self._compiled_adam_step_parameter
450
-
446
+ # Cache compiled function per-shape
447
+ cache_key = (p.shape, state.get('factored', False))
448
+ if cache_key not in self._compiled_adam_step_fns:
449
+ self._compiled_adam_step_fns[cache_key] = torch.compile(
450
+ Muon_AuxAdam._adam_step_parameter,
451
+ fullgraph=True,
452
+ dynamic=False
453
+ )
454
+ adam_step_param = self._compiled_adam_step_fns[cache_key]
455
+
451
456
  # Generate state SR random tensor when compiled
452
457
  actual_precision = group.get('adam_actual_state_precision', 'auto')
453
458
  random_int_state_tensor = random_int_tensor
@@ -466,7 +471,15 @@ class AdaMuon_adv(torch.optim.Optimizer):
466
471
  random_G_sketch = None
467
472
  if is_compiled:
468
473
  lr = torch.as_tensor(group['lr'])
469
- muon_step_param = self._compiled_muon_step_parameter
474
+ # Cache compiled function per-shape
475
+ cache_key = (p.shape, state.get('factored', False))
476
+ if cache_key not in self._compiled_muon_step_fns:
477
+ self._compiled_muon_step_fns[cache_key] = torch.compile(
478
+ self._muon_step_parameter,
479
+ fullgraph=True,
480
+ dynamic=False
481
+ )
482
+ muon_step_param = self._compiled_muon_step_fns[cache_key]
470
483
 
471
484
  # Generate state SR random tensor when compiled
472
485
  actual_precision = group['actual_state_precision']
@@ -484,10 +497,6 @@ class AdaMuon_adv(torch.optim.Optimizer):
484
497
 
485
498
  muon_step_param(p, grad, state, group, lr, random_int_tensor, random_int_state_tensor, random_G_sketch)
486
499
 
487
- def compile(self, *args, **kwargs):
488
- self._compiled_muon_step_parameter = torch.compile(self._muon_step_parameter, *args, **kwargs)
489
- self._compiled_adam_step_parameter = torch.compile(Muon_AuxAdam._adam_step_parameter, *args, **kwargs)
490
-
491
500
  @torch.no_grad()
492
501
  def _muon_step_parameter(self, p, grad, state, group, lr, random_int_tensor, random_int_state_tensor, random_G_sketch):
493
502
  # Upcast grad for low-precision state modes (non-factored path)
@@ -533,8 +542,9 @@ class AdaMuon_adv(torch.optim.Optimizer):
533
542
  else:
534
543
  update = mt_buf.clone()
535
544
 
536
- # Factorize
537
- state['mu_mbuf_nmf'], state['mv_mbuf_nmf'], state['sign_buf'] = _factorize_state(mt_buf, signed=True, shifter=state['shifter'])
545
+ # Compress new momentum and store factors
546
+ for key, val in zip(('mu_mbuf_nmf', 'mv_mbuf_nmf', 'sign_buf'), _factorize_state(mt_buf, signed=True, shifter=state['shifter'])):
547
+ state[key].copy_(val)
538
548
  del mt_buf
539
549
 
540
550
  # Apply update projection
@@ -561,7 +571,8 @@ class AdaMuon_adv(torch.optim.Optimizer):
561
571
  vt_buf = _reconstruct_state((state['mu_vbuf_nmf'], state['mv_vbuf_nmf']), signed=False, shifter=state['shifter'])
562
572
  # Update second momentum in full-size
563
573
  vt_buf.mul_(beta2).addcmul_(update, update, value=1 - beta2)
564
- state['mu_vbuf_nmf'], state['mv_vbuf_nmf'] = _factorize_state(vt_buf, signed=False, shifter=state['shifter'])
574
+ for key, val in zip(('mu_vbuf_nmf', 'mv_vbuf_nmf'), _factorize_state(vt_buf, signed=False, shifter=state['shifter'])):
575
+ state[key].copy_(val)
565
576
  # Apply second momentum update (adaptive scaling)
566
577
  if group['use_atan2']:
567
578
  denom = vt_buf.sqrt_()
@@ -620,7 +631,8 @@ class AdaMuon_adv(torch.optim.Optimizer):
620
631
  update_f32 = update.float()
621
632
  vt_buf = _reconstruct_state((state['mu_vbuf_nmf'], state['mv_vbuf_nmf']), signed=False, shifter=state['shifter'])
622
633
  vt_buf.mul_(beta2).addcmul_(update_f32.view(d1, d2), update_f32.view(d1, d2), value=1 - beta2)
623
- state['mu_vbuf_nmf'], state['mv_vbuf_nmf'] = _factorize_state(vt_buf, signed=False, shifter=state['shifter'])
634
+ for key, val in zip(('mu_vbuf_nmf', 'mv_vbuf_nmf'), _factorize_state(vt_buf, signed=False, shifter=state['shifter'])):
635
+ state[key].copy_(val)
624
636
  # Apply second moment scaling
625
637
  if group['use_atan2']:
626
638
  denom = vt_buf.sqrt_().view(original_shape)
@@ -98,6 +98,7 @@ class AdamW_adv(torch.optim.Optimizer):
98
98
  weight_decay: float = 0.0,
99
99
  fisher_wd: bool = False,
100
100
  cautious_wd: bool = False,
101
+ scaled_wd: bool = False,
101
102
  # Adam's Bias Correction
102
103
  use_bias_correction: bool = True,
103
104
  # Stochastic Rounding for BF16
@@ -121,8 +122,6 @@ class AdamW_adv(torch.optim.Optimizer):
121
122
  layer_key_fn: Optional[Callable] = None,
122
123
  # Spectral Normed Optimizer
123
124
  spectral_normalization: bool = False,
124
- # Orthogonalize the weights (Matrix Sign - MSign) every x steps
125
- MSign_interval: int | None = None,
126
125
  # Centered WD
127
126
  centered_wd: float = 0.0,
128
127
  centered_wd_mode: str = 'float8',
@@ -158,14 +157,14 @@ class AdamW_adv(torch.optim.Optimizer):
158
157
 
159
158
  defaults = {
160
159
  "lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay,
161
- "fisher_wd": fisher_wd, "cautious_wd": cautious_wd,
160
+ "fisher_wd": fisher_wd, "cautious_wd": cautious_wd, "scaled_wd": scaled_wd,
162
161
  "use_atan2": use_atan2, "nesterov": nesterov, "nesterov_coef": nesterov_coef,
163
162
  "normed_momentum": normed_momentum,
164
163
  "orthogonal_gradient": orthogonal_gradient, "use_bias_correction": use_bias_correction,
165
164
  "compiled_optimizer": compiled_optimizer,
166
165
  "kourkoutas_beta": kourkoutas_beta, "beta2_min": beta2_min, "ema_alpha": ema_alpha,
167
166
  "tiny_spike": tiny_spike, "k_warmup_steps": k_warmup_steps, "k_logging": k_logging,
168
- "spectral_normalization": spectral_normalization, "MSign_interval": MSign_interval,
167
+ "spectral_normalization": spectral_normalization,
169
168
  "centered_wd": centered_wd, "centered_wd_mode": centered_wd_mode,
170
169
  "state_precision": state_precision,
171
170
  "nnmf_factor": nnmf_factor, "vector_reshape": vector_reshape, "factored_2nd": factored_2nd
@@ -188,10 +187,8 @@ class AdamW_adv(torch.optim.Optimizer):
188
187
  for device in devices:
189
188
  param_update.set_seed(device)
190
189
 
191
- # Initialize compiled function
192
- self._compiled_step_parameter = None
193
- if compiled_optimizer:
194
- self.compile(fullgraph=True)
190
+ # Initialize compiled function (by parameter shape)
191
+ self._compiled_step_fns = {}
195
192
 
196
193
  def load_state_dict(self, state_dict: dict) -> None:
197
194
  """
@@ -323,7 +320,15 @@ class AdamW_adv(torch.optim.Optimizer):
323
320
  random_int_state_tensor = param_update._get_random_int_for_sr(p)
324
321
  elif group['actual_state_precision'] == 'int8_sr':
325
322
  random_int_state_tensor = param_update._get_random_int_for_8bit_sr(p)
326
- step_param_fn = self._compiled_step_parameter
323
+ # Cache compiled function per-shape
324
+ cache_key = (p.shape, state.get('factored', False))
325
+ if cache_key not in self._compiled_step_fns:
326
+ self._compiled_step_fns[cache_key] = torch.compile(
327
+ self._step_parameter,
328
+ fullgraph=True,
329
+ dynamic=False
330
+ )
331
+ step_param_fn = self._compiled_step_fns[cache_key]
327
332
  else:
328
333
  step_param_fn = self._step_parameter
329
334
 
@@ -359,7 +364,8 @@ class AdamW_adv(torch.optim.Optimizer):
359
364
  vt.mul_(beta2).addcmul_(grad_reshaped, grad_reshaped, value=1.0 - beta2)
360
365
 
361
366
  # Factorize
362
- state['mu_v_nmf'], state['mv_v_nmf'] = _factorize_state(vt, signed=False, shifter=state['shifter'])
367
+ for key, val in zip(('mu_v_nmf', 'mv_v_nmf'), _factorize_state(vt, signed=False, shifter=state['shifter'])):
368
+ state[key].copy_(val)
363
369
 
364
370
  if group['use_atan2']:
365
371
  denom = vt.sqrt_()
@@ -380,7 +386,8 @@ class AdamW_adv(torch.optim.Optimizer):
380
386
  mt.lerp_(grad_reshaped, 1.0 - beta1)
381
387
 
382
388
  # Factorize
383
- state['mu_m_nmf'], state['mv_m_nmf'], state['sign'] = _factorize_state(mt.clone(), signed=True, shifter=state['shifter'])
389
+ for key, val in zip(('mu_m_nmf', 'mv_m_nmf', 'sign'), _factorize_state(mt.clone(), signed=True, shifter=state['shifter'])):
390
+ state[key].copy_(val)
384
391
 
385
392
  update_mt = mt
386
393
 
@@ -424,7 +431,8 @@ class AdamW_adv(torch.optim.Optimizer):
424
431
  exp_avg_sq.mul_(beta2).addcmul_(grad_vt, grad_vt, value=1.0 - beta2)
425
432
 
426
433
  if factored_2nd:
427
- state['mu_v_nmf'], state['mv_v_nmf'] = _factorize_state(exp_avg_sq.view(d1, d2), signed=False, shifter=state['shifter'])
434
+ for key, val in zip(('mu_v_nmf', 'mv_v_nmf'), _factorize_state(exp_avg_sq.view(d1, d2), signed=False, shifter=state['shifter'])):
435
+ state[key].copy_(val)
428
436
  else:
429
437
  set_state(state, 'exp_avg_sq', exp_avg_sq, actual_precision, random_int_state_tensor, non_neg=True)
430
438
 
@@ -471,9 +479,6 @@ class AdamW_adv(torch.optim.Optimizer):
471
479
 
472
480
  param_update.apply_parameter_update(self, p, group, update, group['lr'], random_int_tensor=random_int_tensor, wd_scaler=wd_scaler)
473
481
 
474
- def compile(self, *args, **kwargs):
475
- self._compiled_step_parameter = torch.compile(self._step_parameter, *args, **kwargs)
476
-
477
482
  @torch.no_grad()
478
483
  def step(self, closure=None):
479
484
  """Performs a single optimization step."""
@@ -101,6 +101,7 @@ class Adopt_adv(torch.optim.Optimizer):
101
101
  weight_decay: float = 0.0,
102
102
  fisher_wd: bool = False,
103
103
  cautious_wd: bool = False,
104
+ scaled_wd: bool = False,
104
105
  # ADOPT clipping
105
106
  clip_lambda: Optional[Callable[[int], float]] = lambda step: step**0.25,
106
107
  # Adam_atan2 (scale invariant)
@@ -122,8 +123,6 @@ class Adopt_adv(torch.optim.Optimizer):
122
123
  layer_key_fn: Optional[Callable] = None,
123
124
  # Spectral Normed Optimizer
124
125
  spectral_normalization: bool = False,
125
- # Orthogonalize the weights (Matrix Sign - MSign) every x steps
126
- MSign_interval: int | None = None,
127
126
  # Centered WD
128
127
  centered_wd: float = 0.0,
129
128
  centered_wd_mode: str = 'float8',
@@ -159,12 +158,12 @@ class Adopt_adv(torch.optim.Optimizer):
159
158
  state_precision = "factored"
160
159
 
161
160
  defaults = {
162
- "lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay,
161
+ "lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay, "scaled_wd": scaled_wd,
163
162
  "fisher_wd": fisher_wd, "cautious_wd": cautious_wd, "orthogonal_gradient": orthogonal_gradient,
164
163
  "nesterov": nesterov, "nesterov_coef": nesterov_coef,
165
164
  "kourkoutas_beta": kourkoutas_beta, "beta2_min": beta2_min, "ema_alpha": ema_alpha,
166
165
  "tiny_spike": tiny_spike, "k_warmup_steps": k_warmup_steps, "k_logging": k_logging,
167
- "spectral_normalization": spectral_normalization, "MSign_interval": MSign_interval,
166
+ "spectral_normalization": spectral_normalization,
168
167
  "centered_wd": centered_wd,
169
168
  "centered_wd_mode": centered_wd_mode,
170
169
  "state_precision": state_precision,
@@ -191,9 +190,8 @@ class Adopt_adv(torch.optim.Optimizer):
191
190
  for device in devices:
192
191
  param_update.set_seed(device)
193
192
 
194
- self._compiled_step_parameter = None
195
- if compiled_optimizer:
196
- self.compile(fullgraph=True)
193
+ # Initialize compiled function (by parameter shape)
194
+ self._compiled_step_fns = {}
197
195
 
198
196
  def load_state_dict(self, state_dict: dict) -> None:
199
197
  """
@@ -333,7 +331,15 @@ class Adopt_adv(torch.optim.Optimizer):
333
331
  random_int_state_tensor = param_update._get_random_int_for_sr(p)
334
332
  elif group['actual_state_precision'] == 'int8_sr':
335
333
  random_int_state_tensor = param_update._get_random_int_for_8bit_sr(p)
336
- step_param_fn = self._compiled_step_parameter
334
+ # Cache compiled function per-shape
335
+ cache_key = (p.shape, state.get('factored', False))
336
+ if cache_key not in self._compiled_step_fns:
337
+ self._compiled_step_fns[cache_key] = torch.compile(
338
+ self._step_parameter,
339
+ fullgraph=True,
340
+ dynamic=False
341
+ )
342
+ step_param_fn = self._compiled_step_fns[cache_key]
337
343
  else:
338
344
  lr = group['lr']
339
345
  step_param_fn = self._step_parameter
@@ -375,7 +381,8 @@ class Adopt_adv(torch.optim.Optimizer):
375
381
  else:
376
382
  vt.mul_(beta2).addcmul_(grad_reshaped, grad_reshaped, value=1.0 - beta2)
377
383
  # Factorize
378
- state['mu_v_nmf'], state['mv_v_nmf'] = _factorize_state(vt, signed=False, shifter=state['shifter'])
384
+ for key, val in zip(('mu_v_nmf', 'mv_v_nmf'), _factorize_state(vt, signed=False, shifter=state['shifter'])):
385
+ state[key].copy_(val)
379
386
  del vt
380
387
 
381
388
  if self.use_atan2:
@@ -393,7 +400,8 @@ class Adopt_adv(torch.optim.Optimizer):
393
400
  mt.lerp_(normalized_grad, 1.0 - beta1)
394
401
 
395
402
  # Factorize
396
- state['mu_m_nmf'], state['mv_m_nmf'], state['sign'] = _factorize_state(mt.clone(), signed=True, shifter=state['shifter'])
403
+ for key, val in zip(('mu_m_nmf', 'mv_m_nmf', 'sign'), _factorize_state(mt.clone(), signed=True, shifter=state['shifter'])):
404
+ state[key].copy_(val)
397
405
 
398
406
  update_mt = mt
399
407
 
@@ -460,7 +468,8 @@ class Adopt_adv(torch.optim.Optimizer):
460
468
  vt.mul_(beta2).addcmul_(grad_vt, grad_vt, value=1 - beta2)
461
469
 
462
470
  if factored_2nd:
463
- state['mu_v_nmf'], state['mv_v_nmf'] = _factorize_state(vt.view(d1, d2), signed=False, shifter=state['shifter'])
471
+ for key, val in zip(('mu_v_nmf', 'mv_v_nmf'), _factorize_state(vt.view(d1, d2), signed=False, shifter=state['shifter'])):
472
+ state[key].copy_(val)
464
473
  else:
465
474
  set_state(state, 'exp_avg_sq', vt, actual_precision, random_int_state_tensor, non_neg=True)
466
475
  del random_int_state_tensor
@@ -475,9 +484,6 @@ class Adopt_adv(torch.optim.Optimizer):
475
484
  # Parameter Update
476
485
  param_update.apply_parameter_update(self, p, group, update, lr, random_int_tensor=random_int_tensor, wd_scaler=wd_scaler)
477
486
 
478
- def compile(self, *args, **kwargs):
479
- self._compiled_step_parameter = torch.compile(self._step_parameter, *args, **kwargs)
480
-
481
487
  @torch.no_grad()
482
488
  def step(self, closure=None):
483
489
  """Performs a single optimization step."""
@@ -64,6 +64,7 @@ class Lion_adv(torch.optim.Optimizer):
64
64
  # Decoupled/cautious weight decay
65
65
  weight_decay: float = 0.0,
66
66
  cautious_wd: bool = False,
67
+ scaled_wd: bool = False,
67
68
  # Stochastic Rounding for BF16
68
69
  stochastic_rounding: bool = True,
69
70
  # OrthoGrad
@@ -78,8 +79,6 @@ class Lion_adv(torch.optim.Optimizer):
78
79
  centered_wd_mode: str = 'float8',
79
80
  # Spectral Normed Optimizer
80
81
  spectral_normalization: bool = False,
81
- # Orthogonalize the weights (Matrix Sign - MSign) every x steps
82
- MSign_interval: int | None = None,
83
82
  # SMMF factorization
84
83
  nnmf_factor: bool = False,
85
84
  vector_reshape: bool = False,
@@ -98,13 +97,13 @@ class Lion_adv(torch.optim.Optimizer):
98
97
  betas=betas,
99
98
  weight_decay=weight_decay,
100
99
  cautious_wd=cautious_wd,
100
+ scaled_wd=scaled_wd,
101
101
  vector_reshape=vector_reshape,
102
102
  orthogonal_gradient=orthogonal_gradient,
103
103
  kappa_p=kappa_p,
104
104
  auto_kappa_p=auto_kappa_p,
105
105
  stochastic_sign=stochastic_sign,
106
106
  spectral_normalization=spectral_normalization,
107
- MSign_interval=MSign_interval,
108
107
  nnmf_factor=nnmf_factor,
109
108
  centered_wd= centered_wd,
110
109
  centered_wd_mode= centered_wd_mode,
@@ -122,10 +121,8 @@ class Lion_adv(torch.optim.Optimizer):
122
121
  for device in devices:
123
122
  param_update.set_seed(device)
124
123
 
125
- # Initialize compiled function
126
- self._compiled_step_parameter = None
127
- if compiled_optimizer:
128
- self.compile(fullgraph=True)
124
+ # Initialize compiled function (by parameter shape)
125
+ self._compiled_step_fns = {}
129
126
 
130
127
  def load_state_dict(self, state_dict: dict) -> None:
131
128
  """
@@ -208,7 +205,15 @@ class Lion_adv(torch.optim.Optimizer):
208
205
  if group.get('stochastic_sign', False):
209
206
  random_noise_tensor = param_update._get_random_noise_for_sso(p)
210
207
  lr = torch.as_tensor(lr)
211
- step_param_fn = self._compiled_step_parameter
208
+ # Cache compiled function per-shape
209
+ cache_key = (p.shape, state.get('factored', False))
210
+ if cache_key not in self._compiled_step_fns:
211
+ self._compiled_step_fns[cache_key] = torch.compile(
212
+ self._step_parameter,
213
+ fullgraph=True,
214
+ dynamic=False
215
+ )
216
+ step_param_fn = self._compiled_step_fns[cache_key]
212
217
  else:
213
218
  step_param_fn = self._step_parameter
214
219
 
@@ -281,9 +286,6 @@ class Lion_adv(torch.optim.Optimizer):
281
286
 
282
287
  param_update.apply_parameter_update(self, p, group, update, lr, random_int_tensor=random_int_tensor)
283
288
 
284
- def compile(self, *args, **kwargs):
285
- self._compiled_step_parameter = torch.compile(self._step_parameter, *args, **kwargs)
286
-
287
289
  @torch.no_grad()
288
290
  def step(self, closure: Optional[callable] = None):
289
291
  """Performs a single optimization step."""
@@ -111,6 +111,7 @@ class Muon_adv(torch.optim.Optimizer):
111
111
  # Decoupled/cautious weight decay
112
112
  weight_decay: float = 0.0,
113
113
  cautious_wd: bool = False,
114
+ scaled_wd: bool = False,
114
115
  # Nesterov momentum
115
116
  nesterov: bool = True,
116
117
  nesterov_coef: float | None = None,
@@ -146,8 +147,6 @@ class Muon_adv(torch.optim.Optimizer):
146
147
  mars_gamma: float = 0.025,
147
148
  # Spectral Normalization
148
149
  spectral_normalization: bool = False,
149
- # Orthogonalize the weights (Matrix Sign - MSign) every x steps
150
- MSign_interval: int | None = None,
151
150
  # Centered WD
152
151
  centered_wd: float = 0.0,
153
152
  centered_wd_mode: str = 'float8',
@@ -203,7 +202,7 @@ class Muon_adv(torch.optim.Optimizer):
203
202
  defaults = {
204
203
  "lr": lr, "beta1": beta1, "weight_decay": weight_decay, "cautious_wd": cautious_wd,
205
204
  "nesterov": nesterov, "nesterov_coef": nesterov_coef, "ns_steps": ns_steps, "ns_eps": ns_eps,
206
- "ns_coeffs": ns_coeffs, "nnmf_factor": nnmf_factor,
205
+ "ns_coeffs": ns_coeffs, "nnmf_factor": nnmf_factor, "scaled_wd": scaled_wd,
207
206
  "vector_reshape": vector_reshape, "rms_rescaling": rms_rescaling,
208
207
  "orthogonal_gradient": orthogonal_gradient,
209
208
  'compiled_optimizer': compiled_optimizer,
@@ -220,7 +219,7 @@ class Muon_adv(torch.optim.Optimizer):
220
219
  # MARS-M
221
220
  "approx_mars": approx_mars, "mars_gamma": mars_gamma,
222
221
  # Spectral Normalization
223
- "spectral_normalization": spectral_normalization, "MSign_interval": MSign_interval,
222
+ "spectral_normalization": spectral_normalization,
224
223
  # Centered WD
225
224
  "centered_wd": centered_wd,
226
225
  "centered_wd_mode": centered_wd_mode,
@@ -256,11 +255,9 @@ class Muon_adv(torch.optim.Optimizer):
256
255
  for device in devices:
257
256
  param_update.set_seed(device)
258
257
 
259
- # Initialize compiled function
260
- self._compiled_muon_step_parameter = None
261
- self._compiled_adam_step_parameter = None
262
- if compiled_optimizer:
263
- self.compile(fullgraph=True)
258
+ # Initialize compiled functions (by parameter shape)
259
+ self._compiled_muon_step_fns = {}
260
+ self._compiled_adam_step_fns = {}
264
261
 
265
262
  def load_state_dict(self, state_dict: dict) -> None:
266
263
  """
@@ -398,7 +395,15 @@ class Muon_adv(torch.optim.Optimizer):
398
395
  random_int_state_tensor = None
399
396
  if is_compiled:
400
397
  step_size = torch.as_tensor(step_size)
401
- adam_step_param = self._compiled_adam_step_parameter
398
+ # Cache compiled function per-shape
399
+ cache_key = (p.shape, state.get('factored', False))
400
+ if cache_key not in self._compiled_adam_step_fns:
401
+ self._compiled_adam_step_fns[cache_key] = torch.compile(
402
+ Muon_AuxAdam._adam_step_parameter,
403
+ fullgraph=True,
404
+ dynamic=False
405
+ )
406
+ adam_step_param = self._compiled_adam_step_fns[cache_key]
402
407
 
403
408
  actual_precision = group.get('adam_actual_state_precision', 'auto')
404
409
  random_int_state_tensor = random_int_tensor
@@ -417,7 +422,15 @@ class Muon_adv(torch.optim.Optimizer):
417
422
  random_G_sketch = None
418
423
  if is_compiled:
419
424
  lr = torch.as_tensor(group['lr'])
420
- muon_step_param = self._compiled_muon_step_parameter
425
+ # Cache compiled function per-shape
426
+ cache_key = (p.shape, state.get('factored', False))
427
+ if cache_key not in self._compiled_muon_step_fns:
428
+ self._compiled_muon_step_fns[cache_key] = torch.compile(
429
+ self._muon_step_parameter,
430
+ fullgraph=True,
431
+ dynamic=False
432
+ )
433
+ muon_step_param = self._compiled_muon_step_fns[cache_key]
421
434
 
422
435
  # Generate state SR random tensor when compiled
423
436
  actual_precision = group['actual_state_precision']
@@ -435,10 +448,6 @@ class Muon_adv(torch.optim.Optimizer):
435
448
 
436
449
  muon_step_param(p, grad, state, group, lr, random_int_tensor, random_int_state_tensor, random_G_sketch)
437
450
 
438
- def compile(self, *args, **kwargs):
439
- self._compiled_muon_step_parameter = torch.compile(self._muon_step_parameter, *args, **kwargs)
440
- self._compiled_adam_step_parameter = torch.compile(Muon_AuxAdam._adam_step_parameter, *args, **kwargs)
441
-
442
451
  @torch.no_grad()
443
452
  def _muon_step_parameter(self, p, grad, state, group, lr, random_int_tensor, random_int_state_tensor, random_G_sketch):
444
453
  # Upcast grad for low-precision state modes (non-factored path)
@@ -477,8 +486,9 @@ class Muon_adv(torch.optim.Optimizer):
477
486
  # Standard momentum
478
487
  update = mt_buf.clone()
479
488
 
480
- # Factorize
481
- state['mu_mbuf_nmf'], state['mv_mbuf_nmf'], state['sign_buf'] = _factorize_state(mt_buf, signed=True, shifter=state['shifter'])
489
+ # Compress new momentum and store factors
490
+ for key, val in zip(('mu_mbuf_nmf', 'mv_mbuf_nmf', 'sign_buf'), _factorize_state(mt_buf, signed=True, shifter=state['shifter'])):
491
+ state[key].copy_(val)
482
492
  del mt_buf
483
493
 
484
494
  # Orthogonalization step
@@ -115,6 +115,7 @@ class Prodigy_adv(torch.optim.Optimizer):
115
115
  weight_decay: float = 0.0,
116
116
  fisher_wd: bool = False,
117
117
  cautious_wd: bool = False,
118
+ scaled_wd: bool = False,
118
119
  # Stochastic Rounding for BF16
119
120
  stochastic_rounding: bool = True,
120
121
  # Adam_atan2 (scale invariant)
@@ -156,8 +157,6 @@ class Prodigy_adv(torch.optim.Optimizer):
156
157
  centered_wd_mode: str = 'float8',
157
158
  # Spectral Normalization
158
159
  spectral_normalization: bool = False,
159
- # Orthogonalize the weights (Matrix Sign - MSign) every x steps
160
- MSign_interval: int | None = None,
161
160
  ):
162
161
  if not (lr >= 0.0):
163
162
  raise ValueError(f"Learning-rate should be >= 0.0. Got {lr}")
@@ -183,7 +182,7 @@ class Prodigy_adv(torch.optim.Optimizer):
183
182
 
184
183
  defaults = {
185
184
  "lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay,
186
- "fisher_wd": fisher_wd, "cautious_wd": cautious_wd,
185
+ "fisher_wd": fisher_wd, "cautious_wd": cautious_wd, "scaled_wd": scaled_wd,
187
186
  "use_atan2": use_atan2,
188
187
  "orthogonal_gradient": orthogonal_gradient,
189
188
  "compiled_optimizer": compiled_optimizer,
@@ -195,7 +194,7 @@ class Prodigy_adv(torch.optim.Optimizer):
195
194
  "tiny_spike": tiny_spike, "k_warmup_steps": k_warmup_steps, "k_logging": k_logging,
196
195
  "centered_wd": centered_wd, "centered_wd_mode": centered_wd_mode,
197
196
  "nnmf_factor": nnmf_factor, "vector_reshape": vector_reshape, "factored_2nd": factored_2nd,
198
- "spectral_normalization": spectral_normalization, "MSign_interval": MSign_interval,
197
+ "spectral_normalization": spectral_normalization,
199
198
  }
200
199
  self.stochastic_rounding = stochastic_rounding
201
200
  self.fsdp_in_use = fsdp_in_use
@@ -220,11 +219,8 @@ class Prodigy_adv(torch.optim.Optimizer):
220
219
  for device in devices:
221
220
  param_update.set_seed(device)
222
221
 
223
- # Initialize compiled function
224
- self._compiled_step_parameter = None
225
-
226
- if compiled_optimizer:
227
- self.compile(fullgraph=True)
222
+ # Initialize compiled function (by parameter shape)
223
+ self._compiled_step_fns = {}
228
224
 
229
225
  def load_state_dict(self, state_dict: dict) -> None:
230
226
  """
@@ -366,7 +362,15 @@ class Prodigy_adv(torch.optim.Optimizer):
366
362
  random_int_state_tensor = param_update._get_random_int_for_sr(p)
367
363
  elif group['actual_state_precision'] == 'int8_sr':
368
364
  random_int_state_tensor = param_update._get_random_int_for_8bit_sr(p)
369
- step_param_fn = self._compiled_step_parameter
365
+ # Cache compiled function per-shape
366
+ cache_key = (p.shape, state.get('factored', False))
367
+ if cache_key not in self._compiled_step_fns:
368
+ self._compiled_step_fns[cache_key] = torch.compile(
369
+ self._step_parameter,
370
+ fullgraph=True,
371
+ dynamic=False
372
+ )
373
+ step_param_fn = self._compiled_step_fns[cache_key]
370
374
  else:
371
375
  d = group['d']
372
376
  step_param_fn = self._step_parameter
@@ -402,7 +406,8 @@ class Prodigy_adv(torch.optim.Optimizer):
402
406
  mt.mul_(self.beta1).add_(grad_reshaped, alpha=d * (1.0 - self.beta1))
403
407
 
404
408
  # Factorize
405
- state['mu_m_nmf'], state['mv_m_nmf'], state['sign'] = _factorize_state(mt.clone(), signed=True, shifter=state['shifter'])
409
+ for key, val in zip(('mu_m_nmf', 'mv_m_nmf', 'sign'), _factorize_state(mt.clone(), signed=True, shifter=state['shifter'])):
410
+ state[key].copy_(val)
406
411
 
407
412
  update_mt = mt
408
413
 
@@ -423,7 +428,8 @@ class Prodigy_adv(torch.optim.Optimizer):
423
428
  update = grad_reshaped.mul(d)
424
429
 
425
430
  # Factorize
426
- state['mu_v_nmf'], state['mv_v_nmf'] = _factorize_state(vt, signed=False, shifter=state['shifter'])
431
+ for key, val in zip(('mu_v_nmf', 'mv_v_nmf'), _factorize_state(vt, signed=False, shifter=state['shifter'])):
432
+ state[key].copy_(val)
427
433
 
428
434
  if group['use_atan2']:
429
435
  denom = vt.sqrt_()
@@ -475,7 +481,8 @@ class Prodigy_adv(torch.optim.Optimizer):
475
481
  exp_avg_sq.mul_(beta2).addcmul_(grad_vt, grad_vt, value=d * d * (1.0 - beta2))
476
482
 
477
483
  if factored_2nd:
478
- state['mu_v_nmf'], state['mv_v_nmf'] = _factorize_state(exp_avg_sq.view(d1, d2), signed=False, shifter=state['shifter'])
484
+ for key, val in zip(('mu_v_nmf', 'mv_v_nmf'), _factorize_state(exp_avg_sq.view(d1, d2), signed=False, shifter=state['shifter'])):
485
+ state[key].copy_(val)
479
486
  else:
480
487
  set_state(state, 'exp_avg_sq', exp_avg_sq, actual_precision, random_int_state_tensor, non_neg=True)
481
488
  del random_int_state_tensor
@@ -524,9 +531,6 @@ class Prodigy_adv(torch.optim.Optimizer):
524
531
 
525
532
  param_update.apply_parameter_update(self, p, group, update, dlr, random_int_tensor=random_int_tensor, wd_scaler=wd_scaler)
526
533
 
527
- def compile(self, *args, **kwargs):
528
- self._compiled_step_parameter = torch.compile(self._step_parameter, *args, **kwargs)
529
-
530
534
  @torch.no_grad()
531
535
  def step(self, closure=None):
532
536
  """Performs a single optimization step."""