adv-optm 0.1.1__py3-none-any.whl → 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of adv-optm might be problematic. Click here for more details.
- adv_optm/__init__.py +5 -1
- adv_optm/optim/AdamW_adv.py +58 -54
- adv_optm/optim/Adopt_adv.py +13 -12
- adv_optm/optim/Lion_Prodigy_adv.py +335 -0
- adv_optm/optim/Lion_adv.py +231 -0
- adv_optm/optim/Prodigy_adv.py +5 -5
- adv_optm/optim/__init__.py +4 -0
- adv_optm/util/__init__.py +1 -1
- {adv_optm-0.1.1.dist-info → adv_optm-0.1.2.dist-info}/METADATA +2 -6
- adv_optm-0.1.2.dist-info/RECORD +18 -0
- adv_optm/util/Randomized_SVD.py +0 -37
- adv_optm-0.1.1.dist-info/RECORD +0 -17
- {adv_optm-0.1.1.dist-info → adv_optm-0.1.2.dist-info}/WHEEL +0 -0
- {adv_optm-0.1.1.dist-info → adv_optm-0.1.2.dist-info}/licenses/LICENSE +0 -0
- {adv_optm-0.1.1.dist-info → adv_optm-0.1.2.dist-info}/top_level.txt +0 -0
adv_optm/__init__.py
CHANGED
|
@@ -2,12 +2,16 @@ from .optim import (
|
|
|
2
2
|
AdamW_adv,
|
|
3
3
|
Prodigy_adv,
|
|
4
4
|
Adopt_adv,
|
|
5
|
+
Lion_adv,
|
|
6
|
+
Lion_Prodigy_adv,
|
|
5
7
|
)
|
|
6
8
|
|
|
7
9
|
__all__ = [
|
|
8
10
|
"AdamW_adv",
|
|
9
11
|
"Prodigy_adv",
|
|
10
12
|
"Adopt_adv",
|
|
13
|
+
"Lion_adv",
|
|
14
|
+
"Lion_Prodigy_adv",
|
|
11
15
|
]
|
|
12
16
|
|
|
13
|
-
__version__ = "0.1.
|
|
17
|
+
__version__ = "0.1.2"
|
adv_optm/optim/AdamW_adv.py
CHANGED
|
@@ -22,7 +22,6 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
22
22
|
eps (float): term added to the denominator to improve
|
|
23
23
|
numerical stability (default: 1e-8)
|
|
24
24
|
weight_decay (float): weight decay (L2 penalty) (default: 0)
|
|
25
|
-
use_bias_correction (boolean): Turn on Adam's bias correction. (default: False)
|
|
26
25
|
vector_reshape (bool): whether to reshape 1D vectors into 2D
|
|
27
26
|
matrices to apply low-rank compression (default: True).
|
|
28
27
|
stochastic_rounding (bool): whether to use stochastic
|
|
@@ -63,7 +62,6 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
63
62
|
betas: tuple[float, float] = (0.9, 0.999),
|
|
64
63
|
eps: float = 1e-8,
|
|
65
64
|
weight_decay: float = 0.0,
|
|
66
|
-
use_bias_correction: bool = False,
|
|
67
65
|
vector_reshape: bool = True,
|
|
68
66
|
stochastic_rounding: bool = True,
|
|
69
67
|
use_atan2: bool = False,
|
|
@@ -88,7 +86,7 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
88
86
|
defaults = {
|
|
89
87
|
"lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay,
|
|
90
88
|
"vector_reshape": vector_reshape, "use_atan2": use_atan2,
|
|
91
|
-
"use_orthograd": use_orthograd,
|
|
89
|
+
"use_orthograd": use_orthograd,
|
|
92
90
|
"beta3_ema": beta3_ema, "alpha": alpha, "t_alpha": t_alpha,
|
|
93
91
|
}
|
|
94
92
|
self.stochastic_rounding = stochastic_rounding
|
|
@@ -122,6 +120,8 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
122
120
|
grad = _orthogonalize_gradient(p, grad)
|
|
123
121
|
state = self.state[p]
|
|
124
122
|
|
|
123
|
+
beta1, beta2 = group['betas']
|
|
124
|
+
|
|
125
125
|
# State Initialization
|
|
126
126
|
if len(state) == 0:
|
|
127
127
|
state['step'] = 0
|
|
@@ -141,11 +141,12 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
141
141
|
d1, d2 = state['effective_shape']
|
|
142
142
|
|
|
143
143
|
# First moment (m)
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
144
|
+
if beta1 > 0:
|
|
145
|
+
state['mu_m_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
|
|
146
|
+
state['mv_m_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
|
|
147
|
+
if not self.use_grams:
|
|
148
|
+
packed_d2 = (d2 + 7) // 8
|
|
149
|
+
state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=device)
|
|
149
150
|
if self.use_AdEMAMix:
|
|
150
151
|
state['mu_m_slow_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
|
|
151
152
|
state['mv_m_slow_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
|
|
@@ -155,12 +156,12 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
155
156
|
state['mu_v_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
|
|
156
157
|
state['mv_v_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
|
|
157
158
|
else: # Fallback to standard AdamW for non-factored tensors
|
|
158
|
-
|
|
159
|
+
if beta1 > 0:
|
|
160
|
+
state['exp_avg'] = torch.zeros_like(p, device=device, dtype=dtype)
|
|
159
161
|
if self.use_AdEMAMix:
|
|
160
162
|
state['exp_avg_slow'] = torch.zeros_like(p, dtype=dtype)
|
|
161
163
|
state['exp_avg_sq'] = torch.zeros_like(p, device=device, dtype=dtype)
|
|
162
164
|
|
|
163
|
-
beta1, beta2 = group['betas']
|
|
164
165
|
if self.use_AdEMAMix:
|
|
165
166
|
beta3_ema = group['beta3_ema']
|
|
166
167
|
alpha = group['alpha']
|
|
@@ -174,21 +175,22 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
174
175
|
d1, d2 = state['effective_shape']
|
|
175
176
|
|
|
176
177
|
# Reconstruct momentum from previous step's factors
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
178
|
+
if beta1 > 0:
|
|
179
|
+
mt = _unnmf((state['mu_m_nmf'], state['mv_m_nmf']))
|
|
180
|
+
if not self.use_grams:
|
|
181
|
+
unpacked_sign = _unpack_bools(state['sign'], original_m=d2)
|
|
182
|
+
torch.where(unpacked_sign, mt, -mt, out=mt)
|
|
183
|
+
del unpacked_sign
|
|
184
|
+
# Update momentum in full-size
|
|
185
|
+
grad_reshaped = grad.view(d1, d2)
|
|
186
|
+
mt.mul_(beta1).add_(grad_reshaped, alpha=1.0 - beta1)
|
|
187
|
+
if self.use_grams:
|
|
188
|
+
mt.copy_(grad_reshaped.sign() * mt.abs())
|
|
189
|
+
elif self.use_cautious:
|
|
190
|
+
mask = (mt * grad_reshaped > 0).to(grad_reshaped.dtype)
|
|
191
|
+
mask.div_(mask.mean().clamp_(min=1e-3))
|
|
192
|
+
mt.mul_(mask)
|
|
193
|
+
del mask
|
|
192
194
|
|
|
193
195
|
vt = _unnmf((state['mu_v_nmf'], state['mv_v_nmf']))
|
|
194
196
|
vt.mul_(beta2).addcmul_(grad_reshaped, grad_reshaped, value=1.0 - beta2)
|
|
@@ -202,28 +204,28 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
202
204
|
del unpacked_sign_slow
|
|
203
205
|
|
|
204
206
|
mt_slow.mul_(beta3_ema).add_(grad_reshaped, alpha=1.0 - beta3_ema)
|
|
205
|
-
|
|
207
|
+
update = mt + (alpha_t * mt_slow) if beta1 > 0 else grad_reshaped + (alpha_t * mt_slow)
|
|
206
208
|
else:
|
|
207
|
-
|
|
209
|
+
update = mt if beta1 > 0 else grad_reshaped
|
|
208
210
|
del grad_reshaped
|
|
209
211
|
|
|
210
212
|
if group['use_atan2']:
|
|
211
213
|
a = 1.2732395
|
|
212
214
|
denom = vt.sqrt()
|
|
213
|
-
update
|
|
215
|
+
update.atan2_(denom).mul_(a)
|
|
214
216
|
else:
|
|
215
|
-
denom = vt.sqrt()
|
|
216
|
-
update
|
|
217
|
-
del
|
|
217
|
+
denom = vt.sqrt()
|
|
218
|
+
update.div_(denom.add_(group['eps']))
|
|
219
|
+
del denom
|
|
218
220
|
|
|
219
|
-
update
|
|
220
|
-
update.mul_(group['lr'])
|
|
221
|
+
update.view(p.shape).mul_(group['lr'])
|
|
221
222
|
|
|
222
223
|
# Compress updated moments and store new factors
|
|
223
|
-
if
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
224
|
+
if beta1 > 0:
|
|
225
|
+
if not self.use_grams:
|
|
226
|
+
state['sign'] = _pack_bools(mt > 0)
|
|
227
|
+
_nnmf(mt.abs(), out=(state['mu_m_nmf'], state['mv_m_nmf']))
|
|
228
|
+
del mt
|
|
227
229
|
if self.use_AdEMAMix:
|
|
228
230
|
state['sign_slow'] = _pack_bools(mt_slow > 0)
|
|
229
231
|
_nnmf(mt_slow.abs(), out=(state['mu_m_slow_nmf'], state['mv_m_slow_nmf']))
|
|
@@ -232,36 +234,38 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
232
234
|
del vt
|
|
233
235
|
|
|
234
236
|
else: # Standard AdamW logic for non-factored tensors
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
exp_avg
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
237
|
+
exp_avg_sq = state['exp_avg_sq']
|
|
238
|
+
|
|
239
|
+
if beta1 > 0:
|
|
240
|
+
exp_avg = state['exp_avg']
|
|
241
|
+
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
|
|
242
|
+
if self.use_grams:
|
|
243
|
+
exp_avg = grad.sign() * exp_avg.abs()
|
|
244
|
+
elif self.use_cautious:
|
|
245
|
+
mask = (exp_avg * grad > 0).to(grad.dtype)
|
|
246
|
+
mask.div_(mask.mean().clamp_(min=1e-3))
|
|
247
|
+
exp_avg.mul_(mask)
|
|
248
|
+
del mask
|
|
245
249
|
|
|
246
250
|
if self.use_AdEMAMix:
|
|
247
251
|
exp_avg_slow = state['exp_avg_slow']
|
|
248
252
|
exp_avg_slow.mul_(beta3_ema).add_(grad, alpha=1 - beta3_ema)
|
|
249
|
-
|
|
253
|
+
update = exp_avg + (alpha_t * exp_avg_slow) if beta1 > 0 else grad + (alpha_t * exp_avg_slow)
|
|
250
254
|
else:
|
|
251
|
-
|
|
255
|
+
update = exp_avg if beta1 > 0 else grad
|
|
252
256
|
|
|
253
257
|
exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2)
|
|
254
258
|
|
|
255
259
|
if group['use_atan2']:
|
|
256
260
|
a = 1.2732395
|
|
257
261
|
denom = exp_avg_sq.sqrt()
|
|
258
|
-
update
|
|
262
|
+
update.atan2_(denom).mul_(a)
|
|
259
263
|
else:
|
|
260
|
-
denom = exp_avg_sq.sqrt()
|
|
261
|
-
update
|
|
262
|
-
del
|
|
264
|
+
denom = exp_avg_sq.sqrt()
|
|
265
|
+
update.div_(denom.add_(group['eps']))
|
|
266
|
+
del denom
|
|
263
267
|
|
|
264
|
-
update
|
|
268
|
+
update.mul_(group['lr'])
|
|
265
269
|
|
|
266
270
|
# Decoupled weight decay
|
|
267
271
|
if group["weight_decay"] != 0:
|
adv_optm/optim/Adopt_adv.py
CHANGED
|
@@ -30,8 +30,6 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
30
30
|
clip_lambda (Callable, optional): A function that takes the current step
|
|
31
31
|
and returns a value to clip the normalized gradient. Only used when
|
|
32
32
|
`use_atan2` is False. (default: `lambda step: step**0.25`)
|
|
33
|
-
rank (int): the rank for the low-rank approximation (default: 4).
|
|
34
|
-
oversampling (int): oversampling parameter for Randomized SVD. (default: 0).
|
|
35
33
|
vector_reshape (bool): whether to reshape 1D vectors into 2D
|
|
36
34
|
matrices for low-rank compression (default: True).
|
|
37
35
|
stochastic_rounding (bool): whether to use stochastic
|
|
@@ -192,29 +190,29 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
192
190
|
d1, d2 = state['effective_shape']
|
|
193
191
|
|
|
194
192
|
# Reconstruct m_{t-1}
|
|
195
|
-
|
|
193
|
+
mt = _unnmf((state['mu_m_nmf'], state['mv_m_nmf']))
|
|
196
194
|
if not self.use_grams:
|
|
197
195
|
if state['sign'].dtype != torch.uint8:
|
|
198
196
|
state['sign'] = state['sign'].to(torch.uint8)
|
|
199
197
|
unpacked_sign = _unpack_bools(state['sign'], original_m=d2)
|
|
200
|
-
torch.where(unpacked_sign,
|
|
198
|
+
torch.where(unpacked_sign, mt, -mt, out=mt)
|
|
201
199
|
del unpacked_sign
|
|
202
200
|
|
|
203
201
|
# Reconstruct AdEMAMix EMA
|
|
204
202
|
if self.use_AdEMAMix:
|
|
205
|
-
|
|
203
|
+
mt_slow = _unnmf((state['mu_m_slow_nmf'], state['mv_m_slow_nmf']))
|
|
206
204
|
if state['sign_slow'].dtype != torch.uint8:
|
|
207
205
|
state['sign_slow'] = state['sign_slow'].to(torch.uint8)
|
|
208
206
|
unpacked_sign_slow = _unpack_bools(state['sign_slow'], original_m=d2)
|
|
209
|
-
torch.where(unpacked_sign_slow,
|
|
207
|
+
torch.where(unpacked_sign_slow, mt_slow, -mt_slow, out=mt_slow)
|
|
210
208
|
del unpacked_sign_slow
|
|
211
209
|
|
|
212
210
|
# Reconstruct v_{t-1} using NNMF
|
|
213
|
-
|
|
211
|
+
vt = _unnmf((state['mu_v_nmf'], state['mv_v_nmf']))
|
|
214
212
|
|
|
215
213
|
# ADOPT Step A: Decorrelate g_t using v_{t-1}
|
|
216
214
|
grad_reshaped = grad.view(d1, d2)
|
|
217
|
-
denom =
|
|
215
|
+
denom = vt.sqrt()
|
|
218
216
|
|
|
219
217
|
if self.use_atan2:
|
|
220
218
|
normalized_grad = torch.atan2(grad_reshaped, denom)
|
|
@@ -226,7 +224,7 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
226
224
|
del denom
|
|
227
225
|
|
|
228
226
|
# ADOPT Step B: Update momentum m_t using normalized gradient
|
|
229
|
-
mt
|
|
227
|
+
mt.mul_(beta1).add_(normalized_grad, alpha=1.0 - beta1)
|
|
230
228
|
if self.use_grams:
|
|
231
229
|
mt = grad_reshaped.sign() * mt.abs()
|
|
232
230
|
elif self.use_cautious:
|
|
@@ -236,7 +234,7 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
236
234
|
del mask
|
|
237
235
|
|
|
238
236
|
if self.use_AdEMAMix:
|
|
239
|
-
mt_slow
|
|
237
|
+
mt_slow.mul_(beta3_ema).add_(normalized_grad, alpha=1.0 - beta3_ema)
|
|
240
238
|
update = mt + (alpha_t * mt_slow)
|
|
241
239
|
update = update.view(p.shape)
|
|
242
240
|
else:
|
|
@@ -248,20 +246,23 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
248
246
|
update.mul_(group['lr'])
|
|
249
247
|
|
|
250
248
|
# Update second moment v_t for the *next* step using raw g_t
|
|
251
|
-
|
|
249
|
+
vt.mul_(beta2).addcmul_(grad_reshaped, grad_reshaped, value=1.0 - beta2)
|
|
252
250
|
del grad_reshaped
|
|
253
251
|
|
|
254
252
|
# Compress and store new factors
|
|
255
253
|
if not self.use_grams:
|
|
256
254
|
state['sign'] = _pack_bools(mt > 0)
|
|
257
255
|
_nnmf(mt.abs(), out=(state['mu_m_nmf'], state['mv_m_nmf']))
|
|
256
|
+
del mt
|
|
258
257
|
|
|
259
258
|
if self.use_AdEMAMix:
|
|
260
259
|
state['sign_slow'] = _pack_bools(mt_slow > 0)
|
|
261
260
|
_nnmf(mt_slow.abs(), out=(state['mu_m_slow_nmf'], state['mv_m_slow_nmf']))
|
|
261
|
+
del mt_slow
|
|
262
262
|
|
|
263
263
|
# factorize v_t using NMF compression
|
|
264
|
-
_nnmf(
|
|
264
|
+
_nnmf(vt, out=(state['mu_v_nmf'], state['mv_v_nmf']))
|
|
265
|
+
del vt
|
|
265
266
|
|
|
266
267
|
else: # Standard ADOPT logic for non-factored tensors
|
|
267
268
|
m, v = state['exp_avg'], state['exp_avg_sq'] # m_{t-1}, v_{t-1}
|
|
@@ -0,0 +1,335 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.distributed as dist
|
|
3
|
+
|
|
4
|
+
import math
|
|
5
|
+
|
|
6
|
+
from typing import Tuple, Optional
|
|
7
|
+
|
|
8
|
+
from ..util.BF16_Stochastic_Rounding import add_stochastic_
|
|
9
|
+
from ..util.Effective_Shape import _get_effective_shape
|
|
10
|
+
from ..util.NNMF import _nnmf,_unnmf
|
|
11
|
+
from ..util.OrthoGrad import _orthogonalize_gradient
|
|
12
|
+
from ..util.One_Bit_Boolean import _pack_bools, _unpack_bools
|
|
13
|
+
|
|
14
|
+
class Lion_Prodigy_adv(torch.optim.Optimizer):
|
|
15
|
+
"""
|
|
16
|
+
Implements the SMMF technique and Prodigy D-Adaptation method for Lion algorithm.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
params (iterable): iterable of parameters to optimize or dicts defining
|
|
20
|
+
parameter groups.
|
|
21
|
+
lr (float, optional): learning rate (default: 1e-4).
|
|
22
|
+
betas (Tuple[float, float], optional): coefficients for computing
|
|
23
|
+
running averages of the update (default: (0.9, 0.99)).
|
|
24
|
+
weight_decay (float, optional): weight decay (L2 penalty) (default: 0.0).
|
|
25
|
+
vector_reshape (bool, optional): whether to reshape 1D vectors into 2D
|
|
26
|
+
matrices to apply low-rank compression (default: True).
|
|
27
|
+
stochastic_rounding (bool, optional): whether to use stochastic
|
|
28
|
+
rounding for BF16 parameter updates (default: True).
|
|
29
|
+
use_cautious (bool): whether to use the cautious masking technique. (default: False).
|
|
30
|
+
clip_threshold (float, optional): whether to clip the gradients norm
|
|
31
|
+
per-parameter as proposed in the paper `Lions and Muons: Optimization via
|
|
32
|
+
Stochastic Frank-Wolfe` (https://arxiv.org/abs/2506.04192) to make Lion more stable
|
|
33
|
+
(default: 0.0).
|
|
34
|
+
factored (bool): whether to use the factorization or use the
|
|
35
|
+
uncompressed optimizer. (default: True)
|
|
36
|
+
variance_reduction (bool): whether to use the variance reduction technique
|
|
37
|
+
from "Convergence Analysis of the Lion Optimizer" (arXiv:2508.12327v1). (default: False).
|
|
38
|
+
d0 (float):
|
|
39
|
+
Initial D estimate for D-adaptation (default 1e-6). Rarely needs changing.
|
|
40
|
+
d_coef (float):
|
|
41
|
+
Coefficient in the expression for the estimate of d (default 1.0).
|
|
42
|
+
Values such as 0.5 and 2.0 typically work as well.
|
|
43
|
+
Changing this parameter is the preferred way to tune the method.
|
|
44
|
+
growth_rate (float):
|
|
45
|
+
prevent the D estimate from growing faster than this multiplicative rate.
|
|
46
|
+
Default is inf, for unrestricted. Values like 1.02 give a kind of learning
|
|
47
|
+
rate warmup effect.
|
|
48
|
+
fsdp_in_use (bool):
|
|
49
|
+
If you're using sharded parameters, this should be set to True. The optimizer
|
|
50
|
+
will attempt to auto-detect this, but if you're using an implementation other
|
|
51
|
+
than PyTorch's builtin version, the auto-detection won't work.
|
|
52
|
+
slice_p (int): Reduce memory usage by calculating LR adaptation statistics on only every
|
|
53
|
+
pth entry of each tensor. For values greater than 1 this an an approximation to standard
|
|
54
|
+
Prodigy. Values ~11 are reasonable (default 11).
|
|
55
|
+
"""
|
|
56
|
+
|
|
57
|
+
def __init__(
|
|
58
|
+
self,
|
|
59
|
+
params,
|
|
60
|
+
lr: float = 1,
|
|
61
|
+
betas: Tuple[float, float] = (0.9, 0.99),
|
|
62
|
+
weight_decay: float = 0.0,
|
|
63
|
+
vector_reshape: bool = True,
|
|
64
|
+
stochastic_rounding: bool = True,
|
|
65
|
+
use_orthograd: bool = False,
|
|
66
|
+
use_cautious: bool = False,
|
|
67
|
+
clip_threshold: float = 0.0,
|
|
68
|
+
factored: bool = True,
|
|
69
|
+
variance_reduction: bool = False,
|
|
70
|
+
# prodigy parameters
|
|
71
|
+
beta3: float = None,
|
|
72
|
+
d0: float = 1e-6,
|
|
73
|
+
d_coef: float = 1,
|
|
74
|
+
growth_rate: float = float('inf'),
|
|
75
|
+
safeguard_warmup: bool = False,
|
|
76
|
+
fsdp_in_use: bool = False,
|
|
77
|
+
slice_p: int = 11,
|
|
78
|
+
):
|
|
79
|
+
if not lr > 0.0:
|
|
80
|
+
raise ValueError(f"Learning rate must be > 0.0, but got {lr}")
|
|
81
|
+
if not all(0.0 <= beta <= 1.0 for beta in betas):
|
|
82
|
+
raise ValueError(f"Betas should be in [0.0, 1.0], but got {betas}")
|
|
83
|
+
if not weight_decay >= 0.0:
|
|
84
|
+
raise ValueError(f"Weight decay must be >= 0.0, but got {weight_decay}")
|
|
85
|
+
if variance_reduction and use_cautious:
|
|
86
|
+
print("Warning: Using both 'variance_reduction' and 'use_cautious' is not recommended and may lead to unintended effects.")
|
|
87
|
+
|
|
88
|
+
defaults = dict(
|
|
89
|
+
lr=lr,
|
|
90
|
+
betas=betas,
|
|
91
|
+
weight_decay=weight_decay,
|
|
92
|
+
vector_reshape=vector_reshape,
|
|
93
|
+
use_orthograd=use_orthograd,
|
|
94
|
+
clip_threshold=clip_threshold,
|
|
95
|
+
beta3=beta3, d=d0, d0=d0, d_max=d0, d_numerator=0.0, d_coef=d_coef,
|
|
96
|
+
growth_rate=growth_rate, safeguard_warmup=safeguard_warmup, k=0, slice_p=slice_p,
|
|
97
|
+
fsdp_in_use=fsdp_in_use,
|
|
98
|
+
)
|
|
99
|
+
self.stochastic_rounding = stochastic_rounding
|
|
100
|
+
self.use_cautious = use_cautious
|
|
101
|
+
self.factored = factored
|
|
102
|
+
self.variance_reduction = variance_reduction
|
|
103
|
+
self.fsdp_in_use = fsdp_in_use
|
|
104
|
+
super().__init__(params, defaults)
|
|
105
|
+
# Global state for accumulating metrics across parameter updates within a single step.
|
|
106
|
+
self.init_step()
|
|
107
|
+
|
|
108
|
+
@property
|
|
109
|
+
def supports_fused_back_pass(self) -> bool:
|
|
110
|
+
return True
|
|
111
|
+
|
|
112
|
+
@property
|
|
113
|
+
def supports_memory_efficient_fp16(self) -> bool:
|
|
114
|
+
return True
|
|
115
|
+
|
|
116
|
+
@property
|
|
117
|
+
def supports_flat_params(self) -> bool:
|
|
118
|
+
return False
|
|
119
|
+
|
|
120
|
+
def init_step(self):
|
|
121
|
+
"""Resets accumulators and calculates dlr for the upcoming step."""
|
|
122
|
+
self.d_denom = 0.0
|
|
123
|
+
|
|
124
|
+
g_group = self.param_groups[0]
|
|
125
|
+
self.beta1, self.beta2 = g_group['betas']
|
|
126
|
+
self.beta3 = g_group['beta3']
|
|
127
|
+
if self.beta3 is None:
|
|
128
|
+
self.beta3 = math.sqrt(self.beta2)
|
|
129
|
+
|
|
130
|
+
k = g_group['k']
|
|
131
|
+
self.d = g_group['d']
|
|
132
|
+
lr = g_group['lr']
|
|
133
|
+
|
|
134
|
+
self.dlr = self.d * lr
|
|
135
|
+
|
|
136
|
+
self.d_numerator = g_group.get('d_numerator', 0.0) * self.beta3
|
|
137
|
+
|
|
138
|
+
@torch.no_grad()
|
|
139
|
+
def step_parameter(self, p: torch.Tensor, group: dict, i: Optional[int] = None):
|
|
140
|
+
"""Performs a single optimization step on a single parameter."""
|
|
141
|
+
if p.grad is None:
|
|
142
|
+
return
|
|
143
|
+
|
|
144
|
+
if hasattr(p, "_fsdp_flattened"):
|
|
145
|
+
self.fsdp_in_use = True
|
|
146
|
+
|
|
147
|
+
grad = p.grad
|
|
148
|
+
if grad.dtype != torch.float32 and self.factored:
|
|
149
|
+
grad = grad.float()
|
|
150
|
+
if group["clip_threshold"] > 0.0:
|
|
151
|
+
grad_norm = torch.norm(grad.detach())
|
|
152
|
+
if grad_norm > group["clip_threshold"]:
|
|
153
|
+
clip_coef = group["clip_threshold"] / grad_norm
|
|
154
|
+
grad.mul_(clip_coef)
|
|
155
|
+
if group["use_orthograd"]:
|
|
156
|
+
grad = _orthogonalize_gradient(p, grad)
|
|
157
|
+
state = self.state[p]
|
|
158
|
+
|
|
159
|
+
# State Initialization
|
|
160
|
+
if len(state) == 0:
|
|
161
|
+
state['step'] = 0
|
|
162
|
+
|
|
163
|
+
should_factor = (
|
|
164
|
+
self.factored and
|
|
165
|
+
not (len(p.shape) == 1 and not group['vector_reshape'])
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
state['factored'] = should_factor
|
|
169
|
+
|
|
170
|
+
dtype = torch.float32 if self.factored else p.dtype
|
|
171
|
+
|
|
172
|
+
slice_p = group['slice_p']
|
|
173
|
+
|
|
174
|
+
# D-Adaptation states
|
|
175
|
+
state['s'] = torch.zeros_like(p.flatten()[::slice_p]).detach()
|
|
176
|
+
if p.any():
|
|
177
|
+
state['p0'] = p.flatten()[::slice_p].detach().clone()
|
|
178
|
+
else:
|
|
179
|
+
state['p0'] = torch.tensor(0, device=p.device, dtype=p.dtype)
|
|
180
|
+
|
|
181
|
+
if state['factored']:
|
|
182
|
+
state['effective_shape'] = _get_effective_shape(p.numel())
|
|
183
|
+
d1, d2 = state['effective_shape']
|
|
184
|
+
state['mu_m_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
|
|
185
|
+
state['mv_m_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
|
|
186
|
+
packed_d2 = (d2 + 7) // 8
|
|
187
|
+
state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=p.device)
|
|
188
|
+
if self.variance_reduction:
|
|
189
|
+
state['prev_grad'] = torch.zeros((d1, d2), device=p.device, dtype=dtype)
|
|
190
|
+
else: # Fallback to standard Lion
|
|
191
|
+
state['exp_avg'] = torch.zeros_like(p, device=p.device, dtype=dtype)
|
|
192
|
+
if self.variance_reduction:
|
|
193
|
+
state['prev_grad'] = torch.zeros_like(p, device=p.device, dtype=dtype)
|
|
194
|
+
|
|
195
|
+
if state['factored']:
|
|
196
|
+
# Factored Path
|
|
197
|
+
d1, d2 = state['effective_shape']
|
|
198
|
+
grad_reshaped = grad.view(d1, d2)
|
|
199
|
+
# Reconstruct momentum m_{t-1}
|
|
200
|
+
exp_avg = _unnmf((state['mu_m_nmf'], state['mv_m_nmf']))
|
|
201
|
+
unpacked_sign = _unpack_bools(state['sign'], original_m=d2)
|
|
202
|
+
torch.where(unpacked_sign, exp_avg, -exp_avg, out=exp_avg)
|
|
203
|
+
del unpacked_sign
|
|
204
|
+
if exp_avg.dtype != torch.float32:
|
|
205
|
+
exp_avg = exp_avg.float()
|
|
206
|
+
|
|
207
|
+
# Compute update term c_t = β1*m_{t-1} + (1-β1)*g_t
|
|
208
|
+
signed_update = exp_avg.clone().mul_(self.beta1).add_(grad_reshaped, alpha=(1-self.beta1)).sign_()
|
|
209
|
+
|
|
210
|
+
if self.use_cautious:
|
|
211
|
+
mask = (signed_update * grad_reshaped > 0).to(grad_reshaped.dtype)
|
|
212
|
+
mask.div_(mask.mean().clamp_(min=1e-3))
|
|
213
|
+
signed_update.mul_(mask)
|
|
214
|
+
del mask
|
|
215
|
+
|
|
216
|
+
# Parameter update: p_t = p_{t-1} - lr * sign(c_t)
|
|
217
|
+
update_for_param = signed_update.view(p.shape).mul(self.dlr)
|
|
218
|
+
|
|
219
|
+
# Update momentum m_t = β2*m_{t-1} + (1-β2)*lr*g_t
|
|
220
|
+
if self.variance_reduction:
|
|
221
|
+
vr_term = grad_reshaped - state['prev_grad']
|
|
222
|
+
exp_avg.mul_(self.beta2).add_(grad_reshaped, alpha=self.d * (1-self.beta2)).add_(vr_term, alpha=self.beta2)
|
|
223
|
+
state['prev_grad'].copy_(grad_reshaped)
|
|
224
|
+
else:
|
|
225
|
+
exp_avg.mul_(self.beta2).add_(grad_reshaped, alpha=self.d * (1-self.beta2))
|
|
226
|
+
del grad_reshaped
|
|
227
|
+
|
|
228
|
+
# Compress new momentum m_t and store factors
|
|
229
|
+
state['sign'] = _pack_bools(exp_avg > 0)
|
|
230
|
+
_nnmf(exp_avg.abs(), out=(state['mu_m_nmf'], state['mv_m_nmf']))
|
|
231
|
+
del exp_avg
|
|
232
|
+
|
|
233
|
+
else:
|
|
234
|
+
# Fallback to standard Lion logic
|
|
235
|
+
exp_avg = state["exp_avg"]
|
|
236
|
+
|
|
237
|
+
# Compute update term and sign for the update
|
|
238
|
+
if exp_avg.dtype != torch.float32 and self.factored:
|
|
239
|
+
exp_avg = exp_avg.float()
|
|
240
|
+
signed_update = exp_avg.clone().mul_(self.beta1).add_(grad, alpha=(1-self.beta1)).sign_()
|
|
241
|
+
|
|
242
|
+
if self.use_cautious:
|
|
243
|
+
mask = (signed_update * grad > 0).to(grad.dtype)
|
|
244
|
+
mask.div_(mask.mean().clamp_(min=1e-3))
|
|
245
|
+
signed_update.mul_(mask)
|
|
246
|
+
del mask
|
|
247
|
+
|
|
248
|
+
update_for_param = signed_update.mul(self.dlr)
|
|
249
|
+
|
|
250
|
+
# Update momentum
|
|
251
|
+
if self.variance_reduction:
|
|
252
|
+
vr_term = grad - state['prev_grad']
|
|
253
|
+
exp_avg.mul_(self.beta2).add_(grad, alpha=self.d * (1-self.beta2)).add_(vr_term, alpha=self.beta2)
|
|
254
|
+
state['prev_grad'].copy_(grad)
|
|
255
|
+
else:
|
|
256
|
+
exp_avg.mul_(self.beta2).add_(grad, alpha=self.d * (1-self.beta2))
|
|
257
|
+
|
|
258
|
+
# --- Accumulate Prodigy stats ---
|
|
259
|
+
d0, safeguard_warmup, slice_p = group['d0'], group['safeguard_warmup'], group['slice_p']
|
|
260
|
+
s, p0 = state['s'], state['p0']
|
|
261
|
+
grad_flat = grad.flatten().float()
|
|
262
|
+
p_flat = p.data.flatten().float()
|
|
263
|
+
p0 = p0.float()
|
|
264
|
+
|
|
265
|
+
self.d_numerator += (self.d / d0) * self.dlr * torch.dot(grad_flat[::slice_p], p0.data - p_flat[::slice_p]).item()
|
|
266
|
+
|
|
267
|
+
alpha = ((self.d / d0) * self.d) if safeguard_warmup else ((self.d / d0) * self.dlr)
|
|
268
|
+
s.mul_(self.beta3).add_(grad_flat[::slice_p], alpha=alpha)
|
|
269
|
+
self.d_denom += s.abs().sum().item()
|
|
270
|
+
|
|
271
|
+
del s, p0, grad_flat, p_flat, alpha
|
|
272
|
+
|
|
273
|
+
if group["weight_decay"] != 0:
|
|
274
|
+
if p.dtype == torch.bfloat16 and self.stochastic_rounding:
|
|
275
|
+
add_stochastic_(p.data, p.data,
|
|
276
|
+
alpha=-group["weight_decay"] * self.dlr)
|
|
277
|
+
else:
|
|
278
|
+
p.data.add_(
|
|
279
|
+
p.data, alpha=-group["weight_decay"] * self.dlr
|
|
280
|
+
)
|
|
281
|
+
|
|
282
|
+
if p.dtype == torch.bfloat16 and self.stochastic_rounding:
|
|
283
|
+
add_stochastic_(p.data, -update_for_param)
|
|
284
|
+
else:
|
|
285
|
+
p.data.add_(-update_for_param)
|
|
286
|
+
|
|
287
|
+
del update_for_param
|
|
288
|
+
|
|
289
|
+
@torch.no_grad()
|
|
290
|
+
def step(self, closure: Optional[callable] = None):
|
|
291
|
+
"""Performs a single optimization step."""
|
|
292
|
+
loss = None
|
|
293
|
+
if closure is not None:
|
|
294
|
+
with torch.enable_grad():
|
|
295
|
+
loss = closure()
|
|
296
|
+
|
|
297
|
+
for group in self.param_groups:
|
|
298
|
+
for i, p in enumerate(group["params"]):
|
|
299
|
+
if p.grad is not None:
|
|
300
|
+
self.step_parameter(p, group, i)
|
|
301
|
+
|
|
302
|
+
|
|
303
|
+
self.calculate_d()
|
|
304
|
+
self.init_step()
|
|
305
|
+
return loss
|
|
306
|
+
|
|
307
|
+
def calculate_d(self):
|
|
308
|
+
"""Calculates the new `d` based on the accumulated stats."""
|
|
309
|
+
g_group = self.param_groups[0]
|
|
310
|
+
d_max, d_coef, growth_rate = g_group['d_max'], g_group['d_coef'], g_group['growth_rate']
|
|
311
|
+
|
|
312
|
+
if self.fsdp_in_use and dist.is_available() and dist.is_initialized():
|
|
313
|
+
# Use the device of the first parameter to avoid hardcoding '.cuda()'
|
|
314
|
+
device = self.param_groups[0]['params'][0].device
|
|
315
|
+
dist_tensor = torch.tensor([self.d_numerator, self.d_denom], device=device)
|
|
316
|
+
dist.all_reduce(dist_tensor, op=dist.ReduceOp.SUM)
|
|
317
|
+
global_d_numerator = dist_tensor[0].item()
|
|
318
|
+
global_d_denom = dist_tensor[1].item()
|
|
319
|
+
else:
|
|
320
|
+
global_d_numerator = self.d_numerator
|
|
321
|
+
global_d_denom = self.d_denom
|
|
322
|
+
|
|
323
|
+
d_hat = self.d
|
|
324
|
+
if global_d_denom > 0:
|
|
325
|
+
d_hat = d_coef * global_d_numerator / global_d_denom
|
|
326
|
+
if self.d == g_group['d0']:
|
|
327
|
+
self.d = max(self.d, d_hat)
|
|
328
|
+
d_max = max(d_max, d_hat)
|
|
329
|
+
self.d = min(d_max, self.d * growth_rate)
|
|
330
|
+
|
|
331
|
+
for group in self.param_groups:
|
|
332
|
+
group['d_numerator'] = global_d_numerator
|
|
333
|
+
group['d'] = self.d
|
|
334
|
+
group['d_max'] = d_max
|
|
335
|
+
group['k'] += 1
|
|
@@ -0,0 +1,231 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
from typing import Tuple, Optional
|
|
4
|
+
|
|
5
|
+
from ..util.BF16_Stochastic_Rounding import add_stochastic_
|
|
6
|
+
from ..util.Effective_Shape import _get_effective_shape
|
|
7
|
+
from ..util.NNMF import _nnmf,_unnmf
|
|
8
|
+
from ..util.OrthoGrad import _orthogonalize_gradient
|
|
9
|
+
from ..util.One_Bit_Boolean import _pack_bools, _unpack_bools
|
|
10
|
+
|
|
11
|
+
class Lion_adv(torch.optim.Optimizer):
|
|
12
|
+
"""
|
|
13
|
+
Implements the SMMF technique for Lion algorithm.
|
|
14
|
+
|
|
15
|
+
This optimizer combines the Lion update rule with the memory-saving low-rank
|
|
16
|
+
compression (SMMF) technique from https://arxiv.org/abs/2412.08894.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
params (iterable): iterable of parameters to optimize or dicts defining
|
|
20
|
+
parameter groups.
|
|
21
|
+
lr (float, optional): learning rate (default: 1e-4).
|
|
22
|
+
betas (Tuple[float, float], optional): coefficients for computing
|
|
23
|
+
running averages of the update (default: (0.9, 0.99)).
|
|
24
|
+
weight_decay (float, optional): weight decay (L2 penalty) (default: 0.0).
|
|
25
|
+
vector_reshape (bool, optional): whether to reshape 1D vectors into 2D
|
|
26
|
+
matrices to apply low-rank compression (default: True).
|
|
27
|
+
stochastic_rounding (bool, optional): whether to use stochastic
|
|
28
|
+
rounding for BF16 parameter updates (default: True).
|
|
29
|
+
use_cautious (bool): whether to use the cautious masking technique. (default: False).
|
|
30
|
+
clip_threshold (float, optional): whether to clip the gradients norm
|
|
31
|
+
per-parameter as proposed in the paper `Lions and Muons: Optimization via
|
|
32
|
+
Stochastic Frank-Wolfe` (https://arxiv.org/abs/2506.04192) to make Lion more stable
|
|
33
|
+
(default: 0.0).
|
|
34
|
+
factored (bool): whether to use the factorization or use the
|
|
35
|
+
uncompressed optimizer. (default: True)
|
|
36
|
+
variance_reduction (bool): whether to use the variance reduction technique
|
|
37
|
+
from "Convergence Analysis of the Lion Optimizer" (arXiv:2508.12327v1). (default: False).
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
def __init__(
|
|
41
|
+
self,
|
|
42
|
+
params,
|
|
43
|
+
lr: float = 1e-4,
|
|
44
|
+
betas: Tuple[float, float] = (0.9, 0.99),
|
|
45
|
+
weight_decay: float = 0.0,
|
|
46
|
+
vector_reshape: bool = True,
|
|
47
|
+
stochastic_rounding: bool = True,
|
|
48
|
+
use_orthograd: bool = False,
|
|
49
|
+
use_cautious: bool = False,
|
|
50
|
+
clip_threshold: float = 0.0,
|
|
51
|
+
factored: bool = True,
|
|
52
|
+
variance_reduction: bool = False,
|
|
53
|
+
):
|
|
54
|
+
if not lr > 0.0:
|
|
55
|
+
raise ValueError(f"Learning rate must be > 0.0, but got {lr}")
|
|
56
|
+
if not all(0.0 <= beta <= 1.0 for beta in betas):
|
|
57
|
+
raise ValueError(f"Betas should be in [0.0, 1.0], but got {betas}")
|
|
58
|
+
if not weight_decay >= 0.0:
|
|
59
|
+
raise ValueError(f"Weight decay must be >= 0.0, but got {weight_decay}")
|
|
60
|
+
if variance_reduction and use_cautious:
|
|
61
|
+
print("Warning: Using both 'variance_reduction' and 'use_cautious' is not recommended and may lead to unintended effects.")
|
|
62
|
+
|
|
63
|
+
defaults = dict(
|
|
64
|
+
lr=lr,
|
|
65
|
+
betas=betas,
|
|
66
|
+
weight_decay=weight_decay,
|
|
67
|
+
vector_reshape=vector_reshape,
|
|
68
|
+
use_orthograd=use_orthograd,
|
|
69
|
+
clip_threshold=clip_threshold,
|
|
70
|
+
)
|
|
71
|
+
self.stochastic_rounding = stochastic_rounding
|
|
72
|
+
self.use_cautious = use_cautious
|
|
73
|
+
self.factored = factored
|
|
74
|
+
self.variance_reduction = variance_reduction
|
|
75
|
+
super().__init__(params, defaults)
|
|
76
|
+
|
|
77
|
+
@property
|
|
78
|
+
def supports_fused_back_pass(self) -> bool:
|
|
79
|
+
return True
|
|
80
|
+
|
|
81
|
+
@property
|
|
82
|
+
def supports_memory_efficient_fp16(self) -> bool:
|
|
83
|
+
return True
|
|
84
|
+
|
|
85
|
+
@property
|
|
86
|
+
def supports_flat_params(self) -> bool:
|
|
87
|
+
return False
|
|
88
|
+
|
|
89
|
+
@torch.no_grad()
|
|
90
|
+
def step_parameter(self, p: torch.Tensor, group: dict, i: Optional[int] = None):
|
|
91
|
+
"""Performs a single optimization step on a single parameter."""
|
|
92
|
+
if p.grad is None:
|
|
93
|
+
return
|
|
94
|
+
|
|
95
|
+
grad = p.grad
|
|
96
|
+
if grad.dtype != torch.float32 and self.factored:
|
|
97
|
+
grad = grad.float()
|
|
98
|
+
if group["clip_threshold"] > 0.0:
|
|
99
|
+
grad_norm = torch.norm(grad.detach())
|
|
100
|
+
if grad_norm > group["clip_threshold"]:
|
|
101
|
+
clip_coef = group["clip_threshold"] / grad_norm
|
|
102
|
+
grad.mul_(clip_coef)
|
|
103
|
+
if group["use_orthograd"]:
|
|
104
|
+
grad = _orthogonalize_gradient(p, grad)
|
|
105
|
+
state = self.state[p]
|
|
106
|
+
|
|
107
|
+
# State Initialization
|
|
108
|
+
if len(state) == 0:
|
|
109
|
+
state['step'] = 0
|
|
110
|
+
|
|
111
|
+
should_factor = (
|
|
112
|
+
self.factored and
|
|
113
|
+
not (len(p.shape) == 1 and not group['vector_reshape'])
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
state['factored'] = should_factor
|
|
117
|
+
|
|
118
|
+
dtype = torch.float32 if self.factored else p.dtype
|
|
119
|
+
|
|
120
|
+
if state['factored']:
|
|
121
|
+
state['effective_shape'] = _get_effective_shape(p.numel())
|
|
122
|
+
d1, d2 = state['effective_shape']
|
|
123
|
+
state['mu_m_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
|
|
124
|
+
state['mv_m_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
|
|
125
|
+
packed_d2 = (d2 + 7) // 8
|
|
126
|
+
state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=p.device)
|
|
127
|
+
if self.variance_reduction:
|
|
128
|
+
state['prev_grad'] = torch.zeros((d1, d2), device=p.device, dtype=dtype)
|
|
129
|
+
else: # Fallback to standard Lion
|
|
130
|
+
state['exp_avg'] = torch.zeros_like(p, device=p.device, dtype=dtype)
|
|
131
|
+
if self.variance_reduction:
|
|
132
|
+
state['prev_grad'] = torch.zeros_like(p, device=p.device, dtype=dtype)
|
|
133
|
+
|
|
134
|
+
state['step'] += 1
|
|
135
|
+
beta1, beta2 = group["betas"]
|
|
136
|
+
lr = group["lr"]
|
|
137
|
+
|
|
138
|
+
if state['factored']:
|
|
139
|
+
# Factored Path
|
|
140
|
+
d1, d2 = state['effective_shape']
|
|
141
|
+
grad_reshaped = grad.view(d1, d2)
|
|
142
|
+
# Reconstruct momentum m_{t-1}
|
|
143
|
+
exp_avg = _unnmf((state['mu_m_nmf'], state['mv_m_nmf']))
|
|
144
|
+
unpacked_sign = _unpack_bools(state['sign'], original_m=d2)
|
|
145
|
+
torch.where(unpacked_sign, exp_avg, -exp_avg, out=exp_avg)
|
|
146
|
+
del unpacked_sign
|
|
147
|
+
if exp_avg.dtype != torch.float32:
|
|
148
|
+
exp_avg = exp_avg.float()
|
|
149
|
+
|
|
150
|
+
# Compute update term c_t = β1*m_{t-1} + (1-β1)*g_t
|
|
151
|
+
signed_update = exp_avg.clone().mul_(beta1).add_(grad_reshaped, alpha=(1-beta1)).sign_()
|
|
152
|
+
|
|
153
|
+
if self.use_cautious:
|
|
154
|
+
mask = (signed_update * grad_reshaped > 0).to(grad_reshaped.dtype)
|
|
155
|
+
mask.div_(mask.mean().clamp_(min=1e-3))
|
|
156
|
+
signed_update.mul_(mask)
|
|
157
|
+
del mask
|
|
158
|
+
|
|
159
|
+
# Parameter update: p_t = p_{t-1} - lr * sign(c_t)
|
|
160
|
+
update_for_param = signed_update.view(p.shape).mul_(lr)
|
|
161
|
+
|
|
162
|
+
# Update momentum m_t = β2*m_{t-1} + (1-β2)*lr*g_t
|
|
163
|
+
if self.variance_reduction:
|
|
164
|
+
vr_term = grad_reshaped - state['prev_grad']
|
|
165
|
+
exp_avg.mul_(beta2).add_(grad_reshaped, alpha=1-beta2).add_(vr_term, alpha=beta2)
|
|
166
|
+
del vr_term
|
|
167
|
+
state['prev_grad'].copy_(grad_reshaped)
|
|
168
|
+
else:
|
|
169
|
+
exp_avg.mul_(beta2).add_(grad_reshaped, alpha=1-beta2)
|
|
170
|
+
del grad_reshaped
|
|
171
|
+
|
|
172
|
+
# Compress new momentum m_t and store factors
|
|
173
|
+
state['sign'] = _pack_bools(exp_avg > 0)
|
|
174
|
+
_nnmf(exp_avg.abs(), out=(state['mu_m_nmf'], state['mv_m_nmf']))
|
|
175
|
+
del exp_avg
|
|
176
|
+
|
|
177
|
+
else:
|
|
178
|
+
# Fallback to standard Lion logic
|
|
179
|
+
exp_avg = state["exp_avg"]
|
|
180
|
+
|
|
181
|
+
# Compute update term and sign for the update
|
|
182
|
+
if exp_avg.dtype != torch.float32 and self.factored:
|
|
183
|
+
exp_avg = exp_avg.float()
|
|
184
|
+
signed_update = exp_avg.clone().mul_(beta1).add_(grad, alpha=(1-beta1)).sign_()
|
|
185
|
+
|
|
186
|
+
if self.use_cautious:
|
|
187
|
+
mask = (signed_update * grad > 0).to(grad.dtype)
|
|
188
|
+
mask.div_(mask.mean().clamp_(min=1e-3))
|
|
189
|
+
signed_update.mul_(mask)
|
|
190
|
+
del mask
|
|
191
|
+
|
|
192
|
+
update_for_param = signed_update.mul_(lr)
|
|
193
|
+
|
|
194
|
+
# Update momentum
|
|
195
|
+
if self.variance_reduction:
|
|
196
|
+
vr_term = grad - state['prev_grad']
|
|
197
|
+
exp_avg.mul_(beta2).add_(grad, alpha=1-beta2).add_(vr_term, alpha=beta2)
|
|
198
|
+
state['prev_grad'].copy_(grad)
|
|
199
|
+
else:
|
|
200
|
+
exp_avg.mul_(beta2).add_(grad, alpha=1-beta2)
|
|
201
|
+
|
|
202
|
+
if group["weight_decay"] != 0:
|
|
203
|
+
if p.dtype == torch.bfloat16 and self.stochastic_rounding:
|
|
204
|
+
add_stochastic_(p.data, p.data,
|
|
205
|
+
alpha=-group["weight_decay"] * lr)
|
|
206
|
+
else:
|
|
207
|
+
p.data.add_(
|
|
208
|
+
p.data, alpha=-group["weight_decay"] * lr
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
if p.dtype == torch.bfloat16 and self.stochastic_rounding:
|
|
212
|
+
add_stochastic_(p.data, -update_for_param)
|
|
213
|
+
else:
|
|
214
|
+
p.data.add_(-update_for_param)
|
|
215
|
+
|
|
216
|
+
del update_for_param
|
|
217
|
+
|
|
218
|
+
@torch.no_grad()
|
|
219
|
+
def step(self, closure: Optional[callable] = None):
|
|
220
|
+
"""Performs a single optimization step."""
|
|
221
|
+
loss = None
|
|
222
|
+
if closure is not None:
|
|
223
|
+
with torch.enable_grad():
|
|
224
|
+
loss = closure()
|
|
225
|
+
|
|
226
|
+
for group in self.param_groups:
|
|
227
|
+
for i, p in enumerate(group["params"]):
|
|
228
|
+
if p.grad is not None:
|
|
229
|
+
self.step_parameter(p, group, i)
|
|
230
|
+
|
|
231
|
+
return loss
|
adv_optm/optim/Prodigy_adv.py
CHANGED
|
@@ -18,7 +18,7 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
18
18
|
Args:
|
|
19
19
|
params (iterable): iterable of parameters to optimize or dicts defining
|
|
20
20
|
parameter groups
|
|
21
|
-
lr (float): learning rate (default:
|
|
21
|
+
lr (float): learning rate (default: 1)
|
|
22
22
|
betas (tuple[float, float]): coefficients used for computing running
|
|
23
23
|
averages of gradient and its square (default: (0.9, 0.999))
|
|
24
24
|
eps (float): term added to the denominator to improve
|
|
@@ -71,13 +71,13 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
71
71
|
than PyTorch's builtin version, the auto-detection won't work.
|
|
72
72
|
slice_p (int): Reduce memory usage by calculating LR adaptation statistics on only every
|
|
73
73
|
pth entry of each tensor. For values greater than 1 this an an approximation to standard
|
|
74
|
-
Prodigy. Values ~11 are reasonable (default
|
|
74
|
+
Prodigy. Values ~11 are reasonable (default 11).
|
|
75
75
|
"""
|
|
76
76
|
|
|
77
77
|
def __init__(
|
|
78
78
|
self,
|
|
79
79
|
params,
|
|
80
|
-
lr: float =
|
|
80
|
+
lr: float = 1,
|
|
81
81
|
betas: tuple[float, float] = (0.9, 0.999),
|
|
82
82
|
eps: float = 1e-8,
|
|
83
83
|
weight_decay: float = 0.0,
|
|
@@ -270,7 +270,7 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
270
270
|
denom = vt.sqrt()
|
|
271
271
|
update = torch.atan2(update_m, denom).mul_(a)
|
|
272
272
|
else:
|
|
273
|
-
denom = vt.sqrt().add_(group['eps'])
|
|
273
|
+
denom = vt.sqrt().add_(self.d * group['eps'])
|
|
274
274
|
update = update_m / denom
|
|
275
275
|
del update_m, denom
|
|
276
276
|
|
|
@@ -315,7 +315,7 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
315
315
|
denom = exp_avg_sq.sqrt()
|
|
316
316
|
update = torch.atan2(update_m, denom).mul_(a)
|
|
317
317
|
else:
|
|
318
|
-
denom = exp_avg_sq.sqrt().add_(group['eps'])
|
|
318
|
+
denom = exp_avg_sq.sqrt().add_(self.d * group['eps'])
|
|
319
319
|
update = update_m / denom
|
|
320
320
|
del update_m, denom
|
|
321
321
|
|
adv_optm/optim/__init__.py
CHANGED
|
@@ -1,9 +1,13 @@
|
|
|
1
1
|
from .AdamW_adv import AdamW_adv
|
|
2
2
|
from .Prodigy_adv import Prodigy_adv
|
|
3
3
|
from .Adopt_adv import Adopt_adv
|
|
4
|
+
from .Lion_adv import Lion_adv
|
|
5
|
+
from .Lion_Prodigy_adv import Lion_Prodigy_adv
|
|
4
6
|
|
|
5
7
|
__all__ = [
|
|
6
8
|
"AdamW_adv",
|
|
7
9
|
"Prodigy_adv",
|
|
8
10
|
"Adopt_adv",
|
|
11
|
+
"Lion_adv",
|
|
12
|
+
"Lion_Prodigy_adv",
|
|
9
13
|
]
|
adv_optm/util/__init__.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from .BF16_Stochastic_Rounding import add_stochastic_
|
|
1
|
+
from .BF16_Stochastic_Rounding import add_stochastic_
|
|
2
2
|
from .Effective_Shape import _get_effective_shape
|
|
3
3
|
from .One_Bit_Boolean import _pack_bools, _unpack_bools
|
|
4
4
|
from .OrthoGrad import _orthogonalize_gradient
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: adv_optm
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.2
|
|
4
4
|
Summary: A family of highly efficient, lightweight yet powerful optimizers.
|
|
5
5
|
Home-page: https://github.com/Koratahiu/Advanced_Optimizers
|
|
6
6
|
Author: Koratahiu
|
|
@@ -47,7 +47,7 @@ Based primarily on:
|
|
|
47
47
|
**[SMMF: Square-Matricized Momentum Factorization for Memory-Efficient Optimization](https://arxiv.org/abs/2412.08894)**
|
|
48
48
|
|
|
49
49
|
The core innovation:
|
|
50
|
-
- Uses fast, non-negative matrix factorization (rank 1
|
|
50
|
+
- Uses fast, non-negative matrix factorization (NNMF - rank 1), but **reconstructs the full state before each update** to preserve momentum accuracy, then re-factors afterward (factor → reconstruct → update → factor cycle).
|
|
51
51
|
- For the *signed first moment*, we split into **sign + absolute value**:
|
|
52
52
|
- Sign is stored as **1-bit state** via bitwise ops (SMMF originally used 8-bit with 7 bits wasted).
|
|
53
53
|
- Absolute value goes through the factor/reconstruct cycle using two factored vectors + the signed state.
|
|
@@ -106,8 +106,6 @@ Set `Factored=False` to disable factorization and run as a full uncompressed opt
|
|
|
106
106
|
|
|
107
107
|
⚠️ **Note**: AdEMAMix updates are more aggressive than normal Adam/Adopt, so use a x2-x5 smaller LR than usual (or use Prodigy).
|
|
108
108
|
|
|
109
|
-
⚠️ **Note**: The factored AdEMAMix is **Experimental** (as it needs more tests and validation, but it should work). Also, Adopt with AdEMAMix is **Experimental** (as Adopt normalizes the gradients for the momentum).
|
|
110
|
-
|
|
111
109
|
- **[`atan2` smoothing & scaling](https://github.com/lucidrains/adam-atan2-pytorch)**
|
|
112
110
|
→ Robust `eps` replacement (no tuning!) + built-in gradient clipping
|
|
113
111
|
→ *Ideal for ADOPT* (which normally needs higher `eps` and clipping), so `use_atan2` is all-in-one for it.
|
|
@@ -129,6 +127,4 @@ Set `Factored=False` to disable factorization and run as a full uncompressed opt
|
|
|
129
127
|
|
|
130
128
|
- When `use_atan2` is True, `eps` will be ignored and you should also disable any gradient clipping.
|
|
131
129
|
|
|
132
|
-
- I don't recommend using **OrthoGrad** for training LoRA or embeddings, as their weights are zero-initialized and using weight decay for them should be safe and also beneficial (OrthoGrad is intended for fine-tuning pretrained models with no weight decay).
|
|
133
|
-
|
|
134
130
|
---
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
adv_optm/__init__.py,sha256=BNYlxkuU8MFsWSY1_PLzp2XBSzpt-sxhnVuWVKRZGZ8,252
|
|
2
|
+
adv_optm/optim/AdamW_adv.py,sha256=_4Vt79EB18rnIkHttA0CdMpli8sZ5f03pesdrwT5K58,12887
|
|
3
|
+
adv_optm/optim/Adopt_adv.py,sha256=rzBWfFOPrMuC6vwETsw7QPKmVXcv4IJRDCTj-6eU1Qk,14798
|
|
4
|
+
adv_optm/optim/Lion_Prodigy_adv.py,sha256=ql6506h_IIZvTPdGYrQdd6iEhCXHTMntqmg739fc_dw,14102
|
|
5
|
+
adv_optm/optim/Lion_adv.py,sha256=jOoRbJ6u9HCK7IBI9ILOCcwprKIGTUNvUzhRd99WJK0,9410
|
|
6
|
+
adv_optm/optim/Prodigy_adv.py,sha256=InR50MoE32zG6qgEkg_JzXl7uXAVRy4EYG0JDl4eKok,17324
|
|
7
|
+
adv_optm/optim/__init__.py,sha256=e5UighM92LDvDB2JJwj8gDsTpXEedpytScwqS6F2FR8,300
|
|
8
|
+
adv_optm/util/BF16_Stochastic_Rounding.py,sha256=Q5H0BcogmE4atP65dLoI21HKSf50lRdsBDfeF6v9Tbg,1548
|
|
9
|
+
adv_optm/util/Effective_Shape.py,sha256=TBvIk1V8IuTbbBsxuekJA4e_v8JlR5Nujtut8RTWAm4,318
|
|
10
|
+
adv_optm/util/NNMF.py,sha256=yRf5IP5Sjq0Uf0DxN0Q8NxEGSdD-f1ULziLVDOjY8K4,639
|
|
11
|
+
adv_optm/util/One_Bit_Boolean.py,sha256=Wat49esdwohuN-OHOFMW8D0aOQgV9cP5Rl8z6yfmpos,1068
|
|
12
|
+
adv_optm/util/OrthoGrad.py,sha256=NzInuBQGy_Ja__M1R9XbvqVaQ0fhGbtGgFE9YON7B3I,707
|
|
13
|
+
adv_optm/util/__init__.py,sha256=qoyIF0jcLjs_vSEcsv36clw5LFNBEbifyXrrVxMH-G4,349
|
|
14
|
+
adv_optm-0.1.2.dist-info/licenses/LICENSE,sha256=HrhfyXIkWY2tGFK11kg7vPCqhgh5DcxleloqdhrpyMY,11558
|
|
15
|
+
adv_optm-0.1.2.dist-info/METADATA,sha256=iV5GBWtl4WphBeSIIsUoq1ay6-GJGnDD3XF6aSWWrqg,5846
|
|
16
|
+
adv_optm-0.1.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
17
|
+
adv_optm-0.1.2.dist-info/top_level.txt,sha256=iNfBIIzu-lPrQ7jyC56WBCcbkRwitM2nJ15-MRQ_6fg,9
|
|
18
|
+
adv_optm-0.1.2.dist-info/RECORD,,
|
adv_optm/util/Randomized_SVD.py
DELETED
|
@@ -1,37 +0,0 @@
|
|
|
1
|
-
import torch
|
|
2
|
-
from torch import Tensor
|
|
3
|
-
|
|
4
|
-
from typing import Tuple
|
|
5
|
-
|
|
6
|
-
def _rsvd(A: torch.Tensor, rank: int, oversampling: int) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
7
|
-
"""Performs Randomized SVD."""
|
|
8
|
-
orig_dtype, device, (m, n) = A.dtype, A.device, A.shape
|
|
9
|
-
A_float = A.float()
|
|
10
|
-
l, true_rank = rank + oversampling, min(m, n, rank)
|
|
11
|
-
|
|
12
|
-
if true_rank == 0:
|
|
13
|
-
return (
|
|
14
|
-
torch.zeros(m, rank, dtype=orig_dtype, device=device),
|
|
15
|
-
torch.zeros(rank, dtype=orig_dtype, device=device),
|
|
16
|
-
torch.zeros(rank, n, dtype=orig_dtype, device=device),
|
|
17
|
-
)
|
|
18
|
-
|
|
19
|
-
if l >= min(m, n): # Fallback to full SVD
|
|
20
|
-
U_full, S_full, Vh_full = torch.linalg.svd(A_float, full_matrices=False)
|
|
21
|
-
U, S, Vh = U_full[:, :true_rank], S_full[:true_rank], Vh_full[:true_rank, :]
|
|
22
|
-
else: # Standard RSVD path
|
|
23
|
-
Omega = torch.randn(n, l, dtype=A_float.dtype, device=device)
|
|
24
|
-
Y = A_float @ Omega
|
|
25
|
-
Q, _ = torch.linalg.qr(Y.float())
|
|
26
|
-
B = Q.T @ A_float
|
|
27
|
-
U_tilde, S, Vh = torch.linalg.svd(B.float(), full_matrices=False)
|
|
28
|
-
U, S, Vh = (Q @ U_tilde)[:, :true_rank], S[:true_rank], Vh[:true_rank, :]
|
|
29
|
-
|
|
30
|
-
if true_rank < rank: # Pad factors with zeros
|
|
31
|
-
U_padded = torch.zeros(m, rank, dtype=A_float.dtype, device=device)
|
|
32
|
-
S_padded = torch.zeros(rank, dtype=A_float.dtype, device=device)
|
|
33
|
-
Vh_padded = torch.zeros(rank, n, dtype=A_float.dtype, device=device)
|
|
34
|
-
U_padded[:, :true_rank], S_padded[:true_rank], Vh_padded[:true_rank, :] = U, S, Vh
|
|
35
|
-
U, S, Vh = U_padded, S_padded, Vh_padded
|
|
36
|
-
|
|
37
|
-
return U.to(orig_dtype), S.to(orig_dtype), Vh.to(orig_dtype)
|
adv_optm-0.1.1.dist-info/RECORD
DELETED
|
@@ -1,17 +0,0 @@
|
|
|
1
|
-
adv_optm/__init__.py,sha256=Ol6hg_EdQH1AXJsa_9l5iWnlUXuOXwD-6eU1OweL87A,172
|
|
2
|
-
adv_optm/optim/AdamW_adv.py,sha256=VGGzLhLh6CdY4I8mxmlzIC90rWnc9oGNuuXK8vE1dE0,12729
|
|
3
|
-
adv_optm/optim/Adopt_adv.py,sha256=-GRpXWISCq6HPkd7UB1S57jSzsg2D3nAhAt6082_7Ms,14992
|
|
4
|
-
adv_optm/optim/Prodigy_adv.py,sha256=5N5GsTWYg_0q_R95E_ryZVa3zSe-q30p_bFK5dXOUpM,17311
|
|
5
|
-
adv_optm/optim/__init__.py,sha256=kX9MQhLQZGlKFPCGLXsZtooigs4wXULTEmNSSOJvcCY,178
|
|
6
|
-
adv_optm/util/BF16_Stochastic_Rounding.py,sha256=Q5H0BcogmE4atP65dLoI21HKSf50lRdsBDfeF6v9Tbg,1548
|
|
7
|
-
adv_optm/util/Effective_Shape.py,sha256=TBvIk1V8IuTbbBsxuekJA4e_v8JlR5Nujtut8RTWAm4,318
|
|
8
|
-
adv_optm/util/NNMF.py,sha256=yRf5IP5Sjq0Uf0DxN0Q8NxEGSdD-f1ULziLVDOjY8K4,639
|
|
9
|
-
adv_optm/util/One_Bit_Boolean.py,sha256=Wat49esdwohuN-OHOFMW8D0aOQgV9cP5Rl8z6yfmpos,1068
|
|
10
|
-
adv_optm/util/OrthoGrad.py,sha256=NzInuBQGy_Ja__M1R9XbvqVaQ0fhGbtGgFE9YON7B3I,707
|
|
11
|
-
adv_optm/util/Randomized_SVD.py,sha256=TFG417hh1t5f1n_mChnbgdQhpMoi37O04xVCe8wz8Qc,1708
|
|
12
|
-
adv_optm/util/__init__.py,sha256=3yYKo23JDfHDZdGcjrDKxH8nYjk5KDB-i44kW-J4sPk,367
|
|
13
|
-
adv_optm-0.1.1.dist-info/licenses/LICENSE,sha256=HrhfyXIkWY2tGFK11kg7vPCqhgh5DcxleloqdhrpyMY,11558
|
|
14
|
-
adv_optm-0.1.1.dist-info/METADATA,sha256=Mej63zbzvVh1YkAydQojP6SZSqz_46JA6-Y_3i3b2Fs,6342
|
|
15
|
-
adv_optm-0.1.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
16
|
-
adv_optm-0.1.1.dist-info/top_level.txt,sha256=iNfBIIzu-lPrQ7jyC56WBCcbkRwitM2nJ15-MRQ_6fg,9
|
|
17
|
-
adv_optm-0.1.1.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|