heavyball 0.19.0__py3-none-any.whl → 0.20.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- heavyball/cached_delayed_psgd_kron.py +11 -11
- heavyball/cached_psgd_kron.py +13 -12
- heavyball/delayed_psgd.py +15 -18
- heavyball/foreach_soap.py +4 -7
- heavyball/p_adam.py +9 -9
- heavyball/palm_foreach_soap.py +6 -6
- heavyball/precond_schedule_foreach_soap.py +6 -10
- heavyball/precond_schedule_palm_foreach_soap.py +4 -4
- heavyball/precond_schedule_sfpsoap.py +20 -10
- heavyball/psgd_kron.py +15 -12
- heavyball/pure_psgd.py +3 -6
- heavyball/schedule_free_palm_foreach_soap.py +17 -8
- heavyball/utils.py +146 -57
- {heavyball-0.19.0.dist-info → heavyball-0.20.1.dist-info}/METADATA +2 -2
- heavyball-0.20.1.dist-info/RECORD +24 -0
- heavyball-0.19.0.dist-info/RECORD +0 -24
- {heavyball-0.19.0.dist-info → heavyball-0.20.1.dist-info}/LICENSE +0 -0
- {heavyball-0.19.0.dist-info → heavyball-0.20.1.dist-info}/WHEEL +0 -0
- {heavyball-0.19.0.dist-info → heavyball-0.20.1.dist-info}/top_level.txt +0 -0
@@ -9,7 +9,7 @@ from typing import Optional
|
|
9
9
|
import torch
|
10
10
|
|
11
11
|
from .utils import update_param_, warmup, init_Q_exprs, trust_region_clip_, PSGDBase, split_p_and_g_in_group, \
|
12
|
-
line_to_triu, triu_to_line,
|
12
|
+
line_to_triu, triu_to_line, einsum_base, promote, stochastic_lerp_, beta_debias
|
13
13
|
|
14
14
|
|
15
15
|
class ForeachCachedDelayedPSGDKron(PSGDBase):
|
@@ -41,7 +41,8 @@ class ForeachCachedDelayedPSGDKron(PSGDBase):
|
|
41
41
|
max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
|
42
42
|
momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
|
43
43
|
split: bool = False, clip_fn: Optional[callable] = None, store_triu_as_line: bool = True,
|
44
|
-
foreach: bool = True, q_dtype='float32', stochastic_schedule: bool = True,
|
44
|
+
foreach: bool = True, q_dtype='float32', stochastic_schedule: bool = True,
|
45
|
+
storage_dtype: str = 'float32', #
|
45
46
|
# expert parameters
|
46
47
|
precond_init_scale=1.0, precond_lr=0.1):
|
47
48
|
if not 0.0 <= lr:
|
@@ -58,7 +59,7 @@ class ForeachCachedDelayedPSGDKron(PSGDBase):
|
|
58
59
|
min_ndim_triangular=min_ndim_triangular, memory_save_mode=memory_save_mode,
|
59
60
|
momentum_into_precond_update=momentum_into_precond_update, precond_lr=precond_lr,
|
60
61
|
precond_init_scale=precond_init_scale, step=0, warmup_steps=warmup_steps, merge_dims=merge_dims,
|
61
|
-
split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype)
|
62
|
+
split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype, storage_dtype=storage_dtype)
|
62
63
|
super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
|
63
64
|
|
64
65
|
def _step(self, group):
|
@@ -74,14 +75,15 @@ class ForeachCachedDelayedPSGDKron(PSGDBase):
|
|
74
75
|
beta = group['beta']
|
75
76
|
store_triu_as_line = group['store_triu_as_line']
|
76
77
|
q_dtype = getattr(torch, group['q_dtype'])
|
78
|
+
storage_dtype = getattr(torch, group['storage_dtype'])
|
77
79
|
|
78
80
|
vals = []
|
79
81
|
|
80
|
-
for p, g in split_p_and_g_in_group(group):
|
82
|
+
for p, g in split_p_and_g_in_group(group, should_promote=False):
|
81
83
|
state = self.state_(p)
|
82
84
|
|
83
85
|
if 'Q' not in state:
|
84
|
-
state["exp_avg"] = torch.zeros_like(g)
|
86
|
+
state["exp_avg"] = torch.zeros_like(g, dtype=storage_dtype)
|
85
87
|
Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular, min_ndim_triangular,
|
86
88
|
memory_save_mode, dtype=q_dtype)
|
87
89
|
state['Q'] = triu_to_line(Q) if store_triu_as_line else Q
|
@@ -105,7 +107,9 @@ class ForeachCachedDelayedPSGDKron(PSGDBase):
|
|
105
107
|
|
106
108
|
group["step"] += 1
|
107
109
|
|
108
|
-
|
110
|
+
stochastic_lerp_(exp_avg_list, grad_list, 1 - beta_debias(beta, group['step']))
|
111
|
+
|
112
|
+
lr = -warmup(lr, group['step'], group['warmup_steps'])
|
109
113
|
|
110
114
|
grad_list, Q_list, Q_cache_list, exp_avg_list = list(grad_list), list(Q_list), list(Q_cache_list), list(
|
111
115
|
exp_avg_list)
|
@@ -127,8 +131,4 @@ class ForeachCachedDelayedPSGDKron(PSGDBase):
|
|
127
131
|
else:
|
128
132
|
torch.mul(q_.conj(), q_, out=c_)
|
129
133
|
|
130
|
-
|
131
|
-
grad_list = self.clip_fn(grad_list)
|
132
|
-
|
133
|
-
lr = -warmup(lr, group['step'], group['warmup_steps'])
|
134
|
-
update_param_(p_list, grad_list, lr, weight_decay)
|
134
|
+
update_param_([p], self.clip_fn([new]), lr, weight_decay)
|
heavyball/cached_psgd_kron.py
CHANGED
@@ -9,7 +9,7 @@ from typing import Optional
|
|
9
9
|
import torch
|
10
10
|
|
11
11
|
from .utils import update_param_, warmup, init_Q_exprs, trust_region_clip_, PSGDBase, split_p_and_g_in_group, \
|
12
|
-
line_to_triu, triu_to_line,
|
12
|
+
line_to_triu, triu_to_line, einsum_base, promote, stochastic_lerp_, beta_debias
|
13
13
|
|
14
14
|
|
15
15
|
class ForeachCachedPSGDKron(PSGDBase):
|
@@ -39,7 +39,8 @@ class ForeachCachedPSGDKron(PSGDBase):
|
|
39
39
|
max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
|
40
40
|
momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
|
41
41
|
split: bool = False, clip_fn: Optional[callable] = None, store_triu_as_line: bool = True,
|
42
|
-
foreach: bool = True, q_dtype='float32', stochastic_schedule: bool = True,
|
42
|
+
foreach: bool = True, q_dtype='float32', stochastic_schedule: bool = True,
|
43
|
+
storage_dtype: str = 'float32', #
|
43
44
|
# expert parameters
|
44
45
|
precond_init_scale=1.0, precond_lr=0.1):
|
45
46
|
if not 0.0 <= lr:
|
@@ -56,7 +57,8 @@ class ForeachCachedPSGDKron(PSGDBase):
|
|
56
57
|
min_ndim_triangular=min_ndim_triangular, memory_save_mode=memory_save_mode,
|
57
58
|
momentum_into_precond_update=momentum_into_precond_update, precond_lr=precond_lr,
|
58
59
|
precond_init_scale=precond_init_scale, step=0, warmup_steps=warmup_steps, merge_dims=merge_dims,
|
59
|
-
split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype
|
60
|
+
split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype,
|
61
|
+
storage_dtype=storage_dtype)
|
60
62
|
super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
|
61
63
|
|
62
64
|
def _step(self, group):
|
@@ -71,15 +73,16 @@ class ForeachCachedPSGDKron(PSGDBase):
|
|
71
73
|
beta = group['beta']
|
72
74
|
store_triu_as_line = group['store_triu_as_line']
|
73
75
|
q_dtype = getattr(torch, group['q_dtype'])
|
76
|
+
storage_dtype = getattr(torch, group['storage_dtype'])
|
74
77
|
should_update = self.should_update(group)
|
75
78
|
|
76
79
|
vals = []
|
77
80
|
|
78
|
-
for p, g in split_p_and_g_in_group(group):
|
81
|
+
for p, g in split_p_and_g_in_group(group, should_promote=False):
|
79
82
|
state = self.state_(p)
|
80
83
|
|
81
84
|
if 'Q' not in state:
|
82
|
-
state["exp_avg"] = torch.zeros_like(g)
|
85
|
+
state["exp_avg"] = torch.zeros_like(g, dtype=storage_dtype)
|
83
86
|
Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular, min_ndim_triangular,
|
84
87
|
memory_save_mode, dtype=q_dtype)
|
85
88
|
state['Q'] = triu_to_line(Q) if store_triu_as_line else Q
|
@@ -103,7 +106,9 @@ class ForeachCachedPSGDKron(PSGDBase):
|
|
103
106
|
|
104
107
|
group["step"] += 1
|
105
108
|
|
106
|
-
|
109
|
+
stochastic_lerp_(exp_avg_list, grad_list, 1 - beta_debias(beta, group['step']))
|
110
|
+
|
111
|
+
lr = -warmup(lr, group['step'], group['warmup_steps'])
|
107
112
|
|
108
113
|
grad_list, Q_list, Q_cache_list, exp_avg_list = list(grad_list), list(Q_list), list(Q_cache_list), list(
|
109
114
|
exp_avg_list)
|
@@ -123,9 +128,5 @@ class ForeachCachedPSGDKron(PSGDBase):
|
|
123
128
|
else:
|
124
129
|
torch.mul(q_.conj(), q_, out=c_)
|
125
130
|
|
126
|
-
|
127
|
-
|
128
|
-
grad_list = self.clip_fn(grad_list)
|
129
|
-
|
130
|
-
lr = -warmup(lr, group['step'], group['warmup_steps'])
|
131
|
-
update_param_(p_list, grad_list, lr, weight_decay)
|
131
|
+
g = torch.einsum(self.state_(p)['cache_expr'], *cached_q, ea)
|
132
|
+
update_param_([p], self.clip_fn([g]), lr, weight_decay)
|
heavyball/delayed_psgd.py
CHANGED
@@ -5,10 +5,10 @@ Source available at https://github.com/evanatyourservice/kron_torch/blob/97a2b5e
|
|
5
5
|
"""
|
6
6
|
|
7
7
|
import torch
|
8
|
-
from heavyball.utils import copy_stochastic_list_
|
9
8
|
|
9
|
+
from heavyball.utils import stochastic_lerp_, beta_debias
|
10
10
|
from .utils import update_param_, warmup, psgd_precond_grad, init_Q_exprs, trust_region_clip_, PSGDBase, \
|
11
|
-
|
11
|
+
split_p_and_g_in_group, triu_to_line, line_to_triu, promote
|
12
12
|
|
13
13
|
|
14
14
|
class ForeachDelayedPSGD(PSGDBase):
|
@@ -38,8 +38,8 @@ class ForeachDelayedPSGD(PSGDBase):
|
|
38
38
|
def __init__(self, params, lr=0.001, beta=0.9, weight_decay=0.0, preconditioner_update_probability=None,
|
39
39
|
max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
|
40
40
|
momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
|
41
|
-
split: bool = False, clip_fn: callable = None, store_triu_as_line: bool = True,
|
42
|
-
|
41
|
+
split: bool = False, clip_fn: callable = None, store_triu_as_line: bool = True, foreach: bool = True,
|
42
|
+
q_dtype='float32', stochastic_schedule: bool = True, storage_dtype:str='float32', #
|
43
43
|
# expert parameters
|
44
44
|
precond_init_scale=1.0, precond_lr=0.1):
|
45
45
|
if not 0.0 <= lr:
|
@@ -55,12 +55,10 @@ class ForeachDelayedPSGD(PSGDBase):
|
|
55
55
|
defaults = dict(lr=lr, beta=beta, weight_decay=weight_decay, max_size_triangular=max_size_triangular,
|
56
56
|
min_ndim_triangular=min_ndim_triangular, memory_save_mode=memory_save_mode,
|
57
57
|
momentum_into_precond_update=momentum_into_precond_update, precond_lr=precond_lr,
|
58
|
-
precond_init_scale=precond_init_scale,
|
59
|
-
|
60
|
-
store_triu_as_line=store_triu_as_line, q_dtype=q_dtype)
|
58
|
+
precond_init_scale=precond_init_scale, step=0, warmup_steps=warmup_steps, merge_dims=merge_dims,
|
59
|
+
split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype)
|
61
60
|
super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
|
62
61
|
|
63
|
-
|
64
62
|
def _step(self, group):
|
65
63
|
should_update = self.should_update(group)
|
66
64
|
momentum_into_precond_update = group.get("momentum_into_precond_update", True)
|
@@ -74,14 +72,15 @@ class ForeachDelayedPSGD(PSGDBase):
|
|
74
72
|
beta = group['beta']
|
75
73
|
store_triu_as_line = group['store_triu_as_line']
|
76
74
|
q_dtype = getattr(torch, group['q_dtype'])
|
75
|
+
storage_dtype = getattr(torch, group['storage_dtype'])
|
77
76
|
|
78
77
|
vals = []
|
79
78
|
|
80
|
-
for p, g in split_p_and_g_in_group(group):
|
79
|
+
for p, g in split_p_and_g_in_group(group, should_promote=False):
|
81
80
|
state = self.state_(p)
|
82
81
|
|
83
82
|
if 'Q' not in state:
|
84
|
-
state["exp_avg"] = torch.zeros_like(g)
|
83
|
+
state["exp_avg"] = torch.zeros_like(g, dtype=storage_dtype)
|
85
84
|
Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular, min_ndim_triangular,
|
86
85
|
memory_save_mode, dtype=q_dtype)
|
87
86
|
state["Q"] = triu_to_line(Q) if store_triu_as_line else Q
|
@@ -96,7 +95,9 @@ class ForeachDelayedPSGD(PSGDBase):
|
|
96
95
|
|
97
96
|
group["step"] += 1
|
98
97
|
|
99
|
-
|
98
|
+
stochastic_lerp_(exp_avg_list, grad_list, beta_debias(beta, group["step"]))
|
99
|
+
|
100
|
+
lr = -warmup(lr, group['step'], group['warmup_steps'])
|
100
101
|
|
101
102
|
Q_list, exp_avg_list = list(Q_list), list(exp_avg_list)
|
102
103
|
for i, (p, g) in enumerate(zip(p_list, grad_list)):
|
@@ -106,10 +107,6 @@ class ForeachDelayedPSGD(PSGDBase):
|
|
106
107
|
new = psgd_precond_grad(q, self.state_(p)["exprs"], ea)
|
107
108
|
if should_update:
|
108
109
|
q32 = [promote(q_) for q_ in q]
|
109
|
-
self.do_update(group,[p], [ea if momentum_into_precond_update else g], [q32], precond_lr, [q_orig],
|
110
|
-
|
111
|
-
|
112
|
-
grad_list = self.clip_fn(grad_list)
|
113
|
-
|
114
|
-
lr = -warmup(lr, group['step'], group['warmup_steps'])
|
115
|
-
update_param_(p_list, grad_list, lr, weight_decay)
|
110
|
+
self.do_update(group, [p], [ea if momentum_into_precond_update else g], [q32], precond_lr, [q_orig],
|
111
|
+
store_triu_as_line)
|
112
|
+
update_param_([p], self.clip_fn([new]), lr, weight_decay)
|
heavyball/foreach_soap.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
import torch
|
2
2
|
|
3
3
|
from .utils import init_preconditioner, update_preconditioner, project, beta_debias, exp_avg_sq_, update_param_, set_, \
|
4
|
-
split_p_and_g_in_group, StatefulOptimizer
|
4
|
+
split_p_and_g_in_group, StatefulOptimizer, exp_avg_
|
5
5
|
|
6
6
|
|
7
7
|
class ForeachSOAP(StatefulOptimizer):
|
@@ -26,8 +26,7 @@ class ForeachSOAP(StatefulOptimizer):
|
|
26
26
|
weight_decay: float = 0.01, precondition_frequency: int = 2, max_precond_dim: int = 2048, #
|
27
27
|
merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
|
28
28
|
data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1,
|
29
|
-
split: bool = False,
|
30
|
-
foreach: bool = True):
|
29
|
+
split: bool = False, foreach: bool = True):
|
31
30
|
defaults = {"lr": lr, "betas": betas, "shampoo_beta": shampoo_beta, "eps": eps, "weight_decay": weight_decay,
|
32
31
|
"precondition_frequency": precondition_frequency, "max_precond_dim": max_precond_dim,
|
33
32
|
"merge_dims": merge_dims, "precondition_1d": precondition_1d, "normalize_grads": normalize_grads,
|
@@ -65,14 +64,12 @@ class ForeachSOAP(StatefulOptimizer):
|
|
65
64
|
p_list, grad, grad_projected, exp_avg, exp_avg_sq = zip(*vals)
|
66
65
|
beta1, beta2 = group["betas"]
|
67
66
|
|
68
|
-
old_debiased1 = beta_debias(beta1, step)
|
69
67
|
old_debiased2 = beta_debias(beta2, step)
|
70
68
|
|
71
69
|
# Decay the first and second moment running average coefficient
|
72
70
|
# In-place operations to update the averages at the same time
|
73
|
-
torch.
|
74
|
-
|
75
|
-
denom = exp_avg_sq_(exp_avg_sq, grad_projected, old_debiased2, group['eps'])
|
71
|
+
step_tensor = torch.empty((), dtype=torch.int32, device=p_list[0].device).fill_(step)
|
72
|
+
denom = exp_avg_(exp_avg, exp_avg_sq, grad, grad_projected, beta1, beta2, step_tensor)
|
76
73
|
|
77
74
|
for p, g, ea, d in zip(p_list, grad, exp_avg, denom):
|
78
75
|
state = self.state_(p)
|
heavyball/p_adam.py
CHANGED
@@ -39,7 +39,7 @@ class ForeachPaLMPAdam(PSGDBase):
|
|
39
39
|
momentum_into_precond_update=True, warmup_steps: int = 1, betas=(None, None), beta: float = 0.9,
|
40
40
|
beta2_scale: float = 0.8, merge_dims: bool = False, split: bool = False, clip_fn: callable = None,
|
41
41
|
store_triu_as_line: bool = True, foreach: bool = True, q_dtype='float32',
|
42
|
-
stochastic_schedule: bool = True,
|
42
|
+
stochastic_schedule: bool = True, storage_dtype:str ='float32',#
|
43
43
|
# expert parameters
|
44
44
|
precond_init_scale=1.0, precond_lr=0.1):
|
45
45
|
if not 0.0 <= lr:
|
@@ -57,7 +57,7 @@ class ForeachPaLMPAdam(PSGDBase):
|
|
57
57
|
momentum_into_precond_update=momentum_into_precond_update, precond_lr=precond_lr,
|
58
58
|
precond_init_scale=precond_init_scale, step=0, warmup_steps=warmup_steps, beta=beta,
|
59
59
|
beta2_scale=beta2_scale, merge_dims=merge_dims, split=split,
|
60
|
-
store_triu_as_line=store_triu_as_line, q_dtype=q_dtype)
|
60
|
+
store_triu_as_line=store_triu_as_line, q_dtype=q_dtype, storage_dtype=storage_dtype)
|
61
61
|
super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
|
62
62
|
|
63
63
|
def _step(self, group):
|
@@ -71,15 +71,16 @@ class ForeachPaLMPAdam(PSGDBase):
|
|
71
71
|
lr = group['lr']
|
72
72
|
store_triu_as_line = group['store_triu_as_line']
|
73
73
|
q_dtype = getattr(torch, group['q_dtype'])
|
74
|
+
storage_dtype = getattr(torch, group['storage_dtype'])
|
74
75
|
|
75
76
|
vals = []
|
76
77
|
|
77
|
-
for p, g in split_p_and_g_in_group(group):
|
78
|
+
for p, g in split_p_and_g_in_group(group, should_promote=False):
|
78
79
|
state = self.state_(p)
|
79
80
|
|
80
81
|
if 'Q' not in state:
|
81
|
-
state['exp_avg'] = torch.zeros_like(g)
|
82
|
-
state['exp_avg_sq'] = torch.zeros_like(g)
|
82
|
+
state['exp_avg'] = torch.zeros_like(g, dtype=storage_dtype)
|
83
|
+
state['exp_avg_sq'] = torch.zeros_like(g, dtype=storage_dtype)
|
83
84
|
Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular, min_ndim_triangular,
|
84
85
|
memory_save_mode, dtype=q_dtype)
|
85
86
|
state['Q'] = triu_to_line(Q) if store_triu_as_line else Q
|
@@ -103,6 +104,8 @@ class ForeachPaLMPAdam(PSGDBase):
|
|
103
104
|
|
104
105
|
beta2 = 1 - group['step'] ** -group['beta2_scale']
|
105
106
|
|
107
|
+
lr = -warmup(lr, group['step'], group['warmup_steps'])
|
108
|
+
|
106
109
|
for p, Q, g, ea, eas in zip(p_list, Q_triu, grad_list, exp_avg, exp_avg_sq):
|
107
110
|
psgd_precond_grad(Q, self.state_(p)["exprs"], g, inplace=True)
|
108
111
|
ea = psgd_precond_grad(Q, self.state_(p)["exprs"], ea)
|
@@ -112,8 +115,5 @@ class ForeachPaLMPAdam(PSGDBase):
|
|
112
115
|
divide by g here, because g == denom (from exp_avg_sq_(out=g)), avoids denom allocation
|
113
116
|
divide into g so we can deallocate ea, avoids one allocation (-> less memory than equivalent foreach)
|
114
117
|
"""
|
118
|
+
update_param_([p], self.clip_fn([g]), lr, weight_decay)
|
115
119
|
|
116
|
-
grad_list = self.clip_fn(grad_list)
|
117
|
-
|
118
|
-
lr = -warmup(lr, group['step'], group['warmup_steps'])
|
119
|
-
update_param_(p_list, grad_list, lr, weight_decay)
|
heavyball/palm_foreach_soap.py
CHANGED
@@ -1,7 +1,8 @@
|
|
1
1
|
import torch
|
2
2
|
|
3
3
|
from .utils import init_preconditioner, update_preconditioner, project, beta_debias, exp_avg_sq_, update_param_, set_, \
|
4
|
-
split_p_and_g_in_group, StatefulOptimizer
|
4
|
+
split_p_and_g_in_group, StatefulOptimizer, exp_avg_
|
5
|
+
|
5
6
|
|
6
7
|
|
7
8
|
class PaLMForeachSOAP(StatefulOptimizer):
|
@@ -32,8 +33,7 @@ class PaLMForeachSOAP(StatefulOptimizer):
|
|
32
33
|
max_precond_dim: int = 2048, #
|
33
34
|
merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
|
34
35
|
data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1,
|
35
|
-
beta2_scale: float = 0.8, split: bool = False,
|
36
|
-
foreach: bool = True):
|
36
|
+
beta2_scale: float = 0.8, split: bool = False, foreach: bool = True):
|
37
37
|
if betas[0] is not None:
|
38
38
|
beta = betas[0]
|
39
39
|
defaults = {"lr": lr, "beta": beta, "shampoo_beta": shampoo_beta, "eps": eps, "weight_decay": weight_decay,
|
@@ -75,13 +75,13 @@ class PaLMForeachSOAP(StatefulOptimizer):
|
|
75
75
|
beta1 = group["beta"]
|
76
76
|
|
77
77
|
beta2 = 1 - step ** -group['beta2_scale']
|
78
|
-
old_debiased1 = beta_debias(beta1, step)
|
79
78
|
old_debiased2 = beta_debias(beta2, step)
|
80
79
|
|
81
80
|
# Decay the first and second moment running average coefficient
|
82
81
|
# In-place operations to update the averages at the same time
|
83
|
-
torch.
|
84
|
-
|
82
|
+
beta2 = torch.empty((), dtype=torch.float32, device=p_list[0].device).fill_(beta2)
|
83
|
+
step_tensor = torch.empty((), dtype=torch.int32, device=p_list[0].device).fill_(step)
|
84
|
+
denom = exp_avg_(exp_avg, exp_avg_sq, grad, grad_projected, beta1, beta2, step_tensor)
|
85
85
|
|
86
86
|
for p, g, ea, d in zip(p_list, grad, exp_avg, denom):
|
87
87
|
state = self.state_(p)
|
@@ -2,8 +2,8 @@ import random
|
|
2
2
|
|
3
3
|
import torch
|
4
4
|
|
5
|
-
from .utils import init_preconditioner, update_preconditioner, project, beta_debias,
|
6
|
-
precond_schedule, set_, split_p_and_g_in_group, StatefulOptimizer
|
5
|
+
from .utils import init_preconditioner, update_preconditioner, project, beta_debias, update_param_, \
|
6
|
+
precond_schedule, set_, split_p_and_g_in_group, StatefulOptimizer, exp_avg_
|
7
7
|
|
8
8
|
|
9
9
|
class PrecondScheduleForeachSOAP(StatefulOptimizer):
|
@@ -27,8 +27,7 @@ class PrecondScheduleForeachSOAP(StatefulOptimizer):
|
|
27
27
|
weight_decay: float = 0.01, precondition_frequency: int = 2, max_precond_dim: int = 2048, #
|
28
28
|
merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
|
29
29
|
data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1,
|
30
|
-
precond_scheduler=(1 / 3, 9), split: bool = False,
|
31
|
-
foreach: bool = True):
|
30
|
+
precond_scheduler=(1 / 3, 9), split: bool = False, foreach: bool = True):
|
32
31
|
defaults = {"lr": lr, "betas": betas, "shampoo_beta": shampoo_beta, "eps": eps, "weight_decay": weight_decay,
|
33
32
|
"precondition_frequency": precondition_frequency, "max_precond_dim": max_precond_dim,
|
34
33
|
"merge_dims": merge_dims, "precondition_1d": precondition_1d, "normalize_grads": normalize_grads,
|
@@ -68,14 +67,12 @@ class PrecondScheduleForeachSOAP(StatefulOptimizer):
|
|
68
67
|
p_list, grad, grad_projected, exp_avg, exp_avg_sq = zip(*vals)
|
69
68
|
beta1, beta2 = group["betas"]
|
70
69
|
|
71
|
-
old_debiased1 = beta_debias(beta1, step)
|
72
70
|
old_debiased2 = beta_debias(beta2, step)
|
73
71
|
|
74
72
|
# Decay the first and second moment running average coefficient
|
75
73
|
# In-place operations to update the averages at the same time
|
76
|
-
torch.
|
77
|
-
|
78
|
-
denom = exp_avg_sq_(exp_avg_sq, grad_projected, old_debiased2, group['eps'])
|
74
|
+
step_tensor = torch.empty((), dtype=torch.int32, device=p_list[0].device).fill_(step)
|
75
|
+
denom = exp_avg_(exp_avg, exp_avg_sq, grad, grad_projected, beta1, beta2, step_tensor)
|
79
76
|
|
80
77
|
update_precond = precond_schedule(step, group['precond_scheduler'], self.rng)
|
81
78
|
for p, g, ea, d in zip(p_list, grad, exp_avg, denom):
|
@@ -89,8 +86,7 @@ class PrecondScheduleForeachSOAP(StatefulOptimizer):
|
|
89
86
|
# CANT DO /= HERE AS EXP_AVG MAY POINT TO THE BUFFER
|
90
87
|
set_(d, project(exp_avg_projected / d, state['Q'], True))
|
91
88
|
|
92
|
-
update_preconditioner(g, state, max_precond_dim, precondition_1d, old_debiased2,
|
93
|
-
update_precond)
|
89
|
+
update_preconditioner(g, state, max_precond_dim, precondition_1d, old_debiased2, update_precond)
|
94
90
|
|
95
91
|
# Why does this have to be rebiased here?
|
96
92
|
step_size = -group["lr"] * min(step / group['warmup_steps'], 1)
|
@@ -2,7 +2,7 @@ import random
|
|
2
2
|
|
3
3
|
import torch
|
4
4
|
|
5
|
-
from .utils import init_preconditioner, update_preconditioner, project, beta_debias,
|
5
|
+
from .utils import init_preconditioner, update_preconditioner, project, beta_debias, exp_avg_, update_param_, \
|
6
6
|
precond_schedule, set_, split_p_and_g_in_group, StatefulOptimizer
|
7
7
|
|
8
8
|
|
@@ -81,9 +81,9 @@ class PrecondSchedulePaLMForeachSOAP(StatefulOptimizer):
|
|
81
81
|
|
82
82
|
# Decay the first and second moment running average coefficient
|
83
83
|
# In-place operations to update the averages at the same time
|
84
|
-
torch.
|
85
|
-
torch.
|
86
|
-
denom =
|
84
|
+
beta2 = torch.empty((), dtype=torch.float32, device=p_list[0].device).fill_(beta2)
|
85
|
+
step_tensor = torch.empty((), dtype=torch.int32, device=p_list[0].device).fill_(step)
|
86
|
+
denom = exp_avg_(exp_avg, exp_avg_sq, grad, grad_projected, beta1, beta2, step_tensor)
|
87
87
|
|
88
88
|
update_precond = precond_schedule(step, group['precond_scheduler'], self.rng)
|
89
89
|
for p, g, ea, d in zip(p_list, grad, exp_avg, denom):
|
@@ -2,8 +2,19 @@ import random
|
|
2
2
|
|
3
3
|
import torch
|
4
4
|
|
5
|
-
from .utils import init_preconditioner, update_preconditioner, project, set_, adaptive_gradient_clipping_, \
|
6
|
-
|
5
|
+
from .utils import init_preconditioner, update_preconditioner, project, set_, adaptive_gradient_clipping_, exp_avg_sq_, \
|
6
|
+
beta_debias, schedule_free_, warmup, ScheduleFree, precond_schedule, split_p_and_g_in_group, copy_stochastic_list_, \
|
7
|
+
promote
|
8
|
+
|
9
|
+
|
10
|
+
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
|
11
|
+
def _compilable_exp_avg_sq_(exp_avg_sq, grad_projected, old_debiased2, eps):
|
12
|
+
eas32, gp32 = [list(map(promote, x)) for x in (exp_avg_sq, grad_projected)]
|
13
|
+
denom = exp_avg_sq_(eas32, gp32, old_debiased2, eps)
|
14
|
+
torch._foreach_div_(gp32, denom)
|
15
|
+
|
16
|
+
copy_stochastic_list_(exp_avg_sq, eas32)
|
17
|
+
copy_stochastic_list_(grad_projected, gp32)
|
7
18
|
|
8
19
|
|
9
20
|
class PrecondScheduleSFPaLMSOAP(ScheduleFree):
|
@@ -40,8 +51,8 @@ class PrecondScheduleSFPaLMSOAP(ScheduleFree):
|
|
40
51
|
weight_decay: float = 0.01, precondition_frequency: int = 2, max_precond_dim: int = 2048, #
|
41
52
|
merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
|
42
53
|
data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1, r=0.0,
|
43
|
-
weight_lr_power=2.0, gradient_clip_val: float = 0.1, precond_scheduler=(1 / 3, 9),
|
44
|
-
|
54
|
+
weight_lr_power=2.0, gradient_clip_val: float = 0.1, precond_scheduler=(1 / 3, 9), betas=(None, None),
|
55
|
+
split: bool = False, foreach: bool = True):
|
45
56
|
if betas[0] is not None:
|
46
57
|
beta = betas[0]
|
47
58
|
defaults = {"lr": lr, "beta": beta, "beta2_scale": beta2_scale, "eps": eps, "weight_decay": weight_decay,
|
@@ -103,8 +114,8 @@ class PrecondScheduleSFPaLMSOAP(ScheduleFree):
|
|
103
114
|
|
104
115
|
# Decay the first and second moment running average coefficient
|
105
116
|
# In-place operations to update the averages at the same time
|
106
|
-
|
107
|
-
|
117
|
+
old_debiased_tensor = torch.empty((), dtype=torch.float32, device=p_list[0].device).fill_(old_debiased2)
|
118
|
+
_compilable_exp_avg_sq_(exp_avg_sq, grad_projected, old_debiased_tensor, group["eps"])
|
108
119
|
|
109
120
|
update_precond = precond_schedule(step, group['precond_scheduler'], self.rng)
|
110
121
|
|
@@ -114,13 +125,12 @@ class PrecondScheduleSFPaLMSOAP(ScheduleFree):
|
|
114
125
|
# to the original space
|
115
126
|
set_(gp, project(gp, state['Q'], back=True))
|
116
127
|
|
117
|
-
update_preconditioner(g, state, max_precond_dim, precondition_1d, old_debiased2,
|
118
|
-
update_precond)
|
128
|
+
update_preconditioner(g, state, max_precond_dim, precondition_1d, old_debiased2, update_precond)
|
119
129
|
|
120
130
|
# Weight decay calculated at y
|
121
131
|
if group["weight_decay"] > 0:
|
122
132
|
torch._foreach_add_(grad, p_list, alpha=group["weight_decay"])
|
123
133
|
|
124
134
|
lr = warmup(group['lr'], step, group['warmup_steps'])
|
125
|
-
group['weight_sum'] = schedule_free_(lr, group['weight_lr_power'], group['weight_sum'], group['beta'],
|
126
|
-
|
135
|
+
group['weight_sum'] = schedule_free_(lr, group['weight_lr_power'], group['weight_sum'], group['beta'], p_list,
|
136
|
+
z, grad_projected, group['r'], step)
|
heavyball/psgd_kron.py
CHANGED
@@ -9,7 +9,7 @@ from typing import Optional
|
|
9
9
|
import torch
|
10
10
|
|
11
11
|
from .utils import update_param_, warmup, psgd_precond_grad, init_Q_exprs, trust_region_clip_, PSGDBase, \
|
12
|
-
split_p_and_g_in_group, line_to_triu, triu_to_line,
|
12
|
+
split_p_and_g_in_group, line_to_triu, triu_to_line, promote, stochastic_lerp_, beta_debias
|
13
13
|
|
14
14
|
|
15
15
|
class ForeachPSGDKron(PSGDBase):
|
@@ -39,7 +39,8 @@ class ForeachPSGDKron(PSGDBase):
|
|
39
39
|
max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
|
40
40
|
momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
|
41
41
|
split: bool = False, clip_fn: Optional[callable] = None, store_triu_as_line: bool = True,
|
42
|
-
foreach: bool = True, q_dtype='float32', stochastic_schedule: bool = True,
|
42
|
+
foreach: bool = True, q_dtype='float32', stochastic_schedule: bool = True,
|
43
|
+
storage_dtype: str = 'float32', #
|
43
44
|
# expert parameters
|
44
45
|
precond_init_scale=1.0, precond_lr=0.1):
|
45
46
|
if not 0.0 <= lr:
|
@@ -56,7 +57,7 @@ class ForeachPSGDKron(PSGDBase):
|
|
56
57
|
min_ndim_triangular=min_ndim_triangular, memory_save_mode=memory_save_mode,
|
57
58
|
momentum_into_precond_update=momentum_into_precond_update, precond_lr=precond_lr,
|
58
59
|
precond_init_scale=precond_init_scale, step=0, warmup_steps=warmup_steps, merge_dims=merge_dims,
|
59
|
-
split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype)
|
60
|
+
split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype, storage_dtype=storage_dtype)
|
60
61
|
super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
|
61
62
|
|
62
63
|
def _step(self, group):
|
@@ -72,14 +73,15 @@ class ForeachPSGDKron(PSGDBase):
|
|
72
73
|
beta = group['beta']
|
73
74
|
store_triu_as_line = group['store_triu_as_line']
|
74
75
|
q_dtype = getattr(torch, group['q_dtype'])
|
76
|
+
storage_dtype = getattr(torch, group['storage_dtype'])
|
75
77
|
|
76
78
|
vals = []
|
77
79
|
|
78
|
-
for p, g in split_p_and_g_in_group(group):
|
80
|
+
for p, g in split_p_and_g_in_group(group, should_promote=False):
|
79
81
|
state = self.state_(p)
|
80
82
|
|
81
83
|
if 'Q' not in state:
|
82
|
-
state["exp_avg"] = torch.zeros_like(g)
|
84
|
+
state["exp_avg"] = torch.zeros_like(g, dtype=storage_dtype)
|
83
85
|
Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular, min_ndim_triangular,
|
84
86
|
memory_save_mode, dtype=q_dtype)
|
85
87
|
state['Q'] = triu_to_line(Q) if store_triu_as_line else Q
|
@@ -94,9 +96,14 @@ class ForeachPSGDKron(PSGDBase):
|
|
94
96
|
|
95
97
|
group["step"] += 1
|
96
98
|
|
97
|
-
|
99
|
+
beta = beta_debias(beta, group["step"])
|
100
|
+
beta = torch.empty((), dtype=torch.float32, device=grad_list[0].device).fill_(1 - beta)
|
101
|
+
stochastic_lerp_(exp_avg_list, grad_list, 1 - beta)
|
98
102
|
|
99
103
|
grad_list, Q_list, exp_avg_list = list(grad_list), list(Q_list), list(exp_avg_list)
|
104
|
+
|
105
|
+
lr = -warmup(lr, group['step'], group['warmup_steps'])
|
106
|
+
|
100
107
|
for i, (p, g) in enumerate(zip(p_list, grad_list)):
|
101
108
|
q_orig = Q_list.pop(0)
|
102
109
|
ea = exp_avg_list.pop(0)
|
@@ -106,9 +113,5 @@ class ForeachPSGDKron(PSGDBase):
|
|
106
113
|
q32 = [promote(q_) for q_ in q]
|
107
114
|
self.do_update(group, [p], [ea if momentum_into_precond_update else g], [q32], precond_lr, [q_orig],
|
108
115
|
store_triu_as_line)
|
109
|
-
|
110
|
-
|
111
|
-
grad_list = self.clip_fn(grad_list)
|
112
|
-
|
113
|
-
lr = -warmup(lr, group['step'], group['warmup_steps'])
|
114
|
-
update_param_(p_list, grad_list, lr, weight_decay)
|
116
|
+
g = psgd_precond_grad(q, self.state_(p)["exprs"], ea)
|
117
|
+
update_param_([p], self.clip_fn([g]), lr, weight_decay)
|
heavyball/pure_psgd.py
CHANGED
@@ -70,7 +70,7 @@ class ForeachPurePSGD(PSGDBase):
|
|
70
70
|
|
71
71
|
vals = []
|
72
72
|
|
73
|
-
for p, g in split_p_and_g_in_group(group):
|
73
|
+
for p, g in split_p_and_g_in_group(group, should_promote=False):
|
74
74
|
state = self.state_(p)
|
75
75
|
|
76
76
|
if 'Q' not in state:
|
@@ -89,6 +89,7 @@ class ForeachPurePSGD(PSGDBase):
|
|
89
89
|
group["step"] += 1
|
90
90
|
|
91
91
|
Q_list = list(Q_list)
|
92
|
+
lr = -warmup(lr, group['step'], group['warmup_steps'])
|
92
93
|
for i, (p, g) in enumerate(zip(p_list, grad_list)):
|
93
94
|
q_orig = Q_list.pop(0)
|
94
95
|
q = line_to_triu(q_orig) if store_triu_as_line else q_orig
|
@@ -97,8 +98,4 @@ class ForeachPurePSGD(PSGDBase):
|
|
97
98
|
q32 = [promote(q_) for q_ in q]
|
98
99
|
self.do_update(group, [p], [g], [q32], precond_lr, [q_orig], store_triu_as_line)
|
99
100
|
psgd_precond_grad(q, self.state_(p)["exprs"], g, inplace=True)
|
100
|
-
|
101
|
-
grad_list = self.clip_fn(grad_list)
|
102
|
-
|
103
|
-
lr = -warmup(lr, group['step'], group['warmup_steps'])
|
104
|
-
update_param_(p_list, grad_list, lr, weight_decay)
|
101
|
+
update_param_([p], self.clip_fn([g]), lr, weight_decay)
|
@@ -2,8 +2,18 @@ import random
|
|
2
2
|
|
3
3
|
import torch
|
4
4
|
|
5
|
-
from .utils import init_preconditioner, update_preconditioner, project, set_, adaptive_gradient_clipping_, \
|
6
|
-
|
5
|
+
from .utils import init_preconditioner, update_preconditioner, project, set_, adaptive_gradient_clipping_, exp_avg_sq_, \
|
6
|
+
beta_debias, schedule_free_, warmup, ScheduleFree, split_p_and_g_in_group, copy_stochastic_list_, promote
|
7
|
+
|
8
|
+
|
9
|
+
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
|
10
|
+
def _compilable_exp_avg_sq_(exp_avg_sq, grad_projected, old_debiased2, eps):
|
11
|
+
eas32, gp32 = [list(map(promote, x)) for x in (exp_avg_sq, grad_projected)]
|
12
|
+
denom = exp_avg_sq_(eas32, gp32, old_debiased2, eps)
|
13
|
+
torch._foreach_div_(gp32, denom)
|
14
|
+
|
15
|
+
copy_stochastic_list_(exp_avg_sq, eas32)
|
16
|
+
copy_stochastic_list_(grad_projected, gp32)
|
7
17
|
|
8
18
|
|
9
19
|
class SFPaLMForeachSOAP(ScheduleFree):
|
@@ -95,8 +105,8 @@ class SFPaLMForeachSOAP(ScheduleFree):
|
|
95
105
|
|
96
106
|
# Decay the first and second moment running average coefficient
|
97
107
|
# In-place operations to update the averages at the same time
|
98
|
-
|
99
|
-
|
108
|
+
old_debiased_tensor = torch.empty((), dtype=torch.float32, device=p_list[0].device).fill_(new_debiased2)
|
109
|
+
_compilable_exp_avg_sq_(exp_avg_sq, grad_projected, old_debiased_tensor, group["eps"])
|
100
110
|
|
101
111
|
update_precond = group['step'] > 0 and group['step'] % group['precondition_frequency'] == 0
|
102
112
|
|
@@ -107,13 +117,12 @@ class SFPaLMForeachSOAP(ScheduleFree):
|
|
107
117
|
# CANT DO /= HERE AS EXP_AVG MAY POINT TO THE BUFFER
|
108
118
|
set_(gp, project(gp, state['Q'], back=True))
|
109
119
|
|
110
|
-
update_preconditioner(g, state, max_precond_dim, precondition_1d, 1 - new_debiased2,
|
111
|
-
update_precond)
|
120
|
+
update_preconditioner(g, state, max_precond_dim, precondition_1d, 1 - new_debiased2, update_precond)
|
112
121
|
|
113
122
|
# Weight decay calculated at y
|
114
123
|
if group["weight_decay"] > 0:
|
115
124
|
torch._foreach_add_(grad, p_list, alpha=group["weight_decay"])
|
116
125
|
|
117
126
|
lr = warmup(group['lr'], step, group['warmup_steps'])
|
118
|
-
group['weight_sum'] = schedule_free_(lr, group['weight_lr_power'], group['weight_sum'], group['beta'],
|
119
|
-
|
127
|
+
group['weight_sum'] = schedule_free_(lr, group['weight_lr_power'], group['weight_sum'], group['beta'], p_list,
|
128
|
+
z, grad_projected, group['r'], step)
|
heavyball/utils.py
CHANGED
@@ -3,7 +3,7 @@ import gc
|
|
3
3
|
import math
|
4
4
|
import random
|
5
5
|
import string
|
6
|
-
from typing import List, Optional, Tuple, Callable
|
6
|
+
from typing import List, Optional, Tuple, Callable, Union
|
7
7
|
|
8
8
|
import numpy as np
|
9
9
|
import torch
|
@@ -141,6 +141,7 @@ def beta_debias(beta, step):
|
|
141
141
|
return 1 - (1 - beta) / (1 - beta ** step)
|
142
142
|
|
143
143
|
|
144
|
+
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
|
144
145
|
def exp_avg_sq_(state, grad, beta2, eps, out=None):
|
145
146
|
if isinstance(state, torch.Tensor):
|
146
147
|
state.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
|
@@ -327,6 +328,26 @@ def get_orthogonal_matrix(mat):
|
|
327
328
|
return final
|
328
329
|
|
329
330
|
|
331
|
+
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
|
332
|
+
def stochastic_lerp_(x: List[torch.Tensor], y: List[torch.Tensor], a: Union[float, int, torch.Tensor]):
|
333
|
+
x32 = [promote(x_) for x_ in x]
|
334
|
+
y32 = [promote(y_) for y_ in y]
|
335
|
+
|
336
|
+
torch._foreach_lerp_(x32, y32, a)
|
337
|
+
|
338
|
+
copy_stochastic_list_(x, x32)
|
339
|
+
|
340
|
+
|
341
|
+
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
|
342
|
+
def stochastic_add_(x: List[torch.Tensor], y: List[torch.Tensor], alpha: Union[float, int, torch.Tensor]):
|
343
|
+
x32 = [promote(x_) for x_ in x]
|
344
|
+
y32 = [promote(y_) for y_ in y]
|
345
|
+
|
346
|
+
[x_.add_(y_, alpha=alpha) for x_, y_ in zip(x32, y32)]
|
347
|
+
|
348
|
+
copy_stochastic_list_(x, x32)
|
349
|
+
|
350
|
+
|
330
351
|
@decorator
|
331
352
|
def compute_ggt(grad, GG, max_precond_dim, precondition_1d, beta):
|
332
353
|
if grad.dim() == 1 and (not precondition_1d or grad.shape[0] > max_precond_dim):
|
@@ -409,9 +430,12 @@ def project(grad, Q, back: bool):
|
|
409
430
|
|
410
431
|
|
411
432
|
class StatefulOptimizer(torch.optim.Optimizer):
|
412
|
-
|
433
|
+
ema_decay: float = 0.001
|
434
|
+
|
435
|
+
def __init__(self, params, defaults, foreach: bool = True, use_ema: bool = False):
|
413
436
|
super().__init__(params, {**defaults, 'foreach': foreach})
|
414
437
|
self.fake_groups = {}
|
438
|
+
self.use_ema = use_ema
|
415
439
|
|
416
440
|
def key(self, param: torch.Tensor):
|
417
441
|
return (param.data_ptr(), tuple(param.shape))
|
@@ -445,6 +469,54 @@ class StatefulOptimizer(torch.optim.Optimizer):
|
|
445
469
|
def _step(self, group):
|
446
470
|
raise NotImplementedError
|
447
471
|
|
472
|
+
def ema_update(self):
|
473
|
+
with torch.no_grad():
|
474
|
+
for top_group in self.param_groups:
|
475
|
+
for group in self.get_groups(top_group):
|
476
|
+
active_p = [p for p in group['params']]
|
477
|
+
|
478
|
+
if not active_p:
|
479
|
+
return
|
480
|
+
|
481
|
+
k = group['ema_step'] = group.get('ema_step', -1) + 1
|
482
|
+
|
483
|
+
for p in active_p:
|
484
|
+
if 'param_ema' not in self.state_(p):
|
485
|
+
self.state_(p)['param_ema'] = torch.zeros_like(p.data, memory_format=torch.preserve_format)
|
486
|
+
|
487
|
+
y, param_ema = zip(*[(p.data, self.state_(p)['param_ema']) for p in active_p])
|
488
|
+
torch._foreach_lerp_(param_ema, y, weight=beta_debias(1 - self.ema_decay, k + 1))
|
489
|
+
|
490
|
+
def copy_emas_to_params(self):
|
491
|
+
with torch.no_grad():
|
492
|
+
for top_group in self.param_groups:
|
493
|
+
for group in self.get_groups(top_group):
|
494
|
+
active_p = [p for p in group['params']]
|
495
|
+
|
496
|
+
if not active_p:
|
497
|
+
return
|
498
|
+
|
499
|
+
for p in active_p:
|
500
|
+
if 'param_ema' in self.state_(p):
|
501
|
+
p_clone = p.data.clone()
|
502
|
+
set_(p.data, self.state_(p)['param_ema'])
|
503
|
+
set_(self.state_(p)['param_ema'], p_clone)
|
504
|
+
|
505
|
+
def copy_params_to_emas(self):
|
506
|
+
with torch.no_grad():
|
507
|
+
for top_group in self.param_groups:
|
508
|
+
for group in self.get_groups(top_group):
|
509
|
+
active_p = [p for p in group['params']]
|
510
|
+
|
511
|
+
if not active_p:
|
512
|
+
return
|
513
|
+
|
514
|
+
for p in active_p:
|
515
|
+
if 'param_ema' in self.state_(p):
|
516
|
+
ema_clone = self.state_(p)['param_ema'].data.clone()
|
517
|
+
set_(self.state_(p)['param_ema'], p.data)
|
518
|
+
set_(p.data, ema_clone)
|
519
|
+
|
448
520
|
def step(self, closure: Optional[Callable] = None):
|
449
521
|
if closure is None:
|
450
522
|
loss = None
|
@@ -455,6 +527,8 @@ class StatefulOptimizer(torch.optim.Optimizer):
|
|
455
527
|
for top_group in self.param_groups:
|
456
528
|
for group in self.get_groups(top_group):
|
457
529
|
self._step(group)
|
530
|
+
if self.use_ema:
|
531
|
+
self.ema_update(group)
|
458
532
|
return loss
|
459
533
|
|
460
534
|
|
@@ -497,6 +571,20 @@ def copy_stochastic_list_(target: List[torch.Tensor], source: List[torch.Tensor]
|
|
497
571
|
copy_stochastic_(t, s)
|
498
572
|
|
499
573
|
|
574
|
+
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
|
575
|
+
def exp_avg_(exp_avg, exp_avg_sq, grad, grad_projected, beta1, beta2, step):
|
576
|
+
beta1 = beta_debias(beta1, step)
|
577
|
+
beta2 = beta_debias(beta2, step)
|
578
|
+
|
579
|
+
g32, gp32, exp_avg_sq32 = [list(map(promote, x)) for x in [grad, grad_projected, exp_avg_sq]]
|
580
|
+
|
581
|
+
stochastic_lerp_(exp_avg, g32, 1 - beta1)
|
582
|
+
denom = exp_avg_sq_(exp_avg_sq32, gp32, beta2, 1e-8)
|
583
|
+
|
584
|
+
copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
|
585
|
+
return denom
|
586
|
+
|
587
|
+
|
500
588
|
# this can be dynamic for most optimizers - just not for PSGD. So, it's disabled for all
|
501
589
|
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True)
|
502
590
|
def _compilable_copy_stochastic_(target: torch.Tensor, source: torch.Tensor):
|
@@ -523,23 +611,26 @@ def copy_stochastic_(target: torch.Tensor, source: torch.Tensor):
|
|
523
611
|
|
524
612
|
|
525
613
|
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
|
526
|
-
def
|
527
|
-
|
528
|
-
u32 =
|
614
|
+
def _compilable_update_(p, u, decay, add_fn, lr):
|
615
|
+
u = [u_.view_as(p_) for u_, p_ in zip(u, p)]
|
616
|
+
p32, u32 = [list(map(promote, x)) for x in [p, u]]
|
617
|
+
|
529
618
|
if decay > 0:
|
530
|
-
|
531
|
-
|
532
|
-
|
533
|
-
|
534
|
-
|
535
|
-
|
619
|
+
torch._foreach_mul_(p32, 1 - decay * lr)
|
620
|
+
|
621
|
+
for p32_, u32_ in zip(p32, u32): # lr is data-dependent -> can't compile a foreach
|
622
|
+
if add_fn is None:
|
623
|
+
p32_.add_(u32_, alpha=lr)
|
624
|
+
else:
|
625
|
+
add_fn(p32_, u32_, lr)
|
626
|
+
|
627
|
+
copy_stochastic_list_(p, p32)
|
536
628
|
|
537
629
|
|
538
630
|
def update_param_(param: List[torch.Tensor], update: List[torch.Tensor], lr: float, decay: float,
|
539
631
|
add_fn: callable = None):
|
540
632
|
lr_tensor = torch.empty((), dtype=torch.float32, device=param[0].device).fill_(lr)
|
541
|
-
|
542
|
-
_compilable_update_one_(p, u, decay, add_fn, lr_tensor)
|
633
|
+
_compilable_update_(param, update, decay, add_fn, lr_tensor)
|
543
634
|
|
544
635
|
|
545
636
|
def precond_schedule(step, precond_scheduler, rng):
|
@@ -638,12 +729,13 @@ def psgd_balance_Q(Q_in):
|
|
638
729
|
torch._foreach_mul_(Q_in, list(norms))
|
639
730
|
|
640
731
|
|
641
|
-
def psgd_calc_A_and_conjB(exprA, G, Q
|
642
|
-
md = min_dtype(Q)
|
643
|
-
A = torch.einsum(exprA, *[q.to(md) for q in Q], G.to(md))
|
732
|
+
def psgd_calc_A_and_conjB(exprA, G, Q):
|
733
|
+
md = min_dtype(Q + [G])
|
734
|
+
A = torch.einsum(exprA, *[q.to(md) for q in Q], G.to(md)).to(G.dtype)
|
644
735
|
order = G.dim()
|
645
736
|
p = list(range(order))
|
646
|
-
conjB = torch.
|
737
|
+
conjB = torch.randn_like(G.shape[1:] + G.shape[:1], dtype=promote(G.dtype), device=G.device)
|
738
|
+
Q = [promote(q) for q in Q]
|
647
739
|
for i, q in enumerate(Q):
|
648
740
|
if q.dim() <= 1:
|
649
741
|
conjB /= q
|
@@ -651,7 +743,7 @@ def psgd_calc_A_and_conjB(exprA, G, Q, V):
|
|
651
743
|
unsqueeze = conjB.dim() <= 1
|
652
744
|
if unsqueeze:
|
653
745
|
conjB = conjB.unsqueeze(0)
|
654
|
-
conjB = torch.linalg.solve_triangular(q, conjB, upper=True, left=False
|
746
|
+
conjB = torch.linalg.solve_triangular(q, conjB, upper=True, left=False)
|
655
747
|
if unsqueeze:
|
656
748
|
conjB = conjB.squeeze(0)
|
657
749
|
if i < order - 1:
|
@@ -661,33 +753,29 @@ def psgd_calc_A_and_conjB(exprA, G, Q, V):
|
|
661
753
|
|
662
754
|
def psgd_lb(A, max_abs):
|
663
755
|
A /= max_abs
|
664
|
-
|
665
|
-
|
666
|
-
|
667
|
-
|
668
|
-
|
669
|
-
|
670
|
-
x
|
671
|
-
x =
|
672
|
-
|
673
|
-
x = torch.where(comp, x, x.T)
|
674
|
-
torch.matmul(x, torch.where(comp, A, A.T), out=x.view(1, -1))
|
675
|
-
x /= torch.linalg.vector_norm(x)
|
676
|
-
torch.matmul(x, torch.where(comp, ah, ah.T), out=x.view(1, -1))
|
677
|
-
x = torch.linalg.vector_norm(x)
|
756
|
+
a0 = torch.einsum('ij,ij->j', A, A)
|
757
|
+
i = torch.argmax(a0)
|
758
|
+
|
759
|
+
x = torch.index_select(a, 1, i).flatten().contiguous()
|
760
|
+
|
761
|
+
x = torch.einsum('i,ij->j', x_, a)
|
762
|
+
x /= x.norm()
|
763
|
+
x = torch.einsum('j,kj->k', x_, a)
|
764
|
+
x = x.norm()
|
678
765
|
x *= max_abs
|
679
766
|
return x
|
680
767
|
|
681
768
|
|
682
|
-
|
769
|
+
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
|
770
|
+
def psgd_update_precond(Q, exprs, G, precond_lr, tiny, oq, store_triu_as_line):
|
683
771
|
"""Update Kronecker product preconditioner Q with pair (V, G)."""
|
684
772
|
exprA, exprGs, _ = exprs
|
685
773
|
|
686
|
-
A, conjB = psgd_calc_A_and_conjB(exprA, G, Q
|
774
|
+
A, conjB = psgd_calc_A_and_conjB(exprA, G, Q)
|
687
775
|
|
688
|
-
for q, exprG in zip(Q, exprGs):
|
689
|
-
term1 = torch.einsum(exprG, A, A
|
690
|
-
term2 = torch.einsum(exprG, conjB
|
776
|
+
for q, exprG, o in zip(Q, exprGs, oq):
|
777
|
+
term1 = promote(torch.einsum(exprG, A, A))
|
778
|
+
term2 = promote(torch.einsum(exprG, conjB, conjB))
|
691
779
|
|
692
780
|
term2 += term1 # a + b
|
693
781
|
term1 *= 2 # 2a
|
@@ -696,15 +784,19 @@ def psgd_update_precond(Q, exprs, V, G, step, tiny):
|
|
696
784
|
else:
|
697
785
|
term1 = term1 - term2
|
698
786
|
|
699
|
-
term1 *=
|
787
|
+
term1 *= precond_lr
|
700
788
|
norm = term2.norm(float('inf'))
|
701
789
|
if q.dim() < 2:
|
702
|
-
term1 *= q
|
703
|
-
|
790
|
+
term1 *= q.to(term1.dtype)
|
791
|
+
term1 /= norm.clamp_(min=tiny)
|
704
792
|
else:
|
705
793
|
torch.triu(term1, out=term1)
|
706
|
-
term1 /=
|
707
|
-
|
794
|
+
term1 /= psgd_lb(term2, norm).clamp_(tiny)
|
795
|
+
torch.matmul(term1, q, out=term1)
|
796
|
+
if store_triu_as_line:
|
797
|
+
term1 = triu_to_line([term1])[0][1]
|
798
|
+
o = o[1]
|
799
|
+
stochastic_add_([o], [term1], -1)
|
708
800
|
|
709
801
|
|
710
802
|
@decorator
|
@@ -838,18 +930,9 @@ class PSGDBase(StatefulOptimizer):
|
|
838
930
|
group[name] = cumulative_prob + prob
|
839
931
|
return int(group[name]) > int(cumulative_prob)
|
840
932
|
|
841
|
-
def do_update(self, group, p_list, grad_list, q_list, precond_lr, original_q:
|
842
|
-
|
843
|
-
|
844
|
-
if store_triu_as_line:
|
845
|
-
update_fn = update_triu_
|
846
|
-
else:
|
847
|
-
update_fn = copy_stochastic_list_
|
848
|
-
else:
|
849
|
-
update_fn = lambda x, y: None
|
850
|
-
for i, (p, grad, Q, oq) in enumerate(zip(p_list, grad_list, q_list, original_q)):
|
851
|
-
psgd_update_precond(Q, self.state_(p)["exprs"], torch.randn_like(grad), grad, precond_lr, self._tiny)
|
852
|
-
update_fn(oq, Q)
|
933
|
+
def do_update(self, group, p_list, grad_list, q_list, precond_lr, original_q: List, store_triu_as_line=False):
|
934
|
+
for p, grad, Q, oq in zip(p_list, grad_list, q_list, original_q):
|
935
|
+
psgd_update_precond(Q, self.state_(p)["exprs"], grad, precond_lr, self._tiny, oq, store_triu_as_line)
|
853
936
|
|
854
937
|
if self.should_update(group, self.balance_probability, "balance_prob"):
|
855
938
|
for g, q in zip(grad_list, original_q if original_q else q_list):
|
@@ -896,13 +979,19 @@ def merge_group(group, *tensors):
|
|
896
979
|
return out
|
897
980
|
|
898
981
|
|
899
|
-
def split_p_and_g_in_group(group: dict, skip_none: bool = True):
|
982
|
+
def split_p_and_g_in_group(group: dict, skip_none: bool = True, should_promote: bool = True):
|
900
983
|
for p in group["params"]:
|
901
984
|
if skip_none and p.grad is None:
|
902
985
|
continue
|
903
986
|
|
904
|
-
|
905
|
-
|
987
|
+
if p.grad is None:
|
988
|
+
grad = None
|
989
|
+
else:
|
990
|
+
if should_promote:
|
991
|
+
grad = promote(p.grad)
|
992
|
+
else:
|
993
|
+
grad = p.grad
|
994
|
+
p.grad = None
|
906
995
|
|
907
996
|
p_views = merge_group(group, p)
|
908
997
|
if grad is not None:
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: heavyball
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.20.1
|
4
4
|
Summary: Efficient optimizers
|
5
5
|
Home-page: https://github.com/clashluke/heavyball
|
6
6
|
Author: Lucas Nestler
|
@@ -32,7 +32,7 @@ A simple package of efficient optimizers
|
|
32
32
|
The goal is not to thrive for completeness, full maintenance or abstraction, but instead to provide a simple
|
33
33
|
largely static alternative to `torch.optim` with more and better optimizers.
|
34
34
|
|
35
|
-
Currently (2024-11-22, 0.19), the recommended stable optimizer is `PrecondSchedulePaLMSOAP` (see below). The
|
35
|
+
Currently (2024-11-22, 0.19.1), the recommended stable optimizer is `PrecondSchedulePaLMSOAP` (see below). The
|
36
36
|
recommended experimental optimizer is `DelayedPSGDKron` ([tuning guide](docs/psgd_efficiency.md)).
|
37
37
|
|
38
38
|
## Features
|
@@ -0,0 +1,24 @@
|
|
1
|
+
heavyball/__init__.py,sha256=iqP428JWwwx-XDOZ0nUdbCkOLEyfoqVyWZLQLAcwxaw,2214
|
2
|
+
heavyball/cached_delayed_psgd_kron.py,sha256=apVzESMaQ8uxunHvfvYfyWA8HLbS25wQSd3j_YNEjGs,6603
|
3
|
+
heavyball/cached_psgd_kron.py,sha256=3IETfsC0Ufu_8TPfo9SByGmztwjW6ktSFPwHNrUWkys,6601
|
4
|
+
heavyball/delayed_psgd.py,sha256=0LaazbiBZOdx78EDS-945cW3bmeORjUvdFOGqdw3aMs,5631
|
5
|
+
heavyball/foreach_adamw.py,sha256=Rb5U80cgUcEqlEbUU250UTWdoqA7nyiqkV5w1U4bWX4,2445
|
6
|
+
heavyball/foreach_adopt.py,sha256=ecdi1fKg9i087OGjtKWVbE_DD6Yf4pvpzv4ELCcusvQ,3211
|
7
|
+
heavyball/foreach_laprop.py,sha256=vi6C_gfjXxw5uN0KHgzxI9itUI1dcgOf3ufoO_VVMp0,2471
|
8
|
+
heavyball/foreach_sfadamw.py,sha256=rLZORmCIMu9G09FdDgMSiI6pNq34IVoxsPVWtmeDdbQ,2753
|
9
|
+
heavyball/foreach_soap.py,sha256=4mWSMWYTdjgiXiboI5DwdigecruDtNGKylGAFAVhCRA,4562
|
10
|
+
heavyball/p_adam.py,sha256=J5QqFAlyLTQ1eQzM0LGxPdv4fEtZikIv9mJ_SSkO3ZY,6033
|
11
|
+
heavyball/palm_foreach_sfadamw.py,sha256=JbNrcoquBGGUI5XNMFouDjpNurVHUW9DbX1A3tSrtno,3025
|
12
|
+
heavyball/palm_foreach_soap.py,sha256=GzAwM8kOt1X0QCmUZDTdHwPxbJwjH8ic43dyAK5BYCA,6015
|
13
|
+
heavyball/precond_schedule_foreach_soap.py,sha256=HcObXLfSNN_lKNb4nmC6tkdHcqDIMNX6hILpHKScqLc,4744
|
14
|
+
heavyball/precond_schedule_palm_foreach_soap.py,sha256=xZ7CJvIfdu2RNAZt2g1S7Xb0Jyy1hNC4MowOFU3nWkk,6283
|
15
|
+
heavyball/precond_schedule_sfpsoap.py,sha256=PNneiOkrRyV1yIZn91lPmYofd1_OiLqJTDy75RLpXJk,7273
|
16
|
+
heavyball/psgd_kron.py,sha256=RSLJi5FnSmYjvYBufNDClnYRm-eN_Kpa1Ar2tNP6-X0,5824
|
17
|
+
heavyball/pure_psgd.py,sha256=LZK0qmvZkBF8g00evaVLtW-sIUJmdoxag1K7O26AqEo,4820
|
18
|
+
heavyball/schedule_free_palm_foreach_soap.py,sha256=_AqrnChY6iDlQUkF2YUxS7eLjSWCIuvEUOHvMHVM1yY,6873
|
19
|
+
heavyball/utils.py,sha256=14vt4r_MeTsp1q3m0lpgF-Q3PCJg6GLGJrhjRxnbWwQ,35174
|
20
|
+
heavyball-0.20.1.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
|
21
|
+
heavyball-0.20.1.dist-info/METADATA,sha256=qzF2P7e2EREeTy_4h85tvUY53omjNm32z83CUHTqt3U,11926
|
22
|
+
heavyball-0.20.1.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
23
|
+
heavyball-0.20.1.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
|
24
|
+
heavyball-0.20.1.dist-info/RECORD,,
|
@@ -1,24 +0,0 @@
|
|
1
|
-
heavyball/__init__.py,sha256=iqP428JWwwx-XDOZ0nUdbCkOLEyfoqVyWZLQLAcwxaw,2214
|
2
|
-
heavyball/cached_delayed_psgd_kron.py,sha256=PQAER6UgVh5l87DGRZrJ8CVP9UhyCG5wJD9rPLnj_G8,6460
|
3
|
-
heavyball/cached_psgd_kron.py,sha256=GaeneBp0irksCSBIrJY4D_0hCpZ-uSRPMhqVX_a-og8,6417
|
4
|
-
heavyball/delayed_psgd.py,sha256=fhBWFLTSl1S2gHWCeYak-STaXRwpC56sWZGLFMKFEJM,5589
|
5
|
-
heavyball/foreach_adamw.py,sha256=Rb5U80cgUcEqlEbUU250UTWdoqA7nyiqkV5w1U4bWX4,2445
|
6
|
-
heavyball/foreach_adopt.py,sha256=ecdi1fKg9i087OGjtKWVbE_DD6Yf4pvpzv4ELCcusvQ,3211
|
7
|
-
heavyball/foreach_laprop.py,sha256=vi6C_gfjXxw5uN0KHgzxI9itUI1dcgOf3ufoO_VVMp0,2471
|
8
|
-
heavyball/foreach_sfadamw.py,sha256=rLZORmCIMu9G09FdDgMSiI6pNq34IVoxsPVWtmeDdbQ,2753
|
9
|
-
heavyball/foreach_soap.py,sha256=h6ptMch7oaynvu3eIJtWnVXypDA_5JDVm3Zb3PNEma0,4634
|
10
|
-
heavyball/p_adam.py,sha256=4zJDGJrpgUyVzr3GiELETFre4xr3-PE10OuAZj-jFM8,5883
|
11
|
-
heavyball/palm_foreach_sfadamw.py,sha256=JbNrcoquBGGUI5XNMFouDjpNurVHUW9DbX1A3tSrtno,3025
|
12
|
-
heavyball/palm_foreach_soap.py,sha256=g4hbiGRcti-J-a0SwAkP4ii5pU-aalsZH5bssyhroLk,5938
|
13
|
-
heavyball/precond_schedule_foreach_soap.py,sha256=WLg5SzpJnKPZUvFyIvdwSZa1Umt5cpr3Kow_42orM-E,4863
|
14
|
-
heavyball/precond_schedule_palm_foreach_soap.py,sha256=ammQrvRZFF-wc-wEiPEoFhS_7b8pdV61QfcLoQfimSo,6211
|
15
|
-
heavyball/precond_schedule_sfpsoap.py,sha256=vq7jd302refKPa_9X2lkOTOtCCcTBVByPdojklrY8pA,6770
|
16
|
-
heavyball/psgd_kron.py,sha256=wKjtI56iUnL5D8DseW60kxiXTAlMYNEf52CrvQaQMnI,5547
|
17
|
-
heavyball/pure_psgd.py,sha256=iUy7mMKWxwNiVUMYrQ7SBnreu3t_XSbnhTW3a1yw4m0,4835
|
18
|
-
heavyball/schedule_free_palm_foreach_soap.py,sha256=zkcikH5wWbzq4kOrmBjilvY3iWzuUddcv2HNEPKr3MI,6366
|
19
|
-
heavyball/utils.py,sha256=BWscCHlGOw1_zfKYxNAAmfFeOXVpSJHuvqqlfL5A7_0,31690
|
20
|
-
heavyball-0.19.0.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
|
21
|
-
heavyball-0.19.0.dist-info/METADATA,sha256=1wORoS9rrjlug9tuJqXsbtVA9PphOBGcifiLRxmZNjs,11924
|
22
|
-
heavyball-0.19.0.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
23
|
-
heavyball-0.19.0.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
|
24
|
-
heavyball-0.19.0.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|