heavyball 0.19.0__py3-none-any.whl → 0.20.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -9,7 +9,7 @@ from typing import Optional
9
9
  import torch
10
10
 
11
11
  from .utils import update_param_, warmup, init_Q_exprs, trust_region_clip_, PSGDBase, split_p_and_g_in_group, \
12
- line_to_triu, triu_to_line, set_, einsum_base, promote
12
+ line_to_triu, triu_to_line, einsum_base, promote, stochastic_lerp_, beta_debias
13
13
 
14
14
 
15
15
  class ForeachCachedDelayedPSGDKron(PSGDBase):
@@ -41,7 +41,8 @@ class ForeachCachedDelayedPSGDKron(PSGDBase):
41
41
  max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
42
42
  momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
43
43
  split: bool = False, clip_fn: Optional[callable] = None, store_triu_as_line: bool = True,
44
- foreach: bool = True, q_dtype='float32', stochastic_schedule: bool = True, #
44
+ foreach: bool = True, q_dtype='float32', stochastic_schedule: bool = True,
45
+ storage_dtype: str = 'float32', #
45
46
  # expert parameters
46
47
  precond_init_scale=1.0, precond_lr=0.1):
47
48
  if not 0.0 <= lr:
@@ -58,7 +59,7 @@ class ForeachCachedDelayedPSGDKron(PSGDBase):
58
59
  min_ndim_triangular=min_ndim_triangular, memory_save_mode=memory_save_mode,
59
60
  momentum_into_precond_update=momentum_into_precond_update, precond_lr=precond_lr,
60
61
  precond_init_scale=precond_init_scale, step=0, warmup_steps=warmup_steps, merge_dims=merge_dims,
61
- split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype)
62
+ split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype, storage_dtype=storage_dtype)
62
63
  super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
63
64
 
64
65
  def _step(self, group):
@@ -74,14 +75,15 @@ class ForeachCachedDelayedPSGDKron(PSGDBase):
74
75
  beta = group['beta']
75
76
  store_triu_as_line = group['store_triu_as_line']
76
77
  q_dtype = getattr(torch, group['q_dtype'])
78
+ storage_dtype = getattr(torch, group['storage_dtype'])
77
79
 
78
80
  vals = []
79
81
 
80
- for p, g in split_p_and_g_in_group(group):
82
+ for p, g in split_p_and_g_in_group(group, should_promote=False):
81
83
  state = self.state_(p)
82
84
 
83
85
  if 'Q' not in state:
84
- state["exp_avg"] = torch.zeros_like(g)
86
+ state["exp_avg"] = torch.zeros_like(g, dtype=storage_dtype)
85
87
  Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular, min_ndim_triangular,
86
88
  memory_save_mode, dtype=q_dtype)
87
89
  state['Q'] = triu_to_line(Q) if store_triu_as_line else Q
@@ -105,7 +107,9 @@ class ForeachCachedDelayedPSGDKron(PSGDBase):
105
107
 
106
108
  group["step"] += 1
107
109
 
108
- torch._foreach_lerp_(exp_avg_list, grad_list, (1 - beta) / (1 - beta ** group["step"]))
110
+ stochastic_lerp_(exp_avg_list, grad_list, 1 - beta_debias(beta, group['step']))
111
+
112
+ lr = -warmup(lr, group['step'], group['warmup_steps'])
109
113
 
110
114
  grad_list, Q_list, Q_cache_list, exp_avg_list = list(grad_list), list(Q_list), list(Q_cache_list), list(
111
115
  exp_avg_list)
@@ -127,8 +131,4 @@ class ForeachCachedDelayedPSGDKron(PSGDBase):
127
131
  else:
128
132
  torch.mul(q_.conj(), q_, out=c_)
129
133
 
130
- set_(g, new)
131
- grad_list = self.clip_fn(grad_list)
132
-
133
- lr = -warmup(lr, group['step'], group['warmup_steps'])
134
- update_param_(p_list, grad_list, lr, weight_decay)
134
+ update_param_([p], self.clip_fn([new]), lr, weight_decay)
@@ -9,7 +9,7 @@ from typing import Optional
9
9
  import torch
10
10
 
11
11
  from .utils import update_param_, warmup, init_Q_exprs, trust_region_clip_, PSGDBase, split_p_and_g_in_group, \
12
- line_to_triu, triu_to_line, set_, einsum_base, promote
12
+ line_to_triu, triu_to_line, einsum_base, promote, stochastic_lerp_, beta_debias
13
13
 
14
14
 
15
15
  class ForeachCachedPSGDKron(PSGDBase):
@@ -39,7 +39,8 @@ class ForeachCachedPSGDKron(PSGDBase):
39
39
  max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
40
40
  momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
41
41
  split: bool = False, clip_fn: Optional[callable] = None, store_triu_as_line: bool = True,
42
- foreach: bool = True, q_dtype='float32', stochastic_schedule: bool = True, #
42
+ foreach: bool = True, q_dtype='float32', stochastic_schedule: bool = True,
43
+ storage_dtype: str = 'float32', #
43
44
  # expert parameters
44
45
  precond_init_scale=1.0, precond_lr=0.1):
45
46
  if not 0.0 <= lr:
@@ -56,7 +57,8 @@ class ForeachCachedPSGDKron(PSGDBase):
56
57
  min_ndim_triangular=min_ndim_triangular, memory_save_mode=memory_save_mode,
57
58
  momentum_into_precond_update=momentum_into_precond_update, precond_lr=precond_lr,
58
59
  precond_init_scale=precond_init_scale, step=0, warmup_steps=warmup_steps, merge_dims=merge_dims,
59
- split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype)
60
+ split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype,
61
+ storage_dtype=storage_dtype)
60
62
  super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
61
63
 
62
64
  def _step(self, group):
@@ -71,15 +73,16 @@ class ForeachCachedPSGDKron(PSGDBase):
71
73
  beta = group['beta']
72
74
  store_triu_as_line = group['store_triu_as_line']
73
75
  q_dtype = getattr(torch, group['q_dtype'])
76
+ storage_dtype = getattr(torch, group['storage_dtype'])
74
77
  should_update = self.should_update(group)
75
78
 
76
79
  vals = []
77
80
 
78
- for p, g in split_p_and_g_in_group(group):
81
+ for p, g in split_p_and_g_in_group(group, should_promote=False):
79
82
  state = self.state_(p)
80
83
 
81
84
  if 'Q' not in state:
82
- state["exp_avg"] = torch.zeros_like(g)
85
+ state["exp_avg"] = torch.zeros_like(g, dtype=storage_dtype)
83
86
  Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular, min_ndim_triangular,
84
87
  memory_save_mode, dtype=q_dtype)
85
88
  state['Q'] = triu_to_line(Q) if store_triu_as_line else Q
@@ -103,7 +106,9 @@ class ForeachCachedPSGDKron(PSGDBase):
103
106
 
104
107
  group["step"] += 1
105
108
 
106
- torch._foreach_lerp_(exp_avg_list, grad_list, (1 - beta) / (1 - beta ** group["step"]))
109
+ stochastic_lerp_(exp_avg_list, grad_list, 1 - beta_debias(beta, group['step']))
110
+
111
+ lr = -warmup(lr, group['step'], group['warmup_steps'])
107
112
 
108
113
  grad_list, Q_list, Q_cache_list, exp_avg_list = list(grad_list), list(Q_list), list(Q_cache_list), list(
109
114
  exp_avg_list)
@@ -123,9 +128,5 @@ class ForeachCachedPSGDKron(PSGDBase):
123
128
  else:
124
129
  torch.mul(q_.conj(), q_, out=c_)
125
130
 
126
- set_(g, torch.einsum(self.state_(p)['cache_expr'], *cached_q, ea))
127
-
128
- grad_list = self.clip_fn(grad_list)
129
-
130
- lr = -warmup(lr, group['step'], group['warmup_steps'])
131
- update_param_(p_list, grad_list, lr, weight_decay)
131
+ g = torch.einsum(self.state_(p)['cache_expr'], *cached_q, ea)
132
+ update_param_([p], self.clip_fn([g]), lr, weight_decay)
heavyball/delayed_psgd.py CHANGED
@@ -5,10 +5,10 @@ Source available at https://github.com/evanatyourservice/kron_torch/blob/97a2b5e
5
5
  """
6
6
 
7
7
  import torch
8
- from heavyball.utils import copy_stochastic_list_
9
8
 
9
+ from heavyball.utils import stochastic_lerp_, beta_debias
10
10
  from .utils import update_param_, warmup, psgd_precond_grad, init_Q_exprs, trust_region_clip_, PSGDBase, \
11
- precond_update_prob_schedule, split_p_and_g_in_group, triu_to_line, line_to_triu, set_, promote
11
+ split_p_and_g_in_group, triu_to_line, line_to_triu, promote
12
12
 
13
13
 
14
14
  class ForeachDelayedPSGD(PSGDBase):
@@ -38,8 +38,8 @@ class ForeachDelayedPSGD(PSGDBase):
38
38
  def __init__(self, params, lr=0.001, beta=0.9, weight_decay=0.0, preconditioner_update_probability=None,
39
39
  max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
40
40
  momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
41
- split: bool = False, clip_fn: callable = None, store_triu_as_line: bool = True,
42
- foreach: bool = True, q_dtype='float32', stochastic_schedule: bool = True, #
41
+ split: bool = False, clip_fn: callable = None, store_triu_as_line: bool = True, foreach: bool = True,
42
+ q_dtype='float32', stochastic_schedule: bool = True, storage_dtype:str='float32', #
43
43
  # expert parameters
44
44
  precond_init_scale=1.0, precond_lr=0.1):
45
45
  if not 0.0 <= lr:
@@ -55,12 +55,10 @@ class ForeachDelayedPSGD(PSGDBase):
55
55
  defaults = dict(lr=lr, beta=beta, weight_decay=weight_decay, max_size_triangular=max_size_triangular,
56
56
  min_ndim_triangular=min_ndim_triangular, memory_save_mode=memory_save_mode,
57
57
  momentum_into_precond_update=momentum_into_precond_update, precond_lr=precond_lr,
58
- precond_init_scale=precond_init_scale,
59
- step=0, warmup_steps=warmup_steps, merge_dims=merge_dims, split=split,
60
- store_triu_as_line=store_triu_as_line, q_dtype=q_dtype)
58
+ precond_init_scale=precond_init_scale, step=0, warmup_steps=warmup_steps, merge_dims=merge_dims,
59
+ split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype)
61
60
  super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
62
61
 
63
-
64
62
  def _step(self, group):
65
63
  should_update = self.should_update(group)
66
64
  momentum_into_precond_update = group.get("momentum_into_precond_update", True)
@@ -74,14 +72,15 @@ class ForeachDelayedPSGD(PSGDBase):
74
72
  beta = group['beta']
75
73
  store_triu_as_line = group['store_triu_as_line']
76
74
  q_dtype = getattr(torch, group['q_dtype'])
75
+ storage_dtype = getattr(torch, group['storage_dtype'])
77
76
 
78
77
  vals = []
79
78
 
80
- for p, g in split_p_and_g_in_group(group):
79
+ for p, g in split_p_and_g_in_group(group, should_promote=False):
81
80
  state = self.state_(p)
82
81
 
83
82
  if 'Q' not in state:
84
- state["exp_avg"] = torch.zeros_like(g)
83
+ state["exp_avg"] = torch.zeros_like(g, dtype=storage_dtype)
85
84
  Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular, min_ndim_triangular,
86
85
  memory_save_mode, dtype=q_dtype)
87
86
  state["Q"] = triu_to_line(Q) if store_triu_as_line else Q
@@ -96,7 +95,9 @@ class ForeachDelayedPSGD(PSGDBase):
96
95
 
97
96
  group["step"] += 1
98
97
 
99
- torch._foreach_lerp_(exp_avg_list, grad_list, (1 - beta) / (1 - beta ** group["step"]))
98
+ stochastic_lerp_(exp_avg_list, grad_list, beta_debias(beta, group["step"]))
99
+
100
+ lr = -warmup(lr, group['step'], group['warmup_steps'])
100
101
 
101
102
  Q_list, exp_avg_list = list(Q_list), list(exp_avg_list)
102
103
  for i, (p, g) in enumerate(zip(p_list, grad_list)):
@@ -106,10 +107,6 @@ class ForeachDelayedPSGD(PSGDBase):
106
107
  new = psgd_precond_grad(q, self.state_(p)["exprs"], ea)
107
108
  if should_update:
108
109
  q32 = [promote(q_) for q_ in q]
109
- self.do_update(group,[p], [ea if momentum_into_precond_update else g], [q32], precond_lr, [q_orig], store_triu_as_line)
110
- set_(g, new)
111
-
112
- grad_list = self.clip_fn(grad_list)
113
-
114
- lr = -warmup(lr, group['step'], group['warmup_steps'])
115
- update_param_(p_list, grad_list, lr, weight_decay)
110
+ self.do_update(group, [p], [ea if momentum_into_precond_update else g], [q32], precond_lr, [q_orig],
111
+ store_triu_as_line)
112
+ update_param_([p], self.clip_fn([new]), lr, weight_decay)
heavyball/foreach_soap.py CHANGED
@@ -1,7 +1,7 @@
1
1
  import torch
2
2
 
3
3
  from .utils import init_preconditioner, update_preconditioner, project, beta_debias, exp_avg_sq_, update_param_, set_, \
4
- split_p_and_g_in_group, StatefulOptimizer
4
+ split_p_and_g_in_group, StatefulOptimizer, exp_avg_
5
5
 
6
6
 
7
7
  class ForeachSOAP(StatefulOptimizer):
@@ -26,8 +26,7 @@ class ForeachSOAP(StatefulOptimizer):
26
26
  weight_decay: float = 0.01, precondition_frequency: int = 2, max_precond_dim: int = 2048, #
27
27
  merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
28
28
  data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1,
29
- split: bool = False,
30
- foreach: bool = True):
29
+ split: bool = False, foreach: bool = True):
31
30
  defaults = {"lr": lr, "betas": betas, "shampoo_beta": shampoo_beta, "eps": eps, "weight_decay": weight_decay,
32
31
  "precondition_frequency": precondition_frequency, "max_precond_dim": max_precond_dim,
33
32
  "merge_dims": merge_dims, "precondition_1d": precondition_1d, "normalize_grads": normalize_grads,
@@ -65,14 +64,12 @@ class ForeachSOAP(StatefulOptimizer):
65
64
  p_list, grad, grad_projected, exp_avg, exp_avg_sq = zip(*vals)
66
65
  beta1, beta2 = group["betas"]
67
66
 
68
- old_debiased1 = beta_debias(beta1, step)
69
67
  old_debiased2 = beta_debias(beta2, step)
70
68
 
71
69
  # Decay the first and second moment running average coefficient
72
70
  # In-place operations to update the averages at the same time
73
- torch._foreach_mul_(exp_avg, old_debiased1)
74
- torch._foreach_add_(exp_avg, grad, alpha=1 - old_debiased1)
75
- denom = exp_avg_sq_(exp_avg_sq, grad_projected, old_debiased2, group['eps'])
71
+ step_tensor = torch.empty((), dtype=torch.int32, device=p_list[0].device).fill_(step)
72
+ denom = exp_avg_(exp_avg, exp_avg_sq, grad, grad_projected, beta1, beta2, step_tensor)
76
73
 
77
74
  for p, g, ea, d in zip(p_list, grad, exp_avg, denom):
78
75
  state = self.state_(p)
heavyball/p_adam.py CHANGED
@@ -39,7 +39,7 @@ class ForeachPaLMPAdam(PSGDBase):
39
39
  momentum_into_precond_update=True, warmup_steps: int = 1, betas=(None, None), beta: float = 0.9,
40
40
  beta2_scale: float = 0.8, merge_dims: bool = False, split: bool = False, clip_fn: callable = None,
41
41
  store_triu_as_line: bool = True, foreach: bool = True, q_dtype='float32',
42
- stochastic_schedule: bool = True, #
42
+ stochastic_schedule: bool = True, storage_dtype:str ='float32',#
43
43
  # expert parameters
44
44
  precond_init_scale=1.0, precond_lr=0.1):
45
45
  if not 0.0 <= lr:
@@ -57,7 +57,7 @@ class ForeachPaLMPAdam(PSGDBase):
57
57
  momentum_into_precond_update=momentum_into_precond_update, precond_lr=precond_lr,
58
58
  precond_init_scale=precond_init_scale, step=0, warmup_steps=warmup_steps, beta=beta,
59
59
  beta2_scale=beta2_scale, merge_dims=merge_dims, split=split,
60
- store_triu_as_line=store_triu_as_line, q_dtype=q_dtype)
60
+ store_triu_as_line=store_triu_as_line, q_dtype=q_dtype, storage_dtype=storage_dtype)
61
61
  super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
62
62
 
63
63
  def _step(self, group):
@@ -71,15 +71,16 @@ class ForeachPaLMPAdam(PSGDBase):
71
71
  lr = group['lr']
72
72
  store_triu_as_line = group['store_triu_as_line']
73
73
  q_dtype = getattr(torch, group['q_dtype'])
74
+ storage_dtype = getattr(torch, group['storage_dtype'])
74
75
 
75
76
  vals = []
76
77
 
77
- for p, g in split_p_and_g_in_group(group):
78
+ for p, g in split_p_and_g_in_group(group, should_promote=False):
78
79
  state = self.state_(p)
79
80
 
80
81
  if 'Q' not in state:
81
- state['exp_avg'] = torch.zeros_like(g)
82
- state['exp_avg_sq'] = torch.zeros_like(g)
82
+ state['exp_avg'] = torch.zeros_like(g, dtype=storage_dtype)
83
+ state['exp_avg_sq'] = torch.zeros_like(g, dtype=storage_dtype)
83
84
  Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular, min_ndim_triangular,
84
85
  memory_save_mode, dtype=q_dtype)
85
86
  state['Q'] = triu_to_line(Q) if store_triu_as_line else Q
@@ -103,6 +104,8 @@ class ForeachPaLMPAdam(PSGDBase):
103
104
 
104
105
  beta2 = 1 - group['step'] ** -group['beta2_scale']
105
106
 
107
+ lr = -warmup(lr, group['step'], group['warmup_steps'])
108
+
106
109
  for p, Q, g, ea, eas in zip(p_list, Q_triu, grad_list, exp_avg, exp_avg_sq):
107
110
  psgd_precond_grad(Q, self.state_(p)["exprs"], g, inplace=True)
108
111
  ea = psgd_precond_grad(Q, self.state_(p)["exprs"], ea)
@@ -112,8 +115,5 @@ class ForeachPaLMPAdam(PSGDBase):
112
115
  divide by g here, because g == denom (from exp_avg_sq_(out=g)), avoids denom allocation
113
116
  divide into g so we can deallocate ea, avoids one allocation (-> less memory than equivalent foreach)
114
117
  """
118
+ update_param_([p], self.clip_fn([g]), lr, weight_decay)
115
119
 
116
- grad_list = self.clip_fn(grad_list)
117
-
118
- lr = -warmup(lr, group['step'], group['warmup_steps'])
119
- update_param_(p_list, grad_list, lr, weight_decay)
@@ -1,7 +1,8 @@
1
1
  import torch
2
2
 
3
3
  from .utils import init_preconditioner, update_preconditioner, project, beta_debias, exp_avg_sq_, update_param_, set_, \
4
- split_p_and_g_in_group, StatefulOptimizer
4
+ split_p_and_g_in_group, StatefulOptimizer, exp_avg_
5
+
5
6
 
6
7
 
7
8
  class PaLMForeachSOAP(StatefulOptimizer):
@@ -32,8 +33,7 @@ class PaLMForeachSOAP(StatefulOptimizer):
32
33
  max_precond_dim: int = 2048, #
33
34
  merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
34
35
  data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1,
35
- beta2_scale: float = 0.8, split: bool = False,
36
- foreach: bool = True):
36
+ beta2_scale: float = 0.8, split: bool = False, foreach: bool = True):
37
37
  if betas[0] is not None:
38
38
  beta = betas[0]
39
39
  defaults = {"lr": lr, "beta": beta, "shampoo_beta": shampoo_beta, "eps": eps, "weight_decay": weight_decay,
@@ -75,13 +75,13 @@ class PaLMForeachSOAP(StatefulOptimizer):
75
75
  beta1 = group["beta"]
76
76
 
77
77
  beta2 = 1 - step ** -group['beta2_scale']
78
- old_debiased1 = beta_debias(beta1, step)
79
78
  old_debiased2 = beta_debias(beta2, step)
80
79
 
81
80
  # Decay the first and second moment running average coefficient
82
81
  # In-place operations to update the averages at the same time
83
- torch._foreach_lerp_(exp_avg, grad, 1 - old_debiased1)
84
- denom = exp_avg_sq_(exp_avg_sq, grad_projected, old_debiased2, group['eps'])
82
+ beta2 = torch.empty((), dtype=torch.float32, device=p_list[0].device).fill_(beta2)
83
+ step_tensor = torch.empty((), dtype=torch.int32, device=p_list[0].device).fill_(step)
84
+ denom = exp_avg_(exp_avg, exp_avg_sq, grad, grad_projected, beta1, beta2, step_tensor)
85
85
 
86
86
  for p, g, ea, d in zip(p_list, grad, exp_avg, denom):
87
87
  state = self.state_(p)
@@ -2,8 +2,8 @@ import random
2
2
 
3
3
  import torch
4
4
 
5
- from .utils import init_preconditioner, update_preconditioner, project, beta_debias, exp_avg_sq_, update_param_, \
6
- precond_schedule, set_, split_p_and_g_in_group, StatefulOptimizer
5
+ from .utils import init_preconditioner, update_preconditioner, project, beta_debias, update_param_, \
6
+ precond_schedule, set_, split_p_and_g_in_group, StatefulOptimizer, exp_avg_
7
7
 
8
8
 
9
9
  class PrecondScheduleForeachSOAP(StatefulOptimizer):
@@ -27,8 +27,7 @@ class PrecondScheduleForeachSOAP(StatefulOptimizer):
27
27
  weight_decay: float = 0.01, precondition_frequency: int = 2, max_precond_dim: int = 2048, #
28
28
  merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
29
29
  data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1,
30
- precond_scheduler=(1 / 3, 9), split: bool = False,
31
- foreach: bool = True):
30
+ precond_scheduler=(1 / 3, 9), split: bool = False, foreach: bool = True):
32
31
  defaults = {"lr": lr, "betas": betas, "shampoo_beta": shampoo_beta, "eps": eps, "weight_decay": weight_decay,
33
32
  "precondition_frequency": precondition_frequency, "max_precond_dim": max_precond_dim,
34
33
  "merge_dims": merge_dims, "precondition_1d": precondition_1d, "normalize_grads": normalize_grads,
@@ -68,14 +67,12 @@ class PrecondScheduleForeachSOAP(StatefulOptimizer):
68
67
  p_list, grad, grad_projected, exp_avg, exp_avg_sq = zip(*vals)
69
68
  beta1, beta2 = group["betas"]
70
69
 
71
- old_debiased1 = beta_debias(beta1, step)
72
70
  old_debiased2 = beta_debias(beta2, step)
73
71
 
74
72
  # Decay the first and second moment running average coefficient
75
73
  # In-place operations to update the averages at the same time
76
- torch._foreach_mul_(exp_avg, old_debiased1)
77
- torch._foreach_add_(exp_avg, grad, alpha=1 - old_debiased1)
78
- denom = exp_avg_sq_(exp_avg_sq, grad_projected, old_debiased2, group['eps'])
74
+ step_tensor = torch.empty((), dtype=torch.int32, device=p_list[0].device).fill_(step)
75
+ denom = exp_avg_(exp_avg, exp_avg_sq, grad, grad_projected, beta1, beta2, step_tensor)
79
76
 
80
77
  update_precond = precond_schedule(step, group['precond_scheduler'], self.rng)
81
78
  for p, g, ea, d in zip(p_list, grad, exp_avg, denom):
@@ -89,8 +86,7 @@ class PrecondScheduleForeachSOAP(StatefulOptimizer):
89
86
  # CANT DO /= HERE AS EXP_AVG MAY POINT TO THE BUFFER
90
87
  set_(d, project(exp_avg_projected / d, state['Q'], True))
91
88
 
92
- update_preconditioner(g, state, max_precond_dim, precondition_1d, old_debiased2,
93
- update_precond)
89
+ update_preconditioner(g, state, max_precond_dim, precondition_1d, old_debiased2, update_precond)
94
90
 
95
91
  # Why does this have to be rebiased here?
96
92
  step_size = -group["lr"] * min(step / group['warmup_steps'], 1)
@@ -2,7 +2,7 @@ import random
2
2
 
3
3
  import torch
4
4
 
5
- from .utils import init_preconditioner, update_preconditioner, project, beta_debias, exp_avg_sq_, update_param_, \
5
+ from .utils import init_preconditioner, update_preconditioner, project, beta_debias, exp_avg_, update_param_, \
6
6
  precond_schedule, set_, split_p_and_g_in_group, StatefulOptimizer
7
7
 
8
8
 
@@ -81,9 +81,9 @@ class PrecondSchedulePaLMForeachSOAP(StatefulOptimizer):
81
81
 
82
82
  # Decay the first and second moment running average coefficient
83
83
  # In-place operations to update the averages at the same time
84
- torch._foreach_mul_(exp_avg, old_debiased1)
85
- torch._foreach_add_(exp_avg, grad, alpha=1 - old_debiased1)
86
- denom = exp_avg_sq_(exp_avg_sq, grad_projected, old_debiased2, group['eps'])
84
+ beta2 = torch.empty((), dtype=torch.float32, device=p_list[0].device).fill_(beta2)
85
+ step_tensor = torch.empty((), dtype=torch.int32, device=p_list[0].device).fill_(step)
86
+ denom = exp_avg_(exp_avg, exp_avg_sq, grad, grad_projected, beta1, beta2, step_tensor)
87
87
 
88
88
  update_precond = precond_schedule(step, group['precond_scheduler'], self.rng)
89
89
  for p, g, ea, d in zip(p_list, grad, exp_avg, denom):
@@ -2,8 +2,19 @@ import random
2
2
 
3
3
  import torch
4
4
 
5
- from .utils import init_preconditioner, update_preconditioner, project, set_, adaptive_gradient_clipping_, \
6
- exp_avg_sq_, beta_debias, schedule_free_, warmup, ScheduleFree, precond_schedule, split_p_and_g_in_group
5
+ from .utils import init_preconditioner, update_preconditioner, project, set_, adaptive_gradient_clipping_, exp_avg_sq_, \
6
+ beta_debias, schedule_free_, warmup, ScheduleFree, precond_schedule, split_p_and_g_in_group, copy_stochastic_list_, \
7
+ promote
8
+
9
+
10
+ @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
11
+ def _compilable_exp_avg_sq_(exp_avg_sq, grad_projected, old_debiased2, eps):
12
+ eas32, gp32 = [list(map(promote, x)) for x in (exp_avg_sq, grad_projected)]
13
+ denom = exp_avg_sq_(eas32, gp32, old_debiased2, eps)
14
+ torch._foreach_div_(gp32, denom)
15
+
16
+ copy_stochastic_list_(exp_avg_sq, eas32)
17
+ copy_stochastic_list_(grad_projected, gp32)
7
18
 
8
19
 
9
20
  class PrecondScheduleSFPaLMSOAP(ScheduleFree):
@@ -40,8 +51,8 @@ class PrecondScheduleSFPaLMSOAP(ScheduleFree):
40
51
  weight_decay: float = 0.01, precondition_frequency: int = 2, max_precond_dim: int = 2048, #
41
52
  merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
42
53
  data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1, r=0.0,
43
- weight_lr_power=2.0, gradient_clip_val: float = 0.1, precond_scheduler=(1 / 3, 9),
44
- betas=(None, None), split: bool = False, foreach: bool = True):
54
+ weight_lr_power=2.0, gradient_clip_val: float = 0.1, precond_scheduler=(1 / 3, 9), betas=(None, None),
55
+ split: bool = False, foreach: bool = True):
45
56
  if betas[0] is not None:
46
57
  beta = betas[0]
47
58
  defaults = {"lr": lr, "beta": beta, "beta2_scale": beta2_scale, "eps": eps, "weight_decay": weight_decay,
@@ -103,8 +114,8 @@ class PrecondScheduleSFPaLMSOAP(ScheduleFree):
103
114
 
104
115
  # Decay the first and second moment running average coefficient
105
116
  # In-place operations to update the averages at the same time
106
- denom = exp_avg_sq_(exp_avg_sq, grad_projected, old_debiased2, group['eps'])
107
- torch._foreach_div_(grad_projected, denom)
117
+ old_debiased_tensor = torch.empty((), dtype=torch.float32, device=p_list[0].device).fill_(old_debiased2)
118
+ _compilable_exp_avg_sq_(exp_avg_sq, grad_projected, old_debiased_tensor, group["eps"])
108
119
 
109
120
  update_precond = precond_schedule(step, group['precond_scheduler'], self.rng)
110
121
 
@@ -114,13 +125,12 @@ class PrecondScheduleSFPaLMSOAP(ScheduleFree):
114
125
  # to the original space
115
126
  set_(gp, project(gp, state['Q'], back=True))
116
127
 
117
- update_preconditioner(g, state, max_precond_dim, precondition_1d, old_debiased2,
118
- update_precond)
128
+ update_preconditioner(g, state, max_precond_dim, precondition_1d, old_debiased2, update_precond)
119
129
 
120
130
  # Weight decay calculated at y
121
131
  if group["weight_decay"] > 0:
122
132
  torch._foreach_add_(grad, p_list, alpha=group["weight_decay"])
123
133
 
124
134
  lr = warmup(group['lr'], step, group['warmup_steps'])
125
- group['weight_sum'] = schedule_free_(lr, group['weight_lr_power'], group['weight_sum'], group['beta'],
126
- p_list, z, grad_projected, group['r'], step)
135
+ group['weight_sum'] = schedule_free_(lr, group['weight_lr_power'], group['weight_sum'], group['beta'], p_list,
136
+ z, grad_projected, group['r'], step)
heavyball/psgd_kron.py CHANGED
@@ -9,7 +9,7 @@ from typing import Optional
9
9
  import torch
10
10
 
11
11
  from .utils import update_param_, warmup, psgd_precond_grad, init_Q_exprs, trust_region_clip_, PSGDBase, \
12
- split_p_and_g_in_group, line_to_triu, triu_to_line, set_, promote
12
+ split_p_and_g_in_group, line_to_triu, triu_to_line, promote, stochastic_lerp_, beta_debias
13
13
 
14
14
 
15
15
  class ForeachPSGDKron(PSGDBase):
@@ -39,7 +39,8 @@ class ForeachPSGDKron(PSGDBase):
39
39
  max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
40
40
  momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
41
41
  split: bool = False, clip_fn: Optional[callable] = None, store_triu_as_line: bool = True,
42
- foreach: bool = True, q_dtype='float32', stochastic_schedule: bool = True, #
42
+ foreach: bool = True, q_dtype='float32', stochastic_schedule: bool = True,
43
+ storage_dtype: str = 'float32', #
43
44
  # expert parameters
44
45
  precond_init_scale=1.0, precond_lr=0.1):
45
46
  if not 0.0 <= lr:
@@ -56,7 +57,7 @@ class ForeachPSGDKron(PSGDBase):
56
57
  min_ndim_triangular=min_ndim_triangular, memory_save_mode=memory_save_mode,
57
58
  momentum_into_precond_update=momentum_into_precond_update, precond_lr=precond_lr,
58
59
  precond_init_scale=precond_init_scale, step=0, warmup_steps=warmup_steps, merge_dims=merge_dims,
59
- split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype)
60
+ split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype, storage_dtype=storage_dtype)
60
61
  super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
61
62
 
62
63
  def _step(self, group):
@@ -72,14 +73,15 @@ class ForeachPSGDKron(PSGDBase):
72
73
  beta = group['beta']
73
74
  store_triu_as_line = group['store_triu_as_line']
74
75
  q_dtype = getattr(torch, group['q_dtype'])
76
+ storage_dtype = getattr(torch, group['storage_dtype'])
75
77
 
76
78
  vals = []
77
79
 
78
- for p, g in split_p_and_g_in_group(group):
80
+ for p, g in split_p_and_g_in_group(group, should_promote=False):
79
81
  state = self.state_(p)
80
82
 
81
83
  if 'Q' not in state:
82
- state["exp_avg"] = torch.zeros_like(g)
84
+ state["exp_avg"] = torch.zeros_like(g, dtype=storage_dtype)
83
85
  Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular, min_ndim_triangular,
84
86
  memory_save_mode, dtype=q_dtype)
85
87
  state['Q'] = triu_to_line(Q) if store_triu_as_line else Q
@@ -94,9 +96,14 @@ class ForeachPSGDKron(PSGDBase):
94
96
 
95
97
  group["step"] += 1
96
98
 
97
- torch._foreach_lerp_(exp_avg_list, grad_list, (1 - beta) / (1 - beta ** group["step"]))
99
+ beta = beta_debias(beta, group["step"])
100
+ beta = torch.empty((), dtype=torch.float32, device=grad_list[0].device).fill_(1 - beta)
101
+ stochastic_lerp_(exp_avg_list, grad_list, 1 - beta)
98
102
 
99
103
  grad_list, Q_list, exp_avg_list = list(grad_list), list(Q_list), list(exp_avg_list)
104
+
105
+ lr = -warmup(lr, group['step'], group['warmup_steps'])
106
+
100
107
  for i, (p, g) in enumerate(zip(p_list, grad_list)):
101
108
  q_orig = Q_list.pop(0)
102
109
  ea = exp_avg_list.pop(0)
@@ -106,9 +113,5 @@ class ForeachPSGDKron(PSGDBase):
106
113
  q32 = [promote(q_) for q_ in q]
107
114
  self.do_update(group, [p], [ea if momentum_into_precond_update else g], [q32], precond_lr, [q_orig],
108
115
  store_triu_as_line)
109
- set_(g, psgd_precond_grad(q, self.state_(p)["exprs"], ea))
110
-
111
- grad_list = self.clip_fn(grad_list)
112
-
113
- lr = -warmup(lr, group['step'], group['warmup_steps'])
114
- update_param_(p_list, grad_list, lr, weight_decay)
116
+ g = psgd_precond_grad(q, self.state_(p)["exprs"], ea)
117
+ update_param_([p], self.clip_fn([g]), lr, weight_decay)
heavyball/pure_psgd.py CHANGED
@@ -70,7 +70,7 @@ class ForeachPurePSGD(PSGDBase):
70
70
 
71
71
  vals = []
72
72
 
73
- for p, g in split_p_and_g_in_group(group):
73
+ for p, g in split_p_and_g_in_group(group, should_promote=False):
74
74
  state = self.state_(p)
75
75
 
76
76
  if 'Q' not in state:
@@ -89,6 +89,7 @@ class ForeachPurePSGD(PSGDBase):
89
89
  group["step"] += 1
90
90
 
91
91
  Q_list = list(Q_list)
92
+ lr = -warmup(lr, group['step'], group['warmup_steps'])
92
93
  for i, (p, g) in enumerate(zip(p_list, grad_list)):
93
94
  q_orig = Q_list.pop(0)
94
95
  q = line_to_triu(q_orig) if store_triu_as_line else q_orig
@@ -97,8 +98,4 @@ class ForeachPurePSGD(PSGDBase):
97
98
  q32 = [promote(q_) for q_ in q]
98
99
  self.do_update(group, [p], [g], [q32], precond_lr, [q_orig], store_triu_as_line)
99
100
  psgd_precond_grad(q, self.state_(p)["exprs"], g, inplace=True)
100
-
101
- grad_list = self.clip_fn(grad_list)
102
-
103
- lr = -warmup(lr, group['step'], group['warmup_steps'])
104
- update_param_(p_list, grad_list, lr, weight_decay)
101
+ update_param_([p], self.clip_fn([g]), lr, weight_decay)
@@ -2,8 +2,18 @@ import random
2
2
 
3
3
  import torch
4
4
 
5
- from .utils import init_preconditioner, update_preconditioner, project, set_, adaptive_gradient_clipping_, \
6
- exp_avg_sq_, beta_debias, schedule_free_, warmup, ScheduleFree, split_p_and_g_in_group
5
+ from .utils import init_preconditioner, update_preconditioner, project, set_, adaptive_gradient_clipping_, exp_avg_sq_, \
6
+ beta_debias, schedule_free_, warmup, ScheduleFree, split_p_and_g_in_group, copy_stochastic_list_, promote
7
+
8
+
9
+ @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
10
+ def _compilable_exp_avg_sq_(exp_avg_sq, grad_projected, old_debiased2, eps):
11
+ eas32, gp32 = [list(map(promote, x)) for x in (exp_avg_sq, grad_projected)]
12
+ denom = exp_avg_sq_(eas32, gp32, old_debiased2, eps)
13
+ torch._foreach_div_(gp32, denom)
14
+
15
+ copy_stochastic_list_(exp_avg_sq, eas32)
16
+ copy_stochastic_list_(grad_projected, gp32)
7
17
 
8
18
 
9
19
  class SFPaLMForeachSOAP(ScheduleFree):
@@ -95,8 +105,8 @@ class SFPaLMForeachSOAP(ScheduleFree):
95
105
 
96
106
  # Decay the first and second moment running average coefficient
97
107
  # In-place operations to update the averages at the same time
98
- denom = exp_avg_sq_(exp_avg_sq, grad, new_debiased2, group["eps"])
99
- torch._foreach_div_(grad_projected, denom)
108
+ old_debiased_tensor = torch.empty((), dtype=torch.float32, device=p_list[0].device).fill_(new_debiased2)
109
+ _compilable_exp_avg_sq_(exp_avg_sq, grad_projected, old_debiased_tensor, group["eps"])
100
110
 
101
111
  update_precond = group['step'] > 0 and group['step'] % group['precondition_frequency'] == 0
102
112
 
@@ -107,13 +117,12 @@ class SFPaLMForeachSOAP(ScheduleFree):
107
117
  # CANT DO /= HERE AS EXP_AVG MAY POINT TO THE BUFFER
108
118
  set_(gp, project(gp, state['Q'], back=True))
109
119
 
110
- update_preconditioner(g, state, max_precond_dim, precondition_1d, 1 - new_debiased2,
111
- update_precond)
120
+ update_preconditioner(g, state, max_precond_dim, precondition_1d, 1 - new_debiased2, update_precond)
112
121
 
113
122
  # Weight decay calculated at y
114
123
  if group["weight_decay"] > 0:
115
124
  torch._foreach_add_(grad, p_list, alpha=group["weight_decay"])
116
125
 
117
126
  lr = warmup(group['lr'], step, group['warmup_steps'])
118
- group['weight_sum'] = schedule_free_(lr, group['weight_lr_power'], group['weight_sum'], group['beta'],
119
- p_list, z, grad_projected, group['r'], step)
127
+ group['weight_sum'] = schedule_free_(lr, group['weight_lr_power'], group['weight_sum'], group['beta'], p_list,
128
+ z, grad_projected, group['r'], step)
heavyball/utils.py CHANGED
@@ -3,7 +3,7 @@ import gc
3
3
  import math
4
4
  import random
5
5
  import string
6
- from typing import List, Optional, Tuple, Callable
6
+ from typing import List, Optional, Tuple, Callable, Union
7
7
 
8
8
  import numpy as np
9
9
  import torch
@@ -141,6 +141,7 @@ def beta_debias(beta, step):
141
141
  return 1 - (1 - beta) / (1 - beta ** step)
142
142
 
143
143
 
144
+ @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
144
145
  def exp_avg_sq_(state, grad, beta2, eps, out=None):
145
146
  if isinstance(state, torch.Tensor):
146
147
  state.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
@@ -327,6 +328,26 @@ def get_orthogonal_matrix(mat):
327
328
  return final
328
329
 
329
330
 
331
+ @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
332
+ def stochastic_lerp_(x: List[torch.Tensor], y: List[torch.Tensor], a: Union[float, int, torch.Tensor]):
333
+ x32 = [promote(x_) for x_ in x]
334
+ y32 = [promote(y_) for y_ in y]
335
+
336
+ torch._foreach_lerp_(x32, y32, a)
337
+
338
+ copy_stochastic_list_(x, x32)
339
+
340
+
341
+ @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
342
+ def stochastic_add_(x: List[torch.Tensor], y: List[torch.Tensor], alpha: Union[float, int, torch.Tensor]):
343
+ x32 = [promote(x_) for x_ in x]
344
+ y32 = [promote(y_) for y_ in y]
345
+
346
+ [x_.add_(y_, alpha=alpha) for x_, y_ in zip(x32, y32)]
347
+
348
+ copy_stochastic_list_(x, x32)
349
+
350
+
330
351
  @decorator
331
352
  def compute_ggt(grad, GG, max_precond_dim, precondition_1d, beta):
332
353
  if grad.dim() == 1 and (not precondition_1d or grad.shape[0] > max_precond_dim):
@@ -409,9 +430,12 @@ def project(grad, Q, back: bool):
409
430
 
410
431
 
411
432
  class StatefulOptimizer(torch.optim.Optimizer):
412
- def __init__(self, params, defaults, foreach: bool = True):
433
+ ema_decay: float = 0.001
434
+
435
+ def __init__(self, params, defaults, foreach: bool = True, use_ema: bool = False):
413
436
  super().__init__(params, {**defaults, 'foreach': foreach})
414
437
  self.fake_groups = {}
438
+ self.use_ema = use_ema
415
439
 
416
440
  def key(self, param: torch.Tensor):
417
441
  return (param.data_ptr(), tuple(param.shape))
@@ -445,6 +469,54 @@ class StatefulOptimizer(torch.optim.Optimizer):
445
469
  def _step(self, group):
446
470
  raise NotImplementedError
447
471
 
472
+ def ema_update(self):
473
+ with torch.no_grad():
474
+ for top_group in self.param_groups:
475
+ for group in self.get_groups(top_group):
476
+ active_p = [p for p in group['params']]
477
+
478
+ if not active_p:
479
+ return
480
+
481
+ k = group['ema_step'] = group.get('ema_step', -1) + 1
482
+
483
+ for p in active_p:
484
+ if 'param_ema' not in self.state_(p):
485
+ self.state_(p)['param_ema'] = torch.zeros_like(p.data, memory_format=torch.preserve_format)
486
+
487
+ y, param_ema = zip(*[(p.data, self.state_(p)['param_ema']) for p in active_p])
488
+ torch._foreach_lerp_(param_ema, y, weight=beta_debias(1 - self.ema_decay, k + 1))
489
+
490
+ def copy_emas_to_params(self):
491
+ with torch.no_grad():
492
+ for top_group in self.param_groups:
493
+ for group in self.get_groups(top_group):
494
+ active_p = [p for p in group['params']]
495
+
496
+ if not active_p:
497
+ return
498
+
499
+ for p in active_p:
500
+ if 'param_ema' in self.state_(p):
501
+ p_clone = p.data.clone()
502
+ set_(p.data, self.state_(p)['param_ema'])
503
+ set_(self.state_(p)['param_ema'], p_clone)
504
+
505
+ def copy_params_to_emas(self):
506
+ with torch.no_grad():
507
+ for top_group in self.param_groups:
508
+ for group in self.get_groups(top_group):
509
+ active_p = [p for p in group['params']]
510
+
511
+ if not active_p:
512
+ return
513
+
514
+ for p in active_p:
515
+ if 'param_ema' in self.state_(p):
516
+ ema_clone = self.state_(p)['param_ema'].data.clone()
517
+ set_(self.state_(p)['param_ema'], p.data)
518
+ set_(p.data, ema_clone)
519
+
448
520
  def step(self, closure: Optional[Callable] = None):
449
521
  if closure is None:
450
522
  loss = None
@@ -455,6 +527,8 @@ class StatefulOptimizer(torch.optim.Optimizer):
455
527
  for top_group in self.param_groups:
456
528
  for group in self.get_groups(top_group):
457
529
  self._step(group)
530
+ if self.use_ema:
531
+ self.ema_update(group)
458
532
  return loss
459
533
 
460
534
 
@@ -497,6 +571,20 @@ def copy_stochastic_list_(target: List[torch.Tensor], source: List[torch.Tensor]
497
571
  copy_stochastic_(t, s)
498
572
 
499
573
 
574
+ @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
575
+ def exp_avg_(exp_avg, exp_avg_sq, grad, grad_projected, beta1, beta2, step):
576
+ beta1 = beta_debias(beta1, step)
577
+ beta2 = beta_debias(beta2, step)
578
+
579
+ g32, gp32, exp_avg_sq32 = [list(map(promote, x)) for x in [grad, grad_projected, exp_avg_sq]]
580
+
581
+ stochastic_lerp_(exp_avg, g32, 1 - beta1)
582
+ denom = exp_avg_sq_(exp_avg_sq32, gp32, beta2, 1e-8)
583
+
584
+ copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
585
+ return denom
586
+
587
+
500
588
  # this can be dynamic for most optimizers - just not for PSGD. So, it's disabled for all
501
589
  @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True)
502
590
  def _compilable_copy_stochastic_(target: torch.Tensor, source: torch.Tensor):
@@ -523,23 +611,26 @@ def copy_stochastic_(target: torch.Tensor, source: torch.Tensor):
523
611
 
524
612
 
525
613
  @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
526
- def _compilable_update_one_(p, u, decay, add_fn, lr):
527
- p32 = promote(p)
528
- u32 = promote(u.view(p.shape))
614
+ def _compilable_update_(p, u, decay, add_fn, lr):
615
+ u = [u_.view_as(p_) for u_, p_ in zip(u, p)]
616
+ p32, u32 = [list(map(promote, x)) for x in [p, u]]
617
+
529
618
  if decay > 0:
530
- p32.mul_(1 - decay * lr)
531
- if add_fn is None:
532
- p32.add_(u32, alpha=lr)
533
- else:
534
- add_fn(p32, u32, lr)
535
- copy_stochastic_(p, p32)
619
+ torch._foreach_mul_(p32, 1 - decay * lr)
620
+
621
+ for p32_, u32_ in zip(p32, u32): # lr is data-dependent -> can't compile a foreach
622
+ if add_fn is None:
623
+ p32_.add_(u32_, alpha=lr)
624
+ else:
625
+ add_fn(p32_, u32_, lr)
626
+
627
+ copy_stochastic_list_(p, p32)
536
628
 
537
629
 
538
630
  def update_param_(param: List[torch.Tensor], update: List[torch.Tensor], lr: float, decay: float,
539
631
  add_fn: callable = None):
540
632
  lr_tensor = torch.empty((), dtype=torch.float32, device=param[0].device).fill_(lr)
541
- for p, u in zip(param, update):
542
- _compilable_update_one_(p, u, decay, add_fn, lr_tensor)
633
+ _compilable_update_(param, update, decay, add_fn, lr_tensor)
543
634
 
544
635
 
545
636
  def precond_schedule(step, precond_scheduler, rng):
@@ -630,7 +721,7 @@ def init_Q_exprs(t, scale, max_size, min_ndim_triangular, memory_save_mode, dtyp
630
721
  return [Q, (exprA, tuple(exprGs), exprP)]
631
722
 
632
723
 
633
- @decorator
724
+ @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
634
725
  def psgd_balance_Q(Q_in):
635
726
  norms = torch.stack([q.norm(float("inf")) for q in Q_in])
636
727
  geometric_mean = norms.log().mean().exp()
@@ -638,12 +729,14 @@ def psgd_balance_Q(Q_in):
638
729
  torch._foreach_mul_(Q_in, list(norms))
639
730
 
640
731
 
641
- def psgd_calc_A_and_conjB(exprA, G, Q, V):
642
- md = min_dtype(Q)
643
- A = torch.einsum(exprA, *[q.to(md) for q in Q], G.to(md))
732
+ def psgd_calc_A_and_conjB(exprA, G, Q):
733
+ md = min_dtype(Q + [G])
734
+ A = torch.einsum(exprA, *[q.to(md) for q in Q], G.to(md)).to(G.dtype)
644
735
  order = G.dim()
645
736
  p = list(range(order))
646
- conjB = torch.permute(V.conj(), p[1:] + p[:1])
737
+ V = torch.randn_like(G, dtype=promote(G.dtype))
738
+ conjB = torch.permute(V, p[1:] + p[:1])
739
+ Q = [promote(q) for q in Q]
647
740
  for i, q in enumerate(Q):
648
741
  if q.dim() <= 1:
649
742
  conjB /= q
@@ -651,7 +744,7 @@ def psgd_calc_A_and_conjB(exprA, G, Q, V):
651
744
  unsqueeze = conjB.dim() <= 1
652
745
  if unsqueeze:
653
746
  conjB = conjB.unsqueeze(0)
654
- conjB = torch.linalg.solve_triangular(q, conjB, upper=True, left=False, out=conjB)
747
+ conjB = torch.linalg.solve_triangular(q, conjB, upper=True, left=False)
655
748
  if unsqueeze:
656
749
  conjB = conjB.squeeze(0)
657
750
  if i < order - 1:
@@ -661,33 +754,37 @@ def psgd_calc_A_and_conjB(exprA, G, Q, V):
661
754
 
662
755
  def psgd_lb(A, max_abs):
663
756
  A /= max_abs
664
- aa = torch.real(A * A.conj())
665
- value0, i = torch.max(torch.sum(aa, dim=0), 0)
666
- value1, j = torch.max(torch.sum(aa, dim=1), 0)
757
+ a0 = torch.einsum('ij,ij->j', A, A)
758
+ a1 = torch.einsum('ij,ij->i', A, A)
759
+ value0 = torch.max(a0)
760
+ value1 = torch.max(a1)
761
+ i = torch.argmax(a0)
762
+ j = torch.argmax(a1)
667
763
 
668
- ah = A.H
669
764
  comp = value0 > value1
670
- x = torch.where(comp, A[:, i], A[j])
671
- x = x.conj()
672
- if x.dim() > 1:
673
- x = torch.where(comp, x, x.T)
674
- torch.matmul(x, torch.where(comp, A, A.T), out=x.view(1, -1))
675
- x /= torch.linalg.vector_norm(x)
676
- torch.matmul(x, torch.where(comp, ah, ah.T), out=x.view(1, -1))
677
- x = torch.linalg.vector_norm(x)
765
+ x = torch.cond(comp, lambda a: torch.index_select(a, 1, i).flatten().contiguous(), #
766
+ lambda a: torch.index_select(a, 0, j).flatten().contiguous(), (A,))
767
+
768
+ x = torch.cond(comp, lambda x_, a: torch.einsum('i,ij->j', x_, a), lambda x_, a: torch.einsum('i,ji->j', x_, a),
769
+ (x, A,))
770
+ x /= x.norm()
771
+ x = torch.cond(comp, lambda x_, a: torch.einsum('j,kj->k', x_, a), lambda x_, a: torch.einsum('j,jk->k', x_, a),
772
+ (x, A,))
773
+ x = x.norm()
678
774
  x *= max_abs
679
775
  return x
680
776
 
681
777
 
682
- def psgd_update_precond(Q, exprs, V, G, step, tiny):
778
+ @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
779
+ def psgd_update_precond(Q, exprs, G, precond_lr, tiny, oq, store_triu_as_line):
683
780
  """Update Kronecker product preconditioner Q with pair (V, G)."""
684
781
  exprA, exprGs, _ = exprs
685
782
 
686
- A, conjB = psgd_calc_A_and_conjB(exprA, G, Q, V)
783
+ A, conjB = psgd_calc_A_and_conjB(exprA, G, Q)
687
784
 
688
- for q, exprG in zip(Q, exprGs):
689
- term1 = torch.einsum(exprG, A, A.conj())
690
- term2 = torch.einsum(exprG, conjB.conj(), conjB)
785
+ for q, exprG, o in zip(Q, exprGs, oq):
786
+ term1 = promote(torch.einsum(exprG, A, A))
787
+ term2 = promote(torch.einsum(exprG, conjB, conjB))
691
788
 
692
789
  term2 += term1 # a + b
693
790
  term1 *= 2 # 2a
@@ -696,15 +793,19 @@ def psgd_update_precond(Q, exprs, V, G, step, tiny):
696
793
  else:
697
794
  term1 = term1 - term2
698
795
 
699
- term1 *= step
796
+ term1 *= precond_lr
700
797
  norm = term2.norm(float('inf'))
701
798
  if q.dim() < 2:
702
- term1 *= q
703
- q.addcdiv_(term1, norm.clamp_(min=tiny), value=-1)
799
+ term1 *= q.to(term1.dtype)
800
+ term1 /= norm.clamp_(min=tiny)
704
801
  else:
705
802
  torch.triu(term1, out=term1)
706
- term1 /= torch.where(norm > 0, psgd_lb(term2, norm), norm).clamp_(tiny)
707
- q.addmm_(term1, q, alpha=-1)
803
+ term1 /= psgd_lb(term2, norm).clamp_(tiny)
804
+ torch.matmul(term1, q, out=term1)
805
+ if store_triu_as_line:
806
+ term1 = triu_to_line([term1])[0][1]
807
+ o = o[1]
808
+ stochastic_add_([o], [term1], -1)
708
809
 
709
810
 
710
811
  @decorator
@@ -838,18 +939,9 @@ class PSGDBase(StatefulOptimizer):
838
939
  group[name] = cumulative_prob + prob
839
940
  return int(group[name]) > int(cumulative_prob)
840
941
 
841
- def do_update(self, group, p_list, grad_list, q_list, precond_lr, original_q: Optional[List] = None,
842
- store_triu_as_line=False):
843
- if original_q:
844
- if store_triu_as_line:
845
- update_fn = update_triu_
846
- else:
847
- update_fn = copy_stochastic_list_
848
- else:
849
- update_fn = lambda x, y: None
850
- for i, (p, grad, Q, oq) in enumerate(zip(p_list, grad_list, q_list, original_q)):
851
- psgd_update_precond(Q, self.state_(p)["exprs"], torch.randn_like(grad), grad, precond_lr, self._tiny)
852
- update_fn(oq, Q)
942
+ def do_update(self, group, p_list, grad_list, q_list, precond_lr, original_q: List, store_triu_as_line=False):
943
+ for p, grad, Q, oq in zip(p_list, grad_list, q_list, original_q):
944
+ psgd_update_precond(Q, self.state_(p)["exprs"], grad, precond_lr, self._tiny, oq, store_triu_as_line)
853
945
 
854
946
  if self.should_update(group, self.balance_probability, "balance_prob"):
855
947
  for g, q in zip(grad_list, original_q if original_q else q_list):
@@ -896,13 +988,19 @@ def merge_group(group, *tensors):
896
988
  return out
897
989
 
898
990
 
899
- def split_p_and_g_in_group(group: dict, skip_none: bool = True):
991
+ def split_p_and_g_in_group(group: dict, skip_none: bool = True, should_promote: bool = True):
900
992
  for p in group["params"]:
901
993
  if skip_none and p.grad is None:
902
994
  continue
903
995
 
904
- grad = None if p.grad is None else promote(p.grad)
905
- p.grad = None
996
+ if p.grad is None:
997
+ grad = None
998
+ else:
999
+ if should_promote:
1000
+ grad = promote(p.grad)
1001
+ else:
1002
+ grad = p.grad
1003
+ p.grad = None
906
1004
 
907
1005
  p_views = merge_group(group, p)
908
1006
  if grad is not None:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: heavyball
3
- Version: 0.19.0
3
+ Version: 0.20.0
4
4
  Summary: Efficient optimizers
5
5
  Home-page: https://github.com/clashluke/heavyball
6
6
  Author: Lucas Nestler
@@ -32,7 +32,7 @@ A simple package of efficient optimizers
32
32
  The goal is not to thrive for completeness, full maintenance or abstraction, but instead to provide a simple
33
33
  largely static alternative to `torch.optim` with more and better optimizers.
34
34
 
35
- Currently (2024-11-22, 0.19), the recommended stable optimizer is `PrecondSchedulePaLMSOAP` (see below). The
35
+ Currently (2024-11-22, 0.19.1), the recommended stable optimizer is `PrecondSchedulePaLMSOAP` (see below). The
36
36
  recommended experimental optimizer is `DelayedPSGDKron` ([tuning guide](docs/psgd_efficiency.md)).
37
37
 
38
38
  ## Features
@@ -0,0 +1,24 @@
1
+ heavyball/__init__.py,sha256=iqP428JWwwx-XDOZ0nUdbCkOLEyfoqVyWZLQLAcwxaw,2214
2
+ heavyball/cached_delayed_psgd_kron.py,sha256=apVzESMaQ8uxunHvfvYfyWA8HLbS25wQSd3j_YNEjGs,6603
3
+ heavyball/cached_psgd_kron.py,sha256=3IETfsC0Ufu_8TPfo9SByGmztwjW6ktSFPwHNrUWkys,6601
4
+ heavyball/delayed_psgd.py,sha256=0LaazbiBZOdx78EDS-945cW3bmeORjUvdFOGqdw3aMs,5631
5
+ heavyball/foreach_adamw.py,sha256=Rb5U80cgUcEqlEbUU250UTWdoqA7nyiqkV5w1U4bWX4,2445
6
+ heavyball/foreach_adopt.py,sha256=ecdi1fKg9i087OGjtKWVbE_DD6Yf4pvpzv4ELCcusvQ,3211
7
+ heavyball/foreach_laprop.py,sha256=vi6C_gfjXxw5uN0KHgzxI9itUI1dcgOf3ufoO_VVMp0,2471
8
+ heavyball/foreach_sfadamw.py,sha256=rLZORmCIMu9G09FdDgMSiI6pNq34IVoxsPVWtmeDdbQ,2753
9
+ heavyball/foreach_soap.py,sha256=4mWSMWYTdjgiXiboI5DwdigecruDtNGKylGAFAVhCRA,4562
10
+ heavyball/p_adam.py,sha256=J5QqFAlyLTQ1eQzM0LGxPdv4fEtZikIv9mJ_SSkO3ZY,6033
11
+ heavyball/palm_foreach_sfadamw.py,sha256=JbNrcoquBGGUI5XNMFouDjpNurVHUW9DbX1A3tSrtno,3025
12
+ heavyball/palm_foreach_soap.py,sha256=GzAwM8kOt1X0QCmUZDTdHwPxbJwjH8ic43dyAK5BYCA,6015
13
+ heavyball/precond_schedule_foreach_soap.py,sha256=HcObXLfSNN_lKNb4nmC6tkdHcqDIMNX6hILpHKScqLc,4744
14
+ heavyball/precond_schedule_palm_foreach_soap.py,sha256=xZ7CJvIfdu2RNAZt2g1S7Xb0Jyy1hNC4MowOFU3nWkk,6283
15
+ heavyball/precond_schedule_sfpsoap.py,sha256=PNneiOkrRyV1yIZn91lPmYofd1_OiLqJTDy75RLpXJk,7273
16
+ heavyball/psgd_kron.py,sha256=RSLJi5FnSmYjvYBufNDClnYRm-eN_Kpa1Ar2tNP6-X0,5824
17
+ heavyball/pure_psgd.py,sha256=LZK0qmvZkBF8g00evaVLtW-sIUJmdoxag1K7O26AqEo,4820
18
+ heavyball/schedule_free_palm_foreach_soap.py,sha256=_AqrnChY6iDlQUkF2YUxS7eLjSWCIuvEUOHvMHVM1yY,6873
19
+ heavyball/utils.py,sha256=ESazD0yv14Aa8XKi_pz2CyfVkpcbgYcG2-WMvhQOnxk,35719
20
+ heavyball-0.20.0.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
21
+ heavyball-0.20.0.dist-info/METADATA,sha256=dJ43LOTrNqh7cDTDzZDSu57goP1gNhU3dfZ26BUK9hA,11926
22
+ heavyball-0.20.0.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
23
+ heavyball-0.20.0.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
24
+ heavyball-0.20.0.dist-info/RECORD,,
@@ -1,24 +0,0 @@
1
- heavyball/__init__.py,sha256=iqP428JWwwx-XDOZ0nUdbCkOLEyfoqVyWZLQLAcwxaw,2214
2
- heavyball/cached_delayed_psgd_kron.py,sha256=PQAER6UgVh5l87DGRZrJ8CVP9UhyCG5wJD9rPLnj_G8,6460
3
- heavyball/cached_psgd_kron.py,sha256=GaeneBp0irksCSBIrJY4D_0hCpZ-uSRPMhqVX_a-og8,6417
4
- heavyball/delayed_psgd.py,sha256=fhBWFLTSl1S2gHWCeYak-STaXRwpC56sWZGLFMKFEJM,5589
5
- heavyball/foreach_adamw.py,sha256=Rb5U80cgUcEqlEbUU250UTWdoqA7nyiqkV5w1U4bWX4,2445
6
- heavyball/foreach_adopt.py,sha256=ecdi1fKg9i087OGjtKWVbE_DD6Yf4pvpzv4ELCcusvQ,3211
7
- heavyball/foreach_laprop.py,sha256=vi6C_gfjXxw5uN0KHgzxI9itUI1dcgOf3ufoO_VVMp0,2471
8
- heavyball/foreach_sfadamw.py,sha256=rLZORmCIMu9G09FdDgMSiI6pNq34IVoxsPVWtmeDdbQ,2753
9
- heavyball/foreach_soap.py,sha256=h6ptMch7oaynvu3eIJtWnVXypDA_5JDVm3Zb3PNEma0,4634
10
- heavyball/p_adam.py,sha256=4zJDGJrpgUyVzr3GiELETFre4xr3-PE10OuAZj-jFM8,5883
11
- heavyball/palm_foreach_sfadamw.py,sha256=JbNrcoquBGGUI5XNMFouDjpNurVHUW9DbX1A3tSrtno,3025
12
- heavyball/palm_foreach_soap.py,sha256=g4hbiGRcti-J-a0SwAkP4ii5pU-aalsZH5bssyhroLk,5938
13
- heavyball/precond_schedule_foreach_soap.py,sha256=WLg5SzpJnKPZUvFyIvdwSZa1Umt5cpr3Kow_42orM-E,4863
14
- heavyball/precond_schedule_palm_foreach_soap.py,sha256=ammQrvRZFF-wc-wEiPEoFhS_7b8pdV61QfcLoQfimSo,6211
15
- heavyball/precond_schedule_sfpsoap.py,sha256=vq7jd302refKPa_9X2lkOTOtCCcTBVByPdojklrY8pA,6770
16
- heavyball/psgd_kron.py,sha256=wKjtI56iUnL5D8DseW60kxiXTAlMYNEf52CrvQaQMnI,5547
17
- heavyball/pure_psgd.py,sha256=iUy7mMKWxwNiVUMYrQ7SBnreu3t_XSbnhTW3a1yw4m0,4835
18
- heavyball/schedule_free_palm_foreach_soap.py,sha256=zkcikH5wWbzq4kOrmBjilvY3iWzuUddcv2HNEPKr3MI,6366
19
- heavyball/utils.py,sha256=BWscCHlGOw1_zfKYxNAAmfFeOXVpSJHuvqqlfL5A7_0,31690
20
- heavyball-0.19.0.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
21
- heavyball-0.19.0.dist-info/METADATA,sha256=1wORoS9rrjlug9tuJqXsbtVA9PphOBGcifiLRxmZNjs,11924
22
- heavyball-0.19.0.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
23
- heavyball-0.19.0.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
24
- heavyball-0.19.0.dist-info/RECORD,,