heavyball 0.14.7__py3-none-any.whl → 0.15.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- heavyball/__init__.py +25 -3
- heavyball/cached_psgd_kron.py +141 -0
- heavyball/delayed_psgd.py +43 -51
- heavyball/foreach_adamw.py +22 -32
- heavyball/foreach_adopt.py +38 -48
- heavyball/foreach_laprop.py +25 -35
- heavyball/foreach_sfadamw.py +28 -38
- heavyball/foreach_soap.py +56 -70
- heavyball/p_adam.py +46 -50
- heavyball/palm_foreach_sfadamw.py +31 -41
- heavyball/palm_foreach_soap.py +56 -70
- heavyball/precond_schedule_foreach_soap.py +57 -71
- heavyball/precond_schedule_palm_foreach_soap.py +58 -73
- heavyball/precond_schedule_sfpsoap.py +60 -72
- heavyball/psgd_kron.py +43 -49
- heavyball/pure_psgd.py +36 -43
- heavyball/schedule_free_palm_foreach_soap.py +61 -72
- heavyball/utils.py +23 -7
- {heavyball-0.14.7.dist-info → heavyball-0.15.1.dist-info}/METADATA +1 -1
- heavyball-0.15.1.dist-info/RECORD +23 -0
- heavyball-0.14.7.dist-info/RECORD +0 -22
- {heavyball-0.14.7.dist-info → heavyball-0.15.1.dist-info}/LICENSE +0 -0
- {heavyball-0.14.7.dist-info → heavyball-0.15.1.dist-info}/WHEEL +0 -0
- {heavyball-0.14.7.dist-info → heavyball-0.15.1.dist-info}/top_level.txt +0 -0
@@ -44,76 +44,61 @@ class PrecondSchedulePaLMForeachSOAP(StatefulOptimizer):
|
|
44
44
|
self._data_format = data_format
|
45
45
|
self.rng = random.Random(0x120983109)
|
46
46
|
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
#
|
95
|
-
#
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
exp_avg_projected = project(ea, state['Q'], False)
|
106
|
-
|
107
|
-
# Projecting back the preconditioned (by Adam) exponential moving average of gradients
|
108
|
-
# to the original space
|
109
|
-
# CANT DO /= HERE AS EXP_AVG MAY POINT TO THE BUFFER
|
110
|
-
exp_avg_projected = exp_avg_projected / d
|
111
|
-
set_(d, project(exp_avg_projected, state['Q'], True))
|
112
|
-
|
113
|
-
update_preconditioner(g, state, max_precond_dim, precondition_1d, old_debiased2,
|
114
|
-
update_precond)
|
115
|
-
|
116
|
-
# Why does this have to be rebiased here?
|
117
|
-
step_size = -group["lr"] * min(step / group['warmup_steps'], 1)
|
118
|
-
update_param_(p_list, denom, step_size, group["weight_decay"])
|
119
|
-
return loss
|
47
|
+
def _step(self, group):
|
48
|
+
vals = []
|
49
|
+
step = 0
|
50
|
+
|
51
|
+
max_precond_dim = group['max_precond_dim']
|
52
|
+
precondition_1d = group['precondition_1d']
|
53
|
+
|
54
|
+
for p, g in split_p_and_g_in_group(group):
|
55
|
+
state = self.state_(p)
|
56
|
+
step = state['step'] = state.get("step", -1) + 1
|
57
|
+
|
58
|
+
if "exp_avg" not in state:
|
59
|
+
state["exp_avg"] = torch.zeros_like(g, dtype=torch.float32)
|
60
|
+
state["exp_avg_sq"] = torch.zeros_like(g, dtype=torch.float32)
|
61
|
+
init_preconditioner(g, state, max_precond_dim, precondition_1d)
|
62
|
+
update_preconditioner(g, state, max_precond_dim, precondition_1d, 0, True)
|
63
|
+
continue # first step is skipped so that we never use the current gradients in the projection.
|
64
|
+
|
65
|
+
# Projecting gradients to the eigenbases of Shampoo's preconditioner
|
66
|
+
# i.e. projecting to the eigenbases of matrices in state['GG']
|
67
|
+
grad_projected = project(g, state['Q'], False)
|
68
|
+
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
|
69
|
+
vals.append((p, g, grad_projected, exp_avg, exp_avg_sq))
|
70
|
+
|
71
|
+
if not vals:
|
72
|
+
return
|
73
|
+
|
74
|
+
p_list, grad, grad_projected, exp_avg, exp_avg_sq = zip(*vals)
|
75
|
+
beta1 = group["beta"]
|
76
|
+
|
77
|
+
beta2 = 1 - max(step, 1) ** -group['beta2_scale']
|
78
|
+
old_debiased1 = beta_debias(beta1, step)
|
79
|
+
old_debiased2 = beta_debias(beta2, step)
|
80
|
+
|
81
|
+
# Decay the first and second moment running average coefficient
|
82
|
+
# In-place operations to update the averages at the same time
|
83
|
+
torch._foreach_mul_(exp_avg, old_debiased1)
|
84
|
+
torch._foreach_add_(exp_avg, grad, alpha=1 - old_debiased1)
|
85
|
+
denom = exp_avg_sq_(exp_avg_sq, grad_projected, old_debiased2, group['eps'])
|
86
|
+
|
87
|
+
update_precond = precond_schedule(step, group['precond_scheduler'], self.rng)
|
88
|
+
for p, g, ea, d in zip(p_list, grad, exp_avg, denom):
|
89
|
+
state = self.state_(p)
|
90
|
+
# Projecting the exponential moving average of gradients to the eigenbases of Shampoo's preconditioner
|
91
|
+
# i.e. projecting to the eigenbases of matrices in state['GG']
|
92
|
+
exp_avg_projected = project(ea, state['Q'], False)
|
93
|
+
|
94
|
+
# Projecting back the preconditioned (by Adam) exponential moving average of gradients
|
95
|
+
# to the original space
|
96
|
+
# CANT DO /= HERE AS EXP_AVG MAY POINT TO THE BUFFER
|
97
|
+
exp_avg_projected = exp_avg_projected / d
|
98
|
+
set_(d, project(exp_avg_projected, state['Q'], True))
|
99
|
+
|
100
|
+
update_preconditioner(g, state, max_precond_dim, precondition_1d, old_debiased2, update_precond)
|
101
|
+
|
102
|
+
# Why does this have to be rebiased here?
|
103
|
+
step_size = -group["lr"] * min(step / group['warmup_steps'], 1)
|
104
|
+
update_param_(p_list, denom, step_size, group["weight_decay"])
|
@@ -54,85 +54,73 @@ class PrecondScheduleSFPaLMSOAP(ScheduleFree):
|
|
54
54
|
self._data_format = data_format
|
55
55
|
self.rng = random.Random(0x120983109)
|
56
56
|
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
""
|
65
|
-
|
66
|
-
loss = None
|
67
|
-
else:
|
68
|
-
loss = closure()
|
69
|
-
|
70
|
-
for group in self.param_groups:
|
71
|
-
vals = []
|
72
|
-
max_precond_dim = group['max_precond_dim']
|
73
|
-
precondition_1d = group['precondition_1d']
|
74
|
-
|
75
|
-
step = group['step'] = group.get("step", -1) + 1
|
76
|
-
|
77
|
-
for p in group["params"]:
|
78
|
-
if p.grad is None:
|
79
|
-
continue
|
80
|
-
grad = p.grad.float()
|
81
|
-
vals.append((p, grad))
|
82
|
-
|
83
|
-
p_list, grad = zip(*vals)
|
84
|
-
vals = []
|
85
|
-
|
86
|
-
# adaptive gradient clipping
|
87
|
-
adaptive_gradient_clipping_(p_list, grad, group["gradient_clip_val"], eps=group["eps"])
|
88
|
-
|
89
|
-
for p, g in split_p_and_g_in_group(group):
|
90
|
-
state = self.state_(p)
|
91
|
-
|
92
|
-
if "z" not in state:
|
93
|
-
state["z"] = torch.clone(p.data)
|
94
|
-
state["exp_avg_sq"] = torch.zeros_like(g, dtype=torch.float32)
|
95
|
-
init_preconditioner(g, state, max_precond_dim, precondition_1d)
|
96
|
-
update_preconditioner(g, state, max_precond_dim, precondition_1d, 0, True)
|
97
|
-
continue # first step is skipped so that we never use the current gradients in the projection.
|
98
|
-
|
99
|
-
# Projecting gradients to the eigenbases of Shampoo's preconditioner
|
100
|
-
# i.e. projecting to the eigenbases of matrices in state['GG']
|
101
|
-
grad_projected = project(g, state['Q'], False)
|
102
|
-
z, exp_avg_sq = state["z"], state["exp_avg_sq"]
|
103
|
-
vals.append((p, g, grad_projected, z, exp_avg_sq))
|
104
|
-
|
105
|
-
if not vals:
|
57
|
+
def _step(self, group):
|
58
|
+
vals = []
|
59
|
+
max_precond_dim = group['max_precond_dim']
|
60
|
+
precondition_1d = group['precondition_1d']
|
61
|
+
|
62
|
+
step = group['step'] = group.get("step", -1) + 1
|
63
|
+
|
64
|
+
for p in group["params"]:
|
65
|
+
if p.grad is None:
|
106
66
|
continue
|
67
|
+
grad = p.grad.float()
|
68
|
+
vals.append((p, grad))
|
69
|
+
|
70
|
+
if not vals:
|
71
|
+
return
|
72
|
+
|
73
|
+
p_list, grad = zip(*vals)
|
74
|
+
vals = []
|
75
|
+
|
76
|
+
# adaptive gradient clipping
|
77
|
+
adaptive_gradient_clipping_(p_list, grad, group["gradient_clip_val"], eps=group["eps"])
|
78
|
+
|
79
|
+
for p, g in split_p_and_g_in_group(group):
|
80
|
+
state = self.state_(p)
|
81
|
+
|
82
|
+
if "z" not in state:
|
83
|
+
state["z"] = torch.clone(p.data)
|
84
|
+
state["exp_avg_sq"] = torch.zeros_like(g, dtype=torch.float32)
|
85
|
+
init_preconditioner(g, state, max_precond_dim, precondition_1d)
|
86
|
+
update_preconditioner(g, state, max_precond_dim, precondition_1d, 0, True)
|
87
|
+
continue # first step is skipped so that we never use the current gradients in the projection.
|
88
|
+
|
89
|
+
# Projecting gradients to the eigenbases of Shampoo's preconditioner
|
90
|
+
# i.e. projecting to the eigenbases of matrices in state['GG']
|
91
|
+
grad_projected = project(g, state['Q'], False)
|
92
|
+
z, exp_avg_sq = state["z"], state["exp_avg_sq"]
|
93
|
+
vals.append((p, g, grad_projected, z, exp_avg_sq))
|
107
94
|
|
108
|
-
|
109
|
-
|
95
|
+
if not vals:
|
96
|
+
return
|
110
97
|
|
111
|
-
|
112
|
-
|
98
|
+
p_list, grad, grad_projected, z, exp_avg_sq = zip(*vals)
|
99
|
+
del vals
|
113
100
|
|
114
|
-
|
115
|
-
|
116
|
-
denom = exp_avg_sq_(exp_avg_sq, grad_projected, old_debiased2, group['eps'])
|
117
|
-
torch._foreach_div_(grad_projected, denom)
|
101
|
+
beta2 = 1 - max(step, 1) ** -group['beta2_scale']
|
102
|
+
old_debiased2 = beta_debias(beta2, step)
|
118
103
|
|
119
|
-
|
104
|
+
# Decay the first and second moment running average coefficient
|
105
|
+
# In-place operations to update the averages at the same time
|
106
|
+
denom = exp_avg_sq_(exp_avg_sq, grad_projected, old_debiased2, group['eps'])
|
107
|
+
torch._foreach_div_(grad_projected, denom)
|
120
108
|
|
121
|
-
|
122
|
-
state = self.state_(p)
|
123
|
-
# Projecting back the preconditioned (by Adam) exponential moving average of gradients
|
124
|
-
# to the original space
|
125
|
-
set_(gp, project(gp, state['Q'], back=True))
|
109
|
+
update_precond = precond_schedule(step, group['precond_scheduler'], self.rng)
|
126
110
|
|
127
|
-
|
128
|
-
|
111
|
+
for p, g, gp in zip(p_list, grad, grad_projected):
|
112
|
+
state = self.state_(p)
|
113
|
+
# Projecting back the preconditioned (by Adam) exponential moving average of gradients
|
114
|
+
# to the original space
|
115
|
+
set_(gp, project(gp, state['Q'], back=True))
|
129
116
|
|
130
|
-
|
131
|
-
|
132
|
-
torch._foreach_add_(grad, p_list, alpha=group["weight_decay"])
|
117
|
+
update_preconditioner(g, state, max_precond_dim, precondition_1d, old_debiased2,
|
118
|
+
update_precond)
|
133
119
|
|
134
|
-
|
135
|
-
|
136
|
-
|
120
|
+
# Weight decay calculated at y
|
121
|
+
if group["weight_decay"] > 0:
|
122
|
+
torch._foreach_add_(grad, p_list, alpha=group["weight_decay"])
|
137
123
|
|
138
|
-
|
124
|
+
lr = warmup(group['lr'], step, group['warmup_steps'])
|
125
|
+
group['weight_sum'] = schedule_free_(lr, group['weight_lr_power'], group['weight_sum'], group['beta'],
|
126
|
+
p_list, z, grad_projected, group['r'], step)
|
heavyball/psgd_kron.py
CHANGED
@@ -4,9 +4,10 @@ Modified under Creative Commons Attribution 4.0 International
|
|
4
4
|
Source available at https://github.com/evanatyourservice/kron_torch/blob/97a2b5ee8a1a4c29e4780bbf6c521e545189eff9/kron_torch/kron.py
|
5
5
|
"""
|
6
6
|
|
7
|
-
import torch
|
8
7
|
from typing import Optional
|
9
8
|
|
9
|
+
import torch
|
10
|
+
|
10
11
|
from .utils import update_param_, warmup, psgd_precond_grad, init_Q_exprs, trust_region_clip_, PSGDBase, \
|
11
12
|
precond_update_prob_schedule, split_p_and_g_in_group, line_to_triu, triu_to_line, set_
|
12
13
|
|
@@ -37,7 +38,7 @@ class ForeachPSGDKron(PSGDBase):
|
|
37
38
|
def __init__(self, params, lr=0.001, beta=0.9, weight_decay=0.0, preconditioner_update_probability=None,
|
38
39
|
max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
|
39
40
|
momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
|
40
|
-
split: bool = False, clip_fn: Optional[callable] = None):
|
41
|
+
split: bool = False, clip_fn: Optional[callable] = None, store_triu_as_line: bool = True):
|
41
42
|
if not 0.0 <= lr:
|
42
43
|
raise ValueError(f"Invalid learning rate: {lr}")
|
43
44
|
if not 0.0 <= beta < 1.0:
|
@@ -57,18 +58,13 @@ class ForeachPSGDKron(PSGDBase):
|
|
57
58
|
momentum_into_precond_update=momentum_into_precond_update, precond_lr=0.1,
|
58
59
|
# precond lr hardcoded to 0.1
|
59
60
|
precond_init_scale=1.0, # precond init scale hardcoded to 1.0
|
60
|
-
step=0, warmup_steps=warmup_steps, merge_dims=merge_dims, split=split
|
61
|
+
step=0, warmup_steps=warmup_steps, merge_dims=merge_dims, split=split,
|
62
|
+
store_triu_as_line=store_triu_as_line)
|
61
63
|
super().__init__(params, defaults)
|
62
64
|
|
63
65
|
self._prob_step = 0
|
64
66
|
|
65
|
-
|
66
|
-
def step(self, closure=None):
|
67
|
-
loss = None
|
68
|
-
if closure is not None:
|
69
|
-
with torch.enable_grad():
|
70
|
-
loss = closure()
|
71
|
-
|
67
|
+
def _step(self, group):
|
72
68
|
# update preconditioners all together
|
73
69
|
update_prob = self.preconditioner_update_probability
|
74
70
|
if callable(update_prob):
|
@@ -76,54 +72,52 @@ class ForeachPSGDKron(PSGDBase):
|
|
76
72
|
do_update = self.rng.random() < update_prob
|
77
73
|
self._prob_step += 1
|
78
74
|
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
vals = []
|
75
|
+
momentum_into_precond_update = group.get("momentum_into_precond_update", True)
|
76
|
+
precond_init_scale = group['precond_init_scale']
|
77
|
+
max_size_triangular = group['max_size_triangular']
|
78
|
+
min_ndim_triangular = group['min_ndim_triangular']
|
79
|
+
memory_save_mode = group['memory_save_mode']
|
80
|
+
precond_lr = group['precond_lr']
|
81
|
+
weight_decay = group['weight_decay']
|
82
|
+
lr = group['lr']
|
83
|
+
beta = group['beta']
|
84
|
+
store_triu_as_line = group['store_triu_as_line']
|
91
85
|
|
92
|
-
|
93
|
-
state = self.state_(p)
|
86
|
+
vals = []
|
94
87
|
|
95
|
-
|
96
|
-
|
97
|
-
Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular, min_ndim_triangular,
|
98
|
-
memory_save_mode, dtype=g.dtype)
|
99
|
-
state['Q'] = triu_to_line(Q)
|
88
|
+
for p, g in split_p_and_g_in_group(group):
|
89
|
+
state = self.state_(p)
|
100
90
|
|
101
|
-
|
91
|
+
if 'Q' not in state:
|
92
|
+
state["exp_avg"] = torch.zeros_like(g)
|
93
|
+
Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular, min_ndim_triangular,
|
94
|
+
memory_save_mode, dtype=g.dtype)
|
95
|
+
state['Q'] = triu_to_line(Q) if store_triu_as_line else Q
|
102
96
|
|
103
|
-
|
104
|
-
continue
|
97
|
+
vals.append((p, g, state["exp_avg"], state["Q"]))
|
105
98
|
|
106
|
-
|
107
|
-
|
99
|
+
if not vals:
|
100
|
+
return
|
108
101
|
|
109
|
-
|
102
|
+
p_list, grad_list, exp_avg_list, Q_list = zip(*vals)
|
103
|
+
del vals
|
110
104
|
|
111
|
-
|
105
|
+
group["step"] += 1
|
112
106
|
|
113
|
-
|
114
|
-
for i, (p, g) in enumerate(zip(p_list, grad_list)):
|
115
|
-
q_orig = Q_list.pop(0)
|
116
|
-
ea = exp_avg_list.pop(0)
|
117
|
-
q = line_to_triu(q_orig)
|
107
|
+
torch._foreach_lerp_(exp_avg_list, grad_list, (1 - beta) / (1 - beta ** group["step"]))
|
118
108
|
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
109
|
+
grad_list, Q_list, exp_avg_list = list(grad_list), list(Q_list), list(exp_avg_list)
|
110
|
+
for i, (p, g) in enumerate(zip(p_list, grad_list)):
|
111
|
+
q_orig = Q_list.pop(0)
|
112
|
+
ea = exp_avg_list.pop(0)
|
113
|
+
q = line_to_triu(q_orig) if store_triu_as_line else q_orig
|
123
114
|
|
124
|
-
|
115
|
+
if do_update:
|
116
|
+
self.balance([g], [q])
|
117
|
+
self.do_update([p], [ea if momentum_into_precond_update else g], [q], precond_lr, [q_orig] if store_triu_as_line else None)
|
118
|
+
set_(g, psgd_precond_grad(q, self.state_(p)["exprs"], ea))
|
125
119
|
|
126
|
-
|
127
|
-
update_param_(p_list, grad_list, lr, weight_decay)
|
120
|
+
grad_list = self.clip_fn(grad_list)
|
128
121
|
|
129
|
-
|
122
|
+
lr = -warmup(lr, group['step'], group['warmup_steps'])
|
123
|
+
update_param_(p_list, grad_list, lr, weight_decay)
|
heavyball/pure_psgd.py
CHANGED
@@ -36,7 +36,7 @@ class ForeachPurePSGD(PSGDBase):
|
|
36
36
|
def __init__(self, params, lr=0.001, weight_decay=0.0, preconditioner_update_probability=None,
|
37
37
|
max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
|
38
38
|
momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
|
39
|
-
split: bool = False, clip_fn: callable = None):
|
39
|
+
split: bool = False, clip_fn: callable = None, store_triu_as_line: bool = True):
|
40
40
|
if not 0.0 <= lr:
|
41
41
|
raise ValueError(f"Invalid learning rate: {lr}")
|
42
42
|
if not 0.0 <= weight_decay:
|
@@ -54,18 +54,13 @@ class ForeachPurePSGD(PSGDBase):
|
|
54
54
|
momentum_into_precond_update=momentum_into_precond_update, precond_lr=0.1,
|
55
55
|
# precond lr hardcoded to 0.1
|
56
56
|
precond_init_scale=1.0, # precond init scale hardcoded to 1.0
|
57
|
-
step=0, warmup_steps=warmup_steps, merge_dims=merge_dims, split=split
|
57
|
+
step=0, warmup_steps=warmup_steps, merge_dims=merge_dims, split=split,
|
58
|
+
store_triu_as_line=store_triu_as_line)
|
58
59
|
super().__init__(params, defaults)
|
59
60
|
|
60
61
|
self._prob_step = 0
|
61
62
|
|
62
|
-
|
63
|
-
def step(self, closure=None):
|
64
|
-
loss = None
|
65
|
-
if closure is not None:
|
66
|
-
with torch.enable_grad():
|
67
|
-
loss = closure()
|
68
|
-
|
63
|
+
def _step(self, group):
|
69
64
|
# update preconditioners all together
|
70
65
|
update_prob = self.preconditioner_update_probability
|
71
66
|
if callable(update_prob):
|
@@ -73,48 +68,46 @@ class ForeachPurePSGD(PSGDBase):
|
|
73
68
|
do_update = self.rng.random() < update_prob
|
74
69
|
self._prob_step += 1
|
75
70
|
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
vals = []
|
71
|
+
precond_init_scale = group['precond_init_scale']
|
72
|
+
max_size_triangular = group['max_size_triangular']
|
73
|
+
min_ndim_triangular = group['min_ndim_triangular']
|
74
|
+
memory_save_mode = group['memory_save_mode']
|
75
|
+
precond_lr = group['precond_lr']
|
76
|
+
weight_decay = group['weight_decay']
|
77
|
+
lr = group['lr']
|
78
|
+
store_triu_as_line = group['store_triu_as_line']
|
86
79
|
|
87
|
-
|
88
|
-
state = self.state_(p)
|
80
|
+
vals = []
|
89
81
|
|
90
|
-
|
91
|
-
|
92
|
-
memory_save_mode, dtype=g.dtype)
|
93
|
-
state['Q'] = triu_to_line(Q)
|
82
|
+
for p, g in split_p_and_g_in_group(group):
|
83
|
+
state = self.state_(p)
|
94
84
|
|
95
|
-
|
85
|
+
if 'Q' not in state:
|
86
|
+
Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular, min_ndim_triangular,
|
87
|
+
memory_save_mode, dtype=g.dtype)
|
88
|
+
state['Q'] = triu_to_line(Q) if store_triu_as_line else Q
|
96
89
|
|
97
|
-
|
98
|
-
continue
|
90
|
+
vals.append((p, g, state["Q"]))
|
99
91
|
|
100
|
-
|
101
|
-
|
92
|
+
if not vals:
|
93
|
+
return
|
102
94
|
|
103
|
-
|
95
|
+
p_list, grad_list, Q_list = zip(*vals)
|
96
|
+
del vals
|
104
97
|
|
105
|
-
|
106
|
-
for i, (p, g) in enumerate(zip(p_list, grad_list)):
|
107
|
-
q_orig = Q_list.pop(0)
|
108
|
-
q = line_to_triu(q_orig)
|
98
|
+
group["step"] += 1
|
109
99
|
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
100
|
+
Q_list = list(Q_list)
|
101
|
+
for i, (p, g) in enumerate(zip(p_list, grad_list)):
|
102
|
+
q_orig = Q_list.pop(0)
|
103
|
+
q = line_to_triu(q_orig) if store_triu_as_line else q_orig
|
114
104
|
|
115
|
-
|
105
|
+
if do_update:
|
106
|
+
self.balance([g], [q])
|
107
|
+
self.do_update([p], [g], [q], precond_lr, [q_orig] if store_triu_as_line else None)
|
108
|
+
psgd_precond_grad(q, self.state_(p)["exprs"], g, inplace=True)
|
116
109
|
|
117
|
-
|
118
|
-
update_param_(p_list, grad_list, lr, weight_decay)
|
110
|
+
grad_list = self.clip_fn(grad_list)
|
119
111
|
|
120
|
-
|
112
|
+
lr = -warmup(lr, group['step'], group['warmup_steps'])
|
113
|
+
update_param_(p_list, grad_list, lr, weight_decay)
|