adv-optm 1.2.dev19__py3-none-any.whl → 2.dev2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of adv-optm might be problematic. Click here for more details.
- adv_optm/__init__.py +1 -1
- adv_optm/optim/AdaMuon_adv.py +11 -9
- adv_optm/optim/AdamW_adv.py +91 -61
- adv_optm/optim/Adopt_adv.py +113 -68
- adv_optm/optim/Lion_Prodigy_adv.py +79 -81
- adv_optm/optim/Lion_adv.py +59 -43
- adv_optm/optim/Muon_adv.py +13 -12
- adv_optm/optim/Prodigy_adv.py +108 -86
- adv_optm/optim/Simplified_AdEMAMix.py +93 -52
- adv_optm/optim/__init__.py +1 -1
- adv_optm/util/BF16_Stochastic_Rounding.py +1 -1
- adv_optm/util/Effective_Shape.py +1 -1
- adv_optm/util/Kourkoutas.py +10 -12
- adv_optm/util/NNMF.py +7 -2
- adv_optm/util/One_Bit_Boolean.py +1 -1
- adv_optm/util/OrthoGrad.py +4 -3
- adv_optm/util/__init__.py +1 -1
- {adv_optm-1.2.dev19.dist-info → adv_optm-2.dev2.dist-info}/METADATA +20 -20
- adv_optm-2.dev2.dist-info/RECORD +23 -0
- adv_optm-1.2.dev19.dist-info/RECORD +0 -23
- {adv_optm-1.2.dev19.dist-info → adv_optm-2.dev2.dist-info}/WHEEL +0 -0
- {adv_optm-1.2.dev19.dist-info → adv_optm-2.dev2.dist-info}/licenses/LICENSE +0 -0
- {adv_optm-1.2.dev19.dist-info → adv_optm-2.dev2.dist-info}/top_level.txt +0 -0
adv_optm/optim/Adopt_adv.py
CHANGED
|
@@ -56,13 +56,6 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
56
56
|
before it is added to the fast momentum term (`update = mt + alpha * mt_slow`).
|
|
57
57
|
A higher value increases the stabilizing influence of the slow
|
|
58
58
|
momentum. (default: 5.0)
|
|
59
|
-
t_alpha (Optional[int]): The number of steps for a linear warmup of the
|
|
60
|
-
`alpha` parameter (only used when `use_AdEMAMix` is `True`). This is
|
|
61
|
-
highly recommended to prevent instability at the beginning of training,
|
|
62
|
-
as it gradually introduces the stabilizing slow momentum term. During
|
|
63
|
-
the warmup, `alpha` ramps from 0 to its target value. If `None`,
|
|
64
|
-
the scheduler is disabled and the full `alpha` value is used from
|
|
65
|
-
the start. (default: None)
|
|
66
59
|
Simplified_AdEMAMix (bool): whether to use the Simplified AdEMAMix update rule.
|
|
67
60
|
This changes the EMA to accumulator and the update numerator to `alpha_grad * grad + mt`, which can be
|
|
68
61
|
more responsive, especially for small batch sizes. Enabling this will
|
|
@@ -90,10 +83,10 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
90
83
|
k_logging (int): if > 0 and kourkoutas_beta=True, enables periodic console
|
|
91
84
|
logging of Kourkoutas-β statistics (min, max, mean of `β₂` across layers)
|
|
92
85
|
every logging steps. Useful for debugging and tuning. Set to 0 to disable
|
|
93
|
-
logging (default: 0).
|
|
86
|
+
logging (default: 0).
|
|
94
87
|
layer_key_fn (Optional[Callable]): A function that takes a parameter `p`
|
|
95
88
|
and returns a unique, hashable key representing its "layer" or "bucket".
|
|
96
|
-
If `None`, parameters are bucketed by their
|
|
89
|
+
If `None`, parameters are bucketed by their shape.
|
|
97
90
|
(default: None)
|
|
98
91
|
nnmf_factor (bool): whether to use the factorization or disable it to use
|
|
99
92
|
the uncompressed optimizer. (default: False)
|
|
@@ -107,7 +100,7 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
107
100
|
eps: float = 1e-6,
|
|
108
101
|
weight_decay: float = 0.0,
|
|
109
102
|
clip_lambda: Optional[Callable[[int], float]] = lambda step: step**0.25,
|
|
110
|
-
vector_reshape: bool =
|
|
103
|
+
vector_reshape: bool = False,
|
|
111
104
|
stochastic_rounding: bool = True,
|
|
112
105
|
use_atan2: bool = False,
|
|
113
106
|
cautious_mask: bool = False,
|
|
@@ -116,7 +109,6 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
116
109
|
use_AdEMAMix: bool = False,
|
|
117
110
|
beta3_ema: float = 0.9999,
|
|
118
111
|
alpha: float = 5.0,
|
|
119
|
-
t_alpha: int | None = None,
|
|
120
112
|
Simplified_AdEMAMix: bool = False,
|
|
121
113
|
alpha_grad: float = 100.0,
|
|
122
114
|
kourkoutas_beta: bool = False,
|
|
@@ -127,6 +119,8 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
127
119
|
k_logging: int = 0,
|
|
128
120
|
layer_key_fn: Optional[Callable] = None,
|
|
129
121
|
nnmf_factor: bool = False,
|
|
122
|
+
# Compiled
|
|
123
|
+
compiled_optimizer: bool = False,
|
|
130
124
|
):
|
|
131
125
|
if not (lr >= 0.0):
|
|
132
126
|
raise ValueError(f"Learning-rate should be >= 0.0. Got {lr}")
|
|
@@ -141,7 +135,8 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
141
135
|
cautious_mask = False
|
|
142
136
|
if betas[0] == 0.0 and Simplified_AdEMAMix:
|
|
143
137
|
raise ValueError(f"Beta1 cannot be 0.0 when using Simplified_AdEMAMix. Got {betas[0]}")
|
|
144
|
-
if kourkoutas_beta and not (betas[1] > beta2_min):
|
|
138
|
+
if kourkoutas_beta and not (betas[1] > beta2_min):
|
|
139
|
+
raise ValueError(f"For Kourkoutas-β, betas[1] (as beta2_max) must be > beta2_min. Got {betas[1]} and {beta2_min}")
|
|
145
140
|
if use_AdEMAMix and Simplified_AdEMAMix:
|
|
146
141
|
print("Warning: use_AdEMAMix is incompatible with Simplified_AdEMAMix, Disabling use_AdEMAMix.")
|
|
147
142
|
if grams_moment and Simplified_AdEMAMix:
|
|
@@ -152,9 +147,10 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
152
147
|
defaults = {
|
|
153
148
|
"lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay,
|
|
154
149
|
"vector_reshape": vector_reshape, "beta3_ema": beta3_ema, "alpha": alpha,
|
|
155
|
-
"
|
|
150
|
+
"alpha_grad": alpha_grad,
|
|
156
151
|
"kourkoutas_beta": kourkoutas_beta, "beta2_min": beta2_min, "ema_alpha": ema_alpha,
|
|
157
152
|
"tiny_spike": tiny_spike, "k_warmup_steps": k_warmup_steps, "k_logging": k_logging,
|
|
153
|
+
"compiled_optimizer": compiled_optimizer,
|
|
158
154
|
}
|
|
159
155
|
self.clip_lambda = clip_lambda
|
|
160
156
|
self.stochastic_rounding = stochastic_rounding
|
|
@@ -169,9 +165,17 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
169
165
|
self.layer_key_fn = layer_key_fn
|
|
170
166
|
super().__init__(params, defaults)
|
|
171
167
|
|
|
168
|
+
self.init_step()
|
|
169
|
+
|
|
172
170
|
if self.kourkoutas_beta:
|
|
173
171
|
self.kourkoutas_helper = KourkoutasHelper(self)
|
|
174
172
|
|
|
173
|
+
self.global_step = 0
|
|
174
|
+
|
|
175
|
+
if compiled_optimizer:
|
|
176
|
+
torch._dynamo.config.cache_size_limit = 8192
|
|
177
|
+
self.compile(fullgraph=True)
|
|
178
|
+
|
|
175
179
|
@property
|
|
176
180
|
def supports_fused_back_pass(self): return True
|
|
177
181
|
@property
|
|
@@ -179,29 +183,22 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
179
183
|
@property
|
|
180
184
|
def supports_flat_params(self): return False
|
|
181
185
|
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
+
def init_step(self):
|
|
187
|
+
for group in self.param_groups:
|
|
188
|
+
for p in group['params']:
|
|
189
|
+
self.__init_state(p, group)
|
|
186
190
|
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
grad = grad.float()
|
|
190
|
-
if self.orthogonal_gradient:
|
|
191
|
-
grad = _orthogonalize_gradient(p, grad)
|
|
191
|
+
@torch.no_grad()
|
|
192
|
+
def __init_state(self, p, group):
|
|
192
193
|
state = self.state[p]
|
|
193
194
|
|
|
194
|
-
|
|
195
|
-
if 'step' not in state:
|
|
196
|
-
state['step'] = 0
|
|
195
|
+
if len(state) == 0:
|
|
197
196
|
|
|
198
|
-
|
|
197
|
+
state['factored'] = (
|
|
199
198
|
self.factored and
|
|
200
199
|
not (len(p.shape) == 1 and not group['vector_reshape'])
|
|
201
200
|
)
|
|
202
201
|
|
|
203
|
-
state['factored'] = should_factor
|
|
204
|
-
|
|
205
202
|
dtype = torch.float32 if self.factored else p.dtype
|
|
206
203
|
|
|
207
204
|
if state['factored']:
|
|
@@ -210,55 +207,75 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
210
207
|
|
|
211
208
|
# m_0 = 0
|
|
212
209
|
if group['betas'][0] > 0:
|
|
213
|
-
state['mu_m_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
|
|
210
|
+
state['mu_m_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
|
|
214
211
|
state['mv_m_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
|
|
215
212
|
if not self.grams_moment:
|
|
216
213
|
packed_d2 = (d2 + 7) // 8
|
|
217
214
|
state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=p.device)
|
|
218
215
|
if self.use_AdEMAMix:
|
|
219
|
-
state['mu_m_slow_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
|
|
216
|
+
state['mu_m_slow_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
|
|
220
217
|
state['mv_m_slow_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
|
|
221
218
|
packed_d2 = (d2 + 7) // 8
|
|
222
219
|
state['sign_slow'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=p.device)
|
|
223
|
-
|
|
224
|
-
vt_init = grad.view(d1, d2).square_()
|
|
225
|
-
# Allocate NMF factors for v
|
|
226
|
-
state['mu_v_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
|
|
227
|
-
state['mv_v_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
|
|
228
|
-
# Initialize v_0 using NMF
|
|
229
|
-
_nnmf(vt_init, out=(state['mu_v_nmf'], state['mv_v_nmf']))
|
|
220
|
+
|
|
230
221
|
else: # Fallback for non-factored tensors
|
|
231
222
|
if group['betas'][0] > 0:
|
|
232
223
|
state['exp_avg'] = torch.zeros_like(p, dtype=dtype) # m_0
|
|
233
224
|
if self.use_AdEMAMix:
|
|
234
225
|
state['exp_avg_slow'] = torch.zeros_like(p, dtype=dtype)
|
|
235
|
-
|
|
226
|
+
|
|
227
|
+
@torch.no_grad()
|
|
228
|
+
def __init_step(self, p, group):
|
|
229
|
+
if p.grad is None:
|
|
230
|
+
return
|
|
231
|
+
|
|
232
|
+
state = self.state[p]
|
|
233
|
+
|
|
234
|
+
if 'exp_avg_sq' in state or 'mu_v_nmf' in state:
|
|
235
|
+
return
|
|
236
|
+
|
|
237
|
+
grad = p.grad
|
|
238
|
+
dtype = torch.float32 if self.factored else p.dtype
|
|
239
|
+
|
|
240
|
+
if state['factored']:
|
|
241
|
+
d1, d2 = state['effective_shape']
|
|
242
|
+
# v_0 = g_0^2 (SMMF_ADOPT NMF storage)
|
|
243
|
+
vt_init = grad.view(d1, d2).square_()
|
|
244
|
+
# Allocate NMF factors for v
|
|
245
|
+
state['mu_v_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
|
|
246
|
+
state['mv_v_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
|
|
247
|
+
# Initialize v_0 using NMF
|
|
248
|
+
_nnmf(vt_init, out=(state['mu_v_nmf'], state['mv_v_nmf']))
|
|
249
|
+
del vt_init
|
|
250
|
+
else:
|
|
251
|
+
state['exp_avg_sq'] = grad.square() # v_0
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
@torch.no_grad()
|
|
255
|
+
def __step_parameter(self, p: torch.Tensor, group: dict, lr: torch.Tensor | float):
|
|
256
|
+
if p.grad is None:
|
|
257
|
+
return
|
|
258
|
+
|
|
259
|
+
grad = p.grad
|
|
260
|
+
if self.factored and grad.dtype != torch.float32:
|
|
261
|
+
grad = grad.float()
|
|
262
|
+
if self.orthogonal_gradient:
|
|
263
|
+
grad = _orthogonalize_gradient(p, grad)
|
|
264
|
+
state = self.state[p]
|
|
265
|
+
|
|
236
266
|
|
|
237
267
|
beta1, beta2 = group['betas']
|
|
238
268
|
|
|
239
|
-
current_step = state['step']
|
|
240
269
|
if group.get('kourkoutas_beta', False):
|
|
241
|
-
# Call prepare_step() once at the beginning of the step for all params
|
|
242
|
-
self.kourkoutas_helper.maybe_prepare_step(current_step)
|
|
243
270
|
# Accumulate current grad's norm for the *next* step
|
|
244
271
|
self.kourkoutas_helper.accumulate_gradient_sq_norm(p, grad)
|
|
245
272
|
# Get the dynamic beta2 calculated in prepare_step()
|
|
246
|
-
beta2 = self.kourkoutas_helper.get_beta2(p, group
|
|
247
|
-
|
|
248
|
-
# The first step is for initialization only (skip when use_atan2 as it's scale invariant).
|
|
249
|
-
if state['step'] == 0 and not self.use_atan2:
|
|
250
|
-
state['step'] += 1
|
|
251
|
-
return
|
|
273
|
+
beta2 = self.kourkoutas_helper.get_beta2(p, group)
|
|
252
274
|
|
|
253
275
|
if self.use_AdEMAMix:
|
|
254
276
|
beta3_ema = group['beta3_ema']
|
|
255
277
|
alpha = group['alpha']
|
|
256
|
-
|
|
257
|
-
# Use step+1 for 1-based step count in scheduler
|
|
258
|
-
alpha_step = state['step'] + 1
|
|
259
|
-
alpha_t = alpha
|
|
260
|
-
if t_alpha is not None and t_alpha > 0 and alpha_step < t_alpha:
|
|
261
|
-
alpha_t = min(alpha_step * alpha / t_alpha, alpha)
|
|
278
|
+
|
|
262
279
|
if self.Simplified_AdEMAMix:
|
|
263
280
|
alpha_grad = group["alpha_grad"]
|
|
264
281
|
|
|
@@ -296,7 +313,7 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
296
313
|
else:
|
|
297
314
|
normalized_grad = grad_reshaped / denom.add_(group['eps'])
|
|
298
315
|
if self.clip_lambda is not None:
|
|
299
|
-
clip_val = self.clip_lambda(
|
|
316
|
+
clip_val = self.clip_lambda(self.global_step)
|
|
300
317
|
normalized_grad.clamp_(-clip_val, clip_val)
|
|
301
318
|
del denom
|
|
302
319
|
|
|
@@ -317,9 +334,9 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
317
334
|
if self.use_AdEMAMix:
|
|
318
335
|
mt_slow.mul_(beta3_ema).add_(normalized_grad, alpha=1.0 - beta3_ema)
|
|
319
336
|
if beta1 > 0:
|
|
320
|
-
update = torch.add(mt, mt_slow, alpha=
|
|
337
|
+
update = torch.add(mt, mt_slow, alpha=alpha)
|
|
321
338
|
else:
|
|
322
|
-
update = torch.add(normalized_grad, mt_slow, alpha=
|
|
339
|
+
update = torch.add(normalized_grad, mt_slow, alpha=alpha)
|
|
323
340
|
elif self.Simplified_AdEMAMix:
|
|
324
341
|
update = torch.add(mt, normalized_grad, alpha=alpha_grad)
|
|
325
342
|
else:
|
|
@@ -328,9 +345,9 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
328
345
|
update = update.view(p.shape)
|
|
329
346
|
|
|
330
347
|
if self.use_atan2:
|
|
331
|
-
update.mul_(
|
|
348
|
+
update.mul_(lr * 1.2732395447351628)
|
|
332
349
|
else:
|
|
333
|
-
update.mul_(
|
|
350
|
+
update.mul_(lr)
|
|
334
351
|
|
|
335
352
|
# Update second moment v_t for the *next* step using raw g_t
|
|
336
353
|
vt.mul_(beta2).addcmul_(grad_reshaped, grad_reshaped, value=1.0 - beta2)
|
|
@@ -353,7 +370,7 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
353
370
|
del vt
|
|
354
371
|
|
|
355
372
|
else: # Standard ADOPT logic for non-factored tensors
|
|
356
|
-
v = state['exp_avg_sq'] # v_{t-1}
|
|
373
|
+
v = state['exp_avg_sq'] # v_{t-1}
|
|
357
374
|
|
|
358
375
|
# ADOPT Step A: Decorrelate g_t using v_{t-1}
|
|
359
376
|
denom = v.sqrt()
|
|
@@ -363,7 +380,7 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
363
380
|
else:
|
|
364
381
|
normalized_grad = grad / denom.add_(group['eps'])
|
|
365
382
|
if self.clip_lambda is not None:
|
|
366
|
-
clip_val = self.clip_lambda(
|
|
383
|
+
clip_val = self.clip_lambda(self.global_step)
|
|
367
384
|
normalized_grad.clamp_(-clip_val, clip_val)
|
|
368
385
|
del denom
|
|
369
386
|
|
|
@@ -387,18 +404,18 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
387
404
|
m_slow = state['exp_avg_slow']
|
|
388
405
|
m_slow.mul_(beta3_ema).add_(normalized_grad, alpha=1.0 - beta3_ema)
|
|
389
406
|
if beta1 > 0:
|
|
390
|
-
update = torch.add(m, m_slow, alpha=
|
|
407
|
+
update = torch.add(m, m_slow, alpha=alpha)
|
|
391
408
|
else:
|
|
392
|
-
update = torch.add(normalized_grad, m_slow, alpha=
|
|
409
|
+
update = torch.add(normalized_grad, m_slow, alpha=alpha)
|
|
393
410
|
elif self.Simplified_AdEMAMix:
|
|
394
411
|
update = torch.add(m, normalized_grad, alpha=alpha_grad)
|
|
395
412
|
else:
|
|
396
413
|
update = m.clone() if beta1 > 0 else normalized_grad
|
|
397
414
|
|
|
398
415
|
if self.use_atan2:
|
|
399
|
-
update.mul_(
|
|
416
|
+
update.mul_(lr * 1.2732395447351628)
|
|
400
417
|
else:
|
|
401
|
-
update.mul_(
|
|
418
|
+
update.mul_(lr)
|
|
402
419
|
|
|
403
420
|
# Update second moment v_t for the next step using raw g_t
|
|
404
421
|
v.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2)
|
|
@@ -406,9 +423,9 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
406
423
|
# Parameter Update
|
|
407
424
|
if group["weight_decay"] != 0:
|
|
408
425
|
if p.dtype == torch.bfloat16 and self.stochastic_rounding:
|
|
409
|
-
add_stochastic_(p.data, p.data, alpha=-group["weight_decay"] *
|
|
426
|
+
add_stochastic_(p.data, p.data, alpha=-group["weight_decay"] * lr)
|
|
410
427
|
else:
|
|
411
|
-
p.data.add_(p.data, alpha=-group["weight_decay"] *
|
|
428
|
+
p.data.add_(p.data, alpha=-group["weight_decay"] * lr)
|
|
412
429
|
|
|
413
430
|
if p.dtype == torch.bfloat16 and self.stochastic_rounding:
|
|
414
431
|
add_stochastic_(p.data, -update)
|
|
@@ -416,7 +433,33 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
416
433
|
p.data.add_(-update)
|
|
417
434
|
del update
|
|
418
435
|
|
|
419
|
-
|
|
436
|
+
@torch.no_grad()
|
|
437
|
+
def step_parameter(self, p: torch.Tensor, group: dict, i: int | None = None):
|
|
438
|
+
if self.global_step is None and 'step' in self.state[p]:
|
|
439
|
+
# For backward compatibility
|
|
440
|
+
self.global_step = self.state[p]['step']
|
|
441
|
+
|
|
442
|
+
if self.global_step == 0:
|
|
443
|
+
self.__init_step(p, group)
|
|
444
|
+
|
|
445
|
+
# The first step is for initialization only (skip when use_atan2 as it's scale invariant).
|
|
446
|
+
if self.global_step == 0 and not self.use_atan2:
|
|
447
|
+
self.global_step += 1
|
|
448
|
+
return
|
|
449
|
+
|
|
450
|
+
if group.get('kourkoutas_beta', False):
|
|
451
|
+
# Prepare Kourkoutas-β once per step using the global step counter.
|
|
452
|
+
self.kourkoutas_helper.maybe_prepare_step(self.global_step)
|
|
453
|
+
|
|
454
|
+
if not group.get('compiled_optimizer', False):
|
|
455
|
+
self.__step_parameter(p, group, group['lr'])
|
|
456
|
+
else:
|
|
457
|
+
lr_tensor = torch.tensor(group['lr'], device=p.device)
|
|
458
|
+
self._compiled_step_parameter(p, group, lr_tensor)
|
|
459
|
+
|
|
460
|
+
def compile(self, *args, **kwargs):
|
|
461
|
+
self._compiled_step_parameter = torch.compile(self.__step_parameter, *args, **kwargs)
|
|
462
|
+
|
|
420
463
|
|
|
421
464
|
@torch.no_grad()
|
|
422
465
|
def step(self, closure=None):
|
|
@@ -430,4 +473,6 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
430
473
|
for i, p in enumerate(group['params']):
|
|
431
474
|
self.step_parameter(p, group, i)
|
|
432
475
|
|
|
433
|
-
|
|
476
|
+
self.global_step += 1
|
|
477
|
+
|
|
478
|
+
return loss
|
|
@@ -27,17 +27,13 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
|
|
|
27
27
|
stochastic_rounding (bool, optional): whether to use stochastic
|
|
28
28
|
rounding for BF16 parameter updates (default: True).
|
|
29
29
|
cautious_mask (bool): whether to use the cautious masking technique. (default: False).
|
|
30
|
-
clip_threshold (float, optional): whether to clip the gradients norm
|
|
31
|
-
per-parameter as proposed in the paper `Lions and Muons: Optimization via
|
|
32
|
-
Stochastic Frank-Wolfe` (https://arxiv.org/abs/2506.04192) to make Lion more stable
|
|
33
|
-
(default: 0.0).
|
|
34
30
|
nnmf_factor (bool): whether to use the factorization or use the
|
|
35
31
|
uncompressed optimizer. (default: True)
|
|
36
32
|
d0 (float):
|
|
37
33
|
Initial D estimate for D-adaptation (default 1e-6). Rarely needs changing.
|
|
38
34
|
d_coef (float):
|
|
39
35
|
Coefficient in the expression for the estimate of d (default 1.0).
|
|
40
|
-
Values such as 0.5 and 2.0 typically work as well.
|
|
36
|
+
Values such as 0.5 and 2.0 typically work as well.
|
|
41
37
|
Changing this parameter is the preferred way to tune the method.
|
|
42
38
|
growth_rate (float):
|
|
43
39
|
prevent the D estimate from growing faster than this multiplicative rate.
|
|
@@ -47,8 +43,8 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
|
|
|
47
43
|
If you're using sharded parameters, this should be set to True. The optimizer
|
|
48
44
|
will attempt to auto-detect this, but if you're using an implementation other
|
|
49
45
|
than PyTorch's builtin version, the auto-detection won't work.
|
|
50
|
-
slice_p (int): Reduce memory usage by calculating LR adaptation statistics on only every
|
|
51
|
-
pth entry of each tensor. For values greater than 1 this an an approximation to standard
|
|
46
|
+
slice_p (int): Reduce memory usage by calculating LR adaptation statistics on only every
|
|
47
|
+
pth entry of each tensor. For values greater than 1 this an an approximation to standard
|
|
52
48
|
Prodigy. Values ~11 are reasonable (default 11).
|
|
53
49
|
prodigy_steps (int): If greater than zero, disable Prodigy's stepsize adjustments
|
|
54
50
|
after the specified optimiser step and release all state memory required by Prodigy
|
|
@@ -64,11 +60,10 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
|
|
|
64
60
|
lr: float = 1,
|
|
65
61
|
betas: Tuple[float, float] = (0.9, 0.99),
|
|
66
62
|
weight_decay: float = 0.0,
|
|
67
|
-
vector_reshape: bool =
|
|
63
|
+
vector_reshape: bool = False,
|
|
68
64
|
stochastic_rounding: bool = True,
|
|
69
65
|
orthogonal_gradient: bool = False,
|
|
70
66
|
cautious_mask: bool = False,
|
|
71
|
-
clip_threshold: float = 0.0,
|
|
72
67
|
nnmf_factor: bool = False,
|
|
73
68
|
# prodigy parameters
|
|
74
69
|
beta3: float = None,
|
|
@@ -80,6 +75,8 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
|
|
|
80
75
|
slice_p: int = 11,
|
|
81
76
|
prodigy_steps: int = 0,
|
|
82
77
|
d_limiter: bool = True,
|
|
78
|
+
# Compiled
|
|
79
|
+
compiled_optimizer: bool = False,
|
|
83
80
|
):
|
|
84
81
|
if not lr > 0.0:
|
|
85
82
|
raise ValueError(f"Learning rate must be > 0.0, but got {lr}")
|
|
@@ -94,21 +91,28 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
|
|
|
94
91
|
weight_decay=weight_decay,
|
|
95
92
|
vector_reshape=vector_reshape,
|
|
96
93
|
orthogonal_gradient=orthogonal_gradient,
|
|
97
|
-
clip_threshold=clip_threshold,
|
|
98
94
|
beta3=beta3, d=d0, d0=d0, d_max=d0, d_numerator=0.0, d_coef=d_coef,
|
|
99
|
-
growth_rate=growth_rate, safeguard_warmup=safeguard_warmup,
|
|
95
|
+
growth_rate=growth_rate, safeguard_warmup=safeguard_warmup, slice_p=slice_p,
|
|
100
96
|
fsdp_in_use=fsdp_in_use,
|
|
101
97
|
prodigy_steps=prodigy_steps,
|
|
102
98
|
d_limiter=d_limiter,
|
|
99
|
+
compiled_optimizer=compiled_optimizer,
|
|
103
100
|
)
|
|
104
101
|
self.stochastic_rounding = stochastic_rounding
|
|
105
102
|
self.cautious_mask = cautious_mask
|
|
106
103
|
self.factored = nnmf_factor
|
|
107
104
|
self.fsdp_in_use = fsdp_in_use
|
|
108
105
|
super().__init__(params, defaults)
|
|
109
|
-
#
|
|
106
|
+
# Use the device of the first parameter to avoid hardcoding '.cuda()'
|
|
107
|
+
self.device = self.param_groups[0]['params'][0].device
|
|
108
|
+
|
|
109
|
+
self.global_step = 0
|
|
110
110
|
self.init_step()
|
|
111
111
|
|
|
112
|
+
if compiled_optimizer:
|
|
113
|
+
torch._dynamo.config.cache_size_limit = 8192
|
|
114
|
+
self.compile(fullgraph=True)
|
|
115
|
+
|
|
112
116
|
@property
|
|
113
117
|
def supports_fused_back_pass(self) -> bool:
|
|
114
118
|
return True
|
|
@@ -124,14 +128,13 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
|
|
|
124
128
|
def init_step(self):
|
|
125
129
|
"""Resets accumulators and calculates dlr for the upcoming step."""
|
|
126
130
|
self.d_denom = 0.0
|
|
127
|
-
|
|
131
|
+
|
|
128
132
|
g_group = self.param_groups[0]
|
|
129
133
|
self.beta1, self.beta2 = g_group['betas']
|
|
130
134
|
self.beta3 = g_group['beta3']
|
|
131
135
|
if self.beta3 is None:
|
|
132
136
|
self.beta3 = math.sqrt(self.beta2)
|
|
133
|
-
|
|
134
|
-
k = g_group['k']
|
|
137
|
+
|
|
135
138
|
self.d = g_group['d']
|
|
136
139
|
lr = g_group['lr']
|
|
137
140
|
|
|
@@ -139,38 +142,21 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
|
|
|
139
142
|
|
|
140
143
|
self.d_numerator = g_group.get('d_numerator', 0.0) * self.beta3
|
|
141
144
|
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
if p.grad is None:
|
|
146
|
-
return
|
|
147
|
-
|
|
148
|
-
if hasattr(p, "_fsdp_flattened"):
|
|
149
|
-
self.fsdp_in_use = True
|
|
145
|
+
for group in self.param_groups:
|
|
146
|
+
for i, p in enumerate(group['params']):
|
|
147
|
+
self.__init_state(p, group)
|
|
150
148
|
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
grad = grad.float()
|
|
154
|
-
if group["clip_threshold"] > 0.0:
|
|
155
|
-
grad_norm = torch.norm(grad.detach())
|
|
156
|
-
if grad_norm > group["clip_threshold"]:
|
|
157
|
-
clip_coef = group["clip_threshold"] / grad_norm
|
|
158
|
-
grad.mul_(clip_coef)
|
|
159
|
-
if group["orthogonal_gradient"]:
|
|
160
|
-
grad = _orthogonalize_gradient(p, grad)
|
|
149
|
+
@torch.no_grad()
|
|
150
|
+
def __init_state(self, p, group):
|
|
161
151
|
state = self.state[p]
|
|
162
152
|
|
|
163
|
-
|
|
164
|
-
if 'step' not in state:
|
|
165
|
-
state['step'] = 0
|
|
153
|
+
if len(state) == 0:
|
|
166
154
|
|
|
167
|
-
|
|
155
|
+
state['factored'] = (
|
|
168
156
|
self.factored and
|
|
169
157
|
not (len(p.shape) == 1 and not group['vector_reshape'])
|
|
170
158
|
)
|
|
171
159
|
|
|
172
|
-
state['factored'] = should_factor
|
|
173
|
-
|
|
174
160
|
dtype = torch.float32 if self.factored else p.dtype
|
|
175
161
|
|
|
176
162
|
slice_p = group['slice_p']
|
|
@@ -185,13 +171,28 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
|
|
|
185
171
|
if state['factored']:
|
|
186
172
|
state['effective_shape'] = _get_effective_shape(p.numel())
|
|
187
173
|
d1, d2 = state['effective_shape']
|
|
188
|
-
state['mu_m_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
|
|
174
|
+
state['mu_m_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
|
|
189
175
|
state['mv_m_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
|
|
190
176
|
packed_d2 = (d2 + 7) // 8
|
|
191
177
|
state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=p.device)
|
|
192
178
|
else: # Fallback to standard Lion
|
|
193
179
|
state['exp_avg'] = torch.zeros_like(p, device=p.device, dtype=dtype)
|
|
194
180
|
|
|
181
|
+
@torch.no_grad()
|
|
182
|
+
def __step_parameter(self, p: torch.Tensor, group: dict, d: torch.Tensor | float, dlr: torch.Tensor | float):
|
|
183
|
+
"""Performs a single optimization step on a single parameter."""
|
|
184
|
+
if p.grad is None:
|
|
185
|
+
return
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
grad = p.grad
|
|
189
|
+
if grad.dtype != torch.float32 and self.factored:
|
|
190
|
+
grad = grad.float()
|
|
191
|
+
if group["orthogonal_gradient"]:
|
|
192
|
+
grad = _orthogonalize_gradient(p, grad)
|
|
193
|
+
state = self.state[p]
|
|
194
|
+
|
|
195
|
+
|
|
195
196
|
if state['factored']:
|
|
196
197
|
# Factored Path
|
|
197
198
|
d1, d2 = state['effective_shape']
|
|
@@ -205,7 +206,7 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
|
|
|
205
206
|
exp_avg = exp_avg.float()
|
|
206
207
|
|
|
207
208
|
# Compute update term c_t = β1*m_{t-1} + (1-β1)*g_t
|
|
208
|
-
signed_update = exp_avg.clone().mul_(self.beta1).add_(grad_reshaped, alpha=
|
|
209
|
+
signed_update = exp_avg.clone().mul_(self.beta1).add_(grad_reshaped, alpha=d * (1-self.beta1)).sign_()
|
|
209
210
|
|
|
210
211
|
if self.cautious_mask:
|
|
211
212
|
mask = (signed_update * grad_reshaped > 0).to(grad_reshaped.dtype)
|
|
@@ -214,10 +215,10 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
|
|
|
214
215
|
del mask
|
|
215
216
|
|
|
216
217
|
# Parameter update: p_t = p_{t-1} - lr * sign(c_t)
|
|
217
|
-
update_for_param = signed_update.view(p.shape).mul(
|
|
218
|
+
update_for_param = signed_update.view(p.shape).mul(dlr)
|
|
218
219
|
|
|
219
220
|
# Update momentum m_t = β2*m_{t-1} + (1-β2)*lr*g_t
|
|
220
|
-
exp_avg.mul_(self.beta2).add_(grad_reshaped, alpha=
|
|
221
|
+
exp_avg.mul_(self.beta2).add_(grad_reshaped, alpha=d * (1 - self.beta2))
|
|
221
222
|
del grad_reshaped
|
|
222
223
|
|
|
223
224
|
# Compress new momentum m_t and store factors
|
|
@@ -232,7 +233,7 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
|
|
|
232
233
|
# Compute update term and sign for the update
|
|
233
234
|
if exp_avg.dtype != torch.float32 and self.factored:
|
|
234
235
|
exp_avg = exp_avg.float()
|
|
235
|
-
signed_update = exp_avg.clone().mul_(self.beta1).add_(grad, alpha=
|
|
236
|
+
signed_update = exp_avg.clone().mul_(self.beta1).add_(grad, alpha=d * (1-self.beta1)).sign_()
|
|
236
237
|
|
|
237
238
|
if self.cautious_mask:
|
|
238
239
|
mask = (signed_update * grad > 0).to(grad.dtype)
|
|
@@ -240,41 +241,18 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
|
|
|
240
241
|
signed_update.mul_(mask)
|
|
241
242
|
del mask
|
|
242
243
|
|
|
243
|
-
update_for_param = signed_update.mul(
|
|
244
|
-
|
|
245
|
-
# Update momentum
|
|
246
|
-
exp_avg.mul_(self.beta2).add_(grad, alpha=self.d * (1 - self.beta2))
|
|
247
|
-
|
|
248
|
-
prodigy_steps = group['prodigy_steps']
|
|
249
|
-
if prodigy_steps <= 0 or group['k'] < prodigy_steps:
|
|
250
|
-
# --- Accumulate Prodigy stats ---
|
|
251
|
-
d0, safeguard_warmup, slice_p = group['d0'], group['safeguard_warmup'], group['slice_p']
|
|
252
|
-
s, p0 = state['s'], state['p0']
|
|
253
|
-
grad_flat = grad.flatten().float()
|
|
254
|
-
p_flat = p.data.flatten().float()
|
|
255
|
-
p0 = p0.float()
|
|
256
|
-
|
|
257
|
-
self.d_numerator += (self.d / d0) * self.dlr * torch.dot(grad_flat[::slice_p], p0.data - p_flat[::slice_p]).item()
|
|
258
|
-
|
|
259
|
-
alpha = ((self.d / d0) * self.d) if safeguard_warmup else ((self.d / d0) * self.dlr)
|
|
260
|
-
s.mul_(self.beta3).add_(grad_flat[::slice_p], alpha=alpha)
|
|
261
|
-
self.d_denom += s.abs().sum().item()
|
|
244
|
+
update_for_param = signed_update.mul(dlr)
|
|
262
245
|
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
# Free memory if prodigy_steps is reached
|
|
266
|
-
if 's' in state:
|
|
267
|
-
del state['s']
|
|
268
|
-
if 'p0' in state:
|
|
269
|
-
del state['p0']
|
|
246
|
+
# Update momentum
|
|
247
|
+
exp_avg.mul_(self.beta2).add_(grad, alpha=d * (1 - self.beta2))
|
|
270
248
|
|
|
271
249
|
if group["weight_decay"] != 0:
|
|
272
250
|
if p.dtype == torch.bfloat16 and self.stochastic_rounding:
|
|
273
251
|
add_stochastic_(p.data, p.data,
|
|
274
|
-
alpha=-group["weight_decay"] *
|
|
252
|
+
alpha=-group["weight_decay"] * dlr)
|
|
275
253
|
else:
|
|
276
254
|
p.data.add_(
|
|
277
|
-
p.data, alpha=-group["weight_decay"] *
|
|
255
|
+
p.data, alpha=-group["weight_decay"] * dlr
|
|
278
256
|
)
|
|
279
257
|
|
|
280
258
|
if p.dtype == torch.bfloat16 and self.stochastic_rounding:
|
|
@@ -284,6 +262,29 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
|
|
|
284
262
|
|
|
285
263
|
del update_for_param
|
|
286
264
|
|
|
265
|
+
@torch.no_grad()
|
|
266
|
+
def step_parameter(self, p: torch.Tensor, group: dict, i: int | None = None):
|
|
267
|
+
if hasattr(p, "_fsdp_flattened"):
|
|
268
|
+
self.fsdp_in_use = True
|
|
269
|
+
|
|
270
|
+
if self.global_step is None and 'step' in self.state[p]:
|
|
271
|
+
# For backward compatibility
|
|
272
|
+
self.global_step = self.state[p]['step']
|
|
273
|
+
|
|
274
|
+
if isinstance(self.d_numerator, float):
|
|
275
|
+
self.d_numerator = torch.tensor(self.d_numerator, device=p.device)
|
|
276
|
+
self.d_denom = torch.tensor(self.d_denom, device=p.device)
|
|
277
|
+
|
|
278
|
+
if not group.get('compiled_optimizer', False):
|
|
279
|
+
self.__step_parameter(p, group, self.d, self.dlr)
|
|
280
|
+
else:
|
|
281
|
+
d_tensor = torch.tensor(self.d, device=p.device)
|
|
282
|
+
dlr_tensor = torch.tensor(self.dlr, device=p.device)
|
|
283
|
+
self._compiled_step_parameter(p, group, d_tensor, dlr_tensor)
|
|
284
|
+
|
|
285
|
+
def compile(self, *args, **kwargs):
|
|
286
|
+
self._compiled_step_parameter = torch.compile(self.__step_parameter, *args, **kwargs)
|
|
287
|
+
|
|
287
288
|
@torch.no_grad()
|
|
288
289
|
def step(self, closure: Optional[callable] = None):
|
|
289
290
|
"""Performs a single optimization step."""
|
|
@@ -306,21 +307,19 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
|
|
|
306
307
|
"""Calculates the new `d` based on the accumulated stats."""
|
|
307
308
|
g_group = self.param_groups[0]
|
|
308
309
|
# Only perform d-adaptation if prodigy_steps has not been reached
|
|
309
|
-
prodigy_active = not (g_group.get('prodigy_steps', 0) > 0 and
|
|
310
|
+
prodigy_active = not (g_group.get('prodigy_steps', 0) > 0 and self.global_step >= g_group['prodigy_steps'])
|
|
310
311
|
|
|
311
312
|
if prodigy_active:
|
|
312
313
|
d_max, d_coef, growth_rate = g_group['d_max'], g_group['d_coef'], g_group['growth_rate']
|
|
313
|
-
|
|
314
|
+
|
|
314
315
|
if self.fsdp_in_use and dist.is_available() and dist.is_initialized():
|
|
315
|
-
|
|
316
|
-
device = self.param_groups[0]['params'][0].device
|
|
317
|
-
dist_tensor = torch.tensor([self.d_numerator, self.d_denom], device=device)
|
|
316
|
+
dist_tensor = torch.stack([self.d_numerator, self.d_denom])
|
|
318
317
|
dist.all_reduce(dist_tensor, op=dist.ReduceOp.SUM)
|
|
319
318
|
global_d_numerator = dist_tensor[0].item()
|
|
320
319
|
global_d_denom = dist_tensor[1].item()
|
|
321
320
|
else:
|
|
322
|
-
global_d_numerator = self.d_numerator
|
|
323
|
-
global_d_denom = self.d_denom
|
|
321
|
+
global_d_numerator = self.d_numerator.item()
|
|
322
|
+
global_d_denom = self.d_denom.item()
|
|
324
323
|
|
|
325
324
|
d_hat = self.d
|
|
326
325
|
if global_d_denom > 0:
|
|
@@ -337,5 +336,4 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
|
|
|
337
336
|
group['d'] = self.d
|
|
338
337
|
group['d_max'] = d_max
|
|
339
338
|
# Increment step counter for all groups, regardless of whether d was updated
|
|
340
|
-
|
|
341
|
-
group['k'] += 1
|
|
339
|
+
self.global_step += 1
|