adv-optm 2.4.dev19__tar.gz → 2.4.dev20__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {adv_optm-2.4.dev19 → adv_optm-2.4.dev20}/PKG-INFO +1 -1
- {adv_optm-2.4.dev19 → adv_optm-2.4.dev20}/adv_optm/__init__.py +1 -1
- {adv_optm-2.4.dev19 → adv_optm-2.4.dev20}/adv_optm/optim/SinkSGD_adv.py +2 -7
- {adv_optm-2.4.dev19 → adv_optm-2.4.dev20}/adv_optm.egg-info/PKG-INFO +1 -1
- {adv_optm-2.4.dev19 → adv_optm-2.4.dev20}/setup.py +1 -1
- {adv_optm-2.4.dev19 → adv_optm-2.4.dev20}/LICENSE +0 -0
- {adv_optm-2.4.dev19 → adv_optm-2.4.dev20}/README.md +0 -0
- {adv_optm-2.4.dev19 → adv_optm-2.4.dev20}/adv_optm/optim/AdaMuon_adv.py +0 -0
- {adv_optm-2.4.dev19 → adv_optm-2.4.dev20}/adv_optm/optim/AdamW_adv.py +0 -0
- {adv_optm-2.4.dev19 → adv_optm-2.4.dev20}/adv_optm/optim/Adopt_adv.py +0 -0
- {adv_optm-2.4.dev19 → adv_optm-2.4.dev20}/adv_optm/optim/Lion_Prodigy_adv.py +0 -0
- {adv_optm-2.4.dev19 → adv_optm-2.4.dev20}/adv_optm/optim/Lion_adv.py +0 -0
- {adv_optm-2.4.dev19 → adv_optm-2.4.dev20}/adv_optm/optim/Muon_adv.py +0 -0
- {adv_optm-2.4.dev19 → adv_optm-2.4.dev20}/adv_optm/optim/Prodigy_adv.py +0 -0
- {adv_optm-2.4.dev19 → adv_optm-2.4.dev20}/adv_optm/optim/SignSGD_adv.py +0 -0
- {adv_optm-2.4.dev19 → adv_optm-2.4.dev20}/adv_optm/optim/Simplified_AdEMAMix.py +0 -0
- {adv_optm-2.4.dev19 → adv_optm-2.4.dev20}/adv_optm/optim/__init__.py +0 -0
- {adv_optm-2.4.dev19 → adv_optm-2.4.dev20}/adv_optm/util/Kourkoutas.py +0 -0
- {adv_optm-2.4.dev19 → adv_optm-2.4.dev20}/adv_optm/util/Muon_AuxAdam.py +0 -0
- {adv_optm-2.4.dev19 → adv_optm-2.4.dev20}/adv_optm/util/Muon_util.py +0 -0
- {adv_optm-2.4.dev19 → adv_optm-2.4.dev20}/adv_optm/util/OrthoGrad.py +0 -0
- {adv_optm-2.4.dev19 → adv_optm-2.4.dev20}/adv_optm/util/__init__.py +0 -0
- {adv_optm-2.4.dev19 → adv_optm-2.4.dev20}/adv_optm/util/centered_decay.py +0 -0
- {adv_optm-2.4.dev19 → adv_optm-2.4.dev20}/adv_optm/util/factorization_util.py +0 -0
- {adv_optm-2.4.dev19 → adv_optm-2.4.dev20}/adv_optm/util/lion_k.py +0 -0
- {adv_optm-2.4.dev19 → adv_optm-2.4.dev20}/adv_optm/util/param_update.py +0 -0
- {adv_optm-2.4.dev19 → adv_optm-2.4.dev20}/adv_optm/util/scaled_optm.py +0 -0
- {adv_optm-2.4.dev19 → adv_optm-2.4.dev20}/adv_optm/util/signed_util.py +0 -0
- {adv_optm-2.4.dev19 → adv_optm-2.4.dev20}/adv_optm/util/sinkhorn.py +0 -0
- {adv_optm-2.4.dev19 → adv_optm-2.4.dev20}/adv_optm/util/state_util.py +0 -0
- {adv_optm-2.4.dev19 → adv_optm-2.4.dev20}/adv_optm/util/update_util.py +0 -0
- {adv_optm-2.4.dev19 → adv_optm-2.4.dev20}/adv_optm.egg-info/SOURCES.txt +0 -0
- {adv_optm-2.4.dev19 → adv_optm-2.4.dev20}/adv_optm.egg-info/dependency_links.txt +0 -0
- {adv_optm-2.4.dev19 → adv_optm-2.4.dev20}/adv_optm.egg-info/requires.txt +0 -0
- {adv_optm-2.4.dev19 → adv_optm-2.4.dev20}/adv_optm.egg-info/top_level.txt +0 -0
- {adv_optm-2.4.dev19 → adv_optm-2.4.dev20}/setup.cfg +0 -0
|
@@ -4,7 +4,6 @@ import math
|
|
|
4
4
|
|
|
5
5
|
from ..util import param_update
|
|
6
6
|
from ..util.factorization_util import _get_effective_shape, _reconstruct_state, _factorize_state
|
|
7
|
-
from ..util.update_util import _grams_update, _cautious_update
|
|
8
7
|
from ..util.OrthoGrad import _orthogonalize_gradient
|
|
9
8
|
from ..util.scaled_optm import scale_update, is_spectral, init_spectral_norm
|
|
10
9
|
from ..util.centered_decay import _init_anchor
|
|
@@ -289,10 +288,6 @@ class SinkSGD_adv(torch.optim.Optimizer):
|
|
|
289
288
|
mean_col_grad = grad_vt_sq.mean(dim=-2)
|
|
290
289
|
vt_row.mul_(momentum).add_(mean_row_grad, alpha=1.0 - momentum)
|
|
291
290
|
vt_col.mul_(momentum).add_(mean_col_grad, alpha=1.0 - momentum)
|
|
292
|
-
if nesterov:
|
|
293
|
-
nv_coef = momentum if nesterov_coef is None else nesterov_coef
|
|
294
|
-
vt_row = vt_row.lerp(mean_row_grad, 1.0 - nv_coef)
|
|
295
|
-
vt_col = vt_col.lerp(mean_col_grad, 1.0 - nv_coef)
|
|
296
291
|
else:
|
|
297
292
|
vt_row = None
|
|
298
293
|
vt_col = None
|
|
@@ -314,8 +309,8 @@ class SinkSGD_adv(torch.optim.Optimizer):
|
|
|
314
309
|
if group.get('centered_vt', False):
|
|
315
310
|
# Align with Sinkhorn: Alternate row/col preconditioning
|
|
316
311
|
update_2d = update.view(update.shape[0], -1)
|
|
317
|
-
update_2d.
|
|
318
|
-
update_2d.
|
|
312
|
+
update_2d.mul_(vt_row.clamp_min(1e-30).rsqrt().unsqueeze(1))
|
|
313
|
+
update_2d.mul_(vt_col.clamp_min(1e-30).rsqrt().unsqueeze(0))
|
|
319
314
|
update = update_2d.atan_().view_as(p)
|
|
320
315
|
|
|
321
316
|
if not group.get('normed_momentum', False):
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|