heavyball 0.17.0__tar.gz → 0.17.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {heavyball-0.17.0 → heavyball-0.17.2}/PKG-INFO +17 -17
- {heavyball-0.17.0 → heavyball-0.17.2}/README.md +16 -16
- {heavyball-0.17.0 → heavyball-0.17.2}/heavyball/__init__.py +6 -7
- {heavyball-0.17.0 → heavyball-0.17.2}/heavyball/utils.py +28 -11
- {heavyball-0.17.0 → heavyball-0.17.2}/heavyball.egg-info/PKG-INFO +17 -17
- {heavyball-0.17.0 → heavyball-0.17.2}/setup.py +1 -1
- {heavyball-0.17.0 → heavyball-0.17.2}/LICENSE +0 -0
- {heavyball-0.17.0 → heavyball-0.17.2}/heavyball/cached_delayed_psgd_kron.py +0 -0
- {heavyball-0.17.0 → heavyball-0.17.2}/heavyball/cached_psgd_kron.py +0 -0
- {heavyball-0.17.0 → heavyball-0.17.2}/heavyball/delayed_psgd.py +0 -0
- {heavyball-0.17.0 → heavyball-0.17.2}/heavyball/foreach_adamw.py +0 -0
- {heavyball-0.17.0 → heavyball-0.17.2}/heavyball/foreach_adopt.py +0 -0
- {heavyball-0.17.0 → heavyball-0.17.2}/heavyball/foreach_laprop.py +0 -0
- {heavyball-0.17.0 → heavyball-0.17.2}/heavyball/foreach_sfadamw.py +0 -0
- {heavyball-0.17.0 → heavyball-0.17.2}/heavyball/foreach_soap.py +0 -0
- {heavyball-0.17.0 → heavyball-0.17.2}/heavyball/p_adam.py +0 -0
- {heavyball-0.17.0 → heavyball-0.17.2}/heavyball/palm_foreach_sfadamw.py +0 -0
- {heavyball-0.17.0 → heavyball-0.17.2}/heavyball/palm_foreach_soap.py +0 -0
- {heavyball-0.17.0 → heavyball-0.17.2}/heavyball/precond_schedule_foreach_soap.py +0 -0
- {heavyball-0.17.0 → heavyball-0.17.2}/heavyball/precond_schedule_palm_foreach_soap.py +0 -0
- {heavyball-0.17.0 → heavyball-0.17.2}/heavyball/precond_schedule_sfpsoap.py +0 -0
- {heavyball-0.17.0 → heavyball-0.17.2}/heavyball/psgd_kron.py +0 -0
- {heavyball-0.17.0 → heavyball-0.17.2}/heavyball/pure_psgd.py +0 -0
- {heavyball-0.17.0 → heavyball-0.17.2}/heavyball/schedule_free_palm_foreach_soap.py +0 -0
- {heavyball-0.17.0 → heavyball-0.17.2}/heavyball.egg-info/SOURCES.txt +0 -0
- {heavyball-0.17.0 → heavyball-0.17.2}/heavyball.egg-info/dependency_links.txt +0 -0
- {heavyball-0.17.0 → heavyball-0.17.2}/heavyball.egg-info/requires.txt +0 -0
- {heavyball-0.17.0 → heavyball-0.17.2}/heavyball.egg-info/top_level.txt +0 -0
- {heavyball-0.17.0 → heavyball-0.17.2}/setup.cfg +0 -0
- {heavyball-0.17.0 → heavyball-0.17.2}/test/test_bf16_q.py +0 -0
- {heavyball-0.17.0 → heavyball-0.17.2}/test/test_closure.py +0 -0
- {heavyball-0.17.0 → heavyball-0.17.2}/test/test_foreach.py +0 -0
- {heavyball-0.17.0 → heavyball-0.17.2}/test/test_memory.py +0 -0
- {heavyball-0.17.0 → heavyball-0.17.2}/test/test_merge.py +0 -0
- {heavyball-0.17.0 → heavyball-0.17.2}/test/test_no_grad.py +0 -0
- {heavyball-0.17.0 → heavyball-0.17.2}/test/test_psgd.py +0 -0
- {heavyball-0.17.0 → heavyball-0.17.2}/test/test_soap.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: heavyball
|
3
|
-
Version: 0.17.
|
3
|
+
Version: 0.17.2
|
4
4
|
Summary: Efficient optimizers
|
5
5
|
Home-page: https://github.com/clashluke/heavyball
|
6
6
|
Author: Lucas Nestler
|
@@ -32,8 +32,8 @@ A simple package of efficient optimizers
|
|
32
32
|
The goal is not to thrive for completeness, full maintenance or abstraction, but instead to provide a simple
|
33
33
|
largely static alternative to `torch.optim` with more and better optimizers.
|
34
34
|
|
35
|
-
Currently (2024-11-
|
36
|
-
recommended experimental optimizer is `
|
35
|
+
Currently (2024-11-20, 0.17.0), the recommended stable optimizer is `PrecondSchedulePaLMSOAP` (see below). The
|
36
|
+
recommended experimental optimizer is `DelayedPSGDKron` ([tuning guide](docs/psgd_efficiency.md)).
|
37
37
|
|
38
38
|
## Features
|
39
39
|
|
@@ -62,7 +62,7 @@ import heavyball
|
|
62
62
|
model = torch.nn.Linear(16, 1)
|
63
63
|
|
64
64
|
# Create an optimizer
|
65
|
-
optimizer = heavyball.
|
65
|
+
optimizer = heavyball.PrecondSchedulePaLMSOAP(model.parameters(), lr=1e-3)
|
66
66
|
|
67
67
|
x = torch.randn(128, 16)
|
68
68
|
y = torch.randn(128, 1)
|
@@ -76,19 +76,19 @@ for _ in range(1000):
|
|
76
76
|
|
77
77
|
## Optimizers
|
78
78
|
|
79
|
-
| Name
|
80
|
-
|
81
|
-
| **
|
82
|
-
| **
|
83
|
-
| **
|
84
|
-
| **
|
85
|
-
| **
|
86
|
-
| **
|
87
|
-
| **
|
88
|
-
| **
|
89
|
-
| **
|
90
|
-
| **
|
91
|
-
| **
|
79
|
+
| Name | Description | Advantages / Disadvantages |
|
80
|
+
|-------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
81
|
+
| **AdamW** | More efficient (speed, memory) [AdamW](https://arxiv.org/abs/1711.05101) | + Faster than AdamW<br>+ Possibly more (numerically) stable
|
82
|
+
| **LaProp** | More efficient (speed, memory) [LaProp](https://arxiv.org/abs/2002.04839) | + Same cost as AdamW<br>+ Marginally better converence (better proofs)<br>+ Higher hyperparameter stability<br>- Not a guaranteed win (can be neutral)<br>- No "Slingshot" |
|
83
|
+
| **ADOPT** | More efficient (speed, memory) [ADOPT](https://arxiv.org/abs/2411.02853) | + Same cost as AdamW<br>+ Rigorous mathematical convergence proofs, even for challenging models (GANs)<br>- Empirically underperforms LaProp<br>- no bf16 |
|
84
|
+
| **SFAdamW** | More efficient (speed, memory) [ScheduleFree AdamW](https://arxiv.org/abs/2405.15682) | + Same cost as AdamW, but better eval perf<br>+ Full control over hyperparameters |
|
85
|
+
| **PaLMSFAdamW** | ForeachSFAdamW with [PaLM's beta2 schedule](https://arxiv.org/abs/2204.02311) | + Same cost as AdamW, but better eval perf<br>+ Less control, but faster early and more stable late convergence<br>+ ScheduleFree<br>- slow early convergence |
|
86
|
+
| **SOAP** | More efficient (speed, memory) [SOAP](https://arxiv.org/abs/2409.11321) | + Faster convergence (loss-at-step)<br>+ Full control over hyperparameters<br>- more memory usage<br>- more hyperparameters<br>- higher overhead than AdamW (can be ammortized; better loss-at-second) |
|
87
|
+
| **PaLMSOAP** | ForeachSOAP with [PaLM's beta2 schedule](https://arxiv.org/abs/2204.02311) | + Faster convergence (loss-at-step)<br>+ Less control, but faster early and more stable late convergence<br>- more memory usage<br>- more hyperparameters<br>- higher overhead than AdamW (can be ammortized; better loss-at-second) |
|
88
|
+
| **SFPaLMSOAP** | ScheduleFree PaLMForeachSOAP | + Fast convergence (loss-at-step)<br>+ less memory usage than PaLMForeachSOAP (more tham AdamW)<br>- slower initial convergence than PaLMForeachSOAP (but allows higher LRs)<br>- higher overhead than AdamW (can be ammortized) |
|
89
|
+
| **PrecondScheduleSFPaLMSOAP** | SFPaLMForeachSOAP with [preconditioner schedule](https://github.com/lixilinx/psgd_torch/), matching the error of PrecondEvery=2 with the cost of PrecondEvery=512 | + Better initial convergence than SFPaLMForeachSOAP<br>+ Significantly faster (sec/it) later<br>+ less memory usage than PaLMForeachSOAP (more tham AdamW)<br>- slower initial convergence than PaLMForeachSOAP (but allows higher LRs)<br>- higher overhead than AdamW (can be ammortized), goes to 0 with increasing number of step |
|
90
|
+
| **PrecondSchedulePaLMSOAP** | PrecondScheduleSFPaLMForeachSOAP without schedule-free | + Best initial convergence<br>+ Significantly faster (sec/it) later<br>+ high stability<br>- more memory usage than PrecondScheduleSFPaLMForeachSOAP<br>- higher overhead than AdamW (can be ammortized), goes to 0 with increasing number of steps |
|
91
|
+
| **PrecondScheduleSOAP** | PrecondScheduleSFPaLMForeachSOAP without PaLM's beta2 schedule | + Better initial convergence<br>+ Significantly faster (sec/it) later<br>- more memory usage than PrecondScheduleSFPaLMForeachSOAP<br>- higher overhead than AdamW (can be ammortized), goes to 0 with increasing number of steps |
|
92
92
|
|
93
93
|
## Precond Schedule
|
94
94
|
|
@@ -8,8 +8,8 @@ A simple package of efficient optimizers
|
|
8
8
|
The goal is not to thrive for completeness, full maintenance or abstraction, but instead to provide a simple
|
9
9
|
largely static alternative to `torch.optim` with more and better optimizers.
|
10
10
|
|
11
|
-
Currently (2024-11-
|
12
|
-
recommended experimental optimizer is `
|
11
|
+
Currently (2024-11-20, 0.17.0), the recommended stable optimizer is `PrecondSchedulePaLMSOAP` (see below). The
|
12
|
+
recommended experimental optimizer is `DelayedPSGDKron` ([tuning guide](docs/psgd_efficiency.md)).
|
13
13
|
|
14
14
|
## Features
|
15
15
|
|
@@ -38,7 +38,7 @@ import heavyball
|
|
38
38
|
model = torch.nn.Linear(16, 1)
|
39
39
|
|
40
40
|
# Create an optimizer
|
41
|
-
optimizer = heavyball.
|
41
|
+
optimizer = heavyball.PrecondSchedulePaLMSOAP(model.parameters(), lr=1e-3)
|
42
42
|
|
43
43
|
x = torch.randn(128, 16)
|
44
44
|
y = torch.randn(128, 1)
|
@@ -52,19 +52,19 @@ for _ in range(1000):
|
|
52
52
|
|
53
53
|
## Optimizers
|
54
54
|
|
55
|
-
| Name
|
56
|
-
|
57
|
-
| **
|
58
|
-
| **
|
59
|
-
| **
|
60
|
-
| **
|
61
|
-
| **
|
62
|
-
| **
|
63
|
-
| **
|
64
|
-
| **
|
65
|
-
| **
|
66
|
-
| **
|
67
|
-
| **
|
55
|
+
| Name | Description | Advantages / Disadvantages |
|
56
|
+
|-------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
57
|
+
| **AdamW** | More efficient (speed, memory) [AdamW](https://arxiv.org/abs/1711.05101) | + Faster than AdamW<br>+ Possibly more (numerically) stable
|
58
|
+
| **LaProp** | More efficient (speed, memory) [LaProp](https://arxiv.org/abs/2002.04839) | + Same cost as AdamW<br>+ Marginally better converence (better proofs)<br>+ Higher hyperparameter stability<br>- Not a guaranteed win (can be neutral)<br>- No "Slingshot" |
|
59
|
+
| **ADOPT** | More efficient (speed, memory) [ADOPT](https://arxiv.org/abs/2411.02853) | + Same cost as AdamW<br>+ Rigorous mathematical convergence proofs, even for challenging models (GANs)<br>- Empirically underperforms LaProp<br>- no bf16 |
|
60
|
+
| **SFAdamW** | More efficient (speed, memory) [ScheduleFree AdamW](https://arxiv.org/abs/2405.15682) | + Same cost as AdamW, but better eval perf<br>+ Full control over hyperparameters |
|
61
|
+
| **PaLMSFAdamW** | ForeachSFAdamW with [PaLM's beta2 schedule](https://arxiv.org/abs/2204.02311) | + Same cost as AdamW, but better eval perf<br>+ Less control, but faster early and more stable late convergence<br>+ ScheduleFree<br>- slow early convergence |
|
62
|
+
| **SOAP** | More efficient (speed, memory) [SOAP](https://arxiv.org/abs/2409.11321) | + Faster convergence (loss-at-step)<br>+ Full control over hyperparameters<br>- more memory usage<br>- more hyperparameters<br>- higher overhead than AdamW (can be ammortized; better loss-at-second) |
|
63
|
+
| **PaLMSOAP** | ForeachSOAP with [PaLM's beta2 schedule](https://arxiv.org/abs/2204.02311) | + Faster convergence (loss-at-step)<br>+ Less control, but faster early and more stable late convergence<br>- more memory usage<br>- more hyperparameters<br>- higher overhead than AdamW (can be ammortized; better loss-at-second) |
|
64
|
+
| **SFPaLMSOAP** | ScheduleFree PaLMForeachSOAP | + Fast convergence (loss-at-step)<br>+ less memory usage than PaLMForeachSOAP (more tham AdamW)<br>- slower initial convergence than PaLMForeachSOAP (but allows higher LRs)<br>- higher overhead than AdamW (can be ammortized) |
|
65
|
+
| **PrecondScheduleSFPaLMSOAP** | SFPaLMForeachSOAP with [preconditioner schedule](https://github.com/lixilinx/psgd_torch/), matching the error of PrecondEvery=2 with the cost of PrecondEvery=512 | + Better initial convergence than SFPaLMForeachSOAP<br>+ Significantly faster (sec/it) later<br>+ less memory usage than PaLMForeachSOAP (more tham AdamW)<br>- slower initial convergence than PaLMForeachSOAP (but allows higher LRs)<br>- higher overhead than AdamW (can be ammortized), goes to 0 with increasing number of step |
|
66
|
+
| **PrecondSchedulePaLMSOAP** | PrecondScheduleSFPaLMForeachSOAP without schedule-free | + Best initial convergence<br>+ Significantly faster (sec/it) later<br>+ high stability<br>- more memory usage than PrecondScheduleSFPaLMForeachSOAP<br>- higher overhead than AdamW (can be ammortized), goes to 0 with increasing number of steps |
|
67
|
+
| **PrecondScheduleSOAP** | PrecondScheduleSFPaLMForeachSOAP without PaLM's beta2 schedule | + Better initial convergence<br>+ Significantly faster (sec/it) later<br>- more memory usage than PrecondScheduleSFPaLMForeachSOAP<br>- higher overhead than AdamW (can be ammortized), goes to 0 with increasing number of steps |
|
68
68
|
|
69
69
|
## Precond Schedule
|
70
70
|
|
@@ -21,26 +21,25 @@ PalmForEachSoap = PaLMForeachSOAP
|
|
21
21
|
PaLMSOAP = PaLMForeachSOAP
|
22
22
|
PaLMSFAdamW = PaLMForeachSFAdamW
|
23
23
|
PaLMSFSoap = SFPaLMForeachSOAP
|
24
|
-
PaLMForeachSOAP = PaLMForeachSOAP
|
25
24
|
PrecondScheduleSFPaLMSOAP = PrecondScheduleSFPaLMSOAP
|
26
25
|
SOAP = ForeachSOAP
|
27
26
|
SFAdamW = ForeachSFAdamW
|
28
27
|
LaProp = ForeachLaProp
|
29
28
|
ADOPT = ForeachADOPT
|
30
|
-
|
31
|
-
|
29
|
+
PrecondScheduleSOAP = PrecondScheduleForeachSOAP
|
30
|
+
PrecondSchedulePaLMSOAP = PrecondSchedulePaLMForeachSOAP
|
32
31
|
PSGDKron = ForeachPSGDKron
|
33
32
|
AdamW = ForeachAdamW
|
34
33
|
PurePSGD = ForeachPurePSGD
|
35
34
|
PaLMPAdam = ForeachPaLMPAdam
|
36
35
|
DelayedPSGD = ForeachDelayedPSGD
|
37
36
|
CachedPSGDKron = ForeachCachedPSGDKron
|
38
|
-
CachedDelayedPSGDKron
|
37
|
+
CachedDelayedPSGDKron = ForeachCachedDelayedPSGDKron
|
39
38
|
|
40
39
|
__all__ = ['PalmForEachSoap', 'PaLMForeachSFAdamW', 'PaLMForeachSOAP', 'SFPaLMForeachSOAP', 'PrecondScheduleSFPaLMSOAP',
|
41
40
|
'ForeachSOAP', 'ForeachSFAdamW', 'ForeachLaProp', 'ForeachADOPT', 'PrecondScheduleForeachSOAP',
|
42
41
|
'PrecondSchedulePaLMForeachSOAP', 'ForeachPSGDKron', 'ForeachAdamW', 'ForeachPurePSGD', 'ForeachPaLMPAdam',
|
43
|
-
'ForeachDelayedPSGD', 'ForeachCachedPSGDKron', 'ForeachCachedDelayedPSGDKron' #
|
44
|
-
'PaLMSOAP', 'PaLMSFAdamW', 'PaLMSFSoap', 'PaLMSFAdamW', '
|
42
|
+
'ForeachDelayedPSGD', 'ForeachCachedPSGDKron', 'ForeachCachedDelayedPSGDKron', #
|
43
|
+
'PaLMSOAP', 'PaLMSFAdamW', 'PaLMSFSoap', 'PaLMSFAdamW', 'PrecondScheduleSFPaLMSOAP',
|
45
44
|
'SOAP', 'SFAdamW', 'LaProp', 'ADOPT', 'PSGDKron', 'AdamW', 'PurePSGD', 'PaLMPAdam', 'DelayedPSGD',
|
46
|
-
'CachedPSGDKron', 'CachedDelayedPSGDKron']
|
45
|
+
'CachedPSGDKron', 'CachedDelayedPSGDKron', 'PrecondScheduleSOAP', 'PrecondSchedulePaLMSOAP']
|
@@ -332,6 +332,16 @@ def promote(x):
|
|
332
332
|
return x
|
333
333
|
|
334
334
|
|
335
|
+
def min_dtype(xs: List[torch.Tensor]):
|
336
|
+
dtypes = [x.dtype for x in xs]
|
337
|
+
for d in (torch.float32, torch.bfloat16, torch.float16):
|
338
|
+
if all(d == x for x in dtypes):
|
339
|
+
return d
|
340
|
+
if all(d in (x, torch.float32, torch.float64) for x in dtypes):
|
341
|
+
return d
|
342
|
+
return torch.float32
|
343
|
+
|
344
|
+
|
335
345
|
def update_preconditioner(grad, state, max_precond_dim, precondition_1d, beta, update_precond):
|
336
346
|
"""
|
337
347
|
Updates the preconditioner matrices and the eigenbases (L, R, Q_L, Q_R in the paper).
|
@@ -471,13 +481,8 @@ def copy_stochastic_list_(target: List[torch.Tensor], source: List[torch.Tensor]
|
|
471
481
|
copy_stochastic_(t, s)
|
472
482
|
|
473
483
|
|
474
|
-
|
475
|
-
|
476
|
-
return
|
477
|
-
if target.dtype != torch.bfloat16:
|
478
|
-
set_(target, source)
|
479
|
-
return
|
480
|
-
|
484
|
+
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
|
485
|
+
def _compilable_copy_stochastic_(target: torch.Tensor, source: torch.Tensor):
|
481
486
|
"""Taken as-is from https://github.com/pytorch/pytorch/issues/120376#issuecomment-1974828905"""
|
482
487
|
# create a random 16 bit integer
|
483
488
|
result = torch.randint_like(source, dtype=torch.int32, low=0, high=(1 << 16))
|
@@ -492,6 +497,15 @@ def copy_stochastic_(target: torch.Tensor, source: torch.Tensor):
|
|
492
497
|
target.copy_(result.view(dtype=torch.float32))
|
493
498
|
|
494
499
|
|
500
|
+
def copy_stochastic_(target: torch.Tensor, source: torch.Tensor):
|
501
|
+
if target.data_ptr() == source.data_ptr():
|
502
|
+
return
|
503
|
+
if target.dtype != torch.bfloat16 or source.dtype not in (torch.float16, torch.float32, torch.float64):
|
504
|
+
set_(target, source)
|
505
|
+
return
|
506
|
+
_compilable_copy_stochastic_(target, source)
|
507
|
+
|
508
|
+
|
495
509
|
def update_param_(param: List[torch.Tensor], update: List[torch.Tensor], lr: float, decay: float,
|
496
510
|
add_fn: callable = None):
|
497
511
|
param32 = [promote(p) for p in param]
|
@@ -602,7 +616,8 @@ def psgd_balance_Q(Q_in):
|
|
602
616
|
|
603
617
|
|
604
618
|
def psgd_calc_A_and_conjB(exprA, G, Q, V):
|
605
|
-
|
619
|
+
md = min_dtype(Q)
|
620
|
+
A = torch.einsum(exprA, *[q.to(md) for q in Q], G.to(md))
|
606
621
|
order = G.dim()
|
607
622
|
p = list(range(order))
|
608
623
|
conjB = torch.permute(V.conj(), p[1:] + p[:1])
|
@@ -669,7 +684,8 @@ def psgd_update_precond(Q, exprs, V, G, step, tiny):
|
|
669
684
|
@decorator
|
670
685
|
def psgd_precond_grad(Q, exprs, G, inplace: bool = False):
|
671
686
|
"""Precondition gradient G with preconditioner Q."""
|
672
|
-
|
687
|
+
md = min_dtype(Q)
|
688
|
+
out = torch.einsum(exprs[-1], *[q.conj().to(md) for q in Q], *[q.to(md) for q in Q], G.to(md))
|
673
689
|
if inplace:
|
674
690
|
set_(G, out)
|
675
691
|
return G
|
@@ -787,14 +803,15 @@ class PSGDBase(StatefulOptimizer):
|
|
787
803
|
if g.dim() > 1:
|
788
804
|
psgd_balance_Q(q)
|
789
805
|
|
790
|
-
def do_update(self, p_list, grad_list, q_list, precond_lr, original_q: Optional[List] = None,
|
806
|
+
def do_update(self, p_list, grad_list, q_list, precond_lr, original_q: Optional[List] = None,
|
807
|
+
store_triu_as_line=False):
|
791
808
|
for i, (p, grad, Q) in enumerate(zip(p_list, grad_list, q_list)):
|
792
809
|
psgd_update_precond(Q, self.state_(p)["exprs"], torch.randn_like(grad), grad, precond_lr, self._tiny)
|
793
810
|
if original_q:
|
794
811
|
if store_triu_as_line:
|
795
812
|
update_triu_(original_q[i], Q)
|
796
813
|
else:
|
797
|
-
|
814
|
+
copy_stochastic_list_(original_q[i], Q)
|
798
815
|
|
799
816
|
|
800
817
|
def precond_update_prob_schedule(max_prob=1.0, min_prob=0.03, decay=0.001, flat_start=250):
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: heavyball
|
3
|
-
Version: 0.17.
|
3
|
+
Version: 0.17.2
|
4
4
|
Summary: Efficient optimizers
|
5
5
|
Home-page: https://github.com/clashluke/heavyball
|
6
6
|
Author: Lucas Nestler
|
@@ -32,8 +32,8 @@ A simple package of efficient optimizers
|
|
32
32
|
The goal is not to thrive for completeness, full maintenance or abstraction, but instead to provide a simple
|
33
33
|
largely static alternative to `torch.optim` with more and better optimizers.
|
34
34
|
|
35
|
-
Currently (2024-11-
|
36
|
-
recommended experimental optimizer is `
|
35
|
+
Currently (2024-11-20, 0.17.0), the recommended stable optimizer is `PrecondSchedulePaLMSOAP` (see below). The
|
36
|
+
recommended experimental optimizer is `DelayedPSGDKron` ([tuning guide](docs/psgd_efficiency.md)).
|
37
37
|
|
38
38
|
## Features
|
39
39
|
|
@@ -62,7 +62,7 @@ import heavyball
|
|
62
62
|
model = torch.nn.Linear(16, 1)
|
63
63
|
|
64
64
|
# Create an optimizer
|
65
|
-
optimizer = heavyball.
|
65
|
+
optimizer = heavyball.PrecondSchedulePaLMSOAP(model.parameters(), lr=1e-3)
|
66
66
|
|
67
67
|
x = torch.randn(128, 16)
|
68
68
|
y = torch.randn(128, 1)
|
@@ -76,19 +76,19 @@ for _ in range(1000):
|
|
76
76
|
|
77
77
|
## Optimizers
|
78
78
|
|
79
|
-
| Name
|
80
|
-
|
81
|
-
| **
|
82
|
-
| **
|
83
|
-
| **
|
84
|
-
| **
|
85
|
-
| **
|
86
|
-
| **
|
87
|
-
| **
|
88
|
-
| **
|
89
|
-
| **
|
90
|
-
| **
|
91
|
-
| **
|
79
|
+
| Name | Description | Advantages / Disadvantages |
|
80
|
+
|-------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
81
|
+
| **AdamW** | More efficient (speed, memory) [AdamW](https://arxiv.org/abs/1711.05101) | + Faster than AdamW<br>+ Possibly more (numerically) stable
|
82
|
+
| **LaProp** | More efficient (speed, memory) [LaProp](https://arxiv.org/abs/2002.04839) | + Same cost as AdamW<br>+ Marginally better converence (better proofs)<br>+ Higher hyperparameter stability<br>- Not a guaranteed win (can be neutral)<br>- No "Slingshot" |
|
83
|
+
| **ADOPT** | More efficient (speed, memory) [ADOPT](https://arxiv.org/abs/2411.02853) | + Same cost as AdamW<br>+ Rigorous mathematical convergence proofs, even for challenging models (GANs)<br>- Empirically underperforms LaProp<br>- no bf16 |
|
84
|
+
| **SFAdamW** | More efficient (speed, memory) [ScheduleFree AdamW](https://arxiv.org/abs/2405.15682) | + Same cost as AdamW, but better eval perf<br>+ Full control over hyperparameters |
|
85
|
+
| **PaLMSFAdamW** | ForeachSFAdamW with [PaLM's beta2 schedule](https://arxiv.org/abs/2204.02311) | + Same cost as AdamW, but better eval perf<br>+ Less control, but faster early and more stable late convergence<br>+ ScheduleFree<br>- slow early convergence |
|
86
|
+
| **SOAP** | More efficient (speed, memory) [SOAP](https://arxiv.org/abs/2409.11321) | + Faster convergence (loss-at-step)<br>+ Full control over hyperparameters<br>- more memory usage<br>- more hyperparameters<br>- higher overhead than AdamW (can be ammortized; better loss-at-second) |
|
87
|
+
| **PaLMSOAP** | ForeachSOAP with [PaLM's beta2 schedule](https://arxiv.org/abs/2204.02311) | + Faster convergence (loss-at-step)<br>+ Less control, but faster early and more stable late convergence<br>- more memory usage<br>- more hyperparameters<br>- higher overhead than AdamW (can be ammortized; better loss-at-second) |
|
88
|
+
| **SFPaLMSOAP** | ScheduleFree PaLMForeachSOAP | + Fast convergence (loss-at-step)<br>+ less memory usage than PaLMForeachSOAP (more tham AdamW)<br>- slower initial convergence than PaLMForeachSOAP (but allows higher LRs)<br>- higher overhead than AdamW (can be ammortized) |
|
89
|
+
| **PrecondScheduleSFPaLMSOAP** | SFPaLMForeachSOAP with [preconditioner schedule](https://github.com/lixilinx/psgd_torch/), matching the error of PrecondEvery=2 with the cost of PrecondEvery=512 | + Better initial convergence than SFPaLMForeachSOAP<br>+ Significantly faster (sec/it) later<br>+ less memory usage than PaLMForeachSOAP (more tham AdamW)<br>- slower initial convergence than PaLMForeachSOAP (but allows higher LRs)<br>- higher overhead than AdamW (can be ammortized), goes to 0 with increasing number of step |
|
90
|
+
| **PrecondSchedulePaLMSOAP** | PrecondScheduleSFPaLMForeachSOAP without schedule-free | + Best initial convergence<br>+ Significantly faster (sec/it) later<br>+ high stability<br>- more memory usage than PrecondScheduleSFPaLMForeachSOAP<br>- higher overhead than AdamW (can be ammortized), goes to 0 with increasing number of steps |
|
91
|
+
| **PrecondScheduleSOAP** | PrecondScheduleSFPaLMForeachSOAP without PaLM's beta2 schedule | + Better initial convergence<br>+ Significantly faster (sec/it) later<br>- more memory usage than PrecondScheduleSFPaLMForeachSOAP<br>- higher overhead than AdamW (can be ammortized), goes to 0 with increasing number of steps |
|
92
92
|
|
93
93
|
## Precond Schedule
|
94
94
|
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|