heavyball 0.15.0__tar.gz → 0.16.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. {heavyball-0.15.0 → heavyball-0.16.0}/PKG-INFO +4 -2
  2. {heavyball-0.15.0 → heavyball-0.16.0}/README.md +3 -1
  3. {heavyball-0.15.0 → heavyball-0.16.0}/heavyball/__init__.py +24 -2
  4. heavyball-0.16.0/heavyball/cached_psgd_kron.py +142 -0
  5. {heavyball-0.15.0 → heavyball-0.16.0}/heavyball/delayed_psgd.py +11 -8
  6. {heavyball-0.15.0 → heavyball-0.16.0}/heavyball/foreach_adamw.py +3 -2
  7. {heavyball-0.15.0 → heavyball-0.16.0}/heavyball/foreach_adopt.py +3 -2
  8. {heavyball-0.15.0 → heavyball-0.16.0}/heavyball/foreach_laprop.py +3 -2
  9. {heavyball-0.15.0 → heavyball-0.16.0}/heavyball/foreach_sfadamw.py +4 -4
  10. {heavyball-0.15.0 → heavyball-0.16.0}/heavyball/foreach_soap.py +4 -3
  11. {heavyball-0.15.0 → heavyball-0.16.0}/heavyball/p_adam.py +14 -8
  12. {heavyball-0.15.0 → heavyball-0.16.0}/heavyball/palm_foreach_sfadamw.py +3 -2
  13. {heavyball-0.15.0 → heavyball-0.16.0}/heavyball/palm_foreach_soap.py +3 -2
  14. {heavyball-0.15.0 → heavyball-0.16.0}/heavyball/precond_schedule_foreach_soap.py +3 -2
  15. {heavyball-0.15.0 → heavyball-0.16.0}/heavyball/precond_schedule_palm_foreach_soap.py +3 -2
  16. {heavyball-0.15.0 → heavyball-0.16.0}/heavyball/precond_schedule_sfpsoap.py +3 -3
  17. {heavyball-0.15.0 → heavyball-0.16.0}/heavyball/psgd_kron.py +11 -7
  18. {heavyball-0.15.0 → heavyball-0.16.0}/heavyball/pure_psgd.py +10 -7
  19. {heavyball-0.15.0 → heavyball-0.16.0}/heavyball/schedule_free_palm_foreach_soap.py +4 -3
  20. {heavyball-0.15.0 → heavyball-0.16.0}/heavyball/utils.py +29 -11
  21. {heavyball-0.15.0 → heavyball-0.16.0}/heavyball.egg-info/PKG-INFO +4 -2
  22. {heavyball-0.15.0 → heavyball-0.16.0}/heavyball.egg-info/SOURCES.txt +2 -0
  23. {heavyball-0.15.0 → heavyball-0.16.0}/setup.py +1 -1
  24. heavyball-0.16.0/test/test_foreach.py +65 -0
  25. {heavyball-0.15.0 → heavyball-0.16.0}/LICENSE +0 -0
  26. {heavyball-0.15.0 → heavyball-0.16.0}/heavyball.egg-info/dependency_links.txt +0 -0
  27. {heavyball-0.15.0 → heavyball-0.16.0}/heavyball.egg-info/requires.txt +0 -0
  28. {heavyball-0.15.0 → heavyball-0.16.0}/heavyball.egg-info/top_level.txt +0 -0
  29. {heavyball-0.15.0 → heavyball-0.16.0}/setup.cfg +0 -0
  30. {heavyball-0.15.0 → heavyball-0.16.0}/test/test_closure.py +0 -0
  31. {heavyball-0.15.0 → heavyball-0.16.0}/test/test_memory.py +0 -0
  32. {heavyball-0.15.0 → heavyball-0.16.0}/test/test_merge.py +0 -0
  33. {heavyball-0.15.0 → heavyball-0.16.0}/test/test_no_grad.py +0 -0
  34. {heavyball-0.15.0 → heavyball-0.16.0}/test/test_psgd.py +0 -0
  35. {heavyball-0.15.0 → heavyball-0.16.0}/test/test_soap.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: heavyball
3
- Version: 0.15.0
3
+ Version: 0.16.0
4
4
  Summary: Efficient optimizers
5
5
  Home-page: https://github.com/clashluke/heavyball
6
6
  Author: Lucas Nestler
@@ -39,12 +39,14 @@ recommended experimental optimizer is `ForeachPSGDKron`.
39
39
 
40
40
  * **Stochastic Rounding**: [FP32 convergence with BF16 parameters](https://github.com/pytorch/pytorch/issues/120376)
41
41
  * **Inplace EMA**: Same math, but less memory, less compute and higher stability
42
- * **Foreach**: Fast multi-tensor application
42
+ * **Foreach**: Fast multi-tensor application (turn it off to save memory via `foreach=False`)
43
43
  * **PaLM Beta2**: Fast initial
44
44
  convergence, [stable late convergence](https://x.com/_clashluke/status/1820810798693818761)
45
45
  * **ScheduleFree**: No learning rate schedule, but better convergence
46
46
  * [**Preconditioner Schedule**](https://github.com/lixilinx/psgd_torch/): Improved loss-per-step in early convergence,
47
47
  better step-per-second in late convergence (explained below)
48
+ * **Memory-efficient storage** PSGD supports `store_triu_as_line` (default: `True`) to trade off memory usage for memory
49
+ bandwidth; turn it off for lower overheads (for more, see [PSGD Efficiency](docs/psgd_efficiency.md))
48
50
 
49
51
  ## Getting started
50
52
 
@@ -15,12 +15,14 @@ recommended experimental optimizer is `ForeachPSGDKron`.
15
15
 
16
16
  * **Stochastic Rounding**: [FP32 convergence with BF16 parameters](https://github.com/pytorch/pytorch/issues/120376)
17
17
  * **Inplace EMA**: Same math, but less memory, less compute and higher stability
18
- * **Foreach**: Fast multi-tensor application
18
+ * **Foreach**: Fast multi-tensor application (turn it off to save memory via `foreach=False`)
19
19
  * **PaLM Beta2**: Fast initial
20
20
  convergence, [stable late convergence](https://x.com/_clashluke/status/1820810798693818761)
21
21
  * **ScheduleFree**: No learning rate schedule, but better convergence
22
22
  * [**Preconditioner Schedule**](https://github.com/lixilinx/psgd_torch/): Improved loss-per-step in early convergence,
23
23
  better step-per-second in late convergence (explained below)
24
+ * **Memory-efficient storage** PSGD supports `store_triu_as_line` (default: `True`) to trade off memory usage for memory
25
+ bandwidth; turn it off for lower overheads (for more, see [PSGD Efficiency](docs/psgd_efficiency.md))
24
26
 
25
27
  ## Getting started
26
28
 
@@ -1,3 +1,4 @@
1
+ from .cached_psgd_kron import ForeachCachedPSGDKron
1
2
  from .delayed_psgd import ForeachDelayedPSGD
2
3
  from .foreach_adamw import ForeachAdamW
3
4
  from .foreach_adopt import ForeachADOPT
@@ -16,7 +17,28 @@ from .schedule_free_palm_foreach_soap import SFPaLMForeachSOAP
16
17
 
17
18
  PalmForEachSoap = PaLMForeachSOAP
18
19
 
20
+ PaLMSOAP = PaLMForeachSOAP
21
+ PaLMSFAdamW = PaLMForeachSFAdamW
22
+ PaLMSFSoap = SFPaLMForeachSOAP
23
+ PaLMForeachSOAP = PaLMForeachSOAP
24
+ PrecondScheduleSFPaLMSOAP = PrecondScheduleSFPaLMSOAP
25
+ SOAP = ForeachSOAP
26
+ SFAdamW = ForeachSFAdamW
27
+ LaProp = ForeachLaProp
28
+ ADOPT = ForeachADOPT
29
+ PrecondScheduleForeachSOAP = PrecondScheduleForeachSOAP
30
+ PrecondSchedulePaLMForeachSOAP = PrecondSchedulePaLMForeachSOAP
31
+ PSGDKron = ForeachPSGDKron
32
+ AdamW = ForeachAdamW
33
+ PurePSGD = ForeachPurePSGD
34
+ PaLMPAdam = ForeachPaLMPAdam
35
+ DelayedPSGD = ForeachDelayedPSGD
36
+ CachedPSGDKron = ForeachCachedPSGDKron
37
+
19
38
  __all__ = ['PalmForEachSoap', 'PaLMForeachSFAdamW', 'PaLMForeachSOAP', 'SFPaLMForeachSOAP', 'PrecondScheduleSFPaLMSOAP',
20
39
  'ForeachSOAP', 'ForeachSFAdamW', 'ForeachLaProp', 'ForeachADOPT', 'PrecondScheduleForeachSOAP',
21
- 'PrecondSchedulePaLMForeachSOAP', 'ForeachPSGDKron', 'ForeachAdamW', 'ForeachPurePSGD',
22
- 'ForeachPaLMPAdam', 'ForeachDelayedPSGD']
40
+ 'PrecondSchedulePaLMForeachSOAP', 'ForeachPSGDKron', 'ForeachAdamW', 'ForeachPurePSGD', 'ForeachPaLMPAdam',
41
+ 'ForeachDelayedPSGD', 'ForeachCachedPSGDKron', #
42
+ 'PaLMSOAP', 'PaLMSFAdamW', 'PaLMSFSoap', 'PaLMSFAdamW', 'PaLMForeachSOAP', 'PrecondScheduleSFPaLMSOAP',
43
+ 'SOAP', 'SFAdamW', 'LaProp', 'ADOPT', 'PSGDKron', 'AdamW', 'PurePSGD', 'PaLMPAdam', 'DelayedPSGD',
44
+ 'CachedPSGDKron']
@@ -0,0 +1,142 @@
1
+ """
2
+ Originally from Evan Walters and Omead Pooladzandi, 2024
3
+ Modified under Creative Commons Attribution 4.0 International
4
+ Source available at https://github.com/evanatyourservice/kron_torch/blob/97a2b5ee8a1a4c29e4780bbf6c521e545189eff9/kron_torch/kron.py
5
+ """
6
+
7
+ from typing import Optional
8
+
9
+ import torch
10
+ from heavyball.utils import einsum_base
11
+
12
+ from .utils import update_param_, warmup, psgd_precond_grad, init_Q_exprs, trust_region_clip_, PSGDBase, \
13
+ precond_update_prob_schedule, split_p_and_g_in_group, line_to_triu, triu_to_line, set_, einsum_base
14
+
15
+
16
+ class ForeachCachedPSGDKron(PSGDBase):
17
+ """Implements PSGD Kron from https://github.com/lixilinx/psgd_torch with cached preconditioners.
18
+
19
+ Args:
20
+ params (iterable): Iterable of parameters to optimize or dicts defining
21
+ parameter groups.
22
+ lr (float): Learning rate.
23
+ b1 (float): Momentum parameter.
24
+ weight_decay (float): Weight decay (L2 penalty).
25
+ preconditioner_update_probability (callable or float, optional): Probability of
26
+ updating the preconditioner. If None, defaults to a schedule that anneals
27
+ from 1.0 to 0.03 by 4000 steps.
28
+ max_size_triangular (int): Max size for dim's preconditioner to be triangular.
29
+ min_ndim_triangular (int): Minimum number of dimensions a layer needs
30
+ to have triangular preconditioners.
31
+ memory_save_mode: (string, optional), None, 'one_diag', or 'all_diag', None is default
32
+ to set all preconditioners to be triangular, 'one_diag' sets the largest
33
+ or last dim to be diagonal per layer, and 'all_diag' sets all preconditioners
34
+ to be diagonal.
35
+ momentum_into_precond_update: (bool), whether to send momentum into preconditioner
36
+ update instead of raw gradients.
37
+ """
38
+
39
+ def __init__(self, params, lr=0.001, beta=0.9, weight_decay=0.0, preconditioner_update_probability=None,
40
+ max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
41
+ momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
42
+ split: bool = False, clip_fn: Optional[callable] = None, store_triu_as_line: bool = True,
43
+ foreach: bool = True):
44
+ if not 0.0 <= lr:
45
+ raise ValueError(f"Invalid learning rate: {lr}")
46
+ if not 0.0 <= beta < 1.0:
47
+ raise ValueError(f"Invalid beta parameter: {beta}")
48
+ if not 0.0 <= weight_decay:
49
+ raise ValueError(f"Invalid weight_decay value: {weight_decay}")
50
+
51
+ if preconditioner_update_probability is None:
52
+ preconditioner_update_probability = precond_update_prob_schedule()
53
+ if clip_fn is None:
54
+ clip_fn = lambda x: trust_region_clip_(x, 0.9, 1.5)
55
+ self.preconditioner_update_probability = preconditioner_update_probability
56
+ self.clip_fn = clip_fn
57
+
58
+ defaults = dict(lr=lr, beta=beta, weight_decay=weight_decay, max_size_triangular=max_size_triangular,
59
+ min_ndim_triangular=min_ndim_triangular, memory_save_mode=memory_save_mode,
60
+ momentum_into_precond_update=momentum_into_precond_update, precond_lr=0.1,
61
+ # precond lr hardcoded to 0.1
62
+ precond_init_scale=1.0, # precond init scale hardcoded to 1.0
63
+ step=0, warmup_steps=warmup_steps, merge_dims=merge_dims, split=split,
64
+ store_triu_as_line=store_triu_as_line)
65
+ super().__init__(params, defaults, foreach)
66
+
67
+ self._prob_step = 0
68
+
69
+ def _step(self, group):
70
+ # update preconditioners all together
71
+ update_prob = self.preconditioner_update_probability
72
+ if callable(update_prob):
73
+ update_prob = update_prob(self._prob_step)
74
+ do_update = self.rng.random() < update_prob
75
+ self._prob_step += 1
76
+
77
+ momentum_into_precond_update = group.get("momentum_into_precond_update", True)
78
+ precond_init_scale = group['precond_init_scale']
79
+ max_size_triangular = group['max_size_triangular']
80
+ min_ndim_triangular = group['min_ndim_triangular']
81
+ memory_save_mode = group['memory_save_mode']
82
+ precond_lr = group['precond_lr']
83
+ weight_decay = group['weight_decay']
84
+ lr = group['lr']
85
+ beta = group['beta']
86
+ store_triu_as_line = group['store_triu_as_line']
87
+
88
+ vals = []
89
+
90
+ for p, g in split_p_and_g_in_group(group):
91
+ state = self.state_(p)
92
+
93
+ if 'Q' not in state:
94
+ state["exp_avg"] = torch.zeros_like(g)
95
+ Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular, min_ndim_triangular,
96
+ memory_save_mode, dtype=g.dtype)
97
+ state['Q'] = triu_to_line(Q) if store_triu_as_line else Q
98
+ state['Q_cache'] = [torch.empty_like(q) for q in Q]
99
+
100
+ expr = [f'{c.upper()}{c}' if q_.ndim == 2 else c for c, q_ in zip(einsum_base, Q)]
101
+ expr = ','.join(expr)
102
+ grad_expr = ''.join(c for c, _ in zip(einsum_base, g.shape))
103
+ out_expr = ''.join(c.upper() if c.upper() in expr else c for c in grad_expr)
104
+ expr = f'{expr},{grad_expr}->{out_expr}'
105
+
106
+ state['cache_expr'] = expr
107
+
108
+ vals.append((p, g, state["exp_avg"], state["Q"], state['Q_cache']))
109
+
110
+ if not vals:
111
+ return
112
+
113
+ p_list, grad_list, exp_avg_list, Q_list, Q_cache_list = zip(*vals)
114
+ del vals
115
+
116
+ group["step"] += 1
117
+
118
+ torch._foreach_lerp_(exp_avg_list, grad_list, (1 - beta) / (1 - beta ** group["step"]))
119
+
120
+ grad_list, Q_list, Q_cache_list, exp_avg_list = list(grad_list), list(Q_list), list(Q_cache_list), list(
121
+ exp_avg_list)
122
+ for i, (p, g) in enumerate(zip(p_list, grad_list)):
123
+ cached_q = Q_cache_list.pop(0)
124
+ q_orig = Q_list.pop(0)
125
+ ea = exp_avg_list.pop(0)
126
+
127
+ if do_update:
128
+ q = line_to_triu(q_orig) if store_triu_as_line else q_orig
129
+ self.balance([g], [q])
130
+ self.do_update([p], [ea if momentum_into_precond_update else g], [q], precond_lr,
131
+ [q_orig] if store_triu_as_line else None)
132
+ for c_, q_ in zip(cached_q, q):
133
+ if q_.ndim == 2:
134
+ torch.matmul(q_.T.conj(), q_, out=c_)
135
+ else:
136
+ torch.mul(q_.conj(), q_, out=c_)
137
+
138
+ set_(g, torch.einsum(self.state_(p)['cache_expr'], *cached_q, ea))
139
+ grad_list = self.clip_fn(grad_list)
140
+
141
+ lr = -warmup(lr, group['step'], group['warmup_steps'])
142
+ update_param_(p_list, grad_list, lr, weight_decay)
@@ -38,7 +38,8 @@ class ForeachDelayedPSGD(PSGDBase):
38
38
  def __init__(self, params, lr=0.001, beta=0.9, weight_decay=0.0, preconditioner_update_probability=None,
39
39
  max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
40
40
  momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
41
- split: bool = False, clip_fn: callable = None):
41
+ split: bool = False, clip_fn: callable = None, store_triu_as_line: bool = True,
42
+ foreach: bool = True):
42
43
  if not 0.0 <= lr:
43
44
  raise ValueError(f"Invalid learning rate: {lr}")
44
45
  if not 0.0 <= beta < 1.0:
@@ -58,8 +59,9 @@ class ForeachDelayedPSGD(PSGDBase):
58
59
  momentum_into_precond_update=momentum_into_precond_update, precond_lr=0.1,
59
60
  # precond lr hardcoded to 0.1
60
61
  precond_init_scale=1.0, # precond init scale hardcoded to 1.0
61
- step=0, warmup_steps=warmup_steps, merge_dims=merge_dims, split=split)
62
- super().__init__(params, defaults)
62
+ step=0, warmup_steps=warmup_steps, merge_dims=merge_dims, split=split,
63
+ store_triu_as_line=store_triu_as_line)
64
+ super().__init__(params, defaults, foreach)
63
65
 
64
66
  self._prob_step = 0
65
67
 
@@ -80,6 +82,7 @@ class ForeachDelayedPSGD(PSGDBase):
80
82
  weight_decay = group['weight_decay']
81
83
  lr = group['lr']
82
84
  beta = group['beta']
85
+ store_triu_as_line = group['store_triu_as_line']
83
86
 
84
87
  vals = []
85
88
 
@@ -90,7 +93,7 @@ class ForeachDelayedPSGD(PSGDBase):
90
93
  state["exp_avg"] = torch.zeros_like(g)
91
94
  Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular, min_ndim_triangular,
92
95
  memory_save_mode, dtype=g.dtype)
93
- state["Q"] = triu_to_line(Q)
96
+ state["Q"] = triu_to_line(Q) if store_triu_as_line else Q
94
97
 
95
98
  vals.append((p, g, state["exp_avg"], state["Q"]))
96
99
 
@@ -108,12 +111,12 @@ class ForeachDelayedPSGD(PSGDBase):
108
111
  for i, (p, g) in enumerate(zip(p_list, grad_list)):
109
112
  q_orig = Q_list.pop(0)
110
113
  ea = exp_avg_list.pop(0)
111
- q = line_to_triu(q_orig)
112
- self.balance(do_update, [g], [q])
114
+ q = line_to_triu(q_orig) if store_triu_as_line else q_orig
113
115
  new = psgd_precond_grad(q, self.state_(p)["exprs"], ea)
114
-
115
116
  if do_update:
116
- self.do_update([p], [ea if momentum_into_precond_update else g], [q], precond_lr, [q_orig])
117
+ self.do_update([p], [ea if momentum_into_precond_update else g], [q], precond_lr,
118
+ [q_orig] if store_triu_as_line else None)
119
+ self.balance([g], [q])
117
120
  set_(g, new)
118
121
 
119
122
  grad_list = self.clip_fn(grad_list)
@@ -5,10 +5,11 @@ from .utils import warmup, exp_avg_sq_, beta_debias, update_param_, StatefulOpti
5
5
 
6
6
 
7
7
  class ForeachAdamW(StatefulOptimizer):
8
- def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0):
8
+ def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
9
+ foreach: bool = True):
9
10
  defaults = dict(lr=lr, betas=betas, eps=eps, k=0, warmup_steps=warmup_steps, train_mode=True, weight_sum=0.0,
10
11
  lr_max=-1.0, weight_decay=weight_decay)
11
- super().__init__(params, defaults)
12
+ super().__init__(params, defaults, foreach)
12
13
 
13
14
  def _step(self, group):
14
15
  eps = group['eps']
@@ -6,10 +6,11 @@ from .utils import warmup, beta_debias, update_param_, StatefulOptimizer
6
6
 
7
7
  class ForeachADOPT(StatefulOptimizer):
8
8
 
9
- def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0):
9
+ def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
10
+ foreach: bool = True):
10
11
  defaults = dict(lr=lr, betas=betas, eps=eps, k=0, warmup_steps=warmup_steps, train_mode=True, weight_sum=0.0,
11
12
  lr_max=-1.0, weight_decay=weight_decay)
12
- super().__init__(params, defaults)
13
+ super().__init__(params, defaults, foreach)
13
14
 
14
15
  def _step(self, group):
15
16
  eps = group['eps']
@@ -6,10 +6,11 @@ from .utils import warmup, exp_avg_sq_, beta_debias, update_param_, StatefulOpti
6
6
 
7
7
  class ForeachLaProp(StatefulOptimizer):
8
8
 
9
- def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=1):
9
+ def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=1,
10
+ foreach: bool = True):
10
11
  defaults = dict(lr=lr, betas=betas, eps=eps, k=0, warmup_steps=warmup_steps, train_mode=True, weight_sum=0.0,
11
12
  lr_max=-1.0, weight_decay=weight_decay)
12
- super().__init__(params, defaults)
13
+ super().__init__(params, defaults, foreach)
13
14
 
14
15
  def _step(self, group):
15
16
  eps = group['eps']
@@ -6,12 +6,12 @@ from .utils import schedule_free_, warmup, ScheduleFree, exp_avg_sq_, beta_debia
6
6
 
7
7
  class ForeachSFAdamW(ScheduleFree):
8
8
  def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0, r=0.0,
9
- weight_lr_power=2.0, foreach=hasattr(torch, "_foreach_mul_")):
9
+ weight_lr_power=2.0, foreach: bool = True):
10
10
 
11
11
  defaults = dict(lr=lr, betas=betas, eps=eps, r=r, k=0, warmup_steps=warmup_steps, train_mode=True,
12
12
  weight_sum=0.0, lr_max=-1.0, weight_lr_power=weight_lr_power, weight_decay=weight_decay,
13
13
  foreach=foreach)
14
- super().__init__(params, defaults)
14
+ super().__init__(params, defaults, foreach)
15
15
 
16
16
  def _step(self, group):
17
17
  eps = group['eps']
@@ -48,7 +48,7 @@ class ForeachSFAdamW(ScheduleFree):
48
48
  torch._foreach_add_(grad, y, alpha=decay)
49
49
 
50
50
  lr = warmup(group['lr'], k + 1, group['warmup_steps'])
51
- group['weight_sum'] = schedule_free_(lr, group['weight_lr_power'], group['weight_sum'], group['betas'][0],
52
- y, z, grad, group['r'], k + 1)
51
+ group['weight_sum'] = schedule_free_(lr, group['weight_lr_power'], group['weight_sum'], group['betas'][0], y, z,
52
+ grad, group['r'], k + 1)
53
53
 
54
54
  group['k'] = k + 1
@@ -26,12 +26,13 @@ class ForeachSOAP(StatefulOptimizer):
26
26
  weight_decay: float = 0.01, precondition_frequency: int = 2, max_precond_dim: int = 2048, #
27
27
  merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
28
28
  data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1,
29
- split: bool = False):
29
+ split: bool = False,
30
+ foreach: bool = True):
30
31
  defaults = {"lr": lr, "betas": betas, "shampoo_beta": shampoo_beta, "eps": eps, "weight_decay": weight_decay,
31
32
  "precondition_frequency": precondition_frequency, "max_precond_dim": max_precond_dim,
32
33
  "merge_dims": merge_dims, "precondition_1d": precondition_1d, "normalize_grads": normalize_grads,
33
34
  "correct_bias": correct_bias, 'warmup_steps': warmup_steps, 'split': split}
34
- super().__init__(params, defaults)
35
+ super().__init__(params, defaults, foreach)
35
36
  self._data_format = data_format
36
37
 
37
38
  def _step(self, group):
@@ -59,7 +60,7 @@ class ForeachSOAP(StatefulOptimizer):
59
60
  vals.append((p, g, grad_projected, exp_avg, exp_avg_sq))
60
61
 
61
62
  if not vals:
62
- return
63
+ return
63
64
 
64
65
  p_list, grad, grad_projected, exp_avg, exp_avg_sq = zip(*vals)
65
66
  beta1, beta2 = group["betas"]
@@ -5,6 +5,7 @@ Source available at https://github.com/evanatyourservice/kron_torch/blob/97a2b5e
5
5
  """
6
6
 
7
7
  import torch
8
+ from heavyball.utils import triu_to_line, line_to_triu
8
9
 
9
10
  from .utils import update_param_, warmup, psgd_precond_grad, init_Q_exprs, PSGDBase, precond_update_prob_schedule, \
10
11
  exp_avg_sq_, beta_debias, split_p_and_g_in_group
@@ -36,7 +37,9 @@ class ForeachPaLMPAdam(PSGDBase):
36
37
  def __init__(self, params, lr=0.001, weight_decay=0.0, preconditioner_update_probability=None,
37
38
  max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
38
39
  momentum_into_precond_update=True, warmup_steps: int = 1, betas=(None, None), beta: float = 0.9,
39
- beta2_scale: float = 0.8, merge_dims: bool = False, split: bool = False, clip_fn: callable = None):
40
+ beta2_scale: float = 0.8, merge_dims: bool = False, split: bool = False, clip_fn: callable = None,
41
+ store_triu_as_line: bool = True,
42
+ foreach: bool = True):
40
43
  if not 0.0 <= lr:
41
44
  raise ValueError(f"Invalid learning rate: {lr}")
42
45
  if not 0.0 <= weight_decay:
@@ -57,8 +60,8 @@ class ForeachPaLMPAdam(PSGDBase):
57
60
  # precond lr hardcoded to 0.1
58
61
  precond_init_scale=1.0, # precond init scale hardcoded to 1.0
59
62
  step=0, warmup_steps=warmup_steps, beta=beta, beta2_scale=beta2_scale, merge_dims=merge_dims,
60
- split=split)
61
- super().__init__(params, defaults)
63
+ split=split, store_triu_as_line=store_triu_as_line)
64
+ super().__init__(params, defaults, foreach)
62
65
 
63
66
  self._prob_step = 0
64
67
 
@@ -77,6 +80,7 @@ class ForeachPaLMPAdam(PSGDBase):
77
80
  precond_lr = group['precond_lr']
78
81
  weight_decay = group['weight_decay']
79
82
  lr = group['lr']
83
+ store_triu_as_line = group['store_triu_as_line']
80
84
 
81
85
  vals = []
82
86
 
@@ -86,8 +90,9 @@ class ForeachPaLMPAdam(PSGDBase):
86
90
  if 'Q' not in state:
87
91
  state['exp_avg'] = torch.zeros_like(g)
88
92
  state['exp_avg_sq'] = torch.zeros_like(g)
89
- state["Q"], state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular,
90
- min_ndim_triangular, memory_save_mode, dtype=g.dtype)
93
+ Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular,
94
+ min_ndim_triangular, memory_save_mode, dtype=g.dtype)
95
+ state['Q'] = triu_to_line(Q) if store_triu_as_line else Q
91
96
 
92
97
  vals.append((p, g, state["Q"], state['exp_avg'], state['exp_avg_sq']))
93
98
 
@@ -99,15 +104,16 @@ class ForeachPaLMPAdam(PSGDBase):
99
104
 
100
105
  group["step"] += 1
101
106
 
102
- self.balance(do_update, grad_list, Q_list)
107
+ Q_triu = [line_to_triu(q) if store_triu_as_line else q for q in Q_list]
103
108
  if do_update:
104
- self.do_update(p_list, grad_list, Q_list, precond_lr)
109
+ self.balance(grad_list, Q_triu)
110
+ self.do_update(p_list, grad_list, Q_triu, precond_lr, Q_list if store_triu_as_line else None)
105
111
 
106
112
  torch._foreach_lerp_(exp_avg, grad_list, 1 - beta_debias(group['beta'], group['step']))
107
113
 
108
114
  beta2 = 1 - group['step'] ** -group['beta2_scale']
109
115
 
110
- for p, Q, g, ea, eas in zip(p_list, Q_list, grad_list, exp_avg, exp_avg_sq):
116
+ for p, Q, g, ea, eas in zip(p_list, Q_triu, grad_list, exp_avg, exp_avg_sq):
111
117
  psgd_precond_grad(Q, self.state_(p)["exprs"], g, inplace=True)
112
118
  ea = psgd_precond_grad(Q, self.state_(p)["exprs"], ea)
113
119
  exp_avg_sq_(eas, g, beta_debias(beta2, group['step']), 1e-8, out=g)
@@ -6,13 +6,14 @@ from .utils import schedule_free_, warmup, ScheduleFree, exp_avg_sq_, beta_debia
6
6
 
7
7
  class PaLMForeachSFAdamW(ScheduleFree):
8
8
  def __init__(self, params, lr=0.0025, beta=0.9, betas=(None, None), eps=1e-8, weight_decay=0, warmup_steps=0, r=0.0,
9
- weight_lr_power=2.0, beta2_scale: float = 0.8):
9
+ weight_lr_power=2.0, beta2_scale: float = 0.8,
10
+ foreach: bool = True):
10
11
  if betas[0] is not None:
11
12
  beta = betas[0]
12
13
  defaults = dict(lr=lr, beta=beta, eps=eps, r=r, k=0, warmup_steps=warmup_steps, train_mode=True, weight_sum=0.0,
13
14
  lr_max=-1.0, weight_lr_power=weight_lr_power, weight_decay=weight_decay,
14
15
  beta2_scale=beta2_scale)
15
- super().__init__(params, defaults)
16
+ super().__init__(params, defaults, foreach)
16
17
 
17
18
  def _step(self, group):
18
19
  eps = group['eps']
@@ -32,7 +32,8 @@ class PaLMForeachSOAP(StatefulOptimizer):
32
32
  max_precond_dim: int = 2048, #
33
33
  merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
34
34
  data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1,
35
- beta2_scale: float = 0.8, split: bool = False):
35
+ beta2_scale: float = 0.8, split: bool = False,
36
+ foreach: bool = True):
36
37
  if betas[0] is not None:
37
38
  beta = betas[0]
38
39
  defaults = {"lr": lr, "beta": beta, "shampoo_beta": shampoo_beta, "eps": eps, "weight_decay": weight_decay,
@@ -40,7 +41,7 @@ class PaLMForeachSOAP(StatefulOptimizer):
40
41
  "merge_dims": merge_dims, "precondition_1d": precondition_1d, "normalize_grads": normalize_grads,
41
42
  "correct_bias": correct_bias, 'warmup_steps': warmup_steps, 'beta2_scale': beta2_scale,
42
43
  'split': split}
43
- super().__init__(params, defaults)
44
+ super().__init__(params, defaults, foreach)
44
45
  self._data_format = data_format
45
46
 
46
47
  def _step(self, group):
@@ -27,13 +27,14 @@ class PrecondScheduleForeachSOAP(StatefulOptimizer):
27
27
  weight_decay: float = 0.01, precondition_frequency: int = 2, max_precond_dim: int = 2048, #
28
28
  merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
29
29
  data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1,
30
- precond_scheduler=(1 / 3, 9), split: bool = False):
30
+ precond_scheduler=(1 / 3, 9), split: bool = False,
31
+ foreach: bool = True):
31
32
  defaults = {"lr": lr, "betas": betas, "shampoo_beta": shampoo_beta, "eps": eps, "weight_decay": weight_decay,
32
33
  "precondition_frequency": precondition_frequency, "max_precond_dim": max_precond_dim,
33
34
  "merge_dims": merge_dims, "precondition_1d": precondition_1d, "normalize_grads": normalize_grads,
34
35
  "correct_bias": correct_bias, 'warmup_steps': warmup_steps, 'precond_scheduler': precond_scheduler,
35
36
  'split': split}
36
- super().__init__(params, defaults)
37
+ super().__init__(params, defaults, foreach)
37
38
  self._data_format = data_format
38
39
  self.rng = random.Random(0x120983109)
39
40
 
@@ -32,7 +32,8 @@ class PrecondSchedulePaLMForeachSOAP(StatefulOptimizer):
32
32
  weight_decay: float = 0.01, precondition_frequency: int = 2, max_precond_dim: int = 2048, #
33
33
  merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
34
34
  data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1,
35
- precond_scheduler=(1 / 3, 9), betas=(None, None), beta2_scale: float = 0.8, split: bool = False):
35
+ precond_scheduler=(1 / 3, 9), betas=(None, None), beta2_scale: float = 0.8, split: bool = False,
36
+ foreach: bool = True):
36
37
  if betas[0] is not None:
37
38
  beta = betas[0]
38
39
  defaults = {"lr": lr, "beta": beta, "shampoo_beta": shampoo_beta, "eps": eps, "weight_decay": weight_decay,
@@ -40,7 +41,7 @@ class PrecondSchedulePaLMForeachSOAP(StatefulOptimizer):
40
41
  "merge_dims": merge_dims, "precondition_1d": precondition_1d, "normalize_grads": normalize_grads,
41
42
  "correct_bias": correct_bias, 'warmup_steps': warmup_steps, 'precond_scheduler': precond_scheduler,
42
43
  'beta2_scale': beta2_scale, 'split': split}
43
- super().__init__(params, defaults)
44
+ super().__init__(params, defaults, foreach)
44
45
  self._data_format = data_format
45
46
  self.rng = random.Random(0x120983109)
46
47
 
@@ -41,7 +41,7 @@ class PrecondScheduleSFPaLMSOAP(ScheduleFree):
41
41
  merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
42
42
  data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1, r=0.0,
43
43
  weight_lr_power=2.0, gradient_clip_val: float = 0.1, precond_scheduler=(1 / 3, 9),
44
- betas=(None, None), split: bool = False):
44
+ betas=(None, None), split: bool = False, foreach: bool = True):
45
45
  if betas[0] is not None:
46
46
  beta = betas[0]
47
47
  defaults = {"lr": lr, "beta": beta, "beta2_scale": beta2_scale, "eps": eps, "weight_decay": weight_decay,
@@ -50,7 +50,7 @@ class PrecondScheduleSFPaLMSOAP(ScheduleFree):
50
50
  "correct_bias": correct_bias, 'warmup_steps': warmup_steps, 'r': r,
51
51
  'weight_lr_power': weight_lr_power, 'train_mode': True, 'step': -1, 'weight_sum': 0,
52
52
  'gradient_clip_val': gradient_clip_val, 'precond_scheduler': precond_scheduler, 'split': split}
53
- super().__init__(params, defaults)
53
+ super().__init__(params, defaults, foreach)
54
54
  self._data_format = data_format
55
55
  self.rng = random.Random(0x120983109)
56
56
 
@@ -59,7 +59,7 @@ class PrecondScheduleSFPaLMSOAP(ScheduleFree):
59
59
  max_precond_dim = group['max_precond_dim']
60
60
  precondition_1d = group['precondition_1d']
61
61
 
62
- step = group['step'] = group.get("step", -1) + 1
62
+ step = group['step'] = group.get("step", 0) + 1
63
63
 
64
64
  for p in group["params"]:
65
65
  if p.grad is None:
@@ -38,7 +38,8 @@ class ForeachPSGDKron(PSGDBase):
38
38
  def __init__(self, params, lr=0.001, beta=0.9, weight_decay=0.0, preconditioner_update_probability=None,
39
39
  max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
40
40
  momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
41
- split: bool = False, clip_fn: Optional[callable] = None):
41
+ split: bool = False, clip_fn: Optional[callable] = None, store_triu_as_line: bool = True,
42
+ foreach: bool = True):
42
43
  if not 0.0 <= lr:
43
44
  raise ValueError(f"Invalid learning rate: {lr}")
44
45
  if not 0.0 <= beta < 1.0:
@@ -58,8 +59,9 @@ class ForeachPSGDKron(PSGDBase):
58
59
  momentum_into_precond_update=momentum_into_precond_update, precond_lr=0.1,
59
60
  # precond lr hardcoded to 0.1
60
61
  precond_init_scale=1.0, # precond init scale hardcoded to 1.0
61
- step=0, warmup_steps=warmup_steps, merge_dims=merge_dims, split=split)
62
- super().__init__(params, defaults)
62
+ step=0, warmup_steps=warmup_steps, merge_dims=merge_dims, split=split,
63
+ store_triu_as_line=store_triu_as_line)
64
+ super().__init__(params, defaults, foreach)
63
65
 
64
66
  self._prob_step = 0
65
67
 
@@ -80,6 +82,7 @@ class ForeachPSGDKron(PSGDBase):
80
82
  weight_decay = group['weight_decay']
81
83
  lr = group['lr']
82
84
  beta = group['beta']
85
+ store_triu_as_line = group['store_triu_as_line']
83
86
 
84
87
  vals = []
85
88
 
@@ -90,7 +93,7 @@ class ForeachPSGDKron(PSGDBase):
90
93
  state["exp_avg"] = torch.zeros_like(g)
91
94
  Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular, min_ndim_triangular,
92
95
  memory_save_mode, dtype=g.dtype)
93
- state['Q'] = triu_to_line(Q)
96
+ state['Q'] = triu_to_line(Q) if store_triu_as_line else Q
94
97
 
95
98
  vals.append((p, g, state["exp_avg"], state["Q"]))
96
99
 
@@ -108,11 +111,12 @@ class ForeachPSGDKron(PSGDBase):
108
111
  for i, (p, g) in enumerate(zip(p_list, grad_list)):
109
112
  q_orig = Q_list.pop(0)
110
113
  ea = exp_avg_list.pop(0)
111
- q = line_to_triu(q_orig)
114
+ q = line_to_triu(q_orig) if store_triu_as_line else q_orig
112
115
 
113
- self.balance(do_update, [g], [q])
114
116
  if do_update:
115
- self.do_update([p], [ea if momentum_into_precond_update else g], [q], precond_lr, [q_orig])
117
+ self.balance([g], [q])
118
+ self.do_update([p], [ea if momentum_into_precond_update else g], [q], precond_lr,
119
+ [q_orig] if store_triu_as_line else None)
116
120
  set_(g, psgd_precond_grad(q, self.state_(p)["exprs"], ea))
117
121
 
118
122
  grad_list = self.clip_fn(grad_list)
@@ -36,7 +36,8 @@ class ForeachPurePSGD(PSGDBase):
36
36
  def __init__(self, params, lr=0.001, weight_decay=0.0, preconditioner_update_probability=None,
37
37
  max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
38
38
  momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
39
- split: bool = False, clip_fn: callable = None):
39
+ split: bool = False, clip_fn: callable = None, store_triu_as_line: bool = True,
40
+ foreach: bool = True):
40
41
  if not 0.0 <= lr:
41
42
  raise ValueError(f"Invalid learning rate: {lr}")
42
43
  if not 0.0 <= weight_decay:
@@ -54,8 +55,9 @@ class ForeachPurePSGD(PSGDBase):
54
55
  momentum_into_precond_update=momentum_into_precond_update, precond_lr=0.1,
55
56
  # precond lr hardcoded to 0.1
56
57
  precond_init_scale=1.0, # precond init scale hardcoded to 1.0
57
- step=0, warmup_steps=warmup_steps, merge_dims=merge_dims, split=split)
58
- super().__init__(params, defaults)
58
+ step=0, warmup_steps=warmup_steps, merge_dims=merge_dims, split=split,
59
+ store_triu_as_line=store_triu_as_line)
60
+ super().__init__(params, defaults, foreach)
59
61
 
60
62
  self._prob_step = 0
61
63
 
@@ -74,6 +76,7 @@ class ForeachPurePSGD(PSGDBase):
74
76
  precond_lr = group['precond_lr']
75
77
  weight_decay = group['weight_decay']
76
78
  lr = group['lr']
79
+ store_triu_as_line = group['store_triu_as_line']
77
80
 
78
81
  vals = []
79
82
 
@@ -83,7 +86,7 @@ class ForeachPurePSGD(PSGDBase):
83
86
  if 'Q' not in state:
84
87
  Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular, min_ndim_triangular,
85
88
  memory_save_mode, dtype=g.dtype)
86
- state['Q'] = triu_to_line(Q)
89
+ state['Q'] = triu_to_line(Q) if store_triu_as_line else Q
87
90
 
88
91
  vals.append((p, g, state["Q"]))
89
92
 
@@ -98,11 +101,11 @@ class ForeachPurePSGD(PSGDBase):
98
101
  Q_list = list(Q_list)
99
102
  for i, (p, g) in enumerate(zip(p_list, grad_list)):
100
103
  q_orig = Q_list.pop(0)
101
- q = line_to_triu(q_orig)
104
+ q = line_to_triu(q_orig) if store_triu_as_line else q_orig
102
105
 
103
- self.balance(do_update, [g], [q])
104
106
  if do_update:
105
- self.do_update([p], [g], [q], precond_lr, [q_orig])
107
+ self.balance([g], [q])
108
+ self.do_update([p], [g], [q], precond_lr, [q_orig] if store_triu_as_line else None)
106
109
  psgd_precond_grad(q, self.state_(p)["exprs"], g, inplace=True)
107
110
 
108
111
  grad_list = self.clip_fn(grad_list)
@@ -33,7 +33,8 @@ class SFPaLMForeachSOAP(ScheduleFree):
33
33
  weight_decay: float = 0.01, precondition_frequency: int = 2, max_precond_dim: int = 2048, #
34
34
  merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
35
35
  data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1, r=0.0,
36
- weight_lr_power=2.0, gradient_clip_val: float = 0.1, betas=(None, None), split: bool = False):
36
+ weight_lr_power=2.0, gradient_clip_val: float = 0.1, betas=(None, None), split: bool = False,
37
+ foreach: bool = True):
37
38
  if betas[0] is not None:
38
39
  beta = betas[0]
39
40
  defaults = {"lr": lr, "beta": beta, "beta2_scale": beta2_scale, "eps": eps, "weight_decay": weight_decay,
@@ -42,7 +43,7 @@ class SFPaLMForeachSOAP(ScheduleFree):
42
43
  "correct_bias": correct_bias, 'warmup_steps': warmup_steps, 'r': r,
43
44
  'weight_lr_power': weight_lr_power, 'train_mode': True, 'step': -1,
44
45
  'gradient_clip_val': gradient_clip_val, 'weight_sum': 0, 'split': split}
45
- super().__init__(params, defaults)
46
+ super().__init__(params, defaults, foreach)
46
47
  self._data_format = data_format
47
48
  self.rng = random.Random(0x120983109)
48
49
 
@@ -51,7 +52,7 @@ class SFPaLMForeachSOAP(ScheduleFree):
51
52
  max_precond_dim = group['max_precond_dim']
52
53
  precondition_1d = group['precondition_1d']
53
54
 
54
- step = group['step'] = group.get("step", -1) + 1
55
+ step = group['step'] = group.get("step", 0) + 1
55
56
 
56
57
  for p in group["params"]:
57
58
  if p.grad is None:
@@ -29,7 +29,7 @@ def decorator(func):
29
29
  return _fn
30
30
 
31
31
 
32
- _einsum_base = string.ascii_lowercase + string.ascii_uppercase
32
+ einsum_base = string.ascii_lowercase + string.ascii_uppercase
33
33
 
34
34
 
35
35
  def warmup(lr: float, step: int, warmup_steps: int):
@@ -317,8 +317,8 @@ def compute_ggt(grad, GG, max_precond_dim, precondition_1d, beta):
317
317
  for idx, sh in enumerate(grad.shape):
318
318
  if sh > max_precond_dim:
319
319
  continue
320
- b = _einsum_base[idx]
321
- g0 = _einsum_base[:grad.dim()]
320
+ b = einsum_base[idx]
321
+ g0 = einsum_base[:grad.dim()]
322
322
  g1 = g0.replace(b, b.upper())
323
323
  outer_product = torch.einsum(f'{g0},{g1}->{b + b.upper()}', grad, grad)
324
324
  GG[idx].lerp_(promote(outer_product), 1 - beta)
@@ -374,7 +374,7 @@ def project(grad, Q, back: bool):
374
374
  :param back: whether to project to Shampoo eigenbases or back to original space
375
375
  :return:
376
376
  """
377
- param = _einsum_base[:grad.dim()]
377
+ param = einsum_base[:grad.dim()]
378
378
  preconditioners = ",".join([(g + g.upper())[::-1 if back else 1] for m, g in zip(Q, param) if len(m) > 0])
379
379
  if preconditioners:
380
380
  out = ''.join([c.upper() if c.upper() in preconditioners else c for c in param])
@@ -383,8 +383,25 @@ def project(grad, Q, back: bool):
383
383
 
384
384
 
385
385
  class StatefulOptimizer(torch.optim.Optimizer):
386
+ def __init__(self, params, defaults, foreach: bool = True):
387
+ super().__init__(params, {**defaults, 'foreach': foreach})
388
+ self.fake_groups = {}
389
+
390
+ def key(self, param: torch.Tensor):
391
+ return (param.data_ptr(), tuple(param.shape))
392
+
393
+ def get_groups(self, group):
394
+ if group['foreach']:
395
+ return [group]
396
+
397
+ for p in group['params']:
398
+ if self.key(p) not in self.fake_groups:
399
+ self.fake_groups[self.key(p)] = {**group, 'params': [p]}
400
+
401
+ return [self.fake_groups[self.key(p)] for p in group['params']]
402
+
386
403
  def state_(self, arg: torch.Tensor):
387
- return self.state[(arg.data_ptr(), tuple(arg.shape))]
404
+ return self.state[self.key(arg)]
388
405
 
389
406
  def state_size(self) -> int:
390
407
  total_bytes = 0
@@ -409,8 +426,9 @@ class StatefulOptimizer(torch.optim.Optimizer):
409
426
  with torch.enable_grad():
410
427
  loss = closure()
411
428
  with torch.no_grad():
412
- for group in self.param_groups:
413
- self._step(group)
429
+ for top_group in self.param_groups:
430
+ for group in self.get_groups(top_group):
431
+ self._step(group)
414
432
  return loss
415
433
 
416
434
 
@@ -754,13 +772,13 @@ def update_triu_(q_state, materialised):
754
772
 
755
773
 
756
774
  class PSGDBase(StatefulOptimizer):
757
- def __init__(self, parameters, groups):
758
- super().__init__(parameters, groups)
775
+ def __init__(self, parameters, groups, foreach: bool = True):
776
+ super().__init__(parameters, groups, foreach)
759
777
  self.rng = random.Random(0x1923213)
760
778
  self._tiny = torch.finfo(torch.bfloat16).tiny
761
779
 
762
- def balance(self, do_update, grad_list, Q_list):
763
- if not do_update or self.rng.random() > 0.01:
780
+ def balance(self, grad_list, Q_list):
781
+ if self.rng.random() > 0.01:
764
782
  return
765
783
 
766
784
  for g, q in zip(grad_list, Q_list):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: heavyball
3
- Version: 0.15.0
3
+ Version: 0.16.0
4
4
  Summary: Efficient optimizers
5
5
  Home-page: https://github.com/clashluke/heavyball
6
6
  Author: Lucas Nestler
@@ -39,12 +39,14 @@ recommended experimental optimizer is `ForeachPSGDKron`.
39
39
 
40
40
  * **Stochastic Rounding**: [FP32 convergence with BF16 parameters](https://github.com/pytorch/pytorch/issues/120376)
41
41
  * **Inplace EMA**: Same math, but less memory, less compute and higher stability
42
- * **Foreach**: Fast multi-tensor application
42
+ * **Foreach**: Fast multi-tensor application (turn it off to save memory via `foreach=False`)
43
43
  * **PaLM Beta2**: Fast initial
44
44
  convergence, [stable late convergence](https://x.com/_clashluke/status/1820810798693818761)
45
45
  * **ScheduleFree**: No learning rate schedule, but better convergence
46
46
  * [**Preconditioner Schedule**](https://github.com/lixilinx/psgd_torch/): Improved loss-per-step in early convergence,
47
47
  better step-per-second in late convergence (explained below)
48
+ * **Memory-efficient storage** PSGD supports `store_triu_as_line` (default: `True`) to trade off memory usage for memory
49
+ bandwidth; turn it off for lower overheads (for more, see [PSGD Efficiency](docs/psgd_efficiency.md))
48
50
 
49
51
  ## Getting started
50
52
 
@@ -2,6 +2,7 @@ LICENSE
2
2
  README.md
3
3
  setup.py
4
4
  heavyball/__init__.py
5
+ heavyball/cached_psgd_kron.py
5
6
  heavyball/delayed_psgd.py
6
7
  heavyball/foreach_adamw.py
7
8
  heavyball/foreach_adopt.py
@@ -24,6 +25,7 @@ heavyball.egg-info/dependency_links.txt
24
25
  heavyball.egg-info/requires.txt
25
26
  heavyball.egg-info/top_level.txt
26
27
  test/test_closure.py
28
+ test/test_foreach.py
27
29
  test/test_memory.py
28
30
  test/test_merge.py
29
31
  test/test_no_grad.py
@@ -10,7 +10,7 @@ setuptools.setup(
10
10
  name='heavyball',
11
11
  license='BSD',
12
12
  description='Efficient optimizers',
13
- version='0.15.0',
13
+ version='0.16.0',
14
14
  long_description=README,
15
15
  url='https://github.com/clashluke/heavyball',
16
16
  packages=setuptools.find_packages(),
@@ -0,0 +1,65 @@
1
+ import heavyball
2
+ import heavyball.utils
3
+ import pytest
4
+ import torch
5
+ from benchmark.utils import get_optim
6
+ from heavyball.utils import clean, set_torch, PSGDBase
7
+ from torch import nn
8
+
9
+
10
+ def get_memory():
11
+ clean()
12
+ torch.cuda.synchronize()
13
+ clean()
14
+ torch.cuda.synchronize()
15
+ return torch.cuda.memory_allocated()
16
+
17
+
18
+ @pytest.mark.parametrize("opt", heavyball.__all__)
19
+ @pytest.mark.parametrize("size,depth", [(256, 128)])
20
+ def test_foreach(opt, size, depth: int, iterations: int = 5, outer_iterations: int = 3):
21
+ set_torch()
22
+
23
+ opt = getattr(heavyball, opt)
24
+
25
+ peaks = []
26
+ losses = []
27
+
28
+ for foreach in [True, False]:
29
+ peaks.append([])
30
+ losses.append([])
31
+
32
+ for i in range(outer_iterations):
33
+ torch.manual_seed(0x2131290)
34
+ clean()
35
+ model = nn.Sequential(*[nn.Linear(size, size) for _ in range(depth)]).cuda()
36
+ clean()
37
+
38
+ torch.cuda.reset_peak_memory_stats()
39
+ torch.cuda.reset_max_memory_allocated()
40
+ torch.cuda.reset_max_memory_cached()
41
+ torch.cuda.reset_accumulated_memory_stats()
42
+
43
+ clean()
44
+ o = get_optim(opt, model.parameters(), lr=1e-3, foreach=foreach)
45
+ clean()
46
+
47
+ for _ in range(iterations):
48
+ loss = model(torch.randn((1, size)).cuda()).sum()
49
+ loss.backward()
50
+ o.step()
51
+ o.zero_grad()
52
+ losses[-1].append(loss.detach())
53
+
54
+ del model, o
55
+ clean()
56
+
57
+ peak = torch.cuda.memory_stats()['allocated_bytes.all.peak']
58
+
59
+ if i > 0:
60
+ peaks[-1].append(peak)
61
+
62
+ for p0, p1 in zip(*peaks):
63
+ assert p0 > p1
64
+ for l0, l1 in zip(*losses): # increase error tolerance for PSGD, as we have different RNGs -> expected differences
65
+ assert torch.allclose(l0, l1, rtol=0.01 if isinstance(opt, PSGDBase) else 1e-5)
File without changes
File without changes
File without changes
File without changes