heavyball 0.25.0__py3-none-any.whl → 0.25.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- heavyball/__init__.py +12 -2
- heavyball/foreach_soap.py +3 -1
- heavyball/foreach_solp.py +89 -0
- heavyball/palm_foreach_soap.py +2 -1
- heavyball/palm_foreach_solp.py +98 -0
- heavyball/precond_schedule_foreach_solp.py +95 -0
- heavyball/precond_schedule_palm_foreach_solp.py +103 -0
- heavyball/utils.py +23 -0
- {heavyball-0.25.0.dist-info → heavyball-0.25.1.dist-info}/METADATA +1 -1
- {heavyball-0.25.0.dist-info → heavyball-0.25.1.dist-info}/RECORD +13 -9
- {heavyball-0.25.0.dist-info → heavyball-0.25.1.dist-info}/LICENSE +0 -0
- {heavyball-0.25.0.dist-info → heavyball-0.25.1.dist-info}/WHEEL +0 -0
- {heavyball-0.25.0.dist-info → heavyball-0.25.1.dist-info}/top_level.txt +0 -0
heavyball/__init__.py
CHANGED
@@ -15,6 +15,10 @@ from .precond_schedule_sfpsoap import PrecondScheduleSFPaLMSOAP
|
|
15
15
|
from .psgd_kron import ForeachPSGDKron
|
16
16
|
from .pure_psgd import ForeachPurePSGD
|
17
17
|
from .schedule_free_palm_foreach_soap import SFPaLMForeachSOAP
|
18
|
+
from .foreach_solp import ForeachSOLP
|
19
|
+
from .palm_foreach_solp import PaLMForeachSOLP
|
20
|
+
from .precond_schedule_palm_foreach_solp import PrecondSchedulePaLMForeachSOLP
|
21
|
+
from .precond_schedule_foreach_solp import PrecondScheduleForeachSOLP
|
18
22
|
|
19
23
|
PalmForEachSoap = PaLMForeachSOAP
|
20
24
|
|
@@ -35,12 +39,18 @@ PaLMPAdam = ForeachPaLMPAdam
|
|
35
39
|
DelayedPSGD = ForeachDelayedPSGD
|
36
40
|
CachedPSGDKron = ForeachCachedPSGDKron
|
37
41
|
CachedDelayedPSGDKron = ForeachCachedDelayedPSGDKron
|
42
|
+
SOLP = ForeachSOLP
|
43
|
+
PaLMSOLP = PaLMForeachSOLP
|
44
|
+
PrecondSchedulePaLMSOLP = PrecondSchedulePaLMForeachSOLP
|
45
|
+
PrecondScheduleSOLP = PrecondScheduleForeachSOLP
|
38
46
|
|
39
47
|
__all__ = ['PalmForEachSoap', 'PaLMForeachSFAdamW', 'PaLMForeachSOAP', 'SFPaLMForeachSOAP', 'PrecondScheduleSFPaLMSOAP',
|
40
48
|
'ForeachSOAP', 'ForeachSFAdamW', 'ForeachLaProp', 'ForeachADOPT', 'PrecondScheduleForeachSOAP',
|
41
49
|
'PrecondSchedulePaLMForeachSOAP', 'ForeachPSGDKron', 'ForeachAdamW', 'ForeachPurePSGD', 'ForeachPaLMPAdam',
|
42
|
-
'ForeachDelayedPSGD', 'ForeachCachedPSGDKron', 'ForeachCachedDelayedPSGDKron',
|
50
|
+
'ForeachDelayedPSGD', 'ForeachCachedPSGDKron', 'ForeachCachedDelayedPSGDKron', 'ForeachSOLP',
|
51
|
+
'PaLMForeachSOLP', 'PrecondSchedulePaLMForeachSOLP', 'PrecondScheduleForeachSOLP',
|
43
52
|
#
|
44
53
|
'PaLMSOAP', 'PaLMSFAdamW', 'PaLMSFSoap', 'PaLMSFAdamW', 'PrecondScheduleSFPaLMSOAP', 'SOAP', 'SFAdamW',
|
45
54
|
'LaProp', 'ADOPT', 'PSGDKron', 'AdamW', 'PurePSGD', 'PaLMPAdam', 'DelayedPSGD', 'CachedPSGDKron',
|
46
|
-
'CachedDelayedPSGDKron', 'PrecondScheduleSOAP', 'PrecondSchedulePaLMSOAP'
|
55
|
+
'CachedDelayedPSGDKron', 'PrecondScheduleSOAP', 'PrecondSchedulePaLMSOAP', 'SOLP', 'PrecondScheduleSOLP',
|
56
|
+
'PrecondSchedulePaLMSOLP', 'PrecondScheduleSOLP']
|
heavyball/foreach_soap.py
CHANGED
@@ -6,7 +6,7 @@ from .utils import init_preconditioner, update_preconditioner, project, beta_deb
|
|
6
6
|
|
7
7
|
class ForeachSOAP(StatefulOptimizer):
|
8
8
|
"""
|
9
|
-
|
9
|
+
ForeachSOAP
|
10
10
|
|
11
11
|
Sources:
|
12
12
|
Baseline SOAP:
|
@@ -75,6 +75,8 @@ class ForeachSOAP(StatefulOptimizer):
|
|
75
75
|
step_size = -group["lr"] * min(step / group['warmup_steps'], 1)
|
76
76
|
|
77
77
|
for p, g, gp, ea, eas in zip(p_list, grad, grad_projected, exp_avg, exp_avg_sq):
|
78
|
+
state = self.state_(p)
|
79
|
+
|
78
80
|
d = exp_avg_(ea, eas, g, gp, beta1, beta2, step_tensor)[0]
|
79
81
|
# Projecting the exponential moving average of gradients to the eigenbases of Shampoo's preconditioner
|
80
82
|
# i.e. projecting to the eigenbases of matrices in state['GG']
|
@@ -0,0 +1,89 @@
|
|
1
|
+
import torch
|
2
|
+
|
3
|
+
from .utils import init_preconditioner, update_preconditioner, project, beta_debias, exp_avg_sq_, update_param_, set_, \
|
4
|
+
StatefulOptimizer, laprop_exp_avg_
|
5
|
+
|
6
|
+
|
7
|
+
class ForeachSOLP(StatefulOptimizer):
|
8
|
+
"""
|
9
|
+
ForeachSOLP
|
10
|
+
|
11
|
+
Sources:
|
12
|
+
LaProp:
|
13
|
+
LaProp: Separating Momentum and Adaptivity in Adam
|
14
|
+
Liu Ziyin, Zhikang T.Wang, Masahito Ueda
|
15
|
+
https://arxiv.org/abs/2002.04839
|
16
|
+
https://github.com/ClashLuke/HeavyBall/blob/main/heavyball/foreach_laprop.py
|
17
|
+
|
18
|
+
Baseline SOAP:
|
19
|
+
SOAP: Improving and Stabilizing Shampoo using Adam
|
20
|
+
Nikhil Vyas, Depen Morwani, Rosie Zhao, Itai Shapira, David Brandfonbrener, Lucas Janson, Sham Kakade
|
21
|
+
https://arxiv.org/abs/2409.11321
|
22
|
+
https://github.com/nikhilvyas/SOAP
|
23
|
+
|
24
|
+
ScheduleFree:
|
25
|
+
The Road Less Scheduled
|
26
|
+
Aaron Defazio, Xingyu Alice Yang, Harsh Mehta, Konstantin Mishchenko, Ahmed Khaled, Ashok Cutkosky
|
27
|
+
https://arxiv.org/abs/2405.15682
|
28
|
+
https://github.com/facebookresearch/schedule_free
|
29
|
+
"""
|
30
|
+
|
31
|
+
def __init__(self, params, lr: float = 3e-3, betas=(0.9, 0.95), shampoo_beta: float = 0.95, eps: float = 1e-8,
|
32
|
+
weight_decay: float = 0.01, precondition_frequency: int = 2, max_precond_dim: int = 2048, #
|
33
|
+
merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
|
34
|
+
data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1,
|
35
|
+
split: bool = False, foreach: bool = True, mars: bool = False, caution: bool = False,
|
36
|
+
mars_gamma: float = 0.0025):
|
37
|
+
defaults = {"lr": lr, "betas": betas, "shampoo_beta": shampoo_beta, "eps": eps, "weight_decay": weight_decay,
|
38
|
+
"precondition_frequency": precondition_frequency, "max_precond_dim": max_precond_dim,
|
39
|
+
"merge_dims": merge_dims, "precondition_1d": precondition_1d, "normalize_grads": normalize_grads,
|
40
|
+
"correct_bias": correct_bias, 'warmup_steps': warmup_steps, 'split': split, 'mars': mars,
|
41
|
+
'caution': caution, 'mars_gamma': mars_gamma}
|
42
|
+
super().__init__(params, defaults, foreach)
|
43
|
+
self._data_format = data_format
|
44
|
+
|
45
|
+
def _step(self, group):
|
46
|
+
vals = []
|
47
|
+
step = 0
|
48
|
+
|
49
|
+
max_precond_dim = group['max_precond_dim']
|
50
|
+
precondition_1d = group['precondition_1d']
|
51
|
+
|
52
|
+
for p, g in self.split_p_and_g_in_group(group, beta1=group['betas'][0]):
|
53
|
+
state = self.state_(p)
|
54
|
+
step = state['step'] = state.get("step", -1) + 1
|
55
|
+
|
56
|
+
if "exp_avg" not in state:
|
57
|
+
state["exp_avg"] = torch.zeros_like(g, dtype=torch.float32, memory_format=torch.preserve_format)
|
58
|
+
state["exp_avg_sq"] = torch.zeros_like(g, dtype=torch.float32, memory_format=torch.preserve_format)
|
59
|
+
init_preconditioner(g, state, max_precond_dim, precondition_1d)
|
60
|
+
update_preconditioner(g, state, max_precond_dim, precondition_1d, 0, True)
|
61
|
+
continue # first step is skipped so that we never use the current gradients in the projection.
|
62
|
+
|
63
|
+
# Projecting gradients to the eigenbases of Shampoo's preconditioner
|
64
|
+
# i.e. projecting to the eigenbases of matrices in state['GG']
|
65
|
+
grad_projected = project(g, state['Q'], False)
|
66
|
+
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
|
67
|
+
vals.append((p, g, grad_projected, exp_avg, exp_avg_sq))
|
68
|
+
|
69
|
+
if not vals:
|
70
|
+
return
|
71
|
+
|
72
|
+
p_list, grad, grad_projected, exp_avg, exp_avg_sq = zip(*vals)
|
73
|
+
beta1, beta2 = group["betas"]
|
74
|
+
|
75
|
+
old_debiased2 = beta_debias(beta2, step)
|
76
|
+
|
77
|
+
# Decay the first and second moment running average coefficient
|
78
|
+
# In-place operations to update the averages at the same time
|
79
|
+
step_tensor = torch.empty((), dtype=torch.int32, device=p_list[0].device).fill_(step)
|
80
|
+
|
81
|
+
step_size = -group["lr"] * min(step / group['warmup_steps'], 1)
|
82
|
+
|
83
|
+
for p, g, gp, ea, eas in zip(p_list, grad, grad_projected, exp_avg, exp_avg_sq):
|
84
|
+
laprop_exp_avg_(ea, eas, gp, beta1, beta2, step_tensor)
|
85
|
+
state = self.state_(p)
|
86
|
+
precond = project(ea, state['Q'], True)
|
87
|
+
|
88
|
+
update_preconditioner(g, state, max_precond_dim, precondition_1d, old_debiased2, step > 0 and step % group['precondition_frequency'] == 0)
|
89
|
+
update_param_([p], [precond], step_size, group["weight_decay"], caution=group['caution'], grad=[g])
|
heavyball/palm_foreach_soap.py
CHANGED
@@ -6,7 +6,7 @@ from .utils import init_preconditioner, update_preconditioner, project, beta_deb
|
|
6
6
|
|
7
7
|
class PaLMForeachSOAP(StatefulOptimizer):
|
8
8
|
"""
|
9
|
-
|
9
|
+
PaLMForeachSOAP
|
10
10
|
|
11
11
|
Sources:
|
12
12
|
Baseline SOAP:
|
@@ -84,6 +84,7 @@ class PaLMForeachSOAP(StatefulOptimizer):
|
|
84
84
|
step_size = -group["lr"] * min(step / group['warmup_steps'], 1)
|
85
85
|
|
86
86
|
for p, g, gp, ea, eas in zip(p_list, grad, grad_projected, exp_avg, exp_avg_sq):
|
87
|
+
state = self.state_(p)
|
87
88
|
d = exp_avg_(ea, eas, g, gp, beta1, beta2, step_tensor)[0]
|
88
89
|
# Projecting the exponential moving average of gradients to the eigenbases of Shampoo's preconditioner
|
89
90
|
# i.e. projecting to the eigenbases of matrices in state['GG']
|
@@ -0,0 +1,98 @@
|
|
1
|
+
import torch
|
2
|
+
|
3
|
+
from .utils import init_preconditioner, update_preconditioner, project, beta_debias, exp_avg_sq_, update_param_, set_, \
|
4
|
+
StatefulOptimizer, laprop_exp_avg_
|
5
|
+
|
6
|
+
|
7
|
+
class PaLMForeachSOLP(StatefulOptimizer):
|
8
|
+
"""
|
9
|
+
PaLMForeachSOAP
|
10
|
+
|
11
|
+
Sources:
|
12
|
+
LaProp:
|
13
|
+
LaProp: Separating Momentum and Adaptivity in Adam
|
14
|
+
Liu Ziyin, Zhikang T.Wang, Masahito Ueda
|
15
|
+
https://arxiv.org/abs/2002.04839
|
16
|
+
https://github.com/ClashLuke/HeavyBall/blob/main/heavyball/foreach_laprop.py
|
17
|
+
|
18
|
+
Baseline SOAP:
|
19
|
+
SOAP: Improving and Stabilizing Shampoo using Adam
|
20
|
+
Nikhil Vyas, Depen Morwani, Rosie Zhao, Itai Shapira, David Brandfonbrener, Lucas Janson, Sham Kakade
|
21
|
+
https://arxiv.org/abs/2409.11321
|
22
|
+
https://github.com/nikhilvyas/SOAP
|
23
|
+
|
24
|
+
ScheduleFree:
|
25
|
+
The Road Less Scheduled
|
26
|
+
Aaron Defazio, Xingyu Alice Yang, Harsh Mehta, Konstantin Mishchenko, Ahmed Khaled, Ashok Cutkosky
|
27
|
+
https://arxiv.org/abs/2405.15682
|
28
|
+
https://github.com/facebookresearch/schedule_free
|
29
|
+
|
30
|
+
Beta2 Schedule:
|
31
|
+
PaLM: Scaling Language Modeling with Pathways
|
32
|
+
Aakanksha Chowdhery, Sharan Narang, Jacob Devlin, Maarten Bosma, Gaurav Mishra, Adam Roberts, Paul Barham, Hyung Won Chung, Charles Sutton, Sebastian Gehrmann, Parker Schuh, Kensen Shi, Sasha Tsvyashchenko, Joshua Maynez, Abhishek Rao, Parker Barnes, Yi Tay, Noam Shazeer, Vinodkumar Prabhakaran, Emily Reif, Nan Du, Ben Hutchinson, Reiner Pope, James Bradbury, Jacob Austin, Michael Isard, Guy Gur-Ari, Pengcheng Yin, Toju Duke, Anselm Levskaya, Sanjay Ghemawat, Sunipa Dev, Henryk Michalewski, Xavier Garcia, Vedant Misra, Kevin Robinson, Liam Fedus, Denny Zhou, Daphne Ippolito, David Luan, Hyeontaek Lim, Barret Zoph, Alexander Spiridonov, Ryan Sepassi, David Dohan, Shivani Agrawal, Mark Omernick, Andrew M. Dai, Thanumalayan Sankaranarayana Pillai, Marie Pellat, Aitor Lewkowycz, Erica Moreira, Rewon Child, Oleksandr Polozov, Katherine Lee, Zongwei Zhou, Xuezhi Wang, Brennan Saeta, Mark Diaz, Orhan Firat, Michele Catasta, Jason Wei, Kathy Meier-Hellstern, Douglas Eck, Jeff Dean, Slav Petrov, Noah Fiedel
|
33
|
+
https://arxiv.org/abs/2204.02311
|
34
|
+
"""
|
35
|
+
|
36
|
+
def __init__(self, params, lr: float = 3e-3, beta=0.9, betas=(None, None), shampoo_beta: float = 0.95,
|
37
|
+
eps: float = 1e-8, weight_decay: float = 0.01, precondition_frequency: int = 2,
|
38
|
+
max_precond_dim: int = 2048, #
|
39
|
+
merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
|
40
|
+
data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1,
|
41
|
+
beta2_scale: float = 0.8, split: bool = False, foreach: bool = True, mars: bool = False,
|
42
|
+
caution: bool = False, mars_gamma: float = 0.0025):
|
43
|
+
if betas[0] is not None:
|
44
|
+
beta = betas[0]
|
45
|
+
defaults = {"lr": lr, "beta": beta, "shampoo_beta": shampoo_beta, "eps": eps, "weight_decay": weight_decay,
|
46
|
+
"precondition_frequency": precondition_frequency, "max_precond_dim": max_precond_dim,
|
47
|
+
"merge_dims": merge_dims, "precondition_1d": precondition_1d, "normalize_grads": normalize_grads,
|
48
|
+
"correct_bias": correct_bias, 'warmup_steps': warmup_steps, 'beta2_scale': beta2_scale,
|
49
|
+
'split': split, 'mars': mars, 'caution': caution, 'mars_gamma': mars_gamma}
|
50
|
+
super().__init__(params, defaults, foreach)
|
51
|
+
self._data_format = data_format
|
52
|
+
|
53
|
+
def _step(self, group):
|
54
|
+
vals = []
|
55
|
+
step = 0
|
56
|
+
|
57
|
+
max_precond_dim = group['max_precond_dim']
|
58
|
+
precondition_1d = group['precondition_1d']
|
59
|
+
|
60
|
+
for p, g in self.split_p_and_g_in_group(group, beta1=group['beta']):
|
61
|
+
state = self.state_(p)
|
62
|
+
step = state['step'] = state.get("step", -1) + 1
|
63
|
+
|
64
|
+
if "exp_avg" not in state:
|
65
|
+
state["exp_avg"] = torch.zeros_like(g, dtype=torch.float32, memory_format=torch.preserve_format)
|
66
|
+
state["exp_avg_sq"] = torch.zeros_like(g, dtype=torch.float32, memory_format=torch.preserve_format)
|
67
|
+
init_preconditioner(g, state, max_precond_dim, precondition_1d)
|
68
|
+
update_preconditioner(g, state, max_precond_dim, precondition_1d, 0, True)
|
69
|
+
continue # first step is skipped so that we never use the current gradients in the projection.
|
70
|
+
|
71
|
+
# Projecting gradients to the eigenbases of Shampoo's preconditioner
|
72
|
+
# i.e. projecting to the eigenbases of matrices in state['GG']
|
73
|
+
grad_projected = project(g, state['Q'], False)
|
74
|
+
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
|
75
|
+
vals.append((p, g, grad_projected, exp_avg, exp_avg_sq))
|
76
|
+
|
77
|
+
if not vals:
|
78
|
+
return
|
79
|
+
|
80
|
+
p_list, grad, grad_projected, exp_avg, exp_avg_sq = zip(*vals)
|
81
|
+
beta1 = group["beta"]
|
82
|
+
|
83
|
+
beta2 = 1 - step ** -group['beta2_scale']
|
84
|
+
old_debiased2 = beta_debias(beta2, step)
|
85
|
+
|
86
|
+
# Decay the first and second moment running average coefficient
|
87
|
+
# In-place operations to update the averages at the same time
|
88
|
+
beta2 = torch.empty((), dtype=torch.float32, device=p_list[0].device).fill_(beta2)
|
89
|
+
step_tensor = torch.empty((), dtype=torch.int32, device=p_list[0].device).fill_(step)
|
90
|
+
step_size = -group["lr"] * min(step / group['warmup_steps'], 1)
|
91
|
+
|
92
|
+
for p, g, gp, ea, eas in zip(p_list, grad, grad_projected, exp_avg, exp_avg_sq):
|
93
|
+
laprop_exp_avg_(ea, eas, gp, beta1, beta2, step_tensor)
|
94
|
+
state = self.state_(p)
|
95
|
+
precond = project(ea, state['Q'], True)
|
96
|
+
|
97
|
+
update_preconditioner(g, state, max_precond_dim, precondition_1d, old_debiased2, step > 0 and step % group['precondition_frequency'] == 0)
|
98
|
+
update_param_([p], [precond], step_size, group["weight_decay"], caution=group['caution'], grad=[g])
|
@@ -0,0 +1,95 @@
|
|
1
|
+
|
2
|
+
|
3
|
+
import random
|
4
|
+
|
5
|
+
import torch
|
6
|
+
|
7
|
+
from .utils import init_preconditioner, update_preconditioner, project, beta_debias, update_param_, \
|
8
|
+
precond_schedule, set_, StatefulOptimizer, laprop_exp_avg_
|
9
|
+
|
10
|
+
|
11
|
+
class PrecondScheduleForeachSOLP(StatefulOptimizer):
|
12
|
+
"""
|
13
|
+
Sources:
|
14
|
+
LaProp:
|
15
|
+
LaProp: Separating Momentum and Adaptivity in Adam
|
16
|
+
Liu Ziyin, Zhikang T.Wang, Masahito Ueda
|
17
|
+
https://arxiv.org/abs/2002.04839
|
18
|
+
https://github.com/ClashLuke/HeavyBall/blob/main/heavyball/foreach_laprop.py
|
19
|
+
|
20
|
+
Preconditioner Schedules:
|
21
|
+
Preconditioned Stochastic Gradient Descent
|
22
|
+
Xi-Lin Li, Omead Pooladzandi, Evan Walters
|
23
|
+
https://arxiv.org/abs/1512.04202
|
24
|
+
https://github.com/evanatyourservice/kron_torch
|
25
|
+
https://github.com/lixilinx/psgd_torch
|
26
|
+
|
27
|
+
Baseline SOAP:
|
28
|
+
SOAP: Improving and Stabilizing Shampoo using Adam
|
29
|
+
Nikhil Vyas, Depen Morwani, Rosie Zhao, Itai Shapira, David Brandfonbrener, Lucas Janson, Sham Kakade
|
30
|
+
https://arxiv.org/abs/2409.11321
|
31
|
+
https://github.com/nikhilvyas/SOAP
|
32
|
+
"""
|
33
|
+
|
34
|
+
def __init__(self, params, lr: float = 3e-3, betas=(0.9, 0.95), shampoo_beta: float = 0.95, eps: float = 1e-8,
|
35
|
+
weight_decay: float = 0.01, precondition_frequency: int = 2, max_precond_dim: int = 2048, #
|
36
|
+
merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
|
37
|
+
data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1,
|
38
|
+
precond_scheduler=(1 / 3, 9), split: bool = False, foreach: bool = True, mars: bool = False,
|
39
|
+
caution: bool = False, mars_gamma: float = 0.0025):
|
40
|
+
defaults = {"lr": lr, "betas": betas, "shampoo_beta": shampoo_beta, "eps": eps, "weight_decay": weight_decay,
|
41
|
+
"precondition_frequency": precondition_frequency, "max_precond_dim": max_precond_dim,
|
42
|
+
"merge_dims": merge_dims, "precondition_1d": precondition_1d, "normalize_grads": normalize_grads,
|
43
|
+
"correct_bias": correct_bias, 'warmup_steps': warmup_steps, 'precond_scheduler': precond_scheduler,
|
44
|
+
'split': split, 'mars': mars, 'caution': caution, 'mars_gamma': mars_gamma}
|
45
|
+
super().__init__(params, defaults, foreach)
|
46
|
+
self._data_format = data_format
|
47
|
+
self.rng = random.Random(0x120983109)
|
48
|
+
|
49
|
+
def _step(self, group):
|
50
|
+
vals = []
|
51
|
+
step = 0
|
52
|
+
|
53
|
+
max_precond_dim = group['max_precond_dim']
|
54
|
+
precondition_1d = group['precondition_1d']
|
55
|
+
|
56
|
+
for p, g in self.split_p_and_g_in_group(group, beta1=group['betas'][0]):
|
57
|
+
state = self.state_(p)
|
58
|
+
step = state['step'] = state.get("step", -1) + 1
|
59
|
+
|
60
|
+
if "exp_avg" not in state:
|
61
|
+
state["exp_avg"] = torch.zeros_like(g, dtype=torch.float32, memory_format=torch.preserve_format)
|
62
|
+
state["exp_avg_sq"] = torch.zeros_like(g, dtype=torch.float32, memory_format=torch.preserve_format)
|
63
|
+
init_preconditioner(g, state, max_precond_dim, precondition_1d)
|
64
|
+
update_preconditioner(g, state, max_precond_dim, precondition_1d, 0, True)
|
65
|
+
continue # first step is skipped so that we never use the current gradients in the projection.
|
66
|
+
|
67
|
+
# Projecting gradients to the eigenbases of Shampoo's preconditioner
|
68
|
+
# i.e. projecting to the eigenbases of matrices in state['GG']
|
69
|
+
grad_projected = project(g, state['Q'], False)
|
70
|
+
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
|
71
|
+
vals.append((p, g, grad_projected, exp_avg, exp_avg_sq))
|
72
|
+
|
73
|
+
if not vals:
|
74
|
+
return
|
75
|
+
|
76
|
+
p_list, grad, grad_projected, exp_avg, exp_avg_sq = zip(*vals)
|
77
|
+
beta1, beta2 = group["betas"]
|
78
|
+
|
79
|
+
old_debiased2 = beta_debias(beta2, step)
|
80
|
+
|
81
|
+
# Decay the first and second moment running average coefficient
|
82
|
+
# In-place operations to update the averages at the same time
|
83
|
+
step_tensor = torch.empty((), dtype=torch.int32, device=p_list[0].device).fill_(step)
|
84
|
+
|
85
|
+
update_precond = precond_schedule(step, group['precond_scheduler'], self.rng)
|
86
|
+
step_size = -group["lr"] * min(step / group['warmup_steps'], 1)
|
87
|
+
|
88
|
+
|
89
|
+
for p, g, gp, ea, eas in zip(p_list, grad, grad_projected, exp_avg, exp_avg_sq):
|
90
|
+
laprop_exp_avg_(ea, eas, gp, beta1, beta2, step_tensor)
|
91
|
+
state = self.state_(p)
|
92
|
+
precond = project(ea, state['Q'], True)
|
93
|
+
|
94
|
+
update_preconditioner(g, state, max_precond_dim, precondition_1d, old_debiased2, update_precond)
|
95
|
+
update_param_([p], [precond], step_size, group["weight_decay"], caution=group['caution'], grad=[g])
|
@@ -0,0 +1,103 @@
|
|
1
|
+
import random
|
2
|
+
|
3
|
+
import torch
|
4
|
+
|
5
|
+
from .utils import init_preconditioner, update_preconditioner, project, beta_debias, laprop_exp_avg_, update_param_, \
|
6
|
+
precond_schedule, set_, StatefulOptimizer
|
7
|
+
|
8
|
+
|
9
|
+
class PrecondSchedulePaLMForeachSOLP(StatefulOptimizer):
|
10
|
+
"""
|
11
|
+
Sources:
|
12
|
+
LaProp:
|
13
|
+
LaProp: Separating Momentum and Adaptivity in Adam
|
14
|
+
Liu Ziyin, Zhikang T.Wang, Masahito Ueda
|
15
|
+
https://arxiv.org/abs/2002.04839
|
16
|
+
https://github.com/ClashLuke/HeavyBall/blob/main/heavyball/foreach_laprop.py
|
17
|
+
|
18
|
+
Preconditioner Schedules:
|
19
|
+
Preconditioned Stochastic Gradient Descent
|
20
|
+
Xi-Lin Li, Omead Pooladzandi, Evan Walters
|
21
|
+
https://arxiv.org/abs/1512.04202
|
22
|
+
https://github.com/evanatyourservice/kron_torch
|
23
|
+
https://github.com/lixilinx/psgd_torch
|
24
|
+
|
25
|
+
Baseline SOAP:
|
26
|
+
SOAP: Improving and Stabilizing Shampoo using Adam
|
27
|
+
Nikhil Vyas, Depen Morwani, Rosie Zhao, Itai Shapira, David Brandfonbrener, Lucas Janson, Sham Kakade
|
28
|
+
https://arxiv.org/abs/2409.11321
|
29
|
+
https://github.com/nikhilvyas/SOAP
|
30
|
+
|
31
|
+
Beta2 Schedule:
|
32
|
+
PaLM: Scaling Language Modeling with Pathways
|
33
|
+
Aakanksha Chowdhery, Sharan Narang, Jacob Devlin, Maarten Bosma, Gaurav Mishra, Adam Roberts, Paul Barham, Hyung Won Chung, Charles Sutton, Sebastian Gehrmann, Parker Schuh, Kensen Shi, Sasha Tsvyashchenko, Joshua Maynez, Abhishek Rao, Parker Barnes, Yi Tay, Noam Shazeer, Vinodkumar Prabhakaran, Emily Reif, Nan Du, Ben Hutchinson, Reiner Pope, James Bradbury, Jacob Austin, Michael Isard, Guy Gur-Ari, Pengcheng Yin, Toju Duke, Anselm Levskaya, Sanjay Ghemawat, Sunipa Dev, Henryk Michalewski, Xavier Garcia, Vedant Misra, Kevin Robinson, Liam Fedus, Denny Zhou, Daphne Ippolito, David Luan, Hyeontaek Lim, Barret Zoph, Alexander Spiridonov, Ryan Sepassi, David Dohan, Shivani Agrawal, Mark Omernick, Andrew M. Dai, Thanumalayan Sankaranarayana Pillai, Marie Pellat, Aitor Lewkowycz, Erica Moreira, Rewon Child, Oleksandr Polozov, Katherine Lee, Zongwei Zhou, Xuezhi Wang, Brennan Saeta, Mark Diaz, Orhan Firat, Michele Catasta, Jason Wei, Kathy Meier-Hellstern, Douglas Eck, Jeff Dean, Slav Petrov, Noah Fiedel
|
34
|
+
https://arxiv.org/abs/2204.02311
|
35
|
+
"""
|
36
|
+
|
37
|
+
def __init__(self, params, lr: float = 3e-3, beta=0.9, shampoo_beta: float = 0.95, eps: float = 1e-8,
|
38
|
+
weight_decay: float = 0.01, precondition_frequency: int = 2, max_precond_dim: int = 2048, #
|
39
|
+
merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
|
40
|
+
data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1,
|
41
|
+
precond_scheduler=(1 / 3, 9), betas=(None, None), beta2_scale: float = 0.8, split: bool = False,
|
42
|
+
foreach: bool = True, mars: bool = False, caution: bool = False, mars_gamma: float = 0.0025):
|
43
|
+
if betas[0] is not None:
|
44
|
+
beta = betas[0]
|
45
|
+
defaults = {"lr": lr, "beta": beta, "shampoo_beta": shampoo_beta, "eps": eps, "weight_decay": weight_decay,
|
46
|
+
"precondition_frequency": precondition_frequency, "max_precond_dim": max_precond_dim,
|
47
|
+
"merge_dims": merge_dims, "precondition_1d": precondition_1d, "normalize_grads": normalize_grads,
|
48
|
+
"correct_bias": correct_bias, 'warmup_steps': warmup_steps, 'precond_scheduler': precond_scheduler,
|
49
|
+
'beta2_scale': beta2_scale, 'split': split, 'mars': mars, 'caution': caution,
|
50
|
+
'mars_gamma': mars_gamma}
|
51
|
+
super().__init__(params, defaults, foreach)
|
52
|
+
self._data_format = data_format
|
53
|
+
self.rng = random.Random(0x120983109)
|
54
|
+
|
55
|
+
def _step(self, group):
|
56
|
+
vals = []
|
57
|
+
step = 0
|
58
|
+
|
59
|
+
max_precond_dim = group['max_precond_dim']
|
60
|
+
precondition_1d = group['precondition_1d']
|
61
|
+
|
62
|
+
for p, g in self.split_p_and_g_in_group(group, beta1=group['beta']):
|
63
|
+
state = self.state_(p)
|
64
|
+
step = state['step'] = state.get("step", -1) + 1
|
65
|
+
|
66
|
+
if "exp_avg" not in state:
|
67
|
+
state["exp_avg"] = torch.zeros_like(g, dtype=torch.float32, memory_format=torch.preserve_format)
|
68
|
+
state["exp_avg_sq"] = torch.zeros_like(g, dtype=torch.float32, memory_format=torch.preserve_format)
|
69
|
+
init_preconditioner(g, state, max_precond_dim, precondition_1d)
|
70
|
+
update_preconditioner(g, state, max_precond_dim, precondition_1d, 0, True)
|
71
|
+
continue # first step is skipped so that we never use the current gradients in the projection.
|
72
|
+
|
73
|
+
# Projecting gradients to the eigenbases of Shampoo's preconditioner
|
74
|
+
# i.e. projecting to the eigenbases of matrices in state['GG']
|
75
|
+
grad_projected = project(g, state['Q'], False)
|
76
|
+
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
|
77
|
+
vals.append((p, g, grad_projected, exp_avg, exp_avg_sq))
|
78
|
+
|
79
|
+
if not vals:
|
80
|
+
return
|
81
|
+
|
82
|
+
p_list, grad, grad_projected, exp_avg, exp_avg_sq = zip(*vals)
|
83
|
+
beta1 = group["beta"]
|
84
|
+
|
85
|
+
beta2 = 1 - max(step, 1) ** -group['beta2_scale']
|
86
|
+
old_debiased1 = beta_debias(beta1, step)
|
87
|
+
old_debiased2 = beta_debias(beta2, step)
|
88
|
+
|
89
|
+
# Decay the first and second moment running average coefficient
|
90
|
+
# In-place operations to update the averages at the same time
|
91
|
+
beta2 = torch.empty((), dtype=torch.float32, device=p_list[0].device).fill_(beta2)
|
92
|
+
step_tensor = torch.empty((), dtype=torch.int32, device=p_list[0].device).fill_(step)
|
93
|
+
|
94
|
+
update_precond = precond_schedule(step, group['precond_scheduler'], self.rng)
|
95
|
+
step_size = -group["lr"] * min(step / group['warmup_steps'], 1)
|
96
|
+
|
97
|
+
for p, g, gp, ea, eas in zip(p_list, grad, grad_projected, exp_avg, exp_avg_sq):
|
98
|
+
laprop_exp_avg_(ea, eas, gp, beta1, beta2, step_tensor)
|
99
|
+
state = self.state_(p)
|
100
|
+
precond = project(ea, state['Q'], True)
|
101
|
+
|
102
|
+
update_preconditioner(g, state, max_precond_dim, precondition_1d, old_debiased2, update_precond)
|
103
|
+
update_param_([p], [precond], step_size, group["weight_decay"], caution=group['caution'], grad=[g])
|
heavyball/utils.py
CHANGED
@@ -693,6 +693,29 @@ def exp_avg_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor]
|
|
693
693
|
return denom
|
694
694
|
|
695
695
|
|
696
|
+
|
697
|
+
@decorator_knowngood
|
698
|
+
def _compilable_laprop_exp_avg_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor],
|
699
|
+
grad_projected: List[Tensor], beta1: Tensor, beta2: Tensor, step: Tensor):
|
700
|
+
beta1 = beta_debias(beta1, step)
|
701
|
+
beta2 = beta_debias(beta2, step)
|
702
|
+
|
703
|
+
gp32, exp_avg_sq32 = [list(map(promote, x)) for x in [grad_projected, exp_avg_sq]]
|
704
|
+
|
705
|
+
denom = exp_avg_sq_(exp_avg_sq32, gp32, beta2, 1e-8)
|
706
|
+
gp32 = torch._foreach_div(gp32, denom)
|
707
|
+
stochastic_lerp_(exp_avg, gp32, 1 - beta1)
|
708
|
+
|
709
|
+
copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
|
710
|
+
|
711
|
+
|
712
|
+
def laprop_exp_avg_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad_projected: List[Tensor],
|
713
|
+
beta1: float, beta2: float, step: int):
|
714
|
+
exp_avg, exp_avg_sq, grad_projected = list_guard(exp_avg), list_guard(exp_avg_sq), list_guard(grad_projected)
|
715
|
+
beta1, beta, step = scalar_guard(beta1, exp_avg[0]), scalar_guard(beta2, exp_avg[0]), scalar_guard(step, exp_avg[0])
|
716
|
+
_compilable_laprop_exp_avg_(exp_avg, exp_avg_sq, grad_projected, beta1, beta2, step)
|
717
|
+
|
718
|
+
|
696
719
|
@decorator_knowngood
|
697
720
|
def _compilable_copy_stochastic_(target: Tensor, source: Tensor):
|
698
721
|
"""Taken as-is from https://github.com/pytorch/pytorch/issues/120376#issuecomment-1974828905"""
|
@@ -1,4 +1,4 @@
|
|
1
|
-
heavyball/__init__.py,sha256=
|
1
|
+
heavyball/__init__.py,sha256=RdUfGDTXw-rtoQJNediWnhDseYyyWNPVsr6tRq_ucp8,2813
|
2
2
|
heavyball/cached_delayed_psgd_kron.py,sha256=HEyT6vW6Le6FmWpf-vAEzgbAkPH2mByqXcVZn07KCMk,6866
|
3
3
|
heavyball/cached_psgd_kron.py,sha256=rOgWAeVMENI7kdoBuRo3ywrCeatAnIqBdeYPHuVk2aU,6998
|
4
4
|
heavyball/delayed_psgd.py,sha256=L6qRLPxJmJ_1e0Mk2zLYUEVxkt8NGHq6v3HKawlgFcU,6334
|
@@ -6,19 +6,23 @@ heavyball/foreach_adamw.py,sha256=K4xTes4drylAqaqWky8O_Bg_mmbAmcHZ5DEBs5vMD-s,28
|
|
6
6
|
heavyball/foreach_adopt.py,sha256=fHnbEqvKKc5IKPDWC9Qo9PiISSjj1MEViy0Jb3BRgZQ,3582
|
7
7
|
heavyball/foreach_laprop.py,sha256=EXkwFQ-H7hHWLmiNUsxUcmXhzNNLMjieHjfOlY_6kmo,2868
|
8
8
|
heavyball/foreach_sfadamw.py,sha256=TeWf0nKXQEFcz02rADYRJenDM9mX1dGHhvILLks6OW8,3087
|
9
|
-
heavyball/foreach_soap.py,sha256=
|
9
|
+
heavyball/foreach_soap.py,sha256=ntFqg0fbkZ8EzERGlypXB8JWoGJ1sAY59f0CuWh_d48,4801
|
10
|
+
heavyball/foreach_solp.py,sha256=1r7x_FUZRaUsoSLSvi-Z_-pZNtZrMresVJGq9m1EREA,4563
|
10
11
|
heavyball/p_adam.py,sha256=qEcuU8VEc35vaWAXjT0O65vfCuNn_3ttwL4RlJKN3Xw,6389
|
11
12
|
heavyball/palm_foreach_sfadamw.py,sha256=1qOr-uniSmI1sNCJc1SnvyKH5iFu80Z6H5h93lDTwcE,3410
|
12
|
-
heavyball/palm_foreach_soap.py,sha256=
|
13
|
+
heavyball/palm_foreach_soap.py,sha256=fbRL1Tx9YeQ16sHWFPtY5Kj60BFV2AMngOnTiE4muK0,6231
|
14
|
+
heavyball/palm_foreach_solp.py,sha256=N3M3tnahOfSHvLu3en76JTI1yo-ISEbSliSKlpt8ZWw,5994
|
13
15
|
heavyball/precond_schedule_foreach_soap.py,sha256=p7oD2bESyCPsdGkJYhHluraDb_1K5Q28RNL6fIvD5C8,4969
|
16
|
+
heavyball/precond_schedule_foreach_solp.py,sha256=xGEQ6HHUTCKeT9-ppEbLTdXVAfE74P0tph0qS16USyg,4768
|
14
17
|
heavyball/precond_schedule_palm_foreach_soap.py,sha256=Sb3Fhv-EG28_oXnbVpE0iHe5R8i5_hltqoi_DgPuoEU,6505
|
18
|
+
heavyball/precond_schedule_palm_foreach_solp.py,sha256=gaoJwJo_ZBnYuMamgFepnV9iWpUCmbYrxMWiL1QkPh0,6253
|
15
19
|
heavyball/precond_schedule_sfpsoap.py,sha256=KUKdZzd336w24zPRcqwRatj7IVmd1Us0a_VuzASluIo,7565
|
16
20
|
heavyball/psgd_kron.py,sha256=PtTe6eR547Y-4CvgjpchgkQsr_kWr4AN-uY9L_JO_C8,6088
|
17
21
|
heavyball/pure_psgd.py,sha256=344NdVNHwUFX3fU2R1S_Xh9SXAML3E4ryHr7xfMh9Cc,5076
|
18
22
|
heavyball/schedule_free_palm_foreach_soap.py,sha256=KTQY37MZH7YnOSTLKY8uVySUXxWXbFVUA1QXN3iv8Ds,7244
|
19
|
-
heavyball/utils.py,sha256=
|
20
|
-
heavyball-0.25.
|
21
|
-
heavyball-0.25.
|
22
|
-
heavyball-0.25.
|
23
|
-
heavyball-0.25.
|
24
|
-
heavyball-0.25.
|
23
|
+
heavyball/utils.py,sha256=_KvCCCnsu_l4I_OhiRr4noAiwUvzctN05JAuYPkrxXY,41191
|
24
|
+
heavyball-0.25.1.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
|
25
|
+
heavyball-0.25.1.dist-info/METADATA,sha256=WWR7dX_i7dcF-73-VJ42qcRFwZRL3unOSEwO4EM96e0,11926
|
26
|
+
heavyball-0.25.1.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
27
|
+
heavyball-0.25.1.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
|
28
|
+
heavyball-0.25.1.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|