heavyball 0.18.8__py3-none-any.whl → 0.20.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- heavyball/cached_delayed_psgd_kron.py +11 -11
- heavyball/cached_psgd_kron.py +13 -12
- heavyball/delayed_psgd.py +15 -18
- heavyball/foreach_adamw.py +7 -5
- heavyball/foreach_adopt.py +6 -4
- heavyball/foreach_laprop.py +10 -5
- heavyball/foreach_sfadamw.py +7 -4
- heavyball/foreach_soap.py +4 -7
- heavyball/p_adam.py +9 -9
- heavyball/palm_foreach_sfadamw.py +9 -4
- heavyball/palm_foreach_soap.py +6 -6
- heavyball/precond_schedule_foreach_soap.py +6 -10
- heavyball/precond_schedule_palm_foreach_soap.py +4 -4
- heavyball/precond_schedule_sfpsoap.py +20 -10
- heavyball/psgd_kron.py +15 -12
- heavyball/pure_psgd.py +3 -6
- heavyball/schedule_free_palm_foreach_soap.py +17 -8
- heavyball/utils.py +154 -56
- {heavyball-0.18.8.dist-info → heavyball-0.20.0.dist-info}/METADATA +18 -16
- heavyball-0.20.0.dist-info/RECORD +24 -0
- heavyball-0.18.8.dist-info/RECORD +0 -24
- {heavyball-0.18.8.dist-info → heavyball-0.20.0.dist-info}/LICENSE +0 -0
- {heavyball-0.18.8.dist-info → heavyball-0.20.0.dist-info}/WHEEL +0 -0
- {heavyball-0.18.8.dist-info → heavyball-0.20.0.dist-info}/top_level.txt +0 -0
@@ -9,7 +9,7 @@ from typing import Optional
|
|
9
9
|
import torch
|
10
10
|
|
11
11
|
from .utils import update_param_, warmup, init_Q_exprs, trust_region_clip_, PSGDBase, split_p_and_g_in_group, \
|
12
|
-
line_to_triu, triu_to_line,
|
12
|
+
line_to_triu, triu_to_line, einsum_base, promote, stochastic_lerp_, beta_debias
|
13
13
|
|
14
14
|
|
15
15
|
class ForeachCachedDelayedPSGDKron(PSGDBase):
|
@@ -41,7 +41,8 @@ class ForeachCachedDelayedPSGDKron(PSGDBase):
|
|
41
41
|
max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
|
42
42
|
momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
|
43
43
|
split: bool = False, clip_fn: Optional[callable] = None, store_triu_as_line: bool = True,
|
44
|
-
foreach: bool = True, q_dtype='float32', stochastic_schedule: bool = True,
|
44
|
+
foreach: bool = True, q_dtype='float32', stochastic_schedule: bool = True,
|
45
|
+
storage_dtype: str = 'float32', #
|
45
46
|
# expert parameters
|
46
47
|
precond_init_scale=1.0, precond_lr=0.1):
|
47
48
|
if not 0.0 <= lr:
|
@@ -58,7 +59,7 @@ class ForeachCachedDelayedPSGDKron(PSGDBase):
|
|
58
59
|
min_ndim_triangular=min_ndim_triangular, memory_save_mode=memory_save_mode,
|
59
60
|
momentum_into_precond_update=momentum_into_precond_update, precond_lr=precond_lr,
|
60
61
|
precond_init_scale=precond_init_scale, step=0, warmup_steps=warmup_steps, merge_dims=merge_dims,
|
61
|
-
split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype)
|
62
|
+
split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype, storage_dtype=storage_dtype)
|
62
63
|
super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
|
63
64
|
|
64
65
|
def _step(self, group):
|
@@ -74,14 +75,15 @@ class ForeachCachedDelayedPSGDKron(PSGDBase):
|
|
74
75
|
beta = group['beta']
|
75
76
|
store_triu_as_line = group['store_triu_as_line']
|
76
77
|
q_dtype = getattr(torch, group['q_dtype'])
|
78
|
+
storage_dtype = getattr(torch, group['storage_dtype'])
|
77
79
|
|
78
80
|
vals = []
|
79
81
|
|
80
|
-
for p, g in split_p_and_g_in_group(group):
|
82
|
+
for p, g in split_p_and_g_in_group(group, should_promote=False):
|
81
83
|
state = self.state_(p)
|
82
84
|
|
83
85
|
if 'Q' not in state:
|
84
|
-
state["exp_avg"] = torch.zeros_like(g)
|
86
|
+
state["exp_avg"] = torch.zeros_like(g, dtype=storage_dtype)
|
85
87
|
Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular, min_ndim_triangular,
|
86
88
|
memory_save_mode, dtype=q_dtype)
|
87
89
|
state['Q'] = triu_to_line(Q) if store_triu_as_line else Q
|
@@ -105,7 +107,9 @@ class ForeachCachedDelayedPSGDKron(PSGDBase):
|
|
105
107
|
|
106
108
|
group["step"] += 1
|
107
109
|
|
108
|
-
|
110
|
+
stochastic_lerp_(exp_avg_list, grad_list, 1 - beta_debias(beta, group['step']))
|
111
|
+
|
112
|
+
lr = -warmup(lr, group['step'], group['warmup_steps'])
|
109
113
|
|
110
114
|
grad_list, Q_list, Q_cache_list, exp_avg_list = list(grad_list), list(Q_list), list(Q_cache_list), list(
|
111
115
|
exp_avg_list)
|
@@ -127,8 +131,4 @@ class ForeachCachedDelayedPSGDKron(PSGDBase):
|
|
127
131
|
else:
|
128
132
|
torch.mul(q_.conj(), q_, out=c_)
|
129
133
|
|
130
|
-
|
131
|
-
grad_list = self.clip_fn(grad_list)
|
132
|
-
|
133
|
-
lr = -warmup(lr, group['step'], group['warmup_steps'])
|
134
|
-
update_param_(p_list, grad_list, lr, weight_decay)
|
134
|
+
update_param_([p], self.clip_fn([new]), lr, weight_decay)
|
heavyball/cached_psgd_kron.py
CHANGED
@@ -9,7 +9,7 @@ from typing import Optional
|
|
9
9
|
import torch
|
10
10
|
|
11
11
|
from .utils import update_param_, warmup, init_Q_exprs, trust_region_clip_, PSGDBase, split_p_and_g_in_group, \
|
12
|
-
line_to_triu, triu_to_line,
|
12
|
+
line_to_triu, triu_to_line, einsum_base, promote, stochastic_lerp_, beta_debias
|
13
13
|
|
14
14
|
|
15
15
|
class ForeachCachedPSGDKron(PSGDBase):
|
@@ -39,7 +39,8 @@ class ForeachCachedPSGDKron(PSGDBase):
|
|
39
39
|
max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
|
40
40
|
momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
|
41
41
|
split: bool = False, clip_fn: Optional[callable] = None, store_triu_as_line: bool = True,
|
42
|
-
foreach: bool = True, q_dtype='float32', stochastic_schedule: bool = True,
|
42
|
+
foreach: bool = True, q_dtype='float32', stochastic_schedule: bool = True,
|
43
|
+
storage_dtype: str = 'float32', #
|
43
44
|
# expert parameters
|
44
45
|
precond_init_scale=1.0, precond_lr=0.1):
|
45
46
|
if not 0.0 <= lr:
|
@@ -56,7 +57,8 @@ class ForeachCachedPSGDKron(PSGDBase):
|
|
56
57
|
min_ndim_triangular=min_ndim_triangular, memory_save_mode=memory_save_mode,
|
57
58
|
momentum_into_precond_update=momentum_into_precond_update, precond_lr=precond_lr,
|
58
59
|
precond_init_scale=precond_init_scale, step=0, warmup_steps=warmup_steps, merge_dims=merge_dims,
|
59
|
-
split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype
|
60
|
+
split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype,
|
61
|
+
storage_dtype=storage_dtype)
|
60
62
|
super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
|
61
63
|
|
62
64
|
def _step(self, group):
|
@@ -71,15 +73,16 @@ class ForeachCachedPSGDKron(PSGDBase):
|
|
71
73
|
beta = group['beta']
|
72
74
|
store_triu_as_line = group['store_triu_as_line']
|
73
75
|
q_dtype = getattr(torch, group['q_dtype'])
|
76
|
+
storage_dtype = getattr(torch, group['storage_dtype'])
|
74
77
|
should_update = self.should_update(group)
|
75
78
|
|
76
79
|
vals = []
|
77
80
|
|
78
|
-
for p, g in split_p_and_g_in_group(group):
|
81
|
+
for p, g in split_p_and_g_in_group(group, should_promote=False):
|
79
82
|
state = self.state_(p)
|
80
83
|
|
81
84
|
if 'Q' not in state:
|
82
|
-
state["exp_avg"] = torch.zeros_like(g)
|
85
|
+
state["exp_avg"] = torch.zeros_like(g, dtype=storage_dtype)
|
83
86
|
Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular, min_ndim_triangular,
|
84
87
|
memory_save_mode, dtype=q_dtype)
|
85
88
|
state['Q'] = triu_to_line(Q) if store_triu_as_line else Q
|
@@ -103,7 +106,9 @@ class ForeachCachedPSGDKron(PSGDBase):
|
|
103
106
|
|
104
107
|
group["step"] += 1
|
105
108
|
|
106
|
-
|
109
|
+
stochastic_lerp_(exp_avg_list, grad_list, 1 - beta_debias(beta, group['step']))
|
110
|
+
|
111
|
+
lr = -warmup(lr, group['step'], group['warmup_steps'])
|
107
112
|
|
108
113
|
grad_list, Q_list, Q_cache_list, exp_avg_list = list(grad_list), list(Q_list), list(Q_cache_list), list(
|
109
114
|
exp_avg_list)
|
@@ -123,9 +128,5 @@ class ForeachCachedPSGDKron(PSGDBase):
|
|
123
128
|
else:
|
124
129
|
torch.mul(q_.conj(), q_, out=c_)
|
125
130
|
|
126
|
-
|
127
|
-
|
128
|
-
grad_list = self.clip_fn(grad_list)
|
129
|
-
|
130
|
-
lr = -warmup(lr, group['step'], group['warmup_steps'])
|
131
|
-
update_param_(p_list, grad_list, lr, weight_decay)
|
131
|
+
g = torch.einsum(self.state_(p)['cache_expr'], *cached_q, ea)
|
132
|
+
update_param_([p], self.clip_fn([g]), lr, weight_decay)
|
heavyball/delayed_psgd.py
CHANGED
@@ -5,10 +5,10 @@ Source available at https://github.com/evanatyourservice/kron_torch/blob/97a2b5e
|
|
5
5
|
"""
|
6
6
|
|
7
7
|
import torch
|
8
|
-
from heavyball.utils import copy_stochastic_list_
|
9
8
|
|
9
|
+
from heavyball.utils import stochastic_lerp_, beta_debias
|
10
10
|
from .utils import update_param_, warmup, psgd_precond_grad, init_Q_exprs, trust_region_clip_, PSGDBase, \
|
11
|
-
|
11
|
+
split_p_and_g_in_group, triu_to_line, line_to_triu, promote
|
12
12
|
|
13
13
|
|
14
14
|
class ForeachDelayedPSGD(PSGDBase):
|
@@ -38,8 +38,8 @@ class ForeachDelayedPSGD(PSGDBase):
|
|
38
38
|
def __init__(self, params, lr=0.001, beta=0.9, weight_decay=0.0, preconditioner_update_probability=None,
|
39
39
|
max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
|
40
40
|
momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
|
41
|
-
split: bool = False, clip_fn: callable = None, store_triu_as_line: bool = True,
|
42
|
-
|
41
|
+
split: bool = False, clip_fn: callable = None, store_triu_as_line: bool = True, foreach: bool = True,
|
42
|
+
q_dtype='float32', stochastic_schedule: bool = True, storage_dtype:str='float32', #
|
43
43
|
# expert parameters
|
44
44
|
precond_init_scale=1.0, precond_lr=0.1):
|
45
45
|
if not 0.0 <= lr:
|
@@ -55,12 +55,10 @@ class ForeachDelayedPSGD(PSGDBase):
|
|
55
55
|
defaults = dict(lr=lr, beta=beta, weight_decay=weight_decay, max_size_triangular=max_size_triangular,
|
56
56
|
min_ndim_triangular=min_ndim_triangular, memory_save_mode=memory_save_mode,
|
57
57
|
momentum_into_precond_update=momentum_into_precond_update, precond_lr=precond_lr,
|
58
|
-
precond_init_scale=precond_init_scale,
|
59
|
-
|
60
|
-
store_triu_as_line=store_triu_as_line, q_dtype=q_dtype)
|
58
|
+
precond_init_scale=precond_init_scale, step=0, warmup_steps=warmup_steps, merge_dims=merge_dims,
|
59
|
+
split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype)
|
61
60
|
super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
|
62
61
|
|
63
|
-
|
64
62
|
def _step(self, group):
|
65
63
|
should_update = self.should_update(group)
|
66
64
|
momentum_into_precond_update = group.get("momentum_into_precond_update", True)
|
@@ -74,14 +72,15 @@ class ForeachDelayedPSGD(PSGDBase):
|
|
74
72
|
beta = group['beta']
|
75
73
|
store_triu_as_line = group['store_triu_as_line']
|
76
74
|
q_dtype = getattr(torch, group['q_dtype'])
|
75
|
+
storage_dtype = getattr(torch, group['storage_dtype'])
|
77
76
|
|
78
77
|
vals = []
|
79
78
|
|
80
|
-
for p, g in split_p_and_g_in_group(group):
|
79
|
+
for p, g in split_p_and_g_in_group(group, should_promote=False):
|
81
80
|
state = self.state_(p)
|
82
81
|
|
83
82
|
if 'Q' not in state:
|
84
|
-
state["exp_avg"] = torch.zeros_like(g)
|
83
|
+
state["exp_avg"] = torch.zeros_like(g, dtype=storage_dtype)
|
85
84
|
Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular, min_ndim_triangular,
|
86
85
|
memory_save_mode, dtype=q_dtype)
|
87
86
|
state["Q"] = triu_to_line(Q) if store_triu_as_line else Q
|
@@ -96,7 +95,9 @@ class ForeachDelayedPSGD(PSGDBase):
|
|
96
95
|
|
97
96
|
group["step"] += 1
|
98
97
|
|
99
|
-
|
98
|
+
stochastic_lerp_(exp_avg_list, grad_list, beta_debias(beta, group["step"]))
|
99
|
+
|
100
|
+
lr = -warmup(lr, group['step'], group['warmup_steps'])
|
100
101
|
|
101
102
|
Q_list, exp_avg_list = list(Q_list), list(exp_avg_list)
|
102
103
|
for i, (p, g) in enumerate(zip(p_list, grad_list)):
|
@@ -106,10 +107,6 @@ class ForeachDelayedPSGD(PSGDBase):
|
|
106
107
|
new = psgd_precond_grad(q, self.state_(p)["exprs"], ea)
|
107
108
|
if should_update:
|
108
109
|
q32 = [promote(q_) for q_ in q]
|
109
|
-
self.do_update(group,[p], [ea if momentum_into_precond_update else g], [q32], precond_lr, [q_orig],
|
110
|
-
|
111
|
-
|
112
|
-
grad_list = self.clip_fn(grad_list)
|
113
|
-
|
114
|
-
lr = -warmup(lr, group['step'], group['warmup_steps'])
|
115
|
-
update_param_(p_list, grad_list, lr, weight_decay)
|
110
|
+
self.do_update(group, [p], [ea if momentum_into_precond_update else g], [q32], precond_lr, [q_orig],
|
111
|
+
store_triu_as_line)
|
112
|
+
update_param_([p], self.clip_fn([new]), lr, weight_decay)
|
heavyball/foreach_adamw.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
import torch
|
2
2
|
import torch.optim
|
3
|
-
from heavyball.utils import copy_stochastic_list_
|
4
3
|
|
4
|
+
from heavyball.utils import copy_stochastic_list_
|
5
5
|
from .utils import warmup, exp_avg_sq_, beta_debias, update_param_, StatefulOptimizer, promote
|
6
6
|
|
7
7
|
|
@@ -20,9 +20,9 @@ def _compilable_step_(y, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps,
|
|
20
20
|
|
21
21
|
class ForeachAdamW(StatefulOptimizer):
|
22
22
|
def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
|
23
|
-
foreach: bool = True):
|
23
|
+
foreach: bool = True, storage_dtype: str = 'float32'):
|
24
24
|
defaults = dict(lr=lr, betas=betas, eps=eps, k=0, warmup_steps=warmup_steps, train_mode=True, weight_sum=0.0,
|
25
|
-
lr_max=-1.0, weight_decay=weight_decay)
|
25
|
+
lr_max=-1.0, weight_decay=weight_decay, storage_dtype=storage_dtype)
|
26
26
|
super().__init__(params, defaults, foreach)
|
27
27
|
|
28
28
|
def _step(self, group):
|
@@ -38,10 +38,12 @@ class ForeachAdamW(StatefulOptimizer):
|
|
38
38
|
if not active_p:
|
39
39
|
return
|
40
40
|
|
41
|
+
storage_dtype = getattr(torch, group['storage_dtype'])
|
42
|
+
|
41
43
|
for p in active_p:
|
42
44
|
if 'exp_avg' not in self.state_(p):
|
43
|
-
self.state_(p)['exp_avg'] = torch.zeros_like(p.data, dtype=
|
44
|
-
self.state_(p)['exp_avg_sq'] = torch.zeros_like(p.data, dtype=
|
45
|
+
self.state_(p)['exp_avg'] = torch.zeros_like(p.data, dtype=storage_dtype)
|
46
|
+
self.state_(p)['exp_avg_sq'] = torch.zeros_like(p.data, dtype=storage_dtype)
|
45
47
|
|
46
48
|
y, grad, exp_avg_sq, exp_avg = zip(
|
47
49
|
*[(p.data, p.grad, self.state_(p)['exp_avg_sq'], self.state_(p)['exp_avg']) for p in active_p])
|
heavyball/foreach_adopt.py
CHANGED
@@ -27,9 +27,9 @@ def _compilable_step_(y, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps,
|
|
27
27
|
class ForeachADOPT(StatefulOptimizer):
|
28
28
|
|
29
29
|
def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
|
30
|
-
foreach: bool = True):
|
30
|
+
foreach: bool = True, storage_dtype: str = 'float32'):
|
31
31
|
defaults = dict(lr=lr, betas=betas, eps=eps, k=0, warmup_steps=warmup_steps, train_mode=True, weight_sum=0.0,
|
32
|
-
lr_max=-1.0, weight_decay=weight_decay)
|
32
|
+
lr_max=-1.0, weight_decay=weight_decay, storage_dtype=storage_dtype)
|
33
33
|
super().__init__(params, defaults, foreach)
|
34
34
|
|
35
35
|
def _step(self, group):
|
@@ -45,10 +45,12 @@ class ForeachADOPT(StatefulOptimizer):
|
|
45
45
|
if not active_p:
|
46
46
|
return
|
47
47
|
|
48
|
+
storage_dtype = getattr(torch, group['storage_dtype'])
|
49
|
+
|
48
50
|
for p in active_p:
|
49
51
|
if 'exp_avg' not in self.state_(p):
|
50
|
-
self.state_(p)['exp_avg'] = torch.zeros_like(p.data, dtype=
|
51
|
-
self.state_(p)['exp_avg_sq'] = torch.zeros_like(p.data, dtype=
|
52
|
+
self.state_(p)['exp_avg'] = torch.zeros_like(p.data, dtype=storage_dtype)
|
53
|
+
self.state_(p)['exp_avg_sq'] = torch.zeros_like(p.data, dtype=storage_dtype)
|
52
54
|
|
53
55
|
y, grad, exp_avg_sq, exp_avg = zip(
|
54
56
|
*[(p.data, p.grad, self.state_(p)['exp_avg_sq'], self.state_(p)['exp_avg']) for p in active_p])
|
heavyball/foreach_laprop.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
import torch
|
2
2
|
import torch.optim
|
3
3
|
|
4
|
-
from .utils import warmup, exp_avg_sq_, beta_debias, update_param_, StatefulOptimizer, promote
|
4
|
+
from .utils import warmup, exp_avg_sq_, beta_debias, update_param_, StatefulOptimizer, promote, copy_stochastic_list_
|
5
5
|
|
6
6
|
|
7
7
|
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
|
@@ -16,13 +16,16 @@ def _compilable_step_(y, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps,
|
|
16
16
|
|
17
17
|
update_param_(y, exp_avg32, lr, decay)
|
18
18
|
|
19
|
+
copy_stochastic_list_(exp_avg, exp_avg32)
|
20
|
+
copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
|
21
|
+
|
19
22
|
|
20
23
|
class ForeachLaProp(StatefulOptimizer):
|
21
24
|
|
22
25
|
def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=1,
|
23
|
-
foreach: bool = True):
|
26
|
+
foreach: bool = True, storage_dtype: str = 'float32'):
|
24
27
|
defaults = dict(lr=lr, betas=betas, eps=eps, k=0, warmup_steps=warmup_steps, train_mode=True, weight_sum=0.0,
|
25
|
-
lr_max=-1.0, weight_decay=weight_decay)
|
28
|
+
lr_max=-1.0, weight_decay=weight_decay, storage_dtype=storage_dtype)
|
26
29
|
super().__init__(params, defaults, foreach)
|
27
30
|
|
28
31
|
def _step(self, group):
|
@@ -38,10 +41,12 @@ class ForeachLaProp(StatefulOptimizer):
|
|
38
41
|
if not active_p:
|
39
42
|
return
|
40
43
|
|
44
|
+
storage_dtype = getattr(torch, group['storage_dtype'])
|
45
|
+
|
41
46
|
for p in active_p:
|
42
47
|
if 'exp_avg' not in self.state_(p):
|
43
|
-
self.state_(p)['exp_avg'] = torch.zeros_like(p.data, dtype=
|
44
|
-
self.state_(p)['exp_avg_sq'] = torch.zeros_like(p.data, dtype=
|
48
|
+
self.state_(p)['exp_avg'] = torch.zeros_like(p.data, dtype=storage_dtype)
|
49
|
+
self.state_(p)['exp_avg_sq'] = torch.zeros_like(p.data, dtype=storage_dtype)
|
45
50
|
|
46
51
|
y, grad, exp_avg_sq, exp_avg = zip(
|
47
52
|
*[(p.data, p.grad, self.state_(p)['exp_avg_sq'], self.state_(p)['exp_avg']) #
|
heavyball/foreach_sfadamw.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
import torch
|
2
2
|
import torch.optim
|
3
|
-
from heavyball.utils import get_ckp1
|
3
|
+
from heavyball.utils import get_ckp1, copy_stochastic_list_
|
4
4
|
|
5
5
|
from .utils import warmup, ScheduleFree, exp_avg_sq_, beta_debias, promote, _compilable_schedule_free_
|
6
6
|
|
@@ -19,14 +19,15 @@ def _compilable_step_(y, grad, exp_avg_sq, z, beta1, beta2, step, ckp1, eps, dec
|
|
19
19
|
for p, z_, g in zip(y, z, g32):
|
20
20
|
_compilable_schedule_free_(p, z_, ckp1, g, lr, beta1)
|
21
21
|
|
22
|
+
copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
|
22
23
|
|
23
24
|
class ForeachSFAdamW(ScheduleFree):
|
24
25
|
def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0, r=0.0,
|
25
|
-
weight_lr_power=2.0, foreach: bool = True):
|
26
|
+
weight_lr_power=2.0, foreach: bool = True, storage_dtype: str = 'float32'):
|
26
27
|
|
27
28
|
defaults = dict(lr=lr, betas=betas, eps=eps, r=r, k=0, warmup_steps=warmup_steps, train_mode=True,
|
28
29
|
weight_sum=0.0, lr_max=-1.0, weight_lr_power=weight_lr_power, weight_decay=weight_decay,
|
29
|
-
foreach=foreach)
|
30
|
+
foreach=foreach, storage_dtype=storage_dtype)
|
30
31
|
super().__init__(params, defaults, foreach)
|
31
32
|
|
32
33
|
def _step(self, group):
|
@@ -42,10 +43,12 @@ class ForeachSFAdamW(ScheduleFree):
|
|
42
43
|
if not active_p:
|
43
44
|
return
|
44
45
|
|
46
|
+
storage_dtype = getattr(torch, group['storage_dtype'])
|
47
|
+
|
45
48
|
for p in active_p:
|
46
49
|
if 'z' not in self.state_(p):
|
47
50
|
self.state_(p)['z'] = torch.clone(p.data)
|
48
|
-
self.state_(p)['exp_avg_sq'] = torch.zeros_like(p.data, dtype=
|
51
|
+
self.state_(p)['exp_avg_sq'] = torch.zeros_like(p.data, dtype=storage_dtype)
|
49
52
|
|
50
53
|
y, grad, exp_avg_sq, z = zip(*[(p.data, p.grad, self.state_(p)['exp_avg_sq'], self.state_(p)['z']) #
|
51
54
|
for p in active_p])
|
heavyball/foreach_soap.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
import torch
|
2
2
|
|
3
3
|
from .utils import init_preconditioner, update_preconditioner, project, beta_debias, exp_avg_sq_, update_param_, set_, \
|
4
|
-
split_p_and_g_in_group, StatefulOptimizer
|
4
|
+
split_p_and_g_in_group, StatefulOptimizer, exp_avg_
|
5
5
|
|
6
6
|
|
7
7
|
class ForeachSOAP(StatefulOptimizer):
|
@@ -26,8 +26,7 @@ class ForeachSOAP(StatefulOptimizer):
|
|
26
26
|
weight_decay: float = 0.01, precondition_frequency: int = 2, max_precond_dim: int = 2048, #
|
27
27
|
merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
|
28
28
|
data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1,
|
29
|
-
split: bool = False,
|
30
|
-
foreach: bool = True):
|
29
|
+
split: bool = False, foreach: bool = True):
|
31
30
|
defaults = {"lr": lr, "betas": betas, "shampoo_beta": shampoo_beta, "eps": eps, "weight_decay": weight_decay,
|
32
31
|
"precondition_frequency": precondition_frequency, "max_precond_dim": max_precond_dim,
|
33
32
|
"merge_dims": merge_dims, "precondition_1d": precondition_1d, "normalize_grads": normalize_grads,
|
@@ -65,14 +64,12 @@ class ForeachSOAP(StatefulOptimizer):
|
|
65
64
|
p_list, grad, grad_projected, exp_avg, exp_avg_sq = zip(*vals)
|
66
65
|
beta1, beta2 = group["betas"]
|
67
66
|
|
68
|
-
old_debiased1 = beta_debias(beta1, step)
|
69
67
|
old_debiased2 = beta_debias(beta2, step)
|
70
68
|
|
71
69
|
# Decay the first and second moment running average coefficient
|
72
70
|
# In-place operations to update the averages at the same time
|
73
|
-
torch.
|
74
|
-
|
75
|
-
denom = exp_avg_sq_(exp_avg_sq, grad_projected, old_debiased2, group['eps'])
|
71
|
+
step_tensor = torch.empty((), dtype=torch.int32, device=p_list[0].device).fill_(step)
|
72
|
+
denom = exp_avg_(exp_avg, exp_avg_sq, grad, grad_projected, beta1, beta2, step_tensor)
|
76
73
|
|
77
74
|
for p, g, ea, d in zip(p_list, grad, exp_avg, denom):
|
78
75
|
state = self.state_(p)
|
heavyball/p_adam.py
CHANGED
@@ -39,7 +39,7 @@ class ForeachPaLMPAdam(PSGDBase):
|
|
39
39
|
momentum_into_precond_update=True, warmup_steps: int = 1, betas=(None, None), beta: float = 0.9,
|
40
40
|
beta2_scale: float = 0.8, merge_dims: bool = False, split: bool = False, clip_fn: callable = None,
|
41
41
|
store_triu_as_line: bool = True, foreach: bool = True, q_dtype='float32',
|
42
|
-
stochastic_schedule: bool = True,
|
42
|
+
stochastic_schedule: bool = True, storage_dtype:str ='float32',#
|
43
43
|
# expert parameters
|
44
44
|
precond_init_scale=1.0, precond_lr=0.1):
|
45
45
|
if not 0.0 <= lr:
|
@@ -57,7 +57,7 @@ class ForeachPaLMPAdam(PSGDBase):
|
|
57
57
|
momentum_into_precond_update=momentum_into_precond_update, precond_lr=precond_lr,
|
58
58
|
precond_init_scale=precond_init_scale, step=0, warmup_steps=warmup_steps, beta=beta,
|
59
59
|
beta2_scale=beta2_scale, merge_dims=merge_dims, split=split,
|
60
|
-
store_triu_as_line=store_triu_as_line, q_dtype=q_dtype)
|
60
|
+
store_triu_as_line=store_triu_as_line, q_dtype=q_dtype, storage_dtype=storage_dtype)
|
61
61
|
super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
|
62
62
|
|
63
63
|
def _step(self, group):
|
@@ -71,15 +71,16 @@ class ForeachPaLMPAdam(PSGDBase):
|
|
71
71
|
lr = group['lr']
|
72
72
|
store_triu_as_line = group['store_triu_as_line']
|
73
73
|
q_dtype = getattr(torch, group['q_dtype'])
|
74
|
+
storage_dtype = getattr(torch, group['storage_dtype'])
|
74
75
|
|
75
76
|
vals = []
|
76
77
|
|
77
|
-
for p, g in split_p_and_g_in_group(group):
|
78
|
+
for p, g in split_p_and_g_in_group(group, should_promote=False):
|
78
79
|
state = self.state_(p)
|
79
80
|
|
80
81
|
if 'Q' not in state:
|
81
|
-
state['exp_avg'] = torch.zeros_like(g)
|
82
|
-
state['exp_avg_sq'] = torch.zeros_like(g)
|
82
|
+
state['exp_avg'] = torch.zeros_like(g, dtype=storage_dtype)
|
83
|
+
state['exp_avg_sq'] = torch.zeros_like(g, dtype=storage_dtype)
|
83
84
|
Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular, min_ndim_triangular,
|
84
85
|
memory_save_mode, dtype=q_dtype)
|
85
86
|
state['Q'] = triu_to_line(Q) if store_triu_as_line else Q
|
@@ -103,6 +104,8 @@ class ForeachPaLMPAdam(PSGDBase):
|
|
103
104
|
|
104
105
|
beta2 = 1 - group['step'] ** -group['beta2_scale']
|
105
106
|
|
107
|
+
lr = -warmup(lr, group['step'], group['warmup_steps'])
|
108
|
+
|
106
109
|
for p, Q, g, ea, eas in zip(p_list, Q_triu, grad_list, exp_avg, exp_avg_sq):
|
107
110
|
psgd_precond_grad(Q, self.state_(p)["exprs"], g, inplace=True)
|
108
111
|
ea = psgd_precond_grad(Q, self.state_(p)["exprs"], ea)
|
@@ -112,8 +115,5 @@ class ForeachPaLMPAdam(PSGDBase):
|
|
112
115
|
divide by g here, because g == denom (from exp_avg_sq_(out=g)), avoids denom allocation
|
113
116
|
divide into g so we can deallocate ea, avoids one allocation (-> less memory than equivalent foreach)
|
114
117
|
"""
|
118
|
+
update_param_([p], self.clip_fn([g]), lr, weight_decay)
|
115
119
|
|
116
|
-
grad_list = self.clip_fn(grad_list)
|
117
|
-
|
118
|
-
lr = -warmup(lr, group['step'], group['warmup_steps'])
|
119
|
-
update_param_(p_list, grad_list, lr, weight_decay)
|
@@ -1,7 +1,8 @@
|
|
1
1
|
import torch
|
2
2
|
import torch.optim
|
3
3
|
|
4
|
-
from .utils import
|
4
|
+
from .utils import warmup, ScheduleFree, exp_avg_sq_, beta_debias, get_ckp1, promote, \
|
5
|
+
_compilable_schedule_free_, copy_stochastic_list_
|
5
6
|
|
6
7
|
|
7
8
|
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
|
@@ -18,15 +19,17 @@ def _compilable_step_(y, grad, exp_avg_sq, z, beta1, beta2, step, ckp1, eps, dec
|
|
18
19
|
for p, z_, g in zip(y, z, g32):
|
19
20
|
_compilable_schedule_free_(p, z_, ckp1, g, lr, beta1)
|
20
21
|
|
22
|
+
copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
|
23
|
+
|
21
24
|
|
22
25
|
class PaLMForeachSFAdamW(ScheduleFree):
|
23
26
|
def __init__(self, params, lr=0.0025, beta=0.9, betas=(None, None), eps=1e-8, weight_decay=0, warmup_steps=0, r=0.0,
|
24
|
-
weight_lr_power=2.0, beta2_scale: float = 0.8, foreach: bool = True):
|
27
|
+
weight_lr_power=2.0, beta2_scale: float = 0.8, foreach: bool = True, storage_dtype: str = 'float32'):
|
25
28
|
if betas[0] is not None:
|
26
29
|
beta = betas[0]
|
27
30
|
defaults = dict(lr=lr, beta=beta, eps=eps, r=r, k=0, warmup_steps=warmup_steps, train_mode=True, weight_sum=0.0,
|
28
31
|
lr_max=-1.0, weight_lr_power=weight_lr_power, weight_decay=weight_decay,
|
29
|
-
beta2_scale=beta2_scale)
|
32
|
+
beta2_scale=beta2_scale, storage_dtype=storage_dtype)
|
30
33
|
super().__init__(params, defaults, foreach)
|
31
34
|
|
32
35
|
def _step(self, group):
|
@@ -42,10 +45,12 @@ class PaLMForeachSFAdamW(ScheduleFree):
|
|
42
45
|
if not active_p:
|
43
46
|
return
|
44
47
|
|
48
|
+
storage_dtype = getattr(torch, group['storage_dtype'])
|
49
|
+
|
45
50
|
for p in active_p:
|
46
51
|
if 'z' not in self.state_(p):
|
47
52
|
self.state_(p)['z'] = torch.clone(p.data)
|
48
|
-
self.state_(p)['exp_avg_sq'] = torch.zeros_like(p.data, dtype=
|
53
|
+
self.state_(p)['exp_avg_sq'] = torch.zeros_like(p.data, dtype=storage_dtype)
|
49
54
|
|
50
55
|
# Decay the first moment running average coefficient
|
51
56
|
beta2 = 1 - (k + 1) ** -group['beta2_scale']
|
heavyball/palm_foreach_soap.py
CHANGED
@@ -1,7 +1,8 @@
|
|
1
1
|
import torch
|
2
2
|
|
3
3
|
from .utils import init_preconditioner, update_preconditioner, project, beta_debias, exp_avg_sq_, update_param_, set_, \
|
4
|
-
split_p_and_g_in_group, StatefulOptimizer
|
4
|
+
split_p_and_g_in_group, StatefulOptimizer, exp_avg_
|
5
|
+
|
5
6
|
|
6
7
|
|
7
8
|
class PaLMForeachSOAP(StatefulOptimizer):
|
@@ -32,8 +33,7 @@ class PaLMForeachSOAP(StatefulOptimizer):
|
|
32
33
|
max_precond_dim: int = 2048, #
|
33
34
|
merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
|
34
35
|
data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1,
|
35
|
-
beta2_scale: float = 0.8, split: bool = False,
|
36
|
-
foreach: bool = True):
|
36
|
+
beta2_scale: float = 0.8, split: bool = False, foreach: bool = True):
|
37
37
|
if betas[0] is not None:
|
38
38
|
beta = betas[0]
|
39
39
|
defaults = {"lr": lr, "beta": beta, "shampoo_beta": shampoo_beta, "eps": eps, "weight_decay": weight_decay,
|
@@ -75,13 +75,13 @@ class PaLMForeachSOAP(StatefulOptimizer):
|
|
75
75
|
beta1 = group["beta"]
|
76
76
|
|
77
77
|
beta2 = 1 - step ** -group['beta2_scale']
|
78
|
-
old_debiased1 = beta_debias(beta1, step)
|
79
78
|
old_debiased2 = beta_debias(beta2, step)
|
80
79
|
|
81
80
|
# Decay the first and second moment running average coefficient
|
82
81
|
# In-place operations to update the averages at the same time
|
83
|
-
torch.
|
84
|
-
|
82
|
+
beta2 = torch.empty((), dtype=torch.float32, device=p_list[0].device).fill_(beta2)
|
83
|
+
step_tensor = torch.empty((), dtype=torch.int32, device=p_list[0].device).fill_(step)
|
84
|
+
denom = exp_avg_(exp_avg, exp_avg_sq, grad, grad_projected, beta1, beta2, step_tensor)
|
85
85
|
|
86
86
|
for p, g, ea, d in zip(p_list, grad, exp_avg, denom):
|
87
87
|
state = self.state_(p)
|
@@ -2,8 +2,8 @@ import random
|
|
2
2
|
|
3
3
|
import torch
|
4
4
|
|
5
|
-
from .utils import init_preconditioner, update_preconditioner, project, beta_debias,
|
6
|
-
precond_schedule, set_, split_p_and_g_in_group, StatefulOptimizer
|
5
|
+
from .utils import init_preconditioner, update_preconditioner, project, beta_debias, update_param_, \
|
6
|
+
precond_schedule, set_, split_p_and_g_in_group, StatefulOptimizer, exp_avg_
|
7
7
|
|
8
8
|
|
9
9
|
class PrecondScheduleForeachSOAP(StatefulOptimizer):
|
@@ -27,8 +27,7 @@ class PrecondScheduleForeachSOAP(StatefulOptimizer):
|
|
27
27
|
weight_decay: float = 0.01, precondition_frequency: int = 2, max_precond_dim: int = 2048, #
|
28
28
|
merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
|
29
29
|
data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1,
|
30
|
-
precond_scheduler=(1 / 3, 9), split: bool = False,
|
31
|
-
foreach: bool = True):
|
30
|
+
precond_scheduler=(1 / 3, 9), split: bool = False, foreach: bool = True):
|
32
31
|
defaults = {"lr": lr, "betas": betas, "shampoo_beta": shampoo_beta, "eps": eps, "weight_decay": weight_decay,
|
33
32
|
"precondition_frequency": precondition_frequency, "max_precond_dim": max_precond_dim,
|
34
33
|
"merge_dims": merge_dims, "precondition_1d": precondition_1d, "normalize_grads": normalize_grads,
|
@@ -68,14 +67,12 @@ class PrecondScheduleForeachSOAP(StatefulOptimizer):
|
|
68
67
|
p_list, grad, grad_projected, exp_avg, exp_avg_sq = zip(*vals)
|
69
68
|
beta1, beta2 = group["betas"]
|
70
69
|
|
71
|
-
old_debiased1 = beta_debias(beta1, step)
|
72
70
|
old_debiased2 = beta_debias(beta2, step)
|
73
71
|
|
74
72
|
# Decay the first and second moment running average coefficient
|
75
73
|
# In-place operations to update the averages at the same time
|
76
|
-
torch.
|
77
|
-
|
78
|
-
denom = exp_avg_sq_(exp_avg_sq, grad_projected, old_debiased2, group['eps'])
|
74
|
+
step_tensor = torch.empty((), dtype=torch.int32, device=p_list[0].device).fill_(step)
|
75
|
+
denom = exp_avg_(exp_avg, exp_avg_sq, grad, grad_projected, beta1, beta2, step_tensor)
|
79
76
|
|
80
77
|
update_precond = precond_schedule(step, group['precond_scheduler'], self.rng)
|
81
78
|
for p, g, ea, d in zip(p_list, grad, exp_avg, denom):
|
@@ -89,8 +86,7 @@ class PrecondScheduleForeachSOAP(StatefulOptimizer):
|
|
89
86
|
# CANT DO /= HERE AS EXP_AVG MAY POINT TO THE BUFFER
|
90
87
|
set_(d, project(exp_avg_projected / d, state['Q'], True))
|
91
88
|
|
92
|
-
update_preconditioner(g, state, max_precond_dim, precondition_1d, old_debiased2,
|
93
|
-
update_precond)
|
89
|
+
update_preconditioner(g, state, max_precond_dim, precondition_1d, old_debiased2, update_precond)
|
94
90
|
|
95
91
|
# Why does this have to be rebiased here?
|
96
92
|
step_size = -group["lr"] * min(step / group['warmup_steps'], 1)
|
@@ -2,7 +2,7 @@ import random
|
|
2
2
|
|
3
3
|
import torch
|
4
4
|
|
5
|
-
from .utils import init_preconditioner, update_preconditioner, project, beta_debias,
|
5
|
+
from .utils import init_preconditioner, update_preconditioner, project, beta_debias, exp_avg_, update_param_, \
|
6
6
|
precond_schedule, set_, split_p_and_g_in_group, StatefulOptimizer
|
7
7
|
|
8
8
|
|
@@ -81,9 +81,9 @@ class PrecondSchedulePaLMForeachSOAP(StatefulOptimizer):
|
|
81
81
|
|
82
82
|
# Decay the first and second moment running average coefficient
|
83
83
|
# In-place operations to update the averages at the same time
|
84
|
-
torch.
|
85
|
-
torch.
|
86
|
-
denom =
|
84
|
+
beta2 = torch.empty((), dtype=torch.float32, device=p_list[0].device).fill_(beta2)
|
85
|
+
step_tensor = torch.empty((), dtype=torch.int32, device=p_list[0].device).fill_(step)
|
86
|
+
denom = exp_avg_(exp_avg, exp_avg_sq, grad, grad_projected, beta1, beta2, step_tensor)
|
87
87
|
|
88
88
|
update_precond = precond_schedule(step, group['precond_scheduler'], self.rng)
|
89
89
|
for p, g, ea, d in zip(p_list, grad, exp_avg, denom):
|