adv-optm 2.4.dev25__tar.gz → 2.5__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {adv_optm-2.4.dev25 → adv_optm-2.5}/PKG-INFO +1 -1
- {adv_optm-2.4.dev25 → adv_optm-2.5}/adv_optm/__init__.py +1 -1
- {adv_optm-2.4.dev25 → adv_optm-2.5}/adv_optm/optim/AdaMuon_adv.py +7 -7
- {adv_optm-2.4.dev25 → adv_optm-2.5}/adv_optm/optim/AdamW_adv.py +5 -5
- {adv_optm-2.4.dev25 → adv_optm-2.5}/adv_optm/optim/Adopt_adv.py +4 -6
- {adv_optm-2.4.dev25 → adv_optm-2.5}/adv_optm/optim/Lion_adv.py +6 -5
- {adv_optm-2.4.dev25 → adv_optm-2.5}/adv_optm/optim/Muon_adv.py +7 -7
- {adv_optm-2.4.dev25 → adv_optm-2.5}/adv_optm/optim/Prodigy_adv.py +4 -4
- {adv_optm-2.4.dev25 → adv_optm-2.5}/adv_optm/optim/SignSGD_adv.py +7 -8
- {adv_optm-2.4.dev25 → adv_optm-2.5}/adv_optm/optim/SinkSGD_adv.py +7 -8
- {adv_optm-2.4.dev25 → adv_optm-2.5}/adv_optm/util/Muon_AuxAdam.py +2 -3
- adv_optm-2.5/adv_optm/util/OrthoGrad.py +92 -0
- {adv_optm-2.4.dev25 → adv_optm-2.5}/adv_optm/util/param_update.py +3 -3
- {adv_optm-2.4.dev25 → adv_optm-2.5}/adv_optm/util/scaled_optm.py +2 -2
- {adv_optm-2.4.dev25 → adv_optm-2.5}/adv_optm/util/state_util.py +1 -1
- {adv_optm-2.4.dev25 → adv_optm-2.5}/adv_optm.egg-info/PKG-INFO +1 -1
- {adv_optm-2.4.dev25 → adv_optm-2.5}/setup.py +1 -1
- adv_optm-2.4.dev25/adv_optm/util/OrthoGrad.py +0 -80
- {adv_optm-2.4.dev25 → adv_optm-2.5}/LICENSE +0 -0
- {adv_optm-2.4.dev25 → adv_optm-2.5}/README.md +0 -0
- {adv_optm-2.4.dev25 → adv_optm-2.5}/adv_optm/optim/__init__.py +0 -0
- {adv_optm-2.4.dev25 → adv_optm-2.5}/adv_optm/util/Kourkoutas.py +0 -0
- {adv_optm-2.4.dev25 → adv_optm-2.5}/adv_optm/util/Muon_util.py +0 -0
- {adv_optm-2.4.dev25 → adv_optm-2.5}/adv_optm/util/__init__.py +0 -0
- {adv_optm-2.4.dev25 → adv_optm-2.5}/adv_optm/util/centered_decay.py +0 -0
- {adv_optm-2.4.dev25 → adv_optm-2.5}/adv_optm/util/factorization_util.py +0 -0
- {adv_optm-2.4.dev25 → adv_optm-2.5}/adv_optm/util/lion_k.py +0 -0
- {adv_optm-2.4.dev25 → adv_optm-2.5}/adv_optm/util/signed_util.py +0 -0
- {adv_optm-2.4.dev25 → adv_optm-2.5}/adv_optm/util/sinkhorn.py +0 -0
- {adv_optm-2.4.dev25 → adv_optm-2.5}/adv_optm/util/update_util.py +0 -0
- {adv_optm-2.4.dev25 → adv_optm-2.5}/adv_optm.egg-info/SOURCES.txt +0 -0
- {adv_optm-2.4.dev25 → adv_optm-2.5}/adv_optm.egg-info/dependency_links.txt +0 -0
- {adv_optm-2.4.dev25 → adv_optm-2.5}/adv_optm.egg-info/requires.txt +0 -0
- {adv_optm-2.4.dev25 → adv_optm-2.5}/adv_optm.egg-info/top_level.txt +0 -0
- {adv_optm-2.4.dev25 → adv_optm-2.5}/setup.cfg +0 -0
|
@@ -57,7 +57,8 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
57
57
|
(default: (3.4445, -4.7750, 2.0315)).
|
|
58
58
|
stochastic_rounding (bool): whether to use stochastic rounding for
|
|
59
59
|
BF16 parameter updates (default: True).
|
|
60
|
-
orthogonal_gradient (
|
|
60
|
+
orthogonal_gradient (str): whether to use OrthoGrad variants. 'disabled': off.
|
|
61
|
+
'flattened': Standard vectorized OrthoGrad. 'iterative': Matrix-wise rank-2 OrthoGrad. (default: disabled)
|
|
61
62
|
nesterov (bool): enables Nesterov momentum (default: False).
|
|
62
63
|
use_atan2 (bool): whether to use the atan2 update rule. (default: False)
|
|
63
64
|
vector_reshape (bool): whether to reshape 1D vectors into 2D
|
|
@@ -114,7 +115,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
114
115
|
adam_fisher_wd (bool): Fisher Adam (FAdam) weight decay for the AdamW part. (default: False)
|
|
115
116
|
adam_use_bias_correction (bool): Bias correction for AdamW.
|
|
116
117
|
adam_use_atan2 (bool): Atan2 update rule for AdamW.
|
|
117
|
-
adam_orthogonal_gradient (
|
|
118
|
+
adam_orthogonal_gradient (str): OrthoGrad for AdamW.
|
|
118
119
|
adam_nesterov (bool): Nesterov momentum for AdamW. (default: False)
|
|
119
120
|
adam_nesterov_coef (float, optional): Nesterov coefficient for AdamW. (default: None)
|
|
120
121
|
adam_kourkoutas_beta (bool): Kourkoutas-β for AdamW.
|
|
@@ -149,7 +150,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
149
150
|
# Stochastic Rounding for BF16
|
|
150
151
|
stochastic_rounding: bool = True,
|
|
151
152
|
# OrthoGrad
|
|
152
|
-
orthogonal_gradient:
|
|
153
|
+
orthogonal_gradient: str = 'disabled', # 'flattened', 'iterative'
|
|
153
154
|
# Adam_atan2 (scale invariant)
|
|
154
155
|
use_atan2: bool = False,
|
|
155
156
|
# NorMuon
|
|
@@ -190,7 +191,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
190
191
|
adam_fisher_wd: bool = False,
|
|
191
192
|
adam_use_bias_correction: bool = True,
|
|
192
193
|
adam_use_atan2: bool = False,
|
|
193
|
-
adam_orthogonal_gradient:
|
|
194
|
+
adam_orthogonal_gradient: str = 'disabled', # 'flattened', 'iterative'
|
|
194
195
|
adam_nesterov: bool = False,
|
|
195
196
|
adam_nesterov_coef: float | None = None,
|
|
196
197
|
adam_kourkoutas_beta: bool = False,
|
|
@@ -213,7 +214,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
213
214
|
print("Warning: spectral_normalization is incompatible with rms_rescaling, Disabling rms_rescaling.")
|
|
214
215
|
rms_rescaling = False
|
|
215
216
|
if spectral_normalization and accelerated_ns:
|
|
216
|
-
ValueError("spectral_normalization violates accelerated Newton-Schulz assumptions. Pick one of them.")
|
|
217
|
+
raise ValueError("spectral_normalization violates accelerated Newton-Schulz assumptions. Pick one of them.")
|
|
217
218
|
|
|
218
219
|
# Legacy backwards compatibility support for `nnmf_factor=True`
|
|
219
220
|
if nnmf_factor:
|
|
@@ -515,8 +516,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
515
516
|
grad = approx_mars(grad, state['last_grad'], group['mars_gamma'], beta1)
|
|
516
517
|
|
|
517
518
|
|
|
518
|
-
|
|
519
|
-
grad = _orthogonalize_gradient(p, grad)
|
|
519
|
+
grad = _orthogonalize_gradient(p, grad, group.get("orthogonal_gradient"))
|
|
520
520
|
|
|
521
521
|
if state['factored']: # Factored Muon
|
|
522
522
|
d1, d2 = state['effective_shape']
|
|
@@ -45,7 +45,8 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
45
45
|
stochastic_rounding (bool): whether to use stochastic
|
|
46
46
|
rounding for BF16 parameter updates (default: True).
|
|
47
47
|
use_atan2 (bool): whether to use the atan2 update rule. (default: False)
|
|
48
|
-
orthogonal_gradient (
|
|
48
|
+
orthogonal_gradient (str): whether to use OrthoGrad variants. 'disabled': off.
|
|
49
|
+
'flattened': Standard vectorized OrthoGrad. 'iterative': Matrix-wise rank-2 OrthoGrad. (default: disabled)
|
|
49
50
|
normed_momentum (bool): whether to compute the first moment on the normalized gradient. (default: False)
|
|
50
51
|
kourkoutas_beta (bool): whether to enable the layer-wise dynamic β₂ logic.
|
|
51
52
|
If `False`, the optimizer behaves as standard AdamW. (default: False)
|
|
@@ -104,7 +105,7 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
104
105
|
# Adam_atan2 (scale invariant)
|
|
105
106
|
use_atan2: bool = False,
|
|
106
107
|
# OrthoGrad
|
|
107
|
-
orthogonal_gradient:
|
|
108
|
+
orthogonal_gradient: str = 'disabled', # 'flattened', 'iterative'
|
|
108
109
|
# Nesterov momentum
|
|
109
110
|
nesterov: bool = False,
|
|
110
111
|
nesterov_coef: float | None = None,
|
|
@@ -326,8 +327,7 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
326
327
|
def _step_parameter(self, p, grad, state, group, step_size, beta1, beta2, sqrt_bias_correction2, random_int_tensor, random_int_state_tensor):
|
|
327
328
|
grad = upcast_grad_for_precision(grad, state, group['state_precision'])
|
|
328
329
|
|
|
329
|
-
|
|
330
|
-
grad = _orthogonalize_gradient(p, grad)
|
|
330
|
+
grad = _orthogonalize_gradient(p, grad, group["orthogonal_gradient"])
|
|
331
331
|
|
|
332
332
|
nesterov = group.get('nesterov', False)
|
|
333
333
|
nesterov_coef = group.get('nesterov_coef', None)
|
|
@@ -462,7 +462,7 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
462
462
|
else:
|
|
463
463
|
update.mul_(update_scaling)
|
|
464
464
|
|
|
465
|
-
param_update.apply_parameter_update(self, p, group, update,
|
|
465
|
+
param_update.apply_parameter_update(self, p, group, update, group['lr'], random_int_tensor=random_int_tensor, wd_scaler=wd_scaler)
|
|
466
466
|
|
|
467
467
|
def compile(self, *args, **kwargs):
|
|
468
468
|
self._compiled_step_parameter = torch.compile(self._step_parameter, *args, **kwargs)
|
|
@@ -108,7 +108,7 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
108
108
|
# Stochastic Rounding for BF16
|
|
109
109
|
stochastic_rounding: bool = True,
|
|
110
110
|
# OrthoGrad
|
|
111
|
-
orthogonal_gradient:
|
|
111
|
+
orthogonal_gradient: str = 'disabled', # 'flattened', 'iterative'
|
|
112
112
|
# Nesterov momentum
|
|
113
113
|
nesterov: bool = False,
|
|
114
114
|
nesterov_coef: float | None = None,
|
|
@@ -158,7 +158,7 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
158
158
|
|
|
159
159
|
defaults = {
|
|
160
160
|
"lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay,
|
|
161
|
-
"fisher_wd": fisher_wd, "cautious_wd": cautious_wd,
|
|
161
|
+
"fisher_wd": fisher_wd, "cautious_wd": cautious_wd, "orthogonal_gradient": orthogonal_gradient,
|
|
162
162
|
"nesterov": nesterov, "nesterov_coef": nesterov_coef,
|
|
163
163
|
"kourkoutas_beta": kourkoutas_beta, "beta2_min": beta2_min, "ema_alpha": ema_alpha,
|
|
164
164
|
"tiny_spike": tiny_spike, "k_warmup_steps": k_warmup_steps, "k_logging": k_logging,
|
|
@@ -172,7 +172,6 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
172
172
|
self.clip_lambda = clip_lambda
|
|
173
173
|
self.stochastic_rounding = stochastic_rounding
|
|
174
174
|
self.use_atan2 = use_atan2
|
|
175
|
-
self.orthogonal_gradient = orthogonal_gradient
|
|
176
175
|
self.kourkoutas_beta = kourkoutas_beta
|
|
177
176
|
self.layer_key_fn = layer_key_fn
|
|
178
177
|
self._init_lr = lr if lr > 0 else 1
|
|
@@ -237,7 +236,7 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
237
236
|
dtype = torch.float32 if (state['factored'] or req_precision == 'factored') else p.dtype
|
|
238
237
|
|
|
239
238
|
vt_dtype = torch.float32 if (state['factored'] or state['factored_2nd'] or req_precision in ['factored', 'bf16_sr', 'int8_sr']) else dtype
|
|
240
|
-
vt_init = grad.pow(2).to(vt_dtype)
|
|
239
|
+
vt_init = grad.pow(2).to(vt_dtype)
|
|
241
240
|
|
|
242
241
|
if state['factored']:
|
|
243
242
|
state['effective_shape'] = _get_effective_shape(p.numel())
|
|
@@ -329,8 +328,7 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
329
328
|
def _step_parameter(self, p, grad, state, group, lr, beta1, beta2, random_int_tensor, random_int_state_tensor):
|
|
330
329
|
grad = upcast_grad_for_precision(grad, state, group['state_precision'])
|
|
331
330
|
|
|
332
|
-
|
|
333
|
-
grad = _orthogonalize_gradient(p, grad)
|
|
331
|
+
grad = _orthogonalize_gradient(p, grad, group["orthogonal_gradient"])
|
|
334
332
|
|
|
335
333
|
nesterov = group.get('nesterov', False)
|
|
336
334
|
nesterov_coef = group.get('nesterov_coef', None)
|
|
@@ -67,7 +67,7 @@ class Lion_adv(torch.optim.Optimizer):
|
|
|
67
67
|
# Stochastic Rounding for BF16
|
|
68
68
|
stochastic_rounding: bool = True,
|
|
69
69
|
# OrthoGrad
|
|
70
|
-
orthogonal_gradient:
|
|
70
|
+
orthogonal_gradient: str = 'disabled', # 'flattened', 'iterative'
|
|
71
71
|
# Lion-k
|
|
72
72
|
kappa_p: float = 1.0,
|
|
73
73
|
auto_kappa_p: bool = False,
|
|
@@ -213,8 +213,9 @@ class Lion_adv(torch.optim.Optimizer):
|
|
|
213
213
|
def _step_parameter(self, p, grad, state, group, lr, random_int_tensor, random_noise_tensor):
|
|
214
214
|
if grad.dtype != torch.float32 and state['factored']:
|
|
215
215
|
grad = grad.float()
|
|
216
|
-
|
|
217
|
-
|
|
216
|
+
is_vector = p.ndim < 2 or getattr(p, '_is_dora_scale', False) or getattr(p, 'is_vector', False)
|
|
217
|
+
|
|
218
|
+
grad = _orthogonalize_gradient(p, grad, group["orthogonal_gradient"])
|
|
218
219
|
|
|
219
220
|
# Lion-K Logic
|
|
220
221
|
kappa_p = group.get("kappa_p", 1.0)
|
|
@@ -250,7 +251,7 @@ class Lion_adv(torch.optim.Optimizer):
|
|
|
250
251
|
update = update.view(p.shape)
|
|
251
252
|
|
|
252
253
|
if group.get('stochastic_sign', False):
|
|
253
|
-
update = apply_stochastic_sign_(update, noise=random_noise_tensor)
|
|
254
|
+
update = apply_stochastic_sign_(update, noise=random_noise_tensor, is_vector=is_vector)
|
|
254
255
|
else:
|
|
255
256
|
update = _get_lion_k_update(update, kappa_p)
|
|
256
257
|
|
|
@@ -265,7 +266,7 @@ class Lion_adv(torch.optim.Optimizer):
|
|
|
265
266
|
exp_avg.lerp_(grad, 1 - beta2)
|
|
266
267
|
|
|
267
268
|
if group.get('stochastic_sign', False):
|
|
268
|
-
update = apply_stochastic_sign_(update, noise=random_noise_tensor)
|
|
269
|
+
update = apply_stochastic_sign_(update, noise=random_noise_tensor, is_vector=is_vector)
|
|
269
270
|
else:
|
|
270
271
|
update = _get_lion_k_update(update, kappa_p)
|
|
271
272
|
|
|
@@ -39,7 +39,8 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
39
39
|
(default: (3.4445, -4.7750, 2.0315)).
|
|
40
40
|
stochastic_rounding (bool): whether to use stochastic rounding for
|
|
41
41
|
BF16 parameter updates (default: True).
|
|
42
|
-
orthogonal_gradient (
|
|
42
|
+
orthogonal_gradient (str): whether to use OrthoGrad variants. 'disabled': off.
|
|
43
|
+
'flattened': Standard vectorized OrthoGrad. 'iterative': Matrix-wise rank-2 OrthoGrad. (default: disabled)
|
|
43
44
|
vector_reshape (bool): whether to reshape 1D vectors into 2D
|
|
44
45
|
matrices to apply low-rank compression (default: True).
|
|
45
46
|
nnmf_factor (bool): whether to use the factorization or disable it to use
|
|
@@ -89,7 +90,7 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
89
90
|
adam_fisher_wd (bool): Fisher Adam (FAdam) weight decay for the AdamW part. (default: False)
|
|
90
91
|
adam_use_bias_correction (bool): Bias correction for AdamW.
|
|
91
92
|
adam_use_atan2 (bool): Atan2 update rule for AdamW.
|
|
92
|
-
adam_orthogonal_gradient (
|
|
93
|
+
adam_orthogonal_gradient (str): OrthoGrad for AdamW.
|
|
93
94
|
adam_nesterov (bool): Nesterov momentum for AdamW. (default: False)
|
|
94
95
|
adam_nesterov_coef (float, optional): Nesterov coefficient for AdamW. (default: None)
|
|
95
96
|
adam_kourkoutas_beta (bool): Kourkoutas-β for AdamW.
|
|
@@ -121,7 +122,7 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
121
122
|
# Stochastic Rounding for BF16
|
|
122
123
|
stochastic_rounding: bool = True,
|
|
123
124
|
# OrthoGrad
|
|
124
|
-
orthogonal_gradient:
|
|
125
|
+
orthogonal_gradient: str = 'disabled', # 'flattened', 'iterative'
|
|
125
126
|
# RMS Rescaling
|
|
126
127
|
rms_rescaling: bool = True,
|
|
127
128
|
# SMMF factorization
|
|
@@ -159,7 +160,7 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
159
160
|
adam_fisher_wd: bool = False,
|
|
160
161
|
adam_use_bias_correction: bool = True,
|
|
161
162
|
adam_use_atan2: bool = False,
|
|
162
|
-
adam_orthogonal_gradient:
|
|
163
|
+
adam_orthogonal_gradient: str = 'disabled', # 'flattened', 'iterative'
|
|
163
164
|
adam_nesterov: bool = False,
|
|
164
165
|
adam_nesterov_coef: float | None = None,
|
|
165
166
|
adam_kourkoutas_beta: bool = False,
|
|
@@ -186,7 +187,7 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
186
187
|
print("Warning: spectral_normalization is incompatible with rms_rescaling, Disabling rms_rescaling.")
|
|
187
188
|
rms_rescaling = False
|
|
188
189
|
if spectral_normalization and accelerated_ns:
|
|
189
|
-
ValueError("spectral_normalization violates accelerated Newton-Schulz assumptions. Pick one of them.")
|
|
190
|
+
raise ValueError("spectral_normalization violates accelerated Newton-Schulz assumptions. Pick one of them.")
|
|
190
191
|
|
|
191
192
|
# Legacy backwards compatibility support for `nnmf_factor=True`
|
|
192
193
|
if nnmf_factor:
|
|
@@ -457,8 +458,7 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
457
458
|
if grad.dtype != torch.float32 and state.get('factored', False):
|
|
458
459
|
grad = grad.float()
|
|
459
460
|
|
|
460
|
-
|
|
461
|
-
grad = _orthogonalize_gradient(p, grad)
|
|
461
|
+
grad = _orthogonalize_gradient(p, grad, group.get("orthogonal_gradient"))
|
|
462
462
|
|
|
463
463
|
if state['factored']: # Factored Muon
|
|
464
464
|
d1, d2 = state['effective_shape']
|
|
@@ -43,7 +43,8 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
43
43
|
stochastic_rounding (bool): whether to use stochastic
|
|
44
44
|
rounding for BF16 parameter updates (default: True).
|
|
45
45
|
use_atan2 (bool): whether to use the atan2 update rule. (default: False)
|
|
46
|
-
orthogonal_gradient (
|
|
46
|
+
orthogonal_gradient (str): whether to use OrthoGrad variants. 'disabled': off.
|
|
47
|
+
'flattened': Standard vectorized OrthoGrad. 'iterative': Matrix-wise rank-2 OrthoGrad. (default: disabled)
|
|
47
48
|
nnmf_factor (bool): whether to use the factorization or disable it to use
|
|
48
49
|
the uncompressed optimizer. (default: False)
|
|
49
50
|
factored_2nd (bool): whether to keep the first moment uncompressed (dense)
|
|
@@ -119,7 +120,7 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
119
120
|
# Adam_atan2 (scale invariant)
|
|
120
121
|
use_atan2: bool = False,
|
|
121
122
|
# OrthoGrad
|
|
122
|
-
orthogonal_gradient:
|
|
123
|
+
orthogonal_gradient: str = 'disabled', # 'flattened', 'iterative'
|
|
123
124
|
# Nesterov momentum
|
|
124
125
|
nesterov: bool = False,
|
|
125
126
|
nesterov_coef: float | None = None,
|
|
@@ -371,8 +372,7 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
371
372
|
def _step_parameter(self, p, grad, state, group, beta2, d, dlr, random_int_tensor, random_int_state_tensor):
|
|
372
373
|
grad = upcast_grad_for_precision(grad, state, group['state_precision'])
|
|
373
374
|
|
|
374
|
-
|
|
375
|
-
grad = _orthogonalize_gradient(p, grad)
|
|
375
|
+
grad = _orthogonalize_gradient(p, grad, group["orthogonal_gradient"])
|
|
376
376
|
|
|
377
377
|
nesterov = group.get('nesterov', False)
|
|
378
378
|
nesterov_coef = group.get('nesterov_coef', None)
|
|
@@ -62,7 +62,7 @@ class SignSGD_adv(torch.optim.Optimizer):
|
|
|
62
62
|
# Stochastic Rounding for BF16
|
|
63
63
|
stochastic_rounding: bool = True,
|
|
64
64
|
# OrthoGrad
|
|
65
|
-
orthogonal_gradient:
|
|
65
|
+
orthogonal_gradient: str = 'disabled', # 'flattened', 'iterative'
|
|
66
66
|
# Stochastic Sign Operator
|
|
67
67
|
stochastic_sign: bool = False,
|
|
68
68
|
# Nesterov momentum
|
|
@@ -171,7 +171,7 @@ class SignSGD_adv(torch.optim.Optimizer):
|
|
|
171
171
|
def __init_state(self, p, group):
|
|
172
172
|
state = self.state[p]
|
|
173
173
|
# State Initialization
|
|
174
|
-
if
|
|
174
|
+
if 'step' not in state:
|
|
175
175
|
req_precision = group['state_precision']
|
|
176
176
|
is_vector = len(p.shape) == 1 and not group['vector_reshape']
|
|
177
177
|
|
|
@@ -259,8 +259,7 @@ class SignSGD_adv(torch.optim.Optimizer):
|
|
|
259
259
|
wd_target = None
|
|
260
260
|
cwd_target = None
|
|
261
261
|
|
|
262
|
-
|
|
263
|
-
grad = _orthogonalize_gradient(p, grad)
|
|
262
|
+
grad = _orthogonalize_gradient(p, grad, group["orthogonal_gradient"])
|
|
264
263
|
|
|
265
264
|
if normed_mt:
|
|
266
265
|
if sso:
|
|
@@ -282,7 +281,7 @@ class SignSGD_adv(torch.optim.Optimizer):
|
|
|
282
281
|
|
|
283
282
|
if nesterov and normed_mt:
|
|
284
283
|
# Scale the normalized gradient using empirical buffer magnitude (SNR recovery)
|
|
285
|
-
normed_grad =
|
|
284
|
+
normed_grad = exp_avg.abs().mul_(grad_reshaped)
|
|
286
285
|
|
|
287
286
|
exp_avg.lerp_(grad_reshaped, 1 - momentum)
|
|
288
287
|
|
|
@@ -313,7 +312,7 @@ class SignSGD_adv(torch.optim.Optimizer):
|
|
|
313
312
|
|
|
314
313
|
if nesterov and normed_mt:
|
|
315
314
|
# Scale the normalized gradient using empirical buffer magnitude (SNR recovery)
|
|
316
|
-
normed_grad =
|
|
315
|
+
normed_grad = exp_avg.abs().mul_(grad)
|
|
317
316
|
|
|
318
317
|
exp_avg.lerp_(grad, 1 - momentum)
|
|
319
318
|
|
|
@@ -344,7 +343,7 @@ class SignSGD_adv(torch.optim.Optimizer):
|
|
|
344
343
|
if group.get('geometric_wd', False) and group["weight_decay"] > 0 :
|
|
345
344
|
wd_target = get_signsgd_wd_target(p, denom=denom, stochastic_sign=sso, noise=random_noise_tensor, is_vector=is_vector)
|
|
346
345
|
|
|
347
|
-
if group.get('centered_wd', 0.0) > 0 and '
|
|
346
|
+
if group.get('centered_wd', 0.0) > 0 and 'anchor_data' in state:
|
|
348
347
|
anchor = dequantize_anchor(p, state, group, p.dtype)
|
|
349
348
|
cwd_target = get_signsgd_wd_target(p.sub(anchor), denom=denom, stochastic_sign=sso, noise=random_noise_tensor, is_vector=is_vector)
|
|
350
349
|
del anchor
|
|
@@ -355,7 +354,7 @@ class SignSGD_adv(torch.optim.Optimizer):
|
|
|
355
354
|
update_scaling = lr * A if snr_cond else lr
|
|
356
355
|
update.mul_(update_scaling)
|
|
357
356
|
|
|
358
|
-
param_update.apply_parameter_update(self, p, group, update, lr, random_int_tensor=random_int_tensor, wd_target=wd_target, cwd_target=cwd_target
|
|
357
|
+
param_update.apply_parameter_update(self, p, group, update, lr, random_int_tensor=random_int_tensor, wd_target=wd_target, cwd_target=cwd_target)
|
|
359
358
|
|
|
360
359
|
def compile(self, *args, **kwargs):
|
|
361
360
|
self._compiled_step_parameter = torch.compile(self._step_parameter, *args, **kwargs)
|
|
@@ -69,7 +69,7 @@ class SinkSGD_adv(torch.optim.Optimizer):
|
|
|
69
69
|
# Stochastic Rounding for BF16
|
|
70
70
|
stochastic_rounding: bool = True,
|
|
71
71
|
# OrthoGrad
|
|
72
|
-
orthogonal_gradient:
|
|
72
|
+
orthogonal_gradient: str = 'disabled', # 'flattened', 'iterative'
|
|
73
73
|
# Spectral Normed Optimizer
|
|
74
74
|
spectral_normalization: bool = False,
|
|
75
75
|
# Centered WD
|
|
@@ -89,8 +89,8 @@ class SinkSGD_adv(torch.optim.Optimizer):
|
|
|
89
89
|
raise ValueError(f"Momentum should be >= 0.0. Got {momentum}")
|
|
90
90
|
if not (weight_decay >= 0.0):
|
|
91
91
|
raise ValueError(f"Weight-decay should be >= 0.0. Got {weight_decay}")
|
|
92
|
-
if snr_cond and not normed_momentum:
|
|
93
|
-
raise NotImplementedError(f"snr_cond is intended to be used with normed_momentum")
|
|
92
|
+
if snr_cond and not normed_momentum and not momentum > 0:
|
|
93
|
+
raise NotImplementedError(f"snr_cond is intended to be used with normed_momentum.")
|
|
94
94
|
|
|
95
95
|
state_precision = state_precision.lower()
|
|
96
96
|
valid_precisions = {"auto", "fp32", "factored", "bf16_sr", "fp16", "int8_sr"}
|
|
@@ -237,8 +237,7 @@ class SinkSGD_adv(torch.optim.Optimizer):
|
|
|
237
237
|
wd_target = None
|
|
238
238
|
cwd_target = None
|
|
239
239
|
|
|
240
|
-
|
|
241
|
-
grad = _orthogonalize_gradient(p, grad)
|
|
240
|
+
grad = _orthogonalize_gradient(p, grad, group["orthogonal_gradient"])
|
|
242
241
|
|
|
243
242
|
if normed_mt:
|
|
244
243
|
if not is_vector:
|
|
@@ -266,7 +265,7 @@ class SinkSGD_adv(torch.optim.Optimizer):
|
|
|
266
265
|
|
|
267
266
|
if nesterov and normed_mt:
|
|
268
267
|
# Scale the normalized gradient using empirical buffer magnitude (SNR recovery)
|
|
269
|
-
normed_grad =
|
|
268
|
+
normed_grad = buf.abs().mul_(grad_reshaped)
|
|
270
269
|
|
|
271
270
|
buf.lerp_(grad_reshaped, 1 - momentum)
|
|
272
271
|
|
|
@@ -303,7 +302,7 @@ class SinkSGD_adv(torch.optim.Optimizer):
|
|
|
303
302
|
|
|
304
303
|
if nesterov and normed_mt:
|
|
305
304
|
# Scale the normalized gradient using empirical buffer magnitude (SNR recovery)
|
|
306
|
-
normed_grad =
|
|
305
|
+
normed_grad = buf.abs().mul_(grad)
|
|
307
306
|
|
|
308
307
|
buf.lerp_(grad, 1 - momentum)
|
|
309
308
|
|
|
@@ -346,7 +345,7 @@ class SinkSGD_adv(torch.optim.Optimizer):
|
|
|
346
345
|
wd_scaler = get_sinkhorn_wd_scaler(p, row_denom=vt_row, col_denom=vt_col)
|
|
347
346
|
else:
|
|
348
347
|
wd_target = get_signsgd_wd_target(p, denom=denom)
|
|
349
|
-
if is_vector and group.get('centered_wd', 0.0) > 0 and '
|
|
348
|
+
if is_vector and group.get('centered_wd', 0.0) > 0 and 'anchor_data' in state:
|
|
350
349
|
anchor = dequantize_anchor(p, state, group, p.dtype)
|
|
351
350
|
cwd_target = get_signsgd_wd_target(p.sub(anchor), denom=denom)
|
|
352
351
|
del anchor
|
|
@@ -71,8 +71,7 @@ def _init_auxadam_state(self, p, group):
|
|
|
71
71
|
def _adam_step_parameter(self, p, grad, state, group, beta1_adam, beta2_adam, sqrt_bias_correction2, step_size, random_int_tensor, random_int_state_tensor=None):
|
|
72
72
|
grad = upcast_grad_for_precision(grad, state, group.get('adam_state_precision', 'auto'))
|
|
73
73
|
|
|
74
|
-
|
|
75
|
-
grad = _orthogonalize_gradient(p, grad)
|
|
74
|
+
grad = _orthogonalize_gradient(p, grad, group.get("adam_orthogonal_gradient"))
|
|
76
75
|
|
|
77
76
|
if hasattr(self, 'kourkoutas_helper') and self.kourkoutas_helper:
|
|
78
77
|
# Accumulate current grad's norm for the *next* step
|
|
@@ -190,4 +189,4 @@ def _adam_step_parameter(self, p, grad, state, group, beta1_adam, beta2_adam, sq
|
|
|
190
189
|
else:
|
|
191
190
|
update.mul_(update_scaling)
|
|
192
191
|
|
|
193
|
-
param_update.apply_parameter_update(self, p, group, update,
|
|
192
|
+
param_update.apply_parameter_update(self, p, group, update, group['lr'], group["adam_weight_decay"], random_int_tensor=random_int_tensor, wd_scaler=wd_scaler)
|
|
@@ -0,0 +1,92 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import math
|
|
3
|
+
|
|
4
|
+
def _orthogonalize_gradient(p: torch.Tensor, grad: torch.Tensor, mode: str) -> torch.Tensor:
|
|
5
|
+
"""
|
|
6
|
+
Projects the gradient `grad` to be orthogonal to the parameter `p`.
|
|
7
|
+
Supports two modes: 'flattened' (vectorized) and 'iterative' (matrix-wise).
|
|
8
|
+
"""
|
|
9
|
+
if mode == 'disabled':
|
|
10
|
+
return grad
|
|
11
|
+
elif mode == 'flattened':
|
|
12
|
+
return flattened_ortho_project(p, grad)
|
|
13
|
+
elif mode == 'iterative':
|
|
14
|
+
return iterative_ortho_project(p, grad, iters=3)
|
|
15
|
+
|
|
16
|
+
def flattened_ortho_project(p: torch.Tensor, grad: torch.Tensor) -> torch.Tensor:
|
|
17
|
+
"""
|
|
18
|
+
Projects the flattened gradient `grad` to be orthogonal to the flattened parameter `p`.
|
|
19
|
+
Modified from:
|
|
20
|
+
https://github.com/LucasPrietoAl/grokking-at-the-edge-of-numerical-stability/blob/720d2444df12b851d6cb417ab08cf125c822b2ae/orthograd.py
|
|
21
|
+
"""
|
|
22
|
+
original_shape = grad.shape
|
|
23
|
+
original_dtype = grad.dtype
|
|
24
|
+
w = p.view(-1).float()
|
|
25
|
+
g = grad.view(-1).float()
|
|
26
|
+
w_norm_sq = torch.dot(w, w).add_(1e-30)
|
|
27
|
+
proj = torch.dot(w, g) / w_norm_sq
|
|
28
|
+
g_orth = g.sub(w * proj)
|
|
29
|
+
g_norm = g.norm(2)
|
|
30
|
+
g_orth_norm = g_orth.norm(2).add_(1e-30)
|
|
31
|
+
g_orth_scaled = g_orth * (g_norm / g_orth_norm)
|
|
32
|
+
return g_orth_scaled.view(original_shape).to(original_dtype)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def iterative_ortho_project(p: torch.Tensor, grad: torch.Tensor, iters: int = 3) -> torch.Tensor:
|
|
36
|
+
"""
|
|
37
|
+
Applies iterative alternating orthogonal projection to a 2D matrix.
|
|
38
|
+
Projects the grad to be orthogonal to the parameter matrix along
|
|
39
|
+
rows and columns sequentially, alternating dimensions.
|
|
40
|
+
Inspired from Sinkhorn algorithm, 2-3 iterations is enough to converge
|
|
41
|
+
to cosine similarity of -1e4 to -1e-6 for every row/col (semi orthogonal).
|
|
42
|
+
"""
|
|
43
|
+
# 1D Vector Case fallback to the standard OrthoGrad
|
|
44
|
+
is_vector = p.ndim < 2 or getattr(p, '_is_dora_scale', False) or getattr(p, 'is_vector', False)
|
|
45
|
+
if is_vector:
|
|
46
|
+
return _orthogonalize_gradient(p, grad)
|
|
47
|
+
|
|
48
|
+
original_shape = grad.shape
|
|
49
|
+
|
|
50
|
+
# 2D+ Matrix Case
|
|
51
|
+
grad_2d = grad.view(grad.shape[0], -1)
|
|
52
|
+
param_2d = p.view(p.shape[0], -1)
|
|
53
|
+
|
|
54
|
+
m, n = grad_2d.shape
|
|
55
|
+
|
|
56
|
+
# Dynamically determine the order based on aspect ratio
|
|
57
|
+
row_first = m > n
|
|
58
|
+
dim = 0 if row_first else 1
|
|
59
|
+
|
|
60
|
+
p_norm_sq_dim = torch.sum(param_2d * param_2d, dim=dim, keepdim=True).add_(1e-30)
|
|
61
|
+
p_norm_sq_adim = torch.sum(param_2d * param_2d, dim=1-dim, keepdim=True).add_(1e-30)
|
|
62
|
+
|
|
63
|
+
for _ in range(iters):
|
|
64
|
+
# First dimension
|
|
65
|
+
grad_2d = _ortho_normed_dim(param_2d, grad_2d, p_norm_sq_dim, dim)
|
|
66
|
+
# Second dimension
|
|
67
|
+
grad_2d = _ortho_normed_dim(param_2d, grad_2d, p_norm_sq_adim, 1 - dim)
|
|
68
|
+
|
|
69
|
+
return grad_2d.view(original_shape)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def _ortho_normed_dim(p_2d: torch.Tensor, grad_2d: torch.Tensor, p_norm_sq: torch.Tensor, dim: int) -> torch.Tensor:
|
|
73
|
+
"""
|
|
74
|
+
Projects the grad to be orthogonal to p along 'dim' and dynamically restores
|
|
75
|
+
the original magnitude of that dimension pre-projection.
|
|
76
|
+
"""
|
|
77
|
+
# Record target magnitude before projection
|
|
78
|
+
norm_lb = 1 / math.sqrt(grad_2d.shape[dim])
|
|
79
|
+
target_norm = grad_2d.norm(p=2, dim=dim, keepdim=True).clamp_min_(norm_lb)
|
|
80
|
+
|
|
81
|
+
# Project: g_orth = g - (p * <p, g> / ||p||^2)
|
|
82
|
+
dot_prod = torch.sum(p_2d * grad_2d, dim=dim, keepdim=True)
|
|
83
|
+
proj = dot_prod / p_norm_sq
|
|
84
|
+
|
|
85
|
+
# In-place subtraction: grad_2d = grad_2d - (proj * p_2d)
|
|
86
|
+
# Standard gamma is -1, but -1.01 proved to converge faster
|
|
87
|
+
grad_2d.addcmul_(proj, p_2d, value=-1.01)
|
|
88
|
+
|
|
89
|
+
# Magnitude Preservation
|
|
90
|
+
g_orth_norm = grad_2d.norm(p=2, dim=dim, keepdim=True).clamp_min_(norm_lb)
|
|
91
|
+
scale_factor = target_norm / g_orth_norm
|
|
92
|
+
return grad_2d.mul_(scale_factor)
|
|
@@ -6,7 +6,7 @@ import torch.nn.functional as F
|
|
|
6
6
|
|
|
7
7
|
from typing import Dict, Any
|
|
8
8
|
|
|
9
|
-
from .scaled_optm import adjust_wds
|
|
9
|
+
from .scaled_optm import adjust_wds
|
|
10
10
|
from .centered_decay import dequantize_anchor
|
|
11
11
|
|
|
12
12
|
_generators: Dict[torch.device, torch.Generator] = {}
|
|
@@ -48,7 +48,7 @@ def _apply_weight_decay(
|
|
|
48
48
|
p_calc.add_(wd_target, alpha=-scaled_wd)
|
|
49
49
|
|
|
50
50
|
# Centered Weight Decay (pulls toward anchor)
|
|
51
|
-
if scaled_cwd is not None and '
|
|
51
|
+
if scaled_cwd is not None and 'anchor_data' in state:
|
|
52
52
|
if cwd_target is not None:
|
|
53
53
|
decay_target = cwd_target
|
|
54
54
|
else:
|
|
@@ -330,7 +330,7 @@ def _copy_int8_sym_blockwise_stochastic_core_(
|
|
|
330
330
|
target: torch.Tensor,
|
|
331
331
|
source: torch.Tensor,
|
|
332
332
|
scales: torch.Tensor,
|
|
333
|
-
random_int_tensor: torch.Tensor
|
|
333
|
+
random_int_tensor: torch.Tensor,
|
|
334
334
|
block_size: int = 2048,
|
|
335
335
|
val_blocks: torch.Tensor | None = None,
|
|
336
336
|
) -> None:
|
|
@@ -61,7 +61,7 @@ def adjust_wds(wd: float, cwd: float, p: torch.Tensor) -> tuple[float, float]:
|
|
|
61
61
|
"""
|
|
62
62
|
# DoRA Scale (Magnitude Vector)
|
|
63
63
|
if getattr(p, '_is_dora_scale', False):
|
|
64
|
-
return
|
|
64
|
+
return wd, cwd
|
|
65
65
|
|
|
66
66
|
if getattr(p, '_is_oft', False):
|
|
67
67
|
return wd, 0.0
|
|
@@ -76,7 +76,7 @@ def adjust_wds(wd: float, cwd: float, p: torch.Tensor) -> tuple[float, float]:
|
|
|
76
76
|
else:
|
|
77
77
|
# 1D Biases or generic 1D parameters
|
|
78
78
|
# Centered WD safely regularizes the delta without collapsing base feature variance.
|
|
79
|
-
return
|
|
79
|
+
return wd, cwd
|
|
80
80
|
|
|
81
81
|
|
|
82
82
|
def scale_wds(wd: float, cwd: float, p: torch.Tensor) -> tuple[float, float]:
|
|
@@ -209,7 +209,7 @@ def fix_loaded_state_dtype(state: dict, p: torch.Tensor, group: dict) -> None:
|
|
|
209
209
|
|
|
210
210
|
# Pre-define sets for known exact-match keys
|
|
211
211
|
uint8_keys = {'sign', 'sign_slow', 'sign_buf', 'shifter'}
|
|
212
|
-
fp32_keys = {'mu_m_nmf', 'mv_m_nmf', 'mu_v_nmf', 'mv_v_nmf', 'mu_m_slow_nmf', 'mv_m_slow_nmf'}
|
|
212
|
+
fp32_keys = {'mu_m_nmf', 'mv_m_nmf', 'mu_v_nmf', 'mv_v_nmf', 'mu_m_slow_nmf', 'mv_m_slow_nmf', "mu_mbuf_nmf", "mv_mbuf_nmf", "mu_b_nmf", "normuon_v"}
|
|
213
213
|
|
|
214
214
|
for key, val in state.items():
|
|
215
215
|
if not isinstance(val, torch.Tensor):
|
|
@@ -1,80 +0,0 @@
|
|
|
1
|
-
import torch
|
|
2
|
-
import math
|
|
3
|
-
|
|
4
|
-
def _orthogonalize_gradient(p: torch.Tensor, grad: torch.Tensor) -> torch.Tensor:
|
|
5
|
-
"""
|
|
6
|
-
Projects the gradient `grad` to be orthogonal to the parameter `p`.
|
|
7
|
-
Modified from:
|
|
8
|
-
https://github.com/LucasPrietoAl/grokking-at-the-edge-of-numerical-stability/blob/720d2444df12b851d6cb417ab08cf125c822b2ae/orthograd.py
|
|
9
|
-
"""
|
|
10
|
-
original_shape = grad.shape
|
|
11
|
-
original_dtype = grad.dtype
|
|
12
|
-
w = p.view(-1).float()
|
|
13
|
-
g = grad.view(-1).float()
|
|
14
|
-
w_norm_sq = torch.dot(w, w).add_(1e-30)
|
|
15
|
-
proj = torch.dot(w, g) / w_norm_sq
|
|
16
|
-
g_orth = g.sub(w * proj)
|
|
17
|
-
g_norm = g.norm(2)
|
|
18
|
-
g_orth_norm = g_orth.norm(2).add_(1e-30)
|
|
19
|
-
g_orth_scaled = g_orth * (g_norm / g_orth_norm)
|
|
20
|
-
return g_orth_scaled.view(original_shape).to(original_dtype)
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
def iterative_ortho_project(p: torch.Tensor, update: torch.Tensor, iters: int = 5) -> torch.Tensor:
|
|
24
|
-
"""
|
|
25
|
-
Applies iterative alternating orthogonal projection to a 2D matrix.
|
|
26
|
-
Projects the update to be orthogonal to the parameter matrix along
|
|
27
|
-
rows and columns sequentially, alternating dimensions.
|
|
28
|
-
Inspired from Sinkhorn algorithm, 2 iterations is enough to converge
|
|
29
|
-
to cosine similarity of -1e4 to -1e-5 (semi orthogonal).
|
|
30
|
-
"""
|
|
31
|
-
# 1D Vector Case fallback to the standard OrthoGrad
|
|
32
|
-
is_vector = p.ndim < 2 or getattr(p, '_is_dora_scale', False) or getattr(p, 'is_vector', False)
|
|
33
|
-
if is_vector:
|
|
34
|
-
return _orthogonalize_gradient(p, update)
|
|
35
|
-
|
|
36
|
-
original_shape = update.shape
|
|
37
|
-
|
|
38
|
-
# 2D+ Matrix Case
|
|
39
|
-
update_2d = update.view(update.shape[0], -1)
|
|
40
|
-
param_2d = p.view(p.shape[0], -1)
|
|
41
|
-
|
|
42
|
-
m, n = update_2d.shape
|
|
43
|
-
|
|
44
|
-
# Dynamically determine the order based on aspect ratio
|
|
45
|
-
row_first = m > n
|
|
46
|
-
dim = 0 if row_first else 1
|
|
47
|
-
|
|
48
|
-
p_norm_sq_dim = torch.sum(param_2d * param_2d, dim=dim, keepdim=True).add_(1e-30)
|
|
49
|
-
p_norm_sq_adim = torch.sum(param_2d * param_2d, dim=1-dim, keepdim=True).add_(1e-30)
|
|
50
|
-
|
|
51
|
-
for _ in range(iters):
|
|
52
|
-
# First dimension
|
|
53
|
-
update_2d = _ortho_normed_dim(param_2d, update_2d, p_norm_sq_dim, dim)
|
|
54
|
-
# Second dimension
|
|
55
|
-
update_2d = _ortho_normed_dim(param_2d, update_2d, p_norm_sq_adim, 1 - dim)
|
|
56
|
-
|
|
57
|
-
return update_2d.view(original_shape)
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
def _ortho_normed_dim(p_2d: torch.Tensor, update_2d: torch.Tensor, p_norm_sq: torch.Tensor, dim: int) -> torch.Tensor:
|
|
61
|
-
"""
|
|
62
|
-
Projects the update to be orthogonal to p along 'dim' and dynamically restores
|
|
63
|
-
the original magnitude of that dimension pre-projection.
|
|
64
|
-
"""
|
|
65
|
-
# Record target magnitude before projection
|
|
66
|
-
norm_lb = 1 / math.sqrt(update_2d.shape[dim])
|
|
67
|
-
target_norm = update_2d.norm(p=2, dim=dim, keepdim=True).clamp_min_(norm_lb)
|
|
68
|
-
|
|
69
|
-
# Project: g_orth = g - (p * <p, g> / ||p||^2)
|
|
70
|
-
dot_prod = torch.sum(p_2d * update_2d, dim=dim, keepdim=True)
|
|
71
|
-
proj = dot_prod / p_norm_sq
|
|
72
|
-
|
|
73
|
-
# In-place subtraction: update_2d = update_2d - (proj * p_2d)
|
|
74
|
-
# Standard gamma is -1, but -1.01 proved to converge faster
|
|
75
|
-
update_2d.addcmul_(proj, p_2d, value=-1.01)
|
|
76
|
-
|
|
77
|
-
# Magnitude Preservation
|
|
78
|
-
g_orth_norm = update_2d.norm(p=2, dim=dim, keepdim=True).clamp_min_(norm_lb)
|
|
79
|
-
scale_factor = target_norm / g_orth_norm
|
|
80
|
-
return update_2d.mul_(scale_factor)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|