heavyball 0.14.7__py3-none-any.whl → 0.15.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- heavyball/__init__.py +25 -3
- heavyball/cached_psgd_kron.py +141 -0
- heavyball/delayed_psgd.py +43 -51
- heavyball/foreach_adamw.py +22 -32
- heavyball/foreach_adopt.py +38 -48
- heavyball/foreach_laprop.py +25 -35
- heavyball/foreach_sfadamw.py +28 -38
- heavyball/foreach_soap.py +56 -70
- heavyball/p_adam.py +46 -50
- heavyball/palm_foreach_sfadamw.py +31 -41
- heavyball/palm_foreach_soap.py +56 -70
- heavyball/precond_schedule_foreach_soap.py +57 -71
- heavyball/precond_schedule_palm_foreach_soap.py +58 -73
- heavyball/precond_schedule_sfpsoap.py +60 -72
- heavyball/psgd_kron.py +43 -49
- heavyball/pure_psgd.py +36 -43
- heavyball/schedule_free_palm_foreach_soap.py +61 -72
- heavyball/utils.py +23 -7
- {heavyball-0.14.7.dist-info → heavyball-0.15.1.dist-info}/METADATA +1 -1
- heavyball-0.15.1.dist-info/RECORD +23 -0
- heavyball-0.14.7.dist-info/RECORD +0 -22
- {heavyball-0.14.7.dist-info → heavyball-0.15.1.dist-info}/LICENSE +0 -0
- {heavyball-0.14.7.dist-info → heavyball-0.15.1.dist-info}/WHEEL +0 -0
- {heavyball-0.14.7.dist-info → heavyball-0.15.1.dist-info}/top_level.txt +0 -0
heavyball/__init__.py
CHANGED
@@ -1,3 +1,5 @@
|
|
1
|
+
from .cached_psgd_kron import ForeachCachedPSGDKron
|
2
|
+
from .delayed_psgd import ForeachDelayedPSGD
|
1
3
|
from .foreach_adamw import ForeachAdamW
|
2
4
|
from .foreach_adopt import ForeachADOPT
|
3
5
|
from .foreach_laprop import ForeachLaProp
|
@@ -12,11 +14,31 @@ from .precond_schedule_sfpsoap import PrecondScheduleSFPaLMSOAP
|
|
12
14
|
from .psgd_kron import ForeachPSGDKron
|
13
15
|
from .pure_psgd import ForeachPurePSGD
|
14
16
|
from .schedule_free_palm_foreach_soap import SFPaLMForeachSOAP
|
15
|
-
from .delayed_psgd import ForeachDelayedPSGD
|
16
17
|
|
17
18
|
PalmForEachSoap = PaLMForeachSOAP
|
18
19
|
|
20
|
+
PaLMSOAP = PaLMForeachSOAP
|
21
|
+
PaLMSFAdamW = PaLMForeachSFAdamW
|
22
|
+
PaLMSFSoap = SFPaLMForeachSOAP
|
23
|
+
PaLMForeachSOAP = PaLMForeachSOAP
|
24
|
+
PrecondScheduleSFPaLMSOAP = PrecondScheduleSFPaLMSOAP
|
25
|
+
SOAP = ForeachSOAP
|
26
|
+
SFAdamW = ForeachSFAdamW
|
27
|
+
LaProp = ForeachLaProp
|
28
|
+
ADOPT = ForeachADOPT
|
29
|
+
PrecondScheduleForeachSOAP = PrecondScheduleForeachSOAP
|
30
|
+
PrecondSchedulePaLMForeachSOAP = PrecondSchedulePaLMForeachSOAP
|
31
|
+
PSGDKron = ForeachPSGDKron
|
32
|
+
AdamW = ForeachAdamW
|
33
|
+
PurePSGD = ForeachPurePSGD
|
34
|
+
PaLMPAdam = ForeachPaLMPAdam
|
35
|
+
DelayedPSGD = ForeachDelayedPSGD
|
36
|
+
CachedPSGDKron = ForeachCachedPSGDKron
|
37
|
+
|
19
38
|
__all__ = ['PalmForEachSoap', 'PaLMForeachSFAdamW', 'PaLMForeachSOAP', 'SFPaLMForeachSOAP', 'PrecondScheduleSFPaLMSOAP',
|
20
39
|
'ForeachSOAP', 'ForeachSFAdamW', 'ForeachLaProp', 'ForeachADOPT', 'PrecondScheduleForeachSOAP',
|
21
|
-
'PrecondSchedulePaLMForeachSOAP', 'ForeachPSGDKron', 'ForeachAdamW', 'ForeachPurePSGD',
|
22
|
-
'
|
40
|
+
'PrecondSchedulePaLMForeachSOAP', 'ForeachPSGDKron', 'ForeachAdamW', 'ForeachPurePSGD', 'ForeachPaLMPAdam',
|
41
|
+
'ForeachDelayedPSGD', 'ForeachCachedPSGDKron', #
|
42
|
+
'PaLMSOAP', 'PaLMSFAdamW', 'PaLMSFSoap', 'PaLMSFAdamW', 'PaLMForeachSOAP', 'PrecondScheduleSFPaLMSOAP',
|
43
|
+
'SOAP', 'SFAdamW', 'LaProp', 'ADOPT', 'PSGDKron', 'AdamW', 'PurePSGD', 'PaLMPAdam', 'DelayedPSGD',
|
44
|
+
'CachedPSGDKron']
|
@@ -0,0 +1,141 @@
|
|
1
|
+
"""
|
2
|
+
Originally from Evan Walters and Omead Pooladzandi, 2024
|
3
|
+
Modified under Creative Commons Attribution 4.0 International
|
4
|
+
Source available at https://github.com/evanatyourservice/kron_torch/blob/97a2b5ee8a1a4c29e4780bbf6c521e545189eff9/kron_torch/kron.py
|
5
|
+
"""
|
6
|
+
|
7
|
+
from typing import Optional
|
8
|
+
|
9
|
+
import torch
|
10
|
+
from heavyball.utils import einsum_base
|
11
|
+
|
12
|
+
from .utils import update_param_, warmup, psgd_precond_grad, init_Q_exprs, trust_region_clip_, PSGDBase, \
|
13
|
+
precond_update_prob_schedule, split_p_and_g_in_group, line_to_triu, triu_to_line, set_, einsum_base
|
14
|
+
|
15
|
+
|
16
|
+
class ForeachCachedPSGDKron(PSGDBase):
|
17
|
+
"""Implements PSGD Kron from https://github.com/lixilinx/psgd_torch with cached preconditioners.
|
18
|
+
|
19
|
+
Args:
|
20
|
+
params (iterable): Iterable of parameters to optimize or dicts defining
|
21
|
+
parameter groups.
|
22
|
+
lr (float): Learning rate.
|
23
|
+
b1 (float): Momentum parameter.
|
24
|
+
weight_decay (float): Weight decay (L2 penalty).
|
25
|
+
preconditioner_update_probability (callable or float, optional): Probability of
|
26
|
+
updating the preconditioner. If None, defaults to a schedule that anneals
|
27
|
+
from 1.0 to 0.03 by 4000 steps.
|
28
|
+
max_size_triangular (int): Max size for dim's preconditioner to be triangular.
|
29
|
+
min_ndim_triangular (int): Minimum number of dimensions a layer needs
|
30
|
+
to have triangular preconditioners.
|
31
|
+
memory_save_mode: (string, optional), None, 'one_diag', or 'all_diag', None is default
|
32
|
+
to set all preconditioners to be triangular, 'one_diag' sets the largest
|
33
|
+
or last dim to be diagonal per layer, and 'all_diag' sets all preconditioners
|
34
|
+
to be diagonal.
|
35
|
+
momentum_into_precond_update: (bool), whether to send momentum into preconditioner
|
36
|
+
update instead of raw gradients.
|
37
|
+
"""
|
38
|
+
|
39
|
+
def __init__(self, params, lr=0.001, beta=0.9, weight_decay=0.0, preconditioner_update_probability=None,
|
40
|
+
max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
|
41
|
+
momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
|
42
|
+
split: bool = False, clip_fn: Optional[callable] = None, store_triu_as_line: bool = True):
|
43
|
+
if not 0.0 <= lr:
|
44
|
+
raise ValueError(f"Invalid learning rate: {lr}")
|
45
|
+
if not 0.0 <= beta < 1.0:
|
46
|
+
raise ValueError(f"Invalid beta parameter: {beta}")
|
47
|
+
if not 0.0 <= weight_decay:
|
48
|
+
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
|
49
|
+
|
50
|
+
if preconditioner_update_probability is None:
|
51
|
+
preconditioner_update_probability = precond_update_prob_schedule()
|
52
|
+
if clip_fn is None:
|
53
|
+
clip_fn = lambda x: trust_region_clip_(x, 0.9, 1.5)
|
54
|
+
self.preconditioner_update_probability = preconditioner_update_probability
|
55
|
+
self.clip_fn = clip_fn
|
56
|
+
|
57
|
+
defaults = dict(lr=lr, beta=beta, weight_decay=weight_decay, max_size_triangular=max_size_triangular,
|
58
|
+
min_ndim_triangular=min_ndim_triangular, memory_save_mode=memory_save_mode,
|
59
|
+
momentum_into_precond_update=momentum_into_precond_update, precond_lr=0.1,
|
60
|
+
# precond lr hardcoded to 0.1
|
61
|
+
precond_init_scale=1.0, # precond init scale hardcoded to 1.0
|
62
|
+
step=0, warmup_steps=warmup_steps, merge_dims=merge_dims, split=split,
|
63
|
+
store_triu_as_line=store_triu_as_line)
|
64
|
+
super().__init__(params, defaults)
|
65
|
+
|
66
|
+
self._prob_step = 0
|
67
|
+
|
68
|
+
def _step(self, group):
|
69
|
+
# update preconditioners all together
|
70
|
+
update_prob = self.preconditioner_update_probability
|
71
|
+
if callable(update_prob):
|
72
|
+
update_prob = update_prob(self._prob_step)
|
73
|
+
do_update = self.rng.random() < update_prob
|
74
|
+
self._prob_step += 1
|
75
|
+
|
76
|
+
momentum_into_precond_update = group.get("momentum_into_precond_update", True)
|
77
|
+
precond_init_scale = group['precond_init_scale']
|
78
|
+
max_size_triangular = group['max_size_triangular']
|
79
|
+
min_ndim_triangular = group['min_ndim_triangular']
|
80
|
+
memory_save_mode = group['memory_save_mode']
|
81
|
+
precond_lr = group['precond_lr']
|
82
|
+
weight_decay = group['weight_decay']
|
83
|
+
lr = group['lr']
|
84
|
+
beta = group['beta']
|
85
|
+
store_triu_as_line = group['store_triu_as_line']
|
86
|
+
|
87
|
+
vals = []
|
88
|
+
|
89
|
+
for p, g in split_p_and_g_in_group(group):
|
90
|
+
state = self.state_(p)
|
91
|
+
|
92
|
+
if 'Q' not in state:
|
93
|
+
state["exp_avg"] = torch.zeros_like(g)
|
94
|
+
Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular, min_ndim_triangular,
|
95
|
+
memory_save_mode, dtype=g.dtype)
|
96
|
+
state['Q'] = triu_to_line(Q) if store_triu_as_line else Q
|
97
|
+
state['Q_cache'] = [torch.empty_like(q) for q in Q]
|
98
|
+
|
99
|
+
expr = [f'{c.upper()}{c}' if q_.ndim == 2 else c for c, q_ in zip(einsum_base, Q)]
|
100
|
+
expr = ','.join(expr)
|
101
|
+
grad_expr = ''.join(c for c, _ in zip(einsum_base, g.shape))
|
102
|
+
out_expr = ''.join(c.upper() if c.upper() in expr else c for c in grad_expr)
|
103
|
+
expr = f'{expr},{grad_expr}->{out_expr}'
|
104
|
+
|
105
|
+
state['cache_expr'] = expr
|
106
|
+
|
107
|
+
vals.append((p, g, state["exp_avg"], state["Q"], state['Q_cache']))
|
108
|
+
|
109
|
+
if not vals:
|
110
|
+
return
|
111
|
+
|
112
|
+
p_list, grad_list, exp_avg_list, Q_list, Q_cache_list = zip(*vals)
|
113
|
+
del vals
|
114
|
+
|
115
|
+
group["step"] += 1
|
116
|
+
|
117
|
+
torch._foreach_lerp_(exp_avg_list, grad_list, (1 - beta) / (1 - beta ** group["step"]))
|
118
|
+
|
119
|
+
grad_list, Q_list, Q_cache_list, exp_avg_list = list(grad_list), list(Q_list), list(Q_cache_list), list(
|
120
|
+
exp_avg_list)
|
121
|
+
for i, (p, g) in enumerate(zip(p_list, grad_list)):
|
122
|
+
cached_q = Q_cache_list.pop(0)
|
123
|
+
q_orig = Q_list.pop(0)
|
124
|
+
ea = exp_avg_list.pop(0)
|
125
|
+
|
126
|
+
if do_update:
|
127
|
+
q = line_to_triu(q_orig) if store_triu_as_line else q_orig
|
128
|
+
self.balance([g], [q])
|
129
|
+
self.do_update([p], [ea if momentum_into_precond_update else g], [q], precond_lr,
|
130
|
+
[q_orig] if store_triu_as_line else None)
|
131
|
+
for c_, q_ in zip(cached_q, q):
|
132
|
+
if q_.ndim == 2:
|
133
|
+
torch.matmul(q_.T.conj(), q_, out=c_)
|
134
|
+
else:
|
135
|
+
torch.mul(q_.conj(), q_, out=c_)
|
136
|
+
|
137
|
+
set_(g, torch.einsum(self.state_(p)['cache_expr'], *cached_q, ea))
|
138
|
+
grad_list = self.clip_fn(grad_list)
|
139
|
+
|
140
|
+
lr = -warmup(lr, group['step'], group['warmup_steps'])
|
141
|
+
update_param_(p_list, grad_list, lr, weight_decay)
|
heavyball/delayed_psgd.py
CHANGED
@@ -5,8 +5,8 @@ Source available at https://github.com/evanatyourservice/kron_torch/blob/97a2b5e
|
|
5
5
|
"""
|
6
6
|
|
7
7
|
import torch
|
8
|
-
|
9
8
|
from heavyball.utils import copy_stochastic_list_
|
9
|
+
|
10
10
|
from .utils import update_param_, warmup, psgd_precond_grad, init_Q_exprs, trust_region_clip_, PSGDBase, \
|
11
11
|
precond_update_prob_schedule, split_p_and_g_in_group, triu_to_line, line_to_triu, set_
|
12
12
|
|
@@ -38,7 +38,7 @@ class ForeachDelayedPSGD(PSGDBase):
|
|
38
38
|
def __init__(self, params, lr=0.001, beta=0.9, weight_decay=0.0, preconditioner_update_probability=None,
|
39
39
|
max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
|
40
40
|
momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
|
41
|
-
split: bool = False, clip_fn: callable = None):
|
41
|
+
split: bool = False, clip_fn: callable = None, store_triu_as_line: bool = True):
|
42
42
|
if not 0.0 <= lr:
|
43
43
|
raise ValueError(f"Invalid learning rate: {lr}")
|
44
44
|
if not 0.0 <= beta < 1.0:
|
@@ -58,18 +58,13 @@ class ForeachDelayedPSGD(PSGDBase):
|
|
58
58
|
momentum_into_precond_update=momentum_into_precond_update, precond_lr=0.1,
|
59
59
|
# precond lr hardcoded to 0.1
|
60
60
|
precond_init_scale=1.0, # precond init scale hardcoded to 1.0
|
61
|
-
step=0, warmup_steps=warmup_steps, merge_dims=merge_dims, split=split
|
61
|
+
step=0, warmup_steps=warmup_steps, merge_dims=merge_dims, split=split,
|
62
|
+
store_triu_as_line=store_triu_as_line)
|
62
63
|
super().__init__(params, defaults)
|
63
64
|
|
64
65
|
self._prob_step = 0
|
65
66
|
|
66
|
-
|
67
|
-
def step(self, closure=None):
|
68
|
-
loss = None
|
69
|
-
if closure is not None:
|
70
|
-
with torch.enable_grad():
|
71
|
-
loss = closure()
|
72
|
-
|
67
|
+
def _step(self, group):
|
73
68
|
# update preconditioners all together
|
74
69
|
update_prob = self.preconditioner_update_probability
|
75
70
|
if callable(update_prob):
|
@@ -77,55 +72,52 @@ class ForeachDelayedPSGD(PSGDBase):
|
|
77
72
|
do_update = self.rng.random() < update_prob
|
78
73
|
self._prob_step += 1
|
79
74
|
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
vals = []
|
92
|
-
|
93
|
-
for p, g in split_p_and_g_in_group(group):
|
94
|
-
state = self.state_(p)
|
75
|
+
momentum_into_precond_update = group.get("momentum_into_precond_update", True)
|
76
|
+
precond_init_scale = group['precond_init_scale']
|
77
|
+
max_size_triangular = group['max_size_triangular']
|
78
|
+
min_ndim_triangular = group['min_ndim_triangular']
|
79
|
+
memory_save_mode = group['memory_save_mode']
|
80
|
+
precond_lr = group['precond_lr']
|
81
|
+
weight_decay = group['weight_decay']
|
82
|
+
lr = group['lr']
|
83
|
+
beta = group['beta']
|
84
|
+
store_triu_as_line = group['store_triu_as_line']
|
95
85
|
|
96
|
-
|
97
|
-
state["exp_avg"] = torch.zeros_like(g)
|
98
|
-
Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular, min_ndim_triangular,
|
99
|
-
memory_save_mode, dtype=g.dtype)
|
100
|
-
state["Q"] = triu_to_line(Q)
|
86
|
+
vals = []
|
101
87
|
|
102
|
-
|
88
|
+
for p, g in split_p_and_g_in_group(group):
|
89
|
+
state = self.state_(p)
|
103
90
|
|
104
|
-
if not
|
105
|
-
|
91
|
+
if 'Q' not in state:
|
92
|
+
state["exp_avg"] = torch.zeros_like(g)
|
93
|
+
Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular, min_ndim_triangular,
|
94
|
+
memory_save_mode, dtype=g.dtype)
|
95
|
+
state["Q"] = triu_to_line(Q) if store_triu_as_line else Q
|
106
96
|
|
107
|
-
|
108
|
-
del vals
|
97
|
+
vals.append((p, g, state["exp_avg"], state["Q"]))
|
109
98
|
|
110
|
-
|
99
|
+
if not vals:
|
100
|
+
return
|
111
101
|
|
112
|
-
|
102
|
+
p_list, grad_list, exp_avg_list, Q_list = zip(*vals)
|
103
|
+
del vals
|
113
104
|
|
114
|
-
|
115
|
-
for i, (p, g) in enumerate(zip(p_list, grad_list)):
|
116
|
-
q_orig = Q_list.pop(0)
|
117
|
-
ea = exp_avg_list.pop(0)
|
118
|
-
q = line_to_triu(q_orig)
|
119
|
-
self.balance(do_update, [g], [q])
|
120
|
-
new = psgd_precond_grad(q, self.state_(p)["exprs"], ea)
|
105
|
+
group["step"] += 1
|
121
106
|
|
122
|
-
|
123
|
-
self.do_update([p], [ea if momentum_into_precond_update else g], [q], precond_lr, [q_orig])
|
124
|
-
set_(g, new)
|
107
|
+
torch._foreach_lerp_(exp_avg_list, grad_list, (1 - beta) / (1 - beta ** group["step"]))
|
125
108
|
|
126
|
-
|
109
|
+
Q_list, exp_avg_list = list(Q_list), list(exp_avg_list)
|
110
|
+
for i, (p, g) in enumerate(zip(p_list, grad_list)):
|
111
|
+
q_orig = Q_list.pop(0)
|
112
|
+
ea = exp_avg_list.pop(0)
|
113
|
+
q = line_to_triu(q_orig) if store_triu_as_line else q_orig
|
114
|
+
new = psgd_precond_grad(q, self.state_(p)["exprs"], ea)
|
115
|
+
if do_update:
|
116
|
+
self.do_update([p], [ea if momentum_into_precond_update else g], [q], precond_lr, [q_orig] if store_triu_as_line else None)
|
117
|
+
self.balance([g], [q])
|
118
|
+
set_(g, new)
|
127
119
|
|
128
|
-
|
129
|
-
update_param_(p_list, grad_list, lr, weight_decay)
|
120
|
+
grad_list = self.clip_fn(grad_list)
|
130
121
|
|
131
|
-
|
122
|
+
lr = -warmup(lr, group['step'], group['warmup_steps'])
|
123
|
+
update_param_(p_list, grad_list, lr, weight_decay)
|
heavyball/foreach_adamw.py
CHANGED
@@ -10,42 +10,32 @@ class ForeachAdamW(StatefulOptimizer):
|
|
10
10
|
lr_max=-1.0, weight_decay=weight_decay)
|
11
11
|
super().__init__(params, defaults)
|
12
12
|
|
13
|
-
def
|
14
|
-
|
13
|
+
def _step(self, group):
|
14
|
+
eps = group['eps']
|
15
|
+
decay = group['weight_decay']
|
16
|
+
k = group['k']
|
15
17
|
|
16
|
-
|
17
|
-
|
18
|
-
and returns the loss.
|
19
|
-
"""
|
18
|
+
if not group['train_mode']:
|
19
|
+
raise Exception("Not in train mode!")
|
20
20
|
|
21
|
-
|
22
|
-
if closure is not None:
|
23
|
-
loss = closure()
|
21
|
+
active_p = [p for p in group['params'] if p.grad is not None]
|
24
22
|
|
25
|
-
|
26
|
-
|
27
|
-
decay = group['weight_decay']
|
28
|
-
k = group['k']
|
23
|
+
if not active_p:
|
24
|
+
return
|
29
25
|
|
30
|
-
|
31
|
-
|
26
|
+
for p in active_p:
|
27
|
+
if 'exp_avg' not in self.state_(p):
|
28
|
+
self.state_(p)['exp_avg'] = torch.zeros_like(p.data, dtype=torch.float32)
|
29
|
+
self.state_(p)['exp_avg_sq'] = torch.zeros_like(p.data, dtype=torch.float32)
|
32
30
|
|
33
|
-
|
31
|
+
y, grad, exp_avg_sq, exp_avg = zip(
|
32
|
+
*[(p.data, p.grad.float(), self.state_(p)['exp_avg_sq'], self.state_(p)['exp_avg']) for p in active_p])
|
34
33
|
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
self.state_(p)['exp_avg_sq'] = torch.zeros_like(p.data, dtype=torch.float32)
|
34
|
+
# Decay the first and second moment running average coefficient
|
35
|
+
torch._foreach_lerp_(exp_avg, grad, 1 - beta_debias(group['betas'][0], k + 1))
|
36
|
+
denom = exp_avg_sq_(exp_avg_sq, grad, beta_debias(group['betas'][1], k + 1), eps)
|
39
37
|
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
torch._foreach_lerp_(exp_avg, grad, 1 - beta_debias(group['betas'][0], k + 1))
|
45
|
-
denom = exp_avg_sq_(exp_avg_sq, grad, beta_debias(group['betas'][1], k + 1), eps)
|
46
|
-
|
47
|
-
# Normalize grad in-place for memory efficiency
|
48
|
-
lr = -warmup(group['lr'], k + 1, group['warmup_steps'])
|
49
|
-
update_param_(y, exp_avg, lr, decay, lambda p, e, l: torch._foreach_addcdiv_(p, e, denom, l))
|
50
|
-
group['k'] = k + 1
|
51
|
-
return loss
|
38
|
+
# Normalize grad in-place for memory efficiency
|
39
|
+
lr = -warmup(group['lr'], k + 1, group['warmup_steps'])
|
40
|
+
update_param_(y, exp_avg, lr, decay, lambda p, e, l: torch._foreach_addcdiv_(p, e, denom, l))
|
41
|
+
group['k'] = k + 1
|
heavyball/foreach_adopt.py
CHANGED
@@ -11,51 +11,41 @@ class ForeachADOPT(StatefulOptimizer):
|
|
11
11
|
lr_max=-1.0, weight_decay=weight_decay)
|
12
12
|
super().__init__(params, defaults)
|
13
13
|
|
14
|
-
def
|
15
|
-
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
torch._foreach_mul_(exp_avg, beta1)
|
53
|
-
torch._foreach_addcdiv_(exp_avg, grad, denom, 1 - beta1)
|
54
|
-
|
55
|
-
beta2 = beta_debias(group['betas'][1], k + 1)
|
56
|
-
torch._foreach_mul_(exp_avg_sq, beta2)
|
57
|
-
torch._foreach_addcmul_(exp_avg_sq, grad, grad, value=1 - beta2)
|
58
|
-
del grad
|
59
|
-
|
60
|
-
group['k'] = k + 1
|
61
|
-
return loss
|
14
|
+
def _step(self, group):
|
15
|
+
eps = group['eps']
|
16
|
+
decay = group['weight_decay']
|
17
|
+
k = group['k']
|
18
|
+
|
19
|
+
if not group['train_mode']:
|
20
|
+
raise Exception("Not in train mode!")
|
21
|
+
|
22
|
+
active_p = [p for p in group['params'] if p.grad is not None]
|
23
|
+
|
24
|
+
if not active_p:
|
25
|
+
return
|
26
|
+
|
27
|
+
for p in active_p:
|
28
|
+
if 'exp_avg' not in self.state_(p):
|
29
|
+
self.state_(p)['exp_avg'] = torch.zeros_like(p.data, dtype=torch.float32)
|
30
|
+
self.state_(p)['exp_avg_sq'] = torch.zeros_like(p.data, dtype=torch.float32)
|
31
|
+
|
32
|
+
y, grad, exp_avg_sq, exp_avg = zip(
|
33
|
+
*[(p.data, p.grad.float(), self.state_(p)['exp_avg_sq'], self.state_(p)['exp_avg']) for p in active_p])
|
34
|
+
|
35
|
+
if k > 1:
|
36
|
+
lr = -warmup(group['lr'], k - 1, group['warmup_steps'])
|
37
|
+
|
38
|
+
update_param_(y, exp_avg, lr, decay)
|
39
|
+
if k > 0:
|
40
|
+
beta1 = beta_debias(group['betas'][0], k)
|
41
|
+
denom = torch._foreach_sqrt(exp_avg_sq)
|
42
|
+
torch._foreach_maximum_(denom, eps)
|
43
|
+
torch._foreach_mul_(exp_avg, beta1)
|
44
|
+
torch._foreach_addcdiv_(exp_avg, grad, denom, 1 - beta1)
|
45
|
+
|
46
|
+
beta2 = beta_debias(group['betas'][1], k + 1)
|
47
|
+
torch._foreach_mul_(exp_avg_sq, beta2)
|
48
|
+
torch._foreach_addcmul_(exp_avg_sq, grad, grad, value=1 - beta2)
|
49
|
+
del grad
|
50
|
+
|
51
|
+
group['k'] = k + 1
|
heavyball/foreach_laprop.py
CHANGED
@@ -11,46 +11,36 @@ class ForeachLaProp(StatefulOptimizer):
|
|
11
11
|
lr_max=-1.0, weight_decay=weight_decay)
|
12
12
|
super().__init__(params, defaults)
|
13
13
|
|
14
|
-
def
|
15
|
-
|
14
|
+
def _step(self, group):
|
15
|
+
eps = group['eps']
|
16
|
+
decay = group['weight_decay']
|
17
|
+
k = group['k']
|
16
18
|
|
17
|
-
|
18
|
-
|
19
|
-
and returns the loss.
|
20
|
-
"""
|
19
|
+
if not group['train_mode']:
|
20
|
+
raise Exception("Not in train mode!")
|
21
21
|
|
22
|
-
|
23
|
-
if closure is not None:
|
24
|
-
loss = closure()
|
22
|
+
active_p = [p for p in group['params'] if p.grad is not None]
|
25
23
|
|
26
|
-
|
27
|
-
|
28
|
-
decay = group['weight_decay']
|
29
|
-
k = group['k']
|
24
|
+
if not active_p:
|
25
|
+
return
|
30
26
|
|
31
|
-
|
32
|
-
|
27
|
+
for p in active_p:
|
28
|
+
if 'exp_avg' not in self.state_(p):
|
29
|
+
self.state_(p)['exp_avg'] = torch.zeros_like(p.data, dtype=torch.float32)
|
30
|
+
self.state_(p)['exp_avg_sq'] = torch.zeros_like(p.data, dtype=torch.float32)
|
33
31
|
|
34
|
-
|
32
|
+
y, grad, exp_avg_sq, exp_avg = zip(
|
33
|
+
*[(p.data, p.grad.float(), self.state_(p)['exp_avg_sq'], self.state_(p)['exp_avg']) for p in active_p])
|
35
34
|
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
35
|
+
# Decay the first and second moment running average coefficient
|
36
|
+
denom = exp_avg_sq_(exp_avg_sq, grad, beta_debias(group['betas'][1], k + 1), eps)
|
37
|
+
beta1 = beta_debias(group['betas'][0], k + 1)
|
38
|
+
torch._foreach_mul_(exp_avg, beta1)
|
39
|
+
torch._foreach_addcdiv_(exp_avg, grad, denom, 1 - beta1)
|
40
|
+
del grad
|
40
41
|
|
41
|
-
|
42
|
-
|
42
|
+
# Normalize grad in-place for memory efficiency
|
43
|
+
lr = -warmup(group['lr'], k + 1, group['warmup_steps'])
|
44
|
+
update_param_(y, exp_avg, lr, decay)
|
43
45
|
|
44
|
-
|
45
|
-
denom = exp_avg_sq_(exp_avg_sq, grad, beta_debias(group['betas'][1], k + 1), eps)
|
46
|
-
beta1 = beta_debias(group['betas'][0], k + 1)
|
47
|
-
torch._foreach_mul_(exp_avg, beta1)
|
48
|
-
torch._foreach_addcdiv_(exp_avg, grad, denom, 1 - beta1)
|
49
|
-
del grad
|
50
|
-
|
51
|
-
# Normalize grad in-place for memory efficiency
|
52
|
-
lr = -warmup(group['lr'], k + 1, group['warmup_steps'])
|
53
|
-
update_param_(y, exp_avg, lr, decay)
|
54
|
-
|
55
|
-
group['k'] = k + 1
|
56
|
-
return loss
|
46
|
+
group['k'] = k + 1
|
heavyball/foreach_sfadamw.py
CHANGED
@@ -13,52 +13,42 @@ class ForeachSFAdamW(ScheduleFree):
|
|
13
13
|
foreach=foreach)
|
14
14
|
super().__init__(params, defaults)
|
15
15
|
|
16
|
-
def
|
17
|
-
|
16
|
+
def _step(self, group):
|
17
|
+
eps = group['eps']
|
18
|
+
decay = group['weight_decay']
|
19
|
+
k = group['k']
|
18
20
|
|
19
|
-
|
20
|
-
|
21
|
-
and returns the loss.
|
22
|
-
"""
|
21
|
+
if not group['train_mode']:
|
22
|
+
raise Exception("Not in train mode!")
|
23
23
|
|
24
|
-
|
25
|
-
if closure is not None:
|
26
|
-
loss = closure()
|
24
|
+
active_p = [p for p in group['params'] if p.grad is not None]
|
27
25
|
|
28
|
-
|
29
|
-
|
30
|
-
decay = group['weight_decay']
|
31
|
-
k = group['k']
|
26
|
+
if not active_p:
|
27
|
+
return
|
32
28
|
|
33
|
-
|
34
|
-
|
29
|
+
for p in active_p:
|
30
|
+
if 'z' not in self.state_(p):
|
31
|
+
self.state_(p)['z'] = torch.clone(p.data)
|
32
|
+
self.state_(p)['exp_avg_sq'] = torch.zeros_like(p.data, dtype=torch.float32)
|
35
33
|
|
36
|
-
|
34
|
+
y, grad, exp_avg_sq, z = zip(
|
35
|
+
*[(p.data, p.grad.float(), self.state_(p)['exp_avg_sq'], self.state_(p)['z']) for p in active_p])
|
37
36
|
|
38
|
-
|
39
|
-
|
40
|
-
self.state_(p)['z'] = torch.clone(p.data)
|
41
|
-
self.state_(p)['exp_avg_sq'] = torch.zeros_like(p.data, dtype=torch.float32)
|
37
|
+
# Decay the first moment running average coefficient
|
38
|
+
old_debiased = beta_debias(group['betas'][1], k + 1)
|
42
39
|
|
43
|
-
|
44
|
-
|
40
|
+
# Decay the first and second moment running average coefficient
|
41
|
+
denom = exp_avg_sq_(exp_avg_sq, grad, old_debiased, eps)
|
45
42
|
|
46
|
-
|
47
|
-
|
43
|
+
# Normalize grad in-place for memory efficiency
|
44
|
+
torch._foreach_div_(grad, denom)
|
48
45
|
|
49
|
-
|
50
|
-
|
46
|
+
# Weight decay calculated at y
|
47
|
+
if decay != 0:
|
48
|
+
torch._foreach_add_(grad, y, alpha=decay)
|
51
49
|
|
52
|
-
|
53
|
-
|
50
|
+
lr = warmup(group['lr'], k + 1, group['warmup_steps'])
|
51
|
+
group['weight_sum'] = schedule_free_(lr, group['weight_lr_power'], group['weight_sum'], group['betas'][0],
|
52
|
+
y, z, grad, group['r'], k + 1)
|
54
53
|
|
55
|
-
|
56
|
-
if decay != 0:
|
57
|
-
torch._foreach_add_(grad, y, alpha=decay)
|
58
|
-
|
59
|
-
lr = warmup(group['lr'], k + 1, group['warmup_steps'])
|
60
|
-
group['weight_sum'] = schedule_free_(lr, group['weight_lr_power'], group['weight_sum'], group['betas'][0],
|
61
|
-
y, z, grad, group['r'], k + 1)
|
62
|
-
|
63
|
-
group['k'] = k + 1
|
64
|
-
return loss
|
54
|
+
group['k'] = k + 1
|