adv-optm 2.4.dev2__tar.gz → 2.4.dev4__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {adv_optm-2.4.dev2 → adv_optm-2.4.dev4}/PKG-INFO +1 -1
- {adv_optm-2.4.dev2 → adv_optm-2.4.dev4}/adv_optm/__init__.py +1 -1
- {adv_optm-2.4.dev2 → adv_optm-2.4.dev4}/adv_optm/optim/AdaMuon_adv.py +2 -2
- {adv_optm-2.4.dev2 → adv_optm-2.4.dev4}/adv_optm/optim/AdamW_adv.py +13 -6
- {adv_optm-2.4.dev2 → adv_optm-2.4.dev4}/adv_optm/optim/Adopt_adv.py +33 -21
- {adv_optm-2.4.dev2 → adv_optm-2.4.dev4}/adv_optm/optim/Lion_adv.py +9 -7
- {adv_optm-2.4.dev2 → adv_optm-2.4.dev4}/adv_optm/optim/Muon_adv.py +2 -2
- {adv_optm-2.4.dev2 → adv_optm-2.4.dev4}/adv_optm/optim/Prodigy_adv.py +13 -7
- {adv_optm-2.4.dev2 → adv_optm-2.4.dev4}/adv_optm/optim/SignSGD_adv.py +10 -11
- {adv_optm-2.4.dev2 → adv_optm-2.4.dev4}/adv_optm/optim/Simplified_AdEMAMix.py +11 -5
- {adv_optm-2.4.dev2 → adv_optm-2.4.dev4}/adv_optm/util/Kourkoutas.py +43 -12
- {adv_optm-2.4.dev2 → adv_optm-2.4.dev4}/adv_optm/util/Muon_AuxAdam.py +8 -2
- adv_optm-2.4.dev4/adv_optm/util/OrthoGrad.py +50 -0
- {adv_optm-2.4.dev2 → adv_optm-2.4.dev4}/adv_optm/util/centered_decay.py +1 -1
- {adv_optm-2.4.dev2 → adv_optm-2.4.dev4}/adv_optm/util/param_update.py +5 -5
- {adv_optm-2.4.dev2 → adv_optm-2.4.dev4}/adv_optm/util/scaled_optm.py +9 -5
- adv_optm-2.4.dev4/adv_optm/util/update_util.py +73 -0
- {adv_optm-2.4.dev2 → adv_optm-2.4.dev4}/adv_optm.egg-info/PKG-INFO +1 -1
- {adv_optm-2.4.dev2 → adv_optm-2.4.dev4}/setup.py +1 -1
- adv_optm-2.4.dev2/adv_optm/util/OrthoGrad.py +0 -21
- adv_optm-2.4.dev2/adv_optm/util/update_util.py +0 -32
- {adv_optm-2.4.dev2 → adv_optm-2.4.dev4}/LICENSE +0 -0
- {adv_optm-2.4.dev2 → adv_optm-2.4.dev4}/README.md +0 -0
- {adv_optm-2.4.dev2 → adv_optm-2.4.dev4}/adv_optm/optim/Lion_Prodigy_adv.py +0 -0
- {adv_optm-2.4.dev2 → adv_optm-2.4.dev4}/adv_optm/optim/__init__.py +0 -0
- {adv_optm-2.4.dev2 → adv_optm-2.4.dev4}/adv_optm/util/Muon_util.py +0 -0
- {adv_optm-2.4.dev2 → adv_optm-2.4.dev4}/adv_optm/util/__init__.py +0 -0
- {adv_optm-2.4.dev2 → adv_optm-2.4.dev4}/adv_optm/util/factorization_util.py +0 -0
- {adv_optm-2.4.dev2 → adv_optm-2.4.dev4}/adv_optm/util/lion_k.py +0 -0
- {adv_optm-2.4.dev2 → adv_optm-2.4.dev4}/adv_optm.egg-info/SOURCES.txt +0 -0
- {adv_optm-2.4.dev2 → adv_optm-2.4.dev4}/adv_optm.egg-info/dependency_links.txt +0 -0
- {adv_optm-2.4.dev2 → adv_optm-2.4.dev4}/adv_optm.egg-info/requires.txt +0 -0
- {adv_optm-2.4.dev2 → adv_optm-2.4.dev4}/adv_optm.egg-info/top_level.txt +0 -0
- {adv_optm-2.4.dev2 → adv_optm-2.4.dev4}/setup.cfg +0 -0
|
@@ -280,8 +280,8 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
280
280
|
def load_state_dict(self, state_dict: dict) -> None:
|
|
281
281
|
"""
|
|
282
282
|
Overrides default load_state_dict to implement a workaround for PyTorch's
|
|
283
|
-
automatic dtype casting. It ensures factorized states remain float32 for
|
|
284
|
-
stability, preserves integer/float8 quantized anchor states, and forces
|
|
283
|
+
automatic dtype casting. It ensures factorized states remain float32 for
|
|
284
|
+
stability, preserves integer/float8 quantized anchor states, and forces
|
|
285
285
|
standard states onto the parameter's current dtype/device.
|
|
286
286
|
"""
|
|
287
287
|
super().load_state_dict(state_dict)
|
|
@@ -91,7 +91,7 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
91
91
|
'int4': Uses 4-bit block-wise quantization (block size 32).
|
|
92
92
|
nnmf_factor (bool): whether to use the factorization or disable it to use
|
|
93
93
|
the uncompressed optimizer. (default: False)
|
|
94
|
-
factored_2nd (bool): whether to keep the first moment uncompressed (dense)
|
|
94
|
+
factored_2nd (bool): whether to keep the first moment uncompressed (dense)
|
|
95
95
|
while only factorizing the second moment. (default: True)
|
|
96
96
|
"""
|
|
97
97
|
|
|
@@ -192,8 +192,8 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
192
192
|
def load_state_dict(self, state_dict: dict) -> None:
|
|
193
193
|
"""
|
|
194
194
|
Overrides default load_state_dict to implement a workaround for PyTorch's
|
|
195
|
-
automatic dtype casting. It ensures factorized states remain float32 for
|
|
196
|
-
stability, preserves integer/float8 quantized anchor states, and forces
|
|
195
|
+
automatic dtype casting. It ensures factorized states remain float32 for
|
|
196
|
+
stability, preserves integer/float8 quantized anchor states, and forces
|
|
197
197
|
standard states onto the parameter's current dtype/device.
|
|
198
198
|
"""
|
|
199
199
|
super().load_state_dict(state_dict)
|
|
@@ -349,7 +349,11 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
349
349
|
update_mt = mt if not factored_2nd else mt.clone()
|
|
350
350
|
|
|
351
351
|
vt = _reconstruct_state((state['mu_v_nmf'], state['mv_v_nmf']), signed=False)
|
|
352
|
-
|
|
352
|
+
|
|
353
|
+
if isinstance(beta2, torch.Tensor) and beta2.dim() > 0:
|
|
354
|
+
vt.mul_(beta2).addcmul_(grad_reshaped, grad_reshaped * (1.0 - beta2))
|
|
355
|
+
else:
|
|
356
|
+
vt.mul_(beta2).addcmul_(grad_reshaped, grad_reshaped, value=1.0 - beta2)
|
|
353
357
|
|
|
354
358
|
if self.use_AdEMAMix:
|
|
355
359
|
if factored_2nd:
|
|
@@ -363,7 +367,7 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
363
367
|
update = update_mt.add_(mt_slow, alpha=alpha)
|
|
364
368
|
else:
|
|
365
369
|
update = grad_reshaped.add(mt_slow, alpha=alpha)
|
|
366
|
-
|
|
370
|
+
|
|
367
371
|
if not factored_2nd:
|
|
368
372
|
# Factorize
|
|
369
373
|
state['mu_m_slow_nmf'], state['mv_m_slow_nmf'], state['sign_slow'] = _factorize_state(mt_slow, signed=True)
|
|
@@ -413,7 +417,10 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
413
417
|
update = update_mt if beta1 > 0 else grad.clone()
|
|
414
418
|
|
|
415
419
|
exp_avg_sq = state['exp_avg_sq']
|
|
416
|
-
|
|
420
|
+
if isinstance(beta2, torch.Tensor) and beta2.dim() > 0:
|
|
421
|
+
exp_avg_sq.mul_(beta2).addcmul_(grad, grad * (1.0 - beta2))
|
|
422
|
+
else:
|
|
423
|
+
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)
|
|
417
424
|
|
|
418
425
|
if group['use_atan2']:
|
|
419
426
|
denom = exp_avg_sq.sqrt()
|
|
@@ -107,7 +107,7 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
107
107
|
'int4': Uses 4-bit block-wise quantization (block size 32).
|
|
108
108
|
nnmf_factor (bool): whether to use the factorization or disable it to use
|
|
109
109
|
the uncompressed optimizer. (default: False)
|
|
110
|
-
factored_2nd (bool): whether to keep the first moment uncompressed (dense)
|
|
110
|
+
factored_2nd (bool): whether to keep the first moment uncompressed (dense)
|
|
111
111
|
while only factorizing the second moment. (default: True)
|
|
112
112
|
"""
|
|
113
113
|
|
|
@@ -189,7 +189,7 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
189
189
|
"scaled_optm": scaled_optm,
|
|
190
190
|
"centered_wd": centered_wd,
|
|
191
191
|
"centered_wd_mode": centered_wd_mode,
|
|
192
|
-
"nnmf_factor": nnmf_factor, "vector_reshape": vector_reshape, "factored_2nd": factored_2nd,
|
|
192
|
+
"nnmf_factor": nnmf_factor, "vector_reshape": vector_reshape, "factored_2nd": factored_2nd,
|
|
193
193
|
"compiled_optimizer": compiled_optimizer,
|
|
194
194
|
}
|
|
195
195
|
self.clip_lambda = clip_lambda
|
|
@@ -222,8 +222,8 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
222
222
|
def load_state_dict(self, state_dict: dict) -> None:
|
|
223
223
|
"""
|
|
224
224
|
Overrides default load_state_dict to implement a workaround for PyTorch's
|
|
225
|
-
automatic dtype casting. It ensures factorized states remain float32 for
|
|
226
|
-
stability, preserves integer/float8 quantized anchor states, and forces
|
|
225
|
+
automatic dtype casting. It ensures factorized states remain float32 for
|
|
226
|
+
stability, preserves integer/float8 quantized anchor states, and forces
|
|
227
227
|
standard states onto the parameter's current dtype/device.
|
|
228
228
|
"""
|
|
229
229
|
super().load_state_dict(state_dict)
|
|
@@ -244,6 +244,19 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
244
244
|
grad = p.grad
|
|
245
245
|
state = self.state[p]
|
|
246
246
|
|
|
247
|
+
|
|
248
|
+
beta1, beta2 = group['betas']
|
|
249
|
+
|
|
250
|
+
if group.get('kourkoutas_beta', False):
|
|
251
|
+
if 'step' not in state:
|
|
252
|
+
current_step = 0
|
|
253
|
+
else:
|
|
254
|
+
current_step = state['step']
|
|
255
|
+
# Call prepare_step() once at the beginning of the step for all params
|
|
256
|
+
self.kourkoutas_helper.maybe_prepare_step(current_step, p.device)
|
|
257
|
+
# Get the dynamic beta2 calculated in prepare_step()
|
|
258
|
+
beta2 = self.kourkoutas_helper.get_beta2(p, group)
|
|
259
|
+
|
|
247
260
|
# State Initialization
|
|
248
261
|
if 'step' not in state:
|
|
249
262
|
state['step'] = 0
|
|
@@ -256,6 +269,12 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
256
269
|
|
|
257
270
|
dtype = torch.float32 if state['factored'] else p.dtype
|
|
258
271
|
|
|
272
|
+
vt_init = grad.pow(2).to(dtype)
|
|
273
|
+
if isinstance(beta2, torch.Tensor) and beta2.dim() > 0:
|
|
274
|
+
vt_init.mul_(beta2).addcmul_(grad.to(dtype), grad.to(dtype) * (1.0 - beta2))
|
|
275
|
+
else:
|
|
276
|
+
vt_init.mul_(beta2).addcmul_(grad.to(dtype), grad.to(dtype), value=1.0 - beta2)
|
|
277
|
+
|
|
259
278
|
if state['factored']:
|
|
260
279
|
state['effective_shape'] = _get_effective_shape(p.numel())
|
|
261
280
|
d1, d2 = state['effective_shape']
|
|
@@ -279,33 +298,21 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
279
298
|
if self.use_AdEMAMix:
|
|
280
299
|
state['exp_avg_slow'] = torch.zeros_like(p, device=p.device, dtype=dtype)
|
|
281
300
|
# Second moment (v)
|
|
282
|
-
|
|
283
|
-
# Allocate NMF factors for vt
|
|
284
|
-
state['mu_v_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
|
|
285
|
-
state['mv_v_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
|
|
286
|
-
# Initialize v_0
|
|
287
|
-
state['mu_v_nmf'], state['mv_v_nmf'] = _nnmf(vt_init)
|
|
301
|
+
state['mu_v_nmf'], state['mv_v_nmf'] = _nnmf(vt_init.view(d1, d2))
|
|
288
302
|
del vt_init
|
|
289
303
|
else: # Fallback for non-factored tensors
|
|
290
304
|
if group['betas'][0] > 0:
|
|
291
305
|
state['exp_avg'] = torch.zeros_like(p, device=p.device, dtype=dtype)
|
|
292
306
|
if self.use_AdEMAMix:
|
|
293
307
|
state['exp_avg_slow'] = torch.zeros_like(p, device=p.device, dtype=dtype)
|
|
294
|
-
state['exp_avg_sq'] =
|
|
308
|
+
state['exp_avg_sq'] = vt_init
|
|
295
309
|
|
|
296
310
|
if group.get('scaled_optm', False) and is_spectral(p):
|
|
297
311
|
init_spectral_norm(group, state, p)
|
|
298
312
|
|
|
299
313
|
_init_anchor(p, state, group)
|
|
300
314
|
|
|
301
|
-
beta1, beta2 = group['betas']
|
|
302
|
-
|
|
303
315
|
current_step = state['step']
|
|
304
|
-
if group.get('kourkoutas_beta', False):
|
|
305
|
-
# Call prepare_step() once at the beginning of the step for all params
|
|
306
|
-
self.kourkoutas_helper.maybe_prepare_step(current_step, p.device)
|
|
307
|
-
# Get the dynamic beta2 calculated in prepare_step()
|
|
308
|
-
beta2 = self.kourkoutas_helper.get_beta2(p, group)
|
|
309
316
|
|
|
310
317
|
# The first step is for initialization only (skip when use_atan2 as it's scale invariant).
|
|
311
318
|
if state['step'] == 0 and not self.use_atan2:
|
|
@@ -361,7 +368,10 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
361
368
|
denom = vt.sqrt()
|
|
362
369
|
|
|
363
370
|
# Update second moment v_t for the *next* step using raw g_t
|
|
364
|
-
|
|
371
|
+
if isinstance(beta2, torch.Tensor) and beta2.dim() > 0:
|
|
372
|
+
vt.mul_(beta2).addcmul_(grad_reshaped, grad_reshaped * (1.0 - beta2))
|
|
373
|
+
else:
|
|
374
|
+
vt.mul_(beta2).addcmul_(grad_reshaped, grad_reshaped, value=1.0 - beta2)
|
|
365
375
|
# Factorize
|
|
366
376
|
state['mu_v_nmf'], state['mv_v_nmf'] = _factorize_state(vt, signed=False)
|
|
367
377
|
del vt
|
|
@@ -475,9 +485,11 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
475
485
|
else:
|
|
476
486
|
update = normalized_grad
|
|
477
487
|
|
|
478
|
-
|
|
479
488
|
# Update second moment v_t for the next step using raw g_t
|
|
480
|
-
|
|
489
|
+
if isinstance(beta2, torch.Tensor) and beta2.dim() > 0:
|
|
490
|
+
vt.mul_(beta2).addcmul_(grad, grad * (1.0 - beta2))
|
|
491
|
+
else:
|
|
492
|
+
vt.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
|
|
481
493
|
|
|
482
494
|
update_scaling = lr * A if self.use_atan2 else lr
|
|
483
495
|
|
|
@@ -8,6 +8,7 @@ from ..util.factorization_util import _get_effective_shape, _reconstruct_state,
|
|
|
8
8
|
from ..util.lion_k import _get_lion_k_update
|
|
9
9
|
from ..util.scaled_optm import scale_update, is_spectral, init_spectral_norm
|
|
10
10
|
from ..util.centered_decay import _init_anchor
|
|
11
|
+
from ..util.update_util import _get_l1_adaptive_lr
|
|
11
12
|
|
|
12
13
|
|
|
13
14
|
class Lion_adv(torch.optim.Optimizer):
|
|
@@ -46,7 +47,7 @@ class Lion_adv(torch.optim.Optimizer):
|
|
|
46
47
|
updates. Overrides explicit kappa_p value. (default: False).
|
|
47
48
|
freeze_on_flip (bool): Projected SignGD One-hit freeze. Masks updates for
|
|
48
49
|
coordinates where the gradient sign flips compared to the previous step. (default: False)
|
|
49
|
-
l1_adaptive (bool): Scales learning rate dynamically
|
|
50
|
+
l1_adaptive (bool): Scales learning rate dynamically
|
|
50
51
|
by the L1 norm of the gradient to handle gradient heterogeneity. (default: False).
|
|
51
52
|
centered_wd (float): Centered Weight Decay coefficient. Instead of decaying weights
|
|
52
53
|
toward zero, they are decayed toward their initial values (anchors). This
|
|
@@ -137,8 +138,8 @@ class Lion_adv(torch.optim.Optimizer):
|
|
|
137
138
|
def load_state_dict(self, state_dict: dict) -> None:
|
|
138
139
|
"""
|
|
139
140
|
Overrides default load_state_dict to implement a workaround for PyTorch's
|
|
140
|
-
automatic dtype casting. It ensures factorized states remain float32 for
|
|
141
|
-
stability, preserves integer/float8 quantized anchor states, and forces
|
|
141
|
+
automatic dtype casting. It ensures factorized states remain float32 for
|
|
142
|
+
stability, preserves integer/float8 quantized anchor states, and forces
|
|
142
143
|
standard states onto the parameter's current dtype/device.
|
|
143
144
|
"""
|
|
144
145
|
super().load_state_dict(state_dict)
|
|
@@ -251,8 +252,7 @@ class Lion_adv(torch.optim.Optimizer):
|
|
|
251
252
|
# Compute update term c_t
|
|
252
253
|
update = torch.lerp(grad_reshaped, exp_avg, beta1)
|
|
253
254
|
|
|
254
|
-
|
|
255
|
-
lr = lr * (update.norm(p=1))
|
|
255
|
+
l1_mean = _get_l1_adaptive_lr(p, update, state, group, kappa_p)
|
|
256
256
|
|
|
257
257
|
# Standard Lion momentum update
|
|
258
258
|
# m_t = beta2 * m_{t-1} + (1-beta2) * g_t
|
|
@@ -286,8 +286,7 @@ class Lion_adv(torch.optim.Optimizer):
|
|
|
286
286
|
# Compute update term
|
|
287
287
|
update = torch.lerp(grad, exp_avg, beta1)
|
|
288
288
|
|
|
289
|
-
|
|
290
|
-
lr = lr * (update.norm(p=1))
|
|
289
|
+
l1_mean = _get_l1_adaptive_lr(p, update, state, group, kappa_p)
|
|
291
290
|
|
|
292
291
|
update = _get_lion_k_update(update, kappa_p)
|
|
293
292
|
|
|
@@ -305,6 +304,9 @@ class Lion_adv(torch.optim.Optimizer):
|
|
|
305
304
|
update = torch.where(current_sign == state['prev_sign'], update, 0.0)
|
|
306
305
|
state['prev_sign'] = current_sign
|
|
307
306
|
|
|
307
|
+
if l1_mean is not None:
|
|
308
|
+
update.mul_(l1_mean)
|
|
309
|
+
|
|
308
310
|
if group.get('scaled_optm', False):
|
|
309
311
|
update = scale_update(p, update, lr, vector_state=state.get('spectral_v'))
|
|
310
312
|
else:
|
|
@@ -259,8 +259,8 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
259
259
|
def load_state_dict(self, state_dict: dict) -> None:
|
|
260
260
|
"""
|
|
261
261
|
Overrides default load_state_dict to implement a workaround for PyTorch's
|
|
262
|
-
automatic dtype casting. It ensures factorized states remain float32 for
|
|
263
|
-
stability, preserves integer/float8 quantized anchor states, and forces
|
|
262
|
+
automatic dtype casting. It ensures factorized states remain float32 for
|
|
263
|
+
stability, preserves integer/float8 quantized anchor states, and forces
|
|
264
264
|
standard states onto the parameter's current dtype/device.
|
|
265
265
|
"""
|
|
266
266
|
super().load_state_dict(state_dict)
|
|
@@ -67,7 +67,7 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
67
67
|
stability. (default: 100.0)
|
|
68
68
|
nnmf_factor (bool): whether to use the factorization or disable it to use
|
|
69
69
|
the uncompressed optimizer. (default: False)
|
|
70
|
-
factored_2nd (bool): whether to keep the first moment uncompressed (dense)
|
|
70
|
+
factored_2nd (bool): whether to keep the first moment uncompressed (dense)
|
|
71
71
|
while only factorizing the second moment. (default: True)
|
|
72
72
|
d0 (float):
|
|
73
73
|
Initial D estimate for D-adaptation (default 1e-6). Rarely needs changing.
|
|
@@ -255,8 +255,8 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
255
255
|
def load_state_dict(self, state_dict: dict) -> None:
|
|
256
256
|
"""
|
|
257
257
|
Overrides default load_state_dict to implement a workaround for PyTorch's
|
|
258
|
-
automatic dtype casting. It ensures factorized states remain float32 for
|
|
259
|
-
stability, preserves integer/float8 quantized anchor states, and forces
|
|
258
|
+
automatic dtype casting. It ensures factorized states remain float32 for
|
|
259
|
+
stability, preserves integer/float8 quantized anchor states, and forces
|
|
260
260
|
standard states onto the parameter's current dtype/device.
|
|
261
261
|
"""
|
|
262
262
|
super().load_state_dict(state_dict)
|
|
@@ -440,7 +440,10 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
440
440
|
update_mt = mt if not factored_2nd else mt.clone()
|
|
441
441
|
|
|
442
442
|
vt = _reconstruct_state((state['mu_v_nmf'], state['mv_v_nmf']), signed=False)
|
|
443
|
-
|
|
443
|
+
if isinstance(beta2, torch.Tensor) and beta2.dim() > 0:
|
|
444
|
+
vt.mul_(beta2).addcmul_(grad_reshaped, grad_reshaped * (d * d * (1.0 - beta2)))
|
|
445
|
+
else:
|
|
446
|
+
vt.mul_(beta2).addcmul_(grad_reshaped, grad_reshaped, value=d * d * (1.0 - beta2))
|
|
444
447
|
|
|
445
448
|
if self.use_AdEMAMix:
|
|
446
449
|
if factored_2nd:
|
|
@@ -453,7 +456,7 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
453
456
|
update = update_mt.add_(mt_slow, alpha=alpha)
|
|
454
457
|
else:
|
|
455
458
|
update = grad_reshaped.mul(d).add_(mt_slow, alpha=alpha)
|
|
456
|
-
|
|
459
|
+
|
|
457
460
|
if not factored_2nd:
|
|
458
461
|
# Factorize
|
|
459
462
|
state['mu_m_slow_nmf'], state['mv_m_slow_nmf'], state['sign_slow'] = _factorize_state(mt_slow, signed=True)
|
|
@@ -514,7 +517,10 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
514
517
|
update = grad.mul(d)
|
|
515
518
|
|
|
516
519
|
exp_avg_sq = state['exp_avg_sq']
|
|
517
|
-
|
|
520
|
+
if isinstance(beta2, torch.Tensor) and beta2.dim() > 0:
|
|
521
|
+
exp_avg_sq.mul_(beta2).addcmul_(grad, grad * (d * d * (1.0 - beta2)))
|
|
522
|
+
else:
|
|
523
|
+
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=d * d * (1.0 - beta2))
|
|
518
524
|
|
|
519
525
|
if group['use_atan2']:
|
|
520
526
|
denom = exp_avg_sq.sqrt()
|
|
@@ -608,4 +614,4 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
608
614
|
|
|
609
615
|
# Increment step counter for all groups, regardless of whether d was updated
|
|
610
616
|
for group in self.param_groups:
|
|
611
|
-
group['k'] += 1
|
|
617
|
+
group['k'] += 1
|
|
@@ -6,8 +6,8 @@ from ..util import param_update
|
|
|
6
6
|
from ..util.OrthoGrad import _orthogonalize_gradient
|
|
7
7
|
from ..util.factorization_util import _get_effective_shape, _reconstruct_state, _factorize_state, _pack_bools, _unpack_bools
|
|
8
8
|
from ..util.lion_k import _get_lion_k_update
|
|
9
|
+
from ..util.update_util import _get_l1_adaptive_lr
|
|
9
10
|
from ..util.scaled_optm import scale_update, is_spectral, init_spectral_norm
|
|
10
|
-
from ..util.update_util import _scale_sim_AdEMAMix_update
|
|
11
11
|
from ..util.centered_decay import _init_anchor
|
|
12
12
|
|
|
13
13
|
|
|
@@ -49,8 +49,8 @@ class SignSGD_adv(torch.optim.Optimizer):
|
|
|
49
49
|
stability. (default: 100.0)
|
|
50
50
|
freeze_on_flip (bool): Projected SignGD One-hit freeze. Masks updates for
|
|
51
51
|
coordinates where the gradient sign flips compared to the previous step. (default: False)
|
|
52
|
-
l1_adaptive (bool): Scales
|
|
53
|
-
by the L1 norm of the gradient to handle gradient heterogeneity.
|
|
52
|
+
l1_adaptive (bool): Scales the update step magnitude dynamically
|
|
53
|
+
by the mean L1 norm of the momentum/gradient to handle gradient heterogeneity.(default: False)
|
|
54
54
|
centered_wd (float): Centered Weight Decay coefficient. Instead of decaying weights
|
|
55
55
|
toward zero, they are decayed toward their initial values (anchors). This
|
|
56
56
|
can be used together with standard weight decay. (default: 0.0)
|
|
@@ -140,8 +140,8 @@ class SignSGD_adv(torch.optim.Optimizer):
|
|
|
140
140
|
def load_state_dict(self, state_dict: dict) -> None:
|
|
141
141
|
"""
|
|
142
142
|
Overrides default load_state_dict to implement a workaround for PyTorch's
|
|
143
|
-
automatic dtype casting. It ensures factorized states remain float32 for
|
|
144
|
-
stability, preserves integer/float8 quantized anchor states, and forces
|
|
143
|
+
automatic dtype casting. It ensures factorized states remain float32 for
|
|
144
|
+
stability, preserves integer/float8 quantized anchor states, and forces
|
|
145
145
|
standard states onto the parameter's current dtype/device.
|
|
146
146
|
"""
|
|
147
147
|
super().load_state_dict(state_dict)
|
|
@@ -269,9 +269,7 @@ class SignSGD_adv(torch.optim.Optimizer):
|
|
|
269
269
|
if freeze_on_flip:
|
|
270
270
|
state['sign'] = _pack_bools(raw_update > 0)
|
|
271
271
|
|
|
272
|
-
|
|
273
|
-
scale_factor = 1 / _scale_sim_AdEMAMix_update(momentum, state["step"] + 1, alpha_grad, 1, False)
|
|
274
|
-
lr = lr * (raw_update.norm(p=1)/scale_factor)
|
|
272
|
+
l1_mean = _get_l1_adaptive_lr(p, raw_update, state, group, kappa_p)
|
|
275
273
|
|
|
276
274
|
update = _get_lion_k_update(raw_update, kappa_p)
|
|
277
275
|
update = update.view(p.shape)
|
|
@@ -296,9 +294,7 @@ class SignSGD_adv(torch.optim.Optimizer):
|
|
|
296
294
|
else:
|
|
297
295
|
raw_update = grad.clone()
|
|
298
296
|
|
|
299
|
-
|
|
300
|
-
scale_factor = 1 / _scale_sim_AdEMAMix_update(momentum, state["step"] + 1, alpha_grad, 1, False)
|
|
301
|
-
lr = lr * (raw_update.norm(p=1)/scale_factor)
|
|
297
|
+
l1_mean = _get_l1_adaptive_lr(p, raw_update, state, group, kappa_p)
|
|
302
298
|
|
|
303
299
|
update = _get_lion_k_update(raw_update, kappa_p)
|
|
304
300
|
|
|
@@ -307,6 +303,9 @@ class SignSGD_adv(torch.optim.Optimizer):
|
|
|
307
303
|
update = torch.where(current_sign == state['prev_sign'], update, 0.0)
|
|
308
304
|
state['prev_sign'] = current_sign
|
|
309
305
|
|
|
306
|
+
if l1_mean is not None:
|
|
307
|
+
update.mul_(l1_mean)
|
|
308
|
+
|
|
310
309
|
if group.get('scaled_optm', False):
|
|
311
310
|
update = scale_update(p, update, lr, vector_state=state.get('spectral_v'))
|
|
312
311
|
else:
|
|
@@ -86,7 +86,7 @@ class Simplified_AdEMAMix(torch.optim.Optimizer):
|
|
|
86
86
|
'int4': Uses 4-bit block-wise quantization (block size 32).
|
|
87
87
|
nnmf_factor (bool): whether to use the factorization or disable it to use
|
|
88
88
|
the uncompressed optimizer. (default: False)
|
|
89
|
-
factored_2nd (bool): whether to keep the first moment uncompressed (dense)
|
|
89
|
+
factored_2nd (bool): whether to keep the first moment uncompressed (dense)
|
|
90
90
|
while only factorizing the second moment. (default: True)
|
|
91
91
|
"""
|
|
92
92
|
|
|
@@ -176,8 +176,8 @@ class Simplified_AdEMAMix(torch.optim.Optimizer):
|
|
|
176
176
|
def load_state_dict(self, state_dict: dict) -> None:
|
|
177
177
|
"""
|
|
178
178
|
Overrides default load_state_dict to implement a workaround for PyTorch's
|
|
179
|
-
automatic dtype casting. It ensures factorized states remain float32 for
|
|
180
|
-
stability, preserves integer/float8 quantized anchor states, and forces
|
|
179
|
+
automatic dtype casting. It ensures factorized states remain float32 for
|
|
180
|
+
stability, preserves integer/float8 quantized anchor states, and forces
|
|
181
181
|
standard states onto the parameter's current dtype/device.
|
|
182
182
|
"""
|
|
183
183
|
super().load_state_dict(state_dict)
|
|
@@ -320,7 +320,10 @@ class Simplified_AdEMAMix(torch.optim.Optimizer):
|
|
|
320
320
|
mt.mul_(beta1).add_(grad_reshaped)
|
|
321
321
|
|
|
322
322
|
vt = _reconstruct_state((state['mu_v_nmf'], state['mv_v_nmf']), signed=False)
|
|
323
|
-
|
|
323
|
+
if isinstance(beta2, torch.Tensor) and beta2.dim() > 0:
|
|
324
|
+
vt.mul_(beta2).addcmul_(grad_reshaped, grad_reshaped * (1.0 - beta2))
|
|
325
|
+
else:
|
|
326
|
+
vt.mul_(beta2).addcmul_(grad_reshaped, grad_reshaped, value=1.0 - beta2)
|
|
324
327
|
|
|
325
328
|
# update = mt + (grad_reshaped * alpha_grad)
|
|
326
329
|
update = torch.add(mt, grad_reshaped, alpha=alpha_grad)
|
|
@@ -347,7 +350,10 @@ class Simplified_AdEMAMix(torch.optim.Optimizer):
|
|
|
347
350
|
|
|
348
351
|
update = torch.add(exp_avg, grad, alpha=alpha_grad)
|
|
349
352
|
|
|
350
|
-
|
|
353
|
+
if isinstance(beta2, torch.Tensor) and beta2.dim() > 0:
|
|
354
|
+
exp_avg_sq.mul_(beta2).addcmul_(grad, grad * (1.0 - beta2))
|
|
355
|
+
else:
|
|
356
|
+
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)
|
|
351
357
|
|
|
352
358
|
denom = exp_avg_sq.sqrt().add_(sqrt_den_eps)
|
|
353
359
|
update.div_(denom)
|
|
@@ -34,8 +34,12 @@ class KourkoutasHelper:
|
|
|
34
34
|
else:
|
|
35
35
|
# No key function was provided. Default to coarse, shape-based bucketing.
|
|
36
36
|
self.optimizer.layer_key_fn = lambda p: \
|
|
37
|
-
(id(p),) if
|
|
38
|
-
|
|
37
|
+
(id(p),) if (
|
|
38
|
+
getattr(p, '_is_oft', False) or
|
|
39
|
+
getattr(p, '_is_lora_A', False) or
|
|
40
|
+
getattr(p, '_is_lora_B', False) or
|
|
41
|
+
getattr(p, '_is_dora_scale', False)
|
|
42
|
+
) else tuple(p.shape)
|
|
39
43
|
# This ensures that we won't mix embeddings with tokens (1 to 10)
|
|
40
44
|
# TODO find a better way to safeguard the embeddings
|
|
41
45
|
|
|
@@ -55,13 +59,21 @@ class KourkoutasHelper:
|
|
|
55
59
|
def _get_or_init_layer_ema_tensor(self, layer_key, layer_params, device):
|
|
56
60
|
"""
|
|
57
61
|
Retrieves the EMA tensor for this layer.
|
|
58
|
-
It handles synchronization between the internal layer_state and
|
|
62
|
+
It handles synchronization between the internal layer_state and
|
|
59
63
|
the external optimizer.state (which is required for state_dict saving/loading).
|
|
60
64
|
"""
|
|
61
65
|
# Initialize container in layer_state if missing
|
|
62
66
|
if layer_key not in self.layer_state:
|
|
67
|
+
p = layer_params[0]
|
|
68
|
+
if getattr(p, '_is_oft', False) or getattr(p, '_is_lora_A', False):
|
|
69
|
+
shape = (p.shape[0], 1)
|
|
70
|
+
elif getattr(p, '_is_lora_B', False):
|
|
71
|
+
shape = (1, p.shape[1])
|
|
72
|
+
else:
|
|
73
|
+
shape = ()
|
|
74
|
+
|
|
63
75
|
self.layer_state[layer_key] = {
|
|
64
|
-
'sum_sq_accumulator': torch.
|
|
76
|
+
'sum_sq_accumulator': torch.zeros(shape, device=device, dtype=torch.float32)
|
|
65
77
|
}
|
|
66
78
|
|
|
67
79
|
internal_ema = self.layer_state[layer_key].get('kourkoutas_r_ema')
|
|
@@ -87,7 +99,15 @@ class KourkoutasHelper:
|
|
|
87
99
|
|
|
88
100
|
# Case B: No state anywhere. Create new.
|
|
89
101
|
if internal_ema is None:
|
|
90
|
-
|
|
102
|
+
p = layer_params[0]
|
|
103
|
+
if getattr(p, '_is_oft', False) or getattr(p, '_is_lora_A', False):
|
|
104
|
+
shape = (p.shape[0], 1)
|
|
105
|
+
elif getattr(p, '_is_lora_B', False):
|
|
106
|
+
shape = (1, p.shape[1])
|
|
107
|
+
else:
|
|
108
|
+
shape = ()
|
|
109
|
+
|
|
110
|
+
new_ema = torch.zeros(shape, device=device, dtype=torch.float32)
|
|
91
111
|
self.layer_state[layer_key]['kourkoutas_r_ema'] = new_ema
|
|
92
112
|
|
|
93
113
|
# Register this tensor in optimizer.state for ALL params so it gets saved
|
|
@@ -107,7 +127,7 @@ class KourkoutasHelper:
|
|
|
107
127
|
|
|
108
128
|
def prepare_step(self, current_step: int, device):
|
|
109
129
|
"""
|
|
110
|
-
Calculates dynamic beta2 for all layers using the completed
|
|
130
|
+
Calculates dynamic beta2 for all layers using the completed accumulators
|
|
111
131
|
from the PREVIOUS step. Should be called once at the start of an optimizer step.
|
|
112
132
|
"""
|
|
113
133
|
beta2_log = []
|
|
@@ -154,7 +174,10 @@ class KourkoutasHelper:
|
|
|
154
174
|
beta2 = beta2_max - (beta2_max - beta2_min) * sun
|
|
155
175
|
|
|
156
176
|
# Store the final calculated beta2 in the helper's transient state for this step.
|
|
157
|
-
|
|
177
|
+
if isinstance(beta2, torch.Tensor) and beta2.numel() == 1 and not group.get('compiled_optimizer', False):
|
|
178
|
+
self.layer_state[layer_key]['dynamic_beta2'] = beta2.item()
|
|
179
|
+
else:
|
|
180
|
+
self.layer_state[layer_key]['dynamic_beta2'] = beta2
|
|
158
181
|
|
|
159
182
|
# Reset the accumulator for the next optimizer step.
|
|
160
183
|
accumulator.zero_()
|
|
@@ -163,10 +186,11 @@ class KourkoutasHelper:
|
|
|
163
186
|
|
|
164
187
|
# Compute stats for TensorBoard
|
|
165
188
|
if beta2_log:
|
|
166
|
-
|
|
189
|
+
# Handles lists containing both standard floats and heterogeneous tensors
|
|
190
|
+
means = [b.mean().item() if isinstance(b, torch.Tensor) else float(b) for b in beta2_log]
|
|
167
191
|
self.last_beta2_stats = {
|
|
168
|
-
'mean':
|
|
169
|
-
|
|
192
|
+
'mean': sum(means) / len(means)
|
|
193
|
+
}
|
|
170
194
|
|
|
171
195
|
def maybe_prepare_step(self, current_step: int, device):
|
|
172
196
|
"""
|
|
@@ -184,9 +208,16 @@ class KourkoutasHelper:
|
|
|
184
208
|
|
|
185
209
|
if layer_key in self.layer_info and layer_key in self.layer_state:
|
|
186
210
|
# Accumulate for the *next* step's prepare_step call
|
|
187
|
-
|
|
211
|
+
if getattr(p, '_is_oft', False) or getattr(p, '_is_lora_A', False):
|
|
212
|
+
sq_norm = torch.sum(grad.detach().pow(2), dim=1, keepdim=True).float()
|
|
213
|
+
elif getattr(p, '_is_lora_B', False):
|
|
214
|
+
sq_norm = torch.sum(grad.detach().pow(2), dim=0, keepdim=True).float()
|
|
215
|
+
else:
|
|
216
|
+
sq_norm = torch.sum(grad.detach().pow(2)).float()
|
|
217
|
+
|
|
218
|
+
self.layer_state[layer_key]['sum_sq_accumulator'] += sq_norm
|
|
188
219
|
|
|
189
|
-
def get_beta2(self, p: torch.Tensor, group: dict) -> float:
|
|
220
|
+
def get_beta2(self, p: torch.Tensor, group: dict) -> float | torch.Tensor:
|
|
190
221
|
"""
|
|
191
222
|
Gets the appropriate beta2 for the current parameter, handling warmup and dynamic value fetching.
|
|
192
223
|
"""
|
|
@@ -87,7 +87,10 @@ def _adam_step_parameter(self, p, grad, state, group, beta1_adam, beta2_adam, sq
|
|
|
87
87
|
update_mt = mt
|
|
88
88
|
|
|
89
89
|
vt = _reconstruct_state((state['mu_v_nmf'], state['mv_v_nmf']), signed=False)
|
|
90
|
-
|
|
90
|
+
if isinstance(beta2_adam, torch.Tensor) and beta2_adam.dim() > 0:
|
|
91
|
+
vt.mul_(beta2_adam).addcmul_(grad_reshaped, grad_reshaped * (1.0 - beta2_adam))
|
|
92
|
+
else:
|
|
93
|
+
vt.mul_(beta2_adam).addcmul_(grad_reshaped, grad_reshaped, value=1.0 - beta2_adam)
|
|
91
94
|
|
|
92
95
|
if group.get('adam_use_AdEMAMix'):
|
|
93
96
|
mt_slow = _reconstruct_state((state['mu_m_slow_nmf'], state['mv_m_slow_nmf'], state['sign_slow'], d2), signed=True)
|
|
@@ -148,7 +151,10 @@ def _adam_step_parameter(self, p, grad, state, group, beta1_adam, beta2_adam, sq
|
|
|
148
151
|
update = update_mt if beta1_adam > 0 else grad.clone()
|
|
149
152
|
|
|
150
153
|
exp_avg_sq = state['exp_avg_sq']
|
|
151
|
-
|
|
154
|
+
if isinstance(beta2_adam, torch.Tensor) and beta2_adam.dim() > 0:
|
|
155
|
+
exp_avg_sq.mul_(beta2_adam).addcmul_(grad, grad * (1.0 - beta2_adam))
|
|
156
|
+
else:
|
|
157
|
+
exp_avg_sq.mul_(beta2_adam).addcmul_(grad, grad, value=1.0 - beta2_adam)
|
|
152
158
|
|
|
153
159
|
if group.get('adam_use_atan2'):
|
|
154
160
|
denom = exp_avg_sq.sqrt()
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
def _orthogonalize_gradient(p: torch.Tensor, grad: torch.Tensor) -> torch.Tensor:
|
|
4
|
+
"""
|
|
5
|
+
Projects the gradient `grad` to be orthogonal to the parameter `p`.
|
|
6
|
+
Modified from:
|
|
7
|
+
https://github.com/LucasPrietoAl/grokking-at-the-edge-of-numerical-stability/blob/720d2444df12b851d6cb417ab08cf125c822b2ae/orthograd.py
|
|
8
|
+
"""
|
|
9
|
+
if getattr(p, '_is_oft', False) or getattr(p, '_is_lora_A', False):
|
|
10
|
+
return _orthogonalize_gradient_granular(p, grad, dim=1)
|
|
11
|
+
elif getattr(p, '_is_lora_B', False):
|
|
12
|
+
return _orthogonalize_gradient_granular(p, grad, dim=0)
|
|
13
|
+
|
|
14
|
+
original_shape = grad.shape
|
|
15
|
+
original_dtype = grad.dtype
|
|
16
|
+
w = p.view(-1).float()
|
|
17
|
+
g = grad.view(-1).float()
|
|
18
|
+
w_norm_sq = torch.dot(w, w).add_(1e-30)
|
|
19
|
+
proj = torch.dot(w, g) / w_norm_sq
|
|
20
|
+
g_orth = g.sub(w * proj)
|
|
21
|
+
g_norm = g.norm(2)
|
|
22
|
+
g_orth_norm = g_orth.norm(2).add_(1e-30)
|
|
23
|
+
g_orth_scaled = g_orth * (g_norm / g_orth_norm)
|
|
24
|
+
return g_orth_scaled.view(original_shape).to(original_dtype)
|
|
25
|
+
|
|
26
|
+
def _orthogonalize_gradient_granular(p: torch.Tensor, grad: torch.Tensor, dim: int = 1, eps: float = 1e-30) -> torch.Tensor:
|
|
27
|
+
"""
|
|
28
|
+
Projects the gradient `grad` to be orthogonal to the parameter `p` row/col-wise,
|
|
29
|
+
while preserving the original norm of the gradient for each row/col.
|
|
30
|
+
"""
|
|
31
|
+
original_dtype = grad.dtype
|
|
32
|
+
p_f32 = p.float()
|
|
33
|
+
grad_f32 = grad.float()
|
|
34
|
+
|
|
35
|
+
# Calculate the dot product <p, grad> for each row/col
|
|
36
|
+
dot_prod = torch.sum(p_f32 * grad_f32, dim=dim, keepdim=True)
|
|
37
|
+
|
|
38
|
+
# Calculate ||p||^2 for each row/col
|
|
39
|
+
p_norm_sq = torch.sum(p_f32 * p_f32, dim=dim, keepdim=True).add_(eps)
|
|
40
|
+
|
|
41
|
+
# Project: g_orth = g - (p * <p, g> / ||p||^2)
|
|
42
|
+
proj = dot_prod / p_norm_sq
|
|
43
|
+
grad_orth = grad_f32 - (proj * p_f32)
|
|
44
|
+
|
|
45
|
+
# Magnitude Preservation
|
|
46
|
+
g_norm = torch.norm(grad_f32, p=2, dim=dim, keepdim=True)
|
|
47
|
+
g_orth_norm = torch.norm(grad_orth, p=2, dim=dim, keepdim=True).add_(eps)
|
|
48
|
+
grad_orth_scaled = grad_orth * (g_norm / g_orth_norm)
|
|
49
|
+
|
|
50
|
+
return grad_orth_scaled.to(original_dtype)
|
|
@@ -109,4 +109,4 @@ def dequantize_anchor(p, state, group, dtype):
|
|
|
109
109
|
anchor_blocks = quantized_blocks.to(dtype) * scales.unsqueeze(1) + mins.unsqueeze(1)
|
|
110
110
|
|
|
111
111
|
# Flatten, truncate any padding added during quantization, and reshape
|
|
112
|
-
return anchor_blocks.view(-1)[:orig_numel].view(orig_shape)
|
|
112
|
+
return anchor_blocks.view(-1)[:orig_numel].view(orig_shape)
|
|
@@ -138,7 +138,7 @@ def set_seed(device: torch.device):
|
|
|
138
138
|
|
|
139
139
|
def get_generator(device: torch.device) -> torch.Generator:
|
|
140
140
|
"""
|
|
141
|
-
Retrieves (and initializes if necessary) the deterministic generator
|
|
141
|
+
Retrieves (and initializes if necessary) the deterministic generator
|
|
142
142
|
for the specified device.
|
|
143
143
|
"""
|
|
144
144
|
if device not in _generators:
|
|
@@ -241,9 +241,9 @@ def post_process_loaded_state(optimizer: Optimizer) -> None:
|
|
|
241
241
|
# Deterministically check if this parameter skipped quantization
|
|
242
242
|
numel = p.numel()
|
|
243
243
|
is_skipped = (
|
|
244
|
-
numel == 0 or
|
|
245
|
-
(mode in ['int8', 'int4'] and numel < 10000) or
|
|
246
|
-
p.ndim == 1 or
|
|
244
|
+
numel == 0 or
|
|
245
|
+
(mode in ['int8', 'int4'] and numel < 10000) or
|
|
246
|
+
p.ndim == 1 or
|
|
247
247
|
getattr(p, '_is_dora_scale', False)
|
|
248
248
|
)
|
|
249
249
|
|
|
@@ -283,4 +283,4 @@ def post_process_loaded_state(optimizer: Optimizer) -> None:
|
|
|
283
283
|
|
|
284
284
|
# Ensure device match
|
|
285
285
|
if state[key].device != p.device:
|
|
286
|
-
state[key] = state[key].to(p.device)
|
|
286
|
+
state[key] = state[key].to(p.device)
|
|
@@ -9,7 +9,7 @@ def scale_update(
|
|
|
9
9
|
vector_state: torch.Tensor | None = None
|
|
10
10
|
) -> torch.Tensor:
|
|
11
11
|
"""
|
|
12
|
-
Applies adaptive scaling to the parameter update based on the parameter's
|
|
12
|
+
Applies adaptive scaling to the parameter update based on the parameter's
|
|
13
13
|
role (DoRA, OFT, or LoRA/Full Finetuning).
|
|
14
14
|
|
|
15
15
|
Args:
|
|
@@ -28,11 +28,15 @@ def scale_update(
|
|
|
28
28
|
if is_dora_scale or p.ndim == 1:
|
|
29
29
|
return rms_normalization(update, dim=None, lr=lr)
|
|
30
30
|
|
|
31
|
-
# Orthogonal Fine-Tuning (OFT)
|
|
32
|
-
# RMS normalization (dim=1 normalizes per block)
|
|
31
|
+
# Orthogonal Fine-Tuning (OFT)
|
|
33
32
|
# This guarantees O(1) update complexity scaling, independent of block sizes.
|
|
34
33
|
if is_oft:
|
|
35
|
-
|
|
34
|
+
n = update.shape[1]
|
|
35
|
+
# Calculate block size (b)
|
|
36
|
+
b = (1 + (1 + 8 * n) ** 0.5) / 2
|
|
37
|
+
target_norm = (b / 8) ** 0.5
|
|
38
|
+
scale = target_norm / (n ** 0.5)
|
|
39
|
+
return rms_normalization(update, dim=1, lr=lr * scale)
|
|
36
40
|
|
|
37
41
|
# LoRA Factors or Full Finetuning weights
|
|
38
42
|
# Scales update to maintain consistent spectral norm across different layer sizes and ranks.
|
|
@@ -44,7 +48,7 @@ def scale_update(
|
|
|
44
48
|
|
|
45
49
|
def scale_wds(wd: float, cwd: float, p: torch.Tensor) -> tuple[float, float]:
|
|
46
50
|
"""
|
|
47
|
-
Adjusts standard weight decay and centered weight decay based on the parameter's
|
|
51
|
+
Adjusts standard weight decay and centered weight decay based on the parameter's
|
|
48
52
|
shape and type to maintain effective regularization strength.
|
|
49
53
|
"""
|
|
50
54
|
# DoRA Scale (Magnitude Vector)
|
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
def _grams_update(mt: torch.Tensor, grad: torch.Tensor, inplace: bool=False):
|
|
4
|
+
"""
|
|
5
|
+
Applies the update rule of "Gradient Descent with Adaptive Momentum Scaling"
|
|
6
|
+
(https://arxiv.org/abs/2412.17107).
|
|
7
|
+
"""
|
|
8
|
+
if inplace:
|
|
9
|
+
return mt.abs_().mul_(grad.sign())
|
|
10
|
+
return grad.sign().mul_(mt.abs())
|
|
11
|
+
|
|
12
|
+
def _cautious_update(mt: torch.Tensor, grad: torch.Tensor, inplace: bool=False):
|
|
13
|
+
"""
|
|
14
|
+
Applies the update rule of "Cautious Optimizers: Improving Training with One
|
|
15
|
+
Line of Code" (https://arxiv.org/abs/2411.16085).
|
|
16
|
+
"""
|
|
17
|
+
mask = (mt * grad > 0).to(grad.dtype)
|
|
18
|
+
mask.div_(mask.mean().clamp_min_(1e-3))
|
|
19
|
+
if inplace:
|
|
20
|
+
update_mt = mt.mul_(mask)
|
|
21
|
+
else:
|
|
22
|
+
update_mt = mt.mul(mask)
|
|
23
|
+
del mask
|
|
24
|
+
return update_mt
|
|
25
|
+
|
|
26
|
+
def _scale_sim_AdEMAMix_update(beta: float, current_step: int, alpha_grad: float, lr: float, scaled_optm: bool=False):
|
|
27
|
+
if scaled_optm:
|
|
28
|
+
return lr
|
|
29
|
+
momentum_scale = (1 - beta ** current_step) / (1 - beta)
|
|
30
|
+
total_scale = 1 / (momentum_scale + alpha_grad)
|
|
31
|
+
lr = lr * total_scale
|
|
32
|
+
return lr
|
|
33
|
+
|
|
34
|
+
def _get_l1_adaptive_lr(
|
|
35
|
+
p: torch.Tensor,
|
|
36
|
+
update: torch.Tensor,
|
|
37
|
+
state: dict,
|
|
38
|
+
group: dict,
|
|
39
|
+
kappa_p: float
|
|
40
|
+
) -> torch.Tensor:
|
|
41
|
+
"""
|
|
42
|
+
Calculates the L1 adaptive learning rate based on gradient heterogeneity.
|
|
43
|
+
"""
|
|
44
|
+
if not group.get("l1_adaptive", False) and kappa_p != 1:
|
|
45
|
+
return None
|
|
46
|
+
|
|
47
|
+
momentum = group["momentum"]
|
|
48
|
+
alpha_grad = group["alpha_grad"]
|
|
49
|
+
update_view = update.view(p.shape)
|
|
50
|
+
|
|
51
|
+
# Calculate scale factor based on momentum/update magnitude
|
|
52
|
+
scale_factor = _scale_sim_AdEMAMix_update(
|
|
53
|
+
momentum, state["step"] + 1, alpha_grad, 1, False
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
# Determine dimension for mean calculation based on parameter type
|
|
57
|
+
if getattr(p, '_is_oft', False) or getattr(p, '_is_lora_A', False):
|
|
58
|
+
l1_dim = 1
|
|
59
|
+
elif getattr(p, '_is_lora_B', False):
|
|
60
|
+
l1_dim = 0
|
|
61
|
+
else:
|
|
62
|
+
update_abs = update_view.abs() * scale_factor
|
|
63
|
+
if update_abs.ndim >= 2:
|
|
64
|
+
orig_shape = update_abs.shape
|
|
65
|
+
update_2d = update_abs.view(orig_shape[0], -1)
|
|
66
|
+
mean_l1_norm_2d = torch.outer(update_2d.mean(dim=1), update_2d.mean(dim=0))
|
|
67
|
+
return mean_l1_norm_2d.view(orig_shape)
|
|
68
|
+
else:
|
|
69
|
+
return update_abs.mean()
|
|
70
|
+
|
|
71
|
+
mean_l1_norm = update_view.abs().mean(dim=l1_dim, keepdim=True) * scale_factor
|
|
72
|
+
|
|
73
|
+
return mean_l1_norm
|
|
@@ -1,21 +0,0 @@
|
|
|
1
|
-
import torch
|
|
2
|
-
|
|
3
|
-
def _orthogonalize_gradient(p: torch.Tensor, grad: torch.Tensor) -> torch.Tensor:
|
|
4
|
-
"""
|
|
5
|
-
Projects the gradient `grad` to be orthogonal to the parameter `p`.
|
|
6
|
-
Modified from:
|
|
7
|
-
https://github.com/LucasPrietoAl/grokking-at-the-edge-of-numerical-stability/blob/720d2444df12b851d6cb417ab08cf125c822b2ae/orthograd.py
|
|
8
|
-
"""
|
|
9
|
-
if grad.is_sparse:
|
|
10
|
-
raise RuntimeError("OrthoGrad logic does not support sparse gradients.")
|
|
11
|
-
original_shape = grad.shape
|
|
12
|
-
original_dtype = grad.dtype
|
|
13
|
-
w = p.view(-1).float()
|
|
14
|
-
g = grad.view(-1).float()
|
|
15
|
-
w_norm_sq = torch.dot(w, w).add_(1e-30)
|
|
16
|
-
proj = torch.dot(w, g) / w_norm_sq
|
|
17
|
-
g_orth = g.sub(w * proj)
|
|
18
|
-
g_norm = g.norm(2)
|
|
19
|
-
g_orth_norm = g_orth.norm(2).add_(1e-30)
|
|
20
|
-
g_orth_scaled = g_orth * (g_norm / g_orth_norm)
|
|
21
|
-
return g_orth_scaled.view(original_shape).to(original_dtype)
|
|
@@ -1,32 +0,0 @@
|
|
|
1
|
-
import torch
|
|
2
|
-
|
|
3
|
-
def _grams_update(mt: torch.Tensor, grad: torch.Tensor, inplace: bool=False):
|
|
4
|
-
"""
|
|
5
|
-
Applies the update rule of "Gradient Descent with Adaptive Momentum Scaling"
|
|
6
|
-
(https://arxiv.org/abs/2412.17107).
|
|
7
|
-
"""
|
|
8
|
-
if inplace:
|
|
9
|
-
return mt.abs_().mul_(grad.sign())
|
|
10
|
-
return grad.sign().mul_(mt.abs())
|
|
11
|
-
|
|
12
|
-
def _cautious_update(mt: torch.Tensor, grad: torch.Tensor, inplace: bool=False):
|
|
13
|
-
"""
|
|
14
|
-
Applies the update rule of "Cautious Optimizers: Improving Training with One
|
|
15
|
-
Line of Code" (https://arxiv.org/abs/2411.16085).
|
|
16
|
-
"""
|
|
17
|
-
mask = (mt * grad > 0).to(grad.dtype)
|
|
18
|
-
mask.div_(mask.mean().clamp_min_(1e-3))
|
|
19
|
-
if inplace:
|
|
20
|
-
update_mt = mt.mul_(mask)
|
|
21
|
-
else:
|
|
22
|
-
update_mt = mt.mul(mask)
|
|
23
|
-
del mask
|
|
24
|
-
return update_mt
|
|
25
|
-
|
|
26
|
-
def _scale_sim_AdEMAMix_update(beta: float, current_step: int, alpha_grad: float, lr: float, scaled_optm: bool=False):
|
|
27
|
-
if scaled_optm:
|
|
28
|
-
return lr
|
|
29
|
-
momentum_scale = (1 - beta ** current_step) / (1 - beta)
|
|
30
|
-
total_scale = 1 / (momentum_scale + alpha_grad)
|
|
31
|
-
lr = lr * total_scale
|
|
32
|
-
return lr
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|