heavyball 0.25.1__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,95 +0,0 @@
1
- import random
2
-
3
- import torch
4
-
5
- from .utils import init_preconditioner, update_preconditioner, project, beta_debias, update_param_, \
6
- precond_schedule, set_, StatefulOptimizer, exp_avg_
7
-
8
-
9
- class PrecondScheduleForeachSOAP(StatefulOptimizer):
10
- """
11
- Sources:
12
- Preconditioner Schedules:
13
- Preconditioned Stochastic Gradient Descent
14
- Xi-Lin Li, Omead Pooladzandi, Evan Walters
15
- https://arxiv.org/abs/1512.04202
16
- https://github.com/evanatyourservice/kron_torch
17
- https://github.com/lixilinx/psgd_torch
18
-
19
- Baseline SOAP:
20
- SOAP: Improving and Stabilizing Shampoo using Adam
21
- Nikhil Vyas, Depen Morwani, Rosie Zhao, Itai Shapira, David Brandfonbrener, Lucas Janson, Sham Kakade
22
- https://arxiv.org/abs/2409.11321
23
- https://github.com/nikhilvyas/SOAP
24
- """
25
-
26
- def __init__(self, params, lr: float = 3e-3, betas=(0.9, 0.95), shampoo_beta: float = 0.95, eps: float = 1e-8,
27
- weight_decay: float = 0.01, precondition_frequency: int = 2, max_precond_dim: int = 2048, #
28
- merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
29
- data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1,
30
- precond_scheduler=(1 / 3, 9), split: bool = False, foreach: bool = True, mars: bool = False,
31
- caution: bool = False, mars_gamma: float = 0.0025):
32
- defaults = {"lr": lr, "betas": betas, "shampoo_beta": shampoo_beta, "eps": eps, "weight_decay": weight_decay,
33
- "precondition_frequency": precondition_frequency, "max_precond_dim": max_precond_dim,
34
- "merge_dims": merge_dims, "precondition_1d": precondition_1d, "normalize_grads": normalize_grads,
35
- "correct_bias": correct_bias, 'warmup_steps': warmup_steps, 'precond_scheduler': precond_scheduler,
36
- 'split': split, 'mars': mars, 'caution': caution, 'mars_gamma': mars_gamma}
37
- super().__init__(params, defaults, foreach)
38
- self._data_format = data_format
39
- self.rng = random.Random(0x120983109)
40
-
41
- def _step(self, group):
42
- vals = []
43
- step = 0
44
-
45
- max_precond_dim = group['max_precond_dim']
46
- precondition_1d = group['precondition_1d']
47
-
48
- for p, g in self.split_p_and_g_in_group(group, beta1=group['betas'][0]):
49
- state = self.state_(p)
50
- step = state['step'] = state.get("step", -1) + 1
51
-
52
- if "exp_avg" not in state:
53
- state["exp_avg"] = torch.zeros_like(g, dtype=torch.float32, memory_format=torch.preserve_format)
54
- state["exp_avg_sq"] = torch.zeros_like(g, dtype=torch.float32, memory_format=torch.preserve_format)
55
- init_preconditioner(g, state, max_precond_dim, precondition_1d)
56
- update_preconditioner(g, state, max_precond_dim, precondition_1d, 0, True)
57
- continue # first step is skipped so that we never use the current gradients in the projection.
58
-
59
- # Projecting gradients to the eigenbases of Shampoo's preconditioner
60
- # i.e. projecting to the eigenbases of matrices in state['GG']
61
- grad_projected = project(g, state['Q'], False)
62
- exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
63
- vals.append((p, g, grad_projected, exp_avg, exp_avg_sq))
64
-
65
- if not vals:
66
- return
67
-
68
- p_list, grad, grad_projected, exp_avg, exp_avg_sq = zip(*vals)
69
- beta1, beta2 = group["betas"]
70
-
71
- old_debiased2 = beta_debias(beta2, step)
72
-
73
- # Decay the first and second moment running average coefficient
74
- # In-place operations to update the averages at the same time
75
- step_tensor = torch.empty((), dtype=torch.int32, device=p_list[0].device).fill_(step)
76
-
77
- update_precond = precond_schedule(step, group['precond_scheduler'], self.rng)
78
- step_size = -group["lr"] * min(step / group['warmup_steps'], 1)
79
-
80
- for p, g, gp, ea, eas in zip(p_list, grad, grad_projected, exp_avg, exp_avg_sq):
81
- d = exp_avg_(ea, eas, g, gp, beta1, beta2, step_tensor)[0]
82
- state = self.state_(p)
83
- # Projecting the exponential moving average of gradients to the eigenbases of Shampoo's preconditioner
84
- # i.e. projecting to the eigenbases of matrices in state['GG']
85
- exp_avg_projected = project(ea, state['Q'], False)
86
-
87
- # Projecting back the preconditioned (by Adam) exponential moving average of gradients
88
- # to the original space
89
- # CANT DO /= HERE AS EXP_AVG MAY POINT TO THE BUFFER
90
- precond = project(exp_avg_projected / d, state['Q'], True)
91
-
92
- update_preconditioner(g, state, max_precond_dim, precondition_1d, old_debiased2, update_precond)
93
-
94
- update_param_([p], [precond], step_size, group["weight_decay"], caution=group['caution'], grad=[g])
95
-
@@ -1,95 +0,0 @@
1
-
2
-
3
- import random
4
-
5
- import torch
6
-
7
- from .utils import init_preconditioner, update_preconditioner, project, beta_debias, update_param_, \
8
- precond_schedule, set_, StatefulOptimizer, laprop_exp_avg_
9
-
10
-
11
- class PrecondScheduleForeachSOLP(StatefulOptimizer):
12
- """
13
- Sources:
14
- LaProp:
15
- LaProp: Separating Momentum and Adaptivity in Adam
16
- Liu Ziyin, Zhikang T.Wang, Masahito Ueda
17
- https://arxiv.org/abs/2002.04839
18
- https://github.com/ClashLuke/HeavyBall/blob/main/heavyball/foreach_laprop.py
19
-
20
- Preconditioner Schedules:
21
- Preconditioned Stochastic Gradient Descent
22
- Xi-Lin Li, Omead Pooladzandi, Evan Walters
23
- https://arxiv.org/abs/1512.04202
24
- https://github.com/evanatyourservice/kron_torch
25
- https://github.com/lixilinx/psgd_torch
26
-
27
- Baseline SOAP:
28
- SOAP: Improving and Stabilizing Shampoo using Adam
29
- Nikhil Vyas, Depen Morwani, Rosie Zhao, Itai Shapira, David Brandfonbrener, Lucas Janson, Sham Kakade
30
- https://arxiv.org/abs/2409.11321
31
- https://github.com/nikhilvyas/SOAP
32
- """
33
-
34
- def __init__(self, params, lr: float = 3e-3, betas=(0.9, 0.95), shampoo_beta: float = 0.95, eps: float = 1e-8,
35
- weight_decay: float = 0.01, precondition_frequency: int = 2, max_precond_dim: int = 2048, #
36
- merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
37
- data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1,
38
- precond_scheduler=(1 / 3, 9), split: bool = False, foreach: bool = True, mars: bool = False,
39
- caution: bool = False, mars_gamma: float = 0.0025):
40
- defaults = {"lr": lr, "betas": betas, "shampoo_beta": shampoo_beta, "eps": eps, "weight_decay": weight_decay,
41
- "precondition_frequency": precondition_frequency, "max_precond_dim": max_precond_dim,
42
- "merge_dims": merge_dims, "precondition_1d": precondition_1d, "normalize_grads": normalize_grads,
43
- "correct_bias": correct_bias, 'warmup_steps': warmup_steps, 'precond_scheduler': precond_scheduler,
44
- 'split': split, 'mars': mars, 'caution': caution, 'mars_gamma': mars_gamma}
45
- super().__init__(params, defaults, foreach)
46
- self._data_format = data_format
47
- self.rng = random.Random(0x120983109)
48
-
49
- def _step(self, group):
50
- vals = []
51
- step = 0
52
-
53
- max_precond_dim = group['max_precond_dim']
54
- precondition_1d = group['precondition_1d']
55
-
56
- for p, g in self.split_p_and_g_in_group(group, beta1=group['betas'][0]):
57
- state = self.state_(p)
58
- step = state['step'] = state.get("step", -1) + 1
59
-
60
- if "exp_avg" not in state:
61
- state["exp_avg"] = torch.zeros_like(g, dtype=torch.float32, memory_format=torch.preserve_format)
62
- state["exp_avg_sq"] = torch.zeros_like(g, dtype=torch.float32, memory_format=torch.preserve_format)
63
- init_preconditioner(g, state, max_precond_dim, precondition_1d)
64
- update_preconditioner(g, state, max_precond_dim, precondition_1d, 0, True)
65
- continue # first step is skipped so that we never use the current gradients in the projection.
66
-
67
- # Projecting gradients to the eigenbases of Shampoo's preconditioner
68
- # i.e. projecting to the eigenbases of matrices in state['GG']
69
- grad_projected = project(g, state['Q'], False)
70
- exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
71
- vals.append((p, g, grad_projected, exp_avg, exp_avg_sq))
72
-
73
- if not vals:
74
- return
75
-
76
- p_list, grad, grad_projected, exp_avg, exp_avg_sq = zip(*vals)
77
- beta1, beta2 = group["betas"]
78
-
79
- old_debiased2 = beta_debias(beta2, step)
80
-
81
- # Decay the first and second moment running average coefficient
82
- # In-place operations to update the averages at the same time
83
- step_tensor = torch.empty((), dtype=torch.int32, device=p_list[0].device).fill_(step)
84
-
85
- update_precond = precond_schedule(step, group['precond_scheduler'], self.rng)
86
- step_size = -group["lr"] * min(step / group['warmup_steps'], 1)
87
-
88
-
89
- for p, g, gp, ea, eas in zip(p_list, grad, grad_projected, exp_avg, exp_avg_sq):
90
- laprop_exp_avg_(ea, eas, gp, beta1, beta2, step_tensor)
91
- state = self.state_(p)
92
- precond = project(ea, state['Q'], True)
93
-
94
- update_preconditioner(g, state, max_precond_dim, precondition_1d, old_debiased2, update_precond)
95
- update_param_([p], [precond], step_size, group["weight_decay"], caution=group['caution'], grad=[g])
@@ -1,105 +0,0 @@
1
- import random
2
-
3
- import torch
4
-
5
- from .utils import init_preconditioner, update_preconditioner, project, beta_debias, exp_avg_, update_param_, \
6
- precond_schedule, set_, StatefulOptimizer
7
-
8
-
9
- class PrecondSchedulePaLMForeachSOAP(StatefulOptimizer):
10
- """
11
- Sources:
12
- Preconditioner Schedules:
13
- Preconditioned Stochastic Gradient Descent
14
- Xi-Lin Li, Omead Pooladzandi, Evan Walters
15
- https://arxiv.org/abs/1512.04202
16
- https://github.com/evanatyourservice/kron_torch
17
- https://github.com/lixilinx/psgd_torch
18
-
19
- Baseline SOAP:
20
- SOAP: Improving and Stabilizing Shampoo using Adam
21
- Nikhil Vyas, Depen Morwani, Rosie Zhao, Itai Shapira, David Brandfonbrener, Lucas Janson, Sham Kakade
22
- https://arxiv.org/abs/2409.11321
23
- https://github.com/nikhilvyas/SOAP
24
-
25
- Beta2 Schedule:
26
- PaLM: Scaling Language Modeling with Pathways
27
- Aakanksha Chowdhery, Sharan Narang, Jacob Devlin, Maarten Bosma, Gaurav Mishra, Adam Roberts, Paul Barham, Hyung Won Chung, Charles Sutton, Sebastian Gehrmann, Parker Schuh, Kensen Shi, Sasha Tsvyashchenko, Joshua Maynez, Abhishek Rao, Parker Barnes, Yi Tay, Noam Shazeer, Vinodkumar Prabhakaran, Emily Reif, Nan Du, Ben Hutchinson, Reiner Pope, James Bradbury, Jacob Austin, Michael Isard, Guy Gur-Ari, Pengcheng Yin, Toju Duke, Anselm Levskaya, Sanjay Ghemawat, Sunipa Dev, Henryk Michalewski, Xavier Garcia, Vedant Misra, Kevin Robinson, Liam Fedus, Denny Zhou, Daphne Ippolito, David Luan, Hyeontaek Lim, Barret Zoph, Alexander Spiridonov, Ryan Sepassi, David Dohan, Shivani Agrawal, Mark Omernick, Andrew M. Dai, Thanumalayan Sankaranarayana Pillai, Marie Pellat, Aitor Lewkowycz, Erica Moreira, Rewon Child, Oleksandr Polozov, Katherine Lee, Zongwei Zhou, Xuezhi Wang, Brennan Saeta, Mark Diaz, Orhan Firat, Michele Catasta, Jason Wei, Kathy Meier-Hellstern, Douglas Eck, Jeff Dean, Slav Petrov, Noah Fiedel
28
- https://arxiv.org/abs/2204.02311
29
- """
30
-
31
- def __init__(self, params, lr: float = 3e-3, beta=0.9, shampoo_beta: float = 0.95, eps: float = 1e-8,
32
- weight_decay: float = 0.01, precondition_frequency: int = 2, max_precond_dim: int = 2048, #
33
- merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
34
- data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1,
35
- precond_scheduler=(1 / 3, 9), betas=(None, None), beta2_scale: float = 0.8, split: bool = False,
36
- foreach: bool = True, mars: bool = False, caution: bool = False, mars_gamma: float = 0.0025):
37
- if betas[0] is not None:
38
- beta = betas[0]
39
- defaults = {"lr": lr, "beta": beta, "shampoo_beta": shampoo_beta, "eps": eps, "weight_decay": weight_decay,
40
- "precondition_frequency": precondition_frequency, "max_precond_dim": max_precond_dim,
41
- "merge_dims": merge_dims, "precondition_1d": precondition_1d, "normalize_grads": normalize_grads,
42
- "correct_bias": correct_bias, 'warmup_steps': warmup_steps, 'precond_scheduler': precond_scheduler,
43
- 'beta2_scale': beta2_scale, 'split': split, 'mars': mars, 'caution': caution,
44
- 'mars_gamma': mars_gamma}
45
- super().__init__(params, defaults, foreach)
46
- self._data_format = data_format
47
- self.rng = random.Random(0x120983109)
48
-
49
- def _step(self, group):
50
- vals = []
51
- step = 0
52
-
53
- max_precond_dim = group['max_precond_dim']
54
- precondition_1d = group['precondition_1d']
55
-
56
- for p, g in self.split_p_and_g_in_group(group, beta1=group['beta']):
57
- state = self.state_(p)
58
- step = state['step'] = state.get("step", -1) + 1
59
-
60
- if "exp_avg" not in state:
61
- state["exp_avg"] = torch.zeros_like(g, dtype=torch.float32, memory_format=torch.preserve_format)
62
- state["exp_avg_sq"] = torch.zeros_like(g, dtype=torch.float32, memory_format=torch.preserve_format)
63
- init_preconditioner(g, state, max_precond_dim, precondition_1d)
64
- update_preconditioner(g, state, max_precond_dim, precondition_1d, 0, True)
65
- continue # first step is skipped so that we never use the current gradients in the projection.
66
-
67
- # Projecting gradients to the eigenbases of Shampoo's preconditioner
68
- # i.e. projecting to the eigenbases of matrices in state['GG']
69
- grad_projected = project(g, state['Q'], False)
70
- exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
71
- vals.append((p, g, grad_projected, exp_avg, exp_avg_sq))
72
-
73
- if not vals:
74
- return
75
-
76
- p_list, grad, grad_projected, exp_avg, exp_avg_sq = zip(*vals)
77
- beta1 = group["beta"]
78
-
79
- beta2 = 1 - max(step, 1) ** -group['beta2_scale']
80
- old_debiased1 = beta_debias(beta1, step)
81
- old_debiased2 = beta_debias(beta2, step)
82
-
83
- # Decay the first and second moment running average coefficient
84
- # In-place operations to update the averages at the same time
85
- beta2 = torch.empty((), dtype=torch.float32, device=p_list[0].device).fill_(beta2)
86
- step_tensor = torch.empty((), dtype=torch.int32, device=p_list[0].device).fill_(step)
87
-
88
- update_precond = precond_schedule(step, group['precond_scheduler'], self.rng)
89
- step_size = -group["lr"] * min(step / group['warmup_steps'], 1)
90
-
91
- for p, g, gp, ea, eas in zip(p_list, grad, grad_projected, exp_avg, exp_avg_sq):
92
- d = exp_avg_(ea, eas, g, gp, beta1, beta2, step_tensor)[0]
93
- state = self.state_(p)
94
- # Projecting the exponential moving average of gradients to the eigenbases of Shampoo's preconditioner
95
- # i.e. projecting to the eigenbases of matrices in state['GG']
96
- exp_avg_projected = project(ea, state['Q'], False)
97
-
98
- # Projecting back the preconditioned (by Adam) exponential moving average of gradients
99
- # to the original space
100
- # CANT DO /= HERE AS EXP_AVG MAY POINT TO THE BUFFER
101
- exp_avg_projected = exp_avg_projected / d
102
- precond = project(exp_avg_projected, state['Q'], True)
103
-
104
- update_preconditioner(g, state, max_precond_dim, precondition_1d, old_debiased2, update_precond)
105
- update_param_([p], [precond], step_size, group["weight_decay"], caution=group['caution'], grad=[g])
@@ -1,103 +0,0 @@
1
- import random
2
-
3
- import torch
4
-
5
- from .utils import init_preconditioner, update_preconditioner, project, beta_debias, laprop_exp_avg_, update_param_, \
6
- precond_schedule, set_, StatefulOptimizer
7
-
8
-
9
- class PrecondSchedulePaLMForeachSOLP(StatefulOptimizer):
10
- """
11
- Sources:
12
- LaProp:
13
- LaProp: Separating Momentum and Adaptivity in Adam
14
- Liu Ziyin, Zhikang T.Wang, Masahito Ueda
15
- https://arxiv.org/abs/2002.04839
16
- https://github.com/ClashLuke/HeavyBall/blob/main/heavyball/foreach_laprop.py
17
-
18
- Preconditioner Schedules:
19
- Preconditioned Stochastic Gradient Descent
20
- Xi-Lin Li, Omead Pooladzandi, Evan Walters
21
- https://arxiv.org/abs/1512.04202
22
- https://github.com/evanatyourservice/kron_torch
23
- https://github.com/lixilinx/psgd_torch
24
-
25
- Baseline SOAP:
26
- SOAP: Improving and Stabilizing Shampoo using Adam
27
- Nikhil Vyas, Depen Morwani, Rosie Zhao, Itai Shapira, David Brandfonbrener, Lucas Janson, Sham Kakade
28
- https://arxiv.org/abs/2409.11321
29
- https://github.com/nikhilvyas/SOAP
30
-
31
- Beta2 Schedule:
32
- PaLM: Scaling Language Modeling with Pathways
33
- Aakanksha Chowdhery, Sharan Narang, Jacob Devlin, Maarten Bosma, Gaurav Mishra, Adam Roberts, Paul Barham, Hyung Won Chung, Charles Sutton, Sebastian Gehrmann, Parker Schuh, Kensen Shi, Sasha Tsvyashchenko, Joshua Maynez, Abhishek Rao, Parker Barnes, Yi Tay, Noam Shazeer, Vinodkumar Prabhakaran, Emily Reif, Nan Du, Ben Hutchinson, Reiner Pope, James Bradbury, Jacob Austin, Michael Isard, Guy Gur-Ari, Pengcheng Yin, Toju Duke, Anselm Levskaya, Sanjay Ghemawat, Sunipa Dev, Henryk Michalewski, Xavier Garcia, Vedant Misra, Kevin Robinson, Liam Fedus, Denny Zhou, Daphne Ippolito, David Luan, Hyeontaek Lim, Barret Zoph, Alexander Spiridonov, Ryan Sepassi, David Dohan, Shivani Agrawal, Mark Omernick, Andrew M. Dai, Thanumalayan Sankaranarayana Pillai, Marie Pellat, Aitor Lewkowycz, Erica Moreira, Rewon Child, Oleksandr Polozov, Katherine Lee, Zongwei Zhou, Xuezhi Wang, Brennan Saeta, Mark Diaz, Orhan Firat, Michele Catasta, Jason Wei, Kathy Meier-Hellstern, Douglas Eck, Jeff Dean, Slav Petrov, Noah Fiedel
34
- https://arxiv.org/abs/2204.02311
35
- """
36
-
37
- def __init__(self, params, lr: float = 3e-3, beta=0.9, shampoo_beta: float = 0.95, eps: float = 1e-8,
38
- weight_decay: float = 0.01, precondition_frequency: int = 2, max_precond_dim: int = 2048, #
39
- merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
40
- data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1,
41
- precond_scheduler=(1 / 3, 9), betas=(None, None), beta2_scale: float = 0.8, split: bool = False,
42
- foreach: bool = True, mars: bool = False, caution: bool = False, mars_gamma: float = 0.0025):
43
- if betas[0] is not None:
44
- beta = betas[0]
45
- defaults = {"lr": lr, "beta": beta, "shampoo_beta": shampoo_beta, "eps": eps, "weight_decay": weight_decay,
46
- "precondition_frequency": precondition_frequency, "max_precond_dim": max_precond_dim,
47
- "merge_dims": merge_dims, "precondition_1d": precondition_1d, "normalize_grads": normalize_grads,
48
- "correct_bias": correct_bias, 'warmup_steps': warmup_steps, 'precond_scheduler': precond_scheduler,
49
- 'beta2_scale': beta2_scale, 'split': split, 'mars': mars, 'caution': caution,
50
- 'mars_gamma': mars_gamma}
51
- super().__init__(params, defaults, foreach)
52
- self._data_format = data_format
53
- self.rng = random.Random(0x120983109)
54
-
55
- def _step(self, group):
56
- vals = []
57
- step = 0
58
-
59
- max_precond_dim = group['max_precond_dim']
60
- precondition_1d = group['precondition_1d']
61
-
62
- for p, g in self.split_p_and_g_in_group(group, beta1=group['beta']):
63
- state = self.state_(p)
64
- step = state['step'] = state.get("step", -1) + 1
65
-
66
- if "exp_avg" not in state:
67
- state["exp_avg"] = torch.zeros_like(g, dtype=torch.float32, memory_format=torch.preserve_format)
68
- state["exp_avg_sq"] = torch.zeros_like(g, dtype=torch.float32, memory_format=torch.preserve_format)
69
- init_preconditioner(g, state, max_precond_dim, precondition_1d)
70
- update_preconditioner(g, state, max_precond_dim, precondition_1d, 0, True)
71
- continue # first step is skipped so that we never use the current gradients in the projection.
72
-
73
- # Projecting gradients to the eigenbases of Shampoo's preconditioner
74
- # i.e. projecting to the eigenbases of matrices in state['GG']
75
- grad_projected = project(g, state['Q'], False)
76
- exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
77
- vals.append((p, g, grad_projected, exp_avg, exp_avg_sq))
78
-
79
- if not vals:
80
- return
81
-
82
- p_list, grad, grad_projected, exp_avg, exp_avg_sq = zip(*vals)
83
- beta1 = group["beta"]
84
-
85
- beta2 = 1 - max(step, 1) ** -group['beta2_scale']
86
- old_debiased1 = beta_debias(beta1, step)
87
- old_debiased2 = beta_debias(beta2, step)
88
-
89
- # Decay the first and second moment running average coefficient
90
- # In-place operations to update the averages at the same time
91
- beta2 = torch.empty((), dtype=torch.float32, device=p_list[0].device).fill_(beta2)
92
- step_tensor = torch.empty((), dtype=torch.int32, device=p_list[0].device).fill_(step)
93
-
94
- update_precond = precond_schedule(step, group['precond_scheduler'], self.rng)
95
- step_size = -group["lr"] * min(step / group['warmup_steps'], 1)
96
-
97
- for p, g, gp, ea, eas in zip(p_list, grad, grad_projected, exp_avg, exp_avg_sq):
98
- laprop_exp_avg_(ea, eas, gp, beta1, beta2, step_tensor)
99
- state = self.state_(p)
100
- precond = project(ea, state['Q'], True)
101
-
102
- update_preconditioner(g, state, max_precond_dim, precondition_1d, old_debiased2, update_precond)
103
- update_param_([p], [precond], step_size, group["weight_decay"], caution=group['caution'], grad=[g])
@@ -1,141 +0,0 @@
1
- import random
2
-
3
- import torch
4
-
5
- from .utils import init_preconditioner, update_preconditioner, project, set_, adaptive_gradient_clipping_, exp_avg_sq_, \
6
- beta_debias, schedule_free_, warmup, ScheduleFree, precond_schedule, copy_stochastic_list_, \
7
- promote, decorator_knowngood
8
-
9
-
10
- @decorator_knowngood
11
- def _compilable_exp_avg_sq_(exp_avg_sq, grad_projected, old_debiased2, eps):
12
- eas32, gp32 = [list(map(promote, x)) for x in (exp_avg_sq, grad_projected)]
13
- denom = exp_avg_sq_(eas32, gp32, old_debiased2, eps)
14
- torch._foreach_div_(gp32, denom)
15
-
16
- copy_stochastic_list_(exp_avg_sq, eas32)
17
- copy_stochastic_list_(grad_projected, gp32)
18
-
19
-
20
- class PrecondScheduleSFPaLMSOAP(ScheduleFree):
21
- """
22
- SFPaLMForeachSOAP with preconditioner schedules
23
-
24
- Sources:
25
- Preconditioner Schedules:
26
- Preconditioned Stochastic Gradient Descent
27
- Xi-Lin Li, Omead Pooladzandi, Evan Walters
28
- https://arxiv.org/abs/1512.04202
29
- https://github.com/evanatyourservice/kron_torch
30
- https://github.com/lixilinx/psgd_torch
31
-
32
- Baseline SOAP:
33
- SOAP: Improving and Stabilizing Shampoo using Adam
34
- Nikhil Vyas, Depen Morwani, Rosie Zhao, Itai Shapira, David Brandfonbrener, Lucas Janson, Sham Kakade
35
- https://arxiv.org/abs/2409.11321
36
- https://github.com/nikhilvyas/SOAP
37
-
38
- ScheduleFree:
39
- The Road Less Scheduled
40
- Aaron Defazio, Xingyu Alice Yang, Harsh Mehta, Konstantin Mishchenko, Ahmed Khaled, Ashok Cutkosky
41
- https://arxiv.org/abs/2405.15682
42
- https://github.com/facebookresearch/schedule_free
43
-
44
- Beta2 Schedule:
45
- PaLM: Scaling Language Modeling with Pathways
46
- Aakanksha Chowdhery, Sharan Narang, Jacob Devlin, Maarten Bosma, Gaurav Mishra, Adam Roberts, Paul Barham, Hyung Won Chung, Charles Sutton, Sebastian Gehrmann, Parker Schuh, Kensen Shi, Sasha Tsvyashchenko, Joshua Maynez, Abhishek Rao, Parker Barnes, Yi Tay, Noam Shazeer, Vinodkumar Prabhakaran, Emily Reif, Nan Du, Ben Hutchinson, Reiner Pope, James Bradbury, Jacob Austin, Michael Isard, Guy Gur-Ari, Pengcheng Yin, Toju Duke, Anselm Levskaya, Sanjay Ghemawat, Sunipa Dev, Henryk Michalewski, Xavier Garcia, Vedant Misra, Kevin Robinson, Liam Fedus, Denny Zhou, Daphne Ippolito, David Luan, Hyeontaek Lim, Barret Zoph, Alexander Spiridonov, Ryan Sepassi, David Dohan, Shivani Agrawal, Mark Omernick, Andrew M. Dai, Thanumalayan Sankaranarayana Pillai, Marie Pellat, Aitor Lewkowycz, Erica Moreira, Rewon Child, Oleksandr Polozov, Katherine Lee, Zongwei Zhou, Xuezhi Wang, Brennan Saeta, Mark Diaz, Orhan Firat, Michele Catasta, Jason Wei, Kathy Meier-Hellstern, Douglas Eck, Jeff Dean, Slav Petrov, Noah Fiedel
47
- https://arxiv.org/abs/2204.02311
48
- """
49
-
50
- def __init__(self, params, lr: float = 3e-3, beta=0.9, beta2_scale: float = 0.8, eps: float = 1e-8,
51
- weight_decay: float = 0.01, precondition_frequency: int = 2, max_precond_dim: int = 2048, #
52
- merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
53
- data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1, r=0.0,
54
- weight_lr_power=2.0, gradient_clip_val: float = 0.1, precond_scheduler=(1 / 3, 9), betas=(None, None),
55
- split: bool = False, foreach: bool = True, mars: bool = False, caution: bool = False,
56
- mars_gamma: float = 0.0025):
57
- if betas[0] is not None:
58
- beta = betas[0]
59
-
60
- assert not caution, "Caution is not implemented in ScheduleFree optimizers"
61
-
62
- defaults = {"lr": lr, "beta": beta, "beta2_scale": beta2_scale, "eps": eps, "weight_decay": weight_decay,
63
- "precondition_frequency": precondition_frequency, "max_precond_dim": max_precond_dim,
64
- "merge_dims": merge_dims, "precondition_1d": precondition_1d, "normalize_grads": normalize_grads,
65
- "correct_bias": correct_bias, 'warmup_steps': warmup_steps, 'r': r,
66
- 'weight_lr_power': weight_lr_power, 'train_mode': True, 'step': -1, 'weight_sum': 0,
67
- 'gradient_clip_val': gradient_clip_val, 'precond_scheduler': precond_scheduler, 'split': split,
68
- 'mars': mars, 'caution': caution, 'mars_gamma': mars_gamma}
69
- super().__init__(params, defaults, foreach)
70
- self._data_format = data_format
71
- self.rng = random.Random(0x120983109)
72
-
73
- def _step(self, group):
74
- vals = []
75
- max_precond_dim = group['max_precond_dim']
76
- precondition_1d = group['precondition_1d']
77
-
78
- step = group['step'] = group.get("step", 0) + 1
79
-
80
- for p in group["params"]:
81
- if p.grad is None:
82
- continue
83
- grad = p.grad.float()
84
- vals.append((p, grad))
85
-
86
- if not vals:
87
- return
88
-
89
- p_list, grad = zip(*vals)
90
- vals = []
91
-
92
- # adaptive gradient clipping
93
- adaptive_gradient_clipping_(p_list, grad, group["gradient_clip_val"], eps=group["eps"])
94
-
95
- for p, g in self.split_p_and_g_in_group(group, beta1=group['beta']):
96
- state = self.state_(p)
97
-
98
- if "z" not in state:
99
- state["z"] = torch.clone(p.data, memory_format=torch.preserve_format)
100
- state["exp_avg_sq"] = torch.zeros_like(g, dtype=torch.float32, memory_format=torch.preserve_format)
101
- init_preconditioner(g, state, max_precond_dim, precondition_1d)
102
- update_preconditioner(g, state, max_precond_dim, precondition_1d, 0, True)
103
- continue # first step is skipped so that we never use the current gradients in the projection.
104
-
105
- # Projecting gradients to the eigenbases of Shampoo's preconditioner
106
- # i.e. projecting to the eigenbases of matrices in state['GG']
107
- grad_projected = project(g, state['Q'], False)
108
- z, exp_avg_sq = state["z"], state["exp_avg_sq"]
109
- vals.append((p, g, grad_projected, z, exp_avg_sq))
110
-
111
- if not vals:
112
- return
113
-
114
- p_list, grad, grad_projected, z, exp_avg_sq = zip(*vals)
115
- del vals
116
-
117
- beta2 = 1 - max(step, 1) ** -group['beta2_scale']
118
- old_debiased2 = beta_debias(beta2, step)
119
-
120
- # Decay the first and second moment running average coefficient
121
- # In-place operations to update the averages at the same time
122
- old_debiased_tensor = torch.empty((), dtype=torch.float32, device=p_list[0].device).fill_(old_debiased2)
123
- _compilable_exp_avg_sq_(exp_avg_sq, grad_projected, old_debiased_tensor, group["eps"])
124
-
125
- update_precond = precond_schedule(step, group['precond_scheduler'], self.rng)
126
-
127
- for p, g, gp in zip(p_list, grad, grad_projected):
128
- state = self.state_(p)
129
- # Projecting back the preconditioned (by Adam) exponential moving average of gradients
130
- # to the original space
131
- set_(gp, project(gp, state['Q'], back=True))
132
-
133
- update_preconditioner(g, state, max_precond_dim, precondition_1d, old_debiased2, update_precond)
134
-
135
- # Weight decay calculated at y
136
- if group["weight_decay"] > 0:
137
- torch._foreach_add_(grad, p_list, alpha=group["weight_decay"])
138
-
139
- lr = warmup(group['lr'], step, group['warmup_steps'])
140
- group['weight_sum'] = schedule_free_(lr, group['weight_lr_power'], group['weight_sum'], group['beta'], p_list,
141
- z, grad_projected, group['r'], step)
heavyball/psgd_kron.py DELETED
@@ -1,120 +0,0 @@
1
- """
2
- Originally from Evan Walters and Omead Pooladzandi, 2024
3
- Modified under Creative Commons Attribution 4.0 International
4
- Source available at https://github.com/evanatyourservice/kron_torch/blob/97a2b5ee8a1a4c29e4780bbf6c521e545189eff9/kron_torch/kron.py
5
- """
6
-
7
- from typing import Optional
8
-
9
- import torch
10
-
11
- from .utils import update_param_, warmup, psgd_precond_grad, init_Q_exprs, trust_region_clip_, PSGDBase, \
12
- line_to_triu, triu_to_line, promote, stochastic_lerp_, beta_debias
13
-
14
-
15
- class ForeachPSGDKron(PSGDBase):
16
- """Implements PSGD Kron from https://github.com/lixilinx/psgd_torch.
17
-
18
- Args:
19
- params (iterable): Iterable of parameters to optimize or dicts defining
20
- parameter groups.
21
- lr (float): Learning rate.
22
- b1 (float): Momentum parameter.
23
- weight_decay (float): Weight decay (L2 penalty).
24
- preconditioner_update_probability (callable or float, optional): Probability of
25
- updating the preconditioner. If None, defaults to a schedule that anneals
26
- from 1.0 to 0.03 by 4000 steps.
27
- max_size_triangular (int): Max size for dim's preconditioner to be triangular.
28
- min_ndim_triangular (int): Minimum number of dimensions a layer needs
29
- to have triangular preconditioners.
30
- memory_save_mode: (string, optional), None, 'one_diag', or 'all_diag', None is default
31
- to set all preconditioners to be triangular, 'one_diag' sets the largest
32
- or last dim to be diagonal per layer, and 'all_diag' sets all preconditioners
33
- to be diagonal.
34
- momentum_into_precond_update: (bool), whether to send momentum into preconditioner
35
- update instead of raw gradients.
36
- """
37
-
38
- def __init__(self, params, lr=0.001, beta=0.9, weight_decay=0.0, preconditioner_update_probability=None,
39
- max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
40
- momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
41
- split: bool = False, clip_fn: Optional[callable] = None, store_triu_as_line: bool = True,
42
- foreach: bool = True, q_dtype='float32', stochastic_schedule: bool = True,
43
- storage_dtype: str = 'float32', mars: bool = False, caution: bool = False, mars_gamma: float = 0.0025,
44
- #
45
- # expert parameters
46
- precond_init_scale=1.0, precond_lr=0.1):
47
- if not 0.0 <= lr:
48
- raise ValueError(f"Invalid learning rate: {lr}")
49
- if not 0.0 <= beta < 1.0:
50
- raise ValueError(f"Invalid beta parameter: {beta}")
51
- if not 0.0 <= weight_decay:
52
- raise ValueError(f"Invalid weight_decay value: {weight_decay}")
53
-
54
- if clip_fn is None:
55
- clip_fn = lambda x: trust_region_clip_(x, 0.9, 1.5)
56
-
57
- defaults = dict(lr=lr, beta=beta, weight_decay=weight_decay, max_size_triangular=max_size_triangular,
58
- min_ndim_triangular=min_ndim_triangular, memory_save_mode=memory_save_mode,
59
- momentum_into_precond_update=momentum_into_precond_update, precond_lr=precond_lr,
60
- precond_init_scale=precond_init_scale, step=0, warmup_steps=warmup_steps, merge_dims=merge_dims,
61
- split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype,
62
- storage_dtype=storage_dtype,
63
- mars=mars, caution=caution, mars_gamma=mars_gamma)
64
- super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
65
-
66
- def _step(self, group):
67
- should_update = self.should_update(group)
68
- momentum_into_precond_update = group.get("momentum_into_precond_update", True)
69
- precond_init_scale = group['precond_init_scale']
70
- max_size_triangular = group['max_size_triangular']
71
- min_ndim_triangular = group['min_ndim_triangular']
72
- memory_save_mode = group['memory_save_mode']
73
- precond_lr = group['precond_lr']
74
- weight_decay = group['weight_decay']
75
- lr = group['lr']
76
- beta = group['beta']
77
- store_triu_as_line = group['store_triu_as_line']
78
- q_dtype = getattr(torch, group['q_dtype'])
79
- storage_dtype = getattr(torch, group['storage_dtype'])
80
-
81
- vals = []
82
-
83
- for p, g in self.split_p_and_g_in_group(group, should_promote=False, beta1=beta):
84
- state = self.state_(p)
85
-
86
- if 'Q' not in state:
87
- state["exp_avg"] = torch.zeros_like(g, dtype=storage_dtype, memory_format=torch.preserve_format)
88
- Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular, min_ndim_triangular,
89
- memory_save_mode, dtype=q_dtype)
90
- state['Q'] = triu_to_line(Q) if store_triu_as_line else Q
91
-
92
- vals.append((p, g, state["exp_avg"], state["Q"]))
93
-
94
- if not vals:
95
- return
96
-
97
- p_list, grad_list, exp_avg_list, Q_list = zip(*vals)
98
- del vals
99
-
100
- group["step"] += 1
101
-
102
- beta = beta_debias(beta, group["step"])
103
- beta = torch.empty((), dtype=torch.float32, device=grad_list[0].device).fill_(1 - beta)
104
- stochastic_lerp_(exp_avg_list, grad_list, 1 - beta)
105
-
106
- grad_list, Q_list, exp_avg_list = list(grad_list), list(Q_list), list(exp_avg_list)
107
-
108
- lr = -warmup(lr, group['step'], group['warmup_steps'])
109
-
110
- for i, (p, g) in enumerate(zip(p_list, grad_list)):
111
- q_orig = Q_list.pop(0)
112
- ea = exp_avg_list.pop(0)
113
- q = line_to_triu(q_orig) if store_triu_as_line else q_orig
114
-
115
- if should_update:
116
- q32 = [promote(q_) for q_ in q]
117
- self.do_update(group, [p], [ea if momentum_into_precond_update else g], [q32], precond_lr, [q_orig],
118
- store_triu_as_line)
119
- g = psgd_precond_grad(False, self.state_(p)["exprs"][-1], ea, *q)
120
- update_param_([p], self.clip_fn([g]), lr, weight_decay, caution=group['caution'], grad=[g])