heavyball 0.15.1__tar.gz → 0.17.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. {heavyball-0.15.1 → heavyball-0.17.0}/PKG-INFO +4 -2
  2. {heavyball-0.15.1 → heavyball-0.17.0}/README.md +3 -1
  3. {heavyball-0.15.1 → heavyball-0.17.0}/heavyball/__init__.py +4 -2
  4. heavyball-0.17.0/heavyball/cached_delayed_psgd_kron.py +146 -0
  5. {heavyball-0.15.1 → heavyball-0.17.0}/heavyball/cached_psgd_kron.py +15 -9
  6. {heavyball-0.15.1 → heavyball-0.17.0}/heavyball/delayed_psgd.py +10 -7
  7. {heavyball-0.15.1 → heavyball-0.17.0}/heavyball/foreach_adamw.py +3 -2
  8. {heavyball-0.15.1 → heavyball-0.17.0}/heavyball/foreach_adopt.py +3 -2
  9. {heavyball-0.15.1 → heavyball-0.17.0}/heavyball/foreach_laprop.py +3 -2
  10. {heavyball-0.15.1 → heavyball-0.17.0}/heavyball/foreach_sfadamw.py +4 -4
  11. {heavyball-0.15.1 → heavyball-0.17.0}/heavyball/foreach_soap.py +4 -3
  12. {heavyball-0.15.1 → heavyball-0.17.0}/heavyball/p_adam.py +11 -8
  13. {heavyball-0.15.1 → heavyball-0.17.0}/heavyball/palm_foreach_sfadamw.py +3 -2
  14. {heavyball-0.15.1 → heavyball-0.17.0}/heavyball/palm_foreach_soap.py +3 -2
  15. {heavyball-0.15.1 → heavyball-0.17.0}/heavyball/precond_schedule_foreach_soap.py +3 -2
  16. {heavyball-0.15.1 → heavyball-0.17.0}/heavyball/precond_schedule_palm_foreach_soap.py +3 -2
  17. {heavyball-0.15.1 → heavyball-0.17.0}/heavyball/precond_schedule_sfpsoap.py +3 -3
  18. {heavyball-0.15.1 → heavyball-0.17.0}/heavyball/psgd_kron.py +10 -7
  19. {heavyball-0.15.1 → heavyball-0.17.0}/heavyball/pure_psgd.py +11 -7
  20. {heavyball-0.15.1 → heavyball-0.17.0}/heavyball/schedule_free_palm_foreach_soap.py +4 -3
  21. {heavyball-0.15.1 → heavyball-0.17.0}/heavyball/utils.py +41 -18
  22. {heavyball-0.15.1 → heavyball-0.17.0}/heavyball.egg-info/PKG-INFO +4 -2
  23. {heavyball-0.15.1 → heavyball-0.17.0}/heavyball.egg-info/SOURCES.txt +3 -0
  24. {heavyball-0.15.1 → heavyball-0.17.0}/setup.py +1 -1
  25. heavyball-0.17.0/test/test_bf16_q.py +52 -0
  26. {heavyball-0.15.1 → heavyball-0.17.0}/test/test_closure.py +1 -1
  27. heavyball-0.17.0/test/test_foreach.py +65 -0
  28. {heavyball-0.15.1 → heavyball-0.17.0}/test/test_memory.py +2 -2
  29. {heavyball-0.15.1 → heavyball-0.17.0}/test/test_merge.py +1 -1
  30. {heavyball-0.15.1 → heavyball-0.17.0}/test/test_psgd.py +3 -14
  31. {heavyball-0.15.1 → heavyball-0.17.0}/LICENSE +0 -0
  32. {heavyball-0.15.1 → heavyball-0.17.0}/heavyball.egg-info/dependency_links.txt +0 -0
  33. {heavyball-0.15.1 → heavyball-0.17.0}/heavyball.egg-info/requires.txt +0 -0
  34. {heavyball-0.15.1 → heavyball-0.17.0}/heavyball.egg-info/top_level.txt +0 -0
  35. {heavyball-0.15.1 → heavyball-0.17.0}/setup.cfg +0 -0
  36. {heavyball-0.15.1 → heavyball-0.17.0}/test/test_no_grad.py +0 -0
  37. {heavyball-0.15.1 → heavyball-0.17.0}/test/test_soap.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: heavyball
3
- Version: 0.15.1
3
+ Version: 0.17.0
4
4
  Summary: Efficient optimizers
5
5
  Home-page: https://github.com/clashluke/heavyball
6
6
  Author: Lucas Nestler
@@ -39,12 +39,14 @@ recommended experimental optimizer is `ForeachPSGDKron`.
39
39
 
40
40
  * **Stochastic Rounding**: [FP32 convergence with BF16 parameters](https://github.com/pytorch/pytorch/issues/120376)
41
41
  * **Inplace EMA**: Same math, but less memory, less compute and higher stability
42
- * **Foreach**: Fast multi-tensor application
42
+ * **Foreach**: Fast multi-tensor application (turn it off to save memory via `foreach=False`)
43
43
  * **PaLM Beta2**: Fast initial
44
44
  convergence, [stable late convergence](https://x.com/_clashluke/status/1820810798693818761)
45
45
  * **ScheduleFree**: No learning rate schedule, but better convergence
46
46
  * [**Preconditioner Schedule**](https://github.com/lixilinx/psgd_torch/): Improved loss-per-step in early convergence,
47
47
  better step-per-second in late convergence (explained below)
48
+ * **Memory-efficient storage** PSGD supports `store_triu_as_line` (default: `True`) to trade off memory usage for memory
49
+ bandwidth; turn it off for lower overheads (for more, see [PSGD Efficiency](docs/psgd_efficiency.md))
48
50
 
49
51
  ## Getting started
50
52
 
@@ -15,12 +15,14 @@ recommended experimental optimizer is `ForeachPSGDKron`.
15
15
 
16
16
  * **Stochastic Rounding**: [FP32 convergence with BF16 parameters](https://github.com/pytorch/pytorch/issues/120376)
17
17
  * **Inplace EMA**: Same math, but less memory, less compute and higher stability
18
- * **Foreach**: Fast multi-tensor application
18
+ * **Foreach**: Fast multi-tensor application (turn it off to save memory via `foreach=False`)
19
19
  * **PaLM Beta2**: Fast initial
20
20
  convergence, [stable late convergence](https://x.com/_clashluke/status/1820810798693818761)
21
21
  * **ScheduleFree**: No learning rate schedule, but better convergence
22
22
  * [**Preconditioner Schedule**](https://github.com/lixilinx/psgd_torch/): Improved loss-per-step in early convergence,
23
23
  better step-per-second in late convergence (explained below)
24
+ * **Memory-efficient storage** PSGD supports `store_triu_as_line` (default: `True`) to trade off memory usage for memory
25
+ bandwidth; turn it off for lower overheads (for more, see [PSGD Efficiency](docs/psgd_efficiency.md))
24
26
 
25
27
  ## Getting started
26
28
 
@@ -14,6 +14,7 @@ from .precond_schedule_sfpsoap import PrecondScheduleSFPaLMSOAP
14
14
  from .psgd_kron import ForeachPSGDKron
15
15
  from .pure_psgd import ForeachPurePSGD
16
16
  from .schedule_free_palm_foreach_soap import SFPaLMForeachSOAP
17
+ from .cached_delayed_psgd_kron import ForeachCachedDelayedPSGDKron
17
18
 
18
19
  PalmForEachSoap = PaLMForeachSOAP
19
20
 
@@ -34,11 +35,12 @@ PurePSGD = ForeachPurePSGD
34
35
  PaLMPAdam = ForeachPaLMPAdam
35
36
  DelayedPSGD = ForeachDelayedPSGD
36
37
  CachedPSGDKron = ForeachCachedPSGDKron
38
+ CachedDelayedPSGDKron
37
39
 
38
40
  __all__ = ['PalmForEachSoap', 'PaLMForeachSFAdamW', 'PaLMForeachSOAP', 'SFPaLMForeachSOAP', 'PrecondScheduleSFPaLMSOAP',
39
41
  'ForeachSOAP', 'ForeachSFAdamW', 'ForeachLaProp', 'ForeachADOPT', 'PrecondScheduleForeachSOAP',
40
42
  'PrecondSchedulePaLMForeachSOAP', 'ForeachPSGDKron', 'ForeachAdamW', 'ForeachPurePSGD', 'ForeachPaLMPAdam',
41
- 'ForeachDelayedPSGD', 'ForeachCachedPSGDKron', #
43
+ 'ForeachDelayedPSGD', 'ForeachCachedPSGDKron', 'ForeachCachedDelayedPSGDKron' #
42
44
  'PaLMSOAP', 'PaLMSFAdamW', 'PaLMSFSoap', 'PaLMSFAdamW', 'PaLMForeachSOAP', 'PrecondScheduleSFPaLMSOAP',
43
45
  'SOAP', 'SFAdamW', 'LaProp', 'ADOPT', 'PSGDKron', 'AdamW', 'PurePSGD', 'PaLMPAdam', 'DelayedPSGD',
44
- 'CachedPSGDKron']
46
+ 'CachedPSGDKron', 'CachedDelayedPSGDKron']
@@ -0,0 +1,146 @@
1
+ """
2
+ Originally from Evan Walters and Omead Pooladzandi, 2024
3
+ Modified under Creative Commons Attribution 4.0 International
4
+ Source available at https://github.com/evanatyourservice/kron_torch/blob/97a2b5ee8a1a4c29e4780bbf6c521e545189eff9/kron_torch/kron.py
5
+ """
6
+
7
+ from typing import Optional
8
+
9
+ import torch
10
+ from heavyball.utils import einsum_base
11
+
12
+ from .utils import update_param_, warmup, psgd_precond_grad, init_Q_exprs, trust_region_clip_, PSGDBase, \
13
+ precond_update_prob_schedule, split_p_and_g_in_group, line_to_triu, triu_to_line, set_, einsum_base, promote
14
+
15
+
16
+ class ForeachCachedDelayedPSGDKron(PSGDBase):
17
+ """
18
+ Implements PSGD with off-by-one preconditioning (akin to ADOPT and SOAP) with cached preconditioners.
19
+
20
+
21
+ Args:
22
+ params (iterable): Iterable of parameters to optimize or dicts defining
23
+ parameter groups.
24
+ lr (float): Learning rate.
25
+ b1 (float): Momentum parameter.
26
+ weight_decay (float): Weight decay (L2 penalty).
27
+ preconditioner_update_probability (callable or float, optional): Probability of
28
+ updating the preconditioner. If None, defaults to a schedule that anneals
29
+ from 1.0 to 0.03 by 4000 steps.
30
+ max_size_triangular (int): Max size for dim's preconditioner to be triangular.
31
+ min_ndim_triangular (int): Minimum number of dimensions a layer needs
32
+ to have triangular preconditioners.
33
+ memory_save_mode: (string, optional), None, 'one_diag', or 'all_diag', None is default
34
+ to set all preconditioners to be triangular, 'one_diag' sets the largest
35
+ or last dim to be diagonal per layer, and 'all_diag' sets all preconditioners
36
+ to be diagonal.
37
+ momentum_into_precond_update: (bool), whether to send momentum into preconditioner
38
+ update instead of raw gradients.
39
+ """
40
+
41
+ def __init__(self, params, lr=0.001, beta=0.9, weight_decay=0.0, preconditioner_update_probability=None,
42
+ max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
43
+ momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
44
+ split: bool = False, clip_fn: Optional[callable] = None, store_triu_as_line: bool = True,
45
+ foreach: bool = True, q_dtype='float32'):
46
+ if not 0.0 <= lr:
47
+ raise ValueError(f"Invalid learning rate: {lr}")
48
+ if not 0.0 <= beta < 1.0:
49
+ raise ValueError(f"Invalid beta parameter: {beta}")
50
+ if not 0.0 <= weight_decay:
51
+ raise ValueError(f"Invalid weight_decay value: {weight_decay}")
52
+
53
+ if preconditioner_update_probability is None:
54
+ preconditioner_update_probability = precond_update_prob_schedule()
55
+ if clip_fn is None:
56
+ clip_fn = lambda x: trust_region_clip_(x, 0.9, 1.5)
57
+ self.preconditioner_update_probability = preconditioner_update_probability
58
+ self.clip_fn = clip_fn
59
+
60
+ defaults = dict(lr=lr, beta=beta, weight_decay=weight_decay, max_size_triangular=max_size_triangular,
61
+ min_ndim_triangular=min_ndim_triangular, memory_save_mode=memory_save_mode,
62
+ momentum_into_precond_update=momentum_into_precond_update, precond_lr=0.1,
63
+ # precond lr hardcoded to 0.1
64
+ precond_init_scale=1.0, # precond init scale hardcoded to 1.0
65
+ step=0, warmup_steps=warmup_steps, merge_dims=merge_dims, split=split,
66
+ store_triu_as_line=store_triu_as_line,
67
+ q_dtype=q_dtype)
68
+ super().__init__(params, defaults, foreach)
69
+
70
+ self._prob_step = 0
71
+
72
+ def _step(self, group):
73
+ # update preconditioners all together
74
+ update_prob = self.preconditioner_update_probability
75
+ if callable(update_prob):
76
+ update_prob = update_prob(self._prob_step)
77
+ do_update = self.rng.random() < update_prob
78
+ self._prob_step += 1
79
+
80
+ momentum_into_precond_update = group.get("momentum_into_precond_update", True)
81
+ precond_init_scale = group['precond_init_scale']
82
+ max_size_triangular = group['max_size_triangular']
83
+ min_ndim_triangular = group['min_ndim_triangular']
84
+ memory_save_mode = group['memory_save_mode']
85
+ precond_lr = group['precond_lr']
86
+ weight_decay = group['weight_decay']
87
+ lr = group['lr']
88
+ beta = group['beta']
89
+ store_triu_as_line = group['store_triu_as_line']
90
+ q_dtype = getattr(torch, group['q_dtype'])
91
+
92
+ vals = []
93
+
94
+ for p, g in split_p_and_g_in_group(group):
95
+ state = self.state_(p)
96
+
97
+ if 'Q' not in state:
98
+ state["exp_avg"] = torch.zeros_like(g)
99
+ Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular, min_ndim_triangular,
100
+ memory_save_mode, dtype=q_dtype)
101
+ state['Q'] = triu_to_line(Q) if store_triu_as_line else Q
102
+ state['Q_cache'] = [torch.empty_like(q) for q in Q]
103
+
104
+ expr = [f'{c.upper()}{c}' if q_.ndim == 2 else c for c, q_ in zip(einsum_base, Q)]
105
+ expr = ','.join(expr)
106
+ grad_expr = ''.join(c for c, _ in zip(einsum_base, g.shape))
107
+ out_expr = ''.join(c.upper() if c.upper() in expr else c for c in grad_expr)
108
+ expr = f'{expr},{grad_expr}->{out_expr}'
109
+
110
+ state['cache_expr'] = expr
111
+
112
+ vals.append((p, g, state["exp_avg"], state["Q"], state['Q_cache']))
113
+
114
+ if not vals:
115
+ return
116
+
117
+ p_list, grad_list, exp_avg_list, Q_list, Q_cache_list = zip(*vals)
118
+ del vals
119
+
120
+ group["step"] += 1
121
+
122
+ torch._foreach_lerp_(exp_avg_list, grad_list, (1 - beta) / (1 - beta ** group["step"]))
123
+
124
+ grad_list, Q_list, Q_cache_list, exp_avg_list = list(grad_list), list(Q_list), list(Q_cache_list), list(
125
+ exp_avg_list)
126
+ for i, (p, g) in enumerate(zip(p_list, grad_list)):
127
+ cached_q = Q_cache_list.pop(0)
128
+ q_orig = Q_list.pop(0)
129
+ ea = exp_avg_list.pop(0)
130
+
131
+ if do_update:
132
+ q = line_to_triu(q_orig) if store_triu_as_line else q_orig
133
+ q32 = [promote(q_) for q_ in q]
134
+ self.balance([g], [q32])
135
+ self.do_update([p], [ea if momentum_into_precond_update else g], [q32], precond_lr, [q_orig], store_triu_as_line=store_triu_as_line)
136
+ for c_, q_ in zip(cached_q, q):
137
+ if q_.ndim == 2:
138
+ torch.matmul(q_.T.conj(), q_, out=c_)
139
+ else:
140
+ torch.mul(q_.conj(), q_, out=c_)
141
+
142
+ set_(g, torch.einsum(self.state_(p)['cache_expr'], *cached_q, ea))
143
+ grad_list = self.clip_fn(grad_list)
144
+
145
+ lr = -warmup(lr, group['step'], group['warmup_steps'])
146
+ update_param_(p_list, grad_list, lr, weight_decay)
@@ -10,7 +10,7 @@ import torch
10
10
  from heavyball.utils import einsum_base
11
11
 
12
12
  from .utils import update_param_, warmup, psgd_precond_grad, init_Q_exprs, trust_region_clip_, PSGDBase, \
13
- precond_update_prob_schedule, split_p_and_g_in_group, line_to_triu, triu_to_line, set_, einsum_base
13
+ precond_update_prob_schedule, split_p_and_g_in_group, line_to_triu, triu_to_line, set_, einsum_base, promote
14
14
 
15
15
 
16
16
  class ForeachCachedPSGDKron(PSGDBase):
@@ -39,7 +39,8 @@ class ForeachCachedPSGDKron(PSGDBase):
39
39
  def __init__(self, params, lr=0.001, beta=0.9, weight_decay=0.0, preconditioner_update_probability=None,
40
40
  max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
41
41
  momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
42
- split: bool = False, clip_fn: Optional[callable] = None, store_triu_as_line: bool = True):
42
+ split: bool = False, clip_fn: Optional[callable] = None, store_triu_as_line: bool = True,
43
+ foreach: bool = True, q_dtype='float32'):
43
44
  if not 0.0 <= lr:
44
45
  raise ValueError(f"Invalid learning rate: {lr}")
45
46
  if not 0.0 <= beta < 1.0:
@@ -60,8 +61,9 @@ class ForeachCachedPSGDKron(PSGDBase):
60
61
  # precond lr hardcoded to 0.1
61
62
  precond_init_scale=1.0, # precond init scale hardcoded to 1.0
62
63
  step=0, warmup_steps=warmup_steps, merge_dims=merge_dims, split=split,
63
- store_triu_as_line=store_triu_as_line)
64
- super().__init__(params, defaults)
64
+ store_triu_as_line=store_triu_as_line,
65
+ q_dtype=q_dtype)
66
+ super().__init__(params, defaults, foreach)
65
67
 
66
68
  self._prob_step = 0
67
69
 
@@ -83,6 +85,7 @@ class ForeachCachedPSGDKron(PSGDBase):
83
85
  lr = group['lr']
84
86
  beta = group['beta']
85
87
  store_triu_as_line = group['store_triu_as_line']
88
+ q_dtype = getattr(torch, group['q_dtype'])
86
89
 
87
90
  vals = []
88
91
 
@@ -92,7 +95,7 @@ class ForeachCachedPSGDKron(PSGDBase):
92
95
  if 'Q' not in state:
93
96
  state["exp_avg"] = torch.zeros_like(g)
94
97
  Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular, min_ndim_triangular,
95
- memory_save_mode, dtype=g.dtype)
98
+ memory_save_mode, dtype=q_dtype)
96
99
  state['Q'] = triu_to_line(Q) if store_triu_as_line else Q
97
100
  state['Q_cache'] = [torch.empty_like(q) for q in Q]
98
101
 
@@ -123,18 +126,21 @@ class ForeachCachedPSGDKron(PSGDBase):
123
126
  q_orig = Q_list.pop(0)
124
127
  ea = exp_avg_list.pop(0)
125
128
 
129
+ new = torch.einsum(self.state_(p)['cache_expr'], *cached_q, ea)
130
+
126
131
  if do_update:
127
132
  q = line_to_triu(q_orig) if store_triu_as_line else q_orig
128
- self.balance([g], [q])
129
- self.do_update([p], [ea if momentum_into_precond_update else g], [q], precond_lr,
130
- [q_orig] if store_triu_as_line else None)
133
+ q32 = [promote(q_) for q_ in q]
134
+ self.balance([g], [q32])
135
+ self.do_update([p], [ea if momentum_into_precond_update else g], [q32], precond_lr, [q_orig], store_triu_as_line=store_triu_as_line)
131
136
  for c_, q_ in zip(cached_q, q):
132
137
  if q_.ndim == 2:
133
138
  torch.matmul(q_.T.conj(), q_, out=c_)
134
139
  else:
135
140
  torch.mul(q_.conj(), q_, out=c_)
136
141
 
137
- set_(g, torch.einsum(self.state_(p)['cache_expr'], *cached_q, ea))
142
+ set_(g, new)
143
+
138
144
  grad_list = self.clip_fn(grad_list)
139
145
 
140
146
  lr = -warmup(lr, group['step'], group['warmup_steps'])
@@ -8,7 +8,7 @@ import torch
8
8
  from heavyball.utils import copy_stochastic_list_
9
9
 
10
10
  from .utils import update_param_, warmup, psgd_precond_grad, init_Q_exprs, trust_region_clip_, PSGDBase, \
11
- precond_update_prob_schedule, split_p_and_g_in_group, triu_to_line, line_to_triu, set_
11
+ precond_update_prob_schedule, split_p_and_g_in_group, triu_to_line, line_to_triu, set_, promote
12
12
 
13
13
 
14
14
  class ForeachDelayedPSGD(PSGDBase):
@@ -38,7 +38,8 @@ class ForeachDelayedPSGD(PSGDBase):
38
38
  def __init__(self, params, lr=0.001, beta=0.9, weight_decay=0.0, preconditioner_update_probability=None,
39
39
  max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
40
40
  momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
41
- split: bool = False, clip_fn: callable = None, store_triu_as_line: bool = True):
41
+ split: bool = False, clip_fn: callable = None, store_triu_as_line: bool = True,
42
+ foreach: bool = True, q_dtype='float32'):
42
43
  if not 0.0 <= lr:
43
44
  raise ValueError(f"Invalid learning rate: {lr}")
44
45
  if not 0.0 <= beta < 1.0:
@@ -59,8 +60,8 @@ class ForeachDelayedPSGD(PSGDBase):
59
60
  # precond lr hardcoded to 0.1
60
61
  precond_init_scale=1.0, # precond init scale hardcoded to 1.0
61
62
  step=0, warmup_steps=warmup_steps, merge_dims=merge_dims, split=split,
62
- store_triu_as_line=store_triu_as_line)
63
- super().__init__(params, defaults)
63
+ store_triu_as_line=store_triu_as_line, q_dtype=q_dtype)
64
+ super().__init__(params, defaults, foreach)
64
65
 
65
66
  self._prob_step = 0
66
67
 
@@ -82,6 +83,7 @@ class ForeachDelayedPSGD(PSGDBase):
82
83
  lr = group['lr']
83
84
  beta = group['beta']
84
85
  store_triu_as_line = group['store_triu_as_line']
86
+ q_dtype = getattr(torch, group['q_dtype'])
85
87
 
86
88
  vals = []
87
89
 
@@ -91,7 +93,7 @@ class ForeachDelayedPSGD(PSGDBase):
91
93
  if 'Q' not in state:
92
94
  state["exp_avg"] = torch.zeros_like(g)
93
95
  Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular, min_ndim_triangular,
94
- memory_save_mode, dtype=g.dtype)
96
+ memory_save_mode, dtype=q_dtype)
95
97
  state["Q"] = triu_to_line(Q) if store_triu_as_line else Q
96
98
 
97
99
  vals.append((p, g, state["exp_avg"], state["Q"]))
@@ -113,8 +115,9 @@ class ForeachDelayedPSGD(PSGDBase):
113
115
  q = line_to_triu(q_orig) if store_triu_as_line else q_orig
114
116
  new = psgd_precond_grad(q, self.state_(p)["exprs"], ea)
115
117
  if do_update:
116
- self.do_update([p], [ea if momentum_into_precond_update else g], [q], precond_lr, [q_orig] if store_triu_as_line else None)
117
- self.balance([g], [q])
118
+ q32 = [promote(q_) for q_ in q]
119
+ self.do_update([p], [ea if momentum_into_precond_update else g], [q32], precond_lr, [q_orig], store_triu_as_line=store_triu_as_line)
120
+ self.balance([g], [q32])
118
121
  set_(g, new)
119
122
 
120
123
  grad_list = self.clip_fn(grad_list)
@@ -5,10 +5,11 @@ from .utils import warmup, exp_avg_sq_, beta_debias, update_param_, StatefulOpti
5
5
 
6
6
 
7
7
  class ForeachAdamW(StatefulOptimizer):
8
- def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0):
8
+ def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
9
+ foreach: bool = True):
9
10
  defaults = dict(lr=lr, betas=betas, eps=eps, k=0, warmup_steps=warmup_steps, train_mode=True, weight_sum=0.0,
10
11
  lr_max=-1.0, weight_decay=weight_decay)
11
- super().__init__(params, defaults)
12
+ super().__init__(params, defaults, foreach)
12
13
 
13
14
  def _step(self, group):
14
15
  eps = group['eps']
@@ -6,10 +6,11 @@ from .utils import warmup, beta_debias, update_param_, StatefulOptimizer
6
6
 
7
7
  class ForeachADOPT(StatefulOptimizer):
8
8
 
9
- def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0):
9
+ def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
10
+ foreach: bool = True):
10
11
  defaults = dict(lr=lr, betas=betas, eps=eps, k=0, warmup_steps=warmup_steps, train_mode=True, weight_sum=0.0,
11
12
  lr_max=-1.0, weight_decay=weight_decay)
12
- super().__init__(params, defaults)
13
+ super().__init__(params, defaults, foreach)
13
14
 
14
15
  def _step(self, group):
15
16
  eps = group['eps']
@@ -6,10 +6,11 @@ from .utils import warmup, exp_avg_sq_, beta_debias, update_param_, StatefulOpti
6
6
 
7
7
  class ForeachLaProp(StatefulOptimizer):
8
8
 
9
- def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=1):
9
+ def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=1,
10
+ foreach: bool = True):
10
11
  defaults = dict(lr=lr, betas=betas, eps=eps, k=0, warmup_steps=warmup_steps, train_mode=True, weight_sum=0.0,
11
12
  lr_max=-1.0, weight_decay=weight_decay)
12
- super().__init__(params, defaults)
13
+ super().__init__(params, defaults, foreach)
13
14
 
14
15
  def _step(self, group):
15
16
  eps = group['eps']
@@ -6,12 +6,12 @@ from .utils import schedule_free_, warmup, ScheduleFree, exp_avg_sq_, beta_debia
6
6
 
7
7
  class ForeachSFAdamW(ScheduleFree):
8
8
  def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0, r=0.0,
9
- weight_lr_power=2.0, foreach=hasattr(torch, "_foreach_mul_")):
9
+ weight_lr_power=2.0, foreach: bool = True):
10
10
 
11
11
  defaults = dict(lr=lr, betas=betas, eps=eps, r=r, k=0, warmup_steps=warmup_steps, train_mode=True,
12
12
  weight_sum=0.0, lr_max=-1.0, weight_lr_power=weight_lr_power, weight_decay=weight_decay,
13
13
  foreach=foreach)
14
- super().__init__(params, defaults)
14
+ super().__init__(params, defaults, foreach)
15
15
 
16
16
  def _step(self, group):
17
17
  eps = group['eps']
@@ -48,7 +48,7 @@ class ForeachSFAdamW(ScheduleFree):
48
48
  torch._foreach_add_(grad, y, alpha=decay)
49
49
 
50
50
  lr = warmup(group['lr'], k + 1, group['warmup_steps'])
51
- group['weight_sum'] = schedule_free_(lr, group['weight_lr_power'], group['weight_sum'], group['betas'][0],
52
- y, z, grad, group['r'], k + 1)
51
+ group['weight_sum'] = schedule_free_(lr, group['weight_lr_power'], group['weight_sum'], group['betas'][0], y, z,
52
+ grad, group['r'], k + 1)
53
53
 
54
54
  group['k'] = k + 1
@@ -26,12 +26,13 @@ class ForeachSOAP(StatefulOptimizer):
26
26
  weight_decay: float = 0.01, precondition_frequency: int = 2, max_precond_dim: int = 2048, #
27
27
  merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
28
28
  data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1,
29
- split: bool = False):
29
+ split: bool = False,
30
+ foreach: bool = True):
30
31
  defaults = {"lr": lr, "betas": betas, "shampoo_beta": shampoo_beta, "eps": eps, "weight_decay": weight_decay,
31
32
  "precondition_frequency": precondition_frequency, "max_precond_dim": max_precond_dim,
32
33
  "merge_dims": merge_dims, "precondition_1d": precondition_1d, "normalize_grads": normalize_grads,
33
34
  "correct_bias": correct_bias, 'warmup_steps': warmup_steps, 'split': split}
34
- super().__init__(params, defaults)
35
+ super().__init__(params, defaults, foreach)
35
36
  self._data_format = data_format
36
37
 
37
38
  def _step(self, group):
@@ -59,7 +60,7 @@ class ForeachSOAP(StatefulOptimizer):
59
60
  vals.append((p, g, grad_projected, exp_avg, exp_avg_sq))
60
61
 
61
62
  if not vals:
62
- return
63
+ return
63
64
 
64
65
  p_list, grad, grad_projected, exp_avg, exp_avg_sq = zip(*vals)
65
66
  beta1, beta2 = group["betas"]
@@ -8,7 +8,7 @@ import torch
8
8
  from heavyball.utils import triu_to_line, line_to_triu
9
9
 
10
10
  from .utils import update_param_, warmup, psgd_precond_grad, init_Q_exprs, PSGDBase, precond_update_prob_schedule, \
11
- exp_avg_sq_, beta_debias, split_p_and_g_in_group
11
+ exp_avg_sq_, beta_debias, split_p_and_g_in_group, promote
12
12
 
13
13
 
14
14
  class ForeachPaLMPAdam(PSGDBase):
@@ -38,7 +38,8 @@ class ForeachPaLMPAdam(PSGDBase):
38
38
  max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
39
39
  momentum_into_precond_update=True, warmup_steps: int = 1, betas=(None, None), beta: float = 0.9,
40
40
  beta2_scale: float = 0.8, merge_dims: bool = False, split: bool = False, clip_fn: callable = None,
41
- store_triu_as_line: bool = True):
41
+ store_triu_as_line: bool = True,
42
+ foreach: bool = True, q_dtype='float32'):
42
43
  if not 0.0 <= lr:
43
44
  raise ValueError(f"Invalid learning rate: {lr}")
44
45
  if not 0.0 <= weight_decay:
@@ -59,8 +60,8 @@ class ForeachPaLMPAdam(PSGDBase):
59
60
  # precond lr hardcoded to 0.1
60
61
  precond_init_scale=1.0, # precond init scale hardcoded to 1.0
61
62
  step=0, warmup_steps=warmup_steps, beta=beta, beta2_scale=beta2_scale, merge_dims=merge_dims,
62
- split=split, store_triu_as_line=store_triu_as_line)
63
- super().__init__(params, defaults)
63
+ split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype)
64
+ super().__init__(params, defaults, foreach)
64
65
 
65
66
  self._prob_step = 0
66
67
 
@@ -80,6 +81,7 @@ class ForeachPaLMPAdam(PSGDBase):
80
81
  weight_decay = group['weight_decay']
81
82
  lr = group['lr']
82
83
  store_triu_as_line = group['store_triu_as_line']
84
+ q_dtype = getattr(torch, group['q_dtype'])
83
85
 
84
86
  vals = []
85
87
 
@@ -90,7 +92,7 @@ class ForeachPaLMPAdam(PSGDBase):
90
92
  state['exp_avg'] = torch.zeros_like(g)
91
93
  state['exp_avg_sq'] = torch.zeros_like(g)
92
94
  Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular,
93
- min_ndim_triangular, memory_save_mode, dtype=g.dtype)
95
+ min_ndim_triangular, memory_save_mode, dtype=q_dtype)
94
96
  state['Q'] = triu_to_line(Q) if store_triu_as_line else Q
95
97
 
96
98
  vals.append((p, g, state["Q"], state['exp_avg'], state['exp_avg_sq']))
@@ -105,9 +107,10 @@ class ForeachPaLMPAdam(PSGDBase):
105
107
 
106
108
  Q_triu = [line_to_triu(q) if store_triu_as_line else q for q in Q_list]
107
109
  if do_update:
108
- self.balance(grad_list, Q_triu)
109
- self.do_update(p_list, grad_list, Q_triu, precond_lr, Q_list if store_triu_as_line else None)
110
-
110
+ for g, p, q_, q_orig in zip(grad_list, p_list, Q_triu, Q_list):
111
+ q32 = [promote(qq_) for qq_ in q_]
112
+ self.balance([g], [q32])
113
+ self.do_update([p], [g], [q32], precond_lr, [q_orig], store_triu_as_line=store_triu_as_line)
111
114
  torch._foreach_lerp_(exp_avg, grad_list, 1 - beta_debias(group['beta'], group['step']))
112
115
 
113
116
  beta2 = 1 - group['step'] ** -group['beta2_scale']
@@ -6,13 +6,14 @@ from .utils import schedule_free_, warmup, ScheduleFree, exp_avg_sq_, beta_debia
6
6
 
7
7
  class PaLMForeachSFAdamW(ScheduleFree):
8
8
  def __init__(self, params, lr=0.0025, beta=0.9, betas=(None, None), eps=1e-8, weight_decay=0, warmup_steps=0, r=0.0,
9
- weight_lr_power=2.0, beta2_scale: float = 0.8):
9
+ weight_lr_power=2.0, beta2_scale: float = 0.8,
10
+ foreach: bool = True):
10
11
  if betas[0] is not None:
11
12
  beta = betas[0]
12
13
  defaults = dict(lr=lr, beta=beta, eps=eps, r=r, k=0, warmup_steps=warmup_steps, train_mode=True, weight_sum=0.0,
13
14
  lr_max=-1.0, weight_lr_power=weight_lr_power, weight_decay=weight_decay,
14
15
  beta2_scale=beta2_scale)
15
- super().__init__(params, defaults)
16
+ super().__init__(params, defaults, foreach)
16
17
 
17
18
  def _step(self, group):
18
19
  eps = group['eps']
@@ -32,7 +32,8 @@ class PaLMForeachSOAP(StatefulOptimizer):
32
32
  max_precond_dim: int = 2048, #
33
33
  merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
34
34
  data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1,
35
- beta2_scale: float = 0.8, split: bool = False):
35
+ beta2_scale: float = 0.8, split: bool = False,
36
+ foreach: bool = True):
36
37
  if betas[0] is not None:
37
38
  beta = betas[0]
38
39
  defaults = {"lr": lr, "beta": beta, "shampoo_beta": shampoo_beta, "eps": eps, "weight_decay": weight_decay,
@@ -40,7 +41,7 @@ class PaLMForeachSOAP(StatefulOptimizer):
40
41
  "merge_dims": merge_dims, "precondition_1d": precondition_1d, "normalize_grads": normalize_grads,
41
42
  "correct_bias": correct_bias, 'warmup_steps': warmup_steps, 'beta2_scale': beta2_scale,
42
43
  'split': split}
43
- super().__init__(params, defaults)
44
+ super().__init__(params, defaults, foreach)
44
45
  self._data_format = data_format
45
46
 
46
47
  def _step(self, group):
@@ -27,13 +27,14 @@ class PrecondScheduleForeachSOAP(StatefulOptimizer):
27
27
  weight_decay: float = 0.01, precondition_frequency: int = 2, max_precond_dim: int = 2048, #
28
28
  merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
29
29
  data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1,
30
- precond_scheduler=(1 / 3, 9), split: bool = False):
30
+ precond_scheduler=(1 / 3, 9), split: bool = False,
31
+ foreach: bool = True):
31
32
  defaults = {"lr": lr, "betas": betas, "shampoo_beta": shampoo_beta, "eps": eps, "weight_decay": weight_decay,
32
33
  "precondition_frequency": precondition_frequency, "max_precond_dim": max_precond_dim,
33
34
  "merge_dims": merge_dims, "precondition_1d": precondition_1d, "normalize_grads": normalize_grads,
34
35
  "correct_bias": correct_bias, 'warmup_steps': warmup_steps, 'precond_scheduler': precond_scheduler,
35
36
  'split': split}
36
- super().__init__(params, defaults)
37
+ super().__init__(params, defaults, foreach)
37
38
  self._data_format = data_format
38
39
  self.rng = random.Random(0x120983109)
39
40
 
@@ -32,7 +32,8 @@ class PrecondSchedulePaLMForeachSOAP(StatefulOptimizer):
32
32
  weight_decay: float = 0.01, precondition_frequency: int = 2, max_precond_dim: int = 2048, #
33
33
  merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
34
34
  data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1,
35
- precond_scheduler=(1 / 3, 9), betas=(None, None), beta2_scale: float = 0.8, split: bool = False):
35
+ precond_scheduler=(1 / 3, 9), betas=(None, None), beta2_scale: float = 0.8, split: bool = False,
36
+ foreach: bool = True):
36
37
  if betas[0] is not None:
37
38
  beta = betas[0]
38
39
  defaults = {"lr": lr, "beta": beta, "shampoo_beta": shampoo_beta, "eps": eps, "weight_decay": weight_decay,
@@ -40,7 +41,7 @@ class PrecondSchedulePaLMForeachSOAP(StatefulOptimizer):
40
41
  "merge_dims": merge_dims, "precondition_1d": precondition_1d, "normalize_grads": normalize_grads,
41
42
  "correct_bias": correct_bias, 'warmup_steps': warmup_steps, 'precond_scheduler': precond_scheduler,
42
43
  'beta2_scale': beta2_scale, 'split': split}
43
- super().__init__(params, defaults)
44
+ super().__init__(params, defaults, foreach)
44
45
  self._data_format = data_format
45
46
  self.rng = random.Random(0x120983109)
46
47
 
@@ -41,7 +41,7 @@ class PrecondScheduleSFPaLMSOAP(ScheduleFree):
41
41
  merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
42
42
  data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1, r=0.0,
43
43
  weight_lr_power=2.0, gradient_clip_val: float = 0.1, precond_scheduler=(1 / 3, 9),
44
- betas=(None, None), split: bool = False):
44
+ betas=(None, None), split: bool = False, foreach: bool = True):
45
45
  if betas[0] is not None:
46
46
  beta = betas[0]
47
47
  defaults = {"lr": lr, "beta": beta, "beta2_scale": beta2_scale, "eps": eps, "weight_decay": weight_decay,
@@ -50,7 +50,7 @@ class PrecondScheduleSFPaLMSOAP(ScheduleFree):
50
50
  "correct_bias": correct_bias, 'warmup_steps': warmup_steps, 'r': r,
51
51
  'weight_lr_power': weight_lr_power, 'train_mode': True, 'step': -1, 'weight_sum': 0,
52
52
  'gradient_clip_val': gradient_clip_val, 'precond_scheduler': precond_scheduler, 'split': split}
53
- super().__init__(params, defaults)
53
+ super().__init__(params, defaults, foreach)
54
54
  self._data_format = data_format
55
55
  self.rng = random.Random(0x120983109)
56
56
 
@@ -59,7 +59,7 @@ class PrecondScheduleSFPaLMSOAP(ScheduleFree):
59
59
  max_precond_dim = group['max_precond_dim']
60
60
  precondition_1d = group['precondition_1d']
61
61
 
62
- step = group['step'] = group.get("step", -1) + 1
62
+ step = group['step'] = group.get("step", 0) + 1
63
63
 
64
64
  for p in group["params"]:
65
65
  if p.grad is None:
@@ -9,7 +9,7 @@ from typing import Optional
9
9
  import torch
10
10
 
11
11
  from .utils import update_param_, warmup, psgd_precond_grad, init_Q_exprs, trust_region_clip_, PSGDBase, \
12
- precond_update_prob_schedule, split_p_and_g_in_group, line_to_triu, triu_to_line, set_
12
+ precond_update_prob_schedule, split_p_and_g_in_group, line_to_triu, triu_to_line, set_, promote
13
13
 
14
14
 
15
15
  class ForeachPSGDKron(PSGDBase):
@@ -38,7 +38,8 @@ class ForeachPSGDKron(PSGDBase):
38
38
  def __init__(self, params, lr=0.001, beta=0.9, weight_decay=0.0, preconditioner_update_probability=None,
39
39
  max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
40
40
  momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
41
- split: bool = False, clip_fn: Optional[callable] = None, store_triu_as_line: bool = True):
41
+ split: bool = False, clip_fn: Optional[callable] = None, store_triu_as_line: bool = True,
42
+ foreach: bool = True, q_dtype='float32'):
42
43
  if not 0.0 <= lr:
43
44
  raise ValueError(f"Invalid learning rate: {lr}")
44
45
  if not 0.0 <= beta < 1.0:
@@ -59,8 +60,8 @@ class ForeachPSGDKron(PSGDBase):
59
60
  # precond lr hardcoded to 0.1
60
61
  precond_init_scale=1.0, # precond init scale hardcoded to 1.0
61
62
  step=0, warmup_steps=warmup_steps, merge_dims=merge_dims, split=split,
62
- store_triu_as_line=store_triu_as_line)
63
- super().__init__(params, defaults)
63
+ store_triu_as_line=store_triu_as_line, q_dtype=q_dtype)
64
+ super().__init__(params, defaults, foreach)
64
65
 
65
66
  self._prob_step = 0
66
67
 
@@ -82,6 +83,7 @@ class ForeachPSGDKron(PSGDBase):
82
83
  lr = group['lr']
83
84
  beta = group['beta']
84
85
  store_triu_as_line = group['store_triu_as_line']
86
+ q_dtype = getattr(torch, group['q_dtype'])
85
87
 
86
88
  vals = []
87
89
 
@@ -91,7 +93,7 @@ class ForeachPSGDKron(PSGDBase):
91
93
  if 'Q' not in state:
92
94
  state["exp_avg"] = torch.zeros_like(g)
93
95
  Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular, min_ndim_triangular,
94
- memory_save_mode, dtype=g.dtype)
96
+ memory_save_mode, dtype=q_dtype)
95
97
  state['Q'] = triu_to_line(Q) if store_triu_as_line else Q
96
98
 
97
99
  vals.append((p, g, state["exp_avg"], state["Q"]))
@@ -113,8 +115,9 @@ class ForeachPSGDKron(PSGDBase):
113
115
  q = line_to_triu(q_orig) if store_triu_as_line else q_orig
114
116
 
115
117
  if do_update:
116
- self.balance([g], [q])
117
- self.do_update([p], [ea if momentum_into_precond_update else g], [q], precond_lr, [q_orig] if store_triu_as_line else None)
118
+ q32 = [promote(q_) for q_ in q]
119
+ self.balance([ea if momentum_into_precond_update else g], [q32])
120
+ self.do_update([p], [g], [q32], precond_lr, [q_orig], store_triu_as_line=store_triu_as_line)
118
121
  set_(g, psgd_precond_grad(q, self.state_(p)["exprs"], ea))
119
122
 
120
123
  grad_list = self.clip_fn(grad_list)
@@ -5,9 +5,10 @@ Source available at https://github.com/evanatyourservice/kron_torch/blob/97a2b5e
5
5
  """
6
6
 
7
7
  import torch
8
+ from heavyball.utils import copy_stochastic_list_
8
9
 
9
10
  from .utils import update_param_, warmup, psgd_precond_grad, init_Q_exprs, PSGDBase, precond_update_prob_schedule, \
10
- split_p_and_g_in_group, line_to_triu, triu_to_line
11
+ split_p_and_g_in_group, line_to_triu, triu_to_line, promote
11
12
 
12
13
 
13
14
  class ForeachPurePSGD(PSGDBase):
@@ -36,7 +37,8 @@ class ForeachPurePSGD(PSGDBase):
36
37
  def __init__(self, params, lr=0.001, weight_decay=0.0, preconditioner_update_probability=None,
37
38
  max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
38
39
  momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
39
- split: bool = False, clip_fn: callable = None, store_triu_as_line: bool = True):
40
+ split: bool = False, clip_fn: callable = None, store_triu_as_line: bool = True,
41
+ foreach: bool = True, q_dtype='float32'):
40
42
  if not 0.0 <= lr:
41
43
  raise ValueError(f"Invalid learning rate: {lr}")
42
44
  if not 0.0 <= weight_decay:
@@ -55,8 +57,8 @@ class ForeachPurePSGD(PSGDBase):
55
57
  # precond lr hardcoded to 0.1
56
58
  precond_init_scale=1.0, # precond init scale hardcoded to 1.0
57
59
  step=0, warmup_steps=warmup_steps, merge_dims=merge_dims, split=split,
58
- store_triu_as_line=store_triu_as_line)
59
- super().__init__(params, defaults)
60
+ store_triu_as_line=store_triu_as_line, q_dtype=q_dtype)
61
+ super().__init__(params, defaults, foreach)
60
62
 
61
63
  self._prob_step = 0
62
64
 
@@ -76,6 +78,7 @@ class ForeachPurePSGD(PSGDBase):
76
78
  weight_decay = group['weight_decay']
77
79
  lr = group['lr']
78
80
  store_triu_as_line = group['store_triu_as_line']
81
+ q_dtype = getattr(torch, group['q_dtype'])
79
82
 
80
83
  vals = []
81
84
 
@@ -84,7 +87,7 @@ class ForeachPurePSGD(PSGDBase):
84
87
 
85
88
  if 'Q' not in state:
86
89
  Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular, min_ndim_triangular,
87
- memory_save_mode, dtype=g.dtype)
90
+ memory_save_mode, dtype=q_dtype)
88
91
  state['Q'] = triu_to_line(Q) if store_triu_as_line else Q
89
92
 
90
93
  vals.append((p, g, state["Q"]))
@@ -103,8 +106,9 @@ class ForeachPurePSGD(PSGDBase):
103
106
  q = line_to_triu(q_orig) if store_triu_as_line else q_orig
104
107
 
105
108
  if do_update:
106
- self.balance([g], [q])
107
- self.do_update([p], [g], [q], precond_lr, [q_orig] if store_triu_as_line else None)
109
+ q32 = [promote(q_) for q_ in q]
110
+ self.balance([g], [q32])
111
+ self.do_update([p], [g], [q32], precond_lr, [q_orig], store_triu_as_line=store_triu_as_line)
108
112
  psgd_precond_grad(q, self.state_(p)["exprs"], g, inplace=True)
109
113
 
110
114
  grad_list = self.clip_fn(grad_list)
@@ -33,7 +33,8 @@ class SFPaLMForeachSOAP(ScheduleFree):
33
33
  weight_decay: float = 0.01, precondition_frequency: int = 2, max_precond_dim: int = 2048, #
34
34
  merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
35
35
  data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1, r=0.0,
36
- weight_lr_power=2.0, gradient_clip_val: float = 0.1, betas=(None, None), split: bool = False):
36
+ weight_lr_power=2.0, gradient_clip_val: float = 0.1, betas=(None, None), split: bool = False,
37
+ foreach: bool = True):
37
38
  if betas[0] is not None:
38
39
  beta = betas[0]
39
40
  defaults = {"lr": lr, "beta": beta, "beta2_scale": beta2_scale, "eps": eps, "weight_decay": weight_decay,
@@ -42,7 +43,7 @@ class SFPaLMForeachSOAP(ScheduleFree):
42
43
  "correct_bias": correct_bias, 'warmup_steps': warmup_steps, 'r': r,
43
44
  'weight_lr_power': weight_lr_power, 'train_mode': True, 'step': -1,
44
45
  'gradient_clip_val': gradient_clip_val, 'weight_sum': 0, 'split': split}
45
- super().__init__(params, defaults)
46
+ super().__init__(params, defaults, foreach)
46
47
  self._data_format = data_format
47
48
  self.rng = random.Random(0x120983109)
48
49
 
@@ -51,7 +52,7 @@ class SFPaLMForeachSOAP(ScheduleFree):
51
52
  max_precond_dim = group['max_precond_dim']
52
53
  precondition_1d = group['precondition_1d']
53
54
 
54
- step = group['step'] = group.get("step", -1) + 1
55
+ step = group['step'] = group.get("step", 0) + 1
55
56
 
56
57
  for p in group["params"]:
57
58
  if p.grad is None:
@@ -325,9 +325,9 @@ def compute_ggt(grad, GG, max_precond_dim, precondition_1d, beta):
325
325
 
326
326
 
327
327
  def promote(x):
328
- if x is (torch.bfloat16, torch.float16):
328
+ if x in (torch.bfloat16, torch.float16):
329
329
  return torch.float32
330
- if x.dtype in (torch.bfloat16, torch.float16):
330
+ if hasattr(x, 'dtype') and x.dtype in (torch.bfloat16, torch.float16):
331
331
  return x.float()
332
332
  return x
333
333
 
@@ -383,8 +383,25 @@ def project(grad, Q, back: bool):
383
383
 
384
384
 
385
385
  class StatefulOptimizer(torch.optim.Optimizer):
386
+ def __init__(self, params, defaults, foreach: bool = True):
387
+ super().__init__(params, {**defaults, 'foreach': foreach})
388
+ self.fake_groups = {}
389
+
390
+ def key(self, param: torch.Tensor):
391
+ return (param.data_ptr(), tuple(param.shape))
392
+
393
+ def get_groups(self, group):
394
+ if group['foreach']:
395
+ return [group]
396
+
397
+ for p in group['params']:
398
+ if self.key(p) not in self.fake_groups:
399
+ self.fake_groups[self.key(p)] = {**group, 'params': [p]}
400
+
401
+ return [self.fake_groups[self.key(p)] for p in group['params']]
402
+
386
403
  def state_(self, arg: torch.Tensor):
387
- return self.state[(arg.data_ptr(), tuple(arg.shape))]
404
+ return self.state[self.key(arg)]
388
405
 
389
406
  def state_size(self) -> int:
390
407
  total_bytes = 0
@@ -409,8 +426,9 @@ class StatefulOptimizer(torch.optim.Optimizer):
409
426
  with torch.enable_grad():
410
427
  loss = closure()
411
428
  with torch.no_grad():
412
- for group in self.param_groups:
413
- self._step(group)
429
+ for top_group in self.param_groups:
430
+ for group in self.get_groups(top_group):
431
+ self._step(group)
414
432
  return loss
415
433
 
416
434
 
@@ -450,15 +468,15 @@ class ScheduleFree(StatefulOptimizer):
450
468
 
451
469
  def copy_stochastic_list_(target: List[torch.Tensor], source: List[torch.Tensor]):
452
470
  for t, s in zip(target, source):
453
- if t.dtype == torch.bfloat16:
454
- copy_stochastic_(t, s)
455
- else:
456
- set_(t, s)
471
+ copy_stochastic_(t, s)
457
472
 
458
473
 
459
474
  def copy_stochastic_(target: torch.Tensor, source: torch.Tensor):
460
475
  if target.data_ptr() == source.data_ptr():
461
476
  return
477
+ if target.dtype != torch.bfloat16:
478
+ set_(target, source)
479
+ return
462
480
 
463
481
  """Taken as-is from https://github.com/pytorch/pytorch/issues/120376#issuecomment-1974828905"""
464
482
  # create a random 16 bit integer
@@ -537,7 +555,7 @@ def init_Q_exprs(t, scale, max_size, min_ndim_triangular, memory_save_mode, dtyp
537
555
  for i, (size, dim_d) in enumerate(zip(shape, dim_diag)):
538
556
  if size == 1 or size > max_size or len(shape) < min_ndim_triangular or dim_d:
539
557
  # use diagonal matrix as preconditioner for this dim
540
- Q.append(scale * torch.ones(size, dtype=dtype, device=t.device))
558
+ Q.append(scale * torch.ones(size, dtype=promote(dtype), device=t.device))
541
559
 
542
560
  piece1A.append(letters[i])
543
561
  piece2A = piece2A + letters[i]
@@ -651,11 +669,11 @@ def psgd_update_precond(Q, exprs, V, G, step, tiny):
651
669
  @decorator
652
670
  def psgd_precond_grad(Q, exprs, G, inplace: bool = False):
653
671
  """Precondition gradient G with preconditioner Q."""
654
- out = torch.einsum(exprs[-1], *[q.conj() for q in Q], *Q, G)
672
+ out = torch.einsum(exprs[-1], *[q.conj() for q in Q], *Q, G.to(Q[0].dtype))
655
673
  if inplace:
656
674
  set_(G, out)
657
675
  return G
658
- return out
676
+ return out.to(G.dtype)
659
677
 
660
678
 
661
679
  def norm_clip_(x, scale=None):
@@ -750,28 +768,33 @@ def line_to_triu(Q_list: List[Tuple[Optional[List[int]], torch.Tensor]]):
750
768
  def update_triu_(q_state, materialised):
751
769
  for (shape0, q), (shape1, m) in zip(q_state, triu_to_line(materialised)):
752
770
  assert shape0 == shape1
753
- set_(q, m)
771
+ copy_stochastic_(q, m)
754
772
 
755
773
 
756
774
  class PSGDBase(StatefulOptimizer):
757
- def __init__(self, parameters, groups):
758
- super().__init__(parameters, groups)
775
+ balance_probability: float = 0.01
776
+
777
+ def __init__(self, parameters, groups, foreach: bool = True):
778
+ super().__init__(parameters, groups, foreach)
759
779
  self.rng = random.Random(0x1923213)
760
780
  self._tiny = torch.finfo(torch.bfloat16).tiny
761
781
 
762
782
  def balance(self, grad_list, Q_list):
763
- if self.rng.random() > 0.01:
783
+ if self.rng.random() > self.balance_probability:
764
784
  return
765
785
 
766
786
  for g, q in zip(grad_list, Q_list):
767
787
  if g.dim() > 1:
768
788
  psgd_balance_Q(q)
769
789
 
770
- def do_update(self, p_list, grad_list, q_list, precond_lr, original_q: Optional[List] = None):
790
+ def do_update(self, p_list, grad_list, q_list, precond_lr, original_q: Optional[List] = None, store_triu_as_line=False):
771
791
  for i, (p, grad, Q) in enumerate(zip(p_list, grad_list, q_list)):
772
792
  psgd_update_precond(Q, self.state_(p)["exprs"], torch.randn_like(grad), grad, precond_lr, self._tiny)
773
793
  if original_q:
774
- update_triu_(original_q[i], Q)
794
+ if store_triu_as_line:
795
+ update_triu_(original_q[i], Q)
796
+ else:
797
+ copy_stochastic_(original_q[i], Q)
775
798
 
776
799
 
777
800
  def precond_update_prob_schedule(max_prob=1.0, min_prob=0.03, decay=0.001, flat_start=250):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: heavyball
3
- Version: 0.15.1
3
+ Version: 0.17.0
4
4
  Summary: Efficient optimizers
5
5
  Home-page: https://github.com/clashluke/heavyball
6
6
  Author: Lucas Nestler
@@ -39,12 +39,14 @@ recommended experimental optimizer is `ForeachPSGDKron`.
39
39
 
40
40
  * **Stochastic Rounding**: [FP32 convergence with BF16 parameters](https://github.com/pytorch/pytorch/issues/120376)
41
41
  * **Inplace EMA**: Same math, but less memory, less compute and higher stability
42
- * **Foreach**: Fast multi-tensor application
42
+ * **Foreach**: Fast multi-tensor application (turn it off to save memory via `foreach=False`)
43
43
  * **PaLM Beta2**: Fast initial
44
44
  convergence, [stable late convergence](https://x.com/_clashluke/status/1820810798693818761)
45
45
  * **ScheduleFree**: No learning rate schedule, but better convergence
46
46
  * [**Preconditioner Schedule**](https://github.com/lixilinx/psgd_torch/): Improved loss-per-step in early convergence,
47
47
  better step-per-second in late convergence (explained below)
48
+ * **Memory-efficient storage** PSGD supports `store_triu_as_line` (default: `True`) to trade off memory usage for memory
49
+ bandwidth; turn it off for lower overheads (for more, see [PSGD Efficiency](docs/psgd_efficiency.md))
48
50
 
49
51
  ## Getting started
50
52
 
@@ -2,6 +2,7 @@ LICENSE
2
2
  README.md
3
3
  setup.py
4
4
  heavyball/__init__.py
5
+ heavyball/cached_delayed_psgd_kron.py
5
6
  heavyball/cached_psgd_kron.py
6
7
  heavyball/delayed_psgd.py
7
8
  heavyball/foreach_adamw.py
@@ -24,7 +25,9 @@ heavyball.egg-info/SOURCES.txt
24
25
  heavyball.egg-info/dependency_links.txt
25
26
  heavyball.egg-info/requires.txt
26
27
  heavyball.egg-info/top_level.txt
28
+ test/test_bf16_q.py
27
29
  test/test_closure.py
30
+ test/test_foreach.py
28
31
  test/test_memory.py
29
32
  test/test_merge.py
30
33
  test/test_no_grad.py
@@ -10,7 +10,7 @@ setuptools.setup(
10
10
  name='heavyball',
11
11
  license='BSD',
12
12
  description='Efficient optimizers',
13
- version='0.15.1',
13
+ version='0.17.0',
14
14
  long_description=README,
15
15
  url='https://github.com/clashluke/heavyball',
16
16
  packages=setuptools.find_packages(),
@@ -0,0 +1,52 @@
1
+ import heavyball
2
+ import heavyball.utils
3
+ import pytest
4
+ import torch
5
+ from benchmark.utils import get_optim
6
+ from heavyball.utils import clean, set_torch, PSGDBase
7
+ from torch import nn
8
+
9
+
10
+ def get_memory():
11
+ clean()
12
+ torch.cuda.synchronize()
13
+ clean()
14
+ torch.cuda.synchronize()
15
+ return torch.cuda.memory_allocated()
16
+
17
+
18
+ @pytest.mark.parametrize("opt", heavyball.__all__)
19
+ @pytest.mark.parametrize("size,depth", [(256, 2)])
20
+ def test_foreach(opt, size, depth: int, iterations: int = 128, outer_iterations: int = 3):
21
+ set_torch()
22
+
23
+ opt = getattr(heavyball, opt)
24
+ if not issubclass(opt, PSGDBase):
25
+ raise pytest.skip('Only PSGD is supported')
26
+
27
+ peaks = []
28
+ losses = []
29
+
30
+ for q_dtype in ['float32', 'bfloat16']:
31
+ peaks.append([])
32
+ losses.append([])
33
+
34
+ for i in range(outer_iterations):
35
+ torch.manual_seed(0x2131290)
36
+ model = nn.Sequential(*[nn.Linear(size, size) for _ in range(depth)]).cuda()
37
+ o = get_optim(opt, model.parameters(), lr=1e-3, q_dtype=q_dtype)
38
+
39
+ for _ in range(iterations):
40
+ loss = model(torch.randn((1024, size)).cuda()).square().mean()
41
+ loss.backward()
42
+ o.step()
43
+ o.zero_grad()
44
+ losses[-1].append(loss.detach())
45
+
46
+ del model, o
47
+ clean()
48
+
49
+
50
+ for i, (l0, l1) in enumerate(zip(*losses)):
51
+ print(i, l0.item(), l1.item())
52
+ assert torch.allclose(l0, l1, rtol=0.1)
@@ -20,7 +20,7 @@ class Param(nn.Module):
20
20
 
21
21
  @pytest.mark.parametrize("opt", heavyball.__all__)
22
22
  @pytest.mark.parametrize("size", [(4, 4, 4, 4), ])
23
- def test_closre(opt, size: List[int], depth: int = 2, iterations: int = 5, outer_iterations: int = 3):
23
+ def test_closure(opt, size: List[int], depth: int = 2, iterations: int = 5, outer_iterations: int = 3):
24
24
  clean()
25
25
  set_torch()
26
26
 
@@ -0,0 +1,65 @@
1
+ import heavyball
2
+ import heavyball.utils
3
+ import pytest
4
+ import torch
5
+ from benchmark.utils import get_optim
6
+ from heavyball.utils import clean, set_torch, PSGDBase
7
+ from torch import nn
8
+
9
+
10
+ def get_memory():
11
+ clean()
12
+ torch.cuda.synchronize()
13
+ clean()
14
+ torch.cuda.synchronize()
15
+ return torch.cuda.memory_allocated()
16
+
17
+
18
+ @pytest.mark.parametrize("opt", heavyball.__all__)
19
+ @pytest.mark.parametrize("size,depth", [(256, 128)])
20
+ def test_foreach(opt, size, depth: int, iterations: int = 5, outer_iterations: int = 3):
21
+ set_torch()
22
+
23
+ opt = getattr(heavyball, opt)
24
+
25
+ peaks = []
26
+ losses = []
27
+
28
+ for foreach in [True, False]:
29
+ peaks.append([])
30
+ losses.append([])
31
+
32
+ for i in range(outer_iterations):
33
+ torch.manual_seed(0x2131290)
34
+ clean()
35
+ model = nn.Sequential(*[nn.Linear(size, size) for _ in range(depth)]).cuda()
36
+ clean()
37
+
38
+ torch.cuda.reset_peak_memory_stats()
39
+ torch.cuda.reset_max_memory_allocated()
40
+ torch.cuda.reset_max_memory_cached()
41
+ torch.cuda.reset_accumulated_memory_stats()
42
+
43
+ clean()
44
+ o = get_optim(opt, model.parameters(), lr=1e-3, foreach=foreach)
45
+ clean()
46
+
47
+ for _ in range(iterations):
48
+ loss = model(torch.randn((1, size)).cuda()).sum()
49
+ loss.backward()
50
+ o.step()
51
+ o.zero_grad()
52
+ losses[-1].append(loss.detach())
53
+
54
+ del model, o
55
+ clean()
56
+
57
+ peak = torch.cuda.memory_stats()['allocated_bytes.all.peak']
58
+
59
+ if i > 0:
60
+ peaks[-1].append(peak)
61
+
62
+ for p0, p1 in zip(*peaks):
63
+ assert p0 > p1
64
+ for l0, l1 in zip(*losses): # increase error tolerance for PSGD, as we have different RNGs -> expected differences
65
+ assert torch.allclose(l0, l1, rtol=0.01 if isinstance(opt, PSGDBase) else 1e-5)
@@ -25,14 +25,14 @@ expected_memory = {'adamw': {'after': 4, 'peak': 5.1}, 'soap': {'after': 7, 'pea
25
25
  @pytest.mark.parametrize("size,depth", [(8192, 1), (2048, 16)])
26
26
  def test_memory(opt, method, size, depth: int, iterations: int = 5, outer_iterations: int = 3):
27
27
  if 'soap' not in opt.lower() and method != 'qr':
28
- return
28
+ raise pytest.skip('Only SOAP supports `method` argument')
29
29
  set_torch()
30
30
 
31
31
  for k, v in expected_memory.items():
32
32
  if k in opt.lower():
33
33
  break
34
34
  else:
35
- raise ValueError(f'Unknown optimizer {opt}')
35
+ raise pytest.skip(f'Opt {opt} not supported')
36
36
 
37
37
  opt = getattr(heavyball, opt)
38
38
  heavyball.utils.zeroth_power_mode = method
@@ -26,7 +26,7 @@ class Param(nn.Module):
26
26
  def test_merge(opt, method, size: List[int], merge, split, depth: int = 2, iterations: int = 5,
27
27
  outer_iterations: int = 3):
28
28
  if 'soap' not in opt.lower() and method != 'qr':
29
- return
29
+ raise pytest.skip('Only SOAP supports `method` argument')
30
30
  clean()
31
31
  set_torch()
32
32
 
@@ -1,11 +1,10 @@
1
- import pytest
2
- import torch
3
- from torch import nn
4
-
5
1
  import heavyball
6
2
  import heavyball.utils
3
+ import pytest
4
+ import torch
7
5
  from benchmark.utils import get_optim
8
6
  from heavyball.utils import clean, set_torch
7
+ from torch import nn
9
8
 
10
9
 
11
10
  def get_memory():
@@ -16,10 +15,6 @@ def get_memory():
16
15
  return torch.cuda.memory_allocated()
17
16
 
18
17
 
19
- expected_memory = {'adamw': {'after': 4, 'peak': 5.1}, 'soap': {'after': 7, 'peak': 14},
20
- 'psgd': {'after': 4, 'peak': 11.5}, 'padam': {'after': 5, 'peak': 11.4}}
21
-
22
-
23
18
  @pytest.mark.parametrize("opt", ['ForeachPSGDKron', 'ForeachPaLMPAdam', 'ForeachPurePSGD', 'ForeachDelayedPSGD'])
24
19
  @pytest.mark.parametrize("method",
25
20
  ['norm_clip_', 'mu_law_compress', 'a_law_compress', 'trust_region_clip_', 'identity'])
@@ -27,12 +22,6 @@ expected_memory = {'adamw': {'after': 4, 'peak': 5.1}, 'soap': {'after': 7, 'pea
27
22
  def test_clip(opt, method, size, depth: int, iterations: int = 100, outer_iterations: int = 3):
28
23
  set_torch()
29
24
 
30
- for k, v in expected_memory.items():
31
- if k in opt.lower():
32
- break
33
- else:
34
- raise ValueError(f'Unknown optimizer {opt}')
35
-
36
25
  opt = getattr(heavyball, opt)
37
26
 
38
27
  for i in range(outer_iterations):
File without changes
File without changes
File without changes