adv-optm 1.1.0.dev3__py3-none-any.whl → 1.1.0.dev5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of adv-optm might be problematic. Click here for more details.
- adv_optm/__init__.py +1 -1
- adv_optm/optim/AdamW_adv.py +3 -3
- adv_optm/optim/Adopt_adv.py +435 -439
- adv_optm/optim/Lion_Prodigy_adv.py +315 -315
- adv_optm/optim/Lion_adv.py +1 -1
- adv_optm/optim/Prodigy_adv.py +13 -6
- adv_optm/optim/Simplified_AdEMAMix.py +3 -3
- adv_optm/util/Kourkoutas.py +71 -36
- {adv_optm-1.1.0.dev3.dist-info → adv_optm-1.1.0.dev5.dist-info}/METADATA +1 -1
- adv_optm-1.1.0.dev5.dist-info/RECORD +20 -0
- adv_optm-1.1.0.dev3.dist-info/RECORD +0 -20
- {adv_optm-1.1.0.dev3.dist-info → adv_optm-1.1.0.dev5.dist-info}/WHEEL +0 -0
- {adv_optm-1.1.0.dev3.dist-info → adv_optm-1.1.0.dev5.dist-info}/licenses/LICENSE +0 -0
- {adv_optm-1.1.0.dev3.dist-info → adv_optm-1.1.0.dev5.dist-info}/top_level.txt +0 -0
adv_optm/optim/Adopt_adv.py
CHANGED
|
@@ -1,440 +1,436 @@
|
|
|
1
|
-
import torch
|
|
2
|
-
from typing import Callable, Optional
|
|
3
|
-
|
|
4
|
-
from ..util.BF16_Stochastic_Rounding import add_stochastic_
|
|
5
|
-
from ..util.Effective_Shape import _get_effective_shape
|
|
6
|
-
from ..util.NNMF import _nnmf, _unnmf
|
|
7
|
-
from ..util.OrthoGrad import _orthogonalize_gradient
|
|
8
|
-
from ..util.One_Bit_Boolean import _pack_bools, _unpack_bools
|
|
9
|
-
from ..util.Kourkoutas import KourkoutasHelper
|
|
10
|
-
|
|
11
|
-
class Adopt_adv(torch.optim.Optimizer):
|
|
12
|
-
"""
|
|
13
|
-
Implements a fusion of SMMF, and the ADOPT algorithm.
|
|
14
|
-
|
|
15
|
-
The ADOPT update rule modifies Adam by:
|
|
16
|
-
1. **Initialization:** The second moment `v` is initialized as `v₀ = g₀²`.
|
|
17
|
-
2. **Decorrelation:** The current gradient is normalized using the second-moment estimate
|
|
18
|
-
from the *previous* step (`v_{t-1}`).
|
|
19
|
-
3. **Order of Operations:** This normalization occurs *before* updating the
|
|
20
|
-
first-moment (momentum) estimate.
|
|
21
|
-
|
|
22
|
-
Args:
|
|
23
|
-
params (iterable): iterable of parameters to optimize or dicts defining
|
|
24
|
-
parameter groups
|
|
25
|
-
lr (float): learning rate (default: 1e-4)
|
|
26
|
-
betas (tuple[float, float]): coefficients used for computing running
|
|
27
|
-
averages of momentum and variance (default: (0.9, 0.9999))
|
|
28
|
-
eps (float): term added to the denominator to improve
|
|
29
|
-
numerical stability (default: 1e-6)
|
|
30
|
-
weight_decay (float): weight decay (L2 penalty) (default: 0)
|
|
31
|
-
clip_lambda (Callable, optional): A function that takes the current step
|
|
32
|
-
and returns a value to clip the normalized gradient. Only used when
|
|
33
|
-
`use_atan2` is False. (default: `lambda step: step**0.25`)
|
|
34
|
-
vector_reshape (bool): whether to reshape 1D vectors into 2D
|
|
35
|
-
matrices for low-rank compression (default: True).
|
|
36
|
-
stochastic_rounding (bool): whether to use stochastic
|
|
37
|
-
rounding for BF16 parameter updates (default: True).
|
|
38
|
-
use_atan2 (bool): whether to use an atan2-based normalization, which can
|
|
39
|
-
improve stability by removing the need for `eps`. (default: False)
|
|
40
|
-
cautious_mask (bool): whether to use cautious masking to align the gradient's
|
|
41
|
-
direction with the first moment's. (default: False)
|
|
42
|
-
grams_moment (bool): whether to combine the gradient's direction with the
|
|
43
|
-
first moment's magnitude (default: False).
|
|
44
|
-
orthogonal_gradient (bool): whether to use OrthoGrad. (default: False)
|
|
45
|
-
use_AdEMAMix (bool): whether to enable the AdEMAMix feature. This adds
|
|
46
|
-
a second, slow-moving average of the momentum (`mt_slow`) which is
|
|
47
|
-
combined with the primary momentum (`mt`) to stabilize updates,
|
|
48
|
-
especially in noisy, small-batch settings. If `False`, the
|
|
49
|
-
optimizer behaves as standard ADOPT. (default: False)
|
|
50
|
-
beta3_ema (float): The decay rate for the slow exponential moving average of
|
|
51
|
-
the momentum (only used when `use_AdEMAMix` is `True`). A higher
|
|
52
|
-
value (e.g., 0.9999) gives the EMA a longer memory, making it more
|
|
53
|
-
stable but slower to adapt. A lower value (e.g., 0.999) is often
|
|
54
|
-
better for shorter training runs. (default: 0.9999)
|
|
55
|
-
alpha (float): The mixing coefficient that scales the slow momentum term
|
|
56
|
-
before it is added to the fast momentum term (`update = mt + alpha * mt_slow`).
|
|
57
|
-
A higher value increases the stabilizing influence of the slow
|
|
58
|
-
momentum. (default: 5.0)
|
|
59
|
-
t_alpha (Optional[int]): The number of steps for a linear warmup of the
|
|
60
|
-
`alpha` parameter (only used when `use_AdEMAMix` is `True`). This is
|
|
61
|
-
highly recommended to prevent instability at the beginning of training,
|
|
62
|
-
as it gradually introduces the stabilizing slow momentum term. During
|
|
63
|
-
the warmup, `alpha` ramps from 0 to its target value. If `None`,
|
|
64
|
-
the scheduler is disabled and the full `alpha` value is used from
|
|
65
|
-
the start. (default: None)
|
|
66
|
-
Simplified_AdEMAMix (bool): whether to use the Simplified AdEMAMix update rule.
|
|
67
|
-
This changes the EMA to accumulator and the update numerator to `alpha_grad * grad + mt`, which can be
|
|
68
|
-
more responsive, especially for small batch sizes. Enabling this will
|
|
69
|
-
automatically disable `use_AdEMAMix`, `cautious_mask`, `grams_moment`,
|
|
70
|
-
and `use_atan2`. (default: False)
|
|
71
|
-
alpha_grad (float): Mixing coefficient for the Simplified AdEMAMix update rule
|
|
72
|
-
(only used when `Simplified_AdEMAMix` is `True`). Controls the weight of the
|
|
73
|
-
current gradient. For small batch sizes, use high values (e.g., 10-100) to be
|
|
74
|
-
more responsive. For large batch sizes, use low values (e.g., 0-1) for
|
|
75
|
-
stability. (default: 100.0)
|
|
76
|
-
kourkoutas_beta (bool): whether to enable the layer-wise dynamic β₂ logic.
|
|
77
|
-
If `False`, the optimizer behaves as standard Adopt. (default: False)
|
|
78
|
-
beta2_min (float): The minimum value for dynamic β₂, used during periods of
|
|
79
|
-
high gradient variance ("sunspikes"). Must be less than `betas[1]`.
|
|
80
|
-
(default: 0.88)
|
|
81
|
-
ema_alpha (float): The decay rate for the Exponential Moving Average (EMA) of
|
|
82
|
-
the pooled gradient norms. Corresponds to `α` in the paper.
|
|
83
|
-
(default: 0.93)
|
|
84
|
-
tiny_spike (float): A small constant added to the denominator of the
|
|
85
|
-
"sunspike" ratio calculation to prevent division by zero. Corresponds
|
|
86
|
-
to `ε_spike` in the paper. (default: 1e-9)
|
|
87
|
-
k_warmup_steps (int): The number of initial steps during which β₂ is held
|
|
88
|
-
at a fixed average value (`(beta2_min + beta2_max) / 2`) before the
|
|
89
|
-
dynamic logic activates. (default: 0)
|
|
90
|
-
k_logging (int): if > 0 and kourkoutas_beta=True, enables periodic console
|
|
91
|
-
logging of Kourkoutas-β statistics (min, max, mean of `β₂` across layers)
|
|
92
|
-
every logging steps. Useful for debugging and tuning. Set to 0 to disable
|
|
93
|
-
logging (default: 0).
|
|
94
|
-
layer_key_fn (Optional[Callable]): A function that takes a parameter `p`
|
|
95
|
-
and returns a unique, hashable key representing its "layer" or "bucket".
|
|
96
|
-
If `None`, parameters are bucketed by their memory ID (tensor-wise).
|
|
97
|
-
(default: None)
|
|
98
|
-
nnmf_factor (bool): whether to use the factorization or disable it to use
|
|
99
|
-
the uncompressed optimizer. (default: False)
|
|
100
|
-
"""
|
|
101
|
-
|
|
102
|
-
def __init__(
|
|
103
|
-
self,
|
|
104
|
-
params,
|
|
105
|
-
lr: float = 1e-4,
|
|
106
|
-
betas: tuple[float, float] = (0.9, 0.9999),
|
|
107
|
-
eps: float = 1e-6,
|
|
108
|
-
weight_decay: float = 0.0,
|
|
109
|
-
clip_lambda: Optional[Callable[[int], float]] = lambda step: step**0.25,
|
|
110
|
-
vector_reshape: bool = True,
|
|
111
|
-
stochastic_rounding: bool = True,
|
|
112
|
-
use_atan2: bool = False,
|
|
113
|
-
cautious_mask: bool = False,
|
|
114
|
-
grams_moment: bool = False,
|
|
115
|
-
orthogonal_gradient: bool = False,
|
|
116
|
-
use_AdEMAMix: bool = False,
|
|
117
|
-
beta3_ema: float = 0.9999,
|
|
118
|
-
alpha: float = 5.0,
|
|
119
|
-
t_alpha: int | None = None,
|
|
120
|
-
Simplified_AdEMAMix: bool = False,
|
|
121
|
-
alpha_grad: float = 100.0,
|
|
122
|
-
kourkoutas_beta: bool = False,
|
|
123
|
-
beta2_min: float = 0.
|
|
124
|
-
ema_alpha: float = 0.
|
|
125
|
-
tiny_spike: float = 1e-9,
|
|
126
|
-
k_warmup_steps: int = 0,
|
|
127
|
-
k_logging: int = 0,
|
|
128
|
-
layer_key_fn: Optional[Callable] = None,
|
|
129
|
-
nnmf_factor: bool = False,
|
|
130
|
-
):
|
|
131
|
-
if not (lr >= 0.0):
|
|
132
|
-
raise ValueError(f"Learning-rate should be >= 0.0. Got {lr}")
|
|
133
|
-
if not (0.0 <= betas[0] < 1.0 and 0.0 <= betas[1] < 1.0):
|
|
134
|
-
raise ValueError(f"Betas should be in [0.0, 1.0). Got {betas}")
|
|
135
|
-
if not (eps >= 0.0):
|
|
136
|
-
raise ValueError(f"Epsilon should be >= 0.0. Got {eps}")
|
|
137
|
-
if not (weight_decay >= 0.0):
|
|
138
|
-
raise ValueError(f"Weight-decay should be >= 0.0. Got {weight_decay}")
|
|
139
|
-
if cautious_mask and grams_moment:
|
|
140
|
-
print("Warning: cautious is incompatible with grams, Disabling cautious.")
|
|
141
|
-
cautious_mask = False
|
|
142
|
-
if betas[0] == 0.0 and Simplified_AdEMAMix:
|
|
143
|
-
raise ValueError(f"Beta1 cannot be 0.0 when using Simplified_AdEMAMix. Got {betas[0]}")
|
|
144
|
-
if kourkoutas_beta and not (betas[1] > beta2_min): raise ValueError(f"For Kourkoutas-β, betas[1] (as beta2_max) must be > beta2_min. Got {betas[1]} and {beta2_min}")
|
|
145
|
-
if use_AdEMAMix and Simplified_AdEMAMix:
|
|
146
|
-
print("Warning: use_AdEMAMix is incompatible with Simplified_AdEMAMix, Disabling use_AdEMAMix.")
|
|
147
|
-
if grams_moment and Simplified_AdEMAMix:
|
|
148
|
-
print("Warning: grams is incompatible with Simplified_AdEMAMix, Disabling grams.")
|
|
149
|
-
if cautious_mask and Simplified_AdEMAMix:
|
|
150
|
-
print("Warning: cautious is incompatible with Simplified_AdEMAMix, Disabling cautious.")
|
|
151
|
-
if use_atan2 and Simplified_AdEMAMix:
|
|
152
|
-
print("Warning: use_atan2 is incompatible with Simplified_AdEMAMix. Disabling use_atan2.")
|
|
153
|
-
use_atan2 = False
|
|
154
|
-
|
|
155
|
-
defaults = {
|
|
156
|
-
"lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay,
|
|
157
|
-
"vector_reshape": vector_reshape, "beta3_ema": beta3_ema, "alpha": alpha,
|
|
158
|
-
"t_alpha": t_alpha, "alpha_grad": alpha_grad,
|
|
159
|
-
"kourkoutas_beta": kourkoutas_beta, "beta2_min": beta2_min, "ema_alpha": ema_alpha,
|
|
160
|
-
"tiny_spike": tiny_spike, "k_warmup_steps": k_warmup_steps, "k_logging": k_logging,
|
|
161
|
-
}
|
|
162
|
-
self.clip_lambda = clip_lambda
|
|
163
|
-
self.stochastic_rounding = stochastic_rounding
|
|
164
|
-
self.use_atan2 = use_atan2 and not Simplified_AdEMAMix
|
|
165
|
-
self.cautious_mask = cautious_mask and not Simplified_AdEMAMix
|
|
166
|
-
self.grams_moment = grams_moment and not Simplified_AdEMAMix
|
|
167
|
-
self.orthogonal_gradient = orthogonal_gradient
|
|
168
|
-
self.use_AdEMAMix = use_AdEMAMix and not Simplified_AdEMAMix
|
|
169
|
-
self.Simplified_AdEMAMix = Simplified_AdEMAMix
|
|
170
|
-
self.factored = nnmf_factor
|
|
171
|
-
self.kourkoutas_beta = kourkoutas_beta
|
|
172
|
-
self.layer_key_fn = layer_key_fn
|
|
173
|
-
super().__init__(params, defaults)
|
|
174
|
-
|
|
175
|
-
if self.kourkoutas_beta:
|
|
176
|
-
self.kourkoutas_helper = KourkoutasHelper(self)
|
|
177
|
-
|
|
178
|
-
@property
|
|
179
|
-
def supports_fused_back_pass(self): return True
|
|
180
|
-
@property
|
|
181
|
-
def supports_memory_efficient_fp16(self): return True
|
|
182
|
-
@property
|
|
183
|
-
def supports_flat_params(self): return False
|
|
184
|
-
|
|
185
|
-
@torch.no_grad()
|
|
186
|
-
def step_parameter(self, p: torch.Tensor, group: dict, i: int | None = None):
|
|
187
|
-
if p.grad is None:
|
|
188
|
-
return
|
|
189
|
-
|
|
190
|
-
grad = p.grad
|
|
191
|
-
if self.factored and grad.dtype != torch.float32:
|
|
192
|
-
grad = grad.float()
|
|
193
|
-
if self.orthogonal_gradient:
|
|
194
|
-
grad = _orthogonalize_gradient(p, grad)
|
|
195
|
-
state = self.state[p]
|
|
196
|
-
|
|
197
|
-
# State Initialization
|
|
198
|
-
if
|
|
199
|
-
state['step'] = 0
|
|
200
|
-
|
|
201
|
-
should_factor = (
|
|
202
|
-
self.factored and
|
|
203
|
-
not (len(p.shape) == 1 and not group['vector_reshape'])
|
|
204
|
-
)
|
|
205
|
-
|
|
206
|
-
state['factored'] = should_factor
|
|
207
|
-
|
|
208
|
-
dtype = torch.float32 if self.factored else p.dtype
|
|
209
|
-
|
|
210
|
-
if state['factored']:
|
|
211
|
-
state['effective_shape'] = _get_effective_shape(p.numel())
|
|
212
|
-
d1, d2 = state['effective_shape']
|
|
213
|
-
|
|
214
|
-
# m_0 = 0
|
|
215
|
-
if group['betas'][0] > 0:
|
|
216
|
-
state['mu_m_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
|
|
217
|
-
state['mv_m_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
|
|
218
|
-
if not self.grams_moment:
|
|
219
|
-
packed_d2 = (d2 + 7) // 8
|
|
220
|
-
state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=p.device)
|
|
221
|
-
if self.use_AdEMAMix:
|
|
222
|
-
state['mu_m_slow_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
|
|
223
|
-
state['mv_m_slow_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
|
|
224
|
-
packed_d2 = (d2 + 7) // 8
|
|
225
|
-
state['sign_slow'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=p.device)
|
|
226
|
-
# v_0 = g_0^2 (SMMF_ADOPT NMF storage)
|
|
227
|
-
vt_init = grad.view(d1, d2).square_()
|
|
228
|
-
# Allocate NMF factors for v
|
|
229
|
-
state['mu_v_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
|
|
230
|
-
state['mv_v_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
|
|
231
|
-
# Initialize v_0 using NMF
|
|
232
|
-
_nnmf(vt_init, out=(state['mu_v_nmf'], state['mv_v_nmf']))
|
|
233
|
-
else: # Fallback for non-factored tensors
|
|
234
|
-
if group['betas'][0] > 0:
|
|
235
|
-
state['exp_avg'] = torch.zeros_like(p, dtype=dtype) # m_0
|
|
236
|
-
if self.use_AdEMAMix:
|
|
237
|
-
state['exp_avg_slow'] = torch.zeros_like(p, dtype=dtype)
|
|
238
|
-
state['exp_avg_sq'] = grad.square() # v_0
|
|
239
|
-
|
|
240
|
-
beta1, beta2 = group['betas']
|
|
241
|
-
|
|
242
|
-
current_step = state['step']
|
|
243
|
-
if group['kourkoutas_beta']:
|
|
244
|
-
# Call prepare_step() once at the beginning of the step for all params
|
|
245
|
-
self.kourkoutas_helper.maybe_prepare_step(current_step)
|
|
246
|
-
# Accumulate current grad's norm for the *next* step
|
|
247
|
-
self.kourkoutas_helper.accumulate_gradient_sq_norm(p, grad)
|
|
248
|
-
# Get the dynamic beta2 calculated in prepare_step()
|
|
249
|
-
beta2 = self.kourkoutas_helper.get_beta2(p, group, current_step)
|
|
250
|
-
|
|
251
|
-
# The first step is for initialization only (skip when use_atan2 as it's scale invariant).
|
|
252
|
-
if state['step'] == 0 and not self.use_atan2:
|
|
253
|
-
state['step'] += 1
|
|
254
|
-
return
|
|
255
|
-
|
|
256
|
-
if self.use_AdEMAMix:
|
|
257
|
-
beta3_ema = group['beta3_ema']
|
|
258
|
-
alpha = group['alpha']
|
|
259
|
-
t_alpha = group['t_alpha']
|
|
260
|
-
# Use step+1 for 1-based step count in scheduler
|
|
261
|
-
alpha_step = state['step'] + 1
|
|
262
|
-
alpha_t = alpha
|
|
263
|
-
if t_alpha is not None and t_alpha > 0 and alpha_step < t_alpha:
|
|
264
|
-
alpha_t = min(alpha_step * alpha / t_alpha, alpha)
|
|
265
|
-
if self.Simplified_AdEMAMix:
|
|
266
|
-
alpha_grad = group["alpha_grad"]
|
|
267
|
-
|
|
268
|
-
if state['factored']:
|
|
269
|
-
d1, d2 = state['effective_shape']
|
|
270
|
-
|
|
271
|
-
# Reconstruct m_{t-1}
|
|
272
|
-
if beta1 > 0:
|
|
273
|
-
mt = _unnmf((state['mu_m_nmf'], state['mv_m_nmf']))
|
|
274
|
-
if not self.grams_moment:
|
|
275
|
-
if state['sign'].dtype != torch.uint8:
|
|
276
|
-
state['sign'] = state['sign'].to(torch.uint8)
|
|
277
|
-
unpacked_sign = _unpack_bools(state['sign'], original_m=d2)
|
|
278
|
-
torch.where(unpacked_sign, mt, -mt, out=mt)
|
|
279
|
-
del unpacked_sign
|
|
280
|
-
|
|
281
|
-
# Reconstruct AdEMAMix EMA
|
|
282
|
-
if self.use_AdEMAMix:
|
|
283
|
-
mt_slow = _unnmf((state['mu_m_slow_nmf'], state['mv_m_slow_nmf']))
|
|
284
|
-
if state['sign_slow'].dtype != torch.uint8:
|
|
285
|
-
state['sign_slow'] = state['sign_slow'].to(torch.uint8)
|
|
286
|
-
unpacked_sign_slow = _unpack_bools(state['sign_slow'], original_m=d2)
|
|
287
|
-
torch.where(unpacked_sign_slow, mt_slow, -mt_slow, out=mt_slow)
|
|
288
|
-
del unpacked_sign_slow
|
|
289
|
-
|
|
290
|
-
# Reconstruct v_{t-1} using NNMF
|
|
291
|
-
vt = _unnmf((state['mu_v_nmf'], state['mv_v_nmf']))
|
|
292
|
-
|
|
293
|
-
# ADOPT Step A: Decorrelate g_t using v_{t-1}
|
|
294
|
-
grad_reshaped = grad.view(d1, d2)
|
|
295
|
-
denom = vt.sqrt()
|
|
296
|
-
|
|
297
|
-
if self.use_atan2:
|
|
298
|
-
normalized_grad = torch.atan2(grad_reshaped, denom)
|
|
299
|
-
else:
|
|
300
|
-
normalized_grad = grad_reshaped / denom.add_(group['eps'])
|
|
301
|
-
if self.clip_lambda is not None:
|
|
302
|
-
clip_val = self.clip_lambda(state['step'])
|
|
303
|
-
normalized_grad.clamp_(-clip_val, clip_val)
|
|
304
|
-
del denom
|
|
305
|
-
|
|
306
|
-
# ADOPT Step B: Update momentum m_t using normalized gradient
|
|
307
|
-
if beta1 > 0:
|
|
308
|
-
if self.Simplified_AdEMAMix:
|
|
309
|
-
mt.mul_(beta1).add_(normalized_grad, alpha=1.0)
|
|
310
|
-
else:
|
|
311
|
-
mt.mul_(beta1).add_(normalized_grad, alpha=1.0 - beta1)
|
|
312
|
-
if self.grams_moment:
|
|
313
|
-
mt = grad_reshaped.sign() * mt.abs()
|
|
314
|
-
elif self.cautious_mask:
|
|
315
|
-
mask = (mt * grad_reshaped > 0).to(grad_reshaped.dtype)
|
|
316
|
-
mask.div_(mask.mean().clamp_(min=1e-3))
|
|
317
|
-
mt.mul_(mask)
|
|
318
|
-
del mask
|
|
319
|
-
|
|
320
|
-
if self.use_AdEMAMix:
|
|
321
|
-
mt_slow.mul_(beta3_ema).add_(normalized_grad, alpha=1.0 - beta3_ema)
|
|
322
|
-
if beta1 > 0:
|
|
323
|
-
update = torch.add(mt, mt_slow, alpha=alpha_t)
|
|
324
|
-
else:
|
|
325
|
-
update = torch.add(normalized_grad, mt_slow, alpha=alpha_t)
|
|
326
|
-
elif self.Simplified_AdEMAMix:
|
|
327
|
-
update = torch.add(mt, normalized_grad, alpha=alpha_grad)
|
|
328
|
-
else:
|
|
329
|
-
update = mt.clone() if beta1 > 0 else normalized_grad
|
|
330
|
-
|
|
331
|
-
update = update.view(p.shape)
|
|
332
|
-
|
|
333
|
-
if self.use_atan2:
|
|
334
|
-
update.mul_(group['lr'] * 1.2732395447351628)
|
|
335
|
-
else:
|
|
336
|
-
update.mul_(group['lr'])
|
|
337
|
-
|
|
338
|
-
# Update second moment v_t for the *next* step using raw g_t
|
|
339
|
-
vt.mul_(beta2).addcmul_(grad_reshaped, grad_reshaped, value=1.0 - beta2)
|
|
340
|
-
del grad_reshaped
|
|
341
|
-
|
|
342
|
-
# Compress and store new factors
|
|
343
|
-
if beta1 > 0:
|
|
344
|
-
if not self.grams_moment:
|
|
345
|
-
state['sign'] = _pack_bools(mt > 0)
|
|
346
|
-
_nnmf(mt.abs(), out=(state['mu_m_nmf'], state['mv_m_nmf']))
|
|
347
|
-
del mt
|
|
348
|
-
|
|
349
|
-
if self.use_AdEMAMix:
|
|
350
|
-
state['sign_slow'] = _pack_bools(mt_slow > 0)
|
|
351
|
-
_nnmf(mt_slow.abs(), out=(state['mu_m_slow_nmf'], state['mv_m_slow_nmf']))
|
|
352
|
-
del mt_slow
|
|
353
|
-
|
|
354
|
-
# factorize v_t using NMF compression
|
|
355
|
-
_nnmf(vt, out=(state['mu_v_nmf'], state['mv_v_nmf']))
|
|
356
|
-
del vt
|
|
357
|
-
|
|
358
|
-
else: # Standard ADOPT logic for non-factored tensors
|
|
359
|
-
v = state['exp_avg_sq'] # v_{t-1}
|
|
360
|
-
|
|
361
|
-
# ADOPT Step A: Decorrelate g_t using v_{t-1}
|
|
362
|
-
denom = v.sqrt()
|
|
363
|
-
|
|
364
|
-
if self.use_atan2:
|
|
365
|
-
normalized_grad = torch.atan2(grad, denom)
|
|
366
|
-
else:
|
|
367
|
-
normalized_grad = grad / denom.add_(group['eps'])
|
|
368
|
-
if self.clip_lambda is not None:
|
|
369
|
-
clip_val = self.clip_lambda(state['step'])
|
|
370
|
-
normalized_grad.clamp_(-clip_val, clip_val)
|
|
371
|
-
del denom
|
|
372
|
-
|
|
373
|
-
# ADOPT Step B: Update momentum m_t
|
|
374
|
-
if beta1 > 0:
|
|
375
|
-
m = state['exp_avg'] # m_{t-1},
|
|
376
|
-
if self.Simplified_AdEMAMix:
|
|
377
|
-
m.mul_(beta1).add_(normalized_grad, alpha=1.0)
|
|
378
|
-
else:
|
|
379
|
-
m.mul_(beta1).add_(normalized_grad, alpha=1.0 - beta1)
|
|
380
|
-
|
|
381
|
-
if self.grams_moment:
|
|
382
|
-
m = grad.sign() * m.abs()
|
|
383
|
-
elif self.cautious_mask:
|
|
384
|
-
mask = (m * grad > 0).to(grad.dtype)
|
|
385
|
-
mask.div_(mask.mean().clamp_(min=1e-3))
|
|
386
|
-
m.mul_(mask)
|
|
387
|
-
del mask
|
|
388
|
-
|
|
389
|
-
if self.use_AdEMAMix:
|
|
390
|
-
m_slow = state['exp_avg_slow']
|
|
391
|
-
m_slow.mul_(beta3_ema).add_(normalized_grad, alpha=1.0 - beta3_ema)
|
|
392
|
-
if beta1 > 0:
|
|
393
|
-
update = torch.add(m, m_slow, alpha=alpha_t)
|
|
394
|
-
else:
|
|
395
|
-
update = torch.add(normalized_grad, m_slow, alpha=alpha_t)
|
|
396
|
-
elif self.Simplified_AdEMAMix:
|
|
397
|
-
update = torch.add(m, normalized_grad, alpha=alpha_grad)
|
|
398
|
-
else:
|
|
399
|
-
update = m.clone() if beta1 > 0 else normalized_grad
|
|
400
|
-
|
|
401
|
-
if self.use_atan2:
|
|
402
|
-
update.mul_(group['lr'] * 1.2732395447351628)
|
|
403
|
-
else:
|
|
404
|
-
update.mul_(group['lr'])
|
|
405
|
-
|
|
406
|
-
# Update second moment v_t for the next step using raw g_t
|
|
407
|
-
v.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2)
|
|
408
|
-
|
|
409
|
-
# Parameter Update
|
|
410
|
-
if group["weight_decay"] != 0:
|
|
411
|
-
if p.dtype == torch.bfloat16 and self.stochastic_rounding:
|
|
412
|
-
add_stochastic_(p.data, p.data, alpha=-group["weight_decay"] * group["lr"])
|
|
413
|
-
else:
|
|
414
|
-
p.data.add_(p.data, alpha=-group["weight_decay"] * group["lr"])
|
|
415
|
-
|
|
416
|
-
if p.dtype == torch.bfloat16 and self.stochastic_rounding:
|
|
417
|
-
add_stochastic_(p.data, -update)
|
|
418
|
-
else:
|
|
419
|
-
p.data.add_(-update)
|
|
420
|
-
del update
|
|
421
|
-
|
|
422
|
-
state['step'] += 1
|
|
423
|
-
|
|
424
|
-
@torch.no_grad()
|
|
425
|
-
def step(self, closure=None):
|
|
426
|
-
"""Performs a single optimization step."""
|
|
427
|
-
loss = None
|
|
428
|
-
if closure is not None:
|
|
429
|
-
with torch.enable_grad():
|
|
430
|
-
loss = closure()
|
|
431
|
-
|
|
432
|
-
for group in self.param_groups:
|
|
433
|
-
for i, p in enumerate(group['params']):
|
|
434
|
-
self.step_parameter(p, group, i)
|
|
435
|
-
|
|
436
|
-
if self.kourkoutas_beta and self.k_logging > 0 and hasattr(self, '_beta2_log'):
|
|
437
|
-
first_param_state = self.state[self.param_groups[0]['params'][0]]
|
|
438
|
-
step_num = first_param_state['step']
|
|
439
|
-
|
|
1
|
+
import torch
|
|
2
|
+
from typing import Callable, Optional
|
|
3
|
+
|
|
4
|
+
from ..util.BF16_Stochastic_Rounding import add_stochastic_
|
|
5
|
+
from ..util.Effective_Shape import _get_effective_shape
|
|
6
|
+
from ..util.NNMF import _nnmf, _unnmf
|
|
7
|
+
from ..util.OrthoGrad import _orthogonalize_gradient
|
|
8
|
+
from ..util.One_Bit_Boolean import _pack_bools, _unpack_bools
|
|
9
|
+
from ..util.Kourkoutas import KourkoutasHelper
|
|
10
|
+
|
|
11
|
+
class Adopt_adv(torch.optim.Optimizer):
|
|
12
|
+
"""
|
|
13
|
+
Implements a fusion of SMMF, and the ADOPT algorithm.
|
|
14
|
+
|
|
15
|
+
The ADOPT update rule modifies Adam by:
|
|
16
|
+
1. **Initialization:** The second moment `v` is initialized as `v₀ = g₀²`.
|
|
17
|
+
2. **Decorrelation:** The current gradient is normalized using the second-moment estimate
|
|
18
|
+
from the *previous* step (`v_{t-1}`).
|
|
19
|
+
3. **Order of Operations:** This normalization occurs *before* updating the
|
|
20
|
+
first-moment (momentum) estimate.
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
params (iterable): iterable of parameters to optimize or dicts defining
|
|
24
|
+
parameter groups
|
|
25
|
+
lr (float): learning rate (default: 1e-4)
|
|
26
|
+
betas (tuple[float, float]): coefficients used for computing running
|
|
27
|
+
averages of momentum and variance (default: (0.9, 0.9999))
|
|
28
|
+
eps (float): term added to the denominator to improve
|
|
29
|
+
numerical stability (default: 1e-6)
|
|
30
|
+
weight_decay (float): weight decay (L2 penalty) (default: 0)
|
|
31
|
+
clip_lambda (Callable, optional): A function that takes the current step
|
|
32
|
+
and returns a value to clip the normalized gradient. Only used when
|
|
33
|
+
`use_atan2` is False. (default: `lambda step: step**0.25`)
|
|
34
|
+
vector_reshape (bool): whether to reshape 1D vectors into 2D
|
|
35
|
+
matrices for low-rank compression (default: True).
|
|
36
|
+
stochastic_rounding (bool): whether to use stochastic
|
|
37
|
+
rounding for BF16 parameter updates (default: True).
|
|
38
|
+
use_atan2 (bool): whether to use an atan2-based normalization, which can
|
|
39
|
+
improve stability by removing the need for `eps`. (default: False)
|
|
40
|
+
cautious_mask (bool): whether to use cautious masking to align the gradient's
|
|
41
|
+
direction with the first moment's. (default: False)
|
|
42
|
+
grams_moment (bool): whether to combine the gradient's direction with the
|
|
43
|
+
first moment's magnitude (default: False).
|
|
44
|
+
orthogonal_gradient (bool): whether to use OrthoGrad. (default: False)
|
|
45
|
+
use_AdEMAMix (bool): whether to enable the AdEMAMix feature. This adds
|
|
46
|
+
a second, slow-moving average of the momentum (`mt_slow`) which is
|
|
47
|
+
combined with the primary momentum (`mt`) to stabilize updates,
|
|
48
|
+
especially in noisy, small-batch settings. If `False`, the
|
|
49
|
+
optimizer behaves as standard ADOPT. (default: False)
|
|
50
|
+
beta3_ema (float): The decay rate for the slow exponential moving average of
|
|
51
|
+
the momentum (only used when `use_AdEMAMix` is `True`). A higher
|
|
52
|
+
value (e.g., 0.9999) gives the EMA a longer memory, making it more
|
|
53
|
+
stable but slower to adapt. A lower value (e.g., 0.999) is often
|
|
54
|
+
better for shorter training runs. (default: 0.9999)
|
|
55
|
+
alpha (float): The mixing coefficient that scales the slow momentum term
|
|
56
|
+
before it is added to the fast momentum term (`update = mt + alpha * mt_slow`).
|
|
57
|
+
A higher value increases the stabilizing influence of the slow
|
|
58
|
+
momentum. (default: 5.0)
|
|
59
|
+
t_alpha (Optional[int]): The number of steps for a linear warmup of the
|
|
60
|
+
`alpha` parameter (only used when `use_AdEMAMix` is `True`). This is
|
|
61
|
+
highly recommended to prevent instability at the beginning of training,
|
|
62
|
+
as it gradually introduces the stabilizing slow momentum term. During
|
|
63
|
+
the warmup, `alpha` ramps from 0 to its target value. If `None`,
|
|
64
|
+
the scheduler is disabled and the full `alpha` value is used from
|
|
65
|
+
the start. (default: None)
|
|
66
|
+
Simplified_AdEMAMix (bool): whether to use the Simplified AdEMAMix update rule.
|
|
67
|
+
This changes the EMA to accumulator and the update numerator to `alpha_grad * grad + mt`, which can be
|
|
68
|
+
more responsive, especially for small batch sizes. Enabling this will
|
|
69
|
+
automatically disable `use_AdEMAMix`, `cautious_mask`, `grams_moment`,
|
|
70
|
+
and `use_atan2`. (default: False)
|
|
71
|
+
alpha_grad (float): Mixing coefficient for the Simplified AdEMAMix update rule
|
|
72
|
+
(only used when `Simplified_AdEMAMix` is `True`). Controls the weight of the
|
|
73
|
+
current gradient. For small batch sizes, use high values (e.g., 10-100) to be
|
|
74
|
+
more responsive. For large batch sizes, use low values (e.g., 0-1) for
|
|
75
|
+
stability. (default: 100.0)
|
|
76
|
+
kourkoutas_beta (bool): whether to enable the layer-wise dynamic β₂ logic.
|
|
77
|
+
If `False`, the optimizer behaves as standard Adopt. (default: False)
|
|
78
|
+
beta2_min (float): The minimum value for dynamic β₂, used during periods of
|
|
79
|
+
high gradient variance ("sunspikes"). Must be less than `betas[1]`.
|
|
80
|
+
(default: 0.88)
|
|
81
|
+
ema_alpha (float): The decay rate for the Exponential Moving Average (EMA) of
|
|
82
|
+
the pooled gradient norms. Corresponds to `α` in the paper.
|
|
83
|
+
(default: 0.93)
|
|
84
|
+
tiny_spike (float): A small constant added to the denominator of the
|
|
85
|
+
"sunspike" ratio calculation to prevent division by zero. Corresponds
|
|
86
|
+
to `ε_spike` in the paper. (default: 1e-9)
|
|
87
|
+
k_warmup_steps (int): The number of initial steps during which β₂ is held
|
|
88
|
+
at a fixed average value (`(beta2_min + beta2_max) / 2`) before the
|
|
89
|
+
dynamic logic activates. (default: 0)
|
|
90
|
+
k_logging (int): if > 0 and kourkoutas_beta=True, enables periodic console
|
|
91
|
+
logging of Kourkoutas-β statistics (min, max, mean of `β₂` across layers)
|
|
92
|
+
every logging steps. Useful for debugging and tuning. Set to 0 to disable
|
|
93
|
+
logging (default: 0).
|
|
94
|
+
layer_key_fn (Optional[Callable]): A function that takes a parameter `p`
|
|
95
|
+
and returns a unique, hashable key representing its "layer" or "bucket".
|
|
96
|
+
If `None`, parameters are bucketed by their memory ID (tensor-wise).
|
|
97
|
+
(default: None)
|
|
98
|
+
nnmf_factor (bool): whether to use the factorization or disable it to use
|
|
99
|
+
the uncompressed optimizer. (default: False)
|
|
100
|
+
"""
|
|
101
|
+
|
|
102
|
+
def __init__(
|
|
103
|
+
self,
|
|
104
|
+
params,
|
|
105
|
+
lr: float = 1e-4,
|
|
106
|
+
betas: tuple[float, float] = (0.9, 0.9999),
|
|
107
|
+
eps: float = 1e-6,
|
|
108
|
+
weight_decay: float = 0.0,
|
|
109
|
+
clip_lambda: Optional[Callable[[int], float]] = lambda step: step**0.25,
|
|
110
|
+
vector_reshape: bool = True,
|
|
111
|
+
stochastic_rounding: bool = True,
|
|
112
|
+
use_atan2: bool = False,
|
|
113
|
+
cautious_mask: bool = False,
|
|
114
|
+
grams_moment: bool = False,
|
|
115
|
+
orthogonal_gradient: bool = False,
|
|
116
|
+
use_AdEMAMix: bool = False,
|
|
117
|
+
beta3_ema: float = 0.9999,
|
|
118
|
+
alpha: float = 5.0,
|
|
119
|
+
t_alpha: int | None = None,
|
|
120
|
+
Simplified_AdEMAMix: bool = False,
|
|
121
|
+
alpha_grad: float = 100.0,
|
|
122
|
+
kourkoutas_beta: bool = False,
|
|
123
|
+
beta2_min: float = 0.9,
|
|
124
|
+
ema_alpha: float = 0.95,
|
|
125
|
+
tiny_spike: float = 1e-9,
|
|
126
|
+
k_warmup_steps: int = 0,
|
|
127
|
+
k_logging: int = 0,
|
|
128
|
+
layer_key_fn: Optional[Callable] = None,
|
|
129
|
+
nnmf_factor: bool = False,
|
|
130
|
+
):
|
|
131
|
+
if not (lr >= 0.0):
|
|
132
|
+
raise ValueError(f"Learning-rate should be >= 0.0. Got {lr}")
|
|
133
|
+
if not (0.0 <= betas[0] < 1.0 and 0.0 <= betas[1] < 1.0):
|
|
134
|
+
raise ValueError(f"Betas should be in [0.0, 1.0). Got {betas}")
|
|
135
|
+
if not (eps >= 0.0):
|
|
136
|
+
raise ValueError(f"Epsilon should be >= 0.0. Got {eps}")
|
|
137
|
+
if not (weight_decay >= 0.0):
|
|
138
|
+
raise ValueError(f"Weight-decay should be >= 0.0. Got {weight_decay}")
|
|
139
|
+
if cautious_mask and grams_moment:
|
|
140
|
+
print("Warning: cautious is incompatible with grams, Disabling cautious.")
|
|
141
|
+
cautious_mask = False
|
|
142
|
+
if betas[0] == 0.0 and Simplified_AdEMAMix:
|
|
143
|
+
raise ValueError(f"Beta1 cannot be 0.0 when using Simplified_AdEMAMix. Got {betas[0]}")
|
|
144
|
+
if kourkoutas_beta and not (betas[1] > beta2_min): raise ValueError(f"For Kourkoutas-β, betas[1] (as beta2_max) must be > beta2_min. Got {betas[1]} and {beta2_min}")
|
|
145
|
+
if use_AdEMAMix and Simplified_AdEMAMix:
|
|
146
|
+
print("Warning: use_AdEMAMix is incompatible with Simplified_AdEMAMix, Disabling use_AdEMAMix.")
|
|
147
|
+
if grams_moment and Simplified_AdEMAMix:
|
|
148
|
+
print("Warning: grams is incompatible with Simplified_AdEMAMix, Disabling grams.")
|
|
149
|
+
if cautious_mask and Simplified_AdEMAMix:
|
|
150
|
+
print("Warning: cautious is incompatible with Simplified_AdEMAMix, Disabling cautious.")
|
|
151
|
+
if use_atan2 and Simplified_AdEMAMix:
|
|
152
|
+
print("Warning: use_atan2 is incompatible with Simplified_AdEMAMix. Disabling use_atan2.")
|
|
153
|
+
use_atan2 = False
|
|
154
|
+
|
|
155
|
+
defaults = {
|
|
156
|
+
"lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay,
|
|
157
|
+
"vector_reshape": vector_reshape, "beta3_ema": beta3_ema, "alpha": alpha,
|
|
158
|
+
"t_alpha": t_alpha, "alpha_grad": alpha_grad,
|
|
159
|
+
"kourkoutas_beta": kourkoutas_beta, "beta2_min": beta2_min, "ema_alpha": ema_alpha,
|
|
160
|
+
"tiny_spike": tiny_spike, "k_warmup_steps": k_warmup_steps, "k_logging": k_logging,
|
|
161
|
+
}
|
|
162
|
+
self.clip_lambda = clip_lambda
|
|
163
|
+
self.stochastic_rounding = stochastic_rounding
|
|
164
|
+
self.use_atan2 = use_atan2 and not Simplified_AdEMAMix
|
|
165
|
+
self.cautious_mask = cautious_mask and not Simplified_AdEMAMix
|
|
166
|
+
self.grams_moment = grams_moment and not Simplified_AdEMAMix
|
|
167
|
+
self.orthogonal_gradient = orthogonal_gradient
|
|
168
|
+
self.use_AdEMAMix = use_AdEMAMix and not Simplified_AdEMAMix
|
|
169
|
+
self.Simplified_AdEMAMix = Simplified_AdEMAMix
|
|
170
|
+
self.factored = nnmf_factor
|
|
171
|
+
self.kourkoutas_beta = kourkoutas_beta
|
|
172
|
+
self.layer_key_fn = layer_key_fn
|
|
173
|
+
super().__init__(params, defaults)
|
|
174
|
+
|
|
175
|
+
if self.kourkoutas_beta:
|
|
176
|
+
self.kourkoutas_helper = KourkoutasHelper(self)
|
|
177
|
+
|
|
178
|
+
@property
|
|
179
|
+
def supports_fused_back_pass(self): return True
|
|
180
|
+
@property
|
|
181
|
+
def supports_memory_efficient_fp16(self): return True
|
|
182
|
+
@property
|
|
183
|
+
def supports_flat_params(self): return False
|
|
184
|
+
|
|
185
|
+
@torch.no_grad()
|
|
186
|
+
def step_parameter(self, p: torch.Tensor, group: dict, i: int | None = None):
|
|
187
|
+
if p.grad is None:
|
|
188
|
+
return
|
|
189
|
+
|
|
190
|
+
grad = p.grad
|
|
191
|
+
if self.factored and grad.dtype != torch.float32:
|
|
192
|
+
grad = grad.float()
|
|
193
|
+
if self.orthogonal_gradient:
|
|
194
|
+
grad = _orthogonalize_gradient(p, grad)
|
|
195
|
+
state = self.state[p]
|
|
196
|
+
|
|
197
|
+
# State Initialization
|
|
198
|
+
if 'step' not in state:
|
|
199
|
+
state['step'] = 0
|
|
200
|
+
|
|
201
|
+
should_factor = (
|
|
202
|
+
self.factored and
|
|
203
|
+
not (len(p.shape) == 1 and not group['vector_reshape'])
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
state['factored'] = should_factor
|
|
207
|
+
|
|
208
|
+
dtype = torch.float32 if self.factored else p.dtype
|
|
209
|
+
|
|
210
|
+
if state['factored']:
|
|
211
|
+
state['effective_shape'] = _get_effective_shape(p.numel())
|
|
212
|
+
d1, d2 = state['effective_shape']
|
|
213
|
+
|
|
214
|
+
# m_0 = 0
|
|
215
|
+
if group['betas'][0] > 0:
|
|
216
|
+
state['mu_m_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
|
|
217
|
+
state['mv_m_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
|
|
218
|
+
if not self.grams_moment:
|
|
219
|
+
packed_d2 = (d2 + 7) // 8
|
|
220
|
+
state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=p.device)
|
|
221
|
+
if self.use_AdEMAMix:
|
|
222
|
+
state['mu_m_slow_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
|
|
223
|
+
state['mv_m_slow_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
|
|
224
|
+
packed_d2 = (d2 + 7) // 8
|
|
225
|
+
state['sign_slow'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=p.device)
|
|
226
|
+
# v_0 = g_0^2 (SMMF_ADOPT NMF storage)
|
|
227
|
+
vt_init = grad.view(d1, d2).square_()
|
|
228
|
+
# Allocate NMF factors for v
|
|
229
|
+
state['mu_v_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
|
|
230
|
+
state['mv_v_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
|
|
231
|
+
# Initialize v_0 using NMF
|
|
232
|
+
_nnmf(vt_init, out=(state['mu_v_nmf'], state['mv_v_nmf']))
|
|
233
|
+
else: # Fallback for non-factored tensors
|
|
234
|
+
if group['betas'][0] > 0:
|
|
235
|
+
state['exp_avg'] = torch.zeros_like(p, dtype=dtype) # m_0
|
|
236
|
+
if self.use_AdEMAMix:
|
|
237
|
+
state['exp_avg_slow'] = torch.zeros_like(p, dtype=dtype)
|
|
238
|
+
state['exp_avg_sq'] = grad.square() # v_0
|
|
239
|
+
|
|
240
|
+
beta1, beta2 = group['betas']
|
|
241
|
+
|
|
242
|
+
current_step = state['step']
|
|
243
|
+
if group['kourkoutas_beta']:
|
|
244
|
+
# Call prepare_step() once at the beginning of the step for all params
|
|
245
|
+
self.kourkoutas_helper.maybe_prepare_step(current_step)
|
|
246
|
+
# Accumulate current grad's norm for the *next* step
|
|
247
|
+
self.kourkoutas_helper.accumulate_gradient_sq_norm(p, grad)
|
|
248
|
+
# Get the dynamic beta2 calculated in prepare_step()
|
|
249
|
+
beta2 = self.kourkoutas_helper.get_beta2(p, group, current_step)
|
|
250
|
+
|
|
251
|
+
# The first step is for initialization only (skip when use_atan2 as it's scale invariant).
|
|
252
|
+
if state['step'] == 0 and not self.use_atan2:
|
|
253
|
+
state['step'] += 1
|
|
254
|
+
return
|
|
255
|
+
|
|
256
|
+
if self.use_AdEMAMix:
|
|
257
|
+
beta3_ema = group['beta3_ema']
|
|
258
|
+
alpha = group['alpha']
|
|
259
|
+
t_alpha = group['t_alpha']
|
|
260
|
+
# Use step+1 for 1-based step count in scheduler
|
|
261
|
+
alpha_step = state['step'] + 1
|
|
262
|
+
alpha_t = alpha
|
|
263
|
+
if t_alpha is not None and t_alpha > 0 and alpha_step < t_alpha:
|
|
264
|
+
alpha_t = min(alpha_step * alpha / t_alpha, alpha)
|
|
265
|
+
if self.Simplified_AdEMAMix:
|
|
266
|
+
alpha_grad = group["alpha_grad"]
|
|
267
|
+
|
|
268
|
+
if state['factored']:
|
|
269
|
+
d1, d2 = state['effective_shape']
|
|
270
|
+
|
|
271
|
+
# Reconstruct m_{t-1}
|
|
272
|
+
if beta1 > 0:
|
|
273
|
+
mt = _unnmf((state['mu_m_nmf'], state['mv_m_nmf']))
|
|
274
|
+
if not self.grams_moment:
|
|
275
|
+
if state['sign'].dtype != torch.uint8:
|
|
276
|
+
state['sign'] = state['sign'].to(torch.uint8)
|
|
277
|
+
unpacked_sign = _unpack_bools(state['sign'], original_m=d2)
|
|
278
|
+
torch.where(unpacked_sign, mt, -mt, out=mt)
|
|
279
|
+
del unpacked_sign
|
|
280
|
+
|
|
281
|
+
# Reconstruct AdEMAMix EMA
|
|
282
|
+
if self.use_AdEMAMix:
|
|
283
|
+
mt_slow = _unnmf((state['mu_m_slow_nmf'], state['mv_m_slow_nmf']))
|
|
284
|
+
if state['sign_slow'].dtype != torch.uint8:
|
|
285
|
+
state['sign_slow'] = state['sign_slow'].to(torch.uint8)
|
|
286
|
+
unpacked_sign_slow = _unpack_bools(state['sign_slow'], original_m=d2)
|
|
287
|
+
torch.where(unpacked_sign_slow, mt_slow, -mt_slow, out=mt_slow)
|
|
288
|
+
del unpacked_sign_slow
|
|
289
|
+
|
|
290
|
+
# Reconstruct v_{t-1} using NNMF
|
|
291
|
+
vt = _unnmf((state['mu_v_nmf'], state['mv_v_nmf']))
|
|
292
|
+
|
|
293
|
+
# ADOPT Step A: Decorrelate g_t using v_{t-1}
|
|
294
|
+
grad_reshaped = grad.view(d1, d2)
|
|
295
|
+
denom = vt.sqrt()
|
|
296
|
+
|
|
297
|
+
if self.use_atan2:
|
|
298
|
+
normalized_grad = torch.atan2(grad_reshaped, denom)
|
|
299
|
+
else:
|
|
300
|
+
normalized_grad = grad_reshaped / denom.add_(group['eps'])
|
|
301
|
+
if self.clip_lambda is not None:
|
|
302
|
+
clip_val = self.clip_lambda(state['step'])
|
|
303
|
+
normalized_grad.clamp_(-clip_val, clip_val)
|
|
304
|
+
del denom
|
|
305
|
+
|
|
306
|
+
# ADOPT Step B: Update momentum m_t using normalized gradient
|
|
307
|
+
if beta1 > 0:
|
|
308
|
+
if self.Simplified_AdEMAMix:
|
|
309
|
+
mt.mul_(beta1).add_(normalized_grad, alpha=1.0)
|
|
310
|
+
else:
|
|
311
|
+
mt.mul_(beta1).add_(normalized_grad, alpha=1.0 - beta1)
|
|
312
|
+
if self.grams_moment:
|
|
313
|
+
mt = grad_reshaped.sign() * mt.abs()
|
|
314
|
+
elif self.cautious_mask:
|
|
315
|
+
mask = (mt * grad_reshaped > 0).to(grad_reshaped.dtype)
|
|
316
|
+
mask.div_(mask.mean().clamp_(min=1e-3))
|
|
317
|
+
mt.mul_(mask)
|
|
318
|
+
del mask
|
|
319
|
+
|
|
320
|
+
if self.use_AdEMAMix:
|
|
321
|
+
mt_slow.mul_(beta3_ema).add_(normalized_grad, alpha=1.0 - beta3_ema)
|
|
322
|
+
if beta1 > 0:
|
|
323
|
+
update = torch.add(mt, mt_slow, alpha=alpha_t)
|
|
324
|
+
else:
|
|
325
|
+
update = torch.add(normalized_grad, mt_slow, alpha=alpha_t)
|
|
326
|
+
elif self.Simplified_AdEMAMix:
|
|
327
|
+
update = torch.add(mt, normalized_grad, alpha=alpha_grad)
|
|
328
|
+
else:
|
|
329
|
+
update = mt.clone() if beta1 > 0 else normalized_grad
|
|
330
|
+
|
|
331
|
+
update = update.view(p.shape)
|
|
332
|
+
|
|
333
|
+
if self.use_atan2:
|
|
334
|
+
update.mul_(group['lr'] * 1.2732395447351628)
|
|
335
|
+
else:
|
|
336
|
+
update.mul_(group['lr'])
|
|
337
|
+
|
|
338
|
+
# Update second moment v_t for the *next* step using raw g_t
|
|
339
|
+
vt.mul_(beta2).addcmul_(grad_reshaped, grad_reshaped, value=1.0 - beta2)
|
|
340
|
+
del grad_reshaped
|
|
341
|
+
|
|
342
|
+
# Compress and store new factors
|
|
343
|
+
if beta1 > 0:
|
|
344
|
+
if not self.grams_moment:
|
|
345
|
+
state['sign'] = _pack_bools(mt > 0)
|
|
346
|
+
_nnmf(mt.abs(), out=(state['mu_m_nmf'], state['mv_m_nmf']))
|
|
347
|
+
del mt
|
|
348
|
+
|
|
349
|
+
if self.use_AdEMAMix:
|
|
350
|
+
state['sign_slow'] = _pack_bools(mt_slow > 0)
|
|
351
|
+
_nnmf(mt_slow.abs(), out=(state['mu_m_slow_nmf'], state['mv_m_slow_nmf']))
|
|
352
|
+
del mt_slow
|
|
353
|
+
|
|
354
|
+
# factorize v_t using NMF compression
|
|
355
|
+
_nnmf(vt, out=(state['mu_v_nmf'], state['mv_v_nmf']))
|
|
356
|
+
del vt
|
|
357
|
+
|
|
358
|
+
else: # Standard ADOPT logic for non-factored tensors
|
|
359
|
+
v = state['exp_avg_sq'] # v_{t-1}
|
|
360
|
+
|
|
361
|
+
# ADOPT Step A: Decorrelate g_t using v_{t-1}
|
|
362
|
+
denom = v.sqrt()
|
|
363
|
+
|
|
364
|
+
if self.use_atan2:
|
|
365
|
+
normalized_grad = torch.atan2(grad, denom)
|
|
366
|
+
else:
|
|
367
|
+
normalized_grad = grad / denom.add_(group['eps'])
|
|
368
|
+
if self.clip_lambda is not None:
|
|
369
|
+
clip_val = self.clip_lambda(state['step'])
|
|
370
|
+
normalized_grad.clamp_(-clip_val, clip_val)
|
|
371
|
+
del denom
|
|
372
|
+
|
|
373
|
+
# ADOPT Step B: Update momentum m_t
|
|
374
|
+
if beta1 > 0:
|
|
375
|
+
m = state['exp_avg'] # m_{t-1},
|
|
376
|
+
if self.Simplified_AdEMAMix:
|
|
377
|
+
m.mul_(beta1).add_(normalized_grad, alpha=1.0)
|
|
378
|
+
else:
|
|
379
|
+
m.mul_(beta1).add_(normalized_grad, alpha=1.0 - beta1)
|
|
380
|
+
|
|
381
|
+
if self.grams_moment:
|
|
382
|
+
m = grad.sign() * m.abs()
|
|
383
|
+
elif self.cautious_mask:
|
|
384
|
+
mask = (m * grad > 0).to(grad.dtype)
|
|
385
|
+
mask.div_(mask.mean().clamp_(min=1e-3))
|
|
386
|
+
m.mul_(mask)
|
|
387
|
+
del mask
|
|
388
|
+
|
|
389
|
+
if self.use_AdEMAMix:
|
|
390
|
+
m_slow = state['exp_avg_slow']
|
|
391
|
+
m_slow.mul_(beta3_ema).add_(normalized_grad, alpha=1.0 - beta3_ema)
|
|
392
|
+
if beta1 > 0:
|
|
393
|
+
update = torch.add(m, m_slow, alpha=alpha_t)
|
|
394
|
+
else:
|
|
395
|
+
update = torch.add(normalized_grad, m_slow, alpha=alpha_t)
|
|
396
|
+
elif self.Simplified_AdEMAMix:
|
|
397
|
+
update = torch.add(m, normalized_grad, alpha=alpha_grad)
|
|
398
|
+
else:
|
|
399
|
+
update = m.clone() if beta1 > 0 else normalized_grad
|
|
400
|
+
|
|
401
|
+
if self.use_atan2:
|
|
402
|
+
update.mul_(group['lr'] * 1.2732395447351628)
|
|
403
|
+
else:
|
|
404
|
+
update.mul_(group['lr'])
|
|
405
|
+
|
|
406
|
+
# Update second moment v_t for the next step using raw g_t
|
|
407
|
+
v.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2)
|
|
408
|
+
|
|
409
|
+
# Parameter Update
|
|
410
|
+
if group["weight_decay"] != 0:
|
|
411
|
+
if p.dtype == torch.bfloat16 and self.stochastic_rounding:
|
|
412
|
+
add_stochastic_(p.data, p.data, alpha=-group["weight_decay"] * group["lr"])
|
|
413
|
+
else:
|
|
414
|
+
p.data.add_(p.data, alpha=-group["weight_decay"] * group["lr"])
|
|
415
|
+
|
|
416
|
+
if p.dtype == torch.bfloat16 and self.stochastic_rounding:
|
|
417
|
+
add_stochastic_(p.data, -update)
|
|
418
|
+
else:
|
|
419
|
+
p.data.add_(-update)
|
|
420
|
+
del update
|
|
421
|
+
|
|
422
|
+
state['step'] += 1
|
|
423
|
+
|
|
424
|
+
@torch.no_grad()
|
|
425
|
+
def step(self, closure=None):
|
|
426
|
+
"""Performs a single optimization step."""
|
|
427
|
+
loss = None
|
|
428
|
+
if closure is not None:
|
|
429
|
+
with torch.enable_grad():
|
|
430
|
+
loss = closure()
|
|
431
|
+
|
|
432
|
+
for group in self.param_groups:
|
|
433
|
+
for i, p in enumerate(group['params']):
|
|
434
|
+
self.step_parameter(p, group, i)
|
|
435
|
+
|
|
440
436
|
return loss
|