adv-optm 1.2.dev13__py3-none-any.whl → 1.2.dev15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of adv-optm might be problematic. Click here for more details.

adv_optm/__init__.py CHANGED
@@ -20,4 +20,4 @@ __all__ = [
20
20
  "AdaMuon_adv",
21
21
  ]
22
22
 
23
- __version__ = "1.2.dev13"
23
+ __version__ = "1.2.dev15"
@@ -63,6 +63,9 @@ class AdaMuon_adv(torch.optim.Optimizer):
63
63
  (default: False)
64
64
  ortho_rank (int): The rank for low-rank orthogonalization.
65
65
  (default: 128)
66
+ accelerated_ns (bool): If True, enables Chebyshev-accelerated Newton-Schulz, which
67
+ dynamically calculates optimal 3rd-order polynomial coefficients. (default: False)
68
+ cns_a_bound (float): Initial lower bound for singular values for CANS. (default: 1e-4)
66
69
  nnmf_factor (bool): whether to use the factorization or disable it to use
67
70
  the uncompressed optimizer. (default: False)
68
71
  --- Auxiliary AdamW_adv Parameters (used for 'adam' groups) ---
@@ -96,11 +99,16 @@ class AdaMuon_adv(torch.optim.Optimizer):
96
99
  nesterov: bool = False,
97
100
  Simplified_AdEMAMix: bool = False,
98
101
  alpha_grad: float = 100.0,
99
- vector_reshape: bool = False,
102
+ normuon_variant: bool = False,
100
103
  # Low-rank Muon
101
104
  low_rank_ortho: bool = False,
102
105
  ortho_rank: int = 128,
106
+ # Factored
107
+ vector_reshape: bool = False,
103
108
  nnmf_factor: bool = False,
109
+ # CANS
110
+ accelerated_ns: bool = False,
111
+ cns_a_bound: float = 1e-4,
104
112
  # Compiled
105
113
  compiled_optimizer: bool = False,
106
114
  # --- AdamW_adv specific parameters ---
@@ -139,9 +147,12 @@ class AdaMuon_adv(torch.optim.Optimizer):
139
147
  "vector_reshape": vector_reshape,
140
148
  "nesterov":nesterov, "use_atan2":use_atan2,
141
149
  "Simplified_AdEMAMix": Simplified_AdEMAMix, "alpha_grad": alpha_grad,
150
+ "normuon_variant": normuon_variant,
142
151
  # Low-rank Ortho
143
152
  "low_rank_ortho": low_rank_ortho, "ortho_rank": ortho_rank,
144
153
  "compiled_optimizer":compiled_optimizer,
154
+ # CANS
155
+ "accelerated_ns": accelerated_ns, "cns_a_bound": cns_a_bound,
145
156
  # AdamW_adv defaults
146
157
  "adam_betas": adam_betas, "adam_eps": adam_eps, "adam_weight_decay": adam_weight_decay,
147
158
  "adam_use_bias_correction": adam_use_bias_correction, "adam_use_atan2": adam_use_atan2,
@@ -156,7 +167,6 @@ class AdaMuon_adv(torch.optim.Optimizer):
156
167
 
157
168
  super().__init__(params, defaults)
158
169
 
159
- self.global_step = 0 # For Adam bias correction and Kourkoutas
160
170
  self.kourkoutas_helper = None
161
171
  if any(group.get('adam_kourkoutas_beta', False) for group in self.param_groups):
162
172
  self.kourkoutas_helper = KourkoutasHelper(self)
@@ -214,16 +224,25 @@ class AdaMuon_adv(torch.optim.Optimizer):
214
224
  state['mv_mbuf_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
215
225
  packed_d2 = (d2 + 7) // 8
216
226
  state['sign_buf'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=device)
217
- state['mu_vbuf_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
218
- state['mv_vbuf_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
227
+ if not group['normuon_variant']:
228
+ state['mu_vbuf_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
229
+ state['mv_vbuf_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
219
230
  else:
220
- if len(p.shape) >= 2:
231
+ if len(p.shape) >= 2 and not group['normuon_variant']:
221
232
  state['second_momentum_buffer'] = torch.zeros_like(p)
222
233
  state['momentum_buffer'] = torch.zeros_like(p)
223
234
 
235
+ # NorMuon state initialization
236
+ if group['normuon_variant']:
237
+ if state['factored']:
238
+ state['normuon_v'] = torch.zeros(d1, device=p.device, dtype=torch.float32)
239
+ elif len(p.shape) >= 2:
240
+ state['normuon_v'] = torch.zeros(p.shape[0], device=p.device, dtype=torch.float32)
224
241
 
225
242
  elif optim_type == 'adam':
226
243
 
244
+ state['step'] = 0
245
+
227
246
  state['factored'] = (
228
247
  group['adam_nnmf_factor'] and
229
248
  not (len(p.shape) == 1 and not group['vector_reshape'])
@@ -301,12 +320,12 @@ class AdaMuon_adv(torch.optim.Optimizer):
301
320
  Q, _ = torch.linalg.qr(MG)
302
321
  projected_M = Q.T @ M
303
322
  ortho_projected_M = _newton_schulz_iteration(
304
- projected_M, steps=group['ns_steps'], eps=group['ns_eps'], coeffs=group['ns_coeffs']
323
+ projected_M, steps=group['ns_steps'], eps=group['ns_eps'], coeffs=group['ns_coeffs'], cns=group['accelerated_ns'], cns_a_bound=group['cns_a_bound']
305
324
  )
306
325
  update = Q @ ortho_projected_M
307
326
  else: # Fallback for invalid rank
308
327
  update = _newton_schulz_iteration(
309
- signed_m_buf, steps=group['ns_steps'], eps=group['ns_eps'], coeffs=group['ns_coeffs']
328
+ signed_m_buf, steps=group['ns_steps'], eps=group['ns_eps'], coeffs=group['ns_coeffs'], cns=group['accelerated_ns'], cns_a_bound=group['cns_a_bound']
310
329
  )
311
330
  else:
312
331
  # Original full Newton-Schulz
@@ -315,41 +334,61 @@ class AdaMuon_adv(torch.optim.Optimizer):
315
334
  steps=group['ns_steps'],
316
335
  eps=group['ns_eps'],
317
336
  coeffs=group['ns_coeffs'],
337
+ cns=group['accelerated_ns'],
338
+ cns_a_bound=group['cns_a_bound'],
318
339
  )
319
340
  del signed_m_buf
320
341
 
321
- # Reconstruct second momentum from previous step's factors
322
- vt_buf = _unnmf((state['mu_vbuf_nmf'], state['mv_vbuf_nmf']))
342
+ if group['normuon_variant']:
343
+ v_t = state['normuon_v']
344
+ # Update 2nd moment estimate
345
+ mean_squared_update = torch.mean(update.square(), dim=1)
346
+ v_t.mul_(beta2).add_(mean_squared_update, alpha=1 - beta2)
347
+ # Normalize update
348
+ if group['use_atan2']:
349
+ a = 1.2732395
350
+ update.atan2_(v_t.sqrt().unsqueeze(1)).mul_(a)
351
+ else:
352
+ update.div_(v_t.sqrt().unsqueeze(1).add_(group['eps']))
353
+ # Scale learning rate
354
+ update_norm = torch.linalg.vector_norm(update)
355
+ scaled_lr = group['rms_target'] * lr * (p.numel()**0.5) / update_norm.add_(group['eps'])
356
+ update = update.view(p.shape).mul_(scaled_lr)
357
+ del mean_squared_update, update_norm, scaled_lr
358
+ else:
359
+ # Reconstruct second momentum from previous step's factors
360
+ vt_buf = _unnmf((state['mu_vbuf_nmf'], state['mv_vbuf_nmf']))
323
361
 
324
- # Update second momentum in full-size
325
- vt_buf.mul_(beta2).addcmul_(update, update, value=1 - beta2)
362
+ # Update second momentum in full-size
363
+ vt_buf.mul_(beta2).addcmul_(update, update, value=1 - beta2)
326
364
 
327
- # Apply second momentum update (adaptive scaling)
328
- if group['use_atan2']:
329
- a = 1.2732395
330
- denom = vt_buf.sqrt()
331
- update.atan2_(denom).mul_(a)
332
- else:
333
- denom = vt_buf.sqrt().add_(group['eps'])
334
- update.div_(denom)
335
- del denom
365
+ # Apply second momentum update (adaptive scaling)
366
+ if group['use_atan2']:
367
+ a = 1.2732395
368
+ denom = vt_buf.sqrt()
369
+ update.atan2_(denom).mul_(a)
370
+ else:
371
+ denom = vt_buf.sqrt().add_(group['eps'])
372
+ update.div_(denom)
373
+ del denom
336
374
 
337
- # RMS-aligned rescaling
338
- rms_target = group['rms_target']
339
- num_elements = update.numel()
340
- # Add eps to prevent division by zero
341
- update.mul_(rms_target * (num_elements ** 0.5) / (update.norm() + group['eps']))
375
+ # RMS-aligned rescaling
376
+ rms_target = group['rms_target']
377
+ num_elements = update.numel()
378
+ # Add eps to prevent division by zero
379
+ update.mul_(rms_target * (num_elements ** 0.5) / (update.norm().add_(group['eps'])))
342
380
 
343
- update = update.view(p.shape).mul_(lr)
344
- del num_elements
381
+ update = update.view(p.shape).mul_(lr)
382
+ del num_elements
345
383
 
346
384
  # Compress updated moments and store new factors
347
385
  state['sign_buf'] = _pack_bools(mt_buf > 0)
348
386
  _nnmf(mt_buf.abs(), out=(state['mu_mbuf_nmf'], state['mv_mbuf_nmf']))
349
387
  del mt_buf
350
388
 
351
- _nnmf(vt_buf.abs(), out=(state['mu_vbuf_nmf'], state['mv_vbuf_nmf']))
352
- del vt_buf
389
+ if not group['normuon_variant']:
390
+ _nnmf(vt_buf.abs(), out=(state['mu_vbuf_nmf'], state['mv_vbuf_nmf']))
391
+ del vt_buf
353
392
 
354
393
  else: # Standard AdaMuon logic for non-factored tensors
355
394
 
@@ -387,12 +426,12 @@ class AdaMuon_adv(torch.optim.Optimizer):
387
426
  Q, _ = torch.linalg.qr(MG)
388
427
  projected_M = Q.T @ M
389
428
  ortho_projected_M = _newton_schulz_iteration(
390
- projected_M, steps=group['ns_steps'], eps=group['ns_eps'], coeffs=group['ns_coeffs']
429
+ projected_M, steps=group['ns_steps'], eps=group['ns_eps'], coeffs=group['ns_coeffs'], cns=group['accelerated_ns'], cns_a_bound=group['cns_a_bound']
391
430
  )
392
431
  update = Q @ ortho_projected_M
393
432
  else: # Fallback for invalid rank
394
433
  update = _newton_schulz_iteration(
395
- signed_m_buf, steps=group['ns_steps'], eps=group['ns_eps'], coeffs=group['ns_coeffs']
434
+ signed_m_buf, steps=group['ns_steps'], eps=group['ns_eps'], coeffs=group['ns_coeffs'], cns=group['accelerated_ns'], cns_a_bound=group['cns_a_bound']
396
435
  )
397
436
  else:
398
437
  # Original full Newton-Schulz
@@ -401,42 +440,59 @@ class AdaMuon_adv(torch.optim.Optimizer):
401
440
  steps=group['ns_steps'],
402
441
  eps=group['ns_eps'],
403
442
  coeffs=group['ns_coeffs'],
443
+ cns=group['accelerated_ns'],
444
+ cns_a_bound=group['cns_a_bound'],
404
445
  )
405
446
  del signed_m_buf
406
447
 
407
448
  update = update.view(original_shape)
408
449
 
409
- vt_buf = state['second_momentum_buffer']
410
- vt_buf.mul_(beta2).addcmul_(update, update, value=1 - beta2)
411
-
412
- # Apply second momentum update (adaptive scaling)
413
- if group['use_atan2']:
414
- a = 1.2732395
415
- denom = vt_buf.sqrt()
416
- update.atan2_(denom).mul_(a)
450
+ if group['normuon_variant']:
451
+ # NorMuon Logic
452
+ v_t = state['normuon_v']
453
+ # Update 2nd moment estimate
454
+ mean_squared_update = torch.mean(update.square(), dim=1)
455
+ v_t.mul_(beta2).add_(mean_squared_update, alpha=1 - beta2)
456
+ # Normalize update
457
+ if group['use_atan2']:
458
+ a = 1.2732395
459
+ update.atan2_(v_t.sqrt().unsqueeze(1)).mul_(a)
460
+ else:
461
+ update.div_(v_t.sqrt().unsqueeze(1).add_(group['eps']))
462
+ # Scale learning rate
463
+ update_norm = torch.linalg.vector_norm(update)
464
+ scaled_lr = group['rms_target'] * lr * (p.numel()**0.5) / update_norm.add_(group['eps'])
465
+ update.mul_(scaled_lr)
466
+ del mean_squared_update, update_norm, scaled_lr
417
467
  else:
418
- denom = vt_buf.sqrt().add_(group['eps'])
419
- update.div_(denom)
420
- del denom
421
-
422
- # RMS-aligned rescaling
423
- rms_target = group['rms_target']
424
- num_elements = update.numel()
425
- # Add eps to prevent division by zero
426
- update.mul_(rms_target * (num_elements ** 0.5) / (update.norm() + group['eps']))
427
-
428
- del num_elements
429
-
430
- update.mul_(lr)
468
+ # Original AdaMuon Logic
469
+ vt_buf = state['second_momentum_buffer']
470
+ vt_buf.mul_(beta2).addcmul_(update, update, value=1 - beta2)
471
+ # Apply second momentum update (adaptive scaling)
472
+ if group['use_atan2']:
473
+ a = 1.2732395
474
+ denom = vt_buf.sqrt()
475
+ update.atan2_(denom).mul_(a)
476
+ else:
477
+ denom = vt_buf.sqrt().add_(group['eps'])
478
+ update.div_(denom)
479
+ del denom
480
+ # RMS-aligned rescaling
481
+ rms_target = group['rms_target']
482
+ num_elements = update.numel()
483
+ # Add eps to prevent division by zero
484
+ update.mul_(rms_target * (num_elements ** 0.5) / (update.norm().add_(group['eps'])))
485
+ del num_elements
486
+ update.mul_(lr)
431
487
 
432
488
  else: # Fallback to standard SGD with momentum for 1D params (biases, etc.)
433
489
  # Momentum update
434
490
  mt_buf = state['momentum_buffer']
435
491
  mt_buf.mul_(beta1).add_(grad)
436
492
  if nesterov:
437
- # Nesterov momentum
438
493
  update = grad.add(mt_buf, alpha=beta1)
439
- # elif Simplified_AdEMAMix: # TODO, it will break SGD since it requires x100 lower LR
494
+ # FIXME, Simplified_AdEMAMix will break SGD since it requires x100 lower LR
495
+ # elif Simplified_AdEMAMix:
440
496
  # update = mt_buf.add(grad, alpha=alpha_grad)
441
497
  else:
442
498
  update = mt_buf.clone()
@@ -599,7 +655,6 @@ class AdaMuon_adv(torch.optim.Optimizer):
599
655
  state = self.state[p]
600
656
 
601
657
 
602
-
603
658
  # Determine if using Adam or Muon based on state keys
604
659
  # We can use optm_type but I see this as a safer way.
605
660
  if 'momentum_buffer' in state or 'mu_mbuf_nmf' in state:
@@ -611,32 +666,41 @@ class AdaMuon_adv(torch.optim.Optimizer):
611
666
  is_compiled = group.get('compiled_optimizer', False)
612
667
 
613
668
  if use_adam:
669
+ step = state['step']
670
+
614
671
  if self.kourkoutas_helper:
615
672
  # Prepare Kourkoutas-β once per optimizer step.
616
- self.kourkoutas_helper.maybe_prepare_step(self.global_step)
673
+ self.kourkoutas_helper.maybe_prepare_step(step)
674
+
617
675
  # Adam-specific setup (bias correction)
618
676
  if group['adam_use_bias_correction']:
619
- current_step = self.global_step + 1
677
+ current_step = step + 1
620
678
  beta1_adam, beta2_adam = group['adam_betas']
621
679
  bias_correction1 = 1.0 - beta1_adam ** current_step
622
680
  bias_correction2 = 1.0 - beta2_adam ** current_step
623
681
  else:
624
682
  bias_correction1 = 1.0
625
683
  bias_correction2 = 1.0
684
+
685
+ self.state[p]['step'] += 1
686
+
626
687
  # Dispatch to compiled or uncompiled Adam step
627
688
  if is_compiled and self._compiled_adam_step is not None:
628
- # Tensors must be used for compiled functions
629
- lr_tensor = torch.tensor(lr, device=p.device)
630
- bc1_tensor = torch.tensor(bias_correction1, device=p.device)
631
- bc2_tensor = torch.tensor(bias_correction2, device=p.device)
632
- self._compiled_adam_step(p, grad, state, group, lr_tensor, bc1_tensor, bc2_tensor)
689
+ # convert to tensors for compiled path once a step
690
+ if not hasattr(self, 'lr_adam_tensor') or self.lr_adam_tensor is None:
691
+ self.lr_adam_tensor = torch.tensor(group['lr'])
692
+ self.bc1 = torch.tensor(bias_correction1)
693
+ self.bc2 = torch.tensor(bias_correction2)
694
+ self._compiled_adam_step(p, grad, state, group, self.lr_adam_tensor, self.bc1, self.bc2)
633
695
  else:
634
696
  self._adam_step_parameter(p, grad, state, group, lr, bias_correction1, bias_correction2)
635
697
  else: # Muon path
636
698
  # Dispatch to compiled or uncompiled Muon step
637
699
  if is_compiled and self._compiled_muon_step is not None:
638
- lr_tensor = torch.tensor(lr, device=p.device)
639
- self._compiled_muon_step(p, grad, state, group, lr_tensor)
700
+ # convert to tensors for compiled path once a step
701
+ if not hasattr(self, 'lr_tensor') or self.lr_tensor is None:
702
+ self.lr_tensor = torch.tensor(group['lr'])
703
+ self._compiled_muon_step(p, grad, state, group, self.lr_tensor)
640
704
  else:
641
705
  self._muon_step_parameter(p, grad, state, group, lr)
642
706
 
@@ -659,6 +723,8 @@ class AdaMuon_adv(torch.optim.Optimizer):
659
723
  for i, p in enumerate(group['params']):
660
724
  self.step_parameter(p, group, i)
661
725
 
662
- self.global_step += 1
663
-
726
+ if self.param_groups[0].get('compiled_optimizer', False):
727
+ # Reset compile tensors once a step
728
+ self.lr_tensor = None
729
+ self.lr_adam_tensor = None
664
730
  return loss
@@ -60,6 +60,9 @@ class Muon_adv(torch.optim.Optimizer):
60
60
  normuon_lr_scale (float): Scaling factor for the NorMuon learning rate.
61
61
  (default: 0.2)
62
62
  normuon_atan2 (bool): whether to use the atan2 for NorMuon. (default: False)
63
+ accelerated_ns (bool): If True, enables Chebyshev-accelerated Newton-Schulz, which
64
+ dynamically calculates optimal 3rd-order polynomial coefficients. (default: False)
65
+ cns_a_bound (float): Initial lower bound for singular values for CANS. (default: 1e-4)
63
66
  --- Auxiliary AdamW_adv Parameters (used for 'adam' groups) ---
64
67
  adam_betas (tuple[float, float]): Betas for the AdamW optimizer part.
65
68
  adam_eps (float): Epsilon for the AdamW optimizer part.
@@ -100,6 +103,9 @@ class Muon_adv(torch.optim.Optimizer):
100
103
  normuon_eps: float = 1e-8,
101
104
  normuon_lr_scale: float = 0.2,
102
105
  normuon_atan2: bool = False,
106
+ # CANS
107
+ accelerated_ns: bool = False,
108
+ cns_a_bound: float = 1e-4,
103
109
  # Compiled
104
110
  compiled_optimizer: bool = False,
105
111
  # --- AdamW_adv specific parameters ---
@@ -148,6 +154,8 @@ class Muon_adv(torch.optim.Optimizer):
148
154
  "normuon_variant": normuon_variant, "beta2_normuon": beta2_normuon,
149
155
  "normuon_eps": normuon_eps, "normuon_lr_scale": normuon_lr_scale,
150
156
  "normuon_atan2": normuon_atan2,
157
+ # CANS
158
+ "accelerated_ns": accelerated_ns, "cns_a_bound": cns_a_bound,
151
159
  # AdamW_adv defaults
152
160
  "adam_betas": adam_betas, "adam_eps": adam_eps, "adam_weight_decay": adam_weight_decay,
153
161
  "adam_use_bias_correction": adam_use_bias_correction, "adam_use_atan2": adam_use_atan2,
@@ -163,7 +171,6 @@ class Muon_adv(torch.optim.Optimizer):
163
171
 
164
172
  super().__init__(params, defaults)
165
173
 
166
- self.global_step = 0 # For Adam bias correction and Kourkoutas
167
174
  self.kourkoutas_helper = None
168
175
  if any(group.get('adam_kourkoutas_beta', False) for group in self.param_groups):
169
176
  self.kourkoutas_helper = KourkoutasHelper(self)
@@ -214,6 +221,7 @@ class Muon_adv(torch.optim.Optimizer):
214
221
 
215
222
  if optim_type == 'muon':
216
223
 
224
+
217
225
  state['factored'] = (
218
226
  group['nnmf_factor'] and
219
227
  not (len(p.shape) == 1 and not group['vector_reshape'])
@@ -238,9 +246,12 @@ class Muon_adv(torch.optim.Optimizer):
238
246
  elif len(p.shape) >= 2:
239
247
  state['normuon_v'] = torch.zeros(p.shape[0], device=p.device, dtype=torch.float32)
240
248
 
249
+ group['adam_kourkoutas_beta'] = False
241
250
 
242
251
  elif optim_type == 'adam':
243
252
 
253
+ state['step'] = 0
254
+
244
255
  state['factored'] = (
245
256
  group['adam_nnmf_factor'] and
246
257
  not (len(p.shape) == 1 and not group['vector_reshape'])
@@ -319,12 +330,12 @@ class Muon_adv(torch.optim.Optimizer):
319
330
  Q, _ = torch.linalg.qr(MG)
320
331
  projected_M = Q.T @ M
321
332
  ortho_projected_M = _newton_schulz_iteration(
322
- projected_M, steps=group['ns_steps'], eps=group['ns_eps'], coeffs=group['ns_coeffs']
333
+ projected_M, steps=group['ns_steps'], eps=group['ns_eps'], coeffs=group['ns_coeffs'], cns=group['accelerated_ns'], cns_a_bound=group['cns_a_bound']
323
334
  )
324
335
  update = Q @ ortho_projected_M
325
336
  else: # Fallback for invalid rank
326
337
  update = _newton_schulz_iteration(
327
- update, steps=group['ns_steps'], eps=group['ns_eps'], coeffs=group['ns_coeffs']
338
+ update, steps=group['ns_steps'], eps=group['ns_eps'], coeffs=group['ns_coeffs'], cns=group['accelerated_ns'], cns_a_bound=group['cns_a_bound']
328
339
  )
329
340
  else:
330
341
  # Original full Newton-Schulz
@@ -333,10 +344,12 @@ class Muon_adv(torch.optim.Optimizer):
333
344
  steps=group['ns_steps'],
334
345
  eps=group['ns_eps'],
335
346
  coeffs=group['ns_coeffs'],
347
+ cns=group['accelerated_ns'],
348
+ cns_a_bound=group['cns_a_bound'],
336
349
  )
337
350
 
338
351
 
339
- if group['normuon_variant'] and 'normuon_v' in state:
352
+ if group['normuon_variant']:
340
353
  v_t = state['normuon_v']
341
354
  beta2_normuon = group['beta2_normuon']
342
355
  # Update 2nd moment estimate
@@ -414,6 +427,8 @@ class Muon_adv(torch.optim.Optimizer):
414
427
  steps=group['ns_steps'],
415
428
  eps=group['ns_eps'],
416
429
  coeffs=group['ns_coeffs'],
430
+ cns=group['accelerated_ns'],
431
+ cns_a_bound=group['cns_a_bound'],
417
432
  )
418
433
 
419
434
  # 5. Project back to the original space
@@ -424,6 +439,8 @@ class Muon_adv(torch.optim.Optimizer):
424
439
  steps=group['ns_steps'],
425
440
  eps=group['ns_eps'],
426
441
  coeffs=group['ns_coeffs'],
442
+ cns=group['accelerated_ns'],
443
+ cns_a_bound=group['cns_a_bound'],
427
444
  )
428
445
  else:
429
446
  # Original NewtonSchulz
@@ -432,10 +449,12 @@ class Muon_adv(torch.optim.Optimizer):
432
449
  steps=group['ns_steps'],
433
450
  eps=group['ns_eps'],
434
451
  coeffs=group['ns_coeffs'],
452
+ cns=group['accelerated_ns'],
453
+ cns_a_bound=group['cns_a_bound'],
435
454
  )
436
455
 
437
456
  # NorMuon Logic
438
- if group['normuon_variant'] and 'normuon_v' in state:
457
+ if group['normuon_variant']:
439
458
  v_t = state['normuon_v']
440
459
  beta2_normuon = group['beta2_normuon']
441
460
  # Update 2nd moment estimate
@@ -629,10 +648,6 @@ class Muon_adv(torch.optim.Optimizer):
629
648
 
630
649
  state = self.state[p]
631
650
 
632
- if self.kourkoutas_helper:
633
- # Prepare Kourkoutas-β once per optimizer step.
634
- self.kourkoutas_helper.maybe_prepare_step(self.global_step)
635
-
636
651
  # Determine if using Adam or Muon based on state keys
637
652
  # We can use optm_type but I see this as a safer way.
638
653
  if 'momentum_buffer' in state or 'mu_mbuf_nmf' in state:
@@ -644,9 +659,15 @@ class Muon_adv(torch.optim.Optimizer):
644
659
  is_compiled = group.get('compiled_optimizer', False)
645
660
 
646
661
  if use_adam:
662
+ step = state['step']
663
+
664
+ if self.kourkoutas_helper:
665
+ # Prepare Kourkoutas-β once per optimizer step.
666
+ self.kourkoutas_helper.maybe_prepare_step(step)
667
+
647
668
  # Adam-specific setup (bias correction)
648
669
  if group['adam_use_bias_correction']:
649
- current_step = self.global_step + 1
670
+ current_step = step + 1
650
671
  beta1_adam, beta2_adam = group['adam_betas']
651
672
  bias_correction1 = 1.0 - beta1_adam ** current_step
652
673
  bias_correction2 = 1.0 - beta2_adam ** current_step
@@ -654,6 +675,8 @@ class Muon_adv(torch.optim.Optimizer):
654
675
  bias_correction1 = 1.0
655
676
  bias_correction2 = 1.0
656
677
 
678
+ self.state[p]['step'] += 1
679
+
657
680
  # Dispatch to compiled or uncompiled Adam step
658
681
  if is_compiled and self._compiled_adam_step is not None:
659
682
  # Tensors must be used for compiled functions
@@ -690,6 +713,4 @@ class Muon_adv(torch.optim.Optimizer):
690
713
  for i, p in enumerate(group['params']):
691
714
  self.step_parameter(p, group, i)
692
715
 
693
- self.global_step += 1
694
-
695
716
  return loss
@@ -71,7 +71,7 @@ class KourkoutasHelper:
71
71
  for layer_key, info in self.layer_info.items():
72
72
  params, group = info['params'], info['group_ref']
73
73
 
74
- if not group.get('kourkoutas_beta', False):
74
+ if not group.get('kourkoutas_beta', False) and not group.get('adam_kourkoutas_beta', False):
75
75
  continue
76
76
 
77
77
  first_param_in_layer = info['params'][0]
@@ -5,7 +5,9 @@ def _newton_schulz_iteration(
5
5
  G: torch.Tensor,
6
6
  steps: int = 5,
7
7
  eps: float = 1e-7,
8
- coeffs: tuple[float, float, float] = (3.4445, -4.7750, 2.0315)
8
+ coeffs: tuple[float, float, float] = (3.4445, -4.7750, 2.0315),
9
+ cns: bool = False,
10
+ cns_a_bound: float = 1e-4,
9
11
  ) -> torch.Tensor:
10
12
  """
11
13
  Performs the Newton-Schulz iteration to find the nearest orthogonal matrix.
@@ -17,32 +19,69 @@ def _newton_schulz_iteration(
17
19
  eps (float): Small constant for numerical stability during normalization.
18
20
  coeffs (tuple[float, float, float]): The (a, b, c) coefficients for the
19
21
  quintic polynomial update.
20
-
22
+ cns (bool): If True, enables Chebyshev-accelerated Newton-Schulz (CANS)
23
+ using an iterative 3rd-order polynomial with optimal coefficients
24
+ derived at each step.
25
+ cns_a_bound (float): The initial lower bound for singular values when
26
+ using CANS. The upper bound is assumed to be 1.0 after normalization.
21
27
  Returns:
22
28
  torch.Tensor: The orthogonalized matrix.
23
29
  """
24
- assert G.ndim == 2, "Newton-Schulz iteration only supports 2D matrices."
30
+ assert G.ndim >= 2
25
31
 
26
32
  a, b, c = coeffs
27
33
 
28
34
  X = G.to(torch.bfloat16)
29
35
 
30
- # Normalize the matrix
31
- X.div_(X.norm() + eps)
32
-
33
- # Handle non-square matrices by transposing the taller one
34
- transposed = G.size(0) > G.size(1)
36
+ transposed = G.size(-2) > G.size(-1)
35
37
  if transposed:
36
- X = X.T
38
+ X = X.mT
39
+
40
+ X = X / (X.norm(dim=(-2, -1), keepdim=True) + eps)
41
+
42
+ if cns:
43
+ # Chebyshev-accelerated Newton-Schulz (CANS) from
44
+ # "Accelerating Newton-Schulz Iteration for Orthogonalization via Chebyshev-type Polynomials"
45
+ # This implements the iterative scheme from Algorithm 1, using the
46
+ # closed-form 3rd-order polynomial from Proposition 2.
47
+ lower_bound = cns_a_bound
48
+ upper_bound = 1.0 # Matrix is normalized, so largest singular value is approx 1.
49
+
50
+ for _ in range(steps):
51
+ # Calculate optimal 3rd-order coefficients c1, c3 for p(x) = c1*x + c3*x^3
52
+ # based on the current singular value bounds [lower_bound, upper_bound].
53
+ # Formulas are derived from Proposition 2 and its proof in Appendix B of the paper.
54
+ a_bound, b_bound = lower_bound, upper_bound
55
+ term = a_bound*a_bound + a_bound*b_bound + b_bound*b_bound
56
+ e_sq = term / 3.0
57
+
58
+ # Calculate alpha, which scales the polynomial
59
+ common_den_part = 2.0 * (e_sq**1.5)
60
+ ab_part = a_bound*a_bound*b_bound + b_bound*b_bound*a_bound
61
+ alpha_den = common_den_part + ab_part
62
+ alpha = 6.0 / alpha_den
63
+
64
+ c1 = alpha * e_sq
65
+ c3 = -alpha / 3.0
66
+
67
+ # Apply the 3rd-order Newton-Schulz update
68
+ A = X @ X.mT
69
+ X = c1 * X + c3 * (A @ X)
37
70
 
38
- # Perform the iterative updates
39
- for _ in range(steps):
40
- A = X @ X.T
41
- B = b * A + c * (A @ A)
42
- X = a * X + B @ X
71
+ # Update the singular value bounds for the next iteration based on the error
72
+ eps_num = common_den_part - ab_part
73
+ eps_val = eps_num / alpha_den
74
+ lower_bound = 1.0 - eps_val
75
+ upper_bound = 1.0 + eps_val
76
+ else:
77
+ # Perform the iterative updates
78
+ for _ in range(steps):
79
+ A = X @ X.mT
80
+ B = b * A + c * (A @ A)
81
+ X = a * X + B @ X
43
82
 
44
83
  # Transpose back if necessary
45
84
  if transposed:
46
- X = X.T
85
+ X = X.mT
47
86
 
48
- return X.to(G.dtype)
87
+ return X.to(G.dtype)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: adv_optm
3
- Version: 1.2.dev13
3
+ Version: 1.2.dev15
4
4
  Summary: A family of highly efficient, lightweight yet powerful optimizers.
5
5
  Home-page: https://github.com/Koratahiu/Advanced_Optimizers
6
6
  Author: Koratahiu
@@ -1,23 +1,23 @@
1
- adv_optm/__init__.py,sha256=b2cBBXd4W_tBKGcxO1SHA5SaQrCl_cjxlHdoZSICo3E,380
2
- adv_optm/optim/AdaMuon_adv.py,sha256=-UBw_mJj8JzDAi3zQ0nLnSOgzsTzl7b7kVksDRUziEE,30582
1
+ adv_optm/__init__.py,sha256=I9iRXHonvg_82dEmyKXqt9PyN04Ez8TVbMb1uZgRZAc,380
2
+ adv_optm/optim/AdaMuon_adv.py,sha256=zjZHFS7ng5KwemQzePjFiGtNZlcgbzmmnqF6A80h_Tg,34652
3
3
  adv_optm/optim/AdamW_adv.py,sha256=KL9SCJWZ_ckAQEApB6ofbndVYjancN-v7Us7hJLFf54,17475
4
4
  adv_optm/optim/Adopt_adv.py,sha256=S8XI2YA7683jsW8p7igc2YcU30lsN0H18qL02Kpvj8E,21244
5
5
  adv_optm/optim/Lion_Prodigy_adv.py,sha256=LEA3UYJpPeFnmxeniLNv1u2LKKj4ufx3Bq_MLw-nWXk,14617
6
6
  adv_optm/optim/Lion_adv.py,sha256=aGNAplZlyXYgVllYcV_s4bK8iC4fv6EizFoWIMNLdBc,8299
7
- adv_optm/optim/Muon_adv.py,sha256=8d99NcXzLyxTbxVVXC8mHyeW7wM8jjK59QoXVTLScQA,32112
7
+ adv_optm/optim/Muon_adv.py,sha256=QutgiRkDS36O5BQNdcwdIcYBCKPy7U07YYVQT6dq8tc,33165
8
8
  adv_optm/optim/Prodigy_adv.py,sha256=lEjbtuQbomsCX39DnTPeI8Z5YG0f2aZPXN_E7-nGgWw,26060
9
9
  adv_optm/optim/Simplified_AdEMAMix.py,sha256=nEIA3yM11nBooKzHudB5l3x4UdFRBYRwiKVUkGmO0K8,12971
10
10
  adv_optm/optim/__init__.py,sha256=hpUWE6CKtt_rvMdgQVb3PtjhfZAvAxTq6hp8H8rIpBo,489
11
11
  adv_optm/util/BF16_Stochastic_Rounding.py,sha256=Q5H0BcogmE4atP65dLoI21HKSf50lRdsBDfeF6v9Tbg,1548
12
12
  adv_optm/util/Effective_Shape.py,sha256=TBvIk1V8IuTbbBsxuekJA4e_v8JlR5Nujtut8RTWAm4,318
13
- adv_optm/util/Kourkoutas.py,sha256=_fq2glPqKmzgWpLedfwq5EqIJAxICUK2fmUP-cdcgq0,7467
13
+ adv_optm/util/Kourkoutas.py,sha256=SSzhe0B6Zb2AXGwCKpVTLr0aaFfspcFBNZCZG3azI9k,7516
14
14
  adv_optm/util/NNMF.py,sha256=yRf5IP5Sjq0Uf0DxN0Q8NxEGSdD-f1ULziLVDOjY8K4,639
15
- adv_optm/util/Newton_Schulz.py,sha256=wJ_sKRaGVIsOofQ737my4ng494qX_pfgOqlDDmYtnCg,1377
15
+ adv_optm/util/Newton_Schulz.py,sha256=bBboYw_jm5_FMf0Citl79uqNedkHOTjQnUI7rZgLBmY,3341
16
16
  adv_optm/util/One_Bit_Boolean.py,sha256=Wat49esdwohuN-OHOFMW8D0aOQgV9cP5Rl8z6yfmpos,1068
17
17
  adv_optm/util/OrthoGrad.py,sha256=NzInuBQGy_Ja__M1R9XbvqVaQ0fhGbtGgFE9YON7B3I,707
18
18
  adv_optm/util/__init__.py,sha256=CXzS703GB4gil85khZi7sgKOnbzXGBOltshIOSPqj18,435
19
- adv_optm-1.2.dev13.dist-info/licenses/LICENSE,sha256=HrhfyXIkWY2tGFK11kg7vPCqhgh5DcxleloqdhrpyMY,11558
20
- adv_optm-1.2.dev13.dist-info/METADATA,sha256=NW2SU3Uw-ow_Nn7B7VURZ4LmWKsem5KebbDrYystfU4,14023
21
- adv_optm-1.2.dev13.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
22
- adv_optm-1.2.dev13.dist-info/top_level.txt,sha256=iNfBIIzu-lPrQ7jyC56WBCcbkRwitM2nJ15-MRQ_6fg,9
23
- adv_optm-1.2.dev13.dist-info/RECORD,,
19
+ adv_optm-1.2.dev15.dist-info/licenses/LICENSE,sha256=HrhfyXIkWY2tGFK11kg7vPCqhgh5DcxleloqdhrpyMY,11558
20
+ adv_optm-1.2.dev15.dist-info/METADATA,sha256=CH8IxEUd-TSH5dVzXKR-rl54pTIIU_JTN_MkviBWprs,14023
21
+ adv_optm-1.2.dev15.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
22
+ adv_optm-1.2.dev15.dist-info/top_level.txt,sha256=iNfBIIzu-lPrQ7jyC56WBCcbkRwitM2nJ15-MRQ_6fg,9
23
+ adv_optm-1.2.dev15.dist-info/RECORD,,