adv-optm 0.1.6__tar.gz → 0.1.8__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of adv-optm might be problematic. Click here for more details.
- {adv_optm-0.1.6 → adv_optm-0.1.8}/PKG-INFO +1 -1
- {adv_optm-0.1.6 → adv_optm-0.1.8}/adv_optm/__init__.py +3 -1
- {adv_optm-0.1.6 → adv_optm-0.1.8}/adv_optm/optim/AdamW_adv.py +10 -4
- {adv_optm-0.1.6 → adv_optm-0.1.8}/adv_optm/optim/Adopt_adv.py +5 -5
- {adv_optm-0.1.6 → adv_optm-0.1.8}/adv_optm/optim/Lion_Prodigy_adv.py +3 -37
- {adv_optm-0.1.6 → adv_optm-0.1.8}/adv_optm/optim/Lion_adv.py +6 -39
- {adv_optm-0.1.6 → adv_optm-0.1.8}/adv_optm/optim/Prodigy_adv.py +112 -44
- adv_optm-0.1.8/adv_optm/optim/Simplified_AdEMAMix.py +246 -0
- {adv_optm-0.1.6 → adv_optm-0.1.8}/adv_optm/optim/__init__.py +2 -0
- {adv_optm-0.1.6 → adv_optm-0.1.8}/adv_optm.egg-info/PKG-INFO +1 -1
- {adv_optm-0.1.6 → adv_optm-0.1.8}/adv_optm.egg-info/SOURCES.txt +1 -0
- {adv_optm-0.1.6 → adv_optm-0.1.8}/setup.py +1 -1
- {adv_optm-0.1.6 → adv_optm-0.1.8}/LICENSE +0 -0
- {adv_optm-0.1.6 → adv_optm-0.1.8}/README.md +0 -0
- {adv_optm-0.1.6 → adv_optm-0.1.8}/adv_optm/util/BF16_Stochastic_Rounding.py +0 -0
- {adv_optm-0.1.6 → adv_optm-0.1.8}/adv_optm/util/Effective_Shape.py +0 -0
- {adv_optm-0.1.6 → adv_optm-0.1.8}/adv_optm/util/NNMF.py +0 -0
- {adv_optm-0.1.6 → adv_optm-0.1.8}/adv_optm/util/One_Bit_Boolean.py +0 -0
- {adv_optm-0.1.6 → adv_optm-0.1.8}/adv_optm/util/OrthoGrad.py +0 -0
- {adv_optm-0.1.6 → adv_optm-0.1.8}/adv_optm/util/__init__.py +0 -0
- {adv_optm-0.1.6 → adv_optm-0.1.8}/adv_optm.egg-info/dependency_links.txt +0 -0
- {adv_optm-0.1.6 → adv_optm-0.1.8}/adv_optm.egg-info/requires.txt +0 -0
- {adv_optm-0.1.6 → adv_optm-0.1.8}/adv_optm.egg-info/top_level.txt +0 -0
- {adv_optm-0.1.6 → adv_optm-0.1.8}/setup.cfg +0 -0
|
@@ -2,6 +2,7 @@ from .optim import (
|
|
|
2
2
|
AdamW_adv,
|
|
3
3
|
Prodigy_adv,
|
|
4
4
|
Adopt_adv,
|
|
5
|
+
Simplified_AdEMAMix,
|
|
5
6
|
Lion_adv,
|
|
6
7
|
Lion_Prodigy_adv,
|
|
7
8
|
)
|
|
@@ -10,8 +11,9 @@ __all__ = [
|
|
|
10
11
|
"AdamW_adv",
|
|
11
12
|
"Prodigy_adv",
|
|
12
13
|
"Adopt_adv",
|
|
14
|
+
"Simplified_AdEMAMix",
|
|
13
15
|
"Lion_adv",
|
|
14
16
|
"Lion_Prodigy_adv",
|
|
15
17
|
]
|
|
16
18
|
|
|
17
|
-
__version__ = "0.1.
|
|
19
|
+
__version__ = "0.1.8"
|
|
@@ -55,7 +55,7 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
55
55
|
the warmup, `alpha` ramps from 0 to its target value. If `None`,
|
|
56
56
|
the scheduler is disabled. (default: None)
|
|
57
57
|
factored (bool): whether to use the factorization or disable it to use
|
|
58
|
-
the uncompressed optimizer. (default:
|
|
58
|
+
the uncompressed optimizer. (default: False)
|
|
59
59
|
"""
|
|
60
60
|
|
|
61
61
|
def __init__(
|
|
@@ -76,7 +76,7 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
76
76
|
beta3_ema: float = 0.9999,
|
|
77
77
|
alpha: float = 5.0,
|
|
78
78
|
t_alpha: int | None = None,
|
|
79
|
-
factored: bool =
|
|
79
|
+
factored: bool = False,
|
|
80
80
|
):
|
|
81
81
|
if not (lr >= 0.0):
|
|
82
82
|
raise ValueError(f"Learning-rate should be >= 0.0. Got {lr}")
|
|
@@ -216,7 +216,10 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
216
216
|
del unpacked_sign_slow
|
|
217
217
|
|
|
218
218
|
mt_slow.mul_(beta3_ema).add_(grad_reshaped, alpha=1.0 - beta3_ema)
|
|
219
|
-
|
|
219
|
+
if beta1 > 0:
|
|
220
|
+
update = torch.add(mt, mt_slow, alpha=alpha_t)
|
|
221
|
+
else:
|
|
222
|
+
update = torch.add(grad_reshaped, mt_slow, alpha=alpha_t)
|
|
220
223
|
else:
|
|
221
224
|
update = mt.clone() if beta1 > 0 else grad_reshaped.clone()
|
|
222
225
|
del grad_reshaped
|
|
@@ -262,7 +265,10 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
262
265
|
if self.use_AdEMAMix:
|
|
263
266
|
exp_avg_slow = state['exp_avg_slow']
|
|
264
267
|
exp_avg_slow.mul_(beta3_ema).add_(grad, alpha=1 - beta3_ema)
|
|
265
|
-
|
|
268
|
+
if beta1 > 0:
|
|
269
|
+
update = torch.add(exp_avg, exp_avg_slow, alpha=alpha_t)
|
|
270
|
+
else:
|
|
271
|
+
update = torch.add(grad, exp_avg_slow, alpha=alpha_t)
|
|
266
272
|
else:
|
|
267
273
|
update = exp_avg.clone() if beta1 > 0 else grad.clone()
|
|
268
274
|
|
|
@@ -63,7 +63,7 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
63
63
|
the scheduler is disabled and the full `alpha` value is used from
|
|
64
64
|
the start. (default: None)
|
|
65
65
|
factored (bool): whether to use the factorization or disable it to use
|
|
66
|
-
the uncompressed optimizer. (default:
|
|
66
|
+
the uncompressed optimizer. (default: False)
|
|
67
67
|
"""
|
|
68
68
|
|
|
69
69
|
def __init__(
|
|
@@ -84,7 +84,7 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
84
84
|
beta3_ema: float = 0.9999,
|
|
85
85
|
alpha: float = 5.0,
|
|
86
86
|
t_alpha: int | None = None,
|
|
87
|
-
factored: bool =
|
|
87
|
+
factored: bool = False,
|
|
88
88
|
):
|
|
89
89
|
if not (lr >= 0.0):
|
|
90
90
|
raise ValueError(f"Learning-rate should be >= 0.0. Got {lr}")
|
|
@@ -235,7 +235,7 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
235
235
|
|
|
236
236
|
if self.use_AdEMAMix:
|
|
237
237
|
mt_slow.mul_(beta3_ema).add_(normalized_grad, alpha=1.0 - beta3_ema)
|
|
238
|
-
update = mt
|
|
238
|
+
update = torch.add(mt, m_slow, alpha=alpha_t)
|
|
239
239
|
update = update.view(p.shape)
|
|
240
240
|
else:
|
|
241
241
|
update = mt.view(p.shape)
|
|
@@ -295,9 +295,9 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
295
295
|
|
|
296
296
|
if self.use_AdEMAMix:
|
|
297
297
|
m_slow.mul_(beta3_ema).add_(normalized_grad, alpha=1.0 - beta3_ema)
|
|
298
|
-
update = m
|
|
298
|
+
update = torch.add(m, m_slow, alpha=alpha_t)
|
|
299
299
|
else:
|
|
300
|
-
update = m
|
|
300
|
+
update = m.clone()
|
|
301
301
|
|
|
302
302
|
if self.use_atan2:
|
|
303
303
|
update.mul_(group['lr'] * 1.2732395447351628)
|
|
@@ -33,8 +33,6 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
|
|
|
33
33
|
(default: 0.0).
|
|
34
34
|
factored (bool): whether to use the factorization or use the
|
|
35
35
|
uncompressed optimizer. (default: True)
|
|
36
|
-
variance_reduction (bool): whether to use the variance reduction technique
|
|
37
|
-
from "Convergence Analysis of the Lion Optimizer" (arXiv:2508.12327v1). (default: False).
|
|
38
36
|
d0 (float):
|
|
39
37
|
Initial D estimate for D-adaptation (default 1e-6). Rarely needs changing.
|
|
40
38
|
d_coef (float):
|
|
@@ -66,7 +64,6 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
|
|
|
66
64
|
use_cautious: bool = False,
|
|
67
65
|
clip_threshold: float = 0.0,
|
|
68
66
|
factored: bool = True,
|
|
69
|
-
variance_reduction: bool = False,
|
|
70
67
|
# prodigy parameters
|
|
71
68
|
beta3: float = None,
|
|
72
69
|
d0: float = 1e-6,
|
|
@@ -97,7 +94,6 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
|
|
|
97
94
|
self.stochastic_rounding = stochastic_rounding
|
|
98
95
|
self.use_cautious = use_cautious
|
|
99
96
|
self.factored = factored
|
|
100
|
-
self.variance_reduction = variance_reduction
|
|
101
97
|
self.fsdp_in_use = fsdp_in_use
|
|
102
98
|
super().__init__(params, defaults)
|
|
103
99
|
# Global state for accumulating metrics across parameter updates within a single step.
|
|
@@ -183,12 +179,8 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
|
|
|
183
179
|
state['mv_m_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
|
|
184
180
|
packed_d2 = (d2 + 7) // 8
|
|
185
181
|
state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=p.device)
|
|
186
|
-
if self.variance_reduction:
|
|
187
|
-
state['prev_grad'] = torch.zeros((d1, d2), device=p.device, dtype=dtype)
|
|
188
182
|
else: # Fallback to standard Lion
|
|
189
183
|
state['exp_avg'] = torch.zeros_like(p, device=p.device, dtype=dtype)
|
|
190
|
-
if self.variance_reduction:
|
|
191
|
-
state['prev_grad'] = torch.zeros_like(p, device=p.device, dtype=dtype)
|
|
192
184
|
|
|
193
185
|
if state['factored']:
|
|
194
186
|
# Factored Path
|
|
@@ -215,20 +207,7 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
|
|
|
215
207
|
update_for_param = signed_update.view(p.shape).mul(self.dlr)
|
|
216
208
|
|
|
217
209
|
# Update momentum m_t = β2*m_{t-1} + (1-β2)*lr*g_t
|
|
218
|
-
|
|
219
|
-
if state['step'] == 1:
|
|
220
|
-
exp_avg.copy_(grad_reshaped)
|
|
221
|
-
else:
|
|
222
|
-
# Heuristic Prodigy-STORM update
|
|
223
|
-
correction = exp_avg.sub(state['prev_grad'])
|
|
224
|
-
grad_alpha = self.d * (1 - self.beta2) + self.beta2
|
|
225
|
-
exp_avg.copy_(grad_reshaped).mul_(grad_alpha).add_(correction, alpha=self.beta2)
|
|
226
|
-
del correction, grad_alpha
|
|
227
|
-
state['prev_grad'].copy_(grad_reshaped)
|
|
228
|
-
else:
|
|
229
|
-
# Standard Prodigy-Lion
|
|
230
|
-
alpha = self.d * (1 - self.beta2)
|
|
231
|
-
exp_avg.mul_(self.beta2).add_(grad_reshaped, alpha=alpha)
|
|
210
|
+
exp_avg.mul_(self.beta2).add_(grad_reshaped, alpha=self.d * (1 - self.beta2))
|
|
232
211
|
del grad_reshaped
|
|
233
212
|
|
|
234
213
|
# Compress new momentum m_t and store factors
|
|
@@ -254,20 +233,7 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
|
|
|
254
233
|
update_for_param = signed_update.mul(self.dlr)
|
|
255
234
|
|
|
256
235
|
# Update momentum
|
|
257
|
-
|
|
258
|
-
if state['step'] == 1:
|
|
259
|
-
exp_avg.copy_(grad)
|
|
260
|
-
else:
|
|
261
|
-
# Heuristic Prodigy-STORM update
|
|
262
|
-
correction = exp_avg.sub(state['prev_grad'])
|
|
263
|
-
grad_alpha = self.d * (1 - self.beta2) + self.beta2
|
|
264
|
-
exp_avg.copy_(grad).mul_(grad_alpha).add_(correction, alpha=self.beta2)
|
|
265
|
-
del grad_alpha, correction
|
|
266
|
-
state['prev_grad'].copy_(grad)
|
|
267
|
-
else:
|
|
268
|
-
# Standard Prodigy-Lion
|
|
269
|
-
alpha = self.d * (1 - self.beta2)
|
|
270
|
-
exp_avg.mul_(self.beta2).add_(grad, alpha=alpha)
|
|
236
|
+
exp_avg.mul_(self.beta2).add_(grad, alpha=self.d * (1 - self.beta2))
|
|
271
237
|
|
|
272
238
|
# --- Accumulate Prodigy stats ---
|
|
273
239
|
d0, safeguard_warmup, slice_p = group['d0'], group['safeguard_warmup'], group['slice_p']
|
|
@@ -298,7 +264,7 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
|
|
|
298
264
|
else:
|
|
299
265
|
p.data.add_(-update_for_param)
|
|
300
266
|
|
|
301
|
-
|
|
267
|
+
del update_for_param
|
|
302
268
|
|
|
303
269
|
@torch.no_grad()
|
|
304
270
|
def step(self, closure: Optional[callable] = None):
|
|
@@ -33,8 +33,6 @@ class Lion_adv(torch.optim.Optimizer):
|
|
|
33
33
|
(default: 0.0).
|
|
34
34
|
factored (bool): whether to use the factorization or use the
|
|
35
35
|
uncompressed optimizer. (default: True)
|
|
36
|
-
variance_reduction (bool): whether to use the variance reduction technique
|
|
37
|
-
from "Convergence Analysis of the Lion Optimizer" (arXiv:2508.12327v1). (default: False).
|
|
38
36
|
"""
|
|
39
37
|
|
|
40
38
|
def __init__(
|
|
@@ -49,7 +47,6 @@ class Lion_adv(torch.optim.Optimizer):
|
|
|
49
47
|
use_cautious: bool = False,
|
|
50
48
|
clip_threshold: float = 0.0,
|
|
51
49
|
factored: bool = True,
|
|
52
|
-
variance_reduction: bool = False,
|
|
53
50
|
):
|
|
54
51
|
if not lr > 0.0:
|
|
55
52
|
raise ValueError(f"Learning rate must be > 0.0, but got {lr}")
|
|
@@ -69,7 +66,6 @@ class Lion_adv(torch.optim.Optimizer):
|
|
|
69
66
|
self.stochastic_rounding = stochastic_rounding
|
|
70
67
|
self.use_cautious = use_cautious
|
|
71
68
|
self.factored = factored
|
|
72
|
-
self.variance_reduction = variance_reduction
|
|
73
69
|
super().__init__(params, defaults)
|
|
74
70
|
|
|
75
71
|
@property
|
|
@@ -122,12 +118,8 @@ class Lion_adv(torch.optim.Optimizer):
|
|
|
122
118
|
state['mv_m_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
|
|
123
119
|
packed_d2 = (d2 + 7) // 8
|
|
124
120
|
state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=p.device)
|
|
125
|
-
if self.variance_reduction:
|
|
126
|
-
state['prev_grad'] = torch.zeros((d1, d2), device=p.device, dtype=dtype)
|
|
127
121
|
else: # Fallback to standard Lion
|
|
128
122
|
state['exp_avg'] = torch.zeros_like(p, device=p.device, dtype=dtype)
|
|
129
|
-
if self.variance_reduction:
|
|
130
|
-
state['prev_grad'] = torch.zeros_like(p, device=p.device, dtype=dtype)
|
|
131
123
|
|
|
132
124
|
state['step'] += 1
|
|
133
125
|
beta1, beta2 = group["betas"]
|
|
@@ -157,21 +149,9 @@ class Lion_adv(torch.optim.Optimizer):
|
|
|
157
149
|
# Parameter update
|
|
158
150
|
update_for_param = signed_update.view(p.shape).mul_(lr)
|
|
159
151
|
|
|
160
|
-
#
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
exp_avg.copy_(grad_reshaped)
|
|
164
|
-
else:
|
|
165
|
-
# Use the simplified STORM update: m_t = g_t + β₂ * (m_{t-1} - g_{t-1})
|
|
166
|
-
correction = exp_avg.sub(state['prev_grad'])
|
|
167
|
-
# Calculate the new momentum and store it back into exp_avg
|
|
168
|
-
exp_avg.copy_(grad_reshaped).add_(correction, alpha=beta2)
|
|
169
|
-
del correction
|
|
170
|
-
# Update prev_grad for the next iteration
|
|
171
|
-
state['prev_grad'].copy_(grad_reshaped)
|
|
172
|
-
else:
|
|
173
|
-
# Standard Lion momentum update
|
|
174
|
-
exp_avg.mul_(beta2).add_(grad_reshaped, alpha=1-beta2)
|
|
152
|
+
# Standard Lion momentum update
|
|
153
|
+
exp_avg.mul_(beta2).add_(grad_reshaped, alpha=1-beta2)
|
|
154
|
+
del grad_reshaped
|
|
175
155
|
|
|
176
156
|
# Compress new momentum m_t and store factors
|
|
177
157
|
state['sign'] = _pack_bools(exp_avg > 0)
|
|
@@ -195,21 +175,8 @@ class Lion_adv(torch.optim.Optimizer):
|
|
|
195
175
|
|
|
196
176
|
update_for_param = signed_update.mul_(lr)
|
|
197
177
|
|
|
198
|
-
#
|
|
199
|
-
|
|
200
|
-
if state['step'] == 1:
|
|
201
|
-
exp_avg.copy_(grad)
|
|
202
|
-
else:
|
|
203
|
-
# Use the simplified STORM update: m_t = g_t + β₂ * (m_{t-1} - g_{t-1})
|
|
204
|
-
correction = exp_avg.sub(state['prev_grad'])
|
|
205
|
-
# Calculate the new momentum and store it back into exp_avg
|
|
206
|
-
exp_avg.copy_(grad).add_(correction, alpha=beta2)
|
|
207
|
-
del correction
|
|
208
|
-
# Update prev_grad for the next iteration
|
|
209
|
-
state['prev_grad'].copy_(grad)
|
|
210
|
-
else:
|
|
211
|
-
# Standard Lion momentum update
|
|
212
|
-
exp_avg.mul_(beta2).add_(grad, alpha=1-beta2)
|
|
178
|
+
# Standard Lion momentum update
|
|
179
|
+
exp_avg.mul_(beta2).add_(grad, alpha=1-beta2)
|
|
213
180
|
|
|
214
181
|
if group["weight_decay"] != 0:
|
|
215
182
|
if p.dtype == torch.bfloat16 and self.stochastic_rounding:
|
|
@@ -225,7 +192,7 @@ class Lion_adv(torch.optim.Optimizer):
|
|
|
225
192
|
else:
|
|
226
193
|
p.data.add_(-update_for_param)
|
|
227
194
|
|
|
228
|
-
|
|
195
|
+
del update_for_param
|
|
229
196
|
|
|
230
197
|
@torch.no_grad()
|
|
231
198
|
def step(self, closure: Optional[callable] = None):
|
|
@@ -52,9 +52,19 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
52
52
|
highly recommended to prevent instability at the beginning of training,
|
|
53
53
|
as it gradually introduces the stabilizing slow momentum term. During
|
|
54
54
|
the warmup, `alpha` ramps from 0 to its target value. If `None`,
|
|
55
|
-
the scheduler is disabled
|
|
55
|
+
the scheduler is disabled.
|
|
56
|
+
Simplified_AdEMAMix (bool): whether to use the Simplified AdEMAMix update rule.
|
|
57
|
+
This changes the EMA to accumulator and the update numerator to `alpha_grad * grad + mt`, which can be
|
|
58
|
+
more responsive, especially for small batch sizes. Enabling this will
|
|
59
|
+
automatically disable `use_AdEMAMix`, `use_cautious`, `use_grams`,
|
|
60
|
+
and `use_atan2`. (default: False)
|
|
61
|
+
alpha_grad (float): Mixing coefficient for the Simplified AdEMAMix update rule
|
|
62
|
+
(only used when `Simplified_AdEMAMix` is `True`). Controls the weight of the
|
|
63
|
+
current gradient. For small batch sizes, use high values (e.g., 10-100) to be
|
|
64
|
+
more responsive. For large batch sizes, use low values (e.g., 0-1) for
|
|
65
|
+
stability. (default: 100.0)
|
|
56
66
|
factored (bool): whether to use the factorization or disable it to use
|
|
57
|
-
the uncompressed optimizer. (default:
|
|
67
|
+
the uncompressed optimizer. (default: False)
|
|
58
68
|
d0 (float):
|
|
59
69
|
Initial D estimate for D-adaptation (default 1e-6). Rarely needs changing.
|
|
60
70
|
d_coef (float):
|
|
@@ -72,6 +82,9 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
72
82
|
slice_p (int): Reduce memory usage by calculating LR adaptation statistics on only every
|
|
73
83
|
pth entry of each tensor. For values greater than 1 this an an approximation to standard
|
|
74
84
|
Prodigy. Values ~11 are reasonable (default 11).
|
|
85
|
+
prodigy_steps (int): If greater than zero, disable Prodigy's stepsize adjustments
|
|
86
|
+
after the specified optimiser step and release all state memory required by Prodigy
|
|
87
|
+
(default: 0).
|
|
75
88
|
"""
|
|
76
89
|
|
|
77
90
|
def __init__(
|
|
@@ -91,7 +104,9 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
91
104
|
beta3_ema: float = 0.9999,
|
|
92
105
|
alpha: float = 5.0,
|
|
93
106
|
t_alpha: int | None = None,
|
|
94
|
-
|
|
107
|
+
Simplified_AdEMAMix: bool = False,
|
|
108
|
+
alpha_grad: float = 100.0,
|
|
109
|
+
factored: bool = False,
|
|
95
110
|
# prodigy parameters
|
|
96
111
|
beta3: float = None,
|
|
97
112
|
d0: float = 1e-6,
|
|
@@ -100,6 +115,7 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
100
115
|
safeguard_warmup: bool = False,
|
|
101
116
|
fsdp_in_use: bool = False,
|
|
102
117
|
slice_p: int = 11,
|
|
118
|
+
prodigy_steps: int = 0,
|
|
103
119
|
):
|
|
104
120
|
if not (lr >= 0.0):
|
|
105
121
|
raise ValueError(f"Learning-rate should be >= 0.0. Got {lr}")
|
|
@@ -109,6 +125,22 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
109
125
|
raise ValueError(f"Epsilon should be >= 0.0. Got {eps}")
|
|
110
126
|
if not (weight_decay >= 0.0):
|
|
111
127
|
raise ValueError(f"Weight-decay should be >= 0.0. Got {weight_decay}")
|
|
128
|
+
if not (prodigy_steps >= 0):
|
|
129
|
+
raise ValueError(f"prodigy_steps should be >= 0. Got {prodigy_steps}")
|
|
130
|
+
if betas[0] == 0.0 and Simplified_AdEMAMix:
|
|
131
|
+
raise ValueError(f"Beta 1 cannot be 0.0 when using Simplified_AdEMAMix. Got {betas[0]}")
|
|
132
|
+
if use_AdEMAMix and Simplified_AdEMAMix:
|
|
133
|
+
print("Warning: use_AdEMAMix is incompatible with Simplified_AdEMAMix, Disabling use_AdEMAMix.")
|
|
134
|
+
if use_grams and Simplified_AdEMAMix:
|
|
135
|
+
print("Warning: use_grams is incompatible with Simplified_AdEMAMix, Disabling use_grams.")
|
|
136
|
+
if use_cautious and Simplified_AdEMAMix:
|
|
137
|
+
print("Warning: use_cautious is incompatible with Simplified_AdEMAMix, Disabling use_cautious.")
|
|
138
|
+
if use_atan2 and Simplified_AdEMAMix:
|
|
139
|
+
print("Warning: use_atan2 is incompatible with Simplified_AdEMAMix. Disabling use_atan2.")
|
|
140
|
+
use_atan2 = False
|
|
141
|
+
if Simplified_AdEMAMix and alpha_grad > 0:
|
|
142
|
+
# scales d_coef by alpha_grad, this force prodigy to behave well with Simplified_AdEMAMix
|
|
143
|
+
d_coef = d_coef/alpha_grad
|
|
112
144
|
|
|
113
145
|
defaults = {
|
|
114
146
|
"lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay,
|
|
@@ -117,12 +149,14 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
117
149
|
"beta3_ema": beta3_ema, "alpha": alpha, "t_alpha": t_alpha,
|
|
118
150
|
"beta3": beta3, "d": d0, "d0": d0, "d_max": d0, "d_numerator": 0.0, "d_coef": d_coef,
|
|
119
151
|
"growth_rate": growth_rate, "safeguard_warmup": safeguard_warmup, "k": 0, "slice_p": slice_p,
|
|
120
|
-
"fsdp_in_use": fsdp_in_use,
|
|
152
|
+
"fsdp_in_use": fsdp_in_use, "prodigy_steps": prodigy_steps,
|
|
153
|
+
"alpha_grad": alpha_grad,
|
|
121
154
|
}
|
|
122
155
|
self.stochastic_rounding = stochastic_rounding
|
|
123
|
-
self.use_cautious = use_cautious
|
|
124
|
-
self.use_grams = use_grams
|
|
125
|
-
self.use_AdEMAMix = use_AdEMAMix
|
|
156
|
+
self.use_cautious = use_cautious and not Simplified_AdEMAMix
|
|
157
|
+
self.use_grams = use_grams and not Simplified_AdEMAMix
|
|
158
|
+
self.use_AdEMAMix = use_AdEMAMix and not Simplified_AdEMAMix
|
|
159
|
+
self.Simplified_AdEMAMix = Simplified_AdEMAMix
|
|
126
160
|
self.factored = factored
|
|
127
161
|
self.fsdp_in_use = fsdp_in_use
|
|
128
162
|
super().__init__(params, defaults)
|
|
@@ -229,6 +263,8 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
229
263
|
alpha_t = alpha
|
|
230
264
|
if t_alpha is not None and t_alpha > 0 and current_step < t_alpha:
|
|
231
265
|
alpha_t = min(current_step * alpha / t_alpha, alpha)
|
|
266
|
+
if self.Simplified_AdEMAMix:
|
|
267
|
+
alpha_grad = group["alpha_grad"]
|
|
232
268
|
|
|
233
269
|
if state['factored']:
|
|
234
270
|
d1, d2 = state['effective_shape']
|
|
@@ -243,7 +279,10 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
243
279
|
torch.where(unpacked_sign, mt, -mt, out=mt)
|
|
244
280
|
del unpacked_sign
|
|
245
281
|
# Update momentum in full-size
|
|
246
|
-
|
|
282
|
+
if self.Simplified_AdEMAMix:
|
|
283
|
+
mt.mul_(self.beta1).add_(grad_reshaped, alpha=self.d)
|
|
284
|
+
else:
|
|
285
|
+
mt.mul_(self.beta1).add_(grad_reshaped, alpha=self.d * (1.0 - self.beta1))
|
|
247
286
|
if self.use_grams:
|
|
248
287
|
mt.copy_(grad_reshaped.sign() * mt.abs())
|
|
249
288
|
elif self.use_cautious:
|
|
@@ -263,7 +302,12 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
263
302
|
torch.where(unpacked_sign_slow, mt_slow, -mt_slow, out=mt_slow)
|
|
264
303
|
del unpacked_sign_slow
|
|
265
304
|
mt_slow.mul_(beta3_ema).add_(grad_reshaped, alpha=self.d * (1.0 - beta3_ema))
|
|
266
|
-
|
|
305
|
+
if self.beta1 > 0:
|
|
306
|
+
update = torch.add(mt, mt_slow, alpha=alpha_t)
|
|
307
|
+
else:
|
|
308
|
+
update = torch.add(grad_reshaped, mt_slow, alpha=alpha_t)
|
|
309
|
+
elif self.Simplified_AdEMAMix:
|
|
310
|
+
update = torch.add(mt, grad_reshaped, alpha=alpha_grad * self.d)
|
|
267
311
|
else:
|
|
268
312
|
update = mt.clone() if self.beta1 > 0 else grad_reshaped.clone()
|
|
269
313
|
del grad_reshaped
|
|
@@ -297,7 +341,10 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
297
341
|
|
|
298
342
|
if self.beta1 > 0:
|
|
299
343
|
exp_avg = state['exp_avg']
|
|
300
|
-
|
|
344
|
+
if self.Simplified_AdEMAMix:
|
|
345
|
+
exp_avg.mul_(self.beta1).add_(grad, alpha=self.d)
|
|
346
|
+
else:
|
|
347
|
+
exp_avg.mul_(self.beta1).add_(grad, alpha=self.d * (1.0 - self.beta1))
|
|
301
348
|
if self.use_grams:
|
|
302
349
|
exp_avg = grad.sign() * exp_avg.abs()
|
|
303
350
|
elif self.use_cautious:
|
|
@@ -309,7 +356,12 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
309
356
|
if self.use_AdEMAMix:
|
|
310
357
|
exp_avg_slow = state['exp_avg_slow']
|
|
311
358
|
exp_avg_slow.mul_(beta3_ema).add_(grad, alpha=self.d * (1.0 - beta3_ema))
|
|
312
|
-
|
|
359
|
+
if self.beta1 > 0:
|
|
360
|
+
update = torch.add(exp_avg, exp_avg_slow, alpha=alpha_t)
|
|
361
|
+
else:
|
|
362
|
+
update = torch.add(grad, exp_avg_slow, alpha=alpha_t)
|
|
363
|
+
elif self.Simplified_AdEMAMix:
|
|
364
|
+
update = torch.add(exp_avg, grad, alpha=alpha_grad * self.d)
|
|
313
365
|
else:
|
|
314
366
|
update = exp_avg.clone() if self.beta1 > 0 else grad.clone()
|
|
315
367
|
|
|
@@ -327,19 +379,27 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
327
379
|
update.mul_(self.dlr)
|
|
328
380
|
|
|
329
381
|
# --- Accumulate Prodigy stats ---
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
382
|
+
prodigy_steps = group['prodigy_steps']
|
|
383
|
+
if prodigy_steps <= 0 or group['k'] < prodigy_steps:
|
|
384
|
+
d0, safeguard_warmup, slice_p = group['d0'], group['safeguard_warmup'], group['slice_p']
|
|
385
|
+
s, p0 = state['s'], state['p0']
|
|
386
|
+
grad_flat = grad.flatten().float()
|
|
387
|
+
p_flat = p.data.flatten().float()
|
|
388
|
+
p0 = p0.float()
|
|
335
389
|
|
|
336
|
-
|
|
390
|
+
self.d_numerator += (self.d / d0) * self.dlr * torch.dot(grad_flat[::slice_p], p0.data - p_flat[::slice_p]).item()
|
|
337
391
|
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
392
|
+
alpha = ((self.d / d0) * self.d) if safeguard_warmup else ((self.d / d0) * self.dlr)
|
|
393
|
+
s.mul_(self.beta3).add_(grad_flat[::slice_p], alpha=alpha)
|
|
394
|
+
self.d_denom += s.abs().sum().item()
|
|
341
395
|
|
|
342
|
-
|
|
396
|
+
del s, p0, grad_flat, p_flat, alpha
|
|
397
|
+
else:
|
|
398
|
+
# Free memory if prodigy_steps is reached
|
|
399
|
+
if 's' in state:
|
|
400
|
+
del state['s']
|
|
401
|
+
if 'p0' in state:
|
|
402
|
+
del state['p0']
|
|
343
403
|
|
|
344
404
|
# Decoupled weight decay
|
|
345
405
|
if group["weight_decay"] != 0:
|
|
@@ -376,29 +436,37 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
376
436
|
def calculate_d(self):
|
|
377
437
|
"""Calculates the new `d` based on the accumulated stats."""
|
|
378
438
|
g_group = self.param_groups[0]
|
|
379
|
-
d_max, d_coef, growth_rate = g_group['d_max'], g_group['d_coef'], g_group['growth_rate']
|
|
380
439
|
|
|
381
|
-
if
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
440
|
+
# Only perform d-adaptation if prodigy_steps has not been reached
|
|
441
|
+
prodigy_active = not (g_group.get('prodigy_steps', 0) > 0 and g_group['k'] >= g_group['prodigy_steps'])
|
|
442
|
+
|
|
443
|
+
if prodigy_active:
|
|
444
|
+
d_max, d_coef, growth_rate = g_group['d_max'], g_group['d_coef'], g_group['growth_rate']
|
|
445
|
+
|
|
446
|
+
if self.fsdp_in_use and dist.is_available() and dist.is_initialized():
|
|
447
|
+
# Use the device of the first parameter to avoid hardcoding '.cuda()'
|
|
448
|
+
device = self.param_groups[0]['params'][0].device
|
|
449
|
+
dist_tensor = torch.tensor([self.d_numerator, self.d_denom], device=device)
|
|
450
|
+
dist.all_reduce(dist_tensor, op=dist.ReduceOp.SUM)
|
|
451
|
+
global_d_numerator = dist_tensor[0].item()
|
|
452
|
+
global_d_denom = dist_tensor[1].item()
|
|
453
|
+
else:
|
|
454
|
+
global_d_numerator = self.d_numerator
|
|
455
|
+
global_d_denom = self.d_denom
|
|
456
|
+
|
|
457
|
+
d_hat = self.d
|
|
458
|
+
if global_d_denom > 0:
|
|
459
|
+
d_hat = d_coef * global_d_numerator / global_d_denom
|
|
460
|
+
if self.d == g_group['d0']:
|
|
461
|
+
self.d = max(self.d, d_hat)
|
|
462
|
+
d_max = max(d_max, d_hat)
|
|
463
|
+
self.d = min(d_max, self.d * growth_rate)
|
|
464
|
+
|
|
465
|
+
for group in self.param_groups:
|
|
466
|
+
group['d_numerator'] = global_d_numerator
|
|
467
|
+
group['d'] = self.d
|
|
468
|
+
group['d_max'] = d_max
|
|
469
|
+
|
|
470
|
+
# Increment step counter for all groups, regardless of whether d was updated
|
|
400
471
|
for group in self.param_groups:
|
|
401
|
-
group['d_numerator'] = global_d_numerator
|
|
402
|
-
group['d'] = self.d
|
|
403
|
-
group['d_max'] = d_max
|
|
404
472
|
group['k'] += 1
|
|
@@ -0,0 +1,246 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
import math
|
|
4
|
+
|
|
5
|
+
from ..util.BF16_Stochastic_Rounding import add_stochastic_
|
|
6
|
+
from ..util.Effective_Shape import _get_effective_shape
|
|
7
|
+
from ..util.NNMF import _nnmf,_unnmf
|
|
8
|
+
from ..util.OrthoGrad import _orthogonalize_gradient
|
|
9
|
+
from ..util.One_Bit_Boolean import _pack_bools, _unpack_bools
|
|
10
|
+
|
|
11
|
+
# A little helper from the original simplified_AdEMAMix
|
|
12
|
+
def linear_hl_warmup_scheduler(step, beta_end, beta_start=0, warmup=1):
|
|
13
|
+
|
|
14
|
+
def f(beta, eps=1e-8):
|
|
15
|
+
return math.log(0.5)/math.log(beta+eps)-1
|
|
16
|
+
|
|
17
|
+
def f_inv(t):
|
|
18
|
+
return math.pow(0.5, 1/(t+1))
|
|
19
|
+
|
|
20
|
+
if step < warmup:
|
|
21
|
+
a = step / float(warmup)
|
|
22
|
+
return f_inv((1.0-a) * f(beta_start) + a * f(beta_end))
|
|
23
|
+
return beta_end
|
|
24
|
+
|
|
25
|
+
class Simplified_AdEMAMix(torch.optim.Optimizer):
|
|
26
|
+
"""
|
|
27
|
+
Implements the Simplified AdEMAMix algorithm.
|
|
28
|
+
Refactored from:
|
|
29
|
+
https://github.com/DepenM/Simplified-AdEMAMix/blob/main/simplified_AdEMAMix.py
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
params (iterable): iterable of parameters to optimize or dicts defining
|
|
33
|
+
parameter groups
|
|
34
|
+
lr (float): learning rate (default: 1e-5)
|
|
35
|
+
betas (tuple[float, float]): coefficients used for computing running
|
|
36
|
+
averages of gradient and its square (default: (0.99, 0.999))
|
|
37
|
+
eps (float): term added to the denominator to improve
|
|
38
|
+
numerical stability (default: 1e-8)
|
|
39
|
+
weight_decay (float): weight decay (L2 penalty) (default: 0).
|
|
40
|
+
alpha_grad (float): Coeficient for mixing the current gradient and EMA. for small batch
|
|
41
|
+
sizes set it to high values, up to 100. And for large batch sized set it to small
|
|
42
|
+
value, down to 0. (default: 100)
|
|
43
|
+
beta1_warmup (int, optional): number of warmup steps used to increase beta1 (default: None)
|
|
44
|
+
min_beta1 (float, optional): minimum value of beta1 to start from (default 0.9)
|
|
45
|
+
vector_reshape (bool): whether to reshape 1D vectors into 2D
|
|
46
|
+
matrices to apply low-rank compression (default: True).
|
|
47
|
+
stochastic_rounding (bool): whether to use stochastic
|
|
48
|
+
rounding for BF16 parameter updates (default: True).
|
|
49
|
+
use_orthograd (bool): whether to use OrthoGrad. (default: False)
|
|
50
|
+
factored (bool): whether to use the factorization or disable it to use
|
|
51
|
+
the uncompressed optimizer. (default: False)
|
|
52
|
+
"""
|
|
53
|
+
|
|
54
|
+
def __init__(
|
|
55
|
+
self,
|
|
56
|
+
params,
|
|
57
|
+
lr: float = 1e-5,
|
|
58
|
+
betas: tuple[float, float] = (0.99, 0.999),
|
|
59
|
+
eps: float = 1e-8,
|
|
60
|
+
weight_decay: float = 0.0,
|
|
61
|
+
alpha_grad: float = 100.0,
|
|
62
|
+
beta1_warmup: int | None = None,
|
|
63
|
+
min_beta1: float | None = 0.9,
|
|
64
|
+
use_bias_correction: bool = True,
|
|
65
|
+
vector_reshape: bool = True,
|
|
66
|
+
stochastic_rounding: bool = True,
|
|
67
|
+
use_orthograd: bool = False,
|
|
68
|
+
factored: bool = False,
|
|
69
|
+
):
|
|
70
|
+
if not (lr >= 0.0):
|
|
71
|
+
raise ValueError(f"Learning-rate should be >= 0.0. Got {lr}")
|
|
72
|
+
if not (0.0 <= betas[0] < 1.0 and 0.0 <= betas[1] < 1.0):
|
|
73
|
+
raise ValueError(f"Betas should be in [0.0, 1.0). Got {betas}")
|
|
74
|
+
if not (eps >= 0.0):
|
|
75
|
+
raise ValueError(f"Epsilon should be >= 0.0. Got {eps}")
|
|
76
|
+
if not (weight_decay >= 0.0):
|
|
77
|
+
raise ValueError(f"Weight-decay should be >= 0.0. Got {weight_decay}")
|
|
78
|
+
if not 0.0 <= alpha_grad:
|
|
79
|
+
raise ValueError("Invalid alpha value: {}".format(alpha_grad))
|
|
80
|
+
|
|
81
|
+
defaults = {
|
|
82
|
+
"lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay,
|
|
83
|
+
"alpha_grad": alpha_grad, "beta1_warmup": beta1_warmup, "min_beta1": min_beta1,
|
|
84
|
+
"vector_reshape": vector_reshape,
|
|
85
|
+
"use_orthograd": use_orthograd, "use_bias_correction": use_bias_correction,
|
|
86
|
+
}
|
|
87
|
+
self.stochastic_rounding = stochastic_rounding
|
|
88
|
+
self.factored = factored
|
|
89
|
+
super().__init__(params, defaults)
|
|
90
|
+
|
|
91
|
+
@property
|
|
92
|
+
def supports_fused_back_pass(self):
|
|
93
|
+
return True
|
|
94
|
+
|
|
95
|
+
@property
|
|
96
|
+
def supports_memory_efficient_fp16(self):
|
|
97
|
+
return True
|
|
98
|
+
|
|
99
|
+
@property
|
|
100
|
+
def supports_flat_params(self):
|
|
101
|
+
return False
|
|
102
|
+
|
|
103
|
+
@torch.no_grad()
|
|
104
|
+
def step_parameter(self, p: torch.Tensor, group: dict, i: int | None = None):
|
|
105
|
+
if p.grad is None:
|
|
106
|
+
return
|
|
107
|
+
|
|
108
|
+
grad = p.grad
|
|
109
|
+
if grad.dtype != torch.float32 and self.factored:
|
|
110
|
+
grad = grad.float()
|
|
111
|
+
if group["use_orthograd"]:
|
|
112
|
+
grad = _orthogonalize_gradient(p, grad)
|
|
113
|
+
state = self.state[p]
|
|
114
|
+
|
|
115
|
+
# State Initialization
|
|
116
|
+
if len(state) == 0:
|
|
117
|
+
state['step'] = 0
|
|
118
|
+
|
|
119
|
+
should_factor = (
|
|
120
|
+
self.factored and
|
|
121
|
+
not (len(p.shape) == 1 and not group['vector_reshape'])
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
state['factored'] = should_factor
|
|
125
|
+
|
|
126
|
+
dtype = torch.float32 if self.factored else p.dtype
|
|
127
|
+
device = p.device
|
|
128
|
+
|
|
129
|
+
if state['factored']:
|
|
130
|
+
state['effective_shape'] = _get_effective_shape(p.numel())
|
|
131
|
+
d1, d2 = state['effective_shape']
|
|
132
|
+
|
|
133
|
+
# First moment (m)
|
|
134
|
+
state['mu_m_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
|
|
135
|
+
state['mv_m_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
|
|
136
|
+
packed_d2 = (d2 + 7) // 8
|
|
137
|
+
state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=device)
|
|
138
|
+
# Second moment (v)
|
|
139
|
+
state['mu_v_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
|
|
140
|
+
state['mv_v_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
|
|
141
|
+
else: # Fallback to standard optimizer for non-factored tensors
|
|
142
|
+
state['exp_avg'] = torch.zeros_like(p, device=device, dtype=dtype)
|
|
143
|
+
state['exp_avg_sq'] = torch.zeros_like(p, device=device, dtype=dtype)
|
|
144
|
+
|
|
145
|
+
if group['use_bias_correction']:
|
|
146
|
+
state['num_sum'] = 0.0
|
|
147
|
+
state['den_sum'] = 0.0
|
|
148
|
+
else:
|
|
149
|
+
state['num_sum'] = 1.0
|
|
150
|
+
state['den_sum'] = 1.0
|
|
151
|
+
|
|
152
|
+
beta1_final, beta2 = group["betas"]
|
|
153
|
+
beta1_warmup = group["beta1_warmup"]
|
|
154
|
+
alpha_grad = group["alpha_grad"]
|
|
155
|
+
|
|
156
|
+
if beta1_warmup is not None:
|
|
157
|
+
step = state['step'] + 1
|
|
158
|
+
beta1 = linear_hl_warmup_scheduler(step, beta_end=beta1_final, beta_start=group['min_beta1'], warmup=beta1_warmup)
|
|
159
|
+
else:
|
|
160
|
+
beta1 = beta1_final
|
|
161
|
+
|
|
162
|
+
if group['use_bias_correction']:
|
|
163
|
+
state['num_sum'] = beta1 * state['num_sum'] + 1.0
|
|
164
|
+
state['den_sum'] = beta2 * state['den_sum'] + (1.0 - beta2)
|
|
165
|
+
|
|
166
|
+
if state['factored']:
|
|
167
|
+
d1, d2 = state['effective_shape']
|
|
168
|
+
|
|
169
|
+
# Reconstruct momentum from previous step's factors
|
|
170
|
+
mt = _unnmf((state['mu_m_nmf'], state['mv_m_nmf']))
|
|
171
|
+
unpacked_sign = _unpack_bools(state['sign'], original_m=d2)
|
|
172
|
+
torch.where(unpacked_sign, mt, -mt, out=mt)
|
|
173
|
+
del unpacked_sign
|
|
174
|
+
# Update momentum in full-size
|
|
175
|
+
grad_reshaped = grad.view(d1, d2)
|
|
176
|
+
mt.mul_(beta1).add_(grad_reshaped, alpha=1.0)
|
|
177
|
+
|
|
178
|
+
vt = _unnmf((state['mu_v_nmf'], state['mv_v_nmf']))
|
|
179
|
+
vt.mul_(beta2).addcmul_(grad_reshaped, grad_reshaped, value=1.0 - beta2)
|
|
180
|
+
|
|
181
|
+
update = torch.add(mt, grad_reshaped, alpha=alpha_grad)
|
|
182
|
+
del grad_reshaped
|
|
183
|
+
|
|
184
|
+
denom = vt.sqrt().add_(group['eps'] * math.sqrt(state['den_sum']))
|
|
185
|
+
update.div_(denom)
|
|
186
|
+
del denom
|
|
187
|
+
|
|
188
|
+
if group['use_bias_correction']:
|
|
189
|
+
update = (update / state['num_sum']) * math.sqrt(state['den_sum'])
|
|
190
|
+
|
|
191
|
+
update = update.view(p.shape).mul_(group['lr'])
|
|
192
|
+
|
|
193
|
+
# Compress updated moments and store new factors
|
|
194
|
+
state['sign'] = _pack_bools(mt > 0)
|
|
195
|
+
_nnmf(mt.abs(), out=(state['mu_m_nmf'], state['mv_m_nmf']))
|
|
196
|
+
del mt
|
|
197
|
+
_nnmf(vt, out=(state['mu_v_nmf'], state['mv_v_nmf']))
|
|
198
|
+
del vt
|
|
199
|
+
|
|
200
|
+
else: # Standard optimizer logic for non-factored tensors
|
|
201
|
+
exp_avg_sq = state['exp_avg_sq']
|
|
202
|
+
|
|
203
|
+
exp_avg = state['exp_avg']
|
|
204
|
+
exp_avg.mul_(beta1).add_(grad, alpha=1.0)
|
|
205
|
+
|
|
206
|
+
update = torch.add(exp_avg, grad, alpha=alpha_grad)
|
|
207
|
+
|
|
208
|
+
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
|
|
209
|
+
|
|
210
|
+
denom = exp_avg_sq.sqrt().add_(group['eps'] * math.sqrt(state['den_sum']))
|
|
211
|
+
update.div_(denom)
|
|
212
|
+
del denom
|
|
213
|
+
|
|
214
|
+
if group['use_bias_correction']:
|
|
215
|
+
update = (update / state['num_sum']) * math.sqrt(state['den_sum'])
|
|
216
|
+
|
|
217
|
+
update.mul_(group['lr'])
|
|
218
|
+
|
|
219
|
+
# Decoupled weight decay
|
|
220
|
+
if group["weight_decay"] != 0:
|
|
221
|
+
if p.dtype == torch.bfloat16 and self.stochastic_rounding:
|
|
222
|
+
add_stochastic_(p.data, p.data, alpha=-group["weight_decay"] * group["lr"])
|
|
223
|
+
else:
|
|
224
|
+
p.data.add_(p.data, alpha=-group["weight_decay"] * group["lr"])
|
|
225
|
+
|
|
226
|
+
if p.dtype == torch.bfloat16 and self.stochastic_rounding:
|
|
227
|
+
add_stochastic_(p.data, -update)
|
|
228
|
+
else:
|
|
229
|
+
p.data.add_(-update)
|
|
230
|
+
del update
|
|
231
|
+
|
|
232
|
+
state['step'] += 1
|
|
233
|
+
|
|
234
|
+
@torch.no_grad()
|
|
235
|
+
def step(self, closure=None):
|
|
236
|
+
"""Performs a single optimization step."""
|
|
237
|
+
loss = None
|
|
238
|
+
if closure is not None:
|
|
239
|
+
with torch.enable_grad():
|
|
240
|
+
loss = closure()
|
|
241
|
+
|
|
242
|
+
for group in self.param_groups:
|
|
243
|
+
for i, p in enumerate(group['params']):
|
|
244
|
+
self.step_parameter(p, group, i)
|
|
245
|
+
|
|
246
|
+
return loss
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
from .AdamW_adv import AdamW_adv
|
|
2
2
|
from .Prodigy_adv import Prodigy_adv
|
|
3
3
|
from .Adopt_adv import Adopt_adv
|
|
4
|
+
from .Simplified_AdEMAMix import Simplified_AdEMAMix
|
|
4
5
|
from .Lion_adv import Lion_adv
|
|
5
6
|
from .Lion_Prodigy_adv import Lion_Prodigy_adv
|
|
6
7
|
|
|
@@ -8,6 +9,7 @@ __all__ = [
|
|
|
8
9
|
"AdamW_adv",
|
|
9
10
|
"Prodigy_adv",
|
|
10
11
|
"Adopt_adv",
|
|
12
|
+
"Simplified_AdEMAMix",
|
|
11
13
|
"Lion_adv",
|
|
12
14
|
"Lion_Prodigy_adv",
|
|
13
15
|
]
|
|
@@ -12,6 +12,7 @@ adv_optm/optim/Adopt_adv.py
|
|
|
12
12
|
adv_optm/optim/Lion_Prodigy_adv.py
|
|
13
13
|
adv_optm/optim/Lion_adv.py
|
|
14
14
|
adv_optm/optim/Prodigy_adv.py
|
|
15
|
+
adv_optm/optim/Simplified_AdEMAMix.py
|
|
15
16
|
adv_optm/optim/__init__.py
|
|
16
17
|
adv_optm/util/BF16_Stochastic_Rounding.py
|
|
17
18
|
adv_optm/util/Effective_Shape.py
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|