adv-optm 0.1.5__tar.gz → 0.1.7__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of adv-optm might be problematic. Click here for more details.
- {adv_optm-0.1.5 → adv_optm-0.1.7}/PKG-INFO +1 -1
- {adv_optm-0.1.5 → adv_optm-0.1.7}/adv_optm/__init__.py +3 -1
- {adv_optm-0.1.5 → adv_optm-0.1.7}/adv_optm/optim/AdamW_adv.py +1 -1
- {adv_optm-0.1.5 → adv_optm-0.1.7}/adv_optm/optim/Prodigy_adv.py +44 -7
- adv_optm-0.1.7/adv_optm/optim/Simplified_AdEMAMix.py +246 -0
- {adv_optm-0.1.5 → adv_optm-0.1.7}/adv_optm/optim/__init__.py +2 -0
- {adv_optm-0.1.5 → adv_optm-0.1.7}/adv_optm.egg-info/PKG-INFO +1 -1
- {adv_optm-0.1.5 → adv_optm-0.1.7}/adv_optm.egg-info/SOURCES.txt +1 -0
- {adv_optm-0.1.5 → adv_optm-0.1.7}/setup.py +1 -1
- {adv_optm-0.1.5 → adv_optm-0.1.7}/LICENSE +0 -0
- {adv_optm-0.1.5 → adv_optm-0.1.7}/README.md +0 -0
- {adv_optm-0.1.5 → adv_optm-0.1.7}/adv_optm/optim/Adopt_adv.py +0 -0
- {adv_optm-0.1.5 → adv_optm-0.1.7}/adv_optm/optim/Lion_Prodigy_adv.py +0 -0
- {adv_optm-0.1.5 → adv_optm-0.1.7}/adv_optm/optim/Lion_adv.py +0 -0
- {adv_optm-0.1.5 → adv_optm-0.1.7}/adv_optm/util/BF16_Stochastic_Rounding.py +0 -0
- {adv_optm-0.1.5 → adv_optm-0.1.7}/adv_optm/util/Effective_Shape.py +0 -0
- {adv_optm-0.1.5 → adv_optm-0.1.7}/adv_optm/util/NNMF.py +0 -0
- {adv_optm-0.1.5 → adv_optm-0.1.7}/adv_optm/util/One_Bit_Boolean.py +0 -0
- {adv_optm-0.1.5 → adv_optm-0.1.7}/adv_optm/util/OrthoGrad.py +0 -0
- {adv_optm-0.1.5 → adv_optm-0.1.7}/adv_optm/util/__init__.py +0 -0
- {adv_optm-0.1.5 → adv_optm-0.1.7}/adv_optm.egg-info/dependency_links.txt +0 -0
- {adv_optm-0.1.5 → adv_optm-0.1.7}/adv_optm.egg-info/requires.txt +0 -0
- {adv_optm-0.1.5 → adv_optm-0.1.7}/adv_optm.egg-info/top_level.txt +0 -0
- {adv_optm-0.1.5 → adv_optm-0.1.7}/setup.cfg +0 -0
|
@@ -2,6 +2,7 @@ from .optim import (
|
|
|
2
2
|
AdamW_adv,
|
|
3
3
|
Prodigy_adv,
|
|
4
4
|
Adopt_adv,
|
|
5
|
+
Simplified_AdEMAMix,
|
|
5
6
|
Lion_adv,
|
|
6
7
|
Lion_Prodigy_adv,
|
|
7
8
|
)
|
|
@@ -10,8 +11,9 @@ __all__ = [
|
|
|
10
11
|
"AdamW_adv",
|
|
11
12
|
"Prodigy_adv",
|
|
12
13
|
"Adopt_adv",
|
|
14
|
+
"Simplified_AdEMAMix",
|
|
13
15
|
"Lion_adv",
|
|
14
16
|
"Lion_Prodigy_adv",
|
|
15
17
|
]
|
|
16
18
|
|
|
17
|
-
__version__ = "0.1.
|
|
19
|
+
__version__ = "0.1.7"
|
|
@@ -52,7 +52,17 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
52
52
|
highly recommended to prevent instability at the beginning of training,
|
|
53
53
|
as it gradually introduces the stabilizing slow momentum term. During
|
|
54
54
|
the warmup, `alpha` ramps from 0 to its target value. If `None`,
|
|
55
|
-
the scheduler is disabled
|
|
55
|
+
the scheduler is disabled.
|
|
56
|
+
Simplified_AdEMAMix (bool): whether to use the Simplified AdEMAMix update rule.
|
|
57
|
+
This changes the EMA to accumulator and the update numerator to `alpha_grad * grad + mt`, which can be
|
|
58
|
+
more responsive, especially for small batch sizes. Enabling this will
|
|
59
|
+
automatically disable `use_AdEMAMix`, `use_cautious`, `use_grams`,
|
|
60
|
+
and `use_atan2`. (default: False)
|
|
61
|
+
alpha_grad (float): Mixing coefficient for the Simplified AdEMAMix update rule
|
|
62
|
+
(only used when `Simplified_AdEMAMix` is `True`). Controls the weight of the
|
|
63
|
+
current gradient. For small batch sizes, use high values (e.g., 10-100) to be
|
|
64
|
+
more responsive. For large batch sizes, use low values (e.g., 0-1) for
|
|
65
|
+
stability. (default: 100.0)
|
|
56
66
|
factored (bool): whether to use the factorization or disable it to use
|
|
57
67
|
the uncompressed optimizer. (default: True)
|
|
58
68
|
d0 (float):
|
|
@@ -91,6 +101,8 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
91
101
|
beta3_ema: float = 0.9999,
|
|
92
102
|
alpha: float = 5.0,
|
|
93
103
|
t_alpha: int | None = None,
|
|
104
|
+
Simplified_AdEMAMix: bool = False,
|
|
105
|
+
alpha_grad: float = 100.0,
|
|
94
106
|
factored: bool = True,
|
|
95
107
|
# prodigy parameters
|
|
96
108
|
beta3: float = None,
|
|
@@ -109,6 +121,17 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
109
121
|
raise ValueError(f"Epsilon should be >= 0.0. Got {eps}")
|
|
110
122
|
if not (weight_decay >= 0.0):
|
|
111
123
|
raise ValueError(f"Weight-decay should be >= 0.0. Got {weight_decay}")
|
|
124
|
+
if betas[0] == 0.0 and Simplified_AdEMAMix:
|
|
125
|
+
raise ValueError(f"Beta 1 cannot be 0.0 when using Simplified_AdEMAMix. Got {betas[0]}")
|
|
126
|
+
if use_AdEMAMix and Simplified_AdEMAMix:
|
|
127
|
+
print("Warning: use_AdEMAMix is incompatible with Simplified_AdEMAMix, Disabling use_AdEMAMix.")
|
|
128
|
+
if use_grams and Simplified_AdEMAMix:
|
|
129
|
+
print("Warning: use_grams is incompatible with Simplified_AdEMAMix, Disabling use_grams.")
|
|
130
|
+
if use_cautious and Simplified_AdEMAMix:
|
|
131
|
+
print("Warning: use_cautious is incompatible with Simplified_AdEMAMix, Disabling use_cautious.")
|
|
132
|
+
if use_atan2 and Simplified_AdEMAMix:
|
|
133
|
+
print("Warning: use_atan2 is incompatible with Simplified_AdEMAMix. Disabling use_atan2.")
|
|
134
|
+
use_atan2 = False
|
|
112
135
|
|
|
113
136
|
defaults = {
|
|
114
137
|
"lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay,
|
|
@@ -118,11 +141,13 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
118
141
|
"beta3": beta3, "d": d0, "d0": d0, "d_max": d0, "d_numerator": 0.0, "d_coef": d_coef,
|
|
119
142
|
"growth_rate": growth_rate, "safeguard_warmup": safeguard_warmup, "k": 0, "slice_p": slice_p,
|
|
120
143
|
"fsdp_in_use": fsdp_in_use,
|
|
144
|
+
"alpha_grad": alpha_grad,
|
|
121
145
|
}
|
|
122
146
|
self.stochastic_rounding = stochastic_rounding
|
|
123
|
-
self.use_cautious = use_cautious
|
|
124
|
-
self.use_grams = use_grams
|
|
125
|
-
self.use_AdEMAMix = use_AdEMAMix
|
|
147
|
+
self.use_cautious = use_cautious and not Simplified_AdEMAMix
|
|
148
|
+
self.use_grams = use_grams and not Simplified_AdEMAMix
|
|
149
|
+
self.use_AdEMAMix = use_AdEMAMix and not Simplified_AdEMAMix
|
|
150
|
+
self.Simplified_AdEMAMix = Simplified_AdEMAMix
|
|
126
151
|
self.factored = factored
|
|
127
152
|
self.fsdp_in_use = fsdp_in_use
|
|
128
153
|
super().__init__(params, defaults)
|
|
@@ -229,6 +254,8 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
229
254
|
alpha_t = alpha
|
|
230
255
|
if t_alpha is not None and t_alpha > 0 and current_step < t_alpha:
|
|
231
256
|
alpha_t = min(current_step * alpha / t_alpha, alpha)
|
|
257
|
+
if self.Simplified_AdEMAMix:
|
|
258
|
+
alpha_grad = group["alpha_grad"]
|
|
232
259
|
|
|
233
260
|
if state['factored']:
|
|
234
261
|
d1, d2 = state['effective_shape']
|
|
@@ -243,7 +270,10 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
243
270
|
torch.where(unpacked_sign, mt, -mt, out=mt)
|
|
244
271
|
del unpacked_sign
|
|
245
272
|
# Update momentum in full-size
|
|
246
|
-
|
|
273
|
+
if self.Simplified_AdEMAMix:
|
|
274
|
+
mt.mul_(self.beta1).add_(grad_reshaped, alpha=self.d)
|
|
275
|
+
else:
|
|
276
|
+
mt.mul_(self.beta1).add_(grad_reshaped, alpha=self.d * (1.0 - self.beta1))
|
|
247
277
|
if self.use_grams:
|
|
248
278
|
mt.copy_(grad_reshaped.sign() * mt.abs())
|
|
249
279
|
elif self.use_cautious:
|
|
@@ -264,6 +294,8 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
264
294
|
del unpacked_sign_slow
|
|
265
295
|
mt_slow.mul_(beta3_ema).add_(grad_reshaped, alpha=self.d * (1.0 - beta3_ema))
|
|
266
296
|
update = mt + (alpha_t * mt_slow) if self.beta1 > 0 else grad_reshaped + (alpha_t * mt_slow)
|
|
297
|
+
elif self.Simplified_AdEMAMix:
|
|
298
|
+
update = torch.add(mt, grad_reshaped, alpha=alpha_grad * self.d)
|
|
267
299
|
else:
|
|
268
300
|
update = mt.clone() if self.beta1 > 0 else grad_reshaped.clone()
|
|
269
301
|
del grad_reshaped
|
|
@@ -277,7 +309,7 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
277
309
|
update.div_(denom.add_(self.d * group['eps']))
|
|
278
310
|
del denom
|
|
279
311
|
|
|
280
|
-
update.view(p.shape).mul_(self.dlr)
|
|
312
|
+
update = update.view(p.shape).mul_(self.dlr)
|
|
281
313
|
|
|
282
314
|
# Compress updated moments and store new factors
|
|
283
315
|
if self.beta1 > 0:
|
|
@@ -297,7 +329,10 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
297
329
|
|
|
298
330
|
if self.beta1 > 0:
|
|
299
331
|
exp_avg = state['exp_avg']
|
|
300
|
-
|
|
332
|
+
if self.Simplified_AdEMAMix:
|
|
333
|
+
exp_avg.mul_(self.beta1).add_(grad, alpha=self.d)
|
|
334
|
+
else:
|
|
335
|
+
exp_avg.mul_(self.beta1).add_(grad, alpha=self.d * (1.0 - self.beta1))
|
|
301
336
|
if self.use_grams:
|
|
302
337
|
exp_avg = grad.sign() * exp_avg.abs()
|
|
303
338
|
elif self.use_cautious:
|
|
@@ -310,6 +345,8 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
310
345
|
exp_avg_slow = state['exp_avg_slow']
|
|
311
346
|
exp_avg_slow.mul_(beta3_ema).add_(grad, alpha=self.d * (1.0 - beta3_ema))
|
|
312
347
|
update = exp_avg + (alpha_t * exp_avg_slow) if self.beta1 > 0 else grad + (alpha_t * exp_avg_slow)
|
|
348
|
+
elif self.Simplified_AdEMAMix:
|
|
349
|
+
update = torch.add(exp_avg, grad, alpha=alpha_grad * self.d)
|
|
313
350
|
else:
|
|
314
351
|
update = exp_avg.clone() if self.beta1 > 0 else grad.clone()
|
|
315
352
|
|
|
@@ -0,0 +1,246 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
import math
|
|
4
|
+
|
|
5
|
+
from ..util.BF16_Stochastic_Rounding import add_stochastic_
|
|
6
|
+
from ..util.Effective_Shape import _get_effective_shape
|
|
7
|
+
from ..util.NNMF import _nnmf,_unnmf
|
|
8
|
+
from ..util.OrthoGrad import _orthogonalize_gradient
|
|
9
|
+
from ..util.One_Bit_Boolean import _pack_bools, _unpack_bools
|
|
10
|
+
|
|
11
|
+
# A little helper from the original simplified_AdEMAMix
|
|
12
|
+
def linear_hl_warmup_scheduler(step, beta_end, beta_start=0, warmup=1):
|
|
13
|
+
|
|
14
|
+
def f(beta, eps=1e-8):
|
|
15
|
+
return math.log(0.5)/math.log(beta+eps)-1
|
|
16
|
+
|
|
17
|
+
def f_inv(t):
|
|
18
|
+
return math.pow(0.5, 1/(t+1))
|
|
19
|
+
|
|
20
|
+
if step < warmup:
|
|
21
|
+
a = step / float(warmup)
|
|
22
|
+
return f_inv((1.0-a) * f(beta_start) + a * f(beta_end))
|
|
23
|
+
return beta_end
|
|
24
|
+
|
|
25
|
+
class Simplified_AdEMAMix(torch.optim.Optimizer):
|
|
26
|
+
"""
|
|
27
|
+
Implements the Simplified AdEMAMix algorithm.
|
|
28
|
+
Refactored from:
|
|
29
|
+
https://github.com/DepenM/Simplified-AdEMAMix/blob/main/simplified_AdEMAMix.py
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
params (iterable): iterable of parameters to optimize or dicts defining
|
|
33
|
+
parameter groups
|
|
34
|
+
lr (float): learning rate (default: 1e-5)
|
|
35
|
+
betas (tuple[float, float]): coefficients used for computing running
|
|
36
|
+
averages of gradient and its square (default: (0.99, 0.999))
|
|
37
|
+
eps (float): term added to the denominator to improve
|
|
38
|
+
numerical stability (default: 1e-8)
|
|
39
|
+
weight_decay (float): weight decay (L2 penalty) (default: 0).
|
|
40
|
+
alpha_grad (float): Coeficient for mixing the current gradient and EMA. for small batch
|
|
41
|
+
sizes set it to high values, up to 100. And for large batch sized set it to small
|
|
42
|
+
value, down to 0. (default: 100)
|
|
43
|
+
beta1_warmup (int, optional): number of warmup steps used to increase beta1 (default: None)
|
|
44
|
+
min_beta1 (float, optional): minimum value of beta1 to start from (default 0.9)
|
|
45
|
+
vector_reshape (bool): whether to reshape 1D vectors into 2D
|
|
46
|
+
matrices to apply low-rank compression (default: True).
|
|
47
|
+
stochastic_rounding (bool): whether to use stochastic
|
|
48
|
+
rounding for BF16 parameter updates (default: True).
|
|
49
|
+
use_orthograd (bool): whether to use OrthoGrad. (default: False)
|
|
50
|
+
factored (bool): whether to use the factorization or disable it to use
|
|
51
|
+
the uncompressed optimizer. (default: False)
|
|
52
|
+
"""
|
|
53
|
+
|
|
54
|
+
def __init__(
|
|
55
|
+
self,
|
|
56
|
+
params,
|
|
57
|
+
lr: float = 1e-5,
|
|
58
|
+
betas: tuple[float, float] = (0.99, 0.999),
|
|
59
|
+
eps: float = 1e-8,
|
|
60
|
+
weight_decay: float = 0.0,
|
|
61
|
+
alpha_grad: float = 100.0,
|
|
62
|
+
beta1_warmup: int | None = None,
|
|
63
|
+
min_beta1: float | None = 0.9,
|
|
64
|
+
use_bias_correction: bool = True,
|
|
65
|
+
vector_reshape: bool = True,
|
|
66
|
+
stochastic_rounding: bool = True,
|
|
67
|
+
use_orthograd: bool = False,
|
|
68
|
+
factored: bool = False,
|
|
69
|
+
):
|
|
70
|
+
if not (lr >= 0.0):
|
|
71
|
+
raise ValueError(f"Learning-rate should be >= 0.0. Got {lr}")
|
|
72
|
+
if not (0.0 <= betas[0] < 1.0 and 0.0 <= betas[1] < 1.0):
|
|
73
|
+
raise ValueError(f"Betas should be in [0.0, 1.0). Got {betas}")
|
|
74
|
+
if not (eps >= 0.0):
|
|
75
|
+
raise ValueError(f"Epsilon should be >= 0.0. Got {eps}")
|
|
76
|
+
if not (weight_decay >= 0.0):
|
|
77
|
+
raise ValueError(f"Weight-decay should be >= 0.0. Got {weight_decay}")
|
|
78
|
+
if not 0.0 <= alpha_grad:
|
|
79
|
+
raise ValueError("Invalid alpha value: {}".format(alpha_grad))
|
|
80
|
+
|
|
81
|
+
defaults = {
|
|
82
|
+
"lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay,
|
|
83
|
+
"alpha_grad": alpha_grad, "beta1_warmup": beta1_warmup, "min_beta1": min_beta1,
|
|
84
|
+
"vector_reshape": vector_reshape,
|
|
85
|
+
"use_orthograd": use_orthograd, "use_bias_correction": use_bias_correction,
|
|
86
|
+
}
|
|
87
|
+
self.stochastic_rounding = stochastic_rounding
|
|
88
|
+
self.factored = factored
|
|
89
|
+
super().__init__(params, defaults)
|
|
90
|
+
|
|
91
|
+
@property
|
|
92
|
+
def supports_fused_back_pass(self):
|
|
93
|
+
return True
|
|
94
|
+
|
|
95
|
+
@property
|
|
96
|
+
def supports_memory_efficient_fp16(self):
|
|
97
|
+
return True
|
|
98
|
+
|
|
99
|
+
@property
|
|
100
|
+
def supports_flat_params(self):
|
|
101
|
+
return False
|
|
102
|
+
|
|
103
|
+
@torch.no_grad()
|
|
104
|
+
def step_parameter(self, p: torch.Tensor, group: dict, i: int | None = None):
|
|
105
|
+
if p.grad is None:
|
|
106
|
+
return
|
|
107
|
+
|
|
108
|
+
grad = p.grad
|
|
109
|
+
if grad.dtype != torch.float32 and self.factored:
|
|
110
|
+
grad = grad.float()
|
|
111
|
+
if group["use_orthograd"]:
|
|
112
|
+
grad = _orthogonalize_gradient(p, grad)
|
|
113
|
+
state = self.state[p]
|
|
114
|
+
|
|
115
|
+
# State Initialization
|
|
116
|
+
if len(state) == 0:
|
|
117
|
+
state['step'] = 0
|
|
118
|
+
|
|
119
|
+
should_factor = (
|
|
120
|
+
self.factored and
|
|
121
|
+
not (len(p.shape) == 1 and not group['vector_reshape'])
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
state['factored'] = should_factor
|
|
125
|
+
|
|
126
|
+
dtype = torch.float32 if self.factored else p.dtype
|
|
127
|
+
device = p.device
|
|
128
|
+
|
|
129
|
+
if state['factored']:
|
|
130
|
+
state['effective_shape'] = _get_effective_shape(p.numel())
|
|
131
|
+
d1, d2 = state['effective_shape']
|
|
132
|
+
|
|
133
|
+
# First moment (m)
|
|
134
|
+
state['mu_m_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
|
|
135
|
+
state['mv_m_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
|
|
136
|
+
packed_d2 = (d2 + 7) // 8
|
|
137
|
+
state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=device)
|
|
138
|
+
# Second moment (v)
|
|
139
|
+
state['mu_v_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
|
|
140
|
+
state['mv_v_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
|
|
141
|
+
else: # Fallback to standard optimizer for non-factored tensors
|
|
142
|
+
state['exp_avg'] = torch.zeros_like(p, device=device, dtype=dtype)
|
|
143
|
+
state['exp_avg_sq'] = torch.zeros_like(p, device=device, dtype=dtype)
|
|
144
|
+
|
|
145
|
+
if group['use_bias_correction']:
|
|
146
|
+
state['num_sum'] = 0.0
|
|
147
|
+
state['den_sum'] = 0.0
|
|
148
|
+
else:
|
|
149
|
+
state['num_sum'] = 1.0
|
|
150
|
+
state['den_sum'] = 1.0
|
|
151
|
+
|
|
152
|
+
beta1_final, beta2 = group["betas"]
|
|
153
|
+
beta1_warmup = group["beta1_warmup"]
|
|
154
|
+
alpha_grad = group["alpha_grad"]
|
|
155
|
+
|
|
156
|
+
if beta1_warmup is not None:
|
|
157
|
+
step = state['step'] + 1
|
|
158
|
+
beta1 = linear_hl_warmup_scheduler(step, beta_end=beta1_final, beta_start=group['min_beta1'], warmup=beta1_warmup)
|
|
159
|
+
else:
|
|
160
|
+
beta1 = beta1_final
|
|
161
|
+
|
|
162
|
+
if group['use_bias_correction']:
|
|
163
|
+
state['num_sum'] = beta1 * state['num_sum'] + 1.0
|
|
164
|
+
state['den_sum'] = beta2 * state['den_sum'] + (1.0 - beta2)
|
|
165
|
+
|
|
166
|
+
if state['factored']:
|
|
167
|
+
d1, d2 = state['effective_shape']
|
|
168
|
+
|
|
169
|
+
# Reconstruct momentum from previous step's factors
|
|
170
|
+
mt = _unnmf((state['mu_m_nmf'], state['mv_m_nmf']))
|
|
171
|
+
unpacked_sign = _unpack_bools(state['sign'], original_m=d2)
|
|
172
|
+
torch.where(unpacked_sign, mt, -mt, out=mt)
|
|
173
|
+
del unpacked_sign
|
|
174
|
+
# Update momentum in full-size
|
|
175
|
+
grad_reshaped = grad.view(d1, d2)
|
|
176
|
+
mt.mul_(beta1).add_(grad_reshaped, alpha=1.0)
|
|
177
|
+
|
|
178
|
+
vt = _unnmf((state['mu_v_nmf'], state['mv_v_nmf']))
|
|
179
|
+
vt.mul_(beta2).addcmul_(grad_reshaped, grad_reshaped, value=1.0 - beta2)
|
|
180
|
+
|
|
181
|
+
update = torch.add(mt, grad_reshaped, alpha=alpha_grad)
|
|
182
|
+
del grad_reshaped
|
|
183
|
+
|
|
184
|
+
denom = vt.sqrt().add_(group['eps'] * math.sqrt(state['den_sum']))
|
|
185
|
+
update.div_(denom)
|
|
186
|
+
del denom
|
|
187
|
+
|
|
188
|
+
if group['use_bias_correction']:
|
|
189
|
+
update = (update / state['num_sum']) * math.sqrt(state['den_sum'])
|
|
190
|
+
|
|
191
|
+
update = update.view(p.shape).mul_(group['lr'])
|
|
192
|
+
|
|
193
|
+
# Compress updated moments and store new factors
|
|
194
|
+
state['sign'] = _pack_bools(mt > 0)
|
|
195
|
+
_nnmf(mt.abs(), out=(state['mu_m_nmf'], state['mv_m_nmf']))
|
|
196
|
+
del mt
|
|
197
|
+
_nnmf(vt, out=(state['mu_v_nmf'], state['mv_v_nmf']))
|
|
198
|
+
del vt
|
|
199
|
+
|
|
200
|
+
else: # Standard optimizer logic for non-factored tensors
|
|
201
|
+
exp_avg_sq = state['exp_avg_sq']
|
|
202
|
+
|
|
203
|
+
exp_avg = state['exp_avg']
|
|
204
|
+
exp_avg.mul_(beta1).add_(grad, alpha=1.0)
|
|
205
|
+
|
|
206
|
+
update = torch.add(exp_avg, grad, alpha=alpha_grad)
|
|
207
|
+
|
|
208
|
+
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
|
|
209
|
+
|
|
210
|
+
denom = exp_avg_sq.sqrt().add_(group['eps'] * math.sqrt(state['den_sum']))
|
|
211
|
+
update.div_(denom)
|
|
212
|
+
del denom
|
|
213
|
+
|
|
214
|
+
if group['use_bias_correction']:
|
|
215
|
+
update = (update / state['num_sum']) * math.sqrt(state['den_sum'])
|
|
216
|
+
|
|
217
|
+
update.mul_(group['lr'])
|
|
218
|
+
|
|
219
|
+
# Decoupled weight decay
|
|
220
|
+
if group["weight_decay"] != 0:
|
|
221
|
+
if p.dtype == torch.bfloat16 and self.stochastic_rounding:
|
|
222
|
+
add_stochastic_(p.data, p.data, alpha=-group["weight_decay"] * group["lr"])
|
|
223
|
+
else:
|
|
224
|
+
p.data.add_(p.data, alpha=-group["weight_decay"] * group["lr"])
|
|
225
|
+
|
|
226
|
+
if p.dtype == torch.bfloat16 and self.stochastic_rounding:
|
|
227
|
+
add_stochastic_(p.data, -update)
|
|
228
|
+
else:
|
|
229
|
+
p.data.add_(-update)
|
|
230
|
+
del update
|
|
231
|
+
|
|
232
|
+
state['step'] += 1
|
|
233
|
+
|
|
234
|
+
@torch.no_grad()
|
|
235
|
+
def step(self, closure=None):
|
|
236
|
+
"""Performs a single optimization step."""
|
|
237
|
+
loss = None
|
|
238
|
+
if closure is not None:
|
|
239
|
+
with torch.enable_grad():
|
|
240
|
+
loss = closure()
|
|
241
|
+
|
|
242
|
+
for group in self.param_groups:
|
|
243
|
+
for i, p in enumerate(group['params']):
|
|
244
|
+
self.step_parameter(p, group, i)
|
|
245
|
+
|
|
246
|
+
return loss
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
from .AdamW_adv import AdamW_adv
|
|
2
2
|
from .Prodigy_adv import Prodigy_adv
|
|
3
3
|
from .Adopt_adv import Adopt_adv
|
|
4
|
+
from .Simplified_AdEMAMix import Simplified_AdEMAMix
|
|
4
5
|
from .Lion_adv import Lion_adv
|
|
5
6
|
from .Lion_Prodigy_adv import Lion_Prodigy_adv
|
|
6
7
|
|
|
@@ -8,6 +9,7 @@ __all__ = [
|
|
|
8
9
|
"AdamW_adv",
|
|
9
10
|
"Prodigy_adv",
|
|
10
11
|
"Adopt_adv",
|
|
12
|
+
"Simplified_AdEMAMix",
|
|
11
13
|
"Lion_adv",
|
|
12
14
|
"Lion_Prodigy_adv",
|
|
13
15
|
]
|
|
@@ -12,6 +12,7 @@ adv_optm/optim/Adopt_adv.py
|
|
|
12
12
|
adv_optm/optim/Lion_Prodigy_adv.py
|
|
13
13
|
adv_optm/optim/Lion_adv.py
|
|
14
14
|
adv_optm/optim/Prodigy_adv.py
|
|
15
|
+
adv_optm/optim/Simplified_AdEMAMix.py
|
|
15
16
|
adv_optm/optim/__init__.py
|
|
16
17
|
adv_optm/util/BF16_Stochastic_Rounding.py
|
|
17
18
|
adv_optm/util/Effective_Shape.py
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|