heavyball 0.14.7__py3-none-any.whl → 0.15.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
heavyball/foreach_soap.py CHANGED
@@ -34,73 +34,59 @@ class ForeachSOAP(StatefulOptimizer):
34
34
  super().__init__(params, defaults)
35
35
  self._data_format = data_format
36
36
 
37
- @torch.no_grad()
38
- def step(self, closure=None):
39
- """
40
- Performs a single optimization step.
41
-
42
- Arguments:
43
- closure (`Callable`, *optional*): A closure that reevaluates the model and returns the loss.
44
- """
45
- if closure is None:
46
- loss = None
47
- else:
48
- loss = closure()
49
-
50
- for group in self.param_groups:
51
- vals = []
52
- step = 0
53
-
54
- max_precond_dim = group['max_precond_dim']
55
- precondition_1d = group['precondition_1d']
56
-
57
- for p, g in split_p_and_g_in_group(group):
58
- state = self.state_(p)
59
- step = state['step'] = state.get("step", -1) + 1
60
-
61
- if "exp_avg" not in state:
62
- state["exp_avg"] = torch.zeros_like(g, dtype=torch.float32)
63
- state["exp_avg_sq"] = torch.zeros_like(g, dtype=torch.float32)
64
- init_preconditioner(g, state, max_precond_dim, precondition_1d)
65
- update_preconditioner(g, state, max_precond_dim, precondition_1d, 0, True)
66
- continue # first step is skipped so that we never use the current gradients in the projection.
67
-
68
- # Projecting gradients to the eigenbases of Shampoo's preconditioner
69
- # i.e. projecting to the eigenbases of matrices in state['GG']
70
- grad_projected = project(g, state['Q'], False)
71
- exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
72
- vals.append((p, g, grad_projected, exp_avg, exp_avg_sq))
73
-
74
- if not vals:
75
- continue
76
-
77
- p_list, grad, grad_projected, exp_avg, exp_avg_sq = zip(*vals)
78
- beta1, beta2 = group["betas"]
79
-
80
- old_debiased1 = beta_debias(beta1, step)
81
- old_debiased2 = beta_debias(beta2, step)
82
-
83
- # Decay the first and second moment running average coefficient
84
- # In-place operations to update the averages at the same time
85
- torch._foreach_mul_(exp_avg, old_debiased1)
86
- torch._foreach_add_(exp_avg, grad, alpha=1 - old_debiased1)
87
- denom = exp_avg_sq_(exp_avg_sq, grad_projected, old_debiased2, group['eps'])
88
-
89
- for p, g, ea, d in zip(p_list, grad, exp_avg, denom):
90
- state = self.state_(p)
91
- # Projecting the exponential moving average of gradients to the eigenbases of Shampoo's preconditioner
92
- # i.e. projecting to the eigenbases of matrices in state['GG']
93
- exp_avg_projected = project(ea, state['Q'], False)
94
-
95
- # Projecting back the preconditioned (by Adam) exponential moving average of gradients
96
- # to the original space
97
- # CANT DO /= HERE AS EXP_AVG MAY POINT TO THE BUFFER
98
- set_(d, project(exp_avg_projected / d, state['Q'], True))
99
-
100
- update_preconditioner(g, state, max_precond_dim, precondition_1d, old_debiased2,
101
- step > 0 and step % group['precondition_frequency'] == 0)
102
-
103
- # Why does this have to be rebiased here?
104
- step_size = -group["lr"] * min(step / group['warmup_steps'], 1)
105
- update_param_(p_list, denom, step_size, group["weight_decay"])
106
- return loss
37
+ def _step(self, group):
38
+ vals = []
39
+ step = 0
40
+
41
+ max_precond_dim = group['max_precond_dim']
42
+ precondition_1d = group['precondition_1d']
43
+
44
+ for p, g in split_p_and_g_in_group(group):
45
+ state = self.state_(p)
46
+ step = state['step'] = state.get("step", -1) + 1
47
+
48
+ if "exp_avg" not in state:
49
+ state["exp_avg"] = torch.zeros_like(g, dtype=torch.float32)
50
+ state["exp_avg_sq"] = torch.zeros_like(g, dtype=torch.float32)
51
+ init_preconditioner(g, state, max_precond_dim, precondition_1d)
52
+ update_preconditioner(g, state, max_precond_dim, precondition_1d, 0, True)
53
+ continue # first step is skipped so that we never use the current gradients in the projection.
54
+
55
+ # Projecting gradients to the eigenbases of Shampoo's preconditioner
56
+ # i.e. projecting to the eigenbases of matrices in state['GG']
57
+ grad_projected = project(g, state['Q'], False)
58
+ exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
59
+ vals.append((p, g, grad_projected, exp_avg, exp_avg_sq))
60
+
61
+ if not vals:
62
+ return
63
+
64
+ p_list, grad, grad_projected, exp_avg, exp_avg_sq = zip(*vals)
65
+ beta1, beta2 = group["betas"]
66
+
67
+ old_debiased1 = beta_debias(beta1, step)
68
+ old_debiased2 = beta_debias(beta2, step)
69
+
70
+ # Decay the first and second moment running average coefficient
71
+ # In-place operations to update the averages at the same time
72
+ torch._foreach_mul_(exp_avg, old_debiased1)
73
+ torch._foreach_add_(exp_avg, grad, alpha=1 - old_debiased1)
74
+ denom = exp_avg_sq_(exp_avg_sq, grad_projected, old_debiased2, group['eps'])
75
+
76
+ for p, g, ea, d in zip(p_list, grad, exp_avg, denom):
77
+ state = self.state_(p)
78
+ # Projecting the exponential moving average of gradients to the eigenbases of Shampoo's preconditioner
79
+ # i.e. projecting to the eigenbases of matrices in state['GG']
80
+ exp_avg_projected = project(ea, state['Q'], False)
81
+
82
+ # Projecting back the preconditioned (by Adam) exponential moving average of gradients
83
+ # to the original space
84
+ # CANT DO /= HERE AS EXP_AVG MAY POINT TO THE BUFFER
85
+ set_(d, project(exp_avg_projected / d, state['Q'], True))
86
+
87
+ update_preconditioner(g, state, max_precond_dim, precondition_1d, old_debiased2,
88
+ step > 0 and step % group['precondition_frequency'] == 0)
89
+
90
+ # Why does this have to be rebiased here?
91
+ step_size = -group["lr"] * min(step / group['warmup_steps'], 1)
92
+ update_param_(p_list, denom, step_size, group["weight_decay"])
heavyball/p_adam.py CHANGED
@@ -5,6 +5,7 @@ Source available at https://github.com/evanatyourservice/kron_torch/blob/97a2b5e
5
5
  """
6
6
 
7
7
  import torch
8
+ from heavyball.utils import triu_to_line, line_to_triu
8
9
 
9
10
  from .utils import update_param_, warmup, psgd_precond_grad, init_Q_exprs, PSGDBase, precond_update_prob_schedule, \
10
11
  exp_avg_sq_, beta_debias, split_p_and_g_in_group
@@ -36,7 +37,8 @@ class ForeachPaLMPAdam(PSGDBase):
36
37
  def __init__(self, params, lr=0.001, weight_decay=0.0, preconditioner_update_probability=None,
37
38
  max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
38
39
  momentum_into_precond_update=True, warmup_steps: int = 1, betas=(None, None), beta: float = 0.9,
39
- beta2_scale: float = 0.8, merge_dims: bool = False, split: bool = False, clip_fn: callable = None):
40
+ beta2_scale: float = 0.8, merge_dims: bool = False, split: bool = False, clip_fn: callable = None,
41
+ store_triu_as_line: bool = True):
40
42
  if not 0.0 <= lr:
41
43
  raise ValueError(f"Invalid learning rate: {lr}")
42
44
  if not 0.0 <= weight_decay:
@@ -57,18 +59,12 @@ class ForeachPaLMPAdam(PSGDBase):
57
59
  # precond lr hardcoded to 0.1
58
60
  precond_init_scale=1.0, # precond init scale hardcoded to 1.0
59
61
  step=0, warmup_steps=warmup_steps, beta=beta, beta2_scale=beta2_scale, merge_dims=merge_dims,
60
- split=split)
62
+ split=split, store_triu_as_line=store_triu_as_line)
61
63
  super().__init__(params, defaults)
62
64
 
63
65
  self._prob_step = 0
64
66
 
65
- @torch.no_grad()
66
- def step(self, closure=None):
67
- loss = None
68
- if closure is not None:
69
- with torch.enable_grad():
70
- loss = closure()
71
-
67
+ def _step(self, group):
72
68
  # update preconditioners all together
73
69
  update_prob = self.preconditioner_update_probability
74
70
  if callable(update_prob):
@@ -76,57 +72,57 @@ class ForeachPaLMPAdam(PSGDBase):
76
72
  do_update = self.rng.random() < update_prob
77
73
  self._prob_step += 1
78
74
 
79
- for group in self.param_groups:
80
- precond_init_scale = group['precond_init_scale']
81
- max_size_triangular = group['max_size_triangular']
82
- min_ndim_triangular = group['min_ndim_triangular']
83
- memory_save_mode = group['memory_save_mode']
84
- precond_lr = group['precond_lr']
85
- weight_decay = group['weight_decay']
86
- lr = group['lr']
87
-
88
- vals = []
75
+ precond_init_scale = group['precond_init_scale']
76
+ max_size_triangular = group['max_size_triangular']
77
+ min_ndim_triangular = group['min_ndim_triangular']
78
+ memory_save_mode = group['memory_save_mode']
79
+ precond_lr = group['precond_lr']
80
+ weight_decay = group['weight_decay']
81
+ lr = group['lr']
82
+ store_triu_as_line = group['store_triu_as_line']
89
83
 
90
- for p, g in split_p_and_g_in_group(group):
91
- state = self.state_(p)
84
+ vals = []
92
85
 
93
- if 'Q' not in state:
94
- state['exp_avg'] = torch.zeros_like(g)
95
- state['exp_avg_sq'] = torch.zeros_like(g)
96
- state["Q"], state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular,
97
- min_ndim_triangular, memory_save_mode, dtype=g.dtype)
86
+ for p, g in split_p_and_g_in_group(group):
87
+ state = self.state_(p)
98
88
 
99
- vals.append((p, g, state["Q"], state['exp_avg'], state['exp_avg_sq']))
89
+ if 'Q' not in state:
90
+ state['exp_avg'] = torch.zeros_like(g)
91
+ state['exp_avg_sq'] = torch.zeros_like(g)
92
+ Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular,
93
+ min_ndim_triangular, memory_save_mode, dtype=g.dtype)
94
+ state['Q'] = triu_to_line(Q) if store_triu_as_line else Q
100
95
 
101
- if not vals:
102
- continue
96
+ vals.append((p, g, state["Q"], state['exp_avg'], state['exp_avg_sq']))
103
97
 
104
- p_list, grad_list, Q_list, exp_avg, exp_avg_sq = zip(*vals)
105
- del vals
98
+ if not vals:
99
+ return
106
100
 
107
- group["step"] += 1
101
+ p_list, grad_list, Q_list, exp_avg, exp_avg_sq = zip(*vals)
102
+ del vals
108
103
 
109
- self.balance(do_update, grad_list, Q_list)
110
- if do_update:
111
- self.do_update(p_list, grad_list, Q_list, precond_lr)
104
+ group["step"] += 1
112
105
 
113
- torch._foreach_lerp_(exp_avg, grad_list, 1 - beta_debias(group['beta'], group['step']))
106
+ Q_triu = [line_to_triu(q) if store_triu_as_line else q for q in Q_list]
107
+ if do_update:
108
+ self.balance(grad_list, Q_triu)
109
+ self.do_update(p_list, grad_list, Q_triu, precond_lr, Q_list if store_triu_as_line else None)
114
110
 
115
- beta2 = 1 - group['step'] ** -group['beta2_scale']
111
+ torch._foreach_lerp_(exp_avg, grad_list, 1 - beta_debias(group['beta'], group['step']))
116
112
 
117
- for p, Q, g, ea, eas in zip(p_list, Q_list, grad_list, exp_avg, exp_avg_sq):
118
- psgd_precond_grad(Q, self.state_(p)["exprs"], g, inplace=True)
119
- ea = psgd_precond_grad(Q, self.state_(p)["exprs"], ea)
120
- exp_avg_sq_(eas, g, beta_debias(beta2, group['step']), 1e-8, out=g)
121
- torch.div(ea, g, out=g)
122
- """
123
- divide by g here, because g == denom (from exp_avg_sq_(out=g)), avoids denom allocation
124
- divide into g so we can deallocate ea, avoids one allocation (-> less memory than equivalent foreach)
125
- """
113
+ beta2 = 1 - group['step'] ** -group['beta2_scale']
126
114
 
127
- grad_list = self.clip_fn(grad_list)
115
+ for p, Q, g, ea, eas in zip(p_list, Q_triu, grad_list, exp_avg, exp_avg_sq):
116
+ psgd_precond_grad(Q, self.state_(p)["exprs"], g, inplace=True)
117
+ ea = psgd_precond_grad(Q, self.state_(p)["exprs"], ea)
118
+ exp_avg_sq_(eas, g, beta_debias(beta2, group['step']), 1e-8, out=g)
119
+ torch.div(ea, g, out=g)
120
+ """
121
+ divide by g here, because g == denom (from exp_avg_sq_(out=g)), avoids denom allocation
122
+ divide into g so we can deallocate ea, avoids one allocation (-> less memory than equivalent foreach)
123
+ """
128
124
 
129
- lr = -warmup(lr, group['step'], group['warmup_steps'])
130
- update_param_(p_list, grad_list, lr, weight_decay)
125
+ grad_list = self.clip_fn(grad_list)
131
126
 
132
- return loss
127
+ lr = -warmup(lr, group['step'], group['warmup_steps'])
128
+ update_param_(p_list, grad_list, lr, weight_decay)
@@ -9,58 +9,48 @@ class PaLMForeachSFAdamW(ScheduleFree):
9
9
  weight_lr_power=2.0, beta2_scale: float = 0.8):
10
10
  if betas[0] is not None:
11
11
  beta = betas[0]
12
- defaults = dict(lr=lr, beta=beta, eps=eps, r=r, k=0, warmup_steps=warmup_steps, train_mode=True,
13
- weight_sum=0.0, lr_max=-1.0, weight_lr_power=weight_lr_power, weight_decay=weight_decay,
12
+ defaults = dict(lr=lr, beta=beta, eps=eps, r=r, k=0, warmup_steps=warmup_steps, train_mode=True, weight_sum=0.0,
13
+ lr_max=-1.0, weight_lr_power=weight_lr_power, weight_decay=weight_decay,
14
14
  beta2_scale=beta2_scale)
15
15
  super().__init__(params, defaults)
16
16
 
17
- def step(self, closure=None):
18
- """Performs a single optimization step.
17
+ def _step(self, group):
18
+ eps = group['eps']
19
+ decay = group['weight_decay']
20
+ k = group['k']
19
21
 
20
- Arguments:
21
- closure (callable, optional): A closure that reevaluates the model
22
- and returns the loss.
23
- """
22
+ if not group['train_mode']:
23
+ raise Exception("Not in train mode!")
24
24
 
25
- loss = None
26
- if closure is not None:
27
- loss = closure()
25
+ active_p = [p for p in group['params'] if p.grad is not None]
28
26
 
29
- for group in self.param_groups:
30
- eps = group['eps']
31
- decay = group['weight_decay']
32
- k = group['k']
27
+ if not active_p:
28
+ return
33
29
 
34
- if not group['train_mode']:
35
- raise Exception("Not in train mode!")
30
+ for p in active_p:
31
+ if 'z' not in self.state_(p):
32
+ self.state_(p)['z'] = torch.clone(p.data)
33
+ self.state_(p)['exp_avg_sq'] = torch.zeros_like(p.data, dtype=torch.float32)
36
34
 
37
- active_p = [p for p in group['params'] if p.grad is not None]
35
+ y, grad, exp_avg_sq, z = zip(
36
+ *[(p.data, p.grad.float(), self.state_(p)['exp_avg_sq'], self.state_(p)['z']) for p in active_p])
38
37
 
39
- for p in active_p:
40
- if 'z' not in self.state_(p):
41
- self.state_(p)['z'] = torch.clone(p.data)
42
- self.state_(p)['exp_avg_sq'] = torch.zeros_like(p.data, dtype=torch.float32)
38
+ # Decay the first moment running average coefficient
39
+ beta2 = 1 - (k + 1) ** -group['beta2_scale']
40
+ old_debiased = beta_debias(beta2, k + 1)
43
41
 
44
- y, grad, exp_avg_sq, z = zip(
45
- *[(p.data, p.grad.float(), self.state_(p)['exp_avg_sq'], self.state_(p)['z']) for p in active_p])
42
+ # Decay the first and second moment running average coefficient
43
+ denom = exp_avg_sq_(exp_avg_sq, grad, old_debiased, eps)
46
44
 
47
- # Decay the first moment running average coefficient
48
- beta2 = 1 - (k + 1) ** -group['beta2_scale']
49
- old_debiased = beta_debias(beta2, k + 1)
45
+ # Normalize grad in-place for memory efficiency
46
+ torch._foreach_div_(grad, denom)
50
47
 
51
- # Decay the first and second moment running average coefficient
52
- denom = exp_avg_sq_(exp_avg_sq, grad, old_debiased, eps)
48
+ # Weight decay calculated at y
49
+ if decay != 0:
50
+ torch._foreach_add_(grad, y, alpha=decay)
53
51
 
54
- # Normalize grad in-place for memory efficiency
55
- torch._foreach_div_(grad, denom)
52
+ lr = warmup(group['lr'], k + 1, group['warmup_steps'])
53
+ group['weight_sum'] = schedule_free_(lr, group['weight_lr_power'], group['weight_sum'], group['beta'], y, z,
54
+ grad, group['r'], k + 1)
56
55
 
57
- # Weight decay calculated at y
58
- if decay != 0:
59
- torch._foreach_add_(grad, y, alpha=decay)
60
-
61
- lr = warmup(group['lr'], k + 1, group['warmup_steps'])
62
- group['weight_sum'] = schedule_free_(lr, group['weight_lr_power'], group['weight_sum'], group['beta'],
63
- y, z, grad, group['r'], k + 1)
64
-
65
- group['k'] = k + 1
66
- return loss
56
+ group['k'] = k + 1
@@ -43,73 +43,59 @@ class PaLMForeachSOAP(StatefulOptimizer):
43
43
  super().__init__(params, defaults)
44
44
  self._data_format = data_format
45
45
 
46
- @torch.no_grad()
47
- def step(self, closure=None):
48
- """
49
- Performs a single optimization step.
50
-
51
- Arguments:
52
- closure (`Callable`, *optional*): A closure that reevaluates the model and returns the loss.
53
- """
54
- if closure is None:
55
- loss = None
56
- else:
57
- loss = closure()
58
-
59
- for group in self.param_groups:
60
- vals = []
61
- step = 0
62
-
63
- max_precond_dim = group['max_precond_dim']
64
- precondition_1d = group['precondition_1d']
65
-
66
- for p, g in split_p_and_g_in_group(group):
67
- state = self.state_(p)
68
- step = state['step'] = state.get("step", -1) + 1
69
-
70
- if "exp_avg" not in state:
71
- state["exp_avg"] = torch.zeros_like(g, dtype=torch.float32)
72
- state["exp_avg_sq"] = torch.zeros_like(g, dtype=torch.float32)
73
- init_preconditioner(g, state, max_precond_dim, precondition_1d)
74
- update_preconditioner(g, state, max_precond_dim, precondition_1d, 0, True)
75
- continue # first step is skipped so that we never use the current gradients in the projection.
76
-
77
- # Projecting gradients to the eigenbases of Shampoo's preconditioner
78
- # i.e. projecting to the eigenbases of matrices in state['GG']
79
- grad_projected = project(g, state['Q'], False)
80
- exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
81
- vals.append((p, g, grad_projected, exp_avg, exp_avg_sq))
82
-
83
- if not vals:
84
- continue
85
-
86
- p_list, grad, grad_projected, exp_avg, exp_avg_sq = zip(*vals)
87
- beta1 = group["beta"]
88
-
89
- beta2 = 1 - step ** -group['beta2_scale']
90
- old_debiased1 = beta_debias(beta1, step)
91
- old_debiased2 = beta_debias(beta2, step)
92
-
93
- # Decay the first and second moment running average coefficient
94
- # In-place operations to update the averages at the same time
95
- torch._foreach_lerp_(exp_avg, grad, 1 - old_debiased1)
96
- denom = exp_avg_sq_(exp_avg_sq, grad_projected, old_debiased2, group['eps'])
97
-
98
- for p, g, ea, d in zip(p_list, grad, exp_avg, denom):
99
- state = self.state_(p)
100
- # Projecting the exponential moving average of gradients to the eigenbases of Shampoo's preconditioner
101
- # i.e. projecting to the eigenbases of matrices in state['GG']
102
- exp_avg_projected = project(ea, state['Q'], False)
103
-
104
- # Projecting back the preconditioned (by Adam) exponential moving average of gradients
105
- # to the original space
106
- # CANT DO /= HERE AS EXP_AVG MAY POINT TO THE BUFFER
107
- set_(d, project(exp_avg_projected / d, state['Q'], True))
108
-
109
- update_preconditioner(g, state, max_precond_dim, precondition_1d, old_debiased2,
110
- step > 0 and step % group['precondition_frequency'] == 0)
111
-
112
- # Why does this have to be rebiased here?
113
- step_size = -group["lr"] * min(step / group['warmup_steps'], 1)
114
- update_param_(p_list, denom, step_size, group["weight_decay"])
115
- return loss
46
+ def _step(self, group):
47
+ vals = []
48
+ step = 0
49
+
50
+ max_precond_dim = group['max_precond_dim']
51
+ precondition_1d = group['precondition_1d']
52
+
53
+ for p, g in split_p_and_g_in_group(group):
54
+ state = self.state_(p)
55
+ step = state['step'] = state.get("step", -1) + 1
56
+
57
+ if "exp_avg" not in state:
58
+ state["exp_avg"] = torch.zeros_like(g, dtype=torch.float32)
59
+ state["exp_avg_sq"] = torch.zeros_like(g, dtype=torch.float32)
60
+ init_preconditioner(g, state, max_precond_dim, precondition_1d)
61
+ update_preconditioner(g, state, max_precond_dim, precondition_1d, 0, True)
62
+ continue # first step is skipped so that we never use the current gradients in the projection.
63
+
64
+ # Projecting gradients to the eigenbases of Shampoo's preconditioner
65
+ # i.e. projecting to the eigenbases of matrices in state['GG']
66
+ grad_projected = project(g, state['Q'], False)
67
+ exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
68
+ vals.append((p, g, grad_projected, exp_avg, exp_avg_sq))
69
+
70
+ if not vals:
71
+ return
72
+
73
+ p_list, grad, grad_projected, exp_avg, exp_avg_sq = zip(*vals)
74
+ beta1 = group["beta"]
75
+
76
+ beta2 = 1 - step ** -group['beta2_scale']
77
+ old_debiased1 = beta_debias(beta1, step)
78
+ old_debiased2 = beta_debias(beta2, step)
79
+
80
+ # Decay the first and second moment running average coefficient
81
+ # In-place operations to update the averages at the same time
82
+ torch._foreach_lerp_(exp_avg, grad, 1 - old_debiased1)
83
+ denom = exp_avg_sq_(exp_avg_sq, grad_projected, old_debiased2, group['eps'])
84
+
85
+ for p, g, ea, d in zip(p_list, grad, exp_avg, denom):
86
+ state = self.state_(p)
87
+ # Projecting the exponential moving average of gradients to the eigenbases of Shampoo's preconditioner
88
+ # i.e. projecting to the eigenbases of matrices in state['GG']
89
+ exp_avg_projected = project(ea, state['Q'], False)
90
+
91
+ # Projecting back the preconditioned (by Adam) exponential moving average of gradients
92
+ # to the original space
93
+ # CANT DO /= HERE AS EXP_AVG MAY POINT TO THE BUFFER
94
+ set_(d, project(exp_avg_projected / d, state['Q'], True))
95
+
96
+ update_preconditioner(g, state, max_precond_dim, precondition_1d, old_debiased2,
97
+ step > 0 and step % group['precondition_frequency'] == 0)
98
+
99
+ # Why does this have to be rebiased here?
100
+ step_size = -group["lr"] * min(step / group['warmup_steps'], 1)
101
+ update_param_(p_list, denom, step_size, group["weight_decay"])
@@ -37,74 +37,60 @@ class PrecondScheduleForeachSOAP(StatefulOptimizer):
37
37
  self._data_format = data_format
38
38
  self.rng = random.Random(0x120983109)
39
39
 
40
- @torch.no_grad()
41
- def step(self, closure=None):
42
- """
43
- Performs a single optimization step.
44
-
45
- Arguments:
46
- closure (`Callable`, *optional*): A closure that reevaluates the model and returns the loss.
47
- """
48
- if closure is None:
49
- loss = None
50
- else:
51
- loss = closure()
52
-
53
- for group in self.param_groups:
54
- vals = []
55
- step = 0
56
-
57
- max_precond_dim = group['max_precond_dim']
58
- precondition_1d = group['precondition_1d']
59
-
60
- for p, g in split_p_and_g_in_group(group):
61
- state = self.state_(p)
62
- step = state['step'] = state.get("step", -1) + 1
63
-
64
- if "exp_avg" not in state:
65
- state["exp_avg"] = torch.zeros_like(g, dtype=torch.float32)
66
- state["exp_avg_sq"] = torch.zeros_like(g, dtype=torch.float32)
67
- init_preconditioner(g, state, max_precond_dim, precondition_1d)
68
- update_preconditioner(g, state, max_precond_dim, precondition_1d, 0, True)
69
- continue # first step is skipped so that we never use the current gradients in the projection.
70
-
71
- # Projecting gradients to the eigenbases of Shampoo's preconditioner
72
- # i.e. projecting to the eigenbases of matrices in state['GG']
73
- grad_projected = project(g, state['Q'], False)
74
- exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
75
- vals.append((p, g, grad_projected, exp_avg, exp_avg_sq))
76
-
77
- if not vals:
78
- continue
79
-
80
- p_list, grad, grad_projected, exp_avg, exp_avg_sq = zip(*vals)
81
- beta1, beta2 = group["betas"]
82
-
83
- old_debiased1 = beta_debias(beta1, step)
84
- old_debiased2 = beta_debias(beta2, step)
85
-
86
- # Decay the first and second moment running average coefficient
87
- # In-place operations to update the averages at the same time
88
- torch._foreach_mul_(exp_avg, old_debiased1)
89
- torch._foreach_add_(exp_avg, grad, alpha=1 - old_debiased1)
90
- denom = exp_avg_sq_(exp_avg_sq, grad_projected, old_debiased2, group['eps'])
91
-
92
- update_precond = precond_schedule(step, group['precond_scheduler'], self.rng)
93
- for p, g, ea, d in zip(p_list, grad, exp_avg, denom):
94
- state = self.state_(p)
95
- # Projecting the exponential moving average of gradients to the eigenbases of Shampoo's preconditioner
96
- # i.e. projecting to the eigenbases of matrices in state['GG']
97
- exp_avg_projected = project(ea, state['Q'], False)
98
-
99
- # Projecting back the preconditioned (by Adam) exponential moving average of gradients
100
- # to the original space
101
- # CANT DO /= HERE AS EXP_AVG MAY POINT TO THE BUFFER
102
- set_(d, project(exp_avg_projected / d, state['Q'], True))
103
-
104
- update_preconditioner(g, state, max_precond_dim, precondition_1d, old_debiased2,
105
- update_precond)
106
-
107
- # Why does this have to be rebiased here?
108
- step_size = -group["lr"] * min(step / group['warmup_steps'], 1)
109
- update_param_(p_list, denom, step_size, group["weight_decay"])
110
- return loss
40
+ def _step(self, group):
41
+ vals = []
42
+ step = 0
43
+
44
+ max_precond_dim = group['max_precond_dim']
45
+ precondition_1d = group['precondition_1d']
46
+
47
+ for p, g in split_p_and_g_in_group(group):
48
+ state = self.state_(p)
49
+ step = state['step'] = state.get("step", -1) + 1
50
+
51
+ if "exp_avg" not in state:
52
+ state["exp_avg"] = torch.zeros_like(g, dtype=torch.float32)
53
+ state["exp_avg_sq"] = torch.zeros_like(g, dtype=torch.float32)
54
+ init_preconditioner(g, state, max_precond_dim, precondition_1d)
55
+ update_preconditioner(g, state, max_precond_dim, precondition_1d, 0, True)
56
+ continue # first step is skipped so that we never use the current gradients in the projection.
57
+
58
+ # Projecting gradients to the eigenbases of Shampoo's preconditioner
59
+ # i.e. projecting to the eigenbases of matrices in state['GG']
60
+ grad_projected = project(g, state['Q'], False)
61
+ exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
62
+ vals.append((p, g, grad_projected, exp_avg, exp_avg_sq))
63
+
64
+ if not vals:
65
+ return
66
+
67
+ p_list, grad, grad_projected, exp_avg, exp_avg_sq = zip(*vals)
68
+ beta1, beta2 = group["betas"]
69
+
70
+ old_debiased1 = beta_debias(beta1, step)
71
+ old_debiased2 = beta_debias(beta2, step)
72
+
73
+ # Decay the first and second moment running average coefficient
74
+ # In-place operations to update the averages at the same time
75
+ torch._foreach_mul_(exp_avg, old_debiased1)
76
+ torch._foreach_add_(exp_avg, grad, alpha=1 - old_debiased1)
77
+ denom = exp_avg_sq_(exp_avg_sq, grad_projected, old_debiased2, group['eps'])
78
+
79
+ update_precond = precond_schedule(step, group['precond_scheduler'], self.rng)
80
+ for p, g, ea, d in zip(p_list, grad, exp_avg, denom):
81
+ state = self.state_(p)
82
+ # Projecting the exponential moving average of gradients to the eigenbases of Shampoo's preconditioner
83
+ # i.e. projecting to the eigenbases of matrices in state['GG']
84
+ exp_avg_projected = project(ea, state['Q'], False)
85
+
86
+ # Projecting back the preconditioned (by Adam) exponential moving average of gradients
87
+ # to the original space
88
+ # CANT DO /= HERE AS EXP_AVG MAY POINT TO THE BUFFER
89
+ set_(d, project(exp_avg_projected / d, state['Q'], True))
90
+
91
+ update_preconditioner(g, state, max_precond_dim, precondition_1d, old_debiased2,
92
+ update_precond)
93
+
94
+ # Why does this have to be rebiased here?
95
+ step_size = -group["lr"] * min(step / group['warmup_steps'], 1)
96
+ update_param_(p_list, denom, step_size, group["weight_decay"])