heavyball 0.15.1__py3-none-any.whl → 0.17.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- heavyball/__init__.py +4 -2
- heavyball/cached_delayed_psgd_kron.py +146 -0
- heavyball/cached_psgd_kron.py +15 -9
- heavyball/delayed_psgd.py +10 -7
- heavyball/foreach_adamw.py +3 -2
- heavyball/foreach_adopt.py +3 -2
- heavyball/foreach_laprop.py +3 -2
- heavyball/foreach_sfadamw.py +4 -4
- heavyball/foreach_soap.py +4 -3
- heavyball/p_adam.py +11 -8
- heavyball/palm_foreach_sfadamw.py +3 -2
- heavyball/palm_foreach_soap.py +3 -2
- heavyball/precond_schedule_foreach_soap.py +3 -2
- heavyball/precond_schedule_palm_foreach_soap.py +3 -2
- heavyball/precond_schedule_sfpsoap.py +3 -3
- heavyball/psgd_kron.py +10 -7
- heavyball/pure_psgd.py +11 -7
- heavyball/schedule_free_palm_foreach_soap.py +4 -3
- heavyball/utils.py +41 -18
- {heavyball-0.15.1.dist-info → heavyball-0.17.0.dist-info}/METADATA +4 -2
- heavyball-0.17.0.dist-info/RECORD +24 -0
- heavyball-0.15.1.dist-info/RECORD +0 -23
- {heavyball-0.15.1.dist-info → heavyball-0.17.0.dist-info}/LICENSE +0 -0
- {heavyball-0.15.1.dist-info → heavyball-0.17.0.dist-info}/WHEEL +0 -0
- {heavyball-0.15.1.dist-info → heavyball-0.17.0.dist-info}/top_level.txt +0 -0
heavyball/__init__.py
CHANGED
@@ -14,6 +14,7 @@ from .precond_schedule_sfpsoap import PrecondScheduleSFPaLMSOAP
|
|
14
14
|
from .psgd_kron import ForeachPSGDKron
|
15
15
|
from .pure_psgd import ForeachPurePSGD
|
16
16
|
from .schedule_free_palm_foreach_soap import SFPaLMForeachSOAP
|
17
|
+
from .cached_delayed_psgd_kron import ForeachCachedDelayedPSGDKron
|
17
18
|
|
18
19
|
PalmForEachSoap = PaLMForeachSOAP
|
19
20
|
|
@@ -34,11 +35,12 @@ PurePSGD = ForeachPurePSGD
|
|
34
35
|
PaLMPAdam = ForeachPaLMPAdam
|
35
36
|
DelayedPSGD = ForeachDelayedPSGD
|
36
37
|
CachedPSGDKron = ForeachCachedPSGDKron
|
38
|
+
CachedDelayedPSGDKron
|
37
39
|
|
38
40
|
__all__ = ['PalmForEachSoap', 'PaLMForeachSFAdamW', 'PaLMForeachSOAP', 'SFPaLMForeachSOAP', 'PrecondScheduleSFPaLMSOAP',
|
39
41
|
'ForeachSOAP', 'ForeachSFAdamW', 'ForeachLaProp', 'ForeachADOPT', 'PrecondScheduleForeachSOAP',
|
40
42
|
'PrecondSchedulePaLMForeachSOAP', 'ForeachPSGDKron', 'ForeachAdamW', 'ForeachPurePSGD', 'ForeachPaLMPAdam',
|
41
|
-
'ForeachDelayedPSGD', 'ForeachCachedPSGDKron', #
|
43
|
+
'ForeachDelayedPSGD', 'ForeachCachedPSGDKron', 'ForeachCachedDelayedPSGDKron' #
|
42
44
|
'PaLMSOAP', 'PaLMSFAdamW', 'PaLMSFSoap', 'PaLMSFAdamW', 'PaLMForeachSOAP', 'PrecondScheduleSFPaLMSOAP',
|
43
45
|
'SOAP', 'SFAdamW', 'LaProp', 'ADOPT', 'PSGDKron', 'AdamW', 'PurePSGD', 'PaLMPAdam', 'DelayedPSGD',
|
44
|
-
'CachedPSGDKron']
|
46
|
+
'CachedPSGDKron', 'CachedDelayedPSGDKron']
|
@@ -0,0 +1,146 @@
|
|
1
|
+
"""
|
2
|
+
Originally from Evan Walters and Omead Pooladzandi, 2024
|
3
|
+
Modified under Creative Commons Attribution 4.0 International
|
4
|
+
Source available at https://github.com/evanatyourservice/kron_torch/blob/97a2b5ee8a1a4c29e4780bbf6c521e545189eff9/kron_torch/kron.py
|
5
|
+
"""
|
6
|
+
|
7
|
+
from typing import Optional
|
8
|
+
|
9
|
+
import torch
|
10
|
+
from heavyball.utils import einsum_base
|
11
|
+
|
12
|
+
from .utils import update_param_, warmup, psgd_precond_grad, init_Q_exprs, trust_region_clip_, PSGDBase, \
|
13
|
+
precond_update_prob_schedule, split_p_and_g_in_group, line_to_triu, triu_to_line, set_, einsum_base, promote
|
14
|
+
|
15
|
+
|
16
|
+
class ForeachCachedDelayedPSGDKron(PSGDBase):
|
17
|
+
"""
|
18
|
+
Implements PSGD with off-by-one preconditioning (akin to ADOPT and SOAP) with cached preconditioners.
|
19
|
+
|
20
|
+
|
21
|
+
Args:
|
22
|
+
params (iterable): Iterable of parameters to optimize or dicts defining
|
23
|
+
parameter groups.
|
24
|
+
lr (float): Learning rate.
|
25
|
+
b1 (float): Momentum parameter.
|
26
|
+
weight_decay (float): Weight decay (L2 penalty).
|
27
|
+
preconditioner_update_probability (callable or float, optional): Probability of
|
28
|
+
updating the preconditioner. If None, defaults to a schedule that anneals
|
29
|
+
from 1.0 to 0.03 by 4000 steps.
|
30
|
+
max_size_triangular (int): Max size for dim's preconditioner to be triangular.
|
31
|
+
min_ndim_triangular (int): Minimum number of dimensions a layer needs
|
32
|
+
to have triangular preconditioners.
|
33
|
+
memory_save_mode: (string, optional), None, 'one_diag', or 'all_diag', None is default
|
34
|
+
to set all preconditioners to be triangular, 'one_diag' sets the largest
|
35
|
+
or last dim to be diagonal per layer, and 'all_diag' sets all preconditioners
|
36
|
+
to be diagonal.
|
37
|
+
momentum_into_precond_update: (bool), whether to send momentum into preconditioner
|
38
|
+
update instead of raw gradients.
|
39
|
+
"""
|
40
|
+
|
41
|
+
def __init__(self, params, lr=0.001, beta=0.9, weight_decay=0.0, preconditioner_update_probability=None,
|
42
|
+
max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
|
43
|
+
momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
|
44
|
+
split: bool = False, clip_fn: Optional[callable] = None, store_triu_as_line: bool = True,
|
45
|
+
foreach: bool = True, q_dtype='float32'):
|
46
|
+
if not 0.0 <= lr:
|
47
|
+
raise ValueError(f"Invalid learning rate: {lr}")
|
48
|
+
if not 0.0 <= beta < 1.0:
|
49
|
+
raise ValueError(f"Invalid beta parameter: {beta}")
|
50
|
+
if not 0.0 <= weight_decay:
|
51
|
+
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
|
52
|
+
|
53
|
+
if preconditioner_update_probability is None:
|
54
|
+
preconditioner_update_probability = precond_update_prob_schedule()
|
55
|
+
if clip_fn is None:
|
56
|
+
clip_fn = lambda x: trust_region_clip_(x, 0.9, 1.5)
|
57
|
+
self.preconditioner_update_probability = preconditioner_update_probability
|
58
|
+
self.clip_fn = clip_fn
|
59
|
+
|
60
|
+
defaults = dict(lr=lr, beta=beta, weight_decay=weight_decay, max_size_triangular=max_size_triangular,
|
61
|
+
min_ndim_triangular=min_ndim_triangular, memory_save_mode=memory_save_mode,
|
62
|
+
momentum_into_precond_update=momentum_into_precond_update, precond_lr=0.1,
|
63
|
+
# precond lr hardcoded to 0.1
|
64
|
+
precond_init_scale=1.0, # precond init scale hardcoded to 1.0
|
65
|
+
step=0, warmup_steps=warmup_steps, merge_dims=merge_dims, split=split,
|
66
|
+
store_triu_as_line=store_triu_as_line,
|
67
|
+
q_dtype=q_dtype)
|
68
|
+
super().__init__(params, defaults, foreach)
|
69
|
+
|
70
|
+
self._prob_step = 0
|
71
|
+
|
72
|
+
def _step(self, group):
|
73
|
+
# update preconditioners all together
|
74
|
+
update_prob = self.preconditioner_update_probability
|
75
|
+
if callable(update_prob):
|
76
|
+
update_prob = update_prob(self._prob_step)
|
77
|
+
do_update = self.rng.random() < update_prob
|
78
|
+
self._prob_step += 1
|
79
|
+
|
80
|
+
momentum_into_precond_update = group.get("momentum_into_precond_update", True)
|
81
|
+
precond_init_scale = group['precond_init_scale']
|
82
|
+
max_size_triangular = group['max_size_triangular']
|
83
|
+
min_ndim_triangular = group['min_ndim_triangular']
|
84
|
+
memory_save_mode = group['memory_save_mode']
|
85
|
+
precond_lr = group['precond_lr']
|
86
|
+
weight_decay = group['weight_decay']
|
87
|
+
lr = group['lr']
|
88
|
+
beta = group['beta']
|
89
|
+
store_triu_as_line = group['store_triu_as_line']
|
90
|
+
q_dtype = getattr(torch, group['q_dtype'])
|
91
|
+
|
92
|
+
vals = []
|
93
|
+
|
94
|
+
for p, g in split_p_and_g_in_group(group):
|
95
|
+
state = self.state_(p)
|
96
|
+
|
97
|
+
if 'Q' not in state:
|
98
|
+
state["exp_avg"] = torch.zeros_like(g)
|
99
|
+
Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular, min_ndim_triangular,
|
100
|
+
memory_save_mode, dtype=q_dtype)
|
101
|
+
state['Q'] = triu_to_line(Q) if store_triu_as_line else Q
|
102
|
+
state['Q_cache'] = [torch.empty_like(q) for q in Q]
|
103
|
+
|
104
|
+
expr = [f'{c.upper()}{c}' if q_.ndim == 2 else c for c, q_ in zip(einsum_base, Q)]
|
105
|
+
expr = ','.join(expr)
|
106
|
+
grad_expr = ''.join(c for c, _ in zip(einsum_base, g.shape))
|
107
|
+
out_expr = ''.join(c.upper() if c.upper() in expr else c for c in grad_expr)
|
108
|
+
expr = f'{expr},{grad_expr}->{out_expr}'
|
109
|
+
|
110
|
+
state['cache_expr'] = expr
|
111
|
+
|
112
|
+
vals.append((p, g, state["exp_avg"], state["Q"], state['Q_cache']))
|
113
|
+
|
114
|
+
if not vals:
|
115
|
+
return
|
116
|
+
|
117
|
+
p_list, grad_list, exp_avg_list, Q_list, Q_cache_list = zip(*vals)
|
118
|
+
del vals
|
119
|
+
|
120
|
+
group["step"] += 1
|
121
|
+
|
122
|
+
torch._foreach_lerp_(exp_avg_list, grad_list, (1 - beta) / (1 - beta ** group["step"]))
|
123
|
+
|
124
|
+
grad_list, Q_list, Q_cache_list, exp_avg_list = list(grad_list), list(Q_list), list(Q_cache_list), list(
|
125
|
+
exp_avg_list)
|
126
|
+
for i, (p, g) in enumerate(zip(p_list, grad_list)):
|
127
|
+
cached_q = Q_cache_list.pop(0)
|
128
|
+
q_orig = Q_list.pop(0)
|
129
|
+
ea = exp_avg_list.pop(0)
|
130
|
+
|
131
|
+
if do_update:
|
132
|
+
q = line_to_triu(q_orig) if store_triu_as_line else q_orig
|
133
|
+
q32 = [promote(q_) for q_ in q]
|
134
|
+
self.balance([g], [q32])
|
135
|
+
self.do_update([p], [ea if momentum_into_precond_update else g], [q32], precond_lr, [q_orig], store_triu_as_line=store_triu_as_line)
|
136
|
+
for c_, q_ in zip(cached_q, q):
|
137
|
+
if q_.ndim == 2:
|
138
|
+
torch.matmul(q_.T.conj(), q_, out=c_)
|
139
|
+
else:
|
140
|
+
torch.mul(q_.conj(), q_, out=c_)
|
141
|
+
|
142
|
+
set_(g, torch.einsum(self.state_(p)['cache_expr'], *cached_q, ea))
|
143
|
+
grad_list = self.clip_fn(grad_list)
|
144
|
+
|
145
|
+
lr = -warmup(lr, group['step'], group['warmup_steps'])
|
146
|
+
update_param_(p_list, grad_list, lr, weight_decay)
|
heavyball/cached_psgd_kron.py
CHANGED
@@ -10,7 +10,7 @@ import torch
|
|
10
10
|
from heavyball.utils import einsum_base
|
11
11
|
|
12
12
|
from .utils import update_param_, warmup, psgd_precond_grad, init_Q_exprs, trust_region_clip_, PSGDBase, \
|
13
|
-
precond_update_prob_schedule, split_p_and_g_in_group, line_to_triu, triu_to_line, set_, einsum_base
|
13
|
+
precond_update_prob_schedule, split_p_and_g_in_group, line_to_triu, triu_to_line, set_, einsum_base, promote
|
14
14
|
|
15
15
|
|
16
16
|
class ForeachCachedPSGDKron(PSGDBase):
|
@@ -39,7 +39,8 @@ class ForeachCachedPSGDKron(PSGDBase):
|
|
39
39
|
def __init__(self, params, lr=0.001, beta=0.9, weight_decay=0.0, preconditioner_update_probability=None,
|
40
40
|
max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
|
41
41
|
momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
|
42
|
-
split: bool = False, clip_fn: Optional[callable] = None, store_triu_as_line: bool = True
|
42
|
+
split: bool = False, clip_fn: Optional[callable] = None, store_triu_as_line: bool = True,
|
43
|
+
foreach: bool = True, q_dtype='float32'):
|
43
44
|
if not 0.0 <= lr:
|
44
45
|
raise ValueError(f"Invalid learning rate: {lr}")
|
45
46
|
if not 0.0 <= beta < 1.0:
|
@@ -60,8 +61,9 @@ class ForeachCachedPSGDKron(PSGDBase):
|
|
60
61
|
# precond lr hardcoded to 0.1
|
61
62
|
precond_init_scale=1.0, # precond init scale hardcoded to 1.0
|
62
63
|
step=0, warmup_steps=warmup_steps, merge_dims=merge_dims, split=split,
|
63
|
-
store_triu_as_line=store_triu_as_line
|
64
|
-
|
64
|
+
store_triu_as_line=store_triu_as_line,
|
65
|
+
q_dtype=q_dtype)
|
66
|
+
super().__init__(params, defaults, foreach)
|
65
67
|
|
66
68
|
self._prob_step = 0
|
67
69
|
|
@@ -83,6 +85,7 @@ class ForeachCachedPSGDKron(PSGDBase):
|
|
83
85
|
lr = group['lr']
|
84
86
|
beta = group['beta']
|
85
87
|
store_triu_as_line = group['store_triu_as_line']
|
88
|
+
q_dtype = getattr(torch, group['q_dtype'])
|
86
89
|
|
87
90
|
vals = []
|
88
91
|
|
@@ -92,7 +95,7 @@ class ForeachCachedPSGDKron(PSGDBase):
|
|
92
95
|
if 'Q' not in state:
|
93
96
|
state["exp_avg"] = torch.zeros_like(g)
|
94
97
|
Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular, min_ndim_triangular,
|
95
|
-
memory_save_mode, dtype=
|
98
|
+
memory_save_mode, dtype=q_dtype)
|
96
99
|
state['Q'] = triu_to_line(Q) if store_triu_as_line else Q
|
97
100
|
state['Q_cache'] = [torch.empty_like(q) for q in Q]
|
98
101
|
|
@@ -123,18 +126,21 @@ class ForeachCachedPSGDKron(PSGDBase):
|
|
123
126
|
q_orig = Q_list.pop(0)
|
124
127
|
ea = exp_avg_list.pop(0)
|
125
128
|
|
129
|
+
new = torch.einsum(self.state_(p)['cache_expr'], *cached_q, ea)
|
130
|
+
|
126
131
|
if do_update:
|
127
132
|
q = line_to_triu(q_orig) if store_triu_as_line else q_orig
|
128
|
-
|
129
|
-
self.
|
130
|
-
|
133
|
+
q32 = [promote(q_) for q_ in q]
|
134
|
+
self.balance([g], [q32])
|
135
|
+
self.do_update([p], [ea if momentum_into_precond_update else g], [q32], precond_lr, [q_orig], store_triu_as_line=store_triu_as_line)
|
131
136
|
for c_, q_ in zip(cached_q, q):
|
132
137
|
if q_.ndim == 2:
|
133
138
|
torch.matmul(q_.T.conj(), q_, out=c_)
|
134
139
|
else:
|
135
140
|
torch.mul(q_.conj(), q_, out=c_)
|
136
141
|
|
137
|
-
set_(g,
|
142
|
+
set_(g, new)
|
143
|
+
|
138
144
|
grad_list = self.clip_fn(grad_list)
|
139
145
|
|
140
146
|
lr = -warmup(lr, group['step'], group['warmup_steps'])
|
heavyball/delayed_psgd.py
CHANGED
@@ -8,7 +8,7 @@ import torch
|
|
8
8
|
from heavyball.utils import copy_stochastic_list_
|
9
9
|
|
10
10
|
from .utils import update_param_, warmup, psgd_precond_grad, init_Q_exprs, trust_region_clip_, PSGDBase, \
|
11
|
-
precond_update_prob_schedule, split_p_and_g_in_group, triu_to_line, line_to_triu, set_
|
11
|
+
precond_update_prob_schedule, split_p_and_g_in_group, triu_to_line, line_to_triu, set_, promote
|
12
12
|
|
13
13
|
|
14
14
|
class ForeachDelayedPSGD(PSGDBase):
|
@@ -38,7 +38,8 @@ class ForeachDelayedPSGD(PSGDBase):
|
|
38
38
|
def __init__(self, params, lr=0.001, beta=0.9, weight_decay=0.0, preconditioner_update_probability=None,
|
39
39
|
max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
|
40
40
|
momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
|
41
|
-
split: bool = False, clip_fn: callable = None, store_triu_as_line: bool = True
|
41
|
+
split: bool = False, clip_fn: callable = None, store_triu_as_line: bool = True,
|
42
|
+
foreach: bool = True, q_dtype='float32'):
|
42
43
|
if not 0.0 <= lr:
|
43
44
|
raise ValueError(f"Invalid learning rate: {lr}")
|
44
45
|
if not 0.0 <= beta < 1.0:
|
@@ -59,8 +60,8 @@ class ForeachDelayedPSGD(PSGDBase):
|
|
59
60
|
# precond lr hardcoded to 0.1
|
60
61
|
precond_init_scale=1.0, # precond init scale hardcoded to 1.0
|
61
62
|
step=0, warmup_steps=warmup_steps, merge_dims=merge_dims, split=split,
|
62
|
-
store_triu_as_line=store_triu_as_line)
|
63
|
-
super().__init__(params, defaults)
|
63
|
+
store_triu_as_line=store_triu_as_line, q_dtype=q_dtype)
|
64
|
+
super().__init__(params, defaults, foreach)
|
64
65
|
|
65
66
|
self._prob_step = 0
|
66
67
|
|
@@ -82,6 +83,7 @@ class ForeachDelayedPSGD(PSGDBase):
|
|
82
83
|
lr = group['lr']
|
83
84
|
beta = group['beta']
|
84
85
|
store_triu_as_line = group['store_triu_as_line']
|
86
|
+
q_dtype = getattr(torch, group['q_dtype'])
|
85
87
|
|
86
88
|
vals = []
|
87
89
|
|
@@ -91,7 +93,7 @@ class ForeachDelayedPSGD(PSGDBase):
|
|
91
93
|
if 'Q' not in state:
|
92
94
|
state["exp_avg"] = torch.zeros_like(g)
|
93
95
|
Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular, min_ndim_triangular,
|
94
|
-
memory_save_mode, dtype=
|
96
|
+
memory_save_mode, dtype=q_dtype)
|
95
97
|
state["Q"] = triu_to_line(Q) if store_triu_as_line else Q
|
96
98
|
|
97
99
|
vals.append((p, g, state["exp_avg"], state["Q"]))
|
@@ -113,8 +115,9 @@ class ForeachDelayedPSGD(PSGDBase):
|
|
113
115
|
q = line_to_triu(q_orig) if store_triu_as_line else q_orig
|
114
116
|
new = psgd_precond_grad(q, self.state_(p)["exprs"], ea)
|
115
117
|
if do_update:
|
116
|
-
|
117
|
-
self.
|
118
|
+
q32 = [promote(q_) for q_ in q]
|
119
|
+
self.do_update([p], [ea if momentum_into_precond_update else g], [q32], precond_lr, [q_orig], store_triu_as_line=store_triu_as_line)
|
120
|
+
self.balance([g], [q32])
|
118
121
|
set_(g, new)
|
119
122
|
|
120
123
|
grad_list = self.clip_fn(grad_list)
|
heavyball/foreach_adamw.py
CHANGED
@@ -5,10 +5,11 @@ from .utils import warmup, exp_avg_sq_, beta_debias, update_param_, StatefulOpti
|
|
5
5
|
|
6
6
|
|
7
7
|
class ForeachAdamW(StatefulOptimizer):
|
8
|
-
def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0
|
8
|
+
def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
|
9
|
+
foreach: bool = True):
|
9
10
|
defaults = dict(lr=lr, betas=betas, eps=eps, k=0, warmup_steps=warmup_steps, train_mode=True, weight_sum=0.0,
|
10
11
|
lr_max=-1.0, weight_decay=weight_decay)
|
11
|
-
super().__init__(params, defaults)
|
12
|
+
super().__init__(params, defaults, foreach)
|
12
13
|
|
13
14
|
def _step(self, group):
|
14
15
|
eps = group['eps']
|
heavyball/foreach_adopt.py
CHANGED
@@ -6,10 +6,11 @@ from .utils import warmup, beta_debias, update_param_, StatefulOptimizer
|
|
6
6
|
|
7
7
|
class ForeachADOPT(StatefulOptimizer):
|
8
8
|
|
9
|
-
def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0
|
9
|
+
def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
|
10
|
+
foreach: bool = True):
|
10
11
|
defaults = dict(lr=lr, betas=betas, eps=eps, k=0, warmup_steps=warmup_steps, train_mode=True, weight_sum=0.0,
|
11
12
|
lr_max=-1.0, weight_decay=weight_decay)
|
12
|
-
super().__init__(params, defaults)
|
13
|
+
super().__init__(params, defaults, foreach)
|
13
14
|
|
14
15
|
def _step(self, group):
|
15
16
|
eps = group['eps']
|
heavyball/foreach_laprop.py
CHANGED
@@ -6,10 +6,11 @@ from .utils import warmup, exp_avg_sq_, beta_debias, update_param_, StatefulOpti
|
|
6
6
|
|
7
7
|
class ForeachLaProp(StatefulOptimizer):
|
8
8
|
|
9
|
-
def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=1
|
9
|
+
def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=1,
|
10
|
+
foreach: bool = True):
|
10
11
|
defaults = dict(lr=lr, betas=betas, eps=eps, k=0, warmup_steps=warmup_steps, train_mode=True, weight_sum=0.0,
|
11
12
|
lr_max=-1.0, weight_decay=weight_decay)
|
12
|
-
super().__init__(params, defaults)
|
13
|
+
super().__init__(params, defaults, foreach)
|
13
14
|
|
14
15
|
def _step(self, group):
|
15
16
|
eps = group['eps']
|
heavyball/foreach_sfadamw.py
CHANGED
@@ -6,12 +6,12 @@ from .utils import schedule_free_, warmup, ScheduleFree, exp_avg_sq_, beta_debia
|
|
6
6
|
|
7
7
|
class ForeachSFAdamW(ScheduleFree):
|
8
8
|
def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0, r=0.0,
|
9
|
-
weight_lr_power=2.0, foreach=
|
9
|
+
weight_lr_power=2.0, foreach: bool = True):
|
10
10
|
|
11
11
|
defaults = dict(lr=lr, betas=betas, eps=eps, r=r, k=0, warmup_steps=warmup_steps, train_mode=True,
|
12
12
|
weight_sum=0.0, lr_max=-1.0, weight_lr_power=weight_lr_power, weight_decay=weight_decay,
|
13
13
|
foreach=foreach)
|
14
|
-
super().__init__(params, defaults)
|
14
|
+
super().__init__(params, defaults, foreach)
|
15
15
|
|
16
16
|
def _step(self, group):
|
17
17
|
eps = group['eps']
|
@@ -48,7 +48,7 @@ class ForeachSFAdamW(ScheduleFree):
|
|
48
48
|
torch._foreach_add_(grad, y, alpha=decay)
|
49
49
|
|
50
50
|
lr = warmup(group['lr'], k + 1, group['warmup_steps'])
|
51
|
-
group['weight_sum'] = schedule_free_(lr, group['weight_lr_power'], group['weight_sum'], group['betas'][0],
|
52
|
-
|
51
|
+
group['weight_sum'] = schedule_free_(lr, group['weight_lr_power'], group['weight_sum'], group['betas'][0], y, z,
|
52
|
+
grad, group['r'], k + 1)
|
53
53
|
|
54
54
|
group['k'] = k + 1
|
heavyball/foreach_soap.py
CHANGED
@@ -26,12 +26,13 @@ class ForeachSOAP(StatefulOptimizer):
|
|
26
26
|
weight_decay: float = 0.01, precondition_frequency: int = 2, max_precond_dim: int = 2048, #
|
27
27
|
merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
|
28
28
|
data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1,
|
29
|
-
split: bool = False
|
29
|
+
split: bool = False,
|
30
|
+
foreach: bool = True):
|
30
31
|
defaults = {"lr": lr, "betas": betas, "shampoo_beta": shampoo_beta, "eps": eps, "weight_decay": weight_decay,
|
31
32
|
"precondition_frequency": precondition_frequency, "max_precond_dim": max_precond_dim,
|
32
33
|
"merge_dims": merge_dims, "precondition_1d": precondition_1d, "normalize_grads": normalize_grads,
|
33
34
|
"correct_bias": correct_bias, 'warmup_steps': warmup_steps, 'split': split}
|
34
|
-
super().__init__(params, defaults)
|
35
|
+
super().__init__(params, defaults, foreach)
|
35
36
|
self._data_format = data_format
|
36
37
|
|
37
38
|
def _step(self, group):
|
@@ -59,7 +60,7 @@ class ForeachSOAP(StatefulOptimizer):
|
|
59
60
|
vals.append((p, g, grad_projected, exp_avg, exp_avg_sq))
|
60
61
|
|
61
62
|
if not vals:
|
62
|
-
return
|
63
|
+
return
|
63
64
|
|
64
65
|
p_list, grad, grad_projected, exp_avg, exp_avg_sq = zip(*vals)
|
65
66
|
beta1, beta2 = group["betas"]
|
heavyball/p_adam.py
CHANGED
@@ -8,7 +8,7 @@ import torch
|
|
8
8
|
from heavyball.utils import triu_to_line, line_to_triu
|
9
9
|
|
10
10
|
from .utils import update_param_, warmup, psgd_precond_grad, init_Q_exprs, PSGDBase, precond_update_prob_schedule, \
|
11
|
-
exp_avg_sq_, beta_debias, split_p_and_g_in_group
|
11
|
+
exp_avg_sq_, beta_debias, split_p_and_g_in_group, promote
|
12
12
|
|
13
13
|
|
14
14
|
class ForeachPaLMPAdam(PSGDBase):
|
@@ -38,7 +38,8 @@ class ForeachPaLMPAdam(PSGDBase):
|
|
38
38
|
max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
|
39
39
|
momentum_into_precond_update=True, warmup_steps: int = 1, betas=(None, None), beta: float = 0.9,
|
40
40
|
beta2_scale: float = 0.8, merge_dims: bool = False, split: bool = False, clip_fn: callable = None,
|
41
|
-
store_triu_as_line: bool = True
|
41
|
+
store_triu_as_line: bool = True,
|
42
|
+
foreach: bool = True, q_dtype='float32'):
|
42
43
|
if not 0.0 <= lr:
|
43
44
|
raise ValueError(f"Invalid learning rate: {lr}")
|
44
45
|
if not 0.0 <= weight_decay:
|
@@ -59,8 +60,8 @@ class ForeachPaLMPAdam(PSGDBase):
|
|
59
60
|
# precond lr hardcoded to 0.1
|
60
61
|
precond_init_scale=1.0, # precond init scale hardcoded to 1.0
|
61
62
|
step=0, warmup_steps=warmup_steps, beta=beta, beta2_scale=beta2_scale, merge_dims=merge_dims,
|
62
|
-
split=split, store_triu_as_line=store_triu_as_line)
|
63
|
-
super().__init__(params, defaults)
|
63
|
+
split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype)
|
64
|
+
super().__init__(params, defaults, foreach)
|
64
65
|
|
65
66
|
self._prob_step = 0
|
66
67
|
|
@@ -80,6 +81,7 @@ class ForeachPaLMPAdam(PSGDBase):
|
|
80
81
|
weight_decay = group['weight_decay']
|
81
82
|
lr = group['lr']
|
82
83
|
store_triu_as_line = group['store_triu_as_line']
|
84
|
+
q_dtype = getattr(torch, group['q_dtype'])
|
83
85
|
|
84
86
|
vals = []
|
85
87
|
|
@@ -90,7 +92,7 @@ class ForeachPaLMPAdam(PSGDBase):
|
|
90
92
|
state['exp_avg'] = torch.zeros_like(g)
|
91
93
|
state['exp_avg_sq'] = torch.zeros_like(g)
|
92
94
|
Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular,
|
93
|
-
|
95
|
+
min_ndim_triangular, memory_save_mode, dtype=q_dtype)
|
94
96
|
state['Q'] = triu_to_line(Q) if store_triu_as_line else Q
|
95
97
|
|
96
98
|
vals.append((p, g, state["Q"], state['exp_avg'], state['exp_avg_sq']))
|
@@ -105,9 +107,10 @@ class ForeachPaLMPAdam(PSGDBase):
|
|
105
107
|
|
106
108
|
Q_triu = [line_to_triu(q) if store_triu_as_line else q for q in Q_list]
|
107
109
|
if do_update:
|
108
|
-
|
109
|
-
|
110
|
-
|
110
|
+
for g, p, q_, q_orig in zip(grad_list, p_list, Q_triu, Q_list):
|
111
|
+
q32 = [promote(qq_) for qq_ in q_]
|
112
|
+
self.balance([g], [q32])
|
113
|
+
self.do_update([p], [g], [q32], precond_lr, [q_orig], store_triu_as_line=store_triu_as_line)
|
111
114
|
torch._foreach_lerp_(exp_avg, grad_list, 1 - beta_debias(group['beta'], group['step']))
|
112
115
|
|
113
116
|
beta2 = 1 - group['step'] ** -group['beta2_scale']
|
@@ -6,13 +6,14 @@ from .utils import schedule_free_, warmup, ScheduleFree, exp_avg_sq_, beta_debia
|
|
6
6
|
|
7
7
|
class PaLMForeachSFAdamW(ScheduleFree):
|
8
8
|
def __init__(self, params, lr=0.0025, beta=0.9, betas=(None, None), eps=1e-8, weight_decay=0, warmup_steps=0, r=0.0,
|
9
|
-
weight_lr_power=2.0, beta2_scale: float = 0.8
|
9
|
+
weight_lr_power=2.0, beta2_scale: float = 0.8,
|
10
|
+
foreach: bool = True):
|
10
11
|
if betas[0] is not None:
|
11
12
|
beta = betas[0]
|
12
13
|
defaults = dict(lr=lr, beta=beta, eps=eps, r=r, k=0, warmup_steps=warmup_steps, train_mode=True, weight_sum=0.0,
|
13
14
|
lr_max=-1.0, weight_lr_power=weight_lr_power, weight_decay=weight_decay,
|
14
15
|
beta2_scale=beta2_scale)
|
15
|
-
super().__init__(params, defaults)
|
16
|
+
super().__init__(params, defaults, foreach)
|
16
17
|
|
17
18
|
def _step(self, group):
|
18
19
|
eps = group['eps']
|
heavyball/palm_foreach_soap.py
CHANGED
@@ -32,7 +32,8 @@ class PaLMForeachSOAP(StatefulOptimizer):
|
|
32
32
|
max_precond_dim: int = 2048, #
|
33
33
|
merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
|
34
34
|
data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1,
|
35
|
-
beta2_scale: float = 0.8, split: bool = False
|
35
|
+
beta2_scale: float = 0.8, split: bool = False,
|
36
|
+
foreach: bool = True):
|
36
37
|
if betas[0] is not None:
|
37
38
|
beta = betas[0]
|
38
39
|
defaults = {"lr": lr, "beta": beta, "shampoo_beta": shampoo_beta, "eps": eps, "weight_decay": weight_decay,
|
@@ -40,7 +41,7 @@ class PaLMForeachSOAP(StatefulOptimizer):
|
|
40
41
|
"merge_dims": merge_dims, "precondition_1d": precondition_1d, "normalize_grads": normalize_grads,
|
41
42
|
"correct_bias": correct_bias, 'warmup_steps': warmup_steps, 'beta2_scale': beta2_scale,
|
42
43
|
'split': split}
|
43
|
-
super().__init__(params, defaults)
|
44
|
+
super().__init__(params, defaults, foreach)
|
44
45
|
self._data_format = data_format
|
45
46
|
|
46
47
|
def _step(self, group):
|
@@ -27,13 +27,14 @@ class PrecondScheduleForeachSOAP(StatefulOptimizer):
|
|
27
27
|
weight_decay: float = 0.01, precondition_frequency: int = 2, max_precond_dim: int = 2048, #
|
28
28
|
merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
|
29
29
|
data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1,
|
30
|
-
precond_scheduler=(1 / 3, 9), split: bool = False
|
30
|
+
precond_scheduler=(1 / 3, 9), split: bool = False,
|
31
|
+
foreach: bool = True):
|
31
32
|
defaults = {"lr": lr, "betas": betas, "shampoo_beta": shampoo_beta, "eps": eps, "weight_decay": weight_decay,
|
32
33
|
"precondition_frequency": precondition_frequency, "max_precond_dim": max_precond_dim,
|
33
34
|
"merge_dims": merge_dims, "precondition_1d": precondition_1d, "normalize_grads": normalize_grads,
|
34
35
|
"correct_bias": correct_bias, 'warmup_steps': warmup_steps, 'precond_scheduler': precond_scheduler,
|
35
36
|
'split': split}
|
36
|
-
super().__init__(params, defaults)
|
37
|
+
super().__init__(params, defaults, foreach)
|
37
38
|
self._data_format = data_format
|
38
39
|
self.rng = random.Random(0x120983109)
|
39
40
|
|
@@ -32,7 +32,8 @@ class PrecondSchedulePaLMForeachSOAP(StatefulOptimizer):
|
|
32
32
|
weight_decay: float = 0.01, precondition_frequency: int = 2, max_precond_dim: int = 2048, #
|
33
33
|
merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
|
34
34
|
data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1,
|
35
|
-
precond_scheduler=(1 / 3, 9), betas=(None, None), beta2_scale: float = 0.8, split: bool = False
|
35
|
+
precond_scheduler=(1 / 3, 9), betas=(None, None), beta2_scale: float = 0.8, split: bool = False,
|
36
|
+
foreach: bool = True):
|
36
37
|
if betas[0] is not None:
|
37
38
|
beta = betas[0]
|
38
39
|
defaults = {"lr": lr, "beta": beta, "shampoo_beta": shampoo_beta, "eps": eps, "weight_decay": weight_decay,
|
@@ -40,7 +41,7 @@ class PrecondSchedulePaLMForeachSOAP(StatefulOptimizer):
|
|
40
41
|
"merge_dims": merge_dims, "precondition_1d": precondition_1d, "normalize_grads": normalize_grads,
|
41
42
|
"correct_bias": correct_bias, 'warmup_steps': warmup_steps, 'precond_scheduler': precond_scheduler,
|
42
43
|
'beta2_scale': beta2_scale, 'split': split}
|
43
|
-
super().__init__(params, defaults)
|
44
|
+
super().__init__(params, defaults, foreach)
|
44
45
|
self._data_format = data_format
|
45
46
|
self.rng = random.Random(0x120983109)
|
46
47
|
|
@@ -41,7 +41,7 @@ class PrecondScheduleSFPaLMSOAP(ScheduleFree):
|
|
41
41
|
merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
|
42
42
|
data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1, r=0.0,
|
43
43
|
weight_lr_power=2.0, gradient_clip_val: float = 0.1, precond_scheduler=(1 / 3, 9),
|
44
|
-
betas=(None, None), split: bool = False):
|
44
|
+
betas=(None, None), split: bool = False, foreach: bool = True):
|
45
45
|
if betas[0] is not None:
|
46
46
|
beta = betas[0]
|
47
47
|
defaults = {"lr": lr, "beta": beta, "beta2_scale": beta2_scale, "eps": eps, "weight_decay": weight_decay,
|
@@ -50,7 +50,7 @@ class PrecondScheduleSFPaLMSOAP(ScheduleFree):
|
|
50
50
|
"correct_bias": correct_bias, 'warmup_steps': warmup_steps, 'r': r,
|
51
51
|
'weight_lr_power': weight_lr_power, 'train_mode': True, 'step': -1, 'weight_sum': 0,
|
52
52
|
'gradient_clip_val': gradient_clip_val, 'precond_scheduler': precond_scheduler, 'split': split}
|
53
|
-
super().__init__(params, defaults)
|
53
|
+
super().__init__(params, defaults, foreach)
|
54
54
|
self._data_format = data_format
|
55
55
|
self.rng = random.Random(0x120983109)
|
56
56
|
|
@@ -59,7 +59,7 @@ class PrecondScheduleSFPaLMSOAP(ScheduleFree):
|
|
59
59
|
max_precond_dim = group['max_precond_dim']
|
60
60
|
precondition_1d = group['precondition_1d']
|
61
61
|
|
62
|
-
step = group['step'] = group.get("step",
|
62
|
+
step = group['step'] = group.get("step", 0) + 1
|
63
63
|
|
64
64
|
for p in group["params"]:
|
65
65
|
if p.grad is None:
|
heavyball/psgd_kron.py
CHANGED
@@ -9,7 +9,7 @@ from typing import Optional
|
|
9
9
|
import torch
|
10
10
|
|
11
11
|
from .utils import update_param_, warmup, psgd_precond_grad, init_Q_exprs, trust_region_clip_, PSGDBase, \
|
12
|
-
precond_update_prob_schedule, split_p_and_g_in_group, line_to_triu, triu_to_line, set_
|
12
|
+
precond_update_prob_schedule, split_p_and_g_in_group, line_to_triu, triu_to_line, set_, promote
|
13
13
|
|
14
14
|
|
15
15
|
class ForeachPSGDKron(PSGDBase):
|
@@ -38,7 +38,8 @@ class ForeachPSGDKron(PSGDBase):
|
|
38
38
|
def __init__(self, params, lr=0.001, beta=0.9, weight_decay=0.0, preconditioner_update_probability=None,
|
39
39
|
max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
|
40
40
|
momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
|
41
|
-
split: bool = False, clip_fn: Optional[callable] = None, store_triu_as_line: bool = True
|
41
|
+
split: bool = False, clip_fn: Optional[callable] = None, store_triu_as_line: bool = True,
|
42
|
+
foreach: bool = True, q_dtype='float32'):
|
42
43
|
if not 0.0 <= lr:
|
43
44
|
raise ValueError(f"Invalid learning rate: {lr}")
|
44
45
|
if not 0.0 <= beta < 1.0:
|
@@ -59,8 +60,8 @@ class ForeachPSGDKron(PSGDBase):
|
|
59
60
|
# precond lr hardcoded to 0.1
|
60
61
|
precond_init_scale=1.0, # precond init scale hardcoded to 1.0
|
61
62
|
step=0, warmup_steps=warmup_steps, merge_dims=merge_dims, split=split,
|
62
|
-
store_triu_as_line=store_triu_as_line)
|
63
|
-
super().__init__(params, defaults)
|
63
|
+
store_triu_as_line=store_triu_as_line, q_dtype=q_dtype)
|
64
|
+
super().__init__(params, defaults, foreach)
|
64
65
|
|
65
66
|
self._prob_step = 0
|
66
67
|
|
@@ -82,6 +83,7 @@ class ForeachPSGDKron(PSGDBase):
|
|
82
83
|
lr = group['lr']
|
83
84
|
beta = group['beta']
|
84
85
|
store_triu_as_line = group['store_triu_as_line']
|
86
|
+
q_dtype = getattr(torch, group['q_dtype'])
|
85
87
|
|
86
88
|
vals = []
|
87
89
|
|
@@ -91,7 +93,7 @@ class ForeachPSGDKron(PSGDBase):
|
|
91
93
|
if 'Q' not in state:
|
92
94
|
state["exp_avg"] = torch.zeros_like(g)
|
93
95
|
Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular, min_ndim_triangular,
|
94
|
-
memory_save_mode, dtype=
|
96
|
+
memory_save_mode, dtype=q_dtype)
|
95
97
|
state['Q'] = triu_to_line(Q) if store_triu_as_line else Q
|
96
98
|
|
97
99
|
vals.append((p, g, state["exp_avg"], state["Q"]))
|
@@ -113,8 +115,9 @@ class ForeachPSGDKron(PSGDBase):
|
|
113
115
|
q = line_to_triu(q_orig) if store_triu_as_line else q_orig
|
114
116
|
|
115
117
|
if do_update:
|
116
|
-
|
117
|
-
self.
|
118
|
+
q32 = [promote(q_) for q_ in q]
|
119
|
+
self.balance([ea if momentum_into_precond_update else g], [q32])
|
120
|
+
self.do_update([p], [g], [q32], precond_lr, [q_orig], store_triu_as_line=store_triu_as_line)
|
118
121
|
set_(g, psgd_precond_grad(q, self.state_(p)["exprs"], ea))
|
119
122
|
|
120
123
|
grad_list = self.clip_fn(grad_list)
|
heavyball/pure_psgd.py
CHANGED
@@ -5,9 +5,10 @@ Source available at https://github.com/evanatyourservice/kron_torch/blob/97a2b5e
|
|
5
5
|
"""
|
6
6
|
|
7
7
|
import torch
|
8
|
+
from heavyball.utils import copy_stochastic_list_
|
8
9
|
|
9
10
|
from .utils import update_param_, warmup, psgd_precond_grad, init_Q_exprs, PSGDBase, precond_update_prob_schedule, \
|
10
|
-
split_p_and_g_in_group, line_to_triu, triu_to_line
|
11
|
+
split_p_and_g_in_group, line_to_triu, triu_to_line, promote
|
11
12
|
|
12
13
|
|
13
14
|
class ForeachPurePSGD(PSGDBase):
|
@@ -36,7 +37,8 @@ class ForeachPurePSGD(PSGDBase):
|
|
36
37
|
def __init__(self, params, lr=0.001, weight_decay=0.0, preconditioner_update_probability=None,
|
37
38
|
max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
|
38
39
|
momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
|
39
|
-
split: bool = False, clip_fn: callable = None, store_triu_as_line: bool = True
|
40
|
+
split: bool = False, clip_fn: callable = None, store_triu_as_line: bool = True,
|
41
|
+
foreach: bool = True, q_dtype='float32'):
|
40
42
|
if not 0.0 <= lr:
|
41
43
|
raise ValueError(f"Invalid learning rate: {lr}")
|
42
44
|
if not 0.0 <= weight_decay:
|
@@ -55,8 +57,8 @@ class ForeachPurePSGD(PSGDBase):
|
|
55
57
|
# precond lr hardcoded to 0.1
|
56
58
|
precond_init_scale=1.0, # precond init scale hardcoded to 1.0
|
57
59
|
step=0, warmup_steps=warmup_steps, merge_dims=merge_dims, split=split,
|
58
|
-
store_triu_as_line=store_triu_as_line)
|
59
|
-
super().__init__(params, defaults)
|
60
|
+
store_triu_as_line=store_triu_as_line, q_dtype=q_dtype)
|
61
|
+
super().__init__(params, defaults, foreach)
|
60
62
|
|
61
63
|
self._prob_step = 0
|
62
64
|
|
@@ -76,6 +78,7 @@ class ForeachPurePSGD(PSGDBase):
|
|
76
78
|
weight_decay = group['weight_decay']
|
77
79
|
lr = group['lr']
|
78
80
|
store_triu_as_line = group['store_triu_as_line']
|
81
|
+
q_dtype = getattr(torch, group['q_dtype'])
|
79
82
|
|
80
83
|
vals = []
|
81
84
|
|
@@ -84,7 +87,7 @@ class ForeachPurePSGD(PSGDBase):
|
|
84
87
|
|
85
88
|
if 'Q' not in state:
|
86
89
|
Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular, min_ndim_triangular,
|
87
|
-
memory_save_mode, dtype=
|
90
|
+
memory_save_mode, dtype=q_dtype)
|
88
91
|
state['Q'] = triu_to_line(Q) if store_triu_as_line else Q
|
89
92
|
|
90
93
|
vals.append((p, g, state["Q"]))
|
@@ -103,8 +106,9 @@ class ForeachPurePSGD(PSGDBase):
|
|
103
106
|
q = line_to_triu(q_orig) if store_triu_as_line else q_orig
|
104
107
|
|
105
108
|
if do_update:
|
106
|
-
|
107
|
-
self.
|
109
|
+
q32 = [promote(q_) for q_ in q]
|
110
|
+
self.balance([g], [q32])
|
111
|
+
self.do_update([p], [g], [q32], precond_lr, [q_orig], store_triu_as_line=store_triu_as_line)
|
108
112
|
psgd_precond_grad(q, self.state_(p)["exprs"], g, inplace=True)
|
109
113
|
|
110
114
|
grad_list = self.clip_fn(grad_list)
|
@@ -33,7 +33,8 @@ class SFPaLMForeachSOAP(ScheduleFree):
|
|
33
33
|
weight_decay: float = 0.01, precondition_frequency: int = 2, max_precond_dim: int = 2048, #
|
34
34
|
merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
|
35
35
|
data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1, r=0.0,
|
36
|
-
weight_lr_power=2.0, gradient_clip_val: float = 0.1, betas=(None, None), split: bool = False
|
36
|
+
weight_lr_power=2.0, gradient_clip_val: float = 0.1, betas=(None, None), split: bool = False,
|
37
|
+
foreach: bool = True):
|
37
38
|
if betas[0] is not None:
|
38
39
|
beta = betas[0]
|
39
40
|
defaults = {"lr": lr, "beta": beta, "beta2_scale": beta2_scale, "eps": eps, "weight_decay": weight_decay,
|
@@ -42,7 +43,7 @@ class SFPaLMForeachSOAP(ScheduleFree):
|
|
42
43
|
"correct_bias": correct_bias, 'warmup_steps': warmup_steps, 'r': r,
|
43
44
|
'weight_lr_power': weight_lr_power, 'train_mode': True, 'step': -1,
|
44
45
|
'gradient_clip_val': gradient_clip_val, 'weight_sum': 0, 'split': split}
|
45
|
-
super().__init__(params, defaults)
|
46
|
+
super().__init__(params, defaults, foreach)
|
46
47
|
self._data_format = data_format
|
47
48
|
self.rng = random.Random(0x120983109)
|
48
49
|
|
@@ -51,7 +52,7 @@ class SFPaLMForeachSOAP(ScheduleFree):
|
|
51
52
|
max_precond_dim = group['max_precond_dim']
|
52
53
|
precondition_1d = group['precondition_1d']
|
53
54
|
|
54
|
-
step = group['step'] = group.get("step",
|
55
|
+
step = group['step'] = group.get("step", 0) + 1
|
55
56
|
|
56
57
|
for p in group["params"]:
|
57
58
|
if p.grad is None:
|
heavyball/utils.py
CHANGED
@@ -325,9 +325,9 @@ def compute_ggt(grad, GG, max_precond_dim, precondition_1d, beta):
|
|
325
325
|
|
326
326
|
|
327
327
|
def promote(x):
|
328
|
-
if x
|
328
|
+
if x in (torch.bfloat16, torch.float16):
|
329
329
|
return torch.float32
|
330
|
-
if x.dtype in (torch.bfloat16, torch.float16):
|
330
|
+
if hasattr(x, 'dtype') and x.dtype in (torch.bfloat16, torch.float16):
|
331
331
|
return x.float()
|
332
332
|
return x
|
333
333
|
|
@@ -383,8 +383,25 @@ def project(grad, Q, back: bool):
|
|
383
383
|
|
384
384
|
|
385
385
|
class StatefulOptimizer(torch.optim.Optimizer):
|
386
|
+
def __init__(self, params, defaults, foreach: bool = True):
|
387
|
+
super().__init__(params, {**defaults, 'foreach': foreach})
|
388
|
+
self.fake_groups = {}
|
389
|
+
|
390
|
+
def key(self, param: torch.Tensor):
|
391
|
+
return (param.data_ptr(), tuple(param.shape))
|
392
|
+
|
393
|
+
def get_groups(self, group):
|
394
|
+
if group['foreach']:
|
395
|
+
return [group]
|
396
|
+
|
397
|
+
for p in group['params']:
|
398
|
+
if self.key(p) not in self.fake_groups:
|
399
|
+
self.fake_groups[self.key(p)] = {**group, 'params': [p]}
|
400
|
+
|
401
|
+
return [self.fake_groups[self.key(p)] for p in group['params']]
|
402
|
+
|
386
403
|
def state_(self, arg: torch.Tensor):
|
387
|
-
return self.state[
|
404
|
+
return self.state[self.key(arg)]
|
388
405
|
|
389
406
|
def state_size(self) -> int:
|
390
407
|
total_bytes = 0
|
@@ -409,8 +426,9 @@ class StatefulOptimizer(torch.optim.Optimizer):
|
|
409
426
|
with torch.enable_grad():
|
410
427
|
loss = closure()
|
411
428
|
with torch.no_grad():
|
412
|
-
for
|
413
|
-
self.
|
429
|
+
for top_group in self.param_groups:
|
430
|
+
for group in self.get_groups(top_group):
|
431
|
+
self._step(group)
|
414
432
|
return loss
|
415
433
|
|
416
434
|
|
@@ -450,15 +468,15 @@ class ScheduleFree(StatefulOptimizer):
|
|
450
468
|
|
451
469
|
def copy_stochastic_list_(target: List[torch.Tensor], source: List[torch.Tensor]):
|
452
470
|
for t, s in zip(target, source):
|
453
|
-
|
454
|
-
copy_stochastic_(t, s)
|
455
|
-
else:
|
456
|
-
set_(t, s)
|
471
|
+
copy_stochastic_(t, s)
|
457
472
|
|
458
473
|
|
459
474
|
def copy_stochastic_(target: torch.Tensor, source: torch.Tensor):
|
460
475
|
if target.data_ptr() == source.data_ptr():
|
461
476
|
return
|
477
|
+
if target.dtype != torch.bfloat16:
|
478
|
+
set_(target, source)
|
479
|
+
return
|
462
480
|
|
463
481
|
"""Taken as-is from https://github.com/pytorch/pytorch/issues/120376#issuecomment-1974828905"""
|
464
482
|
# create a random 16 bit integer
|
@@ -537,7 +555,7 @@ def init_Q_exprs(t, scale, max_size, min_ndim_triangular, memory_save_mode, dtyp
|
|
537
555
|
for i, (size, dim_d) in enumerate(zip(shape, dim_diag)):
|
538
556
|
if size == 1 or size > max_size or len(shape) < min_ndim_triangular or dim_d:
|
539
557
|
# use diagonal matrix as preconditioner for this dim
|
540
|
-
Q.append(scale * torch.ones(size, dtype=dtype, device=t.device))
|
558
|
+
Q.append(scale * torch.ones(size, dtype=promote(dtype), device=t.device))
|
541
559
|
|
542
560
|
piece1A.append(letters[i])
|
543
561
|
piece2A = piece2A + letters[i]
|
@@ -651,11 +669,11 @@ def psgd_update_precond(Q, exprs, V, G, step, tiny):
|
|
651
669
|
@decorator
|
652
670
|
def psgd_precond_grad(Q, exprs, G, inplace: bool = False):
|
653
671
|
"""Precondition gradient G with preconditioner Q."""
|
654
|
-
out = torch.einsum(exprs[-1], *[q.conj() for q in Q], *Q, G)
|
672
|
+
out = torch.einsum(exprs[-1], *[q.conj() for q in Q], *Q, G.to(Q[0].dtype))
|
655
673
|
if inplace:
|
656
674
|
set_(G, out)
|
657
675
|
return G
|
658
|
-
return out
|
676
|
+
return out.to(G.dtype)
|
659
677
|
|
660
678
|
|
661
679
|
def norm_clip_(x, scale=None):
|
@@ -750,28 +768,33 @@ def line_to_triu(Q_list: List[Tuple[Optional[List[int]], torch.Tensor]]):
|
|
750
768
|
def update_triu_(q_state, materialised):
|
751
769
|
for (shape0, q), (shape1, m) in zip(q_state, triu_to_line(materialised)):
|
752
770
|
assert shape0 == shape1
|
753
|
-
|
771
|
+
copy_stochastic_(q, m)
|
754
772
|
|
755
773
|
|
756
774
|
class PSGDBase(StatefulOptimizer):
|
757
|
-
|
758
|
-
|
775
|
+
balance_probability: float = 0.01
|
776
|
+
|
777
|
+
def __init__(self, parameters, groups, foreach: bool = True):
|
778
|
+
super().__init__(parameters, groups, foreach)
|
759
779
|
self.rng = random.Random(0x1923213)
|
760
780
|
self._tiny = torch.finfo(torch.bfloat16).tiny
|
761
781
|
|
762
782
|
def balance(self, grad_list, Q_list):
|
763
|
-
if self.rng.random() >
|
783
|
+
if self.rng.random() > self.balance_probability:
|
764
784
|
return
|
765
785
|
|
766
786
|
for g, q in zip(grad_list, Q_list):
|
767
787
|
if g.dim() > 1:
|
768
788
|
psgd_balance_Q(q)
|
769
789
|
|
770
|
-
def do_update(self, p_list, grad_list, q_list, precond_lr, original_q: Optional[List] = None):
|
790
|
+
def do_update(self, p_list, grad_list, q_list, precond_lr, original_q: Optional[List] = None, store_triu_as_line=False):
|
771
791
|
for i, (p, grad, Q) in enumerate(zip(p_list, grad_list, q_list)):
|
772
792
|
psgd_update_precond(Q, self.state_(p)["exprs"], torch.randn_like(grad), grad, precond_lr, self._tiny)
|
773
793
|
if original_q:
|
774
|
-
|
794
|
+
if store_triu_as_line:
|
795
|
+
update_triu_(original_q[i], Q)
|
796
|
+
else:
|
797
|
+
copy_stochastic_(original_q[i], Q)
|
775
798
|
|
776
799
|
|
777
800
|
def precond_update_prob_schedule(max_prob=1.0, min_prob=0.03, decay=0.001, flat_start=250):
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: heavyball
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.17.0
|
4
4
|
Summary: Efficient optimizers
|
5
5
|
Home-page: https://github.com/clashluke/heavyball
|
6
6
|
Author: Lucas Nestler
|
@@ -39,12 +39,14 @@ recommended experimental optimizer is `ForeachPSGDKron`.
|
|
39
39
|
|
40
40
|
* **Stochastic Rounding**: [FP32 convergence with BF16 parameters](https://github.com/pytorch/pytorch/issues/120376)
|
41
41
|
* **Inplace EMA**: Same math, but less memory, less compute and higher stability
|
42
|
-
* **Foreach**: Fast multi-tensor application
|
42
|
+
* **Foreach**: Fast multi-tensor application (turn it off to save memory via `foreach=False`)
|
43
43
|
* **PaLM Beta2**: Fast initial
|
44
44
|
convergence, [stable late convergence](https://x.com/_clashluke/status/1820810798693818761)
|
45
45
|
* **ScheduleFree**: No learning rate schedule, but better convergence
|
46
46
|
* [**Preconditioner Schedule**](https://github.com/lixilinx/psgd_torch/): Improved loss-per-step in early convergence,
|
47
47
|
better step-per-second in late convergence (explained below)
|
48
|
+
* **Memory-efficient storage** PSGD supports `store_triu_as_line` (default: `True`) to trade off memory usage for memory
|
49
|
+
bandwidth; turn it off for lower overheads (for more, see [PSGD Efficiency](docs/psgd_efficiency.md))
|
48
50
|
|
49
51
|
## Getting started
|
50
52
|
|
@@ -0,0 +1,24 @@
|
|
1
|
+
heavyball/__init__.py,sha256=mDHahP4u0fy2YKWA4FPMAp7jLPMt5WwUkEiOrwE4u3E,2199
|
2
|
+
heavyball/cached_delayed_psgd_kron.py,sha256=DvjNNHzbnS-NDq965wve-VQ-ol7IFljYYGTuTwPHOhU,6971
|
3
|
+
heavyball/cached_psgd_kron.py,sha256=xy3-yRKFUvRTstJb_asMVp-k-5Zuw_HyILPi7BsuMKQ,6974
|
4
|
+
heavyball/delayed_psgd.py,sha256=rDDUj3miEn6HRJmKl-ZImsqkqBASSn8aC7MEV_06fzU,6017
|
5
|
+
heavyball/foreach_adamw.py,sha256=h_ar0ZZRM0q_wxuEkxEOEYe-2p-mB4OMgAHivrUnPl8,1777
|
6
|
+
heavyball/foreach_adopt.py,sha256=ogOw2JjwEQNj7AKlweAphQFdMJ_GcMDm-RyDvEzugoc,1911
|
7
|
+
heavyball/foreach_laprop.py,sha256=yGVmGqWiSw8Y2Xj70ndkR8ZMygakTB4_iRwV02Svkqg,1816
|
8
|
+
heavyball/foreach_sfadamw.py,sha256=15-n6-lx4PAHYsKYmXbugxsR5MnqaPYy2vUudPRiitg,2087
|
9
|
+
heavyball/foreach_soap.py,sha256=h6ptMch7oaynvu3eIJtWnVXypDA_5JDVm3Zb3PNEma0,4634
|
10
|
+
heavyball/p_adam.py,sha256=F2b-xGNROi9VfX7isa3kffWePojpBl5BI1n464w4tGQ,6334
|
11
|
+
heavyball/palm_foreach_sfadamw.py,sha256=yvZbPyjDW8qd3r4qDXb6uTr5RozQ7JSDj4aYYRnKGLA,2248
|
12
|
+
heavyball/palm_foreach_soap.py,sha256=g4hbiGRcti-J-a0SwAkP4ii5pU-aalsZH5bssyhroLk,5938
|
13
|
+
heavyball/precond_schedule_foreach_soap.py,sha256=WLg5SzpJnKPZUvFyIvdwSZa1Umt5cpr3Kow_42orM-E,4863
|
14
|
+
heavyball/precond_schedule_palm_foreach_soap.py,sha256=ammQrvRZFF-wc-wEiPEoFhS_7b8pdV61QfcLoQfimSo,6211
|
15
|
+
heavyball/precond_schedule_sfpsoap.py,sha256=vq7jd302refKPa_9X2lkOTOtCCcTBVByPdojklrY8pA,6770
|
16
|
+
heavyball/psgd_kron.py,sha256=2IpPj2TOExNGm8hSewi3er2GczJRNgC7r2J5yYSSA_0,5998
|
17
|
+
heavyball/pure_psgd.py,sha256=uA7W9a3Qm1sxHQhtNxaUYrmE5x55lP5iJOKy_qT8XaQ,5341
|
18
|
+
heavyball/schedule_free_palm_foreach_soap.py,sha256=zkcikH5wWbzq4kOrmBjilvY3iWzuUddcv2HNEPKr3MI,6366
|
19
|
+
heavyball/utils.py,sha256=Jqh7VdWGeiSdwaPtUNB9l14wuuFPSReLaTwJA3juFbM,28765
|
20
|
+
heavyball-0.17.0.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
|
21
|
+
heavyball-0.17.0.dist-info/METADATA,sha256=GIJQ4ha-fcYR6ltOs4WUO8L_LhWGiZv2UrEZcuJD0LI,11941
|
22
|
+
heavyball-0.17.0.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
23
|
+
heavyball-0.17.0.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
|
24
|
+
heavyball-0.17.0.dist-info/RECORD,,
|
@@ -1,23 +0,0 @@
|
|
1
|
-
heavyball/__init__.py,sha256=KbT0GMU0DKqZxq9laCrD7XgiqS9yxC1W52zhte5kjKs,2054
|
2
|
-
heavyball/cached_psgd_kron.py,sha256=mXDtxq2WJST_aUJhrLr_xCCXSFaDvD5gCTSEveBUtac,6754
|
3
|
-
heavyball/delayed_psgd.py,sha256=dN3NW1jmjxmUkgqxPwUVrqLY8nnBOFp4TVtJ_BhPDR4,5814
|
4
|
-
heavyball/foreach_adamw.py,sha256=NSzoIgNm7eavzbJgkAF0k7TUEnWAgOpt9-4juIFoaSA,1729
|
5
|
-
heavyball/foreach_adopt.py,sha256=WA07m5jocLfb1GPU8s6mJ2PteS-03ronkKm-VJrAm5I,1863
|
6
|
-
heavyball/foreach_laprop.py,sha256=mE2NDGX9XgvRhsewcWnk_-FulZPqGA65ejYF_9-A1Xk,1768
|
7
|
-
heavyball/foreach_sfadamw.py,sha256=ussHfPd99u3RTfMrCuu5oIbwNFLXK19wO1Fbz3JShlc,2097
|
8
|
-
heavyball/foreach_soap.py,sha256=WWvssYKg607uoEJHftp8ag8mtKSKSeHrT0QTgqBucVg,4587
|
9
|
-
heavyball/p_adam.py,sha256=ms7BoMHu3jKGsuztUeECrsXufGAwBpqGsxgZ5LBXLQg,6073
|
10
|
-
heavyball/palm_foreach_sfadamw.py,sha256=wjUb_fNZNUmzWXyKvwB0unP9lvNMmaYSQo5YoeS5cj0,2200
|
11
|
-
heavyball/palm_foreach_soap.py,sha256=2Sb4hUHQeexJcCgjHeQM_ENkZ6lG1DVxW72ryrvR6iY,5890
|
12
|
-
heavyball/precond_schedule_foreach_soap.py,sha256=bHsDyh-UvHpHjumjqqy0PePoR1ZMsJV6o5wWvpLAA04,4815
|
13
|
-
heavyball/precond_schedule_palm_foreach_soap.py,sha256=myLTJNQKLtZ3Xi3MVTB-RYtx_XeMRJw5CIMJW75ndUY,6163
|
14
|
-
heavyball/precond_schedule_sfpsoap.py,sha256=xeNWetBzBEYqfOSzl98aAVJsHk43QkrUUhHH_YD_mS4,6740
|
15
|
-
heavyball/psgd_kron.py,sha256=rMG5UPEgyfQs_n1MHSEicekVDpbbIzinlL8akEyY918,5795
|
16
|
-
heavyball/pure_psgd.py,sha256=LLVJhUAb04hgAmT3BTz_faswwQEQUkLhm_VwGQmbBUo,5088
|
17
|
-
heavyball/schedule_free_palm_foreach_soap.py,sha256=w0P7lMmoijTpL9V7NwOHcNBFJQ7S1TS9aCiwPhY2yVw,6319
|
18
|
-
heavyball/utils.py,sha256=PWmwjZPL4oxMjK79a5R1e7JHykphNi5GdpYqO_xmmFU,27829
|
19
|
-
heavyball-0.15.1.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
|
20
|
-
heavyball-0.15.1.dist-info/METADATA,sha256=0wImMJNYM-Zg0akh9hRf7X8ofVW6zlmpyDGgAkK5GFA,11667
|
21
|
-
heavyball-0.15.1.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
22
|
-
heavyball-0.15.1.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
|
23
|
-
heavyball-0.15.1.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|