adv-optm 0.1.7__py3-none-any.whl → 0.1.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of adv-optm might be problematic. Click here for more details.
- adv_optm/__init__.py +1 -1
- adv_optm/optim/AdamW_adv.py +10 -4
- adv_optm/optim/Adopt_adv.py +5 -5
- adv_optm/optim/Lion_Prodigy_adv.py +3 -37
- adv_optm/optim/Lion_adv.py +6 -39
- adv_optm/optim/Prodigy_adv.py +69 -38
- {adv_optm-0.1.7.dist-info → adv_optm-0.1.8.dist-info}/METADATA +1 -1
- adv_optm-0.1.8.dist-info/RECORD +19 -0
- adv_optm-0.1.7.dist-info/RECORD +0 -19
- {adv_optm-0.1.7.dist-info → adv_optm-0.1.8.dist-info}/WHEEL +0 -0
- {adv_optm-0.1.7.dist-info → adv_optm-0.1.8.dist-info}/licenses/LICENSE +0 -0
- {adv_optm-0.1.7.dist-info → adv_optm-0.1.8.dist-info}/top_level.txt +0 -0
adv_optm/__init__.py
CHANGED
adv_optm/optim/AdamW_adv.py
CHANGED
|
@@ -55,7 +55,7 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
55
55
|
the warmup, `alpha` ramps from 0 to its target value. If `None`,
|
|
56
56
|
the scheduler is disabled. (default: None)
|
|
57
57
|
factored (bool): whether to use the factorization or disable it to use
|
|
58
|
-
the uncompressed optimizer. (default:
|
|
58
|
+
the uncompressed optimizer. (default: False)
|
|
59
59
|
"""
|
|
60
60
|
|
|
61
61
|
def __init__(
|
|
@@ -76,7 +76,7 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
76
76
|
beta3_ema: float = 0.9999,
|
|
77
77
|
alpha: float = 5.0,
|
|
78
78
|
t_alpha: int | None = None,
|
|
79
|
-
factored: bool =
|
|
79
|
+
factored: bool = False,
|
|
80
80
|
):
|
|
81
81
|
if not (lr >= 0.0):
|
|
82
82
|
raise ValueError(f"Learning-rate should be >= 0.0. Got {lr}")
|
|
@@ -216,7 +216,10 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
216
216
|
del unpacked_sign_slow
|
|
217
217
|
|
|
218
218
|
mt_slow.mul_(beta3_ema).add_(grad_reshaped, alpha=1.0 - beta3_ema)
|
|
219
|
-
|
|
219
|
+
if beta1 > 0:
|
|
220
|
+
update = torch.add(mt, mt_slow, alpha=alpha_t)
|
|
221
|
+
else:
|
|
222
|
+
update = torch.add(grad_reshaped, mt_slow, alpha=alpha_t)
|
|
220
223
|
else:
|
|
221
224
|
update = mt.clone() if beta1 > 0 else grad_reshaped.clone()
|
|
222
225
|
del grad_reshaped
|
|
@@ -262,7 +265,10 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
262
265
|
if self.use_AdEMAMix:
|
|
263
266
|
exp_avg_slow = state['exp_avg_slow']
|
|
264
267
|
exp_avg_slow.mul_(beta3_ema).add_(grad, alpha=1 - beta3_ema)
|
|
265
|
-
|
|
268
|
+
if beta1 > 0:
|
|
269
|
+
update = torch.add(exp_avg, exp_avg_slow, alpha=alpha_t)
|
|
270
|
+
else:
|
|
271
|
+
update = torch.add(grad, exp_avg_slow, alpha=alpha_t)
|
|
266
272
|
else:
|
|
267
273
|
update = exp_avg.clone() if beta1 > 0 else grad.clone()
|
|
268
274
|
|
adv_optm/optim/Adopt_adv.py
CHANGED
|
@@ -63,7 +63,7 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
63
63
|
the scheduler is disabled and the full `alpha` value is used from
|
|
64
64
|
the start. (default: None)
|
|
65
65
|
factored (bool): whether to use the factorization or disable it to use
|
|
66
|
-
the uncompressed optimizer. (default:
|
|
66
|
+
the uncompressed optimizer. (default: False)
|
|
67
67
|
"""
|
|
68
68
|
|
|
69
69
|
def __init__(
|
|
@@ -84,7 +84,7 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
84
84
|
beta3_ema: float = 0.9999,
|
|
85
85
|
alpha: float = 5.0,
|
|
86
86
|
t_alpha: int | None = None,
|
|
87
|
-
factored: bool =
|
|
87
|
+
factored: bool = False,
|
|
88
88
|
):
|
|
89
89
|
if not (lr >= 0.0):
|
|
90
90
|
raise ValueError(f"Learning-rate should be >= 0.0. Got {lr}")
|
|
@@ -235,7 +235,7 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
235
235
|
|
|
236
236
|
if self.use_AdEMAMix:
|
|
237
237
|
mt_slow.mul_(beta3_ema).add_(normalized_grad, alpha=1.0 - beta3_ema)
|
|
238
|
-
update = mt
|
|
238
|
+
update = torch.add(mt, m_slow, alpha=alpha_t)
|
|
239
239
|
update = update.view(p.shape)
|
|
240
240
|
else:
|
|
241
241
|
update = mt.view(p.shape)
|
|
@@ -295,9 +295,9 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
295
295
|
|
|
296
296
|
if self.use_AdEMAMix:
|
|
297
297
|
m_slow.mul_(beta3_ema).add_(normalized_grad, alpha=1.0 - beta3_ema)
|
|
298
|
-
update = m
|
|
298
|
+
update = torch.add(m, m_slow, alpha=alpha_t)
|
|
299
299
|
else:
|
|
300
|
-
update = m
|
|
300
|
+
update = m.clone()
|
|
301
301
|
|
|
302
302
|
if self.use_atan2:
|
|
303
303
|
update.mul_(group['lr'] * 1.2732395447351628)
|
|
@@ -33,8 +33,6 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
|
|
|
33
33
|
(default: 0.0).
|
|
34
34
|
factored (bool): whether to use the factorization or use the
|
|
35
35
|
uncompressed optimizer. (default: True)
|
|
36
|
-
variance_reduction (bool): whether to use the variance reduction technique
|
|
37
|
-
from "Convergence Analysis of the Lion Optimizer" (arXiv:2508.12327v1). (default: False).
|
|
38
36
|
d0 (float):
|
|
39
37
|
Initial D estimate for D-adaptation (default 1e-6). Rarely needs changing.
|
|
40
38
|
d_coef (float):
|
|
@@ -66,7 +64,6 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
|
|
|
66
64
|
use_cautious: bool = False,
|
|
67
65
|
clip_threshold: float = 0.0,
|
|
68
66
|
factored: bool = True,
|
|
69
|
-
variance_reduction: bool = False,
|
|
70
67
|
# prodigy parameters
|
|
71
68
|
beta3: float = None,
|
|
72
69
|
d0: float = 1e-6,
|
|
@@ -97,7 +94,6 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
|
|
|
97
94
|
self.stochastic_rounding = stochastic_rounding
|
|
98
95
|
self.use_cautious = use_cautious
|
|
99
96
|
self.factored = factored
|
|
100
|
-
self.variance_reduction = variance_reduction
|
|
101
97
|
self.fsdp_in_use = fsdp_in_use
|
|
102
98
|
super().__init__(params, defaults)
|
|
103
99
|
# Global state for accumulating metrics across parameter updates within a single step.
|
|
@@ -183,12 +179,8 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
|
|
|
183
179
|
state['mv_m_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
|
|
184
180
|
packed_d2 = (d2 + 7) // 8
|
|
185
181
|
state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=p.device)
|
|
186
|
-
if self.variance_reduction:
|
|
187
|
-
state['prev_grad'] = torch.zeros((d1, d2), device=p.device, dtype=dtype)
|
|
188
182
|
else: # Fallback to standard Lion
|
|
189
183
|
state['exp_avg'] = torch.zeros_like(p, device=p.device, dtype=dtype)
|
|
190
|
-
if self.variance_reduction:
|
|
191
|
-
state['prev_grad'] = torch.zeros_like(p, device=p.device, dtype=dtype)
|
|
192
184
|
|
|
193
185
|
if state['factored']:
|
|
194
186
|
# Factored Path
|
|
@@ -215,20 +207,7 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
|
|
|
215
207
|
update_for_param = signed_update.view(p.shape).mul(self.dlr)
|
|
216
208
|
|
|
217
209
|
# Update momentum m_t = β2*m_{t-1} + (1-β2)*lr*g_t
|
|
218
|
-
|
|
219
|
-
if state['step'] == 1:
|
|
220
|
-
exp_avg.copy_(grad_reshaped)
|
|
221
|
-
else:
|
|
222
|
-
# Heuristic Prodigy-STORM update
|
|
223
|
-
correction = exp_avg.sub(state['prev_grad'])
|
|
224
|
-
grad_alpha = self.d * (1 - self.beta2) + self.beta2
|
|
225
|
-
exp_avg.copy_(grad_reshaped).mul_(grad_alpha).add_(correction, alpha=self.beta2)
|
|
226
|
-
del correction, grad_alpha
|
|
227
|
-
state['prev_grad'].copy_(grad_reshaped)
|
|
228
|
-
else:
|
|
229
|
-
# Standard Prodigy-Lion
|
|
230
|
-
alpha = self.d * (1 - self.beta2)
|
|
231
|
-
exp_avg.mul_(self.beta2).add_(grad_reshaped, alpha=alpha)
|
|
210
|
+
exp_avg.mul_(self.beta2).add_(grad_reshaped, alpha=self.d * (1 - self.beta2))
|
|
232
211
|
del grad_reshaped
|
|
233
212
|
|
|
234
213
|
# Compress new momentum m_t and store factors
|
|
@@ -254,20 +233,7 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
|
|
|
254
233
|
update_for_param = signed_update.mul(self.dlr)
|
|
255
234
|
|
|
256
235
|
# Update momentum
|
|
257
|
-
|
|
258
|
-
if state['step'] == 1:
|
|
259
|
-
exp_avg.copy_(grad)
|
|
260
|
-
else:
|
|
261
|
-
# Heuristic Prodigy-STORM update
|
|
262
|
-
correction = exp_avg.sub(state['prev_grad'])
|
|
263
|
-
grad_alpha = self.d * (1 - self.beta2) + self.beta2
|
|
264
|
-
exp_avg.copy_(grad).mul_(grad_alpha).add_(correction, alpha=self.beta2)
|
|
265
|
-
del grad_alpha, correction
|
|
266
|
-
state['prev_grad'].copy_(grad)
|
|
267
|
-
else:
|
|
268
|
-
# Standard Prodigy-Lion
|
|
269
|
-
alpha = self.d * (1 - self.beta2)
|
|
270
|
-
exp_avg.mul_(self.beta2).add_(grad, alpha=alpha)
|
|
236
|
+
exp_avg.mul_(self.beta2).add_(grad, alpha=self.d * (1 - self.beta2))
|
|
271
237
|
|
|
272
238
|
# --- Accumulate Prodigy stats ---
|
|
273
239
|
d0, safeguard_warmup, slice_p = group['d0'], group['safeguard_warmup'], group['slice_p']
|
|
@@ -298,7 +264,7 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
|
|
|
298
264
|
else:
|
|
299
265
|
p.data.add_(-update_for_param)
|
|
300
266
|
|
|
301
|
-
|
|
267
|
+
del update_for_param
|
|
302
268
|
|
|
303
269
|
@torch.no_grad()
|
|
304
270
|
def step(self, closure: Optional[callable] = None):
|
adv_optm/optim/Lion_adv.py
CHANGED
|
@@ -33,8 +33,6 @@ class Lion_adv(torch.optim.Optimizer):
|
|
|
33
33
|
(default: 0.0).
|
|
34
34
|
factored (bool): whether to use the factorization or use the
|
|
35
35
|
uncompressed optimizer. (default: True)
|
|
36
|
-
variance_reduction (bool): whether to use the variance reduction technique
|
|
37
|
-
from "Convergence Analysis of the Lion Optimizer" (arXiv:2508.12327v1). (default: False).
|
|
38
36
|
"""
|
|
39
37
|
|
|
40
38
|
def __init__(
|
|
@@ -49,7 +47,6 @@ class Lion_adv(torch.optim.Optimizer):
|
|
|
49
47
|
use_cautious: bool = False,
|
|
50
48
|
clip_threshold: float = 0.0,
|
|
51
49
|
factored: bool = True,
|
|
52
|
-
variance_reduction: bool = False,
|
|
53
50
|
):
|
|
54
51
|
if not lr > 0.0:
|
|
55
52
|
raise ValueError(f"Learning rate must be > 0.0, but got {lr}")
|
|
@@ -69,7 +66,6 @@ class Lion_adv(torch.optim.Optimizer):
|
|
|
69
66
|
self.stochastic_rounding = stochastic_rounding
|
|
70
67
|
self.use_cautious = use_cautious
|
|
71
68
|
self.factored = factored
|
|
72
|
-
self.variance_reduction = variance_reduction
|
|
73
69
|
super().__init__(params, defaults)
|
|
74
70
|
|
|
75
71
|
@property
|
|
@@ -122,12 +118,8 @@ class Lion_adv(torch.optim.Optimizer):
|
|
|
122
118
|
state['mv_m_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
|
|
123
119
|
packed_d2 = (d2 + 7) // 8
|
|
124
120
|
state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=p.device)
|
|
125
|
-
if self.variance_reduction:
|
|
126
|
-
state['prev_grad'] = torch.zeros((d1, d2), device=p.device, dtype=dtype)
|
|
127
121
|
else: # Fallback to standard Lion
|
|
128
122
|
state['exp_avg'] = torch.zeros_like(p, device=p.device, dtype=dtype)
|
|
129
|
-
if self.variance_reduction:
|
|
130
|
-
state['prev_grad'] = torch.zeros_like(p, device=p.device, dtype=dtype)
|
|
131
123
|
|
|
132
124
|
state['step'] += 1
|
|
133
125
|
beta1, beta2 = group["betas"]
|
|
@@ -157,21 +149,9 @@ class Lion_adv(torch.optim.Optimizer):
|
|
|
157
149
|
# Parameter update
|
|
158
150
|
update_for_param = signed_update.view(p.shape).mul_(lr)
|
|
159
151
|
|
|
160
|
-
#
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
exp_avg.copy_(grad_reshaped)
|
|
164
|
-
else:
|
|
165
|
-
# Use the simplified STORM update: m_t = g_t + β₂ * (m_{t-1} - g_{t-1})
|
|
166
|
-
correction = exp_avg.sub(state['prev_grad'])
|
|
167
|
-
# Calculate the new momentum and store it back into exp_avg
|
|
168
|
-
exp_avg.copy_(grad_reshaped).add_(correction, alpha=beta2)
|
|
169
|
-
del correction
|
|
170
|
-
# Update prev_grad for the next iteration
|
|
171
|
-
state['prev_grad'].copy_(grad_reshaped)
|
|
172
|
-
else:
|
|
173
|
-
# Standard Lion momentum update
|
|
174
|
-
exp_avg.mul_(beta2).add_(grad_reshaped, alpha=1-beta2)
|
|
152
|
+
# Standard Lion momentum update
|
|
153
|
+
exp_avg.mul_(beta2).add_(grad_reshaped, alpha=1-beta2)
|
|
154
|
+
del grad_reshaped
|
|
175
155
|
|
|
176
156
|
# Compress new momentum m_t and store factors
|
|
177
157
|
state['sign'] = _pack_bools(exp_avg > 0)
|
|
@@ -195,21 +175,8 @@ class Lion_adv(torch.optim.Optimizer):
|
|
|
195
175
|
|
|
196
176
|
update_for_param = signed_update.mul_(lr)
|
|
197
177
|
|
|
198
|
-
#
|
|
199
|
-
|
|
200
|
-
if state['step'] == 1:
|
|
201
|
-
exp_avg.copy_(grad)
|
|
202
|
-
else:
|
|
203
|
-
# Use the simplified STORM update: m_t = g_t + β₂ * (m_{t-1} - g_{t-1})
|
|
204
|
-
correction = exp_avg.sub(state['prev_grad'])
|
|
205
|
-
# Calculate the new momentum and store it back into exp_avg
|
|
206
|
-
exp_avg.copy_(grad).add_(correction, alpha=beta2)
|
|
207
|
-
del correction
|
|
208
|
-
# Update prev_grad for the next iteration
|
|
209
|
-
state['prev_grad'].copy_(grad)
|
|
210
|
-
else:
|
|
211
|
-
# Standard Lion momentum update
|
|
212
|
-
exp_avg.mul_(beta2).add_(grad, alpha=1-beta2)
|
|
178
|
+
# Standard Lion momentum update
|
|
179
|
+
exp_avg.mul_(beta2).add_(grad, alpha=1-beta2)
|
|
213
180
|
|
|
214
181
|
if group["weight_decay"] != 0:
|
|
215
182
|
if p.dtype == torch.bfloat16 and self.stochastic_rounding:
|
|
@@ -225,7 +192,7 @@ class Lion_adv(torch.optim.Optimizer):
|
|
|
225
192
|
else:
|
|
226
193
|
p.data.add_(-update_for_param)
|
|
227
194
|
|
|
228
|
-
|
|
195
|
+
del update_for_param
|
|
229
196
|
|
|
230
197
|
@torch.no_grad()
|
|
231
198
|
def step(self, closure: Optional[callable] = None):
|
adv_optm/optim/Prodigy_adv.py
CHANGED
|
@@ -64,7 +64,7 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
64
64
|
more responsive. For large batch sizes, use low values (e.g., 0-1) for
|
|
65
65
|
stability. (default: 100.0)
|
|
66
66
|
factored (bool): whether to use the factorization or disable it to use
|
|
67
|
-
the uncompressed optimizer. (default:
|
|
67
|
+
the uncompressed optimizer. (default: False)
|
|
68
68
|
d0 (float):
|
|
69
69
|
Initial D estimate for D-adaptation (default 1e-6). Rarely needs changing.
|
|
70
70
|
d_coef (float):
|
|
@@ -82,6 +82,9 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
82
82
|
slice_p (int): Reduce memory usage by calculating LR adaptation statistics on only every
|
|
83
83
|
pth entry of each tensor. For values greater than 1 this an an approximation to standard
|
|
84
84
|
Prodigy. Values ~11 are reasonable (default 11).
|
|
85
|
+
prodigy_steps (int): If greater than zero, disable Prodigy's stepsize adjustments
|
|
86
|
+
after the specified optimiser step and release all state memory required by Prodigy
|
|
87
|
+
(default: 0).
|
|
85
88
|
"""
|
|
86
89
|
|
|
87
90
|
def __init__(
|
|
@@ -103,7 +106,7 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
103
106
|
t_alpha: int | None = None,
|
|
104
107
|
Simplified_AdEMAMix: bool = False,
|
|
105
108
|
alpha_grad: float = 100.0,
|
|
106
|
-
factored: bool =
|
|
109
|
+
factored: bool = False,
|
|
107
110
|
# prodigy parameters
|
|
108
111
|
beta3: float = None,
|
|
109
112
|
d0: float = 1e-6,
|
|
@@ -112,6 +115,7 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
112
115
|
safeguard_warmup: bool = False,
|
|
113
116
|
fsdp_in_use: bool = False,
|
|
114
117
|
slice_p: int = 11,
|
|
118
|
+
prodigy_steps: int = 0,
|
|
115
119
|
):
|
|
116
120
|
if not (lr >= 0.0):
|
|
117
121
|
raise ValueError(f"Learning-rate should be >= 0.0. Got {lr}")
|
|
@@ -121,6 +125,8 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
121
125
|
raise ValueError(f"Epsilon should be >= 0.0. Got {eps}")
|
|
122
126
|
if not (weight_decay >= 0.0):
|
|
123
127
|
raise ValueError(f"Weight-decay should be >= 0.0. Got {weight_decay}")
|
|
128
|
+
if not (prodigy_steps >= 0):
|
|
129
|
+
raise ValueError(f"prodigy_steps should be >= 0. Got {prodigy_steps}")
|
|
124
130
|
if betas[0] == 0.0 and Simplified_AdEMAMix:
|
|
125
131
|
raise ValueError(f"Beta 1 cannot be 0.0 when using Simplified_AdEMAMix. Got {betas[0]}")
|
|
126
132
|
if use_AdEMAMix and Simplified_AdEMAMix:
|
|
@@ -132,6 +138,9 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
132
138
|
if use_atan2 and Simplified_AdEMAMix:
|
|
133
139
|
print("Warning: use_atan2 is incompatible with Simplified_AdEMAMix. Disabling use_atan2.")
|
|
134
140
|
use_atan2 = False
|
|
141
|
+
if Simplified_AdEMAMix and alpha_grad > 0:
|
|
142
|
+
# scales d_coef by alpha_grad, this force prodigy to behave well with Simplified_AdEMAMix
|
|
143
|
+
d_coef = d_coef/alpha_grad
|
|
135
144
|
|
|
136
145
|
defaults = {
|
|
137
146
|
"lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay,
|
|
@@ -140,7 +149,7 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
140
149
|
"beta3_ema": beta3_ema, "alpha": alpha, "t_alpha": t_alpha,
|
|
141
150
|
"beta3": beta3, "d": d0, "d0": d0, "d_max": d0, "d_numerator": 0.0, "d_coef": d_coef,
|
|
142
151
|
"growth_rate": growth_rate, "safeguard_warmup": safeguard_warmup, "k": 0, "slice_p": slice_p,
|
|
143
|
-
"fsdp_in_use": fsdp_in_use,
|
|
152
|
+
"fsdp_in_use": fsdp_in_use, "prodigy_steps": prodigy_steps,
|
|
144
153
|
"alpha_grad": alpha_grad,
|
|
145
154
|
}
|
|
146
155
|
self.stochastic_rounding = stochastic_rounding
|
|
@@ -293,7 +302,10 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
293
302
|
torch.where(unpacked_sign_slow, mt_slow, -mt_slow, out=mt_slow)
|
|
294
303
|
del unpacked_sign_slow
|
|
295
304
|
mt_slow.mul_(beta3_ema).add_(grad_reshaped, alpha=self.d * (1.0 - beta3_ema))
|
|
296
|
-
|
|
305
|
+
if self.beta1 > 0:
|
|
306
|
+
update = torch.add(mt, mt_slow, alpha=alpha_t)
|
|
307
|
+
else:
|
|
308
|
+
update = torch.add(grad_reshaped, mt_slow, alpha=alpha_t)
|
|
297
309
|
elif self.Simplified_AdEMAMix:
|
|
298
310
|
update = torch.add(mt, grad_reshaped, alpha=alpha_grad * self.d)
|
|
299
311
|
else:
|
|
@@ -344,7 +356,10 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
344
356
|
if self.use_AdEMAMix:
|
|
345
357
|
exp_avg_slow = state['exp_avg_slow']
|
|
346
358
|
exp_avg_slow.mul_(beta3_ema).add_(grad, alpha=self.d * (1.0 - beta3_ema))
|
|
347
|
-
|
|
359
|
+
if self.beta1 > 0:
|
|
360
|
+
update = torch.add(exp_avg, exp_avg_slow, alpha=alpha_t)
|
|
361
|
+
else:
|
|
362
|
+
update = torch.add(grad, exp_avg_slow, alpha=alpha_t)
|
|
348
363
|
elif self.Simplified_AdEMAMix:
|
|
349
364
|
update = torch.add(exp_avg, grad, alpha=alpha_grad * self.d)
|
|
350
365
|
else:
|
|
@@ -364,19 +379,27 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
364
379
|
update.mul_(self.dlr)
|
|
365
380
|
|
|
366
381
|
# --- Accumulate Prodigy stats ---
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
382
|
+
prodigy_steps = group['prodigy_steps']
|
|
383
|
+
if prodigy_steps <= 0 or group['k'] < prodigy_steps:
|
|
384
|
+
d0, safeguard_warmup, slice_p = group['d0'], group['safeguard_warmup'], group['slice_p']
|
|
385
|
+
s, p0 = state['s'], state['p0']
|
|
386
|
+
grad_flat = grad.flatten().float()
|
|
387
|
+
p_flat = p.data.flatten().float()
|
|
388
|
+
p0 = p0.float()
|
|
372
389
|
|
|
373
|
-
|
|
390
|
+
self.d_numerator += (self.d / d0) * self.dlr * torch.dot(grad_flat[::slice_p], p0.data - p_flat[::slice_p]).item()
|
|
374
391
|
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
392
|
+
alpha = ((self.d / d0) * self.d) if safeguard_warmup else ((self.d / d0) * self.dlr)
|
|
393
|
+
s.mul_(self.beta3).add_(grad_flat[::slice_p], alpha=alpha)
|
|
394
|
+
self.d_denom += s.abs().sum().item()
|
|
378
395
|
|
|
379
|
-
|
|
396
|
+
del s, p0, grad_flat, p_flat, alpha
|
|
397
|
+
else:
|
|
398
|
+
# Free memory if prodigy_steps is reached
|
|
399
|
+
if 's' in state:
|
|
400
|
+
del state['s']
|
|
401
|
+
if 'p0' in state:
|
|
402
|
+
del state['p0']
|
|
380
403
|
|
|
381
404
|
# Decoupled weight decay
|
|
382
405
|
if group["weight_decay"] != 0:
|
|
@@ -413,29 +436,37 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
413
436
|
def calculate_d(self):
|
|
414
437
|
"""Calculates the new `d` based on the accumulated stats."""
|
|
415
438
|
g_group = self.param_groups[0]
|
|
416
|
-
d_max, d_coef, growth_rate = g_group['d_max'], g_group['d_coef'], g_group['growth_rate']
|
|
417
439
|
|
|
418
|
-
if
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
|
|
440
|
+
# Only perform d-adaptation if prodigy_steps has not been reached
|
|
441
|
+
prodigy_active = not (g_group.get('prodigy_steps', 0) > 0 and g_group['k'] >= g_group['prodigy_steps'])
|
|
442
|
+
|
|
443
|
+
if prodigy_active:
|
|
444
|
+
d_max, d_coef, growth_rate = g_group['d_max'], g_group['d_coef'], g_group['growth_rate']
|
|
445
|
+
|
|
446
|
+
if self.fsdp_in_use and dist.is_available() and dist.is_initialized():
|
|
447
|
+
# Use the device of the first parameter to avoid hardcoding '.cuda()'
|
|
448
|
+
device = self.param_groups[0]['params'][0].device
|
|
449
|
+
dist_tensor = torch.tensor([self.d_numerator, self.d_denom], device=device)
|
|
450
|
+
dist.all_reduce(dist_tensor, op=dist.ReduceOp.SUM)
|
|
451
|
+
global_d_numerator = dist_tensor[0].item()
|
|
452
|
+
global_d_denom = dist_tensor[1].item()
|
|
453
|
+
else:
|
|
454
|
+
global_d_numerator = self.d_numerator
|
|
455
|
+
global_d_denom = self.d_denom
|
|
456
|
+
|
|
457
|
+
d_hat = self.d
|
|
458
|
+
if global_d_denom > 0:
|
|
459
|
+
d_hat = d_coef * global_d_numerator / global_d_denom
|
|
460
|
+
if self.d == g_group['d0']:
|
|
461
|
+
self.d = max(self.d, d_hat)
|
|
462
|
+
d_max = max(d_max, d_hat)
|
|
463
|
+
self.d = min(d_max, self.d * growth_rate)
|
|
464
|
+
|
|
465
|
+
for group in self.param_groups:
|
|
466
|
+
group['d_numerator'] = global_d_numerator
|
|
467
|
+
group['d'] = self.d
|
|
468
|
+
group['d_max'] = d_max
|
|
469
|
+
|
|
470
|
+
# Increment step counter for all groups, regardless of whether d was updated
|
|
437
471
|
for group in self.param_groups:
|
|
438
|
-
group['d_numerator'] = global_d_numerator
|
|
439
|
-
group['d'] = self.d
|
|
440
|
-
group['d_max'] = d_max
|
|
441
472
|
group['k'] += 1
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
adv_optm/__init__.py,sha256=csc19AmU_h7daI3bo4hDVBouMqGiHejfipPIOGFAUQ8,306
|
|
2
|
+
adv_optm/optim/AdamW_adv.py,sha256=Had6kzSBI0eEMiL2yI1wa1nEBoPfgwHQGtnRcDJ8tXI,14078
|
|
3
|
+
adv_optm/optim/Adopt_adv.py,sha256=-iAKhPbEnzdL0Mx96h2BBlJB85TyHdkjULRjWvNbTyY,14833
|
|
4
|
+
adv_optm/optim/Lion_Prodigy_adv.py,sha256=kIAGXoMbDNRg5reKXtUC_vQQ2gyM-NXPB-Pv9zSpiE8,12787
|
|
5
|
+
adv_optm/optim/Lion_adv.py,sha256=05j_j6LIzHW5b79DVwMIf1FZHVNB8xnStNVjlOdVkCE,8256
|
|
6
|
+
adv_optm/optim/Prodigy_adv.py,sha256=U4grKRumzDJRYSI-QHmmZZ7ed_67tyiC3OPSXqJVBx8,21759
|
|
7
|
+
adv_optm/optim/Simplified_AdEMAMix.py,sha256=opIZjnGJ03-DDAIHTZyJBMReVfgusGDb8FZSWMU3-UM,9774
|
|
8
|
+
adv_optm/optim/__init__.py,sha256=pcP865H2j1tut2VfTUhzQh7V8TF_tzPjqFnjMfFed2k,382
|
|
9
|
+
adv_optm/util/BF16_Stochastic_Rounding.py,sha256=Q5H0BcogmE4atP65dLoI21HKSf50lRdsBDfeF6v9Tbg,1548
|
|
10
|
+
adv_optm/util/Effective_Shape.py,sha256=TBvIk1V8IuTbbBsxuekJA4e_v8JlR5Nujtut8RTWAm4,318
|
|
11
|
+
adv_optm/util/NNMF.py,sha256=yRf5IP5Sjq0Uf0DxN0Q8NxEGSdD-f1ULziLVDOjY8K4,639
|
|
12
|
+
adv_optm/util/One_Bit_Boolean.py,sha256=Wat49esdwohuN-OHOFMW8D0aOQgV9cP5Rl8z6yfmpos,1068
|
|
13
|
+
adv_optm/util/OrthoGrad.py,sha256=NzInuBQGy_Ja__M1R9XbvqVaQ0fhGbtGgFE9YON7B3I,707
|
|
14
|
+
adv_optm/util/__init__.py,sha256=qoyIF0jcLjs_vSEcsv36clw5LFNBEbifyXrrVxMH-G4,349
|
|
15
|
+
adv_optm-0.1.8.dist-info/licenses/LICENSE,sha256=HrhfyXIkWY2tGFK11kg7vPCqhgh5DcxleloqdhrpyMY,11558
|
|
16
|
+
adv_optm-0.1.8.dist-info/METADATA,sha256=Ydu5_f_d19hoYMf9zvP3eu9ci8XsLWyDuY99JYJVR9o,5846
|
|
17
|
+
adv_optm-0.1.8.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
18
|
+
adv_optm-0.1.8.dist-info/top_level.txt,sha256=iNfBIIzu-lPrQ7jyC56WBCcbkRwitM2nJ15-MRQ_6fg,9
|
|
19
|
+
adv_optm-0.1.8.dist-info/RECORD,,
|
adv_optm-0.1.7.dist-info/RECORD
DELETED
|
@@ -1,19 +0,0 @@
|
|
|
1
|
-
adv_optm/__init__.py,sha256=CZ_tjWWk5d5D8q_R0rcr8vvwlZyY_44zyAcIAmN_SDY,306
|
|
2
|
-
adv_optm/optim/AdamW_adv.py,sha256=ZeNzk2tWbyd2QDI5hp4InwG3iuHHfqLrlhr_VmcQfRM,13884
|
|
3
|
-
adv_optm/optim/Adopt_adv.py,sha256=rzBWfFOPrMuC6vwETsw7QPKmVXcv4IJRDCTj-6eU1Qk,14798
|
|
4
|
-
adv_optm/optim/Lion_Prodigy_adv.py,sha256=JMss9X8lRpIU4E34PfFpWMMal_XNvZ8Yuqc6i7R5wIQ,14588
|
|
5
|
-
adv_optm/optim/Lion_adv.py,sha256=BA4bSEhJiQ7BhGLDRn9nuMlBrLVh-OMscbmSTeGgRmI,10137
|
|
6
|
-
adv_optm/optim/Prodigy_adv.py,sha256=gJL2r32R3xGD62jMR55ZyKxRv0yL70XHxj4FzEJbFc4,20196
|
|
7
|
-
adv_optm/optim/Simplified_AdEMAMix.py,sha256=opIZjnGJ03-DDAIHTZyJBMReVfgusGDb8FZSWMU3-UM,9774
|
|
8
|
-
adv_optm/optim/__init__.py,sha256=pcP865H2j1tut2VfTUhzQh7V8TF_tzPjqFnjMfFed2k,382
|
|
9
|
-
adv_optm/util/BF16_Stochastic_Rounding.py,sha256=Q5H0BcogmE4atP65dLoI21HKSf50lRdsBDfeF6v9Tbg,1548
|
|
10
|
-
adv_optm/util/Effective_Shape.py,sha256=TBvIk1V8IuTbbBsxuekJA4e_v8JlR5Nujtut8RTWAm4,318
|
|
11
|
-
adv_optm/util/NNMF.py,sha256=yRf5IP5Sjq0Uf0DxN0Q8NxEGSdD-f1ULziLVDOjY8K4,639
|
|
12
|
-
adv_optm/util/One_Bit_Boolean.py,sha256=Wat49esdwohuN-OHOFMW8D0aOQgV9cP5Rl8z6yfmpos,1068
|
|
13
|
-
adv_optm/util/OrthoGrad.py,sha256=NzInuBQGy_Ja__M1R9XbvqVaQ0fhGbtGgFE9YON7B3I,707
|
|
14
|
-
adv_optm/util/__init__.py,sha256=qoyIF0jcLjs_vSEcsv36clw5LFNBEbifyXrrVxMH-G4,349
|
|
15
|
-
adv_optm-0.1.7.dist-info/licenses/LICENSE,sha256=HrhfyXIkWY2tGFK11kg7vPCqhgh5DcxleloqdhrpyMY,11558
|
|
16
|
-
adv_optm-0.1.7.dist-info/METADATA,sha256=BEKyVG9zVdb9WThOw9YtgWZ_zqDmErumpY5Fr-AkbX0,5846
|
|
17
|
-
adv_optm-0.1.7.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
18
|
-
adv_optm-0.1.7.dist-info/top_level.txt,sha256=iNfBIIzu-lPrQ7jyC56WBCcbkRwitM2nJ15-MRQ_6fg,9
|
|
19
|
-
adv_optm-0.1.7.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|