heavyball 0.16.0__tar.gz → 0.17.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {heavyball-0.16.0 → heavyball-0.17.1}/PKG-INFO +17 -17
- {heavyball-0.16.0 → heavyball-0.17.1}/README.md +16 -16
- {heavyball-0.16.0 → heavyball-0.17.1}/heavyball/__init__.py +7 -6
- heavyball-0.17.1/heavyball/cached_delayed_psgd_kron.py +146 -0
- {heavyball-0.16.0 → heavyball-0.17.1}/heavyball/cached_psgd_kron.py +13 -8
- {heavyball-0.16.0 → heavyball-0.17.1}/heavyball/delayed_psgd.py +8 -7
- {heavyball-0.16.0 → heavyball-0.17.1}/heavyball/p_adam.py +9 -7
- {heavyball-0.16.0 → heavyball-0.17.1}/heavyball/psgd_kron.py +8 -7
- {heavyball-0.16.0 → heavyball-0.17.1}/heavyball/pure_psgd.py +9 -6
- {heavyball-0.16.0 → heavyball-0.17.1}/heavyball/utils.py +18 -13
- {heavyball-0.16.0 → heavyball-0.17.1}/heavyball.egg-info/PKG-INFO +17 -17
- {heavyball-0.16.0 → heavyball-0.17.1}/heavyball.egg-info/SOURCES.txt +2 -0
- {heavyball-0.16.0 → heavyball-0.17.1}/setup.py +1 -1
- heavyball-0.17.1/test/test_bf16_q.py +52 -0
- {heavyball-0.16.0 → heavyball-0.17.1}/test/test_closure.py +1 -1
- {heavyball-0.16.0 → heavyball-0.17.1}/test/test_memory.py +2 -2
- {heavyball-0.16.0 → heavyball-0.17.1}/test/test_merge.py +1 -1
- {heavyball-0.16.0 → heavyball-0.17.1}/test/test_psgd.py +3 -14
- {heavyball-0.16.0 → heavyball-0.17.1}/LICENSE +0 -0
- {heavyball-0.16.0 → heavyball-0.17.1}/heavyball/foreach_adamw.py +0 -0
- {heavyball-0.16.0 → heavyball-0.17.1}/heavyball/foreach_adopt.py +0 -0
- {heavyball-0.16.0 → heavyball-0.17.1}/heavyball/foreach_laprop.py +0 -0
- {heavyball-0.16.0 → heavyball-0.17.1}/heavyball/foreach_sfadamw.py +0 -0
- {heavyball-0.16.0 → heavyball-0.17.1}/heavyball/foreach_soap.py +0 -0
- {heavyball-0.16.0 → heavyball-0.17.1}/heavyball/palm_foreach_sfadamw.py +0 -0
- {heavyball-0.16.0 → heavyball-0.17.1}/heavyball/palm_foreach_soap.py +0 -0
- {heavyball-0.16.0 → heavyball-0.17.1}/heavyball/precond_schedule_foreach_soap.py +0 -0
- {heavyball-0.16.0 → heavyball-0.17.1}/heavyball/precond_schedule_palm_foreach_soap.py +0 -0
- {heavyball-0.16.0 → heavyball-0.17.1}/heavyball/precond_schedule_sfpsoap.py +0 -0
- {heavyball-0.16.0 → heavyball-0.17.1}/heavyball/schedule_free_palm_foreach_soap.py +0 -0
- {heavyball-0.16.0 → heavyball-0.17.1}/heavyball.egg-info/dependency_links.txt +0 -0
- {heavyball-0.16.0 → heavyball-0.17.1}/heavyball.egg-info/requires.txt +0 -0
- {heavyball-0.16.0 → heavyball-0.17.1}/heavyball.egg-info/top_level.txt +0 -0
- {heavyball-0.16.0 → heavyball-0.17.1}/setup.cfg +0 -0
- {heavyball-0.16.0 → heavyball-0.17.1}/test/test_foreach.py +0 -0
- {heavyball-0.16.0 → heavyball-0.17.1}/test/test_no_grad.py +0 -0
- {heavyball-0.16.0 → heavyball-0.17.1}/test/test_soap.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: heavyball
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.17.1
|
4
4
|
Summary: Efficient optimizers
|
5
5
|
Home-page: https://github.com/clashluke/heavyball
|
6
6
|
Author: Lucas Nestler
|
@@ -32,8 +32,8 @@ A simple package of efficient optimizers
|
|
32
32
|
The goal is not to thrive for completeness, full maintenance or abstraction, but instead to provide a simple
|
33
33
|
largely static alternative to `torch.optim` with more and better optimizers.
|
34
34
|
|
35
|
-
Currently (2024-11-
|
36
|
-
recommended experimental optimizer is `
|
35
|
+
Currently (2024-11-20, 0.17.0), the recommended stable optimizer is `PrecondSchedulePaLMSOAP` (see below). The
|
36
|
+
recommended experimental optimizer is `DelayedPSGDKron` ([tuning guide](docs/psgd_efficiency.md)).
|
37
37
|
|
38
38
|
## Features
|
39
39
|
|
@@ -62,7 +62,7 @@ import heavyball
|
|
62
62
|
model = torch.nn.Linear(16, 1)
|
63
63
|
|
64
64
|
# Create an optimizer
|
65
|
-
optimizer = heavyball.
|
65
|
+
optimizer = heavyball.PrecondSchedulePaLMSOAP(model.parameters(), lr=1e-3)
|
66
66
|
|
67
67
|
x = torch.randn(128, 16)
|
68
68
|
y = torch.randn(128, 1)
|
@@ -76,19 +76,19 @@ for _ in range(1000):
|
|
76
76
|
|
77
77
|
## Optimizers
|
78
78
|
|
79
|
-
| Name
|
80
|
-
|
81
|
-
| **
|
82
|
-
| **
|
83
|
-
| **
|
84
|
-
| **
|
85
|
-
| **
|
86
|
-
| **
|
87
|
-
| **
|
88
|
-
| **
|
89
|
-
| **
|
90
|
-
| **
|
91
|
-
| **
|
79
|
+
| Name | Description | Advantages / Disadvantages |
|
80
|
+
|-------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
81
|
+
| **AdamW** | More efficient (speed, memory) [AdamW](https://arxiv.org/abs/1711.05101) | + Faster than AdamW<br>+ Possibly more (numerically) stable
|
82
|
+
| **LaProp** | More efficient (speed, memory) [LaProp](https://arxiv.org/abs/2002.04839) | + Same cost as AdamW<br>+ Marginally better converence (better proofs)<br>+ Higher hyperparameter stability<br>- Not a guaranteed win (can be neutral)<br>- No "Slingshot" |
|
83
|
+
| **ADOPT** | More efficient (speed, memory) [ADOPT](https://arxiv.org/abs/2411.02853) | + Same cost as AdamW<br>+ Rigorous mathematical convergence proofs, even for challenging models (GANs)<br>- Empirically underperforms LaProp<br>- no bf16 |
|
84
|
+
| **SFAdamW** | More efficient (speed, memory) [ScheduleFree AdamW](https://arxiv.org/abs/2405.15682) | + Same cost as AdamW, but better eval perf<br>+ Full control over hyperparameters |
|
85
|
+
| **PaLMSFAdamW** | ForeachSFAdamW with [PaLM's beta2 schedule](https://arxiv.org/abs/2204.02311) | + Same cost as AdamW, but better eval perf<br>+ Less control, but faster early and more stable late convergence<br>+ ScheduleFree<br>- slow early convergence |
|
86
|
+
| **SOAP** | More efficient (speed, memory) [SOAP](https://arxiv.org/abs/2409.11321) | + Faster convergence (loss-at-step)<br>+ Full control over hyperparameters<br>- more memory usage<br>- more hyperparameters<br>- higher overhead than AdamW (can be ammortized; better loss-at-second) |
|
87
|
+
| **PaLMSOAP** | ForeachSOAP with [PaLM's beta2 schedule](https://arxiv.org/abs/2204.02311) | + Faster convergence (loss-at-step)<br>+ Less control, but faster early and more stable late convergence<br>- more memory usage<br>- more hyperparameters<br>- higher overhead than AdamW (can be ammortized; better loss-at-second) |
|
88
|
+
| **SFPaLMSOAP** | ScheduleFree PaLMForeachSOAP | + Fast convergence (loss-at-step)<br>+ less memory usage than PaLMForeachSOAP (more tham AdamW)<br>- slower initial convergence than PaLMForeachSOAP (but allows higher LRs)<br>- higher overhead than AdamW (can be ammortized) |
|
89
|
+
| **PrecondScheduleSFPaLMSOAP** | SFPaLMForeachSOAP with [preconditioner schedule](https://github.com/lixilinx/psgd_torch/), matching the error of PrecondEvery=2 with the cost of PrecondEvery=512 | + Better initial convergence than SFPaLMForeachSOAP<br>+ Significantly faster (sec/it) later<br>+ less memory usage than PaLMForeachSOAP (more tham AdamW)<br>- slower initial convergence than PaLMForeachSOAP (but allows higher LRs)<br>- higher overhead than AdamW (can be ammortized), goes to 0 with increasing number of step |
|
90
|
+
| **PrecondSchedulePaLMSOAP** | PrecondScheduleSFPaLMForeachSOAP without schedule-free | + Best initial convergence<br>+ Significantly faster (sec/it) later<br>+ high stability<br>- more memory usage than PrecondScheduleSFPaLMForeachSOAP<br>- higher overhead than AdamW (can be ammortized), goes to 0 with increasing number of steps |
|
91
|
+
| **PrecondScheduleSOAP** | PrecondScheduleSFPaLMForeachSOAP without PaLM's beta2 schedule | + Better initial convergence<br>+ Significantly faster (sec/it) later<br>- more memory usage than PrecondScheduleSFPaLMForeachSOAP<br>- higher overhead than AdamW (can be ammortized), goes to 0 with increasing number of steps |
|
92
92
|
|
93
93
|
## Precond Schedule
|
94
94
|
|
@@ -8,8 +8,8 @@ A simple package of efficient optimizers
|
|
8
8
|
The goal is not to thrive for completeness, full maintenance or abstraction, but instead to provide a simple
|
9
9
|
largely static alternative to `torch.optim` with more and better optimizers.
|
10
10
|
|
11
|
-
Currently (2024-11-
|
12
|
-
recommended experimental optimizer is `
|
11
|
+
Currently (2024-11-20, 0.17.0), the recommended stable optimizer is `PrecondSchedulePaLMSOAP` (see below). The
|
12
|
+
recommended experimental optimizer is `DelayedPSGDKron` ([tuning guide](docs/psgd_efficiency.md)).
|
13
13
|
|
14
14
|
## Features
|
15
15
|
|
@@ -38,7 +38,7 @@ import heavyball
|
|
38
38
|
model = torch.nn.Linear(16, 1)
|
39
39
|
|
40
40
|
# Create an optimizer
|
41
|
-
optimizer = heavyball.
|
41
|
+
optimizer = heavyball.PrecondSchedulePaLMSOAP(model.parameters(), lr=1e-3)
|
42
42
|
|
43
43
|
x = torch.randn(128, 16)
|
44
44
|
y = torch.randn(128, 1)
|
@@ -52,19 +52,19 @@ for _ in range(1000):
|
|
52
52
|
|
53
53
|
## Optimizers
|
54
54
|
|
55
|
-
| Name
|
56
|
-
|
57
|
-
| **
|
58
|
-
| **
|
59
|
-
| **
|
60
|
-
| **
|
61
|
-
| **
|
62
|
-
| **
|
63
|
-
| **
|
64
|
-
| **
|
65
|
-
| **
|
66
|
-
| **
|
67
|
-
| **
|
55
|
+
| Name | Description | Advantages / Disadvantages |
|
56
|
+
|-------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
57
|
+
| **AdamW** | More efficient (speed, memory) [AdamW](https://arxiv.org/abs/1711.05101) | + Faster than AdamW<br>+ Possibly more (numerically) stable
|
58
|
+
| **LaProp** | More efficient (speed, memory) [LaProp](https://arxiv.org/abs/2002.04839) | + Same cost as AdamW<br>+ Marginally better converence (better proofs)<br>+ Higher hyperparameter stability<br>- Not a guaranteed win (can be neutral)<br>- No "Slingshot" |
|
59
|
+
| **ADOPT** | More efficient (speed, memory) [ADOPT](https://arxiv.org/abs/2411.02853) | + Same cost as AdamW<br>+ Rigorous mathematical convergence proofs, even for challenging models (GANs)<br>- Empirically underperforms LaProp<br>- no bf16 |
|
60
|
+
| **SFAdamW** | More efficient (speed, memory) [ScheduleFree AdamW](https://arxiv.org/abs/2405.15682) | + Same cost as AdamW, but better eval perf<br>+ Full control over hyperparameters |
|
61
|
+
| **PaLMSFAdamW** | ForeachSFAdamW with [PaLM's beta2 schedule](https://arxiv.org/abs/2204.02311) | + Same cost as AdamW, but better eval perf<br>+ Less control, but faster early and more stable late convergence<br>+ ScheduleFree<br>- slow early convergence |
|
62
|
+
| **SOAP** | More efficient (speed, memory) [SOAP](https://arxiv.org/abs/2409.11321) | + Faster convergence (loss-at-step)<br>+ Full control over hyperparameters<br>- more memory usage<br>- more hyperparameters<br>- higher overhead than AdamW (can be ammortized; better loss-at-second) |
|
63
|
+
| **PaLMSOAP** | ForeachSOAP with [PaLM's beta2 schedule](https://arxiv.org/abs/2204.02311) | + Faster convergence (loss-at-step)<br>+ Less control, but faster early and more stable late convergence<br>- more memory usage<br>- more hyperparameters<br>- higher overhead than AdamW (can be ammortized; better loss-at-second) |
|
64
|
+
| **SFPaLMSOAP** | ScheduleFree PaLMForeachSOAP | + Fast convergence (loss-at-step)<br>+ less memory usage than PaLMForeachSOAP (more tham AdamW)<br>- slower initial convergence than PaLMForeachSOAP (but allows higher LRs)<br>- higher overhead than AdamW (can be ammortized) |
|
65
|
+
| **PrecondScheduleSFPaLMSOAP** | SFPaLMForeachSOAP with [preconditioner schedule](https://github.com/lixilinx/psgd_torch/), matching the error of PrecondEvery=2 with the cost of PrecondEvery=512 | + Better initial convergence than SFPaLMForeachSOAP<br>+ Significantly faster (sec/it) later<br>+ less memory usage than PaLMForeachSOAP (more tham AdamW)<br>- slower initial convergence than PaLMForeachSOAP (but allows higher LRs)<br>- higher overhead than AdamW (can be ammortized), goes to 0 with increasing number of step |
|
66
|
+
| **PrecondSchedulePaLMSOAP** | PrecondScheduleSFPaLMForeachSOAP without schedule-free | + Best initial convergence<br>+ Significantly faster (sec/it) later<br>+ high stability<br>- more memory usage than PrecondScheduleSFPaLMForeachSOAP<br>- higher overhead than AdamW (can be ammortized), goes to 0 with increasing number of steps |
|
67
|
+
| **PrecondScheduleSOAP** | PrecondScheduleSFPaLMForeachSOAP without PaLM's beta2 schedule | + Better initial convergence<br>+ Significantly faster (sec/it) later<br>- more memory usage than PrecondScheduleSFPaLMForeachSOAP<br>- higher overhead than AdamW (can be ammortized), goes to 0 with increasing number of steps |
|
68
68
|
|
69
69
|
## Precond Schedule
|
70
70
|
|
@@ -14,31 +14,32 @@ from .precond_schedule_sfpsoap import PrecondScheduleSFPaLMSOAP
|
|
14
14
|
from .psgd_kron import ForeachPSGDKron
|
15
15
|
from .pure_psgd import ForeachPurePSGD
|
16
16
|
from .schedule_free_palm_foreach_soap import SFPaLMForeachSOAP
|
17
|
+
from .cached_delayed_psgd_kron import ForeachCachedDelayedPSGDKron
|
17
18
|
|
18
19
|
PalmForEachSoap = PaLMForeachSOAP
|
19
20
|
|
20
21
|
PaLMSOAP = PaLMForeachSOAP
|
21
22
|
PaLMSFAdamW = PaLMForeachSFAdamW
|
22
23
|
PaLMSFSoap = SFPaLMForeachSOAP
|
23
|
-
PaLMForeachSOAP = PaLMForeachSOAP
|
24
24
|
PrecondScheduleSFPaLMSOAP = PrecondScheduleSFPaLMSOAP
|
25
25
|
SOAP = ForeachSOAP
|
26
26
|
SFAdamW = ForeachSFAdamW
|
27
27
|
LaProp = ForeachLaProp
|
28
28
|
ADOPT = ForeachADOPT
|
29
|
-
|
30
|
-
|
29
|
+
PrecondScheduleSOAP = PrecondScheduleForeachSOAP
|
30
|
+
PrecondSchedulePaLMSOAP = PrecondSchedulePaLMForeachSOAP
|
31
31
|
PSGDKron = ForeachPSGDKron
|
32
32
|
AdamW = ForeachAdamW
|
33
33
|
PurePSGD = ForeachPurePSGD
|
34
34
|
PaLMPAdam = ForeachPaLMPAdam
|
35
35
|
DelayedPSGD = ForeachDelayedPSGD
|
36
36
|
CachedPSGDKron = ForeachCachedPSGDKron
|
37
|
+
CachedDelayedPSGDKron = ForeachCachedDelayedPSGDKron
|
37
38
|
|
38
39
|
__all__ = ['PalmForEachSoap', 'PaLMForeachSFAdamW', 'PaLMForeachSOAP', 'SFPaLMForeachSOAP', 'PrecondScheduleSFPaLMSOAP',
|
39
40
|
'ForeachSOAP', 'ForeachSFAdamW', 'ForeachLaProp', 'ForeachADOPT', 'PrecondScheduleForeachSOAP',
|
40
41
|
'PrecondSchedulePaLMForeachSOAP', 'ForeachPSGDKron', 'ForeachAdamW', 'ForeachPurePSGD', 'ForeachPaLMPAdam',
|
41
|
-
'ForeachDelayedPSGD', 'ForeachCachedPSGDKron', #
|
42
|
-
'PaLMSOAP', 'PaLMSFAdamW', 'PaLMSFSoap', 'PaLMSFAdamW', '
|
42
|
+
'ForeachDelayedPSGD', 'ForeachCachedPSGDKron', 'ForeachCachedDelayedPSGDKron', #
|
43
|
+
'PaLMSOAP', 'PaLMSFAdamW', 'PaLMSFSoap', 'PaLMSFAdamW', 'PrecondScheduleSFPaLMSOAP',
|
43
44
|
'SOAP', 'SFAdamW', 'LaProp', 'ADOPT', 'PSGDKron', 'AdamW', 'PurePSGD', 'PaLMPAdam', 'DelayedPSGD',
|
44
|
-
'CachedPSGDKron']
|
45
|
+
'CachedPSGDKron', 'CachedDelayedPSGDKron', 'PrecondScheduleSOAP', 'PrecondSchedulePaLMSOAP']
|
@@ -0,0 +1,146 @@
|
|
1
|
+
"""
|
2
|
+
Originally from Evan Walters and Omead Pooladzandi, 2024
|
3
|
+
Modified under Creative Commons Attribution 4.0 International
|
4
|
+
Source available at https://github.com/evanatyourservice/kron_torch/blob/97a2b5ee8a1a4c29e4780bbf6c521e545189eff9/kron_torch/kron.py
|
5
|
+
"""
|
6
|
+
|
7
|
+
from typing import Optional
|
8
|
+
|
9
|
+
import torch
|
10
|
+
from heavyball.utils import einsum_base
|
11
|
+
|
12
|
+
from .utils import update_param_, warmup, psgd_precond_grad, init_Q_exprs, trust_region_clip_, PSGDBase, \
|
13
|
+
precond_update_prob_schedule, split_p_and_g_in_group, line_to_triu, triu_to_line, set_, einsum_base, promote
|
14
|
+
|
15
|
+
|
16
|
+
class ForeachCachedDelayedPSGDKron(PSGDBase):
|
17
|
+
"""
|
18
|
+
Implements PSGD with off-by-one preconditioning (akin to ADOPT and SOAP) with cached preconditioners.
|
19
|
+
|
20
|
+
|
21
|
+
Args:
|
22
|
+
params (iterable): Iterable of parameters to optimize or dicts defining
|
23
|
+
parameter groups.
|
24
|
+
lr (float): Learning rate.
|
25
|
+
b1 (float): Momentum parameter.
|
26
|
+
weight_decay (float): Weight decay (L2 penalty).
|
27
|
+
preconditioner_update_probability (callable or float, optional): Probability of
|
28
|
+
updating the preconditioner. If None, defaults to a schedule that anneals
|
29
|
+
from 1.0 to 0.03 by 4000 steps.
|
30
|
+
max_size_triangular (int): Max size for dim's preconditioner to be triangular.
|
31
|
+
min_ndim_triangular (int): Minimum number of dimensions a layer needs
|
32
|
+
to have triangular preconditioners.
|
33
|
+
memory_save_mode: (string, optional), None, 'one_diag', or 'all_diag', None is default
|
34
|
+
to set all preconditioners to be triangular, 'one_diag' sets the largest
|
35
|
+
or last dim to be diagonal per layer, and 'all_diag' sets all preconditioners
|
36
|
+
to be diagonal.
|
37
|
+
momentum_into_precond_update: (bool), whether to send momentum into preconditioner
|
38
|
+
update instead of raw gradients.
|
39
|
+
"""
|
40
|
+
|
41
|
+
def __init__(self, params, lr=0.001, beta=0.9, weight_decay=0.0, preconditioner_update_probability=None,
|
42
|
+
max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
|
43
|
+
momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
|
44
|
+
split: bool = False, clip_fn: Optional[callable] = None, store_triu_as_line: bool = True,
|
45
|
+
foreach: bool = True, q_dtype='float32'):
|
46
|
+
if not 0.0 <= lr:
|
47
|
+
raise ValueError(f"Invalid learning rate: {lr}")
|
48
|
+
if not 0.0 <= beta < 1.0:
|
49
|
+
raise ValueError(f"Invalid beta parameter: {beta}")
|
50
|
+
if not 0.0 <= weight_decay:
|
51
|
+
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
|
52
|
+
|
53
|
+
if preconditioner_update_probability is None:
|
54
|
+
preconditioner_update_probability = precond_update_prob_schedule()
|
55
|
+
if clip_fn is None:
|
56
|
+
clip_fn = lambda x: trust_region_clip_(x, 0.9, 1.5)
|
57
|
+
self.preconditioner_update_probability = preconditioner_update_probability
|
58
|
+
self.clip_fn = clip_fn
|
59
|
+
|
60
|
+
defaults = dict(lr=lr, beta=beta, weight_decay=weight_decay, max_size_triangular=max_size_triangular,
|
61
|
+
min_ndim_triangular=min_ndim_triangular, memory_save_mode=memory_save_mode,
|
62
|
+
momentum_into_precond_update=momentum_into_precond_update, precond_lr=0.1,
|
63
|
+
# precond lr hardcoded to 0.1
|
64
|
+
precond_init_scale=1.0, # precond init scale hardcoded to 1.0
|
65
|
+
step=0, warmup_steps=warmup_steps, merge_dims=merge_dims, split=split,
|
66
|
+
store_triu_as_line=store_triu_as_line,
|
67
|
+
q_dtype=q_dtype)
|
68
|
+
super().__init__(params, defaults, foreach)
|
69
|
+
|
70
|
+
self._prob_step = 0
|
71
|
+
|
72
|
+
def _step(self, group):
|
73
|
+
# update preconditioners all together
|
74
|
+
update_prob = self.preconditioner_update_probability
|
75
|
+
if callable(update_prob):
|
76
|
+
update_prob = update_prob(self._prob_step)
|
77
|
+
do_update = self.rng.random() < update_prob
|
78
|
+
self._prob_step += 1
|
79
|
+
|
80
|
+
momentum_into_precond_update = group.get("momentum_into_precond_update", True)
|
81
|
+
precond_init_scale = group['precond_init_scale']
|
82
|
+
max_size_triangular = group['max_size_triangular']
|
83
|
+
min_ndim_triangular = group['min_ndim_triangular']
|
84
|
+
memory_save_mode = group['memory_save_mode']
|
85
|
+
precond_lr = group['precond_lr']
|
86
|
+
weight_decay = group['weight_decay']
|
87
|
+
lr = group['lr']
|
88
|
+
beta = group['beta']
|
89
|
+
store_triu_as_line = group['store_triu_as_line']
|
90
|
+
q_dtype = getattr(torch, group['q_dtype'])
|
91
|
+
|
92
|
+
vals = []
|
93
|
+
|
94
|
+
for p, g in split_p_and_g_in_group(group):
|
95
|
+
state = self.state_(p)
|
96
|
+
|
97
|
+
if 'Q' not in state:
|
98
|
+
state["exp_avg"] = torch.zeros_like(g)
|
99
|
+
Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular, min_ndim_triangular,
|
100
|
+
memory_save_mode, dtype=q_dtype)
|
101
|
+
state['Q'] = triu_to_line(Q) if store_triu_as_line else Q
|
102
|
+
state['Q_cache'] = [torch.empty_like(q) for q in Q]
|
103
|
+
|
104
|
+
expr = [f'{c.upper()}{c}' if q_.ndim == 2 else c for c, q_ in zip(einsum_base, Q)]
|
105
|
+
expr = ','.join(expr)
|
106
|
+
grad_expr = ''.join(c for c, _ in zip(einsum_base, g.shape))
|
107
|
+
out_expr = ''.join(c.upper() if c.upper() in expr else c for c in grad_expr)
|
108
|
+
expr = f'{expr},{grad_expr}->{out_expr}'
|
109
|
+
|
110
|
+
state['cache_expr'] = expr
|
111
|
+
|
112
|
+
vals.append((p, g, state["exp_avg"], state["Q"], state['Q_cache']))
|
113
|
+
|
114
|
+
if not vals:
|
115
|
+
return
|
116
|
+
|
117
|
+
p_list, grad_list, exp_avg_list, Q_list, Q_cache_list = zip(*vals)
|
118
|
+
del vals
|
119
|
+
|
120
|
+
group["step"] += 1
|
121
|
+
|
122
|
+
torch._foreach_lerp_(exp_avg_list, grad_list, (1 - beta) / (1 - beta ** group["step"]))
|
123
|
+
|
124
|
+
grad_list, Q_list, Q_cache_list, exp_avg_list = list(grad_list), list(Q_list), list(Q_cache_list), list(
|
125
|
+
exp_avg_list)
|
126
|
+
for i, (p, g) in enumerate(zip(p_list, grad_list)):
|
127
|
+
cached_q = Q_cache_list.pop(0)
|
128
|
+
q_orig = Q_list.pop(0)
|
129
|
+
ea = exp_avg_list.pop(0)
|
130
|
+
|
131
|
+
if do_update:
|
132
|
+
q = line_to_triu(q_orig) if store_triu_as_line else q_orig
|
133
|
+
q32 = [promote(q_) for q_ in q]
|
134
|
+
self.balance([g], [q32])
|
135
|
+
self.do_update([p], [ea if momentum_into_precond_update else g], [q32], precond_lr, [q_orig], store_triu_as_line=store_triu_as_line)
|
136
|
+
for c_, q_ in zip(cached_q, q):
|
137
|
+
if q_.ndim == 2:
|
138
|
+
torch.matmul(q_.T.conj(), q_, out=c_)
|
139
|
+
else:
|
140
|
+
torch.mul(q_.conj(), q_, out=c_)
|
141
|
+
|
142
|
+
set_(g, torch.einsum(self.state_(p)['cache_expr'], *cached_q, ea))
|
143
|
+
grad_list = self.clip_fn(grad_list)
|
144
|
+
|
145
|
+
lr = -warmup(lr, group['step'], group['warmup_steps'])
|
146
|
+
update_param_(p_list, grad_list, lr, weight_decay)
|
@@ -10,7 +10,7 @@ import torch
|
|
10
10
|
from heavyball.utils import einsum_base
|
11
11
|
|
12
12
|
from .utils import update_param_, warmup, psgd_precond_grad, init_Q_exprs, trust_region_clip_, PSGDBase, \
|
13
|
-
precond_update_prob_schedule, split_p_and_g_in_group, line_to_triu, triu_to_line, set_, einsum_base
|
13
|
+
precond_update_prob_schedule, split_p_and_g_in_group, line_to_triu, triu_to_line, set_, einsum_base, promote
|
14
14
|
|
15
15
|
|
16
16
|
class ForeachCachedPSGDKron(PSGDBase):
|
@@ -40,7 +40,7 @@ class ForeachCachedPSGDKron(PSGDBase):
|
|
40
40
|
max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
|
41
41
|
momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
|
42
42
|
split: bool = False, clip_fn: Optional[callable] = None, store_triu_as_line: bool = True,
|
43
|
-
foreach: bool = True):
|
43
|
+
foreach: bool = True, q_dtype='float32'):
|
44
44
|
if not 0.0 <= lr:
|
45
45
|
raise ValueError(f"Invalid learning rate: {lr}")
|
46
46
|
if not 0.0 <= beta < 1.0:
|
@@ -61,7 +61,8 @@ class ForeachCachedPSGDKron(PSGDBase):
|
|
61
61
|
# precond lr hardcoded to 0.1
|
62
62
|
precond_init_scale=1.0, # precond init scale hardcoded to 1.0
|
63
63
|
step=0, warmup_steps=warmup_steps, merge_dims=merge_dims, split=split,
|
64
|
-
store_triu_as_line=store_triu_as_line
|
64
|
+
store_triu_as_line=store_triu_as_line,
|
65
|
+
q_dtype=q_dtype)
|
65
66
|
super().__init__(params, defaults, foreach)
|
66
67
|
|
67
68
|
self._prob_step = 0
|
@@ -84,6 +85,7 @@ class ForeachCachedPSGDKron(PSGDBase):
|
|
84
85
|
lr = group['lr']
|
85
86
|
beta = group['beta']
|
86
87
|
store_triu_as_line = group['store_triu_as_line']
|
88
|
+
q_dtype = getattr(torch, group['q_dtype'])
|
87
89
|
|
88
90
|
vals = []
|
89
91
|
|
@@ -93,7 +95,7 @@ class ForeachCachedPSGDKron(PSGDBase):
|
|
93
95
|
if 'Q' not in state:
|
94
96
|
state["exp_avg"] = torch.zeros_like(g)
|
95
97
|
Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular, min_ndim_triangular,
|
96
|
-
memory_save_mode, dtype=
|
98
|
+
memory_save_mode, dtype=q_dtype)
|
97
99
|
state['Q'] = triu_to_line(Q) if store_triu_as_line else Q
|
98
100
|
state['Q_cache'] = [torch.empty_like(q) for q in Q]
|
99
101
|
|
@@ -124,18 +126,21 @@ class ForeachCachedPSGDKron(PSGDBase):
|
|
124
126
|
q_orig = Q_list.pop(0)
|
125
127
|
ea = exp_avg_list.pop(0)
|
126
128
|
|
129
|
+
new = torch.einsum(self.state_(p)['cache_expr'], *cached_q, ea)
|
130
|
+
|
127
131
|
if do_update:
|
128
132
|
q = line_to_triu(q_orig) if store_triu_as_line else q_orig
|
129
|
-
|
130
|
-
self.
|
131
|
-
|
133
|
+
q32 = [promote(q_) for q_ in q]
|
134
|
+
self.balance([g], [q32])
|
135
|
+
self.do_update([p], [ea if momentum_into_precond_update else g], [q32], precond_lr, [q_orig], store_triu_as_line=store_triu_as_line)
|
132
136
|
for c_, q_ in zip(cached_q, q):
|
133
137
|
if q_.ndim == 2:
|
134
138
|
torch.matmul(q_.T.conj(), q_, out=c_)
|
135
139
|
else:
|
136
140
|
torch.mul(q_.conj(), q_, out=c_)
|
137
141
|
|
138
|
-
set_(g,
|
142
|
+
set_(g, new)
|
143
|
+
|
139
144
|
grad_list = self.clip_fn(grad_list)
|
140
145
|
|
141
146
|
lr = -warmup(lr, group['step'], group['warmup_steps'])
|
@@ -8,7 +8,7 @@ import torch
|
|
8
8
|
from heavyball.utils import copy_stochastic_list_
|
9
9
|
|
10
10
|
from .utils import update_param_, warmup, psgd_precond_grad, init_Q_exprs, trust_region_clip_, PSGDBase, \
|
11
|
-
precond_update_prob_schedule, split_p_and_g_in_group, triu_to_line, line_to_triu, set_
|
11
|
+
precond_update_prob_schedule, split_p_and_g_in_group, triu_to_line, line_to_triu, set_, promote
|
12
12
|
|
13
13
|
|
14
14
|
class ForeachDelayedPSGD(PSGDBase):
|
@@ -39,7 +39,7 @@ class ForeachDelayedPSGD(PSGDBase):
|
|
39
39
|
max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
|
40
40
|
momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
|
41
41
|
split: bool = False, clip_fn: callable = None, store_triu_as_line: bool = True,
|
42
|
-
foreach: bool = True):
|
42
|
+
foreach: bool = True, q_dtype='float32'):
|
43
43
|
if not 0.0 <= lr:
|
44
44
|
raise ValueError(f"Invalid learning rate: {lr}")
|
45
45
|
if not 0.0 <= beta < 1.0:
|
@@ -60,7 +60,7 @@ class ForeachDelayedPSGD(PSGDBase):
|
|
60
60
|
# precond lr hardcoded to 0.1
|
61
61
|
precond_init_scale=1.0, # precond init scale hardcoded to 1.0
|
62
62
|
step=0, warmup_steps=warmup_steps, merge_dims=merge_dims, split=split,
|
63
|
-
store_triu_as_line=store_triu_as_line)
|
63
|
+
store_triu_as_line=store_triu_as_line, q_dtype=q_dtype)
|
64
64
|
super().__init__(params, defaults, foreach)
|
65
65
|
|
66
66
|
self._prob_step = 0
|
@@ -83,6 +83,7 @@ class ForeachDelayedPSGD(PSGDBase):
|
|
83
83
|
lr = group['lr']
|
84
84
|
beta = group['beta']
|
85
85
|
store_triu_as_line = group['store_triu_as_line']
|
86
|
+
q_dtype = getattr(torch, group['q_dtype'])
|
86
87
|
|
87
88
|
vals = []
|
88
89
|
|
@@ -92,7 +93,7 @@ class ForeachDelayedPSGD(PSGDBase):
|
|
92
93
|
if 'Q' not in state:
|
93
94
|
state["exp_avg"] = torch.zeros_like(g)
|
94
95
|
Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular, min_ndim_triangular,
|
95
|
-
memory_save_mode, dtype=
|
96
|
+
memory_save_mode, dtype=q_dtype)
|
96
97
|
state["Q"] = triu_to_line(Q) if store_triu_as_line else Q
|
97
98
|
|
98
99
|
vals.append((p, g, state["exp_avg"], state["Q"]))
|
@@ -114,9 +115,9 @@ class ForeachDelayedPSGD(PSGDBase):
|
|
114
115
|
q = line_to_triu(q_orig) if store_triu_as_line else q_orig
|
115
116
|
new = psgd_precond_grad(q, self.state_(p)["exprs"], ea)
|
116
117
|
if do_update:
|
117
|
-
|
118
|
-
|
119
|
-
self.balance([g], [
|
118
|
+
q32 = [promote(q_) for q_ in q]
|
119
|
+
self.do_update([p], [ea if momentum_into_precond_update else g], [q32], precond_lr, [q_orig], store_triu_as_line=store_triu_as_line)
|
120
|
+
self.balance([g], [q32])
|
120
121
|
set_(g, new)
|
121
122
|
|
122
123
|
grad_list = self.clip_fn(grad_list)
|
@@ -8,7 +8,7 @@ import torch
|
|
8
8
|
from heavyball.utils import triu_to_line, line_to_triu
|
9
9
|
|
10
10
|
from .utils import update_param_, warmup, psgd_precond_grad, init_Q_exprs, PSGDBase, precond_update_prob_schedule, \
|
11
|
-
exp_avg_sq_, beta_debias, split_p_and_g_in_group
|
11
|
+
exp_avg_sq_, beta_debias, split_p_and_g_in_group, promote
|
12
12
|
|
13
13
|
|
14
14
|
class ForeachPaLMPAdam(PSGDBase):
|
@@ -39,7 +39,7 @@ class ForeachPaLMPAdam(PSGDBase):
|
|
39
39
|
momentum_into_precond_update=True, warmup_steps: int = 1, betas=(None, None), beta: float = 0.9,
|
40
40
|
beta2_scale: float = 0.8, merge_dims: bool = False, split: bool = False, clip_fn: callable = None,
|
41
41
|
store_triu_as_line: bool = True,
|
42
|
-
foreach: bool = True):
|
42
|
+
foreach: bool = True, q_dtype='float32'):
|
43
43
|
if not 0.0 <= lr:
|
44
44
|
raise ValueError(f"Invalid learning rate: {lr}")
|
45
45
|
if not 0.0 <= weight_decay:
|
@@ -60,7 +60,7 @@ class ForeachPaLMPAdam(PSGDBase):
|
|
60
60
|
# precond lr hardcoded to 0.1
|
61
61
|
precond_init_scale=1.0, # precond init scale hardcoded to 1.0
|
62
62
|
step=0, warmup_steps=warmup_steps, beta=beta, beta2_scale=beta2_scale, merge_dims=merge_dims,
|
63
|
-
split=split, store_triu_as_line=store_triu_as_line)
|
63
|
+
split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype)
|
64
64
|
super().__init__(params, defaults, foreach)
|
65
65
|
|
66
66
|
self._prob_step = 0
|
@@ -81,6 +81,7 @@ class ForeachPaLMPAdam(PSGDBase):
|
|
81
81
|
weight_decay = group['weight_decay']
|
82
82
|
lr = group['lr']
|
83
83
|
store_triu_as_line = group['store_triu_as_line']
|
84
|
+
q_dtype = getattr(torch, group['q_dtype'])
|
84
85
|
|
85
86
|
vals = []
|
86
87
|
|
@@ -91,7 +92,7 @@ class ForeachPaLMPAdam(PSGDBase):
|
|
91
92
|
state['exp_avg'] = torch.zeros_like(g)
|
92
93
|
state['exp_avg_sq'] = torch.zeros_like(g)
|
93
94
|
Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular,
|
94
|
-
min_ndim_triangular, memory_save_mode, dtype=
|
95
|
+
min_ndim_triangular, memory_save_mode, dtype=q_dtype)
|
95
96
|
state['Q'] = triu_to_line(Q) if store_triu_as_line else Q
|
96
97
|
|
97
98
|
vals.append((p, g, state["Q"], state['exp_avg'], state['exp_avg_sq']))
|
@@ -106,9 +107,10 @@ class ForeachPaLMPAdam(PSGDBase):
|
|
106
107
|
|
107
108
|
Q_triu = [line_to_triu(q) if store_triu_as_line else q for q in Q_list]
|
108
109
|
if do_update:
|
109
|
-
|
110
|
-
|
111
|
-
|
110
|
+
for g, p, q_, q_orig in zip(grad_list, p_list, Q_triu, Q_list):
|
111
|
+
q32 = [promote(qq_) for qq_ in q_]
|
112
|
+
self.balance([g], [q32])
|
113
|
+
self.do_update([p], [g], [q32], precond_lr, [q_orig], store_triu_as_line=store_triu_as_line)
|
112
114
|
torch._foreach_lerp_(exp_avg, grad_list, 1 - beta_debias(group['beta'], group['step']))
|
113
115
|
|
114
116
|
beta2 = 1 - group['step'] ** -group['beta2_scale']
|
@@ -9,7 +9,7 @@ from typing import Optional
|
|
9
9
|
import torch
|
10
10
|
|
11
11
|
from .utils import update_param_, warmup, psgd_precond_grad, init_Q_exprs, trust_region_clip_, PSGDBase, \
|
12
|
-
precond_update_prob_schedule, split_p_and_g_in_group, line_to_triu, triu_to_line, set_
|
12
|
+
precond_update_prob_schedule, split_p_and_g_in_group, line_to_triu, triu_to_line, set_, promote
|
13
13
|
|
14
14
|
|
15
15
|
class ForeachPSGDKron(PSGDBase):
|
@@ -39,7 +39,7 @@ class ForeachPSGDKron(PSGDBase):
|
|
39
39
|
max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
|
40
40
|
momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
|
41
41
|
split: bool = False, clip_fn: Optional[callable] = None, store_triu_as_line: bool = True,
|
42
|
-
foreach: bool = True):
|
42
|
+
foreach: bool = True, q_dtype='float32'):
|
43
43
|
if not 0.0 <= lr:
|
44
44
|
raise ValueError(f"Invalid learning rate: {lr}")
|
45
45
|
if not 0.0 <= beta < 1.0:
|
@@ -60,7 +60,7 @@ class ForeachPSGDKron(PSGDBase):
|
|
60
60
|
# precond lr hardcoded to 0.1
|
61
61
|
precond_init_scale=1.0, # precond init scale hardcoded to 1.0
|
62
62
|
step=0, warmup_steps=warmup_steps, merge_dims=merge_dims, split=split,
|
63
|
-
store_triu_as_line=store_triu_as_line)
|
63
|
+
store_triu_as_line=store_triu_as_line, q_dtype=q_dtype)
|
64
64
|
super().__init__(params, defaults, foreach)
|
65
65
|
|
66
66
|
self._prob_step = 0
|
@@ -83,6 +83,7 @@ class ForeachPSGDKron(PSGDBase):
|
|
83
83
|
lr = group['lr']
|
84
84
|
beta = group['beta']
|
85
85
|
store_triu_as_line = group['store_triu_as_line']
|
86
|
+
q_dtype = getattr(torch, group['q_dtype'])
|
86
87
|
|
87
88
|
vals = []
|
88
89
|
|
@@ -92,7 +93,7 @@ class ForeachPSGDKron(PSGDBase):
|
|
92
93
|
if 'Q' not in state:
|
93
94
|
state["exp_avg"] = torch.zeros_like(g)
|
94
95
|
Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular, min_ndim_triangular,
|
95
|
-
memory_save_mode, dtype=
|
96
|
+
memory_save_mode, dtype=q_dtype)
|
96
97
|
state['Q'] = triu_to_line(Q) if store_triu_as_line else Q
|
97
98
|
|
98
99
|
vals.append((p, g, state["exp_avg"], state["Q"]))
|
@@ -114,9 +115,9 @@ class ForeachPSGDKron(PSGDBase):
|
|
114
115
|
q = line_to_triu(q_orig) if store_triu_as_line else q_orig
|
115
116
|
|
116
117
|
if do_update:
|
117
|
-
|
118
|
-
self.
|
119
|
-
|
118
|
+
q32 = [promote(q_) for q_ in q]
|
119
|
+
self.balance([ea if momentum_into_precond_update else g], [q32])
|
120
|
+
self.do_update([p], [g], [q32], precond_lr, [q_orig], store_triu_as_line=store_triu_as_line)
|
120
121
|
set_(g, psgd_precond_grad(q, self.state_(p)["exprs"], ea))
|
121
122
|
|
122
123
|
grad_list = self.clip_fn(grad_list)
|
@@ -5,9 +5,10 @@ Source available at https://github.com/evanatyourservice/kron_torch/blob/97a2b5e
|
|
5
5
|
"""
|
6
6
|
|
7
7
|
import torch
|
8
|
+
from heavyball.utils import copy_stochastic_list_
|
8
9
|
|
9
10
|
from .utils import update_param_, warmup, psgd_precond_grad, init_Q_exprs, PSGDBase, precond_update_prob_schedule, \
|
10
|
-
split_p_and_g_in_group, line_to_triu, triu_to_line
|
11
|
+
split_p_and_g_in_group, line_to_triu, triu_to_line, promote
|
11
12
|
|
12
13
|
|
13
14
|
class ForeachPurePSGD(PSGDBase):
|
@@ -37,7 +38,7 @@ class ForeachPurePSGD(PSGDBase):
|
|
37
38
|
max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
|
38
39
|
momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
|
39
40
|
split: bool = False, clip_fn: callable = None, store_triu_as_line: bool = True,
|
40
|
-
foreach: bool = True):
|
41
|
+
foreach: bool = True, q_dtype='float32'):
|
41
42
|
if not 0.0 <= lr:
|
42
43
|
raise ValueError(f"Invalid learning rate: {lr}")
|
43
44
|
if not 0.0 <= weight_decay:
|
@@ -56,7 +57,7 @@ class ForeachPurePSGD(PSGDBase):
|
|
56
57
|
# precond lr hardcoded to 0.1
|
57
58
|
precond_init_scale=1.0, # precond init scale hardcoded to 1.0
|
58
59
|
step=0, warmup_steps=warmup_steps, merge_dims=merge_dims, split=split,
|
59
|
-
store_triu_as_line=store_triu_as_line)
|
60
|
+
store_triu_as_line=store_triu_as_line, q_dtype=q_dtype)
|
60
61
|
super().__init__(params, defaults, foreach)
|
61
62
|
|
62
63
|
self._prob_step = 0
|
@@ -77,6 +78,7 @@ class ForeachPurePSGD(PSGDBase):
|
|
77
78
|
weight_decay = group['weight_decay']
|
78
79
|
lr = group['lr']
|
79
80
|
store_triu_as_line = group['store_triu_as_line']
|
81
|
+
q_dtype = getattr(torch, group['q_dtype'])
|
80
82
|
|
81
83
|
vals = []
|
82
84
|
|
@@ -85,7 +87,7 @@ class ForeachPurePSGD(PSGDBase):
|
|
85
87
|
|
86
88
|
if 'Q' not in state:
|
87
89
|
Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular, min_ndim_triangular,
|
88
|
-
memory_save_mode, dtype=
|
90
|
+
memory_save_mode, dtype=q_dtype)
|
89
91
|
state['Q'] = triu_to_line(Q) if store_triu_as_line else Q
|
90
92
|
|
91
93
|
vals.append((p, g, state["Q"]))
|
@@ -104,8 +106,9 @@ class ForeachPurePSGD(PSGDBase):
|
|
104
106
|
q = line_to_triu(q_orig) if store_triu_as_line else q_orig
|
105
107
|
|
106
108
|
if do_update:
|
107
|
-
|
108
|
-
self.
|
109
|
+
q32 = [promote(q_) for q_ in q]
|
110
|
+
self.balance([g], [q32])
|
111
|
+
self.do_update([p], [g], [q32], precond_lr, [q_orig], store_triu_as_line=store_triu_as_line)
|
109
112
|
psgd_precond_grad(q, self.state_(p)["exprs"], g, inplace=True)
|
110
113
|
|
111
114
|
grad_list = self.clip_fn(grad_list)
|
@@ -325,9 +325,9 @@ def compute_ggt(grad, GG, max_precond_dim, precondition_1d, beta):
|
|
325
325
|
|
326
326
|
|
327
327
|
def promote(x):
|
328
|
-
if x
|
328
|
+
if x in (torch.bfloat16, torch.float16):
|
329
329
|
return torch.float32
|
330
|
-
if x.dtype in (torch.bfloat16, torch.float16):
|
330
|
+
if hasattr(x, 'dtype') and x.dtype in (torch.bfloat16, torch.float16):
|
331
331
|
return x.float()
|
332
332
|
return x
|
333
333
|
|
@@ -468,15 +468,15 @@ class ScheduleFree(StatefulOptimizer):
|
|
468
468
|
|
469
469
|
def copy_stochastic_list_(target: List[torch.Tensor], source: List[torch.Tensor]):
|
470
470
|
for t, s in zip(target, source):
|
471
|
-
|
472
|
-
copy_stochastic_(t, s)
|
473
|
-
else:
|
474
|
-
set_(t, s)
|
471
|
+
copy_stochastic_(t, s)
|
475
472
|
|
476
473
|
|
477
474
|
def copy_stochastic_(target: torch.Tensor, source: torch.Tensor):
|
478
475
|
if target.data_ptr() == source.data_ptr():
|
479
476
|
return
|
477
|
+
if target.dtype != torch.bfloat16:
|
478
|
+
set_(target, source)
|
479
|
+
return
|
480
480
|
|
481
481
|
"""Taken as-is from https://github.com/pytorch/pytorch/issues/120376#issuecomment-1974828905"""
|
482
482
|
# create a random 16 bit integer
|
@@ -555,7 +555,7 @@ def init_Q_exprs(t, scale, max_size, min_ndim_triangular, memory_save_mode, dtyp
|
|
555
555
|
for i, (size, dim_d) in enumerate(zip(shape, dim_diag)):
|
556
556
|
if size == 1 or size > max_size or len(shape) < min_ndim_triangular or dim_d:
|
557
557
|
# use diagonal matrix as preconditioner for this dim
|
558
|
-
Q.append(scale * torch.ones(size, dtype=dtype, device=t.device))
|
558
|
+
Q.append(scale * torch.ones(size, dtype=promote(dtype), device=t.device))
|
559
559
|
|
560
560
|
piece1A.append(letters[i])
|
561
561
|
piece2A = piece2A + letters[i]
|
@@ -669,11 +669,11 @@ def psgd_update_precond(Q, exprs, V, G, step, tiny):
|
|
669
669
|
@decorator
|
670
670
|
def psgd_precond_grad(Q, exprs, G, inplace: bool = False):
|
671
671
|
"""Precondition gradient G with preconditioner Q."""
|
672
|
-
out = torch.einsum(exprs[-1], *[q.conj() for q in Q], *Q, G)
|
672
|
+
out = torch.einsum(exprs[-1], *[q.conj() for q in Q], *Q, G.to(Q[0].dtype))
|
673
673
|
if inplace:
|
674
674
|
set_(G, out)
|
675
675
|
return G
|
676
|
-
return out
|
676
|
+
return out.to(G.dtype)
|
677
677
|
|
678
678
|
|
679
679
|
def norm_clip_(x, scale=None):
|
@@ -768,28 +768,33 @@ def line_to_triu(Q_list: List[Tuple[Optional[List[int]], torch.Tensor]]):
|
|
768
768
|
def update_triu_(q_state, materialised):
|
769
769
|
for (shape0, q), (shape1, m) in zip(q_state, triu_to_line(materialised)):
|
770
770
|
assert shape0 == shape1
|
771
|
-
|
771
|
+
copy_stochastic_(q, m)
|
772
772
|
|
773
773
|
|
774
774
|
class PSGDBase(StatefulOptimizer):
|
775
|
+
balance_probability: float = 0.01
|
776
|
+
|
775
777
|
def __init__(self, parameters, groups, foreach: bool = True):
|
776
778
|
super().__init__(parameters, groups, foreach)
|
777
779
|
self.rng = random.Random(0x1923213)
|
778
780
|
self._tiny = torch.finfo(torch.bfloat16).tiny
|
779
781
|
|
780
782
|
def balance(self, grad_list, Q_list):
|
781
|
-
if self.rng.random() >
|
783
|
+
if self.rng.random() > self.balance_probability:
|
782
784
|
return
|
783
785
|
|
784
786
|
for g, q in zip(grad_list, Q_list):
|
785
787
|
if g.dim() > 1:
|
786
788
|
psgd_balance_Q(q)
|
787
789
|
|
788
|
-
def do_update(self, p_list, grad_list, q_list, precond_lr, original_q: Optional[List] = None):
|
790
|
+
def do_update(self, p_list, grad_list, q_list, precond_lr, original_q: Optional[List] = None, store_triu_as_line=False):
|
789
791
|
for i, (p, grad, Q) in enumerate(zip(p_list, grad_list, q_list)):
|
790
792
|
psgd_update_precond(Q, self.state_(p)["exprs"], torch.randn_like(grad), grad, precond_lr, self._tiny)
|
791
793
|
if original_q:
|
792
|
-
|
794
|
+
if store_triu_as_line:
|
795
|
+
update_triu_(original_q[i], Q)
|
796
|
+
else:
|
797
|
+
copy_stochastic_(original_q[i], Q)
|
793
798
|
|
794
799
|
|
795
800
|
def precond_update_prob_schedule(max_prob=1.0, min_prob=0.03, decay=0.001, flat_start=250):
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: heavyball
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.17.1
|
4
4
|
Summary: Efficient optimizers
|
5
5
|
Home-page: https://github.com/clashluke/heavyball
|
6
6
|
Author: Lucas Nestler
|
@@ -32,8 +32,8 @@ A simple package of efficient optimizers
|
|
32
32
|
The goal is not to thrive for completeness, full maintenance or abstraction, but instead to provide a simple
|
33
33
|
largely static alternative to `torch.optim` with more and better optimizers.
|
34
34
|
|
35
|
-
Currently (2024-11-
|
36
|
-
recommended experimental optimizer is `
|
35
|
+
Currently (2024-11-20, 0.17.0), the recommended stable optimizer is `PrecondSchedulePaLMSOAP` (see below). The
|
36
|
+
recommended experimental optimizer is `DelayedPSGDKron` ([tuning guide](docs/psgd_efficiency.md)).
|
37
37
|
|
38
38
|
## Features
|
39
39
|
|
@@ -62,7 +62,7 @@ import heavyball
|
|
62
62
|
model = torch.nn.Linear(16, 1)
|
63
63
|
|
64
64
|
# Create an optimizer
|
65
|
-
optimizer = heavyball.
|
65
|
+
optimizer = heavyball.PrecondSchedulePaLMSOAP(model.parameters(), lr=1e-3)
|
66
66
|
|
67
67
|
x = torch.randn(128, 16)
|
68
68
|
y = torch.randn(128, 1)
|
@@ -76,19 +76,19 @@ for _ in range(1000):
|
|
76
76
|
|
77
77
|
## Optimizers
|
78
78
|
|
79
|
-
| Name
|
80
|
-
|
81
|
-
| **
|
82
|
-
| **
|
83
|
-
| **
|
84
|
-
| **
|
85
|
-
| **
|
86
|
-
| **
|
87
|
-
| **
|
88
|
-
| **
|
89
|
-
| **
|
90
|
-
| **
|
91
|
-
| **
|
79
|
+
| Name | Description | Advantages / Disadvantages |
|
80
|
+
|-------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
81
|
+
| **AdamW** | More efficient (speed, memory) [AdamW](https://arxiv.org/abs/1711.05101) | + Faster than AdamW<br>+ Possibly more (numerically) stable
|
82
|
+
| **LaProp** | More efficient (speed, memory) [LaProp](https://arxiv.org/abs/2002.04839) | + Same cost as AdamW<br>+ Marginally better converence (better proofs)<br>+ Higher hyperparameter stability<br>- Not a guaranteed win (can be neutral)<br>- No "Slingshot" |
|
83
|
+
| **ADOPT** | More efficient (speed, memory) [ADOPT](https://arxiv.org/abs/2411.02853) | + Same cost as AdamW<br>+ Rigorous mathematical convergence proofs, even for challenging models (GANs)<br>- Empirically underperforms LaProp<br>- no bf16 |
|
84
|
+
| **SFAdamW** | More efficient (speed, memory) [ScheduleFree AdamW](https://arxiv.org/abs/2405.15682) | + Same cost as AdamW, but better eval perf<br>+ Full control over hyperparameters |
|
85
|
+
| **PaLMSFAdamW** | ForeachSFAdamW with [PaLM's beta2 schedule](https://arxiv.org/abs/2204.02311) | + Same cost as AdamW, but better eval perf<br>+ Less control, but faster early and more stable late convergence<br>+ ScheduleFree<br>- slow early convergence |
|
86
|
+
| **SOAP** | More efficient (speed, memory) [SOAP](https://arxiv.org/abs/2409.11321) | + Faster convergence (loss-at-step)<br>+ Full control over hyperparameters<br>- more memory usage<br>- more hyperparameters<br>- higher overhead than AdamW (can be ammortized; better loss-at-second) |
|
87
|
+
| **PaLMSOAP** | ForeachSOAP with [PaLM's beta2 schedule](https://arxiv.org/abs/2204.02311) | + Faster convergence (loss-at-step)<br>+ Less control, but faster early and more stable late convergence<br>- more memory usage<br>- more hyperparameters<br>- higher overhead than AdamW (can be ammortized; better loss-at-second) |
|
88
|
+
| **SFPaLMSOAP** | ScheduleFree PaLMForeachSOAP | + Fast convergence (loss-at-step)<br>+ less memory usage than PaLMForeachSOAP (more tham AdamW)<br>- slower initial convergence than PaLMForeachSOAP (but allows higher LRs)<br>- higher overhead than AdamW (can be ammortized) |
|
89
|
+
| **PrecondScheduleSFPaLMSOAP** | SFPaLMForeachSOAP with [preconditioner schedule](https://github.com/lixilinx/psgd_torch/), matching the error of PrecondEvery=2 with the cost of PrecondEvery=512 | + Better initial convergence than SFPaLMForeachSOAP<br>+ Significantly faster (sec/it) later<br>+ less memory usage than PaLMForeachSOAP (more tham AdamW)<br>- slower initial convergence than PaLMForeachSOAP (but allows higher LRs)<br>- higher overhead than AdamW (can be ammortized), goes to 0 with increasing number of step |
|
90
|
+
| **PrecondSchedulePaLMSOAP** | PrecondScheduleSFPaLMForeachSOAP without schedule-free | + Best initial convergence<br>+ Significantly faster (sec/it) later<br>+ high stability<br>- more memory usage than PrecondScheduleSFPaLMForeachSOAP<br>- higher overhead than AdamW (can be ammortized), goes to 0 with increasing number of steps |
|
91
|
+
| **PrecondScheduleSOAP** | PrecondScheduleSFPaLMForeachSOAP without PaLM's beta2 schedule | + Better initial convergence<br>+ Significantly faster (sec/it) later<br>- more memory usage than PrecondScheduleSFPaLMForeachSOAP<br>- higher overhead than AdamW (can be ammortized), goes to 0 with increasing number of steps |
|
92
92
|
|
93
93
|
## Precond Schedule
|
94
94
|
|
@@ -2,6 +2,7 @@ LICENSE
|
|
2
2
|
README.md
|
3
3
|
setup.py
|
4
4
|
heavyball/__init__.py
|
5
|
+
heavyball/cached_delayed_psgd_kron.py
|
5
6
|
heavyball/cached_psgd_kron.py
|
6
7
|
heavyball/delayed_psgd.py
|
7
8
|
heavyball/foreach_adamw.py
|
@@ -24,6 +25,7 @@ heavyball.egg-info/SOURCES.txt
|
|
24
25
|
heavyball.egg-info/dependency_links.txt
|
25
26
|
heavyball.egg-info/requires.txt
|
26
27
|
heavyball.egg-info/top_level.txt
|
28
|
+
test/test_bf16_q.py
|
27
29
|
test/test_closure.py
|
28
30
|
test/test_foreach.py
|
29
31
|
test/test_memory.py
|
@@ -0,0 +1,52 @@
|
|
1
|
+
import heavyball
|
2
|
+
import heavyball.utils
|
3
|
+
import pytest
|
4
|
+
import torch
|
5
|
+
from benchmark.utils import get_optim
|
6
|
+
from heavyball.utils import clean, set_torch, PSGDBase
|
7
|
+
from torch import nn
|
8
|
+
|
9
|
+
|
10
|
+
def get_memory():
|
11
|
+
clean()
|
12
|
+
torch.cuda.synchronize()
|
13
|
+
clean()
|
14
|
+
torch.cuda.synchronize()
|
15
|
+
return torch.cuda.memory_allocated()
|
16
|
+
|
17
|
+
|
18
|
+
@pytest.mark.parametrize("opt", heavyball.__all__)
|
19
|
+
@pytest.mark.parametrize("size,depth", [(256, 2)])
|
20
|
+
def test_foreach(opt, size, depth: int, iterations: int = 128, outer_iterations: int = 3):
|
21
|
+
set_torch()
|
22
|
+
|
23
|
+
opt = getattr(heavyball, opt)
|
24
|
+
if not issubclass(opt, PSGDBase):
|
25
|
+
raise pytest.skip('Only PSGD is supported')
|
26
|
+
|
27
|
+
peaks = []
|
28
|
+
losses = []
|
29
|
+
|
30
|
+
for q_dtype in ['float32', 'bfloat16']:
|
31
|
+
peaks.append([])
|
32
|
+
losses.append([])
|
33
|
+
|
34
|
+
for i in range(outer_iterations):
|
35
|
+
torch.manual_seed(0x2131290)
|
36
|
+
model = nn.Sequential(*[nn.Linear(size, size) for _ in range(depth)]).cuda()
|
37
|
+
o = get_optim(opt, model.parameters(), lr=1e-3, q_dtype=q_dtype)
|
38
|
+
|
39
|
+
for _ in range(iterations):
|
40
|
+
loss = model(torch.randn((1024, size)).cuda()).square().mean()
|
41
|
+
loss.backward()
|
42
|
+
o.step()
|
43
|
+
o.zero_grad()
|
44
|
+
losses[-1].append(loss.detach())
|
45
|
+
|
46
|
+
del model, o
|
47
|
+
clean()
|
48
|
+
|
49
|
+
|
50
|
+
for i, (l0, l1) in enumerate(zip(*losses)):
|
51
|
+
print(i, l0.item(), l1.item())
|
52
|
+
assert torch.allclose(l0, l1, rtol=0.1)
|
@@ -20,7 +20,7 @@ class Param(nn.Module):
|
|
20
20
|
|
21
21
|
@pytest.mark.parametrize("opt", heavyball.__all__)
|
22
22
|
@pytest.mark.parametrize("size", [(4, 4, 4, 4), ])
|
23
|
-
def
|
23
|
+
def test_closure(opt, size: List[int], depth: int = 2, iterations: int = 5, outer_iterations: int = 3):
|
24
24
|
clean()
|
25
25
|
set_torch()
|
26
26
|
|
@@ -25,14 +25,14 @@ expected_memory = {'adamw': {'after': 4, 'peak': 5.1}, 'soap': {'after': 7, 'pea
|
|
25
25
|
@pytest.mark.parametrize("size,depth", [(8192, 1), (2048, 16)])
|
26
26
|
def test_memory(opt, method, size, depth: int, iterations: int = 5, outer_iterations: int = 3):
|
27
27
|
if 'soap' not in opt.lower() and method != 'qr':
|
28
|
-
|
28
|
+
raise pytest.skip('Only SOAP supports `method` argument')
|
29
29
|
set_torch()
|
30
30
|
|
31
31
|
for k, v in expected_memory.items():
|
32
32
|
if k in opt.lower():
|
33
33
|
break
|
34
34
|
else:
|
35
|
-
raise
|
35
|
+
raise pytest.skip(f'Opt {opt} not supported')
|
36
36
|
|
37
37
|
opt = getattr(heavyball, opt)
|
38
38
|
heavyball.utils.zeroth_power_mode = method
|
@@ -26,7 +26,7 @@ class Param(nn.Module):
|
|
26
26
|
def test_merge(opt, method, size: List[int], merge, split, depth: int = 2, iterations: int = 5,
|
27
27
|
outer_iterations: int = 3):
|
28
28
|
if 'soap' not in opt.lower() and method != 'qr':
|
29
|
-
|
29
|
+
raise pytest.skip('Only SOAP supports `method` argument')
|
30
30
|
clean()
|
31
31
|
set_torch()
|
32
32
|
|
@@ -1,11 +1,10 @@
|
|
1
|
-
import pytest
|
2
|
-
import torch
|
3
|
-
from torch import nn
|
4
|
-
|
5
1
|
import heavyball
|
6
2
|
import heavyball.utils
|
3
|
+
import pytest
|
4
|
+
import torch
|
7
5
|
from benchmark.utils import get_optim
|
8
6
|
from heavyball.utils import clean, set_torch
|
7
|
+
from torch import nn
|
9
8
|
|
10
9
|
|
11
10
|
def get_memory():
|
@@ -16,10 +15,6 @@ def get_memory():
|
|
16
15
|
return torch.cuda.memory_allocated()
|
17
16
|
|
18
17
|
|
19
|
-
expected_memory = {'adamw': {'after': 4, 'peak': 5.1}, 'soap': {'after': 7, 'peak': 14},
|
20
|
-
'psgd': {'after': 4, 'peak': 11.5}, 'padam': {'after': 5, 'peak': 11.4}}
|
21
|
-
|
22
|
-
|
23
18
|
@pytest.mark.parametrize("opt", ['ForeachPSGDKron', 'ForeachPaLMPAdam', 'ForeachPurePSGD', 'ForeachDelayedPSGD'])
|
24
19
|
@pytest.mark.parametrize("method",
|
25
20
|
['norm_clip_', 'mu_law_compress', 'a_law_compress', 'trust_region_clip_', 'identity'])
|
@@ -27,12 +22,6 @@ expected_memory = {'adamw': {'after': 4, 'peak': 5.1}, 'soap': {'after': 7, 'pea
|
|
27
22
|
def test_clip(opt, method, size, depth: int, iterations: int = 100, outer_iterations: int = 3):
|
28
23
|
set_torch()
|
29
24
|
|
30
|
-
for k, v in expected_memory.items():
|
31
|
-
if k in opt.lower():
|
32
|
-
break
|
33
|
-
else:
|
34
|
-
raise ValueError(f'Unknown optimizer {opt}')
|
35
|
-
|
36
25
|
opt = getattr(heavyball, opt)
|
37
26
|
|
38
27
|
for i in range(outer_iterations):
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|