adv-optm 1.1.1__py3-none-any.whl → 1.1.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of adv-optm might be problematic. Click here for more details.
- adv_optm/__init__.py +1 -1
- adv_optm/optim/AdamW_adv.py +14 -10
- adv_optm/optim/Adopt_adv.py +24 -20
- adv_optm/optim/Lion_Prodigy_adv.py +62 -36
- adv_optm/optim/Prodigy_adv.py +16 -12
- adv_optm/optim/Simplified_AdEMAMix.py +2 -2
- {adv_optm-1.1.1.dist-info → adv_optm-1.1.4.dist-info}/METADATA +1 -2
- {adv_optm-1.1.1.dist-info → adv_optm-1.1.4.dist-info}/RECORD +11 -11
- {adv_optm-1.1.1.dist-info → adv_optm-1.1.4.dist-info}/WHEEL +0 -0
- {adv_optm-1.1.1.dist-info → adv_optm-1.1.4.dist-info}/licenses/LICENSE +0 -0
- {adv_optm-1.1.1.dist-info → adv_optm-1.1.4.dist-info}/top_level.txt +0 -0
adv_optm/__init__.py
CHANGED
adv_optm/optim/AdamW_adv.py
CHANGED
|
@@ -209,7 +209,7 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
209
209
|
beta1, beta2 = group['betas']
|
|
210
210
|
|
|
211
211
|
current_step = state['step']
|
|
212
|
-
if group
|
|
212
|
+
if group.get('kourkoutas_beta', False):
|
|
213
213
|
# Call prepare_step() once at the beginning of the step for all params
|
|
214
214
|
self.kourkoutas_helper.maybe_prepare_step(current_step)
|
|
215
215
|
# Accumulate current grad's norm for the *next* step
|
|
@@ -220,7 +220,7 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
220
220
|
step = state['step'] + 1
|
|
221
221
|
if group['use_bias_correction']:
|
|
222
222
|
bias_correction1 = 1.0 - beta1 ** step
|
|
223
|
-
if group
|
|
223
|
+
if group.get('kourkoutas_beta', False):
|
|
224
224
|
bias_correction2 = 1.0 - group['betas'][1] ** step
|
|
225
225
|
# Use beta2_max for bias correction
|
|
226
226
|
else:
|
|
@@ -252,12 +252,14 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
252
252
|
grad_reshaped = grad.view(d1, d2)
|
|
253
253
|
mt.mul_(beta1).add_(grad_reshaped, alpha=1.0 - beta1)
|
|
254
254
|
if self.grams_moment:
|
|
255
|
-
|
|
255
|
+
update_mt = (grad_reshaped.sign().mul_(mt.abs()))
|
|
256
256
|
elif self.cautious_mask:
|
|
257
257
|
mask = (mt * grad_reshaped > 0).to(grad_reshaped.dtype)
|
|
258
258
|
mask.div_(mask.mean().clamp_(min=1e-3))
|
|
259
|
-
mt.
|
|
259
|
+
update_mt = mt.mul(mask)
|
|
260
260
|
del mask
|
|
261
|
+
else:
|
|
262
|
+
update_mt = mt.clone()
|
|
261
263
|
|
|
262
264
|
vt = _unnmf((state['mu_v_nmf'], state['mv_v_nmf']))
|
|
263
265
|
vt.mul_(beta2).addcmul_(grad_reshaped, grad_reshaped, value=1.0 - beta2)
|
|
@@ -272,11 +274,11 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
272
274
|
|
|
273
275
|
mt_slow.mul_(beta3_ema).add_(grad_reshaped, alpha=1.0 - beta3_ema)
|
|
274
276
|
if beta1 > 0:
|
|
275
|
-
update = torch.add(
|
|
277
|
+
update = torch.add(update_mt, mt_slow, alpha=alpha_t)
|
|
276
278
|
else:
|
|
277
279
|
update = torch.add(grad_reshaped, mt_slow, alpha=alpha_t)
|
|
278
280
|
else:
|
|
279
|
-
update =
|
|
281
|
+
update = update_mt if beta1 > 0 else grad_reshaped.clone()
|
|
280
282
|
del grad_reshaped
|
|
281
283
|
|
|
282
284
|
if group['use_atan2']:
|
|
@@ -310,22 +312,24 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
310
312
|
exp_avg = state['exp_avg']
|
|
311
313
|
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
|
|
312
314
|
if self.grams_moment:
|
|
313
|
-
|
|
315
|
+
update_mt = grad.sign().mul_(exp_avg.abs())
|
|
314
316
|
elif self.cautious_mask:
|
|
315
317
|
mask = (exp_avg * grad > 0).to(grad.dtype)
|
|
316
318
|
mask.div_(mask.mean().clamp_(min=1e-3))
|
|
317
|
-
exp_avg.
|
|
319
|
+
update_mt = exp_avg.mul(mask)
|
|
318
320
|
del mask
|
|
321
|
+
else:
|
|
322
|
+
update_mt = exp_avg.clone()
|
|
319
323
|
|
|
320
324
|
if self.use_AdEMAMix:
|
|
321
325
|
exp_avg_slow = state['exp_avg_slow']
|
|
322
326
|
exp_avg_slow.mul_(beta3_ema).add_(grad, alpha=1 - beta3_ema)
|
|
323
327
|
if beta1 > 0:
|
|
324
|
-
update = torch.add(
|
|
328
|
+
update = torch.add(update_mt, exp_avg_slow, alpha=alpha_t)
|
|
325
329
|
else:
|
|
326
330
|
update = torch.add(grad, exp_avg_slow, alpha=alpha_t)
|
|
327
331
|
else:
|
|
328
|
-
update =
|
|
332
|
+
update = update_mt if beta1 > 0 else grad.clone()
|
|
329
333
|
|
|
330
334
|
exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2)
|
|
331
335
|
|
adv_optm/optim/Adopt_adv.py
CHANGED
|
@@ -13,7 +13,7 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
13
13
|
Implements an advanced ADOPT algorithm.
|
|
14
14
|
|
|
15
15
|
The ADOPT update rule modifies Adam by:
|
|
16
|
-
1. **Initialization:** The second moment `
|
|
16
|
+
1. **Initialization:** The second moment `vt` is initialized as `v₀ = g₀²`.
|
|
17
17
|
2. **Decorrelation:** The current gradient is normalized using the second-moment estimate
|
|
18
18
|
from the *previous* step (`v_{t-1}`).
|
|
19
19
|
3. **Order of Operations:** This normalization occurs *before* updating the
|
|
@@ -225,7 +225,7 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
225
225
|
state['sign_slow'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=p.device)
|
|
226
226
|
# v_0 = g_0^2 (SMMF_ADOPT NMF storage)
|
|
227
227
|
vt_init = grad.view(d1, d2).square_()
|
|
228
|
-
# Allocate NMF factors for
|
|
228
|
+
# Allocate NMF factors for vt
|
|
229
229
|
state['mu_v_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
|
|
230
230
|
state['mv_v_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
|
|
231
231
|
# Initialize v_0 using NMF
|
|
@@ -240,7 +240,7 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
240
240
|
beta1, beta2 = group['betas']
|
|
241
241
|
|
|
242
242
|
current_step = state['step']
|
|
243
|
-
if group
|
|
243
|
+
if group.get('kourkoutas_beta', False):
|
|
244
244
|
# Call prepare_step() once at the beginning of the step for all params
|
|
245
245
|
self.kourkoutas_helper.maybe_prepare_step(current_step)
|
|
246
246
|
# Accumulate current grad's norm for the *next* step
|
|
@@ -310,23 +310,25 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
310
310
|
else:
|
|
311
311
|
mt.mul_(beta1).add_(normalized_grad, alpha=1.0 - beta1)
|
|
312
312
|
if self.grams_moment:
|
|
313
|
-
|
|
313
|
+
update_mt = grad_reshaped.sign().mul_(mt.abs())
|
|
314
314
|
elif self.cautious_mask:
|
|
315
315
|
mask = (mt * grad_reshaped > 0).to(grad_reshaped.dtype)
|
|
316
316
|
mask.div_(mask.mean().clamp_(min=1e-3))
|
|
317
|
-
mt.
|
|
317
|
+
update_mt= mt.mul(mask)
|
|
318
318
|
del mask
|
|
319
|
+
else:
|
|
320
|
+
update_mt = mt.clone()
|
|
319
321
|
|
|
320
322
|
if self.use_AdEMAMix:
|
|
321
323
|
mt_slow.mul_(beta3_ema).add_(normalized_grad, alpha=1.0 - beta3_ema)
|
|
322
324
|
if beta1 > 0:
|
|
323
|
-
update = torch.add(
|
|
325
|
+
update = torch.add(update_mt, mt_slow, alpha=alpha_t)
|
|
324
326
|
else:
|
|
325
327
|
update = torch.add(normalized_grad, mt_slow, alpha=alpha_t)
|
|
326
328
|
elif self.Simplified_AdEMAMix:
|
|
327
|
-
update = torch.add(
|
|
329
|
+
update = torch.add(update_mt, normalized_grad, alpha=alpha_grad)
|
|
328
330
|
else:
|
|
329
|
-
update =
|
|
331
|
+
update = update_mt if beta1 > 0 else normalized_grad
|
|
330
332
|
|
|
331
333
|
update = update.view(p.shape)
|
|
332
334
|
|
|
@@ -356,10 +358,10 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
356
358
|
del vt
|
|
357
359
|
|
|
358
360
|
else: # Standard ADOPT logic for non-factored tensors
|
|
359
|
-
|
|
361
|
+
vt = state['exp_avg_sq'] # v_{t-1}
|
|
360
362
|
|
|
361
363
|
# ADOPT Step A: Decorrelate g_t using v_{t-1}
|
|
362
|
-
denom =
|
|
364
|
+
denom = vt.sqrt()
|
|
363
365
|
|
|
364
366
|
if self.use_atan2:
|
|
365
367
|
normalized_grad = torch.atan2(grad, denom)
|
|
@@ -372,31 +374,33 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
372
374
|
|
|
373
375
|
# ADOPT Step B: Update momentum m_t
|
|
374
376
|
if beta1 > 0:
|
|
375
|
-
|
|
377
|
+
mt = state['exp_avg'] # m_{t-1},
|
|
376
378
|
if self.Simplified_AdEMAMix:
|
|
377
|
-
|
|
379
|
+
mt.mul_(beta1).add_(normalized_grad, alpha=1.0)
|
|
378
380
|
else:
|
|
379
|
-
|
|
381
|
+
mt.mul_(beta1).add_(normalized_grad, alpha=1.0 - beta1)
|
|
380
382
|
|
|
381
383
|
if self.grams_moment:
|
|
382
|
-
|
|
384
|
+
update_mt = grad.sign().mul_(mt.abs())
|
|
383
385
|
elif self.cautious_mask:
|
|
384
|
-
mask = (
|
|
386
|
+
mask = (mt * grad > 0).to(grad.dtype)
|
|
385
387
|
mask.div_(mask.mean().clamp_(min=1e-3))
|
|
386
|
-
|
|
388
|
+
update_mt = mt.mul(mask)
|
|
387
389
|
del mask
|
|
390
|
+
else:
|
|
391
|
+
update_mt = mt.clone()
|
|
388
392
|
|
|
389
393
|
if self.use_AdEMAMix:
|
|
390
394
|
m_slow = state['exp_avg_slow']
|
|
391
395
|
m_slow.mul_(beta3_ema).add_(normalized_grad, alpha=1.0 - beta3_ema)
|
|
392
396
|
if beta1 > 0:
|
|
393
|
-
update = torch.add(
|
|
397
|
+
update = torch.add(update_mt, m_slow, alpha=alpha_t)
|
|
394
398
|
else:
|
|
395
399
|
update = torch.add(normalized_grad, m_slow, alpha=alpha_t)
|
|
396
400
|
elif self.Simplified_AdEMAMix:
|
|
397
|
-
update = torch.add(
|
|
401
|
+
update = torch.add(update_mt, normalized_grad, alpha=alpha_grad)
|
|
398
402
|
else:
|
|
399
|
-
update =
|
|
403
|
+
update = update_mt if beta1 > 0 else normalized_grad
|
|
400
404
|
|
|
401
405
|
if self.use_atan2:
|
|
402
406
|
update.mul_(group['lr'] * 1.2732395447351628)
|
|
@@ -404,7 +408,7 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
404
408
|
update.mul_(group['lr'])
|
|
405
409
|
|
|
406
410
|
# Update second moment v_t for the next step using raw g_t
|
|
407
|
-
|
|
411
|
+
vt.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2)
|
|
408
412
|
|
|
409
413
|
# Parameter Update
|
|
410
414
|
if group["weight_decay"] != 0:
|
|
@@ -50,6 +50,12 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
|
|
|
50
50
|
slice_p (int): Reduce memory usage by calculating LR adaptation statistics on only every
|
|
51
51
|
pth entry of each tensor. For values greater than 1 this an an approximation to standard
|
|
52
52
|
Prodigy. Values ~11 are reasonable (default 11).
|
|
53
|
+
prodigy_steps (int): If greater than zero, disable Prodigy's stepsize adjustments
|
|
54
|
+
after the specified optimiser step and release all state memory required by Prodigy
|
|
55
|
+
(default: 0).
|
|
56
|
+
d_limiter (bool): whether to clamp the new step size estimate (`d_hat`)
|
|
57
|
+
to prevent sudden, volatile increases in the adaptive step size (`d`).
|
|
58
|
+
(default: True)
|
|
53
59
|
"""
|
|
54
60
|
|
|
55
61
|
def __init__(
|
|
@@ -63,7 +69,7 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
|
|
|
63
69
|
orthogonal_gradient: bool = False,
|
|
64
70
|
cautious_mask: bool = False,
|
|
65
71
|
clip_threshold: float = 0.0,
|
|
66
|
-
nnmf_factor: bool =
|
|
72
|
+
nnmf_factor: bool = False,
|
|
67
73
|
# prodigy parameters
|
|
68
74
|
beta3: float = None,
|
|
69
75
|
d0: float = 1e-6,
|
|
@@ -72,6 +78,8 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
|
|
|
72
78
|
safeguard_warmup: bool = False,
|
|
73
79
|
fsdp_in_use: bool = False,
|
|
74
80
|
slice_p: int = 11,
|
|
81
|
+
prodigy_steps: int = 0,
|
|
82
|
+
d_limiter: bool = True,
|
|
75
83
|
):
|
|
76
84
|
if not lr > 0.0:
|
|
77
85
|
raise ValueError(f"Learning rate must be > 0.0, but got {lr}")
|
|
@@ -90,6 +98,8 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
|
|
|
90
98
|
beta3=beta3, d=d0, d0=d0, d_max=d0, d_numerator=0.0, d_coef=d_coef,
|
|
91
99
|
growth_rate=growth_rate, safeguard_warmup=safeguard_warmup, k=0, slice_p=slice_p,
|
|
92
100
|
fsdp_in_use=fsdp_in_use,
|
|
101
|
+
prodigy_steps=prodigy_steps,
|
|
102
|
+
d_limiter=d_limiter,
|
|
93
103
|
)
|
|
94
104
|
self.stochastic_rounding = stochastic_rounding
|
|
95
105
|
self.cautious_mask = cautious_mask
|
|
@@ -235,20 +245,28 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
|
|
|
235
245
|
# Update momentum
|
|
236
246
|
exp_avg.mul_(self.beta2).add_(grad, alpha=self.d * (1 - self.beta2))
|
|
237
247
|
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
248
|
+
prodigy_steps = group['prodigy_steps']
|
|
249
|
+
if prodigy_steps <= 0 or group['k'] < prodigy_steps:
|
|
250
|
+
# --- Accumulate Prodigy stats ---
|
|
251
|
+
d0, safeguard_warmup, slice_p = group['d0'], group['safeguard_warmup'], group['slice_p']
|
|
252
|
+
s, p0 = state['s'], state['p0']
|
|
253
|
+
grad_flat = grad.flatten().float()
|
|
254
|
+
p_flat = p.data.flatten().float()
|
|
255
|
+
p0 = p0.float()
|
|
244
256
|
|
|
245
|
-
|
|
257
|
+
self.d_numerator += (self.d / d0) * self.dlr * torch.dot(grad_flat[::slice_p], p0.data - p_flat[::slice_p]).item()
|
|
246
258
|
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
259
|
+
alpha = ((self.d / d0) * self.d) if safeguard_warmup else ((self.d / d0) * self.dlr)
|
|
260
|
+
s.mul_(self.beta3).add_(grad_flat[::slice_p], alpha=alpha)
|
|
261
|
+
self.d_denom += s.abs().sum().item()
|
|
250
262
|
|
|
251
|
-
|
|
263
|
+
del s, p0, grad_flat, p_flat, alpha
|
|
264
|
+
else:
|
|
265
|
+
# Free memory if prodigy_steps is reached
|
|
266
|
+
if 's' in state:
|
|
267
|
+
del state['s']
|
|
268
|
+
if 'p0' in state:
|
|
269
|
+
del state['p0']
|
|
252
270
|
|
|
253
271
|
if group["weight_decay"] != 0:
|
|
254
272
|
if p.dtype == torch.bfloat16 and self.stochastic_rounding:
|
|
@@ -287,29 +305,37 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
|
|
|
287
305
|
def calculate_d(self):
|
|
288
306
|
"""Calculates the new `d` based on the accumulated stats."""
|
|
289
307
|
g_group = self.param_groups[0]
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
dist.
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
308
|
+
# Only perform d-adaptation if prodigy_steps has not been reached
|
|
309
|
+
prodigy_active = not (g_group.get('prodigy_steps', 0) > 0 and g_group['k'] >= g_group['prodigy_steps'])
|
|
310
|
+
|
|
311
|
+
if prodigy_active:
|
|
312
|
+
d_max, d_coef, growth_rate = g_group['d_max'], g_group['d_coef'], g_group['growth_rate']
|
|
313
|
+
|
|
314
|
+
if self.fsdp_in_use and dist.is_available() and dist.is_initialized():
|
|
315
|
+
# Use the device of the first parameter to avoid hardcoding '.cuda()'
|
|
316
|
+
device = self.param_groups[0]['params'][0].device
|
|
317
|
+
dist_tensor = torch.tensor([self.d_numerator, self.d_denom], device=device)
|
|
318
|
+
dist.all_reduce(dist_tensor, op=dist.ReduceOp.SUM)
|
|
319
|
+
global_d_numerator = dist_tensor[0].item()
|
|
320
|
+
global_d_denom = dist_tensor[1].item()
|
|
321
|
+
else:
|
|
322
|
+
global_d_numerator = self.d_numerator
|
|
323
|
+
global_d_denom = self.d_denom
|
|
324
|
+
|
|
325
|
+
d_hat = self.d
|
|
326
|
+
if global_d_denom > 0:
|
|
327
|
+
d_hat = d_coef * global_d_numerator / global_d_denom
|
|
328
|
+
if g_group.get('d_limiter', False):
|
|
329
|
+
d_hat = min(self.d * (2 ** 0.25), d_hat)
|
|
330
|
+
if self.d == g_group['d0']:
|
|
331
|
+
self.d = max(self.d, d_hat)
|
|
332
|
+
d_max = max(d_max, d_hat)
|
|
333
|
+
self.d = min(d_max, self.d * growth_rate)
|
|
334
|
+
|
|
335
|
+
for group in self.param_groups:
|
|
336
|
+
group['d_numerator'] = global_d_numerator
|
|
337
|
+
group['d'] = self.d
|
|
338
|
+
group['d_max'] = d_max
|
|
339
|
+
# Increment step counter for all groups, regardless of whether d was updated
|
|
311
340
|
for group in self.param_groups:
|
|
312
|
-
group['d_numerator'] = global_d_numerator
|
|
313
|
-
group['d'] = self.d
|
|
314
|
-
group['d_max'] = d_max
|
|
315
341
|
group['k'] += 1
|
adv_optm/optim/Prodigy_adv.py
CHANGED
|
@@ -304,7 +304,7 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
304
304
|
state['p0'] = torch.tensor(0, device=device, dtype=p.dtype)
|
|
305
305
|
|
|
306
306
|
current_step = state['step']
|
|
307
|
-
if group
|
|
307
|
+
if group.get('kourkoutas_beta', False):
|
|
308
308
|
# Call prepare_step() once at the beginning of the step for all params
|
|
309
309
|
self.kourkoutas_helper.maybe_prepare_step(current_step)
|
|
310
310
|
# Accumulate current grad's norm for the *next* step
|
|
@@ -343,12 +343,14 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
343
343
|
else:
|
|
344
344
|
mt.mul_(self.beta1).add_(grad_reshaped, alpha=self.d * (1.0 - self.beta1))
|
|
345
345
|
if self.grams_moment:
|
|
346
|
-
|
|
346
|
+
update_mt = (grad_reshaped.sign().mul_(mt.abs()))
|
|
347
347
|
elif self.cautious_mask:
|
|
348
348
|
mask = (mt * grad_reshaped > 0).to(grad_reshaped.dtype)
|
|
349
349
|
mask.div_(mask.mean().clamp_(min=1e-3))
|
|
350
|
-
mt.
|
|
350
|
+
update_mt = mt.mul(mask)
|
|
351
351
|
del mask
|
|
352
|
+
else:
|
|
353
|
+
update_mt = mt.clone()
|
|
352
354
|
|
|
353
355
|
vt = _unnmf((state['mu_v_nmf'], state['mv_v_nmf']))
|
|
354
356
|
vt.mul_(beta2).addcmul_(grad_reshaped, grad_reshaped, value=self.d * self.d * (1.0 - beta2))
|
|
@@ -362,13 +364,13 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
362
364
|
del unpacked_sign_slow
|
|
363
365
|
mt_slow.mul_(beta3_ema).add_(grad_reshaped, alpha=self.d * (1.0 - beta3_ema))
|
|
364
366
|
if self.beta1 > 0:
|
|
365
|
-
update = torch.add(
|
|
367
|
+
update = torch.add(update_mt, mt_slow, alpha=alpha_t)
|
|
366
368
|
else:
|
|
367
369
|
update = torch.add(grad_reshaped.mul(self.d), mt_slow, alpha=alpha_t)
|
|
368
370
|
elif self.Simplified_AdEMAMix:
|
|
369
|
-
update = torch.add(
|
|
371
|
+
update = torch.add(update_mt, grad_reshaped, alpha=alpha_grad * self.d)
|
|
370
372
|
else:
|
|
371
|
-
update =
|
|
373
|
+
update = update_mt if self.beta1 > 0 else grad_reshaped.mul(self.d)
|
|
372
374
|
del grad_reshaped
|
|
373
375
|
|
|
374
376
|
if group['use_atan2']:
|
|
@@ -405,24 +407,26 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
405
407
|
else:
|
|
406
408
|
exp_avg.mul_(self.beta1).add_(grad, alpha=self.d * (1.0 - self.beta1))
|
|
407
409
|
if self.grams_moment:
|
|
408
|
-
|
|
410
|
+
update_mt = grad.sign().mul_(exp_avg.abs())
|
|
409
411
|
elif self.cautious_mask:
|
|
410
412
|
mask = (exp_avg * grad > 0).to(grad.dtype)
|
|
411
413
|
mask.div_(mask.mean().clamp_(min=1e-3))
|
|
412
|
-
exp_avg.
|
|
414
|
+
update_mt = exp_avg.mul(mask)
|
|
413
415
|
del mask
|
|
416
|
+
else:
|
|
417
|
+
update_mt = exp_avg.clone()
|
|
414
418
|
|
|
415
419
|
if self.use_AdEMAMix:
|
|
416
420
|
exp_avg_slow = state['exp_avg_slow']
|
|
417
421
|
exp_avg_slow.mul_(beta3_ema).add_(grad, alpha=self.d * (1.0 - beta3_ema))
|
|
418
422
|
if self.beta1 > 0:
|
|
419
|
-
update = torch.add(
|
|
423
|
+
update = torch.add(update_mt, exp_avg_slow, alpha=alpha_t)
|
|
420
424
|
else:
|
|
421
425
|
update = torch.add(grad.mul(self.d), exp_avg_slow, alpha=alpha_t)
|
|
422
426
|
elif self.Simplified_AdEMAMix:
|
|
423
|
-
update = torch.add(
|
|
427
|
+
update = torch.add(update_mt, grad, alpha=alpha_grad * self.d)
|
|
424
428
|
else:
|
|
425
|
-
update =
|
|
429
|
+
update = update_mt if self.beta1 > 0 else grad.mul(self.d)
|
|
426
430
|
|
|
427
431
|
exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=self.d * self.d * (1.0 - beta2))
|
|
428
432
|
|
|
@@ -515,7 +519,7 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
515
519
|
d_hat = self.d
|
|
516
520
|
if global_d_denom > 0:
|
|
517
521
|
d_hat = d_coef * global_d_numerator / global_d_denom
|
|
518
|
-
if g_group
|
|
522
|
+
if g_group.get('d_limiter', False):
|
|
519
523
|
d_hat = min(self.d * (2 ** 0.25), d_hat)
|
|
520
524
|
if self.d == g_group['d0']:
|
|
521
525
|
self.d = max(self.d, d_hat)
|
|
@@ -191,7 +191,7 @@ class Simplified_AdEMAMix(torch.optim.Optimizer):
|
|
|
191
191
|
beta1_final, beta2 = group["betas"]
|
|
192
192
|
|
|
193
193
|
current_step = state['step']
|
|
194
|
-
if group
|
|
194
|
+
if group.get('kourkoutas_beta', False):
|
|
195
195
|
# Call prepare_step() once at the beginning of the step for all params
|
|
196
196
|
self.kourkoutas_helper.maybe_prepare_step(current_step)
|
|
197
197
|
# Accumulate current grad's norm for the *next* step
|
|
@@ -210,7 +210,7 @@ class Simplified_AdEMAMix(torch.optim.Optimizer):
|
|
|
210
210
|
|
|
211
211
|
if group['use_bias_correction']:
|
|
212
212
|
state['num_sum'] = beta1 * state['num_sum'] + 1.0
|
|
213
|
-
if group
|
|
213
|
+
if group.get('kourkoutas_beta', False):
|
|
214
214
|
state['den_sum'] = group['betas'][1] * state['den_sum'] + (1.0 - group['betas'][1])
|
|
215
215
|
else:
|
|
216
216
|
state['den_sum'] = beta2 * state['den_sum'] + (1.0 - beta2)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: adv_optm
|
|
3
|
-
Version: 1.1.
|
|
3
|
+
Version: 1.1.4
|
|
4
4
|
Summary: A family of highly efficient, lightweight yet powerful optimizers.
|
|
5
5
|
Home-page: https://github.com/Koratahiu/Advanced_Optimizers
|
|
6
6
|
Author: Koratahiu
|
|
@@ -254,7 +254,6 @@ settings:
|
|
|
254
254
|
• Full fine-tune: 1e-10
|
|
255
255
|
• Embedding: 1e-7
|
|
256
256
|
- d_coef: 1
|
|
257
|
-
- d_limiter: True # To stablizie Prodigy with Simplified_AdEMAMix
|
|
258
257
|
- factored: False # Can be true or false, quality should not degrade due to Simplified_AdEMAMix’s high tolerance to 1-bit factorization.
|
|
259
258
|
```
|
|
260
259
|
|
|
@@ -1,10 +1,10 @@
|
|
|
1
|
-
adv_optm/__init__.py,sha256=
|
|
2
|
-
adv_optm/optim/AdamW_adv.py,sha256=
|
|
3
|
-
adv_optm/optim/Adopt_adv.py,sha256=
|
|
4
|
-
adv_optm/optim/Lion_Prodigy_adv.py,sha256=
|
|
1
|
+
adv_optm/__init__.py,sha256=Y1TYe8pweNoL-52qOQojMUf6_7BZaANYJExo043yi54,306
|
|
2
|
+
adv_optm/optim/AdamW_adv.py,sha256=sdeXzjjknKjYaFipPn6BWyo8aOuqWoF9tXIylJUZayw,17656
|
|
3
|
+
adv_optm/optim/Adopt_adv.py,sha256=gOUEahnvzIdg_650VIajRxMGCyGhfpk6OsiTY514yFA,21636
|
|
4
|
+
adv_optm/optim/Lion_Prodigy_adv.py,sha256=LEA3UYJpPeFnmxeniLNv1u2LKKj4ufx3Bq_MLw-nWXk,14617
|
|
5
5
|
adv_optm/optim/Lion_adv.py,sha256=aGNAplZlyXYgVllYcV_s4bK8iC4fv6EizFoWIMNLdBc,8299
|
|
6
|
-
adv_optm/optim/Prodigy_adv.py,sha256=
|
|
7
|
-
adv_optm/optim/Simplified_AdEMAMix.py,sha256=
|
|
6
|
+
adv_optm/optim/Prodigy_adv.py,sha256=5p9kV5gB11xdH15DL99GTfeEsVYe-IeS0WvvoeyvLpA,26083
|
|
7
|
+
adv_optm/optim/Simplified_AdEMAMix.py,sha256=nEIA3yM11nBooKzHudB5l3x4UdFRBYRwiKVUkGmO0K8,12971
|
|
8
8
|
adv_optm/optim/__init__.py,sha256=pcP865H2j1tut2VfTUhzQh7V8TF_tzPjqFnjMfFed2k,382
|
|
9
9
|
adv_optm/util/BF16_Stochastic_Rounding.py,sha256=Q5H0BcogmE4atP65dLoI21HKSf50lRdsBDfeF6v9Tbg,1548
|
|
10
10
|
adv_optm/util/Effective_Shape.py,sha256=TBvIk1V8IuTbbBsxuekJA4e_v8JlR5Nujtut8RTWAm4,318
|
|
@@ -13,8 +13,8 @@ adv_optm/util/NNMF.py,sha256=yRf5IP5Sjq0Uf0DxN0Q8NxEGSdD-f1ULziLVDOjY8K4,639
|
|
|
13
13
|
adv_optm/util/One_Bit_Boolean.py,sha256=Wat49esdwohuN-OHOFMW8D0aOQgV9cP5Rl8z6yfmpos,1068
|
|
14
14
|
adv_optm/util/OrthoGrad.py,sha256=NzInuBQGy_Ja__M1R9XbvqVaQ0fhGbtGgFE9YON7B3I,707
|
|
15
15
|
adv_optm/util/__init__.py,sha256=qoyIF0jcLjs_vSEcsv36clw5LFNBEbifyXrrVxMH-G4,349
|
|
16
|
-
adv_optm-1.1.
|
|
17
|
-
adv_optm-1.1.
|
|
18
|
-
adv_optm-1.1.
|
|
19
|
-
adv_optm-1.1.
|
|
20
|
-
adv_optm-1.1.
|
|
16
|
+
adv_optm-1.1.4.dist-info/licenses/LICENSE,sha256=HrhfyXIkWY2tGFK11kg7vPCqhgh5DcxleloqdhrpyMY,11558
|
|
17
|
+
adv_optm-1.1.4.dist-info/METADATA,sha256=eaUrKC9WbjSIjwNZaqIuGdn11tZC_Ob39fxnFo_Rbd0,13950
|
|
18
|
+
adv_optm-1.1.4.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
19
|
+
adv_optm-1.1.4.dist-info/top_level.txt,sha256=iNfBIIzu-lPrQ7jyC56WBCcbkRwitM2nJ15-MRQ_6fg,9
|
|
20
|
+
adv_optm-1.1.4.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|