heavyball 0.18.4__tar.gz → 0.18.6__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. {heavyball-0.18.4 → heavyball-0.18.6}/PKG-INFO +2 -2
  2. {heavyball-0.18.4 → heavyball-0.18.6}/README.md +1 -1
  3. {heavyball-0.18.4 → heavyball-0.18.6}/heavyball/foreach_adamw.py +2 -2
  4. {heavyball-0.18.4 → heavyball-0.18.6}/heavyball/utils.py +44 -24
  5. {heavyball-0.18.4 → heavyball-0.18.6}/heavyball.egg-info/PKG-INFO +2 -2
  6. {heavyball-0.18.4 → heavyball-0.18.6}/heavyball.egg-info/SOURCES.txt +1 -0
  7. {heavyball-0.18.4 → heavyball-0.18.6}/setup.py +1 -1
  8. heavyball-0.18.6/test/test_bf16_params.py +52 -0
  9. {heavyball-0.18.4 → heavyball-0.18.6}/test/test_bf16_q.py +8 -5
  10. {heavyball-0.18.4 → heavyball-0.18.6}/test/test_foreach.py +1 -1
  11. {heavyball-0.18.4 → heavyball-0.18.6}/test/test_memory.py +1 -1
  12. {heavyball-0.18.4 → heavyball-0.18.6}/test/test_stochastic_updates.py +1 -1
  13. {heavyball-0.18.4 → heavyball-0.18.6}/LICENSE +0 -0
  14. {heavyball-0.18.4 → heavyball-0.18.6}/heavyball/__init__.py +0 -0
  15. {heavyball-0.18.4 → heavyball-0.18.6}/heavyball/cached_delayed_psgd_kron.py +0 -0
  16. {heavyball-0.18.4 → heavyball-0.18.6}/heavyball/cached_psgd_kron.py +0 -0
  17. {heavyball-0.18.4 → heavyball-0.18.6}/heavyball/delayed_psgd.py +0 -0
  18. {heavyball-0.18.4 → heavyball-0.18.6}/heavyball/foreach_adopt.py +0 -0
  19. {heavyball-0.18.4 → heavyball-0.18.6}/heavyball/foreach_laprop.py +0 -0
  20. {heavyball-0.18.4 → heavyball-0.18.6}/heavyball/foreach_sfadamw.py +0 -0
  21. {heavyball-0.18.4 → heavyball-0.18.6}/heavyball/foreach_soap.py +0 -0
  22. {heavyball-0.18.4 → heavyball-0.18.6}/heavyball/p_adam.py +0 -0
  23. {heavyball-0.18.4 → heavyball-0.18.6}/heavyball/palm_foreach_sfadamw.py +0 -0
  24. {heavyball-0.18.4 → heavyball-0.18.6}/heavyball/palm_foreach_soap.py +0 -0
  25. {heavyball-0.18.4 → heavyball-0.18.6}/heavyball/precond_schedule_foreach_soap.py +0 -0
  26. {heavyball-0.18.4 → heavyball-0.18.6}/heavyball/precond_schedule_palm_foreach_soap.py +0 -0
  27. {heavyball-0.18.4 → heavyball-0.18.6}/heavyball/precond_schedule_sfpsoap.py +0 -0
  28. {heavyball-0.18.4 → heavyball-0.18.6}/heavyball/psgd_kron.py +0 -0
  29. {heavyball-0.18.4 → heavyball-0.18.6}/heavyball/pure_psgd.py +0 -0
  30. {heavyball-0.18.4 → heavyball-0.18.6}/heavyball/schedule_free_palm_foreach_soap.py +0 -0
  31. {heavyball-0.18.4 → heavyball-0.18.6}/heavyball.egg-info/dependency_links.txt +0 -0
  32. {heavyball-0.18.4 → heavyball-0.18.6}/heavyball.egg-info/requires.txt +0 -0
  33. {heavyball-0.18.4 → heavyball-0.18.6}/heavyball.egg-info/top_level.txt +0 -0
  34. {heavyball-0.18.4 → heavyball-0.18.6}/setup.cfg +0 -0
  35. {heavyball-0.18.4 → heavyball-0.18.6}/test/test_closure.py +0 -0
  36. {heavyball-0.18.4 → heavyball-0.18.6}/test/test_merge.py +0 -0
  37. {heavyball-0.18.4 → heavyball-0.18.6}/test/test_no_grad.py +0 -0
  38. {heavyball-0.18.4 → heavyball-0.18.6}/test/test_psgd.py +0 -0
  39. {heavyball-0.18.4 → heavyball-0.18.6}/test/test_soap.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: heavyball
3
- Version: 0.18.4
3
+ Version: 0.18.6
4
4
  Summary: Efficient optimizers
5
5
  Home-page: https://github.com/clashluke/heavyball
6
6
  Author: Lucas Nestler
@@ -32,7 +32,7 @@ A simple package of efficient optimizers
32
32
  The goal is not to thrive for completeness, full maintenance or abstraction, but instead to provide a simple
33
33
  largely static alternative to `torch.optim` with more and better optimizers.
34
34
 
35
- Currently (2024-11-20, 0.17.0), the recommended stable optimizer is `PrecondSchedulePaLMSOAP` (see below). The
35
+ Currently (2024-11-21, 0.18.6), the recommended stable optimizer is `PrecondSchedulePaLMSOAP` (see below). The
36
36
  recommended experimental optimizer is `DelayedPSGDKron` ([tuning guide](docs/psgd_efficiency.md)).
37
37
 
38
38
  ## Features
@@ -8,7 +8,7 @@ A simple package of efficient optimizers
8
8
  The goal is not to thrive for completeness, full maintenance or abstraction, but instead to provide a simple
9
9
  largely static alternative to `torch.optim` with more and better optimizers.
10
10
 
11
- Currently (2024-11-20, 0.17.0), the recommended stable optimizer is `PrecondSchedulePaLMSOAP` (see below). The
11
+ Currently (2024-11-21, 0.18.6), the recommended stable optimizer is `PrecondSchedulePaLMSOAP` (see below). The
12
12
  recommended experimental optimizer is `DelayedPSGDKron` ([tuning guide](docs/psgd_efficiency.md)).
13
13
 
14
14
  ## Features
@@ -34,9 +34,9 @@ class ForeachAdamW(StatefulOptimizer):
34
34
 
35
35
  # Decay the first and second moment running average coefficient
36
36
  torch._foreach_lerp_(exp_avg, grad, 1 - beta_debias(group['betas'][0], k + 1))
37
- denom = exp_avg_sq_(exp_avg_sq, grad, beta_debias(group['betas'][1], k + 1), eps)
37
+ denom = list(exp_avg_sq_(exp_avg_sq, grad, beta_debias(group['betas'][1], k + 1), eps))
38
38
 
39
39
  # Normalize grad in-place for memory efficiency
40
40
  lr = -warmup(group['lr'], k + 1, group['warmup_steps'])
41
- update_param_(y, exp_avg, lr, decay, lambda p, e, l: torch._foreach_addcdiv_(p, e, denom, l))
41
+ update_param_(y, exp_avg, lr, decay, lambda p, e, l: p.addcdiv_(e, denom.pop(0), value=l))
42
42
  group['k'] = k + 1
@@ -38,6 +38,18 @@ def warmup(lr: float, step: int, warmup_steps: int):
38
38
  return lr * step / warmup_steps
39
39
 
40
40
 
41
+ @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
42
+ def _compilable_schedule_free_(p, z, ckp1, grad, lr, beta1):
43
+ p32 = p.float()
44
+ z32 = z.float()
45
+ p32.lerp_(end=z32, weight=1 - ckp1)
46
+ p32.add_(grad, alpha=lr * (beta1 * (1 - ckp1) - 1))
47
+ _guarded_copy_stochastic(p, p32)
48
+
49
+ z32.add_(grad, alpha=-lr)
50
+ _guarded_copy_stochastic(z, z32)
51
+
52
+
41
53
  def schedule_free_(lr: float, weight_lr_power: float, weight_sum: float, beta1: float, parameters: List[torch.Tensor],
42
54
  z: List[torch.Tensor], grad: list[torch.Tensor], r: float = 0.0, step: int = 0):
43
55
  weight = lr ** weight_lr_power * max(step, 1) ** r
@@ -50,15 +62,10 @@ def schedule_free_(lr: float, weight_lr_power: float, weight_sum: float, beta1:
50
62
 
51
63
  # These operations update y in-place,
52
64
  # without computing x explicitly.
53
- p32 = [promote(p) for p in parameters]
54
- z32 = [promote(z_) for z_ in z]
55
- torch._foreach_lerp_(p32, z32, weight=ckp1)
56
- torch._foreach_add_(p32, grad, alpha=lr * (beta1 * (1 - ckp1) - 1))
57
- copy_stochastic_list_(parameters, p32)
58
-
59
- # z step
60
- torch._foreach_sub_(z, grad, alpha=lr)
61
- copy_stochastic_list_(z, z32)
65
+ lr_tensor = torch.empty((), dtype=torch.float32, device=parameters[0].device).fill_(lr)
66
+ ckp1_tensor = torch.empty((), dtype=torch.float32, device=parameters[0].device).fill_(ckp1)
67
+ for p, z_, g in zip(parameters, z, grad):
68
+ _compilable_schedule_free_(p, z_, ckp1_tensor, g, lr_tensor, beta1)
62
69
  return weight_sum
63
70
 
64
71
 
@@ -150,11 +157,11 @@ def adaptive_gradient_clipping_(parameters: List[torch.Tensor], gradients: List[
150
157
 
151
158
 
152
159
  def set_(dst: torch.Tensor, src: torch.Tensor):
153
- if src.data_ptr() == dst.data_ptr():
160
+ if not torch.compiler.is_compiling() and src.data_ptr() == dst.data_ptr():
154
161
  return
155
162
  if src.shape != dst.shape:
156
163
  src = src.reshape_as(dst)
157
- if src.is_contiguous() and dst.is_contiguous() and src.dtype == dst.dtype:
164
+ if not torch.compiler.is_compiling() and src.is_contiguous() and dst.is_contiguous() and src.dtype == dst.dtype:
158
165
  dst.set_(src)
159
166
  else:
160
167
  dst.copy_(src)
@@ -479,6 +486,12 @@ def copy_stochastic_list_(target: List[torch.Tensor], source: List[torch.Tensor]
479
486
  copy_stochastic_(t, s)
480
487
 
481
488
 
489
+ def _guarded_copy_stochastic(target: torch.Tensor, source: torch.Tensor):
490
+ if target.dtype != torch.bfloat16 or source.dtype not in (torch.float16, torch.float32, torch.float64):
491
+ set_(target, source)
492
+ _compilable_copy_stochastic_(target, source)
493
+
494
+
482
495
  @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
483
496
  def _compilable_copy_stochastic_(target: torch.Tensor, source: torch.Tensor):
484
497
  """Taken as-is from https://github.com/pytorch/pytorch/issues/120376#issuecomment-1974828905"""
@@ -498,23 +511,27 @@ def _compilable_copy_stochastic_(target: torch.Tensor, source: torch.Tensor):
498
511
  def copy_stochastic_(target: torch.Tensor, source: torch.Tensor):
499
512
  if target.data_ptr() == source.data_ptr():
500
513
  return
501
- if target.dtype != torch.bfloat16 or source.dtype not in (torch.float16, torch.float32, torch.float64):
502
- set_(target, source)
503
- return
504
- _compilable_copy_stochastic_(target, source)
514
+ _guarded_copy_stochastic(target, source)
505
515
 
506
516
 
507
- def update_param_(param: List[torch.Tensor], update: List[torch.Tensor], lr: float, decay: float,
508
- add_fn: callable = None):
509
- param32 = [promote(p) for p in param]
510
- update32 = [promote(u.view(p.shape)) for u, p in zip(update, param)]
517
+ @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
518
+ def _compilable_update_one_(p, u, decay, add_fn, lr):
519
+ p32 = p.float()
520
+ u32 = u.view(p.shape).float()
511
521
  if decay > 0:
512
- torch._foreach_mul_(param32, 1 - decay * lr)
522
+ p32.mul_(1 - decay * lr)
513
523
  if add_fn is None:
514
- torch._foreach_add_(param32, update32, alpha=lr)
524
+ p32.add_(u32, alpha=lr)
515
525
  else:
516
- add_fn(param32, update32, lr)
517
- copy_stochastic_list_(param, param32)
526
+ add_fn(p32, u32, lr)
527
+ _guarded_copy_stochastic(p, p32)
528
+
529
+
530
+ def update_param_(param: List[torch.Tensor], update: List[torch.Tensor], lr: float, decay: float,
531
+ add_fn: callable = None):
532
+ lr_tensor = torch.empty((), dtype=torch.float32, device=param[0].device).fill_(lr)
533
+ for p, u in zip(param, update):
534
+ _compilable_update_one_(p, u, decay, add_fn, lr_tensor)
518
535
 
519
536
 
520
537
  def precond_schedule(step, precond_scheduler, rng):
@@ -828,7 +845,10 @@ class PSGDBase(StatefulOptimizer):
828
845
 
829
846
  for g, q in zip(grad_list, original_q if original_q else q_list):
830
847
  if g.dim() > 1:
831
- psgd_balance_Q(q)
848
+ if store_triu_as_line:
849
+ psgd_balance_Q([q_ for _, q_ in q])
850
+ else:
851
+ psgd_balance_Q(q)
832
852
 
833
853
 
834
854
  def precond_update_prob_schedule(max_prob=1.0, min_prob=0.03, decay=0.001, flat_start=250):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: heavyball
3
- Version: 0.18.4
3
+ Version: 0.18.6
4
4
  Summary: Efficient optimizers
5
5
  Home-page: https://github.com/clashluke/heavyball
6
6
  Author: Lucas Nestler
@@ -32,7 +32,7 @@ A simple package of efficient optimizers
32
32
  The goal is not to thrive for completeness, full maintenance or abstraction, but instead to provide a simple
33
33
  largely static alternative to `torch.optim` with more and better optimizers.
34
34
 
35
- Currently (2024-11-20, 0.17.0), the recommended stable optimizer is `PrecondSchedulePaLMSOAP` (see below). The
35
+ Currently (2024-11-21, 0.18.6), the recommended stable optimizer is `PrecondSchedulePaLMSOAP` (see below). The
36
36
  recommended experimental optimizer is `DelayedPSGDKron` ([tuning guide](docs/psgd_efficiency.md)).
37
37
 
38
38
  ## Features
@@ -25,6 +25,7 @@ heavyball.egg-info/SOURCES.txt
25
25
  heavyball.egg-info/dependency_links.txt
26
26
  heavyball.egg-info/requires.txt
27
27
  heavyball.egg-info/top_level.txt
28
+ test/test_bf16_params.py
28
29
  test/test_bf16_q.py
29
30
  test/test_closure.py
30
31
  test/test_foreach.py
@@ -10,7 +10,7 @@ setuptools.setup(
10
10
  name='heavyball',
11
11
  license='BSD',
12
12
  description='Efficient optimizers',
13
- version='0.18.4',
13
+ version='0.18.6',
14
14
  long_description=README,
15
15
  url='https://github.com/clashluke/heavyball',
16
16
  packages=setuptools.find_packages(),
@@ -0,0 +1,52 @@
1
+ import pytest
2
+ import torch
3
+ from torch import nn
4
+ from torch._dynamo import config
5
+
6
+ import heavyball
7
+ import heavyball.utils
8
+ from benchmark.utils import get_optim
9
+ from heavyball.utils import clean, set_torch
10
+
11
+ config.cache_size_limit = 128
12
+
13
+ def get_memory():
14
+ clean()
15
+ torch.cuda.synchronize()
16
+ clean()
17
+ torch.cuda.synchronize()
18
+ return torch.cuda.memory_allocated()
19
+
20
+
21
+ @pytest.mark.parametrize("opt", heavyball.__all__)
22
+ @pytest.mark.parametrize("size,depth", [(256, 2)])
23
+ def test_foreach(opt, size, depth: int, iterations: int = 128, outer_iterations: int = 3):
24
+ set_torch()
25
+
26
+ opt = getattr(heavyball, opt)
27
+
28
+ peaks = []
29
+ losses = []
30
+
31
+ for dtype in [torch.float32, torch.bfloat16]:
32
+ torch.manual_seed(0x2131290)
33
+ peaks.append([])
34
+ losses.append([])
35
+
36
+ for i in range(outer_iterations):
37
+ model = nn.Sequential(*[nn.Linear(size, size) for _ in range(depth)]).cuda().to(dtype)
38
+ o = get_optim(opt, model.parameters(), lr=1e-3)
39
+
40
+ for _ in range(iterations):
41
+ loss = model(torch.randn((1024, size), device='cuda', dtype=dtype)).square().mean()
42
+ loss.backward()
43
+ o.step()
44
+ o.zero_grad()
45
+ losses[-1].append(loss.detach())
46
+
47
+ del model, o
48
+ clean()
49
+
50
+ for i, (l0, l1) in enumerate(zip(*losses)):
51
+ print(i, l0.item(), l1.item())
52
+ assert torch.allclose(l0.float(), l1.float(), rtol=0.1)
@@ -1,10 +1,14 @@
1
- import heavyball
2
- import heavyball.utils
3
1
  import pytest
4
2
  import torch
3
+ from torch import nn
4
+ from torch._dynamo import config
5
+
6
+ import heavyball
7
+ import heavyball.utils
5
8
  from benchmark.utils import get_optim
6
9
  from heavyball.utils import clean, set_torch, PSGDBase
7
- from torch import nn
10
+
11
+ config.cache_size_limit = 128
8
12
 
9
13
 
10
14
  def get_memory():
@@ -37,7 +41,7 @@ def test_foreach(opt, size, depth: int, iterations: int = 128, outer_iterations:
37
41
  o = get_optim(opt, model.parameters(), lr=1e-3, q_dtype=q_dtype)
38
42
 
39
43
  for _ in range(iterations):
40
- loss = model(torch.randn((1024, size)).cuda()).square().mean()
44
+ loss = model(torch.randn((1024, size), device='cuda')).square().mean()
41
45
  loss.backward()
42
46
  o.step()
43
47
  o.zero_grad()
@@ -46,7 +50,6 @@ def test_foreach(opt, size, depth: int, iterations: int = 128, outer_iterations:
46
50
  del model, o
47
51
  clean()
48
52
 
49
-
50
53
  for i, (l0, l1) in enumerate(zip(*losses)):
51
54
  print(i, l0.item(), l1.item())
52
55
  assert torch.allclose(l0, l1, rtol=0.1)
@@ -45,7 +45,7 @@ def test_foreach(opt, size, depth: int, iterations: int = 5, outer_iterations: i
45
45
  clean()
46
46
 
47
47
  for _ in range(iterations):
48
- loss = model(torch.randn((1, size)).cuda()).sum()
48
+ loss = model(torch.randn((1, size), device='cuda')).sum()
49
49
  loss.backward()
50
50
  o.step()
51
51
  o.zero_grad()
@@ -48,7 +48,7 @@ def test_memory(opt, method, size, depth: int, iterations: int = 5, outer_iterat
48
48
  model_allocated = get_memory()
49
49
  o = get_optim(opt, model.parameters(), lr=1e-3)
50
50
  for _ in range(iterations):
51
- model(torch.randn((1, size)).cuda()).sum().backward()
51
+ model(torch.randn((1, size), device='cuda')).sum().backward()
52
52
  o.step()
53
53
 
54
54
  opt_allocated = get_memory()
@@ -38,7 +38,7 @@ def test_foreach(opt, size, depth: int, iterations: int = 8192, outer_iterations
38
38
  o = get_optim(opt, model.parameters(), lr=1e-3, stochastic_schedule=stochastic)
39
39
 
40
40
  for _ in range(iterations):
41
- loss = model(torch.randn((128, size)).cuda()).square().mean()
41
+ loss = model(torch.randn((128, size), device-'cuda')).square().mean()
42
42
  loss.backward()
43
43
  o.step()
44
44
  o.zero_grad()
File without changes
File without changes
File without changes
File without changes