adv-optm 1.0.0__py3-none-any.whl → 1.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of adv-optm might be problematic. Click here for more details.
- adv_optm/__init__.py +1 -1
- adv_optm/optim/AdamW_adv.py +19 -19
- adv_optm/optim/Adopt_adv.py +24 -24
- adv_optm/optim/Lion_Prodigy_adv.py +8 -8
- adv_optm/optim/Lion_adv.py +8 -8
- adv_optm/optim/Prodigy_adv.py +475 -475
- adv_optm/optim/Simplified_AdEMAMix.py +3 -3
- {adv_optm-1.0.0.dist-info → adv_optm-1.0.1.dist-info}/METADATA +1 -1
- adv_optm-1.0.1.dist-info/RECORD +19 -0
- adv_optm-1.0.0.dist-info/RECORD +0 -19
- {adv_optm-1.0.0.dist-info → adv_optm-1.0.1.dist-info}/WHEEL +0 -0
- {adv_optm-1.0.0.dist-info → adv_optm-1.0.1.dist-info}/licenses/LICENSE +0 -0
- {adv_optm-1.0.0.dist-info → adv_optm-1.0.1.dist-info}/top_level.txt +0 -0
adv_optm/optim/Prodigy_adv.py
CHANGED
|
@@ -1,475 +1,475 @@
|
|
|
1
|
-
import torch
|
|
2
|
-
import torch.distributed as dist
|
|
3
|
-
|
|
4
|
-
import math
|
|
5
|
-
|
|
6
|
-
from ..util.BF16_Stochastic_Rounding import add_stochastic_
|
|
7
|
-
from ..util.Effective_Shape import _get_effective_shape
|
|
8
|
-
from ..util.NNMF import _nnmf,_unnmf
|
|
9
|
-
from ..util.OrthoGrad import _orthogonalize_gradient
|
|
10
|
-
from ..util.One_Bit_Boolean import _pack_bools, _unpack_bools
|
|
11
|
-
|
|
12
|
-
class Prodigy_adv(torch.optim.Optimizer):
|
|
13
|
-
"""
|
|
14
|
-
Implements a factored Prodigy/AdamW algorithm.
|
|
15
|
-
This is an advanced version of Prodigy with optional features like
|
|
16
|
-
low-rank factorization of optimizer states (SMMF), OrthoGrad, etc.
|
|
17
|
-
|
|
18
|
-
Args:
|
|
19
|
-
params (iterable): iterable of parameters to optimize or dicts defining
|
|
20
|
-
parameter groups
|
|
21
|
-
lr (float): learning rate (default: 1)
|
|
22
|
-
betas (tuple[float, float]): coefficients used for computing running
|
|
23
|
-
averages of gradient and its square (default: (0.9, 0.999))
|
|
24
|
-
eps (float): term added to the denominator to improve
|
|
25
|
-
numerical stability (default: 1e-8)
|
|
26
|
-
weight_decay (float): weight decay (L2 penalty) (default: 0)
|
|
27
|
-
vector_reshape (bool): whether to reshape 1D vectors into 2D
|
|
28
|
-
matrices to apply low-rank compression (default: True).
|
|
29
|
-
stochastic_rounding (bool): whether to use stochastic
|
|
30
|
-
rounding for BF16 parameter updates (default: True).
|
|
31
|
-
use_atan2 (bool): whether to use the atan2 update rule. (default: False)
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
direction with the first moment's. (default: False)
|
|
35
|
-
use_orthograd (bool): whether to use OrthoGrad. (default: False)
|
|
36
|
-
use_AdEMAMix (bool): whether to enable the AdEMAMix feature. This adds
|
|
37
|
-
a second, slow-moving average of the momentum (`mt_slow`) which is
|
|
38
|
-
combined with the primary momentum (`mt`) to stabilize updates,
|
|
39
|
-
especially in noisy, small-batch settings. If `False`, the
|
|
40
|
-
optimizer behaves as standard AdamW. (default: False)
|
|
41
|
-
beta3_ema (float): The decay rate for the slow exponential moving average of
|
|
42
|
-
the momentum (only used when `use_AdEMAMix` is `True`). A higher
|
|
43
|
-
value (e.g., 0.9999) gives the EMA a longer memory, making it more
|
|
44
|
-
stable but slower to adapt. A lower value (e.g., 0.999) is often
|
|
45
|
-
better for shorter training runs. (default: 0.9999)
|
|
46
|
-
alpha (float): The mixing coefficient that scales the slow momentum term
|
|
47
|
-
before it is added to the fast momentum term (`update = mt + alpha * mt_slow`).
|
|
48
|
-
A higher value increases the stabilizing influence of the slow
|
|
49
|
-
momentum. (default: 5.0)
|
|
50
|
-
t_alpha (Optional[int]): The number of steps for a linear warmup of the
|
|
51
|
-
`alpha` parameter (only used when `use_AdEMAMix` is `True`). This is
|
|
52
|
-
highly recommended to prevent instability at the beginning of training,
|
|
53
|
-
as it gradually introduces the stabilizing slow momentum term. During
|
|
54
|
-
the warmup, `alpha` ramps from 0 to its target value. If `None`,
|
|
55
|
-
the scheduler is disabled.
|
|
56
|
-
Simplified_AdEMAMix (bool): whether to use the Simplified AdEMAMix update rule.
|
|
57
|
-
This changes the EMA to accumulator and the update numerator to `alpha_grad * grad + mt`, which can be
|
|
58
|
-
more responsive, especially for small batch sizes. Enabling this will
|
|
59
|
-
automatically disable `use_AdEMAMix`, `
|
|
60
|
-
and `use_atan2`. (default: False)
|
|
61
|
-
alpha_grad (float): Mixing coefficient for the Simplified AdEMAMix update rule
|
|
62
|
-
(only used when `Simplified_AdEMAMix` is `True`). Controls the weight of the
|
|
63
|
-
current gradient. For small batch sizes, use high values (e.g., 10-100) to be
|
|
64
|
-
more responsive. For large batch sizes, use low values (e.g., 0-1) for
|
|
65
|
-
stability. (default: 100.0)
|
|
66
|
-
|
|
67
|
-
the uncompressed optimizer. (default: False)
|
|
68
|
-
d0 (float):
|
|
69
|
-
Initial D estimate for D-adaptation (default 1e-6). Rarely needs changing.
|
|
70
|
-
d_coef (float):
|
|
71
|
-
Coefficient in the expression for the estimate of d (default 1.0).
|
|
72
|
-
Values such as 0.5 and 2.0 typically work as well.
|
|
73
|
-
Changing this parameter is the preferred way to tune the method.
|
|
74
|
-
growth_rate (float):
|
|
75
|
-
prevent the D estimate from growing faster than this multiplicative rate.
|
|
76
|
-
Default is inf, for unrestricted. Values like 1.02 give a kind of learning
|
|
77
|
-
rate warmup effect.
|
|
78
|
-
fsdp_in_use (bool):
|
|
79
|
-
If you're using sharded parameters, this should be set to True. The optimizer
|
|
80
|
-
will attempt to auto-detect this, but if you're using an implementation other
|
|
81
|
-
than PyTorch's builtin version, the auto-detection won't work.
|
|
82
|
-
slice_p (int): Reduce memory usage by calculating LR adaptation statistics on only every
|
|
83
|
-
pth entry of each tensor. For values greater than 1 this an an approximation to standard
|
|
84
|
-
Prodigy. Values ~11 are reasonable (default 11).
|
|
85
|
-
prodigy_steps (int): If greater than zero, disable Prodigy's stepsize adjustments
|
|
86
|
-
after the specified optimiser step and release all state memory required by Prodigy
|
|
87
|
-
(default: 0).
|
|
88
|
-
"""
|
|
89
|
-
|
|
90
|
-
def __init__(
|
|
91
|
-
self,
|
|
92
|
-
params,
|
|
93
|
-
lr: float = 1,
|
|
94
|
-
betas: tuple[float, float] = (0.9, 0.999),
|
|
95
|
-
eps: float = 1e-8,
|
|
96
|
-
weight_decay: float = 0.0,
|
|
97
|
-
vector_reshape: bool = True,
|
|
98
|
-
stochastic_rounding: bool = True,
|
|
99
|
-
use_atan2: bool = False,
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
use_orthograd: bool = False,
|
|
103
|
-
use_AdEMAMix: bool = False,
|
|
104
|
-
beta3_ema: float = 0.9999,
|
|
105
|
-
alpha: float = 5.0,
|
|
106
|
-
t_alpha: int | None = None,
|
|
107
|
-
Simplified_AdEMAMix: bool = False,
|
|
108
|
-
alpha_grad: float = 100.0,
|
|
109
|
-
|
|
110
|
-
# prodigy parameters
|
|
111
|
-
beta3: float = None,
|
|
112
|
-
d0: float = 1e-6,
|
|
113
|
-
d_coef: float = 1,
|
|
114
|
-
growth_rate: float = float('inf'),
|
|
115
|
-
safeguard_warmup: bool = False,
|
|
116
|
-
fsdp_in_use: bool = False,
|
|
117
|
-
slice_p: int = 11,
|
|
118
|
-
prodigy_steps: int = 0,
|
|
119
|
-
):
|
|
120
|
-
if not (lr >= 0.0):
|
|
121
|
-
raise ValueError(f"Learning-rate should be >= 0.0. Got {lr}")
|
|
122
|
-
if not (0.0 <= betas[0] < 1.0 and 0.0 <= betas[1] < 1.0):
|
|
123
|
-
raise ValueError(f"Betas should be in [0.0, 1.0). Got {betas}")
|
|
124
|
-
if not (eps >= 0.0):
|
|
125
|
-
raise ValueError(f"Epsilon should be >= 0.0. Got {eps}")
|
|
126
|
-
if not (weight_decay >= 0.0):
|
|
127
|
-
raise ValueError(f"Weight-decay should be >= 0.0. Got {weight_decay}")
|
|
128
|
-
if not (prodigy_steps >= 0):
|
|
129
|
-
raise ValueError(f"prodigy_steps should be >= 0. Got {prodigy_steps}")
|
|
130
|
-
if
|
|
131
|
-
print("Warning:
|
|
132
|
-
|
|
133
|
-
if betas[0] == 0.0 and Simplified_AdEMAMix:
|
|
134
|
-
raise ValueError(f"Beta1 cannot be 0.0 when using Simplified_AdEMAMix. Got {betas[0]}")
|
|
135
|
-
if use_AdEMAMix and Simplified_AdEMAMix:
|
|
136
|
-
print("Warning: use_AdEMAMix is incompatible with Simplified_AdEMAMix, Disabling use_AdEMAMix.")
|
|
137
|
-
if
|
|
138
|
-
print("Warning:
|
|
139
|
-
if
|
|
140
|
-
print("Warning:
|
|
141
|
-
if use_atan2 and Simplified_AdEMAMix:
|
|
142
|
-
print("Warning: use_atan2 is incompatible with Simplified_AdEMAMix. Disabling use_atan2.")
|
|
143
|
-
use_atan2 = False
|
|
144
|
-
if Simplified_AdEMAMix and alpha_grad > 0:
|
|
145
|
-
# scales d_coef by alpha_grad, this force prodigy to behave well with Simplified_AdEMAMix
|
|
146
|
-
d_coef = d_coef/alpha_grad
|
|
147
|
-
|
|
148
|
-
defaults = {
|
|
149
|
-
"lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay,
|
|
150
|
-
"vector_reshape": vector_reshape, "use_atan2": use_atan2,
|
|
151
|
-
"use_orthograd": use_orthograd,
|
|
152
|
-
"beta3_ema": beta3_ema, "alpha": alpha, "t_alpha": t_alpha,
|
|
153
|
-
"beta3": beta3, "d": d0, "d0": d0, "d_max": d0, "d_numerator": 0.0, "d_coef": d_coef,
|
|
154
|
-
"growth_rate": growth_rate, "safeguard_warmup": safeguard_warmup, "k": 0, "slice_p": slice_p,
|
|
155
|
-
"fsdp_in_use": fsdp_in_use, "prodigy_steps": prodigy_steps,
|
|
156
|
-
"alpha_grad": alpha_grad,
|
|
157
|
-
}
|
|
158
|
-
self.stochastic_rounding = stochastic_rounding
|
|
159
|
-
self.
|
|
160
|
-
self.
|
|
161
|
-
self.use_AdEMAMix = use_AdEMAMix and not Simplified_AdEMAMix
|
|
162
|
-
self.Simplified_AdEMAMix = Simplified_AdEMAMix
|
|
163
|
-
self.factored =
|
|
164
|
-
self.fsdp_in_use = fsdp_in_use
|
|
165
|
-
super().__init__(params, defaults)
|
|
166
|
-
self.init_step()
|
|
167
|
-
|
|
168
|
-
@property
|
|
169
|
-
def supports_fused_back_pass(self):
|
|
170
|
-
return True
|
|
171
|
-
|
|
172
|
-
@property
|
|
173
|
-
def supports_memory_efficient_fp16(self):
|
|
174
|
-
return True
|
|
175
|
-
|
|
176
|
-
@property
|
|
177
|
-
def supports_flat_params(self):
|
|
178
|
-
return False
|
|
179
|
-
|
|
180
|
-
def init_step(self):
|
|
181
|
-
"""Resets accumulators and calculates dlr for the upcoming step."""
|
|
182
|
-
self.d_denom = 0.0
|
|
183
|
-
|
|
184
|
-
g_group = self.param_groups[0]
|
|
185
|
-
self.beta1, self.beta2 = g_group['betas']
|
|
186
|
-
self.beta3 = g_group['beta3']
|
|
187
|
-
if self.beta3 is None:
|
|
188
|
-
self.beta3 = math.sqrt(self.beta2)
|
|
189
|
-
|
|
190
|
-
k = g_group['k']
|
|
191
|
-
self.d = g_group['d']
|
|
192
|
-
lr = g_group['lr']
|
|
193
|
-
|
|
194
|
-
self.dlr = self.d * lr
|
|
195
|
-
|
|
196
|
-
self.d_numerator = g_group.get('d_numerator', 0.0) * self.beta3
|
|
197
|
-
|
|
198
|
-
@torch.no_grad()
|
|
199
|
-
def step_parameter(self, p: torch.Tensor, group: dict, i: int | None = None):
|
|
200
|
-
if p.grad is None:
|
|
201
|
-
return
|
|
202
|
-
|
|
203
|
-
if hasattr(p, "_fsdp_flattened"):
|
|
204
|
-
self.fsdp_in_use = True
|
|
205
|
-
|
|
206
|
-
grad = p.grad
|
|
207
|
-
if grad.dtype != torch.float32 and self.factored:
|
|
208
|
-
grad = grad.float()
|
|
209
|
-
if group["use_orthograd"]:
|
|
210
|
-
grad = _orthogonalize_gradient(p, grad)
|
|
211
|
-
state = self.state[p]
|
|
212
|
-
|
|
213
|
-
# State Initialization
|
|
214
|
-
if len(state) == 0:
|
|
215
|
-
state['step'] = 0
|
|
216
|
-
|
|
217
|
-
should_factor = (
|
|
218
|
-
self.factored and
|
|
219
|
-
not (len(p.shape) == 1 and not group['vector_reshape'])
|
|
220
|
-
)
|
|
221
|
-
|
|
222
|
-
state['factored'] = should_factor
|
|
223
|
-
|
|
224
|
-
slice_p = group['slice_p']
|
|
225
|
-
|
|
226
|
-
dtype = torch.float32 if self.factored else p.dtype
|
|
227
|
-
device = p.device
|
|
228
|
-
|
|
229
|
-
if state['factored']:
|
|
230
|
-
state['effective_shape'] = _get_effective_shape(p.numel())
|
|
231
|
-
d1, d2 = state['effective_shape']
|
|
232
|
-
|
|
233
|
-
# First moment (m)
|
|
234
|
-
if self.beta1 > 0:
|
|
235
|
-
state['mu_m_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
|
|
236
|
-
state['mv_m_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
|
|
237
|
-
if not self.
|
|
238
|
-
packed_d2 = (d2 + 7) // 8
|
|
239
|
-
state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=device)
|
|
240
|
-
if self.use_AdEMAMix:
|
|
241
|
-
state['mu_m_slow_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
|
|
242
|
-
state['mv_m_slow_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
|
|
243
|
-
packed_d2 = (d2 + 7) // 8
|
|
244
|
-
state['sign_slow'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=p.device)
|
|
245
|
-
# Second moment (v)
|
|
246
|
-
state['mu_v_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
|
|
247
|
-
state['mv_v_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
|
|
248
|
-
else: # Fallback to standard AdamW for non-factored tensors
|
|
249
|
-
if self.beta1 > 0:
|
|
250
|
-
state['exp_avg'] = torch.zeros_like(p, device=device, dtype=dtype)
|
|
251
|
-
if self.use_AdEMAMix:
|
|
252
|
-
state['exp_avg_slow'] = torch.zeros_like(p, dtype=dtype)
|
|
253
|
-
state['exp_avg_sq'] = torch.zeros_like(p, device=device, dtype=dtype)
|
|
254
|
-
|
|
255
|
-
state['s'] = torch.zeros_like(p.flatten()[::slice_p]).detach()
|
|
256
|
-
if p.any():
|
|
257
|
-
state['p0'] = p.flatten()[::slice_p].detach().clone()
|
|
258
|
-
else:
|
|
259
|
-
state['p0'] = torch.tensor(0, device=device, dtype=p.dtype)
|
|
260
|
-
|
|
261
|
-
if self.use_AdEMAMix:
|
|
262
|
-
beta3_ema = group['beta3_ema']
|
|
263
|
-
alpha = group['alpha']
|
|
264
|
-
t_alpha = group['t_alpha']
|
|
265
|
-
current_step = state['step'] + 1
|
|
266
|
-
alpha_t = alpha
|
|
267
|
-
if t_alpha is not None and t_alpha > 0 and current_step < t_alpha:
|
|
268
|
-
alpha_t = min(current_step * alpha / t_alpha, alpha)
|
|
269
|
-
if self.Simplified_AdEMAMix:
|
|
270
|
-
alpha_grad = group["alpha_grad"]
|
|
271
|
-
|
|
272
|
-
if state['factored']:
|
|
273
|
-
d1, d2 = state['effective_shape']
|
|
274
|
-
|
|
275
|
-
grad_reshaped = grad.view(d1, d2)
|
|
276
|
-
|
|
277
|
-
# Reconstruct momentum from previous step's factors
|
|
278
|
-
if self.beta1 > 0:
|
|
279
|
-
mt = _unnmf((state['mu_m_nmf'], state['mv_m_nmf']))
|
|
280
|
-
if not self.
|
|
281
|
-
unpacked_sign = _unpack_bools(state['sign'], original_m=d2)
|
|
282
|
-
torch.where(unpacked_sign, mt, -mt, out=mt)
|
|
283
|
-
del unpacked_sign
|
|
284
|
-
# Update momentum in full-size
|
|
285
|
-
if self.Simplified_AdEMAMix:
|
|
286
|
-
mt.mul_(self.beta1).add_(grad_reshaped, alpha=self.d)
|
|
287
|
-
else:
|
|
288
|
-
mt.mul_(self.beta1).add_(grad_reshaped, alpha=self.d * (1.0 - self.beta1))
|
|
289
|
-
if self.
|
|
290
|
-
mt.copy_(grad_reshaped.sign() * mt.abs())
|
|
291
|
-
elif self.
|
|
292
|
-
mask = (mt * grad_reshaped > 0).to(grad_reshaped.dtype)
|
|
293
|
-
mask.div_(mask.mean().clamp_(min=1e-3))
|
|
294
|
-
mt.mul_(mask)
|
|
295
|
-
del mask
|
|
296
|
-
|
|
297
|
-
vt = _unnmf((state['mu_v_nmf'], state['mv_v_nmf']))
|
|
298
|
-
vt.mul_(self.beta2).addcmul_(grad_reshaped, grad_reshaped, value=self.d * self.d * (1.0 - self.beta2))
|
|
299
|
-
|
|
300
|
-
if self.use_AdEMAMix:
|
|
301
|
-
mt_slow = _unnmf((state['mu_m_slow_nmf'], state['mv_m_slow_nmf']))
|
|
302
|
-
if state['sign_slow'].dtype != torch.uint8:
|
|
303
|
-
state['sign_slow'] = state['sign_slow'].to(torch.uint8)
|
|
304
|
-
unpacked_sign_slow = _unpack_bools(state['sign_slow'], original_m=d2)
|
|
305
|
-
torch.where(unpacked_sign_slow, mt_slow, -mt_slow, out=mt_slow)
|
|
306
|
-
del unpacked_sign_slow
|
|
307
|
-
mt_slow.mul_(beta3_ema).add_(grad_reshaped, alpha=self.d * (1.0 - beta3_ema))
|
|
308
|
-
if self.beta1 > 0:
|
|
309
|
-
update = torch.add(mt, mt_slow, alpha=alpha_t)
|
|
310
|
-
else:
|
|
311
|
-
update = torch.add(grad_reshaped, mt_slow, alpha=alpha_t)
|
|
312
|
-
elif self.Simplified_AdEMAMix:
|
|
313
|
-
update = torch.add(mt, grad_reshaped, alpha=alpha_grad * self.d)
|
|
314
|
-
else:
|
|
315
|
-
update = mt.clone() if self.beta1 > 0 else grad_reshaped.clone()
|
|
316
|
-
del grad_reshaped
|
|
317
|
-
|
|
318
|
-
if group['use_atan2']:
|
|
319
|
-
a = 1.2732395
|
|
320
|
-
denom = vt.sqrt()
|
|
321
|
-
update.atan2_(denom).mul_(a)
|
|
322
|
-
else:
|
|
323
|
-
denom = vt.sqrt()
|
|
324
|
-
update.div_(denom.add_(self.d * group['eps']))
|
|
325
|
-
del denom
|
|
326
|
-
|
|
327
|
-
update = update.view(p.shape).mul_(self.dlr)
|
|
328
|
-
|
|
329
|
-
# Compress updated moments and store new factors
|
|
330
|
-
if self.beta1 > 0:
|
|
331
|
-
if not self.
|
|
332
|
-
state['sign'] = _pack_bools(mt > 0)
|
|
333
|
-
_nnmf(mt.abs(), out=(state['mu_m_nmf'], state['mv_m_nmf']))
|
|
334
|
-
del mt
|
|
335
|
-
if self.use_AdEMAMix:
|
|
336
|
-
state['sign_slow'] = _pack_bools(mt_slow > 0)
|
|
337
|
-
_nnmf(mt_slow.abs(), out=(state['mu_m_slow_nmf'], state['mv_m_slow_nmf']))
|
|
338
|
-
del mt_slow
|
|
339
|
-
_nnmf(vt, out=(state['mu_v_nmf'], state['mv_v_nmf']))
|
|
340
|
-
del vt
|
|
341
|
-
|
|
342
|
-
else: # Standard AdamW logic for non-factored tensors
|
|
343
|
-
exp_avg_sq = state['exp_avg_sq']
|
|
344
|
-
|
|
345
|
-
if self.beta1 > 0:
|
|
346
|
-
exp_avg = state['exp_avg']
|
|
347
|
-
if self.Simplified_AdEMAMix:
|
|
348
|
-
exp_avg.mul_(self.beta1).add_(grad, alpha=self.d)
|
|
349
|
-
else:
|
|
350
|
-
exp_avg.mul_(self.beta1).add_(grad, alpha=self.d * (1.0 - self.beta1))
|
|
351
|
-
if self.
|
|
352
|
-
exp_avg = grad.sign() * exp_avg.abs()
|
|
353
|
-
elif self.
|
|
354
|
-
mask = (exp_avg * grad > 0).to(grad.dtype)
|
|
355
|
-
mask.div_(mask.mean().clamp_(min=1e-3))
|
|
356
|
-
exp_avg.mul_(mask)
|
|
357
|
-
del mask
|
|
358
|
-
|
|
359
|
-
if self.use_AdEMAMix:
|
|
360
|
-
exp_avg_slow = state['exp_avg_slow']
|
|
361
|
-
exp_avg_slow.mul_(beta3_ema).add_(grad, alpha=self.d * (1.0 - beta3_ema))
|
|
362
|
-
if self.beta1 > 0:
|
|
363
|
-
update = torch.add(exp_avg, exp_avg_slow, alpha=alpha_t)
|
|
364
|
-
else:
|
|
365
|
-
update = torch.add(grad, exp_avg_slow, alpha=alpha_t)
|
|
366
|
-
elif self.Simplified_AdEMAMix:
|
|
367
|
-
update = torch.add(exp_avg, grad, alpha=alpha_grad * self.d)
|
|
368
|
-
else:
|
|
369
|
-
update = exp_avg.clone() if self.beta1 > 0 else grad.clone()
|
|
370
|
-
|
|
371
|
-
exp_avg_sq.mul_(self.beta2).addcmul_(grad, grad.conj(), value=self.d * self.d * (1.0 - self.beta2))
|
|
372
|
-
|
|
373
|
-
if group['use_atan2']:
|
|
374
|
-
a = 1.2732395
|
|
375
|
-
denom = exp_avg_sq.sqrt()
|
|
376
|
-
update.atan2_(denom).mul_(a)
|
|
377
|
-
else:
|
|
378
|
-
denom = exp_avg_sq.sqrt()
|
|
379
|
-
update.div_(denom.add_(self.d * group['eps']))
|
|
380
|
-
del denom
|
|
381
|
-
|
|
382
|
-
update.mul_(self.dlr)
|
|
383
|
-
|
|
384
|
-
# --- Accumulate Prodigy stats ---
|
|
385
|
-
prodigy_steps = group['prodigy_steps']
|
|
386
|
-
if prodigy_steps <= 0 or group['k'] < prodigy_steps:
|
|
387
|
-
d0, safeguard_warmup, slice_p = group['d0'], group['safeguard_warmup'], group['slice_p']
|
|
388
|
-
s, p0 = state['s'], state['p0']
|
|
389
|
-
grad_flat = grad.flatten().float()
|
|
390
|
-
p_flat = p.data.flatten().float()
|
|
391
|
-
p0 = p0.float()
|
|
392
|
-
|
|
393
|
-
self.d_numerator += (self.d / d0) * self.dlr * torch.dot(grad_flat[::slice_p], p0.data - p_flat[::slice_p]).item()
|
|
394
|
-
|
|
395
|
-
alpha = ((self.d / d0) * self.d) if safeguard_warmup else ((self.d / d0) * self.dlr)
|
|
396
|
-
s.mul_(self.beta3).add_(grad_flat[::slice_p], alpha=alpha)
|
|
397
|
-
self.d_denom += s.abs().sum().item()
|
|
398
|
-
|
|
399
|
-
del s, p0, grad_flat, p_flat, alpha
|
|
400
|
-
else:
|
|
401
|
-
# Free memory if prodigy_steps is reached
|
|
402
|
-
if 's' in state:
|
|
403
|
-
del state['s']
|
|
404
|
-
if 'p0' in state:
|
|
405
|
-
del state['p0']
|
|
406
|
-
|
|
407
|
-
# Decoupled weight decay
|
|
408
|
-
if group["weight_decay"] != 0:
|
|
409
|
-
if p.dtype == torch.bfloat16 and self.stochastic_rounding:
|
|
410
|
-
add_stochastic_(p.data, p.data, alpha=-group["weight_decay"] * self.dlr)
|
|
411
|
-
else:
|
|
412
|
-
p.data.add_(p.data, alpha=-group["weight_decay"] * self.dlr)
|
|
413
|
-
|
|
414
|
-
if p.dtype == torch.bfloat16 and self.stochastic_rounding:
|
|
415
|
-
add_stochastic_(p.data, -update)
|
|
416
|
-
else:
|
|
417
|
-
p.data.add_(-update)
|
|
418
|
-
del update
|
|
419
|
-
|
|
420
|
-
state['step'] += 1
|
|
421
|
-
|
|
422
|
-
@torch.no_grad()
|
|
423
|
-
def step(self, closure=None):
|
|
424
|
-
"""Performs a single optimization step."""
|
|
425
|
-
loss = None
|
|
426
|
-
if closure is not None:
|
|
427
|
-
with torch.enable_grad():
|
|
428
|
-
loss = closure()
|
|
429
|
-
|
|
430
|
-
for group in self.param_groups:
|
|
431
|
-
for i, p in enumerate(group['params']):
|
|
432
|
-
self.step_parameter(p, group, i)
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
self.calculate_d()
|
|
436
|
-
self.init_step()
|
|
437
|
-
return loss
|
|
438
|
-
|
|
439
|
-
def calculate_d(self):
|
|
440
|
-
"""Calculates the new `d` based on the accumulated stats."""
|
|
441
|
-
g_group = self.param_groups[0]
|
|
442
|
-
|
|
443
|
-
# Only perform d-adaptation if prodigy_steps has not been reached
|
|
444
|
-
prodigy_active = not (g_group.get('prodigy_steps', 0) > 0 and g_group['k'] >= g_group['prodigy_steps'])
|
|
445
|
-
|
|
446
|
-
if prodigy_active:
|
|
447
|
-
d_max, d_coef, growth_rate = g_group['d_max'], g_group['d_coef'], g_group['growth_rate']
|
|
448
|
-
|
|
449
|
-
if self.fsdp_in_use and dist.is_available() and dist.is_initialized():
|
|
450
|
-
# Use the device of the first parameter to avoid hardcoding '.cuda()'
|
|
451
|
-
device = self.param_groups[0]['params'][0].device
|
|
452
|
-
dist_tensor = torch.tensor([self.d_numerator, self.d_denom], device=device)
|
|
453
|
-
dist.all_reduce(dist_tensor, op=dist.ReduceOp.SUM)
|
|
454
|
-
global_d_numerator = dist_tensor[0].item()
|
|
455
|
-
global_d_denom = dist_tensor[1].item()
|
|
456
|
-
else:
|
|
457
|
-
global_d_numerator = self.d_numerator
|
|
458
|
-
global_d_denom = self.d_denom
|
|
459
|
-
|
|
460
|
-
d_hat = self.d
|
|
461
|
-
if global_d_denom > 0:
|
|
462
|
-
d_hat = d_coef * global_d_numerator / global_d_denom
|
|
463
|
-
if self.d == g_group['d0']:
|
|
464
|
-
self.d = max(self.d, d_hat)
|
|
465
|
-
d_max = max(d_max, d_hat)
|
|
466
|
-
self.d = min(d_max, self.d * growth_rate)
|
|
467
|
-
|
|
468
|
-
for group in self.param_groups:
|
|
469
|
-
group['d_numerator'] = global_d_numerator
|
|
470
|
-
group['d'] = self.d
|
|
471
|
-
group['d_max'] = d_max
|
|
472
|
-
|
|
473
|
-
# Increment step counter for all groups, regardless of whether d was updated
|
|
474
|
-
for group in self.param_groups:
|
|
475
|
-
group['k'] += 1
|
|
1
|
+
import torch
|
|
2
|
+
import torch.distributed as dist
|
|
3
|
+
|
|
4
|
+
import math
|
|
5
|
+
|
|
6
|
+
from ..util.BF16_Stochastic_Rounding import add_stochastic_
|
|
7
|
+
from ..util.Effective_Shape import _get_effective_shape
|
|
8
|
+
from ..util.NNMF import _nnmf,_unnmf
|
|
9
|
+
from ..util.OrthoGrad import _orthogonalize_gradient
|
|
10
|
+
from ..util.One_Bit_Boolean import _pack_bools, _unpack_bools
|
|
11
|
+
|
|
12
|
+
class Prodigy_adv(torch.optim.Optimizer):
|
|
13
|
+
"""
|
|
14
|
+
Implements a factored Prodigy/AdamW algorithm.
|
|
15
|
+
This is an advanced version of Prodigy with optional features like
|
|
16
|
+
low-rank factorization of optimizer states (SMMF), OrthoGrad, etc.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
params (iterable): iterable of parameters to optimize or dicts defining
|
|
20
|
+
parameter groups
|
|
21
|
+
lr (float): learning rate (default: 1)
|
|
22
|
+
betas (tuple[float, float]): coefficients used for computing running
|
|
23
|
+
averages of gradient and its square (default: (0.9, 0.999))
|
|
24
|
+
eps (float): term added to the denominator to improve
|
|
25
|
+
numerical stability (default: 1e-8)
|
|
26
|
+
weight_decay (float): weight decay (L2 penalty) (default: 0)
|
|
27
|
+
vector_reshape (bool): whether to reshape 1D vectors into 2D
|
|
28
|
+
matrices to apply low-rank compression (default: True).
|
|
29
|
+
stochastic_rounding (bool): whether to use stochastic
|
|
30
|
+
rounding for BF16 parameter updates (default: True).
|
|
31
|
+
use_atan2 (bool): whether to use the atan2 update rule. (default: False)
|
|
32
|
+
grams_moment (bool): whether to use Grams-style updates. (default: False)
|
|
33
|
+
cautious_mask (bool): whether to use cautious masking to align the gradient's
|
|
34
|
+
direction with the first moment's. (default: False)
|
|
35
|
+
use_orthograd (bool): whether to use OrthoGrad. (default: False)
|
|
36
|
+
use_AdEMAMix (bool): whether to enable the AdEMAMix feature. This adds
|
|
37
|
+
a second, slow-moving average of the momentum (`mt_slow`) which is
|
|
38
|
+
combined with the primary momentum (`mt`) to stabilize updates,
|
|
39
|
+
especially in noisy, small-batch settings. If `False`, the
|
|
40
|
+
optimizer behaves as standard AdamW. (default: False)
|
|
41
|
+
beta3_ema (float): The decay rate for the slow exponential moving average of
|
|
42
|
+
the momentum (only used when `use_AdEMAMix` is `True`). A higher
|
|
43
|
+
value (e.g., 0.9999) gives the EMA a longer memory, making it more
|
|
44
|
+
stable but slower to adapt. A lower value (e.g., 0.999) is often
|
|
45
|
+
better for shorter training runs. (default: 0.9999)
|
|
46
|
+
alpha (float): The mixing coefficient that scales the slow momentum term
|
|
47
|
+
before it is added to the fast momentum term (`update = mt + alpha * mt_slow`).
|
|
48
|
+
A higher value increases the stabilizing influence of the slow
|
|
49
|
+
momentum. (default: 5.0)
|
|
50
|
+
t_alpha (Optional[int]): The number of steps for a linear warmup of the
|
|
51
|
+
`alpha` parameter (only used when `use_AdEMAMix` is `True`). This is
|
|
52
|
+
highly recommended to prevent instability at the beginning of training,
|
|
53
|
+
as it gradually introduces the stabilizing slow momentum term. During
|
|
54
|
+
the warmup, `alpha` ramps from 0 to its target value. If `None`,
|
|
55
|
+
the scheduler is disabled.
|
|
56
|
+
Simplified_AdEMAMix (bool): whether to use the Simplified AdEMAMix update rule.
|
|
57
|
+
This changes the EMA to accumulator and the update numerator to `alpha_grad * grad + mt`, which can be
|
|
58
|
+
more responsive, especially for small batch sizes. Enabling this will
|
|
59
|
+
automatically disable `use_AdEMAMix`, `cautious_mask`, `grams_moment`,
|
|
60
|
+
and `use_atan2`. (default: False)
|
|
61
|
+
alpha_grad (float): Mixing coefficient for the Simplified AdEMAMix update rule
|
|
62
|
+
(only used when `Simplified_AdEMAMix` is `True`). Controls the weight of the
|
|
63
|
+
current gradient. For small batch sizes, use high values (e.g., 10-100) to be
|
|
64
|
+
more responsive. For large batch sizes, use low values (e.g., 0-1) for
|
|
65
|
+
stability. (default: 100.0)
|
|
66
|
+
nnmf_factor (bool): whether to use the factorization or disable it to use
|
|
67
|
+
the uncompressed optimizer. (default: False)
|
|
68
|
+
d0 (float):
|
|
69
|
+
Initial D estimate for D-adaptation (default 1e-6). Rarely needs changing.
|
|
70
|
+
d_coef (float):
|
|
71
|
+
Coefficient in the expression for the estimate of d (default 1.0).
|
|
72
|
+
Values such as 0.5 and 2.0 typically work as well.
|
|
73
|
+
Changing this parameter is the preferred way to tune the method.
|
|
74
|
+
growth_rate (float):
|
|
75
|
+
prevent the D estimate from growing faster than this multiplicative rate.
|
|
76
|
+
Default is inf, for unrestricted. Values like 1.02 give a kind of learning
|
|
77
|
+
rate warmup effect.
|
|
78
|
+
fsdp_in_use (bool):
|
|
79
|
+
If you're using sharded parameters, this should be set to True. The optimizer
|
|
80
|
+
will attempt to auto-detect this, but if you're using an implementation other
|
|
81
|
+
than PyTorch's builtin version, the auto-detection won't work.
|
|
82
|
+
slice_p (int): Reduce memory usage by calculating LR adaptation statistics on only every
|
|
83
|
+
pth entry of each tensor. For values greater than 1 this an an approximation to standard
|
|
84
|
+
Prodigy. Values ~11 are reasonable (default 11).
|
|
85
|
+
prodigy_steps (int): If greater than zero, disable Prodigy's stepsize adjustments
|
|
86
|
+
after the specified optimiser step and release all state memory required by Prodigy
|
|
87
|
+
(default: 0).
|
|
88
|
+
"""
|
|
89
|
+
|
|
90
|
+
def __init__(
|
|
91
|
+
self,
|
|
92
|
+
params,
|
|
93
|
+
lr: float = 1,
|
|
94
|
+
betas: tuple[float, float] = (0.9, 0.999),
|
|
95
|
+
eps: float = 1e-8,
|
|
96
|
+
weight_decay: float = 0.0,
|
|
97
|
+
vector_reshape: bool = True,
|
|
98
|
+
stochastic_rounding: bool = True,
|
|
99
|
+
use_atan2: bool = False,
|
|
100
|
+
cautious_mask: bool = False,
|
|
101
|
+
grams_moment: bool = False,
|
|
102
|
+
use_orthograd: bool = False,
|
|
103
|
+
use_AdEMAMix: bool = False,
|
|
104
|
+
beta3_ema: float = 0.9999,
|
|
105
|
+
alpha: float = 5.0,
|
|
106
|
+
t_alpha: int | None = None,
|
|
107
|
+
Simplified_AdEMAMix: bool = False,
|
|
108
|
+
alpha_grad: float = 100.0,
|
|
109
|
+
nnmf_factor: bool = False,
|
|
110
|
+
# prodigy parameters
|
|
111
|
+
beta3: float = None,
|
|
112
|
+
d0: float = 1e-6,
|
|
113
|
+
d_coef: float = 1,
|
|
114
|
+
growth_rate: float = float('inf'),
|
|
115
|
+
safeguard_warmup: bool = False,
|
|
116
|
+
fsdp_in_use: bool = False,
|
|
117
|
+
slice_p: int = 11,
|
|
118
|
+
prodigy_steps: int = 0,
|
|
119
|
+
):
|
|
120
|
+
if not (lr >= 0.0):
|
|
121
|
+
raise ValueError(f"Learning-rate should be >= 0.0. Got {lr}")
|
|
122
|
+
if not (0.0 <= betas[0] < 1.0 and 0.0 <= betas[1] < 1.0):
|
|
123
|
+
raise ValueError(f"Betas should be in [0.0, 1.0). Got {betas}")
|
|
124
|
+
if not (eps >= 0.0):
|
|
125
|
+
raise ValueError(f"Epsilon should be >= 0.0. Got {eps}")
|
|
126
|
+
if not (weight_decay >= 0.0):
|
|
127
|
+
raise ValueError(f"Weight-decay should be >= 0.0. Got {weight_decay}")
|
|
128
|
+
if not (prodigy_steps >= 0):
|
|
129
|
+
raise ValueError(f"prodigy_steps should be >= 0. Got {prodigy_steps}")
|
|
130
|
+
if cautious_mask and grams_moment:
|
|
131
|
+
print("Warning: cautious is incompatible with grams, Disabling cautious.")
|
|
132
|
+
cautious_mask = False
|
|
133
|
+
if betas[0] == 0.0 and Simplified_AdEMAMix:
|
|
134
|
+
raise ValueError(f"Beta1 cannot be 0.0 when using Simplified_AdEMAMix. Got {betas[0]}")
|
|
135
|
+
if use_AdEMAMix and Simplified_AdEMAMix:
|
|
136
|
+
print("Warning: use_AdEMAMix is incompatible with Simplified_AdEMAMix, Disabling use_AdEMAMix.")
|
|
137
|
+
if grams_moment and Simplified_AdEMAMix:
|
|
138
|
+
print("Warning: grams is incompatible with Simplified_AdEMAMix, Disabling grams.")
|
|
139
|
+
if cautious_mask and Simplified_AdEMAMix:
|
|
140
|
+
print("Warning: cautious is incompatible with Simplified_AdEMAMix, Disabling cautious.")
|
|
141
|
+
if use_atan2 and Simplified_AdEMAMix:
|
|
142
|
+
print("Warning: use_atan2 is incompatible with Simplified_AdEMAMix. Disabling use_atan2.")
|
|
143
|
+
use_atan2 = False
|
|
144
|
+
if Simplified_AdEMAMix and alpha_grad > 0:
|
|
145
|
+
# scales d_coef by alpha_grad, this force prodigy to behave well with Simplified_AdEMAMix
|
|
146
|
+
d_coef = d_coef/alpha_grad
|
|
147
|
+
|
|
148
|
+
defaults = {
|
|
149
|
+
"lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay,
|
|
150
|
+
"vector_reshape": vector_reshape, "use_atan2": use_atan2,
|
|
151
|
+
"use_orthograd": use_orthograd,
|
|
152
|
+
"beta3_ema": beta3_ema, "alpha": alpha, "t_alpha": t_alpha,
|
|
153
|
+
"beta3": beta3, "d": d0, "d0": d0, "d_max": d0, "d_numerator": 0.0, "d_coef": d_coef,
|
|
154
|
+
"growth_rate": growth_rate, "safeguard_warmup": safeguard_warmup, "k": 0, "slice_p": slice_p,
|
|
155
|
+
"fsdp_in_use": fsdp_in_use, "prodigy_steps": prodigy_steps,
|
|
156
|
+
"alpha_grad": alpha_grad,
|
|
157
|
+
}
|
|
158
|
+
self.stochastic_rounding = stochastic_rounding
|
|
159
|
+
self.cautious_mask = cautious_mask and not Simplified_AdEMAMix
|
|
160
|
+
self.grams_moment = grams_moment and not Simplified_AdEMAMix
|
|
161
|
+
self.use_AdEMAMix = use_AdEMAMix and not Simplified_AdEMAMix
|
|
162
|
+
self.Simplified_AdEMAMix = Simplified_AdEMAMix
|
|
163
|
+
self.factored = nnmf_factor
|
|
164
|
+
self.fsdp_in_use = fsdp_in_use
|
|
165
|
+
super().__init__(params, defaults)
|
|
166
|
+
self.init_step()
|
|
167
|
+
|
|
168
|
+
@property
|
|
169
|
+
def supports_fused_back_pass(self):
|
|
170
|
+
return True
|
|
171
|
+
|
|
172
|
+
@property
|
|
173
|
+
def supports_memory_efficient_fp16(self):
|
|
174
|
+
return True
|
|
175
|
+
|
|
176
|
+
@property
|
|
177
|
+
def supports_flat_params(self):
|
|
178
|
+
return False
|
|
179
|
+
|
|
180
|
+
def init_step(self):
|
|
181
|
+
"""Resets accumulators and calculates dlr for the upcoming step."""
|
|
182
|
+
self.d_denom = 0.0
|
|
183
|
+
|
|
184
|
+
g_group = self.param_groups[0]
|
|
185
|
+
self.beta1, self.beta2 = g_group['betas']
|
|
186
|
+
self.beta3 = g_group['beta3']
|
|
187
|
+
if self.beta3 is None:
|
|
188
|
+
self.beta3 = math.sqrt(self.beta2)
|
|
189
|
+
|
|
190
|
+
k = g_group['k']
|
|
191
|
+
self.d = g_group['d']
|
|
192
|
+
lr = g_group['lr']
|
|
193
|
+
|
|
194
|
+
self.dlr = self.d * lr
|
|
195
|
+
|
|
196
|
+
self.d_numerator = g_group.get('d_numerator', 0.0) * self.beta3
|
|
197
|
+
|
|
198
|
+
@torch.no_grad()
|
|
199
|
+
def step_parameter(self, p: torch.Tensor, group: dict, i: int | None = None):
|
|
200
|
+
if p.grad is None:
|
|
201
|
+
return
|
|
202
|
+
|
|
203
|
+
if hasattr(p, "_fsdp_flattened"):
|
|
204
|
+
self.fsdp_in_use = True
|
|
205
|
+
|
|
206
|
+
grad = p.grad
|
|
207
|
+
if grad.dtype != torch.float32 and self.factored:
|
|
208
|
+
grad = grad.float()
|
|
209
|
+
if group["use_orthograd"]:
|
|
210
|
+
grad = _orthogonalize_gradient(p, grad)
|
|
211
|
+
state = self.state[p]
|
|
212
|
+
|
|
213
|
+
# State Initialization
|
|
214
|
+
if len(state) == 0:
|
|
215
|
+
state['step'] = 0
|
|
216
|
+
|
|
217
|
+
should_factor = (
|
|
218
|
+
self.factored and
|
|
219
|
+
not (len(p.shape) == 1 and not group['vector_reshape'])
|
|
220
|
+
)
|
|
221
|
+
|
|
222
|
+
state['factored'] = should_factor
|
|
223
|
+
|
|
224
|
+
slice_p = group['slice_p']
|
|
225
|
+
|
|
226
|
+
dtype = torch.float32 if self.factored else p.dtype
|
|
227
|
+
device = p.device
|
|
228
|
+
|
|
229
|
+
if state['factored']:
|
|
230
|
+
state['effective_shape'] = _get_effective_shape(p.numel())
|
|
231
|
+
d1, d2 = state['effective_shape']
|
|
232
|
+
|
|
233
|
+
# First moment (m)
|
|
234
|
+
if self.beta1 > 0:
|
|
235
|
+
state['mu_m_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
|
|
236
|
+
state['mv_m_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
|
|
237
|
+
if not self.grams_moment:
|
|
238
|
+
packed_d2 = (d2 + 7) // 8
|
|
239
|
+
state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=device)
|
|
240
|
+
if self.use_AdEMAMix:
|
|
241
|
+
state['mu_m_slow_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
|
|
242
|
+
state['mv_m_slow_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
|
|
243
|
+
packed_d2 = (d2 + 7) // 8
|
|
244
|
+
state['sign_slow'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=p.device)
|
|
245
|
+
# Second moment (v)
|
|
246
|
+
state['mu_v_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
|
|
247
|
+
state['mv_v_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
|
|
248
|
+
else: # Fallback to standard AdamW for non-factored tensors
|
|
249
|
+
if self.beta1 > 0:
|
|
250
|
+
state['exp_avg'] = torch.zeros_like(p, device=device, dtype=dtype)
|
|
251
|
+
if self.use_AdEMAMix:
|
|
252
|
+
state['exp_avg_slow'] = torch.zeros_like(p, dtype=dtype)
|
|
253
|
+
state['exp_avg_sq'] = torch.zeros_like(p, device=device, dtype=dtype)
|
|
254
|
+
|
|
255
|
+
state['s'] = torch.zeros_like(p.flatten()[::slice_p]).detach()
|
|
256
|
+
if p.any():
|
|
257
|
+
state['p0'] = p.flatten()[::slice_p].detach().clone()
|
|
258
|
+
else:
|
|
259
|
+
state['p0'] = torch.tensor(0, device=device, dtype=p.dtype)
|
|
260
|
+
|
|
261
|
+
if self.use_AdEMAMix:
|
|
262
|
+
beta3_ema = group['beta3_ema']
|
|
263
|
+
alpha = group['alpha']
|
|
264
|
+
t_alpha = group['t_alpha']
|
|
265
|
+
current_step = state['step'] + 1
|
|
266
|
+
alpha_t = alpha
|
|
267
|
+
if t_alpha is not None and t_alpha > 0 and current_step < t_alpha:
|
|
268
|
+
alpha_t = min(current_step * alpha / t_alpha, alpha)
|
|
269
|
+
if self.Simplified_AdEMAMix:
|
|
270
|
+
alpha_grad = group["alpha_grad"]
|
|
271
|
+
|
|
272
|
+
if state['factored']:
|
|
273
|
+
d1, d2 = state['effective_shape']
|
|
274
|
+
|
|
275
|
+
grad_reshaped = grad.view(d1, d2)
|
|
276
|
+
|
|
277
|
+
# Reconstruct momentum from previous step's factors
|
|
278
|
+
if self.beta1 > 0:
|
|
279
|
+
mt = _unnmf((state['mu_m_nmf'], state['mv_m_nmf']))
|
|
280
|
+
if not self.grams_moment:
|
|
281
|
+
unpacked_sign = _unpack_bools(state['sign'], original_m=d2)
|
|
282
|
+
torch.where(unpacked_sign, mt, -mt, out=mt)
|
|
283
|
+
del unpacked_sign
|
|
284
|
+
# Update momentum in full-size
|
|
285
|
+
if self.Simplified_AdEMAMix:
|
|
286
|
+
mt.mul_(self.beta1).add_(grad_reshaped, alpha=self.d)
|
|
287
|
+
else:
|
|
288
|
+
mt.mul_(self.beta1).add_(grad_reshaped, alpha=self.d * (1.0 - self.beta1))
|
|
289
|
+
if self.grams_moment:
|
|
290
|
+
mt.copy_(grad_reshaped.sign() * mt.abs())
|
|
291
|
+
elif self.cautious_mask:
|
|
292
|
+
mask = (mt * grad_reshaped > 0).to(grad_reshaped.dtype)
|
|
293
|
+
mask.div_(mask.mean().clamp_(min=1e-3))
|
|
294
|
+
mt.mul_(mask)
|
|
295
|
+
del mask
|
|
296
|
+
|
|
297
|
+
vt = _unnmf((state['mu_v_nmf'], state['mv_v_nmf']))
|
|
298
|
+
vt.mul_(self.beta2).addcmul_(grad_reshaped, grad_reshaped, value=self.d * self.d * (1.0 - self.beta2))
|
|
299
|
+
|
|
300
|
+
if self.use_AdEMAMix:
|
|
301
|
+
mt_slow = _unnmf((state['mu_m_slow_nmf'], state['mv_m_slow_nmf']))
|
|
302
|
+
if state['sign_slow'].dtype != torch.uint8:
|
|
303
|
+
state['sign_slow'] = state['sign_slow'].to(torch.uint8)
|
|
304
|
+
unpacked_sign_slow = _unpack_bools(state['sign_slow'], original_m=d2)
|
|
305
|
+
torch.where(unpacked_sign_slow, mt_slow, -mt_slow, out=mt_slow)
|
|
306
|
+
del unpacked_sign_slow
|
|
307
|
+
mt_slow.mul_(beta3_ema).add_(grad_reshaped, alpha=self.d * (1.0 - beta3_ema))
|
|
308
|
+
if self.beta1 > 0:
|
|
309
|
+
update = torch.add(mt, mt_slow, alpha=alpha_t)
|
|
310
|
+
else:
|
|
311
|
+
update = torch.add(grad_reshaped, mt_slow, alpha=alpha_t)
|
|
312
|
+
elif self.Simplified_AdEMAMix:
|
|
313
|
+
update = torch.add(mt, grad_reshaped, alpha=alpha_grad * self.d)
|
|
314
|
+
else:
|
|
315
|
+
update = mt.clone() if self.beta1 > 0 else grad_reshaped.clone()
|
|
316
|
+
del grad_reshaped
|
|
317
|
+
|
|
318
|
+
if group['use_atan2']:
|
|
319
|
+
a = 1.2732395
|
|
320
|
+
denom = vt.sqrt()
|
|
321
|
+
update.atan2_(denom).mul_(a)
|
|
322
|
+
else:
|
|
323
|
+
denom = vt.sqrt()
|
|
324
|
+
update.div_(denom.add_(self.d * group['eps']))
|
|
325
|
+
del denom
|
|
326
|
+
|
|
327
|
+
update = update.view(p.shape).mul_(self.dlr)
|
|
328
|
+
|
|
329
|
+
# Compress updated moments and store new factors
|
|
330
|
+
if self.beta1 > 0:
|
|
331
|
+
if not self.grams_moment:
|
|
332
|
+
state['sign'] = _pack_bools(mt > 0)
|
|
333
|
+
_nnmf(mt.abs(), out=(state['mu_m_nmf'], state['mv_m_nmf']))
|
|
334
|
+
del mt
|
|
335
|
+
if self.use_AdEMAMix:
|
|
336
|
+
state['sign_slow'] = _pack_bools(mt_slow > 0)
|
|
337
|
+
_nnmf(mt_slow.abs(), out=(state['mu_m_slow_nmf'], state['mv_m_slow_nmf']))
|
|
338
|
+
del mt_slow
|
|
339
|
+
_nnmf(vt, out=(state['mu_v_nmf'], state['mv_v_nmf']))
|
|
340
|
+
del vt
|
|
341
|
+
|
|
342
|
+
else: # Standard AdamW logic for non-factored tensors
|
|
343
|
+
exp_avg_sq = state['exp_avg_sq']
|
|
344
|
+
|
|
345
|
+
if self.beta1 > 0:
|
|
346
|
+
exp_avg = state['exp_avg']
|
|
347
|
+
if self.Simplified_AdEMAMix:
|
|
348
|
+
exp_avg.mul_(self.beta1).add_(grad, alpha=self.d)
|
|
349
|
+
else:
|
|
350
|
+
exp_avg.mul_(self.beta1).add_(grad, alpha=self.d * (1.0 - self.beta1))
|
|
351
|
+
if self.grams_moment:
|
|
352
|
+
exp_avg = grad.sign() * exp_avg.abs()
|
|
353
|
+
elif self.cautious_mask:
|
|
354
|
+
mask = (exp_avg * grad > 0).to(grad.dtype)
|
|
355
|
+
mask.div_(mask.mean().clamp_(min=1e-3))
|
|
356
|
+
exp_avg.mul_(mask)
|
|
357
|
+
del mask
|
|
358
|
+
|
|
359
|
+
if self.use_AdEMAMix:
|
|
360
|
+
exp_avg_slow = state['exp_avg_slow']
|
|
361
|
+
exp_avg_slow.mul_(beta3_ema).add_(grad, alpha=self.d * (1.0 - beta3_ema))
|
|
362
|
+
if self.beta1 > 0:
|
|
363
|
+
update = torch.add(exp_avg, exp_avg_slow, alpha=alpha_t)
|
|
364
|
+
else:
|
|
365
|
+
update = torch.add(grad, exp_avg_slow, alpha=alpha_t)
|
|
366
|
+
elif self.Simplified_AdEMAMix:
|
|
367
|
+
update = torch.add(exp_avg, grad, alpha=alpha_grad * self.d)
|
|
368
|
+
else:
|
|
369
|
+
update = exp_avg.clone() if self.beta1 > 0 else grad.clone()
|
|
370
|
+
|
|
371
|
+
exp_avg_sq.mul_(self.beta2).addcmul_(grad, grad.conj(), value=self.d * self.d * (1.0 - self.beta2))
|
|
372
|
+
|
|
373
|
+
if group['use_atan2']:
|
|
374
|
+
a = 1.2732395
|
|
375
|
+
denom = exp_avg_sq.sqrt()
|
|
376
|
+
update.atan2_(denom).mul_(a)
|
|
377
|
+
else:
|
|
378
|
+
denom = exp_avg_sq.sqrt()
|
|
379
|
+
update.div_(denom.add_(self.d * group['eps']))
|
|
380
|
+
del denom
|
|
381
|
+
|
|
382
|
+
update.mul_(self.dlr)
|
|
383
|
+
|
|
384
|
+
# --- Accumulate Prodigy stats ---
|
|
385
|
+
prodigy_steps = group['prodigy_steps']
|
|
386
|
+
if prodigy_steps <= 0 or group['k'] < prodigy_steps:
|
|
387
|
+
d0, safeguard_warmup, slice_p = group['d0'], group['safeguard_warmup'], group['slice_p']
|
|
388
|
+
s, p0 = state['s'], state['p0']
|
|
389
|
+
grad_flat = grad.flatten().float()
|
|
390
|
+
p_flat = p.data.flatten().float()
|
|
391
|
+
p0 = p0.float()
|
|
392
|
+
|
|
393
|
+
self.d_numerator += (self.d / d0) * self.dlr * torch.dot(grad_flat[::slice_p], p0.data - p_flat[::slice_p]).item()
|
|
394
|
+
|
|
395
|
+
alpha = ((self.d / d0) * self.d) if safeguard_warmup else ((self.d / d0) * self.dlr)
|
|
396
|
+
s.mul_(self.beta3).add_(grad_flat[::slice_p], alpha=alpha)
|
|
397
|
+
self.d_denom += s.abs().sum().item()
|
|
398
|
+
|
|
399
|
+
del s, p0, grad_flat, p_flat, alpha
|
|
400
|
+
else:
|
|
401
|
+
# Free memory if prodigy_steps is reached
|
|
402
|
+
if 's' in state:
|
|
403
|
+
del state['s']
|
|
404
|
+
if 'p0' in state:
|
|
405
|
+
del state['p0']
|
|
406
|
+
|
|
407
|
+
# Decoupled weight decay
|
|
408
|
+
if group["weight_decay"] != 0:
|
|
409
|
+
if p.dtype == torch.bfloat16 and self.stochastic_rounding:
|
|
410
|
+
add_stochastic_(p.data, p.data, alpha=-group["weight_decay"] * self.dlr)
|
|
411
|
+
else:
|
|
412
|
+
p.data.add_(p.data, alpha=-group["weight_decay"] * self.dlr)
|
|
413
|
+
|
|
414
|
+
if p.dtype == torch.bfloat16 and self.stochastic_rounding:
|
|
415
|
+
add_stochastic_(p.data, -update)
|
|
416
|
+
else:
|
|
417
|
+
p.data.add_(-update)
|
|
418
|
+
del update
|
|
419
|
+
|
|
420
|
+
state['step'] += 1
|
|
421
|
+
|
|
422
|
+
@torch.no_grad()
|
|
423
|
+
def step(self, closure=None):
|
|
424
|
+
"""Performs a single optimization step."""
|
|
425
|
+
loss = None
|
|
426
|
+
if closure is not None:
|
|
427
|
+
with torch.enable_grad():
|
|
428
|
+
loss = closure()
|
|
429
|
+
|
|
430
|
+
for group in self.param_groups:
|
|
431
|
+
for i, p in enumerate(group['params']):
|
|
432
|
+
self.step_parameter(p, group, i)
|
|
433
|
+
|
|
434
|
+
|
|
435
|
+
self.calculate_d()
|
|
436
|
+
self.init_step()
|
|
437
|
+
return loss
|
|
438
|
+
|
|
439
|
+
def calculate_d(self):
|
|
440
|
+
"""Calculates the new `d` based on the accumulated stats."""
|
|
441
|
+
g_group = self.param_groups[0]
|
|
442
|
+
|
|
443
|
+
# Only perform d-adaptation if prodigy_steps has not been reached
|
|
444
|
+
prodigy_active = not (g_group.get('prodigy_steps', 0) > 0 and g_group['k'] >= g_group['prodigy_steps'])
|
|
445
|
+
|
|
446
|
+
if prodigy_active:
|
|
447
|
+
d_max, d_coef, growth_rate = g_group['d_max'], g_group['d_coef'], g_group['growth_rate']
|
|
448
|
+
|
|
449
|
+
if self.fsdp_in_use and dist.is_available() and dist.is_initialized():
|
|
450
|
+
# Use the device of the first parameter to avoid hardcoding '.cuda()'
|
|
451
|
+
device = self.param_groups[0]['params'][0].device
|
|
452
|
+
dist_tensor = torch.tensor([self.d_numerator, self.d_denom], device=device)
|
|
453
|
+
dist.all_reduce(dist_tensor, op=dist.ReduceOp.SUM)
|
|
454
|
+
global_d_numerator = dist_tensor[0].item()
|
|
455
|
+
global_d_denom = dist_tensor[1].item()
|
|
456
|
+
else:
|
|
457
|
+
global_d_numerator = self.d_numerator
|
|
458
|
+
global_d_denom = self.d_denom
|
|
459
|
+
|
|
460
|
+
d_hat = self.d
|
|
461
|
+
if global_d_denom > 0:
|
|
462
|
+
d_hat = d_coef * global_d_numerator / global_d_denom
|
|
463
|
+
if self.d == g_group['d0']:
|
|
464
|
+
self.d = max(self.d, d_hat)
|
|
465
|
+
d_max = max(d_max, d_hat)
|
|
466
|
+
self.d = min(d_max, self.d * growth_rate)
|
|
467
|
+
|
|
468
|
+
for group in self.param_groups:
|
|
469
|
+
group['d_numerator'] = global_d_numerator
|
|
470
|
+
group['d'] = self.d
|
|
471
|
+
group['d_max'] = d_max
|
|
472
|
+
|
|
473
|
+
# Increment step counter for all groups, regardless of whether d was updated
|
|
474
|
+
for group in self.param_groups:
|
|
475
|
+
group['k'] += 1
|