adv-optm 2.dev1__tar.gz → 2.dev2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of adv-optm might be problematic. Click here for more details.

Files changed (29) hide show
  1. {adv_optm-2.dev1 → adv_optm-2.dev2}/PKG-INFO +1 -1
  2. {adv_optm-2.dev1 → adv_optm-2.dev2}/adv_optm/__init__.py +1 -1
  3. {adv_optm-2.dev1 → adv_optm-2.dev2}/adv_optm/optim/AdaMuon_adv.py +132 -66
  4. {adv_optm-2.dev1 → adv_optm-2.dev2}/adv_optm/optim/AdamW_adv.py +24 -15
  5. {adv_optm-2.dev1 → adv_optm-2.dev2}/adv_optm/optim/Adopt_adv.py +1 -1
  6. {adv_optm-2.dev1 → adv_optm-2.dev2}/adv_optm/optim/Lion_Prodigy_adv.py +1 -1
  7. {adv_optm-2.dev1 → adv_optm-2.dev2}/adv_optm/optim/Lion_adv.py +22 -1
  8. {adv_optm-2.dev1 → adv_optm-2.dev2}/adv_optm/optim/Muon_adv.py +33 -12
  9. {adv_optm-2.dev1 → adv_optm-2.dev2}/adv_optm/optim/Prodigy_adv.py +10 -8
  10. {adv_optm-2.dev1 → adv_optm-2.dev2}/adv_optm/optim/Simplified_AdEMAMix.py +1 -1
  11. adv_optm-2.dev2/adv_optm/util/Newton_Schulz.py +87 -0
  12. {adv_optm-2.dev1 → adv_optm-2.dev2}/adv_optm.egg-info/PKG-INFO +1 -1
  13. {adv_optm-2.dev1 → adv_optm-2.dev2}/setup.py +1 -1
  14. adv_optm-2.dev1/adv_optm/util/Newton_Schulz.py +0 -47
  15. {adv_optm-2.dev1 → adv_optm-2.dev2}/LICENSE +0 -0
  16. {adv_optm-2.dev1 → adv_optm-2.dev2}/README.md +0 -0
  17. {adv_optm-2.dev1 → adv_optm-2.dev2}/adv_optm/optim/__init__.py +0 -0
  18. {adv_optm-2.dev1 → adv_optm-2.dev2}/adv_optm/util/BF16_Stochastic_Rounding.py +0 -0
  19. {adv_optm-2.dev1 → adv_optm-2.dev2}/adv_optm/util/Effective_Shape.py +0 -0
  20. {adv_optm-2.dev1 → adv_optm-2.dev2}/adv_optm/util/Kourkoutas.py +0 -0
  21. {adv_optm-2.dev1 → adv_optm-2.dev2}/adv_optm/util/NNMF.py +0 -0
  22. {adv_optm-2.dev1 → adv_optm-2.dev2}/adv_optm/util/One_Bit_Boolean.py +0 -0
  23. {adv_optm-2.dev1 → adv_optm-2.dev2}/adv_optm/util/OrthoGrad.py +0 -0
  24. {adv_optm-2.dev1 → adv_optm-2.dev2}/adv_optm/util/__init__.py +0 -0
  25. {adv_optm-2.dev1 → adv_optm-2.dev2}/adv_optm.egg-info/SOURCES.txt +0 -0
  26. {adv_optm-2.dev1 → adv_optm-2.dev2}/adv_optm.egg-info/dependency_links.txt +0 -0
  27. {adv_optm-2.dev1 → adv_optm-2.dev2}/adv_optm.egg-info/requires.txt +0 -0
  28. {adv_optm-2.dev1 → adv_optm-2.dev2}/adv_optm.egg-info/top_level.txt +0 -0
  29. {adv_optm-2.dev1 → adv_optm-2.dev2}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: adv_optm
3
- Version: 2.dev1
3
+ Version: 2.dev2
4
4
  Summary: A family of highly efficient, lightweight yet powerful optimizers.
5
5
  Home-page: https://github.com/Koratahiu/Advanced_Optimizers
6
6
  Author: Koratahiu
@@ -20,4 +20,4 @@ __all__ = [
20
20
  "AdaMuon_adv",
21
21
  ]
22
22
 
23
- __version__ = "2.dev1"
23
+ __version__ = "2.dev2"
@@ -63,6 +63,9 @@ class AdaMuon_adv(torch.optim.Optimizer):
63
63
  (default: False)
64
64
  ortho_rank (int): The rank for low-rank orthogonalization.
65
65
  (default: 128)
66
+ accelerated_ns (bool): If True, enables Chebyshev-accelerated Newton-Schulz, which
67
+ dynamically calculates optimal 3rd-order polynomial coefficients. (default: False)
68
+ cns_a_bound (float): Initial lower bound for singular values for CANS. (default: 1e-4)
66
69
  nnmf_factor (bool): whether to use the factorization or disable it to use
67
70
  the uncompressed optimizer. (default: False)
68
71
  --- Auxiliary AdamW_adv Parameters (used for 'adam' groups) ---
@@ -96,11 +99,16 @@ class AdaMuon_adv(torch.optim.Optimizer):
96
99
  nesterov: bool = False,
97
100
  Simplified_AdEMAMix: bool = False,
98
101
  alpha_grad: float = 100.0,
99
- vector_reshape: bool = False,
102
+ normuon_variant: bool = False,
100
103
  # Low-rank Muon
101
104
  low_rank_ortho: bool = False,
102
105
  ortho_rank: int = 128,
106
+ # Factored
107
+ vector_reshape: bool = False,
103
108
  nnmf_factor: bool = False,
109
+ # CANS
110
+ accelerated_ns: bool = False,
111
+ cns_a_bound: float = 1e-4,
104
112
  # Compiled
105
113
  compiled_optimizer: bool = False,
106
114
  # --- AdamW_adv specific parameters ---
@@ -139,9 +147,12 @@ class AdaMuon_adv(torch.optim.Optimizer):
139
147
  "vector_reshape": vector_reshape,
140
148
  "nesterov":nesterov, "use_atan2":use_atan2,
141
149
  "Simplified_AdEMAMix": Simplified_AdEMAMix, "alpha_grad": alpha_grad,
150
+ "normuon_variant": normuon_variant,
142
151
  # Low-rank Ortho
143
152
  "low_rank_ortho": low_rank_ortho, "ortho_rank": ortho_rank,
144
153
  "compiled_optimizer":compiled_optimizer,
154
+ # CANS
155
+ "accelerated_ns": accelerated_ns, "cns_a_bound": cns_a_bound,
145
156
  # AdamW_adv defaults
146
157
  "adam_betas": adam_betas, "adam_eps": adam_eps, "adam_weight_decay": adam_weight_decay,
147
158
  "adam_use_bias_correction": adam_use_bias_correction, "adam_use_atan2": adam_use_atan2,
@@ -156,7 +167,6 @@ class AdaMuon_adv(torch.optim.Optimizer):
156
167
 
157
168
  super().__init__(params, defaults)
158
169
 
159
- self.global_step = 0 # For Adam bias correction and Kourkoutas
160
170
  self.kourkoutas_helper = None
161
171
  if any(group.get('adam_kourkoutas_beta', False) for group in self.param_groups):
162
172
  self.kourkoutas_helper = KourkoutasHelper(self)
@@ -214,16 +224,25 @@ class AdaMuon_adv(torch.optim.Optimizer):
214
224
  state['mv_mbuf_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
215
225
  packed_d2 = (d2 + 7) // 8
216
226
  state['sign_buf'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=device)
217
- state['mu_vbuf_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
218
- state['mv_vbuf_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
227
+ if not group['normuon_variant']:
228
+ state['mu_vbuf_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
229
+ state['mv_vbuf_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
219
230
  else:
220
- if len(p.shape) >= 2:
231
+ if len(p.shape) >= 2 and not group['normuon_variant']:
221
232
  state['second_momentum_buffer'] = torch.zeros_like(p)
222
233
  state['momentum_buffer'] = torch.zeros_like(p)
223
234
 
235
+ # NorMuon state initialization
236
+ if group['normuon_variant']:
237
+ if state['factored']:
238
+ state['normuon_v'] = torch.zeros(d1, device=p.device, dtype=torch.float32)
239
+ elif len(p.shape) >= 2:
240
+ state['normuon_v'] = torch.zeros(p.shape[0], device=p.device, dtype=torch.float32)
224
241
 
225
242
  elif optim_type == 'adam':
226
243
 
244
+ state['step'] = 0
245
+
227
246
  state['factored'] = (
228
247
  group['adam_nnmf_factor'] and
229
248
  not (len(p.shape) == 1 and not group['vector_reshape'])
@@ -301,12 +320,12 @@ class AdaMuon_adv(torch.optim.Optimizer):
301
320
  Q, _ = torch.linalg.qr(MG)
302
321
  projected_M = Q.T @ M
303
322
  ortho_projected_M = _newton_schulz_iteration(
304
- projected_M, steps=group['ns_steps'], eps=group['ns_eps'], coeffs=group['ns_coeffs']
323
+ projected_M, steps=group['ns_steps'], eps=group['ns_eps'], coeffs=group['ns_coeffs'], cns=group['accelerated_ns'], cns_a_bound=group['cns_a_bound']
305
324
  )
306
325
  update = Q @ ortho_projected_M
307
326
  else: # Fallback for invalid rank
308
327
  update = _newton_schulz_iteration(
309
- signed_m_buf, steps=group['ns_steps'], eps=group['ns_eps'], coeffs=group['ns_coeffs']
328
+ signed_m_buf, steps=group['ns_steps'], eps=group['ns_eps'], coeffs=group['ns_coeffs'], cns=group['accelerated_ns'], cns_a_bound=group['cns_a_bound']
310
329
  )
311
330
  else:
312
331
  # Original full Newton-Schulz
@@ -315,41 +334,61 @@ class AdaMuon_adv(torch.optim.Optimizer):
315
334
  steps=group['ns_steps'],
316
335
  eps=group['ns_eps'],
317
336
  coeffs=group['ns_coeffs'],
337
+ cns=group['accelerated_ns'],
338
+ cns_a_bound=group['cns_a_bound'],
318
339
  )
319
340
  del signed_m_buf
320
341
 
321
- # Reconstruct second momentum from previous step's factors
322
- vt_buf = _unnmf((state['mu_vbuf_nmf'], state['mv_vbuf_nmf']))
342
+ if group['normuon_variant']:
343
+ v_t = state['normuon_v']
344
+ # Update 2nd moment estimate
345
+ mean_squared_update = torch.mean(update.square(), dim=1)
346
+ v_t.mul_(beta2).add_(mean_squared_update, alpha=1 - beta2)
347
+ # Normalize update
348
+ if group['use_atan2']:
349
+ a = 1.2732395
350
+ update.atan2_(v_t.sqrt().unsqueeze(1)).mul_(a)
351
+ else:
352
+ update.div_(v_t.sqrt().unsqueeze(1).add_(group['eps']))
353
+ # Scale learning rate
354
+ update_norm = torch.linalg.vector_norm(update)
355
+ scaled_lr = group['rms_target'] * lr * (p.numel()**0.5) / update_norm.add_(group['eps'])
356
+ update = update.view(p.shape).mul_(scaled_lr)
357
+ del mean_squared_update, update_norm, scaled_lr
358
+ else:
359
+ # Reconstruct second momentum from previous step's factors
360
+ vt_buf = _unnmf((state['mu_vbuf_nmf'], state['mv_vbuf_nmf']))
323
361
 
324
- # Update second momentum in full-size
325
- vt_buf.mul_(beta2).addcmul_(update, update, value=1 - beta2)
362
+ # Update second momentum in full-size
363
+ vt_buf.mul_(beta2).addcmul_(update, update, value=1 - beta2)
326
364
 
327
- # Apply second momentum update (adaptive scaling)
328
- if group['use_atan2']:
329
- a = 1.2732395
330
- denom = vt_buf.sqrt()
331
- update.atan2_(denom).mul_(a)
332
- else:
333
- denom = vt_buf.sqrt().add_(group['eps'])
334
- update.div_(denom)
335
- del denom
365
+ # Apply second momentum update (adaptive scaling)
366
+ if group['use_atan2']:
367
+ a = 1.2732395
368
+ denom = vt_buf.sqrt()
369
+ update.atan2_(denom).mul_(a)
370
+ else:
371
+ denom = vt_buf.sqrt().add_(group['eps'])
372
+ update.div_(denom)
373
+ del denom
336
374
 
337
- # RMS-aligned rescaling
338
- rms_target = group['rms_target']
339
- num_elements = update.numel()
340
- # Add eps to prevent division by zero
341
- update.mul_(rms_target * (num_elements ** 0.5) / (update.norm() + group['eps']))
375
+ # RMS-aligned rescaling
376
+ rms_target = group['rms_target']
377
+ num_elements = update.numel()
378
+ # Add eps to prevent division by zero
379
+ update.mul_(rms_target * (num_elements ** 0.5) / (update.norm().add_(group['eps'])))
342
380
 
343
- update = update.view(p.shape).mul_(lr)
344
- del num_elements
381
+ update = update.view(p.shape).mul_(lr)
382
+ del num_elements
345
383
 
346
384
  # Compress updated moments and store new factors
347
385
  state['sign_buf'] = _pack_bools(mt_buf > 0)
348
386
  _nnmf(mt_buf.abs(), out=(state['mu_mbuf_nmf'], state['mv_mbuf_nmf']))
349
387
  del mt_buf
350
388
 
351
- _nnmf(vt_buf.abs(), out=(state['mu_vbuf_nmf'], state['mv_vbuf_nmf']))
352
- del vt_buf
389
+ if not group['normuon_variant']:
390
+ _nnmf(vt_buf.abs(), out=(state['mu_vbuf_nmf'], state['mv_vbuf_nmf']))
391
+ del vt_buf
353
392
 
354
393
  else: # Standard AdaMuon logic for non-factored tensors
355
394
 
@@ -387,12 +426,12 @@ class AdaMuon_adv(torch.optim.Optimizer):
387
426
  Q, _ = torch.linalg.qr(MG)
388
427
  projected_M = Q.T @ M
389
428
  ortho_projected_M = _newton_schulz_iteration(
390
- projected_M, steps=group['ns_steps'], eps=group['ns_eps'], coeffs=group['ns_coeffs']
429
+ projected_M, steps=group['ns_steps'], eps=group['ns_eps'], coeffs=group['ns_coeffs'], cns=group['accelerated_ns'], cns_a_bound=group['cns_a_bound']
391
430
  )
392
431
  update = Q @ ortho_projected_M
393
432
  else: # Fallback for invalid rank
394
433
  update = _newton_schulz_iteration(
395
- signed_m_buf, steps=group['ns_steps'], eps=group['ns_eps'], coeffs=group['ns_coeffs']
434
+ signed_m_buf, steps=group['ns_steps'], eps=group['ns_eps'], coeffs=group['ns_coeffs'], cns=group['accelerated_ns'], cns_a_bound=group['cns_a_bound']
396
435
  )
397
436
  else:
398
437
  # Original full Newton-Schulz
@@ -401,42 +440,59 @@ class AdaMuon_adv(torch.optim.Optimizer):
401
440
  steps=group['ns_steps'],
402
441
  eps=group['ns_eps'],
403
442
  coeffs=group['ns_coeffs'],
443
+ cns=group['accelerated_ns'],
444
+ cns_a_bound=group['cns_a_bound'],
404
445
  )
405
446
  del signed_m_buf
406
447
 
407
448
  update = update.view(original_shape)
408
449
 
409
- vt_buf = state['second_momentum_buffer']
410
- vt_buf.mul_(beta2).addcmul_(update, update, value=1 - beta2)
411
-
412
- # Apply second momentum update (adaptive scaling)
413
- if group['use_atan2']:
414
- a = 1.2732395
415
- denom = vt_buf.sqrt()
416
- update.atan2_(denom).mul_(a)
450
+ if group['normuon_variant']:
451
+ # NorMuon Logic
452
+ v_t = state['normuon_v']
453
+ # Update 2nd moment estimate
454
+ mean_squared_update = torch.mean(update.square(), dim=1)
455
+ v_t.mul_(beta2).add_(mean_squared_update, alpha=1 - beta2)
456
+ # Normalize update
457
+ if group['use_atan2']:
458
+ a = 1.2732395
459
+ update.atan2_(v_t.sqrt().unsqueeze(1)).mul_(a)
460
+ else:
461
+ update.div_(v_t.sqrt().unsqueeze(1).add_(group['eps']))
462
+ # Scale learning rate
463
+ update_norm = torch.linalg.vector_norm(update)
464
+ scaled_lr = group['rms_target'] * lr * (p.numel()**0.5) / update_norm.add_(group['eps'])
465
+ update.mul_(scaled_lr)
466
+ del mean_squared_update, update_norm, scaled_lr
417
467
  else:
418
- denom = vt_buf.sqrt().add_(group['eps'])
419
- update.div_(denom)
420
- del denom
421
-
422
- # RMS-aligned rescaling
423
- rms_target = group['rms_target']
424
- num_elements = update.numel()
425
- # Add eps to prevent division by zero
426
- update.mul_(rms_target * (num_elements ** 0.5) / (update.norm() + group['eps']))
427
-
428
- del num_elements
429
-
430
- update.mul_(lr)
468
+ # Original AdaMuon Logic
469
+ vt_buf = state['second_momentum_buffer']
470
+ vt_buf.mul_(beta2).addcmul_(update, update, value=1 - beta2)
471
+ # Apply second momentum update (adaptive scaling)
472
+ if group['use_atan2']:
473
+ a = 1.2732395
474
+ denom = vt_buf.sqrt()
475
+ update.atan2_(denom).mul_(a)
476
+ else:
477
+ denom = vt_buf.sqrt().add_(group['eps'])
478
+ update.div_(denom)
479
+ del denom
480
+ # RMS-aligned rescaling
481
+ rms_target = group['rms_target']
482
+ num_elements = update.numel()
483
+ # Add eps to prevent division by zero
484
+ update.mul_(rms_target * (num_elements ** 0.5) / (update.norm().add_(group['eps'])))
485
+ del num_elements
486
+ update.mul_(lr)
431
487
 
432
488
  else: # Fallback to standard SGD with momentum for 1D params (biases, etc.)
433
489
  # Momentum update
434
490
  mt_buf = state['momentum_buffer']
435
491
  mt_buf.mul_(beta1).add_(grad)
436
492
  if nesterov:
437
- # Nesterov momentum
438
493
  update = grad.add(mt_buf, alpha=beta1)
439
- # elif Simplified_AdEMAMix: # TODO, it will break SGD since it requires x100 lower LR
494
+ # FIXME, Simplified_AdEMAMix will break SGD since it requires x100 lower LR
495
+ # elif Simplified_AdEMAMix:
440
496
  # update = mt_buf.add(grad, alpha=alpha_grad)
441
497
  else:
442
498
  update = mt_buf.clone()
@@ -599,7 +655,6 @@ class AdaMuon_adv(torch.optim.Optimizer):
599
655
  state = self.state[p]
600
656
 
601
657
 
602
-
603
658
  # Determine if using Adam or Muon based on state keys
604
659
  # We can use optm_type but I see this as a safer way.
605
660
  if 'momentum_buffer' in state or 'mu_mbuf_nmf' in state:
@@ -611,32 +666,41 @@ class AdaMuon_adv(torch.optim.Optimizer):
611
666
  is_compiled = group.get('compiled_optimizer', False)
612
667
 
613
668
  if use_adam:
669
+ step = state['step']
670
+
614
671
  if self.kourkoutas_helper:
615
672
  # Prepare Kourkoutas-β once per optimizer step.
616
- self.kourkoutas_helper.maybe_prepare_step(self.global_step)
673
+ self.kourkoutas_helper.maybe_prepare_step(step)
674
+
617
675
  # Adam-specific setup (bias correction)
618
676
  if group['adam_use_bias_correction']:
619
- current_step = self.global_step + 1
677
+ current_step = step + 1
620
678
  beta1_adam, beta2_adam = group['adam_betas']
621
679
  bias_correction1 = 1.0 - beta1_adam ** current_step
622
680
  bias_correction2 = 1.0 - beta2_adam ** current_step
623
681
  else:
624
682
  bias_correction1 = 1.0
625
683
  bias_correction2 = 1.0
684
+
685
+ self.state[p]['step'] += 1
686
+
626
687
  # Dispatch to compiled or uncompiled Adam step
627
688
  if is_compiled and self._compiled_adam_step is not None:
628
- # Tensors must be used for compiled functions
629
- lr_tensor = torch.tensor(lr, device=p.device)
630
- bc1_tensor = torch.tensor(bias_correction1, device=p.device)
631
- bc2_tensor = torch.tensor(bias_correction2, device=p.device)
632
- self._compiled_adam_step(p, grad, state, group, lr_tensor, bc1_tensor, bc2_tensor)
689
+ # convert to tensors for compiled path once a step
690
+ if not hasattr(self, 'lr_adam_tensor') or self.lr_adam_tensor is None:
691
+ self.lr_adam_tensor = torch.tensor(group['lr'])
692
+ self.bc1 = torch.tensor(bias_correction1)
693
+ self.bc2 = torch.tensor(bias_correction2)
694
+ self._compiled_adam_step(p, grad, state, group, self.lr_adam_tensor, self.bc1, self.bc2)
633
695
  else:
634
696
  self._adam_step_parameter(p, grad, state, group, lr, bias_correction1, bias_correction2)
635
697
  else: # Muon path
636
698
  # Dispatch to compiled or uncompiled Muon step
637
699
  if is_compiled and self._compiled_muon_step is not None:
638
- lr_tensor = torch.tensor(lr, device=p.device)
639
- self._compiled_muon_step(p, grad, state, group, lr_tensor)
700
+ # convert to tensors for compiled path once a step
701
+ if not hasattr(self, 'lr_tensor') or self.lr_tensor is None:
702
+ self.lr_tensor = torch.tensor(group['lr'])
703
+ self._compiled_muon_step(p, grad, state, group, self.lr_tensor)
640
704
  else:
641
705
  self._muon_step_parameter(p, grad, state, group, lr)
642
706
 
@@ -659,6 +723,8 @@ class AdaMuon_adv(torch.optim.Optimizer):
659
723
  for i, p in enumerate(group['params']):
660
724
  self.step_parameter(p, group, i)
661
725
 
662
- self.global_step += 1
663
-
726
+ if self.param_groups[0].get('compiled_optimizer', False):
727
+ # Reset compile tensors once a step
728
+ self.lr_tensor = None
729
+ self.lr_adam_tensor = None
664
730
  return loss
@@ -51,6 +51,10 @@ class AdamW_adv(torch.optim.Optimizer):
51
51
  momentum. (default: 5.0)
52
52
  kourkoutas_beta (bool): whether to enable the layer-wise dynamic β₂ logic.
53
53
  If `False`, the optimizer behaves as standard AdamW. (default: False)
54
+ layer_key_fn (Optional[Callable]): A function that takes a parameter `p`
55
+ and returns a unique, hashable key representing its "layer" or "bucket".
56
+ If `None`, parameters are bucketed by their shape.
57
+ (default: None)
54
58
  beta2_min (float): The minimum value for dynamic β₂, used during periods of
55
59
  high gradient variance ("sunspikes"). Must be less than `betas[1]`.
56
60
  (default: 0.88)
@@ -89,6 +93,7 @@ class AdamW_adv(torch.optim.Optimizer):
89
93
  beta3_ema: float = 0.9999,
90
94
  alpha: float = 5.0,
91
95
  kourkoutas_beta: bool = False,
96
+ layer_key_fn: Optional[Callable] = None,
92
97
  beta2_min: float = 0.9,
93
98
  ema_alpha: float = 0.95,
94
99
  tiny_spike: float = 1e-9,
@@ -127,6 +132,7 @@ class AdamW_adv(torch.optim.Optimizer):
127
132
  self.grams_moment = grams_moment
128
133
  self.use_AdEMAMix = use_AdEMAMix
129
134
  self.factored = nnmf_factor
135
+ self.layer_key_fn = layer_key_fn
130
136
  self.kourkoutas_beta = kourkoutas_beta
131
137
 
132
138
  super().__init__(params, defaults)
@@ -136,8 +142,6 @@ class AdamW_adv(torch.optim.Optimizer):
136
142
  if self.kourkoutas_beta:
137
143
  self.kourkoutas_helper = KourkoutasHelper(self)
138
144
 
139
- self.global_step = 0
140
-
141
145
  if compiled_optimizer:
142
146
  torch._dynamo.config.cache_size_limit = 8192
143
147
  self.compile(fullgraph=True)
@@ -165,6 +169,8 @@ class AdamW_adv(torch.optim.Optimizer):
165
169
 
166
170
  if len(state) == 0:
167
171
 
172
+ state['step'] = 0
173
+
168
174
  state['factored'] = (
169
175
  self.factored and
170
176
  not (len(p.shape) == 1 and not group['vector_reshape'])
@@ -343,15 +349,13 @@ class AdamW_adv(torch.optim.Optimizer):
343
349
 
344
350
  @torch.no_grad()
345
351
  def step_parameter(self, p: torch.Tensor, group: dict, i: int | None = None):
346
- # if 'exp_avg_sq' not in self.state[p] and 'mu_v_nmf' not in self.state[p]:
347
- # return
348
352
 
349
- if self.global_step is None and 'step' in self.state[p]:
350
- # For backward compatibility
351
- self.global_step = self.state[p]['step']
353
+ state = self.state[p]
354
+
355
+ step = state['step']
352
356
 
353
357
  if group['use_bias_correction']:
354
- current_step = self.global_step + 1
358
+ current_step = step + 1
355
359
  beta1, beta2 = group['betas']
356
360
  bias_correction1 = 1.0 - beta1 ** current_step
357
361
  bias_correction2 = 1.0 - beta2 ** current_step
@@ -361,15 +365,19 @@ class AdamW_adv(torch.optim.Optimizer):
361
365
 
362
366
  if group.get('kourkoutas_beta', False):
363
367
  # Prepare Kourkoutas-β once per step using the global step counter.
364
- self.kourkoutas_helper.maybe_prepare_step(self.global_step)
368
+ self.kourkoutas_helper.maybe_prepare_step(step)
369
+
370
+ self.state[p]['step'] += 1
365
371
 
366
372
  if not group.get('compiled_optimizer', False):
367
373
  self.__step_parameter(p, group, group['lr'], bias_correction1, bias_correction2)
368
374
  else:
369
- lr_tensor = torch.tensor(group['lr'], device=p.device)
370
- bias_correction1_tensor = torch.tensor(bias_correction1, device=p.device)
371
- bias_correction2_tensor = torch.tensor(bias_correction2, device=p.device)
372
- self._compiled_step_parameter(p, group, lr_tensor, bias_correction1_tensor, bias_correction2_tensor)
375
+ if not hasattr(self, 'lr_tensor') or self.lr_tensor is None:
376
+ # convert to tensors for compiled path once a step
377
+ self.lr_tensor = torch.tensor(group['lr'], device=p.device)
378
+ self.bc1_tensor = torch.tensor(bias_correction1, device=p.device)
379
+ self.bc2_tensor = torch.tensor(bias_correction2, device=p.device)
380
+ self._compiled_step_parameter(p, group, self.lr_tensor, self.bc1_tensor, self.bc2_tensor)
373
381
 
374
382
  def compile(self, *args, **kwargs):
375
383
  self._compiled_step_parameter = torch.compile(self.__step_parameter, *args, **kwargs)
@@ -386,6 +394,7 @@ class AdamW_adv(torch.optim.Optimizer):
386
394
  for i, p in enumerate(group['params']):
387
395
  self.step_parameter(p, group, i)
388
396
 
389
- self.global_step += 1
390
-
397
+ if self.param_groups[0].get('compiled_optimizer', False):
398
+ # Reset compile tensors once a step
399
+ self.lr_tensor = None
391
400
  return loss
@@ -86,7 +86,7 @@ class Adopt_adv(torch.optim.Optimizer):
86
86
  logging (default: 0).
87
87
  layer_key_fn (Optional[Callable]): A function that takes a parameter `p`
88
88
  and returns a unique, hashable key representing its "layer" or "bucket".
89
- If `None`, parameters are bucketed by their memory ID (tensor-wise).
89
+ If `None`, parameters are bucketed by their shape.
90
90
  (default: None)
91
91
  nnmf_factor (bool): whether to use the factorization or disable it to use
92
92
  the uncompressed optimizer. (default: False)
@@ -111,7 +111,7 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
111
111
 
112
112
  if compiled_optimizer:
113
113
  torch._dynamo.config.cache_size_limit = 8192
114
- self.compile(fullgraph=False, dynamic=False) #FIXME
114
+ self.compile(fullgraph=True)
115
115
 
116
116
  @property
117
117
  def supports_fused_back_pass(self) -> bool:
@@ -42,6 +42,8 @@ class Lion_adv(torch.optim.Optimizer):
42
42
  orthogonal_gradient: bool = False,
43
43
  cautious_mask: bool = False,
44
44
  nnmf_factor: bool = True,
45
+ # Compiled
46
+ compiled_optimizer: bool = False,
45
47
  ):
46
48
  if not lr > 0.0:
47
49
  raise ValueError(f"Learning rate must be > 0.0, but got {lr}")
@@ -56,12 +58,20 @@ class Lion_adv(torch.optim.Optimizer):
56
58
  weight_decay=weight_decay,
57
59
  vector_reshape=vector_reshape,
58
60
  orthogonal_gradient=orthogonal_gradient,
61
+ compiled_optimizer=compiled_optimizer,
59
62
  )
60
63
  self.stochastic_rounding = stochastic_rounding
61
64
  self.cautious_mask = cautious_mask
62
65
  self.factored = nnmf_factor
63
66
  super().__init__(params, defaults)
64
67
 
68
+ self.init_step()
69
+
70
+ if compiled_optimizer:
71
+ torch._dynamo.config.cache_size_limit = 8192
72
+ self.compile(fullgraph=True)
73
+
74
+
65
75
  @property
66
76
  def supports_fused_back_pass(self) -> bool:
67
77
  return True
@@ -118,7 +128,6 @@ class Lion_adv(torch.optim.Optimizer):
118
128
 
119
129
 
120
130
  beta1, beta2 = group["betas"]
121
- lr = group["lr"]
122
131
 
123
132
  if state['factored']:
124
133
  # Factored Path
@@ -189,6 +198,18 @@ class Lion_adv(torch.optim.Optimizer):
189
198
 
190
199
  del update
191
200
 
201
+ @torch.no_grad()
202
+ def step_parameter(self, p: torch.Tensor, group: dict, i: int | None = None):
203
+
204
+ if not group.get('compiled_optimizer', False):
205
+ self.__step_parameter(p, group, group["lr"])
206
+ else:
207
+ lr_tensor = torch.tensor(group["lr"], device=p.device)
208
+ self._compiled_step_parameter(p, group, lr_tensor)
209
+
210
+ def compile(self, *args, **kwargs):
211
+ self._compiled_step_parameter = torch.compile(self.__step_parameter, *args, **kwargs)
212
+
192
213
  @torch.no_grad()
193
214
  def step(self, closure: Optional[callable] = None):
194
215
  """Performs a single optimization step."""
@@ -60,6 +60,9 @@ class Muon_adv(torch.optim.Optimizer):
60
60
  normuon_lr_scale (float): Scaling factor for the NorMuon learning rate.
61
61
  (default: 0.2)
62
62
  normuon_atan2 (bool): whether to use the atan2 for NorMuon. (default: False)
63
+ accelerated_ns (bool): If True, enables Chebyshev-accelerated Newton-Schulz, which
64
+ dynamically calculates optimal 3rd-order polynomial coefficients. (default: False)
65
+ cns_a_bound (float): Initial lower bound for singular values for CANS. (default: 1e-4)
63
66
  --- Auxiliary AdamW_adv Parameters (used for 'adam' groups) ---
64
67
  adam_betas (tuple[float, float]): Betas for the AdamW optimizer part.
65
68
  adam_eps (float): Epsilon for the AdamW optimizer part.
@@ -100,6 +103,9 @@ class Muon_adv(torch.optim.Optimizer):
100
103
  normuon_eps: float = 1e-8,
101
104
  normuon_lr_scale: float = 0.2,
102
105
  normuon_atan2: bool = False,
106
+ # CANS
107
+ accelerated_ns: bool = False,
108
+ cns_a_bound: float = 1e-4,
103
109
  # Compiled
104
110
  compiled_optimizer: bool = False,
105
111
  # --- AdamW_adv specific parameters ---
@@ -148,6 +154,8 @@ class Muon_adv(torch.optim.Optimizer):
148
154
  "normuon_variant": normuon_variant, "beta2_normuon": beta2_normuon,
149
155
  "normuon_eps": normuon_eps, "normuon_lr_scale": normuon_lr_scale,
150
156
  "normuon_atan2": normuon_atan2,
157
+ # CANS
158
+ "accelerated_ns": accelerated_ns, "cns_a_bound": cns_a_bound,
151
159
  # AdamW_adv defaults
152
160
  "adam_betas": adam_betas, "adam_eps": adam_eps, "adam_weight_decay": adam_weight_decay,
153
161
  "adam_use_bias_correction": adam_use_bias_correction, "adam_use_atan2": adam_use_atan2,
@@ -163,7 +171,6 @@ class Muon_adv(torch.optim.Optimizer):
163
171
 
164
172
  super().__init__(params, defaults)
165
173
 
166
- self.global_step = 0 # For Adam bias correction and Kourkoutas
167
174
  self.kourkoutas_helper = None
168
175
  if any(group.get('adam_kourkoutas_beta', False) for group in self.param_groups):
169
176
  self.kourkoutas_helper = KourkoutasHelper(self)
@@ -214,6 +221,7 @@ class Muon_adv(torch.optim.Optimizer):
214
221
 
215
222
  if optim_type == 'muon':
216
223
 
224
+
217
225
  state['factored'] = (
218
226
  group['nnmf_factor'] and
219
227
  not (len(p.shape) == 1 and not group['vector_reshape'])
@@ -238,9 +246,12 @@ class Muon_adv(torch.optim.Optimizer):
238
246
  elif len(p.shape) >= 2:
239
247
  state['normuon_v'] = torch.zeros(p.shape[0], device=p.device, dtype=torch.float32)
240
248
 
249
+ group['adam_kourkoutas_beta'] = False
241
250
 
242
251
  elif optim_type == 'adam':
243
252
 
253
+ state['step'] = 0
254
+
244
255
  state['factored'] = (
245
256
  group['adam_nnmf_factor'] and
246
257
  not (len(p.shape) == 1 and not group['vector_reshape'])
@@ -319,12 +330,12 @@ class Muon_adv(torch.optim.Optimizer):
319
330
  Q, _ = torch.linalg.qr(MG)
320
331
  projected_M = Q.T @ M
321
332
  ortho_projected_M = _newton_schulz_iteration(
322
- projected_M, steps=group['ns_steps'], eps=group['ns_eps'], coeffs=group['ns_coeffs']
333
+ projected_M, steps=group['ns_steps'], eps=group['ns_eps'], coeffs=group['ns_coeffs'], cns=group['accelerated_ns'], cns_a_bound=group['cns_a_bound']
323
334
  )
324
335
  update = Q @ ortho_projected_M
325
336
  else: # Fallback for invalid rank
326
337
  update = _newton_schulz_iteration(
327
- update, steps=group['ns_steps'], eps=group['ns_eps'], coeffs=group['ns_coeffs']
338
+ update, steps=group['ns_steps'], eps=group['ns_eps'], coeffs=group['ns_coeffs'], cns=group['accelerated_ns'], cns_a_bound=group['cns_a_bound']
328
339
  )
329
340
  else:
330
341
  # Original full Newton-Schulz
@@ -333,10 +344,12 @@ class Muon_adv(torch.optim.Optimizer):
333
344
  steps=group['ns_steps'],
334
345
  eps=group['ns_eps'],
335
346
  coeffs=group['ns_coeffs'],
347
+ cns=group['accelerated_ns'],
348
+ cns_a_bound=group['cns_a_bound'],
336
349
  )
337
350
 
338
351
 
339
- if group['normuon_variant'] and 'normuon_v' in state:
352
+ if group['normuon_variant']:
340
353
  v_t = state['normuon_v']
341
354
  beta2_normuon = group['beta2_normuon']
342
355
  # Update 2nd moment estimate
@@ -414,6 +427,8 @@ class Muon_adv(torch.optim.Optimizer):
414
427
  steps=group['ns_steps'],
415
428
  eps=group['ns_eps'],
416
429
  coeffs=group['ns_coeffs'],
430
+ cns=group['accelerated_ns'],
431
+ cns_a_bound=group['cns_a_bound'],
417
432
  )
418
433
 
419
434
  # 5. Project back to the original space
@@ -424,6 +439,8 @@ class Muon_adv(torch.optim.Optimizer):
424
439
  steps=group['ns_steps'],
425
440
  eps=group['ns_eps'],
426
441
  coeffs=group['ns_coeffs'],
442
+ cns=group['accelerated_ns'],
443
+ cns_a_bound=group['cns_a_bound'],
427
444
  )
428
445
  else:
429
446
  # Original NewtonSchulz
@@ -432,10 +449,12 @@ class Muon_adv(torch.optim.Optimizer):
432
449
  steps=group['ns_steps'],
433
450
  eps=group['ns_eps'],
434
451
  coeffs=group['ns_coeffs'],
452
+ cns=group['accelerated_ns'],
453
+ cns_a_bound=group['cns_a_bound'],
435
454
  )
436
455
 
437
456
  # NorMuon Logic
438
- if group['normuon_variant'] and 'normuon_v' in state:
457
+ if group['normuon_variant']:
439
458
  v_t = state['normuon_v']
440
459
  beta2_normuon = group['beta2_normuon']
441
460
  # Update 2nd moment estimate
@@ -629,10 +648,6 @@ class Muon_adv(torch.optim.Optimizer):
629
648
 
630
649
  state = self.state[p]
631
650
 
632
- if self.kourkoutas_helper:
633
- # Prepare Kourkoutas-β once per optimizer step.
634
- self.kourkoutas_helper.maybe_prepare_step(self.global_step)
635
-
636
651
  # Determine if using Adam or Muon based on state keys
637
652
  # We can use optm_type but I see this as a safer way.
638
653
  if 'momentum_buffer' in state or 'mu_mbuf_nmf' in state:
@@ -644,9 +659,15 @@ class Muon_adv(torch.optim.Optimizer):
644
659
  is_compiled = group.get('compiled_optimizer', False)
645
660
 
646
661
  if use_adam:
662
+ step = state['step']
663
+
664
+ if self.kourkoutas_helper:
665
+ # Prepare Kourkoutas-β once per optimizer step.
666
+ self.kourkoutas_helper.maybe_prepare_step(step)
667
+
647
668
  # Adam-specific setup (bias correction)
648
669
  if group['adam_use_bias_correction']:
649
- current_step = self.global_step + 1
670
+ current_step = step + 1
650
671
  beta1_adam, beta2_adam = group['adam_betas']
651
672
  bias_correction1 = 1.0 - beta1_adam ** current_step
652
673
  bias_correction2 = 1.0 - beta2_adam ** current_step
@@ -654,6 +675,8 @@ class Muon_adv(torch.optim.Optimizer):
654
675
  bias_correction1 = 1.0
655
676
  bias_correction2 = 1.0
656
677
 
678
+ self.state[p]['step'] += 1
679
+
657
680
  # Dispatch to compiled or uncompiled Adam step
658
681
  if is_compiled and self._compiled_adam_step is not None:
659
682
  # Tensors must be used for compiled functions
@@ -690,6 +713,4 @@ class Muon_adv(torch.optim.Optimizer):
690
713
  for i, p in enumerate(group['params']):
691
714
  self.step_parameter(p, group, i)
692
715
 
693
- self.global_step += 1
694
-
695
716
  return loss
@@ -105,7 +105,7 @@ class Prodigy_adv(torch.optim.Optimizer):
105
105
  logging (default: 0).
106
106
  layer_key_fn (Optional[Callable]): A function that takes a parameter `p`
107
107
  and returns a unique, hashable key representing its "layer" or "bucket".
108
- If `None`, parameters are bucketed by their memory ID (tensor-wise).
108
+ If `None`, parameters are bucketed by their shape.
109
109
  (default: None)
110
110
  """
111
111
 
@@ -484,16 +484,18 @@ class Prodigy_adv(torch.optim.Optimizer):
484
484
  if self.kourkoutas_beta:
485
485
  self.kourkoutas_helper.maybe_prepare_step(self.global_step)
486
486
 
487
- if isinstance(self.d_numerator, float):
488
- self.d_numerator = torch.tensor(self.d_numerator, device=p.device)
489
- self.d_denom = torch.tensor(self.d_denom, device=p.device)
490
-
491
487
  if not group.get('compiled_optimizer', False):
488
+ if isinstance(self.d_numerator, float):
489
+ self.d_numerator = torch.tensor(self.d_numerator, device=p.device)
490
+ self.d_denom = torch.tensor(self.d_denom, device=p.device)
492
491
  self.__step_parameter(p, group, self.d, self.dlr)
493
492
  else:
494
- d_tensor = torch.tensor(self.d, device=p.device)
495
- dlr_tensor = torch.tensor(self.dlr, device=p.device)
496
- self._compiled_step_parameter(p, group, d_tensor, dlr_tensor)
493
+ if isinstance(self.d_numerator, float):
494
+ self.d_numerator = torch.tensor(self.d_numerator, device=p.device)
495
+ self.d_denom = torch.tensor(self.d_denom, device=p.device)
496
+ self.d_tensor = torch.tensor(self.d, device=p.device)
497
+ self.dlr_tensor = torch.tensor(self.dlr, device=p.device)
498
+ self._compiled_step_parameter(p, group, self.d_tensor, self.dlr_tensor)
497
499
 
498
500
  def compile(self, *args, **kwargs):
499
501
  self._compiled_step_parameter = torch.compile(self.__step_parameter, *args, **kwargs)
@@ -69,7 +69,7 @@ class Simplified_AdEMAMix(torch.optim.Optimizer):
69
69
  logging (default: 0).
70
70
  layer_key_fn (Optional[Callable]): A function that takes a parameter `p`
71
71
  and returns a unique, hashable key representing its "layer" or "bucket".
72
- If `None`, parameters are bucketed by their memory ID (tensor-wise).
72
+ If `None`, parameters are bucketed by their shape.
73
73
  (default: None)
74
74
  nnmf_factor (bool): whether to use the factorization or disable it to use
75
75
  the uncompressed optimizer. (default: False)
@@ -0,0 +1,87 @@
1
+ import torch
2
+
3
+ @torch.no_grad()
4
+ def _newton_schulz_iteration(
5
+ G: torch.Tensor,
6
+ steps: int = 5,
7
+ eps: float = 1e-7,
8
+ coeffs: tuple[float, float, float] = (3.4445, -4.7750, 2.0315),
9
+ cns: bool = False,
10
+ cns_a_bound: float = 1e-4,
11
+ ) -> torch.Tensor:
12
+ """
13
+ Performs the Newton-Schulz iteration to find the nearest orthogonal matrix.
14
+ This is the core computation of the Muon optimizer.
15
+
16
+ Args:
17
+ G (torch.Tensor): The 2D input matrix (momentum-accumulated gradient).
18
+ steps (int): The number of iterations to run.
19
+ eps (float): Small constant for numerical stability during normalization.
20
+ coeffs (tuple[float, float, float]): The (a, b, c) coefficients for the
21
+ quintic polynomial update.
22
+ cns (bool): If True, enables Chebyshev-accelerated Newton-Schulz (CANS)
23
+ using an iterative 3rd-order polynomial with optimal coefficients
24
+ derived at each step.
25
+ cns_a_bound (float): The initial lower bound for singular values when
26
+ using CANS. The upper bound is assumed to be 1.0 after normalization.
27
+ Returns:
28
+ torch.Tensor: The orthogonalized matrix.
29
+ """
30
+ assert G.ndim >= 2
31
+
32
+ a, b, c = coeffs
33
+
34
+ X = G.to(torch.bfloat16)
35
+
36
+ transposed = G.size(-2) > G.size(-1)
37
+ if transposed:
38
+ X = X.mT
39
+
40
+ X = X / (X.norm(dim=(-2, -1), keepdim=True) + eps)
41
+
42
+ if cns:
43
+ # Chebyshev-accelerated Newton-Schulz (CANS) from
44
+ # "Accelerating Newton-Schulz Iteration for Orthogonalization via Chebyshev-type Polynomials"
45
+ # This implements the iterative scheme from Algorithm 1, using the
46
+ # closed-form 3rd-order polynomial from Proposition 2.
47
+ lower_bound = cns_a_bound
48
+ upper_bound = 1.0 # Matrix is normalized, so largest singular value is approx 1.
49
+
50
+ for _ in range(steps):
51
+ # Calculate optimal 3rd-order coefficients c1, c3 for p(x) = c1*x + c3*x^3
52
+ # based on the current singular value bounds [lower_bound, upper_bound].
53
+ # Formulas are derived from Proposition 2 and its proof in Appendix B of the paper.
54
+ a_bound, b_bound = lower_bound, upper_bound
55
+ term = a_bound*a_bound + a_bound*b_bound + b_bound*b_bound
56
+ e_sq = term / 3.0
57
+
58
+ # Calculate alpha, which scales the polynomial
59
+ common_den_part = 2.0 * (e_sq**1.5)
60
+ ab_part = a_bound*a_bound*b_bound + b_bound*b_bound*a_bound
61
+ alpha_den = common_den_part + ab_part
62
+ alpha = 6.0 / alpha_den
63
+
64
+ c1 = alpha * e_sq
65
+ c3 = -alpha / 3.0
66
+
67
+ # Apply the 3rd-order Newton-Schulz update
68
+ A = X @ X.mT
69
+ X = c1 * X + c3 * (A @ X)
70
+
71
+ # Update the singular value bounds for the next iteration based on the error
72
+ eps_num = common_den_part - ab_part
73
+ eps_val = eps_num / alpha_den
74
+ lower_bound = 1.0 - eps_val
75
+ upper_bound = 1.0 + eps_val
76
+ else:
77
+ # Perform the iterative updates
78
+ for _ in range(steps):
79
+ A = X @ X.mT
80
+ B = b * A + c * (A @ A)
81
+ X = a * X + B @ X
82
+
83
+ # Transpose back if necessary
84
+ if transposed:
85
+ X = X.mT
86
+
87
+ return X.to(G.dtype)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: adv_optm
3
- Version: 2.dev1
3
+ Version: 2.dev2
4
4
  Summary: A family of highly efficient, lightweight yet powerful optimizers.
5
5
  Home-page: https://github.com/Koratahiu/Advanced_Optimizers
6
6
  Author: Koratahiu
@@ -5,7 +5,7 @@ with open("README.md", "r", encoding="utf-8") as fh:
5
5
 
6
6
  setup(
7
7
  name="adv_optm",
8
- version="2.dev1",
8
+ version="2.dev2",
9
9
  author="Koratahiu",
10
10
  author_email="hiuhonor@gmail.com",
11
11
  license='Apache 2.0',
@@ -1,47 +0,0 @@
1
- import torch
2
-
3
- @torch.no_grad()
4
- def _newton_schulz_iteration(
5
- G: torch.Tensor,
6
- steps: int = 5,
7
- eps: float = 1e-7,
8
- coeffs: tuple[float, float, float] = (3.4445, -4.7750, 2.0315)
9
- ) -> torch.Tensor:
10
- """
11
- Performs the Newton-Schulz iteration to find the nearest orthogonal matrix.
12
- This is the core computation of the Muon optimizer.
13
-
14
- Args:
15
- G (torch.Tensor): The 2D input matrix (momentum-accumulated gradient).
16
- steps (int): The number of iterations to run.
17
- eps (float): Small constant for numerical stability during normalization.
18
- coeffs (tuple[float, float, float]): The (a, b, c) coefficients for the
19
- quintic polynomial update.
20
-
21
- Returns:
22
- torch.Tensor: The orthogonalized matrix.
23
- """
24
-
25
- a, b, c = coeffs
26
-
27
- X = G.to(torch.bfloat16)
28
-
29
- # Normalize the matrix
30
- X.div_(X.norm() + eps)
31
-
32
- # Handle non-square matrices by transposing the taller one
33
- transposed = G.size(0) > G.size(1)
34
- if transposed:
35
- X = X.T
36
-
37
- # Perform the iterative updates
38
- for _ in range(steps):
39
- A = X @ X.T
40
- B = b * A + c * (A @ A)
41
- X = a * X + B @ X
42
-
43
- # Transpose back if necessary
44
- if transposed:
45
- X = X.T
46
-
47
- return X.to(G.dtype)
File without changes
File without changes
File without changes