adv-optm 2.4.dev18__tar.gz → 2.4.dev19__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {adv_optm-2.4.dev18 → adv_optm-2.4.dev19}/PKG-INFO +1 -1
- {adv_optm-2.4.dev18 → adv_optm-2.4.dev19}/adv_optm/__init__.py +1 -1
- {adv_optm-2.4.dev18 → adv_optm-2.4.dev19}/adv_optm/optim/SinkSGD_adv.py +14 -10
- {adv_optm-2.4.dev18 → adv_optm-2.4.dev19}/adv_optm/util/sinkhorn.py +2 -19
- {adv_optm-2.4.dev18 → adv_optm-2.4.dev19}/adv_optm.egg-info/PKG-INFO +1 -1
- {adv_optm-2.4.dev18 → adv_optm-2.4.dev19}/setup.py +1 -1
- {adv_optm-2.4.dev18 → adv_optm-2.4.dev19}/LICENSE +0 -0
- {adv_optm-2.4.dev18 → adv_optm-2.4.dev19}/README.md +0 -0
- {adv_optm-2.4.dev18 → adv_optm-2.4.dev19}/adv_optm/optim/AdaMuon_adv.py +0 -0
- {adv_optm-2.4.dev18 → adv_optm-2.4.dev19}/adv_optm/optim/AdamW_adv.py +0 -0
- {adv_optm-2.4.dev18 → adv_optm-2.4.dev19}/adv_optm/optim/Adopt_adv.py +0 -0
- {adv_optm-2.4.dev18 → adv_optm-2.4.dev19}/adv_optm/optim/Lion_Prodigy_adv.py +0 -0
- {adv_optm-2.4.dev18 → adv_optm-2.4.dev19}/adv_optm/optim/Lion_adv.py +0 -0
- {adv_optm-2.4.dev18 → adv_optm-2.4.dev19}/adv_optm/optim/Muon_adv.py +0 -0
- {adv_optm-2.4.dev18 → adv_optm-2.4.dev19}/adv_optm/optim/Prodigy_adv.py +0 -0
- {adv_optm-2.4.dev18 → adv_optm-2.4.dev19}/adv_optm/optim/SignSGD_adv.py +0 -0
- {adv_optm-2.4.dev18 → adv_optm-2.4.dev19}/adv_optm/optim/Simplified_AdEMAMix.py +0 -0
- {adv_optm-2.4.dev18 → adv_optm-2.4.dev19}/adv_optm/optim/__init__.py +0 -0
- {adv_optm-2.4.dev18 → adv_optm-2.4.dev19}/adv_optm/util/Kourkoutas.py +0 -0
- {adv_optm-2.4.dev18 → adv_optm-2.4.dev19}/adv_optm/util/Muon_AuxAdam.py +0 -0
- {adv_optm-2.4.dev18 → adv_optm-2.4.dev19}/adv_optm/util/Muon_util.py +0 -0
- {adv_optm-2.4.dev18 → adv_optm-2.4.dev19}/adv_optm/util/OrthoGrad.py +0 -0
- {adv_optm-2.4.dev18 → adv_optm-2.4.dev19}/adv_optm/util/__init__.py +0 -0
- {adv_optm-2.4.dev18 → adv_optm-2.4.dev19}/adv_optm/util/centered_decay.py +0 -0
- {adv_optm-2.4.dev18 → adv_optm-2.4.dev19}/adv_optm/util/factorization_util.py +0 -0
- {adv_optm-2.4.dev18 → adv_optm-2.4.dev19}/adv_optm/util/lion_k.py +0 -0
- {adv_optm-2.4.dev18 → adv_optm-2.4.dev19}/adv_optm/util/param_update.py +0 -0
- {adv_optm-2.4.dev18 → adv_optm-2.4.dev19}/adv_optm/util/scaled_optm.py +0 -0
- {adv_optm-2.4.dev18 → adv_optm-2.4.dev19}/adv_optm/util/signed_util.py +0 -0
- {adv_optm-2.4.dev18 → adv_optm-2.4.dev19}/adv_optm/util/state_util.py +0 -0
- {adv_optm-2.4.dev18 → adv_optm-2.4.dev19}/adv_optm/util/update_util.py +0 -0
- {adv_optm-2.4.dev18 → adv_optm-2.4.dev19}/adv_optm.egg-info/SOURCES.txt +0 -0
- {adv_optm-2.4.dev18 → adv_optm-2.4.dev19}/adv_optm.egg-info/dependency_links.txt +0 -0
- {adv_optm-2.4.dev18 → adv_optm-2.4.dev19}/adv_optm.egg-info/requires.txt +0 -0
- {adv_optm-2.4.dev18 → adv_optm-2.4.dev19}/adv_optm.egg-info/top_level.txt +0 -0
- {adv_optm-2.4.dev18 → adv_optm-2.4.dev19}/setup.cfg +0 -0
|
@@ -9,7 +9,7 @@ from ..util.OrthoGrad import _orthogonalize_gradient
|
|
|
9
9
|
from ..util.scaled_optm import scale_update, is_spectral, init_spectral_norm
|
|
10
10
|
from ..util.centered_decay import _init_anchor
|
|
11
11
|
from ..util.state_util import init_state_tensor, get_state, set_state, upcast_grad_for_precision
|
|
12
|
-
from ..util.sinkhorn import apply_sr_sinkhorn,
|
|
12
|
+
from ..util.sinkhorn import apply_sr_sinkhorn, get_sinkhorn_wd_scaler
|
|
13
13
|
from ..util.signed_util import apply_stochastic_sign_
|
|
14
14
|
|
|
15
15
|
class SinkSGD_adv(torch.optim.Optimizer):
|
|
@@ -90,6 +90,8 @@ class SinkSGD_adv(torch.optim.Optimizer):
|
|
|
90
90
|
raise ValueError(f"Momentum should be >= 0.0. Got {momentum}")
|
|
91
91
|
if not (weight_decay >= 0.0):
|
|
92
92
|
raise ValueError(f"Weight-decay should be >= 0.0. Got {weight_decay}")
|
|
93
|
+
if centered_vt and not normed_momentum:
|
|
94
|
+
raise NotImplementedError(f"centered_vt is intended to be used with normed_momentum")
|
|
93
95
|
|
|
94
96
|
state_precision = state_precision.lower()
|
|
95
97
|
valid_precisions = {"auto", "fp32", "factored", "bf16_sr", "fp8_sr", "int8_sr"}
|
|
@@ -183,9 +185,11 @@ class SinkSGD_adv(torch.optim.Optimizer):
|
|
|
183
185
|
init_state_tensor(state, 'momentum_buffer', p.shape, actual_precision, p.device, dtype)
|
|
184
186
|
|
|
185
187
|
if group.get('centered_vt', False):
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
188
|
+
# Align shapes with Sinkhorn's 2D flattening
|
|
189
|
+
dim0 = p.shape[0]
|
|
190
|
+
dim1 = p.numel() // dim0
|
|
191
|
+
state['vt_row'] = torch.zeros(dim0, device=device, dtype=torch.float32)
|
|
192
|
+
state['vt_col'] = torch.zeros(dim1, device=device, dtype=torch.float32)
|
|
189
193
|
|
|
190
194
|
if group.get('spectral_normalization', False) and is_spectral(p):
|
|
191
195
|
init_spectral_norm(state, p)
|
|
@@ -280,7 +284,7 @@ class SinkSGD_adv(torch.optim.Optimizer):
|
|
|
280
284
|
if group.get('centered_vt', False):
|
|
281
285
|
vt_row, vt_col = state['vt_row'], state['vt_col']
|
|
282
286
|
grad_vt = grad - buf
|
|
283
|
-
grad_vt_sq = grad_vt
|
|
287
|
+
grad_vt_sq = grad_vt.mul_(grad_vt).view(grad.shape[0], -1)
|
|
284
288
|
mean_row_grad = grad_vt_sq.mean(dim=-1)
|
|
285
289
|
mean_col_grad = grad_vt_sq.mean(dim=-2)
|
|
286
290
|
vt_row.mul_(momentum).add_(mean_row_grad, alpha=1.0 - momentum)
|
|
@@ -289,7 +293,6 @@ class SinkSGD_adv(torch.optim.Optimizer):
|
|
|
289
293
|
nv_coef = momentum if nesterov_coef is None else nesterov_coef
|
|
290
294
|
vt_row = vt_row.lerp(mean_row_grad, 1.0 - nv_coef)
|
|
291
295
|
vt_col = vt_col.lerp(mean_col_grad, 1.0 - nv_coef)
|
|
292
|
-
vt = _sinkhorn_sq_grad(vt_row, vt_col)
|
|
293
296
|
else:
|
|
294
297
|
vt_row = None
|
|
295
298
|
vt_col = None
|
|
@@ -309,10 +312,11 @@ class SinkSGD_adv(torch.optim.Optimizer):
|
|
|
309
312
|
del random_int_state_tensor
|
|
310
313
|
|
|
311
314
|
if group.get('centered_vt', False):
|
|
312
|
-
|
|
313
|
-
update.
|
|
314
|
-
|
|
315
|
-
|
|
315
|
+
# Align with Sinkhorn: Alternate row/col preconditioning
|
|
316
|
+
update_2d = update.view(update.shape[0], -1)
|
|
317
|
+
update_2d.div_(vt_row.clamp_min(1e-30).sqrt().unsqueeze(1))
|
|
318
|
+
update_2d.div_(vt_col.clamp_min(1e-30).sqrt().unsqueeze(0))
|
|
319
|
+
update = update_2d.atan_().view_as(p)
|
|
316
320
|
|
|
317
321
|
if not group.get('normed_momentum', False):
|
|
318
322
|
if not is_vector:
|
|
@@ -80,23 +80,6 @@ def ortho_normed(p_2d, update_2d, p_norm_sq, dim, target_norm):
|
|
|
80
80
|
scale_factor = target_norm / g_orth_norm
|
|
81
81
|
return update_2d.mul_(scale_factor)
|
|
82
82
|
|
|
83
|
-
def _sinkhorn_sq_grad(
|
|
84
|
-
vt_row: torch.Tensor,
|
|
85
|
-
vt_col: torch.Tensor,
|
|
86
|
-
) -> torch.Tensor:
|
|
87
|
-
"""
|
|
88
|
-
Reconstructs the variance precondition from its rank-1 factors.
|
|
89
|
-
Modified from:
|
|
90
|
-
https://github.com/jettify/pytorch-optimizer/blob/master/torch_optimizer/adafactor.py
|
|
91
|
-
"""
|
|
92
|
-
r_factor = (
|
|
93
|
-
(vt_row / vt_row.mean(dim=-1).clamp_min_(1e-30))
|
|
94
|
-
.sqrt_()
|
|
95
|
-
.unsqueeze(-1)
|
|
96
|
-
)
|
|
97
|
-
c_factor = vt_col.unsqueeze(-2).sqrt()
|
|
98
|
-
return torch.mul(r_factor, c_factor)
|
|
99
|
-
|
|
100
83
|
def get_sinkhorn_wd_scaler(
|
|
101
84
|
p: torch.Tensor,
|
|
102
85
|
row_denom: torch.Tensor | None = None,
|
|
@@ -126,8 +109,8 @@ def get_sinkhorn_wd_scaler(
|
|
|
126
109
|
|
|
127
110
|
if row_denom is not None and col_denom is not None:
|
|
128
111
|
# Reshape denominators to ensure safe in-place broadcasting
|
|
129
|
-
row_denom = row_denom.view(p_2d.shape[0], 1)
|
|
130
|
-
col_denom = col_denom.view(1, p_2d.shape[1])
|
|
112
|
+
row_denom = row_denom.sqrt().view(p_2d.shape[0], 1)
|
|
113
|
+
col_denom = col_denom.sqrt().view(1, p_2d.shape[1])
|
|
131
114
|
|
|
132
115
|
# High denom (noise) -> smaller angle (protects weights)
|
|
133
116
|
# Low denom (confident) -> larger angle (decays weights)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|