adv-optm 1.2.dev5__tar.gz → 1.2.dev7__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of adv-optm might be problematic. Click here for more details.

Files changed (29) hide show
  1. {adv_optm-1.2.dev5 → adv_optm-1.2.dev7}/PKG-INFO +1 -1
  2. {adv_optm-1.2.dev5 → adv_optm-1.2.dev7}/adv_optm/__init__.py +1 -1
  3. {adv_optm-1.2.dev5 → adv_optm-1.2.dev7}/adv_optm/optim/Muon_adv.py +165 -18
  4. {adv_optm-1.2.dev5 → adv_optm-1.2.dev7}/adv_optm/util/Kourkoutas.py +21 -5
  5. {adv_optm-1.2.dev5 → adv_optm-1.2.dev7}/adv_optm.egg-info/PKG-INFO +1 -1
  6. {adv_optm-1.2.dev5 → adv_optm-1.2.dev7}/setup.py +1 -1
  7. {adv_optm-1.2.dev5 → adv_optm-1.2.dev7}/LICENSE +0 -0
  8. {adv_optm-1.2.dev5 → adv_optm-1.2.dev7}/README.md +0 -0
  9. {adv_optm-1.2.dev5 → adv_optm-1.2.dev7}/adv_optm/optim/AdaMuon_adv.py +0 -0
  10. {adv_optm-1.2.dev5 → adv_optm-1.2.dev7}/adv_optm/optim/AdamW_adv.py +0 -0
  11. {adv_optm-1.2.dev5 → adv_optm-1.2.dev7}/adv_optm/optim/Adopt_adv.py +0 -0
  12. {adv_optm-1.2.dev5 → adv_optm-1.2.dev7}/adv_optm/optim/Lion_Prodigy_adv.py +0 -0
  13. {adv_optm-1.2.dev5 → adv_optm-1.2.dev7}/adv_optm/optim/Lion_adv.py +0 -0
  14. {adv_optm-1.2.dev5 → adv_optm-1.2.dev7}/adv_optm/optim/Prodigy_adv.py +0 -0
  15. {adv_optm-1.2.dev5 → adv_optm-1.2.dev7}/adv_optm/optim/Simplified_AdEMAMix.py +0 -0
  16. {adv_optm-1.2.dev5 → adv_optm-1.2.dev7}/adv_optm/optim/__init__.py +0 -0
  17. {adv_optm-1.2.dev5 → adv_optm-1.2.dev7}/adv_optm/util/BF16_Stochastic_Rounding.py +0 -0
  18. {adv_optm-1.2.dev5 → adv_optm-1.2.dev7}/adv_optm/util/Effective_Shape.py +0 -0
  19. {adv_optm-1.2.dev5 → adv_optm-1.2.dev7}/adv_optm/util/MuonAdam_helper.py +0 -0
  20. {adv_optm-1.2.dev5 → adv_optm-1.2.dev7}/adv_optm/util/NNMF.py +0 -0
  21. {adv_optm-1.2.dev5 → adv_optm-1.2.dev7}/adv_optm/util/Newton_Schulz.py +0 -0
  22. {adv_optm-1.2.dev5 → adv_optm-1.2.dev7}/adv_optm/util/One_Bit_Boolean.py +0 -0
  23. {adv_optm-1.2.dev5 → adv_optm-1.2.dev7}/adv_optm/util/OrthoGrad.py +0 -0
  24. {adv_optm-1.2.dev5 → adv_optm-1.2.dev7}/adv_optm/util/__init__.py +0 -0
  25. {adv_optm-1.2.dev5 → adv_optm-1.2.dev7}/adv_optm.egg-info/SOURCES.txt +0 -0
  26. {adv_optm-1.2.dev5 → adv_optm-1.2.dev7}/adv_optm.egg-info/dependency_links.txt +0 -0
  27. {adv_optm-1.2.dev5 → adv_optm-1.2.dev7}/adv_optm.egg-info/requires.txt +0 -0
  28. {adv_optm-1.2.dev5 → adv_optm-1.2.dev7}/adv_optm.egg-info/top_level.txt +0 -0
  29. {adv_optm-1.2.dev5 → adv_optm-1.2.dev7}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: adv_optm
3
- Version: 1.2.dev5
3
+ Version: 1.2.dev7
4
4
  Summary: A family of highly efficient, lightweight yet powerful optimizers.
5
5
  Home-page: https://github.com/Koratahiu/Advanced_Optimizers
6
6
  Author: Koratahiu
@@ -20,4 +20,4 @@ __all__ = [
20
20
  "AdaMuon_adv",
21
21
  ]
22
22
 
23
- __version__ = "1.2.dev5"
23
+ __version__ = "1.2.dev7"
@@ -18,6 +18,10 @@ class Muon_adv(torch.optim.Optimizer):
18
18
  the hidden layers of neural networks. It applies SGD with momentum and then
19
19
  orthogonalizes the resulting update matrix using a Newton-Schulz iteration.
20
20
 
21
+ NorMuon (Neuron-wise Normalized Muon) extends this by adding neuron-level
22
+ adaptive learning rates, combining the benefits of orthogonalization with
23
+ second-order momentum statistics.
24
+
21
25
  This implementation is designed for 2D parameters (e.g., linear layers) and
22
26
  can handle other-dimensional parameters (e.g., 1D bias, 4D convolutional layers) by
23
27
  flattening/reshaping them.
@@ -54,6 +58,19 @@ class Muon_adv(torch.optim.Optimizer):
54
58
  matrices to apply low-rank compression (default: True).
55
59
  nnmf_factor (bool): whether to use the factorization or disable it to use
56
60
  the uncompressed optimizer. (default: False)
61
+ low_rank_ortho (bool): If True, enables low-rank orthogonalization, which
62
+ projects the update to a lower rank before orthogonalization.
63
+ (default: False)
64
+ ortho_rank (int): The rank for low-rank orthogonalization.
65
+ (default: 128)
66
+ normuon_variant (bool): If True, enables the NorMuon update rule, which adds
67
+ neuron-wise normalization. (default: False)
68
+ beta2_normuon (float): The exponential decay rate for the second moment estimates
69
+ used in NorMuon. (default: 0.95)
70
+ normuon_eps (float): Epsilon for NorMuon normalization stability. (default: 1e-8)
71
+ normuon_lr_scale (float): Scaling factor for the NorMuon learning rate.
72
+ (default: 0.2)
73
+ normuon_atan2 (bool): whether to use the atan2 for NorMuon. (default: False)
57
74
  MuonWithAuxAdam (bool): If True, enables the hybrid optimizer mode.
58
75
  Parameters designated by `layer_key_fn` will be optimized with
59
76
  AdamW_adv instead of Muon. (default: False)
@@ -82,6 +99,15 @@ class Muon_adv(torch.optim.Optimizer):
82
99
  vector_reshape_muon: bool = False,
83
100
  vector_reshape: bool = False,
84
101
  nnmf_factor: bool = False,
102
+ # Low-rank Muon
103
+ low_rank_ortho: bool = False,
104
+ ortho_rank: int = 128,
105
+ # NorMuon additions
106
+ normuon_variant: bool = False,
107
+ beta2_normuon: float = 0.95,
108
+ normuon_eps: float = 1e-8,
109
+ normuon_lr_scale: float = 0.2,
110
+ normuon_atan2: bool = False,
85
111
  # hybrid optimizer mode
86
112
  MuonWithAuxAdam: bool = False,
87
113
  layer_key_fn: Optional[Callable] = None,
@@ -92,6 +118,8 @@ class Muon_adv(torch.optim.Optimizer):
92
118
  raise ValueError(f"Learning-rate should be >= 0.0. Got {lr}")
93
119
  if not (0.0 <= beta1 < 1.0):
94
120
  raise ValueError(f"beta1 should be in [0.0, 1.0). Got {beta1}")
121
+ if normuon_variant and not (0.0 <= beta2_normuon < 1.0):
122
+ raise ValueError(f"beta2_normuon should be in [0.0, 1.0) for NorMuon. Got {beta2_normuon}")
95
123
  if not (weight_decay >= 0.0):
96
124
  raise ValueError(f"Weight-decay should be >= 0.0. Got {weight_decay}")
97
125
  if not (ns_steps > 0):
@@ -106,10 +134,16 @@ class Muon_adv(torch.optim.Optimizer):
106
134
  "ns_coeffs": ns_coeffs, "nnmf_factor": nnmf_factor,
107
135
  "vector_reshape": vector_reshape,
108
136
  "vector_reshape_muon": vector_reshape_muon,
109
- "Simplified_AdEMAMix": Simplified_AdEMAMix, "alpha_grad": alpha_grad,
137
+ "Simplified_AdEMAMix": Simplified_AdEMAMix, "alpha_grad": alpha_grad,
138
+ # Low-rank Ortho
139
+ "low_rank_ortho": low_rank_ortho, "ortho_rank": ortho_rank,
140
+ # NorMuon
141
+ "normuon_variant": normuon_variant, "beta2_normuon": beta2_normuon,
142
+ "normuon_eps": normuon_eps, "normuon_lr_scale": normuon_lr_scale,
143
+ "normuon_atan2": normuon_atan2,
110
144
  }
111
145
  self.stochastic_rounding = stochastic_rounding
112
-
146
+
113
147
  self.MuonWithAuxAdam = MuonWithAuxAdam
114
148
  self.helper = None
115
149
  self.aux_adam = None
@@ -223,6 +257,12 @@ class Muon_adv(torch.optim.Optimizer):
223
257
  elif len(p.shape) == 1:
224
258
  state['momentum_buffer'] = torch.zeros_like(p)
225
259
 
260
+ # NorMuon state initialization
261
+ if group['normuon_variant']:
262
+ if len(p.shape) >= 2 or state['reshaped_1d_muon']:
263
+ num_rows = p.shape[0] if len(p.shape) >= 2 else state['effective_shape'][0]
264
+ state['normuon_v'] = torch.zeros(num_rows, device=p.device, dtype=torch.float32)
265
+
226
266
  beta1 = group['beta1']
227
267
  nesterov = group['nesterov']
228
268
  Simplified_AdEMAMix = group['Simplified_AdEMAMix']
@@ -251,14 +291,60 @@ class Muon_adv(torch.optim.Optimizer):
251
291
  update = mt_buf.clone()
252
292
  del grad_reshaped
253
293
 
254
- update = _newton_schulz_iteration(
255
- update,
256
- steps=group['ns_steps'],
257
- eps=group['ns_eps'],
258
- coeffs=group['ns_coeffs'],
259
- )
294
+ # Orthogonalization step
295
+ if group['low_rank_muon']:
296
+ # Low-Rank Orthogonalization on the reconstructed matrix
297
+ M = update
298
+ r = min(group['low_rank_rank'], M.shape[0], M.shape[1])
299
+ if r > 0:
300
+ G_sketch = torch.randn(M.shape[1], r, device=M.device, dtype=M.dtype)
301
+ MG = M @ G_sketch
302
+ if MG.dtype != torch.float32:
303
+ MG_dtype = M.dtype
304
+ Q, _ = torch.linalg.qr(MG.float())
305
+ Q = Q.to(MG_dtype)
306
+ else:
307
+ Q, _ = torch.linalg.qr(MG)
308
+ projected_M = Q.T @ M
309
+ ortho_projected_M = _newton_schulz_iteration(
310
+ projected_M, steps=group['ns_steps'], eps=group['ns_eps'], coeffs=group['ns_coeffs']
311
+ )
312
+ update = Q @ ortho_projected_M
313
+ else: # Fallback for invalid rank
314
+ update = _newton_schulz_iteration(
315
+ update, steps=group['ns_steps'], eps=group['ns_eps'], coeffs=group['ns_coeffs']
316
+ )
317
+ else:
318
+ # Original full Newton-Schulz
319
+ update = _newton_schulz_iteration(
320
+ update,
321
+ steps=group['ns_steps'],
322
+ eps=group['ns_eps'],
323
+ coeffs=group['ns_coeffs'],
324
+ )
260
325
 
261
- update = update.view(p.shape).mul_(group['lr'])
326
+
327
+ if group['normuon_variant'] and 'normuon_v' in state:
328
+ v_t = state['normuon_v']
329
+ beta2_normuon = group['beta2_normuon']
330
+ # Update 2nd moment estimate
331
+ mean_squared_update = torch.mean(update.square(), dim=1)
332
+ v_t.mul_(beta2_normuon).add_(mean_squared_update, alpha=1 - beta2_normuon)
333
+ # Normalize update
334
+ if group['normuon_atan2']:
335
+ a = 1.2732395
336
+ update.atan2_(v_t.sqrt().unsqueeze(1)).mul_(a)
337
+ else:
338
+ update.div_(v_t.sqrt().unsqueeze(1).add_(group['normuon_eps']))
339
+ # Scale learning rate
340
+ update_norm = torch.linalg.vector_norm(update)
341
+ if update_norm > 1e-12:
342
+ scaled_lr = group['normuon_lr_scale'] * group['lr'] * (p.numel()**0.5) / update_norm
343
+ else:
344
+ scaled_lr = 0.0
345
+ update = update.view(p.shape).mul_(scaled_lr)
346
+ else: # Original Muon learning rate application
347
+ update = update.view(p.shape).mul_(group['lr'])
262
348
 
263
349
  state['sign'] = _pack_bools(mt_buf > 0)
264
350
  _nnmf(mt_buf.abs(), out=(state['mu_m_nmf'], state['mv_m_nmf']))
@@ -298,19 +384,80 @@ class Muon_adv(torch.optim.Optimizer):
298
384
  if len(p.shape) > 2:
299
385
  update = update.view(p.shape[0], -1)
300
386
 
301
- # NewtonSchulz
302
- update = _newton_schulz_iteration(
303
- update,
304
- steps=group['ns_steps'],
305
- eps=group['ns_eps'],
306
- coeffs=group['ns_coeffs'],
307
- )
387
+ # Orthogonalization step
388
+ if group['low_rank_ortho']:
389
+ # Low-Rank Orthogonalization based on Gaussian Sketching
390
+ M = update
391
+ r = min(group['ortho_rank'], M.shape[0], M.shape[1])
392
+
393
+ if r > 0:
394
+ # 1. Sketch the matrix
395
+ G_sketch = torch.randn(M.shape[1], r, device=M.device, dtype=M.dtype)
396
+ MG = M @ G_sketch
397
+
398
+ # 2. QR decomposition to get orthogonal basis Q
399
+ if MG.dtype != torch.float32:
400
+ MG_dtype = M.dtype
401
+ Q, _ = torch.linalg.qr(MG.float())
402
+ Q = Q.to(MG_dtype)
403
+ else:
404
+ Q, _ = torch.linalg.qr(MG)
405
+
406
+ # 3. Project M onto the basis
407
+ projected_M = Q.T @ M
408
+
409
+ # 4. Orthogonalize the smaller projected matrix
410
+ ortho_projected_M = _newton_schulz_iteration(
411
+ projected_M,
412
+ steps=group['ns_steps'],
413
+ eps=group['ns_eps'],
414
+ coeffs=group['ns_coeffs'],
415
+ )
416
+
417
+ # 5. Project back to the original space
418
+ update = Q @ ortho_projected_M
419
+ else: # Fallback for invalid rank
420
+ update = _newton_schulz_iteration(
421
+ update,
422
+ steps=group['ns_steps'],
423
+ eps=group['ns_eps'],
424
+ coeffs=group['ns_coeffs'],
425
+ )
426
+ else:
427
+ # Original NewtonSchulz
428
+ update = _newton_schulz_iteration(
429
+ update,
430
+ steps=group['ns_steps'],
431
+ eps=group['ns_eps'],
432
+ coeffs=group['ns_coeffs'],
433
+ )
434
+
435
+ # NorMuon Logic
436
+ if group['normuon_variant'] and 'normuon_v' in state:
437
+ v_t = state['normuon_v']
438
+ beta2_normuon = group['beta2_normuon']
439
+ # Update 2nd moment estimate
440
+ mean_squared_update = torch.mean(update.square(), dim=1)
441
+ v_t.mul_(beta2_normuon).add_(mean_squared_update, alpha=1 - beta2_normuon)
442
+ # Normalize update
443
+ if group['normuon_atan2']:
444
+ a = 1.2732395
445
+ update.atan2_(v_t.sqrt().unsqueeze(1)).mul_(a)
446
+ else:
447
+ update.div_(v_t.sqrt().unsqueeze(1).add_(group['normuon_eps']))
448
+ # Scale learning rate
449
+ update_norm = torch.linalg.vector_norm(update)
450
+ if update_norm > 1e-12:
451
+ scaled_lr = group['normuon_lr_scale'] * group['lr'] * (p.numel()**0.5) / update_norm
452
+ else:
453
+ scaled_lr = 0.0
454
+ update.mul_(scaled_lr)
455
+ else: # Original Muon learning rate application
456
+ update.mul_(group['lr'])
308
457
 
309
458
  # Reshape back to original if we flattened or reshaped
310
459
  if len(p.shape) > 2 or state['reshaped_1d_muon']:
311
460
  update = update.view(p.shape)
312
-
313
- update.mul_(group['lr'])
314
461
 
315
462
  else: # Fallback to standard SGD with momentum for 1D params (biases, etc.) when not reshaped
316
463
  # Momentum update
@@ -86,9 +86,17 @@ class KourkoutasHelper:
86
86
  # These are just for the sample log, initialize them
87
87
  sun, pooled_grad_norm, prev_r_ema_val, r_ema_tensor = (torch.tensor(0.0),)*4
88
88
 
89
+ # The optimizer that owns this helper holds the master defaults for K-b.
90
+ # This is crucial in hybrid optimizers where some param_groups might not
91
+ # have all K-b keys populated, preventing KeyErrors.
92
+ master_defaults = self.optimizer.defaults
93
+
89
94
  for layer_key, info in self.layer_info.items():
90
95
  params, group = info['params'], info['group_ref']
91
96
 
97
+ if not group.get('kourkoutas_beta', False):
98
+ continue
99
+
92
100
  first_param_in_layer = info['params'][0]
93
101
  param_state = self.optimizer.state[first_param_in_layer]
94
102
 
@@ -100,6 +108,15 @@ class KourkoutasHelper:
100
108
  if 'kourkoutas_r_ema' not in param_state:
101
109
  param_state['kourkoutas_r_ema'] = torch.tensor(0.0, device=first_param_in_layer.device, dtype=torch.float32)
102
110
 
111
+ # Use group-specific K-b settings, falling back to the optimizer's master defaults.
112
+ # This makes the helper robust against param groups that enable kourkoutas_beta
113
+ # but are missing the other required hyperparameters.
114
+ ema_alpha = group.get('ema_alpha', master_defaults['ema_alpha'])
115
+ beta2_max = group.get('betas', master_defaults['betas'])[1]
116
+ beta2_min = group.get('beta2_min', master_defaults['beta2_min'])
117
+ tiny_spike = group.get('tiny_spike', master_defaults['tiny_spike'])
118
+ k_warmup_steps = group.get('k_warmup_steps', master_defaults['k_warmup_steps'])
119
+
103
120
  r_ema_tensor = param_state['kourkoutas_r_ema']
104
121
  accumulator = self.layer_state[layer_key]['sum_sq_accumulator']
105
122
 
@@ -107,17 +124,16 @@ class KourkoutasHelper:
107
124
  prev_r_ema_val = r_ema_tensor.item() # for logging
108
125
 
109
126
  # Update the persistent EMA tensor in-place.
110
- r_ema_tensor.mul_(group['ema_alpha']).add_(pooled_grad_norm, alpha=1.0 - group['ema_alpha'])
127
+ r_ema_tensor.mul_(ema_alpha).add_(pooled_grad_norm, alpha=1.0 - ema_alpha)
111
128
 
112
- beta2_max = group['betas'][1]
113
129
  sun = torch.tensor(0.0, device=r_ema_tensor.device) # Default sun to 0 for warmup
114
130
 
115
- if current_step < group['k_warmup_steps']:
131
+ if current_step < k_warmup_steps:
116
132
  beta2 = beta2_max
117
133
  else:
118
- raw = pooled_grad_norm / (r_ema_tensor + group['tiny_spike'])
134
+ raw = pooled_grad_norm / (r_ema_tensor + tiny_spike)
119
135
  sun = raw / (1.0 + raw)
120
- beta2 = beta2_max - (beta2_max - group['beta2_min']) * sun
136
+ beta2 = beta2_max - (beta2_max - beta2_min) * sun
121
137
 
122
138
  # Store the final calculated beta2 in the helper's transient state for this step.
123
139
  self.layer_state[layer_key]['dynamic_beta2'] = beta2.item() if isinstance(beta2, torch.Tensor) else beta2
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: adv_optm
3
- Version: 1.2.dev5
3
+ Version: 1.2.dev7
4
4
  Summary: A family of highly efficient, lightweight yet powerful optimizers.
5
5
  Home-page: https://github.com/Koratahiu/Advanced_Optimizers
6
6
  Author: Koratahiu
@@ -5,7 +5,7 @@ with open("README.md", "r", encoding="utf-8") as fh:
5
5
 
6
6
  setup(
7
7
  name="adv_optm",
8
- version="1.2.dev5",
8
+ version="1.2.dev7",
9
9
  author="Koratahiu",
10
10
  author_email="hiuhonor@gmail.com",
11
11
  license='Apache 2.0',
File without changes
File without changes
File without changes