adv-optm 1.2.dev13__py3-none-any.whl → 2.dev1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of adv-optm might be problematic. Click here for more details.
- adv_optm/__init__.py +1 -1
- adv_optm/optim/AdamW_adv.py +85 -64
- adv_optm/optim/Adopt_adv.py +114 -69
- adv_optm/optim/Lion_Prodigy_adv.py +79 -81
- adv_optm/optim/Lion_adv.py +37 -42
- adv_optm/optim/Prodigy_adv.py +105 -85
- adv_optm/optim/Simplified_AdEMAMix.py +92 -51
- adv_optm/optim/__init__.py +1 -1
- adv_optm/util/BF16_Stochastic_Rounding.py +1 -1
- adv_optm/util/Effective_Shape.py +1 -1
- adv_optm/util/Kourkoutas.py +11 -12
- adv_optm/util/NNMF.py +7 -2
- adv_optm/util/Newton_Schulz.py +1 -2
- adv_optm/util/One_Bit_Boolean.py +1 -1
- adv_optm/util/OrthoGrad.py +4 -3
- adv_optm/util/__init__.py +1 -1
- {adv_optm-1.2.dev13.dist-info → adv_optm-2.dev1.dist-info}/METADATA +20 -20
- adv_optm-2.dev1.dist-info/RECORD +23 -0
- adv_optm-1.2.dev13.dist-info/RECORD +0 -23
- {adv_optm-1.2.dev13.dist-info → adv_optm-2.dev1.dist-info}/WHEEL +0 -0
- {adv_optm-1.2.dev13.dist-info → adv_optm-2.dev1.dist-info}/licenses/LICENSE +0 -0
- {adv_optm-1.2.dev13.dist-info → adv_optm-2.dev1.dist-info}/top_level.txt +0 -0
|
@@ -27,17 +27,13 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
|
|
|
27
27
|
stochastic_rounding (bool, optional): whether to use stochastic
|
|
28
28
|
rounding for BF16 parameter updates (default: True).
|
|
29
29
|
cautious_mask (bool): whether to use the cautious masking technique. (default: False).
|
|
30
|
-
clip_threshold (float, optional): whether to clip the gradients norm
|
|
31
|
-
per-parameter as proposed in the paper `Lions and Muons: Optimization via
|
|
32
|
-
Stochastic Frank-Wolfe` (https://arxiv.org/abs/2506.04192) to make Lion more stable
|
|
33
|
-
(default: 0.0).
|
|
34
30
|
nnmf_factor (bool): whether to use the factorization or use the
|
|
35
31
|
uncompressed optimizer. (default: True)
|
|
36
32
|
d0 (float):
|
|
37
33
|
Initial D estimate for D-adaptation (default 1e-6). Rarely needs changing.
|
|
38
34
|
d_coef (float):
|
|
39
35
|
Coefficient in the expression for the estimate of d (default 1.0).
|
|
40
|
-
Values such as 0.5 and 2.0 typically work as well.
|
|
36
|
+
Values such as 0.5 and 2.0 typically work as well.
|
|
41
37
|
Changing this parameter is the preferred way to tune the method.
|
|
42
38
|
growth_rate (float):
|
|
43
39
|
prevent the D estimate from growing faster than this multiplicative rate.
|
|
@@ -47,8 +43,8 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
|
|
|
47
43
|
If you're using sharded parameters, this should be set to True. The optimizer
|
|
48
44
|
will attempt to auto-detect this, but if you're using an implementation other
|
|
49
45
|
than PyTorch's builtin version, the auto-detection won't work.
|
|
50
|
-
slice_p (int): Reduce memory usage by calculating LR adaptation statistics on only every
|
|
51
|
-
pth entry of each tensor. For values greater than 1 this an an approximation to standard
|
|
46
|
+
slice_p (int): Reduce memory usage by calculating LR adaptation statistics on only every
|
|
47
|
+
pth entry of each tensor. For values greater than 1 this an an approximation to standard
|
|
52
48
|
Prodigy. Values ~11 are reasonable (default 11).
|
|
53
49
|
prodigy_steps (int): If greater than zero, disable Prodigy's stepsize adjustments
|
|
54
50
|
after the specified optimiser step and release all state memory required by Prodigy
|
|
@@ -64,11 +60,10 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
|
|
|
64
60
|
lr: float = 1,
|
|
65
61
|
betas: Tuple[float, float] = (0.9, 0.99),
|
|
66
62
|
weight_decay: float = 0.0,
|
|
67
|
-
vector_reshape: bool =
|
|
63
|
+
vector_reshape: bool = False,
|
|
68
64
|
stochastic_rounding: bool = True,
|
|
69
65
|
orthogonal_gradient: bool = False,
|
|
70
66
|
cautious_mask: bool = False,
|
|
71
|
-
clip_threshold: float = 0.0,
|
|
72
67
|
nnmf_factor: bool = False,
|
|
73
68
|
# prodigy parameters
|
|
74
69
|
beta3: float = None,
|
|
@@ -80,6 +75,8 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
|
|
|
80
75
|
slice_p: int = 11,
|
|
81
76
|
prodigy_steps: int = 0,
|
|
82
77
|
d_limiter: bool = True,
|
|
78
|
+
# Compiled
|
|
79
|
+
compiled_optimizer: bool = False,
|
|
83
80
|
):
|
|
84
81
|
if not lr > 0.0:
|
|
85
82
|
raise ValueError(f"Learning rate must be > 0.0, but got {lr}")
|
|
@@ -94,21 +91,28 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
|
|
|
94
91
|
weight_decay=weight_decay,
|
|
95
92
|
vector_reshape=vector_reshape,
|
|
96
93
|
orthogonal_gradient=orthogonal_gradient,
|
|
97
|
-
clip_threshold=clip_threshold,
|
|
98
94
|
beta3=beta3, d=d0, d0=d0, d_max=d0, d_numerator=0.0, d_coef=d_coef,
|
|
99
|
-
growth_rate=growth_rate, safeguard_warmup=safeguard_warmup,
|
|
95
|
+
growth_rate=growth_rate, safeguard_warmup=safeguard_warmup, slice_p=slice_p,
|
|
100
96
|
fsdp_in_use=fsdp_in_use,
|
|
101
97
|
prodigy_steps=prodigy_steps,
|
|
102
98
|
d_limiter=d_limiter,
|
|
99
|
+
compiled_optimizer=compiled_optimizer,
|
|
103
100
|
)
|
|
104
101
|
self.stochastic_rounding = stochastic_rounding
|
|
105
102
|
self.cautious_mask = cautious_mask
|
|
106
103
|
self.factored = nnmf_factor
|
|
107
104
|
self.fsdp_in_use = fsdp_in_use
|
|
108
105
|
super().__init__(params, defaults)
|
|
109
|
-
#
|
|
106
|
+
# Use the device of the first parameter to avoid hardcoding '.cuda()'
|
|
107
|
+
self.device = self.param_groups[0]['params'][0].device
|
|
108
|
+
|
|
109
|
+
self.global_step = 0
|
|
110
110
|
self.init_step()
|
|
111
111
|
|
|
112
|
+
if compiled_optimizer:
|
|
113
|
+
torch._dynamo.config.cache_size_limit = 8192
|
|
114
|
+
self.compile(fullgraph=False, dynamic=False) #FIXME
|
|
115
|
+
|
|
112
116
|
@property
|
|
113
117
|
def supports_fused_back_pass(self) -> bool:
|
|
114
118
|
return True
|
|
@@ -124,14 +128,13 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
|
|
|
124
128
|
def init_step(self):
|
|
125
129
|
"""Resets accumulators and calculates dlr for the upcoming step."""
|
|
126
130
|
self.d_denom = 0.0
|
|
127
|
-
|
|
131
|
+
|
|
128
132
|
g_group = self.param_groups[0]
|
|
129
133
|
self.beta1, self.beta2 = g_group['betas']
|
|
130
134
|
self.beta3 = g_group['beta3']
|
|
131
135
|
if self.beta3 is None:
|
|
132
136
|
self.beta3 = math.sqrt(self.beta2)
|
|
133
|
-
|
|
134
|
-
k = g_group['k']
|
|
137
|
+
|
|
135
138
|
self.d = g_group['d']
|
|
136
139
|
lr = g_group['lr']
|
|
137
140
|
|
|
@@ -139,38 +142,21 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
|
|
|
139
142
|
|
|
140
143
|
self.d_numerator = g_group.get('d_numerator', 0.0) * self.beta3
|
|
141
144
|
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
if p.grad is None:
|
|
146
|
-
return
|
|
147
|
-
|
|
148
|
-
if hasattr(p, "_fsdp_flattened"):
|
|
149
|
-
self.fsdp_in_use = True
|
|
145
|
+
for group in self.param_groups:
|
|
146
|
+
for i, p in enumerate(group['params']):
|
|
147
|
+
self.__init_state(p, group)
|
|
150
148
|
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
grad = grad.float()
|
|
154
|
-
if group["clip_threshold"] > 0.0:
|
|
155
|
-
grad_norm = torch.norm(grad.detach())
|
|
156
|
-
if grad_norm > group["clip_threshold"]:
|
|
157
|
-
clip_coef = group["clip_threshold"] / grad_norm
|
|
158
|
-
grad.mul_(clip_coef)
|
|
159
|
-
if group["orthogonal_gradient"]:
|
|
160
|
-
grad = _orthogonalize_gradient(p, grad)
|
|
149
|
+
@torch.no_grad()
|
|
150
|
+
def __init_state(self, p, group):
|
|
161
151
|
state = self.state[p]
|
|
162
152
|
|
|
163
|
-
|
|
164
|
-
if 'step' not in state:
|
|
165
|
-
state['step'] = 0
|
|
153
|
+
if len(state) == 0:
|
|
166
154
|
|
|
167
|
-
|
|
155
|
+
state['factored'] = (
|
|
168
156
|
self.factored and
|
|
169
157
|
not (len(p.shape) == 1 and not group['vector_reshape'])
|
|
170
158
|
)
|
|
171
159
|
|
|
172
|
-
state['factored'] = should_factor
|
|
173
|
-
|
|
174
160
|
dtype = torch.float32 if self.factored else p.dtype
|
|
175
161
|
|
|
176
162
|
slice_p = group['slice_p']
|
|
@@ -185,13 +171,28 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
|
|
|
185
171
|
if state['factored']:
|
|
186
172
|
state['effective_shape'] = _get_effective_shape(p.numel())
|
|
187
173
|
d1, d2 = state['effective_shape']
|
|
188
|
-
state['mu_m_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
|
|
174
|
+
state['mu_m_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
|
|
189
175
|
state['mv_m_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
|
|
190
176
|
packed_d2 = (d2 + 7) // 8
|
|
191
177
|
state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=p.device)
|
|
192
178
|
else: # Fallback to standard Lion
|
|
193
179
|
state['exp_avg'] = torch.zeros_like(p, device=p.device, dtype=dtype)
|
|
194
180
|
|
|
181
|
+
@torch.no_grad()
|
|
182
|
+
def __step_parameter(self, p: torch.Tensor, group: dict, d: torch.Tensor | float, dlr: torch.Tensor | float):
|
|
183
|
+
"""Performs a single optimization step on a single parameter."""
|
|
184
|
+
if p.grad is None:
|
|
185
|
+
return
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
grad = p.grad
|
|
189
|
+
if grad.dtype != torch.float32 and self.factored:
|
|
190
|
+
grad = grad.float()
|
|
191
|
+
if group["orthogonal_gradient"]:
|
|
192
|
+
grad = _orthogonalize_gradient(p, grad)
|
|
193
|
+
state = self.state[p]
|
|
194
|
+
|
|
195
|
+
|
|
195
196
|
if state['factored']:
|
|
196
197
|
# Factored Path
|
|
197
198
|
d1, d2 = state['effective_shape']
|
|
@@ -205,7 +206,7 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
|
|
|
205
206
|
exp_avg = exp_avg.float()
|
|
206
207
|
|
|
207
208
|
# Compute update term c_t = β1*m_{t-1} + (1-β1)*g_t
|
|
208
|
-
signed_update = exp_avg.clone().mul_(self.beta1).add_(grad_reshaped, alpha=
|
|
209
|
+
signed_update = exp_avg.clone().mul_(self.beta1).add_(grad_reshaped, alpha=d * (1-self.beta1)).sign_()
|
|
209
210
|
|
|
210
211
|
if self.cautious_mask:
|
|
211
212
|
mask = (signed_update * grad_reshaped > 0).to(grad_reshaped.dtype)
|
|
@@ -214,10 +215,10 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
|
|
|
214
215
|
del mask
|
|
215
216
|
|
|
216
217
|
# Parameter update: p_t = p_{t-1} - lr * sign(c_t)
|
|
217
|
-
update_for_param = signed_update.view(p.shape).mul(
|
|
218
|
+
update_for_param = signed_update.view(p.shape).mul(dlr)
|
|
218
219
|
|
|
219
220
|
# Update momentum m_t = β2*m_{t-1} + (1-β2)*lr*g_t
|
|
220
|
-
exp_avg.mul_(self.beta2).add_(grad_reshaped, alpha=
|
|
221
|
+
exp_avg.mul_(self.beta2).add_(grad_reshaped, alpha=d * (1 - self.beta2))
|
|
221
222
|
del grad_reshaped
|
|
222
223
|
|
|
223
224
|
# Compress new momentum m_t and store factors
|
|
@@ -232,7 +233,7 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
|
|
|
232
233
|
# Compute update term and sign for the update
|
|
233
234
|
if exp_avg.dtype != torch.float32 and self.factored:
|
|
234
235
|
exp_avg = exp_avg.float()
|
|
235
|
-
signed_update = exp_avg.clone().mul_(self.beta1).add_(grad, alpha=
|
|
236
|
+
signed_update = exp_avg.clone().mul_(self.beta1).add_(grad, alpha=d * (1-self.beta1)).sign_()
|
|
236
237
|
|
|
237
238
|
if self.cautious_mask:
|
|
238
239
|
mask = (signed_update * grad > 0).to(grad.dtype)
|
|
@@ -240,41 +241,18 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
|
|
|
240
241
|
signed_update.mul_(mask)
|
|
241
242
|
del mask
|
|
242
243
|
|
|
243
|
-
update_for_param = signed_update.mul(
|
|
244
|
-
|
|
245
|
-
# Update momentum
|
|
246
|
-
exp_avg.mul_(self.beta2).add_(grad, alpha=self.d * (1 - self.beta2))
|
|
247
|
-
|
|
248
|
-
prodigy_steps = group['prodigy_steps']
|
|
249
|
-
if prodigy_steps <= 0 or group['k'] < prodigy_steps:
|
|
250
|
-
# --- Accumulate Prodigy stats ---
|
|
251
|
-
d0, safeguard_warmup, slice_p = group['d0'], group['safeguard_warmup'], group['slice_p']
|
|
252
|
-
s, p0 = state['s'], state['p0']
|
|
253
|
-
grad_flat = grad.flatten().float()
|
|
254
|
-
p_flat = p.data.flatten().float()
|
|
255
|
-
p0 = p0.float()
|
|
256
|
-
|
|
257
|
-
self.d_numerator += (self.d / d0) * self.dlr * torch.dot(grad_flat[::slice_p], p0.data - p_flat[::slice_p]).item()
|
|
258
|
-
|
|
259
|
-
alpha = ((self.d / d0) * self.d) if safeguard_warmup else ((self.d / d0) * self.dlr)
|
|
260
|
-
s.mul_(self.beta3).add_(grad_flat[::slice_p], alpha=alpha)
|
|
261
|
-
self.d_denom += s.abs().sum().item()
|
|
244
|
+
update_for_param = signed_update.mul(dlr)
|
|
262
245
|
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
# Free memory if prodigy_steps is reached
|
|
266
|
-
if 's' in state:
|
|
267
|
-
del state['s']
|
|
268
|
-
if 'p0' in state:
|
|
269
|
-
del state['p0']
|
|
246
|
+
# Update momentum
|
|
247
|
+
exp_avg.mul_(self.beta2).add_(grad, alpha=d * (1 - self.beta2))
|
|
270
248
|
|
|
271
249
|
if group["weight_decay"] != 0:
|
|
272
250
|
if p.dtype == torch.bfloat16 and self.stochastic_rounding:
|
|
273
251
|
add_stochastic_(p.data, p.data,
|
|
274
|
-
alpha=-group["weight_decay"] *
|
|
252
|
+
alpha=-group["weight_decay"] * dlr)
|
|
275
253
|
else:
|
|
276
254
|
p.data.add_(
|
|
277
|
-
p.data, alpha=-group["weight_decay"] *
|
|
255
|
+
p.data, alpha=-group["weight_decay"] * dlr
|
|
278
256
|
)
|
|
279
257
|
|
|
280
258
|
if p.dtype == torch.bfloat16 and self.stochastic_rounding:
|
|
@@ -284,6 +262,29 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
|
|
|
284
262
|
|
|
285
263
|
del update_for_param
|
|
286
264
|
|
|
265
|
+
@torch.no_grad()
|
|
266
|
+
def step_parameter(self, p: torch.Tensor, group: dict, i: int | None = None):
|
|
267
|
+
if hasattr(p, "_fsdp_flattened"):
|
|
268
|
+
self.fsdp_in_use = True
|
|
269
|
+
|
|
270
|
+
if self.global_step is None and 'step' in self.state[p]:
|
|
271
|
+
# For backward compatibility
|
|
272
|
+
self.global_step = self.state[p]['step']
|
|
273
|
+
|
|
274
|
+
if isinstance(self.d_numerator, float):
|
|
275
|
+
self.d_numerator = torch.tensor(self.d_numerator, device=p.device)
|
|
276
|
+
self.d_denom = torch.tensor(self.d_denom, device=p.device)
|
|
277
|
+
|
|
278
|
+
if not group.get('compiled_optimizer', False):
|
|
279
|
+
self.__step_parameter(p, group, self.d, self.dlr)
|
|
280
|
+
else:
|
|
281
|
+
d_tensor = torch.tensor(self.d, device=p.device)
|
|
282
|
+
dlr_tensor = torch.tensor(self.dlr, device=p.device)
|
|
283
|
+
self._compiled_step_parameter(p, group, d_tensor, dlr_tensor)
|
|
284
|
+
|
|
285
|
+
def compile(self, *args, **kwargs):
|
|
286
|
+
self._compiled_step_parameter = torch.compile(self.__step_parameter, *args, **kwargs)
|
|
287
|
+
|
|
287
288
|
@torch.no_grad()
|
|
288
289
|
def step(self, closure: Optional[callable] = None):
|
|
289
290
|
"""Performs a single optimization step."""
|
|
@@ -306,21 +307,19 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
|
|
|
306
307
|
"""Calculates the new `d` based on the accumulated stats."""
|
|
307
308
|
g_group = self.param_groups[0]
|
|
308
309
|
# Only perform d-adaptation if prodigy_steps has not been reached
|
|
309
|
-
prodigy_active = not (g_group.get('prodigy_steps', 0) > 0 and
|
|
310
|
+
prodigy_active = not (g_group.get('prodigy_steps', 0) > 0 and self.global_step >= g_group['prodigy_steps'])
|
|
310
311
|
|
|
311
312
|
if prodigy_active:
|
|
312
313
|
d_max, d_coef, growth_rate = g_group['d_max'], g_group['d_coef'], g_group['growth_rate']
|
|
313
|
-
|
|
314
|
+
|
|
314
315
|
if self.fsdp_in_use and dist.is_available() and dist.is_initialized():
|
|
315
|
-
|
|
316
|
-
device = self.param_groups[0]['params'][0].device
|
|
317
|
-
dist_tensor = torch.tensor([self.d_numerator, self.d_denom], device=device)
|
|
316
|
+
dist_tensor = torch.stack([self.d_numerator, self.d_denom])
|
|
318
317
|
dist.all_reduce(dist_tensor, op=dist.ReduceOp.SUM)
|
|
319
318
|
global_d_numerator = dist_tensor[0].item()
|
|
320
319
|
global_d_denom = dist_tensor[1].item()
|
|
321
320
|
else:
|
|
322
|
-
global_d_numerator = self.d_numerator
|
|
323
|
-
global_d_denom = self.d_denom
|
|
321
|
+
global_d_numerator = self.d_numerator.item()
|
|
322
|
+
global_d_denom = self.d_denom.item()
|
|
324
323
|
|
|
325
324
|
d_hat = self.d
|
|
326
325
|
if global_d_denom > 0:
|
|
@@ -337,5 +336,4 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
|
|
|
337
336
|
group['d'] = self.d
|
|
338
337
|
group['d_max'] = d_max
|
|
339
338
|
# Increment step counter for all groups, regardless of whether d was updated
|
|
340
|
-
|
|
341
|
-
group['k'] += 1
|
|
339
|
+
self.global_step += 1
|
adv_optm/optim/Lion_adv.py
CHANGED
|
@@ -27,10 +27,6 @@ class Lion_adv(torch.optim.Optimizer):
|
|
|
27
27
|
stochastic_rounding (bool, optional): whether to use stochastic
|
|
28
28
|
rounding for BF16 parameter updates (default: True).
|
|
29
29
|
cautious_mask (bool): whether to use the cautious masking technique. (default: False).
|
|
30
|
-
clip_threshold (float, optional): whether to clip the gradients norm
|
|
31
|
-
per-parameter as proposed in the paper `Lions and Muons: Optimization via
|
|
32
|
-
Stochastic Frank-Wolfe` (https://arxiv.org/abs/2506.04192) to make Lion more stable
|
|
33
|
-
(default: 0.0).
|
|
34
30
|
nnmf_factor (bool): whether to use the factorization or use the
|
|
35
31
|
uncompressed optimizer. (default: True)
|
|
36
32
|
"""
|
|
@@ -41,11 +37,10 @@ class Lion_adv(torch.optim.Optimizer):
|
|
|
41
37
|
lr: float = 1e-4,
|
|
42
38
|
betas: Tuple[float, float] = (0.9, 0.99),
|
|
43
39
|
weight_decay: float = 0.0,
|
|
44
|
-
vector_reshape: bool =
|
|
40
|
+
vector_reshape: bool = False,
|
|
45
41
|
stochastic_rounding: bool = True,
|
|
46
42
|
orthogonal_gradient: bool = False,
|
|
47
43
|
cautious_mask: bool = False,
|
|
48
|
-
clip_threshold: float = 0.0,
|
|
49
44
|
nnmf_factor: bool = True,
|
|
50
45
|
):
|
|
51
46
|
if not lr > 0.0:
|
|
@@ -61,7 +56,6 @@ class Lion_adv(torch.optim.Optimizer):
|
|
|
61
56
|
weight_decay=weight_decay,
|
|
62
57
|
vector_reshape=vector_reshape,
|
|
63
58
|
orthogonal_gradient=orthogonal_gradient,
|
|
64
|
-
clip_threshold=clip_threshold,
|
|
65
59
|
)
|
|
66
60
|
self.stochastic_rounding = stochastic_rounding
|
|
67
61
|
self.cautious_mask = cautious_mask
|
|
@@ -80,48 +74,49 @@ class Lion_adv(torch.optim.Optimizer):
|
|
|
80
74
|
def supports_flat_params(self) -> bool:
|
|
81
75
|
return False
|
|
82
76
|
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
return
|
|
77
|
+
def init_step(self):
|
|
78
|
+
for group in self.param_groups:
|
|
79
|
+
for i, p in enumerate(group['params']):
|
|
80
|
+
self.__init_state(p, group)
|
|
88
81
|
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
grad = grad.float()
|
|
92
|
-
if group["clip_threshold"] > 0.0:
|
|
93
|
-
grad_norm = torch.norm(grad.detach())
|
|
94
|
-
if grad_norm > group["clip_threshold"]:
|
|
95
|
-
clip_coef = group["clip_threshold"] / grad_norm
|
|
96
|
-
grad.mul_(clip_coef)
|
|
97
|
-
if group["orthogonal_gradient"]:
|
|
98
|
-
grad = _orthogonalize_gradient(p, grad)
|
|
82
|
+
@torch.no_grad()
|
|
83
|
+
def __init_state(self, p, group):
|
|
99
84
|
state = self.state[p]
|
|
100
85
|
|
|
101
|
-
|
|
102
|
-
if 'step' not in state:
|
|
103
|
-
state['step'] = 0
|
|
86
|
+
if len(state) == 0:
|
|
104
87
|
|
|
105
|
-
|
|
88
|
+
state['factored'] = (
|
|
106
89
|
self.factored and
|
|
107
90
|
not (len(p.shape) == 1 and not group['vector_reshape'])
|
|
108
91
|
)
|
|
109
92
|
|
|
110
|
-
state['factored'] = should_factor
|
|
111
|
-
|
|
112
93
|
dtype = torch.float32 if self.factored else p.dtype
|
|
113
94
|
|
|
114
95
|
if state['factored']:
|
|
115
96
|
state['effective_shape'] = _get_effective_shape(p.numel())
|
|
116
97
|
d1, d2 = state['effective_shape']
|
|
117
|
-
state['mu_m_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
|
|
98
|
+
state['mu_m_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
|
|
118
99
|
state['mv_m_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
|
|
119
100
|
packed_d2 = (d2 + 7) // 8
|
|
120
101
|
state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=p.device)
|
|
121
102
|
else: # Fallback to standard Lion
|
|
122
103
|
state['exp_avg'] = torch.zeros_like(p, device=p.device, dtype=dtype)
|
|
123
104
|
|
|
124
|
-
|
|
105
|
+
@torch.no_grad()
|
|
106
|
+
def __step_parameter(self, p: torch.Tensor, group: dict, lr: torch.Tensor | float):
|
|
107
|
+
"""Performs a single optimization step on a single parameter."""
|
|
108
|
+
if p.grad is None:
|
|
109
|
+
return
|
|
110
|
+
|
|
111
|
+
grad = p.grad
|
|
112
|
+
if grad.dtype != torch.float32 and self.factored:
|
|
113
|
+
grad = grad.float()
|
|
114
|
+
if group["orthogonal_gradient"]:
|
|
115
|
+
grad = _orthogonalize_gradient(p, grad)
|
|
116
|
+
|
|
117
|
+
state = self.state[p]
|
|
118
|
+
|
|
119
|
+
|
|
125
120
|
beta1, beta2 = group["betas"]
|
|
126
121
|
lr = group["lr"]
|
|
127
122
|
|
|
@@ -138,16 +133,16 @@ class Lion_adv(torch.optim.Optimizer):
|
|
|
138
133
|
exp_avg = exp_avg.float()
|
|
139
134
|
|
|
140
135
|
# Compute update term c_t
|
|
141
|
-
|
|
136
|
+
update = exp_avg.clone().mul_(beta1).add_(grad_reshaped, alpha=(1-beta1)).sign_()
|
|
142
137
|
|
|
143
138
|
if self.cautious_mask:
|
|
144
|
-
mask = (
|
|
139
|
+
mask = (update * grad_reshaped > 0).to(grad_reshaped.dtype)
|
|
145
140
|
mask.div_(mask.mean().clamp_(min=1e-3))
|
|
146
|
-
|
|
141
|
+
update.mul_(mask)
|
|
147
142
|
del mask
|
|
148
143
|
|
|
149
144
|
# Parameter update
|
|
150
|
-
|
|
145
|
+
update = update.view(p.shape).mul_(lr)
|
|
151
146
|
|
|
152
147
|
# Standard Lion momentum update
|
|
153
148
|
exp_avg.mul_(beta2).add_(grad_reshaped, alpha=1-beta2)
|
|
@@ -165,15 +160,15 @@ class Lion_adv(torch.optim.Optimizer):
|
|
|
165
160
|
# Compute update term and sign for the update
|
|
166
161
|
if exp_avg.dtype != torch.float32 and self.factored:
|
|
167
162
|
exp_avg = exp_avg.float()
|
|
168
|
-
|
|
163
|
+
update = exp_avg.clone().mul_(beta1).add_(grad, alpha=(1-beta1)).sign_()
|
|
169
164
|
|
|
170
165
|
if self.cautious_mask:
|
|
171
|
-
mask = (
|
|
166
|
+
mask = (update * grad > 0).to(grad.dtype)
|
|
172
167
|
mask.div_(mask.mean().clamp_(min=1e-3))
|
|
173
|
-
|
|
168
|
+
update.mul_(mask)
|
|
174
169
|
del mask
|
|
175
170
|
|
|
176
|
-
|
|
171
|
+
update.mul_(lr)
|
|
177
172
|
|
|
178
173
|
# Standard Lion momentum update
|
|
179
174
|
exp_avg.mul_(beta2).add_(grad, alpha=1-beta2)
|
|
@@ -188,11 +183,11 @@ class Lion_adv(torch.optim.Optimizer):
|
|
|
188
183
|
)
|
|
189
184
|
|
|
190
185
|
if p.dtype == torch.bfloat16 and self.stochastic_rounding:
|
|
191
|
-
add_stochastic_(p.data, -
|
|
186
|
+
add_stochastic_(p.data, -update)
|
|
192
187
|
else:
|
|
193
|
-
p.data.add_(-
|
|
188
|
+
p.data.add_(-update)
|
|
194
189
|
|
|
195
|
-
del
|
|
190
|
+
del update
|
|
196
191
|
|
|
197
192
|
@torch.no_grad()
|
|
198
193
|
def step(self, closure: Optional[callable] = None):
|
|
@@ -207,4 +202,4 @@ class Lion_adv(torch.optim.Optimizer):
|
|
|
207
202
|
if p.grad is not None:
|
|
208
203
|
self.step_parameter(p, group, i)
|
|
209
204
|
|
|
210
|
-
return loss
|
|
205
|
+
return loss
|