heavyball 0.17.1__py3-none-any.whl → 0.17.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- heavyball/utils.py +32 -12
- {heavyball-0.17.1.dist-info → heavyball-0.17.3.dist-info}/METADATA +1 -1
- {heavyball-0.17.1.dist-info → heavyball-0.17.3.dist-info}/RECORD +6 -6
- {heavyball-0.17.1.dist-info → heavyball-0.17.3.dist-info}/LICENSE +0 -0
- {heavyball-0.17.1.dist-info → heavyball-0.17.3.dist-info}/WHEEL +0 -0
- {heavyball-0.17.1.dist-info → heavyball-0.17.3.dist-info}/top_level.txt +0 -0
heavyball/utils.py
CHANGED
@@ -332,6 +332,16 @@ def promote(x):
|
|
332
332
|
return x
|
333
333
|
|
334
334
|
|
335
|
+
def min_dtype(xs: List[torch.Tensor]):
|
336
|
+
dtypes = [x.dtype for x in xs]
|
337
|
+
for d in (torch.float32, torch.bfloat16, torch.float16):
|
338
|
+
if all(d == x for x in dtypes):
|
339
|
+
return d
|
340
|
+
if all(d in (x, torch.float32, torch.float64) for x in dtypes):
|
341
|
+
return d
|
342
|
+
return torch.float32
|
343
|
+
|
344
|
+
|
335
345
|
def update_preconditioner(grad, state, max_precond_dim, precondition_1d, beta, update_precond):
|
336
346
|
"""
|
337
347
|
Updates the preconditioner matrices and the eigenbases (L, R, Q_L, Q_R in the paper).
|
@@ -471,13 +481,8 @@ def copy_stochastic_list_(target: List[torch.Tensor], source: List[torch.Tensor]
|
|
471
481
|
copy_stochastic_(t, s)
|
472
482
|
|
473
483
|
|
474
|
-
|
475
|
-
|
476
|
-
return
|
477
|
-
if target.dtype != torch.bfloat16:
|
478
|
-
set_(target, source)
|
479
|
-
return
|
480
|
-
|
484
|
+
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
|
485
|
+
def _compilable_copy_stochastic_(target: torch.Tensor, source: torch.Tensor):
|
481
486
|
"""Taken as-is from https://github.com/pytorch/pytorch/issues/120376#issuecomment-1974828905"""
|
482
487
|
# create a random 16 bit integer
|
483
488
|
result = torch.randint_like(source, dtype=torch.int32, low=0, high=(1 << 16))
|
@@ -492,6 +497,15 @@ def copy_stochastic_(target: torch.Tensor, source: torch.Tensor):
|
|
492
497
|
target.copy_(result.view(dtype=torch.float32))
|
493
498
|
|
494
499
|
|
500
|
+
def copy_stochastic_(target: torch.Tensor, source: torch.Tensor):
|
501
|
+
if target.data_ptr() == source.data_ptr():
|
502
|
+
return
|
503
|
+
if target.dtype != torch.bfloat16 or source.dtype not in (torch.float16, torch.float32, torch.float64):
|
504
|
+
set_(target, source)
|
505
|
+
return
|
506
|
+
_compilable_copy_stochastic_(target, source)
|
507
|
+
|
508
|
+
|
495
509
|
def update_param_(param: List[torch.Tensor], update: List[torch.Tensor], lr: float, decay: float,
|
496
510
|
add_fn: callable = None):
|
497
511
|
param32 = [promote(p) for p in param]
|
@@ -602,7 +616,8 @@ def psgd_balance_Q(Q_in):
|
|
602
616
|
|
603
617
|
|
604
618
|
def psgd_calc_A_and_conjB(exprA, G, Q, V):
|
605
|
-
|
619
|
+
md = min_dtype(Q)
|
620
|
+
A = torch.einsum(exprA, *[q.to(md) for q in Q], G.to(md))
|
606
621
|
order = G.dim()
|
607
622
|
p = list(range(order))
|
608
623
|
conjB = torch.permute(V.conj(), p[1:] + p[:1])
|
@@ -653,7 +668,10 @@ def psgd_update_precond(Q, exprs, V, G, step, tiny):
|
|
653
668
|
|
654
669
|
term2 += term1 # a + b
|
655
670
|
term1 *= 2 # 2a
|
656
|
-
term1
|
671
|
+
if term1.dtype == term2.dtype:
|
672
|
+
term1 -= term2 # 2a - (a + b) == a - b
|
673
|
+
else:
|
674
|
+
term1 = term1 - term2
|
657
675
|
|
658
676
|
term1 *= step
|
659
677
|
norm = term2.norm(float('inf'))
|
@@ -669,7 +687,8 @@ def psgd_update_precond(Q, exprs, V, G, step, tiny):
|
|
669
687
|
@decorator
|
670
688
|
def psgd_precond_grad(Q, exprs, G, inplace: bool = False):
|
671
689
|
"""Precondition gradient G with preconditioner Q."""
|
672
|
-
|
690
|
+
md = min_dtype(Q)
|
691
|
+
out = torch.einsum(exprs[-1], *[q.conj().to(md) for q in Q], *[q.to(md) for q in Q], G.to(md))
|
673
692
|
if inplace:
|
674
693
|
set_(G, out)
|
675
694
|
return G
|
@@ -787,14 +806,15 @@ class PSGDBase(StatefulOptimizer):
|
|
787
806
|
if g.dim() > 1:
|
788
807
|
psgd_balance_Q(q)
|
789
808
|
|
790
|
-
def do_update(self, p_list, grad_list, q_list, precond_lr, original_q: Optional[List] = None,
|
809
|
+
def do_update(self, p_list, grad_list, q_list, precond_lr, original_q: Optional[List] = None,
|
810
|
+
store_triu_as_line=False):
|
791
811
|
for i, (p, grad, Q) in enumerate(zip(p_list, grad_list, q_list)):
|
792
812
|
psgd_update_precond(Q, self.state_(p)["exprs"], torch.randn_like(grad), grad, precond_lr, self._tiny)
|
793
813
|
if original_q:
|
794
814
|
if store_triu_as_line:
|
795
815
|
update_triu_(original_q[i], Q)
|
796
816
|
else:
|
797
|
-
|
817
|
+
copy_stochastic_list_(original_q[i], Q)
|
798
818
|
|
799
819
|
|
800
820
|
def precond_update_prob_schedule(max_prob=1.0, min_prob=0.03, decay=0.001, flat_start=250):
|
@@ -16,9 +16,9 @@ heavyball/precond_schedule_sfpsoap.py,sha256=vq7jd302refKPa_9X2lkOTOtCCcTBVByPdo
|
|
16
16
|
heavyball/psgd_kron.py,sha256=2IpPj2TOExNGm8hSewi3er2GczJRNgC7r2J5yYSSA_0,5998
|
17
17
|
heavyball/pure_psgd.py,sha256=uA7W9a3Qm1sxHQhtNxaUYrmE5x55lP5iJOKy_qT8XaQ,5341
|
18
18
|
heavyball/schedule_free_palm_foreach_soap.py,sha256=zkcikH5wWbzq4kOrmBjilvY3iWzuUddcv2HNEPKr3MI,6366
|
19
|
-
heavyball/utils.py,sha256=
|
20
|
-
heavyball-0.17.
|
21
|
-
heavyball-0.17.
|
22
|
-
heavyball-0.17.
|
23
|
-
heavyball-0.17.
|
24
|
-
heavyball-0.17.
|
19
|
+
heavyball/utils.py,sha256=wEwg-TElm56CEf-Q_F7wgJdz17P4Hp0CakALGJj1090,29563
|
20
|
+
heavyball-0.17.3.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
|
21
|
+
heavyball-0.17.3.dist-info/METADATA,sha256=QMYWssjAbuY_eYSGRVuSOVeFrMiq7YZA4-qJe8E8u74,11810
|
22
|
+
heavyball-0.17.3.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
23
|
+
heavyball-0.17.3.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
|
24
|
+
heavyball-0.17.3.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|