heavyball 0.18.6__tar.gz → 0.18.8__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {heavyball-0.18.6 → heavyball-0.18.8}/PKG-INFO +1 -1
- {heavyball-0.18.6 → heavyball-0.18.8}/heavyball/foreach_adamw.py +20 -8
- {heavyball-0.18.6 → heavyball-0.18.8}/heavyball/foreach_adopt.py +29 -5
- {heavyball-0.18.6 → heavyball-0.18.8}/heavyball/foreach_laprop.py +20 -11
- heavyball-0.18.8/heavyball/foreach_sfadamw.py +60 -0
- heavyball-0.18.8/heavyball/palm_foreach_sfadamw.py +64 -0
- {heavyball-0.18.6 → heavyball-0.18.8}/heavyball/psgd_kron.py +2 -1
- {heavyball-0.18.6 → heavyball-0.18.8}/heavyball/utils.py +35 -26
- {heavyball-0.18.6 → heavyball-0.18.8}/heavyball.egg-info/PKG-INFO +1 -1
- {heavyball-0.18.6 → heavyball-0.18.8}/setup.py +1 -1
- {heavyball-0.18.6 → heavyball-0.18.8}/test/test_bf16_params.py +0 -1
- heavyball-0.18.6/heavyball/foreach_sfadamw.py +0 -54
- heavyball-0.18.6/heavyball/palm_foreach_sfadamw.py +0 -57
- {heavyball-0.18.6 → heavyball-0.18.8}/LICENSE +0 -0
- {heavyball-0.18.6 → heavyball-0.18.8}/README.md +0 -0
- {heavyball-0.18.6 → heavyball-0.18.8}/heavyball/__init__.py +0 -0
- {heavyball-0.18.6 → heavyball-0.18.8}/heavyball/cached_delayed_psgd_kron.py +0 -0
- {heavyball-0.18.6 → heavyball-0.18.8}/heavyball/cached_psgd_kron.py +0 -0
- {heavyball-0.18.6 → heavyball-0.18.8}/heavyball/delayed_psgd.py +0 -0
- {heavyball-0.18.6 → heavyball-0.18.8}/heavyball/foreach_soap.py +0 -0
- {heavyball-0.18.6 → heavyball-0.18.8}/heavyball/p_adam.py +0 -0
- {heavyball-0.18.6 → heavyball-0.18.8}/heavyball/palm_foreach_soap.py +0 -0
- {heavyball-0.18.6 → heavyball-0.18.8}/heavyball/precond_schedule_foreach_soap.py +0 -0
- {heavyball-0.18.6 → heavyball-0.18.8}/heavyball/precond_schedule_palm_foreach_soap.py +0 -0
- {heavyball-0.18.6 → heavyball-0.18.8}/heavyball/precond_schedule_sfpsoap.py +0 -0
- {heavyball-0.18.6 → heavyball-0.18.8}/heavyball/pure_psgd.py +0 -0
- {heavyball-0.18.6 → heavyball-0.18.8}/heavyball/schedule_free_palm_foreach_soap.py +0 -0
- {heavyball-0.18.6 → heavyball-0.18.8}/heavyball.egg-info/SOURCES.txt +0 -0
- {heavyball-0.18.6 → heavyball-0.18.8}/heavyball.egg-info/dependency_links.txt +0 -0
- {heavyball-0.18.6 → heavyball-0.18.8}/heavyball.egg-info/requires.txt +0 -0
- {heavyball-0.18.6 → heavyball-0.18.8}/heavyball.egg-info/top_level.txt +0 -0
- {heavyball-0.18.6 → heavyball-0.18.8}/setup.cfg +0 -0
- {heavyball-0.18.6 → heavyball-0.18.8}/test/test_bf16_q.py +0 -0
- {heavyball-0.18.6 → heavyball-0.18.8}/test/test_closure.py +0 -0
- {heavyball-0.18.6 → heavyball-0.18.8}/test/test_foreach.py +0 -0
- {heavyball-0.18.6 → heavyball-0.18.8}/test/test_memory.py +0 -0
- {heavyball-0.18.6 → heavyball-0.18.8}/test/test_merge.py +0 -0
- {heavyball-0.18.6 → heavyball-0.18.8}/test/test_no_grad.py +0 -0
- {heavyball-0.18.6 → heavyball-0.18.8}/test/test_psgd.py +0 -0
- {heavyball-0.18.6 → heavyball-0.18.8}/test/test_soap.py +0 -0
- {heavyball-0.18.6 → heavyball-0.18.8}/test/test_stochastic_updates.py +0 -0
@@ -1,7 +1,21 @@
|
|
1
1
|
import torch
|
2
2
|
import torch.optim
|
3
|
+
from heavyball.utils import copy_stochastic_list_
|
3
4
|
|
4
|
-
from .utils import warmup, exp_avg_sq_, beta_debias, update_param_, StatefulOptimizer
|
5
|
+
from .utils import warmup, exp_avg_sq_, beta_debias, update_param_, StatefulOptimizer, promote
|
6
|
+
|
7
|
+
|
8
|
+
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
|
9
|
+
def _compilable_step_(y, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps, decay):
|
10
|
+
g32, exp_avg32, exp_avg_sq32 = [list(map(promote, x)) for x in [grad, exp_avg, exp_avg_sq]]
|
11
|
+
|
12
|
+
torch._foreach_lerp_(exp_avg32, g32, 1 - beta_debias(beta1, step + 1))
|
13
|
+
denom = list(exp_avg_sq_(exp_avg_sq32, g32, beta_debias(beta2, step + 1), eps))
|
14
|
+
|
15
|
+
update_param_(y, exp_avg32, lr, decay, lambda p, e, l: p.addcdiv_(e, denom.pop(0), value=l))
|
16
|
+
|
17
|
+
copy_stochastic_list_(exp_avg, exp_avg32)
|
18
|
+
copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
|
5
19
|
|
6
20
|
|
7
21
|
class ForeachAdamW(StatefulOptimizer):
|
@@ -30,13 +44,11 @@ class ForeachAdamW(StatefulOptimizer):
|
|
30
44
|
self.state_(p)['exp_avg_sq'] = torch.zeros_like(p.data, dtype=torch.float32)
|
31
45
|
|
32
46
|
y, grad, exp_avg_sq, exp_avg = zip(
|
33
|
-
*[(p.data, p.grad
|
34
|
-
|
35
|
-
# Decay the first and second moment running average coefficient
|
36
|
-
torch._foreach_lerp_(exp_avg, grad, 1 - beta_debias(group['betas'][0], k + 1))
|
37
|
-
denom = list(exp_avg_sq_(exp_avg_sq, grad, beta_debias(group['betas'][1], k + 1), eps))
|
47
|
+
*[(p.data, p.grad, self.state_(p)['exp_avg_sq'], self.state_(p)['exp_avg']) for p in active_p])
|
38
48
|
|
39
|
-
# Normalize grad in-place for memory efficiency
|
40
49
|
lr = -warmup(group['lr'], k + 1, group['warmup_steps'])
|
41
|
-
|
50
|
+
lr = torch.empty((), dtype=torch.float32, device=y[0].device).fill_(lr)
|
51
|
+
step = torch.empty((), dtype=torch.int32, device=y[0].device).fill_(k)
|
52
|
+
_compilable_step_(y, grad, exp_avg_sq, exp_avg, group['betas'][0], group['betas'][1], step, lr, eps, decay)
|
53
|
+
|
42
54
|
group['k'] = k + 1
|
@@ -1,7 +1,27 @@
|
|
1
1
|
import torch
|
2
2
|
import torch.optim
|
3
|
+
from heavyball.utils import copy_stochastic_list_
|
3
4
|
|
4
|
-
from .utils import warmup, beta_debias, update_param_, StatefulOptimizer
|
5
|
+
from .utils import warmup, beta_debias, update_param_, StatefulOptimizer, promote
|
6
|
+
|
7
|
+
|
8
|
+
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
|
9
|
+
def _compilable_step_(y, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps, decay):
|
10
|
+
g32, exp_avg32, exp_avg_sq32 = [list(map(promote, x)) for x in [grad, exp_avg, exp_avg_sq]]
|
11
|
+
update_param_(y, exp_avg, lr, decay)
|
12
|
+
|
13
|
+
beta1 = beta_debias(beta1, step)
|
14
|
+
denom = torch._foreach_sqrt(exp_avg_sq32)
|
15
|
+
torch._foreach_maximum_(denom, eps)
|
16
|
+
torch._foreach_mul_(exp_avg32, beta1)
|
17
|
+
[ea32.addcdiv_(g, d, value=1 - beta1) for ea32, g, d in zip(exp_avg32, g32, denom)]
|
18
|
+
|
19
|
+
beta2 = beta_debias(beta2, step + 1)
|
20
|
+
torch._foreach_mul_(exp_avg_sq32, beta2)
|
21
|
+
[eas32.addcmul_(g, g, value=1 - beta2) for eas32, g in zip(exp_avg_sq32, g32)]
|
22
|
+
|
23
|
+
copy_stochastic_list_(exp_avg, exp_avg32)
|
24
|
+
copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
|
5
25
|
|
6
26
|
|
7
27
|
class ForeachADOPT(StatefulOptimizer):
|
@@ -31,12 +51,18 @@ class ForeachADOPT(StatefulOptimizer):
|
|
31
51
|
self.state_(p)['exp_avg_sq'] = torch.zeros_like(p.data, dtype=torch.float32)
|
32
52
|
|
33
53
|
y, grad, exp_avg_sq, exp_avg = zip(
|
34
|
-
*[(p.data, p.grad
|
54
|
+
*[(p.data, p.grad, self.state_(p)['exp_avg_sq'], self.state_(p)['exp_avg']) for p in active_p])
|
55
|
+
|
56
|
+
group['k'] = k + 1
|
35
57
|
|
36
58
|
if k > 1:
|
37
59
|
lr = -warmup(group['lr'], k - 1, group['warmup_steps'])
|
60
|
+
lr = torch.empty((), dtype=torch.float32, device=y[0].device).fill_(lr)
|
61
|
+
k = torch.empty((), dtype=torch.int32, device=y[0].device).fill_(k)
|
62
|
+
_compilable_step_(y, grad, exp_avg_sq, exp_avg, group['betas'][0], group['betas'][1], k, lr, eps, decay)
|
63
|
+
return
|
38
64
|
|
39
|
-
|
65
|
+
grad = [promote(g) for g in grad]
|
40
66
|
if k > 0:
|
41
67
|
beta1 = beta_debias(group['betas'][0], k)
|
42
68
|
denom = torch._foreach_sqrt(exp_avg_sq)
|
@@ -48,5 +74,3 @@ class ForeachADOPT(StatefulOptimizer):
|
|
48
74
|
torch._foreach_mul_(exp_avg_sq, beta2)
|
49
75
|
torch._foreach_addcmul_(exp_avg_sq, grad, grad, value=1 - beta2)
|
50
76
|
del grad
|
51
|
-
|
52
|
-
group['k'] = k + 1
|
@@ -1,7 +1,20 @@
|
|
1
1
|
import torch
|
2
2
|
import torch.optim
|
3
3
|
|
4
|
-
from .utils import warmup, exp_avg_sq_, beta_debias, update_param_, StatefulOptimizer
|
4
|
+
from .utils import warmup, exp_avg_sq_, beta_debias, update_param_, StatefulOptimizer, promote
|
5
|
+
|
6
|
+
|
7
|
+
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
|
8
|
+
def _compilable_step_(y, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps, decay):
|
9
|
+
g32, exp_avg32, exp_avg_sq32 = [list(map(promote, x)) for x in [grad, exp_avg, exp_avg_sq]]
|
10
|
+
|
11
|
+
denom = exp_avg_sq_(exp_avg_sq32, g32, beta_debias(beta2, step), eps)
|
12
|
+
|
13
|
+
beta1 = beta_debias(beta1, step)
|
14
|
+
torch._foreach_mul_(exp_avg32, beta1)
|
15
|
+
[ea32.addcdiv_(g, d, value=1 - beta1) for ea32, g, d in zip(exp_avg32, g32, denom)]
|
16
|
+
|
17
|
+
update_param_(y, exp_avg32, lr, decay)
|
5
18
|
|
6
19
|
|
7
20
|
class ForeachLaProp(StatefulOptimizer):
|
@@ -31,17 +44,13 @@ class ForeachLaProp(StatefulOptimizer):
|
|
31
44
|
self.state_(p)['exp_avg_sq'] = torch.zeros_like(p.data, dtype=torch.float32)
|
32
45
|
|
33
46
|
y, grad, exp_avg_sq, exp_avg = zip(
|
34
|
-
*[(p.data, p.grad
|
35
|
-
|
36
|
-
# Decay the first and second moment running average coefficient
|
37
|
-
denom = exp_avg_sq_(exp_avg_sq, grad, beta_debias(group['betas'][1], k + 1), eps)
|
38
|
-
beta1 = beta_debias(group['betas'][0], k + 1)
|
39
|
-
torch._foreach_mul_(exp_avg, beta1)
|
40
|
-
torch._foreach_addcdiv_(exp_avg, grad, denom, 1 - beta1)
|
41
|
-
del grad
|
47
|
+
*[(p.data, p.grad, self.state_(p)['exp_avg_sq'], self.state_(p)['exp_avg']) #
|
48
|
+
for p in active_p])
|
42
49
|
|
43
|
-
# Normalize grad in-place for memory efficiency
|
44
50
|
lr = -warmup(group['lr'], k + 1, group['warmup_steps'])
|
45
|
-
|
51
|
+
lr = torch.empty((), dtype=torch.float32, device=y[0].device).fill_(lr)
|
52
|
+
step = torch.empty((), dtype=torch.int32, device=y[0].device).fill_(k + 1)
|
53
|
+
|
54
|
+
_compilable_step_(y, grad, exp_avg_sq, exp_avg, group['betas'][0], group['betas'][1], step, lr, eps, decay)
|
46
55
|
|
47
56
|
group['k'] = k + 1
|
@@ -0,0 +1,60 @@
|
|
1
|
+
import torch
|
2
|
+
import torch.optim
|
3
|
+
from heavyball.utils import get_ckp1
|
4
|
+
|
5
|
+
from .utils import warmup, ScheduleFree, exp_avg_sq_, beta_debias, promote, _compilable_schedule_free_
|
6
|
+
|
7
|
+
|
8
|
+
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
|
9
|
+
def _compilable_step_(y, grad, exp_avg_sq, z, beta1, beta2, step, ckp1, eps, decay, lr):
|
10
|
+
old_debiased2 = beta_debias(beta2, step)
|
11
|
+
|
12
|
+
g32 = [promote(g_) for g_ in grad]
|
13
|
+
exp_avg_sq32 = [promote(e_) for e_ in exp_avg_sq]
|
14
|
+
|
15
|
+
denom = exp_avg_sq_(exp_avg_sq32, g32, old_debiased2, eps)
|
16
|
+
torch._foreach_div_(g32, denom)
|
17
|
+
if decay != 0:
|
18
|
+
torch._foreach_add_(g32, y, alpha=decay)
|
19
|
+
for p, z_, g in zip(y, z, g32):
|
20
|
+
_compilable_schedule_free_(p, z_, ckp1, g, lr, beta1)
|
21
|
+
|
22
|
+
|
23
|
+
class ForeachSFAdamW(ScheduleFree):
|
24
|
+
def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0, r=0.0,
|
25
|
+
weight_lr_power=2.0, foreach: bool = True):
|
26
|
+
|
27
|
+
defaults = dict(lr=lr, betas=betas, eps=eps, r=r, k=0, warmup_steps=warmup_steps, train_mode=True,
|
28
|
+
weight_sum=0.0, lr_max=-1.0, weight_lr_power=weight_lr_power, weight_decay=weight_decay,
|
29
|
+
foreach=foreach)
|
30
|
+
super().__init__(params, defaults, foreach)
|
31
|
+
|
32
|
+
def _step(self, group):
|
33
|
+
eps = group['eps']
|
34
|
+
decay = group['weight_decay']
|
35
|
+
k = group['k']
|
36
|
+
|
37
|
+
if not group['train_mode']:
|
38
|
+
raise Exception("Not in train mode!")
|
39
|
+
|
40
|
+
active_p = [p for p in group['params'] if p.grad is not None]
|
41
|
+
|
42
|
+
if not active_p:
|
43
|
+
return
|
44
|
+
|
45
|
+
for p in active_p:
|
46
|
+
if 'z' not in self.state_(p):
|
47
|
+
self.state_(p)['z'] = torch.clone(p.data)
|
48
|
+
self.state_(p)['exp_avg_sq'] = torch.zeros_like(p.data, dtype=torch.float32)
|
49
|
+
|
50
|
+
y, grad, exp_avg_sq, z = zip(*[(p.data, p.grad, self.state_(p)['exp_avg_sq'], self.state_(p)['z']) #
|
51
|
+
for p in active_p])
|
52
|
+
|
53
|
+
lr = warmup(group['lr'], k + 1, group['warmup_steps'])
|
54
|
+
ckp1, group['weight_sum'] = get_ckp1(lr, group['weight_lr_power'], group['weight_sum'], group['r'], k + 1)
|
55
|
+
|
56
|
+
step = torch.empty((), dtype=torch.int32, device=y[0].device).fill_(k + 1)
|
57
|
+
ckp1 = torch.empty((), dtype=torch.float32, device=y[0].device).fill_(ckp1)
|
58
|
+
lr = torch.empty((), dtype=torch.float32, device=y[0].device).fill_(lr)
|
59
|
+
_compilable_step_(y, grad, exp_avg_sq, z, group['betas'][0], group['betas'][1], step, ckp1, eps, decay, lr)
|
60
|
+
group['k'] = k + 1
|
@@ -0,0 +1,64 @@
|
|
1
|
+
import torch
|
2
|
+
import torch.optim
|
3
|
+
|
4
|
+
from .utils import schedule_free_, warmup, ScheduleFree, exp_avg_sq_, beta_debias, get_ckp1, promote, _compilable_schedule_free_
|
5
|
+
|
6
|
+
|
7
|
+
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
|
8
|
+
def _compilable_step_(y, grad, exp_avg_sq, z, beta1, beta2, step, ckp1, eps, decay, lr):
|
9
|
+
old_debiased2 = beta_debias(beta2, step)
|
10
|
+
|
11
|
+
g32 = [promote(g_) for g_ in grad]
|
12
|
+
exp_avg_sq32 = [promote(e_) for e_ in exp_avg_sq]
|
13
|
+
|
14
|
+
denom = exp_avg_sq_(exp_avg_sq32, g32, old_debiased2, eps)
|
15
|
+
torch._foreach_div_(g32, denom)
|
16
|
+
if decay != 0:
|
17
|
+
torch._foreach_add_(g32, y, alpha=decay)
|
18
|
+
for p, z_, g in zip(y, z, g32):
|
19
|
+
_compilable_schedule_free_(p, z_, ckp1, g, lr, beta1)
|
20
|
+
|
21
|
+
|
22
|
+
class PaLMForeachSFAdamW(ScheduleFree):
|
23
|
+
def __init__(self, params, lr=0.0025, beta=0.9, betas=(None, None), eps=1e-8, weight_decay=0, warmup_steps=0, r=0.0,
|
24
|
+
weight_lr_power=2.0, beta2_scale: float = 0.8, foreach: bool = True):
|
25
|
+
if betas[0] is not None:
|
26
|
+
beta = betas[0]
|
27
|
+
defaults = dict(lr=lr, beta=beta, eps=eps, r=r, k=0, warmup_steps=warmup_steps, train_mode=True, weight_sum=0.0,
|
28
|
+
lr_max=-1.0, weight_lr_power=weight_lr_power, weight_decay=weight_decay,
|
29
|
+
beta2_scale=beta2_scale)
|
30
|
+
super().__init__(params, defaults, foreach)
|
31
|
+
|
32
|
+
def _step(self, group):
|
33
|
+
eps = group['eps']
|
34
|
+
decay = group['weight_decay']
|
35
|
+
k = group['k']
|
36
|
+
|
37
|
+
if not group['train_mode']:
|
38
|
+
raise Exception("Not in train mode!")
|
39
|
+
|
40
|
+
active_p = [p for p in group['params'] if p.grad is not None]
|
41
|
+
|
42
|
+
if not active_p:
|
43
|
+
return
|
44
|
+
|
45
|
+
for p in active_p:
|
46
|
+
if 'z' not in self.state_(p):
|
47
|
+
self.state_(p)['z'] = torch.clone(p.data)
|
48
|
+
self.state_(p)['exp_avg_sq'] = torch.zeros_like(p.data, dtype=torch.float32)
|
49
|
+
|
50
|
+
# Decay the first moment running average coefficient
|
51
|
+
beta2 = 1 - (k + 1) ** -group['beta2_scale']
|
52
|
+
|
53
|
+
y, grad, exp_avg_sq, z = zip(*[(p.data, p.grad, self.state_(p)['exp_avg_sq'], self.state_(p)['z']) #
|
54
|
+
for p in active_p])
|
55
|
+
|
56
|
+
lr = warmup(group['lr'], k + 1, group['warmup_steps'])
|
57
|
+
ckp1, group['weight_sum'] = get_ckp1(lr, group['weight_lr_power'], group['weight_sum'], group['r'], k + 1)
|
58
|
+
|
59
|
+
step = torch.empty((), dtype=torch.int32, device=y[0].device).fill_(k + 1)
|
60
|
+
ckp1 = torch.empty((), dtype=torch.float32, device=y[0].device).fill_(ckp1)
|
61
|
+
beta2 = torch.empty((), dtype=torch.float32, device=y[0].device).fill_(beta2)
|
62
|
+
lr = torch.empty((), dtype=torch.float32, device=y[0].device).fill_(lr)
|
63
|
+
_compilable_step_(y, grad, exp_avg_sq, z, group['beta'], beta2, step, ckp1, eps, decay, lr)
|
64
|
+
group['k'] = k + 1
|
@@ -104,7 +104,8 @@ class ForeachPSGDKron(PSGDBase):
|
|
104
104
|
|
105
105
|
if should_update:
|
106
106
|
q32 = [promote(q_) for q_ in q]
|
107
|
-
self.do_update(group, [p], [g], [q32], precond_lr, [q_orig],
|
107
|
+
self.do_update(group, [p], [ea if momentum_into_precond_update else g], [q32], precond_lr, [q_orig],
|
108
|
+
store_triu_as_line)
|
108
109
|
set_(g, psgd_precond_grad(q, self.state_(p)["exprs"], ea))
|
109
110
|
|
110
111
|
grad_list = self.clip_fn(grad_list)
|
@@ -40,14 +40,25 @@ def warmup(lr: float, step: int, warmup_steps: int):
|
|
40
40
|
|
41
41
|
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
|
42
42
|
def _compilable_schedule_free_(p, z, ckp1, grad, lr, beta1):
|
43
|
-
p32 = p
|
44
|
-
z32 = z
|
45
|
-
p32.lerp_(end=z32, weight=
|
43
|
+
p32 = promote(p)
|
44
|
+
z32 = promote(z)
|
45
|
+
p32.lerp_(end=z32, weight=ckp1)
|
46
46
|
p32.add_(grad, alpha=lr * (beta1 * (1 - ckp1) - 1))
|
47
|
-
|
47
|
+
copy_stochastic_(p, p32)
|
48
48
|
|
49
49
|
z32.add_(grad, alpha=-lr)
|
50
|
-
|
50
|
+
copy_stochastic_(z, z32)
|
51
|
+
|
52
|
+
|
53
|
+
def get_ckp1(lr, weight_lr_power, weight_sum, r, step):
|
54
|
+
weight = lr ** weight_lr_power * max(step, 1) ** r
|
55
|
+
weight_sum = weight_sum + weight
|
56
|
+
|
57
|
+
try:
|
58
|
+
ckp1 = weight / weight_sum
|
59
|
+
except ZeroDivisionError:
|
60
|
+
ckp1 = 0
|
61
|
+
return ckp1, weight_sum
|
51
62
|
|
52
63
|
|
53
64
|
def schedule_free_(lr: float, weight_lr_power: float, weight_sum: float, beta1: float, parameters: List[torch.Tensor],
|
@@ -136,7 +147,7 @@ def exp_avg_sq_(state, grad, beta2, eps, out=None):
|
|
136
147
|
return torch.sqrt(state, out=out).clamp_(min=eps)
|
137
148
|
|
138
149
|
torch._foreach_mul_(state, beta2)
|
139
|
-
|
150
|
+
[s.addcmul_(g, g, value=1 - beta2) for s, g in zip(state, grad)]
|
140
151
|
denom = torch._foreach_sqrt(state)
|
141
152
|
torch._foreach_maximum_(denom, eps)
|
142
153
|
return denom
|
@@ -332,9 +343,9 @@ def compute_ggt(grad, GG, max_precond_dim, precondition_1d, beta):
|
|
332
343
|
|
333
344
|
|
334
345
|
def promote(x):
|
335
|
-
if x in (torch.bfloat16, torch.float16):
|
346
|
+
if isinstance(x, torch.dtype) and x in (torch.bfloat16, torch.float16):
|
336
347
|
return torch.float32
|
337
|
-
if
|
348
|
+
if isinstance(x, torch.Tensor) and x.dtype in (torch.bfloat16, torch.float16):
|
338
349
|
return x.float()
|
339
350
|
return x
|
340
351
|
|
@@ -486,13 +497,8 @@ def copy_stochastic_list_(target: List[torch.Tensor], source: List[torch.Tensor]
|
|
486
497
|
copy_stochastic_(t, s)
|
487
498
|
|
488
499
|
|
489
|
-
|
490
|
-
|
491
|
-
set_(target, source)
|
492
|
-
_compilable_copy_stochastic_(target, source)
|
493
|
-
|
494
|
-
|
495
|
-
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
|
500
|
+
# this can be dynamic for most optimizers - just not for PSGD. So, it's disabled for all
|
501
|
+
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True)
|
496
502
|
def _compilable_copy_stochastic_(target: torch.Tensor, source: torch.Tensor):
|
497
503
|
"""Taken as-is from https://github.com/pytorch/pytorch/issues/120376#issuecomment-1974828905"""
|
498
504
|
# create a random 16 bit integer
|
@@ -509,22 +515,24 @@ def _compilable_copy_stochastic_(target: torch.Tensor, source: torch.Tensor):
|
|
509
515
|
|
510
516
|
|
511
517
|
def copy_stochastic_(target: torch.Tensor, source: torch.Tensor):
|
512
|
-
if target.data_ptr() == source.data_ptr():
|
518
|
+
if not torch.compiler.is_compiling() and target.data_ptr() == source.data_ptr():
|
513
519
|
return
|
514
|
-
|
520
|
+
if target.dtype != torch.bfloat16 or source.dtype not in (torch.float16, torch.float32, torch.float64):
|
521
|
+
set_(target, source)
|
522
|
+
_compilable_copy_stochastic_(target, source)
|
515
523
|
|
516
524
|
|
517
525
|
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
|
518
526
|
def _compilable_update_one_(p, u, decay, add_fn, lr):
|
519
|
-
p32 = p
|
520
|
-
u32 = u.view(p.shape)
|
527
|
+
p32 = promote(p)
|
528
|
+
u32 = promote(u.view(p.shape))
|
521
529
|
if decay > 0:
|
522
530
|
p32.mul_(1 - decay * lr)
|
523
531
|
if add_fn is None:
|
524
532
|
p32.add_(u32, alpha=lr)
|
525
533
|
else:
|
526
534
|
add_fn(p32, u32, lr)
|
527
|
-
|
535
|
+
copy_stochastic_(p, p32)
|
528
536
|
|
529
537
|
|
530
538
|
def update_param_(param: List[torch.Tensor], update: List[torch.Tensor], lr: float, decay: float,
|
@@ -843,12 +851,13 @@ class PSGDBase(StatefulOptimizer):
|
|
843
851
|
psgd_update_precond(Q, self.state_(p)["exprs"], torch.randn_like(grad), grad, precond_lr, self._tiny)
|
844
852
|
update_fn(oq, Q)
|
845
853
|
|
846
|
-
|
847
|
-
|
848
|
-
if
|
849
|
-
|
850
|
-
|
851
|
-
|
854
|
+
if self.should_update(group, self.balance_probability, "balance_prob"):
|
855
|
+
for g, q in zip(grad_list, original_q if original_q else q_list):
|
856
|
+
if g.dim() > 1:
|
857
|
+
if store_triu_as_line:
|
858
|
+
psgd_balance_Q([q_ for _, q_ in q])
|
859
|
+
else:
|
860
|
+
psgd_balance_Q(q)
|
852
861
|
|
853
862
|
|
854
863
|
def precond_update_prob_schedule(max_prob=1.0, min_prob=0.03, decay=0.001, flat_start=250):
|
@@ -1,54 +0,0 @@
|
|
1
|
-
import torch
|
2
|
-
import torch.optim
|
3
|
-
|
4
|
-
from .utils import schedule_free_, warmup, ScheduleFree, exp_avg_sq_, beta_debias
|
5
|
-
|
6
|
-
|
7
|
-
class ForeachSFAdamW(ScheduleFree):
|
8
|
-
def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0, r=0.0,
|
9
|
-
weight_lr_power=2.0, foreach: bool = True):
|
10
|
-
|
11
|
-
defaults = dict(lr=lr, betas=betas, eps=eps, r=r, k=0, warmup_steps=warmup_steps, train_mode=True,
|
12
|
-
weight_sum=0.0, lr_max=-1.0, weight_lr_power=weight_lr_power, weight_decay=weight_decay,
|
13
|
-
foreach=foreach)
|
14
|
-
super().__init__(params, defaults, foreach)
|
15
|
-
|
16
|
-
def _step(self, group):
|
17
|
-
eps = group['eps']
|
18
|
-
decay = group['weight_decay']
|
19
|
-
k = group['k']
|
20
|
-
|
21
|
-
if not group['train_mode']:
|
22
|
-
raise Exception("Not in train mode!")
|
23
|
-
|
24
|
-
active_p = [p for p in group['params'] if p.grad is not None]
|
25
|
-
|
26
|
-
if not active_p:
|
27
|
-
return
|
28
|
-
|
29
|
-
for p in active_p:
|
30
|
-
if 'z' not in self.state_(p):
|
31
|
-
self.state_(p)['z'] = torch.clone(p.data)
|
32
|
-
self.state_(p)['exp_avg_sq'] = torch.zeros_like(p.data, dtype=torch.float32)
|
33
|
-
|
34
|
-
y, grad, exp_avg_sq, z = zip(
|
35
|
-
*[(p.data, p.grad.float(), self.state_(p)['exp_avg_sq'], self.state_(p)['z']) for p in active_p])
|
36
|
-
|
37
|
-
# Decay the first moment running average coefficient
|
38
|
-
old_debiased = beta_debias(group['betas'][1], k + 1)
|
39
|
-
|
40
|
-
# Decay the first and second moment running average coefficient
|
41
|
-
denom = exp_avg_sq_(exp_avg_sq, grad, old_debiased, eps)
|
42
|
-
|
43
|
-
# Normalize grad in-place for memory efficiency
|
44
|
-
torch._foreach_div_(grad, denom)
|
45
|
-
|
46
|
-
# Weight decay calculated at y
|
47
|
-
if decay != 0:
|
48
|
-
torch._foreach_add_(grad, y, alpha=decay)
|
49
|
-
|
50
|
-
lr = warmup(group['lr'], k + 1, group['warmup_steps'])
|
51
|
-
group['weight_sum'] = schedule_free_(lr, group['weight_lr_power'], group['weight_sum'], group['betas'][0], y, z,
|
52
|
-
grad, group['r'], k + 1)
|
53
|
-
|
54
|
-
group['k'] = k + 1
|
@@ -1,57 +0,0 @@
|
|
1
|
-
import torch
|
2
|
-
import torch.optim
|
3
|
-
|
4
|
-
from .utils import schedule_free_, warmup, ScheduleFree, exp_avg_sq_, beta_debias
|
5
|
-
|
6
|
-
|
7
|
-
class PaLMForeachSFAdamW(ScheduleFree):
|
8
|
-
def __init__(self, params, lr=0.0025, beta=0.9, betas=(None, None), eps=1e-8, weight_decay=0, warmup_steps=0, r=0.0,
|
9
|
-
weight_lr_power=2.0, beta2_scale: float = 0.8,
|
10
|
-
foreach: bool = True):
|
11
|
-
if betas[0] is not None:
|
12
|
-
beta = betas[0]
|
13
|
-
defaults = dict(lr=lr, beta=beta, eps=eps, r=r, k=0, warmup_steps=warmup_steps, train_mode=True, weight_sum=0.0,
|
14
|
-
lr_max=-1.0, weight_lr_power=weight_lr_power, weight_decay=weight_decay,
|
15
|
-
beta2_scale=beta2_scale)
|
16
|
-
super().__init__(params, defaults, foreach)
|
17
|
-
|
18
|
-
def _step(self, group):
|
19
|
-
eps = group['eps']
|
20
|
-
decay = group['weight_decay']
|
21
|
-
k = group['k']
|
22
|
-
|
23
|
-
if not group['train_mode']:
|
24
|
-
raise Exception("Not in train mode!")
|
25
|
-
|
26
|
-
active_p = [p for p in group['params'] if p.grad is not None]
|
27
|
-
|
28
|
-
if not active_p:
|
29
|
-
return
|
30
|
-
|
31
|
-
for p in active_p:
|
32
|
-
if 'z' not in self.state_(p):
|
33
|
-
self.state_(p)['z'] = torch.clone(p.data)
|
34
|
-
self.state_(p)['exp_avg_sq'] = torch.zeros_like(p.data, dtype=torch.float32)
|
35
|
-
|
36
|
-
y, grad, exp_avg_sq, z = zip(
|
37
|
-
*[(p.data, p.grad.float(), self.state_(p)['exp_avg_sq'], self.state_(p)['z']) for p in active_p])
|
38
|
-
|
39
|
-
# Decay the first moment running average coefficient
|
40
|
-
beta2 = 1 - (k + 1) ** -group['beta2_scale']
|
41
|
-
old_debiased = beta_debias(beta2, k + 1)
|
42
|
-
|
43
|
-
# Decay the first and second moment running average coefficient
|
44
|
-
denom = exp_avg_sq_(exp_avg_sq, grad, old_debiased, eps)
|
45
|
-
|
46
|
-
# Normalize grad in-place for memory efficiency
|
47
|
-
torch._foreach_div_(grad, denom)
|
48
|
-
|
49
|
-
# Weight decay calculated at y
|
50
|
-
if decay != 0:
|
51
|
-
torch._foreach_add_(grad, y, alpha=decay)
|
52
|
-
|
53
|
-
lr = warmup(group['lr'], k + 1, group['warmup_steps'])
|
54
|
-
group['weight_sum'] = schedule_free_(lr, group['weight_lr_power'], group['weight_sum'], group['beta'], y, z,
|
55
|
-
grad, group['r'], k + 1)
|
56
|
-
|
57
|
-
group['k'] = k + 1
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|