heavyball 0.14.7__py3-none-any.whl → 0.15.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- heavyball/__init__.py +25 -3
- heavyball/cached_psgd_kron.py +141 -0
- heavyball/delayed_psgd.py +43 -51
- heavyball/foreach_adamw.py +22 -32
- heavyball/foreach_adopt.py +38 -48
- heavyball/foreach_laprop.py +25 -35
- heavyball/foreach_sfadamw.py +28 -38
- heavyball/foreach_soap.py +56 -70
- heavyball/p_adam.py +46 -50
- heavyball/palm_foreach_sfadamw.py +31 -41
- heavyball/palm_foreach_soap.py +56 -70
- heavyball/precond_schedule_foreach_soap.py +57 -71
- heavyball/precond_schedule_palm_foreach_soap.py +58 -73
- heavyball/precond_schedule_sfpsoap.py +60 -72
- heavyball/psgd_kron.py +43 -49
- heavyball/pure_psgd.py +36 -43
- heavyball/schedule_free_palm_foreach_soap.py +61 -72
- heavyball/utils.py +23 -7
- {heavyball-0.14.7.dist-info → heavyball-0.15.1.dist-info}/METADATA +1 -1
- heavyball-0.15.1.dist-info/RECORD +23 -0
- heavyball-0.14.7.dist-info/RECORD +0 -22
- {heavyball-0.14.7.dist-info → heavyball-0.15.1.dist-info}/LICENSE +0 -0
- {heavyball-0.14.7.dist-info → heavyball-0.15.1.dist-info}/WHEEL +0 -0
- {heavyball-0.14.7.dist-info → heavyball-0.15.1.dist-info}/top_level.txt +0 -0
heavyball/foreach_soap.py
CHANGED
@@ -34,73 +34,59 @@ class ForeachSOAP(StatefulOptimizer):
|
|
34
34
|
super().__init__(params, defaults)
|
35
35
|
self._data_format = data_format
|
36
36
|
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
#
|
84
|
-
#
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
exp_avg_projected = project(ea, state['Q'], False)
|
94
|
-
|
95
|
-
# Projecting back the preconditioned (by Adam) exponential moving average of gradients
|
96
|
-
# to the original space
|
97
|
-
# CANT DO /= HERE AS EXP_AVG MAY POINT TO THE BUFFER
|
98
|
-
set_(d, project(exp_avg_projected / d, state['Q'], True))
|
99
|
-
|
100
|
-
update_preconditioner(g, state, max_precond_dim, precondition_1d, old_debiased2,
|
101
|
-
step > 0 and step % group['precondition_frequency'] == 0)
|
102
|
-
|
103
|
-
# Why does this have to be rebiased here?
|
104
|
-
step_size = -group["lr"] * min(step / group['warmup_steps'], 1)
|
105
|
-
update_param_(p_list, denom, step_size, group["weight_decay"])
|
106
|
-
return loss
|
37
|
+
def _step(self, group):
|
38
|
+
vals = []
|
39
|
+
step = 0
|
40
|
+
|
41
|
+
max_precond_dim = group['max_precond_dim']
|
42
|
+
precondition_1d = group['precondition_1d']
|
43
|
+
|
44
|
+
for p, g in split_p_and_g_in_group(group):
|
45
|
+
state = self.state_(p)
|
46
|
+
step = state['step'] = state.get("step", -1) + 1
|
47
|
+
|
48
|
+
if "exp_avg" not in state:
|
49
|
+
state["exp_avg"] = torch.zeros_like(g, dtype=torch.float32)
|
50
|
+
state["exp_avg_sq"] = torch.zeros_like(g, dtype=torch.float32)
|
51
|
+
init_preconditioner(g, state, max_precond_dim, precondition_1d)
|
52
|
+
update_preconditioner(g, state, max_precond_dim, precondition_1d, 0, True)
|
53
|
+
continue # first step is skipped so that we never use the current gradients in the projection.
|
54
|
+
|
55
|
+
# Projecting gradients to the eigenbases of Shampoo's preconditioner
|
56
|
+
# i.e. projecting to the eigenbases of matrices in state['GG']
|
57
|
+
grad_projected = project(g, state['Q'], False)
|
58
|
+
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
|
59
|
+
vals.append((p, g, grad_projected, exp_avg, exp_avg_sq))
|
60
|
+
|
61
|
+
if not vals:
|
62
|
+
return
|
63
|
+
|
64
|
+
p_list, grad, grad_projected, exp_avg, exp_avg_sq = zip(*vals)
|
65
|
+
beta1, beta2 = group["betas"]
|
66
|
+
|
67
|
+
old_debiased1 = beta_debias(beta1, step)
|
68
|
+
old_debiased2 = beta_debias(beta2, step)
|
69
|
+
|
70
|
+
# Decay the first and second moment running average coefficient
|
71
|
+
# In-place operations to update the averages at the same time
|
72
|
+
torch._foreach_mul_(exp_avg, old_debiased1)
|
73
|
+
torch._foreach_add_(exp_avg, grad, alpha=1 - old_debiased1)
|
74
|
+
denom = exp_avg_sq_(exp_avg_sq, grad_projected, old_debiased2, group['eps'])
|
75
|
+
|
76
|
+
for p, g, ea, d in zip(p_list, grad, exp_avg, denom):
|
77
|
+
state = self.state_(p)
|
78
|
+
# Projecting the exponential moving average of gradients to the eigenbases of Shampoo's preconditioner
|
79
|
+
# i.e. projecting to the eigenbases of matrices in state['GG']
|
80
|
+
exp_avg_projected = project(ea, state['Q'], False)
|
81
|
+
|
82
|
+
# Projecting back the preconditioned (by Adam) exponential moving average of gradients
|
83
|
+
# to the original space
|
84
|
+
# CANT DO /= HERE AS EXP_AVG MAY POINT TO THE BUFFER
|
85
|
+
set_(d, project(exp_avg_projected / d, state['Q'], True))
|
86
|
+
|
87
|
+
update_preconditioner(g, state, max_precond_dim, precondition_1d, old_debiased2,
|
88
|
+
step > 0 and step % group['precondition_frequency'] == 0)
|
89
|
+
|
90
|
+
# Why does this have to be rebiased here?
|
91
|
+
step_size = -group["lr"] * min(step / group['warmup_steps'], 1)
|
92
|
+
update_param_(p_list, denom, step_size, group["weight_decay"])
|
heavyball/p_adam.py
CHANGED
@@ -5,6 +5,7 @@ Source available at https://github.com/evanatyourservice/kron_torch/blob/97a2b5e
|
|
5
5
|
"""
|
6
6
|
|
7
7
|
import torch
|
8
|
+
from heavyball.utils import triu_to_line, line_to_triu
|
8
9
|
|
9
10
|
from .utils import update_param_, warmup, psgd_precond_grad, init_Q_exprs, PSGDBase, precond_update_prob_schedule, \
|
10
11
|
exp_avg_sq_, beta_debias, split_p_and_g_in_group
|
@@ -36,7 +37,8 @@ class ForeachPaLMPAdam(PSGDBase):
|
|
36
37
|
def __init__(self, params, lr=0.001, weight_decay=0.0, preconditioner_update_probability=None,
|
37
38
|
max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
|
38
39
|
momentum_into_precond_update=True, warmup_steps: int = 1, betas=(None, None), beta: float = 0.9,
|
39
|
-
beta2_scale: float = 0.8, merge_dims: bool = False, split: bool = False, clip_fn: callable = None
|
40
|
+
beta2_scale: float = 0.8, merge_dims: bool = False, split: bool = False, clip_fn: callable = None,
|
41
|
+
store_triu_as_line: bool = True):
|
40
42
|
if not 0.0 <= lr:
|
41
43
|
raise ValueError(f"Invalid learning rate: {lr}")
|
42
44
|
if not 0.0 <= weight_decay:
|
@@ -57,18 +59,12 @@ class ForeachPaLMPAdam(PSGDBase):
|
|
57
59
|
# precond lr hardcoded to 0.1
|
58
60
|
precond_init_scale=1.0, # precond init scale hardcoded to 1.0
|
59
61
|
step=0, warmup_steps=warmup_steps, beta=beta, beta2_scale=beta2_scale, merge_dims=merge_dims,
|
60
|
-
split=split)
|
62
|
+
split=split, store_triu_as_line=store_triu_as_line)
|
61
63
|
super().__init__(params, defaults)
|
62
64
|
|
63
65
|
self._prob_step = 0
|
64
66
|
|
65
|
-
|
66
|
-
def step(self, closure=None):
|
67
|
-
loss = None
|
68
|
-
if closure is not None:
|
69
|
-
with torch.enable_grad():
|
70
|
-
loss = closure()
|
71
|
-
|
67
|
+
def _step(self, group):
|
72
68
|
# update preconditioners all together
|
73
69
|
update_prob = self.preconditioner_update_probability
|
74
70
|
if callable(update_prob):
|
@@ -76,57 +72,57 @@ class ForeachPaLMPAdam(PSGDBase):
|
|
76
72
|
do_update = self.rng.random() < update_prob
|
77
73
|
self._prob_step += 1
|
78
74
|
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
vals = []
|
75
|
+
precond_init_scale = group['precond_init_scale']
|
76
|
+
max_size_triangular = group['max_size_triangular']
|
77
|
+
min_ndim_triangular = group['min_ndim_triangular']
|
78
|
+
memory_save_mode = group['memory_save_mode']
|
79
|
+
precond_lr = group['precond_lr']
|
80
|
+
weight_decay = group['weight_decay']
|
81
|
+
lr = group['lr']
|
82
|
+
store_triu_as_line = group['store_triu_as_line']
|
89
83
|
|
90
|
-
|
91
|
-
state = self.state_(p)
|
84
|
+
vals = []
|
92
85
|
|
93
|
-
|
94
|
-
|
95
|
-
state['exp_avg_sq'] = torch.zeros_like(g)
|
96
|
-
state["Q"], state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular,
|
97
|
-
min_ndim_triangular, memory_save_mode, dtype=g.dtype)
|
86
|
+
for p, g in split_p_and_g_in_group(group):
|
87
|
+
state = self.state_(p)
|
98
88
|
|
99
|
-
|
89
|
+
if 'Q' not in state:
|
90
|
+
state['exp_avg'] = torch.zeros_like(g)
|
91
|
+
state['exp_avg_sq'] = torch.zeros_like(g)
|
92
|
+
Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular,
|
93
|
+
min_ndim_triangular, memory_save_mode, dtype=g.dtype)
|
94
|
+
state['Q'] = triu_to_line(Q) if store_triu_as_line else Q
|
100
95
|
|
101
|
-
|
102
|
-
continue
|
96
|
+
vals.append((p, g, state["Q"], state['exp_avg'], state['exp_avg_sq']))
|
103
97
|
|
104
|
-
|
105
|
-
|
98
|
+
if not vals:
|
99
|
+
return
|
106
100
|
|
107
|
-
|
101
|
+
p_list, grad_list, Q_list, exp_avg, exp_avg_sq = zip(*vals)
|
102
|
+
del vals
|
108
103
|
|
109
|
-
|
110
|
-
if do_update:
|
111
|
-
self.do_update(p_list, grad_list, Q_list, precond_lr)
|
104
|
+
group["step"] += 1
|
112
105
|
|
113
|
-
|
106
|
+
Q_triu = [line_to_triu(q) if store_triu_as_line else q for q in Q_list]
|
107
|
+
if do_update:
|
108
|
+
self.balance(grad_list, Q_triu)
|
109
|
+
self.do_update(p_list, grad_list, Q_triu, precond_lr, Q_list if store_triu_as_line else None)
|
114
110
|
|
115
|
-
|
111
|
+
torch._foreach_lerp_(exp_avg, grad_list, 1 - beta_debias(group['beta'], group['step']))
|
116
112
|
|
117
|
-
|
118
|
-
psgd_precond_grad(Q, self.state_(p)["exprs"], g, inplace=True)
|
119
|
-
ea = psgd_precond_grad(Q, self.state_(p)["exprs"], ea)
|
120
|
-
exp_avg_sq_(eas, g, beta_debias(beta2, group['step']), 1e-8, out=g)
|
121
|
-
torch.div(ea, g, out=g)
|
122
|
-
"""
|
123
|
-
divide by g here, because g == denom (from exp_avg_sq_(out=g)), avoids denom allocation
|
124
|
-
divide into g so we can deallocate ea, avoids one allocation (-> less memory than equivalent foreach)
|
125
|
-
"""
|
113
|
+
beta2 = 1 - group['step'] ** -group['beta2_scale']
|
126
114
|
|
127
|
-
|
115
|
+
for p, Q, g, ea, eas in zip(p_list, Q_triu, grad_list, exp_avg, exp_avg_sq):
|
116
|
+
psgd_precond_grad(Q, self.state_(p)["exprs"], g, inplace=True)
|
117
|
+
ea = psgd_precond_grad(Q, self.state_(p)["exprs"], ea)
|
118
|
+
exp_avg_sq_(eas, g, beta_debias(beta2, group['step']), 1e-8, out=g)
|
119
|
+
torch.div(ea, g, out=g)
|
120
|
+
"""
|
121
|
+
divide by g here, because g == denom (from exp_avg_sq_(out=g)), avoids denom allocation
|
122
|
+
divide into g so we can deallocate ea, avoids one allocation (-> less memory than equivalent foreach)
|
123
|
+
"""
|
128
124
|
|
129
|
-
|
130
|
-
update_param_(p_list, grad_list, lr, weight_decay)
|
125
|
+
grad_list = self.clip_fn(grad_list)
|
131
126
|
|
132
|
-
|
127
|
+
lr = -warmup(lr, group['step'], group['warmup_steps'])
|
128
|
+
update_param_(p_list, grad_list, lr, weight_decay)
|
@@ -9,58 +9,48 @@ class PaLMForeachSFAdamW(ScheduleFree):
|
|
9
9
|
weight_lr_power=2.0, beta2_scale: float = 0.8):
|
10
10
|
if betas[0] is not None:
|
11
11
|
beta = betas[0]
|
12
|
-
defaults = dict(lr=lr, beta=beta, eps=eps, r=r, k=0, warmup_steps=warmup_steps, train_mode=True,
|
13
|
-
|
12
|
+
defaults = dict(lr=lr, beta=beta, eps=eps, r=r, k=0, warmup_steps=warmup_steps, train_mode=True, weight_sum=0.0,
|
13
|
+
lr_max=-1.0, weight_lr_power=weight_lr_power, weight_decay=weight_decay,
|
14
14
|
beta2_scale=beta2_scale)
|
15
15
|
super().__init__(params, defaults)
|
16
16
|
|
17
|
-
def
|
18
|
-
|
17
|
+
def _step(self, group):
|
18
|
+
eps = group['eps']
|
19
|
+
decay = group['weight_decay']
|
20
|
+
k = group['k']
|
19
21
|
|
20
|
-
|
21
|
-
|
22
|
-
and returns the loss.
|
23
|
-
"""
|
22
|
+
if not group['train_mode']:
|
23
|
+
raise Exception("Not in train mode!")
|
24
24
|
|
25
|
-
|
26
|
-
if closure is not None:
|
27
|
-
loss = closure()
|
25
|
+
active_p = [p for p in group['params'] if p.grad is not None]
|
28
26
|
|
29
|
-
|
30
|
-
|
31
|
-
decay = group['weight_decay']
|
32
|
-
k = group['k']
|
27
|
+
if not active_p:
|
28
|
+
return
|
33
29
|
|
34
|
-
|
35
|
-
|
30
|
+
for p in active_p:
|
31
|
+
if 'z' not in self.state_(p):
|
32
|
+
self.state_(p)['z'] = torch.clone(p.data)
|
33
|
+
self.state_(p)['exp_avg_sq'] = torch.zeros_like(p.data, dtype=torch.float32)
|
36
34
|
|
37
|
-
|
35
|
+
y, grad, exp_avg_sq, z = zip(
|
36
|
+
*[(p.data, p.grad.float(), self.state_(p)['exp_avg_sq'], self.state_(p)['z']) for p in active_p])
|
38
37
|
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
self.state_(p)['exp_avg_sq'] = torch.zeros_like(p.data, dtype=torch.float32)
|
38
|
+
# Decay the first moment running average coefficient
|
39
|
+
beta2 = 1 - (k + 1) ** -group['beta2_scale']
|
40
|
+
old_debiased = beta_debias(beta2, k + 1)
|
43
41
|
|
44
|
-
|
45
|
-
|
42
|
+
# Decay the first and second moment running average coefficient
|
43
|
+
denom = exp_avg_sq_(exp_avg_sq, grad, old_debiased, eps)
|
46
44
|
|
47
|
-
|
48
|
-
|
49
|
-
old_debiased = beta_debias(beta2, k + 1)
|
45
|
+
# Normalize grad in-place for memory efficiency
|
46
|
+
torch._foreach_div_(grad, denom)
|
50
47
|
|
51
|
-
|
52
|
-
|
48
|
+
# Weight decay calculated at y
|
49
|
+
if decay != 0:
|
50
|
+
torch._foreach_add_(grad, y, alpha=decay)
|
53
51
|
|
54
|
-
|
55
|
-
|
52
|
+
lr = warmup(group['lr'], k + 1, group['warmup_steps'])
|
53
|
+
group['weight_sum'] = schedule_free_(lr, group['weight_lr_power'], group['weight_sum'], group['beta'], y, z,
|
54
|
+
grad, group['r'], k + 1)
|
56
55
|
|
57
|
-
|
58
|
-
if decay != 0:
|
59
|
-
torch._foreach_add_(grad, y, alpha=decay)
|
60
|
-
|
61
|
-
lr = warmup(group['lr'], k + 1, group['warmup_steps'])
|
62
|
-
group['weight_sum'] = schedule_free_(lr, group['weight_lr_power'], group['weight_sum'], group['beta'],
|
63
|
-
y, z, grad, group['r'], k + 1)
|
64
|
-
|
65
|
-
group['k'] = k + 1
|
66
|
-
return loss
|
56
|
+
group['k'] = k + 1
|
heavyball/palm_foreach_soap.py
CHANGED
@@ -43,73 +43,59 @@ class PaLMForeachSOAP(StatefulOptimizer):
|
|
43
43
|
super().__init__(params, defaults)
|
44
44
|
self._data_format = data_format
|
45
45
|
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
#
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
exp_avg_projected = project(ea, state['Q'], False)
|
103
|
-
|
104
|
-
# Projecting back the preconditioned (by Adam) exponential moving average of gradients
|
105
|
-
# to the original space
|
106
|
-
# CANT DO /= HERE AS EXP_AVG MAY POINT TO THE BUFFER
|
107
|
-
set_(d, project(exp_avg_projected / d, state['Q'], True))
|
108
|
-
|
109
|
-
update_preconditioner(g, state, max_precond_dim, precondition_1d, old_debiased2,
|
110
|
-
step > 0 and step % group['precondition_frequency'] == 0)
|
111
|
-
|
112
|
-
# Why does this have to be rebiased here?
|
113
|
-
step_size = -group["lr"] * min(step / group['warmup_steps'], 1)
|
114
|
-
update_param_(p_list, denom, step_size, group["weight_decay"])
|
115
|
-
return loss
|
46
|
+
def _step(self, group):
|
47
|
+
vals = []
|
48
|
+
step = 0
|
49
|
+
|
50
|
+
max_precond_dim = group['max_precond_dim']
|
51
|
+
precondition_1d = group['precondition_1d']
|
52
|
+
|
53
|
+
for p, g in split_p_and_g_in_group(group):
|
54
|
+
state = self.state_(p)
|
55
|
+
step = state['step'] = state.get("step", -1) + 1
|
56
|
+
|
57
|
+
if "exp_avg" not in state:
|
58
|
+
state["exp_avg"] = torch.zeros_like(g, dtype=torch.float32)
|
59
|
+
state["exp_avg_sq"] = torch.zeros_like(g, dtype=torch.float32)
|
60
|
+
init_preconditioner(g, state, max_precond_dim, precondition_1d)
|
61
|
+
update_preconditioner(g, state, max_precond_dim, precondition_1d, 0, True)
|
62
|
+
continue # first step is skipped so that we never use the current gradients in the projection.
|
63
|
+
|
64
|
+
# Projecting gradients to the eigenbases of Shampoo's preconditioner
|
65
|
+
# i.e. projecting to the eigenbases of matrices in state['GG']
|
66
|
+
grad_projected = project(g, state['Q'], False)
|
67
|
+
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
|
68
|
+
vals.append((p, g, grad_projected, exp_avg, exp_avg_sq))
|
69
|
+
|
70
|
+
if not vals:
|
71
|
+
return
|
72
|
+
|
73
|
+
p_list, grad, grad_projected, exp_avg, exp_avg_sq = zip(*vals)
|
74
|
+
beta1 = group["beta"]
|
75
|
+
|
76
|
+
beta2 = 1 - step ** -group['beta2_scale']
|
77
|
+
old_debiased1 = beta_debias(beta1, step)
|
78
|
+
old_debiased2 = beta_debias(beta2, step)
|
79
|
+
|
80
|
+
# Decay the first and second moment running average coefficient
|
81
|
+
# In-place operations to update the averages at the same time
|
82
|
+
torch._foreach_lerp_(exp_avg, grad, 1 - old_debiased1)
|
83
|
+
denom = exp_avg_sq_(exp_avg_sq, grad_projected, old_debiased2, group['eps'])
|
84
|
+
|
85
|
+
for p, g, ea, d in zip(p_list, grad, exp_avg, denom):
|
86
|
+
state = self.state_(p)
|
87
|
+
# Projecting the exponential moving average of gradients to the eigenbases of Shampoo's preconditioner
|
88
|
+
# i.e. projecting to the eigenbases of matrices in state['GG']
|
89
|
+
exp_avg_projected = project(ea, state['Q'], False)
|
90
|
+
|
91
|
+
# Projecting back the preconditioned (by Adam) exponential moving average of gradients
|
92
|
+
# to the original space
|
93
|
+
# CANT DO /= HERE AS EXP_AVG MAY POINT TO THE BUFFER
|
94
|
+
set_(d, project(exp_avg_projected / d, state['Q'], True))
|
95
|
+
|
96
|
+
update_preconditioner(g, state, max_precond_dim, precondition_1d, old_debiased2,
|
97
|
+
step > 0 and step % group['precondition_frequency'] == 0)
|
98
|
+
|
99
|
+
# Why does this have to be rebiased here?
|
100
|
+
step_size = -group["lr"] * min(step / group['warmup_steps'], 1)
|
101
|
+
update_param_(p_list, denom, step_size, group["weight_decay"])
|
@@ -37,74 +37,60 @@ class PrecondScheduleForeachSOAP(StatefulOptimizer):
|
|
37
37
|
self._data_format = data_format
|
38
38
|
self.rng = random.Random(0x120983109)
|
39
39
|
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
#
|
87
|
-
#
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
exp_avg_projected = project(ea, state['Q'], False)
|
98
|
-
|
99
|
-
# Projecting back the preconditioned (by Adam) exponential moving average of gradients
|
100
|
-
# to the original space
|
101
|
-
# CANT DO /= HERE AS EXP_AVG MAY POINT TO THE BUFFER
|
102
|
-
set_(d, project(exp_avg_projected / d, state['Q'], True))
|
103
|
-
|
104
|
-
update_preconditioner(g, state, max_precond_dim, precondition_1d, old_debiased2,
|
105
|
-
update_precond)
|
106
|
-
|
107
|
-
# Why does this have to be rebiased here?
|
108
|
-
step_size = -group["lr"] * min(step / group['warmup_steps'], 1)
|
109
|
-
update_param_(p_list, denom, step_size, group["weight_decay"])
|
110
|
-
return loss
|
40
|
+
def _step(self, group):
|
41
|
+
vals = []
|
42
|
+
step = 0
|
43
|
+
|
44
|
+
max_precond_dim = group['max_precond_dim']
|
45
|
+
precondition_1d = group['precondition_1d']
|
46
|
+
|
47
|
+
for p, g in split_p_and_g_in_group(group):
|
48
|
+
state = self.state_(p)
|
49
|
+
step = state['step'] = state.get("step", -1) + 1
|
50
|
+
|
51
|
+
if "exp_avg" not in state:
|
52
|
+
state["exp_avg"] = torch.zeros_like(g, dtype=torch.float32)
|
53
|
+
state["exp_avg_sq"] = torch.zeros_like(g, dtype=torch.float32)
|
54
|
+
init_preconditioner(g, state, max_precond_dim, precondition_1d)
|
55
|
+
update_preconditioner(g, state, max_precond_dim, precondition_1d, 0, True)
|
56
|
+
continue # first step is skipped so that we never use the current gradients in the projection.
|
57
|
+
|
58
|
+
# Projecting gradients to the eigenbases of Shampoo's preconditioner
|
59
|
+
# i.e. projecting to the eigenbases of matrices in state['GG']
|
60
|
+
grad_projected = project(g, state['Q'], False)
|
61
|
+
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
|
62
|
+
vals.append((p, g, grad_projected, exp_avg, exp_avg_sq))
|
63
|
+
|
64
|
+
if not vals:
|
65
|
+
return
|
66
|
+
|
67
|
+
p_list, grad, grad_projected, exp_avg, exp_avg_sq = zip(*vals)
|
68
|
+
beta1, beta2 = group["betas"]
|
69
|
+
|
70
|
+
old_debiased1 = beta_debias(beta1, step)
|
71
|
+
old_debiased2 = beta_debias(beta2, step)
|
72
|
+
|
73
|
+
# Decay the first and second moment running average coefficient
|
74
|
+
# In-place operations to update the averages at the same time
|
75
|
+
torch._foreach_mul_(exp_avg, old_debiased1)
|
76
|
+
torch._foreach_add_(exp_avg, grad, alpha=1 - old_debiased1)
|
77
|
+
denom = exp_avg_sq_(exp_avg_sq, grad_projected, old_debiased2, group['eps'])
|
78
|
+
|
79
|
+
update_precond = precond_schedule(step, group['precond_scheduler'], self.rng)
|
80
|
+
for p, g, ea, d in zip(p_list, grad, exp_avg, denom):
|
81
|
+
state = self.state_(p)
|
82
|
+
# Projecting the exponential moving average of gradients to the eigenbases of Shampoo's preconditioner
|
83
|
+
# i.e. projecting to the eigenbases of matrices in state['GG']
|
84
|
+
exp_avg_projected = project(ea, state['Q'], False)
|
85
|
+
|
86
|
+
# Projecting back the preconditioned (by Adam) exponential moving average of gradients
|
87
|
+
# to the original space
|
88
|
+
# CANT DO /= HERE AS EXP_AVG MAY POINT TO THE BUFFER
|
89
|
+
set_(d, project(exp_avg_projected / d, state['Q'], True))
|
90
|
+
|
91
|
+
update_preconditioner(g, state, max_precond_dim, precondition_1d, old_debiased2,
|
92
|
+
update_precond)
|
93
|
+
|
94
|
+
# Why does this have to be rebiased here?
|
95
|
+
step_size = -group["lr"] * min(step / group['warmup_steps'], 1)
|
96
|
+
update_param_(p_list, denom, step_size, group["weight_decay"])
|