heavyball 0.25.1__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- heavyball/__init__.py +193 -40
- heavyball/chainable.py +475 -0
- heavyball/utils.py +318 -187
- {heavyball-0.25.1.dist-info → heavyball-1.0.0.dist-info}/METADATA +4 -3
- heavyball-1.0.0.dist-info/RECORD +8 -0
- heavyball/cached_delayed_psgd_kron.py +0 -135
- heavyball/cached_psgd_kron.py +0 -136
- heavyball/delayed_psgd.py +0 -122
- heavyball/foreach_adamw.py +0 -63
- heavyball/foreach_adopt.py +0 -83
- heavyball/foreach_laprop.py +0 -67
- heavyball/foreach_sfadamw.py +0 -69
- heavyball/foreach_soap.py +0 -93
- heavyball/foreach_solp.py +0 -89
- heavyball/p_adam.py +0 -121
- heavyball/palm_foreach_sfadamw.py +0 -77
- heavyball/palm_foreach_soap.py +0 -101
- heavyball/palm_foreach_solp.py +0 -98
- heavyball/precond_schedule_foreach_soap.py +0 -95
- heavyball/precond_schedule_foreach_solp.py +0 -95
- heavyball/precond_schedule_palm_foreach_soap.py +0 -105
- heavyball/precond_schedule_palm_foreach_solp.py +0 -103
- heavyball/precond_schedule_sfpsoap.py +0 -141
- heavyball/psgd_kron.py +0 -120
- heavyball/pure_psgd.py +0 -105
- heavyball/schedule_free_palm_foreach_soap.py +0 -136
- heavyball-0.25.1.dist-info/RECORD +0 -28
- {heavyball-0.25.1.dist-info → heavyball-1.0.0.dist-info}/LICENSE +0 -0
- {heavyball-0.25.1.dist-info → heavyball-1.0.0.dist-info}/WHEEL +0 -0
- {heavyball-0.25.1.dist-info → heavyball-1.0.0.dist-info}/top_level.txt +0 -0
heavyball/foreach_sfadamw.py
DELETED
@@ -1,69 +0,0 @@
|
|
1
|
-
import torch
|
2
|
-
import torch.optim
|
3
|
-
|
4
|
-
from .utils import get_ckp1, copy_stochastic_list_, warmup, ScheduleFree, exp_avg_sq_, beta_debias, promote, _compilable_schedule_free_, decorator_knowngood
|
5
|
-
|
6
|
-
|
7
|
-
@decorator_knowngood
|
8
|
-
def _compilable_step_(y, grad, exp_avg_sq, z, beta1, beta2, step, ckp1, eps, decay, lr):
|
9
|
-
old_debiased2 = beta_debias(beta2, step)
|
10
|
-
|
11
|
-
g32 = [promote(g_) for g_ in grad]
|
12
|
-
exp_avg_sq32 = [promote(e_) for e_ in exp_avg_sq]
|
13
|
-
|
14
|
-
denom = exp_avg_sq_(exp_avg_sq32, g32, old_debiased2, eps)
|
15
|
-
torch._foreach_div_(g32, denom)
|
16
|
-
if decay != 0:
|
17
|
-
torch._foreach_add_(g32, y, alpha=decay)
|
18
|
-
for p, z_, g in zip(y, z, g32):
|
19
|
-
_compilable_schedule_free_(p, z_, ckp1, g, lr, beta1)
|
20
|
-
|
21
|
-
copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
|
22
|
-
|
23
|
-
|
24
|
-
class ForeachSFAdamW(ScheduleFree):
|
25
|
-
def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0, r=0.0,
|
26
|
-
weight_lr_power=2.0, foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False,
|
27
|
-
caution: bool = False, mars_gamma: float = 0.0025):
|
28
|
-
|
29
|
-
assert not caution, "Caution not implemented for SFAdamW"
|
30
|
-
|
31
|
-
defaults = dict(lr=lr, betas=betas, eps=eps, r=r, k=0, warmup_steps=warmup_steps, train_mode=True,
|
32
|
-
weight_sum=0.0, lr_max=-1.0, weight_lr_power=weight_lr_power, weight_decay=weight_decay,
|
33
|
-
foreach=foreach, storage_dtype=storage_dtype, mars=mars, caution=caution, mars_gamma=mars_gamma)
|
34
|
-
super().__init__(params, defaults, foreach)
|
35
|
-
|
36
|
-
def _step(self, group):
|
37
|
-
eps = group['eps']
|
38
|
-
decay = group['weight_decay']
|
39
|
-
k = group['k']
|
40
|
-
|
41
|
-
if not group['train_mode']:
|
42
|
-
raise Exception("Not in train mode!")
|
43
|
-
|
44
|
-
active_p = [p for p in group['params'] if p.grad is not None]
|
45
|
-
|
46
|
-
if not active_p:
|
47
|
-
return
|
48
|
-
|
49
|
-
storage_dtype = getattr(torch, group['storage_dtype'])
|
50
|
-
|
51
|
-
for p in active_p:
|
52
|
-
if 'z' not in self.state_(p):
|
53
|
-
self.state_(p)['z'] = torch.clone(p.data, memory_format=torch.preserve_format)
|
54
|
-
self.state_(p)['exp_avg_sq'] = torch.zeros_like(p.data, dtype=storage_dtype, memory_format=torch.preserve_format)
|
55
|
-
|
56
|
-
y, grad, exp_avg_sq, z = zip(*[(p.data, p.grad, self.state_(p)['exp_avg_sq'], self.state_(p)['z']) #
|
57
|
-
for p in active_p])
|
58
|
-
|
59
|
-
if group['mars']:
|
60
|
-
self.mars_correct_list(group, y, grad, group['mars_gamma'], group['betas'][0])
|
61
|
-
|
62
|
-
lr = warmup(group['lr'], k + 1, group['warmup_steps'])
|
63
|
-
ckp1, group['weight_sum'] = get_ckp1(lr, group['weight_lr_power'], group['weight_sum'], group['r'], k + 1)
|
64
|
-
|
65
|
-
step = torch.empty((), dtype=torch.int32, device=y[0].device).fill_(k + 1)
|
66
|
-
ckp1 = torch.empty((), dtype=torch.float32, device=y[0].device).fill_(ckp1)
|
67
|
-
lr = torch.empty((), dtype=torch.float32, device=y[0].device).fill_(lr)
|
68
|
-
_compilable_step_(y, grad, exp_avg_sq, z, group['betas'][0], group['betas'][1], step, ckp1, eps, decay, lr)
|
69
|
-
group['k'] = k + 1
|
heavyball/foreach_soap.py
DELETED
@@ -1,93 +0,0 @@
|
|
1
|
-
import torch
|
2
|
-
|
3
|
-
from .utils import init_preconditioner, update_preconditioner, project, beta_debias, exp_avg_sq_, update_param_, set_, \
|
4
|
-
StatefulOptimizer, exp_avg_
|
5
|
-
|
6
|
-
|
7
|
-
class ForeachSOAP(StatefulOptimizer):
|
8
|
-
"""
|
9
|
-
ForeachSOAP
|
10
|
-
|
11
|
-
Sources:
|
12
|
-
Baseline SOAP:
|
13
|
-
SOAP: Improving and Stabilizing Shampoo using Adam
|
14
|
-
Nikhil Vyas, Depen Morwani, Rosie Zhao, Itai Shapira, David Brandfonbrener, Lucas Janson, Sham Kakade
|
15
|
-
https://arxiv.org/abs/2409.11321
|
16
|
-
https://github.com/nikhilvyas/SOAP
|
17
|
-
|
18
|
-
ScheduleFree:
|
19
|
-
The Road Less Scheduled
|
20
|
-
Aaron Defazio, Xingyu Alice Yang, Harsh Mehta, Konstantin Mishchenko, Ahmed Khaled, Ashok Cutkosky
|
21
|
-
https://arxiv.org/abs/2405.15682
|
22
|
-
https://github.com/facebookresearch/schedule_free
|
23
|
-
"""
|
24
|
-
|
25
|
-
def __init__(self, params, lr: float = 3e-3, betas=(0.9, 0.95), shampoo_beta: float = 0.95, eps: float = 1e-8,
|
26
|
-
weight_decay: float = 0.01, precondition_frequency: int = 2, max_precond_dim: int = 2048, #
|
27
|
-
merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
|
28
|
-
data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1,
|
29
|
-
split: bool = False, foreach: bool = True, mars: bool = False, caution: bool = False,
|
30
|
-
mars_gamma: float = 0.0025):
|
31
|
-
defaults = {"lr": lr, "betas": betas, "shampoo_beta": shampoo_beta, "eps": eps, "weight_decay": weight_decay,
|
32
|
-
"precondition_frequency": precondition_frequency, "max_precond_dim": max_precond_dim,
|
33
|
-
"merge_dims": merge_dims, "precondition_1d": precondition_1d, "normalize_grads": normalize_grads,
|
34
|
-
"correct_bias": correct_bias, 'warmup_steps': warmup_steps, 'split': split, 'mars': mars,
|
35
|
-
'caution': caution, 'mars_gamma': mars_gamma}
|
36
|
-
super().__init__(params, defaults, foreach)
|
37
|
-
self._data_format = data_format
|
38
|
-
|
39
|
-
def _step(self, group):
|
40
|
-
vals = []
|
41
|
-
step = 0
|
42
|
-
|
43
|
-
max_precond_dim = group['max_precond_dim']
|
44
|
-
precondition_1d = group['precondition_1d']
|
45
|
-
|
46
|
-
for p, g in self.split_p_and_g_in_group(group, beta1=group['betas'][0]):
|
47
|
-
state = self.state_(p)
|
48
|
-
step = state['step'] = state.get("step", -1) + 1
|
49
|
-
|
50
|
-
if "exp_avg" not in state:
|
51
|
-
state["exp_avg"] = torch.zeros_like(g, dtype=torch.float32, memory_format=torch.preserve_format)
|
52
|
-
state["exp_avg_sq"] = torch.zeros_like(g, dtype=torch.float32, memory_format=torch.preserve_format)
|
53
|
-
init_preconditioner(g, state, max_precond_dim, precondition_1d)
|
54
|
-
update_preconditioner(g, state, max_precond_dim, precondition_1d, 0, True)
|
55
|
-
continue # first step is skipped so that we never use the current gradients in the projection.
|
56
|
-
|
57
|
-
# Projecting gradients to the eigenbases of Shampoo's preconditioner
|
58
|
-
# i.e. projecting to the eigenbases of matrices in state['GG']
|
59
|
-
grad_projected = project(g, state['Q'], False)
|
60
|
-
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
|
61
|
-
vals.append((p, g, grad_projected, exp_avg, exp_avg_sq))
|
62
|
-
|
63
|
-
if not vals:
|
64
|
-
return
|
65
|
-
|
66
|
-
p_list, grad, grad_projected, exp_avg, exp_avg_sq = zip(*vals)
|
67
|
-
beta1, beta2 = group["betas"]
|
68
|
-
|
69
|
-
old_debiased2 = beta_debias(beta2, step)
|
70
|
-
|
71
|
-
# Decay the first and second moment running average coefficient
|
72
|
-
# In-place operations to update the averages at the same time
|
73
|
-
step_tensor = torch.empty((), dtype=torch.int32, device=p_list[0].device).fill_(step)
|
74
|
-
|
75
|
-
step_size = -group["lr"] * min(step / group['warmup_steps'], 1)
|
76
|
-
|
77
|
-
for p, g, gp, ea, eas in zip(p_list, grad, grad_projected, exp_avg, exp_avg_sq):
|
78
|
-
state = self.state_(p)
|
79
|
-
|
80
|
-
d = exp_avg_(ea, eas, g, gp, beta1, beta2, step_tensor)[0]
|
81
|
-
# Projecting the exponential moving average of gradients to the eigenbases of Shampoo's preconditioner
|
82
|
-
# i.e. projecting to the eigenbases of matrices in state['GG']
|
83
|
-
exp_avg_projected = project(ea, state['Q'], False)
|
84
|
-
|
85
|
-
# Projecting back the preconditioned (by Adam) exponential moving average of gradients
|
86
|
-
# to the original space
|
87
|
-
# CANT DO /= HERE AS EXP_AVG MAY POINT TO THE BUFFER
|
88
|
-
precond = project(exp_avg_projected / d, state['Q'], True)
|
89
|
-
|
90
|
-
update_preconditioner(g, state, max_precond_dim, precondition_1d, old_debiased2,
|
91
|
-
step > 0 and step % group['precondition_frequency'] == 0)
|
92
|
-
|
93
|
-
update_param_([p], [precond], step_size, group["weight_decay"], caution=group['caution'], grad=[g])
|
heavyball/foreach_solp.py
DELETED
@@ -1,89 +0,0 @@
|
|
1
|
-
import torch
|
2
|
-
|
3
|
-
from .utils import init_preconditioner, update_preconditioner, project, beta_debias, exp_avg_sq_, update_param_, set_, \
|
4
|
-
StatefulOptimizer, laprop_exp_avg_
|
5
|
-
|
6
|
-
|
7
|
-
class ForeachSOLP(StatefulOptimizer):
|
8
|
-
"""
|
9
|
-
ForeachSOLP
|
10
|
-
|
11
|
-
Sources:
|
12
|
-
LaProp:
|
13
|
-
LaProp: Separating Momentum and Adaptivity in Adam
|
14
|
-
Liu Ziyin, Zhikang T.Wang, Masahito Ueda
|
15
|
-
https://arxiv.org/abs/2002.04839
|
16
|
-
https://github.com/ClashLuke/HeavyBall/blob/main/heavyball/foreach_laprop.py
|
17
|
-
|
18
|
-
Baseline SOAP:
|
19
|
-
SOAP: Improving and Stabilizing Shampoo using Adam
|
20
|
-
Nikhil Vyas, Depen Morwani, Rosie Zhao, Itai Shapira, David Brandfonbrener, Lucas Janson, Sham Kakade
|
21
|
-
https://arxiv.org/abs/2409.11321
|
22
|
-
https://github.com/nikhilvyas/SOAP
|
23
|
-
|
24
|
-
ScheduleFree:
|
25
|
-
The Road Less Scheduled
|
26
|
-
Aaron Defazio, Xingyu Alice Yang, Harsh Mehta, Konstantin Mishchenko, Ahmed Khaled, Ashok Cutkosky
|
27
|
-
https://arxiv.org/abs/2405.15682
|
28
|
-
https://github.com/facebookresearch/schedule_free
|
29
|
-
"""
|
30
|
-
|
31
|
-
def __init__(self, params, lr: float = 3e-3, betas=(0.9, 0.95), shampoo_beta: float = 0.95, eps: float = 1e-8,
|
32
|
-
weight_decay: float = 0.01, precondition_frequency: int = 2, max_precond_dim: int = 2048, #
|
33
|
-
merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
|
34
|
-
data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1,
|
35
|
-
split: bool = False, foreach: bool = True, mars: bool = False, caution: bool = False,
|
36
|
-
mars_gamma: float = 0.0025):
|
37
|
-
defaults = {"lr": lr, "betas": betas, "shampoo_beta": shampoo_beta, "eps": eps, "weight_decay": weight_decay,
|
38
|
-
"precondition_frequency": precondition_frequency, "max_precond_dim": max_precond_dim,
|
39
|
-
"merge_dims": merge_dims, "precondition_1d": precondition_1d, "normalize_grads": normalize_grads,
|
40
|
-
"correct_bias": correct_bias, 'warmup_steps': warmup_steps, 'split': split, 'mars': mars,
|
41
|
-
'caution': caution, 'mars_gamma': mars_gamma}
|
42
|
-
super().__init__(params, defaults, foreach)
|
43
|
-
self._data_format = data_format
|
44
|
-
|
45
|
-
def _step(self, group):
|
46
|
-
vals = []
|
47
|
-
step = 0
|
48
|
-
|
49
|
-
max_precond_dim = group['max_precond_dim']
|
50
|
-
precondition_1d = group['precondition_1d']
|
51
|
-
|
52
|
-
for p, g in self.split_p_and_g_in_group(group, beta1=group['betas'][0]):
|
53
|
-
state = self.state_(p)
|
54
|
-
step = state['step'] = state.get("step", -1) + 1
|
55
|
-
|
56
|
-
if "exp_avg" not in state:
|
57
|
-
state["exp_avg"] = torch.zeros_like(g, dtype=torch.float32, memory_format=torch.preserve_format)
|
58
|
-
state["exp_avg_sq"] = torch.zeros_like(g, dtype=torch.float32, memory_format=torch.preserve_format)
|
59
|
-
init_preconditioner(g, state, max_precond_dim, precondition_1d)
|
60
|
-
update_preconditioner(g, state, max_precond_dim, precondition_1d, 0, True)
|
61
|
-
continue # first step is skipped so that we never use the current gradients in the projection.
|
62
|
-
|
63
|
-
# Projecting gradients to the eigenbases of Shampoo's preconditioner
|
64
|
-
# i.e. projecting to the eigenbases of matrices in state['GG']
|
65
|
-
grad_projected = project(g, state['Q'], False)
|
66
|
-
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
|
67
|
-
vals.append((p, g, grad_projected, exp_avg, exp_avg_sq))
|
68
|
-
|
69
|
-
if not vals:
|
70
|
-
return
|
71
|
-
|
72
|
-
p_list, grad, grad_projected, exp_avg, exp_avg_sq = zip(*vals)
|
73
|
-
beta1, beta2 = group["betas"]
|
74
|
-
|
75
|
-
old_debiased2 = beta_debias(beta2, step)
|
76
|
-
|
77
|
-
# Decay the first and second moment running average coefficient
|
78
|
-
# In-place operations to update the averages at the same time
|
79
|
-
step_tensor = torch.empty((), dtype=torch.int32, device=p_list[0].device).fill_(step)
|
80
|
-
|
81
|
-
step_size = -group["lr"] * min(step / group['warmup_steps'], 1)
|
82
|
-
|
83
|
-
for p, g, gp, ea, eas in zip(p_list, grad, grad_projected, exp_avg, exp_avg_sq):
|
84
|
-
laprop_exp_avg_(ea, eas, gp, beta1, beta2, step_tensor)
|
85
|
-
state = self.state_(p)
|
86
|
-
precond = project(ea, state['Q'], True)
|
87
|
-
|
88
|
-
update_preconditioner(g, state, max_precond_dim, precondition_1d, old_debiased2, step > 0 and step % group['precondition_frequency'] == 0)
|
89
|
-
update_param_([p], [precond], step_size, group["weight_decay"], caution=group['caution'], grad=[g])
|
heavyball/p_adam.py
DELETED
@@ -1,121 +0,0 @@
|
|
1
|
-
"""
|
2
|
-
Originally from Evan Walters and Omead Pooladzandi, 2024
|
3
|
-
Modified under Creative Commons Attribution 4.0 International
|
4
|
-
Source available at https://github.com/evanatyourservice/kron_torch/blob/97a2b5ee8a1a4c29e4780bbf6c521e545189eff9/kron_torch/kron.py
|
5
|
-
"""
|
6
|
-
|
7
|
-
import torch
|
8
|
-
from heavyball.utils import triu_to_line, line_to_triu, identity, stochastic_lerp_
|
9
|
-
|
10
|
-
from .utils import update_param_, warmup, psgd_precond_grad, init_Q_exprs, PSGDBase, exp_avg_sq_, beta_debias, \
|
11
|
-
promote
|
12
|
-
|
13
|
-
|
14
|
-
class ForeachPaLMPAdam(PSGDBase):
|
15
|
-
"""
|
16
|
-
Kronecker Factorized Adam with PSGD preconditioner
|
17
|
-
|
18
|
-
Args:
|
19
|
-
params (iterable): Iterable of parameters to optimize or dicts defining
|
20
|
-
parameter groups.
|
21
|
-
lr (float): Learning rate.
|
22
|
-
weight_decay (float): Weight decay (L2 penalty).
|
23
|
-
preconditioner_update_probability (callable or float, optional): Probability of
|
24
|
-
updating the preconditioner. If None, defaults to a schedule that anneals
|
25
|
-
from 1.0 to 0.03 by 4000 steps.
|
26
|
-
max_size_triangular (int): Max size for dim's preconditioner to be triangular.
|
27
|
-
min_ndim_triangular (int): Minimum number of dimensions a layer needs
|
28
|
-
to have triangular preconditioners.
|
29
|
-
memory_save_mode: (string, optional), None, 'one_diag', or 'all_diag', None is default
|
30
|
-
to set all preconditioners to be triangular, 'one_diag' sets the largest
|
31
|
-
or last dim to be diagonal per layer, and 'all_diag' sets all preconditioners
|
32
|
-
to be diagonal.
|
33
|
-
momentum_into_precond_update: (bool), whether to send momentum into preconditioner
|
34
|
-
update instead of raw gradients.
|
35
|
-
"""
|
36
|
-
|
37
|
-
def __init__(self, params, lr=0.001, weight_decay=0.0, preconditioner_update_probability=None,
|
38
|
-
max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
|
39
|
-
momentum_into_precond_update=True, warmup_steps: int = 1, betas=(None, None), beta: float = 0.9,
|
40
|
-
beta2_scale: float = 0.8, merge_dims: bool = False, split: bool = False, clip_fn: callable = None,
|
41
|
-
store_triu_as_line: bool = True, foreach: bool = True, q_dtype='float32',
|
42
|
-
stochastic_schedule: bool = True, storage_dtype: str = 'float32', mars: bool = False,
|
43
|
-
caution: bool = False, mars_gamma: float = 0.0025, #
|
44
|
-
# expert parameters
|
45
|
-
precond_init_scale=1.0, precond_lr=0.1):
|
46
|
-
if not 0.0 <= lr:
|
47
|
-
raise ValueError(f"Invalid learning rate: {lr}")
|
48
|
-
if not 0.0 <= weight_decay:
|
49
|
-
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
|
50
|
-
if betas[0] is not None:
|
51
|
-
beta = betas[0]
|
52
|
-
|
53
|
-
if clip_fn is None:
|
54
|
-
clip_fn = identity
|
55
|
-
|
56
|
-
defaults = dict(lr=lr, weight_decay=weight_decay, max_size_triangular=max_size_triangular,
|
57
|
-
min_ndim_triangular=min_ndim_triangular, memory_save_mode=memory_save_mode,
|
58
|
-
momentum_into_precond_update=momentum_into_precond_update, precond_lr=precond_lr,
|
59
|
-
precond_init_scale=precond_init_scale, step=0, warmup_steps=warmup_steps, beta=beta,
|
60
|
-
beta2_scale=beta2_scale, merge_dims=merge_dims, split=split,
|
61
|
-
store_triu_as_line=store_triu_as_line, q_dtype=q_dtype, storage_dtype=storage_dtype,
|
62
|
-
mars=mars, caution=caution, mars_gamma=mars_gamma)
|
63
|
-
super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
|
64
|
-
|
65
|
-
def _step(self, group):
|
66
|
-
should_update = self.should_update(group)
|
67
|
-
precond_init_scale = group['precond_init_scale']
|
68
|
-
max_size_triangular = group['max_size_triangular']
|
69
|
-
min_ndim_triangular = group['min_ndim_triangular']
|
70
|
-
memory_save_mode = group['memory_save_mode']
|
71
|
-
precond_lr = group['precond_lr']
|
72
|
-
weight_decay = group['weight_decay']
|
73
|
-
lr = group['lr']
|
74
|
-
store_triu_as_line = group['store_triu_as_line']
|
75
|
-
q_dtype = getattr(torch, group['q_dtype'])
|
76
|
-
storage_dtype = getattr(torch, group['storage_dtype'])
|
77
|
-
|
78
|
-
vals = []
|
79
|
-
|
80
|
-
for p, g in self.split_p_and_g_in_group(group, should_promote=False, beta1=group['beta']):
|
81
|
-
state = self.state_(p)
|
82
|
-
|
83
|
-
if 'Q' not in state:
|
84
|
-
state['exp_avg'] = torch.zeros_like(g, dtype=storage_dtype, memory_format=torch.preserve_format)
|
85
|
-
state['exp_avg_sq'] = torch.zeros_like(g, dtype=storage_dtype, memory_format=torch.preserve_format)
|
86
|
-
Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular, min_ndim_triangular,
|
87
|
-
memory_save_mode, dtype=q_dtype)
|
88
|
-
state['Q'] = triu_to_line(Q) if store_triu_as_line else Q
|
89
|
-
|
90
|
-
vals.append((p, g, state["Q"], state['exp_avg'], state['exp_avg_sq']))
|
91
|
-
|
92
|
-
if not vals:
|
93
|
-
return
|
94
|
-
|
95
|
-
p_list, grad_list, Q_list, exp_avg, exp_avg_sq = zip(*vals)
|
96
|
-
del vals
|
97
|
-
|
98
|
-
group["step"] += 1
|
99
|
-
|
100
|
-
Q_triu = [line_to_triu(q) if store_triu_as_line else q for q in Q_list]
|
101
|
-
if should_update:
|
102
|
-
for g, p, q_, q_orig in zip(grad_list, p_list, Q_triu, Q_list):
|
103
|
-
q32 = [promote(qq_) for qq_ in q_]
|
104
|
-
self.do_update(group, [p], [g], [q32], precond_lr, [q_orig], store_triu_as_line)
|
105
|
-
stochastic_lerp_(exp_avg, grad_list, 1 - beta_debias(group['beta'], group['step']))
|
106
|
-
|
107
|
-
beta2 = 1 - group['step'] ** -group['beta2_scale']
|
108
|
-
|
109
|
-
lr = -warmup(lr, group['step'], group['warmup_steps'])
|
110
|
-
|
111
|
-
for p, Q, g, ea, eas in zip(p_list, Q_triu, grad_list, exp_avg, exp_avg_sq):
|
112
|
-
gc = g.clone() if group['caution'] else None
|
113
|
-
psgd_precond_grad(True, self.state_(p)["exprs"][-1], g, *Q)
|
114
|
-
ea = psgd_precond_grad(False, self.state_(p)["exprs"][-1], ea, *Q)
|
115
|
-
exp_avg_sq_(eas, g, beta_debias(beta2, group['step']), 1e-8, out=g)
|
116
|
-
torch.div(ea, g, out=g)
|
117
|
-
"""
|
118
|
-
divide by g here, because g == denom (from exp_avg_sq_(out=g)), avoids denom allocation
|
119
|
-
divide into g so we can deallocate ea, avoids one allocation (-> less memory than equivalent foreach)
|
120
|
-
"""
|
121
|
-
update_param_([p], self.clip_fn([g]), lr, weight_decay, caution=group['caution'], grad=gc)
|
@@ -1,77 +0,0 @@
|
|
1
|
-
import torch
|
2
|
-
import torch.optim
|
3
|
-
|
4
|
-
from .utils import warmup, ScheduleFree, exp_avg_sq_, beta_debias, get_ckp1, promote, \
|
5
|
-
_compilable_schedule_free_, copy_stochastic_list_, decorator_knowngood
|
6
|
-
|
7
|
-
|
8
|
-
@decorator_knowngood
|
9
|
-
def _compilable_step_(y, grad, exp_avg_sq, z, beta1, beta2, step, ckp1, eps, decay, lr):
|
10
|
-
old_debiased2 = beta_debias(beta2, step)
|
11
|
-
|
12
|
-
g32 = [promote(g_) for g_ in grad]
|
13
|
-
exp_avg_sq32 = [promote(e_) for e_ in exp_avg_sq]
|
14
|
-
|
15
|
-
denom = exp_avg_sq_(exp_avg_sq32, g32, old_debiased2, eps)
|
16
|
-
torch._foreach_div_(g32, denom)
|
17
|
-
if decay != 0:
|
18
|
-
torch._foreach_add_(g32, y, alpha=decay)
|
19
|
-
for p, z_, g in zip(y, z, g32):
|
20
|
-
_compilable_schedule_free_(p, z_, ckp1, g, lr, beta1)
|
21
|
-
|
22
|
-
copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
|
23
|
-
|
24
|
-
|
25
|
-
class PaLMForeachSFAdamW(ScheduleFree):
|
26
|
-
def __init__(self, params, lr=0.0025, beta=0.9, betas=(None, None), eps=1e-8, weight_decay=0, warmup_steps=0, r=0.0,
|
27
|
-
weight_lr_power=2.0, beta2_scale: float = 0.8, foreach: bool = True, storage_dtype: str = 'float32',
|
28
|
-
mars: bool = False, caution: bool = False, mars_gamma: float = 0.0025):
|
29
|
-
if betas[0] is not None:
|
30
|
-
beta = betas[0]
|
31
|
-
|
32
|
-
assert not caution, "Caution not implemented for SFAdamW"
|
33
|
-
|
34
|
-
defaults = dict(lr=lr, beta=beta, eps=eps, r=r, k=0, warmup_steps=warmup_steps, train_mode=True, weight_sum=0.0,
|
35
|
-
lr_max=-1.0, weight_lr_power=weight_lr_power, weight_decay=weight_decay,
|
36
|
-
beta2_scale=beta2_scale, storage_dtype=storage_dtype, mars=mars, caution=caution,
|
37
|
-
mars_gamma=mars_gamma)
|
38
|
-
super().__init__(params, defaults, foreach)
|
39
|
-
|
40
|
-
def _step(self, group):
|
41
|
-
eps = group['eps']
|
42
|
-
decay = group['weight_decay']
|
43
|
-
k = group['k']
|
44
|
-
|
45
|
-
if not group['train_mode']:
|
46
|
-
raise Exception("Not in train mode!")
|
47
|
-
|
48
|
-
active_p = [p for p in group['params'] if p.grad is not None]
|
49
|
-
|
50
|
-
if not active_p:
|
51
|
-
return
|
52
|
-
|
53
|
-
storage_dtype = getattr(torch, group['storage_dtype'])
|
54
|
-
|
55
|
-
for p in active_p:
|
56
|
-
if 'z' not in self.state_(p):
|
57
|
-
self.state_(p)['z'] = torch.clone(p.data, memory_format=torch.preserve_format)
|
58
|
-
self.state_(p)['exp_avg_sq'] = torch.zeros_like(p.data, dtype=storage_dtype, memory_format=torch.preserve_format)
|
59
|
-
|
60
|
-
# Decay the first moment running average coefficient
|
61
|
-
beta2 = 1 - (k + 1) ** -group['beta2_scale']
|
62
|
-
|
63
|
-
y, grad, exp_avg_sq, z = zip(*[(p.data, p.grad, self.state_(p)['exp_avg_sq'], self.state_(p)['z']) #
|
64
|
-
for p in active_p])
|
65
|
-
|
66
|
-
if group['mars']:
|
67
|
-
self.mars_correct_list(group, y, grad, group['mars_gamma'], group['betas'][0])
|
68
|
-
|
69
|
-
lr = warmup(group['lr'], k + 1, group['warmup_steps'])
|
70
|
-
ckp1, group['weight_sum'] = get_ckp1(lr, group['weight_lr_power'], group['weight_sum'], group['r'], k + 1)
|
71
|
-
|
72
|
-
step = torch.empty((), dtype=torch.int32, device=y[0].device).fill_(k + 1)
|
73
|
-
ckp1 = torch.empty((), dtype=torch.float32, device=y[0].device).fill_(ckp1)
|
74
|
-
beta2 = torch.empty((), dtype=torch.float32, device=y[0].device).fill_(beta2)
|
75
|
-
lr = torch.empty((), dtype=torch.float32, device=y[0].device).fill_(lr)
|
76
|
-
_compilable_step_(y, grad, exp_avg_sq, z, group['beta'], beta2, step, ckp1, eps, decay, lr)
|
77
|
-
group['k'] = k + 1
|
heavyball/palm_foreach_soap.py
DELETED
@@ -1,101 +0,0 @@
|
|
1
|
-
import torch
|
2
|
-
|
3
|
-
from .utils import init_preconditioner, update_preconditioner, project, beta_debias, exp_avg_sq_, update_param_, set_, \
|
4
|
-
StatefulOptimizer, exp_avg_
|
5
|
-
|
6
|
-
|
7
|
-
class PaLMForeachSOAP(StatefulOptimizer):
|
8
|
-
"""
|
9
|
-
PaLMForeachSOAP
|
10
|
-
|
11
|
-
Sources:
|
12
|
-
Baseline SOAP:
|
13
|
-
SOAP: Improving and Stabilizing Shampoo using Adam
|
14
|
-
Nikhil Vyas, Depen Morwani, Rosie Zhao, Itai Shapira, David Brandfonbrener, Lucas Janson, Sham Kakade
|
15
|
-
https://arxiv.org/abs/2409.11321
|
16
|
-
https://github.com/nikhilvyas/SOAP
|
17
|
-
|
18
|
-
ScheduleFree:
|
19
|
-
The Road Less Scheduled
|
20
|
-
Aaron Defazio, Xingyu Alice Yang, Harsh Mehta, Konstantin Mishchenko, Ahmed Khaled, Ashok Cutkosky
|
21
|
-
https://arxiv.org/abs/2405.15682
|
22
|
-
https://github.com/facebookresearch/schedule_free
|
23
|
-
|
24
|
-
Beta2 Schedule:
|
25
|
-
PaLM: Scaling Language Modeling with Pathways
|
26
|
-
Aakanksha Chowdhery, Sharan Narang, Jacob Devlin, Maarten Bosma, Gaurav Mishra, Adam Roberts, Paul Barham, Hyung Won Chung, Charles Sutton, Sebastian Gehrmann, Parker Schuh, Kensen Shi, Sasha Tsvyashchenko, Joshua Maynez, Abhishek Rao, Parker Barnes, Yi Tay, Noam Shazeer, Vinodkumar Prabhakaran, Emily Reif, Nan Du, Ben Hutchinson, Reiner Pope, James Bradbury, Jacob Austin, Michael Isard, Guy Gur-Ari, Pengcheng Yin, Toju Duke, Anselm Levskaya, Sanjay Ghemawat, Sunipa Dev, Henryk Michalewski, Xavier Garcia, Vedant Misra, Kevin Robinson, Liam Fedus, Denny Zhou, Daphne Ippolito, David Luan, Hyeontaek Lim, Barret Zoph, Alexander Spiridonov, Ryan Sepassi, David Dohan, Shivani Agrawal, Mark Omernick, Andrew M. Dai, Thanumalayan Sankaranarayana Pillai, Marie Pellat, Aitor Lewkowycz, Erica Moreira, Rewon Child, Oleksandr Polozov, Katherine Lee, Zongwei Zhou, Xuezhi Wang, Brennan Saeta, Mark Diaz, Orhan Firat, Michele Catasta, Jason Wei, Kathy Meier-Hellstern, Douglas Eck, Jeff Dean, Slav Petrov, Noah Fiedel
|
27
|
-
https://arxiv.org/abs/2204.02311
|
28
|
-
"""
|
29
|
-
|
30
|
-
def __init__(self, params, lr: float = 3e-3, beta=0.9, betas=(None, None), shampoo_beta: float = 0.95,
|
31
|
-
eps: float = 1e-8, weight_decay: float = 0.01, precondition_frequency: int = 2,
|
32
|
-
max_precond_dim: int = 2048, #
|
33
|
-
merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
|
34
|
-
data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1,
|
35
|
-
beta2_scale: float = 0.8, split: bool = False, foreach: bool = True, mars: bool = False,
|
36
|
-
caution: bool = False, mars_gamma: float = 0.0025):
|
37
|
-
if betas[0] is not None:
|
38
|
-
beta = betas[0]
|
39
|
-
defaults = {"lr": lr, "beta": beta, "shampoo_beta": shampoo_beta, "eps": eps, "weight_decay": weight_decay,
|
40
|
-
"precondition_frequency": precondition_frequency, "max_precond_dim": max_precond_dim,
|
41
|
-
"merge_dims": merge_dims, "precondition_1d": precondition_1d, "normalize_grads": normalize_grads,
|
42
|
-
"correct_bias": correct_bias, 'warmup_steps': warmup_steps, 'beta2_scale': beta2_scale,
|
43
|
-
'split': split, 'mars': mars, 'caution': caution, 'mars_gamma': mars_gamma}
|
44
|
-
super().__init__(params, defaults, foreach)
|
45
|
-
self._data_format = data_format
|
46
|
-
|
47
|
-
def _step(self, group):
|
48
|
-
vals = []
|
49
|
-
step = 0
|
50
|
-
|
51
|
-
max_precond_dim = group['max_precond_dim']
|
52
|
-
precondition_1d = group['precondition_1d']
|
53
|
-
|
54
|
-
for p, g in self.split_p_and_g_in_group(group, beta1=group['beta']):
|
55
|
-
state = self.state_(p)
|
56
|
-
step = state['step'] = state.get("step", -1) + 1
|
57
|
-
|
58
|
-
if "exp_avg" not in state:
|
59
|
-
state["exp_avg"] = torch.zeros_like(g, dtype=torch.float32, memory_format=torch.preserve_format)
|
60
|
-
state["exp_avg_sq"] = torch.zeros_like(g, dtype=torch.float32, memory_format=torch.preserve_format)
|
61
|
-
init_preconditioner(g, state, max_precond_dim, precondition_1d)
|
62
|
-
update_preconditioner(g, state, max_precond_dim, precondition_1d, 0, True)
|
63
|
-
continue # first step is skipped so that we never use the current gradients in the projection.
|
64
|
-
|
65
|
-
# Projecting gradients to the eigenbases of Shampoo's preconditioner
|
66
|
-
# i.e. projecting to the eigenbases of matrices in state['GG']
|
67
|
-
grad_projected = project(g, state['Q'], False)
|
68
|
-
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
|
69
|
-
vals.append((p, g, grad_projected, exp_avg, exp_avg_sq))
|
70
|
-
|
71
|
-
if not vals:
|
72
|
-
return
|
73
|
-
|
74
|
-
p_list, grad, grad_projected, exp_avg, exp_avg_sq = zip(*vals)
|
75
|
-
beta1 = group["beta"]
|
76
|
-
|
77
|
-
beta2 = 1 - step ** -group['beta2_scale']
|
78
|
-
old_debiased2 = beta_debias(beta2, step)
|
79
|
-
|
80
|
-
# Decay the first and second moment running average coefficient
|
81
|
-
# In-place operations to update the averages at the same time
|
82
|
-
beta2 = torch.empty((), dtype=torch.float32, device=p_list[0].device).fill_(beta2)
|
83
|
-
step_tensor = torch.empty((), dtype=torch.int32, device=p_list[0].device).fill_(step)
|
84
|
-
step_size = -group["lr"] * min(step / group['warmup_steps'], 1)
|
85
|
-
|
86
|
-
for p, g, gp, ea, eas in zip(p_list, grad, grad_projected, exp_avg, exp_avg_sq):
|
87
|
-
state = self.state_(p)
|
88
|
-
d = exp_avg_(ea, eas, g, gp, beta1, beta2, step_tensor)[0]
|
89
|
-
# Projecting the exponential moving average of gradients to the eigenbases of Shampoo's preconditioner
|
90
|
-
# i.e. projecting to the eigenbases of matrices in state['GG']
|
91
|
-
exp_avg_projected = project(ea, state['Q'], False)
|
92
|
-
|
93
|
-
# Projecting back the preconditioned (by Adam) exponential moving average of gradients
|
94
|
-
# to the original space
|
95
|
-
# CANT DO /= HERE AS EXP_AVG MAY POINT TO THE BUFFER
|
96
|
-
precond = project(exp_avg_projected / d, state['Q'], True)
|
97
|
-
|
98
|
-
update_preconditioner(g, state, max_precond_dim, precondition_1d, old_debiased2,
|
99
|
-
step > 0 and step % group['precondition_frequency'] == 0)
|
100
|
-
|
101
|
-
update_param_([p], [precond], step_size, group["weight_decay"], caution=group['caution'], grad=[g])
|
heavyball/palm_foreach_solp.py
DELETED
@@ -1,98 +0,0 @@
|
|
1
|
-
import torch
|
2
|
-
|
3
|
-
from .utils import init_preconditioner, update_preconditioner, project, beta_debias, exp_avg_sq_, update_param_, set_, \
|
4
|
-
StatefulOptimizer, laprop_exp_avg_
|
5
|
-
|
6
|
-
|
7
|
-
class PaLMForeachSOLP(StatefulOptimizer):
|
8
|
-
"""
|
9
|
-
PaLMForeachSOAP
|
10
|
-
|
11
|
-
Sources:
|
12
|
-
LaProp:
|
13
|
-
LaProp: Separating Momentum and Adaptivity in Adam
|
14
|
-
Liu Ziyin, Zhikang T.Wang, Masahito Ueda
|
15
|
-
https://arxiv.org/abs/2002.04839
|
16
|
-
https://github.com/ClashLuke/HeavyBall/blob/main/heavyball/foreach_laprop.py
|
17
|
-
|
18
|
-
Baseline SOAP:
|
19
|
-
SOAP: Improving and Stabilizing Shampoo using Adam
|
20
|
-
Nikhil Vyas, Depen Morwani, Rosie Zhao, Itai Shapira, David Brandfonbrener, Lucas Janson, Sham Kakade
|
21
|
-
https://arxiv.org/abs/2409.11321
|
22
|
-
https://github.com/nikhilvyas/SOAP
|
23
|
-
|
24
|
-
ScheduleFree:
|
25
|
-
The Road Less Scheduled
|
26
|
-
Aaron Defazio, Xingyu Alice Yang, Harsh Mehta, Konstantin Mishchenko, Ahmed Khaled, Ashok Cutkosky
|
27
|
-
https://arxiv.org/abs/2405.15682
|
28
|
-
https://github.com/facebookresearch/schedule_free
|
29
|
-
|
30
|
-
Beta2 Schedule:
|
31
|
-
PaLM: Scaling Language Modeling with Pathways
|
32
|
-
Aakanksha Chowdhery, Sharan Narang, Jacob Devlin, Maarten Bosma, Gaurav Mishra, Adam Roberts, Paul Barham, Hyung Won Chung, Charles Sutton, Sebastian Gehrmann, Parker Schuh, Kensen Shi, Sasha Tsvyashchenko, Joshua Maynez, Abhishek Rao, Parker Barnes, Yi Tay, Noam Shazeer, Vinodkumar Prabhakaran, Emily Reif, Nan Du, Ben Hutchinson, Reiner Pope, James Bradbury, Jacob Austin, Michael Isard, Guy Gur-Ari, Pengcheng Yin, Toju Duke, Anselm Levskaya, Sanjay Ghemawat, Sunipa Dev, Henryk Michalewski, Xavier Garcia, Vedant Misra, Kevin Robinson, Liam Fedus, Denny Zhou, Daphne Ippolito, David Luan, Hyeontaek Lim, Barret Zoph, Alexander Spiridonov, Ryan Sepassi, David Dohan, Shivani Agrawal, Mark Omernick, Andrew M. Dai, Thanumalayan Sankaranarayana Pillai, Marie Pellat, Aitor Lewkowycz, Erica Moreira, Rewon Child, Oleksandr Polozov, Katherine Lee, Zongwei Zhou, Xuezhi Wang, Brennan Saeta, Mark Diaz, Orhan Firat, Michele Catasta, Jason Wei, Kathy Meier-Hellstern, Douglas Eck, Jeff Dean, Slav Petrov, Noah Fiedel
|
33
|
-
https://arxiv.org/abs/2204.02311
|
34
|
-
"""
|
35
|
-
|
36
|
-
def __init__(self, params, lr: float = 3e-3, beta=0.9, betas=(None, None), shampoo_beta: float = 0.95,
|
37
|
-
eps: float = 1e-8, weight_decay: float = 0.01, precondition_frequency: int = 2,
|
38
|
-
max_precond_dim: int = 2048, #
|
39
|
-
merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
|
40
|
-
data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1,
|
41
|
-
beta2_scale: float = 0.8, split: bool = False, foreach: bool = True, mars: bool = False,
|
42
|
-
caution: bool = False, mars_gamma: float = 0.0025):
|
43
|
-
if betas[0] is not None:
|
44
|
-
beta = betas[0]
|
45
|
-
defaults = {"lr": lr, "beta": beta, "shampoo_beta": shampoo_beta, "eps": eps, "weight_decay": weight_decay,
|
46
|
-
"precondition_frequency": precondition_frequency, "max_precond_dim": max_precond_dim,
|
47
|
-
"merge_dims": merge_dims, "precondition_1d": precondition_1d, "normalize_grads": normalize_grads,
|
48
|
-
"correct_bias": correct_bias, 'warmup_steps': warmup_steps, 'beta2_scale': beta2_scale,
|
49
|
-
'split': split, 'mars': mars, 'caution': caution, 'mars_gamma': mars_gamma}
|
50
|
-
super().__init__(params, defaults, foreach)
|
51
|
-
self._data_format = data_format
|
52
|
-
|
53
|
-
def _step(self, group):
|
54
|
-
vals = []
|
55
|
-
step = 0
|
56
|
-
|
57
|
-
max_precond_dim = group['max_precond_dim']
|
58
|
-
precondition_1d = group['precondition_1d']
|
59
|
-
|
60
|
-
for p, g in self.split_p_and_g_in_group(group, beta1=group['beta']):
|
61
|
-
state = self.state_(p)
|
62
|
-
step = state['step'] = state.get("step", -1) + 1
|
63
|
-
|
64
|
-
if "exp_avg" not in state:
|
65
|
-
state["exp_avg"] = torch.zeros_like(g, dtype=torch.float32, memory_format=torch.preserve_format)
|
66
|
-
state["exp_avg_sq"] = torch.zeros_like(g, dtype=torch.float32, memory_format=torch.preserve_format)
|
67
|
-
init_preconditioner(g, state, max_precond_dim, precondition_1d)
|
68
|
-
update_preconditioner(g, state, max_precond_dim, precondition_1d, 0, True)
|
69
|
-
continue # first step is skipped so that we never use the current gradients in the projection.
|
70
|
-
|
71
|
-
# Projecting gradients to the eigenbases of Shampoo's preconditioner
|
72
|
-
# i.e. projecting to the eigenbases of matrices in state['GG']
|
73
|
-
grad_projected = project(g, state['Q'], False)
|
74
|
-
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
|
75
|
-
vals.append((p, g, grad_projected, exp_avg, exp_avg_sq))
|
76
|
-
|
77
|
-
if not vals:
|
78
|
-
return
|
79
|
-
|
80
|
-
p_list, grad, grad_projected, exp_avg, exp_avg_sq = zip(*vals)
|
81
|
-
beta1 = group["beta"]
|
82
|
-
|
83
|
-
beta2 = 1 - step ** -group['beta2_scale']
|
84
|
-
old_debiased2 = beta_debias(beta2, step)
|
85
|
-
|
86
|
-
# Decay the first and second moment running average coefficient
|
87
|
-
# In-place operations to update the averages at the same time
|
88
|
-
beta2 = torch.empty((), dtype=torch.float32, device=p_list[0].device).fill_(beta2)
|
89
|
-
step_tensor = torch.empty((), dtype=torch.int32, device=p_list[0].device).fill_(step)
|
90
|
-
step_size = -group["lr"] * min(step / group['warmup_steps'], 1)
|
91
|
-
|
92
|
-
for p, g, gp, ea, eas in zip(p_list, grad, grad_projected, exp_avg, exp_avg_sq):
|
93
|
-
laprop_exp_avg_(ea, eas, gp, beta1, beta2, step_tensor)
|
94
|
-
state = self.state_(p)
|
95
|
-
precond = project(ea, state['Q'], True)
|
96
|
-
|
97
|
-
update_preconditioner(g, state, max_precond_dim, precondition_1d, old_debiased2, step > 0 and step % group['precondition_frequency'] == 0)
|
98
|
-
update_param_([p], [precond], step_size, group["weight_decay"], caution=group['caution'], grad=[g])
|