adv-optm 1.2.dev11__tar.gz → 1.2.dev13__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of adv-optm might be problematic. Click here for more details.

Files changed (30) hide show
  1. {adv_optm-1.2.dev11 → adv_optm-1.2.dev13}/PKG-INFO +1 -1
  2. {adv_optm-1.2.dev11 → adv_optm-1.2.dev13}/adv_optm/__init__.py +1 -1
  3. adv_optm-1.2.dev13/adv_optm/optim/AdaMuon_adv.py +664 -0
  4. adv_optm-1.2.dev13/adv_optm/optim/Muon_adv.py +695 -0
  5. {adv_optm-1.2.dev11 → adv_optm-1.2.dev13}/adv_optm.egg-info/PKG-INFO +1 -1
  6. {adv_optm-1.2.dev11 → adv_optm-1.2.dev13}/setup.py +1 -1
  7. adv_optm-1.2.dev11/adv_optm/optim/AdaMuon_adv.py +0 -397
  8. adv_optm-1.2.dev11/adv_optm/optim/Muon_adv.py +0 -423
  9. {adv_optm-1.2.dev11 → adv_optm-1.2.dev13}/LICENSE +0 -0
  10. {adv_optm-1.2.dev11 → adv_optm-1.2.dev13}/README.md +0 -0
  11. {adv_optm-1.2.dev11 → adv_optm-1.2.dev13}/adv_optm/optim/AdamW_adv.py +0 -0
  12. {adv_optm-1.2.dev11 → adv_optm-1.2.dev13}/adv_optm/optim/Adopt_adv.py +0 -0
  13. {adv_optm-1.2.dev11 → adv_optm-1.2.dev13}/adv_optm/optim/Lion_Prodigy_adv.py +0 -0
  14. {adv_optm-1.2.dev11 → adv_optm-1.2.dev13}/adv_optm/optim/Lion_adv.py +0 -0
  15. {adv_optm-1.2.dev11 → adv_optm-1.2.dev13}/adv_optm/optim/Prodigy_adv.py +0 -0
  16. {adv_optm-1.2.dev11 → adv_optm-1.2.dev13}/adv_optm/optim/Simplified_AdEMAMix.py +0 -0
  17. {adv_optm-1.2.dev11 → adv_optm-1.2.dev13}/adv_optm/optim/__init__.py +0 -0
  18. {adv_optm-1.2.dev11 → adv_optm-1.2.dev13}/adv_optm/util/BF16_Stochastic_Rounding.py +0 -0
  19. {adv_optm-1.2.dev11 → adv_optm-1.2.dev13}/adv_optm/util/Effective_Shape.py +0 -0
  20. {adv_optm-1.2.dev11 → adv_optm-1.2.dev13}/adv_optm/util/Kourkoutas.py +0 -0
  21. {adv_optm-1.2.dev11 → adv_optm-1.2.dev13}/adv_optm/util/NNMF.py +0 -0
  22. {adv_optm-1.2.dev11 → adv_optm-1.2.dev13}/adv_optm/util/Newton_Schulz.py +0 -0
  23. {adv_optm-1.2.dev11 → adv_optm-1.2.dev13}/adv_optm/util/One_Bit_Boolean.py +0 -0
  24. {adv_optm-1.2.dev11 → adv_optm-1.2.dev13}/adv_optm/util/OrthoGrad.py +0 -0
  25. {adv_optm-1.2.dev11 → adv_optm-1.2.dev13}/adv_optm/util/__init__.py +0 -0
  26. {adv_optm-1.2.dev11 → adv_optm-1.2.dev13}/adv_optm.egg-info/SOURCES.txt +0 -0
  27. {adv_optm-1.2.dev11 → adv_optm-1.2.dev13}/adv_optm.egg-info/dependency_links.txt +0 -0
  28. {adv_optm-1.2.dev11 → adv_optm-1.2.dev13}/adv_optm.egg-info/requires.txt +0 -0
  29. {adv_optm-1.2.dev11 → adv_optm-1.2.dev13}/adv_optm.egg-info/top_level.txt +0 -0
  30. {adv_optm-1.2.dev11 → adv_optm-1.2.dev13}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: adv_optm
3
- Version: 1.2.dev11
3
+ Version: 1.2.dev13
4
4
  Summary: A family of highly efficient, lightweight yet powerful optimizers.
5
5
  Home-page: https://github.com/Koratahiu/Advanced_Optimizers
6
6
  Author: Koratahiu
@@ -20,4 +20,4 @@ __all__ = [
20
20
  "AdaMuon_adv",
21
21
  ]
22
22
 
23
- __version__ = "1.2.dev11"
23
+ __version__ = "1.2.dev13"
@@ -0,0 +1,664 @@
1
+ import torch
2
+
3
+ from ..util.BF16_Stochastic_Rounding import add_stochastic_
4
+ from ..util.Newton_Schulz import _newton_schulz_iteration
5
+ from ..util.Effective_Shape import _get_effective_shape
6
+ from ..util.NNMF import _nnmf,_unnmf
7
+ from ..util.One_Bit_Boolean import _pack_bools, _unpack_bools
8
+ from ..util.OrthoGrad import _orthogonalize_gradient
9
+ from ..util.Kourkoutas import KourkoutasHelper
10
+
11
+ class AdaMuon_adv(torch.optim.Optimizer):
12
+ """
13
+ IImplements an advanced AdaMuon optimizer algorithm.
14
+
15
+ AdaMuon combines the geometry-aware updates of Muon with the element-wise
16
+ adaptivity of Adam. It is designed for 2D parameters (e.g., linear layers)
17
+ and can handle higher-dimensional parameters by flattening.
18
+
19
+ The algorithm incorporates three key mechanisms:
20
+ 1. A sign-stabilized orthogonal update, where the sign of the momentum is
21
+ orthogonalized instead of the momentum itself.
22
+ 2. An element-wise second momentum estimator applied to the orthogonalized
23
+ update directions.
24
+ 3. An RMS-aligned rescaling strategy to match the update magnitude of Adam,
25
+ allowing for reuse of learning rate schedules.
26
+
27
+ When `MuonWithAuxAdam` is enabled, this single optimizer class handles both
28
+ 'muon' and 'adam' parameter groups, dispatching to the appropriate logic internally.
29
+
30
+ Args:
31
+ params (iterable): iterable of parameters to optimize or dicts defining
32
+ parameter groups.
33
+ lr (float): learning rate (default: 1e-3).
34
+ betas (tuple[float, float]): coefficients used for both first and second moment
35
+ estimation (default: (0.95, 0.95))
36
+ weight_decay (float): weight decay (L2 penalty) (default: 0.1).
37
+ eps (float): term added to the denominator for adaptive scaling to improve
38
+ numerical stability (default: 1e-8).
39
+ rms_target (float): The target Root-Mean-Square value for the final update
40
+ vector, used for RMS-aligned rescaling. Allows for the reuse of existing Adam
41
+ learning rate schedules. (default: 0.2).
42
+ ns_steps (int): number of Newton-Schulz iterations to perform (default: 5).
43
+ ns_eps (float): epsilon for Newton-Schulz normalization stability (default: 1e-7).
44
+ ns_coeffs (tuple[float, float, float]): The (a, b, c) coefficients for the
45
+ quintic polynomial in the Newton-Schulz iteration.
46
+ (default: (3.4445, -4.7750, 2.0315)).
47
+ stochastic_rounding (bool): whether to use stochastic rounding for
48
+ BF16 parameter updates (default: True).
49
+ nesterov (bool): enables Nesterov momentum (default: False).
50
+ use_atan2 (bool): whether to use the atan2 update rule. (default: False)
51
+ Simplified_AdEMAMix (bool): whether to use the Simplified AdEMAMix update rule.
52
+ This changes the update to `alpha_grad * grad + mt`, which can be
53
+ more responsive, especially for small batch sizes. (default: False)
54
+ alpha_grad (float): Mixing coefficient for the Simplified AdEMAMix update rule
55
+ (only used when `Simplified_AdEMAMix` is `True`). Controls the weight of the
56
+ current gradient. For small batch sizes, use high values (e.g., 10-100) to be
57
+ more responsive. For large batch sizes, use low values (e.g., 0-1) for
58
+ stability. (default: 100.0)
59
+ vector_reshape (bool): whether to reshape 1D vectors into 2D
60
+ matrices to apply low-rank compression (default: True).
61
+ low_rank_ortho (bool): If True, enables low-rank orthogonalization, which
62
+ projects the update to a lower rank before orthogonalization.
63
+ (default: False)
64
+ ortho_rank (int): The rank for low-rank orthogonalization.
65
+ (default: 128)
66
+ nnmf_factor (bool): whether to use the factorization or disable it to use
67
+ the uncompressed optimizer. (default: False)
68
+ --- Auxiliary AdamW_adv Parameters (used for 'adam' groups) ---
69
+ adam_betas (tuple[float, float]): Betas for the AdamW optimizer part.
70
+ adam_eps (float): Epsilon for the AdamW optimizer part.
71
+ adam_weight_decay (float): Weight decay for the AdamW optimizer part.
72
+ adam_use_bias_correction (bool): Bias correction for AdamW.
73
+ adam_use_atan2 (bool): Atan2 update rule for AdamW.
74
+ adam_cautious_mask (bool): Cautious masking for AdamW.
75
+ adam_grams_moment (bool): Grams-style updates for AdamW.
76
+ adam_orthogonal_gradient (bool): OrthoGrad for AdamW.
77
+ adam_use_AdEMAMix (bool): AdEMAMix for AdamW.
78
+ adam_beta3_ema (float): Beta3 for AdEMAMix.
79
+ adam_alpha (float): Alpha for AdEMAMix.
80
+ adam_kourkoutas_beta (bool): Kourkoutas-β for AdamW.
81
+ """
82
+
83
+ def __init__(
84
+ self,
85
+ params,
86
+ lr: float = 1e-3,
87
+ betas: tuple[float, float] = (0.95, 0.95),
88
+ weight_decay: float = 0.1,
89
+ eps: float = 1e-8,
90
+ rms_target: float = 0.2,
91
+ ns_steps: int = 5,
92
+ ns_eps: float = 1e-7,
93
+ ns_coeffs: tuple[float, float, float] = (3.4445, -4.7750, 2.0315),
94
+ stochastic_rounding: bool = False,
95
+ use_atan2: bool = False,
96
+ nesterov: bool = False,
97
+ Simplified_AdEMAMix: bool = False,
98
+ alpha_grad: float = 100.0,
99
+ vector_reshape: bool = False,
100
+ # Low-rank Muon
101
+ low_rank_ortho: bool = False,
102
+ ortho_rank: int = 128,
103
+ nnmf_factor: bool = False,
104
+ # Compiled
105
+ compiled_optimizer: bool = False,
106
+ # --- AdamW_adv specific parameters ---
107
+ adam_betas: tuple[float, float] = (0.9, 0.99),
108
+ adam_eps: float = 1e-8,
109
+ adam_weight_decay: float = 0.0,
110
+ adam_use_bias_correction: bool = True,
111
+ adam_use_atan2: bool = False,
112
+ adam_cautious_mask: bool = False,
113
+ adam_grams_moment: bool = False,
114
+ adam_orthogonal_gradient: bool = False,
115
+ adam_use_AdEMAMix: bool = False,
116
+ adam_beta3_ema: float = 0.9999,
117
+ adam_alpha: float = 5.0,
118
+ adam_kourkoutas_beta: bool = False,
119
+ adam_beta2_min: float = 0.9,
120
+ adam_ema_alpha: float = 0.95,
121
+ adam_tiny_spike: float = 1e-9,
122
+ adam_k_warmup_steps: int = 0,
123
+ adam_nnmf_factor: bool = False,
124
+ ):
125
+ if not (lr >= 0.0):
126
+ raise ValueError(f"Learning-rate should be >= 0.0. Got {lr}")
127
+ if not (weight_decay >= 0.0):
128
+ raise ValueError(f"Weight-decay should be >= 0.0. Got {weight_decay}")
129
+ if not (ns_steps > 0):
130
+ raise ValueError(f"Newton-Schulz steps should be > 0. Got {ns_steps}")
131
+ if Simplified_AdEMAMix and nesterov:
132
+ print("Warning: nesterov is incompatible with Simplified_AdEMAMix, Disabling nesterov.")
133
+ nesterov = False
134
+
135
+ defaults = {
136
+ "lr": lr, "betas": betas, "weight_decay": weight_decay,
137
+ "eps": eps, "rms_target": rms_target, "ns_steps": ns_steps,
138
+ "ns_eps": ns_eps, "ns_coeffs": ns_coeffs, "nnmf_factor": nnmf_factor,
139
+ "vector_reshape": vector_reshape,
140
+ "nesterov":nesterov, "use_atan2":use_atan2,
141
+ "Simplified_AdEMAMix": Simplified_AdEMAMix, "alpha_grad": alpha_grad,
142
+ # Low-rank Ortho
143
+ "low_rank_ortho": low_rank_ortho, "ortho_rank": ortho_rank,
144
+ "compiled_optimizer":compiled_optimizer,
145
+ # AdamW_adv defaults
146
+ "adam_betas": adam_betas, "adam_eps": adam_eps, "adam_weight_decay": adam_weight_decay,
147
+ "adam_use_bias_correction": adam_use_bias_correction, "adam_use_atan2": adam_use_atan2,
148
+ "adam_cautious_mask": adam_cautious_mask, "adam_grams_moment": adam_grams_moment,
149
+ "adam_orthogonal_gradient": adam_orthogonal_gradient,
150
+ "adam_use_AdEMAMix": adam_use_AdEMAMix, "adam_beta3_ema": adam_beta3_ema, "adam_alpha": adam_alpha,
151
+ "adam_kourkoutas_beta": adam_kourkoutas_beta, "adam_beta2_min": adam_beta2_min,
152
+ "adam_ema_alpha": adam_ema_alpha, "adam_tiny_spike": adam_tiny_spike,
153
+ "adam_k_warmup_steps": adam_k_warmup_steps, "adam_nnmf_factor": adam_nnmf_factor,
154
+ }
155
+ self.stochastic_rounding = stochastic_rounding
156
+
157
+ super().__init__(params, defaults)
158
+
159
+ self.global_step = 0 # For Adam bias correction and Kourkoutas
160
+ self.kourkoutas_helper = None
161
+ if any(group.get('adam_kourkoutas_beta', False) for group in self.param_groups):
162
+ self.kourkoutas_helper = KourkoutasHelper(self)
163
+
164
+ self.init_step()
165
+
166
+ # Initialize compiled functions to None
167
+ self._compiled_muon_step = None
168
+ self._compiled_adam_step = None
169
+
170
+ if compiled_optimizer:
171
+ print("Compiling AdaMuon_adv optimizer paths...")
172
+ torch._dynamo.config.cache_size_limit = 8192
173
+ self.compile(fullgraph=True)
174
+
175
+ @property
176
+ def supports_fused_back_pass(self):
177
+ return True
178
+
179
+ @property
180
+ def supports_memory_efficient_fp16(self):
181
+ return True
182
+
183
+ @property
184
+ def supports_flat_params(self):
185
+ return False
186
+
187
+ def init_step(self):
188
+ for group in self.param_groups:
189
+ for i, p in enumerate(group['params']):
190
+ self.__init_state(p, group)
191
+
192
+ @torch.no_grad()
193
+ def __init_state(self, p, group):
194
+ state = self.state[p]
195
+
196
+ if len(state) > 0:
197
+ return
198
+
199
+ optim_type = group.get('optim_type', 'muon')
200
+
201
+ if optim_type == 'muon':
202
+
203
+ state['factored'] = (
204
+ group['nnmf_factor'] and
205
+ not (len(p.shape) == 1 and not group['vector_reshape'])
206
+ )
207
+ dtype = torch.float32 if state['factored'] else p.dtype
208
+ device = p.device
209
+
210
+ if state['factored']:
211
+ state['effective_shape'] = _get_effective_shape(p.numel())
212
+ d1, d2 = state['effective_shape']
213
+ state['mu_mbuf_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
214
+ state['mv_mbuf_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
215
+ packed_d2 = (d2 + 7) // 8
216
+ state['sign_buf'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=device)
217
+ state['mu_vbuf_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
218
+ state['mv_vbuf_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
219
+ else:
220
+ if len(p.shape) >= 2:
221
+ state['second_momentum_buffer'] = torch.zeros_like(p)
222
+ state['momentum_buffer'] = torch.zeros_like(p)
223
+
224
+
225
+ elif optim_type == 'adam':
226
+
227
+ state['factored'] = (
228
+ group['adam_nnmf_factor'] and
229
+ not (len(p.shape) == 1 and not group['vector_reshape'])
230
+ )
231
+ dtype = torch.float32 if state['factored'] else p.dtype
232
+ device = p.device
233
+
234
+ if state['factored']:
235
+ state['effective_shape'] = _get_effective_shape(p.numel())
236
+ d1, d2 = state['effective_shape']
237
+ # First moment (m)
238
+ if group['adam_betas'][0] > 0:
239
+ state['mu_m_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
240
+ state['mv_m_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
241
+ if not group.get('adam_grams_moment'):
242
+ packed_d2 = (d2 + 7) // 8
243
+ state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=device)
244
+ if group.get('adam_use_AdEMAMix'):
245
+ state['mu_m_slow_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
246
+ state['mv_m_slow_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
247
+ packed_d2 = (d2 + 7) // 8
248
+ state['sign_slow'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=p.device)
249
+ # Second moment (v)
250
+ state['mu_v_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
251
+ state['mv_v_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
252
+ else: # Fallback to standard AdamW for non-factored tensors
253
+ if group['adam_betas'][0] > 0:
254
+ state['exp_avg'] = torch.zeros_like(p, device=device, dtype=dtype)
255
+ if group.get('adam_use_AdEMAMix'):
256
+ state['exp_avg_slow'] = torch.zeros_like(p, device=device, dtype=dtype)
257
+ state['exp_avg_sq'] = torch.zeros_like(p, device=device, dtype=dtype)
258
+
259
+ @torch.no_grad()
260
+ def _muon_step_parameter(self, p, grad, state, group, lr):
261
+ # Retrieve hyperparameters
262
+ beta1, beta2 = group['betas']
263
+ nesterov = group['nesterov']
264
+ Simplified_AdEMAMix = group['Simplified_AdEMAMix']
265
+ alpha_grad = group['alpha_grad']
266
+
267
+ if state['factored']: # Factored AdaMuon
268
+
269
+ # Reconstruct momentum from previous step's factors & sign
270
+ d1, d2 = state['effective_shape']
271
+ mt_buf = _unnmf((state['mu_mbuf_nmf'], state['mv_mbuf_nmf']))
272
+ unpacked_sign = _unpack_bools(state['sign_buf'], original_m=d2)
273
+ torch.where(unpacked_sign, mt_buf, -mt_buf, out=mt_buf)
274
+ del unpacked_sign
275
+
276
+ # Update momentum in full-size
277
+ grad_reshaped = grad.view(d1, d2)
278
+ mt_buf.mul_(beta1).add_(grad_reshaped)
279
+
280
+ if nesterov:
281
+ signed_m_buf = torch.sign(grad_reshaped.add(mt_buf, alpha=beta1))
282
+ elif Simplified_AdEMAMix:
283
+ signed_m_buf = torch.sign(mt_buf.add(grad_reshaped, alpha=alpha_grad))
284
+ else:
285
+ signed_m_buf = torch.sign(mt_buf)
286
+ del grad_reshaped
287
+
288
+ # Orthogonalization step
289
+ if group['low_rank_ortho']:
290
+ # Low-Rank Orthogonalization on the reconstructed matrix
291
+ M = signed_m_buf
292
+ r = min(group['ortho_rank'], M.shape[0], M.shape[1])
293
+ if r > 0:
294
+ G_sketch = torch.randn(M.shape[1], r, device=M.device, dtype=M.dtype)
295
+ MG = M @ G_sketch
296
+ if MG.dtype != torch.float32:
297
+ MG_dtype = M.dtype
298
+ Q, _ = torch.linalg.qr(MG.float())
299
+ Q = Q.to(MG_dtype)
300
+ else:
301
+ Q, _ = torch.linalg.qr(MG)
302
+ projected_M = Q.T @ M
303
+ ortho_projected_M = _newton_schulz_iteration(
304
+ projected_M, steps=group['ns_steps'], eps=group['ns_eps'], coeffs=group['ns_coeffs']
305
+ )
306
+ update = Q @ ortho_projected_M
307
+ else: # Fallback for invalid rank
308
+ update = _newton_schulz_iteration(
309
+ signed_m_buf, steps=group['ns_steps'], eps=group['ns_eps'], coeffs=group['ns_coeffs']
310
+ )
311
+ else:
312
+ # Original full Newton-Schulz
313
+ update = _newton_schulz_iteration(
314
+ signed_m_buf,
315
+ steps=group['ns_steps'],
316
+ eps=group['ns_eps'],
317
+ coeffs=group['ns_coeffs'],
318
+ )
319
+ del signed_m_buf
320
+
321
+ # Reconstruct second momentum from previous step's factors
322
+ vt_buf = _unnmf((state['mu_vbuf_nmf'], state['mv_vbuf_nmf']))
323
+
324
+ # Update second momentum in full-size
325
+ vt_buf.mul_(beta2).addcmul_(update, update, value=1 - beta2)
326
+
327
+ # Apply second momentum update (adaptive scaling)
328
+ if group['use_atan2']:
329
+ a = 1.2732395
330
+ denom = vt_buf.sqrt()
331
+ update.atan2_(denom).mul_(a)
332
+ else:
333
+ denom = vt_buf.sqrt().add_(group['eps'])
334
+ update.div_(denom)
335
+ del denom
336
+
337
+ # RMS-aligned rescaling
338
+ rms_target = group['rms_target']
339
+ num_elements = update.numel()
340
+ # Add eps to prevent division by zero
341
+ update.mul_(rms_target * (num_elements ** 0.5) / (update.norm() + group['eps']))
342
+
343
+ update = update.view(p.shape).mul_(lr)
344
+ del num_elements
345
+
346
+ # Compress updated moments and store new factors
347
+ state['sign_buf'] = _pack_bools(mt_buf > 0)
348
+ _nnmf(mt_buf.abs(), out=(state['mu_mbuf_nmf'], state['mv_mbuf_nmf']))
349
+ del mt_buf
350
+
351
+ _nnmf(vt_buf.abs(), out=(state['mu_vbuf_nmf'], state['mv_vbuf_nmf']))
352
+ del vt_buf
353
+
354
+ else: # Standard AdaMuon logic for non-factored tensors
355
+
356
+ if len(p.shape) >= 2:
357
+
358
+ original_shape = p.shape
359
+
360
+ # Momentum update
361
+ mt_buf = state['momentum_buffer']
362
+ mt_buf.mul_(beta1).add_(grad)
363
+
364
+ if nesterov:
365
+ signed_m_buf = torch.sign(grad.add(mt_buf, alpha=beta1))
366
+ elif Simplified_AdEMAMix:
367
+ signed_m_buf = torch.sign(mt_buf.add(grad, alpha=alpha_grad))
368
+ else:
369
+ signed_m_buf = torch.sign(mt_buf)
370
+
371
+ # Flatten if necessary (e.g., for Conv layers)
372
+ signed_m_buf = signed_m_buf.view(original_shape[0], -1)
373
+
374
+ # Orthogonalization step
375
+ if group['low_rank_ortho']:
376
+ # Low-Rank Orthogonalization on the reconstructed matrix
377
+ M = signed_m_buf
378
+ r = min(group['ortho_rank'], M.shape[0], M.shape[1])
379
+ if r > 0:
380
+ G_sketch = torch.randn(M.shape[1], r, device=M.device, dtype=M.dtype)
381
+ MG = M @ G_sketch
382
+ if MG.dtype != torch.float32:
383
+ MG_dtype = M.dtype
384
+ Q, _ = torch.linalg.qr(MG.float())
385
+ Q = Q.to(MG_dtype)
386
+ else:
387
+ Q, _ = torch.linalg.qr(MG)
388
+ projected_M = Q.T @ M
389
+ ortho_projected_M = _newton_schulz_iteration(
390
+ projected_M, steps=group['ns_steps'], eps=group['ns_eps'], coeffs=group['ns_coeffs']
391
+ )
392
+ update = Q @ ortho_projected_M
393
+ else: # Fallback for invalid rank
394
+ update = _newton_schulz_iteration(
395
+ signed_m_buf, steps=group['ns_steps'], eps=group['ns_eps'], coeffs=group['ns_coeffs']
396
+ )
397
+ else:
398
+ # Original full Newton-Schulz
399
+ update = _newton_schulz_iteration(
400
+ signed_m_buf,
401
+ steps=group['ns_steps'],
402
+ eps=group['ns_eps'],
403
+ coeffs=group['ns_coeffs'],
404
+ )
405
+ del signed_m_buf
406
+
407
+ update = update.view(original_shape)
408
+
409
+ vt_buf = state['second_momentum_buffer']
410
+ vt_buf.mul_(beta2).addcmul_(update, update, value=1 - beta2)
411
+
412
+ # Apply second momentum update (adaptive scaling)
413
+ if group['use_atan2']:
414
+ a = 1.2732395
415
+ denom = vt_buf.sqrt()
416
+ update.atan2_(denom).mul_(a)
417
+ else:
418
+ denom = vt_buf.sqrt().add_(group['eps'])
419
+ update.div_(denom)
420
+ del denom
421
+
422
+ # RMS-aligned rescaling
423
+ rms_target = group['rms_target']
424
+ num_elements = update.numel()
425
+ # Add eps to prevent division by zero
426
+ update.mul_(rms_target * (num_elements ** 0.5) / (update.norm() + group['eps']))
427
+
428
+ del num_elements
429
+
430
+ update.mul_(lr)
431
+
432
+ else: # Fallback to standard SGD with momentum for 1D params (biases, etc.)
433
+ # Momentum update
434
+ mt_buf = state['momentum_buffer']
435
+ mt_buf.mul_(beta1).add_(grad)
436
+ if nesterov:
437
+ # Nesterov momentum
438
+ update = grad.add(mt_buf, alpha=beta1)
439
+ # elif Simplified_AdEMAMix: # TODO, it will break SGD since it requires x100 lower LR
440
+ # update = mt_buf.add(grad, alpha=alpha_grad)
441
+ else:
442
+ update = mt_buf.clone()
443
+ update.mul_(lr)
444
+
445
+ # Decoupled weight decay
446
+ if group["weight_decay"] != 0:
447
+ if p.dtype == torch.bfloat16 and self.stochastic_rounding:
448
+ add_stochastic_(p.data, p.data, alpha=-group["weight_decay"] * lr)
449
+ else:
450
+ p.data.add_(p.data, alpha=-group["weight_decay"] * lr)
451
+
452
+ if p.dtype == torch.bfloat16 and self.stochastic_rounding:
453
+ add_stochastic_(p.data, -update)
454
+ else:
455
+ p.data.add_(-update)
456
+ del update
457
+
458
+ @torch.no_grad()
459
+ def _adam_step_parameter(self, p, grad, state, group, lr, bias_correction1, bias_correction2):
460
+ if grad.dtype != torch.float32 and state.get('factored', False):
461
+ grad = grad.float()
462
+ if group.get("adam_orthogonal_gradient"):
463
+ grad = _orthogonalize_gradient(p, grad)
464
+
465
+ beta1_adam, beta2_adam = group['adam_betas']
466
+
467
+ if group.get('adam_kourkoutas_beta', False):
468
+ # Accumulate current grad's norm for the *next* step
469
+ self.kourkoutas_helper.accumulate_gradient_sq_norm(p, grad)
470
+ # Get the dynamic beta2_adam calculated in prepare_step()
471
+ beta2_adam = self.kourkoutas_helper.get_beta2(p, group)
472
+
473
+ step_size = lr / bias_correction1
474
+
475
+ if group.get('adam_use_AdEMAMix'):
476
+ beta3_ema = group['adam_beta3_ema']
477
+ alpha = group['adam_alpha']
478
+
479
+ if state['factored']:
480
+ d1, d2 = state['effective_shape']
481
+ grad_reshaped = grad.view(d1, d2)
482
+
483
+ # Reconstruct momentum from previous step's factors
484
+ if beta1_adam > 0:
485
+ mt = _unnmf((state['mu_m_nmf'], state['mv_m_nmf']))
486
+ if not group.get('adam_grams_moment'):
487
+ unpacked_sign = _unpack_bools(state['sign'], original_m=d2)
488
+ torch.where(unpacked_sign, mt, -mt, out=mt)
489
+ del unpacked_sign
490
+ # Update momentum in full-size
491
+ mt.mul_(beta1_adam).add_(grad_reshaped, alpha=1.0 - beta1_adam)
492
+ if group.get('adam_grams_moment'):
493
+ mt = (grad_reshaped.sign().mul_(mt.abs()))
494
+ elif group.get('adam_cautious_mask'):
495
+ mask = (mt * grad_reshaped > 0).to(grad_reshaped.dtype)
496
+ mask.div_(mask.mean().clamp_(min=1e-3))
497
+ mt.mul_(mask)
498
+ del mask
499
+
500
+ vt = _unnmf((state['mu_v_nmf'], state['mv_v_nmf']))
501
+ vt.mul_(beta2_adam).addcmul_(grad_reshaped, grad_reshaped, value=1.0 - beta2_adam)
502
+
503
+ if group.get('adam_use_AdEMAMix'):
504
+ mt_slow = _unnmf((state['mu_m_slow_nmf'], state['mv_m_slow_nmf']))
505
+ if state['sign_slow'].dtype != torch.uint8:
506
+ state['sign_slow'] = state['sign_slow'].to(torch.uint8)
507
+ unpacked_sign_slow = _unpack_bools(state['sign_slow'], original_m=d2)
508
+ torch.where(unpacked_sign_slow, mt_slow, -mt_slow, out=mt_slow)
509
+ del unpacked_sign_slow
510
+
511
+ mt_slow.mul_(beta3_ema).add_(grad_reshaped, alpha=1.0 - beta3_ema)
512
+ if beta1_adam > 0:
513
+ update = torch.add(mt, mt_slow, alpha=alpha)
514
+ else:
515
+ update = torch.add(grad_reshaped, mt_slow, alpha=alpha)
516
+ else:
517
+ update = mt.clone() if beta1_adam > 0 else grad_reshaped.clone()
518
+ del grad_reshaped
519
+
520
+ if group['adam_use_atan2']:
521
+ a = 1.2732395
522
+ denom = (vt.sqrt() / (bias_correction2**0.5))
523
+ update.atan2_(denom).mul_(a)
524
+ else:
525
+ denom = (vt.sqrt() / (bias_correction2**0.5)).add_(group['adam_eps'])
526
+ update.div_(denom)
527
+ del denom
528
+
529
+ update = update.view(p.shape).mul_(step_size)
530
+
531
+ # Compress updated moments and store new factors
532
+ if beta1_adam > 0:
533
+ if not group.get('adam_grams_moment'):
534
+ state['sign'] = _pack_bools(mt > 0)
535
+ _nnmf(mt.abs(), out=(state['mu_m_nmf'], state['mv_m_nmf']))
536
+ del mt
537
+ if group.get('adam_use_AdEMAMix'):
538
+ state['sign_slow'] = _pack_bools(mt_slow > 0)
539
+ _nnmf(mt_slow.abs(), out=(state['mu_m_slow_nmf'], state['mv_m_slow_nmf']))
540
+ del mt_slow
541
+ _nnmf(vt, out=(state['mu_v_nmf'], state['mv_v_nmf']))
542
+ del vt
543
+
544
+ else: # Standard AdamW logic for non-factored tensors
545
+ exp_avg_sq = state['exp_avg_sq']
546
+
547
+ if beta1_adam > 0:
548
+ exp_avg = state['exp_avg']
549
+ exp_avg.mul_(beta1_adam).add_(grad, alpha=1 - beta1_adam)
550
+ if group.get('adam_grams_moment'):
551
+ exp_avg = grad.sign().mul_(exp_avg.abs())
552
+ elif group.get('adam_cautious_mask'):
553
+ mask = (exp_avg * grad > 0).to(grad.dtype)
554
+ mask.div_(mask.mean().clamp_(min=1e-3))
555
+ exp_avg.mul_(mask)
556
+ del mask
557
+
558
+ if group.get('adam_use_AdEMAMix'):
559
+ exp_avg_slow = state['exp_avg_slow']
560
+ exp_avg_slow.mul_(beta3_ema).add_(grad, alpha=1 - beta3_ema)
561
+ if beta1_adam > 0:
562
+ update = torch.add(exp_avg, exp_avg_slow, alpha=alpha)
563
+ else:
564
+ update = torch.add(grad, exp_avg_slow, alpha=alpha)
565
+ else:
566
+ update = exp_avg.clone() if beta1_adam > 0 else grad.clone()
567
+
568
+ exp_avg_sq.mul_(beta2_adam).addcmul_(grad, grad.conj(), value=1 - beta2_adam)
569
+
570
+ if group.get('adam_use_atan2'):
571
+ a = 1.2732395
572
+ denom = (exp_avg_sq.sqrt() / (bias_correction2**0.5))
573
+ update.atan2_(denom).mul_(a)
574
+ else:
575
+ denom = (exp_avg_sq.sqrt() / (bias_correction2**0.5)).add_(group['adam_eps'])
576
+ update.div_(denom)
577
+ del denom
578
+
579
+ update.mul_(step_size)
580
+
581
+ # Decoupled weight decay
582
+ if group["adam_weight_decay"] != 0:
583
+ if p.dtype == torch.bfloat16 and self.stochastic_rounding:
584
+ add_stochastic_(p.data, p.data, alpha=-group["adam_weight_decay"] * lr)
585
+ else:
586
+ p.data.add_(p.data, alpha=-group["adam_weight_decay"] * lr)
587
+
588
+ if p.dtype == torch.bfloat16 and self.stochastic_rounding:
589
+ add_stochastic_(p.data, -update)
590
+ else:
591
+ p.data.add_(-update)
592
+ del update
593
+
594
+ @torch.no_grad()
595
+ def step_parameter(self, p: torch.Tensor, group: dict, i: int | None = None):
596
+ grad = p.grad
597
+ if grad is None:
598
+ return
599
+ state = self.state[p]
600
+
601
+
602
+
603
+ # Determine if using Adam or Muon based on state keys
604
+ # We can use optm_type but I see this as a safer way.
605
+ if 'momentum_buffer' in state or 'mu_mbuf_nmf' in state:
606
+ use_adam = False
607
+ else:
608
+ use_adam = True
609
+
610
+ lr = group['lr']
611
+ is_compiled = group.get('compiled_optimizer', False)
612
+
613
+ if use_adam:
614
+ if self.kourkoutas_helper:
615
+ # Prepare Kourkoutas-β once per optimizer step.
616
+ self.kourkoutas_helper.maybe_prepare_step(self.global_step)
617
+ # Adam-specific setup (bias correction)
618
+ if group['adam_use_bias_correction']:
619
+ current_step = self.global_step + 1
620
+ beta1_adam, beta2_adam = group['adam_betas']
621
+ bias_correction1 = 1.0 - beta1_adam ** current_step
622
+ bias_correction2 = 1.0 - beta2_adam ** current_step
623
+ else:
624
+ bias_correction1 = 1.0
625
+ bias_correction2 = 1.0
626
+ # Dispatch to compiled or uncompiled Adam step
627
+ if is_compiled and self._compiled_adam_step is not None:
628
+ # Tensors must be used for compiled functions
629
+ lr_tensor = torch.tensor(lr, device=p.device)
630
+ bc1_tensor = torch.tensor(bias_correction1, device=p.device)
631
+ bc2_tensor = torch.tensor(bias_correction2, device=p.device)
632
+ self._compiled_adam_step(p, grad, state, group, lr_tensor, bc1_tensor, bc2_tensor)
633
+ else:
634
+ self._adam_step_parameter(p, grad, state, group, lr, bias_correction1, bias_correction2)
635
+ else: # Muon path
636
+ # Dispatch to compiled or uncompiled Muon step
637
+ if is_compiled and self._compiled_muon_step is not None:
638
+ lr_tensor = torch.tensor(lr, device=p.device)
639
+ self._compiled_muon_step(p, grad, state, group, lr_tensor)
640
+ else:
641
+ self._muon_step_parameter(p, grad, state, group, lr)
642
+
643
+
644
+ def compile(self, *args, **kwargs):
645
+ print("Compiling AdaMuon step path...")
646
+ self._compiled_muon_step = torch.compile(self._muon_step_parameter, *args, **kwargs)
647
+ print("Compiling AuxAdam step path...")
648
+ self._compiled_adam_step = torch.compile(self._adam_step_parameter, *args, **kwargs)
649
+
650
+ @torch.no_grad()
651
+ def step(self, closure=None):
652
+ """Performs a single optimization step."""
653
+ loss = None
654
+ if closure is not None:
655
+ with torch.enable_grad():
656
+ loss = closure()
657
+
658
+ for group in self.param_groups:
659
+ for i, p in enumerate(group['params']):
660
+ self.step_parameter(p, group, i)
661
+
662
+ self.global_step += 1
663
+
664
+ return loss