adv-optm 0.1.2__py3-none-any.whl → 0.1.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of adv-optm might be problematic. Click here for more details.
- adv_optm/__init__.py +1 -1
- adv_optm/optim/AdamW_adv.py +296 -296
- adv_optm/optim/Lion_Prodigy_adv.py +22 -8
- adv_optm/optim/Lion_adv.py +242 -230
- adv_optm/optim/Prodigy_adv.py +56 -51
- {adv_optm-0.1.2.dist-info → adv_optm-0.1.4.dist-info}/METADATA +1 -1
- {adv_optm-0.1.2.dist-info → adv_optm-0.1.4.dist-info}/RECORD +10 -10
- {adv_optm-0.1.2.dist-info → adv_optm-0.1.4.dist-info}/WHEEL +0 -0
- {adv_optm-0.1.2.dist-info → adv_optm-0.1.4.dist-info}/licenses/LICENSE +0 -0
- {adv_optm-0.1.2.dist-info → adv_optm-0.1.4.dist-info}/top_level.txt +0 -0
adv_optm/optim/Lion_adv.py
CHANGED
|
@@ -1,231 +1,243 @@
|
|
|
1
|
-
import torch
|
|
2
|
-
|
|
3
|
-
from typing import Tuple, Optional
|
|
4
|
-
|
|
5
|
-
from ..util.BF16_Stochastic_Rounding import add_stochastic_
|
|
6
|
-
from ..util.Effective_Shape import _get_effective_shape
|
|
7
|
-
from ..util.NNMF import _nnmf,_unnmf
|
|
8
|
-
from ..util.OrthoGrad import _orthogonalize_gradient
|
|
9
|
-
from ..util.One_Bit_Boolean import _pack_bools, _unpack_bools
|
|
10
|
-
|
|
11
|
-
class Lion_adv(torch.optim.Optimizer):
|
|
12
|
-
"""
|
|
13
|
-
Implements the SMMF technique for Lion algorithm.
|
|
14
|
-
|
|
15
|
-
This optimizer combines the Lion update rule with the memory-saving low-rank
|
|
16
|
-
compression (SMMF) technique from https://arxiv.org/abs/2412.08894.
|
|
17
|
-
|
|
18
|
-
Args:
|
|
19
|
-
params (iterable): iterable of parameters to optimize or dicts defining
|
|
20
|
-
parameter groups.
|
|
21
|
-
lr (float, optional): learning rate (default: 1e-4).
|
|
22
|
-
betas (Tuple[float, float], optional): coefficients for computing
|
|
23
|
-
running averages of the update (default: (0.9, 0.99)).
|
|
24
|
-
weight_decay (float, optional): weight decay (L2 penalty) (default: 0.0).
|
|
25
|
-
vector_reshape (bool, optional): whether to reshape 1D vectors into 2D
|
|
26
|
-
matrices to apply low-rank compression (default: True).
|
|
27
|
-
stochastic_rounding (bool, optional): whether to use stochastic
|
|
28
|
-
rounding for BF16 parameter updates (default: True).
|
|
29
|
-
use_cautious (bool): whether to use the cautious masking technique. (default: False).
|
|
30
|
-
clip_threshold (float, optional): whether to clip the gradients norm
|
|
31
|
-
per-parameter as proposed in the paper `Lions and Muons: Optimization via
|
|
32
|
-
Stochastic Frank-Wolfe` (https://arxiv.org/abs/2506.04192) to make Lion more stable
|
|
33
|
-
(default: 0.0).
|
|
34
|
-
factored (bool): whether to use the factorization or use the
|
|
35
|
-
uncompressed optimizer. (default: True)
|
|
36
|
-
variance_reduction (bool): whether to use the variance reduction technique
|
|
37
|
-
from "Convergence Analysis of the Lion Optimizer" (arXiv:2508.12327v1). (default: False).
|
|
38
|
-
"""
|
|
39
|
-
|
|
40
|
-
def __init__(
|
|
41
|
-
self,
|
|
42
|
-
params,
|
|
43
|
-
lr: float = 1e-4,
|
|
44
|
-
betas: Tuple[float, float] = (0.9, 0.99),
|
|
45
|
-
weight_decay: float = 0.0,
|
|
46
|
-
vector_reshape: bool = True,
|
|
47
|
-
stochastic_rounding: bool = True,
|
|
48
|
-
use_orthograd: bool = False,
|
|
49
|
-
use_cautious: bool = False,
|
|
50
|
-
clip_threshold: float = 0.0,
|
|
51
|
-
factored: bool = True,
|
|
52
|
-
variance_reduction: bool = False,
|
|
53
|
-
):
|
|
54
|
-
if not lr > 0.0:
|
|
55
|
-
raise ValueError(f"Learning rate must be > 0.0, but got {lr}")
|
|
56
|
-
if not all(0.0 <= beta <= 1.0 for beta in betas):
|
|
57
|
-
raise ValueError(f"Betas should be in [0.0, 1.0], but got {betas}")
|
|
58
|
-
if not weight_decay >= 0.0:
|
|
59
|
-
raise ValueError(f"Weight decay must be >= 0.0, but got {weight_decay}")
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
self.
|
|
72
|
-
self.
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
if
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
state['
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
state['
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
exp_avg
|
|
144
|
-
unpacked_sign
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
mask
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
if self.
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
)
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
from typing import Tuple, Optional
|
|
4
|
+
|
|
5
|
+
from ..util.BF16_Stochastic_Rounding import add_stochastic_
|
|
6
|
+
from ..util.Effective_Shape import _get_effective_shape
|
|
7
|
+
from ..util.NNMF import _nnmf,_unnmf
|
|
8
|
+
from ..util.OrthoGrad import _orthogonalize_gradient
|
|
9
|
+
from ..util.One_Bit_Boolean import _pack_bools, _unpack_bools
|
|
10
|
+
|
|
11
|
+
class Lion_adv(torch.optim.Optimizer):
|
|
12
|
+
"""
|
|
13
|
+
Implements the SMMF technique for Lion algorithm.
|
|
14
|
+
|
|
15
|
+
This optimizer combines the Lion update rule with the memory-saving low-rank
|
|
16
|
+
compression (SMMF) technique from https://arxiv.org/abs/2412.08894.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
params (iterable): iterable of parameters to optimize or dicts defining
|
|
20
|
+
parameter groups.
|
|
21
|
+
lr (float, optional): learning rate (default: 1e-4).
|
|
22
|
+
betas (Tuple[float, float], optional): coefficients for computing
|
|
23
|
+
running averages of the update (default: (0.9, 0.99)).
|
|
24
|
+
weight_decay (float, optional): weight decay (L2 penalty) (default: 0.0).
|
|
25
|
+
vector_reshape (bool, optional): whether to reshape 1D vectors into 2D
|
|
26
|
+
matrices to apply low-rank compression (default: True).
|
|
27
|
+
stochastic_rounding (bool, optional): whether to use stochastic
|
|
28
|
+
rounding for BF16 parameter updates (default: True).
|
|
29
|
+
use_cautious (bool): whether to use the cautious masking technique. (default: False).
|
|
30
|
+
clip_threshold (float, optional): whether to clip the gradients norm
|
|
31
|
+
per-parameter as proposed in the paper `Lions and Muons: Optimization via
|
|
32
|
+
Stochastic Frank-Wolfe` (https://arxiv.org/abs/2506.04192) to make Lion more stable
|
|
33
|
+
(default: 0.0).
|
|
34
|
+
factored (bool): whether to use the factorization or use the
|
|
35
|
+
uncompressed optimizer. (default: True)
|
|
36
|
+
variance_reduction (bool): whether to use the variance reduction technique
|
|
37
|
+
from "Convergence Analysis of the Lion Optimizer" (arXiv:2508.12327v1). (default: False).
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
def __init__(
|
|
41
|
+
self,
|
|
42
|
+
params,
|
|
43
|
+
lr: float = 1e-4,
|
|
44
|
+
betas: Tuple[float, float] = (0.9, 0.99),
|
|
45
|
+
weight_decay: float = 0.0,
|
|
46
|
+
vector_reshape: bool = True,
|
|
47
|
+
stochastic_rounding: bool = True,
|
|
48
|
+
use_orthograd: bool = False,
|
|
49
|
+
use_cautious: bool = False,
|
|
50
|
+
clip_threshold: float = 0.0,
|
|
51
|
+
factored: bool = True,
|
|
52
|
+
variance_reduction: bool = False,
|
|
53
|
+
):
|
|
54
|
+
if not lr > 0.0:
|
|
55
|
+
raise ValueError(f"Learning rate must be > 0.0, but got {lr}")
|
|
56
|
+
if not all(0.0 <= beta <= 1.0 for beta in betas):
|
|
57
|
+
raise ValueError(f"Betas should be in [0.0, 1.0], but got {betas}")
|
|
58
|
+
if not weight_decay >= 0.0:
|
|
59
|
+
raise ValueError(f"Weight decay must be >= 0.0, but got {weight_decay}")
|
|
60
|
+
|
|
61
|
+
defaults = dict(
|
|
62
|
+
lr=lr,
|
|
63
|
+
betas=betas,
|
|
64
|
+
weight_decay=weight_decay,
|
|
65
|
+
vector_reshape=vector_reshape,
|
|
66
|
+
use_orthograd=use_orthograd,
|
|
67
|
+
clip_threshold=clip_threshold,
|
|
68
|
+
)
|
|
69
|
+
self.stochastic_rounding = stochastic_rounding
|
|
70
|
+
self.use_cautious = use_cautious
|
|
71
|
+
self.factored = factored
|
|
72
|
+
self.variance_reduction = variance_reduction
|
|
73
|
+
super().__init__(params, defaults)
|
|
74
|
+
|
|
75
|
+
@property
|
|
76
|
+
def supports_fused_back_pass(self) -> bool:
|
|
77
|
+
return True
|
|
78
|
+
|
|
79
|
+
@property
|
|
80
|
+
def supports_memory_efficient_fp16(self) -> bool:
|
|
81
|
+
return True
|
|
82
|
+
|
|
83
|
+
@property
|
|
84
|
+
def supports_flat_params(self) -> bool:
|
|
85
|
+
return False
|
|
86
|
+
|
|
87
|
+
@torch.no_grad()
|
|
88
|
+
def step_parameter(self, p: torch.Tensor, group: dict, i: Optional[int] = None):
|
|
89
|
+
"""Performs a single optimization step on a single parameter."""
|
|
90
|
+
if p.grad is None:
|
|
91
|
+
return
|
|
92
|
+
|
|
93
|
+
grad = p.grad
|
|
94
|
+
if grad.dtype != torch.float32 and self.factored:
|
|
95
|
+
grad = grad.float()
|
|
96
|
+
if group["clip_threshold"] > 0.0:
|
|
97
|
+
grad_norm = torch.norm(grad.detach())
|
|
98
|
+
if grad_norm > group["clip_threshold"]:
|
|
99
|
+
clip_coef = group["clip_threshold"] / grad_norm
|
|
100
|
+
grad.mul_(clip_coef)
|
|
101
|
+
if group["use_orthograd"]:
|
|
102
|
+
grad = _orthogonalize_gradient(p, grad)
|
|
103
|
+
state = self.state[p]
|
|
104
|
+
|
|
105
|
+
# State Initialization
|
|
106
|
+
if len(state) == 0:
|
|
107
|
+
state['step'] = 0
|
|
108
|
+
|
|
109
|
+
should_factor = (
|
|
110
|
+
self.factored and
|
|
111
|
+
not (len(p.shape) == 1 and not group['vector_reshape'])
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
state['factored'] = should_factor
|
|
115
|
+
|
|
116
|
+
dtype = torch.float32 if self.factored else p.dtype
|
|
117
|
+
|
|
118
|
+
if state['factored']:
|
|
119
|
+
state['effective_shape'] = _get_effective_shape(p.numel())
|
|
120
|
+
d1, d2 = state['effective_shape']
|
|
121
|
+
state['mu_m_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
|
|
122
|
+
state['mv_m_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
|
|
123
|
+
packed_d2 = (d2 + 7) // 8
|
|
124
|
+
state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=p.device)
|
|
125
|
+
if self.variance_reduction:
|
|
126
|
+
state['prev_grad'] = torch.zeros((d1, d2), device=p.device, dtype=dtype)
|
|
127
|
+
else: # Fallback to standard Lion
|
|
128
|
+
state['exp_avg'] = torch.zeros_like(p, device=p.device, dtype=dtype)
|
|
129
|
+
if self.variance_reduction:
|
|
130
|
+
state['prev_grad'] = torch.zeros_like(p, device=p.device, dtype=dtype)
|
|
131
|
+
|
|
132
|
+
state['step'] += 1
|
|
133
|
+
beta1, beta2 = group["betas"]
|
|
134
|
+
lr = group["lr"]
|
|
135
|
+
|
|
136
|
+
if state['factored']:
|
|
137
|
+
# Factored Path
|
|
138
|
+
d1, d2 = state['effective_shape']
|
|
139
|
+
grad_reshaped = grad.view(d1, d2)
|
|
140
|
+
# Reconstruct momentum m_{t-1}
|
|
141
|
+
exp_avg = _unnmf((state['mu_m_nmf'], state['mv_m_nmf']))
|
|
142
|
+
unpacked_sign = _unpack_bools(state['sign'], original_m=d2)
|
|
143
|
+
torch.where(unpacked_sign, exp_avg, -exp_avg, out=exp_avg)
|
|
144
|
+
del unpacked_sign
|
|
145
|
+
if exp_avg.dtype != torch.float32:
|
|
146
|
+
exp_avg = exp_avg.float()
|
|
147
|
+
|
|
148
|
+
# Compute update term c_t
|
|
149
|
+
signed_update = exp_avg.clone().mul_(beta1).add_(grad_reshaped, alpha=(1-beta1)).sign_()
|
|
150
|
+
|
|
151
|
+
if self.use_cautious:
|
|
152
|
+
mask = (signed_update * grad_reshaped > 0).to(grad_reshaped.dtype)
|
|
153
|
+
mask.div_(mask.mean().clamp_(min=1e-3))
|
|
154
|
+
signed_update.mul_(mask)
|
|
155
|
+
del mask
|
|
156
|
+
|
|
157
|
+
# Parameter update
|
|
158
|
+
update_for_param = signed_update.view(p.shape).mul_(lr)
|
|
159
|
+
|
|
160
|
+
# Update momentum
|
|
161
|
+
if self.variance_reduction:
|
|
162
|
+
if state['step'] == 1:
|
|
163
|
+
exp_avg.copy_(grad_reshaped)
|
|
164
|
+
else:
|
|
165
|
+
# Use the simplified STORM update: m_t = g_t + β₂ * (m_{t-1} - g_{t-1})
|
|
166
|
+
correction = exp_avg.sub(state['prev_grad'])
|
|
167
|
+
# Calculate the new momentum and store it back into exp_avg
|
|
168
|
+
exp_avg.copy_(grad_reshaped).add_(correction, alpha=beta2)
|
|
169
|
+
del correction
|
|
170
|
+
# Update prev_grad for the next iteration
|
|
171
|
+
state['prev_grad'].copy_(grad_reshaped)
|
|
172
|
+
else:
|
|
173
|
+
# Standard Lion momentum update
|
|
174
|
+
exp_avg.mul_(beta2).add_(grad_reshaped, alpha=1-beta2)
|
|
175
|
+
|
|
176
|
+
# Compress new momentum m_t and store factors
|
|
177
|
+
state['sign'] = _pack_bools(exp_avg > 0)
|
|
178
|
+
_nnmf(exp_avg.abs(), out=(state['mu_m_nmf'], state['mv_m_nmf']))
|
|
179
|
+
del exp_avg
|
|
180
|
+
|
|
181
|
+
else:
|
|
182
|
+
# Fallback to standard Lion logic
|
|
183
|
+
exp_avg = state["exp_avg"]
|
|
184
|
+
|
|
185
|
+
# Compute update term and sign for the update
|
|
186
|
+
if exp_avg.dtype != torch.float32 and self.factored:
|
|
187
|
+
exp_avg = exp_avg.float()
|
|
188
|
+
signed_update = exp_avg.clone().mul_(beta1).add_(grad, alpha=(1-beta1)).sign_()
|
|
189
|
+
|
|
190
|
+
if self.use_cautious:
|
|
191
|
+
mask = (signed_update * grad > 0).to(grad.dtype)
|
|
192
|
+
mask.div_(mask.mean().clamp_(min=1e-3))
|
|
193
|
+
signed_update.mul_(mask)
|
|
194
|
+
del mask
|
|
195
|
+
|
|
196
|
+
update_for_param = signed_update.mul_(lr)
|
|
197
|
+
|
|
198
|
+
# Update momentum
|
|
199
|
+
if self.variance_reduction:
|
|
200
|
+
if state['step'] == 1:
|
|
201
|
+
exp_avg.copy_(grad)
|
|
202
|
+
else:
|
|
203
|
+
# Use the simplified STORM update: m_t = g_t + β₂ * (m_{t-1} - g_{t-1})
|
|
204
|
+
correction = exp_avg.sub(state['prev_grad'])
|
|
205
|
+
# Calculate the new momentum and store it back into exp_avg
|
|
206
|
+
exp_avg.copy_(grad).add_(correction, alpha=beta2)
|
|
207
|
+
del correction
|
|
208
|
+
# Update prev_grad for the next iteration
|
|
209
|
+
state['prev_grad'].copy_(grad)
|
|
210
|
+
else:
|
|
211
|
+
# Standard Lion momentum update
|
|
212
|
+
exp_avg.mul_(beta2).add_(grad, alpha=1-beta2)
|
|
213
|
+
|
|
214
|
+
if group["weight_decay"] != 0:
|
|
215
|
+
if p.dtype == torch.bfloat16 and self.stochastic_rounding:
|
|
216
|
+
add_stochastic_(p.data, p.data,
|
|
217
|
+
alpha=-group["weight_decay"] * lr)
|
|
218
|
+
else:
|
|
219
|
+
p.data.add_(
|
|
220
|
+
p.data, alpha=-group["weight_decay"] * lr
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
if p.dtype == torch.bfloat16 and self.stochastic_rounding:
|
|
224
|
+
add_stochastic_(p.data, -update_for_param)
|
|
225
|
+
else:
|
|
226
|
+
p.data.add_(-update_for_param)
|
|
227
|
+
|
|
228
|
+
del update_for_param
|
|
229
|
+
|
|
230
|
+
@torch.no_grad()
|
|
231
|
+
def step(self, closure: Optional[callable] = None):
|
|
232
|
+
"""Performs a single optimization step."""
|
|
233
|
+
loss = None
|
|
234
|
+
if closure is not None:
|
|
235
|
+
with torch.enable_grad():
|
|
236
|
+
loss = closure()
|
|
237
|
+
|
|
238
|
+
for group in self.param_groups:
|
|
239
|
+
for i, p in enumerate(group["params"]):
|
|
240
|
+
if p.grad is not None:
|
|
241
|
+
self.step_parameter(p, group, i)
|
|
242
|
+
|
|
231
243
|
return loss
|
adv_optm/optim/Prodigy_adv.py
CHANGED
|
@@ -194,11 +194,12 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
194
194
|
d1, d2 = state['effective_shape']
|
|
195
195
|
|
|
196
196
|
# First moment (m)
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
197
|
+
if self.beta1 > 0:
|
|
198
|
+
state['mu_m_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
|
|
199
|
+
state['mv_m_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
|
|
200
|
+
if not self.use_grams:
|
|
201
|
+
packed_d2 = (d2 + 7) // 8
|
|
202
|
+
state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=device)
|
|
202
203
|
if self.use_AdEMAMix:
|
|
203
204
|
state['mu_m_slow_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
|
|
204
205
|
state['mv_m_slow_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
|
|
@@ -208,7 +209,8 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
208
209
|
state['mu_v_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
|
|
209
210
|
state['mv_v_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
|
|
210
211
|
else: # Fallback to standard AdamW for non-factored tensors
|
|
211
|
-
|
|
212
|
+
if self.beta1 > 0:
|
|
213
|
+
state['exp_avg'] = torch.zeros_like(p, device=device, dtype=dtype)
|
|
212
214
|
if self.use_AdEMAMix:
|
|
213
215
|
state['exp_avg_slow'] = torch.zeros_like(p, dtype=dtype)
|
|
214
216
|
state['exp_avg_sq'] = torch.zeros_like(p, device=device, dtype=dtype)
|
|
@@ -231,22 +233,24 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
231
233
|
if state['factored']:
|
|
232
234
|
d1, d2 = state['effective_shape']
|
|
233
235
|
|
|
234
|
-
# Reconstruct momentum from previous step's factors
|
|
235
|
-
mt = _unnmf((state['mu_m_nmf'], state['mv_m_nmf']))
|
|
236
|
-
if not self.use_grams:
|
|
237
|
-
unpacked_sign = _unpack_bools(state['sign'], original_m=d2)
|
|
238
|
-
torch.where(unpacked_sign, mt, -mt, out=mt)
|
|
239
|
-
del unpacked_sign
|
|
240
|
-
# Update momentum in full-size
|
|
241
236
|
grad_reshaped = grad.view(d1, d2)
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
237
|
+
|
|
238
|
+
# Reconstruct momentum from previous step's factors
|
|
239
|
+
if self.beta1 > 0:
|
|
240
|
+
mt = _unnmf((state['mu_m_nmf'], state['mv_m_nmf']))
|
|
241
|
+
if not self.use_grams:
|
|
242
|
+
unpacked_sign = _unpack_bools(state['sign'], original_m=d2)
|
|
243
|
+
torch.where(unpacked_sign, mt, -mt, out=mt)
|
|
244
|
+
del unpacked_sign
|
|
245
|
+
# Update momentum in full-size
|
|
246
|
+
mt.mul_(self.beta1).add_(grad_reshaped, alpha=self.d * (1.0 - self.beta1))
|
|
247
|
+
if self.use_grams:
|
|
248
|
+
mt.copy_(grad_reshaped.sign() * mt.abs())
|
|
249
|
+
elif self.use_cautious:
|
|
250
|
+
mask = (mt * grad_reshaped > 0).to(grad_reshaped.dtype)
|
|
251
|
+
mask.div_(mask.mean().clamp_(min=1e-3))
|
|
252
|
+
mt.mul_(mask)
|
|
253
|
+
del mask
|
|
250
254
|
|
|
251
255
|
vt = _unnmf((state['mu_v_nmf'], state['mv_v_nmf']))
|
|
252
256
|
vt.mul_(self.beta2).addcmul_(grad_reshaped, grad_reshaped, value=self.d * self.d * (1.0 - self.beta2))
|
|
@@ -258,30 +262,29 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
258
262
|
unpacked_sign_slow = _unpack_bools(state['sign_slow'], original_m=d2)
|
|
259
263
|
torch.where(unpacked_sign_slow, mt_slow, -mt_slow, out=mt_slow)
|
|
260
264
|
del unpacked_sign_slow
|
|
261
|
-
|
|
262
265
|
mt_slow.mul_(beta3_ema).add_(grad_reshaped, alpha=self.d * (1.0 - beta3_ema))
|
|
263
|
-
|
|
266
|
+
update = mt + (alpha_t * mt_slow) if self.beta1 > 0 else grad_reshaped + (alpha_t * mt_slow)
|
|
264
267
|
else:
|
|
265
|
-
|
|
268
|
+
update = mt if self.beta1 > 0 else grad_reshaped
|
|
266
269
|
del grad_reshaped
|
|
267
270
|
|
|
268
271
|
if group['use_atan2']:
|
|
269
272
|
a = 1.2732395
|
|
270
273
|
denom = vt.sqrt()
|
|
271
|
-
update
|
|
274
|
+
update.atan2_(denom).mul_(a)
|
|
272
275
|
else:
|
|
273
|
-
denom = vt.sqrt()
|
|
274
|
-
update
|
|
275
|
-
del
|
|
276
|
+
denom = vt.sqrt()
|
|
277
|
+
update.div_(denom.add_(self.d * group['eps']))
|
|
278
|
+
del denom
|
|
276
279
|
|
|
277
|
-
update
|
|
278
|
-
update.mul_(self.dlr)
|
|
280
|
+
update.view(p.shape).mul_(self.dlr)
|
|
279
281
|
|
|
280
282
|
# Compress updated moments and store new factors
|
|
281
|
-
if
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
283
|
+
if self.beta1 > 0:
|
|
284
|
+
if not self.use_grams:
|
|
285
|
+
state['sign'] = _pack_bools(mt > 0)
|
|
286
|
+
_nnmf(mt.abs(), out=(state['mu_m_nmf'], state['mv_m_nmf']))
|
|
287
|
+
del mt
|
|
285
288
|
if self.use_AdEMAMix:
|
|
286
289
|
state['sign_slow'] = _pack_bools(mt_slow > 0)
|
|
287
290
|
_nnmf(mt_slow.abs(), out=(state['mu_m_slow_nmf'], state['mv_m_slow_nmf']))
|
|
@@ -290,36 +293,38 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
290
293
|
del vt
|
|
291
294
|
|
|
292
295
|
else: # Standard AdamW logic for non-factored tensors
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
exp_avg
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
296
|
+
exp_avg_sq = state['exp_avg_sq']
|
|
297
|
+
|
|
298
|
+
if self.beta1 > 0:
|
|
299
|
+
exp_avg = state['exp_avg']
|
|
300
|
+
exp_avg.mul_(self.beta1).add_(grad, alpha=self.d * (1.0 - self.beta1))
|
|
301
|
+
if self.use_grams:
|
|
302
|
+
exp_avg = grad.sign() * exp_avg.abs()
|
|
303
|
+
elif self.use_cautious:
|
|
304
|
+
mask = (exp_avg * grad > 0).to(grad.dtype)
|
|
305
|
+
mask.div_(mask.mean().clamp_(min=1e-3))
|
|
306
|
+
exp_avg.mul_(mask)
|
|
307
|
+
del mask
|
|
303
308
|
|
|
304
309
|
if self.use_AdEMAMix:
|
|
305
310
|
exp_avg_slow = state['exp_avg_slow']
|
|
306
311
|
exp_avg_slow.mul_(beta3_ema).add_(grad, alpha=self.d * (1.0 - beta3_ema))
|
|
307
|
-
|
|
312
|
+
update = exp_avg + (alpha_t * exp_avg_slow) if self.beta1 > 0 else grad + (alpha_t * exp_avg_slow)
|
|
308
313
|
else:
|
|
309
|
-
|
|
314
|
+
update = exp_avg if self.beta1 > 0 else grad
|
|
310
315
|
|
|
311
316
|
exp_avg_sq.mul_(self.beta2).addcmul_(grad, grad.conj(), value=self.d * self.d * (1.0 - self.beta2))
|
|
312
317
|
|
|
313
318
|
if group['use_atan2']:
|
|
314
319
|
a = 1.2732395
|
|
315
320
|
denom = exp_avg_sq.sqrt()
|
|
316
|
-
update
|
|
321
|
+
update.atan2_(denom).mul_(a)
|
|
317
322
|
else:
|
|
318
|
-
denom = exp_avg_sq.sqrt()
|
|
319
|
-
update
|
|
320
|
-
del
|
|
323
|
+
denom = exp_avg_sq.sqrt()
|
|
324
|
+
update.div_(denom.add_(self.d * group['eps']))
|
|
325
|
+
del denom
|
|
321
326
|
|
|
322
|
-
update
|
|
327
|
+
update.mul_(self.dlr)
|
|
323
328
|
|
|
324
329
|
# --- Accumulate Prodigy stats ---
|
|
325
330
|
d0, safeguard_warmup, slice_p = group['d0'], group['safeguard_warmup'], group['slice_p']
|