adv-optm 2.4.dev24__tar.gz → 2.4.dev25__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {adv_optm-2.4.dev24 → adv_optm-2.4.dev25}/PKG-INFO +1 -1
- {adv_optm-2.4.dev24 → adv_optm-2.4.dev25}/adv_optm/__init__.py +1 -1
- {adv_optm-2.4.dev24 → adv_optm-2.4.dev25}/adv_optm/optim/SignSGD_adv.py +12 -8
- {adv_optm-2.4.dev24 → adv_optm-2.4.dev25}/adv_optm/optim/SinkSGD_adv.py +10 -6
- {adv_optm-2.4.dev24 → adv_optm-2.4.dev25}/adv_optm.egg-info/PKG-INFO +1 -1
- {adv_optm-2.4.dev24 → adv_optm-2.4.dev25}/setup.py +1 -1
- {adv_optm-2.4.dev24 → adv_optm-2.4.dev25}/LICENSE +0 -0
- {adv_optm-2.4.dev24 → adv_optm-2.4.dev25}/README.md +0 -0
- {adv_optm-2.4.dev24 → adv_optm-2.4.dev25}/adv_optm/optim/AdaMuon_adv.py +0 -0
- {adv_optm-2.4.dev24 → adv_optm-2.4.dev25}/adv_optm/optim/AdamW_adv.py +0 -0
- {adv_optm-2.4.dev24 → adv_optm-2.4.dev25}/adv_optm/optim/Adopt_adv.py +0 -0
- {adv_optm-2.4.dev24 → adv_optm-2.4.dev25}/adv_optm/optim/Lion_adv.py +0 -0
- {adv_optm-2.4.dev24 → adv_optm-2.4.dev25}/adv_optm/optim/Muon_adv.py +0 -0
- {adv_optm-2.4.dev24 → adv_optm-2.4.dev25}/adv_optm/optim/Prodigy_adv.py +0 -0
- {adv_optm-2.4.dev24 → adv_optm-2.4.dev25}/adv_optm/optim/__init__.py +0 -0
- {adv_optm-2.4.dev24 → adv_optm-2.4.dev25}/adv_optm/util/Kourkoutas.py +0 -0
- {adv_optm-2.4.dev24 → adv_optm-2.4.dev25}/adv_optm/util/Muon_AuxAdam.py +0 -0
- {adv_optm-2.4.dev24 → adv_optm-2.4.dev25}/adv_optm/util/Muon_util.py +0 -0
- {adv_optm-2.4.dev24 → adv_optm-2.4.dev25}/adv_optm/util/OrthoGrad.py +0 -0
- {adv_optm-2.4.dev24 → adv_optm-2.4.dev25}/adv_optm/util/__init__.py +0 -0
- {adv_optm-2.4.dev24 → adv_optm-2.4.dev25}/adv_optm/util/centered_decay.py +0 -0
- {adv_optm-2.4.dev24 → adv_optm-2.4.dev25}/adv_optm/util/factorization_util.py +0 -0
- {adv_optm-2.4.dev24 → adv_optm-2.4.dev25}/adv_optm/util/lion_k.py +0 -0
- {adv_optm-2.4.dev24 → adv_optm-2.4.dev25}/adv_optm/util/param_update.py +0 -0
- {adv_optm-2.4.dev24 → adv_optm-2.4.dev25}/adv_optm/util/scaled_optm.py +0 -0
- {adv_optm-2.4.dev24 → adv_optm-2.4.dev25}/adv_optm/util/signed_util.py +0 -0
- {adv_optm-2.4.dev24 → adv_optm-2.4.dev25}/adv_optm/util/sinkhorn.py +0 -0
- {adv_optm-2.4.dev24 → adv_optm-2.4.dev25}/adv_optm/util/state_util.py +0 -0
- {adv_optm-2.4.dev24 → adv_optm-2.4.dev25}/adv_optm/util/update_util.py +0 -0
- {adv_optm-2.4.dev24 → adv_optm-2.4.dev25}/adv_optm.egg-info/SOURCES.txt +0 -0
- {adv_optm-2.4.dev24 → adv_optm-2.4.dev25}/adv_optm.egg-info/dependency_links.txt +0 -0
- {adv_optm-2.4.dev24 → adv_optm-2.4.dev25}/adv_optm.egg-info/requires.txt +0 -0
- {adv_optm-2.4.dev24 → adv_optm-2.4.dev25}/adv_optm.egg-info/top_level.txt +0 -0
- {adv_optm-2.4.dev24 → adv_optm-2.4.dev25}/setup.cfg +0 -0
|
@@ -280,16 +280,18 @@ class SignSGD_adv(torch.optim.Optimizer):
|
|
|
280
280
|
if snr_cond:
|
|
281
281
|
denom = (1.0 - exp_avg.square()).clamp_min_(1e-30).sqrt_().view_as(p)
|
|
282
282
|
|
|
283
|
+
if nesterov and normed_mt:
|
|
284
|
+
# Scale the normalized gradient using empirical buffer magnitude (SNR recovery)
|
|
285
|
+
normed_grad = grad_reshaped * exp_avg.abs()
|
|
286
|
+
|
|
283
287
|
exp_avg.lerp_(grad_reshaped, 1 - momentum)
|
|
284
288
|
|
|
285
289
|
if nesterov:
|
|
286
290
|
nv_coef = momentum if nesterov_coef is None else nesterov_coef
|
|
287
291
|
if normed_mt:
|
|
288
|
-
|
|
289
|
-
ema_std = math.sqrt((1 - momentum) / (1 + momentum))
|
|
290
|
-
raw_update = (grad_reshaped * ema_std).lerp_(exp_avg, nv_coef)
|
|
292
|
+
raw_update = normed_grad.lerp_(exp_avg, nv_coef)
|
|
291
293
|
else:
|
|
292
|
-
raw_update =
|
|
294
|
+
raw_update = grad_reshaped.lerp(exp_avg, nv_coef)
|
|
293
295
|
else:
|
|
294
296
|
raw_update = exp_avg.clone()
|
|
295
297
|
|
|
@@ -309,14 +311,16 @@ class SignSGD_adv(torch.optim.Optimizer):
|
|
|
309
311
|
if snr_cond:
|
|
310
312
|
denom = (1.0 - exp_avg.square()).clamp_min_(1e-30).sqrt_()
|
|
311
313
|
|
|
314
|
+
if nesterov and normed_mt:
|
|
315
|
+
# Scale the normalized gradient using empirical buffer magnitude (SNR recovery)
|
|
316
|
+
normed_grad = grad * exp_avg.abs()
|
|
317
|
+
|
|
312
318
|
exp_avg.lerp_(grad, 1 - momentum)
|
|
313
319
|
|
|
314
320
|
if nesterov:
|
|
315
321
|
nv_coef = momentum if nesterov_coef is None else nesterov_coef
|
|
316
322
|
if normed_mt:
|
|
317
|
-
|
|
318
|
-
ema_std = math.sqrt((1 - momentum) / (1 + momentum))
|
|
319
|
-
raw_update = (grad * ema_std).lerp_(exp_avg, nv_coef)
|
|
323
|
+
raw_update = normed_grad.lerp_(exp_avg, nv_coef)
|
|
320
324
|
else:
|
|
321
325
|
raw_update = grad.lerp(exp_avg, nv_coef)
|
|
322
326
|
else:
|
|
@@ -351,7 +355,7 @@ class SignSGD_adv(torch.optim.Optimizer):
|
|
|
351
355
|
update_scaling = lr * A if snr_cond else lr
|
|
352
356
|
update.mul_(update_scaling)
|
|
353
357
|
|
|
354
|
-
param_update.apply_parameter_update(self, p, group, update, lr, random_int_tensor=random_int_tensor, wd_target=wd_target, cwd_target=cwd_target)
|
|
358
|
+
param_update.apply_parameter_update(self, p, group, update, lr, random_int_tensor=random_int_tensor, wd_target=wd_target, cwd_target=cwd_target, decoupled=True)
|
|
355
359
|
|
|
356
360
|
def compile(self, *args, **kwargs):
|
|
357
361
|
self._compiled_step_parameter = torch.compile(self._step_parameter, *args, **kwargs)
|
|
@@ -264,6 +264,10 @@ class SinkSGD_adv(torch.optim.Optimizer):
|
|
|
264
264
|
else:
|
|
265
265
|
denom = (1.0 - buf.square()).clamp_min_(1e-30).sqrt_().view_as(p)
|
|
266
266
|
|
|
267
|
+
if nesterov and normed_mt:
|
|
268
|
+
# Scale the normalized gradient using empirical buffer magnitude (SNR recovery)
|
|
269
|
+
normed_grad = grad_reshaped * buf.abs()
|
|
270
|
+
|
|
267
271
|
buf.lerp_(grad_reshaped, 1 - momentum)
|
|
268
272
|
|
|
269
273
|
# Factorize updated buffer
|
|
@@ -272,9 +276,7 @@ class SinkSGD_adv(torch.optim.Optimizer):
|
|
|
272
276
|
if nesterov:
|
|
273
277
|
nv_coef = momentum if nesterov_coef is None else nesterov_coef
|
|
274
278
|
if normed_mt:
|
|
275
|
-
|
|
276
|
-
ema_std = math.sqrt((1 - momentum) / (1 + momentum))
|
|
277
|
-
update = (grad_reshaped * ema_std).lerp_(buf, nv_coef)
|
|
279
|
+
update = normed_grad.lerp_(buf, nv_coef)
|
|
278
280
|
else:
|
|
279
281
|
update = grad_reshaped.lerp(buf, nv_coef)
|
|
280
282
|
else:
|
|
@@ -299,6 +301,10 @@ class SinkSGD_adv(torch.optim.Optimizer):
|
|
|
299
301
|
else:
|
|
300
302
|
denom = (1.0 - buf.square()).clamp_min_(1e-30).sqrt_()
|
|
301
303
|
|
|
304
|
+
if nesterov and normed_mt:
|
|
305
|
+
# Scale the normalized gradient using empirical buffer magnitude (SNR recovery)
|
|
306
|
+
normed_grad = grad * buf.abs()
|
|
307
|
+
|
|
302
308
|
buf.lerp_(grad, 1 - momentum)
|
|
303
309
|
|
|
304
310
|
set_state(state, 'momentum_buffer', buf, actual_precision, random_int_state_tensor)
|
|
@@ -306,9 +312,7 @@ class SinkSGD_adv(torch.optim.Optimizer):
|
|
|
306
312
|
if nesterov:
|
|
307
313
|
nv_coef = momentum if nesterov_coef is None else nesterov_coef
|
|
308
314
|
if normed_mt:
|
|
309
|
-
|
|
310
|
-
ema_std = math.sqrt((1 - momentum) / (1 + momentum))
|
|
311
|
-
update = (grad * ema_std).lerp_(buf, nv_coef)
|
|
315
|
+
update = normed_grad.lerp_(buf, nv_coef)
|
|
312
316
|
else:
|
|
313
317
|
update = grad.lerp(buf, nv_coef)
|
|
314
318
|
else:
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|