adv-optm 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,720 @@
1
+ import torch
2
+
3
+ from ..util.BF16_Stochastic_Rounding import add_stochastic_
4
+ from ..util.Newton_Schulz import _newton_schulz_iteration
5
+ from ..util.Effective_Shape import _get_effective_shape
6
+ from ..util.NNMF import _nnmf,_unnmf
7
+ from ..util.One_Bit_Boolean import _pack_bools, _unpack_bools
8
+ from ..util.OrthoGrad import _orthogonalize_gradient
9
+ from ..util.Kourkoutas import KourkoutasHelper
10
+
11
+ class AdaMuon_adv(torch.optim.Optimizer):
12
+ """
13
+ IImplements an advanced AdaMuon optimizer algorithm.
14
+
15
+ AdaMuon combines the geometry-aware updates of Muon with the element-wise
16
+ adaptivity of Adam. It is designed for 2D parameters (e.g., linear layers)
17
+ and can handle higher-dimensional parameters by flattening.
18
+
19
+ The algorithm incorporates three key mechanisms:
20
+ 1. A sign-stabilized orthogonal update, where the sign of the momentum is
21
+ orthogonalized instead of the momentum itself.
22
+ 2. An element-wise second momentum estimator applied to the orthogonalized
23
+ update directions.
24
+ 3. An RMS-aligned rescaling strategy to match the update magnitude of Adam,
25
+ allowing for reuse of learning rate schedules.
26
+
27
+ When `MuonWithAuxAdam` is enabled, this single optimizer class handles both
28
+ 'muon' and 'adam' parameter groups, dispatching to the appropriate logic internally.
29
+
30
+ Args:
31
+ params (iterable): iterable of parameters to optimize or dicts defining
32
+ parameter groups.
33
+ lr (float): learning rate (default: 1e-3).
34
+ betas (tuple[float, float]): coefficients used for both first and second moment
35
+ estimation (default: (0.95, 0.95))
36
+ weight_decay (float): weight decay (L2 penalty) (default: 0.1).
37
+ eps (float): term added to the denominator for adaptive scaling to improve
38
+ numerical stability (default: 1e-8).
39
+ rms_rescaling (bool): Use Root-Mean-Square for the final update
40
+ vector, used for RMS-aligned rescaling. Allows for the reuse of existing Adam
41
+ learning rate schedules. (default: True).
42
+ ns_steps (int): number of Newton-Schulz iterations to perform (default: 5).
43
+ ns_eps (float): epsilon for Newton-Schulz normalization stability (default: 1e-7).
44
+ ns_coeffs (tuple[float, float, float]): The (a, b, c) coefficients for the
45
+ quintic polynomial in the Newton-Schulz iteration.
46
+ (default: (3.4445, -4.7750, 2.0315)).
47
+ stochastic_rounding (bool): whether to use stochastic rounding for
48
+ BF16 parameter updates (default: True).
49
+ orthogonal_gradient (bool): whether to use OrthoGrad. (default: False)
50
+ nesterov (bool): enables Nesterov momentum (default: False).
51
+ use_atan2 (bool): whether to use the atan2 update rule. (default: False)
52
+ Simplified_AdEMAMix (bool): whether to use the Simplified AdEMAMix update rule.
53
+ This changes the update to `alpha_grad * grad + mt`, which can be
54
+ more responsive, especially for small batch sizes. (default: False)
55
+ alpha_grad (float): Mixing coefficient for the Simplified AdEMAMix update rule
56
+ (only used when `Simplified_AdEMAMix` is `True`). Controls the weight of the
57
+ current gradient. For small batch sizes, use high values (e.g., 10-100) to be
58
+ more responsive. For large batch sizes, use low values (e.g., 0-1) for
59
+ stability. (default: 100.0)
60
+ vector_reshape (bool): whether to reshape 1D vectors into 2D
61
+ matrices to apply low-rank compression (default: True).
62
+ low_rank_ortho (bool): If True, enables low-rank orthogonalization, which
63
+ projects the update to a lower rank before orthogonalization.
64
+ (default: False)
65
+ ortho_rank (int): The rank for low-rank orthogonalization.
66
+ (default: 128)
67
+ accelerated_ns (bool): If True, enables Chebyshev-accelerated Newton-Schulz, which
68
+ dynamically calculates optimal 3rd-order polynomial coefficients. (default: False)
69
+ cns_a_bound (float): Initial lower bound for singular values for CANS. (default: 1e-4)
70
+ nnmf_factor (bool): whether to use the factorization or disable it to use
71
+ the uncompressed optimizer. (default: False)
72
+ --- Auxiliary AdamW_adv Parameters (used for 'adam' groups) ---
73
+ adam_betas (tuple[float, float]): Betas for the AdamW optimizer part.
74
+ adam_eps (float): Epsilon for the AdamW optimizer part.
75
+ adam_weight_decay (float): Weight decay for the AdamW optimizer part.
76
+ adam_use_bias_correction (bool): Bias correction for AdamW.
77
+ adam_use_atan2 (bool): Atan2 update rule for AdamW.
78
+ adam_cautious_mask (bool): Cautious masking for AdamW.
79
+ adam_grams_moment (bool): Grams-style updates for AdamW.
80
+ adam_orthogonal_gradient (bool): OrthoGrad for AdamW.
81
+ adam_use_AdEMAMix (bool): AdEMAMix for AdamW.
82
+ adam_beta3_ema (float): Beta3 for AdEMAMix.
83
+ adam_alpha (float): Alpha for AdEMAMix.
84
+ adam_kourkoutas_beta (bool): Kourkoutas-β for AdamW.
85
+ """
86
+
87
+ def __init__(
88
+ self,
89
+ params,
90
+ lr: float = 1e-3,
91
+ betas: tuple[float, float] = (0.95, 0.95),
92
+ weight_decay: float = 0,
93
+ eps: float = 1e-8,
94
+ rms_rescaling: bool = True,
95
+ ns_steps: int = 5,
96
+ ns_eps: float = 1e-7,
97
+ ns_coeffs: tuple[float, float, float] = (3.4445, -4.7750, 2.0315),
98
+ stochastic_rounding: bool = True,
99
+ orthogonal_gradient: bool = False,
100
+ use_atan2: bool = False,
101
+ nesterov: bool = False,
102
+ Simplified_AdEMAMix: bool = False,
103
+ alpha_grad: float = 100.0,
104
+ normuon_variant: bool = False,
105
+ # Low-rank Muon
106
+ low_rank_ortho: bool = False,
107
+ ortho_rank: int = 128,
108
+ # Factored
109
+ vector_reshape: bool = False,
110
+ nnmf_factor: bool = False,
111
+ # CANS
112
+ accelerated_ns: bool = False,
113
+ cns_a_bound: float = 1e-4,
114
+ # Compiled
115
+ compiled_optimizer: bool = False,
116
+ # --- AdamW_adv specific parameters ---
117
+ adam_betas: tuple[float, float] = (0.9, 0.99),
118
+ adam_eps: float = 1e-8,
119
+ adam_weight_decay: float = 0.0,
120
+ adam_use_bias_correction: bool = True,
121
+ adam_use_atan2: bool = False,
122
+ adam_cautious_mask: bool = False,
123
+ adam_grams_moment: bool = False,
124
+ adam_orthogonal_gradient: bool = False,
125
+ adam_use_AdEMAMix: bool = False,
126
+ adam_beta3_ema: float = 0.9999,
127
+ adam_alpha: float = 5.0,
128
+ adam_kourkoutas_beta: bool = False,
129
+ adam_beta2_min: float = 0.9,
130
+ adam_ema_alpha: float = 0.95,
131
+ adam_tiny_spike: float = 1e-9,
132
+ adam_k_warmup_steps: int = 0,
133
+ adam_nnmf_factor: bool = False,
134
+ ):
135
+ if not (lr >= 0.0):
136
+ raise ValueError(f"Learning-rate should be >= 0.0. Got {lr}")
137
+ if not (weight_decay >= 0.0):
138
+ raise ValueError(f"Weight-decay should be >= 0.0. Got {weight_decay}")
139
+ if not (ns_steps > 0):
140
+ raise ValueError(f"Newton-Schulz steps should be > 0. Got {ns_steps}")
141
+ if Simplified_AdEMAMix and nesterov:
142
+ print("Warning: nesterov is incompatible with Simplified_AdEMAMix, Disabling nesterov.")
143
+ nesterov = False
144
+
145
+ defaults = {
146
+ "lr": lr, "betas": betas, "weight_decay": weight_decay,
147
+ "eps": eps, "rms_rescaling": rms_rescaling, "ns_steps": ns_steps,
148
+ "ns_eps": ns_eps, "ns_coeffs": ns_coeffs, "nnmf_factor": nnmf_factor,
149
+ "vector_reshape": vector_reshape,
150
+ "nesterov":nesterov, "use_atan2":use_atan2,
151
+ "Simplified_AdEMAMix": Simplified_AdEMAMix, "alpha_grad": alpha_grad,
152
+ "normuon_variant": normuon_variant, "orthogonal_gradient": orthogonal_gradient,
153
+ # Low-rank Ortho
154
+ "low_rank_ortho": low_rank_ortho, "ortho_rank": ortho_rank,
155
+ "compiled_optimizer":compiled_optimizer,
156
+ # CANS
157
+ "accelerated_ns": accelerated_ns, "cns_a_bound": cns_a_bound,
158
+ # AdamW_adv defaults
159
+ "adam_betas": adam_betas, "adam_eps": adam_eps, "adam_weight_decay": adam_weight_decay,
160
+ "adam_use_bias_correction": adam_use_bias_correction, "adam_use_atan2": adam_use_atan2,
161
+ "adam_cautious_mask": adam_cautious_mask, "adam_grams_moment": adam_grams_moment,
162
+ "adam_orthogonal_gradient": adam_orthogonal_gradient,
163
+ "adam_use_AdEMAMix": adam_use_AdEMAMix, "adam_beta3_ema": adam_beta3_ema, "adam_alpha": adam_alpha,
164
+ "adam_kourkoutas_beta": adam_kourkoutas_beta, "adam_beta2_min": adam_beta2_min,
165
+ "adam_ema_alpha": adam_ema_alpha, "adam_tiny_spike": adam_tiny_spike,
166
+ "adam_k_warmup_steps": adam_k_warmup_steps, "adam_nnmf_factor": adam_nnmf_factor,
167
+ }
168
+ self.stochastic_rounding = stochastic_rounding
169
+
170
+ super().__init__(params, defaults)
171
+
172
+ self.kourkoutas_helper = None
173
+ if any(group.get('adam_kourkoutas_beta', False) for group in self.param_groups):
174
+ self.kourkoutas_helper = KourkoutasHelper(self)
175
+
176
+ self.init_step()
177
+
178
+ # Initialize compiled functions to None
179
+ self._compiled_muon_step = None
180
+ self._compiled_adam_step = None
181
+
182
+ if compiled_optimizer:
183
+ print("Compiling AdaMuon_adv optimizer paths...")
184
+ torch._dynamo.config.cache_size_limit = 8192
185
+ self.compile(fullgraph=True)
186
+
187
+ @property
188
+ def supports_fused_back_pass(self):
189
+ return True
190
+
191
+ @property
192
+ def supports_memory_efficient_fp16(self):
193
+ return True
194
+
195
+ @property
196
+ def supports_flat_params(self):
197
+ return False
198
+
199
+ def init_step(self):
200
+ for group in self.param_groups:
201
+ for i, p in enumerate(group['params']):
202
+ self.__init_state(p, group)
203
+
204
+ @torch.no_grad()
205
+ def __init_state(self, p, group):
206
+ state = self.state[p]
207
+
208
+ if len(state) > 0:
209
+ return
210
+
211
+ optim_type = group.get('optim_type', 'muon')
212
+
213
+ if optim_type == 'muon':
214
+
215
+ state['factored'] = (
216
+ group['nnmf_factor'] and
217
+ not (len(p.shape) == 1 and not group['vector_reshape'])
218
+ )
219
+ dtype = torch.float32 if state['factored'] else p.dtype
220
+ device = p.device
221
+
222
+ if state['factored']:
223
+ state['effective_shape'] = _get_effective_shape(p.numel())
224
+ d1, d2 = state['effective_shape']
225
+ state['mu_mbuf_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
226
+ state['mv_mbuf_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
227
+ packed_d2 = (d2 + 7) // 8
228
+ state['sign_buf'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=device)
229
+ if not group['normuon_variant']:
230
+ state['mu_vbuf_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
231
+ state['mv_vbuf_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
232
+ else:
233
+ if len(p.shape) >= 2 and not group['normuon_variant']:
234
+ state['second_momentum_buffer'] = torch.zeros_like(p)
235
+ state['momentum_buffer'] = torch.zeros_like(p)
236
+
237
+ # NorMuon state initialization
238
+ if group['normuon_variant']:
239
+ if state['factored']:
240
+ state['normuon_v'] = torch.zeros(d1, device=p.device, dtype=torch.float32)
241
+ elif len(p.shape) >= 2:
242
+ state['normuon_v'] = torch.zeros(p.shape[0], device=p.device, dtype=torch.float32)
243
+
244
+ elif optim_type == 'adam':
245
+
246
+ state['step'] = 0
247
+
248
+ state['factored'] = (
249
+ group['adam_nnmf_factor'] and
250
+ not (len(p.shape) == 1 and not group['vector_reshape'])
251
+ )
252
+ dtype = torch.float32 if state['factored'] else p.dtype
253
+ device = p.device
254
+
255
+ if state['factored']:
256
+ state['effective_shape'] = _get_effective_shape(p.numel())
257
+ d1, d2 = state['effective_shape']
258
+ # First moment (m)
259
+ if group['adam_betas'][0] > 0:
260
+ state['mu_m_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
261
+ state['mv_m_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
262
+ if not group.get('adam_grams_moment'):
263
+ packed_d2 = (d2 + 7) // 8
264
+ state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=device)
265
+ if group.get('adam_use_AdEMAMix'):
266
+ state['mu_m_slow_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
267
+ state['mv_m_slow_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
268
+ packed_d2 = (d2 + 7) // 8
269
+ state['sign_slow'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=p.device)
270
+ # Second moment (v)
271
+ state['mu_v_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
272
+ state['mv_v_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
273
+ else: # Fallback to standard AdamW for non-factored tensors
274
+ if group['adam_betas'][0] > 0:
275
+ state['exp_avg'] = torch.zeros_like(p, device=device, dtype=dtype)
276
+ if group.get('adam_use_AdEMAMix'):
277
+ state['exp_avg_slow'] = torch.zeros_like(p, device=device, dtype=dtype)
278
+ state['exp_avg_sq'] = torch.zeros_like(p, device=device, dtype=dtype)
279
+
280
+ @torch.no_grad()
281
+ def _muon_step_parameter(self, p, grad, state, group, lr):
282
+ # Retrieve hyperparameters
283
+ beta1, beta2 = group['betas']
284
+ nesterov = group['nesterov']
285
+ Simplified_AdEMAMix = group['Simplified_AdEMAMix']
286
+ alpha_grad = group['alpha_grad']
287
+ if grad.dtype != torch.float32 and state.get('factored', False):
288
+ grad = grad.float()
289
+ if group.get("orthogonal_gradient"):
290
+ grad = _orthogonalize_gradient(p, grad)
291
+
292
+ if state['factored']: # Factored AdaMuon
293
+
294
+ # Reconstruct momentum from previous step's factors & sign
295
+ d1, d2 = state['effective_shape']
296
+ mt_buf = _unnmf((state['mu_mbuf_nmf'], state['mv_mbuf_nmf']))
297
+ unpacked_sign = _unpack_bools(state['sign_buf'], original_m=d2)
298
+ torch.where(unpacked_sign, mt_buf, -mt_buf, out=mt_buf)
299
+ del unpacked_sign
300
+
301
+ # Update momentum in full-size
302
+ grad_reshaped = grad.view(d1, d2)
303
+ mt_buf.mul_(beta1).add_(grad_reshaped)
304
+
305
+ if nesterov:
306
+ signed_m_buf = torch.sign(grad_reshaped.add(mt_buf, alpha=beta1))
307
+ elif Simplified_AdEMAMix:
308
+ signed_m_buf = torch.sign(mt_buf.add(grad_reshaped, alpha=alpha_grad))
309
+ else:
310
+ signed_m_buf = torch.sign(mt_buf)
311
+ del grad_reshaped
312
+
313
+ # Orthogonalization step
314
+ if group['low_rank_ortho']:
315
+ # Low-Rank Orthogonalization on the reconstructed matrix
316
+ M = signed_m_buf
317
+ r = min(group['ortho_rank'], M.shape[0], M.shape[1])
318
+ if r > 0:
319
+ G_sketch = torch.randn(M.shape[1], r, device=M.device, dtype=M.dtype)
320
+ MG = M @ G_sketch
321
+ if MG.dtype != torch.float32:
322
+ MG_dtype = M.dtype
323
+ Q, _ = torch.linalg.qr(MG.float())
324
+ Q = Q.to(MG_dtype)
325
+ else:
326
+ Q, _ = torch.linalg.qr(MG)
327
+ projected_M = Q.T @ M
328
+ ortho_projected_M = _newton_schulz_iteration(
329
+ projected_M, steps=group['ns_steps'], eps=group['ns_eps'], coeffs=group['ns_coeffs'], cns=group['accelerated_ns'], cns_a_bound=group['cns_a_bound']
330
+ )
331
+ update = Q @ ortho_projected_M
332
+ else: # Fallback for invalid rank
333
+ update = _newton_schulz_iteration(
334
+ signed_m_buf, steps=group['ns_steps'], eps=group['ns_eps'], coeffs=group['ns_coeffs'], cns=group['accelerated_ns'], cns_a_bound=group['cns_a_bound']
335
+ )
336
+ else:
337
+ # Original full Newton-Schulz
338
+ update = _newton_schulz_iteration(
339
+ signed_m_buf,
340
+ steps=group['ns_steps'],
341
+ eps=group['ns_eps'],
342
+ coeffs=group['ns_coeffs'],
343
+ cns=group['accelerated_ns'],
344
+ cns_a_bound=group['cns_a_bound'],
345
+ )
346
+ del signed_m_buf
347
+
348
+ if group['normuon_variant']:
349
+ v_t = state['normuon_v']
350
+ # Update 2nd moment estimate
351
+ mean_squared_update = torch.mean(update.square(), dim=1)
352
+ v_t.mul_(beta2).add_(mean_squared_update, alpha=1 - beta2)
353
+ # Normalize update
354
+ update.div_(v_t.sqrt().unsqueeze(1).add_(group['eps']))
355
+ del mean_squared_update
356
+ else:
357
+ # Reconstruct second momentum from previous step's factors
358
+ vt_buf = _unnmf((state['mu_vbuf_nmf'], state['mv_vbuf_nmf']))
359
+ # Update second momentum in full-size
360
+ vt_buf.mul_(beta2).addcmul_(update, update, value=1 - beta2)
361
+ # Apply second momentum update (adaptive scaling)
362
+ if group['use_atan2']:
363
+ a = 1.2732395
364
+ denom = vt_buf.sqrt()
365
+ update.atan2_(denom).mul_(a)
366
+ else:
367
+ denom = vt_buf.sqrt().add_(group['eps'])
368
+ update.div_(denom)
369
+ del denom
370
+
371
+ # RMS-aligned rescaling
372
+ if group['rms_rescaling']:
373
+ rms_target = 0.2 # default (Adam) value for RMS
374
+ update_norm = torch.linalg.vector_norm(update)
375
+ update = update.view(p.shape).mul_(rms_target * lr * (p.numel()**0.5) / update_norm.add_(1e-8))
376
+ del update_norm
377
+ else:
378
+ update = update.view(p.shape).mul_(lr)
379
+
380
+ # Compress updated moments and store new factors
381
+ state['sign_buf'] = _pack_bools(mt_buf > 0)
382
+ _nnmf(mt_buf.abs(), out=(state['mu_mbuf_nmf'], state['mv_mbuf_nmf']))
383
+ del mt_buf
384
+
385
+ if not group['normuon_variant']:
386
+ _nnmf(vt_buf.abs(), out=(state['mu_vbuf_nmf'], state['mv_vbuf_nmf']))
387
+ del vt_buf
388
+
389
+ else: # Standard AdaMuon logic for non-factored tensors
390
+
391
+ if len(p.shape) >= 2:
392
+
393
+ original_shape = p.shape
394
+
395
+ # Momentum update
396
+ mt_buf = state['momentum_buffer']
397
+ mt_buf.mul_(beta1).add_(grad)
398
+
399
+ if nesterov:
400
+ signed_m_buf = torch.sign(grad.add(mt_buf, alpha=beta1))
401
+ elif Simplified_AdEMAMix:
402
+ signed_m_buf = torch.sign(mt_buf.add(grad, alpha=alpha_grad))
403
+ else:
404
+ signed_m_buf = torch.sign(mt_buf)
405
+
406
+ # Flatten if necessary (e.g., for Conv layers)
407
+ signed_m_buf = signed_m_buf.view(original_shape[0], -1)
408
+
409
+ # Orthogonalization step
410
+ if group['low_rank_ortho']:
411
+ # Low-Rank Orthogonalization on the reconstructed matrix
412
+ M = signed_m_buf
413
+ r = min(group['ortho_rank'], M.shape[0], M.shape[1])
414
+ if r > 0:
415
+ G_sketch = torch.randn(M.shape[1], r, device=M.device, dtype=M.dtype)
416
+ MG = M @ G_sketch
417
+ if MG.dtype != torch.float32:
418
+ MG_dtype = M.dtype
419
+ Q, _ = torch.linalg.qr(MG.float())
420
+ Q = Q.to(MG_dtype)
421
+ else:
422
+ Q, _ = torch.linalg.qr(MG)
423
+ projected_M = Q.T @ M
424
+ ortho_projected_M = _newton_schulz_iteration(
425
+ projected_M, steps=group['ns_steps'], eps=group['ns_eps'], coeffs=group['ns_coeffs'], cns=group['accelerated_ns'], cns_a_bound=group['cns_a_bound']
426
+ )
427
+ update = Q @ ortho_projected_M
428
+ else: # Fallback for invalid rank
429
+ update = _newton_schulz_iteration(
430
+ signed_m_buf, steps=group['ns_steps'], eps=group['ns_eps'], coeffs=group['ns_coeffs'], cns=group['accelerated_ns'], cns_a_bound=group['cns_a_bound']
431
+ )
432
+ else:
433
+ # Original full Newton-Schulz
434
+ update = _newton_schulz_iteration(
435
+ signed_m_buf,
436
+ steps=group['ns_steps'],
437
+ eps=group['ns_eps'],
438
+ coeffs=group['ns_coeffs'],
439
+ cns=group['accelerated_ns'],
440
+ cns_a_bound=group['cns_a_bound'],
441
+ )
442
+ del signed_m_buf
443
+
444
+ update = update.view(original_shape)
445
+
446
+ if group['normuon_variant']:
447
+ # NorMuon Logic
448
+ v_t = state['normuon_v']
449
+ # Update 2nd moment estimate
450
+ mean_squared_update = torch.mean(update.square(), dim=1)
451
+ v_t.mul_(beta2).add_(mean_squared_update, alpha=1 - beta2)
452
+ # Normalize update
453
+ update.div_(v_t.sqrt().unsqueeze(1).add_(group['eps']))
454
+ del mean_squared_update
455
+ else:
456
+ # Original AdaMuon Logic
457
+ vt_buf = state['second_momentum_buffer']
458
+ vt_buf.mul_(beta2).addcmul_(update, update, value=1 - beta2)
459
+ # Apply second momentum update (adaptive scaling)
460
+ if group['use_atan2']:
461
+ a = 1.2732395
462
+ denom = vt_buf.sqrt()
463
+ update.atan2_(denom).mul_(a)
464
+ else:
465
+ denom = vt_buf.sqrt().add_(group['eps'])
466
+ update.div_(denom)
467
+ del denom
468
+
469
+ # RMS-aligned rescaling
470
+ if group['rms_rescaling']:
471
+ rms_target = 0.2 # default (Adam) value for RMS
472
+ update_norm = torch.linalg.vector_norm(update)
473
+ update = update.view(p.shape).mul_(rms_target * lr * (p.numel()**0.5) / update_norm.add_(1e-8))
474
+ del update_norm
475
+ else:
476
+ update = update.view(p.shape).mul_(lr)
477
+
478
+ else: # Fallback to standard SGD with momentum for 1D params (biases, etc.)
479
+ # Momentum update
480
+ mt_buf = state['momentum_buffer']
481
+ mt_buf.mul_(beta1).add_(grad)
482
+ if nesterov:
483
+ update = grad.add(mt_buf, alpha=beta1)
484
+ # FIXME, Simplified_AdEMAMix will break SGD since it requires x100 lower LR
485
+ # elif Simplified_AdEMAMix:
486
+ # update = mt_buf.add(grad, alpha=alpha_grad)
487
+ else:
488
+ update = mt_buf.clone()
489
+ update.mul_(lr)
490
+
491
+ # Decoupled weight decay
492
+ if group["weight_decay"] != 0:
493
+ if p.dtype == torch.bfloat16 and self.stochastic_rounding:
494
+ add_stochastic_(p.data, p.data, alpha=-group["weight_decay"] * lr)
495
+ else:
496
+ p.data.add_(p.data, alpha=-group["weight_decay"] * lr)
497
+
498
+ if p.dtype == torch.bfloat16 and self.stochastic_rounding:
499
+ add_stochastic_(p.data, -update)
500
+ else:
501
+ p.data.add_(-update)
502
+ del update
503
+
504
+ @torch.no_grad()
505
+ def _adam_step_parameter(self, p, grad, state, group, lr, bias_correction1, bias_correction2):
506
+ if grad.dtype != torch.float32 and state.get('factored', False):
507
+ grad = grad.float()
508
+ if group.get("adam_orthogonal_gradient"):
509
+ grad = _orthogonalize_gradient(p, grad)
510
+
511
+ beta1_adam, beta2_adam = group['adam_betas']
512
+
513
+ if group.get('adam_kourkoutas_beta', False):
514
+ # Accumulate current grad's norm for the *next* step
515
+ self.kourkoutas_helper.accumulate_gradient_sq_norm(p, grad)
516
+ # Get the dynamic beta2_adam calculated in prepare_step()
517
+ beta2_adam = self.kourkoutas_helper.get_beta2(p, group)
518
+
519
+ step_size = lr / bias_correction1
520
+
521
+ if group.get('adam_use_AdEMAMix'):
522
+ beta3_ema = group['adam_beta3_ema']
523
+ alpha = group['adam_alpha']
524
+
525
+ if state['factored']:
526
+ d1, d2 = state['effective_shape']
527
+ grad_reshaped = grad.view(d1, d2)
528
+
529
+ # Reconstruct momentum from previous step's factors
530
+ if beta1_adam > 0:
531
+ mt = _unnmf((state['mu_m_nmf'], state['mv_m_nmf']))
532
+ if not group.get('adam_grams_moment'):
533
+ unpacked_sign = _unpack_bools(state['sign'], original_m=d2)
534
+ torch.where(unpacked_sign, mt, -mt, out=mt)
535
+ del unpacked_sign
536
+ # Update momentum in full-size
537
+ mt.mul_(beta1_adam).add_(grad_reshaped, alpha=1.0 - beta1_adam)
538
+ if group.get('adam_grams_moment'):
539
+ mt = (grad_reshaped.sign().mul_(mt.abs()))
540
+ elif group.get('adam_cautious_mask'):
541
+ mask = (mt * grad_reshaped > 0).to(grad_reshaped.dtype)
542
+ mask.div_(mask.mean().clamp_(min=1e-3))
543
+ mt.mul_(mask)
544
+ del mask
545
+
546
+ vt = _unnmf((state['mu_v_nmf'], state['mv_v_nmf']))
547
+ vt.mul_(beta2_adam).addcmul_(grad_reshaped, grad_reshaped, value=1.0 - beta2_adam)
548
+
549
+ if group.get('adam_use_AdEMAMix'):
550
+ mt_slow = _unnmf((state['mu_m_slow_nmf'], state['mv_m_slow_nmf']))
551
+ if state['sign_slow'].dtype != torch.uint8:
552
+ state['sign_slow'] = state['sign_slow'].to(torch.uint8)
553
+ unpacked_sign_slow = _unpack_bools(state['sign_slow'], original_m=d2)
554
+ torch.where(unpacked_sign_slow, mt_slow, -mt_slow, out=mt_slow)
555
+ del unpacked_sign_slow
556
+
557
+ mt_slow.mul_(beta3_ema).add_(grad_reshaped, alpha=1.0 - beta3_ema)
558
+ if beta1_adam > 0:
559
+ update = torch.add(mt, mt_slow, alpha=alpha)
560
+ else:
561
+ update = torch.add(grad_reshaped, mt_slow, alpha=alpha)
562
+ else:
563
+ update = mt.clone() if beta1_adam > 0 else grad_reshaped.clone()
564
+ del grad_reshaped
565
+
566
+ if group['adam_use_atan2']:
567
+ a = 1.2732395
568
+ denom = (vt.sqrt() / (bias_correction2**0.5))
569
+ update.atan2_(denom).mul_(a)
570
+ else:
571
+ denom = (vt.sqrt() / (bias_correction2**0.5)).add_(group['adam_eps'])
572
+ update.div_(denom)
573
+ del denom
574
+
575
+ update = update.view(p.shape).mul_(step_size)
576
+
577
+ # Compress updated moments and store new factors
578
+ if beta1_adam > 0:
579
+ if not group.get('adam_grams_moment'):
580
+ state['sign'] = _pack_bools(mt > 0)
581
+ _nnmf(mt.abs(), out=(state['mu_m_nmf'], state['mv_m_nmf']))
582
+ del mt
583
+ if group.get('adam_use_AdEMAMix'):
584
+ state['sign_slow'] = _pack_bools(mt_slow > 0)
585
+ _nnmf(mt_slow.abs(), out=(state['mu_m_slow_nmf'], state['mv_m_slow_nmf']))
586
+ del mt_slow
587
+ _nnmf(vt, out=(state['mu_v_nmf'], state['mv_v_nmf']))
588
+ del vt
589
+
590
+ else: # Standard AdamW logic for non-factored tensors
591
+ exp_avg_sq = state['exp_avg_sq']
592
+
593
+ if beta1_adam > 0:
594
+ exp_avg = state['exp_avg']
595
+ exp_avg.mul_(beta1_adam).add_(grad, alpha=1 - beta1_adam)
596
+ if group.get('adam_grams_moment'):
597
+ exp_avg = grad.sign().mul_(exp_avg.abs())
598
+ elif group.get('adam_cautious_mask'):
599
+ mask = (exp_avg * grad > 0).to(grad.dtype)
600
+ mask.div_(mask.mean().clamp_(min=1e-3))
601
+ exp_avg.mul_(mask)
602
+ del mask
603
+
604
+ if group.get('adam_use_AdEMAMix'):
605
+ exp_avg_slow = state['exp_avg_slow']
606
+ exp_avg_slow.mul_(beta3_ema).add_(grad, alpha=1 - beta3_ema)
607
+ if beta1_adam > 0:
608
+ update = torch.add(exp_avg, exp_avg_slow, alpha=alpha)
609
+ else:
610
+ update = torch.add(grad, exp_avg_slow, alpha=alpha)
611
+ else:
612
+ update = exp_avg.clone() if beta1_adam > 0 else grad.clone()
613
+
614
+ exp_avg_sq.mul_(beta2_adam).addcmul_(grad, grad.conj(), value=1 - beta2_adam)
615
+
616
+ if group.get('adam_use_atan2'):
617
+ a = 1.2732395
618
+ denom = (exp_avg_sq.sqrt() / (bias_correction2**0.5))
619
+ update.atan2_(denom).mul_(a)
620
+ else:
621
+ denom = (exp_avg_sq.sqrt() / (bias_correction2**0.5)).add_(group['adam_eps'])
622
+ update.div_(denom)
623
+ del denom
624
+
625
+ update.mul_(step_size)
626
+
627
+ # Decoupled weight decay
628
+ if group["adam_weight_decay"] != 0:
629
+ if p.dtype == torch.bfloat16 and self.stochastic_rounding:
630
+ add_stochastic_(p.data, p.data, alpha=-group["adam_weight_decay"] * lr)
631
+ else:
632
+ p.data.add_(p.data, alpha=-group["adam_weight_decay"] * lr)
633
+
634
+ if p.dtype == torch.bfloat16 and self.stochastic_rounding:
635
+ add_stochastic_(p.data, -update)
636
+ else:
637
+ p.data.add_(-update)
638
+ del update
639
+
640
+ @torch.no_grad()
641
+ def step_parameter(self, p: torch.Tensor, group: dict, i: int | None = None):
642
+ grad = p.grad
643
+ if grad is None:
644
+ return
645
+ state = self.state[p]
646
+
647
+
648
+ # Determine if using Adam or Muon based on state keys
649
+ # We can use optm_type but I see this as a safer way.
650
+ if 'momentum_buffer' in state or 'mu_mbuf_nmf' in state:
651
+ use_adam = False
652
+ else:
653
+ use_adam = True
654
+
655
+ lr = group['lr']
656
+ is_compiled = group.get('compiled_optimizer', False)
657
+
658
+ if use_adam:
659
+ step = state['step']
660
+
661
+ if self.kourkoutas_helper:
662
+ # Prepare Kourkoutas-β once per optimizer step.
663
+ self.kourkoutas_helper.maybe_prepare_step(step)
664
+
665
+ # Adam-specific setup (bias correction)
666
+ if group['adam_use_bias_correction']:
667
+ current_step = step + 1
668
+ beta1_adam, beta2_adam = group['adam_betas']
669
+ bias_correction1 = 1.0 - beta1_adam ** current_step
670
+ bias_correction2 = 1.0 - beta2_adam ** current_step
671
+ else:
672
+ bias_correction1 = 1.0
673
+ bias_correction2 = 1.0
674
+
675
+ self.state[p]['step'] += 1
676
+
677
+ # Dispatch to compiled or uncompiled Adam step
678
+ if is_compiled and self._compiled_adam_step is not None:
679
+ # convert to tensors for compiled path once a step
680
+ if not hasattr(self, 'lr_adam_tensor') or self.lr_adam_tensor is None:
681
+ self.lr_adam_tensor = torch.tensor(group['lr'])
682
+ self.bc1 = torch.tensor(bias_correction1)
683
+ self.bc2 = torch.tensor(bias_correction2)
684
+ self._compiled_adam_step(p, grad, state, group, self.lr_adam_tensor, self.bc1, self.bc2)
685
+ else:
686
+ self._adam_step_parameter(p, grad, state, group, lr, bias_correction1, bias_correction2)
687
+ else: # Muon path
688
+ # Dispatch to compiled or uncompiled Muon step
689
+ if is_compiled and self._compiled_muon_step is not None:
690
+ # convert to tensors for compiled path once a step
691
+ if not hasattr(self, 'lr_tensor') or self.lr_tensor is None:
692
+ self.lr_tensor = torch.tensor(group['lr'])
693
+ self._compiled_muon_step(p, grad, state, group, self.lr_tensor)
694
+ else:
695
+ self._muon_step_parameter(p, grad, state, group, lr)
696
+
697
+
698
+ def compile(self, *args, **kwargs):
699
+ print("Compiling AdaMuon step path...")
700
+ self._compiled_muon_step = torch.compile(self._muon_step_parameter, *args, **kwargs)
701
+ print("Compiling AuxAdam step path...")
702
+ self._compiled_adam_step = torch.compile(self._adam_step_parameter, *args, **kwargs)
703
+
704
+ @torch.no_grad()
705
+ def step(self, closure=None):
706
+ """Performs a single optimization step."""
707
+ loss = None
708
+ if closure is not None:
709
+ with torch.enable_grad():
710
+ loss = closure()
711
+
712
+ for group in self.param_groups:
713
+ for i, p in enumerate(group['params']):
714
+ self.step_parameter(p, group, i)
715
+
716
+ if self.param_groups[0].get('compiled_optimizer', False):
717
+ # Reset compile tensors once a step
718
+ self.lr_tensor = None
719
+ self.lr_adam_tensor = None
720
+ return loss