heavyball 0.21.8__py3-none-any.whl → 0.22.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- heavyball/__init__.py +6 -5
- heavyball/cached_delayed_psgd_kron.py +6 -5
- heavyball/cached_psgd_kron.py +7 -5
- heavyball/delayed_psgd.py +12 -9
- heavyball/foreach_adamw.py +14 -7
- heavyball/foreach_adopt.py +11 -6
- heavyball/foreach_laprop.py +12 -6
- heavyball/foreach_sfadamw.py +10 -3
- heavyball/foreach_soap.py +10 -8
- heavyball/p_adam.py +9 -7
- heavyball/palm_foreach_sfadamw.py +11 -3
- heavyball/palm_foreach_soap.py +8 -9
- heavyball/precond_schedule_foreach_soap.py +10 -8
- heavyball/precond_schedule_palm_foreach_soap.py +9 -9
- heavyball/precond_schedule_sfpsoap.py +10 -5
- heavyball/psgd_kron.py +8 -5
- heavyball/pure_psgd.py +10 -6
- heavyball/schedule_free_palm_foreach_soap.py +13 -5
- heavyball/utils.py +112 -46
- {heavyball-0.21.8.dist-info → heavyball-0.22.0.dist-info}/METADATA +2 -2
- heavyball-0.22.0.dist-info/RECORD +24 -0
- heavyball-0.21.8.dist-info/RECORD +0 -24
- {heavyball-0.21.8.dist-info → heavyball-0.22.0.dist-info}/LICENSE +0 -0
- {heavyball-0.21.8.dist-info → heavyball-0.22.0.dist-info}/WHEEL +0 -0
- {heavyball-0.21.8.dist-info → heavyball-0.22.0.dist-info}/top_level.txt +0 -0
heavyball/__init__.py
CHANGED
@@ -1,3 +1,4 @@
|
|
1
|
+
from .cached_delayed_psgd_kron import ForeachCachedDelayedPSGDKron
|
1
2
|
from .cached_psgd_kron import ForeachCachedPSGDKron
|
2
3
|
from .delayed_psgd import ForeachDelayedPSGD
|
3
4
|
from .foreach_adamw import ForeachAdamW
|
@@ -14,7 +15,6 @@ from .precond_schedule_sfpsoap import PrecondScheduleSFPaLMSOAP
|
|
14
15
|
from .psgd_kron import ForeachPSGDKron
|
15
16
|
from .pure_psgd import ForeachPurePSGD
|
16
17
|
from .schedule_free_palm_foreach_soap import SFPaLMForeachSOAP
|
17
|
-
from .cached_delayed_psgd_kron import ForeachCachedDelayedPSGDKron
|
18
18
|
|
19
19
|
PalmForEachSoap = PaLMForeachSOAP
|
20
20
|
|
@@ -39,7 +39,8 @@ CachedDelayedPSGDKron = ForeachCachedDelayedPSGDKron
|
|
39
39
|
__all__ = ['PalmForEachSoap', 'PaLMForeachSFAdamW', 'PaLMForeachSOAP', 'SFPaLMForeachSOAP', 'PrecondScheduleSFPaLMSOAP',
|
40
40
|
'ForeachSOAP', 'ForeachSFAdamW', 'ForeachLaProp', 'ForeachADOPT', 'PrecondScheduleForeachSOAP',
|
41
41
|
'PrecondSchedulePaLMForeachSOAP', 'ForeachPSGDKron', 'ForeachAdamW', 'ForeachPurePSGD', 'ForeachPaLMPAdam',
|
42
|
-
'ForeachDelayedPSGD', 'ForeachCachedPSGDKron', 'ForeachCachedDelayedPSGDKron',
|
43
|
-
|
44
|
-
'
|
45
|
-
'
|
42
|
+
'ForeachDelayedPSGD', 'ForeachCachedPSGDKron', 'ForeachCachedDelayedPSGDKron',
|
43
|
+
#
|
44
|
+
'PaLMSOAP', 'PaLMSFAdamW', 'PaLMSFSoap', 'PaLMSFAdamW', 'PrecondScheduleSFPaLMSOAP', 'SOAP', 'SFAdamW',
|
45
|
+
'LaProp', 'ADOPT', 'PSGDKron', 'AdamW', 'PurePSGD', 'PaLMPAdam', 'DelayedPSGD', 'CachedPSGDKron',
|
46
|
+
'CachedDelayedPSGDKron', 'PrecondScheduleSOAP', 'PrecondSchedulePaLMSOAP']
|
@@ -9,7 +9,7 @@ from typing import Optional
|
|
9
9
|
import torch
|
10
10
|
from heavyball.utils import min_dtype, precond_grad_cached_
|
11
11
|
|
12
|
-
from .utils import update_param_, warmup, init_Q_exprs, trust_region_clip_, PSGDBase,
|
12
|
+
from .utils import update_param_, warmup, init_Q_exprs, trust_region_clip_, PSGDBase, \
|
13
13
|
line_to_triu, triu_to_line, einsum_base, promote, stochastic_lerp_, beta_debias, precond_grad_cached_
|
14
14
|
|
15
15
|
|
@@ -43,7 +43,8 @@ class ForeachCachedDelayedPSGDKron(PSGDBase):
|
|
43
43
|
momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
|
44
44
|
split: bool = False, clip_fn: Optional[callable] = None, store_triu_as_line: bool = True,
|
45
45
|
foreach: bool = True, q_dtype='float32', stochastic_schedule: bool = True,
|
46
|
-
storage_dtype: str = 'float32',
|
46
|
+
storage_dtype: str = 'float32', mars: bool = False, caution: bool = False, mars_gamma: float = 0.0025,
|
47
|
+
#
|
47
48
|
# expert parameters
|
48
49
|
precond_init_scale=1.0, precond_lr=0.1):
|
49
50
|
if not 0.0 <= lr:
|
@@ -61,7 +62,7 @@ class ForeachCachedDelayedPSGDKron(PSGDBase):
|
|
61
62
|
momentum_into_precond_update=momentum_into_precond_update, precond_lr=precond_lr,
|
62
63
|
precond_init_scale=precond_init_scale, step=0, warmup_steps=warmup_steps, merge_dims=merge_dims,
|
63
64
|
split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype,
|
64
|
-
storage_dtype=storage_dtype)
|
65
|
+
storage_dtype=storage_dtype, caution=caution, mars_gamma=mars_gamma, mars=mars)
|
65
66
|
super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
|
66
67
|
|
67
68
|
def _step(self, group):
|
@@ -81,7 +82,7 @@ class ForeachCachedDelayedPSGDKron(PSGDBase):
|
|
81
82
|
|
82
83
|
vals = []
|
83
84
|
|
84
|
-
for p, g in split_p_and_g_in_group(group, should_promote=False):
|
85
|
+
for p, g in self.split_p_and_g_in_group(group, should_promote=False, beta1=beta):
|
85
86
|
state = self.state_(p)
|
86
87
|
|
87
88
|
if 'Q' not in state:
|
@@ -120,7 +121,7 @@ class ForeachCachedDelayedPSGDKron(PSGDBase):
|
|
120
121
|
q_orig = Q_list.pop(0)
|
121
122
|
ea = exp_avg_list.pop(0)
|
122
123
|
|
123
|
-
precond_grad_cached_(cached_q, ea, self.state_(p)['cache_expr'], p, lr, weight_decay, self.clip_fn)
|
124
|
+
precond_grad_cached_(cached_q, ea, self.state_(p)['cache_expr'], p, lr, weight_decay, self.clip_fn, group['caution'], g)
|
124
125
|
|
125
126
|
if should_update:
|
126
127
|
q = line_to_triu(q_orig) if store_triu_as_line else q_orig
|
heavyball/cached_psgd_kron.py
CHANGED
@@ -8,7 +8,7 @@ from typing import Optional
|
|
8
8
|
|
9
9
|
import torch
|
10
10
|
|
11
|
-
from .utils import update_param_, warmup, init_Q_exprs, trust_region_clip_, PSGDBase,
|
11
|
+
from .utils import update_param_, warmup, init_Q_exprs, trust_region_clip_, PSGDBase, \
|
12
12
|
line_to_triu, triu_to_line, einsum_base, promote, stochastic_lerp_, beta_debias, precond_grad_cached_
|
13
13
|
|
14
14
|
|
@@ -40,7 +40,8 @@ class ForeachCachedPSGDKron(PSGDBase):
|
|
40
40
|
momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
|
41
41
|
split: bool = False, clip_fn: Optional[callable] = None, store_triu_as_line: bool = True,
|
42
42
|
foreach: bool = True, q_dtype='float32', stochastic_schedule: bool = True,
|
43
|
-
storage_dtype: str = 'float32',
|
43
|
+
storage_dtype: str = 'float32', mars: bool = False, caution: bool = False, mars_gamma: float = 0.0025,
|
44
|
+
#
|
44
45
|
# expert parameters
|
45
46
|
precond_init_scale=1.0, precond_lr=0.1):
|
46
47
|
if not 0.0 <= lr:
|
@@ -58,7 +59,7 @@ class ForeachCachedPSGDKron(PSGDBase):
|
|
58
59
|
momentum_into_precond_update=momentum_into_precond_update, precond_lr=precond_lr,
|
59
60
|
precond_init_scale=precond_init_scale, step=0, warmup_steps=warmup_steps, merge_dims=merge_dims,
|
60
61
|
split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype,
|
61
|
-
storage_dtype=storage_dtype)
|
62
|
+
storage_dtype=storage_dtype, caution=caution, mars_gamma=mars_gamma, mars=mars)
|
62
63
|
super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
|
63
64
|
|
64
65
|
def _step(self, group):
|
@@ -78,7 +79,7 @@ class ForeachCachedPSGDKron(PSGDBase):
|
|
78
79
|
|
79
80
|
vals = []
|
80
81
|
|
81
|
-
for p, g in split_p_and_g_in_group(group, should_promote=False):
|
82
|
+
for p, g in self.split_p_and_g_in_group(group, should_promote=False, beta1=beta):
|
82
83
|
state = self.state_(p)
|
83
84
|
|
84
85
|
if 'Q' not in state:
|
@@ -128,4 +129,5 @@ class ForeachCachedPSGDKron(PSGDBase):
|
|
128
129
|
else:
|
129
130
|
torch.mul(q_.conj(), q_, out=c_)
|
130
131
|
|
131
|
-
precond_grad_cached_(cached_q, ea, self.state_(p)['cache_expr'], p, lr, weight_decay, self.clip_fn
|
132
|
+
precond_grad_cached_(cached_q, ea, self.state_(p)['cache_expr'], p, lr, weight_decay, self.clip_fn,
|
133
|
+
group['caution'], g)
|
heavyball/delayed_psgd.py
CHANGED
@@ -8,14 +8,13 @@ import torch
|
|
8
8
|
from heavyball.utils import stochastic_lerp_, beta_debias
|
9
9
|
|
10
10
|
from .utils import update_param_, warmup, psgd_precond_grad, init_Q_exprs, trust_region_clip_, PSGDBase, \
|
11
|
-
|
11
|
+
triu_to_line, line_to_triu, promote
|
12
|
+
|
12
13
|
|
13
|
-
# TODO: E1123 00:51:55.423000 159394 site-packages/torch/_guards.py:283] [5/0] Error while creating guard:
|
14
|
-
# E1123 00:51:55.423000 159394 site-packages/torch/_guards.py:283] [5/0] Name: "G['psgd_precond_grad'].__defaults__[0]"
|
15
14
|
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
|
16
|
-
def _compilable_psgd_precond_grad_(q, exprs, ea, p, lr,
|
15
|
+
def _compilable_psgd_precond_grad_(q, exprs, ea, p, lr, weight_deca, clip_fn, caution, grad):
|
17
16
|
new = psgd_precond_grad(q, exprs, ea)
|
18
|
-
update_param_([p], clip_fn([new]), lr, weight_decay)
|
17
|
+
update_param_([p], clip_fn([new]), lr, weight_decay, caution=caution, grad=grad)
|
19
18
|
|
20
19
|
|
21
20
|
class ForeachDelayedPSGD(PSGDBase):
|
@@ -46,7 +45,8 @@ class ForeachDelayedPSGD(PSGDBase):
|
|
46
45
|
max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
|
47
46
|
momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
|
48
47
|
split: bool = False, clip_fn: callable = None, store_triu_as_line: bool = True, foreach: bool = True,
|
49
|
-
q_dtype='float32', stochastic_schedule: bool = True, storage_dtype: str = 'float32',
|
48
|
+
q_dtype='float32', stochastic_schedule: bool = True, storage_dtype: str = 'float32',
|
49
|
+
mars: bool = False, caution: bool = False, mars_gamma: float = 0.0025, #
|
50
50
|
# expert parameters
|
51
51
|
precond_init_scale=1.0, precond_lr=0.1):
|
52
52
|
if not 0.0 <= lr:
|
@@ -63,7 +63,9 @@ class ForeachDelayedPSGD(PSGDBase):
|
|
63
63
|
min_ndim_triangular=min_ndim_triangular, memory_save_mode=memory_save_mode,
|
64
64
|
momentum_into_precond_update=momentum_into_precond_update, precond_lr=precond_lr,
|
65
65
|
precond_init_scale=precond_init_scale, step=0, warmup_steps=warmup_steps, merge_dims=merge_dims,
|
66
|
-
split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype,
|
66
|
+
split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype,
|
67
|
+
storage_dtype=storage_dtype,
|
68
|
+
caution=caution, mars_gamma=mars_gamma, mars=mars)
|
67
69
|
super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
|
68
70
|
|
69
71
|
def _step(self, group):
|
@@ -83,7 +85,7 @@ class ForeachDelayedPSGD(PSGDBase):
|
|
83
85
|
|
84
86
|
vals = []
|
85
87
|
|
86
|
-
for p, g in split_p_and_g_in_group(group, should_promote=False):
|
88
|
+
for p, g in self.split_p_and_g_in_group(group, should_promote=False, beta1=beta):
|
87
89
|
state = self.state_(p)
|
88
90
|
|
89
91
|
if 'Q' not in state:
|
@@ -112,7 +114,8 @@ class ForeachDelayedPSGD(PSGDBase):
|
|
112
114
|
q_orig = Q_list.pop(0)
|
113
115
|
ea = exp_avg_list.pop(0)
|
114
116
|
q = line_to_triu(q_orig) if store_triu_as_line else q_orig
|
115
|
-
_compilable_psgd_precond_grad_(q, self.state_(p)["exprs"], ea, p, lr, weight_decay, self.clip_fn
|
117
|
+
_compilable_psgd_precond_grad_(q, self.state_(p)["exprs"], ea, p, lr, weight_decay, self.clip_fn, group['caution'],
|
118
|
+
g)
|
116
119
|
if should_update:
|
117
120
|
q32 = [promote(q_) for q_ in q]
|
118
121
|
self.do_update(group, [p], [ea if momentum_into_precond_update else g], [q32], precond_lr, [q_orig],
|
heavyball/foreach_adamw.py
CHANGED
@@ -1,18 +1,19 @@
|
|
1
1
|
import torch
|
2
2
|
import torch.optim
|
3
|
-
|
4
3
|
from heavyball.utils import copy_stochastic_list_
|
4
|
+
|
5
5
|
from .utils import warmup, exp_avg_sq_, beta_debias, update_param_, StatefulOptimizer, promote
|
6
6
|
|
7
7
|
|
8
|
-
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=
|
9
|
-
def _compilable_step_(y, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps, decay):
|
8
|
+
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
|
9
|
+
def _compilable_step_(y, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps, decay, caution):
|
10
10
|
g32, exp_avg32, exp_avg_sq32 = [list(map(promote, x)) for x in [grad, exp_avg, exp_avg_sq]]
|
11
11
|
|
12
12
|
torch._foreach_lerp_(exp_avg32, g32, 1 - beta_debias(beta1, step + 1))
|
13
13
|
denom = list(exp_avg_sq_(exp_avg_sq32, g32, beta_debias(beta2, step + 1), eps))
|
14
14
|
|
15
|
-
update_param_(y, exp_avg32, lr, decay, lambda p, e, l: p.addcdiv_(e, denom.pop(0), value=l)
|
15
|
+
update_param_(y, exp_avg32, lr, decay, lambda p, e, l: p.addcdiv_(e, denom.pop(0), value=l), caution=caution,
|
16
|
+
grad=g32)
|
16
17
|
|
17
18
|
copy_stochastic_list_(exp_avg, exp_avg32)
|
18
19
|
copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
|
@@ -20,9 +21,11 @@ def _compilable_step_(y, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps,
|
|
20
21
|
|
21
22
|
class ForeachAdamW(StatefulOptimizer):
|
22
23
|
def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
|
23
|
-
foreach: bool = True, storage_dtype: str = 'float32'
|
24
|
+
foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False,
|
25
|
+
mars_gamma: float = 0.0025):
|
24
26
|
defaults = dict(lr=lr, betas=betas, eps=eps, k=0, warmup_steps=warmup_steps, train_mode=True, weight_sum=0.0,
|
25
|
-
lr_max=-1.0, weight_decay=weight_decay, storage_dtype=storage_dtype
|
27
|
+
lr_max=-1.0, weight_decay=weight_decay, storage_dtype=storage_dtype, mars=mars, caution=caution,
|
28
|
+
mars_gamma=mars_gamma)
|
26
29
|
super().__init__(params, defaults, foreach)
|
27
30
|
|
28
31
|
def _step(self, group):
|
@@ -48,9 +51,13 @@ class ForeachAdamW(StatefulOptimizer):
|
|
48
51
|
y, grad, exp_avg_sq, exp_avg = zip(
|
49
52
|
*[(p.data, p.grad, self.state_(p)['exp_avg_sq'], self.state_(p)['exp_avg']) for p in active_p])
|
50
53
|
|
54
|
+
if group['mars']:
|
55
|
+
self.mars_correct_list(group, y, grad, group['mars_gamma'], group['betas'][0])
|
56
|
+
|
51
57
|
lr = -warmup(group['lr'], k + 1, group['warmup_steps'])
|
52
58
|
lr = torch.empty((), dtype=torch.float32, device=y[0].device).fill_(lr)
|
53
59
|
step = torch.empty((), dtype=torch.int32, device=y[0].device).fill_(k)
|
54
|
-
_compilable_step_(y, grad, exp_avg_sq, exp_avg, group['betas'][0], group['betas'][1], step, lr, eps, decay
|
60
|
+
_compilable_step_(y, grad, exp_avg_sq, exp_avg, group['betas'][0], group['betas'][1], step, lr, eps, decay,
|
61
|
+
group['caution'])
|
55
62
|
|
56
63
|
group['k'] = k + 1
|
heavyball/foreach_adopt.py
CHANGED
@@ -5,10 +5,10 @@ from heavyball.utils import copy_stochastic_list_
|
|
5
5
|
from .utils import warmup, beta_debias, update_param_, StatefulOptimizer, promote
|
6
6
|
|
7
7
|
|
8
|
-
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=
|
9
|
-
def _compilable_step_(y, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps, decay):
|
8
|
+
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
|
9
|
+
def _compilable_step_(y, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps, decay, caution):
|
10
10
|
g32, exp_avg32, exp_avg_sq32 = [list(map(promote, x)) for x in [grad, exp_avg, exp_avg_sq]]
|
11
|
-
update_param_(y, exp_avg, lr, decay)
|
11
|
+
update_param_(y, exp_avg, lr, decay, caution=caution, grad=g32)
|
12
12
|
|
13
13
|
beta1 = beta_debias(beta1, step)
|
14
14
|
denom = torch._foreach_sqrt(exp_avg_sq32)
|
@@ -27,9 +27,11 @@ def _compilable_step_(y, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps,
|
|
27
27
|
class ForeachADOPT(StatefulOptimizer):
|
28
28
|
|
29
29
|
def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
|
30
|
-
foreach: bool = True, storage_dtype: str = 'float32'
|
30
|
+
foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False,
|
31
|
+
mars_gamma: float = 0.0025):
|
31
32
|
defaults = dict(lr=lr, betas=betas, eps=eps, k=0, warmup_steps=warmup_steps, train_mode=True, weight_sum=0.0,
|
32
|
-
lr_max=-1.0, weight_decay=weight_decay, storage_dtype=storage_dtype
|
33
|
+
lr_max=-1.0, weight_decay=weight_decay, storage_dtype=storage_dtype, mars=mars, caution=caution,
|
34
|
+
mars_gamma=mars_gamma)
|
33
35
|
super().__init__(params, defaults, foreach)
|
34
36
|
|
35
37
|
def _step(self, group):
|
@@ -57,11 +59,14 @@ class ForeachADOPT(StatefulOptimizer):
|
|
57
59
|
|
58
60
|
group['k'] = k + 1
|
59
61
|
|
62
|
+
if group['mars']:
|
63
|
+
self.mars_correct_list(group, y, grad, group['mars_gamma'], group['betas'][0])
|
64
|
+
|
60
65
|
if k > 1:
|
61
66
|
lr = -warmup(group['lr'], k - 1, group['warmup_steps'])
|
62
67
|
lr = torch.empty((), dtype=torch.float32, device=y[0].device).fill_(lr)
|
63
68
|
k = torch.empty((), dtype=torch.int32, device=y[0].device).fill_(k)
|
64
|
-
_compilable_step_(y, grad, exp_avg_sq, exp_avg, group['betas'][0], group['betas'][1], k, lr, eps, decay)
|
69
|
+
_compilable_step_(y, grad, exp_avg_sq, exp_avg, group['betas'][0], group['betas'][1], k, lr, eps, decay, group['caution'])
|
65
70
|
return
|
66
71
|
|
67
72
|
grad = [promote(g) for g in grad]
|
heavyball/foreach_laprop.py
CHANGED
@@ -4,8 +4,8 @@ import torch.optim
|
|
4
4
|
from .utils import warmup, exp_avg_sq_, beta_debias, update_param_, StatefulOptimizer, promote, copy_stochastic_list_
|
5
5
|
|
6
6
|
|
7
|
-
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=
|
8
|
-
def _compilable_step_(y, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps, decay):
|
7
|
+
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
|
8
|
+
def _compilable_step_(y, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps, decay, caution):
|
9
9
|
g32, exp_avg32, exp_avg_sq32 = [list(map(promote, x)) for x in [grad, exp_avg, exp_avg_sq]]
|
10
10
|
|
11
11
|
denom = exp_avg_sq_(exp_avg_sq32, g32, beta_debias(beta2, step), eps)
|
@@ -14,7 +14,7 @@ def _compilable_step_(y, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps,
|
|
14
14
|
torch._foreach_mul_(exp_avg32, beta1)
|
15
15
|
[ea32.addcdiv_(g, d, value=1 - beta1) for ea32, g, d in zip(exp_avg32, g32, denom)]
|
16
16
|
|
17
|
-
update_param_(y, exp_avg32, lr, decay)
|
17
|
+
update_param_(y, exp_avg32, lr, decay, caution=caution, grad=g32)
|
18
18
|
|
19
19
|
copy_stochastic_list_(exp_avg, exp_avg32)
|
20
20
|
copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
|
@@ -23,9 +23,11 @@ def _compilable_step_(y, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps,
|
|
23
23
|
class ForeachLaProp(StatefulOptimizer):
|
24
24
|
|
25
25
|
def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=1,
|
26
|
-
foreach: bool = True, storage_dtype: str = 'float32'
|
26
|
+
foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False,
|
27
|
+
mars_gamma: float = 0.0025):
|
27
28
|
defaults = dict(lr=lr, betas=betas, eps=eps, k=0, warmup_steps=warmup_steps, train_mode=True, weight_sum=0.0,
|
28
|
-
lr_max=-1.0, weight_decay=weight_decay, storage_dtype=storage_dtype
|
29
|
+
lr_max=-1.0, weight_decay=weight_decay, storage_dtype=storage_dtype, mars=mars, caution=caution,
|
30
|
+
mars_gamma=mars_gamma)
|
29
31
|
super().__init__(params, defaults, foreach)
|
30
32
|
|
31
33
|
def _step(self, group):
|
@@ -52,10 +54,14 @@ class ForeachLaProp(StatefulOptimizer):
|
|
52
54
|
*[(p.data, p.grad, self.state_(p)['exp_avg_sq'], self.state_(p)['exp_avg']) #
|
53
55
|
for p in active_p])
|
54
56
|
|
57
|
+
if group['mars']:
|
58
|
+
self.mars_correct_list(group, y, grad, group['mars_gamma'], group['betas'][0])
|
59
|
+
|
55
60
|
lr = -warmup(group['lr'], k + 1, group['warmup_steps'])
|
56
61
|
lr = torch.empty((), dtype=torch.float32, device=y[0].device).fill_(lr)
|
57
62
|
step = torch.empty((), dtype=torch.int32, device=y[0].device).fill_(k + 1)
|
58
63
|
|
59
|
-
_compilable_step_(y, grad, exp_avg_sq, exp_avg, group['betas'][0], group['betas'][1], step, lr, eps, decay
|
64
|
+
_compilable_step_(y, grad, exp_avg_sq, exp_avg, group['betas'][0], group['betas'][1], step, lr, eps, decay,
|
65
|
+
group['caution'])
|
60
66
|
|
61
67
|
group['k'] = k + 1
|
heavyball/foreach_sfadamw.py
CHANGED
@@ -5,7 +5,7 @@ from heavyball.utils import get_ckp1, copy_stochastic_list_
|
|
5
5
|
from .utils import warmup, ScheduleFree, exp_avg_sq_, beta_debias, promote, _compilable_schedule_free_
|
6
6
|
|
7
7
|
|
8
|
-
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=
|
8
|
+
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
|
9
9
|
def _compilable_step_(y, grad, exp_avg_sq, z, beta1, beta2, step, ckp1, eps, decay, lr):
|
10
10
|
old_debiased2 = beta_debias(beta2, step)
|
11
11
|
|
@@ -21,13 +21,17 @@ def _compilable_step_(y, grad, exp_avg_sq, z, beta1, beta2, step, ckp1, eps, dec
|
|
21
21
|
|
22
22
|
copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
|
23
23
|
|
24
|
+
|
24
25
|
class ForeachSFAdamW(ScheduleFree):
|
25
26
|
def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0, r=0.0,
|
26
|
-
weight_lr_power=2.0, foreach: bool = True, storage_dtype: str = 'float32'
|
27
|
+
weight_lr_power=2.0, foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False,
|
28
|
+
caution: bool = False, mars_gamma: float = 0.0025):
|
29
|
+
|
30
|
+
assert not caution, "Caution not implemented for SFAdamW"
|
27
31
|
|
28
32
|
defaults = dict(lr=lr, betas=betas, eps=eps, r=r, k=0, warmup_steps=warmup_steps, train_mode=True,
|
29
33
|
weight_sum=0.0, lr_max=-1.0, weight_lr_power=weight_lr_power, weight_decay=weight_decay,
|
30
|
-
foreach=foreach, storage_dtype=storage_dtype)
|
34
|
+
foreach=foreach, storage_dtype=storage_dtype, mars=mars, caution=caution, mars_gamma=mars_gamma)
|
31
35
|
super().__init__(params, defaults, foreach)
|
32
36
|
|
33
37
|
def _step(self, group):
|
@@ -53,6 +57,9 @@ class ForeachSFAdamW(ScheduleFree):
|
|
53
57
|
y, grad, exp_avg_sq, z = zip(*[(p.data, p.grad, self.state_(p)['exp_avg_sq'], self.state_(p)['z']) #
|
54
58
|
for p in active_p])
|
55
59
|
|
60
|
+
if group['mars']:
|
61
|
+
self.mars_correct_list(group, y, grad, group['mars_gamma'], group['betas'][0])
|
62
|
+
|
56
63
|
lr = warmup(group['lr'], k + 1, group['warmup_steps'])
|
57
64
|
ckp1, group['weight_sum'] = get_ckp1(lr, group['weight_lr_power'], group['weight_sum'], group['r'], k + 1)
|
58
65
|
|
heavyball/foreach_soap.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
import torch
|
2
2
|
|
3
3
|
from .utils import init_preconditioner, update_preconditioner, project, beta_debias, exp_avg_sq_, update_param_, set_, \
|
4
|
-
|
4
|
+
StatefulOptimizer, exp_avg_
|
5
5
|
|
6
6
|
|
7
7
|
class ForeachSOAP(StatefulOptimizer):
|
@@ -26,11 +26,13 @@ class ForeachSOAP(StatefulOptimizer):
|
|
26
26
|
weight_decay: float = 0.01, precondition_frequency: int = 2, max_precond_dim: int = 2048, #
|
27
27
|
merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
|
28
28
|
data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1,
|
29
|
-
split: bool = False, foreach: bool = True
|
29
|
+
split: bool = False, foreach: bool = True, mars: bool = False, caution: bool = False,
|
30
|
+
mars_gamma: float = 0.0025):
|
30
31
|
defaults = {"lr": lr, "betas": betas, "shampoo_beta": shampoo_beta, "eps": eps, "weight_decay": weight_decay,
|
31
32
|
"precondition_frequency": precondition_frequency, "max_precond_dim": max_precond_dim,
|
32
33
|
"merge_dims": merge_dims, "precondition_1d": precondition_1d, "normalize_grads": normalize_grads,
|
33
|
-
"correct_bias": correct_bias, 'warmup_steps': warmup_steps, 'split': split
|
34
|
+
"correct_bias": correct_bias, 'warmup_steps': warmup_steps, 'split': split, 'mars': mars,
|
35
|
+
'caution': caution, 'mars_gamma': mars_gamma}
|
34
36
|
super().__init__(params, defaults, foreach)
|
35
37
|
self._data_format = data_format
|
36
38
|
|
@@ -41,7 +43,7 @@ class ForeachSOAP(StatefulOptimizer):
|
|
41
43
|
max_precond_dim = group['max_precond_dim']
|
42
44
|
precondition_1d = group['precondition_1d']
|
43
45
|
|
44
|
-
for p, g in split_p_and_g_in_group(group):
|
46
|
+
for p, g in self.split_p_and_g_in_group(group, beta1=group['betas'][0]):
|
45
47
|
state = self.state_(p)
|
46
48
|
step = state['step'] = state.get("step", -1) + 1
|
47
49
|
|
@@ -71,6 +73,8 @@ class ForeachSOAP(StatefulOptimizer):
|
|
71
73
|
step_tensor = torch.empty((), dtype=torch.int32, device=p_list[0].device).fill_(step)
|
72
74
|
denom = exp_avg_(exp_avg, exp_avg_sq, grad, grad_projected, beta1, beta2, step_tensor)
|
73
75
|
|
76
|
+
step_size = -group["lr"] * min(step / group['warmup_steps'], 1)
|
77
|
+
|
74
78
|
for p, g, ea, d in zip(p_list, grad, exp_avg, denom):
|
75
79
|
state = self.state_(p)
|
76
80
|
# Projecting the exponential moving average of gradients to the eigenbases of Shampoo's preconditioner
|
@@ -80,11 +84,9 @@ class ForeachSOAP(StatefulOptimizer):
|
|
80
84
|
# Projecting back the preconditioned (by Adam) exponential moving average of gradients
|
81
85
|
# to the original space
|
82
86
|
# CANT DO /= HERE AS EXP_AVG MAY POINT TO THE BUFFER
|
83
|
-
|
87
|
+
precond = project(exp_avg_projected / d, state['Q'], True)
|
84
88
|
|
85
89
|
update_preconditioner(g, state, max_precond_dim, precondition_1d, old_debiased2,
|
86
90
|
step > 0 and step % group['precondition_frequency'] == 0)
|
87
91
|
|
88
|
-
|
89
|
-
step_size = -group["lr"] * min(step / group['warmup_steps'], 1)
|
90
|
-
update_param_(p_list, denom, step_size, group["weight_decay"])
|
92
|
+
update_param_([p], [precond], step_size, group["weight_decay"], caution=group['caution'], grad=[g])
|
heavyball/p_adam.py
CHANGED
@@ -5,10 +5,10 @@ Source available at https://github.com/evanatyourservice/kron_torch/blob/97a2b5e
|
|
5
5
|
"""
|
6
6
|
|
7
7
|
import torch
|
8
|
-
|
9
8
|
from heavyball.utils import triu_to_line, line_to_triu, identity, stochastic_lerp_
|
9
|
+
|
10
10
|
from .utils import update_param_, warmup, psgd_precond_grad, init_Q_exprs, PSGDBase, exp_avg_sq_, beta_debias, \
|
11
|
-
|
11
|
+
promote
|
12
12
|
|
13
13
|
|
14
14
|
class ForeachPaLMPAdam(PSGDBase):
|
@@ -39,7 +39,8 @@ class ForeachPaLMPAdam(PSGDBase):
|
|
39
39
|
momentum_into_precond_update=True, warmup_steps: int = 1, betas=(None, None), beta: float = 0.9,
|
40
40
|
beta2_scale: float = 0.8, merge_dims: bool = False, split: bool = False, clip_fn: callable = None,
|
41
41
|
store_triu_as_line: bool = True, foreach: bool = True, q_dtype='float32',
|
42
|
-
stochastic_schedule: bool = True,
|
42
|
+
stochastic_schedule: bool = True, storage_dtype: str = 'float32', mars: bool = False,
|
43
|
+
caution: bool = False, mars_gamma: float = 0.0025, #
|
43
44
|
# expert parameters
|
44
45
|
precond_init_scale=1.0, precond_lr=0.1):
|
45
46
|
if not 0.0 <= lr:
|
@@ -57,7 +58,8 @@ class ForeachPaLMPAdam(PSGDBase):
|
|
57
58
|
momentum_into_precond_update=momentum_into_precond_update, precond_lr=precond_lr,
|
58
59
|
precond_init_scale=precond_init_scale, step=0, warmup_steps=warmup_steps, beta=beta,
|
59
60
|
beta2_scale=beta2_scale, merge_dims=merge_dims, split=split,
|
60
|
-
store_triu_as_line=store_triu_as_line, q_dtype=q_dtype, storage_dtype=storage_dtype
|
61
|
+
store_triu_as_line=store_triu_as_line, q_dtype=q_dtype, storage_dtype=storage_dtype,
|
62
|
+
mars=mars, caution=caution, mars_gamma=mars_gamma)
|
61
63
|
super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
|
62
64
|
|
63
65
|
def _step(self, group):
|
@@ -75,7 +77,7 @@ class ForeachPaLMPAdam(PSGDBase):
|
|
75
77
|
|
76
78
|
vals = []
|
77
79
|
|
78
|
-
for p, g in split_p_and_g_in_group(group, should_promote=False):
|
80
|
+
for p, g in self.split_p_and_g_in_group(group, should_promote=False, beta1=group['beta']):
|
79
81
|
state = self.state_(p)
|
80
82
|
|
81
83
|
if 'Q' not in state:
|
@@ -107,6 +109,7 @@ class ForeachPaLMPAdam(PSGDBase):
|
|
107
109
|
lr = -warmup(lr, group['step'], group['warmup_steps'])
|
108
110
|
|
109
111
|
for p, Q, g, ea, eas in zip(p_list, Q_triu, grad_list, exp_avg, exp_avg_sq):
|
112
|
+
gc = g.clone() if group['caution'] else None
|
110
113
|
psgd_precond_grad(Q, self.state_(p)["exprs"], g, inplace=True)
|
111
114
|
ea = psgd_precond_grad(Q, self.state_(p)["exprs"], ea)
|
112
115
|
exp_avg_sq_(eas, g, beta_debias(beta2, group['step']), 1e-8, out=g)
|
@@ -115,5 +118,4 @@ class ForeachPaLMPAdam(PSGDBase):
|
|
115
118
|
divide by g here, because g == denom (from exp_avg_sq_(out=g)), avoids denom allocation
|
116
119
|
divide into g so we can deallocate ea, avoids one allocation (-> less memory than equivalent foreach)
|
117
120
|
"""
|
118
|
-
update_param_([p], self.clip_fn([g]), lr, weight_decay)
|
119
|
-
|
121
|
+
update_param_([p], self.clip_fn([g]), lr, weight_decay, caution=group['caution'], grad=gc)
|
@@ -5,7 +5,7 @@ from .utils import warmup, ScheduleFree, exp_avg_sq_, beta_debias, get_ckp1, pro
|
|
5
5
|
_compilable_schedule_free_, copy_stochastic_list_
|
6
6
|
|
7
7
|
|
8
|
-
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=
|
8
|
+
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
|
9
9
|
def _compilable_step_(y, grad, exp_avg_sq, z, beta1, beta2, step, ckp1, eps, decay, lr):
|
10
10
|
old_debiased2 = beta_debias(beta2, step)
|
11
11
|
|
@@ -24,12 +24,17 @@ def _compilable_step_(y, grad, exp_avg_sq, z, beta1, beta2, step, ckp1, eps, dec
|
|
24
24
|
|
25
25
|
class PaLMForeachSFAdamW(ScheduleFree):
|
26
26
|
def __init__(self, params, lr=0.0025, beta=0.9, betas=(None, None), eps=1e-8, weight_decay=0, warmup_steps=0, r=0.0,
|
27
|
-
weight_lr_power=2.0, beta2_scale: float = 0.8, foreach: bool = True, storage_dtype: str = 'float32'
|
27
|
+
weight_lr_power=2.0, beta2_scale: float = 0.8, foreach: bool = True, storage_dtype: str = 'float32',
|
28
|
+
mars: bool = False, caution: bool = False, mars_gamma: float = 0.0025):
|
28
29
|
if betas[0] is not None:
|
29
30
|
beta = betas[0]
|
31
|
+
|
32
|
+
assert not caution, "Caution not implemented for SFAdamW"
|
33
|
+
|
30
34
|
defaults = dict(lr=lr, beta=beta, eps=eps, r=r, k=0, warmup_steps=warmup_steps, train_mode=True, weight_sum=0.0,
|
31
35
|
lr_max=-1.0, weight_lr_power=weight_lr_power, weight_decay=weight_decay,
|
32
|
-
beta2_scale=beta2_scale, storage_dtype=storage_dtype
|
36
|
+
beta2_scale=beta2_scale, storage_dtype=storage_dtype, mars=mars, caution=caution,
|
37
|
+
mars_gamma=mars_gamma)
|
33
38
|
super().__init__(params, defaults, foreach)
|
34
39
|
|
35
40
|
def _step(self, group):
|
@@ -58,6 +63,9 @@ class PaLMForeachSFAdamW(ScheduleFree):
|
|
58
63
|
y, grad, exp_avg_sq, z = zip(*[(p.data, p.grad, self.state_(p)['exp_avg_sq'], self.state_(p)['z']) #
|
59
64
|
for p in active_p])
|
60
65
|
|
66
|
+
if group['mars']:
|
67
|
+
self.mars_correct_list(group, y, grad, group['mars_gamma'], group['betas'][0])
|
68
|
+
|
61
69
|
lr = warmup(group['lr'], k + 1, group['warmup_steps'])
|
62
70
|
ckp1, group['weight_sum'] = get_ckp1(lr, group['weight_lr_power'], group['weight_sum'], group['r'], k + 1)
|
63
71
|
|
heavyball/palm_foreach_soap.py
CHANGED
@@ -1,8 +1,7 @@
|
|
1
1
|
import torch
|
2
2
|
|
3
3
|
from .utils import init_preconditioner, update_preconditioner, project, beta_debias, exp_avg_sq_, update_param_, set_, \
|
4
|
-
|
5
|
-
|
4
|
+
StatefulOptimizer, exp_avg_
|
6
5
|
|
7
6
|
|
8
7
|
class PaLMForeachSOAP(StatefulOptimizer):
|
@@ -33,14 +32,15 @@ class PaLMForeachSOAP(StatefulOptimizer):
|
|
33
32
|
max_precond_dim: int = 2048, #
|
34
33
|
merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
|
35
34
|
data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1,
|
36
|
-
beta2_scale: float = 0.8, split: bool = False, foreach: bool = True
|
35
|
+
beta2_scale: float = 0.8, split: bool = False, foreach: bool = True, mars: bool = False,
|
36
|
+
caution: bool = False, mars_gamma: float = 0.0025):
|
37
37
|
if betas[0] is not None:
|
38
38
|
beta = betas[0]
|
39
39
|
defaults = {"lr": lr, "beta": beta, "shampoo_beta": shampoo_beta, "eps": eps, "weight_decay": weight_decay,
|
40
40
|
"precondition_frequency": precondition_frequency, "max_precond_dim": max_precond_dim,
|
41
41
|
"merge_dims": merge_dims, "precondition_1d": precondition_1d, "normalize_grads": normalize_grads,
|
42
42
|
"correct_bias": correct_bias, 'warmup_steps': warmup_steps, 'beta2_scale': beta2_scale,
|
43
|
-
'split': split}
|
43
|
+
'split': split, 'mars': mars, 'caution': caution, 'mars_gamma': mars_gamma}
|
44
44
|
super().__init__(params, defaults, foreach)
|
45
45
|
self._data_format = data_format
|
46
46
|
|
@@ -51,7 +51,7 @@ class PaLMForeachSOAP(StatefulOptimizer):
|
|
51
51
|
max_precond_dim = group['max_precond_dim']
|
52
52
|
precondition_1d = group['precondition_1d']
|
53
53
|
|
54
|
-
for p, g in split_p_and_g_in_group(group):
|
54
|
+
for p, g in self.split_p_and_g_in_group(group, beta1=group['beta']):
|
55
55
|
state = self.state_(p)
|
56
56
|
step = state['step'] = state.get("step", -1) + 1
|
57
57
|
|
@@ -82,6 +82,7 @@ class PaLMForeachSOAP(StatefulOptimizer):
|
|
82
82
|
beta2 = torch.empty((), dtype=torch.float32, device=p_list[0].device).fill_(beta2)
|
83
83
|
step_tensor = torch.empty((), dtype=torch.int32, device=p_list[0].device).fill_(step)
|
84
84
|
denom = exp_avg_(exp_avg, exp_avg_sq, grad, grad_projected, beta1, beta2, step_tensor)
|
85
|
+
step_size = -group["lr"] * min(step / group['warmup_steps'], 1)
|
85
86
|
|
86
87
|
for p, g, ea, d in zip(p_list, grad, exp_avg, denom):
|
87
88
|
state = self.state_(p)
|
@@ -92,11 +93,9 @@ class PaLMForeachSOAP(StatefulOptimizer):
|
|
92
93
|
# Projecting back the preconditioned (by Adam) exponential moving average of gradients
|
93
94
|
# to the original space
|
94
95
|
# CANT DO /= HERE AS EXP_AVG MAY POINT TO THE BUFFER
|
95
|
-
|
96
|
+
precond = project(exp_avg_projected / d, state['Q'], True)
|
96
97
|
|
97
98
|
update_preconditioner(g, state, max_precond_dim, precondition_1d, old_debiased2,
|
98
99
|
step > 0 and step % group['precondition_frequency'] == 0)
|
99
100
|
|
100
|
-
|
101
|
-
step_size = -group["lr"] * min(step / group['warmup_steps'], 1)
|
102
|
-
update_param_(p_list, denom, step_size, group["weight_decay"])
|
101
|
+
update_param_([p], [precond], step_size, group["weight_decay"], caution=group['caution'], grad=[g])
|
@@ -3,7 +3,7 @@ import random
|
|
3
3
|
import torch
|
4
4
|
|
5
5
|
from .utils import init_preconditioner, update_preconditioner, project, beta_debias, update_param_, \
|
6
|
-
precond_schedule, set_,
|
6
|
+
precond_schedule, set_, StatefulOptimizer, exp_avg_
|
7
7
|
|
8
8
|
|
9
9
|
class PrecondScheduleForeachSOAP(StatefulOptimizer):
|
@@ -27,12 +27,13 @@ class PrecondScheduleForeachSOAP(StatefulOptimizer):
|
|
27
27
|
weight_decay: float = 0.01, precondition_frequency: int = 2, max_precond_dim: int = 2048, #
|
28
28
|
merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
|
29
29
|
data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1,
|
30
|
-
precond_scheduler=(1 / 3, 9), split: bool = False, foreach: bool = True
|
30
|
+
precond_scheduler=(1 / 3, 9), split: bool = False, foreach: bool = True, mars: bool = False,
|
31
|
+
caution: bool = False, mars_gamma: float = 0.0025):
|
31
32
|
defaults = {"lr": lr, "betas": betas, "shampoo_beta": shampoo_beta, "eps": eps, "weight_decay": weight_decay,
|
32
33
|
"precondition_frequency": precondition_frequency, "max_precond_dim": max_precond_dim,
|
33
34
|
"merge_dims": merge_dims, "precondition_1d": precondition_1d, "normalize_grads": normalize_grads,
|
34
35
|
"correct_bias": correct_bias, 'warmup_steps': warmup_steps, 'precond_scheduler': precond_scheduler,
|
35
|
-
'split': split}
|
36
|
+
'split': split, 'mars': mars, 'caution': caution, 'mars_gamma': mars_gamma}
|
36
37
|
super().__init__(params, defaults, foreach)
|
37
38
|
self._data_format = data_format
|
38
39
|
self.rng = random.Random(0x120983109)
|
@@ -44,7 +45,7 @@ class PrecondScheduleForeachSOAP(StatefulOptimizer):
|
|
44
45
|
max_precond_dim = group['max_precond_dim']
|
45
46
|
precondition_1d = group['precondition_1d']
|
46
47
|
|
47
|
-
for p, g in split_p_and_g_in_group(group):
|
48
|
+
for p, g in self.split_p_and_g_in_group(group, beta1=group['betas'][0]):
|
48
49
|
state = self.state_(p)
|
49
50
|
step = state['step'] = state.get("step", -1) + 1
|
50
51
|
|
@@ -75,6 +76,8 @@ class PrecondScheduleForeachSOAP(StatefulOptimizer):
|
|
75
76
|
denom = exp_avg_(exp_avg, exp_avg_sq, grad, grad_projected, beta1, beta2, step_tensor)
|
76
77
|
|
77
78
|
update_precond = precond_schedule(step, group['precond_scheduler'], self.rng)
|
79
|
+
step_size = -group["lr"] * min(step / group['warmup_steps'], 1)
|
80
|
+
|
78
81
|
for p, g, ea, d in zip(p_list, grad, exp_avg, denom):
|
79
82
|
state = self.state_(p)
|
80
83
|
# Projecting the exponential moving average of gradients to the eigenbases of Shampoo's preconditioner
|
@@ -84,10 +87,9 @@ class PrecondScheduleForeachSOAP(StatefulOptimizer):
|
|
84
87
|
# Projecting back the preconditioned (by Adam) exponential moving average of gradients
|
85
88
|
# to the original space
|
86
89
|
# CANT DO /= HERE AS EXP_AVG MAY POINT TO THE BUFFER
|
87
|
-
|
90
|
+
precond = project(exp_avg_projected / d, state['Q'], True)
|
88
91
|
|
89
92
|
update_preconditioner(g, state, max_precond_dim, precondition_1d, old_debiased2, update_precond)
|
90
93
|
|
91
|
-
|
92
|
-
|
93
|
-
update_param_(p_list, denom, step_size, group["weight_decay"])
|
94
|
+
update_param_([p], [precond], step_size, group["weight_decay"], caution=group['caution'], grad=[g])
|
95
|
+
|
@@ -3,7 +3,7 @@ import random
|
|
3
3
|
import torch
|
4
4
|
|
5
5
|
from .utils import init_preconditioner, update_preconditioner, project, beta_debias, exp_avg_, update_param_, \
|
6
|
-
precond_schedule, set_,
|
6
|
+
precond_schedule, set_, StatefulOptimizer
|
7
7
|
|
8
8
|
|
9
9
|
class PrecondSchedulePaLMForeachSOAP(StatefulOptimizer):
|
@@ -33,14 +33,15 @@ class PrecondSchedulePaLMForeachSOAP(StatefulOptimizer):
|
|
33
33
|
merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
|
34
34
|
data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1,
|
35
35
|
precond_scheduler=(1 / 3, 9), betas=(None, None), beta2_scale: float = 0.8, split: bool = False,
|
36
|
-
foreach: bool = True):
|
36
|
+
foreach: bool = True, mars: bool = False, caution: bool = False, mars_gamma: float = 0.0025):
|
37
37
|
if betas[0] is not None:
|
38
38
|
beta = betas[0]
|
39
39
|
defaults = {"lr": lr, "beta": beta, "shampoo_beta": shampoo_beta, "eps": eps, "weight_decay": weight_decay,
|
40
40
|
"precondition_frequency": precondition_frequency, "max_precond_dim": max_precond_dim,
|
41
41
|
"merge_dims": merge_dims, "precondition_1d": precondition_1d, "normalize_grads": normalize_grads,
|
42
42
|
"correct_bias": correct_bias, 'warmup_steps': warmup_steps, 'precond_scheduler': precond_scheduler,
|
43
|
-
'beta2_scale': beta2_scale, 'split': split
|
43
|
+
'beta2_scale': beta2_scale, 'split': split, 'mars': mars, 'caution': caution,
|
44
|
+
'mars_gamma': mars_gamma}
|
44
45
|
super().__init__(params, defaults, foreach)
|
45
46
|
self._data_format = data_format
|
46
47
|
self.rng = random.Random(0x120983109)
|
@@ -52,7 +53,7 @@ class PrecondSchedulePaLMForeachSOAP(StatefulOptimizer):
|
|
52
53
|
max_precond_dim = group['max_precond_dim']
|
53
54
|
precondition_1d = group['precondition_1d']
|
54
55
|
|
55
|
-
for p, g in split_p_and_g_in_group(group):
|
56
|
+
for p, g in self.split_p_and_g_in_group(group, beta1=group['beta']):
|
56
57
|
state = self.state_(p)
|
57
58
|
step = state['step'] = state.get("step", -1) + 1
|
58
59
|
|
@@ -86,6 +87,8 @@ class PrecondSchedulePaLMForeachSOAP(StatefulOptimizer):
|
|
86
87
|
denom = exp_avg_(exp_avg, exp_avg_sq, grad, grad_projected, beta1, beta2, step_tensor)
|
87
88
|
|
88
89
|
update_precond = precond_schedule(step, group['precond_scheduler'], self.rng)
|
90
|
+
step_size = -group["lr"] * min(step / group['warmup_steps'], 1)
|
91
|
+
|
89
92
|
for p, g, ea, d in zip(p_list, grad, exp_avg, denom):
|
90
93
|
state = self.state_(p)
|
91
94
|
# Projecting the exponential moving average of gradients to the eigenbases of Shampoo's preconditioner
|
@@ -96,10 +99,7 @@ class PrecondSchedulePaLMForeachSOAP(StatefulOptimizer):
|
|
96
99
|
# to the original space
|
97
100
|
# CANT DO /= HERE AS EXP_AVG MAY POINT TO THE BUFFER
|
98
101
|
exp_avg_projected = exp_avg_projected / d
|
99
|
-
|
102
|
+
precond = project(exp_avg_projected, state['Q'], True)
|
100
103
|
|
101
104
|
update_preconditioner(g, state, max_precond_dim, precondition_1d, old_debiased2, update_precond)
|
102
|
-
|
103
|
-
# Why does this have to be rebiased here?
|
104
|
-
step_size = -group["lr"] * min(step / group['warmup_steps'], 1)
|
105
|
-
update_param_(p_list, denom, step_size, group["weight_decay"])
|
105
|
+
update_param_([p], [precond], step_size, group["weight_decay"], caution=group['caution'], grad=[g])
|
@@ -3,11 +3,11 @@ import random
|
|
3
3
|
import torch
|
4
4
|
|
5
5
|
from .utils import init_preconditioner, update_preconditioner, project, set_, adaptive_gradient_clipping_, exp_avg_sq_, \
|
6
|
-
beta_debias, schedule_free_, warmup, ScheduleFree, precond_schedule,
|
6
|
+
beta_debias, schedule_free_, warmup, ScheduleFree, precond_schedule, copy_stochastic_list_, \
|
7
7
|
promote
|
8
8
|
|
9
9
|
|
10
|
-
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=
|
10
|
+
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
|
11
11
|
def _compilable_exp_avg_sq_(exp_avg_sq, grad_projected, old_debiased2, eps):
|
12
12
|
eas32, gp32 = [list(map(promote, x)) for x in (exp_avg_sq, grad_projected)]
|
13
13
|
denom = exp_avg_sq_(eas32, gp32, old_debiased2, eps)
|
@@ -52,15 +52,20 @@ class PrecondScheduleSFPaLMSOAP(ScheduleFree):
|
|
52
52
|
merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
|
53
53
|
data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1, r=0.0,
|
54
54
|
weight_lr_power=2.0, gradient_clip_val: float = 0.1, precond_scheduler=(1 / 3, 9), betas=(None, None),
|
55
|
-
split: bool = False, foreach: bool = True
|
55
|
+
split: bool = False, foreach: bool = True, mars: bool = False, caution: bool = False,
|
56
|
+
mars_gamma: float = 0.0025):
|
56
57
|
if betas[0] is not None:
|
57
58
|
beta = betas[0]
|
59
|
+
|
60
|
+
assert not caution, "Caution is not implemented in ScheduleFree optimizers"
|
61
|
+
|
58
62
|
defaults = {"lr": lr, "beta": beta, "beta2_scale": beta2_scale, "eps": eps, "weight_decay": weight_decay,
|
59
63
|
"precondition_frequency": precondition_frequency, "max_precond_dim": max_precond_dim,
|
60
64
|
"merge_dims": merge_dims, "precondition_1d": precondition_1d, "normalize_grads": normalize_grads,
|
61
65
|
"correct_bias": correct_bias, 'warmup_steps': warmup_steps, 'r': r,
|
62
66
|
'weight_lr_power': weight_lr_power, 'train_mode': True, 'step': -1, 'weight_sum': 0,
|
63
|
-
'gradient_clip_val': gradient_clip_val, 'precond_scheduler': precond_scheduler, 'split': split
|
67
|
+
'gradient_clip_val': gradient_clip_val, 'precond_scheduler': precond_scheduler, 'split': split,
|
68
|
+
'mars': mars, 'caution': caution, 'mars_gamma': mars_gamma}
|
64
69
|
super().__init__(params, defaults, foreach)
|
65
70
|
self._data_format = data_format
|
66
71
|
self.rng = random.Random(0x120983109)
|
@@ -87,7 +92,7 @@ class PrecondScheduleSFPaLMSOAP(ScheduleFree):
|
|
87
92
|
# adaptive gradient clipping
|
88
93
|
adaptive_gradient_clipping_(p_list, grad, group["gradient_clip_val"], eps=group["eps"])
|
89
94
|
|
90
|
-
for p, g in split_p_and_g_in_group(group):
|
95
|
+
for p, g in self.split_p_and_g_in_group(group, beta1=group['beta']):
|
91
96
|
state = self.state_(p)
|
92
97
|
|
93
98
|
if "z" not in state:
|
heavyball/psgd_kron.py
CHANGED
@@ -9,7 +9,7 @@ from typing import Optional
|
|
9
9
|
import torch
|
10
10
|
|
11
11
|
from .utils import update_param_, warmup, psgd_precond_grad, init_Q_exprs, trust_region_clip_, PSGDBase, \
|
12
|
-
|
12
|
+
line_to_triu, triu_to_line, promote, stochastic_lerp_, beta_debias
|
13
13
|
|
14
14
|
|
15
15
|
class ForeachPSGDKron(PSGDBase):
|
@@ -40,7 +40,8 @@ class ForeachPSGDKron(PSGDBase):
|
|
40
40
|
momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
|
41
41
|
split: bool = False, clip_fn: Optional[callable] = None, store_triu_as_line: bool = True,
|
42
42
|
foreach: bool = True, q_dtype='float32', stochastic_schedule: bool = True,
|
43
|
-
storage_dtype: str = 'float32',
|
43
|
+
storage_dtype: str = 'float32', mars: bool = False, caution: bool = False, mars_gamma: float = 0.0025,
|
44
|
+
#
|
44
45
|
# expert parameters
|
45
46
|
precond_init_scale=1.0, precond_lr=0.1):
|
46
47
|
if not 0.0 <= lr:
|
@@ -57,7 +58,9 @@ class ForeachPSGDKron(PSGDBase):
|
|
57
58
|
min_ndim_triangular=min_ndim_triangular, memory_save_mode=memory_save_mode,
|
58
59
|
momentum_into_precond_update=momentum_into_precond_update, precond_lr=precond_lr,
|
59
60
|
precond_init_scale=precond_init_scale, step=0, warmup_steps=warmup_steps, merge_dims=merge_dims,
|
60
|
-
split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype,
|
61
|
+
split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype,
|
62
|
+
storage_dtype=storage_dtype,
|
63
|
+
mars=mars, caution=caution, mars_gamma=mars_gamma)
|
61
64
|
super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
|
62
65
|
|
63
66
|
def _step(self, group):
|
@@ -77,7 +80,7 @@ class ForeachPSGDKron(PSGDBase):
|
|
77
80
|
|
78
81
|
vals = []
|
79
82
|
|
80
|
-
for p, g in split_p_and_g_in_group(group, should_promote=False):
|
83
|
+
for p, g in self.split_p_and_g_in_group(group, should_promote=False, beta1=beta):
|
81
84
|
state = self.state_(p)
|
82
85
|
|
83
86
|
if 'Q' not in state:
|
@@ -114,4 +117,4 @@ class ForeachPSGDKron(PSGDBase):
|
|
114
117
|
self.do_update(group, [p], [ea if momentum_into_precond_update else g], [q32], precond_lr, [q_orig],
|
115
118
|
store_triu_as_line)
|
116
119
|
g = psgd_precond_grad(q, self.state_(p)["exprs"], ea)
|
117
|
-
update_param_([p], self.clip_fn([g]), lr, weight_decay)
|
120
|
+
update_param_([p], self.clip_fn([g]), lr, weight_decay, caution=group['caution'], grad=[g])
|
heavyball/pure_psgd.py
CHANGED
@@ -5,9 +5,9 @@ Source available at https://github.com/evanatyourservice/kron_torch/blob/97a2b5e
|
|
5
5
|
"""
|
6
6
|
|
7
7
|
import torch
|
8
|
-
|
9
8
|
from heavyball.utils import identity
|
10
|
-
|
9
|
+
|
10
|
+
from .utils import update_param_, warmup, psgd_precond_grad, init_Q_exprs, PSGDBase, \
|
11
11
|
line_to_triu, triu_to_line, promote
|
12
12
|
|
13
13
|
|
@@ -38,7 +38,8 @@ class ForeachPurePSGD(PSGDBase):
|
|
38
38
|
max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
|
39
39
|
momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
|
40
40
|
split: bool = False, clip_fn: callable = None, store_triu_as_line: bool = True, foreach: bool = True,
|
41
|
-
q_dtype='float32', stochastic_schedule: bool = True,
|
41
|
+
q_dtype='float32', stochastic_schedule: bool = True, mars: bool = False, caution: bool = False,
|
42
|
+
mars_gamma: float = 0.0025, #
|
42
43
|
# expert parameters
|
43
44
|
precond_init_scale=1.0, precond_lr=0.1):
|
44
45
|
if not 0.0 <= lr:
|
@@ -49,11 +50,14 @@ class ForeachPurePSGD(PSGDBase):
|
|
49
50
|
if clip_fn is None:
|
50
51
|
clip_fn = identity
|
51
52
|
|
53
|
+
assert not mars, "MARS is not supported in this optimizer"
|
54
|
+
|
52
55
|
defaults = dict(lr=lr, weight_decay=weight_decay, max_size_triangular=max_size_triangular,
|
53
56
|
min_ndim_triangular=min_ndim_triangular, memory_save_mode=memory_save_mode,
|
54
57
|
momentum_into_precond_update=momentum_into_precond_update, precond_lr=precond_lr,
|
55
58
|
precond_init_scale=precond_init_scale, step=0, warmup_steps=warmup_steps, merge_dims=merge_dims,
|
56
|
-
split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype
|
59
|
+
split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype, mars=mars, caution=caution,
|
60
|
+
mars_gamma=mars_gamma)
|
57
61
|
super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
|
58
62
|
|
59
63
|
def _step(self, group):
|
@@ -70,7 +74,7 @@ class ForeachPurePSGD(PSGDBase):
|
|
70
74
|
|
71
75
|
vals = []
|
72
76
|
|
73
|
-
for p, g in split_p_and_g_in_group(group, should_promote=False):
|
77
|
+
for p, g in self.split_p_and_g_in_group(group, should_promote=False, beta1=0.0):
|
74
78
|
state = self.state_(p)
|
75
79
|
|
76
80
|
if 'Q' not in state:
|
@@ -98,4 +102,4 @@ class ForeachPurePSGD(PSGDBase):
|
|
98
102
|
q32 = [promote(q_) for q_ in q]
|
99
103
|
self.do_update(group, [p], [g], [q32], precond_lr, [q_orig], store_triu_as_line)
|
100
104
|
psgd_precond_grad(q, self.state_(p)["exprs"], g, inplace=True)
|
101
|
-
update_param_([p], self.clip_fn([g]), lr, weight_decay)
|
105
|
+
update_param_([p], self.clip_fn([g]), lr, weight_decay, caution=group['caution'], grad=[g])
|
@@ -1,12 +1,13 @@
|
|
1
1
|
import random
|
2
2
|
|
3
3
|
import torch
|
4
|
+
from heavyball.utils import mars_correction
|
4
5
|
|
5
6
|
from .utils import init_preconditioner, update_preconditioner, project, set_, adaptive_gradient_clipping_, exp_avg_sq_, \
|
6
|
-
beta_debias, schedule_free_, warmup, ScheduleFree,
|
7
|
+
beta_debias, schedule_free_, warmup, ScheduleFree, copy_stochastic_list_, promote
|
7
8
|
|
8
9
|
|
9
|
-
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=
|
10
|
+
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
|
10
11
|
def _compilable_exp_avg_sq_(exp_avg_sq, grad_projected, old_debiased2, eps):
|
11
12
|
eas32, gp32 = [list(map(promote, x)) for x in (exp_avg_sq, grad_projected)]
|
12
13
|
denom = exp_avg_sq_(eas32, gp32, old_debiased2, eps)
|
@@ -44,15 +45,19 @@ class SFPaLMForeachSOAP(ScheduleFree):
|
|
44
45
|
merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
|
45
46
|
data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1, r=0.0,
|
46
47
|
weight_lr_power=2.0, gradient_clip_val: float = 0.1, betas=(None, None), split: bool = False,
|
47
|
-
foreach: bool = True):
|
48
|
+
foreach: bool = True, mars: bool = False, caution: bool = False, mars_gamma: float = 0.0025):
|
48
49
|
if betas[0] is not None:
|
49
50
|
beta = betas[0]
|
51
|
+
|
52
|
+
assert not caution, "Caution is not implemented in ScheduleFree optimizers"
|
53
|
+
|
50
54
|
defaults = {"lr": lr, "beta": beta, "beta2_scale": beta2_scale, "eps": eps, "weight_decay": weight_decay,
|
51
55
|
"precondition_frequency": precondition_frequency, "max_precond_dim": max_precond_dim,
|
52
56
|
"merge_dims": merge_dims, "precondition_1d": precondition_1d, "normalize_grads": normalize_grads,
|
53
57
|
"correct_bias": correct_bias, 'warmup_steps': warmup_steps, 'r': r,
|
54
58
|
'weight_lr_power': weight_lr_power, 'train_mode': True, 'step': -1,
|
55
|
-
'gradient_clip_val': gradient_clip_val, 'weight_sum': 0, 'split': split
|
59
|
+
'gradient_clip_val': gradient_clip_val, 'weight_sum': 0, 'split': split, 'mars': mars,
|
60
|
+
'caution': caution, 'mars_gamma': mars_gamma}
|
56
61
|
super().__init__(params, defaults, foreach)
|
57
62
|
self._data_format = data_format
|
58
63
|
self.rng = random.Random(0x120983109)
|
@@ -61,6 +66,7 @@ class SFPaLMForeachSOAP(ScheduleFree):
|
|
61
66
|
vals = []
|
62
67
|
max_precond_dim = group['max_precond_dim']
|
63
68
|
precondition_1d = group['precondition_1d']
|
69
|
+
mars = group['mars']
|
64
70
|
|
65
71
|
step = group['step'] = group.get("step", 0) + 1
|
66
72
|
|
@@ -79,12 +85,14 @@ class SFPaLMForeachSOAP(ScheduleFree):
|
|
79
85
|
|
80
86
|
vals = []
|
81
87
|
|
82
|
-
for p, g in split_p_and_g_in_group(group):
|
88
|
+
for p, g in self.split_p_and_g_in_group(group, beta1=group['beta']):
|
83
89
|
state = self.state_(p)
|
84
90
|
|
85
91
|
if "z" not in state:
|
86
92
|
state["z"] = torch.clone(p).float()
|
87
93
|
state["exp_avg_sq"] = torch.zeros_like(g, dtype=torch.float32)
|
94
|
+
if mars:
|
95
|
+
state['mars_prev_grad'] = g.clone()
|
88
96
|
init_preconditioner(g, state, max_precond_dim, precondition_1d)
|
89
97
|
update_preconditioner(g, state, max_precond_dim, precondition_1d, 0, True)
|
90
98
|
continue # first step is skipped so that we never use the current gradients in the projection.
|
heavyball/utils.py
CHANGED
@@ -142,18 +142,26 @@ def beta_debias(beta, step):
|
|
142
142
|
|
143
143
|
|
144
144
|
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
|
145
|
-
def
|
146
|
-
if isinstance(state, torch.Tensor):
|
147
|
-
state.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
|
148
|
-
return torch.sqrt(state, out=out).clamp_(min=eps)
|
149
|
-
|
145
|
+
def _compilable_exp_avg_sq_(state, grad, beta2, eps, out=None):
|
150
146
|
torch._foreach_mul_(state, beta2)
|
151
147
|
[s.addcmul_(g, g, value=1 - beta2) for s, g in zip(state, grad)]
|
152
148
|
denom = torch._foreach_sqrt(state)
|
153
|
-
|
149
|
+
[denom.clamp_(min=eps) for denom in denom]
|
150
|
+
if out is not None:
|
151
|
+
copy_stochastic_list_(out, denom)
|
152
|
+
return out
|
153
|
+
|
154
154
|
return denom
|
155
155
|
|
156
156
|
|
157
|
+
def exp_avg_sq_(state, grad, beta2, eps, out=None):
|
158
|
+
state, grad = list_guard(state), list_guard(grad)
|
159
|
+
if not isinstance(beta2, torch.Tensor):
|
160
|
+
beta2 = torch.empty((), dtype=torch.float32, device=state[0].device).fill_(beta2)
|
161
|
+
if not isinstance(eps, torch.Tensor):
|
162
|
+
eps = torch.empty((), dtype=torch.float32, device=state[0].device).fill_(eps)
|
163
|
+
return _compilable_exp_avg_sq_(state, grad, beta2, eps, out)
|
164
|
+
|
157
165
|
def adaptive_gradient_clipping_(parameters: List[torch.Tensor], gradients: List[torch.Tensor], clip_val: float,
|
158
166
|
minimum: float = 1e-3, eps: float = 1e-8):
|
159
167
|
if clip_val <= 0:
|
@@ -168,12 +176,19 @@ def adaptive_gradient_clipping_(parameters: List[torch.Tensor], gradients: List[
|
|
168
176
|
torch._foreach_mul_(gradients, p_norm)
|
169
177
|
|
170
178
|
|
179
|
+
def is_compiling():
|
180
|
+
try:
|
181
|
+
return torch.compiler.is_compiling()
|
182
|
+
except AttributeError:
|
183
|
+
return True
|
184
|
+
|
185
|
+
|
171
186
|
def set_(dst: torch.Tensor, src: torch.Tensor):
|
172
|
-
if not
|
187
|
+
if not is_compiling() and src.data_ptr() == dst.data_ptr():
|
173
188
|
return
|
174
189
|
if src.shape != dst.shape:
|
175
190
|
src = src.reshape_as(dst)
|
176
|
-
if not
|
191
|
+
if not is_compiling() and src.is_contiguous() and dst.is_contiguous() and src.dtype == dst.dtype:
|
177
192
|
dst.set_(src)
|
178
193
|
else:
|
179
194
|
dst.copy_(src)
|
@@ -338,11 +353,18 @@ def _compilable_stochastic_lerp_(x: List[torch.Tensor], y: List[torch.Tensor], a
|
|
338
353
|
|
339
354
|
|
340
355
|
def stochastic_lerp_(x: List[torch.Tensor], y: List[torch.Tensor], a: Union[float, int, torch.Tensor]):
|
356
|
+
x, y = list_guard(x), list_guard(y)
|
341
357
|
if not isinstance(a, torch.Tensor):
|
342
358
|
a = torch.empty((), dtype=torch.float32, device=x[0].device).fill_(a)
|
343
359
|
_compilable_stochastic_lerp_(x, y, a)
|
344
360
|
|
345
361
|
|
362
|
+
def list_guard(x):
|
363
|
+
if isinstance(x, (list, tuple)):
|
364
|
+
return x
|
365
|
+
return [x]
|
366
|
+
|
367
|
+
|
346
368
|
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
|
347
369
|
def _compilable_stochastic_add_(x: List[torch.Tensor], y: List[torch.Tensor], alpha: Union[float, int, torch.Tensor]):
|
348
370
|
for x_, y_ in zip(x, y):
|
@@ -353,6 +375,7 @@ def _compilable_stochastic_add_(x: List[torch.Tensor], y: List[torch.Tensor], al
|
|
353
375
|
|
354
376
|
|
355
377
|
def stochastic_add_(x: List[torch.Tensor], y: List[torch.Tensor], alpha: Union[float, int, torch.Tensor]):
|
378
|
+
x, y = list_guard(x), list_guard(y)
|
356
379
|
if not isinstance(alpha, torch.Tensor):
|
357
380
|
alpha = torch.empty((), dtype=torch.float32, device=x[0].device).fill_(alpha)
|
358
381
|
_compilable_stochastic_add_(x, y, alpha)
|
@@ -463,6 +486,43 @@ class StatefulOptimizer(torch.optim.Optimizer):
|
|
463
486
|
def state_(self, arg: torch.Tensor):
|
464
487
|
return self.state[self.key(arg)]
|
465
488
|
|
489
|
+
def mars_correct_list(self, group, p_list, g_list, mars_gamma, beta):
|
490
|
+
for p, g in zip(p_list, g_list):
|
491
|
+
state = self.state_(p)
|
492
|
+
if 'mars_old_grad' not in state:
|
493
|
+
state['mars_old_grad'] = torch.zeros_like(g)
|
494
|
+
old_gs = [self.state_(p)['mars_old_grad'] for p in p_list]
|
495
|
+
mars_correction(g_list, old_gs, mars_gamma, beta)
|
496
|
+
|
497
|
+
def split_p_and_g_in_group(self, group: dict, skip_none: bool = True, should_promote: bool = True,
|
498
|
+
beta1: float = -1.0):
|
499
|
+
for p in group["params"]:
|
500
|
+
if skip_none and p.grad is None:
|
501
|
+
continue
|
502
|
+
|
503
|
+
if p.grad is None:
|
504
|
+
grad = None
|
505
|
+
else:
|
506
|
+
if should_promote:
|
507
|
+
grad = promote(p.grad)
|
508
|
+
else:
|
509
|
+
grad = p.grad
|
510
|
+
if beta1 >= 0 and group.get('mars', False):
|
511
|
+
self.mars_correct_list(group, [p], [grad], group['mars_gamma'], beta1)
|
512
|
+
|
513
|
+
p.grad = None
|
514
|
+
|
515
|
+
p_views = merge_group(group, p)
|
516
|
+
if grad is not None:
|
517
|
+
grad = merge_group(group, grad)
|
518
|
+
if isinstance(p_views, torch.Tensor):
|
519
|
+
yield p_views, grad
|
520
|
+
continue
|
521
|
+
if grad is None:
|
522
|
+
yield from zip(p_views, [None] * len(p_views))
|
523
|
+
continue
|
524
|
+
yield from zip(p_views, grad)
|
525
|
+
|
466
526
|
def state_size(self) -> int:
|
467
527
|
total_bytes = 0
|
468
528
|
|
@@ -472,7 +532,7 @@ class StatefulOptimizer(torch.optim.Optimizer):
|
|
472
532
|
total_bytes += x.numel() * x.element_size()
|
473
533
|
|
474
534
|
for group in self.param_groups:
|
475
|
-
for p, _ in split_p_and_g_in_group(group, skip_none=False):
|
535
|
+
for p, _ in self.split_p_and_g_in_group(group, skip_none=False):
|
476
536
|
tree_map(_add, self.state_(p))
|
477
537
|
return total_bytes
|
478
538
|
|
@@ -625,7 +685,7 @@ def _compilable_copy_stochastic_(target: torch.Tensor, source: torch.Tensor):
|
|
625
685
|
|
626
686
|
|
627
687
|
def copy_stochastic_(target: torch.Tensor, source: torch.Tensor):
|
628
|
-
if not
|
688
|
+
if not is_compiling() and target.data_ptr() == source.data_ptr():
|
629
689
|
return
|
630
690
|
if target.dtype != torch.bfloat16 or source.dtype not in (torch.float16, torch.float32, torch.float64):
|
631
691
|
set_(target, source)
|
@@ -633,14 +693,16 @@ def copy_stochastic_(target: torch.Tensor, source: torch.Tensor):
|
|
633
693
|
|
634
694
|
|
635
695
|
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
|
636
|
-
def _compilable_update_(p, u, decay, add_fn, lr):
|
696
|
+
def _compilable_update_(p, u, decay, add_fn, lr, caution, g):
|
637
697
|
u = [u_.view_as(p_) for u_, p_ in zip(u, p)]
|
638
|
-
p32, u32 = [list(map(promote, x)) for x in [p, u]]
|
698
|
+
p32, u32, g32 = [list(map(promote, x)) for x in [p, u, g]]
|
639
699
|
|
640
700
|
if decay > 0:
|
641
701
|
torch._foreach_mul_(p32, 1 - decay * lr)
|
642
702
|
|
643
|
-
for p32_, u32_ in zip(p32, u32): # lr is data-dependent -> can't compile a foreach
|
703
|
+
for p32_, u32_, g32_ in zip(p32, u32, g32): # lr is data-dependent -> can't compile a foreach
|
704
|
+
if caution:
|
705
|
+
_compilable_cautioning_(g32_, u32_)
|
644
706
|
if add_fn is None:
|
645
707
|
p32_.add_(u32_, alpha=lr)
|
646
708
|
else:
|
@@ -650,9 +712,12 @@ def _compilable_update_(p, u, decay, add_fn, lr):
|
|
650
712
|
|
651
713
|
|
652
714
|
def update_param_(param: List[torch.Tensor], update: List[torch.Tensor], lr: float, decay: float,
|
653
|
-
add_fn: callable = None):
|
715
|
+
add_fn: callable = None, caution: bool = False, grad: List[torch.Tensor] = None):
|
654
716
|
lr_tensor = torch.empty((), dtype=torch.float32, device=param[0].device).fill_(lr)
|
655
|
-
|
717
|
+
param, update, grad = list_guard(param), list_guard(update), list_guard(grad)
|
718
|
+
if not caution:
|
719
|
+
grad = [None] * len(param)
|
720
|
+
_compilable_update_(param, update, decay, add_fn, lr_tensor, caution, grad)
|
656
721
|
|
657
722
|
|
658
723
|
def precond_schedule(step, precond_scheduler, rng):
|
@@ -965,18 +1030,45 @@ class PSGDBase(StatefulOptimizer):
|
|
965
1030
|
psgd_balance_Q(q)
|
966
1031
|
|
967
1032
|
|
968
|
-
|
969
|
-
|
1033
|
+
# TODO: Figure out why this sometimes crashes
|
1034
|
+
# @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
|
1035
|
+
def _compilable_precond_grad_cached_(cached_q, ea, expr, param, lr, weight_decay, clip_fn, caution, grad):
|
970
1036
|
md = min_dtype(cached_q + [ea])
|
971
1037
|
new = torch.einsum(expr, *[c_.to(md) for c_ in cached_q], ea.to(md)).to(torch.float32)
|
972
|
-
update_param_([param], clip_fn([new]), lr, weight_decay)
|
1038
|
+
update_param_([param], clip_fn([new]), lr, weight_decay, caution=caution, grad=grad)
|
973
1039
|
|
974
1040
|
|
975
1041
|
def precond_grad_cached_(cached_q: List[torch.Tensor], ea: torch.Tensor, expr: str, param: torch.Tensor, lr: float,
|
976
|
-
weight_decay: float, clip_fn):
|
1042
|
+
weight_decay: float, clip_fn, caution, grad):
|
977
1043
|
if isinstance(lr, float):
|
978
1044
|
lr = torch.empty((), dtype=torch.float32, device=param.device).fill_(lr)
|
979
|
-
_compilable_precond_grad_cached_(cached_q, ea, expr, param, lr, weight_decay, clip_fn)
|
1045
|
+
_compilable_precond_grad_cached_(cached_q, ea, expr, param, lr, weight_decay, clip_fn, caution, grad)
|
1046
|
+
|
1047
|
+
|
1048
|
+
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
|
1049
|
+
def _compilable_mars_correction_(g, old_g, a):
|
1050
|
+
g_copy = [g_.clone() for g_ in g]
|
1051
|
+
_compilable_stochastic_lerp_(g, old_g, a)
|
1052
|
+
copy_stochastic_list_(old_g, g_copy)
|
1053
|
+
|
1054
|
+
|
1055
|
+
def mars_correction(g, old_g, beta1, gamma):
|
1056
|
+
a = -gamma * beta1 / (1 - beta1)
|
1057
|
+
g, old_g = list_guard(g), list_guard(old_g)
|
1058
|
+
a = torch.empty((), dtype=torch.float32, device=g[0].device).fill_(a)
|
1059
|
+
_compilable_mars_correction_(g, old_g, a)
|
1060
|
+
|
1061
|
+
|
1062
|
+
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
|
1063
|
+
def _compilable_cautioning_(g, update):
|
1064
|
+
mask = (g * update) > 0
|
1065
|
+
update.masked_fill_(~mask, 0)
|
1066
|
+
scale = mask.numel() / mask.sum().clamp(min=1)
|
1067
|
+
update.mul_(scale)
|
1068
|
+
|
1069
|
+
|
1070
|
+
def caution(g, update):
|
1071
|
+
_compilable_cautioning_(g, update)
|
980
1072
|
|
981
1073
|
|
982
1074
|
def precond_update_prob_schedule(max_prob=1.0, min_prob=0.03, decay=0.001, flat_start=250):
|
@@ -1013,29 +1105,3 @@ def merge_group(group, *tensors):
|
|
1013
1105
|
append_or_extend(out, dim_merger(t, group['max_size_triangular'] if 'max_size_triangular' in group else group[
|
1014
1106
|
'max_precond_dim'], group.get('split', False)))
|
1015
1107
|
return out
|
1016
|
-
|
1017
|
-
|
1018
|
-
def split_p_and_g_in_group(group: dict, skip_none: bool = True, should_promote: bool = True):
|
1019
|
-
for p in group["params"]:
|
1020
|
-
if skip_none and p.grad is None:
|
1021
|
-
continue
|
1022
|
-
|
1023
|
-
if p.grad is None:
|
1024
|
-
grad = None
|
1025
|
-
else:
|
1026
|
-
if should_promote:
|
1027
|
-
grad = promote(p.grad)
|
1028
|
-
else:
|
1029
|
-
grad = p.grad
|
1030
|
-
p.grad = None
|
1031
|
-
|
1032
|
-
p_views = merge_group(group, p)
|
1033
|
-
if grad is not None:
|
1034
|
-
grad = merge_group(group, grad)
|
1035
|
-
if isinstance(p_views, torch.Tensor):
|
1036
|
-
yield p_views, grad
|
1037
|
-
continue
|
1038
|
-
if grad is None:
|
1039
|
-
yield from zip(p_views, [None] * len(p_views))
|
1040
|
-
continue
|
1041
|
-
yield from zip(p_views, grad)
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: heavyball
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.22.0
|
4
4
|
Summary: Efficient optimizers
|
5
5
|
Home-page: https://github.com/clashluke/heavyball
|
6
6
|
Author: Lucas Nestler
|
@@ -32,7 +32,7 @@ A simple package of efficient optimizers
|
|
32
32
|
The goal is not to thrive for completeness, full maintenance or abstraction, but instead to provide a simple
|
33
33
|
largely static alternative to `torch.optim` with more and better optimizers.
|
34
34
|
|
35
|
-
Currently (2024-11-
|
35
|
+
Currently (2024-11-26, 0.22.0), the recommended stable optimizer is `PrecondSchedulePaLMSOAP` (see below). The
|
36
36
|
recommended experimental optimizer is `DelayedPSGDKron` ([tuning guide](docs/psgd_efficiency.md)).
|
37
37
|
|
38
38
|
## Features
|
@@ -0,0 +1,24 @@
|
|
1
|
+
heavyball/__init__.py,sha256=icHYN-MGsmHkLUlHCMcZkOlwY7GT63_ayR_a5iPKmzM,2226
|
2
|
+
heavyball/cached_delayed_psgd_kron.py,sha256=n3wIOhrop0Ls4MZ0kXpwGuImp1jzPs6VGdxIlPyoYdQ,6827
|
3
|
+
heavyball/cached_psgd_kron.py,sha256=KCLsfvj9qh_2FNwRTdWM3zjnt2oGHfsf4Y341rPcceI,6778
|
4
|
+
heavyball/delayed_psgd.py,sha256=CaG17zqorLsCSDeGEePOyb6n9ugv8W6gyRQqeQNq-e8,6272
|
5
|
+
heavyball/foreach_adamw.py,sha256=uawSbGGUD2E1RtcwspP83yQNElERdGX-diqCI5e8FqE,2825
|
6
|
+
heavyball/foreach_adopt.py,sha256=DFEaPswVzdHcbxC-mirsf_okM_HR6r34PDUTty5CrUE,3547
|
7
|
+
heavyball/foreach_laprop.py,sha256=J4Vms0nAOMh3GQtAOPyrYOe5WtpzokVv25b9oDnwc2A,2833
|
8
|
+
heavyball/foreach_sfadamw.py,sha256=HWbLekY5BloHDIgrN2J0a7IolZCt8Ah2xkLAU_-5oSc,3079
|
9
|
+
heavyball/foreach_soap.py,sha256=7B_dP2Hm_xqwpBQiPYkv_c6eoRnU1dV2VZfvSoa4uJ8,4729
|
10
|
+
heavyball/p_adam.py,sha256=F-id4qOkAaDTJaKTSNhSsonX-Js5IzIu1Bdj1S4qE2g,6306
|
11
|
+
heavyball/palm_foreach_sfadamw.py,sha256=E8raxrBIkSmTEGFzwnfWxKwDJjBQE2vdsmyqfc8aL_A,3375
|
12
|
+
heavyball/palm_foreach_soap.py,sha256=IknGm_CzrqDIFEoCkejxjoZ4sfIy6RSoInqlMUOYLB4,6156
|
13
|
+
heavyball/precond_schedule_foreach_soap.py,sha256=bJ2ifPFa8zEP9GO8eBpqZzsmP7p_iQkkCkllNeEMHPU,4892
|
14
|
+
heavyball/precond_schedule_palm_foreach_soap.py,sha256=4dT9f134-Faq2KuCMCHzMtrkMO-es5p_DYS1of5yF-s,6428
|
15
|
+
heavyball/precond_schedule_sfpsoap.py,sha256=FOR-axwlkSN7IHZWYYUVFfjSFCLxc_NdiTlb-n5gmgs,7530
|
16
|
+
heavyball/psgd_kron.py,sha256=achB23mQUT3F00IGhjjVf_8YW7VOTHR6YdoCDRyWxsI,6039
|
17
|
+
heavyball/pure_psgd.py,sha256=dbYgkunFFA6EsO6fJEhaJRxTH0smi7qLX3Np9XTQ9E4,5079
|
18
|
+
heavyball/schedule_free_palm_foreach_soap.py,sha256=0WT_gvTKymqLQzYT6ewDgCmpDq-HgMAewipw1QvyQYA,7267
|
19
|
+
heavyball/utils.py,sha256=TVpyev0oL4a78px4cvtaGoGPJqfpfTKE-xBWkRCmzkw,39785
|
20
|
+
heavyball-0.22.0.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
|
21
|
+
heavyball-0.22.0.dist-info/METADATA,sha256=LqVR3tUgxpk21zsmKxfJAQCKLPzmaQz2PQiKvlvpQe8,11926
|
22
|
+
heavyball-0.22.0.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
23
|
+
heavyball-0.22.0.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
|
24
|
+
heavyball-0.22.0.dist-info/RECORD,,
|
@@ -1,24 +0,0 @@
|
|
1
|
-
heavyball/__init__.py,sha256=iqP428JWwwx-XDOZ0nUdbCkOLEyfoqVyWZLQLAcwxaw,2214
|
2
|
-
heavyball/cached_delayed_psgd_kron.py,sha256=Nyxl-G-o6greKwDN-vLiw5W02GXO2LRvknc0OzvzFnE,6674
|
3
|
-
heavyball/cached_psgd_kron.py,sha256=HzD6se0AYb-W5hpydUxcR9uqrpe_54PBwgL1VWX3DHU,6592
|
4
|
-
heavyball/delayed_psgd.py,sha256=m4c-OvcLMrRxSAPYs2l6Up21uCyF2kvHvpcnfe3nzGs,6212
|
5
|
-
heavyball/foreach_adamw.py,sha256=Rb5U80cgUcEqlEbUU250UTWdoqA7nyiqkV5w1U4bWX4,2445
|
6
|
-
heavyball/foreach_adopt.py,sha256=ecdi1fKg9i087OGjtKWVbE_DD6Yf4pvpzv4ELCcusvQ,3211
|
7
|
-
heavyball/foreach_laprop.py,sha256=vi6C_gfjXxw5uN0KHgzxI9itUI1dcgOf3ufoO_VVMp0,2471
|
8
|
-
heavyball/foreach_sfadamw.py,sha256=rLZORmCIMu9G09FdDgMSiI6pNq34IVoxsPVWtmeDdbQ,2753
|
9
|
-
heavyball/foreach_soap.py,sha256=4mWSMWYTdjgiXiboI5DwdigecruDtNGKylGAFAVhCRA,4562
|
10
|
-
heavyball/p_adam.py,sha256=Xyxsavwtw-t0OyTHitYQXZSmF9UJlMDzDAURge-MbbQ,6047
|
11
|
-
heavyball/palm_foreach_sfadamw.py,sha256=JbNrcoquBGGUI5XNMFouDjpNurVHUW9DbX1A3tSrtno,3025
|
12
|
-
heavyball/palm_foreach_soap.py,sha256=GzAwM8kOt1X0QCmUZDTdHwPxbJwjH8ic43dyAK5BYCA,6015
|
13
|
-
heavyball/precond_schedule_foreach_soap.py,sha256=HcObXLfSNN_lKNb4nmC6tkdHcqDIMNX6hILpHKScqLc,4744
|
14
|
-
heavyball/precond_schedule_palm_foreach_soap.py,sha256=xZ7CJvIfdu2RNAZt2g1S7Xb0Jyy1hNC4MowOFU3nWkk,6283
|
15
|
-
heavyball/precond_schedule_sfpsoap.py,sha256=PNneiOkrRyV1yIZn91lPmYofd1_OiLqJTDy75RLpXJk,7273
|
16
|
-
heavyball/psgd_kron.py,sha256=RSLJi5FnSmYjvYBufNDClnYRm-eN_Kpa1Ar2tNP6-X0,5824
|
17
|
-
heavyball/pure_psgd.py,sha256=LZK0qmvZkBF8g00evaVLtW-sIUJmdoxag1K7O26AqEo,4820
|
18
|
-
heavyball/schedule_free_palm_foreach_soap.py,sha256=_AqrnChY6iDlQUkF2YUxS7eLjSWCIuvEUOHvMHVM1yY,6873
|
19
|
-
heavyball/utils.py,sha256=xTDZEt2_DM57EYnJkRq7d7scTnro4eKPdMtEwPdLy-c,37218
|
20
|
-
heavyball-0.21.8.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
|
21
|
-
heavyball-0.21.8.dist-info/METADATA,sha256=nLyxHlENmhAGyU9GManYKKJJTykhsAMt7hkJNXPu_YY,11926
|
22
|
-
heavyball-0.21.8.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
23
|
-
heavyball-0.21.8.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
|
24
|
-
heavyball-0.21.8.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|