heavyball 0.14.7__py3-none-any.whl → 0.15.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -44,76 +44,61 @@ class PrecondSchedulePaLMForeachSOAP(StatefulOptimizer):
44
44
  self._data_format = data_format
45
45
  self.rng = random.Random(0x120983109)
46
46
 
47
- @torch.no_grad()
48
- def step(self, closure=None):
49
- """
50
- Performs a single optimization step.
51
-
52
- Arguments:
53
- closure (`Callable`, *optional*): A closure that reevaluates the model and returns the loss.
54
- """
55
- if closure is None:
56
- loss = None
57
- else:
58
- loss = closure()
59
-
60
- for group in self.param_groups:
61
- vals = []
62
- step = 0
63
-
64
- max_precond_dim = group['max_precond_dim']
65
- precondition_1d = group['precondition_1d']
66
-
67
- for p, g in split_p_and_g_in_group(group):
68
- state = self.state_(p)
69
- step = state['step'] = state.get("step", -1) + 1
70
-
71
- if "exp_avg" not in state:
72
- state["exp_avg"] = torch.zeros_like(g, dtype=torch.float32)
73
- state["exp_avg_sq"] = torch.zeros_like(g, dtype=torch.float32)
74
- init_preconditioner(g, state, max_precond_dim, precondition_1d)
75
- update_preconditioner(g, state, max_precond_dim, precondition_1d, 0, True)
76
- continue # first step is skipped so that we never use the current gradients in the projection.
77
-
78
- # Projecting gradients to the eigenbases of Shampoo's preconditioner
79
- # i.e. projecting to the eigenbases of matrices in state['GG']
80
- grad_projected = project(g, state['Q'], False)
81
- exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
82
- vals.append((p, g, grad_projected, exp_avg, exp_avg_sq))
83
-
84
- if not vals:
85
- continue
86
-
87
- p_list, grad, grad_projected, exp_avg, exp_avg_sq = zip(*vals)
88
- beta1 = group["beta"]
89
-
90
- beta2 = 1 - max(step, 1) ** -group['beta2_scale']
91
- old_debiased1 = beta_debias(beta1, step)
92
- old_debiased2 = beta_debias(beta2, step)
93
-
94
- # Decay the first and second moment running average coefficient
95
- # In-place operations to update the averages at the same time
96
- torch._foreach_mul_(exp_avg, old_debiased1)
97
- torch._foreach_add_(exp_avg, grad, alpha=1 - old_debiased1)
98
- denom = exp_avg_sq_(exp_avg_sq, grad_projected, old_debiased2, group['eps'])
99
-
100
- update_precond = precond_schedule(step, group['precond_scheduler'], self.rng)
101
- for p, g, ea, d in zip(p_list, grad, exp_avg, denom):
102
- state = self.state_(p)
103
- # Projecting the exponential moving average of gradients to the eigenbases of Shampoo's preconditioner
104
- # i.e. projecting to the eigenbases of matrices in state['GG']
105
- exp_avg_projected = project(ea, state['Q'], False)
106
-
107
- # Projecting back the preconditioned (by Adam) exponential moving average of gradients
108
- # to the original space
109
- # CANT DO /= HERE AS EXP_AVG MAY POINT TO THE BUFFER
110
- exp_avg_projected = exp_avg_projected / d
111
- set_(d, project(exp_avg_projected, state['Q'], True))
112
-
113
- update_preconditioner(g, state, max_precond_dim, precondition_1d, old_debiased2,
114
- update_precond)
115
-
116
- # Why does this have to be rebiased here?
117
- step_size = -group["lr"] * min(step / group['warmup_steps'], 1)
118
- update_param_(p_list, denom, step_size, group["weight_decay"])
119
- return loss
47
+ def _step(self, group):
48
+ vals = []
49
+ step = 0
50
+
51
+ max_precond_dim = group['max_precond_dim']
52
+ precondition_1d = group['precondition_1d']
53
+
54
+ for p, g in split_p_and_g_in_group(group):
55
+ state = self.state_(p)
56
+ step = state['step'] = state.get("step", -1) + 1
57
+
58
+ if "exp_avg" not in state:
59
+ state["exp_avg"] = torch.zeros_like(g, dtype=torch.float32)
60
+ state["exp_avg_sq"] = torch.zeros_like(g, dtype=torch.float32)
61
+ init_preconditioner(g, state, max_precond_dim, precondition_1d)
62
+ update_preconditioner(g, state, max_precond_dim, precondition_1d, 0, True)
63
+ continue # first step is skipped so that we never use the current gradients in the projection.
64
+
65
+ # Projecting gradients to the eigenbases of Shampoo's preconditioner
66
+ # i.e. projecting to the eigenbases of matrices in state['GG']
67
+ grad_projected = project(g, state['Q'], False)
68
+ exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
69
+ vals.append((p, g, grad_projected, exp_avg, exp_avg_sq))
70
+
71
+ if not vals:
72
+ return
73
+
74
+ p_list, grad, grad_projected, exp_avg, exp_avg_sq = zip(*vals)
75
+ beta1 = group["beta"]
76
+
77
+ beta2 = 1 - max(step, 1) ** -group['beta2_scale']
78
+ old_debiased1 = beta_debias(beta1, step)
79
+ old_debiased2 = beta_debias(beta2, step)
80
+
81
+ # Decay the first and second moment running average coefficient
82
+ # In-place operations to update the averages at the same time
83
+ torch._foreach_mul_(exp_avg, old_debiased1)
84
+ torch._foreach_add_(exp_avg, grad, alpha=1 - old_debiased1)
85
+ denom = exp_avg_sq_(exp_avg_sq, grad_projected, old_debiased2, group['eps'])
86
+
87
+ update_precond = precond_schedule(step, group['precond_scheduler'], self.rng)
88
+ for p, g, ea, d in zip(p_list, grad, exp_avg, denom):
89
+ state = self.state_(p)
90
+ # Projecting the exponential moving average of gradients to the eigenbases of Shampoo's preconditioner
91
+ # i.e. projecting to the eigenbases of matrices in state['GG']
92
+ exp_avg_projected = project(ea, state['Q'], False)
93
+
94
+ # Projecting back the preconditioned (by Adam) exponential moving average of gradients
95
+ # to the original space
96
+ # CANT DO /= HERE AS EXP_AVG MAY POINT TO THE BUFFER
97
+ exp_avg_projected = exp_avg_projected / d
98
+ set_(d, project(exp_avg_projected, state['Q'], True))
99
+
100
+ update_preconditioner(g, state, max_precond_dim, precondition_1d, old_debiased2, update_precond)
101
+
102
+ # Why does this have to be rebiased here?
103
+ step_size = -group["lr"] * min(step / group['warmup_steps'], 1)
104
+ update_param_(p_list, denom, step_size, group["weight_decay"])
@@ -54,85 +54,73 @@ class PrecondScheduleSFPaLMSOAP(ScheduleFree):
54
54
  self._data_format = data_format
55
55
  self.rng = random.Random(0x120983109)
56
56
 
57
- @torch.no_grad()
58
- def step(self, closure=None):
59
- """
60
- Performs a single optimization step.
61
-
62
- Arguments:
63
- closure (`Callable`, *optional*): A closure that reevaluates the model and returns the loss.
64
- """
65
- if closure is None:
66
- loss = None
67
- else:
68
- loss = closure()
69
-
70
- for group in self.param_groups:
71
- vals = []
72
- max_precond_dim = group['max_precond_dim']
73
- precondition_1d = group['precondition_1d']
74
-
75
- step = group['step'] = group.get("step", -1) + 1
76
-
77
- for p in group["params"]:
78
- if p.grad is None:
79
- continue
80
- grad = p.grad.float()
81
- vals.append((p, grad))
82
-
83
- p_list, grad = zip(*vals)
84
- vals = []
85
-
86
- # adaptive gradient clipping
87
- adaptive_gradient_clipping_(p_list, grad, group["gradient_clip_val"], eps=group["eps"])
88
-
89
- for p, g in split_p_and_g_in_group(group):
90
- state = self.state_(p)
91
-
92
- if "z" not in state:
93
- state["z"] = torch.clone(p.data)
94
- state["exp_avg_sq"] = torch.zeros_like(g, dtype=torch.float32)
95
- init_preconditioner(g, state, max_precond_dim, precondition_1d)
96
- update_preconditioner(g, state, max_precond_dim, precondition_1d, 0, True)
97
- continue # first step is skipped so that we never use the current gradients in the projection.
98
-
99
- # Projecting gradients to the eigenbases of Shampoo's preconditioner
100
- # i.e. projecting to the eigenbases of matrices in state['GG']
101
- grad_projected = project(g, state['Q'], False)
102
- z, exp_avg_sq = state["z"], state["exp_avg_sq"]
103
- vals.append((p, g, grad_projected, z, exp_avg_sq))
104
-
105
- if not vals:
57
+ def _step(self, group):
58
+ vals = []
59
+ max_precond_dim = group['max_precond_dim']
60
+ precondition_1d = group['precondition_1d']
61
+
62
+ step = group['step'] = group.get("step", -1) + 1
63
+
64
+ for p in group["params"]:
65
+ if p.grad is None:
106
66
  continue
67
+ grad = p.grad.float()
68
+ vals.append((p, grad))
69
+
70
+ if not vals:
71
+ return
72
+
73
+ p_list, grad = zip(*vals)
74
+ vals = []
75
+
76
+ # adaptive gradient clipping
77
+ adaptive_gradient_clipping_(p_list, grad, group["gradient_clip_val"], eps=group["eps"])
78
+
79
+ for p, g in split_p_and_g_in_group(group):
80
+ state = self.state_(p)
81
+
82
+ if "z" not in state:
83
+ state["z"] = torch.clone(p.data)
84
+ state["exp_avg_sq"] = torch.zeros_like(g, dtype=torch.float32)
85
+ init_preconditioner(g, state, max_precond_dim, precondition_1d)
86
+ update_preconditioner(g, state, max_precond_dim, precondition_1d, 0, True)
87
+ continue # first step is skipped so that we never use the current gradients in the projection.
88
+
89
+ # Projecting gradients to the eigenbases of Shampoo's preconditioner
90
+ # i.e. projecting to the eigenbases of matrices in state['GG']
91
+ grad_projected = project(g, state['Q'], False)
92
+ z, exp_avg_sq = state["z"], state["exp_avg_sq"]
93
+ vals.append((p, g, grad_projected, z, exp_avg_sq))
107
94
 
108
- p_list, grad, grad_projected, z, exp_avg_sq = zip(*vals)
109
- del vals
95
+ if not vals:
96
+ return
110
97
 
111
- beta2 = 1 - max(step, 1) ** -group['beta2_scale']
112
- old_debiased2 = beta_debias(beta2, step)
98
+ p_list, grad, grad_projected, z, exp_avg_sq = zip(*vals)
99
+ del vals
113
100
 
114
- # Decay the first and second moment running average coefficient
115
- # In-place operations to update the averages at the same time
116
- denom = exp_avg_sq_(exp_avg_sq, grad_projected, old_debiased2, group['eps'])
117
- torch._foreach_div_(grad_projected, denom)
101
+ beta2 = 1 - max(step, 1) ** -group['beta2_scale']
102
+ old_debiased2 = beta_debias(beta2, step)
118
103
 
119
- update_precond = precond_schedule(step, group['precond_scheduler'], self.rng)
104
+ # Decay the first and second moment running average coefficient
105
+ # In-place operations to update the averages at the same time
106
+ denom = exp_avg_sq_(exp_avg_sq, grad_projected, old_debiased2, group['eps'])
107
+ torch._foreach_div_(grad_projected, denom)
120
108
 
121
- for p, g, gp in zip(p_list, grad, grad_projected):
122
- state = self.state_(p)
123
- # Projecting back the preconditioned (by Adam) exponential moving average of gradients
124
- # to the original space
125
- set_(gp, project(gp, state['Q'], back=True))
109
+ update_precond = precond_schedule(step, group['precond_scheduler'], self.rng)
126
110
 
127
- update_preconditioner(g, state, max_precond_dim, precondition_1d, old_debiased2,
128
- update_precond)
111
+ for p, g, gp in zip(p_list, grad, grad_projected):
112
+ state = self.state_(p)
113
+ # Projecting back the preconditioned (by Adam) exponential moving average of gradients
114
+ # to the original space
115
+ set_(gp, project(gp, state['Q'], back=True))
129
116
 
130
- # Weight decay calculated at y
131
- if group["weight_decay"] > 0:
132
- torch._foreach_add_(grad, p_list, alpha=group["weight_decay"])
117
+ update_preconditioner(g, state, max_precond_dim, precondition_1d, old_debiased2,
118
+ update_precond)
133
119
 
134
- lr = warmup(group['lr'], step, group['warmup_steps'])
135
- group['weight_sum'] = schedule_free_(lr, group['weight_lr_power'], group['weight_sum'], group['beta'],
136
- p_list, z, grad_projected, group['r'], step)
120
+ # Weight decay calculated at y
121
+ if group["weight_decay"] > 0:
122
+ torch._foreach_add_(grad, p_list, alpha=group["weight_decay"])
137
123
 
138
- return loss
124
+ lr = warmup(group['lr'], step, group['warmup_steps'])
125
+ group['weight_sum'] = schedule_free_(lr, group['weight_lr_power'], group['weight_sum'], group['beta'],
126
+ p_list, z, grad_projected, group['r'], step)
heavyball/psgd_kron.py CHANGED
@@ -4,9 +4,10 @@ Modified under Creative Commons Attribution 4.0 International
4
4
  Source available at https://github.com/evanatyourservice/kron_torch/blob/97a2b5ee8a1a4c29e4780bbf6c521e545189eff9/kron_torch/kron.py
5
5
  """
6
6
 
7
- import torch
8
7
  from typing import Optional
9
8
 
9
+ import torch
10
+
10
11
  from .utils import update_param_, warmup, psgd_precond_grad, init_Q_exprs, trust_region_clip_, PSGDBase, \
11
12
  precond_update_prob_schedule, split_p_and_g_in_group, line_to_triu, triu_to_line, set_
12
13
 
@@ -37,7 +38,7 @@ class ForeachPSGDKron(PSGDBase):
37
38
  def __init__(self, params, lr=0.001, beta=0.9, weight_decay=0.0, preconditioner_update_probability=None,
38
39
  max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
39
40
  momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
40
- split: bool = False, clip_fn: Optional[callable] = None):
41
+ split: bool = False, clip_fn: Optional[callable] = None, store_triu_as_line: bool = True):
41
42
  if not 0.0 <= lr:
42
43
  raise ValueError(f"Invalid learning rate: {lr}")
43
44
  if not 0.0 <= beta < 1.0:
@@ -57,18 +58,13 @@ class ForeachPSGDKron(PSGDBase):
57
58
  momentum_into_precond_update=momentum_into_precond_update, precond_lr=0.1,
58
59
  # precond lr hardcoded to 0.1
59
60
  precond_init_scale=1.0, # precond init scale hardcoded to 1.0
60
- step=0, warmup_steps=warmup_steps, merge_dims=merge_dims, split=split)
61
+ step=0, warmup_steps=warmup_steps, merge_dims=merge_dims, split=split,
62
+ store_triu_as_line=store_triu_as_line)
61
63
  super().__init__(params, defaults)
62
64
 
63
65
  self._prob_step = 0
64
66
 
65
- @torch.no_grad()
66
- def step(self, closure=None):
67
- loss = None
68
- if closure is not None:
69
- with torch.enable_grad():
70
- loss = closure()
71
-
67
+ def _step(self, group):
72
68
  # update preconditioners all together
73
69
  update_prob = self.preconditioner_update_probability
74
70
  if callable(update_prob):
@@ -76,54 +72,52 @@ class ForeachPSGDKron(PSGDBase):
76
72
  do_update = self.rng.random() < update_prob
77
73
  self._prob_step += 1
78
74
 
79
- for group in self.param_groups:
80
- momentum_into_precond_update = group.get("momentum_into_precond_update", True)
81
- precond_init_scale = group['precond_init_scale']
82
- max_size_triangular = group['max_size_triangular']
83
- min_ndim_triangular = group['min_ndim_triangular']
84
- memory_save_mode = group['memory_save_mode']
85
- precond_lr = group['precond_lr']
86
- weight_decay = group['weight_decay']
87
- lr = group['lr']
88
- beta = group['beta']
89
-
90
- vals = []
75
+ momentum_into_precond_update = group.get("momentum_into_precond_update", True)
76
+ precond_init_scale = group['precond_init_scale']
77
+ max_size_triangular = group['max_size_triangular']
78
+ min_ndim_triangular = group['min_ndim_triangular']
79
+ memory_save_mode = group['memory_save_mode']
80
+ precond_lr = group['precond_lr']
81
+ weight_decay = group['weight_decay']
82
+ lr = group['lr']
83
+ beta = group['beta']
84
+ store_triu_as_line = group['store_triu_as_line']
91
85
 
92
- for p, g in split_p_and_g_in_group(group):
93
- state = self.state_(p)
86
+ vals = []
94
87
 
95
- if 'Q' not in state:
96
- state["exp_avg"] = torch.zeros_like(g)
97
- Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular, min_ndim_triangular,
98
- memory_save_mode, dtype=g.dtype)
99
- state['Q'] = triu_to_line(Q)
88
+ for p, g in split_p_and_g_in_group(group):
89
+ state = self.state_(p)
100
90
 
101
- vals.append((p, g, state["exp_avg"], state["Q"]))
91
+ if 'Q' not in state:
92
+ state["exp_avg"] = torch.zeros_like(g)
93
+ Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular, min_ndim_triangular,
94
+ memory_save_mode, dtype=g.dtype)
95
+ state['Q'] = triu_to_line(Q) if store_triu_as_line else Q
102
96
 
103
- if not vals:
104
- continue
97
+ vals.append((p, g, state["exp_avg"], state["Q"]))
105
98
 
106
- p_list, grad_list, exp_avg_list, Q_list = zip(*vals)
107
- del vals
99
+ if not vals:
100
+ return
108
101
 
109
- group["step"] += 1
102
+ p_list, grad_list, exp_avg_list, Q_list = zip(*vals)
103
+ del vals
110
104
 
111
- torch._foreach_lerp_(exp_avg_list, grad_list, (1 - beta) / (1 - beta ** group["step"]))
105
+ group["step"] += 1
112
106
 
113
- grad_list, Q_list, exp_avg_list = list(grad_list), list(Q_list), list(exp_avg_list)
114
- for i, (p, g) in enumerate(zip(p_list, grad_list)):
115
- q_orig = Q_list.pop(0)
116
- ea = exp_avg_list.pop(0)
117
- q = line_to_triu(q_orig)
107
+ torch._foreach_lerp_(exp_avg_list, grad_list, (1 - beta) / (1 - beta ** group["step"]))
118
108
 
119
- self.balance(do_update, [g], [q])
120
- if do_update:
121
- self.do_update([p], [ea if momentum_into_precond_update else g], [q], precond_lr, [q_orig])
122
- set_(g, psgd_precond_grad(q, self.state_(p)["exprs"], ea))
109
+ grad_list, Q_list, exp_avg_list = list(grad_list), list(Q_list), list(exp_avg_list)
110
+ for i, (p, g) in enumerate(zip(p_list, grad_list)):
111
+ q_orig = Q_list.pop(0)
112
+ ea = exp_avg_list.pop(0)
113
+ q = line_to_triu(q_orig) if store_triu_as_line else q_orig
123
114
 
124
- grad_list = self.clip_fn(grad_list)
115
+ if do_update:
116
+ self.balance([g], [q])
117
+ self.do_update([p], [ea if momentum_into_precond_update else g], [q], precond_lr, [q_orig] if store_triu_as_line else None)
118
+ set_(g, psgd_precond_grad(q, self.state_(p)["exprs"], ea))
125
119
 
126
- lr = -warmup(lr, group['step'], group['warmup_steps'])
127
- update_param_(p_list, grad_list, lr, weight_decay)
120
+ grad_list = self.clip_fn(grad_list)
128
121
 
129
- return loss
122
+ lr = -warmup(lr, group['step'], group['warmup_steps'])
123
+ update_param_(p_list, grad_list, lr, weight_decay)
heavyball/pure_psgd.py CHANGED
@@ -36,7 +36,7 @@ class ForeachPurePSGD(PSGDBase):
36
36
  def __init__(self, params, lr=0.001, weight_decay=0.0, preconditioner_update_probability=None,
37
37
  max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
38
38
  momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
39
- split: bool = False, clip_fn: callable = None):
39
+ split: bool = False, clip_fn: callable = None, store_triu_as_line: bool = True):
40
40
  if not 0.0 <= lr:
41
41
  raise ValueError(f"Invalid learning rate: {lr}")
42
42
  if not 0.0 <= weight_decay:
@@ -54,18 +54,13 @@ class ForeachPurePSGD(PSGDBase):
54
54
  momentum_into_precond_update=momentum_into_precond_update, precond_lr=0.1,
55
55
  # precond lr hardcoded to 0.1
56
56
  precond_init_scale=1.0, # precond init scale hardcoded to 1.0
57
- step=0, warmup_steps=warmup_steps, merge_dims=merge_dims, split=split)
57
+ step=0, warmup_steps=warmup_steps, merge_dims=merge_dims, split=split,
58
+ store_triu_as_line=store_triu_as_line)
58
59
  super().__init__(params, defaults)
59
60
 
60
61
  self._prob_step = 0
61
62
 
62
- @torch.no_grad()
63
- def step(self, closure=None):
64
- loss = None
65
- if closure is not None:
66
- with torch.enable_grad():
67
- loss = closure()
68
-
63
+ def _step(self, group):
69
64
  # update preconditioners all together
70
65
  update_prob = self.preconditioner_update_probability
71
66
  if callable(update_prob):
@@ -73,48 +68,46 @@ class ForeachPurePSGD(PSGDBase):
73
68
  do_update = self.rng.random() < update_prob
74
69
  self._prob_step += 1
75
70
 
76
- for group in self.param_groups:
77
- precond_init_scale = group['precond_init_scale']
78
- max_size_triangular = group['max_size_triangular']
79
- min_ndim_triangular = group['min_ndim_triangular']
80
- memory_save_mode = group['memory_save_mode']
81
- precond_lr = group['precond_lr']
82
- weight_decay = group['weight_decay']
83
- lr = group['lr']
84
-
85
- vals = []
71
+ precond_init_scale = group['precond_init_scale']
72
+ max_size_triangular = group['max_size_triangular']
73
+ min_ndim_triangular = group['min_ndim_triangular']
74
+ memory_save_mode = group['memory_save_mode']
75
+ precond_lr = group['precond_lr']
76
+ weight_decay = group['weight_decay']
77
+ lr = group['lr']
78
+ store_triu_as_line = group['store_triu_as_line']
86
79
 
87
- for p, g in split_p_and_g_in_group(group):
88
- state = self.state_(p)
80
+ vals = []
89
81
 
90
- if 'Q' not in state:
91
- Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular, min_ndim_triangular,
92
- memory_save_mode, dtype=g.dtype)
93
- state['Q'] = triu_to_line(Q)
82
+ for p, g in split_p_and_g_in_group(group):
83
+ state = self.state_(p)
94
84
 
95
- vals.append((p, g, state["Q"]))
85
+ if 'Q' not in state:
86
+ Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular, min_ndim_triangular,
87
+ memory_save_mode, dtype=g.dtype)
88
+ state['Q'] = triu_to_line(Q) if store_triu_as_line else Q
96
89
 
97
- if not vals:
98
- continue
90
+ vals.append((p, g, state["Q"]))
99
91
 
100
- p_list, grad_list, Q_list = zip(*vals)
101
- del vals
92
+ if not vals:
93
+ return
102
94
 
103
- group["step"] += 1
95
+ p_list, grad_list, Q_list = zip(*vals)
96
+ del vals
104
97
 
105
- Q_list = list(Q_list)
106
- for i, (p, g) in enumerate(zip(p_list, grad_list)):
107
- q_orig = Q_list.pop(0)
108
- q = line_to_triu(q_orig)
98
+ group["step"] += 1
109
99
 
110
- self.balance(do_update, [g], [q])
111
- if do_update:
112
- self.do_update([p], [g], [q], precond_lr, [q_orig])
113
- psgd_precond_grad(q, self.state_(p)["exprs"], g, inplace=True)
100
+ Q_list = list(Q_list)
101
+ for i, (p, g) in enumerate(zip(p_list, grad_list)):
102
+ q_orig = Q_list.pop(0)
103
+ q = line_to_triu(q_orig) if store_triu_as_line else q_orig
114
104
 
115
- grad_list = self.clip_fn(grad_list)
105
+ if do_update:
106
+ self.balance([g], [q])
107
+ self.do_update([p], [g], [q], precond_lr, [q_orig] if store_triu_as_line else None)
108
+ psgd_precond_grad(q, self.state_(p)["exprs"], g, inplace=True)
116
109
 
117
- lr = -warmup(lr, group['step'], group['warmup_steps'])
118
- update_param_(p_list, grad_list, lr, weight_decay)
110
+ grad_list = self.clip_fn(grad_list)
119
111
 
120
- return loss
112
+ lr = -warmup(lr, group['step'], group['warmup_steps'])
113
+ update_param_(p_list, grad_list, lr, weight_decay)