adv-optm 2.4.dev18__tar.gz → 2.4.dev20__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {adv_optm-2.4.dev18 → adv_optm-2.4.dev20}/PKG-INFO +1 -1
- {adv_optm-2.4.dev18 → adv_optm-2.4.dev20}/adv_optm/__init__.py +1 -1
- {adv_optm-2.4.dev18 → adv_optm-2.4.dev20}/adv_optm/optim/SinkSGD_adv.py +14 -15
- {adv_optm-2.4.dev18 → adv_optm-2.4.dev20}/adv_optm/util/sinkhorn.py +2 -19
- {adv_optm-2.4.dev18 → adv_optm-2.4.dev20}/adv_optm.egg-info/PKG-INFO +1 -1
- {adv_optm-2.4.dev18 → adv_optm-2.4.dev20}/setup.py +1 -1
- {adv_optm-2.4.dev18 → adv_optm-2.4.dev20}/LICENSE +0 -0
- {adv_optm-2.4.dev18 → adv_optm-2.4.dev20}/README.md +0 -0
- {adv_optm-2.4.dev18 → adv_optm-2.4.dev20}/adv_optm/optim/AdaMuon_adv.py +0 -0
- {adv_optm-2.4.dev18 → adv_optm-2.4.dev20}/adv_optm/optim/AdamW_adv.py +0 -0
- {adv_optm-2.4.dev18 → adv_optm-2.4.dev20}/adv_optm/optim/Adopt_adv.py +0 -0
- {adv_optm-2.4.dev18 → adv_optm-2.4.dev20}/adv_optm/optim/Lion_Prodigy_adv.py +0 -0
- {adv_optm-2.4.dev18 → adv_optm-2.4.dev20}/adv_optm/optim/Lion_adv.py +0 -0
- {adv_optm-2.4.dev18 → adv_optm-2.4.dev20}/adv_optm/optim/Muon_adv.py +0 -0
- {adv_optm-2.4.dev18 → adv_optm-2.4.dev20}/adv_optm/optim/Prodigy_adv.py +0 -0
- {adv_optm-2.4.dev18 → adv_optm-2.4.dev20}/adv_optm/optim/SignSGD_adv.py +0 -0
- {adv_optm-2.4.dev18 → adv_optm-2.4.dev20}/adv_optm/optim/Simplified_AdEMAMix.py +0 -0
- {adv_optm-2.4.dev18 → adv_optm-2.4.dev20}/adv_optm/optim/__init__.py +0 -0
- {adv_optm-2.4.dev18 → adv_optm-2.4.dev20}/adv_optm/util/Kourkoutas.py +0 -0
- {adv_optm-2.4.dev18 → adv_optm-2.4.dev20}/adv_optm/util/Muon_AuxAdam.py +0 -0
- {adv_optm-2.4.dev18 → adv_optm-2.4.dev20}/adv_optm/util/Muon_util.py +0 -0
- {adv_optm-2.4.dev18 → adv_optm-2.4.dev20}/adv_optm/util/OrthoGrad.py +0 -0
- {adv_optm-2.4.dev18 → adv_optm-2.4.dev20}/adv_optm/util/__init__.py +0 -0
- {adv_optm-2.4.dev18 → adv_optm-2.4.dev20}/adv_optm/util/centered_decay.py +0 -0
- {adv_optm-2.4.dev18 → adv_optm-2.4.dev20}/adv_optm/util/factorization_util.py +0 -0
- {adv_optm-2.4.dev18 → adv_optm-2.4.dev20}/adv_optm/util/lion_k.py +0 -0
- {adv_optm-2.4.dev18 → adv_optm-2.4.dev20}/adv_optm/util/param_update.py +0 -0
- {adv_optm-2.4.dev18 → adv_optm-2.4.dev20}/adv_optm/util/scaled_optm.py +0 -0
- {adv_optm-2.4.dev18 → adv_optm-2.4.dev20}/adv_optm/util/signed_util.py +0 -0
- {adv_optm-2.4.dev18 → adv_optm-2.4.dev20}/adv_optm/util/state_util.py +0 -0
- {adv_optm-2.4.dev18 → adv_optm-2.4.dev20}/adv_optm/util/update_util.py +0 -0
- {adv_optm-2.4.dev18 → adv_optm-2.4.dev20}/adv_optm.egg-info/SOURCES.txt +0 -0
- {adv_optm-2.4.dev18 → adv_optm-2.4.dev20}/adv_optm.egg-info/dependency_links.txt +0 -0
- {adv_optm-2.4.dev18 → adv_optm-2.4.dev20}/adv_optm.egg-info/requires.txt +0 -0
- {adv_optm-2.4.dev18 → adv_optm-2.4.dev20}/adv_optm.egg-info/top_level.txt +0 -0
- {adv_optm-2.4.dev18 → adv_optm-2.4.dev20}/setup.cfg +0 -0
|
@@ -4,12 +4,11 @@ import math
|
|
|
4
4
|
|
|
5
5
|
from ..util import param_update
|
|
6
6
|
from ..util.factorization_util import _get_effective_shape, _reconstruct_state, _factorize_state
|
|
7
|
-
from ..util.update_util import _grams_update, _cautious_update
|
|
8
7
|
from ..util.OrthoGrad import _orthogonalize_gradient
|
|
9
8
|
from ..util.scaled_optm import scale_update, is_spectral, init_spectral_norm
|
|
10
9
|
from ..util.centered_decay import _init_anchor
|
|
11
10
|
from ..util.state_util import init_state_tensor, get_state, set_state, upcast_grad_for_precision
|
|
12
|
-
from ..util.sinkhorn import apply_sr_sinkhorn,
|
|
11
|
+
from ..util.sinkhorn import apply_sr_sinkhorn, get_sinkhorn_wd_scaler
|
|
13
12
|
from ..util.signed_util import apply_stochastic_sign_
|
|
14
13
|
|
|
15
14
|
class SinkSGD_adv(torch.optim.Optimizer):
|
|
@@ -90,6 +89,8 @@ class SinkSGD_adv(torch.optim.Optimizer):
|
|
|
90
89
|
raise ValueError(f"Momentum should be >= 0.0. Got {momentum}")
|
|
91
90
|
if not (weight_decay >= 0.0):
|
|
92
91
|
raise ValueError(f"Weight-decay should be >= 0.0. Got {weight_decay}")
|
|
92
|
+
if centered_vt and not normed_momentum:
|
|
93
|
+
raise NotImplementedError(f"centered_vt is intended to be used with normed_momentum")
|
|
93
94
|
|
|
94
95
|
state_precision = state_precision.lower()
|
|
95
96
|
valid_precisions = {"auto", "fp32", "factored", "bf16_sr", "fp8_sr", "int8_sr"}
|
|
@@ -183,9 +184,11 @@ class SinkSGD_adv(torch.optim.Optimizer):
|
|
|
183
184
|
init_state_tensor(state, 'momentum_buffer', p.shape, actual_precision, p.device, dtype)
|
|
184
185
|
|
|
185
186
|
if group.get('centered_vt', False):
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
187
|
+
# Align shapes with Sinkhorn's 2D flattening
|
|
188
|
+
dim0 = p.shape[0]
|
|
189
|
+
dim1 = p.numel() // dim0
|
|
190
|
+
state['vt_row'] = torch.zeros(dim0, device=device, dtype=torch.float32)
|
|
191
|
+
state['vt_col'] = torch.zeros(dim1, device=device, dtype=torch.float32)
|
|
189
192
|
|
|
190
193
|
if group.get('spectral_normalization', False) and is_spectral(p):
|
|
191
194
|
init_spectral_norm(state, p)
|
|
@@ -280,16 +283,11 @@ class SinkSGD_adv(torch.optim.Optimizer):
|
|
|
280
283
|
if group.get('centered_vt', False):
|
|
281
284
|
vt_row, vt_col = state['vt_row'], state['vt_col']
|
|
282
285
|
grad_vt = grad - buf
|
|
283
|
-
grad_vt_sq = grad_vt
|
|
286
|
+
grad_vt_sq = grad_vt.mul_(grad_vt).view(grad.shape[0], -1)
|
|
284
287
|
mean_row_grad = grad_vt_sq.mean(dim=-1)
|
|
285
288
|
mean_col_grad = grad_vt_sq.mean(dim=-2)
|
|
286
289
|
vt_row.mul_(momentum).add_(mean_row_grad, alpha=1.0 - momentum)
|
|
287
290
|
vt_col.mul_(momentum).add_(mean_col_grad, alpha=1.0 - momentum)
|
|
288
|
-
if nesterov:
|
|
289
|
-
nv_coef = momentum if nesterov_coef is None else nesterov_coef
|
|
290
|
-
vt_row = vt_row.lerp(mean_row_grad, 1.0 - nv_coef)
|
|
291
|
-
vt_col = vt_col.lerp(mean_col_grad, 1.0 - nv_coef)
|
|
292
|
-
vt = _sinkhorn_sq_grad(vt_row, vt_col)
|
|
293
291
|
else:
|
|
294
292
|
vt_row = None
|
|
295
293
|
vt_col = None
|
|
@@ -309,10 +307,11 @@ class SinkSGD_adv(torch.optim.Optimizer):
|
|
|
309
307
|
del random_int_state_tensor
|
|
310
308
|
|
|
311
309
|
if group.get('centered_vt', False):
|
|
312
|
-
|
|
313
|
-
update.
|
|
314
|
-
|
|
315
|
-
|
|
310
|
+
# Align with Sinkhorn: Alternate row/col preconditioning
|
|
311
|
+
update_2d = update.view(update.shape[0], -1)
|
|
312
|
+
update_2d.mul_(vt_row.clamp_min(1e-30).rsqrt().unsqueeze(1))
|
|
313
|
+
update_2d.mul_(vt_col.clamp_min(1e-30).rsqrt().unsqueeze(0))
|
|
314
|
+
update = update_2d.atan_().view_as(p)
|
|
316
315
|
|
|
317
316
|
if not group.get('normed_momentum', False):
|
|
318
317
|
if not is_vector:
|
|
@@ -80,23 +80,6 @@ def ortho_normed(p_2d, update_2d, p_norm_sq, dim, target_norm):
|
|
|
80
80
|
scale_factor = target_norm / g_orth_norm
|
|
81
81
|
return update_2d.mul_(scale_factor)
|
|
82
82
|
|
|
83
|
-
def _sinkhorn_sq_grad(
|
|
84
|
-
vt_row: torch.Tensor,
|
|
85
|
-
vt_col: torch.Tensor,
|
|
86
|
-
) -> torch.Tensor:
|
|
87
|
-
"""
|
|
88
|
-
Reconstructs the variance precondition from its rank-1 factors.
|
|
89
|
-
Modified from:
|
|
90
|
-
https://github.com/jettify/pytorch-optimizer/blob/master/torch_optimizer/adafactor.py
|
|
91
|
-
"""
|
|
92
|
-
r_factor = (
|
|
93
|
-
(vt_row / vt_row.mean(dim=-1).clamp_min_(1e-30))
|
|
94
|
-
.sqrt_()
|
|
95
|
-
.unsqueeze(-1)
|
|
96
|
-
)
|
|
97
|
-
c_factor = vt_col.unsqueeze(-2).sqrt()
|
|
98
|
-
return torch.mul(r_factor, c_factor)
|
|
99
|
-
|
|
100
83
|
def get_sinkhorn_wd_scaler(
|
|
101
84
|
p: torch.Tensor,
|
|
102
85
|
row_denom: torch.Tensor | None = None,
|
|
@@ -126,8 +109,8 @@ def get_sinkhorn_wd_scaler(
|
|
|
126
109
|
|
|
127
110
|
if row_denom is not None and col_denom is not None:
|
|
128
111
|
# Reshape denominators to ensure safe in-place broadcasting
|
|
129
|
-
row_denom = row_denom.view(p_2d.shape[0], 1)
|
|
130
|
-
col_denom = col_denom.view(1, p_2d.shape[1])
|
|
112
|
+
row_denom = row_denom.sqrt().view(p_2d.shape[0], 1)
|
|
113
|
+
col_denom = col_denom.sqrt().view(1, p_2d.shape[1])
|
|
131
114
|
|
|
132
115
|
# High denom (noise) -> smaller angle (protects weights)
|
|
133
116
|
# Low denom (confident) -> larger angle (decays weights)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|