heavyball 0.17.1__py3-none-any.whl → 0.17.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
heavyball/utils.py CHANGED
@@ -332,6 +332,16 @@ def promote(x):
332
332
  return x
333
333
 
334
334
 
335
+ def min_dtype(xs: List[torch.Tensor]):
336
+ dtypes = [x.dtype for x in xs]
337
+ for d in (torch.float32, torch.bfloat16, torch.float16):
338
+ if all(d == x for x in dtypes):
339
+ return d
340
+ if all(d in (x, torch.float32, torch.float64) for x in dtypes):
341
+ return d
342
+ return torch.float32
343
+
344
+
335
345
  def update_preconditioner(grad, state, max_precond_dim, precondition_1d, beta, update_precond):
336
346
  """
337
347
  Updates the preconditioner matrices and the eigenbases (L, R, Q_L, Q_R in the paper).
@@ -471,13 +481,8 @@ def copy_stochastic_list_(target: List[torch.Tensor], source: List[torch.Tensor]
471
481
  copy_stochastic_(t, s)
472
482
 
473
483
 
474
- def copy_stochastic_(target: torch.Tensor, source: torch.Tensor):
475
- if target.data_ptr() == source.data_ptr():
476
- return
477
- if target.dtype != torch.bfloat16:
478
- set_(target, source)
479
- return
480
-
484
+ @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
485
+ def _compilable_copy_stochastic_(target: torch.Tensor, source: torch.Tensor):
481
486
  """Taken as-is from https://github.com/pytorch/pytorch/issues/120376#issuecomment-1974828905"""
482
487
  # create a random 16 bit integer
483
488
  result = torch.randint_like(source, dtype=torch.int32, low=0, high=(1 << 16))
@@ -492,6 +497,15 @@ def copy_stochastic_(target: torch.Tensor, source: torch.Tensor):
492
497
  target.copy_(result.view(dtype=torch.float32))
493
498
 
494
499
 
500
+ def copy_stochastic_(target: torch.Tensor, source: torch.Tensor):
501
+ if target.data_ptr() == source.data_ptr():
502
+ return
503
+ if target.dtype != torch.bfloat16 or source.dtype not in (torch.float16, torch.float32, torch.float64):
504
+ set_(target, source)
505
+ return
506
+ _compilable_copy_stochastic_(target, source)
507
+
508
+
495
509
  def update_param_(param: List[torch.Tensor], update: List[torch.Tensor], lr: float, decay: float,
496
510
  add_fn: callable = None):
497
511
  param32 = [promote(p) for p in param]
@@ -602,7 +616,8 @@ def psgd_balance_Q(Q_in):
602
616
 
603
617
 
604
618
  def psgd_calc_A_and_conjB(exprA, G, Q, V):
605
- A = torch.einsum(exprA, *Q, G)
619
+ md = min_dtype(Q)
620
+ A = torch.einsum(exprA, *[q.to(md) for q in Q], G.to(md))
606
621
  order = G.dim()
607
622
  p = list(range(order))
608
623
  conjB = torch.permute(V.conj(), p[1:] + p[:1])
@@ -653,7 +668,10 @@ def psgd_update_precond(Q, exprs, V, G, step, tiny):
653
668
 
654
669
  term2 += term1 # a + b
655
670
  term1 *= 2 # 2a
656
- term1 -= term2 # 2a - (a + b) == a - b
671
+ if term1.dtype == term2.dtype:
672
+ term1 -= term2 # 2a - (a + b) == a - b
673
+ else:
674
+ term1 = term1 - term2
657
675
 
658
676
  term1 *= step
659
677
  norm = term2.norm(float('inf'))
@@ -669,7 +687,8 @@ def psgd_update_precond(Q, exprs, V, G, step, tiny):
669
687
  @decorator
670
688
  def psgd_precond_grad(Q, exprs, G, inplace: bool = False):
671
689
  """Precondition gradient G with preconditioner Q."""
672
- out = torch.einsum(exprs[-1], *[q.conj() for q in Q], *Q, G.to(Q[0].dtype))
690
+ md = min_dtype(Q)
691
+ out = torch.einsum(exprs[-1], *[q.conj().to(md) for q in Q], *[q.to(md) for q in Q], G.to(md))
673
692
  if inplace:
674
693
  set_(G, out)
675
694
  return G
@@ -787,14 +806,15 @@ class PSGDBase(StatefulOptimizer):
787
806
  if g.dim() > 1:
788
807
  psgd_balance_Q(q)
789
808
 
790
- def do_update(self, p_list, grad_list, q_list, precond_lr, original_q: Optional[List] = None, store_triu_as_line=False):
809
+ def do_update(self, p_list, grad_list, q_list, precond_lr, original_q: Optional[List] = None,
810
+ store_triu_as_line=False):
791
811
  for i, (p, grad, Q) in enumerate(zip(p_list, grad_list, q_list)):
792
812
  psgd_update_precond(Q, self.state_(p)["exprs"], torch.randn_like(grad), grad, precond_lr, self._tiny)
793
813
  if original_q:
794
814
  if store_triu_as_line:
795
815
  update_triu_(original_q[i], Q)
796
816
  else:
797
- copy_stochastic_(original_q[i], Q)
817
+ copy_stochastic_list_(original_q[i], Q)
798
818
 
799
819
 
800
820
  def precond_update_prob_schedule(max_prob=1.0, min_prob=0.03, decay=0.001, flat_start=250):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: heavyball
3
- Version: 0.17.1
3
+ Version: 0.17.3
4
4
  Summary: Efficient optimizers
5
5
  Home-page: https://github.com/clashluke/heavyball
6
6
  Author: Lucas Nestler
@@ -16,9 +16,9 @@ heavyball/precond_schedule_sfpsoap.py,sha256=vq7jd302refKPa_9X2lkOTOtCCcTBVByPdo
16
16
  heavyball/psgd_kron.py,sha256=2IpPj2TOExNGm8hSewi3er2GczJRNgC7r2J5yYSSA_0,5998
17
17
  heavyball/pure_psgd.py,sha256=uA7W9a3Qm1sxHQhtNxaUYrmE5x55lP5iJOKy_qT8XaQ,5341
18
18
  heavyball/schedule_free_palm_foreach_soap.py,sha256=zkcikH5wWbzq4kOrmBjilvY3iWzuUddcv2HNEPKr3MI,6366
19
- heavyball/utils.py,sha256=Jqh7VdWGeiSdwaPtUNB9l14wuuFPSReLaTwJA3juFbM,28765
20
- heavyball-0.17.1.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
21
- heavyball-0.17.1.dist-info/METADATA,sha256=2FAgCpyuH4G-B_m0mhbl-sdkMizS1sd8oNmNkPpAKN0,11810
22
- heavyball-0.17.1.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
23
- heavyball-0.17.1.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
24
- heavyball-0.17.1.dist-info/RECORD,,
19
+ heavyball/utils.py,sha256=wEwg-TElm56CEf-Q_F7wgJdz17P4Hp0CakALGJj1090,29563
20
+ heavyball-0.17.3.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
21
+ heavyball-0.17.3.dist-info/METADATA,sha256=QMYWssjAbuY_eYSGRVuSOVeFrMiq7YZA4-qJe8E8u74,11810
22
+ heavyball-0.17.3.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
23
+ heavyball-0.17.3.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
24
+ heavyball-0.17.3.dist-info/RECORD,,