adv-optm 1.1.1__py3-none-any.whl → 1.1.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of adv-optm might be problematic. Click here for more details.

adv_optm/__init__.py CHANGED
@@ -16,4 +16,4 @@ __all__ = [
16
16
  "Lion_Prodigy_adv",
17
17
  ]
18
18
 
19
- __version__ = "1.1.1"
19
+ __version__ = "1.1.4"
@@ -209,7 +209,7 @@ class AdamW_adv(torch.optim.Optimizer):
209
209
  beta1, beta2 = group['betas']
210
210
 
211
211
  current_step = state['step']
212
- if group['kourkoutas_beta']:
212
+ if group.get('kourkoutas_beta', False):
213
213
  # Call prepare_step() once at the beginning of the step for all params
214
214
  self.kourkoutas_helper.maybe_prepare_step(current_step)
215
215
  # Accumulate current grad's norm for the *next* step
@@ -220,7 +220,7 @@ class AdamW_adv(torch.optim.Optimizer):
220
220
  step = state['step'] + 1
221
221
  if group['use_bias_correction']:
222
222
  bias_correction1 = 1.0 - beta1 ** step
223
- if group['kourkoutas_beta']:
223
+ if group.get('kourkoutas_beta', False):
224
224
  bias_correction2 = 1.0 - group['betas'][1] ** step
225
225
  # Use beta2_max for bias correction
226
226
  else:
@@ -252,12 +252,14 @@ class AdamW_adv(torch.optim.Optimizer):
252
252
  grad_reshaped = grad.view(d1, d2)
253
253
  mt.mul_(beta1).add_(grad_reshaped, alpha=1.0 - beta1)
254
254
  if self.grams_moment:
255
- mt.copy_(grad_reshaped.sign() * mt.abs())
255
+ update_mt = (grad_reshaped.sign().mul_(mt.abs()))
256
256
  elif self.cautious_mask:
257
257
  mask = (mt * grad_reshaped > 0).to(grad_reshaped.dtype)
258
258
  mask.div_(mask.mean().clamp_(min=1e-3))
259
- mt.mul_(mask)
259
+ update_mt = mt.mul(mask)
260
260
  del mask
261
+ else:
262
+ update_mt = mt.clone()
261
263
 
262
264
  vt = _unnmf((state['mu_v_nmf'], state['mv_v_nmf']))
263
265
  vt.mul_(beta2).addcmul_(grad_reshaped, grad_reshaped, value=1.0 - beta2)
@@ -272,11 +274,11 @@ class AdamW_adv(torch.optim.Optimizer):
272
274
 
273
275
  mt_slow.mul_(beta3_ema).add_(grad_reshaped, alpha=1.0 - beta3_ema)
274
276
  if beta1 > 0:
275
- update = torch.add(mt, mt_slow, alpha=alpha_t)
277
+ update = torch.add(update_mt, mt_slow, alpha=alpha_t)
276
278
  else:
277
279
  update = torch.add(grad_reshaped, mt_slow, alpha=alpha_t)
278
280
  else:
279
- update = mt.clone() if beta1 > 0 else grad_reshaped.clone()
281
+ update = update_mt if beta1 > 0 else grad_reshaped.clone()
280
282
  del grad_reshaped
281
283
 
282
284
  if group['use_atan2']:
@@ -310,22 +312,24 @@ class AdamW_adv(torch.optim.Optimizer):
310
312
  exp_avg = state['exp_avg']
311
313
  exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
312
314
  if self.grams_moment:
313
- exp_avg = grad.sign() * exp_avg.abs()
315
+ update_mt = grad.sign().mul_(exp_avg.abs())
314
316
  elif self.cautious_mask:
315
317
  mask = (exp_avg * grad > 0).to(grad.dtype)
316
318
  mask.div_(mask.mean().clamp_(min=1e-3))
317
- exp_avg.mul_(mask)
319
+ update_mt = exp_avg.mul(mask)
318
320
  del mask
321
+ else:
322
+ update_mt = exp_avg.clone()
319
323
 
320
324
  if self.use_AdEMAMix:
321
325
  exp_avg_slow = state['exp_avg_slow']
322
326
  exp_avg_slow.mul_(beta3_ema).add_(grad, alpha=1 - beta3_ema)
323
327
  if beta1 > 0:
324
- update = torch.add(exp_avg, exp_avg_slow, alpha=alpha_t)
328
+ update = torch.add(update_mt, exp_avg_slow, alpha=alpha_t)
325
329
  else:
326
330
  update = torch.add(grad, exp_avg_slow, alpha=alpha_t)
327
331
  else:
328
- update = exp_avg.clone() if beta1 > 0 else grad.clone()
332
+ update = update_mt if beta1 > 0 else grad.clone()
329
333
 
330
334
  exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2)
331
335
 
@@ -13,7 +13,7 @@ class Adopt_adv(torch.optim.Optimizer):
13
13
  Implements an advanced ADOPT algorithm.
14
14
 
15
15
  The ADOPT update rule modifies Adam by:
16
- 1. **Initialization:** The second moment `v` is initialized as `v₀ = g₀²`.
16
+ 1. **Initialization:** The second moment `vt` is initialized as `v₀ = g₀²`.
17
17
  2. **Decorrelation:** The current gradient is normalized using the second-moment estimate
18
18
  from the *previous* step (`v_{t-1}`).
19
19
  3. **Order of Operations:** This normalization occurs *before* updating the
@@ -225,7 +225,7 @@ class Adopt_adv(torch.optim.Optimizer):
225
225
  state['sign_slow'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=p.device)
226
226
  # v_0 = g_0^2 (SMMF_ADOPT NMF storage)
227
227
  vt_init = grad.view(d1, d2).square_()
228
- # Allocate NMF factors for v
228
+ # Allocate NMF factors for vt
229
229
  state['mu_v_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
230
230
  state['mv_v_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
231
231
  # Initialize v_0 using NMF
@@ -240,7 +240,7 @@ class Adopt_adv(torch.optim.Optimizer):
240
240
  beta1, beta2 = group['betas']
241
241
 
242
242
  current_step = state['step']
243
- if group['kourkoutas_beta']:
243
+ if group.get('kourkoutas_beta', False):
244
244
  # Call prepare_step() once at the beginning of the step for all params
245
245
  self.kourkoutas_helper.maybe_prepare_step(current_step)
246
246
  # Accumulate current grad's norm for the *next* step
@@ -310,23 +310,25 @@ class Adopt_adv(torch.optim.Optimizer):
310
310
  else:
311
311
  mt.mul_(beta1).add_(normalized_grad, alpha=1.0 - beta1)
312
312
  if self.grams_moment:
313
- mt = grad_reshaped.sign() * mt.abs()
313
+ update_mt = grad_reshaped.sign().mul_(mt.abs())
314
314
  elif self.cautious_mask:
315
315
  mask = (mt * grad_reshaped > 0).to(grad_reshaped.dtype)
316
316
  mask.div_(mask.mean().clamp_(min=1e-3))
317
- mt.mul_(mask)
317
+ update_mt= mt.mul(mask)
318
318
  del mask
319
+ else:
320
+ update_mt = mt.clone()
319
321
 
320
322
  if self.use_AdEMAMix:
321
323
  mt_slow.mul_(beta3_ema).add_(normalized_grad, alpha=1.0 - beta3_ema)
322
324
  if beta1 > 0:
323
- update = torch.add(mt, mt_slow, alpha=alpha_t)
325
+ update = torch.add(update_mt, mt_slow, alpha=alpha_t)
324
326
  else:
325
327
  update = torch.add(normalized_grad, mt_slow, alpha=alpha_t)
326
328
  elif self.Simplified_AdEMAMix:
327
- update = torch.add(mt, normalized_grad, alpha=alpha_grad)
329
+ update = torch.add(update_mt, normalized_grad, alpha=alpha_grad)
328
330
  else:
329
- update = mt.clone() if beta1 > 0 else normalized_grad
331
+ update = update_mt if beta1 > 0 else normalized_grad
330
332
 
331
333
  update = update.view(p.shape)
332
334
 
@@ -356,10 +358,10 @@ class Adopt_adv(torch.optim.Optimizer):
356
358
  del vt
357
359
 
358
360
  else: # Standard ADOPT logic for non-factored tensors
359
- v = state['exp_avg_sq'] # v_{t-1}
361
+ vt = state['exp_avg_sq'] # v_{t-1}
360
362
 
361
363
  # ADOPT Step A: Decorrelate g_t using v_{t-1}
362
- denom = v.sqrt()
364
+ denom = vt.sqrt()
363
365
 
364
366
  if self.use_atan2:
365
367
  normalized_grad = torch.atan2(grad, denom)
@@ -372,31 +374,33 @@ class Adopt_adv(torch.optim.Optimizer):
372
374
 
373
375
  # ADOPT Step B: Update momentum m_t
374
376
  if beta1 > 0:
375
- m = state['exp_avg'] # m_{t-1},
377
+ mt = state['exp_avg'] # m_{t-1},
376
378
  if self.Simplified_AdEMAMix:
377
- m.mul_(beta1).add_(normalized_grad, alpha=1.0)
379
+ mt.mul_(beta1).add_(normalized_grad, alpha=1.0)
378
380
  else:
379
- m.mul_(beta1).add_(normalized_grad, alpha=1.0 - beta1)
381
+ mt.mul_(beta1).add_(normalized_grad, alpha=1.0 - beta1)
380
382
 
381
383
  if self.grams_moment:
382
- m = grad.sign() * m.abs()
384
+ update_mt = grad.sign().mul_(mt.abs())
383
385
  elif self.cautious_mask:
384
- mask = (m * grad > 0).to(grad.dtype)
386
+ mask = (mt * grad > 0).to(grad.dtype)
385
387
  mask.div_(mask.mean().clamp_(min=1e-3))
386
- m.mul_(mask)
388
+ update_mt = mt.mul(mask)
387
389
  del mask
390
+ else:
391
+ update_mt = mt.clone()
388
392
 
389
393
  if self.use_AdEMAMix:
390
394
  m_slow = state['exp_avg_slow']
391
395
  m_slow.mul_(beta3_ema).add_(normalized_grad, alpha=1.0 - beta3_ema)
392
396
  if beta1 > 0:
393
- update = torch.add(m, m_slow, alpha=alpha_t)
397
+ update = torch.add(update_mt, m_slow, alpha=alpha_t)
394
398
  else:
395
399
  update = torch.add(normalized_grad, m_slow, alpha=alpha_t)
396
400
  elif self.Simplified_AdEMAMix:
397
- update = torch.add(m, normalized_grad, alpha=alpha_grad)
401
+ update = torch.add(update_mt, normalized_grad, alpha=alpha_grad)
398
402
  else:
399
- update = m.clone() if beta1 > 0 else normalized_grad
403
+ update = update_mt if beta1 > 0 else normalized_grad
400
404
 
401
405
  if self.use_atan2:
402
406
  update.mul_(group['lr'] * 1.2732395447351628)
@@ -404,7 +408,7 @@ class Adopt_adv(torch.optim.Optimizer):
404
408
  update.mul_(group['lr'])
405
409
 
406
410
  # Update second moment v_t for the next step using raw g_t
407
- v.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2)
411
+ vt.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2)
408
412
 
409
413
  # Parameter Update
410
414
  if group["weight_decay"] != 0:
@@ -50,6 +50,12 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
50
50
  slice_p (int): Reduce memory usage by calculating LR adaptation statistics on only every
51
51
  pth entry of each tensor. For values greater than 1 this an an approximation to standard
52
52
  Prodigy. Values ~11 are reasonable (default 11).
53
+ prodigy_steps (int): If greater than zero, disable Prodigy's stepsize adjustments
54
+ after the specified optimiser step and release all state memory required by Prodigy
55
+ (default: 0).
56
+ d_limiter (bool): whether to clamp the new step size estimate (`d_hat`)
57
+ to prevent sudden, volatile increases in the adaptive step size (`d`).
58
+ (default: True)
53
59
  """
54
60
 
55
61
  def __init__(
@@ -63,7 +69,7 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
63
69
  orthogonal_gradient: bool = False,
64
70
  cautious_mask: bool = False,
65
71
  clip_threshold: float = 0.0,
66
- nnmf_factor: bool = True,
72
+ nnmf_factor: bool = False,
67
73
  # prodigy parameters
68
74
  beta3: float = None,
69
75
  d0: float = 1e-6,
@@ -72,6 +78,8 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
72
78
  safeguard_warmup: bool = False,
73
79
  fsdp_in_use: bool = False,
74
80
  slice_p: int = 11,
81
+ prodigy_steps: int = 0,
82
+ d_limiter: bool = True,
75
83
  ):
76
84
  if not lr > 0.0:
77
85
  raise ValueError(f"Learning rate must be > 0.0, but got {lr}")
@@ -90,6 +98,8 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
90
98
  beta3=beta3, d=d0, d0=d0, d_max=d0, d_numerator=0.0, d_coef=d_coef,
91
99
  growth_rate=growth_rate, safeguard_warmup=safeguard_warmup, k=0, slice_p=slice_p,
92
100
  fsdp_in_use=fsdp_in_use,
101
+ prodigy_steps=prodigy_steps,
102
+ d_limiter=d_limiter,
93
103
  )
94
104
  self.stochastic_rounding = stochastic_rounding
95
105
  self.cautious_mask = cautious_mask
@@ -235,20 +245,28 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
235
245
  # Update momentum
236
246
  exp_avg.mul_(self.beta2).add_(grad, alpha=self.d * (1 - self.beta2))
237
247
 
238
- # --- Accumulate Prodigy stats ---
239
- d0, safeguard_warmup, slice_p = group['d0'], group['safeguard_warmup'], group['slice_p']
240
- s, p0 = state['s'], state['p0']
241
- grad_flat = grad.flatten().float()
242
- p_flat = p.data.flatten().float()
243
- p0 = p0.float()
248
+ prodigy_steps = group['prodigy_steps']
249
+ if prodigy_steps <= 0 or group['k'] < prodigy_steps:
250
+ # --- Accumulate Prodigy stats ---
251
+ d0, safeguard_warmup, slice_p = group['d0'], group['safeguard_warmup'], group['slice_p']
252
+ s, p0 = state['s'], state['p0']
253
+ grad_flat = grad.flatten().float()
254
+ p_flat = p.data.flatten().float()
255
+ p0 = p0.float()
244
256
 
245
- self.d_numerator += (self.d / d0) * self.dlr * torch.dot(grad_flat[::slice_p], p0.data - p_flat[::slice_p]).item()
257
+ self.d_numerator += (self.d / d0) * self.dlr * torch.dot(grad_flat[::slice_p], p0.data - p_flat[::slice_p]).item()
246
258
 
247
- alpha = ((self.d / d0) * self.d) if safeguard_warmup else ((self.d / d0) * self.dlr)
248
- s.mul_(self.beta3).add_(grad_flat[::slice_p], alpha=alpha)
249
- self.d_denom += s.abs().sum().item()
259
+ alpha = ((self.d / d0) * self.d) if safeguard_warmup else ((self.d / d0) * self.dlr)
260
+ s.mul_(self.beta3).add_(grad_flat[::slice_p], alpha=alpha)
261
+ self.d_denom += s.abs().sum().item()
250
262
 
251
- del s, p0, grad_flat, p_flat, alpha
263
+ del s, p0, grad_flat, p_flat, alpha
264
+ else:
265
+ # Free memory if prodigy_steps is reached
266
+ if 's' in state:
267
+ del state['s']
268
+ if 'p0' in state:
269
+ del state['p0']
252
270
 
253
271
  if group["weight_decay"] != 0:
254
272
  if p.dtype == torch.bfloat16 and self.stochastic_rounding:
@@ -287,29 +305,37 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
287
305
  def calculate_d(self):
288
306
  """Calculates the new `d` based on the accumulated stats."""
289
307
  g_group = self.param_groups[0]
290
- d_max, d_coef, growth_rate = g_group['d_max'], g_group['d_coef'], g_group['growth_rate']
291
-
292
- if self.fsdp_in_use and dist.is_available() and dist.is_initialized():
293
- # Use the device of the first parameter to avoid hardcoding '.cuda()'
294
- device = self.param_groups[0]['params'][0].device
295
- dist_tensor = torch.tensor([self.d_numerator, self.d_denom], device=device)
296
- dist.all_reduce(dist_tensor, op=dist.ReduceOp.SUM)
297
- global_d_numerator = dist_tensor[0].item()
298
- global_d_denom = dist_tensor[1].item()
299
- else:
300
- global_d_numerator = self.d_numerator
301
- global_d_denom = self.d_denom
302
-
303
- d_hat = self.d
304
- if global_d_denom > 0:
305
- d_hat = d_coef * global_d_numerator / global_d_denom
306
- if self.d == g_group['d0']:
307
- self.d = max(self.d, d_hat)
308
- d_max = max(d_max, d_hat)
309
- self.d = min(d_max, self.d * growth_rate)
310
-
308
+ # Only perform d-adaptation if prodigy_steps has not been reached
309
+ prodigy_active = not (g_group.get('prodigy_steps', 0) > 0 and g_group['k'] >= g_group['prodigy_steps'])
310
+
311
+ if prodigy_active:
312
+ d_max, d_coef, growth_rate = g_group['d_max'], g_group['d_coef'], g_group['growth_rate']
313
+
314
+ if self.fsdp_in_use and dist.is_available() and dist.is_initialized():
315
+ # Use the device of the first parameter to avoid hardcoding '.cuda()'
316
+ device = self.param_groups[0]['params'][0].device
317
+ dist_tensor = torch.tensor([self.d_numerator, self.d_denom], device=device)
318
+ dist.all_reduce(dist_tensor, op=dist.ReduceOp.SUM)
319
+ global_d_numerator = dist_tensor[0].item()
320
+ global_d_denom = dist_tensor[1].item()
321
+ else:
322
+ global_d_numerator = self.d_numerator
323
+ global_d_denom = self.d_denom
324
+
325
+ d_hat = self.d
326
+ if global_d_denom > 0:
327
+ d_hat = d_coef * global_d_numerator / global_d_denom
328
+ if g_group.get('d_limiter', False):
329
+ d_hat = min(self.d * (2 ** 0.25), d_hat)
330
+ if self.d == g_group['d0']:
331
+ self.d = max(self.d, d_hat)
332
+ d_max = max(d_max, d_hat)
333
+ self.d = min(d_max, self.d * growth_rate)
334
+
335
+ for group in self.param_groups:
336
+ group['d_numerator'] = global_d_numerator
337
+ group['d'] = self.d
338
+ group['d_max'] = d_max
339
+ # Increment step counter for all groups, regardless of whether d was updated
311
340
  for group in self.param_groups:
312
- group['d_numerator'] = global_d_numerator
313
- group['d'] = self.d
314
- group['d_max'] = d_max
315
341
  group['k'] += 1
@@ -304,7 +304,7 @@ class Prodigy_adv(torch.optim.Optimizer):
304
304
  state['p0'] = torch.tensor(0, device=device, dtype=p.dtype)
305
305
 
306
306
  current_step = state['step']
307
- if group['kourkoutas_beta']:
307
+ if group.get('kourkoutas_beta', False):
308
308
  # Call prepare_step() once at the beginning of the step for all params
309
309
  self.kourkoutas_helper.maybe_prepare_step(current_step)
310
310
  # Accumulate current grad's norm for the *next* step
@@ -343,12 +343,14 @@ class Prodigy_adv(torch.optim.Optimizer):
343
343
  else:
344
344
  mt.mul_(self.beta1).add_(grad_reshaped, alpha=self.d * (1.0 - self.beta1))
345
345
  if self.grams_moment:
346
- mt.copy_(grad_reshaped.sign() * mt.abs())
346
+ update_mt = (grad_reshaped.sign().mul_(mt.abs()))
347
347
  elif self.cautious_mask:
348
348
  mask = (mt * grad_reshaped > 0).to(grad_reshaped.dtype)
349
349
  mask.div_(mask.mean().clamp_(min=1e-3))
350
- mt.mul_(mask)
350
+ update_mt = mt.mul(mask)
351
351
  del mask
352
+ else:
353
+ update_mt = mt.clone()
352
354
 
353
355
  vt = _unnmf((state['mu_v_nmf'], state['mv_v_nmf']))
354
356
  vt.mul_(beta2).addcmul_(grad_reshaped, grad_reshaped, value=self.d * self.d * (1.0 - beta2))
@@ -362,13 +364,13 @@ class Prodigy_adv(torch.optim.Optimizer):
362
364
  del unpacked_sign_slow
363
365
  mt_slow.mul_(beta3_ema).add_(grad_reshaped, alpha=self.d * (1.0 - beta3_ema))
364
366
  if self.beta1 > 0:
365
- update = torch.add(mt, mt_slow, alpha=alpha_t)
367
+ update = torch.add(update_mt, mt_slow, alpha=alpha_t)
366
368
  else:
367
369
  update = torch.add(grad_reshaped.mul(self.d), mt_slow, alpha=alpha_t)
368
370
  elif self.Simplified_AdEMAMix:
369
- update = torch.add(mt, grad_reshaped, alpha=alpha_grad * self.d)
371
+ update = torch.add(update_mt, grad_reshaped, alpha=alpha_grad * self.d)
370
372
  else:
371
- update = mt.clone() if self.beta1 > 0 else grad_reshaped.mul(self.d)
373
+ update = update_mt if self.beta1 > 0 else grad_reshaped.mul(self.d)
372
374
  del grad_reshaped
373
375
 
374
376
  if group['use_atan2']:
@@ -405,24 +407,26 @@ class Prodigy_adv(torch.optim.Optimizer):
405
407
  else:
406
408
  exp_avg.mul_(self.beta1).add_(grad, alpha=self.d * (1.0 - self.beta1))
407
409
  if self.grams_moment:
408
- exp_avg = grad.sign() * exp_avg.abs()
410
+ update_mt = grad.sign().mul_(exp_avg.abs())
409
411
  elif self.cautious_mask:
410
412
  mask = (exp_avg * grad > 0).to(grad.dtype)
411
413
  mask.div_(mask.mean().clamp_(min=1e-3))
412
- exp_avg.mul_(mask)
414
+ update_mt = exp_avg.mul(mask)
413
415
  del mask
416
+ else:
417
+ update_mt = exp_avg.clone()
414
418
 
415
419
  if self.use_AdEMAMix:
416
420
  exp_avg_slow = state['exp_avg_slow']
417
421
  exp_avg_slow.mul_(beta3_ema).add_(grad, alpha=self.d * (1.0 - beta3_ema))
418
422
  if self.beta1 > 0:
419
- update = torch.add(exp_avg, exp_avg_slow, alpha=alpha_t)
423
+ update = torch.add(update_mt, exp_avg_slow, alpha=alpha_t)
420
424
  else:
421
425
  update = torch.add(grad.mul(self.d), exp_avg_slow, alpha=alpha_t)
422
426
  elif self.Simplified_AdEMAMix:
423
- update = torch.add(exp_avg, grad, alpha=alpha_grad * self.d)
427
+ update = torch.add(update_mt, grad, alpha=alpha_grad * self.d)
424
428
  else:
425
- update = exp_avg.clone() if self.beta1 > 0 else grad.mul(self.d)
429
+ update = update_mt if self.beta1 > 0 else grad.mul(self.d)
426
430
 
427
431
  exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=self.d * self.d * (1.0 - beta2))
428
432
 
@@ -515,7 +519,7 @@ class Prodigy_adv(torch.optim.Optimizer):
515
519
  d_hat = self.d
516
520
  if global_d_denom > 0:
517
521
  d_hat = d_coef * global_d_numerator / global_d_denom
518
- if g_group['d_limiter']:
522
+ if g_group.get('d_limiter', False):
519
523
  d_hat = min(self.d * (2 ** 0.25), d_hat)
520
524
  if self.d == g_group['d0']:
521
525
  self.d = max(self.d, d_hat)
@@ -191,7 +191,7 @@ class Simplified_AdEMAMix(torch.optim.Optimizer):
191
191
  beta1_final, beta2 = group["betas"]
192
192
 
193
193
  current_step = state['step']
194
- if group['kourkoutas_beta']:
194
+ if group.get('kourkoutas_beta', False):
195
195
  # Call prepare_step() once at the beginning of the step for all params
196
196
  self.kourkoutas_helper.maybe_prepare_step(current_step)
197
197
  # Accumulate current grad's norm for the *next* step
@@ -210,7 +210,7 @@ class Simplified_AdEMAMix(torch.optim.Optimizer):
210
210
 
211
211
  if group['use_bias_correction']:
212
212
  state['num_sum'] = beta1 * state['num_sum'] + 1.0
213
- if group['kourkoutas_beta']:
213
+ if group.get('kourkoutas_beta', False):
214
214
  state['den_sum'] = group['betas'][1] * state['den_sum'] + (1.0 - group['betas'][1])
215
215
  else:
216
216
  state['den_sum'] = beta2 * state['den_sum'] + (1.0 - beta2)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: adv_optm
3
- Version: 1.1.1
3
+ Version: 1.1.4
4
4
  Summary: A family of highly efficient, lightweight yet powerful optimizers.
5
5
  Home-page: https://github.com/Koratahiu/Advanced_Optimizers
6
6
  Author: Koratahiu
@@ -254,7 +254,6 @@ settings:
254
254
  • Full fine-tune: 1e-10
255
255
  • Embedding: 1e-7
256
256
  - d_coef: 1
257
- - d_limiter: True # To stablizie Prodigy with Simplified_AdEMAMix
258
257
  - factored: False # Can be true or false, quality should not degrade due to Simplified_AdEMAMix’s high tolerance to 1-bit factorization.
259
258
  ```
260
259
 
@@ -1,10 +1,10 @@
1
- adv_optm/__init__.py,sha256=TL9XFW3kQQ2Xrxl6UULMftBzNvg7uTIcxMRD0vTttPk,306
2
- adv_optm/optim/AdamW_adv.py,sha256=ddEUVOif1gfZPgEJNrEGZ2wnha4MPMWw5ppPd8acQ3o,17457
3
- adv_optm/optim/Adopt_adv.py,sha256=fhH3hS9K6z5Blxc7NFfzpCrUGbl9EQnwLPmKDxBC1zg,21415
4
- adv_optm/optim/Lion_Prodigy_adv.py,sha256=aJ9orEEw0QYbrDzn1be0SHvOBlIkLwWG9RpWFuNMskM,13163
1
+ adv_optm/__init__.py,sha256=Y1TYe8pweNoL-52qOQojMUf6_7BZaANYJExo043yi54,306
2
+ adv_optm/optim/AdamW_adv.py,sha256=sdeXzjjknKjYaFipPn6BWyo8aOuqWoF9tXIylJUZayw,17656
3
+ adv_optm/optim/Adopt_adv.py,sha256=gOUEahnvzIdg_650VIajRxMGCyGhfpk6OsiTY514yFA,21636
4
+ adv_optm/optim/Lion_Prodigy_adv.py,sha256=LEA3UYJpPeFnmxeniLNv1u2LKKj4ufx3Bq_MLw-nWXk,14617
5
5
  adv_optm/optim/Lion_adv.py,sha256=aGNAplZlyXYgVllYcV_s4bK8iC4fv6EizFoWIMNLdBc,8299
6
- adv_optm/optim/Prodigy_adv.py,sha256=nD59cAWOJJCjZdIiuD5hD9MWO5sTjPQSvq-3dwGTcEM,25875
7
- adv_optm/optim/Simplified_AdEMAMix.py,sha256=gPjMhKulzmAeO42foe-d7xW0AcB50vKFYsvHgxbD3uc,12949
6
+ adv_optm/optim/Prodigy_adv.py,sha256=5p9kV5gB11xdH15DL99GTfeEsVYe-IeS0WvvoeyvLpA,26083
7
+ adv_optm/optim/Simplified_AdEMAMix.py,sha256=nEIA3yM11nBooKzHudB5l3x4UdFRBYRwiKVUkGmO0K8,12971
8
8
  adv_optm/optim/__init__.py,sha256=pcP865H2j1tut2VfTUhzQh7V8TF_tzPjqFnjMfFed2k,382
9
9
  adv_optm/util/BF16_Stochastic_Rounding.py,sha256=Q5H0BcogmE4atP65dLoI21HKSf50lRdsBDfeF6v9Tbg,1548
10
10
  adv_optm/util/Effective_Shape.py,sha256=TBvIk1V8IuTbbBsxuekJA4e_v8JlR5Nujtut8RTWAm4,318
@@ -13,8 +13,8 @@ adv_optm/util/NNMF.py,sha256=yRf5IP5Sjq0Uf0DxN0Q8NxEGSdD-f1ULziLVDOjY8K4,639
13
13
  adv_optm/util/One_Bit_Boolean.py,sha256=Wat49esdwohuN-OHOFMW8D0aOQgV9cP5Rl8z6yfmpos,1068
14
14
  adv_optm/util/OrthoGrad.py,sha256=NzInuBQGy_Ja__M1R9XbvqVaQ0fhGbtGgFE9YON7B3I,707
15
15
  adv_optm/util/__init__.py,sha256=qoyIF0jcLjs_vSEcsv36clw5LFNBEbifyXrrVxMH-G4,349
16
- adv_optm-1.1.1.dist-info/licenses/LICENSE,sha256=HrhfyXIkWY2tGFK11kg7vPCqhgh5DcxleloqdhrpyMY,11558
17
- adv_optm-1.1.1.dist-info/METADATA,sha256=F30-DuFinS-633wznIM27NBGU5asYpnKdiExchOFPcI,14019
18
- adv_optm-1.1.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
19
- adv_optm-1.1.1.dist-info/top_level.txt,sha256=iNfBIIzu-lPrQ7jyC56WBCcbkRwitM2nJ15-MRQ_6fg,9
20
- adv_optm-1.1.1.dist-info/RECORD,,
16
+ adv_optm-1.1.4.dist-info/licenses/LICENSE,sha256=HrhfyXIkWY2tGFK11kg7vPCqhgh5DcxleloqdhrpyMY,11558
17
+ adv_optm-1.1.4.dist-info/METADATA,sha256=eaUrKC9WbjSIjwNZaqIuGdn11tZC_Ob39fxnFo_Rbd0,13950
18
+ adv_optm-1.1.4.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
19
+ adv_optm-1.1.4.dist-info/top_level.txt,sha256=iNfBIIzu-lPrQ7jyC56WBCcbkRwitM2nJ15-MRQ_6fg,9
20
+ adv_optm-1.1.4.dist-info/RECORD,,