adv-optm 1.2.dev14__tar.gz → 1.2.dev16__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of adv-optm might be problematic. Click here for more details.

Files changed (29) hide show
  1. {adv_optm-1.2.dev14 → adv_optm-1.2.dev16}/PKG-INFO +1 -1
  2. {adv_optm-1.2.dev14 → adv_optm-1.2.dev16}/adv_optm/__init__.py +1 -1
  3. {adv_optm-1.2.dev14 → adv_optm-1.2.dev16}/adv_optm/optim/AdaMuon_adv.py +132 -66
  4. {adv_optm-1.2.dev14 → adv_optm-1.2.dev16}/adv_optm/optim/Muon_adv.py +36 -12
  5. adv_optm-1.2.dev16/adv_optm/util/Newton_Schulz.py +87 -0
  6. {adv_optm-1.2.dev14 → adv_optm-1.2.dev16}/adv_optm.egg-info/PKG-INFO +1 -1
  7. {adv_optm-1.2.dev14 → adv_optm-1.2.dev16}/setup.py +1 -1
  8. adv_optm-1.2.dev14/adv_optm/util/Newton_Schulz.py +0 -48
  9. {adv_optm-1.2.dev14 → adv_optm-1.2.dev16}/LICENSE +0 -0
  10. {adv_optm-1.2.dev14 → adv_optm-1.2.dev16}/README.md +0 -0
  11. {adv_optm-1.2.dev14 → adv_optm-1.2.dev16}/adv_optm/optim/AdamW_adv.py +0 -0
  12. {adv_optm-1.2.dev14 → adv_optm-1.2.dev16}/adv_optm/optim/Adopt_adv.py +0 -0
  13. {adv_optm-1.2.dev14 → adv_optm-1.2.dev16}/adv_optm/optim/Lion_Prodigy_adv.py +0 -0
  14. {adv_optm-1.2.dev14 → adv_optm-1.2.dev16}/adv_optm/optim/Lion_adv.py +0 -0
  15. {adv_optm-1.2.dev14 → adv_optm-1.2.dev16}/adv_optm/optim/Prodigy_adv.py +0 -0
  16. {adv_optm-1.2.dev14 → adv_optm-1.2.dev16}/adv_optm/optim/Simplified_AdEMAMix.py +0 -0
  17. {adv_optm-1.2.dev14 → adv_optm-1.2.dev16}/adv_optm/optim/__init__.py +0 -0
  18. {adv_optm-1.2.dev14 → adv_optm-1.2.dev16}/adv_optm/util/BF16_Stochastic_Rounding.py +0 -0
  19. {adv_optm-1.2.dev14 → adv_optm-1.2.dev16}/adv_optm/util/Effective_Shape.py +0 -0
  20. {adv_optm-1.2.dev14 → adv_optm-1.2.dev16}/adv_optm/util/Kourkoutas.py +0 -0
  21. {adv_optm-1.2.dev14 → adv_optm-1.2.dev16}/adv_optm/util/NNMF.py +0 -0
  22. {adv_optm-1.2.dev14 → adv_optm-1.2.dev16}/adv_optm/util/One_Bit_Boolean.py +0 -0
  23. {adv_optm-1.2.dev14 → adv_optm-1.2.dev16}/adv_optm/util/OrthoGrad.py +0 -0
  24. {adv_optm-1.2.dev14 → adv_optm-1.2.dev16}/adv_optm/util/__init__.py +0 -0
  25. {adv_optm-1.2.dev14 → adv_optm-1.2.dev16}/adv_optm.egg-info/SOURCES.txt +0 -0
  26. {adv_optm-1.2.dev14 → adv_optm-1.2.dev16}/adv_optm.egg-info/dependency_links.txt +0 -0
  27. {adv_optm-1.2.dev14 → adv_optm-1.2.dev16}/adv_optm.egg-info/requires.txt +0 -0
  28. {adv_optm-1.2.dev14 → adv_optm-1.2.dev16}/adv_optm.egg-info/top_level.txt +0 -0
  29. {adv_optm-1.2.dev14 → adv_optm-1.2.dev16}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: adv_optm
3
- Version: 1.2.dev14
3
+ Version: 1.2.dev16
4
4
  Summary: A family of highly efficient, lightweight yet powerful optimizers.
5
5
  Home-page: https://github.com/Koratahiu/Advanced_Optimizers
6
6
  Author: Koratahiu
@@ -20,4 +20,4 @@ __all__ = [
20
20
  "AdaMuon_adv",
21
21
  ]
22
22
 
23
- __version__ = "1.2.dev14"
23
+ __version__ = "1.2.dev16"
@@ -63,6 +63,9 @@ class AdaMuon_adv(torch.optim.Optimizer):
63
63
  (default: False)
64
64
  ortho_rank (int): The rank for low-rank orthogonalization.
65
65
  (default: 128)
66
+ accelerated_ns (bool): If True, enables Chebyshev-accelerated Newton-Schulz, which
67
+ dynamically calculates optimal 3rd-order polynomial coefficients. (default: False)
68
+ cns_a_bound (float): Initial lower bound for singular values for CANS. (default: 1e-4)
66
69
  nnmf_factor (bool): whether to use the factorization or disable it to use
67
70
  the uncompressed optimizer. (default: False)
68
71
  --- Auxiliary AdamW_adv Parameters (used for 'adam' groups) ---
@@ -96,11 +99,16 @@ class AdaMuon_adv(torch.optim.Optimizer):
96
99
  nesterov: bool = False,
97
100
  Simplified_AdEMAMix: bool = False,
98
101
  alpha_grad: float = 100.0,
99
- vector_reshape: bool = False,
102
+ normuon_variant: bool = False,
100
103
  # Low-rank Muon
101
104
  low_rank_ortho: bool = False,
102
105
  ortho_rank: int = 128,
106
+ # Factored
107
+ vector_reshape: bool = False,
103
108
  nnmf_factor: bool = False,
109
+ # CANS
110
+ accelerated_ns: bool = False,
111
+ cns_a_bound: float = 1e-4,
104
112
  # Compiled
105
113
  compiled_optimizer: bool = False,
106
114
  # --- AdamW_adv specific parameters ---
@@ -139,9 +147,12 @@ class AdaMuon_adv(torch.optim.Optimizer):
139
147
  "vector_reshape": vector_reshape,
140
148
  "nesterov":nesterov, "use_atan2":use_atan2,
141
149
  "Simplified_AdEMAMix": Simplified_AdEMAMix, "alpha_grad": alpha_grad,
150
+ "normuon_variant": normuon_variant,
142
151
  # Low-rank Ortho
143
152
  "low_rank_ortho": low_rank_ortho, "ortho_rank": ortho_rank,
144
153
  "compiled_optimizer":compiled_optimizer,
154
+ # CANS
155
+ "accelerated_ns": accelerated_ns, "cns_a_bound": cns_a_bound,
145
156
  # AdamW_adv defaults
146
157
  "adam_betas": adam_betas, "adam_eps": adam_eps, "adam_weight_decay": adam_weight_decay,
147
158
  "adam_use_bias_correction": adam_use_bias_correction, "adam_use_atan2": adam_use_atan2,
@@ -156,7 +167,6 @@ class AdaMuon_adv(torch.optim.Optimizer):
156
167
 
157
168
  super().__init__(params, defaults)
158
169
 
159
- self.global_step = 0 # For Adam bias correction and Kourkoutas
160
170
  self.kourkoutas_helper = None
161
171
  if any(group.get('adam_kourkoutas_beta', False) for group in self.param_groups):
162
172
  self.kourkoutas_helper = KourkoutasHelper(self)
@@ -214,16 +224,25 @@ class AdaMuon_adv(torch.optim.Optimizer):
214
224
  state['mv_mbuf_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
215
225
  packed_d2 = (d2 + 7) // 8
216
226
  state['sign_buf'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=device)
217
- state['mu_vbuf_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
218
- state['mv_vbuf_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
227
+ if not group['normuon_variant']:
228
+ state['mu_vbuf_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
229
+ state['mv_vbuf_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
219
230
  else:
220
- if len(p.shape) >= 2:
231
+ if len(p.shape) >= 2 and not group['normuon_variant']:
221
232
  state['second_momentum_buffer'] = torch.zeros_like(p)
222
233
  state['momentum_buffer'] = torch.zeros_like(p)
223
234
 
235
+ # NorMuon state initialization
236
+ if group['normuon_variant']:
237
+ if state['factored']:
238
+ state['normuon_v'] = torch.zeros(d1, device=p.device, dtype=torch.float32)
239
+ elif len(p.shape) >= 2:
240
+ state['normuon_v'] = torch.zeros(p.shape[0], device=p.device, dtype=torch.float32)
224
241
 
225
242
  elif optim_type == 'adam':
226
243
 
244
+ state['step'] = 0
245
+
227
246
  state['factored'] = (
228
247
  group['adam_nnmf_factor'] and
229
248
  not (len(p.shape) == 1 and not group['vector_reshape'])
@@ -301,12 +320,12 @@ class AdaMuon_adv(torch.optim.Optimizer):
301
320
  Q, _ = torch.linalg.qr(MG)
302
321
  projected_M = Q.T @ M
303
322
  ortho_projected_M = _newton_schulz_iteration(
304
- projected_M, steps=group['ns_steps'], eps=group['ns_eps'], coeffs=group['ns_coeffs']
323
+ projected_M, steps=group['ns_steps'], eps=group['ns_eps'], coeffs=group['ns_coeffs'], cns=group['accelerated_ns'], cns_a_bound=group['cns_a_bound']
305
324
  )
306
325
  update = Q @ ortho_projected_M
307
326
  else: # Fallback for invalid rank
308
327
  update = _newton_schulz_iteration(
309
- signed_m_buf, steps=group['ns_steps'], eps=group['ns_eps'], coeffs=group['ns_coeffs']
328
+ signed_m_buf, steps=group['ns_steps'], eps=group['ns_eps'], coeffs=group['ns_coeffs'], cns=group['accelerated_ns'], cns_a_bound=group['cns_a_bound']
310
329
  )
311
330
  else:
312
331
  # Original full Newton-Schulz
@@ -315,41 +334,61 @@ class AdaMuon_adv(torch.optim.Optimizer):
315
334
  steps=group['ns_steps'],
316
335
  eps=group['ns_eps'],
317
336
  coeffs=group['ns_coeffs'],
337
+ cns=group['accelerated_ns'],
338
+ cns_a_bound=group['cns_a_bound'],
318
339
  )
319
340
  del signed_m_buf
320
341
 
321
- # Reconstruct second momentum from previous step's factors
322
- vt_buf = _unnmf((state['mu_vbuf_nmf'], state['mv_vbuf_nmf']))
342
+ if group['normuon_variant']:
343
+ v_t = state['normuon_v']
344
+ # Update 2nd moment estimate
345
+ mean_squared_update = torch.mean(update.square(), dim=1)
346
+ v_t.mul_(beta2).add_(mean_squared_update, alpha=1 - beta2)
347
+ # Normalize update
348
+ if group['use_atan2']:
349
+ a = 1.2732395
350
+ update.atan2_(v_t.sqrt().unsqueeze(1)).mul_(a)
351
+ else:
352
+ update.div_(v_t.sqrt().unsqueeze(1).add_(group['eps']))
353
+ # Scale learning rate
354
+ update_norm = torch.linalg.vector_norm(update)
355
+ scaled_lr = group['rms_target'] * lr * (p.numel()**0.5) / update_norm.add_(group['eps'])
356
+ update = update.view(p.shape).mul_(scaled_lr)
357
+ del mean_squared_update, update_norm, scaled_lr
358
+ else:
359
+ # Reconstruct second momentum from previous step's factors
360
+ vt_buf = _unnmf((state['mu_vbuf_nmf'], state['mv_vbuf_nmf']))
323
361
 
324
- # Update second momentum in full-size
325
- vt_buf.mul_(beta2).addcmul_(update, update, value=1 - beta2)
362
+ # Update second momentum in full-size
363
+ vt_buf.mul_(beta2).addcmul_(update, update, value=1 - beta2)
326
364
 
327
- # Apply second momentum update (adaptive scaling)
328
- if group['use_atan2']:
329
- a = 1.2732395
330
- denom = vt_buf.sqrt()
331
- update.atan2_(denom).mul_(a)
332
- else:
333
- denom = vt_buf.sqrt().add_(group['eps'])
334
- update.div_(denom)
335
- del denom
365
+ # Apply second momentum update (adaptive scaling)
366
+ if group['use_atan2']:
367
+ a = 1.2732395
368
+ denom = vt_buf.sqrt()
369
+ update.atan2_(denom).mul_(a)
370
+ else:
371
+ denom = vt_buf.sqrt().add_(group['eps'])
372
+ update.div_(denom)
373
+ del denom
336
374
 
337
- # RMS-aligned rescaling
338
- rms_target = group['rms_target']
339
- num_elements = update.numel()
340
- # Add eps to prevent division by zero
341
- update.mul_(rms_target * (num_elements ** 0.5) / (update.norm() + group['eps']))
375
+ # RMS-aligned rescaling
376
+ rms_target = group['rms_target']
377
+ num_elements = update.numel()
378
+ # Add eps to prevent division by zero
379
+ update.mul_(rms_target * (num_elements ** 0.5) / (update.norm().add_(group['eps'])))
342
380
 
343
- update = update.view(p.shape).mul_(lr)
344
- del num_elements
381
+ update = update.view(p.shape).mul_(lr)
382
+ del num_elements
345
383
 
346
384
  # Compress updated moments and store new factors
347
385
  state['sign_buf'] = _pack_bools(mt_buf > 0)
348
386
  _nnmf(mt_buf.abs(), out=(state['mu_mbuf_nmf'], state['mv_mbuf_nmf']))
349
387
  del mt_buf
350
388
 
351
- _nnmf(vt_buf.abs(), out=(state['mu_vbuf_nmf'], state['mv_vbuf_nmf']))
352
- del vt_buf
389
+ if not group['normuon_variant']:
390
+ _nnmf(vt_buf.abs(), out=(state['mu_vbuf_nmf'], state['mv_vbuf_nmf']))
391
+ del vt_buf
353
392
 
354
393
  else: # Standard AdaMuon logic for non-factored tensors
355
394
 
@@ -387,12 +426,12 @@ class AdaMuon_adv(torch.optim.Optimizer):
387
426
  Q, _ = torch.linalg.qr(MG)
388
427
  projected_M = Q.T @ M
389
428
  ortho_projected_M = _newton_schulz_iteration(
390
- projected_M, steps=group['ns_steps'], eps=group['ns_eps'], coeffs=group['ns_coeffs']
429
+ projected_M, steps=group['ns_steps'], eps=group['ns_eps'], coeffs=group['ns_coeffs'], cns=group['accelerated_ns'], cns_a_bound=group['cns_a_bound']
391
430
  )
392
431
  update = Q @ ortho_projected_M
393
432
  else: # Fallback for invalid rank
394
433
  update = _newton_schulz_iteration(
395
- signed_m_buf, steps=group['ns_steps'], eps=group['ns_eps'], coeffs=group['ns_coeffs']
434
+ signed_m_buf, steps=group['ns_steps'], eps=group['ns_eps'], coeffs=group['ns_coeffs'], cns=group['accelerated_ns'], cns_a_bound=group['cns_a_bound']
396
435
  )
397
436
  else:
398
437
  # Original full Newton-Schulz
@@ -401,42 +440,59 @@ class AdaMuon_adv(torch.optim.Optimizer):
401
440
  steps=group['ns_steps'],
402
441
  eps=group['ns_eps'],
403
442
  coeffs=group['ns_coeffs'],
443
+ cns=group['accelerated_ns'],
444
+ cns_a_bound=group['cns_a_bound'],
404
445
  )
405
446
  del signed_m_buf
406
447
 
407
448
  update = update.view(original_shape)
408
449
 
409
- vt_buf = state['second_momentum_buffer']
410
- vt_buf.mul_(beta2).addcmul_(update, update, value=1 - beta2)
411
-
412
- # Apply second momentum update (adaptive scaling)
413
- if group['use_atan2']:
414
- a = 1.2732395
415
- denom = vt_buf.sqrt()
416
- update.atan2_(denom).mul_(a)
450
+ if group['normuon_variant']:
451
+ # NorMuon Logic
452
+ v_t = state['normuon_v']
453
+ # Update 2nd moment estimate
454
+ mean_squared_update = torch.mean(update.square(), dim=1)
455
+ v_t.mul_(beta2).add_(mean_squared_update, alpha=1 - beta2)
456
+ # Normalize update
457
+ if group['use_atan2']:
458
+ a = 1.2732395
459
+ update.atan2_(v_t.sqrt().unsqueeze(1)).mul_(a)
460
+ else:
461
+ update.div_(v_t.sqrt().unsqueeze(1).add_(group['eps']))
462
+ # Scale learning rate
463
+ update_norm = torch.linalg.vector_norm(update)
464
+ scaled_lr = group['rms_target'] * lr * (p.numel()**0.5) / update_norm.add_(group['eps'])
465
+ update.mul_(scaled_lr)
466
+ del mean_squared_update, update_norm, scaled_lr
417
467
  else:
418
- denom = vt_buf.sqrt().add_(group['eps'])
419
- update.div_(denom)
420
- del denom
421
-
422
- # RMS-aligned rescaling
423
- rms_target = group['rms_target']
424
- num_elements = update.numel()
425
- # Add eps to prevent division by zero
426
- update.mul_(rms_target * (num_elements ** 0.5) / (update.norm() + group['eps']))
427
-
428
- del num_elements
429
-
430
- update.mul_(lr)
468
+ # Original AdaMuon Logic
469
+ vt_buf = state['second_momentum_buffer']
470
+ vt_buf.mul_(beta2).addcmul_(update, update, value=1 - beta2)
471
+ # Apply second momentum update (adaptive scaling)
472
+ if group['use_atan2']:
473
+ a = 1.2732395
474
+ denom = vt_buf.sqrt()
475
+ update.atan2_(denom).mul_(a)
476
+ else:
477
+ denom = vt_buf.sqrt().add_(group['eps'])
478
+ update.div_(denom)
479
+ del denom
480
+ # RMS-aligned rescaling
481
+ rms_target = group['rms_target']
482
+ num_elements = update.numel()
483
+ # Add eps to prevent division by zero
484
+ update.mul_(rms_target * (num_elements ** 0.5) / (update.norm().add_(group['eps'])))
485
+ del num_elements
486
+ update.mul_(lr)
431
487
 
432
488
  else: # Fallback to standard SGD with momentum for 1D params (biases, etc.)
433
489
  # Momentum update
434
490
  mt_buf = state['momentum_buffer']
435
491
  mt_buf.mul_(beta1).add_(grad)
436
492
  if nesterov:
437
- # Nesterov momentum
438
493
  update = grad.add(mt_buf, alpha=beta1)
439
- # elif Simplified_AdEMAMix: # TODO, it will break SGD since it requires x100 lower LR
494
+ # FIXME, Simplified_AdEMAMix will break SGD since it requires x100 lower LR
495
+ # elif Simplified_AdEMAMix:
440
496
  # update = mt_buf.add(grad, alpha=alpha_grad)
441
497
  else:
442
498
  update = mt_buf.clone()
@@ -599,7 +655,6 @@ class AdaMuon_adv(torch.optim.Optimizer):
599
655
  state = self.state[p]
600
656
 
601
657
 
602
-
603
658
  # Determine if using Adam or Muon based on state keys
604
659
  # We can use optm_type but I see this as a safer way.
605
660
  if 'momentum_buffer' in state or 'mu_mbuf_nmf' in state:
@@ -611,32 +666,41 @@ class AdaMuon_adv(torch.optim.Optimizer):
611
666
  is_compiled = group.get('compiled_optimizer', False)
612
667
 
613
668
  if use_adam:
669
+ step = state['step']
670
+
614
671
  if self.kourkoutas_helper:
615
672
  # Prepare Kourkoutas-β once per optimizer step.
616
- self.kourkoutas_helper.maybe_prepare_step(self.global_step)
673
+ self.kourkoutas_helper.maybe_prepare_step(step)
674
+
617
675
  # Adam-specific setup (bias correction)
618
676
  if group['adam_use_bias_correction']:
619
- current_step = self.global_step + 1
677
+ current_step = step + 1
620
678
  beta1_adam, beta2_adam = group['adam_betas']
621
679
  bias_correction1 = 1.0 - beta1_adam ** current_step
622
680
  bias_correction2 = 1.0 - beta2_adam ** current_step
623
681
  else:
624
682
  bias_correction1 = 1.0
625
683
  bias_correction2 = 1.0
684
+
685
+ self.state[p]['step'] += 1
686
+
626
687
  # Dispatch to compiled or uncompiled Adam step
627
688
  if is_compiled and self._compiled_adam_step is not None:
628
- # Tensors must be used for compiled functions
629
- lr_tensor = torch.tensor(lr, device=p.device)
630
- bc1_tensor = torch.tensor(bias_correction1, device=p.device)
631
- bc2_tensor = torch.tensor(bias_correction2, device=p.device)
632
- self._compiled_adam_step(p, grad, state, group, lr_tensor, bc1_tensor, bc2_tensor)
689
+ # convert to tensors for compiled path once a step
690
+ if not hasattr(self, 'lr_adam_tensor') or self.lr_adam_tensor is None:
691
+ self.lr_adam_tensor = torch.tensor(group['lr'])
692
+ self.bc1 = torch.tensor(bias_correction1)
693
+ self.bc2 = torch.tensor(bias_correction2)
694
+ self._compiled_adam_step(p, grad, state, group, self.lr_adam_tensor, self.bc1, self.bc2)
633
695
  else:
634
696
  self._adam_step_parameter(p, grad, state, group, lr, bias_correction1, bias_correction2)
635
697
  else: # Muon path
636
698
  # Dispatch to compiled or uncompiled Muon step
637
699
  if is_compiled and self._compiled_muon_step is not None:
638
- lr_tensor = torch.tensor(lr, device=p.device)
639
- self._compiled_muon_step(p, grad, state, group, lr_tensor)
700
+ # convert to tensors for compiled path once a step
701
+ if not hasattr(self, 'lr_tensor') or self.lr_tensor is None:
702
+ self.lr_tensor = torch.tensor(group['lr'])
703
+ self._compiled_muon_step(p, grad, state, group, self.lr_tensor)
640
704
  else:
641
705
  self._muon_step_parameter(p, grad, state, group, lr)
642
706
 
@@ -659,6 +723,8 @@ class AdaMuon_adv(torch.optim.Optimizer):
659
723
  for i, p in enumerate(group['params']):
660
724
  self.step_parameter(p, group, i)
661
725
 
662
- self.global_step += 1
663
-
726
+ if self.param_groups[0].get('compiled_optimizer', False):
727
+ # Reset compile tensors once a step
728
+ self.lr_tensor = None
729
+ self.lr_adam_tensor = None
664
730
  return loss
@@ -60,6 +60,9 @@ class Muon_adv(torch.optim.Optimizer):
60
60
  normuon_lr_scale (float): Scaling factor for the NorMuon learning rate.
61
61
  (default: 0.2)
62
62
  normuon_atan2 (bool): whether to use the atan2 for NorMuon. (default: False)
63
+ accelerated_ns (bool): If True, enables Chebyshev-accelerated Newton-Schulz, which
64
+ dynamically calculates optimal 3rd-order polynomial coefficients. (default: False)
65
+ cns_a_bound (float): Initial lower bound for singular values for CANS. (default: 1e-4)
63
66
  --- Auxiliary AdamW_adv Parameters (used for 'adam' groups) ---
64
67
  adam_betas (tuple[float, float]): Betas for the AdamW optimizer part.
65
68
  adam_eps (float): Epsilon for the AdamW optimizer part.
@@ -73,6 +76,7 @@ class Muon_adv(torch.optim.Optimizer):
73
76
  adam_beta3_ema (float): Beta3 for AdEMAMix.
74
77
  adam_alpha (float): Alpha for AdEMAMix.
75
78
  adam_kourkoutas_beta (bool): Kourkoutas-β for AdamW.
79
+ adam_nnmf_factor (bool): 1-bit factored for AdamW.
76
80
  """
77
81
 
78
82
  def __init__(
@@ -100,6 +104,9 @@ class Muon_adv(torch.optim.Optimizer):
100
104
  normuon_eps: float = 1e-8,
101
105
  normuon_lr_scale: float = 0.2,
102
106
  normuon_atan2: bool = False,
107
+ # CANS
108
+ accelerated_ns: bool = False,
109
+ cns_a_bound: float = 1e-4,
103
110
  # Compiled
104
111
  compiled_optimizer: bool = False,
105
112
  # --- AdamW_adv specific parameters ---
@@ -119,6 +126,7 @@ class Muon_adv(torch.optim.Optimizer):
119
126
  adam_ema_alpha: float = 0.95,
120
127
  adam_tiny_spike: float = 1e-9,
121
128
  adam_k_warmup_steps: int = 0,
129
+ adam_nnmf_factor: bool = False,
122
130
  ):
123
131
  if not (lr >= 0.0):
124
132
  raise ValueError(f"Learning-rate should be >= 0.0. Got {lr}")
@@ -148,6 +156,8 @@ class Muon_adv(torch.optim.Optimizer):
148
156
  "normuon_variant": normuon_variant, "beta2_normuon": beta2_normuon,
149
157
  "normuon_eps": normuon_eps, "normuon_lr_scale": normuon_lr_scale,
150
158
  "normuon_atan2": normuon_atan2,
159
+ # CANS
160
+ "accelerated_ns": accelerated_ns, "cns_a_bound": cns_a_bound,
151
161
  # AdamW_adv defaults
152
162
  "adam_betas": adam_betas, "adam_eps": adam_eps, "adam_weight_decay": adam_weight_decay,
153
163
  "adam_use_bias_correction": adam_use_bias_correction, "adam_use_atan2": adam_use_atan2,
@@ -157,13 +167,13 @@ class Muon_adv(torch.optim.Optimizer):
157
167
  "adam_kourkoutas_beta": adam_kourkoutas_beta, "adam_beta2_min": adam_beta2_min,
158
168
  "adam_ema_alpha": adam_ema_alpha, "adam_tiny_spike": adam_tiny_spike,
159
169
  "adam_k_warmup_steps": adam_k_warmup_steps,
170
+ "adam_nnmf_factor":adam_nnmf_factor,
160
171
  }
161
172
  self.stochastic_rounding = stochastic_rounding
162
173
  self.compiled_optimizer = compiled_optimizer
163
174
 
164
175
  super().__init__(params, defaults)
165
176
 
166
- self.global_step = 0 # For Adam bias correction and Kourkoutas
167
177
  self.kourkoutas_helper = None
168
178
  if any(group.get('adam_kourkoutas_beta', False) for group in self.param_groups):
169
179
  self.kourkoutas_helper = KourkoutasHelper(self)
@@ -214,6 +224,7 @@ class Muon_adv(torch.optim.Optimizer):
214
224
 
215
225
  if optim_type == 'muon':
216
226
 
227
+
217
228
  state['factored'] = (
218
229
  group['nnmf_factor'] and
219
230
  not (len(p.shape) == 1 and not group['vector_reshape'])
@@ -238,9 +249,12 @@ class Muon_adv(torch.optim.Optimizer):
238
249
  elif len(p.shape) >= 2:
239
250
  state['normuon_v'] = torch.zeros(p.shape[0], device=p.device, dtype=torch.float32)
240
251
 
252
+ group['adam_kourkoutas_beta'] = False
241
253
 
242
254
  elif optim_type == 'adam':
243
255
 
256
+ state['step'] = 0
257
+
244
258
  state['factored'] = (
245
259
  group['adam_nnmf_factor'] and
246
260
  not (len(p.shape) == 1 and not group['vector_reshape'])
@@ -319,12 +333,12 @@ class Muon_adv(torch.optim.Optimizer):
319
333
  Q, _ = torch.linalg.qr(MG)
320
334
  projected_M = Q.T @ M
321
335
  ortho_projected_M = _newton_schulz_iteration(
322
- projected_M, steps=group['ns_steps'], eps=group['ns_eps'], coeffs=group['ns_coeffs']
336
+ projected_M, steps=group['ns_steps'], eps=group['ns_eps'], coeffs=group['ns_coeffs'], cns=group['accelerated_ns'], cns_a_bound=group['cns_a_bound']
323
337
  )
324
338
  update = Q @ ortho_projected_M
325
339
  else: # Fallback for invalid rank
326
340
  update = _newton_schulz_iteration(
327
- update, steps=group['ns_steps'], eps=group['ns_eps'], coeffs=group['ns_coeffs']
341
+ update, steps=group['ns_steps'], eps=group['ns_eps'], coeffs=group['ns_coeffs'], cns=group['accelerated_ns'], cns_a_bound=group['cns_a_bound']
328
342
  )
329
343
  else:
330
344
  # Original full Newton-Schulz
@@ -333,10 +347,12 @@ class Muon_adv(torch.optim.Optimizer):
333
347
  steps=group['ns_steps'],
334
348
  eps=group['ns_eps'],
335
349
  coeffs=group['ns_coeffs'],
350
+ cns=group['accelerated_ns'],
351
+ cns_a_bound=group['cns_a_bound'],
336
352
  )
337
353
 
338
354
 
339
- if group['normuon_variant'] and 'normuon_v' in state:
355
+ if group['normuon_variant']:
340
356
  v_t = state['normuon_v']
341
357
  beta2_normuon = group['beta2_normuon']
342
358
  # Update 2nd moment estimate
@@ -414,6 +430,8 @@ class Muon_adv(torch.optim.Optimizer):
414
430
  steps=group['ns_steps'],
415
431
  eps=group['ns_eps'],
416
432
  coeffs=group['ns_coeffs'],
433
+ cns=group['accelerated_ns'],
434
+ cns_a_bound=group['cns_a_bound'],
417
435
  )
418
436
 
419
437
  # 5. Project back to the original space
@@ -424,6 +442,8 @@ class Muon_adv(torch.optim.Optimizer):
424
442
  steps=group['ns_steps'],
425
443
  eps=group['ns_eps'],
426
444
  coeffs=group['ns_coeffs'],
445
+ cns=group['accelerated_ns'],
446
+ cns_a_bound=group['cns_a_bound'],
427
447
  )
428
448
  else:
429
449
  # Original NewtonSchulz
@@ -432,10 +452,12 @@ class Muon_adv(torch.optim.Optimizer):
432
452
  steps=group['ns_steps'],
433
453
  eps=group['ns_eps'],
434
454
  coeffs=group['ns_coeffs'],
455
+ cns=group['accelerated_ns'],
456
+ cns_a_bound=group['cns_a_bound'],
435
457
  )
436
458
 
437
459
  # NorMuon Logic
438
- if group['normuon_variant'] and 'normuon_v' in state:
460
+ if group['normuon_variant']:
439
461
  v_t = state['normuon_v']
440
462
  beta2_normuon = group['beta2_normuon']
441
463
  # Update 2nd moment estimate
@@ -629,10 +651,6 @@ class Muon_adv(torch.optim.Optimizer):
629
651
 
630
652
  state = self.state[p]
631
653
 
632
- if self.kourkoutas_helper:
633
- # Prepare Kourkoutas-β once per optimizer step.
634
- self.kourkoutas_helper.maybe_prepare_step(self.global_step)
635
-
636
654
  # Determine if using Adam or Muon based on state keys
637
655
  # We can use optm_type but I see this as a safer way.
638
656
  if 'momentum_buffer' in state or 'mu_mbuf_nmf' in state:
@@ -644,9 +662,15 @@ class Muon_adv(torch.optim.Optimizer):
644
662
  is_compiled = group.get('compiled_optimizer', False)
645
663
 
646
664
  if use_adam:
665
+ step = state['step']
666
+
667
+ if self.kourkoutas_helper:
668
+ # Prepare Kourkoutas-β once per optimizer step.
669
+ self.kourkoutas_helper.maybe_prepare_step(step)
670
+
647
671
  # Adam-specific setup (bias correction)
648
672
  if group['adam_use_bias_correction']:
649
- current_step = self.global_step + 1
673
+ current_step = step + 1
650
674
  beta1_adam, beta2_adam = group['adam_betas']
651
675
  bias_correction1 = 1.0 - beta1_adam ** current_step
652
676
  bias_correction2 = 1.0 - beta2_adam ** current_step
@@ -654,6 +678,8 @@ class Muon_adv(torch.optim.Optimizer):
654
678
  bias_correction1 = 1.0
655
679
  bias_correction2 = 1.0
656
680
 
681
+ self.state[p]['step'] += 1
682
+
657
683
  # Dispatch to compiled or uncompiled Adam step
658
684
  if is_compiled and self._compiled_adam_step is not None:
659
685
  # Tensors must be used for compiled functions
@@ -690,6 +716,4 @@ class Muon_adv(torch.optim.Optimizer):
690
716
  for i, p in enumerate(group['params']):
691
717
  self.step_parameter(p, group, i)
692
718
 
693
- self.global_step += 1
694
-
695
719
  return loss
@@ -0,0 +1,87 @@
1
+ import torch
2
+
3
+ @torch.no_grad()
4
+ def _newton_schulz_iteration(
5
+ G: torch.Tensor,
6
+ steps: int = 5,
7
+ eps: float = 1e-7,
8
+ coeffs: tuple[float, float, float] = (3.4445, -4.7750, 2.0315),
9
+ cns: bool = False,
10
+ cns_a_bound: float = 1e-4,
11
+ ) -> torch.Tensor:
12
+ """
13
+ Performs the Newton-Schulz iteration to find the nearest orthogonal matrix.
14
+ This is the core computation of the Muon optimizer.
15
+
16
+ Args:
17
+ G (torch.Tensor): The 2D input matrix (momentum-accumulated gradient).
18
+ steps (int): The number of iterations to run.
19
+ eps (float): Small constant for numerical stability during normalization.
20
+ coeffs (tuple[float, float, float]): The (a, b, c) coefficients for the
21
+ quintic polynomial update.
22
+ cns (bool): If True, enables Chebyshev-accelerated Newton-Schulz (CANS)
23
+ using an iterative 3rd-order polynomial with optimal coefficients
24
+ derived at each step.
25
+ cns_a_bound (float): The initial lower bound for singular values when
26
+ using CANS. The upper bound is assumed to be 1.0 after normalization.
27
+ Returns:
28
+ torch.Tensor: The orthogonalized matrix.
29
+ """
30
+ assert G.ndim >= 2
31
+
32
+ a, b, c = coeffs
33
+
34
+ X = G.to(torch.bfloat16)
35
+
36
+ transposed = G.size(-2) > G.size(-1)
37
+ if transposed:
38
+ X = X.mT
39
+
40
+ X = X / (X.norm(dim=(-2, -1), keepdim=True) + eps)
41
+
42
+ if cns:
43
+ # Chebyshev-accelerated Newton-Schulz (CANS) from
44
+ # "Accelerating Newton-Schulz Iteration for Orthogonalization via Chebyshev-type Polynomials"
45
+ # This implements the iterative scheme from Algorithm 1, using the
46
+ # closed-form 3rd-order polynomial from Proposition 2.
47
+ lower_bound = cns_a_bound
48
+ upper_bound = 1.0 # Matrix is normalized, so largest singular value is approx 1.
49
+
50
+ for _ in range(steps):
51
+ # Calculate optimal 3rd-order coefficients c1, c3 for p(x) = c1*x + c3*x^3
52
+ # based on the current singular value bounds [lower_bound, upper_bound].
53
+ # Formulas are derived from Proposition 2 and its proof in Appendix B of the paper.
54
+ a_bound, b_bound = lower_bound, upper_bound
55
+ term = a_bound*a_bound + a_bound*b_bound + b_bound*b_bound
56
+ e_sq = term / 3.0
57
+
58
+ # Calculate alpha, which scales the polynomial
59
+ common_den_part = 2.0 * (e_sq**1.5)
60
+ ab_part = a_bound*a_bound*b_bound + b_bound*b_bound*a_bound
61
+ alpha_den = common_den_part + ab_part
62
+ alpha = 6.0 / alpha_den
63
+
64
+ c1 = alpha * e_sq
65
+ c3 = -alpha / 3.0
66
+
67
+ # Apply the 3rd-order Newton-Schulz update
68
+ A = X @ X.mT
69
+ X = c1 * X + c3 * (A @ X)
70
+
71
+ # Update the singular value bounds for the next iteration based on the error
72
+ eps_num = common_den_part - ab_part
73
+ eps_val = eps_num / alpha_den
74
+ lower_bound = 1.0 - eps_val
75
+ upper_bound = 1.0 + eps_val
76
+ else:
77
+ # Perform the iterative updates
78
+ for _ in range(steps):
79
+ A = X @ X.mT
80
+ B = b * A + c * (A @ A)
81
+ X = a * X + B @ X
82
+
83
+ # Transpose back if necessary
84
+ if transposed:
85
+ X = X.mT
86
+
87
+ return X.to(G.dtype)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: adv_optm
3
- Version: 1.2.dev14
3
+ Version: 1.2.dev16
4
4
  Summary: A family of highly efficient, lightweight yet powerful optimizers.
5
5
  Home-page: https://github.com/Koratahiu/Advanced_Optimizers
6
6
  Author: Koratahiu
@@ -5,7 +5,7 @@ with open("README.md", "r", encoding="utf-8") as fh:
5
5
 
6
6
  setup(
7
7
  name="adv_optm",
8
- version="1.2.dev14",
8
+ version="1.2.dev16",
9
9
  author="Koratahiu",
10
10
  author_email="hiuhonor@gmail.com",
11
11
  license='Apache 2.0',
@@ -1,48 +0,0 @@
1
- import torch
2
-
3
- @torch.no_grad()
4
- def _newton_schulz_iteration(
5
- G: torch.Tensor,
6
- steps: int = 5,
7
- eps: float = 1e-7,
8
- coeffs: tuple[float, float, float] = (3.4445, -4.7750, 2.0315)
9
- ) -> torch.Tensor:
10
- """
11
- Performs the Newton-Schulz iteration to find the nearest orthogonal matrix.
12
- This is the core computation of the Muon optimizer.
13
-
14
- Args:
15
- G (torch.Tensor): The 2D input matrix (momentum-accumulated gradient).
16
- steps (int): The number of iterations to run.
17
- eps (float): Small constant for numerical stability during normalization.
18
- coeffs (tuple[float, float, float]): The (a, b, c) coefficients for the
19
- quintic polynomial update.
20
-
21
- Returns:
22
- torch.Tensor: The orthogonalized matrix.
23
- """
24
- assert G.ndim == 2, "Newton-Schulz iteration only supports 2D matrices."
25
-
26
- a, b, c = coeffs
27
-
28
- X = G.to(torch.bfloat16)
29
-
30
- # Normalize the matrix
31
- X.div_(X.norm() + eps)
32
-
33
- # Handle non-square matrices by transposing the taller one
34
- transposed = G.size(0) > G.size(1)
35
- if transposed:
36
- X = X.T
37
-
38
- # Perform the iterative updates
39
- for _ in range(steps):
40
- A = X @ X.T
41
- B = b * A + c * (A @ A)
42
- X = a * X + B @ X
43
-
44
- # Transpose back if necessary
45
- if transposed:
46
- X = X.T
47
-
48
- return X.to(G.dtype)
File without changes
File without changes
File without changes