heavyball 0.20.0__py3-none-any.whl → 0.20.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
heavyball/utils.py CHANGED
@@ -721,7 +721,7 @@ def init_Q_exprs(t, scale, max_size, min_ndim_triangular, memory_save_mode, dtyp
721
721
  return [Q, (exprA, tuple(exprGs), exprP)]
722
722
 
723
723
 
724
- @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
724
+ @decorator
725
725
  def psgd_balance_Q(Q_in):
726
726
  norms = torch.stack([q.norm(float("inf")) for q in Q_in])
727
727
  geometric_mean = norms.log().mean().exp()
@@ -734,8 +734,7 @@ def psgd_calc_A_and_conjB(exprA, G, Q):
734
734
  A = torch.einsum(exprA, *[q.to(md) for q in Q], G.to(md)).to(G.dtype)
735
735
  order = G.dim()
736
736
  p = list(range(order))
737
- V = torch.randn_like(G, dtype=promote(G.dtype))
738
- conjB = torch.permute(V, p[1:] + p[:1])
737
+ conjB = torch.randn_like(G.shape[1:] + G.shape[:1], dtype=promote(G.dtype), device=G.device)
739
738
  Q = [promote(q) for q in Q]
740
739
  for i, q in enumerate(Q):
741
740
  if q.dim() <= 1:
@@ -755,21 +754,13 @@ def psgd_calc_A_and_conjB(exprA, G, Q):
755
754
  def psgd_lb(A, max_abs):
756
755
  A /= max_abs
757
756
  a0 = torch.einsum('ij,ij->j', A, A)
758
- a1 = torch.einsum('ij,ij->i', A, A)
759
- value0 = torch.max(a0)
760
- value1 = torch.max(a1)
761
757
  i = torch.argmax(a0)
762
- j = torch.argmax(a1)
763
758
 
764
- comp = value0 > value1
765
- x = torch.cond(comp, lambda a: torch.index_select(a, 1, i).flatten().contiguous(), #
766
- lambda a: torch.index_select(a, 0, j).flatten().contiguous(), (A,))
759
+ x = torch.index_select(a, 1, i).flatten().contiguous()
767
760
 
768
- x = torch.cond(comp, lambda x_, a: torch.einsum('i,ij->j', x_, a), lambda x_, a: torch.einsum('i,ji->j', x_, a),
769
- (x, A,))
761
+ x = torch.einsum('i,ij->j', x_, a)
770
762
  x /= x.norm()
771
- x = torch.cond(comp, lambda x_, a: torch.einsum('j,kj->k', x_, a), lambda x_, a: torch.einsum('j,jk->k', x_, a),
772
- (x, A,))
763
+ x = torch.einsum('j,kj->k', x_, a)
773
764
  x = x.norm()
774
765
  x *= max_abs
775
766
  return x
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: heavyball
3
- Version: 0.20.0
3
+ Version: 0.20.1
4
4
  Summary: Efficient optimizers
5
5
  Home-page: https://github.com/clashluke/heavyball
6
6
  Author: Lucas Nestler
@@ -16,9 +16,9 @@ heavyball/precond_schedule_sfpsoap.py,sha256=PNneiOkrRyV1yIZn91lPmYofd1_OiLqJTDy
16
16
  heavyball/psgd_kron.py,sha256=RSLJi5FnSmYjvYBufNDClnYRm-eN_Kpa1Ar2tNP6-X0,5824
17
17
  heavyball/pure_psgd.py,sha256=LZK0qmvZkBF8g00evaVLtW-sIUJmdoxag1K7O26AqEo,4820
18
18
  heavyball/schedule_free_palm_foreach_soap.py,sha256=_AqrnChY6iDlQUkF2YUxS7eLjSWCIuvEUOHvMHVM1yY,6873
19
- heavyball/utils.py,sha256=ESazD0yv14Aa8XKi_pz2CyfVkpcbgYcG2-WMvhQOnxk,35719
20
- heavyball-0.20.0.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
21
- heavyball-0.20.0.dist-info/METADATA,sha256=dJ43LOTrNqh7cDTDzZDSu57goP1gNhU3dfZ26BUK9hA,11926
22
- heavyball-0.20.0.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
23
- heavyball-0.20.0.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
24
- heavyball-0.20.0.dist-info/RECORD,,
19
+ heavyball/utils.py,sha256=14vt4r_MeTsp1q3m0lpgF-Q3PCJg6GLGJrhjRxnbWwQ,35174
20
+ heavyball-0.20.1.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
21
+ heavyball-0.20.1.dist-info/METADATA,sha256=qzF2P7e2EREeTy_4h85tvUY53omjNm32z83CUHTqt3U,11926
22
+ heavyball-0.20.1.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
23
+ heavyball-0.20.1.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
24
+ heavyball-0.20.1.dist-info/RECORD,,