adv-optm 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- adv_optm/__init__.py +23 -0
- adv_optm/optim/AdaMuon_adv.py +720 -0
- adv_optm/optim/AdamW_adv.py +374 -0
- adv_optm/optim/Adopt_adv.py +437 -0
- adv_optm/optim/Lion_Prodigy_adv.py +341 -0
- adv_optm/optim/Lion_adv.py +210 -0
- adv_optm/optim/Muon_adv.py +723 -0
- adv_optm/optim/Prodigy_adv.py +539 -0
- adv_optm/optim/Simplified_AdEMAMix.py +298 -0
- adv_optm/optim/__init__.py +19 -0
- adv_optm/util/BF16_Stochastic_Rounding.py +47 -0
- adv_optm/util/Effective_Shape.py +8 -0
- adv_optm/util/Kourkoutas.py +159 -0
- adv_optm/util/NNMF.py +18 -0
- adv_optm/util/Newton_Schulz.py +87 -0
- adv_optm/util/One_Bit_Boolean.py +22 -0
- adv_optm/util/OrthoGrad.py +16 -0
- adv_optm/util/__init__.py +13 -0
- adv_optm-1.2.0.dist-info/METADATA +222 -0
- adv_optm-1.2.0.dist-info/RECORD +23 -0
- adv_optm-1.2.0.dist-info/WHEEL +5 -0
- adv_optm-1.2.0.dist-info/licenses/LICENSE +201 -0
- adv_optm-1.2.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,723 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
from ..util.BF16_Stochastic_Rounding import add_stochastic_
|
|
4
|
+
from ..util.Newton_Schulz import _newton_schulz_iteration
|
|
5
|
+
from ..util.Effective_Shape import _get_effective_shape
|
|
6
|
+
from ..util.NNMF import _nnmf,_unnmf
|
|
7
|
+
from ..util.One_Bit_Boolean import _pack_bools, _unpack_bools
|
|
8
|
+
from ..util.OrthoGrad import _orthogonalize_gradient
|
|
9
|
+
from ..util.Kourkoutas import KourkoutasHelper
|
|
10
|
+
|
|
11
|
+
class Muon_adv(torch.optim.Optimizer):
|
|
12
|
+
"""
|
|
13
|
+
Implements an advanced Muon algorithm, with an integrated auxiliary AdamW optimizer.
|
|
14
|
+
|
|
15
|
+
Muon (MomentUm Orthogonalized by Newton-Schulz) is an optimizer designed for
|
|
16
|
+
the hidden layers of neural networks. It applies SGD with momentum and then
|
|
17
|
+
orthogonalizes the resulting update matrix using a Newton-Schulz iteration.
|
|
18
|
+
|
|
19
|
+
When `MuonWithAuxAdam` is enabled, this single optimizer class handles both
|
|
20
|
+
'muon' and 'adam' parameter groups, dispatching to the appropriate logic internally.
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
params (iterable): iterable of parameters to optimize or dicts defining
|
|
24
|
+
parameter groups.
|
|
25
|
+
lr (float): learning rate (default: 1e-3).
|
|
26
|
+
beta1 (float): momentum factor (default: 0.9).
|
|
27
|
+
weight_decay (float): weight decay (L2 penalty) (default: 0).
|
|
28
|
+
nesterov (bool): enables Nesterov momentum (default: True).
|
|
29
|
+
ns_steps (int): number of Newton-Schulz iterations to perform (default: 5).
|
|
30
|
+
ns_eps (float): epsilon for Newton-Schulz normalization stability (default: 1e-7).
|
|
31
|
+
ns_coeffs (tuple[float, float, float]): The (a, b, c) coefficients for the
|
|
32
|
+
quintic polynomial in the Newton-Schulz iteration.
|
|
33
|
+
(default: (3.4445, -4.7750, 2.0315)).
|
|
34
|
+
Simplified_AdEMAMix (bool): whether to use the Simplified AdEMAMix update rule.
|
|
35
|
+
This changes the update to `alpha_grad * grad + mt`, which can be
|
|
36
|
+
more responsive, especially for small batch sizes. (default: False)
|
|
37
|
+
alpha_grad (float): Mixing coefficient for the Simplified AdEMAMix update rule
|
|
38
|
+
(only used when `Simplified_AdEMAMix` is `True`). Controls the weight of the
|
|
39
|
+
current gradient. For small batch sizes, use high values (e.g., 10-100) to be
|
|
40
|
+
more responsive. For large batch sizes, use low values (e.g., 0-1) for
|
|
41
|
+
stability. (default: 100.0)
|
|
42
|
+
stochastic_rounding (bool): whether to use stochastic rounding for
|
|
43
|
+
BF16 parameter updates (default: True).
|
|
44
|
+
orthogonal_gradient (bool): whether to use OrthoGrad. (default: False)
|
|
45
|
+
vector_reshape (bool): whether to reshape 1D vectors into 2D
|
|
46
|
+
matrices to apply low-rank compression (default: True).
|
|
47
|
+
nnmf_factor (bool): whether to use the factorization or disable it to use
|
|
48
|
+
the uncompressed optimizer. (default: False)
|
|
49
|
+
low_rank_ortho (bool): If True, enables low-rank orthogonalization, which
|
|
50
|
+
projects the update to a lower rank before orthogonalization.
|
|
51
|
+
(default: False)
|
|
52
|
+
ortho_rank (int): The rank for low-rank orthogonalization.
|
|
53
|
+
(default: 128)
|
|
54
|
+
normuon_variant (bool): If True, enables the NorMuon update rule, which adds
|
|
55
|
+
neuron-wise normalization. (default: False)
|
|
56
|
+
beta2_normuon (float): The exponential decay rate for the second moment estimates
|
|
57
|
+
used in NorMuon. (default: 0.95)
|
|
58
|
+
normuon_eps (float): Epsilon for NorMuon normalization stability. (default: 1e-8)
|
|
59
|
+
rms_rescaling (bool): Use Root-Mean-Square for the final update
|
|
60
|
+
vector, used for RMS-aligned rescaling. Allows for the reuse of existing Adam
|
|
61
|
+
learning rate schedules. (default: True).
|
|
62
|
+
accelerated_ns (bool): If True, enables Chebyshev-accelerated Newton-Schulz, which
|
|
63
|
+
dynamically calculates optimal 3rd-order polynomial coefficients. (default: False)
|
|
64
|
+
cns_a_bound (float): Initial lower bound for singular values for CANS. (default: 1e-4)
|
|
65
|
+
--- Auxiliary AdamW_adv Parameters (used for 'adam' groups) ---
|
|
66
|
+
adam_betas (tuple[float, float]): Betas for the AdamW optimizer part.
|
|
67
|
+
adam_eps (float): Epsilon for the AdamW optimizer part.
|
|
68
|
+
adam_weight_decay (float): Weight decay for the AdamW optimizer part.
|
|
69
|
+
adam_use_bias_correction (bool): Bias correction for AdamW.
|
|
70
|
+
adam_use_atan2 (bool): Atan2 update rule for AdamW.
|
|
71
|
+
adam_cautious_mask (bool): Cautious masking for AdamW.
|
|
72
|
+
adam_grams_moment (bool): Grams-style updates for AdamW.
|
|
73
|
+
adam_orthogonal_gradient (bool): OrthoGrad for AdamW.
|
|
74
|
+
adam_use_AdEMAMix (bool): AdEMAMix for AdamW.
|
|
75
|
+
adam_beta3_ema (float): Beta3 for AdEMAMix.
|
|
76
|
+
adam_alpha (float): Alpha for AdEMAMix.
|
|
77
|
+
adam_kourkoutas_beta (bool): Kourkoutas-β for AdamW.
|
|
78
|
+
adam_nnmf_factor (bool): 1-bit factored for AdamW.
|
|
79
|
+
"""
|
|
80
|
+
|
|
81
|
+
def __init__(
|
|
82
|
+
self,
|
|
83
|
+
params,
|
|
84
|
+
lr: float = 1e-3,
|
|
85
|
+
beta1: float = 0.9,
|
|
86
|
+
weight_decay: float = 0.0,
|
|
87
|
+
nesterov: bool = True,
|
|
88
|
+
ns_steps: int = 5,
|
|
89
|
+
ns_eps: float = 1e-7,
|
|
90
|
+
ns_coeffs: tuple[float, float, float] = (3.4445, -4.7750, 2.0315),
|
|
91
|
+
Simplified_AdEMAMix: bool = False,
|
|
92
|
+
alpha_grad: float = 100.0,
|
|
93
|
+
stochastic_rounding: bool = True,
|
|
94
|
+
orthogonal_gradient: bool = False,
|
|
95
|
+
rms_rescaling: bool = True,
|
|
96
|
+
vector_reshape: bool = False,
|
|
97
|
+
nnmf_factor: bool = False,
|
|
98
|
+
# Low-rank Muon
|
|
99
|
+
low_rank_ortho: bool = False,
|
|
100
|
+
ortho_rank: int = 128,
|
|
101
|
+
# NorMuon additions
|
|
102
|
+
normuon_variant: bool = False,
|
|
103
|
+
beta2_normuon: float = 0.95,
|
|
104
|
+
normuon_eps: float = 1e-8,
|
|
105
|
+
# CANS
|
|
106
|
+
accelerated_ns: bool = False,
|
|
107
|
+
cns_a_bound: float = 1e-4,
|
|
108
|
+
# Compiled
|
|
109
|
+
compiled_optimizer: bool = False,
|
|
110
|
+
# --- AdamW_adv specific parameters ---
|
|
111
|
+
adam_betas: tuple[float, float] = (0.9, 0.99),
|
|
112
|
+
adam_eps: float = 1e-8,
|
|
113
|
+
adam_weight_decay: float = 0.0,
|
|
114
|
+
adam_use_bias_correction: bool = True,
|
|
115
|
+
adam_use_atan2: bool = False,
|
|
116
|
+
adam_cautious_mask: bool = False,
|
|
117
|
+
adam_grams_moment: bool = False,
|
|
118
|
+
adam_orthogonal_gradient: bool = False,
|
|
119
|
+
adam_use_AdEMAMix: bool = False,
|
|
120
|
+
adam_beta3_ema: float = 0.9999,
|
|
121
|
+
adam_alpha: float = 5.0,
|
|
122
|
+
adam_kourkoutas_beta: bool = False,
|
|
123
|
+
adam_beta2_min: float = 0.9,
|
|
124
|
+
adam_ema_alpha: float = 0.95,
|
|
125
|
+
adam_tiny_spike: float = 1e-9,
|
|
126
|
+
adam_k_warmup_steps: int = 0,
|
|
127
|
+
adam_nnmf_factor: bool = False,
|
|
128
|
+
):
|
|
129
|
+
if not (lr >= 0.0):
|
|
130
|
+
raise ValueError(f"Learning-rate should be >= 0.0. Got {lr}")
|
|
131
|
+
if not (0.0 <= beta1 < 1.0):
|
|
132
|
+
raise ValueError(f"beta1 should be in [0.0, 1.0). Got {beta1}")
|
|
133
|
+
if normuon_variant and not (0.0 <= beta2_normuon < 1.0):
|
|
134
|
+
raise ValueError(f"beta2_normuon should be in [0.0, 1.0) for NorMuon. Got {beta2_normuon}")
|
|
135
|
+
if not (weight_decay >= 0.0):
|
|
136
|
+
raise ValueError(f"Weight-decay should be >= 0.0. Got {weight_decay}")
|
|
137
|
+
if not (ns_steps > 0):
|
|
138
|
+
raise ValueError(f"Newton-Schulz steps should be > 0. Got {ns_steps}")
|
|
139
|
+
if Simplified_AdEMAMix and nesterov:
|
|
140
|
+
print("Warning: nesterov is incompatible with Simplified_AdEMAMix, Disabling nesterov.")
|
|
141
|
+
nesterov = False
|
|
142
|
+
|
|
143
|
+
defaults = {
|
|
144
|
+
"lr": lr, "beta1": beta1, "weight_decay": weight_decay,
|
|
145
|
+
"nesterov": nesterov, "ns_steps": ns_steps, "ns_eps": ns_eps,
|
|
146
|
+
"ns_coeffs": ns_coeffs, "nnmf_factor": nnmf_factor,
|
|
147
|
+
"vector_reshape": vector_reshape, "rms_rescaling": rms_rescaling,
|
|
148
|
+
"Simplified_AdEMAMix": Simplified_AdEMAMix, "alpha_grad": alpha_grad,
|
|
149
|
+
"orthogonal_gradient": orthogonal_gradient,
|
|
150
|
+
'compiled_optimizer': compiled_optimizer,
|
|
151
|
+
# Low-rank Ortho
|
|
152
|
+
"low_rank_ortho": low_rank_ortho, "ortho_rank": ortho_rank,
|
|
153
|
+
# NorMuon
|
|
154
|
+
"normuon_variant": normuon_variant, "beta2_normuon": beta2_normuon,
|
|
155
|
+
"normuon_eps": normuon_eps,
|
|
156
|
+
# CANS
|
|
157
|
+
"accelerated_ns": accelerated_ns, "cns_a_bound": cns_a_bound,
|
|
158
|
+
# AdamW_adv defaults
|
|
159
|
+
"adam_betas": adam_betas, "adam_eps": adam_eps, "adam_weight_decay": adam_weight_decay,
|
|
160
|
+
"adam_use_bias_correction": adam_use_bias_correction, "adam_use_atan2": adam_use_atan2,
|
|
161
|
+
"adam_cautious_mask": adam_cautious_mask, "adam_grams_moment": adam_grams_moment,
|
|
162
|
+
"adam_orthogonal_gradient": adam_orthogonal_gradient,
|
|
163
|
+
"adam_use_AdEMAMix": adam_use_AdEMAMix, "adam_beta3_ema": adam_beta3_ema, "adam_alpha": adam_alpha,
|
|
164
|
+
"adam_kourkoutas_beta": adam_kourkoutas_beta, "adam_beta2_min": adam_beta2_min,
|
|
165
|
+
"adam_ema_alpha": adam_ema_alpha, "adam_tiny_spike": adam_tiny_spike,
|
|
166
|
+
"adam_k_warmup_steps": adam_k_warmup_steps,
|
|
167
|
+
"adam_nnmf_factor":adam_nnmf_factor,
|
|
168
|
+
}
|
|
169
|
+
self.stochastic_rounding = stochastic_rounding
|
|
170
|
+
self.compiled_optimizer = compiled_optimizer
|
|
171
|
+
|
|
172
|
+
super().__init__(params, defaults)
|
|
173
|
+
|
|
174
|
+
self.kourkoutas_helper = None
|
|
175
|
+
if any(group.get('adam_kourkoutas_beta', False) for group in self.param_groups):
|
|
176
|
+
self.kourkoutas_helper = KourkoutasHelper(self)
|
|
177
|
+
|
|
178
|
+
self.init_step()
|
|
179
|
+
|
|
180
|
+
# Initialize compiled functions to None
|
|
181
|
+
self._compiled_muon_step = None
|
|
182
|
+
self._compiled_adam_step = None
|
|
183
|
+
|
|
184
|
+
if compiled_optimizer:
|
|
185
|
+
print("Compiling Muon_adv optimizer paths...")
|
|
186
|
+
torch._dynamo.config.cache_size_limit = 8192
|
|
187
|
+
self.compile(fullgraph=True)
|
|
188
|
+
|
|
189
|
+
@property
|
|
190
|
+
def supports_fused_back_pass(self):
|
|
191
|
+
return True
|
|
192
|
+
|
|
193
|
+
@property
|
|
194
|
+
def supports_memory_efficient_fp16(self):
|
|
195
|
+
return True
|
|
196
|
+
|
|
197
|
+
@property
|
|
198
|
+
def supports_flat_params(self):
|
|
199
|
+
return False
|
|
200
|
+
|
|
201
|
+
def init_step(self):
|
|
202
|
+
for group in self.param_groups:
|
|
203
|
+
for i, p in enumerate(group['params']):
|
|
204
|
+
self.__init_state(p, group)
|
|
205
|
+
|
|
206
|
+
@torch.no_grad()
|
|
207
|
+
def __init_state(self, p, group):
|
|
208
|
+
state = self.state[p]
|
|
209
|
+
|
|
210
|
+
if len(state) > 0:
|
|
211
|
+
return
|
|
212
|
+
|
|
213
|
+
optim_type = group.get('optim_type', 'muon')
|
|
214
|
+
|
|
215
|
+
state['factored'] = (
|
|
216
|
+
group['nnmf_factor'] and
|
|
217
|
+
not (len(p.shape) == 1 and not group['vector_reshape'])
|
|
218
|
+
)
|
|
219
|
+
dtype = torch.float32 if state['factored'] else p.dtype
|
|
220
|
+
device = p.device
|
|
221
|
+
|
|
222
|
+
if optim_type == 'muon':
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
state['factored'] = (
|
|
226
|
+
group['nnmf_factor'] and
|
|
227
|
+
not (len(p.shape) == 1 and not group['vector_reshape'])
|
|
228
|
+
)
|
|
229
|
+
dtype = torch.float32 if state['factored'] else p.dtype
|
|
230
|
+
device = p.device
|
|
231
|
+
|
|
232
|
+
if state['factored']:
|
|
233
|
+
state['effective_shape'] = _get_effective_shape(p.numel())
|
|
234
|
+
d1, d2 = state['effective_shape']
|
|
235
|
+
state['mu_mbuf_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
|
|
236
|
+
state['mv_mbuf_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
|
|
237
|
+
packed_d2 = (d2 + 7) // 8
|
|
238
|
+
state['sign_buf'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=device)
|
|
239
|
+
else:
|
|
240
|
+
state['momentum_buffer'] = torch.zeros_like(p)
|
|
241
|
+
|
|
242
|
+
# NorMuon state initialization
|
|
243
|
+
if group['normuon_variant']:
|
|
244
|
+
if state['factored']:
|
|
245
|
+
state['normuon_v'] = torch.zeros(d1, device=p.device, dtype=torch.float32)
|
|
246
|
+
elif len(p.shape) >= 2:
|
|
247
|
+
state['normuon_v'] = torch.zeros(p.shape[0], device=p.device, dtype=torch.float32)
|
|
248
|
+
|
|
249
|
+
group['adam_kourkoutas_beta'] = False
|
|
250
|
+
|
|
251
|
+
elif optim_type == 'adam':
|
|
252
|
+
|
|
253
|
+
state['step'] = 0
|
|
254
|
+
|
|
255
|
+
state['factored'] = (
|
|
256
|
+
group['adam_nnmf_factor'] and
|
|
257
|
+
not (len(p.shape) == 1 and not group['vector_reshape'])
|
|
258
|
+
)
|
|
259
|
+
dtype = torch.float32 if state['factored'] else p.dtype
|
|
260
|
+
device = p.device
|
|
261
|
+
|
|
262
|
+
if state['factored']:
|
|
263
|
+
state['effective_shape'] = _get_effective_shape(p.numel())
|
|
264
|
+
d1, d2 = state['effective_shape']
|
|
265
|
+
# First moment (m)
|
|
266
|
+
if group['adam_betas'][0] > 0:
|
|
267
|
+
state['mu_m_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
|
|
268
|
+
state['mv_m_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
|
|
269
|
+
if not group.get('adam_grams_moment'):
|
|
270
|
+
packed_d2 = (d2 + 7) // 8
|
|
271
|
+
state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=device)
|
|
272
|
+
if group.get('adam_use_AdEMAMix'):
|
|
273
|
+
state['mu_m_slow_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
|
|
274
|
+
state['mv_m_slow_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
|
|
275
|
+
packed_d2 = (d2 + 7) // 8
|
|
276
|
+
state['sign_slow'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=p.device)
|
|
277
|
+
# Second moment (v)
|
|
278
|
+
state['mu_v_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
|
|
279
|
+
state['mv_v_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
|
|
280
|
+
else: # Fallback to standard AdamW for non-factored tensors
|
|
281
|
+
if group['adam_betas'][0] > 0:
|
|
282
|
+
state['exp_avg'] = torch.zeros_like(p, device=device, dtype=dtype)
|
|
283
|
+
if group.get('adam_use_AdEMAMix'):
|
|
284
|
+
state['exp_avg_slow'] = torch.zeros_like(p, device=device, dtype=dtype)
|
|
285
|
+
state['exp_avg_sq'] = torch.zeros_like(p, device=device, dtype=dtype)
|
|
286
|
+
|
|
287
|
+
@torch.no_grad()
|
|
288
|
+
def _muon_step_parameter(self, p, grad, state, group, lr):
|
|
289
|
+
beta1 = group['beta1']
|
|
290
|
+
nesterov = group['nesterov']
|
|
291
|
+
Simplified_AdEMAMix = group['Simplified_AdEMAMix']
|
|
292
|
+
alpha_grad = group['alpha_grad']
|
|
293
|
+
if grad.dtype != torch.float32 and state.get('factored', False):
|
|
294
|
+
grad = grad.float()
|
|
295
|
+
if group.get("orthogonal_gradient"):
|
|
296
|
+
grad = _orthogonalize_gradient(p, grad)
|
|
297
|
+
|
|
298
|
+
if state['factored']: # Factored Muon
|
|
299
|
+
|
|
300
|
+
# Reconstruct momentum from previous step's factors & sign
|
|
301
|
+
d1, d2 = state['effective_shape']
|
|
302
|
+
mt_buf = _unnmf((state['mu_mbuf_nmf'], state['mv_mbuf_nmf']))
|
|
303
|
+
unpacked_sign = _unpack_bools(state['sign_buf'], original_m=d2)
|
|
304
|
+
torch.where(unpacked_sign, mt_buf, -mt_buf, out=mt_buf)
|
|
305
|
+
del unpacked_sign
|
|
306
|
+
|
|
307
|
+
# Update momentum in full-size
|
|
308
|
+
grad_reshaped = grad.view(d1, d2)
|
|
309
|
+
mt_buf.mul_(beta1).add_(grad_reshaped)
|
|
310
|
+
|
|
311
|
+
if nesterov:
|
|
312
|
+
# Nesterov momentum
|
|
313
|
+
update = grad_reshaped.add(mt_buf, alpha=beta1)
|
|
314
|
+
elif Simplified_AdEMAMix:
|
|
315
|
+
update = torch.add(mt_buf, grad_reshaped, alpha=alpha_grad)
|
|
316
|
+
else:
|
|
317
|
+
# Standard momentum
|
|
318
|
+
update = mt_buf.clone()
|
|
319
|
+
del grad_reshaped
|
|
320
|
+
|
|
321
|
+
# Orthogonalization step
|
|
322
|
+
if group['low_rank_ortho']:
|
|
323
|
+
# Low-Rank Orthogonalization on the reconstructed matrix
|
|
324
|
+
M = update
|
|
325
|
+
r = min(group['ortho_rank'], M.shape[0], M.shape[1])
|
|
326
|
+
if r > 0:
|
|
327
|
+
G_sketch = torch.randn(M.shape[1], r, device=M.device, dtype=M.dtype)
|
|
328
|
+
MG = M @ G_sketch
|
|
329
|
+
if MG.dtype != torch.float32:
|
|
330
|
+
MG_dtype = M.dtype
|
|
331
|
+
Q, _ = torch.linalg.qr(MG.float())
|
|
332
|
+
Q = Q.to(MG_dtype)
|
|
333
|
+
else:
|
|
334
|
+
Q, _ = torch.linalg.qr(MG)
|
|
335
|
+
projected_M = Q.T @ M
|
|
336
|
+
ortho_projected_M = _newton_schulz_iteration(
|
|
337
|
+
projected_M, steps=group['ns_steps'], eps=group['ns_eps'], coeffs=group['ns_coeffs'], cns=group['accelerated_ns'], cns_a_bound=group['cns_a_bound']
|
|
338
|
+
)
|
|
339
|
+
update = Q @ ortho_projected_M
|
|
340
|
+
else: # Fallback for invalid rank
|
|
341
|
+
update = _newton_schulz_iteration(
|
|
342
|
+
update, steps=group['ns_steps'], eps=group['ns_eps'], coeffs=group['ns_coeffs'], cns=group['accelerated_ns'], cns_a_bound=group['cns_a_bound']
|
|
343
|
+
)
|
|
344
|
+
else:
|
|
345
|
+
# Original full Newton-Schulz
|
|
346
|
+
update = _newton_schulz_iteration(
|
|
347
|
+
update,
|
|
348
|
+
steps=group['ns_steps'],
|
|
349
|
+
eps=group['ns_eps'],
|
|
350
|
+
coeffs=group['ns_coeffs'],
|
|
351
|
+
cns=group['accelerated_ns'],
|
|
352
|
+
cns_a_bound=group['cns_a_bound'],
|
|
353
|
+
)
|
|
354
|
+
|
|
355
|
+
|
|
356
|
+
if group['normuon_variant']:
|
|
357
|
+
v_t = state['normuon_v']
|
|
358
|
+
beta2_normuon = group['beta2_normuon']
|
|
359
|
+
# Update 2nd moment estimate
|
|
360
|
+
mean_squared_update = torch.mean(update.square(), dim=1)
|
|
361
|
+
v_t.mul_(beta2_normuon).add_(mean_squared_update, alpha=1 - beta2_normuon)
|
|
362
|
+
# Normalize update
|
|
363
|
+
update.div_(v_t.sqrt().unsqueeze(1).add_(group['normuon_eps']))
|
|
364
|
+
del mean_squared_update
|
|
365
|
+
|
|
366
|
+
# RMS-aligned rescaling
|
|
367
|
+
if group['rms_rescaling']:
|
|
368
|
+
rms_target = 0.2 # default (Adam) value for RMS
|
|
369
|
+
update_norm = torch.linalg.vector_norm(update)
|
|
370
|
+
update = update.view(p.shape).mul_(rms_target * lr * (p.numel()**0.5) / update_norm.add_(1e-8))
|
|
371
|
+
del update_norm
|
|
372
|
+
else:
|
|
373
|
+
update = update.view(p.shape).mul_(lr)
|
|
374
|
+
|
|
375
|
+
state['sign_buf'] = _pack_bools(mt_buf > 0)
|
|
376
|
+
_nnmf(mt_buf.abs(), out=(state['mu_mbuf_nmf'], state['mv_mbuf_nmf']))
|
|
377
|
+
del mt_buf
|
|
378
|
+
|
|
379
|
+
else: # Standard Muon logic for non-factored tensors
|
|
380
|
+
|
|
381
|
+
if len(p.shape) >= 2:
|
|
382
|
+
|
|
383
|
+
original_shape = p.shape
|
|
384
|
+
|
|
385
|
+
# Momentum update
|
|
386
|
+
mt_buf = state['momentum_buffer']
|
|
387
|
+
mt_buf.mul_(beta1).add_(grad)
|
|
388
|
+
|
|
389
|
+
if nesterov:
|
|
390
|
+
# Nesterov momentum
|
|
391
|
+
update = grad.add(mt_buf, alpha=beta1)
|
|
392
|
+
elif Simplified_AdEMAMix:
|
|
393
|
+
update = torch.add(mt_buf, grad, alpha=alpha_grad)
|
|
394
|
+
else:
|
|
395
|
+
# Standard momentum
|
|
396
|
+
update = mt_buf.clone()
|
|
397
|
+
|
|
398
|
+
# flatten to 2D for orthogonalization.
|
|
399
|
+
# This is a no-op for 2D tensors and correctly flattens 4D+ tensors.
|
|
400
|
+
# This removes the dynamic control flow that breaks torch.compile.
|
|
401
|
+
update = update.view(original_shape[0], -1)
|
|
402
|
+
|
|
403
|
+
# Orthogonalization step
|
|
404
|
+
if group['low_rank_ortho']:
|
|
405
|
+
# Low-Rank Orthogonalization based on Gaussian Sketching
|
|
406
|
+
M = update
|
|
407
|
+
r = min(group['ortho_rank'], M.shape[0], M.shape[1])
|
|
408
|
+
|
|
409
|
+
if r > 0:
|
|
410
|
+
# 1. Sketch the matrix
|
|
411
|
+
G_sketch = torch.randn(M.shape[1], r, device=M.device, dtype=M.dtype)
|
|
412
|
+
MG = M @ G_sketch
|
|
413
|
+
|
|
414
|
+
# 2. QR decomposition to get orthogonal basis Q
|
|
415
|
+
if MG.dtype != torch.float32:
|
|
416
|
+
MG_dtype = M.dtype
|
|
417
|
+
Q, _ = torch.linalg.qr(MG.float())
|
|
418
|
+
Q = Q.to(MG_dtype)
|
|
419
|
+
else:
|
|
420
|
+
Q, _ = torch.linalg.qr(MG)
|
|
421
|
+
|
|
422
|
+
# 3. Project M onto the basis
|
|
423
|
+
projected_M = Q.T @ M
|
|
424
|
+
|
|
425
|
+
# 4. Orthogonalize the smaller projected matrix
|
|
426
|
+
ortho_projected_M = _newton_schulz_iteration(
|
|
427
|
+
projected_M,
|
|
428
|
+
steps=group['ns_steps'],
|
|
429
|
+
eps=group['ns_eps'],
|
|
430
|
+
coeffs=group['ns_coeffs'],
|
|
431
|
+
cns=group['accelerated_ns'],
|
|
432
|
+
cns_a_bound=group['cns_a_bound'],
|
|
433
|
+
)
|
|
434
|
+
|
|
435
|
+
# 5. Project back to the original space
|
|
436
|
+
update = Q @ ortho_projected_M
|
|
437
|
+
else: # Fallback for invalid rank
|
|
438
|
+
update = _newton_schulz_iteration(
|
|
439
|
+
update,
|
|
440
|
+
steps=group['ns_steps'],
|
|
441
|
+
eps=group['ns_eps'],
|
|
442
|
+
coeffs=group['ns_coeffs'],
|
|
443
|
+
cns=group['accelerated_ns'],
|
|
444
|
+
cns_a_bound=group['cns_a_bound'],
|
|
445
|
+
)
|
|
446
|
+
else:
|
|
447
|
+
# Original NewtonSchulz
|
|
448
|
+
update = _newton_schulz_iteration(
|
|
449
|
+
update,
|
|
450
|
+
steps=group['ns_steps'],
|
|
451
|
+
eps=group['ns_eps'],
|
|
452
|
+
coeffs=group['ns_coeffs'],
|
|
453
|
+
cns=group['accelerated_ns'],
|
|
454
|
+
cns_a_bound=group['cns_a_bound'],
|
|
455
|
+
)
|
|
456
|
+
|
|
457
|
+
# NorMuon Logic
|
|
458
|
+
if group['normuon_variant']:
|
|
459
|
+
v_t = state['normuon_v']
|
|
460
|
+
beta2_normuon = group['beta2_normuon']
|
|
461
|
+
# Update 2nd moment estimate
|
|
462
|
+
mean_squared_update = torch.mean(update.square(), dim=1)
|
|
463
|
+
v_t.mul_(beta2_normuon).add_(mean_squared_update, alpha=1 - beta2_normuon)
|
|
464
|
+
# Normalize update
|
|
465
|
+
update.div_(v_t.sqrt().unsqueeze(1).add_(group['normuon_eps']))
|
|
466
|
+
|
|
467
|
+
# RMS-aligned rescaling
|
|
468
|
+
if group['rms_rescaling']:
|
|
469
|
+
rms_target = 0.2 # default (Adam) value for RMS
|
|
470
|
+
update_norm = torch.linalg.vector_norm(update)
|
|
471
|
+
update = update.view(original_shape).mul_(rms_target * lr * (p.numel()**0.5) / update_norm.add_(1e-8))
|
|
472
|
+
del update_norm
|
|
473
|
+
else:
|
|
474
|
+
update = update.view(original_shape).mul_(lr)
|
|
475
|
+
|
|
476
|
+
|
|
477
|
+
else: # Fallback to standard SGD with momentum for 1D params (biases, etc.)
|
|
478
|
+
# Momentum update
|
|
479
|
+
mt_buf = state['momentum_buffer']
|
|
480
|
+
mt_buf.mul_(beta1).add_(grad)
|
|
481
|
+
if nesterov:
|
|
482
|
+
# Nesterov momentum
|
|
483
|
+
update = grad.add(mt_buf, alpha=beta1)
|
|
484
|
+
# FIXME, Simplified_AdEMAMix will break SGD since it requires x100 lower LR
|
|
485
|
+
# elif Simplified_AdEMAMix:
|
|
486
|
+
# update = torch.add(mt_buf, grad, alpha=alpha_grad)
|
|
487
|
+
else:
|
|
488
|
+
# Standard momentum
|
|
489
|
+
update = mt_buf.clone()
|
|
490
|
+
update.mul_(lr)
|
|
491
|
+
|
|
492
|
+
# Decoupled weight decay
|
|
493
|
+
if group["weight_decay"] != 0:
|
|
494
|
+
if p.dtype == torch.bfloat16 and self.stochastic_rounding:
|
|
495
|
+
add_stochastic_(p.data, p.data, alpha=-group["weight_decay"] * lr)
|
|
496
|
+
else:
|
|
497
|
+
p.data.add_(p.data, alpha=-group["weight_decay"] * lr)
|
|
498
|
+
|
|
499
|
+
if p.dtype == torch.bfloat16 and self.stochastic_rounding:
|
|
500
|
+
add_stochastic_(p.data, -update)
|
|
501
|
+
else:
|
|
502
|
+
p.data.add_(-update)
|
|
503
|
+
del update
|
|
504
|
+
|
|
505
|
+
@torch.no_grad()
|
|
506
|
+
def _adam_step_parameter(self, p, grad, state, group, lr, bias_correction1, bias_correction2):
|
|
507
|
+
if grad.dtype != torch.float32 and state.get('factored', False):
|
|
508
|
+
grad = grad.float()
|
|
509
|
+
if group.get("adam_orthogonal_gradient"):
|
|
510
|
+
grad = _orthogonalize_gradient(p, grad)
|
|
511
|
+
|
|
512
|
+
beta1_adam, beta2_adam = group['adam_betas']
|
|
513
|
+
|
|
514
|
+
if group.get('adam_kourkoutas_beta', False):
|
|
515
|
+
# Accumulate current grad's norm for the *next* step
|
|
516
|
+
self.kourkoutas_helper.accumulate_gradient_sq_norm(p, grad)
|
|
517
|
+
# Get the dynamic beta2_adam calculated in prepare_step()
|
|
518
|
+
beta2_adam = self.kourkoutas_helper.get_beta2(p, group)
|
|
519
|
+
|
|
520
|
+
step_size = lr / bias_correction1
|
|
521
|
+
|
|
522
|
+
if group.get('adam_use_AdEMAMix'):
|
|
523
|
+
beta3_ema = group['adam_beta3_ema']
|
|
524
|
+
alpha = group['adam_alpha']
|
|
525
|
+
|
|
526
|
+
if state['factored']:
|
|
527
|
+
d1, d2 = state['effective_shape']
|
|
528
|
+
grad_reshaped = grad.view(d1, d2)
|
|
529
|
+
|
|
530
|
+
# Reconstruct momentum from previous step's factors
|
|
531
|
+
if beta1_adam > 0:
|
|
532
|
+
mt = _unnmf((state['mu_m_nmf'], state['mv_m_nmf']))
|
|
533
|
+
if not group.get('adam_grams_moment'):
|
|
534
|
+
unpacked_sign = _unpack_bools(state['sign'], original_m=d2)
|
|
535
|
+
torch.where(unpacked_sign, mt, -mt, out=mt)
|
|
536
|
+
del unpacked_sign
|
|
537
|
+
# Update momentum in full-size
|
|
538
|
+
mt.mul_(beta1_adam).add_(grad_reshaped, alpha=1.0 - beta1_adam)
|
|
539
|
+
if group.get('adam_grams_moment'):
|
|
540
|
+
mt = (grad_reshaped.sign().mul_(mt.abs()))
|
|
541
|
+
elif group.get('adam_cautious_mask'):
|
|
542
|
+
mask = (mt * grad_reshaped > 0).to(grad_reshaped.dtype)
|
|
543
|
+
mask.div_(mask.mean().clamp_(min=1e-3))
|
|
544
|
+
mt.mul_(mask)
|
|
545
|
+
del mask
|
|
546
|
+
|
|
547
|
+
vt = _unnmf((state['mu_v_nmf'], state['mv_v_nmf']))
|
|
548
|
+
vt.mul_(beta2_adam).addcmul_(grad_reshaped, grad_reshaped, value=1.0 - beta2_adam)
|
|
549
|
+
|
|
550
|
+
if group.get('adam_use_AdEMAMix'):
|
|
551
|
+
mt_slow = _unnmf((state['mu_m_slow_nmf'], state['mv_m_slow_nmf']))
|
|
552
|
+
if state['sign_slow'].dtype != torch.uint8:
|
|
553
|
+
state['sign_slow'] = state['sign_slow'].to(torch.uint8)
|
|
554
|
+
unpacked_sign_slow = _unpack_bools(state['sign_slow'], original_m=d2)
|
|
555
|
+
torch.where(unpacked_sign_slow, mt_slow, -mt_slow, out=mt_slow)
|
|
556
|
+
del unpacked_sign_slow
|
|
557
|
+
|
|
558
|
+
mt_slow.mul_(beta3_ema).add_(grad_reshaped, alpha=1.0 - beta3_ema)
|
|
559
|
+
if beta1_adam > 0:
|
|
560
|
+
update = torch.add(mt, mt_slow, alpha=alpha)
|
|
561
|
+
else:
|
|
562
|
+
update = torch.add(grad_reshaped, mt_slow, alpha=alpha)
|
|
563
|
+
else:
|
|
564
|
+
update = mt.clone() if beta1_adam > 0 else grad_reshaped.clone()
|
|
565
|
+
del grad_reshaped
|
|
566
|
+
|
|
567
|
+
if group['adam_use_atan2']:
|
|
568
|
+
a = 1.2732395
|
|
569
|
+
denom = (vt.sqrt() / (bias_correction2**0.5))
|
|
570
|
+
update.atan2_(denom).mul_(a)
|
|
571
|
+
else:
|
|
572
|
+
denom = (vt.sqrt() / (bias_correction2**0.5)).add_(group['adam_eps'])
|
|
573
|
+
update.div_(denom)
|
|
574
|
+
del denom
|
|
575
|
+
|
|
576
|
+
update = update.view(p.shape).mul_(step_size)
|
|
577
|
+
|
|
578
|
+
# Compress updated moments and store new factors
|
|
579
|
+
if beta1_adam > 0:
|
|
580
|
+
if not group.get('adam_grams_moment'):
|
|
581
|
+
state['sign'] = _pack_bools(mt > 0)
|
|
582
|
+
_nnmf(mt.abs(), out=(state['mu_m_nmf'], state['mv_m_nmf']))
|
|
583
|
+
del mt
|
|
584
|
+
if group.get('adam_use_AdEMAMix'):
|
|
585
|
+
state['sign_slow'] = _pack_bools(mt_slow > 0)
|
|
586
|
+
_nnmf(mt_slow.abs(), out=(state['mu_m_slow_nmf'], state['mv_m_slow_nmf']))
|
|
587
|
+
del mt_slow
|
|
588
|
+
_nnmf(vt, out=(state['mu_v_nmf'], state['mv_v_nmf']))
|
|
589
|
+
del vt
|
|
590
|
+
|
|
591
|
+
else: # Standard AdamW logic for non-factored tensors
|
|
592
|
+
exp_avg_sq = state['exp_avg_sq']
|
|
593
|
+
|
|
594
|
+
if beta1_adam > 0:
|
|
595
|
+
exp_avg = state['exp_avg']
|
|
596
|
+
exp_avg.mul_(beta1_adam).add_(grad, alpha=1 - beta1_adam)
|
|
597
|
+
if group.get('adam_grams_moment'):
|
|
598
|
+
exp_avg = grad.sign().mul_(exp_avg.abs())
|
|
599
|
+
elif group.get('adam_cautious_mask'):
|
|
600
|
+
mask = (exp_avg * grad > 0).to(grad.dtype)
|
|
601
|
+
mask.div_(mask.mean().clamp_(min=1e-3))
|
|
602
|
+
exp_avg.mul_(mask)
|
|
603
|
+
del mask
|
|
604
|
+
|
|
605
|
+
if group.get('adam_use_AdEMAMix'):
|
|
606
|
+
exp_avg_slow = state['exp_avg_slow']
|
|
607
|
+
exp_avg_slow.mul_(beta3_ema).add_(grad, alpha=1 - beta3_ema)
|
|
608
|
+
if beta1_adam > 0:
|
|
609
|
+
update = torch.add(exp_avg, exp_avg_slow, alpha=alpha)
|
|
610
|
+
else:
|
|
611
|
+
update = torch.add(grad, exp_avg_slow, alpha=alpha)
|
|
612
|
+
else:
|
|
613
|
+
update = exp_avg.clone() if beta1_adam > 0 else grad.clone()
|
|
614
|
+
|
|
615
|
+
exp_avg_sq.mul_(beta2_adam).addcmul_(grad, grad.conj(), value=1 - beta2_adam)
|
|
616
|
+
|
|
617
|
+
if group.get('adam_use_atan2'):
|
|
618
|
+
a = 1.2732395
|
|
619
|
+
denom = (exp_avg_sq.sqrt() / (bias_correction2**0.5))
|
|
620
|
+
update.atan2_(denom).mul_(a)
|
|
621
|
+
else:
|
|
622
|
+
denom = (exp_avg_sq.sqrt() / (bias_correction2**0.5)).add_(group['adam_eps'])
|
|
623
|
+
update.div_(denom)
|
|
624
|
+
del denom
|
|
625
|
+
|
|
626
|
+
update.mul_(step_size)
|
|
627
|
+
|
|
628
|
+
# Decoupled weight decay
|
|
629
|
+
if group["adam_weight_decay"] != 0:
|
|
630
|
+
if p.dtype == torch.bfloat16 and self.stochastic_rounding:
|
|
631
|
+
add_stochastic_(p.data, p.data, alpha=-group["adam_weight_decay"] * lr)
|
|
632
|
+
else:
|
|
633
|
+
p.data.add_(p.data, alpha=-group["adam_weight_decay"] * lr)
|
|
634
|
+
|
|
635
|
+
if p.dtype == torch.bfloat16 and self.stochastic_rounding:
|
|
636
|
+
add_stochastic_(p.data, -update)
|
|
637
|
+
else:
|
|
638
|
+
p.data.add_(-update)
|
|
639
|
+
del update
|
|
640
|
+
|
|
641
|
+
@torch.no_grad()
|
|
642
|
+
def step_parameter(self, p: torch.Tensor, group: dict, i: int | None = None):
|
|
643
|
+
grad = p.grad
|
|
644
|
+
if grad is None:
|
|
645
|
+
return
|
|
646
|
+
|
|
647
|
+
state = self.state[p]
|
|
648
|
+
|
|
649
|
+
# Determine if using Adam or Muon based on state keys
|
|
650
|
+
# We can use optm_type but I see this as a safer way.
|
|
651
|
+
if 'momentum_buffer' in state or 'mu_mbuf_nmf' in state:
|
|
652
|
+
use_adam = False
|
|
653
|
+
else:
|
|
654
|
+
use_adam = True
|
|
655
|
+
|
|
656
|
+
lr = group['lr']
|
|
657
|
+
is_compiled = group.get('compiled_optimizer', False)
|
|
658
|
+
|
|
659
|
+
if use_adam:
|
|
660
|
+
step = state['step']
|
|
661
|
+
|
|
662
|
+
if self.kourkoutas_helper:
|
|
663
|
+
# Prepare Kourkoutas-β once per optimizer step.
|
|
664
|
+
self.kourkoutas_helper.maybe_prepare_step(step)
|
|
665
|
+
|
|
666
|
+
# Adam-specific setup (bias correction)
|
|
667
|
+
if group['adam_use_bias_correction']:
|
|
668
|
+
current_step = step + 1
|
|
669
|
+
beta1_adam, beta2_adam = group['adam_betas']
|
|
670
|
+
bias_correction1 = 1.0 - beta1_adam ** current_step
|
|
671
|
+
bias_correction2 = 1.0 - beta2_adam ** current_step
|
|
672
|
+
else:
|
|
673
|
+
bias_correction1 = 1.0
|
|
674
|
+
bias_correction2 = 1.0
|
|
675
|
+
|
|
676
|
+
self.state[p]['step'] += 1
|
|
677
|
+
|
|
678
|
+
# Dispatch to compiled or uncompiled Adam step
|
|
679
|
+
if is_compiled and self._compiled_adam_step is not None:
|
|
680
|
+
# convert to tensors for compiled path once a step
|
|
681
|
+
if not hasattr(self, 'lr_adam_tensor') or self.lr_adam_tensor is None:
|
|
682
|
+
self.lr_adam_tensor = torch.tensor(group['lr'])
|
|
683
|
+
self.bc1 = torch.tensor(bias_correction1)
|
|
684
|
+
self.bc2 = torch.tensor(bias_correction2)
|
|
685
|
+
self._compiled_adam_step(p, grad, state, group, self.lr_adam_tensor, self.bc1, self.bc2)
|
|
686
|
+
else:
|
|
687
|
+
self._adam_step_parameter(p, grad, state, group, lr, bias_correction1, bias_correction2)
|
|
688
|
+
else: # Muon path
|
|
689
|
+
# Dispatch to compiled or uncompiled Muon step
|
|
690
|
+
if is_compiled and self._compiled_muon_step is not None:
|
|
691
|
+
lr_tensor = torch.tensor(lr, device=p.device)
|
|
692
|
+
self._compiled_muon_step(p, grad, state, group, lr_tensor)
|
|
693
|
+
# convert to tensors for compiled path once a step
|
|
694
|
+
if not hasattr(self, 'lr_tensor') or self.lr_tensor is None:
|
|
695
|
+
self.lr_tensor = torch.tensor(group['lr'])
|
|
696
|
+
self._compiled_muon_step(p, grad, state, group, self.lr_tensor)
|
|
697
|
+
else:
|
|
698
|
+
self._muon_step_parameter(p, grad, state, group, lr)
|
|
699
|
+
|
|
700
|
+
|
|
701
|
+
def compile(self, *args, **kwargs):
|
|
702
|
+
print("Compiling Muon step path...")
|
|
703
|
+
self._compiled_muon_step = torch.compile(self._muon_step_parameter, *args, **kwargs)
|
|
704
|
+
print("Compiling AuxAdam step path...")
|
|
705
|
+
self._compiled_adam_step = torch.compile(self._adam_step_parameter, *args, **kwargs)
|
|
706
|
+
|
|
707
|
+
@torch.no_grad()
|
|
708
|
+
def step(self, closure=None):
|
|
709
|
+
"""Performs a single optimization step."""
|
|
710
|
+
loss = None
|
|
711
|
+
if closure is not None:
|
|
712
|
+
with torch.enable_grad():
|
|
713
|
+
loss = closure()
|
|
714
|
+
|
|
715
|
+
for group in self.param_groups:
|
|
716
|
+
for i, p in enumerate(group['params']):
|
|
717
|
+
self.step_parameter(p, group, i)
|
|
718
|
+
|
|
719
|
+
if self.param_groups[0].get('compiled_optimizer', False):
|
|
720
|
+
# Reset compile tensors once a step
|
|
721
|
+
self.lr_tensor = None
|
|
722
|
+
self.lr_adam_tensor = None
|
|
723
|
+
return loss
|