heavyball 0.20.1__tar.gz → 0.21.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {heavyball-0.20.1 → heavyball-0.21.0}/PKG-INFO +2 -2
- {heavyball-0.20.1 → heavyball-0.21.0}/README.md +1 -1
- {heavyball-0.20.1 → heavyball-0.21.0}/heavyball/utils.py +39 -17
- {heavyball-0.20.1 → heavyball-0.21.0}/heavyball.egg-info/PKG-INFO +2 -2
- {heavyball-0.20.1 → heavyball-0.21.0}/setup.py +1 -1
- {heavyball-0.20.1 → heavyball-0.21.0}/LICENSE +0 -0
- {heavyball-0.20.1 → heavyball-0.21.0}/heavyball/__init__.py +0 -0
- {heavyball-0.20.1 → heavyball-0.21.0}/heavyball/cached_delayed_psgd_kron.py +0 -0
- {heavyball-0.20.1 → heavyball-0.21.0}/heavyball/cached_psgd_kron.py +0 -0
- {heavyball-0.20.1 → heavyball-0.21.0}/heavyball/delayed_psgd.py +0 -0
- {heavyball-0.20.1 → heavyball-0.21.0}/heavyball/foreach_adamw.py +0 -0
- {heavyball-0.20.1 → heavyball-0.21.0}/heavyball/foreach_adopt.py +0 -0
- {heavyball-0.20.1 → heavyball-0.21.0}/heavyball/foreach_laprop.py +0 -0
- {heavyball-0.20.1 → heavyball-0.21.0}/heavyball/foreach_sfadamw.py +0 -0
- {heavyball-0.20.1 → heavyball-0.21.0}/heavyball/foreach_soap.py +0 -0
- {heavyball-0.20.1 → heavyball-0.21.0}/heavyball/p_adam.py +0 -0
- {heavyball-0.20.1 → heavyball-0.21.0}/heavyball/palm_foreach_sfadamw.py +0 -0
- {heavyball-0.20.1 → heavyball-0.21.0}/heavyball/palm_foreach_soap.py +0 -0
- {heavyball-0.20.1 → heavyball-0.21.0}/heavyball/precond_schedule_foreach_soap.py +0 -0
- {heavyball-0.20.1 → heavyball-0.21.0}/heavyball/precond_schedule_palm_foreach_soap.py +0 -0
- {heavyball-0.20.1 → heavyball-0.21.0}/heavyball/precond_schedule_sfpsoap.py +0 -0
- {heavyball-0.20.1 → heavyball-0.21.0}/heavyball/psgd_kron.py +0 -0
- {heavyball-0.20.1 → heavyball-0.21.0}/heavyball/pure_psgd.py +0 -0
- {heavyball-0.20.1 → heavyball-0.21.0}/heavyball/schedule_free_palm_foreach_soap.py +0 -0
- {heavyball-0.20.1 → heavyball-0.21.0}/heavyball.egg-info/SOURCES.txt +0 -0
- {heavyball-0.20.1 → heavyball-0.21.0}/heavyball.egg-info/dependency_links.txt +0 -0
- {heavyball-0.20.1 → heavyball-0.21.0}/heavyball.egg-info/requires.txt +0 -0
- {heavyball-0.20.1 → heavyball-0.21.0}/heavyball.egg-info/top_level.txt +0 -0
- {heavyball-0.20.1 → heavyball-0.21.0}/setup.cfg +0 -0
- {heavyball-0.20.1 → heavyball-0.21.0}/test/test_bf16_params.py +0 -0
- {heavyball-0.20.1 → heavyball-0.21.0}/test/test_bf16_q.py +0 -0
- {heavyball-0.20.1 → heavyball-0.21.0}/test/test_bf16_storage.py +0 -0
- {heavyball-0.20.1 → heavyball-0.21.0}/test/test_closure.py +0 -0
- {heavyball-0.20.1 → heavyball-0.21.0}/test/test_ema.py +0 -0
- {heavyball-0.20.1 → heavyball-0.21.0}/test/test_foreach.py +0 -0
- {heavyball-0.20.1 → heavyball-0.21.0}/test/test_memory.py +0 -0
- {heavyball-0.20.1 → heavyball-0.21.0}/test/test_merge.py +0 -0
- {heavyball-0.20.1 → heavyball-0.21.0}/test/test_no_grad.py +0 -0
- {heavyball-0.20.1 → heavyball-0.21.0}/test/test_psgd.py +0 -0
- {heavyball-0.20.1 → heavyball-0.21.0}/test/test_soap.py +0 -0
- {heavyball-0.20.1 → heavyball-0.21.0}/test/test_stochastic_updates.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: heavyball
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.21.0
|
4
4
|
Summary: Efficient optimizers
|
5
5
|
Home-page: https://github.com/clashluke/heavyball
|
6
6
|
Author: Lucas Nestler
|
@@ -32,7 +32,7 @@ A simple package of efficient optimizers
|
|
32
32
|
The goal is not to thrive for completeness, full maintenance or abstraction, but instead to provide a simple
|
33
33
|
largely static alternative to `torch.optim` with more and better optimizers.
|
34
34
|
|
35
|
-
Currently (2024-11-22, 0.
|
35
|
+
Currently (2024-11-22, 0.21.0), the recommended stable optimizer is `PrecondSchedulePaLMSOAP` (see below). The
|
36
36
|
recommended experimental optimizer is `DelayedPSGDKron` ([tuning guide](docs/psgd_efficiency.md)).
|
37
37
|
|
38
38
|
## Features
|
@@ -8,7 +8,7 @@ A simple package of efficient optimizers
|
|
8
8
|
The goal is not to thrive for completeness, full maintenance or abstraction, but instead to provide a simple
|
9
9
|
largely static alternative to `torch.optim` with more and better optimizers.
|
10
10
|
|
11
|
-
Currently (2024-11-22, 0.
|
11
|
+
Currently (2024-11-22, 0.21.0), the recommended stable optimizer is `PrecondSchedulePaLMSOAP` (see below). The
|
12
12
|
recommended experimental optimizer is `DelayedPSGDKron` ([tuning guide](docs/psgd_efficiency.md)).
|
13
13
|
|
14
14
|
## Features
|
@@ -329,23 +329,33 @@ def get_orthogonal_matrix(mat):
|
|
329
329
|
|
330
330
|
|
331
331
|
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
|
332
|
-
def
|
333
|
-
|
334
|
-
|
332
|
+
def _compilable_stochastic_lerp_(x: List[torch.Tensor], y: List[torch.Tensor], a: Union[float, int, torch.Tensor]):
|
333
|
+
for x_, y_ in zip(x, y):
|
334
|
+
x32 = promote(x_)
|
335
|
+
y32 = promote(y_)
|
336
|
+
x32.lerp_(y32, a)
|
337
|
+
copy_stochastic_(x_, x32)
|
335
338
|
|
336
|
-
torch._foreach_lerp_(x32, y32, a)
|
337
339
|
|
338
|
-
|
340
|
+
def stochastic_lerp_(x: List[torch.Tensor], y: List[torch.Tensor], a: Union[float, int, torch.Tensor]):
|
341
|
+
if not isinstance(a, torch.Tensor):
|
342
|
+
a = torch.empty((), dtype=torch.float32, device=x[0].device).fill_(a)
|
343
|
+
_compilable_stochastic_lerp_(x, y, a)
|
339
344
|
|
340
345
|
|
341
346
|
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
|
342
|
-
def
|
343
|
-
|
344
|
-
|
347
|
+
def _compilable_stochastic_add_(x: List[torch.Tensor], y: List[torch.Tensor], alpha: Union[float, int, torch.Tensor]):
|
348
|
+
for x_, y_ in zip(x, y):
|
349
|
+
x32 = promote(x_)
|
350
|
+
y32 = promote(y_)
|
351
|
+
x32.add_(y32, alpha=alpha)
|
352
|
+
copy_stochastic_(x_, x32)
|
345
353
|
|
346
|
-
[x_.add_(y_, alpha=alpha) for x_, y_ in zip(x32, y32)]
|
347
354
|
|
348
|
-
|
355
|
+
def stochastic_add_(x: List[torch.Tensor], y: List[torch.Tensor], alpha: Union[float, int, torch.Tensor]):
|
356
|
+
if not isinstance(alpha, torch.Tensor):
|
357
|
+
alpha = torch.empty((), dtype=torch.float32, device=x[0].device).fill_(alpha)
|
358
|
+
_compilable_stochastic_add_(x, y, alpha)
|
349
359
|
|
350
360
|
|
351
361
|
@decorator
|
@@ -572,7 +582,7 @@ def copy_stochastic_list_(target: List[torch.Tensor], source: List[torch.Tensor]
|
|
572
582
|
|
573
583
|
|
574
584
|
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
|
575
|
-
def
|
585
|
+
def _compilable_exp_avg_(exp_avg, exp_avg_sq, grad, grad_projected, beta1, beta2, step):
|
576
586
|
beta1 = beta_debias(beta1, step)
|
577
587
|
beta2 = beta_debias(beta2, step)
|
578
588
|
|
@@ -585,6 +595,18 @@ def exp_avg_(exp_avg, exp_avg_sq, grad, grad_projected, beta1, beta2, step):
|
|
585
595
|
return denom
|
586
596
|
|
587
597
|
|
598
|
+
def exp_avg_(exp_avg: List[torch.Tensor], exp_avg_sq: List[torch.Tensor], grad: List[torch.Tensor],
|
599
|
+
grad_projected: List[torch.Tensor], beta1: float, beta2: float, step: int):
|
600
|
+
if isinstance(beta1, float):
|
601
|
+
beta1 = torch.empty((), dtype=torch.float32, device=exp_avg[0].device).fill_(beta1)
|
602
|
+
if isinstance(beta2, float):
|
603
|
+
beta2 = torch.empty((), dtype=torch.float32, device=exp_avg[0].device).fill_(beta2)
|
604
|
+
if isinstance(step, int):
|
605
|
+
step = torch.empty((), dtype=torch.int32, device=exp_avg[0].device).fill_(step)
|
606
|
+
denom = _compilable_exp_avg_(exp_avg, exp_avg_sq, grad, grad_projected, beta1, beta2, step)
|
607
|
+
return denom
|
608
|
+
|
609
|
+
|
588
610
|
# this can be dynamic for most optimizers - just not for PSGD. So, it's disabled for all
|
589
611
|
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True)
|
590
612
|
def _compilable_copy_stochastic_(target: torch.Tensor, source: torch.Tensor):
|
@@ -734,7 +756,7 @@ def psgd_calc_A_and_conjB(exprA, G, Q):
|
|
734
756
|
A = torch.einsum(exprA, *[q.to(md) for q in Q], G.to(md)).to(G.dtype)
|
735
757
|
order = G.dim()
|
736
758
|
p = list(range(order))
|
737
|
-
conjB = torch.
|
759
|
+
conjB = torch.randn(G.shape[1:] + G.shape[:1], dtype=promote(G.dtype), device=G.device)
|
738
760
|
Q = [promote(q) for q in Q]
|
739
761
|
for i, q in enumerate(Q):
|
740
762
|
if q.dim() <= 1:
|
@@ -756,17 +778,17 @@ def psgd_lb(A, max_abs):
|
|
756
778
|
a0 = torch.einsum('ij,ij->j', A, A)
|
757
779
|
i = torch.argmax(a0)
|
758
780
|
|
759
|
-
x = torch.index_select(
|
781
|
+
x = torch.index_select(A, 1, i).flatten().contiguous()
|
760
782
|
|
761
|
-
x = torch.einsum('i,ij->j',
|
783
|
+
x = torch.einsum('i,ij->j', x, A)
|
762
784
|
x /= x.norm()
|
763
|
-
x = torch.einsum('j,kj->k',
|
785
|
+
x = torch.einsum('j,kj->k', x, A)
|
764
786
|
x = x.norm()
|
765
787
|
x *= max_abs
|
766
788
|
return x
|
767
789
|
|
768
790
|
|
769
|
-
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=
|
791
|
+
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
|
770
792
|
def psgd_update_precond(Q, exprs, G, precond_lr, tiny, oq, store_triu_as_line):
|
771
793
|
"""Update Kronecker product preconditioner Q with pair (V, G)."""
|
772
794
|
exprA, exprGs, _ = exprs
|
@@ -799,7 +821,7 @@ def psgd_update_precond(Q, exprs, G, precond_lr, tiny, oq, store_triu_as_line):
|
|
799
821
|
stochastic_add_([o], [term1], -1)
|
800
822
|
|
801
823
|
|
802
|
-
@
|
824
|
+
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
|
803
825
|
def psgd_precond_grad(Q, exprs, G, inplace: bool = False):
|
804
826
|
"""Precondition gradient G with preconditioner Q."""
|
805
827
|
md = min_dtype(Q)
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: heavyball
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.21.0
|
4
4
|
Summary: Efficient optimizers
|
5
5
|
Home-page: https://github.com/clashluke/heavyball
|
6
6
|
Author: Lucas Nestler
|
@@ -32,7 +32,7 @@ A simple package of efficient optimizers
|
|
32
32
|
The goal is not to thrive for completeness, full maintenance or abstraction, but instead to provide a simple
|
33
33
|
largely static alternative to `torch.optim` with more and better optimizers.
|
34
34
|
|
35
|
-
Currently (2024-11-22, 0.
|
35
|
+
Currently (2024-11-22, 0.21.0), the recommended stable optimizer is `PrecondSchedulePaLMSOAP` (see below). The
|
36
36
|
recommended experimental optimizer is `DelayedPSGDKron` ([tuning guide](docs/psgd_efficiency.md)).
|
37
37
|
|
38
38
|
## Features
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|