adv-optm 2.4.dev11__tar.gz → 2.4.dev13__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {adv_optm-2.4.dev11 → adv_optm-2.4.dev13}/PKG-INFO +1 -1
- {adv_optm-2.4.dev11 → adv_optm-2.4.dev13}/adv_optm/__init__.py +3 -3
- {adv_optm-2.4.dev11 → adv_optm-2.4.dev13}/adv_optm/optim/AdaMuon_adv.py +1 -1
- {adv_optm-2.4.dev11 → adv_optm-2.4.dev13}/adv_optm/optim/AdamW_adv.py +16 -6
- {adv_optm-2.4.dev11 → adv_optm-2.4.dev13}/adv_optm/optim/Adopt_adv.py +1 -1
- {adv_optm-2.4.dev11 → adv_optm-2.4.dev13}/adv_optm/optim/Muon_adv.py +1 -1
- {adv_optm-2.4.dev11 → adv_optm-2.4.dev13}/adv_optm/optim/SignSGD_adv.py +1 -1
- adv_optm-2.4.dev11/adv_optm/optim/SGD_adv.py → adv_optm-2.4.dev13/adv_optm/optim/SinkSGD_adv.py +31 -22
- {adv_optm-2.4.dev11 → adv_optm-2.4.dev13}/adv_optm/optim/__init__.py +2 -2
- adv_optm-2.4.dev13/adv_optm/util/sinkhorn.py +77 -0
- {adv_optm-2.4.dev11 → adv_optm-2.4.dev13}/adv_optm.egg-info/PKG-INFO +1 -1
- {adv_optm-2.4.dev11 → adv_optm-2.4.dev13}/adv_optm.egg-info/SOURCES.txt +1 -1
- {adv_optm-2.4.dev11 → adv_optm-2.4.dev13}/setup.py +1 -1
- adv_optm-2.4.dev11/adv_optm/util/sinkhorn.py +0 -42
- {adv_optm-2.4.dev11 → adv_optm-2.4.dev13}/LICENSE +0 -0
- {adv_optm-2.4.dev11 → adv_optm-2.4.dev13}/README.md +0 -0
- {adv_optm-2.4.dev11 → adv_optm-2.4.dev13}/adv_optm/optim/Lion_Prodigy_adv.py +0 -0
- {adv_optm-2.4.dev11 → adv_optm-2.4.dev13}/adv_optm/optim/Lion_adv.py +0 -0
- {adv_optm-2.4.dev11 → adv_optm-2.4.dev13}/adv_optm/optim/Prodigy_adv.py +0 -0
- {adv_optm-2.4.dev11 → adv_optm-2.4.dev13}/adv_optm/optim/Simplified_AdEMAMix.py +0 -0
- {adv_optm-2.4.dev11 → adv_optm-2.4.dev13}/adv_optm/util/Kourkoutas.py +0 -0
- {adv_optm-2.4.dev11 → adv_optm-2.4.dev13}/adv_optm/util/Muon_AuxAdam.py +0 -0
- {adv_optm-2.4.dev11 → adv_optm-2.4.dev13}/adv_optm/util/Muon_util.py +0 -0
- {adv_optm-2.4.dev11 → adv_optm-2.4.dev13}/adv_optm/util/OrthoGrad.py +0 -0
- {adv_optm-2.4.dev11 → adv_optm-2.4.dev13}/adv_optm/util/__init__.py +0 -0
- {adv_optm-2.4.dev11 → adv_optm-2.4.dev13}/adv_optm/util/centered_decay.py +0 -0
- {adv_optm-2.4.dev11 → adv_optm-2.4.dev13}/adv_optm/util/factorization_util.py +0 -0
- {adv_optm-2.4.dev11 → adv_optm-2.4.dev13}/adv_optm/util/lion_k.py +0 -0
- {adv_optm-2.4.dev11 → adv_optm-2.4.dev13}/adv_optm/util/param_update.py +0 -0
- {adv_optm-2.4.dev11 → adv_optm-2.4.dev13}/adv_optm/util/scaled_optm.py +0 -0
- {adv_optm-2.4.dev11 → adv_optm-2.4.dev13}/adv_optm/util/signed_util.py +0 -0
- {adv_optm-2.4.dev11 → adv_optm-2.4.dev13}/adv_optm/util/state_util.py +0 -0
- {adv_optm-2.4.dev11 → adv_optm-2.4.dev13}/adv_optm/util/update_util.py +0 -0
- {adv_optm-2.4.dev11 → adv_optm-2.4.dev13}/adv_optm.egg-info/dependency_links.txt +0 -0
- {adv_optm-2.4.dev11 → adv_optm-2.4.dev13}/adv_optm.egg-info/requires.txt +0 -0
- {adv_optm-2.4.dev11 → adv_optm-2.4.dev13}/adv_optm.egg-info/top_level.txt +0 -0
- {adv_optm-2.4.dev11 → adv_optm-2.4.dev13}/setup.cfg +0 -0
|
@@ -8,7 +8,7 @@ from .optim import (
|
|
|
8
8
|
Muon_adv,
|
|
9
9
|
AdaMuon_adv,
|
|
10
10
|
SignSGD_adv,
|
|
11
|
-
|
|
11
|
+
SinkSGD_adv,
|
|
12
12
|
)
|
|
13
13
|
|
|
14
14
|
__all__ = [
|
|
@@ -21,7 +21,7 @@ __all__ = [
|
|
|
21
21
|
"Muon_adv",
|
|
22
22
|
"AdaMuon_adv",
|
|
23
23
|
"SignSGD_adv",
|
|
24
|
-
"
|
|
24
|
+
"SinkSGD_adv",
|
|
25
25
|
]
|
|
26
26
|
|
|
27
|
-
__version__ = "2.4.
|
|
27
|
+
__version__ = "2.4.dev13"
|
|
@@ -468,7 +468,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
468
468
|
if actual_precision == 'bf16_sr' and random_int_state_tensor is not None:
|
|
469
469
|
random_int_state_tensor = param_update._get_random_int_for_sr(p)
|
|
470
470
|
elif actual_precision == 'int8_sr':
|
|
471
|
-
random_int_state_tensor = param_update.
|
|
471
|
+
random_int_state_tensor = param_update._get_random_int_for_8bit_sr(p)
|
|
472
472
|
elif actual_precision == 'fp8_sr':
|
|
473
473
|
random_int_state_tensor = param_update._get_random_int_for_fp8_sr(p)
|
|
474
474
|
else:
|
|
@@ -232,12 +232,13 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
232
232
|
def supports_flat_params(self):
|
|
233
233
|
return False
|
|
234
234
|
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
235
|
+
def init_step(self):
|
|
236
|
+
for group in self.param_groups:
|
|
237
|
+
for i, p in enumerate(group['params']):
|
|
238
|
+
self.__init_state(p, group)
|
|
239
239
|
|
|
240
|
-
|
|
240
|
+
@torch.no_grad()
|
|
241
|
+
def __init_state(self, p, group):
|
|
241
242
|
state = self.state[p]
|
|
242
243
|
|
|
243
244
|
# State Initialization
|
|
@@ -303,6 +304,15 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
303
304
|
|
|
304
305
|
_init_fisher_wd_scaler(group, state, p)
|
|
305
306
|
|
|
307
|
+
@torch.no_grad()
|
|
308
|
+
def step_parameter(self, p: torch.Tensor, group: dict, i: int | None = None):
|
|
309
|
+
if p.grad is None:
|
|
310
|
+
return
|
|
311
|
+
|
|
312
|
+
grad = p.grad
|
|
313
|
+
state = self.state[p]
|
|
314
|
+
self.__init_state(p, group)
|
|
315
|
+
|
|
306
316
|
beta1, beta2 = group['betas']
|
|
307
317
|
|
|
308
318
|
current_step = state['step']
|
|
@@ -333,7 +343,7 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
333
343
|
if group['actual_state_precision'] == 'bf16_sr' and random_int_state_tensor is None:
|
|
334
344
|
random_int_state_tensor = param_update._get_random_int_for_sr(p)
|
|
335
345
|
elif group['actual_state_precision'] == 'int8_sr':
|
|
336
|
-
random_int_state_tensor = param_update.
|
|
346
|
+
random_int_state_tensor = param_update._get_random_int_for_8bit_sr(p)
|
|
337
347
|
elif group['actual_state_precision'] == 'fp8_sr':
|
|
338
348
|
random_int_state_tensor = param_update._get_random_int_for_fp8_sr(p)
|
|
339
349
|
step_param_fn = self._compiled_step_parameter
|
|
@@ -359,7 +359,7 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
359
359
|
if group['actual_state_precision'] == 'bf16_sr' and random_int_state_tensor is None:
|
|
360
360
|
random_int_state_tensor = param_update._get_random_int_for_sr(p)
|
|
361
361
|
elif group['actual_state_precision'] == 'int8_sr':
|
|
362
|
-
random_int_state_tensor = param_update.
|
|
362
|
+
random_int_state_tensor = param_update._get_random_int_for_8bit_sr(p)
|
|
363
363
|
elif group['actual_state_precision'] == 'fp8_sr':
|
|
364
364
|
random_int_state_tensor = param_update._get_random_int_for_fp8_sr(p)
|
|
365
365
|
step_param_fn = self._compiled_step_parameter
|
|
@@ -422,7 +422,7 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
422
422
|
if actual_precision == 'bf16_sr' and random_int_state_tensor is not None:
|
|
423
423
|
random_int_state_tensor = param_update._get_random_int_for_sr(p)
|
|
424
424
|
elif actual_precision == 'int8_sr':
|
|
425
|
-
random_int_state_tensor = param_update.
|
|
425
|
+
random_int_state_tensor = param_update._get_random_int_for_8bit_sr(p)
|
|
426
426
|
elif actual_precision == 'fp8_sr':
|
|
427
427
|
random_int_state_tensor = param_update._get_random_int_for_fp8_sr(p)
|
|
428
428
|
else:
|
|
@@ -252,7 +252,7 @@ class SignSGD_adv(torch.optim.Optimizer):
|
|
|
252
252
|
if group['actual_state_precision'] == 'bf16_sr' and random_int_state_tensor is None:
|
|
253
253
|
random_int_state_tensor = param_update._get_random_int_for_sr(p)
|
|
254
254
|
elif group['actual_state_precision'] == 'int8_sr':
|
|
255
|
-
random_int_state_tensor = param_update.
|
|
255
|
+
random_int_state_tensor = param_update._get_random_int_for_8bit_sr(p)
|
|
256
256
|
elif group['actual_state_precision'] == 'fp8_sr':
|
|
257
257
|
random_int_state_tensor = param_update._get_random_int_for_fp8_sr(p)
|
|
258
258
|
|
adv_optm-2.4.dev11/adv_optm/optim/SGD_adv.py → adv_optm-2.4.dev13/adv_optm/optim/SinkSGD_adv.py
RENAMED
|
@@ -11,12 +11,11 @@ from ..util.centered_decay import _init_anchor
|
|
|
11
11
|
from ..util.state_util import init_state_tensor, get_state, set_state, upcast_grad_for_precision
|
|
12
12
|
from ..util.sinkhorn import apply_sr_sinkhorn
|
|
13
13
|
|
|
14
|
-
class
|
|
14
|
+
class SinkSGD_adv(torch.optim.Optimizer):
|
|
15
15
|
"""
|
|
16
|
-
Implements an advanced Stochastic Gradient Descent (SGD) algorithm.
|
|
17
|
-
This is an advanced version of
|
|
18
|
-
low-rank factorization of optimizer states (SMMF), OrthoGrad,
|
|
19
|
-
Cautious updating, and AdEMAMix extensions.
|
|
16
|
+
Implements an advanced Stochastic Gradient Descent (SGD) with Sinkhorn Iterative Normalization (SinkSGD) algorithm.
|
|
17
|
+
This is an advanced version of SinkSGD with optional features like
|
|
18
|
+
low-rank factorization of optimizer states (SMMF), OrthoGrad, etc.
|
|
20
19
|
|
|
21
20
|
Args:
|
|
22
21
|
params (iterable): iterable of parameters to optimize or dicts defining
|
|
@@ -62,11 +61,11 @@ class SGD_adv(torch.optim.Optimizer):
|
|
|
62
61
|
cautious_wd: bool = False,
|
|
63
62
|
# Stochastic Rounding for BF16
|
|
64
63
|
stochastic_rounding: bool = True,
|
|
65
|
-
# OrthoGrad
|
|
66
|
-
orthogonal_gradient: bool = False,
|
|
67
64
|
# Sinkhorn Iterative Normalization
|
|
68
|
-
sinkhorn: bool = False,
|
|
69
65
|
sinkhorn_iterations: int = 5,
|
|
66
|
+
orthogonal_sinkhorn: bool = False,
|
|
67
|
+
# OrthoGrad
|
|
68
|
+
orthogonal_gradient: bool = False,
|
|
70
69
|
# Spectral Normed Optimizer
|
|
71
70
|
spectral_normalization: bool = False,
|
|
72
71
|
# Centered WD
|
|
@@ -101,7 +100,8 @@ class SGD_adv(torch.optim.Optimizer):
|
|
|
101
100
|
"decoupled_wd": decoupled_wd, "cautious_wd": cautious_wd,
|
|
102
101
|
"orthogonal_gradient": orthogonal_gradient,
|
|
103
102
|
"compiled_optimizer": compiled_optimizer,
|
|
104
|
-
"
|
|
103
|
+
"sinkhorn_iterations": sinkhorn_iterations,
|
|
104
|
+
"orthogonal_sinkhorn": orthogonal_sinkhorn,
|
|
105
105
|
"spectral_normalization": spectral_normalization,
|
|
106
106
|
"centered_wd": centered_wd, "centered_wd_mode": centered_wd_mode,
|
|
107
107
|
"state_precision": state_precision,
|
|
@@ -116,6 +116,8 @@ class SGD_adv(torch.optim.Optimizer):
|
|
|
116
116
|
for device in devices:
|
|
117
117
|
param_update.set_seed(device)
|
|
118
118
|
|
|
119
|
+
self.init_step()
|
|
120
|
+
|
|
119
121
|
self._compiled_step_parameter = None
|
|
120
122
|
if compiled_optimizer:
|
|
121
123
|
self.compile(fullgraph=True)
|
|
@@ -136,14 +138,14 @@ class SGD_adv(torch.optim.Optimizer):
|
|
|
136
138
|
def supports_flat_params(self):
|
|
137
139
|
return False
|
|
138
140
|
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
141
|
+
def init_step(self):
|
|
142
|
+
for group in self.param_groups:
|
|
143
|
+
for i, p in enumerate(group['params']):
|
|
144
|
+
self.__init_state(p, group)
|
|
143
145
|
|
|
144
|
-
|
|
146
|
+
@torch.no_grad()
|
|
147
|
+
def __init_state(self, p, group):
|
|
145
148
|
state = self.state[p]
|
|
146
|
-
|
|
147
149
|
# State Initialization
|
|
148
150
|
if 'step' not in state:
|
|
149
151
|
state['step'] = 0
|
|
@@ -180,6 +182,15 @@ class SGD_adv(torch.optim.Optimizer):
|
|
|
180
182
|
|
|
181
183
|
_init_anchor(p, state, group)
|
|
182
184
|
|
|
185
|
+
@torch.no_grad()
|
|
186
|
+
def step_parameter(self, p: torch.Tensor, group: dict, i: int | None = None):
|
|
187
|
+
if p.grad is None:
|
|
188
|
+
return
|
|
189
|
+
|
|
190
|
+
grad = p.grad
|
|
191
|
+
state = self.state[p]
|
|
192
|
+
self.__init_state(p, group)
|
|
193
|
+
|
|
183
194
|
step_size = group['lr']
|
|
184
195
|
|
|
185
196
|
random_int_tensor = None
|
|
@@ -193,7 +204,7 @@ class SGD_adv(torch.optim.Optimizer):
|
|
|
193
204
|
if group['actual_state_precision'] == 'bf16_sr' and random_int_state_tensor is None:
|
|
194
205
|
random_int_state_tensor = param_update._get_random_int_for_sr(p)
|
|
195
206
|
elif group['actual_state_precision'] == 'int8_sr':
|
|
196
|
-
random_int_state_tensor = param_update.
|
|
207
|
+
random_int_state_tensor = param_update._get_random_int_for_8bit_sr(p)
|
|
197
208
|
elif group['actual_state_precision'] == 'fp8_sr':
|
|
198
209
|
random_int_state_tensor = param_update._get_random_int_for_fp8_sr(p)
|
|
199
210
|
step_param_fn = self._compiled_step_parameter
|
|
@@ -219,7 +230,7 @@ class SGD_adv(torch.optim.Optimizer):
|
|
|
219
230
|
|
|
220
231
|
if momentum != 0:
|
|
221
232
|
buf = _reconstruct_state((state['mu_b_nmf'], state['mv_b_nmf'], state['sign'], d2), signed=True)
|
|
222
|
-
buf.
|
|
233
|
+
buf.lerp_(grad_reshaped, 1 - momentum)
|
|
223
234
|
|
|
224
235
|
# Factorize updated buffer
|
|
225
236
|
state['mu_b_nmf'], state['mv_b_nmf'], state['sign'] = _factorize_state(buf.clone(), signed=True)
|
|
@@ -239,9 +250,7 @@ class SGD_adv(torch.optim.Optimizer):
|
|
|
239
250
|
|
|
240
251
|
if momentum != 0:
|
|
241
252
|
buf = get_state(state, 'momentum_buffer', actual_precision)
|
|
242
|
-
|
|
243
|
-
buf.mul_(momentum).add_(grad, alpha=1 - momentum)
|
|
244
|
-
|
|
253
|
+
buf.lerp_(grad, 1 - momentum)
|
|
245
254
|
|
|
246
255
|
set_state(state, 'momentum_buffer', buf, actual_precision, random_int_state_tensor)
|
|
247
256
|
|
|
@@ -254,8 +263,8 @@ class SGD_adv(torch.optim.Optimizer):
|
|
|
254
263
|
|
|
255
264
|
del random_int_state_tensor
|
|
256
265
|
|
|
257
|
-
|
|
258
|
-
|
|
266
|
+
# Sinkhorn iterative normalization
|
|
267
|
+
update = apply_sr_sinkhorn(update, p, ortho_project=group['orthogonal_sinkhorn'], iters=group['sinkhorn_iterations'])
|
|
259
268
|
|
|
260
269
|
update_scaling = step_size
|
|
261
270
|
if group.get('spectral_normalization', False):
|
|
@@ -7,7 +7,7 @@ from .Lion_Prodigy_adv import Lion_Prodigy_adv
|
|
|
7
7
|
from .Muon_adv import Muon_adv
|
|
8
8
|
from .AdaMuon_adv import AdaMuon_adv
|
|
9
9
|
from .SignSGD_adv import SignSGD_adv
|
|
10
|
-
from .
|
|
10
|
+
from .SinkSGD_adv import SinkSGD_adv
|
|
11
11
|
|
|
12
12
|
__all__ = [
|
|
13
13
|
"AdamW_adv",
|
|
@@ -19,5 +19,5 @@ __all__ = [
|
|
|
19
19
|
"Muon_adv",
|
|
20
20
|
"AdaMuon_adv",
|
|
21
21
|
"SignSGD_adv",
|
|
22
|
-
"
|
|
22
|
+
"SinkSGD_adv",
|
|
23
23
|
]
|
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
import math
|
|
2
|
+
import torch
|
|
3
|
+
|
|
4
|
+
def apply_sr_sinkhorn(update: torch.Tensor, p: torch.Tensor, ortho_project: bool, iters: int = 5) -> torch.Tensor:
|
|
5
|
+
"""
|
|
6
|
+
Applies Square-Root Sinkhorn (SR-Sinkhorn) multi-normalization.
|
|
7
|
+
As described in 'Gradient Multi-Normalization for Efficient LLM Training'.
|
|
8
|
+
|
|
9
|
+
This technique normalizes a 2D matrix alternatively by its row-wise L2 norm
|
|
10
|
+
and column-wise L2 norm, driving it toward a fixed point that uniformly
|
|
11
|
+
distributes update magnitudes.
|
|
12
|
+
"""
|
|
13
|
+
original_shape = update.shape
|
|
14
|
+
original_dtype = update.dtype
|
|
15
|
+
update = update.float()
|
|
16
|
+
|
|
17
|
+
# 1D Vector Case
|
|
18
|
+
if update.dim() == 1:
|
|
19
|
+
if ortho_project:
|
|
20
|
+
p_float = p.float()
|
|
21
|
+
p_norm_sq = torch.dot(p_float, p_float).add_(1e-30)
|
|
22
|
+
proj = torch.dot(p_float, update) / p_norm_sq
|
|
23
|
+
update.sub_(p_float * proj)
|
|
24
|
+
norm = update.norm(p=2).clamp_min_(1e-12)
|
|
25
|
+
return update.mul_(math.sqrt(update.numel()) / norm).view(original_shape).to(original_dtype)
|
|
26
|
+
|
|
27
|
+
# 2D+ Matrix Case
|
|
28
|
+
update_2d = update.view(update.shape[0], -1)
|
|
29
|
+
|
|
30
|
+
m, n = update_2d.shape
|
|
31
|
+
|
|
32
|
+
# Dynamically determine the order of normalization based on aspect ratio
|
|
33
|
+
# Normalizing the longer dimension first aids stability.
|
|
34
|
+
scale_cond = update_2d.shape[0] > update_2d.shape[1]
|
|
35
|
+
dim = 0 if scale_cond else 1
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
# Precompute scaling factors.
|
|
39
|
+
scale_first = math.sqrt(m if scale_cond else n)
|
|
40
|
+
scale_second = math.sqrt(n if scale_cond else m)
|
|
41
|
+
|
|
42
|
+
if ortho_project:
|
|
43
|
+
param_2d = p.float().view(p.shape[0], -1)
|
|
44
|
+
p_norm_sq_dim = torch.sum(param_2d * param_2d, dim=dim, keepdim=True).add_(1e-30)
|
|
45
|
+
p_norm_sq_adim = torch.sum(param_2d * param_2d, dim=1-dim, keepdim=True).add_(1e-30)
|
|
46
|
+
|
|
47
|
+
# In-place alternating Sinkhorn normalization steps
|
|
48
|
+
for _ in range(iters):
|
|
49
|
+
# First normalization step
|
|
50
|
+
norm1 = update_2d.norm(p=2, dim=dim, keepdim=True).clamp_min_(1e-12)
|
|
51
|
+
update_2d.mul_(scale_first / norm1)
|
|
52
|
+
if ortho_project:
|
|
53
|
+
update_2d = ortho_normed(param_2d, update_2d, p_norm_sq_dim, dim, scale_first)
|
|
54
|
+
|
|
55
|
+
# Second normalization step
|
|
56
|
+
norm2 = update_2d.norm(p=2, dim=1-dim, keepdim=True).clamp_min_(1e-12)
|
|
57
|
+
update_2d.mul_(scale_second / norm2)
|
|
58
|
+
if ortho_project:
|
|
59
|
+
update_2d = ortho_normed(param_2d, update_2d, p_norm_sq_adim, 1-dim, scale_second)
|
|
60
|
+
|
|
61
|
+
return update_2d.view(original_shape).to(original_dtype)
|
|
62
|
+
|
|
63
|
+
def ortho_normed(p_2d, update_2d, p_norm_sq, dim, target_norm):
|
|
64
|
+
"""
|
|
65
|
+
Projects the update to be orthogonal to p along 'dim' and restores the original norm.
|
|
66
|
+
"""
|
|
67
|
+
# Project: g_orth = g - (p * <p, g> / ||p||^2)
|
|
68
|
+
dot_prod = torch.sum(p_2d * update_2d, dim=dim, keepdim=True)
|
|
69
|
+
proj = dot_prod / p_norm_sq
|
|
70
|
+
|
|
71
|
+
# In-place subtraction: update_2d = update_2d - (proj * p_2d)
|
|
72
|
+
update_2d.addcmul_(proj, p_2d, value=-1.0)
|
|
73
|
+
|
|
74
|
+
# Magnitude Preservation
|
|
75
|
+
g_orth_norm = update_2d.norm(p=2, dim=dim, keepdim=True).clamp_min_(1e-12)
|
|
76
|
+
scale_factor = target_norm / g_orth_norm
|
|
77
|
+
return update_2d.mul_(scale_factor)
|
|
@@ -14,9 +14,9 @@ adv_optm/optim/Lion_Prodigy_adv.py
|
|
|
14
14
|
adv_optm/optim/Lion_adv.py
|
|
15
15
|
adv_optm/optim/Muon_adv.py
|
|
16
16
|
adv_optm/optim/Prodigy_adv.py
|
|
17
|
-
adv_optm/optim/SGD_adv.py
|
|
18
17
|
adv_optm/optim/SignSGD_adv.py
|
|
19
18
|
adv_optm/optim/Simplified_AdEMAMix.py
|
|
19
|
+
adv_optm/optim/SinkSGD_adv.py
|
|
20
20
|
adv_optm/optim/__init__.py
|
|
21
21
|
adv_optm/util/Kourkoutas.py
|
|
22
22
|
adv_optm/util/Muon_AuxAdam.py
|
|
@@ -1,42 +0,0 @@
|
|
|
1
|
-
import math
|
|
2
|
-
import torch
|
|
3
|
-
|
|
4
|
-
def apply_sr_sinkhorn(update: torch.Tensor, iters: int = 5) -> torch.Tensor:
|
|
5
|
-
"""
|
|
6
|
-
Applies Square-Root Sinkhorn (SR-Sinkhorn) multi-normalization.
|
|
7
|
-
As described in 'Gradient Multi-Normalization for Efficient LLM Training'.
|
|
8
|
-
|
|
9
|
-
This technique normalizes a 2D matrix alternatively by its row-wise L2 norm
|
|
10
|
-
and column-wise L2 norm, driving it toward a fixed point that uniformly
|
|
11
|
-
distributes update magnitudes.
|
|
12
|
-
"""
|
|
13
|
-
original_shape = update.shape
|
|
14
|
-
|
|
15
|
-
if update.dim() == 1:
|
|
16
|
-
norm = update.norm(p=2).clamp_min_(1e-12)
|
|
17
|
-
return update.mul_(math.sqrt(update.numel()) / norm)
|
|
18
|
-
else:
|
|
19
|
-
# Flatten >= 3D tensors into 2D matrices
|
|
20
|
-
update_2d = update.view(update.shape[0], -1)
|
|
21
|
-
|
|
22
|
-
m, n = update_2d.shape
|
|
23
|
-
|
|
24
|
-
# Dynamically determine the order of normalization based on aspect ratio
|
|
25
|
-
# Normalizing the longer dimension first aids stability.
|
|
26
|
-
dim = 0 if m > n else 1
|
|
27
|
-
|
|
28
|
-
# Precompute scaling factors.
|
|
29
|
-
scale_first = math.sqrt(m) if dim == 0 else math.sqrt(n)
|
|
30
|
-
scale_second = math.sqrt(n) if dim == 0 else math.sqrt(m)
|
|
31
|
-
|
|
32
|
-
# In-place alternating Sinkhorn normalization steps
|
|
33
|
-
for _ in range(iters):
|
|
34
|
-
# First normalization step
|
|
35
|
-
norm1 = update_2d.norm(p=2, dim=dim, keepdim=True).clamp_min_(1e-12)
|
|
36
|
-
update_2d.mul_(scale_first / norm1)
|
|
37
|
-
|
|
38
|
-
# Second normalization step
|
|
39
|
-
norm2 = update_2d.norm(p=2, dim=1-dim, keepdim=True).clamp_min_(1e-12)
|
|
40
|
-
update_2d.mul_(scale_second / norm2)
|
|
41
|
-
|
|
42
|
-
return update_2d.view(original_shape)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|