heavyball 0.18.3__tar.gz → 0.18.5__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {heavyball-0.18.3 → heavyball-0.18.5}/PKG-INFO +1 -1
- {heavyball-0.18.3 → heavyball-0.18.5}/heavyball/foreach_adamw.py +2 -2
- {heavyball-0.18.3 → heavyball-0.18.5}/heavyball/utils.py +44 -26
- {heavyball-0.18.3 → heavyball-0.18.5}/heavyball.egg-info/PKG-INFO +1 -1
- {heavyball-0.18.3 → heavyball-0.18.5}/heavyball.egg-info/SOURCES.txt +1 -0
- {heavyball-0.18.3 → heavyball-0.18.5}/setup.py +1 -1
- heavyball-0.18.5/test/test_bf16_params.py +52 -0
- {heavyball-0.18.3 → heavyball-0.18.5}/test/test_bf16_q.py +8 -5
- {heavyball-0.18.3 → heavyball-0.18.5}/test/test_foreach.py +1 -1
- {heavyball-0.18.3 → heavyball-0.18.5}/test/test_memory.py +1 -1
- {heavyball-0.18.3 → heavyball-0.18.5}/test/test_stochastic_updates.py +1 -1
- {heavyball-0.18.3 → heavyball-0.18.5}/LICENSE +0 -0
- {heavyball-0.18.3 → heavyball-0.18.5}/README.md +0 -0
- {heavyball-0.18.3 → heavyball-0.18.5}/heavyball/__init__.py +0 -0
- {heavyball-0.18.3 → heavyball-0.18.5}/heavyball/cached_delayed_psgd_kron.py +0 -0
- {heavyball-0.18.3 → heavyball-0.18.5}/heavyball/cached_psgd_kron.py +0 -0
- {heavyball-0.18.3 → heavyball-0.18.5}/heavyball/delayed_psgd.py +0 -0
- {heavyball-0.18.3 → heavyball-0.18.5}/heavyball/foreach_adopt.py +0 -0
- {heavyball-0.18.3 → heavyball-0.18.5}/heavyball/foreach_laprop.py +0 -0
- {heavyball-0.18.3 → heavyball-0.18.5}/heavyball/foreach_sfadamw.py +0 -0
- {heavyball-0.18.3 → heavyball-0.18.5}/heavyball/foreach_soap.py +0 -0
- {heavyball-0.18.3 → heavyball-0.18.5}/heavyball/p_adam.py +0 -0
- {heavyball-0.18.3 → heavyball-0.18.5}/heavyball/palm_foreach_sfadamw.py +0 -0
- {heavyball-0.18.3 → heavyball-0.18.5}/heavyball/palm_foreach_soap.py +0 -0
- {heavyball-0.18.3 → heavyball-0.18.5}/heavyball/precond_schedule_foreach_soap.py +0 -0
- {heavyball-0.18.3 → heavyball-0.18.5}/heavyball/precond_schedule_palm_foreach_soap.py +0 -0
- {heavyball-0.18.3 → heavyball-0.18.5}/heavyball/precond_schedule_sfpsoap.py +0 -0
- {heavyball-0.18.3 → heavyball-0.18.5}/heavyball/psgd_kron.py +0 -0
- {heavyball-0.18.3 → heavyball-0.18.5}/heavyball/pure_psgd.py +0 -0
- {heavyball-0.18.3 → heavyball-0.18.5}/heavyball/schedule_free_palm_foreach_soap.py +0 -0
- {heavyball-0.18.3 → heavyball-0.18.5}/heavyball.egg-info/dependency_links.txt +0 -0
- {heavyball-0.18.3 → heavyball-0.18.5}/heavyball.egg-info/requires.txt +0 -0
- {heavyball-0.18.3 → heavyball-0.18.5}/heavyball.egg-info/top_level.txt +0 -0
- {heavyball-0.18.3 → heavyball-0.18.5}/setup.cfg +0 -0
- {heavyball-0.18.3 → heavyball-0.18.5}/test/test_closure.py +0 -0
- {heavyball-0.18.3 → heavyball-0.18.5}/test/test_merge.py +0 -0
- {heavyball-0.18.3 → heavyball-0.18.5}/test/test_no_grad.py +0 -0
- {heavyball-0.18.3 → heavyball-0.18.5}/test/test_psgd.py +0 -0
- {heavyball-0.18.3 → heavyball-0.18.5}/test/test_soap.py +0 -0
@@ -34,9 +34,9 @@ class ForeachAdamW(StatefulOptimizer):
|
|
34
34
|
|
35
35
|
# Decay the first and second moment running average coefficient
|
36
36
|
torch._foreach_lerp_(exp_avg, grad, 1 - beta_debias(group['betas'][0], k + 1))
|
37
|
-
denom = exp_avg_sq_(exp_avg_sq, grad, beta_debias(group['betas'][1], k + 1), eps)
|
37
|
+
denom = list(exp_avg_sq_(exp_avg_sq, grad, beta_debias(group['betas'][1], k + 1), eps))
|
38
38
|
|
39
39
|
# Normalize grad in-place for memory efficiency
|
40
40
|
lr = -warmup(group['lr'], k + 1, group['warmup_steps'])
|
41
|
-
update_param_(y, exp_avg, lr, decay, lambda p, e, l:
|
41
|
+
update_param_(y, exp_avg, lr, decay, lambda p, e, l: p.addcdiv_(e, denom.pop(0), value=l))
|
42
42
|
group['k'] = k + 1
|
@@ -38,6 +38,18 @@ def warmup(lr: float, step: int, warmup_steps: int):
|
|
38
38
|
return lr * step / warmup_steps
|
39
39
|
|
40
40
|
|
41
|
+
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
|
42
|
+
def _compilable_schedule_free_(p, z, ckp1, grad, lr, beta1):
|
43
|
+
p32 = p.float()
|
44
|
+
z32 = z.float()
|
45
|
+
p32.lerp_(end=z32, weight=1 - ckp1)
|
46
|
+
p32.add_(grad, alpha=lr * (beta1 * (1 - ckp1) - 1))
|
47
|
+
_compilable_copy_stochastic_(p, p32)
|
48
|
+
|
49
|
+
z32.add_(grad, alpha=-lr)
|
50
|
+
_compilable_copy_stochastic_(z, z32)
|
51
|
+
|
52
|
+
|
41
53
|
def schedule_free_(lr: float, weight_lr_power: float, weight_sum: float, beta1: float, parameters: List[torch.Tensor],
|
42
54
|
z: List[torch.Tensor], grad: list[torch.Tensor], r: float = 0.0, step: int = 0):
|
43
55
|
weight = lr ** weight_lr_power * max(step, 1) ** r
|
@@ -50,15 +62,10 @@ def schedule_free_(lr: float, weight_lr_power: float, weight_sum: float, beta1:
|
|
50
62
|
|
51
63
|
# These operations update y in-place,
|
52
64
|
# without computing x explicitly.
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
copy_stochastic_list_(parameters, p32)
|
58
|
-
|
59
|
-
# z step
|
60
|
-
torch._foreach_sub_(z, grad, alpha=lr)
|
61
|
-
copy_stochastic_list_(z, z32)
|
65
|
+
lr_tensor = torch.empty((), dtype=torch.float32, device=parameters[0].device).fill_(lr)
|
66
|
+
ckp1_tensor = torch.empty((), dtype=torch.float32, device=parameters[0].device).fill_(ckp1)
|
67
|
+
for p, z_, g in zip(parameters, z, grad):
|
68
|
+
_compilable_schedule_free_(p, z_, ckp1_tensor, g, lr_tensor, beta1)
|
62
69
|
return weight_sum
|
63
70
|
|
64
71
|
|
@@ -479,7 +486,7 @@ def copy_stochastic_list_(target: List[torch.Tensor], source: List[torch.Tensor]
|
|
479
486
|
copy_stochastic_(t, s)
|
480
487
|
|
481
488
|
|
482
|
-
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=
|
489
|
+
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
|
483
490
|
def _compilable_copy_stochastic_(target: torch.Tensor, source: torch.Tensor):
|
484
491
|
"""Taken as-is from https://github.com/pytorch/pytorch/issues/120376#issuecomment-1974828905"""
|
485
492
|
# create a random 16 bit integer
|
@@ -504,17 +511,24 @@ def copy_stochastic_(target: torch.Tensor, source: torch.Tensor):
|
|
504
511
|
_compilable_copy_stochastic_(target, source)
|
505
512
|
|
506
513
|
|
507
|
-
|
508
|
-
|
509
|
-
|
510
|
-
|
514
|
+
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
|
515
|
+
def _compilable_update_one_(p, u, decay, add_fn, lr):
|
516
|
+
p32 = p.float()
|
517
|
+
u32 = u.view(p.shape).float()
|
511
518
|
if decay > 0:
|
512
|
-
|
519
|
+
p32.mul_(1 - decay * lr)
|
513
520
|
if add_fn is None:
|
514
|
-
|
521
|
+
p32.add_(u32, alpha=lr)
|
515
522
|
else:
|
516
|
-
add_fn(
|
517
|
-
|
523
|
+
add_fn(p32, u32, lr)
|
524
|
+
_compilable_copy_stochastic_(p, p32)
|
525
|
+
|
526
|
+
|
527
|
+
def update_param_(param: List[torch.Tensor], update: List[torch.Tensor], lr: float, decay: float,
|
528
|
+
add_fn: callable = None):
|
529
|
+
lr_tensor = torch.empty((), dtype=torch.float32, device=param[0].device).fill_(lr)
|
530
|
+
for p, u in zip(param, update):
|
531
|
+
_compilable_update_one_(p, u, decay, add_fn, lr_tensor)
|
518
532
|
|
519
533
|
|
520
534
|
def precond_schedule(step, precond_scheduler, rng):
|
@@ -815,19 +829,23 @@ class PSGDBase(StatefulOptimizer):
|
|
815
829
|
|
816
830
|
def do_update(self, group, p_list, grad_list, q_list, precond_lr, original_q: Optional[List] = None,
|
817
831
|
store_triu_as_line=False):
|
818
|
-
|
832
|
+
if original_q:
|
833
|
+
if store_triu_as_line:
|
834
|
+
update_fn = update_triu_
|
835
|
+
else:
|
836
|
+
update_fn = copy_stochastic_list_
|
837
|
+
else:
|
838
|
+
update_fn = lambda x, y: None
|
839
|
+
for i, (p, grad, Q, oq) in enumerate(zip(p_list, grad_list, q_list, original_q)):
|
819
840
|
psgd_update_precond(Q, self.state_(p)["exprs"], torch.randn_like(grad), grad, precond_lr, self._tiny)
|
841
|
+
update_fn(oq, Q)
|
820
842
|
|
821
|
-
for g, q in zip(grad_list, q_list):
|
843
|
+
for g, q in zip(grad_list, original_q if original_q else q_list):
|
822
844
|
if g.dim() > 1:
|
823
|
-
psgd_balance_Q(q)
|
824
|
-
|
825
|
-
if original_q:
|
826
|
-
for q in q_list:
|
827
845
|
if store_triu_as_line:
|
828
|
-
|
846
|
+
psgd_balance_Q([q_ for _, q_ in q])
|
829
847
|
else:
|
830
|
-
|
848
|
+
psgd_balance_Q(q)
|
831
849
|
|
832
850
|
|
833
851
|
def precond_update_prob_schedule(max_prob=1.0, min_prob=0.03, decay=0.001, flat_start=250):
|
@@ -0,0 +1,52 @@
|
|
1
|
+
import pytest
|
2
|
+
import torch
|
3
|
+
from torch import nn
|
4
|
+
from torch._dynamo import config
|
5
|
+
|
6
|
+
import heavyball
|
7
|
+
import heavyball.utils
|
8
|
+
from benchmark.utils import get_optim
|
9
|
+
from heavyball.utils import clean, set_torch
|
10
|
+
|
11
|
+
config.cache_size_limit = 128
|
12
|
+
|
13
|
+
def get_memory():
|
14
|
+
clean()
|
15
|
+
torch.cuda.synchronize()
|
16
|
+
clean()
|
17
|
+
torch.cuda.synchronize()
|
18
|
+
return torch.cuda.memory_allocated()
|
19
|
+
|
20
|
+
|
21
|
+
@pytest.mark.parametrize("opt", heavyball.__all__)
|
22
|
+
@pytest.mark.parametrize("size,depth", [(256, 2)])
|
23
|
+
def test_foreach(opt, size, depth: int, iterations: int = 128, outer_iterations: int = 3):
|
24
|
+
set_torch()
|
25
|
+
|
26
|
+
opt = getattr(heavyball, opt)
|
27
|
+
|
28
|
+
peaks = []
|
29
|
+
losses = []
|
30
|
+
|
31
|
+
for dtype in [torch.float32, torch.bfloat16]:
|
32
|
+
torch.manual_seed(0x2131290)
|
33
|
+
peaks.append([])
|
34
|
+
losses.append([])
|
35
|
+
|
36
|
+
for i in range(outer_iterations):
|
37
|
+
model = nn.Sequential(*[nn.Linear(size, size) for _ in range(depth)]).cuda().to(dtype)
|
38
|
+
o = get_optim(opt, model.parameters(), lr=1e-3)
|
39
|
+
|
40
|
+
for _ in range(iterations):
|
41
|
+
loss = model(torch.randn((1024, size), device='cuda', dtype=dtype)).square().mean()
|
42
|
+
loss.backward()
|
43
|
+
o.step()
|
44
|
+
o.zero_grad()
|
45
|
+
losses[-1].append(loss.detach())
|
46
|
+
|
47
|
+
del model, o
|
48
|
+
clean()
|
49
|
+
|
50
|
+
for i, (l0, l1) in enumerate(zip(*losses)):
|
51
|
+
print(i, l0.item(), l1.item())
|
52
|
+
assert torch.allclose(l0.float(), l1.float(), rtol=0.1)
|
@@ -1,10 +1,14 @@
|
|
1
|
-
import heavyball
|
2
|
-
import heavyball.utils
|
3
1
|
import pytest
|
4
2
|
import torch
|
3
|
+
from torch import nn
|
4
|
+
from torch._dynamo import config
|
5
|
+
|
6
|
+
import heavyball
|
7
|
+
import heavyball.utils
|
5
8
|
from benchmark.utils import get_optim
|
6
9
|
from heavyball.utils import clean, set_torch, PSGDBase
|
7
|
-
|
10
|
+
|
11
|
+
config.cache_size_limit = 128
|
8
12
|
|
9
13
|
|
10
14
|
def get_memory():
|
@@ -37,7 +41,7 @@ def test_foreach(opt, size, depth: int, iterations: int = 128, outer_iterations:
|
|
37
41
|
o = get_optim(opt, model.parameters(), lr=1e-3, q_dtype=q_dtype)
|
38
42
|
|
39
43
|
for _ in range(iterations):
|
40
|
-
loss = model(torch.randn((1024, size)
|
44
|
+
loss = model(torch.randn((1024, size), device='cuda')).square().mean()
|
41
45
|
loss.backward()
|
42
46
|
o.step()
|
43
47
|
o.zero_grad()
|
@@ -46,7 +50,6 @@ def test_foreach(opt, size, depth: int, iterations: int = 128, outer_iterations:
|
|
46
50
|
del model, o
|
47
51
|
clean()
|
48
52
|
|
49
|
-
|
50
53
|
for i, (l0, l1) in enumerate(zip(*losses)):
|
51
54
|
print(i, l0.item(), l1.item())
|
52
55
|
assert torch.allclose(l0, l1, rtol=0.1)
|
@@ -45,7 +45,7 @@ def test_foreach(opt, size, depth: int, iterations: int = 5, outer_iterations: i
|
|
45
45
|
clean()
|
46
46
|
|
47
47
|
for _ in range(iterations):
|
48
|
-
loss = model(torch.randn((1, size)
|
48
|
+
loss = model(torch.randn((1, size), device='cuda')).sum()
|
49
49
|
loss.backward()
|
50
50
|
o.step()
|
51
51
|
o.zero_grad()
|
@@ -48,7 +48,7 @@ def test_memory(opt, method, size, depth: int, iterations: int = 5, outer_iterat
|
|
48
48
|
model_allocated = get_memory()
|
49
49
|
o = get_optim(opt, model.parameters(), lr=1e-3)
|
50
50
|
for _ in range(iterations):
|
51
|
-
model(torch.randn((1, size)
|
51
|
+
model(torch.randn((1, size), device='cuda')).sum().backward()
|
52
52
|
o.step()
|
53
53
|
|
54
54
|
opt_allocated = get_memory()
|
@@ -38,7 +38,7 @@ def test_foreach(opt, size, depth: int, iterations: int = 8192, outer_iterations
|
|
38
38
|
o = get_optim(opt, model.parameters(), lr=1e-3, stochastic_schedule=stochastic)
|
39
39
|
|
40
40
|
for _ in range(iterations):
|
41
|
-
loss = model(torch.randn((128, size)
|
41
|
+
loss = model(torch.randn((128, size), device-'cuda')).square().mean()
|
42
42
|
loss.backward()
|
43
43
|
o.step()
|
44
44
|
o.zero_grad()
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|