adv-optm 0.1.0__py3-none-any.whl → 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of adv-optm might be problematic. Click here for more details.

adv_optm/__init__.py CHANGED
@@ -2,12 +2,16 @@ from .optim import (
2
2
  AdamW_adv,
3
3
  Prodigy_adv,
4
4
  Adopt_adv,
5
+ Lion_adv,
6
+ Lion_Prodigy_adv,
5
7
  )
6
8
 
7
9
  __all__ = [
8
10
  "AdamW_adv",
9
11
  "Prodigy_adv",
10
12
  "Adopt_adv",
13
+ "Lion_adv",
14
+ "Lion_Prodigy_adv",
11
15
  ]
12
16
 
13
- __version__ = "0.1.0"
17
+ __version__ = "0.1.2"
@@ -22,7 +22,6 @@ class AdamW_adv(torch.optim.Optimizer):
22
22
  eps (float): term added to the denominator to improve
23
23
  numerical stability (default: 1e-8)
24
24
  weight_decay (float): weight decay (L2 penalty) (default: 0)
25
- use_bias_correction (boolean): Turn on Adam's bias correction. (default: False)
26
25
  vector_reshape (bool): whether to reshape 1D vectors into 2D
27
26
  matrices to apply low-rank compression (default: True).
28
27
  stochastic_rounding (bool): whether to use stochastic
@@ -37,7 +36,7 @@ class AdamW_adv(torch.optim.Optimizer):
37
36
  combined with the primary momentum (`mt`) to stabilize updates,
38
37
  especially in noisy, small-batch settings. If `False`, the
39
38
  optimizer behaves as standard AdamW. (default: False)
40
- beta3 (float): The decay rate for the slow exponential moving average of
39
+ beta3_ema (float): The decay rate for the slow exponential moving average of
41
40
  the momentum (only used when `use_AdEMAMix` is `True`). A higher
42
41
  value (e.g., 0.9999) gives the EMA a longer memory, making it more
43
42
  stable but slower to adapt. A lower value (e.g., 0.999) is often
@@ -63,7 +62,6 @@ class AdamW_adv(torch.optim.Optimizer):
63
62
  betas: tuple[float, float] = (0.9, 0.999),
64
63
  eps: float = 1e-8,
65
64
  weight_decay: float = 0.0,
66
- use_bias_correction: bool = False,
67
65
  vector_reshape: bool = True,
68
66
  stochastic_rounding: bool = True,
69
67
  use_atan2: bool = False,
@@ -71,7 +69,7 @@ class AdamW_adv(torch.optim.Optimizer):
71
69
  use_grams: bool = False,
72
70
  use_orthograd: bool = False,
73
71
  use_AdEMAMix: bool = False,
74
- beta3: float = 0.9999,
72
+ beta3_ema: float = 0.9999,
75
73
  alpha: float = 5.0,
76
74
  t_alpha: int | None = None,
77
75
  factored: bool = True,
@@ -88,8 +86,8 @@ class AdamW_adv(torch.optim.Optimizer):
88
86
  defaults = {
89
87
  "lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay,
90
88
  "vector_reshape": vector_reshape, "use_atan2": use_atan2,
91
- "use_orthograd": use_orthograd, "use_bias_correction": use_bias_correction,
92
- "beta3": beta3, "alpha": alpha, "t_alpha": t_alpha,
89
+ "use_orthograd": use_orthograd,
90
+ "beta3_ema": beta3_ema, "alpha": alpha, "t_alpha": t_alpha,
93
91
  }
94
92
  self.stochastic_rounding = stochastic_rounding
95
93
  self.use_cautious = use_cautious
@@ -122,6 +120,8 @@ class AdamW_adv(torch.optim.Optimizer):
122
120
  grad = _orthogonalize_gradient(p, grad)
123
121
  state = self.state[p]
124
122
 
123
+ beta1, beta2 = group['betas']
124
+
125
125
  # State Initialization
126
126
  if len(state) == 0:
127
127
  state['step'] = 0
@@ -141,11 +141,12 @@ class AdamW_adv(torch.optim.Optimizer):
141
141
  d1, d2 = state['effective_shape']
142
142
 
143
143
  # First moment (m)
144
- state['mu_m_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
145
- state['mv_m_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
146
- if not self.use_grams:
147
- packed_d2 = (d2 + 7) // 8
148
- state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=device)
144
+ if beta1 > 0:
145
+ state['mu_m_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
146
+ state['mv_m_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
147
+ if not self.use_grams:
148
+ packed_d2 = (d2 + 7) // 8
149
+ state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=device)
149
150
  if self.use_AdEMAMix:
150
151
  state['mu_m_slow_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
151
152
  state['mv_m_slow_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
@@ -155,14 +156,14 @@ class AdamW_adv(torch.optim.Optimizer):
155
156
  state['mu_v_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
156
157
  state['mv_v_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
157
158
  else: # Fallback to standard AdamW for non-factored tensors
158
- state['exp_avg'] = torch.zeros_like(p, device=device, dtype=dtype)
159
+ if beta1 > 0:
160
+ state['exp_avg'] = torch.zeros_like(p, device=device, dtype=dtype)
159
161
  if self.use_AdEMAMix:
160
162
  state['exp_avg_slow'] = torch.zeros_like(p, dtype=dtype)
161
163
  state['exp_avg_sq'] = torch.zeros_like(p, device=device, dtype=dtype)
162
164
 
163
- beta1, beta2 = group['betas']
164
165
  if self.use_AdEMAMix:
165
- beta3 = group['beta3']
166
+ beta3_ema = group['beta3_ema']
166
167
  alpha = group['alpha']
167
168
  t_alpha = group['t_alpha']
168
169
  current_step = state['step'] + 1
@@ -174,21 +175,22 @@ class AdamW_adv(torch.optim.Optimizer):
174
175
  d1, d2 = state['effective_shape']
175
176
 
176
177
  # Reconstruct momentum from previous step's factors
177
- mt = _unnmf((state['mu_m_nmf'], state['mv_m_nmf']))
178
- if not self.use_grams:
179
- unpacked_sign = _unpack_bools(state['sign'], original_m=d2)
180
- torch.where(unpacked_sign, mt, -mt, out=mt)
181
- del unpacked_sign
182
- # Update momentum in full-size
183
- grad_reshaped = grad.view(d1, d2)
184
- mt.mul_(beta1).add_(grad_reshaped, alpha=1.0 - beta1)
185
- if self.use_grams:
186
- mt.copy_(grad_reshaped.sign() * mt.abs())
187
- elif self.use_cautious:
188
- mask = (mt * grad_reshaped > 0).to(grad_reshaped.dtype)
189
- mask.div_(mask.mean().clamp_(min=1e-3))
190
- mt.mul_(mask)
191
- del mask
178
+ if beta1 > 0:
179
+ mt = _unnmf((state['mu_m_nmf'], state['mv_m_nmf']))
180
+ if not self.use_grams:
181
+ unpacked_sign = _unpack_bools(state['sign'], original_m=d2)
182
+ torch.where(unpacked_sign, mt, -mt, out=mt)
183
+ del unpacked_sign
184
+ # Update momentum in full-size
185
+ grad_reshaped = grad.view(d1, d2)
186
+ mt.mul_(beta1).add_(grad_reshaped, alpha=1.0 - beta1)
187
+ if self.use_grams:
188
+ mt.copy_(grad_reshaped.sign() * mt.abs())
189
+ elif self.use_cautious:
190
+ mask = (mt * grad_reshaped > 0).to(grad_reshaped.dtype)
191
+ mask.div_(mask.mean().clamp_(min=1e-3))
192
+ mt.mul_(mask)
193
+ del mask
192
194
 
193
195
  vt = _unnmf((state['mu_v_nmf'], state['mv_v_nmf']))
194
196
  vt.mul_(beta2).addcmul_(grad_reshaped, grad_reshaped, value=1.0 - beta2)
@@ -201,29 +203,29 @@ class AdamW_adv(torch.optim.Optimizer):
201
203
  torch.where(unpacked_sign_slow, mt_slow, -mt_slow, out=mt_slow)
202
204
  del unpacked_sign_slow
203
205
 
204
- mt_slow.mul_(beta3).add_(grad_reshaped, alpha=1.0 - beta3)
205
- update_m = mt + (alpha_t * mt_slow)
206
+ mt_slow.mul_(beta3_ema).add_(grad_reshaped, alpha=1.0 - beta3_ema)
207
+ update = mt + (alpha_t * mt_slow) if beta1 > 0 else grad_reshaped + (alpha_t * mt_slow)
206
208
  else:
207
- update_m = mt
209
+ update = mt if beta1 > 0 else grad_reshaped
208
210
  del grad_reshaped
209
211
 
210
212
  if group['use_atan2']:
211
213
  a = 1.2732395
212
214
  denom = vt.sqrt()
213
- update = torch.atan2(update_m, denom).mul_(a)
215
+ update.atan2_(denom).mul_(a)
214
216
  else:
215
- denom = vt.sqrt().add_(group['eps'])
216
- update = update_m / denom
217
- del update_m, denom
217
+ denom = vt.sqrt()
218
+ update.div_(denom.add_(group['eps']))
219
+ del denom
218
220
 
219
- update = update.view(p.shape)
220
- update.mul_(group['lr'])
221
+ update.view(p.shape).mul_(group['lr'])
221
222
 
222
223
  # Compress updated moments and store new factors
223
- if not self.use_grams:
224
- state['sign'] = _pack_bools(mt > 0)
225
- _nnmf(mt.abs(), out=(state['mu_m_nmf'], state['mv_m_nmf']))
226
- del mt
224
+ if beta1 > 0:
225
+ if not self.use_grams:
226
+ state['sign'] = _pack_bools(mt > 0)
227
+ _nnmf(mt.abs(), out=(state['mu_m_nmf'], state['mv_m_nmf']))
228
+ del mt
227
229
  if self.use_AdEMAMix:
228
230
  state['sign_slow'] = _pack_bools(mt_slow > 0)
229
231
  _nnmf(mt_slow.abs(), out=(state['mu_m_slow_nmf'], state['mv_m_slow_nmf']))
@@ -232,36 +234,38 @@ class AdamW_adv(torch.optim.Optimizer):
232
234
  del vt
233
235
 
234
236
  else: # Standard AdamW logic for non-factored tensors
235
- exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
236
-
237
- exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
238
- if self.use_grams:
239
- exp_avg = grad.sign() * exp_avg.abs()
240
- elif self.use_cautious:
241
- mask = (exp_avg * grad > 0).to(grad.dtype)
242
- mask.div_(mask.mean().clamp_(min=1e-3))
243
- exp_avg.mul_(mask)
244
- del mask
237
+ exp_avg_sq = state['exp_avg_sq']
238
+
239
+ if beta1 > 0:
240
+ exp_avg = state['exp_avg']
241
+ exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
242
+ if self.use_grams:
243
+ exp_avg = grad.sign() * exp_avg.abs()
244
+ elif self.use_cautious:
245
+ mask = (exp_avg * grad > 0).to(grad.dtype)
246
+ mask.div_(mask.mean().clamp_(min=1e-3))
247
+ exp_avg.mul_(mask)
248
+ del mask
245
249
 
246
250
  if self.use_AdEMAMix:
247
251
  exp_avg_slow = state['exp_avg_slow']
248
- exp_avg_slow.mul_(beta3).add_(grad, alpha=1 - beta3)
249
- update_m = exp_avg + (alpha_t * exp_avg_slow)
252
+ exp_avg_slow.mul_(beta3_ema).add_(grad, alpha=1 - beta3_ema)
253
+ update = exp_avg + (alpha_t * exp_avg_slow) if beta1 > 0 else grad + (alpha_t * exp_avg_slow)
250
254
  else:
251
- update_m = exp_avg
255
+ update = exp_avg if beta1 > 0 else grad
252
256
 
253
257
  exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2)
254
258
 
255
259
  if group['use_atan2']:
256
260
  a = 1.2732395
257
261
  denom = exp_avg_sq.sqrt()
258
- update = torch.atan2(update_m, denom).mul_(a)
262
+ update.atan2_(denom).mul_(a)
259
263
  else:
260
- denom = exp_avg_sq.sqrt().add_(group['eps'])
261
- update = update_m / denom
262
- del update_m, denom
264
+ denom = exp_avg_sq.sqrt()
265
+ update.div_(denom.add_(group['eps']))
266
+ del denom
263
267
 
264
- update = update.mul_(group['lr'])
268
+ update.mul_(group['lr'])
265
269
 
266
270
  # Decoupled weight decay
267
271
  if group["weight_decay"] != 0:
@@ -30,8 +30,6 @@ class Adopt_adv(torch.optim.Optimizer):
30
30
  clip_lambda (Callable, optional): A function that takes the current step
31
31
  and returns a value to clip the normalized gradient. Only used when
32
32
  `use_atan2` is False. (default: `lambda step: step**0.25`)
33
- rank (int): the rank for the low-rank approximation (default: 4).
34
- oversampling (int): oversampling parameter for Randomized SVD. (default: 0).
35
33
  vector_reshape (bool): whether to reshape 1D vectors into 2D
36
34
  matrices for low-rank compression (default: True).
37
35
  stochastic_rounding (bool): whether to use stochastic
@@ -48,7 +46,7 @@ class Adopt_adv(torch.optim.Optimizer):
48
46
  combined with the primary momentum (`mt`) to stabilize updates,
49
47
  especially in noisy, small-batch settings. If `False`, the
50
48
  optimizer behaves as standard ADOPT. (default: False)
51
- beta3 (float): The decay rate for the slow exponential moving average of
49
+ beta3_ema (float): The decay rate for the slow exponential moving average of
52
50
  the momentum (only used when `use_AdEMAMix` is `True`). A higher
53
51
  value (e.g., 0.9999) gives the EMA a longer memory, making it more
54
52
  stable but slower to adapt. A lower value (e.g., 0.999) is often
@@ -83,7 +81,7 @@ class Adopt_adv(torch.optim.Optimizer):
83
81
  use_grams: bool = False,
84
82
  use_orthograd: bool = False,
85
83
  use_AdEMAMix: bool = False,
86
- beta3: float = 0.9999,
84
+ beta3_ema: float = 0.9999,
87
85
  alpha: float = 5.0,
88
86
  t_alpha: int | None = None,
89
87
  factored: bool = True,
@@ -99,7 +97,7 @@ class Adopt_adv(torch.optim.Optimizer):
99
97
 
100
98
  defaults = {
101
99
  "lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay,
102
- "vector_reshape": vector_reshape, "beta3": beta3, "alpha": alpha,
100
+ "vector_reshape": vector_reshape, "beta3_ema": beta3_ema, "alpha": alpha,
103
101
  "t_alpha": t_alpha,
104
102
  }
105
103
  self.clip_lambda = clip_lambda
@@ -179,7 +177,7 @@ class Adopt_adv(torch.optim.Optimizer):
179
177
 
180
178
  beta1, beta2 = group['betas']
181
179
  if self.use_AdEMAMix:
182
- beta3 = group['beta3']
180
+ beta3_ema = group['beta3_ema']
183
181
  alpha = group['alpha']
184
182
  t_alpha = group['t_alpha']
185
183
  # Use step+1 for 1-based step count in scheduler
@@ -192,29 +190,29 @@ class Adopt_adv(torch.optim.Optimizer):
192
190
  d1, d2 = state['effective_shape']
193
191
 
194
192
  # Reconstruct m_{t-1}
195
- mt_prev = _unnmf((state['mu_m_nmf'], state['mv_m_nmf']))
193
+ mt = _unnmf((state['mu_m_nmf'], state['mv_m_nmf']))
196
194
  if not self.use_grams:
197
195
  if state['sign'].dtype != torch.uint8:
198
196
  state['sign'] = state['sign'].to(torch.uint8)
199
197
  unpacked_sign = _unpack_bools(state['sign'], original_m=d2)
200
- torch.where(unpacked_sign, mt_prev, -mt_prev, out=mt_prev)
198
+ torch.where(unpacked_sign, mt, -mt, out=mt)
201
199
  del unpacked_sign
202
200
 
203
201
  # Reconstruct AdEMAMix EMA
204
202
  if self.use_AdEMAMix:
205
- mt_slow_prev = _unnmf((state['mu_m_slow_nmf'], state['mv_m_slow_nmf']))
203
+ mt_slow = _unnmf((state['mu_m_slow_nmf'], state['mv_m_slow_nmf']))
206
204
  if state['sign_slow'].dtype != torch.uint8:
207
205
  state['sign_slow'] = state['sign_slow'].to(torch.uint8)
208
206
  unpacked_sign_slow = _unpack_bools(state['sign_slow'], original_m=d2)
209
- torch.where(unpacked_sign_slow, mt_slow_prev, -mt_slow_prev, out=mt_slow_prev)
207
+ torch.where(unpacked_sign_slow, mt_slow, -mt_slow, out=mt_slow)
210
208
  del unpacked_sign_slow
211
209
 
212
210
  # Reconstruct v_{t-1} using NNMF
213
- vt_prev = _unnmf((state['mu_v_nmf'], state['mv_v_nmf']))
211
+ vt = _unnmf((state['mu_v_nmf'], state['mv_v_nmf']))
214
212
 
215
213
  # ADOPT Step A: Decorrelate g_t using v_{t-1}
216
214
  grad_reshaped = grad.view(d1, d2)
217
- denom = vt_prev.sqrt()
215
+ denom = vt.sqrt()
218
216
 
219
217
  if self.use_atan2:
220
218
  normalized_grad = torch.atan2(grad_reshaped, denom)
@@ -226,7 +224,7 @@ class Adopt_adv(torch.optim.Optimizer):
226
224
  del denom
227
225
 
228
226
  # ADOPT Step B: Update momentum m_t using normalized gradient
229
- mt = mt_prev.mul_(beta1).add_(normalized_grad, alpha=1.0 - beta1)
227
+ mt.mul_(beta1).add_(normalized_grad, alpha=1.0 - beta1)
230
228
  if self.use_grams:
231
229
  mt = grad_reshaped.sign() * mt.abs()
232
230
  elif self.use_cautious:
@@ -236,7 +234,7 @@ class Adopt_adv(torch.optim.Optimizer):
236
234
  del mask
237
235
 
238
236
  if self.use_AdEMAMix:
239
- mt_slow = mt_slow_prev.mul_(beta3).add_(normalized_grad, alpha=1.0 - beta3)
237
+ mt_slow.mul_(beta3_ema).add_(normalized_grad, alpha=1.0 - beta3_ema)
240
238
  update = mt + (alpha_t * mt_slow)
241
239
  update = update.view(p.shape)
242
240
  else:
@@ -248,20 +246,23 @@ class Adopt_adv(torch.optim.Optimizer):
248
246
  update.mul_(group['lr'])
249
247
 
250
248
  # Update second moment v_t for the *next* step using raw g_t
251
- vt_updated = vt_prev.mul_(beta2).addcmul_(grad_reshaped, grad_reshaped, value=1.0 - beta2)
249
+ vt.mul_(beta2).addcmul_(grad_reshaped, grad_reshaped, value=1.0 - beta2)
252
250
  del grad_reshaped
253
251
 
254
252
  # Compress and store new factors
255
253
  if not self.use_grams:
256
254
  state['sign'] = _pack_bools(mt > 0)
257
255
  _nnmf(mt.abs(), out=(state['mu_m_nmf'], state['mv_m_nmf']))
256
+ del mt
258
257
 
259
258
  if self.use_AdEMAMix:
260
259
  state['sign_slow'] = _pack_bools(mt_slow > 0)
261
260
  _nnmf(mt_slow.abs(), out=(state['mu_m_slow_nmf'], state['mv_m_slow_nmf']))
261
+ del mt_slow
262
262
 
263
263
  # factorize v_t using NMF compression
264
- _nnmf(vt_updated, out=(state['mu_v_nmf'], state['mv_v_nmf']))
264
+ _nnmf(vt, out=(state['mu_v_nmf'], state['mv_v_nmf']))
265
+ del vt
265
266
 
266
267
  else: # Standard ADOPT logic for non-factored tensors
267
268
  m, v = state['exp_avg'], state['exp_avg_sq'] # m_{t-1}, v_{t-1}
@@ -293,7 +294,7 @@ class Adopt_adv(torch.optim.Optimizer):
293
294
  del mask
294
295
 
295
296
  if self.use_AdEMAMix:
296
- m_slow.mul_(beta3).add_(normalized_grad, alpha=1.0 - beta3)
297
+ m_slow.mul_(beta3_ema).add_(normalized_grad, alpha=1.0 - beta3_ema)
297
298
  update = m + (alpha_t * m_slow)
298
299
  else:
299
300
  update = m
@@ -0,0 +1,335 @@
1
+ import torch
2
+ import torch.distributed as dist
3
+
4
+ import math
5
+
6
+ from typing import Tuple, Optional
7
+
8
+ from ..util.BF16_Stochastic_Rounding import add_stochastic_
9
+ from ..util.Effective_Shape import _get_effective_shape
10
+ from ..util.NNMF import _nnmf,_unnmf
11
+ from ..util.OrthoGrad import _orthogonalize_gradient
12
+ from ..util.One_Bit_Boolean import _pack_bools, _unpack_bools
13
+
14
+ class Lion_Prodigy_adv(torch.optim.Optimizer):
15
+ """
16
+ Implements the SMMF technique and Prodigy D-Adaptation method for Lion algorithm.
17
+
18
+ Args:
19
+ params (iterable): iterable of parameters to optimize or dicts defining
20
+ parameter groups.
21
+ lr (float, optional): learning rate (default: 1e-4).
22
+ betas (Tuple[float, float], optional): coefficients for computing
23
+ running averages of the update (default: (0.9, 0.99)).
24
+ weight_decay (float, optional): weight decay (L2 penalty) (default: 0.0).
25
+ vector_reshape (bool, optional): whether to reshape 1D vectors into 2D
26
+ matrices to apply low-rank compression (default: True).
27
+ stochastic_rounding (bool, optional): whether to use stochastic
28
+ rounding for BF16 parameter updates (default: True).
29
+ use_cautious (bool): whether to use the cautious masking technique. (default: False).
30
+ clip_threshold (float, optional): whether to clip the gradients norm
31
+ per-parameter as proposed in the paper `Lions and Muons: Optimization via
32
+ Stochastic Frank-Wolfe` (https://arxiv.org/abs/2506.04192) to make Lion more stable
33
+ (default: 0.0).
34
+ factored (bool): whether to use the factorization or use the
35
+ uncompressed optimizer. (default: True)
36
+ variance_reduction (bool): whether to use the variance reduction technique
37
+ from "Convergence Analysis of the Lion Optimizer" (arXiv:2508.12327v1). (default: False).
38
+ d0 (float):
39
+ Initial D estimate for D-adaptation (default 1e-6). Rarely needs changing.
40
+ d_coef (float):
41
+ Coefficient in the expression for the estimate of d (default 1.0).
42
+ Values such as 0.5 and 2.0 typically work as well.
43
+ Changing this parameter is the preferred way to tune the method.
44
+ growth_rate (float):
45
+ prevent the D estimate from growing faster than this multiplicative rate.
46
+ Default is inf, for unrestricted. Values like 1.02 give a kind of learning
47
+ rate warmup effect.
48
+ fsdp_in_use (bool):
49
+ If you're using sharded parameters, this should be set to True. The optimizer
50
+ will attempt to auto-detect this, but if you're using an implementation other
51
+ than PyTorch's builtin version, the auto-detection won't work.
52
+ slice_p (int): Reduce memory usage by calculating LR adaptation statistics on only every
53
+ pth entry of each tensor. For values greater than 1 this an an approximation to standard
54
+ Prodigy. Values ~11 are reasonable (default 11).
55
+ """
56
+
57
+ def __init__(
58
+ self,
59
+ params,
60
+ lr: float = 1,
61
+ betas: Tuple[float, float] = (0.9, 0.99),
62
+ weight_decay: float = 0.0,
63
+ vector_reshape: bool = True,
64
+ stochastic_rounding: bool = True,
65
+ use_orthograd: bool = False,
66
+ use_cautious: bool = False,
67
+ clip_threshold: float = 0.0,
68
+ factored: bool = True,
69
+ variance_reduction: bool = False,
70
+ # prodigy parameters
71
+ beta3: float = None,
72
+ d0: float = 1e-6,
73
+ d_coef: float = 1,
74
+ growth_rate: float = float('inf'),
75
+ safeguard_warmup: bool = False,
76
+ fsdp_in_use: bool = False,
77
+ slice_p: int = 11,
78
+ ):
79
+ if not lr > 0.0:
80
+ raise ValueError(f"Learning rate must be > 0.0, but got {lr}")
81
+ if not all(0.0 <= beta <= 1.0 for beta in betas):
82
+ raise ValueError(f"Betas should be in [0.0, 1.0], but got {betas}")
83
+ if not weight_decay >= 0.0:
84
+ raise ValueError(f"Weight decay must be >= 0.0, but got {weight_decay}")
85
+ if variance_reduction and use_cautious:
86
+ print("Warning: Using both 'variance_reduction' and 'use_cautious' is not recommended and may lead to unintended effects.")
87
+
88
+ defaults = dict(
89
+ lr=lr,
90
+ betas=betas,
91
+ weight_decay=weight_decay,
92
+ vector_reshape=vector_reshape,
93
+ use_orthograd=use_orthograd,
94
+ clip_threshold=clip_threshold,
95
+ beta3=beta3, d=d0, d0=d0, d_max=d0, d_numerator=0.0, d_coef=d_coef,
96
+ growth_rate=growth_rate, safeguard_warmup=safeguard_warmup, k=0, slice_p=slice_p,
97
+ fsdp_in_use=fsdp_in_use,
98
+ )
99
+ self.stochastic_rounding = stochastic_rounding
100
+ self.use_cautious = use_cautious
101
+ self.factored = factored
102
+ self.variance_reduction = variance_reduction
103
+ self.fsdp_in_use = fsdp_in_use
104
+ super().__init__(params, defaults)
105
+ # Global state for accumulating metrics across parameter updates within a single step.
106
+ self.init_step()
107
+
108
+ @property
109
+ def supports_fused_back_pass(self) -> bool:
110
+ return True
111
+
112
+ @property
113
+ def supports_memory_efficient_fp16(self) -> bool:
114
+ return True
115
+
116
+ @property
117
+ def supports_flat_params(self) -> bool:
118
+ return False
119
+
120
+ def init_step(self):
121
+ """Resets accumulators and calculates dlr for the upcoming step."""
122
+ self.d_denom = 0.0
123
+
124
+ g_group = self.param_groups[0]
125
+ self.beta1, self.beta2 = g_group['betas']
126
+ self.beta3 = g_group['beta3']
127
+ if self.beta3 is None:
128
+ self.beta3 = math.sqrt(self.beta2)
129
+
130
+ k = g_group['k']
131
+ self.d = g_group['d']
132
+ lr = g_group['lr']
133
+
134
+ self.dlr = self.d * lr
135
+
136
+ self.d_numerator = g_group.get('d_numerator', 0.0) * self.beta3
137
+
138
+ @torch.no_grad()
139
+ def step_parameter(self, p: torch.Tensor, group: dict, i: Optional[int] = None):
140
+ """Performs a single optimization step on a single parameter."""
141
+ if p.grad is None:
142
+ return
143
+
144
+ if hasattr(p, "_fsdp_flattened"):
145
+ self.fsdp_in_use = True
146
+
147
+ grad = p.grad
148
+ if grad.dtype != torch.float32 and self.factored:
149
+ grad = grad.float()
150
+ if group["clip_threshold"] > 0.0:
151
+ grad_norm = torch.norm(grad.detach())
152
+ if grad_norm > group["clip_threshold"]:
153
+ clip_coef = group["clip_threshold"] / grad_norm
154
+ grad.mul_(clip_coef)
155
+ if group["use_orthograd"]:
156
+ grad = _orthogonalize_gradient(p, grad)
157
+ state = self.state[p]
158
+
159
+ # State Initialization
160
+ if len(state) == 0:
161
+ state['step'] = 0
162
+
163
+ should_factor = (
164
+ self.factored and
165
+ not (len(p.shape) == 1 and not group['vector_reshape'])
166
+ )
167
+
168
+ state['factored'] = should_factor
169
+
170
+ dtype = torch.float32 if self.factored else p.dtype
171
+
172
+ slice_p = group['slice_p']
173
+
174
+ # D-Adaptation states
175
+ state['s'] = torch.zeros_like(p.flatten()[::slice_p]).detach()
176
+ if p.any():
177
+ state['p0'] = p.flatten()[::slice_p].detach().clone()
178
+ else:
179
+ state['p0'] = torch.tensor(0, device=p.device, dtype=p.dtype)
180
+
181
+ if state['factored']:
182
+ state['effective_shape'] = _get_effective_shape(p.numel())
183
+ d1, d2 = state['effective_shape']
184
+ state['mu_m_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
185
+ state['mv_m_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
186
+ packed_d2 = (d2 + 7) // 8
187
+ state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=p.device)
188
+ if self.variance_reduction:
189
+ state['prev_grad'] = torch.zeros((d1, d2), device=p.device, dtype=dtype)
190
+ else: # Fallback to standard Lion
191
+ state['exp_avg'] = torch.zeros_like(p, device=p.device, dtype=dtype)
192
+ if self.variance_reduction:
193
+ state['prev_grad'] = torch.zeros_like(p, device=p.device, dtype=dtype)
194
+
195
+ if state['factored']:
196
+ # Factored Path
197
+ d1, d2 = state['effective_shape']
198
+ grad_reshaped = grad.view(d1, d2)
199
+ # Reconstruct momentum m_{t-1}
200
+ exp_avg = _unnmf((state['mu_m_nmf'], state['mv_m_nmf']))
201
+ unpacked_sign = _unpack_bools(state['sign'], original_m=d2)
202
+ torch.where(unpacked_sign, exp_avg, -exp_avg, out=exp_avg)
203
+ del unpacked_sign
204
+ if exp_avg.dtype != torch.float32:
205
+ exp_avg = exp_avg.float()
206
+
207
+ # Compute update term c_t = β1*m_{t-1} + (1-β1)*g_t
208
+ signed_update = exp_avg.clone().mul_(self.beta1).add_(grad_reshaped, alpha=(1-self.beta1)).sign_()
209
+
210
+ if self.use_cautious:
211
+ mask = (signed_update * grad_reshaped > 0).to(grad_reshaped.dtype)
212
+ mask.div_(mask.mean().clamp_(min=1e-3))
213
+ signed_update.mul_(mask)
214
+ del mask
215
+
216
+ # Parameter update: p_t = p_{t-1} - lr * sign(c_t)
217
+ update_for_param = signed_update.view(p.shape).mul(self.dlr)
218
+
219
+ # Update momentum m_t = β2*m_{t-1} + (1-β2)*lr*g_t
220
+ if self.variance_reduction:
221
+ vr_term = grad_reshaped - state['prev_grad']
222
+ exp_avg.mul_(self.beta2).add_(grad_reshaped, alpha=self.d * (1-self.beta2)).add_(vr_term, alpha=self.beta2)
223
+ state['prev_grad'].copy_(grad_reshaped)
224
+ else:
225
+ exp_avg.mul_(self.beta2).add_(grad_reshaped, alpha=self.d * (1-self.beta2))
226
+ del grad_reshaped
227
+
228
+ # Compress new momentum m_t and store factors
229
+ state['sign'] = _pack_bools(exp_avg > 0)
230
+ _nnmf(exp_avg.abs(), out=(state['mu_m_nmf'], state['mv_m_nmf']))
231
+ del exp_avg
232
+
233
+ else:
234
+ # Fallback to standard Lion logic
235
+ exp_avg = state["exp_avg"]
236
+
237
+ # Compute update term and sign for the update
238
+ if exp_avg.dtype != torch.float32 and self.factored:
239
+ exp_avg = exp_avg.float()
240
+ signed_update = exp_avg.clone().mul_(self.beta1).add_(grad, alpha=(1-self.beta1)).sign_()
241
+
242
+ if self.use_cautious:
243
+ mask = (signed_update * grad > 0).to(grad.dtype)
244
+ mask.div_(mask.mean().clamp_(min=1e-3))
245
+ signed_update.mul_(mask)
246
+ del mask
247
+
248
+ update_for_param = signed_update.mul(self.dlr)
249
+
250
+ # Update momentum
251
+ if self.variance_reduction:
252
+ vr_term = grad - state['prev_grad']
253
+ exp_avg.mul_(self.beta2).add_(grad, alpha=self.d * (1-self.beta2)).add_(vr_term, alpha=self.beta2)
254
+ state['prev_grad'].copy_(grad)
255
+ else:
256
+ exp_avg.mul_(self.beta2).add_(grad, alpha=self.d * (1-self.beta2))
257
+
258
+ # --- Accumulate Prodigy stats ---
259
+ d0, safeguard_warmup, slice_p = group['d0'], group['safeguard_warmup'], group['slice_p']
260
+ s, p0 = state['s'], state['p0']
261
+ grad_flat = grad.flatten().float()
262
+ p_flat = p.data.flatten().float()
263
+ p0 = p0.float()
264
+
265
+ self.d_numerator += (self.d / d0) * self.dlr * torch.dot(grad_flat[::slice_p], p0.data - p_flat[::slice_p]).item()
266
+
267
+ alpha = ((self.d / d0) * self.d) if safeguard_warmup else ((self.d / d0) * self.dlr)
268
+ s.mul_(self.beta3).add_(grad_flat[::slice_p], alpha=alpha)
269
+ self.d_denom += s.abs().sum().item()
270
+
271
+ del s, p0, grad_flat, p_flat, alpha
272
+
273
+ if group["weight_decay"] != 0:
274
+ if p.dtype == torch.bfloat16 and self.stochastic_rounding:
275
+ add_stochastic_(p.data, p.data,
276
+ alpha=-group["weight_decay"] * self.dlr)
277
+ else:
278
+ p.data.add_(
279
+ p.data, alpha=-group["weight_decay"] * self.dlr
280
+ )
281
+
282
+ if p.dtype == torch.bfloat16 and self.stochastic_rounding:
283
+ add_stochastic_(p.data, -update_for_param)
284
+ else:
285
+ p.data.add_(-update_for_param)
286
+
287
+ del update_for_param
288
+
289
+ @torch.no_grad()
290
+ def step(self, closure: Optional[callable] = None):
291
+ """Performs a single optimization step."""
292
+ loss = None
293
+ if closure is not None:
294
+ with torch.enable_grad():
295
+ loss = closure()
296
+
297
+ for group in self.param_groups:
298
+ for i, p in enumerate(group["params"]):
299
+ if p.grad is not None:
300
+ self.step_parameter(p, group, i)
301
+
302
+
303
+ self.calculate_d()
304
+ self.init_step()
305
+ return loss
306
+
307
+ def calculate_d(self):
308
+ """Calculates the new `d` based on the accumulated stats."""
309
+ g_group = self.param_groups[0]
310
+ d_max, d_coef, growth_rate = g_group['d_max'], g_group['d_coef'], g_group['growth_rate']
311
+
312
+ if self.fsdp_in_use and dist.is_available() and dist.is_initialized():
313
+ # Use the device of the first parameter to avoid hardcoding '.cuda()'
314
+ device = self.param_groups[0]['params'][0].device
315
+ dist_tensor = torch.tensor([self.d_numerator, self.d_denom], device=device)
316
+ dist.all_reduce(dist_tensor, op=dist.ReduceOp.SUM)
317
+ global_d_numerator = dist_tensor[0].item()
318
+ global_d_denom = dist_tensor[1].item()
319
+ else:
320
+ global_d_numerator = self.d_numerator
321
+ global_d_denom = self.d_denom
322
+
323
+ d_hat = self.d
324
+ if global_d_denom > 0:
325
+ d_hat = d_coef * global_d_numerator / global_d_denom
326
+ if self.d == g_group['d0']:
327
+ self.d = max(self.d, d_hat)
328
+ d_max = max(d_max, d_hat)
329
+ self.d = min(d_max, self.d * growth_rate)
330
+
331
+ for group in self.param_groups:
332
+ group['d_numerator'] = global_d_numerator
333
+ group['d'] = self.d
334
+ group['d_max'] = d_max
335
+ group['k'] += 1
@@ -0,0 +1,231 @@
1
+ import torch
2
+
3
+ from typing import Tuple, Optional
4
+
5
+ from ..util.BF16_Stochastic_Rounding import add_stochastic_
6
+ from ..util.Effective_Shape import _get_effective_shape
7
+ from ..util.NNMF import _nnmf,_unnmf
8
+ from ..util.OrthoGrad import _orthogonalize_gradient
9
+ from ..util.One_Bit_Boolean import _pack_bools, _unpack_bools
10
+
11
+ class Lion_adv(torch.optim.Optimizer):
12
+ """
13
+ Implements the SMMF technique for Lion algorithm.
14
+
15
+ This optimizer combines the Lion update rule with the memory-saving low-rank
16
+ compression (SMMF) technique from https://arxiv.org/abs/2412.08894.
17
+
18
+ Args:
19
+ params (iterable): iterable of parameters to optimize or dicts defining
20
+ parameter groups.
21
+ lr (float, optional): learning rate (default: 1e-4).
22
+ betas (Tuple[float, float], optional): coefficients for computing
23
+ running averages of the update (default: (0.9, 0.99)).
24
+ weight_decay (float, optional): weight decay (L2 penalty) (default: 0.0).
25
+ vector_reshape (bool, optional): whether to reshape 1D vectors into 2D
26
+ matrices to apply low-rank compression (default: True).
27
+ stochastic_rounding (bool, optional): whether to use stochastic
28
+ rounding for BF16 parameter updates (default: True).
29
+ use_cautious (bool): whether to use the cautious masking technique. (default: False).
30
+ clip_threshold (float, optional): whether to clip the gradients norm
31
+ per-parameter as proposed in the paper `Lions and Muons: Optimization via
32
+ Stochastic Frank-Wolfe` (https://arxiv.org/abs/2506.04192) to make Lion more stable
33
+ (default: 0.0).
34
+ factored (bool): whether to use the factorization or use the
35
+ uncompressed optimizer. (default: True)
36
+ variance_reduction (bool): whether to use the variance reduction technique
37
+ from "Convergence Analysis of the Lion Optimizer" (arXiv:2508.12327v1). (default: False).
38
+ """
39
+
40
+ def __init__(
41
+ self,
42
+ params,
43
+ lr: float = 1e-4,
44
+ betas: Tuple[float, float] = (0.9, 0.99),
45
+ weight_decay: float = 0.0,
46
+ vector_reshape: bool = True,
47
+ stochastic_rounding: bool = True,
48
+ use_orthograd: bool = False,
49
+ use_cautious: bool = False,
50
+ clip_threshold: float = 0.0,
51
+ factored: bool = True,
52
+ variance_reduction: bool = False,
53
+ ):
54
+ if not lr > 0.0:
55
+ raise ValueError(f"Learning rate must be > 0.0, but got {lr}")
56
+ if not all(0.0 <= beta <= 1.0 for beta in betas):
57
+ raise ValueError(f"Betas should be in [0.0, 1.0], but got {betas}")
58
+ if not weight_decay >= 0.0:
59
+ raise ValueError(f"Weight decay must be >= 0.0, but got {weight_decay}")
60
+ if variance_reduction and use_cautious:
61
+ print("Warning: Using both 'variance_reduction' and 'use_cautious' is not recommended and may lead to unintended effects.")
62
+
63
+ defaults = dict(
64
+ lr=lr,
65
+ betas=betas,
66
+ weight_decay=weight_decay,
67
+ vector_reshape=vector_reshape,
68
+ use_orthograd=use_orthograd,
69
+ clip_threshold=clip_threshold,
70
+ )
71
+ self.stochastic_rounding = stochastic_rounding
72
+ self.use_cautious = use_cautious
73
+ self.factored = factored
74
+ self.variance_reduction = variance_reduction
75
+ super().__init__(params, defaults)
76
+
77
+ @property
78
+ def supports_fused_back_pass(self) -> bool:
79
+ return True
80
+
81
+ @property
82
+ def supports_memory_efficient_fp16(self) -> bool:
83
+ return True
84
+
85
+ @property
86
+ def supports_flat_params(self) -> bool:
87
+ return False
88
+
89
+ @torch.no_grad()
90
+ def step_parameter(self, p: torch.Tensor, group: dict, i: Optional[int] = None):
91
+ """Performs a single optimization step on a single parameter."""
92
+ if p.grad is None:
93
+ return
94
+
95
+ grad = p.grad
96
+ if grad.dtype != torch.float32 and self.factored:
97
+ grad = grad.float()
98
+ if group["clip_threshold"] > 0.0:
99
+ grad_norm = torch.norm(grad.detach())
100
+ if grad_norm > group["clip_threshold"]:
101
+ clip_coef = group["clip_threshold"] / grad_norm
102
+ grad.mul_(clip_coef)
103
+ if group["use_orthograd"]:
104
+ grad = _orthogonalize_gradient(p, grad)
105
+ state = self.state[p]
106
+
107
+ # State Initialization
108
+ if len(state) == 0:
109
+ state['step'] = 0
110
+
111
+ should_factor = (
112
+ self.factored and
113
+ not (len(p.shape) == 1 and not group['vector_reshape'])
114
+ )
115
+
116
+ state['factored'] = should_factor
117
+
118
+ dtype = torch.float32 if self.factored else p.dtype
119
+
120
+ if state['factored']:
121
+ state['effective_shape'] = _get_effective_shape(p.numel())
122
+ d1, d2 = state['effective_shape']
123
+ state['mu_m_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
124
+ state['mv_m_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
125
+ packed_d2 = (d2 + 7) // 8
126
+ state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=p.device)
127
+ if self.variance_reduction:
128
+ state['prev_grad'] = torch.zeros((d1, d2), device=p.device, dtype=dtype)
129
+ else: # Fallback to standard Lion
130
+ state['exp_avg'] = torch.zeros_like(p, device=p.device, dtype=dtype)
131
+ if self.variance_reduction:
132
+ state['prev_grad'] = torch.zeros_like(p, device=p.device, dtype=dtype)
133
+
134
+ state['step'] += 1
135
+ beta1, beta2 = group["betas"]
136
+ lr = group["lr"]
137
+
138
+ if state['factored']:
139
+ # Factored Path
140
+ d1, d2 = state['effective_shape']
141
+ grad_reshaped = grad.view(d1, d2)
142
+ # Reconstruct momentum m_{t-1}
143
+ exp_avg = _unnmf((state['mu_m_nmf'], state['mv_m_nmf']))
144
+ unpacked_sign = _unpack_bools(state['sign'], original_m=d2)
145
+ torch.where(unpacked_sign, exp_avg, -exp_avg, out=exp_avg)
146
+ del unpacked_sign
147
+ if exp_avg.dtype != torch.float32:
148
+ exp_avg = exp_avg.float()
149
+
150
+ # Compute update term c_t = β1*m_{t-1} + (1-β1)*g_t
151
+ signed_update = exp_avg.clone().mul_(beta1).add_(grad_reshaped, alpha=(1-beta1)).sign_()
152
+
153
+ if self.use_cautious:
154
+ mask = (signed_update * grad_reshaped > 0).to(grad_reshaped.dtype)
155
+ mask.div_(mask.mean().clamp_(min=1e-3))
156
+ signed_update.mul_(mask)
157
+ del mask
158
+
159
+ # Parameter update: p_t = p_{t-1} - lr * sign(c_t)
160
+ update_for_param = signed_update.view(p.shape).mul_(lr)
161
+
162
+ # Update momentum m_t = β2*m_{t-1} + (1-β2)*lr*g_t
163
+ if self.variance_reduction:
164
+ vr_term = grad_reshaped - state['prev_grad']
165
+ exp_avg.mul_(beta2).add_(grad_reshaped, alpha=1-beta2).add_(vr_term, alpha=beta2)
166
+ del vr_term
167
+ state['prev_grad'].copy_(grad_reshaped)
168
+ else:
169
+ exp_avg.mul_(beta2).add_(grad_reshaped, alpha=1-beta2)
170
+ del grad_reshaped
171
+
172
+ # Compress new momentum m_t and store factors
173
+ state['sign'] = _pack_bools(exp_avg > 0)
174
+ _nnmf(exp_avg.abs(), out=(state['mu_m_nmf'], state['mv_m_nmf']))
175
+ del exp_avg
176
+
177
+ else:
178
+ # Fallback to standard Lion logic
179
+ exp_avg = state["exp_avg"]
180
+
181
+ # Compute update term and sign for the update
182
+ if exp_avg.dtype != torch.float32 and self.factored:
183
+ exp_avg = exp_avg.float()
184
+ signed_update = exp_avg.clone().mul_(beta1).add_(grad, alpha=(1-beta1)).sign_()
185
+
186
+ if self.use_cautious:
187
+ mask = (signed_update * grad > 0).to(grad.dtype)
188
+ mask.div_(mask.mean().clamp_(min=1e-3))
189
+ signed_update.mul_(mask)
190
+ del mask
191
+
192
+ update_for_param = signed_update.mul_(lr)
193
+
194
+ # Update momentum
195
+ if self.variance_reduction:
196
+ vr_term = grad - state['prev_grad']
197
+ exp_avg.mul_(beta2).add_(grad, alpha=1-beta2).add_(vr_term, alpha=beta2)
198
+ state['prev_grad'].copy_(grad)
199
+ else:
200
+ exp_avg.mul_(beta2).add_(grad, alpha=1-beta2)
201
+
202
+ if group["weight_decay"] != 0:
203
+ if p.dtype == torch.bfloat16 and self.stochastic_rounding:
204
+ add_stochastic_(p.data, p.data,
205
+ alpha=-group["weight_decay"] * lr)
206
+ else:
207
+ p.data.add_(
208
+ p.data, alpha=-group["weight_decay"] * lr
209
+ )
210
+
211
+ if p.dtype == torch.bfloat16 and self.stochastic_rounding:
212
+ add_stochastic_(p.data, -update_for_param)
213
+ else:
214
+ p.data.add_(-update_for_param)
215
+
216
+ del update_for_param
217
+
218
+ @torch.no_grad()
219
+ def step(self, closure: Optional[callable] = None):
220
+ """Performs a single optimization step."""
221
+ loss = None
222
+ if closure is not None:
223
+ with torch.enable_grad():
224
+ loss = closure()
225
+
226
+ for group in self.param_groups:
227
+ for i, p in enumerate(group["params"]):
228
+ if p.grad is not None:
229
+ self.step_parameter(p, group, i)
230
+
231
+ return loss
@@ -1,5 +1,6 @@
1
1
  import torch
2
- from typing import Optional
2
+ import torch.distributed as dist
3
+
3
4
  import math
4
5
 
5
6
  from ..util.BF16_Stochastic_Rounding import add_stochastic_
@@ -17,7 +18,7 @@ class Prodigy_adv(torch.optim.Optimizer):
17
18
  Args:
18
19
  params (iterable): iterable of parameters to optimize or dicts defining
19
20
  parameter groups
20
- lr (float): learning rate (default: 1e-3)
21
+ lr (float): learning rate (default: 1)
21
22
  betas (tuple[float, float]): coefficients used for computing running
22
23
  averages of gradient and its square (default: (0.9, 0.999))
23
24
  eps (float): term added to the denominator to improve
@@ -54,12 +55,29 @@ class Prodigy_adv(torch.optim.Optimizer):
54
55
  the scheduler is disabled and th
55
56
  factored (bool): whether to use the factorization or disable it to use
56
57
  the uncompressed optimizer. (default: True)
58
+ d0 (float):
59
+ Initial D estimate for D-adaptation (default 1e-6). Rarely needs changing.
60
+ d_coef (float):
61
+ Coefficient in the expression for the estimate of d (default 1.0).
62
+ Values such as 0.5 and 2.0 typically work as well.
63
+ Changing this parameter is the preferred way to tune the method.
64
+ growth_rate (float):
65
+ prevent the D estimate from growing faster than this multiplicative rate.
66
+ Default is inf, for unrestricted. Values like 1.02 give a kind of learning
67
+ rate warmup effect.
68
+ fsdp_in_use (bool):
69
+ If you're using sharded parameters, this should be set to True. The optimizer
70
+ will attempt to auto-detect this, but if you're using an implementation other
71
+ than PyTorch's builtin version, the auto-detection won't work.
72
+ slice_p (int): Reduce memory usage by calculating LR adaptation statistics on only every
73
+ pth entry of each tensor. For values greater than 1 this an an approximation to standard
74
+ Prodigy. Values ~11 are reasonable (default 11).
57
75
  """
58
76
 
59
77
  def __init__(
60
78
  self,
61
79
  params,
62
- lr: float = 1e-3,
80
+ lr: float = 1,
63
81
  betas: tuple[float, float] = (0.9, 0.999),
64
82
  eps: float = 1e-8,
65
83
  weight_decay: float = 0.0,
@@ -80,6 +98,7 @@ class Prodigy_adv(torch.optim.Optimizer):
80
98
  d_coef: float = 1,
81
99
  growth_rate: float = float('inf'),
82
100
  safeguard_warmup: bool = False,
101
+ fsdp_in_use: bool = False,
83
102
  slice_p: int = 11,
84
103
  ):
85
104
  if not (lr >= 0.0):
@@ -98,12 +117,14 @@ class Prodigy_adv(torch.optim.Optimizer):
98
117
  "beta3_ema": beta3_ema, "alpha": alpha, "t_alpha": t_alpha,
99
118
  "beta3": beta3, "d": d0, "d0": d0, "d_max": d0, "d_numerator": 0.0, "d_coef": d_coef,
100
119
  "growth_rate": growth_rate, "safeguard_warmup": safeguard_warmup, "k": 0, "slice_p": slice_p,
120
+ "fsdp_in_use": fsdp_in_use,
101
121
  }
102
122
  self.stochastic_rounding = stochastic_rounding
103
123
  self.use_cautious = use_cautious
104
124
  self.use_grams = use_grams
105
125
  self.use_AdEMAMix = use_AdEMAMix
106
126
  self.factored = factored
127
+ self.fsdp_in_use = fsdp_in_use
107
128
  super().__init__(params, defaults)
108
129
  self.init_step()
109
130
 
@@ -142,6 +163,9 @@ class Prodigy_adv(torch.optim.Optimizer):
142
163
  if p.grad is None:
143
164
  return
144
165
 
166
+ if hasattr(p, "_fsdp_flattened"):
167
+ self.fsdp_in_use = True
168
+
145
169
  grad = p.grad
146
170
  if grad.dtype != torch.float32 and self.factored:
147
171
  grad = grad.float()
@@ -246,7 +270,7 @@ class Prodigy_adv(torch.optim.Optimizer):
246
270
  denom = vt.sqrt()
247
271
  update = torch.atan2(update_m, denom).mul_(a)
248
272
  else:
249
- denom = vt.sqrt().add_(group['eps'])
273
+ denom = vt.sqrt().add_(self.d * group['eps'])
250
274
  update = update_m / denom
251
275
  del update_m, denom
252
276
 
@@ -291,7 +315,7 @@ class Prodigy_adv(torch.optim.Optimizer):
291
315
  denom = exp_avg_sq.sqrt()
292
316
  update = torch.atan2(update_m, denom).mul_(a)
293
317
  else:
294
- denom = exp_avg_sq.sqrt().add_(group['eps'])
318
+ denom = exp_avg_sq.sqrt().add_(self.d * group['eps'])
295
319
  update = update_m / denom
296
320
  del update_m, denom
297
321
 
@@ -349,8 +373,16 @@ class Prodigy_adv(torch.optim.Optimizer):
349
373
  g_group = self.param_groups[0]
350
374
  d_max, d_coef, growth_rate = g_group['d_max'], g_group['d_coef'], g_group['growth_rate']
351
375
 
352
- global_d_numerator = self.d_numerator
353
- global_d_denom = self.d_denom
376
+ if self.fsdp_in_use and dist.is_available() and dist.is_initialized():
377
+ # Use the device of the first parameter to avoid hardcoding '.cuda()'
378
+ device = self.param_groups[0]['params'][0].device
379
+ dist_tensor = torch.tensor([self.d_numerator, self.d_denom], device=device)
380
+ dist.all_reduce(dist_tensor, op=dist.ReduceOp.SUM)
381
+ global_d_numerator = dist_tensor[0].item()
382
+ global_d_denom = dist_tensor[1].item()
383
+ else:
384
+ global_d_numerator = self.d_numerator
385
+ global_d_denom = self.d_denom
354
386
 
355
387
  d_hat = self.d
356
388
  if global_d_denom > 0:
@@ -1,9 +1,13 @@
1
1
  from .AdamW_adv import AdamW_adv
2
2
  from .Prodigy_adv import Prodigy_adv
3
3
  from .Adopt_adv import Adopt_adv
4
+ from .Lion_adv import Lion_adv
5
+ from .Lion_Prodigy_adv import Lion_Prodigy_adv
4
6
 
5
7
  __all__ = [
6
8
  "AdamW_adv",
7
9
  "Prodigy_adv",
8
10
  "Adopt_adv",
11
+ "Lion_adv",
12
+ "Lion_Prodigy_adv",
9
13
  ]
adv_optm/util/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- from .BF16_Stochastic_Rounding import add_stochastic_, copy_stochastic_
1
+ from .BF16_Stochastic_Rounding import add_stochastic_
2
2
  from .Effective_Shape import _get_effective_shape
3
3
  from .One_Bit_Boolean import _pack_bools, _unpack_bools
4
4
  from .OrthoGrad import _orthogonalize_gradient
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: adv_optm
3
- Version: 0.1.0
3
+ Version: 0.1.2
4
4
  Summary: A family of highly efficient, lightweight yet powerful optimizers.
5
5
  Home-page: https://github.com/Koratahiu/Advanced_Optimizers
6
6
  Author: Koratahiu
@@ -47,7 +47,7 @@ Based primarily on:
47
47
  **[SMMF: Square-Matricized Momentum Factorization for Memory-Efficient Optimization](https://arxiv.org/abs/2412.08894)**
48
48
 
49
49
  The core innovation:
50
- - Uses fast, non-negative matrix factorization (rank 1, à la Adafactor), but **reconstructs the full state before each update** to preserve momentum accuracy, then re-factors afterward (factor → reconstruct → update → factor cycle).
50
+ - Uses fast, non-negative matrix factorization (NNMF - rank 1), but **reconstructs the full state before each update** to preserve momentum accuracy, then re-factors afterward (factor → reconstruct → update → factor cycle).
51
51
  - For the *signed first moment*, we split into **sign + absolute value**:
52
52
  - Sign is stored as **1-bit state** via bitwise ops (SMMF originally used 8-bit with 7 bits wasted).
53
53
  - Absolute value goes through the factor/reconstruct cycle using two factored vectors + the signed state.
@@ -106,8 +106,6 @@ Set `Factored=False` to disable factorization and run as a full uncompressed opt
106
106
 
107
107
  ⚠️ **Note**: AdEMAMix updates are more aggressive than normal Adam/Adopt, so use a x2-x5 smaller LR than usual (or use Prodigy).
108
108
 
109
- ⚠️ **Note**: The factored AdEMAMix is **Experimental** (as it needs more tests and validation, but it should work). Also, Adopt with AdEMAMix is **Experimental** (as Adopt normalizes the gradients for the momentum).
110
-
111
109
  - **[`atan2` smoothing & scaling](https://github.com/lucidrains/adam-atan2-pytorch)**
112
110
  → Robust `eps` replacement (no tuning!) + built-in gradient clipping
113
111
  → *Ideal for ADOPT* (which normally needs higher `eps` and clipping), so `use_atan2` is all-in-one for it.
@@ -129,6 +127,4 @@ Set `Factored=False` to disable factorization and run as a full uncompressed opt
129
127
 
130
128
  - When `use_atan2` is True, `eps` will be ignored and you should also disable any gradient clipping.
131
129
 
132
- - I don't recommend using **OrthoGrad** for training LoRA or embeddings, as their weights are zero-initialized and using weight decay for them should be safe and also beneficial (OrthoGrad is intended for fine-tuning pretrained models with no weight decay).
133
-
134
130
  ---
@@ -0,0 +1,18 @@
1
+ adv_optm/__init__.py,sha256=BNYlxkuU8MFsWSY1_PLzp2XBSzpt-sxhnVuWVKRZGZ8,252
2
+ adv_optm/optim/AdamW_adv.py,sha256=_4Vt79EB18rnIkHttA0CdMpli8sZ5f03pesdrwT5K58,12887
3
+ adv_optm/optim/Adopt_adv.py,sha256=rzBWfFOPrMuC6vwETsw7QPKmVXcv4IJRDCTj-6eU1Qk,14798
4
+ adv_optm/optim/Lion_Prodigy_adv.py,sha256=ql6506h_IIZvTPdGYrQdd6iEhCXHTMntqmg739fc_dw,14102
5
+ adv_optm/optim/Lion_adv.py,sha256=jOoRbJ6u9HCK7IBI9ILOCcwprKIGTUNvUzhRd99WJK0,9410
6
+ adv_optm/optim/Prodigy_adv.py,sha256=InR50MoE32zG6qgEkg_JzXl7uXAVRy4EYG0JDl4eKok,17324
7
+ adv_optm/optim/__init__.py,sha256=e5UighM92LDvDB2JJwj8gDsTpXEedpytScwqS6F2FR8,300
8
+ adv_optm/util/BF16_Stochastic_Rounding.py,sha256=Q5H0BcogmE4atP65dLoI21HKSf50lRdsBDfeF6v9Tbg,1548
9
+ adv_optm/util/Effective_Shape.py,sha256=TBvIk1V8IuTbbBsxuekJA4e_v8JlR5Nujtut8RTWAm4,318
10
+ adv_optm/util/NNMF.py,sha256=yRf5IP5Sjq0Uf0DxN0Q8NxEGSdD-f1ULziLVDOjY8K4,639
11
+ adv_optm/util/One_Bit_Boolean.py,sha256=Wat49esdwohuN-OHOFMW8D0aOQgV9cP5Rl8z6yfmpos,1068
12
+ adv_optm/util/OrthoGrad.py,sha256=NzInuBQGy_Ja__M1R9XbvqVaQ0fhGbtGgFE9YON7B3I,707
13
+ adv_optm/util/__init__.py,sha256=qoyIF0jcLjs_vSEcsv36clw5LFNBEbifyXrrVxMH-G4,349
14
+ adv_optm-0.1.2.dist-info/licenses/LICENSE,sha256=HrhfyXIkWY2tGFK11kg7vPCqhgh5DcxleloqdhrpyMY,11558
15
+ adv_optm-0.1.2.dist-info/METADATA,sha256=iV5GBWtl4WphBeSIIsUoq1ay6-GJGnDD3XF6aSWWrqg,5846
16
+ adv_optm-0.1.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
17
+ adv_optm-0.1.2.dist-info/top_level.txt,sha256=iNfBIIzu-lPrQ7jyC56WBCcbkRwitM2nJ15-MRQ_6fg,9
18
+ adv_optm-0.1.2.dist-info/RECORD,,
@@ -1,37 +0,0 @@
1
- import torch
2
- from torch import Tensor
3
-
4
- from typing import Tuple
5
-
6
- def _rsvd(A: torch.Tensor, rank: int, oversampling: int) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
7
- """Performs Randomized SVD."""
8
- orig_dtype, device, (m, n) = A.dtype, A.device, A.shape
9
- A_float = A.float()
10
- l, true_rank = rank + oversampling, min(m, n, rank)
11
-
12
- if true_rank == 0:
13
- return (
14
- torch.zeros(m, rank, dtype=orig_dtype, device=device),
15
- torch.zeros(rank, dtype=orig_dtype, device=device),
16
- torch.zeros(rank, n, dtype=orig_dtype, device=device),
17
- )
18
-
19
- if l >= min(m, n): # Fallback to full SVD
20
- U_full, S_full, Vh_full = torch.linalg.svd(A_float, full_matrices=False)
21
- U, S, Vh = U_full[:, :true_rank], S_full[:true_rank], Vh_full[:true_rank, :]
22
- else: # Standard RSVD path
23
- Omega = torch.randn(n, l, dtype=A_float.dtype, device=device)
24
- Y = A_float @ Omega
25
- Q, _ = torch.linalg.qr(Y.float())
26
- B = Q.T @ A_float
27
- U_tilde, S, Vh = torch.linalg.svd(B.float(), full_matrices=False)
28
- U, S, Vh = (Q @ U_tilde)[:, :true_rank], S[:true_rank], Vh[:true_rank, :]
29
-
30
- if true_rank < rank: # Pad factors with zeros
31
- U_padded = torch.zeros(m, rank, dtype=A_float.dtype, device=device)
32
- S_padded = torch.zeros(rank, dtype=A_float.dtype, device=device)
33
- Vh_padded = torch.zeros(rank, n, dtype=A_float.dtype, device=device)
34
- U_padded[:, :true_rank], S_padded[:true_rank], Vh_padded[:true_rank, :] = U, S, Vh
35
- U, S, Vh = U_padded, S_padded, Vh_padded
36
-
37
- return U.to(orig_dtype), S.to(orig_dtype), Vh.to(orig_dtype)
@@ -1,17 +0,0 @@
1
- adv_optm/__init__.py,sha256=4JNXqWmFkMvsUIQorZLy43BbyqZiJxMRQkCCr09sPKw,172
2
- adv_optm/optim/AdamW_adv.py,sha256=cvCl3bRfkENfbXwfdzZZ8k3AJ_tNx-c5kBgaguf5fnQ,12689
3
- adv_optm/optim/Adopt_adv.py,sha256=jDmz2Fky2t5Gv9VY5UzltF5b5TDtY3xS5pNlnj-Eox4,14952
4
- adv_optm/optim/Prodigy_adv.py,sha256=FFATlt4VFb7o3UocP_W4KjBIzJa_0ncsji7BsFFU_9E,15482
5
- adv_optm/optim/__init__.py,sha256=kX9MQhLQZGlKFPCGLXsZtooigs4wXULTEmNSSOJvcCY,178
6
- adv_optm/util/BF16_Stochastic_Rounding.py,sha256=Q5H0BcogmE4atP65dLoI21HKSf50lRdsBDfeF6v9Tbg,1548
7
- adv_optm/util/Effective_Shape.py,sha256=TBvIk1V8IuTbbBsxuekJA4e_v8JlR5Nujtut8RTWAm4,318
8
- adv_optm/util/NNMF.py,sha256=yRf5IP5Sjq0Uf0DxN0Q8NxEGSdD-f1ULziLVDOjY8K4,639
9
- adv_optm/util/One_Bit_Boolean.py,sha256=Wat49esdwohuN-OHOFMW8D0aOQgV9cP5Rl8z6yfmpos,1068
10
- adv_optm/util/OrthoGrad.py,sha256=NzInuBQGy_Ja__M1R9XbvqVaQ0fhGbtGgFE9YON7B3I,707
11
- adv_optm/util/Randomized_SVD.py,sha256=TFG417hh1t5f1n_mChnbgdQhpMoi37O04xVCe8wz8Qc,1708
12
- adv_optm/util/__init__.py,sha256=3yYKo23JDfHDZdGcjrDKxH8nYjk5KDB-i44kW-J4sPk,367
13
- adv_optm-0.1.0.dist-info/licenses/LICENSE,sha256=HrhfyXIkWY2tGFK11kg7vPCqhgh5DcxleloqdhrpyMY,11558
14
- adv_optm-0.1.0.dist-info/METADATA,sha256=ig2YmYzdS6DmX0KEIGkdwX-n9eciG2S2aZYog1feqmE,6342
15
- adv_optm-0.1.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
16
- adv_optm-0.1.0.dist-info/top_level.txt,sha256=iNfBIIzu-lPrQ7jyC56WBCcbkRwitM2nJ15-MRQ_6fg,9
17
- adv_optm-0.1.0.dist-info/RECORD,,