adv-optm 1.2.dev17__py3-none-any.whl → 1.2.dev18__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of adv-optm might be problematic. Click here for more details.
- adv_optm/__init__.py +1 -1
- adv_optm/optim/AdaMuon_adv.py +7 -1
- adv_optm/optim/AdamW_adv.py +2 -2
- adv_optm/optim/Adopt_adv.py +2 -2
- adv_optm/optim/Muon_adv.py +7 -0
- adv_optm/optim/Prodigy_adv.py +1 -1
- {adv_optm-1.2.dev17.dist-info → adv_optm-1.2.dev18.dist-info}/METADATA +1 -1
- {adv_optm-1.2.dev17.dist-info → adv_optm-1.2.dev18.dist-info}/RECORD +11 -11
- {adv_optm-1.2.dev17.dist-info → adv_optm-1.2.dev18.dist-info}/WHEEL +0 -0
- {adv_optm-1.2.dev17.dist-info → adv_optm-1.2.dev18.dist-info}/licenses/LICENSE +0 -0
- {adv_optm-1.2.dev17.dist-info → adv_optm-1.2.dev18.dist-info}/top_level.txt +0 -0
adv_optm/__init__.py
CHANGED
adv_optm/optim/AdaMuon_adv.py
CHANGED
|
@@ -46,6 +46,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
46
46
|
(default: (3.4445, -4.7750, 2.0315)).
|
|
47
47
|
stochastic_rounding (bool): whether to use stochastic rounding for
|
|
48
48
|
BF16 parameter updates (default: True).
|
|
49
|
+
orthogonal_gradient (bool): whether to use OrthoGrad. (default: False)
|
|
49
50
|
nesterov (bool): enables Nesterov momentum (default: False).
|
|
50
51
|
use_atan2 (bool): whether to use the atan2 update rule. (default: False)
|
|
51
52
|
Simplified_AdEMAMix (bool): whether to use the Simplified AdEMAMix update rule.
|
|
@@ -95,6 +96,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
95
96
|
ns_eps: float = 1e-7,
|
|
96
97
|
ns_coeffs: tuple[float, float, float] = (3.4445, -4.7750, 2.0315),
|
|
97
98
|
stochastic_rounding: bool = False,
|
|
99
|
+
orthogonal_gradient: bool = False,
|
|
98
100
|
use_atan2: bool = False,
|
|
99
101
|
nesterov: bool = False,
|
|
100
102
|
Simplified_AdEMAMix: bool = False,
|
|
@@ -147,7 +149,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
147
149
|
"vector_reshape": vector_reshape,
|
|
148
150
|
"nesterov":nesterov, "use_atan2":use_atan2,
|
|
149
151
|
"Simplified_AdEMAMix": Simplified_AdEMAMix, "alpha_grad": alpha_grad,
|
|
150
|
-
"normuon_variant": normuon_variant,
|
|
152
|
+
"normuon_variant": normuon_variant, "orthogonal_gradient": orthogonal_gradient,
|
|
151
153
|
# Low-rank Ortho
|
|
152
154
|
"low_rank_ortho": low_rank_ortho, "ortho_rank": ortho_rank,
|
|
153
155
|
"compiled_optimizer":compiled_optimizer,
|
|
@@ -282,6 +284,10 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
282
284
|
nesterov = group['nesterov']
|
|
283
285
|
Simplified_AdEMAMix = group['Simplified_AdEMAMix']
|
|
284
286
|
alpha_grad = group['alpha_grad']
|
|
287
|
+
if grad.dtype != torch.float32 and state.get('factored', False):
|
|
288
|
+
grad = grad.float()
|
|
289
|
+
if group.get("orthogonal_gradient"):
|
|
290
|
+
grad = _orthogonalize_gradient(p, grad)
|
|
285
291
|
|
|
286
292
|
if state['factored']: # Factored AdaMuon
|
|
287
293
|
|
adv_optm/optim/AdamW_adv.py
CHANGED
|
@@ -252,7 +252,7 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
252
252
|
# Update momentum in full-size
|
|
253
253
|
mt.mul_(beta1).add_(grad_reshaped, alpha=1.0 - beta1)
|
|
254
254
|
if self.grams_moment:
|
|
255
|
-
mt
|
|
255
|
+
mt = (grad_reshaped.sign().mul_(mt.abs()))
|
|
256
256
|
elif self.cautious_mask:
|
|
257
257
|
mask = (mt * grad_reshaped > 0).to(grad_reshaped.dtype)
|
|
258
258
|
mask.div_(mask.mean().clamp_(min=1e-3))
|
|
@@ -310,7 +310,7 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
310
310
|
exp_avg = state['exp_avg']
|
|
311
311
|
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
|
|
312
312
|
if self.grams_moment:
|
|
313
|
-
exp_avg = grad.sign()
|
|
313
|
+
exp_avg = grad.sign().mul_(exp_avg.abs())
|
|
314
314
|
elif self.cautious_mask:
|
|
315
315
|
mask = (exp_avg * grad > 0).to(grad.dtype)
|
|
316
316
|
mask.div_(mask.mean().clamp_(min=1e-3))
|
adv_optm/optim/Adopt_adv.py
CHANGED
|
@@ -307,7 +307,7 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
307
307
|
else:
|
|
308
308
|
mt.mul_(beta1).add_(normalized_grad, alpha=1.0 - beta1)
|
|
309
309
|
if self.grams_moment:
|
|
310
|
-
mt = grad_reshaped.sign()
|
|
310
|
+
mt = grad_reshaped.sign().mul_(mt.abs())
|
|
311
311
|
elif self.cautious_mask:
|
|
312
312
|
mask = (mt * grad_reshaped > 0).to(grad_reshaped.dtype)
|
|
313
313
|
mask.div_(mask.mean().clamp_(min=1e-3))
|
|
@@ -376,7 +376,7 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
376
376
|
m.mul_(beta1).add_(normalized_grad, alpha=1.0 - beta1)
|
|
377
377
|
|
|
378
378
|
if self.grams_moment:
|
|
379
|
-
m = grad.sign()
|
|
379
|
+
m = grad.sign().mul_(m.abs())
|
|
380
380
|
elif self.cautious_mask:
|
|
381
381
|
mask = (m * grad > 0).to(grad.dtype)
|
|
382
382
|
mask.div_(mask.mean().clamp_(min=1e-3))
|
adv_optm/optim/Muon_adv.py
CHANGED
|
@@ -41,6 +41,7 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
41
41
|
stability. (default: 100.0)
|
|
42
42
|
stochastic_rounding (bool): whether to use stochastic rounding for
|
|
43
43
|
BF16 parameter updates (default: True).
|
|
44
|
+
orthogonal_gradient (bool): whether to use OrthoGrad. (default: False)
|
|
44
45
|
vector_reshape_muon (bool): whether to reshape 1D vectors into 2D
|
|
45
46
|
matrices for muon NewtonSchulz (default: False).
|
|
46
47
|
vector_reshape (bool): whether to reshape 1D vectors into 2D
|
|
@@ -92,6 +93,7 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
92
93
|
Simplified_AdEMAMix: bool = False,
|
|
93
94
|
alpha_grad: float = 100.0,
|
|
94
95
|
stochastic_rounding: bool = True,
|
|
96
|
+
orthogonal_gradient: bool = False,
|
|
95
97
|
vector_reshape_muon: bool = False,
|
|
96
98
|
vector_reshape: bool = False,
|
|
97
99
|
nnmf_factor: bool = False,
|
|
@@ -149,6 +151,7 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
149
151
|
"vector_reshape": vector_reshape,
|
|
150
152
|
"vector_reshape_muon": vector_reshape_muon,
|
|
151
153
|
"Simplified_AdEMAMix": Simplified_AdEMAMix, "alpha_grad": alpha_grad,
|
|
154
|
+
"orthogonal_gradient": orthogonal_gradient,
|
|
152
155
|
'compiled_optimizer': compiled_optimizer,
|
|
153
156
|
# Low-rank Ortho
|
|
154
157
|
"low_rank_ortho": low_rank_ortho, "ortho_rank": ortho_rank,
|
|
@@ -293,6 +296,10 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
293
296
|
nesterov = group['nesterov']
|
|
294
297
|
Simplified_AdEMAMix = group['Simplified_AdEMAMix']
|
|
295
298
|
alpha_grad = group['alpha_grad']
|
|
299
|
+
if grad.dtype != torch.float32 and state.get('factored', False):
|
|
300
|
+
grad = grad.float()
|
|
301
|
+
if group.get("orthogonal_gradient"):
|
|
302
|
+
grad = _orthogonalize_gradient(p, grad)
|
|
296
303
|
|
|
297
304
|
if state['factored']: # Factored Muon
|
|
298
305
|
|
adv_optm/optim/Prodigy_adv.py
CHANGED
|
@@ -343,7 +343,7 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
343
343
|
else:
|
|
344
344
|
mt.mul_(self.beta1).add_(grad_reshaped, alpha=self.d * (1.0 - self.beta1))
|
|
345
345
|
if self.grams_moment:
|
|
346
|
-
mt
|
|
346
|
+
mt = (grad_reshaped.sign().mul_(mt.abs()))
|
|
347
347
|
elif self.cautious_mask:
|
|
348
348
|
mask = (mt * grad_reshaped > 0).to(grad_reshaped.dtype)
|
|
349
349
|
mask.div_(mask.mean().clamp_(min=1e-3))
|
|
@@ -1,11 +1,11 @@
|
|
|
1
|
-
adv_optm/__init__.py,sha256=
|
|
2
|
-
adv_optm/optim/AdaMuon_adv.py,sha256=
|
|
3
|
-
adv_optm/optim/AdamW_adv.py,sha256=
|
|
4
|
-
adv_optm/optim/Adopt_adv.py,sha256=
|
|
1
|
+
adv_optm/__init__.py,sha256=1UzgEkreoqaobiwUZ8yR-8Fnda7T7XiHQ4PhJKQocy4,380
|
|
2
|
+
adv_optm/optim/AdaMuon_adv.py,sha256=VpNsw2CnU8bZThj9cJJ6HGIATPxv4VkIf3xTsUMXQAY,35027
|
|
3
|
+
adv_optm/optim/AdamW_adv.py,sha256=jgMuRAfsnUh_2wUEZgYpJX5uwoT_kQjtMs2Xn2vJ3x0,17480
|
|
4
|
+
adv_optm/optim/Adopt_adv.py,sha256=kbAeBG4bXWBvgj_qrE9W67J6c0swpEi4Erj2rfYrMXE,21252
|
|
5
5
|
adv_optm/optim/Lion_Prodigy_adv.py,sha256=LEA3UYJpPeFnmxeniLNv1u2LKKj4ufx3Bq_MLw-nWXk,14617
|
|
6
6
|
adv_optm/optim/Lion_adv.py,sha256=aGNAplZlyXYgVllYcV_s4bK8iC4fv6EizFoWIMNLdBc,8299
|
|
7
|
-
adv_optm/optim/Muon_adv.py,sha256=
|
|
8
|
-
adv_optm/optim/Prodigy_adv.py,sha256=
|
|
7
|
+
adv_optm/optim/Muon_adv.py,sha256=0D4k8UfMSzITJwQEDfqpceD5H7HQvv0f8uyVKvdvkHo,33704
|
|
8
|
+
adv_optm/optim/Prodigy_adv.py,sha256=k7f2J_RQpnrUXjwER_XOokISlQWpTSwGG-OL-bjMfBk,26061
|
|
9
9
|
adv_optm/optim/Simplified_AdEMAMix.py,sha256=nEIA3yM11nBooKzHudB5l3x4UdFRBYRwiKVUkGmO0K8,12971
|
|
10
10
|
adv_optm/optim/__init__.py,sha256=hpUWE6CKtt_rvMdgQVb3PtjhfZAvAxTq6hp8H8rIpBo,489
|
|
11
11
|
adv_optm/util/BF16_Stochastic_Rounding.py,sha256=Q5H0BcogmE4atP65dLoI21HKSf50lRdsBDfeF6v9Tbg,1548
|
|
@@ -16,8 +16,8 @@ adv_optm/util/Newton_Schulz.py,sha256=bBboYw_jm5_FMf0Citl79uqNedkHOTjQnUI7rZgLBm
|
|
|
16
16
|
adv_optm/util/One_Bit_Boolean.py,sha256=Wat49esdwohuN-OHOFMW8D0aOQgV9cP5Rl8z6yfmpos,1068
|
|
17
17
|
adv_optm/util/OrthoGrad.py,sha256=NzInuBQGy_Ja__M1R9XbvqVaQ0fhGbtGgFE9YON7B3I,707
|
|
18
18
|
adv_optm/util/__init__.py,sha256=CXzS703GB4gil85khZi7sgKOnbzXGBOltshIOSPqj18,435
|
|
19
|
-
adv_optm-1.2.
|
|
20
|
-
adv_optm-1.2.
|
|
21
|
-
adv_optm-1.2.
|
|
22
|
-
adv_optm-1.2.
|
|
23
|
-
adv_optm-1.2.
|
|
19
|
+
adv_optm-1.2.dev18.dist-info/licenses/LICENSE,sha256=HrhfyXIkWY2tGFK11kg7vPCqhgh5DcxleloqdhrpyMY,11558
|
|
20
|
+
adv_optm-1.2.dev18.dist-info/METADATA,sha256=cfQdGhiRlf_-xnPKqwuCE8PR6faLYx_RC6MrkjYDqI8,14023
|
|
21
|
+
adv_optm-1.2.dev18.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
22
|
+
adv_optm-1.2.dev18.dist-info/top_level.txt,sha256=iNfBIIzu-lPrQ7jyC56WBCcbkRwitM2nJ15-MRQ_6fg,9
|
|
23
|
+
adv_optm-1.2.dev18.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|