adv-optm 2.5.7__tar.gz → 2.6.dev1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {adv_optm-2.5.7 → adv_optm-2.6.dev1}/PKG-INFO +1 -1
- {adv_optm-2.5.7 → adv_optm-2.6.dev1}/adv_optm/__init__.py +1 -1
- {adv_optm-2.5.7 → adv_optm-2.6.dev1}/adv_optm/optim/AdaMuon_adv.py +3 -3
- {adv_optm-2.5.7 → adv_optm-2.6.dev1}/adv_optm/optim/AdamW_adv.py +3 -1
- {adv_optm-2.5.7 → adv_optm-2.6.dev1}/adv_optm/optim/Adopt_adv.py +3 -1
- {adv_optm-2.5.7 → adv_optm-2.6.dev1}/adv_optm/optim/Lion_adv.py +3 -0
- {adv_optm-2.5.7 → adv_optm-2.6.dev1}/adv_optm/optim/Muon_adv.py +3 -3
- {adv_optm-2.5.7 → adv_optm-2.6.dev1}/adv_optm/optim/Prodigy_adv.py +6 -1
- {adv_optm-2.5.7 → adv_optm-2.6.dev1}/adv_optm/optim/SignSGD_adv.py +3 -0
- {adv_optm-2.5.7 → adv_optm-2.6.dev1}/adv_optm/optim/SinkSGD_adv.py +3 -1
- {adv_optm-2.5.7 → adv_optm-2.6.dev1}/adv_optm/util/Kourkoutas.py +0 -28
- {adv_optm-2.5.7 → adv_optm-2.6.dev1}/adv_optm/util/Muon_AuxAdam.py +1 -0
- adv_optm-2.6.dev1/adv_optm/util/msign.py +114 -0
- {adv_optm-2.5.7 → adv_optm-2.6.dev1}/adv_optm/util/param_update.py +12 -0
- {adv_optm-2.5.7 → adv_optm-2.6.dev1}/adv_optm.egg-info/PKG-INFO +1 -1
- {adv_optm-2.5.7 → adv_optm-2.6.dev1}/adv_optm.egg-info/SOURCES.txt +1 -0
- {adv_optm-2.5.7 → adv_optm-2.6.dev1}/setup.py +1 -1
- {adv_optm-2.5.7 → adv_optm-2.6.dev1}/LICENSE +0 -0
- {adv_optm-2.5.7 → adv_optm-2.6.dev1}/README.md +0 -0
- {adv_optm-2.5.7 → adv_optm-2.6.dev1}/adv_optm/optim/__init__.py +0 -0
- {adv_optm-2.5.7 → adv_optm-2.6.dev1}/adv_optm/util/Muon_util.py +0 -0
- {adv_optm-2.5.7 → adv_optm-2.6.dev1}/adv_optm/util/OrthoGrad.py +0 -0
- {adv_optm-2.5.7 → adv_optm-2.6.dev1}/adv_optm/util/__init__.py +0 -0
- {adv_optm-2.5.7 → adv_optm-2.6.dev1}/adv_optm/util/centered_decay.py +0 -0
- {adv_optm-2.5.7 → adv_optm-2.6.dev1}/adv_optm/util/factorization_util.py +0 -0
- {adv_optm-2.5.7 → adv_optm-2.6.dev1}/adv_optm/util/lion_k.py +0 -0
- {adv_optm-2.5.7 → adv_optm-2.6.dev1}/adv_optm/util/scaled_optm.py +0 -0
- {adv_optm-2.5.7 → adv_optm-2.6.dev1}/adv_optm/util/signed_util.py +0 -0
- {adv_optm-2.5.7 → adv_optm-2.6.dev1}/adv_optm/util/sinkhorn.py +0 -0
- {adv_optm-2.5.7 → adv_optm-2.6.dev1}/adv_optm/util/state_util.py +0 -0
- {adv_optm-2.5.7 → adv_optm-2.6.dev1}/adv_optm/util/update_util.py +0 -0
- {adv_optm-2.5.7 → adv_optm-2.6.dev1}/adv_optm.egg-info/dependency_links.txt +0 -0
- {adv_optm-2.5.7 → adv_optm-2.6.dev1}/adv_optm.egg-info/requires.txt +0 -0
- {adv_optm-2.5.7 → adv_optm-2.6.dev1}/adv_optm.egg-info/top_level.txt +0 -0
- {adv_optm-2.5.7 → adv_optm-2.6.dev1}/setup.cfg +0 -0
|
@@ -106,7 +106,6 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
106
106
|
low-rank compression while keeping the first moment (momentum_buffer)
|
|
107
107
|
dense. Ignored when `nnmf_factor=True` (full SMMF) or `normuon_variant=True`.
|
|
108
108
|
Combines well with `state_precision` on the first moment. (default: False)
|
|
109
|
-
n_layers (int): The depth of the network (L). Required for optimal epsilon scaling. (default: 1)
|
|
110
109
|
spectral_normalization (bool): Enable explicit spectral normalization using power iteration. (default: False)
|
|
111
110
|
--- Auxiliary AdamW_adv Parameters (used for 'adam' groups) ---
|
|
112
111
|
adam_betas (tuple[float, float]): Betas for the AdamW optimizer part.
|
|
@@ -177,8 +176,9 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
177
176
|
approx_mars: bool = False,
|
|
178
177
|
mars_gamma: float = 0.025,
|
|
179
178
|
# Spectral Normalization
|
|
180
|
-
n_layers: int = 1,
|
|
181
179
|
spectral_normalization: bool = False,
|
|
180
|
+
# Orthogonalize the weights (Matrix Sign - MSign) every x steps
|
|
181
|
+
MSign_interval: int | None = None,
|
|
182
182
|
# Centered WD
|
|
183
183
|
centered_wd: float = 0.0,
|
|
184
184
|
centered_wd_mode: str = 'float8',
|
|
@@ -249,7 +249,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
249
249
|
# MARS-M
|
|
250
250
|
"approx_mars": approx_mars, "mars_gamma": mars_gamma,
|
|
251
251
|
# Spectral Normalization
|
|
252
|
-
"
|
|
252
|
+
"spectral_normalization": spectral_normalization, "MSign_interval": MSign_interval,
|
|
253
253
|
# Centered WD
|
|
254
254
|
"centered_wd": centered_wd,
|
|
255
255
|
"centered_wd_mode": centered_wd_mode,
|
|
@@ -121,6 +121,8 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
121
121
|
layer_key_fn: Optional[Callable] = None,
|
|
122
122
|
# Spectral Normed Optimizer
|
|
123
123
|
spectral_normalization: bool = False,
|
|
124
|
+
# Orthogonalize the weights (Matrix Sign - MSign) every x steps
|
|
125
|
+
MSign_interval: int | None = None,
|
|
124
126
|
# Centered WD
|
|
125
127
|
centered_wd: float = 0.0,
|
|
126
128
|
centered_wd_mode: str = 'float8',
|
|
@@ -163,7 +165,7 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
163
165
|
"compiled_optimizer": compiled_optimizer,
|
|
164
166
|
"kourkoutas_beta": kourkoutas_beta, "beta2_min": beta2_min, "ema_alpha": ema_alpha,
|
|
165
167
|
"tiny_spike": tiny_spike, "k_warmup_steps": k_warmup_steps, "k_logging": k_logging,
|
|
166
|
-
"spectral_normalization": spectral_normalization,
|
|
168
|
+
"spectral_normalization": spectral_normalization, "MSign_interval": MSign_interval,
|
|
167
169
|
"centered_wd": centered_wd, "centered_wd_mode": centered_wd_mode,
|
|
168
170
|
"state_precision": state_precision,
|
|
169
171
|
"nnmf_factor": nnmf_factor, "vector_reshape": vector_reshape, "factored_2nd": factored_2nd
|
|
@@ -122,6 +122,8 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
122
122
|
layer_key_fn: Optional[Callable] = None,
|
|
123
123
|
# Spectral Normed Optimizer
|
|
124
124
|
spectral_normalization: bool = False,
|
|
125
|
+
# Orthogonalize the weights (Matrix Sign - MSign) every x steps
|
|
126
|
+
MSign_interval: int | None = None,
|
|
125
127
|
# Centered WD
|
|
126
128
|
centered_wd: float = 0.0,
|
|
127
129
|
centered_wd_mode: str = 'float8',
|
|
@@ -162,7 +164,7 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
162
164
|
"nesterov": nesterov, "nesterov_coef": nesterov_coef,
|
|
163
165
|
"kourkoutas_beta": kourkoutas_beta, "beta2_min": beta2_min, "ema_alpha": ema_alpha,
|
|
164
166
|
"tiny_spike": tiny_spike, "k_warmup_steps": k_warmup_steps, "k_logging": k_logging,
|
|
165
|
-
"spectral_normalization": spectral_normalization,
|
|
167
|
+
"spectral_normalization": spectral_normalization, "MSign_interval": MSign_interval,
|
|
166
168
|
"centered_wd": centered_wd,
|
|
167
169
|
"centered_wd_mode": centered_wd_mode,
|
|
168
170
|
"state_precision": state_precision,
|
|
@@ -78,6 +78,8 @@ class Lion_adv(torch.optim.Optimizer):
|
|
|
78
78
|
centered_wd_mode: str = 'float8',
|
|
79
79
|
# Spectral Normed Optimizer
|
|
80
80
|
spectral_normalization: bool = False,
|
|
81
|
+
# Orthogonalize the weights (Matrix Sign - MSign) every x steps
|
|
82
|
+
MSign_interval: int | None = None,
|
|
81
83
|
# SMMF factorization
|
|
82
84
|
nnmf_factor: bool = False,
|
|
83
85
|
vector_reshape: bool = False,
|
|
@@ -102,6 +104,7 @@ class Lion_adv(torch.optim.Optimizer):
|
|
|
102
104
|
auto_kappa_p=auto_kappa_p,
|
|
103
105
|
stochastic_sign=stochastic_sign,
|
|
104
106
|
spectral_normalization=spectral_normalization,
|
|
107
|
+
MSign_interval=MSign_interval,
|
|
105
108
|
nnmf_factor=nnmf_factor,
|
|
106
109
|
centered_wd= centered_wd,
|
|
107
110
|
centered_wd_mode= centered_wd_mode,
|
|
@@ -81,7 +81,6 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
81
81
|
'float8': Uses torch.float8_e4m3fn for a balance of precision and memory.
|
|
82
82
|
'int8': Uses 8-bit block-wise quantization (block size 128).
|
|
83
83
|
'int4': Uses 4-bit block-wise quantization (block size 32).
|
|
84
|
-
n_layers (int): The depth of the network (L). Required for optimal epsilon scaling. (default: 1)
|
|
85
84
|
spectral_normalization (bool): Enable explicit spectral normalization using power iteration. (default: False)
|
|
86
85
|
--- Auxiliary AdamW_adv Parameters (used for 'adam' groups) ---
|
|
87
86
|
adam_betas (tuple[float, float]): Betas for the AdamW optimizer part.
|
|
@@ -146,8 +145,9 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
146
145
|
approx_mars: bool = False,
|
|
147
146
|
mars_gamma: float = 0.025,
|
|
148
147
|
# Spectral Normalization
|
|
149
|
-
n_layers: int = 1,
|
|
150
148
|
spectral_normalization: bool = False,
|
|
149
|
+
# Orthogonalize the weights (Matrix Sign - MSign) every x steps
|
|
150
|
+
MSign_interval: int | None = None,
|
|
151
151
|
# Centered WD
|
|
152
152
|
centered_wd: float = 0.0,
|
|
153
153
|
centered_wd_mode: str = 'float8',
|
|
@@ -220,7 +220,7 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
220
220
|
# MARS-M
|
|
221
221
|
"approx_mars": approx_mars, "mars_gamma": mars_gamma,
|
|
222
222
|
# Spectral Normalization
|
|
223
|
-
"
|
|
223
|
+
"spectral_normalization": spectral_normalization, "MSign_interval": MSign_interval,
|
|
224
224
|
# Centered WD
|
|
225
225
|
"centered_wd": centered_wd,
|
|
226
226
|
"centered_wd_mode": centered_wd_mode,
|
|
@@ -154,6 +154,10 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
154
154
|
# Centered WD
|
|
155
155
|
centered_wd: float = 0.0,
|
|
156
156
|
centered_wd_mode: str = 'float8',
|
|
157
|
+
# Spectral Normalization
|
|
158
|
+
spectral_normalization: bool = False,
|
|
159
|
+
# Orthogonalize the weights (Matrix Sign - MSign) every x steps
|
|
160
|
+
MSign_interval: int | None = None,
|
|
157
161
|
):
|
|
158
162
|
if not (lr >= 0.0):
|
|
159
163
|
raise ValueError(f"Learning-rate should be >= 0.0. Got {lr}")
|
|
@@ -190,7 +194,8 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
190
194
|
"kourkoutas_beta": kourkoutas_beta, "beta2_min": beta2_min, "ema_alpha": ema_alpha,
|
|
191
195
|
"tiny_spike": tiny_spike, "k_warmup_steps": k_warmup_steps, "k_logging": k_logging,
|
|
192
196
|
"centered_wd": centered_wd, "centered_wd_mode": centered_wd_mode,
|
|
193
|
-
"nnmf_factor": nnmf_factor, "vector_reshape": vector_reshape, "factored_2nd": factored_2nd
|
|
197
|
+
"nnmf_factor": nnmf_factor, "vector_reshape": vector_reshape, "factored_2nd": factored_2nd,
|
|
198
|
+
"spectral_normalization": spectral_normalization, "MSign_interval": MSign_interval,
|
|
194
199
|
}
|
|
195
200
|
self.stochastic_rounding = stochastic_rounding
|
|
196
201
|
self.fsdp_in_use = fsdp_in_use
|
|
@@ -79,6 +79,8 @@ class SignSGD_adv(torch.optim.Optimizer):
|
|
|
79
79
|
state_precision: str = "auto", # 'fp32', 'factored', 'bf16_sr', 'int8_sr'.
|
|
80
80
|
# Spectral Normed Optimizer
|
|
81
81
|
spectral_normalization: bool = False,
|
|
82
|
+
# Orthogonalize the weights (Matrix Sign - MSign) every x steps
|
|
83
|
+
MSign_interval: int | None = None,
|
|
82
84
|
# SMMF factorization
|
|
83
85
|
nnmf_factor: bool = False,
|
|
84
86
|
vector_reshape: bool = False,
|
|
@@ -117,6 +119,7 @@ class SignSGD_adv(torch.optim.Optimizer):
|
|
|
117
119
|
normed_momentum=normed_momentum,
|
|
118
120
|
snr_cond=snr_cond,
|
|
119
121
|
spectral_normalization=spectral_normalization,
|
|
122
|
+
MSign_interval=MSign_interval,
|
|
120
123
|
centered_wd= centered_wd,
|
|
121
124
|
centered_wd_mode= centered_wd_mode,
|
|
122
125
|
state_precision=state_precision,
|
|
@@ -72,6 +72,8 @@ class SinkSGD_adv(torch.optim.Optimizer):
|
|
|
72
72
|
orthogonal_gradient: str = 'disabled', # 'flattened', 'iterative'
|
|
73
73
|
# Spectral Normed Optimizer
|
|
74
74
|
spectral_normalization: bool = False,
|
|
75
|
+
# Orthogonalize the weights (Matrix Sign - MSign) every x steps
|
|
76
|
+
MSign_interval: int | None = None,
|
|
75
77
|
# Centered WD
|
|
76
78
|
centered_wd: float = 0.0,
|
|
77
79
|
centered_wd_mode: str = 'float8',
|
|
@@ -108,7 +110,7 @@ class SinkSGD_adv(torch.optim.Optimizer):
|
|
|
108
110
|
"compiled_optimizer": compiled_optimizer,
|
|
109
111
|
"sinkhorn_iterations": sinkhorn_iterations,
|
|
110
112
|
"orthogonal_sinkhorn": orthogonal_sinkhorn,
|
|
111
|
-
"spectral_normalization": spectral_normalization,
|
|
113
|
+
"spectral_normalization": spectral_normalization, "MSign_interval": MSign_interval,
|
|
112
114
|
"centered_wd": centered_wd, "centered_wd_mode": centered_wd_mode,
|
|
113
115
|
"state_precision": state_precision,
|
|
114
116
|
"nnmf_factor": nnmf_factor, "vector_reshape": vector_reshape
|
|
@@ -168,8 +168,6 @@ class KourkoutasHelper:
|
|
|
168
168
|
# Update the persistent EMA tensor in-place.
|
|
169
169
|
r_ema_tensor.mul_(ema_alpha).add_(pooled_grad_norm, alpha=1.0 - ema_alpha)
|
|
170
170
|
|
|
171
|
-
tiny_spike = scale_tiny_spike(group, info['params'], tiny_spike)
|
|
172
|
-
|
|
173
171
|
# Calculate Beta2
|
|
174
172
|
raw = pooled_grad_norm / (r_ema_tensor + tiny_spike)
|
|
175
173
|
sun = raw / (1.0 + raw)
|
|
@@ -255,29 +253,3 @@ class KourkoutasHelper:
|
|
|
255
253
|
# The default is the max value, which is correct for unmapped params or edge cases
|
|
256
254
|
beta2_default = group.get('betas', group.get('adam_betas'))[1] if group.get('betas', group.get('adam_betas')) else 0.999
|
|
257
255
|
return self.layer_state.get(layer_key, {}).get('dynamic_beta2', beta2_default)
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
def scale_tiny_spike(group: dict, layer_params: list, tiny_spike: float) -> float:
|
|
261
|
-
"""
|
|
262
|
-
Derives scale-invariant tiny_spike from the EMA tensor's effective numel.
|
|
263
|
-
"""
|
|
264
|
-
if not group.get('spectral_normalization', False):
|
|
265
|
-
return tiny_spike
|
|
266
|
-
|
|
267
|
-
p0 = layer_params[0]
|
|
268
|
-
if getattr(p0, '_is_lora_A', False) or p0.ndim < 2:
|
|
269
|
-
# No depth scaling for:
|
|
270
|
-
# - lora_A: non-zero init, different gradient dynamics than B
|
|
271
|
-
# - 1D params (biases, norms, DoRA scales): additive, don't compound through depth.
|
|
272
|
-
L = 1
|
|
273
|
-
else:
|
|
274
|
-
L = group['n_layers']
|
|
275
|
-
|
|
276
|
-
if getattr(p0, '_is_lora_A', False) or getattr(p0, '_is_oft', False):
|
|
277
|
-
ema_numel = p0.shape[1] # (1, in_features)
|
|
278
|
-
elif getattr(p0, '_is_lora_B', False):
|
|
279
|
-
ema_numel = p0.shape[0] # (out_features, 1)
|
|
280
|
-
else:
|
|
281
|
-
ema_numel = sum(p.numel() for p in layer_params) # scalar EMA
|
|
282
|
-
|
|
283
|
-
return 1.0 / (L * math.sqrt(ema_numel))
|
|
@@ -66,6 +66,7 @@ def _init_auxadam_state(self, p, group):
|
|
|
66
66
|
_init_anchor(p, state, group)
|
|
67
67
|
_init_fisher_wd_scaler(group, state, p)
|
|
68
68
|
|
|
69
|
+
group["MSign_interval"] = None
|
|
69
70
|
|
|
70
71
|
@torch.no_grad()
|
|
71
72
|
def _adam_step_parameter(self, p, grad, state, group, beta1_adam, beta2_adam, sqrt_bias_correction2, step_size, random_int_tensor, random_int_state_tensor=None):
|
|
@@ -0,0 +1,114 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
import math
|
|
4
|
+
|
|
5
|
+
@torch.no_grad()
|
|
6
|
+
def msign_ortho_precond_(p):
|
|
7
|
+
"""
|
|
8
|
+
Applies an orthogonal preconditioner to a parameter tensor in-place using Matrix Sign.
|
|
9
|
+
|
|
10
|
+
Originally proposed in the paper:
|
|
11
|
+
"MSign: An Optimizer Preventing Training Instability in Large Language Models via Stable Rank Restoration"(arxiv:2602.01734)
|
|
12
|
+
Modified to use CANS instead of SVD.
|
|
13
|
+
|
|
14
|
+
Args:
|
|
15
|
+
p (torch.Tensor): The parameter tensor to be preconditioned. Must be a dense
|
|
16
|
+
tensor. Vectors and biases should generally be excluded before calling this.
|
|
17
|
+
The operation modifies `p` in-place.
|
|
18
|
+
|
|
19
|
+
Note:
|
|
20
|
+
- The Frobenius norm of the matrix is strictly preserved.
|
|
21
|
+
- High-dimensional tensors (e.g., Conv2D weights) are reshaped into
|
|
22
|
+
`(out_channels, in_channels * ...)` automatically before computation.
|
|
23
|
+
"""
|
|
24
|
+
# Record the original Frobenius norm of the weight matrix
|
|
25
|
+
orig_norm = torch.linalg.vector_norm(p, ord=2).clamp_min_(1e-12)
|
|
26
|
+
# Reshape parameter to 2D for the matrix operation
|
|
27
|
+
p_2d = p.view(p.shape[0], -1)
|
|
28
|
+
# Approximate the matrix sign UV^T using CA Newton-Schulz
|
|
29
|
+
p_sign = _cans_newton_schulz_iteration(
|
|
30
|
+
p_2d,
|
|
31
|
+
steps=15,
|
|
32
|
+
eps=1e-7,
|
|
33
|
+
cns_a_bound=None, # auto
|
|
34
|
+
)
|
|
35
|
+
# Calculate the Frobenius norm of the sign-projected matrix
|
|
36
|
+
sign_norm = torch.linalg.vector_norm(p_sign, ord=2).clamp_min_(1e-12)
|
|
37
|
+
# Restore original Frobenius norm scale and write back in-place
|
|
38
|
+
p.copy_(p_sign.mul_(orig_norm / sign_norm).view_as(p))
|
|
39
|
+
|
|
40
|
+
@torch.no_grad()
|
|
41
|
+
def _cans_newton_schulz_iteration(
|
|
42
|
+
G: torch.Tensor,
|
|
43
|
+
steps: int = 15,
|
|
44
|
+
eps: float = 1e-7,
|
|
45
|
+
cns_a_bound: float | None = None,
|
|
46
|
+
) -> torch.Tensor:
|
|
47
|
+
"""
|
|
48
|
+
Computes the matrix sign function using Chebyshev-Optimized Newton-Schulz (CANS) iterations.
|
|
49
|
+
|
|
50
|
+
This function iteratively approximates the orthogonalized version of a matrix (i.e., UV^T from SVD)
|
|
51
|
+
using purely matrix multiplications. This is highly optimized for Tensor Cores on modern GPUs
|
|
52
|
+
and is significantly faster than computing an exact SVD.
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
G (torch.Tensor): The input matrix/tensor to be orthogonalized. Must be at least 2D.
|
|
56
|
+
steps (int, optional): Number of Newton-Schulz iterations. Deep learning models
|
|
57
|
+
usually only require 5 or 6 steps for sufficient orthogonal preconditioner
|
|
58
|
+
precision due to quadratic convergence. Defaults to 6.
|
|
59
|
+
eps (float, optional): A small epsilon value used to prevent division by zero
|
|
60
|
+
during the initial spectral normalization. Defaults to 1e-7.
|
|
61
|
+
cns_a_bound (float | None, optional): The lower bound for singular values after
|
|
62
|
+
normalization. If None, it is automatically calculated based on the Marchenko-Pastur
|
|
63
|
+
theoretical minimum and baseline L2 bounds. Defaults to None.
|
|
64
|
+
|
|
65
|
+
Returns:
|
|
66
|
+
torch.Tensor: The approximate matrix sign of G (orthogonalized matrix), preserving
|
|
67
|
+
the original shape and device of G.
|
|
68
|
+
"""
|
|
69
|
+
X = G
|
|
70
|
+
|
|
71
|
+
# Transpose if needed
|
|
72
|
+
transposed = X.size(-2) > X.size(-1)
|
|
73
|
+
if transposed:
|
|
74
|
+
X = X.mT
|
|
75
|
+
|
|
76
|
+
# Normalize spectral norm to at most 1
|
|
77
|
+
X.div_(X.norm(dim=(-2, -1), keepdim=True).clamp_min_(eps))
|
|
78
|
+
|
|
79
|
+
if cns_a_bound is None:
|
|
80
|
+
M = G.shape[-2]
|
|
81
|
+
N = G.shape[-1]
|
|
82
|
+
# baseline L2 norm bound (for square matrices)
|
|
83
|
+
baseline_bound = 1.0 / math.sqrt(M * N)
|
|
84
|
+
# Marchenko-Pastur theoretical minimum (for rectangular matrices)
|
|
85
|
+
mp_bound = abs(1.0 / math.sqrt(N) - 1.0 / math.sqrt(M))
|
|
86
|
+
# The optimal bound is safely the maximum of the two
|
|
87
|
+
cns_a_bound = max(baseline_bound, mp_bound)
|
|
88
|
+
|
|
89
|
+
lower_bound = cns_a_bound
|
|
90
|
+
upper_bound = 1.0
|
|
91
|
+
for _ in range(steps):
|
|
92
|
+
lb, ub = lower_bound, upper_bound
|
|
93
|
+
lb_ub = lb * ub
|
|
94
|
+
# Calculate Mean Square Error term
|
|
95
|
+
e_sq = (lb**2 + lb_ub + ub**2) / 3.0
|
|
96
|
+
# Calculate components for alpha and bounds update
|
|
97
|
+
K = 2.0 * e_sq**1.5
|
|
98
|
+
L = lb_ub * (lb + ub)
|
|
99
|
+
denom = K + L
|
|
100
|
+
alpha = 6.0 / denom
|
|
101
|
+
c1 = alpha * e_sq
|
|
102
|
+
c3 = -alpha / 3.0
|
|
103
|
+
# Apply the 3rd-order Newton-Schulz update
|
|
104
|
+
A = X @ X.mT
|
|
105
|
+
X = c1 * X + c3 * (A @ X)
|
|
106
|
+
# Update the singular value bounds for the next iteration based on the error
|
|
107
|
+
eps_val = (K - L) / denom
|
|
108
|
+
lower_bound, upper_bound = 1.0 - eps_val, 1.0 + eps_val
|
|
109
|
+
|
|
110
|
+
# Transpose back if necessary
|
|
111
|
+
if transposed:
|
|
112
|
+
X = X.mT
|
|
113
|
+
|
|
114
|
+
return X
|
|
@@ -8,6 +8,7 @@ from typing import Dict, Any
|
|
|
8
8
|
|
|
9
9
|
from .scaled_optm import adjust_wds
|
|
10
10
|
from .centered_decay import dequantize_anchor
|
|
11
|
+
from .msign import msign_ortho_precond_
|
|
11
12
|
|
|
12
13
|
_generators: Dict[torch.device, torch.Generator] = {}
|
|
13
14
|
|
|
@@ -120,6 +121,11 @@ def apply_parameter_update(
|
|
|
120
121
|
|
|
121
122
|
state = self.state[p]
|
|
122
123
|
|
|
124
|
+
ortho_interval = group.get('MSign_interval', None)
|
|
125
|
+
if ortho_interval is not None:
|
|
126
|
+
is_vector = p.ndim < 2 or getattr(p, '_is_dora_scale', False) or getattr(p, 'is_vector', False)
|
|
127
|
+
is_ortho_step = (state['step'] % ortho_interval == 0) and not is_vector
|
|
128
|
+
|
|
123
129
|
# Compute full update in float32 if using bfloat16 with stochastic rounding
|
|
124
130
|
if p.dtype == torch.bfloat16 and self.stochastic_rounding:
|
|
125
131
|
p_fp32 = p.float()
|
|
@@ -135,6 +141,9 @@ def apply_parameter_update(
|
|
|
135
141
|
# Apply main update
|
|
136
142
|
p_fp32.add_(-update_fp32)
|
|
137
143
|
|
|
144
|
+
if is_ortho_step:
|
|
145
|
+
msign_ortho_precond_(p_fp32)
|
|
146
|
+
|
|
138
147
|
# Single stochastic rounding at the end
|
|
139
148
|
if random_int_tensor is not None:
|
|
140
149
|
# Compiled path: use the pre-computed random tensor
|
|
@@ -153,6 +162,9 @@ def apply_parameter_update(
|
|
|
153
162
|
# Apply main update
|
|
154
163
|
p.add_(-update)
|
|
155
164
|
|
|
165
|
+
if is_ortho_step:
|
|
166
|
+
msign_ortho_precond_(p)
|
|
167
|
+
|
|
156
168
|
del update
|
|
157
169
|
|
|
158
170
|
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|