heavyball 0.24.4__py3-none-any.whl → 0.25.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- heavyball/__init__.py +12 -2
- heavyball/foreach_soap.py +3 -1
- heavyball/foreach_solp.py +89 -0
- heavyball/palm_foreach_soap.py +2 -1
- heavyball/palm_foreach_solp.py +98 -0
- heavyball/precond_schedule_foreach_solp.py +95 -0
- heavyball/precond_schedule_palm_foreach_solp.py +103 -0
- heavyball/utils.py +33 -7
- {heavyball-0.24.4.dist-info → heavyball-0.25.1.dist-info}/METADATA +1 -1
- {heavyball-0.24.4.dist-info → heavyball-0.25.1.dist-info}/RECORD +13 -9
- {heavyball-0.24.4.dist-info → heavyball-0.25.1.dist-info}/LICENSE +0 -0
- {heavyball-0.24.4.dist-info → heavyball-0.25.1.dist-info}/WHEEL +0 -0
- {heavyball-0.24.4.dist-info → heavyball-0.25.1.dist-info}/top_level.txt +0 -0
heavyball/__init__.py
CHANGED
@@ -15,6 +15,10 @@ from .precond_schedule_sfpsoap import PrecondScheduleSFPaLMSOAP
|
|
15
15
|
from .psgd_kron import ForeachPSGDKron
|
16
16
|
from .pure_psgd import ForeachPurePSGD
|
17
17
|
from .schedule_free_palm_foreach_soap import SFPaLMForeachSOAP
|
18
|
+
from .foreach_solp import ForeachSOLP
|
19
|
+
from .palm_foreach_solp import PaLMForeachSOLP
|
20
|
+
from .precond_schedule_palm_foreach_solp import PrecondSchedulePaLMForeachSOLP
|
21
|
+
from .precond_schedule_foreach_solp import PrecondScheduleForeachSOLP
|
18
22
|
|
19
23
|
PalmForEachSoap = PaLMForeachSOAP
|
20
24
|
|
@@ -35,12 +39,18 @@ PaLMPAdam = ForeachPaLMPAdam
|
|
35
39
|
DelayedPSGD = ForeachDelayedPSGD
|
36
40
|
CachedPSGDKron = ForeachCachedPSGDKron
|
37
41
|
CachedDelayedPSGDKron = ForeachCachedDelayedPSGDKron
|
42
|
+
SOLP = ForeachSOLP
|
43
|
+
PaLMSOLP = PaLMForeachSOLP
|
44
|
+
PrecondSchedulePaLMSOLP = PrecondSchedulePaLMForeachSOLP
|
45
|
+
PrecondScheduleSOLP = PrecondScheduleForeachSOLP
|
38
46
|
|
39
47
|
__all__ = ['PalmForEachSoap', 'PaLMForeachSFAdamW', 'PaLMForeachSOAP', 'SFPaLMForeachSOAP', 'PrecondScheduleSFPaLMSOAP',
|
40
48
|
'ForeachSOAP', 'ForeachSFAdamW', 'ForeachLaProp', 'ForeachADOPT', 'PrecondScheduleForeachSOAP',
|
41
49
|
'PrecondSchedulePaLMForeachSOAP', 'ForeachPSGDKron', 'ForeachAdamW', 'ForeachPurePSGD', 'ForeachPaLMPAdam',
|
42
|
-
'ForeachDelayedPSGD', 'ForeachCachedPSGDKron', 'ForeachCachedDelayedPSGDKron',
|
50
|
+
'ForeachDelayedPSGD', 'ForeachCachedPSGDKron', 'ForeachCachedDelayedPSGDKron', 'ForeachSOLP',
|
51
|
+
'PaLMForeachSOLP', 'PrecondSchedulePaLMForeachSOLP', 'PrecondScheduleForeachSOLP',
|
43
52
|
#
|
44
53
|
'PaLMSOAP', 'PaLMSFAdamW', 'PaLMSFSoap', 'PaLMSFAdamW', 'PrecondScheduleSFPaLMSOAP', 'SOAP', 'SFAdamW',
|
45
54
|
'LaProp', 'ADOPT', 'PSGDKron', 'AdamW', 'PurePSGD', 'PaLMPAdam', 'DelayedPSGD', 'CachedPSGDKron',
|
46
|
-
'CachedDelayedPSGDKron', 'PrecondScheduleSOAP', 'PrecondSchedulePaLMSOAP'
|
55
|
+
'CachedDelayedPSGDKron', 'PrecondScheduleSOAP', 'PrecondSchedulePaLMSOAP', 'SOLP', 'PrecondScheduleSOLP',
|
56
|
+
'PrecondSchedulePaLMSOLP', 'PrecondScheduleSOLP']
|
heavyball/foreach_soap.py
CHANGED
@@ -6,7 +6,7 @@ from .utils import init_preconditioner, update_preconditioner, project, beta_deb
|
|
6
6
|
|
7
7
|
class ForeachSOAP(StatefulOptimizer):
|
8
8
|
"""
|
9
|
-
|
9
|
+
ForeachSOAP
|
10
10
|
|
11
11
|
Sources:
|
12
12
|
Baseline SOAP:
|
@@ -75,6 +75,8 @@ class ForeachSOAP(StatefulOptimizer):
|
|
75
75
|
step_size = -group["lr"] * min(step / group['warmup_steps'], 1)
|
76
76
|
|
77
77
|
for p, g, gp, ea, eas in zip(p_list, grad, grad_projected, exp_avg, exp_avg_sq):
|
78
|
+
state = self.state_(p)
|
79
|
+
|
78
80
|
d = exp_avg_(ea, eas, g, gp, beta1, beta2, step_tensor)[0]
|
79
81
|
# Projecting the exponential moving average of gradients to the eigenbases of Shampoo's preconditioner
|
80
82
|
# i.e. projecting to the eigenbases of matrices in state['GG']
|
@@ -0,0 +1,89 @@
|
|
1
|
+
import torch
|
2
|
+
|
3
|
+
from .utils import init_preconditioner, update_preconditioner, project, beta_debias, exp_avg_sq_, update_param_, set_, \
|
4
|
+
StatefulOptimizer, laprop_exp_avg_
|
5
|
+
|
6
|
+
|
7
|
+
class ForeachSOLP(StatefulOptimizer):
|
8
|
+
"""
|
9
|
+
ForeachSOLP
|
10
|
+
|
11
|
+
Sources:
|
12
|
+
LaProp:
|
13
|
+
LaProp: Separating Momentum and Adaptivity in Adam
|
14
|
+
Liu Ziyin, Zhikang T.Wang, Masahito Ueda
|
15
|
+
https://arxiv.org/abs/2002.04839
|
16
|
+
https://github.com/ClashLuke/HeavyBall/blob/main/heavyball/foreach_laprop.py
|
17
|
+
|
18
|
+
Baseline SOAP:
|
19
|
+
SOAP: Improving and Stabilizing Shampoo using Adam
|
20
|
+
Nikhil Vyas, Depen Morwani, Rosie Zhao, Itai Shapira, David Brandfonbrener, Lucas Janson, Sham Kakade
|
21
|
+
https://arxiv.org/abs/2409.11321
|
22
|
+
https://github.com/nikhilvyas/SOAP
|
23
|
+
|
24
|
+
ScheduleFree:
|
25
|
+
The Road Less Scheduled
|
26
|
+
Aaron Defazio, Xingyu Alice Yang, Harsh Mehta, Konstantin Mishchenko, Ahmed Khaled, Ashok Cutkosky
|
27
|
+
https://arxiv.org/abs/2405.15682
|
28
|
+
https://github.com/facebookresearch/schedule_free
|
29
|
+
"""
|
30
|
+
|
31
|
+
def __init__(self, params, lr: float = 3e-3, betas=(0.9, 0.95), shampoo_beta: float = 0.95, eps: float = 1e-8,
|
32
|
+
weight_decay: float = 0.01, precondition_frequency: int = 2, max_precond_dim: int = 2048, #
|
33
|
+
merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
|
34
|
+
data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1,
|
35
|
+
split: bool = False, foreach: bool = True, mars: bool = False, caution: bool = False,
|
36
|
+
mars_gamma: float = 0.0025):
|
37
|
+
defaults = {"lr": lr, "betas": betas, "shampoo_beta": shampoo_beta, "eps": eps, "weight_decay": weight_decay,
|
38
|
+
"precondition_frequency": precondition_frequency, "max_precond_dim": max_precond_dim,
|
39
|
+
"merge_dims": merge_dims, "precondition_1d": precondition_1d, "normalize_grads": normalize_grads,
|
40
|
+
"correct_bias": correct_bias, 'warmup_steps': warmup_steps, 'split': split, 'mars': mars,
|
41
|
+
'caution': caution, 'mars_gamma': mars_gamma}
|
42
|
+
super().__init__(params, defaults, foreach)
|
43
|
+
self._data_format = data_format
|
44
|
+
|
45
|
+
def _step(self, group):
|
46
|
+
vals = []
|
47
|
+
step = 0
|
48
|
+
|
49
|
+
max_precond_dim = group['max_precond_dim']
|
50
|
+
precondition_1d = group['precondition_1d']
|
51
|
+
|
52
|
+
for p, g in self.split_p_and_g_in_group(group, beta1=group['betas'][0]):
|
53
|
+
state = self.state_(p)
|
54
|
+
step = state['step'] = state.get("step", -1) + 1
|
55
|
+
|
56
|
+
if "exp_avg" not in state:
|
57
|
+
state["exp_avg"] = torch.zeros_like(g, dtype=torch.float32, memory_format=torch.preserve_format)
|
58
|
+
state["exp_avg_sq"] = torch.zeros_like(g, dtype=torch.float32, memory_format=torch.preserve_format)
|
59
|
+
init_preconditioner(g, state, max_precond_dim, precondition_1d)
|
60
|
+
update_preconditioner(g, state, max_precond_dim, precondition_1d, 0, True)
|
61
|
+
continue # first step is skipped so that we never use the current gradients in the projection.
|
62
|
+
|
63
|
+
# Projecting gradients to the eigenbases of Shampoo's preconditioner
|
64
|
+
# i.e. projecting to the eigenbases of matrices in state['GG']
|
65
|
+
grad_projected = project(g, state['Q'], False)
|
66
|
+
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
|
67
|
+
vals.append((p, g, grad_projected, exp_avg, exp_avg_sq))
|
68
|
+
|
69
|
+
if not vals:
|
70
|
+
return
|
71
|
+
|
72
|
+
p_list, grad, grad_projected, exp_avg, exp_avg_sq = zip(*vals)
|
73
|
+
beta1, beta2 = group["betas"]
|
74
|
+
|
75
|
+
old_debiased2 = beta_debias(beta2, step)
|
76
|
+
|
77
|
+
# Decay the first and second moment running average coefficient
|
78
|
+
# In-place operations to update the averages at the same time
|
79
|
+
step_tensor = torch.empty((), dtype=torch.int32, device=p_list[0].device).fill_(step)
|
80
|
+
|
81
|
+
step_size = -group["lr"] * min(step / group['warmup_steps'], 1)
|
82
|
+
|
83
|
+
for p, g, gp, ea, eas in zip(p_list, grad, grad_projected, exp_avg, exp_avg_sq):
|
84
|
+
laprop_exp_avg_(ea, eas, gp, beta1, beta2, step_tensor)
|
85
|
+
state = self.state_(p)
|
86
|
+
precond = project(ea, state['Q'], True)
|
87
|
+
|
88
|
+
update_preconditioner(g, state, max_precond_dim, precondition_1d, old_debiased2, step > 0 and step % group['precondition_frequency'] == 0)
|
89
|
+
update_param_([p], [precond], step_size, group["weight_decay"], caution=group['caution'], grad=[g])
|
heavyball/palm_foreach_soap.py
CHANGED
@@ -6,7 +6,7 @@ from .utils import init_preconditioner, update_preconditioner, project, beta_deb
|
|
6
6
|
|
7
7
|
class PaLMForeachSOAP(StatefulOptimizer):
|
8
8
|
"""
|
9
|
-
|
9
|
+
PaLMForeachSOAP
|
10
10
|
|
11
11
|
Sources:
|
12
12
|
Baseline SOAP:
|
@@ -84,6 +84,7 @@ class PaLMForeachSOAP(StatefulOptimizer):
|
|
84
84
|
step_size = -group["lr"] * min(step / group['warmup_steps'], 1)
|
85
85
|
|
86
86
|
for p, g, gp, ea, eas in zip(p_list, grad, grad_projected, exp_avg, exp_avg_sq):
|
87
|
+
state = self.state_(p)
|
87
88
|
d = exp_avg_(ea, eas, g, gp, beta1, beta2, step_tensor)[0]
|
88
89
|
# Projecting the exponential moving average of gradients to the eigenbases of Shampoo's preconditioner
|
89
90
|
# i.e. projecting to the eigenbases of matrices in state['GG']
|
@@ -0,0 +1,98 @@
|
|
1
|
+
import torch
|
2
|
+
|
3
|
+
from .utils import init_preconditioner, update_preconditioner, project, beta_debias, exp_avg_sq_, update_param_, set_, \
|
4
|
+
StatefulOptimizer, laprop_exp_avg_
|
5
|
+
|
6
|
+
|
7
|
+
class PaLMForeachSOLP(StatefulOptimizer):
|
8
|
+
"""
|
9
|
+
PaLMForeachSOAP
|
10
|
+
|
11
|
+
Sources:
|
12
|
+
LaProp:
|
13
|
+
LaProp: Separating Momentum and Adaptivity in Adam
|
14
|
+
Liu Ziyin, Zhikang T.Wang, Masahito Ueda
|
15
|
+
https://arxiv.org/abs/2002.04839
|
16
|
+
https://github.com/ClashLuke/HeavyBall/blob/main/heavyball/foreach_laprop.py
|
17
|
+
|
18
|
+
Baseline SOAP:
|
19
|
+
SOAP: Improving and Stabilizing Shampoo using Adam
|
20
|
+
Nikhil Vyas, Depen Morwani, Rosie Zhao, Itai Shapira, David Brandfonbrener, Lucas Janson, Sham Kakade
|
21
|
+
https://arxiv.org/abs/2409.11321
|
22
|
+
https://github.com/nikhilvyas/SOAP
|
23
|
+
|
24
|
+
ScheduleFree:
|
25
|
+
The Road Less Scheduled
|
26
|
+
Aaron Defazio, Xingyu Alice Yang, Harsh Mehta, Konstantin Mishchenko, Ahmed Khaled, Ashok Cutkosky
|
27
|
+
https://arxiv.org/abs/2405.15682
|
28
|
+
https://github.com/facebookresearch/schedule_free
|
29
|
+
|
30
|
+
Beta2 Schedule:
|
31
|
+
PaLM: Scaling Language Modeling with Pathways
|
32
|
+
Aakanksha Chowdhery, Sharan Narang, Jacob Devlin, Maarten Bosma, Gaurav Mishra, Adam Roberts, Paul Barham, Hyung Won Chung, Charles Sutton, Sebastian Gehrmann, Parker Schuh, Kensen Shi, Sasha Tsvyashchenko, Joshua Maynez, Abhishek Rao, Parker Barnes, Yi Tay, Noam Shazeer, Vinodkumar Prabhakaran, Emily Reif, Nan Du, Ben Hutchinson, Reiner Pope, James Bradbury, Jacob Austin, Michael Isard, Guy Gur-Ari, Pengcheng Yin, Toju Duke, Anselm Levskaya, Sanjay Ghemawat, Sunipa Dev, Henryk Michalewski, Xavier Garcia, Vedant Misra, Kevin Robinson, Liam Fedus, Denny Zhou, Daphne Ippolito, David Luan, Hyeontaek Lim, Barret Zoph, Alexander Spiridonov, Ryan Sepassi, David Dohan, Shivani Agrawal, Mark Omernick, Andrew M. Dai, Thanumalayan Sankaranarayana Pillai, Marie Pellat, Aitor Lewkowycz, Erica Moreira, Rewon Child, Oleksandr Polozov, Katherine Lee, Zongwei Zhou, Xuezhi Wang, Brennan Saeta, Mark Diaz, Orhan Firat, Michele Catasta, Jason Wei, Kathy Meier-Hellstern, Douglas Eck, Jeff Dean, Slav Petrov, Noah Fiedel
|
33
|
+
https://arxiv.org/abs/2204.02311
|
34
|
+
"""
|
35
|
+
|
36
|
+
def __init__(self, params, lr: float = 3e-3, beta=0.9, betas=(None, None), shampoo_beta: float = 0.95,
|
37
|
+
eps: float = 1e-8, weight_decay: float = 0.01, precondition_frequency: int = 2,
|
38
|
+
max_precond_dim: int = 2048, #
|
39
|
+
merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
|
40
|
+
data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1,
|
41
|
+
beta2_scale: float = 0.8, split: bool = False, foreach: bool = True, mars: bool = False,
|
42
|
+
caution: bool = False, mars_gamma: float = 0.0025):
|
43
|
+
if betas[0] is not None:
|
44
|
+
beta = betas[0]
|
45
|
+
defaults = {"lr": lr, "beta": beta, "shampoo_beta": shampoo_beta, "eps": eps, "weight_decay": weight_decay,
|
46
|
+
"precondition_frequency": precondition_frequency, "max_precond_dim": max_precond_dim,
|
47
|
+
"merge_dims": merge_dims, "precondition_1d": precondition_1d, "normalize_grads": normalize_grads,
|
48
|
+
"correct_bias": correct_bias, 'warmup_steps': warmup_steps, 'beta2_scale': beta2_scale,
|
49
|
+
'split': split, 'mars': mars, 'caution': caution, 'mars_gamma': mars_gamma}
|
50
|
+
super().__init__(params, defaults, foreach)
|
51
|
+
self._data_format = data_format
|
52
|
+
|
53
|
+
def _step(self, group):
|
54
|
+
vals = []
|
55
|
+
step = 0
|
56
|
+
|
57
|
+
max_precond_dim = group['max_precond_dim']
|
58
|
+
precondition_1d = group['precondition_1d']
|
59
|
+
|
60
|
+
for p, g in self.split_p_and_g_in_group(group, beta1=group['beta']):
|
61
|
+
state = self.state_(p)
|
62
|
+
step = state['step'] = state.get("step", -1) + 1
|
63
|
+
|
64
|
+
if "exp_avg" not in state:
|
65
|
+
state["exp_avg"] = torch.zeros_like(g, dtype=torch.float32, memory_format=torch.preserve_format)
|
66
|
+
state["exp_avg_sq"] = torch.zeros_like(g, dtype=torch.float32, memory_format=torch.preserve_format)
|
67
|
+
init_preconditioner(g, state, max_precond_dim, precondition_1d)
|
68
|
+
update_preconditioner(g, state, max_precond_dim, precondition_1d, 0, True)
|
69
|
+
continue # first step is skipped so that we never use the current gradients in the projection.
|
70
|
+
|
71
|
+
# Projecting gradients to the eigenbases of Shampoo's preconditioner
|
72
|
+
# i.e. projecting to the eigenbases of matrices in state['GG']
|
73
|
+
grad_projected = project(g, state['Q'], False)
|
74
|
+
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
|
75
|
+
vals.append((p, g, grad_projected, exp_avg, exp_avg_sq))
|
76
|
+
|
77
|
+
if not vals:
|
78
|
+
return
|
79
|
+
|
80
|
+
p_list, grad, grad_projected, exp_avg, exp_avg_sq = zip(*vals)
|
81
|
+
beta1 = group["beta"]
|
82
|
+
|
83
|
+
beta2 = 1 - step ** -group['beta2_scale']
|
84
|
+
old_debiased2 = beta_debias(beta2, step)
|
85
|
+
|
86
|
+
# Decay the first and second moment running average coefficient
|
87
|
+
# In-place operations to update the averages at the same time
|
88
|
+
beta2 = torch.empty((), dtype=torch.float32, device=p_list[0].device).fill_(beta2)
|
89
|
+
step_tensor = torch.empty((), dtype=torch.int32, device=p_list[0].device).fill_(step)
|
90
|
+
step_size = -group["lr"] * min(step / group['warmup_steps'], 1)
|
91
|
+
|
92
|
+
for p, g, gp, ea, eas in zip(p_list, grad, grad_projected, exp_avg, exp_avg_sq):
|
93
|
+
laprop_exp_avg_(ea, eas, gp, beta1, beta2, step_tensor)
|
94
|
+
state = self.state_(p)
|
95
|
+
precond = project(ea, state['Q'], True)
|
96
|
+
|
97
|
+
update_preconditioner(g, state, max_precond_dim, precondition_1d, old_debiased2, step > 0 and step % group['precondition_frequency'] == 0)
|
98
|
+
update_param_([p], [precond], step_size, group["weight_decay"], caution=group['caution'], grad=[g])
|
@@ -0,0 +1,95 @@
|
|
1
|
+
|
2
|
+
|
3
|
+
import random
|
4
|
+
|
5
|
+
import torch
|
6
|
+
|
7
|
+
from .utils import init_preconditioner, update_preconditioner, project, beta_debias, update_param_, \
|
8
|
+
precond_schedule, set_, StatefulOptimizer, laprop_exp_avg_
|
9
|
+
|
10
|
+
|
11
|
+
class PrecondScheduleForeachSOLP(StatefulOptimizer):
|
12
|
+
"""
|
13
|
+
Sources:
|
14
|
+
LaProp:
|
15
|
+
LaProp: Separating Momentum and Adaptivity in Adam
|
16
|
+
Liu Ziyin, Zhikang T.Wang, Masahito Ueda
|
17
|
+
https://arxiv.org/abs/2002.04839
|
18
|
+
https://github.com/ClashLuke/HeavyBall/blob/main/heavyball/foreach_laprop.py
|
19
|
+
|
20
|
+
Preconditioner Schedules:
|
21
|
+
Preconditioned Stochastic Gradient Descent
|
22
|
+
Xi-Lin Li, Omead Pooladzandi, Evan Walters
|
23
|
+
https://arxiv.org/abs/1512.04202
|
24
|
+
https://github.com/evanatyourservice/kron_torch
|
25
|
+
https://github.com/lixilinx/psgd_torch
|
26
|
+
|
27
|
+
Baseline SOAP:
|
28
|
+
SOAP: Improving and Stabilizing Shampoo using Adam
|
29
|
+
Nikhil Vyas, Depen Morwani, Rosie Zhao, Itai Shapira, David Brandfonbrener, Lucas Janson, Sham Kakade
|
30
|
+
https://arxiv.org/abs/2409.11321
|
31
|
+
https://github.com/nikhilvyas/SOAP
|
32
|
+
"""
|
33
|
+
|
34
|
+
def __init__(self, params, lr: float = 3e-3, betas=(0.9, 0.95), shampoo_beta: float = 0.95, eps: float = 1e-8,
|
35
|
+
weight_decay: float = 0.01, precondition_frequency: int = 2, max_precond_dim: int = 2048, #
|
36
|
+
merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
|
37
|
+
data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1,
|
38
|
+
precond_scheduler=(1 / 3, 9), split: bool = False, foreach: bool = True, mars: bool = False,
|
39
|
+
caution: bool = False, mars_gamma: float = 0.0025):
|
40
|
+
defaults = {"lr": lr, "betas": betas, "shampoo_beta": shampoo_beta, "eps": eps, "weight_decay": weight_decay,
|
41
|
+
"precondition_frequency": precondition_frequency, "max_precond_dim": max_precond_dim,
|
42
|
+
"merge_dims": merge_dims, "precondition_1d": precondition_1d, "normalize_grads": normalize_grads,
|
43
|
+
"correct_bias": correct_bias, 'warmup_steps': warmup_steps, 'precond_scheduler': precond_scheduler,
|
44
|
+
'split': split, 'mars': mars, 'caution': caution, 'mars_gamma': mars_gamma}
|
45
|
+
super().__init__(params, defaults, foreach)
|
46
|
+
self._data_format = data_format
|
47
|
+
self.rng = random.Random(0x120983109)
|
48
|
+
|
49
|
+
def _step(self, group):
|
50
|
+
vals = []
|
51
|
+
step = 0
|
52
|
+
|
53
|
+
max_precond_dim = group['max_precond_dim']
|
54
|
+
precondition_1d = group['precondition_1d']
|
55
|
+
|
56
|
+
for p, g in self.split_p_and_g_in_group(group, beta1=group['betas'][0]):
|
57
|
+
state = self.state_(p)
|
58
|
+
step = state['step'] = state.get("step", -1) + 1
|
59
|
+
|
60
|
+
if "exp_avg" not in state:
|
61
|
+
state["exp_avg"] = torch.zeros_like(g, dtype=torch.float32, memory_format=torch.preserve_format)
|
62
|
+
state["exp_avg_sq"] = torch.zeros_like(g, dtype=torch.float32, memory_format=torch.preserve_format)
|
63
|
+
init_preconditioner(g, state, max_precond_dim, precondition_1d)
|
64
|
+
update_preconditioner(g, state, max_precond_dim, precondition_1d, 0, True)
|
65
|
+
continue # first step is skipped so that we never use the current gradients in the projection.
|
66
|
+
|
67
|
+
# Projecting gradients to the eigenbases of Shampoo's preconditioner
|
68
|
+
# i.e. projecting to the eigenbases of matrices in state['GG']
|
69
|
+
grad_projected = project(g, state['Q'], False)
|
70
|
+
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
|
71
|
+
vals.append((p, g, grad_projected, exp_avg, exp_avg_sq))
|
72
|
+
|
73
|
+
if not vals:
|
74
|
+
return
|
75
|
+
|
76
|
+
p_list, grad, grad_projected, exp_avg, exp_avg_sq = zip(*vals)
|
77
|
+
beta1, beta2 = group["betas"]
|
78
|
+
|
79
|
+
old_debiased2 = beta_debias(beta2, step)
|
80
|
+
|
81
|
+
# Decay the first and second moment running average coefficient
|
82
|
+
# In-place operations to update the averages at the same time
|
83
|
+
step_tensor = torch.empty((), dtype=torch.int32, device=p_list[0].device).fill_(step)
|
84
|
+
|
85
|
+
update_precond = precond_schedule(step, group['precond_scheduler'], self.rng)
|
86
|
+
step_size = -group["lr"] * min(step / group['warmup_steps'], 1)
|
87
|
+
|
88
|
+
|
89
|
+
for p, g, gp, ea, eas in zip(p_list, grad, grad_projected, exp_avg, exp_avg_sq):
|
90
|
+
laprop_exp_avg_(ea, eas, gp, beta1, beta2, step_tensor)
|
91
|
+
state = self.state_(p)
|
92
|
+
precond = project(ea, state['Q'], True)
|
93
|
+
|
94
|
+
update_preconditioner(g, state, max_precond_dim, precondition_1d, old_debiased2, update_precond)
|
95
|
+
update_param_([p], [precond], step_size, group["weight_decay"], caution=group['caution'], grad=[g])
|
@@ -0,0 +1,103 @@
|
|
1
|
+
import random
|
2
|
+
|
3
|
+
import torch
|
4
|
+
|
5
|
+
from .utils import init_preconditioner, update_preconditioner, project, beta_debias, laprop_exp_avg_, update_param_, \
|
6
|
+
precond_schedule, set_, StatefulOptimizer
|
7
|
+
|
8
|
+
|
9
|
+
class PrecondSchedulePaLMForeachSOLP(StatefulOptimizer):
|
10
|
+
"""
|
11
|
+
Sources:
|
12
|
+
LaProp:
|
13
|
+
LaProp: Separating Momentum and Adaptivity in Adam
|
14
|
+
Liu Ziyin, Zhikang T.Wang, Masahito Ueda
|
15
|
+
https://arxiv.org/abs/2002.04839
|
16
|
+
https://github.com/ClashLuke/HeavyBall/blob/main/heavyball/foreach_laprop.py
|
17
|
+
|
18
|
+
Preconditioner Schedules:
|
19
|
+
Preconditioned Stochastic Gradient Descent
|
20
|
+
Xi-Lin Li, Omead Pooladzandi, Evan Walters
|
21
|
+
https://arxiv.org/abs/1512.04202
|
22
|
+
https://github.com/evanatyourservice/kron_torch
|
23
|
+
https://github.com/lixilinx/psgd_torch
|
24
|
+
|
25
|
+
Baseline SOAP:
|
26
|
+
SOAP: Improving and Stabilizing Shampoo using Adam
|
27
|
+
Nikhil Vyas, Depen Morwani, Rosie Zhao, Itai Shapira, David Brandfonbrener, Lucas Janson, Sham Kakade
|
28
|
+
https://arxiv.org/abs/2409.11321
|
29
|
+
https://github.com/nikhilvyas/SOAP
|
30
|
+
|
31
|
+
Beta2 Schedule:
|
32
|
+
PaLM: Scaling Language Modeling with Pathways
|
33
|
+
Aakanksha Chowdhery, Sharan Narang, Jacob Devlin, Maarten Bosma, Gaurav Mishra, Adam Roberts, Paul Barham, Hyung Won Chung, Charles Sutton, Sebastian Gehrmann, Parker Schuh, Kensen Shi, Sasha Tsvyashchenko, Joshua Maynez, Abhishek Rao, Parker Barnes, Yi Tay, Noam Shazeer, Vinodkumar Prabhakaran, Emily Reif, Nan Du, Ben Hutchinson, Reiner Pope, James Bradbury, Jacob Austin, Michael Isard, Guy Gur-Ari, Pengcheng Yin, Toju Duke, Anselm Levskaya, Sanjay Ghemawat, Sunipa Dev, Henryk Michalewski, Xavier Garcia, Vedant Misra, Kevin Robinson, Liam Fedus, Denny Zhou, Daphne Ippolito, David Luan, Hyeontaek Lim, Barret Zoph, Alexander Spiridonov, Ryan Sepassi, David Dohan, Shivani Agrawal, Mark Omernick, Andrew M. Dai, Thanumalayan Sankaranarayana Pillai, Marie Pellat, Aitor Lewkowycz, Erica Moreira, Rewon Child, Oleksandr Polozov, Katherine Lee, Zongwei Zhou, Xuezhi Wang, Brennan Saeta, Mark Diaz, Orhan Firat, Michele Catasta, Jason Wei, Kathy Meier-Hellstern, Douglas Eck, Jeff Dean, Slav Petrov, Noah Fiedel
|
34
|
+
https://arxiv.org/abs/2204.02311
|
35
|
+
"""
|
36
|
+
|
37
|
+
def __init__(self, params, lr: float = 3e-3, beta=0.9, shampoo_beta: float = 0.95, eps: float = 1e-8,
|
38
|
+
weight_decay: float = 0.01, precondition_frequency: int = 2, max_precond_dim: int = 2048, #
|
39
|
+
merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
|
40
|
+
data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1,
|
41
|
+
precond_scheduler=(1 / 3, 9), betas=(None, None), beta2_scale: float = 0.8, split: bool = False,
|
42
|
+
foreach: bool = True, mars: bool = False, caution: bool = False, mars_gamma: float = 0.0025):
|
43
|
+
if betas[0] is not None:
|
44
|
+
beta = betas[0]
|
45
|
+
defaults = {"lr": lr, "beta": beta, "shampoo_beta": shampoo_beta, "eps": eps, "weight_decay": weight_decay,
|
46
|
+
"precondition_frequency": precondition_frequency, "max_precond_dim": max_precond_dim,
|
47
|
+
"merge_dims": merge_dims, "precondition_1d": precondition_1d, "normalize_grads": normalize_grads,
|
48
|
+
"correct_bias": correct_bias, 'warmup_steps': warmup_steps, 'precond_scheduler': precond_scheduler,
|
49
|
+
'beta2_scale': beta2_scale, 'split': split, 'mars': mars, 'caution': caution,
|
50
|
+
'mars_gamma': mars_gamma}
|
51
|
+
super().__init__(params, defaults, foreach)
|
52
|
+
self._data_format = data_format
|
53
|
+
self.rng = random.Random(0x120983109)
|
54
|
+
|
55
|
+
def _step(self, group):
|
56
|
+
vals = []
|
57
|
+
step = 0
|
58
|
+
|
59
|
+
max_precond_dim = group['max_precond_dim']
|
60
|
+
precondition_1d = group['precondition_1d']
|
61
|
+
|
62
|
+
for p, g in self.split_p_and_g_in_group(group, beta1=group['beta']):
|
63
|
+
state = self.state_(p)
|
64
|
+
step = state['step'] = state.get("step", -1) + 1
|
65
|
+
|
66
|
+
if "exp_avg" not in state:
|
67
|
+
state["exp_avg"] = torch.zeros_like(g, dtype=torch.float32, memory_format=torch.preserve_format)
|
68
|
+
state["exp_avg_sq"] = torch.zeros_like(g, dtype=torch.float32, memory_format=torch.preserve_format)
|
69
|
+
init_preconditioner(g, state, max_precond_dim, precondition_1d)
|
70
|
+
update_preconditioner(g, state, max_precond_dim, precondition_1d, 0, True)
|
71
|
+
continue # first step is skipped so that we never use the current gradients in the projection.
|
72
|
+
|
73
|
+
# Projecting gradients to the eigenbases of Shampoo's preconditioner
|
74
|
+
# i.e. projecting to the eigenbases of matrices in state['GG']
|
75
|
+
grad_projected = project(g, state['Q'], False)
|
76
|
+
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
|
77
|
+
vals.append((p, g, grad_projected, exp_avg, exp_avg_sq))
|
78
|
+
|
79
|
+
if not vals:
|
80
|
+
return
|
81
|
+
|
82
|
+
p_list, grad, grad_projected, exp_avg, exp_avg_sq = zip(*vals)
|
83
|
+
beta1 = group["beta"]
|
84
|
+
|
85
|
+
beta2 = 1 - max(step, 1) ** -group['beta2_scale']
|
86
|
+
old_debiased1 = beta_debias(beta1, step)
|
87
|
+
old_debiased2 = beta_debias(beta2, step)
|
88
|
+
|
89
|
+
# Decay the first and second moment running average coefficient
|
90
|
+
# In-place operations to update the averages at the same time
|
91
|
+
beta2 = torch.empty((), dtype=torch.float32, device=p_list[0].device).fill_(beta2)
|
92
|
+
step_tensor = torch.empty((), dtype=torch.int32, device=p_list[0].device).fill_(step)
|
93
|
+
|
94
|
+
update_precond = precond_schedule(step, group['precond_scheduler'], self.rng)
|
95
|
+
step_size = -group["lr"] * min(step / group['warmup_steps'], 1)
|
96
|
+
|
97
|
+
for p, g, gp, ea, eas in zip(p_list, grad, grad_projected, exp_avg, exp_avg_sq):
|
98
|
+
laprop_exp_avg_(ea, eas, gp, beta1, beta2, step_tensor)
|
99
|
+
state = self.state_(p)
|
100
|
+
precond = project(ea, state['Q'], True)
|
101
|
+
|
102
|
+
update_preconditioner(g, state, max_precond_dim, precondition_1d, old_debiased2, update_precond)
|
103
|
+
update_param_([p], [precond], step_size, group["weight_decay"], caution=group['caution'], grad=[g])
|
heavyball/utils.py
CHANGED
@@ -492,22 +492,20 @@ class StatefulOptimizer(torch.optim.Optimizer):
|
|
492
492
|
super().__init__(params, {**defaults, 'foreach': foreach})
|
493
493
|
self.fake_groups = {}
|
494
494
|
self.use_ema = use_ema
|
495
|
-
|
496
|
-
def key(self, param: Tensor):
|
497
|
-
return (param.data_ptr(), tuple(param.shape))
|
495
|
+
self.mapping = {}
|
498
496
|
|
499
497
|
def get_groups(self, group):
|
500
498
|
if group['foreach']:
|
501
499
|
return [group]
|
502
500
|
|
503
501
|
for p in group['params']:
|
504
|
-
if
|
505
|
-
self.fake_groups[
|
502
|
+
if p not in self.fake_groups:
|
503
|
+
self.fake_groups[p] = {**group, 'params': [p]}
|
506
504
|
|
507
|
-
return [self.fake_groups[
|
505
|
+
return [self.fake_groups[p] for p in group['params']]
|
508
506
|
|
509
507
|
def state_(self, arg: Tensor):
|
510
|
-
return self.state[self.
|
508
|
+
return self.state[self.mapping.get(arg, arg)]
|
511
509
|
|
512
510
|
def mars_correct_list(self, group, p_list, g_list, mars_gamma, beta):
|
513
511
|
for p, g in zip(p_list, g_list):
|
@@ -538,6 +536,8 @@ class StatefulOptimizer(torch.optim.Optimizer):
|
|
538
536
|
p_views = merge_group(group, p)
|
539
537
|
if grad is not None:
|
540
538
|
grad = merge_group(group, grad)
|
539
|
+
for i, pv in enumerate(p_views):
|
540
|
+
self.mapping[pv] = (p, i)
|
541
541
|
if isinstance(p_views, Tensor):
|
542
542
|
yield p_views, grad
|
543
543
|
continue
|
@@ -622,11 +622,14 @@ class StatefulOptimizer(torch.optim.Optimizer):
|
|
622
622
|
for top_group in self.param_groups:
|
623
623
|
for group in self.get_groups(top_group):
|
624
624
|
self._step(group)
|
625
|
+
self.mapping.clear()
|
625
626
|
if self.use_ema:
|
626
627
|
self.ema_update(group)
|
628
|
+
|
627
629
|
return loss
|
628
630
|
|
629
631
|
|
632
|
+
|
630
633
|
class ScheduleFree(StatefulOptimizer):
|
631
634
|
def eval(self):
|
632
635
|
for group in self.param_groups:
|
@@ -690,6 +693,29 @@ def exp_avg_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor]
|
|
690
693
|
return denom
|
691
694
|
|
692
695
|
|
696
|
+
|
697
|
+
@decorator_knowngood
|
698
|
+
def _compilable_laprop_exp_avg_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor],
|
699
|
+
grad_projected: List[Tensor], beta1: Tensor, beta2: Tensor, step: Tensor):
|
700
|
+
beta1 = beta_debias(beta1, step)
|
701
|
+
beta2 = beta_debias(beta2, step)
|
702
|
+
|
703
|
+
gp32, exp_avg_sq32 = [list(map(promote, x)) for x in [grad_projected, exp_avg_sq]]
|
704
|
+
|
705
|
+
denom = exp_avg_sq_(exp_avg_sq32, gp32, beta2, 1e-8)
|
706
|
+
gp32 = torch._foreach_div(gp32, denom)
|
707
|
+
stochastic_lerp_(exp_avg, gp32, 1 - beta1)
|
708
|
+
|
709
|
+
copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
|
710
|
+
|
711
|
+
|
712
|
+
def laprop_exp_avg_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad_projected: List[Tensor],
|
713
|
+
beta1: float, beta2: float, step: int):
|
714
|
+
exp_avg, exp_avg_sq, grad_projected = list_guard(exp_avg), list_guard(exp_avg_sq), list_guard(grad_projected)
|
715
|
+
beta1, beta, step = scalar_guard(beta1, exp_avg[0]), scalar_guard(beta2, exp_avg[0]), scalar_guard(step, exp_avg[0])
|
716
|
+
_compilable_laprop_exp_avg_(exp_avg, exp_avg_sq, grad_projected, beta1, beta2, step)
|
717
|
+
|
718
|
+
|
693
719
|
@decorator_knowngood
|
694
720
|
def _compilable_copy_stochastic_(target: Tensor, source: Tensor):
|
695
721
|
"""Taken as-is from https://github.com/pytorch/pytorch/issues/120376#issuecomment-1974828905"""
|
@@ -1,4 +1,4 @@
|
|
1
|
-
heavyball/__init__.py,sha256=
|
1
|
+
heavyball/__init__.py,sha256=RdUfGDTXw-rtoQJNediWnhDseYyyWNPVsr6tRq_ucp8,2813
|
2
2
|
heavyball/cached_delayed_psgd_kron.py,sha256=HEyT6vW6Le6FmWpf-vAEzgbAkPH2mByqXcVZn07KCMk,6866
|
3
3
|
heavyball/cached_psgd_kron.py,sha256=rOgWAeVMENI7kdoBuRo3ywrCeatAnIqBdeYPHuVk2aU,6998
|
4
4
|
heavyball/delayed_psgd.py,sha256=L6qRLPxJmJ_1e0Mk2zLYUEVxkt8NGHq6v3HKawlgFcU,6334
|
@@ -6,19 +6,23 @@ heavyball/foreach_adamw.py,sha256=K4xTes4drylAqaqWky8O_Bg_mmbAmcHZ5DEBs5vMD-s,28
|
|
6
6
|
heavyball/foreach_adopt.py,sha256=fHnbEqvKKc5IKPDWC9Qo9PiISSjj1MEViy0Jb3BRgZQ,3582
|
7
7
|
heavyball/foreach_laprop.py,sha256=EXkwFQ-H7hHWLmiNUsxUcmXhzNNLMjieHjfOlY_6kmo,2868
|
8
8
|
heavyball/foreach_sfadamw.py,sha256=TeWf0nKXQEFcz02rADYRJenDM9mX1dGHhvILLks6OW8,3087
|
9
|
-
heavyball/foreach_soap.py,sha256=
|
9
|
+
heavyball/foreach_soap.py,sha256=ntFqg0fbkZ8EzERGlypXB8JWoGJ1sAY59f0CuWh_d48,4801
|
10
|
+
heavyball/foreach_solp.py,sha256=1r7x_FUZRaUsoSLSvi-Z_-pZNtZrMresVJGq9m1EREA,4563
|
10
11
|
heavyball/p_adam.py,sha256=qEcuU8VEc35vaWAXjT0O65vfCuNn_3ttwL4RlJKN3Xw,6389
|
11
12
|
heavyball/palm_foreach_sfadamw.py,sha256=1qOr-uniSmI1sNCJc1SnvyKH5iFu80Z6H5h93lDTwcE,3410
|
12
|
-
heavyball/palm_foreach_soap.py,sha256=
|
13
|
+
heavyball/palm_foreach_soap.py,sha256=fbRL1Tx9YeQ16sHWFPtY5Kj60BFV2AMngOnTiE4muK0,6231
|
14
|
+
heavyball/palm_foreach_solp.py,sha256=N3M3tnahOfSHvLu3en76JTI1yo-ISEbSliSKlpt8ZWw,5994
|
13
15
|
heavyball/precond_schedule_foreach_soap.py,sha256=p7oD2bESyCPsdGkJYhHluraDb_1K5Q28RNL6fIvD5C8,4969
|
16
|
+
heavyball/precond_schedule_foreach_solp.py,sha256=xGEQ6HHUTCKeT9-ppEbLTdXVAfE74P0tph0qS16USyg,4768
|
14
17
|
heavyball/precond_schedule_palm_foreach_soap.py,sha256=Sb3Fhv-EG28_oXnbVpE0iHe5R8i5_hltqoi_DgPuoEU,6505
|
18
|
+
heavyball/precond_schedule_palm_foreach_solp.py,sha256=gaoJwJo_ZBnYuMamgFepnV9iWpUCmbYrxMWiL1QkPh0,6253
|
15
19
|
heavyball/precond_schedule_sfpsoap.py,sha256=KUKdZzd336w24zPRcqwRatj7IVmd1Us0a_VuzASluIo,7565
|
16
20
|
heavyball/psgd_kron.py,sha256=PtTe6eR547Y-4CvgjpchgkQsr_kWr4AN-uY9L_JO_C8,6088
|
17
21
|
heavyball/pure_psgd.py,sha256=344NdVNHwUFX3fU2R1S_Xh9SXAML3E4ryHr7xfMh9Cc,5076
|
18
22
|
heavyball/schedule_free_palm_foreach_soap.py,sha256=KTQY37MZH7YnOSTLKY8uVySUXxWXbFVUA1QXN3iv8Ds,7244
|
19
|
-
heavyball/utils.py,sha256=
|
20
|
-
heavyball-0.
|
21
|
-
heavyball-0.
|
22
|
-
heavyball-0.
|
23
|
-
heavyball-0.
|
24
|
-
heavyball-0.
|
23
|
+
heavyball/utils.py,sha256=_KvCCCnsu_l4I_OhiRr4noAiwUvzctN05JAuYPkrxXY,41191
|
24
|
+
heavyball-0.25.1.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
|
25
|
+
heavyball-0.25.1.dist-info/METADATA,sha256=WWR7dX_i7dcF-73-VJ42qcRFwZRL3unOSEwO4EM96e0,11926
|
26
|
+
heavyball-0.25.1.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
27
|
+
heavyball-0.25.1.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
|
28
|
+
heavyball-0.25.1.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|