heavyball 0.14.6__py3-none-any.whl → 0.15.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
heavyball/__init__.py CHANGED
@@ -1,3 +1,4 @@
1
+ from .delayed_psgd import ForeachDelayedPSGD
1
2
  from .foreach_adamw import ForeachAdamW
2
3
  from .foreach_adopt import ForeachADOPT
3
4
  from .foreach_laprop import ForeachLaProp
@@ -12,7 +13,6 @@ from .precond_schedule_sfpsoap import PrecondScheduleSFPaLMSOAP
12
13
  from .psgd_kron import ForeachPSGDKron
13
14
  from .pure_psgd import ForeachPurePSGD
14
15
  from .schedule_free_palm_foreach_soap import SFPaLMForeachSOAP
15
- from .delayed_psgd import ForeachDelayedPSGD
16
16
 
17
17
  PalmForEachSoap = PaLMForeachSOAP
18
18
 
heavyball/delayed_psgd.py CHANGED
@@ -5,8 +5,8 @@ Source available at https://github.com/evanatyourservice/kron_torch/blob/97a2b5e
5
5
  """
6
6
 
7
7
  import torch
8
-
9
8
  from heavyball.utils import copy_stochastic_list_
9
+
10
10
  from .utils import update_param_, warmup, psgd_precond_grad, init_Q_exprs, trust_region_clip_, PSGDBase, \
11
11
  precond_update_prob_schedule, split_p_and_g_in_group, triu_to_line, line_to_triu, set_
12
12
 
@@ -63,13 +63,7 @@ class ForeachDelayedPSGD(PSGDBase):
63
63
 
64
64
  self._prob_step = 0
65
65
 
66
- @torch.no_grad()
67
- def step(self, closure=None):
68
- loss = None
69
- if closure is not None:
70
- with torch.enable_grad():
71
- loss = closure()
72
-
66
+ def _step(self, group):
73
67
  # update preconditioners all together
74
68
  update_prob = self.preconditioner_update_probability
75
69
  if callable(update_prob):
@@ -77,55 +71,52 @@ class ForeachDelayedPSGD(PSGDBase):
77
71
  do_update = self.rng.random() < update_prob
78
72
  self._prob_step += 1
79
73
 
80
- for group in self.param_groups:
81
- momentum_into_precond_update = group.get("momentum_into_precond_update", True)
82
- precond_init_scale = group['precond_init_scale']
83
- max_size_triangular = group['max_size_triangular']
84
- min_ndim_triangular = group['min_ndim_triangular']
85
- memory_save_mode = group['memory_save_mode']
86
- precond_lr = group['precond_lr']
87
- weight_decay = group['weight_decay']
88
- lr = group['lr']
89
- beta = group['beta']
90
-
91
- vals = []
74
+ momentum_into_precond_update = group.get("momentum_into_precond_update", True)
75
+ precond_init_scale = group['precond_init_scale']
76
+ max_size_triangular = group['max_size_triangular']
77
+ min_ndim_triangular = group['min_ndim_triangular']
78
+ memory_save_mode = group['memory_save_mode']
79
+ precond_lr = group['precond_lr']
80
+ weight_decay = group['weight_decay']
81
+ lr = group['lr']
82
+ beta = group['beta']
92
83
 
93
- for p, g in split_p_and_g_in_group(group):
94
- state = self.state_(p)
84
+ vals = []
95
85
 
96
- if 'Q' not in state:
97
- state["exp_avg"] = torch.zeros_like(g)
98
- Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular, min_ndim_triangular,
99
- memory_save_mode, dtype=g.dtype)
100
- state["Q"] = triu_to_line(Q)
86
+ for p, g in split_p_and_g_in_group(group):
87
+ state = self.state_(p)
101
88
 
102
- vals.append((p, g, state["exp_avg"], state["Q"]))
89
+ if 'Q' not in state:
90
+ state["exp_avg"] = torch.zeros_like(g)
91
+ Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular, min_ndim_triangular,
92
+ memory_save_mode, dtype=g.dtype)
93
+ state["Q"] = triu_to_line(Q)
103
94
 
104
- if not vals:
105
- continue
95
+ vals.append((p, g, state["exp_avg"], state["Q"]))
106
96
 
107
- p_list, grad_list, exp_avg_list, Q_list = zip(*vals)
108
- del vals
97
+ if not vals:
98
+ return
109
99
 
110
- group["step"] += 1
100
+ p_list, grad_list, exp_avg_list, Q_list = zip(*vals)
101
+ del vals
111
102
 
112
- torch._foreach_lerp_(exp_avg_list, grad_list, (1 - beta) / (1 - beta ** group["step"]))
103
+ group["step"] += 1
113
104
 
114
- Q_list, exp_avg_list = list(Q_list), list(exp_avg_list)
115
- for i, (p, g) in enumerate(zip(p_list, grad_list)):
116
- q_orig = Q_list.pop(0)
117
- ea = exp_avg_list.pop(0)
118
- q = line_to_triu(q_orig)
119
- self.balance(do_update, [g], [q])
120
- new = psgd_precond_grad(q, self.state_(p)["exprs"], ea)
105
+ torch._foreach_lerp_(exp_avg_list, grad_list, (1 - beta) / (1 - beta ** group["step"]))
121
106
 
122
- if do_update:
123
- self.do_update([p], [ea if momentum_into_precond_update else g], [q], precond_lr, [q_orig])
124
- set_(g, new)
107
+ Q_list, exp_avg_list = list(Q_list), list(exp_avg_list)
108
+ for i, (p, g) in enumerate(zip(p_list, grad_list)):
109
+ q_orig = Q_list.pop(0)
110
+ ea = exp_avg_list.pop(0)
111
+ q = line_to_triu(q_orig)
112
+ self.balance(do_update, [g], [q])
113
+ new = psgd_precond_grad(q, self.state_(p)["exprs"], ea)
125
114
 
126
- grad_list = self.clip_fn(grad_list)
115
+ if do_update:
116
+ self.do_update([p], [ea if momentum_into_precond_update else g], [q], precond_lr, [q_orig])
117
+ set_(g, new)
127
118
 
128
- lr = -warmup(lr, group['step'], group['warmup_steps'])
129
- update_param_(p_list, grad_list, lr, weight_decay)
119
+ grad_list = self.clip_fn(grad_list)
130
120
 
131
- return loss
121
+ lr = -warmup(lr, group['step'], group['warmup_steps'])
122
+ update_param_(p_list, grad_list, lr, weight_decay)
@@ -10,42 +10,32 @@ class ForeachAdamW(StatefulOptimizer):
10
10
  lr_max=-1.0, weight_decay=weight_decay)
11
11
  super().__init__(params, defaults)
12
12
 
13
- def step(self, closure=None):
14
- """Performs a single optimization step.
13
+ def _step(self, group):
14
+ eps = group['eps']
15
+ decay = group['weight_decay']
16
+ k = group['k']
15
17
 
16
- Arguments:
17
- closure (callable, optional): A closure that reevaluates the model
18
- and returns the loss.
19
- """
18
+ if not group['train_mode']:
19
+ raise Exception("Not in train mode!")
20
20
 
21
- loss = None
22
- if closure is not None:
23
- loss = closure()
21
+ active_p = [p for p in group['params'] if p.grad is not None]
24
22
 
25
- for group in self.param_groups:
26
- eps = group['eps']
27
- decay = group['weight_decay']
28
- k = group['k']
23
+ if not active_p:
24
+ return
29
25
 
30
- if not group['train_mode']:
31
- raise Exception("Not in train mode!")
26
+ for p in active_p:
27
+ if 'exp_avg' not in self.state_(p):
28
+ self.state_(p)['exp_avg'] = torch.zeros_like(p.data, dtype=torch.float32)
29
+ self.state_(p)['exp_avg_sq'] = torch.zeros_like(p.data, dtype=torch.float32)
32
30
 
33
- active_p = [p for p in group['params'] if p.grad is not None]
31
+ y, grad, exp_avg_sq, exp_avg = zip(
32
+ *[(p.data, p.grad.float(), self.state_(p)['exp_avg_sq'], self.state_(p)['exp_avg']) for p in active_p])
34
33
 
35
- for p in active_p:
36
- if 'exp_avg' not in self.state_(p):
37
- self.state_(p)['exp_avg'] = torch.zeros_like(p.data, dtype=torch.float32)
38
- self.state_(p)['exp_avg_sq'] = torch.zeros_like(p.data, dtype=torch.float32)
34
+ # Decay the first and second moment running average coefficient
35
+ torch._foreach_lerp_(exp_avg, grad, 1 - beta_debias(group['betas'][0], k + 1))
36
+ denom = exp_avg_sq_(exp_avg_sq, grad, beta_debias(group['betas'][1], k + 1), eps)
39
37
 
40
- y, grad, exp_avg_sq, exp_avg = zip(
41
- *[(p.data, p.grad.float(), self.state_(p)['exp_avg_sq'], self.state_(p)['exp_avg']) for p in active_p])
42
-
43
- # Decay the first and second moment running average coefficient
44
- torch._foreach_lerp_(exp_avg, grad, 1 - beta_debias(group['betas'][0], k + 1))
45
- denom = exp_avg_sq_(exp_avg_sq, grad, beta_debias(group['betas'][1], k + 1), eps)
46
-
47
- # Normalize grad in-place for memory efficiency
48
- lr = -warmup(group['lr'], k + 1, group['warmup_steps'])
49
- update_param_(y, exp_avg, lr, decay, lambda p, e, l: torch._foreach_addcdiv_(p, e, denom, l))
50
- group['k'] = k + 1
51
- return loss
38
+ # Normalize grad in-place for memory efficiency
39
+ lr = -warmup(group['lr'], k + 1, group['warmup_steps'])
40
+ update_param_(y, exp_avg, lr, decay, lambda p, e, l: torch._foreach_addcdiv_(p, e, denom, l))
41
+ group['k'] = k + 1
@@ -11,51 +11,41 @@ class ForeachADOPT(StatefulOptimizer):
11
11
  lr_max=-1.0, weight_decay=weight_decay)
12
12
  super().__init__(params, defaults)
13
13
 
14
- def step(self, closure=None):
15
- """Performs a single optimization step.
16
-
17
- Arguments:
18
- closure (callable, optional): A closure that reevaluates the model
19
- and returns the loss.
20
- """
21
-
22
- loss = None
23
- if closure is not None:
24
- loss = closure()
25
-
26
- for group in self.param_groups:
27
- eps = group['eps']
28
- decay = group['weight_decay']
29
- k = group['k']
30
-
31
- if not group['train_mode']:
32
- raise Exception("Not in train mode!")
33
-
34
- active_p = [p for p in group['params'] if p.grad is not None]
35
-
36
- for p in active_p:
37
- if 'exp_avg' not in self.state_(p):
38
- self.state_(p)['exp_avg'] = torch.zeros_like(p.data, dtype=torch.float32)
39
- self.state_(p)['exp_avg_sq'] = torch.zeros_like(p.data, dtype=torch.float32)
40
-
41
- y, grad, exp_avg_sq, exp_avg = zip(
42
- *[(p.data, p.grad.float(), self.state_(p)['exp_avg_sq'], self.state_(p)['exp_avg']) for p in active_p])
43
-
44
- if k > 1:
45
- lr = -warmup(group['lr'], k - 1, group['warmup_steps'])
46
-
47
- update_param_(y, exp_avg, lr, decay)
48
- if k > 0:
49
- beta1 = beta_debias(group['betas'][0], k)
50
- denom = torch._foreach_sqrt(exp_avg_sq)
51
- torch._foreach_maximum_(denom, eps)
52
- torch._foreach_mul_(exp_avg, beta1)
53
- torch._foreach_addcdiv_(exp_avg, grad, denom, 1 - beta1)
54
-
55
- beta2 = beta_debias(group['betas'][1], k + 1)
56
- torch._foreach_mul_(exp_avg_sq, beta2)
57
- torch._foreach_addcmul_(exp_avg_sq, grad, grad, value=1 - beta2)
58
- del grad
59
-
60
- group['k'] = k + 1
61
- return loss
14
+ def _step(self, group):
15
+ eps = group['eps']
16
+ decay = group['weight_decay']
17
+ k = group['k']
18
+
19
+ if not group['train_mode']:
20
+ raise Exception("Not in train mode!")
21
+
22
+ active_p = [p for p in group['params'] if p.grad is not None]
23
+
24
+ if not active_p:
25
+ return
26
+
27
+ for p in active_p:
28
+ if 'exp_avg' not in self.state_(p):
29
+ self.state_(p)['exp_avg'] = torch.zeros_like(p.data, dtype=torch.float32)
30
+ self.state_(p)['exp_avg_sq'] = torch.zeros_like(p.data, dtype=torch.float32)
31
+
32
+ y, grad, exp_avg_sq, exp_avg = zip(
33
+ *[(p.data, p.grad.float(), self.state_(p)['exp_avg_sq'], self.state_(p)['exp_avg']) for p in active_p])
34
+
35
+ if k > 1:
36
+ lr = -warmup(group['lr'], k - 1, group['warmup_steps'])
37
+
38
+ update_param_(y, exp_avg, lr, decay)
39
+ if k > 0:
40
+ beta1 = beta_debias(group['betas'][0], k)
41
+ denom = torch._foreach_sqrt(exp_avg_sq)
42
+ torch._foreach_maximum_(denom, eps)
43
+ torch._foreach_mul_(exp_avg, beta1)
44
+ torch._foreach_addcdiv_(exp_avg, grad, denom, 1 - beta1)
45
+
46
+ beta2 = beta_debias(group['betas'][1], k + 1)
47
+ torch._foreach_mul_(exp_avg_sq, beta2)
48
+ torch._foreach_addcmul_(exp_avg_sq, grad, grad, value=1 - beta2)
49
+ del grad
50
+
51
+ group['k'] = k + 1
@@ -11,46 +11,36 @@ class ForeachLaProp(StatefulOptimizer):
11
11
  lr_max=-1.0, weight_decay=weight_decay)
12
12
  super().__init__(params, defaults)
13
13
 
14
- def step(self, closure=None):
15
- """Performs a single optimization step.
14
+ def _step(self, group):
15
+ eps = group['eps']
16
+ decay = group['weight_decay']
17
+ k = group['k']
16
18
 
17
- Arguments:
18
- closure (callable, optional): A closure that reevaluates the model
19
- and returns the loss.
20
- """
19
+ if not group['train_mode']:
20
+ raise Exception("Not in train mode!")
21
21
 
22
- loss = None
23
- if closure is not None:
24
- loss = closure()
22
+ active_p = [p for p in group['params'] if p.grad is not None]
25
23
 
26
- for group in self.param_groups:
27
- eps = group['eps']
28
- decay = group['weight_decay']
29
- k = group['k']
24
+ if not active_p:
25
+ return
30
26
 
31
- if not group['train_mode']:
32
- raise Exception("Not in train mode!")
27
+ for p in active_p:
28
+ if 'exp_avg' not in self.state_(p):
29
+ self.state_(p)['exp_avg'] = torch.zeros_like(p.data, dtype=torch.float32)
30
+ self.state_(p)['exp_avg_sq'] = torch.zeros_like(p.data, dtype=torch.float32)
33
31
 
34
- active_p = [p for p in group['params'] if p.grad is not None]
32
+ y, grad, exp_avg_sq, exp_avg = zip(
33
+ *[(p.data, p.grad.float(), self.state_(p)['exp_avg_sq'], self.state_(p)['exp_avg']) for p in active_p])
35
34
 
36
- for p in active_p:
37
- if 'exp_avg' not in self.state_(p):
38
- self.state_(p)['exp_avg'] = torch.zeros_like(p.data, dtype=torch.float32)
39
- self.state_(p)['exp_avg_sq'] = torch.zeros_like(p.data, dtype=torch.float32)
35
+ # Decay the first and second moment running average coefficient
36
+ denom = exp_avg_sq_(exp_avg_sq, grad, beta_debias(group['betas'][1], k + 1), eps)
37
+ beta1 = beta_debias(group['betas'][0], k + 1)
38
+ torch._foreach_mul_(exp_avg, beta1)
39
+ torch._foreach_addcdiv_(exp_avg, grad, denom, 1 - beta1)
40
+ del grad
40
41
 
41
- y, grad, exp_avg_sq, exp_avg = zip(
42
- *[(p.data, p.grad.float(), self.state_(p)['exp_avg_sq'], self.state_(p)['exp_avg']) for p in active_p])
42
+ # Normalize grad in-place for memory efficiency
43
+ lr = -warmup(group['lr'], k + 1, group['warmup_steps'])
44
+ update_param_(y, exp_avg, lr, decay)
43
45
 
44
- # Decay the first and second moment running average coefficient
45
- denom = exp_avg_sq_(exp_avg_sq, grad, beta_debias(group['betas'][1], k + 1), eps)
46
- beta1 = beta_debias(group['betas'][0], k + 1)
47
- torch._foreach_mul_(exp_avg, beta1)
48
- torch._foreach_addcdiv_(exp_avg, grad, denom, 1 - beta1)
49
- del grad
50
-
51
- # Normalize grad in-place for memory efficiency
52
- lr = -warmup(group['lr'], k + 1, group['warmup_steps'])
53
- update_param_(y, exp_avg, lr, decay)
54
-
55
- group['k'] = k + 1
56
- return loss
46
+ group['k'] = k + 1
@@ -13,52 +13,42 @@ class ForeachSFAdamW(ScheduleFree):
13
13
  foreach=foreach)
14
14
  super().__init__(params, defaults)
15
15
 
16
- def step(self, closure=None):
17
- """Performs a single optimization step.
16
+ def _step(self, group):
17
+ eps = group['eps']
18
+ decay = group['weight_decay']
19
+ k = group['k']
18
20
 
19
- Arguments:
20
- closure (callable, optional): A closure that reevaluates the model
21
- and returns the loss.
22
- """
21
+ if not group['train_mode']:
22
+ raise Exception("Not in train mode!")
23
23
 
24
- loss = None
25
- if closure is not None:
26
- loss = closure()
24
+ active_p = [p for p in group['params'] if p.grad is not None]
27
25
 
28
- for group in self.param_groups:
29
- eps = group['eps']
30
- decay = group['weight_decay']
31
- k = group['k']
26
+ if not active_p:
27
+ return
32
28
 
33
- if not group['train_mode']:
34
- raise Exception("Not in train mode!")
29
+ for p in active_p:
30
+ if 'z' not in self.state_(p):
31
+ self.state_(p)['z'] = torch.clone(p.data)
32
+ self.state_(p)['exp_avg_sq'] = torch.zeros_like(p.data, dtype=torch.float32)
35
33
 
36
- active_p = [p for p in group['params'] if p.grad is not None]
34
+ y, grad, exp_avg_sq, z = zip(
35
+ *[(p.data, p.grad.float(), self.state_(p)['exp_avg_sq'], self.state_(p)['z']) for p in active_p])
37
36
 
38
- for p in active_p:
39
- if 'z' not in self.state_(p):
40
- self.state_(p)['z'] = torch.clone(p.data)
41
- self.state_(p)['exp_avg_sq'] = torch.zeros_like(p.data, dtype=torch.float32)
37
+ # Decay the first moment running average coefficient
38
+ old_debiased = beta_debias(group['betas'][1], k + 1)
42
39
 
43
- y, grad, exp_avg_sq, z = zip(
44
- *[(p.data, p.grad.float(), self.state_(p)['exp_avg_sq'], self.state_(p)['z']) for p in active_p])
40
+ # Decay the first and second moment running average coefficient
41
+ denom = exp_avg_sq_(exp_avg_sq, grad, old_debiased, eps)
45
42
 
46
- # Decay the first moment running average coefficient
47
- old_debiased = beta_debias(group['betas'][1], k + 1)
43
+ # Normalize grad in-place for memory efficiency
44
+ torch._foreach_div_(grad, denom)
48
45
 
49
- # Decay the first and second moment running average coefficient
50
- denom = exp_avg_sq_(exp_avg_sq, grad, old_debiased, eps)
46
+ # Weight decay calculated at y
47
+ if decay != 0:
48
+ torch._foreach_add_(grad, y, alpha=decay)
51
49
 
52
- # Normalize grad in-place for memory efficiency
53
- torch._foreach_div_(grad, denom)
50
+ lr = warmup(group['lr'], k + 1, group['warmup_steps'])
51
+ group['weight_sum'] = schedule_free_(lr, group['weight_lr_power'], group['weight_sum'], group['betas'][0],
52
+ y, z, grad, group['r'], k + 1)
54
53
 
55
- # Weight decay calculated at y
56
- if decay != 0:
57
- torch._foreach_add_(grad, y, alpha=decay)
58
-
59
- lr = warmup(group['lr'], k + 1, group['warmup_steps'])
60
- group['weight_sum'] = schedule_free_(lr, group['weight_lr_power'], group['weight_sum'], group['betas'][0],
61
- y, z, grad, group['r'], k + 1)
62
-
63
- group['k'] = k + 1
64
- return loss
54
+ group['k'] = k + 1
heavyball/foreach_soap.py CHANGED
@@ -34,73 +34,59 @@ class ForeachSOAP(StatefulOptimizer):
34
34
  super().__init__(params, defaults)
35
35
  self._data_format = data_format
36
36
 
37
- @torch.no_grad()
38
- def step(self, closure=None):
39
- """
40
- Performs a single optimization step.
41
-
42
- Arguments:
43
- closure (`Callable`, *optional*): A closure that reevaluates the model and returns the loss.
44
- """
45
- if closure is None:
46
- loss = None
47
- else:
48
- loss = closure()
49
-
50
- for group in self.param_groups:
51
- vals = []
52
- step = 0
53
-
54
- max_precond_dim = group['max_precond_dim']
55
- precondition_1d = group['precondition_1d']
56
-
57
- for p, g in split_p_and_g_in_group(group):
58
- state = self.state_(p)
59
- step = state['step'] = state.get("step", -1) + 1
60
-
61
- if "exp_avg" not in state:
62
- state["exp_avg"] = torch.zeros_like(g, dtype=torch.float32)
63
- state["exp_avg_sq"] = torch.zeros_like(g, dtype=torch.float32)
64
- init_preconditioner(g, state, max_precond_dim, precondition_1d)
65
- update_preconditioner(g, state, max_precond_dim, precondition_1d, 0, True)
66
- continue # first step is skipped so that we never use the current gradients in the projection.
67
-
68
- # Projecting gradients to the eigenbases of Shampoo's preconditioner
69
- # i.e. projecting to the eigenbases of matrices in state['GG']
70
- grad_projected = project(g, state['Q'], False)
71
- exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
72
- vals.append((p, g, grad_projected, exp_avg, exp_avg_sq))
73
-
74
- if not vals:
75
- continue
76
-
77
- p_list, grad, grad_projected, exp_avg, exp_avg_sq = zip(*vals)
78
- beta1, beta2 = group["betas"]
79
-
80
- old_debiased1 = beta_debias(beta1, step)
81
- old_debiased2 = beta_debias(beta2, step)
82
-
83
- # Decay the first and second moment running average coefficient
84
- # In-place operations to update the averages at the same time
85
- torch._foreach_mul_(exp_avg, old_debiased1)
86
- torch._foreach_add_(exp_avg, grad, alpha=1 - old_debiased1)
87
- denom = exp_avg_sq_(exp_avg_sq, grad_projected, old_debiased2, group['eps'])
88
-
89
- for p, g, ea, d in zip(p_list, grad, exp_avg, denom):
90
- state = self.state_(p)
91
- # Projecting the exponential moving average of gradients to the eigenbases of Shampoo's preconditioner
92
- # i.e. projecting to the eigenbases of matrices in state['GG']
93
- exp_avg_projected = project(ea, state['Q'], False)
94
-
95
- # Projecting back the preconditioned (by Adam) exponential moving average of gradients
96
- # to the original space
97
- # CANT DO /= HERE AS EXP_AVG MAY POINT TO THE BUFFER
98
- set_(d, project(exp_avg_projected / d, state['Q'], True))
99
-
100
- update_preconditioner(g, state, max_precond_dim, precondition_1d, old_debiased2,
101
- step > 0 and step % group['precondition_frequency'] == 0)
102
-
103
- # Why does this have to be rebiased here?
104
- step_size = -group["lr"] * min(step / group['warmup_steps'], 1)
105
- update_param_(p_list, denom, step_size, group["weight_decay"])
106
- return loss
37
+ def _step(self, group):
38
+ vals = []
39
+ step = 0
40
+
41
+ max_precond_dim = group['max_precond_dim']
42
+ precondition_1d = group['precondition_1d']
43
+
44
+ for p, g in split_p_and_g_in_group(group):
45
+ state = self.state_(p)
46
+ step = state['step'] = state.get("step", -1) + 1
47
+
48
+ if "exp_avg" not in state:
49
+ state["exp_avg"] = torch.zeros_like(g, dtype=torch.float32)
50
+ state["exp_avg_sq"] = torch.zeros_like(g, dtype=torch.float32)
51
+ init_preconditioner(g, state, max_precond_dim, precondition_1d)
52
+ update_preconditioner(g, state, max_precond_dim, precondition_1d, 0, True)
53
+ continue # first step is skipped so that we never use the current gradients in the projection.
54
+
55
+ # Projecting gradients to the eigenbases of Shampoo's preconditioner
56
+ # i.e. projecting to the eigenbases of matrices in state['GG']
57
+ grad_projected = project(g, state['Q'], False)
58
+ exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
59
+ vals.append((p, g, grad_projected, exp_avg, exp_avg_sq))
60
+
61
+ if not vals:
62
+ return
63
+
64
+ p_list, grad, grad_projected, exp_avg, exp_avg_sq = zip(*vals)
65
+ beta1, beta2 = group["betas"]
66
+
67
+ old_debiased1 = beta_debias(beta1, step)
68
+ old_debiased2 = beta_debias(beta2, step)
69
+
70
+ # Decay the first and second moment running average coefficient
71
+ # In-place operations to update the averages at the same time
72
+ torch._foreach_mul_(exp_avg, old_debiased1)
73
+ torch._foreach_add_(exp_avg, grad, alpha=1 - old_debiased1)
74
+ denom = exp_avg_sq_(exp_avg_sq, grad_projected, old_debiased2, group['eps'])
75
+
76
+ for p, g, ea, d in zip(p_list, grad, exp_avg, denom):
77
+ state = self.state_(p)
78
+ # Projecting the exponential moving average of gradients to the eigenbases of Shampoo's preconditioner
79
+ # i.e. projecting to the eigenbases of matrices in state['GG']
80
+ exp_avg_projected = project(ea, state['Q'], False)
81
+
82
+ # Projecting back the preconditioned (by Adam) exponential moving average of gradients
83
+ # to the original space
84
+ # CANT DO /= HERE AS EXP_AVG MAY POINT TO THE BUFFER
85
+ set_(d, project(exp_avg_projected / d, state['Q'], True))
86
+
87
+ update_preconditioner(g, state, max_precond_dim, precondition_1d, old_debiased2,
88
+ step > 0 and step % group['precondition_frequency'] == 0)
89
+
90
+ # Why does this have to be rebiased here?
91
+ step_size = -group["lr"] * min(step / group['warmup_steps'], 1)
92
+ update_param_(p_list, denom, step_size, group["weight_decay"])