adv-optm 2.4.dev2__tar.gz → 2.4.dev4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. {adv_optm-2.4.dev2 → adv_optm-2.4.dev4}/PKG-INFO +1 -1
  2. {adv_optm-2.4.dev2 → adv_optm-2.4.dev4}/adv_optm/__init__.py +1 -1
  3. {adv_optm-2.4.dev2 → adv_optm-2.4.dev4}/adv_optm/optim/AdaMuon_adv.py +2 -2
  4. {adv_optm-2.4.dev2 → adv_optm-2.4.dev4}/adv_optm/optim/AdamW_adv.py +13 -6
  5. {adv_optm-2.4.dev2 → adv_optm-2.4.dev4}/adv_optm/optim/Adopt_adv.py +33 -21
  6. {adv_optm-2.4.dev2 → adv_optm-2.4.dev4}/adv_optm/optim/Lion_adv.py +9 -7
  7. {adv_optm-2.4.dev2 → adv_optm-2.4.dev4}/adv_optm/optim/Muon_adv.py +2 -2
  8. {adv_optm-2.4.dev2 → adv_optm-2.4.dev4}/adv_optm/optim/Prodigy_adv.py +13 -7
  9. {adv_optm-2.4.dev2 → adv_optm-2.4.dev4}/adv_optm/optim/SignSGD_adv.py +10 -11
  10. {adv_optm-2.4.dev2 → adv_optm-2.4.dev4}/adv_optm/optim/Simplified_AdEMAMix.py +11 -5
  11. {adv_optm-2.4.dev2 → adv_optm-2.4.dev4}/adv_optm/util/Kourkoutas.py +43 -12
  12. {adv_optm-2.4.dev2 → adv_optm-2.4.dev4}/adv_optm/util/Muon_AuxAdam.py +8 -2
  13. adv_optm-2.4.dev4/adv_optm/util/OrthoGrad.py +50 -0
  14. {adv_optm-2.4.dev2 → adv_optm-2.4.dev4}/adv_optm/util/centered_decay.py +1 -1
  15. {adv_optm-2.4.dev2 → adv_optm-2.4.dev4}/adv_optm/util/param_update.py +5 -5
  16. {adv_optm-2.4.dev2 → adv_optm-2.4.dev4}/adv_optm/util/scaled_optm.py +9 -5
  17. adv_optm-2.4.dev4/adv_optm/util/update_util.py +73 -0
  18. {adv_optm-2.4.dev2 → adv_optm-2.4.dev4}/adv_optm.egg-info/PKG-INFO +1 -1
  19. {adv_optm-2.4.dev2 → adv_optm-2.4.dev4}/setup.py +1 -1
  20. adv_optm-2.4.dev2/adv_optm/util/OrthoGrad.py +0 -21
  21. adv_optm-2.4.dev2/adv_optm/util/update_util.py +0 -32
  22. {adv_optm-2.4.dev2 → adv_optm-2.4.dev4}/LICENSE +0 -0
  23. {adv_optm-2.4.dev2 → adv_optm-2.4.dev4}/README.md +0 -0
  24. {adv_optm-2.4.dev2 → adv_optm-2.4.dev4}/adv_optm/optim/Lion_Prodigy_adv.py +0 -0
  25. {adv_optm-2.4.dev2 → adv_optm-2.4.dev4}/adv_optm/optim/__init__.py +0 -0
  26. {adv_optm-2.4.dev2 → adv_optm-2.4.dev4}/adv_optm/util/Muon_util.py +0 -0
  27. {adv_optm-2.4.dev2 → adv_optm-2.4.dev4}/adv_optm/util/__init__.py +0 -0
  28. {adv_optm-2.4.dev2 → adv_optm-2.4.dev4}/adv_optm/util/factorization_util.py +0 -0
  29. {adv_optm-2.4.dev2 → adv_optm-2.4.dev4}/adv_optm/util/lion_k.py +0 -0
  30. {adv_optm-2.4.dev2 → adv_optm-2.4.dev4}/adv_optm.egg-info/SOURCES.txt +0 -0
  31. {adv_optm-2.4.dev2 → adv_optm-2.4.dev4}/adv_optm.egg-info/dependency_links.txt +0 -0
  32. {adv_optm-2.4.dev2 → adv_optm-2.4.dev4}/adv_optm.egg-info/requires.txt +0 -0
  33. {adv_optm-2.4.dev2 → adv_optm-2.4.dev4}/adv_optm.egg-info/top_level.txt +0 -0
  34. {adv_optm-2.4.dev2 → adv_optm-2.4.dev4}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: adv_optm
3
- Version: 2.4.dev2
3
+ Version: 2.4.dev4
4
4
  Summary: A family of highly efficient, lightweight yet powerful optimizers.
5
5
  Home-page: https://github.com/Koratahiu/Advanced_Optimizers
6
6
  Author: Koratahiu
@@ -22,4 +22,4 @@ __all__ = [
22
22
  "SignSGD_adv",
23
23
  ]
24
24
 
25
- __version__ = "2.4.dev2"
25
+ __version__ = "2.4.dev4"
@@ -280,8 +280,8 @@ class AdaMuon_adv(torch.optim.Optimizer):
280
280
  def load_state_dict(self, state_dict: dict) -> None:
281
281
  """
282
282
  Overrides default load_state_dict to implement a workaround for PyTorch's
283
- automatic dtype casting. It ensures factorized states remain float32 for
284
- stability, preserves integer/float8 quantized anchor states, and forces
283
+ automatic dtype casting. It ensures factorized states remain float32 for
284
+ stability, preserves integer/float8 quantized anchor states, and forces
285
285
  standard states onto the parameter's current dtype/device.
286
286
  """
287
287
  super().load_state_dict(state_dict)
@@ -91,7 +91,7 @@ class AdamW_adv(torch.optim.Optimizer):
91
91
  'int4': Uses 4-bit block-wise quantization (block size 32).
92
92
  nnmf_factor (bool): whether to use the factorization or disable it to use
93
93
  the uncompressed optimizer. (default: False)
94
- factored_2nd (bool): whether to keep the first moment uncompressed (dense)
94
+ factored_2nd (bool): whether to keep the first moment uncompressed (dense)
95
95
  while only factorizing the second moment. (default: True)
96
96
  """
97
97
 
@@ -192,8 +192,8 @@ class AdamW_adv(torch.optim.Optimizer):
192
192
  def load_state_dict(self, state_dict: dict) -> None:
193
193
  """
194
194
  Overrides default load_state_dict to implement a workaround for PyTorch's
195
- automatic dtype casting. It ensures factorized states remain float32 for
196
- stability, preserves integer/float8 quantized anchor states, and forces
195
+ automatic dtype casting. It ensures factorized states remain float32 for
196
+ stability, preserves integer/float8 quantized anchor states, and forces
197
197
  standard states onto the parameter's current dtype/device.
198
198
  """
199
199
  super().load_state_dict(state_dict)
@@ -349,7 +349,11 @@ class AdamW_adv(torch.optim.Optimizer):
349
349
  update_mt = mt if not factored_2nd else mt.clone()
350
350
 
351
351
  vt = _reconstruct_state((state['mu_v_nmf'], state['mv_v_nmf']), signed=False)
352
- vt.mul_(beta2).addcmul_(grad_reshaped, grad_reshaped, value=1.0 - beta2)
352
+
353
+ if isinstance(beta2, torch.Tensor) and beta2.dim() > 0:
354
+ vt.mul_(beta2).addcmul_(grad_reshaped, grad_reshaped * (1.0 - beta2))
355
+ else:
356
+ vt.mul_(beta2).addcmul_(grad_reshaped, grad_reshaped, value=1.0 - beta2)
353
357
 
354
358
  if self.use_AdEMAMix:
355
359
  if factored_2nd:
@@ -363,7 +367,7 @@ class AdamW_adv(torch.optim.Optimizer):
363
367
  update = update_mt.add_(mt_slow, alpha=alpha)
364
368
  else:
365
369
  update = grad_reshaped.add(mt_slow, alpha=alpha)
366
-
370
+
367
371
  if not factored_2nd:
368
372
  # Factorize
369
373
  state['mu_m_slow_nmf'], state['mv_m_slow_nmf'], state['sign_slow'] = _factorize_state(mt_slow, signed=True)
@@ -413,7 +417,10 @@ class AdamW_adv(torch.optim.Optimizer):
413
417
  update = update_mt if beta1 > 0 else grad.clone()
414
418
 
415
419
  exp_avg_sq = state['exp_avg_sq']
416
- exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
420
+ if isinstance(beta2, torch.Tensor) and beta2.dim() > 0:
421
+ exp_avg_sq.mul_(beta2).addcmul_(grad, grad * (1.0 - beta2))
422
+ else:
423
+ exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)
417
424
 
418
425
  if group['use_atan2']:
419
426
  denom = exp_avg_sq.sqrt()
@@ -107,7 +107,7 @@ class Adopt_adv(torch.optim.Optimizer):
107
107
  'int4': Uses 4-bit block-wise quantization (block size 32).
108
108
  nnmf_factor (bool): whether to use the factorization or disable it to use
109
109
  the uncompressed optimizer. (default: False)
110
- factored_2nd (bool): whether to keep the first moment uncompressed (dense)
110
+ factored_2nd (bool): whether to keep the first moment uncompressed (dense)
111
111
  while only factorizing the second moment. (default: True)
112
112
  """
113
113
 
@@ -189,7 +189,7 @@ class Adopt_adv(torch.optim.Optimizer):
189
189
  "scaled_optm": scaled_optm,
190
190
  "centered_wd": centered_wd,
191
191
  "centered_wd_mode": centered_wd_mode,
192
- "nnmf_factor": nnmf_factor, "vector_reshape": vector_reshape, "factored_2nd": factored_2nd,
192
+ "nnmf_factor": nnmf_factor, "vector_reshape": vector_reshape, "factored_2nd": factored_2nd,
193
193
  "compiled_optimizer": compiled_optimizer,
194
194
  }
195
195
  self.clip_lambda = clip_lambda
@@ -222,8 +222,8 @@ class Adopt_adv(torch.optim.Optimizer):
222
222
  def load_state_dict(self, state_dict: dict) -> None:
223
223
  """
224
224
  Overrides default load_state_dict to implement a workaround for PyTorch's
225
- automatic dtype casting. It ensures factorized states remain float32 for
226
- stability, preserves integer/float8 quantized anchor states, and forces
225
+ automatic dtype casting. It ensures factorized states remain float32 for
226
+ stability, preserves integer/float8 quantized anchor states, and forces
227
227
  standard states onto the parameter's current dtype/device.
228
228
  """
229
229
  super().load_state_dict(state_dict)
@@ -244,6 +244,19 @@ class Adopt_adv(torch.optim.Optimizer):
244
244
  grad = p.grad
245
245
  state = self.state[p]
246
246
 
247
+
248
+ beta1, beta2 = group['betas']
249
+
250
+ if group.get('kourkoutas_beta', False):
251
+ if 'step' not in state:
252
+ current_step = 0
253
+ else:
254
+ current_step = state['step']
255
+ # Call prepare_step() once at the beginning of the step for all params
256
+ self.kourkoutas_helper.maybe_prepare_step(current_step, p.device)
257
+ # Get the dynamic beta2 calculated in prepare_step()
258
+ beta2 = self.kourkoutas_helper.get_beta2(p, group)
259
+
247
260
  # State Initialization
248
261
  if 'step' not in state:
249
262
  state['step'] = 0
@@ -256,6 +269,12 @@ class Adopt_adv(torch.optim.Optimizer):
256
269
 
257
270
  dtype = torch.float32 if state['factored'] else p.dtype
258
271
 
272
+ vt_init = grad.pow(2).to(dtype)
273
+ if isinstance(beta2, torch.Tensor) and beta2.dim() > 0:
274
+ vt_init.mul_(beta2).addcmul_(grad.to(dtype), grad.to(dtype) * (1.0 - beta2))
275
+ else:
276
+ vt_init.mul_(beta2).addcmul_(grad.to(dtype), grad.to(dtype), value=1.0 - beta2)
277
+
259
278
  if state['factored']:
260
279
  state['effective_shape'] = _get_effective_shape(p.numel())
261
280
  d1, d2 = state['effective_shape']
@@ -279,33 +298,21 @@ class Adopt_adv(torch.optim.Optimizer):
279
298
  if self.use_AdEMAMix:
280
299
  state['exp_avg_slow'] = torch.zeros_like(p, device=p.device, dtype=dtype)
281
300
  # Second moment (v)
282
- vt_init = grad.to(dtype).view(d1, d2).square()
283
- # Allocate NMF factors for vt
284
- state['mu_v_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
285
- state['mv_v_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
286
- # Initialize v_0
287
- state['mu_v_nmf'], state['mv_v_nmf'] = _nnmf(vt_init)
301
+ state['mu_v_nmf'], state['mv_v_nmf'] = _nnmf(vt_init.view(d1, d2))
288
302
  del vt_init
289
303
  else: # Fallback for non-factored tensors
290
304
  if group['betas'][0] > 0:
291
305
  state['exp_avg'] = torch.zeros_like(p, device=p.device, dtype=dtype)
292
306
  if self.use_AdEMAMix:
293
307
  state['exp_avg_slow'] = torch.zeros_like(p, device=p.device, dtype=dtype)
294
- state['exp_avg_sq'] = grad.to(dtype).square()
308
+ state['exp_avg_sq'] = vt_init
295
309
 
296
310
  if group.get('scaled_optm', False) and is_spectral(p):
297
311
  init_spectral_norm(group, state, p)
298
312
 
299
313
  _init_anchor(p, state, group)
300
314
 
301
- beta1, beta2 = group['betas']
302
-
303
315
  current_step = state['step']
304
- if group.get('kourkoutas_beta', False):
305
- # Call prepare_step() once at the beginning of the step for all params
306
- self.kourkoutas_helper.maybe_prepare_step(current_step, p.device)
307
- # Get the dynamic beta2 calculated in prepare_step()
308
- beta2 = self.kourkoutas_helper.get_beta2(p, group)
309
316
 
310
317
  # The first step is for initialization only (skip when use_atan2 as it's scale invariant).
311
318
  if state['step'] == 0 and not self.use_atan2:
@@ -361,7 +368,10 @@ class Adopt_adv(torch.optim.Optimizer):
361
368
  denom = vt.sqrt()
362
369
 
363
370
  # Update second moment v_t for the *next* step using raw g_t
364
- vt.mul_(beta2).addcmul_(grad_reshaped, grad_reshaped, value=1.0 - beta2)
371
+ if isinstance(beta2, torch.Tensor) and beta2.dim() > 0:
372
+ vt.mul_(beta2).addcmul_(grad_reshaped, grad_reshaped * (1.0 - beta2))
373
+ else:
374
+ vt.mul_(beta2).addcmul_(grad_reshaped, grad_reshaped, value=1.0 - beta2)
365
375
  # Factorize
366
376
  state['mu_v_nmf'], state['mv_v_nmf'] = _factorize_state(vt, signed=False)
367
377
  del vt
@@ -475,9 +485,11 @@ class Adopt_adv(torch.optim.Optimizer):
475
485
  else:
476
486
  update = normalized_grad
477
487
 
478
-
479
488
  # Update second moment v_t for the next step using raw g_t
480
- vt.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
489
+ if isinstance(beta2, torch.Tensor) and beta2.dim() > 0:
490
+ vt.mul_(beta2).addcmul_(grad, grad * (1.0 - beta2))
491
+ else:
492
+ vt.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
481
493
 
482
494
  update_scaling = lr * A if self.use_atan2 else lr
483
495
 
@@ -8,6 +8,7 @@ from ..util.factorization_util import _get_effective_shape, _reconstruct_state,
8
8
  from ..util.lion_k import _get_lion_k_update
9
9
  from ..util.scaled_optm import scale_update, is_spectral, init_spectral_norm
10
10
  from ..util.centered_decay import _init_anchor
11
+ from ..util.update_util import _get_l1_adaptive_lr
11
12
 
12
13
 
13
14
  class Lion_adv(torch.optim.Optimizer):
@@ -46,7 +47,7 @@ class Lion_adv(torch.optim.Optimizer):
46
47
  updates. Overrides explicit kappa_p value. (default: False).
47
48
  freeze_on_flip (bool): Projected SignGD One-hit freeze. Masks updates for
48
49
  coordinates where the gradient sign flips compared to the previous step. (default: False)
49
- l1_adaptive (bool): Scales learning rate dynamically
50
+ l1_adaptive (bool): Scales learning rate dynamically
50
51
  by the L1 norm of the gradient to handle gradient heterogeneity. (default: False).
51
52
  centered_wd (float): Centered Weight Decay coefficient. Instead of decaying weights
52
53
  toward zero, they are decayed toward their initial values (anchors). This
@@ -137,8 +138,8 @@ class Lion_adv(torch.optim.Optimizer):
137
138
  def load_state_dict(self, state_dict: dict) -> None:
138
139
  """
139
140
  Overrides default load_state_dict to implement a workaround for PyTorch's
140
- automatic dtype casting. It ensures factorized states remain float32 for
141
- stability, preserves integer/float8 quantized anchor states, and forces
141
+ automatic dtype casting. It ensures factorized states remain float32 for
142
+ stability, preserves integer/float8 quantized anchor states, and forces
142
143
  standard states onto the parameter's current dtype/device.
143
144
  """
144
145
  super().load_state_dict(state_dict)
@@ -251,8 +252,7 @@ class Lion_adv(torch.optim.Optimizer):
251
252
  # Compute update term c_t
252
253
  update = torch.lerp(grad_reshaped, exp_avg, beta1)
253
254
 
254
- if group.get("l1_adaptive", False) and kappa_p == 1:
255
- lr = lr * (update.norm(p=1))
255
+ l1_mean = _get_l1_adaptive_lr(p, update, state, group, kappa_p)
256
256
 
257
257
  # Standard Lion momentum update
258
258
  # m_t = beta2 * m_{t-1} + (1-beta2) * g_t
@@ -286,8 +286,7 @@ class Lion_adv(torch.optim.Optimizer):
286
286
  # Compute update term
287
287
  update = torch.lerp(grad, exp_avg, beta1)
288
288
 
289
- if group.get("l1_adaptive", False) and kappa_p == 1:
290
- lr = lr * (update.norm(p=1))
289
+ l1_mean = _get_l1_adaptive_lr(p, update, state, group, kappa_p)
291
290
 
292
291
  update = _get_lion_k_update(update, kappa_p)
293
292
 
@@ -305,6 +304,9 @@ class Lion_adv(torch.optim.Optimizer):
305
304
  update = torch.where(current_sign == state['prev_sign'], update, 0.0)
306
305
  state['prev_sign'] = current_sign
307
306
 
307
+ if l1_mean is not None:
308
+ update.mul_(l1_mean)
309
+
308
310
  if group.get('scaled_optm', False):
309
311
  update = scale_update(p, update, lr, vector_state=state.get('spectral_v'))
310
312
  else:
@@ -259,8 +259,8 @@ class Muon_adv(torch.optim.Optimizer):
259
259
  def load_state_dict(self, state_dict: dict) -> None:
260
260
  """
261
261
  Overrides default load_state_dict to implement a workaround for PyTorch's
262
- automatic dtype casting. It ensures factorized states remain float32 for
263
- stability, preserves integer/float8 quantized anchor states, and forces
262
+ automatic dtype casting. It ensures factorized states remain float32 for
263
+ stability, preserves integer/float8 quantized anchor states, and forces
264
264
  standard states onto the parameter's current dtype/device.
265
265
  """
266
266
  super().load_state_dict(state_dict)
@@ -67,7 +67,7 @@ class Prodigy_adv(torch.optim.Optimizer):
67
67
  stability. (default: 100.0)
68
68
  nnmf_factor (bool): whether to use the factorization or disable it to use
69
69
  the uncompressed optimizer. (default: False)
70
- factored_2nd (bool): whether to keep the first moment uncompressed (dense)
70
+ factored_2nd (bool): whether to keep the first moment uncompressed (dense)
71
71
  while only factorizing the second moment. (default: True)
72
72
  d0 (float):
73
73
  Initial D estimate for D-adaptation (default 1e-6). Rarely needs changing.
@@ -255,8 +255,8 @@ class Prodigy_adv(torch.optim.Optimizer):
255
255
  def load_state_dict(self, state_dict: dict) -> None:
256
256
  """
257
257
  Overrides default load_state_dict to implement a workaround for PyTorch's
258
- automatic dtype casting. It ensures factorized states remain float32 for
259
- stability, preserves integer/float8 quantized anchor states, and forces
258
+ automatic dtype casting. It ensures factorized states remain float32 for
259
+ stability, preserves integer/float8 quantized anchor states, and forces
260
260
  standard states onto the parameter's current dtype/device.
261
261
  """
262
262
  super().load_state_dict(state_dict)
@@ -440,7 +440,10 @@ class Prodigy_adv(torch.optim.Optimizer):
440
440
  update_mt = mt if not factored_2nd else mt.clone()
441
441
 
442
442
  vt = _reconstruct_state((state['mu_v_nmf'], state['mv_v_nmf']), signed=False)
443
- vt.mul_(beta2).addcmul_(grad_reshaped, grad_reshaped, value=d * d * (1.0 - beta2))
443
+ if isinstance(beta2, torch.Tensor) and beta2.dim() > 0:
444
+ vt.mul_(beta2).addcmul_(grad_reshaped, grad_reshaped * (d * d * (1.0 - beta2)))
445
+ else:
446
+ vt.mul_(beta2).addcmul_(grad_reshaped, grad_reshaped, value=d * d * (1.0 - beta2))
444
447
 
445
448
  if self.use_AdEMAMix:
446
449
  if factored_2nd:
@@ -453,7 +456,7 @@ class Prodigy_adv(torch.optim.Optimizer):
453
456
  update = update_mt.add_(mt_slow, alpha=alpha)
454
457
  else:
455
458
  update = grad_reshaped.mul(d).add_(mt_slow, alpha=alpha)
456
-
459
+
457
460
  if not factored_2nd:
458
461
  # Factorize
459
462
  state['mu_m_slow_nmf'], state['mv_m_slow_nmf'], state['sign_slow'] = _factorize_state(mt_slow, signed=True)
@@ -514,7 +517,10 @@ class Prodigy_adv(torch.optim.Optimizer):
514
517
  update = grad.mul(d)
515
518
 
516
519
  exp_avg_sq = state['exp_avg_sq']
517
- exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=d * d * (1.0 - beta2))
520
+ if isinstance(beta2, torch.Tensor) and beta2.dim() > 0:
521
+ exp_avg_sq.mul_(beta2).addcmul_(grad, grad * (d * d * (1.0 - beta2)))
522
+ else:
523
+ exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=d * d * (1.0 - beta2))
518
524
 
519
525
  if group['use_atan2']:
520
526
  denom = exp_avg_sq.sqrt()
@@ -608,4 +614,4 @@ class Prodigy_adv(torch.optim.Optimizer):
608
614
 
609
615
  # Increment step counter for all groups, regardless of whether d was updated
610
616
  for group in self.param_groups:
611
- group['k'] += 1
617
+ group['k'] += 1
@@ -6,8 +6,8 @@ from ..util import param_update
6
6
  from ..util.OrthoGrad import _orthogonalize_gradient
7
7
  from ..util.factorization_util import _get_effective_shape, _reconstruct_state, _factorize_state, _pack_bools, _unpack_bools
8
8
  from ..util.lion_k import _get_lion_k_update
9
+ from ..util.update_util import _get_l1_adaptive_lr
9
10
  from ..util.scaled_optm import scale_update, is_spectral, init_spectral_norm
10
- from ..util.update_util import _scale_sim_AdEMAMix_update
11
11
  from ..util.centered_decay import _init_anchor
12
12
 
13
13
 
@@ -49,8 +49,8 @@ class SignSGD_adv(torch.optim.Optimizer):
49
49
  stability. (default: 100.0)
50
50
  freeze_on_flip (bool): Projected SignGD One-hit freeze. Masks updates for
51
51
  coordinates where the gradient sign flips compared to the previous step. (default: False)
52
- l1_adaptive (bool): Scales learning rate dynamically.
53
- by the L1 norm of the gradient to handle gradient heterogeneity. (default: False)
52
+ l1_adaptive (bool): Scales the update step magnitude dynamically
53
+ by the mean L1 norm of the momentum/gradient to handle gradient heterogeneity.(default: False)
54
54
  centered_wd (float): Centered Weight Decay coefficient. Instead of decaying weights
55
55
  toward zero, they are decayed toward their initial values (anchors). This
56
56
  can be used together with standard weight decay. (default: 0.0)
@@ -140,8 +140,8 @@ class SignSGD_adv(torch.optim.Optimizer):
140
140
  def load_state_dict(self, state_dict: dict) -> None:
141
141
  """
142
142
  Overrides default load_state_dict to implement a workaround for PyTorch's
143
- automatic dtype casting. It ensures factorized states remain float32 for
144
- stability, preserves integer/float8 quantized anchor states, and forces
143
+ automatic dtype casting. It ensures factorized states remain float32 for
144
+ stability, preserves integer/float8 quantized anchor states, and forces
145
145
  standard states onto the parameter's current dtype/device.
146
146
  """
147
147
  super().load_state_dict(state_dict)
@@ -269,9 +269,7 @@ class SignSGD_adv(torch.optim.Optimizer):
269
269
  if freeze_on_flip:
270
270
  state['sign'] = _pack_bools(raw_update > 0)
271
271
 
272
- if group.get("l1_adaptive", False) and kappa_p == 1:
273
- scale_factor = 1 / _scale_sim_AdEMAMix_update(momentum, state["step"] + 1, alpha_grad, 1, False)
274
- lr = lr * (raw_update.norm(p=1)/scale_factor)
272
+ l1_mean = _get_l1_adaptive_lr(p, raw_update, state, group, kappa_p)
275
273
 
276
274
  update = _get_lion_k_update(raw_update, kappa_p)
277
275
  update = update.view(p.shape)
@@ -296,9 +294,7 @@ class SignSGD_adv(torch.optim.Optimizer):
296
294
  else:
297
295
  raw_update = grad.clone()
298
296
 
299
- if group.get("l1_adaptive", False) and kappa_p == 1:
300
- scale_factor = 1 / _scale_sim_AdEMAMix_update(momentum, state["step"] + 1, alpha_grad, 1, False)
301
- lr = lr * (raw_update.norm(p=1)/scale_factor)
297
+ l1_mean = _get_l1_adaptive_lr(p, raw_update, state, group, kappa_p)
302
298
 
303
299
  update = _get_lion_k_update(raw_update, kappa_p)
304
300
 
@@ -307,6 +303,9 @@ class SignSGD_adv(torch.optim.Optimizer):
307
303
  update = torch.where(current_sign == state['prev_sign'], update, 0.0)
308
304
  state['prev_sign'] = current_sign
309
305
 
306
+ if l1_mean is not None:
307
+ update.mul_(l1_mean)
308
+
310
309
  if group.get('scaled_optm', False):
311
310
  update = scale_update(p, update, lr, vector_state=state.get('spectral_v'))
312
311
  else:
@@ -86,7 +86,7 @@ class Simplified_AdEMAMix(torch.optim.Optimizer):
86
86
  'int4': Uses 4-bit block-wise quantization (block size 32).
87
87
  nnmf_factor (bool): whether to use the factorization or disable it to use
88
88
  the uncompressed optimizer. (default: False)
89
- factored_2nd (bool): whether to keep the first moment uncompressed (dense)
89
+ factored_2nd (bool): whether to keep the first moment uncompressed (dense)
90
90
  while only factorizing the second moment. (default: True)
91
91
  """
92
92
 
@@ -176,8 +176,8 @@ class Simplified_AdEMAMix(torch.optim.Optimizer):
176
176
  def load_state_dict(self, state_dict: dict) -> None:
177
177
  """
178
178
  Overrides default load_state_dict to implement a workaround for PyTorch's
179
- automatic dtype casting. It ensures factorized states remain float32 for
180
- stability, preserves integer/float8 quantized anchor states, and forces
179
+ automatic dtype casting. It ensures factorized states remain float32 for
180
+ stability, preserves integer/float8 quantized anchor states, and forces
181
181
  standard states onto the parameter's current dtype/device.
182
182
  """
183
183
  super().load_state_dict(state_dict)
@@ -320,7 +320,10 @@ class Simplified_AdEMAMix(torch.optim.Optimizer):
320
320
  mt.mul_(beta1).add_(grad_reshaped)
321
321
 
322
322
  vt = _reconstruct_state((state['mu_v_nmf'], state['mv_v_nmf']), signed=False)
323
- vt.mul_(beta2).addcmul_(grad_reshaped, grad_reshaped, value=1.0 - beta2)
323
+ if isinstance(beta2, torch.Tensor) and beta2.dim() > 0:
324
+ vt.mul_(beta2).addcmul_(grad_reshaped, grad_reshaped * (1.0 - beta2))
325
+ else:
326
+ vt.mul_(beta2).addcmul_(grad_reshaped, grad_reshaped, value=1.0 - beta2)
324
327
 
325
328
  # update = mt + (grad_reshaped * alpha_grad)
326
329
  update = torch.add(mt, grad_reshaped, alpha=alpha_grad)
@@ -347,7 +350,10 @@ class Simplified_AdEMAMix(torch.optim.Optimizer):
347
350
 
348
351
  update = torch.add(exp_avg, grad, alpha=alpha_grad)
349
352
 
350
- exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
353
+ if isinstance(beta2, torch.Tensor) and beta2.dim() > 0:
354
+ exp_avg_sq.mul_(beta2).addcmul_(grad, grad * (1.0 - beta2))
355
+ else:
356
+ exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)
351
357
 
352
358
  denom = exp_avg_sq.sqrt().add_(sqrt_den_eps)
353
359
  update.div_(denom)
@@ -34,8 +34,12 @@ class KourkoutasHelper:
34
34
  else:
35
35
  # No key function was provided. Default to coarse, shape-based bucketing.
36
36
  self.optimizer.layer_key_fn = lambda p: \
37
- (id(p),) if p.dim() == 2 and 1 <= p.shape[0] <= 10 and p.shape[1] in {768, 1280, 4096} \
38
- else tuple(p.shape)
37
+ (id(p),) if (
38
+ getattr(p, '_is_oft', False) or
39
+ getattr(p, '_is_lora_A', False) or
40
+ getattr(p, '_is_lora_B', False) or
41
+ getattr(p, '_is_dora_scale', False)
42
+ ) else tuple(p.shape)
39
43
  # This ensures that we won't mix embeddings with tokens (1 to 10)
40
44
  # TODO find a better way to safeguard the embeddings
41
45
 
@@ -55,13 +59,21 @@ class KourkoutasHelper:
55
59
  def _get_or_init_layer_ema_tensor(self, layer_key, layer_params, device):
56
60
  """
57
61
  Retrieves the EMA tensor for this layer.
58
- It handles synchronization between the internal layer_state and
62
+ It handles synchronization between the internal layer_state and
59
63
  the external optimizer.state (which is required for state_dict saving/loading).
60
64
  """
61
65
  # Initialize container in layer_state if missing
62
66
  if layer_key not in self.layer_state:
67
+ p = layer_params[0]
68
+ if getattr(p, '_is_oft', False) or getattr(p, '_is_lora_A', False):
69
+ shape = (p.shape[0], 1)
70
+ elif getattr(p, '_is_lora_B', False):
71
+ shape = (1, p.shape[1])
72
+ else:
73
+ shape = ()
74
+
63
75
  self.layer_state[layer_key] = {
64
- 'sum_sq_accumulator': torch.tensor(0.0, device=device, dtype=torch.float32)
76
+ 'sum_sq_accumulator': torch.zeros(shape, device=device, dtype=torch.float32)
65
77
  }
66
78
 
67
79
  internal_ema = self.layer_state[layer_key].get('kourkoutas_r_ema')
@@ -87,7 +99,15 @@ class KourkoutasHelper:
87
99
 
88
100
  # Case B: No state anywhere. Create new.
89
101
  if internal_ema is None:
90
- new_ema = torch.tensor(0.0, device=device, dtype=torch.float32)
102
+ p = layer_params[0]
103
+ if getattr(p, '_is_oft', False) or getattr(p, '_is_lora_A', False):
104
+ shape = (p.shape[0], 1)
105
+ elif getattr(p, '_is_lora_B', False):
106
+ shape = (1, p.shape[1])
107
+ else:
108
+ shape = ()
109
+
110
+ new_ema = torch.zeros(shape, device=device, dtype=torch.float32)
91
111
  self.layer_state[layer_key]['kourkoutas_r_ema'] = new_ema
92
112
 
93
113
  # Register this tensor in optimizer.state for ALL params so it gets saved
@@ -107,7 +127,7 @@ class KourkoutasHelper:
107
127
 
108
128
  def prepare_step(self, current_step: int, device):
109
129
  """
110
- Calculates dynamic beta2 for all layers using the completed scalar accumulators
130
+ Calculates dynamic beta2 for all layers using the completed accumulators
111
131
  from the PREVIOUS step. Should be called once at the start of an optimizer step.
112
132
  """
113
133
  beta2_log = []
@@ -154,7 +174,10 @@ class KourkoutasHelper:
154
174
  beta2 = beta2_max - (beta2_max - beta2_min) * sun
155
175
 
156
176
  # Store the final calculated beta2 in the helper's transient state for this step.
157
- self.layer_state[layer_key]['dynamic_beta2'] = beta2.item() if isinstance(beta2, torch.Tensor) and not group.get('compiled_optimizer', False) else beta2
177
+ if isinstance(beta2, torch.Tensor) and beta2.numel() == 1 and not group.get('compiled_optimizer', False):
178
+ self.layer_state[layer_key]['dynamic_beta2'] = beta2.item()
179
+ else:
180
+ self.layer_state[layer_key]['dynamic_beta2'] = beta2
158
181
 
159
182
  # Reset the accumulator for the next optimizer step.
160
183
  accumulator.zero_()
@@ -163,10 +186,11 @@ class KourkoutasHelper:
163
186
 
164
187
  # Compute stats for TensorBoard
165
188
  if beta2_log:
166
- beta2_tensor = torch.as_tensor(beta2_log, device='cpu')
189
+ # Handles lists containing both standard floats and heterogeneous tensors
190
+ means = [b.mean().item() if isinstance(b, torch.Tensor) else float(b) for b in beta2_log]
167
191
  self.last_beta2_stats = {
168
- 'mean': beta2_tensor.mean().item()
169
- }
192
+ 'mean': sum(means) / len(means)
193
+ }
170
194
 
171
195
  def maybe_prepare_step(self, current_step: int, device):
172
196
  """
@@ -184,9 +208,16 @@ class KourkoutasHelper:
184
208
 
185
209
  if layer_key in self.layer_info and layer_key in self.layer_state:
186
210
  # Accumulate for the *next* step's prepare_step call
187
- self.layer_state[layer_key]['sum_sq_accumulator'] += torch.sum(grad.detach().pow(2)).float()
211
+ if getattr(p, '_is_oft', False) or getattr(p, '_is_lora_A', False):
212
+ sq_norm = torch.sum(grad.detach().pow(2), dim=1, keepdim=True).float()
213
+ elif getattr(p, '_is_lora_B', False):
214
+ sq_norm = torch.sum(grad.detach().pow(2), dim=0, keepdim=True).float()
215
+ else:
216
+ sq_norm = torch.sum(grad.detach().pow(2)).float()
217
+
218
+ self.layer_state[layer_key]['sum_sq_accumulator'] += sq_norm
188
219
 
189
- def get_beta2(self, p: torch.Tensor, group: dict) -> float:
220
+ def get_beta2(self, p: torch.Tensor, group: dict) -> float | torch.Tensor:
190
221
  """
191
222
  Gets the appropriate beta2 for the current parameter, handling warmup and dynamic value fetching.
192
223
  """
@@ -87,7 +87,10 @@ def _adam_step_parameter(self, p, grad, state, group, beta1_adam, beta2_adam, sq
87
87
  update_mt = mt
88
88
 
89
89
  vt = _reconstruct_state((state['mu_v_nmf'], state['mv_v_nmf']), signed=False)
90
- vt.mul_(beta2_adam).addcmul_(grad_reshaped, grad_reshaped, value=1.0 - beta2_adam)
90
+ if isinstance(beta2_adam, torch.Tensor) and beta2_adam.dim() > 0:
91
+ vt.mul_(beta2_adam).addcmul_(grad_reshaped, grad_reshaped * (1.0 - beta2_adam))
92
+ else:
93
+ vt.mul_(beta2_adam).addcmul_(grad_reshaped, grad_reshaped, value=1.0 - beta2_adam)
91
94
 
92
95
  if group.get('adam_use_AdEMAMix'):
93
96
  mt_slow = _reconstruct_state((state['mu_m_slow_nmf'], state['mv_m_slow_nmf'], state['sign_slow'], d2), signed=True)
@@ -148,7 +151,10 @@ def _adam_step_parameter(self, p, grad, state, group, beta1_adam, beta2_adam, sq
148
151
  update = update_mt if beta1_adam > 0 else grad.clone()
149
152
 
150
153
  exp_avg_sq = state['exp_avg_sq']
151
- exp_avg_sq.mul_(beta2_adam).addcmul_(grad, grad, value=1 - beta2_adam)
154
+ if isinstance(beta2_adam, torch.Tensor) and beta2_adam.dim() > 0:
155
+ exp_avg_sq.mul_(beta2_adam).addcmul_(grad, grad * (1.0 - beta2_adam))
156
+ else:
157
+ exp_avg_sq.mul_(beta2_adam).addcmul_(grad, grad, value=1.0 - beta2_adam)
152
158
 
153
159
  if group.get('adam_use_atan2'):
154
160
  denom = exp_avg_sq.sqrt()
@@ -0,0 +1,50 @@
1
+ import torch
2
+
3
+ def _orthogonalize_gradient(p: torch.Tensor, grad: torch.Tensor) -> torch.Tensor:
4
+ """
5
+ Projects the gradient `grad` to be orthogonal to the parameter `p`.
6
+ Modified from:
7
+ https://github.com/LucasPrietoAl/grokking-at-the-edge-of-numerical-stability/blob/720d2444df12b851d6cb417ab08cf125c822b2ae/orthograd.py
8
+ """
9
+ if getattr(p, '_is_oft', False) or getattr(p, '_is_lora_A', False):
10
+ return _orthogonalize_gradient_granular(p, grad, dim=1)
11
+ elif getattr(p, '_is_lora_B', False):
12
+ return _orthogonalize_gradient_granular(p, grad, dim=0)
13
+
14
+ original_shape = grad.shape
15
+ original_dtype = grad.dtype
16
+ w = p.view(-1).float()
17
+ g = grad.view(-1).float()
18
+ w_norm_sq = torch.dot(w, w).add_(1e-30)
19
+ proj = torch.dot(w, g) / w_norm_sq
20
+ g_orth = g.sub(w * proj)
21
+ g_norm = g.norm(2)
22
+ g_orth_norm = g_orth.norm(2).add_(1e-30)
23
+ g_orth_scaled = g_orth * (g_norm / g_orth_norm)
24
+ return g_orth_scaled.view(original_shape).to(original_dtype)
25
+
26
+ def _orthogonalize_gradient_granular(p: torch.Tensor, grad: torch.Tensor, dim: int = 1, eps: float = 1e-30) -> torch.Tensor:
27
+ """
28
+ Projects the gradient `grad` to be orthogonal to the parameter `p` row/col-wise,
29
+ while preserving the original norm of the gradient for each row/col.
30
+ """
31
+ original_dtype = grad.dtype
32
+ p_f32 = p.float()
33
+ grad_f32 = grad.float()
34
+
35
+ # Calculate the dot product <p, grad> for each row/col
36
+ dot_prod = torch.sum(p_f32 * grad_f32, dim=dim, keepdim=True)
37
+
38
+ # Calculate ||p||^2 for each row/col
39
+ p_norm_sq = torch.sum(p_f32 * p_f32, dim=dim, keepdim=True).add_(eps)
40
+
41
+ # Project: g_orth = g - (p * <p, g> / ||p||^2)
42
+ proj = dot_prod / p_norm_sq
43
+ grad_orth = grad_f32 - (proj * p_f32)
44
+
45
+ # Magnitude Preservation
46
+ g_norm = torch.norm(grad_f32, p=2, dim=dim, keepdim=True)
47
+ g_orth_norm = torch.norm(grad_orth, p=2, dim=dim, keepdim=True).add_(eps)
48
+ grad_orth_scaled = grad_orth * (g_norm / g_orth_norm)
49
+
50
+ return grad_orth_scaled.to(original_dtype)
@@ -109,4 +109,4 @@ def dequantize_anchor(p, state, group, dtype):
109
109
  anchor_blocks = quantized_blocks.to(dtype) * scales.unsqueeze(1) + mins.unsqueeze(1)
110
110
 
111
111
  # Flatten, truncate any padding added during quantization, and reshape
112
- return anchor_blocks.view(-1)[:orig_numel].view(orig_shape)
112
+ return anchor_blocks.view(-1)[:orig_numel].view(orig_shape)
@@ -138,7 +138,7 @@ def set_seed(device: torch.device):
138
138
 
139
139
  def get_generator(device: torch.device) -> torch.Generator:
140
140
  """
141
- Retrieves (and initializes if necessary) the deterministic generator
141
+ Retrieves (and initializes if necessary) the deterministic generator
142
142
  for the specified device.
143
143
  """
144
144
  if device not in _generators:
@@ -241,9 +241,9 @@ def post_process_loaded_state(optimizer: Optimizer) -> None:
241
241
  # Deterministically check if this parameter skipped quantization
242
242
  numel = p.numel()
243
243
  is_skipped = (
244
- numel == 0 or
245
- (mode in ['int8', 'int4'] and numel < 10000) or
246
- p.ndim == 1 or
244
+ numel == 0 or
245
+ (mode in ['int8', 'int4'] and numel < 10000) or
246
+ p.ndim == 1 or
247
247
  getattr(p, '_is_dora_scale', False)
248
248
  )
249
249
 
@@ -283,4 +283,4 @@ def post_process_loaded_state(optimizer: Optimizer) -> None:
283
283
 
284
284
  # Ensure device match
285
285
  if state[key].device != p.device:
286
- state[key] = state[key].to(p.device)
286
+ state[key] = state[key].to(p.device)
@@ -9,7 +9,7 @@ def scale_update(
9
9
  vector_state: torch.Tensor | None = None
10
10
  ) -> torch.Tensor:
11
11
  """
12
- Applies adaptive scaling to the parameter update based on the parameter's
12
+ Applies adaptive scaling to the parameter update based on the parameter's
13
13
  role (DoRA, OFT, or LoRA/Full Finetuning).
14
14
 
15
15
  Args:
@@ -28,11 +28,15 @@ def scale_update(
28
28
  if is_dora_scale or p.ndim == 1:
29
29
  return rms_normalization(update, dim=None, lr=lr)
30
30
 
31
- # Orthogonal Fine-Tuning (OFT)
32
- # RMS normalization (dim=1 normalizes per block)
31
+ # Orthogonal Fine-Tuning (OFT)
33
32
  # This guarantees O(1) update complexity scaling, independent of block sizes.
34
33
  if is_oft:
35
- return rms_normalization(update, dim=1, lr=lr)
34
+ n = update.shape[1]
35
+ # Calculate block size (b)
36
+ b = (1 + (1 + 8 * n) ** 0.5) / 2
37
+ target_norm = (b / 8) ** 0.5
38
+ scale = target_norm / (n ** 0.5)
39
+ return rms_normalization(update, dim=1, lr=lr * scale)
36
40
 
37
41
  # LoRA Factors or Full Finetuning weights
38
42
  # Scales update to maintain consistent spectral norm across different layer sizes and ranks.
@@ -44,7 +48,7 @@ def scale_update(
44
48
 
45
49
  def scale_wds(wd: float, cwd: float, p: torch.Tensor) -> tuple[float, float]:
46
50
  """
47
- Adjusts standard weight decay and centered weight decay based on the parameter's
51
+ Adjusts standard weight decay and centered weight decay based on the parameter's
48
52
  shape and type to maintain effective regularization strength.
49
53
  """
50
54
  # DoRA Scale (Magnitude Vector)
@@ -0,0 +1,73 @@
1
+ import torch
2
+
3
+ def _grams_update(mt: torch.Tensor, grad: torch.Tensor, inplace: bool=False):
4
+ """
5
+ Applies the update rule of "Gradient Descent with Adaptive Momentum Scaling"
6
+ (https://arxiv.org/abs/2412.17107).
7
+ """
8
+ if inplace:
9
+ return mt.abs_().mul_(grad.sign())
10
+ return grad.sign().mul_(mt.abs())
11
+
12
+ def _cautious_update(mt: torch.Tensor, grad: torch.Tensor, inplace: bool=False):
13
+ """
14
+ Applies the update rule of "Cautious Optimizers: Improving Training with One
15
+ Line of Code" (https://arxiv.org/abs/2411.16085).
16
+ """
17
+ mask = (mt * grad > 0).to(grad.dtype)
18
+ mask.div_(mask.mean().clamp_min_(1e-3))
19
+ if inplace:
20
+ update_mt = mt.mul_(mask)
21
+ else:
22
+ update_mt = mt.mul(mask)
23
+ del mask
24
+ return update_mt
25
+
26
+ def _scale_sim_AdEMAMix_update(beta: float, current_step: int, alpha_grad: float, lr: float, scaled_optm: bool=False):
27
+ if scaled_optm:
28
+ return lr
29
+ momentum_scale = (1 - beta ** current_step) / (1 - beta)
30
+ total_scale = 1 / (momentum_scale + alpha_grad)
31
+ lr = lr * total_scale
32
+ return lr
33
+
34
+ def _get_l1_adaptive_lr(
35
+ p: torch.Tensor,
36
+ update: torch.Tensor,
37
+ state: dict,
38
+ group: dict,
39
+ kappa_p: float
40
+ ) -> torch.Tensor:
41
+ """
42
+ Calculates the L1 adaptive learning rate based on gradient heterogeneity.
43
+ """
44
+ if not group.get("l1_adaptive", False) and kappa_p != 1:
45
+ return None
46
+
47
+ momentum = group["momentum"]
48
+ alpha_grad = group["alpha_grad"]
49
+ update_view = update.view(p.shape)
50
+
51
+ # Calculate scale factor based on momentum/update magnitude
52
+ scale_factor = _scale_sim_AdEMAMix_update(
53
+ momentum, state["step"] + 1, alpha_grad, 1, False
54
+ )
55
+
56
+ # Determine dimension for mean calculation based on parameter type
57
+ if getattr(p, '_is_oft', False) or getattr(p, '_is_lora_A', False):
58
+ l1_dim = 1
59
+ elif getattr(p, '_is_lora_B', False):
60
+ l1_dim = 0
61
+ else:
62
+ update_abs = update_view.abs() * scale_factor
63
+ if update_abs.ndim >= 2:
64
+ orig_shape = update_abs.shape
65
+ update_2d = update_abs.view(orig_shape[0], -1)
66
+ mean_l1_norm_2d = torch.outer(update_2d.mean(dim=1), update_2d.mean(dim=0))
67
+ return mean_l1_norm_2d.view(orig_shape)
68
+ else:
69
+ return update_abs.mean()
70
+
71
+ mean_l1_norm = update_view.abs().mean(dim=l1_dim, keepdim=True) * scale_factor
72
+
73
+ return mean_l1_norm
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: adv_optm
3
- Version: 2.4.dev2
3
+ Version: 2.4.dev4
4
4
  Summary: A family of highly efficient, lightweight yet powerful optimizers.
5
5
  Home-page: https://github.com/Koratahiu/Advanced_Optimizers
6
6
  Author: Koratahiu
@@ -5,7 +5,7 @@ with open("README.md", "r", encoding="utf-8") as fh:
5
5
 
6
6
  setup(
7
7
  name="adv_optm",
8
- version="2.4.dev2",
8
+ version="2.4.dev4",
9
9
  author="Koratahiu",
10
10
  author_email="hiuhonor@gmail.com",
11
11
  license='Apache 2.0',
@@ -1,21 +0,0 @@
1
- import torch
2
-
3
- def _orthogonalize_gradient(p: torch.Tensor, grad: torch.Tensor) -> torch.Tensor:
4
- """
5
- Projects the gradient `grad` to be orthogonal to the parameter `p`.
6
- Modified from:
7
- https://github.com/LucasPrietoAl/grokking-at-the-edge-of-numerical-stability/blob/720d2444df12b851d6cb417ab08cf125c822b2ae/orthograd.py
8
- """
9
- if grad.is_sparse:
10
- raise RuntimeError("OrthoGrad logic does not support sparse gradients.")
11
- original_shape = grad.shape
12
- original_dtype = grad.dtype
13
- w = p.view(-1).float()
14
- g = grad.view(-1).float()
15
- w_norm_sq = torch.dot(w, w).add_(1e-30)
16
- proj = torch.dot(w, g) / w_norm_sq
17
- g_orth = g.sub(w * proj)
18
- g_norm = g.norm(2)
19
- g_orth_norm = g_orth.norm(2).add_(1e-30)
20
- g_orth_scaled = g_orth * (g_norm / g_orth_norm)
21
- return g_orth_scaled.view(original_shape).to(original_dtype)
@@ -1,32 +0,0 @@
1
- import torch
2
-
3
- def _grams_update(mt: torch.Tensor, grad: torch.Tensor, inplace: bool=False):
4
- """
5
- Applies the update rule of "Gradient Descent with Adaptive Momentum Scaling"
6
- (https://arxiv.org/abs/2412.17107).
7
- """
8
- if inplace:
9
- return mt.abs_().mul_(grad.sign())
10
- return grad.sign().mul_(mt.abs())
11
-
12
- def _cautious_update(mt: torch.Tensor, grad: torch.Tensor, inplace: bool=False):
13
- """
14
- Applies the update rule of "Cautious Optimizers: Improving Training with One
15
- Line of Code" (https://arxiv.org/abs/2411.16085).
16
- """
17
- mask = (mt * grad > 0).to(grad.dtype)
18
- mask.div_(mask.mean().clamp_min_(1e-3))
19
- if inplace:
20
- update_mt = mt.mul_(mask)
21
- else:
22
- update_mt = mt.mul(mask)
23
- del mask
24
- return update_mt
25
-
26
- def _scale_sim_AdEMAMix_update(beta: float, current_step: int, alpha_grad: float, lr: float, scaled_optm: bool=False):
27
- if scaled_optm:
28
- return lr
29
- momentum_scale = (1 - beta ** current_step) / (1 - beta)
30
- total_scale = 1 / (momentum_scale + alpha_grad)
31
- lr = lr * total_scale
32
- return lr
File without changes
File without changes
File without changes