heavyball 0.25.0__tar.gz → 0.25.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. {heavyball-0.25.0 → heavyball-0.25.1}/PKG-INFO +1 -1
  2. {heavyball-0.25.0 → heavyball-0.25.1}/heavyball/__init__.py +12 -2
  3. {heavyball-0.25.0 → heavyball-0.25.1}/heavyball/foreach_soap.py +3 -1
  4. heavyball-0.25.1/heavyball/foreach_solp.py +89 -0
  5. {heavyball-0.25.0 → heavyball-0.25.1}/heavyball/palm_foreach_soap.py +2 -1
  6. heavyball-0.25.1/heavyball/palm_foreach_solp.py +98 -0
  7. heavyball-0.25.1/heavyball/precond_schedule_foreach_solp.py +95 -0
  8. heavyball-0.25.1/heavyball/precond_schedule_palm_foreach_solp.py +103 -0
  9. {heavyball-0.25.0 → heavyball-0.25.1}/heavyball/utils.py +23 -0
  10. {heavyball-0.25.0 → heavyball-0.25.1}/heavyball.egg-info/PKG-INFO +1 -1
  11. {heavyball-0.25.0 → heavyball-0.25.1}/heavyball.egg-info/SOURCES.txt +5 -0
  12. {heavyball-0.25.0 → heavyball-0.25.1}/setup.py +1 -1
  13. heavyball-0.25.1/test/test_solp.py +50 -0
  14. {heavyball-0.25.0 → heavyball-0.25.1}/LICENSE +0 -0
  15. {heavyball-0.25.0 → heavyball-0.25.1}/README.md +0 -0
  16. {heavyball-0.25.0 → heavyball-0.25.1}/heavyball/cached_delayed_psgd_kron.py +0 -0
  17. {heavyball-0.25.0 → heavyball-0.25.1}/heavyball/cached_psgd_kron.py +0 -0
  18. {heavyball-0.25.0 → heavyball-0.25.1}/heavyball/delayed_psgd.py +0 -0
  19. {heavyball-0.25.0 → heavyball-0.25.1}/heavyball/foreach_adamw.py +0 -0
  20. {heavyball-0.25.0 → heavyball-0.25.1}/heavyball/foreach_adopt.py +0 -0
  21. {heavyball-0.25.0 → heavyball-0.25.1}/heavyball/foreach_laprop.py +0 -0
  22. {heavyball-0.25.0 → heavyball-0.25.1}/heavyball/foreach_sfadamw.py +0 -0
  23. {heavyball-0.25.0 → heavyball-0.25.1}/heavyball/p_adam.py +0 -0
  24. {heavyball-0.25.0 → heavyball-0.25.1}/heavyball/palm_foreach_sfadamw.py +0 -0
  25. {heavyball-0.25.0 → heavyball-0.25.1}/heavyball/precond_schedule_foreach_soap.py +0 -0
  26. {heavyball-0.25.0 → heavyball-0.25.1}/heavyball/precond_schedule_palm_foreach_soap.py +0 -0
  27. {heavyball-0.25.0 → heavyball-0.25.1}/heavyball/precond_schedule_sfpsoap.py +0 -0
  28. {heavyball-0.25.0 → heavyball-0.25.1}/heavyball/psgd_kron.py +0 -0
  29. {heavyball-0.25.0 → heavyball-0.25.1}/heavyball/pure_psgd.py +0 -0
  30. {heavyball-0.25.0 → heavyball-0.25.1}/heavyball/schedule_free_palm_foreach_soap.py +0 -0
  31. {heavyball-0.25.0 → heavyball-0.25.1}/heavyball.egg-info/dependency_links.txt +0 -0
  32. {heavyball-0.25.0 → heavyball-0.25.1}/heavyball.egg-info/requires.txt +0 -0
  33. {heavyball-0.25.0 → heavyball-0.25.1}/heavyball.egg-info/top_level.txt +0 -0
  34. {heavyball-0.25.0 → heavyball-0.25.1}/setup.cfg +0 -0
  35. {heavyball-0.25.0 → heavyball-0.25.1}/test/test_bf16_params.py +0 -0
  36. {heavyball-0.25.0 → heavyball-0.25.1}/test/test_bf16_q.py +0 -0
  37. {heavyball-0.25.0 → heavyball-0.25.1}/test/test_bf16_storage.py +0 -0
  38. {heavyball-0.25.0 → heavyball-0.25.1}/test/test_caution.py +0 -0
  39. {heavyball-0.25.0 → heavyball-0.25.1}/test/test_channels_last.py +0 -0
  40. {heavyball-0.25.0 → heavyball-0.25.1}/test/test_closure.py +0 -0
  41. {heavyball-0.25.0 → heavyball-0.25.1}/test/test_ema.py +0 -0
  42. {heavyball-0.25.0 → heavyball-0.25.1}/test/test_foreach.py +0 -0
  43. {heavyball-0.25.0 → heavyball-0.25.1}/test/test_hook.py +0 -0
  44. {heavyball-0.25.0 → heavyball-0.25.1}/test/test_mars.py +0 -0
  45. {heavyball-0.25.0 → heavyball-0.25.1}/test/test_memory.py +0 -0
  46. {heavyball-0.25.0 → heavyball-0.25.1}/test/test_merge.py +0 -0
  47. {heavyball-0.25.0 → heavyball-0.25.1}/test/test_no_grad.py +0 -0
  48. {heavyball-0.25.0 → heavyball-0.25.1}/test/test_psgd.py +0 -0
  49. {heavyball-0.25.0 → heavyball-0.25.1}/test/test_soap.py +0 -0
  50. {heavyball-0.25.0 → heavyball-0.25.1}/test/test_stochastic_updates.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: heavyball
3
- Version: 0.25.0
3
+ Version: 0.25.1
4
4
  Summary: Efficient optimizers
5
5
  Home-page: https://github.com/clashluke/heavyball
6
6
  Author: Lucas Nestler
@@ -15,6 +15,10 @@ from .precond_schedule_sfpsoap import PrecondScheduleSFPaLMSOAP
15
15
  from .psgd_kron import ForeachPSGDKron
16
16
  from .pure_psgd import ForeachPurePSGD
17
17
  from .schedule_free_palm_foreach_soap import SFPaLMForeachSOAP
18
+ from .foreach_solp import ForeachSOLP
19
+ from .palm_foreach_solp import PaLMForeachSOLP
20
+ from .precond_schedule_palm_foreach_solp import PrecondSchedulePaLMForeachSOLP
21
+ from .precond_schedule_foreach_solp import PrecondScheduleForeachSOLP
18
22
 
19
23
  PalmForEachSoap = PaLMForeachSOAP
20
24
 
@@ -35,12 +39,18 @@ PaLMPAdam = ForeachPaLMPAdam
35
39
  DelayedPSGD = ForeachDelayedPSGD
36
40
  CachedPSGDKron = ForeachCachedPSGDKron
37
41
  CachedDelayedPSGDKron = ForeachCachedDelayedPSGDKron
42
+ SOLP = ForeachSOLP
43
+ PaLMSOLP = PaLMForeachSOLP
44
+ PrecondSchedulePaLMSOLP = PrecondSchedulePaLMForeachSOLP
45
+ PrecondScheduleSOLP = PrecondScheduleForeachSOLP
38
46
 
39
47
  __all__ = ['PalmForEachSoap', 'PaLMForeachSFAdamW', 'PaLMForeachSOAP', 'SFPaLMForeachSOAP', 'PrecondScheduleSFPaLMSOAP',
40
48
  'ForeachSOAP', 'ForeachSFAdamW', 'ForeachLaProp', 'ForeachADOPT', 'PrecondScheduleForeachSOAP',
41
49
  'PrecondSchedulePaLMForeachSOAP', 'ForeachPSGDKron', 'ForeachAdamW', 'ForeachPurePSGD', 'ForeachPaLMPAdam',
42
- 'ForeachDelayedPSGD', 'ForeachCachedPSGDKron', 'ForeachCachedDelayedPSGDKron',
50
+ 'ForeachDelayedPSGD', 'ForeachCachedPSGDKron', 'ForeachCachedDelayedPSGDKron', 'ForeachSOLP',
51
+ 'PaLMForeachSOLP', 'PrecondSchedulePaLMForeachSOLP', 'PrecondScheduleForeachSOLP',
43
52
  #
44
53
  'PaLMSOAP', 'PaLMSFAdamW', 'PaLMSFSoap', 'PaLMSFAdamW', 'PrecondScheduleSFPaLMSOAP', 'SOAP', 'SFAdamW',
45
54
  'LaProp', 'ADOPT', 'PSGDKron', 'AdamW', 'PurePSGD', 'PaLMPAdam', 'DelayedPSGD', 'CachedPSGDKron',
46
- 'CachedDelayedPSGDKron', 'PrecondScheduleSOAP', 'PrecondSchedulePaLMSOAP']
55
+ 'CachedDelayedPSGDKron', 'PrecondScheduleSOAP', 'PrecondSchedulePaLMSOAP', 'SOLP', 'PrecondScheduleSOLP',
56
+ 'PrecondSchedulePaLMSOLP', 'PrecondScheduleSOLP']
@@ -6,7 +6,7 @@ from .utils import init_preconditioner, update_preconditioner, project, beta_deb
6
6
 
7
7
  class ForeachSOAP(StatefulOptimizer):
8
8
  """
9
- SFPaLMForeachSOAP
9
+ ForeachSOAP
10
10
 
11
11
  Sources:
12
12
  Baseline SOAP:
@@ -75,6 +75,8 @@ class ForeachSOAP(StatefulOptimizer):
75
75
  step_size = -group["lr"] * min(step / group['warmup_steps'], 1)
76
76
 
77
77
  for p, g, gp, ea, eas in zip(p_list, grad, grad_projected, exp_avg, exp_avg_sq):
78
+ state = self.state_(p)
79
+
78
80
  d = exp_avg_(ea, eas, g, gp, beta1, beta2, step_tensor)[0]
79
81
  # Projecting the exponential moving average of gradients to the eigenbases of Shampoo's preconditioner
80
82
  # i.e. projecting to the eigenbases of matrices in state['GG']
@@ -0,0 +1,89 @@
1
+ import torch
2
+
3
+ from .utils import init_preconditioner, update_preconditioner, project, beta_debias, exp_avg_sq_, update_param_, set_, \
4
+ StatefulOptimizer, laprop_exp_avg_
5
+
6
+
7
+ class ForeachSOLP(StatefulOptimizer):
8
+ """
9
+ ForeachSOLP
10
+
11
+ Sources:
12
+ LaProp:
13
+ LaProp: Separating Momentum and Adaptivity in Adam
14
+ Liu Ziyin, Zhikang T.Wang, Masahito Ueda
15
+ https://arxiv.org/abs/2002.04839
16
+ https://github.com/ClashLuke/HeavyBall/blob/main/heavyball/foreach_laprop.py
17
+
18
+ Baseline SOAP:
19
+ SOAP: Improving and Stabilizing Shampoo using Adam
20
+ Nikhil Vyas, Depen Morwani, Rosie Zhao, Itai Shapira, David Brandfonbrener, Lucas Janson, Sham Kakade
21
+ https://arxiv.org/abs/2409.11321
22
+ https://github.com/nikhilvyas/SOAP
23
+
24
+ ScheduleFree:
25
+ The Road Less Scheduled
26
+ Aaron Defazio, Xingyu Alice Yang, Harsh Mehta, Konstantin Mishchenko, Ahmed Khaled, Ashok Cutkosky
27
+ https://arxiv.org/abs/2405.15682
28
+ https://github.com/facebookresearch/schedule_free
29
+ """
30
+
31
+ def __init__(self, params, lr: float = 3e-3, betas=(0.9, 0.95), shampoo_beta: float = 0.95, eps: float = 1e-8,
32
+ weight_decay: float = 0.01, precondition_frequency: int = 2, max_precond_dim: int = 2048, #
33
+ merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
34
+ data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1,
35
+ split: bool = False, foreach: bool = True, mars: bool = False, caution: bool = False,
36
+ mars_gamma: float = 0.0025):
37
+ defaults = {"lr": lr, "betas": betas, "shampoo_beta": shampoo_beta, "eps": eps, "weight_decay": weight_decay,
38
+ "precondition_frequency": precondition_frequency, "max_precond_dim": max_precond_dim,
39
+ "merge_dims": merge_dims, "precondition_1d": precondition_1d, "normalize_grads": normalize_grads,
40
+ "correct_bias": correct_bias, 'warmup_steps': warmup_steps, 'split': split, 'mars': mars,
41
+ 'caution': caution, 'mars_gamma': mars_gamma}
42
+ super().__init__(params, defaults, foreach)
43
+ self._data_format = data_format
44
+
45
+ def _step(self, group):
46
+ vals = []
47
+ step = 0
48
+
49
+ max_precond_dim = group['max_precond_dim']
50
+ precondition_1d = group['precondition_1d']
51
+
52
+ for p, g in self.split_p_and_g_in_group(group, beta1=group['betas'][0]):
53
+ state = self.state_(p)
54
+ step = state['step'] = state.get("step", -1) + 1
55
+
56
+ if "exp_avg" not in state:
57
+ state["exp_avg"] = torch.zeros_like(g, dtype=torch.float32, memory_format=torch.preserve_format)
58
+ state["exp_avg_sq"] = torch.zeros_like(g, dtype=torch.float32, memory_format=torch.preserve_format)
59
+ init_preconditioner(g, state, max_precond_dim, precondition_1d)
60
+ update_preconditioner(g, state, max_precond_dim, precondition_1d, 0, True)
61
+ continue # first step is skipped so that we never use the current gradients in the projection.
62
+
63
+ # Projecting gradients to the eigenbases of Shampoo's preconditioner
64
+ # i.e. projecting to the eigenbases of matrices in state['GG']
65
+ grad_projected = project(g, state['Q'], False)
66
+ exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
67
+ vals.append((p, g, grad_projected, exp_avg, exp_avg_sq))
68
+
69
+ if not vals:
70
+ return
71
+
72
+ p_list, grad, grad_projected, exp_avg, exp_avg_sq = zip(*vals)
73
+ beta1, beta2 = group["betas"]
74
+
75
+ old_debiased2 = beta_debias(beta2, step)
76
+
77
+ # Decay the first and second moment running average coefficient
78
+ # In-place operations to update the averages at the same time
79
+ step_tensor = torch.empty((), dtype=torch.int32, device=p_list[0].device).fill_(step)
80
+
81
+ step_size = -group["lr"] * min(step / group['warmup_steps'], 1)
82
+
83
+ for p, g, gp, ea, eas in zip(p_list, grad, grad_projected, exp_avg, exp_avg_sq):
84
+ laprop_exp_avg_(ea, eas, gp, beta1, beta2, step_tensor)
85
+ state = self.state_(p)
86
+ precond = project(ea, state['Q'], True)
87
+
88
+ update_preconditioner(g, state, max_precond_dim, precondition_1d, old_debiased2, step > 0 and step % group['precondition_frequency'] == 0)
89
+ update_param_([p], [precond], step_size, group["weight_decay"], caution=group['caution'], grad=[g])
@@ -6,7 +6,7 @@ from .utils import init_preconditioner, update_preconditioner, project, beta_deb
6
6
 
7
7
  class PaLMForeachSOAP(StatefulOptimizer):
8
8
  """
9
- SFPaLMForeachSOAP
9
+ PaLMForeachSOAP
10
10
 
11
11
  Sources:
12
12
  Baseline SOAP:
@@ -84,6 +84,7 @@ class PaLMForeachSOAP(StatefulOptimizer):
84
84
  step_size = -group["lr"] * min(step / group['warmup_steps'], 1)
85
85
 
86
86
  for p, g, gp, ea, eas in zip(p_list, grad, grad_projected, exp_avg, exp_avg_sq):
87
+ state = self.state_(p)
87
88
  d = exp_avg_(ea, eas, g, gp, beta1, beta2, step_tensor)[0]
88
89
  # Projecting the exponential moving average of gradients to the eigenbases of Shampoo's preconditioner
89
90
  # i.e. projecting to the eigenbases of matrices in state['GG']
@@ -0,0 +1,98 @@
1
+ import torch
2
+
3
+ from .utils import init_preconditioner, update_preconditioner, project, beta_debias, exp_avg_sq_, update_param_, set_, \
4
+ StatefulOptimizer, laprop_exp_avg_
5
+
6
+
7
+ class PaLMForeachSOLP(StatefulOptimizer):
8
+ """
9
+ PaLMForeachSOAP
10
+
11
+ Sources:
12
+ LaProp:
13
+ LaProp: Separating Momentum and Adaptivity in Adam
14
+ Liu Ziyin, Zhikang T.Wang, Masahito Ueda
15
+ https://arxiv.org/abs/2002.04839
16
+ https://github.com/ClashLuke/HeavyBall/blob/main/heavyball/foreach_laprop.py
17
+
18
+ Baseline SOAP:
19
+ SOAP: Improving and Stabilizing Shampoo using Adam
20
+ Nikhil Vyas, Depen Morwani, Rosie Zhao, Itai Shapira, David Brandfonbrener, Lucas Janson, Sham Kakade
21
+ https://arxiv.org/abs/2409.11321
22
+ https://github.com/nikhilvyas/SOAP
23
+
24
+ ScheduleFree:
25
+ The Road Less Scheduled
26
+ Aaron Defazio, Xingyu Alice Yang, Harsh Mehta, Konstantin Mishchenko, Ahmed Khaled, Ashok Cutkosky
27
+ https://arxiv.org/abs/2405.15682
28
+ https://github.com/facebookresearch/schedule_free
29
+
30
+ Beta2 Schedule:
31
+ PaLM: Scaling Language Modeling with Pathways
32
+ Aakanksha Chowdhery, Sharan Narang, Jacob Devlin, Maarten Bosma, Gaurav Mishra, Adam Roberts, Paul Barham, Hyung Won Chung, Charles Sutton, Sebastian Gehrmann, Parker Schuh, Kensen Shi, Sasha Tsvyashchenko, Joshua Maynez, Abhishek Rao, Parker Barnes, Yi Tay, Noam Shazeer, Vinodkumar Prabhakaran, Emily Reif, Nan Du, Ben Hutchinson, Reiner Pope, James Bradbury, Jacob Austin, Michael Isard, Guy Gur-Ari, Pengcheng Yin, Toju Duke, Anselm Levskaya, Sanjay Ghemawat, Sunipa Dev, Henryk Michalewski, Xavier Garcia, Vedant Misra, Kevin Robinson, Liam Fedus, Denny Zhou, Daphne Ippolito, David Luan, Hyeontaek Lim, Barret Zoph, Alexander Spiridonov, Ryan Sepassi, David Dohan, Shivani Agrawal, Mark Omernick, Andrew M. Dai, Thanumalayan Sankaranarayana Pillai, Marie Pellat, Aitor Lewkowycz, Erica Moreira, Rewon Child, Oleksandr Polozov, Katherine Lee, Zongwei Zhou, Xuezhi Wang, Brennan Saeta, Mark Diaz, Orhan Firat, Michele Catasta, Jason Wei, Kathy Meier-Hellstern, Douglas Eck, Jeff Dean, Slav Petrov, Noah Fiedel
33
+ https://arxiv.org/abs/2204.02311
34
+ """
35
+
36
+ def __init__(self, params, lr: float = 3e-3, beta=0.9, betas=(None, None), shampoo_beta: float = 0.95,
37
+ eps: float = 1e-8, weight_decay: float = 0.01, precondition_frequency: int = 2,
38
+ max_precond_dim: int = 2048, #
39
+ merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
40
+ data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1,
41
+ beta2_scale: float = 0.8, split: bool = False, foreach: bool = True, mars: bool = False,
42
+ caution: bool = False, mars_gamma: float = 0.0025):
43
+ if betas[0] is not None:
44
+ beta = betas[0]
45
+ defaults = {"lr": lr, "beta": beta, "shampoo_beta": shampoo_beta, "eps": eps, "weight_decay": weight_decay,
46
+ "precondition_frequency": precondition_frequency, "max_precond_dim": max_precond_dim,
47
+ "merge_dims": merge_dims, "precondition_1d": precondition_1d, "normalize_grads": normalize_grads,
48
+ "correct_bias": correct_bias, 'warmup_steps': warmup_steps, 'beta2_scale': beta2_scale,
49
+ 'split': split, 'mars': mars, 'caution': caution, 'mars_gamma': mars_gamma}
50
+ super().__init__(params, defaults, foreach)
51
+ self._data_format = data_format
52
+
53
+ def _step(self, group):
54
+ vals = []
55
+ step = 0
56
+
57
+ max_precond_dim = group['max_precond_dim']
58
+ precondition_1d = group['precondition_1d']
59
+
60
+ for p, g in self.split_p_and_g_in_group(group, beta1=group['beta']):
61
+ state = self.state_(p)
62
+ step = state['step'] = state.get("step", -1) + 1
63
+
64
+ if "exp_avg" not in state:
65
+ state["exp_avg"] = torch.zeros_like(g, dtype=torch.float32, memory_format=torch.preserve_format)
66
+ state["exp_avg_sq"] = torch.zeros_like(g, dtype=torch.float32, memory_format=torch.preserve_format)
67
+ init_preconditioner(g, state, max_precond_dim, precondition_1d)
68
+ update_preconditioner(g, state, max_precond_dim, precondition_1d, 0, True)
69
+ continue # first step is skipped so that we never use the current gradients in the projection.
70
+
71
+ # Projecting gradients to the eigenbases of Shampoo's preconditioner
72
+ # i.e. projecting to the eigenbases of matrices in state['GG']
73
+ grad_projected = project(g, state['Q'], False)
74
+ exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
75
+ vals.append((p, g, grad_projected, exp_avg, exp_avg_sq))
76
+
77
+ if not vals:
78
+ return
79
+
80
+ p_list, grad, grad_projected, exp_avg, exp_avg_sq = zip(*vals)
81
+ beta1 = group["beta"]
82
+
83
+ beta2 = 1 - step ** -group['beta2_scale']
84
+ old_debiased2 = beta_debias(beta2, step)
85
+
86
+ # Decay the first and second moment running average coefficient
87
+ # In-place operations to update the averages at the same time
88
+ beta2 = torch.empty((), dtype=torch.float32, device=p_list[0].device).fill_(beta2)
89
+ step_tensor = torch.empty((), dtype=torch.int32, device=p_list[0].device).fill_(step)
90
+ step_size = -group["lr"] * min(step / group['warmup_steps'], 1)
91
+
92
+ for p, g, gp, ea, eas in zip(p_list, grad, grad_projected, exp_avg, exp_avg_sq):
93
+ laprop_exp_avg_(ea, eas, gp, beta1, beta2, step_tensor)
94
+ state = self.state_(p)
95
+ precond = project(ea, state['Q'], True)
96
+
97
+ update_preconditioner(g, state, max_precond_dim, precondition_1d, old_debiased2, step > 0 and step % group['precondition_frequency'] == 0)
98
+ update_param_([p], [precond], step_size, group["weight_decay"], caution=group['caution'], grad=[g])
@@ -0,0 +1,95 @@
1
+
2
+
3
+ import random
4
+
5
+ import torch
6
+
7
+ from .utils import init_preconditioner, update_preconditioner, project, beta_debias, update_param_, \
8
+ precond_schedule, set_, StatefulOptimizer, laprop_exp_avg_
9
+
10
+
11
+ class PrecondScheduleForeachSOLP(StatefulOptimizer):
12
+ """
13
+ Sources:
14
+ LaProp:
15
+ LaProp: Separating Momentum and Adaptivity in Adam
16
+ Liu Ziyin, Zhikang T.Wang, Masahito Ueda
17
+ https://arxiv.org/abs/2002.04839
18
+ https://github.com/ClashLuke/HeavyBall/blob/main/heavyball/foreach_laprop.py
19
+
20
+ Preconditioner Schedules:
21
+ Preconditioned Stochastic Gradient Descent
22
+ Xi-Lin Li, Omead Pooladzandi, Evan Walters
23
+ https://arxiv.org/abs/1512.04202
24
+ https://github.com/evanatyourservice/kron_torch
25
+ https://github.com/lixilinx/psgd_torch
26
+
27
+ Baseline SOAP:
28
+ SOAP: Improving and Stabilizing Shampoo using Adam
29
+ Nikhil Vyas, Depen Morwani, Rosie Zhao, Itai Shapira, David Brandfonbrener, Lucas Janson, Sham Kakade
30
+ https://arxiv.org/abs/2409.11321
31
+ https://github.com/nikhilvyas/SOAP
32
+ """
33
+
34
+ def __init__(self, params, lr: float = 3e-3, betas=(0.9, 0.95), shampoo_beta: float = 0.95, eps: float = 1e-8,
35
+ weight_decay: float = 0.01, precondition_frequency: int = 2, max_precond_dim: int = 2048, #
36
+ merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
37
+ data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1,
38
+ precond_scheduler=(1 / 3, 9), split: bool = False, foreach: bool = True, mars: bool = False,
39
+ caution: bool = False, mars_gamma: float = 0.0025):
40
+ defaults = {"lr": lr, "betas": betas, "shampoo_beta": shampoo_beta, "eps": eps, "weight_decay": weight_decay,
41
+ "precondition_frequency": precondition_frequency, "max_precond_dim": max_precond_dim,
42
+ "merge_dims": merge_dims, "precondition_1d": precondition_1d, "normalize_grads": normalize_grads,
43
+ "correct_bias": correct_bias, 'warmup_steps': warmup_steps, 'precond_scheduler': precond_scheduler,
44
+ 'split': split, 'mars': mars, 'caution': caution, 'mars_gamma': mars_gamma}
45
+ super().__init__(params, defaults, foreach)
46
+ self._data_format = data_format
47
+ self.rng = random.Random(0x120983109)
48
+
49
+ def _step(self, group):
50
+ vals = []
51
+ step = 0
52
+
53
+ max_precond_dim = group['max_precond_dim']
54
+ precondition_1d = group['precondition_1d']
55
+
56
+ for p, g in self.split_p_and_g_in_group(group, beta1=group['betas'][0]):
57
+ state = self.state_(p)
58
+ step = state['step'] = state.get("step", -1) + 1
59
+
60
+ if "exp_avg" not in state:
61
+ state["exp_avg"] = torch.zeros_like(g, dtype=torch.float32, memory_format=torch.preserve_format)
62
+ state["exp_avg_sq"] = torch.zeros_like(g, dtype=torch.float32, memory_format=torch.preserve_format)
63
+ init_preconditioner(g, state, max_precond_dim, precondition_1d)
64
+ update_preconditioner(g, state, max_precond_dim, precondition_1d, 0, True)
65
+ continue # first step is skipped so that we never use the current gradients in the projection.
66
+
67
+ # Projecting gradients to the eigenbases of Shampoo's preconditioner
68
+ # i.e. projecting to the eigenbases of matrices in state['GG']
69
+ grad_projected = project(g, state['Q'], False)
70
+ exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
71
+ vals.append((p, g, grad_projected, exp_avg, exp_avg_sq))
72
+
73
+ if not vals:
74
+ return
75
+
76
+ p_list, grad, grad_projected, exp_avg, exp_avg_sq = zip(*vals)
77
+ beta1, beta2 = group["betas"]
78
+
79
+ old_debiased2 = beta_debias(beta2, step)
80
+
81
+ # Decay the first and second moment running average coefficient
82
+ # In-place operations to update the averages at the same time
83
+ step_tensor = torch.empty((), dtype=torch.int32, device=p_list[0].device).fill_(step)
84
+
85
+ update_precond = precond_schedule(step, group['precond_scheduler'], self.rng)
86
+ step_size = -group["lr"] * min(step / group['warmup_steps'], 1)
87
+
88
+
89
+ for p, g, gp, ea, eas in zip(p_list, grad, grad_projected, exp_avg, exp_avg_sq):
90
+ laprop_exp_avg_(ea, eas, gp, beta1, beta2, step_tensor)
91
+ state = self.state_(p)
92
+ precond = project(ea, state['Q'], True)
93
+
94
+ update_preconditioner(g, state, max_precond_dim, precondition_1d, old_debiased2, update_precond)
95
+ update_param_([p], [precond], step_size, group["weight_decay"], caution=group['caution'], grad=[g])
@@ -0,0 +1,103 @@
1
+ import random
2
+
3
+ import torch
4
+
5
+ from .utils import init_preconditioner, update_preconditioner, project, beta_debias, laprop_exp_avg_, update_param_, \
6
+ precond_schedule, set_, StatefulOptimizer
7
+
8
+
9
+ class PrecondSchedulePaLMForeachSOLP(StatefulOptimizer):
10
+ """
11
+ Sources:
12
+ LaProp:
13
+ LaProp: Separating Momentum and Adaptivity in Adam
14
+ Liu Ziyin, Zhikang T.Wang, Masahito Ueda
15
+ https://arxiv.org/abs/2002.04839
16
+ https://github.com/ClashLuke/HeavyBall/blob/main/heavyball/foreach_laprop.py
17
+
18
+ Preconditioner Schedules:
19
+ Preconditioned Stochastic Gradient Descent
20
+ Xi-Lin Li, Omead Pooladzandi, Evan Walters
21
+ https://arxiv.org/abs/1512.04202
22
+ https://github.com/evanatyourservice/kron_torch
23
+ https://github.com/lixilinx/psgd_torch
24
+
25
+ Baseline SOAP:
26
+ SOAP: Improving and Stabilizing Shampoo using Adam
27
+ Nikhil Vyas, Depen Morwani, Rosie Zhao, Itai Shapira, David Brandfonbrener, Lucas Janson, Sham Kakade
28
+ https://arxiv.org/abs/2409.11321
29
+ https://github.com/nikhilvyas/SOAP
30
+
31
+ Beta2 Schedule:
32
+ PaLM: Scaling Language Modeling with Pathways
33
+ Aakanksha Chowdhery, Sharan Narang, Jacob Devlin, Maarten Bosma, Gaurav Mishra, Adam Roberts, Paul Barham, Hyung Won Chung, Charles Sutton, Sebastian Gehrmann, Parker Schuh, Kensen Shi, Sasha Tsvyashchenko, Joshua Maynez, Abhishek Rao, Parker Barnes, Yi Tay, Noam Shazeer, Vinodkumar Prabhakaran, Emily Reif, Nan Du, Ben Hutchinson, Reiner Pope, James Bradbury, Jacob Austin, Michael Isard, Guy Gur-Ari, Pengcheng Yin, Toju Duke, Anselm Levskaya, Sanjay Ghemawat, Sunipa Dev, Henryk Michalewski, Xavier Garcia, Vedant Misra, Kevin Robinson, Liam Fedus, Denny Zhou, Daphne Ippolito, David Luan, Hyeontaek Lim, Barret Zoph, Alexander Spiridonov, Ryan Sepassi, David Dohan, Shivani Agrawal, Mark Omernick, Andrew M. Dai, Thanumalayan Sankaranarayana Pillai, Marie Pellat, Aitor Lewkowycz, Erica Moreira, Rewon Child, Oleksandr Polozov, Katherine Lee, Zongwei Zhou, Xuezhi Wang, Brennan Saeta, Mark Diaz, Orhan Firat, Michele Catasta, Jason Wei, Kathy Meier-Hellstern, Douglas Eck, Jeff Dean, Slav Petrov, Noah Fiedel
34
+ https://arxiv.org/abs/2204.02311
35
+ """
36
+
37
+ def __init__(self, params, lr: float = 3e-3, beta=0.9, shampoo_beta: float = 0.95, eps: float = 1e-8,
38
+ weight_decay: float = 0.01, precondition_frequency: int = 2, max_precond_dim: int = 2048, #
39
+ merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
40
+ data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1,
41
+ precond_scheduler=(1 / 3, 9), betas=(None, None), beta2_scale: float = 0.8, split: bool = False,
42
+ foreach: bool = True, mars: bool = False, caution: bool = False, mars_gamma: float = 0.0025):
43
+ if betas[0] is not None:
44
+ beta = betas[0]
45
+ defaults = {"lr": lr, "beta": beta, "shampoo_beta": shampoo_beta, "eps": eps, "weight_decay": weight_decay,
46
+ "precondition_frequency": precondition_frequency, "max_precond_dim": max_precond_dim,
47
+ "merge_dims": merge_dims, "precondition_1d": precondition_1d, "normalize_grads": normalize_grads,
48
+ "correct_bias": correct_bias, 'warmup_steps': warmup_steps, 'precond_scheduler': precond_scheduler,
49
+ 'beta2_scale': beta2_scale, 'split': split, 'mars': mars, 'caution': caution,
50
+ 'mars_gamma': mars_gamma}
51
+ super().__init__(params, defaults, foreach)
52
+ self._data_format = data_format
53
+ self.rng = random.Random(0x120983109)
54
+
55
+ def _step(self, group):
56
+ vals = []
57
+ step = 0
58
+
59
+ max_precond_dim = group['max_precond_dim']
60
+ precondition_1d = group['precondition_1d']
61
+
62
+ for p, g in self.split_p_and_g_in_group(group, beta1=group['beta']):
63
+ state = self.state_(p)
64
+ step = state['step'] = state.get("step", -1) + 1
65
+
66
+ if "exp_avg" not in state:
67
+ state["exp_avg"] = torch.zeros_like(g, dtype=torch.float32, memory_format=torch.preserve_format)
68
+ state["exp_avg_sq"] = torch.zeros_like(g, dtype=torch.float32, memory_format=torch.preserve_format)
69
+ init_preconditioner(g, state, max_precond_dim, precondition_1d)
70
+ update_preconditioner(g, state, max_precond_dim, precondition_1d, 0, True)
71
+ continue # first step is skipped so that we never use the current gradients in the projection.
72
+
73
+ # Projecting gradients to the eigenbases of Shampoo's preconditioner
74
+ # i.e. projecting to the eigenbases of matrices in state['GG']
75
+ grad_projected = project(g, state['Q'], False)
76
+ exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
77
+ vals.append((p, g, grad_projected, exp_avg, exp_avg_sq))
78
+
79
+ if not vals:
80
+ return
81
+
82
+ p_list, grad, grad_projected, exp_avg, exp_avg_sq = zip(*vals)
83
+ beta1 = group["beta"]
84
+
85
+ beta2 = 1 - max(step, 1) ** -group['beta2_scale']
86
+ old_debiased1 = beta_debias(beta1, step)
87
+ old_debiased2 = beta_debias(beta2, step)
88
+
89
+ # Decay the first and second moment running average coefficient
90
+ # In-place operations to update the averages at the same time
91
+ beta2 = torch.empty((), dtype=torch.float32, device=p_list[0].device).fill_(beta2)
92
+ step_tensor = torch.empty((), dtype=torch.int32, device=p_list[0].device).fill_(step)
93
+
94
+ update_precond = precond_schedule(step, group['precond_scheduler'], self.rng)
95
+ step_size = -group["lr"] * min(step / group['warmup_steps'], 1)
96
+
97
+ for p, g, gp, ea, eas in zip(p_list, grad, grad_projected, exp_avg, exp_avg_sq):
98
+ laprop_exp_avg_(ea, eas, gp, beta1, beta2, step_tensor)
99
+ state = self.state_(p)
100
+ precond = project(ea, state['Q'], True)
101
+
102
+ update_preconditioner(g, state, max_precond_dim, precondition_1d, old_debiased2, update_precond)
103
+ update_param_([p], [precond], step_size, group["weight_decay"], caution=group['caution'], grad=[g])
@@ -693,6 +693,29 @@ def exp_avg_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor]
693
693
  return denom
694
694
 
695
695
 
696
+
697
+ @decorator_knowngood
698
+ def _compilable_laprop_exp_avg_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor],
699
+ grad_projected: List[Tensor], beta1: Tensor, beta2: Tensor, step: Tensor):
700
+ beta1 = beta_debias(beta1, step)
701
+ beta2 = beta_debias(beta2, step)
702
+
703
+ gp32, exp_avg_sq32 = [list(map(promote, x)) for x in [grad_projected, exp_avg_sq]]
704
+
705
+ denom = exp_avg_sq_(exp_avg_sq32, gp32, beta2, 1e-8)
706
+ gp32 = torch._foreach_div(gp32, denom)
707
+ stochastic_lerp_(exp_avg, gp32, 1 - beta1)
708
+
709
+ copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
710
+
711
+
712
+ def laprop_exp_avg_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad_projected: List[Tensor],
713
+ beta1: float, beta2: float, step: int):
714
+ exp_avg, exp_avg_sq, grad_projected = list_guard(exp_avg), list_guard(exp_avg_sq), list_guard(grad_projected)
715
+ beta1, beta, step = scalar_guard(beta1, exp_avg[0]), scalar_guard(beta2, exp_avg[0]), scalar_guard(step, exp_avg[0])
716
+ _compilable_laprop_exp_avg_(exp_avg, exp_avg_sq, grad_projected, beta1, beta2, step)
717
+
718
+
696
719
  @decorator_knowngood
697
720
  def _compilable_copy_stochastic_(target: Tensor, source: Tensor):
698
721
  """Taken as-is from https://github.com/pytorch/pytorch/issues/120376#issuecomment-1974828905"""
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: heavyball
3
- Version: 0.25.0
3
+ Version: 0.25.1
4
4
  Summary: Efficient optimizers
5
5
  Home-page: https://github.com/clashluke/heavyball
6
6
  Author: Lucas Nestler
@@ -10,11 +10,15 @@ heavyball/foreach_adopt.py
10
10
  heavyball/foreach_laprop.py
11
11
  heavyball/foreach_sfadamw.py
12
12
  heavyball/foreach_soap.py
13
+ heavyball/foreach_solp.py
13
14
  heavyball/p_adam.py
14
15
  heavyball/palm_foreach_sfadamw.py
15
16
  heavyball/palm_foreach_soap.py
17
+ heavyball/palm_foreach_solp.py
16
18
  heavyball/precond_schedule_foreach_soap.py
19
+ heavyball/precond_schedule_foreach_solp.py
17
20
  heavyball/precond_schedule_palm_foreach_soap.py
21
+ heavyball/precond_schedule_palm_foreach_solp.py
18
22
  heavyball/precond_schedule_sfpsoap.py
19
23
  heavyball/psgd_kron.py
20
24
  heavyball/pure_psgd.py
@@ -40,4 +44,5 @@ test/test_merge.py
40
44
  test/test_no_grad.py
41
45
  test/test_psgd.py
42
46
  test/test_soap.py
47
+ test/test_solp.py
43
48
  test/test_stochastic_updates.py
@@ -10,7 +10,7 @@ setuptools.setup(
10
10
  name='heavyball',
11
11
  license='BSD',
12
12
  description='Efficient optimizers',
13
- version='0.25.0',
13
+ version='0.25.1',
14
14
  long_description=README,
15
15
  url='https://github.com/clashluke/heavyball',
16
16
  packages=setuptools.find_packages(),
@@ -0,0 +1,50 @@
1
+ import heavyball
2
+ import heavyball.utils
3
+ import pytest
4
+ import torch
5
+ from benchmark.utils import get_optim
6
+ from heavyball.utils import clean, set_torch, ScheduleFree
7
+ from torch import nn
8
+ from torch._dynamo import config
9
+
10
+ config.cache_size_limit = 128
11
+
12
+ @pytest.mark.parametrize("opt", heavyball.__all__)
13
+ @pytest.mark.parametrize("size,depth", [(128, 2)])
14
+ def test_solp(opt, size, depth: int, iterations: int = 65536, outer_iterations: int = 2):
15
+ set_torch()
16
+ if 'SOAP' not in opt:
17
+ raise pytest.skip('This test is only for SOAP')
18
+
19
+ opt_name = opt
20
+ peaks = []
21
+ losses = []
22
+
23
+ for use_solp in [True, False]:
24
+ try:
25
+ opt = getattr(heavyball, opt_name.replace("SOAP", "SOLP") if use_solp else opt_name)
26
+ except AttributeError:
27
+ raise pytest.skip(f'{opt_name} does not have a SOLP variant')
28
+ print(opt, opt_name.replace("SOAP", "SOLP"))
29
+
30
+ torch.manual_seed(0x2131290)
31
+ peaks.append([])
32
+ losses.append([])
33
+
34
+ for i in range(outer_iterations):
35
+ model = nn.Sequential(*[nn.Linear(size, size) for _ in range(depth)]).cuda()
36
+ o = get_optim(opt, model.parameters(), lr=1e-5)
37
+
38
+ for _ in range(iterations):
39
+ loss = model(torch.randn((1024, size), device='cuda')).square().mean()
40
+ loss.backward()
41
+ o.step()
42
+ o.zero_grad()
43
+ losses[-1].append(loss.detach())
44
+
45
+ del model, o
46
+ clean()
47
+
48
+ for i, (l0, l1) in enumerate(zip(*losses)):
49
+ print(i, l0.item(), l1.item())
50
+ assert l0.item() <= l1.item() * 1.1
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes