adv-optm 2.3.dev2__tar.gz → 2.4.dev1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {adv_optm-2.3.dev2 → adv_optm-2.4.dev1}/PKG-INFO +1 -1
- {adv_optm-2.3.dev2 → adv_optm-2.4.dev1}/adv_optm/__init__.py +1 -7
- {adv_optm-2.3.dev2 → adv_optm-2.4.dev1}/adv_optm/optim/AdaMuon_adv.py +28 -0
- {adv_optm-2.3.dev2 → adv_optm-2.4.dev1}/adv_optm/optim/AdamW_adv.py +86 -27
- {adv_optm-2.3.dev2 → adv_optm-2.4.dev1}/adv_optm/optim/Adopt_adv.py +95 -33
- {adv_optm-2.3.dev2 → adv_optm-2.4.dev1}/adv_optm/optim/Lion_adv.py +80 -5
- {adv_optm-2.3.dev2 → adv_optm-2.4.dev1}/adv_optm/optim/Muon_adv.py +28 -0
- {adv_optm-2.3.dev2 → adv_optm-2.4.dev1}/adv_optm/optim/Prodigy_adv.py +74 -25
- {adv_optm-2.3.dev2 → adv_optm-2.4.dev1}/adv_optm/optim/SignSGD_adv.py +94 -6
- {adv_optm-2.3.dev2 → adv_optm-2.4.dev1}/adv_optm/optim/Simplified_AdEMAMix.py +62 -13
- {adv_optm-2.3.dev2 → adv_optm-2.4.dev1}/adv_optm/util/Muon_AuxAdam.py +3 -0
- adv_optm-2.4.dev1/adv_optm/util/centered_decay.py +112 -0
- adv_optm-2.4.dev1/adv_optm/util/param_update.py +286 -0
- adv_optm-2.4.dev1/adv_optm/util/scaled_optm.py +137 -0
- {adv_optm-2.3.dev2 → adv_optm-2.4.dev1}/adv_optm/util/update_util.py +3 -1
- {adv_optm-2.3.dev2 → adv_optm-2.4.dev1}/adv_optm.egg-info/PKG-INFO +1 -1
- {adv_optm-2.3.dev2 → adv_optm-2.4.dev1}/adv_optm.egg-info/SOURCES.txt +2 -0
- {adv_optm-2.3.dev2 → adv_optm-2.4.dev1}/setup.py +1 -1
- adv_optm-2.3.dev2/adv_optm/util/param_update.py +0 -177
- {adv_optm-2.3.dev2 → adv_optm-2.4.dev1}/LICENSE +0 -0
- {adv_optm-2.3.dev2 → adv_optm-2.4.dev1}/README.md +0 -0
- {adv_optm-2.3.dev2 → adv_optm-2.4.dev1}/adv_optm/optim/Lion_Prodigy_adv.py +0 -0
- {adv_optm-2.3.dev2 → adv_optm-2.4.dev1}/adv_optm/optim/__init__.py +0 -0
- {adv_optm-2.3.dev2 → adv_optm-2.4.dev1}/adv_optm/util/Kourkoutas.py +0 -0
- {adv_optm-2.3.dev2 → adv_optm-2.4.dev1}/adv_optm/util/Muon_util.py +0 -0
- {adv_optm-2.3.dev2 → adv_optm-2.4.dev1}/adv_optm/util/OrthoGrad.py +0 -0
- {adv_optm-2.3.dev2 → adv_optm-2.4.dev1}/adv_optm/util/__init__.py +0 -0
- {adv_optm-2.3.dev2 → adv_optm-2.4.dev1}/adv_optm/util/factorization_util.py +0 -0
- {adv_optm-2.3.dev2 → adv_optm-2.4.dev1}/adv_optm/util/lion_k.py +0 -0
- {adv_optm-2.3.dev2 → adv_optm-2.4.dev1}/adv_optm.egg-info/dependency_links.txt +0 -0
- {adv_optm-2.3.dev2 → adv_optm-2.4.dev1}/adv_optm.egg-info/requires.txt +0 -0
- {adv_optm-2.3.dev2 → adv_optm-2.4.dev1}/adv_optm.egg-info/top_level.txt +0 -0
- {adv_optm-2.3.dev2 → adv_optm-2.4.dev1}/setup.cfg +0 -0
|
@@ -10,11 +10,6 @@ from .optim import (
|
|
|
10
10
|
SignSGD_adv,
|
|
11
11
|
)
|
|
12
12
|
|
|
13
|
-
from .stiefel_optm.Stiefel_LoRA import (
|
|
14
|
-
Stiefel_LoRA,
|
|
15
|
-
)
|
|
16
|
-
|
|
17
|
-
|
|
18
13
|
__all__ = [
|
|
19
14
|
"AdamW_adv",
|
|
20
15
|
"Prodigy_adv",
|
|
@@ -25,7 +20,6 @@ __all__ = [
|
|
|
25
20
|
"Muon_adv",
|
|
26
21
|
"AdaMuon_adv",
|
|
27
22
|
"SignSGD_adv",
|
|
28
|
-
"Stiefel_LoRA",
|
|
29
23
|
]
|
|
30
24
|
|
|
31
|
-
__version__ = "2.
|
|
25
|
+
__version__ = "2.4.dev1"
|
|
@@ -8,6 +8,7 @@ from ..util.factorization_util import _get_effective_shape, _factorize_state, _r
|
|
|
8
8
|
from ..util.OrthoGrad import _orthogonalize_gradient
|
|
9
9
|
from ..util.Kourkoutas import KourkoutasHelper
|
|
10
10
|
from ..util import Muon_AuxAdam
|
|
11
|
+
from ..util.centered_decay import _init_anchor
|
|
11
12
|
|
|
12
13
|
A = 4 / math.pi
|
|
13
14
|
|
|
@@ -87,6 +88,15 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
87
88
|
(default: False)
|
|
88
89
|
mars_gamma (float): The scaling coefficient for MARS gradient correction.
|
|
89
90
|
(default: 0.025)
|
|
91
|
+
centered_wd (float): Centered Weight Decay coefficient. Instead of decaying weights
|
|
92
|
+
toward zero, they are decayed toward their initial values (anchors). This
|
|
93
|
+
can be used together with standard weight decay. (default: 0.0)
|
|
94
|
+
centered_wd_mode (str): The quantization format used to store the anchor
|
|
95
|
+
weights to save VRAM. Options include:
|
|
96
|
+
'full': Stores anchors in the original parameter's precision.
|
|
97
|
+
'float8': Uses torch.float8_e4m3fn for a balance of precision and memory.
|
|
98
|
+
'int8': Uses 8-bit block-wise quantization (block size 128).
|
|
99
|
+
'int4': Uses 4-bit block-wise quantization (block size 32).
|
|
90
100
|
nnmf_factor (bool): whether to use the factorization or disable it to use
|
|
91
101
|
the uncompressed optimizer. (default: False)
|
|
92
102
|
use_muon (bool | None): whether to use Muon or AuxAdamW. MUST be provided
|
|
@@ -157,6 +167,9 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
157
167
|
# Spectral Normalization
|
|
158
168
|
n_layers: int = 1,
|
|
159
169
|
spectral_normalization: bool = False,
|
|
170
|
+
# Centered WD
|
|
171
|
+
centered_wd: float = 0.0,
|
|
172
|
+
centered_wd_mode: str = 'float8',
|
|
160
173
|
# torch.compile
|
|
161
174
|
compiled_optimizer: bool = False,
|
|
162
175
|
# --- AdamW_adv specific parameters ---
|
|
@@ -214,6 +227,9 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
214
227
|
"approx_mars": approx_mars, "mars_gamma": mars_gamma,
|
|
215
228
|
# Spectral Normalization
|
|
216
229
|
"n_layers": n_layers, "spectral_normalization": spectral_normalization,
|
|
230
|
+
# Centered WD
|
|
231
|
+
"centered_wd": centered_wd,
|
|
232
|
+
"centered_wd_mode": centered_wd_mode,
|
|
217
233
|
# AdamW_adv defaults
|
|
218
234
|
"adam_betas": adam_betas, "adam_eps": adam_eps, "adam_weight_decay": adam_weight_decay,
|
|
219
235
|
"adam_use_bias_correction": adam_use_bias_correction, "adam_use_atan2": adam_use_atan2,
|
|
@@ -261,6 +277,16 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
261
277
|
if compiled_optimizer:
|
|
262
278
|
self.compile(fullgraph=True)
|
|
263
279
|
|
|
280
|
+
def load_state_dict(self, state_dict: dict) -> None:
|
|
281
|
+
"""
|
|
282
|
+
Overrides default load_state_dict to implement a workaround for PyTorch's
|
|
283
|
+
automatic dtype casting. It ensures factorized states remain float32 for
|
|
284
|
+
stability, preserves integer/float8 quantized anchor states, and forces
|
|
285
|
+
standard states onto the parameter's current dtype/device.
|
|
286
|
+
"""
|
|
287
|
+
super().load_state_dict(state_dict)
|
|
288
|
+
param_update.post_process_loaded_state(self)
|
|
289
|
+
|
|
264
290
|
@property
|
|
265
291
|
def supports_fused_back_pass(self):
|
|
266
292
|
return True
|
|
@@ -344,6 +370,8 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
344
370
|
# Note: This requires full-rank memory even if factored
|
|
345
371
|
state['last_grad'] = torch.zeros_like(p, device=device, dtype=p.dtype)
|
|
346
372
|
|
|
373
|
+
_init_anchor(p, state, group)
|
|
374
|
+
|
|
347
375
|
group['adam_kourkoutas_beta'] = False
|
|
348
376
|
state['is_muon'] = True # Workaround as group was acting weirdly; passing muon params in adam path
|
|
349
377
|
|
|
@@ -9,6 +9,8 @@ from ..util.factorization_util import _get_effective_shape, _reconstruct_state,
|
|
|
9
9
|
from ..util.update_util import _grams_update, _cautious_update
|
|
10
10
|
from ..util.OrthoGrad import _orthogonalize_gradient
|
|
11
11
|
from ..util.Kourkoutas import KourkoutasHelper
|
|
12
|
+
from ..util.scaled_optm import scale_update, is_spectral, init_spectral_norm
|
|
13
|
+
from ..util.centered_decay import _init_anchor
|
|
12
14
|
|
|
13
15
|
A = 4 / math.pi
|
|
14
16
|
|
|
@@ -78,8 +80,19 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
78
80
|
and returns a unique, hashable key representing its "layer" or "bucket".
|
|
79
81
|
If `None`, parameters are bucketed by their memory ID (tensor-wise).
|
|
80
82
|
(default: None)
|
|
83
|
+
centered_wd (float): Centered Weight Decay coefficient. Instead of decaying weights
|
|
84
|
+
toward zero, they are decayed toward their initial values (anchors). This
|
|
85
|
+
can be used together with standard weight decay. (default: 0.0)
|
|
86
|
+
centered_wd_mode (str): The quantization format used to store the anchor
|
|
87
|
+
weights to save VRAM. Options include:
|
|
88
|
+
'full': Stores anchors in the original parameter's precision.
|
|
89
|
+
'float8': Uses torch.float8_e4m3fn for a balance of precision and memory.
|
|
90
|
+
'int8': Uses 8-bit block-wise quantization (block size 128).
|
|
91
|
+
'int4': Uses 4-bit block-wise quantization (block size 32).
|
|
81
92
|
nnmf_factor (bool): whether to use the factorization or disable it to use
|
|
82
93
|
the uncompressed optimizer. (default: False)
|
|
94
|
+
factored_2nd (bool): whether to keep the first moment uncompressed (dense)
|
|
95
|
+
while only factorizing the second moment. (default: True)
|
|
83
96
|
"""
|
|
84
97
|
|
|
85
98
|
def __init__(
|
|
@@ -114,9 +127,15 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
114
127
|
k_warmup_steps: int = 0,
|
|
115
128
|
k_logging: int = 0,
|
|
116
129
|
layer_key_fn: Optional[Callable] = None,
|
|
130
|
+
# Scaled Optimizer
|
|
131
|
+
scaled_optm: bool = False,
|
|
132
|
+
# Centered WD
|
|
133
|
+
centered_wd: float = 0.0,
|
|
134
|
+
centered_wd_mode: str = 'float8',
|
|
117
135
|
# SMMF factorization
|
|
118
136
|
nnmf_factor: bool = False,
|
|
119
137
|
vector_reshape: bool = False,
|
|
138
|
+
factored_2nd: bool = False,
|
|
120
139
|
# torch.compile
|
|
121
140
|
compiled_optimizer: bool = False,
|
|
122
141
|
):
|
|
@@ -137,12 +156,14 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
137
156
|
|
|
138
157
|
defaults = {
|
|
139
158
|
"lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay, "cautious_wd": cautious_wd,
|
|
140
|
-
"
|
|
159
|
+
"use_atan2": use_atan2,
|
|
141
160
|
"orthogonal_gradient": orthogonal_gradient, "use_bias_correction": use_bias_correction,
|
|
142
161
|
"beta3_ema": beta3_ema, "alpha": alpha, "compiled_optimizer": compiled_optimizer,
|
|
143
162
|
"kourkoutas_beta": kourkoutas_beta, "beta2_min": beta2_min, "ema_alpha": ema_alpha,
|
|
144
163
|
"tiny_spike": tiny_spike, "k_warmup_steps": k_warmup_steps, "k_logging": k_logging,
|
|
145
|
-
"
|
|
164
|
+
"scaled_optm": scaled_optm,
|
|
165
|
+
"centered_wd": centered_wd, "centered_wd_mode": centered_wd_mode,
|
|
166
|
+
"nnmf_factor": nnmf_factor, "vector_reshape": vector_reshape, "factored_2nd": factored_2nd
|
|
146
167
|
}
|
|
147
168
|
self.stochastic_rounding = stochastic_rounding
|
|
148
169
|
self.cautious_mask = cautious_mask
|
|
@@ -150,6 +171,7 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
150
171
|
self.use_AdEMAMix = use_AdEMAMix
|
|
151
172
|
self.kourkoutas_beta = kourkoutas_beta
|
|
152
173
|
self.layer_key_fn = layer_key_fn
|
|
174
|
+
self._init_lr = lr
|
|
153
175
|
super().__init__(params, defaults)
|
|
154
176
|
|
|
155
177
|
if self.kourkoutas_beta:
|
|
@@ -167,6 +189,16 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
167
189
|
if compiled_optimizer:
|
|
168
190
|
self.compile(fullgraph=True)
|
|
169
191
|
|
|
192
|
+
def load_state_dict(self, state_dict: dict) -> None:
|
|
193
|
+
"""
|
|
194
|
+
Overrides default load_state_dict to implement a workaround for PyTorch's
|
|
195
|
+
automatic dtype casting. It ensures factorized states remain float32 for
|
|
196
|
+
stability, preserves integer/float8 quantized anchor states, and forces
|
|
197
|
+
standard states onto the parameter's current dtype/device.
|
|
198
|
+
"""
|
|
199
|
+
super().load_state_dict(state_dict)
|
|
200
|
+
param_update.post_process_loaded_state(self)
|
|
201
|
+
|
|
170
202
|
@property
|
|
171
203
|
def supports_fused_back_pass(self):
|
|
172
204
|
return True
|
|
@@ -194,6 +226,7 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
194
226
|
state['factored'] = (
|
|
195
227
|
group['nnmf_factor'] and
|
|
196
228
|
not (len(p.shape) == 1 and not group['vector_reshape'])
|
|
229
|
+
or group["factored_2nd"]
|
|
197
230
|
)
|
|
198
231
|
|
|
199
232
|
dtype = torch.float32 if state['factored'] else p.dtype
|
|
@@ -203,18 +236,25 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
203
236
|
state['effective_shape'] = _get_effective_shape(p.numel())
|
|
204
237
|
d1, d2 = state['effective_shape']
|
|
205
238
|
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
239
|
+
if not group.get('factored_2nd', False):
|
|
240
|
+
# First moment (m)
|
|
241
|
+
if group['betas'][0] > 0:
|
|
242
|
+
state['mu_m_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
|
|
243
|
+
state['mv_m_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
|
|
244
|
+
packed_d2 = (d2 + 7) // 8
|
|
245
|
+
state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=device)
|
|
246
|
+
# AdEMAMix slow moment (m_slow)
|
|
247
|
+
if self.use_AdEMAMix:
|
|
248
|
+
state['mu_m_slow_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
|
|
249
|
+
state['mv_m_slow_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
|
|
250
|
+
packed_d2 = (d2 + 7) // 8
|
|
251
|
+
state['sign_slow'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=device)
|
|
252
|
+
else:
|
|
253
|
+
if group['betas'][0] > 0:
|
|
254
|
+
state['exp_avg'] = torch.zeros_like(p, device=device, dtype=dtype)
|
|
255
|
+
if self.use_AdEMAMix:
|
|
256
|
+
state['exp_avg_slow'] = torch.zeros_like(p, device=device, dtype=dtype)
|
|
257
|
+
|
|
218
258
|
# Second moment (v)
|
|
219
259
|
state['mu_v_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
|
|
220
260
|
state['mv_v_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
|
|
@@ -228,6 +268,11 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
228
268
|
# Second moment (v)
|
|
229
269
|
state['exp_avg_sq'] = torch.zeros_like(p, device=device, dtype=dtype)
|
|
230
270
|
|
|
271
|
+
if group.get('scaled_optm', False) and is_spectral(p):
|
|
272
|
+
init_spectral_norm(group, state, p)
|
|
273
|
+
|
|
274
|
+
_init_anchor(p, state, group)
|
|
275
|
+
|
|
231
276
|
beta1, beta2 = group['betas']
|
|
232
277
|
|
|
233
278
|
current_step = state['step']
|
|
@@ -275,32 +320,42 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
275
320
|
# Accumulate current grad's norm for the *next* step
|
|
276
321
|
self.kourkoutas_helper.accumulate_gradient_sq_norm(p, grad)
|
|
277
322
|
|
|
323
|
+
# Determine if we are using dense first-moments alongside a factored second-order second-moment
|
|
324
|
+
factored_2nd = group.get('factored_2nd', False)
|
|
325
|
+
|
|
278
326
|
if state['factored']:
|
|
279
327
|
d1, d2 = state['effective_shape']
|
|
280
328
|
grad_reshaped = grad.view(d1, d2)
|
|
281
329
|
|
|
282
330
|
# Reconstruct momentum from previous step's factors
|
|
283
331
|
if beta1 > 0:
|
|
284
|
-
|
|
332
|
+
if factored_2nd:
|
|
333
|
+
mt = state['exp_avg'].view(d1, d2)
|
|
334
|
+
else:
|
|
335
|
+
mt = _reconstruct_state((state['mu_m_nmf'], state['mv_m_nmf'], state['sign'], d2), signed=True)
|
|
285
336
|
|
|
286
337
|
# Update momentum in full-size
|
|
287
338
|
mt.lerp_(grad_reshaped, 1.0 - beta1)
|
|
288
339
|
|
|
289
|
-
|
|
290
|
-
|
|
340
|
+
if not factored_2nd:
|
|
341
|
+
# Factorize
|
|
342
|
+
state['mu_m_nmf'], state['mv_m_nmf'], state['sign'] = _factorize_state(mt.clone(), signed=True)
|
|
291
343
|
|
|
292
344
|
if self.grams_moment:
|
|
293
|
-
update_mt = _grams_update(mt, grad_reshaped, inplace=
|
|
345
|
+
update_mt = _grams_update(mt, grad_reshaped, inplace=not factored_2nd)
|
|
294
346
|
elif self.cautious_mask:
|
|
295
|
-
update_mt = _cautious_update(mt, grad_reshaped, inplace=
|
|
347
|
+
update_mt = _cautious_update(mt, grad_reshaped, inplace=not factored_2nd)
|
|
296
348
|
else:
|
|
297
|
-
update_mt = mt
|
|
349
|
+
update_mt = mt if not factored_2nd else mt.clone()
|
|
298
350
|
|
|
299
351
|
vt = _reconstruct_state((state['mu_v_nmf'], state['mv_v_nmf']), signed=False)
|
|
300
352
|
vt.mul_(beta2).addcmul_(grad_reshaped, grad_reshaped, value=1.0 - beta2)
|
|
301
353
|
|
|
302
354
|
if self.use_AdEMAMix:
|
|
303
|
-
|
|
355
|
+
if factored_2nd:
|
|
356
|
+
mt_slow = state['exp_avg_slow'].view(d1, d2)
|
|
357
|
+
else:
|
|
358
|
+
mt_slow = _reconstruct_state((state['mu_m_slow_nmf'], state['mv_m_slow_nmf'], state['sign_slow'], d2), signed=True)
|
|
304
359
|
|
|
305
360
|
mt_slow.lerp_(grad_reshaped, 1.0 - beta3_ema)
|
|
306
361
|
|
|
@@ -308,9 +363,11 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
308
363
|
update = update_mt.add_(mt_slow, alpha=alpha)
|
|
309
364
|
else:
|
|
310
365
|
update = grad_reshaped.add(mt_slow, alpha=alpha)
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
366
|
+
|
|
367
|
+
if not factored_2nd:
|
|
368
|
+
# Factorize
|
|
369
|
+
state['mu_m_slow_nmf'], state['mv_m_slow_nmf'], state['sign_slow'] = _factorize_state(mt_slow, signed=True)
|
|
370
|
+
del mt_slow
|
|
314
371
|
else:
|
|
315
372
|
if beta1 > 0:
|
|
316
373
|
update = update_mt
|
|
@@ -330,8 +387,7 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
330
387
|
update.div_(denom)
|
|
331
388
|
del vt
|
|
332
389
|
|
|
333
|
-
|
|
334
|
-
update = update.view(p.shape).mul_(update_scaling)
|
|
390
|
+
update = update.view(p.shape)
|
|
335
391
|
|
|
336
392
|
else: # Standard AdamW logic for non-factored tensors
|
|
337
393
|
if beta1 > 0:
|
|
@@ -369,7 +425,10 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
369
425
|
update.div_(denom)
|
|
370
426
|
del denom
|
|
371
427
|
|
|
372
|
-
|
|
428
|
+
update_scaling = step_size * A if group['use_atan2'] else step_size
|
|
429
|
+
if group.get('scaled_optm', False):
|
|
430
|
+
update = scale_update(p, update, update_scaling, vector_state=state.get('spectral_v'))
|
|
431
|
+
else:
|
|
373
432
|
update.mul_(update_scaling)
|
|
374
433
|
|
|
375
434
|
param_update.apply_parameter_update(self, p, group, update, step_size, random_int_tensor=random_int_tensor)
|
|
@@ -8,6 +8,8 @@ from ..util.factorization_util import _get_effective_shape, _reconstruct_state,
|
|
|
8
8
|
from ..util.OrthoGrad import _orthogonalize_gradient
|
|
9
9
|
from ..util.Kourkoutas import KourkoutasHelper
|
|
10
10
|
from ..util.update_util import _grams_update, _cautious_update, _scale_sim_AdEMAMix_update
|
|
11
|
+
from ..util.scaled_optm import scale_update, is_spectral, init_spectral_norm
|
|
12
|
+
from ..util.centered_decay import _init_anchor
|
|
11
13
|
|
|
12
14
|
A = 4 / math.pi
|
|
13
15
|
|
|
@@ -94,8 +96,19 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
94
96
|
and returns a unique, hashable key representing its "layer" or "bucket".
|
|
95
97
|
If `None`, parameters are bucketed by their memory ID (tensor-wise).
|
|
96
98
|
(default: None)
|
|
99
|
+
centered_wd (float): Centered Weight Decay coefficient. Instead of decaying weights
|
|
100
|
+
toward zero, they are decayed toward their initial values (anchors). This
|
|
101
|
+
can be used together with standard weight decay. (default: 0.0)
|
|
102
|
+
centered_wd_mode (str): The quantization format used to store the anchor
|
|
103
|
+
weights to save VRAM. Options include:
|
|
104
|
+
'full': Stores anchors in the original parameter's precision.
|
|
105
|
+
'float8': Uses torch.float8_e4m3fn for a balance of precision and memory.
|
|
106
|
+
'int8': Uses 8-bit block-wise quantization (block size 128).
|
|
107
|
+
'int4': Uses 4-bit block-wise quantization (block size 32).
|
|
97
108
|
nnmf_factor (bool): whether to use the factorization or disable it to use
|
|
98
109
|
the uncompressed optimizer. (default: False)
|
|
110
|
+
factored_2nd (bool): whether to keep the first moment uncompressed (dense)
|
|
111
|
+
while only factorizing the second moment. (default: True)
|
|
99
112
|
"""
|
|
100
113
|
|
|
101
114
|
def __init__(
|
|
@@ -133,9 +146,15 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
133
146
|
k_warmup_steps: int = 0,
|
|
134
147
|
k_logging: int = 0,
|
|
135
148
|
layer_key_fn: Optional[Callable] = None,
|
|
149
|
+
# Scaled Optimizer
|
|
150
|
+
scaled_optm: bool = False,
|
|
151
|
+
# Centered WD
|
|
152
|
+
centered_wd: float = 0.0,
|
|
153
|
+
centered_wd_mode: str = 'float8',
|
|
136
154
|
# SMMF factorization
|
|
137
155
|
nnmf_factor: bool = False,
|
|
138
|
-
vector_reshape: bool =
|
|
156
|
+
vector_reshape: bool = True,
|
|
157
|
+
factored_2nd: bool = False,
|
|
139
158
|
# torch.compile
|
|
140
159
|
compiled_optimizer: bool = False,
|
|
141
160
|
):
|
|
@@ -163,11 +182,14 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
163
182
|
|
|
164
183
|
defaults = {
|
|
165
184
|
"lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay, "cautious_wd": cautious_wd,
|
|
166
|
-
"
|
|
185
|
+
"beta3_ema": beta3_ema, "alpha": alpha,
|
|
167
186
|
"alpha_grad": alpha_grad,
|
|
168
187
|
"kourkoutas_beta": kourkoutas_beta, "beta2_min": beta2_min, "ema_alpha": ema_alpha,
|
|
169
188
|
"tiny_spike": tiny_spike, "k_warmup_steps": k_warmup_steps, "k_logging": k_logging,
|
|
170
|
-
"
|
|
189
|
+
"scaled_optm": scaled_optm,
|
|
190
|
+
"centered_wd": centered_wd,
|
|
191
|
+
"centered_wd_mode": centered_wd_mode,
|
|
192
|
+
"nnmf_factor": nnmf_factor, "vector_reshape": vector_reshape, "factored_2nd": factored_2nd,
|
|
171
193
|
"compiled_optimizer": compiled_optimizer,
|
|
172
194
|
}
|
|
173
195
|
self.clip_lambda = clip_lambda
|
|
@@ -180,6 +202,7 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
180
202
|
self.Simplified_AdEMAMix = Simplified_AdEMAMix
|
|
181
203
|
self.kourkoutas_beta = kourkoutas_beta
|
|
182
204
|
self.layer_key_fn = layer_key_fn
|
|
205
|
+
self._init_lr = lr
|
|
183
206
|
super().__init__(params, defaults)
|
|
184
207
|
|
|
185
208
|
if self.kourkoutas_beta:
|
|
@@ -196,6 +219,16 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
196
219
|
if compiled_optimizer:
|
|
197
220
|
self.compile(fullgraph=True)
|
|
198
221
|
|
|
222
|
+
def load_state_dict(self, state_dict: dict) -> None:
|
|
223
|
+
"""
|
|
224
|
+
Overrides default load_state_dict to implement a workaround for PyTorch's
|
|
225
|
+
automatic dtype casting. It ensures factorized states remain float32 for
|
|
226
|
+
stability, preserves integer/float8 quantized anchor states, and forces
|
|
227
|
+
standard states onto the parameter's current dtype/device.
|
|
228
|
+
"""
|
|
229
|
+
super().load_state_dict(state_dict)
|
|
230
|
+
param_update.post_process_loaded_state(self)
|
|
231
|
+
|
|
199
232
|
@property
|
|
200
233
|
def supports_fused_back_pass(self): return True
|
|
201
234
|
@property
|
|
@@ -218,6 +251,7 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
218
251
|
state['factored'] = (
|
|
219
252
|
group['nnmf_factor'] and
|
|
220
253
|
not (len(p.shape) == 1 and not group['vector_reshape'])
|
|
254
|
+
or group["factored_2nd"]
|
|
221
255
|
)
|
|
222
256
|
|
|
223
257
|
dtype = torch.float32 if state['factored'] else p.dtype
|
|
@@ -226,18 +260,24 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
226
260
|
state['effective_shape'] = _get_effective_shape(p.numel())
|
|
227
261
|
d1, d2 = state['effective_shape']
|
|
228
262
|
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
263
|
+
if not group.get('factored_2nd', False):
|
|
264
|
+
# First moment (m)
|
|
265
|
+
if group['betas'][0] > 0:
|
|
266
|
+
state['mu_m_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
|
|
267
|
+
state['mv_m_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
|
|
268
|
+
packed_d2 = (d2 + 7) // 8
|
|
269
|
+
state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=p.device)
|
|
270
|
+
# AdEMAMix slow moment (m_slow)
|
|
271
|
+
if self.use_AdEMAMix:
|
|
272
|
+
state['mu_m_slow_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
|
|
273
|
+
state['mv_m_slow_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
|
|
274
|
+
packed_d2 = (d2 + 7) // 8
|
|
275
|
+
state['sign_slow'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=p.device)
|
|
276
|
+
else:
|
|
277
|
+
if group['betas'][0] > 0:
|
|
278
|
+
state['exp_avg'] = torch.zeros_like(p, device=p.device, dtype=dtype)
|
|
279
|
+
if self.use_AdEMAMix:
|
|
280
|
+
state['exp_avg_slow'] = torch.zeros_like(p, device=p.device, dtype=dtype)
|
|
241
281
|
# Second moment (v)
|
|
242
282
|
vt_init = grad.to(dtype).view(d1, d2).square()
|
|
243
283
|
# Allocate NMF factors for vt
|
|
@@ -253,6 +293,11 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
253
293
|
state['exp_avg_slow'] = torch.zeros_like(p, device=p.device, dtype=dtype)
|
|
254
294
|
state['exp_avg_sq'] = grad.to(dtype).square()
|
|
255
295
|
|
|
296
|
+
if group.get('scaled_optm', False) and is_spectral(p):
|
|
297
|
+
init_spectral_norm(group, state, p)
|
|
298
|
+
|
|
299
|
+
_init_anchor(p, state, group)
|
|
300
|
+
|
|
256
301
|
beta1, beta2 = group['betas']
|
|
257
302
|
|
|
258
303
|
current_step = state['step']
|
|
@@ -280,7 +325,7 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
280
325
|
step_param_fn = self._step_parameter
|
|
281
326
|
|
|
282
327
|
if self.Simplified_AdEMAMix:
|
|
283
|
-
lr = _scale_sim_AdEMAMix_update(beta1, state['step'] + 1, group["alpha_grad"], lr)
|
|
328
|
+
lr = _scale_sim_AdEMAMix_update(beta1, state['step'] + 1, group["alpha_grad"], lr, group.get('scaled_optm', False))
|
|
284
329
|
|
|
285
330
|
step_param_fn(p, grad, state, group, lr, beta1, beta2, random_int_tensor)
|
|
286
331
|
|
|
@@ -302,6 +347,9 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
302
347
|
# Accumulate current grad's norm for the *next* step
|
|
303
348
|
self.kourkoutas_helper.accumulate_gradient_sq_norm(p, grad)
|
|
304
349
|
|
|
350
|
+
# Determine if we are using dense first-moments alongside a factored second-order second-moment
|
|
351
|
+
factored_2nd = group.get('factored_2nd', False)
|
|
352
|
+
|
|
305
353
|
if state['factored']:
|
|
306
354
|
d1, d2 = state['effective_shape']
|
|
307
355
|
grad_reshaped = grad.view(d1, d2)
|
|
@@ -328,35 +376,47 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
328
376
|
|
|
329
377
|
# ADOPT Step B: Update momentum m_t using normalized gradient
|
|
330
378
|
if beta1 > 0:
|
|
331
|
-
|
|
332
|
-
|
|
379
|
+
if factored_2nd:
|
|
380
|
+
mt = state['exp_avg'].view(d1, d2)
|
|
381
|
+
else:
|
|
382
|
+
# Reconstruct m_{t-1}
|
|
383
|
+
mt = _reconstruct_state((state['mu_m_nmf'], state['mv_m_nmf'], state['sign'], d2), signed=True)
|
|
384
|
+
|
|
333
385
|
if self.Simplified_AdEMAMix:
|
|
334
386
|
mt.mul_(beta1).add_(normalized_grad, alpha=1.0)
|
|
335
387
|
else:
|
|
336
388
|
mt.lerp_(normalized_grad, 1.0 - beta1)
|
|
337
389
|
|
|
338
|
-
|
|
339
|
-
|
|
390
|
+
if not factored_2nd:
|
|
391
|
+
# Factorize
|
|
392
|
+
state['mu_m_nmf'], state['mv_m_nmf'], state['sign'] = _factorize_state(mt.clone(), signed=True)
|
|
340
393
|
|
|
341
394
|
if self.grams_moment:
|
|
342
|
-
update_mt = _grams_update(mt, grad_reshaped, inplace=
|
|
395
|
+
update_mt = _grams_update(mt, grad_reshaped, inplace=not factored_2nd)
|
|
343
396
|
elif self.cautious_mask:
|
|
344
|
-
update_mt = _cautious_update(mt, grad_reshaped, inplace=
|
|
397
|
+
update_mt = _cautious_update(mt, grad_reshaped, inplace=not factored_2nd)
|
|
345
398
|
else:
|
|
346
|
-
update_mt = mt
|
|
399
|
+
update_mt = mt if not factored_2nd else mt.clone()
|
|
347
400
|
|
|
348
401
|
if self.use_AdEMAMix:
|
|
349
|
-
|
|
350
|
-
|
|
402
|
+
if factored_2nd:
|
|
403
|
+
mt_slow = state['exp_avg_slow'].view(d1, d2)
|
|
404
|
+
else:
|
|
405
|
+
# Reconstruct AdEMAMix EMA
|
|
406
|
+
mt_slow = _reconstruct_state((state['mu_m_slow_nmf'], state['mv_m_slow_nmf'], state['sign_slow'], d2), signed=True)
|
|
407
|
+
|
|
351
408
|
mt_slow.lerp_(normalized_grad, 1.0 - beta3_ema)
|
|
409
|
+
|
|
352
410
|
if beta1 > 0:
|
|
353
411
|
update = update_mt.add_(mt_slow, alpha=alpha)
|
|
354
412
|
del normalized_grad
|
|
355
413
|
else:
|
|
356
414
|
update = normalized_grad.add_(mt_slow, alpha=alpha)
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
415
|
+
if not factored_2nd:
|
|
416
|
+
# Factorize
|
|
417
|
+
state['mu_m_slow_nmf'], state['mv_m_slow_nmf'], state['sign_slow'] = _factorize_state(mt_slow, signed=True)
|
|
418
|
+
del mt_slow
|
|
419
|
+
|
|
360
420
|
elif self.Simplified_AdEMAMix:
|
|
361
421
|
update = update_mt.add_(normalized_grad, alpha=alpha_grad)
|
|
362
422
|
del normalized_grad
|
|
@@ -369,9 +429,6 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
369
429
|
|
|
370
430
|
update = update.view(p.shape)
|
|
371
431
|
|
|
372
|
-
update_scaling = lr * A if self.use_atan2 else lr
|
|
373
|
-
update.mul_(update_scaling)
|
|
374
|
-
|
|
375
432
|
else: # Standard ADOPT logic for non-factored tensors
|
|
376
433
|
vt = state['exp_avg_sq'] # v_{t-1}
|
|
377
434
|
|
|
@@ -418,12 +475,17 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
418
475
|
else:
|
|
419
476
|
update = normalized_grad
|
|
420
477
|
|
|
421
|
-
update_scaling = lr * A if self.use_atan2 else lr
|
|
422
|
-
update.mul_(update_scaling)
|
|
423
478
|
|
|
424
479
|
# Update second moment v_t for the next step using raw g_t
|
|
425
480
|
vt.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
|
|
426
481
|
|
|
482
|
+
update_scaling = lr * A if self.use_atan2 else lr
|
|
483
|
+
|
|
484
|
+
if group.get('scaled_optm', False):
|
|
485
|
+
update = scale_update(p, update, update_scaling, vector_state=state.get('spectral_v'))
|
|
486
|
+
else:
|
|
487
|
+
update.mul_(update_scaling)
|
|
488
|
+
|
|
427
489
|
# Parameter Update
|
|
428
490
|
param_update.apply_parameter_update(self, p, group, update, lr, random_int_tensor=random_int_tensor)
|
|
429
491
|
|