adv-optm 2.4.dev21__tar.gz → 2.4.dev22__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {adv_optm-2.4.dev21 → adv_optm-2.4.dev22}/PKG-INFO +1 -1
- {adv_optm-2.4.dev21 → adv_optm-2.4.dev22}/adv_optm/__init__.py +1 -1
- {adv_optm-2.4.dev21 → adv_optm-2.4.dev22}/adv_optm/optim/SignSGD_adv.py +10 -10
- {adv_optm-2.4.dev21 → adv_optm-2.4.dev22}/adv_optm/optim/SinkSGD_adv.py +10 -10
- {adv_optm-2.4.dev21 → adv_optm-2.4.dev22}/adv_optm.egg-info/PKG-INFO +1 -1
- {adv_optm-2.4.dev21 → adv_optm-2.4.dev22}/setup.py +1 -1
- {adv_optm-2.4.dev21 → adv_optm-2.4.dev22}/LICENSE +0 -0
- {adv_optm-2.4.dev21 → adv_optm-2.4.dev22}/README.md +0 -0
- {adv_optm-2.4.dev21 → adv_optm-2.4.dev22}/adv_optm/optim/AdaMuon_adv.py +0 -0
- {adv_optm-2.4.dev21 → adv_optm-2.4.dev22}/adv_optm/optim/AdamW_adv.py +0 -0
- {adv_optm-2.4.dev21 → adv_optm-2.4.dev22}/adv_optm/optim/Adopt_adv.py +0 -0
- {adv_optm-2.4.dev21 → adv_optm-2.4.dev22}/adv_optm/optim/Lion_adv.py +0 -0
- {adv_optm-2.4.dev21 → adv_optm-2.4.dev22}/adv_optm/optim/Muon_adv.py +0 -0
- {adv_optm-2.4.dev21 → adv_optm-2.4.dev22}/adv_optm/optim/Prodigy_adv.py +0 -0
- {adv_optm-2.4.dev21 → adv_optm-2.4.dev22}/adv_optm/optim/__init__.py +0 -0
- {adv_optm-2.4.dev21 → adv_optm-2.4.dev22}/adv_optm/util/Kourkoutas.py +0 -0
- {adv_optm-2.4.dev21 → adv_optm-2.4.dev22}/adv_optm/util/Muon_AuxAdam.py +0 -0
- {adv_optm-2.4.dev21 → adv_optm-2.4.dev22}/adv_optm/util/Muon_util.py +0 -0
- {adv_optm-2.4.dev21 → adv_optm-2.4.dev22}/adv_optm/util/OrthoGrad.py +0 -0
- {adv_optm-2.4.dev21 → adv_optm-2.4.dev22}/adv_optm/util/__init__.py +0 -0
- {adv_optm-2.4.dev21 → adv_optm-2.4.dev22}/adv_optm/util/centered_decay.py +0 -0
- {adv_optm-2.4.dev21 → adv_optm-2.4.dev22}/adv_optm/util/factorization_util.py +0 -0
- {adv_optm-2.4.dev21 → adv_optm-2.4.dev22}/adv_optm/util/lion_k.py +0 -0
- {adv_optm-2.4.dev21 → adv_optm-2.4.dev22}/adv_optm/util/param_update.py +0 -0
- {adv_optm-2.4.dev21 → adv_optm-2.4.dev22}/adv_optm/util/scaled_optm.py +0 -0
- {adv_optm-2.4.dev21 → adv_optm-2.4.dev22}/adv_optm/util/signed_util.py +0 -0
- {adv_optm-2.4.dev21 → adv_optm-2.4.dev22}/adv_optm/util/sinkhorn.py +0 -0
- {adv_optm-2.4.dev21 → adv_optm-2.4.dev22}/adv_optm/util/state_util.py +0 -0
- {adv_optm-2.4.dev21 → adv_optm-2.4.dev22}/adv_optm/util/update_util.py +0 -0
- {adv_optm-2.4.dev21 → adv_optm-2.4.dev22}/adv_optm.egg-info/SOURCES.txt +0 -0
- {adv_optm-2.4.dev21 → adv_optm-2.4.dev22}/adv_optm.egg-info/dependency_links.txt +0 -0
- {adv_optm-2.4.dev21 → adv_optm-2.4.dev22}/adv_optm.egg-info/requires.txt +0 -0
- {adv_optm-2.4.dev21 → adv_optm-2.4.dev22}/adv_optm.egg-info/top_level.txt +0 -0
- {adv_optm-2.4.dev21 → adv_optm-2.4.dev22}/setup.cfg +0 -0
|
@@ -70,8 +70,8 @@ class SignSGD_adv(torch.optim.Optimizer):
|
|
|
70
70
|
nesterov_coef: float | None = None,
|
|
71
71
|
# Normalization then Momentum
|
|
72
72
|
normed_momentum: bool = False,
|
|
73
|
-
#
|
|
74
|
-
|
|
73
|
+
# SNR Precondition
|
|
74
|
+
snr_cond: bool = False,
|
|
75
75
|
# Centered WD
|
|
76
76
|
centered_wd: float = 0.0,
|
|
77
77
|
centered_wd_mode: str = 'float8',
|
|
@@ -91,8 +91,8 @@ class SignSGD_adv(torch.optim.Optimizer):
|
|
|
91
91
|
raise ValueError(f"momentum should be in [0.0, 1.0], but got {momentum}")
|
|
92
92
|
if not weight_decay >= 0.0:
|
|
93
93
|
raise ValueError(f"Weight decay must be >= 0.0, but got {weight_decay}")
|
|
94
|
-
if
|
|
95
|
-
raise NotImplementedError(f"
|
|
94
|
+
if snr_cond and not normed_momentum and not momentum > 0:
|
|
95
|
+
raise NotImplementedError(f"snr_cond is intended to be used with normed_momentum")
|
|
96
96
|
|
|
97
97
|
state_precision = state_precision.lower()
|
|
98
98
|
valid_precisions = {"auto", "fp32", "factored", "bf16_sr", "fp16", "fp8_sr", "int8_sr"}
|
|
@@ -115,7 +115,7 @@ class SignSGD_adv(torch.optim.Optimizer):
|
|
|
115
115
|
nesterov=nesterov,
|
|
116
116
|
nesterov_coef=nesterov_coef,
|
|
117
117
|
normed_momentum=normed_momentum,
|
|
118
|
-
|
|
118
|
+
snr_cond=snr_cond,
|
|
119
119
|
spectral_normalization=spectral_normalization,
|
|
120
120
|
centered_wd= centered_wd,
|
|
121
121
|
centered_wd_mode= centered_wd_mode,
|
|
@@ -254,7 +254,7 @@ class SignSGD_adv(torch.optim.Optimizer):
|
|
|
254
254
|
nesterov = group.get('nesterov', False)
|
|
255
255
|
nesterov_coef = group.get('nesterov_coef', None)
|
|
256
256
|
sso = group.get('stochastic_sign', False)
|
|
257
|
-
|
|
257
|
+
snr_cond = group.get('snr_cond', False) and group.get('normed_momentum', False) and momentum > 0
|
|
258
258
|
|
|
259
259
|
denom = None
|
|
260
260
|
wd_target = None
|
|
@@ -278,7 +278,7 @@ class SignSGD_adv(torch.optim.Optimizer):
|
|
|
278
278
|
# Reconstruct momentum m_{t-1}
|
|
279
279
|
exp_avg = _reconstruct_state((state['mu_m_nmf'], state['mv_m_nmf'], state['sign'], d2), signed=True, shifter=state['shifter'])
|
|
280
280
|
|
|
281
|
-
if
|
|
281
|
+
if snr_cond:
|
|
282
282
|
denom = (1.0 - exp_avg.square()).clamp_min_(1e-30).sqrt_().view_as(p)
|
|
283
283
|
|
|
284
284
|
exp_avg.lerp_(grad_reshaped, 1 - momentum)
|
|
@@ -302,7 +302,7 @@ class SignSGD_adv(torch.optim.Optimizer):
|
|
|
302
302
|
actual_precision = group['actual_state_precision']
|
|
303
303
|
exp_avg = get_state(state, 'exp_avg', actual_precision)
|
|
304
304
|
|
|
305
|
-
if
|
|
305
|
+
if snr_cond:
|
|
306
306
|
denom = (1.0 - exp_avg.square()).clamp_min_(1e-30).sqrt_()
|
|
307
307
|
|
|
308
308
|
exp_avg.lerp_(grad, 1 - momentum)
|
|
@@ -325,7 +325,7 @@ class SignSGD_adv(torch.optim.Optimizer):
|
|
|
325
325
|
else:
|
|
326
326
|
update = raw_update
|
|
327
327
|
|
|
328
|
-
if
|
|
328
|
+
if snr_cond:
|
|
329
329
|
update.atan2_(denom)
|
|
330
330
|
|
|
331
331
|
if group.get('geometric_wd', False) and group["weight_decay"] > 0 :
|
|
@@ -339,7 +339,7 @@ class SignSGD_adv(torch.optim.Optimizer):
|
|
|
339
339
|
if group.get('spectral_normalization', False):
|
|
340
340
|
update = scale_update(p, update, lr, state=state)
|
|
341
341
|
else:
|
|
342
|
-
update_scaling = lr * A if
|
|
342
|
+
update_scaling = lr * A if snr_cond else lr
|
|
343
343
|
update.mul_(update_scaling)
|
|
344
344
|
|
|
345
345
|
param_update.apply_parameter_update(self, p, group, update, lr, random_int_tensor=random_int_tensor, wd_target=wd_target, cwd_target=cwd_target)
|
|
@@ -58,8 +58,8 @@ class SinkSGD_adv(torch.optim.Optimizer):
|
|
|
58
58
|
orthogonal_sinkhorn: bool = False,
|
|
59
59
|
# Normalization then Momentum
|
|
60
60
|
normed_momentum: bool = False,
|
|
61
|
-
#
|
|
62
|
-
|
|
61
|
+
# SNR Precondition
|
|
62
|
+
snr_cond: bool = False,
|
|
63
63
|
# Nesterov Momentum
|
|
64
64
|
nesterov: bool = False,
|
|
65
65
|
nesterov_coef: float | None = None,
|
|
@@ -89,8 +89,8 @@ class SinkSGD_adv(torch.optim.Optimizer):
|
|
|
89
89
|
raise ValueError(f"Momentum should be >= 0.0. Got {momentum}")
|
|
90
90
|
if not (weight_decay >= 0.0):
|
|
91
91
|
raise ValueError(f"Weight-decay should be >= 0.0. Got {weight_decay}")
|
|
92
|
-
if
|
|
93
|
-
raise NotImplementedError(f"
|
|
92
|
+
if snr_cond and not normed_momentum:
|
|
93
|
+
raise NotImplementedError(f"snr_cond is intended to be used with normed_momentum")
|
|
94
94
|
|
|
95
95
|
state_precision = state_precision.lower()
|
|
96
96
|
valid_precisions = {"auto", "fp32", "factored", "bf16_sr", "fp16", "fp8_sr", "int8_sr"}
|
|
@@ -102,7 +102,7 @@ class SinkSGD_adv(torch.optim.Optimizer):
|
|
|
102
102
|
|
|
103
103
|
defaults = {
|
|
104
104
|
"lr": lr, "momentum": momentum,
|
|
105
|
-
"weight_decay": weight_decay, "nesterov": nesterov, "nesterov_coef": nesterov_coef, "normed_momentum": normed_momentum, "
|
|
105
|
+
"weight_decay": weight_decay, "nesterov": nesterov, "nesterov_coef": nesterov_coef, "normed_momentum": normed_momentum, "snr_cond": snr_cond,
|
|
106
106
|
"geometric_wd": geometric_wd, "cautious_wd": cautious_wd,
|
|
107
107
|
"orthogonal_gradient": orthogonal_gradient,
|
|
108
108
|
"compiled_optimizer": compiled_optimizer,
|
|
@@ -228,7 +228,7 @@ class SinkSGD_adv(torch.optim.Optimizer):
|
|
|
228
228
|
momentum = group['momentum']
|
|
229
229
|
nesterov = group['nesterov']
|
|
230
230
|
nesterov_coef = group.get('nesterov_coef', None)
|
|
231
|
-
|
|
231
|
+
snr_cond = group.get('snr_cond', False)
|
|
232
232
|
|
|
233
233
|
vt_row = None
|
|
234
234
|
vt_col = None
|
|
@@ -256,7 +256,7 @@ class SinkSGD_adv(torch.optim.Optimizer):
|
|
|
256
256
|
if momentum != 0:
|
|
257
257
|
buf = _reconstruct_state((state['mu_b_nmf'], state['mv_b_nmf'], state['sign'], d2), signed=True, shifter=state['shifter'])
|
|
258
258
|
|
|
259
|
-
if
|
|
259
|
+
if snr_cond:
|
|
260
260
|
if not is_vector:
|
|
261
261
|
buf_2d_sq = buf.view(grad.shape[0], -1).square()
|
|
262
262
|
vt_row = (1 - buf_2d_sq.mean(dim=-1)).clamp_min_(1e-30)
|
|
@@ -286,7 +286,7 @@ class SinkSGD_adv(torch.optim.Optimizer):
|
|
|
286
286
|
if momentum != 0:
|
|
287
287
|
buf = get_state(state, 'momentum_buffer', actual_precision)
|
|
288
288
|
|
|
289
|
-
if
|
|
289
|
+
if snr_cond:
|
|
290
290
|
if not is_vector:
|
|
291
291
|
buf_2d_sq = buf.view(grad.shape[0], -1).square()
|
|
292
292
|
vt_row = (1 - buf_2d_sq.mean(dim=-1)).clamp_min_(1e-30)
|
|
@@ -309,7 +309,7 @@ class SinkSGD_adv(torch.optim.Optimizer):
|
|
|
309
309
|
|
|
310
310
|
del random_int_state_tensor
|
|
311
311
|
|
|
312
|
-
if
|
|
312
|
+
if snr_cond:
|
|
313
313
|
if not is_vector:
|
|
314
314
|
# Align with Sinkhorn: Alternate row/col preconditioning
|
|
315
315
|
update_2d = update.view(update.shape[0], -1)
|
|
@@ -342,7 +342,7 @@ class SinkSGD_adv(torch.optim.Optimizer):
|
|
|
342
342
|
if group.get('spectral_normalization', False):
|
|
343
343
|
update = scale_update(p, update, update_scaling, state=state)
|
|
344
344
|
else:
|
|
345
|
-
if
|
|
345
|
+
if snr_cond:
|
|
346
346
|
update_scaling = update_scaling * (4/math.pi)
|
|
347
347
|
update.mul_(update_scaling)
|
|
348
348
|
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|