heavyball 0.24.4__tar.gz → 0.25.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. {heavyball-0.24.4 → heavyball-0.25.1}/PKG-INFO +1 -1
  2. {heavyball-0.24.4 → heavyball-0.25.1}/heavyball/__init__.py +12 -2
  3. {heavyball-0.24.4 → heavyball-0.25.1}/heavyball/foreach_soap.py +3 -1
  4. heavyball-0.25.1/heavyball/foreach_solp.py +89 -0
  5. {heavyball-0.24.4 → heavyball-0.25.1}/heavyball/palm_foreach_soap.py +2 -1
  6. heavyball-0.25.1/heavyball/palm_foreach_solp.py +98 -0
  7. heavyball-0.25.1/heavyball/precond_schedule_foreach_solp.py +95 -0
  8. heavyball-0.25.1/heavyball/precond_schedule_palm_foreach_solp.py +103 -0
  9. {heavyball-0.24.4 → heavyball-0.25.1}/heavyball/utils.py +33 -7
  10. {heavyball-0.24.4 → heavyball-0.25.1}/heavyball.egg-info/PKG-INFO +1 -1
  11. {heavyball-0.24.4 → heavyball-0.25.1}/heavyball.egg-info/SOURCES.txt +5 -0
  12. {heavyball-0.24.4 → heavyball-0.25.1}/setup.py +1 -1
  13. heavyball-0.25.1/test/test_solp.py +50 -0
  14. {heavyball-0.24.4 → heavyball-0.25.1}/LICENSE +0 -0
  15. {heavyball-0.24.4 → heavyball-0.25.1}/README.md +0 -0
  16. {heavyball-0.24.4 → heavyball-0.25.1}/heavyball/cached_delayed_psgd_kron.py +0 -0
  17. {heavyball-0.24.4 → heavyball-0.25.1}/heavyball/cached_psgd_kron.py +0 -0
  18. {heavyball-0.24.4 → heavyball-0.25.1}/heavyball/delayed_psgd.py +0 -0
  19. {heavyball-0.24.4 → heavyball-0.25.1}/heavyball/foreach_adamw.py +0 -0
  20. {heavyball-0.24.4 → heavyball-0.25.1}/heavyball/foreach_adopt.py +0 -0
  21. {heavyball-0.24.4 → heavyball-0.25.1}/heavyball/foreach_laprop.py +0 -0
  22. {heavyball-0.24.4 → heavyball-0.25.1}/heavyball/foreach_sfadamw.py +0 -0
  23. {heavyball-0.24.4 → heavyball-0.25.1}/heavyball/p_adam.py +0 -0
  24. {heavyball-0.24.4 → heavyball-0.25.1}/heavyball/palm_foreach_sfadamw.py +0 -0
  25. {heavyball-0.24.4 → heavyball-0.25.1}/heavyball/precond_schedule_foreach_soap.py +0 -0
  26. {heavyball-0.24.4 → heavyball-0.25.1}/heavyball/precond_schedule_palm_foreach_soap.py +0 -0
  27. {heavyball-0.24.4 → heavyball-0.25.1}/heavyball/precond_schedule_sfpsoap.py +0 -0
  28. {heavyball-0.24.4 → heavyball-0.25.1}/heavyball/psgd_kron.py +0 -0
  29. {heavyball-0.24.4 → heavyball-0.25.1}/heavyball/pure_psgd.py +0 -0
  30. {heavyball-0.24.4 → heavyball-0.25.1}/heavyball/schedule_free_palm_foreach_soap.py +0 -0
  31. {heavyball-0.24.4 → heavyball-0.25.1}/heavyball.egg-info/dependency_links.txt +0 -0
  32. {heavyball-0.24.4 → heavyball-0.25.1}/heavyball.egg-info/requires.txt +0 -0
  33. {heavyball-0.24.4 → heavyball-0.25.1}/heavyball.egg-info/top_level.txt +0 -0
  34. {heavyball-0.24.4 → heavyball-0.25.1}/setup.cfg +0 -0
  35. {heavyball-0.24.4 → heavyball-0.25.1}/test/test_bf16_params.py +0 -0
  36. {heavyball-0.24.4 → heavyball-0.25.1}/test/test_bf16_q.py +0 -0
  37. {heavyball-0.24.4 → heavyball-0.25.1}/test/test_bf16_storage.py +0 -0
  38. {heavyball-0.24.4 → heavyball-0.25.1}/test/test_caution.py +0 -0
  39. {heavyball-0.24.4 → heavyball-0.25.1}/test/test_channels_last.py +0 -0
  40. {heavyball-0.24.4 → heavyball-0.25.1}/test/test_closure.py +0 -0
  41. {heavyball-0.24.4 → heavyball-0.25.1}/test/test_ema.py +0 -0
  42. {heavyball-0.24.4 → heavyball-0.25.1}/test/test_foreach.py +0 -0
  43. {heavyball-0.24.4 → heavyball-0.25.1}/test/test_hook.py +0 -0
  44. {heavyball-0.24.4 → heavyball-0.25.1}/test/test_mars.py +0 -0
  45. {heavyball-0.24.4 → heavyball-0.25.1}/test/test_memory.py +0 -0
  46. {heavyball-0.24.4 → heavyball-0.25.1}/test/test_merge.py +0 -0
  47. {heavyball-0.24.4 → heavyball-0.25.1}/test/test_no_grad.py +0 -0
  48. {heavyball-0.24.4 → heavyball-0.25.1}/test/test_psgd.py +0 -0
  49. {heavyball-0.24.4 → heavyball-0.25.1}/test/test_soap.py +0 -0
  50. {heavyball-0.24.4 → heavyball-0.25.1}/test/test_stochastic_updates.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: heavyball
3
- Version: 0.24.4
3
+ Version: 0.25.1
4
4
  Summary: Efficient optimizers
5
5
  Home-page: https://github.com/clashluke/heavyball
6
6
  Author: Lucas Nestler
@@ -15,6 +15,10 @@ from .precond_schedule_sfpsoap import PrecondScheduleSFPaLMSOAP
15
15
  from .psgd_kron import ForeachPSGDKron
16
16
  from .pure_psgd import ForeachPurePSGD
17
17
  from .schedule_free_palm_foreach_soap import SFPaLMForeachSOAP
18
+ from .foreach_solp import ForeachSOLP
19
+ from .palm_foreach_solp import PaLMForeachSOLP
20
+ from .precond_schedule_palm_foreach_solp import PrecondSchedulePaLMForeachSOLP
21
+ from .precond_schedule_foreach_solp import PrecondScheduleForeachSOLP
18
22
 
19
23
  PalmForEachSoap = PaLMForeachSOAP
20
24
 
@@ -35,12 +39,18 @@ PaLMPAdam = ForeachPaLMPAdam
35
39
  DelayedPSGD = ForeachDelayedPSGD
36
40
  CachedPSGDKron = ForeachCachedPSGDKron
37
41
  CachedDelayedPSGDKron = ForeachCachedDelayedPSGDKron
42
+ SOLP = ForeachSOLP
43
+ PaLMSOLP = PaLMForeachSOLP
44
+ PrecondSchedulePaLMSOLP = PrecondSchedulePaLMForeachSOLP
45
+ PrecondScheduleSOLP = PrecondScheduleForeachSOLP
38
46
 
39
47
  __all__ = ['PalmForEachSoap', 'PaLMForeachSFAdamW', 'PaLMForeachSOAP', 'SFPaLMForeachSOAP', 'PrecondScheduleSFPaLMSOAP',
40
48
  'ForeachSOAP', 'ForeachSFAdamW', 'ForeachLaProp', 'ForeachADOPT', 'PrecondScheduleForeachSOAP',
41
49
  'PrecondSchedulePaLMForeachSOAP', 'ForeachPSGDKron', 'ForeachAdamW', 'ForeachPurePSGD', 'ForeachPaLMPAdam',
42
- 'ForeachDelayedPSGD', 'ForeachCachedPSGDKron', 'ForeachCachedDelayedPSGDKron',
50
+ 'ForeachDelayedPSGD', 'ForeachCachedPSGDKron', 'ForeachCachedDelayedPSGDKron', 'ForeachSOLP',
51
+ 'PaLMForeachSOLP', 'PrecondSchedulePaLMForeachSOLP', 'PrecondScheduleForeachSOLP',
43
52
  #
44
53
  'PaLMSOAP', 'PaLMSFAdamW', 'PaLMSFSoap', 'PaLMSFAdamW', 'PrecondScheduleSFPaLMSOAP', 'SOAP', 'SFAdamW',
45
54
  'LaProp', 'ADOPT', 'PSGDKron', 'AdamW', 'PurePSGD', 'PaLMPAdam', 'DelayedPSGD', 'CachedPSGDKron',
46
- 'CachedDelayedPSGDKron', 'PrecondScheduleSOAP', 'PrecondSchedulePaLMSOAP']
55
+ 'CachedDelayedPSGDKron', 'PrecondScheduleSOAP', 'PrecondSchedulePaLMSOAP', 'SOLP', 'PrecondScheduleSOLP',
56
+ 'PrecondSchedulePaLMSOLP', 'PrecondScheduleSOLP']
@@ -6,7 +6,7 @@ from .utils import init_preconditioner, update_preconditioner, project, beta_deb
6
6
 
7
7
  class ForeachSOAP(StatefulOptimizer):
8
8
  """
9
- SFPaLMForeachSOAP
9
+ ForeachSOAP
10
10
 
11
11
  Sources:
12
12
  Baseline SOAP:
@@ -75,6 +75,8 @@ class ForeachSOAP(StatefulOptimizer):
75
75
  step_size = -group["lr"] * min(step / group['warmup_steps'], 1)
76
76
 
77
77
  for p, g, gp, ea, eas in zip(p_list, grad, grad_projected, exp_avg, exp_avg_sq):
78
+ state = self.state_(p)
79
+
78
80
  d = exp_avg_(ea, eas, g, gp, beta1, beta2, step_tensor)[0]
79
81
  # Projecting the exponential moving average of gradients to the eigenbases of Shampoo's preconditioner
80
82
  # i.e. projecting to the eigenbases of matrices in state['GG']
@@ -0,0 +1,89 @@
1
+ import torch
2
+
3
+ from .utils import init_preconditioner, update_preconditioner, project, beta_debias, exp_avg_sq_, update_param_, set_, \
4
+ StatefulOptimizer, laprop_exp_avg_
5
+
6
+
7
+ class ForeachSOLP(StatefulOptimizer):
8
+ """
9
+ ForeachSOLP
10
+
11
+ Sources:
12
+ LaProp:
13
+ LaProp: Separating Momentum and Adaptivity in Adam
14
+ Liu Ziyin, Zhikang T.Wang, Masahito Ueda
15
+ https://arxiv.org/abs/2002.04839
16
+ https://github.com/ClashLuke/HeavyBall/blob/main/heavyball/foreach_laprop.py
17
+
18
+ Baseline SOAP:
19
+ SOAP: Improving and Stabilizing Shampoo using Adam
20
+ Nikhil Vyas, Depen Morwani, Rosie Zhao, Itai Shapira, David Brandfonbrener, Lucas Janson, Sham Kakade
21
+ https://arxiv.org/abs/2409.11321
22
+ https://github.com/nikhilvyas/SOAP
23
+
24
+ ScheduleFree:
25
+ The Road Less Scheduled
26
+ Aaron Defazio, Xingyu Alice Yang, Harsh Mehta, Konstantin Mishchenko, Ahmed Khaled, Ashok Cutkosky
27
+ https://arxiv.org/abs/2405.15682
28
+ https://github.com/facebookresearch/schedule_free
29
+ """
30
+
31
+ def __init__(self, params, lr: float = 3e-3, betas=(0.9, 0.95), shampoo_beta: float = 0.95, eps: float = 1e-8,
32
+ weight_decay: float = 0.01, precondition_frequency: int = 2, max_precond_dim: int = 2048, #
33
+ merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
34
+ data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1,
35
+ split: bool = False, foreach: bool = True, mars: bool = False, caution: bool = False,
36
+ mars_gamma: float = 0.0025):
37
+ defaults = {"lr": lr, "betas": betas, "shampoo_beta": shampoo_beta, "eps": eps, "weight_decay": weight_decay,
38
+ "precondition_frequency": precondition_frequency, "max_precond_dim": max_precond_dim,
39
+ "merge_dims": merge_dims, "precondition_1d": precondition_1d, "normalize_grads": normalize_grads,
40
+ "correct_bias": correct_bias, 'warmup_steps': warmup_steps, 'split': split, 'mars': mars,
41
+ 'caution': caution, 'mars_gamma': mars_gamma}
42
+ super().__init__(params, defaults, foreach)
43
+ self._data_format = data_format
44
+
45
+ def _step(self, group):
46
+ vals = []
47
+ step = 0
48
+
49
+ max_precond_dim = group['max_precond_dim']
50
+ precondition_1d = group['precondition_1d']
51
+
52
+ for p, g in self.split_p_and_g_in_group(group, beta1=group['betas'][0]):
53
+ state = self.state_(p)
54
+ step = state['step'] = state.get("step", -1) + 1
55
+
56
+ if "exp_avg" not in state:
57
+ state["exp_avg"] = torch.zeros_like(g, dtype=torch.float32, memory_format=torch.preserve_format)
58
+ state["exp_avg_sq"] = torch.zeros_like(g, dtype=torch.float32, memory_format=torch.preserve_format)
59
+ init_preconditioner(g, state, max_precond_dim, precondition_1d)
60
+ update_preconditioner(g, state, max_precond_dim, precondition_1d, 0, True)
61
+ continue # first step is skipped so that we never use the current gradients in the projection.
62
+
63
+ # Projecting gradients to the eigenbases of Shampoo's preconditioner
64
+ # i.e. projecting to the eigenbases of matrices in state['GG']
65
+ grad_projected = project(g, state['Q'], False)
66
+ exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
67
+ vals.append((p, g, grad_projected, exp_avg, exp_avg_sq))
68
+
69
+ if not vals:
70
+ return
71
+
72
+ p_list, grad, grad_projected, exp_avg, exp_avg_sq = zip(*vals)
73
+ beta1, beta2 = group["betas"]
74
+
75
+ old_debiased2 = beta_debias(beta2, step)
76
+
77
+ # Decay the first and second moment running average coefficient
78
+ # In-place operations to update the averages at the same time
79
+ step_tensor = torch.empty((), dtype=torch.int32, device=p_list[0].device).fill_(step)
80
+
81
+ step_size = -group["lr"] * min(step / group['warmup_steps'], 1)
82
+
83
+ for p, g, gp, ea, eas in zip(p_list, grad, grad_projected, exp_avg, exp_avg_sq):
84
+ laprop_exp_avg_(ea, eas, gp, beta1, beta2, step_tensor)
85
+ state = self.state_(p)
86
+ precond = project(ea, state['Q'], True)
87
+
88
+ update_preconditioner(g, state, max_precond_dim, precondition_1d, old_debiased2, step > 0 and step % group['precondition_frequency'] == 0)
89
+ update_param_([p], [precond], step_size, group["weight_decay"], caution=group['caution'], grad=[g])
@@ -6,7 +6,7 @@ from .utils import init_preconditioner, update_preconditioner, project, beta_deb
6
6
 
7
7
  class PaLMForeachSOAP(StatefulOptimizer):
8
8
  """
9
- SFPaLMForeachSOAP
9
+ PaLMForeachSOAP
10
10
 
11
11
  Sources:
12
12
  Baseline SOAP:
@@ -84,6 +84,7 @@ class PaLMForeachSOAP(StatefulOptimizer):
84
84
  step_size = -group["lr"] * min(step / group['warmup_steps'], 1)
85
85
 
86
86
  for p, g, gp, ea, eas in zip(p_list, grad, grad_projected, exp_avg, exp_avg_sq):
87
+ state = self.state_(p)
87
88
  d = exp_avg_(ea, eas, g, gp, beta1, beta2, step_tensor)[0]
88
89
  # Projecting the exponential moving average of gradients to the eigenbases of Shampoo's preconditioner
89
90
  # i.e. projecting to the eigenbases of matrices in state['GG']
@@ -0,0 +1,98 @@
1
+ import torch
2
+
3
+ from .utils import init_preconditioner, update_preconditioner, project, beta_debias, exp_avg_sq_, update_param_, set_, \
4
+ StatefulOptimizer, laprop_exp_avg_
5
+
6
+
7
+ class PaLMForeachSOLP(StatefulOptimizer):
8
+ """
9
+ PaLMForeachSOAP
10
+
11
+ Sources:
12
+ LaProp:
13
+ LaProp: Separating Momentum and Adaptivity in Adam
14
+ Liu Ziyin, Zhikang T.Wang, Masahito Ueda
15
+ https://arxiv.org/abs/2002.04839
16
+ https://github.com/ClashLuke/HeavyBall/blob/main/heavyball/foreach_laprop.py
17
+
18
+ Baseline SOAP:
19
+ SOAP: Improving and Stabilizing Shampoo using Adam
20
+ Nikhil Vyas, Depen Morwani, Rosie Zhao, Itai Shapira, David Brandfonbrener, Lucas Janson, Sham Kakade
21
+ https://arxiv.org/abs/2409.11321
22
+ https://github.com/nikhilvyas/SOAP
23
+
24
+ ScheduleFree:
25
+ The Road Less Scheduled
26
+ Aaron Defazio, Xingyu Alice Yang, Harsh Mehta, Konstantin Mishchenko, Ahmed Khaled, Ashok Cutkosky
27
+ https://arxiv.org/abs/2405.15682
28
+ https://github.com/facebookresearch/schedule_free
29
+
30
+ Beta2 Schedule:
31
+ PaLM: Scaling Language Modeling with Pathways
32
+ Aakanksha Chowdhery, Sharan Narang, Jacob Devlin, Maarten Bosma, Gaurav Mishra, Adam Roberts, Paul Barham, Hyung Won Chung, Charles Sutton, Sebastian Gehrmann, Parker Schuh, Kensen Shi, Sasha Tsvyashchenko, Joshua Maynez, Abhishek Rao, Parker Barnes, Yi Tay, Noam Shazeer, Vinodkumar Prabhakaran, Emily Reif, Nan Du, Ben Hutchinson, Reiner Pope, James Bradbury, Jacob Austin, Michael Isard, Guy Gur-Ari, Pengcheng Yin, Toju Duke, Anselm Levskaya, Sanjay Ghemawat, Sunipa Dev, Henryk Michalewski, Xavier Garcia, Vedant Misra, Kevin Robinson, Liam Fedus, Denny Zhou, Daphne Ippolito, David Luan, Hyeontaek Lim, Barret Zoph, Alexander Spiridonov, Ryan Sepassi, David Dohan, Shivani Agrawal, Mark Omernick, Andrew M. Dai, Thanumalayan Sankaranarayana Pillai, Marie Pellat, Aitor Lewkowycz, Erica Moreira, Rewon Child, Oleksandr Polozov, Katherine Lee, Zongwei Zhou, Xuezhi Wang, Brennan Saeta, Mark Diaz, Orhan Firat, Michele Catasta, Jason Wei, Kathy Meier-Hellstern, Douglas Eck, Jeff Dean, Slav Petrov, Noah Fiedel
33
+ https://arxiv.org/abs/2204.02311
34
+ """
35
+
36
+ def __init__(self, params, lr: float = 3e-3, beta=0.9, betas=(None, None), shampoo_beta: float = 0.95,
37
+ eps: float = 1e-8, weight_decay: float = 0.01, precondition_frequency: int = 2,
38
+ max_precond_dim: int = 2048, #
39
+ merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
40
+ data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1,
41
+ beta2_scale: float = 0.8, split: bool = False, foreach: bool = True, mars: bool = False,
42
+ caution: bool = False, mars_gamma: float = 0.0025):
43
+ if betas[0] is not None:
44
+ beta = betas[0]
45
+ defaults = {"lr": lr, "beta": beta, "shampoo_beta": shampoo_beta, "eps": eps, "weight_decay": weight_decay,
46
+ "precondition_frequency": precondition_frequency, "max_precond_dim": max_precond_dim,
47
+ "merge_dims": merge_dims, "precondition_1d": precondition_1d, "normalize_grads": normalize_grads,
48
+ "correct_bias": correct_bias, 'warmup_steps': warmup_steps, 'beta2_scale': beta2_scale,
49
+ 'split': split, 'mars': mars, 'caution': caution, 'mars_gamma': mars_gamma}
50
+ super().__init__(params, defaults, foreach)
51
+ self._data_format = data_format
52
+
53
+ def _step(self, group):
54
+ vals = []
55
+ step = 0
56
+
57
+ max_precond_dim = group['max_precond_dim']
58
+ precondition_1d = group['precondition_1d']
59
+
60
+ for p, g in self.split_p_and_g_in_group(group, beta1=group['beta']):
61
+ state = self.state_(p)
62
+ step = state['step'] = state.get("step", -1) + 1
63
+
64
+ if "exp_avg" not in state:
65
+ state["exp_avg"] = torch.zeros_like(g, dtype=torch.float32, memory_format=torch.preserve_format)
66
+ state["exp_avg_sq"] = torch.zeros_like(g, dtype=torch.float32, memory_format=torch.preserve_format)
67
+ init_preconditioner(g, state, max_precond_dim, precondition_1d)
68
+ update_preconditioner(g, state, max_precond_dim, precondition_1d, 0, True)
69
+ continue # first step is skipped so that we never use the current gradients in the projection.
70
+
71
+ # Projecting gradients to the eigenbases of Shampoo's preconditioner
72
+ # i.e. projecting to the eigenbases of matrices in state['GG']
73
+ grad_projected = project(g, state['Q'], False)
74
+ exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
75
+ vals.append((p, g, grad_projected, exp_avg, exp_avg_sq))
76
+
77
+ if not vals:
78
+ return
79
+
80
+ p_list, grad, grad_projected, exp_avg, exp_avg_sq = zip(*vals)
81
+ beta1 = group["beta"]
82
+
83
+ beta2 = 1 - step ** -group['beta2_scale']
84
+ old_debiased2 = beta_debias(beta2, step)
85
+
86
+ # Decay the first and second moment running average coefficient
87
+ # In-place operations to update the averages at the same time
88
+ beta2 = torch.empty((), dtype=torch.float32, device=p_list[0].device).fill_(beta2)
89
+ step_tensor = torch.empty((), dtype=torch.int32, device=p_list[0].device).fill_(step)
90
+ step_size = -group["lr"] * min(step / group['warmup_steps'], 1)
91
+
92
+ for p, g, gp, ea, eas in zip(p_list, grad, grad_projected, exp_avg, exp_avg_sq):
93
+ laprop_exp_avg_(ea, eas, gp, beta1, beta2, step_tensor)
94
+ state = self.state_(p)
95
+ precond = project(ea, state['Q'], True)
96
+
97
+ update_preconditioner(g, state, max_precond_dim, precondition_1d, old_debiased2, step > 0 and step % group['precondition_frequency'] == 0)
98
+ update_param_([p], [precond], step_size, group["weight_decay"], caution=group['caution'], grad=[g])
@@ -0,0 +1,95 @@
1
+
2
+
3
+ import random
4
+
5
+ import torch
6
+
7
+ from .utils import init_preconditioner, update_preconditioner, project, beta_debias, update_param_, \
8
+ precond_schedule, set_, StatefulOptimizer, laprop_exp_avg_
9
+
10
+
11
+ class PrecondScheduleForeachSOLP(StatefulOptimizer):
12
+ """
13
+ Sources:
14
+ LaProp:
15
+ LaProp: Separating Momentum and Adaptivity in Adam
16
+ Liu Ziyin, Zhikang T.Wang, Masahito Ueda
17
+ https://arxiv.org/abs/2002.04839
18
+ https://github.com/ClashLuke/HeavyBall/blob/main/heavyball/foreach_laprop.py
19
+
20
+ Preconditioner Schedules:
21
+ Preconditioned Stochastic Gradient Descent
22
+ Xi-Lin Li, Omead Pooladzandi, Evan Walters
23
+ https://arxiv.org/abs/1512.04202
24
+ https://github.com/evanatyourservice/kron_torch
25
+ https://github.com/lixilinx/psgd_torch
26
+
27
+ Baseline SOAP:
28
+ SOAP: Improving and Stabilizing Shampoo using Adam
29
+ Nikhil Vyas, Depen Morwani, Rosie Zhao, Itai Shapira, David Brandfonbrener, Lucas Janson, Sham Kakade
30
+ https://arxiv.org/abs/2409.11321
31
+ https://github.com/nikhilvyas/SOAP
32
+ """
33
+
34
+ def __init__(self, params, lr: float = 3e-3, betas=(0.9, 0.95), shampoo_beta: float = 0.95, eps: float = 1e-8,
35
+ weight_decay: float = 0.01, precondition_frequency: int = 2, max_precond_dim: int = 2048, #
36
+ merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
37
+ data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1,
38
+ precond_scheduler=(1 / 3, 9), split: bool = False, foreach: bool = True, mars: bool = False,
39
+ caution: bool = False, mars_gamma: float = 0.0025):
40
+ defaults = {"lr": lr, "betas": betas, "shampoo_beta": shampoo_beta, "eps": eps, "weight_decay": weight_decay,
41
+ "precondition_frequency": precondition_frequency, "max_precond_dim": max_precond_dim,
42
+ "merge_dims": merge_dims, "precondition_1d": precondition_1d, "normalize_grads": normalize_grads,
43
+ "correct_bias": correct_bias, 'warmup_steps': warmup_steps, 'precond_scheduler': precond_scheduler,
44
+ 'split': split, 'mars': mars, 'caution': caution, 'mars_gamma': mars_gamma}
45
+ super().__init__(params, defaults, foreach)
46
+ self._data_format = data_format
47
+ self.rng = random.Random(0x120983109)
48
+
49
+ def _step(self, group):
50
+ vals = []
51
+ step = 0
52
+
53
+ max_precond_dim = group['max_precond_dim']
54
+ precondition_1d = group['precondition_1d']
55
+
56
+ for p, g in self.split_p_and_g_in_group(group, beta1=group['betas'][0]):
57
+ state = self.state_(p)
58
+ step = state['step'] = state.get("step", -1) + 1
59
+
60
+ if "exp_avg" not in state:
61
+ state["exp_avg"] = torch.zeros_like(g, dtype=torch.float32, memory_format=torch.preserve_format)
62
+ state["exp_avg_sq"] = torch.zeros_like(g, dtype=torch.float32, memory_format=torch.preserve_format)
63
+ init_preconditioner(g, state, max_precond_dim, precondition_1d)
64
+ update_preconditioner(g, state, max_precond_dim, precondition_1d, 0, True)
65
+ continue # first step is skipped so that we never use the current gradients in the projection.
66
+
67
+ # Projecting gradients to the eigenbases of Shampoo's preconditioner
68
+ # i.e. projecting to the eigenbases of matrices in state['GG']
69
+ grad_projected = project(g, state['Q'], False)
70
+ exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
71
+ vals.append((p, g, grad_projected, exp_avg, exp_avg_sq))
72
+
73
+ if not vals:
74
+ return
75
+
76
+ p_list, grad, grad_projected, exp_avg, exp_avg_sq = zip(*vals)
77
+ beta1, beta2 = group["betas"]
78
+
79
+ old_debiased2 = beta_debias(beta2, step)
80
+
81
+ # Decay the first and second moment running average coefficient
82
+ # In-place operations to update the averages at the same time
83
+ step_tensor = torch.empty((), dtype=torch.int32, device=p_list[0].device).fill_(step)
84
+
85
+ update_precond = precond_schedule(step, group['precond_scheduler'], self.rng)
86
+ step_size = -group["lr"] * min(step / group['warmup_steps'], 1)
87
+
88
+
89
+ for p, g, gp, ea, eas in zip(p_list, grad, grad_projected, exp_avg, exp_avg_sq):
90
+ laprop_exp_avg_(ea, eas, gp, beta1, beta2, step_tensor)
91
+ state = self.state_(p)
92
+ precond = project(ea, state['Q'], True)
93
+
94
+ update_preconditioner(g, state, max_precond_dim, precondition_1d, old_debiased2, update_precond)
95
+ update_param_([p], [precond], step_size, group["weight_decay"], caution=group['caution'], grad=[g])
@@ -0,0 +1,103 @@
1
+ import random
2
+
3
+ import torch
4
+
5
+ from .utils import init_preconditioner, update_preconditioner, project, beta_debias, laprop_exp_avg_, update_param_, \
6
+ precond_schedule, set_, StatefulOptimizer
7
+
8
+
9
+ class PrecondSchedulePaLMForeachSOLP(StatefulOptimizer):
10
+ """
11
+ Sources:
12
+ LaProp:
13
+ LaProp: Separating Momentum and Adaptivity in Adam
14
+ Liu Ziyin, Zhikang T.Wang, Masahito Ueda
15
+ https://arxiv.org/abs/2002.04839
16
+ https://github.com/ClashLuke/HeavyBall/blob/main/heavyball/foreach_laprop.py
17
+
18
+ Preconditioner Schedules:
19
+ Preconditioned Stochastic Gradient Descent
20
+ Xi-Lin Li, Omead Pooladzandi, Evan Walters
21
+ https://arxiv.org/abs/1512.04202
22
+ https://github.com/evanatyourservice/kron_torch
23
+ https://github.com/lixilinx/psgd_torch
24
+
25
+ Baseline SOAP:
26
+ SOAP: Improving and Stabilizing Shampoo using Adam
27
+ Nikhil Vyas, Depen Morwani, Rosie Zhao, Itai Shapira, David Brandfonbrener, Lucas Janson, Sham Kakade
28
+ https://arxiv.org/abs/2409.11321
29
+ https://github.com/nikhilvyas/SOAP
30
+
31
+ Beta2 Schedule:
32
+ PaLM: Scaling Language Modeling with Pathways
33
+ Aakanksha Chowdhery, Sharan Narang, Jacob Devlin, Maarten Bosma, Gaurav Mishra, Adam Roberts, Paul Barham, Hyung Won Chung, Charles Sutton, Sebastian Gehrmann, Parker Schuh, Kensen Shi, Sasha Tsvyashchenko, Joshua Maynez, Abhishek Rao, Parker Barnes, Yi Tay, Noam Shazeer, Vinodkumar Prabhakaran, Emily Reif, Nan Du, Ben Hutchinson, Reiner Pope, James Bradbury, Jacob Austin, Michael Isard, Guy Gur-Ari, Pengcheng Yin, Toju Duke, Anselm Levskaya, Sanjay Ghemawat, Sunipa Dev, Henryk Michalewski, Xavier Garcia, Vedant Misra, Kevin Robinson, Liam Fedus, Denny Zhou, Daphne Ippolito, David Luan, Hyeontaek Lim, Barret Zoph, Alexander Spiridonov, Ryan Sepassi, David Dohan, Shivani Agrawal, Mark Omernick, Andrew M. Dai, Thanumalayan Sankaranarayana Pillai, Marie Pellat, Aitor Lewkowycz, Erica Moreira, Rewon Child, Oleksandr Polozov, Katherine Lee, Zongwei Zhou, Xuezhi Wang, Brennan Saeta, Mark Diaz, Orhan Firat, Michele Catasta, Jason Wei, Kathy Meier-Hellstern, Douglas Eck, Jeff Dean, Slav Petrov, Noah Fiedel
34
+ https://arxiv.org/abs/2204.02311
35
+ """
36
+
37
+ def __init__(self, params, lr: float = 3e-3, beta=0.9, shampoo_beta: float = 0.95, eps: float = 1e-8,
38
+ weight_decay: float = 0.01, precondition_frequency: int = 2, max_precond_dim: int = 2048, #
39
+ merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
40
+ data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1,
41
+ precond_scheduler=(1 / 3, 9), betas=(None, None), beta2_scale: float = 0.8, split: bool = False,
42
+ foreach: bool = True, mars: bool = False, caution: bool = False, mars_gamma: float = 0.0025):
43
+ if betas[0] is not None:
44
+ beta = betas[0]
45
+ defaults = {"lr": lr, "beta": beta, "shampoo_beta": shampoo_beta, "eps": eps, "weight_decay": weight_decay,
46
+ "precondition_frequency": precondition_frequency, "max_precond_dim": max_precond_dim,
47
+ "merge_dims": merge_dims, "precondition_1d": precondition_1d, "normalize_grads": normalize_grads,
48
+ "correct_bias": correct_bias, 'warmup_steps': warmup_steps, 'precond_scheduler': precond_scheduler,
49
+ 'beta2_scale': beta2_scale, 'split': split, 'mars': mars, 'caution': caution,
50
+ 'mars_gamma': mars_gamma}
51
+ super().__init__(params, defaults, foreach)
52
+ self._data_format = data_format
53
+ self.rng = random.Random(0x120983109)
54
+
55
+ def _step(self, group):
56
+ vals = []
57
+ step = 0
58
+
59
+ max_precond_dim = group['max_precond_dim']
60
+ precondition_1d = group['precondition_1d']
61
+
62
+ for p, g in self.split_p_and_g_in_group(group, beta1=group['beta']):
63
+ state = self.state_(p)
64
+ step = state['step'] = state.get("step", -1) + 1
65
+
66
+ if "exp_avg" not in state:
67
+ state["exp_avg"] = torch.zeros_like(g, dtype=torch.float32, memory_format=torch.preserve_format)
68
+ state["exp_avg_sq"] = torch.zeros_like(g, dtype=torch.float32, memory_format=torch.preserve_format)
69
+ init_preconditioner(g, state, max_precond_dim, precondition_1d)
70
+ update_preconditioner(g, state, max_precond_dim, precondition_1d, 0, True)
71
+ continue # first step is skipped so that we never use the current gradients in the projection.
72
+
73
+ # Projecting gradients to the eigenbases of Shampoo's preconditioner
74
+ # i.e. projecting to the eigenbases of matrices in state['GG']
75
+ grad_projected = project(g, state['Q'], False)
76
+ exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
77
+ vals.append((p, g, grad_projected, exp_avg, exp_avg_sq))
78
+
79
+ if not vals:
80
+ return
81
+
82
+ p_list, grad, grad_projected, exp_avg, exp_avg_sq = zip(*vals)
83
+ beta1 = group["beta"]
84
+
85
+ beta2 = 1 - max(step, 1) ** -group['beta2_scale']
86
+ old_debiased1 = beta_debias(beta1, step)
87
+ old_debiased2 = beta_debias(beta2, step)
88
+
89
+ # Decay the first and second moment running average coefficient
90
+ # In-place operations to update the averages at the same time
91
+ beta2 = torch.empty((), dtype=torch.float32, device=p_list[0].device).fill_(beta2)
92
+ step_tensor = torch.empty((), dtype=torch.int32, device=p_list[0].device).fill_(step)
93
+
94
+ update_precond = precond_schedule(step, group['precond_scheduler'], self.rng)
95
+ step_size = -group["lr"] * min(step / group['warmup_steps'], 1)
96
+
97
+ for p, g, gp, ea, eas in zip(p_list, grad, grad_projected, exp_avg, exp_avg_sq):
98
+ laprop_exp_avg_(ea, eas, gp, beta1, beta2, step_tensor)
99
+ state = self.state_(p)
100
+ precond = project(ea, state['Q'], True)
101
+
102
+ update_preconditioner(g, state, max_precond_dim, precondition_1d, old_debiased2, update_precond)
103
+ update_param_([p], [precond], step_size, group["weight_decay"], caution=group['caution'], grad=[g])
@@ -492,22 +492,20 @@ class StatefulOptimizer(torch.optim.Optimizer):
492
492
  super().__init__(params, {**defaults, 'foreach': foreach})
493
493
  self.fake_groups = {}
494
494
  self.use_ema = use_ema
495
-
496
- def key(self, param: Tensor):
497
- return (param.data_ptr(), tuple(param.shape))
495
+ self.mapping = {}
498
496
 
499
497
  def get_groups(self, group):
500
498
  if group['foreach']:
501
499
  return [group]
502
500
 
503
501
  for p in group['params']:
504
- if self.key(p) not in self.fake_groups:
505
- self.fake_groups[self.key(p)] = {**group, 'params': [p]}
502
+ if p not in self.fake_groups:
503
+ self.fake_groups[p] = {**group, 'params': [p]}
506
504
 
507
- return [self.fake_groups[self.key(p)] for p in group['params']]
505
+ return [self.fake_groups[p] for p in group['params']]
508
506
 
509
507
  def state_(self, arg: Tensor):
510
- return self.state[self.key(arg)]
508
+ return self.state[self.mapping.get(arg, arg)]
511
509
 
512
510
  def mars_correct_list(self, group, p_list, g_list, mars_gamma, beta):
513
511
  for p, g in zip(p_list, g_list):
@@ -538,6 +536,8 @@ class StatefulOptimizer(torch.optim.Optimizer):
538
536
  p_views = merge_group(group, p)
539
537
  if grad is not None:
540
538
  grad = merge_group(group, grad)
539
+ for i, pv in enumerate(p_views):
540
+ self.mapping[pv] = (p, i)
541
541
  if isinstance(p_views, Tensor):
542
542
  yield p_views, grad
543
543
  continue
@@ -622,11 +622,14 @@ class StatefulOptimizer(torch.optim.Optimizer):
622
622
  for top_group in self.param_groups:
623
623
  for group in self.get_groups(top_group):
624
624
  self._step(group)
625
+ self.mapping.clear()
625
626
  if self.use_ema:
626
627
  self.ema_update(group)
628
+
627
629
  return loss
628
630
 
629
631
 
632
+
630
633
  class ScheduleFree(StatefulOptimizer):
631
634
  def eval(self):
632
635
  for group in self.param_groups:
@@ -690,6 +693,29 @@ def exp_avg_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor]
690
693
  return denom
691
694
 
692
695
 
696
+
697
+ @decorator_knowngood
698
+ def _compilable_laprop_exp_avg_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor],
699
+ grad_projected: List[Tensor], beta1: Tensor, beta2: Tensor, step: Tensor):
700
+ beta1 = beta_debias(beta1, step)
701
+ beta2 = beta_debias(beta2, step)
702
+
703
+ gp32, exp_avg_sq32 = [list(map(promote, x)) for x in [grad_projected, exp_avg_sq]]
704
+
705
+ denom = exp_avg_sq_(exp_avg_sq32, gp32, beta2, 1e-8)
706
+ gp32 = torch._foreach_div(gp32, denom)
707
+ stochastic_lerp_(exp_avg, gp32, 1 - beta1)
708
+
709
+ copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
710
+
711
+
712
+ def laprop_exp_avg_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad_projected: List[Tensor],
713
+ beta1: float, beta2: float, step: int):
714
+ exp_avg, exp_avg_sq, grad_projected = list_guard(exp_avg), list_guard(exp_avg_sq), list_guard(grad_projected)
715
+ beta1, beta, step = scalar_guard(beta1, exp_avg[0]), scalar_guard(beta2, exp_avg[0]), scalar_guard(step, exp_avg[0])
716
+ _compilable_laprop_exp_avg_(exp_avg, exp_avg_sq, grad_projected, beta1, beta2, step)
717
+
718
+
693
719
  @decorator_knowngood
694
720
  def _compilable_copy_stochastic_(target: Tensor, source: Tensor):
695
721
  """Taken as-is from https://github.com/pytorch/pytorch/issues/120376#issuecomment-1974828905"""
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: heavyball
3
- Version: 0.24.4
3
+ Version: 0.25.1
4
4
  Summary: Efficient optimizers
5
5
  Home-page: https://github.com/clashluke/heavyball
6
6
  Author: Lucas Nestler
@@ -10,11 +10,15 @@ heavyball/foreach_adopt.py
10
10
  heavyball/foreach_laprop.py
11
11
  heavyball/foreach_sfadamw.py
12
12
  heavyball/foreach_soap.py
13
+ heavyball/foreach_solp.py
13
14
  heavyball/p_adam.py
14
15
  heavyball/palm_foreach_sfadamw.py
15
16
  heavyball/palm_foreach_soap.py
17
+ heavyball/palm_foreach_solp.py
16
18
  heavyball/precond_schedule_foreach_soap.py
19
+ heavyball/precond_schedule_foreach_solp.py
17
20
  heavyball/precond_schedule_palm_foreach_soap.py
21
+ heavyball/precond_schedule_palm_foreach_solp.py
18
22
  heavyball/precond_schedule_sfpsoap.py
19
23
  heavyball/psgd_kron.py
20
24
  heavyball/pure_psgd.py
@@ -40,4 +44,5 @@ test/test_merge.py
40
44
  test/test_no_grad.py
41
45
  test/test_psgd.py
42
46
  test/test_soap.py
47
+ test/test_solp.py
43
48
  test/test_stochastic_updates.py
@@ -10,7 +10,7 @@ setuptools.setup(
10
10
  name='heavyball',
11
11
  license='BSD',
12
12
  description='Efficient optimizers',
13
- version='0.24.4',
13
+ version='0.25.1',
14
14
  long_description=README,
15
15
  url='https://github.com/clashluke/heavyball',
16
16
  packages=setuptools.find_packages(),
@@ -0,0 +1,50 @@
1
+ import heavyball
2
+ import heavyball.utils
3
+ import pytest
4
+ import torch
5
+ from benchmark.utils import get_optim
6
+ from heavyball.utils import clean, set_torch, ScheduleFree
7
+ from torch import nn
8
+ from torch._dynamo import config
9
+
10
+ config.cache_size_limit = 128
11
+
12
+ @pytest.mark.parametrize("opt", heavyball.__all__)
13
+ @pytest.mark.parametrize("size,depth", [(128, 2)])
14
+ def test_solp(opt, size, depth: int, iterations: int = 65536, outer_iterations: int = 2):
15
+ set_torch()
16
+ if 'SOAP' not in opt:
17
+ raise pytest.skip('This test is only for SOAP')
18
+
19
+ opt_name = opt
20
+ peaks = []
21
+ losses = []
22
+
23
+ for use_solp in [True, False]:
24
+ try:
25
+ opt = getattr(heavyball, opt_name.replace("SOAP", "SOLP") if use_solp else opt_name)
26
+ except AttributeError:
27
+ raise pytest.skip(f'{opt_name} does not have a SOLP variant')
28
+ print(opt, opt_name.replace("SOAP", "SOLP"))
29
+
30
+ torch.manual_seed(0x2131290)
31
+ peaks.append([])
32
+ losses.append([])
33
+
34
+ for i in range(outer_iterations):
35
+ model = nn.Sequential(*[nn.Linear(size, size) for _ in range(depth)]).cuda()
36
+ o = get_optim(opt, model.parameters(), lr=1e-5)
37
+
38
+ for _ in range(iterations):
39
+ loss = model(torch.randn((1024, size), device='cuda')).square().mean()
40
+ loss.backward()
41
+ o.step()
42
+ o.zero_grad()
43
+ losses[-1].append(loss.detach())
44
+
45
+ del model, o
46
+ clean()
47
+
48
+ for i, (l0, l1) in enumerate(zip(*losses)):
49
+ print(i, l0.item(), l1.item())
50
+ assert l0.item() <= l1.item() * 1.1
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes