adv-optm 1.2.dev17__py3-none-any.whl → 1.2.dev19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of adv-optm might be problematic. Click here for more details.

adv_optm/__init__.py CHANGED
@@ -20,4 +20,4 @@ __all__ = [
20
20
  "AdaMuon_adv",
21
21
  ]
22
22
 
23
- __version__ = "1.2.dev17"
23
+ __version__ = "1.2.dev19"
@@ -46,6 +46,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
46
46
  (default: (3.4445, -4.7750, 2.0315)).
47
47
  stochastic_rounding (bool): whether to use stochastic rounding for
48
48
  BF16 parameter updates (default: True).
49
+ orthogonal_gradient (bool): whether to use OrthoGrad. (default: False)
49
50
  nesterov (bool): enables Nesterov momentum (default: False).
50
51
  use_atan2 (bool): whether to use the atan2 update rule. (default: False)
51
52
  Simplified_AdEMAMix (bool): whether to use the Simplified AdEMAMix update rule.
@@ -95,6 +96,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
95
96
  ns_eps: float = 1e-7,
96
97
  ns_coeffs: tuple[float, float, float] = (3.4445, -4.7750, 2.0315),
97
98
  stochastic_rounding: bool = False,
99
+ orthogonal_gradient: bool = False,
98
100
  use_atan2: bool = False,
99
101
  nesterov: bool = False,
100
102
  Simplified_AdEMAMix: bool = False,
@@ -147,7 +149,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
147
149
  "vector_reshape": vector_reshape,
148
150
  "nesterov":nesterov, "use_atan2":use_atan2,
149
151
  "Simplified_AdEMAMix": Simplified_AdEMAMix, "alpha_grad": alpha_grad,
150
- "normuon_variant": normuon_variant,
152
+ "normuon_variant": normuon_variant, "orthogonal_gradient": orthogonal_gradient,
151
153
  # Low-rank Ortho
152
154
  "low_rank_ortho": low_rank_ortho, "ortho_rank": ortho_rank,
153
155
  "compiled_optimizer":compiled_optimizer,
@@ -282,6 +284,10 @@ class AdaMuon_adv(torch.optim.Optimizer):
282
284
  nesterov = group['nesterov']
283
285
  Simplified_AdEMAMix = group['Simplified_AdEMAMix']
284
286
  alpha_grad = group['alpha_grad']
287
+ if grad.dtype != torch.float32 and state.get('factored', False):
288
+ grad = grad.float()
289
+ if group.get("orthogonal_gradient"):
290
+ grad = _orthogonalize_gradient(p, grad)
285
291
 
286
292
  if state['factored']: # Factored AdaMuon
287
293
 
@@ -345,11 +351,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
345
351
  mean_squared_update = torch.mean(update.square(), dim=1)
346
352
  v_t.mul_(beta2).add_(mean_squared_update, alpha=1 - beta2)
347
353
  # Normalize update
348
- if group['use_atan2']:
349
- a = 1.2732395
350
- update.atan2_(v_t.sqrt().unsqueeze(1)).mul_(a)
351
- else:
352
- update.div_(v_t.sqrt().unsqueeze(1).add_(group['eps']))
354
+ update.div_(v_t.sqrt().unsqueeze(1).add_(group['eps']))
353
355
  # Scale learning rate
354
356
  update_norm = torch.linalg.vector_norm(update)
355
357
  scaled_lr = group['rms_target'] * lr * (p.numel()**0.5) / update_norm.add_(group['eps'])
@@ -454,11 +456,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
454
456
  mean_squared_update = torch.mean(update.square(), dim=1)
455
457
  v_t.mul_(beta2).add_(mean_squared_update, alpha=1 - beta2)
456
458
  # Normalize update
457
- if group['use_atan2']:
458
- a = 1.2732395
459
- update.atan2_(v_t.sqrt().unsqueeze(1)).mul_(a)
460
- else:
461
- update.div_(v_t.sqrt().unsqueeze(1).add_(group['eps']))
459
+ update.div_(v_t.sqrt().unsqueeze(1).add_(group['eps']))
462
460
  # Scale learning rate
463
461
  update_norm = torch.linalg.vector_norm(update)
464
462
  scaled_lr = group['rms_target'] * lr * (p.numel()**0.5) / update_norm.add_(group['eps'])
@@ -252,7 +252,7 @@ class AdamW_adv(torch.optim.Optimizer):
252
252
  # Update momentum in full-size
253
253
  mt.mul_(beta1).add_(grad_reshaped, alpha=1.0 - beta1)
254
254
  if self.grams_moment:
255
- mt.copy_(grad_reshaped.sign() * mt.abs())
255
+ mt = (grad_reshaped.sign().mul_(mt.abs()))
256
256
  elif self.cautious_mask:
257
257
  mask = (mt * grad_reshaped > 0).to(grad_reshaped.dtype)
258
258
  mask.div_(mask.mean().clamp_(min=1e-3))
@@ -310,7 +310,7 @@ class AdamW_adv(torch.optim.Optimizer):
310
310
  exp_avg = state['exp_avg']
311
311
  exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
312
312
  if self.grams_moment:
313
- exp_avg = grad.sign() * exp_avg.abs()
313
+ exp_avg = grad.sign().mul_(exp_avg.abs())
314
314
  elif self.cautious_mask:
315
315
  mask = (exp_avg * grad > 0).to(grad.dtype)
316
316
  mask.div_(mask.mean().clamp_(min=1e-3))
@@ -307,7 +307,7 @@ class Adopt_adv(torch.optim.Optimizer):
307
307
  else:
308
308
  mt.mul_(beta1).add_(normalized_grad, alpha=1.0 - beta1)
309
309
  if self.grams_moment:
310
- mt = grad_reshaped.sign() * mt.abs()
310
+ mt = grad_reshaped.sign().mul_(mt.abs())
311
311
  elif self.cautious_mask:
312
312
  mask = (mt * grad_reshaped > 0).to(grad_reshaped.dtype)
313
313
  mask.div_(mask.mean().clamp_(min=1e-3))
@@ -376,7 +376,7 @@ class Adopt_adv(torch.optim.Optimizer):
376
376
  m.mul_(beta1).add_(normalized_grad, alpha=1.0 - beta1)
377
377
 
378
378
  if self.grams_moment:
379
- m = grad.sign() * m.abs()
379
+ m = grad.sign().mul_(m.abs())
380
380
  elif self.cautious_mask:
381
381
  mask = (m * grad > 0).to(grad.dtype)
382
382
  mask.div_(mask.mean().clamp_(min=1e-3))
@@ -41,6 +41,7 @@ class Muon_adv(torch.optim.Optimizer):
41
41
  stability. (default: 100.0)
42
42
  stochastic_rounding (bool): whether to use stochastic rounding for
43
43
  BF16 parameter updates (default: True).
44
+ orthogonal_gradient (bool): whether to use OrthoGrad. (default: False)
44
45
  vector_reshape_muon (bool): whether to reshape 1D vectors into 2D
45
46
  matrices for muon NewtonSchulz (default: False).
46
47
  vector_reshape (bool): whether to reshape 1D vectors into 2D
@@ -59,7 +60,6 @@ class Muon_adv(torch.optim.Optimizer):
59
60
  normuon_eps (float): Epsilon for NorMuon normalization stability. (default: 1e-8)
60
61
  normuon_lr_scale (float): Scaling factor for the NorMuon learning rate.
61
62
  (default: 0.2)
62
- normuon_atan2 (bool): whether to use the atan2 for NorMuon. (default: False)
63
63
  accelerated_ns (bool): If True, enables Chebyshev-accelerated Newton-Schulz, which
64
64
  dynamically calculates optimal 3rd-order polynomial coefficients. (default: False)
65
65
  cns_a_bound (float): Initial lower bound for singular values for CANS. (default: 1e-4)
@@ -92,6 +92,7 @@ class Muon_adv(torch.optim.Optimizer):
92
92
  Simplified_AdEMAMix: bool = False,
93
93
  alpha_grad: float = 100.0,
94
94
  stochastic_rounding: bool = True,
95
+ orthogonal_gradient: bool = False,
95
96
  vector_reshape_muon: bool = False,
96
97
  vector_reshape: bool = False,
97
98
  nnmf_factor: bool = False,
@@ -103,7 +104,6 @@ class Muon_adv(torch.optim.Optimizer):
103
104
  beta2_normuon: float = 0.95,
104
105
  normuon_eps: float = 1e-8,
105
106
  normuon_lr_scale: float = 0.2,
106
- normuon_atan2: bool = False,
107
107
  # CANS
108
108
  accelerated_ns: bool = False,
109
109
  cns_a_bound: float = 1e-4,
@@ -149,13 +149,13 @@ class Muon_adv(torch.optim.Optimizer):
149
149
  "vector_reshape": vector_reshape,
150
150
  "vector_reshape_muon": vector_reshape_muon,
151
151
  "Simplified_AdEMAMix": Simplified_AdEMAMix, "alpha_grad": alpha_grad,
152
+ "orthogonal_gradient": orthogonal_gradient,
152
153
  'compiled_optimizer': compiled_optimizer,
153
154
  # Low-rank Ortho
154
155
  "low_rank_ortho": low_rank_ortho, "ortho_rank": ortho_rank,
155
156
  # NorMuon
156
157
  "normuon_variant": normuon_variant, "beta2_normuon": beta2_normuon,
157
158
  "normuon_eps": normuon_eps, "normuon_lr_scale": normuon_lr_scale,
158
- "normuon_atan2": normuon_atan2,
159
159
  # CANS
160
160
  "accelerated_ns": accelerated_ns, "cns_a_bound": cns_a_bound,
161
161
  # AdamW_adv defaults
@@ -293,6 +293,10 @@ class Muon_adv(torch.optim.Optimizer):
293
293
  nesterov = group['nesterov']
294
294
  Simplified_AdEMAMix = group['Simplified_AdEMAMix']
295
295
  alpha_grad = group['alpha_grad']
296
+ if grad.dtype != torch.float32 and state.get('factored', False):
297
+ grad = grad.float()
298
+ if group.get("orthogonal_gradient"):
299
+ grad = _orthogonalize_gradient(p, grad)
296
300
 
297
301
  if state['factored']: # Factored Muon
298
302
 
@@ -359,11 +363,7 @@ class Muon_adv(torch.optim.Optimizer):
359
363
  mean_squared_update = torch.mean(update.square(), dim=1)
360
364
  v_t.mul_(beta2_normuon).add_(mean_squared_update, alpha=1 - beta2_normuon)
361
365
  # Normalize update
362
- if group['normuon_atan2']:
363
- a = 1.2732395
364
- update.atan2_(v_t.sqrt().unsqueeze(1)).mul_(a)
365
- else:
366
- update.div_(v_t.sqrt().unsqueeze(1).add_(group['normuon_eps']))
366
+ update.div_(v_t.sqrt().unsqueeze(1).add_(group['normuon_eps']))
367
367
  # Scale learning rate
368
368
  update_norm = torch.linalg.vector_norm(update)
369
369
 
@@ -464,11 +464,7 @@ class Muon_adv(torch.optim.Optimizer):
464
464
  mean_squared_update = torch.mean(update.square(), dim=1)
465
465
  v_t.mul_(beta2_normuon).add_(mean_squared_update, alpha=1 - beta2_normuon)
466
466
  # Normalize update
467
- if group['normuon_atan2']:
468
- a = 1.2732395
469
- update.atan2_(v_t.sqrt().unsqueeze(1)).mul_(a)
470
- else:
471
- update.div_(v_t.sqrt().unsqueeze(1).add_(group['normuon_eps']))
467
+ update.div_(v_t.sqrt().unsqueeze(1).add_(group['normuon_eps']))
472
468
  # Scale learning rate
473
469
  update_norm = torch.linalg.vector_norm(update)
474
470
  scaled_lr = group['normuon_lr_scale'] * lr * (p.numel()**0.5) / update_norm.add_(group['normuon_eps'])
@@ -343,7 +343,7 @@ class Prodigy_adv(torch.optim.Optimizer):
343
343
  else:
344
344
  mt.mul_(self.beta1).add_(grad_reshaped, alpha=self.d * (1.0 - self.beta1))
345
345
  if self.grams_moment:
346
- mt.copy_(grad_reshaped.sign() * mt.abs())
346
+ mt = (grad_reshaped.sign().mul_(mt.abs()))
347
347
  elif self.cautious_mask:
348
348
  mask = (mt * grad_reshaped > 0).to(grad_reshaped.dtype)
349
349
  mask.div_(mask.mean().clamp_(min=1e-3))
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: adv_optm
3
- Version: 1.2.dev17
3
+ Version: 1.2.dev19
4
4
  Summary: A family of highly efficient, lightweight yet powerful optimizers.
5
5
  Home-page: https://github.com/Koratahiu/Advanced_Optimizers
6
6
  Author: Koratahiu
@@ -1,11 +1,11 @@
1
- adv_optm/__init__.py,sha256=5Dww3w78iQNwyVH82E_dmD-s6luvQjiqYS0BxKQHYCE,380
2
- adv_optm/optim/AdaMuon_adv.py,sha256=zjZHFS7ng5KwemQzePjFiGtNZlcgbzmmnqF6A80h_Tg,34652
3
- adv_optm/optim/AdamW_adv.py,sha256=KL9SCJWZ_ckAQEApB6ofbndVYjancN-v7Us7hJLFf54,17475
4
- adv_optm/optim/Adopt_adv.py,sha256=S8XI2YA7683jsW8p7igc2YcU30lsN0H18qL02Kpvj8E,21244
1
+ adv_optm/__init__.py,sha256=1AKxG--scx5Bl9G08tQcnfzAMaQVSgmW99uy3v2QWMw,380
2
+ adv_optm/optim/AdaMuon_adv.py,sha256=7Had92OcsCiN1E9UJRyrpPV7VzHqmIvS-qM6OEcc24I,34671
3
+ adv_optm/optim/AdamW_adv.py,sha256=jgMuRAfsnUh_2wUEZgYpJX5uwoT_kQjtMs2Xn2vJ3x0,17480
4
+ adv_optm/optim/Adopt_adv.py,sha256=kbAeBG4bXWBvgj_qrE9W67J6c0swpEi4Erj2rfYrMXE,21252
5
5
  adv_optm/optim/Lion_Prodigy_adv.py,sha256=LEA3UYJpPeFnmxeniLNv1u2LKKj4ufx3Bq_MLw-nWXk,14617
6
6
  adv_optm/optim/Lion_adv.py,sha256=aGNAplZlyXYgVllYcV_s4bK8iC4fv6EizFoWIMNLdBc,8299
7
- adv_optm/optim/Muon_adv.py,sha256=d91wvmKKt_3IPqsqK1ZZ5cY71kuXyzy04IU3krn2NQ8,33316
8
- adv_optm/optim/Prodigy_adv.py,sha256=lEjbtuQbomsCX39DnTPeI8Z5YG0f2aZPXN_E7-nGgWw,26060
7
+ adv_optm/optim/Muon_adv.py,sha256=tZY8K3pNBCGk1V09GbK05lJooFw92NfkF7_T548up3Q,33171
8
+ adv_optm/optim/Prodigy_adv.py,sha256=k7f2J_RQpnrUXjwER_XOokISlQWpTSwGG-OL-bjMfBk,26061
9
9
  adv_optm/optim/Simplified_AdEMAMix.py,sha256=nEIA3yM11nBooKzHudB5l3x4UdFRBYRwiKVUkGmO0K8,12971
10
10
  adv_optm/optim/__init__.py,sha256=hpUWE6CKtt_rvMdgQVb3PtjhfZAvAxTq6hp8H8rIpBo,489
11
11
  adv_optm/util/BF16_Stochastic_Rounding.py,sha256=Q5H0BcogmE4atP65dLoI21HKSf50lRdsBDfeF6v9Tbg,1548
@@ -16,8 +16,8 @@ adv_optm/util/Newton_Schulz.py,sha256=bBboYw_jm5_FMf0Citl79uqNedkHOTjQnUI7rZgLBm
16
16
  adv_optm/util/One_Bit_Boolean.py,sha256=Wat49esdwohuN-OHOFMW8D0aOQgV9cP5Rl8z6yfmpos,1068
17
17
  adv_optm/util/OrthoGrad.py,sha256=NzInuBQGy_Ja__M1R9XbvqVaQ0fhGbtGgFE9YON7B3I,707
18
18
  adv_optm/util/__init__.py,sha256=CXzS703GB4gil85khZi7sgKOnbzXGBOltshIOSPqj18,435
19
- adv_optm-1.2.dev17.dist-info/licenses/LICENSE,sha256=HrhfyXIkWY2tGFK11kg7vPCqhgh5DcxleloqdhrpyMY,11558
20
- adv_optm-1.2.dev17.dist-info/METADATA,sha256=xE_ECrY_ALerNQRFBtKml1w_n8wSp8zOH0tIz-BLiqY,14023
21
- adv_optm-1.2.dev17.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
22
- adv_optm-1.2.dev17.dist-info/top_level.txt,sha256=iNfBIIzu-lPrQ7jyC56WBCcbkRwitM2nJ15-MRQ_6fg,9
23
- adv_optm-1.2.dev17.dist-info/RECORD,,
19
+ adv_optm-1.2.dev19.dist-info/licenses/LICENSE,sha256=HrhfyXIkWY2tGFK11kg7vPCqhgh5DcxleloqdhrpyMY,11558
20
+ adv_optm-1.2.dev19.dist-info/METADATA,sha256=pQm5WuMKvf5Xse10viziVK9ry1UufcYRDwOd55jad8Y,14023
21
+ adv_optm-1.2.dev19.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
22
+ adv_optm-1.2.dev19.dist-info/top_level.txt,sha256=iNfBIIzu-lPrQ7jyC56WBCcbkRwitM2nJ15-MRQ_6fg,9
23
+ adv_optm-1.2.dev19.dist-info/RECORD,,