adv-optm 2.4.dev17__tar.gz → 2.4.dev19__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {adv_optm-2.4.dev17 → adv_optm-2.4.dev19}/PKG-INFO +1 -1
- {adv_optm-2.4.dev17 → adv_optm-2.4.dev19}/adv_optm/__init__.py +1 -1
- {adv_optm-2.4.dev17 → adv_optm-2.4.dev19}/adv_optm/optim/SinkSGD_adv.py +51 -11
- {adv_optm-2.4.dev17 → adv_optm-2.4.dev19}/adv_optm/util/sinkhorn.py +53 -4
- {adv_optm-2.4.dev17 → adv_optm-2.4.dev19}/adv_optm.egg-info/PKG-INFO +1 -1
- {adv_optm-2.4.dev17 → adv_optm-2.4.dev19}/setup.py +1 -1
- {adv_optm-2.4.dev17 → adv_optm-2.4.dev19}/LICENSE +0 -0
- {adv_optm-2.4.dev17 → adv_optm-2.4.dev19}/README.md +0 -0
- {adv_optm-2.4.dev17 → adv_optm-2.4.dev19}/adv_optm/optim/AdaMuon_adv.py +0 -0
- {adv_optm-2.4.dev17 → adv_optm-2.4.dev19}/adv_optm/optim/AdamW_adv.py +0 -0
- {adv_optm-2.4.dev17 → adv_optm-2.4.dev19}/adv_optm/optim/Adopt_adv.py +0 -0
- {adv_optm-2.4.dev17 → adv_optm-2.4.dev19}/adv_optm/optim/Lion_Prodigy_adv.py +0 -0
- {adv_optm-2.4.dev17 → adv_optm-2.4.dev19}/adv_optm/optim/Lion_adv.py +0 -0
- {adv_optm-2.4.dev17 → adv_optm-2.4.dev19}/adv_optm/optim/Muon_adv.py +0 -0
- {adv_optm-2.4.dev17 → adv_optm-2.4.dev19}/adv_optm/optim/Prodigy_adv.py +0 -0
- {adv_optm-2.4.dev17 → adv_optm-2.4.dev19}/adv_optm/optim/SignSGD_adv.py +0 -0
- {adv_optm-2.4.dev17 → adv_optm-2.4.dev19}/adv_optm/optim/Simplified_AdEMAMix.py +0 -0
- {adv_optm-2.4.dev17 → adv_optm-2.4.dev19}/adv_optm/optim/__init__.py +0 -0
- {adv_optm-2.4.dev17 → adv_optm-2.4.dev19}/adv_optm/util/Kourkoutas.py +0 -0
- {adv_optm-2.4.dev17 → adv_optm-2.4.dev19}/adv_optm/util/Muon_AuxAdam.py +0 -0
- {adv_optm-2.4.dev17 → adv_optm-2.4.dev19}/adv_optm/util/Muon_util.py +0 -0
- {adv_optm-2.4.dev17 → adv_optm-2.4.dev19}/adv_optm/util/OrthoGrad.py +0 -0
- {adv_optm-2.4.dev17 → adv_optm-2.4.dev19}/adv_optm/util/__init__.py +0 -0
- {adv_optm-2.4.dev17 → adv_optm-2.4.dev19}/adv_optm/util/centered_decay.py +0 -0
- {adv_optm-2.4.dev17 → adv_optm-2.4.dev19}/adv_optm/util/factorization_util.py +0 -0
- {adv_optm-2.4.dev17 → adv_optm-2.4.dev19}/adv_optm/util/lion_k.py +0 -0
- {adv_optm-2.4.dev17 → adv_optm-2.4.dev19}/adv_optm/util/param_update.py +0 -0
- {adv_optm-2.4.dev17 → adv_optm-2.4.dev19}/adv_optm/util/scaled_optm.py +0 -0
- {adv_optm-2.4.dev17 → adv_optm-2.4.dev19}/adv_optm/util/signed_util.py +0 -0
- {adv_optm-2.4.dev17 → adv_optm-2.4.dev19}/adv_optm/util/state_util.py +0 -0
- {adv_optm-2.4.dev17 → adv_optm-2.4.dev19}/adv_optm/util/update_util.py +0 -0
- {adv_optm-2.4.dev17 → adv_optm-2.4.dev19}/adv_optm.egg-info/SOURCES.txt +0 -0
- {adv_optm-2.4.dev17 → adv_optm-2.4.dev19}/adv_optm.egg-info/dependency_links.txt +0 -0
- {adv_optm-2.4.dev17 → adv_optm-2.4.dev19}/adv_optm.egg-info/requires.txt +0 -0
- {adv_optm-2.4.dev17 → adv_optm-2.4.dev19}/adv_optm.egg-info/top_level.txt +0 -0
- {adv_optm-2.4.dev17 → adv_optm-2.4.dev19}/setup.cfg +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
|
|
3
|
-
|
|
3
|
+
import math
|
|
4
4
|
|
|
5
5
|
from ..util import param_update
|
|
6
6
|
from ..util.factorization_util import _get_effective_shape, _reconstruct_state, _factorize_state
|
|
@@ -9,7 +9,7 @@ from ..util.OrthoGrad import _orthogonalize_gradient
|
|
|
9
9
|
from ..util.scaled_optm import scale_update, is_spectral, init_spectral_norm
|
|
10
10
|
from ..util.centered_decay import _init_anchor
|
|
11
11
|
from ..util.state_util import init_state_tensor, get_state, set_state, upcast_grad_for_precision
|
|
12
|
-
from ..util.sinkhorn import apply_sr_sinkhorn
|
|
12
|
+
from ..util.sinkhorn import apply_sr_sinkhorn, get_sinkhorn_wd_scaler
|
|
13
13
|
from ..util.signed_util import apply_stochastic_sign_
|
|
14
14
|
|
|
15
15
|
class SinkSGD_adv(torch.optim.Optimizer):
|
|
@@ -26,8 +26,6 @@ class SinkSGD_adv(torch.optim.Optimizer):
|
|
|
26
26
|
weight_decay (float): weight decay (L2 penalty or decoupled) (default: 0).
|
|
27
27
|
nesterov (bool): enables Nesterov momentum. Only applicable when momentum
|
|
28
28
|
is non-zero. (default: False)
|
|
29
|
-
decoupled_wd (bool): whether to apply decoupled weight decay (like AdamW)
|
|
30
|
-
instead of standard L2 penalty. (default: False)
|
|
31
29
|
cautious_wd (bool): Enables Cautious Weight Decay. If True, weight decay is
|
|
32
30
|
applied only to parameter coordinates where the sign of the parameter
|
|
33
31
|
and the sign of the optimizer update align (default: False).
|
|
@@ -61,11 +59,13 @@ class SinkSGD_adv(torch.optim.Optimizer):
|
|
|
61
59
|
orthogonal_sinkhorn: bool = False,
|
|
62
60
|
# Normalization then Momentum
|
|
63
61
|
normed_momentum: bool = False,
|
|
62
|
+
# Centered Variance Precondition
|
|
63
|
+
centered_vt: bool = False,
|
|
64
64
|
# Nesterov Momentum
|
|
65
65
|
nesterov: bool = False,
|
|
66
66
|
nesterov_coef: float | None = None,
|
|
67
|
-
#
|
|
68
|
-
|
|
67
|
+
# weight decay features
|
|
68
|
+
geometric_wd: bool = False,
|
|
69
69
|
cautious_wd: bool = False,
|
|
70
70
|
# Stochastic Rounding for BF16
|
|
71
71
|
stochastic_rounding: bool = True,
|
|
@@ -90,6 +90,8 @@ class SinkSGD_adv(torch.optim.Optimizer):
|
|
|
90
90
|
raise ValueError(f"Momentum should be >= 0.0. Got {momentum}")
|
|
91
91
|
if not (weight_decay >= 0.0):
|
|
92
92
|
raise ValueError(f"Weight-decay should be >= 0.0. Got {weight_decay}")
|
|
93
|
+
if centered_vt and not normed_momentum:
|
|
94
|
+
raise NotImplementedError(f"centered_vt is intended to be used with normed_momentum")
|
|
93
95
|
|
|
94
96
|
state_precision = state_precision.lower()
|
|
95
97
|
valid_precisions = {"auto", "fp32", "factored", "bf16_sr", "fp8_sr", "int8_sr"}
|
|
@@ -101,8 +103,8 @@ class SinkSGD_adv(torch.optim.Optimizer):
|
|
|
101
103
|
|
|
102
104
|
defaults = {
|
|
103
105
|
"lr": lr, "momentum": momentum,
|
|
104
|
-
"weight_decay": weight_decay, "nesterov": nesterov, "nesterov_coef": nesterov_coef, "normed_momentum": normed_momentum,
|
|
105
|
-
"
|
|
106
|
+
"weight_decay": weight_decay, "nesterov": nesterov, "nesterov_coef": nesterov_coef, "normed_momentum": normed_momentum, "centered_vt": centered_vt,
|
|
107
|
+
"geometric_wd": geometric_wd, "cautious_wd": cautious_wd,
|
|
106
108
|
"orthogonal_gradient": orthogonal_gradient,
|
|
107
109
|
"compiled_optimizer": compiled_optimizer,
|
|
108
110
|
"sinkhorn_iterations": sinkhorn_iterations,
|
|
@@ -182,6 +184,13 @@ class SinkSGD_adv(torch.optim.Optimizer):
|
|
|
182
184
|
if group['momentum'] != 0:
|
|
183
185
|
init_state_tensor(state, 'momentum_buffer', p.shape, actual_precision, p.device, dtype)
|
|
184
186
|
|
|
187
|
+
if group.get('centered_vt', False):
|
|
188
|
+
# Align shapes with Sinkhorn's 2D flattening
|
|
189
|
+
dim0 = p.shape[0]
|
|
190
|
+
dim1 = p.numel() // dim0
|
|
191
|
+
state['vt_row'] = torch.zeros(dim0, device=device, dtype=torch.float32)
|
|
192
|
+
state['vt_col'] = torch.zeros(dim1, device=device, dtype=torch.float32)
|
|
193
|
+
|
|
185
194
|
if group.get('spectral_normalization', False) and is_spectral(p):
|
|
186
195
|
init_spectral_norm(state, p)
|
|
187
196
|
|
|
@@ -237,7 +246,7 @@ class SinkSGD_adv(torch.optim.Optimizer):
|
|
|
237
246
|
if group.get('normed_momentum', False):
|
|
238
247
|
if not is_vector:
|
|
239
248
|
# Sinkhorn iterative normalization
|
|
240
|
-
grad = apply_sr_sinkhorn(grad, p, ortho_project=orthogonal_sinkhorn
|
|
249
|
+
grad = apply_sr_sinkhorn(grad, iters=sinkhorn_iterations, p=p, ortho_project=orthogonal_sinkhorn)
|
|
241
250
|
else:
|
|
242
251
|
# For vectors, apply adaptive stochastic sign
|
|
243
252
|
grad = apply_stochastic_sign_(grad, sign_noise, is_vector=is_vector)
|
|
@@ -271,6 +280,23 @@ class SinkSGD_adv(torch.optim.Optimizer):
|
|
|
271
280
|
|
|
272
281
|
if momentum != 0:
|
|
273
282
|
buf = get_state(state, 'momentum_buffer', actual_precision)
|
|
283
|
+
|
|
284
|
+
if group.get('centered_vt', False):
|
|
285
|
+
vt_row, vt_col = state['vt_row'], state['vt_col']
|
|
286
|
+
grad_vt = grad - buf
|
|
287
|
+
grad_vt_sq = grad_vt.mul_(grad_vt).view(grad.shape[0], -1)
|
|
288
|
+
mean_row_grad = grad_vt_sq.mean(dim=-1)
|
|
289
|
+
mean_col_grad = grad_vt_sq.mean(dim=-2)
|
|
290
|
+
vt_row.mul_(momentum).add_(mean_row_grad, alpha=1.0 - momentum)
|
|
291
|
+
vt_col.mul_(momentum).add_(mean_col_grad, alpha=1.0 - momentum)
|
|
292
|
+
if nesterov:
|
|
293
|
+
nv_coef = momentum if nesterov_coef is None else nesterov_coef
|
|
294
|
+
vt_row = vt_row.lerp(mean_row_grad, 1.0 - nv_coef)
|
|
295
|
+
vt_col = vt_col.lerp(mean_col_grad, 1.0 - nv_coef)
|
|
296
|
+
else:
|
|
297
|
+
vt_row = None
|
|
298
|
+
vt_col = None
|
|
299
|
+
|
|
274
300
|
buf.lerp_(grad, 1 - momentum)
|
|
275
301
|
|
|
276
302
|
set_state(state, 'momentum_buffer', buf, actual_precision, random_int_state_tensor)
|
|
@@ -285,21 +311,35 @@ class SinkSGD_adv(torch.optim.Optimizer):
|
|
|
285
311
|
|
|
286
312
|
del random_int_state_tensor
|
|
287
313
|
|
|
314
|
+
if group.get('centered_vt', False):
|
|
315
|
+
# Align with Sinkhorn: Alternate row/col preconditioning
|
|
316
|
+
update_2d = update.view(update.shape[0], -1)
|
|
317
|
+
update_2d.div_(vt_row.clamp_min(1e-30).sqrt().unsqueeze(1))
|
|
318
|
+
update_2d.div_(vt_col.clamp_min(1e-30).sqrt().unsqueeze(0))
|
|
319
|
+
update = update_2d.atan_().view_as(p)
|
|
320
|
+
|
|
288
321
|
if not group.get('normed_momentum', False):
|
|
289
322
|
if not is_vector:
|
|
290
323
|
# Sinkhorn iterative normalization
|
|
291
|
-
update = apply_sr_sinkhorn(update, p, ortho_project=orthogonal_sinkhorn
|
|
324
|
+
update = apply_sr_sinkhorn(update, iters=sinkhorn_iterations, p=p, ortho_project=orthogonal_sinkhorn)
|
|
292
325
|
else:
|
|
293
326
|
# For vectors, apply adaptive stochastic sign
|
|
294
327
|
update = apply_stochastic_sign_(update, sign_noise, is_vector=is_vector)
|
|
295
328
|
|
|
329
|
+
if group.get('geometric_wd', False):
|
|
330
|
+
wd_scaler = get_sinkhorn_wd_scaler(p, row_denom=vt_row, col_denom=vt_col)
|
|
331
|
+
else:
|
|
332
|
+
wd_scaler = None
|
|
333
|
+
|
|
296
334
|
update_scaling = step_size
|
|
297
335
|
if group.get('spectral_normalization', False):
|
|
298
336
|
update = scale_update(p, update, update_scaling, state=state)
|
|
299
337
|
else:
|
|
338
|
+
if group.get('centered_vt', False):
|
|
339
|
+
update_scaling = update_scaling * (4/math.pi)
|
|
300
340
|
update.mul_(update_scaling)
|
|
301
341
|
|
|
302
|
-
param_update.apply_parameter_update(self, p, group, update, step_size, random_int_tensor=random_int_tensor)
|
|
342
|
+
param_update.apply_parameter_update(self, p, group, update, step_size, random_int_tensor=random_int_tensor, wd_scaler=wd_scaler)
|
|
303
343
|
|
|
304
344
|
def compile(self, *args, **kwargs):
|
|
305
345
|
self._compiled_step_parameter = torch.compile(self._step_parameter, *args, **kwargs)
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import math
|
|
2
2
|
import torch
|
|
3
3
|
|
|
4
|
-
def apply_sr_sinkhorn(update: torch.Tensor, p: torch.Tensor, ortho_project: bool
|
|
4
|
+
def apply_sr_sinkhorn(update: torch.Tensor, iters: int = 5, p: torch.Tensor | None = None, ortho_project: bool = False) -> torch.Tensor:
|
|
5
5
|
"""
|
|
6
6
|
Applies Square-Root Sinkhorn (SR-Sinkhorn) multi-normalization.
|
|
7
7
|
As described in 'Gradient Multi-Normalization for Efficient LLM Training'.
|
|
@@ -47,13 +47,16 @@ def apply_sr_sinkhorn(update: torch.Tensor, p: torch.Tensor, ortho_project: bool
|
|
|
47
47
|
# In-place alternating Sinkhorn normalization steps
|
|
48
48
|
for _ in range(iters):
|
|
49
49
|
# First normalization step
|
|
50
|
-
|
|
50
|
+
# Stability floor: equivalent to a single-element vector norm lower bound (lb)
|
|
51
|
+
norm1_lb = 1 / math.sqrt(update_2d.shape[dim])
|
|
52
|
+
norm1 = update_2d.norm(p=2, dim=dim, keepdim=True).clamp_min_(norm1_lb)
|
|
51
53
|
update_2d.mul_(scale_first / norm1)
|
|
52
54
|
if ortho_project:
|
|
53
55
|
update_2d = ortho_normed(param_2d, update_2d, p_norm_sq_dim, dim, scale_first)
|
|
54
56
|
|
|
55
57
|
# Second normalization step
|
|
56
|
-
|
|
58
|
+
norm2_lb = 1 / math.sqrt(update_2d.shape[1-dim])
|
|
59
|
+
norm2 = update_2d.norm(p=2, dim=1-dim, keepdim=True).clamp_min_(norm2_lb)
|
|
57
60
|
update_2d.mul_(scale_second / norm2)
|
|
58
61
|
if ortho_project:
|
|
59
62
|
update_2d = ortho_normed(param_2d, update_2d, p_norm_sq_adim, 1-dim, scale_second)
|
|
@@ -72,6 +75,52 @@ def ortho_normed(p_2d, update_2d, p_norm_sq, dim, target_norm):
|
|
|
72
75
|
update_2d.addcmul_(proj, p_2d, value=-1.0)
|
|
73
76
|
|
|
74
77
|
# Magnitude Preservation
|
|
75
|
-
|
|
78
|
+
norm_lb = 1 / math.sqrt(update_2d.shape[dim])
|
|
79
|
+
g_orth_norm = update_2d.norm(p=2, dim=dim, keepdim=True).clamp_min_(norm_lb)
|
|
76
80
|
scale_factor = target_norm / g_orth_norm
|
|
77
81
|
return update_2d.mul_(scale_factor)
|
|
82
|
+
|
|
83
|
+
def get_sinkhorn_wd_scaler(
|
|
84
|
+
p: torch.Tensor,
|
|
85
|
+
row_denom: torch.Tensor | None = None,
|
|
86
|
+
col_denom: torch.Tensor | None = None
|
|
87
|
+
):
|
|
88
|
+
"""
|
|
89
|
+
Computes a structural weight decay multiplier.
|
|
90
|
+
Penalizes parameters belonging to dominant rows/columns more heavily,
|
|
91
|
+
while protecting parameters in under-utilized/noisy rows/columns from decay.
|
|
92
|
+
"""
|
|
93
|
+
if p.ndim < 2:
|
|
94
|
+
return 1.0
|
|
95
|
+
|
|
96
|
+
p_2d = p.view(p.shape[0], -1)
|
|
97
|
+
|
|
98
|
+
# Lower bounds based on the effective 2D shapes
|
|
99
|
+
row_lb = 1 / math.sqrt(p_2d.shape[1])
|
|
100
|
+
col_lb = 1 / math.sqrt(p_2d.shape[0])
|
|
101
|
+
|
|
102
|
+
# Get the norms
|
|
103
|
+
row_norms = torch.linalg.vector_norm(p_2d, ord=2, dim=1, keepdim=True).clamp_min_(row_lb)
|
|
104
|
+
col_norms = torch.linalg.vector_norm(p_2d, ord=2, dim=0, keepdim=True).clamp_min_(col_lb)
|
|
105
|
+
|
|
106
|
+
# Compute the structural scaler
|
|
107
|
+
row_factor = row_norms.sqrt_()
|
|
108
|
+
col_factor = col_norms.sqrt_()
|
|
109
|
+
|
|
110
|
+
if row_denom is not None and col_denom is not None:
|
|
111
|
+
# Reshape denominators to ensure safe in-place broadcasting
|
|
112
|
+
row_denom = row_denom.sqrt().view(p_2d.shape[0], 1)
|
|
113
|
+
col_denom = col_denom.sqrt().view(1, p_2d.shape[1])
|
|
114
|
+
|
|
115
|
+
# High denom (noise) -> smaller angle (protects weights)
|
|
116
|
+
# Low denom (confident) -> larger angle (decays weights)
|
|
117
|
+
row_factor.atan2_(row_denom)
|
|
118
|
+
col_factor.atan2_(col_denom)
|
|
119
|
+
|
|
120
|
+
# Outer product: merges the row and column confidences into a 2D matrix
|
|
121
|
+
wd_scaler = row_factor * col_factor
|
|
122
|
+
|
|
123
|
+
# Normalize the scaler so its mean is exactly 1.0
|
|
124
|
+
wd_scaler.div_(wd_scaler.mean().clamp_min_(1e-12))
|
|
125
|
+
|
|
126
|
+
return wd_scaler.view_as(p)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|