adv-optm 1.2.dev17__py3-none-any.whl → 1.2.dev19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of adv-optm might be problematic. Click here for more details.
- adv_optm/__init__.py +1 -1
- adv_optm/optim/AdaMuon_adv.py +9 -11
- adv_optm/optim/AdamW_adv.py +2 -2
- adv_optm/optim/Adopt_adv.py +2 -2
- adv_optm/optim/Muon_adv.py +9 -13
- adv_optm/optim/Prodigy_adv.py +1 -1
- {adv_optm-1.2.dev17.dist-info → adv_optm-1.2.dev19.dist-info}/METADATA +1 -1
- {adv_optm-1.2.dev17.dist-info → adv_optm-1.2.dev19.dist-info}/RECORD +11 -11
- {adv_optm-1.2.dev17.dist-info → adv_optm-1.2.dev19.dist-info}/WHEEL +0 -0
- {adv_optm-1.2.dev17.dist-info → adv_optm-1.2.dev19.dist-info}/licenses/LICENSE +0 -0
- {adv_optm-1.2.dev17.dist-info → adv_optm-1.2.dev19.dist-info}/top_level.txt +0 -0
adv_optm/__init__.py
CHANGED
adv_optm/optim/AdaMuon_adv.py
CHANGED
|
@@ -46,6 +46,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
46
46
|
(default: (3.4445, -4.7750, 2.0315)).
|
|
47
47
|
stochastic_rounding (bool): whether to use stochastic rounding for
|
|
48
48
|
BF16 parameter updates (default: True).
|
|
49
|
+
orthogonal_gradient (bool): whether to use OrthoGrad. (default: False)
|
|
49
50
|
nesterov (bool): enables Nesterov momentum (default: False).
|
|
50
51
|
use_atan2 (bool): whether to use the atan2 update rule. (default: False)
|
|
51
52
|
Simplified_AdEMAMix (bool): whether to use the Simplified AdEMAMix update rule.
|
|
@@ -95,6 +96,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
95
96
|
ns_eps: float = 1e-7,
|
|
96
97
|
ns_coeffs: tuple[float, float, float] = (3.4445, -4.7750, 2.0315),
|
|
97
98
|
stochastic_rounding: bool = False,
|
|
99
|
+
orthogonal_gradient: bool = False,
|
|
98
100
|
use_atan2: bool = False,
|
|
99
101
|
nesterov: bool = False,
|
|
100
102
|
Simplified_AdEMAMix: bool = False,
|
|
@@ -147,7 +149,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
147
149
|
"vector_reshape": vector_reshape,
|
|
148
150
|
"nesterov":nesterov, "use_atan2":use_atan2,
|
|
149
151
|
"Simplified_AdEMAMix": Simplified_AdEMAMix, "alpha_grad": alpha_grad,
|
|
150
|
-
"normuon_variant": normuon_variant,
|
|
152
|
+
"normuon_variant": normuon_variant, "orthogonal_gradient": orthogonal_gradient,
|
|
151
153
|
# Low-rank Ortho
|
|
152
154
|
"low_rank_ortho": low_rank_ortho, "ortho_rank": ortho_rank,
|
|
153
155
|
"compiled_optimizer":compiled_optimizer,
|
|
@@ -282,6 +284,10 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
282
284
|
nesterov = group['nesterov']
|
|
283
285
|
Simplified_AdEMAMix = group['Simplified_AdEMAMix']
|
|
284
286
|
alpha_grad = group['alpha_grad']
|
|
287
|
+
if grad.dtype != torch.float32 and state.get('factored', False):
|
|
288
|
+
grad = grad.float()
|
|
289
|
+
if group.get("orthogonal_gradient"):
|
|
290
|
+
grad = _orthogonalize_gradient(p, grad)
|
|
285
291
|
|
|
286
292
|
if state['factored']: # Factored AdaMuon
|
|
287
293
|
|
|
@@ -345,11 +351,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
345
351
|
mean_squared_update = torch.mean(update.square(), dim=1)
|
|
346
352
|
v_t.mul_(beta2).add_(mean_squared_update, alpha=1 - beta2)
|
|
347
353
|
# Normalize update
|
|
348
|
-
|
|
349
|
-
a = 1.2732395
|
|
350
|
-
update.atan2_(v_t.sqrt().unsqueeze(1)).mul_(a)
|
|
351
|
-
else:
|
|
352
|
-
update.div_(v_t.sqrt().unsqueeze(1).add_(group['eps']))
|
|
354
|
+
update.div_(v_t.sqrt().unsqueeze(1).add_(group['eps']))
|
|
353
355
|
# Scale learning rate
|
|
354
356
|
update_norm = torch.linalg.vector_norm(update)
|
|
355
357
|
scaled_lr = group['rms_target'] * lr * (p.numel()**0.5) / update_norm.add_(group['eps'])
|
|
@@ -454,11 +456,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
454
456
|
mean_squared_update = torch.mean(update.square(), dim=1)
|
|
455
457
|
v_t.mul_(beta2).add_(mean_squared_update, alpha=1 - beta2)
|
|
456
458
|
# Normalize update
|
|
457
|
-
|
|
458
|
-
a = 1.2732395
|
|
459
|
-
update.atan2_(v_t.sqrt().unsqueeze(1)).mul_(a)
|
|
460
|
-
else:
|
|
461
|
-
update.div_(v_t.sqrt().unsqueeze(1).add_(group['eps']))
|
|
459
|
+
update.div_(v_t.sqrt().unsqueeze(1).add_(group['eps']))
|
|
462
460
|
# Scale learning rate
|
|
463
461
|
update_norm = torch.linalg.vector_norm(update)
|
|
464
462
|
scaled_lr = group['rms_target'] * lr * (p.numel()**0.5) / update_norm.add_(group['eps'])
|
adv_optm/optim/AdamW_adv.py
CHANGED
|
@@ -252,7 +252,7 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
252
252
|
# Update momentum in full-size
|
|
253
253
|
mt.mul_(beta1).add_(grad_reshaped, alpha=1.0 - beta1)
|
|
254
254
|
if self.grams_moment:
|
|
255
|
-
mt
|
|
255
|
+
mt = (grad_reshaped.sign().mul_(mt.abs()))
|
|
256
256
|
elif self.cautious_mask:
|
|
257
257
|
mask = (mt * grad_reshaped > 0).to(grad_reshaped.dtype)
|
|
258
258
|
mask.div_(mask.mean().clamp_(min=1e-3))
|
|
@@ -310,7 +310,7 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
310
310
|
exp_avg = state['exp_avg']
|
|
311
311
|
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
|
|
312
312
|
if self.grams_moment:
|
|
313
|
-
exp_avg = grad.sign()
|
|
313
|
+
exp_avg = grad.sign().mul_(exp_avg.abs())
|
|
314
314
|
elif self.cautious_mask:
|
|
315
315
|
mask = (exp_avg * grad > 0).to(grad.dtype)
|
|
316
316
|
mask.div_(mask.mean().clamp_(min=1e-3))
|
adv_optm/optim/Adopt_adv.py
CHANGED
|
@@ -307,7 +307,7 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
307
307
|
else:
|
|
308
308
|
mt.mul_(beta1).add_(normalized_grad, alpha=1.0 - beta1)
|
|
309
309
|
if self.grams_moment:
|
|
310
|
-
mt = grad_reshaped.sign()
|
|
310
|
+
mt = grad_reshaped.sign().mul_(mt.abs())
|
|
311
311
|
elif self.cautious_mask:
|
|
312
312
|
mask = (mt * grad_reshaped > 0).to(grad_reshaped.dtype)
|
|
313
313
|
mask.div_(mask.mean().clamp_(min=1e-3))
|
|
@@ -376,7 +376,7 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
376
376
|
m.mul_(beta1).add_(normalized_grad, alpha=1.0 - beta1)
|
|
377
377
|
|
|
378
378
|
if self.grams_moment:
|
|
379
|
-
m = grad.sign()
|
|
379
|
+
m = grad.sign().mul_(m.abs())
|
|
380
380
|
elif self.cautious_mask:
|
|
381
381
|
mask = (m * grad > 0).to(grad.dtype)
|
|
382
382
|
mask.div_(mask.mean().clamp_(min=1e-3))
|
adv_optm/optim/Muon_adv.py
CHANGED
|
@@ -41,6 +41,7 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
41
41
|
stability. (default: 100.0)
|
|
42
42
|
stochastic_rounding (bool): whether to use stochastic rounding for
|
|
43
43
|
BF16 parameter updates (default: True).
|
|
44
|
+
orthogonal_gradient (bool): whether to use OrthoGrad. (default: False)
|
|
44
45
|
vector_reshape_muon (bool): whether to reshape 1D vectors into 2D
|
|
45
46
|
matrices for muon NewtonSchulz (default: False).
|
|
46
47
|
vector_reshape (bool): whether to reshape 1D vectors into 2D
|
|
@@ -59,7 +60,6 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
59
60
|
normuon_eps (float): Epsilon for NorMuon normalization stability. (default: 1e-8)
|
|
60
61
|
normuon_lr_scale (float): Scaling factor for the NorMuon learning rate.
|
|
61
62
|
(default: 0.2)
|
|
62
|
-
normuon_atan2 (bool): whether to use the atan2 for NorMuon. (default: False)
|
|
63
63
|
accelerated_ns (bool): If True, enables Chebyshev-accelerated Newton-Schulz, which
|
|
64
64
|
dynamically calculates optimal 3rd-order polynomial coefficients. (default: False)
|
|
65
65
|
cns_a_bound (float): Initial lower bound for singular values for CANS. (default: 1e-4)
|
|
@@ -92,6 +92,7 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
92
92
|
Simplified_AdEMAMix: bool = False,
|
|
93
93
|
alpha_grad: float = 100.0,
|
|
94
94
|
stochastic_rounding: bool = True,
|
|
95
|
+
orthogonal_gradient: bool = False,
|
|
95
96
|
vector_reshape_muon: bool = False,
|
|
96
97
|
vector_reshape: bool = False,
|
|
97
98
|
nnmf_factor: bool = False,
|
|
@@ -103,7 +104,6 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
103
104
|
beta2_normuon: float = 0.95,
|
|
104
105
|
normuon_eps: float = 1e-8,
|
|
105
106
|
normuon_lr_scale: float = 0.2,
|
|
106
|
-
normuon_atan2: bool = False,
|
|
107
107
|
# CANS
|
|
108
108
|
accelerated_ns: bool = False,
|
|
109
109
|
cns_a_bound: float = 1e-4,
|
|
@@ -149,13 +149,13 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
149
149
|
"vector_reshape": vector_reshape,
|
|
150
150
|
"vector_reshape_muon": vector_reshape_muon,
|
|
151
151
|
"Simplified_AdEMAMix": Simplified_AdEMAMix, "alpha_grad": alpha_grad,
|
|
152
|
+
"orthogonal_gradient": orthogonal_gradient,
|
|
152
153
|
'compiled_optimizer': compiled_optimizer,
|
|
153
154
|
# Low-rank Ortho
|
|
154
155
|
"low_rank_ortho": low_rank_ortho, "ortho_rank": ortho_rank,
|
|
155
156
|
# NorMuon
|
|
156
157
|
"normuon_variant": normuon_variant, "beta2_normuon": beta2_normuon,
|
|
157
158
|
"normuon_eps": normuon_eps, "normuon_lr_scale": normuon_lr_scale,
|
|
158
|
-
"normuon_atan2": normuon_atan2,
|
|
159
159
|
# CANS
|
|
160
160
|
"accelerated_ns": accelerated_ns, "cns_a_bound": cns_a_bound,
|
|
161
161
|
# AdamW_adv defaults
|
|
@@ -293,6 +293,10 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
293
293
|
nesterov = group['nesterov']
|
|
294
294
|
Simplified_AdEMAMix = group['Simplified_AdEMAMix']
|
|
295
295
|
alpha_grad = group['alpha_grad']
|
|
296
|
+
if grad.dtype != torch.float32 and state.get('factored', False):
|
|
297
|
+
grad = grad.float()
|
|
298
|
+
if group.get("orthogonal_gradient"):
|
|
299
|
+
grad = _orthogonalize_gradient(p, grad)
|
|
296
300
|
|
|
297
301
|
if state['factored']: # Factored Muon
|
|
298
302
|
|
|
@@ -359,11 +363,7 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
359
363
|
mean_squared_update = torch.mean(update.square(), dim=1)
|
|
360
364
|
v_t.mul_(beta2_normuon).add_(mean_squared_update, alpha=1 - beta2_normuon)
|
|
361
365
|
# Normalize update
|
|
362
|
-
|
|
363
|
-
a = 1.2732395
|
|
364
|
-
update.atan2_(v_t.sqrt().unsqueeze(1)).mul_(a)
|
|
365
|
-
else:
|
|
366
|
-
update.div_(v_t.sqrt().unsqueeze(1).add_(group['normuon_eps']))
|
|
366
|
+
update.div_(v_t.sqrt().unsqueeze(1).add_(group['normuon_eps']))
|
|
367
367
|
# Scale learning rate
|
|
368
368
|
update_norm = torch.linalg.vector_norm(update)
|
|
369
369
|
|
|
@@ -464,11 +464,7 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
464
464
|
mean_squared_update = torch.mean(update.square(), dim=1)
|
|
465
465
|
v_t.mul_(beta2_normuon).add_(mean_squared_update, alpha=1 - beta2_normuon)
|
|
466
466
|
# Normalize update
|
|
467
|
-
|
|
468
|
-
a = 1.2732395
|
|
469
|
-
update.atan2_(v_t.sqrt().unsqueeze(1)).mul_(a)
|
|
470
|
-
else:
|
|
471
|
-
update.div_(v_t.sqrt().unsqueeze(1).add_(group['normuon_eps']))
|
|
467
|
+
update.div_(v_t.sqrt().unsqueeze(1).add_(group['normuon_eps']))
|
|
472
468
|
# Scale learning rate
|
|
473
469
|
update_norm = torch.linalg.vector_norm(update)
|
|
474
470
|
scaled_lr = group['normuon_lr_scale'] * lr * (p.numel()**0.5) / update_norm.add_(group['normuon_eps'])
|
adv_optm/optim/Prodigy_adv.py
CHANGED
|
@@ -343,7 +343,7 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
343
343
|
else:
|
|
344
344
|
mt.mul_(self.beta1).add_(grad_reshaped, alpha=self.d * (1.0 - self.beta1))
|
|
345
345
|
if self.grams_moment:
|
|
346
|
-
mt
|
|
346
|
+
mt = (grad_reshaped.sign().mul_(mt.abs()))
|
|
347
347
|
elif self.cautious_mask:
|
|
348
348
|
mask = (mt * grad_reshaped > 0).to(grad_reshaped.dtype)
|
|
349
349
|
mask.div_(mask.mean().clamp_(min=1e-3))
|
|
@@ -1,11 +1,11 @@
|
|
|
1
|
-
adv_optm/__init__.py,sha256=
|
|
2
|
-
adv_optm/optim/AdaMuon_adv.py,sha256=
|
|
3
|
-
adv_optm/optim/AdamW_adv.py,sha256=
|
|
4
|
-
adv_optm/optim/Adopt_adv.py,sha256=
|
|
1
|
+
adv_optm/__init__.py,sha256=1AKxG--scx5Bl9G08tQcnfzAMaQVSgmW99uy3v2QWMw,380
|
|
2
|
+
adv_optm/optim/AdaMuon_adv.py,sha256=7Had92OcsCiN1E9UJRyrpPV7VzHqmIvS-qM6OEcc24I,34671
|
|
3
|
+
adv_optm/optim/AdamW_adv.py,sha256=jgMuRAfsnUh_2wUEZgYpJX5uwoT_kQjtMs2Xn2vJ3x0,17480
|
|
4
|
+
adv_optm/optim/Adopt_adv.py,sha256=kbAeBG4bXWBvgj_qrE9W67J6c0swpEi4Erj2rfYrMXE,21252
|
|
5
5
|
adv_optm/optim/Lion_Prodigy_adv.py,sha256=LEA3UYJpPeFnmxeniLNv1u2LKKj4ufx3Bq_MLw-nWXk,14617
|
|
6
6
|
adv_optm/optim/Lion_adv.py,sha256=aGNAplZlyXYgVllYcV_s4bK8iC4fv6EizFoWIMNLdBc,8299
|
|
7
|
-
adv_optm/optim/Muon_adv.py,sha256=
|
|
8
|
-
adv_optm/optim/Prodigy_adv.py,sha256=
|
|
7
|
+
adv_optm/optim/Muon_adv.py,sha256=tZY8K3pNBCGk1V09GbK05lJooFw92NfkF7_T548up3Q,33171
|
|
8
|
+
adv_optm/optim/Prodigy_adv.py,sha256=k7f2J_RQpnrUXjwER_XOokISlQWpTSwGG-OL-bjMfBk,26061
|
|
9
9
|
adv_optm/optim/Simplified_AdEMAMix.py,sha256=nEIA3yM11nBooKzHudB5l3x4UdFRBYRwiKVUkGmO0K8,12971
|
|
10
10
|
adv_optm/optim/__init__.py,sha256=hpUWE6CKtt_rvMdgQVb3PtjhfZAvAxTq6hp8H8rIpBo,489
|
|
11
11
|
adv_optm/util/BF16_Stochastic_Rounding.py,sha256=Q5H0BcogmE4atP65dLoI21HKSf50lRdsBDfeF6v9Tbg,1548
|
|
@@ -16,8 +16,8 @@ adv_optm/util/Newton_Schulz.py,sha256=bBboYw_jm5_FMf0Citl79uqNedkHOTjQnUI7rZgLBm
|
|
|
16
16
|
adv_optm/util/One_Bit_Boolean.py,sha256=Wat49esdwohuN-OHOFMW8D0aOQgV9cP5Rl8z6yfmpos,1068
|
|
17
17
|
adv_optm/util/OrthoGrad.py,sha256=NzInuBQGy_Ja__M1R9XbvqVaQ0fhGbtGgFE9YON7B3I,707
|
|
18
18
|
adv_optm/util/__init__.py,sha256=CXzS703GB4gil85khZi7sgKOnbzXGBOltshIOSPqj18,435
|
|
19
|
-
adv_optm-1.2.
|
|
20
|
-
adv_optm-1.2.
|
|
21
|
-
adv_optm-1.2.
|
|
22
|
-
adv_optm-1.2.
|
|
23
|
-
adv_optm-1.2.
|
|
19
|
+
adv_optm-1.2.dev19.dist-info/licenses/LICENSE,sha256=HrhfyXIkWY2tGFK11kg7vPCqhgh5DcxleloqdhrpyMY,11558
|
|
20
|
+
adv_optm-1.2.dev19.dist-info/METADATA,sha256=pQm5WuMKvf5Xse10viziVK9ry1UufcYRDwOd55jad8Y,14023
|
|
21
|
+
adv_optm-1.2.dev19.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
22
|
+
adv_optm-1.2.dev19.dist-info/top_level.txt,sha256=iNfBIIzu-lPrQ7jyC56WBCcbkRwitM2nJ15-MRQ_6fg,9
|
|
23
|
+
adv_optm-1.2.dev19.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|