adv-optm 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of adv-optm might be problematic. Click here for more details.
- adv_optm/__init__.py +13 -0
- adv_optm/optim/AdamW_adv.py +293 -0
- adv_optm/optim/Adopt_adv.py +336 -0
- adv_optm/optim/Prodigy_adv.py +367 -0
- adv_optm/optim/__init__.py +9 -0
- adv_optm/util/BF16_Stochastic_Rounding.py +47 -0
- adv_optm/util/Effective_Shape.py +8 -0
- adv_optm/util/NNMF.py +18 -0
- adv_optm/util/One_Bit_Boolean.py +22 -0
- adv_optm/util/OrthoGrad.py +16 -0
- adv_optm/util/Randomized_SVD.py +37 -0
- adv_optm/util/__init__.py +11 -0
- adv_optm-0.1.0.dist-info/METADATA +134 -0
- adv_optm-0.1.0.dist-info/RECORD +17 -0
- adv_optm-0.1.0.dist-info/WHEEL +5 -0
- adv_optm-0.1.0.dist-info/licenses/LICENSE +201 -0
- adv_optm-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,367 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from typing import Optional
|
|
3
|
+
import math
|
|
4
|
+
|
|
5
|
+
from ..util.BF16_Stochastic_Rounding import add_stochastic_
|
|
6
|
+
from ..util.Effective_Shape import _get_effective_shape
|
|
7
|
+
from ..util.NNMF import _nnmf,_unnmf
|
|
8
|
+
from ..util.OrthoGrad import _orthogonalize_gradient
|
|
9
|
+
from ..util.One_Bit_Boolean import _pack_bools, _unpack_bools
|
|
10
|
+
|
|
11
|
+
class Prodigy_adv(torch.optim.Optimizer):
|
|
12
|
+
"""
|
|
13
|
+
Implements a factored Prodigy/AdamW algorithm.
|
|
14
|
+
This is an advanced version of Prodigy with optional features like
|
|
15
|
+
low-rank factorization of optimizer states (SMMF), OrthoGrad, etc.
|
|
16
|
+
|
|
17
|
+
Args:
|
|
18
|
+
params (iterable): iterable of parameters to optimize or dicts defining
|
|
19
|
+
parameter groups
|
|
20
|
+
lr (float): learning rate (default: 1e-3)
|
|
21
|
+
betas (tuple[float, float]): coefficients used for computing running
|
|
22
|
+
averages of gradient and its square (default: (0.9, 0.999))
|
|
23
|
+
eps (float): term added to the denominator to improve
|
|
24
|
+
numerical stability (default: 1e-8)
|
|
25
|
+
weight_decay (float): weight decay (L2 penalty) (default: 0)
|
|
26
|
+
vector_reshape (bool): whether to reshape 1D vectors into 2D
|
|
27
|
+
matrices to apply low-rank compression (default: True).
|
|
28
|
+
stochastic_rounding (bool): whether to use stochastic
|
|
29
|
+
rounding for BF16 parameter updates (default: True).
|
|
30
|
+
use_atan2 (bool): whether to use the atan2 update rule. (default: False)
|
|
31
|
+
use_grams (bool): whether to use Grams-style updates. (default: False)
|
|
32
|
+
use_cautious (bool): whether to use cautious masking to align the gradient's
|
|
33
|
+
direction with the first moment's. (default: False)
|
|
34
|
+
use_orthograd (bool): whether to use OrthoGrad. (default: False)
|
|
35
|
+
use_AdEMAMix (bool): whether to enable the AdEMAMix feature. This adds
|
|
36
|
+
a second, slow-moving average of the momentum (`mt_slow`) which is
|
|
37
|
+
combined with the primary momentum (`mt`) to stabilize updates,
|
|
38
|
+
especially in noisy, small-batch settings. If `False`, the
|
|
39
|
+
optimizer behaves as standard AdamW. (default: False)
|
|
40
|
+
beta3_ema (float): The decay rate for the slow exponential moving average of
|
|
41
|
+
the momentum (only used when `use_AdEMAMix` is `True`). A higher
|
|
42
|
+
value (e.g., 0.9999) gives the EMA a longer memory, making it more
|
|
43
|
+
stable but slower to adapt. A lower value (e.g., 0.999) is often
|
|
44
|
+
better for shorter training runs. (default: 0.9999)
|
|
45
|
+
alpha (float): The mixing coefficient that scales the slow momentum term
|
|
46
|
+
before it is added to the fast momentum term (`update = mt + alpha * mt_slow`).
|
|
47
|
+
A higher value increases the stabilizing influence of the slow
|
|
48
|
+
momentum. (default: 5.0)
|
|
49
|
+
t_alpha (Optional[int]): The number of steps for a linear warmup of the
|
|
50
|
+
`alpha` parameter (only used when `use_AdEMAMix` is `True`). This is
|
|
51
|
+
highly recommended to prevent instability at the beginning of training,
|
|
52
|
+
as it gradually introduces the stabilizing slow momentum term. During
|
|
53
|
+
the warmup, `alpha` ramps from 0 to its target value. If `None`,
|
|
54
|
+
the scheduler is disabled and th
|
|
55
|
+
factored (bool): whether to use the factorization or disable it to use
|
|
56
|
+
the uncompressed optimizer. (default: True)
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
def __init__(
|
|
60
|
+
self,
|
|
61
|
+
params,
|
|
62
|
+
lr: float = 1e-3,
|
|
63
|
+
betas: tuple[float, float] = (0.9, 0.999),
|
|
64
|
+
eps: float = 1e-8,
|
|
65
|
+
weight_decay: float = 0.0,
|
|
66
|
+
vector_reshape: bool = True,
|
|
67
|
+
stochastic_rounding: bool = True,
|
|
68
|
+
use_atan2: bool = False,
|
|
69
|
+
use_cautious: bool = False,
|
|
70
|
+
use_grams: bool = False,
|
|
71
|
+
use_orthograd: bool = False,
|
|
72
|
+
use_AdEMAMix: bool = False,
|
|
73
|
+
beta3_ema: float = 0.9999,
|
|
74
|
+
alpha: float = 5.0,
|
|
75
|
+
t_alpha: int | None = None,
|
|
76
|
+
factored: bool = True,
|
|
77
|
+
# prodigy parameters
|
|
78
|
+
beta3: float = None,
|
|
79
|
+
d0: float = 1e-6,
|
|
80
|
+
d_coef: float = 1,
|
|
81
|
+
growth_rate: float = float('inf'),
|
|
82
|
+
safeguard_warmup: bool = False,
|
|
83
|
+
slice_p: int = 11,
|
|
84
|
+
):
|
|
85
|
+
if not (lr >= 0.0):
|
|
86
|
+
raise ValueError(f"Learning-rate should be >= 0.0. Got {lr}")
|
|
87
|
+
if not (0.0 <= betas[0] < 1.0 and 0.0 <= betas[1] < 1.0):
|
|
88
|
+
raise ValueError(f"Betas should be in [0.0, 1.0). Got {betas}")
|
|
89
|
+
if not (eps >= 0.0):
|
|
90
|
+
raise ValueError(f"Epsilon should be >= 0.0. Got {eps}")
|
|
91
|
+
if not (weight_decay >= 0.0):
|
|
92
|
+
raise ValueError(f"Weight-decay should be >= 0.0. Got {weight_decay}")
|
|
93
|
+
|
|
94
|
+
defaults = {
|
|
95
|
+
"lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay,
|
|
96
|
+
"vector_reshape": vector_reshape, "use_atan2": use_atan2,
|
|
97
|
+
"use_orthograd": use_orthograd,
|
|
98
|
+
"beta3_ema": beta3_ema, "alpha": alpha, "t_alpha": t_alpha,
|
|
99
|
+
"beta3": beta3, "d": d0, "d0": d0, "d_max": d0, "d_numerator": 0.0, "d_coef": d_coef,
|
|
100
|
+
"growth_rate": growth_rate, "safeguard_warmup": safeguard_warmup, "k": 0, "slice_p": slice_p,
|
|
101
|
+
}
|
|
102
|
+
self.stochastic_rounding = stochastic_rounding
|
|
103
|
+
self.use_cautious = use_cautious
|
|
104
|
+
self.use_grams = use_grams
|
|
105
|
+
self.use_AdEMAMix = use_AdEMAMix
|
|
106
|
+
self.factored = factored
|
|
107
|
+
super().__init__(params, defaults)
|
|
108
|
+
self.init_step()
|
|
109
|
+
|
|
110
|
+
@property
|
|
111
|
+
def supports_fused_back_pass(self):
|
|
112
|
+
return True
|
|
113
|
+
|
|
114
|
+
@property
|
|
115
|
+
def supports_memory_efficient_fp16(self):
|
|
116
|
+
return True
|
|
117
|
+
|
|
118
|
+
@property
|
|
119
|
+
def supports_flat_params(self):
|
|
120
|
+
return False
|
|
121
|
+
|
|
122
|
+
def init_step(self):
|
|
123
|
+
"""Resets accumulators and calculates dlr for the upcoming step."""
|
|
124
|
+
self.d_denom = 0.0
|
|
125
|
+
|
|
126
|
+
g_group = self.param_groups[0]
|
|
127
|
+
self.beta1, self.beta2 = g_group['betas']
|
|
128
|
+
self.beta3 = g_group['beta3']
|
|
129
|
+
if self.beta3 is None:
|
|
130
|
+
self.beta3 = math.sqrt(self.beta2)
|
|
131
|
+
|
|
132
|
+
k = g_group['k']
|
|
133
|
+
self.d = g_group['d']
|
|
134
|
+
lr = g_group['lr']
|
|
135
|
+
|
|
136
|
+
self.dlr = self.d * lr
|
|
137
|
+
|
|
138
|
+
self.d_numerator = g_group.get('d_numerator', 0.0) * self.beta3
|
|
139
|
+
|
|
140
|
+
@torch.no_grad()
|
|
141
|
+
def step_parameter(self, p: torch.Tensor, group: dict, i: int | None = None):
|
|
142
|
+
if p.grad is None:
|
|
143
|
+
return
|
|
144
|
+
|
|
145
|
+
grad = p.grad
|
|
146
|
+
if grad.dtype != torch.float32 and self.factored:
|
|
147
|
+
grad = grad.float()
|
|
148
|
+
if group["use_orthograd"]:
|
|
149
|
+
grad = _orthogonalize_gradient(p, grad)
|
|
150
|
+
state = self.state[p]
|
|
151
|
+
|
|
152
|
+
# State Initialization
|
|
153
|
+
if len(state) == 0:
|
|
154
|
+
state['step'] = 0
|
|
155
|
+
|
|
156
|
+
should_factor = (
|
|
157
|
+
self.factored and
|
|
158
|
+
not (len(p.shape) == 1 and not group['vector_reshape'])
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
state['factored'] = should_factor
|
|
162
|
+
|
|
163
|
+
slice_p = group['slice_p']
|
|
164
|
+
|
|
165
|
+
dtype = torch.float32 if self.factored else p.dtype
|
|
166
|
+
device = p.device
|
|
167
|
+
|
|
168
|
+
if state['factored']:
|
|
169
|
+
state['effective_shape'] = _get_effective_shape(p.numel())
|
|
170
|
+
d1, d2 = state['effective_shape']
|
|
171
|
+
|
|
172
|
+
# First moment (m)
|
|
173
|
+
state['mu_m_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
|
|
174
|
+
state['mv_m_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
|
|
175
|
+
if not self.use_grams:
|
|
176
|
+
packed_d2 = (d2 + 7) // 8
|
|
177
|
+
state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=device)
|
|
178
|
+
if self.use_AdEMAMix:
|
|
179
|
+
state['mu_m_slow_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
|
|
180
|
+
state['mv_m_slow_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
|
|
181
|
+
packed_d2 = (d2 + 7) // 8
|
|
182
|
+
state['sign_slow'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=p.device)
|
|
183
|
+
# Second moment (v)
|
|
184
|
+
state['mu_v_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
|
|
185
|
+
state['mv_v_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
|
|
186
|
+
else: # Fallback to standard AdamW for non-factored tensors
|
|
187
|
+
state['exp_avg'] = torch.zeros_like(p, device=device, dtype=dtype)
|
|
188
|
+
if self.use_AdEMAMix:
|
|
189
|
+
state['exp_avg_slow'] = torch.zeros_like(p, dtype=dtype)
|
|
190
|
+
state['exp_avg_sq'] = torch.zeros_like(p, device=device, dtype=dtype)
|
|
191
|
+
|
|
192
|
+
state['s'] = torch.zeros_like(p.flatten()[::slice_p]).detach()
|
|
193
|
+
if p.any():
|
|
194
|
+
state['p0'] = p.flatten()[::slice_p].detach().clone()
|
|
195
|
+
else:
|
|
196
|
+
state['p0'] = torch.tensor(0, device=device, dtype=p.dtype)
|
|
197
|
+
|
|
198
|
+
if self.use_AdEMAMix:
|
|
199
|
+
beta3_ema = group['beta3_ema']
|
|
200
|
+
alpha = group['alpha']
|
|
201
|
+
t_alpha = group['t_alpha']
|
|
202
|
+
current_step = state['step'] + 1
|
|
203
|
+
alpha_t = alpha
|
|
204
|
+
if t_alpha is not None and t_alpha > 0 and current_step < t_alpha:
|
|
205
|
+
alpha_t = min(current_step * alpha / t_alpha, alpha)
|
|
206
|
+
|
|
207
|
+
if state['factored']:
|
|
208
|
+
d1, d2 = state['effective_shape']
|
|
209
|
+
|
|
210
|
+
# Reconstruct momentum from previous step's factors
|
|
211
|
+
mt = _unnmf((state['mu_m_nmf'], state['mv_m_nmf']))
|
|
212
|
+
if not self.use_grams:
|
|
213
|
+
unpacked_sign = _unpack_bools(state['sign'], original_m=d2)
|
|
214
|
+
torch.where(unpacked_sign, mt, -mt, out=mt)
|
|
215
|
+
del unpacked_sign
|
|
216
|
+
# Update momentum in full-size
|
|
217
|
+
grad_reshaped = grad.view(d1, d2)
|
|
218
|
+
mt.mul_(self.beta1).add_(grad_reshaped, alpha=self.d * (1.0 - self.beta1))
|
|
219
|
+
if self.use_grams:
|
|
220
|
+
mt.copy_(grad_reshaped.sign() * mt.abs())
|
|
221
|
+
elif self.use_cautious:
|
|
222
|
+
mask = (mt * grad_reshaped > 0).to(grad_reshaped.dtype)
|
|
223
|
+
mask.div_(mask.mean().clamp_(min=1e-3))
|
|
224
|
+
mt.mul_(mask)
|
|
225
|
+
del mask
|
|
226
|
+
|
|
227
|
+
vt = _unnmf((state['mu_v_nmf'], state['mv_v_nmf']))
|
|
228
|
+
vt.mul_(self.beta2).addcmul_(grad_reshaped, grad_reshaped, value=self.d * self.d * (1.0 - self.beta2))
|
|
229
|
+
|
|
230
|
+
if self.use_AdEMAMix:
|
|
231
|
+
mt_slow = _unnmf((state['mu_m_slow_nmf'], state['mv_m_slow_nmf']))
|
|
232
|
+
if state['sign_slow'].dtype != torch.uint8:
|
|
233
|
+
state['sign_slow'] = state['sign_slow'].to(torch.uint8)
|
|
234
|
+
unpacked_sign_slow = _unpack_bools(state['sign_slow'], original_m=d2)
|
|
235
|
+
torch.where(unpacked_sign_slow, mt_slow, -mt_slow, out=mt_slow)
|
|
236
|
+
del unpacked_sign_slow
|
|
237
|
+
|
|
238
|
+
mt_slow.mul_(beta3_ema).add_(grad_reshaped, alpha=self.d * (1.0 - beta3_ema))
|
|
239
|
+
update_m = mt + (alpha_t * mt_slow)
|
|
240
|
+
else:
|
|
241
|
+
update_m = mt
|
|
242
|
+
del grad_reshaped
|
|
243
|
+
|
|
244
|
+
if group['use_atan2']:
|
|
245
|
+
a = 1.2732395
|
|
246
|
+
denom = vt.sqrt()
|
|
247
|
+
update = torch.atan2(update_m, denom).mul_(a)
|
|
248
|
+
else:
|
|
249
|
+
denom = vt.sqrt().add_(group['eps'])
|
|
250
|
+
update = update_m / denom
|
|
251
|
+
del update_m, denom
|
|
252
|
+
|
|
253
|
+
update = update.view(p.shape)
|
|
254
|
+
update.mul_(self.dlr)
|
|
255
|
+
|
|
256
|
+
# Compress updated moments and store new factors
|
|
257
|
+
if not self.use_grams:
|
|
258
|
+
state['sign'] = _pack_bools(mt > 0)
|
|
259
|
+
_nnmf(mt.abs(), out=(state['mu_m_nmf'], state['mv_m_nmf']))
|
|
260
|
+
del mt
|
|
261
|
+
if self.use_AdEMAMix:
|
|
262
|
+
state['sign_slow'] = _pack_bools(mt_slow > 0)
|
|
263
|
+
_nnmf(mt_slow.abs(), out=(state['mu_m_slow_nmf'], state['mv_m_slow_nmf']))
|
|
264
|
+
del mt_slow
|
|
265
|
+
_nnmf(vt, out=(state['mu_v_nmf'], state['mv_v_nmf']))
|
|
266
|
+
del vt
|
|
267
|
+
|
|
268
|
+
else: # Standard AdamW logic for non-factored tensors
|
|
269
|
+
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
|
|
270
|
+
|
|
271
|
+
exp_avg.mul_(self.beta1).add_(grad, alpha=self.d * (1.0 - self.beta1))
|
|
272
|
+
if self.use_grams:
|
|
273
|
+
exp_avg = grad.sign() * exp_avg.abs()
|
|
274
|
+
elif self.use_cautious:
|
|
275
|
+
mask = (exp_avg * grad > 0).to(grad.dtype)
|
|
276
|
+
mask.div_(mask.mean().clamp_(min=1e-3))
|
|
277
|
+
exp_avg.mul_(mask)
|
|
278
|
+
del mask
|
|
279
|
+
|
|
280
|
+
if self.use_AdEMAMix:
|
|
281
|
+
exp_avg_slow = state['exp_avg_slow']
|
|
282
|
+
exp_avg_slow.mul_(beta3_ema).add_(grad, alpha=self.d * (1.0 - beta3_ema))
|
|
283
|
+
update_m = exp_avg + (alpha_t * exp_avg_slow)
|
|
284
|
+
else:
|
|
285
|
+
update_m = exp_avg
|
|
286
|
+
|
|
287
|
+
exp_avg_sq.mul_(self.beta2).addcmul_(grad, grad.conj(), value=self.d * self.d * (1.0 - self.beta2))
|
|
288
|
+
|
|
289
|
+
if group['use_atan2']:
|
|
290
|
+
a = 1.2732395
|
|
291
|
+
denom = exp_avg_sq.sqrt()
|
|
292
|
+
update = torch.atan2(update_m, denom).mul_(a)
|
|
293
|
+
else:
|
|
294
|
+
denom = exp_avg_sq.sqrt().add_(group['eps'])
|
|
295
|
+
update = update_m / denom
|
|
296
|
+
del update_m, denom
|
|
297
|
+
|
|
298
|
+
update = update.mul_(self.dlr)
|
|
299
|
+
|
|
300
|
+
# --- Accumulate Prodigy stats ---
|
|
301
|
+
d0, safeguard_warmup, slice_p = group['d0'], group['safeguard_warmup'], group['slice_p']
|
|
302
|
+
s, p0 = state['s'], state['p0']
|
|
303
|
+
grad_flat = grad.flatten().float()
|
|
304
|
+
p_flat = p.data.flatten().float()
|
|
305
|
+
p0 = p0.float()
|
|
306
|
+
|
|
307
|
+
self.d_numerator += (self.d / d0) * self.dlr * torch.dot(grad_flat[::slice_p], p0.data - p_flat[::slice_p]).item()
|
|
308
|
+
|
|
309
|
+
alpha = ((self.d / d0) * self.d) if safeguard_warmup else ((self.d / d0) * self.dlr)
|
|
310
|
+
s.mul_(self.beta3).add_(grad_flat[::slice_p], alpha=alpha)
|
|
311
|
+
self.d_denom += s.abs().sum().item()
|
|
312
|
+
|
|
313
|
+
del s, p0, grad_flat, p_flat, alpha
|
|
314
|
+
|
|
315
|
+
# Decoupled weight decay
|
|
316
|
+
if group["weight_decay"] != 0:
|
|
317
|
+
if p.dtype == torch.bfloat16 and self.stochastic_rounding:
|
|
318
|
+
add_stochastic_(p.data, p.data, alpha=-group["weight_decay"] * self.dlr)
|
|
319
|
+
else:
|
|
320
|
+
p.data.add_(p.data, alpha=-group["weight_decay"] * self.dlr)
|
|
321
|
+
|
|
322
|
+
if p.dtype == torch.bfloat16 and self.stochastic_rounding:
|
|
323
|
+
add_stochastic_(p.data, -update)
|
|
324
|
+
else:
|
|
325
|
+
p.data.add_(-update)
|
|
326
|
+
del update
|
|
327
|
+
|
|
328
|
+
state['step'] += 1
|
|
329
|
+
|
|
330
|
+
@torch.no_grad()
|
|
331
|
+
def step(self, closure=None):
|
|
332
|
+
"""Performs a single optimization step."""
|
|
333
|
+
loss = None
|
|
334
|
+
if closure is not None:
|
|
335
|
+
with torch.enable_grad():
|
|
336
|
+
loss = closure()
|
|
337
|
+
|
|
338
|
+
for group in self.param_groups:
|
|
339
|
+
for i, p in enumerate(group['params']):
|
|
340
|
+
self.step_parameter(p, group, i)
|
|
341
|
+
|
|
342
|
+
|
|
343
|
+
self.calculate_d()
|
|
344
|
+
self.init_step()
|
|
345
|
+
return loss
|
|
346
|
+
|
|
347
|
+
def calculate_d(self):
|
|
348
|
+
"""Calculates the new `d` based on the accumulated stats."""
|
|
349
|
+
g_group = self.param_groups[0]
|
|
350
|
+
d_max, d_coef, growth_rate = g_group['d_max'], g_group['d_coef'], g_group['growth_rate']
|
|
351
|
+
|
|
352
|
+
global_d_numerator = self.d_numerator
|
|
353
|
+
global_d_denom = self.d_denom
|
|
354
|
+
|
|
355
|
+
d_hat = self.d
|
|
356
|
+
if global_d_denom > 0:
|
|
357
|
+
d_hat = d_coef * global_d_numerator / global_d_denom
|
|
358
|
+
if self.d == g_group['d0']:
|
|
359
|
+
self.d = max(self.d, d_hat)
|
|
360
|
+
d_max = max(d_max, d_hat)
|
|
361
|
+
self.d = min(d_max, self.d * growth_rate)
|
|
362
|
+
|
|
363
|
+
for group in self.param_groups:
|
|
364
|
+
group['d_numerator'] = global_d_numerator
|
|
365
|
+
group['d'] = self.d
|
|
366
|
+
group['d_max'] = d_max
|
|
367
|
+
group['k'] += 1
|
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from torch import Tensor
|
|
3
|
+
|
|
4
|
+
def copy_stochastic_(target: Tensor, source: Tensor):
|
|
5
|
+
"""
|
|
6
|
+
Nerogar's implementation of stochastic rounding in the paper "Revisiting BFloat16 Training"
|
|
7
|
+
(https://arxiv.org/abs/2010.06192).
|
|
8
|
+
see:
|
|
9
|
+
https://github.com/pytorch/pytorch/issues/120376
|
|
10
|
+
https://github.com/Nerogar/OneTrainer/blob/daae18eaed8c0fa39289b2ff79cc2c1e08577fcb/modules/util/bf16_stochastic_rounding.py
|
|
11
|
+
|
|
12
|
+
Args:
|
|
13
|
+
target: the target tensor with dtype=bfloat16
|
|
14
|
+
source: the target tensor with dtype=float32
|
|
15
|
+
"""
|
|
16
|
+
# create a random 16 bit integer
|
|
17
|
+
result = torch.randint_like(
|
|
18
|
+
source,
|
|
19
|
+
dtype=torch.int32,
|
|
20
|
+
low=0,
|
|
21
|
+
high=(1 << 16),
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
# add the random number to the lower 16 bit of the mantissa
|
|
25
|
+
result.add_(source.view(dtype=torch.int32))
|
|
26
|
+
|
|
27
|
+
# mask off the lower 16 bit of the mantissa
|
|
28
|
+
result.bitwise_and_(-65536) # -65536 = FFFF0000 as a signed int32
|
|
29
|
+
|
|
30
|
+
# copy the higher 16 bit into the target tensor
|
|
31
|
+
target.copy_(result.view(dtype=torch.float32))
|
|
32
|
+
|
|
33
|
+
del result
|
|
34
|
+
|
|
35
|
+
def add_stochastic_(input: Tensor, other: Tensor, alpha: float = 1.0):
|
|
36
|
+
"""
|
|
37
|
+
adds other to input using stochastic rounding
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
input: the input tensor with dtype=bfloat16
|
|
41
|
+
other: the other tensor
|
|
42
|
+
alpha: a multiplier for other
|
|
43
|
+
"""
|
|
44
|
+
result = other.clone() if other.dtype == torch.float32 else other.to(dtype=torch.float32)
|
|
45
|
+
|
|
46
|
+
result.add_(input, alpha=alpha)
|
|
47
|
+
copy_stochastic_(input, result)
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
def _get_effective_shape(numel: int) -> tuple[int, int]:
|
|
2
|
+
"""Finds two factors of numel that are closest to its square root."""
|
|
3
|
+
if numel <= 0:
|
|
4
|
+
return (0, 0)
|
|
5
|
+
for i in reversed(range(1, int(numel ** 0.5) + 1)):
|
|
6
|
+
if numel % i == 0:
|
|
7
|
+
return (numel // i, i)
|
|
8
|
+
return (numel, 1)
|
adv_optm/util/NNMF.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
def _unnmf(row_col: tuple) -> torch.Tensor:
|
|
4
|
+
"""Reconstructs a matrix from its rank-1 factors (outer product)."""
|
|
5
|
+
return torch.outer(row_col[0], row_col[1])
|
|
6
|
+
|
|
7
|
+
def _nnmf(matrix: torch.Tensor, out: tuple):
|
|
8
|
+
"""Performs a rank-1 non-negative matrix factorization."""
|
|
9
|
+
shape = matrix.shape
|
|
10
|
+
torch.sum(matrix, dim=1, out=out[0])
|
|
11
|
+
torch.sum(matrix, dim=0, out=out[1])
|
|
12
|
+
# Normalize one of the factors for stability
|
|
13
|
+
if shape[0] < shape[1]:
|
|
14
|
+
scale = out[0].sum()
|
|
15
|
+
if scale != 0: out[0].div_(scale)
|
|
16
|
+
else:
|
|
17
|
+
scale = out[1].sum()
|
|
18
|
+
if scale != 0: out[1].div_(scale)
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
@torch.no_grad()
|
|
4
|
+
def _pack_bools(tensor: torch.Tensor) -> torch.Tensor:
|
|
5
|
+
"""Packs a boolean tensor into a uint8 tensor to achieve 1-bit storage."""
|
|
6
|
+
n, m = tensor.shape
|
|
7
|
+
packed_m = (m + 7) // 8
|
|
8
|
+
padded_tensor = torch.nn.functional.pad(tensor, (0, packed_m * 8 - m), 'constant', 0)
|
|
9
|
+
reshaped = padded_tensor.view(n, packed_m, 8)
|
|
10
|
+
shifter = torch.arange(8, device=tensor.device, dtype=torch.uint8)
|
|
11
|
+
packed = (reshaped.to(torch.uint8) * (2**shifter)).sum(dim=2).to(torch.uint8)
|
|
12
|
+
return packed
|
|
13
|
+
|
|
14
|
+
@torch.no_grad()
|
|
15
|
+
def _unpack_bools(packed_tensor: torch.Tensor, original_m: int) -> torch.Tensor:
|
|
16
|
+
"""Unpacks a uint8 tensor back into a boolean tensor."""
|
|
17
|
+
if packed_tensor.dtype != torch.uint8:
|
|
18
|
+
packed_tensor = packed_tensor.to(torch.uint8)
|
|
19
|
+
shifter = (2**torch.arange(8, device=packed_tensor.device, dtype=torch.uint8)).view(1, 1, 8)
|
|
20
|
+
unpacked_padded = (packed_tensor.unsqueeze(2) & shifter) != 0
|
|
21
|
+
unpacked = unpacked_padded.view(packed_tensor.shape[0], -1)[:, :original_m]
|
|
22
|
+
return unpacked
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
def _orthogonalize_gradient(p: torch.Tensor, grad: torch.Tensor) -> torch.Tensor:
|
|
4
|
+
"""Projects the gradient `grad` to be orthogonal to the parameter `p`."""
|
|
5
|
+
if grad.is_sparse: raise RuntimeError("OrthoGrad logic does not support sparse gradients.")
|
|
6
|
+
original_shape = grad.shape
|
|
7
|
+
original_dtype = grad.dtype
|
|
8
|
+
w = p.view(-1).float()
|
|
9
|
+
g = grad.view(-1).float()
|
|
10
|
+
w_norm_sq = torch.dot(w, w).add_(1e-30)
|
|
11
|
+
proj = torch.dot(w, g) / w_norm_sq
|
|
12
|
+
g_orth = g.sub(w, alpha=proj)
|
|
13
|
+
g_norm = g.norm(2)
|
|
14
|
+
g_orth_norm = g_orth.norm(2).add_(1e-30)
|
|
15
|
+
g_orth_scaled = g_orth * (g_norm / g_orth_norm)
|
|
16
|
+
return g_orth_scaled.view(original_shape).to(original_dtype)
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from torch import Tensor
|
|
3
|
+
|
|
4
|
+
from typing import Tuple
|
|
5
|
+
|
|
6
|
+
def _rsvd(A: torch.Tensor, rank: int, oversampling: int) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
7
|
+
"""Performs Randomized SVD."""
|
|
8
|
+
orig_dtype, device, (m, n) = A.dtype, A.device, A.shape
|
|
9
|
+
A_float = A.float()
|
|
10
|
+
l, true_rank = rank + oversampling, min(m, n, rank)
|
|
11
|
+
|
|
12
|
+
if true_rank == 0:
|
|
13
|
+
return (
|
|
14
|
+
torch.zeros(m, rank, dtype=orig_dtype, device=device),
|
|
15
|
+
torch.zeros(rank, dtype=orig_dtype, device=device),
|
|
16
|
+
torch.zeros(rank, n, dtype=orig_dtype, device=device),
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
if l >= min(m, n): # Fallback to full SVD
|
|
20
|
+
U_full, S_full, Vh_full = torch.linalg.svd(A_float, full_matrices=False)
|
|
21
|
+
U, S, Vh = U_full[:, :true_rank], S_full[:true_rank], Vh_full[:true_rank, :]
|
|
22
|
+
else: # Standard RSVD path
|
|
23
|
+
Omega = torch.randn(n, l, dtype=A_float.dtype, device=device)
|
|
24
|
+
Y = A_float @ Omega
|
|
25
|
+
Q, _ = torch.linalg.qr(Y.float())
|
|
26
|
+
B = Q.T @ A_float
|
|
27
|
+
U_tilde, S, Vh = torch.linalg.svd(B.float(), full_matrices=False)
|
|
28
|
+
U, S, Vh = (Q @ U_tilde)[:, :true_rank], S[:true_rank], Vh[:true_rank, :]
|
|
29
|
+
|
|
30
|
+
if true_rank < rank: # Pad factors with zeros
|
|
31
|
+
U_padded = torch.zeros(m, rank, dtype=A_float.dtype, device=device)
|
|
32
|
+
S_padded = torch.zeros(rank, dtype=A_float.dtype, device=device)
|
|
33
|
+
Vh_padded = torch.zeros(rank, n, dtype=A_float.dtype, device=device)
|
|
34
|
+
U_padded[:, :true_rank], S_padded[:true_rank], Vh_padded[:true_rank, :] = U, S, Vh
|
|
35
|
+
U, S, Vh = U_padded, S_padded, Vh_padded
|
|
36
|
+
|
|
37
|
+
return U.to(orig_dtype), S.to(orig_dtype), Vh.to(orig_dtype)
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
from .BF16_Stochastic_Rounding import add_stochastic_, copy_stochastic_
|
|
2
|
+
from .Effective_Shape import _get_effective_shape
|
|
3
|
+
from .One_Bit_Boolean import _pack_bools, _unpack_bools
|
|
4
|
+
from .OrthoGrad import _orthogonalize_gradient
|
|
5
|
+
|
|
6
|
+
__all__ = [
|
|
7
|
+
"_pack_bools", "_unpack_bools",
|
|
8
|
+
"add_stochastic_",
|
|
9
|
+
"_get_effective_shape",
|
|
10
|
+
"_orthogonalize_gradient",
|
|
11
|
+
]
|
|
@@ -0,0 +1,134 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: adv_optm
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: A family of highly efficient, lightweight yet powerful optimizers.
|
|
5
|
+
Home-page: https://github.com/Koratahiu/Advanced_Optimizers
|
|
6
|
+
Author: Koratahiu
|
|
7
|
+
Author-email: hiuhonor@gmail.com
|
|
8
|
+
License: Apache 2.0
|
|
9
|
+
Keywords: llm,fine-tuning,memory-efficient,low-rank,compression,pytorch,optimizer,adam
|
|
10
|
+
Classifier: Programming Language :: Python :: 3
|
|
11
|
+
Classifier: License :: OSI Approved :: Apache Software License
|
|
12
|
+
Classifier: Operating System :: OS Independent
|
|
13
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
14
|
+
Classifier: Topic :: Software Development :: Libraries :: Python Modules
|
|
15
|
+
Requires-Python: >=3.8
|
|
16
|
+
Description-Content-Type: text/markdown
|
|
17
|
+
License-File: LICENSE
|
|
18
|
+
Requires-Dist: torch>=2.0
|
|
19
|
+
Dynamic: author
|
|
20
|
+
Dynamic: author-email
|
|
21
|
+
Dynamic: classifier
|
|
22
|
+
Dynamic: description
|
|
23
|
+
Dynamic: description-content-type
|
|
24
|
+
Dynamic: home-page
|
|
25
|
+
Dynamic: keywords
|
|
26
|
+
Dynamic: license
|
|
27
|
+
Dynamic: license-file
|
|
28
|
+
Dynamic: requires-dist
|
|
29
|
+
Dynamic: requires-python
|
|
30
|
+
Dynamic: summary
|
|
31
|
+
|
|
32
|
+
# Advanced Optimizers
|
|
33
|
+
|
|
34
|
+
This repo introduces a new family of highly efficient, lightweight yet powerful optimizers, born from extensive research into recent academic literature and validated through practical training runs across diverse models.
|
|
35
|
+
|
|
36
|
+
---
|
|
37
|
+
|
|
38
|
+
### Install
|
|
39
|
+
|
|
40
|
+
`pip install adv_optm`
|
|
41
|
+
|
|
42
|
+
---
|
|
43
|
+
|
|
44
|
+
### Theory (Inspired by SMMF)
|
|
45
|
+
|
|
46
|
+
Based primarily on:
|
|
47
|
+
**[SMMF: Square-Matricized Momentum Factorization for Memory-Efficient Optimization](https://arxiv.org/abs/2412.08894)**
|
|
48
|
+
|
|
49
|
+
The core innovation:
|
|
50
|
+
- Uses fast, non-negative matrix factorization (rank 1, à la Adafactor), but **reconstructs the full state before each update** to preserve momentum accuracy, then re-factors afterward (factor → reconstruct → update → factor cycle).
|
|
51
|
+
- For the *signed first moment*, we split into **sign + absolute value**:
|
|
52
|
+
- Sign is stored as **1-bit state** via bitwise ops (SMMF originally used 8-bit with 7 bits wasted).
|
|
53
|
+
- Absolute value goes through the factor/reconstruct cycle using two factored vectors + the signed state.
|
|
54
|
+
- Final storage: **four factored vectors + one 1-bit sign**.
|
|
55
|
+
- Updates behave like full-state Adam but with drastically reduced memory.
|
|
56
|
+
|
|
57
|
+
> ✅ **TL;DR**: Lightweight, strong, memory-efficient optimizer.
|
|
58
|
+
|
|
59
|
+
---
|
|
60
|
+
|
|
61
|
+
### Memory Cost
|
|
62
|
+
|
|
63
|
+
- **Adopt_Factored** for full SDXL finetune: **328 MB** (4 small vectors + 1-bit state)
|
|
64
|
+
- **Adopt_Factored with AdEMAMix** for full SDXL finetune: **625 MB** (6 small vectors + two 1-bit states)
|
|
65
|
+
> SDXL is 6.5GB model.
|
|
66
|
+
|
|
67
|
+
---
|
|
68
|
+
|
|
69
|
+
### ⏱️ Speed (my tests in SDXL - BS 4)
|
|
70
|
+
|
|
71
|
+
- **Adopt_Factored**: ~10s/it
|
|
72
|
+
- **Adopt_Factored with AdEMAMix**: ~12s/it
|
|
73
|
+
- **Adafactor**: ~8.5s/it
|
|
74
|
+
→ Overhead from compression/reconstruction cycles.
|
|
75
|
+
→ It's faster than [MLorc](https://arxiv.org/abs/2506.01897) (~12s/it), which uses RSVD compression, and should be the fastest momentum compression (AFAIK).
|
|
76
|
+
|
|
77
|
+
---
|
|
78
|
+
|
|
79
|
+
### 📈 Performance
|
|
80
|
+
|
|
81
|
+
- **Better than Adafactor, and CAME factorzation methods**
|
|
82
|
+
- **Comparable or identical to Adam** (see SMMF paper results)
|
|
83
|
+
|
|
84
|
+
---
|
|
85
|
+
|
|
86
|
+
### Available Optimizers (all support `Factored` toggle)
|
|
87
|
+
|
|
88
|
+
Set `Factored=False` to disable factorization and run as a full uncompressed optimizer (like vanilla Adam).
|
|
89
|
+
|
|
90
|
+
1. **Adam**
|
|
91
|
+
2. **Prodigy**
|
|
92
|
+
3. **Adopt**
|
|
93
|
+
|
|
94
|
+
---
|
|
95
|
+
|
|
96
|
+
### Bonus Features (Built-in)
|
|
97
|
+
|
|
98
|
+
- **Fused Backward Pass**
|
|
99
|
+
|
|
100
|
+
- **Stochastic Rounding (SR)**: Improves quality and convergence for **BF16 training**.
|
|
101
|
+
|
|
102
|
+
- **[AdEMAMix](https://arxiv.org/abs/2409.03137)**
|
|
103
|
+
→ This adds a second, slow-moving EMA, which is combined with the primary momentum to stabilize updates, especially during long runs of full finetuning.
|
|
104
|
+
→ A higher value of beta3 (e.g., 0.9999) gives the EMA a longer memory, making it more stable but slower to adapt. A lower value (e.g., 0.999) is often better for shorter training runs (2k-4k steps).
|
|
105
|
+
→ When `factored` is true, it compresses the new momentum in the same way as the first moment (1-bit state + 2 vectors). However, this introduces noticeable overhead as we are compressing/reconstructing a third state each step.
|
|
106
|
+
|
|
107
|
+
⚠️ **Note**: AdEMAMix updates are more aggressive than normal Adam/Adopt, so use a x2-x5 smaller LR than usual (or use Prodigy).
|
|
108
|
+
|
|
109
|
+
⚠️ **Note**: The factored AdEMAMix is **Experimental** (as it needs more tests and validation, but it should work). Also, Adopt with AdEMAMix is **Experimental** (as Adopt normalizes the gradients for the momentum).
|
|
110
|
+
|
|
111
|
+
- **[`atan2` smoothing & scaling](https://github.com/lucidrains/adam-atan2-pytorch)**
|
|
112
|
+
→ Robust `eps` replacement (no tuning!) + built-in gradient clipping
|
|
113
|
+
→ *Ideal for ADOPT* (which normally needs higher `eps` and clipping), so `use_atan2` is all-in-one for it.
|
|
114
|
+
|
|
115
|
+
- **[OrthoGrad](https://github.com/LucasPrietoAl/grokking-at-the-edge-of-numerical-stability)**
|
|
116
|
+
→ Removes gradient component parallel to weights → prevents "naïve loss minimization" (NLM) → reduces natural overfitting
|
|
117
|
+
→ Perfect for fine-tuning the direction of existing features (e.g., full finetune or training a trained LoRA) without weight decay erasing prior knowledge.
|
|
118
|
+
|
|
119
|
+
⚠️ **Note**: OrthoGrad introduces **~33% time overhead**, so take this into account.
|
|
120
|
+
|
|
121
|
+
- **[Grams: Gradient Descent with Adaptive Momentum Scaling](https://github.com/Gunale0926/Grams)**
|
|
122
|
+
→ Eliminates the need for 1-bit momentum sign storage by using the **sign of gradients** for the first moment.
|
|
123
|
+
|
|
124
|
+
⚠️ **Not recommended for small batch sizes**: gradients are too noisy, which can destabilize momentum (tested for Prodigy and it made the optimizer slower to find the LR or converge in BS 4).
|
|
125
|
+
|
|
126
|
+
### Other Notes
|
|
127
|
+
|
|
128
|
+
- **Adopt** skips the first step (only initializes the states) and has built-in clipping (sticking to the original optimizer), but we skip both of these when you enable `use_atan2`; as the optimizer becomes scale-invariant and the values of the states won't cause any issues or instability.
|
|
129
|
+
|
|
130
|
+
- When `use_atan2` is True, `eps` will be ignored and you should also disable any gradient clipping.
|
|
131
|
+
|
|
132
|
+
- I don't recommend using **OrthoGrad** for training LoRA or embeddings, as their weights are zero-initialized and using weight decay for them should be safe and also beneficial (OrthoGrad is intended for fine-tuning pretrained models with no weight decay).
|
|
133
|
+
|
|
134
|
+
---
|