heavyball 0.14.7__tar.gz → 0.15.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. {heavyball-0.14.7 → heavyball-0.15.1}/PKG-INFO +1 -1
  2. {heavyball-0.14.7 → heavyball-0.15.1}/heavyball/__init__.py +25 -3
  3. heavyball-0.15.1/heavyball/cached_psgd_kron.py +141 -0
  4. {heavyball-0.14.7 → heavyball-0.15.1}/heavyball/delayed_psgd.py +43 -51
  5. heavyball-0.15.1/heavyball/foreach_adamw.py +41 -0
  6. heavyball-0.15.1/heavyball/foreach_adopt.py +51 -0
  7. heavyball-0.15.1/heavyball/foreach_laprop.py +46 -0
  8. heavyball-0.15.1/heavyball/foreach_sfadamw.py +54 -0
  9. heavyball-0.15.1/heavyball/foreach_soap.py +92 -0
  10. {heavyball-0.14.7 → heavyball-0.15.1}/heavyball/p_adam.py +46 -50
  11. heavyball-0.15.1/heavyball/palm_foreach_sfadamw.py +56 -0
  12. {heavyball-0.14.7 → heavyball-0.15.1}/heavyball/palm_foreach_soap.py +56 -70
  13. heavyball-0.15.1/heavyball/precond_schedule_foreach_soap.py +96 -0
  14. {heavyball-0.14.7 → heavyball-0.15.1}/heavyball/precond_schedule_palm_foreach_soap.py +58 -73
  15. {heavyball-0.14.7 → heavyball-0.15.1}/heavyball/precond_schedule_sfpsoap.py +60 -72
  16. {heavyball-0.14.7 → heavyball-0.15.1}/heavyball/psgd_kron.py +43 -49
  17. {heavyball-0.14.7 → heavyball-0.15.1}/heavyball/pure_psgd.py +36 -43
  18. {heavyball-0.14.7 → heavyball-0.15.1}/heavyball/schedule_free_palm_foreach_soap.py +61 -72
  19. {heavyball-0.14.7 → heavyball-0.15.1}/heavyball/utils.py +23 -7
  20. {heavyball-0.14.7 → heavyball-0.15.1}/heavyball.egg-info/PKG-INFO +1 -1
  21. {heavyball-0.14.7 → heavyball-0.15.1}/heavyball.egg-info/SOURCES.txt +3 -0
  22. {heavyball-0.14.7 → heavyball-0.15.1}/setup.py +1 -1
  23. heavyball-0.15.1/test/test_closure.py +44 -0
  24. heavyball-0.15.1/test/test_no_grad.py +39 -0
  25. heavyball-0.14.7/heavyball/foreach_adamw.py +0 -51
  26. heavyball-0.14.7/heavyball/foreach_adopt.py +0 -61
  27. heavyball-0.14.7/heavyball/foreach_laprop.py +0 -56
  28. heavyball-0.14.7/heavyball/foreach_sfadamw.py +0 -64
  29. heavyball-0.14.7/heavyball/foreach_soap.py +0 -106
  30. heavyball-0.14.7/heavyball/palm_foreach_sfadamw.py +0 -66
  31. heavyball-0.14.7/heavyball/precond_schedule_foreach_soap.py +0 -110
  32. {heavyball-0.14.7 → heavyball-0.15.1}/LICENSE +0 -0
  33. {heavyball-0.14.7 → heavyball-0.15.1}/README.md +0 -0
  34. {heavyball-0.14.7 → heavyball-0.15.1}/heavyball.egg-info/dependency_links.txt +0 -0
  35. {heavyball-0.14.7 → heavyball-0.15.1}/heavyball.egg-info/requires.txt +0 -0
  36. {heavyball-0.14.7 → heavyball-0.15.1}/heavyball.egg-info/top_level.txt +0 -0
  37. {heavyball-0.14.7 → heavyball-0.15.1}/setup.cfg +0 -0
  38. {heavyball-0.14.7 → heavyball-0.15.1}/test/test_memory.py +0 -0
  39. {heavyball-0.14.7 → heavyball-0.15.1}/test/test_merge.py +0 -0
  40. {heavyball-0.14.7 → heavyball-0.15.1}/test/test_psgd.py +0 -0
  41. {heavyball-0.14.7 → heavyball-0.15.1}/test/test_soap.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: heavyball
3
- Version: 0.14.7
3
+ Version: 0.15.1
4
4
  Summary: Efficient optimizers
5
5
  Home-page: https://github.com/clashluke/heavyball
6
6
  Author: Lucas Nestler
@@ -1,3 +1,5 @@
1
+ from .cached_psgd_kron import ForeachCachedPSGDKron
2
+ from .delayed_psgd import ForeachDelayedPSGD
1
3
  from .foreach_adamw import ForeachAdamW
2
4
  from .foreach_adopt import ForeachADOPT
3
5
  from .foreach_laprop import ForeachLaProp
@@ -12,11 +14,31 @@ from .precond_schedule_sfpsoap import PrecondScheduleSFPaLMSOAP
12
14
  from .psgd_kron import ForeachPSGDKron
13
15
  from .pure_psgd import ForeachPurePSGD
14
16
  from .schedule_free_palm_foreach_soap import SFPaLMForeachSOAP
15
- from .delayed_psgd import ForeachDelayedPSGD
16
17
 
17
18
  PalmForEachSoap = PaLMForeachSOAP
18
19
 
20
+ PaLMSOAP = PaLMForeachSOAP
21
+ PaLMSFAdamW = PaLMForeachSFAdamW
22
+ PaLMSFSoap = SFPaLMForeachSOAP
23
+ PaLMForeachSOAP = PaLMForeachSOAP
24
+ PrecondScheduleSFPaLMSOAP = PrecondScheduleSFPaLMSOAP
25
+ SOAP = ForeachSOAP
26
+ SFAdamW = ForeachSFAdamW
27
+ LaProp = ForeachLaProp
28
+ ADOPT = ForeachADOPT
29
+ PrecondScheduleForeachSOAP = PrecondScheduleForeachSOAP
30
+ PrecondSchedulePaLMForeachSOAP = PrecondSchedulePaLMForeachSOAP
31
+ PSGDKron = ForeachPSGDKron
32
+ AdamW = ForeachAdamW
33
+ PurePSGD = ForeachPurePSGD
34
+ PaLMPAdam = ForeachPaLMPAdam
35
+ DelayedPSGD = ForeachDelayedPSGD
36
+ CachedPSGDKron = ForeachCachedPSGDKron
37
+
19
38
  __all__ = ['PalmForEachSoap', 'PaLMForeachSFAdamW', 'PaLMForeachSOAP', 'SFPaLMForeachSOAP', 'PrecondScheduleSFPaLMSOAP',
20
39
  'ForeachSOAP', 'ForeachSFAdamW', 'ForeachLaProp', 'ForeachADOPT', 'PrecondScheduleForeachSOAP',
21
- 'PrecondSchedulePaLMForeachSOAP', 'ForeachPSGDKron', 'ForeachAdamW', 'ForeachPurePSGD',
22
- 'ForeachPaLMPAdam', 'ForeachDelayedPSGD']
40
+ 'PrecondSchedulePaLMForeachSOAP', 'ForeachPSGDKron', 'ForeachAdamW', 'ForeachPurePSGD', 'ForeachPaLMPAdam',
41
+ 'ForeachDelayedPSGD', 'ForeachCachedPSGDKron', #
42
+ 'PaLMSOAP', 'PaLMSFAdamW', 'PaLMSFSoap', 'PaLMSFAdamW', 'PaLMForeachSOAP', 'PrecondScheduleSFPaLMSOAP',
43
+ 'SOAP', 'SFAdamW', 'LaProp', 'ADOPT', 'PSGDKron', 'AdamW', 'PurePSGD', 'PaLMPAdam', 'DelayedPSGD',
44
+ 'CachedPSGDKron']
@@ -0,0 +1,141 @@
1
+ """
2
+ Originally from Evan Walters and Omead Pooladzandi, 2024
3
+ Modified under Creative Commons Attribution 4.0 International
4
+ Source available at https://github.com/evanatyourservice/kron_torch/blob/97a2b5ee8a1a4c29e4780bbf6c521e545189eff9/kron_torch/kron.py
5
+ """
6
+
7
+ from typing import Optional
8
+
9
+ import torch
10
+ from heavyball.utils import einsum_base
11
+
12
+ from .utils import update_param_, warmup, psgd_precond_grad, init_Q_exprs, trust_region_clip_, PSGDBase, \
13
+ precond_update_prob_schedule, split_p_and_g_in_group, line_to_triu, triu_to_line, set_, einsum_base
14
+
15
+
16
+ class ForeachCachedPSGDKron(PSGDBase):
17
+ """Implements PSGD Kron from https://github.com/lixilinx/psgd_torch with cached preconditioners.
18
+
19
+ Args:
20
+ params (iterable): Iterable of parameters to optimize or dicts defining
21
+ parameter groups.
22
+ lr (float): Learning rate.
23
+ b1 (float): Momentum parameter.
24
+ weight_decay (float): Weight decay (L2 penalty).
25
+ preconditioner_update_probability (callable or float, optional): Probability of
26
+ updating the preconditioner. If None, defaults to a schedule that anneals
27
+ from 1.0 to 0.03 by 4000 steps.
28
+ max_size_triangular (int): Max size for dim's preconditioner to be triangular.
29
+ min_ndim_triangular (int): Minimum number of dimensions a layer needs
30
+ to have triangular preconditioners.
31
+ memory_save_mode: (string, optional), None, 'one_diag', or 'all_diag', None is default
32
+ to set all preconditioners to be triangular, 'one_diag' sets the largest
33
+ or last dim to be diagonal per layer, and 'all_diag' sets all preconditioners
34
+ to be diagonal.
35
+ momentum_into_precond_update: (bool), whether to send momentum into preconditioner
36
+ update instead of raw gradients.
37
+ """
38
+
39
+ def __init__(self, params, lr=0.001, beta=0.9, weight_decay=0.0, preconditioner_update_probability=None,
40
+ max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
41
+ momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
42
+ split: bool = False, clip_fn: Optional[callable] = None, store_triu_as_line: bool = True):
43
+ if not 0.0 <= lr:
44
+ raise ValueError(f"Invalid learning rate: {lr}")
45
+ if not 0.0 <= beta < 1.0:
46
+ raise ValueError(f"Invalid beta parameter: {beta}")
47
+ if not 0.0 <= weight_decay:
48
+ raise ValueError(f"Invalid weight_decay value: {weight_decay}")
49
+
50
+ if preconditioner_update_probability is None:
51
+ preconditioner_update_probability = precond_update_prob_schedule()
52
+ if clip_fn is None:
53
+ clip_fn = lambda x: trust_region_clip_(x, 0.9, 1.5)
54
+ self.preconditioner_update_probability = preconditioner_update_probability
55
+ self.clip_fn = clip_fn
56
+
57
+ defaults = dict(lr=lr, beta=beta, weight_decay=weight_decay, max_size_triangular=max_size_triangular,
58
+ min_ndim_triangular=min_ndim_triangular, memory_save_mode=memory_save_mode,
59
+ momentum_into_precond_update=momentum_into_precond_update, precond_lr=0.1,
60
+ # precond lr hardcoded to 0.1
61
+ precond_init_scale=1.0, # precond init scale hardcoded to 1.0
62
+ step=0, warmup_steps=warmup_steps, merge_dims=merge_dims, split=split,
63
+ store_triu_as_line=store_triu_as_line)
64
+ super().__init__(params, defaults)
65
+
66
+ self._prob_step = 0
67
+
68
+ def _step(self, group):
69
+ # update preconditioners all together
70
+ update_prob = self.preconditioner_update_probability
71
+ if callable(update_prob):
72
+ update_prob = update_prob(self._prob_step)
73
+ do_update = self.rng.random() < update_prob
74
+ self._prob_step += 1
75
+
76
+ momentum_into_precond_update = group.get("momentum_into_precond_update", True)
77
+ precond_init_scale = group['precond_init_scale']
78
+ max_size_triangular = group['max_size_triangular']
79
+ min_ndim_triangular = group['min_ndim_triangular']
80
+ memory_save_mode = group['memory_save_mode']
81
+ precond_lr = group['precond_lr']
82
+ weight_decay = group['weight_decay']
83
+ lr = group['lr']
84
+ beta = group['beta']
85
+ store_triu_as_line = group['store_triu_as_line']
86
+
87
+ vals = []
88
+
89
+ for p, g in split_p_and_g_in_group(group):
90
+ state = self.state_(p)
91
+
92
+ if 'Q' not in state:
93
+ state["exp_avg"] = torch.zeros_like(g)
94
+ Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular, min_ndim_triangular,
95
+ memory_save_mode, dtype=g.dtype)
96
+ state['Q'] = triu_to_line(Q) if store_triu_as_line else Q
97
+ state['Q_cache'] = [torch.empty_like(q) for q in Q]
98
+
99
+ expr = [f'{c.upper()}{c}' if q_.ndim == 2 else c for c, q_ in zip(einsum_base, Q)]
100
+ expr = ','.join(expr)
101
+ grad_expr = ''.join(c for c, _ in zip(einsum_base, g.shape))
102
+ out_expr = ''.join(c.upper() if c.upper() in expr else c for c in grad_expr)
103
+ expr = f'{expr},{grad_expr}->{out_expr}'
104
+
105
+ state['cache_expr'] = expr
106
+
107
+ vals.append((p, g, state["exp_avg"], state["Q"], state['Q_cache']))
108
+
109
+ if not vals:
110
+ return
111
+
112
+ p_list, grad_list, exp_avg_list, Q_list, Q_cache_list = zip(*vals)
113
+ del vals
114
+
115
+ group["step"] += 1
116
+
117
+ torch._foreach_lerp_(exp_avg_list, grad_list, (1 - beta) / (1 - beta ** group["step"]))
118
+
119
+ grad_list, Q_list, Q_cache_list, exp_avg_list = list(grad_list), list(Q_list), list(Q_cache_list), list(
120
+ exp_avg_list)
121
+ for i, (p, g) in enumerate(zip(p_list, grad_list)):
122
+ cached_q = Q_cache_list.pop(0)
123
+ q_orig = Q_list.pop(0)
124
+ ea = exp_avg_list.pop(0)
125
+
126
+ if do_update:
127
+ q = line_to_triu(q_orig) if store_triu_as_line else q_orig
128
+ self.balance([g], [q])
129
+ self.do_update([p], [ea if momentum_into_precond_update else g], [q], precond_lr,
130
+ [q_orig] if store_triu_as_line else None)
131
+ for c_, q_ in zip(cached_q, q):
132
+ if q_.ndim == 2:
133
+ torch.matmul(q_.T.conj(), q_, out=c_)
134
+ else:
135
+ torch.mul(q_.conj(), q_, out=c_)
136
+
137
+ set_(g, torch.einsum(self.state_(p)['cache_expr'], *cached_q, ea))
138
+ grad_list = self.clip_fn(grad_list)
139
+
140
+ lr = -warmup(lr, group['step'], group['warmup_steps'])
141
+ update_param_(p_list, grad_list, lr, weight_decay)
@@ -5,8 +5,8 @@ Source available at https://github.com/evanatyourservice/kron_torch/blob/97a2b5e
5
5
  """
6
6
 
7
7
  import torch
8
-
9
8
  from heavyball.utils import copy_stochastic_list_
9
+
10
10
  from .utils import update_param_, warmup, psgd_precond_grad, init_Q_exprs, trust_region_clip_, PSGDBase, \
11
11
  precond_update_prob_schedule, split_p_and_g_in_group, triu_to_line, line_to_triu, set_
12
12
 
@@ -38,7 +38,7 @@ class ForeachDelayedPSGD(PSGDBase):
38
38
  def __init__(self, params, lr=0.001, beta=0.9, weight_decay=0.0, preconditioner_update_probability=None,
39
39
  max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
40
40
  momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
41
- split: bool = False, clip_fn: callable = None):
41
+ split: bool = False, clip_fn: callable = None, store_triu_as_line: bool = True):
42
42
  if not 0.0 <= lr:
43
43
  raise ValueError(f"Invalid learning rate: {lr}")
44
44
  if not 0.0 <= beta < 1.0:
@@ -58,18 +58,13 @@ class ForeachDelayedPSGD(PSGDBase):
58
58
  momentum_into_precond_update=momentum_into_precond_update, precond_lr=0.1,
59
59
  # precond lr hardcoded to 0.1
60
60
  precond_init_scale=1.0, # precond init scale hardcoded to 1.0
61
- step=0, warmup_steps=warmup_steps, merge_dims=merge_dims, split=split)
61
+ step=0, warmup_steps=warmup_steps, merge_dims=merge_dims, split=split,
62
+ store_triu_as_line=store_triu_as_line)
62
63
  super().__init__(params, defaults)
63
64
 
64
65
  self._prob_step = 0
65
66
 
66
- @torch.no_grad()
67
- def step(self, closure=None):
68
- loss = None
69
- if closure is not None:
70
- with torch.enable_grad():
71
- loss = closure()
72
-
67
+ def _step(self, group):
73
68
  # update preconditioners all together
74
69
  update_prob = self.preconditioner_update_probability
75
70
  if callable(update_prob):
@@ -77,55 +72,52 @@ class ForeachDelayedPSGD(PSGDBase):
77
72
  do_update = self.rng.random() < update_prob
78
73
  self._prob_step += 1
79
74
 
80
- for group in self.param_groups:
81
- momentum_into_precond_update = group.get("momentum_into_precond_update", True)
82
- precond_init_scale = group['precond_init_scale']
83
- max_size_triangular = group['max_size_triangular']
84
- min_ndim_triangular = group['min_ndim_triangular']
85
- memory_save_mode = group['memory_save_mode']
86
- precond_lr = group['precond_lr']
87
- weight_decay = group['weight_decay']
88
- lr = group['lr']
89
- beta = group['beta']
90
-
91
- vals = []
92
-
93
- for p, g in split_p_and_g_in_group(group):
94
- state = self.state_(p)
75
+ momentum_into_precond_update = group.get("momentum_into_precond_update", True)
76
+ precond_init_scale = group['precond_init_scale']
77
+ max_size_triangular = group['max_size_triangular']
78
+ min_ndim_triangular = group['min_ndim_triangular']
79
+ memory_save_mode = group['memory_save_mode']
80
+ precond_lr = group['precond_lr']
81
+ weight_decay = group['weight_decay']
82
+ lr = group['lr']
83
+ beta = group['beta']
84
+ store_triu_as_line = group['store_triu_as_line']
95
85
 
96
- if 'Q' not in state:
97
- state["exp_avg"] = torch.zeros_like(g)
98
- Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular, min_ndim_triangular,
99
- memory_save_mode, dtype=g.dtype)
100
- state["Q"] = triu_to_line(Q)
86
+ vals = []
101
87
 
102
- vals.append((p, g, state["exp_avg"], state["Q"]))
88
+ for p, g in split_p_and_g_in_group(group):
89
+ state = self.state_(p)
103
90
 
104
- if not vals:
105
- continue
91
+ if 'Q' not in state:
92
+ state["exp_avg"] = torch.zeros_like(g)
93
+ Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular, min_ndim_triangular,
94
+ memory_save_mode, dtype=g.dtype)
95
+ state["Q"] = triu_to_line(Q) if store_triu_as_line else Q
106
96
 
107
- p_list, grad_list, exp_avg_list, Q_list = zip(*vals)
108
- del vals
97
+ vals.append((p, g, state["exp_avg"], state["Q"]))
109
98
 
110
- group["step"] += 1
99
+ if not vals:
100
+ return
111
101
 
112
- torch._foreach_lerp_(exp_avg_list, grad_list, (1 - beta) / (1 - beta ** group["step"]))
102
+ p_list, grad_list, exp_avg_list, Q_list = zip(*vals)
103
+ del vals
113
104
 
114
- Q_list, exp_avg_list = list(Q_list), list(exp_avg_list)
115
- for i, (p, g) in enumerate(zip(p_list, grad_list)):
116
- q_orig = Q_list.pop(0)
117
- ea = exp_avg_list.pop(0)
118
- q = line_to_triu(q_orig)
119
- self.balance(do_update, [g], [q])
120
- new = psgd_precond_grad(q, self.state_(p)["exprs"], ea)
105
+ group["step"] += 1
121
106
 
122
- if do_update:
123
- self.do_update([p], [ea if momentum_into_precond_update else g], [q], precond_lr, [q_orig])
124
- set_(g, new)
107
+ torch._foreach_lerp_(exp_avg_list, grad_list, (1 - beta) / (1 - beta ** group["step"]))
125
108
 
126
- grad_list = self.clip_fn(grad_list)
109
+ Q_list, exp_avg_list = list(Q_list), list(exp_avg_list)
110
+ for i, (p, g) in enumerate(zip(p_list, grad_list)):
111
+ q_orig = Q_list.pop(0)
112
+ ea = exp_avg_list.pop(0)
113
+ q = line_to_triu(q_orig) if store_triu_as_line else q_orig
114
+ new = psgd_precond_grad(q, self.state_(p)["exprs"], ea)
115
+ if do_update:
116
+ self.do_update([p], [ea if momentum_into_precond_update else g], [q], precond_lr, [q_orig] if store_triu_as_line else None)
117
+ self.balance([g], [q])
118
+ set_(g, new)
127
119
 
128
- lr = -warmup(lr, group['step'], group['warmup_steps'])
129
- update_param_(p_list, grad_list, lr, weight_decay)
120
+ grad_list = self.clip_fn(grad_list)
130
121
 
131
- return loss
122
+ lr = -warmup(lr, group['step'], group['warmup_steps'])
123
+ update_param_(p_list, grad_list, lr, weight_decay)
@@ -0,0 +1,41 @@
1
+ import torch
2
+ import torch.optim
3
+
4
+ from .utils import warmup, exp_avg_sq_, beta_debias, update_param_, StatefulOptimizer
5
+
6
+
7
+ class ForeachAdamW(StatefulOptimizer):
8
+ def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0):
9
+ defaults = dict(lr=lr, betas=betas, eps=eps, k=0, warmup_steps=warmup_steps, train_mode=True, weight_sum=0.0,
10
+ lr_max=-1.0, weight_decay=weight_decay)
11
+ super().__init__(params, defaults)
12
+
13
+ def _step(self, group):
14
+ eps = group['eps']
15
+ decay = group['weight_decay']
16
+ k = group['k']
17
+
18
+ if not group['train_mode']:
19
+ raise Exception("Not in train mode!")
20
+
21
+ active_p = [p for p in group['params'] if p.grad is not None]
22
+
23
+ if not active_p:
24
+ return
25
+
26
+ for p in active_p:
27
+ if 'exp_avg' not in self.state_(p):
28
+ self.state_(p)['exp_avg'] = torch.zeros_like(p.data, dtype=torch.float32)
29
+ self.state_(p)['exp_avg_sq'] = torch.zeros_like(p.data, dtype=torch.float32)
30
+
31
+ y, grad, exp_avg_sq, exp_avg = zip(
32
+ *[(p.data, p.grad.float(), self.state_(p)['exp_avg_sq'], self.state_(p)['exp_avg']) for p in active_p])
33
+
34
+ # Decay the first and second moment running average coefficient
35
+ torch._foreach_lerp_(exp_avg, grad, 1 - beta_debias(group['betas'][0], k + 1))
36
+ denom = exp_avg_sq_(exp_avg_sq, grad, beta_debias(group['betas'][1], k + 1), eps)
37
+
38
+ # Normalize grad in-place for memory efficiency
39
+ lr = -warmup(group['lr'], k + 1, group['warmup_steps'])
40
+ update_param_(y, exp_avg, lr, decay, lambda p, e, l: torch._foreach_addcdiv_(p, e, denom, l))
41
+ group['k'] = k + 1
@@ -0,0 +1,51 @@
1
+ import torch
2
+ import torch.optim
3
+
4
+ from .utils import warmup, beta_debias, update_param_, StatefulOptimizer
5
+
6
+
7
+ class ForeachADOPT(StatefulOptimizer):
8
+
9
+ def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0):
10
+ defaults = dict(lr=lr, betas=betas, eps=eps, k=0, warmup_steps=warmup_steps, train_mode=True, weight_sum=0.0,
11
+ lr_max=-1.0, weight_decay=weight_decay)
12
+ super().__init__(params, defaults)
13
+
14
+ def _step(self, group):
15
+ eps = group['eps']
16
+ decay = group['weight_decay']
17
+ k = group['k']
18
+
19
+ if not group['train_mode']:
20
+ raise Exception("Not in train mode!")
21
+
22
+ active_p = [p for p in group['params'] if p.grad is not None]
23
+
24
+ if not active_p:
25
+ return
26
+
27
+ for p in active_p:
28
+ if 'exp_avg' not in self.state_(p):
29
+ self.state_(p)['exp_avg'] = torch.zeros_like(p.data, dtype=torch.float32)
30
+ self.state_(p)['exp_avg_sq'] = torch.zeros_like(p.data, dtype=torch.float32)
31
+
32
+ y, grad, exp_avg_sq, exp_avg = zip(
33
+ *[(p.data, p.grad.float(), self.state_(p)['exp_avg_sq'], self.state_(p)['exp_avg']) for p in active_p])
34
+
35
+ if k > 1:
36
+ lr = -warmup(group['lr'], k - 1, group['warmup_steps'])
37
+
38
+ update_param_(y, exp_avg, lr, decay)
39
+ if k > 0:
40
+ beta1 = beta_debias(group['betas'][0], k)
41
+ denom = torch._foreach_sqrt(exp_avg_sq)
42
+ torch._foreach_maximum_(denom, eps)
43
+ torch._foreach_mul_(exp_avg, beta1)
44
+ torch._foreach_addcdiv_(exp_avg, grad, denom, 1 - beta1)
45
+
46
+ beta2 = beta_debias(group['betas'][1], k + 1)
47
+ torch._foreach_mul_(exp_avg_sq, beta2)
48
+ torch._foreach_addcmul_(exp_avg_sq, grad, grad, value=1 - beta2)
49
+ del grad
50
+
51
+ group['k'] = k + 1
@@ -0,0 +1,46 @@
1
+ import torch
2
+ import torch.optim
3
+
4
+ from .utils import warmup, exp_avg_sq_, beta_debias, update_param_, StatefulOptimizer
5
+
6
+
7
+ class ForeachLaProp(StatefulOptimizer):
8
+
9
+ def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=1):
10
+ defaults = dict(lr=lr, betas=betas, eps=eps, k=0, warmup_steps=warmup_steps, train_mode=True, weight_sum=0.0,
11
+ lr_max=-1.0, weight_decay=weight_decay)
12
+ super().__init__(params, defaults)
13
+
14
+ def _step(self, group):
15
+ eps = group['eps']
16
+ decay = group['weight_decay']
17
+ k = group['k']
18
+
19
+ if not group['train_mode']:
20
+ raise Exception("Not in train mode!")
21
+
22
+ active_p = [p for p in group['params'] if p.grad is not None]
23
+
24
+ if not active_p:
25
+ return
26
+
27
+ for p in active_p:
28
+ if 'exp_avg' not in self.state_(p):
29
+ self.state_(p)['exp_avg'] = torch.zeros_like(p.data, dtype=torch.float32)
30
+ self.state_(p)['exp_avg_sq'] = torch.zeros_like(p.data, dtype=torch.float32)
31
+
32
+ y, grad, exp_avg_sq, exp_avg = zip(
33
+ *[(p.data, p.grad.float(), self.state_(p)['exp_avg_sq'], self.state_(p)['exp_avg']) for p in active_p])
34
+
35
+ # Decay the first and second moment running average coefficient
36
+ denom = exp_avg_sq_(exp_avg_sq, grad, beta_debias(group['betas'][1], k + 1), eps)
37
+ beta1 = beta_debias(group['betas'][0], k + 1)
38
+ torch._foreach_mul_(exp_avg, beta1)
39
+ torch._foreach_addcdiv_(exp_avg, grad, denom, 1 - beta1)
40
+ del grad
41
+
42
+ # Normalize grad in-place for memory efficiency
43
+ lr = -warmup(group['lr'], k + 1, group['warmup_steps'])
44
+ update_param_(y, exp_avg, lr, decay)
45
+
46
+ group['k'] = k + 1
@@ -0,0 +1,54 @@
1
+ import torch
2
+ import torch.optim
3
+
4
+ from .utils import schedule_free_, warmup, ScheduleFree, exp_avg_sq_, beta_debias
5
+
6
+
7
+ class ForeachSFAdamW(ScheduleFree):
8
+ def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0, r=0.0,
9
+ weight_lr_power=2.0, foreach=hasattr(torch, "_foreach_mul_")):
10
+
11
+ defaults = dict(lr=lr, betas=betas, eps=eps, r=r, k=0, warmup_steps=warmup_steps, train_mode=True,
12
+ weight_sum=0.0, lr_max=-1.0, weight_lr_power=weight_lr_power, weight_decay=weight_decay,
13
+ foreach=foreach)
14
+ super().__init__(params, defaults)
15
+
16
+ def _step(self, group):
17
+ eps = group['eps']
18
+ decay = group['weight_decay']
19
+ k = group['k']
20
+
21
+ if not group['train_mode']:
22
+ raise Exception("Not in train mode!")
23
+
24
+ active_p = [p for p in group['params'] if p.grad is not None]
25
+
26
+ if not active_p:
27
+ return
28
+
29
+ for p in active_p:
30
+ if 'z' not in self.state_(p):
31
+ self.state_(p)['z'] = torch.clone(p.data)
32
+ self.state_(p)['exp_avg_sq'] = torch.zeros_like(p.data, dtype=torch.float32)
33
+
34
+ y, grad, exp_avg_sq, z = zip(
35
+ *[(p.data, p.grad.float(), self.state_(p)['exp_avg_sq'], self.state_(p)['z']) for p in active_p])
36
+
37
+ # Decay the first moment running average coefficient
38
+ old_debiased = beta_debias(group['betas'][1], k + 1)
39
+
40
+ # Decay the first and second moment running average coefficient
41
+ denom = exp_avg_sq_(exp_avg_sq, grad, old_debiased, eps)
42
+
43
+ # Normalize grad in-place for memory efficiency
44
+ torch._foreach_div_(grad, denom)
45
+
46
+ # Weight decay calculated at y
47
+ if decay != 0:
48
+ torch._foreach_add_(grad, y, alpha=decay)
49
+
50
+ lr = warmup(group['lr'], k + 1, group['warmup_steps'])
51
+ group['weight_sum'] = schedule_free_(lr, group['weight_lr_power'], group['weight_sum'], group['betas'][0],
52
+ y, z, grad, group['r'], k + 1)
53
+
54
+ group['k'] = k + 1
@@ -0,0 +1,92 @@
1
+ import torch
2
+
3
+ from .utils import init_preconditioner, update_preconditioner, project, beta_debias, exp_avg_sq_, update_param_, set_, \
4
+ split_p_and_g_in_group, StatefulOptimizer
5
+
6
+
7
+ class ForeachSOAP(StatefulOptimizer):
8
+ """
9
+ SFPaLMForeachSOAP
10
+
11
+ Sources:
12
+ Baseline SOAP:
13
+ SOAP: Improving and Stabilizing Shampoo using Adam
14
+ Nikhil Vyas, Depen Morwani, Rosie Zhao, Itai Shapira, David Brandfonbrener, Lucas Janson, Sham Kakade
15
+ https://arxiv.org/abs/2409.11321
16
+ https://github.com/nikhilvyas/SOAP
17
+
18
+ ScheduleFree:
19
+ The Road Less Scheduled
20
+ Aaron Defazio, Xingyu Alice Yang, Harsh Mehta, Konstantin Mishchenko, Ahmed Khaled, Ashok Cutkosky
21
+ https://arxiv.org/abs/2405.15682
22
+ https://github.com/facebookresearch/schedule_free
23
+ """
24
+
25
+ def __init__(self, params, lr: float = 3e-3, betas=(0.9, 0.95), shampoo_beta: float = 0.95, eps: float = 1e-8,
26
+ weight_decay: float = 0.01, precondition_frequency: int = 2, max_precond_dim: int = 2048, #
27
+ merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
28
+ data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1,
29
+ split: bool = False):
30
+ defaults = {"lr": lr, "betas": betas, "shampoo_beta": shampoo_beta, "eps": eps, "weight_decay": weight_decay,
31
+ "precondition_frequency": precondition_frequency, "max_precond_dim": max_precond_dim,
32
+ "merge_dims": merge_dims, "precondition_1d": precondition_1d, "normalize_grads": normalize_grads,
33
+ "correct_bias": correct_bias, 'warmup_steps': warmup_steps, 'split': split}
34
+ super().__init__(params, defaults)
35
+ self._data_format = data_format
36
+
37
+ def _step(self, group):
38
+ vals = []
39
+ step = 0
40
+
41
+ max_precond_dim = group['max_precond_dim']
42
+ precondition_1d = group['precondition_1d']
43
+
44
+ for p, g in split_p_and_g_in_group(group):
45
+ state = self.state_(p)
46
+ step = state['step'] = state.get("step", -1) + 1
47
+
48
+ if "exp_avg" not in state:
49
+ state["exp_avg"] = torch.zeros_like(g, dtype=torch.float32)
50
+ state["exp_avg_sq"] = torch.zeros_like(g, dtype=torch.float32)
51
+ init_preconditioner(g, state, max_precond_dim, precondition_1d)
52
+ update_preconditioner(g, state, max_precond_dim, precondition_1d, 0, True)
53
+ continue # first step is skipped so that we never use the current gradients in the projection.
54
+
55
+ # Projecting gradients to the eigenbases of Shampoo's preconditioner
56
+ # i.e. projecting to the eigenbases of matrices in state['GG']
57
+ grad_projected = project(g, state['Q'], False)
58
+ exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
59
+ vals.append((p, g, grad_projected, exp_avg, exp_avg_sq))
60
+
61
+ if not vals:
62
+ return
63
+
64
+ p_list, grad, grad_projected, exp_avg, exp_avg_sq = zip(*vals)
65
+ beta1, beta2 = group["betas"]
66
+
67
+ old_debiased1 = beta_debias(beta1, step)
68
+ old_debiased2 = beta_debias(beta2, step)
69
+
70
+ # Decay the first and second moment running average coefficient
71
+ # In-place operations to update the averages at the same time
72
+ torch._foreach_mul_(exp_avg, old_debiased1)
73
+ torch._foreach_add_(exp_avg, grad, alpha=1 - old_debiased1)
74
+ denom = exp_avg_sq_(exp_avg_sq, grad_projected, old_debiased2, group['eps'])
75
+
76
+ for p, g, ea, d in zip(p_list, grad, exp_avg, denom):
77
+ state = self.state_(p)
78
+ # Projecting the exponential moving average of gradients to the eigenbases of Shampoo's preconditioner
79
+ # i.e. projecting to the eigenbases of matrices in state['GG']
80
+ exp_avg_projected = project(ea, state['Q'], False)
81
+
82
+ # Projecting back the preconditioned (by Adam) exponential moving average of gradients
83
+ # to the original space
84
+ # CANT DO /= HERE AS EXP_AVG MAY POINT TO THE BUFFER
85
+ set_(d, project(exp_avg_projected / d, state['Q'], True))
86
+
87
+ update_preconditioner(g, state, max_precond_dim, precondition_1d, old_debiased2,
88
+ step > 0 and step % group['precondition_frequency'] == 0)
89
+
90
+ # Why does this have to be rebiased here?
91
+ step_size = -group["lr"] * min(step / group['warmup_steps'], 1)
92
+ update_param_(p_list, denom, step_size, group["weight_decay"])