adv-optm 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,723 @@
1
+ import torch
2
+
3
+ from ..util.BF16_Stochastic_Rounding import add_stochastic_
4
+ from ..util.Newton_Schulz import _newton_schulz_iteration
5
+ from ..util.Effective_Shape import _get_effective_shape
6
+ from ..util.NNMF import _nnmf,_unnmf
7
+ from ..util.One_Bit_Boolean import _pack_bools, _unpack_bools
8
+ from ..util.OrthoGrad import _orthogonalize_gradient
9
+ from ..util.Kourkoutas import KourkoutasHelper
10
+
11
+ class Muon_adv(torch.optim.Optimizer):
12
+ """
13
+ Implements an advanced Muon algorithm, with an integrated auxiliary AdamW optimizer.
14
+
15
+ Muon (MomentUm Orthogonalized by Newton-Schulz) is an optimizer designed for
16
+ the hidden layers of neural networks. It applies SGD with momentum and then
17
+ orthogonalizes the resulting update matrix using a Newton-Schulz iteration.
18
+
19
+ When `MuonWithAuxAdam` is enabled, this single optimizer class handles both
20
+ 'muon' and 'adam' parameter groups, dispatching to the appropriate logic internally.
21
+
22
+ Args:
23
+ params (iterable): iterable of parameters to optimize or dicts defining
24
+ parameter groups.
25
+ lr (float): learning rate (default: 1e-3).
26
+ beta1 (float): momentum factor (default: 0.9).
27
+ weight_decay (float): weight decay (L2 penalty) (default: 0).
28
+ nesterov (bool): enables Nesterov momentum (default: True).
29
+ ns_steps (int): number of Newton-Schulz iterations to perform (default: 5).
30
+ ns_eps (float): epsilon for Newton-Schulz normalization stability (default: 1e-7).
31
+ ns_coeffs (tuple[float, float, float]): The (a, b, c) coefficients for the
32
+ quintic polynomial in the Newton-Schulz iteration.
33
+ (default: (3.4445, -4.7750, 2.0315)).
34
+ Simplified_AdEMAMix (bool): whether to use the Simplified AdEMAMix update rule.
35
+ This changes the update to `alpha_grad * grad + mt`, which can be
36
+ more responsive, especially for small batch sizes. (default: False)
37
+ alpha_grad (float): Mixing coefficient for the Simplified AdEMAMix update rule
38
+ (only used when `Simplified_AdEMAMix` is `True`). Controls the weight of the
39
+ current gradient. For small batch sizes, use high values (e.g., 10-100) to be
40
+ more responsive. For large batch sizes, use low values (e.g., 0-1) for
41
+ stability. (default: 100.0)
42
+ stochastic_rounding (bool): whether to use stochastic rounding for
43
+ BF16 parameter updates (default: True).
44
+ orthogonal_gradient (bool): whether to use OrthoGrad. (default: False)
45
+ vector_reshape (bool): whether to reshape 1D vectors into 2D
46
+ matrices to apply low-rank compression (default: True).
47
+ nnmf_factor (bool): whether to use the factorization or disable it to use
48
+ the uncompressed optimizer. (default: False)
49
+ low_rank_ortho (bool): If True, enables low-rank orthogonalization, which
50
+ projects the update to a lower rank before orthogonalization.
51
+ (default: False)
52
+ ortho_rank (int): The rank for low-rank orthogonalization.
53
+ (default: 128)
54
+ normuon_variant (bool): If True, enables the NorMuon update rule, which adds
55
+ neuron-wise normalization. (default: False)
56
+ beta2_normuon (float): The exponential decay rate for the second moment estimates
57
+ used in NorMuon. (default: 0.95)
58
+ normuon_eps (float): Epsilon for NorMuon normalization stability. (default: 1e-8)
59
+ rms_rescaling (bool): Use Root-Mean-Square for the final update
60
+ vector, used for RMS-aligned rescaling. Allows for the reuse of existing Adam
61
+ learning rate schedules. (default: True).
62
+ accelerated_ns (bool): If True, enables Chebyshev-accelerated Newton-Schulz, which
63
+ dynamically calculates optimal 3rd-order polynomial coefficients. (default: False)
64
+ cns_a_bound (float): Initial lower bound for singular values for CANS. (default: 1e-4)
65
+ --- Auxiliary AdamW_adv Parameters (used for 'adam' groups) ---
66
+ adam_betas (tuple[float, float]): Betas for the AdamW optimizer part.
67
+ adam_eps (float): Epsilon for the AdamW optimizer part.
68
+ adam_weight_decay (float): Weight decay for the AdamW optimizer part.
69
+ adam_use_bias_correction (bool): Bias correction for AdamW.
70
+ adam_use_atan2 (bool): Atan2 update rule for AdamW.
71
+ adam_cautious_mask (bool): Cautious masking for AdamW.
72
+ adam_grams_moment (bool): Grams-style updates for AdamW.
73
+ adam_orthogonal_gradient (bool): OrthoGrad for AdamW.
74
+ adam_use_AdEMAMix (bool): AdEMAMix for AdamW.
75
+ adam_beta3_ema (float): Beta3 for AdEMAMix.
76
+ adam_alpha (float): Alpha for AdEMAMix.
77
+ adam_kourkoutas_beta (bool): Kourkoutas-β for AdamW.
78
+ adam_nnmf_factor (bool): 1-bit factored for AdamW.
79
+ """
80
+
81
+ def __init__(
82
+ self,
83
+ params,
84
+ lr: float = 1e-3,
85
+ beta1: float = 0.9,
86
+ weight_decay: float = 0.0,
87
+ nesterov: bool = True,
88
+ ns_steps: int = 5,
89
+ ns_eps: float = 1e-7,
90
+ ns_coeffs: tuple[float, float, float] = (3.4445, -4.7750, 2.0315),
91
+ Simplified_AdEMAMix: bool = False,
92
+ alpha_grad: float = 100.0,
93
+ stochastic_rounding: bool = True,
94
+ orthogonal_gradient: bool = False,
95
+ rms_rescaling: bool = True,
96
+ vector_reshape: bool = False,
97
+ nnmf_factor: bool = False,
98
+ # Low-rank Muon
99
+ low_rank_ortho: bool = False,
100
+ ortho_rank: int = 128,
101
+ # NorMuon additions
102
+ normuon_variant: bool = False,
103
+ beta2_normuon: float = 0.95,
104
+ normuon_eps: float = 1e-8,
105
+ # CANS
106
+ accelerated_ns: bool = False,
107
+ cns_a_bound: float = 1e-4,
108
+ # Compiled
109
+ compiled_optimizer: bool = False,
110
+ # --- AdamW_adv specific parameters ---
111
+ adam_betas: tuple[float, float] = (0.9, 0.99),
112
+ adam_eps: float = 1e-8,
113
+ adam_weight_decay: float = 0.0,
114
+ adam_use_bias_correction: bool = True,
115
+ adam_use_atan2: bool = False,
116
+ adam_cautious_mask: bool = False,
117
+ adam_grams_moment: bool = False,
118
+ adam_orthogonal_gradient: bool = False,
119
+ adam_use_AdEMAMix: bool = False,
120
+ adam_beta3_ema: float = 0.9999,
121
+ adam_alpha: float = 5.0,
122
+ adam_kourkoutas_beta: bool = False,
123
+ adam_beta2_min: float = 0.9,
124
+ adam_ema_alpha: float = 0.95,
125
+ adam_tiny_spike: float = 1e-9,
126
+ adam_k_warmup_steps: int = 0,
127
+ adam_nnmf_factor: bool = False,
128
+ ):
129
+ if not (lr >= 0.0):
130
+ raise ValueError(f"Learning-rate should be >= 0.0. Got {lr}")
131
+ if not (0.0 <= beta1 < 1.0):
132
+ raise ValueError(f"beta1 should be in [0.0, 1.0). Got {beta1}")
133
+ if normuon_variant and not (0.0 <= beta2_normuon < 1.0):
134
+ raise ValueError(f"beta2_normuon should be in [0.0, 1.0) for NorMuon. Got {beta2_normuon}")
135
+ if not (weight_decay >= 0.0):
136
+ raise ValueError(f"Weight-decay should be >= 0.0. Got {weight_decay}")
137
+ if not (ns_steps > 0):
138
+ raise ValueError(f"Newton-Schulz steps should be > 0. Got {ns_steps}")
139
+ if Simplified_AdEMAMix and nesterov:
140
+ print("Warning: nesterov is incompatible with Simplified_AdEMAMix, Disabling nesterov.")
141
+ nesterov = False
142
+
143
+ defaults = {
144
+ "lr": lr, "beta1": beta1, "weight_decay": weight_decay,
145
+ "nesterov": nesterov, "ns_steps": ns_steps, "ns_eps": ns_eps,
146
+ "ns_coeffs": ns_coeffs, "nnmf_factor": nnmf_factor,
147
+ "vector_reshape": vector_reshape, "rms_rescaling": rms_rescaling,
148
+ "Simplified_AdEMAMix": Simplified_AdEMAMix, "alpha_grad": alpha_grad,
149
+ "orthogonal_gradient": orthogonal_gradient,
150
+ 'compiled_optimizer': compiled_optimizer,
151
+ # Low-rank Ortho
152
+ "low_rank_ortho": low_rank_ortho, "ortho_rank": ortho_rank,
153
+ # NorMuon
154
+ "normuon_variant": normuon_variant, "beta2_normuon": beta2_normuon,
155
+ "normuon_eps": normuon_eps,
156
+ # CANS
157
+ "accelerated_ns": accelerated_ns, "cns_a_bound": cns_a_bound,
158
+ # AdamW_adv defaults
159
+ "adam_betas": adam_betas, "adam_eps": adam_eps, "adam_weight_decay": adam_weight_decay,
160
+ "adam_use_bias_correction": adam_use_bias_correction, "adam_use_atan2": adam_use_atan2,
161
+ "adam_cautious_mask": adam_cautious_mask, "adam_grams_moment": adam_grams_moment,
162
+ "adam_orthogonal_gradient": adam_orthogonal_gradient,
163
+ "adam_use_AdEMAMix": adam_use_AdEMAMix, "adam_beta3_ema": adam_beta3_ema, "adam_alpha": adam_alpha,
164
+ "adam_kourkoutas_beta": adam_kourkoutas_beta, "adam_beta2_min": adam_beta2_min,
165
+ "adam_ema_alpha": adam_ema_alpha, "adam_tiny_spike": adam_tiny_spike,
166
+ "adam_k_warmup_steps": adam_k_warmup_steps,
167
+ "adam_nnmf_factor":adam_nnmf_factor,
168
+ }
169
+ self.stochastic_rounding = stochastic_rounding
170
+ self.compiled_optimizer = compiled_optimizer
171
+
172
+ super().__init__(params, defaults)
173
+
174
+ self.kourkoutas_helper = None
175
+ if any(group.get('adam_kourkoutas_beta', False) for group in self.param_groups):
176
+ self.kourkoutas_helper = KourkoutasHelper(self)
177
+
178
+ self.init_step()
179
+
180
+ # Initialize compiled functions to None
181
+ self._compiled_muon_step = None
182
+ self._compiled_adam_step = None
183
+
184
+ if compiled_optimizer:
185
+ print("Compiling Muon_adv optimizer paths...")
186
+ torch._dynamo.config.cache_size_limit = 8192
187
+ self.compile(fullgraph=True)
188
+
189
+ @property
190
+ def supports_fused_back_pass(self):
191
+ return True
192
+
193
+ @property
194
+ def supports_memory_efficient_fp16(self):
195
+ return True
196
+
197
+ @property
198
+ def supports_flat_params(self):
199
+ return False
200
+
201
+ def init_step(self):
202
+ for group in self.param_groups:
203
+ for i, p in enumerate(group['params']):
204
+ self.__init_state(p, group)
205
+
206
+ @torch.no_grad()
207
+ def __init_state(self, p, group):
208
+ state = self.state[p]
209
+
210
+ if len(state) > 0:
211
+ return
212
+
213
+ optim_type = group.get('optim_type', 'muon')
214
+
215
+ state['factored'] = (
216
+ group['nnmf_factor'] and
217
+ not (len(p.shape) == 1 and not group['vector_reshape'])
218
+ )
219
+ dtype = torch.float32 if state['factored'] else p.dtype
220
+ device = p.device
221
+
222
+ if optim_type == 'muon':
223
+
224
+
225
+ state['factored'] = (
226
+ group['nnmf_factor'] and
227
+ not (len(p.shape) == 1 and not group['vector_reshape'])
228
+ )
229
+ dtype = torch.float32 if state['factored'] else p.dtype
230
+ device = p.device
231
+
232
+ if state['factored']:
233
+ state['effective_shape'] = _get_effective_shape(p.numel())
234
+ d1, d2 = state['effective_shape']
235
+ state['mu_mbuf_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
236
+ state['mv_mbuf_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
237
+ packed_d2 = (d2 + 7) // 8
238
+ state['sign_buf'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=device)
239
+ else:
240
+ state['momentum_buffer'] = torch.zeros_like(p)
241
+
242
+ # NorMuon state initialization
243
+ if group['normuon_variant']:
244
+ if state['factored']:
245
+ state['normuon_v'] = torch.zeros(d1, device=p.device, dtype=torch.float32)
246
+ elif len(p.shape) >= 2:
247
+ state['normuon_v'] = torch.zeros(p.shape[0], device=p.device, dtype=torch.float32)
248
+
249
+ group['adam_kourkoutas_beta'] = False
250
+
251
+ elif optim_type == 'adam':
252
+
253
+ state['step'] = 0
254
+
255
+ state['factored'] = (
256
+ group['adam_nnmf_factor'] and
257
+ not (len(p.shape) == 1 and not group['vector_reshape'])
258
+ )
259
+ dtype = torch.float32 if state['factored'] else p.dtype
260
+ device = p.device
261
+
262
+ if state['factored']:
263
+ state['effective_shape'] = _get_effective_shape(p.numel())
264
+ d1, d2 = state['effective_shape']
265
+ # First moment (m)
266
+ if group['adam_betas'][0] > 0:
267
+ state['mu_m_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
268
+ state['mv_m_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
269
+ if not group.get('adam_grams_moment'):
270
+ packed_d2 = (d2 + 7) // 8
271
+ state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=device)
272
+ if group.get('adam_use_AdEMAMix'):
273
+ state['mu_m_slow_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
274
+ state['mv_m_slow_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
275
+ packed_d2 = (d2 + 7) // 8
276
+ state['sign_slow'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=p.device)
277
+ # Second moment (v)
278
+ state['mu_v_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
279
+ state['mv_v_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
280
+ else: # Fallback to standard AdamW for non-factored tensors
281
+ if group['adam_betas'][0] > 0:
282
+ state['exp_avg'] = torch.zeros_like(p, device=device, dtype=dtype)
283
+ if group.get('adam_use_AdEMAMix'):
284
+ state['exp_avg_slow'] = torch.zeros_like(p, device=device, dtype=dtype)
285
+ state['exp_avg_sq'] = torch.zeros_like(p, device=device, dtype=dtype)
286
+
287
+ @torch.no_grad()
288
+ def _muon_step_parameter(self, p, grad, state, group, lr):
289
+ beta1 = group['beta1']
290
+ nesterov = group['nesterov']
291
+ Simplified_AdEMAMix = group['Simplified_AdEMAMix']
292
+ alpha_grad = group['alpha_grad']
293
+ if grad.dtype != torch.float32 and state.get('factored', False):
294
+ grad = grad.float()
295
+ if group.get("orthogonal_gradient"):
296
+ grad = _orthogonalize_gradient(p, grad)
297
+
298
+ if state['factored']: # Factored Muon
299
+
300
+ # Reconstruct momentum from previous step's factors & sign
301
+ d1, d2 = state['effective_shape']
302
+ mt_buf = _unnmf((state['mu_mbuf_nmf'], state['mv_mbuf_nmf']))
303
+ unpacked_sign = _unpack_bools(state['sign_buf'], original_m=d2)
304
+ torch.where(unpacked_sign, mt_buf, -mt_buf, out=mt_buf)
305
+ del unpacked_sign
306
+
307
+ # Update momentum in full-size
308
+ grad_reshaped = grad.view(d1, d2)
309
+ mt_buf.mul_(beta1).add_(grad_reshaped)
310
+
311
+ if nesterov:
312
+ # Nesterov momentum
313
+ update = grad_reshaped.add(mt_buf, alpha=beta1)
314
+ elif Simplified_AdEMAMix:
315
+ update = torch.add(mt_buf, grad_reshaped, alpha=alpha_grad)
316
+ else:
317
+ # Standard momentum
318
+ update = mt_buf.clone()
319
+ del grad_reshaped
320
+
321
+ # Orthogonalization step
322
+ if group['low_rank_ortho']:
323
+ # Low-Rank Orthogonalization on the reconstructed matrix
324
+ M = update
325
+ r = min(group['ortho_rank'], M.shape[0], M.shape[1])
326
+ if r > 0:
327
+ G_sketch = torch.randn(M.shape[1], r, device=M.device, dtype=M.dtype)
328
+ MG = M @ G_sketch
329
+ if MG.dtype != torch.float32:
330
+ MG_dtype = M.dtype
331
+ Q, _ = torch.linalg.qr(MG.float())
332
+ Q = Q.to(MG_dtype)
333
+ else:
334
+ Q, _ = torch.linalg.qr(MG)
335
+ projected_M = Q.T @ M
336
+ ortho_projected_M = _newton_schulz_iteration(
337
+ projected_M, steps=group['ns_steps'], eps=group['ns_eps'], coeffs=group['ns_coeffs'], cns=group['accelerated_ns'], cns_a_bound=group['cns_a_bound']
338
+ )
339
+ update = Q @ ortho_projected_M
340
+ else: # Fallback for invalid rank
341
+ update = _newton_schulz_iteration(
342
+ update, steps=group['ns_steps'], eps=group['ns_eps'], coeffs=group['ns_coeffs'], cns=group['accelerated_ns'], cns_a_bound=group['cns_a_bound']
343
+ )
344
+ else:
345
+ # Original full Newton-Schulz
346
+ update = _newton_schulz_iteration(
347
+ update,
348
+ steps=group['ns_steps'],
349
+ eps=group['ns_eps'],
350
+ coeffs=group['ns_coeffs'],
351
+ cns=group['accelerated_ns'],
352
+ cns_a_bound=group['cns_a_bound'],
353
+ )
354
+
355
+
356
+ if group['normuon_variant']:
357
+ v_t = state['normuon_v']
358
+ beta2_normuon = group['beta2_normuon']
359
+ # Update 2nd moment estimate
360
+ mean_squared_update = torch.mean(update.square(), dim=1)
361
+ v_t.mul_(beta2_normuon).add_(mean_squared_update, alpha=1 - beta2_normuon)
362
+ # Normalize update
363
+ update.div_(v_t.sqrt().unsqueeze(1).add_(group['normuon_eps']))
364
+ del mean_squared_update
365
+
366
+ # RMS-aligned rescaling
367
+ if group['rms_rescaling']:
368
+ rms_target = 0.2 # default (Adam) value for RMS
369
+ update_norm = torch.linalg.vector_norm(update)
370
+ update = update.view(p.shape).mul_(rms_target * lr * (p.numel()**0.5) / update_norm.add_(1e-8))
371
+ del update_norm
372
+ else:
373
+ update = update.view(p.shape).mul_(lr)
374
+
375
+ state['sign_buf'] = _pack_bools(mt_buf > 0)
376
+ _nnmf(mt_buf.abs(), out=(state['mu_mbuf_nmf'], state['mv_mbuf_nmf']))
377
+ del mt_buf
378
+
379
+ else: # Standard Muon logic for non-factored tensors
380
+
381
+ if len(p.shape) >= 2:
382
+
383
+ original_shape = p.shape
384
+
385
+ # Momentum update
386
+ mt_buf = state['momentum_buffer']
387
+ mt_buf.mul_(beta1).add_(grad)
388
+
389
+ if nesterov:
390
+ # Nesterov momentum
391
+ update = grad.add(mt_buf, alpha=beta1)
392
+ elif Simplified_AdEMAMix:
393
+ update = torch.add(mt_buf, grad, alpha=alpha_grad)
394
+ else:
395
+ # Standard momentum
396
+ update = mt_buf.clone()
397
+
398
+ # flatten to 2D for orthogonalization.
399
+ # This is a no-op for 2D tensors and correctly flattens 4D+ tensors.
400
+ # This removes the dynamic control flow that breaks torch.compile.
401
+ update = update.view(original_shape[0], -1)
402
+
403
+ # Orthogonalization step
404
+ if group['low_rank_ortho']:
405
+ # Low-Rank Orthogonalization based on Gaussian Sketching
406
+ M = update
407
+ r = min(group['ortho_rank'], M.shape[0], M.shape[1])
408
+
409
+ if r > 0:
410
+ # 1. Sketch the matrix
411
+ G_sketch = torch.randn(M.shape[1], r, device=M.device, dtype=M.dtype)
412
+ MG = M @ G_sketch
413
+
414
+ # 2. QR decomposition to get orthogonal basis Q
415
+ if MG.dtype != torch.float32:
416
+ MG_dtype = M.dtype
417
+ Q, _ = torch.linalg.qr(MG.float())
418
+ Q = Q.to(MG_dtype)
419
+ else:
420
+ Q, _ = torch.linalg.qr(MG)
421
+
422
+ # 3. Project M onto the basis
423
+ projected_M = Q.T @ M
424
+
425
+ # 4. Orthogonalize the smaller projected matrix
426
+ ortho_projected_M = _newton_schulz_iteration(
427
+ projected_M,
428
+ steps=group['ns_steps'],
429
+ eps=group['ns_eps'],
430
+ coeffs=group['ns_coeffs'],
431
+ cns=group['accelerated_ns'],
432
+ cns_a_bound=group['cns_a_bound'],
433
+ )
434
+
435
+ # 5. Project back to the original space
436
+ update = Q @ ortho_projected_M
437
+ else: # Fallback for invalid rank
438
+ update = _newton_schulz_iteration(
439
+ update,
440
+ steps=group['ns_steps'],
441
+ eps=group['ns_eps'],
442
+ coeffs=group['ns_coeffs'],
443
+ cns=group['accelerated_ns'],
444
+ cns_a_bound=group['cns_a_bound'],
445
+ )
446
+ else:
447
+ # Original NewtonSchulz
448
+ update = _newton_schulz_iteration(
449
+ update,
450
+ steps=group['ns_steps'],
451
+ eps=group['ns_eps'],
452
+ coeffs=group['ns_coeffs'],
453
+ cns=group['accelerated_ns'],
454
+ cns_a_bound=group['cns_a_bound'],
455
+ )
456
+
457
+ # NorMuon Logic
458
+ if group['normuon_variant']:
459
+ v_t = state['normuon_v']
460
+ beta2_normuon = group['beta2_normuon']
461
+ # Update 2nd moment estimate
462
+ mean_squared_update = torch.mean(update.square(), dim=1)
463
+ v_t.mul_(beta2_normuon).add_(mean_squared_update, alpha=1 - beta2_normuon)
464
+ # Normalize update
465
+ update.div_(v_t.sqrt().unsqueeze(1).add_(group['normuon_eps']))
466
+
467
+ # RMS-aligned rescaling
468
+ if group['rms_rescaling']:
469
+ rms_target = 0.2 # default (Adam) value for RMS
470
+ update_norm = torch.linalg.vector_norm(update)
471
+ update = update.view(original_shape).mul_(rms_target * lr * (p.numel()**0.5) / update_norm.add_(1e-8))
472
+ del update_norm
473
+ else:
474
+ update = update.view(original_shape).mul_(lr)
475
+
476
+
477
+ else: # Fallback to standard SGD with momentum for 1D params (biases, etc.)
478
+ # Momentum update
479
+ mt_buf = state['momentum_buffer']
480
+ mt_buf.mul_(beta1).add_(grad)
481
+ if nesterov:
482
+ # Nesterov momentum
483
+ update = grad.add(mt_buf, alpha=beta1)
484
+ # FIXME, Simplified_AdEMAMix will break SGD since it requires x100 lower LR
485
+ # elif Simplified_AdEMAMix:
486
+ # update = torch.add(mt_buf, grad, alpha=alpha_grad)
487
+ else:
488
+ # Standard momentum
489
+ update = mt_buf.clone()
490
+ update.mul_(lr)
491
+
492
+ # Decoupled weight decay
493
+ if group["weight_decay"] != 0:
494
+ if p.dtype == torch.bfloat16 and self.stochastic_rounding:
495
+ add_stochastic_(p.data, p.data, alpha=-group["weight_decay"] * lr)
496
+ else:
497
+ p.data.add_(p.data, alpha=-group["weight_decay"] * lr)
498
+
499
+ if p.dtype == torch.bfloat16 and self.stochastic_rounding:
500
+ add_stochastic_(p.data, -update)
501
+ else:
502
+ p.data.add_(-update)
503
+ del update
504
+
505
+ @torch.no_grad()
506
+ def _adam_step_parameter(self, p, grad, state, group, lr, bias_correction1, bias_correction2):
507
+ if grad.dtype != torch.float32 and state.get('factored', False):
508
+ grad = grad.float()
509
+ if group.get("adam_orthogonal_gradient"):
510
+ grad = _orthogonalize_gradient(p, grad)
511
+
512
+ beta1_adam, beta2_adam = group['adam_betas']
513
+
514
+ if group.get('adam_kourkoutas_beta', False):
515
+ # Accumulate current grad's norm for the *next* step
516
+ self.kourkoutas_helper.accumulate_gradient_sq_norm(p, grad)
517
+ # Get the dynamic beta2_adam calculated in prepare_step()
518
+ beta2_adam = self.kourkoutas_helper.get_beta2(p, group)
519
+
520
+ step_size = lr / bias_correction1
521
+
522
+ if group.get('adam_use_AdEMAMix'):
523
+ beta3_ema = group['adam_beta3_ema']
524
+ alpha = group['adam_alpha']
525
+
526
+ if state['factored']:
527
+ d1, d2 = state['effective_shape']
528
+ grad_reshaped = grad.view(d1, d2)
529
+
530
+ # Reconstruct momentum from previous step's factors
531
+ if beta1_adam > 0:
532
+ mt = _unnmf((state['mu_m_nmf'], state['mv_m_nmf']))
533
+ if not group.get('adam_grams_moment'):
534
+ unpacked_sign = _unpack_bools(state['sign'], original_m=d2)
535
+ torch.where(unpacked_sign, mt, -mt, out=mt)
536
+ del unpacked_sign
537
+ # Update momentum in full-size
538
+ mt.mul_(beta1_adam).add_(grad_reshaped, alpha=1.0 - beta1_adam)
539
+ if group.get('adam_grams_moment'):
540
+ mt = (grad_reshaped.sign().mul_(mt.abs()))
541
+ elif group.get('adam_cautious_mask'):
542
+ mask = (mt * grad_reshaped > 0).to(grad_reshaped.dtype)
543
+ mask.div_(mask.mean().clamp_(min=1e-3))
544
+ mt.mul_(mask)
545
+ del mask
546
+
547
+ vt = _unnmf((state['mu_v_nmf'], state['mv_v_nmf']))
548
+ vt.mul_(beta2_adam).addcmul_(grad_reshaped, grad_reshaped, value=1.0 - beta2_adam)
549
+
550
+ if group.get('adam_use_AdEMAMix'):
551
+ mt_slow = _unnmf((state['mu_m_slow_nmf'], state['mv_m_slow_nmf']))
552
+ if state['sign_slow'].dtype != torch.uint8:
553
+ state['sign_slow'] = state['sign_slow'].to(torch.uint8)
554
+ unpacked_sign_slow = _unpack_bools(state['sign_slow'], original_m=d2)
555
+ torch.where(unpacked_sign_slow, mt_slow, -mt_slow, out=mt_slow)
556
+ del unpacked_sign_slow
557
+
558
+ mt_slow.mul_(beta3_ema).add_(grad_reshaped, alpha=1.0 - beta3_ema)
559
+ if beta1_adam > 0:
560
+ update = torch.add(mt, mt_slow, alpha=alpha)
561
+ else:
562
+ update = torch.add(grad_reshaped, mt_slow, alpha=alpha)
563
+ else:
564
+ update = mt.clone() if beta1_adam > 0 else grad_reshaped.clone()
565
+ del grad_reshaped
566
+
567
+ if group['adam_use_atan2']:
568
+ a = 1.2732395
569
+ denom = (vt.sqrt() / (bias_correction2**0.5))
570
+ update.atan2_(denom).mul_(a)
571
+ else:
572
+ denom = (vt.sqrt() / (bias_correction2**0.5)).add_(group['adam_eps'])
573
+ update.div_(denom)
574
+ del denom
575
+
576
+ update = update.view(p.shape).mul_(step_size)
577
+
578
+ # Compress updated moments and store new factors
579
+ if beta1_adam > 0:
580
+ if not group.get('adam_grams_moment'):
581
+ state['sign'] = _pack_bools(mt > 0)
582
+ _nnmf(mt.abs(), out=(state['mu_m_nmf'], state['mv_m_nmf']))
583
+ del mt
584
+ if group.get('adam_use_AdEMAMix'):
585
+ state['sign_slow'] = _pack_bools(mt_slow > 0)
586
+ _nnmf(mt_slow.abs(), out=(state['mu_m_slow_nmf'], state['mv_m_slow_nmf']))
587
+ del mt_slow
588
+ _nnmf(vt, out=(state['mu_v_nmf'], state['mv_v_nmf']))
589
+ del vt
590
+
591
+ else: # Standard AdamW logic for non-factored tensors
592
+ exp_avg_sq = state['exp_avg_sq']
593
+
594
+ if beta1_adam > 0:
595
+ exp_avg = state['exp_avg']
596
+ exp_avg.mul_(beta1_adam).add_(grad, alpha=1 - beta1_adam)
597
+ if group.get('adam_grams_moment'):
598
+ exp_avg = grad.sign().mul_(exp_avg.abs())
599
+ elif group.get('adam_cautious_mask'):
600
+ mask = (exp_avg * grad > 0).to(grad.dtype)
601
+ mask.div_(mask.mean().clamp_(min=1e-3))
602
+ exp_avg.mul_(mask)
603
+ del mask
604
+
605
+ if group.get('adam_use_AdEMAMix'):
606
+ exp_avg_slow = state['exp_avg_slow']
607
+ exp_avg_slow.mul_(beta3_ema).add_(grad, alpha=1 - beta3_ema)
608
+ if beta1_adam > 0:
609
+ update = torch.add(exp_avg, exp_avg_slow, alpha=alpha)
610
+ else:
611
+ update = torch.add(grad, exp_avg_slow, alpha=alpha)
612
+ else:
613
+ update = exp_avg.clone() if beta1_adam > 0 else grad.clone()
614
+
615
+ exp_avg_sq.mul_(beta2_adam).addcmul_(grad, grad.conj(), value=1 - beta2_adam)
616
+
617
+ if group.get('adam_use_atan2'):
618
+ a = 1.2732395
619
+ denom = (exp_avg_sq.sqrt() / (bias_correction2**0.5))
620
+ update.atan2_(denom).mul_(a)
621
+ else:
622
+ denom = (exp_avg_sq.sqrt() / (bias_correction2**0.5)).add_(group['adam_eps'])
623
+ update.div_(denom)
624
+ del denom
625
+
626
+ update.mul_(step_size)
627
+
628
+ # Decoupled weight decay
629
+ if group["adam_weight_decay"] != 0:
630
+ if p.dtype == torch.bfloat16 and self.stochastic_rounding:
631
+ add_stochastic_(p.data, p.data, alpha=-group["adam_weight_decay"] * lr)
632
+ else:
633
+ p.data.add_(p.data, alpha=-group["adam_weight_decay"] * lr)
634
+
635
+ if p.dtype == torch.bfloat16 and self.stochastic_rounding:
636
+ add_stochastic_(p.data, -update)
637
+ else:
638
+ p.data.add_(-update)
639
+ del update
640
+
641
+ @torch.no_grad()
642
+ def step_parameter(self, p: torch.Tensor, group: dict, i: int | None = None):
643
+ grad = p.grad
644
+ if grad is None:
645
+ return
646
+
647
+ state = self.state[p]
648
+
649
+ # Determine if using Adam or Muon based on state keys
650
+ # We can use optm_type but I see this as a safer way.
651
+ if 'momentum_buffer' in state or 'mu_mbuf_nmf' in state:
652
+ use_adam = False
653
+ else:
654
+ use_adam = True
655
+
656
+ lr = group['lr']
657
+ is_compiled = group.get('compiled_optimizer', False)
658
+
659
+ if use_adam:
660
+ step = state['step']
661
+
662
+ if self.kourkoutas_helper:
663
+ # Prepare Kourkoutas-β once per optimizer step.
664
+ self.kourkoutas_helper.maybe_prepare_step(step)
665
+
666
+ # Adam-specific setup (bias correction)
667
+ if group['adam_use_bias_correction']:
668
+ current_step = step + 1
669
+ beta1_adam, beta2_adam = group['adam_betas']
670
+ bias_correction1 = 1.0 - beta1_adam ** current_step
671
+ bias_correction2 = 1.0 - beta2_adam ** current_step
672
+ else:
673
+ bias_correction1 = 1.0
674
+ bias_correction2 = 1.0
675
+
676
+ self.state[p]['step'] += 1
677
+
678
+ # Dispatch to compiled or uncompiled Adam step
679
+ if is_compiled and self._compiled_adam_step is not None:
680
+ # convert to tensors for compiled path once a step
681
+ if not hasattr(self, 'lr_adam_tensor') or self.lr_adam_tensor is None:
682
+ self.lr_adam_tensor = torch.tensor(group['lr'])
683
+ self.bc1 = torch.tensor(bias_correction1)
684
+ self.bc2 = torch.tensor(bias_correction2)
685
+ self._compiled_adam_step(p, grad, state, group, self.lr_adam_tensor, self.bc1, self.bc2)
686
+ else:
687
+ self._adam_step_parameter(p, grad, state, group, lr, bias_correction1, bias_correction2)
688
+ else: # Muon path
689
+ # Dispatch to compiled or uncompiled Muon step
690
+ if is_compiled and self._compiled_muon_step is not None:
691
+ lr_tensor = torch.tensor(lr, device=p.device)
692
+ self._compiled_muon_step(p, grad, state, group, lr_tensor)
693
+ # convert to tensors for compiled path once a step
694
+ if not hasattr(self, 'lr_tensor') or self.lr_tensor is None:
695
+ self.lr_tensor = torch.tensor(group['lr'])
696
+ self._compiled_muon_step(p, grad, state, group, self.lr_tensor)
697
+ else:
698
+ self._muon_step_parameter(p, grad, state, group, lr)
699
+
700
+
701
+ def compile(self, *args, **kwargs):
702
+ print("Compiling Muon step path...")
703
+ self._compiled_muon_step = torch.compile(self._muon_step_parameter, *args, **kwargs)
704
+ print("Compiling AuxAdam step path...")
705
+ self._compiled_adam_step = torch.compile(self._adam_step_parameter, *args, **kwargs)
706
+
707
+ @torch.no_grad()
708
+ def step(self, closure=None):
709
+ """Performs a single optimization step."""
710
+ loss = None
711
+ if closure is not None:
712
+ with torch.enable_grad():
713
+ loss = closure()
714
+
715
+ for group in self.param_groups:
716
+ for i, p in enumerate(group['params']):
717
+ self.step_parameter(p, group, i)
718
+
719
+ if self.param_groups[0].get('compiled_optimizer', False):
720
+ # Reset compile tensors once a step
721
+ self.lr_tensor = None
722
+ self.lr_adam_tensor = None
723
+ return loss