adv-optm 2.2.0__tar.gz → 2.2.dev1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {adv_optm-2.2.0 → adv_optm-2.2.dev1}/PKG-INFO +1 -1
- {adv_optm-2.2.0 → adv_optm-2.2.dev1}/adv_optm/__init__.py +1 -1
- {adv_optm-2.2.0 → adv_optm-2.2.dev1}/adv_optm/optim/AdaMuon_adv.py +3 -10
- {adv_optm-2.2.0 → adv_optm-2.2.dev1}/adv_optm/optim/Muon_adv.py +4 -10
- {adv_optm-2.2.0 → adv_optm-2.2.dev1}/adv_optm/util/Muon_util.py +2 -7
- {adv_optm-2.2.0 → adv_optm-2.2.dev1}/adv_optm/util/param_update.py +7 -20
- {adv_optm-2.2.0 → adv_optm-2.2.dev1}/adv_optm.egg-info/PKG-INFO +1 -1
- {adv_optm-2.2.0 → adv_optm-2.2.dev1}/setup.py +1 -1
- {adv_optm-2.2.0 → adv_optm-2.2.dev1}/LICENSE +0 -0
- {adv_optm-2.2.0 → adv_optm-2.2.dev1}/README.md +0 -0
- {adv_optm-2.2.0 → adv_optm-2.2.dev1}/adv_optm/optim/AdamW_adv.py +0 -0
- {adv_optm-2.2.0 → adv_optm-2.2.dev1}/adv_optm/optim/Adopt_adv.py +0 -0
- {adv_optm-2.2.0 → adv_optm-2.2.dev1}/adv_optm/optim/Lion_Prodigy_adv.py +0 -0
- {adv_optm-2.2.0 → adv_optm-2.2.dev1}/adv_optm/optim/Lion_adv.py +0 -0
- {adv_optm-2.2.0 → adv_optm-2.2.dev1}/adv_optm/optim/Prodigy_adv.py +0 -0
- {adv_optm-2.2.0 → adv_optm-2.2.dev1}/adv_optm/optim/SignSGD_adv.py +0 -0
- {adv_optm-2.2.0 → adv_optm-2.2.dev1}/adv_optm/optim/Simplified_AdEMAMix.py +0 -0
- {adv_optm-2.2.0 → adv_optm-2.2.dev1}/adv_optm/optim/__init__.py +0 -0
- {adv_optm-2.2.0 → adv_optm-2.2.dev1}/adv_optm/util/Kourkoutas.py +0 -0
- {adv_optm-2.2.0 → adv_optm-2.2.dev1}/adv_optm/util/Muon_AuxAdam.py +0 -0
- {adv_optm-2.2.0 → adv_optm-2.2.dev1}/adv_optm/util/OrthoGrad.py +0 -0
- {adv_optm-2.2.0 → adv_optm-2.2.dev1}/adv_optm/util/__init__.py +0 -0
- {adv_optm-2.2.0 → adv_optm-2.2.dev1}/adv_optm/util/factorization_util.py +0 -0
- {adv_optm-2.2.0 → adv_optm-2.2.dev1}/adv_optm/util/lion_k.py +0 -0
- {adv_optm-2.2.0 → adv_optm-2.2.dev1}/adv_optm/util/update_util.py +0 -0
- {adv_optm-2.2.0 → adv_optm-2.2.dev1}/adv_optm.egg-info/SOURCES.txt +0 -0
- {adv_optm-2.2.0 → adv_optm-2.2.dev1}/adv_optm.egg-info/dependency_links.txt +0 -0
- {adv_optm-2.2.0 → adv_optm-2.2.dev1}/adv_optm.egg-info/requires.txt +0 -0
- {adv_optm-2.2.0 → adv_optm-2.2.dev1}/adv_optm.egg-info/top_level.txt +0 -0
- {adv_optm-2.2.0 → adv_optm-2.2.dev1}/setup.cfg +0 -0
|
@@ -225,7 +225,6 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
225
225
|
"adam_k_warmup_steps": adam_k_warmup_steps, "adam_nnmf_factor": adam_nnmf_factor,
|
|
226
226
|
}
|
|
227
227
|
self.stochastic_rounding = stochastic_rounding
|
|
228
|
-
self._init_lr = lr
|
|
229
228
|
|
|
230
229
|
super().__init__(params, defaults)
|
|
231
230
|
|
|
@@ -319,13 +318,11 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
319
318
|
|
|
320
319
|
# Spectral Normalization
|
|
321
320
|
if group.get('spectral_normalization', False):
|
|
322
|
-
gen = param_update.get_generator(device)
|
|
323
|
-
|
|
324
321
|
# Case A: Factored Muon
|
|
325
322
|
if state['factored']:
|
|
326
323
|
d1, d2 = state['effective_shape']
|
|
327
324
|
# We need a vector matching the 'inner' dimension d2
|
|
328
|
-
state['spectral_v'] = torch.randn(d2, device=device, dtype=dtype
|
|
325
|
+
state['spectral_v'] = torch.randn(d2, device=device, dtype=dtype)
|
|
329
326
|
|
|
330
327
|
# Case B: Standard Muon (Linear, Conv2d, etc.)
|
|
331
328
|
elif len(p.shape) >= 2:
|
|
@@ -333,7 +330,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
333
330
|
# (p.shape[0], product_of_rest).
|
|
334
331
|
d_in_flat = p.numel() // p.shape[0]
|
|
335
332
|
|
|
336
|
-
state['spectral_v'] = torch.randn(d_in_flat, device=device, dtype=dtype
|
|
333
|
+
state['spectral_v'] = torch.randn(d_in_flat, device=device, dtype=dtype)
|
|
337
334
|
|
|
338
335
|
# Normalize initial vector for stability
|
|
339
336
|
if 'spectral_v' in state:
|
|
@@ -442,13 +439,9 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
442
439
|
scaled_eps, adaptive_eps, spectral_target, wd_scale = get_spectral_scaling(shape_for_scaling, group['n_layers'])
|
|
443
440
|
|
|
444
441
|
weight_decay = group['weight_decay'] * wd_scale
|
|
445
|
-
decoupled_wd = True
|
|
446
|
-
|
|
447
442
|
ns_eps = scaled_eps
|
|
448
|
-
|
|
449
443
|
else:
|
|
450
444
|
weight_decay = group['weight_decay']
|
|
451
|
-
decoupled_wd = False
|
|
452
445
|
ns_eps = group['ns_eps']
|
|
453
446
|
adaptive_eps = group['eps']
|
|
454
447
|
|
|
@@ -594,7 +587,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
594
587
|
|
|
595
588
|
update = update.reshape(original_shape)
|
|
596
589
|
|
|
597
|
-
param_update.apply_parameter_update(self, p, group, update, lr, wd=weight_decay, random_int_tensor=random_int_tensor
|
|
590
|
+
param_update.apply_parameter_update(self, p, group, update, lr, wd=weight_decay, random_int_tensor=random_int_tensor)
|
|
598
591
|
|
|
599
592
|
@torch.no_grad()
|
|
600
593
|
def step(self, closure=None):
|
|
@@ -204,7 +204,6 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
204
204
|
}
|
|
205
205
|
self.stochastic_rounding = stochastic_rounding
|
|
206
206
|
self.compiled_optimizer = compiled_optimizer
|
|
207
|
-
self._init_lr = lr
|
|
208
207
|
|
|
209
208
|
super().__init__(params, defaults)
|
|
210
209
|
|
|
@@ -285,13 +284,11 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
285
284
|
|
|
286
285
|
# Spectral Normalization
|
|
287
286
|
if group.get('spectral_normalization', False):
|
|
288
|
-
gen = param_update.get_generator(device)
|
|
289
|
-
|
|
290
287
|
# Case A: Factored Muon
|
|
291
288
|
if state['factored']:
|
|
292
289
|
d1, d2 = state['effective_shape']
|
|
293
290
|
# We need a vector matching the 'inner' dimension d2
|
|
294
|
-
state['spectral_v'] = torch.randn(d2, device=device, dtype=dtype
|
|
291
|
+
state['spectral_v'] = torch.randn(d2, device=device, dtype=dtype)
|
|
295
292
|
|
|
296
293
|
# Case B: Standard Muon (Linear, Conv2d, etc.)
|
|
297
294
|
elif len(p.shape) >= 2:
|
|
@@ -299,11 +296,11 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
299
296
|
# (p.shape[0], product_of_rest).
|
|
300
297
|
d_in_flat = p.numel() // p.shape[0]
|
|
301
298
|
|
|
302
|
-
state['spectral_v'] = torch.randn(d_in_flat, device=device, dtype=dtype
|
|
299
|
+
state['spectral_v'] = torch.randn(d_in_flat, device=device, dtype=dtype)
|
|
303
300
|
|
|
304
301
|
# Normalize initial vector for stability
|
|
305
302
|
if 'spectral_v' in state:
|
|
306
|
-
|
|
303
|
+
state['spectral_v'].div_(state['spectral_v'].norm())
|
|
307
304
|
|
|
308
305
|
# MARS-M state initialization
|
|
309
306
|
if group.get('approx_mars', False):
|
|
@@ -407,12 +404,9 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
407
404
|
scaled_eps, _, spectral_target, wd_scale = get_spectral_scaling(shape_for_scaling, group['n_layers'])
|
|
408
405
|
|
|
409
406
|
weight_decay = group['weight_decay'] * wd_scale
|
|
410
|
-
decoupled_wd = True
|
|
411
|
-
|
|
412
407
|
ns_eps = scaled_eps
|
|
413
408
|
else:
|
|
414
409
|
weight_decay = group['weight_decay']
|
|
415
|
-
decoupled_wd = False
|
|
416
410
|
ns_eps = group['ns_eps']
|
|
417
411
|
|
|
418
412
|
# MARS-M Approximated (Variance Reduction)
|
|
@@ -527,7 +521,7 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
527
521
|
|
|
528
522
|
update = update.reshape(original_shape)
|
|
529
523
|
|
|
530
|
-
param_update.apply_parameter_update(self, p, group, update, lr, wd=weight_decay, random_int_tensor=random_int_tensor
|
|
524
|
+
param_update.apply_parameter_update(self, p, group, update, lr, wd=weight_decay, random_int_tensor=random_int_tensor)
|
|
531
525
|
|
|
532
526
|
@torch.no_grad()
|
|
533
527
|
def step(self, closure=None):
|
|
@@ -352,12 +352,7 @@ def spectral_norm_update(update: torch.Tensor, vector_state: torch.Tensor, targe
|
|
|
352
352
|
# Normalize v_new to get next state
|
|
353
353
|
v_norm = torch.linalg.vector_norm(v_new)
|
|
354
354
|
|
|
355
|
-
|
|
356
|
-
# vector_state.copy_(v_new.div_(v_norm.clamp_min_(1e-12))).to(vector_state.dtype))
|
|
357
|
-
candidate_v = v_new / v_norm
|
|
358
|
-
next_state = torch.where(v_norm >= 0.5, candidate_v, vector_state)
|
|
359
|
-
vector_state.copy_(next_state.to(vector_state.dtype))
|
|
360
|
-
# Else: We keep the old vector_state (which is a random unit vector at init)
|
|
355
|
+
vector_state.copy_(v_new.div_(v_norm.clamp_min_(1e-12)).to(vector_state.dtype))
|
|
361
356
|
|
|
362
357
|
# Estimate sigma = ||A @ v|| (since v is unit norm)
|
|
363
358
|
# Re-compute A @ v_new with the updated vector for better estimate
|
|
@@ -384,7 +379,7 @@ def get_spectral_scaling(shape: torch.Size, n_layers: int):
|
|
|
384
379
|
wd_scale: Weight decay scale
|
|
385
380
|
"""
|
|
386
381
|
d_out, d_in = shape[0], shape[1]
|
|
387
|
-
|
|
382
|
+
|
|
388
383
|
# Handle Convolutional/Flattened tensors
|
|
389
384
|
if len(shape) > 2:
|
|
390
385
|
d_in = shape[1:].numel()
|
|
@@ -14,7 +14,6 @@ def apply_parameter_update(
|
|
|
14
14
|
lr: float | Tensor,
|
|
15
15
|
wd: float | None = None,
|
|
16
16
|
random_int_tensor: Tensor | None = None,
|
|
17
|
-
decoupled: bool = False,
|
|
18
17
|
) -> None:
|
|
19
18
|
"""
|
|
20
19
|
Applies decoupled weight decay (standard or cautious) and the final
|
|
@@ -28,14 +27,9 @@ def apply_parameter_update(
|
|
|
28
27
|
wd: Optional float value for weight decay, if another value other than group["weight_decay"] is needed.
|
|
29
28
|
random_int_tensor: Optional pre-generated random tensor for stochastic
|
|
30
29
|
rounding. Required for the `torch.compile` path.
|
|
31
|
-
decoupled: Whenever to use the true decoupled weight decay.
|
|
32
30
|
"""
|
|
33
31
|
wd = group["weight_decay"] if wd is None else wd
|
|
34
32
|
cautious = group.get('cautious_wd', False)
|
|
35
|
-
if decoupled:
|
|
36
|
-
scaled_wd = wd * (lr / self._init_lr)
|
|
37
|
-
else:
|
|
38
|
-
scaled_wd = wd * lr
|
|
39
33
|
|
|
40
34
|
# Compute full update in float32 if using bfloat16 with stochastic rounding
|
|
41
35
|
if p.dtype == torch.bfloat16 and self.stochastic_rounding:
|
|
@@ -47,11 +41,11 @@ def apply_parameter_update(
|
|
|
47
41
|
if cautious:
|
|
48
42
|
# Cautious Weight Decay
|
|
49
43
|
mask = (update_fp32 * p_fp32 >= 0).float()
|
|
50
|
-
p_fp32.addcmul_(p_fp32, mask, value=-
|
|
44
|
+
p_fp32.addcmul_(p_fp32, mask, value=-wd * lr)
|
|
51
45
|
del mask
|
|
52
46
|
else:
|
|
53
47
|
# Standard decoupled weight decay
|
|
54
|
-
p_fp32.add_(p_fp32, alpha=-
|
|
48
|
+
p_fp32.add_(p_fp32, alpha=-wd * lr)
|
|
55
49
|
|
|
56
50
|
# Apply main update
|
|
57
51
|
p_fp32.add_(-update_fp32)
|
|
@@ -60,7 +54,6 @@ def apply_parameter_update(
|
|
|
60
54
|
if random_int_tensor is not None:
|
|
61
55
|
# Compiled path: use the pre-computed random tensor
|
|
62
56
|
_copy_stochastic_core_(p, p_fp32, random_int_tensor)
|
|
63
|
-
del random_int_tensor
|
|
64
57
|
else:
|
|
65
58
|
# Uncompiled path: generate randoms inside
|
|
66
59
|
copy_stochastic_(p, p_fp32)
|
|
@@ -72,11 +65,11 @@ def apply_parameter_update(
|
|
|
72
65
|
if cautious:
|
|
73
66
|
# Cautious Weight Decay
|
|
74
67
|
mask = (update * p >= 0).to(p.dtype)
|
|
75
|
-
p.addcmul_(p, mask, value=-
|
|
68
|
+
p.addcmul_(p, mask, value=-wd * lr)
|
|
76
69
|
del mask
|
|
77
70
|
else:
|
|
78
71
|
# Standard decoupled weight decay
|
|
79
|
-
p.add_(p, alpha=-
|
|
72
|
+
p.add_(p, alpha=-wd * lr)
|
|
80
73
|
|
|
81
74
|
# Apply main update
|
|
82
75
|
p.add_(-update)
|
|
@@ -94,14 +87,6 @@ def set_seed(device: torch.device):
|
|
|
94
87
|
_generators[device] = torch.Generator(device=device)
|
|
95
88
|
_generators[device].manual_seed(42)
|
|
96
89
|
|
|
97
|
-
def get_generator(device: torch.device) -> torch.Generator:
|
|
98
|
-
"""
|
|
99
|
-
Retrieves (and initializes if necessary) the deterministic generator
|
|
100
|
-
for the specified device.
|
|
101
|
-
"""
|
|
102
|
-
if device not in _generators:
|
|
103
|
-
set_seed(device)
|
|
104
|
-
return _generators[device]
|
|
105
90
|
|
|
106
91
|
def _get_random_int_for_sr(source: Tensor) -> Tensor:
|
|
107
92
|
"""
|
|
@@ -133,7 +118,7 @@ def _copy_stochastic_core_(target: Tensor, source: Tensor, random_int_tensor: Te
|
|
|
133
118
|
Core logic for stochastic rounding using a pre-computed random integer tensor.
|
|
134
119
|
This version is designed to be torch.compile-friendly.
|
|
135
120
|
"""
|
|
136
|
-
result = random_int_tensor
|
|
121
|
+
result = random_int_tensor.clone()
|
|
137
122
|
# add the random number to the lower 16 bit of the mantissa
|
|
138
123
|
result.add_(source.view(dtype=torch.int32))
|
|
139
124
|
|
|
@@ -143,6 +128,8 @@ def _copy_stochastic_core_(target: Tensor, source: Tensor, random_int_tensor: Te
|
|
|
143
128
|
# copy the higher 16 bit into the target tensor
|
|
144
129
|
target.copy_(result.view(dtype=torch.float32))
|
|
145
130
|
|
|
131
|
+
del result
|
|
132
|
+
|
|
146
133
|
|
|
147
134
|
def copy_stochastic_(target: Tensor, source: Tensor):
|
|
148
135
|
"""
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|