adv-optm 2.4.dev6__tar.gz → 2.4.dev8__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. {adv_optm-2.4.dev6 → adv_optm-2.4.dev8}/PKG-INFO +1 -1
  2. {adv_optm-2.4.dev6 → adv_optm-2.4.dev8}/adv_optm/__init__.py +3 -1
  3. {adv_optm-2.4.dev6 → adv_optm-2.4.dev8}/adv_optm/optim/AdaMuon_adv.py +101 -64
  4. {adv_optm-2.4.dev6 → adv_optm-2.4.dev8}/adv_optm/optim/AdamW_adv.py +111 -75
  5. {adv_optm-2.4.dev6 → adv_optm-2.4.dev8}/adv_optm/optim/Adopt_adv.py +118 -85
  6. {adv_optm-2.4.dev6 → adv_optm-2.4.dev8}/adv_optm/optim/Lion_adv.py +13 -10
  7. {adv_optm-2.4.dev6 → adv_optm-2.4.dev8}/adv_optm/optim/Muon_adv.py +56 -53
  8. adv_optm-2.4.dev8/adv_optm/optim/SGD_adv.py +283 -0
  9. {adv_optm-2.4.dev6 → adv_optm-2.4.dev8}/adv_optm/optim/SignSGD_adv.py +79 -28
  10. {adv_optm-2.4.dev6 → adv_optm-2.4.dev8}/adv_optm/optim/Simplified_AdEMAMix.py +7 -7
  11. {adv_optm-2.4.dev6 → adv_optm-2.4.dev8}/adv_optm/optim/__init__.py +2 -0
  12. {adv_optm-2.4.dev6 → adv_optm-2.4.dev8}/adv_optm/util/Kourkoutas.py +64 -8
  13. {adv_optm-2.4.dev6 → adv_optm-2.4.dev8}/adv_optm/util/Muon_util.py +3 -43
  14. adv_optm-2.4.dev8/adv_optm/util/OrthoGrad.py +19 -0
  15. {adv_optm-2.4.dev6 → adv_optm-2.4.dev8}/adv_optm/util/centered_decay.py +9 -2
  16. {adv_optm-2.4.dev6 → adv_optm-2.4.dev8}/adv_optm/util/param_update.py +226 -70
  17. {adv_optm-2.4.dev6 → adv_optm-2.4.dev8}/adv_optm/util/scaled_optm.py +57 -47
  18. adv_optm-2.4.dev8/adv_optm/util/sinkhorn.py +42 -0
  19. adv_optm-2.4.dev8/adv_optm/util/state_util.py +289 -0
  20. {adv_optm-2.4.dev6 → adv_optm-2.4.dev8}/adv_optm/util/update_util.py +10 -52
  21. {adv_optm-2.4.dev6 → adv_optm-2.4.dev8}/adv_optm.egg-info/PKG-INFO +1 -1
  22. {adv_optm-2.4.dev6 → adv_optm-2.4.dev8}/adv_optm.egg-info/SOURCES.txt +3 -0
  23. {adv_optm-2.4.dev6 → adv_optm-2.4.dev8}/setup.py +1 -1
  24. adv_optm-2.4.dev6/adv_optm/util/OrthoGrad.py +0 -50
  25. {adv_optm-2.4.dev6 → adv_optm-2.4.dev8}/LICENSE +0 -0
  26. {adv_optm-2.4.dev6 → adv_optm-2.4.dev8}/README.md +0 -0
  27. {adv_optm-2.4.dev6 → adv_optm-2.4.dev8}/adv_optm/optim/Lion_Prodigy_adv.py +0 -0
  28. {adv_optm-2.4.dev6 → adv_optm-2.4.dev8}/adv_optm/optim/Prodigy_adv.py +0 -0
  29. {adv_optm-2.4.dev6 → adv_optm-2.4.dev8}/adv_optm/util/Muon_AuxAdam.py +0 -0
  30. {adv_optm-2.4.dev6 → adv_optm-2.4.dev8}/adv_optm/util/__init__.py +0 -0
  31. {adv_optm-2.4.dev6 → adv_optm-2.4.dev8}/adv_optm/util/factorization_util.py +0 -0
  32. {adv_optm-2.4.dev6 → adv_optm-2.4.dev8}/adv_optm/util/lion_k.py +0 -0
  33. {adv_optm-2.4.dev6 → adv_optm-2.4.dev8}/adv_optm/util/signed_util.py +0 -0
  34. {adv_optm-2.4.dev6 → adv_optm-2.4.dev8}/adv_optm.egg-info/dependency_links.txt +0 -0
  35. {adv_optm-2.4.dev6 → adv_optm-2.4.dev8}/adv_optm.egg-info/requires.txt +0 -0
  36. {adv_optm-2.4.dev6 → adv_optm-2.4.dev8}/adv_optm.egg-info/top_level.txt +0 -0
  37. {adv_optm-2.4.dev6 → adv_optm-2.4.dev8}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: adv_optm
3
- Version: 2.4.dev6
3
+ Version: 2.4.dev8
4
4
  Summary: A family of highly efficient, lightweight yet powerful optimizers.
5
5
  Home-page: https://github.com/Koratahiu/Advanced_Optimizers
6
6
  Author: Koratahiu
@@ -8,6 +8,7 @@ from .optim import (
8
8
  Muon_adv,
9
9
  AdaMuon_adv,
10
10
  SignSGD_adv,
11
+ SGD_adv,
11
12
  )
12
13
 
13
14
  __all__ = [
@@ -20,6 +21,7 @@ __all__ = [
20
21
  "Muon_adv",
21
22
  "AdaMuon_adv",
22
23
  "SignSGD_adv",
24
+ "SGD_adv",
23
25
  ]
24
26
 
25
- __version__ = "2.4.dev6"
27
+ __version__ = "2.4.dev8"
@@ -3,12 +3,15 @@ import torch
3
3
  import math
4
4
 
5
5
  from ..util import param_update
6
- from ..util.Muon_util import newton_schulz, _is_suitable_for_muon, rms_adjustment, normuon_update, approx_mars, _auto_projection_for_adamuon, spectral_norm_update, get_spectral_scaling
6
+ from ..util.Muon_util import newton_schulz, _is_suitable_for_muon, rms_adjustment, normuon_update, approx_mars, _auto_projection_for_adamuon, get_spectral_scaling
7
+ from ..util.scaled_optm import spectral_normalization, init_spectral_norm
7
8
  from ..util.factorization_util import _get_effective_shape, _factorize_state, _reconstruct_state
8
9
  from ..util.OrthoGrad import _orthogonalize_gradient
9
10
  from ..util.Kourkoutas import KourkoutasHelper
10
11
  from ..util import Muon_AuxAdam
11
12
  from ..util.centered_decay import _init_anchor
13
+ from typing import Optional
14
+ from ..util.state_util import init_state_tensor, get_state, set_state, upcast_grad_for_precision
12
15
 
13
16
  A = 4 / math.pi
14
17
 
@@ -101,6 +104,13 @@ class AdaMuon_adv(torch.optim.Optimizer):
101
104
  the uncompressed optimizer. (default: False)
102
105
  use_muon (bool | None): whether to use Muon or AuxAdamW. MUST be provided
103
106
  either here or via `optim_type` in parameter groups. (default: None)
107
+ state_precision (str): Precision for Muon optimizer states. Options: 'auto' (parameter dtype), 'fp32',
108
+ 'bf16_sr' (BF16 with stochastic rounding), 'fp8_sr', 'int8_sr'.
109
+ (default: 'auto')
110
+ factored_2nd (bool): Factorize only the second moment (v_t) using SMMF
111
+ low-rank compression while keeping the first moment (momentum_buffer)
112
+ dense. Ignored when `nnmf_factor=True` (full SMMF) or `normuon_variant=True`.
113
+ Combines well with `state_precision` on the first moment. (default: False)
104
114
  n_layers (int): The depth of the network (L). Required for optimal epsilon scaling. (default: 1)
105
115
  spectral_normalization (bool): Enable explicit spectral normalization using power iteration. (default: False)
106
116
  --- Auxiliary AdamW_adv Parameters (used for 'adam' groups) ---
@@ -129,7 +139,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
129
139
  weight_decay: float = 0,
130
140
  cautious_wd: bool = False,
131
141
  # Nesterov momentum
132
- nesterov: bool = False,
142
+ nesterov: bool = True,
133
143
  # RMS Rescaling
134
144
  rms_rescaling: bool = True,
135
145
  # Newton Schulz
@@ -149,6 +159,10 @@ class AdaMuon_adv(torch.optim.Optimizer):
149
159
  normuon_variant: bool = False,
150
160
  # Boolean to spilt param
151
161
  use_muon: bool | None = None,
162
+ # States precision (Muon path)
163
+ state_precision: str = "auto", # 'fp32', 'bf16_sr', 'fp8_sr', 'int8_sr'
164
+ # Factorized second moment only
165
+ factored_2nd: bool = False,
152
166
  # Update geometry parameters
153
167
  kappa_p: float = 1.0,
154
168
  auto_projection: bool = True,
@@ -174,7 +188,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
174
188
  compiled_optimizer: bool = False,
175
189
  # --- AdamW_adv specific parameters ---
176
190
  adam_betas: tuple[float, float] = (0.9, 0.99),
177
- adam_eps: float = 1e-8,
191
+ adam_eps: float | None = 1e-8,
178
192
  adam_weight_decay: float = 0.0,
179
193
  adam_use_bias_correction: bool = True,
180
194
  adam_use_atan2: bool = False,
@@ -200,15 +214,17 @@ class AdaMuon_adv(torch.optim.Optimizer):
200
214
  if Simplified_AdEMAMix and nesterov:
201
215
  print("Warning: nesterov is incompatible with Simplified_AdEMAMix, Disabling nesterov.")
202
216
  nesterov = False
203
- if normuon_variant and use_atan2:
204
- print("Warning: AdaMuon atan2 is incompatible with NorMuon, Disabling AdaMuon atan2.")
205
- use_atan2 = False
206
217
  if spectral_normalization and rms_rescaling:
207
218
  print("Warning: spectral_normalization is incompatible with rms_rescaling, Disabling rms_rescaling.")
208
219
  rms_rescaling = False
209
220
  if spectral_normalization and accelerated_ns:
210
221
  ValueError("spectral_normalization violates accelerated Newton-Schulz assumptions. Pick one of them.")
211
222
 
223
+ state_precision = state_precision.lower()
224
+ valid_precisions = {"auto", "fp32", "bf16_sr", "fp8_sr", "int8_sr"}
225
+ if state_precision not in valid_precisions:
226
+ raise ValueError(f"state_precision must be one of {valid_precisions}. Got {state_precision}")
227
+
212
228
  defaults = {
213
229
  "lr": lr, "betas": betas, "weight_decay": weight_decay, "cautious_wd": cautious_wd,
214
230
  "eps": eps, "rms_rescaling": rms_rescaling, "ns_steps": ns_steps,
@@ -219,6 +235,10 @@ class AdaMuon_adv(torch.optim.Optimizer):
219
235
  "normuon_variant": normuon_variant, "orthogonal_gradient": orthogonal_gradient,
220
236
  "compiled_optimizer":compiled_optimizer,
221
237
  "use_muon": use_muon,
238
+ # States precision (Muon path)
239
+ "state_precision": state_precision,
240
+ # Factorized second moment only (Muon path)
241
+ "factored_2nd": factored_2nd,
222
242
  # Lion-K
223
243
  "kappa_p": kappa_p, "auto_projection": auto_projection,
224
244
  # Low-rank Ortho
@@ -335,9 +355,32 @@ class AdaMuon_adv(torch.optim.Optimizer):
335
355
  state['mu_vbuf_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
336
356
  state['mv_vbuf_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
337
357
  else:
338
- if not group['normuon_variant']:
339
- state['second_momentum_buffer'] = torch.zeros_like(p)
340
- state['momentum_buffer'] = torch.zeros_like(p)
358
+ # Determine effective state precision (small tensors always use fp32)
359
+ req_precision = group.get('state_precision', 'auto')
360
+ actual_precision = req_precision
361
+ if actual_precision != 'auto' and (p.numel() < 10000 or p.ndim == 1):
362
+ actual_precision = 'fp32'
363
+ group['actual_state_precision'] = actual_precision
364
+
365
+ # factored_2nd: factorize v_t only; ignored for NorMuon (no v_t) and tiny params
366
+ use_factored_2nd = (
367
+ group.get('factored_2nd', False)
368
+ and not group['normuon_variant']
369
+ and p.numel() >= 10000
370
+ and p.ndim > 1
371
+ )
372
+ state['factored_2nd'] = use_factored_2nd
373
+
374
+ default_dtype = p.dtype
375
+ init_state_tensor(state, 'momentum_buffer', p.shape, actual_precision, p.device, default_dtype)
376
+
377
+ if use_factored_2nd:
378
+ state['effective_shape'] = _get_effective_shape(p.numel())
379
+ d1, d2 = state['effective_shape']
380
+ state['mu_vbuf_nmf'] = torch.zeros(d1, device=p.device, dtype=torch.float32)
381
+ state['mv_vbuf_nmf'] = torch.zeros(d2, device=p.device, dtype=torch.float32)
382
+ elif not group['normuon_variant']:
383
+ init_state_tensor(state, 'second_momentum_buffer', p.shape, actual_precision, p.device, default_dtype, non_neg=True)
341
384
 
342
385
  # NorMuon state initialization
343
386
  if group['normuon_variant']:
@@ -349,25 +392,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
349
392
 
350
393
  # Spectral Normalization
351
394
  if group.get('spectral_normalization', False):
352
- gen = param_update.get_generator(device)
353
-
354
- # Case A: Factored Muon
355
- if state['factored']:
356
- d1, d2 = state['effective_shape']
357
- # We need a vector matching the 'inner' dimension d2
358
- state['spectral_v'] = torch.randn(d2, device=device, dtype=dtype, generator=gen)
359
-
360
- # Case B: Standard Muon (Linear, Conv2d, etc.)
361
- elif len(p.shape) >= 2:
362
- # Since Muon performs `update.flatten(1)`, the matrix becomes
363
- # (p.shape[0], product_of_rest).
364
- d_in_flat = p.numel() // p.shape[0]
365
-
366
- state['spectral_v'] = torch.randn(d_in_flat, device=device, dtype=dtype, generator=gen)
367
-
368
- # Normalize initial vector for stability
369
- if 'spectral_v' in state:
370
- state['spectral_v'].div_(state['spectral_v'].norm())
395
+ init_spectral_norm(group, state, p)
371
396
 
372
397
  # MARS-M state initialization
373
398
  if group.get('approx_mars', False):
@@ -436,18 +461,31 @@ class AdaMuon_adv(torch.optim.Optimizer):
436
461
  if is_compiled:
437
462
  lr = torch.as_tensor(group['lr'])
438
463
  muon_step_param = self._compiled_muon_step_parameter
464
+
465
+ # Generate state SR random tensor when compiled
466
+ actual_precision = group['actual_state_precision']
467
+ random_int_state_tensor = random_int_tensor
468
+ if actual_precision == 'bf16_sr' and random_int_state_tensor is not None:
469
+ random_int_state_tensor = param_update._get_random_int_for_sr(p)
470
+ elif actual_precision == 'int8_sr':
471
+ random_int_state_tensor = param_update._get_random_int_for_int8_sr(p)
472
+ elif actual_precision == 'fp8_sr':
473
+ random_int_state_tensor = param_update._get_random_int_for_fp8_sr(p)
439
474
  else:
440
475
  lr = group['lr']
441
476
  muon_step_param = self._muon_step_parameter
477
+ random_int_state_tensor = None
442
478
 
443
- muon_step_param(p, grad, state, group, lr, random_int_tensor)
479
+ muon_step_param(p, grad, state, group, lr, random_int_tensor, random_int_state_tensor)
444
480
 
445
481
  def compile(self, *args, **kwargs):
446
482
  self._compiled_muon_step_parameter = torch.compile(self._muon_step_parameter, *args, **kwargs)
447
483
  self._compiled_adam_step_parameter = torch.compile(Muon_AuxAdam._adam_step_parameter, *args, **kwargs)
448
484
 
449
485
  @torch.no_grad()
450
- def _muon_step_parameter(self, p, grad, state, group, lr, random_int_tensor):
486
+ def _muon_step_parameter(self, p, grad, state, group, lr, random_int_tensor, random_int_state_tensor=None):
487
+ # Upcast grad for low-precision state modes (non-factored path)
488
+ grad = upcast_grad_for_precision(grad, state, group.get('state_precision', 'auto'))
451
489
  beta1, beta2 = group['betas']
452
490
  nesterov = group['nesterov']
453
491
  Simplified_AdEMAMix = group['Simplified_AdEMAMix']
@@ -465,21 +503,10 @@ class AdaMuon_adv(torch.optim.Optimizer):
465
503
  kappa_p = 1.0
466
504
 
467
505
  if group.get('spectral_normalization', False):
468
- # Compute Scaling Factors
469
- if state['factored']:
470
- shape_for_scaling = torch.Size(state['effective_shape'])
471
- else:
472
- shape_for_scaling = p.shape
473
-
474
- scaled_eps, adaptive_eps, spectral_target, wd_scale = get_spectral_scaling(p, shape_for_scaling, group['n_layers'])
475
506
 
476
- weight_decay = group['weight_decay'] * wd_scale
507
+ ns_eps, adaptive_eps, _, _ = get_spectral_scaling(p, p.shape, group.get('n_layers', 1))
477
508
  decoupled_wd = True
478
-
479
- ns_eps = scaled_eps
480
-
481
509
  else:
482
- weight_decay = group['weight_decay']
483
510
  decoupled_wd = False
484
511
  ns_eps = group['ns_eps']
485
512
  adaptive_eps = group['eps']
@@ -488,8 +515,6 @@ class AdaMuon_adv(torch.optim.Optimizer):
488
515
  if group.get('approx_mars', False):
489
516
  grad = approx_mars(grad, state['last_grad'], group['mars_gamma'], beta1, Simplified_AdEMAMix=Simplified_AdEMAMix)
490
517
 
491
- if grad.dtype != torch.float32 and state.get('factored', False):
492
- grad = grad.float()
493
518
 
494
519
  if group.get("orthogonal_gradient"):
495
520
  grad = _orthogonalize_gradient(p, grad)
@@ -552,22 +577,15 @@ class AdaMuon_adv(torch.optim.Optimizer):
552
577
  update.div_(denom)
553
578
  del denom, vt_buf
554
579
 
555
- # RMS-aligned scaling
556
- step_scale = lr * A if group['use_atan2'] and not group['normuon_variant'] else lr
557
- # Spectral Normalization
558
- if group.get('spectral_normalization', False):
559
- spectral_norm_update(update, state['spectral_v'], spectral_target, step_scale)
560
- else:
561
- # Factored RMS-aligned scaling
562
- rms_adjustment(update, group['rms_rescaling'], step_scale)
563
-
564
580
  update = update.reshape(p.shape)
565
581
 
566
582
  else: # Standard AdaMuon logic for non-factored tensors
567
583
  original_shape = p.shape
584
+ actual_precision = group['actual_state_precision']
585
+ factored_2nd = state.get('factored_2nd', False)
568
586
 
569
587
  # Momentum update
570
- mt_buf = state['momentum_buffer']
588
+ mt_buf = get_state(state, 'momentum_buffer', actual_precision)
571
589
  if not Simplified_AdEMAMix:
572
590
  mt_buf.lerp_(grad, 1 - beta1)
573
591
  else:
@@ -580,6 +598,8 @@ class AdaMuon_adv(torch.optim.Optimizer):
580
598
  else:
581
599
  update = mt_buf.clone()
582
600
 
601
+ set_state(state, 'momentum_buffer', mt_buf, actual_precision, random_int_state_tensor)
602
+
583
603
  # Apply update projection
584
604
  update = _auto_projection_for_adamuon(update, kappa_p)
585
605
 
@@ -603,10 +623,26 @@ class AdaMuon_adv(torch.optim.Optimizer):
603
623
  # NorMuon Logic
604
624
  if group['normuon_variant']:
605
625
  normuon_update(update, state['normuon_v'], beta2, group['eps'])
626
+ elif factored_2nd:
627
+ # Factorized second moment: reconstruct → update → re-factorize
628
+ d1, d2 = state['effective_shape']
629
+ update = update.view(original_shape)
630
+ update_f32 = update.float()
631
+ vt_buf = _reconstruct_state((state['mu_vbuf_nmf'], state['mv_vbuf_nmf']), signed=False)
632
+ vt_buf.mul_(beta2).addcmul_(update_f32.view(d1, d2), update_f32.view(d1, d2), value=1 - beta2)
633
+ state['mu_vbuf_nmf'], state['mv_vbuf_nmf'] = _factorize_state(vt_buf, signed=False)
634
+ # Apply second moment scaling
635
+ if group['use_atan2']:
636
+ denom = vt_buf.sqrt_().view(original_shape)
637
+ update.atan2_(denom.to(update.dtype))
638
+ else:
639
+ denom = vt_buf.sqrt_().add_(adaptive_eps).view(original_shape)
640
+ update.div_(denom.to(update.dtype))
641
+ del denom, vt_buf, update_f32
606
642
  else:
607
643
  # Original AdaMuon Logic
608
644
  update = update.view(original_shape)
609
- vt_buf = state['second_momentum_buffer']
645
+ vt_buf = get_state(state, 'second_momentum_buffer', actual_precision)
610
646
  vt_buf.mul_(beta2).addcmul_(update, update, value=1 - beta2)
611
647
  # Apply second momentum update (adaptive scaling)
612
648
  if group['use_atan2']:
@@ -615,20 +651,21 @@ class AdaMuon_adv(torch.optim.Optimizer):
615
651
  else:
616
652
  denom = vt_buf.sqrt().add_(adaptive_eps)
617
653
  update.div_(denom)
654
+ set_state(state, 'second_momentum_buffer', vt_buf, actual_precision, random_int_state_tensor, non_neg=True)
618
655
  del denom
619
656
 
620
- step_scale = lr * A if group['use_atan2'] and not group['normuon_variant'] else lr
657
+ step_scale = lr * A if group['use_atan2'] and not group['normuon_variant'] else lr
621
658
 
622
- if group.get('spectral_normalization', False):
623
- # Spectral Normalization
624
- spectral_norm_update(update, state['spectral_v'], spectral_target, step_scale)
625
- else:
626
- # RMS-aligned rescaling
627
- rms_adjustment(update, group['rms_rescaling'], step_scale)
659
+ if group.get('spectral_normalization', False):
660
+ # Spectral Normalization
661
+ spectral_normalization(update, state['spectral_u'], state['spectral_v'], step_scale)
662
+ else:
663
+ # RMS-aligned rescaling
664
+ rms_adjustment(update, group['rms_rescaling'], step_scale)
628
665
 
629
- update = update.reshape(original_shape)
666
+ update = update.reshape(original_shape)
630
667
 
631
- param_update.apply_parameter_update(self, p, group, update, lr, wd=weight_decay, random_int_tensor=random_int_tensor, decoupled=decoupled_wd)
668
+ param_update.apply_parameter_update(self, p, group, update, lr, random_int_tensor=random_int_tensor, decoupled=decoupled_wd)
632
669
 
633
670
  @torch.no_grad()
634
671
  def step(self, closure=None):