adv-optm 1.2.dev6__tar.gz → 1.2.dev8__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of adv-optm might be problematic. Click here for more details.

Files changed (29) hide show
  1. {adv_optm-1.2.dev6 → adv_optm-1.2.dev8}/PKG-INFO +1 -1
  2. {adv_optm-1.2.dev6 → adv_optm-1.2.dev8}/adv_optm/__init__.py +1 -1
  3. {adv_optm-1.2.dev6 → adv_optm-1.2.dev8}/adv_optm/optim/AdaMuon_adv.py +74 -68
  4. {adv_optm-1.2.dev6 → adv_optm-1.2.dev8}/adv_optm/optim/Muon_adv.py +166 -19
  5. {adv_optm-1.2.dev6 → adv_optm-1.2.dev8}/adv_optm/util/Kourkoutas.py +1 -1
  6. {adv_optm-1.2.dev6 → adv_optm-1.2.dev8}/adv_optm.egg-info/PKG-INFO +1 -1
  7. {adv_optm-1.2.dev6 → adv_optm-1.2.dev8}/setup.py +1 -1
  8. {adv_optm-1.2.dev6 → adv_optm-1.2.dev8}/LICENSE +0 -0
  9. {adv_optm-1.2.dev6 → adv_optm-1.2.dev8}/README.md +0 -0
  10. {adv_optm-1.2.dev6 → adv_optm-1.2.dev8}/adv_optm/optim/AdamW_adv.py +0 -0
  11. {adv_optm-1.2.dev6 → adv_optm-1.2.dev8}/adv_optm/optim/Adopt_adv.py +0 -0
  12. {adv_optm-1.2.dev6 → adv_optm-1.2.dev8}/adv_optm/optim/Lion_Prodigy_adv.py +0 -0
  13. {adv_optm-1.2.dev6 → adv_optm-1.2.dev8}/adv_optm/optim/Lion_adv.py +0 -0
  14. {adv_optm-1.2.dev6 → adv_optm-1.2.dev8}/adv_optm/optim/Prodigy_adv.py +0 -0
  15. {adv_optm-1.2.dev6 → adv_optm-1.2.dev8}/adv_optm/optim/Simplified_AdEMAMix.py +0 -0
  16. {adv_optm-1.2.dev6 → adv_optm-1.2.dev8}/adv_optm/optim/__init__.py +0 -0
  17. {adv_optm-1.2.dev6 → adv_optm-1.2.dev8}/adv_optm/util/BF16_Stochastic_Rounding.py +0 -0
  18. {adv_optm-1.2.dev6 → adv_optm-1.2.dev8}/adv_optm/util/Effective_Shape.py +0 -0
  19. {adv_optm-1.2.dev6 → adv_optm-1.2.dev8}/adv_optm/util/MuonAdam_helper.py +0 -0
  20. {adv_optm-1.2.dev6 → adv_optm-1.2.dev8}/adv_optm/util/NNMF.py +0 -0
  21. {adv_optm-1.2.dev6 → adv_optm-1.2.dev8}/adv_optm/util/Newton_Schulz.py +0 -0
  22. {adv_optm-1.2.dev6 → adv_optm-1.2.dev8}/adv_optm/util/One_Bit_Boolean.py +0 -0
  23. {adv_optm-1.2.dev6 → adv_optm-1.2.dev8}/adv_optm/util/OrthoGrad.py +0 -0
  24. {adv_optm-1.2.dev6 → adv_optm-1.2.dev8}/adv_optm/util/__init__.py +0 -0
  25. {adv_optm-1.2.dev6 → adv_optm-1.2.dev8}/adv_optm.egg-info/SOURCES.txt +0 -0
  26. {adv_optm-1.2.dev6 → adv_optm-1.2.dev8}/adv_optm.egg-info/dependency_links.txt +0 -0
  27. {adv_optm-1.2.dev6 → adv_optm-1.2.dev8}/adv_optm.egg-info/requires.txt +0 -0
  28. {adv_optm-1.2.dev6 → adv_optm-1.2.dev8}/adv_optm.egg-info/top_level.txt +0 -0
  29. {adv_optm-1.2.dev6 → adv_optm-1.2.dev8}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: adv_optm
3
- Version: 1.2.dev6
3
+ Version: 1.2.dev8
4
4
  Summary: A family of highly efficient, lightweight yet powerful optimizers.
5
5
  Home-page: https://github.com/Koratahiu/Advanced_Optimizers
6
6
  Author: Koratahiu
@@ -20,4 +20,4 @@ __all__ = [
20
20
  "AdaMuon_adv",
21
21
  ]
22
22
 
23
- __version__ = "1.2.dev6"
23
+ __version__ = "1.2.dev8"
@@ -3,7 +3,6 @@ from typing import Optional, Callable
3
3
 
4
4
  from .AdamW_adv import AdamW_adv
5
5
  from ..util.MuonAdam_helper import MuonAdamHelper
6
- from ..util.Kourkoutas import KourkoutasHelper
7
6
 
8
7
  from ..util.BF16_Stochastic_Rounding import add_stochastic_
9
8
  from ..util.Newton_Schulz import _newton_schulz_iteration
@@ -64,22 +63,13 @@ class AdaMuon_adv(torch.optim.Optimizer):
64
63
  matrices for muon NewtonSchulz (default: False).
65
64
  vector_reshape (bool): whether to reshape 1D vectors into 2D
66
65
  matrices to apply low-rank compression (default: True).
66
+ low_rank_ortho (bool): If True, enables low-rank orthogonalization, which
67
+ projects the update to a lower rank before orthogonalization.
68
+ (default: False)
69
+ ortho_rank (int): The rank for low-rank orthogonalization.
70
+ (default: 128)
67
71
  nnmf_factor (bool): whether to use the factorization or disable it to use
68
72
  the uncompressed optimizer. (default: False)
69
- kourkoutas_beta (bool): whether to enable the layer-wise dynamic β₂ logic.
70
- If `False`, the optimizer behaves as standard AdamW. (default: False)
71
- beta2_min (float): The minimum value for dynamic β₂, used during periods of
72
- high gradient variance ("sunspikes"). Must be less than `betas[1]`.
73
- (default: 0.88)
74
- ema_alpha (float): The decay rate for the Exponential Moving Average (EMA) of
75
- the pooled gradient norms. Corresponds to `α` in the paper.
76
- (default: 0.93)
77
- tiny_spike (float): A small constant added to the denominator of the
78
- "sunspike" ratio calculation to prevent division by zero. Corresponds
79
- to `ε_spike` in the paper. (default: 1e-9)
80
- k_warmup_steps (int): The number of initial steps during which β₂ is held
81
- at a fixed beta2 value before the
82
- dynamic logic activates. (default: 0)
83
73
  MuonWithAuxAdam (bool): If True, enables the hybrid optimizer mode.
84
74
  Parameters designated by `layer_key_fn` will be optimized with
85
75
  AdamW_adv instead of Muon. (default: False)
@@ -110,15 +100,10 @@ class AdaMuon_adv(torch.optim.Optimizer):
110
100
  alpha_grad: float = 100.0,
111
101
  vector_reshape_muon: bool = False,
112
102
  vector_reshape: bool = False,
103
+ # Low-rank Muon
104
+ low_rank_ortho: bool = False,
105
+ ortho_rank: int = 128,
113
106
  nnmf_factor: bool = False,
114
- # K-b parameters
115
- kourkoutas_beta: bool = False,
116
- beta2_min: float = 0.9,
117
- ema_alpha: float = 0.95,
118
- tiny_spike: float = 1e-9,
119
- k_warmup_steps: int = 0,
120
- k_logging: int = 0,
121
- layer_key_kb_fn: Optional[Callable] = None,
122
107
  # hybrid optimizer mode
123
108
  MuonWithAuxAdam: bool = False,
124
109
  layer_key_fn: Optional[Callable] = None,
@@ -142,14 +127,11 @@ class AdaMuon_adv(torch.optim.Optimizer):
142
127
  "vector_reshape": vector_reshape,
143
128
  "vector_reshape_muon": vector_reshape_muon,
144
129
  "nesterov":nesterov, "use_atan2":use_atan2,
145
- "Simplified_AdEMAMix": Simplified_AdEMAMix, "alpha_grad": alpha_grad,
146
- "_kourkoutas_beta": kourkoutas_beta, "beta2_min": beta2_min, "ema_alpha": ema_alpha,
147
- "tiny_spike": tiny_spike, "k_warmup_steps": k_warmup_steps, "k_logging": k_logging,
130
+ "Simplified_AdEMAMix": Simplified_AdEMAMix, "alpha_grad": alpha_grad,
131
+ # Low-rank Ortho
132
+ "low_rank_ortho": low_rank_ortho, "ortho_rank": ortho_rank,
148
133
  }
149
134
  self.stochastic_rounding = stochastic_rounding
150
- self._kourkoutas_beta = kourkoutas_beta
151
- self._kourkoutas_helper = None
152
- self.layer_key_kb_fn = layer_key_kb_fn
153
135
  self.MuonWithAuxAdam = MuonWithAuxAdam
154
136
  self.helper = None
155
137
  self.aux_adam = None
@@ -182,14 +164,9 @@ class AdaMuon_adv(torch.optim.Optimizer):
182
164
 
183
165
  for key, value in defaults_to_use.items():
184
166
  new_group.setdefault(key, value)
185
- if '_kourkoutas_beta' not in new_group:
186
- if optim_type == 'adam':
187
- new_group['_kourkoutas_beta'] = False
188
- else:
189
- new_group['_kourkoutas_beta'] = muon_defaults['_kourkoutas_beta']
190
167
  final_param_groups.append(new_group)
191
168
 
192
- super().__init__(final_param_groups, {})
169
+ super().__init__(final_param_groups, muon_defaults)
193
170
 
194
171
  # Now that self is initialized, create the helper
195
172
  self.helper = MuonAdamHelper(self, layer_key_fn)
@@ -219,9 +196,6 @@ class AdaMuon_adv(torch.optim.Optimizer):
219
196
 
220
197
  @torch.no_grad()
221
198
  def step_parameter(self, p: torch.Tensor, group: dict, i: int | None = None):
222
- if group['_kourkoutas_beta'] and self._kourkoutas_helper is None:
223
- self._kourkoutas_helper = KourkoutasHelper(self)
224
-
225
199
  if self.MuonWithAuxAdam:
226
200
  optim_type = self.helper.get_optimizer_type(p)
227
201
  if optim_type == 'adam':
@@ -277,7 +251,6 @@ class AdaMuon_adv(torch.optim.Optimizer):
277
251
 
278
252
  # Retrieve hyperparameters
279
253
  beta1, beta2 = group['betas']
280
- current_step = state['step']
281
254
  nesterov = group['nesterov']
282
255
  Simplified_AdEMAMix = group['Simplified_AdEMAMix']
283
256
  alpha_grad = group['alpha_grad']
@@ -303,20 +276,37 @@ class AdaMuon_adv(torch.optim.Optimizer):
303
276
  signed_m_buf = torch.sign(mt_buf)
304
277
  del grad_reshaped
305
278
 
306
- update = _newton_schulz_iteration(
307
- signed_m_buf,
308
- steps=group['ns_steps'],
309
- eps=group['ns_eps'],
310
- coeffs=group['ns_coeffs'],
311
- )
312
-
313
- if group['_kourkoutas_beta']:
314
- # Call prepare_step() once at the beginning of the step for all params
315
- self._kourkoutas_helper.maybe_prepare_step(current_step)
316
- # Accumulate current sign-stabilized orthogonal update's norm for the *next* step
317
- self._kourkoutas_helper.accumulate_gradient_sq_norm(p, update.view(p.shape))
318
- # Get the dynamic beta2 calculated in prepare_step()
319
- beta2 = self._kourkoutas_helper.get_beta2(p, group, current_step)
279
+ # Orthogonalization step
280
+ if group['low_rank_ortho']:
281
+ # Low-Rank Orthogonalization on the reconstructed matrix
282
+ M = signed_m_buf
283
+ r = min(group['ortho_rank'], M.shape[0], M.shape[1])
284
+ if r > 0:
285
+ G_sketch = torch.randn(M.shape[1], r, device=M.device, dtype=M.dtype)
286
+ MG = M @ G_sketch
287
+ if MG.dtype != torch.float32:
288
+ MG_dtype = M.dtype
289
+ Q, _ = torch.linalg.qr(MG.float())
290
+ Q = Q.to(MG_dtype)
291
+ else:
292
+ Q, _ = torch.linalg.qr(MG)
293
+ projected_M = Q.T @ M
294
+ ortho_projected_M = _newton_schulz_iteration(
295
+ projected_M, steps=group['ns_steps'], eps=group['ns_eps'], coeffs=group['ns_coeffs']
296
+ )
297
+ update = Q @ ortho_projected_M
298
+ else: # Fallback for invalid rank
299
+ update = _newton_schulz_iteration(
300
+ signed_m_buf, steps=group['ns_steps'], eps=group['ns_eps'], coeffs=group['ns_coeffs']
301
+ )
302
+ else:
303
+ # Original full Newton-Schulz
304
+ update = _newton_schulz_iteration(
305
+ signed_m_buf,
306
+ steps=group['ns_steps'],
307
+ eps=group['ns_eps'],
308
+ coeffs=group['ns_coeffs'],
309
+ )
320
310
 
321
311
  # Reconstruct second momentum from previous step's factors
322
312
  vt_buf = _unnmf((state['mu_v_nmf'], state['mv_v_nmf']))
@@ -381,25 +371,41 @@ class AdaMuon_adv(torch.optim.Optimizer):
381
371
  if len(p.shape) > 2:
382
372
  signed_m_buf = signed_m_buf.view(p.shape[0], -1)
383
373
 
384
- # NewtonSchulz
385
- update = _newton_schulz_iteration(
386
- signed_m_buf,
387
- steps=group['ns_steps'],
388
- eps=group['ns_eps'],
389
- coeffs=group['ns_coeffs'],
390
- )
374
+ # Orthogonalization step
375
+ if group['low_rank_ortho']:
376
+ # Low-Rank Orthogonalization on the reconstructed matrix
377
+ M = signed_m_buf
378
+ r = min(group['ortho_rank'], M.shape[0], M.shape[1])
379
+ if r > 0:
380
+ G_sketch = torch.randn(M.shape[1], r, device=M.device, dtype=M.dtype)
381
+ MG = M @ G_sketch
382
+ if MG.dtype != torch.float32:
383
+ MG_dtype = M.dtype
384
+ Q, _ = torch.linalg.qr(MG.float())
385
+ Q = Q.to(MG_dtype)
386
+ else:
387
+ Q, _ = torch.linalg.qr(MG)
388
+ projected_M = Q.T @ M
389
+ ortho_projected_M = _newton_schulz_iteration(
390
+ projected_M, steps=group['ns_steps'], eps=group['ns_eps'], coeffs=group['ns_coeffs']
391
+ )
392
+ update = Q @ ortho_projected_M
393
+ else: # Fallback for invalid rank
394
+ update = _newton_schulz_iteration(
395
+ signed_m_buf, steps=group['ns_steps'], eps=group['ns_eps'], coeffs=group['ns_coeffs']
396
+ )
397
+ else:
398
+ # Original full Newton-Schulz
399
+ update = _newton_schulz_iteration(
400
+ signed_m_buf,
401
+ steps=group['ns_steps'],
402
+ eps=group['ns_eps'],
403
+ coeffs=group['ns_coeffs'],
404
+ )
391
405
 
392
406
  if len(p.shape) > 2 or state['reshaped_1d_muon']:
393
407
  update = update.view(p.shape)
394
408
 
395
- if group['_kourkoutas_beta']:
396
- # Call prepare_step() once at the beginning of the step for all params
397
- self._kourkoutas_helper.maybe_prepare_step(current_step)
398
- # Accumulate current sign-stabilized orthogonal update's norm for the *next* step
399
- self._kourkoutas_helper.accumulate_gradient_sq_norm(p, update)
400
- # Get the dynamic beta2 calculated in prepare_step()
401
- beta2 = self._kourkoutas_helper.get_beta2(p, group, current_step)
402
-
403
409
  vt_buf = state['second_momentum_buffer']
404
410
  vt_buf.mul_(beta2).addcmul_(update, update, value=1 - beta2)
405
411
 
@@ -18,6 +18,10 @@ class Muon_adv(torch.optim.Optimizer):
18
18
  the hidden layers of neural networks. It applies SGD with momentum and then
19
19
  orthogonalizes the resulting update matrix using a Newton-Schulz iteration.
20
20
 
21
+ NorMuon (Neuron-wise Normalized Muon) extends this by adding neuron-level
22
+ adaptive learning rates, combining the benefits of orthogonalization with
23
+ second-order momentum statistics.
24
+
21
25
  This implementation is designed for 2D parameters (e.g., linear layers) and
22
26
  can handle other-dimensional parameters (e.g., 1D bias, 4D convolutional layers) by
23
27
  flattening/reshaping them.
@@ -54,6 +58,19 @@ class Muon_adv(torch.optim.Optimizer):
54
58
  matrices to apply low-rank compression (default: True).
55
59
  nnmf_factor (bool): whether to use the factorization or disable it to use
56
60
  the uncompressed optimizer. (default: False)
61
+ low_rank_ortho (bool): If True, enables low-rank orthogonalization, which
62
+ projects the update to a lower rank before orthogonalization.
63
+ (default: False)
64
+ ortho_rank (int): The rank for low-rank orthogonalization.
65
+ (default: 128)
66
+ normuon_variant (bool): If True, enables the NorMuon update rule, which adds
67
+ neuron-wise normalization. (default: False)
68
+ beta2_normuon (float): The exponential decay rate for the second moment estimates
69
+ used in NorMuon. (default: 0.95)
70
+ normuon_eps (float): Epsilon for NorMuon normalization stability. (default: 1e-8)
71
+ normuon_lr_scale (float): Scaling factor for the NorMuon learning rate.
72
+ (default: 0.2)
73
+ normuon_atan2 (bool): whether to use the atan2 for NorMuon. (default: False)
57
74
  MuonWithAuxAdam (bool): If True, enables the hybrid optimizer mode.
58
75
  Parameters designated by `layer_key_fn` will be optimized with
59
76
  AdamW_adv instead of Muon. (default: False)
@@ -82,6 +99,15 @@ class Muon_adv(torch.optim.Optimizer):
82
99
  vector_reshape_muon: bool = False,
83
100
  vector_reshape: bool = False,
84
101
  nnmf_factor: bool = False,
102
+ # Low-rank Muon
103
+ low_rank_ortho: bool = False,
104
+ ortho_rank: int = 128,
105
+ # NorMuon additions
106
+ normuon_variant: bool = False,
107
+ beta2_normuon: float = 0.95,
108
+ normuon_eps: float = 1e-8,
109
+ normuon_lr_scale: float = 0.2,
110
+ normuon_atan2: bool = False,
85
111
  # hybrid optimizer mode
86
112
  MuonWithAuxAdam: bool = False,
87
113
  layer_key_fn: Optional[Callable] = None,
@@ -92,6 +118,8 @@ class Muon_adv(torch.optim.Optimizer):
92
118
  raise ValueError(f"Learning-rate should be >= 0.0. Got {lr}")
93
119
  if not (0.0 <= beta1 < 1.0):
94
120
  raise ValueError(f"beta1 should be in [0.0, 1.0). Got {beta1}")
121
+ if normuon_variant and not (0.0 <= beta2_normuon < 1.0):
122
+ raise ValueError(f"beta2_normuon should be in [0.0, 1.0) for NorMuon. Got {beta2_normuon}")
95
123
  if not (weight_decay >= 0.0):
96
124
  raise ValueError(f"Weight-decay should be >= 0.0. Got {weight_decay}")
97
125
  if not (ns_steps > 0):
@@ -106,10 +134,16 @@ class Muon_adv(torch.optim.Optimizer):
106
134
  "ns_coeffs": ns_coeffs, "nnmf_factor": nnmf_factor,
107
135
  "vector_reshape": vector_reshape,
108
136
  "vector_reshape_muon": vector_reshape_muon,
109
- "Simplified_AdEMAMix": Simplified_AdEMAMix, "alpha_grad": alpha_grad,
137
+ "Simplified_AdEMAMix": Simplified_AdEMAMix, "alpha_grad": alpha_grad,
138
+ # Low-rank Ortho
139
+ "low_rank_ortho": low_rank_ortho, "ortho_rank": ortho_rank,
140
+ # NorMuon
141
+ "normuon_variant": normuon_variant, "beta2_normuon": beta2_normuon,
142
+ "normuon_eps": normuon_eps, "normuon_lr_scale": normuon_lr_scale,
143
+ "normuon_atan2": normuon_atan2,
110
144
  }
111
145
  self.stochastic_rounding = stochastic_rounding
112
-
146
+
113
147
  self.MuonWithAuxAdam = MuonWithAuxAdam
114
148
  self.helper = None
115
149
  self.aux_adam = None
@@ -144,7 +178,7 @@ class Muon_adv(torch.optim.Optimizer):
144
178
 
145
179
  final_param_groups.append(new_group)
146
180
 
147
- super().__init__(final_param_groups, {})
181
+ super().__init__(final_param_groups, muon_defaults)
148
182
 
149
183
  # Now that self is initialized, create the helper
150
184
  self.helper = MuonAdamHelper(self, layer_key_fn)
@@ -223,6 +257,12 @@ class Muon_adv(torch.optim.Optimizer):
223
257
  elif len(p.shape) == 1:
224
258
  state['momentum_buffer'] = torch.zeros_like(p)
225
259
 
260
+ # NorMuon state initialization
261
+ if group['normuon_variant']:
262
+ if len(p.shape) >= 2 or state['reshaped_1d_muon']:
263
+ num_rows = p.shape[0] if len(p.shape) >= 2 else state['effective_shape'][0]
264
+ state['normuon_v'] = torch.zeros(num_rows, device=p.device, dtype=torch.float32)
265
+
226
266
  beta1 = group['beta1']
227
267
  nesterov = group['nesterov']
228
268
  Simplified_AdEMAMix = group['Simplified_AdEMAMix']
@@ -251,14 +291,60 @@ class Muon_adv(torch.optim.Optimizer):
251
291
  update = mt_buf.clone()
252
292
  del grad_reshaped
253
293
 
254
- update = _newton_schulz_iteration(
255
- update,
256
- steps=group['ns_steps'],
257
- eps=group['ns_eps'],
258
- coeffs=group['ns_coeffs'],
259
- )
294
+ # Orthogonalization step
295
+ if group['low_rank_ortho']:
296
+ # Low-Rank Orthogonalization on the reconstructed matrix
297
+ M = update
298
+ r = min(group['ortho_rank'], M.shape[0], M.shape[1])
299
+ if r > 0:
300
+ G_sketch = torch.randn(M.shape[1], r, device=M.device, dtype=M.dtype)
301
+ MG = M @ G_sketch
302
+ if MG.dtype != torch.float32:
303
+ MG_dtype = M.dtype
304
+ Q, _ = torch.linalg.qr(MG.float())
305
+ Q = Q.to(MG_dtype)
306
+ else:
307
+ Q, _ = torch.linalg.qr(MG)
308
+ projected_M = Q.T @ M
309
+ ortho_projected_M = _newton_schulz_iteration(
310
+ projected_M, steps=group['ns_steps'], eps=group['ns_eps'], coeffs=group['ns_coeffs']
311
+ )
312
+ update = Q @ ortho_projected_M
313
+ else: # Fallback for invalid rank
314
+ update = _newton_schulz_iteration(
315
+ update, steps=group['ns_steps'], eps=group['ns_eps'], coeffs=group['ns_coeffs']
316
+ )
317
+ else:
318
+ # Original full Newton-Schulz
319
+ update = _newton_schulz_iteration(
320
+ update,
321
+ steps=group['ns_steps'],
322
+ eps=group['ns_eps'],
323
+ coeffs=group['ns_coeffs'],
324
+ )
260
325
 
261
- update = update.view(p.shape).mul_(group['lr'])
326
+
327
+ if group['normuon_variant'] and 'normuon_v' in state:
328
+ v_t = state['normuon_v']
329
+ beta2_normuon = group['beta2_normuon']
330
+ # Update 2nd moment estimate
331
+ mean_squared_update = torch.mean(update.square(), dim=1)
332
+ v_t.mul_(beta2_normuon).add_(mean_squared_update, alpha=1 - beta2_normuon)
333
+ # Normalize update
334
+ if group['normuon_atan2']:
335
+ a = 1.2732395
336
+ update.atan2_(v_t.sqrt().unsqueeze(1)).mul_(a)
337
+ else:
338
+ update.div_(v_t.sqrt().unsqueeze(1).add_(group['normuon_eps']))
339
+ # Scale learning rate
340
+ update_norm = torch.linalg.vector_norm(update)
341
+ if update_norm > 1e-12:
342
+ scaled_lr = group['normuon_lr_scale'] * group['lr'] * (p.numel()**0.5) / update_norm
343
+ else:
344
+ scaled_lr = 0.0
345
+ update = update.view(p.shape).mul_(scaled_lr)
346
+ else: # Original Muon learning rate application
347
+ update = update.view(p.shape).mul_(group['lr'])
262
348
 
263
349
  state['sign'] = _pack_bools(mt_buf > 0)
264
350
  _nnmf(mt_buf.abs(), out=(state['mu_m_nmf'], state['mv_m_nmf']))
@@ -298,19 +384,80 @@ class Muon_adv(torch.optim.Optimizer):
298
384
  if len(p.shape) > 2:
299
385
  update = update.view(p.shape[0], -1)
300
386
 
301
- # NewtonSchulz
302
- update = _newton_schulz_iteration(
303
- update,
304
- steps=group['ns_steps'],
305
- eps=group['ns_eps'],
306
- coeffs=group['ns_coeffs'],
307
- )
387
+ # Orthogonalization step
388
+ if group['low_rank_ortho']:
389
+ # Low-Rank Orthogonalization based on Gaussian Sketching
390
+ M = update
391
+ r = min(group['ortho_rank'], M.shape[0], M.shape[1])
392
+
393
+ if r > 0:
394
+ # 1. Sketch the matrix
395
+ G_sketch = torch.randn(M.shape[1], r, device=M.device, dtype=M.dtype)
396
+ MG = M @ G_sketch
397
+
398
+ # 2. QR decomposition to get orthogonal basis Q
399
+ if MG.dtype != torch.float32:
400
+ MG_dtype = M.dtype
401
+ Q, _ = torch.linalg.qr(MG.float())
402
+ Q = Q.to(MG_dtype)
403
+ else:
404
+ Q, _ = torch.linalg.qr(MG)
405
+
406
+ # 3. Project M onto the basis
407
+ projected_M = Q.T @ M
408
+
409
+ # 4. Orthogonalize the smaller projected matrix
410
+ ortho_projected_M = _newton_schulz_iteration(
411
+ projected_M,
412
+ steps=group['ns_steps'],
413
+ eps=group['ns_eps'],
414
+ coeffs=group['ns_coeffs'],
415
+ )
416
+
417
+ # 5. Project back to the original space
418
+ update = Q @ ortho_projected_M
419
+ else: # Fallback for invalid rank
420
+ update = _newton_schulz_iteration(
421
+ update,
422
+ steps=group['ns_steps'],
423
+ eps=group['ns_eps'],
424
+ coeffs=group['ns_coeffs'],
425
+ )
426
+ else:
427
+ # Original NewtonSchulz
428
+ update = _newton_schulz_iteration(
429
+ update,
430
+ steps=group['ns_steps'],
431
+ eps=group['ns_eps'],
432
+ coeffs=group['ns_coeffs'],
433
+ )
434
+
435
+ # NorMuon Logic
436
+ if group['normuon_variant'] and 'normuon_v' in state:
437
+ v_t = state['normuon_v']
438
+ beta2_normuon = group['beta2_normuon']
439
+ # Update 2nd moment estimate
440
+ mean_squared_update = torch.mean(update.square(), dim=1)
441
+ v_t.mul_(beta2_normuon).add_(mean_squared_update, alpha=1 - beta2_normuon)
442
+ # Normalize update
443
+ if group['normuon_atan2']:
444
+ a = 1.2732395
445
+ update.atan2_(v_t.sqrt().unsqueeze(1)).mul_(a)
446
+ else:
447
+ update.div_(v_t.sqrt().unsqueeze(1).add_(group['normuon_eps']))
448
+ # Scale learning rate
449
+ update_norm = torch.linalg.vector_norm(update)
450
+ if update_norm > 1e-12:
451
+ scaled_lr = group['normuon_lr_scale'] * group['lr'] * (p.numel()**0.5) / update_norm
452
+ else:
453
+ scaled_lr = 0.0
454
+ update.mul_(scaled_lr)
455
+ else: # Original Muon learning rate application
456
+ update.mul_(group['lr'])
308
457
 
309
458
  # Reshape back to original if we flattened or reshaped
310
459
  if len(p.shape) > 2 or state['reshaped_1d_muon']:
311
460
  update = update.view(p.shape)
312
-
313
- update.mul_(group['lr'])
314
461
 
315
462
  else: # Fallback to standard SGD with momentum for 1D params (biases, etc.) when not reshaped
316
463
  # Momentum update
@@ -94,7 +94,7 @@ class KourkoutasHelper:
94
94
  for layer_key, info in self.layer_info.items():
95
95
  params, group = info['params'], info['group_ref']
96
96
 
97
- if not group.get('kourkoutas_beta', False):
97
+ if not group.get('kourkoutas_beta', False) and not group.get('_kourkoutas_beta', False):
98
98
  continue
99
99
 
100
100
  first_param_in_layer = info['params'][0]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: adv_optm
3
- Version: 1.2.dev6
3
+ Version: 1.2.dev8
4
4
  Summary: A family of highly efficient, lightweight yet powerful optimizers.
5
5
  Home-page: https://github.com/Koratahiu/Advanced_Optimizers
6
6
  Author: Koratahiu
@@ -5,7 +5,7 @@ with open("README.md", "r", encoding="utf-8") as fh:
5
5
 
6
6
  setup(
7
7
  name="adv_optm",
8
- version="1.2.dev6",
8
+ version="1.2.dev8",
9
9
  author="Koratahiu",
10
10
  author_email="hiuhonor@gmail.com",
11
11
  license='Apache 2.0',
File without changes
File without changes
File without changes