heavyball 0.25.1__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,69 +0,0 @@
1
- import torch
2
- import torch.optim
3
-
4
- from .utils import get_ckp1, copy_stochastic_list_, warmup, ScheduleFree, exp_avg_sq_, beta_debias, promote, _compilable_schedule_free_, decorator_knowngood
5
-
6
-
7
- @decorator_knowngood
8
- def _compilable_step_(y, grad, exp_avg_sq, z, beta1, beta2, step, ckp1, eps, decay, lr):
9
- old_debiased2 = beta_debias(beta2, step)
10
-
11
- g32 = [promote(g_) for g_ in grad]
12
- exp_avg_sq32 = [promote(e_) for e_ in exp_avg_sq]
13
-
14
- denom = exp_avg_sq_(exp_avg_sq32, g32, old_debiased2, eps)
15
- torch._foreach_div_(g32, denom)
16
- if decay != 0:
17
- torch._foreach_add_(g32, y, alpha=decay)
18
- for p, z_, g in zip(y, z, g32):
19
- _compilable_schedule_free_(p, z_, ckp1, g, lr, beta1)
20
-
21
- copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
22
-
23
-
24
- class ForeachSFAdamW(ScheduleFree):
25
- def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0, r=0.0,
26
- weight_lr_power=2.0, foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False,
27
- caution: bool = False, mars_gamma: float = 0.0025):
28
-
29
- assert not caution, "Caution not implemented for SFAdamW"
30
-
31
- defaults = dict(lr=lr, betas=betas, eps=eps, r=r, k=0, warmup_steps=warmup_steps, train_mode=True,
32
- weight_sum=0.0, lr_max=-1.0, weight_lr_power=weight_lr_power, weight_decay=weight_decay,
33
- foreach=foreach, storage_dtype=storage_dtype, mars=mars, caution=caution, mars_gamma=mars_gamma)
34
- super().__init__(params, defaults, foreach)
35
-
36
- def _step(self, group):
37
- eps = group['eps']
38
- decay = group['weight_decay']
39
- k = group['k']
40
-
41
- if not group['train_mode']:
42
- raise Exception("Not in train mode!")
43
-
44
- active_p = [p for p in group['params'] if p.grad is not None]
45
-
46
- if not active_p:
47
- return
48
-
49
- storage_dtype = getattr(torch, group['storage_dtype'])
50
-
51
- for p in active_p:
52
- if 'z' not in self.state_(p):
53
- self.state_(p)['z'] = torch.clone(p.data, memory_format=torch.preserve_format)
54
- self.state_(p)['exp_avg_sq'] = torch.zeros_like(p.data, dtype=storage_dtype, memory_format=torch.preserve_format)
55
-
56
- y, grad, exp_avg_sq, z = zip(*[(p.data, p.grad, self.state_(p)['exp_avg_sq'], self.state_(p)['z']) #
57
- for p in active_p])
58
-
59
- if group['mars']:
60
- self.mars_correct_list(group, y, grad, group['mars_gamma'], group['betas'][0])
61
-
62
- lr = warmup(group['lr'], k + 1, group['warmup_steps'])
63
- ckp1, group['weight_sum'] = get_ckp1(lr, group['weight_lr_power'], group['weight_sum'], group['r'], k + 1)
64
-
65
- step = torch.empty((), dtype=torch.int32, device=y[0].device).fill_(k + 1)
66
- ckp1 = torch.empty((), dtype=torch.float32, device=y[0].device).fill_(ckp1)
67
- lr = torch.empty((), dtype=torch.float32, device=y[0].device).fill_(lr)
68
- _compilable_step_(y, grad, exp_avg_sq, z, group['betas'][0], group['betas'][1], step, ckp1, eps, decay, lr)
69
- group['k'] = k + 1
heavyball/foreach_soap.py DELETED
@@ -1,93 +0,0 @@
1
- import torch
2
-
3
- from .utils import init_preconditioner, update_preconditioner, project, beta_debias, exp_avg_sq_, update_param_, set_, \
4
- StatefulOptimizer, exp_avg_
5
-
6
-
7
- class ForeachSOAP(StatefulOptimizer):
8
- """
9
- ForeachSOAP
10
-
11
- Sources:
12
- Baseline SOAP:
13
- SOAP: Improving and Stabilizing Shampoo using Adam
14
- Nikhil Vyas, Depen Morwani, Rosie Zhao, Itai Shapira, David Brandfonbrener, Lucas Janson, Sham Kakade
15
- https://arxiv.org/abs/2409.11321
16
- https://github.com/nikhilvyas/SOAP
17
-
18
- ScheduleFree:
19
- The Road Less Scheduled
20
- Aaron Defazio, Xingyu Alice Yang, Harsh Mehta, Konstantin Mishchenko, Ahmed Khaled, Ashok Cutkosky
21
- https://arxiv.org/abs/2405.15682
22
- https://github.com/facebookresearch/schedule_free
23
- """
24
-
25
- def __init__(self, params, lr: float = 3e-3, betas=(0.9, 0.95), shampoo_beta: float = 0.95, eps: float = 1e-8,
26
- weight_decay: float = 0.01, precondition_frequency: int = 2, max_precond_dim: int = 2048, #
27
- merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
28
- data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1,
29
- split: bool = False, foreach: bool = True, mars: bool = False, caution: bool = False,
30
- mars_gamma: float = 0.0025):
31
- defaults = {"lr": lr, "betas": betas, "shampoo_beta": shampoo_beta, "eps": eps, "weight_decay": weight_decay,
32
- "precondition_frequency": precondition_frequency, "max_precond_dim": max_precond_dim,
33
- "merge_dims": merge_dims, "precondition_1d": precondition_1d, "normalize_grads": normalize_grads,
34
- "correct_bias": correct_bias, 'warmup_steps': warmup_steps, 'split': split, 'mars': mars,
35
- 'caution': caution, 'mars_gamma': mars_gamma}
36
- super().__init__(params, defaults, foreach)
37
- self._data_format = data_format
38
-
39
- def _step(self, group):
40
- vals = []
41
- step = 0
42
-
43
- max_precond_dim = group['max_precond_dim']
44
- precondition_1d = group['precondition_1d']
45
-
46
- for p, g in self.split_p_and_g_in_group(group, beta1=group['betas'][0]):
47
- state = self.state_(p)
48
- step = state['step'] = state.get("step", -1) + 1
49
-
50
- if "exp_avg" not in state:
51
- state["exp_avg"] = torch.zeros_like(g, dtype=torch.float32, memory_format=torch.preserve_format)
52
- state["exp_avg_sq"] = torch.zeros_like(g, dtype=torch.float32, memory_format=torch.preserve_format)
53
- init_preconditioner(g, state, max_precond_dim, precondition_1d)
54
- update_preconditioner(g, state, max_precond_dim, precondition_1d, 0, True)
55
- continue # first step is skipped so that we never use the current gradients in the projection.
56
-
57
- # Projecting gradients to the eigenbases of Shampoo's preconditioner
58
- # i.e. projecting to the eigenbases of matrices in state['GG']
59
- grad_projected = project(g, state['Q'], False)
60
- exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
61
- vals.append((p, g, grad_projected, exp_avg, exp_avg_sq))
62
-
63
- if not vals:
64
- return
65
-
66
- p_list, grad, grad_projected, exp_avg, exp_avg_sq = zip(*vals)
67
- beta1, beta2 = group["betas"]
68
-
69
- old_debiased2 = beta_debias(beta2, step)
70
-
71
- # Decay the first and second moment running average coefficient
72
- # In-place operations to update the averages at the same time
73
- step_tensor = torch.empty((), dtype=torch.int32, device=p_list[0].device).fill_(step)
74
-
75
- step_size = -group["lr"] * min(step / group['warmup_steps'], 1)
76
-
77
- for p, g, gp, ea, eas in zip(p_list, grad, grad_projected, exp_avg, exp_avg_sq):
78
- state = self.state_(p)
79
-
80
- d = exp_avg_(ea, eas, g, gp, beta1, beta2, step_tensor)[0]
81
- # Projecting the exponential moving average of gradients to the eigenbases of Shampoo's preconditioner
82
- # i.e. projecting to the eigenbases of matrices in state['GG']
83
- exp_avg_projected = project(ea, state['Q'], False)
84
-
85
- # Projecting back the preconditioned (by Adam) exponential moving average of gradients
86
- # to the original space
87
- # CANT DO /= HERE AS EXP_AVG MAY POINT TO THE BUFFER
88
- precond = project(exp_avg_projected / d, state['Q'], True)
89
-
90
- update_preconditioner(g, state, max_precond_dim, precondition_1d, old_debiased2,
91
- step > 0 and step % group['precondition_frequency'] == 0)
92
-
93
- update_param_([p], [precond], step_size, group["weight_decay"], caution=group['caution'], grad=[g])
heavyball/foreach_solp.py DELETED
@@ -1,89 +0,0 @@
1
- import torch
2
-
3
- from .utils import init_preconditioner, update_preconditioner, project, beta_debias, exp_avg_sq_, update_param_, set_, \
4
- StatefulOptimizer, laprop_exp_avg_
5
-
6
-
7
- class ForeachSOLP(StatefulOptimizer):
8
- """
9
- ForeachSOLP
10
-
11
- Sources:
12
- LaProp:
13
- LaProp: Separating Momentum and Adaptivity in Adam
14
- Liu Ziyin, Zhikang T.Wang, Masahito Ueda
15
- https://arxiv.org/abs/2002.04839
16
- https://github.com/ClashLuke/HeavyBall/blob/main/heavyball/foreach_laprop.py
17
-
18
- Baseline SOAP:
19
- SOAP: Improving and Stabilizing Shampoo using Adam
20
- Nikhil Vyas, Depen Morwani, Rosie Zhao, Itai Shapira, David Brandfonbrener, Lucas Janson, Sham Kakade
21
- https://arxiv.org/abs/2409.11321
22
- https://github.com/nikhilvyas/SOAP
23
-
24
- ScheduleFree:
25
- The Road Less Scheduled
26
- Aaron Defazio, Xingyu Alice Yang, Harsh Mehta, Konstantin Mishchenko, Ahmed Khaled, Ashok Cutkosky
27
- https://arxiv.org/abs/2405.15682
28
- https://github.com/facebookresearch/schedule_free
29
- """
30
-
31
- def __init__(self, params, lr: float = 3e-3, betas=(0.9, 0.95), shampoo_beta: float = 0.95, eps: float = 1e-8,
32
- weight_decay: float = 0.01, precondition_frequency: int = 2, max_precond_dim: int = 2048, #
33
- merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
34
- data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1,
35
- split: bool = False, foreach: bool = True, mars: bool = False, caution: bool = False,
36
- mars_gamma: float = 0.0025):
37
- defaults = {"lr": lr, "betas": betas, "shampoo_beta": shampoo_beta, "eps": eps, "weight_decay": weight_decay,
38
- "precondition_frequency": precondition_frequency, "max_precond_dim": max_precond_dim,
39
- "merge_dims": merge_dims, "precondition_1d": precondition_1d, "normalize_grads": normalize_grads,
40
- "correct_bias": correct_bias, 'warmup_steps': warmup_steps, 'split': split, 'mars': mars,
41
- 'caution': caution, 'mars_gamma': mars_gamma}
42
- super().__init__(params, defaults, foreach)
43
- self._data_format = data_format
44
-
45
- def _step(self, group):
46
- vals = []
47
- step = 0
48
-
49
- max_precond_dim = group['max_precond_dim']
50
- precondition_1d = group['precondition_1d']
51
-
52
- for p, g in self.split_p_and_g_in_group(group, beta1=group['betas'][0]):
53
- state = self.state_(p)
54
- step = state['step'] = state.get("step", -1) + 1
55
-
56
- if "exp_avg" not in state:
57
- state["exp_avg"] = torch.zeros_like(g, dtype=torch.float32, memory_format=torch.preserve_format)
58
- state["exp_avg_sq"] = torch.zeros_like(g, dtype=torch.float32, memory_format=torch.preserve_format)
59
- init_preconditioner(g, state, max_precond_dim, precondition_1d)
60
- update_preconditioner(g, state, max_precond_dim, precondition_1d, 0, True)
61
- continue # first step is skipped so that we never use the current gradients in the projection.
62
-
63
- # Projecting gradients to the eigenbases of Shampoo's preconditioner
64
- # i.e. projecting to the eigenbases of matrices in state['GG']
65
- grad_projected = project(g, state['Q'], False)
66
- exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
67
- vals.append((p, g, grad_projected, exp_avg, exp_avg_sq))
68
-
69
- if not vals:
70
- return
71
-
72
- p_list, grad, grad_projected, exp_avg, exp_avg_sq = zip(*vals)
73
- beta1, beta2 = group["betas"]
74
-
75
- old_debiased2 = beta_debias(beta2, step)
76
-
77
- # Decay the first and second moment running average coefficient
78
- # In-place operations to update the averages at the same time
79
- step_tensor = torch.empty((), dtype=torch.int32, device=p_list[0].device).fill_(step)
80
-
81
- step_size = -group["lr"] * min(step / group['warmup_steps'], 1)
82
-
83
- for p, g, gp, ea, eas in zip(p_list, grad, grad_projected, exp_avg, exp_avg_sq):
84
- laprop_exp_avg_(ea, eas, gp, beta1, beta2, step_tensor)
85
- state = self.state_(p)
86
- precond = project(ea, state['Q'], True)
87
-
88
- update_preconditioner(g, state, max_precond_dim, precondition_1d, old_debiased2, step > 0 and step % group['precondition_frequency'] == 0)
89
- update_param_([p], [precond], step_size, group["weight_decay"], caution=group['caution'], grad=[g])
heavyball/p_adam.py DELETED
@@ -1,121 +0,0 @@
1
- """
2
- Originally from Evan Walters and Omead Pooladzandi, 2024
3
- Modified under Creative Commons Attribution 4.0 International
4
- Source available at https://github.com/evanatyourservice/kron_torch/blob/97a2b5ee8a1a4c29e4780bbf6c521e545189eff9/kron_torch/kron.py
5
- """
6
-
7
- import torch
8
- from heavyball.utils import triu_to_line, line_to_triu, identity, stochastic_lerp_
9
-
10
- from .utils import update_param_, warmup, psgd_precond_grad, init_Q_exprs, PSGDBase, exp_avg_sq_, beta_debias, \
11
- promote
12
-
13
-
14
- class ForeachPaLMPAdam(PSGDBase):
15
- """
16
- Kronecker Factorized Adam with PSGD preconditioner
17
-
18
- Args:
19
- params (iterable): Iterable of parameters to optimize or dicts defining
20
- parameter groups.
21
- lr (float): Learning rate.
22
- weight_decay (float): Weight decay (L2 penalty).
23
- preconditioner_update_probability (callable or float, optional): Probability of
24
- updating the preconditioner. If None, defaults to a schedule that anneals
25
- from 1.0 to 0.03 by 4000 steps.
26
- max_size_triangular (int): Max size for dim's preconditioner to be triangular.
27
- min_ndim_triangular (int): Minimum number of dimensions a layer needs
28
- to have triangular preconditioners.
29
- memory_save_mode: (string, optional), None, 'one_diag', or 'all_diag', None is default
30
- to set all preconditioners to be triangular, 'one_diag' sets the largest
31
- or last dim to be diagonal per layer, and 'all_diag' sets all preconditioners
32
- to be diagonal.
33
- momentum_into_precond_update: (bool), whether to send momentum into preconditioner
34
- update instead of raw gradients.
35
- """
36
-
37
- def __init__(self, params, lr=0.001, weight_decay=0.0, preconditioner_update_probability=None,
38
- max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
39
- momentum_into_precond_update=True, warmup_steps: int = 1, betas=(None, None), beta: float = 0.9,
40
- beta2_scale: float = 0.8, merge_dims: bool = False, split: bool = False, clip_fn: callable = None,
41
- store_triu_as_line: bool = True, foreach: bool = True, q_dtype='float32',
42
- stochastic_schedule: bool = True, storage_dtype: str = 'float32', mars: bool = False,
43
- caution: bool = False, mars_gamma: float = 0.0025, #
44
- # expert parameters
45
- precond_init_scale=1.0, precond_lr=0.1):
46
- if not 0.0 <= lr:
47
- raise ValueError(f"Invalid learning rate: {lr}")
48
- if not 0.0 <= weight_decay:
49
- raise ValueError(f"Invalid weight_decay value: {weight_decay}")
50
- if betas[0] is not None:
51
- beta = betas[0]
52
-
53
- if clip_fn is None:
54
- clip_fn = identity
55
-
56
- defaults = dict(lr=lr, weight_decay=weight_decay, max_size_triangular=max_size_triangular,
57
- min_ndim_triangular=min_ndim_triangular, memory_save_mode=memory_save_mode,
58
- momentum_into_precond_update=momentum_into_precond_update, precond_lr=precond_lr,
59
- precond_init_scale=precond_init_scale, step=0, warmup_steps=warmup_steps, beta=beta,
60
- beta2_scale=beta2_scale, merge_dims=merge_dims, split=split,
61
- store_triu_as_line=store_triu_as_line, q_dtype=q_dtype, storage_dtype=storage_dtype,
62
- mars=mars, caution=caution, mars_gamma=mars_gamma)
63
- super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
64
-
65
- def _step(self, group):
66
- should_update = self.should_update(group)
67
- precond_init_scale = group['precond_init_scale']
68
- max_size_triangular = group['max_size_triangular']
69
- min_ndim_triangular = group['min_ndim_triangular']
70
- memory_save_mode = group['memory_save_mode']
71
- precond_lr = group['precond_lr']
72
- weight_decay = group['weight_decay']
73
- lr = group['lr']
74
- store_triu_as_line = group['store_triu_as_line']
75
- q_dtype = getattr(torch, group['q_dtype'])
76
- storage_dtype = getattr(torch, group['storage_dtype'])
77
-
78
- vals = []
79
-
80
- for p, g in self.split_p_and_g_in_group(group, should_promote=False, beta1=group['beta']):
81
- state = self.state_(p)
82
-
83
- if 'Q' not in state:
84
- state['exp_avg'] = torch.zeros_like(g, dtype=storage_dtype, memory_format=torch.preserve_format)
85
- state['exp_avg_sq'] = torch.zeros_like(g, dtype=storage_dtype, memory_format=torch.preserve_format)
86
- Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular, min_ndim_triangular,
87
- memory_save_mode, dtype=q_dtype)
88
- state['Q'] = triu_to_line(Q) if store_triu_as_line else Q
89
-
90
- vals.append((p, g, state["Q"], state['exp_avg'], state['exp_avg_sq']))
91
-
92
- if not vals:
93
- return
94
-
95
- p_list, grad_list, Q_list, exp_avg, exp_avg_sq = zip(*vals)
96
- del vals
97
-
98
- group["step"] += 1
99
-
100
- Q_triu = [line_to_triu(q) if store_triu_as_line else q for q in Q_list]
101
- if should_update:
102
- for g, p, q_, q_orig in zip(grad_list, p_list, Q_triu, Q_list):
103
- q32 = [promote(qq_) for qq_ in q_]
104
- self.do_update(group, [p], [g], [q32], precond_lr, [q_orig], store_triu_as_line)
105
- stochastic_lerp_(exp_avg, grad_list, 1 - beta_debias(group['beta'], group['step']))
106
-
107
- beta2 = 1 - group['step'] ** -group['beta2_scale']
108
-
109
- lr = -warmup(lr, group['step'], group['warmup_steps'])
110
-
111
- for p, Q, g, ea, eas in zip(p_list, Q_triu, grad_list, exp_avg, exp_avg_sq):
112
- gc = g.clone() if group['caution'] else None
113
- psgd_precond_grad(True, self.state_(p)["exprs"][-1], g, *Q)
114
- ea = psgd_precond_grad(False, self.state_(p)["exprs"][-1], ea, *Q)
115
- exp_avg_sq_(eas, g, beta_debias(beta2, group['step']), 1e-8, out=g)
116
- torch.div(ea, g, out=g)
117
- """
118
- divide by g here, because g == denom (from exp_avg_sq_(out=g)), avoids denom allocation
119
- divide into g so we can deallocate ea, avoids one allocation (-> less memory than equivalent foreach)
120
- """
121
- update_param_([p], self.clip_fn([g]), lr, weight_decay, caution=group['caution'], grad=gc)
@@ -1,77 +0,0 @@
1
- import torch
2
- import torch.optim
3
-
4
- from .utils import warmup, ScheduleFree, exp_avg_sq_, beta_debias, get_ckp1, promote, \
5
- _compilable_schedule_free_, copy_stochastic_list_, decorator_knowngood
6
-
7
-
8
- @decorator_knowngood
9
- def _compilable_step_(y, grad, exp_avg_sq, z, beta1, beta2, step, ckp1, eps, decay, lr):
10
- old_debiased2 = beta_debias(beta2, step)
11
-
12
- g32 = [promote(g_) for g_ in grad]
13
- exp_avg_sq32 = [promote(e_) for e_ in exp_avg_sq]
14
-
15
- denom = exp_avg_sq_(exp_avg_sq32, g32, old_debiased2, eps)
16
- torch._foreach_div_(g32, denom)
17
- if decay != 0:
18
- torch._foreach_add_(g32, y, alpha=decay)
19
- for p, z_, g in zip(y, z, g32):
20
- _compilable_schedule_free_(p, z_, ckp1, g, lr, beta1)
21
-
22
- copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
23
-
24
-
25
- class PaLMForeachSFAdamW(ScheduleFree):
26
- def __init__(self, params, lr=0.0025, beta=0.9, betas=(None, None), eps=1e-8, weight_decay=0, warmup_steps=0, r=0.0,
27
- weight_lr_power=2.0, beta2_scale: float = 0.8, foreach: bool = True, storage_dtype: str = 'float32',
28
- mars: bool = False, caution: bool = False, mars_gamma: float = 0.0025):
29
- if betas[0] is not None:
30
- beta = betas[0]
31
-
32
- assert not caution, "Caution not implemented for SFAdamW"
33
-
34
- defaults = dict(lr=lr, beta=beta, eps=eps, r=r, k=0, warmup_steps=warmup_steps, train_mode=True, weight_sum=0.0,
35
- lr_max=-1.0, weight_lr_power=weight_lr_power, weight_decay=weight_decay,
36
- beta2_scale=beta2_scale, storage_dtype=storage_dtype, mars=mars, caution=caution,
37
- mars_gamma=mars_gamma)
38
- super().__init__(params, defaults, foreach)
39
-
40
- def _step(self, group):
41
- eps = group['eps']
42
- decay = group['weight_decay']
43
- k = group['k']
44
-
45
- if not group['train_mode']:
46
- raise Exception("Not in train mode!")
47
-
48
- active_p = [p for p in group['params'] if p.grad is not None]
49
-
50
- if not active_p:
51
- return
52
-
53
- storage_dtype = getattr(torch, group['storage_dtype'])
54
-
55
- for p in active_p:
56
- if 'z' not in self.state_(p):
57
- self.state_(p)['z'] = torch.clone(p.data, memory_format=torch.preserve_format)
58
- self.state_(p)['exp_avg_sq'] = torch.zeros_like(p.data, dtype=storage_dtype, memory_format=torch.preserve_format)
59
-
60
- # Decay the first moment running average coefficient
61
- beta2 = 1 - (k + 1) ** -group['beta2_scale']
62
-
63
- y, grad, exp_avg_sq, z = zip(*[(p.data, p.grad, self.state_(p)['exp_avg_sq'], self.state_(p)['z']) #
64
- for p in active_p])
65
-
66
- if group['mars']:
67
- self.mars_correct_list(group, y, grad, group['mars_gamma'], group['betas'][0])
68
-
69
- lr = warmup(group['lr'], k + 1, group['warmup_steps'])
70
- ckp1, group['weight_sum'] = get_ckp1(lr, group['weight_lr_power'], group['weight_sum'], group['r'], k + 1)
71
-
72
- step = torch.empty((), dtype=torch.int32, device=y[0].device).fill_(k + 1)
73
- ckp1 = torch.empty((), dtype=torch.float32, device=y[0].device).fill_(ckp1)
74
- beta2 = torch.empty((), dtype=torch.float32, device=y[0].device).fill_(beta2)
75
- lr = torch.empty((), dtype=torch.float32, device=y[0].device).fill_(lr)
76
- _compilable_step_(y, grad, exp_avg_sq, z, group['beta'], beta2, step, ckp1, eps, decay, lr)
77
- group['k'] = k + 1
@@ -1,101 +0,0 @@
1
- import torch
2
-
3
- from .utils import init_preconditioner, update_preconditioner, project, beta_debias, exp_avg_sq_, update_param_, set_, \
4
- StatefulOptimizer, exp_avg_
5
-
6
-
7
- class PaLMForeachSOAP(StatefulOptimizer):
8
- """
9
- PaLMForeachSOAP
10
-
11
- Sources:
12
- Baseline SOAP:
13
- SOAP: Improving and Stabilizing Shampoo using Adam
14
- Nikhil Vyas, Depen Morwani, Rosie Zhao, Itai Shapira, David Brandfonbrener, Lucas Janson, Sham Kakade
15
- https://arxiv.org/abs/2409.11321
16
- https://github.com/nikhilvyas/SOAP
17
-
18
- ScheduleFree:
19
- The Road Less Scheduled
20
- Aaron Defazio, Xingyu Alice Yang, Harsh Mehta, Konstantin Mishchenko, Ahmed Khaled, Ashok Cutkosky
21
- https://arxiv.org/abs/2405.15682
22
- https://github.com/facebookresearch/schedule_free
23
-
24
- Beta2 Schedule:
25
- PaLM: Scaling Language Modeling with Pathways
26
- Aakanksha Chowdhery, Sharan Narang, Jacob Devlin, Maarten Bosma, Gaurav Mishra, Adam Roberts, Paul Barham, Hyung Won Chung, Charles Sutton, Sebastian Gehrmann, Parker Schuh, Kensen Shi, Sasha Tsvyashchenko, Joshua Maynez, Abhishek Rao, Parker Barnes, Yi Tay, Noam Shazeer, Vinodkumar Prabhakaran, Emily Reif, Nan Du, Ben Hutchinson, Reiner Pope, James Bradbury, Jacob Austin, Michael Isard, Guy Gur-Ari, Pengcheng Yin, Toju Duke, Anselm Levskaya, Sanjay Ghemawat, Sunipa Dev, Henryk Michalewski, Xavier Garcia, Vedant Misra, Kevin Robinson, Liam Fedus, Denny Zhou, Daphne Ippolito, David Luan, Hyeontaek Lim, Barret Zoph, Alexander Spiridonov, Ryan Sepassi, David Dohan, Shivani Agrawal, Mark Omernick, Andrew M. Dai, Thanumalayan Sankaranarayana Pillai, Marie Pellat, Aitor Lewkowycz, Erica Moreira, Rewon Child, Oleksandr Polozov, Katherine Lee, Zongwei Zhou, Xuezhi Wang, Brennan Saeta, Mark Diaz, Orhan Firat, Michele Catasta, Jason Wei, Kathy Meier-Hellstern, Douglas Eck, Jeff Dean, Slav Petrov, Noah Fiedel
27
- https://arxiv.org/abs/2204.02311
28
- """
29
-
30
- def __init__(self, params, lr: float = 3e-3, beta=0.9, betas=(None, None), shampoo_beta: float = 0.95,
31
- eps: float = 1e-8, weight_decay: float = 0.01, precondition_frequency: int = 2,
32
- max_precond_dim: int = 2048, #
33
- merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
34
- data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1,
35
- beta2_scale: float = 0.8, split: bool = False, foreach: bool = True, mars: bool = False,
36
- caution: bool = False, mars_gamma: float = 0.0025):
37
- if betas[0] is not None:
38
- beta = betas[0]
39
- defaults = {"lr": lr, "beta": beta, "shampoo_beta": shampoo_beta, "eps": eps, "weight_decay": weight_decay,
40
- "precondition_frequency": precondition_frequency, "max_precond_dim": max_precond_dim,
41
- "merge_dims": merge_dims, "precondition_1d": precondition_1d, "normalize_grads": normalize_grads,
42
- "correct_bias": correct_bias, 'warmup_steps': warmup_steps, 'beta2_scale': beta2_scale,
43
- 'split': split, 'mars': mars, 'caution': caution, 'mars_gamma': mars_gamma}
44
- super().__init__(params, defaults, foreach)
45
- self._data_format = data_format
46
-
47
- def _step(self, group):
48
- vals = []
49
- step = 0
50
-
51
- max_precond_dim = group['max_precond_dim']
52
- precondition_1d = group['precondition_1d']
53
-
54
- for p, g in self.split_p_and_g_in_group(group, beta1=group['beta']):
55
- state = self.state_(p)
56
- step = state['step'] = state.get("step", -1) + 1
57
-
58
- if "exp_avg" not in state:
59
- state["exp_avg"] = torch.zeros_like(g, dtype=torch.float32, memory_format=torch.preserve_format)
60
- state["exp_avg_sq"] = torch.zeros_like(g, dtype=torch.float32, memory_format=torch.preserve_format)
61
- init_preconditioner(g, state, max_precond_dim, precondition_1d)
62
- update_preconditioner(g, state, max_precond_dim, precondition_1d, 0, True)
63
- continue # first step is skipped so that we never use the current gradients in the projection.
64
-
65
- # Projecting gradients to the eigenbases of Shampoo's preconditioner
66
- # i.e. projecting to the eigenbases of matrices in state['GG']
67
- grad_projected = project(g, state['Q'], False)
68
- exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
69
- vals.append((p, g, grad_projected, exp_avg, exp_avg_sq))
70
-
71
- if not vals:
72
- return
73
-
74
- p_list, grad, grad_projected, exp_avg, exp_avg_sq = zip(*vals)
75
- beta1 = group["beta"]
76
-
77
- beta2 = 1 - step ** -group['beta2_scale']
78
- old_debiased2 = beta_debias(beta2, step)
79
-
80
- # Decay the first and second moment running average coefficient
81
- # In-place operations to update the averages at the same time
82
- beta2 = torch.empty((), dtype=torch.float32, device=p_list[0].device).fill_(beta2)
83
- step_tensor = torch.empty((), dtype=torch.int32, device=p_list[0].device).fill_(step)
84
- step_size = -group["lr"] * min(step / group['warmup_steps'], 1)
85
-
86
- for p, g, gp, ea, eas in zip(p_list, grad, grad_projected, exp_avg, exp_avg_sq):
87
- state = self.state_(p)
88
- d = exp_avg_(ea, eas, g, gp, beta1, beta2, step_tensor)[0]
89
- # Projecting the exponential moving average of gradients to the eigenbases of Shampoo's preconditioner
90
- # i.e. projecting to the eigenbases of matrices in state['GG']
91
- exp_avg_projected = project(ea, state['Q'], False)
92
-
93
- # Projecting back the preconditioned (by Adam) exponential moving average of gradients
94
- # to the original space
95
- # CANT DO /= HERE AS EXP_AVG MAY POINT TO THE BUFFER
96
- precond = project(exp_avg_projected / d, state['Q'], True)
97
-
98
- update_preconditioner(g, state, max_precond_dim, precondition_1d, old_debiased2,
99
- step > 0 and step % group['precondition_frequency'] == 0)
100
-
101
- update_param_([p], [precond], step_size, group["weight_decay"], caution=group['caution'], grad=[g])
@@ -1,98 +0,0 @@
1
- import torch
2
-
3
- from .utils import init_preconditioner, update_preconditioner, project, beta_debias, exp_avg_sq_, update_param_, set_, \
4
- StatefulOptimizer, laprop_exp_avg_
5
-
6
-
7
- class PaLMForeachSOLP(StatefulOptimizer):
8
- """
9
- PaLMForeachSOAP
10
-
11
- Sources:
12
- LaProp:
13
- LaProp: Separating Momentum and Adaptivity in Adam
14
- Liu Ziyin, Zhikang T.Wang, Masahito Ueda
15
- https://arxiv.org/abs/2002.04839
16
- https://github.com/ClashLuke/HeavyBall/blob/main/heavyball/foreach_laprop.py
17
-
18
- Baseline SOAP:
19
- SOAP: Improving and Stabilizing Shampoo using Adam
20
- Nikhil Vyas, Depen Morwani, Rosie Zhao, Itai Shapira, David Brandfonbrener, Lucas Janson, Sham Kakade
21
- https://arxiv.org/abs/2409.11321
22
- https://github.com/nikhilvyas/SOAP
23
-
24
- ScheduleFree:
25
- The Road Less Scheduled
26
- Aaron Defazio, Xingyu Alice Yang, Harsh Mehta, Konstantin Mishchenko, Ahmed Khaled, Ashok Cutkosky
27
- https://arxiv.org/abs/2405.15682
28
- https://github.com/facebookresearch/schedule_free
29
-
30
- Beta2 Schedule:
31
- PaLM: Scaling Language Modeling with Pathways
32
- Aakanksha Chowdhery, Sharan Narang, Jacob Devlin, Maarten Bosma, Gaurav Mishra, Adam Roberts, Paul Barham, Hyung Won Chung, Charles Sutton, Sebastian Gehrmann, Parker Schuh, Kensen Shi, Sasha Tsvyashchenko, Joshua Maynez, Abhishek Rao, Parker Barnes, Yi Tay, Noam Shazeer, Vinodkumar Prabhakaran, Emily Reif, Nan Du, Ben Hutchinson, Reiner Pope, James Bradbury, Jacob Austin, Michael Isard, Guy Gur-Ari, Pengcheng Yin, Toju Duke, Anselm Levskaya, Sanjay Ghemawat, Sunipa Dev, Henryk Michalewski, Xavier Garcia, Vedant Misra, Kevin Robinson, Liam Fedus, Denny Zhou, Daphne Ippolito, David Luan, Hyeontaek Lim, Barret Zoph, Alexander Spiridonov, Ryan Sepassi, David Dohan, Shivani Agrawal, Mark Omernick, Andrew M. Dai, Thanumalayan Sankaranarayana Pillai, Marie Pellat, Aitor Lewkowycz, Erica Moreira, Rewon Child, Oleksandr Polozov, Katherine Lee, Zongwei Zhou, Xuezhi Wang, Brennan Saeta, Mark Diaz, Orhan Firat, Michele Catasta, Jason Wei, Kathy Meier-Hellstern, Douglas Eck, Jeff Dean, Slav Petrov, Noah Fiedel
33
- https://arxiv.org/abs/2204.02311
34
- """
35
-
36
- def __init__(self, params, lr: float = 3e-3, beta=0.9, betas=(None, None), shampoo_beta: float = 0.95,
37
- eps: float = 1e-8, weight_decay: float = 0.01, precondition_frequency: int = 2,
38
- max_precond_dim: int = 2048, #
39
- merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
40
- data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1,
41
- beta2_scale: float = 0.8, split: bool = False, foreach: bool = True, mars: bool = False,
42
- caution: bool = False, mars_gamma: float = 0.0025):
43
- if betas[0] is not None:
44
- beta = betas[0]
45
- defaults = {"lr": lr, "beta": beta, "shampoo_beta": shampoo_beta, "eps": eps, "weight_decay": weight_decay,
46
- "precondition_frequency": precondition_frequency, "max_precond_dim": max_precond_dim,
47
- "merge_dims": merge_dims, "precondition_1d": precondition_1d, "normalize_grads": normalize_grads,
48
- "correct_bias": correct_bias, 'warmup_steps': warmup_steps, 'beta2_scale': beta2_scale,
49
- 'split': split, 'mars': mars, 'caution': caution, 'mars_gamma': mars_gamma}
50
- super().__init__(params, defaults, foreach)
51
- self._data_format = data_format
52
-
53
- def _step(self, group):
54
- vals = []
55
- step = 0
56
-
57
- max_precond_dim = group['max_precond_dim']
58
- precondition_1d = group['precondition_1d']
59
-
60
- for p, g in self.split_p_and_g_in_group(group, beta1=group['beta']):
61
- state = self.state_(p)
62
- step = state['step'] = state.get("step", -1) + 1
63
-
64
- if "exp_avg" not in state:
65
- state["exp_avg"] = torch.zeros_like(g, dtype=torch.float32, memory_format=torch.preserve_format)
66
- state["exp_avg_sq"] = torch.zeros_like(g, dtype=torch.float32, memory_format=torch.preserve_format)
67
- init_preconditioner(g, state, max_precond_dim, precondition_1d)
68
- update_preconditioner(g, state, max_precond_dim, precondition_1d, 0, True)
69
- continue # first step is skipped so that we never use the current gradients in the projection.
70
-
71
- # Projecting gradients to the eigenbases of Shampoo's preconditioner
72
- # i.e. projecting to the eigenbases of matrices in state['GG']
73
- grad_projected = project(g, state['Q'], False)
74
- exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
75
- vals.append((p, g, grad_projected, exp_avg, exp_avg_sq))
76
-
77
- if not vals:
78
- return
79
-
80
- p_list, grad, grad_projected, exp_avg, exp_avg_sq = zip(*vals)
81
- beta1 = group["beta"]
82
-
83
- beta2 = 1 - step ** -group['beta2_scale']
84
- old_debiased2 = beta_debias(beta2, step)
85
-
86
- # Decay the first and second moment running average coefficient
87
- # In-place operations to update the averages at the same time
88
- beta2 = torch.empty((), dtype=torch.float32, device=p_list[0].device).fill_(beta2)
89
- step_tensor = torch.empty((), dtype=torch.int32, device=p_list[0].device).fill_(step)
90
- step_size = -group["lr"] * min(step / group['warmup_steps'], 1)
91
-
92
- for p, g, gp, ea, eas in zip(p_list, grad, grad_projected, exp_avg, exp_avg_sq):
93
- laprop_exp_avg_(ea, eas, gp, beta1, beta2, step_tensor)
94
- state = self.state_(p)
95
- precond = project(ea, state['Q'], True)
96
-
97
- update_preconditioner(g, state, max_precond_dim, precondition_1d, old_debiased2, step > 0 and step % group['precondition_frequency'] == 0)
98
- update_param_([p], [precond], step_size, group["weight_decay"], caution=group['caution'], grad=[g])