adv-optm 1.1.0.dev3__py3-none-any.whl → 1.1.0.dev5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of adv-optm might be problematic. Click here for more details.
- adv_optm/__init__.py +1 -1
- adv_optm/optim/AdamW_adv.py +3 -3
- adv_optm/optim/Adopt_adv.py +435 -439
- adv_optm/optim/Lion_Prodigy_adv.py +315 -315
- adv_optm/optim/Lion_adv.py +1 -1
- adv_optm/optim/Prodigy_adv.py +13 -6
- adv_optm/optim/Simplified_AdEMAMix.py +3 -3
- adv_optm/util/Kourkoutas.py +71 -36
- {adv_optm-1.1.0.dev3.dist-info → adv_optm-1.1.0.dev5.dist-info}/METADATA +1 -1
- adv_optm-1.1.0.dev5.dist-info/RECORD +20 -0
- adv_optm-1.1.0.dev3.dist-info/RECORD +0 -20
- {adv_optm-1.1.0.dev3.dist-info → adv_optm-1.1.0.dev5.dist-info}/WHEEL +0 -0
- {adv_optm-1.1.0.dev3.dist-info → adv_optm-1.1.0.dev5.dist-info}/licenses/LICENSE +0 -0
- {adv_optm-1.1.0.dev3.dist-info → adv_optm-1.1.0.dev5.dist-info}/top_level.txt +0 -0
|
@@ -1,315 +1,315 @@
|
|
|
1
|
-
import torch
|
|
2
|
-
import torch.distributed as dist
|
|
3
|
-
|
|
4
|
-
import math
|
|
5
|
-
|
|
6
|
-
from typing import Tuple, Optional
|
|
7
|
-
|
|
8
|
-
from ..util.BF16_Stochastic_Rounding import add_stochastic_
|
|
9
|
-
from ..util.Effective_Shape import _get_effective_shape
|
|
10
|
-
from ..util.NNMF import _nnmf,_unnmf
|
|
11
|
-
from ..util.OrthoGrad import _orthogonalize_gradient
|
|
12
|
-
from ..util.One_Bit_Boolean import _pack_bools, _unpack_bools
|
|
13
|
-
|
|
14
|
-
class Lion_Prodigy_adv(torch.optim.Optimizer):
|
|
15
|
-
"""
|
|
16
|
-
Implements the SMMF technique and Prodigy D-Adaptation method for Lion algorithm.
|
|
17
|
-
|
|
18
|
-
Args:
|
|
19
|
-
params (iterable): iterable of parameters to optimize or dicts defining
|
|
20
|
-
parameter groups.
|
|
21
|
-
lr (float, optional): learning rate (default: 1e-4).
|
|
22
|
-
betas (Tuple[float, float], optional): coefficients for computing
|
|
23
|
-
running averages of the update (default: (0.9, 0.99)).
|
|
24
|
-
weight_decay (float, optional): weight decay (L2 penalty) (default: 0.0).
|
|
25
|
-
vector_reshape (bool, optional): whether to reshape 1D vectors into 2D
|
|
26
|
-
matrices to apply low-rank compression (default: True).
|
|
27
|
-
stochastic_rounding (bool, optional): whether to use stochastic
|
|
28
|
-
rounding for BF16 parameter updates (default: True).
|
|
29
|
-
cautious_mask (bool): whether to use the cautious masking technique. (default: False).
|
|
30
|
-
clip_threshold (float, optional): whether to clip the gradients norm
|
|
31
|
-
per-parameter as proposed in the paper `Lions and Muons: Optimization via
|
|
32
|
-
Stochastic Frank-Wolfe` (https://arxiv.org/abs/2506.04192) to make Lion more stable
|
|
33
|
-
(default: 0.0).
|
|
34
|
-
nnmf_factor (bool): whether to use the factorization or use the
|
|
35
|
-
uncompressed optimizer. (default: True)
|
|
36
|
-
d0 (float):
|
|
37
|
-
Initial D estimate for D-adaptation (default 1e-6). Rarely needs changing.
|
|
38
|
-
d_coef (float):
|
|
39
|
-
Coefficient in the expression for the estimate of d (default 1.0).
|
|
40
|
-
Values such as 0.5 and 2.0 typically work as well.
|
|
41
|
-
Changing this parameter is the preferred way to tune the method.
|
|
42
|
-
growth_rate (float):
|
|
43
|
-
prevent the D estimate from growing faster than this multiplicative rate.
|
|
44
|
-
Default is inf, for unrestricted. Values like 1.02 give a kind of learning
|
|
45
|
-
rate warmup effect.
|
|
46
|
-
fsdp_in_use (bool):
|
|
47
|
-
If you're using sharded parameters, this should be set to True. The optimizer
|
|
48
|
-
will attempt to auto-detect this, but if you're using an implementation other
|
|
49
|
-
than PyTorch's builtin version, the auto-detection won't work.
|
|
50
|
-
slice_p (int): Reduce memory usage by calculating LR adaptation statistics on only every
|
|
51
|
-
pth entry of each tensor. For values greater than 1 this an an approximation to standard
|
|
52
|
-
Prodigy. Values ~11 are reasonable (default 11).
|
|
53
|
-
"""
|
|
54
|
-
|
|
55
|
-
def __init__(
|
|
56
|
-
self,
|
|
57
|
-
params,
|
|
58
|
-
lr: float = 1,
|
|
59
|
-
betas: Tuple[float, float] = (0.9, 0.99),
|
|
60
|
-
weight_decay: float = 0.0,
|
|
61
|
-
vector_reshape: bool = True,
|
|
62
|
-
stochastic_rounding: bool = True,
|
|
63
|
-
orthogonal_gradient: bool = False,
|
|
64
|
-
cautious_mask: bool = False,
|
|
65
|
-
clip_threshold: float = 0.0,
|
|
66
|
-
nnmf_factor: bool = True,
|
|
67
|
-
# prodigy parameters
|
|
68
|
-
beta3: float = None,
|
|
69
|
-
d0: float = 1e-6,
|
|
70
|
-
d_coef: float = 1,
|
|
71
|
-
growth_rate: float = float('inf'),
|
|
72
|
-
safeguard_warmup: bool = False,
|
|
73
|
-
fsdp_in_use: bool = False,
|
|
74
|
-
slice_p: int = 11,
|
|
75
|
-
):
|
|
76
|
-
if not lr > 0.0:
|
|
77
|
-
raise ValueError(f"Learning rate must be > 0.0, but got {lr}")
|
|
78
|
-
if not all(0.0 <= beta <= 1.0 for beta in betas):
|
|
79
|
-
raise ValueError(f"Betas should be in [0.0, 1.0], but got {betas}")
|
|
80
|
-
if not weight_decay >= 0.0:
|
|
81
|
-
raise ValueError(f"Weight decay must be >= 0.0, but got {weight_decay}")
|
|
82
|
-
|
|
83
|
-
defaults = dict(
|
|
84
|
-
lr=lr,
|
|
85
|
-
betas=betas,
|
|
86
|
-
weight_decay=weight_decay,
|
|
87
|
-
vector_reshape=vector_reshape,
|
|
88
|
-
orthogonal_gradient=orthogonal_gradient,
|
|
89
|
-
clip_threshold=clip_threshold,
|
|
90
|
-
beta3=beta3, d=d0, d0=d0, d_max=d0, d_numerator=0.0, d_coef=d_coef,
|
|
91
|
-
growth_rate=growth_rate, safeguard_warmup=safeguard_warmup, k=0, slice_p=slice_p,
|
|
92
|
-
fsdp_in_use=fsdp_in_use,
|
|
93
|
-
)
|
|
94
|
-
self.stochastic_rounding = stochastic_rounding
|
|
95
|
-
self.cautious_mask = cautious_mask
|
|
96
|
-
self.factored = nnmf_factor
|
|
97
|
-
self.fsdp_in_use = fsdp_in_use
|
|
98
|
-
super().__init__(params, defaults)
|
|
99
|
-
# Global state for accumulating metrics across parameter updates within a single step.
|
|
100
|
-
self.init_step()
|
|
101
|
-
|
|
102
|
-
@property
|
|
103
|
-
def supports_fused_back_pass(self) -> bool:
|
|
104
|
-
return True
|
|
105
|
-
|
|
106
|
-
@property
|
|
107
|
-
def supports_memory_efficient_fp16(self) -> bool:
|
|
108
|
-
return True
|
|
109
|
-
|
|
110
|
-
@property
|
|
111
|
-
def supports_flat_params(self) -> bool:
|
|
112
|
-
return False
|
|
113
|
-
|
|
114
|
-
def init_step(self):
|
|
115
|
-
"""Resets accumulators and calculates dlr for the upcoming step."""
|
|
116
|
-
self.d_denom = 0.0
|
|
117
|
-
|
|
118
|
-
g_group = self.param_groups[0]
|
|
119
|
-
self.beta1, self.beta2 = g_group['betas']
|
|
120
|
-
self.beta3 = g_group['beta3']
|
|
121
|
-
if self.beta3 is None:
|
|
122
|
-
self.beta3 = math.sqrt(self.beta2)
|
|
123
|
-
|
|
124
|
-
k = g_group['k']
|
|
125
|
-
self.d = g_group['d']
|
|
126
|
-
lr = g_group['lr']
|
|
127
|
-
|
|
128
|
-
self.dlr = self.d * lr
|
|
129
|
-
|
|
130
|
-
self.d_numerator = g_group.get('d_numerator', 0.0) * self.beta3
|
|
131
|
-
|
|
132
|
-
@torch.no_grad()
|
|
133
|
-
def step_parameter(self, p: torch.Tensor, group: dict, i: Optional[int] = None):
|
|
134
|
-
"""Performs a single optimization step on a single parameter."""
|
|
135
|
-
if p.grad is None:
|
|
136
|
-
return
|
|
137
|
-
|
|
138
|
-
if hasattr(p, "_fsdp_flattened"):
|
|
139
|
-
self.fsdp_in_use = True
|
|
140
|
-
|
|
141
|
-
grad = p.grad
|
|
142
|
-
if grad.dtype != torch.float32 and self.factored:
|
|
143
|
-
grad = grad.float()
|
|
144
|
-
if group["clip_threshold"] > 0.0:
|
|
145
|
-
grad_norm = torch.norm(grad.detach())
|
|
146
|
-
if grad_norm > group["clip_threshold"]:
|
|
147
|
-
clip_coef = group["clip_threshold"] / grad_norm
|
|
148
|
-
grad.mul_(clip_coef)
|
|
149
|
-
if group["orthogonal_gradient"]:
|
|
150
|
-
grad = _orthogonalize_gradient(p, grad)
|
|
151
|
-
state = self.state[p]
|
|
152
|
-
|
|
153
|
-
# State Initialization
|
|
154
|
-
if
|
|
155
|
-
state['step'] = 0
|
|
156
|
-
|
|
157
|
-
should_factor = (
|
|
158
|
-
self.factored and
|
|
159
|
-
not (len(p.shape) == 1 and not group['vector_reshape'])
|
|
160
|
-
)
|
|
161
|
-
|
|
162
|
-
state['factored'] = should_factor
|
|
163
|
-
|
|
164
|
-
dtype = torch.float32 if self.factored else p.dtype
|
|
165
|
-
|
|
166
|
-
slice_p = group['slice_p']
|
|
167
|
-
|
|
168
|
-
# D-Adaptation states
|
|
169
|
-
state['s'] = torch.zeros_like(p.flatten()[::slice_p]).detach()
|
|
170
|
-
if p.any():
|
|
171
|
-
state['p0'] = p.flatten()[::slice_p].detach().clone()
|
|
172
|
-
else:
|
|
173
|
-
state['p0'] = torch.tensor(0, device=p.device, dtype=p.dtype)
|
|
174
|
-
|
|
175
|
-
if state['factored']:
|
|
176
|
-
state['effective_shape'] = _get_effective_shape(p.numel())
|
|
177
|
-
d1, d2 = state['effective_shape']
|
|
178
|
-
state['mu_m_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
|
|
179
|
-
state['mv_m_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
|
|
180
|
-
packed_d2 = (d2 + 7) // 8
|
|
181
|
-
state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=p.device)
|
|
182
|
-
else: # Fallback to standard Lion
|
|
183
|
-
state['exp_avg'] = torch.zeros_like(p, device=p.device, dtype=dtype)
|
|
184
|
-
|
|
185
|
-
if state['factored']:
|
|
186
|
-
# Factored Path
|
|
187
|
-
d1, d2 = state['effective_shape']
|
|
188
|
-
grad_reshaped = grad.view(d1, d2)
|
|
189
|
-
# Reconstruct momentum m_{t-1}
|
|
190
|
-
exp_avg = _unnmf((state['mu_m_nmf'], state['mv_m_nmf']))
|
|
191
|
-
unpacked_sign = _unpack_bools(state['sign'], original_m=d2)
|
|
192
|
-
torch.where(unpacked_sign, exp_avg, -exp_avg, out=exp_avg)
|
|
193
|
-
del unpacked_sign
|
|
194
|
-
if exp_avg.dtype != torch.float32:
|
|
195
|
-
exp_avg = exp_avg.float()
|
|
196
|
-
|
|
197
|
-
# Compute update term c_t = β1*m_{t-1} + (1-β1)*g_t
|
|
198
|
-
signed_update = exp_avg.clone().mul_(self.beta1).add_(grad_reshaped, alpha=self.d * (1-self.beta1)).sign_()
|
|
199
|
-
|
|
200
|
-
if self.cautious_mask:
|
|
201
|
-
mask = (signed_update * grad_reshaped > 0).to(grad_reshaped.dtype)
|
|
202
|
-
mask.div_(mask.mean().clamp_(min=1e-3))
|
|
203
|
-
signed_update.mul_(mask)
|
|
204
|
-
del mask
|
|
205
|
-
|
|
206
|
-
# Parameter update: p_t = p_{t-1} - lr * sign(c_t)
|
|
207
|
-
update_for_param = signed_update.view(p.shape).mul(self.dlr)
|
|
208
|
-
|
|
209
|
-
# Update momentum m_t = β2*m_{t-1} + (1-β2)*lr*g_t
|
|
210
|
-
exp_avg.mul_(self.beta2).add_(grad_reshaped, alpha=self.d * (1 - self.beta2))
|
|
211
|
-
del grad_reshaped
|
|
212
|
-
|
|
213
|
-
# Compress new momentum m_t and store factors
|
|
214
|
-
state['sign'] = _pack_bools(exp_avg > 0)
|
|
215
|
-
_nnmf(exp_avg.abs(), out=(state['mu_m_nmf'], state['mv_m_nmf']))
|
|
216
|
-
del exp_avg
|
|
217
|
-
|
|
218
|
-
else:
|
|
219
|
-
# Fallback to standard Lion logic
|
|
220
|
-
exp_avg = state["exp_avg"]
|
|
221
|
-
|
|
222
|
-
# Compute update term and sign for the update
|
|
223
|
-
if exp_avg.dtype != torch.float32 and self.factored:
|
|
224
|
-
exp_avg = exp_avg.float()
|
|
225
|
-
signed_update = exp_avg.clone().mul_(self.beta1).add_(grad, alpha=self.d * (1-self.beta1)).sign_()
|
|
226
|
-
|
|
227
|
-
if self.cautious_mask:
|
|
228
|
-
mask = (signed_update * grad > 0).to(grad.dtype)
|
|
229
|
-
mask.div_(mask.mean().clamp_(min=1e-3))
|
|
230
|
-
signed_update.mul_(mask)
|
|
231
|
-
del mask
|
|
232
|
-
|
|
233
|
-
update_for_param = signed_update.mul(self.dlr)
|
|
234
|
-
|
|
235
|
-
# Update momentum
|
|
236
|
-
exp_avg.mul_(self.beta2).add_(grad, alpha=self.d * (1 - self.beta2))
|
|
237
|
-
|
|
238
|
-
# --- Accumulate Prodigy stats ---
|
|
239
|
-
d0, safeguard_warmup, slice_p = group['d0'], group['safeguard_warmup'], group['slice_p']
|
|
240
|
-
s, p0 = state['s'], state['p0']
|
|
241
|
-
grad_flat = grad.flatten().float()
|
|
242
|
-
p_flat = p.data.flatten().float()
|
|
243
|
-
p0 = p0.float()
|
|
244
|
-
|
|
245
|
-
self.d_numerator += (self.d / d0) * self.dlr * torch.dot(grad_flat[::slice_p], p0.data - p_flat[::slice_p]).item()
|
|
246
|
-
|
|
247
|
-
alpha = ((self.d / d0) * self.d) if safeguard_warmup else ((self.d / d0) * self.dlr)
|
|
248
|
-
s.mul_(self.beta3).add_(grad_flat[::slice_p], alpha=alpha)
|
|
249
|
-
self.d_denom += s.abs().sum().item()
|
|
250
|
-
|
|
251
|
-
del s, p0, grad_flat, p_flat, alpha
|
|
252
|
-
|
|
253
|
-
if group["weight_decay"] != 0:
|
|
254
|
-
if p.dtype == torch.bfloat16 and self.stochastic_rounding:
|
|
255
|
-
add_stochastic_(p.data, p.data,
|
|
256
|
-
alpha=-group["weight_decay"] * self.dlr)
|
|
257
|
-
else:
|
|
258
|
-
p.data.add_(
|
|
259
|
-
p.data, alpha=-group["weight_decay"] * self.dlr
|
|
260
|
-
)
|
|
261
|
-
|
|
262
|
-
if p.dtype == torch.bfloat16 and self.stochastic_rounding:
|
|
263
|
-
add_stochastic_(p.data, -update_for_param)
|
|
264
|
-
else:
|
|
265
|
-
p.data.add_(-update_for_param)
|
|
266
|
-
|
|
267
|
-
del update_for_param
|
|
268
|
-
|
|
269
|
-
@torch.no_grad()
|
|
270
|
-
def step(self, closure: Optional[callable] = None):
|
|
271
|
-
"""Performs a single optimization step."""
|
|
272
|
-
loss = None
|
|
273
|
-
if closure is not None:
|
|
274
|
-
with torch.enable_grad():
|
|
275
|
-
loss = closure()
|
|
276
|
-
|
|
277
|
-
for group in self.param_groups:
|
|
278
|
-
for i, p in enumerate(group["params"]):
|
|
279
|
-
if p.grad is not None:
|
|
280
|
-
self.step_parameter(p, group, i)
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
self.calculate_d()
|
|
284
|
-
self.init_step()
|
|
285
|
-
return loss
|
|
286
|
-
|
|
287
|
-
def calculate_d(self):
|
|
288
|
-
"""Calculates the new `d` based on the accumulated stats."""
|
|
289
|
-
g_group = self.param_groups[0]
|
|
290
|
-
d_max, d_coef, growth_rate = g_group['d_max'], g_group['d_coef'], g_group['growth_rate']
|
|
291
|
-
|
|
292
|
-
if self.fsdp_in_use and dist.is_available() and dist.is_initialized():
|
|
293
|
-
# Use the device of the first parameter to avoid hardcoding '.cuda()'
|
|
294
|
-
device = self.param_groups[0]['params'][0].device
|
|
295
|
-
dist_tensor = torch.tensor([self.d_numerator, self.d_denom], device=device)
|
|
296
|
-
dist.all_reduce(dist_tensor, op=dist.ReduceOp.SUM)
|
|
297
|
-
global_d_numerator = dist_tensor[0].item()
|
|
298
|
-
global_d_denom = dist_tensor[1].item()
|
|
299
|
-
else:
|
|
300
|
-
global_d_numerator = self.d_numerator
|
|
301
|
-
global_d_denom = self.d_denom
|
|
302
|
-
|
|
303
|
-
d_hat = self.d
|
|
304
|
-
if global_d_denom > 0:
|
|
305
|
-
d_hat = d_coef * global_d_numerator / global_d_denom
|
|
306
|
-
if self.d == g_group['d0']:
|
|
307
|
-
self.d = max(self.d, d_hat)
|
|
308
|
-
d_max = max(d_max, d_hat)
|
|
309
|
-
self.d = min(d_max, self.d * growth_rate)
|
|
310
|
-
|
|
311
|
-
for group in self.param_groups:
|
|
312
|
-
group['d_numerator'] = global_d_numerator
|
|
313
|
-
group['d'] = self.d
|
|
314
|
-
group['d_max'] = d_max
|
|
315
|
-
group['k'] += 1
|
|
1
|
+
import torch
|
|
2
|
+
import torch.distributed as dist
|
|
3
|
+
|
|
4
|
+
import math
|
|
5
|
+
|
|
6
|
+
from typing import Tuple, Optional
|
|
7
|
+
|
|
8
|
+
from ..util.BF16_Stochastic_Rounding import add_stochastic_
|
|
9
|
+
from ..util.Effective_Shape import _get_effective_shape
|
|
10
|
+
from ..util.NNMF import _nnmf,_unnmf
|
|
11
|
+
from ..util.OrthoGrad import _orthogonalize_gradient
|
|
12
|
+
from ..util.One_Bit_Boolean import _pack_bools, _unpack_bools
|
|
13
|
+
|
|
14
|
+
class Lion_Prodigy_adv(torch.optim.Optimizer):
|
|
15
|
+
"""
|
|
16
|
+
Implements the SMMF technique and Prodigy D-Adaptation method for Lion algorithm.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
params (iterable): iterable of parameters to optimize or dicts defining
|
|
20
|
+
parameter groups.
|
|
21
|
+
lr (float, optional): learning rate (default: 1e-4).
|
|
22
|
+
betas (Tuple[float, float], optional): coefficients for computing
|
|
23
|
+
running averages of the update (default: (0.9, 0.99)).
|
|
24
|
+
weight_decay (float, optional): weight decay (L2 penalty) (default: 0.0).
|
|
25
|
+
vector_reshape (bool, optional): whether to reshape 1D vectors into 2D
|
|
26
|
+
matrices to apply low-rank compression (default: True).
|
|
27
|
+
stochastic_rounding (bool, optional): whether to use stochastic
|
|
28
|
+
rounding for BF16 parameter updates (default: True).
|
|
29
|
+
cautious_mask (bool): whether to use the cautious masking technique. (default: False).
|
|
30
|
+
clip_threshold (float, optional): whether to clip the gradients norm
|
|
31
|
+
per-parameter as proposed in the paper `Lions and Muons: Optimization via
|
|
32
|
+
Stochastic Frank-Wolfe` (https://arxiv.org/abs/2506.04192) to make Lion more stable
|
|
33
|
+
(default: 0.0).
|
|
34
|
+
nnmf_factor (bool): whether to use the factorization or use the
|
|
35
|
+
uncompressed optimizer. (default: True)
|
|
36
|
+
d0 (float):
|
|
37
|
+
Initial D estimate for D-adaptation (default 1e-6). Rarely needs changing.
|
|
38
|
+
d_coef (float):
|
|
39
|
+
Coefficient in the expression for the estimate of d (default 1.0).
|
|
40
|
+
Values such as 0.5 and 2.0 typically work as well.
|
|
41
|
+
Changing this parameter is the preferred way to tune the method.
|
|
42
|
+
growth_rate (float):
|
|
43
|
+
prevent the D estimate from growing faster than this multiplicative rate.
|
|
44
|
+
Default is inf, for unrestricted. Values like 1.02 give a kind of learning
|
|
45
|
+
rate warmup effect.
|
|
46
|
+
fsdp_in_use (bool):
|
|
47
|
+
If you're using sharded parameters, this should be set to True. The optimizer
|
|
48
|
+
will attempt to auto-detect this, but if you're using an implementation other
|
|
49
|
+
than PyTorch's builtin version, the auto-detection won't work.
|
|
50
|
+
slice_p (int): Reduce memory usage by calculating LR adaptation statistics on only every
|
|
51
|
+
pth entry of each tensor. For values greater than 1 this an an approximation to standard
|
|
52
|
+
Prodigy. Values ~11 are reasonable (default 11).
|
|
53
|
+
"""
|
|
54
|
+
|
|
55
|
+
def __init__(
|
|
56
|
+
self,
|
|
57
|
+
params,
|
|
58
|
+
lr: float = 1,
|
|
59
|
+
betas: Tuple[float, float] = (0.9, 0.99),
|
|
60
|
+
weight_decay: float = 0.0,
|
|
61
|
+
vector_reshape: bool = True,
|
|
62
|
+
stochastic_rounding: bool = True,
|
|
63
|
+
orthogonal_gradient: bool = False,
|
|
64
|
+
cautious_mask: bool = False,
|
|
65
|
+
clip_threshold: float = 0.0,
|
|
66
|
+
nnmf_factor: bool = True,
|
|
67
|
+
# prodigy parameters
|
|
68
|
+
beta3: float = None,
|
|
69
|
+
d0: float = 1e-6,
|
|
70
|
+
d_coef: float = 1,
|
|
71
|
+
growth_rate: float = float('inf'),
|
|
72
|
+
safeguard_warmup: bool = False,
|
|
73
|
+
fsdp_in_use: bool = False,
|
|
74
|
+
slice_p: int = 11,
|
|
75
|
+
):
|
|
76
|
+
if not lr > 0.0:
|
|
77
|
+
raise ValueError(f"Learning rate must be > 0.0, but got {lr}")
|
|
78
|
+
if not all(0.0 <= beta <= 1.0 for beta in betas):
|
|
79
|
+
raise ValueError(f"Betas should be in [0.0, 1.0], but got {betas}")
|
|
80
|
+
if not weight_decay >= 0.0:
|
|
81
|
+
raise ValueError(f"Weight decay must be >= 0.0, but got {weight_decay}")
|
|
82
|
+
|
|
83
|
+
defaults = dict(
|
|
84
|
+
lr=lr,
|
|
85
|
+
betas=betas,
|
|
86
|
+
weight_decay=weight_decay,
|
|
87
|
+
vector_reshape=vector_reshape,
|
|
88
|
+
orthogonal_gradient=orthogonal_gradient,
|
|
89
|
+
clip_threshold=clip_threshold,
|
|
90
|
+
beta3=beta3, d=d0, d0=d0, d_max=d0, d_numerator=0.0, d_coef=d_coef,
|
|
91
|
+
growth_rate=growth_rate, safeguard_warmup=safeguard_warmup, k=0, slice_p=slice_p,
|
|
92
|
+
fsdp_in_use=fsdp_in_use,
|
|
93
|
+
)
|
|
94
|
+
self.stochastic_rounding = stochastic_rounding
|
|
95
|
+
self.cautious_mask = cautious_mask
|
|
96
|
+
self.factored = nnmf_factor
|
|
97
|
+
self.fsdp_in_use = fsdp_in_use
|
|
98
|
+
super().__init__(params, defaults)
|
|
99
|
+
# Global state for accumulating metrics across parameter updates within a single step.
|
|
100
|
+
self.init_step()
|
|
101
|
+
|
|
102
|
+
@property
|
|
103
|
+
def supports_fused_back_pass(self) -> bool:
|
|
104
|
+
return True
|
|
105
|
+
|
|
106
|
+
@property
|
|
107
|
+
def supports_memory_efficient_fp16(self) -> bool:
|
|
108
|
+
return True
|
|
109
|
+
|
|
110
|
+
@property
|
|
111
|
+
def supports_flat_params(self) -> bool:
|
|
112
|
+
return False
|
|
113
|
+
|
|
114
|
+
def init_step(self):
|
|
115
|
+
"""Resets accumulators and calculates dlr for the upcoming step."""
|
|
116
|
+
self.d_denom = 0.0
|
|
117
|
+
|
|
118
|
+
g_group = self.param_groups[0]
|
|
119
|
+
self.beta1, self.beta2 = g_group['betas']
|
|
120
|
+
self.beta3 = g_group['beta3']
|
|
121
|
+
if self.beta3 is None:
|
|
122
|
+
self.beta3 = math.sqrt(self.beta2)
|
|
123
|
+
|
|
124
|
+
k = g_group['k']
|
|
125
|
+
self.d = g_group['d']
|
|
126
|
+
lr = g_group['lr']
|
|
127
|
+
|
|
128
|
+
self.dlr = self.d * lr
|
|
129
|
+
|
|
130
|
+
self.d_numerator = g_group.get('d_numerator', 0.0) * self.beta3
|
|
131
|
+
|
|
132
|
+
@torch.no_grad()
|
|
133
|
+
def step_parameter(self, p: torch.Tensor, group: dict, i: Optional[int] = None):
|
|
134
|
+
"""Performs a single optimization step on a single parameter."""
|
|
135
|
+
if p.grad is None:
|
|
136
|
+
return
|
|
137
|
+
|
|
138
|
+
if hasattr(p, "_fsdp_flattened"):
|
|
139
|
+
self.fsdp_in_use = True
|
|
140
|
+
|
|
141
|
+
grad = p.grad
|
|
142
|
+
if grad.dtype != torch.float32 and self.factored:
|
|
143
|
+
grad = grad.float()
|
|
144
|
+
if group["clip_threshold"] > 0.0:
|
|
145
|
+
grad_norm = torch.norm(grad.detach())
|
|
146
|
+
if grad_norm > group["clip_threshold"]:
|
|
147
|
+
clip_coef = group["clip_threshold"] / grad_norm
|
|
148
|
+
grad.mul_(clip_coef)
|
|
149
|
+
if group["orthogonal_gradient"]:
|
|
150
|
+
grad = _orthogonalize_gradient(p, grad)
|
|
151
|
+
state = self.state[p]
|
|
152
|
+
|
|
153
|
+
# State Initialization
|
|
154
|
+
if 'step' not in state:
|
|
155
|
+
state['step'] = 0
|
|
156
|
+
|
|
157
|
+
should_factor = (
|
|
158
|
+
self.factored and
|
|
159
|
+
not (len(p.shape) == 1 and not group['vector_reshape'])
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
state['factored'] = should_factor
|
|
163
|
+
|
|
164
|
+
dtype = torch.float32 if self.factored else p.dtype
|
|
165
|
+
|
|
166
|
+
slice_p = group['slice_p']
|
|
167
|
+
|
|
168
|
+
# D-Adaptation states
|
|
169
|
+
state['s'] = torch.zeros_like(p.flatten()[::slice_p]).detach()
|
|
170
|
+
if p.any():
|
|
171
|
+
state['p0'] = p.flatten()[::slice_p].detach().clone()
|
|
172
|
+
else:
|
|
173
|
+
state['p0'] = torch.tensor(0, device=p.device, dtype=p.dtype)
|
|
174
|
+
|
|
175
|
+
if state['factored']:
|
|
176
|
+
state['effective_shape'] = _get_effective_shape(p.numel())
|
|
177
|
+
d1, d2 = state['effective_shape']
|
|
178
|
+
state['mu_m_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
|
|
179
|
+
state['mv_m_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
|
|
180
|
+
packed_d2 = (d2 + 7) // 8
|
|
181
|
+
state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=p.device)
|
|
182
|
+
else: # Fallback to standard Lion
|
|
183
|
+
state['exp_avg'] = torch.zeros_like(p, device=p.device, dtype=dtype)
|
|
184
|
+
|
|
185
|
+
if state['factored']:
|
|
186
|
+
# Factored Path
|
|
187
|
+
d1, d2 = state['effective_shape']
|
|
188
|
+
grad_reshaped = grad.view(d1, d2)
|
|
189
|
+
# Reconstruct momentum m_{t-1}
|
|
190
|
+
exp_avg = _unnmf((state['mu_m_nmf'], state['mv_m_nmf']))
|
|
191
|
+
unpacked_sign = _unpack_bools(state['sign'], original_m=d2)
|
|
192
|
+
torch.where(unpacked_sign, exp_avg, -exp_avg, out=exp_avg)
|
|
193
|
+
del unpacked_sign
|
|
194
|
+
if exp_avg.dtype != torch.float32:
|
|
195
|
+
exp_avg = exp_avg.float()
|
|
196
|
+
|
|
197
|
+
# Compute update term c_t = β1*m_{t-1} + (1-β1)*g_t
|
|
198
|
+
signed_update = exp_avg.clone().mul_(self.beta1).add_(grad_reshaped, alpha=self.d * (1-self.beta1)).sign_()
|
|
199
|
+
|
|
200
|
+
if self.cautious_mask:
|
|
201
|
+
mask = (signed_update * grad_reshaped > 0).to(grad_reshaped.dtype)
|
|
202
|
+
mask.div_(mask.mean().clamp_(min=1e-3))
|
|
203
|
+
signed_update.mul_(mask)
|
|
204
|
+
del mask
|
|
205
|
+
|
|
206
|
+
# Parameter update: p_t = p_{t-1} - lr * sign(c_t)
|
|
207
|
+
update_for_param = signed_update.view(p.shape).mul(self.dlr)
|
|
208
|
+
|
|
209
|
+
# Update momentum m_t = β2*m_{t-1} + (1-β2)*lr*g_t
|
|
210
|
+
exp_avg.mul_(self.beta2).add_(grad_reshaped, alpha=self.d * (1 - self.beta2))
|
|
211
|
+
del grad_reshaped
|
|
212
|
+
|
|
213
|
+
# Compress new momentum m_t and store factors
|
|
214
|
+
state['sign'] = _pack_bools(exp_avg > 0)
|
|
215
|
+
_nnmf(exp_avg.abs(), out=(state['mu_m_nmf'], state['mv_m_nmf']))
|
|
216
|
+
del exp_avg
|
|
217
|
+
|
|
218
|
+
else:
|
|
219
|
+
# Fallback to standard Lion logic
|
|
220
|
+
exp_avg = state["exp_avg"]
|
|
221
|
+
|
|
222
|
+
# Compute update term and sign for the update
|
|
223
|
+
if exp_avg.dtype != torch.float32 and self.factored:
|
|
224
|
+
exp_avg = exp_avg.float()
|
|
225
|
+
signed_update = exp_avg.clone().mul_(self.beta1).add_(grad, alpha=self.d * (1-self.beta1)).sign_()
|
|
226
|
+
|
|
227
|
+
if self.cautious_mask:
|
|
228
|
+
mask = (signed_update * grad > 0).to(grad.dtype)
|
|
229
|
+
mask.div_(mask.mean().clamp_(min=1e-3))
|
|
230
|
+
signed_update.mul_(mask)
|
|
231
|
+
del mask
|
|
232
|
+
|
|
233
|
+
update_for_param = signed_update.mul(self.dlr)
|
|
234
|
+
|
|
235
|
+
# Update momentum
|
|
236
|
+
exp_avg.mul_(self.beta2).add_(grad, alpha=self.d * (1 - self.beta2))
|
|
237
|
+
|
|
238
|
+
# --- Accumulate Prodigy stats ---
|
|
239
|
+
d0, safeguard_warmup, slice_p = group['d0'], group['safeguard_warmup'], group['slice_p']
|
|
240
|
+
s, p0 = state['s'], state['p0']
|
|
241
|
+
grad_flat = grad.flatten().float()
|
|
242
|
+
p_flat = p.data.flatten().float()
|
|
243
|
+
p0 = p0.float()
|
|
244
|
+
|
|
245
|
+
self.d_numerator += (self.d / d0) * self.dlr * torch.dot(grad_flat[::slice_p], p0.data - p_flat[::slice_p]).item()
|
|
246
|
+
|
|
247
|
+
alpha = ((self.d / d0) * self.d) if safeguard_warmup else ((self.d / d0) * self.dlr)
|
|
248
|
+
s.mul_(self.beta3).add_(grad_flat[::slice_p], alpha=alpha)
|
|
249
|
+
self.d_denom += s.abs().sum().item()
|
|
250
|
+
|
|
251
|
+
del s, p0, grad_flat, p_flat, alpha
|
|
252
|
+
|
|
253
|
+
if group["weight_decay"] != 0:
|
|
254
|
+
if p.dtype == torch.bfloat16 and self.stochastic_rounding:
|
|
255
|
+
add_stochastic_(p.data, p.data,
|
|
256
|
+
alpha=-group["weight_decay"] * self.dlr)
|
|
257
|
+
else:
|
|
258
|
+
p.data.add_(
|
|
259
|
+
p.data, alpha=-group["weight_decay"] * self.dlr
|
|
260
|
+
)
|
|
261
|
+
|
|
262
|
+
if p.dtype == torch.bfloat16 and self.stochastic_rounding:
|
|
263
|
+
add_stochastic_(p.data, -update_for_param)
|
|
264
|
+
else:
|
|
265
|
+
p.data.add_(-update_for_param)
|
|
266
|
+
|
|
267
|
+
del update_for_param
|
|
268
|
+
|
|
269
|
+
@torch.no_grad()
|
|
270
|
+
def step(self, closure: Optional[callable] = None):
|
|
271
|
+
"""Performs a single optimization step."""
|
|
272
|
+
loss = None
|
|
273
|
+
if closure is not None:
|
|
274
|
+
with torch.enable_grad():
|
|
275
|
+
loss = closure()
|
|
276
|
+
|
|
277
|
+
for group in self.param_groups:
|
|
278
|
+
for i, p in enumerate(group["params"]):
|
|
279
|
+
if p.grad is not None:
|
|
280
|
+
self.step_parameter(p, group, i)
|
|
281
|
+
|
|
282
|
+
|
|
283
|
+
self.calculate_d()
|
|
284
|
+
self.init_step()
|
|
285
|
+
return loss
|
|
286
|
+
|
|
287
|
+
def calculate_d(self):
|
|
288
|
+
"""Calculates the new `d` based on the accumulated stats."""
|
|
289
|
+
g_group = self.param_groups[0]
|
|
290
|
+
d_max, d_coef, growth_rate = g_group['d_max'], g_group['d_coef'], g_group['growth_rate']
|
|
291
|
+
|
|
292
|
+
if self.fsdp_in_use and dist.is_available() and dist.is_initialized():
|
|
293
|
+
# Use the device of the first parameter to avoid hardcoding '.cuda()'
|
|
294
|
+
device = self.param_groups[0]['params'][0].device
|
|
295
|
+
dist_tensor = torch.tensor([self.d_numerator, self.d_denom], device=device)
|
|
296
|
+
dist.all_reduce(dist_tensor, op=dist.ReduceOp.SUM)
|
|
297
|
+
global_d_numerator = dist_tensor[0].item()
|
|
298
|
+
global_d_denom = dist_tensor[1].item()
|
|
299
|
+
else:
|
|
300
|
+
global_d_numerator = self.d_numerator
|
|
301
|
+
global_d_denom = self.d_denom
|
|
302
|
+
|
|
303
|
+
d_hat = self.d
|
|
304
|
+
if global_d_denom > 0:
|
|
305
|
+
d_hat = d_coef * global_d_numerator / global_d_denom
|
|
306
|
+
if self.d == g_group['d0']:
|
|
307
|
+
self.d = max(self.d, d_hat)
|
|
308
|
+
d_max = max(d_max, d_hat)
|
|
309
|
+
self.d = min(d_max, self.d * growth_rate)
|
|
310
|
+
|
|
311
|
+
for group in self.param_groups:
|
|
312
|
+
group['d_numerator'] = global_d_numerator
|
|
313
|
+
group['d'] = self.d
|
|
314
|
+
group['d_max'] = d_max
|
|
315
|
+
group['k'] += 1
|