heavyball 0.20.0__py3-none-any.whl → 0.21.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- heavyball/utils.py +40 -27
- {heavyball-0.20.0.dist-info → heavyball-0.21.0.dist-info}/METADATA +2 -2
- {heavyball-0.20.0.dist-info → heavyball-0.21.0.dist-info}/RECORD +6 -6
- {heavyball-0.20.0.dist-info → heavyball-0.21.0.dist-info}/LICENSE +0 -0
- {heavyball-0.20.0.dist-info → heavyball-0.21.0.dist-info}/WHEEL +0 -0
- {heavyball-0.20.0.dist-info → heavyball-0.21.0.dist-info}/top_level.txt +0 -0
heavyball/utils.py
CHANGED
@@ -329,23 +329,33 @@ def get_orthogonal_matrix(mat):
|
|
329
329
|
|
330
330
|
|
331
331
|
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
|
332
|
-
def
|
333
|
-
|
334
|
-
|
332
|
+
def _compilable_stochastic_lerp_(x: List[torch.Tensor], y: List[torch.Tensor], a: Union[float, int, torch.Tensor]):
|
333
|
+
for x_, y_ in zip(x, y):
|
334
|
+
x32 = promote(x_)
|
335
|
+
y32 = promote(y_)
|
336
|
+
x32.lerp_(y32, a)
|
337
|
+
copy_stochastic_(x_, x32)
|
335
338
|
|
336
|
-
torch._foreach_lerp_(x32, y32, a)
|
337
339
|
|
338
|
-
|
340
|
+
def stochastic_lerp_(x: List[torch.Tensor], y: List[torch.Tensor], a: Union[float, int, torch.Tensor]):
|
341
|
+
if not isinstance(a, torch.Tensor):
|
342
|
+
a = torch.empty((), dtype=torch.float32, device=x[0].device).fill_(a)
|
343
|
+
_compilable_stochastic_lerp_(x, y, a)
|
339
344
|
|
340
345
|
|
341
346
|
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
|
342
|
-
def
|
343
|
-
|
344
|
-
|
347
|
+
def _compilable_stochastic_add_(x: List[torch.Tensor], y: List[torch.Tensor], alpha: Union[float, int, torch.Tensor]):
|
348
|
+
for x_, y_ in zip(x, y):
|
349
|
+
x32 = promote(x_)
|
350
|
+
y32 = promote(y_)
|
351
|
+
x32.add_(y32, alpha=alpha)
|
352
|
+
copy_stochastic_(x_, x32)
|
345
353
|
|
346
|
-
[x_.add_(y_, alpha=alpha) for x_, y_ in zip(x32, y32)]
|
347
354
|
|
348
|
-
|
355
|
+
def stochastic_add_(x: List[torch.Tensor], y: List[torch.Tensor], alpha: Union[float, int, torch.Tensor]):
|
356
|
+
if not isinstance(alpha, torch.Tensor):
|
357
|
+
alpha = torch.empty((), dtype=torch.float32, device=x[0].device).fill_(alpha)
|
358
|
+
_compilable_stochastic_add_(x, y, alpha)
|
349
359
|
|
350
360
|
|
351
361
|
@decorator
|
@@ -572,7 +582,7 @@ def copy_stochastic_list_(target: List[torch.Tensor], source: List[torch.Tensor]
|
|
572
582
|
|
573
583
|
|
574
584
|
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
|
575
|
-
def
|
585
|
+
def _compilable_exp_avg_(exp_avg, exp_avg_sq, grad, grad_projected, beta1, beta2, step):
|
576
586
|
beta1 = beta_debias(beta1, step)
|
577
587
|
beta2 = beta_debias(beta2, step)
|
578
588
|
|
@@ -585,6 +595,18 @@ def exp_avg_(exp_avg, exp_avg_sq, grad, grad_projected, beta1, beta2, step):
|
|
585
595
|
return denom
|
586
596
|
|
587
597
|
|
598
|
+
def exp_avg_(exp_avg: List[torch.Tensor], exp_avg_sq: List[torch.Tensor], grad: List[torch.Tensor],
|
599
|
+
grad_projected: List[torch.Tensor], beta1: float, beta2: float, step: int):
|
600
|
+
if isinstance(beta1, float):
|
601
|
+
beta1 = torch.empty((), dtype=torch.float32, device=exp_avg[0].device).fill_(beta1)
|
602
|
+
if isinstance(beta2, float):
|
603
|
+
beta2 = torch.empty((), dtype=torch.float32, device=exp_avg[0].device).fill_(beta2)
|
604
|
+
if isinstance(step, int):
|
605
|
+
step = torch.empty((), dtype=torch.int32, device=exp_avg[0].device).fill_(step)
|
606
|
+
denom = _compilable_exp_avg_(exp_avg, exp_avg_sq, grad, grad_projected, beta1, beta2, step)
|
607
|
+
return denom
|
608
|
+
|
609
|
+
|
588
610
|
# this can be dynamic for most optimizers - just not for PSGD. So, it's disabled for all
|
589
611
|
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True)
|
590
612
|
def _compilable_copy_stochastic_(target: torch.Tensor, source: torch.Tensor):
|
@@ -721,7 +743,7 @@ def init_Q_exprs(t, scale, max_size, min_ndim_triangular, memory_save_mode, dtyp
|
|
721
743
|
return [Q, (exprA, tuple(exprGs), exprP)]
|
722
744
|
|
723
745
|
|
724
|
-
@
|
746
|
+
@decorator
|
725
747
|
def psgd_balance_Q(Q_in):
|
726
748
|
norms = torch.stack([q.norm(float("inf")) for q in Q_in])
|
727
749
|
geometric_mean = norms.log().mean().exp()
|
@@ -734,8 +756,7 @@ def psgd_calc_A_and_conjB(exprA, G, Q):
|
|
734
756
|
A = torch.einsum(exprA, *[q.to(md) for q in Q], G.to(md)).to(G.dtype)
|
735
757
|
order = G.dim()
|
736
758
|
p = list(range(order))
|
737
|
-
|
738
|
-
conjB = torch.permute(V, p[1:] + p[:1])
|
759
|
+
conjB = torch.randn(G.shape[1:] + G.shape[:1], dtype=promote(G.dtype), device=G.device)
|
739
760
|
Q = [promote(q) for q in Q]
|
740
761
|
for i, q in enumerate(Q):
|
741
762
|
if q.dim() <= 1:
|
@@ -755,27 +776,19 @@ def psgd_calc_A_and_conjB(exprA, G, Q):
|
|
755
776
|
def psgd_lb(A, max_abs):
|
756
777
|
A /= max_abs
|
757
778
|
a0 = torch.einsum('ij,ij->j', A, A)
|
758
|
-
a1 = torch.einsum('ij,ij->i', A, A)
|
759
|
-
value0 = torch.max(a0)
|
760
|
-
value1 = torch.max(a1)
|
761
779
|
i = torch.argmax(a0)
|
762
|
-
j = torch.argmax(a1)
|
763
780
|
|
764
|
-
|
765
|
-
x = torch.cond(comp, lambda a: torch.index_select(a, 1, i).flatten().contiguous(), #
|
766
|
-
lambda a: torch.index_select(a, 0, j).flatten().contiguous(), (A,))
|
781
|
+
x = torch.index_select(A, 1, i).flatten().contiguous()
|
767
782
|
|
768
|
-
x = torch.
|
769
|
-
(x, A,))
|
783
|
+
x = torch.einsum('i,ij->j', x, A)
|
770
784
|
x /= x.norm()
|
771
|
-
x = torch.
|
772
|
-
(x, A,))
|
785
|
+
x = torch.einsum('j,kj->k', x, A)
|
773
786
|
x = x.norm()
|
774
787
|
x *= max_abs
|
775
788
|
return x
|
776
789
|
|
777
790
|
|
778
|
-
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=
|
791
|
+
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
|
779
792
|
def psgd_update_precond(Q, exprs, G, precond_lr, tiny, oq, store_triu_as_line):
|
780
793
|
"""Update Kronecker product preconditioner Q with pair (V, G)."""
|
781
794
|
exprA, exprGs, _ = exprs
|
@@ -808,7 +821,7 @@ def psgd_update_precond(Q, exprs, G, precond_lr, tiny, oq, store_triu_as_line):
|
|
808
821
|
stochastic_add_([o], [term1], -1)
|
809
822
|
|
810
823
|
|
811
|
-
@
|
824
|
+
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
|
812
825
|
def psgd_precond_grad(Q, exprs, G, inplace: bool = False):
|
813
826
|
"""Precondition gradient G with preconditioner Q."""
|
814
827
|
md = min_dtype(Q)
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: heavyball
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.21.0
|
4
4
|
Summary: Efficient optimizers
|
5
5
|
Home-page: https://github.com/clashluke/heavyball
|
6
6
|
Author: Lucas Nestler
|
@@ -32,7 +32,7 @@ A simple package of efficient optimizers
|
|
32
32
|
The goal is not to thrive for completeness, full maintenance or abstraction, but instead to provide a simple
|
33
33
|
largely static alternative to `torch.optim` with more and better optimizers.
|
34
34
|
|
35
|
-
Currently (2024-11-22, 0.
|
35
|
+
Currently (2024-11-22, 0.21.0), the recommended stable optimizer is `PrecondSchedulePaLMSOAP` (see below). The
|
36
36
|
recommended experimental optimizer is `DelayedPSGDKron` ([tuning guide](docs/psgd_efficiency.md)).
|
37
37
|
|
38
38
|
## Features
|
@@ -16,9 +16,9 @@ heavyball/precond_schedule_sfpsoap.py,sha256=PNneiOkrRyV1yIZn91lPmYofd1_OiLqJTDy
|
|
16
16
|
heavyball/psgd_kron.py,sha256=RSLJi5FnSmYjvYBufNDClnYRm-eN_Kpa1Ar2tNP6-X0,5824
|
17
17
|
heavyball/pure_psgd.py,sha256=LZK0qmvZkBF8g00evaVLtW-sIUJmdoxag1K7O26AqEo,4820
|
18
18
|
heavyball/schedule_free_palm_foreach_soap.py,sha256=_AqrnChY6iDlQUkF2YUxS7eLjSWCIuvEUOHvMHVM1yY,6873
|
19
|
-
heavyball/utils.py,sha256=
|
20
|
-
heavyball-0.
|
21
|
-
heavyball-0.
|
22
|
-
heavyball-0.
|
23
|
-
heavyball-0.
|
24
|
-
heavyball-0.
|
19
|
+
heavyball/utils.py,sha256=H8RsADNAXVbjQ9wWstYIKkXMq9E81aUF1j-2wfCeSLA,36471
|
20
|
+
heavyball-0.21.0.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
|
21
|
+
heavyball-0.21.0.dist-info/METADATA,sha256=hbXhr4XcPAkgfW8hpgoPRPrUoKeTCTvhZdofj4h8_8c,11926
|
22
|
+
heavyball-0.21.0.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
23
|
+
heavyball-0.21.0.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
|
24
|
+
heavyball-0.21.0.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|