heavyball 0.18.4__tar.gz → 0.18.5__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {heavyball-0.18.4 → heavyball-0.18.5}/PKG-INFO +1 -1
- {heavyball-0.18.4 → heavyball-0.18.5}/heavyball/foreach_adamw.py +2 -2
- {heavyball-0.18.4 → heavyball-0.18.5}/heavyball/utils.py +35 -18
- {heavyball-0.18.4 → heavyball-0.18.5}/heavyball.egg-info/PKG-INFO +1 -1
- {heavyball-0.18.4 → heavyball-0.18.5}/heavyball.egg-info/SOURCES.txt +1 -0
- {heavyball-0.18.4 → heavyball-0.18.5}/setup.py +1 -1
- heavyball-0.18.5/test/test_bf16_params.py +52 -0
- {heavyball-0.18.4 → heavyball-0.18.5}/test/test_bf16_q.py +8 -5
- {heavyball-0.18.4 → heavyball-0.18.5}/test/test_foreach.py +1 -1
- {heavyball-0.18.4 → heavyball-0.18.5}/test/test_memory.py +1 -1
- {heavyball-0.18.4 → heavyball-0.18.5}/test/test_stochastic_updates.py +1 -1
- {heavyball-0.18.4 → heavyball-0.18.5}/LICENSE +0 -0
- {heavyball-0.18.4 → heavyball-0.18.5}/README.md +0 -0
- {heavyball-0.18.4 → heavyball-0.18.5}/heavyball/__init__.py +0 -0
- {heavyball-0.18.4 → heavyball-0.18.5}/heavyball/cached_delayed_psgd_kron.py +0 -0
- {heavyball-0.18.4 → heavyball-0.18.5}/heavyball/cached_psgd_kron.py +0 -0
- {heavyball-0.18.4 → heavyball-0.18.5}/heavyball/delayed_psgd.py +0 -0
- {heavyball-0.18.4 → heavyball-0.18.5}/heavyball/foreach_adopt.py +0 -0
- {heavyball-0.18.4 → heavyball-0.18.5}/heavyball/foreach_laprop.py +0 -0
- {heavyball-0.18.4 → heavyball-0.18.5}/heavyball/foreach_sfadamw.py +0 -0
- {heavyball-0.18.4 → heavyball-0.18.5}/heavyball/foreach_soap.py +0 -0
- {heavyball-0.18.4 → heavyball-0.18.5}/heavyball/p_adam.py +0 -0
- {heavyball-0.18.4 → heavyball-0.18.5}/heavyball/palm_foreach_sfadamw.py +0 -0
- {heavyball-0.18.4 → heavyball-0.18.5}/heavyball/palm_foreach_soap.py +0 -0
- {heavyball-0.18.4 → heavyball-0.18.5}/heavyball/precond_schedule_foreach_soap.py +0 -0
- {heavyball-0.18.4 → heavyball-0.18.5}/heavyball/precond_schedule_palm_foreach_soap.py +0 -0
- {heavyball-0.18.4 → heavyball-0.18.5}/heavyball/precond_schedule_sfpsoap.py +0 -0
- {heavyball-0.18.4 → heavyball-0.18.5}/heavyball/psgd_kron.py +0 -0
- {heavyball-0.18.4 → heavyball-0.18.5}/heavyball/pure_psgd.py +0 -0
- {heavyball-0.18.4 → heavyball-0.18.5}/heavyball/schedule_free_palm_foreach_soap.py +0 -0
- {heavyball-0.18.4 → heavyball-0.18.5}/heavyball.egg-info/dependency_links.txt +0 -0
- {heavyball-0.18.4 → heavyball-0.18.5}/heavyball.egg-info/requires.txt +0 -0
- {heavyball-0.18.4 → heavyball-0.18.5}/heavyball.egg-info/top_level.txt +0 -0
- {heavyball-0.18.4 → heavyball-0.18.5}/setup.cfg +0 -0
- {heavyball-0.18.4 → heavyball-0.18.5}/test/test_closure.py +0 -0
- {heavyball-0.18.4 → heavyball-0.18.5}/test/test_merge.py +0 -0
- {heavyball-0.18.4 → heavyball-0.18.5}/test/test_no_grad.py +0 -0
- {heavyball-0.18.4 → heavyball-0.18.5}/test/test_psgd.py +0 -0
- {heavyball-0.18.4 → heavyball-0.18.5}/test/test_soap.py +0 -0
@@ -34,9 +34,9 @@ class ForeachAdamW(StatefulOptimizer):
|
|
34
34
|
|
35
35
|
# Decay the first and second moment running average coefficient
|
36
36
|
torch._foreach_lerp_(exp_avg, grad, 1 - beta_debias(group['betas'][0], k + 1))
|
37
|
-
denom = exp_avg_sq_(exp_avg_sq, grad, beta_debias(group['betas'][1], k + 1), eps)
|
37
|
+
denom = list(exp_avg_sq_(exp_avg_sq, grad, beta_debias(group['betas'][1], k + 1), eps))
|
38
38
|
|
39
39
|
# Normalize grad in-place for memory efficiency
|
40
40
|
lr = -warmup(group['lr'], k + 1, group['warmup_steps'])
|
41
|
-
update_param_(y, exp_avg, lr, decay, lambda p, e, l:
|
41
|
+
update_param_(y, exp_avg, lr, decay, lambda p, e, l: p.addcdiv_(e, denom.pop(0), value=l))
|
42
42
|
group['k'] = k + 1
|
@@ -38,6 +38,18 @@ def warmup(lr: float, step: int, warmup_steps: int):
|
|
38
38
|
return lr * step / warmup_steps
|
39
39
|
|
40
40
|
|
41
|
+
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
|
42
|
+
def _compilable_schedule_free_(p, z, ckp1, grad, lr, beta1):
|
43
|
+
p32 = p.float()
|
44
|
+
z32 = z.float()
|
45
|
+
p32.lerp_(end=z32, weight=1 - ckp1)
|
46
|
+
p32.add_(grad, alpha=lr * (beta1 * (1 - ckp1) - 1))
|
47
|
+
_compilable_copy_stochastic_(p, p32)
|
48
|
+
|
49
|
+
z32.add_(grad, alpha=-lr)
|
50
|
+
_compilable_copy_stochastic_(z, z32)
|
51
|
+
|
52
|
+
|
41
53
|
def schedule_free_(lr: float, weight_lr_power: float, weight_sum: float, beta1: float, parameters: List[torch.Tensor],
|
42
54
|
z: List[torch.Tensor], grad: list[torch.Tensor], r: float = 0.0, step: int = 0):
|
43
55
|
weight = lr ** weight_lr_power * max(step, 1) ** r
|
@@ -50,15 +62,10 @@ def schedule_free_(lr: float, weight_lr_power: float, weight_sum: float, beta1:
|
|
50
62
|
|
51
63
|
# These operations update y in-place,
|
52
64
|
# without computing x explicitly.
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
copy_stochastic_list_(parameters, p32)
|
58
|
-
|
59
|
-
# z step
|
60
|
-
torch._foreach_sub_(z, grad, alpha=lr)
|
61
|
-
copy_stochastic_list_(z, z32)
|
65
|
+
lr_tensor = torch.empty((), dtype=torch.float32, device=parameters[0].device).fill_(lr)
|
66
|
+
ckp1_tensor = torch.empty((), dtype=torch.float32, device=parameters[0].device).fill_(ckp1)
|
67
|
+
for p, z_, g in zip(parameters, z, grad):
|
68
|
+
_compilable_schedule_free_(p, z_, ckp1_tensor, g, lr_tensor, beta1)
|
62
69
|
return weight_sum
|
63
70
|
|
64
71
|
|
@@ -504,17 +511,24 @@ def copy_stochastic_(target: torch.Tensor, source: torch.Tensor):
|
|
504
511
|
_compilable_copy_stochastic_(target, source)
|
505
512
|
|
506
513
|
|
507
|
-
|
508
|
-
|
509
|
-
|
510
|
-
|
514
|
+
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
|
515
|
+
def _compilable_update_one_(p, u, decay, add_fn, lr):
|
516
|
+
p32 = p.float()
|
517
|
+
u32 = u.view(p.shape).float()
|
511
518
|
if decay > 0:
|
512
|
-
|
519
|
+
p32.mul_(1 - decay * lr)
|
513
520
|
if add_fn is None:
|
514
|
-
|
521
|
+
p32.add_(u32, alpha=lr)
|
515
522
|
else:
|
516
|
-
add_fn(
|
517
|
-
|
523
|
+
add_fn(p32, u32, lr)
|
524
|
+
_compilable_copy_stochastic_(p, p32)
|
525
|
+
|
526
|
+
|
527
|
+
def update_param_(param: List[torch.Tensor], update: List[torch.Tensor], lr: float, decay: float,
|
528
|
+
add_fn: callable = None):
|
529
|
+
lr_tensor = torch.empty((), dtype=torch.float32, device=param[0].device).fill_(lr)
|
530
|
+
for p, u in zip(param, update):
|
531
|
+
_compilable_update_one_(p, u, decay, add_fn, lr_tensor)
|
518
532
|
|
519
533
|
|
520
534
|
def precond_schedule(step, precond_scheduler, rng):
|
@@ -828,7 +842,10 @@ class PSGDBase(StatefulOptimizer):
|
|
828
842
|
|
829
843
|
for g, q in zip(grad_list, original_q if original_q else q_list):
|
830
844
|
if g.dim() > 1:
|
831
|
-
|
845
|
+
if store_triu_as_line:
|
846
|
+
psgd_balance_Q([q_ for _, q_ in q])
|
847
|
+
else:
|
848
|
+
psgd_balance_Q(q)
|
832
849
|
|
833
850
|
|
834
851
|
def precond_update_prob_schedule(max_prob=1.0, min_prob=0.03, decay=0.001, flat_start=250):
|
@@ -0,0 +1,52 @@
|
|
1
|
+
import pytest
|
2
|
+
import torch
|
3
|
+
from torch import nn
|
4
|
+
from torch._dynamo import config
|
5
|
+
|
6
|
+
import heavyball
|
7
|
+
import heavyball.utils
|
8
|
+
from benchmark.utils import get_optim
|
9
|
+
from heavyball.utils import clean, set_torch
|
10
|
+
|
11
|
+
config.cache_size_limit = 128
|
12
|
+
|
13
|
+
def get_memory():
|
14
|
+
clean()
|
15
|
+
torch.cuda.synchronize()
|
16
|
+
clean()
|
17
|
+
torch.cuda.synchronize()
|
18
|
+
return torch.cuda.memory_allocated()
|
19
|
+
|
20
|
+
|
21
|
+
@pytest.mark.parametrize("opt", heavyball.__all__)
|
22
|
+
@pytest.mark.parametrize("size,depth", [(256, 2)])
|
23
|
+
def test_foreach(opt, size, depth: int, iterations: int = 128, outer_iterations: int = 3):
|
24
|
+
set_torch()
|
25
|
+
|
26
|
+
opt = getattr(heavyball, opt)
|
27
|
+
|
28
|
+
peaks = []
|
29
|
+
losses = []
|
30
|
+
|
31
|
+
for dtype in [torch.float32, torch.bfloat16]:
|
32
|
+
torch.manual_seed(0x2131290)
|
33
|
+
peaks.append([])
|
34
|
+
losses.append([])
|
35
|
+
|
36
|
+
for i in range(outer_iterations):
|
37
|
+
model = nn.Sequential(*[nn.Linear(size, size) for _ in range(depth)]).cuda().to(dtype)
|
38
|
+
o = get_optim(opt, model.parameters(), lr=1e-3)
|
39
|
+
|
40
|
+
for _ in range(iterations):
|
41
|
+
loss = model(torch.randn((1024, size), device='cuda', dtype=dtype)).square().mean()
|
42
|
+
loss.backward()
|
43
|
+
o.step()
|
44
|
+
o.zero_grad()
|
45
|
+
losses[-1].append(loss.detach())
|
46
|
+
|
47
|
+
del model, o
|
48
|
+
clean()
|
49
|
+
|
50
|
+
for i, (l0, l1) in enumerate(zip(*losses)):
|
51
|
+
print(i, l0.item(), l1.item())
|
52
|
+
assert torch.allclose(l0.float(), l1.float(), rtol=0.1)
|
@@ -1,10 +1,14 @@
|
|
1
|
-
import heavyball
|
2
|
-
import heavyball.utils
|
3
1
|
import pytest
|
4
2
|
import torch
|
3
|
+
from torch import nn
|
4
|
+
from torch._dynamo import config
|
5
|
+
|
6
|
+
import heavyball
|
7
|
+
import heavyball.utils
|
5
8
|
from benchmark.utils import get_optim
|
6
9
|
from heavyball.utils import clean, set_torch, PSGDBase
|
7
|
-
|
10
|
+
|
11
|
+
config.cache_size_limit = 128
|
8
12
|
|
9
13
|
|
10
14
|
def get_memory():
|
@@ -37,7 +41,7 @@ def test_foreach(opt, size, depth: int, iterations: int = 128, outer_iterations:
|
|
37
41
|
o = get_optim(opt, model.parameters(), lr=1e-3, q_dtype=q_dtype)
|
38
42
|
|
39
43
|
for _ in range(iterations):
|
40
|
-
loss = model(torch.randn((1024, size)
|
44
|
+
loss = model(torch.randn((1024, size), device='cuda')).square().mean()
|
41
45
|
loss.backward()
|
42
46
|
o.step()
|
43
47
|
o.zero_grad()
|
@@ -46,7 +50,6 @@ def test_foreach(opt, size, depth: int, iterations: int = 128, outer_iterations:
|
|
46
50
|
del model, o
|
47
51
|
clean()
|
48
52
|
|
49
|
-
|
50
53
|
for i, (l0, l1) in enumerate(zip(*losses)):
|
51
54
|
print(i, l0.item(), l1.item())
|
52
55
|
assert torch.allclose(l0, l1, rtol=0.1)
|
@@ -45,7 +45,7 @@ def test_foreach(opt, size, depth: int, iterations: int = 5, outer_iterations: i
|
|
45
45
|
clean()
|
46
46
|
|
47
47
|
for _ in range(iterations):
|
48
|
-
loss = model(torch.randn((1, size)
|
48
|
+
loss = model(torch.randn((1, size), device='cuda')).sum()
|
49
49
|
loss.backward()
|
50
50
|
o.step()
|
51
51
|
o.zero_grad()
|
@@ -48,7 +48,7 @@ def test_memory(opt, method, size, depth: int, iterations: int = 5, outer_iterat
|
|
48
48
|
model_allocated = get_memory()
|
49
49
|
o = get_optim(opt, model.parameters(), lr=1e-3)
|
50
50
|
for _ in range(iterations):
|
51
|
-
model(torch.randn((1, size)
|
51
|
+
model(torch.randn((1, size), device='cuda')).sum().backward()
|
52
52
|
o.step()
|
53
53
|
|
54
54
|
opt_allocated = get_memory()
|
@@ -38,7 +38,7 @@ def test_foreach(opt, size, depth: int, iterations: int = 8192, outer_iterations
|
|
38
38
|
o = get_optim(opt, model.parameters(), lr=1e-3, stochastic_schedule=stochastic)
|
39
39
|
|
40
40
|
for _ in range(iterations):
|
41
|
-
loss = model(torch.randn((128, size)
|
41
|
+
loss = model(torch.randn((128, size), device-'cuda')).square().mean()
|
42
42
|
loss.backward()
|
43
43
|
o.step()
|
44
44
|
o.zero_grad()
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|