adv-optm 0.1.7__py3-none-any.whl → 0.1.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of adv-optm might be problematic. Click here for more details.
- adv_optm/__init__.py +1 -1
- adv_optm/optim/AdamW_adv.py +13 -4
- adv_optm/optim/Adopt_adv.py +52 -13
- adv_optm/optim/Lion_Prodigy_adv.py +3 -37
- adv_optm/optim/Lion_adv.py +6 -39
- adv_optm/optim/Prodigy_adv.py +76 -39
- adv_optm-0.1.9.dist-info/METADATA +174 -0
- adv_optm-0.1.9.dist-info/RECORD +19 -0
- adv_optm-0.1.7.dist-info/METADATA +0 -130
- adv_optm-0.1.7.dist-info/RECORD +0 -19
- {adv_optm-0.1.7.dist-info → adv_optm-0.1.9.dist-info}/WHEEL +0 -0
- {adv_optm-0.1.7.dist-info → adv_optm-0.1.9.dist-info}/licenses/LICENSE +0 -0
- {adv_optm-0.1.7.dist-info → adv_optm-0.1.9.dist-info}/top_level.txt +0 -0
adv_optm/__init__.py
CHANGED
adv_optm/optim/AdamW_adv.py
CHANGED
|
@@ -55,7 +55,7 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
55
55
|
the warmup, `alpha` ramps from 0 to its target value. If `None`,
|
|
56
56
|
the scheduler is disabled. (default: None)
|
|
57
57
|
factored (bool): whether to use the factorization or disable it to use
|
|
58
|
-
the uncompressed optimizer. (default:
|
|
58
|
+
the uncompressed optimizer. (default: False)
|
|
59
59
|
"""
|
|
60
60
|
|
|
61
61
|
def __init__(
|
|
@@ -76,7 +76,7 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
76
76
|
beta3_ema: float = 0.9999,
|
|
77
77
|
alpha: float = 5.0,
|
|
78
78
|
t_alpha: int | None = None,
|
|
79
|
-
factored: bool =
|
|
79
|
+
factored: bool = False,
|
|
80
80
|
):
|
|
81
81
|
if not (lr >= 0.0):
|
|
82
82
|
raise ValueError(f"Learning-rate should be >= 0.0. Got {lr}")
|
|
@@ -86,6 +86,9 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
86
86
|
raise ValueError(f"Epsilon should be >= 0.0. Got {eps}")
|
|
87
87
|
if not (weight_decay >= 0.0):
|
|
88
88
|
raise ValueError(f"Weight-decay should be >= 0.0. Got {weight_decay}")
|
|
89
|
+
if use_cautious and use_grams:
|
|
90
|
+
print("Warning: use_cautious is incompatible with use_grams, Disabling use_cautious.")
|
|
91
|
+
use_cautious = False
|
|
89
92
|
|
|
90
93
|
defaults = {
|
|
91
94
|
"lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay,
|
|
@@ -216,7 +219,10 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
216
219
|
del unpacked_sign_slow
|
|
217
220
|
|
|
218
221
|
mt_slow.mul_(beta3_ema).add_(grad_reshaped, alpha=1.0 - beta3_ema)
|
|
219
|
-
|
|
222
|
+
if beta1 > 0:
|
|
223
|
+
update = torch.add(mt, mt_slow, alpha=alpha_t)
|
|
224
|
+
else:
|
|
225
|
+
update = torch.add(grad_reshaped, mt_slow, alpha=alpha_t)
|
|
220
226
|
else:
|
|
221
227
|
update = mt.clone() if beta1 > 0 else grad_reshaped.clone()
|
|
222
228
|
del grad_reshaped
|
|
@@ -262,7 +268,10 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
262
268
|
if self.use_AdEMAMix:
|
|
263
269
|
exp_avg_slow = state['exp_avg_slow']
|
|
264
270
|
exp_avg_slow.mul_(beta3_ema).add_(grad, alpha=1 - beta3_ema)
|
|
265
|
-
|
|
271
|
+
if beta1 > 0:
|
|
272
|
+
update = torch.add(exp_avg, exp_avg_slow, alpha=alpha_t)
|
|
273
|
+
else:
|
|
274
|
+
update = torch.add(grad, exp_avg_slow, alpha=alpha_t)
|
|
266
275
|
else:
|
|
267
276
|
update = exp_avg.clone() if beta1 > 0 else grad.clone()
|
|
268
277
|
|
adv_optm/optim/Adopt_adv.py
CHANGED
|
@@ -62,8 +62,18 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
62
62
|
the warmup, `alpha` ramps from 0 to its target value. If `None`,
|
|
63
63
|
the scheduler is disabled and the full `alpha` value is used from
|
|
64
64
|
the start. (default: None)
|
|
65
|
+
Simplified_AdEMAMix (bool): whether to use the Simplified AdEMAMix update rule.
|
|
66
|
+
This changes the EMA to accumulator and the update numerator to `alpha_grad * grad + mt`, which can be
|
|
67
|
+
more responsive, especially for small batch sizes. Enabling this will
|
|
68
|
+
automatically disable `use_AdEMAMix`, `use_cautious`, `use_grams`,
|
|
69
|
+
and `use_atan2`. (default: False)
|
|
70
|
+
alpha_grad (float): Mixing coefficient for the Simplified AdEMAMix update rule
|
|
71
|
+
(only used when `Simplified_AdEMAMix` is `True`). Controls the weight of the
|
|
72
|
+
current gradient. For small batch sizes, use high values (e.g., 10-100) to be
|
|
73
|
+
more responsive. For large batch sizes, use low values (e.g., 0-1) for
|
|
74
|
+
stability. (default: 100.0)
|
|
65
75
|
factored (bool): whether to use the factorization or disable it to use
|
|
66
|
-
the uncompressed optimizer. (default:
|
|
76
|
+
the uncompressed optimizer. (default: False)
|
|
67
77
|
"""
|
|
68
78
|
|
|
69
79
|
def __init__(
|
|
@@ -77,14 +87,16 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
77
87
|
vector_reshape: bool = True,
|
|
78
88
|
stochastic_rounding: bool = True,
|
|
79
89
|
use_atan2: bool = False,
|
|
80
|
-
use_cautious: bool =
|
|
90
|
+
use_cautious: bool = False,
|
|
81
91
|
use_grams: bool = False,
|
|
82
92
|
use_orthograd: bool = False,
|
|
83
93
|
use_AdEMAMix: bool = False,
|
|
84
94
|
beta3_ema: float = 0.9999,
|
|
85
95
|
alpha: float = 5.0,
|
|
86
96
|
t_alpha: int | None = None,
|
|
87
|
-
|
|
97
|
+
Simplified_AdEMAMix: bool = False,
|
|
98
|
+
alpha_grad: float = 100.0,
|
|
99
|
+
factored: bool = False,
|
|
88
100
|
):
|
|
89
101
|
if not (lr >= 0.0):
|
|
90
102
|
raise ValueError(f"Learning-rate should be >= 0.0. Got {lr}")
|
|
@@ -94,19 +106,34 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
94
106
|
raise ValueError(f"Epsilon should be >= 0.0. Got {eps}")
|
|
95
107
|
if not (weight_decay >= 0.0):
|
|
96
108
|
raise ValueError(f"Weight-decay should be >= 0.0. Got {weight_decay}")
|
|
109
|
+
if use_cautious and use_grams:
|
|
110
|
+
print("Warning: use_cautious is incompatible with use_grams, Disabling use_cautious.")
|
|
111
|
+
use_cautious = False
|
|
112
|
+
if betas[0] == 0.0 and Simplified_AdEMAMix:
|
|
113
|
+
raise ValueError(f"Beta1 cannot be 0.0 when using Simplified_AdEMAMix. Got {betas[0]}")
|
|
114
|
+
if use_AdEMAMix and Simplified_AdEMAMix:
|
|
115
|
+
print("Warning: use_AdEMAMix is incompatible with Simplified_AdEMAMix, Disabling use_AdEMAMix.")
|
|
116
|
+
if use_grams and Simplified_AdEMAMix:
|
|
117
|
+
print("Warning: use_grams is incompatible with Simplified_AdEMAMix, Disabling use_grams.")
|
|
118
|
+
if use_cautious and Simplified_AdEMAMix:
|
|
119
|
+
print("Warning: use_cautious is incompatible with Simplified_AdEMAMix, Disabling use_cautious.")
|
|
120
|
+
if use_atan2 and Simplified_AdEMAMix:
|
|
121
|
+
print("Warning: use_atan2 is incompatible with Simplified_AdEMAMix. Disabling use_atan2.")
|
|
122
|
+
use_atan2 = False
|
|
97
123
|
|
|
98
124
|
defaults = {
|
|
99
125
|
"lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay,
|
|
100
126
|
"vector_reshape": vector_reshape, "beta3_ema": beta3_ema, "alpha": alpha,
|
|
101
|
-
"t_alpha": t_alpha,
|
|
127
|
+
"t_alpha": t_alpha, "alpha_grad": alpha_grad,
|
|
102
128
|
}
|
|
103
129
|
self.clip_lambda = clip_lambda
|
|
104
130
|
self.stochastic_rounding = stochastic_rounding
|
|
105
|
-
self.use_atan2 = use_atan2
|
|
106
|
-
self.use_cautious = use_cautious
|
|
107
|
-
self.use_grams = use_grams
|
|
131
|
+
self.use_atan2 = use_atan2 and not Simplified_AdEMAMix
|
|
132
|
+
self.use_cautious = use_cautious and not Simplified_AdEMAMix
|
|
133
|
+
self.use_grams = use_grams and not Simplified_AdEMAMix
|
|
108
134
|
self.use_orthograd = use_orthograd
|
|
109
|
-
self.use_AdEMAMix = use_AdEMAMix
|
|
135
|
+
self.use_AdEMAMix = use_AdEMAMix and not Simplified_AdEMAMix
|
|
136
|
+
self.Simplified_AdEMAMix = Simplified_AdEMAMix
|
|
110
137
|
self.factored = factored
|
|
111
138
|
super().__init__(params, defaults)
|
|
112
139
|
|
|
@@ -185,6 +212,8 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
185
212
|
alpha_t = alpha
|
|
186
213
|
if t_alpha is not None and t_alpha > 0 and current_step < t_alpha:
|
|
187
214
|
alpha_t = min(current_step * alpha / t_alpha, alpha)
|
|
215
|
+
if self.Simplified_AdEMAMix:
|
|
216
|
+
alpha_grad = group["alpha_grad"]
|
|
188
217
|
|
|
189
218
|
if state['factored']:
|
|
190
219
|
d1, d2 = state['effective_shape']
|
|
@@ -224,7 +253,10 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
224
253
|
del denom
|
|
225
254
|
|
|
226
255
|
# ADOPT Step B: Update momentum m_t using normalized gradient
|
|
227
|
-
|
|
256
|
+
if self.Simplified_AdEMAMix:
|
|
257
|
+
mt.mul_(beta1).add_(normalized_grad, alpha=1.0)
|
|
258
|
+
else:
|
|
259
|
+
mt.mul_(beta1).add_(normalized_grad, alpha=1.0 - beta1)
|
|
228
260
|
if self.use_grams:
|
|
229
261
|
mt = grad_reshaped.sign() * mt.abs()
|
|
230
262
|
elif self.use_cautious:
|
|
@@ -235,8 +267,10 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
235
267
|
|
|
236
268
|
if self.use_AdEMAMix:
|
|
237
269
|
mt_slow.mul_(beta3_ema).add_(normalized_grad, alpha=1.0 - beta3_ema)
|
|
238
|
-
update = mt
|
|
270
|
+
update = torch.add(mt, m_slow, alpha=alpha_t)
|
|
239
271
|
update = update.view(p.shape)
|
|
272
|
+
elif self.Simplified_AdEMAMix:
|
|
273
|
+
update = torch.add(mt, grad_reshaped, alpha=alpha_grad)
|
|
240
274
|
else:
|
|
241
275
|
update = mt.view(p.shape)
|
|
242
276
|
|
|
@@ -283,7 +317,10 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
283
317
|
del denom
|
|
284
318
|
|
|
285
319
|
# ADOPT Step B: Update momentum m_t
|
|
286
|
-
|
|
320
|
+
if self.Simplified_AdEMAMix:
|
|
321
|
+
m.mul_(beta1).add_(normalized_grad, alpha=1.0)
|
|
322
|
+
else:
|
|
323
|
+
m.mul_(beta1).add_(normalized_grad, alpha=1.0 - beta1)
|
|
287
324
|
|
|
288
325
|
if self.use_grams:
|
|
289
326
|
m = grad.sign() * m.abs()
|
|
@@ -295,9 +332,11 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
295
332
|
|
|
296
333
|
if self.use_AdEMAMix:
|
|
297
334
|
m_slow.mul_(beta3_ema).add_(normalized_grad, alpha=1.0 - beta3_ema)
|
|
298
|
-
update = m
|
|
335
|
+
update = torch.add(m, m_slow, alpha=alpha_t)
|
|
336
|
+
elif self.Simplified_AdEMAMix:
|
|
337
|
+
update = torch.add(m, grad, alpha=alpha_grad)
|
|
299
338
|
else:
|
|
300
|
-
update = m
|
|
339
|
+
update = m.clone()
|
|
301
340
|
|
|
302
341
|
if self.use_atan2:
|
|
303
342
|
update.mul_(group['lr'] * 1.2732395447351628)
|
|
@@ -33,8 +33,6 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
|
|
|
33
33
|
(default: 0.0).
|
|
34
34
|
factored (bool): whether to use the factorization or use the
|
|
35
35
|
uncompressed optimizer. (default: True)
|
|
36
|
-
variance_reduction (bool): whether to use the variance reduction technique
|
|
37
|
-
from "Convergence Analysis of the Lion Optimizer" (arXiv:2508.12327v1). (default: False).
|
|
38
36
|
d0 (float):
|
|
39
37
|
Initial D estimate for D-adaptation (default 1e-6). Rarely needs changing.
|
|
40
38
|
d_coef (float):
|
|
@@ -66,7 +64,6 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
|
|
|
66
64
|
use_cautious: bool = False,
|
|
67
65
|
clip_threshold: float = 0.0,
|
|
68
66
|
factored: bool = True,
|
|
69
|
-
variance_reduction: bool = False,
|
|
70
67
|
# prodigy parameters
|
|
71
68
|
beta3: float = None,
|
|
72
69
|
d0: float = 1e-6,
|
|
@@ -97,7 +94,6 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
|
|
|
97
94
|
self.stochastic_rounding = stochastic_rounding
|
|
98
95
|
self.use_cautious = use_cautious
|
|
99
96
|
self.factored = factored
|
|
100
|
-
self.variance_reduction = variance_reduction
|
|
101
97
|
self.fsdp_in_use = fsdp_in_use
|
|
102
98
|
super().__init__(params, defaults)
|
|
103
99
|
# Global state for accumulating metrics across parameter updates within a single step.
|
|
@@ -183,12 +179,8 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
|
|
|
183
179
|
state['mv_m_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
|
|
184
180
|
packed_d2 = (d2 + 7) // 8
|
|
185
181
|
state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=p.device)
|
|
186
|
-
if self.variance_reduction:
|
|
187
|
-
state['prev_grad'] = torch.zeros((d1, d2), device=p.device, dtype=dtype)
|
|
188
182
|
else: # Fallback to standard Lion
|
|
189
183
|
state['exp_avg'] = torch.zeros_like(p, device=p.device, dtype=dtype)
|
|
190
|
-
if self.variance_reduction:
|
|
191
|
-
state['prev_grad'] = torch.zeros_like(p, device=p.device, dtype=dtype)
|
|
192
184
|
|
|
193
185
|
if state['factored']:
|
|
194
186
|
# Factored Path
|
|
@@ -215,20 +207,7 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
|
|
|
215
207
|
update_for_param = signed_update.view(p.shape).mul(self.dlr)
|
|
216
208
|
|
|
217
209
|
# Update momentum m_t = β2*m_{t-1} + (1-β2)*lr*g_t
|
|
218
|
-
|
|
219
|
-
if state['step'] == 1:
|
|
220
|
-
exp_avg.copy_(grad_reshaped)
|
|
221
|
-
else:
|
|
222
|
-
# Heuristic Prodigy-STORM update
|
|
223
|
-
correction = exp_avg.sub(state['prev_grad'])
|
|
224
|
-
grad_alpha = self.d * (1 - self.beta2) + self.beta2
|
|
225
|
-
exp_avg.copy_(grad_reshaped).mul_(grad_alpha).add_(correction, alpha=self.beta2)
|
|
226
|
-
del correction, grad_alpha
|
|
227
|
-
state['prev_grad'].copy_(grad_reshaped)
|
|
228
|
-
else:
|
|
229
|
-
# Standard Prodigy-Lion
|
|
230
|
-
alpha = self.d * (1 - self.beta2)
|
|
231
|
-
exp_avg.mul_(self.beta2).add_(grad_reshaped, alpha=alpha)
|
|
210
|
+
exp_avg.mul_(self.beta2).add_(grad_reshaped, alpha=self.d * (1 - self.beta2))
|
|
232
211
|
del grad_reshaped
|
|
233
212
|
|
|
234
213
|
# Compress new momentum m_t and store factors
|
|
@@ -254,20 +233,7 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
|
|
|
254
233
|
update_for_param = signed_update.mul(self.dlr)
|
|
255
234
|
|
|
256
235
|
# Update momentum
|
|
257
|
-
|
|
258
|
-
if state['step'] == 1:
|
|
259
|
-
exp_avg.copy_(grad)
|
|
260
|
-
else:
|
|
261
|
-
# Heuristic Prodigy-STORM update
|
|
262
|
-
correction = exp_avg.sub(state['prev_grad'])
|
|
263
|
-
grad_alpha = self.d * (1 - self.beta2) + self.beta2
|
|
264
|
-
exp_avg.copy_(grad).mul_(grad_alpha).add_(correction, alpha=self.beta2)
|
|
265
|
-
del grad_alpha, correction
|
|
266
|
-
state['prev_grad'].copy_(grad)
|
|
267
|
-
else:
|
|
268
|
-
# Standard Prodigy-Lion
|
|
269
|
-
alpha = self.d * (1 - self.beta2)
|
|
270
|
-
exp_avg.mul_(self.beta2).add_(grad, alpha=alpha)
|
|
236
|
+
exp_avg.mul_(self.beta2).add_(grad, alpha=self.d * (1 - self.beta2))
|
|
271
237
|
|
|
272
238
|
# --- Accumulate Prodigy stats ---
|
|
273
239
|
d0, safeguard_warmup, slice_p = group['d0'], group['safeguard_warmup'], group['slice_p']
|
|
@@ -298,7 +264,7 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
|
|
|
298
264
|
else:
|
|
299
265
|
p.data.add_(-update_for_param)
|
|
300
266
|
|
|
301
|
-
|
|
267
|
+
del update_for_param
|
|
302
268
|
|
|
303
269
|
@torch.no_grad()
|
|
304
270
|
def step(self, closure: Optional[callable] = None):
|
adv_optm/optim/Lion_adv.py
CHANGED
|
@@ -33,8 +33,6 @@ class Lion_adv(torch.optim.Optimizer):
|
|
|
33
33
|
(default: 0.0).
|
|
34
34
|
factored (bool): whether to use the factorization or use the
|
|
35
35
|
uncompressed optimizer. (default: True)
|
|
36
|
-
variance_reduction (bool): whether to use the variance reduction technique
|
|
37
|
-
from "Convergence Analysis of the Lion Optimizer" (arXiv:2508.12327v1). (default: False).
|
|
38
36
|
"""
|
|
39
37
|
|
|
40
38
|
def __init__(
|
|
@@ -49,7 +47,6 @@ class Lion_adv(torch.optim.Optimizer):
|
|
|
49
47
|
use_cautious: bool = False,
|
|
50
48
|
clip_threshold: float = 0.0,
|
|
51
49
|
factored: bool = True,
|
|
52
|
-
variance_reduction: bool = False,
|
|
53
50
|
):
|
|
54
51
|
if not lr > 0.0:
|
|
55
52
|
raise ValueError(f"Learning rate must be > 0.0, but got {lr}")
|
|
@@ -69,7 +66,6 @@ class Lion_adv(torch.optim.Optimizer):
|
|
|
69
66
|
self.stochastic_rounding = stochastic_rounding
|
|
70
67
|
self.use_cautious = use_cautious
|
|
71
68
|
self.factored = factored
|
|
72
|
-
self.variance_reduction = variance_reduction
|
|
73
69
|
super().__init__(params, defaults)
|
|
74
70
|
|
|
75
71
|
@property
|
|
@@ -122,12 +118,8 @@ class Lion_adv(torch.optim.Optimizer):
|
|
|
122
118
|
state['mv_m_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
|
|
123
119
|
packed_d2 = (d2 + 7) // 8
|
|
124
120
|
state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=p.device)
|
|
125
|
-
if self.variance_reduction:
|
|
126
|
-
state['prev_grad'] = torch.zeros((d1, d2), device=p.device, dtype=dtype)
|
|
127
121
|
else: # Fallback to standard Lion
|
|
128
122
|
state['exp_avg'] = torch.zeros_like(p, device=p.device, dtype=dtype)
|
|
129
|
-
if self.variance_reduction:
|
|
130
|
-
state['prev_grad'] = torch.zeros_like(p, device=p.device, dtype=dtype)
|
|
131
123
|
|
|
132
124
|
state['step'] += 1
|
|
133
125
|
beta1, beta2 = group["betas"]
|
|
@@ -157,21 +149,9 @@ class Lion_adv(torch.optim.Optimizer):
|
|
|
157
149
|
# Parameter update
|
|
158
150
|
update_for_param = signed_update.view(p.shape).mul_(lr)
|
|
159
151
|
|
|
160
|
-
#
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
exp_avg.copy_(grad_reshaped)
|
|
164
|
-
else:
|
|
165
|
-
# Use the simplified STORM update: m_t = g_t + β₂ * (m_{t-1} - g_{t-1})
|
|
166
|
-
correction = exp_avg.sub(state['prev_grad'])
|
|
167
|
-
# Calculate the new momentum and store it back into exp_avg
|
|
168
|
-
exp_avg.copy_(grad_reshaped).add_(correction, alpha=beta2)
|
|
169
|
-
del correction
|
|
170
|
-
# Update prev_grad for the next iteration
|
|
171
|
-
state['prev_grad'].copy_(grad_reshaped)
|
|
172
|
-
else:
|
|
173
|
-
# Standard Lion momentum update
|
|
174
|
-
exp_avg.mul_(beta2).add_(grad_reshaped, alpha=1-beta2)
|
|
152
|
+
# Standard Lion momentum update
|
|
153
|
+
exp_avg.mul_(beta2).add_(grad_reshaped, alpha=1-beta2)
|
|
154
|
+
del grad_reshaped
|
|
175
155
|
|
|
176
156
|
# Compress new momentum m_t and store factors
|
|
177
157
|
state['sign'] = _pack_bools(exp_avg > 0)
|
|
@@ -195,21 +175,8 @@ class Lion_adv(torch.optim.Optimizer):
|
|
|
195
175
|
|
|
196
176
|
update_for_param = signed_update.mul_(lr)
|
|
197
177
|
|
|
198
|
-
#
|
|
199
|
-
|
|
200
|
-
if state['step'] == 1:
|
|
201
|
-
exp_avg.copy_(grad)
|
|
202
|
-
else:
|
|
203
|
-
# Use the simplified STORM update: m_t = g_t + β₂ * (m_{t-1} - g_{t-1})
|
|
204
|
-
correction = exp_avg.sub(state['prev_grad'])
|
|
205
|
-
# Calculate the new momentum and store it back into exp_avg
|
|
206
|
-
exp_avg.copy_(grad).add_(correction, alpha=beta2)
|
|
207
|
-
del correction
|
|
208
|
-
# Update prev_grad for the next iteration
|
|
209
|
-
state['prev_grad'].copy_(grad)
|
|
210
|
-
else:
|
|
211
|
-
# Standard Lion momentum update
|
|
212
|
-
exp_avg.mul_(beta2).add_(grad, alpha=1-beta2)
|
|
178
|
+
# Standard Lion momentum update
|
|
179
|
+
exp_avg.mul_(beta2).add_(grad, alpha=1-beta2)
|
|
213
180
|
|
|
214
181
|
if group["weight_decay"] != 0:
|
|
215
182
|
if p.dtype == torch.bfloat16 and self.stochastic_rounding:
|
|
@@ -225,7 +192,7 @@ class Lion_adv(torch.optim.Optimizer):
|
|
|
225
192
|
else:
|
|
226
193
|
p.data.add_(-update_for_param)
|
|
227
194
|
|
|
228
|
-
|
|
195
|
+
del update_for_param
|
|
229
196
|
|
|
230
197
|
@torch.no_grad()
|
|
231
198
|
def step(self, closure: Optional[callable] = None):
|
adv_optm/optim/Prodigy_adv.py
CHANGED
|
@@ -64,7 +64,7 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
64
64
|
more responsive. For large batch sizes, use low values (e.g., 0-1) for
|
|
65
65
|
stability. (default: 100.0)
|
|
66
66
|
factored (bool): whether to use the factorization or disable it to use
|
|
67
|
-
the uncompressed optimizer. (default:
|
|
67
|
+
the uncompressed optimizer. (default: False)
|
|
68
68
|
d0 (float):
|
|
69
69
|
Initial D estimate for D-adaptation (default 1e-6). Rarely needs changing.
|
|
70
70
|
d_coef (float):
|
|
@@ -82,6 +82,9 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
82
82
|
slice_p (int): Reduce memory usage by calculating LR adaptation statistics on only every
|
|
83
83
|
pth entry of each tensor. For values greater than 1 this an an approximation to standard
|
|
84
84
|
Prodigy. Values ~11 are reasonable (default 11).
|
|
85
|
+
prodigy_steps (int): If greater than zero, disable Prodigy's stepsize adjustments
|
|
86
|
+
after the specified optimiser step and release all state memory required by Prodigy
|
|
87
|
+
(default: 0).
|
|
85
88
|
"""
|
|
86
89
|
|
|
87
90
|
def __init__(
|
|
@@ -103,7 +106,7 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
103
106
|
t_alpha: int | None = None,
|
|
104
107
|
Simplified_AdEMAMix: bool = False,
|
|
105
108
|
alpha_grad: float = 100.0,
|
|
106
|
-
factored: bool =
|
|
109
|
+
factored: bool = False,
|
|
107
110
|
# prodigy parameters
|
|
108
111
|
beta3: float = None,
|
|
109
112
|
d0: float = 1e-6,
|
|
@@ -112,6 +115,7 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
112
115
|
safeguard_warmup: bool = False,
|
|
113
116
|
fsdp_in_use: bool = False,
|
|
114
117
|
slice_p: int = 11,
|
|
118
|
+
prodigy_steps: int = 0,
|
|
115
119
|
):
|
|
116
120
|
if not (lr >= 0.0):
|
|
117
121
|
raise ValueError(f"Learning-rate should be >= 0.0. Got {lr}")
|
|
@@ -121,8 +125,13 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
121
125
|
raise ValueError(f"Epsilon should be >= 0.0. Got {eps}")
|
|
122
126
|
if not (weight_decay >= 0.0):
|
|
123
127
|
raise ValueError(f"Weight-decay should be >= 0.0. Got {weight_decay}")
|
|
128
|
+
if not (prodigy_steps >= 0):
|
|
129
|
+
raise ValueError(f"prodigy_steps should be >= 0. Got {prodigy_steps}")
|
|
130
|
+
if use_cautious and use_grams:
|
|
131
|
+
print("Warning: use_cautious is incompatible with use_grams, Disabling use_cautious.")
|
|
132
|
+
use_cautious = False
|
|
124
133
|
if betas[0] == 0.0 and Simplified_AdEMAMix:
|
|
125
|
-
raise ValueError(f"
|
|
134
|
+
raise ValueError(f"Beta1 cannot be 0.0 when using Simplified_AdEMAMix. Got {betas[0]}")
|
|
126
135
|
if use_AdEMAMix and Simplified_AdEMAMix:
|
|
127
136
|
print("Warning: use_AdEMAMix is incompatible with Simplified_AdEMAMix, Disabling use_AdEMAMix.")
|
|
128
137
|
if use_grams and Simplified_AdEMAMix:
|
|
@@ -140,7 +149,7 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
140
149
|
"beta3_ema": beta3_ema, "alpha": alpha, "t_alpha": t_alpha,
|
|
141
150
|
"beta3": beta3, "d": d0, "d0": d0, "d_max": d0, "d_numerator": 0.0, "d_coef": d_coef,
|
|
142
151
|
"growth_rate": growth_rate, "safeguard_warmup": safeguard_warmup, "k": 0, "slice_p": slice_p,
|
|
143
|
-
"fsdp_in_use": fsdp_in_use,
|
|
152
|
+
"fsdp_in_use": fsdp_in_use, "prodigy_steps": prodigy_steps,
|
|
144
153
|
"alpha_grad": alpha_grad,
|
|
145
154
|
}
|
|
146
155
|
self.stochastic_rounding = stochastic_rounding
|
|
@@ -293,7 +302,10 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
293
302
|
torch.where(unpacked_sign_slow, mt_slow, -mt_slow, out=mt_slow)
|
|
294
303
|
del unpacked_sign_slow
|
|
295
304
|
mt_slow.mul_(beta3_ema).add_(grad_reshaped, alpha=self.d * (1.0 - beta3_ema))
|
|
296
|
-
|
|
305
|
+
if self.beta1 > 0:
|
|
306
|
+
update = torch.add(mt, mt_slow, alpha=alpha_t)
|
|
307
|
+
else:
|
|
308
|
+
update = torch.add(grad_reshaped, mt_slow, alpha=alpha_t)
|
|
297
309
|
elif self.Simplified_AdEMAMix:
|
|
298
310
|
update = torch.add(mt, grad_reshaped, alpha=alpha_grad * self.d)
|
|
299
311
|
else:
|
|
@@ -344,7 +356,10 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
344
356
|
if self.use_AdEMAMix:
|
|
345
357
|
exp_avg_slow = state['exp_avg_slow']
|
|
346
358
|
exp_avg_slow.mul_(beta3_ema).add_(grad, alpha=self.d * (1.0 - beta3_ema))
|
|
347
|
-
|
|
359
|
+
if self.beta1 > 0:
|
|
360
|
+
update = torch.add(exp_avg, exp_avg_slow, alpha=alpha_t)
|
|
361
|
+
else:
|
|
362
|
+
update = torch.add(grad, exp_avg_slow, alpha=alpha_t)
|
|
348
363
|
elif self.Simplified_AdEMAMix:
|
|
349
364
|
update = torch.add(exp_avg, grad, alpha=alpha_grad * self.d)
|
|
350
365
|
else:
|
|
@@ -364,19 +379,27 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
364
379
|
update.mul_(self.dlr)
|
|
365
380
|
|
|
366
381
|
# --- Accumulate Prodigy stats ---
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
382
|
+
prodigy_steps = group['prodigy_steps']
|
|
383
|
+
if prodigy_steps <= 0 or group['k'] < prodigy_steps:
|
|
384
|
+
d0, safeguard_warmup, slice_p = group['d0'], group['safeguard_warmup'], group['slice_p']
|
|
385
|
+
s, p0 = state['s'], state['p0']
|
|
386
|
+
grad_flat = grad.flatten().float()
|
|
387
|
+
p_flat = p.data.flatten().float()
|
|
388
|
+
p0 = p0.float()
|
|
372
389
|
|
|
373
|
-
|
|
390
|
+
self.d_numerator += (self.d / d0) * self.dlr * torch.dot(grad_flat[::slice_p], p0.data - p_flat[::slice_p]).item()
|
|
374
391
|
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
392
|
+
alpha = ((self.d / d0) * self.d) if safeguard_warmup else ((self.d / d0) * self.dlr)
|
|
393
|
+
s.mul_(self.beta3).add_(grad_flat[::slice_p], alpha=alpha)
|
|
394
|
+
self.d_denom += s.abs().sum().item()
|
|
378
395
|
|
|
379
|
-
|
|
396
|
+
del s, p0, grad_flat, p_flat, alpha
|
|
397
|
+
else:
|
|
398
|
+
# Free memory if prodigy_steps is reached
|
|
399
|
+
if 's' in state:
|
|
400
|
+
del state['s']
|
|
401
|
+
if 'p0' in state:
|
|
402
|
+
del state['p0']
|
|
380
403
|
|
|
381
404
|
# Decoupled weight decay
|
|
382
405
|
if group["weight_decay"] != 0:
|
|
@@ -413,29 +436,43 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
413
436
|
def calculate_d(self):
|
|
414
437
|
"""Calculates the new `d` based on the accumulated stats."""
|
|
415
438
|
g_group = self.param_groups[0]
|
|
416
|
-
d_max, d_coef, growth_rate = g_group['d_max'], g_group['d_coef'], g_group['growth_rate']
|
|
417
439
|
|
|
418
|
-
if
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
|
|
440
|
+
# Only perform d-adaptation if prodigy_steps has not been reached
|
|
441
|
+
prodigy_active = not (g_group.get('prodigy_steps', 0) > 0 and g_group['k'] >= g_group['prodigy_steps'])
|
|
442
|
+
|
|
443
|
+
if prodigy_active:
|
|
444
|
+
d_max, d_coef, growth_rate = g_group['d_max'], g_group['d_coef'], g_group['growth_rate']
|
|
445
|
+
|
|
446
|
+
if self.fsdp_in_use and dist.is_available() and dist.is_initialized():
|
|
447
|
+
# Use the device of the first parameter to avoid hardcoding '.cuda()'
|
|
448
|
+
device = self.param_groups[0]['params'][0].device
|
|
449
|
+
dist_tensor = torch.tensor([self.d_numerator, self.d_denom], device=device)
|
|
450
|
+
dist.all_reduce(dist_tensor, op=dist.ReduceOp.SUM)
|
|
451
|
+
global_d_numerator = dist_tensor[0].item()
|
|
452
|
+
global_d_denom = dist_tensor[1].item()
|
|
453
|
+
else:
|
|
454
|
+
global_d_numerator = self.d_numerator
|
|
455
|
+
global_d_denom = self.d_denom
|
|
456
|
+
|
|
457
|
+
d_hat = self.d
|
|
458
|
+
if global_d_denom > 0:
|
|
459
|
+
if self.Simplified_AdEMAMix and g_group['alpha_grad'] > 0:
|
|
460
|
+
# A simple and effective hack to make prodigy compatible with Simplified_AdEMAMix large step sizes
|
|
461
|
+
# by diving by alpha_grad we make sure that d_numerator that was influenced by (alpha_grad * grad)
|
|
462
|
+
# are now normalized by /alpha_grad. this is a heuristic way since the update is also influenced by
|
|
463
|
+
# the increasing and decaying accumulator but it's effective and it worked for me (for Lora/Finetune).
|
|
464
|
+
global_d_numerator /= g_group['alpha_grad']
|
|
465
|
+
d_hat = d_coef * global_d_numerator / global_d_denom
|
|
466
|
+
if self.d == g_group['d0']:
|
|
467
|
+
self.d = max(self.d, d_hat)
|
|
468
|
+
d_max = max(d_max, d_hat)
|
|
469
|
+
self.d = min(d_max, self.d * growth_rate)
|
|
470
|
+
|
|
471
|
+
for group in self.param_groups:
|
|
472
|
+
group['d_numerator'] = global_d_numerator
|
|
473
|
+
group['d'] = self.d
|
|
474
|
+
group['d_max'] = d_max
|
|
475
|
+
|
|
476
|
+
# Increment step counter for all groups, regardless of whether d was updated
|
|
437
477
|
for group in self.param_groups:
|
|
438
|
-
group['d_numerator'] = global_d_numerator
|
|
439
|
-
group['d'] = self.d
|
|
440
|
-
group['d_max'] = d_max
|
|
441
478
|
group['k'] += 1
|
|
@@ -0,0 +1,174 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: adv_optm
|
|
3
|
+
Version: 0.1.9
|
|
4
|
+
Summary: A family of highly efficient, lightweight yet powerful optimizers.
|
|
5
|
+
Home-page: https://github.com/Koratahiu/Advanced_Optimizers
|
|
6
|
+
Author: Koratahiu
|
|
7
|
+
Author-email: hiuhonor@gmail.com
|
|
8
|
+
License: Apache 2.0
|
|
9
|
+
Keywords: llm,fine-tuning,memory-efficient,low-rank,compression,pytorch,optimizer,adam
|
|
10
|
+
Classifier: Programming Language :: Python :: 3
|
|
11
|
+
Classifier: License :: OSI Approved :: Apache Software License
|
|
12
|
+
Classifier: Operating System :: OS Independent
|
|
13
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
14
|
+
Classifier: Topic :: Software Development :: Libraries :: Python Modules
|
|
15
|
+
Requires-Python: >=3.8
|
|
16
|
+
Description-Content-Type: text/markdown
|
|
17
|
+
License-File: LICENSE
|
|
18
|
+
Requires-Dist: torch>=2.0
|
|
19
|
+
Dynamic: author
|
|
20
|
+
Dynamic: author-email
|
|
21
|
+
Dynamic: classifier
|
|
22
|
+
Dynamic: description
|
|
23
|
+
Dynamic: description-content-type
|
|
24
|
+
Dynamic: home-page
|
|
25
|
+
Dynamic: keywords
|
|
26
|
+
Dynamic: license
|
|
27
|
+
Dynamic: license-file
|
|
28
|
+
Dynamic: requires-dist
|
|
29
|
+
Dynamic: requires-python
|
|
30
|
+
Dynamic: summary
|
|
31
|
+
|
|
32
|
+
# Advanced Optimizers (AIO)
|
|
33
|
+
|
|
34
|
+
A comprehensive, all-in-one collection of optimization algorithms for deep learning, designed for maximum efficiency, minimal memory footprint, and superior performance across diverse model architectures and training scenarios.
|
|
35
|
+
|
|
36
|
+
[](https://pypi.org/project/adv_optm/)
|
|
37
|
+
|
|
38
|
+
---
|
|
39
|
+
|
|
40
|
+
## 📦 Installation
|
|
41
|
+
|
|
42
|
+
```bash
|
|
43
|
+
pip install adv_optm
|
|
44
|
+
```
|
|
45
|
+
|
|
46
|
+
---
|
|
47
|
+
|
|
48
|
+
## 🧠 Core Innovations
|
|
49
|
+
|
|
50
|
+
This library integrates multiple state-of-the-art optimization techniques validated through extensive research and practical training, with 1-bit compression for optimizer states:
|
|
51
|
+
|
|
52
|
+
### **Memory-Efficient Optimization (SMMF-inspired)**
|
|
53
|
+
- **Paper**: [SMMF: Square-Matricized Momentum Factorization](https://arxiv.org/abs/2412.08894)
|
|
54
|
+
- **Approach**: Uses rank-1 non-negative matrix factorization with reconstruction cycle (factor → reconstruct → update → factor)
|
|
55
|
+
- **Innovation**:
|
|
56
|
+
- First moment split into **1-bit sign + absolute value**
|
|
57
|
+
- Final storage: **four factored vectors + one 1-bit sign state**
|
|
58
|
+
- Preserves Adam-like update quality with drastically reduced memory
|
|
59
|
+
|
|
60
|
+
---
|
|
61
|
+
|
|
62
|
+
## ⚡ Performance Characteristics
|
|
63
|
+
|
|
64
|
+
### Memory Efficiency (SDXL Model - 6.5GB)
|
|
65
|
+
| Optimizer | Memory Usage | Description |
|
|
66
|
+
|-----------|--------------|-------------|
|
|
67
|
+
| `Adopt_Factored` | 328 MB | 4 small vectors + 1-bit state |
|
|
68
|
+
| `Adopt_Factored + AdEMAMix` | 625 MB | 6 small vectors + two 1-bit states |
|
|
69
|
+
| `Simplified_AdEMAMix` | 328 MB | Same as standard factored (no extra state) |
|
|
70
|
+
|
|
71
|
+
### Speed Comparison (SDXL, Batch Size 4)
|
|
72
|
+
| Optimizer | Speed | Notes |
|
|
73
|
+
|-----------|-------|-------|
|
|
74
|
+
| `Adafactor` | ~8.5s/it | Baseline |
|
|
75
|
+
| `Adopt_Factored` | ~10s/it | +18% overhead from compression |
|
|
76
|
+
| `Adopt_Factored + AdEMAMix` | ~12s/it | +41% overhead (3 factored states) |
|
|
77
|
+
|
|
78
|
+
---
|
|
79
|
+
|
|
80
|
+
## 🧪 Available Optimizers
|
|
81
|
+
|
|
82
|
+
### Standard Optimizers (All support `factored=True/False`)
|
|
83
|
+
| Optimizer | Description | Best For |
|
|
84
|
+
|-----------|-------------|----------|
|
|
85
|
+
| `Adam_Adv` | Advanced Adam implementation | General purpose |
|
|
86
|
+
| `Adopt_Adv` | Adam-variant with independent beta2 | Stable training for small batch size regimes |
|
|
87
|
+
| `Prodigy_Adv` | Prodigy with D-Adaptation | Adam with automatic LR tuning |
|
|
88
|
+
| `Simplified_AdEMAMix` | Adam variant with accumulator momentum | Small/large batch training when tuned correctly |
|
|
89
|
+
| `Lion_Adv` | Advanced Lion implementation | Memory-constrained environments |
|
|
90
|
+
| `Prodigy_Lion_Adv` | Prodigy + Lion combination | Lion with automatic LR tuning |
|
|
91
|
+
|
|
92
|
+
### Feature Matrix
|
|
93
|
+
| Feature | Adam_Adv | Adopt_Adv | Prodigy_Adv | Simplified_AdEMAMix | Lion_Adv |
|
|
94
|
+
|---------|----------|-----------|-------------|---------------------|----------|
|
|
95
|
+
| Factored | ✓ | ✓ | ✓ | ✓ | ✓ |
|
|
96
|
+
| AdEMAMix | ✓ | ✓ | ✓ | ✗ | ✗ |
|
|
97
|
+
| Simplified_AdEMAMix | ✗ | ✗ | ✓ | ✓ | ✗ |
|
|
98
|
+
| OrthoGrad | ✓ | ✓ | ✓ | ✓ | ✓ |
|
|
99
|
+
| Grams | ✓ | ✓ | ✓ | ✗ | ✗ |
|
|
100
|
+
| Cautious | ✓ | ✓ | ✓ | ✗ | ✓ |
|
|
101
|
+
| atan2 | ✓ | ✓ | ✓ | ✗ | ✗ |
|
|
102
|
+
| Stochastic Rounding | ✓ | ✓ | ✓ | ✓ | ✓ |
|
|
103
|
+
| Fused Backward Pass | ✓ | ✓ | ✓ | ✓ | ✓ |
|
|
104
|
+
|
|
105
|
+
---
|
|
106
|
+
|
|
107
|
+
## ⚙️ Key Features & Parameters
|
|
108
|
+
|
|
109
|
+
### Comprehensive Feature Guide
|
|
110
|
+
|
|
111
|
+
| Feature | Description | Recommended Usage | Performance Impact | Theoretical Basis | Compatibility |
|
|
112
|
+
|---------|-------------|-------------------|--------------------|-------------------|--------------|
|
|
113
|
+
| **Factored** | Memory-efficient optimization using rank-1 factorization | Enable for large models (>1B params) or limited VRAM | +12-41% time overhead, 1-bit memory usage | [SMMF](https://arxiv.org/abs/2412.08894) | All optimizers |
|
|
114
|
+
| **AdEMAMix** | Dual EMA system for momentum | Use for long training runs (10k+ steps) | +1 state memory. | [AdEMAMix](https://arxiv.org/abs/2409.03137) | Adam/Adopt/Prodigy |
|
|
115
|
+
| **Simplified_AdEMAMix** | Accumulator-based momentum | Small batch training (≤32) | Same memory as standard, no extra overhead | [Schedule-Free Connections](https://arxiv.org/abs/2502.02431) | Adam/Prodigy |
|
|
116
|
+
| **OrthoGrad** | Removes gradient component parallel to weights | Full finetuning without weight decay | +33% time overhead, no memory impact | [Grokking at Edge](https://github.com/LucasPrietoAl/grokking-at-the-edge-of-numerical-stability) | All optimizers |
|
|
117
|
+
| **Stochastic Rounding** | Improves precision for BF16 training | BF16 training | Minimal overhead (<5%) | [Revisiting BFloat16 Training](https://arxiv.org/abs/2010.06192) | All optimizers |
|
|
118
|
+
| **atan2** | Robust eps replacement + built-in clipping | Use with Adopt or unstable training | No overhead | [Adam-atan2](https://github.com/lucidrains/adam-atan2-pytorch) | Adam/Adopt/prodigy |
|
|
119
|
+
| **Cautious** | Update only when the direction align with the gradients | should faster the convergence | No overhead | [C-Optim](https://github.com/kyleliang919/C-Optim) | Adam/Adopt/prodigy |
|
|
120
|
+
| **Grams** | Update direction from the gradients | should have a stronger effect than cautious | No overhead | [Grams](https://github.com/Gunale0926/Grams) | Adam/Adopt/prodigy |
|
|
121
|
+
|
|
122
|
+
---
|
|
123
|
+
|
|
124
|
+
## Simplified_AdEMAMix Parameters
|
|
125
|
+
Simplified_AdEMAMix replaces standard momentum with an accumulator for better small-large batch performance.
|
|
126
|
+
|
|
127
|
+
| Parameter | Recommended Values | Description |
|
|
128
|
+
|-----------|---------------------|-------------|
|
|
129
|
+
| `beta1` | 0.9 (large BS), 0.99-0.9999 (small BS) | Determines memory length of accumulator |
|
|
130
|
+
| `alpha` | 100-10 (small BS), 1-0 (large BS) | Gradient smoothing factor |
|
|
131
|
+
|
|
132
|
+
**Alpha Tuning Guide**:
|
|
133
|
+
| Batch Size | Recommended α | Rationale |
|
|
134
|
+
|------------|---------------|-----------|
|
|
135
|
+
| Small (≤32) | 100, 50, 20, 10 | Emphasizes recent gradients for quick adaptation |
|
|
136
|
+
| Medium (32-512) | 10, 5, 2, 1 | Balanced approach |
|
|
137
|
+
| Large (≥512) | 1, 0.5, 0 | Emphasizes historical gradients for stability |
|
|
138
|
+
|
|
139
|
+
⚠️ **Important**: Use **~100x smaller learning rate** with Simplified_AdEMAMix compared to AdamW (e.g., 1e-6 instead of 1e-4)
|
|
140
|
+
|
|
141
|
+
### 📊 Performance Validation
|
|
142
|
+
Small Batch Training (SDXL, BS=2, 1.8K steps)
|
|
143
|
+

|
|
144
|
+
|
|
145
|
+
- **🟢 Prodigy_adv** (beta1=0.9, d0=1e-5): Final LR=2.9e-4
|
|
146
|
+
- **🔵 Prodigy_adv + Simplified_AdEMAMix** (beta1=0.99, α=100, d0=1e-7): Final LR=5.8e-6
|
|
147
|
+
|
|
148
|
+
**Results**:
|
|
149
|
+
- Simplified_AdEMAMix shows faster convergence and better final performance
|
|
150
|
+
- D-Adaptation automatically handles aggressive updates (50x smaller LR)
|
|
151
|
+
- Generated samples show significantly better quality with Simplified_AdEMAMix
|
|
152
|
+
|
|
153
|
+
---
|
|
154
|
+
|
|
155
|
+
## ⚠️ Known Limitations
|
|
156
|
+
|
|
157
|
+
### 1. Prodigy_Adv Sensitivity
|
|
158
|
+
- Highly sensitive to gradient modifications (Adopt normalization, low-rank factorization)
|
|
159
|
+
- May fail to increase learning rate in some LoRA scenarios
|
|
160
|
+
- **Fix**: Disable factorization or set beta1=0
|
|
161
|
+
|
|
162
|
+
### 2. Aggressive Learning Rates
|
|
163
|
+
- Can destabilize factored first moment
|
|
164
|
+
- **Recommendation**: Check Prodigy learning rate as reference for safe LR threshold
|
|
165
|
+
|
|
166
|
+
---
|
|
167
|
+
|
|
168
|
+
## 📚 References
|
|
169
|
+
|
|
170
|
+
1. [SMMF: Square-Matricized Momentum Factorization](https://arxiv.org/abs/2412.08894)
|
|
171
|
+
2. [The AdEMAMix Optimizer: Better, Faster, Older](https://arxiv.org/abs/2409.03137)
|
|
172
|
+
3. [Connections between Schedule-Free Optimizers, AdEMAMix, and Accelerated SGD Variants](https://arxiv.org/abs/2502.02431)
|
|
173
|
+
|
|
174
|
+
---
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
adv_optm/__init__.py,sha256=hHL2QwlnQMvIggC9ejOxGOKq65DnnYaHC1ScPQMuIIw,306
|
|
2
|
+
adv_optm/optim/AdamW_adv.py,sha256=Pu0TB14dOhcq9kwXclMIeKCI6ef_P0emwzxPu6xuBM0,14252
|
|
3
|
+
adv_optm/optim/Adopt_adv.py,sha256=71o9BHV3XFefJX21G37PKG96D09x-PSU0eW3Q7WkAjs,17427
|
|
4
|
+
adv_optm/optim/Lion_Prodigy_adv.py,sha256=kIAGXoMbDNRg5reKXtUC_vQQ2gyM-NXPB-Pv9zSpiE8,12787
|
|
5
|
+
adv_optm/optim/Lion_adv.py,sha256=05j_j6LIzHW5b79DVwMIf1FZHVNB8xnStNVjlOdVkCE,8256
|
|
6
|
+
adv_optm/optim/Prodigy_adv.py,sha256=NykG5gcAHjmhlMutknOjAoYKI-K6e5lA3Q9J9vkqnz0,22357
|
|
7
|
+
adv_optm/optim/Simplified_AdEMAMix.py,sha256=opIZjnGJ03-DDAIHTZyJBMReVfgusGDb8FZSWMU3-UM,9774
|
|
8
|
+
adv_optm/optim/__init__.py,sha256=pcP865H2j1tut2VfTUhzQh7V8TF_tzPjqFnjMfFed2k,382
|
|
9
|
+
adv_optm/util/BF16_Stochastic_Rounding.py,sha256=Q5H0BcogmE4atP65dLoI21HKSf50lRdsBDfeF6v9Tbg,1548
|
|
10
|
+
adv_optm/util/Effective_Shape.py,sha256=TBvIk1V8IuTbbBsxuekJA4e_v8JlR5Nujtut8RTWAm4,318
|
|
11
|
+
adv_optm/util/NNMF.py,sha256=yRf5IP5Sjq0Uf0DxN0Q8NxEGSdD-f1ULziLVDOjY8K4,639
|
|
12
|
+
adv_optm/util/One_Bit_Boolean.py,sha256=Wat49esdwohuN-OHOFMW8D0aOQgV9cP5Rl8z6yfmpos,1068
|
|
13
|
+
adv_optm/util/OrthoGrad.py,sha256=NzInuBQGy_Ja__M1R9XbvqVaQ0fhGbtGgFE9YON7B3I,707
|
|
14
|
+
adv_optm/util/__init__.py,sha256=qoyIF0jcLjs_vSEcsv36clw5LFNBEbifyXrrVxMH-G4,349
|
|
15
|
+
adv_optm-0.1.9.dist-info/licenses/LICENSE,sha256=HrhfyXIkWY2tGFK11kg7vPCqhgh5DcxleloqdhrpyMY,11558
|
|
16
|
+
adv_optm-0.1.9.dist-info/METADATA,sha256=IvocLvlwTsZ5WPmO6ZsVffmybwZRf3tr_ALojuwL6dw,8422
|
|
17
|
+
adv_optm-0.1.9.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
18
|
+
adv_optm-0.1.9.dist-info/top_level.txt,sha256=iNfBIIzu-lPrQ7jyC56WBCcbkRwitM2nJ15-MRQ_6fg,9
|
|
19
|
+
adv_optm-0.1.9.dist-info/RECORD,,
|
|
@@ -1,130 +0,0 @@
|
|
|
1
|
-
Metadata-Version: 2.4
|
|
2
|
-
Name: adv_optm
|
|
3
|
-
Version: 0.1.7
|
|
4
|
-
Summary: A family of highly efficient, lightweight yet powerful optimizers.
|
|
5
|
-
Home-page: https://github.com/Koratahiu/Advanced_Optimizers
|
|
6
|
-
Author: Koratahiu
|
|
7
|
-
Author-email: hiuhonor@gmail.com
|
|
8
|
-
License: Apache 2.0
|
|
9
|
-
Keywords: llm,fine-tuning,memory-efficient,low-rank,compression,pytorch,optimizer,adam
|
|
10
|
-
Classifier: Programming Language :: Python :: 3
|
|
11
|
-
Classifier: License :: OSI Approved :: Apache Software License
|
|
12
|
-
Classifier: Operating System :: OS Independent
|
|
13
|
-
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
14
|
-
Classifier: Topic :: Software Development :: Libraries :: Python Modules
|
|
15
|
-
Requires-Python: >=3.8
|
|
16
|
-
Description-Content-Type: text/markdown
|
|
17
|
-
License-File: LICENSE
|
|
18
|
-
Requires-Dist: torch>=2.0
|
|
19
|
-
Dynamic: author
|
|
20
|
-
Dynamic: author-email
|
|
21
|
-
Dynamic: classifier
|
|
22
|
-
Dynamic: description
|
|
23
|
-
Dynamic: description-content-type
|
|
24
|
-
Dynamic: home-page
|
|
25
|
-
Dynamic: keywords
|
|
26
|
-
Dynamic: license
|
|
27
|
-
Dynamic: license-file
|
|
28
|
-
Dynamic: requires-dist
|
|
29
|
-
Dynamic: requires-python
|
|
30
|
-
Dynamic: summary
|
|
31
|
-
|
|
32
|
-
# Advanced Optimizers
|
|
33
|
-
|
|
34
|
-
This repo introduces a new family of highly efficient, lightweight yet powerful optimizers, born from extensive research into recent academic literature and validated through practical training runs across diverse models.
|
|
35
|
-
|
|
36
|
-
---
|
|
37
|
-
|
|
38
|
-
### Install
|
|
39
|
-
|
|
40
|
-
`pip install adv_optm`
|
|
41
|
-
|
|
42
|
-
---
|
|
43
|
-
|
|
44
|
-
### Theory (Inspired by SMMF)
|
|
45
|
-
|
|
46
|
-
Based primarily on:
|
|
47
|
-
**[SMMF: Square-Matricized Momentum Factorization for Memory-Efficient Optimization](https://arxiv.org/abs/2412.08894)**
|
|
48
|
-
|
|
49
|
-
The core innovation:
|
|
50
|
-
- Uses fast, non-negative matrix factorization (NNMF - rank 1), but **reconstructs the full state before each update** to preserve momentum accuracy, then re-factors afterward (factor → reconstruct → update → factor cycle).
|
|
51
|
-
- For the *signed first moment*, we split into **sign + absolute value**:
|
|
52
|
-
- Sign is stored as **1-bit state** via bitwise ops (SMMF originally used 8-bit with 7 bits wasted).
|
|
53
|
-
- Absolute value goes through the factor/reconstruct cycle using two factored vectors + the signed state.
|
|
54
|
-
- Final storage: **four factored vectors + one 1-bit sign**.
|
|
55
|
-
- Updates behave like full-state Adam but with drastically reduced memory.
|
|
56
|
-
|
|
57
|
-
> ✅ **TL;DR**: Lightweight, strong, memory-efficient optimizer.
|
|
58
|
-
|
|
59
|
-
---
|
|
60
|
-
|
|
61
|
-
### Memory Cost
|
|
62
|
-
|
|
63
|
-
- **Adopt_Factored** for full SDXL finetune: **328 MB** (4 small vectors + 1-bit state)
|
|
64
|
-
- **Adopt_Factored with AdEMAMix** for full SDXL finetune: **625 MB** (6 small vectors + two 1-bit states)
|
|
65
|
-
> SDXL is 6.5GB model.
|
|
66
|
-
|
|
67
|
-
---
|
|
68
|
-
|
|
69
|
-
### ⏱️ Speed (my tests in SDXL - BS 4)
|
|
70
|
-
|
|
71
|
-
- **Adopt_Factored**: ~10s/it
|
|
72
|
-
- **Adopt_Factored with AdEMAMix**: ~12s/it
|
|
73
|
-
- **Adafactor**: ~8.5s/it
|
|
74
|
-
→ Overhead from compression/reconstruction cycles.
|
|
75
|
-
→ It's faster than [MLorc](https://arxiv.org/abs/2506.01897) (~12s/it), which uses RSVD compression, and should be the fastest momentum compression (AFAIK).
|
|
76
|
-
|
|
77
|
-
---
|
|
78
|
-
|
|
79
|
-
### 📈 Performance
|
|
80
|
-
|
|
81
|
-
- **Better than Adafactor, and CAME factorzation methods**
|
|
82
|
-
- **Comparable or identical to Adam** (see SMMF paper results)
|
|
83
|
-
|
|
84
|
-
---
|
|
85
|
-
|
|
86
|
-
### Available Optimizers (all support `Factored` toggle)
|
|
87
|
-
|
|
88
|
-
Set `Factored=False` to disable factorization and run as a full uncompressed optimizer (like vanilla Adam).
|
|
89
|
-
|
|
90
|
-
1. **Adam**
|
|
91
|
-
2. **Prodigy**
|
|
92
|
-
3. **Adopt**
|
|
93
|
-
|
|
94
|
-
---
|
|
95
|
-
|
|
96
|
-
### Bonus Features (Built-in)
|
|
97
|
-
|
|
98
|
-
- **Fused Backward Pass**
|
|
99
|
-
|
|
100
|
-
- **Stochastic Rounding (SR)**: Improves quality and convergence for **BF16 training**.
|
|
101
|
-
|
|
102
|
-
- **[AdEMAMix](https://arxiv.org/abs/2409.03137)**
|
|
103
|
-
→ This adds a second, slow-moving EMA, which is combined with the primary momentum to stabilize updates, especially during long runs of full finetuning.
|
|
104
|
-
→ A higher value of beta3 (e.g., 0.9999) gives the EMA a longer memory, making it more stable but slower to adapt. A lower value (e.g., 0.999) is often better for shorter training runs (2k-4k steps).
|
|
105
|
-
→ When `factored` is true, it compresses the new momentum in the same way as the first moment (1-bit state + 2 vectors). However, this introduces noticeable overhead as we are compressing/reconstructing a third state each step.
|
|
106
|
-
|
|
107
|
-
⚠️ **Note**: AdEMAMix updates are more aggressive than normal Adam/Adopt, so use a x2-x5 smaller LR than usual (or use Prodigy).
|
|
108
|
-
|
|
109
|
-
- **[`atan2` smoothing & scaling](https://github.com/lucidrains/adam-atan2-pytorch)**
|
|
110
|
-
→ Robust `eps` replacement (no tuning!) + built-in gradient clipping
|
|
111
|
-
→ *Ideal for ADOPT* (which normally needs higher `eps` and clipping), so `use_atan2` is all-in-one for it.
|
|
112
|
-
|
|
113
|
-
- **[OrthoGrad](https://github.com/LucasPrietoAl/grokking-at-the-edge-of-numerical-stability)**
|
|
114
|
-
→ Removes gradient component parallel to weights → prevents "naïve loss minimization" (NLM) → reduces natural overfitting
|
|
115
|
-
→ Perfect for fine-tuning the direction of existing features (e.g., full finetune or training a trained LoRA) without weight decay erasing prior knowledge.
|
|
116
|
-
|
|
117
|
-
⚠️ **Note**: OrthoGrad introduces **~33% time overhead**, so take this into account.
|
|
118
|
-
|
|
119
|
-
- **[Grams: Gradient Descent with Adaptive Momentum Scaling](https://github.com/Gunale0926/Grams)**
|
|
120
|
-
→ Eliminates the need for 1-bit momentum sign storage by using the **sign of gradients** for the first moment.
|
|
121
|
-
|
|
122
|
-
⚠️ **Not recommended for small batch sizes**: gradients are too noisy, which can destabilize momentum (tested for Prodigy and it made the optimizer slower to find the LR or converge in BS 4).
|
|
123
|
-
|
|
124
|
-
### Other Notes
|
|
125
|
-
|
|
126
|
-
- **Adopt** skips the first step (only initializes the states) and has built-in clipping (sticking to the original optimizer), but we skip both of these when you enable `use_atan2`; as the optimizer becomes scale-invariant and the values of the states won't cause any issues or instability.
|
|
127
|
-
|
|
128
|
-
- When `use_atan2` is True, `eps` will be ignored and you should also disable any gradient clipping.
|
|
129
|
-
|
|
130
|
-
---
|
adv_optm-0.1.7.dist-info/RECORD
DELETED
|
@@ -1,19 +0,0 @@
|
|
|
1
|
-
adv_optm/__init__.py,sha256=CZ_tjWWk5d5D8q_R0rcr8vvwlZyY_44zyAcIAmN_SDY,306
|
|
2
|
-
adv_optm/optim/AdamW_adv.py,sha256=ZeNzk2tWbyd2QDI5hp4InwG3iuHHfqLrlhr_VmcQfRM,13884
|
|
3
|
-
adv_optm/optim/Adopt_adv.py,sha256=rzBWfFOPrMuC6vwETsw7QPKmVXcv4IJRDCTj-6eU1Qk,14798
|
|
4
|
-
adv_optm/optim/Lion_Prodigy_adv.py,sha256=JMss9X8lRpIU4E34PfFpWMMal_XNvZ8Yuqc6i7R5wIQ,14588
|
|
5
|
-
adv_optm/optim/Lion_adv.py,sha256=BA4bSEhJiQ7BhGLDRn9nuMlBrLVh-OMscbmSTeGgRmI,10137
|
|
6
|
-
adv_optm/optim/Prodigy_adv.py,sha256=gJL2r32R3xGD62jMR55ZyKxRv0yL70XHxj4FzEJbFc4,20196
|
|
7
|
-
adv_optm/optim/Simplified_AdEMAMix.py,sha256=opIZjnGJ03-DDAIHTZyJBMReVfgusGDb8FZSWMU3-UM,9774
|
|
8
|
-
adv_optm/optim/__init__.py,sha256=pcP865H2j1tut2VfTUhzQh7V8TF_tzPjqFnjMfFed2k,382
|
|
9
|
-
adv_optm/util/BF16_Stochastic_Rounding.py,sha256=Q5H0BcogmE4atP65dLoI21HKSf50lRdsBDfeF6v9Tbg,1548
|
|
10
|
-
adv_optm/util/Effective_Shape.py,sha256=TBvIk1V8IuTbbBsxuekJA4e_v8JlR5Nujtut8RTWAm4,318
|
|
11
|
-
adv_optm/util/NNMF.py,sha256=yRf5IP5Sjq0Uf0DxN0Q8NxEGSdD-f1ULziLVDOjY8K4,639
|
|
12
|
-
adv_optm/util/One_Bit_Boolean.py,sha256=Wat49esdwohuN-OHOFMW8D0aOQgV9cP5Rl8z6yfmpos,1068
|
|
13
|
-
adv_optm/util/OrthoGrad.py,sha256=NzInuBQGy_Ja__M1R9XbvqVaQ0fhGbtGgFE9YON7B3I,707
|
|
14
|
-
adv_optm/util/__init__.py,sha256=qoyIF0jcLjs_vSEcsv36clw5LFNBEbifyXrrVxMH-G4,349
|
|
15
|
-
adv_optm-0.1.7.dist-info/licenses/LICENSE,sha256=HrhfyXIkWY2tGFK11kg7vPCqhgh5DcxleloqdhrpyMY,11558
|
|
16
|
-
adv_optm-0.1.7.dist-info/METADATA,sha256=BEKyVG9zVdb9WThOw9YtgWZ_zqDmErumpY5Fr-AkbX0,5846
|
|
17
|
-
adv_optm-0.1.7.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
18
|
-
adv_optm-0.1.7.dist-info/top_level.txt,sha256=iNfBIIzu-lPrQ7jyC56WBCcbkRwitM2nJ15-MRQ_6fg,9
|
|
19
|
-
adv_optm-0.1.7.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|