adv-optm 2.4.dev24__tar.gz → 2.5__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {adv_optm-2.4.dev24 → adv_optm-2.5}/PKG-INFO +1 -1
- {adv_optm-2.4.dev24 → adv_optm-2.5}/adv_optm/__init__.py +1 -1
- {adv_optm-2.4.dev24 → adv_optm-2.5}/adv_optm/optim/AdaMuon_adv.py +7 -7
- {adv_optm-2.4.dev24 → adv_optm-2.5}/adv_optm/optim/AdamW_adv.py +5 -5
- {adv_optm-2.4.dev24 → adv_optm-2.5}/adv_optm/optim/Adopt_adv.py +4 -6
- {adv_optm-2.4.dev24 → adv_optm-2.5}/adv_optm/optim/Lion_adv.py +6 -5
- {adv_optm-2.4.dev24 → adv_optm-2.5}/adv_optm/optim/Muon_adv.py +7 -7
- {adv_optm-2.4.dev24 → adv_optm-2.5}/adv_optm/optim/Prodigy_adv.py +4 -4
- {adv_optm-2.4.dev24 → adv_optm-2.5}/adv_optm/optim/SignSGD_adv.py +15 -12
- {adv_optm-2.4.dev24 → adv_optm-2.5}/adv_optm/optim/SinkSGD_adv.py +15 -12
- {adv_optm-2.4.dev24 → adv_optm-2.5}/adv_optm/util/Muon_AuxAdam.py +2 -3
- adv_optm-2.5/adv_optm/util/OrthoGrad.py +92 -0
- {adv_optm-2.4.dev24 → adv_optm-2.5}/adv_optm/util/param_update.py +3 -3
- {adv_optm-2.4.dev24 → adv_optm-2.5}/adv_optm/util/scaled_optm.py +2 -2
- {adv_optm-2.4.dev24 → adv_optm-2.5}/adv_optm/util/state_util.py +1 -1
- {adv_optm-2.4.dev24 → adv_optm-2.5}/adv_optm.egg-info/PKG-INFO +1 -1
- {adv_optm-2.4.dev24 → adv_optm-2.5}/setup.py +1 -1
- adv_optm-2.4.dev24/adv_optm/util/OrthoGrad.py +0 -80
- {adv_optm-2.4.dev24 → adv_optm-2.5}/LICENSE +0 -0
- {adv_optm-2.4.dev24 → adv_optm-2.5}/README.md +0 -0
- {adv_optm-2.4.dev24 → adv_optm-2.5}/adv_optm/optim/__init__.py +0 -0
- {adv_optm-2.4.dev24 → adv_optm-2.5}/adv_optm/util/Kourkoutas.py +0 -0
- {adv_optm-2.4.dev24 → adv_optm-2.5}/adv_optm/util/Muon_util.py +0 -0
- {adv_optm-2.4.dev24 → adv_optm-2.5}/adv_optm/util/__init__.py +0 -0
- {adv_optm-2.4.dev24 → adv_optm-2.5}/adv_optm/util/centered_decay.py +0 -0
- {adv_optm-2.4.dev24 → adv_optm-2.5}/adv_optm/util/factorization_util.py +0 -0
- {adv_optm-2.4.dev24 → adv_optm-2.5}/adv_optm/util/lion_k.py +0 -0
- {adv_optm-2.4.dev24 → adv_optm-2.5}/adv_optm/util/signed_util.py +0 -0
- {adv_optm-2.4.dev24 → adv_optm-2.5}/adv_optm/util/sinkhorn.py +0 -0
- {adv_optm-2.4.dev24 → adv_optm-2.5}/adv_optm/util/update_util.py +0 -0
- {adv_optm-2.4.dev24 → adv_optm-2.5}/adv_optm.egg-info/SOURCES.txt +0 -0
- {adv_optm-2.4.dev24 → adv_optm-2.5}/adv_optm.egg-info/dependency_links.txt +0 -0
- {adv_optm-2.4.dev24 → adv_optm-2.5}/adv_optm.egg-info/requires.txt +0 -0
- {adv_optm-2.4.dev24 → adv_optm-2.5}/adv_optm.egg-info/top_level.txt +0 -0
- {adv_optm-2.4.dev24 → adv_optm-2.5}/setup.cfg +0 -0
|
@@ -57,7 +57,8 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
57
57
|
(default: (3.4445, -4.7750, 2.0315)).
|
|
58
58
|
stochastic_rounding (bool): whether to use stochastic rounding for
|
|
59
59
|
BF16 parameter updates (default: True).
|
|
60
|
-
orthogonal_gradient (
|
|
60
|
+
orthogonal_gradient (str): whether to use OrthoGrad variants. 'disabled': off.
|
|
61
|
+
'flattened': Standard vectorized OrthoGrad. 'iterative': Matrix-wise rank-2 OrthoGrad. (default: disabled)
|
|
61
62
|
nesterov (bool): enables Nesterov momentum (default: False).
|
|
62
63
|
use_atan2 (bool): whether to use the atan2 update rule. (default: False)
|
|
63
64
|
vector_reshape (bool): whether to reshape 1D vectors into 2D
|
|
@@ -114,7 +115,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
114
115
|
adam_fisher_wd (bool): Fisher Adam (FAdam) weight decay for the AdamW part. (default: False)
|
|
115
116
|
adam_use_bias_correction (bool): Bias correction for AdamW.
|
|
116
117
|
adam_use_atan2 (bool): Atan2 update rule for AdamW.
|
|
117
|
-
adam_orthogonal_gradient (
|
|
118
|
+
adam_orthogonal_gradient (str): OrthoGrad for AdamW.
|
|
118
119
|
adam_nesterov (bool): Nesterov momentum for AdamW. (default: False)
|
|
119
120
|
adam_nesterov_coef (float, optional): Nesterov coefficient for AdamW. (default: None)
|
|
120
121
|
adam_kourkoutas_beta (bool): Kourkoutas-β for AdamW.
|
|
@@ -149,7 +150,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
149
150
|
# Stochastic Rounding for BF16
|
|
150
151
|
stochastic_rounding: bool = True,
|
|
151
152
|
# OrthoGrad
|
|
152
|
-
orthogonal_gradient:
|
|
153
|
+
orthogonal_gradient: str = 'disabled', # 'flattened', 'iterative'
|
|
153
154
|
# Adam_atan2 (scale invariant)
|
|
154
155
|
use_atan2: bool = False,
|
|
155
156
|
# NorMuon
|
|
@@ -190,7 +191,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
190
191
|
adam_fisher_wd: bool = False,
|
|
191
192
|
adam_use_bias_correction: bool = True,
|
|
192
193
|
adam_use_atan2: bool = False,
|
|
193
|
-
adam_orthogonal_gradient:
|
|
194
|
+
adam_orthogonal_gradient: str = 'disabled', # 'flattened', 'iterative'
|
|
194
195
|
adam_nesterov: bool = False,
|
|
195
196
|
adam_nesterov_coef: float | None = None,
|
|
196
197
|
adam_kourkoutas_beta: bool = False,
|
|
@@ -213,7 +214,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
213
214
|
print("Warning: spectral_normalization is incompatible with rms_rescaling, Disabling rms_rescaling.")
|
|
214
215
|
rms_rescaling = False
|
|
215
216
|
if spectral_normalization and accelerated_ns:
|
|
216
|
-
ValueError("spectral_normalization violates accelerated Newton-Schulz assumptions. Pick one of them.")
|
|
217
|
+
raise ValueError("spectral_normalization violates accelerated Newton-Schulz assumptions. Pick one of them.")
|
|
217
218
|
|
|
218
219
|
# Legacy backwards compatibility support for `nnmf_factor=True`
|
|
219
220
|
if nnmf_factor:
|
|
@@ -515,8 +516,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
515
516
|
grad = approx_mars(grad, state['last_grad'], group['mars_gamma'], beta1)
|
|
516
517
|
|
|
517
518
|
|
|
518
|
-
|
|
519
|
-
grad = _orthogonalize_gradient(p, grad)
|
|
519
|
+
grad = _orthogonalize_gradient(p, grad, group.get("orthogonal_gradient"))
|
|
520
520
|
|
|
521
521
|
if state['factored']: # Factored Muon
|
|
522
522
|
d1, d2 = state['effective_shape']
|
|
@@ -45,7 +45,8 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
45
45
|
stochastic_rounding (bool): whether to use stochastic
|
|
46
46
|
rounding for BF16 parameter updates (default: True).
|
|
47
47
|
use_atan2 (bool): whether to use the atan2 update rule. (default: False)
|
|
48
|
-
orthogonal_gradient (
|
|
48
|
+
orthogonal_gradient (str): whether to use OrthoGrad variants. 'disabled': off.
|
|
49
|
+
'flattened': Standard vectorized OrthoGrad. 'iterative': Matrix-wise rank-2 OrthoGrad. (default: disabled)
|
|
49
50
|
normed_momentum (bool): whether to compute the first moment on the normalized gradient. (default: False)
|
|
50
51
|
kourkoutas_beta (bool): whether to enable the layer-wise dynamic β₂ logic.
|
|
51
52
|
If `False`, the optimizer behaves as standard AdamW. (default: False)
|
|
@@ -104,7 +105,7 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
104
105
|
# Adam_atan2 (scale invariant)
|
|
105
106
|
use_atan2: bool = False,
|
|
106
107
|
# OrthoGrad
|
|
107
|
-
orthogonal_gradient:
|
|
108
|
+
orthogonal_gradient: str = 'disabled', # 'flattened', 'iterative'
|
|
108
109
|
# Nesterov momentum
|
|
109
110
|
nesterov: bool = False,
|
|
110
111
|
nesterov_coef: float | None = None,
|
|
@@ -326,8 +327,7 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
326
327
|
def _step_parameter(self, p, grad, state, group, step_size, beta1, beta2, sqrt_bias_correction2, random_int_tensor, random_int_state_tensor):
|
|
327
328
|
grad = upcast_grad_for_precision(grad, state, group['state_precision'])
|
|
328
329
|
|
|
329
|
-
|
|
330
|
-
grad = _orthogonalize_gradient(p, grad)
|
|
330
|
+
grad = _orthogonalize_gradient(p, grad, group["orthogonal_gradient"])
|
|
331
331
|
|
|
332
332
|
nesterov = group.get('nesterov', False)
|
|
333
333
|
nesterov_coef = group.get('nesterov_coef', None)
|
|
@@ -462,7 +462,7 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
462
462
|
else:
|
|
463
463
|
update.mul_(update_scaling)
|
|
464
464
|
|
|
465
|
-
param_update.apply_parameter_update(self, p, group, update,
|
|
465
|
+
param_update.apply_parameter_update(self, p, group, update, group['lr'], random_int_tensor=random_int_tensor, wd_scaler=wd_scaler)
|
|
466
466
|
|
|
467
467
|
def compile(self, *args, **kwargs):
|
|
468
468
|
self._compiled_step_parameter = torch.compile(self._step_parameter, *args, **kwargs)
|
|
@@ -108,7 +108,7 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
108
108
|
# Stochastic Rounding for BF16
|
|
109
109
|
stochastic_rounding: bool = True,
|
|
110
110
|
# OrthoGrad
|
|
111
|
-
orthogonal_gradient:
|
|
111
|
+
orthogonal_gradient: str = 'disabled', # 'flattened', 'iterative'
|
|
112
112
|
# Nesterov momentum
|
|
113
113
|
nesterov: bool = False,
|
|
114
114
|
nesterov_coef: float | None = None,
|
|
@@ -158,7 +158,7 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
158
158
|
|
|
159
159
|
defaults = {
|
|
160
160
|
"lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay,
|
|
161
|
-
"fisher_wd": fisher_wd, "cautious_wd": cautious_wd,
|
|
161
|
+
"fisher_wd": fisher_wd, "cautious_wd": cautious_wd, "orthogonal_gradient": orthogonal_gradient,
|
|
162
162
|
"nesterov": nesterov, "nesterov_coef": nesterov_coef,
|
|
163
163
|
"kourkoutas_beta": kourkoutas_beta, "beta2_min": beta2_min, "ema_alpha": ema_alpha,
|
|
164
164
|
"tiny_spike": tiny_spike, "k_warmup_steps": k_warmup_steps, "k_logging": k_logging,
|
|
@@ -172,7 +172,6 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
172
172
|
self.clip_lambda = clip_lambda
|
|
173
173
|
self.stochastic_rounding = stochastic_rounding
|
|
174
174
|
self.use_atan2 = use_atan2
|
|
175
|
-
self.orthogonal_gradient = orthogonal_gradient
|
|
176
175
|
self.kourkoutas_beta = kourkoutas_beta
|
|
177
176
|
self.layer_key_fn = layer_key_fn
|
|
178
177
|
self._init_lr = lr if lr > 0 else 1
|
|
@@ -237,7 +236,7 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
237
236
|
dtype = torch.float32 if (state['factored'] or req_precision == 'factored') else p.dtype
|
|
238
237
|
|
|
239
238
|
vt_dtype = torch.float32 if (state['factored'] or state['factored_2nd'] or req_precision in ['factored', 'bf16_sr', 'int8_sr']) else dtype
|
|
240
|
-
vt_init = grad.pow(2).to(vt_dtype)
|
|
239
|
+
vt_init = grad.pow(2).to(vt_dtype)
|
|
241
240
|
|
|
242
241
|
if state['factored']:
|
|
243
242
|
state['effective_shape'] = _get_effective_shape(p.numel())
|
|
@@ -329,8 +328,7 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
329
328
|
def _step_parameter(self, p, grad, state, group, lr, beta1, beta2, random_int_tensor, random_int_state_tensor):
|
|
330
329
|
grad = upcast_grad_for_precision(grad, state, group['state_precision'])
|
|
331
330
|
|
|
332
|
-
|
|
333
|
-
grad = _orthogonalize_gradient(p, grad)
|
|
331
|
+
grad = _orthogonalize_gradient(p, grad, group["orthogonal_gradient"])
|
|
334
332
|
|
|
335
333
|
nesterov = group.get('nesterov', False)
|
|
336
334
|
nesterov_coef = group.get('nesterov_coef', None)
|
|
@@ -67,7 +67,7 @@ class Lion_adv(torch.optim.Optimizer):
|
|
|
67
67
|
# Stochastic Rounding for BF16
|
|
68
68
|
stochastic_rounding: bool = True,
|
|
69
69
|
# OrthoGrad
|
|
70
|
-
orthogonal_gradient:
|
|
70
|
+
orthogonal_gradient: str = 'disabled', # 'flattened', 'iterative'
|
|
71
71
|
# Lion-k
|
|
72
72
|
kappa_p: float = 1.0,
|
|
73
73
|
auto_kappa_p: bool = False,
|
|
@@ -213,8 +213,9 @@ class Lion_adv(torch.optim.Optimizer):
|
|
|
213
213
|
def _step_parameter(self, p, grad, state, group, lr, random_int_tensor, random_noise_tensor):
|
|
214
214
|
if grad.dtype != torch.float32 and state['factored']:
|
|
215
215
|
grad = grad.float()
|
|
216
|
-
|
|
217
|
-
|
|
216
|
+
is_vector = p.ndim < 2 or getattr(p, '_is_dora_scale', False) or getattr(p, 'is_vector', False)
|
|
217
|
+
|
|
218
|
+
grad = _orthogonalize_gradient(p, grad, group["orthogonal_gradient"])
|
|
218
219
|
|
|
219
220
|
# Lion-K Logic
|
|
220
221
|
kappa_p = group.get("kappa_p", 1.0)
|
|
@@ -250,7 +251,7 @@ class Lion_adv(torch.optim.Optimizer):
|
|
|
250
251
|
update = update.view(p.shape)
|
|
251
252
|
|
|
252
253
|
if group.get('stochastic_sign', False):
|
|
253
|
-
update = apply_stochastic_sign_(update, noise=random_noise_tensor)
|
|
254
|
+
update = apply_stochastic_sign_(update, noise=random_noise_tensor, is_vector=is_vector)
|
|
254
255
|
else:
|
|
255
256
|
update = _get_lion_k_update(update, kappa_p)
|
|
256
257
|
|
|
@@ -265,7 +266,7 @@ class Lion_adv(torch.optim.Optimizer):
|
|
|
265
266
|
exp_avg.lerp_(grad, 1 - beta2)
|
|
266
267
|
|
|
267
268
|
if group.get('stochastic_sign', False):
|
|
268
|
-
update = apply_stochastic_sign_(update, noise=random_noise_tensor)
|
|
269
|
+
update = apply_stochastic_sign_(update, noise=random_noise_tensor, is_vector=is_vector)
|
|
269
270
|
else:
|
|
270
271
|
update = _get_lion_k_update(update, kappa_p)
|
|
271
272
|
|
|
@@ -39,7 +39,8 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
39
39
|
(default: (3.4445, -4.7750, 2.0315)).
|
|
40
40
|
stochastic_rounding (bool): whether to use stochastic rounding for
|
|
41
41
|
BF16 parameter updates (default: True).
|
|
42
|
-
orthogonal_gradient (
|
|
42
|
+
orthogonal_gradient (str): whether to use OrthoGrad variants. 'disabled': off.
|
|
43
|
+
'flattened': Standard vectorized OrthoGrad. 'iterative': Matrix-wise rank-2 OrthoGrad. (default: disabled)
|
|
43
44
|
vector_reshape (bool): whether to reshape 1D vectors into 2D
|
|
44
45
|
matrices to apply low-rank compression (default: True).
|
|
45
46
|
nnmf_factor (bool): whether to use the factorization or disable it to use
|
|
@@ -89,7 +90,7 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
89
90
|
adam_fisher_wd (bool): Fisher Adam (FAdam) weight decay for the AdamW part. (default: False)
|
|
90
91
|
adam_use_bias_correction (bool): Bias correction for AdamW.
|
|
91
92
|
adam_use_atan2 (bool): Atan2 update rule for AdamW.
|
|
92
|
-
adam_orthogonal_gradient (
|
|
93
|
+
adam_orthogonal_gradient (str): OrthoGrad for AdamW.
|
|
93
94
|
adam_nesterov (bool): Nesterov momentum for AdamW. (default: False)
|
|
94
95
|
adam_nesterov_coef (float, optional): Nesterov coefficient for AdamW. (default: None)
|
|
95
96
|
adam_kourkoutas_beta (bool): Kourkoutas-β for AdamW.
|
|
@@ -121,7 +122,7 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
121
122
|
# Stochastic Rounding for BF16
|
|
122
123
|
stochastic_rounding: bool = True,
|
|
123
124
|
# OrthoGrad
|
|
124
|
-
orthogonal_gradient:
|
|
125
|
+
orthogonal_gradient: str = 'disabled', # 'flattened', 'iterative'
|
|
125
126
|
# RMS Rescaling
|
|
126
127
|
rms_rescaling: bool = True,
|
|
127
128
|
# SMMF factorization
|
|
@@ -159,7 +160,7 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
159
160
|
adam_fisher_wd: bool = False,
|
|
160
161
|
adam_use_bias_correction: bool = True,
|
|
161
162
|
adam_use_atan2: bool = False,
|
|
162
|
-
adam_orthogonal_gradient:
|
|
163
|
+
adam_orthogonal_gradient: str = 'disabled', # 'flattened', 'iterative'
|
|
163
164
|
adam_nesterov: bool = False,
|
|
164
165
|
adam_nesterov_coef: float | None = None,
|
|
165
166
|
adam_kourkoutas_beta: bool = False,
|
|
@@ -186,7 +187,7 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
186
187
|
print("Warning: spectral_normalization is incompatible with rms_rescaling, Disabling rms_rescaling.")
|
|
187
188
|
rms_rescaling = False
|
|
188
189
|
if spectral_normalization and accelerated_ns:
|
|
189
|
-
ValueError("spectral_normalization violates accelerated Newton-Schulz assumptions. Pick one of them.")
|
|
190
|
+
raise ValueError("spectral_normalization violates accelerated Newton-Schulz assumptions. Pick one of them.")
|
|
190
191
|
|
|
191
192
|
# Legacy backwards compatibility support for `nnmf_factor=True`
|
|
192
193
|
if nnmf_factor:
|
|
@@ -457,8 +458,7 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
457
458
|
if grad.dtype != torch.float32 and state.get('factored', False):
|
|
458
459
|
grad = grad.float()
|
|
459
460
|
|
|
460
|
-
|
|
461
|
-
grad = _orthogonalize_gradient(p, grad)
|
|
461
|
+
grad = _orthogonalize_gradient(p, grad, group.get("orthogonal_gradient"))
|
|
462
462
|
|
|
463
463
|
if state['factored']: # Factored Muon
|
|
464
464
|
d1, d2 = state['effective_shape']
|
|
@@ -43,7 +43,8 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
43
43
|
stochastic_rounding (bool): whether to use stochastic
|
|
44
44
|
rounding for BF16 parameter updates (default: True).
|
|
45
45
|
use_atan2 (bool): whether to use the atan2 update rule. (default: False)
|
|
46
|
-
orthogonal_gradient (
|
|
46
|
+
orthogonal_gradient (str): whether to use OrthoGrad variants. 'disabled': off.
|
|
47
|
+
'flattened': Standard vectorized OrthoGrad. 'iterative': Matrix-wise rank-2 OrthoGrad. (default: disabled)
|
|
47
48
|
nnmf_factor (bool): whether to use the factorization or disable it to use
|
|
48
49
|
the uncompressed optimizer. (default: False)
|
|
49
50
|
factored_2nd (bool): whether to keep the first moment uncompressed (dense)
|
|
@@ -119,7 +120,7 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
119
120
|
# Adam_atan2 (scale invariant)
|
|
120
121
|
use_atan2: bool = False,
|
|
121
122
|
# OrthoGrad
|
|
122
|
-
orthogonal_gradient:
|
|
123
|
+
orthogonal_gradient: str = 'disabled', # 'flattened', 'iterative'
|
|
123
124
|
# Nesterov momentum
|
|
124
125
|
nesterov: bool = False,
|
|
125
126
|
nesterov_coef: float | None = None,
|
|
@@ -371,8 +372,7 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
371
372
|
def _step_parameter(self, p, grad, state, group, beta2, d, dlr, random_int_tensor, random_int_state_tensor):
|
|
372
373
|
grad = upcast_grad_for_precision(grad, state, group['state_precision'])
|
|
373
374
|
|
|
374
|
-
|
|
375
|
-
grad = _orthogonalize_gradient(p, grad)
|
|
375
|
+
grad = _orthogonalize_gradient(p, grad, group["orthogonal_gradient"])
|
|
376
376
|
|
|
377
377
|
nesterov = group.get('nesterov', False)
|
|
378
378
|
nesterov_coef = group.get('nesterov_coef', None)
|
|
@@ -62,7 +62,7 @@ class SignSGD_adv(torch.optim.Optimizer):
|
|
|
62
62
|
# Stochastic Rounding for BF16
|
|
63
63
|
stochastic_rounding: bool = True,
|
|
64
64
|
# OrthoGrad
|
|
65
|
-
orthogonal_gradient:
|
|
65
|
+
orthogonal_gradient: str = 'disabled', # 'flattened', 'iterative'
|
|
66
66
|
# Stochastic Sign Operator
|
|
67
67
|
stochastic_sign: bool = False,
|
|
68
68
|
# Nesterov momentum
|
|
@@ -171,7 +171,7 @@ class SignSGD_adv(torch.optim.Optimizer):
|
|
|
171
171
|
def __init_state(self, p, group):
|
|
172
172
|
state = self.state[p]
|
|
173
173
|
# State Initialization
|
|
174
|
-
if
|
|
174
|
+
if 'step' not in state:
|
|
175
175
|
req_precision = group['state_precision']
|
|
176
176
|
is_vector = len(p.shape) == 1 and not group['vector_reshape']
|
|
177
177
|
|
|
@@ -259,8 +259,7 @@ class SignSGD_adv(torch.optim.Optimizer):
|
|
|
259
259
|
wd_target = None
|
|
260
260
|
cwd_target = None
|
|
261
261
|
|
|
262
|
-
|
|
263
|
-
grad = _orthogonalize_gradient(p, grad)
|
|
262
|
+
grad = _orthogonalize_gradient(p, grad, group["orthogonal_gradient"])
|
|
264
263
|
|
|
265
264
|
if normed_mt:
|
|
266
265
|
if sso:
|
|
@@ -280,16 +279,18 @@ class SignSGD_adv(torch.optim.Optimizer):
|
|
|
280
279
|
if snr_cond:
|
|
281
280
|
denom = (1.0 - exp_avg.square()).clamp_min_(1e-30).sqrt_().view_as(p)
|
|
282
281
|
|
|
282
|
+
if nesterov and normed_mt:
|
|
283
|
+
# Scale the normalized gradient using empirical buffer magnitude (SNR recovery)
|
|
284
|
+
normed_grad = exp_avg.abs().mul_(grad_reshaped)
|
|
285
|
+
|
|
283
286
|
exp_avg.lerp_(grad_reshaped, 1 - momentum)
|
|
284
287
|
|
|
285
288
|
if nesterov:
|
|
286
289
|
nv_coef = momentum if nesterov_coef is None else nesterov_coef
|
|
287
290
|
if normed_mt:
|
|
288
|
-
|
|
289
|
-
ema_std = math.sqrt((1 - momentum) / (1 + momentum))
|
|
290
|
-
raw_update = (grad_reshaped * ema_std).lerp_(exp_avg, nv_coef)
|
|
291
|
+
raw_update = normed_grad.lerp_(exp_avg, nv_coef)
|
|
291
292
|
else:
|
|
292
|
-
raw_update =
|
|
293
|
+
raw_update = grad_reshaped.lerp(exp_avg, nv_coef)
|
|
293
294
|
else:
|
|
294
295
|
raw_update = exp_avg.clone()
|
|
295
296
|
|
|
@@ -309,14 +310,16 @@ class SignSGD_adv(torch.optim.Optimizer):
|
|
|
309
310
|
if snr_cond:
|
|
310
311
|
denom = (1.0 - exp_avg.square()).clamp_min_(1e-30).sqrt_()
|
|
311
312
|
|
|
313
|
+
if nesterov and normed_mt:
|
|
314
|
+
# Scale the normalized gradient using empirical buffer magnitude (SNR recovery)
|
|
315
|
+
normed_grad = exp_avg.abs().mul_(grad)
|
|
316
|
+
|
|
312
317
|
exp_avg.lerp_(grad, 1 - momentum)
|
|
313
318
|
|
|
314
319
|
if nesterov:
|
|
315
320
|
nv_coef = momentum if nesterov_coef is None else nesterov_coef
|
|
316
321
|
if normed_mt:
|
|
317
|
-
|
|
318
|
-
ema_std = math.sqrt((1 - momentum) / (1 + momentum))
|
|
319
|
-
raw_update = (grad * ema_std).lerp_(exp_avg, nv_coef)
|
|
322
|
+
raw_update = normed_grad.lerp_(exp_avg, nv_coef)
|
|
320
323
|
else:
|
|
321
324
|
raw_update = grad.lerp(exp_avg, nv_coef)
|
|
322
325
|
else:
|
|
@@ -340,7 +343,7 @@ class SignSGD_adv(torch.optim.Optimizer):
|
|
|
340
343
|
if group.get('geometric_wd', False) and group["weight_decay"] > 0 :
|
|
341
344
|
wd_target = get_signsgd_wd_target(p, denom=denom, stochastic_sign=sso, noise=random_noise_tensor, is_vector=is_vector)
|
|
342
345
|
|
|
343
|
-
if group.get('centered_wd', 0.0) > 0 and '
|
|
346
|
+
if group.get('centered_wd', 0.0) > 0 and 'anchor_data' in state:
|
|
344
347
|
anchor = dequantize_anchor(p, state, group, p.dtype)
|
|
345
348
|
cwd_target = get_signsgd_wd_target(p.sub(anchor), denom=denom, stochastic_sign=sso, noise=random_noise_tensor, is_vector=is_vector)
|
|
346
349
|
del anchor
|
|
@@ -69,7 +69,7 @@ class SinkSGD_adv(torch.optim.Optimizer):
|
|
|
69
69
|
# Stochastic Rounding for BF16
|
|
70
70
|
stochastic_rounding: bool = True,
|
|
71
71
|
# OrthoGrad
|
|
72
|
-
orthogonal_gradient:
|
|
72
|
+
orthogonal_gradient: str = 'disabled', # 'flattened', 'iterative'
|
|
73
73
|
# Spectral Normed Optimizer
|
|
74
74
|
spectral_normalization: bool = False,
|
|
75
75
|
# Centered WD
|
|
@@ -89,8 +89,8 @@ class SinkSGD_adv(torch.optim.Optimizer):
|
|
|
89
89
|
raise ValueError(f"Momentum should be >= 0.0. Got {momentum}")
|
|
90
90
|
if not (weight_decay >= 0.0):
|
|
91
91
|
raise ValueError(f"Weight-decay should be >= 0.0. Got {weight_decay}")
|
|
92
|
-
if snr_cond and not normed_momentum:
|
|
93
|
-
raise NotImplementedError(f"snr_cond is intended to be used with normed_momentum")
|
|
92
|
+
if snr_cond and not normed_momentum and not momentum > 0:
|
|
93
|
+
raise NotImplementedError(f"snr_cond is intended to be used with normed_momentum.")
|
|
94
94
|
|
|
95
95
|
state_precision = state_precision.lower()
|
|
96
96
|
valid_precisions = {"auto", "fp32", "factored", "bf16_sr", "fp16", "int8_sr"}
|
|
@@ -237,8 +237,7 @@ class SinkSGD_adv(torch.optim.Optimizer):
|
|
|
237
237
|
wd_target = None
|
|
238
238
|
cwd_target = None
|
|
239
239
|
|
|
240
|
-
|
|
241
|
-
grad = _orthogonalize_gradient(p, grad)
|
|
240
|
+
grad = _orthogonalize_gradient(p, grad, group["orthogonal_gradient"])
|
|
242
241
|
|
|
243
242
|
if normed_mt:
|
|
244
243
|
if not is_vector:
|
|
@@ -264,6 +263,10 @@ class SinkSGD_adv(torch.optim.Optimizer):
|
|
|
264
263
|
else:
|
|
265
264
|
denom = (1.0 - buf.square()).clamp_min_(1e-30).sqrt_().view_as(p)
|
|
266
265
|
|
|
266
|
+
if nesterov and normed_mt:
|
|
267
|
+
# Scale the normalized gradient using empirical buffer magnitude (SNR recovery)
|
|
268
|
+
normed_grad = buf.abs().mul_(grad_reshaped)
|
|
269
|
+
|
|
267
270
|
buf.lerp_(grad_reshaped, 1 - momentum)
|
|
268
271
|
|
|
269
272
|
# Factorize updated buffer
|
|
@@ -272,9 +275,7 @@ class SinkSGD_adv(torch.optim.Optimizer):
|
|
|
272
275
|
if nesterov:
|
|
273
276
|
nv_coef = momentum if nesterov_coef is None else nesterov_coef
|
|
274
277
|
if normed_mt:
|
|
275
|
-
|
|
276
|
-
ema_std = math.sqrt((1 - momentum) / (1 + momentum))
|
|
277
|
-
update = (grad_reshaped * ema_std).lerp_(buf, nv_coef)
|
|
278
|
+
update = normed_grad.lerp_(buf, nv_coef)
|
|
278
279
|
else:
|
|
279
280
|
update = grad_reshaped.lerp(buf, nv_coef)
|
|
280
281
|
else:
|
|
@@ -299,6 +300,10 @@ class SinkSGD_adv(torch.optim.Optimizer):
|
|
|
299
300
|
else:
|
|
300
301
|
denom = (1.0 - buf.square()).clamp_min_(1e-30).sqrt_()
|
|
301
302
|
|
|
303
|
+
if nesterov and normed_mt:
|
|
304
|
+
# Scale the normalized gradient using empirical buffer magnitude (SNR recovery)
|
|
305
|
+
normed_grad = buf.abs().mul_(grad)
|
|
306
|
+
|
|
302
307
|
buf.lerp_(grad, 1 - momentum)
|
|
303
308
|
|
|
304
309
|
set_state(state, 'momentum_buffer', buf, actual_precision, random_int_state_tensor)
|
|
@@ -306,9 +311,7 @@ class SinkSGD_adv(torch.optim.Optimizer):
|
|
|
306
311
|
if nesterov:
|
|
307
312
|
nv_coef = momentum if nesterov_coef is None else nesterov_coef
|
|
308
313
|
if normed_mt:
|
|
309
|
-
|
|
310
|
-
ema_std = math.sqrt((1 - momentum) / (1 + momentum))
|
|
311
|
-
update = (grad * ema_std).lerp_(buf, nv_coef)
|
|
314
|
+
update = normed_grad.lerp_(buf, nv_coef)
|
|
312
315
|
else:
|
|
313
316
|
update = grad.lerp(buf, nv_coef)
|
|
314
317
|
else:
|
|
@@ -342,7 +345,7 @@ class SinkSGD_adv(torch.optim.Optimizer):
|
|
|
342
345
|
wd_scaler = get_sinkhorn_wd_scaler(p, row_denom=vt_row, col_denom=vt_col)
|
|
343
346
|
else:
|
|
344
347
|
wd_target = get_signsgd_wd_target(p, denom=denom)
|
|
345
|
-
if is_vector and group.get('centered_wd', 0.0) > 0 and '
|
|
348
|
+
if is_vector and group.get('centered_wd', 0.0) > 0 and 'anchor_data' in state:
|
|
346
349
|
anchor = dequantize_anchor(p, state, group, p.dtype)
|
|
347
350
|
cwd_target = get_signsgd_wd_target(p.sub(anchor), denom=denom)
|
|
348
351
|
del anchor
|
|
@@ -71,8 +71,7 @@ def _init_auxadam_state(self, p, group):
|
|
|
71
71
|
def _adam_step_parameter(self, p, grad, state, group, beta1_adam, beta2_adam, sqrt_bias_correction2, step_size, random_int_tensor, random_int_state_tensor=None):
|
|
72
72
|
grad = upcast_grad_for_precision(grad, state, group.get('adam_state_precision', 'auto'))
|
|
73
73
|
|
|
74
|
-
|
|
75
|
-
grad = _orthogonalize_gradient(p, grad)
|
|
74
|
+
grad = _orthogonalize_gradient(p, grad, group.get("adam_orthogonal_gradient"))
|
|
76
75
|
|
|
77
76
|
if hasattr(self, 'kourkoutas_helper') and self.kourkoutas_helper:
|
|
78
77
|
# Accumulate current grad's norm for the *next* step
|
|
@@ -190,4 +189,4 @@ def _adam_step_parameter(self, p, grad, state, group, beta1_adam, beta2_adam, sq
|
|
|
190
189
|
else:
|
|
191
190
|
update.mul_(update_scaling)
|
|
192
191
|
|
|
193
|
-
param_update.apply_parameter_update(self, p, group, update,
|
|
192
|
+
param_update.apply_parameter_update(self, p, group, update, group['lr'], group["adam_weight_decay"], random_int_tensor=random_int_tensor, wd_scaler=wd_scaler)
|
|
@@ -0,0 +1,92 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import math
|
|
3
|
+
|
|
4
|
+
def _orthogonalize_gradient(p: torch.Tensor, grad: torch.Tensor, mode: str) -> torch.Tensor:
|
|
5
|
+
"""
|
|
6
|
+
Projects the gradient `grad` to be orthogonal to the parameter `p`.
|
|
7
|
+
Supports two modes: 'flattened' (vectorized) and 'iterative' (matrix-wise).
|
|
8
|
+
"""
|
|
9
|
+
if mode == 'disabled':
|
|
10
|
+
return grad
|
|
11
|
+
elif mode == 'flattened':
|
|
12
|
+
return flattened_ortho_project(p, grad)
|
|
13
|
+
elif mode == 'iterative':
|
|
14
|
+
return iterative_ortho_project(p, grad, iters=3)
|
|
15
|
+
|
|
16
|
+
def flattened_ortho_project(p: torch.Tensor, grad: torch.Tensor) -> torch.Tensor:
|
|
17
|
+
"""
|
|
18
|
+
Projects the flattened gradient `grad` to be orthogonal to the flattened parameter `p`.
|
|
19
|
+
Modified from:
|
|
20
|
+
https://github.com/LucasPrietoAl/grokking-at-the-edge-of-numerical-stability/blob/720d2444df12b851d6cb417ab08cf125c822b2ae/orthograd.py
|
|
21
|
+
"""
|
|
22
|
+
original_shape = grad.shape
|
|
23
|
+
original_dtype = grad.dtype
|
|
24
|
+
w = p.view(-1).float()
|
|
25
|
+
g = grad.view(-1).float()
|
|
26
|
+
w_norm_sq = torch.dot(w, w).add_(1e-30)
|
|
27
|
+
proj = torch.dot(w, g) / w_norm_sq
|
|
28
|
+
g_orth = g.sub(w * proj)
|
|
29
|
+
g_norm = g.norm(2)
|
|
30
|
+
g_orth_norm = g_orth.norm(2).add_(1e-30)
|
|
31
|
+
g_orth_scaled = g_orth * (g_norm / g_orth_norm)
|
|
32
|
+
return g_orth_scaled.view(original_shape).to(original_dtype)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def iterative_ortho_project(p: torch.Tensor, grad: torch.Tensor, iters: int = 3) -> torch.Tensor:
|
|
36
|
+
"""
|
|
37
|
+
Applies iterative alternating orthogonal projection to a 2D matrix.
|
|
38
|
+
Projects the grad to be orthogonal to the parameter matrix along
|
|
39
|
+
rows and columns sequentially, alternating dimensions.
|
|
40
|
+
Inspired from Sinkhorn algorithm, 2-3 iterations is enough to converge
|
|
41
|
+
to cosine similarity of -1e4 to -1e-6 for every row/col (semi orthogonal).
|
|
42
|
+
"""
|
|
43
|
+
# 1D Vector Case fallback to the standard OrthoGrad
|
|
44
|
+
is_vector = p.ndim < 2 or getattr(p, '_is_dora_scale', False) or getattr(p, 'is_vector', False)
|
|
45
|
+
if is_vector:
|
|
46
|
+
return _orthogonalize_gradient(p, grad)
|
|
47
|
+
|
|
48
|
+
original_shape = grad.shape
|
|
49
|
+
|
|
50
|
+
# 2D+ Matrix Case
|
|
51
|
+
grad_2d = grad.view(grad.shape[0], -1)
|
|
52
|
+
param_2d = p.view(p.shape[0], -1)
|
|
53
|
+
|
|
54
|
+
m, n = grad_2d.shape
|
|
55
|
+
|
|
56
|
+
# Dynamically determine the order based on aspect ratio
|
|
57
|
+
row_first = m > n
|
|
58
|
+
dim = 0 if row_first else 1
|
|
59
|
+
|
|
60
|
+
p_norm_sq_dim = torch.sum(param_2d * param_2d, dim=dim, keepdim=True).add_(1e-30)
|
|
61
|
+
p_norm_sq_adim = torch.sum(param_2d * param_2d, dim=1-dim, keepdim=True).add_(1e-30)
|
|
62
|
+
|
|
63
|
+
for _ in range(iters):
|
|
64
|
+
# First dimension
|
|
65
|
+
grad_2d = _ortho_normed_dim(param_2d, grad_2d, p_norm_sq_dim, dim)
|
|
66
|
+
# Second dimension
|
|
67
|
+
grad_2d = _ortho_normed_dim(param_2d, grad_2d, p_norm_sq_adim, 1 - dim)
|
|
68
|
+
|
|
69
|
+
return grad_2d.view(original_shape)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def _ortho_normed_dim(p_2d: torch.Tensor, grad_2d: torch.Tensor, p_norm_sq: torch.Tensor, dim: int) -> torch.Tensor:
|
|
73
|
+
"""
|
|
74
|
+
Projects the grad to be orthogonal to p along 'dim' and dynamically restores
|
|
75
|
+
the original magnitude of that dimension pre-projection.
|
|
76
|
+
"""
|
|
77
|
+
# Record target magnitude before projection
|
|
78
|
+
norm_lb = 1 / math.sqrt(grad_2d.shape[dim])
|
|
79
|
+
target_norm = grad_2d.norm(p=2, dim=dim, keepdim=True).clamp_min_(norm_lb)
|
|
80
|
+
|
|
81
|
+
# Project: g_orth = g - (p * <p, g> / ||p||^2)
|
|
82
|
+
dot_prod = torch.sum(p_2d * grad_2d, dim=dim, keepdim=True)
|
|
83
|
+
proj = dot_prod / p_norm_sq
|
|
84
|
+
|
|
85
|
+
# In-place subtraction: grad_2d = grad_2d - (proj * p_2d)
|
|
86
|
+
# Standard gamma is -1, but -1.01 proved to converge faster
|
|
87
|
+
grad_2d.addcmul_(proj, p_2d, value=-1.01)
|
|
88
|
+
|
|
89
|
+
# Magnitude Preservation
|
|
90
|
+
g_orth_norm = grad_2d.norm(p=2, dim=dim, keepdim=True).clamp_min_(norm_lb)
|
|
91
|
+
scale_factor = target_norm / g_orth_norm
|
|
92
|
+
return grad_2d.mul_(scale_factor)
|
|
@@ -6,7 +6,7 @@ import torch.nn.functional as F
|
|
|
6
6
|
|
|
7
7
|
from typing import Dict, Any
|
|
8
8
|
|
|
9
|
-
from .scaled_optm import adjust_wds
|
|
9
|
+
from .scaled_optm import adjust_wds
|
|
10
10
|
from .centered_decay import dequantize_anchor
|
|
11
11
|
|
|
12
12
|
_generators: Dict[torch.device, torch.Generator] = {}
|
|
@@ -48,7 +48,7 @@ def _apply_weight_decay(
|
|
|
48
48
|
p_calc.add_(wd_target, alpha=-scaled_wd)
|
|
49
49
|
|
|
50
50
|
# Centered Weight Decay (pulls toward anchor)
|
|
51
|
-
if scaled_cwd is not None and '
|
|
51
|
+
if scaled_cwd is not None and 'anchor_data' in state:
|
|
52
52
|
if cwd_target is not None:
|
|
53
53
|
decay_target = cwd_target
|
|
54
54
|
else:
|
|
@@ -330,7 +330,7 @@ def _copy_int8_sym_blockwise_stochastic_core_(
|
|
|
330
330
|
target: torch.Tensor,
|
|
331
331
|
source: torch.Tensor,
|
|
332
332
|
scales: torch.Tensor,
|
|
333
|
-
random_int_tensor: torch.Tensor
|
|
333
|
+
random_int_tensor: torch.Tensor,
|
|
334
334
|
block_size: int = 2048,
|
|
335
335
|
val_blocks: torch.Tensor | None = None,
|
|
336
336
|
) -> None:
|
|
@@ -61,7 +61,7 @@ def adjust_wds(wd: float, cwd: float, p: torch.Tensor) -> tuple[float, float]:
|
|
|
61
61
|
"""
|
|
62
62
|
# DoRA Scale (Magnitude Vector)
|
|
63
63
|
if getattr(p, '_is_dora_scale', False):
|
|
64
|
-
return
|
|
64
|
+
return wd, cwd
|
|
65
65
|
|
|
66
66
|
if getattr(p, '_is_oft', False):
|
|
67
67
|
return wd, 0.0
|
|
@@ -76,7 +76,7 @@ def adjust_wds(wd: float, cwd: float, p: torch.Tensor) -> tuple[float, float]:
|
|
|
76
76
|
else:
|
|
77
77
|
# 1D Biases or generic 1D parameters
|
|
78
78
|
# Centered WD safely regularizes the delta without collapsing base feature variance.
|
|
79
|
-
return
|
|
79
|
+
return wd, cwd
|
|
80
80
|
|
|
81
81
|
|
|
82
82
|
def scale_wds(wd: float, cwd: float, p: torch.Tensor) -> tuple[float, float]:
|
|
@@ -209,7 +209,7 @@ def fix_loaded_state_dtype(state: dict, p: torch.Tensor, group: dict) -> None:
|
|
|
209
209
|
|
|
210
210
|
# Pre-define sets for known exact-match keys
|
|
211
211
|
uint8_keys = {'sign', 'sign_slow', 'sign_buf', 'shifter'}
|
|
212
|
-
fp32_keys = {'mu_m_nmf', 'mv_m_nmf', 'mu_v_nmf', 'mv_v_nmf', 'mu_m_slow_nmf', 'mv_m_slow_nmf'}
|
|
212
|
+
fp32_keys = {'mu_m_nmf', 'mv_m_nmf', 'mu_v_nmf', 'mv_v_nmf', 'mu_m_slow_nmf', 'mv_m_slow_nmf', "mu_mbuf_nmf", "mv_mbuf_nmf", "mu_b_nmf", "normuon_v"}
|
|
213
213
|
|
|
214
214
|
for key, val in state.items():
|
|
215
215
|
if not isinstance(val, torch.Tensor):
|
|
@@ -1,80 +0,0 @@
|
|
|
1
|
-
import torch
|
|
2
|
-
import math
|
|
3
|
-
|
|
4
|
-
def _orthogonalize_gradient(p: torch.Tensor, grad: torch.Tensor) -> torch.Tensor:
|
|
5
|
-
"""
|
|
6
|
-
Projects the gradient `grad` to be orthogonal to the parameter `p`.
|
|
7
|
-
Modified from:
|
|
8
|
-
https://github.com/LucasPrietoAl/grokking-at-the-edge-of-numerical-stability/blob/720d2444df12b851d6cb417ab08cf125c822b2ae/orthograd.py
|
|
9
|
-
"""
|
|
10
|
-
original_shape = grad.shape
|
|
11
|
-
original_dtype = grad.dtype
|
|
12
|
-
w = p.view(-1).float()
|
|
13
|
-
g = grad.view(-1).float()
|
|
14
|
-
w_norm_sq = torch.dot(w, w).add_(1e-30)
|
|
15
|
-
proj = torch.dot(w, g) / w_norm_sq
|
|
16
|
-
g_orth = g.sub(w * proj)
|
|
17
|
-
g_norm = g.norm(2)
|
|
18
|
-
g_orth_norm = g_orth.norm(2).add_(1e-30)
|
|
19
|
-
g_orth_scaled = g_orth * (g_norm / g_orth_norm)
|
|
20
|
-
return g_orth_scaled.view(original_shape).to(original_dtype)
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
def iterative_ortho_project(p: torch.Tensor, update: torch.Tensor, iters: int = 5) -> torch.Tensor:
|
|
24
|
-
"""
|
|
25
|
-
Applies iterative alternating orthogonal projection to a 2D matrix.
|
|
26
|
-
Projects the update to be orthogonal to the parameter matrix along
|
|
27
|
-
rows and columns sequentially, alternating dimensions.
|
|
28
|
-
Inspired from Sinkhorn algorithm, 2 iterations is enough to converge
|
|
29
|
-
to cosine similarity of -1e4 to -1e-5 (semi orthogonal).
|
|
30
|
-
"""
|
|
31
|
-
# 1D Vector Case fallback to the standard OrthoGrad
|
|
32
|
-
is_vector = p.ndim < 2 or getattr(p, '_is_dora_scale', False) or getattr(p, 'is_vector', False)
|
|
33
|
-
if is_vector:
|
|
34
|
-
return _orthogonalize_gradient(p, update)
|
|
35
|
-
|
|
36
|
-
original_shape = update.shape
|
|
37
|
-
|
|
38
|
-
# 2D+ Matrix Case
|
|
39
|
-
update_2d = update.view(update.shape[0], -1)
|
|
40
|
-
param_2d = p.view(p.shape[0], -1)
|
|
41
|
-
|
|
42
|
-
m, n = update_2d.shape
|
|
43
|
-
|
|
44
|
-
# Dynamically determine the order based on aspect ratio
|
|
45
|
-
row_first = m > n
|
|
46
|
-
dim = 0 if row_first else 1
|
|
47
|
-
|
|
48
|
-
p_norm_sq_dim = torch.sum(param_2d * param_2d, dim=dim, keepdim=True).add_(1e-30)
|
|
49
|
-
p_norm_sq_adim = torch.sum(param_2d * param_2d, dim=1-dim, keepdim=True).add_(1e-30)
|
|
50
|
-
|
|
51
|
-
for _ in range(iters):
|
|
52
|
-
# First dimension
|
|
53
|
-
update_2d = _ortho_normed_dim(param_2d, update_2d, p_norm_sq_dim, dim)
|
|
54
|
-
# Second dimension
|
|
55
|
-
update_2d = _ortho_normed_dim(param_2d, update_2d, p_norm_sq_adim, 1 - dim)
|
|
56
|
-
|
|
57
|
-
return update_2d.view(original_shape)
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
def _ortho_normed_dim(p_2d: torch.Tensor, update_2d: torch.Tensor, p_norm_sq: torch.Tensor, dim: int) -> torch.Tensor:
|
|
61
|
-
"""
|
|
62
|
-
Projects the update to be orthogonal to p along 'dim' and dynamically restores
|
|
63
|
-
the original magnitude of that dimension pre-projection.
|
|
64
|
-
"""
|
|
65
|
-
# Record target magnitude before projection
|
|
66
|
-
norm_lb = 1 / math.sqrt(update_2d.shape[dim])
|
|
67
|
-
target_norm = update_2d.norm(p=2, dim=dim, keepdim=True).clamp_min_(norm_lb)
|
|
68
|
-
|
|
69
|
-
# Project: g_orth = g - (p * <p, g> / ||p||^2)
|
|
70
|
-
dot_prod = torch.sum(p_2d * update_2d, dim=dim, keepdim=True)
|
|
71
|
-
proj = dot_prod / p_norm_sq
|
|
72
|
-
|
|
73
|
-
# In-place subtraction: update_2d = update_2d - (proj * p_2d)
|
|
74
|
-
# Standard gamma is -1, but -1.01 proved to converge faster
|
|
75
|
-
update_2d.addcmul_(proj, p_2d, value=-1.01)
|
|
76
|
-
|
|
77
|
-
# Magnitude Preservation
|
|
78
|
-
g_orth_norm = update_2d.norm(p=2, dim=dim, keepdim=True).clamp_min_(norm_lb)
|
|
79
|
-
scale_factor = target_norm / g_orth_norm
|
|
80
|
-
return update_2d.mul_(scale_factor)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|