adv-optm 0.1.7__py3-none-any.whl → 0.1.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of adv-optm might be problematic. Click here for more details.

adv_optm/__init__.py CHANGED
@@ -16,4 +16,4 @@ __all__ = [
16
16
  "Lion_Prodigy_adv",
17
17
  ]
18
18
 
19
- __version__ = "0.1.7"
19
+ __version__ = "0.1.9"
@@ -55,7 +55,7 @@ class AdamW_adv(torch.optim.Optimizer):
55
55
  the warmup, `alpha` ramps from 0 to its target value. If `None`,
56
56
  the scheduler is disabled. (default: None)
57
57
  factored (bool): whether to use the factorization or disable it to use
58
- the uncompressed optimizer. (default: True)
58
+ the uncompressed optimizer. (default: False)
59
59
  """
60
60
 
61
61
  def __init__(
@@ -76,7 +76,7 @@ class AdamW_adv(torch.optim.Optimizer):
76
76
  beta3_ema: float = 0.9999,
77
77
  alpha: float = 5.0,
78
78
  t_alpha: int | None = None,
79
- factored: bool = True,
79
+ factored: bool = False,
80
80
  ):
81
81
  if not (lr >= 0.0):
82
82
  raise ValueError(f"Learning-rate should be >= 0.0. Got {lr}")
@@ -86,6 +86,9 @@ class AdamW_adv(torch.optim.Optimizer):
86
86
  raise ValueError(f"Epsilon should be >= 0.0. Got {eps}")
87
87
  if not (weight_decay >= 0.0):
88
88
  raise ValueError(f"Weight-decay should be >= 0.0. Got {weight_decay}")
89
+ if use_cautious and use_grams:
90
+ print("Warning: use_cautious is incompatible with use_grams, Disabling use_cautious.")
91
+ use_cautious = False
89
92
 
90
93
  defaults = {
91
94
  "lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay,
@@ -216,7 +219,10 @@ class AdamW_adv(torch.optim.Optimizer):
216
219
  del unpacked_sign_slow
217
220
 
218
221
  mt_slow.mul_(beta3_ema).add_(grad_reshaped, alpha=1.0 - beta3_ema)
219
- update = mt + (alpha_t * mt_slow) if beta1 > 0 else grad_reshaped + (alpha_t * mt_slow)
222
+ if beta1 > 0:
223
+ update = torch.add(mt, mt_slow, alpha=alpha_t)
224
+ else:
225
+ update = torch.add(grad_reshaped, mt_slow, alpha=alpha_t)
220
226
  else:
221
227
  update = mt.clone() if beta1 > 0 else grad_reshaped.clone()
222
228
  del grad_reshaped
@@ -262,7 +268,10 @@ class AdamW_adv(torch.optim.Optimizer):
262
268
  if self.use_AdEMAMix:
263
269
  exp_avg_slow = state['exp_avg_slow']
264
270
  exp_avg_slow.mul_(beta3_ema).add_(grad, alpha=1 - beta3_ema)
265
- update = exp_avg + (alpha_t * exp_avg_slow) if beta1 > 0 else grad + (alpha_t * exp_avg_slow)
271
+ if beta1 > 0:
272
+ update = torch.add(exp_avg, exp_avg_slow, alpha=alpha_t)
273
+ else:
274
+ update = torch.add(grad, exp_avg_slow, alpha=alpha_t)
266
275
  else:
267
276
  update = exp_avg.clone() if beta1 > 0 else grad.clone()
268
277
 
@@ -62,8 +62,18 @@ class Adopt_adv(torch.optim.Optimizer):
62
62
  the warmup, `alpha` ramps from 0 to its target value. If `None`,
63
63
  the scheduler is disabled and the full `alpha` value is used from
64
64
  the start. (default: None)
65
+ Simplified_AdEMAMix (bool): whether to use the Simplified AdEMAMix update rule.
66
+ This changes the EMA to accumulator and the update numerator to `alpha_grad * grad + mt`, which can be
67
+ more responsive, especially for small batch sizes. Enabling this will
68
+ automatically disable `use_AdEMAMix`, `use_cautious`, `use_grams`,
69
+ and `use_atan2`. (default: False)
70
+ alpha_grad (float): Mixing coefficient for the Simplified AdEMAMix update rule
71
+ (only used when `Simplified_AdEMAMix` is `True`). Controls the weight of the
72
+ current gradient. For small batch sizes, use high values (e.g., 10-100) to be
73
+ more responsive. For large batch sizes, use low values (e.g., 0-1) for
74
+ stability. (default: 100.0)
65
75
  factored (bool): whether to use the factorization or disable it to use
66
- the uncompressed optimizer. (default: True)
76
+ the uncompressed optimizer. (default: False)
67
77
  """
68
78
 
69
79
  def __init__(
@@ -77,14 +87,16 @@ class Adopt_adv(torch.optim.Optimizer):
77
87
  vector_reshape: bool = True,
78
88
  stochastic_rounding: bool = True,
79
89
  use_atan2: bool = False,
80
- use_cautious: bool = True,
90
+ use_cautious: bool = False,
81
91
  use_grams: bool = False,
82
92
  use_orthograd: bool = False,
83
93
  use_AdEMAMix: bool = False,
84
94
  beta3_ema: float = 0.9999,
85
95
  alpha: float = 5.0,
86
96
  t_alpha: int | None = None,
87
- factored: bool = True,
97
+ Simplified_AdEMAMix: bool = False,
98
+ alpha_grad: float = 100.0,
99
+ factored: bool = False,
88
100
  ):
89
101
  if not (lr >= 0.0):
90
102
  raise ValueError(f"Learning-rate should be >= 0.0. Got {lr}")
@@ -94,19 +106,34 @@ class Adopt_adv(torch.optim.Optimizer):
94
106
  raise ValueError(f"Epsilon should be >= 0.0. Got {eps}")
95
107
  if not (weight_decay >= 0.0):
96
108
  raise ValueError(f"Weight-decay should be >= 0.0. Got {weight_decay}")
109
+ if use_cautious and use_grams:
110
+ print("Warning: use_cautious is incompatible with use_grams, Disabling use_cautious.")
111
+ use_cautious = False
112
+ if betas[0] == 0.0 and Simplified_AdEMAMix:
113
+ raise ValueError(f"Beta1 cannot be 0.0 when using Simplified_AdEMAMix. Got {betas[0]}")
114
+ if use_AdEMAMix and Simplified_AdEMAMix:
115
+ print("Warning: use_AdEMAMix is incompatible with Simplified_AdEMAMix, Disabling use_AdEMAMix.")
116
+ if use_grams and Simplified_AdEMAMix:
117
+ print("Warning: use_grams is incompatible with Simplified_AdEMAMix, Disabling use_grams.")
118
+ if use_cautious and Simplified_AdEMAMix:
119
+ print("Warning: use_cautious is incompatible with Simplified_AdEMAMix, Disabling use_cautious.")
120
+ if use_atan2 and Simplified_AdEMAMix:
121
+ print("Warning: use_atan2 is incompatible with Simplified_AdEMAMix. Disabling use_atan2.")
122
+ use_atan2 = False
97
123
 
98
124
  defaults = {
99
125
  "lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay,
100
126
  "vector_reshape": vector_reshape, "beta3_ema": beta3_ema, "alpha": alpha,
101
- "t_alpha": t_alpha,
127
+ "t_alpha": t_alpha, "alpha_grad": alpha_grad,
102
128
  }
103
129
  self.clip_lambda = clip_lambda
104
130
  self.stochastic_rounding = stochastic_rounding
105
- self.use_atan2 = use_atan2
106
- self.use_cautious = use_cautious
107
- self.use_grams = use_grams
131
+ self.use_atan2 = use_atan2 and not Simplified_AdEMAMix
132
+ self.use_cautious = use_cautious and not Simplified_AdEMAMix
133
+ self.use_grams = use_grams and not Simplified_AdEMAMix
108
134
  self.use_orthograd = use_orthograd
109
- self.use_AdEMAMix = use_AdEMAMix
135
+ self.use_AdEMAMix = use_AdEMAMix and not Simplified_AdEMAMix
136
+ self.Simplified_AdEMAMix = Simplified_AdEMAMix
110
137
  self.factored = factored
111
138
  super().__init__(params, defaults)
112
139
 
@@ -185,6 +212,8 @@ class Adopt_adv(torch.optim.Optimizer):
185
212
  alpha_t = alpha
186
213
  if t_alpha is not None and t_alpha > 0 and current_step < t_alpha:
187
214
  alpha_t = min(current_step * alpha / t_alpha, alpha)
215
+ if self.Simplified_AdEMAMix:
216
+ alpha_grad = group["alpha_grad"]
188
217
 
189
218
  if state['factored']:
190
219
  d1, d2 = state['effective_shape']
@@ -224,7 +253,10 @@ class Adopt_adv(torch.optim.Optimizer):
224
253
  del denom
225
254
 
226
255
  # ADOPT Step B: Update momentum m_t using normalized gradient
227
- mt.mul_(beta1).add_(normalized_grad, alpha=1.0 - beta1)
256
+ if self.Simplified_AdEMAMix:
257
+ mt.mul_(beta1).add_(normalized_grad, alpha=1.0)
258
+ else:
259
+ mt.mul_(beta1).add_(normalized_grad, alpha=1.0 - beta1)
228
260
  if self.use_grams:
229
261
  mt = grad_reshaped.sign() * mt.abs()
230
262
  elif self.use_cautious:
@@ -235,8 +267,10 @@ class Adopt_adv(torch.optim.Optimizer):
235
267
 
236
268
  if self.use_AdEMAMix:
237
269
  mt_slow.mul_(beta3_ema).add_(normalized_grad, alpha=1.0 - beta3_ema)
238
- update = mt + (alpha_t * mt_slow)
270
+ update = torch.add(mt, m_slow, alpha=alpha_t)
239
271
  update = update.view(p.shape)
272
+ elif self.Simplified_AdEMAMix:
273
+ update = torch.add(mt, grad_reshaped, alpha=alpha_grad)
240
274
  else:
241
275
  update = mt.view(p.shape)
242
276
 
@@ -283,7 +317,10 @@ class Adopt_adv(torch.optim.Optimizer):
283
317
  del denom
284
318
 
285
319
  # ADOPT Step B: Update momentum m_t
286
- m.mul_(beta1).add_(normalized_grad, alpha=1.0 - beta1)
320
+ if self.Simplified_AdEMAMix:
321
+ m.mul_(beta1).add_(normalized_grad, alpha=1.0)
322
+ else:
323
+ m.mul_(beta1).add_(normalized_grad, alpha=1.0 - beta1)
287
324
 
288
325
  if self.use_grams:
289
326
  m = grad.sign() * m.abs()
@@ -295,9 +332,11 @@ class Adopt_adv(torch.optim.Optimizer):
295
332
 
296
333
  if self.use_AdEMAMix:
297
334
  m_slow.mul_(beta3_ema).add_(normalized_grad, alpha=1.0 - beta3_ema)
298
- update = m + (alpha_t * m_slow)
335
+ update = torch.add(m, m_slow, alpha=alpha_t)
336
+ elif self.Simplified_AdEMAMix:
337
+ update = torch.add(m, grad, alpha=alpha_grad)
299
338
  else:
300
- update = m
339
+ update = m.clone()
301
340
 
302
341
  if self.use_atan2:
303
342
  update.mul_(group['lr'] * 1.2732395447351628)
@@ -33,8 +33,6 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
33
33
  (default: 0.0).
34
34
  factored (bool): whether to use the factorization or use the
35
35
  uncompressed optimizer. (default: True)
36
- variance_reduction (bool): whether to use the variance reduction technique
37
- from "Convergence Analysis of the Lion Optimizer" (arXiv:2508.12327v1). (default: False).
38
36
  d0 (float):
39
37
  Initial D estimate for D-adaptation (default 1e-6). Rarely needs changing.
40
38
  d_coef (float):
@@ -66,7 +64,6 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
66
64
  use_cautious: bool = False,
67
65
  clip_threshold: float = 0.0,
68
66
  factored: bool = True,
69
- variance_reduction: bool = False,
70
67
  # prodigy parameters
71
68
  beta3: float = None,
72
69
  d0: float = 1e-6,
@@ -97,7 +94,6 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
97
94
  self.stochastic_rounding = stochastic_rounding
98
95
  self.use_cautious = use_cautious
99
96
  self.factored = factored
100
- self.variance_reduction = variance_reduction
101
97
  self.fsdp_in_use = fsdp_in_use
102
98
  super().__init__(params, defaults)
103
99
  # Global state for accumulating metrics across parameter updates within a single step.
@@ -183,12 +179,8 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
183
179
  state['mv_m_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
184
180
  packed_d2 = (d2 + 7) // 8
185
181
  state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=p.device)
186
- if self.variance_reduction:
187
- state['prev_grad'] = torch.zeros((d1, d2), device=p.device, dtype=dtype)
188
182
  else: # Fallback to standard Lion
189
183
  state['exp_avg'] = torch.zeros_like(p, device=p.device, dtype=dtype)
190
- if self.variance_reduction:
191
- state['prev_grad'] = torch.zeros_like(p, device=p.device, dtype=dtype)
192
184
 
193
185
  if state['factored']:
194
186
  # Factored Path
@@ -215,20 +207,7 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
215
207
  update_for_param = signed_update.view(p.shape).mul(self.dlr)
216
208
 
217
209
  # Update momentum m_t = β2*m_{t-1} + (1-β2)*lr*g_t
218
- if self.variance_reduction:
219
- if state['step'] == 1:
220
- exp_avg.copy_(grad_reshaped)
221
- else:
222
- # Heuristic Prodigy-STORM update
223
- correction = exp_avg.sub(state['prev_grad'])
224
- grad_alpha = self.d * (1 - self.beta2) + self.beta2
225
- exp_avg.copy_(grad_reshaped).mul_(grad_alpha).add_(correction, alpha=self.beta2)
226
- del correction, grad_alpha
227
- state['prev_grad'].copy_(grad_reshaped)
228
- else:
229
- # Standard Prodigy-Lion
230
- alpha = self.d * (1 - self.beta2)
231
- exp_avg.mul_(self.beta2).add_(grad_reshaped, alpha=alpha)
210
+ exp_avg.mul_(self.beta2).add_(grad_reshaped, alpha=self.d * (1 - self.beta2))
232
211
  del grad_reshaped
233
212
 
234
213
  # Compress new momentum m_t and store factors
@@ -254,20 +233,7 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
254
233
  update_for_param = signed_update.mul(self.dlr)
255
234
 
256
235
  # Update momentum
257
- if self.variance_reduction:
258
- if state['step'] == 1:
259
- exp_avg.copy_(grad)
260
- else:
261
- # Heuristic Prodigy-STORM update
262
- correction = exp_avg.sub(state['prev_grad'])
263
- grad_alpha = self.d * (1 - self.beta2) + self.beta2
264
- exp_avg.copy_(grad).mul_(grad_alpha).add_(correction, alpha=self.beta2)
265
- del grad_alpha, correction
266
- state['prev_grad'].copy_(grad)
267
- else:
268
- # Standard Prodigy-Lion
269
- alpha = self.d * (1 - self.beta2)
270
- exp_avg.mul_(self.beta2).add_(grad, alpha=alpha)
236
+ exp_avg.mul_(self.beta2).add_(grad, alpha=self.d * (1 - self.beta2))
271
237
 
272
238
  # --- Accumulate Prodigy stats ---
273
239
  d0, safeguard_warmup, slice_p = group['d0'], group['safeguard_warmup'], group['slice_p']
@@ -298,7 +264,7 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
298
264
  else:
299
265
  p.data.add_(-update_for_param)
300
266
 
301
- del update_for_param
267
+ del update_for_param
302
268
 
303
269
  @torch.no_grad()
304
270
  def step(self, closure: Optional[callable] = None):
@@ -33,8 +33,6 @@ class Lion_adv(torch.optim.Optimizer):
33
33
  (default: 0.0).
34
34
  factored (bool): whether to use the factorization or use the
35
35
  uncompressed optimizer. (default: True)
36
- variance_reduction (bool): whether to use the variance reduction technique
37
- from "Convergence Analysis of the Lion Optimizer" (arXiv:2508.12327v1). (default: False).
38
36
  """
39
37
 
40
38
  def __init__(
@@ -49,7 +47,6 @@ class Lion_adv(torch.optim.Optimizer):
49
47
  use_cautious: bool = False,
50
48
  clip_threshold: float = 0.0,
51
49
  factored: bool = True,
52
- variance_reduction: bool = False,
53
50
  ):
54
51
  if not lr > 0.0:
55
52
  raise ValueError(f"Learning rate must be > 0.0, but got {lr}")
@@ -69,7 +66,6 @@ class Lion_adv(torch.optim.Optimizer):
69
66
  self.stochastic_rounding = stochastic_rounding
70
67
  self.use_cautious = use_cautious
71
68
  self.factored = factored
72
- self.variance_reduction = variance_reduction
73
69
  super().__init__(params, defaults)
74
70
 
75
71
  @property
@@ -122,12 +118,8 @@ class Lion_adv(torch.optim.Optimizer):
122
118
  state['mv_m_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
123
119
  packed_d2 = (d2 + 7) // 8
124
120
  state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=p.device)
125
- if self.variance_reduction:
126
- state['prev_grad'] = torch.zeros((d1, d2), device=p.device, dtype=dtype)
127
121
  else: # Fallback to standard Lion
128
122
  state['exp_avg'] = torch.zeros_like(p, device=p.device, dtype=dtype)
129
- if self.variance_reduction:
130
- state['prev_grad'] = torch.zeros_like(p, device=p.device, dtype=dtype)
131
123
 
132
124
  state['step'] += 1
133
125
  beta1, beta2 = group["betas"]
@@ -157,21 +149,9 @@ class Lion_adv(torch.optim.Optimizer):
157
149
  # Parameter update
158
150
  update_for_param = signed_update.view(p.shape).mul_(lr)
159
151
 
160
- # Update momentum
161
- if self.variance_reduction:
162
- if state['step'] == 1:
163
- exp_avg.copy_(grad_reshaped)
164
- else:
165
- # Use the simplified STORM update: m_t = g_t + β₂ * (m_{t-1} - g_{t-1})
166
- correction = exp_avg.sub(state['prev_grad'])
167
- # Calculate the new momentum and store it back into exp_avg
168
- exp_avg.copy_(grad_reshaped).add_(correction, alpha=beta2)
169
- del correction
170
- # Update prev_grad for the next iteration
171
- state['prev_grad'].copy_(grad_reshaped)
172
- else:
173
- # Standard Lion momentum update
174
- exp_avg.mul_(beta2).add_(grad_reshaped, alpha=1-beta2)
152
+ # Standard Lion momentum update
153
+ exp_avg.mul_(beta2).add_(grad_reshaped, alpha=1-beta2)
154
+ del grad_reshaped
175
155
 
176
156
  # Compress new momentum m_t and store factors
177
157
  state['sign'] = _pack_bools(exp_avg > 0)
@@ -195,21 +175,8 @@ class Lion_adv(torch.optim.Optimizer):
195
175
 
196
176
  update_for_param = signed_update.mul_(lr)
197
177
 
198
- # Update momentum
199
- if self.variance_reduction:
200
- if state['step'] == 1:
201
- exp_avg.copy_(grad)
202
- else:
203
- # Use the simplified STORM update: m_t = g_t + β₂ * (m_{t-1} - g_{t-1})
204
- correction = exp_avg.sub(state['prev_grad'])
205
- # Calculate the new momentum and store it back into exp_avg
206
- exp_avg.copy_(grad).add_(correction, alpha=beta2)
207
- del correction
208
- # Update prev_grad for the next iteration
209
- state['prev_grad'].copy_(grad)
210
- else:
211
- # Standard Lion momentum update
212
- exp_avg.mul_(beta2).add_(grad, alpha=1-beta2)
178
+ # Standard Lion momentum update
179
+ exp_avg.mul_(beta2).add_(grad, alpha=1-beta2)
213
180
 
214
181
  if group["weight_decay"] != 0:
215
182
  if p.dtype == torch.bfloat16 and self.stochastic_rounding:
@@ -225,7 +192,7 @@ class Lion_adv(torch.optim.Optimizer):
225
192
  else:
226
193
  p.data.add_(-update_for_param)
227
194
 
228
- del update_for_param
195
+ del update_for_param
229
196
 
230
197
  @torch.no_grad()
231
198
  def step(self, closure: Optional[callable] = None):
@@ -64,7 +64,7 @@ class Prodigy_adv(torch.optim.Optimizer):
64
64
  more responsive. For large batch sizes, use low values (e.g., 0-1) for
65
65
  stability. (default: 100.0)
66
66
  factored (bool): whether to use the factorization or disable it to use
67
- the uncompressed optimizer. (default: True)
67
+ the uncompressed optimizer. (default: False)
68
68
  d0 (float):
69
69
  Initial D estimate for D-adaptation (default 1e-6). Rarely needs changing.
70
70
  d_coef (float):
@@ -82,6 +82,9 @@ class Prodigy_adv(torch.optim.Optimizer):
82
82
  slice_p (int): Reduce memory usage by calculating LR adaptation statistics on only every
83
83
  pth entry of each tensor. For values greater than 1 this an an approximation to standard
84
84
  Prodigy. Values ~11 are reasonable (default 11).
85
+ prodigy_steps (int): If greater than zero, disable Prodigy's stepsize adjustments
86
+ after the specified optimiser step and release all state memory required by Prodigy
87
+ (default: 0).
85
88
  """
86
89
 
87
90
  def __init__(
@@ -103,7 +106,7 @@ class Prodigy_adv(torch.optim.Optimizer):
103
106
  t_alpha: int | None = None,
104
107
  Simplified_AdEMAMix: bool = False,
105
108
  alpha_grad: float = 100.0,
106
- factored: bool = True,
109
+ factored: bool = False,
107
110
  # prodigy parameters
108
111
  beta3: float = None,
109
112
  d0: float = 1e-6,
@@ -112,6 +115,7 @@ class Prodigy_adv(torch.optim.Optimizer):
112
115
  safeguard_warmup: bool = False,
113
116
  fsdp_in_use: bool = False,
114
117
  slice_p: int = 11,
118
+ prodigy_steps: int = 0,
115
119
  ):
116
120
  if not (lr >= 0.0):
117
121
  raise ValueError(f"Learning-rate should be >= 0.0. Got {lr}")
@@ -121,8 +125,13 @@ class Prodigy_adv(torch.optim.Optimizer):
121
125
  raise ValueError(f"Epsilon should be >= 0.0. Got {eps}")
122
126
  if not (weight_decay >= 0.0):
123
127
  raise ValueError(f"Weight-decay should be >= 0.0. Got {weight_decay}")
128
+ if not (prodigy_steps >= 0):
129
+ raise ValueError(f"prodigy_steps should be >= 0. Got {prodigy_steps}")
130
+ if use_cautious and use_grams:
131
+ print("Warning: use_cautious is incompatible with use_grams, Disabling use_cautious.")
132
+ use_cautious = False
124
133
  if betas[0] == 0.0 and Simplified_AdEMAMix:
125
- raise ValueError(f"Beta 1 cannot be 0.0 when using Simplified_AdEMAMix. Got {betas[0]}")
134
+ raise ValueError(f"Beta1 cannot be 0.0 when using Simplified_AdEMAMix. Got {betas[0]}")
126
135
  if use_AdEMAMix and Simplified_AdEMAMix:
127
136
  print("Warning: use_AdEMAMix is incompatible with Simplified_AdEMAMix, Disabling use_AdEMAMix.")
128
137
  if use_grams and Simplified_AdEMAMix:
@@ -140,7 +149,7 @@ class Prodigy_adv(torch.optim.Optimizer):
140
149
  "beta3_ema": beta3_ema, "alpha": alpha, "t_alpha": t_alpha,
141
150
  "beta3": beta3, "d": d0, "d0": d0, "d_max": d0, "d_numerator": 0.0, "d_coef": d_coef,
142
151
  "growth_rate": growth_rate, "safeguard_warmup": safeguard_warmup, "k": 0, "slice_p": slice_p,
143
- "fsdp_in_use": fsdp_in_use,
152
+ "fsdp_in_use": fsdp_in_use, "prodigy_steps": prodigy_steps,
144
153
  "alpha_grad": alpha_grad,
145
154
  }
146
155
  self.stochastic_rounding = stochastic_rounding
@@ -293,7 +302,10 @@ class Prodigy_adv(torch.optim.Optimizer):
293
302
  torch.where(unpacked_sign_slow, mt_slow, -mt_slow, out=mt_slow)
294
303
  del unpacked_sign_slow
295
304
  mt_slow.mul_(beta3_ema).add_(grad_reshaped, alpha=self.d * (1.0 - beta3_ema))
296
- update = mt + (alpha_t * mt_slow) if self.beta1 > 0 else grad_reshaped + (alpha_t * mt_slow)
305
+ if self.beta1 > 0:
306
+ update = torch.add(mt, mt_slow, alpha=alpha_t)
307
+ else:
308
+ update = torch.add(grad_reshaped, mt_slow, alpha=alpha_t)
297
309
  elif self.Simplified_AdEMAMix:
298
310
  update = torch.add(mt, grad_reshaped, alpha=alpha_grad * self.d)
299
311
  else:
@@ -344,7 +356,10 @@ class Prodigy_adv(torch.optim.Optimizer):
344
356
  if self.use_AdEMAMix:
345
357
  exp_avg_slow = state['exp_avg_slow']
346
358
  exp_avg_slow.mul_(beta3_ema).add_(grad, alpha=self.d * (1.0 - beta3_ema))
347
- update = exp_avg + (alpha_t * exp_avg_slow) if self.beta1 > 0 else grad + (alpha_t * exp_avg_slow)
359
+ if self.beta1 > 0:
360
+ update = torch.add(exp_avg, exp_avg_slow, alpha=alpha_t)
361
+ else:
362
+ update = torch.add(grad, exp_avg_slow, alpha=alpha_t)
348
363
  elif self.Simplified_AdEMAMix:
349
364
  update = torch.add(exp_avg, grad, alpha=alpha_grad * self.d)
350
365
  else:
@@ -364,19 +379,27 @@ class Prodigy_adv(torch.optim.Optimizer):
364
379
  update.mul_(self.dlr)
365
380
 
366
381
  # --- Accumulate Prodigy stats ---
367
- d0, safeguard_warmup, slice_p = group['d0'], group['safeguard_warmup'], group['slice_p']
368
- s, p0 = state['s'], state['p0']
369
- grad_flat = grad.flatten().float()
370
- p_flat = p.data.flatten().float()
371
- p0 = p0.float()
382
+ prodigy_steps = group['prodigy_steps']
383
+ if prodigy_steps <= 0 or group['k'] < prodigy_steps:
384
+ d0, safeguard_warmup, slice_p = group['d0'], group['safeguard_warmup'], group['slice_p']
385
+ s, p0 = state['s'], state['p0']
386
+ grad_flat = grad.flatten().float()
387
+ p_flat = p.data.flatten().float()
388
+ p0 = p0.float()
372
389
 
373
- self.d_numerator += (self.d / d0) * self.dlr * torch.dot(grad_flat[::slice_p], p0.data - p_flat[::slice_p]).item()
390
+ self.d_numerator += (self.d / d0) * self.dlr * torch.dot(grad_flat[::slice_p], p0.data - p_flat[::slice_p]).item()
374
391
 
375
- alpha = ((self.d / d0) * self.d) if safeguard_warmup else ((self.d / d0) * self.dlr)
376
- s.mul_(self.beta3).add_(grad_flat[::slice_p], alpha=alpha)
377
- self.d_denom += s.abs().sum().item()
392
+ alpha = ((self.d / d0) * self.d) if safeguard_warmup else ((self.d / d0) * self.dlr)
393
+ s.mul_(self.beta3).add_(grad_flat[::slice_p], alpha=alpha)
394
+ self.d_denom += s.abs().sum().item()
378
395
 
379
- del s, p0, grad_flat, p_flat, alpha
396
+ del s, p0, grad_flat, p_flat, alpha
397
+ else:
398
+ # Free memory if prodigy_steps is reached
399
+ if 's' in state:
400
+ del state['s']
401
+ if 'p0' in state:
402
+ del state['p0']
380
403
 
381
404
  # Decoupled weight decay
382
405
  if group["weight_decay"] != 0:
@@ -413,29 +436,43 @@ class Prodigy_adv(torch.optim.Optimizer):
413
436
  def calculate_d(self):
414
437
  """Calculates the new `d` based on the accumulated stats."""
415
438
  g_group = self.param_groups[0]
416
- d_max, d_coef, growth_rate = g_group['d_max'], g_group['d_coef'], g_group['growth_rate']
417
439
 
418
- if self.fsdp_in_use and dist.is_available() and dist.is_initialized():
419
- # Use the device of the first parameter to avoid hardcoding '.cuda()'
420
- device = self.param_groups[0]['params'][0].device
421
- dist_tensor = torch.tensor([self.d_numerator, self.d_denom], device=device)
422
- dist.all_reduce(dist_tensor, op=dist.ReduceOp.SUM)
423
- global_d_numerator = dist_tensor[0].item()
424
- global_d_denom = dist_tensor[1].item()
425
- else:
426
- global_d_numerator = self.d_numerator
427
- global_d_denom = self.d_denom
428
-
429
- d_hat = self.d
430
- if global_d_denom > 0:
431
- d_hat = d_coef * global_d_numerator / global_d_denom
432
- if self.d == g_group['d0']:
433
- self.d = max(self.d, d_hat)
434
- d_max = max(d_max, d_hat)
435
- self.d = min(d_max, self.d * growth_rate)
436
-
440
+ # Only perform d-adaptation if prodigy_steps has not been reached
441
+ prodigy_active = not (g_group.get('prodigy_steps', 0) > 0 and g_group['k'] >= g_group['prodigy_steps'])
442
+
443
+ if prodigy_active:
444
+ d_max, d_coef, growth_rate = g_group['d_max'], g_group['d_coef'], g_group['growth_rate']
445
+
446
+ if self.fsdp_in_use and dist.is_available() and dist.is_initialized():
447
+ # Use the device of the first parameter to avoid hardcoding '.cuda()'
448
+ device = self.param_groups[0]['params'][0].device
449
+ dist_tensor = torch.tensor([self.d_numerator, self.d_denom], device=device)
450
+ dist.all_reduce(dist_tensor, op=dist.ReduceOp.SUM)
451
+ global_d_numerator = dist_tensor[0].item()
452
+ global_d_denom = dist_tensor[1].item()
453
+ else:
454
+ global_d_numerator = self.d_numerator
455
+ global_d_denom = self.d_denom
456
+
457
+ d_hat = self.d
458
+ if global_d_denom > 0:
459
+ if self.Simplified_AdEMAMix and g_group['alpha_grad'] > 0:
460
+ # A simple and effective hack to make prodigy compatible with Simplified_AdEMAMix large step sizes
461
+ # by diving by alpha_grad we make sure that d_numerator that was influenced by (alpha_grad * grad)
462
+ # are now normalized by /alpha_grad. this is a heuristic way since the update is also influenced by
463
+ # the increasing and decaying accumulator but it's effective and it worked for me (for Lora/Finetune).
464
+ global_d_numerator /= g_group['alpha_grad']
465
+ d_hat = d_coef * global_d_numerator / global_d_denom
466
+ if self.d == g_group['d0']:
467
+ self.d = max(self.d, d_hat)
468
+ d_max = max(d_max, d_hat)
469
+ self.d = min(d_max, self.d * growth_rate)
470
+
471
+ for group in self.param_groups:
472
+ group['d_numerator'] = global_d_numerator
473
+ group['d'] = self.d
474
+ group['d_max'] = d_max
475
+
476
+ # Increment step counter for all groups, regardless of whether d was updated
437
477
  for group in self.param_groups:
438
- group['d_numerator'] = global_d_numerator
439
- group['d'] = self.d
440
- group['d_max'] = d_max
441
478
  group['k'] += 1
@@ -0,0 +1,174 @@
1
+ Metadata-Version: 2.4
2
+ Name: adv_optm
3
+ Version: 0.1.9
4
+ Summary: A family of highly efficient, lightweight yet powerful optimizers.
5
+ Home-page: https://github.com/Koratahiu/Advanced_Optimizers
6
+ Author: Koratahiu
7
+ Author-email: hiuhonor@gmail.com
8
+ License: Apache 2.0
9
+ Keywords: llm,fine-tuning,memory-efficient,low-rank,compression,pytorch,optimizer,adam
10
+ Classifier: Programming Language :: Python :: 3
11
+ Classifier: License :: OSI Approved :: Apache Software License
12
+ Classifier: Operating System :: OS Independent
13
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
14
+ Classifier: Topic :: Software Development :: Libraries :: Python Modules
15
+ Requires-Python: >=3.8
16
+ Description-Content-Type: text/markdown
17
+ License-File: LICENSE
18
+ Requires-Dist: torch>=2.0
19
+ Dynamic: author
20
+ Dynamic: author-email
21
+ Dynamic: classifier
22
+ Dynamic: description
23
+ Dynamic: description-content-type
24
+ Dynamic: home-page
25
+ Dynamic: keywords
26
+ Dynamic: license
27
+ Dynamic: license-file
28
+ Dynamic: requires-dist
29
+ Dynamic: requires-python
30
+ Dynamic: summary
31
+
32
+ # Advanced Optimizers (AIO)
33
+
34
+ A comprehensive, all-in-one collection of optimization algorithms for deep learning, designed for maximum efficiency, minimal memory footprint, and superior performance across diverse model architectures and training scenarios.
35
+
36
+ [![PyPI](https://img.shields.io/pypi/v/adv_optm)](https://pypi.org/project/adv_optm/)
37
+
38
+ ---
39
+
40
+ ## 📦 Installation
41
+
42
+ ```bash
43
+ pip install adv_optm
44
+ ```
45
+
46
+ ---
47
+
48
+ ## 🧠 Core Innovations
49
+
50
+ This library integrates multiple state-of-the-art optimization techniques validated through extensive research and practical training, with 1-bit compression for optimizer states:
51
+
52
+ ### **Memory-Efficient Optimization (SMMF-inspired)**
53
+ - **Paper**: [SMMF: Square-Matricized Momentum Factorization](https://arxiv.org/abs/2412.08894)
54
+ - **Approach**: Uses rank-1 non-negative matrix factorization with reconstruction cycle (factor → reconstruct → update → factor)
55
+ - **Innovation**:
56
+ - First moment split into **1-bit sign + absolute value**
57
+ - Final storage: **four factored vectors + one 1-bit sign state**
58
+ - Preserves Adam-like update quality with drastically reduced memory
59
+
60
+ ---
61
+
62
+ ## ⚡ Performance Characteristics
63
+
64
+ ### Memory Efficiency (SDXL Model - 6.5GB)
65
+ | Optimizer | Memory Usage | Description |
66
+ |-----------|--------------|-------------|
67
+ | `Adopt_Factored` | 328 MB | 4 small vectors + 1-bit state |
68
+ | `Adopt_Factored + AdEMAMix` | 625 MB | 6 small vectors + two 1-bit states |
69
+ | `Simplified_AdEMAMix` | 328 MB | Same as standard factored (no extra state) |
70
+
71
+ ### Speed Comparison (SDXL, Batch Size 4)
72
+ | Optimizer | Speed | Notes |
73
+ |-----------|-------|-------|
74
+ | `Adafactor` | ~8.5s/it | Baseline |
75
+ | `Adopt_Factored` | ~10s/it | +18% overhead from compression |
76
+ | `Adopt_Factored + AdEMAMix` | ~12s/it | +41% overhead (3 factored states) |
77
+
78
+ ---
79
+
80
+ ## 🧪 Available Optimizers
81
+
82
+ ### Standard Optimizers (All support `factored=True/False`)
83
+ | Optimizer | Description | Best For |
84
+ |-----------|-------------|----------|
85
+ | `Adam_Adv` | Advanced Adam implementation | General purpose |
86
+ | `Adopt_Adv` | Adam-variant with independent beta2 | Stable training for small batch size regimes |
87
+ | `Prodigy_Adv` | Prodigy with D-Adaptation | Adam with automatic LR tuning |
88
+ | `Simplified_AdEMAMix` | Adam variant with accumulator momentum | Small/large batch training when tuned correctly |
89
+ | `Lion_Adv` | Advanced Lion implementation | Memory-constrained environments |
90
+ | `Prodigy_Lion_Adv` | Prodigy + Lion combination | Lion with automatic LR tuning |
91
+
92
+ ### Feature Matrix
93
+ | Feature | Adam_Adv | Adopt_Adv | Prodigy_Adv | Simplified_AdEMAMix | Lion_Adv |
94
+ |---------|----------|-----------|-------------|---------------------|----------|
95
+ | Factored | ✓ | ✓ | ✓ | ✓ | ✓ |
96
+ | AdEMAMix | ✓ | ✓ | ✓ | ✗ | ✗ |
97
+ | Simplified_AdEMAMix | ✗ | ✗ | ✓ | ✓ | ✗ |
98
+ | OrthoGrad | ✓ | ✓ | ✓ | ✓ | ✓ |
99
+ | Grams | ✓ | ✓ | ✓ | ✗ | ✗ |
100
+ | Cautious | ✓ | ✓ | ✓ | ✗ | ✓ |
101
+ | atan2 | ✓ | ✓ | ✓ | ✗ | ✗ |
102
+ | Stochastic Rounding | ✓ | ✓ | ✓ | ✓ | ✓ |
103
+ | Fused Backward Pass | ✓ | ✓ | ✓ | ✓ | ✓ |
104
+
105
+ ---
106
+
107
+ ## ⚙️ Key Features & Parameters
108
+
109
+ ### Comprehensive Feature Guide
110
+
111
+ | Feature | Description | Recommended Usage | Performance Impact | Theoretical Basis | Compatibility |
112
+ |---------|-------------|-------------------|--------------------|-------------------|--------------|
113
+ | **Factored** | Memory-efficient optimization using rank-1 factorization | Enable for large models (>1B params) or limited VRAM | +12-41% time overhead, 1-bit memory usage | [SMMF](https://arxiv.org/abs/2412.08894) | All optimizers |
114
+ | **AdEMAMix** | Dual EMA system for momentum | Use for long training runs (10k+ steps) | +1 state memory. | [AdEMAMix](https://arxiv.org/abs/2409.03137) | Adam/Adopt/Prodigy |
115
+ | **Simplified_AdEMAMix** | Accumulator-based momentum | Small batch training (≤32) | Same memory as standard, no extra overhead | [Schedule-Free Connections](https://arxiv.org/abs/2502.02431) | Adam/Prodigy |
116
+ | **OrthoGrad** | Removes gradient component parallel to weights | Full finetuning without weight decay | +33% time overhead, no memory impact | [Grokking at Edge](https://github.com/LucasPrietoAl/grokking-at-the-edge-of-numerical-stability) | All optimizers |
117
+ | **Stochastic Rounding** | Improves precision for BF16 training | BF16 training | Minimal overhead (<5%) | [Revisiting BFloat16 Training](https://arxiv.org/abs/2010.06192) | All optimizers |
118
+ | **atan2** | Robust eps replacement + built-in clipping | Use with Adopt or unstable training | No overhead | [Adam-atan2](https://github.com/lucidrains/adam-atan2-pytorch) | Adam/Adopt/prodigy |
119
+ | **Cautious** | Update only when the direction align with the gradients | should faster the convergence | No overhead | [C-Optim](https://github.com/kyleliang919/C-Optim) | Adam/Adopt/prodigy |
120
+ | **Grams** | Update direction from the gradients | should have a stronger effect than cautious | No overhead | [Grams](https://github.com/Gunale0926/Grams) | Adam/Adopt/prodigy |
121
+
122
+ ---
123
+
124
+ ## Simplified_AdEMAMix Parameters
125
+ Simplified_AdEMAMix replaces standard momentum with an accumulator for better small-large batch performance.
126
+
127
+ | Parameter | Recommended Values | Description |
128
+ |-----------|---------------------|-------------|
129
+ | `beta1` | 0.9 (large BS), 0.99-0.9999 (small BS) | Determines memory length of accumulator |
130
+ | `alpha` | 100-10 (small BS), 1-0 (large BS) | Gradient smoothing factor |
131
+
132
+ **Alpha Tuning Guide**:
133
+ | Batch Size | Recommended α | Rationale |
134
+ |------------|---------------|-----------|
135
+ | Small (≤32) | 100, 50, 20, 10 | Emphasizes recent gradients for quick adaptation |
136
+ | Medium (32-512) | 10, 5, 2, 1 | Balanced approach |
137
+ | Large (≥512) | 1, 0.5, 0 | Emphasizes historical gradients for stability |
138
+
139
+ ⚠️ **Important**: Use **~100x smaller learning rate** with Simplified_AdEMAMix compared to AdamW (e.g., 1e-6 instead of 1e-4)
140
+
141
+ ### 📊 Performance Validation
142
+ Small Batch Training (SDXL, BS=2, 1.8K steps)
143
+ ![Training Comparison](https://github.com/user-attachments/assets/7eff0671-cc59-47fc-8b63-d5205456d649)
144
+
145
+ - **🟢 Prodigy_adv** (beta1=0.9, d0=1e-5): Final LR=2.9e-4
146
+ - **🔵 Prodigy_adv + Simplified_AdEMAMix** (beta1=0.99, α=100, d0=1e-7): Final LR=5.8e-6
147
+
148
+ **Results**:
149
+ - Simplified_AdEMAMix shows faster convergence and better final performance
150
+ - D-Adaptation automatically handles aggressive updates (50x smaller LR)
151
+ - Generated samples show significantly better quality with Simplified_AdEMAMix
152
+
153
+ ---
154
+
155
+ ## ⚠️ Known Limitations
156
+
157
+ ### 1. Prodigy_Adv Sensitivity
158
+ - Highly sensitive to gradient modifications (Adopt normalization, low-rank factorization)
159
+ - May fail to increase learning rate in some LoRA scenarios
160
+ - **Fix**: Disable factorization or set beta1=0
161
+
162
+ ### 2. Aggressive Learning Rates
163
+ - Can destabilize factored first moment
164
+ - **Recommendation**: Check Prodigy learning rate as reference for safe LR threshold
165
+
166
+ ---
167
+
168
+ ## 📚 References
169
+
170
+ 1. [SMMF: Square-Matricized Momentum Factorization](https://arxiv.org/abs/2412.08894)
171
+ 2. [The AdEMAMix Optimizer: Better, Faster, Older](https://arxiv.org/abs/2409.03137)
172
+ 3. [Connections between Schedule-Free Optimizers, AdEMAMix, and Accelerated SGD Variants](https://arxiv.org/abs/2502.02431)
173
+
174
+ ---
@@ -0,0 +1,19 @@
1
+ adv_optm/__init__.py,sha256=hHL2QwlnQMvIggC9ejOxGOKq65DnnYaHC1ScPQMuIIw,306
2
+ adv_optm/optim/AdamW_adv.py,sha256=Pu0TB14dOhcq9kwXclMIeKCI6ef_P0emwzxPu6xuBM0,14252
3
+ adv_optm/optim/Adopt_adv.py,sha256=71o9BHV3XFefJX21G37PKG96D09x-PSU0eW3Q7WkAjs,17427
4
+ adv_optm/optim/Lion_Prodigy_adv.py,sha256=kIAGXoMbDNRg5reKXtUC_vQQ2gyM-NXPB-Pv9zSpiE8,12787
5
+ adv_optm/optim/Lion_adv.py,sha256=05j_j6LIzHW5b79DVwMIf1FZHVNB8xnStNVjlOdVkCE,8256
6
+ adv_optm/optim/Prodigy_adv.py,sha256=NykG5gcAHjmhlMutknOjAoYKI-K6e5lA3Q9J9vkqnz0,22357
7
+ adv_optm/optim/Simplified_AdEMAMix.py,sha256=opIZjnGJ03-DDAIHTZyJBMReVfgusGDb8FZSWMU3-UM,9774
8
+ adv_optm/optim/__init__.py,sha256=pcP865H2j1tut2VfTUhzQh7V8TF_tzPjqFnjMfFed2k,382
9
+ adv_optm/util/BF16_Stochastic_Rounding.py,sha256=Q5H0BcogmE4atP65dLoI21HKSf50lRdsBDfeF6v9Tbg,1548
10
+ adv_optm/util/Effective_Shape.py,sha256=TBvIk1V8IuTbbBsxuekJA4e_v8JlR5Nujtut8RTWAm4,318
11
+ adv_optm/util/NNMF.py,sha256=yRf5IP5Sjq0Uf0DxN0Q8NxEGSdD-f1ULziLVDOjY8K4,639
12
+ adv_optm/util/One_Bit_Boolean.py,sha256=Wat49esdwohuN-OHOFMW8D0aOQgV9cP5Rl8z6yfmpos,1068
13
+ adv_optm/util/OrthoGrad.py,sha256=NzInuBQGy_Ja__M1R9XbvqVaQ0fhGbtGgFE9YON7B3I,707
14
+ adv_optm/util/__init__.py,sha256=qoyIF0jcLjs_vSEcsv36clw5LFNBEbifyXrrVxMH-G4,349
15
+ adv_optm-0.1.9.dist-info/licenses/LICENSE,sha256=HrhfyXIkWY2tGFK11kg7vPCqhgh5DcxleloqdhrpyMY,11558
16
+ adv_optm-0.1.9.dist-info/METADATA,sha256=IvocLvlwTsZ5WPmO6ZsVffmybwZRf3tr_ALojuwL6dw,8422
17
+ adv_optm-0.1.9.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
18
+ adv_optm-0.1.9.dist-info/top_level.txt,sha256=iNfBIIzu-lPrQ7jyC56WBCcbkRwitM2nJ15-MRQ_6fg,9
19
+ adv_optm-0.1.9.dist-info/RECORD,,
@@ -1,130 +0,0 @@
1
- Metadata-Version: 2.4
2
- Name: adv_optm
3
- Version: 0.1.7
4
- Summary: A family of highly efficient, lightweight yet powerful optimizers.
5
- Home-page: https://github.com/Koratahiu/Advanced_Optimizers
6
- Author: Koratahiu
7
- Author-email: hiuhonor@gmail.com
8
- License: Apache 2.0
9
- Keywords: llm,fine-tuning,memory-efficient,low-rank,compression,pytorch,optimizer,adam
10
- Classifier: Programming Language :: Python :: 3
11
- Classifier: License :: OSI Approved :: Apache Software License
12
- Classifier: Operating System :: OS Independent
13
- Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
14
- Classifier: Topic :: Software Development :: Libraries :: Python Modules
15
- Requires-Python: >=3.8
16
- Description-Content-Type: text/markdown
17
- License-File: LICENSE
18
- Requires-Dist: torch>=2.0
19
- Dynamic: author
20
- Dynamic: author-email
21
- Dynamic: classifier
22
- Dynamic: description
23
- Dynamic: description-content-type
24
- Dynamic: home-page
25
- Dynamic: keywords
26
- Dynamic: license
27
- Dynamic: license-file
28
- Dynamic: requires-dist
29
- Dynamic: requires-python
30
- Dynamic: summary
31
-
32
- # Advanced Optimizers
33
-
34
- This repo introduces a new family of highly efficient, lightweight yet powerful optimizers, born from extensive research into recent academic literature and validated through practical training runs across diverse models.
35
-
36
- ---
37
-
38
- ### Install
39
-
40
- `pip install adv_optm`
41
-
42
- ---
43
-
44
- ### Theory (Inspired by SMMF)
45
-
46
- Based primarily on:
47
- **[SMMF: Square-Matricized Momentum Factorization for Memory-Efficient Optimization](https://arxiv.org/abs/2412.08894)**
48
-
49
- The core innovation:
50
- - Uses fast, non-negative matrix factorization (NNMF - rank 1), but **reconstructs the full state before each update** to preserve momentum accuracy, then re-factors afterward (factor → reconstruct → update → factor cycle).
51
- - For the *signed first moment*, we split into **sign + absolute value**:
52
- - Sign is stored as **1-bit state** via bitwise ops (SMMF originally used 8-bit with 7 bits wasted).
53
- - Absolute value goes through the factor/reconstruct cycle using two factored vectors + the signed state.
54
- - Final storage: **four factored vectors + one 1-bit sign**.
55
- - Updates behave like full-state Adam but with drastically reduced memory.
56
-
57
- > ✅ **TL;DR**: Lightweight, strong, memory-efficient optimizer.
58
-
59
- ---
60
-
61
- ### Memory Cost
62
-
63
- - **Adopt_Factored** for full SDXL finetune: **328 MB** (4 small vectors + 1-bit state)
64
- - **Adopt_Factored with AdEMAMix** for full SDXL finetune: **625 MB** (6 small vectors + two 1-bit states)
65
- > SDXL is 6.5GB model.
66
-
67
- ---
68
-
69
- ### ⏱️ Speed (my tests in SDXL - BS 4)
70
-
71
- - **Adopt_Factored**: ~10s/it
72
- - **Adopt_Factored with AdEMAMix**: ~12s/it
73
- - **Adafactor**: ~8.5s/it
74
- → Overhead from compression/reconstruction cycles.
75
- → It's faster than [MLorc](https://arxiv.org/abs/2506.01897) (~12s/it), which uses RSVD compression, and should be the fastest momentum compression (AFAIK).
76
-
77
- ---
78
-
79
- ### 📈 Performance
80
-
81
- - **Better than Adafactor, and CAME factorzation methods**
82
- - **Comparable or identical to Adam** (see SMMF paper results)
83
-
84
- ---
85
-
86
- ### Available Optimizers (all support `Factored` toggle)
87
-
88
- Set `Factored=False` to disable factorization and run as a full uncompressed optimizer (like vanilla Adam).
89
-
90
- 1. **Adam**
91
- 2. **Prodigy**
92
- 3. **Adopt**
93
-
94
- ---
95
-
96
- ### Bonus Features (Built-in)
97
-
98
- - **Fused Backward Pass**
99
-
100
- - **Stochastic Rounding (SR)**: Improves quality and convergence for **BF16 training**.
101
-
102
- - **[AdEMAMix](https://arxiv.org/abs/2409.03137)**
103
- → This adds a second, slow-moving EMA, which is combined with the primary momentum to stabilize updates, especially during long runs of full finetuning.
104
- → A higher value of beta3 (e.g., 0.9999) gives the EMA a longer memory, making it more stable but slower to adapt. A lower value (e.g., 0.999) is often better for shorter training runs (2k-4k steps).
105
- → When `factored` is true, it compresses the new momentum in the same way as the first moment (1-bit state + 2 vectors). However, this introduces noticeable overhead as we are compressing/reconstructing a third state each step.
106
-
107
- ⚠️ **Note**: AdEMAMix updates are more aggressive than normal Adam/Adopt, so use a x2-x5 smaller LR than usual (or use Prodigy).
108
-
109
- - **[`atan2` smoothing & scaling](https://github.com/lucidrains/adam-atan2-pytorch)**
110
- → Robust `eps` replacement (no tuning!) + built-in gradient clipping
111
- → *Ideal for ADOPT* (which normally needs higher `eps` and clipping), so `use_atan2` is all-in-one for it.
112
-
113
- - **[OrthoGrad](https://github.com/LucasPrietoAl/grokking-at-the-edge-of-numerical-stability)**
114
- → Removes gradient component parallel to weights → prevents "naïve loss minimization" (NLM) → reduces natural overfitting
115
- → Perfect for fine-tuning the direction of existing features (e.g., full finetune or training a trained LoRA) without weight decay erasing prior knowledge.
116
-
117
- ⚠️ **Note**: OrthoGrad introduces **~33% time overhead**, so take this into account.
118
-
119
- - **[Grams: Gradient Descent with Adaptive Momentum Scaling](https://github.com/Gunale0926/Grams)**
120
- → Eliminates the need for 1-bit momentum sign storage by using the **sign of gradients** for the first moment.
121
-
122
- ⚠️ **Not recommended for small batch sizes**: gradients are too noisy, which can destabilize momentum (tested for Prodigy and it made the optimizer slower to find the LR or converge in BS 4).
123
-
124
- ### Other Notes
125
-
126
- - **Adopt** skips the first step (only initializes the states) and has built-in clipping (sticking to the original optimizer), but we skip both of these when you enable `use_atan2`; as the optimizer becomes scale-invariant and the values of the states won't cause any issues or instability.
127
-
128
- - When `use_atan2` is True, `eps` will be ignored and you should also disable any gradient clipping.
129
-
130
- ---
@@ -1,19 +0,0 @@
1
- adv_optm/__init__.py,sha256=CZ_tjWWk5d5D8q_R0rcr8vvwlZyY_44zyAcIAmN_SDY,306
2
- adv_optm/optim/AdamW_adv.py,sha256=ZeNzk2tWbyd2QDI5hp4InwG3iuHHfqLrlhr_VmcQfRM,13884
3
- adv_optm/optim/Adopt_adv.py,sha256=rzBWfFOPrMuC6vwETsw7QPKmVXcv4IJRDCTj-6eU1Qk,14798
4
- adv_optm/optim/Lion_Prodigy_adv.py,sha256=JMss9X8lRpIU4E34PfFpWMMal_XNvZ8Yuqc6i7R5wIQ,14588
5
- adv_optm/optim/Lion_adv.py,sha256=BA4bSEhJiQ7BhGLDRn9nuMlBrLVh-OMscbmSTeGgRmI,10137
6
- adv_optm/optim/Prodigy_adv.py,sha256=gJL2r32R3xGD62jMR55ZyKxRv0yL70XHxj4FzEJbFc4,20196
7
- adv_optm/optim/Simplified_AdEMAMix.py,sha256=opIZjnGJ03-DDAIHTZyJBMReVfgusGDb8FZSWMU3-UM,9774
8
- adv_optm/optim/__init__.py,sha256=pcP865H2j1tut2VfTUhzQh7V8TF_tzPjqFnjMfFed2k,382
9
- adv_optm/util/BF16_Stochastic_Rounding.py,sha256=Q5H0BcogmE4atP65dLoI21HKSf50lRdsBDfeF6v9Tbg,1548
10
- adv_optm/util/Effective_Shape.py,sha256=TBvIk1V8IuTbbBsxuekJA4e_v8JlR5Nujtut8RTWAm4,318
11
- adv_optm/util/NNMF.py,sha256=yRf5IP5Sjq0Uf0DxN0Q8NxEGSdD-f1ULziLVDOjY8K4,639
12
- adv_optm/util/One_Bit_Boolean.py,sha256=Wat49esdwohuN-OHOFMW8D0aOQgV9cP5Rl8z6yfmpos,1068
13
- adv_optm/util/OrthoGrad.py,sha256=NzInuBQGy_Ja__M1R9XbvqVaQ0fhGbtGgFE9YON7B3I,707
14
- adv_optm/util/__init__.py,sha256=qoyIF0jcLjs_vSEcsv36clw5LFNBEbifyXrrVxMH-G4,349
15
- adv_optm-0.1.7.dist-info/licenses/LICENSE,sha256=HrhfyXIkWY2tGFK11kg7vPCqhgh5DcxleloqdhrpyMY,11558
16
- adv_optm-0.1.7.dist-info/METADATA,sha256=BEKyVG9zVdb9WThOw9YtgWZ_zqDmErumpY5Fr-AkbX0,5846
17
- adv_optm-0.1.7.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
18
- adv_optm-0.1.7.dist-info/top_level.txt,sha256=iNfBIIzu-lPrQ7jyC56WBCcbkRwitM2nJ15-MRQ_6fg,9
19
- adv_optm-0.1.7.dist-info/RECORD,,