adv-optm 1.2.dev18__py3-none-any.whl → 1.2.dev19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of adv-optm might be problematic. Click here for more details.
- adv_optm/__init__.py +1 -1
- adv_optm/optim/AdaMuon_adv.py +2 -10
- adv_optm/optim/Muon_adv.py +2 -13
- {adv_optm-1.2.dev18.dist-info → adv_optm-1.2.dev19.dist-info}/METADATA +1 -1
- {adv_optm-1.2.dev18.dist-info → adv_optm-1.2.dev19.dist-info}/RECORD +8 -8
- {adv_optm-1.2.dev18.dist-info → adv_optm-1.2.dev19.dist-info}/WHEEL +0 -0
- {adv_optm-1.2.dev18.dist-info → adv_optm-1.2.dev19.dist-info}/licenses/LICENSE +0 -0
- {adv_optm-1.2.dev18.dist-info → adv_optm-1.2.dev19.dist-info}/top_level.txt +0 -0
adv_optm/__init__.py
CHANGED
adv_optm/optim/AdaMuon_adv.py
CHANGED
|
@@ -351,11 +351,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
351
351
|
mean_squared_update = torch.mean(update.square(), dim=1)
|
|
352
352
|
v_t.mul_(beta2).add_(mean_squared_update, alpha=1 - beta2)
|
|
353
353
|
# Normalize update
|
|
354
|
-
|
|
355
|
-
a = 1.2732395
|
|
356
|
-
update.atan2_(v_t.sqrt().unsqueeze(1)).mul_(a)
|
|
357
|
-
else:
|
|
358
|
-
update.div_(v_t.sqrt().unsqueeze(1).add_(group['eps']))
|
|
354
|
+
update.div_(v_t.sqrt().unsqueeze(1).add_(group['eps']))
|
|
359
355
|
# Scale learning rate
|
|
360
356
|
update_norm = torch.linalg.vector_norm(update)
|
|
361
357
|
scaled_lr = group['rms_target'] * lr * (p.numel()**0.5) / update_norm.add_(group['eps'])
|
|
@@ -460,11 +456,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
460
456
|
mean_squared_update = torch.mean(update.square(), dim=1)
|
|
461
457
|
v_t.mul_(beta2).add_(mean_squared_update, alpha=1 - beta2)
|
|
462
458
|
# Normalize update
|
|
463
|
-
|
|
464
|
-
a = 1.2732395
|
|
465
|
-
update.atan2_(v_t.sqrt().unsqueeze(1)).mul_(a)
|
|
466
|
-
else:
|
|
467
|
-
update.div_(v_t.sqrt().unsqueeze(1).add_(group['eps']))
|
|
459
|
+
update.div_(v_t.sqrt().unsqueeze(1).add_(group['eps']))
|
|
468
460
|
# Scale learning rate
|
|
469
461
|
update_norm = torch.linalg.vector_norm(update)
|
|
470
462
|
scaled_lr = group['rms_target'] * lr * (p.numel()**0.5) / update_norm.add_(group['eps'])
|
adv_optm/optim/Muon_adv.py
CHANGED
|
@@ -60,7 +60,6 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
60
60
|
normuon_eps (float): Epsilon for NorMuon normalization stability. (default: 1e-8)
|
|
61
61
|
normuon_lr_scale (float): Scaling factor for the NorMuon learning rate.
|
|
62
62
|
(default: 0.2)
|
|
63
|
-
normuon_atan2 (bool): whether to use the atan2 for NorMuon. (default: False)
|
|
64
63
|
accelerated_ns (bool): If True, enables Chebyshev-accelerated Newton-Schulz, which
|
|
65
64
|
dynamically calculates optimal 3rd-order polynomial coefficients. (default: False)
|
|
66
65
|
cns_a_bound (float): Initial lower bound for singular values for CANS. (default: 1e-4)
|
|
@@ -105,7 +104,6 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
105
104
|
beta2_normuon: float = 0.95,
|
|
106
105
|
normuon_eps: float = 1e-8,
|
|
107
106
|
normuon_lr_scale: float = 0.2,
|
|
108
|
-
normuon_atan2: bool = False,
|
|
109
107
|
# CANS
|
|
110
108
|
accelerated_ns: bool = False,
|
|
111
109
|
cns_a_bound: float = 1e-4,
|
|
@@ -158,7 +156,6 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
158
156
|
# NorMuon
|
|
159
157
|
"normuon_variant": normuon_variant, "beta2_normuon": beta2_normuon,
|
|
160
158
|
"normuon_eps": normuon_eps, "normuon_lr_scale": normuon_lr_scale,
|
|
161
|
-
"normuon_atan2": normuon_atan2,
|
|
162
159
|
# CANS
|
|
163
160
|
"accelerated_ns": accelerated_ns, "cns_a_bound": cns_a_bound,
|
|
164
161
|
# AdamW_adv defaults
|
|
@@ -366,11 +363,7 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
366
363
|
mean_squared_update = torch.mean(update.square(), dim=1)
|
|
367
364
|
v_t.mul_(beta2_normuon).add_(mean_squared_update, alpha=1 - beta2_normuon)
|
|
368
365
|
# Normalize update
|
|
369
|
-
|
|
370
|
-
a = 1.2732395
|
|
371
|
-
update.atan2_(v_t.sqrt().unsqueeze(1)).mul_(a)
|
|
372
|
-
else:
|
|
373
|
-
update.div_(v_t.sqrt().unsqueeze(1).add_(group['normuon_eps']))
|
|
366
|
+
update.div_(v_t.sqrt().unsqueeze(1).add_(group['normuon_eps']))
|
|
374
367
|
# Scale learning rate
|
|
375
368
|
update_norm = torch.linalg.vector_norm(update)
|
|
376
369
|
|
|
@@ -471,11 +464,7 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
471
464
|
mean_squared_update = torch.mean(update.square(), dim=1)
|
|
472
465
|
v_t.mul_(beta2_normuon).add_(mean_squared_update, alpha=1 - beta2_normuon)
|
|
473
466
|
# Normalize update
|
|
474
|
-
|
|
475
|
-
a = 1.2732395
|
|
476
|
-
update.atan2_(v_t.sqrt().unsqueeze(1)).mul_(a)
|
|
477
|
-
else:
|
|
478
|
-
update.div_(v_t.sqrt().unsqueeze(1).add_(group['normuon_eps']))
|
|
467
|
+
update.div_(v_t.sqrt().unsqueeze(1).add_(group['normuon_eps']))
|
|
479
468
|
# Scale learning rate
|
|
480
469
|
update_norm = torch.linalg.vector_norm(update)
|
|
481
470
|
scaled_lr = group['normuon_lr_scale'] * lr * (p.numel()**0.5) / update_norm.add_(group['normuon_eps'])
|
|
@@ -1,10 +1,10 @@
|
|
|
1
|
-
adv_optm/__init__.py,sha256=
|
|
2
|
-
adv_optm/optim/AdaMuon_adv.py,sha256=
|
|
1
|
+
adv_optm/__init__.py,sha256=1AKxG--scx5Bl9G08tQcnfzAMaQVSgmW99uy3v2QWMw,380
|
|
2
|
+
adv_optm/optim/AdaMuon_adv.py,sha256=7Had92OcsCiN1E9UJRyrpPV7VzHqmIvS-qM6OEcc24I,34671
|
|
3
3
|
adv_optm/optim/AdamW_adv.py,sha256=jgMuRAfsnUh_2wUEZgYpJX5uwoT_kQjtMs2Xn2vJ3x0,17480
|
|
4
4
|
adv_optm/optim/Adopt_adv.py,sha256=kbAeBG4bXWBvgj_qrE9W67J6c0swpEi4Erj2rfYrMXE,21252
|
|
5
5
|
adv_optm/optim/Lion_Prodigy_adv.py,sha256=LEA3UYJpPeFnmxeniLNv1u2LKKj4ufx3Bq_MLw-nWXk,14617
|
|
6
6
|
adv_optm/optim/Lion_adv.py,sha256=aGNAplZlyXYgVllYcV_s4bK8iC4fv6EizFoWIMNLdBc,8299
|
|
7
|
-
adv_optm/optim/Muon_adv.py,sha256=
|
|
7
|
+
adv_optm/optim/Muon_adv.py,sha256=tZY8K3pNBCGk1V09GbK05lJooFw92NfkF7_T548up3Q,33171
|
|
8
8
|
adv_optm/optim/Prodigy_adv.py,sha256=k7f2J_RQpnrUXjwER_XOokISlQWpTSwGG-OL-bjMfBk,26061
|
|
9
9
|
adv_optm/optim/Simplified_AdEMAMix.py,sha256=nEIA3yM11nBooKzHudB5l3x4UdFRBYRwiKVUkGmO0K8,12971
|
|
10
10
|
adv_optm/optim/__init__.py,sha256=hpUWE6CKtt_rvMdgQVb3PtjhfZAvAxTq6hp8H8rIpBo,489
|
|
@@ -16,8 +16,8 @@ adv_optm/util/Newton_Schulz.py,sha256=bBboYw_jm5_FMf0Citl79uqNedkHOTjQnUI7rZgLBm
|
|
|
16
16
|
adv_optm/util/One_Bit_Boolean.py,sha256=Wat49esdwohuN-OHOFMW8D0aOQgV9cP5Rl8z6yfmpos,1068
|
|
17
17
|
adv_optm/util/OrthoGrad.py,sha256=NzInuBQGy_Ja__M1R9XbvqVaQ0fhGbtGgFE9YON7B3I,707
|
|
18
18
|
adv_optm/util/__init__.py,sha256=CXzS703GB4gil85khZi7sgKOnbzXGBOltshIOSPqj18,435
|
|
19
|
-
adv_optm-1.2.
|
|
20
|
-
adv_optm-1.2.
|
|
21
|
-
adv_optm-1.2.
|
|
22
|
-
adv_optm-1.2.
|
|
23
|
-
adv_optm-1.2.
|
|
19
|
+
adv_optm-1.2.dev19.dist-info/licenses/LICENSE,sha256=HrhfyXIkWY2tGFK11kg7vPCqhgh5DcxleloqdhrpyMY,11558
|
|
20
|
+
adv_optm-1.2.dev19.dist-info/METADATA,sha256=pQm5WuMKvf5Xse10viziVK9ry1UufcYRDwOd55jad8Y,14023
|
|
21
|
+
adv_optm-1.2.dev19.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
22
|
+
adv_optm-1.2.dev19.dist-info/top_level.txt,sha256=iNfBIIzu-lPrQ7jyC56WBCcbkRwitM2nJ15-MRQ_6fg,9
|
|
23
|
+
adv_optm-1.2.dev19.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|