adv-optm 2.4.dev5__tar.gz → 2.4.dev6__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {adv_optm-2.4.dev5 → adv_optm-2.4.dev6}/PKG-INFO +1 -1
- {adv_optm-2.4.dev5 → adv_optm-2.4.dev6}/adv_optm/__init__.py +1 -1
- {adv_optm-2.4.dev5 → adv_optm-2.4.dev6}/adv_optm/optim/AdamW_adv.py +8 -3
- {adv_optm-2.4.dev5 → adv_optm-2.4.dev6}/adv_optm/optim/Adopt_adv.py +8 -3
- {adv_optm-2.4.dev5 → adv_optm-2.4.dev6}/adv_optm/util/scaled_optm.py +18 -6
- {adv_optm-2.4.dev5 → adv_optm-2.4.dev6}/adv_optm.egg-info/PKG-INFO +1 -1
- {adv_optm-2.4.dev5 → adv_optm-2.4.dev6}/setup.py +1 -1
- {adv_optm-2.4.dev5 → adv_optm-2.4.dev6}/LICENSE +0 -0
- {adv_optm-2.4.dev5 → adv_optm-2.4.dev6}/README.md +0 -0
- {adv_optm-2.4.dev5 → adv_optm-2.4.dev6}/adv_optm/optim/AdaMuon_adv.py +0 -0
- {adv_optm-2.4.dev5 → adv_optm-2.4.dev6}/adv_optm/optim/Lion_Prodigy_adv.py +0 -0
- {adv_optm-2.4.dev5 → adv_optm-2.4.dev6}/adv_optm/optim/Lion_adv.py +0 -0
- {adv_optm-2.4.dev5 → adv_optm-2.4.dev6}/adv_optm/optim/Muon_adv.py +0 -0
- {adv_optm-2.4.dev5 → adv_optm-2.4.dev6}/adv_optm/optim/Prodigy_adv.py +0 -0
- {adv_optm-2.4.dev5 → adv_optm-2.4.dev6}/adv_optm/optim/SignSGD_adv.py +0 -0
- {adv_optm-2.4.dev5 → adv_optm-2.4.dev6}/adv_optm/optim/Simplified_AdEMAMix.py +0 -0
- {adv_optm-2.4.dev5 → adv_optm-2.4.dev6}/adv_optm/optim/__init__.py +0 -0
- {adv_optm-2.4.dev5 → adv_optm-2.4.dev6}/adv_optm/util/Kourkoutas.py +0 -0
- {adv_optm-2.4.dev5 → adv_optm-2.4.dev6}/adv_optm/util/Muon_AuxAdam.py +0 -0
- {adv_optm-2.4.dev5 → adv_optm-2.4.dev6}/adv_optm/util/Muon_util.py +0 -0
- {adv_optm-2.4.dev5 → adv_optm-2.4.dev6}/adv_optm/util/OrthoGrad.py +0 -0
- {adv_optm-2.4.dev5 → adv_optm-2.4.dev6}/adv_optm/util/__init__.py +0 -0
- {adv_optm-2.4.dev5 → adv_optm-2.4.dev6}/adv_optm/util/centered_decay.py +0 -0
- {adv_optm-2.4.dev5 → adv_optm-2.4.dev6}/adv_optm/util/factorization_util.py +0 -0
- {adv_optm-2.4.dev5 → adv_optm-2.4.dev6}/adv_optm/util/lion_k.py +0 -0
- {adv_optm-2.4.dev5 → adv_optm-2.4.dev6}/adv_optm/util/param_update.py +0 -0
- {adv_optm-2.4.dev5 → adv_optm-2.4.dev6}/adv_optm/util/signed_util.py +0 -0
- {adv_optm-2.4.dev5 → adv_optm-2.4.dev6}/adv_optm/util/update_util.py +0 -0
- {adv_optm-2.4.dev5 → adv_optm-2.4.dev6}/adv_optm.egg-info/SOURCES.txt +0 -0
- {adv_optm-2.4.dev5 → adv_optm-2.4.dev6}/adv_optm.egg-info/dependency_links.txt +0 -0
- {adv_optm-2.4.dev5 → adv_optm-2.4.dev6}/adv_optm.egg-info/requires.txt +0 -0
- {adv_optm-2.4.dev5 → adv_optm-2.4.dev6}/adv_optm.egg-info/top_level.txt +0 -0
- {adv_optm-2.4.dev5 → adv_optm-2.4.dev6}/setup.cfg +0 -0
|
@@ -9,7 +9,7 @@ from ..util.factorization_util import _get_effective_shape, _reconstruct_state,
|
|
|
9
9
|
from ..util.update_util import _grams_update, _cautious_update, _init_fisher_wd_scaler, _get_fisher_wd_scaler
|
|
10
10
|
from ..util.OrthoGrad import _orthogonalize_gradient
|
|
11
11
|
from ..util.Kourkoutas import KourkoutasHelper
|
|
12
|
-
from ..util.scaled_optm import scale_update, is_spectral, init_spectral_norm
|
|
12
|
+
from ..util.scaled_optm import scale_update, is_spectral, init_spectral_norm, scale_eps
|
|
13
13
|
from ..util.centered_decay import _init_anchor
|
|
14
14
|
|
|
15
15
|
A = 4 / math.pi
|
|
@@ -153,6 +153,9 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
153
153
|
raise ValueError(f"Weight-decay should be >= 0.0. Got {weight_decay}")
|
|
154
154
|
if kourkoutas_beta and not (betas[1] > beta2_min):
|
|
155
155
|
raise ValueError(f"For Kourkoutas-β, betas[1] (as beta2_max) must be > beta2_min. Got {betas[1]} and {beta2_min}")
|
|
156
|
+
if scaled_optm and use_atan2:
|
|
157
|
+
print("Warning: use_atan2 is incompatible with scaled_optm, Disabling atan2.")
|
|
158
|
+
use_atan2 = False
|
|
156
159
|
|
|
157
160
|
if cautious_mask and grams_moment:
|
|
158
161
|
print("Warning: cautious is incompatible with grams, Disabling cautious.")
|
|
@@ -330,6 +333,8 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
330
333
|
# Determine if we are using dense first-moments alongside a factored second-order second-moment
|
|
331
334
|
factored_2nd = group.get('factored_2nd', False)
|
|
332
335
|
|
|
336
|
+
adaptive_eps = scale_eps(group, p)
|
|
337
|
+
|
|
333
338
|
if state['factored']:
|
|
334
339
|
d1, d2 = state['effective_shape']
|
|
335
340
|
grad_reshaped = grad.view(d1, d2)
|
|
@@ -394,7 +399,7 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
394
399
|
update.atan2_(denom)
|
|
395
400
|
else:
|
|
396
401
|
denom = vt.sqrt_()
|
|
397
|
-
denom.div_(sqrt_bias_correction2).add_(
|
|
402
|
+
denom.div_(sqrt_bias_correction2).add_(adaptive_eps)
|
|
398
403
|
update.div_(denom)
|
|
399
404
|
|
|
400
405
|
wd_scaler = _get_fisher_wd_scaler(group, state.get("wd_scaler"), p, denom, group['use_atan2'])
|
|
@@ -438,7 +443,7 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
438
443
|
update.atan2_(denom)
|
|
439
444
|
else:
|
|
440
445
|
denom = exp_avg_sq.sqrt()
|
|
441
|
-
denom.div_(sqrt_bias_correction2).add_(
|
|
446
|
+
denom.div_(sqrt_bias_correction2).add_(adaptive_eps)
|
|
442
447
|
update.div_(denom)
|
|
443
448
|
|
|
444
449
|
wd_scaler = _get_fisher_wd_scaler(group, state.get("wd_scaler"), p, denom, group['use_atan2'])
|
|
@@ -8,7 +8,7 @@ from ..util.factorization_util import _get_effective_shape, _reconstruct_state,
|
|
|
8
8
|
from ..util.OrthoGrad import _orthogonalize_gradient
|
|
9
9
|
from ..util.Kourkoutas import KourkoutasHelper
|
|
10
10
|
from ..util.update_util import _grams_update, _cautious_update, _scale_sim_AdEMAMix_update, _init_fisher_wd_scaler, _get_fisher_wd_scaler
|
|
11
|
-
from ..util.scaled_optm import scale_update, is_spectral, init_spectral_norm
|
|
11
|
+
from ..util.scaled_optm import scale_update, is_spectral, init_spectral_norm, scale_eps
|
|
12
12
|
from ..util.centered_decay import _init_anchor
|
|
13
13
|
|
|
14
14
|
A = 4 / math.pi
|
|
@@ -183,6 +183,9 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
183
183
|
print("Warning: grams is incompatible with Simplified_AdEMAMix, Disabling grams.")
|
|
184
184
|
if cautious_mask and Simplified_AdEMAMix:
|
|
185
185
|
print("Warning: cautious is incompatible with Simplified_AdEMAMix, Disabling cautious.")
|
|
186
|
+
if scaled_optm and use_atan2:
|
|
187
|
+
print("Warning: use_atan2 is incompatible with scaled_optm, Disabling atan2.")
|
|
188
|
+
use_atan2 = False
|
|
186
189
|
|
|
187
190
|
defaults = {
|
|
188
191
|
"lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay,
|
|
@@ -364,6 +367,8 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
364
367
|
# Determine if we are using dense first-moments alongside a factored second-order second-moment
|
|
365
368
|
factored_2nd = group.get('factored_2nd', False)
|
|
366
369
|
|
|
370
|
+
adaptive_eps = scale_eps(group, p)
|
|
371
|
+
|
|
367
372
|
if state['factored']:
|
|
368
373
|
d1, d2 = state['effective_shape']
|
|
369
374
|
grad_reshaped = grad.view(d1, d2)
|
|
@@ -387,7 +392,7 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
387
392
|
if self.use_atan2:
|
|
388
393
|
normalized_grad = torch.atan2(grad_reshaped, denom, out=denom)
|
|
389
394
|
else:
|
|
390
|
-
normalized_grad = torch.div(grad_reshaped, denom.add_(
|
|
395
|
+
normalized_grad = torch.div(grad_reshaped, denom.add_(adaptive_eps), out=denom)
|
|
391
396
|
if self.clip_lambda is not None:
|
|
392
397
|
clip_val = self.clip_lambda(state['step'])
|
|
393
398
|
normalized_grad.clamp_(-clip_val, clip_val)
|
|
@@ -457,7 +462,7 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
457
462
|
if self.use_atan2:
|
|
458
463
|
normalized_grad = torch.atan2(grad, denom, out=denom)
|
|
459
464
|
else:
|
|
460
|
-
normalized_grad = torch.div(grad, denom.add_(
|
|
465
|
+
normalized_grad = torch.div(grad, denom.add_(adaptive_eps), out=denom)
|
|
461
466
|
if self.clip_lambda is not None:
|
|
462
467
|
clip_val = self.clip_lambda(state['step'])
|
|
463
468
|
normalized_grad.clamp_(-clip_val, clip_val)
|
|
@@ -2,6 +2,8 @@ import torch
|
|
|
2
2
|
|
|
3
3
|
from . import param_update
|
|
4
4
|
|
|
5
|
+
import math
|
|
6
|
+
|
|
5
7
|
def scale_update(
|
|
6
8
|
p: torch.Tensor,
|
|
7
9
|
update: torch.Tensor,
|
|
@@ -26,16 +28,16 @@ def scale_update(
|
|
|
26
28
|
|
|
27
29
|
# DoRA Magnitude Scales (1D) or 1D Bias/Norm layers
|
|
28
30
|
if is_dora_scale or p.ndim == 1:
|
|
29
|
-
return
|
|
31
|
+
return l2_normalization(update, dim=None, lr=lr)
|
|
30
32
|
|
|
31
33
|
# Orthogonal Fine-Tuning (OFT)
|
|
32
34
|
# This guarantees O(1) update complexity scaling, independent of block sizes.
|
|
33
35
|
if is_oft:
|
|
34
36
|
n = update.shape[1]
|
|
35
37
|
# Calculate block size (b)
|
|
36
|
-
b = (1 + (1 + 8 * n)
|
|
37
|
-
target_norm = (b / 8)
|
|
38
|
-
scale = target_norm / (n
|
|
38
|
+
b = (1 + math.sqrt(1 + 8 * n)) / 2
|
|
39
|
+
target_norm = math.sqrt(b / 8)
|
|
40
|
+
scale = target_norm / math.sqrt(n)
|
|
39
41
|
return rms_normalization(update, dim=1, lr=lr * scale)
|
|
40
42
|
|
|
41
43
|
# LoRA Factors or Full Finetuning weights
|
|
@@ -46,6 +48,16 @@ def scale_update(
|
|
|
46
48
|
return update.mul_(lr)
|
|
47
49
|
|
|
48
50
|
|
|
51
|
+
def scale_eps(group: dict, p) -> tuple[float, float]:
|
|
52
|
+
"""
|
|
53
|
+
Scales Adam eps to be scale-invariant.
|
|
54
|
+
"""
|
|
55
|
+
if group.get('scaled_optm', False):
|
|
56
|
+
adaptive_eps = (1.0 / group['n_layers']) * (1.0 / math.sqrt(p.numel()))
|
|
57
|
+
else:
|
|
58
|
+
adaptive_eps = group['eps']
|
|
59
|
+
return adaptive_eps
|
|
60
|
+
|
|
49
61
|
def adjust_wds(wd: float, cwd: float, p: torch.Tensor) -> tuple[float, float]:
|
|
50
62
|
"""
|
|
51
63
|
Adjusts standard weight decay and centered weight decay.
|
|
@@ -93,7 +105,7 @@ def rms_normalization(update: torch.Tensor, dim: int | None, lr: float) -> torch
|
|
|
93
105
|
"""Performs Root Mean Square normalization on the update tensor."""
|
|
94
106
|
n = update.numel() if dim is None else update.shape[dim]
|
|
95
107
|
norm = torch.linalg.vector_norm(update, ord=2, dim=dim, keepdim=True).clamp_min_(1e-12)
|
|
96
|
-
scale_n = n
|
|
108
|
+
scale_n = math.sqrt(n)
|
|
97
109
|
return update.mul_(lr * scale_n / norm)
|
|
98
110
|
|
|
99
111
|
|
|
@@ -123,7 +135,7 @@ def spectral_normalization(update: torch.Tensor, vector_state: torch.Tensor, lr:
|
|
|
123
135
|
update = update.to(vector_state.dtype)
|
|
124
136
|
update_flat = update.view(d_out, d_in)
|
|
125
137
|
# Target scale derived from the "Modular Norm" paper
|
|
126
|
-
target_scale = (d_out / d_in)
|
|
138
|
+
target_scale = math.sqrt(d_out / d_in)
|
|
127
139
|
# Power Iteration step to estimate the largest singular value (sigma)
|
|
128
140
|
# u = Wv
|
|
129
141
|
u = torch.mv(update_flat, vector_state)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|