heavyball 0.18.8__tar.gz → 0.19.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {heavyball-0.18.8 → heavyball-0.19.0}/PKG-INFO +18 -16
- {heavyball-0.18.8 → heavyball-0.19.0}/README.md +17 -15
- {heavyball-0.18.8 → heavyball-0.19.0}/heavyball/foreach_adamw.py +7 -5
- {heavyball-0.18.8 → heavyball-0.19.0}/heavyball/foreach_adopt.py +6 -4
- {heavyball-0.18.8 → heavyball-0.19.0}/heavyball/foreach_laprop.py +10 -5
- {heavyball-0.18.8 → heavyball-0.19.0}/heavyball/foreach_sfadamw.py +7 -4
- {heavyball-0.18.8 → heavyball-0.19.0}/heavyball/palm_foreach_sfadamw.py +9 -4
- {heavyball-0.18.8 → heavyball-0.19.0}/heavyball.egg-info/PKG-INFO +18 -16
- {heavyball-0.18.8 → heavyball-0.19.0}/heavyball.egg-info/SOURCES.txt +1 -0
- {heavyball-0.18.8 → heavyball-0.19.0}/setup.py +1 -1
- heavyball-0.19.0/test/test_bf16_storage.py +60 -0
- {heavyball-0.18.8 → heavyball-0.19.0}/LICENSE +0 -0
- {heavyball-0.18.8 → heavyball-0.19.0}/heavyball/__init__.py +0 -0
- {heavyball-0.18.8 → heavyball-0.19.0}/heavyball/cached_delayed_psgd_kron.py +0 -0
- {heavyball-0.18.8 → heavyball-0.19.0}/heavyball/cached_psgd_kron.py +0 -0
- {heavyball-0.18.8 → heavyball-0.19.0}/heavyball/delayed_psgd.py +0 -0
- {heavyball-0.18.8 → heavyball-0.19.0}/heavyball/foreach_soap.py +0 -0
- {heavyball-0.18.8 → heavyball-0.19.0}/heavyball/p_adam.py +0 -0
- {heavyball-0.18.8 → heavyball-0.19.0}/heavyball/palm_foreach_soap.py +0 -0
- {heavyball-0.18.8 → heavyball-0.19.0}/heavyball/precond_schedule_foreach_soap.py +0 -0
- {heavyball-0.18.8 → heavyball-0.19.0}/heavyball/precond_schedule_palm_foreach_soap.py +0 -0
- {heavyball-0.18.8 → heavyball-0.19.0}/heavyball/precond_schedule_sfpsoap.py +0 -0
- {heavyball-0.18.8 → heavyball-0.19.0}/heavyball/psgd_kron.py +0 -0
- {heavyball-0.18.8 → heavyball-0.19.0}/heavyball/pure_psgd.py +0 -0
- {heavyball-0.18.8 → heavyball-0.19.0}/heavyball/schedule_free_palm_foreach_soap.py +0 -0
- {heavyball-0.18.8 → heavyball-0.19.0}/heavyball/utils.py +0 -0
- {heavyball-0.18.8 → heavyball-0.19.0}/heavyball.egg-info/dependency_links.txt +0 -0
- {heavyball-0.18.8 → heavyball-0.19.0}/heavyball.egg-info/requires.txt +0 -0
- {heavyball-0.18.8 → heavyball-0.19.0}/heavyball.egg-info/top_level.txt +0 -0
- {heavyball-0.18.8 → heavyball-0.19.0}/setup.cfg +0 -0
- {heavyball-0.18.8 → heavyball-0.19.0}/test/test_bf16_params.py +0 -0
- {heavyball-0.18.8 → heavyball-0.19.0}/test/test_bf16_q.py +0 -0
- {heavyball-0.18.8 → heavyball-0.19.0}/test/test_closure.py +0 -0
- {heavyball-0.18.8 → heavyball-0.19.0}/test/test_foreach.py +0 -0
- {heavyball-0.18.8 → heavyball-0.19.0}/test/test_memory.py +0 -0
- {heavyball-0.18.8 → heavyball-0.19.0}/test/test_merge.py +0 -0
- {heavyball-0.18.8 → heavyball-0.19.0}/test/test_no_grad.py +0 -0
- {heavyball-0.18.8 → heavyball-0.19.0}/test/test_psgd.py +0 -0
- {heavyball-0.18.8 → heavyball-0.19.0}/test/test_soap.py +0 -0
- {heavyball-0.18.8 → heavyball-0.19.0}/test/test_stochastic_updates.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: heavyball
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.19.0
|
4
4
|
Summary: Efficient optimizers
|
5
5
|
Home-page: https://github.com/clashluke/heavyball
|
6
6
|
Author: Lucas Nestler
|
@@ -32,7 +32,7 @@ A simple package of efficient optimizers
|
|
32
32
|
The goal is not to thrive for completeness, full maintenance or abstraction, but instead to provide a simple
|
33
33
|
largely static alternative to `torch.optim` with more and better optimizers.
|
34
34
|
|
35
|
-
Currently (2024-11-
|
35
|
+
Currently (2024-11-22, 0.19), the recommended stable optimizer is `PrecondSchedulePaLMSOAP` (see below). The
|
36
36
|
recommended experimental optimizer is `DelayedPSGDKron` ([tuning guide](docs/psgd_efficiency.md)).
|
37
37
|
|
38
38
|
## Features
|
@@ -45,8 +45,10 @@ recommended experimental optimizer is `DelayedPSGDKron` ([tuning guide](docs/psg
|
|
45
45
|
* **ScheduleFree**: No learning rate schedule, but better convergence
|
46
46
|
* [**Preconditioner Schedule**](https://github.com/lixilinx/psgd_torch/): Improved loss-per-step in early convergence,
|
47
47
|
better step-per-second in late convergence (explained below)
|
48
|
-
* **Memory-efficient storage** PSGD supports `store_triu_as_line` (default: `True`) to trade off memory
|
49
|
-
|
48
|
+
* **Memory-efficient storage** PSGD supports `store_triu_as_line` (default: `True`) and `q_dtype` to trade off memory
|
49
|
+
usage for memory
|
50
|
+
bandwidth; Other optimizers have `storage_dtype`, supporting lower-precision EMAs at no(?) performance drop via
|
51
|
+
stochastic rounding
|
50
52
|
|
51
53
|
## Getting started
|
52
54
|
|
@@ -76,19 +78,19 @@ for _ in range(1000):
|
|
76
78
|
|
77
79
|
## Optimizers
|
78
80
|
|
79
|
-
| Name
|
80
|
-
|
81
|
-
| **AdamW**
|
82
|
-
| **LaProp**
|
83
|
-
| **ADOPT**
|
84
|
-
| **SFAdamW**
|
85
|
-
| **PaLMSFAdamW**
|
86
|
-
| **SOAP**
|
87
|
-
| **PaLMSOAP**
|
88
|
-
| **SFPaLMSOAP**
|
81
|
+
| Name | Description | Advantages / Disadvantages |
|
82
|
+
|-------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
83
|
+
| **AdamW** | More efficient (speed, memory) [AdamW](https://arxiv.org/abs/1711.05101) | + Faster than AdamW<br>+ Possibly more (numerically) stable
|
84
|
+
| **LaProp** | More efficient (speed, memory) [LaProp](https://arxiv.org/abs/2002.04839) | + Same cost as AdamW<br>+ Marginally better converence (better proofs)<br>+ Higher hyperparameter stability<br>- Not a guaranteed win (can be neutral)<br>- No "Slingshot" |
|
85
|
+
| **ADOPT** | More efficient (speed, memory) [ADOPT](https://arxiv.org/abs/2411.02853) | + Same cost as AdamW<br>+ Rigorous mathematical convergence proofs, even for challenging models (GANs)<br>- Empirically underperforms LaProp<br>- no bf16 |
|
86
|
+
| **SFAdamW** | More efficient (speed, memory) [ScheduleFree AdamW](https://arxiv.org/abs/2405.15682) | + Same cost as AdamW, but better eval perf<br>+ Full control over hyperparameters |
|
87
|
+
| **PaLMSFAdamW** | ForeachSFAdamW with [PaLM's beta2 schedule](https://arxiv.org/abs/2204.02311) | + Same cost as AdamW, but better eval perf<br>+ Less control, but faster early and more stable late convergence<br>+ ScheduleFree<br>- slow early convergence |
|
88
|
+
| **SOAP** | More efficient (speed, memory) [SOAP](https://arxiv.org/abs/2409.11321) | + Faster convergence (loss-at-step)<br>+ Full control over hyperparameters<br>- more memory usage<br>- more hyperparameters<br>- higher overhead than AdamW (can be ammortized; better loss-at-second) |
|
89
|
+
| **PaLMSOAP** | ForeachSOAP with [PaLM's beta2 schedule](https://arxiv.org/abs/2204.02311) | + Faster convergence (loss-at-step)<br>+ Less control, but faster early and more stable late convergence<br>- more memory usage<br>- more hyperparameters<br>- higher overhead than AdamW (can be ammortized; better loss-at-second) |
|
90
|
+
| **SFPaLMSOAP** | ScheduleFree PaLMForeachSOAP | + Fast convergence (loss-at-step)<br>+ less memory usage than PaLMForeachSOAP (more tham AdamW)<br>- slower initial convergence than PaLMForeachSOAP (but allows higher LRs)<br>- higher overhead than AdamW (can be ammortized) |
|
89
91
|
| **PrecondScheduleSFPaLMSOAP** | SFPaLMForeachSOAP with [preconditioner schedule](https://github.com/lixilinx/psgd_torch/), matching the error of PrecondEvery=2 with the cost of PrecondEvery=512 | + Better initial convergence than SFPaLMForeachSOAP<br>+ Significantly faster (sec/it) later<br>+ less memory usage than PaLMForeachSOAP (more tham AdamW)<br>- slower initial convergence than PaLMForeachSOAP (but allows higher LRs)<br>- higher overhead than AdamW (can be ammortized), goes to 0 with increasing number of step |
|
90
|
-
| **PrecondSchedulePaLMSOAP**
|
91
|
-
| **PrecondScheduleSOAP**
|
92
|
+
| **PrecondSchedulePaLMSOAP** | PrecondScheduleSFPaLMForeachSOAP without schedule-free | + Best initial convergence<br>+ Significantly faster (sec/it) later<br>+ high stability<br>- more memory usage than PrecondScheduleSFPaLMForeachSOAP<br>- higher overhead than AdamW (can be ammortized), goes to 0 with increasing number of steps |
|
93
|
+
| **PrecondScheduleSOAP** | PrecondScheduleSFPaLMForeachSOAP without PaLM's beta2 schedule | + Better initial convergence<br>+ Significantly faster (sec/it) later<br>- more memory usage than PrecondScheduleSFPaLMForeachSOAP<br>- higher overhead than AdamW (can be ammortized), goes to 0 with increasing number of steps |
|
92
94
|
|
93
95
|
## Precond Schedule
|
94
96
|
|
@@ -8,7 +8,7 @@ A simple package of efficient optimizers
|
|
8
8
|
The goal is not to thrive for completeness, full maintenance or abstraction, but instead to provide a simple
|
9
9
|
largely static alternative to `torch.optim` with more and better optimizers.
|
10
10
|
|
11
|
-
Currently (2024-11-
|
11
|
+
Currently (2024-11-22, 0.19), the recommended stable optimizer is `PrecondSchedulePaLMSOAP` (see below). The
|
12
12
|
recommended experimental optimizer is `DelayedPSGDKron` ([tuning guide](docs/psgd_efficiency.md)).
|
13
13
|
|
14
14
|
## Features
|
@@ -21,8 +21,10 @@ recommended experimental optimizer is `DelayedPSGDKron` ([tuning guide](docs/psg
|
|
21
21
|
* **ScheduleFree**: No learning rate schedule, but better convergence
|
22
22
|
* [**Preconditioner Schedule**](https://github.com/lixilinx/psgd_torch/): Improved loss-per-step in early convergence,
|
23
23
|
better step-per-second in late convergence (explained below)
|
24
|
-
* **Memory-efficient storage** PSGD supports `store_triu_as_line` (default: `True`) to trade off memory
|
25
|
-
|
24
|
+
* **Memory-efficient storage** PSGD supports `store_triu_as_line` (default: `True`) and `q_dtype` to trade off memory
|
25
|
+
usage for memory
|
26
|
+
bandwidth; Other optimizers have `storage_dtype`, supporting lower-precision EMAs at no(?) performance drop via
|
27
|
+
stochastic rounding
|
26
28
|
|
27
29
|
## Getting started
|
28
30
|
|
@@ -52,19 +54,19 @@ for _ in range(1000):
|
|
52
54
|
|
53
55
|
## Optimizers
|
54
56
|
|
55
|
-
| Name
|
56
|
-
|
57
|
-
| **AdamW**
|
58
|
-
| **LaProp**
|
59
|
-
| **ADOPT**
|
60
|
-
| **SFAdamW**
|
61
|
-
| **PaLMSFAdamW**
|
62
|
-
| **SOAP**
|
63
|
-
| **PaLMSOAP**
|
64
|
-
| **SFPaLMSOAP**
|
57
|
+
| Name | Description | Advantages / Disadvantages |
|
58
|
+
|-------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
59
|
+
| **AdamW** | More efficient (speed, memory) [AdamW](https://arxiv.org/abs/1711.05101) | + Faster than AdamW<br>+ Possibly more (numerically) stable
|
60
|
+
| **LaProp** | More efficient (speed, memory) [LaProp](https://arxiv.org/abs/2002.04839) | + Same cost as AdamW<br>+ Marginally better converence (better proofs)<br>+ Higher hyperparameter stability<br>- Not a guaranteed win (can be neutral)<br>- No "Slingshot" |
|
61
|
+
| **ADOPT** | More efficient (speed, memory) [ADOPT](https://arxiv.org/abs/2411.02853) | + Same cost as AdamW<br>+ Rigorous mathematical convergence proofs, even for challenging models (GANs)<br>- Empirically underperforms LaProp<br>- no bf16 |
|
62
|
+
| **SFAdamW** | More efficient (speed, memory) [ScheduleFree AdamW](https://arxiv.org/abs/2405.15682) | + Same cost as AdamW, but better eval perf<br>+ Full control over hyperparameters |
|
63
|
+
| **PaLMSFAdamW** | ForeachSFAdamW with [PaLM's beta2 schedule](https://arxiv.org/abs/2204.02311) | + Same cost as AdamW, but better eval perf<br>+ Less control, but faster early and more stable late convergence<br>+ ScheduleFree<br>- slow early convergence |
|
64
|
+
| **SOAP** | More efficient (speed, memory) [SOAP](https://arxiv.org/abs/2409.11321) | + Faster convergence (loss-at-step)<br>+ Full control over hyperparameters<br>- more memory usage<br>- more hyperparameters<br>- higher overhead than AdamW (can be ammortized; better loss-at-second) |
|
65
|
+
| **PaLMSOAP** | ForeachSOAP with [PaLM's beta2 schedule](https://arxiv.org/abs/2204.02311) | + Faster convergence (loss-at-step)<br>+ Less control, but faster early and more stable late convergence<br>- more memory usage<br>- more hyperparameters<br>- higher overhead than AdamW (can be ammortized; better loss-at-second) |
|
66
|
+
| **SFPaLMSOAP** | ScheduleFree PaLMForeachSOAP | + Fast convergence (loss-at-step)<br>+ less memory usage than PaLMForeachSOAP (more tham AdamW)<br>- slower initial convergence than PaLMForeachSOAP (but allows higher LRs)<br>- higher overhead than AdamW (can be ammortized) |
|
65
67
|
| **PrecondScheduleSFPaLMSOAP** | SFPaLMForeachSOAP with [preconditioner schedule](https://github.com/lixilinx/psgd_torch/), matching the error of PrecondEvery=2 with the cost of PrecondEvery=512 | + Better initial convergence than SFPaLMForeachSOAP<br>+ Significantly faster (sec/it) later<br>+ less memory usage than PaLMForeachSOAP (more tham AdamW)<br>- slower initial convergence than PaLMForeachSOAP (but allows higher LRs)<br>- higher overhead than AdamW (can be ammortized), goes to 0 with increasing number of step |
|
66
|
-
| **PrecondSchedulePaLMSOAP**
|
67
|
-
| **PrecondScheduleSOAP**
|
68
|
+
| **PrecondSchedulePaLMSOAP** | PrecondScheduleSFPaLMForeachSOAP without schedule-free | + Best initial convergence<br>+ Significantly faster (sec/it) later<br>+ high stability<br>- more memory usage than PrecondScheduleSFPaLMForeachSOAP<br>- higher overhead than AdamW (can be ammortized), goes to 0 with increasing number of steps |
|
69
|
+
| **PrecondScheduleSOAP** | PrecondScheduleSFPaLMForeachSOAP without PaLM's beta2 schedule | + Better initial convergence<br>+ Significantly faster (sec/it) later<br>- more memory usage than PrecondScheduleSFPaLMForeachSOAP<br>- higher overhead than AdamW (can be ammortized), goes to 0 with increasing number of steps |
|
68
70
|
|
69
71
|
## Precond Schedule
|
70
72
|
|
@@ -1,7 +1,7 @@
|
|
1
1
|
import torch
|
2
2
|
import torch.optim
|
3
|
-
from heavyball.utils import copy_stochastic_list_
|
4
3
|
|
4
|
+
from heavyball.utils import copy_stochastic_list_
|
5
5
|
from .utils import warmup, exp_avg_sq_, beta_debias, update_param_, StatefulOptimizer, promote
|
6
6
|
|
7
7
|
|
@@ -20,9 +20,9 @@ def _compilable_step_(y, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps,
|
|
20
20
|
|
21
21
|
class ForeachAdamW(StatefulOptimizer):
|
22
22
|
def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
|
23
|
-
foreach: bool = True):
|
23
|
+
foreach: bool = True, storage_dtype: str = 'float32'):
|
24
24
|
defaults = dict(lr=lr, betas=betas, eps=eps, k=0, warmup_steps=warmup_steps, train_mode=True, weight_sum=0.0,
|
25
|
-
lr_max=-1.0, weight_decay=weight_decay)
|
25
|
+
lr_max=-1.0, weight_decay=weight_decay, storage_dtype=storage_dtype)
|
26
26
|
super().__init__(params, defaults, foreach)
|
27
27
|
|
28
28
|
def _step(self, group):
|
@@ -38,10 +38,12 @@ class ForeachAdamW(StatefulOptimizer):
|
|
38
38
|
if not active_p:
|
39
39
|
return
|
40
40
|
|
41
|
+
storage_dtype = getattr(torch, group['storage_dtype'])
|
42
|
+
|
41
43
|
for p in active_p:
|
42
44
|
if 'exp_avg' not in self.state_(p):
|
43
|
-
self.state_(p)['exp_avg'] = torch.zeros_like(p.data, dtype=
|
44
|
-
self.state_(p)['exp_avg_sq'] = torch.zeros_like(p.data, dtype=
|
45
|
+
self.state_(p)['exp_avg'] = torch.zeros_like(p.data, dtype=storage_dtype)
|
46
|
+
self.state_(p)['exp_avg_sq'] = torch.zeros_like(p.data, dtype=storage_dtype)
|
45
47
|
|
46
48
|
y, grad, exp_avg_sq, exp_avg = zip(
|
47
49
|
*[(p.data, p.grad, self.state_(p)['exp_avg_sq'], self.state_(p)['exp_avg']) for p in active_p])
|
@@ -27,9 +27,9 @@ def _compilable_step_(y, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps,
|
|
27
27
|
class ForeachADOPT(StatefulOptimizer):
|
28
28
|
|
29
29
|
def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
|
30
|
-
foreach: bool = True):
|
30
|
+
foreach: bool = True, storage_dtype: str = 'float32'):
|
31
31
|
defaults = dict(lr=lr, betas=betas, eps=eps, k=0, warmup_steps=warmup_steps, train_mode=True, weight_sum=0.0,
|
32
|
-
lr_max=-1.0, weight_decay=weight_decay)
|
32
|
+
lr_max=-1.0, weight_decay=weight_decay, storage_dtype=storage_dtype)
|
33
33
|
super().__init__(params, defaults, foreach)
|
34
34
|
|
35
35
|
def _step(self, group):
|
@@ -45,10 +45,12 @@ class ForeachADOPT(StatefulOptimizer):
|
|
45
45
|
if not active_p:
|
46
46
|
return
|
47
47
|
|
48
|
+
storage_dtype = getattr(torch, group['storage_dtype'])
|
49
|
+
|
48
50
|
for p in active_p:
|
49
51
|
if 'exp_avg' not in self.state_(p):
|
50
|
-
self.state_(p)['exp_avg'] = torch.zeros_like(p.data, dtype=
|
51
|
-
self.state_(p)['exp_avg_sq'] = torch.zeros_like(p.data, dtype=
|
52
|
+
self.state_(p)['exp_avg'] = torch.zeros_like(p.data, dtype=storage_dtype)
|
53
|
+
self.state_(p)['exp_avg_sq'] = torch.zeros_like(p.data, dtype=storage_dtype)
|
52
54
|
|
53
55
|
y, grad, exp_avg_sq, exp_avg = zip(
|
54
56
|
*[(p.data, p.grad, self.state_(p)['exp_avg_sq'], self.state_(p)['exp_avg']) for p in active_p])
|
@@ -1,7 +1,7 @@
|
|
1
1
|
import torch
|
2
2
|
import torch.optim
|
3
3
|
|
4
|
-
from .utils import warmup, exp_avg_sq_, beta_debias, update_param_, StatefulOptimizer, promote
|
4
|
+
from .utils import warmup, exp_avg_sq_, beta_debias, update_param_, StatefulOptimizer, promote, copy_stochastic_list_
|
5
5
|
|
6
6
|
|
7
7
|
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
|
@@ -16,13 +16,16 @@ def _compilable_step_(y, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps,
|
|
16
16
|
|
17
17
|
update_param_(y, exp_avg32, lr, decay)
|
18
18
|
|
19
|
+
copy_stochastic_list_(exp_avg, exp_avg32)
|
20
|
+
copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
|
21
|
+
|
19
22
|
|
20
23
|
class ForeachLaProp(StatefulOptimizer):
|
21
24
|
|
22
25
|
def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=1,
|
23
|
-
foreach: bool = True):
|
26
|
+
foreach: bool = True, storage_dtype: str = 'float32'):
|
24
27
|
defaults = dict(lr=lr, betas=betas, eps=eps, k=0, warmup_steps=warmup_steps, train_mode=True, weight_sum=0.0,
|
25
|
-
lr_max=-1.0, weight_decay=weight_decay)
|
28
|
+
lr_max=-1.0, weight_decay=weight_decay, storage_dtype=storage_dtype)
|
26
29
|
super().__init__(params, defaults, foreach)
|
27
30
|
|
28
31
|
def _step(self, group):
|
@@ -38,10 +41,12 @@ class ForeachLaProp(StatefulOptimizer):
|
|
38
41
|
if not active_p:
|
39
42
|
return
|
40
43
|
|
44
|
+
storage_dtype = getattr(torch, group['storage_dtype'])
|
45
|
+
|
41
46
|
for p in active_p:
|
42
47
|
if 'exp_avg' not in self.state_(p):
|
43
|
-
self.state_(p)['exp_avg'] = torch.zeros_like(p.data, dtype=
|
44
|
-
self.state_(p)['exp_avg_sq'] = torch.zeros_like(p.data, dtype=
|
48
|
+
self.state_(p)['exp_avg'] = torch.zeros_like(p.data, dtype=storage_dtype)
|
49
|
+
self.state_(p)['exp_avg_sq'] = torch.zeros_like(p.data, dtype=storage_dtype)
|
45
50
|
|
46
51
|
y, grad, exp_avg_sq, exp_avg = zip(
|
47
52
|
*[(p.data, p.grad, self.state_(p)['exp_avg_sq'], self.state_(p)['exp_avg']) #
|
@@ -1,6 +1,6 @@
|
|
1
1
|
import torch
|
2
2
|
import torch.optim
|
3
|
-
from heavyball.utils import get_ckp1
|
3
|
+
from heavyball.utils import get_ckp1, copy_stochastic_list_
|
4
4
|
|
5
5
|
from .utils import warmup, ScheduleFree, exp_avg_sq_, beta_debias, promote, _compilable_schedule_free_
|
6
6
|
|
@@ -19,14 +19,15 @@ def _compilable_step_(y, grad, exp_avg_sq, z, beta1, beta2, step, ckp1, eps, dec
|
|
19
19
|
for p, z_, g in zip(y, z, g32):
|
20
20
|
_compilable_schedule_free_(p, z_, ckp1, g, lr, beta1)
|
21
21
|
|
22
|
+
copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
|
22
23
|
|
23
24
|
class ForeachSFAdamW(ScheduleFree):
|
24
25
|
def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0, r=0.0,
|
25
|
-
weight_lr_power=2.0, foreach: bool = True):
|
26
|
+
weight_lr_power=2.0, foreach: bool = True, storage_dtype: str = 'float32'):
|
26
27
|
|
27
28
|
defaults = dict(lr=lr, betas=betas, eps=eps, r=r, k=0, warmup_steps=warmup_steps, train_mode=True,
|
28
29
|
weight_sum=0.0, lr_max=-1.0, weight_lr_power=weight_lr_power, weight_decay=weight_decay,
|
29
|
-
foreach=foreach)
|
30
|
+
foreach=foreach, storage_dtype=storage_dtype)
|
30
31
|
super().__init__(params, defaults, foreach)
|
31
32
|
|
32
33
|
def _step(self, group):
|
@@ -42,10 +43,12 @@ class ForeachSFAdamW(ScheduleFree):
|
|
42
43
|
if not active_p:
|
43
44
|
return
|
44
45
|
|
46
|
+
storage_dtype = getattr(torch, group['storage_dtype'])
|
47
|
+
|
45
48
|
for p in active_p:
|
46
49
|
if 'z' not in self.state_(p):
|
47
50
|
self.state_(p)['z'] = torch.clone(p.data)
|
48
|
-
self.state_(p)['exp_avg_sq'] = torch.zeros_like(p.data, dtype=
|
51
|
+
self.state_(p)['exp_avg_sq'] = torch.zeros_like(p.data, dtype=storage_dtype)
|
49
52
|
|
50
53
|
y, grad, exp_avg_sq, z = zip(*[(p.data, p.grad, self.state_(p)['exp_avg_sq'], self.state_(p)['z']) #
|
51
54
|
for p in active_p])
|
@@ -1,7 +1,8 @@
|
|
1
1
|
import torch
|
2
2
|
import torch.optim
|
3
3
|
|
4
|
-
from .utils import
|
4
|
+
from .utils import warmup, ScheduleFree, exp_avg_sq_, beta_debias, get_ckp1, promote, \
|
5
|
+
_compilable_schedule_free_, copy_stochastic_list_
|
5
6
|
|
6
7
|
|
7
8
|
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
|
@@ -18,15 +19,17 @@ def _compilable_step_(y, grad, exp_avg_sq, z, beta1, beta2, step, ckp1, eps, dec
|
|
18
19
|
for p, z_, g in zip(y, z, g32):
|
19
20
|
_compilable_schedule_free_(p, z_, ckp1, g, lr, beta1)
|
20
21
|
|
22
|
+
copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
|
23
|
+
|
21
24
|
|
22
25
|
class PaLMForeachSFAdamW(ScheduleFree):
|
23
26
|
def __init__(self, params, lr=0.0025, beta=0.9, betas=(None, None), eps=1e-8, weight_decay=0, warmup_steps=0, r=0.0,
|
24
|
-
weight_lr_power=2.0, beta2_scale: float = 0.8, foreach: bool = True):
|
27
|
+
weight_lr_power=2.0, beta2_scale: float = 0.8, foreach: bool = True, storage_dtype: str = 'float32'):
|
25
28
|
if betas[0] is not None:
|
26
29
|
beta = betas[0]
|
27
30
|
defaults = dict(lr=lr, beta=beta, eps=eps, r=r, k=0, warmup_steps=warmup_steps, train_mode=True, weight_sum=0.0,
|
28
31
|
lr_max=-1.0, weight_lr_power=weight_lr_power, weight_decay=weight_decay,
|
29
|
-
beta2_scale=beta2_scale)
|
32
|
+
beta2_scale=beta2_scale, storage_dtype=storage_dtype)
|
30
33
|
super().__init__(params, defaults, foreach)
|
31
34
|
|
32
35
|
def _step(self, group):
|
@@ -42,10 +45,12 @@ class PaLMForeachSFAdamW(ScheduleFree):
|
|
42
45
|
if not active_p:
|
43
46
|
return
|
44
47
|
|
48
|
+
storage_dtype = getattr(torch, group['storage_dtype'])
|
49
|
+
|
45
50
|
for p in active_p:
|
46
51
|
if 'z' not in self.state_(p):
|
47
52
|
self.state_(p)['z'] = torch.clone(p.data)
|
48
|
-
self.state_(p)['exp_avg_sq'] = torch.zeros_like(p.data, dtype=
|
53
|
+
self.state_(p)['exp_avg_sq'] = torch.zeros_like(p.data, dtype=storage_dtype)
|
49
54
|
|
50
55
|
# Decay the first moment running average coefficient
|
51
56
|
beta2 = 1 - (k + 1) ** -group['beta2_scale']
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: heavyball
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.19.0
|
4
4
|
Summary: Efficient optimizers
|
5
5
|
Home-page: https://github.com/clashluke/heavyball
|
6
6
|
Author: Lucas Nestler
|
@@ -32,7 +32,7 @@ A simple package of efficient optimizers
|
|
32
32
|
The goal is not to thrive for completeness, full maintenance or abstraction, but instead to provide a simple
|
33
33
|
largely static alternative to `torch.optim` with more and better optimizers.
|
34
34
|
|
35
|
-
Currently (2024-11-
|
35
|
+
Currently (2024-11-22, 0.19), the recommended stable optimizer is `PrecondSchedulePaLMSOAP` (see below). The
|
36
36
|
recommended experimental optimizer is `DelayedPSGDKron` ([tuning guide](docs/psgd_efficiency.md)).
|
37
37
|
|
38
38
|
## Features
|
@@ -45,8 +45,10 @@ recommended experimental optimizer is `DelayedPSGDKron` ([tuning guide](docs/psg
|
|
45
45
|
* **ScheduleFree**: No learning rate schedule, but better convergence
|
46
46
|
* [**Preconditioner Schedule**](https://github.com/lixilinx/psgd_torch/): Improved loss-per-step in early convergence,
|
47
47
|
better step-per-second in late convergence (explained below)
|
48
|
-
* **Memory-efficient storage** PSGD supports `store_triu_as_line` (default: `True`) to trade off memory
|
49
|
-
|
48
|
+
* **Memory-efficient storage** PSGD supports `store_triu_as_line` (default: `True`) and `q_dtype` to trade off memory
|
49
|
+
usage for memory
|
50
|
+
bandwidth; Other optimizers have `storage_dtype`, supporting lower-precision EMAs at no(?) performance drop via
|
51
|
+
stochastic rounding
|
50
52
|
|
51
53
|
## Getting started
|
52
54
|
|
@@ -76,19 +78,19 @@ for _ in range(1000):
|
|
76
78
|
|
77
79
|
## Optimizers
|
78
80
|
|
79
|
-
| Name
|
80
|
-
|
81
|
-
| **AdamW**
|
82
|
-
| **LaProp**
|
83
|
-
| **ADOPT**
|
84
|
-
| **SFAdamW**
|
85
|
-
| **PaLMSFAdamW**
|
86
|
-
| **SOAP**
|
87
|
-
| **PaLMSOAP**
|
88
|
-
| **SFPaLMSOAP**
|
81
|
+
| Name | Description | Advantages / Disadvantages |
|
82
|
+
|-------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
83
|
+
| **AdamW** | More efficient (speed, memory) [AdamW](https://arxiv.org/abs/1711.05101) | + Faster than AdamW<br>+ Possibly more (numerically) stable
|
84
|
+
| **LaProp** | More efficient (speed, memory) [LaProp](https://arxiv.org/abs/2002.04839) | + Same cost as AdamW<br>+ Marginally better converence (better proofs)<br>+ Higher hyperparameter stability<br>- Not a guaranteed win (can be neutral)<br>- No "Slingshot" |
|
85
|
+
| **ADOPT** | More efficient (speed, memory) [ADOPT](https://arxiv.org/abs/2411.02853) | + Same cost as AdamW<br>+ Rigorous mathematical convergence proofs, even for challenging models (GANs)<br>- Empirically underperforms LaProp<br>- no bf16 |
|
86
|
+
| **SFAdamW** | More efficient (speed, memory) [ScheduleFree AdamW](https://arxiv.org/abs/2405.15682) | + Same cost as AdamW, but better eval perf<br>+ Full control over hyperparameters |
|
87
|
+
| **PaLMSFAdamW** | ForeachSFAdamW with [PaLM's beta2 schedule](https://arxiv.org/abs/2204.02311) | + Same cost as AdamW, but better eval perf<br>+ Less control, but faster early and more stable late convergence<br>+ ScheduleFree<br>- slow early convergence |
|
88
|
+
| **SOAP** | More efficient (speed, memory) [SOAP](https://arxiv.org/abs/2409.11321) | + Faster convergence (loss-at-step)<br>+ Full control over hyperparameters<br>- more memory usage<br>- more hyperparameters<br>- higher overhead than AdamW (can be ammortized; better loss-at-second) |
|
89
|
+
| **PaLMSOAP** | ForeachSOAP with [PaLM's beta2 schedule](https://arxiv.org/abs/2204.02311) | + Faster convergence (loss-at-step)<br>+ Less control, but faster early and more stable late convergence<br>- more memory usage<br>- more hyperparameters<br>- higher overhead than AdamW (can be ammortized; better loss-at-second) |
|
90
|
+
| **SFPaLMSOAP** | ScheduleFree PaLMForeachSOAP | + Fast convergence (loss-at-step)<br>+ less memory usage than PaLMForeachSOAP (more tham AdamW)<br>- slower initial convergence than PaLMForeachSOAP (but allows higher LRs)<br>- higher overhead than AdamW (can be ammortized) |
|
89
91
|
| **PrecondScheduleSFPaLMSOAP** | SFPaLMForeachSOAP with [preconditioner schedule](https://github.com/lixilinx/psgd_torch/), matching the error of PrecondEvery=2 with the cost of PrecondEvery=512 | + Better initial convergence than SFPaLMForeachSOAP<br>+ Significantly faster (sec/it) later<br>+ less memory usage than PaLMForeachSOAP (more tham AdamW)<br>- slower initial convergence than PaLMForeachSOAP (but allows higher LRs)<br>- higher overhead than AdamW (can be ammortized), goes to 0 with increasing number of step |
|
90
|
-
| **PrecondSchedulePaLMSOAP**
|
91
|
-
| **PrecondScheduleSOAP**
|
92
|
+
| **PrecondSchedulePaLMSOAP** | PrecondScheduleSFPaLMForeachSOAP without schedule-free | + Best initial convergence<br>+ Significantly faster (sec/it) later<br>+ high stability<br>- more memory usage than PrecondScheduleSFPaLMForeachSOAP<br>- higher overhead than AdamW (can be ammortized), goes to 0 with increasing number of steps |
|
93
|
+
| **PrecondScheduleSOAP** | PrecondScheduleSFPaLMForeachSOAP without PaLM's beta2 schedule | + Better initial convergence<br>+ Significantly faster (sec/it) later<br>- more memory usage than PrecondScheduleSFPaLMForeachSOAP<br>- higher overhead than AdamW (can be ammortized), goes to 0 with increasing number of steps |
|
92
94
|
|
93
95
|
## Precond Schedule
|
94
96
|
|
@@ -0,0 +1,60 @@
|
|
1
|
+
import pytest
|
2
|
+
import torch
|
3
|
+
from torch import nn
|
4
|
+
from torch._dynamo import config
|
5
|
+
|
6
|
+
import heavyball
|
7
|
+
import heavyball.utils
|
8
|
+
from benchmark.utils import get_optim
|
9
|
+
from heavyball.utils import clean, set_torch, PSGDBase
|
10
|
+
|
11
|
+
config.cache_size_limit = 128
|
12
|
+
|
13
|
+
def get_memory():
|
14
|
+
clean()
|
15
|
+
torch.cuda.synchronize()
|
16
|
+
clean()
|
17
|
+
torch.cuda.synchronize()
|
18
|
+
return torch.cuda.memory_allocated()
|
19
|
+
|
20
|
+
|
21
|
+
@pytest.mark.parametrize("opt", heavyball.__all__)
|
22
|
+
@pytest.mark.parametrize("size,depth", [(256, 2)])
|
23
|
+
def test_foreach(opt, size, depth: int, iterations: int = 128, outer_iterations: int = 3):
|
24
|
+
set_torch()
|
25
|
+
|
26
|
+
if 'soap' in opt.lower():
|
27
|
+
raise pytest.skip('soap is not supported')
|
28
|
+
|
29
|
+
opt = getattr(heavyball, opt)
|
30
|
+
|
31
|
+
if PSGDBase in opt.__mro__:
|
32
|
+
raise pytest.skip('PSGD is not supported')
|
33
|
+
|
34
|
+
peaks = []
|
35
|
+
losses = []
|
36
|
+
|
37
|
+
for dtype_name in ["float32", "bfloat16"]:
|
38
|
+
torch.manual_seed(0x2131290)
|
39
|
+
peaks.append([])
|
40
|
+
losses.append([])
|
41
|
+
|
42
|
+
dtype = getattr(torch, dtype_name)
|
43
|
+
|
44
|
+
for i in range(outer_iterations):
|
45
|
+
model = nn.Sequential(*[nn.Linear(size, size) for _ in range(depth)]).cuda().to(dtype)
|
46
|
+
o = get_optim(opt, model.parameters(), lr=1e-3, storage_dtype=dtype_name)
|
47
|
+
|
48
|
+
for _ in range(iterations):
|
49
|
+
loss = model(torch.randn((1024, size), device='cuda', dtype=dtype)).square().mean()
|
50
|
+
loss.backward()
|
51
|
+
o.step()
|
52
|
+
o.zero_grad()
|
53
|
+
losses[-1].append(loss.detach())
|
54
|
+
|
55
|
+
del model, o
|
56
|
+
clean()
|
57
|
+
|
58
|
+
for i, (l0, l1) in enumerate(zip(*losses)):
|
59
|
+
print(i, l0.item(), l1.item())
|
60
|
+
assert torch.allclose(l0.float(), l1.float(), rtol=0.1)
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|