heavyball 0.17.1__py3-none-any.whl → 0.17.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- heavyball/utils.py +28 -11
- {heavyball-0.17.1.dist-info → heavyball-0.17.2.dist-info}/METADATA +1 -1
- {heavyball-0.17.1.dist-info → heavyball-0.17.2.dist-info}/RECORD +6 -6
- {heavyball-0.17.1.dist-info → heavyball-0.17.2.dist-info}/LICENSE +0 -0
- {heavyball-0.17.1.dist-info → heavyball-0.17.2.dist-info}/WHEEL +0 -0
- {heavyball-0.17.1.dist-info → heavyball-0.17.2.dist-info}/top_level.txt +0 -0
heavyball/utils.py
CHANGED
@@ -332,6 +332,16 @@ def promote(x):
|
|
332
332
|
return x
|
333
333
|
|
334
334
|
|
335
|
+
def min_dtype(xs: List[torch.Tensor]):
|
336
|
+
dtypes = [x.dtype for x in xs]
|
337
|
+
for d in (torch.float32, torch.bfloat16, torch.float16):
|
338
|
+
if all(d == x for x in dtypes):
|
339
|
+
return d
|
340
|
+
if all(d in (x, torch.float32, torch.float64) for x in dtypes):
|
341
|
+
return d
|
342
|
+
return torch.float32
|
343
|
+
|
344
|
+
|
335
345
|
def update_preconditioner(grad, state, max_precond_dim, precondition_1d, beta, update_precond):
|
336
346
|
"""
|
337
347
|
Updates the preconditioner matrices and the eigenbases (L, R, Q_L, Q_R in the paper).
|
@@ -471,13 +481,8 @@ def copy_stochastic_list_(target: List[torch.Tensor], source: List[torch.Tensor]
|
|
471
481
|
copy_stochastic_(t, s)
|
472
482
|
|
473
483
|
|
474
|
-
|
475
|
-
|
476
|
-
return
|
477
|
-
if target.dtype != torch.bfloat16:
|
478
|
-
set_(target, source)
|
479
|
-
return
|
480
|
-
|
484
|
+
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
|
485
|
+
def _compilable_copy_stochastic_(target: torch.Tensor, source: torch.Tensor):
|
481
486
|
"""Taken as-is from https://github.com/pytorch/pytorch/issues/120376#issuecomment-1974828905"""
|
482
487
|
# create a random 16 bit integer
|
483
488
|
result = torch.randint_like(source, dtype=torch.int32, low=0, high=(1 << 16))
|
@@ -492,6 +497,15 @@ def copy_stochastic_(target: torch.Tensor, source: torch.Tensor):
|
|
492
497
|
target.copy_(result.view(dtype=torch.float32))
|
493
498
|
|
494
499
|
|
500
|
+
def copy_stochastic_(target: torch.Tensor, source: torch.Tensor):
|
501
|
+
if target.data_ptr() == source.data_ptr():
|
502
|
+
return
|
503
|
+
if target.dtype != torch.bfloat16 or source.dtype not in (torch.float16, torch.float32, torch.float64):
|
504
|
+
set_(target, source)
|
505
|
+
return
|
506
|
+
_compilable_copy_stochastic_(target, source)
|
507
|
+
|
508
|
+
|
495
509
|
def update_param_(param: List[torch.Tensor], update: List[torch.Tensor], lr: float, decay: float,
|
496
510
|
add_fn: callable = None):
|
497
511
|
param32 = [promote(p) for p in param]
|
@@ -602,7 +616,8 @@ def psgd_balance_Q(Q_in):
|
|
602
616
|
|
603
617
|
|
604
618
|
def psgd_calc_A_and_conjB(exprA, G, Q, V):
|
605
|
-
|
619
|
+
md = min_dtype(Q)
|
620
|
+
A = torch.einsum(exprA, *[q.to(md) for q in Q], G.to(md))
|
606
621
|
order = G.dim()
|
607
622
|
p = list(range(order))
|
608
623
|
conjB = torch.permute(V.conj(), p[1:] + p[:1])
|
@@ -669,7 +684,8 @@ def psgd_update_precond(Q, exprs, V, G, step, tiny):
|
|
669
684
|
@decorator
|
670
685
|
def psgd_precond_grad(Q, exprs, G, inplace: bool = False):
|
671
686
|
"""Precondition gradient G with preconditioner Q."""
|
672
|
-
|
687
|
+
md = min_dtype(Q)
|
688
|
+
out = torch.einsum(exprs[-1], *[q.conj().to(md) for q in Q], *[q.to(md) for q in Q], G.to(md))
|
673
689
|
if inplace:
|
674
690
|
set_(G, out)
|
675
691
|
return G
|
@@ -787,14 +803,15 @@ class PSGDBase(StatefulOptimizer):
|
|
787
803
|
if g.dim() > 1:
|
788
804
|
psgd_balance_Q(q)
|
789
805
|
|
790
|
-
def do_update(self, p_list, grad_list, q_list, precond_lr, original_q: Optional[List] = None,
|
806
|
+
def do_update(self, p_list, grad_list, q_list, precond_lr, original_q: Optional[List] = None,
|
807
|
+
store_triu_as_line=False):
|
791
808
|
for i, (p, grad, Q) in enumerate(zip(p_list, grad_list, q_list)):
|
792
809
|
psgd_update_precond(Q, self.state_(p)["exprs"], torch.randn_like(grad), grad, precond_lr, self._tiny)
|
793
810
|
if original_q:
|
794
811
|
if store_triu_as_line:
|
795
812
|
update_triu_(original_q[i], Q)
|
796
813
|
else:
|
797
|
-
|
814
|
+
copy_stochastic_list_(original_q[i], Q)
|
798
815
|
|
799
816
|
|
800
817
|
def precond_update_prob_schedule(max_prob=1.0, min_prob=0.03, decay=0.001, flat_start=250):
|
@@ -16,9 +16,9 @@ heavyball/precond_schedule_sfpsoap.py,sha256=vq7jd302refKPa_9X2lkOTOtCCcTBVByPdo
|
|
16
16
|
heavyball/psgd_kron.py,sha256=2IpPj2TOExNGm8hSewi3er2GczJRNgC7r2J5yYSSA_0,5998
|
17
17
|
heavyball/pure_psgd.py,sha256=uA7W9a3Qm1sxHQhtNxaUYrmE5x55lP5iJOKy_qT8XaQ,5341
|
18
18
|
heavyball/schedule_free_palm_foreach_soap.py,sha256=zkcikH5wWbzq4kOrmBjilvY3iWzuUddcv2HNEPKr3MI,6366
|
19
|
-
heavyball/utils.py,sha256=
|
20
|
-
heavyball-0.17.
|
21
|
-
heavyball-0.17.
|
22
|
-
heavyball-0.17.
|
23
|
-
heavyball-0.17.
|
24
|
-
heavyball-0.17.
|
19
|
+
heavyball/utils.py,sha256=8dhagAGj03D7kBEWOJmqsCjQKP069e1WwrzVp1JsBr8,29472
|
20
|
+
heavyball-0.17.2.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
|
21
|
+
heavyball-0.17.2.dist-info/METADATA,sha256=kCK9J8gg-6lj0qao7S-yDc7jsOzGxNJ6F4_JUFaiIR4,11810
|
22
|
+
heavyball-0.17.2.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
23
|
+
heavyball-0.17.2.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
|
24
|
+
heavyball-0.17.2.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|