adv-optm 1.2.dev12__py3-none-any.whl → 1.2.dev13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of adv-optm might be problematic. Click here for more details.
- adv_optm/__init__.py +1 -1
- adv_optm/optim/AdaMuon_adv.py +349 -82
- adv_optm/optim/Muon_adv.py +357 -87
- {adv_optm-1.2.dev12.dist-info → adv_optm-1.2.dev13.dist-info}/METADATA +1 -1
- {adv_optm-1.2.dev12.dist-info → adv_optm-1.2.dev13.dist-info}/RECORD +8 -8
- {adv_optm-1.2.dev12.dist-info → adv_optm-1.2.dev13.dist-info}/WHEEL +0 -0
- {adv_optm-1.2.dev12.dist-info → adv_optm-1.2.dev13.dist-info}/licenses/LICENSE +0 -0
- {adv_optm-1.2.dev12.dist-info → adv_optm-1.2.dev13.dist-info}/top_level.txt +0 -0
adv_optm/__init__.py
CHANGED
adv_optm/optim/AdaMuon_adv.py
CHANGED
|
@@ -1,13 +1,12 @@
|
|
|
1
1
|
import torch
|
|
2
|
-
from typing import Optional
|
|
3
|
-
|
|
4
|
-
from .AdamW_adv import AdamW_adv
|
|
5
2
|
|
|
6
3
|
from ..util.BF16_Stochastic_Rounding import add_stochastic_
|
|
7
4
|
from ..util.Newton_Schulz import _newton_schulz_iteration
|
|
8
5
|
from ..util.Effective_Shape import _get_effective_shape
|
|
9
6
|
from ..util.NNMF import _nnmf,_unnmf
|
|
10
7
|
from ..util.One_Bit_Boolean import _pack_bools, _unpack_bools
|
|
8
|
+
from ..util.OrthoGrad import _orthogonalize_gradient
|
|
9
|
+
from ..util.Kourkoutas import KourkoutasHelper
|
|
11
10
|
|
|
12
11
|
class AdaMuon_adv(torch.optim.Optimizer):
|
|
13
12
|
"""
|
|
@@ -25,6 +24,8 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
25
24
|
3. An RMS-aligned rescaling strategy to match the update magnitude of Adam,
|
|
26
25
|
allowing for reuse of learning rate schedules.
|
|
27
26
|
|
|
27
|
+
When `MuonWithAuxAdam` is enabled, this single optimizer class handles both
|
|
28
|
+
'muon' and 'adam' parameter groups, dispatching to the appropriate logic internally.
|
|
28
29
|
|
|
29
30
|
Args:
|
|
30
31
|
params (iterable): iterable of parameters to optimize or dicts defining
|
|
@@ -55,8 +56,6 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
55
56
|
current gradient. For small batch sizes, use high values (e.g., 10-100) to be
|
|
56
57
|
more responsive. For large batch sizes, use low values (e.g., 0-1) for
|
|
57
58
|
stability. (default: 100.0)
|
|
58
|
-
vector_reshape_muon (bool): whether to reshape 1D vectors into 2D
|
|
59
|
-
matrices for muon NewtonSchulz (default: False).
|
|
60
59
|
vector_reshape (bool): whether to reshape 1D vectors into 2D
|
|
61
60
|
matrices to apply low-rank compression (default: True).
|
|
62
61
|
low_rank_ortho (bool): If True, enables low-rank orthogonalization, which
|
|
@@ -66,6 +65,19 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
66
65
|
(default: 128)
|
|
67
66
|
nnmf_factor (bool): whether to use the factorization or disable it to use
|
|
68
67
|
the uncompressed optimizer. (default: False)
|
|
68
|
+
--- Auxiliary AdamW_adv Parameters (used for 'adam' groups) ---
|
|
69
|
+
adam_betas (tuple[float, float]): Betas for the AdamW optimizer part.
|
|
70
|
+
adam_eps (float): Epsilon for the AdamW optimizer part.
|
|
71
|
+
adam_weight_decay (float): Weight decay for the AdamW optimizer part.
|
|
72
|
+
adam_use_bias_correction (bool): Bias correction for AdamW.
|
|
73
|
+
adam_use_atan2 (bool): Atan2 update rule for AdamW.
|
|
74
|
+
adam_cautious_mask (bool): Cautious masking for AdamW.
|
|
75
|
+
adam_grams_moment (bool): Grams-style updates for AdamW.
|
|
76
|
+
adam_orthogonal_gradient (bool): OrthoGrad for AdamW.
|
|
77
|
+
adam_use_AdEMAMix (bool): AdEMAMix for AdamW.
|
|
78
|
+
adam_beta3_ema (float): Beta3 for AdEMAMix.
|
|
79
|
+
adam_alpha (float): Alpha for AdEMAMix.
|
|
80
|
+
adam_kourkoutas_beta (bool): Kourkoutas-β for AdamW.
|
|
69
81
|
"""
|
|
70
82
|
|
|
71
83
|
def __init__(
|
|
@@ -79,17 +91,36 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
79
91
|
ns_steps: int = 5,
|
|
80
92
|
ns_eps: float = 1e-7,
|
|
81
93
|
ns_coeffs: tuple[float, float, float] = (3.4445, -4.7750, 2.0315),
|
|
82
|
-
stochastic_rounding: bool =
|
|
94
|
+
stochastic_rounding: bool = False,
|
|
83
95
|
use_atan2: bool = False,
|
|
84
96
|
nesterov: bool = False,
|
|
85
97
|
Simplified_AdEMAMix: bool = False,
|
|
86
98
|
alpha_grad: float = 100.0,
|
|
87
|
-
vector_reshape_muon: bool = False,
|
|
88
99
|
vector_reshape: bool = False,
|
|
89
100
|
# Low-rank Muon
|
|
90
101
|
low_rank_ortho: bool = False,
|
|
91
102
|
ortho_rank: int = 128,
|
|
92
103
|
nnmf_factor: bool = False,
|
|
104
|
+
# Compiled
|
|
105
|
+
compiled_optimizer: bool = False,
|
|
106
|
+
# --- AdamW_adv specific parameters ---
|
|
107
|
+
adam_betas: tuple[float, float] = (0.9, 0.99),
|
|
108
|
+
adam_eps: float = 1e-8,
|
|
109
|
+
adam_weight_decay: float = 0.0,
|
|
110
|
+
adam_use_bias_correction: bool = True,
|
|
111
|
+
adam_use_atan2: bool = False,
|
|
112
|
+
adam_cautious_mask: bool = False,
|
|
113
|
+
adam_grams_moment: bool = False,
|
|
114
|
+
adam_orthogonal_gradient: bool = False,
|
|
115
|
+
adam_use_AdEMAMix: bool = False,
|
|
116
|
+
adam_beta3_ema: float = 0.9999,
|
|
117
|
+
adam_alpha: float = 5.0,
|
|
118
|
+
adam_kourkoutas_beta: bool = False,
|
|
119
|
+
adam_beta2_min: float = 0.9,
|
|
120
|
+
adam_ema_alpha: float = 0.95,
|
|
121
|
+
adam_tiny_spike: float = 1e-9,
|
|
122
|
+
adam_k_warmup_steps: int = 0,
|
|
123
|
+
adam_nnmf_factor: bool = False,
|
|
93
124
|
):
|
|
94
125
|
if not (lr >= 0.0):
|
|
95
126
|
raise ValueError(f"Learning-rate should be >= 0.0. Got {lr}")
|
|
@@ -98,7 +129,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
98
129
|
if not (ns_steps > 0):
|
|
99
130
|
raise ValueError(f"Newton-Schulz steps should be > 0. Got {ns_steps}")
|
|
100
131
|
if Simplified_AdEMAMix and nesterov:
|
|
101
|
-
print("Warning: nesterov is incompatible with Simplified_AdEMAMix, Disabling
|
|
132
|
+
print("Warning: nesterov is incompatible with Simplified_AdEMAMix, Disabling nesterov.")
|
|
102
133
|
nesterov = False
|
|
103
134
|
|
|
104
135
|
defaults = {
|
|
@@ -106,16 +137,40 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
106
137
|
"eps": eps, "rms_target": rms_target, "ns_steps": ns_steps,
|
|
107
138
|
"ns_eps": ns_eps, "ns_coeffs": ns_coeffs, "nnmf_factor": nnmf_factor,
|
|
108
139
|
"vector_reshape": vector_reshape,
|
|
109
|
-
"vector_reshape_muon": vector_reshape_muon,
|
|
110
140
|
"nesterov":nesterov, "use_atan2":use_atan2,
|
|
111
141
|
"Simplified_AdEMAMix": Simplified_AdEMAMix, "alpha_grad": alpha_grad,
|
|
112
142
|
# Low-rank Ortho
|
|
113
143
|
"low_rank_ortho": low_rank_ortho, "ortho_rank": ortho_rank,
|
|
144
|
+
"compiled_optimizer":compiled_optimizer,
|
|
145
|
+
# AdamW_adv defaults
|
|
146
|
+
"adam_betas": adam_betas, "adam_eps": adam_eps, "adam_weight_decay": adam_weight_decay,
|
|
147
|
+
"adam_use_bias_correction": adam_use_bias_correction, "adam_use_atan2": adam_use_atan2,
|
|
148
|
+
"adam_cautious_mask": adam_cautious_mask, "adam_grams_moment": adam_grams_moment,
|
|
149
|
+
"adam_orthogonal_gradient": adam_orthogonal_gradient,
|
|
150
|
+
"adam_use_AdEMAMix": adam_use_AdEMAMix, "adam_beta3_ema": adam_beta3_ema, "adam_alpha": adam_alpha,
|
|
151
|
+
"adam_kourkoutas_beta": adam_kourkoutas_beta, "adam_beta2_min": adam_beta2_min,
|
|
152
|
+
"adam_ema_alpha": adam_ema_alpha, "adam_tiny_spike": adam_tiny_spike,
|
|
153
|
+
"adam_k_warmup_steps": adam_k_warmup_steps, "adam_nnmf_factor": adam_nnmf_factor,
|
|
114
154
|
}
|
|
115
155
|
self.stochastic_rounding = stochastic_rounding
|
|
116
156
|
|
|
117
157
|
super().__init__(params, defaults)
|
|
118
|
-
|
|
158
|
+
|
|
159
|
+
self.global_step = 0 # For Adam bias correction and Kourkoutas
|
|
160
|
+
self.kourkoutas_helper = None
|
|
161
|
+
if any(group.get('adam_kourkoutas_beta', False) for group in self.param_groups):
|
|
162
|
+
self.kourkoutas_helper = KourkoutasHelper(self)
|
|
163
|
+
|
|
164
|
+
self.init_step()
|
|
165
|
+
|
|
166
|
+
# Initialize compiled functions to None
|
|
167
|
+
self._compiled_muon_step = None
|
|
168
|
+
self._compiled_adam_step = None
|
|
169
|
+
|
|
170
|
+
if compiled_optimizer:
|
|
171
|
+
print("Compiling AdaMuon_adv optimizer paths...")
|
|
172
|
+
torch._dynamo.config.cache_size_limit = 8192
|
|
173
|
+
self.compile(fullgraph=True)
|
|
119
174
|
|
|
120
175
|
@property
|
|
121
176
|
def supports_fused_back_pass(self):
|
|
@@ -129,50 +184,80 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
129
184
|
def supports_flat_params(self):
|
|
130
185
|
return False
|
|
131
186
|
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
187
|
+
def init_step(self):
|
|
188
|
+
for group in self.param_groups:
|
|
189
|
+
for i, p in enumerate(group['params']):
|
|
190
|
+
self.__init_state(p, group)
|
|
136
191
|
|
|
137
|
-
|
|
192
|
+
@torch.no_grad()
|
|
193
|
+
def __init_state(self, p, group):
|
|
138
194
|
state = self.state[p]
|
|
139
195
|
|
|
196
|
+
if len(state) > 0:
|
|
197
|
+
return
|
|
140
198
|
|
|
141
|
-
|
|
142
|
-
if 'step' not in state:
|
|
143
|
-
state['step'] = 0
|
|
199
|
+
optim_type = group.get('optim_type', 'muon')
|
|
144
200
|
|
|
145
|
-
|
|
201
|
+
if optim_type == 'muon':
|
|
202
|
+
|
|
203
|
+
state['factored'] = (
|
|
146
204
|
group['nnmf_factor'] and
|
|
147
205
|
not (len(p.shape) == 1 and not group['vector_reshape'])
|
|
148
206
|
)
|
|
149
|
-
|
|
150
|
-
state['factored'] = should_factor
|
|
151
|
-
|
|
152
|
-
state['reshaped_1d_muon'] = len(p.shape) == 1 and group['vector_reshape_muon']
|
|
153
|
-
|
|
154
|
-
dtype = torch.float32 if group['nnmf_factor'] else p.dtype
|
|
207
|
+
dtype = torch.float32 if state['factored'] else p.dtype
|
|
155
208
|
device = p.device
|
|
156
|
-
|
|
209
|
+
|
|
210
|
+
if state['factored']:
|
|
157
211
|
state['effective_shape'] = _get_effective_shape(p.numel())
|
|
158
212
|
d1, d2 = state['effective_shape']
|
|
159
|
-
|
|
160
|
-
state['
|
|
161
|
-
state['mv_m_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
|
|
213
|
+
state['mu_mbuf_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
|
|
214
|
+
state['mv_mbuf_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
|
|
162
215
|
packed_d2 = (d2 + 7) // 8
|
|
163
|
-
state['
|
|
164
|
-
state['
|
|
165
|
-
state['
|
|
216
|
+
state['sign_buf'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=device)
|
|
217
|
+
state['mu_vbuf_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
|
|
218
|
+
state['mv_vbuf_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
|
|
166
219
|
else:
|
|
167
220
|
if len(p.shape) >= 2:
|
|
168
|
-
state['momentum_buffer'] = torch.zeros_like(p)
|
|
169
221
|
state['second_momentum_buffer'] = torch.zeros_like(p)
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
222
|
+
state['momentum_buffer'] = torch.zeros_like(p)
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
elif optim_type == 'adam':
|
|
226
|
+
|
|
227
|
+
state['factored'] = (
|
|
228
|
+
group['adam_nnmf_factor'] and
|
|
229
|
+
not (len(p.shape) == 1 and not group['vector_reshape'])
|
|
230
|
+
)
|
|
231
|
+
dtype = torch.float32 if state['factored'] else p.dtype
|
|
232
|
+
device = p.device
|
|
175
233
|
|
|
234
|
+
if state['factored']:
|
|
235
|
+
state['effective_shape'] = _get_effective_shape(p.numel())
|
|
236
|
+
d1, d2 = state['effective_shape']
|
|
237
|
+
# First moment (m)
|
|
238
|
+
if group['adam_betas'][0] > 0:
|
|
239
|
+
state['mu_m_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
|
|
240
|
+
state['mv_m_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
|
|
241
|
+
if not group.get('adam_grams_moment'):
|
|
242
|
+
packed_d2 = (d2 + 7) // 8
|
|
243
|
+
state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=device)
|
|
244
|
+
if group.get('adam_use_AdEMAMix'):
|
|
245
|
+
state['mu_m_slow_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
|
|
246
|
+
state['mv_m_slow_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
|
|
247
|
+
packed_d2 = (d2 + 7) // 8
|
|
248
|
+
state['sign_slow'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=p.device)
|
|
249
|
+
# Second moment (v)
|
|
250
|
+
state['mu_v_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
|
|
251
|
+
state['mv_v_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
|
|
252
|
+
else: # Fallback to standard AdamW for non-factored tensors
|
|
253
|
+
if group['adam_betas'][0] > 0:
|
|
254
|
+
state['exp_avg'] = torch.zeros_like(p, device=device, dtype=dtype)
|
|
255
|
+
if group.get('adam_use_AdEMAMix'):
|
|
256
|
+
state['exp_avg_slow'] = torch.zeros_like(p, device=device, dtype=dtype)
|
|
257
|
+
state['exp_avg_sq'] = torch.zeros_like(p, device=device, dtype=dtype)
|
|
258
|
+
|
|
259
|
+
@torch.no_grad()
|
|
260
|
+
def _muon_step_parameter(self, p, grad, state, group, lr):
|
|
176
261
|
# Retrieve hyperparameters
|
|
177
262
|
beta1, beta2 = group['betas']
|
|
178
263
|
nesterov = group['nesterov']
|
|
@@ -183,8 +268,8 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
183
268
|
|
|
184
269
|
# Reconstruct momentum from previous step's factors & sign
|
|
185
270
|
d1, d2 = state['effective_shape']
|
|
186
|
-
mt_buf = _unnmf((state['
|
|
187
|
-
unpacked_sign = _unpack_bools(state['
|
|
271
|
+
mt_buf = _unnmf((state['mu_mbuf_nmf'], state['mv_mbuf_nmf']))
|
|
272
|
+
unpacked_sign = _unpack_bools(state['sign_buf'], original_m=d2)
|
|
188
273
|
torch.where(unpacked_sign, mt_buf, -mt_buf, out=mt_buf)
|
|
189
274
|
del unpacked_sign
|
|
190
275
|
|
|
@@ -231,9 +316,10 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
231
316
|
eps=group['ns_eps'],
|
|
232
317
|
coeffs=group['ns_coeffs'],
|
|
233
318
|
)
|
|
319
|
+
del signed_m_buf
|
|
234
320
|
|
|
235
321
|
# Reconstruct second momentum from previous step's factors
|
|
236
|
-
vt_buf = _unnmf((state['
|
|
322
|
+
vt_buf = _unnmf((state['mu_vbuf_nmf'], state['mv_vbuf_nmf']))
|
|
237
323
|
|
|
238
324
|
# Update second momentum in full-size
|
|
239
325
|
vt_buf.mul_(beta2).addcmul_(update, update, value=1 - beta2)
|
|
@@ -252,49 +338,38 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
252
338
|
rms_target = group['rms_target']
|
|
253
339
|
num_elements = update.numel()
|
|
254
340
|
# Add eps to prevent division by zero
|
|
255
|
-
|
|
341
|
+
update.mul_(rms_target * (num_elements ** 0.5) / (update.norm() + group['eps']))
|
|
256
342
|
|
|
257
|
-
update.mul_(
|
|
258
|
-
|
|
259
|
-
del num_elements, scaling_factor
|
|
343
|
+
update = update.view(p.shape).mul_(lr)
|
|
344
|
+
del num_elements
|
|
260
345
|
|
|
261
346
|
# Compress updated moments and store new factors
|
|
262
|
-
state['
|
|
263
|
-
_nnmf(mt_buf.abs(), out=(state['
|
|
347
|
+
state['sign_buf'] = _pack_bools(mt_buf > 0)
|
|
348
|
+
_nnmf(mt_buf.abs(), out=(state['mu_mbuf_nmf'], state['mv_mbuf_nmf']))
|
|
264
349
|
del mt_buf
|
|
265
350
|
|
|
266
|
-
_nnmf(vt_buf.abs(), out=(state['
|
|
351
|
+
_nnmf(vt_buf.abs(), out=(state['mu_vbuf_nmf'], state['mv_vbuf_nmf']))
|
|
267
352
|
del vt_buf
|
|
268
353
|
|
|
269
354
|
else: # Standard AdaMuon logic for non-factored tensors
|
|
270
355
|
|
|
271
|
-
if len(p.shape) >= 2
|
|
356
|
+
if len(p.shape) >= 2:
|
|
357
|
+
|
|
358
|
+
original_shape = p.shape
|
|
272
359
|
|
|
273
360
|
# Momentum update
|
|
274
361
|
mt_buf = state['momentum_buffer']
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
elif Simplified_AdEMAMix:
|
|
282
|
-
signed_m_buf = torch.sign(mt_buf.add(grad_reshaped, alpha=alpha_grad))
|
|
283
|
-
else:
|
|
284
|
-
signed_m_buf = torch.sign(mt_buf)
|
|
285
|
-
del grad_reshaped
|
|
362
|
+
mt_buf.mul_(beta1).add_(grad)
|
|
363
|
+
|
|
364
|
+
if nesterov:
|
|
365
|
+
signed_m_buf = torch.sign(grad.add(mt_buf, alpha=beta1))
|
|
366
|
+
elif Simplified_AdEMAMix:
|
|
367
|
+
signed_m_buf = torch.sign(mt_buf.add(grad, alpha=alpha_grad))
|
|
286
368
|
else:
|
|
287
|
-
|
|
288
|
-
if nesterov:
|
|
289
|
-
signed_m_buf = torch.sign(grad.add(mt_buf, alpha=beta1))
|
|
290
|
-
elif Simplified_AdEMAMix:
|
|
291
|
-
signed_m_buf = torch.sign(mt_buf.add(grad, alpha=alpha_grad))
|
|
292
|
-
else:
|
|
293
|
-
signed_m_buf = torch.sign(mt_buf)
|
|
369
|
+
signed_m_buf = torch.sign(mt_buf)
|
|
294
370
|
|
|
295
371
|
# Flatten if necessary (e.g., for Conv layers)
|
|
296
|
-
|
|
297
|
-
signed_m_buf = signed_m_buf.view(p.shape[0], -1)
|
|
372
|
+
signed_m_buf = signed_m_buf.view(original_shape[0], -1)
|
|
298
373
|
|
|
299
374
|
# Orthogonalization step
|
|
300
375
|
if group['low_rank_ortho']:
|
|
@@ -327,9 +402,9 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
327
402
|
eps=group['ns_eps'],
|
|
328
403
|
coeffs=group['ns_coeffs'],
|
|
329
404
|
)
|
|
405
|
+
del signed_m_buf
|
|
330
406
|
|
|
331
|
-
|
|
332
|
-
update = update.view(p.shape)
|
|
407
|
+
update = update.view(original_shape)
|
|
333
408
|
|
|
334
409
|
vt_buf = state['second_momentum_buffer']
|
|
335
410
|
vt_buf.mul_(beta2).addcmul_(update, update, value=1 - beta2)
|
|
@@ -348,31 +423,167 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
348
423
|
rms_target = group['rms_target']
|
|
349
424
|
num_elements = update.numel()
|
|
350
425
|
# Add eps to prevent division by zero
|
|
351
|
-
|
|
426
|
+
update.mul_(rms_target * (num_elements ** 0.5) / (update.norm() + group['eps']))
|
|
352
427
|
|
|
353
|
-
|
|
354
|
-
del num_elements, scaling_factor
|
|
428
|
+
del num_elements
|
|
355
429
|
|
|
356
|
-
update.mul_(
|
|
430
|
+
update.mul_(lr)
|
|
357
431
|
|
|
358
|
-
else: # Fallback to standard SGD with momentum for 1D params (biases, etc.)
|
|
432
|
+
else: # Fallback to standard SGD with momentum for 1D params (biases, etc.)
|
|
359
433
|
# Momentum update
|
|
360
434
|
mt_buf = state['momentum_buffer']
|
|
361
435
|
mt_buf.mul_(beta1).add_(grad)
|
|
362
436
|
if nesterov:
|
|
437
|
+
# Nesterov momentum
|
|
363
438
|
update = grad.add(mt_buf, alpha=beta1)
|
|
364
|
-
|
|
365
|
-
|
|
439
|
+
# elif Simplified_AdEMAMix: # TODO, it will break SGD since it requires x100 lower LR
|
|
440
|
+
# update = mt_buf.add(grad, alpha=alpha_grad)
|
|
366
441
|
else:
|
|
367
442
|
update = mt_buf.clone()
|
|
368
|
-
update.mul_(
|
|
443
|
+
update.mul_(lr)
|
|
369
444
|
|
|
370
445
|
# Decoupled weight decay
|
|
371
446
|
if group["weight_decay"] != 0:
|
|
372
447
|
if p.dtype == torch.bfloat16 and self.stochastic_rounding:
|
|
373
|
-
add_stochastic_(p.data, p.data, alpha=-group["weight_decay"] *
|
|
448
|
+
add_stochastic_(p.data, p.data, alpha=-group["weight_decay"] * lr)
|
|
449
|
+
else:
|
|
450
|
+
p.data.add_(p.data, alpha=-group["weight_decay"] * lr)
|
|
451
|
+
|
|
452
|
+
if p.dtype == torch.bfloat16 and self.stochastic_rounding:
|
|
453
|
+
add_stochastic_(p.data, -update)
|
|
454
|
+
else:
|
|
455
|
+
p.data.add_(-update)
|
|
456
|
+
del update
|
|
457
|
+
|
|
458
|
+
@torch.no_grad()
|
|
459
|
+
def _adam_step_parameter(self, p, grad, state, group, lr, bias_correction1, bias_correction2):
|
|
460
|
+
if grad.dtype != torch.float32 and state.get('factored', False):
|
|
461
|
+
grad = grad.float()
|
|
462
|
+
if group.get("adam_orthogonal_gradient"):
|
|
463
|
+
grad = _orthogonalize_gradient(p, grad)
|
|
464
|
+
|
|
465
|
+
beta1_adam, beta2_adam = group['adam_betas']
|
|
466
|
+
|
|
467
|
+
if group.get('adam_kourkoutas_beta', False):
|
|
468
|
+
# Accumulate current grad's norm for the *next* step
|
|
469
|
+
self.kourkoutas_helper.accumulate_gradient_sq_norm(p, grad)
|
|
470
|
+
# Get the dynamic beta2_adam calculated in prepare_step()
|
|
471
|
+
beta2_adam = self.kourkoutas_helper.get_beta2(p, group)
|
|
472
|
+
|
|
473
|
+
step_size = lr / bias_correction1
|
|
474
|
+
|
|
475
|
+
if group.get('adam_use_AdEMAMix'):
|
|
476
|
+
beta3_ema = group['adam_beta3_ema']
|
|
477
|
+
alpha = group['adam_alpha']
|
|
478
|
+
|
|
479
|
+
if state['factored']:
|
|
480
|
+
d1, d2 = state['effective_shape']
|
|
481
|
+
grad_reshaped = grad.view(d1, d2)
|
|
482
|
+
|
|
483
|
+
# Reconstruct momentum from previous step's factors
|
|
484
|
+
if beta1_adam > 0:
|
|
485
|
+
mt = _unnmf((state['mu_m_nmf'], state['mv_m_nmf']))
|
|
486
|
+
if not group.get('adam_grams_moment'):
|
|
487
|
+
unpacked_sign = _unpack_bools(state['sign'], original_m=d2)
|
|
488
|
+
torch.where(unpacked_sign, mt, -mt, out=mt)
|
|
489
|
+
del unpacked_sign
|
|
490
|
+
# Update momentum in full-size
|
|
491
|
+
mt.mul_(beta1_adam).add_(grad_reshaped, alpha=1.0 - beta1_adam)
|
|
492
|
+
if group.get('adam_grams_moment'):
|
|
493
|
+
mt = (grad_reshaped.sign().mul_(mt.abs()))
|
|
494
|
+
elif group.get('adam_cautious_mask'):
|
|
495
|
+
mask = (mt * grad_reshaped > 0).to(grad_reshaped.dtype)
|
|
496
|
+
mask.div_(mask.mean().clamp_(min=1e-3))
|
|
497
|
+
mt.mul_(mask)
|
|
498
|
+
del mask
|
|
499
|
+
|
|
500
|
+
vt = _unnmf((state['mu_v_nmf'], state['mv_v_nmf']))
|
|
501
|
+
vt.mul_(beta2_adam).addcmul_(grad_reshaped, grad_reshaped, value=1.0 - beta2_adam)
|
|
502
|
+
|
|
503
|
+
if group.get('adam_use_AdEMAMix'):
|
|
504
|
+
mt_slow = _unnmf((state['mu_m_slow_nmf'], state['mv_m_slow_nmf']))
|
|
505
|
+
if state['sign_slow'].dtype != torch.uint8:
|
|
506
|
+
state['sign_slow'] = state['sign_slow'].to(torch.uint8)
|
|
507
|
+
unpacked_sign_slow = _unpack_bools(state['sign_slow'], original_m=d2)
|
|
508
|
+
torch.where(unpacked_sign_slow, mt_slow, -mt_slow, out=mt_slow)
|
|
509
|
+
del unpacked_sign_slow
|
|
510
|
+
|
|
511
|
+
mt_slow.mul_(beta3_ema).add_(grad_reshaped, alpha=1.0 - beta3_ema)
|
|
512
|
+
if beta1_adam > 0:
|
|
513
|
+
update = torch.add(mt, mt_slow, alpha=alpha)
|
|
514
|
+
else:
|
|
515
|
+
update = torch.add(grad_reshaped, mt_slow, alpha=alpha)
|
|
516
|
+
else:
|
|
517
|
+
update = mt.clone() if beta1_adam > 0 else grad_reshaped.clone()
|
|
518
|
+
del grad_reshaped
|
|
519
|
+
|
|
520
|
+
if group['adam_use_atan2']:
|
|
521
|
+
a = 1.2732395
|
|
522
|
+
denom = (vt.sqrt() / (bias_correction2**0.5))
|
|
523
|
+
update.atan2_(denom).mul_(a)
|
|
524
|
+
else:
|
|
525
|
+
denom = (vt.sqrt() / (bias_correction2**0.5)).add_(group['adam_eps'])
|
|
526
|
+
update.div_(denom)
|
|
527
|
+
del denom
|
|
528
|
+
|
|
529
|
+
update = update.view(p.shape).mul_(step_size)
|
|
530
|
+
|
|
531
|
+
# Compress updated moments and store new factors
|
|
532
|
+
if beta1_adam > 0:
|
|
533
|
+
if not group.get('adam_grams_moment'):
|
|
534
|
+
state['sign'] = _pack_bools(mt > 0)
|
|
535
|
+
_nnmf(mt.abs(), out=(state['mu_m_nmf'], state['mv_m_nmf']))
|
|
536
|
+
del mt
|
|
537
|
+
if group.get('adam_use_AdEMAMix'):
|
|
538
|
+
state['sign_slow'] = _pack_bools(mt_slow > 0)
|
|
539
|
+
_nnmf(mt_slow.abs(), out=(state['mu_m_slow_nmf'], state['mv_m_slow_nmf']))
|
|
540
|
+
del mt_slow
|
|
541
|
+
_nnmf(vt, out=(state['mu_v_nmf'], state['mv_v_nmf']))
|
|
542
|
+
del vt
|
|
543
|
+
|
|
544
|
+
else: # Standard AdamW logic for non-factored tensors
|
|
545
|
+
exp_avg_sq = state['exp_avg_sq']
|
|
546
|
+
|
|
547
|
+
if beta1_adam > 0:
|
|
548
|
+
exp_avg = state['exp_avg']
|
|
549
|
+
exp_avg.mul_(beta1_adam).add_(grad, alpha=1 - beta1_adam)
|
|
550
|
+
if group.get('adam_grams_moment'):
|
|
551
|
+
exp_avg = grad.sign().mul_(exp_avg.abs())
|
|
552
|
+
elif group.get('adam_cautious_mask'):
|
|
553
|
+
mask = (exp_avg * grad > 0).to(grad.dtype)
|
|
554
|
+
mask.div_(mask.mean().clamp_(min=1e-3))
|
|
555
|
+
exp_avg.mul_(mask)
|
|
556
|
+
del mask
|
|
557
|
+
|
|
558
|
+
if group.get('adam_use_AdEMAMix'):
|
|
559
|
+
exp_avg_slow = state['exp_avg_slow']
|
|
560
|
+
exp_avg_slow.mul_(beta3_ema).add_(grad, alpha=1 - beta3_ema)
|
|
561
|
+
if beta1_adam > 0:
|
|
562
|
+
update = torch.add(exp_avg, exp_avg_slow, alpha=alpha)
|
|
563
|
+
else:
|
|
564
|
+
update = torch.add(grad, exp_avg_slow, alpha=alpha)
|
|
565
|
+
else:
|
|
566
|
+
update = exp_avg.clone() if beta1_adam > 0 else grad.clone()
|
|
567
|
+
|
|
568
|
+
exp_avg_sq.mul_(beta2_adam).addcmul_(grad, grad.conj(), value=1 - beta2_adam)
|
|
569
|
+
|
|
570
|
+
if group.get('adam_use_atan2'):
|
|
571
|
+
a = 1.2732395
|
|
572
|
+
denom = (exp_avg_sq.sqrt() / (bias_correction2**0.5))
|
|
573
|
+
update.atan2_(denom).mul_(a)
|
|
574
|
+
else:
|
|
575
|
+
denom = (exp_avg_sq.sqrt() / (bias_correction2**0.5)).add_(group['adam_eps'])
|
|
576
|
+
update.div_(denom)
|
|
577
|
+
del denom
|
|
578
|
+
|
|
579
|
+
update.mul_(step_size)
|
|
580
|
+
|
|
581
|
+
# Decoupled weight decay
|
|
582
|
+
if group["adam_weight_decay"] != 0:
|
|
583
|
+
if p.dtype == torch.bfloat16 and self.stochastic_rounding:
|
|
584
|
+
add_stochastic_(p.data, p.data, alpha=-group["adam_weight_decay"] * lr)
|
|
374
585
|
else:
|
|
375
|
-
p.data.add_(p.data, alpha=-group["
|
|
586
|
+
p.data.add_(p.data, alpha=-group["adam_weight_decay"] * lr)
|
|
376
587
|
|
|
377
588
|
if p.dtype == torch.bfloat16 and self.stochastic_rounding:
|
|
378
589
|
add_stochastic_(p.data, -update)
|
|
@@ -380,7 +591,61 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
380
591
|
p.data.add_(-update)
|
|
381
592
|
del update
|
|
382
593
|
|
|
383
|
-
|
|
594
|
+
@torch.no_grad()
|
|
595
|
+
def step_parameter(self, p: torch.Tensor, group: dict, i: int | None = None):
|
|
596
|
+
grad = p.grad
|
|
597
|
+
if grad is None:
|
|
598
|
+
return
|
|
599
|
+
state = self.state[p]
|
|
600
|
+
|
|
601
|
+
|
|
602
|
+
|
|
603
|
+
# Determine if using Adam or Muon based on state keys
|
|
604
|
+
# We can use optm_type but I see this as a safer way.
|
|
605
|
+
if 'momentum_buffer' in state or 'mu_mbuf_nmf' in state:
|
|
606
|
+
use_adam = False
|
|
607
|
+
else:
|
|
608
|
+
use_adam = True
|
|
609
|
+
|
|
610
|
+
lr = group['lr']
|
|
611
|
+
is_compiled = group.get('compiled_optimizer', False)
|
|
612
|
+
|
|
613
|
+
if use_adam:
|
|
614
|
+
if self.kourkoutas_helper:
|
|
615
|
+
# Prepare Kourkoutas-β once per optimizer step.
|
|
616
|
+
self.kourkoutas_helper.maybe_prepare_step(self.global_step)
|
|
617
|
+
# Adam-specific setup (bias correction)
|
|
618
|
+
if group['adam_use_bias_correction']:
|
|
619
|
+
current_step = self.global_step + 1
|
|
620
|
+
beta1_adam, beta2_adam = group['adam_betas']
|
|
621
|
+
bias_correction1 = 1.0 - beta1_adam ** current_step
|
|
622
|
+
bias_correction2 = 1.0 - beta2_adam ** current_step
|
|
623
|
+
else:
|
|
624
|
+
bias_correction1 = 1.0
|
|
625
|
+
bias_correction2 = 1.0
|
|
626
|
+
# Dispatch to compiled or uncompiled Adam step
|
|
627
|
+
if is_compiled and self._compiled_adam_step is not None:
|
|
628
|
+
# Tensors must be used for compiled functions
|
|
629
|
+
lr_tensor = torch.tensor(lr, device=p.device)
|
|
630
|
+
bc1_tensor = torch.tensor(bias_correction1, device=p.device)
|
|
631
|
+
bc2_tensor = torch.tensor(bias_correction2, device=p.device)
|
|
632
|
+
self._compiled_adam_step(p, grad, state, group, lr_tensor, bc1_tensor, bc2_tensor)
|
|
633
|
+
else:
|
|
634
|
+
self._adam_step_parameter(p, grad, state, group, lr, bias_correction1, bias_correction2)
|
|
635
|
+
else: # Muon path
|
|
636
|
+
# Dispatch to compiled or uncompiled Muon step
|
|
637
|
+
if is_compiled and self._compiled_muon_step is not None:
|
|
638
|
+
lr_tensor = torch.tensor(lr, device=p.device)
|
|
639
|
+
self._compiled_muon_step(p, grad, state, group, lr_tensor)
|
|
640
|
+
else:
|
|
641
|
+
self._muon_step_parameter(p, grad, state, group, lr)
|
|
642
|
+
|
|
643
|
+
|
|
644
|
+
def compile(self, *args, **kwargs):
|
|
645
|
+
print("Compiling AdaMuon step path...")
|
|
646
|
+
self._compiled_muon_step = torch.compile(self._muon_step_parameter, *args, **kwargs)
|
|
647
|
+
print("Compiling AuxAdam step path...")
|
|
648
|
+
self._compiled_adam_step = torch.compile(self._adam_step_parameter, *args, **kwargs)
|
|
384
649
|
|
|
385
650
|
@torch.no_grad()
|
|
386
651
|
def step(self, closure=None):
|
|
@@ -394,4 +659,6 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
394
659
|
for i, p in enumerate(group['params']):
|
|
395
660
|
self.step_parameter(p, group, i)
|
|
396
661
|
|
|
662
|
+
self.global_step += 1
|
|
663
|
+
|
|
397
664
|
return loss
|
adv_optm/optim/Muon_adv.py
CHANGED
|
@@ -1,28 +1,23 @@
|
|
|
1
1
|
import torch
|
|
2
|
-
from typing import Optional
|
|
3
|
-
from .AdamW_adv import AdamW_adv
|
|
4
2
|
|
|
5
3
|
from ..util.BF16_Stochastic_Rounding import add_stochastic_
|
|
6
4
|
from ..util.Newton_Schulz import _newton_schulz_iteration
|
|
7
5
|
from ..util.Effective_Shape import _get_effective_shape
|
|
8
6
|
from ..util.NNMF import _nnmf,_unnmf
|
|
9
7
|
from ..util.One_Bit_Boolean import _pack_bools, _unpack_bools
|
|
8
|
+
from ..util.OrthoGrad import _orthogonalize_gradient
|
|
9
|
+
from ..util.Kourkoutas import KourkoutasHelper
|
|
10
10
|
|
|
11
11
|
class Muon_adv(torch.optim.Optimizer):
|
|
12
12
|
"""
|
|
13
|
-
Implements an advanced Muon algorithm.
|
|
13
|
+
Implements an advanced Muon algorithm, with an integrated auxiliary AdamW optimizer.
|
|
14
14
|
|
|
15
15
|
Muon (MomentUm Orthogonalized by Newton-Schulz) is an optimizer designed for
|
|
16
16
|
the hidden layers of neural networks. It applies SGD with momentum and then
|
|
17
17
|
orthogonalizes the resulting update matrix using a Newton-Schulz iteration.
|
|
18
18
|
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
second-order momentum statistics.
|
|
22
|
-
|
|
23
|
-
This implementation is designed for 2D parameters (e.g., linear layers) and
|
|
24
|
-
can handle other-dimensional parameters (e.g., 1D bias, 4D convolutional layers) by
|
|
25
|
-
flattening/reshaping them.
|
|
19
|
+
When `MuonWithAuxAdam` is enabled, this single optimizer class handles both
|
|
20
|
+
'muon' and 'adam' parameter groups, dispatching to the appropriate logic internally.
|
|
26
21
|
|
|
27
22
|
Args:
|
|
28
23
|
params (iterable): iterable of parameters to optimize or dicts defining
|
|
@@ -65,6 +60,19 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
65
60
|
normuon_lr_scale (float): Scaling factor for the NorMuon learning rate.
|
|
66
61
|
(default: 0.2)
|
|
67
62
|
normuon_atan2 (bool): whether to use the atan2 for NorMuon. (default: False)
|
|
63
|
+
--- Auxiliary AdamW_adv Parameters (used for 'adam' groups) ---
|
|
64
|
+
adam_betas (tuple[float, float]): Betas for the AdamW optimizer part.
|
|
65
|
+
adam_eps (float): Epsilon for the AdamW optimizer part.
|
|
66
|
+
adam_weight_decay (float): Weight decay for the AdamW optimizer part.
|
|
67
|
+
adam_use_bias_correction (bool): Bias correction for AdamW.
|
|
68
|
+
adam_use_atan2 (bool): Atan2 update rule for AdamW.
|
|
69
|
+
adam_cautious_mask (bool): Cautious masking for AdamW.
|
|
70
|
+
adam_grams_moment (bool): Grams-style updates for AdamW.
|
|
71
|
+
adam_orthogonal_gradient (bool): OrthoGrad for AdamW.
|
|
72
|
+
adam_use_AdEMAMix (bool): AdEMAMix for AdamW.
|
|
73
|
+
adam_beta3_ema (float): Beta3 for AdEMAMix.
|
|
74
|
+
adam_alpha (float): Alpha for AdEMAMix.
|
|
75
|
+
adam_kourkoutas_beta (bool): Kourkoutas-β for AdamW.
|
|
68
76
|
"""
|
|
69
77
|
|
|
70
78
|
def __init__(
|
|
@@ -92,6 +100,25 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
92
100
|
normuon_eps: float = 1e-8,
|
|
93
101
|
normuon_lr_scale: float = 0.2,
|
|
94
102
|
normuon_atan2: bool = False,
|
|
103
|
+
# Compiled
|
|
104
|
+
compiled_optimizer: bool = False,
|
|
105
|
+
# --- AdamW_adv specific parameters ---
|
|
106
|
+
adam_betas: tuple[float, float] = (0.9, 0.99),
|
|
107
|
+
adam_eps: float = 1e-8,
|
|
108
|
+
adam_weight_decay: float = 0.0,
|
|
109
|
+
adam_use_bias_correction: bool = True,
|
|
110
|
+
adam_use_atan2: bool = False,
|
|
111
|
+
adam_cautious_mask: bool = False,
|
|
112
|
+
adam_grams_moment: bool = False,
|
|
113
|
+
adam_orthogonal_gradient: bool = False,
|
|
114
|
+
adam_use_AdEMAMix: bool = False,
|
|
115
|
+
adam_beta3_ema: float = 0.9999,
|
|
116
|
+
adam_alpha: float = 5.0,
|
|
117
|
+
adam_kourkoutas_beta: bool = False,
|
|
118
|
+
adam_beta2_min: float = 0.9,
|
|
119
|
+
adam_ema_alpha: float = 0.95,
|
|
120
|
+
adam_tiny_spike: float = 1e-9,
|
|
121
|
+
adam_k_warmup_steps: int = 0,
|
|
95
122
|
):
|
|
96
123
|
if not (lr >= 0.0):
|
|
97
124
|
raise ValueError(f"Learning-rate should be >= 0.0. Got {lr}")
|
|
@@ -104,7 +131,7 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
104
131
|
if not (ns_steps > 0):
|
|
105
132
|
raise ValueError(f"Newton-Schulz steps should be > 0. Got {ns_steps}")
|
|
106
133
|
if Simplified_AdEMAMix and nesterov:
|
|
107
|
-
print("Warning: nesterov is incompatible with Simplified_AdEMAMix, Disabling
|
|
134
|
+
print("Warning: nesterov is incompatible with Simplified_AdEMAMix, Disabling nesterov.")
|
|
108
135
|
nesterov = False
|
|
109
136
|
|
|
110
137
|
defaults = {
|
|
@@ -114,17 +141,43 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
114
141
|
"vector_reshape": vector_reshape,
|
|
115
142
|
"vector_reshape_muon": vector_reshape_muon,
|
|
116
143
|
"Simplified_AdEMAMix": Simplified_AdEMAMix, "alpha_grad": alpha_grad,
|
|
144
|
+
'compiled_optimizer': compiled_optimizer,
|
|
117
145
|
# Low-rank Ortho
|
|
118
146
|
"low_rank_ortho": low_rank_ortho, "ortho_rank": ortho_rank,
|
|
119
147
|
# NorMuon
|
|
120
148
|
"normuon_variant": normuon_variant, "beta2_normuon": beta2_normuon,
|
|
121
149
|
"normuon_eps": normuon_eps, "normuon_lr_scale": normuon_lr_scale,
|
|
122
150
|
"normuon_atan2": normuon_atan2,
|
|
151
|
+
# AdamW_adv defaults
|
|
152
|
+
"adam_betas": adam_betas, "adam_eps": adam_eps, "adam_weight_decay": adam_weight_decay,
|
|
153
|
+
"adam_use_bias_correction": adam_use_bias_correction, "adam_use_atan2": adam_use_atan2,
|
|
154
|
+
"adam_cautious_mask": adam_cautious_mask, "adam_grams_moment": adam_grams_moment,
|
|
155
|
+
"adam_orthogonal_gradient": adam_orthogonal_gradient,
|
|
156
|
+
"adam_use_AdEMAMix": adam_use_AdEMAMix, "adam_beta3_ema": adam_beta3_ema, "adam_alpha": adam_alpha,
|
|
157
|
+
"adam_kourkoutas_beta": adam_kourkoutas_beta, "adam_beta2_min": adam_beta2_min,
|
|
158
|
+
"adam_ema_alpha": adam_ema_alpha, "adam_tiny_spike": adam_tiny_spike,
|
|
159
|
+
"adam_k_warmup_steps": adam_k_warmup_steps,
|
|
123
160
|
}
|
|
124
161
|
self.stochastic_rounding = stochastic_rounding
|
|
162
|
+
self.compiled_optimizer = compiled_optimizer
|
|
125
163
|
|
|
126
164
|
super().__init__(params, defaults)
|
|
127
165
|
|
|
166
|
+
self.global_step = 0 # For Adam bias correction and Kourkoutas
|
|
167
|
+
self.kourkoutas_helper = None
|
|
168
|
+
if any(group.get('adam_kourkoutas_beta', False) for group in self.param_groups):
|
|
169
|
+
self.kourkoutas_helper = KourkoutasHelper(self)
|
|
170
|
+
|
|
171
|
+
self.init_step()
|
|
172
|
+
|
|
173
|
+
# Initialize compiled functions to None
|
|
174
|
+
self._compiled_muon_step = None
|
|
175
|
+
self._compiled_adam_step = None
|
|
176
|
+
|
|
177
|
+
if compiled_optimizer:
|
|
178
|
+
print("Compiling Muon_adv optimizer paths...")
|
|
179
|
+
torch._dynamo.config.cache_size_limit = 8192
|
|
180
|
+
self.compile(fullgraph=True)
|
|
128
181
|
|
|
129
182
|
@property
|
|
130
183
|
def supports_fused_back_pass(self):
|
|
@@ -138,53 +191,90 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
138
191
|
def supports_flat_params(self):
|
|
139
192
|
return False
|
|
140
193
|
|
|
194
|
+
def init_step(self):
|
|
195
|
+
for group in self.param_groups:
|
|
196
|
+
for i, p in enumerate(group['params']):
|
|
197
|
+
self.__init_state(p, group)
|
|
198
|
+
|
|
141
199
|
@torch.no_grad()
|
|
142
|
-
def
|
|
143
|
-
|
|
200
|
+
def __init_state(self, p, group):
|
|
201
|
+
state = self.state[p]
|
|
202
|
+
|
|
203
|
+
if len(state) > 0:
|
|
144
204
|
return
|
|
145
205
|
|
|
146
|
-
|
|
147
|
-
state = self.state[p]
|
|
206
|
+
optim_type = group.get('optim_type', 'muon')
|
|
148
207
|
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
208
|
+
state['factored'] = (
|
|
209
|
+
group['nnmf_factor'] and
|
|
210
|
+
not (len(p.shape) == 1 and not group['vector_reshape'])
|
|
211
|
+
)
|
|
212
|
+
dtype = torch.float32 if state['factored'] else p.dtype
|
|
213
|
+
device = p.device
|
|
152
214
|
|
|
153
|
-
|
|
215
|
+
if optim_type == 'muon':
|
|
216
|
+
|
|
217
|
+
state['factored'] = (
|
|
154
218
|
group['nnmf_factor'] and
|
|
155
219
|
not (len(p.shape) == 1 and not group['vector_reshape'])
|
|
156
220
|
)
|
|
157
|
-
|
|
158
|
-
state['factored'] = should_factor
|
|
159
|
-
|
|
160
|
-
state['reshaped_1d_muon'] = len(p.shape) == 1 and group['vector_reshape_muon']
|
|
161
|
-
|
|
162
|
-
dtype = torch.float32 if group['nnmf_factor'] else p.dtype
|
|
221
|
+
dtype = torch.float32 if state['factored'] else p.dtype
|
|
163
222
|
device = p.device
|
|
164
|
-
|
|
223
|
+
|
|
224
|
+
if state['factored']:
|
|
165
225
|
state['effective_shape'] = _get_effective_shape(p.numel())
|
|
166
226
|
d1, d2 = state['effective_shape']
|
|
167
|
-
|
|
168
|
-
state['
|
|
169
|
-
state['mv_m_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
|
|
227
|
+
state['mu_mbuf_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
|
|
228
|
+
state['mv_mbuf_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
|
|
170
229
|
packed_d2 = (d2 + 7) // 8
|
|
171
|
-
state['
|
|
230
|
+
state['sign_buf'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=device)
|
|
172
231
|
else:
|
|
173
|
-
|
|
174
|
-
state['momentum_buffer'] = torch.zeros_like(p)
|
|
175
|
-
if state['reshaped_1d_muon']:
|
|
176
|
-
state['momentum_buffer'] = torch.zeros((d1, d2), device=device, dtype=dtype)
|
|
177
|
-
elif len(p.shape) == 1:
|
|
178
|
-
state['momentum_buffer'] = torch.zeros_like(p)
|
|
232
|
+
state['momentum_buffer'] = torch.zeros_like(p)
|
|
179
233
|
|
|
180
234
|
# NorMuon state initialization
|
|
181
235
|
if group['normuon_variant']:
|
|
182
236
|
if state['factored']:
|
|
183
237
|
state['normuon_v'] = torch.zeros(d1, device=p.device, dtype=torch.float32)
|
|
184
|
-
elif len(p.shape) >= 2
|
|
185
|
-
|
|
186
|
-
|
|
238
|
+
elif len(p.shape) >= 2:
|
|
239
|
+
state['normuon_v'] = torch.zeros(p.shape[0], device=p.device, dtype=torch.float32)
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
elif optim_type == 'adam':
|
|
243
|
+
|
|
244
|
+
state['factored'] = (
|
|
245
|
+
group['adam_nnmf_factor'] and
|
|
246
|
+
not (len(p.shape) == 1 and not group['vector_reshape'])
|
|
247
|
+
)
|
|
248
|
+
dtype = torch.float32 if state['factored'] else p.dtype
|
|
249
|
+
device = p.device
|
|
250
|
+
|
|
251
|
+
if state['factored']:
|
|
252
|
+
state['effective_shape'] = _get_effective_shape(p.numel())
|
|
253
|
+
d1, d2 = state['effective_shape']
|
|
254
|
+
# First moment (m)
|
|
255
|
+
if group['adam_betas'][0] > 0:
|
|
256
|
+
state['mu_m_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
|
|
257
|
+
state['mv_m_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
|
|
258
|
+
if not group.get('adam_grams_moment'):
|
|
259
|
+
packed_d2 = (d2 + 7) // 8
|
|
260
|
+
state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=device)
|
|
261
|
+
if group.get('adam_use_AdEMAMix'):
|
|
262
|
+
state['mu_m_slow_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
|
|
263
|
+
state['mv_m_slow_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
|
|
264
|
+
packed_d2 = (d2 + 7) // 8
|
|
265
|
+
state['sign_slow'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=p.device)
|
|
266
|
+
# Second moment (v)
|
|
267
|
+
state['mu_v_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
|
|
268
|
+
state['mv_v_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
|
|
269
|
+
else: # Fallback to standard AdamW for non-factored tensors
|
|
270
|
+
if group['adam_betas'][0] > 0:
|
|
271
|
+
state['exp_avg'] = torch.zeros_like(p, device=device, dtype=dtype)
|
|
272
|
+
if group.get('adam_use_AdEMAMix'):
|
|
273
|
+
state['exp_avg_slow'] = torch.zeros_like(p, device=device, dtype=dtype)
|
|
274
|
+
state['exp_avg_sq'] = torch.zeros_like(p, device=device, dtype=dtype)
|
|
187
275
|
|
|
276
|
+
@torch.no_grad()
|
|
277
|
+
def _muon_step_parameter(self, p, grad, state, group, lr):
|
|
188
278
|
beta1 = group['beta1']
|
|
189
279
|
nesterov = group['nesterov']
|
|
190
280
|
Simplified_AdEMAMix = group['Simplified_AdEMAMix']
|
|
@@ -194,8 +284,8 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
194
284
|
|
|
195
285
|
# Reconstruct momentum from previous step's factors & sign
|
|
196
286
|
d1, d2 = state['effective_shape']
|
|
197
|
-
mt_buf = _unnmf((state['
|
|
198
|
-
unpacked_sign = _unpack_bools(state['
|
|
287
|
+
mt_buf = _unnmf((state['mu_mbuf_nmf'], state['mv_mbuf_nmf']))
|
|
288
|
+
unpacked_sign = _unpack_bools(state['sign_buf'], original_m=d2)
|
|
199
289
|
torch.where(unpacked_sign, mt_buf, -mt_buf, out=mt_buf)
|
|
200
290
|
del unpacked_sign
|
|
201
291
|
|
|
@@ -260,51 +350,41 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
260
350
|
update.div_(v_t.sqrt().unsqueeze(1).add_(group['normuon_eps']))
|
|
261
351
|
# Scale learning rate
|
|
262
352
|
update_norm = torch.linalg.vector_norm(update)
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
scaled_lr = 0.0
|
|
353
|
+
|
|
354
|
+
scaled_lr = group['normuon_lr_scale'] * lr * (p.numel()**0.5) / update_norm.add_(group['normuon_eps'])
|
|
355
|
+
|
|
267
356
|
update = update.view(p.shape).mul_(scaled_lr)
|
|
357
|
+
del update_norm, scaled_lr
|
|
268
358
|
else: # Original Muon learning rate application
|
|
269
|
-
update = update.view(p.shape).mul_(
|
|
359
|
+
update = update.view(p.shape).mul_(lr)
|
|
270
360
|
|
|
271
|
-
state['
|
|
272
|
-
_nnmf(mt_buf.abs(), out=(state['
|
|
361
|
+
state['sign_buf'] = _pack_bools(mt_buf > 0)
|
|
362
|
+
_nnmf(mt_buf.abs(), out=(state['mu_mbuf_nmf'], state['mv_mbuf_nmf']))
|
|
273
363
|
del mt_buf
|
|
274
364
|
|
|
275
365
|
else: # Standard Muon logic for non-factored tensors
|
|
276
366
|
|
|
277
|
-
if len(p.shape) >= 2
|
|
367
|
+
if len(p.shape) >= 2:
|
|
368
|
+
|
|
369
|
+
original_shape = p.shape
|
|
278
370
|
|
|
279
371
|
# Momentum update
|
|
280
372
|
mt_buf = state['momentum_buffer']
|
|
281
|
-
|
|
282
|
-
d1, d2 = state['effective_shape']
|
|
283
|
-
grad_reshaped = grad.view(d1, d2)
|
|
284
|
-
mt_buf.mul_(beta1).add_(grad_reshaped)
|
|
285
|
-
else:
|
|
286
|
-
mt_buf.mul_(beta1).add_(grad)
|
|
373
|
+
mt_buf.mul_(beta1).add_(grad)
|
|
287
374
|
|
|
288
375
|
if nesterov:
|
|
289
376
|
# Nesterov momentum
|
|
290
|
-
|
|
291
|
-
update = grad_reshaped.add(mt_buf, alpha=beta1)
|
|
292
|
-
del grad_reshaped
|
|
293
|
-
else:
|
|
294
|
-
update = grad.add(mt_buf, alpha=beta1)
|
|
377
|
+
update = grad.add(mt_buf, alpha=beta1)
|
|
295
378
|
elif Simplified_AdEMAMix:
|
|
296
|
-
|
|
297
|
-
update = torch.add(mt_buf, grad_reshaped, alpha=alpha_grad)
|
|
298
|
-
del grad_reshaped
|
|
299
|
-
else:
|
|
300
|
-
update = torch.add(mt_buf, grad, alpha=alpha_grad)
|
|
379
|
+
update = torch.add(mt_buf, grad, alpha=alpha_grad)
|
|
301
380
|
else:
|
|
302
381
|
# Standard momentum
|
|
303
382
|
update = mt_buf.clone()
|
|
304
383
|
|
|
305
|
-
#
|
|
306
|
-
|
|
307
|
-
|
|
384
|
+
# flatten to 2D for orthogonalization.
|
|
385
|
+
# This is a no-op for 2D tensors and correctly flattens 4D+ tensors.
|
|
386
|
+
# This removes the dynamic control flow that breaks torch.compile.
|
|
387
|
+
update = update.view(original_shape[0], -1)
|
|
308
388
|
|
|
309
389
|
# Orthogonalization step
|
|
310
390
|
if group['low_rank_ortho']:
|
|
@@ -316,7 +396,7 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
316
396
|
# 1. Sketch the matrix
|
|
317
397
|
G_sketch = torch.randn(M.shape[1], r, device=M.device, dtype=M.dtype)
|
|
318
398
|
MG = M @ G_sketch
|
|
319
|
-
|
|
399
|
+
|
|
320
400
|
# 2. QR decomposition to get orthogonal basis Q
|
|
321
401
|
if MG.dtype != torch.float32:
|
|
322
402
|
MG_dtype = M.dtype
|
|
@@ -324,10 +404,10 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
324
404
|
Q = Q.to(MG_dtype)
|
|
325
405
|
else:
|
|
326
406
|
Q, _ = torch.linalg.qr(MG)
|
|
327
|
-
|
|
407
|
+
|
|
328
408
|
# 3. Project M onto the basis
|
|
329
409
|
projected_M = Q.T @ M
|
|
330
|
-
|
|
410
|
+
|
|
331
411
|
# 4. Orthogonalize the smaller projected matrix
|
|
332
412
|
ortho_projected_M = _newton_schulz_iteration(
|
|
333
413
|
projected_M,
|
|
@@ -335,7 +415,7 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
335
415
|
eps=group['ns_eps'],
|
|
336
416
|
coeffs=group['ns_coeffs'],
|
|
337
417
|
)
|
|
338
|
-
|
|
418
|
+
|
|
339
419
|
# 5. Project back to the original space
|
|
340
420
|
update = Q @ ortho_projected_M
|
|
341
421
|
else: # Fallback for invalid rank
|
|
@@ -369,19 +449,16 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
369
449
|
update.div_(v_t.sqrt().unsqueeze(1).add_(group['normuon_eps']))
|
|
370
450
|
# Scale learning rate
|
|
371
451
|
update_norm = torch.linalg.vector_norm(update)
|
|
372
|
-
|
|
373
|
-
scaled_lr = group['normuon_lr_scale'] * group['lr'] * (p.numel()**0.5) / update_norm
|
|
374
|
-
else:
|
|
375
|
-
scaled_lr = 0.0
|
|
452
|
+
scaled_lr = group['normuon_lr_scale'] * lr * (p.numel()**0.5) / update_norm.add_(group['normuon_eps'])
|
|
376
453
|
update.mul_(scaled_lr)
|
|
377
454
|
else: # Original Muon learning rate application
|
|
378
|
-
update.mul_(
|
|
455
|
+
update.mul_(lr)
|
|
456
|
+
|
|
457
|
+
# reshape back to the original shape.
|
|
458
|
+
update = update.view(original_shape)
|
|
379
459
|
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
update = update.view(p.shape)
|
|
383
|
-
|
|
384
|
-
else: # Fallback to standard SGD with momentum for 1D params (biases, etc.) when not reshaped
|
|
460
|
+
|
|
461
|
+
else: # Fallback to standard SGD with momentum for 1D params (biases, etc.)
|
|
385
462
|
# Momentum update
|
|
386
463
|
mt_buf = state['momentum_buffer']
|
|
387
464
|
mt_buf.mul_(beta1).add_(grad)
|
|
@@ -393,14 +470,14 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
393
470
|
else:
|
|
394
471
|
# Standard momentum
|
|
395
472
|
update = mt_buf.clone()
|
|
396
|
-
update.mul_(
|
|
473
|
+
update.mul_(lr)
|
|
397
474
|
|
|
398
475
|
# Decoupled weight decay
|
|
399
476
|
if group["weight_decay"] != 0:
|
|
400
477
|
if p.dtype == torch.bfloat16 and self.stochastic_rounding:
|
|
401
|
-
add_stochastic_(p.data, p.data, alpha=-group["weight_decay"] *
|
|
478
|
+
add_stochastic_(p.data, p.data, alpha=-group["weight_decay"] * lr)
|
|
402
479
|
else:
|
|
403
|
-
p.data.add_(p.data, alpha=-group["weight_decay"] *
|
|
480
|
+
p.data.add_(p.data, alpha=-group["weight_decay"] * lr)
|
|
404
481
|
|
|
405
482
|
if p.dtype == torch.bfloat16 and self.stochastic_rounding:
|
|
406
483
|
add_stochastic_(p.data, -update)
|
|
@@ -408,7 +485,198 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
408
485
|
p.data.add_(-update)
|
|
409
486
|
del update
|
|
410
487
|
|
|
411
|
-
|
|
488
|
+
@torch.no_grad()
|
|
489
|
+
def _adam_step_parameter(self, p, grad, state, group, lr, bias_correction1, bias_correction2):
|
|
490
|
+
if grad.dtype != torch.float32 and state.get('factored', False):
|
|
491
|
+
grad = grad.float()
|
|
492
|
+
if group.get("adam_orthogonal_gradient"):
|
|
493
|
+
grad = _orthogonalize_gradient(p, grad)
|
|
494
|
+
|
|
495
|
+
beta1_adam, beta2_adam = group['adam_betas']
|
|
496
|
+
|
|
497
|
+
if group.get('adam_kourkoutas_beta', False):
|
|
498
|
+
# Accumulate current grad's norm for the *next* step
|
|
499
|
+
self.kourkoutas_helper.accumulate_gradient_sq_norm(p, grad)
|
|
500
|
+
# Get the dynamic beta2_adam calculated in prepare_step()
|
|
501
|
+
beta2_adam = self.kourkoutas_helper.get_beta2(p, group)
|
|
502
|
+
|
|
503
|
+
step_size = lr / bias_correction1
|
|
504
|
+
|
|
505
|
+
if group.get('adam_use_AdEMAMix'):
|
|
506
|
+
beta3_ema = group['adam_beta3_ema']
|
|
507
|
+
alpha = group['adam_alpha']
|
|
508
|
+
|
|
509
|
+
if state['factored']:
|
|
510
|
+
d1, d2 = state['effective_shape']
|
|
511
|
+
grad_reshaped = grad.view(d1, d2)
|
|
512
|
+
|
|
513
|
+
# Reconstruct momentum from previous step's factors
|
|
514
|
+
if beta1_adam > 0:
|
|
515
|
+
mt = _unnmf((state['mu_m_nmf'], state['mv_m_nmf']))
|
|
516
|
+
if not group.get('adam_grams_moment'):
|
|
517
|
+
unpacked_sign = _unpack_bools(state['sign'], original_m=d2)
|
|
518
|
+
torch.where(unpacked_sign, mt, -mt, out=mt)
|
|
519
|
+
del unpacked_sign
|
|
520
|
+
# Update momentum in full-size
|
|
521
|
+
mt.mul_(beta1_adam).add_(grad_reshaped, alpha=1.0 - beta1_adam)
|
|
522
|
+
if group.get('adam_grams_moment'):
|
|
523
|
+
mt = (grad_reshaped.sign().mul_(mt.abs()))
|
|
524
|
+
elif group.get('adam_cautious_mask'):
|
|
525
|
+
mask = (mt * grad_reshaped > 0).to(grad_reshaped.dtype)
|
|
526
|
+
mask.div_(mask.mean().clamp_(min=1e-3))
|
|
527
|
+
mt.mul_(mask)
|
|
528
|
+
del mask
|
|
529
|
+
|
|
530
|
+
vt = _unnmf((state['mu_v_nmf'], state['mv_v_nmf']))
|
|
531
|
+
vt.mul_(beta2_adam).addcmul_(grad_reshaped, grad_reshaped, value=1.0 - beta2_adam)
|
|
532
|
+
|
|
533
|
+
if group.get('adam_use_AdEMAMix'):
|
|
534
|
+
mt_slow = _unnmf((state['mu_m_slow_nmf'], state['mv_m_slow_nmf']))
|
|
535
|
+
if state['sign_slow'].dtype != torch.uint8:
|
|
536
|
+
state['sign_slow'] = state['sign_slow'].to(torch.uint8)
|
|
537
|
+
unpacked_sign_slow = _unpack_bools(state['sign_slow'], original_m=d2)
|
|
538
|
+
torch.where(unpacked_sign_slow, mt_slow, -mt_slow, out=mt_slow)
|
|
539
|
+
del unpacked_sign_slow
|
|
540
|
+
|
|
541
|
+
mt_slow.mul_(beta3_ema).add_(grad_reshaped, alpha=1.0 - beta3_ema)
|
|
542
|
+
if beta1_adam > 0:
|
|
543
|
+
update = torch.add(mt, mt_slow, alpha=alpha)
|
|
544
|
+
else:
|
|
545
|
+
update = torch.add(grad_reshaped, mt_slow, alpha=alpha)
|
|
546
|
+
else:
|
|
547
|
+
update = mt.clone() if beta1_adam > 0 else grad_reshaped.clone()
|
|
548
|
+
del grad_reshaped
|
|
549
|
+
|
|
550
|
+
if group['adam_use_atan2']:
|
|
551
|
+
a = 1.2732395
|
|
552
|
+
denom = (vt.sqrt() / (bias_correction2**0.5))
|
|
553
|
+
update.atan2_(denom).mul_(a)
|
|
554
|
+
else:
|
|
555
|
+
denom = (vt.sqrt() / (bias_correction2**0.5)).add_(group['adam_eps'])
|
|
556
|
+
update.div_(denom)
|
|
557
|
+
del denom
|
|
558
|
+
|
|
559
|
+
update = update.view(p.shape).mul_(step_size)
|
|
560
|
+
|
|
561
|
+
# Compress updated moments and store new factors
|
|
562
|
+
if beta1_adam > 0:
|
|
563
|
+
if not group.get('adam_grams_moment'):
|
|
564
|
+
state['sign'] = _pack_bools(mt > 0)
|
|
565
|
+
_nnmf(mt.abs(), out=(state['mu_m_nmf'], state['mv_m_nmf']))
|
|
566
|
+
del mt
|
|
567
|
+
if group.get('adam_use_AdEMAMix'):
|
|
568
|
+
state['sign_slow'] = _pack_bools(mt_slow > 0)
|
|
569
|
+
_nnmf(mt_slow.abs(), out=(state['mu_m_slow_nmf'], state['mv_m_slow_nmf']))
|
|
570
|
+
del mt_slow
|
|
571
|
+
_nnmf(vt, out=(state['mu_v_nmf'], state['mv_v_nmf']))
|
|
572
|
+
del vt
|
|
573
|
+
|
|
574
|
+
else: # Standard AdamW logic for non-factored tensors
|
|
575
|
+
exp_avg_sq = state['exp_avg_sq']
|
|
576
|
+
|
|
577
|
+
if beta1_adam > 0:
|
|
578
|
+
exp_avg = state['exp_avg']
|
|
579
|
+
exp_avg.mul_(beta1_adam).add_(grad, alpha=1 - beta1_adam)
|
|
580
|
+
if group.get('adam_grams_moment'):
|
|
581
|
+
exp_avg = grad.sign().mul_(exp_avg.abs())
|
|
582
|
+
elif group.get('adam_cautious_mask'):
|
|
583
|
+
mask = (exp_avg * grad > 0).to(grad.dtype)
|
|
584
|
+
mask.div_(mask.mean().clamp_(min=1e-3))
|
|
585
|
+
exp_avg.mul_(mask)
|
|
586
|
+
del mask
|
|
587
|
+
|
|
588
|
+
if group.get('adam_use_AdEMAMix'):
|
|
589
|
+
exp_avg_slow = state['exp_avg_slow']
|
|
590
|
+
exp_avg_slow.mul_(beta3_ema).add_(grad, alpha=1 - beta3_ema)
|
|
591
|
+
if beta1_adam > 0:
|
|
592
|
+
update = torch.add(exp_avg, exp_avg_slow, alpha=alpha)
|
|
593
|
+
else:
|
|
594
|
+
update = torch.add(grad, exp_avg_slow, alpha=alpha)
|
|
595
|
+
else:
|
|
596
|
+
update = exp_avg.clone() if beta1_adam > 0 else grad.clone()
|
|
597
|
+
|
|
598
|
+
exp_avg_sq.mul_(beta2_adam).addcmul_(grad, grad.conj(), value=1 - beta2_adam)
|
|
599
|
+
|
|
600
|
+
if group.get('adam_use_atan2'):
|
|
601
|
+
a = 1.2732395
|
|
602
|
+
denom = (exp_avg_sq.sqrt() / (bias_correction2**0.5))
|
|
603
|
+
update.atan2_(denom).mul_(a)
|
|
604
|
+
else:
|
|
605
|
+
denom = (exp_avg_sq.sqrt() / (bias_correction2**0.5)).add_(group['adam_eps'])
|
|
606
|
+
update.div_(denom)
|
|
607
|
+
del denom
|
|
608
|
+
|
|
609
|
+
update.mul_(step_size)
|
|
610
|
+
|
|
611
|
+
# Decoupled weight decay
|
|
612
|
+
if group["adam_weight_decay"] != 0:
|
|
613
|
+
if p.dtype == torch.bfloat16 and self.stochastic_rounding:
|
|
614
|
+
add_stochastic_(p.data, p.data, alpha=-group["adam_weight_decay"] * lr)
|
|
615
|
+
else:
|
|
616
|
+
p.data.add_(p.data, alpha=-group["adam_weight_decay"] * lr)
|
|
617
|
+
|
|
618
|
+
if p.dtype == torch.bfloat16 and self.stochastic_rounding:
|
|
619
|
+
add_stochastic_(p.data, -update)
|
|
620
|
+
else:
|
|
621
|
+
p.data.add_(-update)
|
|
622
|
+
del update
|
|
623
|
+
|
|
624
|
+
@torch.no_grad()
|
|
625
|
+
def step_parameter(self, p: torch.Tensor, group: dict, i: int | None = None):
|
|
626
|
+
grad = p.grad
|
|
627
|
+
if grad is None:
|
|
628
|
+
return
|
|
629
|
+
|
|
630
|
+
state = self.state[p]
|
|
631
|
+
|
|
632
|
+
if self.kourkoutas_helper:
|
|
633
|
+
# Prepare Kourkoutas-β once per optimizer step.
|
|
634
|
+
self.kourkoutas_helper.maybe_prepare_step(self.global_step)
|
|
635
|
+
|
|
636
|
+
# Determine if using Adam or Muon based on state keys
|
|
637
|
+
# We can use optm_type but I see this as a safer way.
|
|
638
|
+
if 'momentum_buffer' in state or 'mu_mbuf_nmf' in state:
|
|
639
|
+
use_adam = False
|
|
640
|
+
else:
|
|
641
|
+
use_adam = True
|
|
642
|
+
|
|
643
|
+
lr = group['lr']
|
|
644
|
+
is_compiled = group.get('compiled_optimizer', False)
|
|
645
|
+
|
|
646
|
+
if use_adam:
|
|
647
|
+
# Adam-specific setup (bias correction)
|
|
648
|
+
if group['adam_use_bias_correction']:
|
|
649
|
+
current_step = self.global_step + 1
|
|
650
|
+
beta1_adam, beta2_adam = group['adam_betas']
|
|
651
|
+
bias_correction1 = 1.0 - beta1_adam ** current_step
|
|
652
|
+
bias_correction2 = 1.0 - beta2_adam ** current_step
|
|
653
|
+
else:
|
|
654
|
+
bias_correction1 = 1.0
|
|
655
|
+
bias_correction2 = 1.0
|
|
656
|
+
|
|
657
|
+
# Dispatch to compiled or uncompiled Adam step
|
|
658
|
+
if is_compiled and self._compiled_adam_step is not None:
|
|
659
|
+
# Tensors must be used for compiled functions
|
|
660
|
+
lr_tensor = torch.tensor(lr, device=p.device)
|
|
661
|
+
bc1_tensor = torch.tensor(bias_correction1, device=p.device)
|
|
662
|
+
bc2_tensor = torch.tensor(bias_correction2, device=p.device)
|
|
663
|
+
self._compiled_adam_step(p, grad, state, group, lr_tensor, bc1_tensor, bc2_tensor)
|
|
664
|
+
else:
|
|
665
|
+
self._adam_step_parameter(p, grad, state, group, lr, bias_correction1, bias_correction2)
|
|
666
|
+
else: # Muon path
|
|
667
|
+
# Dispatch to compiled or uncompiled Muon step
|
|
668
|
+
if is_compiled and self._compiled_muon_step is not None:
|
|
669
|
+
lr_tensor = torch.tensor(lr, device=p.device)
|
|
670
|
+
self._compiled_muon_step(p, grad, state, group, lr_tensor)
|
|
671
|
+
else:
|
|
672
|
+
self._muon_step_parameter(p, grad, state, group, lr)
|
|
673
|
+
|
|
674
|
+
|
|
675
|
+
def compile(self, *args, **kwargs):
|
|
676
|
+
print("Compiling Muon step path...")
|
|
677
|
+
self._compiled_muon_step = torch.compile(self._muon_step_parameter, *args, **kwargs)
|
|
678
|
+
print("Compiling AuxAdam step path...")
|
|
679
|
+
self._compiled_adam_step = torch.compile(self._adam_step_parameter, *args, **kwargs)
|
|
412
680
|
|
|
413
681
|
@torch.no_grad()
|
|
414
682
|
def step(self, closure=None):
|
|
@@ -422,4 +690,6 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
422
690
|
for i, p in enumerate(group['params']):
|
|
423
691
|
self.step_parameter(p, group, i)
|
|
424
692
|
|
|
425
|
-
|
|
693
|
+
self.global_step += 1
|
|
694
|
+
|
|
695
|
+
return loss
|
|
@@ -1,10 +1,10 @@
|
|
|
1
|
-
adv_optm/__init__.py,sha256=
|
|
2
|
-
adv_optm/optim/AdaMuon_adv.py,sha256
|
|
1
|
+
adv_optm/__init__.py,sha256=b2cBBXd4W_tBKGcxO1SHA5SaQrCl_cjxlHdoZSICo3E,380
|
|
2
|
+
adv_optm/optim/AdaMuon_adv.py,sha256=-UBw_mJj8JzDAi3zQ0nLnSOgzsTzl7b7kVksDRUziEE,30582
|
|
3
3
|
adv_optm/optim/AdamW_adv.py,sha256=KL9SCJWZ_ckAQEApB6ofbndVYjancN-v7Us7hJLFf54,17475
|
|
4
4
|
adv_optm/optim/Adopt_adv.py,sha256=S8XI2YA7683jsW8p7igc2YcU30lsN0H18qL02Kpvj8E,21244
|
|
5
5
|
adv_optm/optim/Lion_Prodigy_adv.py,sha256=LEA3UYJpPeFnmxeniLNv1u2LKKj4ufx3Bq_MLw-nWXk,14617
|
|
6
6
|
adv_optm/optim/Lion_adv.py,sha256=aGNAplZlyXYgVllYcV_s4bK8iC4fv6EizFoWIMNLdBc,8299
|
|
7
|
-
adv_optm/optim/Muon_adv.py,sha256=
|
|
7
|
+
adv_optm/optim/Muon_adv.py,sha256=8d99NcXzLyxTbxVVXC8mHyeW7wM8jjK59QoXVTLScQA,32112
|
|
8
8
|
adv_optm/optim/Prodigy_adv.py,sha256=lEjbtuQbomsCX39DnTPeI8Z5YG0f2aZPXN_E7-nGgWw,26060
|
|
9
9
|
adv_optm/optim/Simplified_AdEMAMix.py,sha256=nEIA3yM11nBooKzHudB5l3x4UdFRBYRwiKVUkGmO0K8,12971
|
|
10
10
|
adv_optm/optim/__init__.py,sha256=hpUWE6CKtt_rvMdgQVb3PtjhfZAvAxTq6hp8H8rIpBo,489
|
|
@@ -16,8 +16,8 @@ adv_optm/util/Newton_Schulz.py,sha256=wJ_sKRaGVIsOofQ737my4ng494qX_pfgOqlDDmYtnC
|
|
|
16
16
|
adv_optm/util/One_Bit_Boolean.py,sha256=Wat49esdwohuN-OHOFMW8D0aOQgV9cP5Rl8z6yfmpos,1068
|
|
17
17
|
adv_optm/util/OrthoGrad.py,sha256=NzInuBQGy_Ja__M1R9XbvqVaQ0fhGbtGgFE9YON7B3I,707
|
|
18
18
|
adv_optm/util/__init__.py,sha256=CXzS703GB4gil85khZi7sgKOnbzXGBOltshIOSPqj18,435
|
|
19
|
-
adv_optm-1.2.
|
|
20
|
-
adv_optm-1.2.
|
|
21
|
-
adv_optm-1.2.
|
|
22
|
-
adv_optm-1.2.
|
|
23
|
-
adv_optm-1.2.
|
|
19
|
+
adv_optm-1.2.dev13.dist-info/licenses/LICENSE,sha256=HrhfyXIkWY2tGFK11kg7vPCqhgh5DcxleloqdhrpyMY,11558
|
|
20
|
+
adv_optm-1.2.dev13.dist-info/METADATA,sha256=NW2SU3Uw-ow_Nn7B7VURZ4LmWKsem5KebbDrYystfU4,14023
|
|
21
|
+
adv_optm-1.2.dev13.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
22
|
+
adv_optm-1.2.dev13.dist-info/top_level.txt,sha256=iNfBIIzu-lPrQ7jyC56WBCcbkRwitM2nJ15-MRQ_6fg,9
|
|
23
|
+
adv_optm-1.2.dev13.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|