adv-optm 1.2.dev11__tar.gz → 1.2.dev13__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of adv-optm might be problematic. Click here for more details.
- {adv_optm-1.2.dev11 → adv_optm-1.2.dev13}/PKG-INFO +1 -1
- {adv_optm-1.2.dev11 → adv_optm-1.2.dev13}/adv_optm/__init__.py +1 -1
- adv_optm-1.2.dev13/adv_optm/optim/AdaMuon_adv.py +664 -0
- adv_optm-1.2.dev13/adv_optm/optim/Muon_adv.py +695 -0
- {adv_optm-1.2.dev11 → adv_optm-1.2.dev13}/adv_optm.egg-info/PKG-INFO +1 -1
- {adv_optm-1.2.dev11 → adv_optm-1.2.dev13}/setup.py +1 -1
- adv_optm-1.2.dev11/adv_optm/optim/AdaMuon_adv.py +0 -397
- adv_optm-1.2.dev11/adv_optm/optim/Muon_adv.py +0 -423
- {adv_optm-1.2.dev11 → adv_optm-1.2.dev13}/LICENSE +0 -0
- {adv_optm-1.2.dev11 → adv_optm-1.2.dev13}/README.md +0 -0
- {adv_optm-1.2.dev11 → adv_optm-1.2.dev13}/adv_optm/optim/AdamW_adv.py +0 -0
- {adv_optm-1.2.dev11 → adv_optm-1.2.dev13}/adv_optm/optim/Adopt_adv.py +0 -0
- {adv_optm-1.2.dev11 → adv_optm-1.2.dev13}/adv_optm/optim/Lion_Prodigy_adv.py +0 -0
- {adv_optm-1.2.dev11 → adv_optm-1.2.dev13}/adv_optm/optim/Lion_adv.py +0 -0
- {adv_optm-1.2.dev11 → adv_optm-1.2.dev13}/adv_optm/optim/Prodigy_adv.py +0 -0
- {adv_optm-1.2.dev11 → adv_optm-1.2.dev13}/adv_optm/optim/Simplified_AdEMAMix.py +0 -0
- {adv_optm-1.2.dev11 → adv_optm-1.2.dev13}/adv_optm/optim/__init__.py +0 -0
- {adv_optm-1.2.dev11 → adv_optm-1.2.dev13}/adv_optm/util/BF16_Stochastic_Rounding.py +0 -0
- {adv_optm-1.2.dev11 → adv_optm-1.2.dev13}/adv_optm/util/Effective_Shape.py +0 -0
- {adv_optm-1.2.dev11 → adv_optm-1.2.dev13}/adv_optm/util/Kourkoutas.py +0 -0
- {adv_optm-1.2.dev11 → adv_optm-1.2.dev13}/adv_optm/util/NNMF.py +0 -0
- {adv_optm-1.2.dev11 → adv_optm-1.2.dev13}/adv_optm/util/Newton_Schulz.py +0 -0
- {adv_optm-1.2.dev11 → adv_optm-1.2.dev13}/adv_optm/util/One_Bit_Boolean.py +0 -0
- {adv_optm-1.2.dev11 → adv_optm-1.2.dev13}/adv_optm/util/OrthoGrad.py +0 -0
- {adv_optm-1.2.dev11 → adv_optm-1.2.dev13}/adv_optm/util/__init__.py +0 -0
- {adv_optm-1.2.dev11 → adv_optm-1.2.dev13}/adv_optm.egg-info/SOURCES.txt +0 -0
- {adv_optm-1.2.dev11 → adv_optm-1.2.dev13}/adv_optm.egg-info/dependency_links.txt +0 -0
- {adv_optm-1.2.dev11 → adv_optm-1.2.dev13}/adv_optm.egg-info/requires.txt +0 -0
- {adv_optm-1.2.dev11 → adv_optm-1.2.dev13}/adv_optm.egg-info/top_level.txt +0 -0
- {adv_optm-1.2.dev11 → adv_optm-1.2.dev13}/setup.cfg +0 -0
|
@@ -0,0 +1,664 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
from ..util.BF16_Stochastic_Rounding import add_stochastic_
|
|
4
|
+
from ..util.Newton_Schulz import _newton_schulz_iteration
|
|
5
|
+
from ..util.Effective_Shape import _get_effective_shape
|
|
6
|
+
from ..util.NNMF import _nnmf,_unnmf
|
|
7
|
+
from ..util.One_Bit_Boolean import _pack_bools, _unpack_bools
|
|
8
|
+
from ..util.OrthoGrad import _orthogonalize_gradient
|
|
9
|
+
from ..util.Kourkoutas import KourkoutasHelper
|
|
10
|
+
|
|
11
|
+
class AdaMuon_adv(torch.optim.Optimizer):
|
|
12
|
+
"""
|
|
13
|
+
IImplements an advanced AdaMuon optimizer algorithm.
|
|
14
|
+
|
|
15
|
+
AdaMuon combines the geometry-aware updates of Muon with the element-wise
|
|
16
|
+
adaptivity of Adam. It is designed for 2D parameters (e.g., linear layers)
|
|
17
|
+
and can handle higher-dimensional parameters by flattening.
|
|
18
|
+
|
|
19
|
+
The algorithm incorporates three key mechanisms:
|
|
20
|
+
1. A sign-stabilized orthogonal update, where the sign of the momentum is
|
|
21
|
+
orthogonalized instead of the momentum itself.
|
|
22
|
+
2. An element-wise second momentum estimator applied to the orthogonalized
|
|
23
|
+
update directions.
|
|
24
|
+
3. An RMS-aligned rescaling strategy to match the update magnitude of Adam,
|
|
25
|
+
allowing for reuse of learning rate schedules.
|
|
26
|
+
|
|
27
|
+
When `MuonWithAuxAdam` is enabled, this single optimizer class handles both
|
|
28
|
+
'muon' and 'adam' parameter groups, dispatching to the appropriate logic internally.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
params (iterable): iterable of parameters to optimize or dicts defining
|
|
32
|
+
parameter groups.
|
|
33
|
+
lr (float): learning rate (default: 1e-3).
|
|
34
|
+
betas (tuple[float, float]): coefficients used for both first and second moment
|
|
35
|
+
estimation (default: (0.95, 0.95))
|
|
36
|
+
weight_decay (float): weight decay (L2 penalty) (default: 0.1).
|
|
37
|
+
eps (float): term added to the denominator for adaptive scaling to improve
|
|
38
|
+
numerical stability (default: 1e-8).
|
|
39
|
+
rms_target (float): The target Root-Mean-Square value for the final update
|
|
40
|
+
vector, used for RMS-aligned rescaling. Allows for the reuse of existing Adam
|
|
41
|
+
learning rate schedules. (default: 0.2).
|
|
42
|
+
ns_steps (int): number of Newton-Schulz iterations to perform (default: 5).
|
|
43
|
+
ns_eps (float): epsilon for Newton-Schulz normalization stability (default: 1e-7).
|
|
44
|
+
ns_coeffs (tuple[float, float, float]): The (a, b, c) coefficients for the
|
|
45
|
+
quintic polynomial in the Newton-Schulz iteration.
|
|
46
|
+
(default: (3.4445, -4.7750, 2.0315)).
|
|
47
|
+
stochastic_rounding (bool): whether to use stochastic rounding for
|
|
48
|
+
BF16 parameter updates (default: True).
|
|
49
|
+
nesterov (bool): enables Nesterov momentum (default: False).
|
|
50
|
+
use_atan2 (bool): whether to use the atan2 update rule. (default: False)
|
|
51
|
+
Simplified_AdEMAMix (bool): whether to use the Simplified AdEMAMix update rule.
|
|
52
|
+
This changes the update to `alpha_grad * grad + mt`, which can be
|
|
53
|
+
more responsive, especially for small batch sizes. (default: False)
|
|
54
|
+
alpha_grad (float): Mixing coefficient for the Simplified AdEMAMix update rule
|
|
55
|
+
(only used when `Simplified_AdEMAMix` is `True`). Controls the weight of the
|
|
56
|
+
current gradient. For small batch sizes, use high values (e.g., 10-100) to be
|
|
57
|
+
more responsive. For large batch sizes, use low values (e.g., 0-1) for
|
|
58
|
+
stability. (default: 100.0)
|
|
59
|
+
vector_reshape (bool): whether to reshape 1D vectors into 2D
|
|
60
|
+
matrices to apply low-rank compression (default: True).
|
|
61
|
+
low_rank_ortho (bool): If True, enables low-rank orthogonalization, which
|
|
62
|
+
projects the update to a lower rank before orthogonalization.
|
|
63
|
+
(default: False)
|
|
64
|
+
ortho_rank (int): The rank for low-rank orthogonalization.
|
|
65
|
+
(default: 128)
|
|
66
|
+
nnmf_factor (bool): whether to use the factorization or disable it to use
|
|
67
|
+
the uncompressed optimizer. (default: False)
|
|
68
|
+
--- Auxiliary AdamW_adv Parameters (used for 'adam' groups) ---
|
|
69
|
+
adam_betas (tuple[float, float]): Betas for the AdamW optimizer part.
|
|
70
|
+
adam_eps (float): Epsilon for the AdamW optimizer part.
|
|
71
|
+
adam_weight_decay (float): Weight decay for the AdamW optimizer part.
|
|
72
|
+
adam_use_bias_correction (bool): Bias correction for AdamW.
|
|
73
|
+
adam_use_atan2 (bool): Atan2 update rule for AdamW.
|
|
74
|
+
adam_cautious_mask (bool): Cautious masking for AdamW.
|
|
75
|
+
adam_grams_moment (bool): Grams-style updates for AdamW.
|
|
76
|
+
adam_orthogonal_gradient (bool): OrthoGrad for AdamW.
|
|
77
|
+
adam_use_AdEMAMix (bool): AdEMAMix for AdamW.
|
|
78
|
+
adam_beta3_ema (float): Beta3 for AdEMAMix.
|
|
79
|
+
adam_alpha (float): Alpha for AdEMAMix.
|
|
80
|
+
adam_kourkoutas_beta (bool): Kourkoutas-β for AdamW.
|
|
81
|
+
"""
|
|
82
|
+
|
|
83
|
+
def __init__(
|
|
84
|
+
self,
|
|
85
|
+
params,
|
|
86
|
+
lr: float = 1e-3,
|
|
87
|
+
betas: tuple[float, float] = (0.95, 0.95),
|
|
88
|
+
weight_decay: float = 0.1,
|
|
89
|
+
eps: float = 1e-8,
|
|
90
|
+
rms_target: float = 0.2,
|
|
91
|
+
ns_steps: int = 5,
|
|
92
|
+
ns_eps: float = 1e-7,
|
|
93
|
+
ns_coeffs: tuple[float, float, float] = (3.4445, -4.7750, 2.0315),
|
|
94
|
+
stochastic_rounding: bool = False,
|
|
95
|
+
use_atan2: bool = False,
|
|
96
|
+
nesterov: bool = False,
|
|
97
|
+
Simplified_AdEMAMix: bool = False,
|
|
98
|
+
alpha_grad: float = 100.0,
|
|
99
|
+
vector_reshape: bool = False,
|
|
100
|
+
# Low-rank Muon
|
|
101
|
+
low_rank_ortho: bool = False,
|
|
102
|
+
ortho_rank: int = 128,
|
|
103
|
+
nnmf_factor: bool = False,
|
|
104
|
+
# Compiled
|
|
105
|
+
compiled_optimizer: bool = False,
|
|
106
|
+
# --- AdamW_adv specific parameters ---
|
|
107
|
+
adam_betas: tuple[float, float] = (0.9, 0.99),
|
|
108
|
+
adam_eps: float = 1e-8,
|
|
109
|
+
adam_weight_decay: float = 0.0,
|
|
110
|
+
adam_use_bias_correction: bool = True,
|
|
111
|
+
adam_use_atan2: bool = False,
|
|
112
|
+
adam_cautious_mask: bool = False,
|
|
113
|
+
adam_grams_moment: bool = False,
|
|
114
|
+
adam_orthogonal_gradient: bool = False,
|
|
115
|
+
adam_use_AdEMAMix: bool = False,
|
|
116
|
+
adam_beta3_ema: float = 0.9999,
|
|
117
|
+
adam_alpha: float = 5.0,
|
|
118
|
+
adam_kourkoutas_beta: bool = False,
|
|
119
|
+
adam_beta2_min: float = 0.9,
|
|
120
|
+
adam_ema_alpha: float = 0.95,
|
|
121
|
+
adam_tiny_spike: float = 1e-9,
|
|
122
|
+
adam_k_warmup_steps: int = 0,
|
|
123
|
+
adam_nnmf_factor: bool = False,
|
|
124
|
+
):
|
|
125
|
+
if not (lr >= 0.0):
|
|
126
|
+
raise ValueError(f"Learning-rate should be >= 0.0. Got {lr}")
|
|
127
|
+
if not (weight_decay >= 0.0):
|
|
128
|
+
raise ValueError(f"Weight-decay should be >= 0.0. Got {weight_decay}")
|
|
129
|
+
if not (ns_steps > 0):
|
|
130
|
+
raise ValueError(f"Newton-Schulz steps should be > 0. Got {ns_steps}")
|
|
131
|
+
if Simplified_AdEMAMix and nesterov:
|
|
132
|
+
print("Warning: nesterov is incompatible with Simplified_AdEMAMix, Disabling nesterov.")
|
|
133
|
+
nesterov = False
|
|
134
|
+
|
|
135
|
+
defaults = {
|
|
136
|
+
"lr": lr, "betas": betas, "weight_decay": weight_decay,
|
|
137
|
+
"eps": eps, "rms_target": rms_target, "ns_steps": ns_steps,
|
|
138
|
+
"ns_eps": ns_eps, "ns_coeffs": ns_coeffs, "nnmf_factor": nnmf_factor,
|
|
139
|
+
"vector_reshape": vector_reshape,
|
|
140
|
+
"nesterov":nesterov, "use_atan2":use_atan2,
|
|
141
|
+
"Simplified_AdEMAMix": Simplified_AdEMAMix, "alpha_grad": alpha_grad,
|
|
142
|
+
# Low-rank Ortho
|
|
143
|
+
"low_rank_ortho": low_rank_ortho, "ortho_rank": ortho_rank,
|
|
144
|
+
"compiled_optimizer":compiled_optimizer,
|
|
145
|
+
# AdamW_adv defaults
|
|
146
|
+
"adam_betas": adam_betas, "adam_eps": adam_eps, "adam_weight_decay": adam_weight_decay,
|
|
147
|
+
"adam_use_bias_correction": adam_use_bias_correction, "adam_use_atan2": adam_use_atan2,
|
|
148
|
+
"adam_cautious_mask": adam_cautious_mask, "adam_grams_moment": adam_grams_moment,
|
|
149
|
+
"adam_orthogonal_gradient": adam_orthogonal_gradient,
|
|
150
|
+
"adam_use_AdEMAMix": adam_use_AdEMAMix, "adam_beta3_ema": adam_beta3_ema, "adam_alpha": adam_alpha,
|
|
151
|
+
"adam_kourkoutas_beta": adam_kourkoutas_beta, "adam_beta2_min": adam_beta2_min,
|
|
152
|
+
"adam_ema_alpha": adam_ema_alpha, "adam_tiny_spike": adam_tiny_spike,
|
|
153
|
+
"adam_k_warmup_steps": adam_k_warmup_steps, "adam_nnmf_factor": adam_nnmf_factor,
|
|
154
|
+
}
|
|
155
|
+
self.stochastic_rounding = stochastic_rounding
|
|
156
|
+
|
|
157
|
+
super().__init__(params, defaults)
|
|
158
|
+
|
|
159
|
+
self.global_step = 0 # For Adam bias correction and Kourkoutas
|
|
160
|
+
self.kourkoutas_helper = None
|
|
161
|
+
if any(group.get('adam_kourkoutas_beta', False) for group in self.param_groups):
|
|
162
|
+
self.kourkoutas_helper = KourkoutasHelper(self)
|
|
163
|
+
|
|
164
|
+
self.init_step()
|
|
165
|
+
|
|
166
|
+
# Initialize compiled functions to None
|
|
167
|
+
self._compiled_muon_step = None
|
|
168
|
+
self._compiled_adam_step = None
|
|
169
|
+
|
|
170
|
+
if compiled_optimizer:
|
|
171
|
+
print("Compiling AdaMuon_adv optimizer paths...")
|
|
172
|
+
torch._dynamo.config.cache_size_limit = 8192
|
|
173
|
+
self.compile(fullgraph=True)
|
|
174
|
+
|
|
175
|
+
@property
|
|
176
|
+
def supports_fused_back_pass(self):
|
|
177
|
+
return True
|
|
178
|
+
|
|
179
|
+
@property
|
|
180
|
+
def supports_memory_efficient_fp16(self):
|
|
181
|
+
return True
|
|
182
|
+
|
|
183
|
+
@property
|
|
184
|
+
def supports_flat_params(self):
|
|
185
|
+
return False
|
|
186
|
+
|
|
187
|
+
def init_step(self):
|
|
188
|
+
for group in self.param_groups:
|
|
189
|
+
for i, p in enumerate(group['params']):
|
|
190
|
+
self.__init_state(p, group)
|
|
191
|
+
|
|
192
|
+
@torch.no_grad()
|
|
193
|
+
def __init_state(self, p, group):
|
|
194
|
+
state = self.state[p]
|
|
195
|
+
|
|
196
|
+
if len(state) > 0:
|
|
197
|
+
return
|
|
198
|
+
|
|
199
|
+
optim_type = group.get('optim_type', 'muon')
|
|
200
|
+
|
|
201
|
+
if optim_type == 'muon':
|
|
202
|
+
|
|
203
|
+
state['factored'] = (
|
|
204
|
+
group['nnmf_factor'] and
|
|
205
|
+
not (len(p.shape) == 1 and not group['vector_reshape'])
|
|
206
|
+
)
|
|
207
|
+
dtype = torch.float32 if state['factored'] else p.dtype
|
|
208
|
+
device = p.device
|
|
209
|
+
|
|
210
|
+
if state['factored']:
|
|
211
|
+
state['effective_shape'] = _get_effective_shape(p.numel())
|
|
212
|
+
d1, d2 = state['effective_shape']
|
|
213
|
+
state['mu_mbuf_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
|
|
214
|
+
state['mv_mbuf_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
|
|
215
|
+
packed_d2 = (d2 + 7) // 8
|
|
216
|
+
state['sign_buf'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=device)
|
|
217
|
+
state['mu_vbuf_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
|
|
218
|
+
state['mv_vbuf_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
|
|
219
|
+
else:
|
|
220
|
+
if len(p.shape) >= 2:
|
|
221
|
+
state['second_momentum_buffer'] = torch.zeros_like(p)
|
|
222
|
+
state['momentum_buffer'] = torch.zeros_like(p)
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
elif optim_type == 'adam':
|
|
226
|
+
|
|
227
|
+
state['factored'] = (
|
|
228
|
+
group['adam_nnmf_factor'] and
|
|
229
|
+
not (len(p.shape) == 1 and not group['vector_reshape'])
|
|
230
|
+
)
|
|
231
|
+
dtype = torch.float32 if state['factored'] else p.dtype
|
|
232
|
+
device = p.device
|
|
233
|
+
|
|
234
|
+
if state['factored']:
|
|
235
|
+
state['effective_shape'] = _get_effective_shape(p.numel())
|
|
236
|
+
d1, d2 = state['effective_shape']
|
|
237
|
+
# First moment (m)
|
|
238
|
+
if group['adam_betas'][0] > 0:
|
|
239
|
+
state['mu_m_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
|
|
240
|
+
state['mv_m_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
|
|
241
|
+
if not group.get('adam_grams_moment'):
|
|
242
|
+
packed_d2 = (d2 + 7) // 8
|
|
243
|
+
state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=device)
|
|
244
|
+
if group.get('adam_use_AdEMAMix'):
|
|
245
|
+
state['mu_m_slow_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
|
|
246
|
+
state['mv_m_slow_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
|
|
247
|
+
packed_d2 = (d2 + 7) // 8
|
|
248
|
+
state['sign_slow'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=p.device)
|
|
249
|
+
# Second moment (v)
|
|
250
|
+
state['mu_v_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
|
|
251
|
+
state['mv_v_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
|
|
252
|
+
else: # Fallback to standard AdamW for non-factored tensors
|
|
253
|
+
if group['adam_betas'][0] > 0:
|
|
254
|
+
state['exp_avg'] = torch.zeros_like(p, device=device, dtype=dtype)
|
|
255
|
+
if group.get('adam_use_AdEMAMix'):
|
|
256
|
+
state['exp_avg_slow'] = torch.zeros_like(p, device=device, dtype=dtype)
|
|
257
|
+
state['exp_avg_sq'] = torch.zeros_like(p, device=device, dtype=dtype)
|
|
258
|
+
|
|
259
|
+
@torch.no_grad()
|
|
260
|
+
def _muon_step_parameter(self, p, grad, state, group, lr):
|
|
261
|
+
# Retrieve hyperparameters
|
|
262
|
+
beta1, beta2 = group['betas']
|
|
263
|
+
nesterov = group['nesterov']
|
|
264
|
+
Simplified_AdEMAMix = group['Simplified_AdEMAMix']
|
|
265
|
+
alpha_grad = group['alpha_grad']
|
|
266
|
+
|
|
267
|
+
if state['factored']: # Factored AdaMuon
|
|
268
|
+
|
|
269
|
+
# Reconstruct momentum from previous step's factors & sign
|
|
270
|
+
d1, d2 = state['effective_shape']
|
|
271
|
+
mt_buf = _unnmf((state['mu_mbuf_nmf'], state['mv_mbuf_nmf']))
|
|
272
|
+
unpacked_sign = _unpack_bools(state['sign_buf'], original_m=d2)
|
|
273
|
+
torch.where(unpacked_sign, mt_buf, -mt_buf, out=mt_buf)
|
|
274
|
+
del unpacked_sign
|
|
275
|
+
|
|
276
|
+
# Update momentum in full-size
|
|
277
|
+
grad_reshaped = grad.view(d1, d2)
|
|
278
|
+
mt_buf.mul_(beta1).add_(grad_reshaped)
|
|
279
|
+
|
|
280
|
+
if nesterov:
|
|
281
|
+
signed_m_buf = torch.sign(grad_reshaped.add(mt_buf, alpha=beta1))
|
|
282
|
+
elif Simplified_AdEMAMix:
|
|
283
|
+
signed_m_buf = torch.sign(mt_buf.add(grad_reshaped, alpha=alpha_grad))
|
|
284
|
+
else:
|
|
285
|
+
signed_m_buf = torch.sign(mt_buf)
|
|
286
|
+
del grad_reshaped
|
|
287
|
+
|
|
288
|
+
# Orthogonalization step
|
|
289
|
+
if group['low_rank_ortho']:
|
|
290
|
+
# Low-Rank Orthogonalization on the reconstructed matrix
|
|
291
|
+
M = signed_m_buf
|
|
292
|
+
r = min(group['ortho_rank'], M.shape[0], M.shape[1])
|
|
293
|
+
if r > 0:
|
|
294
|
+
G_sketch = torch.randn(M.shape[1], r, device=M.device, dtype=M.dtype)
|
|
295
|
+
MG = M @ G_sketch
|
|
296
|
+
if MG.dtype != torch.float32:
|
|
297
|
+
MG_dtype = M.dtype
|
|
298
|
+
Q, _ = torch.linalg.qr(MG.float())
|
|
299
|
+
Q = Q.to(MG_dtype)
|
|
300
|
+
else:
|
|
301
|
+
Q, _ = torch.linalg.qr(MG)
|
|
302
|
+
projected_M = Q.T @ M
|
|
303
|
+
ortho_projected_M = _newton_schulz_iteration(
|
|
304
|
+
projected_M, steps=group['ns_steps'], eps=group['ns_eps'], coeffs=group['ns_coeffs']
|
|
305
|
+
)
|
|
306
|
+
update = Q @ ortho_projected_M
|
|
307
|
+
else: # Fallback for invalid rank
|
|
308
|
+
update = _newton_schulz_iteration(
|
|
309
|
+
signed_m_buf, steps=group['ns_steps'], eps=group['ns_eps'], coeffs=group['ns_coeffs']
|
|
310
|
+
)
|
|
311
|
+
else:
|
|
312
|
+
# Original full Newton-Schulz
|
|
313
|
+
update = _newton_schulz_iteration(
|
|
314
|
+
signed_m_buf,
|
|
315
|
+
steps=group['ns_steps'],
|
|
316
|
+
eps=group['ns_eps'],
|
|
317
|
+
coeffs=group['ns_coeffs'],
|
|
318
|
+
)
|
|
319
|
+
del signed_m_buf
|
|
320
|
+
|
|
321
|
+
# Reconstruct second momentum from previous step's factors
|
|
322
|
+
vt_buf = _unnmf((state['mu_vbuf_nmf'], state['mv_vbuf_nmf']))
|
|
323
|
+
|
|
324
|
+
# Update second momentum in full-size
|
|
325
|
+
vt_buf.mul_(beta2).addcmul_(update, update, value=1 - beta2)
|
|
326
|
+
|
|
327
|
+
# Apply second momentum update (adaptive scaling)
|
|
328
|
+
if group['use_atan2']:
|
|
329
|
+
a = 1.2732395
|
|
330
|
+
denom = vt_buf.sqrt()
|
|
331
|
+
update.atan2_(denom).mul_(a)
|
|
332
|
+
else:
|
|
333
|
+
denom = vt_buf.sqrt().add_(group['eps'])
|
|
334
|
+
update.div_(denom)
|
|
335
|
+
del denom
|
|
336
|
+
|
|
337
|
+
# RMS-aligned rescaling
|
|
338
|
+
rms_target = group['rms_target']
|
|
339
|
+
num_elements = update.numel()
|
|
340
|
+
# Add eps to prevent division by zero
|
|
341
|
+
update.mul_(rms_target * (num_elements ** 0.5) / (update.norm() + group['eps']))
|
|
342
|
+
|
|
343
|
+
update = update.view(p.shape).mul_(lr)
|
|
344
|
+
del num_elements
|
|
345
|
+
|
|
346
|
+
# Compress updated moments and store new factors
|
|
347
|
+
state['sign_buf'] = _pack_bools(mt_buf > 0)
|
|
348
|
+
_nnmf(mt_buf.abs(), out=(state['mu_mbuf_nmf'], state['mv_mbuf_nmf']))
|
|
349
|
+
del mt_buf
|
|
350
|
+
|
|
351
|
+
_nnmf(vt_buf.abs(), out=(state['mu_vbuf_nmf'], state['mv_vbuf_nmf']))
|
|
352
|
+
del vt_buf
|
|
353
|
+
|
|
354
|
+
else: # Standard AdaMuon logic for non-factored tensors
|
|
355
|
+
|
|
356
|
+
if len(p.shape) >= 2:
|
|
357
|
+
|
|
358
|
+
original_shape = p.shape
|
|
359
|
+
|
|
360
|
+
# Momentum update
|
|
361
|
+
mt_buf = state['momentum_buffer']
|
|
362
|
+
mt_buf.mul_(beta1).add_(grad)
|
|
363
|
+
|
|
364
|
+
if nesterov:
|
|
365
|
+
signed_m_buf = torch.sign(grad.add(mt_buf, alpha=beta1))
|
|
366
|
+
elif Simplified_AdEMAMix:
|
|
367
|
+
signed_m_buf = torch.sign(mt_buf.add(grad, alpha=alpha_grad))
|
|
368
|
+
else:
|
|
369
|
+
signed_m_buf = torch.sign(mt_buf)
|
|
370
|
+
|
|
371
|
+
# Flatten if necessary (e.g., for Conv layers)
|
|
372
|
+
signed_m_buf = signed_m_buf.view(original_shape[0], -1)
|
|
373
|
+
|
|
374
|
+
# Orthogonalization step
|
|
375
|
+
if group['low_rank_ortho']:
|
|
376
|
+
# Low-Rank Orthogonalization on the reconstructed matrix
|
|
377
|
+
M = signed_m_buf
|
|
378
|
+
r = min(group['ortho_rank'], M.shape[0], M.shape[1])
|
|
379
|
+
if r > 0:
|
|
380
|
+
G_sketch = torch.randn(M.shape[1], r, device=M.device, dtype=M.dtype)
|
|
381
|
+
MG = M @ G_sketch
|
|
382
|
+
if MG.dtype != torch.float32:
|
|
383
|
+
MG_dtype = M.dtype
|
|
384
|
+
Q, _ = torch.linalg.qr(MG.float())
|
|
385
|
+
Q = Q.to(MG_dtype)
|
|
386
|
+
else:
|
|
387
|
+
Q, _ = torch.linalg.qr(MG)
|
|
388
|
+
projected_M = Q.T @ M
|
|
389
|
+
ortho_projected_M = _newton_schulz_iteration(
|
|
390
|
+
projected_M, steps=group['ns_steps'], eps=group['ns_eps'], coeffs=group['ns_coeffs']
|
|
391
|
+
)
|
|
392
|
+
update = Q @ ortho_projected_M
|
|
393
|
+
else: # Fallback for invalid rank
|
|
394
|
+
update = _newton_schulz_iteration(
|
|
395
|
+
signed_m_buf, steps=group['ns_steps'], eps=group['ns_eps'], coeffs=group['ns_coeffs']
|
|
396
|
+
)
|
|
397
|
+
else:
|
|
398
|
+
# Original full Newton-Schulz
|
|
399
|
+
update = _newton_schulz_iteration(
|
|
400
|
+
signed_m_buf,
|
|
401
|
+
steps=group['ns_steps'],
|
|
402
|
+
eps=group['ns_eps'],
|
|
403
|
+
coeffs=group['ns_coeffs'],
|
|
404
|
+
)
|
|
405
|
+
del signed_m_buf
|
|
406
|
+
|
|
407
|
+
update = update.view(original_shape)
|
|
408
|
+
|
|
409
|
+
vt_buf = state['second_momentum_buffer']
|
|
410
|
+
vt_buf.mul_(beta2).addcmul_(update, update, value=1 - beta2)
|
|
411
|
+
|
|
412
|
+
# Apply second momentum update (adaptive scaling)
|
|
413
|
+
if group['use_atan2']:
|
|
414
|
+
a = 1.2732395
|
|
415
|
+
denom = vt_buf.sqrt()
|
|
416
|
+
update.atan2_(denom).mul_(a)
|
|
417
|
+
else:
|
|
418
|
+
denom = vt_buf.sqrt().add_(group['eps'])
|
|
419
|
+
update.div_(denom)
|
|
420
|
+
del denom
|
|
421
|
+
|
|
422
|
+
# RMS-aligned rescaling
|
|
423
|
+
rms_target = group['rms_target']
|
|
424
|
+
num_elements = update.numel()
|
|
425
|
+
# Add eps to prevent division by zero
|
|
426
|
+
update.mul_(rms_target * (num_elements ** 0.5) / (update.norm() + group['eps']))
|
|
427
|
+
|
|
428
|
+
del num_elements
|
|
429
|
+
|
|
430
|
+
update.mul_(lr)
|
|
431
|
+
|
|
432
|
+
else: # Fallback to standard SGD with momentum for 1D params (biases, etc.)
|
|
433
|
+
# Momentum update
|
|
434
|
+
mt_buf = state['momentum_buffer']
|
|
435
|
+
mt_buf.mul_(beta1).add_(grad)
|
|
436
|
+
if nesterov:
|
|
437
|
+
# Nesterov momentum
|
|
438
|
+
update = grad.add(mt_buf, alpha=beta1)
|
|
439
|
+
# elif Simplified_AdEMAMix: # TODO, it will break SGD since it requires x100 lower LR
|
|
440
|
+
# update = mt_buf.add(grad, alpha=alpha_grad)
|
|
441
|
+
else:
|
|
442
|
+
update = mt_buf.clone()
|
|
443
|
+
update.mul_(lr)
|
|
444
|
+
|
|
445
|
+
# Decoupled weight decay
|
|
446
|
+
if group["weight_decay"] != 0:
|
|
447
|
+
if p.dtype == torch.bfloat16 and self.stochastic_rounding:
|
|
448
|
+
add_stochastic_(p.data, p.data, alpha=-group["weight_decay"] * lr)
|
|
449
|
+
else:
|
|
450
|
+
p.data.add_(p.data, alpha=-group["weight_decay"] * lr)
|
|
451
|
+
|
|
452
|
+
if p.dtype == torch.bfloat16 and self.stochastic_rounding:
|
|
453
|
+
add_stochastic_(p.data, -update)
|
|
454
|
+
else:
|
|
455
|
+
p.data.add_(-update)
|
|
456
|
+
del update
|
|
457
|
+
|
|
458
|
+
@torch.no_grad()
|
|
459
|
+
def _adam_step_parameter(self, p, grad, state, group, lr, bias_correction1, bias_correction2):
|
|
460
|
+
if grad.dtype != torch.float32 and state.get('factored', False):
|
|
461
|
+
grad = grad.float()
|
|
462
|
+
if group.get("adam_orthogonal_gradient"):
|
|
463
|
+
grad = _orthogonalize_gradient(p, grad)
|
|
464
|
+
|
|
465
|
+
beta1_adam, beta2_adam = group['adam_betas']
|
|
466
|
+
|
|
467
|
+
if group.get('adam_kourkoutas_beta', False):
|
|
468
|
+
# Accumulate current grad's norm for the *next* step
|
|
469
|
+
self.kourkoutas_helper.accumulate_gradient_sq_norm(p, grad)
|
|
470
|
+
# Get the dynamic beta2_adam calculated in prepare_step()
|
|
471
|
+
beta2_adam = self.kourkoutas_helper.get_beta2(p, group)
|
|
472
|
+
|
|
473
|
+
step_size = lr / bias_correction1
|
|
474
|
+
|
|
475
|
+
if group.get('adam_use_AdEMAMix'):
|
|
476
|
+
beta3_ema = group['adam_beta3_ema']
|
|
477
|
+
alpha = group['adam_alpha']
|
|
478
|
+
|
|
479
|
+
if state['factored']:
|
|
480
|
+
d1, d2 = state['effective_shape']
|
|
481
|
+
grad_reshaped = grad.view(d1, d2)
|
|
482
|
+
|
|
483
|
+
# Reconstruct momentum from previous step's factors
|
|
484
|
+
if beta1_adam > 0:
|
|
485
|
+
mt = _unnmf((state['mu_m_nmf'], state['mv_m_nmf']))
|
|
486
|
+
if not group.get('adam_grams_moment'):
|
|
487
|
+
unpacked_sign = _unpack_bools(state['sign'], original_m=d2)
|
|
488
|
+
torch.where(unpacked_sign, mt, -mt, out=mt)
|
|
489
|
+
del unpacked_sign
|
|
490
|
+
# Update momentum in full-size
|
|
491
|
+
mt.mul_(beta1_adam).add_(grad_reshaped, alpha=1.0 - beta1_adam)
|
|
492
|
+
if group.get('adam_grams_moment'):
|
|
493
|
+
mt = (grad_reshaped.sign().mul_(mt.abs()))
|
|
494
|
+
elif group.get('adam_cautious_mask'):
|
|
495
|
+
mask = (mt * grad_reshaped > 0).to(grad_reshaped.dtype)
|
|
496
|
+
mask.div_(mask.mean().clamp_(min=1e-3))
|
|
497
|
+
mt.mul_(mask)
|
|
498
|
+
del mask
|
|
499
|
+
|
|
500
|
+
vt = _unnmf((state['mu_v_nmf'], state['mv_v_nmf']))
|
|
501
|
+
vt.mul_(beta2_adam).addcmul_(grad_reshaped, grad_reshaped, value=1.0 - beta2_adam)
|
|
502
|
+
|
|
503
|
+
if group.get('adam_use_AdEMAMix'):
|
|
504
|
+
mt_slow = _unnmf((state['mu_m_slow_nmf'], state['mv_m_slow_nmf']))
|
|
505
|
+
if state['sign_slow'].dtype != torch.uint8:
|
|
506
|
+
state['sign_slow'] = state['sign_slow'].to(torch.uint8)
|
|
507
|
+
unpacked_sign_slow = _unpack_bools(state['sign_slow'], original_m=d2)
|
|
508
|
+
torch.where(unpacked_sign_slow, mt_slow, -mt_slow, out=mt_slow)
|
|
509
|
+
del unpacked_sign_slow
|
|
510
|
+
|
|
511
|
+
mt_slow.mul_(beta3_ema).add_(grad_reshaped, alpha=1.0 - beta3_ema)
|
|
512
|
+
if beta1_adam > 0:
|
|
513
|
+
update = torch.add(mt, mt_slow, alpha=alpha)
|
|
514
|
+
else:
|
|
515
|
+
update = torch.add(grad_reshaped, mt_slow, alpha=alpha)
|
|
516
|
+
else:
|
|
517
|
+
update = mt.clone() if beta1_adam > 0 else grad_reshaped.clone()
|
|
518
|
+
del grad_reshaped
|
|
519
|
+
|
|
520
|
+
if group['adam_use_atan2']:
|
|
521
|
+
a = 1.2732395
|
|
522
|
+
denom = (vt.sqrt() / (bias_correction2**0.5))
|
|
523
|
+
update.atan2_(denom).mul_(a)
|
|
524
|
+
else:
|
|
525
|
+
denom = (vt.sqrt() / (bias_correction2**0.5)).add_(group['adam_eps'])
|
|
526
|
+
update.div_(denom)
|
|
527
|
+
del denom
|
|
528
|
+
|
|
529
|
+
update = update.view(p.shape).mul_(step_size)
|
|
530
|
+
|
|
531
|
+
# Compress updated moments and store new factors
|
|
532
|
+
if beta1_adam > 0:
|
|
533
|
+
if not group.get('adam_grams_moment'):
|
|
534
|
+
state['sign'] = _pack_bools(mt > 0)
|
|
535
|
+
_nnmf(mt.abs(), out=(state['mu_m_nmf'], state['mv_m_nmf']))
|
|
536
|
+
del mt
|
|
537
|
+
if group.get('adam_use_AdEMAMix'):
|
|
538
|
+
state['sign_slow'] = _pack_bools(mt_slow > 0)
|
|
539
|
+
_nnmf(mt_slow.abs(), out=(state['mu_m_slow_nmf'], state['mv_m_slow_nmf']))
|
|
540
|
+
del mt_slow
|
|
541
|
+
_nnmf(vt, out=(state['mu_v_nmf'], state['mv_v_nmf']))
|
|
542
|
+
del vt
|
|
543
|
+
|
|
544
|
+
else: # Standard AdamW logic for non-factored tensors
|
|
545
|
+
exp_avg_sq = state['exp_avg_sq']
|
|
546
|
+
|
|
547
|
+
if beta1_adam > 0:
|
|
548
|
+
exp_avg = state['exp_avg']
|
|
549
|
+
exp_avg.mul_(beta1_adam).add_(grad, alpha=1 - beta1_adam)
|
|
550
|
+
if group.get('adam_grams_moment'):
|
|
551
|
+
exp_avg = grad.sign().mul_(exp_avg.abs())
|
|
552
|
+
elif group.get('adam_cautious_mask'):
|
|
553
|
+
mask = (exp_avg * grad > 0).to(grad.dtype)
|
|
554
|
+
mask.div_(mask.mean().clamp_(min=1e-3))
|
|
555
|
+
exp_avg.mul_(mask)
|
|
556
|
+
del mask
|
|
557
|
+
|
|
558
|
+
if group.get('adam_use_AdEMAMix'):
|
|
559
|
+
exp_avg_slow = state['exp_avg_slow']
|
|
560
|
+
exp_avg_slow.mul_(beta3_ema).add_(grad, alpha=1 - beta3_ema)
|
|
561
|
+
if beta1_adam > 0:
|
|
562
|
+
update = torch.add(exp_avg, exp_avg_slow, alpha=alpha)
|
|
563
|
+
else:
|
|
564
|
+
update = torch.add(grad, exp_avg_slow, alpha=alpha)
|
|
565
|
+
else:
|
|
566
|
+
update = exp_avg.clone() if beta1_adam > 0 else grad.clone()
|
|
567
|
+
|
|
568
|
+
exp_avg_sq.mul_(beta2_adam).addcmul_(grad, grad.conj(), value=1 - beta2_adam)
|
|
569
|
+
|
|
570
|
+
if group.get('adam_use_atan2'):
|
|
571
|
+
a = 1.2732395
|
|
572
|
+
denom = (exp_avg_sq.sqrt() / (bias_correction2**0.5))
|
|
573
|
+
update.atan2_(denom).mul_(a)
|
|
574
|
+
else:
|
|
575
|
+
denom = (exp_avg_sq.sqrt() / (bias_correction2**0.5)).add_(group['adam_eps'])
|
|
576
|
+
update.div_(denom)
|
|
577
|
+
del denom
|
|
578
|
+
|
|
579
|
+
update.mul_(step_size)
|
|
580
|
+
|
|
581
|
+
# Decoupled weight decay
|
|
582
|
+
if group["adam_weight_decay"] != 0:
|
|
583
|
+
if p.dtype == torch.bfloat16 and self.stochastic_rounding:
|
|
584
|
+
add_stochastic_(p.data, p.data, alpha=-group["adam_weight_decay"] * lr)
|
|
585
|
+
else:
|
|
586
|
+
p.data.add_(p.data, alpha=-group["adam_weight_decay"] * lr)
|
|
587
|
+
|
|
588
|
+
if p.dtype == torch.bfloat16 and self.stochastic_rounding:
|
|
589
|
+
add_stochastic_(p.data, -update)
|
|
590
|
+
else:
|
|
591
|
+
p.data.add_(-update)
|
|
592
|
+
del update
|
|
593
|
+
|
|
594
|
+
@torch.no_grad()
|
|
595
|
+
def step_parameter(self, p: torch.Tensor, group: dict, i: int | None = None):
|
|
596
|
+
grad = p.grad
|
|
597
|
+
if grad is None:
|
|
598
|
+
return
|
|
599
|
+
state = self.state[p]
|
|
600
|
+
|
|
601
|
+
|
|
602
|
+
|
|
603
|
+
# Determine if using Adam or Muon based on state keys
|
|
604
|
+
# We can use optm_type but I see this as a safer way.
|
|
605
|
+
if 'momentum_buffer' in state or 'mu_mbuf_nmf' in state:
|
|
606
|
+
use_adam = False
|
|
607
|
+
else:
|
|
608
|
+
use_adam = True
|
|
609
|
+
|
|
610
|
+
lr = group['lr']
|
|
611
|
+
is_compiled = group.get('compiled_optimizer', False)
|
|
612
|
+
|
|
613
|
+
if use_adam:
|
|
614
|
+
if self.kourkoutas_helper:
|
|
615
|
+
# Prepare Kourkoutas-β once per optimizer step.
|
|
616
|
+
self.kourkoutas_helper.maybe_prepare_step(self.global_step)
|
|
617
|
+
# Adam-specific setup (bias correction)
|
|
618
|
+
if group['adam_use_bias_correction']:
|
|
619
|
+
current_step = self.global_step + 1
|
|
620
|
+
beta1_adam, beta2_adam = group['adam_betas']
|
|
621
|
+
bias_correction1 = 1.0 - beta1_adam ** current_step
|
|
622
|
+
bias_correction2 = 1.0 - beta2_adam ** current_step
|
|
623
|
+
else:
|
|
624
|
+
bias_correction1 = 1.0
|
|
625
|
+
bias_correction2 = 1.0
|
|
626
|
+
# Dispatch to compiled or uncompiled Adam step
|
|
627
|
+
if is_compiled and self._compiled_adam_step is not None:
|
|
628
|
+
# Tensors must be used for compiled functions
|
|
629
|
+
lr_tensor = torch.tensor(lr, device=p.device)
|
|
630
|
+
bc1_tensor = torch.tensor(bias_correction1, device=p.device)
|
|
631
|
+
bc2_tensor = torch.tensor(bias_correction2, device=p.device)
|
|
632
|
+
self._compiled_adam_step(p, grad, state, group, lr_tensor, bc1_tensor, bc2_tensor)
|
|
633
|
+
else:
|
|
634
|
+
self._adam_step_parameter(p, grad, state, group, lr, bias_correction1, bias_correction2)
|
|
635
|
+
else: # Muon path
|
|
636
|
+
# Dispatch to compiled or uncompiled Muon step
|
|
637
|
+
if is_compiled and self._compiled_muon_step is not None:
|
|
638
|
+
lr_tensor = torch.tensor(lr, device=p.device)
|
|
639
|
+
self._compiled_muon_step(p, grad, state, group, lr_tensor)
|
|
640
|
+
else:
|
|
641
|
+
self._muon_step_parameter(p, grad, state, group, lr)
|
|
642
|
+
|
|
643
|
+
|
|
644
|
+
def compile(self, *args, **kwargs):
|
|
645
|
+
print("Compiling AdaMuon step path...")
|
|
646
|
+
self._compiled_muon_step = torch.compile(self._muon_step_parameter, *args, **kwargs)
|
|
647
|
+
print("Compiling AuxAdam step path...")
|
|
648
|
+
self._compiled_adam_step = torch.compile(self._adam_step_parameter, *args, **kwargs)
|
|
649
|
+
|
|
650
|
+
@torch.no_grad()
|
|
651
|
+
def step(self, closure=None):
|
|
652
|
+
"""Performs a single optimization step."""
|
|
653
|
+
loss = None
|
|
654
|
+
if closure is not None:
|
|
655
|
+
with torch.enable_grad():
|
|
656
|
+
loss = closure()
|
|
657
|
+
|
|
658
|
+
for group in self.param_groups:
|
|
659
|
+
for i, p in enumerate(group['params']):
|
|
660
|
+
self.step_parameter(p, group, i)
|
|
661
|
+
|
|
662
|
+
self.global_step += 1
|
|
663
|
+
|
|
664
|
+
return loss
|