adv-optm 1.2.dev12__py3-none-any.whl → 1.2.dev14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of adv-optm might be problematic. Click here for more details.

adv_optm/__init__.py CHANGED
@@ -20,4 +20,4 @@ __all__ = [
20
20
  "AdaMuon_adv",
21
21
  ]
22
22
 
23
- __version__ = "1.2.dev12"
23
+ __version__ = "1.2.dev14"
@@ -1,13 +1,12 @@
1
1
  import torch
2
- from typing import Optional
3
-
4
- from .AdamW_adv import AdamW_adv
5
2
 
6
3
  from ..util.BF16_Stochastic_Rounding import add_stochastic_
7
4
  from ..util.Newton_Schulz import _newton_schulz_iteration
8
5
  from ..util.Effective_Shape import _get_effective_shape
9
6
  from ..util.NNMF import _nnmf,_unnmf
10
7
  from ..util.One_Bit_Boolean import _pack_bools, _unpack_bools
8
+ from ..util.OrthoGrad import _orthogonalize_gradient
9
+ from ..util.Kourkoutas import KourkoutasHelper
11
10
 
12
11
  class AdaMuon_adv(torch.optim.Optimizer):
13
12
  """
@@ -25,6 +24,8 @@ class AdaMuon_adv(torch.optim.Optimizer):
25
24
  3. An RMS-aligned rescaling strategy to match the update magnitude of Adam,
26
25
  allowing for reuse of learning rate schedules.
27
26
 
27
+ When `MuonWithAuxAdam` is enabled, this single optimizer class handles both
28
+ 'muon' and 'adam' parameter groups, dispatching to the appropriate logic internally.
28
29
 
29
30
  Args:
30
31
  params (iterable): iterable of parameters to optimize or dicts defining
@@ -55,8 +56,6 @@ class AdaMuon_adv(torch.optim.Optimizer):
55
56
  current gradient. For small batch sizes, use high values (e.g., 10-100) to be
56
57
  more responsive. For large batch sizes, use low values (e.g., 0-1) for
57
58
  stability. (default: 100.0)
58
- vector_reshape_muon (bool): whether to reshape 1D vectors into 2D
59
- matrices for muon NewtonSchulz (default: False).
60
59
  vector_reshape (bool): whether to reshape 1D vectors into 2D
61
60
  matrices to apply low-rank compression (default: True).
62
61
  low_rank_ortho (bool): If True, enables low-rank orthogonalization, which
@@ -66,6 +65,19 @@ class AdaMuon_adv(torch.optim.Optimizer):
66
65
  (default: 128)
67
66
  nnmf_factor (bool): whether to use the factorization or disable it to use
68
67
  the uncompressed optimizer. (default: False)
68
+ --- Auxiliary AdamW_adv Parameters (used for 'adam' groups) ---
69
+ adam_betas (tuple[float, float]): Betas for the AdamW optimizer part.
70
+ adam_eps (float): Epsilon for the AdamW optimizer part.
71
+ adam_weight_decay (float): Weight decay for the AdamW optimizer part.
72
+ adam_use_bias_correction (bool): Bias correction for AdamW.
73
+ adam_use_atan2 (bool): Atan2 update rule for AdamW.
74
+ adam_cautious_mask (bool): Cautious masking for AdamW.
75
+ adam_grams_moment (bool): Grams-style updates for AdamW.
76
+ adam_orthogonal_gradient (bool): OrthoGrad for AdamW.
77
+ adam_use_AdEMAMix (bool): AdEMAMix for AdamW.
78
+ adam_beta3_ema (float): Beta3 for AdEMAMix.
79
+ adam_alpha (float): Alpha for AdEMAMix.
80
+ adam_kourkoutas_beta (bool): Kourkoutas-β for AdamW.
69
81
  """
70
82
 
71
83
  def __init__(
@@ -79,17 +91,36 @@ class AdaMuon_adv(torch.optim.Optimizer):
79
91
  ns_steps: int = 5,
80
92
  ns_eps: float = 1e-7,
81
93
  ns_coeffs: tuple[float, float, float] = (3.4445, -4.7750, 2.0315),
82
- stochastic_rounding: bool = True,
94
+ stochastic_rounding: bool = False,
83
95
  use_atan2: bool = False,
84
96
  nesterov: bool = False,
85
97
  Simplified_AdEMAMix: bool = False,
86
98
  alpha_grad: float = 100.0,
87
- vector_reshape_muon: bool = False,
88
99
  vector_reshape: bool = False,
89
100
  # Low-rank Muon
90
101
  low_rank_ortho: bool = False,
91
102
  ortho_rank: int = 128,
92
103
  nnmf_factor: bool = False,
104
+ # Compiled
105
+ compiled_optimizer: bool = False,
106
+ # --- AdamW_adv specific parameters ---
107
+ adam_betas: tuple[float, float] = (0.9, 0.99),
108
+ adam_eps: float = 1e-8,
109
+ adam_weight_decay: float = 0.0,
110
+ adam_use_bias_correction: bool = True,
111
+ adam_use_atan2: bool = False,
112
+ adam_cautious_mask: bool = False,
113
+ adam_grams_moment: bool = False,
114
+ adam_orthogonal_gradient: bool = False,
115
+ adam_use_AdEMAMix: bool = False,
116
+ adam_beta3_ema: float = 0.9999,
117
+ adam_alpha: float = 5.0,
118
+ adam_kourkoutas_beta: bool = False,
119
+ adam_beta2_min: float = 0.9,
120
+ adam_ema_alpha: float = 0.95,
121
+ adam_tiny_spike: float = 1e-9,
122
+ adam_k_warmup_steps: int = 0,
123
+ adam_nnmf_factor: bool = False,
93
124
  ):
94
125
  if not (lr >= 0.0):
95
126
  raise ValueError(f"Learning-rate should be >= 0.0. Got {lr}")
@@ -98,7 +129,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
98
129
  if not (ns_steps > 0):
99
130
  raise ValueError(f"Newton-Schulz steps should be > 0. Got {ns_steps}")
100
131
  if Simplified_AdEMAMix and nesterov:
101
- print("Warning: nesterov is incompatible with Simplified_AdEMAMix, Disabling cautious.")
132
+ print("Warning: nesterov is incompatible with Simplified_AdEMAMix, Disabling nesterov.")
102
133
  nesterov = False
103
134
 
104
135
  defaults = {
@@ -106,16 +137,40 @@ class AdaMuon_adv(torch.optim.Optimizer):
106
137
  "eps": eps, "rms_target": rms_target, "ns_steps": ns_steps,
107
138
  "ns_eps": ns_eps, "ns_coeffs": ns_coeffs, "nnmf_factor": nnmf_factor,
108
139
  "vector_reshape": vector_reshape,
109
- "vector_reshape_muon": vector_reshape_muon,
110
140
  "nesterov":nesterov, "use_atan2":use_atan2,
111
141
  "Simplified_AdEMAMix": Simplified_AdEMAMix, "alpha_grad": alpha_grad,
112
142
  # Low-rank Ortho
113
143
  "low_rank_ortho": low_rank_ortho, "ortho_rank": ortho_rank,
144
+ "compiled_optimizer":compiled_optimizer,
145
+ # AdamW_adv defaults
146
+ "adam_betas": adam_betas, "adam_eps": adam_eps, "adam_weight_decay": adam_weight_decay,
147
+ "adam_use_bias_correction": adam_use_bias_correction, "adam_use_atan2": adam_use_atan2,
148
+ "adam_cautious_mask": adam_cautious_mask, "adam_grams_moment": adam_grams_moment,
149
+ "adam_orthogonal_gradient": adam_orthogonal_gradient,
150
+ "adam_use_AdEMAMix": adam_use_AdEMAMix, "adam_beta3_ema": adam_beta3_ema, "adam_alpha": adam_alpha,
151
+ "adam_kourkoutas_beta": adam_kourkoutas_beta, "adam_beta2_min": adam_beta2_min,
152
+ "adam_ema_alpha": adam_ema_alpha, "adam_tiny_spike": adam_tiny_spike,
153
+ "adam_k_warmup_steps": adam_k_warmup_steps, "adam_nnmf_factor": adam_nnmf_factor,
114
154
  }
115
155
  self.stochastic_rounding = stochastic_rounding
116
156
 
117
157
  super().__init__(params, defaults)
118
-
158
+
159
+ self.global_step = 0 # For Adam bias correction and Kourkoutas
160
+ self.kourkoutas_helper = None
161
+ if any(group.get('adam_kourkoutas_beta', False) for group in self.param_groups):
162
+ self.kourkoutas_helper = KourkoutasHelper(self)
163
+
164
+ self.init_step()
165
+
166
+ # Initialize compiled functions to None
167
+ self._compiled_muon_step = None
168
+ self._compiled_adam_step = None
169
+
170
+ if compiled_optimizer:
171
+ print("Compiling AdaMuon_adv optimizer paths...")
172
+ torch._dynamo.config.cache_size_limit = 8192
173
+ self.compile(fullgraph=True)
119
174
 
120
175
  @property
121
176
  def supports_fused_back_pass(self):
@@ -129,50 +184,80 @@ class AdaMuon_adv(torch.optim.Optimizer):
129
184
  def supports_flat_params(self):
130
185
  return False
131
186
 
132
- @torch.no_grad()
133
- def step_parameter(self, p: torch.Tensor, group: dict, i: int | None = None):
134
- if p.grad is None:
135
- return
187
+ def init_step(self):
188
+ for group in self.param_groups:
189
+ for i, p in enumerate(group['params']):
190
+ self.__init_state(p, group)
136
191
 
137
- grad = p.grad
192
+ @torch.no_grad()
193
+ def __init_state(self, p, group):
138
194
  state = self.state[p]
139
195
 
196
+ if len(state) > 0:
197
+ return
140
198
 
141
- # State Initialization
142
- if 'step' not in state:
143
- state['step'] = 0
199
+ optim_type = group.get('optim_type', 'muon')
144
200
 
145
- should_factor = (
201
+ if optim_type == 'muon':
202
+
203
+ state['factored'] = (
146
204
  group['nnmf_factor'] and
147
205
  not (len(p.shape) == 1 and not group['vector_reshape'])
148
206
  )
149
-
150
- state['factored'] = should_factor
151
-
152
- state['reshaped_1d_muon'] = len(p.shape) == 1 and group['vector_reshape_muon']
153
-
154
- dtype = torch.float32 if group['nnmf_factor'] else p.dtype
207
+ dtype = torch.float32 if state['factored'] else p.dtype
155
208
  device = p.device
156
- if state['factored'] or state['reshaped_1d_muon']:
209
+
210
+ if state['factored']:
157
211
  state['effective_shape'] = _get_effective_shape(p.numel())
158
212
  d1, d2 = state['effective_shape']
159
- if state['factored']:
160
- state['mu_m_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
161
- state['mv_m_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
213
+ state['mu_mbuf_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
214
+ state['mv_mbuf_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
162
215
  packed_d2 = (d2 + 7) // 8
163
- state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=device)
164
- state['mu_v_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
165
- state['mv_v_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
216
+ state['sign_buf'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=device)
217
+ state['mu_vbuf_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
218
+ state['mv_vbuf_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
166
219
  else:
167
220
  if len(p.shape) >= 2:
168
- state['momentum_buffer'] = torch.zeros_like(p)
169
221
  state['second_momentum_buffer'] = torch.zeros_like(p)
170
- if state['reshaped_1d_muon']:
171
- state['momentum_buffer'] = torch.zeros((d1, d2), device=device, dtype=dtype)
172
- state['second_momentum_buffer'] = torch.zeros((d1, d2), device=device, dtype=dtype)
173
- elif len(p.shape) == 1:
174
- state['momentum_buffer'] = torch.zeros_like(p)
222
+ state['momentum_buffer'] = torch.zeros_like(p)
223
+
224
+
225
+ elif optim_type == 'adam':
226
+
227
+ state['factored'] = (
228
+ group['adam_nnmf_factor'] and
229
+ not (len(p.shape) == 1 and not group['vector_reshape'])
230
+ )
231
+ dtype = torch.float32 if state['factored'] else p.dtype
232
+ device = p.device
175
233
 
234
+ if state['factored']:
235
+ state['effective_shape'] = _get_effective_shape(p.numel())
236
+ d1, d2 = state['effective_shape']
237
+ # First moment (m)
238
+ if group['adam_betas'][0] > 0:
239
+ state['mu_m_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
240
+ state['mv_m_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
241
+ if not group.get('adam_grams_moment'):
242
+ packed_d2 = (d2 + 7) // 8
243
+ state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=device)
244
+ if group.get('adam_use_AdEMAMix'):
245
+ state['mu_m_slow_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
246
+ state['mv_m_slow_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
247
+ packed_d2 = (d2 + 7) // 8
248
+ state['sign_slow'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=p.device)
249
+ # Second moment (v)
250
+ state['mu_v_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
251
+ state['mv_v_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
252
+ else: # Fallback to standard AdamW for non-factored tensors
253
+ if group['adam_betas'][0] > 0:
254
+ state['exp_avg'] = torch.zeros_like(p, device=device, dtype=dtype)
255
+ if group.get('adam_use_AdEMAMix'):
256
+ state['exp_avg_slow'] = torch.zeros_like(p, device=device, dtype=dtype)
257
+ state['exp_avg_sq'] = torch.zeros_like(p, device=device, dtype=dtype)
258
+
259
+ @torch.no_grad()
260
+ def _muon_step_parameter(self, p, grad, state, group, lr):
176
261
  # Retrieve hyperparameters
177
262
  beta1, beta2 = group['betas']
178
263
  nesterov = group['nesterov']
@@ -183,8 +268,8 @@ class AdaMuon_adv(torch.optim.Optimizer):
183
268
 
184
269
  # Reconstruct momentum from previous step's factors & sign
185
270
  d1, d2 = state['effective_shape']
186
- mt_buf = _unnmf((state['mu_m_nmf'], state['mv_m_nmf']))
187
- unpacked_sign = _unpack_bools(state['sign'], original_m=d2)
271
+ mt_buf = _unnmf((state['mu_mbuf_nmf'], state['mv_mbuf_nmf']))
272
+ unpacked_sign = _unpack_bools(state['sign_buf'], original_m=d2)
188
273
  torch.where(unpacked_sign, mt_buf, -mt_buf, out=mt_buf)
189
274
  del unpacked_sign
190
275
 
@@ -231,9 +316,10 @@ class AdaMuon_adv(torch.optim.Optimizer):
231
316
  eps=group['ns_eps'],
232
317
  coeffs=group['ns_coeffs'],
233
318
  )
319
+ del signed_m_buf
234
320
 
235
321
  # Reconstruct second momentum from previous step's factors
236
- vt_buf = _unnmf((state['mu_v_nmf'], state['mv_v_nmf']))
322
+ vt_buf = _unnmf((state['mu_vbuf_nmf'], state['mv_vbuf_nmf']))
237
323
 
238
324
  # Update second momentum in full-size
239
325
  vt_buf.mul_(beta2).addcmul_(update, update, value=1 - beta2)
@@ -252,49 +338,38 @@ class AdaMuon_adv(torch.optim.Optimizer):
252
338
  rms_target = group['rms_target']
253
339
  num_elements = update.numel()
254
340
  # Add eps to prevent division by zero
255
- scaling_factor = rms_target * (num_elements ** 0.5) / (update.norm() + group['eps'])
341
+ update.mul_(rms_target * (num_elements ** 0.5) / (update.norm() + group['eps']))
256
342
 
257
- update.mul_(scaling_factor)
258
- update = update.view(p.shape).mul_(group['lr'])
259
- del num_elements, scaling_factor
343
+ update = update.view(p.shape).mul_(lr)
344
+ del num_elements
260
345
 
261
346
  # Compress updated moments and store new factors
262
- state['sign'] = _pack_bools(mt_buf > 0)
263
- _nnmf(mt_buf.abs(), out=(state['mu_m_nmf'], state['mv_m_nmf']))
347
+ state['sign_buf'] = _pack_bools(mt_buf > 0)
348
+ _nnmf(mt_buf.abs(), out=(state['mu_mbuf_nmf'], state['mv_mbuf_nmf']))
264
349
  del mt_buf
265
350
 
266
- _nnmf(vt_buf.abs(), out=(state['mu_v_nmf'], state['mv_v_nmf']))
351
+ _nnmf(vt_buf.abs(), out=(state['mu_vbuf_nmf'], state['mv_vbuf_nmf']))
267
352
  del vt_buf
268
353
 
269
354
  else: # Standard AdaMuon logic for non-factored tensors
270
355
 
271
- if len(p.shape) >= 2 or state['reshaped_1d_muon']:
356
+ if len(p.shape) >= 2:
357
+
358
+ original_shape = p.shape
272
359
 
273
360
  # Momentum update
274
361
  mt_buf = state['momentum_buffer']
275
- if state['reshaped_1d_muon']:
276
- d1, d2 = state['effective_shape']
277
- grad_reshaped = grad.view(d1, d2)
278
- mt_buf.mul_(beta1).add_(grad_reshaped)
279
- if nesterov:
280
- signed_m_buf = torch.sign(grad_reshaped.add(mt_buf, alpha=beta1))
281
- elif Simplified_AdEMAMix:
282
- signed_m_buf = torch.sign(mt_buf.add(grad_reshaped, alpha=alpha_grad))
283
- else:
284
- signed_m_buf = torch.sign(mt_buf)
285
- del grad_reshaped
362
+ mt_buf.mul_(beta1).add_(grad)
363
+
364
+ if nesterov:
365
+ signed_m_buf = torch.sign(grad.add(mt_buf, alpha=beta1))
366
+ elif Simplified_AdEMAMix:
367
+ signed_m_buf = torch.sign(mt_buf.add(grad, alpha=alpha_grad))
286
368
  else:
287
- mt_buf.mul_(beta1).add_(grad)
288
- if nesterov:
289
- signed_m_buf = torch.sign(grad.add(mt_buf, alpha=beta1))
290
- elif Simplified_AdEMAMix:
291
- signed_m_buf = torch.sign(mt_buf.add(grad, alpha=alpha_grad))
292
- else:
293
- signed_m_buf = torch.sign(mt_buf)
369
+ signed_m_buf = torch.sign(mt_buf)
294
370
 
295
371
  # Flatten if necessary (e.g., for Conv layers)
296
- if len(p.shape) > 2:
297
- signed_m_buf = signed_m_buf.view(p.shape[0], -1)
372
+ signed_m_buf = signed_m_buf.view(original_shape[0], -1)
298
373
 
299
374
  # Orthogonalization step
300
375
  if group['low_rank_ortho']:
@@ -327,9 +402,9 @@ class AdaMuon_adv(torch.optim.Optimizer):
327
402
  eps=group['ns_eps'],
328
403
  coeffs=group['ns_coeffs'],
329
404
  )
405
+ del signed_m_buf
330
406
 
331
- if len(p.shape) > 2 or state['reshaped_1d_muon']:
332
- update = update.view(p.shape)
407
+ update = update.view(original_shape)
333
408
 
334
409
  vt_buf = state['second_momentum_buffer']
335
410
  vt_buf.mul_(beta2).addcmul_(update, update, value=1 - beta2)
@@ -348,31 +423,167 @@ class AdaMuon_adv(torch.optim.Optimizer):
348
423
  rms_target = group['rms_target']
349
424
  num_elements = update.numel()
350
425
  # Add eps to prevent division by zero
351
- scaling_factor = rms_target * (num_elements ** 0.5) / (update.norm() + group['eps'])
426
+ update.mul_(rms_target * (num_elements ** 0.5) / (update.norm() + group['eps']))
352
427
 
353
- update.mul_(scaling_factor)
354
- del num_elements, scaling_factor
428
+ del num_elements
355
429
 
356
- update.mul_(group['lr'])
430
+ update.mul_(lr)
357
431
 
358
- else: # Fallback to standard SGD with momentum for 1D params (biases, etc.) when not reshaped
432
+ else: # Fallback to standard SGD with momentum for 1D params (biases, etc.)
359
433
  # Momentum update
360
434
  mt_buf = state['momentum_buffer']
361
435
  mt_buf.mul_(beta1).add_(grad)
362
436
  if nesterov:
437
+ # Nesterov momentum
363
438
  update = grad.add(mt_buf, alpha=beta1)
364
- elif Simplified_AdEMAMix:
365
- signed_m_buf = torch.sign(mt_buf.add(grad, alpha=alpha_grad))
439
+ # elif Simplified_AdEMAMix: # TODO, it will break SGD since it requires x100 lower LR
440
+ # update = mt_buf.add(grad, alpha=alpha_grad)
366
441
  else:
367
442
  update = mt_buf.clone()
368
- update.mul_(group['lr'])
443
+ update.mul_(lr)
369
444
 
370
445
  # Decoupled weight decay
371
446
  if group["weight_decay"] != 0:
372
447
  if p.dtype == torch.bfloat16 and self.stochastic_rounding:
373
- add_stochastic_(p.data, p.data, alpha=-group["weight_decay"] * group["lr"])
448
+ add_stochastic_(p.data, p.data, alpha=-group["weight_decay"] * lr)
449
+ else:
450
+ p.data.add_(p.data, alpha=-group["weight_decay"] * lr)
451
+
452
+ if p.dtype == torch.bfloat16 and self.stochastic_rounding:
453
+ add_stochastic_(p.data, -update)
454
+ else:
455
+ p.data.add_(-update)
456
+ del update
457
+
458
+ @torch.no_grad()
459
+ def _adam_step_parameter(self, p, grad, state, group, lr, bias_correction1, bias_correction2):
460
+ if grad.dtype != torch.float32 and state.get('factored', False):
461
+ grad = grad.float()
462
+ if group.get("adam_orthogonal_gradient"):
463
+ grad = _orthogonalize_gradient(p, grad)
464
+
465
+ beta1_adam, beta2_adam = group['adam_betas']
466
+
467
+ if group.get('adam_kourkoutas_beta', False):
468
+ # Accumulate current grad's norm for the *next* step
469
+ self.kourkoutas_helper.accumulate_gradient_sq_norm(p, grad)
470
+ # Get the dynamic beta2_adam calculated in prepare_step()
471
+ beta2_adam = self.kourkoutas_helper.get_beta2(p, group)
472
+
473
+ step_size = lr / bias_correction1
474
+
475
+ if group.get('adam_use_AdEMAMix'):
476
+ beta3_ema = group['adam_beta3_ema']
477
+ alpha = group['adam_alpha']
478
+
479
+ if state['factored']:
480
+ d1, d2 = state['effective_shape']
481
+ grad_reshaped = grad.view(d1, d2)
482
+
483
+ # Reconstruct momentum from previous step's factors
484
+ if beta1_adam > 0:
485
+ mt = _unnmf((state['mu_m_nmf'], state['mv_m_nmf']))
486
+ if not group.get('adam_grams_moment'):
487
+ unpacked_sign = _unpack_bools(state['sign'], original_m=d2)
488
+ torch.where(unpacked_sign, mt, -mt, out=mt)
489
+ del unpacked_sign
490
+ # Update momentum in full-size
491
+ mt.mul_(beta1_adam).add_(grad_reshaped, alpha=1.0 - beta1_adam)
492
+ if group.get('adam_grams_moment'):
493
+ mt = (grad_reshaped.sign().mul_(mt.abs()))
494
+ elif group.get('adam_cautious_mask'):
495
+ mask = (mt * grad_reshaped > 0).to(grad_reshaped.dtype)
496
+ mask.div_(mask.mean().clamp_(min=1e-3))
497
+ mt.mul_(mask)
498
+ del mask
499
+
500
+ vt = _unnmf((state['mu_v_nmf'], state['mv_v_nmf']))
501
+ vt.mul_(beta2_adam).addcmul_(grad_reshaped, grad_reshaped, value=1.0 - beta2_adam)
502
+
503
+ if group.get('adam_use_AdEMAMix'):
504
+ mt_slow = _unnmf((state['mu_m_slow_nmf'], state['mv_m_slow_nmf']))
505
+ if state['sign_slow'].dtype != torch.uint8:
506
+ state['sign_slow'] = state['sign_slow'].to(torch.uint8)
507
+ unpacked_sign_slow = _unpack_bools(state['sign_slow'], original_m=d2)
508
+ torch.where(unpacked_sign_slow, mt_slow, -mt_slow, out=mt_slow)
509
+ del unpacked_sign_slow
510
+
511
+ mt_slow.mul_(beta3_ema).add_(grad_reshaped, alpha=1.0 - beta3_ema)
512
+ if beta1_adam > 0:
513
+ update = torch.add(mt, mt_slow, alpha=alpha)
514
+ else:
515
+ update = torch.add(grad_reshaped, mt_slow, alpha=alpha)
516
+ else:
517
+ update = mt.clone() if beta1_adam > 0 else grad_reshaped.clone()
518
+ del grad_reshaped
519
+
520
+ if group['adam_use_atan2']:
521
+ a = 1.2732395
522
+ denom = (vt.sqrt() / (bias_correction2**0.5))
523
+ update.atan2_(denom).mul_(a)
524
+ else:
525
+ denom = (vt.sqrt() / (bias_correction2**0.5)).add_(group['adam_eps'])
526
+ update.div_(denom)
527
+ del denom
528
+
529
+ update = update.view(p.shape).mul_(step_size)
530
+
531
+ # Compress updated moments and store new factors
532
+ if beta1_adam > 0:
533
+ if not group.get('adam_grams_moment'):
534
+ state['sign'] = _pack_bools(mt > 0)
535
+ _nnmf(mt.abs(), out=(state['mu_m_nmf'], state['mv_m_nmf']))
536
+ del mt
537
+ if group.get('adam_use_AdEMAMix'):
538
+ state['sign_slow'] = _pack_bools(mt_slow > 0)
539
+ _nnmf(mt_slow.abs(), out=(state['mu_m_slow_nmf'], state['mv_m_slow_nmf']))
540
+ del mt_slow
541
+ _nnmf(vt, out=(state['mu_v_nmf'], state['mv_v_nmf']))
542
+ del vt
543
+
544
+ else: # Standard AdamW logic for non-factored tensors
545
+ exp_avg_sq = state['exp_avg_sq']
546
+
547
+ if beta1_adam > 0:
548
+ exp_avg = state['exp_avg']
549
+ exp_avg.mul_(beta1_adam).add_(grad, alpha=1 - beta1_adam)
550
+ if group.get('adam_grams_moment'):
551
+ exp_avg = grad.sign().mul_(exp_avg.abs())
552
+ elif group.get('adam_cautious_mask'):
553
+ mask = (exp_avg * grad > 0).to(grad.dtype)
554
+ mask.div_(mask.mean().clamp_(min=1e-3))
555
+ exp_avg.mul_(mask)
556
+ del mask
557
+
558
+ if group.get('adam_use_AdEMAMix'):
559
+ exp_avg_slow = state['exp_avg_slow']
560
+ exp_avg_slow.mul_(beta3_ema).add_(grad, alpha=1 - beta3_ema)
561
+ if beta1_adam > 0:
562
+ update = torch.add(exp_avg, exp_avg_slow, alpha=alpha)
563
+ else:
564
+ update = torch.add(grad, exp_avg_slow, alpha=alpha)
565
+ else:
566
+ update = exp_avg.clone() if beta1_adam > 0 else grad.clone()
567
+
568
+ exp_avg_sq.mul_(beta2_adam).addcmul_(grad, grad.conj(), value=1 - beta2_adam)
569
+
570
+ if group.get('adam_use_atan2'):
571
+ a = 1.2732395
572
+ denom = (exp_avg_sq.sqrt() / (bias_correction2**0.5))
573
+ update.atan2_(denom).mul_(a)
574
+ else:
575
+ denom = (exp_avg_sq.sqrt() / (bias_correction2**0.5)).add_(group['adam_eps'])
576
+ update.div_(denom)
577
+ del denom
578
+
579
+ update.mul_(step_size)
580
+
581
+ # Decoupled weight decay
582
+ if group["adam_weight_decay"] != 0:
583
+ if p.dtype == torch.bfloat16 and self.stochastic_rounding:
584
+ add_stochastic_(p.data, p.data, alpha=-group["adam_weight_decay"] * lr)
374
585
  else:
375
- p.data.add_(p.data, alpha=-group["weight_decay"] * group["lr"])
586
+ p.data.add_(p.data, alpha=-group["adam_weight_decay"] * lr)
376
587
 
377
588
  if p.dtype == torch.bfloat16 and self.stochastic_rounding:
378
589
  add_stochastic_(p.data, -update)
@@ -380,7 +591,61 @@ class AdaMuon_adv(torch.optim.Optimizer):
380
591
  p.data.add_(-update)
381
592
  del update
382
593
 
383
- state['step'] += 1
594
+ @torch.no_grad()
595
+ def step_parameter(self, p: torch.Tensor, group: dict, i: int | None = None):
596
+ grad = p.grad
597
+ if grad is None:
598
+ return
599
+ state = self.state[p]
600
+
601
+
602
+
603
+ # Determine if using Adam or Muon based on state keys
604
+ # We can use optm_type but I see this as a safer way.
605
+ if 'momentum_buffer' in state or 'mu_mbuf_nmf' in state:
606
+ use_adam = False
607
+ else:
608
+ use_adam = True
609
+
610
+ lr = group['lr']
611
+ is_compiled = group.get('compiled_optimizer', False)
612
+
613
+ if use_adam:
614
+ if self.kourkoutas_helper:
615
+ # Prepare Kourkoutas-β once per optimizer step.
616
+ self.kourkoutas_helper.maybe_prepare_step(self.global_step)
617
+ # Adam-specific setup (bias correction)
618
+ if group['adam_use_bias_correction']:
619
+ current_step = self.global_step + 1
620
+ beta1_adam, beta2_adam = group['adam_betas']
621
+ bias_correction1 = 1.0 - beta1_adam ** current_step
622
+ bias_correction2 = 1.0 - beta2_adam ** current_step
623
+ else:
624
+ bias_correction1 = 1.0
625
+ bias_correction2 = 1.0
626
+ # Dispatch to compiled or uncompiled Adam step
627
+ if is_compiled and self._compiled_adam_step is not None:
628
+ # Tensors must be used for compiled functions
629
+ lr_tensor = torch.tensor(lr, device=p.device)
630
+ bc1_tensor = torch.tensor(bias_correction1, device=p.device)
631
+ bc2_tensor = torch.tensor(bias_correction2, device=p.device)
632
+ self._compiled_adam_step(p, grad, state, group, lr_tensor, bc1_tensor, bc2_tensor)
633
+ else:
634
+ self._adam_step_parameter(p, grad, state, group, lr, bias_correction1, bias_correction2)
635
+ else: # Muon path
636
+ # Dispatch to compiled or uncompiled Muon step
637
+ if is_compiled and self._compiled_muon_step is not None:
638
+ lr_tensor = torch.tensor(lr, device=p.device)
639
+ self._compiled_muon_step(p, grad, state, group, lr_tensor)
640
+ else:
641
+ self._muon_step_parameter(p, grad, state, group, lr)
642
+
643
+
644
+ def compile(self, *args, **kwargs):
645
+ print("Compiling AdaMuon step path...")
646
+ self._compiled_muon_step = torch.compile(self._muon_step_parameter, *args, **kwargs)
647
+ print("Compiling AuxAdam step path...")
648
+ self._compiled_adam_step = torch.compile(self._adam_step_parameter, *args, **kwargs)
384
649
 
385
650
  @torch.no_grad()
386
651
  def step(self, closure=None):
@@ -394,4 +659,6 @@ class AdaMuon_adv(torch.optim.Optimizer):
394
659
  for i, p in enumerate(group['params']):
395
660
  self.step_parameter(p, group, i)
396
661
 
662
+ self.global_step += 1
663
+
397
664
  return loss
@@ -1,28 +1,23 @@
1
1
  import torch
2
- from typing import Optional
3
- from .AdamW_adv import AdamW_adv
4
2
 
5
3
  from ..util.BF16_Stochastic_Rounding import add_stochastic_
6
4
  from ..util.Newton_Schulz import _newton_schulz_iteration
7
5
  from ..util.Effective_Shape import _get_effective_shape
8
6
  from ..util.NNMF import _nnmf,_unnmf
9
7
  from ..util.One_Bit_Boolean import _pack_bools, _unpack_bools
8
+ from ..util.OrthoGrad import _orthogonalize_gradient
9
+ from ..util.Kourkoutas import KourkoutasHelper
10
10
 
11
11
  class Muon_adv(torch.optim.Optimizer):
12
12
  """
13
- Implements an advanced Muon algorithm.
13
+ Implements an advanced Muon algorithm, with an integrated auxiliary AdamW optimizer.
14
14
 
15
15
  Muon (MomentUm Orthogonalized by Newton-Schulz) is an optimizer designed for
16
16
  the hidden layers of neural networks. It applies SGD with momentum and then
17
17
  orthogonalizes the resulting update matrix using a Newton-Schulz iteration.
18
18
 
19
- NorMuon (Neuron-wise Normalized Muon) extends this by adding neuron-level
20
- adaptive learning rates, combining the benefits of orthogonalization with
21
- second-order momentum statistics.
22
-
23
- This implementation is designed for 2D parameters (e.g., linear layers) and
24
- can handle other-dimensional parameters (e.g., 1D bias, 4D convolutional layers) by
25
- flattening/reshaping them.
19
+ When `MuonWithAuxAdam` is enabled, this single optimizer class handles both
20
+ 'muon' and 'adam' parameter groups, dispatching to the appropriate logic internally.
26
21
 
27
22
  Args:
28
23
  params (iterable): iterable of parameters to optimize or dicts defining
@@ -65,6 +60,19 @@ class Muon_adv(torch.optim.Optimizer):
65
60
  normuon_lr_scale (float): Scaling factor for the NorMuon learning rate.
66
61
  (default: 0.2)
67
62
  normuon_atan2 (bool): whether to use the atan2 for NorMuon. (default: False)
63
+ --- Auxiliary AdamW_adv Parameters (used for 'adam' groups) ---
64
+ adam_betas (tuple[float, float]): Betas for the AdamW optimizer part.
65
+ adam_eps (float): Epsilon for the AdamW optimizer part.
66
+ adam_weight_decay (float): Weight decay for the AdamW optimizer part.
67
+ adam_use_bias_correction (bool): Bias correction for AdamW.
68
+ adam_use_atan2 (bool): Atan2 update rule for AdamW.
69
+ adam_cautious_mask (bool): Cautious masking for AdamW.
70
+ adam_grams_moment (bool): Grams-style updates for AdamW.
71
+ adam_orthogonal_gradient (bool): OrthoGrad for AdamW.
72
+ adam_use_AdEMAMix (bool): AdEMAMix for AdamW.
73
+ adam_beta3_ema (float): Beta3 for AdEMAMix.
74
+ adam_alpha (float): Alpha for AdEMAMix.
75
+ adam_kourkoutas_beta (bool): Kourkoutas-β for AdamW.
68
76
  """
69
77
 
70
78
  def __init__(
@@ -92,6 +100,25 @@ class Muon_adv(torch.optim.Optimizer):
92
100
  normuon_eps: float = 1e-8,
93
101
  normuon_lr_scale: float = 0.2,
94
102
  normuon_atan2: bool = False,
103
+ # Compiled
104
+ compiled_optimizer: bool = False,
105
+ # --- AdamW_adv specific parameters ---
106
+ adam_betas: tuple[float, float] = (0.9, 0.99),
107
+ adam_eps: float = 1e-8,
108
+ adam_weight_decay: float = 0.0,
109
+ adam_use_bias_correction: bool = True,
110
+ adam_use_atan2: bool = False,
111
+ adam_cautious_mask: bool = False,
112
+ adam_grams_moment: bool = False,
113
+ adam_orthogonal_gradient: bool = False,
114
+ adam_use_AdEMAMix: bool = False,
115
+ adam_beta3_ema: float = 0.9999,
116
+ adam_alpha: float = 5.0,
117
+ adam_kourkoutas_beta: bool = False,
118
+ adam_beta2_min: float = 0.9,
119
+ adam_ema_alpha: float = 0.95,
120
+ adam_tiny_spike: float = 1e-9,
121
+ adam_k_warmup_steps: int = 0,
95
122
  ):
96
123
  if not (lr >= 0.0):
97
124
  raise ValueError(f"Learning-rate should be >= 0.0. Got {lr}")
@@ -104,7 +131,7 @@ class Muon_adv(torch.optim.Optimizer):
104
131
  if not (ns_steps > 0):
105
132
  raise ValueError(f"Newton-Schulz steps should be > 0. Got {ns_steps}")
106
133
  if Simplified_AdEMAMix and nesterov:
107
- print("Warning: nesterov is incompatible with Simplified_AdEMAMix, Disabling cautious.")
134
+ print("Warning: nesterov is incompatible with Simplified_AdEMAMix, Disabling nesterov.")
108
135
  nesterov = False
109
136
 
110
137
  defaults = {
@@ -114,17 +141,43 @@ class Muon_adv(torch.optim.Optimizer):
114
141
  "vector_reshape": vector_reshape,
115
142
  "vector_reshape_muon": vector_reshape_muon,
116
143
  "Simplified_AdEMAMix": Simplified_AdEMAMix, "alpha_grad": alpha_grad,
144
+ 'compiled_optimizer': compiled_optimizer,
117
145
  # Low-rank Ortho
118
146
  "low_rank_ortho": low_rank_ortho, "ortho_rank": ortho_rank,
119
147
  # NorMuon
120
148
  "normuon_variant": normuon_variant, "beta2_normuon": beta2_normuon,
121
149
  "normuon_eps": normuon_eps, "normuon_lr_scale": normuon_lr_scale,
122
150
  "normuon_atan2": normuon_atan2,
151
+ # AdamW_adv defaults
152
+ "adam_betas": adam_betas, "adam_eps": adam_eps, "adam_weight_decay": adam_weight_decay,
153
+ "adam_use_bias_correction": adam_use_bias_correction, "adam_use_atan2": adam_use_atan2,
154
+ "adam_cautious_mask": adam_cautious_mask, "adam_grams_moment": adam_grams_moment,
155
+ "adam_orthogonal_gradient": adam_orthogonal_gradient,
156
+ "adam_use_AdEMAMix": adam_use_AdEMAMix, "adam_beta3_ema": adam_beta3_ema, "adam_alpha": adam_alpha,
157
+ "adam_kourkoutas_beta": adam_kourkoutas_beta, "adam_beta2_min": adam_beta2_min,
158
+ "adam_ema_alpha": adam_ema_alpha, "adam_tiny_spike": adam_tiny_spike,
159
+ "adam_k_warmup_steps": adam_k_warmup_steps,
123
160
  }
124
161
  self.stochastic_rounding = stochastic_rounding
162
+ self.compiled_optimizer = compiled_optimizer
125
163
 
126
164
  super().__init__(params, defaults)
127
165
 
166
+ self.global_step = 0 # For Adam bias correction and Kourkoutas
167
+ self.kourkoutas_helper = None
168
+ if any(group.get('adam_kourkoutas_beta', False) for group in self.param_groups):
169
+ self.kourkoutas_helper = KourkoutasHelper(self)
170
+
171
+ self.init_step()
172
+
173
+ # Initialize compiled functions to None
174
+ self._compiled_muon_step = None
175
+ self._compiled_adam_step = None
176
+
177
+ if compiled_optimizer:
178
+ print("Compiling Muon_adv optimizer paths...")
179
+ torch._dynamo.config.cache_size_limit = 8192
180
+ self.compile(fullgraph=True)
128
181
 
129
182
  @property
130
183
  def supports_fused_back_pass(self):
@@ -138,53 +191,90 @@ class Muon_adv(torch.optim.Optimizer):
138
191
  def supports_flat_params(self):
139
192
  return False
140
193
 
194
+ def init_step(self):
195
+ for group in self.param_groups:
196
+ for i, p in enumerate(group['params']):
197
+ self.__init_state(p, group)
198
+
141
199
  @torch.no_grad()
142
- def step_parameter(self, p: torch.Tensor, group: dict, i: int | None = None):
143
- if p.grad is None:
200
+ def __init_state(self, p, group):
201
+ state = self.state[p]
202
+
203
+ if len(state) > 0:
144
204
  return
145
205
 
146
- grad = p.grad
147
- state = self.state[p]
206
+ optim_type = group.get('optim_type', 'muon')
148
207
 
149
- # State Initialization
150
- if 'step' not in state:
151
- state['step'] = 0
208
+ state['factored'] = (
209
+ group['nnmf_factor'] and
210
+ not (len(p.shape) == 1 and not group['vector_reshape'])
211
+ )
212
+ dtype = torch.float32 if state['factored'] else p.dtype
213
+ device = p.device
152
214
 
153
- should_factor = (
215
+ if optim_type == 'muon':
216
+
217
+ state['factored'] = (
154
218
  group['nnmf_factor'] and
155
219
  not (len(p.shape) == 1 and not group['vector_reshape'])
156
220
  )
157
-
158
- state['factored'] = should_factor
159
-
160
- state['reshaped_1d_muon'] = len(p.shape) == 1 and group['vector_reshape_muon']
161
-
162
- dtype = torch.float32 if group['nnmf_factor'] else p.dtype
221
+ dtype = torch.float32 if state['factored'] else p.dtype
163
222
  device = p.device
164
- if state['factored'] or state['reshaped_1d_muon']:
223
+
224
+ if state['factored']:
165
225
  state['effective_shape'] = _get_effective_shape(p.numel())
166
226
  d1, d2 = state['effective_shape']
167
- if state['factored']:
168
- state['mu_m_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
169
- state['mv_m_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
227
+ state['mu_mbuf_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
228
+ state['mv_mbuf_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
170
229
  packed_d2 = (d2 + 7) // 8
171
- state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=device)
230
+ state['sign_buf'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=device)
172
231
  else:
173
- if len(p.shape) >= 2:
174
- state['momentum_buffer'] = torch.zeros_like(p)
175
- if state['reshaped_1d_muon']:
176
- state['momentum_buffer'] = torch.zeros((d1, d2), device=device, dtype=dtype)
177
- elif len(p.shape) == 1:
178
- state['momentum_buffer'] = torch.zeros_like(p)
232
+ state['momentum_buffer'] = torch.zeros_like(p)
179
233
 
180
234
  # NorMuon state initialization
181
235
  if group['normuon_variant']:
182
236
  if state['factored']:
183
237
  state['normuon_v'] = torch.zeros(d1, device=p.device, dtype=torch.float32)
184
- elif len(p.shape) >= 2 or state['reshaped_1d_muon']:
185
- num_rows = p.shape[0] if len(p.shape) >= 2 else state['effective_shape'][0]
186
- state['normuon_v'] = torch.zeros(num_rows, device=p.device, dtype=torch.float32)
238
+ elif len(p.shape) >= 2:
239
+ state['normuon_v'] = torch.zeros(p.shape[0], device=p.device, dtype=torch.float32)
240
+
241
+
242
+ elif optim_type == 'adam':
243
+
244
+ state['factored'] = (
245
+ group['adam_nnmf_factor'] and
246
+ not (len(p.shape) == 1 and not group['vector_reshape'])
247
+ )
248
+ dtype = torch.float32 if state['factored'] else p.dtype
249
+ device = p.device
250
+
251
+ if state['factored']:
252
+ state['effective_shape'] = _get_effective_shape(p.numel())
253
+ d1, d2 = state['effective_shape']
254
+ # First moment (m)
255
+ if group['adam_betas'][0] > 0:
256
+ state['mu_m_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
257
+ state['mv_m_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
258
+ if not group.get('adam_grams_moment'):
259
+ packed_d2 = (d2 + 7) // 8
260
+ state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=device)
261
+ if group.get('adam_use_AdEMAMix'):
262
+ state['mu_m_slow_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
263
+ state['mv_m_slow_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
264
+ packed_d2 = (d2 + 7) // 8
265
+ state['sign_slow'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=p.device)
266
+ # Second moment (v)
267
+ state['mu_v_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
268
+ state['mv_v_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
269
+ else: # Fallback to standard AdamW for non-factored tensors
270
+ if group['adam_betas'][0] > 0:
271
+ state['exp_avg'] = torch.zeros_like(p, device=device, dtype=dtype)
272
+ if group.get('adam_use_AdEMAMix'):
273
+ state['exp_avg_slow'] = torch.zeros_like(p, device=device, dtype=dtype)
274
+ state['exp_avg_sq'] = torch.zeros_like(p, device=device, dtype=dtype)
187
275
 
276
+ @torch.no_grad()
277
+ def _muon_step_parameter(self, p, grad, state, group, lr):
188
278
  beta1 = group['beta1']
189
279
  nesterov = group['nesterov']
190
280
  Simplified_AdEMAMix = group['Simplified_AdEMAMix']
@@ -194,8 +284,8 @@ class Muon_adv(torch.optim.Optimizer):
194
284
 
195
285
  # Reconstruct momentum from previous step's factors & sign
196
286
  d1, d2 = state['effective_shape']
197
- mt_buf = _unnmf((state['mu_m_nmf'], state['mv_m_nmf']))
198
- unpacked_sign = _unpack_bools(state['sign'], original_m=d2)
287
+ mt_buf = _unnmf((state['mu_mbuf_nmf'], state['mv_mbuf_nmf']))
288
+ unpacked_sign = _unpack_bools(state['sign_buf'], original_m=d2)
199
289
  torch.where(unpacked_sign, mt_buf, -mt_buf, out=mt_buf)
200
290
  del unpacked_sign
201
291
 
@@ -260,51 +350,41 @@ class Muon_adv(torch.optim.Optimizer):
260
350
  update.div_(v_t.sqrt().unsqueeze(1).add_(group['normuon_eps']))
261
351
  # Scale learning rate
262
352
  update_norm = torch.linalg.vector_norm(update)
263
- if update_norm > 1e-12:
264
- scaled_lr = group['normuon_lr_scale'] * group['lr'] * (p.numel()**0.5) / update_norm
265
- else:
266
- scaled_lr = 0.0
353
+
354
+ scaled_lr = group['normuon_lr_scale'] * lr * (p.numel()**0.5) / update_norm.add_(group['normuon_eps'])
355
+
267
356
  update = update.view(p.shape).mul_(scaled_lr)
357
+ del update_norm, scaled_lr
268
358
  else: # Original Muon learning rate application
269
- update = update.view(p.shape).mul_(group['lr'])
359
+ update = update.view(p.shape).mul_(lr)
270
360
 
271
- state['sign'] = _pack_bools(mt_buf > 0)
272
- _nnmf(mt_buf.abs(), out=(state['mu_m_nmf'], state['mv_m_nmf']))
361
+ state['sign_buf'] = _pack_bools(mt_buf > 0)
362
+ _nnmf(mt_buf.abs(), out=(state['mu_mbuf_nmf'], state['mv_mbuf_nmf']))
273
363
  del mt_buf
274
364
 
275
365
  else: # Standard Muon logic for non-factored tensors
276
366
 
277
- if len(p.shape) >= 2 or state['reshaped_1d_muon']:
367
+ if len(p.shape) >= 2:
368
+
369
+ original_shape = p.shape
278
370
 
279
371
  # Momentum update
280
372
  mt_buf = state['momentum_buffer']
281
- if state['reshaped_1d_muon']:
282
- d1, d2 = state['effective_shape']
283
- grad_reshaped = grad.view(d1, d2)
284
- mt_buf.mul_(beta1).add_(grad_reshaped)
285
- else:
286
- mt_buf.mul_(beta1).add_(grad)
373
+ mt_buf.mul_(beta1).add_(grad)
287
374
 
288
375
  if nesterov:
289
376
  # Nesterov momentum
290
- if state['reshaped_1d_muon']:
291
- update = grad_reshaped.add(mt_buf, alpha=beta1)
292
- del grad_reshaped
293
- else:
294
- update = grad.add(mt_buf, alpha=beta1)
377
+ update = grad.add(mt_buf, alpha=beta1)
295
378
  elif Simplified_AdEMAMix:
296
- if state['reshaped_1d_muon']:
297
- update = torch.add(mt_buf, grad_reshaped, alpha=alpha_grad)
298
- del grad_reshaped
299
- else:
300
- update = torch.add(mt_buf, grad, alpha=alpha_grad)
379
+ update = torch.add(mt_buf, grad, alpha=alpha_grad)
301
380
  else:
302
381
  # Standard momentum
303
382
  update = mt_buf.clone()
304
383
 
305
- # For Conv layers (4D) or other high-dim tensors, flatten to 2D
306
- if len(p.shape) > 2:
307
- update = update.view(p.shape[0], -1)
384
+ # flatten to 2D for orthogonalization.
385
+ # This is a no-op for 2D tensors and correctly flattens 4D+ tensors.
386
+ # This removes the dynamic control flow that breaks torch.compile.
387
+ update = update.view(original_shape[0], -1)
308
388
 
309
389
  # Orthogonalization step
310
390
  if group['low_rank_ortho']:
@@ -316,7 +396,7 @@ class Muon_adv(torch.optim.Optimizer):
316
396
  # 1. Sketch the matrix
317
397
  G_sketch = torch.randn(M.shape[1], r, device=M.device, dtype=M.dtype)
318
398
  MG = M @ G_sketch
319
-
399
+
320
400
  # 2. QR decomposition to get orthogonal basis Q
321
401
  if MG.dtype != torch.float32:
322
402
  MG_dtype = M.dtype
@@ -324,10 +404,10 @@ class Muon_adv(torch.optim.Optimizer):
324
404
  Q = Q.to(MG_dtype)
325
405
  else:
326
406
  Q, _ = torch.linalg.qr(MG)
327
-
407
+
328
408
  # 3. Project M onto the basis
329
409
  projected_M = Q.T @ M
330
-
410
+
331
411
  # 4. Orthogonalize the smaller projected matrix
332
412
  ortho_projected_M = _newton_schulz_iteration(
333
413
  projected_M,
@@ -335,7 +415,7 @@ class Muon_adv(torch.optim.Optimizer):
335
415
  eps=group['ns_eps'],
336
416
  coeffs=group['ns_coeffs'],
337
417
  )
338
-
418
+
339
419
  # 5. Project back to the original space
340
420
  update = Q @ ortho_projected_M
341
421
  else: # Fallback for invalid rank
@@ -369,19 +449,16 @@ class Muon_adv(torch.optim.Optimizer):
369
449
  update.div_(v_t.sqrt().unsqueeze(1).add_(group['normuon_eps']))
370
450
  # Scale learning rate
371
451
  update_norm = torch.linalg.vector_norm(update)
372
- if update_norm > 1e-12:
373
- scaled_lr = group['normuon_lr_scale'] * group['lr'] * (p.numel()**0.5) / update_norm
374
- else:
375
- scaled_lr = 0.0
452
+ scaled_lr = group['normuon_lr_scale'] * lr * (p.numel()**0.5) / update_norm.add_(group['normuon_eps'])
376
453
  update.mul_(scaled_lr)
377
454
  else: # Original Muon learning rate application
378
- update.mul_(group['lr'])
455
+ update.mul_(lr)
456
+
457
+ # reshape back to the original shape.
458
+ update = update.view(original_shape)
379
459
 
380
- # Reshape back to original if we flattened or reshaped
381
- if len(p.shape) > 2 or state['reshaped_1d_muon']:
382
- update = update.view(p.shape)
383
-
384
- else: # Fallback to standard SGD with momentum for 1D params (biases, etc.) when not reshaped
460
+
461
+ else: # Fallback to standard SGD with momentum for 1D params (biases, etc.)
385
462
  # Momentum update
386
463
  mt_buf = state['momentum_buffer']
387
464
  mt_buf.mul_(beta1).add_(grad)
@@ -393,14 +470,14 @@ class Muon_adv(torch.optim.Optimizer):
393
470
  else:
394
471
  # Standard momentum
395
472
  update = mt_buf.clone()
396
- update.mul_(group['lr'])
473
+ update.mul_(lr)
397
474
 
398
475
  # Decoupled weight decay
399
476
  if group["weight_decay"] != 0:
400
477
  if p.dtype == torch.bfloat16 and self.stochastic_rounding:
401
- add_stochastic_(p.data, p.data, alpha=-group["weight_decay"] * group["lr"])
478
+ add_stochastic_(p.data, p.data, alpha=-group["weight_decay"] * lr)
402
479
  else:
403
- p.data.add_(p.data, alpha=-group["weight_decay"] * group["lr"])
480
+ p.data.add_(p.data, alpha=-group["weight_decay"] * lr)
404
481
 
405
482
  if p.dtype == torch.bfloat16 and self.stochastic_rounding:
406
483
  add_stochastic_(p.data, -update)
@@ -408,7 +485,198 @@ class Muon_adv(torch.optim.Optimizer):
408
485
  p.data.add_(-update)
409
486
  del update
410
487
 
411
- state['step'] += 1
488
+ @torch.no_grad()
489
+ def _adam_step_parameter(self, p, grad, state, group, lr, bias_correction1, bias_correction2):
490
+ if grad.dtype != torch.float32 and state.get('factored', False):
491
+ grad = grad.float()
492
+ if group.get("adam_orthogonal_gradient"):
493
+ grad = _orthogonalize_gradient(p, grad)
494
+
495
+ beta1_adam, beta2_adam = group['adam_betas']
496
+
497
+ if group.get('adam_kourkoutas_beta', False):
498
+ # Accumulate current grad's norm for the *next* step
499
+ self.kourkoutas_helper.accumulate_gradient_sq_norm(p, grad)
500
+ # Get the dynamic beta2_adam calculated in prepare_step()
501
+ beta2_adam = self.kourkoutas_helper.get_beta2(p, group)
502
+
503
+ step_size = lr / bias_correction1
504
+
505
+ if group.get('adam_use_AdEMAMix'):
506
+ beta3_ema = group['adam_beta3_ema']
507
+ alpha = group['adam_alpha']
508
+
509
+ if state['factored']:
510
+ d1, d2 = state['effective_shape']
511
+ grad_reshaped = grad.view(d1, d2)
512
+
513
+ # Reconstruct momentum from previous step's factors
514
+ if beta1_adam > 0:
515
+ mt = _unnmf((state['mu_m_nmf'], state['mv_m_nmf']))
516
+ if not group.get('adam_grams_moment'):
517
+ unpacked_sign = _unpack_bools(state['sign'], original_m=d2)
518
+ torch.where(unpacked_sign, mt, -mt, out=mt)
519
+ del unpacked_sign
520
+ # Update momentum in full-size
521
+ mt.mul_(beta1_adam).add_(grad_reshaped, alpha=1.0 - beta1_adam)
522
+ if group.get('adam_grams_moment'):
523
+ mt = (grad_reshaped.sign().mul_(mt.abs()))
524
+ elif group.get('adam_cautious_mask'):
525
+ mask = (mt * grad_reshaped > 0).to(grad_reshaped.dtype)
526
+ mask.div_(mask.mean().clamp_(min=1e-3))
527
+ mt.mul_(mask)
528
+ del mask
529
+
530
+ vt = _unnmf((state['mu_v_nmf'], state['mv_v_nmf']))
531
+ vt.mul_(beta2_adam).addcmul_(grad_reshaped, grad_reshaped, value=1.0 - beta2_adam)
532
+
533
+ if group.get('adam_use_AdEMAMix'):
534
+ mt_slow = _unnmf((state['mu_m_slow_nmf'], state['mv_m_slow_nmf']))
535
+ if state['sign_slow'].dtype != torch.uint8:
536
+ state['sign_slow'] = state['sign_slow'].to(torch.uint8)
537
+ unpacked_sign_slow = _unpack_bools(state['sign_slow'], original_m=d2)
538
+ torch.where(unpacked_sign_slow, mt_slow, -mt_slow, out=mt_slow)
539
+ del unpacked_sign_slow
540
+
541
+ mt_slow.mul_(beta3_ema).add_(grad_reshaped, alpha=1.0 - beta3_ema)
542
+ if beta1_adam > 0:
543
+ update = torch.add(mt, mt_slow, alpha=alpha)
544
+ else:
545
+ update = torch.add(grad_reshaped, mt_slow, alpha=alpha)
546
+ else:
547
+ update = mt.clone() if beta1_adam > 0 else grad_reshaped.clone()
548
+ del grad_reshaped
549
+
550
+ if group['adam_use_atan2']:
551
+ a = 1.2732395
552
+ denom = (vt.sqrt() / (bias_correction2**0.5))
553
+ update.atan2_(denom).mul_(a)
554
+ else:
555
+ denom = (vt.sqrt() / (bias_correction2**0.5)).add_(group['adam_eps'])
556
+ update.div_(denom)
557
+ del denom
558
+
559
+ update = update.view(p.shape).mul_(step_size)
560
+
561
+ # Compress updated moments and store new factors
562
+ if beta1_adam > 0:
563
+ if not group.get('adam_grams_moment'):
564
+ state['sign'] = _pack_bools(mt > 0)
565
+ _nnmf(mt.abs(), out=(state['mu_m_nmf'], state['mv_m_nmf']))
566
+ del mt
567
+ if group.get('adam_use_AdEMAMix'):
568
+ state['sign_slow'] = _pack_bools(mt_slow > 0)
569
+ _nnmf(mt_slow.abs(), out=(state['mu_m_slow_nmf'], state['mv_m_slow_nmf']))
570
+ del mt_slow
571
+ _nnmf(vt, out=(state['mu_v_nmf'], state['mv_v_nmf']))
572
+ del vt
573
+
574
+ else: # Standard AdamW logic for non-factored tensors
575
+ exp_avg_sq = state['exp_avg_sq']
576
+
577
+ if beta1_adam > 0:
578
+ exp_avg = state['exp_avg']
579
+ exp_avg.mul_(beta1_adam).add_(grad, alpha=1 - beta1_adam)
580
+ if group.get('adam_grams_moment'):
581
+ exp_avg = grad.sign().mul_(exp_avg.abs())
582
+ elif group.get('adam_cautious_mask'):
583
+ mask = (exp_avg * grad > 0).to(grad.dtype)
584
+ mask.div_(mask.mean().clamp_(min=1e-3))
585
+ exp_avg.mul_(mask)
586
+ del mask
587
+
588
+ if group.get('adam_use_AdEMAMix'):
589
+ exp_avg_slow = state['exp_avg_slow']
590
+ exp_avg_slow.mul_(beta3_ema).add_(grad, alpha=1 - beta3_ema)
591
+ if beta1_adam > 0:
592
+ update = torch.add(exp_avg, exp_avg_slow, alpha=alpha)
593
+ else:
594
+ update = torch.add(grad, exp_avg_slow, alpha=alpha)
595
+ else:
596
+ update = exp_avg.clone() if beta1_adam > 0 else grad.clone()
597
+
598
+ exp_avg_sq.mul_(beta2_adam).addcmul_(grad, grad.conj(), value=1 - beta2_adam)
599
+
600
+ if group.get('adam_use_atan2'):
601
+ a = 1.2732395
602
+ denom = (exp_avg_sq.sqrt() / (bias_correction2**0.5))
603
+ update.atan2_(denom).mul_(a)
604
+ else:
605
+ denom = (exp_avg_sq.sqrt() / (bias_correction2**0.5)).add_(group['adam_eps'])
606
+ update.div_(denom)
607
+ del denom
608
+
609
+ update.mul_(step_size)
610
+
611
+ # Decoupled weight decay
612
+ if group["adam_weight_decay"] != 0:
613
+ if p.dtype == torch.bfloat16 and self.stochastic_rounding:
614
+ add_stochastic_(p.data, p.data, alpha=-group["adam_weight_decay"] * lr)
615
+ else:
616
+ p.data.add_(p.data, alpha=-group["adam_weight_decay"] * lr)
617
+
618
+ if p.dtype == torch.bfloat16 and self.stochastic_rounding:
619
+ add_stochastic_(p.data, -update)
620
+ else:
621
+ p.data.add_(-update)
622
+ del update
623
+
624
+ @torch.no_grad()
625
+ def step_parameter(self, p: torch.Tensor, group: dict, i: int | None = None):
626
+ grad = p.grad
627
+ if grad is None:
628
+ return
629
+
630
+ state = self.state[p]
631
+
632
+ if self.kourkoutas_helper:
633
+ # Prepare Kourkoutas-β once per optimizer step.
634
+ self.kourkoutas_helper.maybe_prepare_step(self.global_step)
635
+
636
+ # Determine if using Adam or Muon based on state keys
637
+ # We can use optm_type but I see this as a safer way.
638
+ if 'momentum_buffer' in state or 'mu_mbuf_nmf' in state:
639
+ use_adam = False
640
+ else:
641
+ use_adam = True
642
+
643
+ lr = group['lr']
644
+ is_compiled = group.get('compiled_optimizer', False)
645
+
646
+ if use_adam:
647
+ # Adam-specific setup (bias correction)
648
+ if group['adam_use_bias_correction']:
649
+ current_step = self.global_step + 1
650
+ beta1_adam, beta2_adam = group['adam_betas']
651
+ bias_correction1 = 1.0 - beta1_adam ** current_step
652
+ bias_correction2 = 1.0 - beta2_adam ** current_step
653
+ else:
654
+ bias_correction1 = 1.0
655
+ bias_correction2 = 1.0
656
+
657
+ # Dispatch to compiled or uncompiled Adam step
658
+ if is_compiled and self._compiled_adam_step is not None:
659
+ # Tensors must be used for compiled functions
660
+ lr_tensor = torch.tensor(lr, device=p.device)
661
+ bc1_tensor = torch.tensor(bias_correction1, device=p.device)
662
+ bc2_tensor = torch.tensor(bias_correction2, device=p.device)
663
+ self._compiled_adam_step(p, grad, state, group, lr_tensor, bc1_tensor, bc2_tensor)
664
+ else:
665
+ self._adam_step_parameter(p, grad, state, group, lr, bias_correction1, bias_correction2)
666
+ else: # Muon path
667
+ # Dispatch to compiled or uncompiled Muon step
668
+ if is_compiled and self._compiled_muon_step is not None:
669
+ lr_tensor = torch.tensor(lr, device=p.device)
670
+ self._compiled_muon_step(p, grad, state, group, lr_tensor)
671
+ else:
672
+ self._muon_step_parameter(p, grad, state, group, lr)
673
+
674
+
675
+ def compile(self, *args, **kwargs):
676
+ print("Compiling Muon step path...")
677
+ self._compiled_muon_step = torch.compile(self._muon_step_parameter, *args, **kwargs)
678
+ print("Compiling AuxAdam step path...")
679
+ self._compiled_adam_step = torch.compile(self._adam_step_parameter, *args, **kwargs)
412
680
 
413
681
  @torch.no_grad()
414
682
  def step(self, closure=None):
@@ -422,4 +690,6 @@ class Muon_adv(torch.optim.Optimizer):
422
690
  for i, p in enumerate(group['params']):
423
691
  self.step_parameter(p, group, i)
424
692
 
425
- return loss
693
+ self.global_step += 1
694
+
695
+ return loss
@@ -71,7 +71,7 @@ class KourkoutasHelper:
71
71
  for layer_key, info in self.layer_info.items():
72
72
  params, group = info['params'], info['group_ref']
73
73
 
74
- if not group.get('kourkoutas_beta', False):
74
+ if not group.get('kourkoutas_beta', False) and not group.get('adam_kourkoutas_beta', False):
75
75
  continue
76
76
 
77
77
  first_param_in_layer = info['params'][0]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: adv_optm
3
- Version: 1.2.dev12
3
+ Version: 1.2.dev14
4
4
  Summary: A family of highly efficient, lightweight yet powerful optimizers.
5
5
  Home-page: https://github.com/Koratahiu/Advanced_Optimizers
6
6
  Author: Koratahiu
@@ -1,23 +1,23 @@
1
- adv_optm/__init__.py,sha256=eREtFkRqgGhb8_duC4ZROpRguYIQZiJfws1MJbrBn8c,380
2
- adv_optm/optim/AdaMuon_adv.py,sha256=828WtdsaKXJqlZqFXE2yrsxY3Erxn-6N7CxV9jBXiaI,17880
1
+ adv_optm/__init__.py,sha256=D5arg90L2AukHVLCuo7eEbYCh1KtUMOnCwrxsBQgA18,380
2
+ adv_optm/optim/AdaMuon_adv.py,sha256=-UBw_mJj8JzDAi3zQ0nLnSOgzsTzl7b7kVksDRUziEE,30582
3
3
  adv_optm/optim/AdamW_adv.py,sha256=KL9SCJWZ_ckAQEApB6ofbndVYjancN-v7Us7hJLFf54,17475
4
4
  adv_optm/optim/Adopt_adv.py,sha256=S8XI2YA7683jsW8p7igc2YcU30lsN0H18qL02Kpvj8E,21244
5
5
  adv_optm/optim/Lion_Prodigy_adv.py,sha256=LEA3UYJpPeFnmxeniLNv1u2LKKj4ufx3Bq_MLw-nWXk,14617
6
6
  adv_optm/optim/Lion_adv.py,sha256=aGNAplZlyXYgVllYcV_s4bK8iC4fv6EizFoWIMNLdBc,8299
7
- adv_optm/optim/Muon_adv.py,sha256=uo9CBUI_5hZuuKbEmqHqozvS1_d3rKK0NKlyuv_0XxU,19518
7
+ adv_optm/optim/Muon_adv.py,sha256=8d99NcXzLyxTbxVVXC8mHyeW7wM8jjK59QoXVTLScQA,32112
8
8
  adv_optm/optim/Prodigy_adv.py,sha256=lEjbtuQbomsCX39DnTPeI8Z5YG0f2aZPXN_E7-nGgWw,26060
9
9
  adv_optm/optim/Simplified_AdEMAMix.py,sha256=nEIA3yM11nBooKzHudB5l3x4UdFRBYRwiKVUkGmO0K8,12971
10
10
  adv_optm/optim/__init__.py,sha256=hpUWE6CKtt_rvMdgQVb3PtjhfZAvAxTq6hp8H8rIpBo,489
11
11
  adv_optm/util/BF16_Stochastic_Rounding.py,sha256=Q5H0BcogmE4atP65dLoI21HKSf50lRdsBDfeF6v9Tbg,1548
12
12
  adv_optm/util/Effective_Shape.py,sha256=TBvIk1V8IuTbbBsxuekJA4e_v8JlR5Nujtut8RTWAm4,318
13
- adv_optm/util/Kourkoutas.py,sha256=_fq2glPqKmzgWpLedfwq5EqIJAxICUK2fmUP-cdcgq0,7467
13
+ adv_optm/util/Kourkoutas.py,sha256=SSzhe0B6Zb2AXGwCKpVTLr0aaFfspcFBNZCZG3azI9k,7516
14
14
  adv_optm/util/NNMF.py,sha256=yRf5IP5Sjq0Uf0DxN0Q8NxEGSdD-f1ULziLVDOjY8K4,639
15
15
  adv_optm/util/Newton_Schulz.py,sha256=wJ_sKRaGVIsOofQ737my4ng494qX_pfgOqlDDmYtnCg,1377
16
16
  adv_optm/util/One_Bit_Boolean.py,sha256=Wat49esdwohuN-OHOFMW8D0aOQgV9cP5Rl8z6yfmpos,1068
17
17
  adv_optm/util/OrthoGrad.py,sha256=NzInuBQGy_Ja__M1R9XbvqVaQ0fhGbtGgFE9YON7B3I,707
18
18
  adv_optm/util/__init__.py,sha256=CXzS703GB4gil85khZi7sgKOnbzXGBOltshIOSPqj18,435
19
- adv_optm-1.2.dev12.dist-info/licenses/LICENSE,sha256=HrhfyXIkWY2tGFK11kg7vPCqhgh5DcxleloqdhrpyMY,11558
20
- adv_optm-1.2.dev12.dist-info/METADATA,sha256=RC88vvcjd7LgqF6wuHsBLQyoyDwm87ECoTkat7NOs5Y,14023
21
- adv_optm-1.2.dev12.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
22
- adv_optm-1.2.dev12.dist-info/top_level.txt,sha256=iNfBIIzu-lPrQ7jyC56WBCcbkRwitM2nJ15-MRQ_6fg,9
23
- adv_optm-1.2.dev12.dist-info/RECORD,,
19
+ adv_optm-1.2.dev14.dist-info/licenses/LICENSE,sha256=HrhfyXIkWY2tGFK11kg7vPCqhgh5DcxleloqdhrpyMY,11558
20
+ adv_optm-1.2.dev14.dist-info/METADATA,sha256=g017hnuxrm1a34pjnXUDZlnUify9xQtX_ZkrbMEXLLY,14023
21
+ adv_optm-1.2.dev14.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
22
+ adv_optm-1.2.dev14.dist-info/top_level.txt,sha256=iNfBIIzu-lPrQ7jyC56WBCcbkRwitM2nJ15-MRQ_6fg,9
23
+ adv_optm-1.2.dev14.dist-info/RECORD,,