heavyball 0.18.4__tar.gz → 0.18.6__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {heavyball-0.18.4 → heavyball-0.18.6}/PKG-INFO +2 -2
- {heavyball-0.18.4 → heavyball-0.18.6}/README.md +1 -1
- {heavyball-0.18.4 → heavyball-0.18.6}/heavyball/foreach_adamw.py +2 -2
- {heavyball-0.18.4 → heavyball-0.18.6}/heavyball/utils.py +44 -24
- {heavyball-0.18.4 → heavyball-0.18.6}/heavyball.egg-info/PKG-INFO +2 -2
- {heavyball-0.18.4 → heavyball-0.18.6}/heavyball.egg-info/SOURCES.txt +1 -0
- {heavyball-0.18.4 → heavyball-0.18.6}/setup.py +1 -1
- heavyball-0.18.6/test/test_bf16_params.py +52 -0
- {heavyball-0.18.4 → heavyball-0.18.6}/test/test_bf16_q.py +8 -5
- {heavyball-0.18.4 → heavyball-0.18.6}/test/test_foreach.py +1 -1
- {heavyball-0.18.4 → heavyball-0.18.6}/test/test_memory.py +1 -1
- {heavyball-0.18.4 → heavyball-0.18.6}/test/test_stochastic_updates.py +1 -1
- {heavyball-0.18.4 → heavyball-0.18.6}/LICENSE +0 -0
- {heavyball-0.18.4 → heavyball-0.18.6}/heavyball/__init__.py +0 -0
- {heavyball-0.18.4 → heavyball-0.18.6}/heavyball/cached_delayed_psgd_kron.py +0 -0
- {heavyball-0.18.4 → heavyball-0.18.6}/heavyball/cached_psgd_kron.py +0 -0
- {heavyball-0.18.4 → heavyball-0.18.6}/heavyball/delayed_psgd.py +0 -0
- {heavyball-0.18.4 → heavyball-0.18.6}/heavyball/foreach_adopt.py +0 -0
- {heavyball-0.18.4 → heavyball-0.18.6}/heavyball/foreach_laprop.py +0 -0
- {heavyball-0.18.4 → heavyball-0.18.6}/heavyball/foreach_sfadamw.py +0 -0
- {heavyball-0.18.4 → heavyball-0.18.6}/heavyball/foreach_soap.py +0 -0
- {heavyball-0.18.4 → heavyball-0.18.6}/heavyball/p_adam.py +0 -0
- {heavyball-0.18.4 → heavyball-0.18.6}/heavyball/palm_foreach_sfadamw.py +0 -0
- {heavyball-0.18.4 → heavyball-0.18.6}/heavyball/palm_foreach_soap.py +0 -0
- {heavyball-0.18.4 → heavyball-0.18.6}/heavyball/precond_schedule_foreach_soap.py +0 -0
- {heavyball-0.18.4 → heavyball-0.18.6}/heavyball/precond_schedule_palm_foreach_soap.py +0 -0
- {heavyball-0.18.4 → heavyball-0.18.6}/heavyball/precond_schedule_sfpsoap.py +0 -0
- {heavyball-0.18.4 → heavyball-0.18.6}/heavyball/psgd_kron.py +0 -0
- {heavyball-0.18.4 → heavyball-0.18.6}/heavyball/pure_psgd.py +0 -0
- {heavyball-0.18.4 → heavyball-0.18.6}/heavyball/schedule_free_palm_foreach_soap.py +0 -0
- {heavyball-0.18.4 → heavyball-0.18.6}/heavyball.egg-info/dependency_links.txt +0 -0
- {heavyball-0.18.4 → heavyball-0.18.6}/heavyball.egg-info/requires.txt +0 -0
- {heavyball-0.18.4 → heavyball-0.18.6}/heavyball.egg-info/top_level.txt +0 -0
- {heavyball-0.18.4 → heavyball-0.18.6}/setup.cfg +0 -0
- {heavyball-0.18.4 → heavyball-0.18.6}/test/test_closure.py +0 -0
- {heavyball-0.18.4 → heavyball-0.18.6}/test/test_merge.py +0 -0
- {heavyball-0.18.4 → heavyball-0.18.6}/test/test_no_grad.py +0 -0
- {heavyball-0.18.4 → heavyball-0.18.6}/test/test_psgd.py +0 -0
- {heavyball-0.18.4 → heavyball-0.18.6}/test/test_soap.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: heavyball
|
3
|
-
Version: 0.18.
|
3
|
+
Version: 0.18.6
|
4
4
|
Summary: Efficient optimizers
|
5
5
|
Home-page: https://github.com/clashluke/heavyball
|
6
6
|
Author: Lucas Nestler
|
@@ -32,7 +32,7 @@ A simple package of efficient optimizers
|
|
32
32
|
The goal is not to thrive for completeness, full maintenance or abstraction, but instead to provide a simple
|
33
33
|
largely static alternative to `torch.optim` with more and better optimizers.
|
34
34
|
|
35
|
-
Currently (2024-11-
|
35
|
+
Currently (2024-11-21, 0.18.6), the recommended stable optimizer is `PrecondSchedulePaLMSOAP` (see below). The
|
36
36
|
recommended experimental optimizer is `DelayedPSGDKron` ([tuning guide](docs/psgd_efficiency.md)).
|
37
37
|
|
38
38
|
## Features
|
@@ -8,7 +8,7 @@ A simple package of efficient optimizers
|
|
8
8
|
The goal is not to thrive for completeness, full maintenance or abstraction, but instead to provide a simple
|
9
9
|
largely static alternative to `torch.optim` with more and better optimizers.
|
10
10
|
|
11
|
-
Currently (2024-11-
|
11
|
+
Currently (2024-11-21, 0.18.6), the recommended stable optimizer is `PrecondSchedulePaLMSOAP` (see below). The
|
12
12
|
recommended experimental optimizer is `DelayedPSGDKron` ([tuning guide](docs/psgd_efficiency.md)).
|
13
13
|
|
14
14
|
## Features
|
@@ -34,9 +34,9 @@ class ForeachAdamW(StatefulOptimizer):
|
|
34
34
|
|
35
35
|
# Decay the first and second moment running average coefficient
|
36
36
|
torch._foreach_lerp_(exp_avg, grad, 1 - beta_debias(group['betas'][0], k + 1))
|
37
|
-
denom = exp_avg_sq_(exp_avg_sq, grad, beta_debias(group['betas'][1], k + 1), eps)
|
37
|
+
denom = list(exp_avg_sq_(exp_avg_sq, grad, beta_debias(group['betas'][1], k + 1), eps))
|
38
38
|
|
39
39
|
# Normalize grad in-place for memory efficiency
|
40
40
|
lr = -warmup(group['lr'], k + 1, group['warmup_steps'])
|
41
|
-
update_param_(y, exp_avg, lr, decay, lambda p, e, l:
|
41
|
+
update_param_(y, exp_avg, lr, decay, lambda p, e, l: p.addcdiv_(e, denom.pop(0), value=l))
|
42
42
|
group['k'] = k + 1
|
@@ -38,6 +38,18 @@ def warmup(lr: float, step: int, warmup_steps: int):
|
|
38
38
|
return lr * step / warmup_steps
|
39
39
|
|
40
40
|
|
41
|
+
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
|
42
|
+
def _compilable_schedule_free_(p, z, ckp1, grad, lr, beta1):
|
43
|
+
p32 = p.float()
|
44
|
+
z32 = z.float()
|
45
|
+
p32.lerp_(end=z32, weight=1 - ckp1)
|
46
|
+
p32.add_(grad, alpha=lr * (beta1 * (1 - ckp1) - 1))
|
47
|
+
_guarded_copy_stochastic(p, p32)
|
48
|
+
|
49
|
+
z32.add_(grad, alpha=-lr)
|
50
|
+
_guarded_copy_stochastic(z, z32)
|
51
|
+
|
52
|
+
|
41
53
|
def schedule_free_(lr: float, weight_lr_power: float, weight_sum: float, beta1: float, parameters: List[torch.Tensor],
|
42
54
|
z: List[torch.Tensor], grad: list[torch.Tensor], r: float = 0.0, step: int = 0):
|
43
55
|
weight = lr ** weight_lr_power * max(step, 1) ** r
|
@@ -50,15 +62,10 @@ def schedule_free_(lr: float, weight_lr_power: float, weight_sum: float, beta1:
|
|
50
62
|
|
51
63
|
# These operations update y in-place,
|
52
64
|
# without computing x explicitly.
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
copy_stochastic_list_(parameters, p32)
|
58
|
-
|
59
|
-
# z step
|
60
|
-
torch._foreach_sub_(z, grad, alpha=lr)
|
61
|
-
copy_stochastic_list_(z, z32)
|
65
|
+
lr_tensor = torch.empty((), dtype=torch.float32, device=parameters[0].device).fill_(lr)
|
66
|
+
ckp1_tensor = torch.empty((), dtype=torch.float32, device=parameters[0].device).fill_(ckp1)
|
67
|
+
for p, z_, g in zip(parameters, z, grad):
|
68
|
+
_compilable_schedule_free_(p, z_, ckp1_tensor, g, lr_tensor, beta1)
|
62
69
|
return weight_sum
|
63
70
|
|
64
71
|
|
@@ -150,11 +157,11 @@ def adaptive_gradient_clipping_(parameters: List[torch.Tensor], gradients: List[
|
|
150
157
|
|
151
158
|
|
152
159
|
def set_(dst: torch.Tensor, src: torch.Tensor):
|
153
|
-
if src.data_ptr() == dst.data_ptr():
|
160
|
+
if not torch.compiler.is_compiling() and src.data_ptr() == dst.data_ptr():
|
154
161
|
return
|
155
162
|
if src.shape != dst.shape:
|
156
163
|
src = src.reshape_as(dst)
|
157
|
-
if src.is_contiguous() and dst.is_contiguous() and src.dtype == dst.dtype:
|
164
|
+
if not torch.compiler.is_compiling() and src.is_contiguous() and dst.is_contiguous() and src.dtype == dst.dtype:
|
158
165
|
dst.set_(src)
|
159
166
|
else:
|
160
167
|
dst.copy_(src)
|
@@ -479,6 +486,12 @@ def copy_stochastic_list_(target: List[torch.Tensor], source: List[torch.Tensor]
|
|
479
486
|
copy_stochastic_(t, s)
|
480
487
|
|
481
488
|
|
489
|
+
def _guarded_copy_stochastic(target: torch.Tensor, source: torch.Tensor):
|
490
|
+
if target.dtype != torch.bfloat16 or source.dtype not in (torch.float16, torch.float32, torch.float64):
|
491
|
+
set_(target, source)
|
492
|
+
_compilable_copy_stochastic_(target, source)
|
493
|
+
|
494
|
+
|
482
495
|
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
|
483
496
|
def _compilable_copy_stochastic_(target: torch.Tensor, source: torch.Tensor):
|
484
497
|
"""Taken as-is from https://github.com/pytorch/pytorch/issues/120376#issuecomment-1974828905"""
|
@@ -498,23 +511,27 @@ def _compilable_copy_stochastic_(target: torch.Tensor, source: torch.Tensor):
|
|
498
511
|
def copy_stochastic_(target: torch.Tensor, source: torch.Tensor):
|
499
512
|
if target.data_ptr() == source.data_ptr():
|
500
513
|
return
|
501
|
-
|
502
|
-
set_(target, source)
|
503
|
-
return
|
504
|
-
_compilable_copy_stochastic_(target, source)
|
514
|
+
_guarded_copy_stochastic(target, source)
|
505
515
|
|
506
516
|
|
507
|
-
|
508
|
-
|
509
|
-
|
510
|
-
|
517
|
+
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
|
518
|
+
def _compilable_update_one_(p, u, decay, add_fn, lr):
|
519
|
+
p32 = p.float()
|
520
|
+
u32 = u.view(p.shape).float()
|
511
521
|
if decay > 0:
|
512
|
-
|
522
|
+
p32.mul_(1 - decay * lr)
|
513
523
|
if add_fn is None:
|
514
|
-
|
524
|
+
p32.add_(u32, alpha=lr)
|
515
525
|
else:
|
516
|
-
add_fn(
|
517
|
-
|
526
|
+
add_fn(p32, u32, lr)
|
527
|
+
_guarded_copy_stochastic(p, p32)
|
528
|
+
|
529
|
+
|
530
|
+
def update_param_(param: List[torch.Tensor], update: List[torch.Tensor], lr: float, decay: float,
|
531
|
+
add_fn: callable = None):
|
532
|
+
lr_tensor = torch.empty((), dtype=torch.float32, device=param[0].device).fill_(lr)
|
533
|
+
for p, u in zip(param, update):
|
534
|
+
_compilable_update_one_(p, u, decay, add_fn, lr_tensor)
|
518
535
|
|
519
536
|
|
520
537
|
def precond_schedule(step, precond_scheduler, rng):
|
@@ -828,7 +845,10 @@ class PSGDBase(StatefulOptimizer):
|
|
828
845
|
|
829
846
|
for g, q in zip(grad_list, original_q if original_q else q_list):
|
830
847
|
if g.dim() > 1:
|
831
|
-
|
848
|
+
if store_triu_as_line:
|
849
|
+
psgd_balance_Q([q_ for _, q_ in q])
|
850
|
+
else:
|
851
|
+
psgd_balance_Q(q)
|
832
852
|
|
833
853
|
|
834
854
|
def precond_update_prob_schedule(max_prob=1.0, min_prob=0.03, decay=0.001, flat_start=250):
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: heavyball
|
3
|
-
Version: 0.18.
|
3
|
+
Version: 0.18.6
|
4
4
|
Summary: Efficient optimizers
|
5
5
|
Home-page: https://github.com/clashluke/heavyball
|
6
6
|
Author: Lucas Nestler
|
@@ -32,7 +32,7 @@ A simple package of efficient optimizers
|
|
32
32
|
The goal is not to thrive for completeness, full maintenance or abstraction, but instead to provide a simple
|
33
33
|
largely static alternative to `torch.optim` with more and better optimizers.
|
34
34
|
|
35
|
-
Currently (2024-11-
|
35
|
+
Currently (2024-11-21, 0.18.6), the recommended stable optimizer is `PrecondSchedulePaLMSOAP` (see below). The
|
36
36
|
recommended experimental optimizer is `DelayedPSGDKron` ([tuning guide](docs/psgd_efficiency.md)).
|
37
37
|
|
38
38
|
## Features
|
@@ -0,0 +1,52 @@
|
|
1
|
+
import pytest
|
2
|
+
import torch
|
3
|
+
from torch import nn
|
4
|
+
from torch._dynamo import config
|
5
|
+
|
6
|
+
import heavyball
|
7
|
+
import heavyball.utils
|
8
|
+
from benchmark.utils import get_optim
|
9
|
+
from heavyball.utils import clean, set_torch
|
10
|
+
|
11
|
+
config.cache_size_limit = 128
|
12
|
+
|
13
|
+
def get_memory():
|
14
|
+
clean()
|
15
|
+
torch.cuda.synchronize()
|
16
|
+
clean()
|
17
|
+
torch.cuda.synchronize()
|
18
|
+
return torch.cuda.memory_allocated()
|
19
|
+
|
20
|
+
|
21
|
+
@pytest.mark.parametrize("opt", heavyball.__all__)
|
22
|
+
@pytest.mark.parametrize("size,depth", [(256, 2)])
|
23
|
+
def test_foreach(opt, size, depth: int, iterations: int = 128, outer_iterations: int = 3):
|
24
|
+
set_torch()
|
25
|
+
|
26
|
+
opt = getattr(heavyball, opt)
|
27
|
+
|
28
|
+
peaks = []
|
29
|
+
losses = []
|
30
|
+
|
31
|
+
for dtype in [torch.float32, torch.bfloat16]:
|
32
|
+
torch.manual_seed(0x2131290)
|
33
|
+
peaks.append([])
|
34
|
+
losses.append([])
|
35
|
+
|
36
|
+
for i in range(outer_iterations):
|
37
|
+
model = nn.Sequential(*[nn.Linear(size, size) for _ in range(depth)]).cuda().to(dtype)
|
38
|
+
o = get_optim(opt, model.parameters(), lr=1e-3)
|
39
|
+
|
40
|
+
for _ in range(iterations):
|
41
|
+
loss = model(torch.randn((1024, size), device='cuda', dtype=dtype)).square().mean()
|
42
|
+
loss.backward()
|
43
|
+
o.step()
|
44
|
+
o.zero_grad()
|
45
|
+
losses[-1].append(loss.detach())
|
46
|
+
|
47
|
+
del model, o
|
48
|
+
clean()
|
49
|
+
|
50
|
+
for i, (l0, l1) in enumerate(zip(*losses)):
|
51
|
+
print(i, l0.item(), l1.item())
|
52
|
+
assert torch.allclose(l0.float(), l1.float(), rtol=0.1)
|
@@ -1,10 +1,14 @@
|
|
1
|
-
import heavyball
|
2
|
-
import heavyball.utils
|
3
1
|
import pytest
|
4
2
|
import torch
|
3
|
+
from torch import nn
|
4
|
+
from torch._dynamo import config
|
5
|
+
|
6
|
+
import heavyball
|
7
|
+
import heavyball.utils
|
5
8
|
from benchmark.utils import get_optim
|
6
9
|
from heavyball.utils import clean, set_torch, PSGDBase
|
7
|
-
|
10
|
+
|
11
|
+
config.cache_size_limit = 128
|
8
12
|
|
9
13
|
|
10
14
|
def get_memory():
|
@@ -37,7 +41,7 @@ def test_foreach(opt, size, depth: int, iterations: int = 128, outer_iterations:
|
|
37
41
|
o = get_optim(opt, model.parameters(), lr=1e-3, q_dtype=q_dtype)
|
38
42
|
|
39
43
|
for _ in range(iterations):
|
40
|
-
loss = model(torch.randn((1024, size)
|
44
|
+
loss = model(torch.randn((1024, size), device='cuda')).square().mean()
|
41
45
|
loss.backward()
|
42
46
|
o.step()
|
43
47
|
o.zero_grad()
|
@@ -46,7 +50,6 @@ def test_foreach(opt, size, depth: int, iterations: int = 128, outer_iterations:
|
|
46
50
|
del model, o
|
47
51
|
clean()
|
48
52
|
|
49
|
-
|
50
53
|
for i, (l0, l1) in enumerate(zip(*losses)):
|
51
54
|
print(i, l0.item(), l1.item())
|
52
55
|
assert torch.allclose(l0, l1, rtol=0.1)
|
@@ -45,7 +45,7 @@ def test_foreach(opt, size, depth: int, iterations: int = 5, outer_iterations: i
|
|
45
45
|
clean()
|
46
46
|
|
47
47
|
for _ in range(iterations):
|
48
|
-
loss = model(torch.randn((1, size)
|
48
|
+
loss = model(torch.randn((1, size), device='cuda')).sum()
|
49
49
|
loss.backward()
|
50
50
|
o.step()
|
51
51
|
o.zero_grad()
|
@@ -48,7 +48,7 @@ def test_memory(opt, method, size, depth: int, iterations: int = 5, outer_iterat
|
|
48
48
|
model_allocated = get_memory()
|
49
49
|
o = get_optim(opt, model.parameters(), lr=1e-3)
|
50
50
|
for _ in range(iterations):
|
51
|
-
model(torch.randn((1, size)
|
51
|
+
model(torch.randn((1, size), device='cuda')).sum().backward()
|
52
52
|
o.step()
|
53
53
|
|
54
54
|
opt_allocated = get_memory()
|
@@ -38,7 +38,7 @@ def test_foreach(opt, size, depth: int, iterations: int = 8192, outer_iterations
|
|
38
38
|
o = get_optim(opt, model.parameters(), lr=1e-3, stochastic_schedule=stochastic)
|
39
39
|
|
40
40
|
for _ in range(iterations):
|
41
|
-
loss = model(torch.randn((128, size)
|
41
|
+
loss = model(torch.randn((128, size), device-'cuda')).square().mean()
|
42
42
|
loss.backward()
|
43
43
|
o.step()
|
44
44
|
o.zero_grad()
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|