heavyball 0.21.6__tar.gz → 0.21.8__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. {heavyball-0.21.6 → heavyball-0.21.8}/PKG-INFO +1 -1
  2. {heavyball-0.21.6 → heavyball-0.21.8}/heavyball/utils.py +8 -8
  3. {heavyball-0.21.6 → heavyball-0.21.8}/heavyball.egg-info/PKG-INFO +1 -1
  4. {heavyball-0.21.6 → heavyball-0.21.8}/setup.py +1 -1
  5. {heavyball-0.21.6 → heavyball-0.21.8}/test/test_bf16_params.py +10 -8
  6. {heavyball-0.21.6 → heavyball-0.21.8}/LICENSE +0 -0
  7. {heavyball-0.21.6 → heavyball-0.21.8}/README.md +0 -0
  8. {heavyball-0.21.6 → heavyball-0.21.8}/heavyball/__init__.py +0 -0
  9. {heavyball-0.21.6 → heavyball-0.21.8}/heavyball/cached_delayed_psgd_kron.py +0 -0
  10. {heavyball-0.21.6 → heavyball-0.21.8}/heavyball/cached_psgd_kron.py +0 -0
  11. {heavyball-0.21.6 → heavyball-0.21.8}/heavyball/delayed_psgd.py +0 -0
  12. {heavyball-0.21.6 → heavyball-0.21.8}/heavyball/foreach_adamw.py +0 -0
  13. {heavyball-0.21.6 → heavyball-0.21.8}/heavyball/foreach_adopt.py +0 -0
  14. {heavyball-0.21.6 → heavyball-0.21.8}/heavyball/foreach_laprop.py +0 -0
  15. {heavyball-0.21.6 → heavyball-0.21.8}/heavyball/foreach_sfadamw.py +0 -0
  16. {heavyball-0.21.6 → heavyball-0.21.8}/heavyball/foreach_soap.py +0 -0
  17. {heavyball-0.21.6 → heavyball-0.21.8}/heavyball/p_adam.py +0 -0
  18. {heavyball-0.21.6 → heavyball-0.21.8}/heavyball/palm_foreach_sfadamw.py +0 -0
  19. {heavyball-0.21.6 → heavyball-0.21.8}/heavyball/palm_foreach_soap.py +0 -0
  20. {heavyball-0.21.6 → heavyball-0.21.8}/heavyball/precond_schedule_foreach_soap.py +0 -0
  21. {heavyball-0.21.6 → heavyball-0.21.8}/heavyball/precond_schedule_palm_foreach_soap.py +0 -0
  22. {heavyball-0.21.6 → heavyball-0.21.8}/heavyball/precond_schedule_sfpsoap.py +0 -0
  23. {heavyball-0.21.6 → heavyball-0.21.8}/heavyball/psgd_kron.py +0 -0
  24. {heavyball-0.21.6 → heavyball-0.21.8}/heavyball/pure_psgd.py +0 -0
  25. {heavyball-0.21.6 → heavyball-0.21.8}/heavyball/schedule_free_palm_foreach_soap.py +0 -0
  26. {heavyball-0.21.6 → heavyball-0.21.8}/heavyball.egg-info/SOURCES.txt +0 -0
  27. {heavyball-0.21.6 → heavyball-0.21.8}/heavyball.egg-info/dependency_links.txt +0 -0
  28. {heavyball-0.21.6 → heavyball-0.21.8}/heavyball.egg-info/requires.txt +0 -0
  29. {heavyball-0.21.6 → heavyball-0.21.8}/heavyball.egg-info/top_level.txt +0 -0
  30. {heavyball-0.21.6 → heavyball-0.21.8}/setup.cfg +0 -0
  31. {heavyball-0.21.6 → heavyball-0.21.8}/test/test_bf16_q.py +0 -0
  32. {heavyball-0.21.6 → heavyball-0.21.8}/test/test_bf16_storage.py +0 -0
  33. {heavyball-0.21.6 → heavyball-0.21.8}/test/test_closure.py +0 -0
  34. {heavyball-0.21.6 → heavyball-0.21.8}/test/test_ema.py +0 -0
  35. {heavyball-0.21.6 → heavyball-0.21.8}/test/test_foreach.py +0 -0
  36. {heavyball-0.21.6 → heavyball-0.21.8}/test/test_memory.py +0 -0
  37. {heavyball-0.21.6 → heavyball-0.21.8}/test/test_merge.py +0 -0
  38. {heavyball-0.21.6 → heavyball-0.21.8}/test/test_no_grad.py +0 -0
  39. {heavyball-0.21.6 → heavyball-0.21.8}/test/test_psgd.py +0 -0
  40. {heavyball-0.21.6 → heavyball-0.21.8}/test/test_soap.py +0 -0
  41. {heavyball-0.21.6 → heavyball-0.21.8}/test/test_stochastic_updates.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: heavyball
3
- Version: 0.21.6
3
+ Version: 0.21.8
4
4
  Summary: Efficient optimizers
5
5
  Home-page: https://github.com/clashluke/heavyball
6
6
  Author: Lucas Nestler
@@ -38,7 +38,7 @@ def warmup(lr: float, step: int, warmup_steps: int):
38
38
  return lr * step / warmup_steps
39
39
 
40
40
 
41
- @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
41
+ @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
42
42
  def _compilable_schedule_free_(p, z, ckp1, grad, lr, beta1):
43
43
  p32 = promote(p)
44
44
  z32 = promote(z)
@@ -141,7 +141,7 @@ def beta_debias(beta, step):
141
141
  return 1 - (1 - beta) / (1 - beta ** step)
142
142
 
143
143
 
144
- @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
144
+ @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
145
145
  def exp_avg_sq_(state, grad, beta2, eps, out=None):
146
146
  if isinstance(state, torch.Tensor):
147
147
  state.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
@@ -328,7 +328,7 @@ def get_orthogonal_matrix(mat):
328
328
  return final
329
329
 
330
330
 
331
- @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
331
+ @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
332
332
  def _compilable_stochastic_lerp_(x: List[torch.Tensor], y: List[torch.Tensor], a: Union[float, int, torch.Tensor]):
333
333
  for x_, y_ in zip(x, y):
334
334
  x32 = promote(x_)
@@ -343,7 +343,7 @@ def stochastic_lerp_(x: List[torch.Tensor], y: List[torch.Tensor], a: Union[floa
343
343
  _compilable_stochastic_lerp_(x, y, a)
344
344
 
345
345
 
346
- @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
346
+ @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
347
347
  def _compilable_stochastic_add_(x: List[torch.Tensor], y: List[torch.Tensor], alpha: Union[float, int, torch.Tensor]):
348
348
  for x_, y_ in zip(x, y):
349
349
  x32 = promote(x_)
@@ -581,7 +581,7 @@ def copy_stochastic_list_(target: List[torch.Tensor], source: List[torch.Tensor]
581
581
  copy_stochastic_(t, s)
582
582
 
583
583
 
584
- @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
584
+ @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
585
585
  def _compilable_exp_avg_(exp_avg, exp_avg_sq, grad, grad_projected, beta1, beta2, step):
586
586
  beta1 = beta_debias(beta1, step)
587
587
  beta2 = beta_debias(beta2, step)
@@ -632,7 +632,7 @@ def copy_stochastic_(target: torch.Tensor, source: torch.Tensor):
632
632
  _compilable_copy_stochastic_(target, source)
633
633
 
634
634
 
635
- @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
635
+ @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
636
636
  def _compilable_update_(p, u, decay, add_fn, lr):
637
637
  u = [u_.view_as(p_) for u_, p_ in zip(u, p)]
638
638
  p32, u32 = [list(map(promote, x)) for x in [p, u]]
@@ -788,7 +788,7 @@ def psgd_lb(A, max_abs):
788
788
  return x
789
789
 
790
790
 
791
- @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
791
+ @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
792
792
  def psgd_update_precond(Q, exprs, G, precond_lr, tiny, oq, store_triu_as_line):
793
793
  """Update Kronecker product preconditioner Q with pair (V, G)."""
794
794
  exprA, exprGs, _ = exprs
@@ -821,7 +821,7 @@ def psgd_update_precond(Q, exprs, G, precond_lr, tiny, oq, store_triu_as_line):
821
821
  stochastic_add_([o], [term1], -1)
822
822
 
823
823
 
824
- @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
824
+ @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
825
825
  def psgd_precond_grad(Q, exprs, G, inplace: bool = False):
826
826
  """Precondition gradient G with preconditioner Q."""
827
827
  md = min_dtype(Q)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: heavyball
3
- Version: 0.21.6
3
+ Version: 0.21.8
4
4
  Summary: Efficient optimizers
5
5
  Home-page: https://github.com/clashluke/heavyball
6
6
  Author: Lucas Nestler
@@ -10,7 +10,7 @@ setuptools.setup(
10
10
  name='heavyball',
11
11
  license='BSD',
12
12
  description='Efficient optimizers',
13
- version='0.21.6',
13
+ version='0.21.8',
14
14
  long_description=README,
15
15
  url='https://github.com/clashluke/heavyball',
16
16
  packages=setuptools.find_packages(),
@@ -1,15 +1,15 @@
1
- import pytest
2
- import torch
3
- from torch import nn
4
- from torch._dynamo import config
5
-
6
1
  import heavyball
7
2
  import heavyball.utils
3
+ import pytest
4
+ import torch
8
5
  from benchmark.utils import get_optim
9
6
  from heavyball.utils import clean, set_torch
7
+ from torch import nn
8
+ from torch._dynamo import config
10
9
 
11
10
  config.cache_size_limit = 128
12
11
 
12
+
13
13
  def get_memory():
14
14
  clean()
15
15
  torch.cuda.synchronize()
@@ -22,11 +22,10 @@ def get_memory():
22
22
  @pytest.mark.parametrize("size,depth", [(256, 2)])
23
23
  def test_foreach(opt, size, depth: int, iterations: int = 512, outer_iterations: int = 3):
24
24
  set_torch()
25
- if 'psgd' not in opt.lower() and 'padam' not in opt.lower():
25
+ if 'psgd' not in opt.lower():
26
26
  raise pytest.skip('Only PSGD and PaLMPAdam are supported')
27
27
  opt = getattr(heavyball, opt)
28
28
 
29
-
30
29
  peaks = []
31
30
  losses = []
32
31
 
@@ -37,7 +36,10 @@ def test_foreach(opt, size, depth: int, iterations: int = 512, outer_iterations:
37
36
 
38
37
  for i in range(outer_iterations):
39
38
  model = nn.Sequential(*[nn.Linear(size, size) for _ in range(depth)]).cuda().to(dtype)
40
- o = get_optim(opt, model.parameters(), lr=1e-3)
39
+ o = get_optim(opt, model.parameters(), lr=1e-3, weight_decay=1e-4, warmup_steps=16,
40
+ max_size_triangular=2048, merge_dims=True, split=False, memory_save_mode='one_diag',
41
+ store_triu_as_line=False, stochastic_schedule=False, storage_dtype='bfloat16',
42
+ q_dtype='bfloat16')
41
43
 
42
44
  for _ in range(iterations):
43
45
  loss = model(torch.randn((1024, size), device='cuda', dtype=dtype)).square().mean()
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes