adv-optm 2.5.9__tar.gz → 2.6.1.dev2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {adv_optm-2.5.9 → adv_optm-2.6.1.dev2}/PKG-INFO +1 -1
- {adv_optm-2.5.9 → adv_optm-2.6.1.dev2}/adv_optm/__init__.py +1 -1
- {adv_optm-2.5.9 → adv_optm-2.6.1.dev2}/adv_optm/optim/AdaMuon_adv.py +2 -1
- {adv_optm-2.5.9 → adv_optm-2.6.1.dev2}/adv_optm/optim/AdamW_adv.py +2 -1
- {adv_optm-2.5.9 → adv_optm-2.6.1.dev2}/adv_optm/optim/Adopt_adv.py +2 -1
- {adv_optm-2.5.9 → adv_optm-2.6.1.dev2}/adv_optm/optim/Lion_adv.py +2 -0
- {adv_optm-2.5.9 → adv_optm-2.6.1.dev2}/adv_optm/optim/Muon_adv.py +2 -1
- {adv_optm-2.5.9 → adv_optm-2.6.1.dev2}/adv_optm/optim/Prodigy_adv.py +2 -1
- {adv_optm-2.5.9 → adv_optm-2.6.1.dev2}/adv_optm/optim/SignSGD_adv.py +2 -0
- {adv_optm-2.5.9 → adv_optm-2.6.1.dev2}/adv_optm/optim/SinkSGD_adv.py +2 -1
- {adv_optm-2.5.9 → adv_optm-2.6.1.dev2}/adv_optm/util/OrthoGrad.py +1 -1
- {adv_optm-2.5.9 → adv_optm-2.6.1.dev2}/adv_optm/util/param_update.py +33 -27
- {adv_optm-2.5.9 → adv_optm-2.6.1.dev2}/adv_optm/util/scaled_optm.py +24 -37
- {adv_optm-2.5.9 → adv_optm-2.6.1.dev2}/adv_optm.egg-info/PKG-INFO +1 -1
- {adv_optm-2.5.9 → adv_optm-2.6.1.dev2}/setup.py +1 -1
- {adv_optm-2.5.9 → adv_optm-2.6.1.dev2}/LICENSE +0 -0
- {adv_optm-2.5.9 → adv_optm-2.6.1.dev2}/README.md +0 -0
- {adv_optm-2.5.9 → adv_optm-2.6.1.dev2}/adv_optm/optim/__init__.py +0 -0
- {adv_optm-2.5.9 → adv_optm-2.6.1.dev2}/adv_optm/util/Kourkoutas.py +0 -0
- {adv_optm-2.5.9 → adv_optm-2.6.1.dev2}/adv_optm/util/Muon_AuxAdam.py +0 -0
- {adv_optm-2.5.9 → adv_optm-2.6.1.dev2}/adv_optm/util/Muon_util.py +0 -0
- {adv_optm-2.5.9 → adv_optm-2.6.1.dev2}/adv_optm/util/__init__.py +0 -0
- {adv_optm-2.5.9 → adv_optm-2.6.1.dev2}/adv_optm/util/centered_decay.py +0 -0
- {adv_optm-2.5.9 → adv_optm-2.6.1.dev2}/adv_optm/util/factorization_util.py +0 -0
- {adv_optm-2.5.9 → adv_optm-2.6.1.dev2}/adv_optm/util/lion_k.py +0 -0
- {adv_optm-2.5.9 → adv_optm-2.6.1.dev2}/adv_optm/util/signed_util.py +0 -0
- {adv_optm-2.5.9 → adv_optm-2.6.1.dev2}/adv_optm/util/sinkhorn.py +0 -0
- {adv_optm-2.5.9 → adv_optm-2.6.1.dev2}/adv_optm/util/state_util.py +0 -0
- {adv_optm-2.5.9 → adv_optm-2.6.1.dev2}/adv_optm/util/update_util.py +0 -0
- {adv_optm-2.5.9 → adv_optm-2.6.1.dev2}/adv_optm.egg-info/SOURCES.txt +0 -0
- {adv_optm-2.5.9 → adv_optm-2.6.1.dev2}/adv_optm.egg-info/dependency_links.txt +0 -0
- {adv_optm-2.5.9 → adv_optm-2.6.1.dev2}/adv_optm.egg-info/requires.txt +0 -0
- {adv_optm-2.5.9 → adv_optm-2.6.1.dev2}/adv_optm.egg-info/top_level.txt +0 -0
- {adv_optm-2.5.9 → adv_optm-2.6.1.dev2}/setup.cfg +0 -0
|
@@ -137,6 +137,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
137
137
|
# Decoupled/cautious weight decay
|
|
138
138
|
weight_decay: float = 0,
|
|
139
139
|
cautious_wd: bool = False,
|
|
140
|
+
scaled_wd: bool = False,
|
|
140
141
|
# Nesterov momentum
|
|
141
142
|
nesterov: bool = True,
|
|
142
143
|
nesterov_coef: float | None = None,
|
|
@@ -227,7 +228,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
227
228
|
|
|
228
229
|
defaults = {
|
|
229
230
|
"lr": lr, "betas": betas, "weight_decay": weight_decay, "cautious_wd": cautious_wd,
|
|
230
|
-
"eps": eps, "rms_rescaling": rms_rescaling, "ns_steps": ns_steps,
|
|
231
|
+
"eps": eps, "rms_rescaling": rms_rescaling, "ns_steps": ns_steps, "scaled_wd": scaled_wd,
|
|
231
232
|
"ns_eps": ns_eps, "ns_coeffs": ns_coeffs, "nnmf_factor": nnmf_factor,
|
|
232
233
|
"vector_reshape": vector_reshape,
|
|
233
234
|
"nesterov":nesterov, "nesterov_coef": nesterov_coef, "use_atan2":use_atan2,
|
|
@@ -98,6 +98,7 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
98
98
|
weight_decay: float = 0.0,
|
|
99
99
|
fisher_wd: bool = False,
|
|
100
100
|
cautious_wd: bool = False,
|
|
101
|
+
scaled_wd: bool = False,
|
|
101
102
|
# Adam's Bias Correction
|
|
102
103
|
use_bias_correction: bool = True,
|
|
103
104
|
# Stochastic Rounding for BF16
|
|
@@ -156,7 +157,7 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
156
157
|
|
|
157
158
|
defaults = {
|
|
158
159
|
"lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay,
|
|
159
|
-
"fisher_wd": fisher_wd, "cautious_wd": cautious_wd,
|
|
160
|
+
"fisher_wd": fisher_wd, "cautious_wd": cautious_wd, "scaled_wd": scaled_wd,
|
|
160
161
|
"use_atan2": use_atan2, "nesterov": nesterov, "nesterov_coef": nesterov_coef,
|
|
161
162
|
"normed_momentum": normed_momentum,
|
|
162
163
|
"orthogonal_gradient": orthogonal_gradient, "use_bias_correction": use_bias_correction,
|
|
@@ -101,6 +101,7 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
101
101
|
weight_decay: float = 0.0,
|
|
102
102
|
fisher_wd: bool = False,
|
|
103
103
|
cautious_wd: bool = False,
|
|
104
|
+
scaled_wd: bool = False,
|
|
104
105
|
# ADOPT clipping
|
|
105
106
|
clip_lambda: Optional[Callable[[int], float]] = lambda step: step**0.25,
|
|
106
107
|
# Adam_atan2 (scale invariant)
|
|
@@ -157,7 +158,7 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
157
158
|
state_precision = "factored"
|
|
158
159
|
|
|
159
160
|
defaults = {
|
|
160
|
-
"lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay,
|
|
161
|
+
"lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay, "scaled_wd": scaled_wd,
|
|
161
162
|
"fisher_wd": fisher_wd, "cautious_wd": cautious_wd, "orthogonal_gradient": orthogonal_gradient,
|
|
162
163
|
"nesterov": nesterov, "nesterov_coef": nesterov_coef,
|
|
163
164
|
"kourkoutas_beta": kourkoutas_beta, "beta2_min": beta2_min, "ema_alpha": ema_alpha,
|
|
@@ -64,6 +64,7 @@ class Lion_adv(torch.optim.Optimizer):
|
|
|
64
64
|
# Decoupled/cautious weight decay
|
|
65
65
|
weight_decay: float = 0.0,
|
|
66
66
|
cautious_wd: bool = False,
|
|
67
|
+
scaled_wd: bool = False,
|
|
67
68
|
# Stochastic Rounding for BF16
|
|
68
69
|
stochastic_rounding: bool = True,
|
|
69
70
|
# OrthoGrad
|
|
@@ -96,6 +97,7 @@ class Lion_adv(torch.optim.Optimizer):
|
|
|
96
97
|
betas=betas,
|
|
97
98
|
weight_decay=weight_decay,
|
|
98
99
|
cautious_wd=cautious_wd,
|
|
100
|
+
scaled_wd=scaled_wd,
|
|
99
101
|
vector_reshape=vector_reshape,
|
|
100
102
|
orthogonal_gradient=orthogonal_gradient,
|
|
101
103
|
kappa_p=kappa_p,
|
|
@@ -111,6 +111,7 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
111
111
|
# Decoupled/cautious weight decay
|
|
112
112
|
weight_decay: float = 0.0,
|
|
113
113
|
cautious_wd: bool = False,
|
|
114
|
+
scaled_wd: bool = False,
|
|
114
115
|
# Nesterov momentum
|
|
115
116
|
nesterov: bool = True,
|
|
116
117
|
nesterov_coef: float | None = None,
|
|
@@ -201,7 +202,7 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
201
202
|
defaults = {
|
|
202
203
|
"lr": lr, "beta1": beta1, "weight_decay": weight_decay, "cautious_wd": cautious_wd,
|
|
203
204
|
"nesterov": nesterov, "nesterov_coef": nesterov_coef, "ns_steps": ns_steps, "ns_eps": ns_eps,
|
|
204
|
-
"ns_coeffs": ns_coeffs, "nnmf_factor": nnmf_factor,
|
|
205
|
+
"ns_coeffs": ns_coeffs, "nnmf_factor": nnmf_factor, "scaled_wd": scaled_wd,
|
|
205
206
|
"vector_reshape": vector_reshape, "rms_rescaling": rms_rescaling,
|
|
206
207
|
"orthogonal_gradient": orthogonal_gradient,
|
|
207
208
|
'compiled_optimizer': compiled_optimizer,
|
|
@@ -115,6 +115,7 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
115
115
|
weight_decay: float = 0.0,
|
|
116
116
|
fisher_wd: bool = False,
|
|
117
117
|
cautious_wd: bool = False,
|
|
118
|
+
scaled_wd: bool = False,
|
|
118
119
|
# Stochastic Rounding for BF16
|
|
119
120
|
stochastic_rounding: bool = True,
|
|
120
121
|
# Adam_atan2 (scale invariant)
|
|
@@ -181,7 +182,7 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
181
182
|
|
|
182
183
|
defaults = {
|
|
183
184
|
"lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay,
|
|
184
|
-
"fisher_wd": fisher_wd, "cautious_wd": cautious_wd,
|
|
185
|
+
"fisher_wd": fisher_wd, "cautious_wd": cautious_wd, "scaled_wd": scaled_wd,
|
|
185
186
|
"use_atan2": use_atan2,
|
|
186
187
|
"orthogonal_gradient": orthogonal_gradient,
|
|
187
188
|
"compiled_optimizer": compiled_optimizer,
|
|
@@ -59,6 +59,7 @@ class SignSGD_adv(torch.optim.Optimizer):
|
|
|
59
59
|
# weight decay features
|
|
60
60
|
geometric_wd: bool = False,
|
|
61
61
|
cautious_wd: bool = False,
|
|
62
|
+
scaled_wd: bool = False,
|
|
62
63
|
# Stochastic Rounding for BF16
|
|
63
64
|
stochastic_rounding: bool = True,
|
|
64
65
|
# OrthoGrad
|
|
@@ -108,6 +109,7 @@ class SignSGD_adv(torch.optim.Optimizer):
|
|
|
108
109
|
momentum=momentum,
|
|
109
110
|
weight_decay=weight_decay,
|
|
110
111
|
cautious_wd=cautious_wd,
|
|
112
|
+
scaled_wd=scaled_wd,
|
|
111
113
|
geometric_wd=geometric_wd,
|
|
112
114
|
vector_reshape=vector_reshape,
|
|
113
115
|
orthogonal_gradient=orthogonal_gradient,
|
|
@@ -66,6 +66,7 @@ class SinkSGD_adv(torch.optim.Optimizer):
|
|
|
66
66
|
# weight decay features
|
|
67
67
|
geometric_wd: bool = False,
|
|
68
68
|
cautious_wd: bool = False,
|
|
69
|
+
scaled_wd: bool = False,
|
|
69
70
|
# Stochastic Rounding for BF16
|
|
70
71
|
stochastic_rounding: bool = True,
|
|
71
72
|
# OrthoGrad
|
|
@@ -103,7 +104,7 @@ class SinkSGD_adv(torch.optim.Optimizer):
|
|
|
103
104
|
defaults = {
|
|
104
105
|
"lr": lr, "momentum": momentum,
|
|
105
106
|
"weight_decay": weight_decay, "nesterov": nesterov, "nesterov_coef": nesterov_coef, "normed_momentum": normed_momentum, "snr_cond": snr_cond,
|
|
106
|
-
"geometric_wd": geometric_wd, "cautious_wd": cautious_wd,
|
|
107
|
+
"geometric_wd": geometric_wd, "cautious_wd": cautious_wd, "scaled_wd": scaled_wd,
|
|
107
108
|
"orthogonal_gradient": orthogonal_gradient,
|
|
108
109
|
"compiled_optimizer": compiled_optimizer,
|
|
109
110
|
"sinkhorn_iterations": sinkhorn_iterations,
|
|
@@ -43,7 +43,7 @@ def iterative_ortho_project(p: torch.Tensor, grad: torch.Tensor, iters: int = 3)
|
|
|
43
43
|
# 1D Vector Case fallback to the standard OrthoGrad
|
|
44
44
|
is_vector = p.ndim < 2 or getattr(p, '_is_dora_scale', False) or getattr(p, 'is_vector', False)
|
|
45
45
|
if is_vector:
|
|
46
|
-
return
|
|
46
|
+
return flattened_ortho_project(p, grad)
|
|
47
47
|
|
|
48
48
|
original_shape = grad.shape
|
|
49
49
|
|
|
@@ -6,7 +6,7 @@ import torch.nn.functional as F
|
|
|
6
6
|
|
|
7
7
|
from typing import Dict, Any
|
|
8
8
|
|
|
9
|
-
from .scaled_optm import adjust_wds
|
|
9
|
+
from .scaled_optm import adjust_wds, scale_wd
|
|
10
10
|
from .centered_decay import dequantize_anchor
|
|
11
11
|
|
|
12
12
|
_generators: Dict[torch.device, torch.Generator] = {}
|
|
@@ -18,8 +18,8 @@ def _apply_weight_decay(
|
|
|
18
18
|
p: Tensor,
|
|
19
19
|
state: Dict[str, Any],
|
|
20
20
|
group: Dict[str, Any],
|
|
21
|
-
|
|
22
|
-
|
|
21
|
+
eff_wd: float | Tensor | None,
|
|
22
|
+
eff_cwd: float | Tensor | None,
|
|
23
23
|
wd_target: Tensor | None = None,
|
|
24
24
|
cwd_target: Tensor | None = None,
|
|
25
25
|
) -> None:
|
|
@@ -29,26 +29,26 @@ def _apply_weight_decay(
|
|
|
29
29
|
cautious = group.get('cautious_wd', False)
|
|
30
30
|
|
|
31
31
|
# Standard Weight Decay (pulls toward zero)
|
|
32
|
-
if
|
|
32
|
+
if eff_wd is not None:
|
|
33
33
|
if wd_target is None:
|
|
34
34
|
wd_target = p_calc
|
|
35
35
|
# Cautious Weight Decay: only decay if the update pushes in the same direction as the decay
|
|
36
36
|
if cautious:
|
|
37
37
|
mask = (update_calc * p_calc >= 0).to(p_calc.dtype)
|
|
38
|
-
if isinstance(
|
|
39
|
-
p_calc.addcmul_(wd_target, mask *
|
|
38
|
+
if isinstance(eff_wd, Tensor):
|
|
39
|
+
p_calc.addcmul_(wd_target, mask * eff_wd, value=-1.0)
|
|
40
40
|
else:
|
|
41
|
-
p_calc.addcmul_(wd_target, mask, value=-
|
|
41
|
+
p_calc.addcmul_(wd_target, mask, value=-eff_wd)
|
|
42
42
|
del mask
|
|
43
43
|
else:
|
|
44
44
|
# Standard decoupled weight decay
|
|
45
|
-
if isinstance(
|
|
46
|
-
p_calc.addcmul_(wd_target,
|
|
45
|
+
if isinstance(eff_wd, Tensor):
|
|
46
|
+
p_calc.addcmul_(wd_target, eff_wd, value=-1.0)
|
|
47
47
|
else:
|
|
48
|
-
p_calc.add_(wd_target, alpha=-
|
|
48
|
+
p_calc.add_(wd_target, alpha=-eff_wd)
|
|
49
49
|
|
|
50
50
|
# Centered Weight Decay (pulls toward anchor)
|
|
51
|
-
if
|
|
51
|
+
if eff_cwd is not None and 'anchor_data' in state:
|
|
52
52
|
if cwd_target is not None:
|
|
53
53
|
decay_target = cwd_target
|
|
54
54
|
else:
|
|
@@ -59,17 +59,17 @@ def _apply_weight_decay(
|
|
|
59
59
|
if cautious:
|
|
60
60
|
# Cautious Weight Decay: only decay if the update pushes in the same direction as the decay
|
|
61
61
|
mask = (update_calc * decay_target >= 0).to(p_calc.dtype)
|
|
62
|
-
if isinstance(
|
|
63
|
-
p_calc.addcmul_(decay_target, mask *
|
|
62
|
+
if isinstance(eff_cwd, Tensor):
|
|
63
|
+
p_calc.addcmul_(decay_target, mask * eff_cwd, value=-1.0)
|
|
64
64
|
else:
|
|
65
|
-
p_calc.addcmul_(decay_target, mask, value=-
|
|
65
|
+
p_calc.addcmul_(decay_target, mask, value=-eff_cwd)
|
|
66
66
|
del mask
|
|
67
67
|
else:
|
|
68
68
|
# Standard decoupled weight decay
|
|
69
|
-
if isinstance(
|
|
70
|
-
p_calc.addcmul_(decay_target,
|
|
69
|
+
if isinstance(eff_cwd, Tensor):
|
|
70
|
+
p_calc.addcmul_(decay_target, eff_cwd, value=-1.0)
|
|
71
71
|
else:
|
|
72
|
-
p_calc.add_(decay_target, alpha=-
|
|
72
|
+
p_calc.add_(decay_target, alpha=-eff_cwd)
|
|
73
73
|
|
|
74
74
|
if cwd_target is None:
|
|
75
75
|
del decay_target
|
|
@@ -105,18 +105,24 @@ def apply_parameter_update(
|
|
|
105
105
|
wd = group["weight_decay"] if wd is None else wd
|
|
106
106
|
cwd = group.get("centered_wd", 0.0)
|
|
107
107
|
wd, cwd = adjust_wds(wd, cwd, p)
|
|
108
|
+
scaled_wd = group.get("scaled_wd", False)
|
|
109
|
+
decoupled = scaled_wd
|
|
108
110
|
|
|
109
111
|
# Calculate global decay factor for decoupled vs standard
|
|
110
112
|
decay_factor = (lr / self._init_lr) if decoupled else lr
|
|
111
113
|
|
|
112
|
-
|
|
113
|
-
|
|
114
|
+
eff_wd = (wd * decay_factor) if wd != 0 else None
|
|
115
|
+
eff_cwd = (cwd * decay_factor) if cwd != 0 else None
|
|
114
116
|
|
|
115
117
|
if wd_scaler is not None:
|
|
116
|
-
if
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
118
|
+
if eff_wd is not None:
|
|
119
|
+
if scaled_wd:
|
|
120
|
+
eff_wd = scale_wd(eff_wd, p)
|
|
121
|
+
eff_wd = eff_wd * wd_scaler
|
|
122
|
+
if eff_cwd is not None:
|
|
123
|
+
if scaled_wd:
|
|
124
|
+
eff_cwd = scale_wd(eff_cwd, p)
|
|
125
|
+
eff_cwd = eff_cwd * wd_scaler
|
|
120
126
|
|
|
121
127
|
state = self.state[p]
|
|
122
128
|
|
|
@@ -129,8 +135,8 @@ def apply_parameter_update(
|
|
|
129
135
|
cwd_t = cwd_target.float() if cwd_target is not None else None
|
|
130
136
|
|
|
131
137
|
# Apply weight decay if needed
|
|
132
|
-
if
|
|
133
|
-
_apply_weight_decay(p_fp32, update_fp32, p, state, group,
|
|
138
|
+
if eff_wd is not None or eff_cwd is not None:
|
|
139
|
+
_apply_weight_decay(p_fp32, update_fp32, p, state, group, eff_wd, eff_cwd, wd_t, cwd_t)
|
|
134
140
|
|
|
135
141
|
# Apply main update
|
|
136
142
|
p_fp32.add_(-update_fp32)
|
|
@@ -147,8 +153,8 @@ def apply_parameter_update(
|
|
|
147
153
|
|
|
148
154
|
else:
|
|
149
155
|
# Standard path for non-bfloat16 or without stochastic rounding
|
|
150
|
-
if
|
|
151
|
-
_apply_weight_decay(p, update, p, state, group,
|
|
156
|
+
if eff_wd is not None or eff_cwd is not None:
|
|
157
|
+
_apply_weight_decay(p, update, p, state, group, eff_wd, eff_cwd, wd_target, cwd_target)
|
|
152
158
|
|
|
153
159
|
# Apply main update
|
|
154
160
|
p.add_(-update)
|
|
@@ -7,14 +7,14 @@ import math
|
|
|
7
7
|
_OFT_INDICES_CACHE = {}
|
|
8
8
|
_OFT_IDENTITY_CACHE = {}
|
|
9
9
|
|
|
10
|
-
def get_cached_structural_tensors(b: int,
|
|
10
|
+
def get_cached_structural_tensors(b: int, device: torch.device):
|
|
11
11
|
"""
|
|
12
|
-
Retrieves or creates structural tensors (indices
|
|
12
|
+
Retrieves or creates structural tensors (indices) for OFT exact geometry.
|
|
13
13
|
Caches them globally to prevent redundant memory allocation across thousands of layers.
|
|
14
14
|
"""
|
|
15
|
-
global _OFT_INDICES_CACHE
|
|
15
|
+
global _OFT_INDICES_CACHE
|
|
16
16
|
|
|
17
|
-
# Cache for Indices
|
|
17
|
+
# Cache for Indices
|
|
18
18
|
idx_key = (b, device)
|
|
19
19
|
if idx_key not in _OFT_INDICES_CACHE:
|
|
20
20
|
rows, cols = torch.triu_indices(b, b, 1, device=device)
|
|
@@ -22,15 +22,8 @@ def get_cached_structural_tensors(b: int, dtype: torch.dtype, device: torch.devi
|
|
|
22
22
|
else:
|
|
23
23
|
rows, cols = _OFT_INDICES_CACHE[idx_key]
|
|
24
24
|
|
|
25
|
-
# Cache for Identity Matrix (Depends on block size, dtype, and device)
|
|
26
|
-
id_key = (b, dtype, device)
|
|
27
|
-
if id_key not in _OFT_IDENTITY_CACHE:
|
|
28
|
-
I = torch.eye(b, dtype=dtype, device=device).unsqueeze(0)
|
|
29
|
-
_OFT_IDENTITY_CACHE[id_key] = I
|
|
30
|
-
else:
|
|
31
|
-
I = _OFT_IDENTITY_CACHE[id_key]
|
|
32
25
|
|
|
33
|
-
return rows, cols
|
|
26
|
+
return rows, cols
|
|
34
27
|
|
|
35
28
|
def scale_update(
|
|
36
29
|
p: torch.Tensor,
|
|
@@ -59,7 +52,7 @@ def scale_update(
|
|
|
59
52
|
return max_abs_normalization(update, dim=None, lr=lr)
|
|
60
53
|
|
|
61
54
|
# OFT Block Parameters: shape (k, C(b,2))
|
|
62
|
-
# Direct spectral normalization on the skew-symmetric blocks
|
|
55
|
+
# Direct spectral normalization on the skew-symmetric blocks.
|
|
63
56
|
if is_oft:
|
|
64
57
|
return apply_spectral_riemannian_oft(p, update, lr, state)
|
|
65
58
|
|
|
@@ -104,6 +97,20 @@ def adjust_wds(wd: float, cwd: float, p: torch.Tensor) -> tuple[float, float]:
|
|
|
104
97
|
# Centered WD safely regularizes the delta without collapsing base feature variance.
|
|
105
98
|
return wd, cwd
|
|
106
99
|
|
|
100
|
+
def scale_wd(wd: float, p: torch.Tensor) -> float:
|
|
101
|
+
"""
|
|
102
|
+
Scale-invariant, dimension-scaled weight decay.
|
|
103
|
+
"""
|
|
104
|
+
if getattr(p, '_is_oft', False):
|
|
105
|
+
n_el = p.shape[-1]
|
|
106
|
+
b = (1.0 + math.sqrt(1.0 + 8.0 * n_el)) / 2.0
|
|
107
|
+
wd = (2 * wd) / (b - 1)
|
|
108
|
+
return wd
|
|
109
|
+
|
|
110
|
+
if p.ndim >= 2:
|
|
111
|
+
width = p.numel() // p.shape[0]
|
|
112
|
+
return wd / width
|
|
113
|
+
|
|
107
114
|
|
|
108
115
|
def is_spectral(p: torch.Tensor) -> bool:
|
|
109
116
|
"""Determines if a parameter should undergo spectral normalization updates."""
|
|
@@ -176,34 +183,26 @@ def apply_spectral_riemannian_oft(
|
|
|
176
183
|
state: dict
|
|
177
184
|
) -> torch.Tensor:
|
|
178
185
|
"""
|
|
179
|
-
Applies Spectral Normalization directly on the skew-symmetric gradient
|
|
180
|
-
then uses True Matrix Preconditioning: M @ G @ M where M = (I - Q^2).
|
|
181
|
-
Neutralizes the derivative shrinkage of the Cayley transform.
|
|
186
|
+
Applies Spectral Normalization directly on the skew-symmetric gradient.
|
|
182
187
|
"""
|
|
183
188
|
n_el = p.shape[-1]
|
|
184
189
|
block_size = int((1 + math.sqrt(1 + 8 * n_el)) / 2)
|
|
185
190
|
device, dtype = p.device, p.dtype
|
|
186
|
-
rows, cols
|
|
191
|
+
rows, cols = get_cached_structural_tensors(block_size, device)
|
|
187
192
|
|
|
188
193
|
# Flatten any prepended batch dimensions for processing
|
|
189
194
|
orig_shape = p.shape
|
|
190
195
|
|
|
191
196
|
# Align the scale of p with the forward pass
|
|
192
197
|
scale_factor = getattr(p, '_oft_scale_factor', 1.0)
|
|
193
|
-
p_flat = p.view(-1, n_el) / scale_factor
|
|
194
198
|
|
|
195
199
|
update_flat = update.view(-1, n_el)
|
|
196
|
-
batch_size =
|
|
200
|
+
batch_size = update_flat.shape[0]
|
|
197
201
|
|
|
198
202
|
# Initialize matrices
|
|
199
|
-
Q = torch.zeros(batch_size, block_size, block_size, device=device, dtype=dtype)
|
|
200
203
|
G = torch.zeros(batch_size, block_size, block_size, device=device, dtype=dtype)
|
|
201
204
|
batch_idx = torch.arange(batch_size, device=device)[:, None]
|
|
202
205
|
|
|
203
|
-
# Construct skew-symmetric parameter matrix Q
|
|
204
|
-
Q = Q.index_put((batch_idx, rows, cols), p_flat)
|
|
205
|
-
Q = Q - Q.transpose(-2, -1)
|
|
206
|
-
|
|
207
206
|
# Construct skew-symmetric gradient matrix G
|
|
208
207
|
G = G.index_put((batch_idx, rows, cols), update_flat)
|
|
209
208
|
G = G - G.transpose(-2, -1)
|
|
@@ -235,21 +234,9 @@ def apply_spectral_riemannian_oft(
|
|
|
235
234
|
target_scale = 0.5 * scale_factor
|
|
236
235
|
spectral_eps = 1.0 / (2.0 * math.sqrt(block_size))
|
|
237
236
|
|
|
238
|
-
# Rescale G
|
|
239
237
|
scale = lr * (target_scale / max_sigma.clamp_min(spectral_eps))
|
|
240
|
-
G = G * scale
|
|
241
|
-
|
|
242
|
-
# Apply Riemannian Preconditioning
|
|
243
|
-
# Compute True Matrix Preconditioner M = I - Q^2
|
|
244
|
-
M = I - torch.bmm(Q, Q)
|
|
245
|
-
|
|
246
|
-
# Apply exact preconditioning: G_prec = M @ G @ M
|
|
247
|
-
G_prec = torch.bmm(torch.bmm(M, G), M)
|
|
248
|
-
|
|
249
|
-
# Extract the preconditioned upper-triangular elements
|
|
250
|
-
update_prec_flat = G_prec[batch_idx, rows, cols]
|
|
251
238
|
|
|
252
|
-
return
|
|
239
|
+
return update_flat.mul_(scale).view(orig_shape)
|
|
253
240
|
|
|
254
241
|
|
|
255
242
|
@torch.no_grad()
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|