heavyball 0.14.6__tar.gz → 0.15.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. {heavyball-0.14.6 → heavyball-0.15.0}/PKG-INFO +1 -1
  2. {heavyball-0.14.6 → heavyball-0.15.0}/heavyball/delayed_psgd.py +39 -48
  3. heavyball-0.15.0/heavyball/foreach_adamw.py +41 -0
  4. heavyball-0.15.0/heavyball/foreach_adopt.py +51 -0
  5. heavyball-0.15.0/heavyball/foreach_laprop.py +46 -0
  6. heavyball-0.15.0/heavyball/foreach_sfadamw.py +54 -0
  7. heavyball-0.15.0/heavyball/foreach_soap.py +92 -0
  8. {heavyball-0.14.6 → heavyball-0.15.0}/heavyball/p_adam.py +39 -48
  9. heavyball-0.15.0/heavyball/palm_foreach_sfadamw.py +56 -0
  10. {heavyball-0.14.6 → heavyball-0.15.0}/heavyball/palm_foreach_soap.py +56 -70
  11. heavyball-0.15.0/heavyball/precond_schedule_foreach_soap.py +96 -0
  12. {heavyball-0.14.6 → heavyball-0.15.0}/heavyball/precond_schedule_palm_foreach_soap.py +58 -73
  13. {heavyball-0.14.6 → heavyball-0.15.0}/heavyball/precond_schedule_sfpsoap.py +60 -72
  14. {heavyball-0.14.6 → heavyball-0.15.0}/heavyball/psgd_kron.py +39 -47
  15. {heavyball-0.14.6 → heavyball-0.15.0}/heavyball/pure_psgd.py +32 -41
  16. {heavyball-0.14.6 → heavyball-0.15.0}/heavyball/schedule_free_palm_foreach_soap.py +61 -72
  17. {heavyball-0.14.6 → heavyball-0.15.0}/heavyball/utils.py +17 -2
  18. {heavyball-0.14.6 → heavyball-0.15.0}/heavyball.egg-info/PKG-INFO +1 -1
  19. {heavyball-0.14.6 → heavyball-0.15.0}/heavyball.egg-info/SOURCES.txt +2 -0
  20. {heavyball-0.14.6 → heavyball-0.15.0}/setup.py +1 -1
  21. heavyball-0.15.0/test/test_closure.py +44 -0
  22. heavyball-0.15.0/test/test_no_grad.py +39 -0
  23. heavyball-0.14.6/heavyball/foreach_adamw.py +0 -51
  24. heavyball-0.14.6/heavyball/foreach_adopt.py +0 -61
  25. heavyball-0.14.6/heavyball/foreach_laprop.py +0 -56
  26. heavyball-0.14.6/heavyball/foreach_sfadamw.py +0 -64
  27. heavyball-0.14.6/heavyball/foreach_soap.py +0 -106
  28. heavyball-0.14.6/heavyball/palm_foreach_sfadamw.py +0 -66
  29. heavyball-0.14.6/heavyball/precond_schedule_foreach_soap.py +0 -110
  30. {heavyball-0.14.6 → heavyball-0.15.0}/LICENSE +0 -0
  31. {heavyball-0.14.6 → heavyball-0.15.0}/README.md +0 -0
  32. {heavyball-0.14.6 → heavyball-0.15.0}/heavyball/__init__.py +1 -1
  33. {heavyball-0.14.6 → heavyball-0.15.0}/heavyball.egg-info/dependency_links.txt +0 -0
  34. {heavyball-0.14.6 → heavyball-0.15.0}/heavyball.egg-info/requires.txt +0 -0
  35. {heavyball-0.14.6 → heavyball-0.15.0}/heavyball.egg-info/top_level.txt +0 -0
  36. {heavyball-0.14.6 → heavyball-0.15.0}/setup.cfg +0 -0
  37. {heavyball-0.14.6 → heavyball-0.15.0}/test/test_memory.py +0 -0
  38. {heavyball-0.14.6 → heavyball-0.15.0}/test/test_merge.py +0 -0
  39. {heavyball-0.14.6 → heavyball-0.15.0}/test/test_psgd.py +0 -0
  40. {heavyball-0.14.6 → heavyball-0.15.0}/test/test_soap.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: heavyball
3
- Version: 0.14.6
3
+ Version: 0.15.0
4
4
  Summary: Efficient optimizers
5
5
  Home-page: https://github.com/clashluke/heavyball
6
6
  Author: Lucas Nestler
@@ -5,8 +5,8 @@ Source available at https://github.com/evanatyourservice/kron_torch/blob/97a2b5e
5
5
  """
6
6
 
7
7
  import torch
8
-
9
8
  from heavyball.utils import copy_stochastic_list_
9
+
10
10
  from .utils import update_param_, warmup, psgd_precond_grad, init_Q_exprs, trust_region_clip_, PSGDBase, \
11
11
  precond_update_prob_schedule, split_p_and_g_in_group, triu_to_line, line_to_triu, set_
12
12
 
@@ -63,13 +63,7 @@ class ForeachDelayedPSGD(PSGDBase):
63
63
 
64
64
  self._prob_step = 0
65
65
 
66
- @torch.no_grad()
67
- def step(self, closure=None):
68
- loss = None
69
- if closure is not None:
70
- with torch.enable_grad():
71
- loss = closure()
72
-
66
+ def _step(self, group):
73
67
  # update preconditioners all together
74
68
  update_prob = self.preconditioner_update_probability
75
69
  if callable(update_prob):
@@ -77,55 +71,52 @@ class ForeachDelayedPSGD(PSGDBase):
77
71
  do_update = self.rng.random() < update_prob
78
72
  self._prob_step += 1
79
73
 
80
- for group in self.param_groups:
81
- momentum_into_precond_update = group.get("momentum_into_precond_update", True)
82
- precond_init_scale = group['precond_init_scale']
83
- max_size_triangular = group['max_size_triangular']
84
- min_ndim_triangular = group['min_ndim_triangular']
85
- memory_save_mode = group['memory_save_mode']
86
- precond_lr = group['precond_lr']
87
- weight_decay = group['weight_decay']
88
- lr = group['lr']
89
- beta = group['beta']
90
-
91
- vals = []
74
+ momentum_into_precond_update = group.get("momentum_into_precond_update", True)
75
+ precond_init_scale = group['precond_init_scale']
76
+ max_size_triangular = group['max_size_triangular']
77
+ min_ndim_triangular = group['min_ndim_triangular']
78
+ memory_save_mode = group['memory_save_mode']
79
+ precond_lr = group['precond_lr']
80
+ weight_decay = group['weight_decay']
81
+ lr = group['lr']
82
+ beta = group['beta']
92
83
 
93
- for p, g in split_p_and_g_in_group(group):
94
- state = self.state_(p)
84
+ vals = []
95
85
 
96
- if 'Q' not in state:
97
- state["exp_avg"] = torch.zeros_like(g)
98
- Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular, min_ndim_triangular,
99
- memory_save_mode, dtype=g.dtype)
100
- state["Q"] = triu_to_line(Q)
86
+ for p, g in split_p_and_g_in_group(group):
87
+ state = self.state_(p)
101
88
 
102
- vals.append((p, g, state["exp_avg"], state["Q"]))
89
+ if 'Q' not in state:
90
+ state["exp_avg"] = torch.zeros_like(g)
91
+ Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular, min_ndim_triangular,
92
+ memory_save_mode, dtype=g.dtype)
93
+ state["Q"] = triu_to_line(Q)
103
94
 
104
- if not vals:
105
- continue
95
+ vals.append((p, g, state["exp_avg"], state["Q"]))
106
96
 
107
- p_list, grad_list, exp_avg_list, Q_list = zip(*vals)
108
- del vals
97
+ if not vals:
98
+ return
109
99
 
110
- group["step"] += 1
100
+ p_list, grad_list, exp_avg_list, Q_list = zip(*vals)
101
+ del vals
111
102
 
112
- torch._foreach_lerp_(exp_avg_list, grad_list, (1 - beta) / (1 - beta ** group["step"]))
103
+ group["step"] += 1
113
104
 
114
- Q_list, exp_avg_list = list(Q_list), list(exp_avg_list)
115
- for i, (p, g) in enumerate(zip(p_list, grad_list)):
116
- q_orig = Q_list.pop(0)
117
- ea = exp_avg_list.pop(0)
118
- q = line_to_triu(q_orig)
119
- self.balance(do_update, [g], [q])
120
- new = psgd_precond_grad(q, self.state_(p)["exprs"], ea)
105
+ torch._foreach_lerp_(exp_avg_list, grad_list, (1 - beta) / (1 - beta ** group["step"]))
121
106
 
122
- if do_update:
123
- self.do_update([p], [ea if momentum_into_precond_update else g], [q], precond_lr, [q_orig])
124
- set_(g, new)
107
+ Q_list, exp_avg_list = list(Q_list), list(exp_avg_list)
108
+ for i, (p, g) in enumerate(zip(p_list, grad_list)):
109
+ q_orig = Q_list.pop(0)
110
+ ea = exp_avg_list.pop(0)
111
+ q = line_to_triu(q_orig)
112
+ self.balance(do_update, [g], [q])
113
+ new = psgd_precond_grad(q, self.state_(p)["exprs"], ea)
125
114
 
126
- grad_list = self.clip_fn(grad_list)
115
+ if do_update:
116
+ self.do_update([p], [ea if momentum_into_precond_update else g], [q], precond_lr, [q_orig])
117
+ set_(g, new)
127
118
 
128
- lr = -warmup(lr, group['step'], group['warmup_steps'])
129
- update_param_(p_list, grad_list, lr, weight_decay)
119
+ grad_list = self.clip_fn(grad_list)
130
120
 
131
- return loss
121
+ lr = -warmup(lr, group['step'], group['warmup_steps'])
122
+ update_param_(p_list, grad_list, lr, weight_decay)
@@ -0,0 +1,41 @@
1
+ import torch
2
+ import torch.optim
3
+
4
+ from .utils import warmup, exp_avg_sq_, beta_debias, update_param_, StatefulOptimizer
5
+
6
+
7
+ class ForeachAdamW(StatefulOptimizer):
8
+ def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0):
9
+ defaults = dict(lr=lr, betas=betas, eps=eps, k=0, warmup_steps=warmup_steps, train_mode=True, weight_sum=0.0,
10
+ lr_max=-1.0, weight_decay=weight_decay)
11
+ super().__init__(params, defaults)
12
+
13
+ def _step(self, group):
14
+ eps = group['eps']
15
+ decay = group['weight_decay']
16
+ k = group['k']
17
+
18
+ if not group['train_mode']:
19
+ raise Exception("Not in train mode!")
20
+
21
+ active_p = [p for p in group['params'] if p.grad is not None]
22
+
23
+ if not active_p:
24
+ return
25
+
26
+ for p in active_p:
27
+ if 'exp_avg' not in self.state_(p):
28
+ self.state_(p)['exp_avg'] = torch.zeros_like(p.data, dtype=torch.float32)
29
+ self.state_(p)['exp_avg_sq'] = torch.zeros_like(p.data, dtype=torch.float32)
30
+
31
+ y, grad, exp_avg_sq, exp_avg = zip(
32
+ *[(p.data, p.grad.float(), self.state_(p)['exp_avg_sq'], self.state_(p)['exp_avg']) for p in active_p])
33
+
34
+ # Decay the first and second moment running average coefficient
35
+ torch._foreach_lerp_(exp_avg, grad, 1 - beta_debias(group['betas'][0], k + 1))
36
+ denom = exp_avg_sq_(exp_avg_sq, grad, beta_debias(group['betas'][1], k + 1), eps)
37
+
38
+ # Normalize grad in-place for memory efficiency
39
+ lr = -warmup(group['lr'], k + 1, group['warmup_steps'])
40
+ update_param_(y, exp_avg, lr, decay, lambda p, e, l: torch._foreach_addcdiv_(p, e, denom, l))
41
+ group['k'] = k + 1
@@ -0,0 +1,51 @@
1
+ import torch
2
+ import torch.optim
3
+
4
+ from .utils import warmup, beta_debias, update_param_, StatefulOptimizer
5
+
6
+
7
+ class ForeachADOPT(StatefulOptimizer):
8
+
9
+ def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0):
10
+ defaults = dict(lr=lr, betas=betas, eps=eps, k=0, warmup_steps=warmup_steps, train_mode=True, weight_sum=0.0,
11
+ lr_max=-1.0, weight_decay=weight_decay)
12
+ super().__init__(params, defaults)
13
+
14
+ def _step(self, group):
15
+ eps = group['eps']
16
+ decay = group['weight_decay']
17
+ k = group['k']
18
+
19
+ if not group['train_mode']:
20
+ raise Exception("Not in train mode!")
21
+
22
+ active_p = [p for p in group['params'] if p.grad is not None]
23
+
24
+ if not active_p:
25
+ return
26
+
27
+ for p in active_p:
28
+ if 'exp_avg' not in self.state_(p):
29
+ self.state_(p)['exp_avg'] = torch.zeros_like(p.data, dtype=torch.float32)
30
+ self.state_(p)['exp_avg_sq'] = torch.zeros_like(p.data, dtype=torch.float32)
31
+
32
+ y, grad, exp_avg_sq, exp_avg = zip(
33
+ *[(p.data, p.grad.float(), self.state_(p)['exp_avg_sq'], self.state_(p)['exp_avg']) for p in active_p])
34
+
35
+ if k > 1:
36
+ lr = -warmup(group['lr'], k - 1, group['warmup_steps'])
37
+
38
+ update_param_(y, exp_avg, lr, decay)
39
+ if k > 0:
40
+ beta1 = beta_debias(group['betas'][0], k)
41
+ denom = torch._foreach_sqrt(exp_avg_sq)
42
+ torch._foreach_maximum_(denom, eps)
43
+ torch._foreach_mul_(exp_avg, beta1)
44
+ torch._foreach_addcdiv_(exp_avg, grad, denom, 1 - beta1)
45
+
46
+ beta2 = beta_debias(group['betas'][1], k + 1)
47
+ torch._foreach_mul_(exp_avg_sq, beta2)
48
+ torch._foreach_addcmul_(exp_avg_sq, grad, grad, value=1 - beta2)
49
+ del grad
50
+
51
+ group['k'] = k + 1
@@ -0,0 +1,46 @@
1
+ import torch
2
+ import torch.optim
3
+
4
+ from .utils import warmup, exp_avg_sq_, beta_debias, update_param_, StatefulOptimizer
5
+
6
+
7
+ class ForeachLaProp(StatefulOptimizer):
8
+
9
+ def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=1):
10
+ defaults = dict(lr=lr, betas=betas, eps=eps, k=0, warmup_steps=warmup_steps, train_mode=True, weight_sum=0.0,
11
+ lr_max=-1.0, weight_decay=weight_decay)
12
+ super().__init__(params, defaults)
13
+
14
+ def _step(self, group):
15
+ eps = group['eps']
16
+ decay = group['weight_decay']
17
+ k = group['k']
18
+
19
+ if not group['train_mode']:
20
+ raise Exception("Not in train mode!")
21
+
22
+ active_p = [p for p in group['params'] if p.grad is not None]
23
+
24
+ if not active_p:
25
+ return
26
+
27
+ for p in active_p:
28
+ if 'exp_avg' not in self.state_(p):
29
+ self.state_(p)['exp_avg'] = torch.zeros_like(p.data, dtype=torch.float32)
30
+ self.state_(p)['exp_avg_sq'] = torch.zeros_like(p.data, dtype=torch.float32)
31
+
32
+ y, grad, exp_avg_sq, exp_avg = zip(
33
+ *[(p.data, p.grad.float(), self.state_(p)['exp_avg_sq'], self.state_(p)['exp_avg']) for p in active_p])
34
+
35
+ # Decay the first and second moment running average coefficient
36
+ denom = exp_avg_sq_(exp_avg_sq, grad, beta_debias(group['betas'][1], k + 1), eps)
37
+ beta1 = beta_debias(group['betas'][0], k + 1)
38
+ torch._foreach_mul_(exp_avg, beta1)
39
+ torch._foreach_addcdiv_(exp_avg, grad, denom, 1 - beta1)
40
+ del grad
41
+
42
+ # Normalize grad in-place for memory efficiency
43
+ lr = -warmup(group['lr'], k + 1, group['warmup_steps'])
44
+ update_param_(y, exp_avg, lr, decay)
45
+
46
+ group['k'] = k + 1
@@ -0,0 +1,54 @@
1
+ import torch
2
+ import torch.optim
3
+
4
+ from .utils import schedule_free_, warmup, ScheduleFree, exp_avg_sq_, beta_debias
5
+
6
+
7
+ class ForeachSFAdamW(ScheduleFree):
8
+ def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0, r=0.0,
9
+ weight_lr_power=2.0, foreach=hasattr(torch, "_foreach_mul_")):
10
+
11
+ defaults = dict(lr=lr, betas=betas, eps=eps, r=r, k=0, warmup_steps=warmup_steps, train_mode=True,
12
+ weight_sum=0.0, lr_max=-1.0, weight_lr_power=weight_lr_power, weight_decay=weight_decay,
13
+ foreach=foreach)
14
+ super().__init__(params, defaults)
15
+
16
+ def _step(self, group):
17
+ eps = group['eps']
18
+ decay = group['weight_decay']
19
+ k = group['k']
20
+
21
+ if not group['train_mode']:
22
+ raise Exception("Not in train mode!")
23
+
24
+ active_p = [p for p in group['params'] if p.grad is not None]
25
+
26
+ if not active_p:
27
+ return
28
+
29
+ for p in active_p:
30
+ if 'z' not in self.state_(p):
31
+ self.state_(p)['z'] = torch.clone(p.data)
32
+ self.state_(p)['exp_avg_sq'] = torch.zeros_like(p.data, dtype=torch.float32)
33
+
34
+ y, grad, exp_avg_sq, z = zip(
35
+ *[(p.data, p.grad.float(), self.state_(p)['exp_avg_sq'], self.state_(p)['z']) for p in active_p])
36
+
37
+ # Decay the first moment running average coefficient
38
+ old_debiased = beta_debias(group['betas'][1], k + 1)
39
+
40
+ # Decay the first and second moment running average coefficient
41
+ denom = exp_avg_sq_(exp_avg_sq, grad, old_debiased, eps)
42
+
43
+ # Normalize grad in-place for memory efficiency
44
+ torch._foreach_div_(grad, denom)
45
+
46
+ # Weight decay calculated at y
47
+ if decay != 0:
48
+ torch._foreach_add_(grad, y, alpha=decay)
49
+
50
+ lr = warmup(group['lr'], k + 1, group['warmup_steps'])
51
+ group['weight_sum'] = schedule_free_(lr, group['weight_lr_power'], group['weight_sum'], group['betas'][0],
52
+ y, z, grad, group['r'], k + 1)
53
+
54
+ group['k'] = k + 1
@@ -0,0 +1,92 @@
1
+ import torch
2
+
3
+ from .utils import init_preconditioner, update_preconditioner, project, beta_debias, exp_avg_sq_, update_param_, set_, \
4
+ split_p_and_g_in_group, StatefulOptimizer
5
+
6
+
7
+ class ForeachSOAP(StatefulOptimizer):
8
+ """
9
+ SFPaLMForeachSOAP
10
+
11
+ Sources:
12
+ Baseline SOAP:
13
+ SOAP: Improving and Stabilizing Shampoo using Adam
14
+ Nikhil Vyas, Depen Morwani, Rosie Zhao, Itai Shapira, David Brandfonbrener, Lucas Janson, Sham Kakade
15
+ https://arxiv.org/abs/2409.11321
16
+ https://github.com/nikhilvyas/SOAP
17
+
18
+ ScheduleFree:
19
+ The Road Less Scheduled
20
+ Aaron Defazio, Xingyu Alice Yang, Harsh Mehta, Konstantin Mishchenko, Ahmed Khaled, Ashok Cutkosky
21
+ https://arxiv.org/abs/2405.15682
22
+ https://github.com/facebookresearch/schedule_free
23
+ """
24
+
25
+ def __init__(self, params, lr: float = 3e-3, betas=(0.9, 0.95), shampoo_beta: float = 0.95, eps: float = 1e-8,
26
+ weight_decay: float = 0.01, precondition_frequency: int = 2, max_precond_dim: int = 2048, #
27
+ merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
28
+ data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1,
29
+ split: bool = False):
30
+ defaults = {"lr": lr, "betas": betas, "shampoo_beta": shampoo_beta, "eps": eps, "weight_decay": weight_decay,
31
+ "precondition_frequency": precondition_frequency, "max_precond_dim": max_precond_dim,
32
+ "merge_dims": merge_dims, "precondition_1d": precondition_1d, "normalize_grads": normalize_grads,
33
+ "correct_bias": correct_bias, 'warmup_steps': warmup_steps, 'split': split}
34
+ super().__init__(params, defaults)
35
+ self._data_format = data_format
36
+
37
+ def _step(self, group):
38
+ vals = []
39
+ step = 0
40
+
41
+ max_precond_dim = group['max_precond_dim']
42
+ precondition_1d = group['precondition_1d']
43
+
44
+ for p, g in split_p_and_g_in_group(group):
45
+ state = self.state_(p)
46
+ step = state['step'] = state.get("step", -1) + 1
47
+
48
+ if "exp_avg" not in state:
49
+ state["exp_avg"] = torch.zeros_like(g, dtype=torch.float32)
50
+ state["exp_avg_sq"] = torch.zeros_like(g, dtype=torch.float32)
51
+ init_preconditioner(g, state, max_precond_dim, precondition_1d)
52
+ update_preconditioner(g, state, max_precond_dim, precondition_1d, 0, True)
53
+ continue # first step is skipped so that we never use the current gradients in the projection.
54
+
55
+ # Projecting gradients to the eigenbases of Shampoo's preconditioner
56
+ # i.e. projecting to the eigenbases of matrices in state['GG']
57
+ grad_projected = project(g, state['Q'], False)
58
+ exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
59
+ vals.append((p, g, grad_projected, exp_avg, exp_avg_sq))
60
+
61
+ if not vals:
62
+ return
63
+
64
+ p_list, grad, grad_projected, exp_avg, exp_avg_sq = zip(*vals)
65
+ beta1, beta2 = group["betas"]
66
+
67
+ old_debiased1 = beta_debias(beta1, step)
68
+ old_debiased2 = beta_debias(beta2, step)
69
+
70
+ # Decay the first and second moment running average coefficient
71
+ # In-place operations to update the averages at the same time
72
+ torch._foreach_mul_(exp_avg, old_debiased1)
73
+ torch._foreach_add_(exp_avg, grad, alpha=1 - old_debiased1)
74
+ denom = exp_avg_sq_(exp_avg_sq, grad_projected, old_debiased2, group['eps'])
75
+
76
+ for p, g, ea, d in zip(p_list, grad, exp_avg, denom):
77
+ state = self.state_(p)
78
+ # Projecting the exponential moving average of gradients to the eigenbases of Shampoo's preconditioner
79
+ # i.e. projecting to the eigenbases of matrices in state['GG']
80
+ exp_avg_projected = project(ea, state['Q'], False)
81
+
82
+ # Projecting back the preconditioned (by Adam) exponential moving average of gradients
83
+ # to the original space
84
+ # CANT DO /= HERE AS EXP_AVG MAY POINT TO THE BUFFER
85
+ set_(d, project(exp_avg_projected / d, state['Q'], True))
86
+
87
+ update_preconditioner(g, state, max_precond_dim, precondition_1d, old_debiased2,
88
+ step > 0 and step % group['precondition_frequency'] == 0)
89
+
90
+ # Why does this have to be rebiased here?
91
+ step_size = -group["lr"] * min(step / group['warmup_steps'], 1)
92
+ update_param_(p_list, denom, step_size, group["weight_decay"])
@@ -62,13 +62,7 @@ class ForeachPaLMPAdam(PSGDBase):
62
62
 
63
63
  self._prob_step = 0
64
64
 
65
- @torch.no_grad()
66
- def step(self, closure=None):
67
- loss = None
68
- if closure is not None:
69
- with torch.enable_grad():
70
- loss = closure()
71
-
65
+ def _step(self, group):
72
66
  # update preconditioners all together
73
67
  update_prob = self.preconditioner_update_probability
74
68
  if callable(update_prob):
@@ -76,57 +70,54 @@ class ForeachPaLMPAdam(PSGDBase):
76
70
  do_update = self.rng.random() < update_prob
77
71
  self._prob_step += 1
78
72
 
79
- for group in self.param_groups:
80
- precond_init_scale = group['precond_init_scale']
81
- max_size_triangular = group['max_size_triangular']
82
- min_ndim_triangular = group['min_ndim_triangular']
83
- memory_save_mode = group['memory_save_mode']
84
- precond_lr = group['precond_lr']
85
- weight_decay = group['weight_decay']
86
- lr = group['lr']
87
-
88
- vals = []
73
+ precond_init_scale = group['precond_init_scale']
74
+ max_size_triangular = group['max_size_triangular']
75
+ min_ndim_triangular = group['min_ndim_triangular']
76
+ memory_save_mode = group['memory_save_mode']
77
+ precond_lr = group['precond_lr']
78
+ weight_decay = group['weight_decay']
79
+ lr = group['lr']
89
80
 
90
- for p, g in split_p_and_g_in_group(group):
91
- state = self.state_(p)
81
+ vals = []
92
82
 
93
- if 'Q' not in state:
94
- state['exp_avg'] = torch.zeros_like(g)
95
- state['exp_avg_sq'] = torch.zeros_like(g)
96
- state["Q"], state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular,
97
- min_ndim_triangular, memory_save_mode, dtype=g.dtype)
83
+ for p, g in split_p_and_g_in_group(group):
84
+ state = self.state_(p)
98
85
 
99
- vals.append((p, g, state["Q"], state['exp_avg'], state['exp_avg_sq']))
86
+ if 'Q' not in state:
87
+ state['exp_avg'] = torch.zeros_like(g)
88
+ state['exp_avg_sq'] = torch.zeros_like(g)
89
+ state["Q"], state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular,
90
+ min_ndim_triangular, memory_save_mode, dtype=g.dtype)
100
91
 
101
- if not vals:
102
- continue
92
+ vals.append((p, g, state["Q"], state['exp_avg'], state['exp_avg_sq']))
103
93
 
104
- p_list, grad_list, Q_list, exp_avg, exp_avg_sq = zip(*vals)
105
- del vals
94
+ if not vals:
95
+ return
106
96
 
107
- group["step"] += 1
97
+ p_list, grad_list, Q_list, exp_avg, exp_avg_sq = zip(*vals)
98
+ del vals
108
99
 
109
- self.balance(do_update, grad_list, Q_list)
110
- if do_update:
111
- self.do_update(p_list, grad_list, Q_list, precond_lr)
100
+ group["step"] += 1
112
101
 
113
- torch._foreach_lerp_(exp_avg, grad_list, 1 - beta_debias(group['beta'], group['step']))
102
+ self.balance(do_update, grad_list, Q_list)
103
+ if do_update:
104
+ self.do_update(p_list, grad_list, Q_list, precond_lr)
114
105
 
115
- beta2 = 1 - group['step'] ** -group['beta2_scale']
106
+ torch._foreach_lerp_(exp_avg, grad_list, 1 - beta_debias(group['beta'], group['step']))
116
107
 
117
- for p, Q, g, ea, eas in zip(p_list, Q_list, grad_list, exp_avg, exp_avg_sq):
118
- psgd_precond_grad(Q, self.state_(p)["exprs"], g, inplace=True)
119
- ea = psgd_precond_grad(Q, self.state_(p)["exprs"], ea)
120
- exp_avg_sq_(eas, g, beta_debias(beta2, group['step']), 1e-8, out=g)
121
- torch.div(ea, g, out=g)
122
- """
123
- divide by g here, because g == denom (from exp_avg_sq_(out=g)), avoids denom allocation
124
- divide into g so we can deallocate ea, avoids one allocation (-> less memory than equivalent foreach)
125
- """
108
+ beta2 = 1 - group['step'] ** -group['beta2_scale']
126
109
 
127
- grad_list = self.clip_fn(grad_list)
110
+ for p, Q, g, ea, eas in zip(p_list, Q_list, grad_list, exp_avg, exp_avg_sq):
111
+ psgd_precond_grad(Q, self.state_(p)["exprs"], g, inplace=True)
112
+ ea = psgd_precond_grad(Q, self.state_(p)["exprs"], ea)
113
+ exp_avg_sq_(eas, g, beta_debias(beta2, group['step']), 1e-8, out=g)
114
+ torch.div(ea, g, out=g)
115
+ """
116
+ divide by g here, because g == denom (from exp_avg_sq_(out=g)), avoids denom allocation
117
+ divide into g so we can deallocate ea, avoids one allocation (-> less memory than equivalent foreach)
118
+ """
128
119
 
129
- lr = -warmup(lr, group['step'], group['warmup_steps'])
130
- update_param_(p_list, grad_list, lr, weight_decay)
120
+ grad_list = self.clip_fn(grad_list)
131
121
 
132
- return loss
122
+ lr = -warmup(lr, group['step'], group['warmup_steps'])
123
+ update_param_(p_list, grad_list, lr, weight_decay)
@@ -0,0 +1,56 @@
1
+ import torch
2
+ import torch.optim
3
+
4
+ from .utils import schedule_free_, warmup, ScheduleFree, exp_avg_sq_, beta_debias
5
+
6
+
7
+ class PaLMForeachSFAdamW(ScheduleFree):
8
+ def __init__(self, params, lr=0.0025, beta=0.9, betas=(None, None), eps=1e-8, weight_decay=0, warmup_steps=0, r=0.0,
9
+ weight_lr_power=2.0, beta2_scale: float = 0.8):
10
+ if betas[0] is not None:
11
+ beta = betas[0]
12
+ defaults = dict(lr=lr, beta=beta, eps=eps, r=r, k=0, warmup_steps=warmup_steps, train_mode=True, weight_sum=0.0,
13
+ lr_max=-1.0, weight_lr_power=weight_lr_power, weight_decay=weight_decay,
14
+ beta2_scale=beta2_scale)
15
+ super().__init__(params, defaults)
16
+
17
+ def _step(self, group):
18
+ eps = group['eps']
19
+ decay = group['weight_decay']
20
+ k = group['k']
21
+
22
+ if not group['train_mode']:
23
+ raise Exception("Not in train mode!")
24
+
25
+ active_p = [p for p in group['params'] if p.grad is not None]
26
+
27
+ if not active_p:
28
+ return
29
+
30
+ for p in active_p:
31
+ if 'z' not in self.state_(p):
32
+ self.state_(p)['z'] = torch.clone(p.data)
33
+ self.state_(p)['exp_avg_sq'] = torch.zeros_like(p.data, dtype=torch.float32)
34
+
35
+ y, grad, exp_avg_sq, z = zip(
36
+ *[(p.data, p.grad.float(), self.state_(p)['exp_avg_sq'], self.state_(p)['z']) for p in active_p])
37
+
38
+ # Decay the first moment running average coefficient
39
+ beta2 = 1 - (k + 1) ** -group['beta2_scale']
40
+ old_debiased = beta_debias(beta2, k + 1)
41
+
42
+ # Decay the first and second moment running average coefficient
43
+ denom = exp_avg_sq_(exp_avg_sq, grad, old_debiased, eps)
44
+
45
+ # Normalize grad in-place for memory efficiency
46
+ torch._foreach_div_(grad, denom)
47
+
48
+ # Weight decay calculated at y
49
+ if decay != 0:
50
+ torch._foreach_add_(grad, y, alpha=decay)
51
+
52
+ lr = warmup(group['lr'], k + 1, group['warmup_steps'])
53
+ group['weight_sum'] = schedule_free_(lr, group['weight_lr_power'], group['weight_sum'], group['beta'], y, z,
54
+ grad, group['r'], k + 1)
55
+
56
+ group['k'] = k + 1