adv-optm 2.2.dev1__tar.gz → 2.2.dev3__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {adv_optm-2.2.dev1 → adv_optm-2.2.dev3}/PKG-INFO +1 -1
- {adv_optm-2.2.dev1 → adv_optm-2.2.dev3}/adv_optm/__init__.py +1 -1
- {adv_optm-2.2.dev1 → adv_optm-2.2.dev3}/adv_optm/optim/AdaMuon_adv.py +6 -1
- {adv_optm-2.2.dev1 → adv_optm-2.2.dev3}/adv_optm/optim/Muon_adv.py +5 -1
- {adv_optm-2.2.dev1 → adv_optm-2.2.dev3}/adv_optm/util/Muon_util.py +7 -2
- {adv_optm-2.2.dev1 → adv_optm-2.2.dev3}/adv_optm/util/param_update.py +10 -4
- {adv_optm-2.2.dev1 → adv_optm-2.2.dev3}/adv_optm.egg-info/PKG-INFO +1 -1
- {adv_optm-2.2.dev1 → adv_optm-2.2.dev3}/setup.py +1 -1
- {adv_optm-2.2.dev1 → adv_optm-2.2.dev3}/LICENSE +0 -0
- {adv_optm-2.2.dev1 → adv_optm-2.2.dev3}/README.md +0 -0
- {adv_optm-2.2.dev1 → adv_optm-2.2.dev3}/adv_optm/optim/AdamW_adv.py +0 -0
- {adv_optm-2.2.dev1 → adv_optm-2.2.dev3}/adv_optm/optim/Adopt_adv.py +0 -0
- {adv_optm-2.2.dev1 → adv_optm-2.2.dev3}/adv_optm/optim/Lion_Prodigy_adv.py +0 -0
- {adv_optm-2.2.dev1 → adv_optm-2.2.dev3}/adv_optm/optim/Lion_adv.py +0 -0
- {adv_optm-2.2.dev1 → adv_optm-2.2.dev3}/adv_optm/optim/Prodigy_adv.py +0 -0
- {adv_optm-2.2.dev1 → adv_optm-2.2.dev3}/adv_optm/optim/SignSGD_adv.py +0 -0
- {adv_optm-2.2.dev1 → adv_optm-2.2.dev3}/adv_optm/optim/Simplified_AdEMAMix.py +0 -0
- {adv_optm-2.2.dev1 → adv_optm-2.2.dev3}/adv_optm/optim/__init__.py +0 -0
- {adv_optm-2.2.dev1 → adv_optm-2.2.dev3}/adv_optm/util/Kourkoutas.py +0 -0
- {adv_optm-2.2.dev1 → adv_optm-2.2.dev3}/adv_optm/util/Muon_AuxAdam.py +0 -0
- {adv_optm-2.2.dev1 → adv_optm-2.2.dev3}/adv_optm/util/OrthoGrad.py +0 -0
- {adv_optm-2.2.dev1 → adv_optm-2.2.dev3}/adv_optm/util/__init__.py +0 -0
- {adv_optm-2.2.dev1 → adv_optm-2.2.dev3}/adv_optm/util/factorization_util.py +0 -0
- {adv_optm-2.2.dev1 → adv_optm-2.2.dev3}/adv_optm/util/lion_k.py +0 -0
- {adv_optm-2.2.dev1 → adv_optm-2.2.dev3}/adv_optm/util/update_util.py +0 -0
- {adv_optm-2.2.dev1 → adv_optm-2.2.dev3}/adv_optm.egg-info/SOURCES.txt +0 -0
- {adv_optm-2.2.dev1 → adv_optm-2.2.dev3}/adv_optm.egg-info/dependency_links.txt +0 -0
- {adv_optm-2.2.dev1 → adv_optm-2.2.dev3}/adv_optm.egg-info/requires.txt +0 -0
- {adv_optm-2.2.dev1 → adv_optm-2.2.dev3}/adv_optm.egg-info/top_level.txt +0 -0
- {adv_optm-2.2.dev1 → adv_optm-2.2.dev3}/setup.cfg +0 -0
|
@@ -225,6 +225,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
225
225
|
"adam_k_warmup_steps": adam_k_warmup_steps, "adam_nnmf_factor": adam_nnmf_factor,
|
|
226
226
|
}
|
|
227
227
|
self.stochastic_rounding = stochastic_rounding
|
|
228
|
+
self._init_lr = lr
|
|
228
229
|
|
|
229
230
|
super().__init__(params, defaults)
|
|
230
231
|
|
|
@@ -439,9 +440,13 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
439
440
|
scaled_eps, adaptive_eps, spectral_target, wd_scale = get_spectral_scaling(shape_for_scaling, group['n_layers'])
|
|
440
441
|
|
|
441
442
|
weight_decay = group['weight_decay'] * wd_scale
|
|
443
|
+
decoupled_wd = True
|
|
444
|
+
|
|
442
445
|
ns_eps = scaled_eps
|
|
446
|
+
|
|
443
447
|
else:
|
|
444
448
|
weight_decay = group['weight_decay']
|
|
449
|
+
decoupled_wd = False
|
|
445
450
|
ns_eps = group['ns_eps']
|
|
446
451
|
adaptive_eps = group['eps']
|
|
447
452
|
|
|
@@ -587,7 +592,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
587
592
|
|
|
588
593
|
update = update.reshape(original_shape)
|
|
589
594
|
|
|
590
|
-
param_update.apply_parameter_update(self, p, group, update, lr, wd=weight_decay, random_int_tensor=random_int_tensor)
|
|
595
|
+
param_update.apply_parameter_update(self, p, group, update, lr, wd=weight_decay, random_int_tensor=random_int_tensor, decoupled=decoupled_wd)
|
|
591
596
|
|
|
592
597
|
@torch.no_grad()
|
|
593
598
|
def step(self, closure=None):
|
|
@@ -204,6 +204,7 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
204
204
|
}
|
|
205
205
|
self.stochastic_rounding = stochastic_rounding
|
|
206
206
|
self.compiled_optimizer = compiled_optimizer
|
|
207
|
+
self._init_lr = lr
|
|
207
208
|
|
|
208
209
|
super().__init__(params, defaults)
|
|
209
210
|
|
|
@@ -404,9 +405,12 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
404
405
|
scaled_eps, _, spectral_target, wd_scale = get_spectral_scaling(shape_for_scaling, group['n_layers'])
|
|
405
406
|
|
|
406
407
|
weight_decay = group['weight_decay'] * wd_scale
|
|
408
|
+
decoupled_wd = True
|
|
409
|
+
|
|
407
410
|
ns_eps = scaled_eps
|
|
408
411
|
else:
|
|
409
412
|
weight_decay = group['weight_decay']
|
|
413
|
+
decoupled_wd = False
|
|
410
414
|
ns_eps = group['ns_eps']
|
|
411
415
|
|
|
412
416
|
# MARS-M Approximated (Variance Reduction)
|
|
@@ -521,7 +525,7 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
521
525
|
|
|
522
526
|
update = update.reshape(original_shape)
|
|
523
527
|
|
|
524
|
-
param_update.apply_parameter_update(self, p, group, update, lr, wd=weight_decay, random_int_tensor=random_int_tensor)
|
|
528
|
+
param_update.apply_parameter_update(self, p, group, update, lr, wd=weight_decay, random_int_tensor=random_int_tensor, decoupled=decoupled_wd)
|
|
525
529
|
|
|
526
530
|
@torch.no_grad()
|
|
527
531
|
def step(self, closure=None):
|
|
@@ -352,7 +352,12 @@ def spectral_norm_update(update: torch.Tensor, vector_state: torch.Tensor, targe
|
|
|
352
352
|
# Normalize v_new to get next state
|
|
353
353
|
v_norm = torch.linalg.vector_norm(v_new)
|
|
354
354
|
|
|
355
|
-
|
|
355
|
+
# if v_norm >= 0.5:
|
|
356
|
+
# vector_state.copy_(v_new.div_(v_norm.clamp_min_(1e-12))).to(vector_state.dtype))
|
|
357
|
+
candidate_v = v_new / v_norm
|
|
358
|
+
next_state = torch.where(v_norm >= 0.5, candidate_v, vector_state)
|
|
359
|
+
vector_state.copy_(next_state.to(vector_state.dtype))
|
|
360
|
+
# Else: We keep the old vector_state (which is a random unit vector at init)
|
|
356
361
|
|
|
357
362
|
# Estimate sigma = ||A @ v|| (since v is unit norm)
|
|
358
363
|
# Re-compute A @ v_new with the updated vector for better estimate
|
|
@@ -379,7 +384,7 @@ def get_spectral_scaling(shape: torch.Size, n_layers: int):
|
|
|
379
384
|
wd_scale: Weight decay scale
|
|
380
385
|
"""
|
|
381
386
|
d_out, d_in = shape[0], shape[1]
|
|
382
|
-
|
|
387
|
+
|
|
383
388
|
# Handle Convolutional/Flattened tensors
|
|
384
389
|
if len(shape) > 2:
|
|
385
390
|
d_in = shape[1:].numel()
|
|
@@ -14,6 +14,7 @@ def apply_parameter_update(
|
|
|
14
14
|
lr: float | Tensor,
|
|
15
15
|
wd: float | None = None,
|
|
16
16
|
random_int_tensor: Tensor | None = None,
|
|
17
|
+
decoupled: bool = False,
|
|
17
18
|
) -> None:
|
|
18
19
|
"""
|
|
19
20
|
Applies decoupled weight decay (standard or cautious) and the final
|
|
@@ -27,9 +28,14 @@ def apply_parameter_update(
|
|
|
27
28
|
wd: Optional float value for weight decay, if another value other than group["weight_decay"] is needed.
|
|
28
29
|
random_int_tensor: Optional pre-generated random tensor for stochastic
|
|
29
30
|
rounding. Required for the `torch.compile` path.
|
|
31
|
+
decoupled: Whenever to use the true decoupled weight decay.
|
|
30
32
|
"""
|
|
31
33
|
wd = group["weight_decay"] if wd is None else wd
|
|
32
34
|
cautious = group.get('cautious_wd', False)
|
|
35
|
+
if decoupled:
|
|
36
|
+
scaled_wd = 1 / (1.0 + wd * lr / self._init_lr)
|
|
37
|
+
else:
|
|
38
|
+
scaled_wd = wd * lr
|
|
33
39
|
|
|
34
40
|
# Compute full update in float32 if using bfloat16 with stochastic rounding
|
|
35
41
|
if p.dtype == torch.bfloat16 and self.stochastic_rounding:
|
|
@@ -41,11 +47,11 @@ def apply_parameter_update(
|
|
|
41
47
|
if cautious:
|
|
42
48
|
# Cautious Weight Decay
|
|
43
49
|
mask = (update_fp32 * p_fp32 >= 0).float()
|
|
44
|
-
p_fp32.addcmul_(p_fp32, mask, value=-
|
|
50
|
+
p_fp32.addcmul_(p_fp32, mask, value=-scaled_wd)
|
|
45
51
|
del mask
|
|
46
52
|
else:
|
|
47
53
|
# Standard decoupled weight decay
|
|
48
|
-
p_fp32.add_(p_fp32, alpha=-
|
|
54
|
+
p_fp32.add_(p_fp32, alpha=-scaled_wd)
|
|
49
55
|
|
|
50
56
|
# Apply main update
|
|
51
57
|
p_fp32.add_(-update_fp32)
|
|
@@ -65,11 +71,11 @@ def apply_parameter_update(
|
|
|
65
71
|
if cautious:
|
|
66
72
|
# Cautious Weight Decay
|
|
67
73
|
mask = (update * p >= 0).to(p.dtype)
|
|
68
|
-
p.addcmul_(p, mask, value=-
|
|
74
|
+
p.addcmul_(p, mask, value=-scaled_wd)
|
|
69
75
|
del mask
|
|
70
76
|
else:
|
|
71
77
|
# Standard decoupled weight decay
|
|
72
|
-
p.add_(p, alpha=-
|
|
78
|
+
p.add_(p, alpha=-scaled_wd)
|
|
73
79
|
|
|
74
80
|
# Apply main update
|
|
75
81
|
p.add_(-update)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|