heavyball 0.14.6__py3-none-any.whl → 0.15.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
heavyball/p_adam.py CHANGED
@@ -62,13 +62,7 @@ class ForeachPaLMPAdam(PSGDBase):
62
62
 
63
63
  self._prob_step = 0
64
64
 
65
- @torch.no_grad()
66
- def step(self, closure=None):
67
- loss = None
68
- if closure is not None:
69
- with torch.enable_grad():
70
- loss = closure()
71
-
65
+ def _step(self, group):
72
66
  # update preconditioners all together
73
67
  update_prob = self.preconditioner_update_probability
74
68
  if callable(update_prob):
@@ -76,57 +70,54 @@ class ForeachPaLMPAdam(PSGDBase):
76
70
  do_update = self.rng.random() < update_prob
77
71
  self._prob_step += 1
78
72
 
79
- for group in self.param_groups:
80
- precond_init_scale = group['precond_init_scale']
81
- max_size_triangular = group['max_size_triangular']
82
- min_ndim_triangular = group['min_ndim_triangular']
83
- memory_save_mode = group['memory_save_mode']
84
- precond_lr = group['precond_lr']
85
- weight_decay = group['weight_decay']
86
- lr = group['lr']
87
-
88
- vals = []
73
+ precond_init_scale = group['precond_init_scale']
74
+ max_size_triangular = group['max_size_triangular']
75
+ min_ndim_triangular = group['min_ndim_triangular']
76
+ memory_save_mode = group['memory_save_mode']
77
+ precond_lr = group['precond_lr']
78
+ weight_decay = group['weight_decay']
79
+ lr = group['lr']
89
80
 
90
- for p, g in split_p_and_g_in_group(group):
91
- state = self.state_(p)
81
+ vals = []
92
82
 
93
- if 'Q' not in state:
94
- state['exp_avg'] = torch.zeros_like(g)
95
- state['exp_avg_sq'] = torch.zeros_like(g)
96
- state["Q"], state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular,
97
- min_ndim_triangular, memory_save_mode, dtype=g.dtype)
83
+ for p, g in split_p_and_g_in_group(group):
84
+ state = self.state_(p)
98
85
 
99
- vals.append((p, g, state["Q"], state['exp_avg'], state['exp_avg_sq']))
86
+ if 'Q' not in state:
87
+ state['exp_avg'] = torch.zeros_like(g)
88
+ state['exp_avg_sq'] = torch.zeros_like(g)
89
+ state["Q"], state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular,
90
+ min_ndim_triangular, memory_save_mode, dtype=g.dtype)
100
91
 
101
- if not vals:
102
- continue
92
+ vals.append((p, g, state["Q"], state['exp_avg'], state['exp_avg_sq']))
103
93
 
104
- p_list, grad_list, Q_list, exp_avg, exp_avg_sq = zip(*vals)
105
- del vals
94
+ if not vals:
95
+ return
106
96
 
107
- group["step"] += 1
97
+ p_list, grad_list, Q_list, exp_avg, exp_avg_sq = zip(*vals)
98
+ del vals
108
99
 
109
- self.balance(do_update, grad_list, Q_list)
110
- if do_update:
111
- self.do_update(p_list, grad_list, Q_list, precond_lr)
100
+ group["step"] += 1
112
101
 
113
- torch._foreach_lerp_(exp_avg, grad_list, 1 - beta_debias(group['beta'], group['step']))
102
+ self.balance(do_update, grad_list, Q_list)
103
+ if do_update:
104
+ self.do_update(p_list, grad_list, Q_list, precond_lr)
114
105
 
115
- beta2 = 1 - group['step'] ** -group['beta2_scale']
106
+ torch._foreach_lerp_(exp_avg, grad_list, 1 - beta_debias(group['beta'], group['step']))
116
107
 
117
- for p, Q, g, ea, eas in zip(p_list, Q_list, grad_list, exp_avg, exp_avg_sq):
118
- psgd_precond_grad(Q, self.state_(p)["exprs"], g, inplace=True)
119
- ea = psgd_precond_grad(Q, self.state_(p)["exprs"], ea)
120
- exp_avg_sq_(eas, g, beta_debias(beta2, group['step']), 1e-8, out=g)
121
- torch.div(ea, g, out=g)
122
- """
123
- divide by g here, because g == denom (from exp_avg_sq_(out=g)), avoids denom allocation
124
- divide into g so we can deallocate ea, avoids one allocation (-> less memory than equivalent foreach)
125
- """
108
+ beta2 = 1 - group['step'] ** -group['beta2_scale']
126
109
 
127
- grad_list = self.clip_fn(grad_list)
110
+ for p, Q, g, ea, eas in zip(p_list, Q_list, grad_list, exp_avg, exp_avg_sq):
111
+ psgd_precond_grad(Q, self.state_(p)["exprs"], g, inplace=True)
112
+ ea = psgd_precond_grad(Q, self.state_(p)["exprs"], ea)
113
+ exp_avg_sq_(eas, g, beta_debias(beta2, group['step']), 1e-8, out=g)
114
+ torch.div(ea, g, out=g)
115
+ """
116
+ divide by g here, because g == denom (from exp_avg_sq_(out=g)), avoids denom allocation
117
+ divide into g so we can deallocate ea, avoids one allocation (-> less memory than equivalent foreach)
118
+ """
128
119
 
129
- lr = -warmup(lr, group['step'], group['warmup_steps'])
130
- update_param_(p_list, grad_list, lr, weight_decay)
120
+ grad_list = self.clip_fn(grad_list)
131
121
 
132
- return loss
122
+ lr = -warmup(lr, group['step'], group['warmup_steps'])
123
+ update_param_(p_list, grad_list, lr, weight_decay)
@@ -9,58 +9,48 @@ class PaLMForeachSFAdamW(ScheduleFree):
9
9
  weight_lr_power=2.0, beta2_scale: float = 0.8):
10
10
  if betas[0] is not None:
11
11
  beta = betas[0]
12
- defaults = dict(lr=lr, beta=beta, eps=eps, r=r, k=0, warmup_steps=warmup_steps, train_mode=True,
13
- weight_sum=0.0, lr_max=-1.0, weight_lr_power=weight_lr_power, weight_decay=weight_decay,
12
+ defaults = dict(lr=lr, beta=beta, eps=eps, r=r, k=0, warmup_steps=warmup_steps, train_mode=True, weight_sum=0.0,
13
+ lr_max=-1.0, weight_lr_power=weight_lr_power, weight_decay=weight_decay,
14
14
  beta2_scale=beta2_scale)
15
15
  super().__init__(params, defaults)
16
16
 
17
- def step(self, closure=None):
18
- """Performs a single optimization step.
17
+ def _step(self, group):
18
+ eps = group['eps']
19
+ decay = group['weight_decay']
20
+ k = group['k']
19
21
 
20
- Arguments:
21
- closure (callable, optional): A closure that reevaluates the model
22
- and returns the loss.
23
- """
22
+ if not group['train_mode']:
23
+ raise Exception("Not in train mode!")
24
24
 
25
- loss = None
26
- if closure is not None:
27
- loss = closure()
25
+ active_p = [p for p in group['params'] if p.grad is not None]
28
26
 
29
- for group in self.param_groups:
30
- eps = group['eps']
31
- decay = group['weight_decay']
32
- k = group['k']
27
+ if not active_p:
28
+ return
33
29
 
34
- if not group['train_mode']:
35
- raise Exception("Not in train mode!")
30
+ for p in active_p:
31
+ if 'z' not in self.state_(p):
32
+ self.state_(p)['z'] = torch.clone(p.data)
33
+ self.state_(p)['exp_avg_sq'] = torch.zeros_like(p.data, dtype=torch.float32)
36
34
 
37
- active_p = [p for p in group['params'] if p.grad is not None]
35
+ y, grad, exp_avg_sq, z = zip(
36
+ *[(p.data, p.grad.float(), self.state_(p)['exp_avg_sq'], self.state_(p)['z']) for p in active_p])
38
37
 
39
- for p in active_p:
40
- if 'z' not in self.state_(p):
41
- self.state_(p)['z'] = torch.clone(p.data)
42
- self.state_(p)['exp_avg_sq'] = torch.zeros_like(p.data, dtype=torch.float32)
38
+ # Decay the first moment running average coefficient
39
+ beta2 = 1 - (k + 1) ** -group['beta2_scale']
40
+ old_debiased = beta_debias(beta2, k + 1)
43
41
 
44
- y, grad, exp_avg_sq, z = zip(
45
- *[(p.data, p.grad.float(), self.state_(p)['exp_avg_sq'], self.state_(p)['z']) for p in active_p])
42
+ # Decay the first and second moment running average coefficient
43
+ denom = exp_avg_sq_(exp_avg_sq, grad, old_debiased, eps)
46
44
 
47
- # Decay the first moment running average coefficient
48
- beta2 = 1 - (k + 1) ** -group['beta2_scale']
49
- old_debiased = beta_debias(beta2, k + 1)
45
+ # Normalize grad in-place for memory efficiency
46
+ torch._foreach_div_(grad, denom)
50
47
 
51
- # Decay the first and second moment running average coefficient
52
- denom = exp_avg_sq_(exp_avg_sq, grad, old_debiased, eps)
48
+ # Weight decay calculated at y
49
+ if decay != 0:
50
+ torch._foreach_add_(grad, y, alpha=decay)
53
51
 
54
- # Normalize grad in-place for memory efficiency
55
- torch._foreach_div_(grad, denom)
52
+ lr = warmup(group['lr'], k + 1, group['warmup_steps'])
53
+ group['weight_sum'] = schedule_free_(lr, group['weight_lr_power'], group['weight_sum'], group['beta'], y, z,
54
+ grad, group['r'], k + 1)
56
55
 
57
- # Weight decay calculated at y
58
- if decay != 0:
59
- torch._foreach_add_(grad, y, alpha=decay)
60
-
61
- lr = warmup(group['lr'], k + 1, group['warmup_steps'])
62
- group['weight_sum'] = schedule_free_(lr, group['weight_lr_power'], group['weight_sum'], group['beta'],
63
- y, z, grad, group['r'], k + 1)
64
-
65
- group['k'] = k + 1
66
- return loss
56
+ group['k'] = k + 1
@@ -43,73 +43,59 @@ class PaLMForeachSOAP(StatefulOptimizer):
43
43
  super().__init__(params, defaults)
44
44
  self._data_format = data_format
45
45
 
46
- @torch.no_grad()
47
- def step(self, closure=None):
48
- """
49
- Performs a single optimization step.
50
-
51
- Arguments:
52
- closure (`Callable`, *optional*): A closure that reevaluates the model and returns the loss.
53
- """
54
- if closure is None:
55
- loss = None
56
- else:
57
- loss = closure()
58
-
59
- for group in self.param_groups:
60
- vals = []
61
- step = 0
62
-
63
- max_precond_dim = group['max_precond_dim']
64
- precondition_1d = group['precondition_1d']
65
-
66
- for p, g in split_p_and_g_in_group(group):
67
- state = self.state_(p)
68
- step = state['step'] = state.get("step", -1) + 1
69
-
70
- if "exp_avg" not in state:
71
- state["exp_avg"] = torch.zeros_like(g, dtype=torch.float32)
72
- state["exp_avg_sq"] = torch.zeros_like(g, dtype=torch.float32)
73
- init_preconditioner(g, state, max_precond_dim, precondition_1d)
74
- update_preconditioner(g, state, max_precond_dim, precondition_1d, 0, True)
75
- continue # first step is skipped so that we never use the current gradients in the projection.
76
-
77
- # Projecting gradients to the eigenbases of Shampoo's preconditioner
78
- # i.e. projecting to the eigenbases of matrices in state['GG']
79
- grad_projected = project(g, state['Q'], False)
80
- exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
81
- vals.append((p, g, grad_projected, exp_avg, exp_avg_sq))
82
-
83
- if not vals:
84
- continue
85
-
86
- p_list, grad, grad_projected, exp_avg, exp_avg_sq = zip(*vals)
87
- beta1 = group["beta"]
88
-
89
- beta2 = 1 - step ** -group['beta2_scale']
90
- old_debiased1 = beta_debias(beta1, step)
91
- old_debiased2 = beta_debias(beta2, step)
92
-
93
- # Decay the first and second moment running average coefficient
94
- # In-place operations to update the averages at the same time
95
- torch._foreach_lerp_(exp_avg, grad, 1 - old_debiased1)
96
- denom = exp_avg_sq_(exp_avg_sq, grad_projected, old_debiased2, group['eps'])
97
-
98
- for p, g, ea, d in zip(p_list, grad, exp_avg, denom):
99
- state = self.state_(p)
100
- # Projecting the exponential moving average of gradients to the eigenbases of Shampoo's preconditioner
101
- # i.e. projecting to the eigenbases of matrices in state['GG']
102
- exp_avg_projected = project(ea, state['Q'], False)
103
-
104
- # Projecting back the preconditioned (by Adam) exponential moving average of gradients
105
- # to the original space
106
- # CANT DO /= HERE AS EXP_AVG MAY POINT TO THE BUFFER
107
- set_(d, project(exp_avg_projected / d, state['Q'], True))
108
-
109
- update_preconditioner(g, state, max_precond_dim, precondition_1d, old_debiased2,
110
- step > 0 and step % group['precondition_frequency'] == 0)
111
-
112
- # Why does this have to be rebiased here?
113
- step_size = -group["lr"] * min(step / group['warmup_steps'], 1)
114
- update_param_(p_list, denom, step_size, group["weight_decay"])
115
- return loss
46
+ def _step(self, group):
47
+ vals = []
48
+ step = 0
49
+
50
+ max_precond_dim = group['max_precond_dim']
51
+ precondition_1d = group['precondition_1d']
52
+
53
+ for p, g in split_p_and_g_in_group(group):
54
+ state = self.state_(p)
55
+ step = state['step'] = state.get("step", -1) + 1
56
+
57
+ if "exp_avg" not in state:
58
+ state["exp_avg"] = torch.zeros_like(g, dtype=torch.float32)
59
+ state["exp_avg_sq"] = torch.zeros_like(g, dtype=torch.float32)
60
+ init_preconditioner(g, state, max_precond_dim, precondition_1d)
61
+ update_preconditioner(g, state, max_precond_dim, precondition_1d, 0, True)
62
+ continue # first step is skipped so that we never use the current gradients in the projection.
63
+
64
+ # Projecting gradients to the eigenbases of Shampoo's preconditioner
65
+ # i.e. projecting to the eigenbases of matrices in state['GG']
66
+ grad_projected = project(g, state['Q'], False)
67
+ exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
68
+ vals.append((p, g, grad_projected, exp_avg, exp_avg_sq))
69
+
70
+ if not vals:
71
+ return
72
+
73
+ p_list, grad, grad_projected, exp_avg, exp_avg_sq = zip(*vals)
74
+ beta1 = group["beta"]
75
+
76
+ beta2 = 1 - step ** -group['beta2_scale']
77
+ old_debiased1 = beta_debias(beta1, step)
78
+ old_debiased2 = beta_debias(beta2, step)
79
+
80
+ # Decay the first and second moment running average coefficient
81
+ # In-place operations to update the averages at the same time
82
+ torch._foreach_lerp_(exp_avg, grad, 1 - old_debiased1)
83
+ denom = exp_avg_sq_(exp_avg_sq, grad_projected, old_debiased2, group['eps'])
84
+
85
+ for p, g, ea, d in zip(p_list, grad, exp_avg, denom):
86
+ state = self.state_(p)
87
+ # Projecting the exponential moving average of gradients to the eigenbases of Shampoo's preconditioner
88
+ # i.e. projecting to the eigenbases of matrices in state['GG']
89
+ exp_avg_projected = project(ea, state['Q'], False)
90
+
91
+ # Projecting back the preconditioned (by Adam) exponential moving average of gradients
92
+ # to the original space
93
+ # CANT DO /= HERE AS EXP_AVG MAY POINT TO THE BUFFER
94
+ set_(d, project(exp_avg_projected / d, state['Q'], True))
95
+
96
+ update_preconditioner(g, state, max_precond_dim, precondition_1d, old_debiased2,
97
+ step > 0 and step % group['precondition_frequency'] == 0)
98
+
99
+ # Why does this have to be rebiased here?
100
+ step_size = -group["lr"] * min(step / group['warmup_steps'], 1)
101
+ update_param_(p_list, denom, step_size, group["weight_decay"])
@@ -37,74 +37,60 @@ class PrecondScheduleForeachSOAP(StatefulOptimizer):
37
37
  self._data_format = data_format
38
38
  self.rng = random.Random(0x120983109)
39
39
 
40
- @torch.no_grad()
41
- def step(self, closure=None):
42
- """
43
- Performs a single optimization step.
44
-
45
- Arguments:
46
- closure (`Callable`, *optional*): A closure that reevaluates the model and returns the loss.
47
- """
48
- if closure is None:
49
- loss = None
50
- else:
51
- loss = closure()
52
-
53
- for group in self.param_groups:
54
- vals = []
55
- step = 0
56
-
57
- max_precond_dim = group['max_precond_dim']
58
- precondition_1d = group['precondition_1d']
59
-
60
- for p, g in split_p_and_g_in_group(group):
61
- state = self.state_(p)
62
- step = state['step'] = state.get("step", -1) + 1
63
-
64
- if "exp_avg" not in state:
65
- state["exp_avg"] = torch.zeros_like(g, dtype=torch.float32)
66
- state["exp_avg_sq"] = torch.zeros_like(g, dtype=torch.float32)
67
- init_preconditioner(g, state, max_precond_dim, precondition_1d)
68
- update_preconditioner(g, state, max_precond_dim, precondition_1d, 0, True)
69
- continue # first step is skipped so that we never use the current gradients in the projection.
70
-
71
- # Projecting gradients to the eigenbases of Shampoo's preconditioner
72
- # i.e. projecting to the eigenbases of matrices in state['GG']
73
- grad_projected = project(g, state['Q'], False)
74
- exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
75
- vals.append((p, g, grad_projected, exp_avg, exp_avg_sq))
76
-
77
- if not vals:
78
- continue
79
-
80
- p_list, grad, grad_projected, exp_avg, exp_avg_sq = zip(*vals)
81
- beta1, beta2 = group["betas"]
82
-
83
- old_debiased1 = beta_debias(beta1, step)
84
- old_debiased2 = beta_debias(beta2, step)
85
-
86
- # Decay the first and second moment running average coefficient
87
- # In-place operations to update the averages at the same time
88
- torch._foreach_mul_(exp_avg, old_debiased1)
89
- torch._foreach_add_(exp_avg, grad, alpha=1 - old_debiased1)
90
- denom = exp_avg_sq_(exp_avg_sq, grad_projected, old_debiased2, group['eps'])
91
-
92
- update_precond = precond_schedule(step, group['precond_scheduler'], self.rng)
93
- for p, g, ea, d in zip(p_list, grad, exp_avg, denom):
94
- state = self.state_(p)
95
- # Projecting the exponential moving average of gradients to the eigenbases of Shampoo's preconditioner
96
- # i.e. projecting to the eigenbases of matrices in state['GG']
97
- exp_avg_projected = project(ea, state['Q'], False)
98
-
99
- # Projecting back the preconditioned (by Adam) exponential moving average of gradients
100
- # to the original space
101
- # CANT DO /= HERE AS EXP_AVG MAY POINT TO THE BUFFER
102
- set_(d, project(exp_avg_projected / d, state['Q'], True))
103
-
104
- update_preconditioner(g, state, max_precond_dim, precondition_1d, old_debiased2,
105
- update_precond)
106
-
107
- # Why does this have to be rebiased here?
108
- step_size = -group["lr"] * min(step / group['warmup_steps'], 1)
109
- update_param_(p_list, denom, step_size, group["weight_decay"])
110
- return loss
40
+ def _step(self, group):
41
+ vals = []
42
+ step = 0
43
+
44
+ max_precond_dim = group['max_precond_dim']
45
+ precondition_1d = group['precondition_1d']
46
+
47
+ for p, g in split_p_and_g_in_group(group):
48
+ state = self.state_(p)
49
+ step = state['step'] = state.get("step", -1) + 1
50
+
51
+ if "exp_avg" not in state:
52
+ state["exp_avg"] = torch.zeros_like(g, dtype=torch.float32)
53
+ state["exp_avg_sq"] = torch.zeros_like(g, dtype=torch.float32)
54
+ init_preconditioner(g, state, max_precond_dim, precondition_1d)
55
+ update_preconditioner(g, state, max_precond_dim, precondition_1d, 0, True)
56
+ continue # first step is skipped so that we never use the current gradients in the projection.
57
+
58
+ # Projecting gradients to the eigenbases of Shampoo's preconditioner
59
+ # i.e. projecting to the eigenbases of matrices in state['GG']
60
+ grad_projected = project(g, state['Q'], False)
61
+ exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
62
+ vals.append((p, g, grad_projected, exp_avg, exp_avg_sq))
63
+
64
+ if not vals:
65
+ return
66
+
67
+ p_list, grad, grad_projected, exp_avg, exp_avg_sq = zip(*vals)
68
+ beta1, beta2 = group["betas"]
69
+
70
+ old_debiased1 = beta_debias(beta1, step)
71
+ old_debiased2 = beta_debias(beta2, step)
72
+
73
+ # Decay the first and second moment running average coefficient
74
+ # In-place operations to update the averages at the same time
75
+ torch._foreach_mul_(exp_avg, old_debiased1)
76
+ torch._foreach_add_(exp_avg, grad, alpha=1 - old_debiased1)
77
+ denom = exp_avg_sq_(exp_avg_sq, grad_projected, old_debiased2, group['eps'])
78
+
79
+ update_precond = precond_schedule(step, group['precond_scheduler'], self.rng)
80
+ for p, g, ea, d in zip(p_list, grad, exp_avg, denom):
81
+ state = self.state_(p)
82
+ # Projecting the exponential moving average of gradients to the eigenbases of Shampoo's preconditioner
83
+ # i.e. projecting to the eigenbases of matrices in state['GG']
84
+ exp_avg_projected = project(ea, state['Q'], False)
85
+
86
+ # Projecting back the preconditioned (by Adam) exponential moving average of gradients
87
+ # to the original space
88
+ # CANT DO /= HERE AS EXP_AVG MAY POINT TO THE BUFFER
89
+ set_(d, project(exp_avg_projected / d, state['Q'], True))
90
+
91
+ update_preconditioner(g, state, max_precond_dim, precondition_1d, old_debiased2,
92
+ update_precond)
93
+
94
+ # Why does this have to be rebiased here?
95
+ step_size = -group["lr"] * min(step / group['warmup_steps'], 1)
96
+ update_param_(p_list, denom, step_size, group["weight_decay"])
@@ -44,76 +44,61 @@ class PrecondSchedulePaLMForeachSOAP(StatefulOptimizer):
44
44
  self._data_format = data_format
45
45
  self.rng = random.Random(0x120983109)
46
46
 
47
- @torch.no_grad()
48
- def step(self, closure=None):
49
- """
50
- Performs a single optimization step.
51
-
52
- Arguments:
53
- closure (`Callable`, *optional*): A closure that reevaluates the model and returns the loss.
54
- """
55
- if closure is None:
56
- loss = None
57
- else:
58
- loss = closure()
59
-
60
- for group in self.param_groups:
61
- vals = []
62
- step = 0
63
-
64
- max_precond_dim = group['max_precond_dim']
65
- precondition_1d = group['precondition_1d']
66
-
67
- for p, g in split_p_and_g_in_group(group):
68
- state = self.state_(p)
69
- step = state['step'] = state.get("step", -1) + 1
70
-
71
- if "exp_avg" not in state:
72
- state["exp_avg"] = torch.zeros_like(g, dtype=torch.float32)
73
- state["exp_avg_sq"] = torch.zeros_like(g, dtype=torch.float32)
74
- init_preconditioner(g, state, max_precond_dim, precondition_1d)
75
- update_preconditioner(g, state, max_precond_dim, precondition_1d, 0, True)
76
- continue # first step is skipped so that we never use the current gradients in the projection.
77
-
78
- # Projecting gradients to the eigenbases of Shampoo's preconditioner
79
- # i.e. projecting to the eigenbases of matrices in state['GG']
80
- grad_projected = project(g, state['Q'], False)
81
- exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
82
- vals.append((p, g, grad_projected, exp_avg, exp_avg_sq))
83
-
84
- if not vals:
85
- continue
86
-
87
- p_list, grad, grad_projected, exp_avg, exp_avg_sq = zip(*vals)
88
- beta1 = group["beta"]
89
-
90
- beta2 = 1 - max(step, 1) ** -group['beta2_scale']
91
- old_debiased1 = beta_debias(beta1, step)
92
- old_debiased2 = beta_debias(beta2, step)
93
-
94
- # Decay the first and second moment running average coefficient
95
- # In-place operations to update the averages at the same time
96
- torch._foreach_mul_(exp_avg, old_debiased1)
97
- torch._foreach_add_(exp_avg, grad, alpha=1 - old_debiased1)
98
- denom = exp_avg_sq_(exp_avg_sq, grad_projected, old_debiased2, group['eps'])
99
-
100
- update_precond = precond_schedule(step, group['precond_scheduler'], self.rng)
101
- for p, g, ea, d in zip(p_list, grad, exp_avg, denom):
102
- state = self.state_(p)
103
- # Projecting the exponential moving average of gradients to the eigenbases of Shampoo's preconditioner
104
- # i.e. projecting to the eigenbases of matrices in state['GG']
105
- exp_avg_projected = project(ea, state['Q'], False)
106
-
107
- # Projecting back the preconditioned (by Adam) exponential moving average of gradients
108
- # to the original space
109
- # CANT DO /= HERE AS EXP_AVG MAY POINT TO THE BUFFER
110
- exp_avg_projected = exp_avg_projected / d
111
- set_(d, project(exp_avg_projected, state['Q'], True))
112
-
113
- update_preconditioner(g, state, max_precond_dim, precondition_1d, old_debiased2,
114
- update_precond)
115
-
116
- # Why does this have to be rebiased here?
117
- step_size = -group["lr"] * min(step / group['warmup_steps'], 1)
118
- update_param_(p_list, denom, step_size, group["weight_decay"])
119
- return loss
47
+ def _step(self, group):
48
+ vals = []
49
+ step = 0
50
+
51
+ max_precond_dim = group['max_precond_dim']
52
+ precondition_1d = group['precondition_1d']
53
+
54
+ for p, g in split_p_and_g_in_group(group):
55
+ state = self.state_(p)
56
+ step = state['step'] = state.get("step", -1) + 1
57
+
58
+ if "exp_avg" not in state:
59
+ state["exp_avg"] = torch.zeros_like(g, dtype=torch.float32)
60
+ state["exp_avg_sq"] = torch.zeros_like(g, dtype=torch.float32)
61
+ init_preconditioner(g, state, max_precond_dim, precondition_1d)
62
+ update_preconditioner(g, state, max_precond_dim, precondition_1d, 0, True)
63
+ continue # first step is skipped so that we never use the current gradients in the projection.
64
+
65
+ # Projecting gradients to the eigenbases of Shampoo's preconditioner
66
+ # i.e. projecting to the eigenbases of matrices in state['GG']
67
+ grad_projected = project(g, state['Q'], False)
68
+ exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
69
+ vals.append((p, g, grad_projected, exp_avg, exp_avg_sq))
70
+
71
+ if not vals:
72
+ return
73
+
74
+ p_list, grad, grad_projected, exp_avg, exp_avg_sq = zip(*vals)
75
+ beta1 = group["beta"]
76
+
77
+ beta2 = 1 - max(step, 1) ** -group['beta2_scale']
78
+ old_debiased1 = beta_debias(beta1, step)
79
+ old_debiased2 = beta_debias(beta2, step)
80
+
81
+ # Decay the first and second moment running average coefficient
82
+ # In-place operations to update the averages at the same time
83
+ torch._foreach_mul_(exp_avg, old_debiased1)
84
+ torch._foreach_add_(exp_avg, grad, alpha=1 - old_debiased1)
85
+ denom = exp_avg_sq_(exp_avg_sq, grad_projected, old_debiased2, group['eps'])
86
+
87
+ update_precond = precond_schedule(step, group['precond_scheduler'], self.rng)
88
+ for p, g, ea, d in zip(p_list, grad, exp_avg, denom):
89
+ state = self.state_(p)
90
+ # Projecting the exponential moving average of gradients to the eigenbases of Shampoo's preconditioner
91
+ # i.e. projecting to the eigenbases of matrices in state['GG']
92
+ exp_avg_projected = project(ea, state['Q'], False)
93
+
94
+ # Projecting back the preconditioned (by Adam) exponential moving average of gradients
95
+ # to the original space
96
+ # CANT DO /= HERE AS EXP_AVG MAY POINT TO THE BUFFER
97
+ exp_avg_projected = exp_avg_projected / d
98
+ set_(d, project(exp_avg_projected, state['Q'], True))
99
+
100
+ update_preconditioner(g, state, max_precond_dim, precondition_1d, old_debiased2, update_precond)
101
+
102
+ # Why does this have to be rebiased here?
103
+ step_size = -group["lr"] * min(step / group['warmup_steps'], 1)
104
+ update_param_(p_list, denom, step_size, group["weight_decay"])