adv-optm 0.1.3__py3-none-any.whl → 0.1.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of adv-optm might be problematic. Click here for more details.
- adv_optm/__init__.py +1 -1
- adv_optm/optim/AdamW_adv.py +296 -296
- adv_optm/optim/Lion_Prodigy_adv.py +22 -6
- adv_optm/optim/Lion_adv.py +242 -228
- {adv_optm-0.1.3.dist-info → adv_optm-0.1.4.dist-info}/METADATA +1 -1
- {adv_optm-0.1.3.dist-info → adv_optm-0.1.4.dist-info}/RECORD +9 -9
- {adv_optm-0.1.3.dist-info → adv_optm-0.1.4.dist-info}/WHEEL +0 -0
- {adv_optm-0.1.3.dist-info → adv_optm-0.1.4.dist-info}/licenses/LICENSE +0 -0
- {adv_optm-0.1.3.dist-info → adv_optm-0.1.4.dist-info}/top_level.txt +0 -0
adv_optm/__init__.py
CHANGED
adv_optm/optim/AdamW_adv.py
CHANGED
|
@@ -1,297 +1,297 @@
|
|
|
1
|
-
import torch
|
|
2
|
-
from typing import Optional
|
|
3
|
-
|
|
4
|
-
from ..util.BF16_Stochastic_Rounding import add_stochastic_
|
|
5
|
-
from ..util.Effective_Shape import _get_effective_shape
|
|
6
|
-
from ..util.NNMF import _nnmf,_unnmf
|
|
7
|
-
from ..util.OrthoGrad import _orthogonalize_gradient
|
|
8
|
-
from ..util.One_Bit_Boolean import _pack_bools, _unpack_bools
|
|
9
|
-
|
|
10
|
-
class AdamW_adv(torch.optim.Optimizer):
|
|
11
|
-
"""
|
|
12
|
-
Implements a factored AdamW algorithm.
|
|
13
|
-
This is an advanced version of AdamW with optional features like
|
|
14
|
-
low-rank factorization of optimizer states (SMMF), OrthoGrad, etc.
|
|
15
|
-
|
|
16
|
-
Args:
|
|
17
|
-
params (iterable): iterable of parameters to optimize or dicts defining
|
|
18
|
-
parameter groups
|
|
19
|
-
lr (float): learning rate (default: 1e-3)
|
|
20
|
-
betas (tuple[float, float]): coefficients used for computing running
|
|
21
|
-
averages of gradient and its square (default: (0.9, 0.999))
|
|
22
|
-
eps (float): term added to the denominator to improve
|
|
23
|
-
numerical stability (default: 1e-8)
|
|
24
|
-
weight_decay (float): weight decay (L2 penalty) (default: 0)
|
|
25
|
-
vector_reshape (bool): whether to reshape 1D vectors into 2D
|
|
26
|
-
matrices to apply low-rank compression (default: True).
|
|
27
|
-
stochastic_rounding (bool): whether to use stochastic
|
|
28
|
-
rounding for BF16 parameter updates (default: True).
|
|
29
|
-
use_atan2 (bool): whether to use the atan2 update rule. (default: False)
|
|
30
|
-
use_grams (bool): whether to use Grams-style updates. (default: False)
|
|
31
|
-
use_cautious (bool): whether to use cautious masking to align the gradient's
|
|
32
|
-
direction with the first moment's. (default: False)
|
|
33
|
-
use_orthograd (bool): whether to use OrthoGrad. (default: False)
|
|
34
|
-
use_AdEMAMix (bool): whether to enable the AdEMAMix feature. This adds
|
|
35
|
-
a second, slow-moving average of the momentum (`mt_slow`) which is
|
|
36
|
-
combined with the primary momentum (`mt`) to stabilize updates,
|
|
37
|
-
especially in noisy, small-batch settings. If `False`, the
|
|
38
|
-
optimizer behaves as standard AdamW. (default: False)
|
|
39
|
-
beta3_ema (float): The decay rate for the slow exponential moving average of
|
|
40
|
-
the momentum (only used when `use_AdEMAMix` is `True`). A higher
|
|
41
|
-
value (e.g., 0.9999) gives the EMA a longer memory, making it more
|
|
42
|
-
stable but slower to adapt. A lower value (e.g., 0.999) is often
|
|
43
|
-
better for shorter training runs. (default: 0.9999)
|
|
44
|
-
alpha (float): The mixing coefficient that scales the slow momentum term
|
|
45
|
-
before it is added to the fast momentum term (`update = mt + alpha * mt_slow`).
|
|
46
|
-
A higher value increases the stabilizing influence of the slow
|
|
47
|
-
momentum. (default: 5.0)
|
|
48
|
-
t_alpha (Optional[int]): The number of steps for a linear warmup of the
|
|
49
|
-
`alpha` parameter (only used when `use_AdEMAMix` is `True`). This is
|
|
50
|
-
highly recommended to prevent instability at the beginning of training,
|
|
51
|
-
as it gradually introduces the stabilizing slow momentum term. During
|
|
52
|
-
the warmup, `alpha` ramps from 0 to its target value. If `None`,
|
|
53
|
-
the scheduler is disabled and th
|
|
54
|
-
factored (bool): whether to use the factorization or disable it to use
|
|
55
|
-
the uncompressed optimizer. (default: True)
|
|
56
|
-
"""
|
|
57
|
-
|
|
58
|
-
def __init__(
|
|
59
|
-
self,
|
|
60
|
-
params,
|
|
61
|
-
lr: float = 1e-3,
|
|
62
|
-
betas: tuple[float, float] = (0.9, 0.999),
|
|
63
|
-
eps: float = 1e-8,
|
|
64
|
-
weight_decay: float = 0.0,
|
|
65
|
-
vector_reshape: bool = True,
|
|
66
|
-
stochastic_rounding: bool = True,
|
|
67
|
-
use_atan2: bool = False,
|
|
68
|
-
use_cautious: bool = False,
|
|
69
|
-
use_grams: bool = False,
|
|
70
|
-
use_orthograd: bool = False,
|
|
71
|
-
use_AdEMAMix: bool = False,
|
|
72
|
-
beta3_ema: float = 0.9999,
|
|
73
|
-
alpha: float = 5.0,
|
|
74
|
-
t_alpha: int | None = None,
|
|
75
|
-
factored: bool = True,
|
|
76
|
-
):
|
|
77
|
-
if not (lr >= 0.0):
|
|
78
|
-
raise ValueError(f"Learning-rate should be >= 0.0. Got {lr}")
|
|
79
|
-
if not (0.0 <= betas[0] < 1.0 and 0.0 <= betas[1] < 1.0):
|
|
80
|
-
raise ValueError(f"Betas should be in [0.0, 1.0). Got {betas}")
|
|
81
|
-
if not (eps >= 0.0):
|
|
82
|
-
raise ValueError(f"Epsilon should be >= 0.0. Got {eps}")
|
|
83
|
-
if not (weight_decay >= 0.0):
|
|
84
|
-
raise ValueError(f"Weight-decay should be >= 0.0. Got {weight_decay}")
|
|
85
|
-
|
|
86
|
-
defaults = {
|
|
87
|
-
"lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay,
|
|
88
|
-
"vector_reshape": vector_reshape, "use_atan2": use_atan2,
|
|
89
|
-
"use_orthograd": use_orthograd,
|
|
90
|
-
"beta3_ema": beta3_ema, "alpha": alpha, "t_alpha": t_alpha,
|
|
91
|
-
}
|
|
92
|
-
self.stochastic_rounding = stochastic_rounding
|
|
93
|
-
self.use_cautious = use_cautious
|
|
94
|
-
self.use_grams = use_grams
|
|
95
|
-
self.use_AdEMAMix = use_AdEMAMix
|
|
96
|
-
self.factored = factored
|
|
97
|
-
super().__init__(params, defaults)
|
|
98
|
-
|
|
99
|
-
@property
|
|
100
|
-
def supports_fused_back_pass(self):
|
|
101
|
-
return True
|
|
102
|
-
|
|
103
|
-
@property
|
|
104
|
-
def supports_memory_efficient_fp16(self):
|
|
105
|
-
return True
|
|
106
|
-
|
|
107
|
-
@property
|
|
108
|
-
def supports_flat_params(self):
|
|
109
|
-
return False
|
|
110
|
-
|
|
111
|
-
@torch.no_grad()
|
|
112
|
-
def step_parameter(self, p: torch.Tensor, group: dict, i: int | None = None):
|
|
113
|
-
if p.grad is None:
|
|
114
|
-
return
|
|
115
|
-
|
|
116
|
-
grad = p.grad
|
|
117
|
-
if grad.dtype != torch.float32 and self.factored:
|
|
118
|
-
grad = grad.float()
|
|
119
|
-
if group["use_orthograd"]:
|
|
120
|
-
grad = _orthogonalize_gradient(p, grad)
|
|
121
|
-
state = self.state[p]
|
|
122
|
-
|
|
123
|
-
beta1, beta2 = group['betas']
|
|
124
|
-
|
|
125
|
-
# State Initialization
|
|
126
|
-
if len(state) == 0:
|
|
127
|
-
state['step'] = 0
|
|
128
|
-
|
|
129
|
-
should_factor = (
|
|
130
|
-
self.factored and
|
|
131
|
-
not (len(p.shape) == 1 and not group['vector_reshape'])
|
|
132
|
-
)
|
|
133
|
-
|
|
134
|
-
state['factored'] = should_factor
|
|
135
|
-
|
|
136
|
-
dtype = torch.float32 if self.factored else p.dtype
|
|
137
|
-
device = p.device
|
|
138
|
-
|
|
139
|
-
if state['factored']:
|
|
140
|
-
state['effective_shape'] = _get_effective_shape(p.numel())
|
|
141
|
-
d1, d2 = state['effective_shape']
|
|
142
|
-
|
|
143
|
-
# First moment (m)
|
|
144
|
-
if beta1 > 0:
|
|
145
|
-
state['mu_m_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
|
|
146
|
-
state['mv_m_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
|
|
147
|
-
if not self.use_grams:
|
|
148
|
-
packed_d2 = (d2 + 7) // 8
|
|
149
|
-
state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=device)
|
|
150
|
-
if self.use_AdEMAMix:
|
|
151
|
-
state['mu_m_slow_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
|
|
152
|
-
state['mv_m_slow_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
|
|
153
|
-
packed_d2 = (d2 + 7) // 8
|
|
154
|
-
state['sign_slow'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=p.device)
|
|
155
|
-
# Second moment (v)
|
|
156
|
-
state['mu_v_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
|
|
157
|
-
state['mv_v_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
|
|
158
|
-
else: # Fallback to standard AdamW for non-factored tensors
|
|
159
|
-
if beta1 > 0:
|
|
160
|
-
state['exp_avg'] = torch.zeros_like(p, device=device, dtype=dtype)
|
|
161
|
-
if self.use_AdEMAMix:
|
|
162
|
-
state['exp_avg_slow'] = torch.zeros_like(p, dtype=dtype)
|
|
163
|
-
state['exp_avg_sq'] = torch.zeros_like(p, device=device, dtype=dtype)
|
|
164
|
-
|
|
165
|
-
if self.use_AdEMAMix:
|
|
166
|
-
beta3_ema = group['beta3_ema']
|
|
167
|
-
alpha = group['alpha']
|
|
168
|
-
t_alpha = group['t_alpha']
|
|
169
|
-
current_step = state['step'] + 1
|
|
170
|
-
alpha_t = alpha
|
|
171
|
-
if t_alpha is not None and t_alpha > 0 and current_step < t_alpha:
|
|
172
|
-
alpha_t = min(current_step * alpha / t_alpha, alpha)
|
|
173
|
-
|
|
174
|
-
if state['factored']:
|
|
175
|
-
d1, d2 = state['effective_shape']
|
|
176
|
-
|
|
177
|
-
# Reconstruct momentum from previous step's factors
|
|
178
|
-
if beta1 > 0:
|
|
179
|
-
mt = _unnmf((state['mu_m_nmf'], state['mv_m_nmf']))
|
|
180
|
-
if not self.use_grams:
|
|
181
|
-
unpacked_sign = _unpack_bools(state['sign'], original_m=d2)
|
|
182
|
-
torch.where(unpacked_sign, mt, -mt, out=mt)
|
|
183
|
-
del unpacked_sign
|
|
184
|
-
# Update momentum in full-size
|
|
185
|
-
grad_reshaped = grad.view(d1, d2)
|
|
186
|
-
mt.mul_(beta1).add_(grad_reshaped, alpha=1.0 - beta1)
|
|
187
|
-
if self.use_grams:
|
|
188
|
-
mt.copy_(grad_reshaped.sign() * mt.abs())
|
|
189
|
-
elif self.use_cautious:
|
|
190
|
-
mask = (mt * grad_reshaped > 0).to(grad_reshaped.dtype)
|
|
191
|
-
mask.div_(mask.mean().clamp_(min=1e-3))
|
|
192
|
-
mt.mul_(mask)
|
|
193
|
-
del mask
|
|
194
|
-
|
|
195
|
-
vt = _unnmf((state['mu_v_nmf'], state['mv_v_nmf']))
|
|
196
|
-
vt.mul_(beta2).addcmul_(grad_reshaped, grad_reshaped, value=1.0 - beta2)
|
|
197
|
-
|
|
198
|
-
if self.use_AdEMAMix:
|
|
199
|
-
mt_slow = _unnmf((state['mu_m_slow_nmf'], state['mv_m_slow_nmf']))
|
|
200
|
-
if state['sign_slow'].dtype != torch.uint8:
|
|
201
|
-
state['sign_slow'] = state['sign_slow'].to(torch.uint8)
|
|
202
|
-
unpacked_sign_slow = _unpack_bools(state['sign_slow'], original_m=d2)
|
|
203
|
-
torch.where(unpacked_sign_slow, mt_slow, -mt_slow, out=mt_slow)
|
|
204
|
-
del unpacked_sign_slow
|
|
205
|
-
|
|
206
|
-
mt_slow.mul_(beta3_ema).add_(grad_reshaped, alpha=1.0 - beta3_ema)
|
|
207
|
-
update = mt + (alpha_t * mt_slow) if beta1 > 0 else grad_reshaped + (alpha_t * mt_slow)
|
|
208
|
-
else:
|
|
209
|
-
update = mt if beta1 > 0 else grad_reshaped
|
|
210
|
-
del grad_reshaped
|
|
211
|
-
|
|
212
|
-
if group['use_atan2']:
|
|
213
|
-
a = 1.2732395
|
|
214
|
-
denom = vt.sqrt()
|
|
215
|
-
update.atan2_(denom).mul_(a)
|
|
216
|
-
else:
|
|
217
|
-
denom = vt.sqrt()
|
|
218
|
-
update.div_(denom.add_(group['eps']))
|
|
219
|
-
del denom
|
|
220
|
-
|
|
221
|
-
update.view(p.shape).mul_(group['lr'])
|
|
222
|
-
|
|
223
|
-
# Compress updated moments and store new factors
|
|
224
|
-
if beta1 > 0:
|
|
225
|
-
if not self.use_grams:
|
|
226
|
-
state['sign'] = _pack_bools(mt > 0)
|
|
227
|
-
_nnmf(mt.abs(), out=(state['mu_m_nmf'], state['mv_m_nmf']))
|
|
228
|
-
del mt
|
|
229
|
-
if self.use_AdEMAMix:
|
|
230
|
-
state['sign_slow'] = _pack_bools(mt_slow > 0)
|
|
231
|
-
_nnmf(mt_slow.abs(), out=(state['mu_m_slow_nmf'], state['mv_m_slow_nmf']))
|
|
232
|
-
del mt_slow
|
|
233
|
-
_nnmf(vt, out=(state['mu_v_nmf'], state['mv_v_nmf']))
|
|
234
|
-
del vt
|
|
235
|
-
|
|
236
|
-
else: # Standard AdamW logic for non-factored tensors
|
|
237
|
-
exp_avg_sq = state['exp_avg_sq']
|
|
238
|
-
|
|
239
|
-
if beta1 > 0:
|
|
240
|
-
exp_avg = state['exp_avg']
|
|
241
|
-
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
|
|
242
|
-
if self.use_grams:
|
|
243
|
-
exp_avg = grad.sign() * exp_avg.abs()
|
|
244
|
-
elif self.use_cautious:
|
|
245
|
-
mask = (exp_avg * grad > 0).to(grad.dtype)
|
|
246
|
-
mask.div_(mask.mean().clamp_(min=1e-3))
|
|
247
|
-
exp_avg.mul_(mask)
|
|
248
|
-
del mask
|
|
249
|
-
|
|
250
|
-
if self.use_AdEMAMix:
|
|
251
|
-
exp_avg_slow = state['exp_avg_slow']
|
|
252
|
-
exp_avg_slow.mul_(beta3_ema).add_(grad, alpha=1 - beta3_ema)
|
|
253
|
-
update = exp_avg + (alpha_t * exp_avg_slow) if beta1 > 0 else grad + (alpha_t * exp_avg_slow)
|
|
254
|
-
else:
|
|
255
|
-
update = exp_avg if beta1 > 0 else grad
|
|
256
|
-
|
|
257
|
-
exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2)
|
|
258
|
-
|
|
259
|
-
if group['use_atan2']:
|
|
260
|
-
a = 1.2732395
|
|
261
|
-
denom = exp_avg_sq.sqrt()
|
|
262
|
-
update.atan2_(denom).mul_(a)
|
|
263
|
-
else:
|
|
264
|
-
denom = exp_avg_sq.sqrt()
|
|
265
|
-
update.div_(denom.add_(group['eps']))
|
|
266
|
-
del denom
|
|
267
|
-
|
|
268
|
-
update.mul_(group['lr'])
|
|
269
|
-
|
|
270
|
-
# Decoupled weight decay
|
|
271
|
-
if group["weight_decay"] != 0:
|
|
272
|
-
if p.dtype == torch.bfloat16 and self.stochastic_rounding:
|
|
273
|
-
add_stochastic_(p.data, p.data, alpha=-group["weight_decay"] * group["lr"])
|
|
274
|
-
else:
|
|
275
|
-
p.data.add_(p.data, alpha=-group["weight_decay"] * group["lr"])
|
|
276
|
-
|
|
277
|
-
if p.dtype == torch.bfloat16 and self.stochastic_rounding:
|
|
278
|
-
add_stochastic_(p.data, -update)
|
|
279
|
-
else:
|
|
280
|
-
p.data.add_(-update)
|
|
281
|
-
del update
|
|
282
|
-
|
|
283
|
-
state['step'] += 1
|
|
284
|
-
|
|
285
|
-
@torch.no_grad()
|
|
286
|
-
def step(self, closure=None):
|
|
287
|
-
"""Performs a single optimization step."""
|
|
288
|
-
loss = None
|
|
289
|
-
if closure is not None:
|
|
290
|
-
with torch.enable_grad():
|
|
291
|
-
loss = closure()
|
|
292
|
-
|
|
293
|
-
for group in self.param_groups:
|
|
294
|
-
for i, p in enumerate(group['params']):
|
|
295
|
-
self.step_parameter(p, group, i)
|
|
296
|
-
|
|
1
|
+
import torch
|
|
2
|
+
from typing import Optional
|
|
3
|
+
|
|
4
|
+
from ..util.BF16_Stochastic_Rounding import add_stochastic_
|
|
5
|
+
from ..util.Effective_Shape import _get_effective_shape
|
|
6
|
+
from ..util.NNMF import _nnmf,_unnmf
|
|
7
|
+
from ..util.OrthoGrad import _orthogonalize_gradient
|
|
8
|
+
from ..util.One_Bit_Boolean import _pack_bools, _unpack_bools
|
|
9
|
+
|
|
10
|
+
class AdamW_adv(torch.optim.Optimizer):
|
|
11
|
+
"""
|
|
12
|
+
Implements a factored AdamW algorithm.
|
|
13
|
+
This is an advanced version of AdamW with optional features like
|
|
14
|
+
low-rank factorization of optimizer states (SMMF), OrthoGrad, etc.
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
params (iterable): iterable of parameters to optimize or dicts defining
|
|
18
|
+
parameter groups
|
|
19
|
+
lr (float): learning rate (default: 1e-3)
|
|
20
|
+
betas (tuple[float, float]): coefficients used for computing running
|
|
21
|
+
averages of gradient and its square (default: (0.9, 0.999))
|
|
22
|
+
eps (float): term added to the denominator to improve
|
|
23
|
+
numerical stability (default: 1e-8)
|
|
24
|
+
weight_decay (float): weight decay (L2 penalty) (default: 0)
|
|
25
|
+
vector_reshape (bool): whether to reshape 1D vectors into 2D
|
|
26
|
+
matrices to apply low-rank compression (default: True).
|
|
27
|
+
stochastic_rounding (bool): whether to use stochastic
|
|
28
|
+
rounding for BF16 parameter updates (default: True).
|
|
29
|
+
use_atan2 (bool): whether to use the atan2 update rule. (default: False)
|
|
30
|
+
use_grams (bool): whether to use Grams-style updates. (default: False)
|
|
31
|
+
use_cautious (bool): whether to use cautious masking to align the gradient's
|
|
32
|
+
direction with the first moment's. (default: False)
|
|
33
|
+
use_orthograd (bool): whether to use OrthoGrad. (default: False)
|
|
34
|
+
use_AdEMAMix (bool): whether to enable the AdEMAMix feature. This adds
|
|
35
|
+
a second, slow-moving average of the momentum (`mt_slow`) which is
|
|
36
|
+
combined with the primary momentum (`mt`) to stabilize updates,
|
|
37
|
+
especially in noisy, small-batch settings. If `False`, the
|
|
38
|
+
optimizer behaves as standard AdamW. (default: False)
|
|
39
|
+
beta3_ema (float): The decay rate for the slow exponential moving average of
|
|
40
|
+
the momentum (only used when `use_AdEMAMix` is `True`). A higher
|
|
41
|
+
value (e.g., 0.9999) gives the EMA a longer memory, making it more
|
|
42
|
+
stable but slower to adapt. A lower value (e.g., 0.999) is often
|
|
43
|
+
better for shorter training runs. (default: 0.9999)
|
|
44
|
+
alpha (float): The mixing coefficient that scales the slow momentum term
|
|
45
|
+
before it is added to the fast momentum term (`update = mt + alpha * mt_slow`).
|
|
46
|
+
A higher value increases the stabilizing influence of the slow
|
|
47
|
+
momentum. (default: 5.0)
|
|
48
|
+
t_alpha (Optional[int]): The number of steps for a linear warmup of the
|
|
49
|
+
`alpha` parameter (only used when `use_AdEMAMix` is `True`). This is
|
|
50
|
+
highly recommended to prevent instability at the beginning of training,
|
|
51
|
+
as it gradually introduces the stabilizing slow momentum term. During
|
|
52
|
+
the warmup, `alpha` ramps from 0 to its target value. If `None`,
|
|
53
|
+
the scheduler is disabled and th
|
|
54
|
+
factored (bool): whether to use the factorization or disable it to use
|
|
55
|
+
the uncompressed optimizer. (default: True)
|
|
56
|
+
"""
|
|
57
|
+
|
|
58
|
+
def __init__(
|
|
59
|
+
self,
|
|
60
|
+
params,
|
|
61
|
+
lr: float = 1e-3,
|
|
62
|
+
betas: tuple[float, float] = (0.9, 0.999),
|
|
63
|
+
eps: float = 1e-8,
|
|
64
|
+
weight_decay: float = 0.0,
|
|
65
|
+
vector_reshape: bool = True,
|
|
66
|
+
stochastic_rounding: bool = True,
|
|
67
|
+
use_atan2: bool = False,
|
|
68
|
+
use_cautious: bool = False,
|
|
69
|
+
use_grams: bool = False,
|
|
70
|
+
use_orthograd: bool = False,
|
|
71
|
+
use_AdEMAMix: bool = False,
|
|
72
|
+
beta3_ema: float = 0.9999,
|
|
73
|
+
alpha: float = 5.0,
|
|
74
|
+
t_alpha: int | None = None,
|
|
75
|
+
factored: bool = True,
|
|
76
|
+
):
|
|
77
|
+
if not (lr >= 0.0):
|
|
78
|
+
raise ValueError(f"Learning-rate should be >= 0.0. Got {lr}")
|
|
79
|
+
if not (0.0 <= betas[0] < 1.0 and 0.0 <= betas[1] < 1.0):
|
|
80
|
+
raise ValueError(f"Betas should be in [0.0, 1.0). Got {betas}")
|
|
81
|
+
if not (eps >= 0.0):
|
|
82
|
+
raise ValueError(f"Epsilon should be >= 0.0. Got {eps}")
|
|
83
|
+
if not (weight_decay >= 0.0):
|
|
84
|
+
raise ValueError(f"Weight-decay should be >= 0.0. Got {weight_decay}")
|
|
85
|
+
|
|
86
|
+
defaults = {
|
|
87
|
+
"lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay,
|
|
88
|
+
"vector_reshape": vector_reshape, "use_atan2": use_atan2,
|
|
89
|
+
"use_orthograd": use_orthograd,
|
|
90
|
+
"beta3_ema": beta3_ema, "alpha": alpha, "t_alpha": t_alpha,
|
|
91
|
+
}
|
|
92
|
+
self.stochastic_rounding = stochastic_rounding
|
|
93
|
+
self.use_cautious = use_cautious
|
|
94
|
+
self.use_grams = use_grams
|
|
95
|
+
self.use_AdEMAMix = use_AdEMAMix
|
|
96
|
+
self.factored = factored
|
|
97
|
+
super().__init__(params, defaults)
|
|
98
|
+
|
|
99
|
+
@property
|
|
100
|
+
def supports_fused_back_pass(self):
|
|
101
|
+
return True
|
|
102
|
+
|
|
103
|
+
@property
|
|
104
|
+
def supports_memory_efficient_fp16(self):
|
|
105
|
+
return True
|
|
106
|
+
|
|
107
|
+
@property
|
|
108
|
+
def supports_flat_params(self):
|
|
109
|
+
return False
|
|
110
|
+
|
|
111
|
+
@torch.no_grad()
|
|
112
|
+
def step_parameter(self, p: torch.Tensor, group: dict, i: int | None = None):
|
|
113
|
+
if p.grad is None:
|
|
114
|
+
return
|
|
115
|
+
|
|
116
|
+
grad = p.grad
|
|
117
|
+
if grad.dtype != torch.float32 and self.factored:
|
|
118
|
+
grad = grad.float()
|
|
119
|
+
if group["use_orthograd"]:
|
|
120
|
+
grad = _orthogonalize_gradient(p, grad)
|
|
121
|
+
state = self.state[p]
|
|
122
|
+
|
|
123
|
+
beta1, beta2 = group['betas']
|
|
124
|
+
|
|
125
|
+
# State Initialization
|
|
126
|
+
if len(state) == 0:
|
|
127
|
+
state['step'] = 0
|
|
128
|
+
|
|
129
|
+
should_factor = (
|
|
130
|
+
self.factored and
|
|
131
|
+
not (len(p.shape) == 1 and not group['vector_reshape'])
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
state['factored'] = should_factor
|
|
135
|
+
|
|
136
|
+
dtype = torch.float32 if self.factored else p.dtype
|
|
137
|
+
device = p.device
|
|
138
|
+
|
|
139
|
+
if state['factored']:
|
|
140
|
+
state['effective_shape'] = _get_effective_shape(p.numel())
|
|
141
|
+
d1, d2 = state['effective_shape']
|
|
142
|
+
|
|
143
|
+
# First moment (m)
|
|
144
|
+
if beta1 > 0:
|
|
145
|
+
state['mu_m_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
|
|
146
|
+
state['mv_m_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
|
|
147
|
+
if not self.use_grams:
|
|
148
|
+
packed_d2 = (d2 + 7) // 8
|
|
149
|
+
state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=device)
|
|
150
|
+
if self.use_AdEMAMix:
|
|
151
|
+
state['mu_m_slow_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
|
|
152
|
+
state['mv_m_slow_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
|
|
153
|
+
packed_d2 = (d2 + 7) // 8
|
|
154
|
+
state['sign_slow'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=p.device)
|
|
155
|
+
# Second moment (v)
|
|
156
|
+
state['mu_v_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
|
|
157
|
+
state['mv_v_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
|
|
158
|
+
else: # Fallback to standard AdamW for non-factored tensors
|
|
159
|
+
if beta1 > 0:
|
|
160
|
+
state['exp_avg'] = torch.zeros_like(p, device=device, dtype=dtype)
|
|
161
|
+
if self.use_AdEMAMix:
|
|
162
|
+
state['exp_avg_slow'] = torch.zeros_like(p, dtype=dtype)
|
|
163
|
+
state['exp_avg_sq'] = torch.zeros_like(p, device=device, dtype=dtype)
|
|
164
|
+
|
|
165
|
+
if self.use_AdEMAMix:
|
|
166
|
+
beta3_ema = group['beta3_ema']
|
|
167
|
+
alpha = group['alpha']
|
|
168
|
+
t_alpha = group['t_alpha']
|
|
169
|
+
current_step = state['step'] + 1
|
|
170
|
+
alpha_t = alpha
|
|
171
|
+
if t_alpha is not None and t_alpha > 0 and current_step < t_alpha:
|
|
172
|
+
alpha_t = min(current_step * alpha / t_alpha, alpha)
|
|
173
|
+
|
|
174
|
+
if state['factored']:
|
|
175
|
+
d1, d2 = state['effective_shape']
|
|
176
|
+
|
|
177
|
+
# Reconstruct momentum from previous step's factors
|
|
178
|
+
if beta1 > 0:
|
|
179
|
+
mt = _unnmf((state['mu_m_nmf'], state['mv_m_nmf']))
|
|
180
|
+
if not self.use_grams:
|
|
181
|
+
unpacked_sign = _unpack_bools(state['sign'], original_m=d2)
|
|
182
|
+
torch.where(unpacked_sign, mt, -mt, out=mt)
|
|
183
|
+
del unpacked_sign
|
|
184
|
+
# Update momentum in full-size
|
|
185
|
+
grad_reshaped = grad.view(d1, d2)
|
|
186
|
+
mt.mul_(beta1).add_(grad_reshaped, alpha=1.0 - beta1)
|
|
187
|
+
if self.use_grams:
|
|
188
|
+
mt.copy_(grad_reshaped.sign() * mt.abs())
|
|
189
|
+
elif self.use_cautious:
|
|
190
|
+
mask = (mt * grad_reshaped > 0).to(grad_reshaped.dtype)
|
|
191
|
+
mask.div_(mask.mean().clamp_(min=1e-3))
|
|
192
|
+
mt.mul_(mask)
|
|
193
|
+
del mask
|
|
194
|
+
|
|
195
|
+
vt = _unnmf((state['mu_v_nmf'], state['mv_v_nmf']))
|
|
196
|
+
vt.mul_(beta2).addcmul_(grad_reshaped, grad_reshaped, value=1.0 - beta2)
|
|
197
|
+
|
|
198
|
+
if self.use_AdEMAMix:
|
|
199
|
+
mt_slow = _unnmf((state['mu_m_slow_nmf'], state['mv_m_slow_nmf']))
|
|
200
|
+
if state['sign_slow'].dtype != torch.uint8:
|
|
201
|
+
state['sign_slow'] = state['sign_slow'].to(torch.uint8)
|
|
202
|
+
unpacked_sign_slow = _unpack_bools(state['sign_slow'], original_m=d2)
|
|
203
|
+
torch.where(unpacked_sign_slow, mt_slow, -mt_slow, out=mt_slow)
|
|
204
|
+
del unpacked_sign_slow
|
|
205
|
+
|
|
206
|
+
mt_slow.mul_(beta3_ema).add_(grad_reshaped, alpha=1.0 - beta3_ema)
|
|
207
|
+
update = mt + (alpha_t * mt_slow) if beta1 > 0 else grad_reshaped + (alpha_t * mt_slow)
|
|
208
|
+
else:
|
|
209
|
+
update = mt if beta1 > 0 else grad_reshaped
|
|
210
|
+
del grad_reshaped
|
|
211
|
+
|
|
212
|
+
if group['use_atan2']:
|
|
213
|
+
a = 1.2732395
|
|
214
|
+
denom = vt.sqrt()
|
|
215
|
+
update.atan2_(denom).mul_(a)
|
|
216
|
+
else:
|
|
217
|
+
denom = vt.sqrt()
|
|
218
|
+
update.div_(denom.add_(group['eps']))
|
|
219
|
+
del denom
|
|
220
|
+
|
|
221
|
+
update.view(p.shape).mul_(group['lr'])
|
|
222
|
+
|
|
223
|
+
# Compress updated moments and store new factors
|
|
224
|
+
if beta1 > 0:
|
|
225
|
+
if not self.use_grams:
|
|
226
|
+
state['sign'] = _pack_bools(mt > 0)
|
|
227
|
+
_nnmf(mt.abs(), out=(state['mu_m_nmf'], state['mv_m_nmf']))
|
|
228
|
+
del mt
|
|
229
|
+
if self.use_AdEMAMix:
|
|
230
|
+
state['sign_slow'] = _pack_bools(mt_slow > 0)
|
|
231
|
+
_nnmf(mt_slow.abs(), out=(state['mu_m_slow_nmf'], state['mv_m_slow_nmf']))
|
|
232
|
+
del mt_slow
|
|
233
|
+
_nnmf(vt, out=(state['mu_v_nmf'], state['mv_v_nmf']))
|
|
234
|
+
del vt
|
|
235
|
+
|
|
236
|
+
else: # Standard AdamW logic for non-factored tensors
|
|
237
|
+
exp_avg_sq = state['exp_avg_sq']
|
|
238
|
+
|
|
239
|
+
if beta1 > 0:
|
|
240
|
+
exp_avg = state['exp_avg']
|
|
241
|
+
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
|
|
242
|
+
if self.use_grams:
|
|
243
|
+
exp_avg = grad.sign() * exp_avg.abs()
|
|
244
|
+
elif self.use_cautious:
|
|
245
|
+
mask = (exp_avg * grad > 0).to(grad.dtype)
|
|
246
|
+
mask.div_(mask.mean().clamp_(min=1e-3))
|
|
247
|
+
exp_avg.mul_(mask)
|
|
248
|
+
del mask
|
|
249
|
+
|
|
250
|
+
if self.use_AdEMAMix:
|
|
251
|
+
exp_avg_slow = state['exp_avg_slow']
|
|
252
|
+
exp_avg_slow.mul_(beta3_ema).add_(grad, alpha=1 - beta3_ema)
|
|
253
|
+
update = exp_avg + (alpha_t * exp_avg_slow) if beta1 > 0 else grad + (alpha_t * exp_avg_slow)
|
|
254
|
+
else:
|
|
255
|
+
update = exp_avg if beta1 > 0 else grad
|
|
256
|
+
|
|
257
|
+
exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2)
|
|
258
|
+
|
|
259
|
+
if group['use_atan2']:
|
|
260
|
+
a = 1.2732395
|
|
261
|
+
denom = exp_avg_sq.sqrt()
|
|
262
|
+
update.atan2_(denom).mul_(a)
|
|
263
|
+
else:
|
|
264
|
+
denom = exp_avg_sq.sqrt()
|
|
265
|
+
update.div_(denom.add_(group['eps']))
|
|
266
|
+
del denom
|
|
267
|
+
|
|
268
|
+
update.mul_(group['lr'])
|
|
269
|
+
|
|
270
|
+
# Decoupled weight decay
|
|
271
|
+
if group["weight_decay"] != 0:
|
|
272
|
+
if p.dtype == torch.bfloat16 and self.stochastic_rounding:
|
|
273
|
+
add_stochastic_(p.data, p.data, alpha=-group["weight_decay"] * group["lr"])
|
|
274
|
+
else:
|
|
275
|
+
p.data.add_(p.data, alpha=-group["weight_decay"] * group["lr"])
|
|
276
|
+
|
|
277
|
+
if p.dtype == torch.bfloat16 and self.stochastic_rounding:
|
|
278
|
+
add_stochastic_(p.data, -update)
|
|
279
|
+
else:
|
|
280
|
+
p.data.add_(-update)
|
|
281
|
+
del update
|
|
282
|
+
|
|
283
|
+
state['step'] += 1
|
|
284
|
+
|
|
285
|
+
@torch.no_grad()
|
|
286
|
+
def step(self, closure=None):
|
|
287
|
+
"""Performs a single optimization step."""
|
|
288
|
+
loss = None
|
|
289
|
+
if closure is not None:
|
|
290
|
+
with torch.enable_grad():
|
|
291
|
+
loss = closure()
|
|
292
|
+
|
|
293
|
+
for group in self.param_groups:
|
|
294
|
+
for i, p in enumerate(group['params']):
|
|
295
|
+
self.step_parameter(p, group, i)
|
|
296
|
+
|
|
297
297
|
return loss
|
|
@@ -216,11 +216,19 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
|
|
|
216
216
|
|
|
217
217
|
# Update momentum m_t = β2*m_{t-1} + (1-β2)*lr*g_t
|
|
218
218
|
if self.variance_reduction:
|
|
219
|
-
|
|
220
|
-
|
|
219
|
+
if state['step'] == 1:
|
|
220
|
+
exp_avg.copy_(grad_reshaped)
|
|
221
|
+
else:
|
|
222
|
+
# Heuristic Prodigy-STORM update
|
|
223
|
+
correction = exp_avg.sub(state['prev_grad'])
|
|
224
|
+
grad_alpha = self.d * (1 - self.beta2) + self.beta2
|
|
225
|
+
exp_avg.copy_(grad_reshaped).mul_(grad_alpha).add_(correction, alpha=self.beta2)
|
|
226
|
+
del correction, grad_alpha
|
|
221
227
|
state['prev_grad'].copy_(grad_reshaped)
|
|
222
228
|
else:
|
|
223
|
-
|
|
229
|
+
# Standard Prodigy-Lion
|
|
230
|
+
alpha = self.d * (1 - self.beta2)
|
|
231
|
+
exp_avg.mul_(self.beta2).add_(grad_reshaped, alpha=alpha)
|
|
224
232
|
del grad_reshaped
|
|
225
233
|
|
|
226
234
|
# Compress new momentum m_t and store factors
|
|
@@ -247,11 +255,19 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
|
|
|
247
255
|
|
|
248
256
|
# Update momentum
|
|
249
257
|
if self.variance_reduction:
|
|
250
|
-
|
|
251
|
-
|
|
258
|
+
if state['step'] == 1:
|
|
259
|
+
exp_avg.copy_(grad)
|
|
260
|
+
else:
|
|
261
|
+
# Heuristic Prodigy-STORM update
|
|
262
|
+
correction = exp_avg.sub(state['prev_grad'])
|
|
263
|
+
grad_alpha = self.d * (1 - self.beta2) + self.beta2
|
|
264
|
+
exp_avg.copy_(grad).mul_(grad_alpha).add_(correction, alpha=self.beta2)
|
|
265
|
+
del grad_alpha, correction
|
|
252
266
|
state['prev_grad'].copy_(grad)
|
|
253
267
|
else:
|
|
254
|
-
|
|
268
|
+
# Standard Prodigy-Lion
|
|
269
|
+
alpha = self.d * (1 - self.beta2)
|
|
270
|
+
exp_avg.mul_(self.beta2).add_(grad, alpha=alpha)
|
|
255
271
|
|
|
256
272
|
# --- Accumulate Prodigy stats ---
|
|
257
273
|
d0, safeguard_warmup, slice_p = group['d0'], group['safeguard_warmup'], group['slice_p']
|
adv_optm/optim/Lion_adv.py
CHANGED
|
@@ -1,229 +1,243 @@
|
|
|
1
|
-
import torch
|
|
2
|
-
|
|
3
|
-
from typing import Tuple, Optional
|
|
4
|
-
|
|
5
|
-
from ..util.BF16_Stochastic_Rounding import add_stochastic_
|
|
6
|
-
from ..util.Effective_Shape import _get_effective_shape
|
|
7
|
-
from ..util.NNMF import _nnmf,_unnmf
|
|
8
|
-
from ..util.OrthoGrad import _orthogonalize_gradient
|
|
9
|
-
from ..util.One_Bit_Boolean import _pack_bools, _unpack_bools
|
|
10
|
-
|
|
11
|
-
class Lion_adv(torch.optim.Optimizer):
|
|
12
|
-
"""
|
|
13
|
-
Implements the SMMF technique for Lion algorithm.
|
|
14
|
-
|
|
15
|
-
This optimizer combines the Lion update rule with the memory-saving low-rank
|
|
16
|
-
compression (SMMF) technique from https://arxiv.org/abs/2412.08894.
|
|
17
|
-
|
|
18
|
-
Args:
|
|
19
|
-
params (iterable): iterable of parameters to optimize or dicts defining
|
|
20
|
-
parameter groups.
|
|
21
|
-
lr (float, optional): learning rate (default: 1e-4).
|
|
22
|
-
betas (Tuple[float, float], optional): coefficients for computing
|
|
23
|
-
running averages of the update (default: (0.9, 0.99)).
|
|
24
|
-
weight_decay (float, optional): weight decay (L2 penalty) (default: 0.0).
|
|
25
|
-
vector_reshape (bool, optional): whether to reshape 1D vectors into 2D
|
|
26
|
-
matrices to apply low-rank compression (default: True).
|
|
27
|
-
stochastic_rounding (bool, optional): whether to use stochastic
|
|
28
|
-
rounding for BF16 parameter updates (default: True).
|
|
29
|
-
use_cautious (bool): whether to use the cautious masking technique. (default: False).
|
|
30
|
-
clip_threshold (float, optional): whether to clip the gradients norm
|
|
31
|
-
per-parameter as proposed in the paper `Lions and Muons: Optimization via
|
|
32
|
-
Stochastic Frank-Wolfe` (https://arxiv.org/abs/2506.04192) to make Lion more stable
|
|
33
|
-
(default: 0.0).
|
|
34
|
-
factored (bool): whether to use the factorization or use the
|
|
35
|
-
uncompressed optimizer. (default: True)
|
|
36
|
-
variance_reduction (bool): whether to use the variance reduction technique
|
|
37
|
-
from "Convergence Analysis of the Lion Optimizer" (arXiv:2508.12327v1). (default: False).
|
|
38
|
-
"""
|
|
39
|
-
|
|
40
|
-
def __init__(
|
|
41
|
-
self,
|
|
42
|
-
params,
|
|
43
|
-
lr: float = 1e-4,
|
|
44
|
-
betas: Tuple[float, float] = (0.9, 0.99),
|
|
45
|
-
weight_decay: float = 0.0,
|
|
46
|
-
vector_reshape: bool = True,
|
|
47
|
-
stochastic_rounding: bool = True,
|
|
48
|
-
use_orthograd: bool = False,
|
|
49
|
-
use_cautious: bool = False,
|
|
50
|
-
clip_threshold: float = 0.0,
|
|
51
|
-
factored: bool = True,
|
|
52
|
-
variance_reduction: bool = False,
|
|
53
|
-
):
|
|
54
|
-
if not lr > 0.0:
|
|
55
|
-
raise ValueError(f"Learning rate must be > 0.0, but got {lr}")
|
|
56
|
-
if not all(0.0 <= beta <= 1.0 for beta in betas):
|
|
57
|
-
raise ValueError(f"Betas should be in [0.0, 1.0], but got {betas}")
|
|
58
|
-
if not weight_decay >= 0.0:
|
|
59
|
-
raise ValueError(f"Weight decay must be >= 0.0, but got {weight_decay}")
|
|
60
|
-
|
|
61
|
-
defaults = dict(
|
|
62
|
-
lr=lr,
|
|
63
|
-
betas=betas,
|
|
64
|
-
weight_decay=weight_decay,
|
|
65
|
-
vector_reshape=vector_reshape,
|
|
66
|
-
use_orthograd=use_orthograd,
|
|
67
|
-
clip_threshold=clip_threshold,
|
|
68
|
-
)
|
|
69
|
-
self.stochastic_rounding = stochastic_rounding
|
|
70
|
-
self.use_cautious = use_cautious
|
|
71
|
-
self.factored = factored
|
|
72
|
-
self.variance_reduction = variance_reduction
|
|
73
|
-
super().__init__(params, defaults)
|
|
74
|
-
|
|
75
|
-
@property
|
|
76
|
-
def supports_fused_back_pass(self) -> bool:
|
|
77
|
-
return True
|
|
78
|
-
|
|
79
|
-
@property
|
|
80
|
-
def supports_memory_efficient_fp16(self) -> bool:
|
|
81
|
-
return True
|
|
82
|
-
|
|
83
|
-
@property
|
|
84
|
-
def supports_flat_params(self) -> bool:
|
|
85
|
-
return False
|
|
86
|
-
|
|
87
|
-
@torch.no_grad()
|
|
88
|
-
def step_parameter(self, p: torch.Tensor, group: dict, i: Optional[int] = None):
|
|
89
|
-
"""Performs a single optimization step on a single parameter."""
|
|
90
|
-
if p.grad is None:
|
|
91
|
-
return
|
|
92
|
-
|
|
93
|
-
grad = p.grad
|
|
94
|
-
if grad.dtype != torch.float32 and self.factored:
|
|
95
|
-
grad = grad.float()
|
|
96
|
-
if group["clip_threshold"] > 0.0:
|
|
97
|
-
grad_norm = torch.norm(grad.detach())
|
|
98
|
-
if grad_norm > group["clip_threshold"]:
|
|
99
|
-
clip_coef = group["clip_threshold"] / grad_norm
|
|
100
|
-
grad.mul_(clip_coef)
|
|
101
|
-
if group["use_orthograd"]:
|
|
102
|
-
grad = _orthogonalize_gradient(p, grad)
|
|
103
|
-
state = self.state[p]
|
|
104
|
-
|
|
105
|
-
# State Initialization
|
|
106
|
-
if len(state) == 0:
|
|
107
|
-
state['step'] = 0
|
|
108
|
-
|
|
109
|
-
should_factor = (
|
|
110
|
-
self.factored and
|
|
111
|
-
not (len(p.shape) == 1 and not group['vector_reshape'])
|
|
112
|
-
)
|
|
113
|
-
|
|
114
|
-
state['factored'] = should_factor
|
|
115
|
-
|
|
116
|
-
dtype = torch.float32 if self.factored else p.dtype
|
|
117
|
-
|
|
118
|
-
if state['factored']:
|
|
119
|
-
state['effective_shape'] = _get_effective_shape(p.numel())
|
|
120
|
-
d1, d2 = state['effective_shape']
|
|
121
|
-
state['mu_m_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
|
|
122
|
-
state['mv_m_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
|
|
123
|
-
packed_d2 = (d2 + 7) // 8
|
|
124
|
-
state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=p.device)
|
|
125
|
-
if self.variance_reduction:
|
|
126
|
-
state['prev_grad'] = torch.zeros((d1, d2), device=p.device, dtype=dtype)
|
|
127
|
-
else: # Fallback to standard Lion
|
|
128
|
-
state['exp_avg'] = torch.zeros_like(p, device=p.device, dtype=dtype)
|
|
129
|
-
if self.variance_reduction:
|
|
130
|
-
state['prev_grad'] = torch.zeros_like(p, device=p.device, dtype=dtype)
|
|
131
|
-
|
|
132
|
-
state['step'] += 1
|
|
133
|
-
beta1, beta2 = group["betas"]
|
|
134
|
-
lr = group["lr"]
|
|
135
|
-
|
|
136
|
-
if state['factored']:
|
|
137
|
-
# Factored Path
|
|
138
|
-
d1, d2 = state['effective_shape']
|
|
139
|
-
grad_reshaped = grad.view(d1, d2)
|
|
140
|
-
# Reconstruct momentum m_{t-1}
|
|
141
|
-
exp_avg = _unnmf((state['mu_m_nmf'], state['mv_m_nmf']))
|
|
142
|
-
unpacked_sign = _unpack_bools(state['sign'], original_m=d2)
|
|
143
|
-
torch.where(unpacked_sign, exp_avg, -exp_avg, out=exp_avg)
|
|
144
|
-
del unpacked_sign
|
|
145
|
-
if exp_avg.dtype != torch.float32:
|
|
146
|
-
exp_avg = exp_avg.float()
|
|
147
|
-
|
|
148
|
-
# Compute update term c_t
|
|
149
|
-
signed_update = exp_avg.clone().mul_(beta1).add_(grad_reshaped, alpha=(1-beta1)).sign_()
|
|
150
|
-
|
|
151
|
-
if self.use_cautious:
|
|
152
|
-
mask = (signed_update * grad_reshaped > 0).to(grad_reshaped.dtype)
|
|
153
|
-
mask.div_(mask.mean().clamp_(min=1e-3))
|
|
154
|
-
signed_update.mul_(mask)
|
|
155
|
-
del mask
|
|
156
|
-
|
|
157
|
-
# Parameter update
|
|
158
|
-
update_for_param = signed_update.view(p.shape).mul_(lr)
|
|
159
|
-
|
|
160
|
-
# Update momentum
|
|
161
|
-
if self.variance_reduction:
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
#
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
from typing import Tuple, Optional
|
|
4
|
+
|
|
5
|
+
from ..util.BF16_Stochastic_Rounding import add_stochastic_
|
|
6
|
+
from ..util.Effective_Shape import _get_effective_shape
|
|
7
|
+
from ..util.NNMF import _nnmf,_unnmf
|
|
8
|
+
from ..util.OrthoGrad import _orthogonalize_gradient
|
|
9
|
+
from ..util.One_Bit_Boolean import _pack_bools, _unpack_bools
|
|
10
|
+
|
|
11
|
+
class Lion_adv(torch.optim.Optimizer):
|
|
12
|
+
"""
|
|
13
|
+
Implements the SMMF technique for Lion algorithm.
|
|
14
|
+
|
|
15
|
+
This optimizer combines the Lion update rule with the memory-saving low-rank
|
|
16
|
+
compression (SMMF) technique from https://arxiv.org/abs/2412.08894.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
params (iterable): iterable of parameters to optimize or dicts defining
|
|
20
|
+
parameter groups.
|
|
21
|
+
lr (float, optional): learning rate (default: 1e-4).
|
|
22
|
+
betas (Tuple[float, float], optional): coefficients for computing
|
|
23
|
+
running averages of the update (default: (0.9, 0.99)).
|
|
24
|
+
weight_decay (float, optional): weight decay (L2 penalty) (default: 0.0).
|
|
25
|
+
vector_reshape (bool, optional): whether to reshape 1D vectors into 2D
|
|
26
|
+
matrices to apply low-rank compression (default: True).
|
|
27
|
+
stochastic_rounding (bool, optional): whether to use stochastic
|
|
28
|
+
rounding for BF16 parameter updates (default: True).
|
|
29
|
+
use_cautious (bool): whether to use the cautious masking technique. (default: False).
|
|
30
|
+
clip_threshold (float, optional): whether to clip the gradients norm
|
|
31
|
+
per-parameter as proposed in the paper `Lions and Muons: Optimization via
|
|
32
|
+
Stochastic Frank-Wolfe` (https://arxiv.org/abs/2506.04192) to make Lion more stable
|
|
33
|
+
(default: 0.0).
|
|
34
|
+
factored (bool): whether to use the factorization or use the
|
|
35
|
+
uncompressed optimizer. (default: True)
|
|
36
|
+
variance_reduction (bool): whether to use the variance reduction technique
|
|
37
|
+
from "Convergence Analysis of the Lion Optimizer" (arXiv:2508.12327v1). (default: False).
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
def __init__(
|
|
41
|
+
self,
|
|
42
|
+
params,
|
|
43
|
+
lr: float = 1e-4,
|
|
44
|
+
betas: Tuple[float, float] = (0.9, 0.99),
|
|
45
|
+
weight_decay: float = 0.0,
|
|
46
|
+
vector_reshape: bool = True,
|
|
47
|
+
stochastic_rounding: bool = True,
|
|
48
|
+
use_orthograd: bool = False,
|
|
49
|
+
use_cautious: bool = False,
|
|
50
|
+
clip_threshold: float = 0.0,
|
|
51
|
+
factored: bool = True,
|
|
52
|
+
variance_reduction: bool = False,
|
|
53
|
+
):
|
|
54
|
+
if not lr > 0.0:
|
|
55
|
+
raise ValueError(f"Learning rate must be > 0.0, but got {lr}")
|
|
56
|
+
if not all(0.0 <= beta <= 1.0 for beta in betas):
|
|
57
|
+
raise ValueError(f"Betas should be in [0.0, 1.0], but got {betas}")
|
|
58
|
+
if not weight_decay >= 0.0:
|
|
59
|
+
raise ValueError(f"Weight decay must be >= 0.0, but got {weight_decay}")
|
|
60
|
+
|
|
61
|
+
defaults = dict(
|
|
62
|
+
lr=lr,
|
|
63
|
+
betas=betas,
|
|
64
|
+
weight_decay=weight_decay,
|
|
65
|
+
vector_reshape=vector_reshape,
|
|
66
|
+
use_orthograd=use_orthograd,
|
|
67
|
+
clip_threshold=clip_threshold,
|
|
68
|
+
)
|
|
69
|
+
self.stochastic_rounding = stochastic_rounding
|
|
70
|
+
self.use_cautious = use_cautious
|
|
71
|
+
self.factored = factored
|
|
72
|
+
self.variance_reduction = variance_reduction
|
|
73
|
+
super().__init__(params, defaults)
|
|
74
|
+
|
|
75
|
+
@property
|
|
76
|
+
def supports_fused_back_pass(self) -> bool:
|
|
77
|
+
return True
|
|
78
|
+
|
|
79
|
+
@property
|
|
80
|
+
def supports_memory_efficient_fp16(self) -> bool:
|
|
81
|
+
return True
|
|
82
|
+
|
|
83
|
+
@property
|
|
84
|
+
def supports_flat_params(self) -> bool:
|
|
85
|
+
return False
|
|
86
|
+
|
|
87
|
+
@torch.no_grad()
|
|
88
|
+
def step_parameter(self, p: torch.Tensor, group: dict, i: Optional[int] = None):
|
|
89
|
+
"""Performs a single optimization step on a single parameter."""
|
|
90
|
+
if p.grad is None:
|
|
91
|
+
return
|
|
92
|
+
|
|
93
|
+
grad = p.grad
|
|
94
|
+
if grad.dtype != torch.float32 and self.factored:
|
|
95
|
+
grad = grad.float()
|
|
96
|
+
if group["clip_threshold"] > 0.0:
|
|
97
|
+
grad_norm = torch.norm(grad.detach())
|
|
98
|
+
if grad_norm > group["clip_threshold"]:
|
|
99
|
+
clip_coef = group["clip_threshold"] / grad_norm
|
|
100
|
+
grad.mul_(clip_coef)
|
|
101
|
+
if group["use_orthograd"]:
|
|
102
|
+
grad = _orthogonalize_gradient(p, grad)
|
|
103
|
+
state = self.state[p]
|
|
104
|
+
|
|
105
|
+
# State Initialization
|
|
106
|
+
if len(state) == 0:
|
|
107
|
+
state['step'] = 0
|
|
108
|
+
|
|
109
|
+
should_factor = (
|
|
110
|
+
self.factored and
|
|
111
|
+
not (len(p.shape) == 1 and not group['vector_reshape'])
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
state['factored'] = should_factor
|
|
115
|
+
|
|
116
|
+
dtype = torch.float32 if self.factored else p.dtype
|
|
117
|
+
|
|
118
|
+
if state['factored']:
|
|
119
|
+
state['effective_shape'] = _get_effective_shape(p.numel())
|
|
120
|
+
d1, d2 = state['effective_shape']
|
|
121
|
+
state['mu_m_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
|
|
122
|
+
state['mv_m_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
|
|
123
|
+
packed_d2 = (d2 + 7) // 8
|
|
124
|
+
state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=p.device)
|
|
125
|
+
if self.variance_reduction:
|
|
126
|
+
state['prev_grad'] = torch.zeros((d1, d2), device=p.device, dtype=dtype)
|
|
127
|
+
else: # Fallback to standard Lion
|
|
128
|
+
state['exp_avg'] = torch.zeros_like(p, device=p.device, dtype=dtype)
|
|
129
|
+
if self.variance_reduction:
|
|
130
|
+
state['prev_grad'] = torch.zeros_like(p, device=p.device, dtype=dtype)
|
|
131
|
+
|
|
132
|
+
state['step'] += 1
|
|
133
|
+
beta1, beta2 = group["betas"]
|
|
134
|
+
lr = group["lr"]
|
|
135
|
+
|
|
136
|
+
if state['factored']:
|
|
137
|
+
# Factored Path
|
|
138
|
+
d1, d2 = state['effective_shape']
|
|
139
|
+
grad_reshaped = grad.view(d1, d2)
|
|
140
|
+
# Reconstruct momentum m_{t-1}
|
|
141
|
+
exp_avg = _unnmf((state['mu_m_nmf'], state['mv_m_nmf']))
|
|
142
|
+
unpacked_sign = _unpack_bools(state['sign'], original_m=d2)
|
|
143
|
+
torch.where(unpacked_sign, exp_avg, -exp_avg, out=exp_avg)
|
|
144
|
+
del unpacked_sign
|
|
145
|
+
if exp_avg.dtype != torch.float32:
|
|
146
|
+
exp_avg = exp_avg.float()
|
|
147
|
+
|
|
148
|
+
# Compute update term c_t
|
|
149
|
+
signed_update = exp_avg.clone().mul_(beta1).add_(grad_reshaped, alpha=(1-beta1)).sign_()
|
|
150
|
+
|
|
151
|
+
if self.use_cautious:
|
|
152
|
+
mask = (signed_update * grad_reshaped > 0).to(grad_reshaped.dtype)
|
|
153
|
+
mask.div_(mask.mean().clamp_(min=1e-3))
|
|
154
|
+
signed_update.mul_(mask)
|
|
155
|
+
del mask
|
|
156
|
+
|
|
157
|
+
# Parameter update
|
|
158
|
+
update_for_param = signed_update.view(p.shape).mul_(lr)
|
|
159
|
+
|
|
160
|
+
# Update momentum
|
|
161
|
+
if self.variance_reduction:
|
|
162
|
+
if state['step'] == 1:
|
|
163
|
+
exp_avg.copy_(grad_reshaped)
|
|
164
|
+
else:
|
|
165
|
+
# Use the simplified STORM update: m_t = g_t + β₂ * (m_{t-1} - g_{t-1})
|
|
166
|
+
correction = exp_avg.sub(state['prev_grad'])
|
|
167
|
+
# Calculate the new momentum and store it back into exp_avg
|
|
168
|
+
exp_avg.copy_(grad_reshaped).add_(correction, alpha=beta2)
|
|
169
|
+
del correction
|
|
170
|
+
# Update prev_grad for the next iteration
|
|
171
|
+
state['prev_grad'].copy_(grad_reshaped)
|
|
172
|
+
else:
|
|
173
|
+
# Standard Lion momentum update
|
|
174
|
+
exp_avg.mul_(beta2).add_(grad_reshaped, alpha=1-beta2)
|
|
175
|
+
|
|
176
|
+
# Compress new momentum m_t and store factors
|
|
177
|
+
state['sign'] = _pack_bools(exp_avg > 0)
|
|
178
|
+
_nnmf(exp_avg.abs(), out=(state['mu_m_nmf'], state['mv_m_nmf']))
|
|
179
|
+
del exp_avg
|
|
180
|
+
|
|
181
|
+
else:
|
|
182
|
+
# Fallback to standard Lion logic
|
|
183
|
+
exp_avg = state["exp_avg"]
|
|
184
|
+
|
|
185
|
+
# Compute update term and sign for the update
|
|
186
|
+
if exp_avg.dtype != torch.float32 and self.factored:
|
|
187
|
+
exp_avg = exp_avg.float()
|
|
188
|
+
signed_update = exp_avg.clone().mul_(beta1).add_(grad, alpha=(1-beta1)).sign_()
|
|
189
|
+
|
|
190
|
+
if self.use_cautious:
|
|
191
|
+
mask = (signed_update * grad > 0).to(grad.dtype)
|
|
192
|
+
mask.div_(mask.mean().clamp_(min=1e-3))
|
|
193
|
+
signed_update.mul_(mask)
|
|
194
|
+
del mask
|
|
195
|
+
|
|
196
|
+
update_for_param = signed_update.mul_(lr)
|
|
197
|
+
|
|
198
|
+
# Update momentum
|
|
199
|
+
if self.variance_reduction:
|
|
200
|
+
if state['step'] == 1:
|
|
201
|
+
exp_avg.copy_(grad)
|
|
202
|
+
else:
|
|
203
|
+
# Use the simplified STORM update: m_t = g_t + β₂ * (m_{t-1} - g_{t-1})
|
|
204
|
+
correction = exp_avg.sub(state['prev_grad'])
|
|
205
|
+
# Calculate the new momentum and store it back into exp_avg
|
|
206
|
+
exp_avg.copy_(grad).add_(correction, alpha=beta2)
|
|
207
|
+
del correction
|
|
208
|
+
# Update prev_grad for the next iteration
|
|
209
|
+
state['prev_grad'].copy_(grad)
|
|
210
|
+
else:
|
|
211
|
+
# Standard Lion momentum update
|
|
212
|
+
exp_avg.mul_(beta2).add_(grad, alpha=1-beta2)
|
|
213
|
+
|
|
214
|
+
if group["weight_decay"] != 0:
|
|
215
|
+
if p.dtype == torch.bfloat16 and self.stochastic_rounding:
|
|
216
|
+
add_stochastic_(p.data, p.data,
|
|
217
|
+
alpha=-group["weight_decay"] * lr)
|
|
218
|
+
else:
|
|
219
|
+
p.data.add_(
|
|
220
|
+
p.data, alpha=-group["weight_decay"] * lr
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
if p.dtype == torch.bfloat16 and self.stochastic_rounding:
|
|
224
|
+
add_stochastic_(p.data, -update_for_param)
|
|
225
|
+
else:
|
|
226
|
+
p.data.add_(-update_for_param)
|
|
227
|
+
|
|
228
|
+
del update_for_param
|
|
229
|
+
|
|
230
|
+
@torch.no_grad()
|
|
231
|
+
def step(self, closure: Optional[callable] = None):
|
|
232
|
+
"""Performs a single optimization step."""
|
|
233
|
+
loss = None
|
|
234
|
+
if closure is not None:
|
|
235
|
+
with torch.enable_grad():
|
|
236
|
+
loss = closure()
|
|
237
|
+
|
|
238
|
+
for group in self.param_groups:
|
|
239
|
+
for i, p in enumerate(group["params"]):
|
|
240
|
+
if p.grad is not None:
|
|
241
|
+
self.step_parameter(p, group, i)
|
|
242
|
+
|
|
229
243
|
return loss
|
|
@@ -1,8 +1,8 @@
|
|
|
1
|
-
adv_optm/__init__.py,sha256=
|
|
2
|
-
adv_optm/optim/AdamW_adv.py,sha256=
|
|
1
|
+
adv_optm/__init__.py,sha256=CNgGMUz72nHycvrpa4VwrBs-qbehDdMJcnnJVvMRiqI,252
|
|
2
|
+
adv_optm/optim/AdamW_adv.py,sha256=BF-h4g3g_mJlEwxCCMFCSH4cbnoDxsrtnDO2cvUcBPM,13183
|
|
3
3
|
adv_optm/optim/Adopt_adv.py,sha256=rzBWfFOPrMuC6vwETsw7QPKmVXcv4IJRDCTj-6eU1Qk,14798
|
|
4
|
-
adv_optm/optim/Lion_Prodigy_adv.py,sha256=
|
|
5
|
-
adv_optm/optim/Lion_adv.py,sha256=
|
|
4
|
+
adv_optm/optim/Lion_Prodigy_adv.py,sha256=JMss9X8lRpIU4E34PfFpWMMal_XNvZ8Yuqc6i7R5wIQ,14588
|
|
5
|
+
adv_optm/optim/Lion_adv.py,sha256=BA4bSEhJiQ7BhGLDRn9nuMlBrLVh-OMscbmSTeGgRmI,10137
|
|
6
6
|
adv_optm/optim/Prodigy_adv.py,sha256=H7MrZMjCkZdsHBXY17Jm7aTFNySoVkIXQSszdoHn6u4,17697
|
|
7
7
|
adv_optm/optim/__init__.py,sha256=e5UighM92LDvDB2JJwj8gDsTpXEedpytScwqS6F2FR8,300
|
|
8
8
|
adv_optm/util/BF16_Stochastic_Rounding.py,sha256=Q5H0BcogmE4atP65dLoI21HKSf50lRdsBDfeF6v9Tbg,1548
|
|
@@ -11,8 +11,8 @@ adv_optm/util/NNMF.py,sha256=yRf5IP5Sjq0Uf0DxN0Q8NxEGSdD-f1ULziLVDOjY8K4,639
|
|
|
11
11
|
adv_optm/util/One_Bit_Boolean.py,sha256=Wat49esdwohuN-OHOFMW8D0aOQgV9cP5Rl8z6yfmpos,1068
|
|
12
12
|
adv_optm/util/OrthoGrad.py,sha256=NzInuBQGy_Ja__M1R9XbvqVaQ0fhGbtGgFE9YON7B3I,707
|
|
13
13
|
adv_optm/util/__init__.py,sha256=qoyIF0jcLjs_vSEcsv36clw5LFNBEbifyXrrVxMH-G4,349
|
|
14
|
-
adv_optm-0.1.
|
|
15
|
-
adv_optm-0.1.
|
|
16
|
-
adv_optm-0.1.
|
|
17
|
-
adv_optm-0.1.
|
|
18
|
-
adv_optm-0.1.
|
|
14
|
+
adv_optm-0.1.4.dist-info/licenses/LICENSE,sha256=HrhfyXIkWY2tGFK11kg7vPCqhgh5DcxleloqdhrpyMY,11558
|
|
15
|
+
adv_optm-0.1.4.dist-info/METADATA,sha256=rlzAzZdUBHcz-j7xInfMz95jOi_SfRw5aOotm_MtY1o,5846
|
|
16
|
+
adv_optm-0.1.4.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
17
|
+
adv_optm-0.1.4.dist-info/top_level.txt,sha256=iNfBIIzu-lPrQ7jyC56WBCcbkRwitM2nJ15-MRQ_6fg,9
|
|
18
|
+
adv_optm-0.1.4.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|