adv-optm 1.2.dev14__py3-none-any.whl → 1.2.dev15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of adv-optm might be problematic. Click here for more details.
- adv_optm/__init__.py +1 -1
- adv_optm/optim/AdaMuon_adv.py +132 -66
- adv_optm/optim/Muon_adv.py +33 -12
- adv_optm/util/Newton_Schulz.py +55 -16
- {adv_optm-1.2.dev14.dist-info → adv_optm-1.2.dev15.dist-info}/METADATA +1 -1
- {adv_optm-1.2.dev14.dist-info → adv_optm-1.2.dev15.dist-info}/RECORD +9 -9
- {adv_optm-1.2.dev14.dist-info → adv_optm-1.2.dev15.dist-info}/WHEEL +0 -0
- {adv_optm-1.2.dev14.dist-info → adv_optm-1.2.dev15.dist-info}/licenses/LICENSE +0 -0
- {adv_optm-1.2.dev14.dist-info → adv_optm-1.2.dev15.dist-info}/top_level.txt +0 -0
adv_optm/__init__.py
CHANGED
adv_optm/optim/AdaMuon_adv.py
CHANGED
|
@@ -63,6 +63,9 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
63
63
|
(default: False)
|
|
64
64
|
ortho_rank (int): The rank for low-rank orthogonalization.
|
|
65
65
|
(default: 128)
|
|
66
|
+
accelerated_ns (bool): If True, enables Chebyshev-accelerated Newton-Schulz, which
|
|
67
|
+
dynamically calculates optimal 3rd-order polynomial coefficients. (default: False)
|
|
68
|
+
cns_a_bound (float): Initial lower bound for singular values for CANS. (default: 1e-4)
|
|
66
69
|
nnmf_factor (bool): whether to use the factorization or disable it to use
|
|
67
70
|
the uncompressed optimizer. (default: False)
|
|
68
71
|
--- Auxiliary AdamW_adv Parameters (used for 'adam' groups) ---
|
|
@@ -96,11 +99,16 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
96
99
|
nesterov: bool = False,
|
|
97
100
|
Simplified_AdEMAMix: bool = False,
|
|
98
101
|
alpha_grad: float = 100.0,
|
|
99
|
-
|
|
102
|
+
normuon_variant: bool = False,
|
|
100
103
|
# Low-rank Muon
|
|
101
104
|
low_rank_ortho: bool = False,
|
|
102
105
|
ortho_rank: int = 128,
|
|
106
|
+
# Factored
|
|
107
|
+
vector_reshape: bool = False,
|
|
103
108
|
nnmf_factor: bool = False,
|
|
109
|
+
# CANS
|
|
110
|
+
accelerated_ns: bool = False,
|
|
111
|
+
cns_a_bound: float = 1e-4,
|
|
104
112
|
# Compiled
|
|
105
113
|
compiled_optimizer: bool = False,
|
|
106
114
|
# --- AdamW_adv specific parameters ---
|
|
@@ -139,9 +147,12 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
139
147
|
"vector_reshape": vector_reshape,
|
|
140
148
|
"nesterov":nesterov, "use_atan2":use_atan2,
|
|
141
149
|
"Simplified_AdEMAMix": Simplified_AdEMAMix, "alpha_grad": alpha_grad,
|
|
150
|
+
"normuon_variant": normuon_variant,
|
|
142
151
|
# Low-rank Ortho
|
|
143
152
|
"low_rank_ortho": low_rank_ortho, "ortho_rank": ortho_rank,
|
|
144
153
|
"compiled_optimizer":compiled_optimizer,
|
|
154
|
+
# CANS
|
|
155
|
+
"accelerated_ns": accelerated_ns, "cns_a_bound": cns_a_bound,
|
|
145
156
|
# AdamW_adv defaults
|
|
146
157
|
"adam_betas": adam_betas, "adam_eps": adam_eps, "adam_weight_decay": adam_weight_decay,
|
|
147
158
|
"adam_use_bias_correction": adam_use_bias_correction, "adam_use_atan2": adam_use_atan2,
|
|
@@ -156,7 +167,6 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
156
167
|
|
|
157
168
|
super().__init__(params, defaults)
|
|
158
169
|
|
|
159
|
-
self.global_step = 0 # For Adam bias correction and Kourkoutas
|
|
160
170
|
self.kourkoutas_helper = None
|
|
161
171
|
if any(group.get('adam_kourkoutas_beta', False) for group in self.param_groups):
|
|
162
172
|
self.kourkoutas_helper = KourkoutasHelper(self)
|
|
@@ -214,16 +224,25 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
214
224
|
state['mv_mbuf_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
|
|
215
225
|
packed_d2 = (d2 + 7) // 8
|
|
216
226
|
state['sign_buf'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=device)
|
|
217
|
-
|
|
218
|
-
|
|
227
|
+
if not group['normuon_variant']:
|
|
228
|
+
state['mu_vbuf_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
|
|
229
|
+
state['mv_vbuf_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
|
|
219
230
|
else:
|
|
220
|
-
if len(p.shape) >= 2:
|
|
231
|
+
if len(p.shape) >= 2 and not group['normuon_variant']:
|
|
221
232
|
state['second_momentum_buffer'] = torch.zeros_like(p)
|
|
222
233
|
state['momentum_buffer'] = torch.zeros_like(p)
|
|
223
234
|
|
|
235
|
+
# NorMuon state initialization
|
|
236
|
+
if group['normuon_variant']:
|
|
237
|
+
if state['factored']:
|
|
238
|
+
state['normuon_v'] = torch.zeros(d1, device=p.device, dtype=torch.float32)
|
|
239
|
+
elif len(p.shape) >= 2:
|
|
240
|
+
state['normuon_v'] = torch.zeros(p.shape[0], device=p.device, dtype=torch.float32)
|
|
224
241
|
|
|
225
242
|
elif optim_type == 'adam':
|
|
226
243
|
|
|
244
|
+
state['step'] = 0
|
|
245
|
+
|
|
227
246
|
state['factored'] = (
|
|
228
247
|
group['adam_nnmf_factor'] and
|
|
229
248
|
not (len(p.shape) == 1 and not group['vector_reshape'])
|
|
@@ -301,12 +320,12 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
301
320
|
Q, _ = torch.linalg.qr(MG)
|
|
302
321
|
projected_M = Q.T @ M
|
|
303
322
|
ortho_projected_M = _newton_schulz_iteration(
|
|
304
|
-
projected_M, steps=group['ns_steps'], eps=group['ns_eps'], coeffs=group['ns_coeffs']
|
|
323
|
+
projected_M, steps=group['ns_steps'], eps=group['ns_eps'], coeffs=group['ns_coeffs'], cns=group['accelerated_ns'], cns_a_bound=group['cns_a_bound']
|
|
305
324
|
)
|
|
306
325
|
update = Q @ ortho_projected_M
|
|
307
326
|
else: # Fallback for invalid rank
|
|
308
327
|
update = _newton_schulz_iteration(
|
|
309
|
-
signed_m_buf, steps=group['ns_steps'], eps=group['ns_eps'], coeffs=group['ns_coeffs']
|
|
328
|
+
signed_m_buf, steps=group['ns_steps'], eps=group['ns_eps'], coeffs=group['ns_coeffs'], cns=group['accelerated_ns'], cns_a_bound=group['cns_a_bound']
|
|
310
329
|
)
|
|
311
330
|
else:
|
|
312
331
|
# Original full Newton-Schulz
|
|
@@ -315,41 +334,61 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
315
334
|
steps=group['ns_steps'],
|
|
316
335
|
eps=group['ns_eps'],
|
|
317
336
|
coeffs=group['ns_coeffs'],
|
|
337
|
+
cns=group['accelerated_ns'],
|
|
338
|
+
cns_a_bound=group['cns_a_bound'],
|
|
318
339
|
)
|
|
319
340
|
del signed_m_buf
|
|
320
341
|
|
|
321
|
-
|
|
322
|
-
|
|
342
|
+
if group['normuon_variant']:
|
|
343
|
+
v_t = state['normuon_v']
|
|
344
|
+
# Update 2nd moment estimate
|
|
345
|
+
mean_squared_update = torch.mean(update.square(), dim=1)
|
|
346
|
+
v_t.mul_(beta2).add_(mean_squared_update, alpha=1 - beta2)
|
|
347
|
+
# Normalize update
|
|
348
|
+
if group['use_atan2']:
|
|
349
|
+
a = 1.2732395
|
|
350
|
+
update.atan2_(v_t.sqrt().unsqueeze(1)).mul_(a)
|
|
351
|
+
else:
|
|
352
|
+
update.div_(v_t.sqrt().unsqueeze(1).add_(group['eps']))
|
|
353
|
+
# Scale learning rate
|
|
354
|
+
update_norm = torch.linalg.vector_norm(update)
|
|
355
|
+
scaled_lr = group['rms_target'] * lr * (p.numel()**0.5) / update_norm.add_(group['eps'])
|
|
356
|
+
update = update.view(p.shape).mul_(scaled_lr)
|
|
357
|
+
del mean_squared_update, update_norm, scaled_lr
|
|
358
|
+
else:
|
|
359
|
+
# Reconstruct second momentum from previous step's factors
|
|
360
|
+
vt_buf = _unnmf((state['mu_vbuf_nmf'], state['mv_vbuf_nmf']))
|
|
323
361
|
|
|
324
|
-
|
|
325
|
-
|
|
362
|
+
# Update second momentum in full-size
|
|
363
|
+
vt_buf.mul_(beta2).addcmul_(update, update, value=1 - beta2)
|
|
326
364
|
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
365
|
+
# Apply second momentum update (adaptive scaling)
|
|
366
|
+
if group['use_atan2']:
|
|
367
|
+
a = 1.2732395
|
|
368
|
+
denom = vt_buf.sqrt()
|
|
369
|
+
update.atan2_(denom).mul_(a)
|
|
370
|
+
else:
|
|
371
|
+
denom = vt_buf.sqrt().add_(group['eps'])
|
|
372
|
+
update.div_(denom)
|
|
373
|
+
del denom
|
|
336
374
|
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
375
|
+
# RMS-aligned rescaling
|
|
376
|
+
rms_target = group['rms_target']
|
|
377
|
+
num_elements = update.numel()
|
|
378
|
+
# Add eps to prevent division by zero
|
|
379
|
+
update.mul_(rms_target * (num_elements ** 0.5) / (update.norm().add_(group['eps'])))
|
|
342
380
|
|
|
343
|
-
|
|
344
|
-
|
|
381
|
+
update = update.view(p.shape).mul_(lr)
|
|
382
|
+
del num_elements
|
|
345
383
|
|
|
346
384
|
# Compress updated moments and store new factors
|
|
347
385
|
state['sign_buf'] = _pack_bools(mt_buf > 0)
|
|
348
386
|
_nnmf(mt_buf.abs(), out=(state['mu_mbuf_nmf'], state['mv_mbuf_nmf']))
|
|
349
387
|
del mt_buf
|
|
350
388
|
|
|
351
|
-
|
|
352
|
-
|
|
389
|
+
if not group['normuon_variant']:
|
|
390
|
+
_nnmf(vt_buf.abs(), out=(state['mu_vbuf_nmf'], state['mv_vbuf_nmf']))
|
|
391
|
+
del vt_buf
|
|
353
392
|
|
|
354
393
|
else: # Standard AdaMuon logic for non-factored tensors
|
|
355
394
|
|
|
@@ -387,12 +426,12 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
387
426
|
Q, _ = torch.linalg.qr(MG)
|
|
388
427
|
projected_M = Q.T @ M
|
|
389
428
|
ortho_projected_M = _newton_schulz_iteration(
|
|
390
|
-
projected_M, steps=group['ns_steps'], eps=group['ns_eps'], coeffs=group['ns_coeffs']
|
|
429
|
+
projected_M, steps=group['ns_steps'], eps=group['ns_eps'], coeffs=group['ns_coeffs'], cns=group['accelerated_ns'], cns_a_bound=group['cns_a_bound']
|
|
391
430
|
)
|
|
392
431
|
update = Q @ ortho_projected_M
|
|
393
432
|
else: # Fallback for invalid rank
|
|
394
433
|
update = _newton_schulz_iteration(
|
|
395
|
-
signed_m_buf, steps=group['ns_steps'], eps=group['ns_eps'], coeffs=group['ns_coeffs']
|
|
434
|
+
signed_m_buf, steps=group['ns_steps'], eps=group['ns_eps'], coeffs=group['ns_coeffs'], cns=group['accelerated_ns'], cns_a_bound=group['cns_a_bound']
|
|
396
435
|
)
|
|
397
436
|
else:
|
|
398
437
|
# Original full Newton-Schulz
|
|
@@ -401,42 +440,59 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
401
440
|
steps=group['ns_steps'],
|
|
402
441
|
eps=group['ns_eps'],
|
|
403
442
|
coeffs=group['ns_coeffs'],
|
|
443
|
+
cns=group['accelerated_ns'],
|
|
444
|
+
cns_a_bound=group['cns_a_bound'],
|
|
404
445
|
)
|
|
405
446
|
del signed_m_buf
|
|
406
447
|
|
|
407
448
|
update = update.view(original_shape)
|
|
408
449
|
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
450
|
+
if group['normuon_variant']:
|
|
451
|
+
# NorMuon Logic
|
|
452
|
+
v_t = state['normuon_v']
|
|
453
|
+
# Update 2nd moment estimate
|
|
454
|
+
mean_squared_update = torch.mean(update.square(), dim=1)
|
|
455
|
+
v_t.mul_(beta2).add_(mean_squared_update, alpha=1 - beta2)
|
|
456
|
+
# Normalize update
|
|
457
|
+
if group['use_atan2']:
|
|
458
|
+
a = 1.2732395
|
|
459
|
+
update.atan2_(v_t.sqrt().unsqueeze(1)).mul_(a)
|
|
460
|
+
else:
|
|
461
|
+
update.div_(v_t.sqrt().unsqueeze(1).add_(group['eps']))
|
|
462
|
+
# Scale learning rate
|
|
463
|
+
update_norm = torch.linalg.vector_norm(update)
|
|
464
|
+
scaled_lr = group['rms_target'] * lr * (p.numel()**0.5) / update_norm.add_(group['eps'])
|
|
465
|
+
update.mul_(scaled_lr)
|
|
466
|
+
del mean_squared_update, update_norm, scaled_lr
|
|
417
467
|
else:
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
468
|
+
# Original AdaMuon Logic
|
|
469
|
+
vt_buf = state['second_momentum_buffer']
|
|
470
|
+
vt_buf.mul_(beta2).addcmul_(update, update, value=1 - beta2)
|
|
471
|
+
# Apply second momentum update (adaptive scaling)
|
|
472
|
+
if group['use_atan2']:
|
|
473
|
+
a = 1.2732395
|
|
474
|
+
denom = vt_buf.sqrt()
|
|
475
|
+
update.atan2_(denom).mul_(a)
|
|
476
|
+
else:
|
|
477
|
+
denom = vt_buf.sqrt().add_(group['eps'])
|
|
478
|
+
update.div_(denom)
|
|
479
|
+
del denom
|
|
480
|
+
# RMS-aligned rescaling
|
|
481
|
+
rms_target = group['rms_target']
|
|
482
|
+
num_elements = update.numel()
|
|
483
|
+
# Add eps to prevent division by zero
|
|
484
|
+
update.mul_(rms_target * (num_elements ** 0.5) / (update.norm().add_(group['eps'])))
|
|
485
|
+
del num_elements
|
|
486
|
+
update.mul_(lr)
|
|
431
487
|
|
|
432
488
|
else: # Fallback to standard SGD with momentum for 1D params (biases, etc.)
|
|
433
489
|
# Momentum update
|
|
434
490
|
mt_buf = state['momentum_buffer']
|
|
435
491
|
mt_buf.mul_(beta1).add_(grad)
|
|
436
492
|
if nesterov:
|
|
437
|
-
# Nesterov momentum
|
|
438
493
|
update = grad.add(mt_buf, alpha=beta1)
|
|
439
|
-
#
|
|
494
|
+
# FIXME, Simplified_AdEMAMix will break SGD since it requires x100 lower LR
|
|
495
|
+
# elif Simplified_AdEMAMix:
|
|
440
496
|
# update = mt_buf.add(grad, alpha=alpha_grad)
|
|
441
497
|
else:
|
|
442
498
|
update = mt_buf.clone()
|
|
@@ -599,7 +655,6 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
599
655
|
state = self.state[p]
|
|
600
656
|
|
|
601
657
|
|
|
602
|
-
|
|
603
658
|
# Determine if using Adam or Muon based on state keys
|
|
604
659
|
# We can use optm_type but I see this as a safer way.
|
|
605
660
|
if 'momentum_buffer' in state or 'mu_mbuf_nmf' in state:
|
|
@@ -611,32 +666,41 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
611
666
|
is_compiled = group.get('compiled_optimizer', False)
|
|
612
667
|
|
|
613
668
|
if use_adam:
|
|
669
|
+
step = state['step']
|
|
670
|
+
|
|
614
671
|
if self.kourkoutas_helper:
|
|
615
672
|
# Prepare Kourkoutas-β once per optimizer step.
|
|
616
|
-
self.kourkoutas_helper.maybe_prepare_step(
|
|
673
|
+
self.kourkoutas_helper.maybe_prepare_step(step)
|
|
674
|
+
|
|
617
675
|
# Adam-specific setup (bias correction)
|
|
618
676
|
if group['adam_use_bias_correction']:
|
|
619
|
-
current_step =
|
|
677
|
+
current_step = step + 1
|
|
620
678
|
beta1_adam, beta2_adam = group['adam_betas']
|
|
621
679
|
bias_correction1 = 1.0 - beta1_adam ** current_step
|
|
622
680
|
bias_correction2 = 1.0 - beta2_adam ** current_step
|
|
623
681
|
else:
|
|
624
682
|
bias_correction1 = 1.0
|
|
625
683
|
bias_correction2 = 1.0
|
|
684
|
+
|
|
685
|
+
self.state[p]['step'] += 1
|
|
686
|
+
|
|
626
687
|
# Dispatch to compiled or uncompiled Adam step
|
|
627
688
|
if is_compiled and self._compiled_adam_step is not None:
|
|
628
|
-
#
|
|
629
|
-
|
|
630
|
-
|
|
631
|
-
|
|
632
|
-
|
|
689
|
+
# convert to tensors for compiled path once a step
|
|
690
|
+
if not hasattr(self, 'lr_adam_tensor') or self.lr_adam_tensor is None:
|
|
691
|
+
self.lr_adam_tensor = torch.tensor(group['lr'])
|
|
692
|
+
self.bc1 = torch.tensor(bias_correction1)
|
|
693
|
+
self.bc2 = torch.tensor(bias_correction2)
|
|
694
|
+
self._compiled_adam_step(p, grad, state, group, self.lr_adam_tensor, self.bc1, self.bc2)
|
|
633
695
|
else:
|
|
634
696
|
self._adam_step_parameter(p, grad, state, group, lr, bias_correction1, bias_correction2)
|
|
635
697
|
else: # Muon path
|
|
636
698
|
# Dispatch to compiled or uncompiled Muon step
|
|
637
699
|
if is_compiled and self._compiled_muon_step is not None:
|
|
638
|
-
|
|
639
|
-
self
|
|
700
|
+
# convert to tensors for compiled path once a step
|
|
701
|
+
if not hasattr(self, 'lr_tensor') or self.lr_tensor is None:
|
|
702
|
+
self.lr_tensor = torch.tensor(group['lr'])
|
|
703
|
+
self._compiled_muon_step(p, grad, state, group, self.lr_tensor)
|
|
640
704
|
else:
|
|
641
705
|
self._muon_step_parameter(p, grad, state, group, lr)
|
|
642
706
|
|
|
@@ -659,6 +723,8 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
659
723
|
for i, p in enumerate(group['params']):
|
|
660
724
|
self.step_parameter(p, group, i)
|
|
661
725
|
|
|
662
|
-
self.
|
|
663
|
-
|
|
726
|
+
if self.param_groups[0].get('compiled_optimizer', False):
|
|
727
|
+
# Reset compile tensors once a step
|
|
728
|
+
self.lr_tensor = None
|
|
729
|
+
self.lr_adam_tensor = None
|
|
664
730
|
return loss
|
adv_optm/optim/Muon_adv.py
CHANGED
|
@@ -60,6 +60,9 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
60
60
|
normuon_lr_scale (float): Scaling factor for the NorMuon learning rate.
|
|
61
61
|
(default: 0.2)
|
|
62
62
|
normuon_atan2 (bool): whether to use the atan2 for NorMuon. (default: False)
|
|
63
|
+
accelerated_ns (bool): If True, enables Chebyshev-accelerated Newton-Schulz, which
|
|
64
|
+
dynamically calculates optimal 3rd-order polynomial coefficients. (default: False)
|
|
65
|
+
cns_a_bound (float): Initial lower bound for singular values for CANS. (default: 1e-4)
|
|
63
66
|
--- Auxiliary AdamW_adv Parameters (used for 'adam' groups) ---
|
|
64
67
|
adam_betas (tuple[float, float]): Betas for the AdamW optimizer part.
|
|
65
68
|
adam_eps (float): Epsilon for the AdamW optimizer part.
|
|
@@ -100,6 +103,9 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
100
103
|
normuon_eps: float = 1e-8,
|
|
101
104
|
normuon_lr_scale: float = 0.2,
|
|
102
105
|
normuon_atan2: bool = False,
|
|
106
|
+
# CANS
|
|
107
|
+
accelerated_ns: bool = False,
|
|
108
|
+
cns_a_bound: float = 1e-4,
|
|
103
109
|
# Compiled
|
|
104
110
|
compiled_optimizer: bool = False,
|
|
105
111
|
# --- AdamW_adv specific parameters ---
|
|
@@ -148,6 +154,8 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
148
154
|
"normuon_variant": normuon_variant, "beta2_normuon": beta2_normuon,
|
|
149
155
|
"normuon_eps": normuon_eps, "normuon_lr_scale": normuon_lr_scale,
|
|
150
156
|
"normuon_atan2": normuon_atan2,
|
|
157
|
+
# CANS
|
|
158
|
+
"accelerated_ns": accelerated_ns, "cns_a_bound": cns_a_bound,
|
|
151
159
|
# AdamW_adv defaults
|
|
152
160
|
"adam_betas": adam_betas, "adam_eps": adam_eps, "adam_weight_decay": adam_weight_decay,
|
|
153
161
|
"adam_use_bias_correction": adam_use_bias_correction, "adam_use_atan2": adam_use_atan2,
|
|
@@ -163,7 +171,6 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
163
171
|
|
|
164
172
|
super().__init__(params, defaults)
|
|
165
173
|
|
|
166
|
-
self.global_step = 0 # For Adam bias correction and Kourkoutas
|
|
167
174
|
self.kourkoutas_helper = None
|
|
168
175
|
if any(group.get('adam_kourkoutas_beta', False) for group in self.param_groups):
|
|
169
176
|
self.kourkoutas_helper = KourkoutasHelper(self)
|
|
@@ -214,6 +221,7 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
214
221
|
|
|
215
222
|
if optim_type == 'muon':
|
|
216
223
|
|
|
224
|
+
|
|
217
225
|
state['factored'] = (
|
|
218
226
|
group['nnmf_factor'] and
|
|
219
227
|
not (len(p.shape) == 1 and not group['vector_reshape'])
|
|
@@ -238,9 +246,12 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
238
246
|
elif len(p.shape) >= 2:
|
|
239
247
|
state['normuon_v'] = torch.zeros(p.shape[0], device=p.device, dtype=torch.float32)
|
|
240
248
|
|
|
249
|
+
group['adam_kourkoutas_beta'] = False
|
|
241
250
|
|
|
242
251
|
elif optim_type == 'adam':
|
|
243
252
|
|
|
253
|
+
state['step'] = 0
|
|
254
|
+
|
|
244
255
|
state['factored'] = (
|
|
245
256
|
group['adam_nnmf_factor'] and
|
|
246
257
|
not (len(p.shape) == 1 and not group['vector_reshape'])
|
|
@@ -319,12 +330,12 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
319
330
|
Q, _ = torch.linalg.qr(MG)
|
|
320
331
|
projected_M = Q.T @ M
|
|
321
332
|
ortho_projected_M = _newton_schulz_iteration(
|
|
322
|
-
projected_M, steps=group['ns_steps'], eps=group['ns_eps'], coeffs=group['ns_coeffs']
|
|
333
|
+
projected_M, steps=group['ns_steps'], eps=group['ns_eps'], coeffs=group['ns_coeffs'], cns=group['accelerated_ns'], cns_a_bound=group['cns_a_bound']
|
|
323
334
|
)
|
|
324
335
|
update = Q @ ortho_projected_M
|
|
325
336
|
else: # Fallback for invalid rank
|
|
326
337
|
update = _newton_schulz_iteration(
|
|
327
|
-
update, steps=group['ns_steps'], eps=group['ns_eps'], coeffs=group['ns_coeffs']
|
|
338
|
+
update, steps=group['ns_steps'], eps=group['ns_eps'], coeffs=group['ns_coeffs'], cns=group['accelerated_ns'], cns_a_bound=group['cns_a_bound']
|
|
328
339
|
)
|
|
329
340
|
else:
|
|
330
341
|
# Original full Newton-Schulz
|
|
@@ -333,10 +344,12 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
333
344
|
steps=group['ns_steps'],
|
|
334
345
|
eps=group['ns_eps'],
|
|
335
346
|
coeffs=group['ns_coeffs'],
|
|
347
|
+
cns=group['accelerated_ns'],
|
|
348
|
+
cns_a_bound=group['cns_a_bound'],
|
|
336
349
|
)
|
|
337
350
|
|
|
338
351
|
|
|
339
|
-
if group['normuon_variant']
|
|
352
|
+
if group['normuon_variant']:
|
|
340
353
|
v_t = state['normuon_v']
|
|
341
354
|
beta2_normuon = group['beta2_normuon']
|
|
342
355
|
# Update 2nd moment estimate
|
|
@@ -414,6 +427,8 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
414
427
|
steps=group['ns_steps'],
|
|
415
428
|
eps=group['ns_eps'],
|
|
416
429
|
coeffs=group['ns_coeffs'],
|
|
430
|
+
cns=group['accelerated_ns'],
|
|
431
|
+
cns_a_bound=group['cns_a_bound'],
|
|
417
432
|
)
|
|
418
433
|
|
|
419
434
|
# 5. Project back to the original space
|
|
@@ -424,6 +439,8 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
424
439
|
steps=group['ns_steps'],
|
|
425
440
|
eps=group['ns_eps'],
|
|
426
441
|
coeffs=group['ns_coeffs'],
|
|
442
|
+
cns=group['accelerated_ns'],
|
|
443
|
+
cns_a_bound=group['cns_a_bound'],
|
|
427
444
|
)
|
|
428
445
|
else:
|
|
429
446
|
# Original NewtonSchulz
|
|
@@ -432,10 +449,12 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
432
449
|
steps=group['ns_steps'],
|
|
433
450
|
eps=group['ns_eps'],
|
|
434
451
|
coeffs=group['ns_coeffs'],
|
|
452
|
+
cns=group['accelerated_ns'],
|
|
453
|
+
cns_a_bound=group['cns_a_bound'],
|
|
435
454
|
)
|
|
436
455
|
|
|
437
456
|
# NorMuon Logic
|
|
438
|
-
if group['normuon_variant']
|
|
457
|
+
if group['normuon_variant']:
|
|
439
458
|
v_t = state['normuon_v']
|
|
440
459
|
beta2_normuon = group['beta2_normuon']
|
|
441
460
|
# Update 2nd moment estimate
|
|
@@ -629,10 +648,6 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
629
648
|
|
|
630
649
|
state = self.state[p]
|
|
631
650
|
|
|
632
|
-
if self.kourkoutas_helper:
|
|
633
|
-
# Prepare Kourkoutas-β once per optimizer step.
|
|
634
|
-
self.kourkoutas_helper.maybe_prepare_step(self.global_step)
|
|
635
|
-
|
|
636
651
|
# Determine if using Adam or Muon based on state keys
|
|
637
652
|
# We can use optm_type but I see this as a safer way.
|
|
638
653
|
if 'momentum_buffer' in state or 'mu_mbuf_nmf' in state:
|
|
@@ -644,9 +659,15 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
644
659
|
is_compiled = group.get('compiled_optimizer', False)
|
|
645
660
|
|
|
646
661
|
if use_adam:
|
|
662
|
+
step = state['step']
|
|
663
|
+
|
|
664
|
+
if self.kourkoutas_helper:
|
|
665
|
+
# Prepare Kourkoutas-β once per optimizer step.
|
|
666
|
+
self.kourkoutas_helper.maybe_prepare_step(step)
|
|
667
|
+
|
|
647
668
|
# Adam-specific setup (bias correction)
|
|
648
669
|
if group['adam_use_bias_correction']:
|
|
649
|
-
current_step =
|
|
670
|
+
current_step = step + 1
|
|
650
671
|
beta1_adam, beta2_adam = group['adam_betas']
|
|
651
672
|
bias_correction1 = 1.0 - beta1_adam ** current_step
|
|
652
673
|
bias_correction2 = 1.0 - beta2_adam ** current_step
|
|
@@ -654,6 +675,8 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
654
675
|
bias_correction1 = 1.0
|
|
655
676
|
bias_correction2 = 1.0
|
|
656
677
|
|
|
678
|
+
self.state[p]['step'] += 1
|
|
679
|
+
|
|
657
680
|
# Dispatch to compiled or uncompiled Adam step
|
|
658
681
|
if is_compiled and self._compiled_adam_step is not None:
|
|
659
682
|
# Tensors must be used for compiled functions
|
|
@@ -690,6 +713,4 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
690
713
|
for i, p in enumerate(group['params']):
|
|
691
714
|
self.step_parameter(p, group, i)
|
|
692
715
|
|
|
693
|
-
self.global_step += 1
|
|
694
|
-
|
|
695
716
|
return loss
|
adv_optm/util/Newton_Schulz.py
CHANGED
|
@@ -5,7 +5,9 @@ def _newton_schulz_iteration(
|
|
|
5
5
|
G: torch.Tensor,
|
|
6
6
|
steps: int = 5,
|
|
7
7
|
eps: float = 1e-7,
|
|
8
|
-
coeffs: tuple[float, float, float] = (3.4445, -4.7750, 2.0315)
|
|
8
|
+
coeffs: tuple[float, float, float] = (3.4445, -4.7750, 2.0315),
|
|
9
|
+
cns: bool = False,
|
|
10
|
+
cns_a_bound: float = 1e-4,
|
|
9
11
|
) -> torch.Tensor:
|
|
10
12
|
"""
|
|
11
13
|
Performs the Newton-Schulz iteration to find the nearest orthogonal matrix.
|
|
@@ -17,32 +19,69 @@ def _newton_schulz_iteration(
|
|
|
17
19
|
eps (float): Small constant for numerical stability during normalization.
|
|
18
20
|
coeffs (tuple[float, float, float]): The (a, b, c) coefficients for the
|
|
19
21
|
quintic polynomial update.
|
|
20
|
-
|
|
22
|
+
cns (bool): If True, enables Chebyshev-accelerated Newton-Schulz (CANS)
|
|
23
|
+
using an iterative 3rd-order polynomial with optimal coefficients
|
|
24
|
+
derived at each step.
|
|
25
|
+
cns_a_bound (float): The initial lower bound for singular values when
|
|
26
|
+
using CANS. The upper bound is assumed to be 1.0 after normalization.
|
|
21
27
|
Returns:
|
|
22
28
|
torch.Tensor: The orthogonalized matrix.
|
|
23
29
|
"""
|
|
24
|
-
assert G.ndim
|
|
30
|
+
assert G.ndim >= 2
|
|
25
31
|
|
|
26
32
|
a, b, c = coeffs
|
|
27
33
|
|
|
28
34
|
X = G.to(torch.bfloat16)
|
|
29
35
|
|
|
30
|
-
|
|
31
|
-
X.div_(X.norm() + eps)
|
|
32
|
-
|
|
33
|
-
# Handle non-square matrices by transposing the taller one
|
|
34
|
-
transposed = G.size(0) > G.size(1)
|
|
36
|
+
transposed = G.size(-2) > G.size(-1)
|
|
35
37
|
if transposed:
|
|
36
|
-
X = X.
|
|
38
|
+
X = X.mT
|
|
39
|
+
|
|
40
|
+
X = X / (X.norm(dim=(-2, -1), keepdim=True) + eps)
|
|
41
|
+
|
|
42
|
+
if cns:
|
|
43
|
+
# Chebyshev-accelerated Newton-Schulz (CANS) from
|
|
44
|
+
# "Accelerating Newton-Schulz Iteration for Orthogonalization via Chebyshev-type Polynomials"
|
|
45
|
+
# This implements the iterative scheme from Algorithm 1, using the
|
|
46
|
+
# closed-form 3rd-order polynomial from Proposition 2.
|
|
47
|
+
lower_bound = cns_a_bound
|
|
48
|
+
upper_bound = 1.0 # Matrix is normalized, so largest singular value is approx 1.
|
|
49
|
+
|
|
50
|
+
for _ in range(steps):
|
|
51
|
+
# Calculate optimal 3rd-order coefficients c1, c3 for p(x) = c1*x + c3*x^3
|
|
52
|
+
# based on the current singular value bounds [lower_bound, upper_bound].
|
|
53
|
+
# Formulas are derived from Proposition 2 and its proof in Appendix B of the paper.
|
|
54
|
+
a_bound, b_bound = lower_bound, upper_bound
|
|
55
|
+
term = a_bound*a_bound + a_bound*b_bound + b_bound*b_bound
|
|
56
|
+
e_sq = term / 3.0
|
|
57
|
+
|
|
58
|
+
# Calculate alpha, which scales the polynomial
|
|
59
|
+
common_den_part = 2.0 * (e_sq**1.5)
|
|
60
|
+
ab_part = a_bound*a_bound*b_bound + b_bound*b_bound*a_bound
|
|
61
|
+
alpha_den = common_den_part + ab_part
|
|
62
|
+
alpha = 6.0 / alpha_den
|
|
63
|
+
|
|
64
|
+
c1 = alpha * e_sq
|
|
65
|
+
c3 = -alpha / 3.0
|
|
66
|
+
|
|
67
|
+
# Apply the 3rd-order Newton-Schulz update
|
|
68
|
+
A = X @ X.mT
|
|
69
|
+
X = c1 * X + c3 * (A @ X)
|
|
37
70
|
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
71
|
+
# Update the singular value bounds for the next iteration based on the error
|
|
72
|
+
eps_num = common_den_part - ab_part
|
|
73
|
+
eps_val = eps_num / alpha_den
|
|
74
|
+
lower_bound = 1.0 - eps_val
|
|
75
|
+
upper_bound = 1.0 + eps_val
|
|
76
|
+
else:
|
|
77
|
+
# Perform the iterative updates
|
|
78
|
+
for _ in range(steps):
|
|
79
|
+
A = X @ X.mT
|
|
80
|
+
B = b * A + c * (A @ A)
|
|
81
|
+
X = a * X + B @ X
|
|
43
82
|
|
|
44
83
|
# Transpose back if necessary
|
|
45
84
|
if transposed:
|
|
46
|
-
X = X.
|
|
85
|
+
X = X.mT
|
|
47
86
|
|
|
48
|
-
return X.to(G.dtype)
|
|
87
|
+
return X.to(G.dtype)
|
|
@@ -1,10 +1,10 @@
|
|
|
1
|
-
adv_optm/__init__.py,sha256=
|
|
2
|
-
adv_optm/optim/AdaMuon_adv.py,sha256
|
|
1
|
+
adv_optm/__init__.py,sha256=I9iRXHonvg_82dEmyKXqt9PyN04Ez8TVbMb1uZgRZAc,380
|
|
2
|
+
adv_optm/optim/AdaMuon_adv.py,sha256=zjZHFS7ng5KwemQzePjFiGtNZlcgbzmmnqF6A80h_Tg,34652
|
|
3
3
|
adv_optm/optim/AdamW_adv.py,sha256=KL9SCJWZ_ckAQEApB6ofbndVYjancN-v7Us7hJLFf54,17475
|
|
4
4
|
adv_optm/optim/Adopt_adv.py,sha256=S8XI2YA7683jsW8p7igc2YcU30lsN0H18qL02Kpvj8E,21244
|
|
5
5
|
adv_optm/optim/Lion_Prodigy_adv.py,sha256=LEA3UYJpPeFnmxeniLNv1u2LKKj4ufx3Bq_MLw-nWXk,14617
|
|
6
6
|
adv_optm/optim/Lion_adv.py,sha256=aGNAplZlyXYgVllYcV_s4bK8iC4fv6EizFoWIMNLdBc,8299
|
|
7
|
-
adv_optm/optim/Muon_adv.py,sha256=
|
|
7
|
+
adv_optm/optim/Muon_adv.py,sha256=QutgiRkDS36O5BQNdcwdIcYBCKPy7U07YYVQT6dq8tc,33165
|
|
8
8
|
adv_optm/optim/Prodigy_adv.py,sha256=lEjbtuQbomsCX39DnTPeI8Z5YG0f2aZPXN_E7-nGgWw,26060
|
|
9
9
|
adv_optm/optim/Simplified_AdEMAMix.py,sha256=nEIA3yM11nBooKzHudB5l3x4UdFRBYRwiKVUkGmO0K8,12971
|
|
10
10
|
adv_optm/optim/__init__.py,sha256=hpUWE6CKtt_rvMdgQVb3PtjhfZAvAxTq6hp8H8rIpBo,489
|
|
@@ -12,12 +12,12 @@ adv_optm/util/BF16_Stochastic_Rounding.py,sha256=Q5H0BcogmE4atP65dLoI21HKSf50lRd
|
|
|
12
12
|
adv_optm/util/Effective_Shape.py,sha256=TBvIk1V8IuTbbBsxuekJA4e_v8JlR5Nujtut8RTWAm4,318
|
|
13
13
|
adv_optm/util/Kourkoutas.py,sha256=SSzhe0B6Zb2AXGwCKpVTLr0aaFfspcFBNZCZG3azI9k,7516
|
|
14
14
|
adv_optm/util/NNMF.py,sha256=yRf5IP5Sjq0Uf0DxN0Q8NxEGSdD-f1ULziLVDOjY8K4,639
|
|
15
|
-
adv_optm/util/Newton_Schulz.py,sha256=
|
|
15
|
+
adv_optm/util/Newton_Schulz.py,sha256=bBboYw_jm5_FMf0Citl79uqNedkHOTjQnUI7rZgLBmY,3341
|
|
16
16
|
adv_optm/util/One_Bit_Boolean.py,sha256=Wat49esdwohuN-OHOFMW8D0aOQgV9cP5Rl8z6yfmpos,1068
|
|
17
17
|
adv_optm/util/OrthoGrad.py,sha256=NzInuBQGy_Ja__M1R9XbvqVaQ0fhGbtGgFE9YON7B3I,707
|
|
18
18
|
adv_optm/util/__init__.py,sha256=CXzS703GB4gil85khZi7sgKOnbzXGBOltshIOSPqj18,435
|
|
19
|
-
adv_optm-1.2.
|
|
20
|
-
adv_optm-1.2.
|
|
21
|
-
adv_optm-1.2.
|
|
22
|
-
adv_optm-1.2.
|
|
23
|
-
adv_optm-1.2.
|
|
19
|
+
adv_optm-1.2.dev15.dist-info/licenses/LICENSE,sha256=HrhfyXIkWY2tGFK11kg7vPCqhgh5DcxleloqdhrpyMY,11558
|
|
20
|
+
adv_optm-1.2.dev15.dist-info/METADATA,sha256=CH8IxEUd-TSH5dVzXKR-rl54pTIIU_JTN_MkviBWprs,14023
|
|
21
|
+
adv_optm-1.2.dev15.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
22
|
+
adv_optm-1.2.dev15.dist-info/top_level.txt,sha256=iNfBIIzu-lPrQ7jyC56WBCcbkRwitM2nJ15-MRQ_6fg,9
|
|
23
|
+
adv_optm-1.2.dev15.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|