adv-optm 1.2.dev17__py3-none-any.whl → 1.2.dev18__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of adv-optm might be problematic. Click here for more details.

adv_optm/__init__.py CHANGED
@@ -20,4 +20,4 @@ __all__ = [
20
20
  "AdaMuon_adv",
21
21
  ]
22
22
 
23
- __version__ = "1.2.dev17"
23
+ __version__ = "1.2.dev18"
@@ -46,6 +46,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
46
46
  (default: (3.4445, -4.7750, 2.0315)).
47
47
  stochastic_rounding (bool): whether to use stochastic rounding for
48
48
  BF16 parameter updates (default: True).
49
+ orthogonal_gradient (bool): whether to use OrthoGrad. (default: False)
49
50
  nesterov (bool): enables Nesterov momentum (default: False).
50
51
  use_atan2 (bool): whether to use the atan2 update rule. (default: False)
51
52
  Simplified_AdEMAMix (bool): whether to use the Simplified AdEMAMix update rule.
@@ -95,6 +96,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
95
96
  ns_eps: float = 1e-7,
96
97
  ns_coeffs: tuple[float, float, float] = (3.4445, -4.7750, 2.0315),
97
98
  stochastic_rounding: bool = False,
99
+ orthogonal_gradient: bool = False,
98
100
  use_atan2: bool = False,
99
101
  nesterov: bool = False,
100
102
  Simplified_AdEMAMix: bool = False,
@@ -147,7 +149,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
147
149
  "vector_reshape": vector_reshape,
148
150
  "nesterov":nesterov, "use_atan2":use_atan2,
149
151
  "Simplified_AdEMAMix": Simplified_AdEMAMix, "alpha_grad": alpha_grad,
150
- "normuon_variant": normuon_variant,
152
+ "normuon_variant": normuon_variant, "orthogonal_gradient": orthogonal_gradient,
151
153
  # Low-rank Ortho
152
154
  "low_rank_ortho": low_rank_ortho, "ortho_rank": ortho_rank,
153
155
  "compiled_optimizer":compiled_optimizer,
@@ -282,6 +284,10 @@ class AdaMuon_adv(torch.optim.Optimizer):
282
284
  nesterov = group['nesterov']
283
285
  Simplified_AdEMAMix = group['Simplified_AdEMAMix']
284
286
  alpha_grad = group['alpha_grad']
287
+ if grad.dtype != torch.float32 and state.get('factored', False):
288
+ grad = grad.float()
289
+ if group.get("orthogonal_gradient"):
290
+ grad = _orthogonalize_gradient(p, grad)
285
291
 
286
292
  if state['factored']: # Factored AdaMuon
287
293
 
@@ -252,7 +252,7 @@ class AdamW_adv(torch.optim.Optimizer):
252
252
  # Update momentum in full-size
253
253
  mt.mul_(beta1).add_(grad_reshaped, alpha=1.0 - beta1)
254
254
  if self.grams_moment:
255
- mt.copy_(grad_reshaped.sign() * mt.abs())
255
+ mt = (grad_reshaped.sign().mul_(mt.abs()))
256
256
  elif self.cautious_mask:
257
257
  mask = (mt * grad_reshaped > 0).to(grad_reshaped.dtype)
258
258
  mask.div_(mask.mean().clamp_(min=1e-3))
@@ -310,7 +310,7 @@ class AdamW_adv(torch.optim.Optimizer):
310
310
  exp_avg = state['exp_avg']
311
311
  exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
312
312
  if self.grams_moment:
313
- exp_avg = grad.sign() * exp_avg.abs()
313
+ exp_avg = grad.sign().mul_(exp_avg.abs())
314
314
  elif self.cautious_mask:
315
315
  mask = (exp_avg * grad > 0).to(grad.dtype)
316
316
  mask.div_(mask.mean().clamp_(min=1e-3))
@@ -307,7 +307,7 @@ class Adopt_adv(torch.optim.Optimizer):
307
307
  else:
308
308
  mt.mul_(beta1).add_(normalized_grad, alpha=1.0 - beta1)
309
309
  if self.grams_moment:
310
- mt = grad_reshaped.sign() * mt.abs()
310
+ mt = grad_reshaped.sign().mul_(mt.abs())
311
311
  elif self.cautious_mask:
312
312
  mask = (mt * grad_reshaped > 0).to(grad_reshaped.dtype)
313
313
  mask.div_(mask.mean().clamp_(min=1e-3))
@@ -376,7 +376,7 @@ class Adopt_adv(torch.optim.Optimizer):
376
376
  m.mul_(beta1).add_(normalized_grad, alpha=1.0 - beta1)
377
377
 
378
378
  if self.grams_moment:
379
- m = grad.sign() * m.abs()
379
+ m = grad.sign().mul_(m.abs())
380
380
  elif self.cautious_mask:
381
381
  mask = (m * grad > 0).to(grad.dtype)
382
382
  mask.div_(mask.mean().clamp_(min=1e-3))
@@ -41,6 +41,7 @@ class Muon_adv(torch.optim.Optimizer):
41
41
  stability. (default: 100.0)
42
42
  stochastic_rounding (bool): whether to use stochastic rounding for
43
43
  BF16 parameter updates (default: True).
44
+ orthogonal_gradient (bool): whether to use OrthoGrad. (default: False)
44
45
  vector_reshape_muon (bool): whether to reshape 1D vectors into 2D
45
46
  matrices for muon NewtonSchulz (default: False).
46
47
  vector_reshape (bool): whether to reshape 1D vectors into 2D
@@ -92,6 +93,7 @@ class Muon_adv(torch.optim.Optimizer):
92
93
  Simplified_AdEMAMix: bool = False,
93
94
  alpha_grad: float = 100.0,
94
95
  stochastic_rounding: bool = True,
96
+ orthogonal_gradient: bool = False,
95
97
  vector_reshape_muon: bool = False,
96
98
  vector_reshape: bool = False,
97
99
  nnmf_factor: bool = False,
@@ -149,6 +151,7 @@ class Muon_adv(torch.optim.Optimizer):
149
151
  "vector_reshape": vector_reshape,
150
152
  "vector_reshape_muon": vector_reshape_muon,
151
153
  "Simplified_AdEMAMix": Simplified_AdEMAMix, "alpha_grad": alpha_grad,
154
+ "orthogonal_gradient": orthogonal_gradient,
152
155
  'compiled_optimizer': compiled_optimizer,
153
156
  # Low-rank Ortho
154
157
  "low_rank_ortho": low_rank_ortho, "ortho_rank": ortho_rank,
@@ -293,6 +296,10 @@ class Muon_adv(torch.optim.Optimizer):
293
296
  nesterov = group['nesterov']
294
297
  Simplified_AdEMAMix = group['Simplified_AdEMAMix']
295
298
  alpha_grad = group['alpha_grad']
299
+ if grad.dtype != torch.float32 and state.get('factored', False):
300
+ grad = grad.float()
301
+ if group.get("orthogonal_gradient"):
302
+ grad = _orthogonalize_gradient(p, grad)
296
303
 
297
304
  if state['factored']: # Factored Muon
298
305
 
@@ -343,7 +343,7 @@ class Prodigy_adv(torch.optim.Optimizer):
343
343
  else:
344
344
  mt.mul_(self.beta1).add_(grad_reshaped, alpha=self.d * (1.0 - self.beta1))
345
345
  if self.grams_moment:
346
- mt.copy_(grad_reshaped.sign() * mt.abs())
346
+ mt = (grad_reshaped.sign().mul_(mt.abs()))
347
347
  elif self.cautious_mask:
348
348
  mask = (mt * grad_reshaped > 0).to(grad_reshaped.dtype)
349
349
  mask.div_(mask.mean().clamp_(min=1e-3))
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: adv_optm
3
- Version: 1.2.dev17
3
+ Version: 1.2.dev18
4
4
  Summary: A family of highly efficient, lightweight yet powerful optimizers.
5
5
  Home-page: https://github.com/Koratahiu/Advanced_Optimizers
6
6
  Author: Koratahiu
@@ -1,11 +1,11 @@
1
- adv_optm/__init__.py,sha256=5Dww3w78iQNwyVH82E_dmD-s6luvQjiqYS0BxKQHYCE,380
2
- adv_optm/optim/AdaMuon_adv.py,sha256=zjZHFS7ng5KwemQzePjFiGtNZlcgbzmmnqF6A80h_Tg,34652
3
- adv_optm/optim/AdamW_adv.py,sha256=KL9SCJWZ_ckAQEApB6ofbndVYjancN-v7Us7hJLFf54,17475
4
- adv_optm/optim/Adopt_adv.py,sha256=S8XI2YA7683jsW8p7igc2YcU30lsN0H18qL02Kpvj8E,21244
1
+ adv_optm/__init__.py,sha256=1UzgEkreoqaobiwUZ8yR-8Fnda7T7XiHQ4PhJKQocy4,380
2
+ adv_optm/optim/AdaMuon_adv.py,sha256=VpNsw2CnU8bZThj9cJJ6HGIATPxv4VkIf3xTsUMXQAY,35027
3
+ adv_optm/optim/AdamW_adv.py,sha256=jgMuRAfsnUh_2wUEZgYpJX5uwoT_kQjtMs2Xn2vJ3x0,17480
4
+ adv_optm/optim/Adopt_adv.py,sha256=kbAeBG4bXWBvgj_qrE9W67J6c0swpEi4Erj2rfYrMXE,21252
5
5
  adv_optm/optim/Lion_Prodigy_adv.py,sha256=LEA3UYJpPeFnmxeniLNv1u2LKKj4ufx3Bq_MLw-nWXk,14617
6
6
  adv_optm/optim/Lion_adv.py,sha256=aGNAplZlyXYgVllYcV_s4bK8iC4fv6EizFoWIMNLdBc,8299
7
- adv_optm/optim/Muon_adv.py,sha256=d91wvmKKt_3IPqsqK1ZZ5cY71kuXyzy04IU3krn2NQ8,33316
8
- adv_optm/optim/Prodigy_adv.py,sha256=lEjbtuQbomsCX39DnTPeI8Z5YG0f2aZPXN_E7-nGgWw,26060
7
+ adv_optm/optim/Muon_adv.py,sha256=0D4k8UfMSzITJwQEDfqpceD5H7HQvv0f8uyVKvdvkHo,33704
8
+ adv_optm/optim/Prodigy_adv.py,sha256=k7f2J_RQpnrUXjwER_XOokISlQWpTSwGG-OL-bjMfBk,26061
9
9
  adv_optm/optim/Simplified_AdEMAMix.py,sha256=nEIA3yM11nBooKzHudB5l3x4UdFRBYRwiKVUkGmO0K8,12971
10
10
  adv_optm/optim/__init__.py,sha256=hpUWE6CKtt_rvMdgQVb3PtjhfZAvAxTq6hp8H8rIpBo,489
11
11
  adv_optm/util/BF16_Stochastic_Rounding.py,sha256=Q5H0BcogmE4atP65dLoI21HKSf50lRdsBDfeF6v9Tbg,1548
@@ -16,8 +16,8 @@ adv_optm/util/Newton_Schulz.py,sha256=bBboYw_jm5_FMf0Citl79uqNedkHOTjQnUI7rZgLBm
16
16
  adv_optm/util/One_Bit_Boolean.py,sha256=Wat49esdwohuN-OHOFMW8D0aOQgV9cP5Rl8z6yfmpos,1068
17
17
  adv_optm/util/OrthoGrad.py,sha256=NzInuBQGy_Ja__M1R9XbvqVaQ0fhGbtGgFE9YON7B3I,707
18
18
  adv_optm/util/__init__.py,sha256=CXzS703GB4gil85khZi7sgKOnbzXGBOltshIOSPqj18,435
19
- adv_optm-1.2.dev17.dist-info/licenses/LICENSE,sha256=HrhfyXIkWY2tGFK11kg7vPCqhgh5DcxleloqdhrpyMY,11558
20
- adv_optm-1.2.dev17.dist-info/METADATA,sha256=xE_ECrY_ALerNQRFBtKml1w_n8wSp8zOH0tIz-BLiqY,14023
21
- adv_optm-1.2.dev17.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
22
- adv_optm-1.2.dev17.dist-info/top_level.txt,sha256=iNfBIIzu-lPrQ7jyC56WBCcbkRwitM2nJ15-MRQ_6fg,9
23
- adv_optm-1.2.dev17.dist-info/RECORD,,
19
+ adv_optm-1.2.dev18.dist-info/licenses/LICENSE,sha256=HrhfyXIkWY2tGFK11kg7vPCqhgh5DcxleloqdhrpyMY,11558
20
+ adv_optm-1.2.dev18.dist-info/METADATA,sha256=cfQdGhiRlf_-xnPKqwuCE8PR6faLYx_RC6MrkjYDqI8,14023
21
+ adv_optm-1.2.dev18.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
22
+ adv_optm-1.2.dev18.dist-info/top_level.txt,sha256=iNfBIIzu-lPrQ7jyC56WBCcbkRwitM2nJ15-MRQ_6fg,9
23
+ adv_optm-1.2.dev18.dist-info/RECORD,,