heavyball 0.20.0__py3-none-any.whl → 0.20.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- heavyball/utils.py +5 -14
- {heavyball-0.20.0.dist-info → heavyball-0.20.1.dist-info}/METADATA +1 -1
- {heavyball-0.20.0.dist-info → heavyball-0.20.1.dist-info}/RECORD +6 -6
- {heavyball-0.20.0.dist-info → heavyball-0.20.1.dist-info}/LICENSE +0 -0
- {heavyball-0.20.0.dist-info → heavyball-0.20.1.dist-info}/WHEEL +0 -0
- {heavyball-0.20.0.dist-info → heavyball-0.20.1.dist-info}/top_level.txt +0 -0
heavyball/utils.py
CHANGED
@@ -721,7 +721,7 @@ def init_Q_exprs(t, scale, max_size, min_ndim_triangular, memory_save_mode, dtyp
|
|
721
721
|
return [Q, (exprA, tuple(exprGs), exprP)]
|
722
722
|
|
723
723
|
|
724
|
-
@
|
724
|
+
@decorator
|
725
725
|
def psgd_balance_Q(Q_in):
|
726
726
|
norms = torch.stack([q.norm(float("inf")) for q in Q_in])
|
727
727
|
geometric_mean = norms.log().mean().exp()
|
@@ -734,8 +734,7 @@ def psgd_calc_A_and_conjB(exprA, G, Q):
|
|
734
734
|
A = torch.einsum(exprA, *[q.to(md) for q in Q], G.to(md)).to(G.dtype)
|
735
735
|
order = G.dim()
|
736
736
|
p = list(range(order))
|
737
|
-
|
738
|
-
conjB = torch.permute(V, p[1:] + p[:1])
|
737
|
+
conjB = torch.randn_like(G.shape[1:] + G.shape[:1], dtype=promote(G.dtype), device=G.device)
|
739
738
|
Q = [promote(q) for q in Q]
|
740
739
|
for i, q in enumerate(Q):
|
741
740
|
if q.dim() <= 1:
|
@@ -755,21 +754,13 @@ def psgd_calc_A_and_conjB(exprA, G, Q):
|
|
755
754
|
def psgd_lb(A, max_abs):
|
756
755
|
A /= max_abs
|
757
756
|
a0 = torch.einsum('ij,ij->j', A, A)
|
758
|
-
a1 = torch.einsum('ij,ij->i', A, A)
|
759
|
-
value0 = torch.max(a0)
|
760
|
-
value1 = torch.max(a1)
|
761
757
|
i = torch.argmax(a0)
|
762
|
-
j = torch.argmax(a1)
|
763
758
|
|
764
|
-
|
765
|
-
x = torch.cond(comp, lambda a: torch.index_select(a, 1, i).flatten().contiguous(), #
|
766
|
-
lambda a: torch.index_select(a, 0, j).flatten().contiguous(), (A,))
|
759
|
+
x = torch.index_select(a, 1, i).flatten().contiguous()
|
767
760
|
|
768
|
-
x = torch.
|
769
|
-
(x, A,))
|
761
|
+
x = torch.einsum('i,ij->j', x_, a)
|
770
762
|
x /= x.norm()
|
771
|
-
x = torch.
|
772
|
-
(x, A,))
|
763
|
+
x = torch.einsum('j,kj->k', x_, a)
|
773
764
|
x = x.norm()
|
774
765
|
x *= max_abs
|
775
766
|
return x
|
@@ -16,9 +16,9 @@ heavyball/precond_schedule_sfpsoap.py,sha256=PNneiOkrRyV1yIZn91lPmYofd1_OiLqJTDy
|
|
16
16
|
heavyball/psgd_kron.py,sha256=RSLJi5FnSmYjvYBufNDClnYRm-eN_Kpa1Ar2tNP6-X0,5824
|
17
17
|
heavyball/pure_psgd.py,sha256=LZK0qmvZkBF8g00evaVLtW-sIUJmdoxag1K7O26AqEo,4820
|
18
18
|
heavyball/schedule_free_palm_foreach_soap.py,sha256=_AqrnChY6iDlQUkF2YUxS7eLjSWCIuvEUOHvMHVM1yY,6873
|
19
|
-
heavyball/utils.py,sha256=
|
20
|
-
heavyball-0.20.
|
21
|
-
heavyball-0.20.
|
22
|
-
heavyball-0.20.
|
23
|
-
heavyball-0.20.
|
24
|
-
heavyball-0.20.
|
19
|
+
heavyball/utils.py,sha256=14vt4r_MeTsp1q3m0lpgF-Q3PCJg6GLGJrhjRxnbWwQ,35174
|
20
|
+
heavyball-0.20.1.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
|
21
|
+
heavyball-0.20.1.dist-info/METADATA,sha256=qzF2P7e2EREeTy_4h85tvUY53omjNm32z83CUHTqt3U,11926
|
22
|
+
heavyball-0.20.1.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
23
|
+
heavyball-0.20.1.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
|
24
|
+
heavyball-0.20.1.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|