heavyball 0.14.7__py3-none-any.whl → 0.15.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- heavyball/__init__.py +1 -1
- heavyball/delayed_psgd.py +39 -48
- heavyball/foreach_adamw.py +22 -32
- heavyball/foreach_adopt.py +38 -48
- heavyball/foreach_laprop.py +25 -35
- heavyball/foreach_sfadamw.py +28 -38
- heavyball/foreach_soap.py +56 -70
- heavyball/p_adam.py +39 -48
- heavyball/palm_foreach_sfadamw.py +31 -41
- heavyball/palm_foreach_soap.py +56 -70
- heavyball/precond_schedule_foreach_soap.py +57 -71
- heavyball/precond_schedule_palm_foreach_soap.py +58 -73
- heavyball/precond_schedule_sfpsoap.py +60 -72
- heavyball/psgd_kron.py +39 -47
- heavyball/pure_psgd.py +32 -41
- heavyball/schedule_free_palm_foreach_soap.py +61 -72
- heavyball/utils.py +17 -1
- {heavyball-0.14.7.dist-info → heavyball-0.15.0.dist-info}/METADATA +1 -1
- heavyball-0.15.0.dist-info/RECORD +22 -0
- heavyball-0.14.7.dist-info/RECORD +0 -22
- {heavyball-0.14.7.dist-info → heavyball-0.15.0.dist-info}/LICENSE +0 -0
- {heavyball-0.14.7.dist-info → heavyball-0.15.0.dist-info}/WHEEL +0 -0
- {heavyball-0.14.7.dist-info → heavyball-0.15.0.dist-info}/top_level.txt +0 -0
heavyball/__init__.py
CHANGED
@@ -1,3 +1,4 @@
|
|
1
|
+
from .delayed_psgd import ForeachDelayedPSGD
|
1
2
|
from .foreach_adamw import ForeachAdamW
|
2
3
|
from .foreach_adopt import ForeachADOPT
|
3
4
|
from .foreach_laprop import ForeachLaProp
|
@@ -12,7 +13,6 @@ from .precond_schedule_sfpsoap import PrecondScheduleSFPaLMSOAP
|
|
12
13
|
from .psgd_kron import ForeachPSGDKron
|
13
14
|
from .pure_psgd import ForeachPurePSGD
|
14
15
|
from .schedule_free_palm_foreach_soap import SFPaLMForeachSOAP
|
15
|
-
from .delayed_psgd import ForeachDelayedPSGD
|
16
16
|
|
17
17
|
PalmForEachSoap = PaLMForeachSOAP
|
18
18
|
|
heavyball/delayed_psgd.py
CHANGED
@@ -5,8 +5,8 @@ Source available at https://github.com/evanatyourservice/kron_torch/blob/97a2b5e
|
|
5
5
|
"""
|
6
6
|
|
7
7
|
import torch
|
8
|
-
|
9
8
|
from heavyball.utils import copy_stochastic_list_
|
9
|
+
|
10
10
|
from .utils import update_param_, warmup, psgd_precond_grad, init_Q_exprs, trust_region_clip_, PSGDBase, \
|
11
11
|
precond_update_prob_schedule, split_p_and_g_in_group, triu_to_line, line_to_triu, set_
|
12
12
|
|
@@ -63,13 +63,7 @@ class ForeachDelayedPSGD(PSGDBase):
|
|
63
63
|
|
64
64
|
self._prob_step = 0
|
65
65
|
|
66
|
-
|
67
|
-
def step(self, closure=None):
|
68
|
-
loss = None
|
69
|
-
if closure is not None:
|
70
|
-
with torch.enable_grad():
|
71
|
-
loss = closure()
|
72
|
-
|
66
|
+
def _step(self, group):
|
73
67
|
# update preconditioners all together
|
74
68
|
update_prob = self.preconditioner_update_probability
|
75
69
|
if callable(update_prob):
|
@@ -77,55 +71,52 @@ class ForeachDelayedPSGD(PSGDBase):
|
|
77
71
|
do_update = self.rng.random() < update_prob
|
78
72
|
self._prob_step += 1
|
79
73
|
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
beta = group['beta']
|
90
|
-
|
91
|
-
vals = []
|
74
|
+
momentum_into_precond_update = group.get("momentum_into_precond_update", True)
|
75
|
+
precond_init_scale = group['precond_init_scale']
|
76
|
+
max_size_triangular = group['max_size_triangular']
|
77
|
+
min_ndim_triangular = group['min_ndim_triangular']
|
78
|
+
memory_save_mode = group['memory_save_mode']
|
79
|
+
precond_lr = group['precond_lr']
|
80
|
+
weight_decay = group['weight_decay']
|
81
|
+
lr = group['lr']
|
82
|
+
beta = group['beta']
|
92
83
|
|
93
|
-
|
94
|
-
state = self.state_(p)
|
84
|
+
vals = []
|
95
85
|
|
96
|
-
|
97
|
-
|
98
|
-
Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular, min_ndim_triangular,
|
99
|
-
memory_save_mode, dtype=g.dtype)
|
100
|
-
state["Q"] = triu_to_line(Q)
|
86
|
+
for p, g in split_p_and_g_in_group(group):
|
87
|
+
state = self.state_(p)
|
101
88
|
|
102
|
-
|
89
|
+
if 'Q' not in state:
|
90
|
+
state["exp_avg"] = torch.zeros_like(g)
|
91
|
+
Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular, min_ndim_triangular,
|
92
|
+
memory_save_mode, dtype=g.dtype)
|
93
|
+
state["Q"] = triu_to_line(Q)
|
103
94
|
|
104
|
-
|
105
|
-
continue
|
95
|
+
vals.append((p, g, state["exp_avg"], state["Q"]))
|
106
96
|
|
107
|
-
|
108
|
-
|
97
|
+
if not vals:
|
98
|
+
return
|
109
99
|
|
110
|
-
|
100
|
+
p_list, grad_list, exp_avg_list, Q_list = zip(*vals)
|
101
|
+
del vals
|
111
102
|
|
112
|
-
|
103
|
+
group["step"] += 1
|
113
104
|
|
114
|
-
|
115
|
-
for i, (p, g) in enumerate(zip(p_list, grad_list)):
|
116
|
-
q_orig = Q_list.pop(0)
|
117
|
-
ea = exp_avg_list.pop(0)
|
118
|
-
q = line_to_triu(q_orig)
|
119
|
-
self.balance(do_update, [g], [q])
|
120
|
-
new = psgd_precond_grad(q, self.state_(p)["exprs"], ea)
|
105
|
+
torch._foreach_lerp_(exp_avg_list, grad_list, (1 - beta) / (1 - beta ** group["step"]))
|
121
106
|
|
122
|
-
|
123
|
-
|
124
|
-
|
107
|
+
Q_list, exp_avg_list = list(Q_list), list(exp_avg_list)
|
108
|
+
for i, (p, g) in enumerate(zip(p_list, grad_list)):
|
109
|
+
q_orig = Q_list.pop(0)
|
110
|
+
ea = exp_avg_list.pop(0)
|
111
|
+
q = line_to_triu(q_orig)
|
112
|
+
self.balance(do_update, [g], [q])
|
113
|
+
new = psgd_precond_grad(q, self.state_(p)["exprs"], ea)
|
125
114
|
|
126
|
-
|
115
|
+
if do_update:
|
116
|
+
self.do_update([p], [ea if momentum_into_precond_update else g], [q], precond_lr, [q_orig])
|
117
|
+
set_(g, new)
|
127
118
|
|
128
|
-
|
129
|
-
update_param_(p_list, grad_list, lr, weight_decay)
|
119
|
+
grad_list = self.clip_fn(grad_list)
|
130
120
|
|
131
|
-
|
121
|
+
lr = -warmup(lr, group['step'], group['warmup_steps'])
|
122
|
+
update_param_(p_list, grad_list, lr, weight_decay)
|
heavyball/foreach_adamw.py
CHANGED
@@ -10,42 +10,32 @@ class ForeachAdamW(StatefulOptimizer):
|
|
10
10
|
lr_max=-1.0, weight_decay=weight_decay)
|
11
11
|
super().__init__(params, defaults)
|
12
12
|
|
13
|
-
def
|
14
|
-
|
13
|
+
def _step(self, group):
|
14
|
+
eps = group['eps']
|
15
|
+
decay = group['weight_decay']
|
16
|
+
k = group['k']
|
15
17
|
|
16
|
-
|
17
|
-
|
18
|
-
and returns the loss.
|
19
|
-
"""
|
18
|
+
if not group['train_mode']:
|
19
|
+
raise Exception("Not in train mode!")
|
20
20
|
|
21
|
-
|
22
|
-
if closure is not None:
|
23
|
-
loss = closure()
|
21
|
+
active_p = [p for p in group['params'] if p.grad is not None]
|
24
22
|
|
25
|
-
|
26
|
-
|
27
|
-
decay = group['weight_decay']
|
28
|
-
k = group['k']
|
23
|
+
if not active_p:
|
24
|
+
return
|
29
25
|
|
30
|
-
|
31
|
-
|
26
|
+
for p in active_p:
|
27
|
+
if 'exp_avg' not in self.state_(p):
|
28
|
+
self.state_(p)['exp_avg'] = torch.zeros_like(p.data, dtype=torch.float32)
|
29
|
+
self.state_(p)['exp_avg_sq'] = torch.zeros_like(p.data, dtype=torch.float32)
|
32
30
|
|
33
|
-
|
31
|
+
y, grad, exp_avg_sq, exp_avg = zip(
|
32
|
+
*[(p.data, p.grad.float(), self.state_(p)['exp_avg_sq'], self.state_(p)['exp_avg']) for p in active_p])
|
34
33
|
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
self.state_(p)['exp_avg_sq'] = torch.zeros_like(p.data, dtype=torch.float32)
|
34
|
+
# Decay the first and second moment running average coefficient
|
35
|
+
torch._foreach_lerp_(exp_avg, grad, 1 - beta_debias(group['betas'][0], k + 1))
|
36
|
+
denom = exp_avg_sq_(exp_avg_sq, grad, beta_debias(group['betas'][1], k + 1), eps)
|
39
37
|
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
torch._foreach_lerp_(exp_avg, grad, 1 - beta_debias(group['betas'][0], k + 1))
|
45
|
-
denom = exp_avg_sq_(exp_avg_sq, grad, beta_debias(group['betas'][1], k + 1), eps)
|
46
|
-
|
47
|
-
# Normalize grad in-place for memory efficiency
|
48
|
-
lr = -warmup(group['lr'], k + 1, group['warmup_steps'])
|
49
|
-
update_param_(y, exp_avg, lr, decay, lambda p, e, l: torch._foreach_addcdiv_(p, e, denom, l))
|
50
|
-
group['k'] = k + 1
|
51
|
-
return loss
|
38
|
+
# Normalize grad in-place for memory efficiency
|
39
|
+
lr = -warmup(group['lr'], k + 1, group['warmup_steps'])
|
40
|
+
update_param_(y, exp_avg, lr, decay, lambda p, e, l: torch._foreach_addcdiv_(p, e, denom, l))
|
41
|
+
group['k'] = k + 1
|
heavyball/foreach_adopt.py
CHANGED
@@ -11,51 +11,41 @@ class ForeachADOPT(StatefulOptimizer):
|
|
11
11
|
lr_max=-1.0, weight_decay=weight_decay)
|
12
12
|
super().__init__(params, defaults)
|
13
13
|
|
14
|
-
def
|
15
|
-
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
torch._foreach_mul_(exp_avg, beta1)
|
53
|
-
torch._foreach_addcdiv_(exp_avg, grad, denom, 1 - beta1)
|
54
|
-
|
55
|
-
beta2 = beta_debias(group['betas'][1], k + 1)
|
56
|
-
torch._foreach_mul_(exp_avg_sq, beta2)
|
57
|
-
torch._foreach_addcmul_(exp_avg_sq, grad, grad, value=1 - beta2)
|
58
|
-
del grad
|
59
|
-
|
60
|
-
group['k'] = k + 1
|
61
|
-
return loss
|
14
|
+
def _step(self, group):
|
15
|
+
eps = group['eps']
|
16
|
+
decay = group['weight_decay']
|
17
|
+
k = group['k']
|
18
|
+
|
19
|
+
if not group['train_mode']:
|
20
|
+
raise Exception("Not in train mode!")
|
21
|
+
|
22
|
+
active_p = [p for p in group['params'] if p.grad is not None]
|
23
|
+
|
24
|
+
if not active_p:
|
25
|
+
return
|
26
|
+
|
27
|
+
for p in active_p:
|
28
|
+
if 'exp_avg' not in self.state_(p):
|
29
|
+
self.state_(p)['exp_avg'] = torch.zeros_like(p.data, dtype=torch.float32)
|
30
|
+
self.state_(p)['exp_avg_sq'] = torch.zeros_like(p.data, dtype=torch.float32)
|
31
|
+
|
32
|
+
y, grad, exp_avg_sq, exp_avg = zip(
|
33
|
+
*[(p.data, p.grad.float(), self.state_(p)['exp_avg_sq'], self.state_(p)['exp_avg']) for p in active_p])
|
34
|
+
|
35
|
+
if k > 1:
|
36
|
+
lr = -warmup(group['lr'], k - 1, group['warmup_steps'])
|
37
|
+
|
38
|
+
update_param_(y, exp_avg, lr, decay)
|
39
|
+
if k > 0:
|
40
|
+
beta1 = beta_debias(group['betas'][0], k)
|
41
|
+
denom = torch._foreach_sqrt(exp_avg_sq)
|
42
|
+
torch._foreach_maximum_(denom, eps)
|
43
|
+
torch._foreach_mul_(exp_avg, beta1)
|
44
|
+
torch._foreach_addcdiv_(exp_avg, grad, denom, 1 - beta1)
|
45
|
+
|
46
|
+
beta2 = beta_debias(group['betas'][1], k + 1)
|
47
|
+
torch._foreach_mul_(exp_avg_sq, beta2)
|
48
|
+
torch._foreach_addcmul_(exp_avg_sq, grad, grad, value=1 - beta2)
|
49
|
+
del grad
|
50
|
+
|
51
|
+
group['k'] = k + 1
|
heavyball/foreach_laprop.py
CHANGED
@@ -11,46 +11,36 @@ class ForeachLaProp(StatefulOptimizer):
|
|
11
11
|
lr_max=-1.0, weight_decay=weight_decay)
|
12
12
|
super().__init__(params, defaults)
|
13
13
|
|
14
|
-
def
|
15
|
-
|
14
|
+
def _step(self, group):
|
15
|
+
eps = group['eps']
|
16
|
+
decay = group['weight_decay']
|
17
|
+
k = group['k']
|
16
18
|
|
17
|
-
|
18
|
-
|
19
|
-
and returns the loss.
|
20
|
-
"""
|
19
|
+
if not group['train_mode']:
|
20
|
+
raise Exception("Not in train mode!")
|
21
21
|
|
22
|
-
|
23
|
-
if closure is not None:
|
24
|
-
loss = closure()
|
22
|
+
active_p = [p for p in group['params'] if p.grad is not None]
|
25
23
|
|
26
|
-
|
27
|
-
|
28
|
-
decay = group['weight_decay']
|
29
|
-
k = group['k']
|
24
|
+
if not active_p:
|
25
|
+
return
|
30
26
|
|
31
|
-
|
32
|
-
|
27
|
+
for p in active_p:
|
28
|
+
if 'exp_avg' not in self.state_(p):
|
29
|
+
self.state_(p)['exp_avg'] = torch.zeros_like(p.data, dtype=torch.float32)
|
30
|
+
self.state_(p)['exp_avg_sq'] = torch.zeros_like(p.data, dtype=torch.float32)
|
33
31
|
|
34
|
-
|
32
|
+
y, grad, exp_avg_sq, exp_avg = zip(
|
33
|
+
*[(p.data, p.grad.float(), self.state_(p)['exp_avg_sq'], self.state_(p)['exp_avg']) for p in active_p])
|
35
34
|
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
35
|
+
# Decay the first and second moment running average coefficient
|
36
|
+
denom = exp_avg_sq_(exp_avg_sq, grad, beta_debias(group['betas'][1], k + 1), eps)
|
37
|
+
beta1 = beta_debias(group['betas'][0], k + 1)
|
38
|
+
torch._foreach_mul_(exp_avg, beta1)
|
39
|
+
torch._foreach_addcdiv_(exp_avg, grad, denom, 1 - beta1)
|
40
|
+
del grad
|
40
41
|
|
41
|
-
|
42
|
-
|
42
|
+
# Normalize grad in-place for memory efficiency
|
43
|
+
lr = -warmup(group['lr'], k + 1, group['warmup_steps'])
|
44
|
+
update_param_(y, exp_avg, lr, decay)
|
43
45
|
|
44
|
-
|
45
|
-
denom = exp_avg_sq_(exp_avg_sq, grad, beta_debias(group['betas'][1], k + 1), eps)
|
46
|
-
beta1 = beta_debias(group['betas'][0], k + 1)
|
47
|
-
torch._foreach_mul_(exp_avg, beta1)
|
48
|
-
torch._foreach_addcdiv_(exp_avg, grad, denom, 1 - beta1)
|
49
|
-
del grad
|
50
|
-
|
51
|
-
# Normalize grad in-place for memory efficiency
|
52
|
-
lr = -warmup(group['lr'], k + 1, group['warmup_steps'])
|
53
|
-
update_param_(y, exp_avg, lr, decay)
|
54
|
-
|
55
|
-
group['k'] = k + 1
|
56
|
-
return loss
|
46
|
+
group['k'] = k + 1
|
heavyball/foreach_sfadamw.py
CHANGED
@@ -13,52 +13,42 @@ class ForeachSFAdamW(ScheduleFree):
|
|
13
13
|
foreach=foreach)
|
14
14
|
super().__init__(params, defaults)
|
15
15
|
|
16
|
-
def
|
17
|
-
|
16
|
+
def _step(self, group):
|
17
|
+
eps = group['eps']
|
18
|
+
decay = group['weight_decay']
|
19
|
+
k = group['k']
|
18
20
|
|
19
|
-
|
20
|
-
|
21
|
-
and returns the loss.
|
22
|
-
"""
|
21
|
+
if not group['train_mode']:
|
22
|
+
raise Exception("Not in train mode!")
|
23
23
|
|
24
|
-
|
25
|
-
if closure is not None:
|
26
|
-
loss = closure()
|
24
|
+
active_p = [p for p in group['params'] if p.grad is not None]
|
27
25
|
|
28
|
-
|
29
|
-
|
30
|
-
decay = group['weight_decay']
|
31
|
-
k = group['k']
|
26
|
+
if not active_p:
|
27
|
+
return
|
32
28
|
|
33
|
-
|
34
|
-
|
29
|
+
for p in active_p:
|
30
|
+
if 'z' not in self.state_(p):
|
31
|
+
self.state_(p)['z'] = torch.clone(p.data)
|
32
|
+
self.state_(p)['exp_avg_sq'] = torch.zeros_like(p.data, dtype=torch.float32)
|
35
33
|
|
36
|
-
|
34
|
+
y, grad, exp_avg_sq, z = zip(
|
35
|
+
*[(p.data, p.grad.float(), self.state_(p)['exp_avg_sq'], self.state_(p)['z']) for p in active_p])
|
37
36
|
|
38
|
-
|
39
|
-
|
40
|
-
self.state_(p)['z'] = torch.clone(p.data)
|
41
|
-
self.state_(p)['exp_avg_sq'] = torch.zeros_like(p.data, dtype=torch.float32)
|
37
|
+
# Decay the first moment running average coefficient
|
38
|
+
old_debiased = beta_debias(group['betas'][1], k + 1)
|
42
39
|
|
43
|
-
|
44
|
-
|
40
|
+
# Decay the first and second moment running average coefficient
|
41
|
+
denom = exp_avg_sq_(exp_avg_sq, grad, old_debiased, eps)
|
45
42
|
|
46
|
-
|
47
|
-
|
43
|
+
# Normalize grad in-place for memory efficiency
|
44
|
+
torch._foreach_div_(grad, denom)
|
48
45
|
|
49
|
-
|
50
|
-
|
46
|
+
# Weight decay calculated at y
|
47
|
+
if decay != 0:
|
48
|
+
torch._foreach_add_(grad, y, alpha=decay)
|
51
49
|
|
52
|
-
|
53
|
-
|
50
|
+
lr = warmup(group['lr'], k + 1, group['warmup_steps'])
|
51
|
+
group['weight_sum'] = schedule_free_(lr, group['weight_lr_power'], group['weight_sum'], group['betas'][0],
|
52
|
+
y, z, grad, group['r'], k + 1)
|
54
53
|
|
55
|
-
|
56
|
-
if decay != 0:
|
57
|
-
torch._foreach_add_(grad, y, alpha=decay)
|
58
|
-
|
59
|
-
lr = warmup(group['lr'], k + 1, group['warmup_steps'])
|
60
|
-
group['weight_sum'] = schedule_free_(lr, group['weight_lr_power'], group['weight_sum'], group['betas'][0],
|
61
|
-
y, z, grad, group['r'], k + 1)
|
62
|
-
|
63
|
-
group['k'] = k + 1
|
64
|
-
return loss
|
54
|
+
group['k'] = k + 1
|
heavyball/foreach_soap.py
CHANGED
@@ -34,73 +34,59 @@ class ForeachSOAP(StatefulOptimizer):
|
|
34
34
|
super().__init__(params, defaults)
|
35
35
|
self._data_format = data_format
|
36
36
|
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
#
|
84
|
-
#
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
exp_avg_projected = project(ea, state['Q'], False)
|
94
|
-
|
95
|
-
# Projecting back the preconditioned (by Adam) exponential moving average of gradients
|
96
|
-
# to the original space
|
97
|
-
# CANT DO /= HERE AS EXP_AVG MAY POINT TO THE BUFFER
|
98
|
-
set_(d, project(exp_avg_projected / d, state['Q'], True))
|
99
|
-
|
100
|
-
update_preconditioner(g, state, max_precond_dim, precondition_1d, old_debiased2,
|
101
|
-
step > 0 and step % group['precondition_frequency'] == 0)
|
102
|
-
|
103
|
-
# Why does this have to be rebiased here?
|
104
|
-
step_size = -group["lr"] * min(step / group['warmup_steps'], 1)
|
105
|
-
update_param_(p_list, denom, step_size, group["weight_decay"])
|
106
|
-
return loss
|
37
|
+
def _step(self, group):
|
38
|
+
vals = []
|
39
|
+
step = 0
|
40
|
+
|
41
|
+
max_precond_dim = group['max_precond_dim']
|
42
|
+
precondition_1d = group['precondition_1d']
|
43
|
+
|
44
|
+
for p, g in split_p_and_g_in_group(group):
|
45
|
+
state = self.state_(p)
|
46
|
+
step = state['step'] = state.get("step", -1) + 1
|
47
|
+
|
48
|
+
if "exp_avg" not in state:
|
49
|
+
state["exp_avg"] = torch.zeros_like(g, dtype=torch.float32)
|
50
|
+
state["exp_avg_sq"] = torch.zeros_like(g, dtype=torch.float32)
|
51
|
+
init_preconditioner(g, state, max_precond_dim, precondition_1d)
|
52
|
+
update_preconditioner(g, state, max_precond_dim, precondition_1d, 0, True)
|
53
|
+
continue # first step is skipped so that we never use the current gradients in the projection.
|
54
|
+
|
55
|
+
# Projecting gradients to the eigenbases of Shampoo's preconditioner
|
56
|
+
# i.e. projecting to the eigenbases of matrices in state['GG']
|
57
|
+
grad_projected = project(g, state['Q'], False)
|
58
|
+
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
|
59
|
+
vals.append((p, g, grad_projected, exp_avg, exp_avg_sq))
|
60
|
+
|
61
|
+
if not vals:
|
62
|
+
return
|
63
|
+
|
64
|
+
p_list, grad, grad_projected, exp_avg, exp_avg_sq = zip(*vals)
|
65
|
+
beta1, beta2 = group["betas"]
|
66
|
+
|
67
|
+
old_debiased1 = beta_debias(beta1, step)
|
68
|
+
old_debiased2 = beta_debias(beta2, step)
|
69
|
+
|
70
|
+
# Decay the first and second moment running average coefficient
|
71
|
+
# In-place operations to update the averages at the same time
|
72
|
+
torch._foreach_mul_(exp_avg, old_debiased1)
|
73
|
+
torch._foreach_add_(exp_avg, grad, alpha=1 - old_debiased1)
|
74
|
+
denom = exp_avg_sq_(exp_avg_sq, grad_projected, old_debiased2, group['eps'])
|
75
|
+
|
76
|
+
for p, g, ea, d in zip(p_list, grad, exp_avg, denom):
|
77
|
+
state = self.state_(p)
|
78
|
+
# Projecting the exponential moving average of gradients to the eigenbases of Shampoo's preconditioner
|
79
|
+
# i.e. projecting to the eigenbases of matrices in state['GG']
|
80
|
+
exp_avg_projected = project(ea, state['Q'], False)
|
81
|
+
|
82
|
+
# Projecting back the preconditioned (by Adam) exponential moving average of gradients
|
83
|
+
# to the original space
|
84
|
+
# CANT DO /= HERE AS EXP_AVG MAY POINT TO THE BUFFER
|
85
|
+
set_(d, project(exp_avg_projected / d, state['Q'], True))
|
86
|
+
|
87
|
+
update_preconditioner(g, state, max_precond_dim, precondition_1d, old_debiased2,
|
88
|
+
step > 0 and step % group['precondition_frequency'] == 0)
|
89
|
+
|
90
|
+
# Why does this have to be rebiased here?
|
91
|
+
step_size = -group["lr"] * min(step / group['warmup_steps'], 1)
|
92
|
+
update_param_(p_list, denom, step_size, group["weight_decay"])
|