heavyball 0.14.7__py3-none-any.whl → 0.15.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- heavyball/__init__.py +1 -1
- heavyball/delayed_psgd.py +39 -48
- heavyball/foreach_adamw.py +22 -32
- heavyball/foreach_adopt.py +38 -48
- heavyball/foreach_laprop.py +25 -35
- heavyball/foreach_sfadamw.py +28 -38
- heavyball/foreach_soap.py +56 -70
- heavyball/p_adam.py +39 -48
- heavyball/palm_foreach_sfadamw.py +31 -41
- heavyball/palm_foreach_soap.py +56 -70
- heavyball/precond_schedule_foreach_soap.py +57 -71
- heavyball/precond_schedule_palm_foreach_soap.py +58 -73
- heavyball/precond_schedule_sfpsoap.py +60 -72
- heavyball/psgd_kron.py +39 -47
- heavyball/pure_psgd.py +32 -41
- heavyball/schedule_free_palm_foreach_soap.py +61 -72
- heavyball/utils.py +17 -1
- {heavyball-0.14.7.dist-info → heavyball-0.15.0.dist-info}/METADATA +1 -1
- heavyball-0.15.0.dist-info/RECORD +22 -0
- heavyball-0.14.7.dist-info/RECORD +0 -22
- {heavyball-0.14.7.dist-info → heavyball-0.15.0.dist-info}/LICENSE +0 -0
- {heavyball-0.14.7.dist-info → heavyball-0.15.0.dist-info}/WHEEL +0 -0
- {heavyball-0.14.7.dist-info → heavyball-0.15.0.dist-info}/top_level.txt +0 -0
@@ -54,85 +54,73 @@ class PrecondScheduleSFPaLMSOAP(ScheduleFree):
|
|
54
54
|
self._data_format = data_format
|
55
55
|
self.rng = random.Random(0x120983109)
|
56
56
|
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
""
|
65
|
-
|
66
|
-
loss = None
|
67
|
-
else:
|
68
|
-
loss = closure()
|
69
|
-
|
70
|
-
for group in self.param_groups:
|
71
|
-
vals = []
|
72
|
-
max_precond_dim = group['max_precond_dim']
|
73
|
-
precondition_1d = group['precondition_1d']
|
74
|
-
|
75
|
-
step = group['step'] = group.get("step", -1) + 1
|
76
|
-
|
77
|
-
for p in group["params"]:
|
78
|
-
if p.grad is None:
|
79
|
-
continue
|
80
|
-
grad = p.grad.float()
|
81
|
-
vals.append((p, grad))
|
82
|
-
|
83
|
-
p_list, grad = zip(*vals)
|
84
|
-
vals = []
|
85
|
-
|
86
|
-
# adaptive gradient clipping
|
87
|
-
adaptive_gradient_clipping_(p_list, grad, group["gradient_clip_val"], eps=group["eps"])
|
88
|
-
|
89
|
-
for p, g in split_p_and_g_in_group(group):
|
90
|
-
state = self.state_(p)
|
91
|
-
|
92
|
-
if "z" not in state:
|
93
|
-
state["z"] = torch.clone(p.data)
|
94
|
-
state["exp_avg_sq"] = torch.zeros_like(g, dtype=torch.float32)
|
95
|
-
init_preconditioner(g, state, max_precond_dim, precondition_1d)
|
96
|
-
update_preconditioner(g, state, max_precond_dim, precondition_1d, 0, True)
|
97
|
-
continue # first step is skipped so that we never use the current gradients in the projection.
|
98
|
-
|
99
|
-
# Projecting gradients to the eigenbases of Shampoo's preconditioner
|
100
|
-
# i.e. projecting to the eigenbases of matrices in state['GG']
|
101
|
-
grad_projected = project(g, state['Q'], False)
|
102
|
-
z, exp_avg_sq = state["z"], state["exp_avg_sq"]
|
103
|
-
vals.append((p, g, grad_projected, z, exp_avg_sq))
|
104
|
-
|
105
|
-
if not vals:
|
57
|
+
def _step(self, group):
|
58
|
+
vals = []
|
59
|
+
max_precond_dim = group['max_precond_dim']
|
60
|
+
precondition_1d = group['precondition_1d']
|
61
|
+
|
62
|
+
step = group['step'] = group.get("step", -1) + 1
|
63
|
+
|
64
|
+
for p in group["params"]:
|
65
|
+
if p.grad is None:
|
106
66
|
continue
|
67
|
+
grad = p.grad.float()
|
68
|
+
vals.append((p, grad))
|
69
|
+
|
70
|
+
if not vals:
|
71
|
+
return
|
72
|
+
|
73
|
+
p_list, grad = zip(*vals)
|
74
|
+
vals = []
|
75
|
+
|
76
|
+
# adaptive gradient clipping
|
77
|
+
adaptive_gradient_clipping_(p_list, grad, group["gradient_clip_val"], eps=group["eps"])
|
78
|
+
|
79
|
+
for p, g in split_p_and_g_in_group(group):
|
80
|
+
state = self.state_(p)
|
81
|
+
|
82
|
+
if "z" not in state:
|
83
|
+
state["z"] = torch.clone(p.data)
|
84
|
+
state["exp_avg_sq"] = torch.zeros_like(g, dtype=torch.float32)
|
85
|
+
init_preconditioner(g, state, max_precond_dim, precondition_1d)
|
86
|
+
update_preconditioner(g, state, max_precond_dim, precondition_1d, 0, True)
|
87
|
+
continue # first step is skipped so that we never use the current gradients in the projection.
|
88
|
+
|
89
|
+
# Projecting gradients to the eigenbases of Shampoo's preconditioner
|
90
|
+
# i.e. projecting to the eigenbases of matrices in state['GG']
|
91
|
+
grad_projected = project(g, state['Q'], False)
|
92
|
+
z, exp_avg_sq = state["z"], state["exp_avg_sq"]
|
93
|
+
vals.append((p, g, grad_projected, z, exp_avg_sq))
|
107
94
|
|
108
|
-
|
109
|
-
|
95
|
+
if not vals:
|
96
|
+
return
|
110
97
|
|
111
|
-
|
112
|
-
|
98
|
+
p_list, grad, grad_projected, z, exp_avg_sq = zip(*vals)
|
99
|
+
del vals
|
113
100
|
|
114
|
-
|
115
|
-
|
116
|
-
denom = exp_avg_sq_(exp_avg_sq, grad_projected, old_debiased2, group['eps'])
|
117
|
-
torch._foreach_div_(grad_projected, denom)
|
101
|
+
beta2 = 1 - max(step, 1) ** -group['beta2_scale']
|
102
|
+
old_debiased2 = beta_debias(beta2, step)
|
118
103
|
|
119
|
-
|
104
|
+
# Decay the first and second moment running average coefficient
|
105
|
+
# In-place operations to update the averages at the same time
|
106
|
+
denom = exp_avg_sq_(exp_avg_sq, grad_projected, old_debiased2, group['eps'])
|
107
|
+
torch._foreach_div_(grad_projected, denom)
|
120
108
|
|
121
|
-
|
122
|
-
state = self.state_(p)
|
123
|
-
# Projecting back the preconditioned (by Adam) exponential moving average of gradients
|
124
|
-
# to the original space
|
125
|
-
set_(gp, project(gp, state['Q'], back=True))
|
109
|
+
update_precond = precond_schedule(step, group['precond_scheduler'], self.rng)
|
126
110
|
|
127
|
-
|
128
|
-
|
111
|
+
for p, g, gp in zip(p_list, grad, grad_projected):
|
112
|
+
state = self.state_(p)
|
113
|
+
# Projecting back the preconditioned (by Adam) exponential moving average of gradients
|
114
|
+
# to the original space
|
115
|
+
set_(gp, project(gp, state['Q'], back=True))
|
129
116
|
|
130
|
-
|
131
|
-
|
132
|
-
torch._foreach_add_(grad, p_list, alpha=group["weight_decay"])
|
117
|
+
update_preconditioner(g, state, max_precond_dim, precondition_1d, old_debiased2,
|
118
|
+
update_precond)
|
133
119
|
|
134
|
-
|
135
|
-
|
136
|
-
|
120
|
+
# Weight decay calculated at y
|
121
|
+
if group["weight_decay"] > 0:
|
122
|
+
torch._foreach_add_(grad, p_list, alpha=group["weight_decay"])
|
137
123
|
|
138
|
-
|
124
|
+
lr = warmup(group['lr'], step, group['warmup_steps'])
|
125
|
+
group['weight_sum'] = schedule_free_(lr, group['weight_lr_power'], group['weight_sum'], group['beta'],
|
126
|
+
p_list, z, grad_projected, group['r'], step)
|
heavyball/psgd_kron.py
CHANGED
@@ -4,9 +4,10 @@ Modified under Creative Commons Attribution 4.0 International
|
|
4
4
|
Source available at https://github.com/evanatyourservice/kron_torch/blob/97a2b5ee8a1a4c29e4780bbf6c521e545189eff9/kron_torch/kron.py
|
5
5
|
"""
|
6
6
|
|
7
|
-
import torch
|
8
7
|
from typing import Optional
|
9
8
|
|
9
|
+
import torch
|
10
|
+
|
10
11
|
from .utils import update_param_, warmup, psgd_precond_grad, init_Q_exprs, trust_region_clip_, PSGDBase, \
|
11
12
|
precond_update_prob_schedule, split_p_and_g_in_group, line_to_triu, triu_to_line, set_
|
12
13
|
|
@@ -62,13 +63,7 @@ class ForeachPSGDKron(PSGDBase):
|
|
62
63
|
|
63
64
|
self._prob_step = 0
|
64
65
|
|
65
|
-
|
66
|
-
def step(self, closure=None):
|
67
|
-
loss = None
|
68
|
-
if closure is not None:
|
69
|
-
with torch.enable_grad():
|
70
|
-
loss = closure()
|
71
|
-
|
66
|
+
def _step(self, group):
|
72
67
|
# update preconditioners all together
|
73
68
|
update_prob = self.preconditioner_update_probability
|
74
69
|
if callable(update_prob):
|
@@ -76,54 +71,51 @@ class ForeachPSGDKron(PSGDBase):
|
|
76
71
|
do_update = self.rng.random() < update_prob
|
77
72
|
self._prob_step += 1
|
78
73
|
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
beta = group['beta']
|
89
|
-
|
90
|
-
vals = []
|
74
|
+
momentum_into_precond_update = group.get("momentum_into_precond_update", True)
|
75
|
+
precond_init_scale = group['precond_init_scale']
|
76
|
+
max_size_triangular = group['max_size_triangular']
|
77
|
+
min_ndim_triangular = group['min_ndim_triangular']
|
78
|
+
memory_save_mode = group['memory_save_mode']
|
79
|
+
precond_lr = group['precond_lr']
|
80
|
+
weight_decay = group['weight_decay']
|
81
|
+
lr = group['lr']
|
82
|
+
beta = group['beta']
|
91
83
|
|
92
|
-
|
93
|
-
state = self.state_(p)
|
84
|
+
vals = []
|
94
85
|
|
95
|
-
|
96
|
-
|
97
|
-
Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular, min_ndim_triangular,
|
98
|
-
memory_save_mode, dtype=g.dtype)
|
99
|
-
state['Q'] = triu_to_line(Q)
|
86
|
+
for p, g in split_p_and_g_in_group(group):
|
87
|
+
state = self.state_(p)
|
100
88
|
|
101
|
-
|
89
|
+
if 'Q' not in state:
|
90
|
+
state["exp_avg"] = torch.zeros_like(g)
|
91
|
+
Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular, min_ndim_triangular,
|
92
|
+
memory_save_mode, dtype=g.dtype)
|
93
|
+
state['Q'] = triu_to_line(Q)
|
102
94
|
|
103
|
-
|
104
|
-
continue
|
95
|
+
vals.append((p, g, state["exp_avg"], state["Q"]))
|
105
96
|
|
106
|
-
|
107
|
-
|
97
|
+
if not vals:
|
98
|
+
return
|
108
99
|
|
109
|
-
|
100
|
+
p_list, grad_list, exp_avg_list, Q_list = zip(*vals)
|
101
|
+
del vals
|
110
102
|
|
111
|
-
|
103
|
+
group["step"] += 1
|
112
104
|
|
113
|
-
|
114
|
-
for i, (p, g) in enumerate(zip(p_list, grad_list)):
|
115
|
-
q_orig = Q_list.pop(0)
|
116
|
-
ea = exp_avg_list.pop(0)
|
117
|
-
q = line_to_triu(q_orig)
|
105
|
+
torch._foreach_lerp_(exp_avg_list, grad_list, (1 - beta) / (1 - beta ** group["step"]))
|
118
106
|
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
107
|
+
grad_list, Q_list, exp_avg_list = list(grad_list), list(Q_list), list(exp_avg_list)
|
108
|
+
for i, (p, g) in enumerate(zip(p_list, grad_list)):
|
109
|
+
q_orig = Q_list.pop(0)
|
110
|
+
ea = exp_avg_list.pop(0)
|
111
|
+
q = line_to_triu(q_orig)
|
123
112
|
|
124
|
-
|
113
|
+
self.balance(do_update, [g], [q])
|
114
|
+
if do_update:
|
115
|
+
self.do_update([p], [ea if momentum_into_precond_update else g], [q], precond_lr, [q_orig])
|
116
|
+
set_(g, psgd_precond_grad(q, self.state_(p)["exprs"], ea))
|
125
117
|
|
126
|
-
|
127
|
-
update_param_(p_list, grad_list, lr, weight_decay)
|
118
|
+
grad_list = self.clip_fn(grad_list)
|
128
119
|
|
129
|
-
|
120
|
+
lr = -warmup(lr, group['step'], group['warmup_steps'])
|
121
|
+
update_param_(p_list, grad_list, lr, weight_decay)
|
heavyball/pure_psgd.py
CHANGED
@@ -59,13 +59,7 @@ class ForeachPurePSGD(PSGDBase):
|
|
59
59
|
|
60
60
|
self._prob_step = 0
|
61
61
|
|
62
|
-
|
63
|
-
def step(self, closure=None):
|
64
|
-
loss = None
|
65
|
-
if closure is not None:
|
66
|
-
with torch.enable_grad():
|
67
|
-
loss = closure()
|
68
|
-
|
62
|
+
def _step(self, group):
|
69
63
|
# update preconditioners all together
|
70
64
|
update_prob = self.preconditioner_update_probability
|
71
65
|
if callable(update_prob):
|
@@ -73,48 +67,45 @@ class ForeachPurePSGD(PSGDBase):
|
|
73
67
|
do_update = self.rng.random() < update_prob
|
74
68
|
self._prob_step += 1
|
75
69
|
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
lr = group['lr']
|
84
|
-
|
85
|
-
vals = []
|
70
|
+
precond_init_scale = group['precond_init_scale']
|
71
|
+
max_size_triangular = group['max_size_triangular']
|
72
|
+
min_ndim_triangular = group['min_ndim_triangular']
|
73
|
+
memory_save_mode = group['memory_save_mode']
|
74
|
+
precond_lr = group['precond_lr']
|
75
|
+
weight_decay = group['weight_decay']
|
76
|
+
lr = group['lr']
|
86
77
|
|
87
|
-
|
88
|
-
state = self.state_(p)
|
78
|
+
vals = []
|
89
79
|
|
90
|
-
|
91
|
-
|
92
|
-
memory_save_mode, dtype=g.dtype)
|
93
|
-
state['Q'] = triu_to_line(Q)
|
80
|
+
for p, g in split_p_and_g_in_group(group):
|
81
|
+
state = self.state_(p)
|
94
82
|
|
95
|
-
|
83
|
+
if 'Q' not in state:
|
84
|
+
Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular, min_ndim_triangular,
|
85
|
+
memory_save_mode, dtype=g.dtype)
|
86
|
+
state['Q'] = triu_to_line(Q)
|
96
87
|
|
97
|
-
|
98
|
-
continue
|
88
|
+
vals.append((p, g, state["Q"]))
|
99
89
|
|
100
|
-
|
101
|
-
|
90
|
+
if not vals:
|
91
|
+
return
|
102
92
|
|
103
|
-
|
93
|
+
p_list, grad_list, Q_list = zip(*vals)
|
94
|
+
del vals
|
104
95
|
|
105
|
-
|
106
|
-
for i, (p, g) in enumerate(zip(p_list, grad_list)):
|
107
|
-
q_orig = Q_list.pop(0)
|
108
|
-
q = line_to_triu(q_orig)
|
96
|
+
group["step"] += 1
|
109
97
|
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
98
|
+
Q_list = list(Q_list)
|
99
|
+
for i, (p, g) in enumerate(zip(p_list, grad_list)):
|
100
|
+
q_orig = Q_list.pop(0)
|
101
|
+
q = line_to_triu(q_orig)
|
114
102
|
|
115
|
-
|
103
|
+
self.balance(do_update, [g], [q])
|
104
|
+
if do_update:
|
105
|
+
self.do_update([p], [g], [q], precond_lr, [q_orig])
|
106
|
+
psgd_precond_grad(q, self.state_(p)["exprs"], g, inplace=True)
|
116
107
|
|
117
|
-
|
118
|
-
update_param_(p_list, grad_list, lr, weight_decay)
|
108
|
+
grad_list = self.clip_fn(grad_list)
|
119
109
|
|
120
|
-
|
110
|
+
lr = -warmup(lr, group['step'], group['warmup_steps'])
|
111
|
+
update_param_(p_list, grad_list, lr, weight_decay)
|
@@ -46,84 +46,73 @@ class SFPaLMForeachSOAP(ScheduleFree):
|
|
46
46
|
self._data_format = data_format
|
47
47
|
self.rng = random.Random(0x120983109)
|
48
48
|
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
""
|
57
|
-
|
58
|
-
loss = None
|
59
|
-
else:
|
60
|
-
loss = closure()
|
61
|
-
|
62
|
-
for group in self.param_groups:
|
63
|
-
vals = []
|
64
|
-
max_precond_dim = group['max_precond_dim']
|
65
|
-
precondition_1d = group['precondition_1d']
|
66
|
-
|
67
|
-
step = group['step'] = group.get("step", -1) + 1
|
68
|
-
|
69
|
-
for p in group["params"]:
|
70
|
-
if p.grad is None:
|
71
|
-
continue
|
72
|
-
grad = p.grad.float()
|
73
|
-
vals.append((p, grad))
|
74
|
-
|
75
|
-
p_list, grad = zip(*vals)
|
76
|
-
|
77
|
-
adaptive_gradient_clipping_(p_list, grad, group["gradient_clip_val"], eps=group["eps"])
|
78
|
-
|
79
|
-
vals = []
|
80
|
-
|
81
|
-
for p, g in split_p_and_g_in_group(group):
|
82
|
-
state = self.state_(p)
|
83
|
-
|
84
|
-
if "z" not in state:
|
85
|
-
state["z"] = torch.clone(p).float()
|
86
|
-
state["exp_avg_sq"] = torch.zeros_like(g, dtype=torch.float32)
|
87
|
-
init_preconditioner(g, state, max_precond_dim, precondition_1d)
|
88
|
-
update_preconditioner(g, state, max_precond_dim, precondition_1d, 0, True)
|
89
|
-
continue # first step is skipped so that we never use the current gradients in the projection.
|
90
|
-
|
91
|
-
# Projecting gradients to the eigenbases of Shampoo's preconditioner
|
92
|
-
# i.e. projecting to the eigenbases of matrices in state['GG']
|
93
|
-
grad_projected = project(g, state['Q'], False)
|
94
|
-
z, exp_avg_sq = state["z"], state["exp_avg_sq"]
|
95
|
-
vals.append((p, g, grad_projected, z, exp_avg_sq))
|
96
|
-
|
97
|
-
if not vals:
|
49
|
+
def _step(self, group):
|
50
|
+
vals = []
|
51
|
+
max_precond_dim = group['max_precond_dim']
|
52
|
+
precondition_1d = group['precondition_1d']
|
53
|
+
|
54
|
+
step = group['step'] = group.get("step", -1) + 1
|
55
|
+
|
56
|
+
for p in group["params"]:
|
57
|
+
if p.grad is None:
|
98
58
|
continue
|
59
|
+
grad = p.grad.float()
|
60
|
+
vals.append((p, grad))
|
61
|
+
|
62
|
+
if not vals:
|
63
|
+
return
|
64
|
+
|
65
|
+
p_list, grad = zip(*vals)
|
66
|
+
|
67
|
+
adaptive_gradient_clipping_(p_list, grad, group["gradient_clip_val"], eps=group["eps"])
|
68
|
+
|
69
|
+
vals = []
|
70
|
+
|
71
|
+
for p, g in split_p_and_g_in_group(group):
|
72
|
+
state = self.state_(p)
|
73
|
+
|
74
|
+
if "z" not in state:
|
75
|
+
state["z"] = torch.clone(p).float()
|
76
|
+
state["exp_avg_sq"] = torch.zeros_like(g, dtype=torch.float32)
|
77
|
+
init_preconditioner(g, state, max_precond_dim, precondition_1d)
|
78
|
+
update_preconditioner(g, state, max_precond_dim, precondition_1d, 0, True)
|
79
|
+
continue # first step is skipped so that we never use the current gradients in the projection.
|
80
|
+
|
81
|
+
# Projecting gradients to the eigenbases of Shampoo's preconditioner
|
82
|
+
# i.e. projecting to the eigenbases of matrices in state['GG']
|
83
|
+
grad_projected = project(g, state['Q'], False)
|
84
|
+
z, exp_avg_sq = state["z"], state["exp_avg_sq"]
|
85
|
+
vals.append((p, g, grad_projected, z, exp_avg_sq))
|
86
|
+
|
87
|
+
if not vals:
|
88
|
+
return
|
99
89
|
|
100
|
-
|
90
|
+
p_list, grad, grad_projected, z, exp_avg_sq = zip(*vals)
|
101
91
|
|
102
|
-
|
103
|
-
|
92
|
+
beta2 = 1 - max(step, 1) ** -group['beta2_scale']
|
93
|
+
new_debiased2 = beta_debias(beta2, step)
|
104
94
|
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
95
|
+
# Decay the first and second moment running average coefficient
|
96
|
+
# In-place operations to update the averages at the same time
|
97
|
+
denom = exp_avg_sq_(exp_avg_sq, grad, new_debiased2, group["eps"])
|
98
|
+
torch._foreach_div_(grad_projected, denom)
|
109
99
|
|
110
|
-
|
100
|
+
update_precond = group['step'] > 0 and group['step'] % group['precondition_frequency'] == 0
|
111
101
|
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
102
|
+
for p, g, gp in zip(p_list, grad, grad_projected):
|
103
|
+
state = self.state_(p)
|
104
|
+
# Projecting back the preconditioned (by Adam) exponential moving average of gradients
|
105
|
+
# to the original space
|
106
|
+
# CANT DO /= HERE AS EXP_AVG MAY POINT TO THE BUFFER
|
107
|
+
set_(gp, project(gp, state['Q'], back=True))
|
118
108
|
|
119
|
-
|
120
|
-
|
109
|
+
update_preconditioner(g, state, max_precond_dim, precondition_1d, 1 - new_debiased2,
|
110
|
+
update_precond)
|
121
111
|
|
122
|
-
|
123
|
-
|
124
|
-
|
112
|
+
# Weight decay calculated at y
|
113
|
+
if group["weight_decay"] > 0:
|
114
|
+
torch._foreach_add_(grad, p_list, alpha=group["weight_decay"])
|
125
115
|
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
return loss
|
116
|
+
lr = warmup(group['lr'], step, group['warmup_steps'])
|
117
|
+
group['weight_sum'] = schedule_free_(lr, group['weight_lr_power'], group['weight_sum'], group['beta'],
|
118
|
+
p_list, z, grad_projected, group['r'], step)
|
heavyball/utils.py
CHANGED
@@ -3,7 +3,7 @@ import gc
|
|
3
3
|
import math
|
4
4
|
import random
|
5
5
|
import string
|
6
|
-
from typing import List, Optional, Tuple
|
6
|
+
from typing import List, Optional, Tuple, Callable
|
7
7
|
|
8
8
|
import numpy as np
|
9
9
|
import torch
|
@@ -399,6 +399,20 @@ class StatefulOptimizer(torch.optim.Optimizer):
|
|
399
399
|
tree_map(_add, self.state_(p))
|
400
400
|
return total_bytes
|
401
401
|
|
402
|
+
def _step(self, group):
|
403
|
+
raise NotImplementedError
|
404
|
+
|
405
|
+
def step(self, closure: Optional[Callable] = None):
|
406
|
+
if closure is None:
|
407
|
+
loss = None
|
408
|
+
else:
|
409
|
+
with torch.enable_grad():
|
410
|
+
loss = closure()
|
411
|
+
with torch.no_grad():
|
412
|
+
for group in self.param_groups:
|
413
|
+
self._step(group)
|
414
|
+
return loss
|
415
|
+
|
402
416
|
|
403
417
|
class ScheduleFree(StatefulOptimizer):
|
404
418
|
def eval(self):
|
@@ -684,9 +698,11 @@ def a_law_compress(x, A=87.6):
|
|
684
698
|
torch._foreach_mul_(xa, 1 / (1 + math.log(A)))
|
685
699
|
return xa
|
686
700
|
|
701
|
+
|
687
702
|
def identity(x):
|
688
703
|
return x
|
689
704
|
|
705
|
+
|
690
706
|
def trust_region_clip_(grad, lerp: float = 0.9, scale: float = 1.5):
|
691
707
|
torch._foreach_mul_(grad, 1 / scale)
|
692
708
|
tanh = torch._foreach_tanh(grad)
|
@@ -0,0 +1,22 @@
|
|
1
|
+
heavyball/__init__.py,sha256=AGNWRYDkPFZ9Px3117ChPgUgEk2gqflWWXO4UxHlmYc,1156
|
2
|
+
heavyball/delayed_psgd.py,sha256=28osHU-2khgdQ1ASglxTtA5MA1j2GiYP3OmNirkqMso,5574
|
3
|
+
heavyball/foreach_adamw.py,sha256=NSzoIgNm7eavzbJgkAF0k7TUEnWAgOpt9-4juIFoaSA,1729
|
4
|
+
heavyball/foreach_adopt.py,sha256=WA07m5jocLfb1GPU8s6mJ2PteS-03ronkKm-VJrAm5I,1863
|
5
|
+
heavyball/foreach_laprop.py,sha256=mE2NDGX9XgvRhsewcWnk_-FulZPqGA65ejYF_9-A1Xk,1768
|
6
|
+
heavyball/foreach_sfadamw.py,sha256=ussHfPd99u3RTfMrCuu5oIbwNFLXK19wO1Fbz3JShlc,2097
|
7
|
+
heavyball/foreach_soap.py,sha256=WWvssYKg607uoEJHftp8ag8mtKSKSeHrT0QTgqBucVg,4587
|
8
|
+
heavyball/p_adam.py,sha256=xOJuws2ELPcL-TUyH-2fPvwRdBNZUmaqiKDJFK33bPM,5694
|
9
|
+
heavyball/palm_foreach_sfadamw.py,sha256=wjUb_fNZNUmzWXyKvwB0unP9lvNMmaYSQo5YoeS5cj0,2200
|
10
|
+
heavyball/palm_foreach_soap.py,sha256=2Sb4hUHQeexJcCgjHeQM_ENkZ6lG1DVxW72ryrvR6iY,5890
|
11
|
+
heavyball/precond_schedule_foreach_soap.py,sha256=bHsDyh-UvHpHjumjqqy0PePoR1ZMsJV6o5wWvpLAA04,4815
|
12
|
+
heavyball/precond_schedule_palm_foreach_soap.py,sha256=myLTJNQKLtZ3Xi3MVTB-RYtx_XeMRJw5CIMJW75ndUY,6163
|
13
|
+
heavyball/precond_schedule_sfpsoap.py,sha256=xeNWetBzBEYqfOSzl98aAVJsHk43QkrUUhHH_YD_mS4,6740
|
14
|
+
heavyball/psgd_kron.py,sha256=7PxvVNeXQcxHMDdn0hFn4psEH89xOctflKt5sKjryMU,5554
|
15
|
+
heavyball/pure_psgd.py,sha256=1vafWGQ5vtEE01T1qhI9GDXvzFw5zzq0rJrewa1jY4E,4847
|
16
|
+
heavyball/schedule_free_palm_foreach_soap.py,sha256=w0P7lMmoijTpL9V7NwOHcNBFJQ7S1TS9aCiwPhY2yVw,6319
|
17
|
+
heavyball/utils.py,sha256=xNqBJBZyK5n5EKy2g4qkXf342uGPGvk6pzjFPzeBncM,27861
|
18
|
+
heavyball-0.15.0.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
|
19
|
+
heavyball-0.15.0.dist-info/METADATA,sha256=CuUQyfkSwtwfsnVo6vLM_0hIMwJfB_J26-baYAlqDvM,11667
|
20
|
+
heavyball-0.15.0.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
21
|
+
heavyball-0.15.0.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
|
22
|
+
heavyball-0.15.0.dist-info/RECORD,,
|
@@ -1,22 +0,0 @@
|
|
1
|
-
heavyball/__init__.py,sha256=ef7IWcPF8Uh3WQHzMiMqOFvUnU_LdG7BO9XVecJOph4,1156
|
2
|
-
heavyball/delayed_psgd.py,sha256=Gfa1ogkFPPL7ohayYAwbugB8hyLRUI5FgcJfsK69KGI,5936
|
3
|
-
heavyball/foreach_adamw.py,sha256=L727chOuVqdiVbYYzloy4g3oKH2FmQh40o_bqWeZtk8,2106
|
4
|
-
heavyball/foreach_adopt.py,sha256=M4zZVcYlhGjqO6qekivCuYfX6JtMcp4cJi_RrSHT3H8,2268
|
5
|
-
heavyball/foreach_laprop.py,sha256=htYGqgvlZsF_JpILdPMTnX72SqfrokBJ2J4nDeT0JVU,2157
|
6
|
-
heavyball/foreach_sfadamw.py,sha256=KIGol7Phfq1DHE_nEle4wDuqNdbGsT3kUcMKzJX3msg,2498
|
7
|
-
heavyball/foreach_soap.py,sha256=Ccz9Mc_xaHnrJ_7jUq9ZVxyR0WEqopzOXTUqUY-V8G8,5137
|
8
|
-
heavyball/p_adam.py,sha256=jQgTkKekqnLj1XPA4-fgpWG8P_BtUq2976zEt2QymTo,6060
|
9
|
-
heavyball/palm_foreach_sfadamw.py,sha256=8IGlRCdkfMzUqcSfmTM3Ce04NLNyrT2xfiBcPkrWwqc,2605
|
10
|
-
heavyball/palm_foreach_soap.py,sha256=NEJ3Xeh7pqURUk3cAP2qJe8z2WzYKg60pQe4bsGiaY4,6441
|
11
|
-
heavyball/precond_schedule_foreach_soap.py,sha256=H6Oc5IAL5MR-fgu92AboPs3Xm8mBmYUMPLsEcuJ12VI,5370
|
12
|
-
heavyball/precond_schedule_palm_foreach_soap.py,sha256=v81hRjcqS6Mm-KxT5Rk3TEiKAE8WI2IbmVbSa-YfBkE,6760
|
13
|
-
heavyball/precond_schedule_sfpsoap.py,sha256=7ScnN0in8u9hPiJE7QnOoZOH6Tn-6HeVy4f-bO3bHzY,7279
|
14
|
-
heavyball/psgd_kron.py,sha256=AH8ugd_IxKGVtY9y_Ot7myVSxFDbLlRJIqr2bBlAYy8,5911
|
15
|
-
heavyball/pure_psgd.py,sha256=jp5fnawUdgccEFlZDPrZr4ZbxYV85IIrev4tybZxBVU,5185
|
16
|
-
heavyball/schedule_free_palm_foreach_soap.py,sha256=bV7H-FNNoH5WpposLrNhkqU7mBicMorqKEALBSdROEM,6853
|
17
|
-
heavyball/utils.py,sha256=y5VAd9CQjcl_a1WUcORviAYf7Jz_c7n3-b7i5kLUJIA,27464
|
18
|
-
heavyball-0.14.7.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
|
19
|
-
heavyball-0.14.7.dist-info/METADATA,sha256=5QWB3nuNAp8YjeX0-Y5Uzkek_wjuGF3XG6UWrQk8R0c,11667
|
20
|
-
heavyball-0.14.7.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
21
|
-
heavyball-0.14.7.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
|
22
|
-
heavyball-0.14.7.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|