adv-optm 1.2.dev14__tar.gz → 1.2.dev16__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of adv-optm might be problematic. Click here for more details.
- {adv_optm-1.2.dev14 → adv_optm-1.2.dev16}/PKG-INFO +1 -1
- {adv_optm-1.2.dev14 → adv_optm-1.2.dev16}/adv_optm/__init__.py +1 -1
- {adv_optm-1.2.dev14 → adv_optm-1.2.dev16}/adv_optm/optim/AdaMuon_adv.py +132 -66
- {adv_optm-1.2.dev14 → adv_optm-1.2.dev16}/adv_optm/optim/Muon_adv.py +36 -12
- adv_optm-1.2.dev16/adv_optm/util/Newton_Schulz.py +87 -0
- {adv_optm-1.2.dev14 → adv_optm-1.2.dev16}/adv_optm.egg-info/PKG-INFO +1 -1
- {adv_optm-1.2.dev14 → adv_optm-1.2.dev16}/setup.py +1 -1
- adv_optm-1.2.dev14/adv_optm/util/Newton_Schulz.py +0 -48
- {adv_optm-1.2.dev14 → adv_optm-1.2.dev16}/LICENSE +0 -0
- {adv_optm-1.2.dev14 → adv_optm-1.2.dev16}/README.md +0 -0
- {adv_optm-1.2.dev14 → adv_optm-1.2.dev16}/adv_optm/optim/AdamW_adv.py +0 -0
- {adv_optm-1.2.dev14 → adv_optm-1.2.dev16}/adv_optm/optim/Adopt_adv.py +0 -0
- {adv_optm-1.2.dev14 → adv_optm-1.2.dev16}/adv_optm/optim/Lion_Prodigy_adv.py +0 -0
- {adv_optm-1.2.dev14 → adv_optm-1.2.dev16}/adv_optm/optim/Lion_adv.py +0 -0
- {adv_optm-1.2.dev14 → adv_optm-1.2.dev16}/adv_optm/optim/Prodigy_adv.py +0 -0
- {adv_optm-1.2.dev14 → adv_optm-1.2.dev16}/adv_optm/optim/Simplified_AdEMAMix.py +0 -0
- {adv_optm-1.2.dev14 → adv_optm-1.2.dev16}/adv_optm/optim/__init__.py +0 -0
- {adv_optm-1.2.dev14 → adv_optm-1.2.dev16}/adv_optm/util/BF16_Stochastic_Rounding.py +0 -0
- {adv_optm-1.2.dev14 → adv_optm-1.2.dev16}/adv_optm/util/Effective_Shape.py +0 -0
- {adv_optm-1.2.dev14 → adv_optm-1.2.dev16}/adv_optm/util/Kourkoutas.py +0 -0
- {adv_optm-1.2.dev14 → adv_optm-1.2.dev16}/adv_optm/util/NNMF.py +0 -0
- {adv_optm-1.2.dev14 → adv_optm-1.2.dev16}/adv_optm/util/One_Bit_Boolean.py +0 -0
- {adv_optm-1.2.dev14 → adv_optm-1.2.dev16}/adv_optm/util/OrthoGrad.py +0 -0
- {adv_optm-1.2.dev14 → adv_optm-1.2.dev16}/adv_optm/util/__init__.py +0 -0
- {adv_optm-1.2.dev14 → adv_optm-1.2.dev16}/adv_optm.egg-info/SOURCES.txt +0 -0
- {adv_optm-1.2.dev14 → adv_optm-1.2.dev16}/adv_optm.egg-info/dependency_links.txt +0 -0
- {adv_optm-1.2.dev14 → adv_optm-1.2.dev16}/adv_optm.egg-info/requires.txt +0 -0
- {adv_optm-1.2.dev14 → adv_optm-1.2.dev16}/adv_optm.egg-info/top_level.txt +0 -0
- {adv_optm-1.2.dev14 → adv_optm-1.2.dev16}/setup.cfg +0 -0
|
@@ -63,6 +63,9 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
63
63
|
(default: False)
|
|
64
64
|
ortho_rank (int): The rank for low-rank orthogonalization.
|
|
65
65
|
(default: 128)
|
|
66
|
+
accelerated_ns (bool): If True, enables Chebyshev-accelerated Newton-Schulz, which
|
|
67
|
+
dynamically calculates optimal 3rd-order polynomial coefficients. (default: False)
|
|
68
|
+
cns_a_bound (float): Initial lower bound for singular values for CANS. (default: 1e-4)
|
|
66
69
|
nnmf_factor (bool): whether to use the factorization or disable it to use
|
|
67
70
|
the uncompressed optimizer. (default: False)
|
|
68
71
|
--- Auxiliary AdamW_adv Parameters (used for 'adam' groups) ---
|
|
@@ -96,11 +99,16 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
96
99
|
nesterov: bool = False,
|
|
97
100
|
Simplified_AdEMAMix: bool = False,
|
|
98
101
|
alpha_grad: float = 100.0,
|
|
99
|
-
|
|
102
|
+
normuon_variant: bool = False,
|
|
100
103
|
# Low-rank Muon
|
|
101
104
|
low_rank_ortho: bool = False,
|
|
102
105
|
ortho_rank: int = 128,
|
|
106
|
+
# Factored
|
|
107
|
+
vector_reshape: bool = False,
|
|
103
108
|
nnmf_factor: bool = False,
|
|
109
|
+
# CANS
|
|
110
|
+
accelerated_ns: bool = False,
|
|
111
|
+
cns_a_bound: float = 1e-4,
|
|
104
112
|
# Compiled
|
|
105
113
|
compiled_optimizer: bool = False,
|
|
106
114
|
# --- AdamW_adv specific parameters ---
|
|
@@ -139,9 +147,12 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
139
147
|
"vector_reshape": vector_reshape,
|
|
140
148
|
"nesterov":nesterov, "use_atan2":use_atan2,
|
|
141
149
|
"Simplified_AdEMAMix": Simplified_AdEMAMix, "alpha_grad": alpha_grad,
|
|
150
|
+
"normuon_variant": normuon_variant,
|
|
142
151
|
# Low-rank Ortho
|
|
143
152
|
"low_rank_ortho": low_rank_ortho, "ortho_rank": ortho_rank,
|
|
144
153
|
"compiled_optimizer":compiled_optimizer,
|
|
154
|
+
# CANS
|
|
155
|
+
"accelerated_ns": accelerated_ns, "cns_a_bound": cns_a_bound,
|
|
145
156
|
# AdamW_adv defaults
|
|
146
157
|
"adam_betas": adam_betas, "adam_eps": adam_eps, "adam_weight_decay": adam_weight_decay,
|
|
147
158
|
"adam_use_bias_correction": adam_use_bias_correction, "adam_use_atan2": adam_use_atan2,
|
|
@@ -156,7 +167,6 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
156
167
|
|
|
157
168
|
super().__init__(params, defaults)
|
|
158
169
|
|
|
159
|
-
self.global_step = 0 # For Adam bias correction and Kourkoutas
|
|
160
170
|
self.kourkoutas_helper = None
|
|
161
171
|
if any(group.get('adam_kourkoutas_beta', False) for group in self.param_groups):
|
|
162
172
|
self.kourkoutas_helper = KourkoutasHelper(self)
|
|
@@ -214,16 +224,25 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
214
224
|
state['mv_mbuf_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
|
|
215
225
|
packed_d2 = (d2 + 7) // 8
|
|
216
226
|
state['sign_buf'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=device)
|
|
217
|
-
|
|
218
|
-
|
|
227
|
+
if not group['normuon_variant']:
|
|
228
|
+
state['mu_vbuf_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
|
|
229
|
+
state['mv_vbuf_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
|
|
219
230
|
else:
|
|
220
|
-
if len(p.shape) >= 2:
|
|
231
|
+
if len(p.shape) >= 2 and not group['normuon_variant']:
|
|
221
232
|
state['second_momentum_buffer'] = torch.zeros_like(p)
|
|
222
233
|
state['momentum_buffer'] = torch.zeros_like(p)
|
|
223
234
|
|
|
235
|
+
# NorMuon state initialization
|
|
236
|
+
if group['normuon_variant']:
|
|
237
|
+
if state['factored']:
|
|
238
|
+
state['normuon_v'] = torch.zeros(d1, device=p.device, dtype=torch.float32)
|
|
239
|
+
elif len(p.shape) >= 2:
|
|
240
|
+
state['normuon_v'] = torch.zeros(p.shape[0], device=p.device, dtype=torch.float32)
|
|
224
241
|
|
|
225
242
|
elif optim_type == 'adam':
|
|
226
243
|
|
|
244
|
+
state['step'] = 0
|
|
245
|
+
|
|
227
246
|
state['factored'] = (
|
|
228
247
|
group['adam_nnmf_factor'] and
|
|
229
248
|
not (len(p.shape) == 1 and not group['vector_reshape'])
|
|
@@ -301,12 +320,12 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
301
320
|
Q, _ = torch.linalg.qr(MG)
|
|
302
321
|
projected_M = Q.T @ M
|
|
303
322
|
ortho_projected_M = _newton_schulz_iteration(
|
|
304
|
-
projected_M, steps=group['ns_steps'], eps=group['ns_eps'], coeffs=group['ns_coeffs']
|
|
323
|
+
projected_M, steps=group['ns_steps'], eps=group['ns_eps'], coeffs=group['ns_coeffs'], cns=group['accelerated_ns'], cns_a_bound=group['cns_a_bound']
|
|
305
324
|
)
|
|
306
325
|
update = Q @ ortho_projected_M
|
|
307
326
|
else: # Fallback for invalid rank
|
|
308
327
|
update = _newton_schulz_iteration(
|
|
309
|
-
signed_m_buf, steps=group['ns_steps'], eps=group['ns_eps'], coeffs=group['ns_coeffs']
|
|
328
|
+
signed_m_buf, steps=group['ns_steps'], eps=group['ns_eps'], coeffs=group['ns_coeffs'], cns=group['accelerated_ns'], cns_a_bound=group['cns_a_bound']
|
|
310
329
|
)
|
|
311
330
|
else:
|
|
312
331
|
# Original full Newton-Schulz
|
|
@@ -315,41 +334,61 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
315
334
|
steps=group['ns_steps'],
|
|
316
335
|
eps=group['ns_eps'],
|
|
317
336
|
coeffs=group['ns_coeffs'],
|
|
337
|
+
cns=group['accelerated_ns'],
|
|
338
|
+
cns_a_bound=group['cns_a_bound'],
|
|
318
339
|
)
|
|
319
340
|
del signed_m_buf
|
|
320
341
|
|
|
321
|
-
|
|
322
|
-
|
|
342
|
+
if group['normuon_variant']:
|
|
343
|
+
v_t = state['normuon_v']
|
|
344
|
+
# Update 2nd moment estimate
|
|
345
|
+
mean_squared_update = torch.mean(update.square(), dim=1)
|
|
346
|
+
v_t.mul_(beta2).add_(mean_squared_update, alpha=1 - beta2)
|
|
347
|
+
# Normalize update
|
|
348
|
+
if group['use_atan2']:
|
|
349
|
+
a = 1.2732395
|
|
350
|
+
update.atan2_(v_t.sqrt().unsqueeze(1)).mul_(a)
|
|
351
|
+
else:
|
|
352
|
+
update.div_(v_t.sqrt().unsqueeze(1).add_(group['eps']))
|
|
353
|
+
# Scale learning rate
|
|
354
|
+
update_norm = torch.linalg.vector_norm(update)
|
|
355
|
+
scaled_lr = group['rms_target'] * lr * (p.numel()**0.5) / update_norm.add_(group['eps'])
|
|
356
|
+
update = update.view(p.shape).mul_(scaled_lr)
|
|
357
|
+
del mean_squared_update, update_norm, scaled_lr
|
|
358
|
+
else:
|
|
359
|
+
# Reconstruct second momentum from previous step's factors
|
|
360
|
+
vt_buf = _unnmf((state['mu_vbuf_nmf'], state['mv_vbuf_nmf']))
|
|
323
361
|
|
|
324
|
-
|
|
325
|
-
|
|
362
|
+
# Update second momentum in full-size
|
|
363
|
+
vt_buf.mul_(beta2).addcmul_(update, update, value=1 - beta2)
|
|
326
364
|
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
365
|
+
# Apply second momentum update (adaptive scaling)
|
|
366
|
+
if group['use_atan2']:
|
|
367
|
+
a = 1.2732395
|
|
368
|
+
denom = vt_buf.sqrt()
|
|
369
|
+
update.atan2_(denom).mul_(a)
|
|
370
|
+
else:
|
|
371
|
+
denom = vt_buf.sqrt().add_(group['eps'])
|
|
372
|
+
update.div_(denom)
|
|
373
|
+
del denom
|
|
336
374
|
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
375
|
+
# RMS-aligned rescaling
|
|
376
|
+
rms_target = group['rms_target']
|
|
377
|
+
num_elements = update.numel()
|
|
378
|
+
# Add eps to prevent division by zero
|
|
379
|
+
update.mul_(rms_target * (num_elements ** 0.5) / (update.norm().add_(group['eps'])))
|
|
342
380
|
|
|
343
|
-
|
|
344
|
-
|
|
381
|
+
update = update.view(p.shape).mul_(lr)
|
|
382
|
+
del num_elements
|
|
345
383
|
|
|
346
384
|
# Compress updated moments and store new factors
|
|
347
385
|
state['sign_buf'] = _pack_bools(mt_buf > 0)
|
|
348
386
|
_nnmf(mt_buf.abs(), out=(state['mu_mbuf_nmf'], state['mv_mbuf_nmf']))
|
|
349
387
|
del mt_buf
|
|
350
388
|
|
|
351
|
-
|
|
352
|
-
|
|
389
|
+
if not group['normuon_variant']:
|
|
390
|
+
_nnmf(vt_buf.abs(), out=(state['mu_vbuf_nmf'], state['mv_vbuf_nmf']))
|
|
391
|
+
del vt_buf
|
|
353
392
|
|
|
354
393
|
else: # Standard AdaMuon logic for non-factored tensors
|
|
355
394
|
|
|
@@ -387,12 +426,12 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
387
426
|
Q, _ = torch.linalg.qr(MG)
|
|
388
427
|
projected_M = Q.T @ M
|
|
389
428
|
ortho_projected_M = _newton_schulz_iteration(
|
|
390
|
-
projected_M, steps=group['ns_steps'], eps=group['ns_eps'], coeffs=group['ns_coeffs']
|
|
429
|
+
projected_M, steps=group['ns_steps'], eps=group['ns_eps'], coeffs=group['ns_coeffs'], cns=group['accelerated_ns'], cns_a_bound=group['cns_a_bound']
|
|
391
430
|
)
|
|
392
431
|
update = Q @ ortho_projected_M
|
|
393
432
|
else: # Fallback for invalid rank
|
|
394
433
|
update = _newton_schulz_iteration(
|
|
395
|
-
signed_m_buf, steps=group['ns_steps'], eps=group['ns_eps'], coeffs=group['ns_coeffs']
|
|
434
|
+
signed_m_buf, steps=group['ns_steps'], eps=group['ns_eps'], coeffs=group['ns_coeffs'], cns=group['accelerated_ns'], cns_a_bound=group['cns_a_bound']
|
|
396
435
|
)
|
|
397
436
|
else:
|
|
398
437
|
# Original full Newton-Schulz
|
|
@@ -401,42 +440,59 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
401
440
|
steps=group['ns_steps'],
|
|
402
441
|
eps=group['ns_eps'],
|
|
403
442
|
coeffs=group['ns_coeffs'],
|
|
443
|
+
cns=group['accelerated_ns'],
|
|
444
|
+
cns_a_bound=group['cns_a_bound'],
|
|
404
445
|
)
|
|
405
446
|
del signed_m_buf
|
|
406
447
|
|
|
407
448
|
update = update.view(original_shape)
|
|
408
449
|
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
450
|
+
if group['normuon_variant']:
|
|
451
|
+
# NorMuon Logic
|
|
452
|
+
v_t = state['normuon_v']
|
|
453
|
+
# Update 2nd moment estimate
|
|
454
|
+
mean_squared_update = torch.mean(update.square(), dim=1)
|
|
455
|
+
v_t.mul_(beta2).add_(mean_squared_update, alpha=1 - beta2)
|
|
456
|
+
# Normalize update
|
|
457
|
+
if group['use_atan2']:
|
|
458
|
+
a = 1.2732395
|
|
459
|
+
update.atan2_(v_t.sqrt().unsqueeze(1)).mul_(a)
|
|
460
|
+
else:
|
|
461
|
+
update.div_(v_t.sqrt().unsqueeze(1).add_(group['eps']))
|
|
462
|
+
# Scale learning rate
|
|
463
|
+
update_norm = torch.linalg.vector_norm(update)
|
|
464
|
+
scaled_lr = group['rms_target'] * lr * (p.numel()**0.5) / update_norm.add_(group['eps'])
|
|
465
|
+
update.mul_(scaled_lr)
|
|
466
|
+
del mean_squared_update, update_norm, scaled_lr
|
|
417
467
|
else:
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
468
|
+
# Original AdaMuon Logic
|
|
469
|
+
vt_buf = state['second_momentum_buffer']
|
|
470
|
+
vt_buf.mul_(beta2).addcmul_(update, update, value=1 - beta2)
|
|
471
|
+
# Apply second momentum update (adaptive scaling)
|
|
472
|
+
if group['use_atan2']:
|
|
473
|
+
a = 1.2732395
|
|
474
|
+
denom = vt_buf.sqrt()
|
|
475
|
+
update.atan2_(denom).mul_(a)
|
|
476
|
+
else:
|
|
477
|
+
denom = vt_buf.sqrt().add_(group['eps'])
|
|
478
|
+
update.div_(denom)
|
|
479
|
+
del denom
|
|
480
|
+
# RMS-aligned rescaling
|
|
481
|
+
rms_target = group['rms_target']
|
|
482
|
+
num_elements = update.numel()
|
|
483
|
+
# Add eps to prevent division by zero
|
|
484
|
+
update.mul_(rms_target * (num_elements ** 0.5) / (update.norm().add_(group['eps'])))
|
|
485
|
+
del num_elements
|
|
486
|
+
update.mul_(lr)
|
|
431
487
|
|
|
432
488
|
else: # Fallback to standard SGD with momentum for 1D params (biases, etc.)
|
|
433
489
|
# Momentum update
|
|
434
490
|
mt_buf = state['momentum_buffer']
|
|
435
491
|
mt_buf.mul_(beta1).add_(grad)
|
|
436
492
|
if nesterov:
|
|
437
|
-
# Nesterov momentum
|
|
438
493
|
update = grad.add(mt_buf, alpha=beta1)
|
|
439
|
-
#
|
|
494
|
+
# FIXME, Simplified_AdEMAMix will break SGD since it requires x100 lower LR
|
|
495
|
+
# elif Simplified_AdEMAMix:
|
|
440
496
|
# update = mt_buf.add(grad, alpha=alpha_grad)
|
|
441
497
|
else:
|
|
442
498
|
update = mt_buf.clone()
|
|
@@ -599,7 +655,6 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
599
655
|
state = self.state[p]
|
|
600
656
|
|
|
601
657
|
|
|
602
|
-
|
|
603
658
|
# Determine if using Adam or Muon based on state keys
|
|
604
659
|
# We can use optm_type but I see this as a safer way.
|
|
605
660
|
if 'momentum_buffer' in state or 'mu_mbuf_nmf' in state:
|
|
@@ -611,32 +666,41 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
611
666
|
is_compiled = group.get('compiled_optimizer', False)
|
|
612
667
|
|
|
613
668
|
if use_adam:
|
|
669
|
+
step = state['step']
|
|
670
|
+
|
|
614
671
|
if self.kourkoutas_helper:
|
|
615
672
|
# Prepare Kourkoutas-β once per optimizer step.
|
|
616
|
-
self.kourkoutas_helper.maybe_prepare_step(
|
|
673
|
+
self.kourkoutas_helper.maybe_prepare_step(step)
|
|
674
|
+
|
|
617
675
|
# Adam-specific setup (bias correction)
|
|
618
676
|
if group['adam_use_bias_correction']:
|
|
619
|
-
current_step =
|
|
677
|
+
current_step = step + 1
|
|
620
678
|
beta1_adam, beta2_adam = group['adam_betas']
|
|
621
679
|
bias_correction1 = 1.0 - beta1_adam ** current_step
|
|
622
680
|
bias_correction2 = 1.0 - beta2_adam ** current_step
|
|
623
681
|
else:
|
|
624
682
|
bias_correction1 = 1.0
|
|
625
683
|
bias_correction2 = 1.0
|
|
684
|
+
|
|
685
|
+
self.state[p]['step'] += 1
|
|
686
|
+
|
|
626
687
|
# Dispatch to compiled or uncompiled Adam step
|
|
627
688
|
if is_compiled and self._compiled_adam_step is not None:
|
|
628
|
-
#
|
|
629
|
-
|
|
630
|
-
|
|
631
|
-
|
|
632
|
-
|
|
689
|
+
# convert to tensors for compiled path once a step
|
|
690
|
+
if not hasattr(self, 'lr_adam_tensor') or self.lr_adam_tensor is None:
|
|
691
|
+
self.lr_adam_tensor = torch.tensor(group['lr'])
|
|
692
|
+
self.bc1 = torch.tensor(bias_correction1)
|
|
693
|
+
self.bc2 = torch.tensor(bias_correction2)
|
|
694
|
+
self._compiled_adam_step(p, grad, state, group, self.lr_adam_tensor, self.bc1, self.bc2)
|
|
633
695
|
else:
|
|
634
696
|
self._adam_step_parameter(p, grad, state, group, lr, bias_correction1, bias_correction2)
|
|
635
697
|
else: # Muon path
|
|
636
698
|
# Dispatch to compiled or uncompiled Muon step
|
|
637
699
|
if is_compiled and self._compiled_muon_step is not None:
|
|
638
|
-
|
|
639
|
-
self
|
|
700
|
+
# convert to tensors for compiled path once a step
|
|
701
|
+
if not hasattr(self, 'lr_tensor') or self.lr_tensor is None:
|
|
702
|
+
self.lr_tensor = torch.tensor(group['lr'])
|
|
703
|
+
self._compiled_muon_step(p, grad, state, group, self.lr_tensor)
|
|
640
704
|
else:
|
|
641
705
|
self._muon_step_parameter(p, grad, state, group, lr)
|
|
642
706
|
|
|
@@ -659,6 +723,8 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
659
723
|
for i, p in enumerate(group['params']):
|
|
660
724
|
self.step_parameter(p, group, i)
|
|
661
725
|
|
|
662
|
-
self.
|
|
663
|
-
|
|
726
|
+
if self.param_groups[0].get('compiled_optimizer', False):
|
|
727
|
+
# Reset compile tensors once a step
|
|
728
|
+
self.lr_tensor = None
|
|
729
|
+
self.lr_adam_tensor = None
|
|
664
730
|
return loss
|
|
@@ -60,6 +60,9 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
60
60
|
normuon_lr_scale (float): Scaling factor for the NorMuon learning rate.
|
|
61
61
|
(default: 0.2)
|
|
62
62
|
normuon_atan2 (bool): whether to use the atan2 for NorMuon. (default: False)
|
|
63
|
+
accelerated_ns (bool): If True, enables Chebyshev-accelerated Newton-Schulz, which
|
|
64
|
+
dynamically calculates optimal 3rd-order polynomial coefficients. (default: False)
|
|
65
|
+
cns_a_bound (float): Initial lower bound for singular values for CANS. (default: 1e-4)
|
|
63
66
|
--- Auxiliary AdamW_adv Parameters (used for 'adam' groups) ---
|
|
64
67
|
adam_betas (tuple[float, float]): Betas for the AdamW optimizer part.
|
|
65
68
|
adam_eps (float): Epsilon for the AdamW optimizer part.
|
|
@@ -73,6 +76,7 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
73
76
|
adam_beta3_ema (float): Beta3 for AdEMAMix.
|
|
74
77
|
adam_alpha (float): Alpha for AdEMAMix.
|
|
75
78
|
adam_kourkoutas_beta (bool): Kourkoutas-β for AdamW.
|
|
79
|
+
adam_nnmf_factor (bool): 1-bit factored for AdamW.
|
|
76
80
|
"""
|
|
77
81
|
|
|
78
82
|
def __init__(
|
|
@@ -100,6 +104,9 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
100
104
|
normuon_eps: float = 1e-8,
|
|
101
105
|
normuon_lr_scale: float = 0.2,
|
|
102
106
|
normuon_atan2: bool = False,
|
|
107
|
+
# CANS
|
|
108
|
+
accelerated_ns: bool = False,
|
|
109
|
+
cns_a_bound: float = 1e-4,
|
|
103
110
|
# Compiled
|
|
104
111
|
compiled_optimizer: bool = False,
|
|
105
112
|
# --- AdamW_adv specific parameters ---
|
|
@@ -119,6 +126,7 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
119
126
|
adam_ema_alpha: float = 0.95,
|
|
120
127
|
adam_tiny_spike: float = 1e-9,
|
|
121
128
|
adam_k_warmup_steps: int = 0,
|
|
129
|
+
adam_nnmf_factor: bool = False,
|
|
122
130
|
):
|
|
123
131
|
if not (lr >= 0.0):
|
|
124
132
|
raise ValueError(f"Learning-rate should be >= 0.0. Got {lr}")
|
|
@@ -148,6 +156,8 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
148
156
|
"normuon_variant": normuon_variant, "beta2_normuon": beta2_normuon,
|
|
149
157
|
"normuon_eps": normuon_eps, "normuon_lr_scale": normuon_lr_scale,
|
|
150
158
|
"normuon_atan2": normuon_atan2,
|
|
159
|
+
# CANS
|
|
160
|
+
"accelerated_ns": accelerated_ns, "cns_a_bound": cns_a_bound,
|
|
151
161
|
# AdamW_adv defaults
|
|
152
162
|
"adam_betas": adam_betas, "adam_eps": adam_eps, "adam_weight_decay": adam_weight_decay,
|
|
153
163
|
"adam_use_bias_correction": adam_use_bias_correction, "adam_use_atan2": adam_use_atan2,
|
|
@@ -157,13 +167,13 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
157
167
|
"adam_kourkoutas_beta": adam_kourkoutas_beta, "adam_beta2_min": adam_beta2_min,
|
|
158
168
|
"adam_ema_alpha": adam_ema_alpha, "adam_tiny_spike": adam_tiny_spike,
|
|
159
169
|
"adam_k_warmup_steps": adam_k_warmup_steps,
|
|
170
|
+
"adam_nnmf_factor":adam_nnmf_factor,
|
|
160
171
|
}
|
|
161
172
|
self.stochastic_rounding = stochastic_rounding
|
|
162
173
|
self.compiled_optimizer = compiled_optimizer
|
|
163
174
|
|
|
164
175
|
super().__init__(params, defaults)
|
|
165
176
|
|
|
166
|
-
self.global_step = 0 # For Adam bias correction and Kourkoutas
|
|
167
177
|
self.kourkoutas_helper = None
|
|
168
178
|
if any(group.get('adam_kourkoutas_beta', False) for group in self.param_groups):
|
|
169
179
|
self.kourkoutas_helper = KourkoutasHelper(self)
|
|
@@ -214,6 +224,7 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
214
224
|
|
|
215
225
|
if optim_type == 'muon':
|
|
216
226
|
|
|
227
|
+
|
|
217
228
|
state['factored'] = (
|
|
218
229
|
group['nnmf_factor'] and
|
|
219
230
|
not (len(p.shape) == 1 and not group['vector_reshape'])
|
|
@@ -238,9 +249,12 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
238
249
|
elif len(p.shape) >= 2:
|
|
239
250
|
state['normuon_v'] = torch.zeros(p.shape[0], device=p.device, dtype=torch.float32)
|
|
240
251
|
|
|
252
|
+
group['adam_kourkoutas_beta'] = False
|
|
241
253
|
|
|
242
254
|
elif optim_type == 'adam':
|
|
243
255
|
|
|
256
|
+
state['step'] = 0
|
|
257
|
+
|
|
244
258
|
state['factored'] = (
|
|
245
259
|
group['adam_nnmf_factor'] and
|
|
246
260
|
not (len(p.shape) == 1 and not group['vector_reshape'])
|
|
@@ -319,12 +333,12 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
319
333
|
Q, _ = torch.linalg.qr(MG)
|
|
320
334
|
projected_M = Q.T @ M
|
|
321
335
|
ortho_projected_M = _newton_schulz_iteration(
|
|
322
|
-
projected_M, steps=group['ns_steps'], eps=group['ns_eps'], coeffs=group['ns_coeffs']
|
|
336
|
+
projected_M, steps=group['ns_steps'], eps=group['ns_eps'], coeffs=group['ns_coeffs'], cns=group['accelerated_ns'], cns_a_bound=group['cns_a_bound']
|
|
323
337
|
)
|
|
324
338
|
update = Q @ ortho_projected_M
|
|
325
339
|
else: # Fallback for invalid rank
|
|
326
340
|
update = _newton_schulz_iteration(
|
|
327
|
-
update, steps=group['ns_steps'], eps=group['ns_eps'], coeffs=group['ns_coeffs']
|
|
341
|
+
update, steps=group['ns_steps'], eps=group['ns_eps'], coeffs=group['ns_coeffs'], cns=group['accelerated_ns'], cns_a_bound=group['cns_a_bound']
|
|
328
342
|
)
|
|
329
343
|
else:
|
|
330
344
|
# Original full Newton-Schulz
|
|
@@ -333,10 +347,12 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
333
347
|
steps=group['ns_steps'],
|
|
334
348
|
eps=group['ns_eps'],
|
|
335
349
|
coeffs=group['ns_coeffs'],
|
|
350
|
+
cns=group['accelerated_ns'],
|
|
351
|
+
cns_a_bound=group['cns_a_bound'],
|
|
336
352
|
)
|
|
337
353
|
|
|
338
354
|
|
|
339
|
-
if group['normuon_variant']
|
|
355
|
+
if group['normuon_variant']:
|
|
340
356
|
v_t = state['normuon_v']
|
|
341
357
|
beta2_normuon = group['beta2_normuon']
|
|
342
358
|
# Update 2nd moment estimate
|
|
@@ -414,6 +430,8 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
414
430
|
steps=group['ns_steps'],
|
|
415
431
|
eps=group['ns_eps'],
|
|
416
432
|
coeffs=group['ns_coeffs'],
|
|
433
|
+
cns=group['accelerated_ns'],
|
|
434
|
+
cns_a_bound=group['cns_a_bound'],
|
|
417
435
|
)
|
|
418
436
|
|
|
419
437
|
# 5. Project back to the original space
|
|
@@ -424,6 +442,8 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
424
442
|
steps=group['ns_steps'],
|
|
425
443
|
eps=group['ns_eps'],
|
|
426
444
|
coeffs=group['ns_coeffs'],
|
|
445
|
+
cns=group['accelerated_ns'],
|
|
446
|
+
cns_a_bound=group['cns_a_bound'],
|
|
427
447
|
)
|
|
428
448
|
else:
|
|
429
449
|
# Original NewtonSchulz
|
|
@@ -432,10 +452,12 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
432
452
|
steps=group['ns_steps'],
|
|
433
453
|
eps=group['ns_eps'],
|
|
434
454
|
coeffs=group['ns_coeffs'],
|
|
455
|
+
cns=group['accelerated_ns'],
|
|
456
|
+
cns_a_bound=group['cns_a_bound'],
|
|
435
457
|
)
|
|
436
458
|
|
|
437
459
|
# NorMuon Logic
|
|
438
|
-
if group['normuon_variant']
|
|
460
|
+
if group['normuon_variant']:
|
|
439
461
|
v_t = state['normuon_v']
|
|
440
462
|
beta2_normuon = group['beta2_normuon']
|
|
441
463
|
# Update 2nd moment estimate
|
|
@@ -629,10 +651,6 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
629
651
|
|
|
630
652
|
state = self.state[p]
|
|
631
653
|
|
|
632
|
-
if self.kourkoutas_helper:
|
|
633
|
-
# Prepare Kourkoutas-β once per optimizer step.
|
|
634
|
-
self.kourkoutas_helper.maybe_prepare_step(self.global_step)
|
|
635
|
-
|
|
636
654
|
# Determine if using Adam or Muon based on state keys
|
|
637
655
|
# We can use optm_type but I see this as a safer way.
|
|
638
656
|
if 'momentum_buffer' in state or 'mu_mbuf_nmf' in state:
|
|
@@ -644,9 +662,15 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
644
662
|
is_compiled = group.get('compiled_optimizer', False)
|
|
645
663
|
|
|
646
664
|
if use_adam:
|
|
665
|
+
step = state['step']
|
|
666
|
+
|
|
667
|
+
if self.kourkoutas_helper:
|
|
668
|
+
# Prepare Kourkoutas-β once per optimizer step.
|
|
669
|
+
self.kourkoutas_helper.maybe_prepare_step(step)
|
|
670
|
+
|
|
647
671
|
# Adam-specific setup (bias correction)
|
|
648
672
|
if group['adam_use_bias_correction']:
|
|
649
|
-
current_step =
|
|
673
|
+
current_step = step + 1
|
|
650
674
|
beta1_adam, beta2_adam = group['adam_betas']
|
|
651
675
|
bias_correction1 = 1.0 - beta1_adam ** current_step
|
|
652
676
|
bias_correction2 = 1.0 - beta2_adam ** current_step
|
|
@@ -654,6 +678,8 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
654
678
|
bias_correction1 = 1.0
|
|
655
679
|
bias_correction2 = 1.0
|
|
656
680
|
|
|
681
|
+
self.state[p]['step'] += 1
|
|
682
|
+
|
|
657
683
|
# Dispatch to compiled or uncompiled Adam step
|
|
658
684
|
if is_compiled and self._compiled_adam_step is not None:
|
|
659
685
|
# Tensors must be used for compiled functions
|
|
@@ -690,6 +716,4 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
690
716
|
for i, p in enumerate(group['params']):
|
|
691
717
|
self.step_parameter(p, group, i)
|
|
692
718
|
|
|
693
|
-
self.global_step += 1
|
|
694
|
-
|
|
695
719
|
return loss
|
|
@@ -0,0 +1,87 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
@torch.no_grad()
|
|
4
|
+
def _newton_schulz_iteration(
|
|
5
|
+
G: torch.Tensor,
|
|
6
|
+
steps: int = 5,
|
|
7
|
+
eps: float = 1e-7,
|
|
8
|
+
coeffs: tuple[float, float, float] = (3.4445, -4.7750, 2.0315),
|
|
9
|
+
cns: bool = False,
|
|
10
|
+
cns_a_bound: float = 1e-4,
|
|
11
|
+
) -> torch.Tensor:
|
|
12
|
+
"""
|
|
13
|
+
Performs the Newton-Schulz iteration to find the nearest orthogonal matrix.
|
|
14
|
+
This is the core computation of the Muon optimizer.
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
G (torch.Tensor): The 2D input matrix (momentum-accumulated gradient).
|
|
18
|
+
steps (int): The number of iterations to run.
|
|
19
|
+
eps (float): Small constant for numerical stability during normalization.
|
|
20
|
+
coeffs (tuple[float, float, float]): The (a, b, c) coefficients for the
|
|
21
|
+
quintic polynomial update.
|
|
22
|
+
cns (bool): If True, enables Chebyshev-accelerated Newton-Schulz (CANS)
|
|
23
|
+
using an iterative 3rd-order polynomial with optimal coefficients
|
|
24
|
+
derived at each step.
|
|
25
|
+
cns_a_bound (float): The initial lower bound for singular values when
|
|
26
|
+
using CANS. The upper bound is assumed to be 1.0 after normalization.
|
|
27
|
+
Returns:
|
|
28
|
+
torch.Tensor: The orthogonalized matrix.
|
|
29
|
+
"""
|
|
30
|
+
assert G.ndim >= 2
|
|
31
|
+
|
|
32
|
+
a, b, c = coeffs
|
|
33
|
+
|
|
34
|
+
X = G.to(torch.bfloat16)
|
|
35
|
+
|
|
36
|
+
transposed = G.size(-2) > G.size(-1)
|
|
37
|
+
if transposed:
|
|
38
|
+
X = X.mT
|
|
39
|
+
|
|
40
|
+
X = X / (X.norm(dim=(-2, -1), keepdim=True) + eps)
|
|
41
|
+
|
|
42
|
+
if cns:
|
|
43
|
+
# Chebyshev-accelerated Newton-Schulz (CANS) from
|
|
44
|
+
# "Accelerating Newton-Schulz Iteration for Orthogonalization via Chebyshev-type Polynomials"
|
|
45
|
+
# This implements the iterative scheme from Algorithm 1, using the
|
|
46
|
+
# closed-form 3rd-order polynomial from Proposition 2.
|
|
47
|
+
lower_bound = cns_a_bound
|
|
48
|
+
upper_bound = 1.0 # Matrix is normalized, so largest singular value is approx 1.
|
|
49
|
+
|
|
50
|
+
for _ in range(steps):
|
|
51
|
+
# Calculate optimal 3rd-order coefficients c1, c3 for p(x) = c1*x + c3*x^3
|
|
52
|
+
# based on the current singular value bounds [lower_bound, upper_bound].
|
|
53
|
+
# Formulas are derived from Proposition 2 and its proof in Appendix B of the paper.
|
|
54
|
+
a_bound, b_bound = lower_bound, upper_bound
|
|
55
|
+
term = a_bound*a_bound + a_bound*b_bound + b_bound*b_bound
|
|
56
|
+
e_sq = term / 3.0
|
|
57
|
+
|
|
58
|
+
# Calculate alpha, which scales the polynomial
|
|
59
|
+
common_den_part = 2.0 * (e_sq**1.5)
|
|
60
|
+
ab_part = a_bound*a_bound*b_bound + b_bound*b_bound*a_bound
|
|
61
|
+
alpha_den = common_den_part + ab_part
|
|
62
|
+
alpha = 6.0 / alpha_den
|
|
63
|
+
|
|
64
|
+
c1 = alpha * e_sq
|
|
65
|
+
c3 = -alpha / 3.0
|
|
66
|
+
|
|
67
|
+
# Apply the 3rd-order Newton-Schulz update
|
|
68
|
+
A = X @ X.mT
|
|
69
|
+
X = c1 * X + c3 * (A @ X)
|
|
70
|
+
|
|
71
|
+
# Update the singular value bounds for the next iteration based on the error
|
|
72
|
+
eps_num = common_den_part - ab_part
|
|
73
|
+
eps_val = eps_num / alpha_den
|
|
74
|
+
lower_bound = 1.0 - eps_val
|
|
75
|
+
upper_bound = 1.0 + eps_val
|
|
76
|
+
else:
|
|
77
|
+
# Perform the iterative updates
|
|
78
|
+
for _ in range(steps):
|
|
79
|
+
A = X @ X.mT
|
|
80
|
+
B = b * A + c * (A @ A)
|
|
81
|
+
X = a * X + B @ X
|
|
82
|
+
|
|
83
|
+
# Transpose back if necessary
|
|
84
|
+
if transposed:
|
|
85
|
+
X = X.mT
|
|
86
|
+
|
|
87
|
+
return X.to(G.dtype)
|
|
@@ -1,48 +0,0 @@
|
|
|
1
|
-
import torch
|
|
2
|
-
|
|
3
|
-
@torch.no_grad()
|
|
4
|
-
def _newton_schulz_iteration(
|
|
5
|
-
G: torch.Tensor,
|
|
6
|
-
steps: int = 5,
|
|
7
|
-
eps: float = 1e-7,
|
|
8
|
-
coeffs: tuple[float, float, float] = (3.4445, -4.7750, 2.0315)
|
|
9
|
-
) -> torch.Tensor:
|
|
10
|
-
"""
|
|
11
|
-
Performs the Newton-Schulz iteration to find the nearest orthogonal matrix.
|
|
12
|
-
This is the core computation of the Muon optimizer.
|
|
13
|
-
|
|
14
|
-
Args:
|
|
15
|
-
G (torch.Tensor): The 2D input matrix (momentum-accumulated gradient).
|
|
16
|
-
steps (int): The number of iterations to run.
|
|
17
|
-
eps (float): Small constant for numerical stability during normalization.
|
|
18
|
-
coeffs (tuple[float, float, float]): The (a, b, c) coefficients for the
|
|
19
|
-
quintic polynomial update.
|
|
20
|
-
|
|
21
|
-
Returns:
|
|
22
|
-
torch.Tensor: The orthogonalized matrix.
|
|
23
|
-
"""
|
|
24
|
-
assert G.ndim == 2, "Newton-Schulz iteration only supports 2D matrices."
|
|
25
|
-
|
|
26
|
-
a, b, c = coeffs
|
|
27
|
-
|
|
28
|
-
X = G.to(torch.bfloat16)
|
|
29
|
-
|
|
30
|
-
# Normalize the matrix
|
|
31
|
-
X.div_(X.norm() + eps)
|
|
32
|
-
|
|
33
|
-
# Handle non-square matrices by transposing the taller one
|
|
34
|
-
transposed = G.size(0) > G.size(1)
|
|
35
|
-
if transposed:
|
|
36
|
-
X = X.T
|
|
37
|
-
|
|
38
|
-
# Perform the iterative updates
|
|
39
|
-
for _ in range(steps):
|
|
40
|
-
A = X @ X.T
|
|
41
|
-
B = b * A + c * (A @ A)
|
|
42
|
-
X = a * X + B @ X
|
|
43
|
-
|
|
44
|
-
# Transpose back if necessary
|
|
45
|
-
if transposed:
|
|
46
|
-
X = X.T
|
|
47
|
-
|
|
48
|
-
return X.to(G.dtype)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|