heavyball 0.25.1__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- heavyball/__init__.py +193 -40
- heavyball/chainable.py +475 -0
- heavyball/utils.py +318 -187
- {heavyball-0.25.1.dist-info → heavyball-1.0.0.dist-info}/METADATA +4 -3
- heavyball-1.0.0.dist-info/RECORD +8 -0
- heavyball/cached_delayed_psgd_kron.py +0 -135
- heavyball/cached_psgd_kron.py +0 -136
- heavyball/delayed_psgd.py +0 -122
- heavyball/foreach_adamw.py +0 -63
- heavyball/foreach_adopt.py +0 -83
- heavyball/foreach_laprop.py +0 -67
- heavyball/foreach_sfadamw.py +0 -69
- heavyball/foreach_soap.py +0 -93
- heavyball/foreach_solp.py +0 -89
- heavyball/p_adam.py +0 -121
- heavyball/palm_foreach_sfadamw.py +0 -77
- heavyball/palm_foreach_soap.py +0 -101
- heavyball/palm_foreach_solp.py +0 -98
- heavyball/precond_schedule_foreach_soap.py +0 -95
- heavyball/precond_schedule_foreach_solp.py +0 -95
- heavyball/precond_schedule_palm_foreach_soap.py +0 -105
- heavyball/precond_schedule_palm_foreach_solp.py +0 -103
- heavyball/precond_schedule_sfpsoap.py +0 -141
- heavyball/psgd_kron.py +0 -120
- heavyball/pure_psgd.py +0 -105
- heavyball/schedule_free_palm_foreach_soap.py +0 -136
- heavyball-0.25.1.dist-info/RECORD +0 -28
- {heavyball-0.25.1.dist-info → heavyball-1.0.0.dist-info}/LICENSE +0 -0
- {heavyball-0.25.1.dist-info → heavyball-1.0.0.dist-info}/WHEEL +0 -0
- {heavyball-0.25.1.dist-info → heavyball-1.0.0.dist-info}/top_level.txt +0 -0
@@ -1,95 +0,0 @@
|
|
1
|
-
import random
|
2
|
-
|
3
|
-
import torch
|
4
|
-
|
5
|
-
from .utils import init_preconditioner, update_preconditioner, project, beta_debias, update_param_, \
|
6
|
-
precond_schedule, set_, StatefulOptimizer, exp_avg_
|
7
|
-
|
8
|
-
|
9
|
-
class PrecondScheduleForeachSOAP(StatefulOptimizer):
|
10
|
-
"""
|
11
|
-
Sources:
|
12
|
-
Preconditioner Schedules:
|
13
|
-
Preconditioned Stochastic Gradient Descent
|
14
|
-
Xi-Lin Li, Omead Pooladzandi, Evan Walters
|
15
|
-
https://arxiv.org/abs/1512.04202
|
16
|
-
https://github.com/evanatyourservice/kron_torch
|
17
|
-
https://github.com/lixilinx/psgd_torch
|
18
|
-
|
19
|
-
Baseline SOAP:
|
20
|
-
SOAP: Improving and Stabilizing Shampoo using Adam
|
21
|
-
Nikhil Vyas, Depen Morwani, Rosie Zhao, Itai Shapira, David Brandfonbrener, Lucas Janson, Sham Kakade
|
22
|
-
https://arxiv.org/abs/2409.11321
|
23
|
-
https://github.com/nikhilvyas/SOAP
|
24
|
-
"""
|
25
|
-
|
26
|
-
def __init__(self, params, lr: float = 3e-3, betas=(0.9, 0.95), shampoo_beta: float = 0.95, eps: float = 1e-8,
|
27
|
-
weight_decay: float = 0.01, precondition_frequency: int = 2, max_precond_dim: int = 2048, #
|
28
|
-
merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
|
29
|
-
data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1,
|
30
|
-
precond_scheduler=(1 / 3, 9), split: bool = False, foreach: bool = True, mars: bool = False,
|
31
|
-
caution: bool = False, mars_gamma: float = 0.0025):
|
32
|
-
defaults = {"lr": lr, "betas": betas, "shampoo_beta": shampoo_beta, "eps": eps, "weight_decay": weight_decay,
|
33
|
-
"precondition_frequency": precondition_frequency, "max_precond_dim": max_precond_dim,
|
34
|
-
"merge_dims": merge_dims, "precondition_1d": precondition_1d, "normalize_grads": normalize_grads,
|
35
|
-
"correct_bias": correct_bias, 'warmup_steps': warmup_steps, 'precond_scheduler': precond_scheduler,
|
36
|
-
'split': split, 'mars': mars, 'caution': caution, 'mars_gamma': mars_gamma}
|
37
|
-
super().__init__(params, defaults, foreach)
|
38
|
-
self._data_format = data_format
|
39
|
-
self.rng = random.Random(0x120983109)
|
40
|
-
|
41
|
-
def _step(self, group):
|
42
|
-
vals = []
|
43
|
-
step = 0
|
44
|
-
|
45
|
-
max_precond_dim = group['max_precond_dim']
|
46
|
-
precondition_1d = group['precondition_1d']
|
47
|
-
|
48
|
-
for p, g in self.split_p_and_g_in_group(group, beta1=group['betas'][0]):
|
49
|
-
state = self.state_(p)
|
50
|
-
step = state['step'] = state.get("step", -1) + 1
|
51
|
-
|
52
|
-
if "exp_avg" not in state:
|
53
|
-
state["exp_avg"] = torch.zeros_like(g, dtype=torch.float32, memory_format=torch.preserve_format)
|
54
|
-
state["exp_avg_sq"] = torch.zeros_like(g, dtype=torch.float32, memory_format=torch.preserve_format)
|
55
|
-
init_preconditioner(g, state, max_precond_dim, precondition_1d)
|
56
|
-
update_preconditioner(g, state, max_precond_dim, precondition_1d, 0, True)
|
57
|
-
continue # first step is skipped so that we never use the current gradients in the projection.
|
58
|
-
|
59
|
-
# Projecting gradients to the eigenbases of Shampoo's preconditioner
|
60
|
-
# i.e. projecting to the eigenbases of matrices in state['GG']
|
61
|
-
grad_projected = project(g, state['Q'], False)
|
62
|
-
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
|
63
|
-
vals.append((p, g, grad_projected, exp_avg, exp_avg_sq))
|
64
|
-
|
65
|
-
if not vals:
|
66
|
-
return
|
67
|
-
|
68
|
-
p_list, grad, grad_projected, exp_avg, exp_avg_sq = zip(*vals)
|
69
|
-
beta1, beta2 = group["betas"]
|
70
|
-
|
71
|
-
old_debiased2 = beta_debias(beta2, step)
|
72
|
-
|
73
|
-
# Decay the first and second moment running average coefficient
|
74
|
-
# In-place operations to update the averages at the same time
|
75
|
-
step_tensor = torch.empty((), dtype=torch.int32, device=p_list[0].device).fill_(step)
|
76
|
-
|
77
|
-
update_precond = precond_schedule(step, group['precond_scheduler'], self.rng)
|
78
|
-
step_size = -group["lr"] * min(step / group['warmup_steps'], 1)
|
79
|
-
|
80
|
-
for p, g, gp, ea, eas in zip(p_list, grad, grad_projected, exp_avg, exp_avg_sq):
|
81
|
-
d = exp_avg_(ea, eas, g, gp, beta1, beta2, step_tensor)[0]
|
82
|
-
state = self.state_(p)
|
83
|
-
# Projecting the exponential moving average of gradients to the eigenbases of Shampoo's preconditioner
|
84
|
-
# i.e. projecting to the eigenbases of matrices in state['GG']
|
85
|
-
exp_avg_projected = project(ea, state['Q'], False)
|
86
|
-
|
87
|
-
# Projecting back the preconditioned (by Adam) exponential moving average of gradients
|
88
|
-
# to the original space
|
89
|
-
# CANT DO /= HERE AS EXP_AVG MAY POINT TO THE BUFFER
|
90
|
-
precond = project(exp_avg_projected / d, state['Q'], True)
|
91
|
-
|
92
|
-
update_preconditioner(g, state, max_precond_dim, precondition_1d, old_debiased2, update_precond)
|
93
|
-
|
94
|
-
update_param_([p], [precond], step_size, group["weight_decay"], caution=group['caution'], grad=[g])
|
95
|
-
|
@@ -1,95 +0,0 @@
|
|
1
|
-
|
2
|
-
|
3
|
-
import random
|
4
|
-
|
5
|
-
import torch
|
6
|
-
|
7
|
-
from .utils import init_preconditioner, update_preconditioner, project, beta_debias, update_param_, \
|
8
|
-
precond_schedule, set_, StatefulOptimizer, laprop_exp_avg_
|
9
|
-
|
10
|
-
|
11
|
-
class PrecondScheduleForeachSOLP(StatefulOptimizer):
|
12
|
-
"""
|
13
|
-
Sources:
|
14
|
-
LaProp:
|
15
|
-
LaProp: Separating Momentum and Adaptivity in Adam
|
16
|
-
Liu Ziyin, Zhikang T.Wang, Masahito Ueda
|
17
|
-
https://arxiv.org/abs/2002.04839
|
18
|
-
https://github.com/ClashLuke/HeavyBall/blob/main/heavyball/foreach_laprop.py
|
19
|
-
|
20
|
-
Preconditioner Schedules:
|
21
|
-
Preconditioned Stochastic Gradient Descent
|
22
|
-
Xi-Lin Li, Omead Pooladzandi, Evan Walters
|
23
|
-
https://arxiv.org/abs/1512.04202
|
24
|
-
https://github.com/evanatyourservice/kron_torch
|
25
|
-
https://github.com/lixilinx/psgd_torch
|
26
|
-
|
27
|
-
Baseline SOAP:
|
28
|
-
SOAP: Improving and Stabilizing Shampoo using Adam
|
29
|
-
Nikhil Vyas, Depen Morwani, Rosie Zhao, Itai Shapira, David Brandfonbrener, Lucas Janson, Sham Kakade
|
30
|
-
https://arxiv.org/abs/2409.11321
|
31
|
-
https://github.com/nikhilvyas/SOAP
|
32
|
-
"""
|
33
|
-
|
34
|
-
def __init__(self, params, lr: float = 3e-3, betas=(0.9, 0.95), shampoo_beta: float = 0.95, eps: float = 1e-8,
|
35
|
-
weight_decay: float = 0.01, precondition_frequency: int = 2, max_precond_dim: int = 2048, #
|
36
|
-
merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
|
37
|
-
data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1,
|
38
|
-
precond_scheduler=(1 / 3, 9), split: bool = False, foreach: bool = True, mars: bool = False,
|
39
|
-
caution: bool = False, mars_gamma: float = 0.0025):
|
40
|
-
defaults = {"lr": lr, "betas": betas, "shampoo_beta": shampoo_beta, "eps": eps, "weight_decay": weight_decay,
|
41
|
-
"precondition_frequency": precondition_frequency, "max_precond_dim": max_precond_dim,
|
42
|
-
"merge_dims": merge_dims, "precondition_1d": precondition_1d, "normalize_grads": normalize_grads,
|
43
|
-
"correct_bias": correct_bias, 'warmup_steps': warmup_steps, 'precond_scheduler': precond_scheduler,
|
44
|
-
'split': split, 'mars': mars, 'caution': caution, 'mars_gamma': mars_gamma}
|
45
|
-
super().__init__(params, defaults, foreach)
|
46
|
-
self._data_format = data_format
|
47
|
-
self.rng = random.Random(0x120983109)
|
48
|
-
|
49
|
-
def _step(self, group):
|
50
|
-
vals = []
|
51
|
-
step = 0
|
52
|
-
|
53
|
-
max_precond_dim = group['max_precond_dim']
|
54
|
-
precondition_1d = group['precondition_1d']
|
55
|
-
|
56
|
-
for p, g in self.split_p_and_g_in_group(group, beta1=group['betas'][0]):
|
57
|
-
state = self.state_(p)
|
58
|
-
step = state['step'] = state.get("step", -1) + 1
|
59
|
-
|
60
|
-
if "exp_avg" not in state:
|
61
|
-
state["exp_avg"] = torch.zeros_like(g, dtype=torch.float32, memory_format=torch.preserve_format)
|
62
|
-
state["exp_avg_sq"] = torch.zeros_like(g, dtype=torch.float32, memory_format=torch.preserve_format)
|
63
|
-
init_preconditioner(g, state, max_precond_dim, precondition_1d)
|
64
|
-
update_preconditioner(g, state, max_precond_dim, precondition_1d, 0, True)
|
65
|
-
continue # first step is skipped so that we never use the current gradients in the projection.
|
66
|
-
|
67
|
-
# Projecting gradients to the eigenbases of Shampoo's preconditioner
|
68
|
-
# i.e. projecting to the eigenbases of matrices in state['GG']
|
69
|
-
grad_projected = project(g, state['Q'], False)
|
70
|
-
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
|
71
|
-
vals.append((p, g, grad_projected, exp_avg, exp_avg_sq))
|
72
|
-
|
73
|
-
if not vals:
|
74
|
-
return
|
75
|
-
|
76
|
-
p_list, grad, grad_projected, exp_avg, exp_avg_sq = zip(*vals)
|
77
|
-
beta1, beta2 = group["betas"]
|
78
|
-
|
79
|
-
old_debiased2 = beta_debias(beta2, step)
|
80
|
-
|
81
|
-
# Decay the first and second moment running average coefficient
|
82
|
-
# In-place operations to update the averages at the same time
|
83
|
-
step_tensor = torch.empty((), dtype=torch.int32, device=p_list[0].device).fill_(step)
|
84
|
-
|
85
|
-
update_precond = precond_schedule(step, group['precond_scheduler'], self.rng)
|
86
|
-
step_size = -group["lr"] * min(step / group['warmup_steps'], 1)
|
87
|
-
|
88
|
-
|
89
|
-
for p, g, gp, ea, eas in zip(p_list, grad, grad_projected, exp_avg, exp_avg_sq):
|
90
|
-
laprop_exp_avg_(ea, eas, gp, beta1, beta2, step_tensor)
|
91
|
-
state = self.state_(p)
|
92
|
-
precond = project(ea, state['Q'], True)
|
93
|
-
|
94
|
-
update_preconditioner(g, state, max_precond_dim, precondition_1d, old_debiased2, update_precond)
|
95
|
-
update_param_([p], [precond], step_size, group["weight_decay"], caution=group['caution'], grad=[g])
|
@@ -1,105 +0,0 @@
|
|
1
|
-
import random
|
2
|
-
|
3
|
-
import torch
|
4
|
-
|
5
|
-
from .utils import init_preconditioner, update_preconditioner, project, beta_debias, exp_avg_, update_param_, \
|
6
|
-
precond_schedule, set_, StatefulOptimizer
|
7
|
-
|
8
|
-
|
9
|
-
class PrecondSchedulePaLMForeachSOAP(StatefulOptimizer):
|
10
|
-
"""
|
11
|
-
Sources:
|
12
|
-
Preconditioner Schedules:
|
13
|
-
Preconditioned Stochastic Gradient Descent
|
14
|
-
Xi-Lin Li, Omead Pooladzandi, Evan Walters
|
15
|
-
https://arxiv.org/abs/1512.04202
|
16
|
-
https://github.com/evanatyourservice/kron_torch
|
17
|
-
https://github.com/lixilinx/psgd_torch
|
18
|
-
|
19
|
-
Baseline SOAP:
|
20
|
-
SOAP: Improving and Stabilizing Shampoo using Adam
|
21
|
-
Nikhil Vyas, Depen Morwani, Rosie Zhao, Itai Shapira, David Brandfonbrener, Lucas Janson, Sham Kakade
|
22
|
-
https://arxiv.org/abs/2409.11321
|
23
|
-
https://github.com/nikhilvyas/SOAP
|
24
|
-
|
25
|
-
Beta2 Schedule:
|
26
|
-
PaLM: Scaling Language Modeling with Pathways
|
27
|
-
Aakanksha Chowdhery, Sharan Narang, Jacob Devlin, Maarten Bosma, Gaurav Mishra, Adam Roberts, Paul Barham, Hyung Won Chung, Charles Sutton, Sebastian Gehrmann, Parker Schuh, Kensen Shi, Sasha Tsvyashchenko, Joshua Maynez, Abhishek Rao, Parker Barnes, Yi Tay, Noam Shazeer, Vinodkumar Prabhakaran, Emily Reif, Nan Du, Ben Hutchinson, Reiner Pope, James Bradbury, Jacob Austin, Michael Isard, Guy Gur-Ari, Pengcheng Yin, Toju Duke, Anselm Levskaya, Sanjay Ghemawat, Sunipa Dev, Henryk Michalewski, Xavier Garcia, Vedant Misra, Kevin Robinson, Liam Fedus, Denny Zhou, Daphne Ippolito, David Luan, Hyeontaek Lim, Barret Zoph, Alexander Spiridonov, Ryan Sepassi, David Dohan, Shivani Agrawal, Mark Omernick, Andrew M. Dai, Thanumalayan Sankaranarayana Pillai, Marie Pellat, Aitor Lewkowycz, Erica Moreira, Rewon Child, Oleksandr Polozov, Katherine Lee, Zongwei Zhou, Xuezhi Wang, Brennan Saeta, Mark Diaz, Orhan Firat, Michele Catasta, Jason Wei, Kathy Meier-Hellstern, Douglas Eck, Jeff Dean, Slav Petrov, Noah Fiedel
|
28
|
-
https://arxiv.org/abs/2204.02311
|
29
|
-
"""
|
30
|
-
|
31
|
-
def __init__(self, params, lr: float = 3e-3, beta=0.9, shampoo_beta: float = 0.95, eps: float = 1e-8,
|
32
|
-
weight_decay: float = 0.01, precondition_frequency: int = 2, max_precond_dim: int = 2048, #
|
33
|
-
merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
|
34
|
-
data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1,
|
35
|
-
precond_scheduler=(1 / 3, 9), betas=(None, None), beta2_scale: float = 0.8, split: bool = False,
|
36
|
-
foreach: bool = True, mars: bool = False, caution: bool = False, mars_gamma: float = 0.0025):
|
37
|
-
if betas[0] is not None:
|
38
|
-
beta = betas[0]
|
39
|
-
defaults = {"lr": lr, "beta": beta, "shampoo_beta": shampoo_beta, "eps": eps, "weight_decay": weight_decay,
|
40
|
-
"precondition_frequency": precondition_frequency, "max_precond_dim": max_precond_dim,
|
41
|
-
"merge_dims": merge_dims, "precondition_1d": precondition_1d, "normalize_grads": normalize_grads,
|
42
|
-
"correct_bias": correct_bias, 'warmup_steps': warmup_steps, 'precond_scheduler': precond_scheduler,
|
43
|
-
'beta2_scale': beta2_scale, 'split': split, 'mars': mars, 'caution': caution,
|
44
|
-
'mars_gamma': mars_gamma}
|
45
|
-
super().__init__(params, defaults, foreach)
|
46
|
-
self._data_format = data_format
|
47
|
-
self.rng = random.Random(0x120983109)
|
48
|
-
|
49
|
-
def _step(self, group):
|
50
|
-
vals = []
|
51
|
-
step = 0
|
52
|
-
|
53
|
-
max_precond_dim = group['max_precond_dim']
|
54
|
-
precondition_1d = group['precondition_1d']
|
55
|
-
|
56
|
-
for p, g in self.split_p_and_g_in_group(group, beta1=group['beta']):
|
57
|
-
state = self.state_(p)
|
58
|
-
step = state['step'] = state.get("step", -1) + 1
|
59
|
-
|
60
|
-
if "exp_avg" not in state:
|
61
|
-
state["exp_avg"] = torch.zeros_like(g, dtype=torch.float32, memory_format=torch.preserve_format)
|
62
|
-
state["exp_avg_sq"] = torch.zeros_like(g, dtype=torch.float32, memory_format=torch.preserve_format)
|
63
|
-
init_preconditioner(g, state, max_precond_dim, precondition_1d)
|
64
|
-
update_preconditioner(g, state, max_precond_dim, precondition_1d, 0, True)
|
65
|
-
continue # first step is skipped so that we never use the current gradients in the projection.
|
66
|
-
|
67
|
-
# Projecting gradients to the eigenbases of Shampoo's preconditioner
|
68
|
-
# i.e. projecting to the eigenbases of matrices in state['GG']
|
69
|
-
grad_projected = project(g, state['Q'], False)
|
70
|
-
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
|
71
|
-
vals.append((p, g, grad_projected, exp_avg, exp_avg_sq))
|
72
|
-
|
73
|
-
if not vals:
|
74
|
-
return
|
75
|
-
|
76
|
-
p_list, grad, grad_projected, exp_avg, exp_avg_sq = zip(*vals)
|
77
|
-
beta1 = group["beta"]
|
78
|
-
|
79
|
-
beta2 = 1 - max(step, 1) ** -group['beta2_scale']
|
80
|
-
old_debiased1 = beta_debias(beta1, step)
|
81
|
-
old_debiased2 = beta_debias(beta2, step)
|
82
|
-
|
83
|
-
# Decay the first and second moment running average coefficient
|
84
|
-
# In-place operations to update the averages at the same time
|
85
|
-
beta2 = torch.empty((), dtype=torch.float32, device=p_list[0].device).fill_(beta2)
|
86
|
-
step_tensor = torch.empty((), dtype=torch.int32, device=p_list[0].device).fill_(step)
|
87
|
-
|
88
|
-
update_precond = precond_schedule(step, group['precond_scheduler'], self.rng)
|
89
|
-
step_size = -group["lr"] * min(step / group['warmup_steps'], 1)
|
90
|
-
|
91
|
-
for p, g, gp, ea, eas in zip(p_list, grad, grad_projected, exp_avg, exp_avg_sq):
|
92
|
-
d = exp_avg_(ea, eas, g, gp, beta1, beta2, step_tensor)[0]
|
93
|
-
state = self.state_(p)
|
94
|
-
# Projecting the exponential moving average of gradients to the eigenbases of Shampoo's preconditioner
|
95
|
-
# i.e. projecting to the eigenbases of matrices in state['GG']
|
96
|
-
exp_avg_projected = project(ea, state['Q'], False)
|
97
|
-
|
98
|
-
# Projecting back the preconditioned (by Adam) exponential moving average of gradients
|
99
|
-
# to the original space
|
100
|
-
# CANT DO /= HERE AS EXP_AVG MAY POINT TO THE BUFFER
|
101
|
-
exp_avg_projected = exp_avg_projected / d
|
102
|
-
precond = project(exp_avg_projected, state['Q'], True)
|
103
|
-
|
104
|
-
update_preconditioner(g, state, max_precond_dim, precondition_1d, old_debiased2, update_precond)
|
105
|
-
update_param_([p], [precond], step_size, group["weight_decay"], caution=group['caution'], grad=[g])
|
@@ -1,103 +0,0 @@
|
|
1
|
-
import random
|
2
|
-
|
3
|
-
import torch
|
4
|
-
|
5
|
-
from .utils import init_preconditioner, update_preconditioner, project, beta_debias, laprop_exp_avg_, update_param_, \
|
6
|
-
precond_schedule, set_, StatefulOptimizer
|
7
|
-
|
8
|
-
|
9
|
-
class PrecondSchedulePaLMForeachSOLP(StatefulOptimizer):
|
10
|
-
"""
|
11
|
-
Sources:
|
12
|
-
LaProp:
|
13
|
-
LaProp: Separating Momentum and Adaptivity in Adam
|
14
|
-
Liu Ziyin, Zhikang T.Wang, Masahito Ueda
|
15
|
-
https://arxiv.org/abs/2002.04839
|
16
|
-
https://github.com/ClashLuke/HeavyBall/blob/main/heavyball/foreach_laprop.py
|
17
|
-
|
18
|
-
Preconditioner Schedules:
|
19
|
-
Preconditioned Stochastic Gradient Descent
|
20
|
-
Xi-Lin Li, Omead Pooladzandi, Evan Walters
|
21
|
-
https://arxiv.org/abs/1512.04202
|
22
|
-
https://github.com/evanatyourservice/kron_torch
|
23
|
-
https://github.com/lixilinx/psgd_torch
|
24
|
-
|
25
|
-
Baseline SOAP:
|
26
|
-
SOAP: Improving and Stabilizing Shampoo using Adam
|
27
|
-
Nikhil Vyas, Depen Morwani, Rosie Zhao, Itai Shapira, David Brandfonbrener, Lucas Janson, Sham Kakade
|
28
|
-
https://arxiv.org/abs/2409.11321
|
29
|
-
https://github.com/nikhilvyas/SOAP
|
30
|
-
|
31
|
-
Beta2 Schedule:
|
32
|
-
PaLM: Scaling Language Modeling with Pathways
|
33
|
-
Aakanksha Chowdhery, Sharan Narang, Jacob Devlin, Maarten Bosma, Gaurav Mishra, Adam Roberts, Paul Barham, Hyung Won Chung, Charles Sutton, Sebastian Gehrmann, Parker Schuh, Kensen Shi, Sasha Tsvyashchenko, Joshua Maynez, Abhishek Rao, Parker Barnes, Yi Tay, Noam Shazeer, Vinodkumar Prabhakaran, Emily Reif, Nan Du, Ben Hutchinson, Reiner Pope, James Bradbury, Jacob Austin, Michael Isard, Guy Gur-Ari, Pengcheng Yin, Toju Duke, Anselm Levskaya, Sanjay Ghemawat, Sunipa Dev, Henryk Michalewski, Xavier Garcia, Vedant Misra, Kevin Robinson, Liam Fedus, Denny Zhou, Daphne Ippolito, David Luan, Hyeontaek Lim, Barret Zoph, Alexander Spiridonov, Ryan Sepassi, David Dohan, Shivani Agrawal, Mark Omernick, Andrew M. Dai, Thanumalayan Sankaranarayana Pillai, Marie Pellat, Aitor Lewkowycz, Erica Moreira, Rewon Child, Oleksandr Polozov, Katherine Lee, Zongwei Zhou, Xuezhi Wang, Brennan Saeta, Mark Diaz, Orhan Firat, Michele Catasta, Jason Wei, Kathy Meier-Hellstern, Douglas Eck, Jeff Dean, Slav Petrov, Noah Fiedel
|
34
|
-
https://arxiv.org/abs/2204.02311
|
35
|
-
"""
|
36
|
-
|
37
|
-
def __init__(self, params, lr: float = 3e-3, beta=0.9, shampoo_beta: float = 0.95, eps: float = 1e-8,
|
38
|
-
weight_decay: float = 0.01, precondition_frequency: int = 2, max_precond_dim: int = 2048, #
|
39
|
-
merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
|
40
|
-
data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1,
|
41
|
-
precond_scheduler=(1 / 3, 9), betas=(None, None), beta2_scale: float = 0.8, split: bool = False,
|
42
|
-
foreach: bool = True, mars: bool = False, caution: bool = False, mars_gamma: float = 0.0025):
|
43
|
-
if betas[0] is not None:
|
44
|
-
beta = betas[0]
|
45
|
-
defaults = {"lr": lr, "beta": beta, "shampoo_beta": shampoo_beta, "eps": eps, "weight_decay": weight_decay,
|
46
|
-
"precondition_frequency": precondition_frequency, "max_precond_dim": max_precond_dim,
|
47
|
-
"merge_dims": merge_dims, "precondition_1d": precondition_1d, "normalize_grads": normalize_grads,
|
48
|
-
"correct_bias": correct_bias, 'warmup_steps': warmup_steps, 'precond_scheduler': precond_scheduler,
|
49
|
-
'beta2_scale': beta2_scale, 'split': split, 'mars': mars, 'caution': caution,
|
50
|
-
'mars_gamma': mars_gamma}
|
51
|
-
super().__init__(params, defaults, foreach)
|
52
|
-
self._data_format = data_format
|
53
|
-
self.rng = random.Random(0x120983109)
|
54
|
-
|
55
|
-
def _step(self, group):
|
56
|
-
vals = []
|
57
|
-
step = 0
|
58
|
-
|
59
|
-
max_precond_dim = group['max_precond_dim']
|
60
|
-
precondition_1d = group['precondition_1d']
|
61
|
-
|
62
|
-
for p, g in self.split_p_and_g_in_group(group, beta1=group['beta']):
|
63
|
-
state = self.state_(p)
|
64
|
-
step = state['step'] = state.get("step", -1) + 1
|
65
|
-
|
66
|
-
if "exp_avg" not in state:
|
67
|
-
state["exp_avg"] = torch.zeros_like(g, dtype=torch.float32, memory_format=torch.preserve_format)
|
68
|
-
state["exp_avg_sq"] = torch.zeros_like(g, dtype=torch.float32, memory_format=torch.preserve_format)
|
69
|
-
init_preconditioner(g, state, max_precond_dim, precondition_1d)
|
70
|
-
update_preconditioner(g, state, max_precond_dim, precondition_1d, 0, True)
|
71
|
-
continue # first step is skipped so that we never use the current gradients in the projection.
|
72
|
-
|
73
|
-
# Projecting gradients to the eigenbases of Shampoo's preconditioner
|
74
|
-
# i.e. projecting to the eigenbases of matrices in state['GG']
|
75
|
-
grad_projected = project(g, state['Q'], False)
|
76
|
-
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
|
77
|
-
vals.append((p, g, grad_projected, exp_avg, exp_avg_sq))
|
78
|
-
|
79
|
-
if not vals:
|
80
|
-
return
|
81
|
-
|
82
|
-
p_list, grad, grad_projected, exp_avg, exp_avg_sq = zip(*vals)
|
83
|
-
beta1 = group["beta"]
|
84
|
-
|
85
|
-
beta2 = 1 - max(step, 1) ** -group['beta2_scale']
|
86
|
-
old_debiased1 = beta_debias(beta1, step)
|
87
|
-
old_debiased2 = beta_debias(beta2, step)
|
88
|
-
|
89
|
-
# Decay the first and second moment running average coefficient
|
90
|
-
# In-place operations to update the averages at the same time
|
91
|
-
beta2 = torch.empty((), dtype=torch.float32, device=p_list[0].device).fill_(beta2)
|
92
|
-
step_tensor = torch.empty((), dtype=torch.int32, device=p_list[0].device).fill_(step)
|
93
|
-
|
94
|
-
update_precond = precond_schedule(step, group['precond_scheduler'], self.rng)
|
95
|
-
step_size = -group["lr"] * min(step / group['warmup_steps'], 1)
|
96
|
-
|
97
|
-
for p, g, gp, ea, eas in zip(p_list, grad, grad_projected, exp_avg, exp_avg_sq):
|
98
|
-
laprop_exp_avg_(ea, eas, gp, beta1, beta2, step_tensor)
|
99
|
-
state = self.state_(p)
|
100
|
-
precond = project(ea, state['Q'], True)
|
101
|
-
|
102
|
-
update_preconditioner(g, state, max_precond_dim, precondition_1d, old_debiased2, update_precond)
|
103
|
-
update_param_([p], [precond], step_size, group["weight_decay"], caution=group['caution'], grad=[g])
|
@@ -1,141 +0,0 @@
|
|
1
|
-
import random
|
2
|
-
|
3
|
-
import torch
|
4
|
-
|
5
|
-
from .utils import init_preconditioner, update_preconditioner, project, set_, adaptive_gradient_clipping_, exp_avg_sq_, \
|
6
|
-
beta_debias, schedule_free_, warmup, ScheduleFree, precond_schedule, copy_stochastic_list_, \
|
7
|
-
promote, decorator_knowngood
|
8
|
-
|
9
|
-
|
10
|
-
@decorator_knowngood
|
11
|
-
def _compilable_exp_avg_sq_(exp_avg_sq, grad_projected, old_debiased2, eps):
|
12
|
-
eas32, gp32 = [list(map(promote, x)) for x in (exp_avg_sq, grad_projected)]
|
13
|
-
denom = exp_avg_sq_(eas32, gp32, old_debiased2, eps)
|
14
|
-
torch._foreach_div_(gp32, denom)
|
15
|
-
|
16
|
-
copy_stochastic_list_(exp_avg_sq, eas32)
|
17
|
-
copy_stochastic_list_(grad_projected, gp32)
|
18
|
-
|
19
|
-
|
20
|
-
class PrecondScheduleSFPaLMSOAP(ScheduleFree):
|
21
|
-
"""
|
22
|
-
SFPaLMForeachSOAP with preconditioner schedules
|
23
|
-
|
24
|
-
Sources:
|
25
|
-
Preconditioner Schedules:
|
26
|
-
Preconditioned Stochastic Gradient Descent
|
27
|
-
Xi-Lin Li, Omead Pooladzandi, Evan Walters
|
28
|
-
https://arxiv.org/abs/1512.04202
|
29
|
-
https://github.com/evanatyourservice/kron_torch
|
30
|
-
https://github.com/lixilinx/psgd_torch
|
31
|
-
|
32
|
-
Baseline SOAP:
|
33
|
-
SOAP: Improving and Stabilizing Shampoo using Adam
|
34
|
-
Nikhil Vyas, Depen Morwani, Rosie Zhao, Itai Shapira, David Brandfonbrener, Lucas Janson, Sham Kakade
|
35
|
-
https://arxiv.org/abs/2409.11321
|
36
|
-
https://github.com/nikhilvyas/SOAP
|
37
|
-
|
38
|
-
ScheduleFree:
|
39
|
-
The Road Less Scheduled
|
40
|
-
Aaron Defazio, Xingyu Alice Yang, Harsh Mehta, Konstantin Mishchenko, Ahmed Khaled, Ashok Cutkosky
|
41
|
-
https://arxiv.org/abs/2405.15682
|
42
|
-
https://github.com/facebookresearch/schedule_free
|
43
|
-
|
44
|
-
Beta2 Schedule:
|
45
|
-
PaLM: Scaling Language Modeling with Pathways
|
46
|
-
Aakanksha Chowdhery, Sharan Narang, Jacob Devlin, Maarten Bosma, Gaurav Mishra, Adam Roberts, Paul Barham, Hyung Won Chung, Charles Sutton, Sebastian Gehrmann, Parker Schuh, Kensen Shi, Sasha Tsvyashchenko, Joshua Maynez, Abhishek Rao, Parker Barnes, Yi Tay, Noam Shazeer, Vinodkumar Prabhakaran, Emily Reif, Nan Du, Ben Hutchinson, Reiner Pope, James Bradbury, Jacob Austin, Michael Isard, Guy Gur-Ari, Pengcheng Yin, Toju Duke, Anselm Levskaya, Sanjay Ghemawat, Sunipa Dev, Henryk Michalewski, Xavier Garcia, Vedant Misra, Kevin Robinson, Liam Fedus, Denny Zhou, Daphne Ippolito, David Luan, Hyeontaek Lim, Barret Zoph, Alexander Spiridonov, Ryan Sepassi, David Dohan, Shivani Agrawal, Mark Omernick, Andrew M. Dai, Thanumalayan Sankaranarayana Pillai, Marie Pellat, Aitor Lewkowycz, Erica Moreira, Rewon Child, Oleksandr Polozov, Katherine Lee, Zongwei Zhou, Xuezhi Wang, Brennan Saeta, Mark Diaz, Orhan Firat, Michele Catasta, Jason Wei, Kathy Meier-Hellstern, Douglas Eck, Jeff Dean, Slav Petrov, Noah Fiedel
|
47
|
-
https://arxiv.org/abs/2204.02311
|
48
|
-
"""
|
49
|
-
|
50
|
-
def __init__(self, params, lr: float = 3e-3, beta=0.9, beta2_scale: float = 0.8, eps: float = 1e-8,
|
51
|
-
weight_decay: float = 0.01, precondition_frequency: int = 2, max_precond_dim: int = 2048, #
|
52
|
-
merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
|
53
|
-
data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1, r=0.0,
|
54
|
-
weight_lr_power=2.0, gradient_clip_val: float = 0.1, precond_scheduler=(1 / 3, 9), betas=(None, None),
|
55
|
-
split: bool = False, foreach: bool = True, mars: bool = False, caution: bool = False,
|
56
|
-
mars_gamma: float = 0.0025):
|
57
|
-
if betas[0] is not None:
|
58
|
-
beta = betas[0]
|
59
|
-
|
60
|
-
assert not caution, "Caution is not implemented in ScheduleFree optimizers"
|
61
|
-
|
62
|
-
defaults = {"lr": lr, "beta": beta, "beta2_scale": beta2_scale, "eps": eps, "weight_decay": weight_decay,
|
63
|
-
"precondition_frequency": precondition_frequency, "max_precond_dim": max_precond_dim,
|
64
|
-
"merge_dims": merge_dims, "precondition_1d": precondition_1d, "normalize_grads": normalize_grads,
|
65
|
-
"correct_bias": correct_bias, 'warmup_steps': warmup_steps, 'r': r,
|
66
|
-
'weight_lr_power': weight_lr_power, 'train_mode': True, 'step': -1, 'weight_sum': 0,
|
67
|
-
'gradient_clip_val': gradient_clip_val, 'precond_scheduler': precond_scheduler, 'split': split,
|
68
|
-
'mars': mars, 'caution': caution, 'mars_gamma': mars_gamma}
|
69
|
-
super().__init__(params, defaults, foreach)
|
70
|
-
self._data_format = data_format
|
71
|
-
self.rng = random.Random(0x120983109)
|
72
|
-
|
73
|
-
def _step(self, group):
|
74
|
-
vals = []
|
75
|
-
max_precond_dim = group['max_precond_dim']
|
76
|
-
precondition_1d = group['precondition_1d']
|
77
|
-
|
78
|
-
step = group['step'] = group.get("step", 0) + 1
|
79
|
-
|
80
|
-
for p in group["params"]:
|
81
|
-
if p.grad is None:
|
82
|
-
continue
|
83
|
-
grad = p.grad.float()
|
84
|
-
vals.append((p, grad))
|
85
|
-
|
86
|
-
if not vals:
|
87
|
-
return
|
88
|
-
|
89
|
-
p_list, grad = zip(*vals)
|
90
|
-
vals = []
|
91
|
-
|
92
|
-
# adaptive gradient clipping
|
93
|
-
adaptive_gradient_clipping_(p_list, grad, group["gradient_clip_val"], eps=group["eps"])
|
94
|
-
|
95
|
-
for p, g in self.split_p_and_g_in_group(group, beta1=group['beta']):
|
96
|
-
state = self.state_(p)
|
97
|
-
|
98
|
-
if "z" not in state:
|
99
|
-
state["z"] = torch.clone(p.data, memory_format=torch.preserve_format)
|
100
|
-
state["exp_avg_sq"] = torch.zeros_like(g, dtype=torch.float32, memory_format=torch.preserve_format)
|
101
|
-
init_preconditioner(g, state, max_precond_dim, precondition_1d)
|
102
|
-
update_preconditioner(g, state, max_precond_dim, precondition_1d, 0, True)
|
103
|
-
continue # first step is skipped so that we never use the current gradients in the projection.
|
104
|
-
|
105
|
-
# Projecting gradients to the eigenbases of Shampoo's preconditioner
|
106
|
-
# i.e. projecting to the eigenbases of matrices in state['GG']
|
107
|
-
grad_projected = project(g, state['Q'], False)
|
108
|
-
z, exp_avg_sq = state["z"], state["exp_avg_sq"]
|
109
|
-
vals.append((p, g, grad_projected, z, exp_avg_sq))
|
110
|
-
|
111
|
-
if not vals:
|
112
|
-
return
|
113
|
-
|
114
|
-
p_list, grad, grad_projected, z, exp_avg_sq = zip(*vals)
|
115
|
-
del vals
|
116
|
-
|
117
|
-
beta2 = 1 - max(step, 1) ** -group['beta2_scale']
|
118
|
-
old_debiased2 = beta_debias(beta2, step)
|
119
|
-
|
120
|
-
# Decay the first and second moment running average coefficient
|
121
|
-
# In-place operations to update the averages at the same time
|
122
|
-
old_debiased_tensor = torch.empty((), dtype=torch.float32, device=p_list[0].device).fill_(old_debiased2)
|
123
|
-
_compilable_exp_avg_sq_(exp_avg_sq, grad_projected, old_debiased_tensor, group["eps"])
|
124
|
-
|
125
|
-
update_precond = precond_schedule(step, group['precond_scheduler'], self.rng)
|
126
|
-
|
127
|
-
for p, g, gp in zip(p_list, grad, grad_projected):
|
128
|
-
state = self.state_(p)
|
129
|
-
# Projecting back the preconditioned (by Adam) exponential moving average of gradients
|
130
|
-
# to the original space
|
131
|
-
set_(gp, project(gp, state['Q'], back=True))
|
132
|
-
|
133
|
-
update_preconditioner(g, state, max_precond_dim, precondition_1d, old_debiased2, update_precond)
|
134
|
-
|
135
|
-
# Weight decay calculated at y
|
136
|
-
if group["weight_decay"] > 0:
|
137
|
-
torch._foreach_add_(grad, p_list, alpha=group["weight_decay"])
|
138
|
-
|
139
|
-
lr = warmup(group['lr'], step, group['warmup_steps'])
|
140
|
-
group['weight_sum'] = schedule_free_(lr, group['weight_lr_power'], group['weight_sum'], group['beta'], p_list,
|
141
|
-
z, grad_projected, group['r'], step)
|
heavyball/psgd_kron.py
DELETED
@@ -1,120 +0,0 @@
|
|
1
|
-
"""
|
2
|
-
Originally from Evan Walters and Omead Pooladzandi, 2024
|
3
|
-
Modified under Creative Commons Attribution 4.0 International
|
4
|
-
Source available at https://github.com/evanatyourservice/kron_torch/blob/97a2b5ee8a1a4c29e4780bbf6c521e545189eff9/kron_torch/kron.py
|
5
|
-
"""
|
6
|
-
|
7
|
-
from typing import Optional
|
8
|
-
|
9
|
-
import torch
|
10
|
-
|
11
|
-
from .utils import update_param_, warmup, psgd_precond_grad, init_Q_exprs, trust_region_clip_, PSGDBase, \
|
12
|
-
line_to_triu, triu_to_line, promote, stochastic_lerp_, beta_debias
|
13
|
-
|
14
|
-
|
15
|
-
class ForeachPSGDKron(PSGDBase):
|
16
|
-
"""Implements PSGD Kron from https://github.com/lixilinx/psgd_torch.
|
17
|
-
|
18
|
-
Args:
|
19
|
-
params (iterable): Iterable of parameters to optimize or dicts defining
|
20
|
-
parameter groups.
|
21
|
-
lr (float): Learning rate.
|
22
|
-
b1 (float): Momentum parameter.
|
23
|
-
weight_decay (float): Weight decay (L2 penalty).
|
24
|
-
preconditioner_update_probability (callable or float, optional): Probability of
|
25
|
-
updating the preconditioner. If None, defaults to a schedule that anneals
|
26
|
-
from 1.0 to 0.03 by 4000 steps.
|
27
|
-
max_size_triangular (int): Max size for dim's preconditioner to be triangular.
|
28
|
-
min_ndim_triangular (int): Minimum number of dimensions a layer needs
|
29
|
-
to have triangular preconditioners.
|
30
|
-
memory_save_mode: (string, optional), None, 'one_diag', or 'all_diag', None is default
|
31
|
-
to set all preconditioners to be triangular, 'one_diag' sets the largest
|
32
|
-
or last dim to be diagonal per layer, and 'all_diag' sets all preconditioners
|
33
|
-
to be diagonal.
|
34
|
-
momentum_into_precond_update: (bool), whether to send momentum into preconditioner
|
35
|
-
update instead of raw gradients.
|
36
|
-
"""
|
37
|
-
|
38
|
-
def __init__(self, params, lr=0.001, beta=0.9, weight_decay=0.0, preconditioner_update_probability=None,
|
39
|
-
max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
|
40
|
-
momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
|
41
|
-
split: bool = False, clip_fn: Optional[callable] = None, store_triu_as_line: bool = True,
|
42
|
-
foreach: bool = True, q_dtype='float32', stochastic_schedule: bool = True,
|
43
|
-
storage_dtype: str = 'float32', mars: bool = False, caution: bool = False, mars_gamma: float = 0.0025,
|
44
|
-
#
|
45
|
-
# expert parameters
|
46
|
-
precond_init_scale=1.0, precond_lr=0.1):
|
47
|
-
if not 0.0 <= lr:
|
48
|
-
raise ValueError(f"Invalid learning rate: {lr}")
|
49
|
-
if not 0.0 <= beta < 1.0:
|
50
|
-
raise ValueError(f"Invalid beta parameter: {beta}")
|
51
|
-
if not 0.0 <= weight_decay:
|
52
|
-
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
|
53
|
-
|
54
|
-
if clip_fn is None:
|
55
|
-
clip_fn = lambda x: trust_region_clip_(x, 0.9, 1.5)
|
56
|
-
|
57
|
-
defaults = dict(lr=lr, beta=beta, weight_decay=weight_decay, max_size_triangular=max_size_triangular,
|
58
|
-
min_ndim_triangular=min_ndim_triangular, memory_save_mode=memory_save_mode,
|
59
|
-
momentum_into_precond_update=momentum_into_precond_update, precond_lr=precond_lr,
|
60
|
-
precond_init_scale=precond_init_scale, step=0, warmup_steps=warmup_steps, merge_dims=merge_dims,
|
61
|
-
split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype,
|
62
|
-
storage_dtype=storage_dtype,
|
63
|
-
mars=mars, caution=caution, mars_gamma=mars_gamma)
|
64
|
-
super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
|
65
|
-
|
66
|
-
def _step(self, group):
|
67
|
-
should_update = self.should_update(group)
|
68
|
-
momentum_into_precond_update = group.get("momentum_into_precond_update", True)
|
69
|
-
precond_init_scale = group['precond_init_scale']
|
70
|
-
max_size_triangular = group['max_size_triangular']
|
71
|
-
min_ndim_triangular = group['min_ndim_triangular']
|
72
|
-
memory_save_mode = group['memory_save_mode']
|
73
|
-
precond_lr = group['precond_lr']
|
74
|
-
weight_decay = group['weight_decay']
|
75
|
-
lr = group['lr']
|
76
|
-
beta = group['beta']
|
77
|
-
store_triu_as_line = group['store_triu_as_line']
|
78
|
-
q_dtype = getattr(torch, group['q_dtype'])
|
79
|
-
storage_dtype = getattr(torch, group['storage_dtype'])
|
80
|
-
|
81
|
-
vals = []
|
82
|
-
|
83
|
-
for p, g in self.split_p_and_g_in_group(group, should_promote=False, beta1=beta):
|
84
|
-
state = self.state_(p)
|
85
|
-
|
86
|
-
if 'Q' not in state:
|
87
|
-
state["exp_avg"] = torch.zeros_like(g, dtype=storage_dtype, memory_format=torch.preserve_format)
|
88
|
-
Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular, min_ndim_triangular,
|
89
|
-
memory_save_mode, dtype=q_dtype)
|
90
|
-
state['Q'] = triu_to_line(Q) if store_triu_as_line else Q
|
91
|
-
|
92
|
-
vals.append((p, g, state["exp_avg"], state["Q"]))
|
93
|
-
|
94
|
-
if not vals:
|
95
|
-
return
|
96
|
-
|
97
|
-
p_list, grad_list, exp_avg_list, Q_list = zip(*vals)
|
98
|
-
del vals
|
99
|
-
|
100
|
-
group["step"] += 1
|
101
|
-
|
102
|
-
beta = beta_debias(beta, group["step"])
|
103
|
-
beta = torch.empty((), dtype=torch.float32, device=grad_list[0].device).fill_(1 - beta)
|
104
|
-
stochastic_lerp_(exp_avg_list, grad_list, 1 - beta)
|
105
|
-
|
106
|
-
grad_list, Q_list, exp_avg_list = list(grad_list), list(Q_list), list(exp_avg_list)
|
107
|
-
|
108
|
-
lr = -warmup(lr, group['step'], group['warmup_steps'])
|
109
|
-
|
110
|
-
for i, (p, g) in enumerate(zip(p_list, grad_list)):
|
111
|
-
q_orig = Q_list.pop(0)
|
112
|
-
ea = exp_avg_list.pop(0)
|
113
|
-
q = line_to_triu(q_orig) if store_triu_as_line else q_orig
|
114
|
-
|
115
|
-
if should_update:
|
116
|
-
q32 = [promote(q_) for q_ in q]
|
117
|
-
self.do_update(group, [p], [ea if momentum_into_precond_update else g], [q32], precond_lr, [q_orig],
|
118
|
-
store_triu_as_line)
|
119
|
-
g = psgd_precond_grad(False, self.state_(p)["exprs"][-1], ea, *q)
|
120
|
-
update_param_([p], self.clip_fn([g]), lr, weight_decay, caution=group['caution'], grad=[g])
|