adv-optm 0.1.6__tar.gz → 0.1.8__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of adv-optm might be problematic. Click here for more details.

Files changed (24) hide show
  1. {adv_optm-0.1.6 → adv_optm-0.1.8}/PKG-INFO +1 -1
  2. {adv_optm-0.1.6 → adv_optm-0.1.8}/adv_optm/__init__.py +3 -1
  3. {adv_optm-0.1.6 → adv_optm-0.1.8}/adv_optm/optim/AdamW_adv.py +10 -4
  4. {adv_optm-0.1.6 → adv_optm-0.1.8}/adv_optm/optim/Adopt_adv.py +5 -5
  5. {adv_optm-0.1.6 → adv_optm-0.1.8}/adv_optm/optim/Lion_Prodigy_adv.py +3 -37
  6. {adv_optm-0.1.6 → adv_optm-0.1.8}/adv_optm/optim/Lion_adv.py +6 -39
  7. {adv_optm-0.1.6 → adv_optm-0.1.8}/adv_optm/optim/Prodigy_adv.py +112 -44
  8. adv_optm-0.1.8/adv_optm/optim/Simplified_AdEMAMix.py +246 -0
  9. {adv_optm-0.1.6 → adv_optm-0.1.8}/adv_optm/optim/__init__.py +2 -0
  10. {adv_optm-0.1.6 → adv_optm-0.1.8}/adv_optm.egg-info/PKG-INFO +1 -1
  11. {adv_optm-0.1.6 → adv_optm-0.1.8}/adv_optm.egg-info/SOURCES.txt +1 -0
  12. {adv_optm-0.1.6 → adv_optm-0.1.8}/setup.py +1 -1
  13. {adv_optm-0.1.6 → adv_optm-0.1.8}/LICENSE +0 -0
  14. {adv_optm-0.1.6 → adv_optm-0.1.8}/README.md +0 -0
  15. {adv_optm-0.1.6 → adv_optm-0.1.8}/adv_optm/util/BF16_Stochastic_Rounding.py +0 -0
  16. {adv_optm-0.1.6 → adv_optm-0.1.8}/adv_optm/util/Effective_Shape.py +0 -0
  17. {adv_optm-0.1.6 → adv_optm-0.1.8}/adv_optm/util/NNMF.py +0 -0
  18. {adv_optm-0.1.6 → adv_optm-0.1.8}/adv_optm/util/One_Bit_Boolean.py +0 -0
  19. {adv_optm-0.1.6 → adv_optm-0.1.8}/adv_optm/util/OrthoGrad.py +0 -0
  20. {adv_optm-0.1.6 → adv_optm-0.1.8}/adv_optm/util/__init__.py +0 -0
  21. {adv_optm-0.1.6 → adv_optm-0.1.8}/adv_optm.egg-info/dependency_links.txt +0 -0
  22. {adv_optm-0.1.6 → adv_optm-0.1.8}/adv_optm.egg-info/requires.txt +0 -0
  23. {adv_optm-0.1.6 → adv_optm-0.1.8}/adv_optm.egg-info/top_level.txt +0 -0
  24. {adv_optm-0.1.6 → adv_optm-0.1.8}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: adv_optm
3
- Version: 0.1.6
3
+ Version: 0.1.8
4
4
  Summary: A family of highly efficient, lightweight yet powerful optimizers.
5
5
  Home-page: https://github.com/Koratahiu/Advanced_Optimizers
6
6
  Author: Koratahiu
@@ -2,6 +2,7 @@ from .optim import (
2
2
  AdamW_adv,
3
3
  Prodigy_adv,
4
4
  Adopt_adv,
5
+ Simplified_AdEMAMix,
5
6
  Lion_adv,
6
7
  Lion_Prodigy_adv,
7
8
  )
@@ -10,8 +11,9 @@ __all__ = [
10
11
  "AdamW_adv",
11
12
  "Prodigy_adv",
12
13
  "Adopt_adv",
14
+ "Simplified_AdEMAMix",
13
15
  "Lion_adv",
14
16
  "Lion_Prodigy_adv",
15
17
  ]
16
18
 
17
- __version__ = "0.1.6"
19
+ __version__ = "0.1.8"
@@ -55,7 +55,7 @@ class AdamW_adv(torch.optim.Optimizer):
55
55
  the warmup, `alpha` ramps from 0 to its target value. If `None`,
56
56
  the scheduler is disabled. (default: None)
57
57
  factored (bool): whether to use the factorization or disable it to use
58
- the uncompressed optimizer. (default: True)
58
+ the uncompressed optimizer. (default: False)
59
59
  """
60
60
 
61
61
  def __init__(
@@ -76,7 +76,7 @@ class AdamW_adv(torch.optim.Optimizer):
76
76
  beta3_ema: float = 0.9999,
77
77
  alpha: float = 5.0,
78
78
  t_alpha: int | None = None,
79
- factored: bool = True,
79
+ factored: bool = False,
80
80
  ):
81
81
  if not (lr >= 0.0):
82
82
  raise ValueError(f"Learning-rate should be >= 0.0. Got {lr}")
@@ -216,7 +216,10 @@ class AdamW_adv(torch.optim.Optimizer):
216
216
  del unpacked_sign_slow
217
217
 
218
218
  mt_slow.mul_(beta3_ema).add_(grad_reshaped, alpha=1.0 - beta3_ema)
219
- update = mt + (alpha_t * mt_slow) if beta1 > 0 else grad_reshaped + (alpha_t * mt_slow)
219
+ if beta1 > 0:
220
+ update = torch.add(mt, mt_slow, alpha=alpha_t)
221
+ else:
222
+ update = torch.add(grad_reshaped, mt_slow, alpha=alpha_t)
220
223
  else:
221
224
  update = mt.clone() if beta1 > 0 else grad_reshaped.clone()
222
225
  del grad_reshaped
@@ -262,7 +265,10 @@ class AdamW_adv(torch.optim.Optimizer):
262
265
  if self.use_AdEMAMix:
263
266
  exp_avg_slow = state['exp_avg_slow']
264
267
  exp_avg_slow.mul_(beta3_ema).add_(grad, alpha=1 - beta3_ema)
265
- update = exp_avg + (alpha_t * exp_avg_slow) if beta1 > 0 else grad + (alpha_t * exp_avg_slow)
268
+ if beta1 > 0:
269
+ update = torch.add(exp_avg, exp_avg_slow, alpha=alpha_t)
270
+ else:
271
+ update = torch.add(grad, exp_avg_slow, alpha=alpha_t)
266
272
  else:
267
273
  update = exp_avg.clone() if beta1 > 0 else grad.clone()
268
274
 
@@ -63,7 +63,7 @@ class Adopt_adv(torch.optim.Optimizer):
63
63
  the scheduler is disabled and the full `alpha` value is used from
64
64
  the start. (default: None)
65
65
  factored (bool): whether to use the factorization or disable it to use
66
- the uncompressed optimizer. (default: True)
66
+ the uncompressed optimizer. (default: False)
67
67
  """
68
68
 
69
69
  def __init__(
@@ -84,7 +84,7 @@ class Adopt_adv(torch.optim.Optimizer):
84
84
  beta3_ema: float = 0.9999,
85
85
  alpha: float = 5.0,
86
86
  t_alpha: int | None = None,
87
- factored: bool = True,
87
+ factored: bool = False,
88
88
  ):
89
89
  if not (lr >= 0.0):
90
90
  raise ValueError(f"Learning-rate should be >= 0.0. Got {lr}")
@@ -235,7 +235,7 @@ class Adopt_adv(torch.optim.Optimizer):
235
235
 
236
236
  if self.use_AdEMAMix:
237
237
  mt_slow.mul_(beta3_ema).add_(normalized_grad, alpha=1.0 - beta3_ema)
238
- update = mt + (alpha_t * mt_slow)
238
+ update = torch.add(mt, m_slow, alpha=alpha_t)
239
239
  update = update.view(p.shape)
240
240
  else:
241
241
  update = mt.view(p.shape)
@@ -295,9 +295,9 @@ class Adopt_adv(torch.optim.Optimizer):
295
295
 
296
296
  if self.use_AdEMAMix:
297
297
  m_slow.mul_(beta3_ema).add_(normalized_grad, alpha=1.0 - beta3_ema)
298
- update = m + (alpha_t * m_slow)
298
+ update = torch.add(m, m_slow, alpha=alpha_t)
299
299
  else:
300
- update = m
300
+ update = m.clone()
301
301
 
302
302
  if self.use_atan2:
303
303
  update.mul_(group['lr'] * 1.2732395447351628)
@@ -33,8 +33,6 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
33
33
  (default: 0.0).
34
34
  factored (bool): whether to use the factorization or use the
35
35
  uncompressed optimizer. (default: True)
36
- variance_reduction (bool): whether to use the variance reduction technique
37
- from "Convergence Analysis of the Lion Optimizer" (arXiv:2508.12327v1). (default: False).
38
36
  d0 (float):
39
37
  Initial D estimate for D-adaptation (default 1e-6). Rarely needs changing.
40
38
  d_coef (float):
@@ -66,7 +64,6 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
66
64
  use_cautious: bool = False,
67
65
  clip_threshold: float = 0.0,
68
66
  factored: bool = True,
69
- variance_reduction: bool = False,
70
67
  # prodigy parameters
71
68
  beta3: float = None,
72
69
  d0: float = 1e-6,
@@ -97,7 +94,6 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
97
94
  self.stochastic_rounding = stochastic_rounding
98
95
  self.use_cautious = use_cautious
99
96
  self.factored = factored
100
- self.variance_reduction = variance_reduction
101
97
  self.fsdp_in_use = fsdp_in_use
102
98
  super().__init__(params, defaults)
103
99
  # Global state for accumulating metrics across parameter updates within a single step.
@@ -183,12 +179,8 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
183
179
  state['mv_m_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
184
180
  packed_d2 = (d2 + 7) // 8
185
181
  state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=p.device)
186
- if self.variance_reduction:
187
- state['prev_grad'] = torch.zeros((d1, d2), device=p.device, dtype=dtype)
188
182
  else: # Fallback to standard Lion
189
183
  state['exp_avg'] = torch.zeros_like(p, device=p.device, dtype=dtype)
190
- if self.variance_reduction:
191
- state['prev_grad'] = torch.zeros_like(p, device=p.device, dtype=dtype)
192
184
 
193
185
  if state['factored']:
194
186
  # Factored Path
@@ -215,20 +207,7 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
215
207
  update_for_param = signed_update.view(p.shape).mul(self.dlr)
216
208
 
217
209
  # Update momentum m_t = β2*m_{t-1} + (1-β2)*lr*g_t
218
- if self.variance_reduction:
219
- if state['step'] == 1:
220
- exp_avg.copy_(grad_reshaped)
221
- else:
222
- # Heuristic Prodigy-STORM update
223
- correction = exp_avg.sub(state['prev_grad'])
224
- grad_alpha = self.d * (1 - self.beta2) + self.beta2
225
- exp_avg.copy_(grad_reshaped).mul_(grad_alpha).add_(correction, alpha=self.beta2)
226
- del correction, grad_alpha
227
- state['prev_grad'].copy_(grad_reshaped)
228
- else:
229
- # Standard Prodigy-Lion
230
- alpha = self.d * (1 - self.beta2)
231
- exp_avg.mul_(self.beta2).add_(grad_reshaped, alpha=alpha)
210
+ exp_avg.mul_(self.beta2).add_(grad_reshaped, alpha=self.d * (1 - self.beta2))
232
211
  del grad_reshaped
233
212
 
234
213
  # Compress new momentum m_t and store factors
@@ -254,20 +233,7 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
254
233
  update_for_param = signed_update.mul(self.dlr)
255
234
 
256
235
  # Update momentum
257
- if self.variance_reduction:
258
- if state['step'] == 1:
259
- exp_avg.copy_(grad)
260
- else:
261
- # Heuristic Prodigy-STORM update
262
- correction = exp_avg.sub(state['prev_grad'])
263
- grad_alpha = self.d * (1 - self.beta2) + self.beta2
264
- exp_avg.copy_(grad).mul_(grad_alpha).add_(correction, alpha=self.beta2)
265
- del grad_alpha, correction
266
- state['prev_grad'].copy_(grad)
267
- else:
268
- # Standard Prodigy-Lion
269
- alpha = self.d * (1 - self.beta2)
270
- exp_avg.mul_(self.beta2).add_(grad, alpha=alpha)
236
+ exp_avg.mul_(self.beta2).add_(grad, alpha=self.d * (1 - self.beta2))
271
237
 
272
238
  # --- Accumulate Prodigy stats ---
273
239
  d0, safeguard_warmup, slice_p = group['d0'], group['safeguard_warmup'], group['slice_p']
@@ -298,7 +264,7 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
298
264
  else:
299
265
  p.data.add_(-update_for_param)
300
266
 
301
- del update_for_param
267
+ del update_for_param
302
268
 
303
269
  @torch.no_grad()
304
270
  def step(self, closure: Optional[callable] = None):
@@ -33,8 +33,6 @@ class Lion_adv(torch.optim.Optimizer):
33
33
  (default: 0.0).
34
34
  factored (bool): whether to use the factorization or use the
35
35
  uncompressed optimizer. (default: True)
36
- variance_reduction (bool): whether to use the variance reduction technique
37
- from "Convergence Analysis of the Lion Optimizer" (arXiv:2508.12327v1). (default: False).
38
36
  """
39
37
 
40
38
  def __init__(
@@ -49,7 +47,6 @@ class Lion_adv(torch.optim.Optimizer):
49
47
  use_cautious: bool = False,
50
48
  clip_threshold: float = 0.0,
51
49
  factored: bool = True,
52
- variance_reduction: bool = False,
53
50
  ):
54
51
  if not lr > 0.0:
55
52
  raise ValueError(f"Learning rate must be > 0.0, but got {lr}")
@@ -69,7 +66,6 @@ class Lion_adv(torch.optim.Optimizer):
69
66
  self.stochastic_rounding = stochastic_rounding
70
67
  self.use_cautious = use_cautious
71
68
  self.factored = factored
72
- self.variance_reduction = variance_reduction
73
69
  super().__init__(params, defaults)
74
70
 
75
71
  @property
@@ -122,12 +118,8 @@ class Lion_adv(torch.optim.Optimizer):
122
118
  state['mv_m_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
123
119
  packed_d2 = (d2 + 7) // 8
124
120
  state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=p.device)
125
- if self.variance_reduction:
126
- state['prev_grad'] = torch.zeros((d1, d2), device=p.device, dtype=dtype)
127
121
  else: # Fallback to standard Lion
128
122
  state['exp_avg'] = torch.zeros_like(p, device=p.device, dtype=dtype)
129
- if self.variance_reduction:
130
- state['prev_grad'] = torch.zeros_like(p, device=p.device, dtype=dtype)
131
123
 
132
124
  state['step'] += 1
133
125
  beta1, beta2 = group["betas"]
@@ -157,21 +149,9 @@ class Lion_adv(torch.optim.Optimizer):
157
149
  # Parameter update
158
150
  update_for_param = signed_update.view(p.shape).mul_(lr)
159
151
 
160
- # Update momentum
161
- if self.variance_reduction:
162
- if state['step'] == 1:
163
- exp_avg.copy_(grad_reshaped)
164
- else:
165
- # Use the simplified STORM update: m_t = g_t + β₂ * (m_{t-1} - g_{t-1})
166
- correction = exp_avg.sub(state['prev_grad'])
167
- # Calculate the new momentum and store it back into exp_avg
168
- exp_avg.copy_(grad_reshaped).add_(correction, alpha=beta2)
169
- del correction
170
- # Update prev_grad for the next iteration
171
- state['prev_grad'].copy_(grad_reshaped)
172
- else:
173
- # Standard Lion momentum update
174
- exp_avg.mul_(beta2).add_(grad_reshaped, alpha=1-beta2)
152
+ # Standard Lion momentum update
153
+ exp_avg.mul_(beta2).add_(grad_reshaped, alpha=1-beta2)
154
+ del grad_reshaped
175
155
 
176
156
  # Compress new momentum m_t and store factors
177
157
  state['sign'] = _pack_bools(exp_avg > 0)
@@ -195,21 +175,8 @@ class Lion_adv(torch.optim.Optimizer):
195
175
 
196
176
  update_for_param = signed_update.mul_(lr)
197
177
 
198
- # Update momentum
199
- if self.variance_reduction:
200
- if state['step'] == 1:
201
- exp_avg.copy_(grad)
202
- else:
203
- # Use the simplified STORM update: m_t = g_t + β₂ * (m_{t-1} - g_{t-1})
204
- correction = exp_avg.sub(state['prev_grad'])
205
- # Calculate the new momentum and store it back into exp_avg
206
- exp_avg.copy_(grad).add_(correction, alpha=beta2)
207
- del correction
208
- # Update prev_grad for the next iteration
209
- state['prev_grad'].copy_(grad)
210
- else:
211
- # Standard Lion momentum update
212
- exp_avg.mul_(beta2).add_(grad, alpha=1-beta2)
178
+ # Standard Lion momentum update
179
+ exp_avg.mul_(beta2).add_(grad, alpha=1-beta2)
213
180
 
214
181
  if group["weight_decay"] != 0:
215
182
  if p.dtype == torch.bfloat16 and self.stochastic_rounding:
@@ -225,7 +192,7 @@ class Lion_adv(torch.optim.Optimizer):
225
192
  else:
226
193
  p.data.add_(-update_for_param)
227
194
 
228
- del update_for_param
195
+ del update_for_param
229
196
 
230
197
  @torch.no_grad()
231
198
  def step(self, closure: Optional[callable] = None):
@@ -52,9 +52,19 @@ class Prodigy_adv(torch.optim.Optimizer):
52
52
  highly recommended to prevent instability at the beginning of training,
53
53
  as it gradually introduces the stabilizing slow momentum term. During
54
54
  the warmup, `alpha` ramps from 0 to its target value. If `None`,
55
- the scheduler is disabled and th
55
+ the scheduler is disabled.
56
+ Simplified_AdEMAMix (bool): whether to use the Simplified AdEMAMix update rule.
57
+ This changes the EMA to accumulator and the update numerator to `alpha_grad * grad + mt`, which can be
58
+ more responsive, especially for small batch sizes. Enabling this will
59
+ automatically disable `use_AdEMAMix`, `use_cautious`, `use_grams`,
60
+ and `use_atan2`. (default: False)
61
+ alpha_grad (float): Mixing coefficient for the Simplified AdEMAMix update rule
62
+ (only used when `Simplified_AdEMAMix` is `True`). Controls the weight of the
63
+ current gradient. For small batch sizes, use high values (e.g., 10-100) to be
64
+ more responsive. For large batch sizes, use low values (e.g., 0-1) for
65
+ stability. (default: 100.0)
56
66
  factored (bool): whether to use the factorization or disable it to use
57
- the uncompressed optimizer. (default: True)
67
+ the uncompressed optimizer. (default: False)
58
68
  d0 (float):
59
69
  Initial D estimate for D-adaptation (default 1e-6). Rarely needs changing.
60
70
  d_coef (float):
@@ -72,6 +82,9 @@ class Prodigy_adv(torch.optim.Optimizer):
72
82
  slice_p (int): Reduce memory usage by calculating LR adaptation statistics on only every
73
83
  pth entry of each tensor. For values greater than 1 this an an approximation to standard
74
84
  Prodigy. Values ~11 are reasonable (default 11).
85
+ prodigy_steps (int): If greater than zero, disable Prodigy's stepsize adjustments
86
+ after the specified optimiser step and release all state memory required by Prodigy
87
+ (default: 0).
75
88
  """
76
89
 
77
90
  def __init__(
@@ -91,7 +104,9 @@ class Prodigy_adv(torch.optim.Optimizer):
91
104
  beta3_ema: float = 0.9999,
92
105
  alpha: float = 5.0,
93
106
  t_alpha: int | None = None,
94
- factored: bool = True,
107
+ Simplified_AdEMAMix: bool = False,
108
+ alpha_grad: float = 100.0,
109
+ factored: bool = False,
95
110
  # prodigy parameters
96
111
  beta3: float = None,
97
112
  d0: float = 1e-6,
@@ -100,6 +115,7 @@ class Prodigy_adv(torch.optim.Optimizer):
100
115
  safeguard_warmup: bool = False,
101
116
  fsdp_in_use: bool = False,
102
117
  slice_p: int = 11,
118
+ prodigy_steps: int = 0,
103
119
  ):
104
120
  if not (lr >= 0.0):
105
121
  raise ValueError(f"Learning-rate should be >= 0.0. Got {lr}")
@@ -109,6 +125,22 @@ class Prodigy_adv(torch.optim.Optimizer):
109
125
  raise ValueError(f"Epsilon should be >= 0.0. Got {eps}")
110
126
  if not (weight_decay >= 0.0):
111
127
  raise ValueError(f"Weight-decay should be >= 0.0. Got {weight_decay}")
128
+ if not (prodigy_steps >= 0):
129
+ raise ValueError(f"prodigy_steps should be >= 0. Got {prodigy_steps}")
130
+ if betas[0] == 0.0 and Simplified_AdEMAMix:
131
+ raise ValueError(f"Beta 1 cannot be 0.0 when using Simplified_AdEMAMix. Got {betas[0]}")
132
+ if use_AdEMAMix and Simplified_AdEMAMix:
133
+ print("Warning: use_AdEMAMix is incompatible with Simplified_AdEMAMix, Disabling use_AdEMAMix.")
134
+ if use_grams and Simplified_AdEMAMix:
135
+ print("Warning: use_grams is incompatible with Simplified_AdEMAMix, Disabling use_grams.")
136
+ if use_cautious and Simplified_AdEMAMix:
137
+ print("Warning: use_cautious is incompatible with Simplified_AdEMAMix, Disabling use_cautious.")
138
+ if use_atan2 and Simplified_AdEMAMix:
139
+ print("Warning: use_atan2 is incompatible with Simplified_AdEMAMix. Disabling use_atan2.")
140
+ use_atan2 = False
141
+ if Simplified_AdEMAMix and alpha_grad > 0:
142
+ # scales d_coef by alpha_grad, this force prodigy to behave well with Simplified_AdEMAMix
143
+ d_coef = d_coef/alpha_grad
112
144
 
113
145
  defaults = {
114
146
  "lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay,
@@ -117,12 +149,14 @@ class Prodigy_adv(torch.optim.Optimizer):
117
149
  "beta3_ema": beta3_ema, "alpha": alpha, "t_alpha": t_alpha,
118
150
  "beta3": beta3, "d": d0, "d0": d0, "d_max": d0, "d_numerator": 0.0, "d_coef": d_coef,
119
151
  "growth_rate": growth_rate, "safeguard_warmup": safeguard_warmup, "k": 0, "slice_p": slice_p,
120
- "fsdp_in_use": fsdp_in_use,
152
+ "fsdp_in_use": fsdp_in_use, "prodigy_steps": prodigy_steps,
153
+ "alpha_grad": alpha_grad,
121
154
  }
122
155
  self.stochastic_rounding = stochastic_rounding
123
- self.use_cautious = use_cautious
124
- self.use_grams = use_grams
125
- self.use_AdEMAMix = use_AdEMAMix
156
+ self.use_cautious = use_cautious and not Simplified_AdEMAMix
157
+ self.use_grams = use_grams and not Simplified_AdEMAMix
158
+ self.use_AdEMAMix = use_AdEMAMix and not Simplified_AdEMAMix
159
+ self.Simplified_AdEMAMix = Simplified_AdEMAMix
126
160
  self.factored = factored
127
161
  self.fsdp_in_use = fsdp_in_use
128
162
  super().__init__(params, defaults)
@@ -229,6 +263,8 @@ class Prodigy_adv(torch.optim.Optimizer):
229
263
  alpha_t = alpha
230
264
  if t_alpha is not None and t_alpha > 0 and current_step < t_alpha:
231
265
  alpha_t = min(current_step * alpha / t_alpha, alpha)
266
+ if self.Simplified_AdEMAMix:
267
+ alpha_grad = group["alpha_grad"]
232
268
 
233
269
  if state['factored']:
234
270
  d1, d2 = state['effective_shape']
@@ -243,7 +279,10 @@ class Prodigy_adv(torch.optim.Optimizer):
243
279
  torch.where(unpacked_sign, mt, -mt, out=mt)
244
280
  del unpacked_sign
245
281
  # Update momentum in full-size
246
- mt.mul_(self.beta1).add_(grad_reshaped, alpha=self.d * (1.0 - self.beta1))
282
+ if self.Simplified_AdEMAMix:
283
+ mt.mul_(self.beta1).add_(grad_reshaped, alpha=self.d)
284
+ else:
285
+ mt.mul_(self.beta1).add_(grad_reshaped, alpha=self.d * (1.0 - self.beta1))
247
286
  if self.use_grams:
248
287
  mt.copy_(grad_reshaped.sign() * mt.abs())
249
288
  elif self.use_cautious:
@@ -263,7 +302,12 @@ class Prodigy_adv(torch.optim.Optimizer):
263
302
  torch.where(unpacked_sign_slow, mt_slow, -mt_slow, out=mt_slow)
264
303
  del unpacked_sign_slow
265
304
  mt_slow.mul_(beta3_ema).add_(grad_reshaped, alpha=self.d * (1.0 - beta3_ema))
266
- update = mt + (alpha_t * mt_slow) if self.beta1 > 0 else grad_reshaped + (alpha_t * mt_slow)
305
+ if self.beta1 > 0:
306
+ update = torch.add(mt, mt_slow, alpha=alpha_t)
307
+ else:
308
+ update = torch.add(grad_reshaped, mt_slow, alpha=alpha_t)
309
+ elif self.Simplified_AdEMAMix:
310
+ update = torch.add(mt, grad_reshaped, alpha=alpha_grad * self.d)
267
311
  else:
268
312
  update = mt.clone() if self.beta1 > 0 else grad_reshaped.clone()
269
313
  del grad_reshaped
@@ -297,7 +341,10 @@ class Prodigy_adv(torch.optim.Optimizer):
297
341
 
298
342
  if self.beta1 > 0:
299
343
  exp_avg = state['exp_avg']
300
- exp_avg.mul_(self.beta1).add_(grad, alpha=self.d * (1.0 - self.beta1))
344
+ if self.Simplified_AdEMAMix:
345
+ exp_avg.mul_(self.beta1).add_(grad, alpha=self.d)
346
+ else:
347
+ exp_avg.mul_(self.beta1).add_(grad, alpha=self.d * (1.0 - self.beta1))
301
348
  if self.use_grams:
302
349
  exp_avg = grad.sign() * exp_avg.abs()
303
350
  elif self.use_cautious:
@@ -309,7 +356,12 @@ class Prodigy_adv(torch.optim.Optimizer):
309
356
  if self.use_AdEMAMix:
310
357
  exp_avg_slow = state['exp_avg_slow']
311
358
  exp_avg_slow.mul_(beta3_ema).add_(grad, alpha=self.d * (1.0 - beta3_ema))
312
- update = exp_avg + (alpha_t * exp_avg_slow) if self.beta1 > 0 else grad + (alpha_t * exp_avg_slow)
359
+ if self.beta1 > 0:
360
+ update = torch.add(exp_avg, exp_avg_slow, alpha=alpha_t)
361
+ else:
362
+ update = torch.add(grad, exp_avg_slow, alpha=alpha_t)
363
+ elif self.Simplified_AdEMAMix:
364
+ update = torch.add(exp_avg, grad, alpha=alpha_grad * self.d)
313
365
  else:
314
366
  update = exp_avg.clone() if self.beta1 > 0 else grad.clone()
315
367
 
@@ -327,19 +379,27 @@ class Prodigy_adv(torch.optim.Optimizer):
327
379
  update.mul_(self.dlr)
328
380
 
329
381
  # --- Accumulate Prodigy stats ---
330
- d0, safeguard_warmup, slice_p = group['d0'], group['safeguard_warmup'], group['slice_p']
331
- s, p0 = state['s'], state['p0']
332
- grad_flat = grad.flatten().float()
333
- p_flat = p.data.flatten().float()
334
- p0 = p0.float()
382
+ prodigy_steps = group['prodigy_steps']
383
+ if prodigy_steps <= 0 or group['k'] < prodigy_steps:
384
+ d0, safeguard_warmup, slice_p = group['d0'], group['safeguard_warmup'], group['slice_p']
385
+ s, p0 = state['s'], state['p0']
386
+ grad_flat = grad.flatten().float()
387
+ p_flat = p.data.flatten().float()
388
+ p0 = p0.float()
335
389
 
336
- self.d_numerator += (self.d / d0) * self.dlr * torch.dot(grad_flat[::slice_p], p0.data - p_flat[::slice_p]).item()
390
+ self.d_numerator += (self.d / d0) * self.dlr * torch.dot(grad_flat[::slice_p], p0.data - p_flat[::slice_p]).item()
337
391
 
338
- alpha = ((self.d / d0) * self.d) if safeguard_warmup else ((self.d / d0) * self.dlr)
339
- s.mul_(self.beta3).add_(grad_flat[::slice_p], alpha=alpha)
340
- self.d_denom += s.abs().sum().item()
392
+ alpha = ((self.d / d0) * self.d) if safeguard_warmup else ((self.d / d0) * self.dlr)
393
+ s.mul_(self.beta3).add_(grad_flat[::slice_p], alpha=alpha)
394
+ self.d_denom += s.abs().sum().item()
341
395
 
342
- del s, p0, grad_flat, p_flat, alpha
396
+ del s, p0, grad_flat, p_flat, alpha
397
+ else:
398
+ # Free memory if prodigy_steps is reached
399
+ if 's' in state:
400
+ del state['s']
401
+ if 'p0' in state:
402
+ del state['p0']
343
403
 
344
404
  # Decoupled weight decay
345
405
  if group["weight_decay"] != 0:
@@ -376,29 +436,37 @@ class Prodigy_adv(torch.optim.Optimizer):
376
436
  def calculate_d(self):
377
437
  """Calculates the new `d` based on the accumulated stats."""
378
438
  g_group = self.param_groups[0]
379
- d_max, d_coef, growth_rate = g_group['d_max'], g_group['d_coef'], g_group['growth_rate']
380
439
 
381
- if self.fsdp_in_use and dist.is_available() and dist.is_initialized():
382
- # Use the device of the first parameter to avoid hardcoding '.cuda()'
383
- device = self.param_groups[0]['params'][0].device
384
- dist_tensor = torch.tensor([self.d_numerator, self.d_denom], device=device)
385
- dist.all_reduce(dist_tensor, op=dist.ReduceOp.SUM)
386
- global_d_numerator = dist_tensor[0].item()
387
- global_d_denom = dist_tensor[1].item()
388
- else:
389
- global_d_numerator = self.d_numerator
390
- global_d_denom = self.d_denom
391
-
392
- d_hat = self.d
393
- if global_d_denom > 0:
394
- d_hat = d_coef * global_d_numerator / global_d_denom
395
- if self.d == g_group['d0']:
396
- self.d = max(self.d, d_hat)
397
- d_max = max(d_max, d_hat)
398
- self.d = min(d_max, self.d * growth_rate)
399
-
440
+ # Only perform d-adaptation if prodigy_steps has not been reached
441
+ prodigy_active = not (g_group.get('prodigy_steps', 0) > 0 and g_group['k'] >= g_group['prodigy_steps'])
442
+
443
+ if prodigy_active:
444
+ d_max, d_coef, growth_rate = g_group['d_max'], g_group['d_coef'], g_group['growth_rate']
445
+
446
+ if self.fsdp_in_use and dist.is_available() and dist.is_initialized():
447
+ # Use the device of the first parameter to avoid hardcoding '.cuda()'
448
+ device = self.param_groups[0]['params'][0].device
449
+ dist_tensor = torch.tensor([self.d_numerator, self.d_denom], device=device)
450
+ dist.all_reduce(dist_tensor, op=dist.ReduceOp.SUM)
451
+ global_d_numerator = dist_tensor[0].item()
452
+ global_d_denom = dist_tensor[1].item()
453
+ else:
454
+ global_d_numerator = self.d_numerator
455
+ global_d_denom = self.d_denom
456
+
457
+ d_hat = self.d
458
+ if global_d_denom > 0:
459
+ d_hat = d_coef * global_d_numerator / global_d_denom
460
+ if self.d == g_group['d0']:
461
+ self.d = max(self.d, d_hat)
462
+ d_max = max(d_max, d_hat)
463
+ self.d = min(d_max, self.d * growth_rate)
464
+
465
+ for group in self.param_groups:
466
+ group['d_numerator'] = global_d_numerator
467
+ group['d'] = self.d
468
+ group['d_max'] = d_max
469
+
470
+ # Increment step counter for all groups, regardless of whether d was updated
400
471
  for group in self.param_groups:
401
- group['d_numerator'] = global_d_numerator
402
- group['d'] = self.d
403
- group['d_max'] = d_max
404
472
  group['k'] += 1
@@ -0,0 +1,246 @@
1
+ import torch
2
+
3
+ import math
4
+
5
+ from ..util.BF16_Stochastic_Rounding import add_stochastic_
6
+ from ..util.Effective_Shape import _get_effective_shape
7
+ from ..util.NNMF import _nnmf,_unnmf
8
+ from ..util.OrthoGrad import _orthogonalize_gradient
9
+ from ..util.One_Bit_Boolean import _pack_bools, _unpack_bools
10
+
11
+ # A little helper from the original simplified_AdEMAMix
12
+ def linear_hl_warmup_scheduler(step, beta_end, beta_start=0, warmup=1):
13
+
14
+ def f(beta, eps=1e-8):
15
+ return math.log(0.5)/math.log(beta+eps)-1
16
+
17
+ def f_inv(t):
18
+ return math.pow(0.5, 1/(t+1))
19
+
20
+ if step < warmup:
21
+ a = step / float(warmup)
22
+ return f_inv((1.0-a) * f(beta_start) + a * f(beta_end))
23
+ return beta_end
24
+
25
+ class Simplified_AdEMAMix(torch.optim.Optimizer):
26
+ """
27
+ Implements the Simplified AdEMAMix algorithm.
28
+ Refactored from:
29
+ https://github.com/DepenM/Simplified-AdEMAMix/blob/main/simplified_AdEMAMix.py
30
+
31
+ Args:
32
+ params (iterable): iterable of parameters to optimize or dicts defining
33
+ parameter groups
34
+ lr (float): learning rate (default: 1e-5)
35
+ betas (tuple[float, float]): coefficients used for computing running
36
+ averages of gradient and its square (default: (0.99, 0.999))
37
+ eps (float): term added to the denominator to improve
38
+ numerical stability (default: 1e-8)
39
+ weight_decay (float): weight decay (L2 penalty) (default: 0).
40
+ alpha_grad (float): Coeficient for mixing the current gradient and EMA. for small batch
41
+ sizes set it to high values, up to 100. And for large batch sized set it to small
42
+ value, down to 0. (default: 100)
43
+ beta1_warmup (int, optional): number of warmup steps used to increase beta1 (default: None)
44
+ min_beta1 (float, optional): minimum value of beta1 to start from (default 0.9)
45
+ vector_reshape (bool): whether to reshape 1D vectors into 2D
46
+ matrices to apply low-rank compression (default: True).
47
+ stochastic_rounding (bool): whether to use stochastic
48
+ rounding for BF16 parameter updates (default: True).
49
+ use_orthograd (bool): whether to use OrthoGrad. (default: False)
50
+ factored (bool): whether to use the factorization or disable it to use
51
+ the uncompressed optimizer. (default: False)
52
+ """
53
+
54
+ def __init__(
55
+ self,
56
+ params,
57
+ lr: float = 1e-5,
58
+ betas: tuple[float, float] = (0.99, 0.999),
59
+ eps: float = 1e-8,
60
+ weight_decay: float = 0.0,
61
+ alpha_grad: float = 100.0,
62
+ beta1_warmup: int | None = None,
63
+ min_beta1: float | None = 0.9,
64
+ use_bias_correction: bool = True,
65
+ vector_reshape: bool = True,
66
+ stochastic_rounding: bool = True,
67
+ use_orthograd: bool = False,
68
+ factored: bool = False,
69
+ ):
70
+ if not (lr >= 0.0):
71
+ raise ValueError(f"Learning-rate should be >= 0.0. Got {lr}")
72
+ if not (0.0 <= betas[0] < 1.0 and 0.0 <= betas[1] < 1.0):
73
+ raise ValueError(f"Betas should be in [0.0, 1.0). Got {betas}")
74
+ if not (eps >= 0.0):
75
+ raise ValueError(f"Epsilon should be >= 0.0. Got {eps}")
76
+ if not (weight_decay >= 0.0):
77
+ raise ValueError(f"Weight-decay should be >= 0.0. Got {weight_decay}")
78
+ if not 0.0 <= alpha_grad:
79
+ raise ValueError("Invalid alpha value: {}".format(alpha_grad))
80
+
81
+ defaults = {
82
+ "lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay,
83
+ "alpha_grad": alpha_grad, "beta1_warmup": beta1_warmup, "min_beta1": min_beta1,
84
+ "vector_reshape": vector_reshape,
85
+ "use_orthograd": use_orthograd, "use_bias_correction": use_bias_correction,
86
+ }
87
+ self.stochastic_rounding = stochastic_rounding
88
+ self.factored = factored
89
+ super().__init__(params, defaults)
90
+
91
+ @property
92
+ def supports_fused_back_pass(self):
93
+ return True
94
+
95
+ @property
96
+ def supports_memory_efficient_fp16(self):
97
+ return True
98
+
99
+ @property
100
+ def supports_flat_params(self):
101
+ return False
102
+
103
+ @torch.no_grad()
104
+ def step_parameter(self, p: torch.Tensor, group: dict, i: int | None = None):
105
+ if p.grad is None:
106
+ return
107
+
108
+ grad = p.grad
109
+ if grad.dtype != torch.float32 and self.factored:
110
+ grad = grad.float()
111
+ if group["use_orthograd"]:
112
+ grad = _orthogonalize_gradient(p, grad)
113
+ state = self.state[p]
114
+
115
+ # State Initialization
116
+ if len(state) == 0:
117
+ state['step'] = 0
118
+
119
+ should_factor = (
120
+ self.factored and
121
+ not (len(p.shape) == 1 and not group['vector_reshape'])
122
+ )
123
+
124
+ state['factored'] = should_factor
125
+
126
+ dtype = torch.float32 if self.factored else p.dtype
127
+ device = p.device
128
+
129
+ if state['factored']:
130
+ state['effective_shape'] = _get_effective_shape(p.numel())
131
+ d1, d2 = state['effective_shape']
132
+
133
+ # First moment (m)
134
+ state['mu_m_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
135
+ state['mv_m_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
136
+ packed_d2 = (d2 + 7) // 8
137
+ state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=device)
138
+ # Second moment (v)
139
+ state['mu_v_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
140
+ state['mv_v_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
141
+ else: # Fallback to standard optimizer for non-factored tensors
142
+ state['exp_avg'] = torch.zeros_like(p, device=device, dtype=dtype)
143
+ state['exp_avg_sq'] = torch.zeros_like(p, device=device, dtype=dtype)
144
+
145
+ if group['use_bias_correction']:
146
+ state['num_sum'] = 0.0
147
+ state['den_sum'] = 0.0
148
+ else:
149
+ state['num_sum'] = 1.0
150
+ state['den_sum'] = 1.0
151
+
152
+ beta1_final, beta2 = group["betas"]
153
+ beta1_warmup = group["beta1_warmup"]
154
+ alpha_grad = group["alpha_grad"]
155
+
156
+ if beta1_warmup is not None:
157
+ step = state['step'] + 1
158
+ beta1 = linear_hl_warmup_scheduler(step, beta_end=beta1_final, beta_start=group['min_beta1'], warmup=beta1_warmup)
159
+ else:
160
+ beta1 = beta1_final
161
+
162
+ if group['use_bias_correction']:
163
+ state['num_sum'] = beta1 * state['num_sum'] + 1.0
164
+ state['den_sum'] = beta2 * state['den_sum'] + (1.0 - beta2)
165
+
166
+ if state['factored']:
167
+ d1, d2 = state['effective_shape']
168
+
169
+ # Reconstruct momentum from previous step's factors
170
+ mt = _unnmf((state['mu_m_nmf'], state['mv_m_nmf']))
171
+ unpacked_sign = _unpack_bools(state['sign'], original_m=d2)
172
+ torch.where(unpacked_sign, mt, -mt, out=mt)
173
+ del unpacked_sign
174
+ # Update momentum in full-size
175
+ grad_reshaped = grad.view(d1, d2)
176
+ mt.mul_(beta1).add_(grad_reshaped, alpha=1.0)
177
+
178
+ vt = _unnmf((state['mu_v_nmf'], state['mv_v_nmf']))
179
+ vt.mul_(beta2).addcmul_(grad_reshaped, grad_reshaped, value=1.0 - beta2)
180
+
181
+ update = torch.add(mt, grad_reshaped, alpha=alpha_grad)
182
+ del grad_reshaped
183
+
184
+ denom = vt.sqrt().add_(group['eps'] * math.sqrt(state['den_sum']))
185
+ update.div_(denom)
186
+ del denom
187
+
188
+ if group['use_bias_correction']:
189
+ update = (update / state['num_sum']) * math.sqrt(state['den_sum'])
190
+
191
+ update = update.view(p.shape).mul_(group['lr'])
192
+
193
+ # Compress updated moments and store new factors
194
+ state['sign'] = _pack_bools(mt > 0)
195
+ _nnmf(mt.abs(), out=(state['mu_m_nmf'], state['mv_m_nmf']))
196
+ del mt
197
+ _nnmf(vt, out=(state['mu_v_nmf'], state['mv_v_nmf']))
198
+ del vt
199
+
200
+ else: # Standard optimizer logic for non-factored tensors
201
+ exp_avg_sq = state['exp_avg_sq']
202
+
203
+ exp_avg = state['exp_avg']
204
+ exp_avg.mul_(beta1).add_(grad, alpha=1.0)
205
+
206
+ update = torch.add(exp_avg, grad, alpha=alpha_grad)
207
+
208
+ exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
209
+
210
+ denom = exp_avg_sq.sqrt().add_(group['eps'] * math.sqrt(state['den_sum']))
211
+ update.div_(denom)
212
+ del denom
213
+
214
+ if group['use_bias_correction']:
215
+ update = (update / state['num_sum']) * math.sqrt(state['den_sum'])
216
+
217
+ update.mul_(group['lr'])
218
+
219
+ # Decoupled weight decay
220
+ if group["weight_decay"] != 0:
221
+ if p.dtype == torch.bfloat16 and self.stochastic_rounding:
222
+ add_stochastic_(p.data, p.data, alpha=-group["weight_decay"] * group["lr"])
223
+ else:
224
+ p.data.add_(p.data, alpha=-group["weight_decay"] * group["lr"])
225
+
226
+ if p.dtype == torch.bfloat16 and self.stochastic_rounding:
227
+ add_stochastic_(p.data, -update)
228
+ else:
229
+ p.data.add_(-update)
230
+ del update
231
+
232
+ state['step'] += 1
233
+
234
+ @torch.no_grad()
235
+ def step(self, closure=None):
236
+ """Performs a single optimization step."""
237
+ loss = None
238
+ if closure is not None:
239
+ with torch.enable_grad():
240
+ loss = closure()
241
+
242
+ for group in self.param_groups:
243
+ for i, p in enumerate(group['params']):
244
+ self.step_parameter(p, group, i)
245
+
246
+ return loss
@@ -1,6 +1,7 @@
1
1
  from .AdamW_adv import AdamW_adv
2
2
  from .Prodigy_adv import Prodigy_adv
3
3
  from .Adopt_adv import Adopt_adv
4
+ from .Simplified_AdEMAMix import Simplified_AdEMAMix
4
5
  from .Lion_adv import Lion_adv
5
6
  from .Lion_Prodigy_adv import Lion_Prodigy_adv
6
7
 
@@ -8,6 +9,7 @@ __all__ = [
8
9
  "AdamW_adv",
9
10
  "Prodigy_adv",
10
11
  "Adopt_adv",
12
+ "Simplified_AdEMAMix",
11
13
  "Lion_adv",
12
14
  "Lion_Prodigy_adv",
13
15
  ]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: adv_optm
3
- Version: 0.1.6
3
+ Version: 0.1.8
4
4
  Summary: A family of highly efficient, lightweight yet powerful optimizers.
5
5
  Home-page: https://github.com/Koratahiu/Advanced_Optimizers
6
6
  Author: Koratahiu
@@ -12,6 +12,7 @@ adv_optm/optim/Adopt_adv.py
12
12
  adv_optm/optim/Lion_Prodigy_adv.py
13
13
  adv_optm/optim/Lion_adv.py
14
14
  adv_optm/optim/Prodigy_adv.py
15
+ adv_optm/optim/Simplified_AdEMAMix.py
15
16
  adv_optm/optim/__init__.py
16
17
  adv_optm/util/BF16_Stochastic_Rounding.py
17
18
  adv_optm/util/Effective_Shape.py
@@ -5,7 +5,7 @@ with open("README.md", "r", encoding="utf-8") as fh:
5
5
 
6
6
  setup(
7
7
  name="adv_optm",
8
- version="0.1.6",
8
+ version="0.1.8",
9
9
  author="Koratahiu",
10
10
  author_email="hiuhonor@gmail.com",
11
11
  license='Apache 2.0',
File without changes
File without changes
File without changes
File without changes