adv-optm 1.2.dev6__tar.gz → 1.2.dev8__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of adv-optm might be problematic. Click here for more details.
- {adv_optm-1.2.dev6 → adv_optm-1.2.dev8}/PKG-INFO +1 -1
- {adv_optm-1.2.dev6 → adv_optm-1.2.dev8}/adv_optm/__init__.py +1 -1
- {adv_optm-1.2.dev6 → adv_optm-1.2.dev8}/adv_optm/optim/AdaMuon_adv.py +74 -68
- {adv_optm-1.2.dev6 → adv_optm-1.2.dev8}/adv_optm/optim/Muon_adv.py +166 -19
- {adv_optm-1.2.dev6 → adv_optm-1.2.dev8}/adv_optm/util/Kourkoutas.py +1 -1
- {adv_optm-1.2.dev6 → adv_optm-1.2.dev8}/adv_optm.egg-info/PKG-INFO +1 -1
- {adv_optm-1.2.dev6 → adv_optm-1.2.dev8}/setup.py +1 -1
- {adv_optm-1.2.dev6 → adv_optm-1.2.dev8}/LICENSE +0 -0
- {adv_optm-1.2.dev6 → adv_optm-1.2.dev8}/README.md +0 -0
- {adv_optm-1.2.dev6 → adv_optm-1.2.dev8}/adv_optm/optim/AdamW_adv.py +0 -0
- {adv_optm-1.2.dev6 → adv_optm-1.2.dev8}/adv_optm/optim/Adopt_adv.py +0 -0
- {adv_optm-1.2.dev6 → adv_optm-1.2.dev8}/adv_optm/optim/Lion_Prodigy_adv.py +0 -0
- {adv_optm-1.2.dev6 → adv_optm-1.2.dev8}/adv_optm/optim/Lion_adv.py +0 -0
- {adv_optm-1.2.dev6 → adv_optm-1.2.dev8}/adv_optm/optim/Prodigy_adv.py +0 -0
- {adv_optm-1.2.dev6 → adv_optm-1.2.dev8}/adv_optm/optim/Simplified_AdEMAMix.py +0 -0
- {adv_optm-1.2.dev6 → adv_optm-1.2.dev8}/adv_optm/optim/__init__.py +0 -0
- {adv_optm-1.2.dev6 → adv_optm-1.2.dev8}/adv_optm/util/BF16_Stochastic_Rounding.py +0 -0
- {adv_optm-1.2.dev6 → adv_optm-1.2.dev8}/adv_optm/util/Effective_Shape.py +0 -0
- {adv_optm-1.2.dev6 → adv_optm-1.2.dev8}/adv_optm/util/MuonAdam_helper.py +0 -0
- {adv_optm-1.2.dev6 → adv_optm-1.2.dev8}/adv_optm/util/NNMF.py +0 -0
- {adv_optm-1.2.dev6 → adv_optm-1.2.dev8}/adv_optm/util/Newton_Schulz.py +0 -0
- {adv_optm-1.2.dev6 → adv_optm-1.2.dev8}/adv_optm/util/One_Bit_Boolean.py +0 -0
- {adv_optm-1.2.dev6 → adv_optm-1.2.dev8}/adv_optm/util/OrthoGrad.py +0 -0
- {adv_optm-1.2.dev6 → adv_optm-1.2.dev8}/adv_optm/util/__init__.py +0 -0
- {adv_optm-1.2.dev6 → adv_optm-1.2.dev8}/adv_optm.egg-info/SOURCES.txt +0 -0
- {adv_optm-1.2.dev6 → adv_optm-1.2.dev8}/adv_optm.egg-info/dependency_links.txt +0 -0
- {adv_optm-1.2.dev6 → adv_optm-1.2.dev8}/adv_optm.egg-info/requires.txt +0 -0
- {adv_optm-1.2.dev6 → adv_optm-1.2.dev8}/adv_optm.egg-info/top_level.txt +0 -0
- {adv_optm-1.2.dev6 → adv_optm-1.2.dev8}/setup.cfg +0 -0
|
@@ -3,7 +3,6 @@ from typing import Optional, Callable
|
|
|
3
3
|
|
|
4
4
|
from .AdamW_adv import AdamW_adv
|
|
5
5
|
from ..util.MuonAdam_helper import MuonAdamHelper
|
|
6
|
-
from ..util.Kourkoutas import KourkoutasHelper
|
|
7
6
|
|
|
8
7
|
from ..util.BF16_Stochastic_Rounding import add_stochastic_
|
|
9
8
|
from ..util.Newton_Schulz import _newton_schulz_iteration
|
|
@@ -64,22 +63,13 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
64
63
|
matrices for muon NewtonSchulz (default: False).
|
|
65
64
|
vector_reshape (bool): whether to reshape 1D vectors into 2D
|
|
66
65
|
matrices to apply low-rank compression (default: True).
|
|
66
|
+
low_rank_ortho (bool): If True, enables low-rank orthogonalization, which
|
|
67
|
+
projects the update to a lower rank before orthogonalization.
|
|
68
|
+
(default: False)
|
|
69
|
+
ortho_rank (int): The rank for low-rank orthogonalization.
|
|
70
|
+
(default: 128)
|
|
67
71
|
nnmf_factor (bool): whether to use the factorization or disable it to use
|
|
68
72
|
the uncompressed optimizer. (default: False)
|
|
69
|
-
kourkoutas_beta (bool): whether to enable the layer-wise dynamic β₂ logic.
|
|
70
|
-
If `False`, the optimizer behaves as standard AdamW. (default: False)
|
|
71
|
-
beta2_min (float): The minimum value for dynamic β₂, used during periods of
|
|
72
|
-
high gradient variance ("sunspikes"). Must be less than `betas[1]`.
|
|
73
|
-
(default: 0.88)
|
|
74
|
-
ema_alpha (float): The decay rate for the Exponential Moving Average (EMA) of
|
|
75
|
-
the pooled gradient norms. Corresponds to `α` in the paper.
|
|
76
|
-
(default: 0.93)
|
|
77
|
-
tiny_spike (float): A small constant added to the denominator of the
|
|
78
|
-
"sunspike" ratio calculation to prevent division by zero. Corresponds
|
|
79
|
-
to `ε_spike` in the paper. (default: 1e-9)
|
|
80
|
-
k_warmup_steps (int): The number of initial steps during which β₂ is held
|
|
81
|
-
at a fixed beta2 value before the
|
|
82
|
-
dynamic logic activates. (default: 0)
|
|
83
73
|
MuonWithAuxAdam (bool): If True, enables the hybrid optimizer mode.
|
|
84
74
|
Parameters designated by `layer_key_fn` will be optimized with
|
|
85
75
|
AdamW_adv instead of Muon. (default: False)
|
|
@@ -110,15 +100,10 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
110
100
|
alpha_grad: float = 100.0,
|
|
111
101
|
vector_reshape_muon: bool = False,
|
|
112
102
|
vector_reshape: bool = False,
|
|
103
|
+
# Low-rank Muon
|
|
104
|
+
low_rank_ortho: bool = False,
|
|
105
|
+
ortho_rank: int = 128,
|
|
113
106
|
nnmf_factor: bool = False,
|
|
114
|
-
# K-b parameters
|
|
115
|
-
kourkoutas_beta: bool = False,
|
|
116
|
-
beta2_min: float = 0.9,
|
|
117
|
-
ema_alpha: float = 0.95,
|
|
118
|
-
tiny_spike: float = 1e-9,
|
|
119
|
-
k_warmup_steps: int = 0,
|
|
120
|
-
k_logging: int = 0,
|
|
121
|
-
layer_key_kb_fn: Optional[Callable] = None,
|
|
122
107
|
# hybrid optimizer mode
|
|
123
108
|
MuonWithAuxAdam: bool = False,
|
|
124
109
|
layer_key_fn: Optional[Callable] = None,
|
|
@@ -142,14 +127,11 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
142
127
|
"vector_reshape": vector_reshape,
|
|
143
128
|
"vector_reshape_muon": vector_reshape_muon,
|
|
144
129
|
"nesterov":nesterov, "use_atan2":use_atan2,
|
|
145
|
-
"Simplified_AdEMAMix": Simplified_AdEMAMix, "alpha_grad": alpha_grad,
|
|
146
|
-
|
|
147
|
-
"
|
|
130
|
+
"Simplified_AdEMAMix": Simplified_AdEMAMix, "alpha_grad": alpha_grad,
|
|
131
|
+
# Low-rank Ortho
|
|
132
|
+
"low_rank_ortho": low_rank_ortho, "ortho_rank": ortho_rank,
|
|
148
133
|
}
|
|
149
134
|
self.stochastic_rounding = stochastic_rounding
|
|
150
|
-
self._kourkoutas_beta = kourkoutas_beta
|
|
151
|
-
self._kourkoutas_helper = None
|
|
152
|
-
self.layer_key_kb_fn = layer_key_kb_fn
|
|
153
135
|
self.MuonWithAuxAdam = MuonWithAuxAdam
|
|
154
136
|
self.helper = None
|
|
155
137
|
self.aux_adam = None
|
|
@@ -182,14 +164,9 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
182
164
|
|
|
183
165
|
for key, value in defaults_to_use.items():
|
|
184
166
|
new_group.setdefault(key, value)
|
|
185
|
-
if '_kourkoutas_beta' not in new_group:
|
|
186
|
-
if optim_type == 'adam':
|
|
187
|
-
new_group['_kourkoutas_beta'] = False
|
|
188
|
-
else:
|
|
189
|
-
new_group['_kourkoutas_beta'] = muon_defaults['_kourkoutas_beta']
|
|
190
167
|
final_param_groups.append(new_group)
|
|
191
168
|
|
|
192
|
-
super().__init__(final_param_groups,
|
|
169
|
+
super().__init__(final_param_groups, muon_defaults)
|
|
193
170
|
|
|
194
171
|
# Now that self is initialized, create the helper
|
|
195
172
|
self.helper = MuonAdamHelper(self, layer_key_fn)
|
|
@@ -219,9 +196,6 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
219
196
|
|
|
220
197
|
@torch.no_grad()
|
|
221
198
|
def step_parameter(self, p: torch.Tensor, group: dict, i: int | None = None):
|
|
222
|
-
if group['_kourkoutas_beta'] and self._kourkoutas_helper is None:
|
|
223
|
-
self._kourkoutas_helper = KourkoutasHelper(self)
|
|
224
|
-
|
|
225
199
|
if self.MuonWithAuxAdam:
|
|
226
200
|
optim_type = self.helper.get_optimizer_type(p)
|
|
227
201
|
if optim_type == 'adam':
|
|
@@ -277,7 +251,6 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
277
251
|
|
|
278
252
|
# Retrieve hyperparameters
|
|
279
253
|
beta1, beta2 = group['betas']
|
|
280
|
-
current_step = state['step']
|
|
281
254
|
nesterov = group['nesterov']
|
|
282
255
|
Simplified_AdEMAMix = group['Simplified_AdEMAMix']
|
|
283
256
|
alpha_grad = group['alpha_grad']
|
|
@@ -303,20 +276,37 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
303
276
|
signed_m_buf = torch.sign(mt_buf)
|
|
304
277
|
del grad_reshaped
|
|
305
278
|
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
279
|
+
# Orthogonalization step
|
|
280
|
+
if group['low_rank_ortho']:
|
|
281
|
+
# Low-Rank Orthogonalization on the reconstructed matrix
|
|
282
|
+
M = signed_m_buf
|
|
283
|
+
r = min(group['ortho_rank'], M.shape[0], M.shape[1])
|
|
284
|
+
if r > 0:
|
|
285
|
+
G_sketch = torch.randn(M.shape[1], r, device=M.device, dtype=M.dtype)
|
|
286
|
+
MG = M @ G_sketch
|
|
287
|
+
if MG.dtype != torch.float32:
|
|
288
|
+
MG_dtype = M.dtype
|
|
289
|
+
Q, _ = torch.linalg.qr(MG.float())
|
|
290
|
+
Q = Q.to(MG_dtype)
|
|
291
|
+
else:
|
|
292
|
+
Q, _ = torch.linalg.qr(MG)
|
|
293
|
+
projected_M = Q.T @ M
|
|
294
|
+
ortho_projected_M = _newton_schulz_iteration(
|
|
295
|
+
projected_M, steps=group['ns_steps'], eps=group['ns_eps'], coeffs=group['ns_coeffs']
|
|
296
|
+
)
|
|
297
|
+
update = Q @ ortho_projected_M
|
|
298
|
+
else: # Fallback for invalid rank
|
|
299
|
+
update = _newton_schulz_iteration(
|
|
300
|
+
signed_m_buf, steps=group['ns_steps'], eps=group['ns_eps'], coeffs=group['ns_coeffs']
|
|
301
|
+
)
|
|
302
|
+
else:
|
|
303
|
+
# Original full Newton-Schulz
|
|
304
|
+
update = _newton_schulz_iteration(
|
|
305
|
+
signed_m_buf,
|
|
306
|
+
steps=group['ns_steps'],
|
|
307
|
+
eps=group['ns_eps'],
|
|
308
|
+
coeffs=group['ns_coeffs'],
|
|
309
|
+
)
|
|
320
310
|
|
|
321
311
|
# Reconstruct second momentum from previous step's factors
|
|
322
312
|
vt_buf = _unnmf((state['mu_v_nmf'], state['mv_v_nmf']))
|
|
@@ -381,25 +371,41 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
381
371
|
if len(p.shape) > 2:
|
|
382
372
|
signed_m_buf = signed_m_buf.view(p.shape[0], -1)
|
|
383
373
|
|
|
384
|
-
#
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
374
|
+
# Orthogonalization step
|
|
375
|
+
if group['low_rank_ortho']:
|
|
376
|
+
# Low-Rank Orthogonalization on the reconstructed matrix
|
|
377
|
+
M = signed_m_buf
|
|
378
|
+
r = min(group['ortho_rank'], M.shape[0], M.shape[1])
|
|
379
|
+
if r > 0:
|
|
380
|
+
G_sketch = torch.randn(M.shape[1], r, device=M.device, dtype=M.dtype)
|
|
381
|
+
MG = M @ G_sketch
|
|
382
|
+
if MG.dtype != torch.float32:
|
|
383
|
+
MG_dtype = M.dtype
|
|
384
|
+
Q, _ = torch.linalg.qr(MG.float())
|
|
385
|
+
Q = Q.to(MG_dtype)
|
|
386
|
+
else:
|
|
387
|
+
Q, _ = torch.linalg.qr(MG)
|
|
388
|
+
projected_M = Q.T @ M
|
|
389
|
+
ortho_projected_M = _newton_schulz_iteration(
|
|
390
|
+
projected_M, steps=group['ns_steps'], eps=group['ns_eps'], coeffs=group['ns_coeffs']
|
|
391
|
+
)
|
|
392
|
+
update = Q @ ortho_projected_M
|
|
393
|
+
else: # Fallback for invalid rank
|
|
394
|
+
update = _newton_schulz_iteration(
|
|
395
|
+
signed_m_buf, steps=group['ns_steps'], eps=group['ns_eps'], coeffs=group['ns_coeffs']
|
|
396
|
+
)
|
|
397
|
+
else:
|
|
398
|
+
# Original full Newton-Schulz
|
|
399
|
+
update = _newton_schulz_iteration(
|
|
400
|
+
signed_m_buf,
|
|
401
|
+
steps=group['ns_steps'],
|
|
402
|
+
eps=group['ns_eps'],
|
|
403
|
+
coeffs=group['ns_coeffs'],
|
|
404
|
+
)
|
|
391
405
|
|
|
392
406
|
if len(p.shape) > 2 or state['reshaped_1d_muon']:
|
|
393
407
|
update = update.view(p.shape)
|
|
394
408
|
|
|
395
|
-
if group['_kourkoutas_beta']:
|
|
396
|
-
# Call prepare_step() once at the beginning of the step for all params
|
|
397
|
-
self._kourkoutas_helper.maybe_prepare_step(current_step)
|
|
398
|
-
# Accumulate current sign-stabilized orthogonal update's norm for the *next* step
|
|
399
|
-
self._kourkoutas_helper.accumulate_gradient_sq_norm(p, update)
|
|
400
|
-
# Get the dynamic beta2 calculated in prepare_step()
|
|
401
|
-
beta2 = self._kourkoutas_helper.get_beta2(p, group, current_step)
|
|
402
|
-
|
|
403
409
|
vt_buf = state['second_momentum_buffer']
|
|
404
410
|
vt_buf.mul_(beta2).addcmul_(update, update, value=1 - beta2)
|
|
405
411
|
|
|
@@ -18,6 +18,10 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
18
18
|
the hidden layers of neural networks. It applies SGD with momentum and then
|
|
19
19
|
orthogonalizes the resulting update matrix using a Newton-Schulz iteration.
|
|
20
20
|
|
|
21
|
+
NorMuon (Neuron-wise Normalized Muon) extends this by adding neuron-level
|
|
22
|
+
adaptive learning rates, combining the benefits of orthogonalization with
|
|
23
|
+
second-order momentum statistics.
|
|
24
|
+
|
|
21
25
|
This implementation is designed for 2D parameters (e.g., linear layers) and
|
|
22
26
|
can handle other-dimensional parameters (e.g., 1D bias, 4D convolutional layers) by
|
|
23
27
|
flattening/reshaping them.
|
|
@@ -54,6 +58,19 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
54
58
|
matrices to apply low-rank compression (default: True).
|
|
55
59
|
nnmf_factor (bool): whether to use the factorization or disable it to use
|
|
56
60
|
the uncompressed optimizer. (default: False)
|
|
61
|
+
low_rank_ortho (bool): If True, enables low-rank orthogonalization, which
|
|
62
|
+
projects the update to a lower rank before orthogonalization.
|
|
63
|
+
(default: False)
|
|
64
|
+
ortho_rank (int): The rank for low-rank orthogonalization.
|
|
65
|
+
(default: 128)
|
|
66
|
+
normuon_variant (bool): If True, enables the NorMuon update rule, which adds
|
|
67
|
+
neuron-wise normalization. (default: False)
|
|
68
|
+
beta2_normuon (float): The exponential decay rate for the second moment estimates
|
|
69
|
+
used in NorMuon. (default: 0.95)
|
|
70
|
+
normuon_eps (float): Epsilon for NorMuon normalization stability. (default: 1e-8)
|
|
71
|
+
normuon_lr_scale (float): Scaling factor for the NorMuon learning rate.
|
|
72
|
+
(default: 0.2)
|
|
73
|
+
normuon_atan2 (bool): whether to use the atan2 for NorMuon. (default: False)
|
|
57
74
|
MuonWithAuxAdam (bool): If True, enables the hybrid optimizer mode.
|
|
58
75
|
Parameters designated by `layer_key_fn` will be optimized with
|
|
59
76
|
AdamW_adv instead of Muon. (default: False)
|
|
@@ -82,6 +99,15 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
82
99
|
vector_reshape_muon: bool = False,
|
|
83
100
|
vector_reshape: bool = False,
|
|
84
101
|
nnmf_factor: bool = False,
|
|
102
|
+
# Low-rank Muon
|
|
103
|
+
low_rank_ortho: bool = False,
|
|
104
|
+
ortho_rank: int = 128,
|
|
105
|
+
# NorMuon additions
|
|
106
|
+
normuon_variant: bool = False,
|
|
107
|
+
beta2_normuon: float = 0.95,
|
|
108
|
+
normuon_eps: float = 1e-8,
|
|
109
|
+
normuon_lr_scale: float = 0.2,
|
|
110
|
+
normuon_atan2: bool = False,
|
|
85
111
|
# hybrid optimizer mode
|
|
86
112
|
MuonWithAuxAdam: bool = False,
|
|
87
113
|
layer_key_fn: Optional[Callable] = None,
|
|
@@ -92,6 +118,8 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
92
118
|
raise ValueError(f"Learning-rate should be >= 0.0. Got {lr}")
|
|
93
119
|
if not (0.0 <= beta1 < 1.0):
|
|
94
120
|
raise ValueError(f"beta1 should be in [0.0, 1.0). Got {beta1}")
|
|
121
|
+
if normuon_variant and not (0.0 <= beta2_normuon < 1.0):
|
|
122
|
+
raise ValueError(f"beta2_normuon should be in [0.0, 1.0) for NorMuon. Got {beta2_normuon}")
|
|
95
123
|
if not (weight_decay >= 0.0):
|
|
96
124
|
raise ValueError(f"Weight-decay should be >= 0.0. Got {weight_decay}")
|
|
97
125
|
if not (ns_steps > 0):
|
|
@@ -106,10 +134,16 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
106
134
|
"ns_coeffs": ns_coeffs, "nnmf_factor": nnmf_factor,
|
|
107
135
|
"vector_reshape": vector_reshape,
|
|
108
136
|
"vector_reshape_muon": vector_reshape_muon,
|
|
109
|
-
"Simplified_AdEMAMix": Simplified_AdEMAMix, "alpha_grad": alpha_grad,
|
|
137
|
+
"Simplified_AdEMAMix": Simplified_AdEMAMix, "alpha_grad": alpha_grad,
|
|
138
|
+
# Low-rank Ortho
|
|
139
|
+
"low_rank_ortho": low_rank_ortho, "ortho_rank": ortho_rank,
|
|
140
|
+
# NorMuon
|
|
141
|
+
"normuon_variant": normuon_variant, "beta2_normuon": beta2_normuon,
|
|
142
|
+
"normuon_eps": normuon_eps, "normuon_lr_scale": normuon_lr_scale,
|
|
143
|
+
"normuon_atan2": normuon_atan2,
|
|
110
144
|
}
|
|
111
145
|
self.stochastic_rounding = stochastic_rounding
|
|
112
|
-
|
|
146
|
+
|
|
113
147
|
self.MuonWithAuxAdam = MuonWithAuxAdam
|
|
114
148
|
self.helper = None
|
|
115
149
|
self.aux_adam = None
|
|
@@ -144,7 +178,7 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
144
178
|
|
|
145
179
|
final_param_groups.append(new_group)
|
|
146
180
|
|
|
147
|
-
super().__init__(final_param_groups,
|
|
181
|
+
super().__init__(final_param_groups, muon_defaults)
|
|
148
182
|
|
|
149
183
|
# Now that self is initialized, create the helper
|
|
150
184
|
self.helper = MuonAdamHelper(self, layer_key_fn)
|
|
@@ -223,6 +257,12 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
223
257
|
elif len(p.shape) == 1:
|
|
224
258
|
state['momentum_buffer'] = torch.zeros_like(p)
|
|
225
259
|
|
|
260
|
+
# NorMuon state initialization
|
|
261
|
+
if group['normuon_variant']:
|
|
262
|
+
if len(p.shape) >= 2 or state['reshaped_1d_muon']:
|
|
263
|
+
num_rows = p.shape[0] if len(p.shape) >= 2 else state['effective_shape'][0]
|
|
264
|
+
state['normuon_v'] = torch.zeros(num_rows, device=p.device, dtype=torch.float32)
|
|
265
|
+
|
|
226
266
|
beta1 = group['beta1']
|
|
227
267
|
nesterov = group['nesterov']
|
|
228
268
|
Simplified_AdEMAMix = group['Simplified_AdEMAMix']
|
|
@@ -251,14 +291,60 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
251
291
|
update = mt_buf.clone()
|
|
252
292
|
del grad_reshaped
|
|
253
293
|
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
294
|
+
# Orthogonalization step
|
|
295
|
+
if group['low_rank_ortho']:
|
|
296
|
+
# Low-Rank Orthogonalization on the reconstructed matrix
|
|
297
|
+
M = update
|
|
298
|
+
r = min(group['ortho_rank'], M.shape[0], M.shape[1])
|
|
299
|
+
if r > 0:
|
|
300
|
+
G_sketch = torch.randn(M.shape[1], r, device=M.device, dtype=M.dtype)
|
|
301
|
+
MG = M @ G_sketch
|
|
302
|
+
if MG.dtype != torch.float32:
|
|
303
|
+
MG_dtype = M.dtype
|
|
304
|
+
Q, _ = torch.linalg.qr(MG.float())
|
|
305
|
+
Q = Q.to(MG_dtype)
|
|
306
|
+
else:
|
|
307
|
+
Q, _ = torch.linalg.qr(MG)
|
|
308
|
+
projected_M = Q.T @ M
|
|
309
|
+
ortho_projected_M = _newton_schulz_iteration(
|
|
310
|
+
projected_M, steps=group['ns_steps'], eps=group['ns_eps'], coeffs=group['ns_coeffs']
|
|
311
|
+
)
|
|
312
|
+
update = Q @ ortho_projected_M
|
|
313
|
+
else: # Fallback for invalid rank
|
|
314
|
+
update = _newton_schulz_iteration(
|
|
315
|
+
update, steps=group['ns_steps'], eps=group['ns_eps'], coeffs=group['ns_coeffs']
|
|
316
|
+
)
|
|
317
|
+
else:
|
|
318
|
+
# Original full Newton-Schulz
|
|
319
|
+
update = _newton_schulz_iteration(
|
|
320
|
+
update,
|
|
321
|
+
steps=group['ns_steps'],
|
|
322
|
+
eps=group['ns_eps'],
|
|
323
|
+
coeffs=group['ns_coeffs'],
|
|
324
|
+
)
|
|
260
325
|
|
|
261
|
-
|
|
326
|
+
|
|
327
|
+
if group['normuon_variant'] and 'normuon_v' in state:
|
|
328
|
+
v_t = state['normuon_v']
|
|
329
|
+
beta2_normuon = group['beta2_normuon']
|
|
330
|
+
# Update 2nd moment estimate
|
|
331
|
+
mean_squared_update = torch.mean(update.square(), dim=1)
|
|
332
|
+
v_t.mul_(beta2_normuon).add_(mean_squared_update, alpha=1 - beta2_normuon)
|
|
333
|
+
# Normalize update
|
|
334
|
+
if group['normuon_atan2']:
|
|
335
|
+
a = 1.2732395
|
|
336
|
+
update.atan2_(v_t.sqrt().unsqueeze(1)).mul_(a)
|
|
337
|
+
else:
|
|
338
|
+
update.div_(v_t.sqrt().unsqueeze(1).add_(group['normuon_eps']))
|
|
339
|
+
# Scale learning rate
|
|
340
|
+
update_norm = torch.linalg.vector_norm(update)
|
|
341
|
+
if update_norm > 1e-12:
|
|
342
|
+
scaled_lr = group['normuon_lr_scale'] * group['lr'] * (p.numel()**0.5) / update_norm
|
|
343
|
+
else:
|
|
344
|
+
scaled_lr = 0.0
|
|
345
|
+
update = update.view(p.shape).mul_(scaled_lr)
|
|
346
|
+
else: # Original Muon learning rate application
|
|
347
|
+
update = update.view(p.shape).mul_(group['lr'])
|
|
262
348
|
|
|
263
349
|
state['sign'] = _pack_bools(mt_buf > 0)
|
|
264
350
|
_nnmf(mt_buf.abs(), out=(state['mu_m_nmf'], state['mv_m_nmf']))
|
|
@@ -298,19 +384,80 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
298
384
|
if len(p.shape) > 2:
|
|
299
385
|
update = update.view(p.shape[0], -1)
|
|
300
386
|
|
|
301
|
-
#
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
387
|
+
# Orthogonalization step
|
|
388
|
+
if group['low_rank_ortho']:
|
|
389
|
+
# Low-Rank Orthogonalization based on Gaussian Sketching
|
|
390
|
+
M = update
|
|
391
|
+
r = min(group['ortho_rank'], M.shape[0], M.shape[1])
|
|
392
|
+
|
|
393
|
+
if r > 0:
|
|
394
|
+
# 1. Sketch the matrix
|
|
395
|
+
G_sketch = torch.randn(M.shape[1], r, device=M.device, dtype=M.dtype)
|
|
396
|
+
MG = M @ G_sketch
|
|
397
|
+
|
|
398
|
+
# 2. QR decomposition to get orthogonal basis Q
|
|
399
|
+
if MG.dtype != torch.float32:
|
|
400
|
+
MG_dtype = M.dtype
|
|
401
|
+
Q, _ = torch.linalg.qr(MG.float())
|
|
402
|
+
Q = Q.to(MG_dtype)
|
|
403
|
+
else:
|
|
404
|
+
Q, _ = torch.linalg.qr(MG)
|
|
405
|
+
|
|
406
|
+
# 3. Project M onto the basis
|
|
407
|
+
projected_M = Q.T @ M
|
|
408
|
+
|
|
409
|
+
# 4. Orthogonalize the smaller projected matrix
|
|
410
|
+
ortho_projected_M = _newton_schulz_iteration(
|
|
411
|
+
projected_M,
|
|
412
|
+
steps=group['ns_steps'],
|
|
413
|
+
eps=group['ns_eps'],
|
|
414
|
+
coeffs=group['ns_coeffs'],
|
|
415
|
+
)
|
|
416
|
+
|
|
417
|
+
# 5. Project back to the original space
|
|
418
|
+
update = Q @ ortho_projected_M
|
|
419
|
+
else: # Fallback for invalid rank
|
|
420
|
+
update = _newton_schulz_iteration(
|
|
421
|
+
update,
|
|
422
|
+
steps=group['ns_steps'],
|
|
423
|
+
eps=group['ns_eps'],
|
|
424
|
+
coeffs=group['ns_coeffs'],
|
|
425
|
+
)
|
|
426
|
+
else:
|
|
427
|
+
# Original NewtonSchulz
|
|
428
|
+
update = _newton_schulz_iteration(
|
|
429
|
+
update,
|
|
430
|
+
steps=group['ns_steps'],
|
|
431
|
+
eps=group['ns_eps'],
|
|
432
|
+
coeffs=group['ns_coeffs'],
|
|
433
|
+
)
|
|
434
|
+
|
|
435
|
+
# NorMuon Logic
|
|
436
|
+
if group['normuon_variant'] and 'normuon_v' in state:
|
|
437
|
+
v_t = state['normuon_v']
|
|
438
|
+
beta2_normuon = group['beta2_normuon']
|
|
439
|
+
# Update 2nd moment estimate
|
|
440
|
+
mean_squared_update = torch.mean(update.square(), dim=1)
|
|
441
|
+
v_t.mul_(beta2_normuon).add_(mean_squared_update, alpha=1 - beta2_normuon)
|
|
442
|
+
# Normalize update
|
|
443
|
+
if group['normuon_atan2']:
|
|
444
|
+
a = 1.2732395
|
|
445
|
+
update.atan2_(v_t.sqrt().unsqueeze(1)).mul_(a)
|
|
446
|
+
else:
|
|
447
|
+
update.div_(v_t.sqrt().unsqueeze(1).add_(group['normuon_eps']))
|
|
448
|
+
# Scale learning rate
|
|
449
|
+
update_norm = torch.linalg.vector_norm(update)
|
|
450
|
+
if update_norm > 1e-12:
|
|
451
|
+
scaled_lr = group['normuon_lr_scale'] * group['lr'] * (p.numel()**0.5) / update_norm
|
|
452
|
+
else:
|
|
453
|
+
scaled_lr = 0.0
|
|
454
|
+
update.mul_(scaled_lr)
|
|
455
|
+
else: # Original Muon learning rate application
|
|
456
|
+
update.mul_(group['lr'])
|
|
308
457
|
|
|
309
458
|
# Reshape back to original if we flattened or reshaped
|
|
310
459
|
if len(p.shape) > 2 or state['reshaped_1d_muon']:
|
|
311
460
|
update = update.view(p.shape)
|
|
312
|
-
|
|
313
|
-
update.mul_(group['lr'])
|
|
314
461
|
|
|
315
462
|
else: # Fallback to standard SGD with momentum for 1D params (biases, etc.) when not reshaped
|
|
316
463
|
# Momentum update
|
|
@@ -94,7 +94,7 @@ class KourkoutasHelper:
|
|
|
94
94
|
for layer_key, info in self.layer_info.items():
|
|
95
95
|
params, group = info['params'], info['group_ref']
|
|
96
96
|
|
|
97
|
-
if not group.get('kourkoutas_beta', False):
|
|
97
|
+
if not group.get('kourkoutas_beta', False) and not group.get('_kourkoutas_beta', False):
|
|
98
98
|
continue
|
|
99
99
|
|
|
100
100
|
first_param_in_layer = info['params'][0]
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|