heavyball 0.15.0__tar.gz → 0.16.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {heavyball-0.15.0 → heavyball-0.16.0}/PKG-INFO +4 -2
- {heavyball-0.15.0 → heavyball-0.16.0}/README.md +3 -1
- {heavyball-0.15.0 → heavyball-0.16.0}/heavyball/__init__.py +24 -2
- heavyball-0.16.0/heavyball/cached_psgd_kron.py +142 -0
- {heavyball-0.15.0 → heavyball-0.16.0}/heavyball/delayed_psgd.py +11 -8
- {heavyball-0.15.0 → heavyball-0.16.0}/heavyball/foreach_adamw.py +3 -2
- {heavyball-0.15.0 → heavyball-0.16.0}/heavyball/foreach_adopt.py +3 -2
- {heavyball-0.15.0 → heavyball-0.16.0}/heavyball/foreach_laprop.py +3 -2
- {heavyball-0.15.0 → heavyball-0.16.0}/heavyball/foreach_sfadamw.py +4 -4
- {heavyball-0.15.0 → heavyball-0.16.0}/heavyball/foreach_soap.py +4 -3
- {heavyball-0.15.0 → heavyball-0.16.0}/heavyball/p_adam.py +14 -8
- {heavyball-0.15.0 → heavyball-0.16.0}/heavyball/palm_foreach_sfadamw.py +3 -2
- {heavyball-0.15.0 → heavyball-0.16.0}/heavyball/palm_foreach_soap.py +3 -2
- {heavyball-0.15.0 → heavyball-0.16.0}/heavyball/precond_schedule_foreach_soap.py +3 -2
- {heavyball-0.15.0 → heavyball-0.16.0}/heavyball/precond_schedule_palm_foreach_soap.py +3 -2
- {heavyball-0.15.0 → heavyball-0.16.0}/heavyball/precond_schedule_sfpsoap.py +3 -3
- {heavyball-0.15.0 → heavyball-0.16.0}/heavyball/psgd_kron.py +11 -7
- {heavyball-0.15.0 → heavyball-0.16.0}/heavyball/pure_psgd.py +10 -7
- {heavyball-0.15.0 → heavyball-0.16.0}/heavyball/schedule_free_palm_foreach_soap.py +4 -3
- {heavyball-0.15.0 → heavyball-0.16.0}/heavyball/utils.py +29 -11
- {heavyball-0.15.0 → heavyball-0.16.0}/heavyball.egg-info/PKG-INFO +4 -2
- {heavyball-0.15.0 → heavyball-0.16.0}/heavyball.egg-info/SOURCES.txt +2 -0
- {heavyball-0.15.0 → heavyball-0.16.0}/setup.py +1 -1
- heavyball-0.16.0/test/test_foreach.py +65 -0
- {heavyball-0.15.0 → heavyball-0.16.0}/LICENSE +0 -0
- {heavyball-0.15.0 → heavyball-0.16.0}/heavyball.egg-info/dependency_links.txt +0 -0
- {heavyball-0.15.0 → heavyball-0.16.0}/heavyball.egg-info/requires.txt +0 -0
- {heavyball-0.15.0 → heavyball-0.16.0}/heavyball.egg-info/top_level.txt +0 -0
- {heavyball-0.15.0 → heavyball-0.16.0}/setup.cfg +0 -0
- {heavyball-0.15.0 → heavyball-0.16.0}/test/test_closure.py +0 -0
- {heavyball-0.15.0 → heavyball-0.16.0}/test/test_memory.py +0 -0
- {heavyball-0.15.0 → heavyball-0.16.0}/test/test_merge.py +0 -0
- {heavyball-0.15.0 → heavyball-0.16.0}/test/test_no_grad.py +0 -0
- {heavyball-0.15.0 → heavyball-0.16.0}/test/test_psgd.py +0 -0
- {heavyball-0.15.0 → heavyball-0.16.0}/test/test_soap.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: heavyball
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.16.0
|
4
4
|
Summary: Efficient optimizers
|
5
5
|
Home-page: https://github.com/clashluke/heavyball
|
6
6
|
Author: Lucas Nestler
|
@@ -39,12 +39,14 @@ recommended experimental optimizer is `ForeachPSGDKron`.
|
|
39
39
|
|
40
40
|
* **Stochastic Rounding**: [FP32 convergence with BF16 parameters](https://github.com/pytorch/pytorch/issues/120376)
|
41
41
|
* **Inplace EMA**: Same math, but less memory, less compute and higher stability
|
42
|
-
* **Foreach**: Fast multi-tensor application
|
42
|
+
* **Foreach**: Fast multi-tensor application (turn it off to save memory via `foreach=False`)
|
43
43
|
* **PaLM Beta2**: Fast initial
|
44
44
|
convergence, [stable late convergence](https://x.com/_clashluke/status/1820810798693818761)
|
45
45
|
* **ScheduleFree**: No learning rate schedule, but better convergence
|
46
46
|
* [**Preconditioner Schedule**](https://github.com/lixilinx/psgd_torch/): Improved loss-per-step in early convergence,
|
47
47
|
better step-per-second in late convergence (explained below)
|
48
|
+
* **Memory-efficient storage** PSGD supports `store_triu_as_line` (default: `True`) to trade off memory usage for memory
|
49
|
+
bandwidth; turn it off for lower overheads (for more, see [PSGD Efficiency](docs/psgd_efficiency.md))
|
48
50
|
|
49
51
|
## Getting started
|
50
52
|
|
@@ -15,12 +15,14 @@ recommended experimental optimizer is `ForeachPSGDKron`.
|
|
15
15
|
|
16
16
|
* **Stochastic Rounding**: [FP32 convergence with BF16 parameters](https://github.com/pytorch/pytorch/issues/120376)
|
17
17
|
* **Inplace EMA**: Same math, but less memory, less compute and higher stability
|
18
|
-
* **Foreach**: Fast multi-tensor application
|
18
|
+
* **Foreach**: Fast multi-tensor application (turn it off to save memory via `foreach=False`)
|
19
19
|
* **PaLM Beta2**: Fast initial
|
20
20
|
convergence, [stable late convergence](https://x.com/_clashluke/status/1820810798693818761)
|
21
21
|
* **ScheduleFree**: No learning rate schedule, but better convergence
|
22
22
|
* [**Preconditioner Schedule**](https://github.com/lixilinx/psgd_torch/): Improved loss-per-step in early convergence,
|
23
23
|
better step-per-second in late convergence (explained below)
|
24
|
+
* **Memory-efficient storage** PSGD supports `store_triu_as_line` (default: `True`) to trade off memory usage for memory
|
25
|
+
bandwidth; turn it off for lower overheads (for more, see [PSGD Efficiency](docs/psgd_efficiency.md))
|
24
26
|
|
25
27
|
## Getting started
|
26
28
|
|
@@ -1,3 +1,4 @@
|
|
1
|
+
from .cached_psgd_kron import ForeachCachedPSGDKron
|
1
2
|
from .delayed_psgd import ForeachDelayedPSGD
|
2
3
|
from .foreach_adamw import ForeachAdamW
|
3
4
|
from .foreach_adopt import ForeachADOPT
|
@@ -16,7 +17,28 @@ from .schedule_free_palm_foreach_soap import SFPaLMForeachSOAP
|
|
16
17
|
|
17
18
|
PalmForEachSoap = PaLMForeachSOAP
|
18
19
|
|
20
|
+
PaLMSOAP = PaLMForeachSOAP
|
21
|
+
PaLMSFAdamW = PaLMForeachSFAdamW
|
22
|
+
PaLMSFSoap = SFPaLMForeachSOAP
|
23
|
+
PaLMForeachSOAP = PaLMForeachSOAP
|
24
|
+
PrecondScheduleSFPaLMSOAP = PrecondScheduleSFPaLMSOAP
|
25
|
+
SOAP = ForeachSOAP
|
26
|
+
SFAdamW = ForeachSFAdamW
|
27
|
+
LaProp = ForeachLaProp
|
28
|
+
ADOPT = ForeachADOPT
|
29
|
+
PrecondScheduleForeachSOAP = PrecondScheduleForeachSOAP
|
30
|
+
PrecondSchedulePaLMForeachSOAP = PrecondSchedulePaLMForeachSOAP
|
31
|
+
PSGDKron = ForeachPSGDKron
|
32
|
+
AdamW = ForeachAdamW
|
33
|
+
PurePSGD = ForeachPurePSGD
|
34
|
+
PaLMPAdam = ForeachPaLMPAdam
|
35
|
+
DelayedPSGD = ForeachDelayedPSGD
|
36
|
+
CachedPSGDKron = ForeachCachedPSGDKron
|
37
|
+
|
19
38
|
__all__ = ['PalmForEachSoap', 'PaLMForeachSFAdamW', 'PaLMForeachSOAP', 'SFPaLMForeachSOAP', 'PrecondScheduleSFPaLMSOAP',
|
20
39
|
'ForeachSOAP', 'ForeachSFAdamW', 'ForeachLaProp', 'ForeachADOPT', 'PrecondScheduleForeachSOAP',
|
21
|
-
'PrecondSchedulePaLMForeachSOAP', 'ForeachPSGDKron', 'ForeachAdamW', 'ForeachPurePSGD',
|
22
|
-
'
|
40
|
+
'PrecondSchedulePaLMForeachSOAP', 'ForeachPSGDKron', 'ForeachAdamW', 'ForeachPurePSGD', 'ForeachPaLMPAdam',
|
41
|
+
'ForeachDelayedPSGD', 'ForeachCachedPSGDKron', #
|
42
|
+
'PaLMSOAP', 'PaLMSFAdamW', 'PaLMSFSoap', 'PaLMSFAdamW', 'PaLMForeachSOAP', 'PrecondScheduleSFPaLMSOAP',
|
43
|
+
'SOAP', 'SFAdamW', 'LaProp', 'ADOPT', 'PSGDKron', 'AdamW', 'PurePSGD', 'PaLMPAdam', 'DelayedPSGD',
|
44
|
+
'CachedPSGDKron']
|
@@ -0,0 +1,142 @@
|
|
1
|
+
"""
|
2
|
+
Originally from Evan Walters and Omead Pooladzandi, 2024
|
3
|
+
Modified under Creative Commons Attribution 4.0 International
|
4
|
+
Source available at https://github.com/evanatyourservice/kron_torch/blob/97a2b5ee8a1a4c29e4780bbf6c521e545189eff9/kron_torch/kron.py
|
5
|
+
"""
|
6
|
+
|
7
|
+
from typing import Optional
|
8
|
+
|
9
|
+
import torch
|
10
|
+
from heavyball.utils import einsum_base
|
11
|
+
|
12
|
+
from .utils import update_param_, warmup, psgd_precond_grad, init_Q_exprs, trust_region_clip_, PSGDBase, \
|
13
|
+
precond_update_prob_schedule, split_p_and_g_in_group, line_to_triu, triu_to_line, set_, einsum_base
|
14
|
+
|
15
|
+
|
16
|
+
class ForeachCachedPSGDKron(PSGDBase):
|
17
|
+
"""Implements PSGD Kron from https://github.com/lixilinx/psgd_torch with cached preconditioners.
|
18
|
+
|
19
|
+
Args:
|
20
|
+
params (iterable): Iterable of parameters to optimize or dicts defining
|
21
|
+
parameter groups.
|
22
|
+
lr (float): Learning rate.
|
23
|
+
b1 (float): Momentum parameter.
|
24
|
+
weight_decay (float): Weight decay (L2 penalty).
|
25
|
+
preconditioner_update_probability (callable or float, optional): Probability of
|
26
|
+
updating the preconditioner. If None, defaults to a schedule that anneals
|
27
|
+
from 1.0 to 0.03 by 4000 steps.
|
28
|
+
max_size_triangular (int): Max size for dim's preconditioner to be triangular.
|
29
|
+
min_ndim_triangular (int): Minimum number of dimensions a layer needs
|
30
|
+
to have triangular preconditioners.
|
31
|
+
memory_save_mode: (string, optional), None, 'one_diag', or 'all_diag', None is default
|
32
|
+
to set all preconditioners to be triangular, 'one_diag' sets the largest
|
33
|
+
or last dim to be diagonal per layer, and 'all_diag' sets all preconditioners
|
34
|
+
to be diagonal.
|
35
|
+
momentum_into_precond_update: (bool), whether to send momentum into preconditioner
|
36
|
+
update instead of raw gradients.
|
37
|
+
"""
|
38
|
+
|
39
|
+
def __init__(self, params, lr=0.001, beta=0.9, weight_decay=0.0, preconditioner_update_probability=None,
|
40
|
+
max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
|
41
|
+
momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
|
42
|
+
split: bool = False, clip_fn: Optional[callable] = None, store_triu_as_line: bool = True,
|
43
|
+
foreach: bool = True):
|
44
|
+
if not 0.0 <= lr:
|
45
|
+
raise ValueError(f"Invalid learning rate: {lr}")
|
46
|
+
if not 0.0 <= beta < 1.0:
|
47
|
+
raise ValueError(f"Invalid beta parameter: {beta}")
|
48
|
+
if not 0.0 <= weight_decay:
|
49
|
+
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
|
50
|
+
|
51
|
+
if preconditioner_update_probability is None:
|
52
|
+
preconditioner_update_probability = precond_update_prob_schedule()
|
53
|
+
if clip_fn is None:
|
54
|
+
clip_fn = lambda x: trust_region_clip_(x, 0.9, 1.5)
|
55
|
+
self.preconditioner_update_probability = preconditioner_update_probability
|
56
|
+
self.clip_fn = clip_fn
|
57
|
+
|
58
|
+
defaults = dict(lr=lr, beta=beta, weight_decay=weight_decay, max_size_triangular=max_size_triangular,
|
59
|
+
min_ndim_triangular=min_ndim_triangular, memory_save_mode=memory_save_mode,
|
60
|
+
momentum_into_precond_update=momentum_into_precond_update, precond_lr=0.1,
|
61
|
+
# precond lr hardcoded to 0.1
|
62
|
+
precond_init_scale=1.0, # precond init scale hardcoded to 1.0
|
63
|
+
step=0, warmup_steps=warmup_steps, merge_dims=merge_dims, split=split,
|
64
|
+
store_triu_as_line=store_triu_as_line)
|
65
|
+
super().__init__(params, defaults, foreach)
|
66
|
+
|
67
|
+
self._prob_step = 0
|
68
|
+
|
69
|
+
def _step(self, group):
|
70
|
+
# update preconditioners all together
|
71
|
+
update_prob = self.preconditioner_update_probability
|
72
|
+
if callable(update_prob):
|
73
|
+
update_prob = update_prob(self._prob_step)
|
74
|
+
do_update = self.rng.random() < update_prob
|
75
|
+
self._prob_step += 1
|
76
|
+
|
77
|
+
momentum_into_precond_update = group.get("momentum_into_precond_update", True)
|
78
|
+
precond_init_scale = group['precond_init_scale']
|
79
|
+
max_size_triangular = group['max_size_triangular']
|
80
|
+
min_ndim_triangular = group['min_ndim_triangular']
|
81
|
+
memory_save_mode = group['memory_save_mode']
|
82
|
+
precond_lr = group['precond_lr']
|
83
|
+
weight_decay = group['weight_decay']
|
84
|
+
lr = group['lr']
|
85
|
+
beta = group['beta']
|
86
|
+
store_triu_as_line = group['store_triu_as_line']
|
87
|
+
|
88
|
+
vals = []
|
89
|
+
|
90
|
+
for p, g in split_p_and_g_in_group(group):
|
91
|
+
state = self.state_(p)
|
92
|
+
|
93
|
+
if 'Q' not in state:
|
94
|
+
state["exp_avg"] = torch.zeros_like(g)
|
95
|
+
Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular, min_ndim_triangular,
|
96
|
+
memory_save_mode, dtype=g.dtype)
|
97
|
+
state['Q'] = triu_to_line(Q) if store_triu_as_line else Q
|
98
|
+
state['Q_cache'] = [torch.empty_like(q) for q in Q]
|
99
|
+
|
100
|
+
expr = [f'{c.upper()}{c}' if q_.ndim == 2 else c for c, q_ in zip(einsum_base, Q)]
|
101
|
+
expr = ','.join(expr)
|
102
|
+
grad_expr = ''.join(c for c, _ in zip(einsum_base, g.shape))
|
103
|
+
out_expr = ''.join(c.upper() if c.upper() in expr else c for c in grad_expr)
|
104
|
+
expr = f'{expr},{grad_expr}->{out_expr}'
|
105
|
+
|
106
|
+
state['cache_expr'] = expr
|
107
|
+
|
108
|
+
vals.append((p, g, state["exp_avg"], state["Q"], state['Q_cache']))
|
109
|
+
|
110
|
+
if not vals:
|
111
|
+
return
|
112
|
+
|
113
|
+
p_list, grad_list, exp_avg_list, Q_list, Q_cache_list = zip(*vals)
|
114
|
+
del vals
|
115
|
+
|
116
|
+
group["step"] += 1
|
117
|
+
|
118
|
+
torch._foreach_lerp_(exp_avg_list, grad_list, (1 - beta) / (1 - beta ** group["step"]))
|
119
|
+
|
120
|
+
grad_list, Q_list, Q_cache_list, exp_avg_list = list(grad_list), list(Q_list), list(Q_cache_list), list(
|
121
|
+
exp_avg_list)
|
122
|
+
for i, (p, g) in enumerate(zip(p_list, grad_list)):
|
123
|
+
cached_q = Q_cache_list.pop(0)
|
124
|
+
q_orig = Q_list.pop(0)
|
125
|
+
ea = exp_avg_list.pop(0)
|
126
|
+
|
127
|
+
if do_update:
|
128
|
+
q = line_to_triu(q_orig) if store_triu_as_line else q_orig
|
129
|
+
self.balance([g], [q])
|
130
|
+
self.do_update([p], [ea if momentum_into_precond_update else g], [q], precond_lr,
|
131
|
+
[q_orig] if store_triu_as_line else None)
|
132
|
+
for c_, q_ in zip(cached_q, q):
|
133
|
+
if q_.ndim == 2:
|
134
|
+
torch.matmul(q_.T.conj(), q_, out=c_)
|
135
|
+
else:
|
136
|
+
torch.mul(q_.conj(), q_, out=c_)
|
137
|
+
|
138
|
+
set_(g, torch.einsum(self.state_(p)['cache_expr'], *cached_q, ea))
|
139
|
+
grad_list = self.clip_fn(grad_list)
|
140
|
+
|
141
|
+
lr = -warmup(lr, group['step'], group['warmup_steps'])
|
142
|
+
update_param_(p_list, grad_list, lr, weight_decay)
|
@@ -38,7 +38,8 @@ class ForeachDelayedPSGD(PSGDBase):
|
|
38
38
|
def __init__(self, params, lr=0.001, beta=0.9, weight_decay=0.0, preconditioner_update_probability=None,
|
39
39
|
max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
|
40
40
|
momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
|
41
|
-
split: bool = False, clip_fn: callable = None
|
41
|
+
split: bool = False, clip_fn: callable = None, store_triu_as_line: bool = True,
|
42
|
+
foreach: bool = True):
|
42
43
|
if not 0.0 <= lr:
|
43
44
|
raise ValueError(f"Invalid learning rate: {lr}")
|
44
45
|
if not 0.0 <= beta < 1.0:
|
@@ -58,8 +59,9 @@ class ForeachDelayedPSGD(PSGDBase):
|
|
58
59
|
momentum_into_precond_update=momentum_into_precond_update, precond_lr=0.1,
|
59
60
|
# precond lr hardcoded to 0.1
|
60
61
|
precond_init_scale=1.0, # precond init scale hardcoded to 1.0
|
61
|
-
step=0, warmup_steps=warmup_steps, merge_dims=merge_dims, split=split
|
62
|
-
|
62
|
+
step=0, warmup_steps=warmup_steps, merge_dims=merge_dims, split=split,
|
63
|
+
store_triu_as_line=store_triu_as_line)
|
64
|
+
super().__init__(params, defaults, foreach)
|
63
65
|
|
64
66
|
self._prob_step = 0
|
65
67
|
|
@@ -80,6 +82,7 @@ class ForeachDelayedPSGD(PSGDBase):
|
|
80
82
|
weight_decay = group['weight_decay']
|
81
83
|
lr = group['lr']
|
82
84
|
beta = group['beta']
|
85
|
+
store_triu_as_line = group['store_triu_as_line']
|
83
86
|
|
84
87
|
vals = []
|
85
88
|
|
@@ -90,7 +93,7 @@ class ForeachDelayedPSGD(PSGDBase):
|
|
90
93
|
state["exp_avg"] = torch.zeros_like(g)
|
91
94
|
Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular, min_ndim_triangular,
|
92
95
|
memory_save_mode, dtype=g.dtype)
|
93
|
-
state["Q"] = triu_to_line(Q)
|
96
|
+
state["Q"] = triu_to_line(Q) if store_triu_as_line else Q
|
94
97
|
|
95
98
|
vals.append((p, g, state["exp_avg"], state["Q"]))
|
96
99
|
|
@@ -108,12 +111,12 @@ class ForeachDelayedPSGD(PSGDBase):
|
|
108
111
|
for i, (p, g) in enumerate(zip(p_list, grad_list)):
|
109
112
|
q_orig = Q_list.pop(0)
|
110
113
|
ea = exp_avg_list.pop(0)
|
111
|
-
q = line_to_triu(q_orig)
|
112
|
-
self.balance(do_update, [g], [q])
|
114
|
+
q = line_to_triu(q_orig) if store_triu_as_line else q_orig
|
113
115
|
new = psgd_precond_grad(q, self.state_(p)["exprs"], ea)
|
114
|
-
|
115
116
|
if do_update:
|
116
|
-
self.do_update([p], [ea if momentum_into_precond_update else g], [q], precond_lr,
|
117
|
+
self.do_update([p], [ea if momentum_into_precond_update else g], [q], precond_lr,
|
118
|
+
[q_orig] if store_triu_as_line else None)
|
119
|
+
self.balance([g], [q])
|
117
120
|
set_(g, new)
|
118
121
|
|
119
122
|
grad_list = self.clip_fn(grad_list)
|
@@ -5,10 +5,11 @@ from .utils import warmup, exp_avg_sq_, beta_debias, update_param_, StatefulOpti
|
|
5
5
|
|
6
6
|
|
7
7
|
class ForeachAdamW(StatefulOptimizer):
|
8
|
-
def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0
|
8
|
+
def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
|
9
|
+
foreach: bool = True):
|
9
10
|
defaults = dict(lr=lr, betas=betas, eps=eps, k=0, warmup_steps=warmup_steps, train_mode=True, weight_sum=0.0,
|
10
11
|
lr_max=-1.0, weight_decay=weight_decay)
|
11
|
-
super().__init__(params, defaults)
|
12
|
+
super().__init__(params, defaults, foreach)
|
12
13
|
|
13
14
|
def _step(self, group):
|
14
15
|
eps = group['eps']
|
@@ -6,10 +6,11 @@ from .utils import warmup, beta_debias, update_param_, StatefulOptimizer
|
|
6
6
|
|
7
7
|
class ForeachADOPT(StatefulOptimizer):
|
8
8
|
|
9
|
-
def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0
|
9
|
+
def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
|
10
|
+
foreach: bool = True):
|
10
11
|
defaults = dict(lr=lr, betas=betas, eps=eps, k=0, warmup_steps=warmup_steps, train_mode=True, weight_sum=0.0,
|
11
12
|
lr_max=-1.0, weight_decay=weight_decay)
|
12
|
-
super().__init__(params, defaults)
|
13
|
+
super().__init__(params, defaults, foreach)
|
13
14
|
|
14
15
|
def _step(self, group):
|
15
16
|
eps = group['eps']
|
@@ -6,10 +6,11 @@ from .utils import warmup, exp_avg_sq_, beta_debias, update_param_, StatefulOpti
|
|
6
6
|
|
7
7
|
class ForeachLaProp(StatefulOptimizer):
|
8
8
|
|
9
|
-
def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=1
|
9
|
+
def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=1,
|
10
|
+
foreach: bool = True):
|
10
11
|
defaults = dict(lr=lr, betas=betas, eps=eps, k=0, warmup_steps=warmup_steps, train_mode=True, weight_sum=0.0,
|
11
12
|
lr_max=-1.0, weight_decay=weight_decay)
|
12
|
-
super().__init__(params, defaults)
|
13
|
+
super().__init__(params, defaults, foreach)
|
13
14
|
|
14
15
|
def _step(self, group):
|
15
16
|
eps = group['eps']
|
@@ -6,12 +6,12 @@ from .utils import schedule_free_, warmup, ScheduleFree, exp_avg_sq_, beta_debia
|
|
6
6
|
|
7
7
|
class ForeachSFAdamW(ScheduleFree):
|
8
8
|
def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0, r=0.0,
|
9
|
-
weight_lr_power=2.0, foreach=
|
9
|
+
weight_lr_power=2.0, foreach: bool = True):
|
10
10
|
|
11
11
|
defaults = dict(lr=lr, betas=betas, eps=eps, r=r, k=0, warmup_steps=warmup_steps, train_mode=True,
|
12
12
|
weight_sum=0.0, lr_max=-1.0, weight_lr_power=weight_lr_power, weight_decay=weight_decay,
|
13
13
|
foreach=foreach)
|
14
|
-
super().__init__(params, defaults)
|
14
|
+
super().__init__(params, defaults, foreach)
|
15
15
|
|
16
16
|
def _step(self, group):
|
17
17
|
eps = group['eps']
|
@@ -48,7 +48,7 @@ class ForeachSFAdamW(ScheduleFree):
|
|
48
48
|
torch._foreach_add_(grad, y, alpha=decay)
|
49
49
|
|
50
50
|
lr = warmup(group['lr'], k + 1, group['warmup_steps'])
|
51
|
-
group['weight_sum'] = schedule_free_(lr, group['weight_lr_power'], group['weight_sum'], group['betas'][0],
|
52
|
-
|
51
|
+
group['weight_sum'] = schedule_free_(lr, group['weight_lr_power'], group['weight_sum'], group['betas'][0], y, z,
|
52
|
+
grad, group['r'], k + 1)
|
53
53
|
|
54
54
|
group['k'] = k + 1
|
@@ -26,12 +26,13 @@ class ForeachSOAP(StatefulOptimizer):
|
|
26
26
|
weight_decay: float = 0.01, precondition_frequency: int = 2, max_precond_dim: int = 2048, #
|
27
27
|
merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
|
28
28
|
data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1,
|
29
|
-
split: bool = False
|
29
|
+
split: bool = False,
|
30
|
+
foreach: bool = True):
|
30
31
|
defaults = {"lr": lr, "betas": betas, "shampoo_beta": shampoo_beta, "eps": eps, "weight_decay": weight_decay,
|
31
32
|
"precondition_frequency": precondition_frequency, "max_precond_dim": max_precond_dim,
|
32
33
|
"merge_dims": merge_dims, "precondition_1d": precondition_1d, "normalize_grads": normalize_grads,
|
33
34
|
"correct_bias": correct_bias, 'warmup_steps': warmup_steps, 'split': split}
|
34
|
-
super().__init__(params, defaults)
|
35
|
+
super().__init__(params, defaults, foreach)
|
35
36
|
self._data_format = data_format
|
36
37
|
|
37
38
|
def _step(self, group):
|
@@ -59,7 +60,7 @@ class ForeachSOAP(StatefulOptimizer):
|
|
59
60
|
vals.append((p, g, grad_projected, exp_avg, exp_avg_sq))
|
60
61
|
|
61
62
|
if not vals:
|
62
|
-
return
|
63
|
+
return
|
63
64
|
|
64
65
|
p_list, grad, grad_projected, exp_avg, exp_avg_sq = zip(*vals)
|
65
66
|
beta1, beta2 = group["betas"]
|
@@ -5,6 +5,7 @@ Source available at https://github.com/evanatyourservice/kron_torch/blob/97a2b5e
|
|
5
5
|
"""
|
6
6
|
|
7
7
|
import torch
|
8
|
+
from heavyball.utils import triu_to_line, line_to_triu
|
8
9
|
|
9
10
|
from .utils import update_param_, warmup, psgd_precond_grad, init_Q_exprs, PSGDBase, precond_update_prob_schedule, \
|
10
11
|
exp_avg_sq_, beta_debias, split_p_and_g_in_group
|
@@ -36,7 +37,9 @@ class ForeachPaLMPAdam(PSGDBase):
|
|
36
37
|
def __init__(self, params, lr=0.001, weight_decay=0.0, preconditioner_update_probability=None,
|
37
38
|
max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
|
38
39
|
momentum_into_precond_update=True, warmup_steps: int = 1, betas=(None, None), beta: float = 0.9,
|
39
|
-
beta2_scale: float = 0.8, merge_dims: bool = False, split: bool = False, clip_fn: callable = None
|
40
|
+
beta2_scale: float = 0.8, merge_dims: bool = False, split: bool = False, clip_fn: callable = None,
|
41
|
+
store_triu_as_line: bool = True,
|
42
|
+
foreach: bool = True):
|
40
43
|
if not 0.0 <= lr:
|
41
44
|
raise ValueError(f"Invalid learning rate: {lr}")
|
42
45
|
if not 0.0 <= weight_decay:
|
@@ -57,8 +60,8 @@ class ForeachPaLMPAdam(PSGDBase):
|
|
57
60
|
# precond lr hardcoded to 0.1
|
58
61
|
precond_init_scale=1.0, # precond init scale hardcoded to 1.0
|
59
62
|
step=0, warmup_steps=warmup_steps, beta=beta, beta2_scale=beta2_scale, merge_dims=merge_dims,
|
60
|
-
split=split)
|
61
|
-
super().__init__(params, defaults)
|
63
|
+
split=split, store_triu_as_line=store_triu_as_line)
|
64
|
+
super().__init__(params, defaults, foreach)
|
62
65
|
|
63
66
|
self._prob_step = 0
|
64
67
|
|
@@ -77,6 +80,7 @@ class ForeachPaLMPAdam(PSGDBase):
|
|
77
80
|
precond_lr = group['precond_lr']
|
78
81
|
weight_decay = group['weight_decay']
|
79
82
|
lr = group['lr']
|
83
|
+
store_triu_as_line = group['store_triu_as_line']
|
80
84
|
|
81
85
|
vals = []
|
82
86
|
|
@@ -86,8 +90,9 @@ class ForeachPaLMPAdam(PSGDBase):
|
|
86
90
|
if 'Q' not in state:
|
87
91
|
state['exp_avg'] = torch.zeros_like(g)
|
88
92
|
state['exp_avg_sq'] = torch.zeros_like(g)
|
89
|
-
|
90
|
-
|
93
|
+
Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular,
|
94
|
+
min_ndim_triangular, memory_save_mode, dtype=g.dtype)
|
95
|
+
state['Q'] = triu_to_line(Q) if store_triu_as_line else Q
|
91
96
|
|
92
97
|
vals.append((p, g, state["Q"], state['exp_avg'], state['exp_avg_sq']))
|
93
98
|
|
@@ -99,15 +104,16 @@ class ForeachPaLMPAdam(PSGDBase):
|
|
99
104
|
|
100
105
|
group["step"] += 1
|
101
106
|
|
102
|
-
|
107
|
+
Q_triu = [line_to_triu(q) if store_triu_as_line else q for q in Q_list]
|
103
108
|
if do_update:
|
104
|
-
self.
|
109
|
+
self.balance(grad_list, Q_triu)
|
110
|
+
self.do_update(p_list, grad_list, Q_triu, precond_lr, Q_list if store_triu_as_line else None)
|
105
111
|
|
106
112
|
torch._foreach_lerp_(exp_avg, grad_list, 1 - beta_debias(group['beta'], group['step']))
|
107
113
|
|
108
114
|
beta2 = 1 - group['step'] ** -group['beta2_scale']
|
109
115
|
|
110
|
-
for p, Q, g, ea, eas in zip(p_list,
|
116
|
+
for p, Q, g, ea, eas in zip(p_list, Q_triu, grad_list, exp_avg, exp_avg_sq):
|
111
117
|
psgd_precond_grad(Q, self.state_(p)["exprs"], g, inplace=True)
|
112
118
|
ea = psgd_precond_grad(Q, self.state_(p)["exprs"], ea)
|
113
119
|
exp_avg_sq_(eas, g, beta_debias(beta2, group['step']), 1e-8, out=g)
|
@@ -6,13 +6,14 @@ from .utils import schedule_free_, warmup, ScheduleFree, exp_avg_sq_, beta_debia
|
|
6
6
|
|
7
7
|
class PaLMForeachSFAdamW(ScheduleFree):
|
8
8
|
def __init__(self, params, lr=0.0025, beta=0.9, betas=(None, None), eps=1e-8, weight_decay=0, warmup_steps=0, r=0.0,
|
9
|
-
weight_lr_power=2.0, beta2_scale: float = 0.8
|
9
|
+
weight_lr_power=2.0, beta2_scale: float = 0.8,
|
10
|
+
foreach: bool = True):
|
10
11
|
if betas[0] is not None:
|
11
12
|
beta = betas[0]
|
12
13
|
defaults = dict(lr=lr, beta=beta, eps=eps, r=r, k=0, warmup_steps=warmup_steps, train_mode=True, weight_sum=0.0,
|
13
14
|
lr_max=-1.0, weight_lr_power=weight_lr_power, weight_decay=weight_decay,
|
14
15
|
beta2_scale=beta2_scale)
|
15
|
-
super().__init__(params, defaults)
|
16
|
+
super().__init__(params, defaults, foreach)
|
16
17
|
|
17
18
|
def _step(self, group):
|
18
19
|
eps = group['eps']
|
@@ -32,7 +32,8 @@ class PaLMForeachSOAP(StatefulOptimizer):
|
|
32
32
|
max_precond_dim: int = 2048, #
|
33
33
|
merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
|
34
34
|
data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1,
|
35
|
-
beta2_scale: float = 0.8, split: bool = False
|
35
|
+
beta2_scale: float = 0.8, split: bool = False,
|
36
|
+
foreach: bool = True):
|
36
37
|
if betas[0] is not None:
|
37
38
|
beta = betas[0]
|
38
39
|
defaults = {"lr": lr, "beta": beta, "shampoo_beta": shampoo_beta, "eps": eps, "weight_decay": weight_decay,
|
@@ -40,7 +41,7 @@ class PaLMForeachSOAP(StatefulOptimizer):
|
|
40
41
|
"merge_dims": merge_dims, "precondition_1d": precondition_1d, "normalize_grads": normalize_grads,
|
41
42
|
"correct_bias": correct_bias, 'warmup_steps': warmup_steps, 'beta2_scale': beta2_scale,
|
42
43
|
'split': split}
|
43
|
-
super().__init__(params, defaults)
|
44
|
+
super().__init__(params, defaults, foreach)
|
44
45
|
self._data_format = data_format
|
45
46
|
|
46
47
|
def _step(self, group):
|
@@ -27,13 +27,14 @@ class PrecondScheduleForeachSOAP(StatefulOptimizer):
|
|
27
27
|
weight_decay: float = 0.01, precondition_frequency: int = 2, max_precond_dim: int = 2048, #
|
28
28
|
merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
|
29
29
|
data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1,
|
30
|
-
precond_scheduler=(1 / 3, 9), split: bool = False
|
30
|
+
precond_scheduler=(1 / 3, 9), split: bool = False,
|
31
|
+
foreach: bool = True):
|
31
32
|
defaults = {"lr": lr, "betas": betas, "shampoo_beta": shampoo_beta, "eps": eps, "weight_decay": weight_decay,
|
32
33
|
"precondition_frequency": precondition_frequency, "max_precond_dim": max_precond_dim,
|
33
34
|
"merge_dims": merge_dims, "precondition_1d": precondition_1d, "normalize_grads": normalize_grads,
|
34
35
|
"correct_bias": correct_bias, 'warmup_steps': warmup_steps, 'precond_scheduler': precond_scheduler,
|
35
36
|
'split': split}
|
36
|
-
super().__init__(params, defaults)
|
37
|
+
super().__init__(params, defaults, foreach)
|
37
38
|
self._data_format = data_format
|
38
39
|
self.rng = random.Random(0x120983109)
|
39
40
|
|
@@ -32,7 +32,8 @@ class PrecondSchedulePaLMForeachSOAP(StatefulOptimizer):
|
|
32
32
|
weight_decay: float = 0.01, precondition_frequency: int = 2, max_precond_dim: int = 2048, #
|
33
33
|
merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
|
34
34
|
data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1,
|
35
|
-
precond_scheduler=(1 / 3, 9), betas=(None, None), beta2_scale: float = 0.8, split: bool = False
|
35
|
+
precond_scheduler=(1 / 3, 9), betas=(None, None), beta2_scale: float = 0.8, split: bool = False,
|
36
|
+
foreach: bool = True):
|
36
37
|
if betas[0] is not None:
|
37
38
|
beta = betas[0]
|
38
39
|
defaults = {"lr": lr, "beta": beta, "shampoo_beta": shampoo_beta, "eps": eps, "weight_decay": weight_decay,
|
@@ -40,7 +41,7 @@ class PrecondSchedulePaLMForeachSOAP(StatefulOptimizer):
|
|
40
41
|
"merge_dims": merge_dims, "precondition_1d": precondition_1d, "normalize_grads": normalize_grads,
|
41
42
|
"correct_bias": correct_bias, 'warmup_steps': warmup_steps, 'precond_scheduler': precond_scheduler,
|
42
43
|
'beta2_scale': beta2_scale, 'split': split}
|
43
|
-
super().__init__(params, defaults)
|
44
|
+
super().__init__(params, defaults, foreach)
|
44
45
|
self._data_format = data_format
|
45
46
|
self.rng = random.Random(0x120983109)
|
46
47
|
|
@@ -41,7 +41,7 @@ class PrecondScheduleSFPaLMSOAP(ScheduleFree):
|
|
41
41
|
merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
|
42
42
|
data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1, r=0.0,
|
43
43
|
weight_lr_power=2.0, gradient_clip_val: float = 0.1, precond_scheduler=(1 / 3, 9),
|
44
|
-
betas=(None, None), split: bool = False):
|
44
|
+
betas=(None, None), split: bool = False, foreach: bool = True):
|
45
45
|
if betas[0] is not None:
|
46
46
|
beta = betas[0]
|
47
47
|
defaults = {"lr": lr, "beta": beta, "beta2_scale": beta2_scale, "eps": eps, "weight_decay": weight_decay,
|
@@ -50,7 +50,7 @@ class PrecondScheduleSFPaLMSOAP(ScheduleFree):
|
|
50
50
|
"correct_bias": correct_bias, 'warmup_steps': warmup_steps, 'r': r,
|
51
51
|
'weight_lr_power': weight_lr_power, 'train_mode': True, 'step': -1, 'weight_sum': 0,
|
52
52
|
'gradient_clip_val': gradient_clip_val, 'precond_scheduler': precond_scheduler, 'split': split}
|
53
|
-
super().__init__(params, defaults)
|
53
|
+
super().__init__(params, defaults, foreach)
|
54
54
|
self._data_format = data_format
|
55
55
|
self.rng = random.Random(0x120983109)
|
56
56
|
|
@@ -59,7 +59,7 @@ class PrecondScheduleSFPaLMSOAP(ScheduleFree):
|
|
59
59
|
max_precond_dim = group['max_precond_dim']
|
60
60
|
precondition_1d = group['precondition_1d']
|
61
61
|
|
62
|
-
step = group['step'] = group.get("step",
|
62
|
+
step = group['step'] = group.get("step", 0) + 1
|
63
63
|
|
64
64
|
for p in group["params"]:
|
65
65
|
if p.grad is None:
|
@@ -38,7 +38,8 @@ class ForeachPSGDKron(PSGDBase):
|
|
38
38
|
def __init__(self, params, lr=0.001, beta=0.9, weight_decay=0.0, preconditioner_update_probability=None,
|
39
39
|
max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
|
40
40
|
momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
|
41
|
-
split: bool = False, clip_fn: Optional[callable] = None
|
41
|
+
split: bool = False, clip_fn: Optional[callable] = None, store_triu_as_line: bool = True,
|
42
|
+
foreach: bool = True):
|
42
43
|
if not 0.0 <= lr:
|
43
44
|
raise ValueError(f"Invalid learning rate: {lr}")
|
44
45
|
if not 0.0 <= beta < 1.0:
|
@@ -58,8 +59,9 @@ class ForeachPSGDKron(PSGDBase):
|
|
58
59
|
momentum_into_precond_update=momentum_into_precond_update, precond_lr=0.1,
|
59
60
|
# precond lr hardcoded to 0.1
|
60
61
|
precond_init_scale=1.0, # precond init scale hardcoded to 1.0
|
61
|
-
step=0, warmup_steps=warmup_steps, merge_dims=merge_dims, split=split
|
62
|
-
|
62
|
+
step=0, warmup_steps=warmup_steps, merge_dims=merge_dims, split=split,
|
63
|
+
store_triu_as_line=store_triu_as_line)
|
64
|
+
super().__init__(params, defaults, foreach)
|
63
65
|
|
64
66
|
self._prob_step = 0
|
65
67
|
|
@@ -80,6 +82,7 @@ class ForeachPSGDKron(PSGDBase):
|
|
80
82
|
weight_decay = group['weight_decay']
|
81
83
|
lr = group['lr']
|
82
84
|
beta = group['beta']
|
85
|
+
store_triu_as_line = group['store_triu_as_line']
|
83
86
|
|
84
87
|
vals = []
|
85
88
|
|
@@ -90,7 +93,7 @@ class ForeachPSGDKron(PSGDBase):
|
|
90
93
|
state["exp_avg"] = torch.zeros_like(g)
|
91
94
|
Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular, min_ndim_triangular,
|
92
95
|
memory_save_mode, dtype=g.dtype)
|
93
|
-
state['Q'] = triu_to_line(Q)
|
96
|
+
state['Q'] = triu_to_line(Q) if store_triu_as_line else Q
|
94
97
|
|
95
98
|
vals.append((p, g, state["exp_avg"], state["Q"]))
|
96
99
|
|
@@ -108,11 +111,12 @@ class ForeachPSGDKron(PSGDBase):
|
|
108
111
|
for i, (p, g) in enumerate(zip(p_list, grad_list)):
|
109
112
|
q_orig = Q_list.pop(0)
|
110
113
|
ea = exp_avg_list.pop(0)
|
111
|
-
q = line_to_triu(q_orig)
|
114
|
+
q = line_to_triu(q_orig) if store_triu_as_line else q_orig
|
112
115
|
|
113
|
-
self.balance(do_update, [g], [q])
|
114
116
|
if do_update:
|
115
|
-
self.
|
117
|
+
self.balance([g], [q])
|
118
|
+
self.do_update([p], [ea if momentum_into_precond_update else g], [q], precond_lr,
|
119
|
+
[q_orig] if store_triu_as_line else None)
|
116
120
|
set_(g, psgd_precond_grad(q, self.state_(p)["exprs"], ea))
|
117
121
|
|
118
122
|
grad_list = self.clip_fn(grad_list)
|
@@ -36,7 +36,8 @@ class ForeachPurePSGD(PSGDBase):
|
|
36
36
|
def __init__(self, params, lr=0.001, weight_decay=0.0, preconditioner_update_probability=None,
|
37
37
|
max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
|
38
38
|
momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
|
39
|
-
split: bool = False, clip_fn: callable = None
|
39
|
+
split: bool = False, clip_fn: callable = None, store_triu_as_line: bool = True,
|
40
|
+
foreach: bool = True):
|
40
41
|
if not 0.0 <= lr:
|
41
42
|
raise ValueError(f"Invalid learning rate: {lr}")
|
42
43
|
if not 0.0 <= weight_decay:
|
@@ -54,8 +55,9 @@ class ForeachPurePSGD(PSGDBase):
|
|
54
55
|
momentum_into_precond_update=momentum_into_precond_update, precond_lr=0.1,
|
55
56
|
# precond lr hardcoded to 0.1
|
56
57
|
precond_init_scale=1.0, # precond init scale hardcoded to 1.0
|
57
|
-
step=0, warmup_steps=warmup_steps, merge_dims=merge_dims, split=split
|
58
|
-
|
58
|
+
step=0, warmup_steps=warmup_steps, merge_dims=merge_dims, split=split,
|
59
|
+
store_triu_as_line=store_triu_as_line)
|
60
|
+
super().__init__(params, defaults, foreach)
|
59
61
|
|
60
62
|
self._prob_step = 0
|
61
63
|
|
@@ -74,6 +76,7 @@ class ForeachPurePSGD(PSGDBase):
|
|
74
76
|
precond_lr = group['precond_lr']
|
75
77
|
weight_decay = group['weight_decay']
|
76
78
|
lr = group['lr']
|
79
|
+
store_triu_as_line = group['store_triu_as_line']
|
77
80
|
|
78
81
|
vals = []
|
79
82
|
|
@@ -83,7 +86,7 @@ class ForeachPurePSGD(PSGDBase):
|
|
83
86
|
if 'Q' not in state:
|
84
87
|
Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular, min_ndim_triangular,
|
85
88
|
memory_save_mode, dtype=g.dtype)
|
86
|
-
state['Q'] = triu_to_line(Q)
|
89
|
+
state['Q'] = triu_to_line(Q) if store_triu_as_line else Q
|
87
90
|
|
88
91
|
vals.append((p, g, state["Q"]))
|
89
92
|
|
@@ -98,11 +101,11 @@ class ForeachPurePSGD(PSGDBase):
|
|
98
101
|
Q_list = list(Q_list)
|
99
102
|
for i, (p, g) in enumerate(zip(p_list, grad_list)):
|
100
103
|
q_orig = Q_list.pop(0)
|
101
|
-
q = line_to_triu(q_orig)
|
104
|
+
q = line_to_triu(q_orig) if store_triu_as_line else q_orig
|
102
105
|
|
103
|
-
self.balance(do_update, [g], [q])
|
104
106
|
if do_update:
|
105
|
-
self.
|
107
|
+
self.balance([g], [q])
|
108
|
+
self.do_update([p], [g], [q], precond_lr, [q_orig] if store_triu_as_line else None)
|
106
109
|
psgd_precond_grad(q, self.state_(p)["exprs"], g, inplace=True)
|
107
110
|
|
108
111
|
grad_list = self.clip_fn(grad_list)
|
@@ -33,7 +33,8 @@ class SFPaLMForeachSOAP(ScheduleFree):
|
|
33
33
|
weight_decay: float = 0.01, precondition_frequency: int = 2, max_precond_dim: int = 2048, #
|
34
34
|
merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
|
35
35
|
data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1, r=0.0,
|
36
|
-
weight_lr_power=2.0, gradient_clip_val: float = 0.1, betas=(None, None), split: bool = False
|
36
|
+
weight_lr_power=2.0, gradient_clip_val: float = 0.1, betas=(None, None), split: bool = False,
|
37
|
+
foreach: bool = True):
|
37
38
|
if betas[0] is not None:
|
38
39
|
beta = betas[0]
|
39
40
|
defaults = {"lr": lr, "beta": beta, "beta2_scale": beta2_scale, "eps": eps, "weight_decay": weight_decay,
|
@@ -42,7 +43,7 @@ class SFPaLMForeachSOAP(ScheduleFree):
|
|
42
43
|
"correct_bias": correct_bias, 'warmup_steps': warmup_steps, 'r': r,
|
43
44
|
'weight_lr_power': weight_lr_power, 'train_mode': True, 'step': -1,
|
44
45
|
'gradient_clip_val': gradient_clip_val, 'weight_sum': 0, 'split': split}
|
45
|
-
super().__init__(params, defaults)
|
46
|
+
super().__init__(params, defaults, foreach)
|
46
47
|
self._data_format = data_format
|
47
48
|
self.rng = random.Random(0x120983109)
|
48
49
|
|
@@ -51,7 +52,7 @@ class SFPaLMForeachSOAP(ScheduleFree):
|
|
51
52
|
max_precond_dim = group['max_precond_dim']
|
52
53
|
precondition_1d = group['precondition_1d']
|
53
54
|
|
54
|
-
step = group['step'] = group.get("step",
|
55
|
+
step = group['step'] = group.get("step", 0) + 1
|
55
56
|
|
56
57
|
for p in group["params"]:
|
57
58
|
if p.grad is None:
|
@@ -29,7 +29,7 @@ def decorator(func):
|
|
29
29
|
return _fn
|
30
30
|
|
31
31
|
|
32
|
-
|
32
|
+
einsum_base = string.ascii_lowercase + string.ascii_uppercase
|
33
33
|
|
34
34
|
|
35
35
|
def warmup(lr: float, step: int, warmup_steps: int):
|
@@ -317,8 +317,8 @@ def compute_ggt(grad, GG, max_precond_dim, precondition_1d, beta):
|
|
317
317
|
for idx, sh in enumerate(grad.shape):
|
318
318
|
if sh > max_precond_dim:
|
319
319
|
continue
|
320
|
-
b =
|
321
|
-
g0 =
|
320
|
+
b = einsum_base[idx]
|
321
|
+
g0 = einsum_base[:grad.dim()]
|
322
322
|
g1 = g0.replace(b, b.upper())
|
323
323
|
outer_product = torch.einsum(f'{g0},{g1}->{b + b.upper()}', grad, grad)
|
324
324
|
GG[idx].lerp_(promote(outer_product), 1 - beta)
|
@@ -374,7 +374,7 @@ def project(grad, Q, back: bool):
|
|
374
374
|
:param back: whether to project to Shampoo eigenbases or back to original space
|
375
375
|
:return:
|
376
376
|
"""
|
377
|
-
param =
|
377
|
+
param = einsum_base[:grad.dim()]
|
378
378
|
preconditioners = ",".join([(g + g.upper())[::-1 if back else 1] for m, g in zip(Q, param) if len(m) > 0])
|
379
379
|
if preconditioners:
|
380
380
|
out = ''.join([c.upper() if c.upper() in preconditioners else c for c in param])
|
@@ -383,8 +383,25 @@ def project(grad, Q, back: bool):
|
|
383
383
|
|
384
384
|
|
385
385
|
class StatefulOptimizer(torch.optim.Optimizer):
|
386
|
+
def __init__(self, params, defaults, foreach: bool = True):
|
387
|
+
super().__init__(params, {**defaults, 'foreach': foreach})
|
388
|
+
self.fake_groups = {}
|
389
|
+
|
390
|
+
def key(self, param: torch.Tensor):
|
391
|
+
return (param.data_ptr(), tuple(param.shape))
|
392
|
+
|
393
|
+
def get_groups(self, group):
|
394
|
+
if group['foreach']:
|
395
|
+
return [group]
|
396
|
+
|
397
|
+
for p in group['params']:
|
398
|
+
if self.key(p) not in self.fake_groups:
|
399
|
+
self.fake_groups[self.key(p)] = {**group, 'params': [p]}
|
400
|
+
|
401
|
+
return [self.fake_groups[self.key(p)] for p in group['params']]
|
402
|
+
|
386
403
|
def state_(self, arg: torch.Tensor):
|
387
|
-
return self.state[
|
404
|
+
return self.state[self.key(arg)]
|
388
405
|
|
389
406
|
def state_size(self) -> int:
|
390
407
|
total_bytes = 0
|
@@ -409,8 +426,9 @@ class StatefulOptimizer(torch.optim.Optimizer):
|
|
409
426
|
with torch.enable_grad():
|
410
427
|
loss = closure()
|
411
428
|
with torch.no_grad():
|
412
|
-
for
|
413
|
-
self.
|
429
|
+
for top_group in self.param_groups:
|
430
|
+
for group in self.get_groups(top_group):
|
431
|
+
self._step(group)
|
414
432
|
return loss
|
415
433
|
|
416
434
|
|
@@ -754,13 +772,13 @@ def update_triu_(q_state, materialised):
|
|
754
772
|
|
755
773
|
|
756
774
|
class PSGDBase(StatefulOptimizer):
|
757
|
-
def __init__(self, parameters, groups):
|
758
|
-
super().__init__(parameters, groups)
|
775
|
+
def __init__(self, parameters, groups, foreach: bool = True):
|
776
|
+
super().__init__(parameters, groups, foreach)
|
759
777
|
self.rng = random.Random(0x1923213)
|
760
778
|
self._tiny = torch.finfo(torch.bfloat16).tiny
|
761
779
|
|
762
|
-
def balance(self,
|
763
|
-
if
|
780
|
+
def balance(self, grad_list, Q_list):
|
781
|
+
if self.rng.random() > 0.01:
|
764
782
|
return
|
765
783
|
|
766
784
|
for g, q in zip(grad_list, Q_list):
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: heavyball
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.16.0
|
4
4
|
Summary: Efficient optimizers
|
5
5
|
Home-page: https://github.com/clashluke/heavyball
|
6
6
|
Author: Lucas Nestler
|
@@ -39,12 +39,14 @@ recommended experimental optimizer is `ForeachPSGDKron`.
|
|
39
39
|
|
40
40
|
* **Stochastic Rounding**: [FP32 convergence with BF16 parameters](https://github.com/pytorch/pytorch/issues/120376)
|
41
41
|
* **Inplace EMA**: Same math, but less memory, less compute and higher stability
|
42
|
-
* **Foreach**: Fast multi-tensor application
|
42
|
+
* **Foreach**: Fast multi-tensor application (turn it off to save memory via `foreach=False`)
|
43
43
|
* **PaLM Beta2**: Fast initial
|
44
44
|
convergence, [stable late convergence](https://x.com/_clashluke/status/1820810798693818761)
|
45
45
|
* **ScheduleFree**: No learning rate schedule, but better convergence
|
46
46
|
* [**Preconditioner Schedule**](https://github.com/lixilinx/psgd_torch/): Improved loss-per-step in early convergence,
|
47
47
|
better step-per-second in late convergence (explained below)
|
48
|
+
* **Memory-efficient storage** PSGD supports `store_triu_as_line` (default: `True`) to trade off memory usage for memory
|
49
|
+
bandwidth; turn it off for lower overheads (for more, see [PSGD Efficiency](docs/psgd_efficiency.md))
|
48
50
|
|
49
51
|
## Getting started
|
50
52
|
|
@@ -2,6 +2,7 @@ LICENSE
|
|
2
2
|
README.md
|
3
3
|
setup.py
|
4
4
|
heavyball/__init__.py
|
5
|
+
heavyball/cached_psgd_kron.py
|
5
6
|
heavyball/delayed_psgd.py
|
6
7
|
heavyball/foreach_adamw.py
|
7
8
|
heavyball/foreach_adopt.py
|
@@ -24,6 +25,7 @@ heavyball.egg-info/dependency_links.txt
|
|
24
25
|
heavyball.egg-info/requires.txt
|
25
26
|
heavyball.egg-info/top_level.txt
|
26
27
|
test/test_closure.py
|
28
|
+
test/test_foreach.py
|
27
29
|
test/test_memory.py
|
28
30
|
test/test_merge.py
|
29
31
|
test/test_no_grad.py
|
@@ -0,0 +1,65 @@
|
|
1
|
+
import heavyball
|
2
|
+
import heavyball.utils
|
3
|
+
import pytest
|
4
|
+
import torch
|
5
|
+
from benchmark.utils import get_optim
|
6
|
+
from heavyball.utils import clean, set_torch, PSGDBase
|
7
|
+
from torch import nn
|
8
|
+
|
9
|
+
|
10
|
+
def get_memory():
|
11
|
+
clean()
|
12
|
+
torch.cuda.synchronize()
|
13
|
+
clean()
|
14
|
+
torch.cuda.synchronize()
|
15
|
+
return torch.cuda.memory_allocated()
|
16
|
+
|
17
|
+
|
18
|
+
@pytest.mark.parametrize("opt", heavyball.__all__)
|
19
|
+
@pytest.mark.parametrize("size,depth", [(256, 128)])
|
20
|
+
def test_foreach(opt, size, depth: int, iterations: int = 5, outer_iterations: int = 3):
|
21
|
+
set_torch()
|
22
|
+
|
23
|
+
opt = getattr(heavyball, opt)
|
24
|
+
|
25
|
+
peaks = []
|
26
|
+
losses = []
|
27
|
+
|
28
|
+
for foreach in [True, False]:
|
29
|
+
peaks.append([])
|
30
|
+
losses.append([])
|
31
|
+
|
32
|
+
for i in range(outer_iterations):
|
33
|
+
torch.manual_seed(0x2131290)
|
34
|
+
clean()
|
35
|
+
model = nn.Sequential(*[nn.Linear(size, size) for _ in range(depth)]).cuda()
|
36
|
+
clean()
|
37
|
+
|
38
|
+
torch.cuda.reset_peak_memory_stats()
|
39
|
+
torch.cuda.reset_max_memory_allocated()
|
40
|
+
torch.cuda.reset_max_memory_cached()
|
41
|
+
torch.cuda.reset_accumulated_memory_stats()
|
42
|
+
|
43
|
+
clean()
|
44
|
+
o = get_optim(opt, model.parameters(), lr=1e-3, foreach=foreach)
|
45
|
+
clean()
|
46
|
+
|
47
|
+
for _ in range(iterations):
|
48
|
+
loss = model(torch.randn((1, size)).cuda()).sum()
|
49
|
+
loss.backward()
|
50
|
+
o.step()
|
51
|
+
o.zero_grad()
|
52
|
+
losses[-1].append(loss.detach())
|
53
|
+
|
54
|
+
del model, o
|
55
|
+
clean()
|
56
|
+
|
57
|
+
peak = torch.cuda.memory_stats()['allocated_bytes.all.peak']
|
58
|
+
|
59
|
+
if i > 0:
|
60
|
+
peaks[-1].append(peak)
|
61
|
+
|
62
|
+
for p0, p1 in zip(*peaks):
|
63
|
+
assert p0 > p1
|
64
|
+
for l0, l1 in zip(*losses): # increase error tolerance for PSGD, as we have different RNGs -> expected differences
|
65
|
+
assert torch.allclose(l0, l1, rtol=0.01 if isinstance(opt, PSGDBase) else 1e-5)
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|