heavyball 0.19.0__tar.gz → 0.21.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. {heavyball-0.19.0 → heavyball-0.21.0}/PKG-INFO +2 -2
  2. {heavyball-0.19.0 → heavyball-0.21.0}/README.md +1 -1
  3. {heavyball-0.19.0 → heavyball-0.21.0}/heavyball/cached_delayed_psgd_kron.py +11 -11
  4. {heavyball-0.19.0 → heavyball-0.21.0}/heavyball/cached_psgd_kron.py +13 -12
  5. {heavyball-0.19.0 → heavyball-0.21.0}/heavyball/delayed_psgd.py +15 -18
  6. {heavyball-0.19.0 → heavyball-0.21.0}/heavyball/foreach_soap.py +4 -7
  7. {heavyball-0.19.0 → heavyball-0.21.0}/heavyball/p_adam.py +9 -9
  8. {heavyball-0.19.0 → heavyball-0.21.0}/heavyball/palm_foreach_soap.py +6 -6
  9. {heavyball-0.19.0 → heavyball-0.21.0}/heavyball/precond_schedule_foreach_soap.py +6 -10
  10. {heavyball-0.19.0 → heavyball-0.21.0}/heavyball/precond_schedule_palm_foreach_soap.py +4 -4
  11. {heavyball-0.19.0 → heavyball-0.21.0}/heavyball/precond_schedule_sfpsoap.py +20 -10
  12. {heavyball-0.19.0 → heavyball-0.21.0}/heavyball/psgd_kron.py +15 -12
  13. {heavyball-0.19.0 → heavyball-0.21.0}/heavyball/pure_psgd.py +3 -6
  14. {heavyball-0.19.0 → heavyball-0.21.0}/heavyball/schedule_free_palm_foreach_soap.py +17 -8
  15. {heavyball-0.19.0 → heavyball-0.21.0}/heavyball/utils.py +169 -58
  16. {heavyball-0.19.0 → heavyball-0.21.0}/heavyball.egg-info/PKG-INFO +2 -2
  17. {heavyball-0.19.0 → heavyball-0.21.0}/heavyball.egg-info/SOURCES.txt +1 -0
  18. {heavyball-0.19.0 → heavyball-0.21.0}/setup.py +1 -1
  19. {heavyball-0.19.0 → heavyball-0.21.0}/test/test_bf16_params.py +2 -1
  20. heavyball-0.21.0/test/test_ema.py +61 -0
  21. {heavyball-0.19.0 → heavyball-0.21.0}/LICENSE +0 -0
  22. {heavyball-0.19.0 → heavyball-0.21.0}/heavyball/__init__.py +0 -0
  23. {heavyball-0.19.0 → heavyball-0.21.0}/heavyball/foreach_adamw.py +0 -0
  24. {heavyball-0.19.0 → heavyball-0.21.0}/heavyball/foreach_adopt.py +0 -0
  25. {heavyball-0.19.0 → heavyball-0.21.0}/heavyball/foreach_laprop.py +0 -0
  26. {heavyball-0.19.0 → heavyball-0.21.0}/heavyball/foreach_sfadamw.py +0 -0
  27. {heavyball-0.19.0 → heavyball-0.21.0}/heavyball/palm_foreach_sfadamw.py +0 -0
  28. {heavyball-0.19.0 → heavyball-0.21.0}/heavyball.egg-info/dependency_links.txt +0 -0
  29. {heavyball-0.19.0 → heavyball-0.21.0}/heavyball.egg-info/requires.txt +0 -0
  30. {heavyball-0.19.0 → heavyball-0.21.0}/heavyball.egg-info/top_level.txt +0 -0
  31. {heavyball-0.19.0 → heavyball-0.21.0}/setup.cfg +0 -0
  32. {heavyball-0.19.0 → heavyball-0.21.0}/test/test_bf16_q.py +0 -0
  33. {heavyball-0.19.0 → heavyball-0.21.0}/test/test_bf16_storage.py +0 -0
  34. {heavyball-0.19.0 → heavyball-0.21.0}/test/test_closure.py +0 -0
  35. {heavyball-0.19.0 → heavyball-0.21.0}/test/test_foreach.py +0 -0
  36. {heavyball-0.19.0 → heavyball-0.21.0}/test/test_memory.py +0 -0
  37. {heavyball-0.19.0 → heavyball-0.21.0}/test/test_merge.py +0 -0
  38. {heavyball-0.19.0 → heavyball-0.21.0}/test/test_no_grad.py +0 -0
  39. {heavyball-0.19.0 → heavyball-0.21.0}/test/test_psgd.py +0 -0
  40. {heavyball-0.19.0 → heavyball-0.21.0}/test/test_soap.py +0 -0
  41. {heavyball-0.19.0 → heavyball-0.21.0}/test/test_stochastic_updates.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: heavyball
3
- Version: 0.19.0
3
+ Version: 0.21.0
4
4
  Summary: Efficient optimizers
5
5
  Home-page: https://github.com/clashluke/heavyball
6
6
  Author: Lucas Nestler
@@ -32,7 +32,7 @@ A simple package of efficient optimizers
32
32
  The goal is not to thrive for completeness, full maintenance or abstraction, but instead to provide a simple
33
33
  largely static alternative to `torch.optim` with more and better optimizers.
34
34
 
35
- Currently (2024-11-22, 0.19), the recommended stable optimizer is `PrecondSchedulePaLMSOAP` (see below). The
35
+ Currently (2024-11-22, 0.21.0), the recommended stable optimizer is `PrecondSchedulePaLMSOAP` (see below). The
36
36
  recommended experimental optimizer is `DelayedPSGDKron` ([tuning guide](docs/psgd_efficiency.md)).
37
37
 
38
38
  ## Features
@@ -8,7 +8,7 @@ A simple package of efficient optimizers
8
8
  The goal is not to thrive for completeness, full maintenance or abstraction, but instead to provide a simple
9
9
  largely static alternative to `torch.optim` with more and better optimizers.
10
10
 
11
- Currently (2024-11-22, 0.19), the recommended stable optimizer is `PrecondSchedulePaLMSOAP` (see below). The
11
+ Currently (2024-11-22, 0.21.0), the recommended stable optimizer is `PrecondSchedulePaLMSOAP` (see below). The
12
12
  recommended experimental optimizer is `DelayedPSGDKron` ([tuning guide](docs/psgd_efficiency.md)).
13
13
 
14
14
  ## Features
@@ -9,7 +9,7 @@ from typing import Optional
9
9
  import torch
10
10
 
11
11
  from .utils import update_param_, warmup, init_Q_exprs, trust_region_clip_, PSGDBase, split_p_and_g_in_group, \
12
- line_to_triu, triu_to_line, set_, einsum_base, promote
12
+ line_to_triu, triu_to_line, einsum_base, promote, stochastic_lerp_, beta_debias
13
13
 
14
14
 
15
15
  class ForeachCachedDelayedPSGDKron(PSGDBase):
@@ -41,7 +41,8 @@ class ForeachCachedDelayedPSGDKron(PSGDBase):
41
41
  max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
42
42
  momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
43
43
  split: bool = False, clip_fn: Optional[callable] = None, store_triu_as_line: bool = True,
44
- foreach: bool = True, q_dtype='float32', stochastic_schedule: bool = True, #
44
+ foreach: bool = True, q_dtype='float32', stochastic_schedule: bool = True,
45
+ storage_dtype: str = 'float32', #
45
46
  # expert parameters
46
47
  precond_init_scale=1.0, precond_lr=0.1):
47
48
  if not 0.0 <= lr:
@@ -58,7 +59,7 @@ class ForeachCachedDelayedPSGDKron(PSGDBase):
58
59
  min_ndim_triangular=min_ndim_triangular, memory_save_mode=memory_save_mode,
59
60
  momentum_into_precond_update=momentum_into_precond_update, precond_lr=precond_lr,
60
61
  precond_init_scale=precond_init_scale, step=0, warmup_steps=warmup_steps, merge_dims=merge_dims,
61
- split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype)
62
+ split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype, storage_dtype=storage_dtype)
62
63
  super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
63
64
 
64
65
  def _step(self, group):
@@ -74,14 +75,15 @@ class ForeachCachedDelayedPSGDKron(PSGDBase):
74
75
  beta = group['beta']
75
76
  store_triu_as_line = group['store_triu_as_line']
76
77
  q_dtype = getattr(torch, group['q_dtype'])
78
+ storage_dtype = getattr(torch, group['storage_dtype'])
77
79
 
78
80
  vals = []
79
81
 
80
- for p, g in split_p_and_g_in_group(group):
82
+ for p, g in split_p_and_g_in_group(group, should_promote=False):
81
83
  state = self.state_(p)
82
84
 
83
85
  if 'Q' not in state:
84
- state["exp_avg"] = torch.zeros_like(g)
86
+ state["exp_avg"] = torch.zeros_like(g, dtype=storage_dtype)
85
87
  Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular, min_ndim_triangular,
86
88
  memory_save_mode, dtype=q_dtype)
87
89
  state['Q'] = triu_to_line(Q) if store_triu_as_line else Q
@@ -105,7 +107,9 @@ class ForeachCachedDelayedPSGDKron(PSGDBase):
105
107
 
106
108
  group["step"] += 1
107
109
 
108
- torch._foreach_lerp_(exp_avg_list, grad_list, (1 - beta) / (1 - beta ** group["step"]))
110
+ stochastic_lerp_(exp_avg_list, grad_list, 1 - beta_debias(beta, group['step']))
111
+
112
+ lr = -warmup(lr, group['step'], group['warmup_steps'])
109
113
 
110
114
  grad_list, Q_list, Q_cache_list, exp_avg_list = list(grad_list), list(Q_list), list(Q_cache_list), list(
111
115
  exp_avg_list)
@@ -127,8 +131,4 @@ class ForeachCachedDelayedPSGDKron(PSGDBase):
127
131
  else:
128
132
  torch.mul(q_.conj(), q_, out=c_)
129
133
 
130
- set_(g, new)
131
- grad_list = self.clip_fn(grad_list)
132
-
133
- lr = -warmup(lr, group['step'], group['warmup_steps'])
134
- update_param_(p_list, grad_list, lr, weight_decay)
134
+ update_param_([p], self.clip_fn([new]), lr, weight_decay)
@@ -9,7 +9,7 @@ from typing import Optional
9
9
  import torch
10
10
 
11
11
  from .utils import update_param_, warmup, init_Q_exprs, trust_region_clip_, PSGDBase, split_p_and_g_in_group, \
12
- line_to_triu, triu_to_line, set_, einsum_base, promote
12
+ line_to_triu, triu_to_line, einsum_base, promote, stochastic_lerp_, beta_debias
13
13
 
14
14
 
15
15
  class ForeachCachedPSGDKron(PSGDBase):
@@ -39,7 +39,8 @@ class ForeachCachedPSGDKron(PSGDBase):
39
39
  max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
40
40
  momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
41
41
  split: bool = False, clip_fn: Optional[callable] = None, store_triu_as_line: bool = True,
42
- foreach: bool = True, q_dtype='float32', stochastic_schedule: bool = True, #
42
+ foreach: bool = True, q_dtype='float32', stochastic_schedule: bool = True,
43
+ storage_dtype: str = 'float32', #
43
44
  # expert parameters
44
45
  precond_init_scale=1.0, precond_lr=0.1):
45
46
  if not 0.0 <= lr:
@@ -56,7 +57,8 @@ class ForeachCachedPSGDKron(PSGDBase):
56
57
  min_ndim_triangular=min_ndim_triangular, memory_save_mode=memory_save_mode,
57
58
  momentum_into_precond_update=momentum_into_precond_update, precond_lr=precond_lr,
58
59
  precond_init_scale=precond_init_scale, step=0, warmup_steps=warmup_steps, merge_dims=merge_dims,
59
- split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype)
60
+ split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype,
61
+ storage_dtype=storage_dtype)
60
62
  super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
61
63
 
62
64
  def _step(self, group):
@@ -71,15 +73,16 @@ class ForeachCachedPSGDKron(PSGDBase):
71
73
  beta = group['beta']
72
74
  store_triu_as_line = group['store_triu_as_line']
73
75
  q_dtype = getattr(torch, group['q_dtype'])
76
+ storage_dtype = getattr(torch, group['storage_dtype'])
74
77
  should_update = self.should_update(group)
75
78
 
76
79
  vals = []
77
80
 
78
- for p, g in split_p_and_g_in_group(group):
81
+ for p, g in split_p_and_g_in_group(group, should_promote=False):
79
82
  state = self.state_(p)
80
83
 
81
84
  if 'Q' not in state:
82
- state["exp_avg"] = torch.zeros_like(g)
85
+ state["exp_avg"] = torch.zeros_like(g, dtype=storage_dtype)
83
86
  Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular, min_ndim_triangular,
84
87
  memory_save_mode, dtype=q_dtype)
85
88
  state['Q'] = triu_to_line(Q) if store_triu_as_line else Q
@@ -103,7 +106,9 @@ class ForeachCachedPSGDKron(PSGDBase):
103
106
 
104
107
  group["step"] += 1
105
108
 
106
- torch._foreach_lerp_(exp_avg_list, grad_list, (1 - beta) / (1 - beta ** group["step"]))
109
+ stochastic_lerp_(exp_avg_list, grad_list, 1 - beta_debias(beta, group['step']))
110
+
111
+ lr = -warmup(lr, group['step'], group['warmup_steps'])
107
112
 
108
113
  grad_list, Q_list, Q_cache_list, exp_avg_list = list(grad_list), list(Q_list), list(Q_cache_list), list(
109
114
  exp_avg_list)
@@ -123,9 +128,5 @@ class ForeachCachedPSGDKron(PSGDBase):
123
128
  else:
124
129
  torch.mul(q_.conj(), q_, out=c_)
125
130
 
126
- set_(g, torch.einsum(self.state_(p)['cache_expr'], *cached_q, ea))
127
-
128
- grad_list = self.clip_fn(grad_list)
129
-
130
- lr = -warmup(lr, group['step'], group['warmup_steps'])
131
- update_param_(p_list, grad_list, lr, weight_decay)
131
+ g = torch.einsum(self.state_(p)['cache_expr'], *cached_q, ea)
132
+ update_param_([p], self.clip_fn([g]), lr, weight_decay)
@@ -5,10 +5,10 @@ Source available at https://github.com/evanatyourservice/kron_torch/blob/97a2b5e
5
5
  """
6
6
 
7
7
  import torch
8
- from heavyball.utils import copy_stochastic_list_
9
8
 
9
+ from heavyball.utils import stochastic_lerp_, beta_debias
10
10
  from .utils import update_param_, warmup, psgd_precond_grad, init_Q_exprs, trust_region_clip_, PSGDBase, \
11
- precond_update_prob_schedule, split_p_and_g_in_group, triu_to_line, line_to_triu, set_, promote
11
+ split_p_and_g_in_group, triu_to_line, line_to_triu, promote
12
12
 
13
13
 
14
14
  class ForeachDelayedPSGD(PSGDBase):
@@ -38,8 +38,8 @@ class ForeachDelayedPSGD(PSGDBase):
38
38
  def __init__(self, params, lr=0.001, beta=0.9, weight_decay=0.0, preconditioner_update_probability=None,
39
39
  max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
40
40
  momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
41
- split: bool = False, clip_fn: callable = None, store_triu_as_line: bool = True,
42
- foreach: bool = True, q_dtype='float32', stochastic_schedule: bool = True, #
41
+ split: bool = False, clip_fn: callable = None, store_triu_as_line: bool = True, foreach: bool = True,
42
+ q_dtype='float32', stochastic_schedule: bool = True, storage_dtype:str='float32', #
43
43
  # expert parameters
44
44
  precond_init_scale=1.0, precond_lr=0.1):
45
45
  if not 0.0 <= lr:
@@ -55,12 +55,10 @@ class ForeachDelayedPSGD(PSGDBase):
55
55
  defaults = dict(lr=lr, beta=beta, weight_decay=weight_decay, max_size_triangular=max_size_triangular,
56
56
  min_ndim_triangular=min_ndim_triangular, memory_save_mode=memory_save_mode,
57
57
  momentum_into_precond_update=momentum_into_precond_update, precond_lr=precond_lr,
58
- precond_init_scale=precond_init_scale,
59
- step=0, warmup_steps=warmup_steps, merge_dims=merge_dims, split=split,
60
- store_triu_as_line=store_triu_as_line, q_dtype=q_dtype)
58
+ precond_init_scale=precond_init_scale, step=0, warmup_steps=warmup_steps, merge_dims=merge_dims,
59
+ split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype)
61
60
  super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
62
61
 
63
-
64
62
  def _step(self, group):
65
63
  should_update = self.should_update(group)
66
64
  momentum_into_precond_update = group.get("momentum_into_precond_update", True)
@@ -74,14 +72,15 @@ class ForeachDelayedPSGD(PSGDBase):
74
72
  beta = group['beta']
75
73
  store_triu_as_line = group['store_triu_as_line']
76
74
  q_dtype = getattr(torch, group['q_dtype'])
75
+ storage_dtype = getattr(torch, group['storage_dtype'])
77
76
 
78
77
  vals = []
79
78
 
80
- for p, g in split_p_and_g_in_group(group):
79
+ for p, g in split_p_and_g_in_group(group, should_promote=False):
81
80
  state = self.state_(p)
82
81
 
83
82
  if 'Q' not in state:
84
- state["exp_avg"] = torch.zeros_like(g)
83
+ state["exp_avg"] = torch.zeros_like(g, dtype=storage_dtype)
85
84
  Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular, min_ndim_triangular,
86
85
  memory_save_mode, dtype=q_dtype)
87
86
  state["Q"] = triu_to_line(Q) if store_triu_as_line else Q
@@ -96,7 +95,9 @@ class ForeachDelayedPSGD(PSGDBase):
96
95
 
97
96
  group["step"] += 1
98
97
 
99
- torch._foreach_lerp_(exp_avg_list, grad_list, (1 - beta) / (1 - beta ** group["step"]))
98
+ stochastic_lerp_(exp_avg_list, grad_list, beta_debias(beta, group["step"]))
99
+
100
+ lr = -warmup(lr, group['step'], group['warmup_steps'])
100
101
 
101
102
  Q_list, exp_avg_list = list(Q_list), list(exp_avg_list)
102
103
  for i, (p, g) in enumerate(zip(p_list, grad_list)):
@@ -106,10 +107,6 @@ class ForeachDelayedPSGD(PSGDBase):
106
107
  new = psgd_precond_grad(q, self.state_(p)["exprs"], ea)
107
108
  if should_update:
108
109
  q32 = [promote(q_) for q_ in q]
109
- self.do_update(group,[p], [ea if momentum_into_precond_update else g], [q32], precond_lr, [q_orig], store_triu_as_line)
110
- set_(g, new)
111
-
112
- grad_list = self.clip_fn(grad_list)
113
-
114
- lr = -warmup(lr, group['step'], group['warmup_steps'])
115
- update_param_(p_list, grad_list, lr, weight_decay)
110
+ self.do_update(group, [p], [ea if momentum_into_precond_update else g], [q32], precond_lr, [q_orig],
111
+ store_triu_as_line)
112
+ update_param_([p], self.clip_fn([new]), lr, weight_decay)
@@ -1,7 +1,7 @@
1
1
  import torch
2
2
 
3
3
  from .utils import init_preconditioner, update_preconditioner, project, beta_debias, exp_avg_sq_, update_param_, set_, \
4
- split_p_and_g_in_group, StatefulOptimizer
4
+ split_p_and_g_in_group, StatefulOptimizer, exp_avg_
5
5
 
6
6
 
7
7
  class ForeachSOAP(StatefulOptimizer):
@@ -26,8 +26,7 @@ class ForeachSOAP(StatefulOptimizer):
26
26
  weight_decay: float = 0.01, precondition_frequency: int = 2, max_precond_dim: int = 2048, #
27
27
  merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
28
28
  data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1,
29
- split: bool = False,
30
- foreach: bool = True):
29
+ split: bool = False, foreach: bool = True):
31
30
  defaults = {"lr": lr, "betas": betas, "shampoo_beta": shampoo_beta, "eps": eps, "weight_decay": weight_decay,
32
31
  "precondition_frequency": precondition_frequency, "max_precond_dim": max_precond_dim,
33
32
  "merge_dims": merge_dims, "precondition_1d": precondition_1d, "normalize_grads": normalize_grads,
@@ -65,14 +64,12 @@ class ForeachSOAP(StatefulOptimizer):
65
64
  p_list, grad, grad_projected, exp_avg, exp_avg_sq = zip(*vals)
66
65
  beta1, beta2 = group["betas"]
67
66
 
68
- old_debiased1 = beta_debias(beta1, step)
69
67
  old_debiased2 = beta_debias(beta2, step)
70
68
 
71
69
  # Decay the first and second moment running average coefficient
72
70
  # In-place operations to update the averages at the same time
73
- torch._foreach_mul_(exp_avg, old_debiased1)
74
- torch._foreach_add_(exp_avg, grad, alpha=1 - old_debiased1)
75
- denom = exp_avg_sq_(exp_avg_sq, grad_projected, old_debiased2, group['eps'])
71
+ step_tensor = torch.empty((), dtype=torch.int32, device=p_list[0].device).fill_(step)
72
+ denom = exp_avg_(exp_avg, exp_avg_sq, grad, grad_projected, beta1, beta2, step_tensor)
76
73
 
77
74
  for p, g, ea, d in zip(p_list, grad, exp_avg, denom):
78
75
  state = self.state_(p)
@@ -39,7 +39,7 @@ class ForeachPaLMPAdam(PSGDBase):
39
39
  momentum_into_precond_update=True, warmup_steps: int = 1, betas=(None, None), beta: float = 0.9,
40
40
  beta2_scale: float = 0.8, merge_dims: bool = False, split: bool = False, clip_fn: callable = None,
41
41
  store_triu_as_line: bool = True, foreach: bool = True, q_dtype='float32',
42
- stochastic_schedule: bool = True, #
42
+ stochastic_schedule: bool = True, storage_dtype:str ='float32',#
43
43
  # expert parameters
44
44
  precond_init_scale=1.0, precond_lr=0.1):
45
45
  if not 0.0 <= lr:
@@ -57,7 +57,7 @@ class ForeachPaLMPAdam(PSGDBase):
57
57
  momentum_into_precond_update=momentum_into_precond_update, precond_lr=precond_lr,
58
58
  precond_init_scale=precond_init_scale, step=0, warmup_steps=warmup_steps, beta=beta,
59
59
  beta2_scale=beta2_scale, merge_dims=merge_dims, split=split,
60
- store_triu_as_line=store_triu_as_line, q_dtype=q_dtype)
60
+ store_triu_as_line=store_triu_as_line, q_dtype=q_dtype, storage_dtype=storage_dtype)
61
61
  super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
62
62
 
63
63
  def _step(self, group):
@@ -71,15 +71,16 @@ class ForeachPaLMPAdam(PSGDBase):
71
71
  lr = group['lr']
72
72
  store_triu_as_line = group['store_triu_as_line']
73
73
  q_dtype = getattr(torch, group['q_dtype'])
74
+ storage_dtype = getattr(torch, group['storage_dtype'])
74
75
 
75
76
  vals = []
76
77
 
77
- for p, g in split_p_and_g_in_group(group):
78
+ for p, g in split_p_and_g_in_group(group, should_promote=False):
78
79
  state = self.state_(p)
79
80
 
80
81
  if 'Q' not in state:
81
- state['exp_avg'] = torch.zeros_like(g)
82
- state['exp_avg_sq'] = torch.zeros_like(g)
82
+ state['exp_avg'] = torch.zeros_like(g, dtype=storage_dtype)
83
+ state['exp_avg_sq'] = torch.zeros_like(g, dtype=storage_dtype)
83
84
  Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular, min_ndim_triangular,
84
85
  memory_save_mode, dtype=q_dtype)
85
86
  state['Q'] = triu_to_line(Q) if store_triu_as_line else Q
@@ -103,6 +104,8 @@ class ForeachPaLMPAdam(PSGDBase):
103
104
 
104
105
  beta2 = 1 - group['step'] ** -group['beta2_scale']
105
106
 
107
+ lr = -warmup(lr, group['step'], group['warmup_steps'])
108
+
106
109
  for p, Q, g, ea, eas in zip(p_list, Q_triu, grad_list, exp_avg, exp_avg_sq):
107
110
  psgd_precond_grad(Q, self.state_(p)["exprs"], g, inplace=True)
108
111
  ea = psgd_precond_grad(Q, self.state_(p)["exprs"], ea)
@@ -112,8 +115,5 @@ class ForeachPaLMPAdam(PSGDBase):
112
115
  divide by g here, because g == denom (from exp_avg_sq_(out=g)), avoids denom allocation
113
116
  divide into g so we can deallocate ea, avoids one allocation (-> less memory than equivalent foreach)
114
117
  """
118
+ update_param_([p], self.clip_fn([g]), lr, weight_decay)
115
119
 
116
- grad_list = self.clip_fn(grad_list)
117
-
118
- lr = -warmup(lr, group['step'], group['warmup_steps'])
119
- update_param_(p_list, grad_list, lr, weight_decay)
@@ -1,7 +1,8 @@
1
1
  import torch
2
2
 
3
3
  from .utils import init_preconditioner, update_preconditioner, project, beta_debias, exp_avg_sq_, update_param_, set_, \
4
- split_p_and_g_in_group, StatefulOptimizer
4
+ split_p_and_g_in_group, StatefulOptimizer, exp_avg_
5
+
5
6
 
6
7
 
7
8
  class PaLMForeachSOAP(StatefulOptimizer):
@@ -32,8 +33,7 @@ class PaLMForeachSOAP(StatefulOptimizer):
32
33
  max_precond_dim: int = 2048, #
33
34
  merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
34
35
  data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1,
35
- beta2_scale: float = 0.8, split: bool = False,
36
- foreach: bool = True):
36
+ beta2_scale: float = 0.8, split: bool = False, foreach: bool = True):
37
37
  if betas[0] is not None:
38
38
  beta = betas[0]
39
39
  defaults = {"lr": lr, "beta": beta, "shampoo_beta": shampoo_beta, "eps": eps, "weight_decay": weight_decay,
@@ -75,13 +75,13 @@ class PaLMForeachSOAP(StatefulOptimizer):
75
75
  beta1 = group["beta"]
76
76
 
77
77
  beta2 = 1 - step ** -group['beta2_scale']
78
- old_debiased1 = beta_debias(beta1, step)
79
78
  old_debiased2 = beta_debias(beta2, step)
80
79
 
81
80
  # Decay the first and second moment running average coefficient
82
81
  # In-place operations to update the averages at the same time
83
- torch._foreach_lerp_(exp_avg, grad, 1 - old_debiased1)
84
- denom = exp_avg_sq_(exp_avg_sq, grad_projected, old_debiased2, group['eps'])
82
+ beta2 = torch.empty((), dtype=torch.float32, device=p_list[0].device).fill_(beta2)
83
+ step_tensor = torch.empty((), dtype=torch.int32, device=p_list[0].device).fill_(step)
84
+ denom = exp_avg_(exp_avg, exp_avg_sq, grad, grad_projected, beta1, beta2, step_tensor)
85
85
 
86
86
  for p, g, ea, d in zip(p_list, grad, exp_avg, denom):
87
87
  state = self.state_(p)
@@ -2,8 +2,8 @@ import random
2
2
 
3
3
  import torch
4
4
 
5
- from .utils import init_preconditioner, update_preconditioner, project, beta_debias, exp_avg_sq_, update_param_, \
6
- precond_schedule, set_, split_p_and_g_in_group, StatefulOptimizer
5
+ from .utils import init_preconditioner, update_preconditioner, project, beta_debias, update_param_, \
6
+ precond_schedule, set_, split_p_and_g_in_group, StatefulOptimizer, exp_avg_
7
7
 
8
8
 
9
9
  class PrecondScheduleForeachSOAP(StatefulOptimizer):
@@ -27,8 +27,7 @@ class PrecondScheduleForeachSOAP(StatefulOptimizer):
27
27
  weight_decay: float = 0.01, precondition_frequency: int = 2, max_precond_dim: int = 2048, #
28
28
  merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
29
29
  data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1,
30
- precond_scheduler=(1 / 3, 9), split: bool = False,
31
- foreach: bool = True):
30
+ precond_scheduler=(1 / 3, 9), split: bool = False, foreach: bool = True):
32
31
  defaults = {"lr": lr, "betas": betas, "shampoo_beta": shampoo_beta, "eps": eps, "weight_decay": weight_decay,
33
32
  "precondition_frequency": precondition_frequency, "max_precond_dim": max_precond_dim,
34
33
  "merge_dims": merge_dims, "precondition_1d": precondition_1d, "normalize_grads": normalize_grads,
@@ -68,14 +67,12 @@ class PrecondScheduleForeachSOAP(StatefulOptimizer):
68
67
  p_list, grad, grad_projected, exp_avg, exp_avg_sq = zip(*vals)
69
68
  beta1, beta2 = group["betas"]
70
69
 
71
- old_debiased1 = beta_debias(beta1, step)
72
70
  old_debiased2 = beta_debias(beta2, step)
73
71
 
74
72
  # Decay the first and second moment running average coefficient
75
73
  # In-place operations to update the averages at the same time
76
- torch._foreach_mul_(exp_avg, old_debiased1)
77
- torch._foreach_add_(exp_avg, grad, alpha=1 - old_debiased1)
78
- denom = exp_avg_sq_(exp_avg_sq, grad_projected, old_debiased2, group['eps'])
74
+ step_tensor = torch.empty((), dtype=torch.int32, device=p_list[0].device).fill_(step)
75
+ denom = exp_avg_(exp_avg, exp_avg_sq, grad, grad_projected, beta1, beta2, step_tensor)
79
76
 
80
77
  update_precond = precond_schedule(step, group['precond_scheduler'], self.rng)
81
78
  for p, g, ea, d in zip(p_list, grad, exp_avg, denom):
@@ -89,8 +86,7 @@ class PrecondScheduleForeachSOAP(StatefulOptimizer):
89
86
  # CANT DO /= HERE AS EXP_AVG MAY POINT TO THE BUFFER
90
87
  set_(d, project(exp_avg_projected / d, state['Q'], True))
91
88
 
92
- update_preconditioner(g, state, max_precond_dim, precondition_1d, old_debiased2,
93
- update_precond)
89
+ update_preconditioner(g, state, max_precond_dim, precondition_1d, old_debiased2, update_precond)
94
90
 
95
91
  # Why does this have to be rebiased here?
96
92
  step_size = -group["lr"] * min(step / group['warmup_steps'], 1)
@@ -2,7 +2,7 @@ import random
2
2
 
3
3
  import torch
4
4
 
5
- from .utils import init_preconditioner, update_preconditioner, project, beta_debias, exp_avg_sq_, update_param_, \
5
+ from .utils import init_preconditioner, update_preconditioner, project, beta_debias, exp_avg_, update_param_, \
6
6
  precond_schedule, set_, split_p_and_g_in_group, StatefulOptimizer
7
7
 
8
8
 
@@ -81,9 +81,9 @@ class PrecondSchedulePaLMForeachSOAP(StatefulOptimizer):
81
81
 
82
82
  # Decay the first and second moment running average coefficient
83
83
  # In-place operations to update the averages at the same time
84
- torch._foreach_mul_(exp_avg, old_debiased1)
85
- torch._foreach_add_(exp_avg, grad, alpha=1 - old_debiased1)
86
- denom = exp_avg_sq_(exp_avg_sq, grad_projected, old_debiased2, group['eps'])
84
+ beta2 = torch.empty((), dtype=torch.float32, device=p_list[0].device).fill_(beta2)
85
+ step_tensor = torch.empty((), dtype=torch.int32, device=p_list[0].device).fill_(step)
86
+ denom = exp_avg_(exp_avg, exp_avg_sq, grad, grad_projected, beta1, beta2, step_tensor)
87
87
 
88
88
  update_precond = precond_schedule(step, group['precond_scheduler'], self.rng)
89
89
  for p, g, ea, d in zip(p_list, grad, exp_avg, denom):
@@ -2,8 +2,19 @@ import random
2
2
 
3
3
  import torch
4
4
 
5
- from .utils import init_preconditioner, update_preconditioner, project, set_, adaptive_gradient_clipping_, \
6
- exp_avg_sq_, beta_debias, schedule_free_, warmup, ScheduleFree, precond_schedule, split_p_and_g_in_group
5
+ from .utils import init_preconditioner, update_preconditioner, project, set_, adaptive_gradient_clipping_, exp_avg_sq_, \
6
+ beta_debias, schedule_free_, warmup, ScheduleFree, precond_schedule, split_p_and_g_in_group, copy_stochastic_list_, \
7
+ promote
8
+
9
+
10
+ @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
11
+ def _compilable_exp_avg_sq_(exp_avg_sq, grad_projected, old_debiased2, eps):
12
+ eas32, gp32 = [list(map(promote, x)) for x in (exp_avg_sq, grad_projected)]
13
+ denom = exp_avg_sq_(eas32, gp32, old_debiased2, eps)
14
+ torch._foreach_div_(gp32, denom)
15
+
16
+ copy_stochastic_list_(exp_avg_sq, eas32)
17
+ copy_stochastic_list_(grad_projected, gp32)
7
18
 
8
19
 
9
20
  class PrecondScheduleSFPaLMSOAP(ScheduleFree):
@@ -40,8 +51,8 @@ class PrecondScheduleSFPaLMSOAP(ScheduleFree):
40
51
  weight_decay: float = 0.01, precondition_frequency: int = 2, max_precond_dim: int = 2048, #
41
52
  merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
42
53
  data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1, r=0.0,
43
- weight_lr_power=2.0, gradient_clip_val: float = 0.1, precond_scheduler=(1 / 3, 9),
44
- betas=(None, None), split: bool = False, foreach: bool = True):
54
+ weight_lr_power=2.0, gradient_clip_val: float = 0.1, precond_scheduler=(1 / 3, 9), betas=(None, None),
55
+ split: bool = False, foreach: bool = True):
45
56
  if betas[0] is not None:
46
57
  beta = betas[0]
47
58
  defaults = {"lr": lr, "beta": beta, "beta2_scale": beta2_scale, "eps": eps, "weight_decay": weight_decay,
@@ -103,8 +114,8 @@ class PrecondScheduleSFPaLMSOAP(ScheduleFree):
103
114
 
104
115
  # Decay the first and second moment running average coefficient
105
116
  # In-place operations to update the averages at the same time
106
- denom = exp_avg_sq_(exp_avg_sq, grad_projected, old_debiased2, group['eps'])
107
- torch._foreach_div_(grad_projected, denom)
117
+ old_debiased_tensor = torch.empty((), dtype=torch.float32, device=p_list[0].device).fill_(old_debiased2)
118
+ _compilable_exp_avg_sq_(exp_avg_sq, grad_projected, old_debiased_tensor, group["eps"])
108
119
 
109
120
  update_precond = precond_schedule(step, group['precond_scheduler'], self.rng)
110
121
 
@@ -114,13 +125,12 @@ class PrecondScheduleSFPaLMSOAP(ScheduleFree):
114
125
  # to the original space
115
126
  set_(gp, project(gp, state['Q'], back=True))
116
127
 
117
- update_preconditioner(g, state, max_precond_dim, precondition_1d, old_debiased2,
118
- update_precond)
128
+ update_preconditioner(g, state, max_precond_dim, precondition_1d, old_debiased2, update_precond)
119
129
 
120
130
  # Weight decay calculated at y
121
131
  if group["weight_decay"] > 0:
122
132
  torch._foreach_add_(grad, p_list, alpha=group["weight_decay"])
123
133
 
124
134
  lr = warmup(group['lr'], step, group['warmup_steps'])
125
- group['weight_sum'] = schedule_free_(lr, group['weight_lr_power'], group['weight_sum'], group['beta'],
126
- p_list, z, grad_projected, group['r'], step)
135
+ group['weight_sum'] = schedule_free_(lr, group['weight_lr_power'], group['weight_sum'], group['beta'], p_list,
136
+ z, grad_projected, group['r'], step)
@@ -9,7 +9,7 @@ from typing import Optional
9
9
  import torch
10
10
 
11
11
  from .utils import update_param_, warmup, psgd_precond_grad, init_Q_exprs, trust_region_clip_, PSGDBase, \
12
- split_p_and_g_in_group, line_to_triu, triu_to_line, set_, promote
12
+ split_p_and_g_in_group, line_to_triu, triu_to_line, promote, stochastic_lerp_, beta_debias
13
13
 
14
14
 
15
15
  class ForeachPSGDKron(PSGDBase):
@@ -39,7 +39,8 @@ class ForeachPSGDKron(PSGDBase):
39
39
  max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
40
40
  momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
41
41
  split: bool = False, clip_fn: Optional[callable] = None, store_triu_as_line: bool = True,
42
- foreach: bool = True, q_dtype='float32', stochastic_schedule: bool = True, #
42
+ foreach: bool = True, q_dtype='float32', stochastic_schedule: bool = True,
43
+ storage_dtype: str = 'float32', #
43
44
  # expert parameters
44
45
  precond_init_scale=1.0, precond_lr=0.1):
45
46
  if not 0.0 <= lr:
@@ -56,7 +57,7 @@ class ForeachPSGDKron(PSGDBase):
56
57
  min_ndim_triangular=min_ndim_triangular, memory_save_mode=memory_save_mode,
57
58
  momentum_into_precond_update=momentum_into_precond_update, precond_lr=precond_lr,
58
59
  precond_init_scale=precond_init_scale, step=0, warmup_steps=warmup_steps, merge_dims=merge_dims,
59
- split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype)
60
+ split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype, storage_dtype=storage_dtype)
60
61
  super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
61
62
 
62
63
  def _step(self, group):
@@ -72,14 +73,15 @@ class ForeachPSGDKron(PSGDBase):
72
73
  beta = group['beta']
73
74
  store_triu_as_line = group['store_triu_as_line']
74
75
  q_dtype = getattr(torch, group['q_dtype'])
76
+ storage_dtype = getattr(torch, group['storage_dtype'])
75
77
 
76
78
  vals = []
77
79
 
78
- for p, g in split_p_and_g_in_group(group):
80
+ for p, g in split_p_and_g_in_group(group, should_promote=False):
79
81
  state = self.state_(p)
80
82
 
81
83
  if 'Q' not in state:
82
- state["exp_avg"] = torch.zeros_like(g)
84
+ state["exp_avg"] = torch.zeros_like(g, dtype=storage_dtype)
83
85
  Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular, min_ndim_triangular,
84
86
  memory_save_mode, dtype=q_dtype)
85
87
  state['Q'] = triu_to_line(Q) if store_triu_as_line else Q
@@ -94,9 +96,14 @@ class ForeachPSGDKron(PSGDBase):
94
96
 
95
97
  group["step"] += 1
96
98
 
97
- torch._foreach_lerp_(exp_avg_list, grad_list, (1 - beta) / (1 - beta ** group["step"]))
99
+ beta = beta_debias(beta, group["step"])
100
+ beta = torch.empty((), dtype=torch.float32, device=grad_list[0].device).fill_(1 - beta)
101
+ stochastic_lerp_(exp_avg_list, grad_list, 1 - beta)
98
102
 
99
103
  grad_list, Q_list, exp_avg_list = list(grad_list), list(Q_list), list(exp_avg_list)
104
+
105
+ lr = -warmup(lr, group['step'], group['warmup_steps'])
106
+
100
107
  for i, (p, g) in enumerate(zip(p_list, grad_list)):
101
108
  q_orig = Q_list.pop(0)
102
109
  ea = exp_avg_list.pop(0)
@@ -106,9 +113,5 @@ class ForeachPSGDKron(PSGDBase):
106
113
  q32 = [promote(q_) for q_ in q]
107
114
  self.do_update(group, [p], [ea if momentum_into_precond_update else g], [q32], precond_lr, [q_orig],
108
115
  store_triu_as_line)
109
- set_(g, psgd_precond_grad(q, self.state_(p)["exprs"], ea))
110
-
111
- grad_list = self.clip_fn(grad_list)
112
-
113
- lr = -warmup(lr, group['step'], group['warmup_steps'])
114
- update_param_(p_list, grad_list, lr, weight_decay)
116
+ g = psgd_precond_grad(q, self.state_(p)["exprs"], ea)
117
+ update_param_([p], self.clip_fn([g]), lr, weight_decay)
@@ -70,7 +70,7 @@ class ForeachPurePSGD(PSGDBase):
70
70
 
71
71
  vals = []
72
72
 
73
- for p, g in split_p_and_g_in_group(group):
73
+ for p, g in split_p_and_g_in_group(group, should_promote=False):
74
74
  state = self.state_(p)
75
75
 
76
76
  if 'Q' not in state:
@@ -89,6 +89,7 @@ class ForeachPurePSGD(PSGDBase):
89
89
  group["step"] += 1
90
90
 
91
91
  Q_list = list(Q_list)
92
+ lr = -warmup(lr, group['step'], group['warmup_steps'])
92
93
  for i, (p, g) in enumerate(zip(p_list, grad_list)):
93
94
  q_orig = Q_list.pop(0)
94
95
  q = line_to_triu(q_orig) if store_triu_as_line else q_orig
@@ -97,8 +98,4 @@ class ForeachPurePSGD(PSGDBase):
97
98
  q32 = [promote(q_) for q_ in q]
98
99
  self.do_update(group, [p], [g], [q32], precond_lr, [q_orig], store_triu_as_line)
99
100
  psgd_precond_grad(q, self.state_(p)["exprs"], g, inplace=True)
100
-
101
- grad_list = self.clip_fn(grad_list)
102
-
103
- lr = -warmup(lr, group['step'], group['warmup_steps'])
104
- update_param_(p_list, grad_list, lr, weight_decay)
101
+ update_param_([p], self.clip_fn([g]), lr, weight_decay)
@@ -2,8 +2,18 @@ import random
2
2
 
3
3
  import torch
4
4
 
5
- from .utils import init_preconditioner, update_preconditioner, project, set_, adaptive_gradient_clipping_, \
6
- exp_avg_sq_, beta_debias, schedule_free_, warmup, ScheduleFree, split_p_and_g_in_group
5
+ from .utils import init_preconditioner, update_preconditioner, project, set_, adaptive_gradient_clipping_, exp_avg_sq_, \
6
+ beta_debias, schedule_free_, warmup, ScheduleFree, split_p_and_g_in_group, copy_stochastic_list_, promote
7
+
8
+
9
+ @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
10
+ def _compilable_exp_avg_sq_(exp_avg_sq, grad_projected, old_debiased2, eps):
11
+ eas32, gp32 = [list(map(promote, x)) for x in (exp_avg_sq, grad_projected)]
12
+ denom = exp_avg_sq_(eas32, gp32, old_debiased2, eps)
13
+ torch._foreach_div_(gp32, denom)
14
+
15
+ copy_stochastic_list_(exp_avg_sq, eas32)
16
+ copy_stochastic_list_(grad_projected, gp32)
7
17
 
8
18
 
9
19
  class SFPaLMForeachSOAP(ScheduleFree):
@@ -95,8 +105,8 @@ class SFPaLMForeachSOAP(ScheduleFree):
95
105
 
96
106
  # Decay the first and second moment running average coefficient
97
107
  # In-place operations to update the averages at the same time
98
- denom = exp_avg_sq_(exp_avg_sq, grad, new_debiased2, group["eps"])
99
- torch._foreach_div_(grad_projected, denom)
108
+ old_debiased_tensor = torch.empty((), dtype=torch.float32, device=p_list[0].device).fill_(new_debiased2)
109
+ _compilable_exp_avg_sq_(exp_avg_sq, grad_projected, old_debiased_tensor, group["eps"])
100
110
 
101
111
  update_precond = group['step'] > 0 and group['step'] % group['precondition_frequency'] == 0
102
112
 
@@ -107,13 +117,12 @@ class SFPaLMForeachSOAP(ScheduleFree):
107
117
  # CANT DO /= HERE AS EXP_AVG MAY POINT TO THE BUFFER
108
118
  set_(gp, project(gp, state['Q'], back=True))
109
119
 
110
- update_preconditioner(g, state, max_precond_dim, precondition_1d, 1 - new_debiased2,
111
- update_precond)
120
+ update_preconditioner(g, state, max_precond_dim, precondition_1d, 1 - new_debiased2, update_precond)
112
121
 
113
122
  # Weight decay calculated at y
114
123
  if group["weight_decay"] > 0:
115
124
  torch._foreach_add_(grad, p_list, alpha=group["weight_decay"])
116
125
 
117
126
  lr = warmup(group['lr'], step, group['warmup_steps'])
118
- group['weight_sum'] = schedule_free_(lr, group['weight_lr_power'], group['weight_sum'], group['beta'],
119
- p_list, z, grad_projected, group['r'], step)
127
+ group['weight_sum'] = schedule_free_(lr, group['weight_lr_power'], group['weight_sum'], group['beta'], p_list,
128
+ z, grad_projected, group['r'], step)
@@ -3,7 +3,7 @@ import gc
3
3
  import math
4
4
  import random
5
5
  import string
6
- from typing import List, Optional, Tuple, Callable
6
+ from typing import List, Optional, Tuple, Callable, Union
7
7
 
8
8
  import numpy as np
9
9
  import torch
@@ -141,6 +141,7 @@ def beta_debias(beta, step):
141
141
  return 1 - (1 - beta) / (1 - beta ** step)
142
142
 
143
143
 
144
+ @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
144
145
  def exp_avg_sq_(state, grad, beta2, eps, out=None):
145
146
  if isinstance(state, torch.Tensor):
146
147
  state.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
@@ -327,6 +328,36 @@ def get_orthogonal_matrix(mat):
327
328
  return final
328
329
 
329
330
 
331
+ @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
332
+ def _compilable_stochastic_lerp_(x: List[torch.Tensor], y: List[torch.Tensor], a: Union[float, int, torch.Tensor]):
333
+ for x_, y_ in zip(x, y):
334
+ x32 = promote(x_)
335
+ y32 = promote(y_)
336
+ x32.lerp_(y32, a)
337
+ copy_stochastic_(x_, x32)
338
+
339
+
340
+ def stochastic_lerp_(x: List[torch.Tensor], y: List[torch.Tensor], a: Union[float, int, torch.Tensor]):
341
+ if not isinstance(a, torch.Tensor):
342
+ a = torch.empty((), dtype=torch.float32, device=x[0].device).fill_(a)
343
+ _compilable_stochastic_lerp_(x, y, a)
344
+
345
+
346
+ @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
347
+ def _compilable_stochastic_add_(x: List[torch.Tensor], y: List[torch.Tensor], alpha: Union[float, int, torch.Tensor]):
348
+ for x_, y_ in zip(x, y):
349
+ x32 = promote(x_)
350
+ y32 = promote(y_)
351
+ x32.add_(y32, alpha=alpha)
352
+ copy_stochastic_(x_, x32)
353
+
354
+
355
+ def stochastic_add_(x: List[torch.Tensor], y: List[torch.Tensor], alpha: Union[float, int, torch.Tensor]):
356
+ if not isinstance(alpha, torch.Tensor):
357
+ alpha = torch.empty((), dtype=torch.float32, device=x[0].device).fill_(alpha)
358
+ _compilable_stochastic_add_(x, y, alpha)
359
+
360
+
330
361
  @decorator
331
362
  def compute_ggt(grad, GG, max_precond_dim, precondition_1d, beta):
332
363
  if grad.dim() == 1 and (not precondition_1d or grad.shape[0] > max_precond_dim):
@@ -409,9 +440,12 @@ def project(grad, Q, back: bool):
409
440
 
410
441
 
411
442
  class StatefulOptimizer(torch.optim.Optimizer):
412
- def __init__(self, params, defaults, foreach: bool = True):
443
+ ema_decay: float = 0.001
444
+
445
+ def __init__(self, params, defaults, foreach: bool = True, use_ema: bool = False):
413
446
  super().__init__(params, {**defaults, 'foreach': foreach})
414
447
  self.fake_groups = {}
448
+ self.use_ema = use_ema
415
449
 
416
450
  def key(self, param: torch.Tensor):
417
451
  return (param.data_ptr(), tuple(param.shape))
@@ -445,6 +479,54 @@ class StatefulOptimizer(torch.optim.Optimizer):
445
479
  def _step(self, group):
446
480
  raise NotImplementedError
447
481
 
482
+ def ema_update(self):
483
+ with torch.no_grad():
484
+ for top_group in self.param_groups:
485
+ for group in self.get_groups(top_group):
486
+ active_p = [p for p in group['params']]
487
+
488
+ if not active_p:
489
+ return
490
+
491
+ k = group['ema_step'] = group.get('ema_step', -1) + 1
492
+
493
+ for p in active_p:
494
+ if 'param_ema' not in self.state_(p):
495
+ self.state_(p)['param_ema'] = torch.zeros_like(p.data, memory_format=torch.preserve_format)
496
+
497
+ y, param_ema = zip(*[(p.data, self.state_(p)['param_ema']) for p in active_p])
498
+ torch._foreach_lerp_(param_ema, y, weight=beta_debias(1 - self.ema_decay, k + 1))
499
+
500
+ def copy_emas_to_params(self):
501
+ with torch.no_grad():
502
+ for top_group in self.param_groups:
503
+ for group in self.get_groups(top_group):
504
+ active_p = [p for p in group['params']]
505
+
506
+ if not active_p:
507
+ return
508
+
509
+ for p in active_p:
510
+ if 'param_ema' in self.state_(p):
511
+ p_clone = p.data.clone()
512
+ set_(p.data, self.state_(p)['param_ema'])
513
+ set_(self.state_(p)['param_ema'], p_clone)
514
+
515
+ def copy_params_to_emas(self):
516
+ with torch.no_grad():
517
+ for top_group in self.param_groups:
518
+ for group in self.get_groups(top_group):
519
+ active_p = [p for p in group['params']]
520
+
521
+ if not active_p:
522
+ return
523
+
524
+ for p in active_p:
525
+ if 'param_ema' in self.state_(p):
526
+ ema_clone = self.state_(p)['param_ema'].data.clone()
527
+ set_(self.state_(p)['param_ema'], p.data)
528
+ set_(p.data, ema_clone)
529
+
448
530
  def step(self, closure: Optional[Callable] = None):
449
531
  if closure is None:
450
532
  loss = None
@@ -455,6 +537,8 @@ class StatefulOptimizer(torch.optim.Optimizer):
455
537
  for top_group in self.param_groups:
456
538
  for group in self.get_groups(top_group):
457
539
  self._step(group)
540
+ if self.use_ema:
541
+ self.ema_update(group)
458
542
  return loss
459
543
 
460
544
 
@@ -497,6 +581,32 @@ def copy_stochastic_list_(target: List[torch.Tensor], source: List[torch.Tensor]
497
581
  copy_stochastic_(t, s)
498
582
 
499
583
 
584
+ @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
585
+ def _compilable_exp_avg_(exp_avg, exp_avg_sq, grad, grad_projected, beta1, beta2, step):
586
+ beta1 = beta_debias(beta1, step)
587
+ beta2 = beta_debias(beta2, step)
588
+
589
+ g32, gp32, exp_avg_sq32 = [list(map(promote, x)) for x in [grad, grad_projected, exp_avg_sq]]
590
+
591
+ stochastic_lerp_(exp_avg, g32, 1 - beta1)
592
+ denom = exp_avg_sq_(exp_avg_sq32, gp32, beta2, 1e-8)
593
+
594
+ copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
595
+ return denom
596
+
597
+
598
+ def exp_avg_(exp_avg: List[torch.Tensor], exp_avg_sq: List[torch.Tensor], grad: List[torch.Tensor],
599
+ grad_projected: List[torch.Tensor], beta1: float, beta2: float, step: int):
600
+ if isinstance(beta1, float):
601
+ beta1 = torch.empty((), dtype=torch.float32, device=exp_avg[0].device).fill_(beta1)
602
+ if isinstance(beta2, float):
603
+ beta2 = torch.empty((), dtype=torch.float32, device=exp_avg[0].device).fill_(beta2)
604
+ if isinstance(step, int):
605
+ step = torch.empty((), dtype=torch.int32, device=exp_avg[0].device).fill_(step)
606
+ denom = _compilable_exp_avg_(exp_avg, exp_avg_sq, grad, grad_projected, beta1, beta2, step)
607
+ return denom
608
+
609
+
500
610
  # this can be dynamic for most optimizers - just not for PSGD. So, it's disabled for all
501
611
  @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True)
502
612
  def _compilable_copy_stochastic_(target: torch.Tensor, source: torch.Tensor):
@@ -523,23 +633,26 @@ def copy_stochastic_(target: torch.Tensor, source: torch.Tensor):
523
633
 
524
634
 
525
635
  @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
526
- def _compilable_update_one_(p, u, decay, add_fn, lr):
527
- p32 = promote(p)
528
- u32 = promote(u.view(p.shape))
636
+ def _compilable_update_(p, u, decay, add_fn, lr):
637
+ u = [u_.view_as(p_) for u_, p_ in zip(u, p)]
638
+ p32, u32 = [list(map(promote, x)) for x in [p, u]]
639
+
529
640
  if decay > 0:
530
- p32.mul_(1 - decay * lr)
531
- if add_fn is None:
532
- p32.add_(u32, alpha=lr)
533
- else:
534
- add_fn(p32, u32, lr)
535
- copy_stochastic_(p, p32)
641
+ torch._foreach_mul_(p32, 1 - decay * lr)
642
+
643
+ for p32_, u32_ in zip(p32, u32): # lr is data-dependent -> can't compile a foreach
644
+ if add_fn is None:
645
+ p32_.add_(u32_, alpha=lr)
646
+ else:
647
+ add_fn(p32_, u32_, lr)
648
+
649
+ copy_stochastic_list_(p, p32)
536
650
 
537
651
 
538
652
  def update_param_(param: List[torch.Tensor], update: List[torch.Tensor], lr: float, decay: float,
539
653
  add_fn: callable = None):
540
654
  lr_tensor = torch.empty((), dtype=torch.float32, device=param[0].device).fill_(lr)
541
- for p, u in zip(param, update):
542
- _compilable_update_one_(p, u, decay, add_fn, lr_tensor)
655
+ _compilable_update_(param, update, decay, add_fn, lr_tensor)
543
656
 
544
657
 
545
658
  def precond_schedule(step, precond_scheduler, rng):
@@ -638,12 +751,13 @@ def psgd_balance_Q(Q_in):
638
751
  torch._foreach_mul_(Q_in, list(norms))
639
752
 
640
753
 
641
- def psgd_calc_A_and_conjB(exprA, G, Q, V):
642
- md = min_dtype(Q)
643
- A = torch.einsum(exprA, *[q.to(md) for q in Q], G.to(md))
754
+ def psgd_calc_A_and_conjB(exprA, G, Q):
755
+ md = min_dtype(Q + [G])
756
+ A = torch.einsum(exprA, *[q.to(md) for q in Q], G.to(md)).to(G.dtype)
644
757
  order = G.dim()
645
758
  p = list(range(order))
646
- conjB = torch.permute(V.conj(), p[1:] + p[:1])
759
+ conjB = torch.randn(G.shape[1:] + G.shape[:1], dtype=promote(G.dtype), device=G.device)
760
+ Q = [promote(q) for q in Q]
647
761
  for i, q in enumerate(Q):
648
762
  if q.dim() <= 1:
649
763
  conjB /= q
@@ -651,7 +765,7 @@ def psgd_calc_A_and_conjB(exprA, G, Q, V):
651
765
  unsqueeze = conjB.dim() <= 1
652
766
  if unsqueeze:
653
767
  conjB = conjB.unsqueeze(0)
654
- conjB = torch.linalg.solve_triangular(q, conjB, upper=True, left=False, out=conjB)
768
+ conjB = torch.linalg.solve_triangular(q, conjB, upper=True, left=False)
655
769
  if unsqueeze:
656
770
  conjB = conjB.squeeze(0)
657
771
  if i < order - 1:
@@ -661,33 +775,29 @@ def psgd_calc_A_and_conjB(exprA, G, Q, V):
661
775
 
662
776
  def psgd_lb(A, max_abs):
663
777
  A /= max_abs
664
- aa = torch.real(A * A.conj())
665
- value0, i = torch.max(torch.sum(aa, dim=0), 0)
666
- value1, j = torch.max(torch.sum(aa, dim=1), 0)
667
-
668
- ah = A.H
669
- comp = value0 > value1
670
- x = torch.where(comp, A[:, i], A[j])
671
- x = x.conj()
672
- if x.dim() > 1:
673
- x = torch.where(comp, x, x.T)
674
- torch.matmul(x, torch.where(comp, A, A.T), out=x.view(1, -1))
675
- x /= torch.linalg.vector_norm(x)
676
- torch.matmul(x, torch.where(comp, ah, ah.T), out=x.view(1, -1))
677
- x = torch.linalg.vector_norm(x)
778
+ a0 = torch.einsum('ij,ij->j', A, A)
779
+ i = torch.argmax(a0)
780
+
781
+ x = torch.index_select(A, 1, i).flatten().contiguous()
782
+
783
+ x = torch.einsum('i,ij->j', x, A)
784
+ x /= x.norm()
785
+ x = torch.einsum('j,kj->k', x, A)
786
+ x = x.norm()
678
787
  x *= max_abs
679
788
  return x
680
789
 
681
790
 
682
- def psgd_update_precond(Q, exprs, V, G, step, tiny):
791
+ @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
792
+ def psgd_update_precond(Q, exprs, G, precond_lr, tiny, oq, store_triu_as_line):
683
793
  """Update Kronecker product preconditioner Q with pair (V, G)."""
684
794
  exprA, exprGs, _ = exprs
685
795
 
686
- A, conjB = psgd_calc_A_and_conjB(exprA, G, Q, V)
796
+ A, conjB = psgd_calc_A_and_conjB(exprA, G, Q)
687
797
 
688
- for q, exprG in zip(Q, exprGs):
689
- term1 = torch.einsum(exprG, A, A.conj())
690
- term2 = torch.einsum(exprG, conjB.conj(), conjB)
798
+ for q, exprG, o in zip(Q, exprGs, oq):
799
+ term1 = promote(torch.einsum(exprG, A, A))
800
+ term2 = promote(torch.einsum(exprG, conjB, conjB))
691
801
 
692
802
  term2 += term1 # a + b
693
803
  term1 *= 2 # 2a
@@ -696,18 +806,22 @@ def psgd_update_precond(Q, exprs, V, G, step, tiny):
696
806
  else:
697
807
  term1 = term1 - term2
698
808
 
699
- term1 *= step
809
+ term1 *= precond_lr
700
810
  norm = term2.norm(float('inf'))
701
811
  if q.dim() < 2:
702
- term1 *= q
703
- q.addcdiv_(term1, norm.clamp_(min=tiny), value=-1)
812
+ term1 *= q.to(term1.dtype)
813
+ term1 /= norm.clamp_(min=tiny)
704
814
  else:
705
815
  torch.triu(term1, out=term1)
706
- term1 /= torch.where(norm > 0, psgd_lb(term2, norm), norm).clamp_(tiny)
707
- q.addmm_(term1, q, alpha=-1)
816
+ term1 /= psgd_lb(term2, norm).clamp_(tiny)
817
+ torch.matmul(term1, q, out=term1)
818
+ if store_triu_as_line:
819
+ term1 = triu_to_line([term1])[0][1]
820
+ o = o[1]
821
+ stochastic_add_([o], [term1], -1)
708
822
 
709
823
 
710
- @decorator
824
+ @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
711
825
  def psgd_precond_grad(Q, exprs, G, inplace: bool = False):
712
826
  """Precondition gradient G with preconditioner Q."""
713
827
  md = min_dtype(Q)
@@ -838,18 +952,9 @@ class PSGDBase(StatefulOptimizer):
838
952
  group[name] = cumulative_prob + prob
839
953
  return int(group[name]) > int(cumulative_prob)
840
954
 
841
- def do_update(self, group, p_list, grad_list, q_list, precond_lr, original_q: Optional[List] = None,
842
- store_triu_as_line=False):
843
- if original_q:
844
- if store_triu_as_line:
845
- update_fn = update_triu_
846
- else:
847
- update_fn = copy_stochastic_list_
848
- else:
849
- update_fn = lambda x, y: None
850
- for i, (p, grad, Q, oq) in enumerate(zip(p_list, grad_list, q_list, original_q)):
851
- psgd_update_precond(Q, self.state_(p)["exprs"], torch.randn_like(grad), grad, precond_lr, self._tiny)
852
- update_fn(oq, Q)
955
+ def do_update(self, group, p_list, grad_list, q_list, precond_lr, original_q: List, store_triu_as_line=False):
956
+ for p, grad, Q, oq in zip(p_list, grad_list, q_list, original_q):
957
+ psgd_update_precond(Q, self.state_(p)["exprs"], grad, precond_lr, self._tiny, oq, store_triu_as_line)
853
958
 
854
959
  if self.should_update(group, self.balance_probability, "balance_prob"):
855
960
  for g, q in zip(grad_list, original_q if original_q else q_list):
@@ -896,13 +1001,19 @@ def merge_group(group, *tensors):
896
1001
  return out
897
1002
 
898
1003
 
899
- def split_p_and_g_in_group(group: dict, skip_none: bool = True):
1004
+ def split_p_and_g_in_group(group: dict, skip_none: bool = True, should_promote: bool = True):
900
1005
  for p in group["params"]:
901
1006
  if skip_none and p.grad is None:
902
1007
  continue
903
1008
 
904
- grad = None if p.grad is None else promote(p.grad)
905
- p.grad = None
1009
+ if p.grad is None:
1010
+ grad = None
1011
+ else:
1012
+ if should_promote:
1013
+ grad = promote(p.grad)
1014
+ else:
1015
+ grad = p.grad
1016
+ p.grad = None
906
1017
 
907
1018
  p_views = merge_group(group, p)
908
1019
  if grad is not None:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: heavyball
3
- Version: 0.19.0
3
+ Version: 0.21.0
4
4
  Summary: Efficient optimizers
5
5
  Home-page: https://github.com/clashluke/heavyball
6
6
  Author: Lucas Nestler
@@ -32,7 +32,7 @@ A simple package of efficient optimizers
32
32
  The goal is not to thrive for completeness, full maintenance or abstraction, but instead to provide a simple
33
33
  largely static alternative to `torch.optim` with more and better optimizers.
34
34
 
35
- Currently (2024-11-22, 0.19), the recommended stable optimizer is `PrecondSchedulePaLMSOAP` (see below). The
35
+ Currently (2024-11-22, 0.21.0), the recommended stable optimizer is `PrecondSchedulePaLMSOAP` (see below). The
36
36
  recommended experimental optimizer is `DelayedPSGDKron` ([tuning guide](docs/psgd_efficiency.md)).
37
37
 
38
38
  ## Features
@@ -29,6 +29,7 @@ test/test_bf16_params.py
29
29
  test/test_bf16_q.py
30
30
  test/test_bf16_storage.py
31
31
  test/test_closure.py
32
+ test/test_ema.py
32
33
  test/test_foreach.py
33
34
  test/test_memory.py
34
35
  test/test_merge.py
@@ -10,7 +10,7 @@ setuptools.setup(
10
10
  name='heavyball',
11
11
  license='BSD',
12
12
  description='Efficient optimizers',
13
- version='0.19.0',
13
+ version='0.21.0',
14
14
  long_description=README,
15
15
  url='https://github.com/clashluke/heavyball',
16
16
  packages=setuptools.find_packages(),
@@ -20,10 +20,11 @@ def get_memory():
20
20
 
21
21
  @pytest.mark.parametrize("opt", heavyball.__all__)
22
22
  @pytest.mark.parametrize("size,depth", [(256, 2)])
23
- def test_foreach(opt, size, depth: int, iterations: int = 128, outer_iterations: int = 3):
23
+ def test_foreach(opt, size, depth: int, iterations: int = 512, outer_iterations: int = 3):
24
24
  set_torch()
25
25
  opt = getattr(heavyball, opt)
26
26
 
27
+
27
28
  peaks = []
28
29
  losses = []
29
30
 
@@ -0,0 +1,61 @@
1
+ import pytest
2
+ import torch
3
+ from torch import nn
4
+ from torch._dynamo import config
5
+
6
+ import heavyball
7
+ import heavyball.utils
8
+ from benchmark.utils import get_optim
9
+ from heavyball.utils import clean, set_torch
10
+
11
+ config.cache_size_limit = 128
12
+
13
+
14
+ def get_memory():
15
+ clean()
16
+ torch.cuda.synchronize()
17
+ clean()
18
+ torch.cuda.synchronize()
19
+ return torch.cuda.memory_allocated()
20
+
21
+
22
+ @pytest.mark.parametrize("opt", heavyball.__all__)
23
+ @pytest.mark.parametrize("size,depth", [(256, 2)])
24
+ def test_foreach(opt, size, depth: int, iterations: int = 128, outer_iterations: int = 3):
25
+ set_torch()
26
+ opt = getattr(heavyball, opt)
27
+
28
+ peaks = []
29
+ losses = []
30
+
31
+ for do_ema in [True, False]:
32
+ torch.manual_seed(0x2131290)
33
+ peaks.append([])
34
+ losses.append([])
35
+
36
+ for i in range(outer_iterations):
37
+ model = nn.Sequential(*[nn.Linear(size, size) for _ in range(depth)]).cuda()
38
+ o = get_optim(opt, model.parameters(), lr=1e-3)
39
+
40
+ for _ in range(iterations):
41
+ loss = model(torch.randn((1024, size), device='cuda')).square().mean()
42
+ loss.backward()
43
+ o.step()
44
+ o.zero_grad()
45
+ if do_ema:
46
+ o.ema_update()
47
+ o.copy_emas_to_params()
48
+ o.copy_params_to_emas()
49
+ losses[-1].append(loss.detach())
50
+
51
+ if do_ema:
52
+ o.copy_emas_to_params()
53
+ loss = model(torch.randn((1024, size), device='cuda')).square().mean()
54
+ losses[-1].append(loss.detach())
55
+
56
+ del model, o
57
+ clean()
58
+
59
+ for i, (l0, l1) in enumerate(zip(*losses)):
60
+ print(i, l0.item(), l1.item())
61
+ assert l0.float() <= l1.float()
File without changes
File without changes
File without changes
File without changes