heavyball 0.21.7__tar.gz → 0.21.8__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {heavyball-0.21.7 → heavyball-0.21.8}/PKG-INFO +1 -1
- {heavyball-0.21.7 → heavyball-0.21.8}/heavyball/utils.py +9 -9
- {heavyball-0.21.7 → heavyball-0.21.8}/heavyball.egg-info/PKG-INFO +1 -1
- {heavyball-0.21.7 → heavyball-0.21.8}/setup.py +1 -1
- {heavyball-0.21.7 → heavyball-0.21.8}/test/test_bf16_params.py +10 -8
- {heavyball-0.21.7 → heavyball-0.21.8}/LICENSE +0 -0
- {heavyball-0.21.7 → heavyball-0.21.8}/README.md +0 -0
- {heavyball-0.21.7 → heavyball-0.21.8}/heavyball/__init__.py +0 -0
- {heavyball-0.21.7 → heavyball-0.21.8}/heavyball/cached_delayed_psgd_kron.py +0 -0
- {heavyball-0.21.7 → heavyball-0.21.8}/heavyball/cached_psgd_kron.py +0 -0
- {heavyball-0.21.7 → heavyball-0.21.8}/heavyball/delayed_psgd.py +0 -0
- {heavyball-0.21.7 → heavyball-0.21.8}/heavyball/foreach_adamw.py +0 -0
- {heavyball-0.21.7 → heavyball-0.21.8}/heavyball/foreach_adopt.py +0 -0
- {heavyball-0.21.7 → heavyball-0.21.8}/heavyball/foreach_laprop.py +0 -0
- {heavyball-0.21.7 → heavyball-0.21.8}/heavyball/foreach_sfadamw.py +0 -0
- {heavyball-0.21.7 → heavyball-0.21.8}/heavyball/foreach_soap.py +0 -0
- {heavyball-0.21.7 → heavyball-0.21.8}/heavyball/p_adam.py +0 -0
- {heavyball-0.21.7 → heavyball-0.21.8}/heavyball/palm_foreach_sfadamw.py +0 -0
- {heavyball-0.21.7 → heavyball-0.21.8}/heavyball/palm_foreach_soap.py +0 -0
- {heavyball-0.21.7 → heavyball-0.21.8}/heavyball/precond_schedule_foreach_soap.py +0 -0
- {heavyball-0.21.7 → heavyball-0.21.8}/heavyball/precond_schedule_palm_foreach_soap.py +0 -0
- {heavyball-0.21.7 → heavyball-0.21.8}/heavyball/precond_schedule_sfpsoap.py +0 -0
- {heavyball-0.21.7 → heavyball-0.21.8}/heavyball/psgd_kron.py +0 -0
- {heavyball-0.21.7 → heavyball-0.21.8}/heavyball/pure_psgd.py +0 -0
- {heavyball-0.21.7 → heavyball-0.21.8}/heavyball/schedule_free_palm_foreach_soap.py +0 -0
- {heavyball-0.21.7 → heavyball-0.21.8}/heavyball.egg-info/SOURCES.txt +0 -0
- {heavyball-0.21.7 → heavyball-0.21.8}/heavyball.egg-info/dependency_links.txt +0 -0
- {heavyball-0.21.7 → heavyball-0.21.8}/heavyball.egg-info/requires.txt +0 -0
- {heavyball-0.21.7 → heavyball-0.21.8}/heavyball.egg-info/top_level.txt +0 -0
- {heavyball-0.21.7 → heavyball-0.21.8}/setup.cfg +0 -0
- {heavyball-0.21.7 → heavyball-0.21.8}/test/test_bf16_q.py +0 -0
- {heavyball-0.21.7 → heavyball-0.21.8}/test/test_bf16_storage.py +0 -0
- {heavyball-0.21.7 → heavyball-0.21.8}/test/test_closure.py +0 -0
- {heavyball-0.21.7 → heavyball-0.21.8}/test/test_ema.py +0 -0
- {heavyball-0.21.7 → heavyball-0.21.8}/test/test_foreach.py +0 -0
- {heavyball-0.21.7 → heavyball-0.21.8}/test/test_memory.py +0 -0
- {heavyball-0.21.7 → heavyball-0.21.8}/test/test_merge.py +0 -0
- {heavyball-0.21.7 → heavyball-0.21.8}/test/test_no_grad.py +0 -0
- {heavyball-0.21.7 → heavyball-0.21.8}/test/test_psgd.py +0 -0
- {heavyball-0.21.7 → heavyball-0.21.8}/test/test_soap.py +0 -0
- {heavyball-0.21.7 → heavyball-0.21.8}/test/test_stochastic_updates.py +0 -0
@@ -38,7 +38,7 @@ def warmup(lr: float, step: int, warmup_steps: int):
|
|
38
38
|
return lr * step / warmup_steps
|
39
39
|
|
40
40
|
|
41
|
-
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=
|
41
|
+
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
|
42
42
|
def _compilable_schedule_free_(p, z, ckp1, grad, lr, beta1):
|
43
43
|
p32 = promote(p)
|
44
44
|
z32 = promote(z)
|
@@ -141,7 +141,7 @@ def beta_debias(beta, step):
|
|
141
141
|
return 1 - (1 - beta) / (1 - beta ** step)
|
142
142
|
|
143
143
|
|
144
|
-
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=
|
144
|
+
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
|
145
145
|
def exp_avg_sq_(state, grad, beta2, eps, out=None):
|
146
146
|
if isinstance(state, torch.Tensor):
|
147
147
|
state.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
|
@@ -328,7 +328,7 @@ def get_orthogonal_matrix(mat):
|
|
328
328
|
return final
|
329
329
|
|
330
330
|
|
331
|
-
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=
|
331
|
+
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
|
332
332
|
def _compilable_stochastic_lerp_(x: List[torch.Tensor], y: List[torch.Tensor], a: Union[float, int, torch.Tensor]):
|
333
333
|
for x_, y_ in zip(x, y):
|
334
334
|
x32 = promote(x_)
|
@@ -343,7 +343,7 @@ def stochastic_lerp_(x: List[torch.Tensor], y: List[torch.Tensor], a: Union[floa
|
|
343
343
|
_compilable_stochastic_lerp_(x, y, a)
|
344
344
|
|
345
345
|
|
346
|
-
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=
|
346
|
+
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
|
347
347
|
def _compilable_stochastic_add_(x: List[torch.Tensor], y: List[torch.Tensor], alpha: Union[float, int, torch.Tensor]):
|
348
348
|
for x_, y_ in zip(x, y):
|
349
349
|
x32 = promote(x_)
|
@@ -581,7 +581,7 @@ def copy_stochastic_list_(target: List[torch.Tensor], source: List[torch.Tensor]
|
|
581
581
|
copy_stochastic_(t, s)
|
582
582
|
|
583
583
|
|
584
|
-
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=
|
584
|
+
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
|
585
585
|
def _compilable_exp_avg_(exp_avg, exp_avg_sq, grad, grad_projected, beta1, beta2, step):
|
586
586
|
beta1 = beta_debias(beta1, step)
|
587
587
|
beta2 = beta_debias(beta2, step)
|
@@ -632,7 +632,7 @@ def copy_stochastic_(target: torch.Tensor, source: torch.Tensor):
|
|
632
632
|
_compilable_copy_stochastic_(target, source)
|
633
633
|
|
634
634
|
|
635
|
-
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=
|
635
|
+
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
|
636
636
|
def _compilable_update_(p, u, decay, add_fn, lr):
|
637
637
|
u = [u_.view_as(p_) for u_, p_ in zip(u, p)]
|
638
638
|
p32, u32 = [list(map(promote, x)) for x in [p, u]]
|
@@ -788,7 +788,7 @@ def psgd_lb(A, max_abs):
|
|
788
788
|
return x
|
789
789
|
|
790
790
|
|
791
|
-
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=
|
791
|
+
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
|
792
792
|
def psgd_update_precond(Q, exprs, G, precond_lr, tiny, oq, store_triu_as_line):
|
793
793
|
"""Update Kronecker product preconditioner Q with pair (V, G)."""
|
794
794
|
exprA, exprGs, _ = exprs
|
@@ -821,7 +821,7 @@ def psgd_update_precond(Q, exprs, G, precond_lr, tiny, oq, store_triu_as_line):
|
|
821
821
|
stochastic_add_([o], [term1], -1)
|
822
822
|
|
823
823
|
|
824
|
-
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=
|
824
|
+
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
|
825
825
|
def psgd_precond_grad(Q, exprs, G, inplace: bool = False):
|
826
826
|
"""Precondition gradient G with preconditioner Q."""
|
827
827
|
md = min_dtype(Q)
|
@@ -965,7 +965,7 @@ class PSGDBase(StatefulOptimizer):
|
|
965
965
|
psgd_balance_Q(q)
|
966
966
|
|
967
967
|
|
968
|
-
|
968
|
+
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
|
969
969
|
def _compilable_precond_grad_cached_(cached_q, ea, expr, param, lr, weight_decay, clip_fn):
|
970
970
|
md = min_dtype(cached_q + [ea])
|
971
971
|
new = torch.einsum(expr, *[c_.to(md) for c_ in cached_q], ea.to(md)).to(torch.float32)
|
@@ -1,15 +1,15 @@
|
|
1
|
-
import pytest
|
2
|
-
import torch
|
3
|
-
from torch import nn
|
4
|
-
from torch._dynamo import config
|
5
|
-
|
6
1
|
import heavyball
|
7
2
|
import heavyball.utils
|
3
|
+
import pytest
|
4
|
+
import torch
|
8
5
|
from benchmark.utils import get_optim
|
9
6
|
from heavyball.utils import clean, set_torch
|
7
|
+
from torch import nn
|
8
|
+
from torch._dynamo import config
|
10
9
|
|
11
10
|
config.cache_size_limit = 128
|
12
11
|
|
12
|
+
|
13
13
|
def get_memory():
|
14
14
|
clean()
|
15
15
|
torch.cuda.synchronize()
|
@@ -22,11 +22,10 @@ def get_memory():
|
|
22
22
|
@pytest.mark.parametrize("size,depth", [(256, 2)])
|
23
23
|
def test_foreach(opt, size, depth: int, iterations: int = 512, outer_iterations: int = 3):
|
24
24
|
set_torch()
|
25
|
-
if 'psgd' not in opt.lower()
|
25
|
+
if 'psgd' not in opt.lower():
|
26
26
|
raise pytest.skip('Only PSGD and PaLMPAdam are supported')
|
27
27
|
opt = getattr(heavyball, opt)
|
28
28
|
|
29
|
-
|
30
29
|
peaks = []
|
31
30
|
losses = []
|
32
31
|
|
@@ -37,7 +36,10 @@ def test_foreach(opt, size, depth: int, iterations: int = 512, outer_iterations:
|
|
37
36
|
|
38
37
|
for i in range(outer_iterations):
|
39
38
|
model = nn.Sequential(*[nn.Linear(size, size) for _ in range(depth)]).cuda().to(dtype)
|
40
|
-
o = get_optim(opt, model.parameters(), lr=1e-3
|
39
|
+
o = get_optim(opt, model.parameters(), lr=1e-3, weight_decay=1e-4, warmup_steps=16,
|
40
|
+
max_size_triangular=2048, merge_dims=True, split=False, memory_save_mode='one_diag',
|
41
|
+
store_triu_as_line=False, stochastic_schedule=False, storage_dtype='bfloat16',
|
42
|
+
q_dtype='bfloat16')
|
41
43
|
|
42
44
|
for _ in range(iterations):
|
43
45
|
loss = model(torch.randn((1024, size), device='cuda', dtype=dtype)).square().mean()
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|