heavyball 0.18.7__tar.gz → 0.19.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. {heavyball-0.18.7 → heavyball-0.19.0}/PKG-INFO +18 -16
  2. {heavyball-0.18.7 → heavyball-0.19.0}/README.md +17 -15
  3. heavyball-0.19.0/heavyball/foreach_adamw.py +56 -0
  4. heavyball-0.19.0/heavyball/foreach_adopt.py +78 -0
  5. heavyball-0.19.0/heavyball/foreach_laprop.py +61 -0
  6. heavyball-0.19.0/heavyball/foreach_sfadamw.py +63 -0
  7. heavyball-0.19.0/heavyball/palm_foreach_sfadamw.py +69 -0
  8. {heavyball-0.18.7 → heavyball-0.19.0}/heavyball/psgd_kron.py +2 -1
  9. {heavyball-0.18.7 → heavyball-0.19.0}/heavyball/utils.py +28 -20
  10. {heavyball-0.18.7 → heavyball-0.19.0}/heavyball.egg-info/PKG-INFO +18 -16
  11. {heavyball-0.18.7 → heavyball-0.19.0}/heavyball.egg-info/SOURCES.txt +1 -0
  12. {heavyball-0.18.7 → heavyball-0.19.0}/setup.py +1 -1
  13. {heavyball-0.18.7 → heavyball-0.19.0}/test/test_bf16_params.py +0 -1
  14. heavyball-0.19.0/test/test_bf16_storage.py +60 -0
  15. heavyball-0.18.7/heavyball/foreach_adamw.py +0 -42
  16. heavyball-0.18.7/heavyball/foreach_adopt.py +0 -52
  17. heavyball-0.18.7/heavyball/foreach_laprop.py +0 -47
  18. heavyball-0.18.7/heavyball/foreach_sfadamw.py +0 -54
  19. heavyball-0.18.7/heavyball/palm_foreach_sfadamw.py +0 -57
  20. {heavyball-0.18.7 → heavyball-0.19.0}/LICENSE +0 -0
  21. {heavyball-0.18.7 → heavyball-0.19.0}/heavyball/__init__.py +0 -0
  22. {heavyball-0.18.7 → heavyball-0.19.0}/heavyball/cached_delayed_psgd_kron.py +0 -0
  23. {heavyball-0.18.7 → heavyball-0.19.0}/heavyball/cached_psgd_kron.py +0 -0
  24. {heavyball-0.18.7 → heavyball-0.19.0}/heavyball/delayed_psgd.py +0 -0
  25. {heavyball-0.18.7 → heavyball-0.19.0}/heavyball/foreach_soap.py +0 -0
  26. {heavyball-0.18.7 → heavyball-0.19.0}/heavyball/p_adam.py +0 -0
  27. {heavyball-0.18.7 → heavyball-0.19.0}/heavyball/palm_foreach_soap.py +0 -0
  28. {heavyball-0.18.7 → heavyball-0.19.0}/heavyball/precond_schedule_foreach_soap.py +0 -0
  29. {heavyball-0.18.7 → heavyball-0.19.0}/heavyball/precond_schedule_palm_foreach_soap.py +0 -0
  30. {heavyball-0.18.7 → heavyball-0.19.0}/heavyball/precond_schedule_sfpsoap.py +0 -0
  31. {heavyball-0.18.7 → heavyball-0.19.0}/heavyball/pure_psgd.py +0 -0
  32. {heavyball-0.18.7 → heavyball-0.19.0}/heavyball/schedule_free_palm_foreach_soap.py +0 -0
  33. {heavyball-0.18.7 → heavyball-0.19.0}/heavyball.egg-info/dependency_links.txt +0 -0
  34. {heavyball-0.18.7 → heavyball-0.19.0}/heavyball.egg-info/requires.txt +0 -0
  35. {heavyball-0.18.7 → heavyball-0.19.0}/heavyball.egg-info/top_level.txt +0 -0
  36. {heavyball-0.18.7 → heavyball-0.19.0}/setup.cfg +0 -0
  37. {heavyball-0.18.7 → heavyball-0.19.0}/test/test_bf16_q.py +0 -0
  38. {heavyball-0.18.7 → heavyball-0.19.0}/test/test_closure.py +0 -0
  39. {heavyball-0.18.7 → heavyball-0.19.0}/test/test_foreach.py +0 -0
  40. {heavyball-0.18.7 → heavyball-0.19.0}/test/test_memory.py +0 -0
  41. {heavyball-0.18.7 → heavyball-0.19.0}/test/test_merge.py +0 -0
  42. {heavyball-0.18.7 → heavyball-0.19.0}/test/test_no_grad.py +0 -0
  43. {heavyball-0.18.7 → heavyball-0.19.0}/test/test_psgd.py +0 -0
  44. {heavyball-0.18.7 → heavyball-0.19.0}/test/test_soap.py +0 -0
  45. {heavyball-0.18.7 → heavyball-0.19.0}/test/test_stochastic_updates.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: heavyball
3
- Version: 0.18.7
3
+ Version: 0.19.0
4
4
  Summary: Efficient optimizers
5
5
  Home-page: https://github.com/clashluke/heavyball
6
6
  Author: Lucas Nestler
@@ -32,7 +32,7 @@ A simple package of efficient optimizers
32
32
  The goal is not to thrive for completeness, full maintenance or abstraction, but instead to provide a simple
33
33
  largely static alternative to `torch.optim` with more and better optimizers.
34
34
 
35
- Currently (2024-11-21, 0.18.6), the recommended stable optimizer is `PrecondSchedulePaLMSOAP` (see below). The
35
+ Currently (2024-11-22, 0.19), the recommended stable optimizer is `PrecondSchedulePaLMSOAP` (see below). The
36
36
  recommended experimental optimizer is `DelayedPSGDKron` ([tuning guide](docs/psgd_efficiency.md)).
37
37
 
38
38
  ## Features
@@ -45,8 +45,10 @@ recommended experimental optimizer is `DelayedPSGDKron` ([tuning guide](docs/psg
45
45
  * **ScheduleFree**: No learning rate schedule, but better convergence
46
46
  * [**Preconditioner Schedule**](https://github.com/lixilinx/psgd_torch/): Improved loss-per-step in early convergence,
47
47
  better step-per-second in late convergence (explained below)
48
- * **Memory-efficient storage** PSGD supports `store_triu_as_line` (default: `True`) to trade off memory usage for memory
49
- bandwidth; turn it off for lower overheads (for more, see [PSGD Efficiency](docs/psgd_efficiency.md))
48
+ * **Memory-efficient storage** PSGD supports `store_triu_as_line` (default: `True`) and `q_dtype` to trade off memory
49
+ usage for memory
50
+ bandwidth; Other optimizers have `storage_dtype`, supporting lower-precision EMAs at no(?) performance drop via
51
+ stochastic rounding
50
52
 
51
53
  ## Getting started
52
54
 
@@ -76,19 +78,19 @@ for _ in range(1000):
76
78
 
77
79
  ## Optimizers
78
80
 
79
- | Name | Description | Advantages / Disadvantages |
80
- |-------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
81
- | **AdamW** | More efficient (speed, memory) [AdamW](https://arxiv.org/abs/1711.05101) | + Faster than AdamW<br>+ Possibly more (numerically) stable
82
- | **LaProp** | More efficient (speed, memory) [LaProp](https://arxiv.org/abs/2002.04839) | + Same cost as AdamW<br>+ Marginally better converence (better proofs)<br>+ Higher hyperparameter stability<br>- Not a guaranteed win (can be neutral)<br>- No "Slingshot" |
83
- | **ADOPT** | More efficient (speed, memory) [ADOPT](https://arxiv.org/abs/2411.02853) | + Same cost as AdamW<br>+ Rigorous mathematical convergence proofs, even for challenging models (GANs)<br>- Empirically underperforms LaProp<br>- no bf16 |
84
- | **SFAdamW** | More efficient (speed, memory) [ScheduleFree AdamW](https://arxiv.org/abs/2405.15682) | + Same cost as AdamW, but better eval perf<br>+ Full control over hyperparameters |
85
- | **PaLMSFAdamW** | ForeachSFAdamW with [PaLM's beta2 schedule](https://arxiv.org/abs/2204.02311) | + Same cost as AdamW, but better eval perf<br>+ Less control, but faster early and more stable late convergence<br>+ ScheduleFree<br>- slow early convergence |
86
- | **SOAP** | More efficient (speed, memory) [SOAP](https://arxiv.org/abs/2409.11321) | + Faster convergence (loss-at-step)<br>+ Full control over hyperparameters<br>- more memory usage<br>- more hyperparameters<br>- higher overhead than AdamW (can be ammortized; better loss-at-second) |
87
- | **PaLMSOAP** | ForeachSOAP with [PaLM's beta2 schedule](https://arxiv.org/abs/2204.02311) | + Faster convergence (loss-at-step)<br>+ Less control, but faster early and more stable late convergence<br>- more memory usage<br>- more hyperparameters<br>- higher overhead than AdamW (can be ammortized; better loss-at-second) |
88
- | **SFPaLMSOAP** | ScheduleFree PaLMForeachSOAP | + Fast convergence (loss-at-step)<br>+ less memory usage than PaLMForeachSOAP (more tham AdamW)<br>- slower initial convergence than PaLMForeachSOAP (but allows higher LRs)<br>- higher overhead than AdamW (can be ammortized) |
81
+ | Name | Description | Advantages / Disadvantages |
82
+ |-------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
83
+ | **AdamW** | More efficient (speed, memory) [AdamW](https://arxiv.org/abs/1711.05101) | + Faster than AdamW<br>+ Possibly more (numerically) stable
84
+ | **LaProp** | More efficient (speed, memory) [LaProp](https://arxiv.org/abs/2002.04839) | + Same cost as AdamW<br>+ Marginally better converence (better proofs)<br>+ Higher hyperparameter stability<br>- Not a guaranteed win (can be neutral)<br>- No "Slingshot" |
85
+ | **ADOPT** | More efficient (speed, memory) [ADOPT](https://arxiv.org/abs/2411.02853) | + Same cost as AdamW<br>+ Rigorous mathematical convergence proofs, even for challenging models (GANs)<br>- Empirically underperforms LaProp<br>- no bf16 |
86
+ | **SFAdamW** | More efficient (speed, memory) [ScheduleFree AdamW](https://arxiv.org/abs/2405.15682) | + Same cost as AdamW, but better eval perf<br>+ Full control over hyperparameters |
87
+ | **PaLMSFAdamW** | ForeachSFAdamW with [PaLM's beta2 schedule](https://arxiv.org/abs/2204.02311) | + Same cost as AdamW, but better eval perf<br>+ Less control, but faster early and more stable late convergence<br>+ ScheduleFree<br>- slow early convergence |
88
+ | **SOAP** | More efficient (speed, memory) [SOAP](https://arxiv.org/abs/2409.11321) | + Faster convergence (loss-at-step)<br>+ Full control over hyperparameters<br>- more memory usage<br>- more hyperparameters<br>- higher overhead than AdamW (can be ammortized; better loss-at-second) |
89
+ | **PaLMSOAP** | ForeachSOAP with [PaLM's beta2 schedule](https://arxiv.org/abs/2204.02311) | + Faster convergence (loss-at-step)<br>+ Less control, but faster early and more stable late convergence<br>- more memory usage<br>- more hyperparameters<br>- higher overhead than AdamW (can be ammortized; better loss-at-second) |
90
+ | **SFPaLMSOAP** | ScheduleFree PaLMForeachSOAP | + Fast convergence (loss-at-step)<br>+ less memory usage than PaLMForeachSOAP (more tham AdamW)<br>- slower initial convergence than PaLMForeachSOAP (but allows higher LRs)<br>- higher overhead than AdamW (can be ammortized) |
89
91
  | **PrecondScheduleSFPaLMSOAP** | SFPaLMForeachSOAP with [preconditioner schedule](https://github.com/lixilinx/psgd_torch/), matching the error of PrecondEvery=2 with the cost of PrecondEvery=512 | + Better initial convergence than SFPaLMForeachSOAP<br>+ Significantly faster (sec/it) later<br>+ less memory usage than PaLMForeachSOAP (more tham AdamW)<br>- slower initial convergence than PaLMForeachSOAP (but allows higher LRs)<br>- higher overhead than AdamW (can be ammortized), goes to 0 with increasing number of step |
90
- | **PrecondSchedulePaLMSOAP** | PrecondScheduleSFPaLMForeachSOAP without schedule-free | + Best initial convergence<br>+ Significantly faster (sec/it) later<br>+ high stability<br>- more memory usage than PrecondScheduleSFPaLMForeachSOAP<br>- higher overhead than AdamW (can be ammortized), goes to 0 with increasing number of steps |
91
- | **PrecondScheduleSOAP** | PrecondScheduleSFPaLMForeachSOAP without PaLM's beta2 schedule | + Better initial convergence<br>+ Significantly faster (sec/it) later<br>- more memory usage than PrecondScheduleSFPaLMForeachSOAP<br>- higher overhead than AdamW (can be ammortized), goes to 0 with increasing number of steps |
92
+ | **PrecondSchedulePaLMSOAP** | PrecondScheduleSFPaLMForeachSOAP without schedule-free | + Best initial convergence<br>+ Significantly faster (sec/it) later<br>+ high stability<br>- more memory usage than PrecondScheduleSFPaLMForeachSOAP<br>- higher overhead than AdamW (can be ammortized), goes to 0 with increasing number of steps |
93
+ | **PrecondScheduleSOAP** | PrecondScheduleSFPaLMForeachSOAP without PaLM's beta2 schedule | + Better initial convergence<br>+ Significantly faster (sec/it) later<br>- more memory usage than PrecondScheduleSFPaLMForeachSOAP<br>- higher overhead than AdamW (can be ammortized), goes to 0 with increasing number of steps |
92
94
 
93
95
  ## Precond Schedule
94
96
 
@@ -8,7 +8,7 @@ A simple package of efficient optimizers
8
8
  The goal is not to thrive for completeness, full maintenance or abstraction, but instead to provide a simple
9
9
  largely static alternative to `torch.optim` with more and better optimizers.
10
10
 
11
- Currently (2024-11-21, 0.18.6), the recommended stable optimizer is `PrecondSchedulePaLMSOAP` (see below). The
11
+ Currently (2024-11-22, 0.19), the recommended stable optimizer is `PrecondSchedulePaLMSOAP` (see below). The
12
12
  recommended experimental optimizer is `DelayedPSGDKron` ([tuning guide](docs/psgd_efficiency.md)).
13
13
 
14
14
  ## Features
@@ -21,8 +21,10 @@ recommended experimental optimizer is `DelayedPSGDKron` ([tuning guide](docs/psg
21
21
  * **ScheduleFree**: No learning rate schedule, but better convergence
22
22
  * [**Preconditioner Schedule**](https://github.com/lixilinx/psgd_torch/): Improved loss-per-step in early convergence,
23
23
  better step-per-second in late convergence (explained below)
24
- * **Memory-efficient storage** PSGD supports `store_triu_as_line` (default: `True`) to trade off memory usage for memory
25
- bandwidth; turn it off for lower overheads (for more, see [PSGD Efficiency](docs/psgd_efficiency.md))
24
+ * **Memory-efficient storage** PSGD supports `store_triu_as_line` (default: `True`) and `q_dtype` to trade off memory
25
+ usage for memory
26
+ bandwidth; Other optimizers have `storage_dtype`, supporting lower-precision EMAs at no(?) performance drop via
27
+ stochastic rounding
26
28
 
27
29
  ## Getting started
28
30
 
@@ -52,19 +54,19 @@ for _ in range(1000):
52
54
 
53
55
  ## Optimizers
54
56
 
55
- | Name | Description | Advantages / Disadvantages |
56
- |-------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
57
- | **AdamW** | More efficient (speed, memory) [AdamW](https://arxiv.org/abs/1711.05101) | + Faster than AdamW<br>+ Possibly more (numerically) stable
58
- | **LaProp** | More efficient (speed, memory) [LaProp](https://arxiv.org/abs/2002.04839) | + Same cost as AdamW<br>+ Marginally better converence (better proofs)<br>+ Higher hyperparameter stability<br>- Not a guaranteed win (can be neutral)<br>- No "Slingshot" |
59
- | **ADOPT** | More efficient (speed, memory) [ADOPT](https://arxiv.org/abs/2411.02853) | + Same cost as AdamW<br>+ Rigorous mathematical convergence proofs, even for challenging models (GANs)<br>- Empirically underperforms LaProp<br>- no bf16 |
60
- | **SFAdamW** | More efficient (speed, memory) [ScheduleFree AdamW](https://arxiv.org/abs/2405.15682) | + Same cost as AdamW, but better eval perf<br>+ Full control over hyperparameters |
61
- | **PaLMSFAdamW** | ForeachSFAdamW with [PaLM's beta2 schedule](https://arxiv.org/abs/2204.02311) | + Same cost as AdamW, but better eval perf<br>+ Less control, but faster early and more stable late convergence<br>+ ScheduleFree<br>- slow early convergence |
62
- | **SOAP** | More efficient (speed, memory) [SOAP](https://arxiv.org/abs/2409.11321) | + Faster convergence (loss-at-step)<br>+ Full control over hyperparameters<br>- more memory usage<br>- more hyperparameters<br>- higher overhead than AdamW (can be ammortized; better loss-at-second) |
63
- | **PaLMSOAP** | ForeachSOAP with [PaLM's beta2 schedule](https://arxiv.org/abs/2204.02311) | + Faster convergence (loss-at-step)<br>+ Less control, but faster early and more stable late convergence<br>- more memory usage<br>- more hyperparameters<br>- higher overhead than AdamW (can be ammortized; better loss-at-second) |
64
- | **SFPaLMSOAP** | ScheduleFree PaLMForeachSOAP | + Fast convergence (loss-at-step)<br>+ less memory usage than PaLMForeachSOAP (more tham AdamW)<br>- slower initial convergence than PaLMForeachSOAP (but allows higher LRs)<br>- higher overhead than AdamW (can be ammortized) |
57
+ | Name | Description | Advantages / Disadvantages |
58
+ |-------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
59
+ | **AdamW** | More efficient (speed, memory) [AdamW](https://arxiv.org/abs/1711.05101) | + Faster than AdamW<br>+ Possibly more (numerically) stable
60
+ | **LaProp** | More efficient (speed, memory) [LaProp](https://arxiv.org/abs/2002.04839) | + Same cost as AdamW<br>+ Marginally better converence (better proofs)<br>+ Higher hyperparameter stability<br>- Not a guaranteed win (can be neutral)<br>- No "Slingshot" |
61
+ | **ADOPT** | More efficient (speed, memory) [ADOPT](https://arxiv.org/abs/2411.02853) | + Same cost as AdamW<br>+ Rigorous mathematical convergence proofs, even for challenging models (GANs)<br>- Empirically underperforms LaProp<br>- no bf16 |
62
+ | **SFAdamW** | More efficient (speed, memory) [ScheduleFree AdamW](https://arxiv.org/abs/2405.15682) | + Same cost as AdamW, but better eval perf<br>+ Full control over hyperparameters |
63
+ | **PaLMSFAdamW** | ForeachSFAdamW with [PaLM's beta2 schedule](https://arxiv.org/abs/2204.02311) | + Same cost as AdamW, but better eval perf<br>+ Less control, but faster early and more stable late convergence<br>+ ScheduleFree<br>- slow early convergence |
64
+ | **SOAP** | More efficient (speed, memory) [SOAP](https://arxiv.org/abs/2409.11321) | + Faster convergence (loss-at-step)<br>+ Full control over hyperparameters<br>- more memory usage<br>- more hyperparameters<br>- higher overhead than AdamW (can be ammortized; better loss-at-second) |
65
+ | **PaLMSOAP** | ForeachSOAP with [PaLM's beta2 schedule](https://arxiv.org/abs/2204.02311) | + Faster convergence (loss-at-step)<br>+ Less control, but faster early and more stable late convergence<br>- more memory usage<br>- more hyperparameters<br>- higher overhead than AdamW (can be ammortized; better loss-at-second) |
66
+ | **SFPaLMSOAP** | ScheduleFree PaLMForeachSOAP | + Fast convergence (loss-at-step)<br>+ less memory usage than PaLMForeachSOAP (more tham AdamW)<br>- slower initial convergence than PaLMForeachSOAP (but allows higher LRs)<br>- higher overhead than AdamW (can be ammortized) |
65
67
  | **PrecondScheduleSFPaLMSOAP** | SFPaLMForeachSOAP with [preconditioner schedule](https://github.com/lixilinx/psgd_torch/), matching the error of PrecondEvery=2 with the cost of PrecondEvery=512 | + Better initial convergence than SFPaLMForeachSOAP<br>+ Significantly faster (sec/it) later<br>+ less memory usage than PaLMForeachSOAP (more tham AdamW)<br>- slower initial convergence than PaLMForeachSOAP (but allows higher LRs)<br>- higher overhead than AdamW (can be ammortized), goes to 0 with increasing number of step |
66
- | **PrecondSchedulePaLMSOAP** | PrecondScheduleSFPaLMForeachSOAP without schedule-free | + Best initial convergence<br>+ Significantly faster (sec/it) later<br>+ high stability<br>- more memory usage than PrecondScheduleSFPaLMForeachSOAP<br>- higher overhead than AdamW (can be ammortized), goes to 0 with increasing number of steps |
67
- | **PrecondScheduleSOAP** | PrecondScheduleSFPaLMForeachSOAP without PaLM's beta2 schedule | + Better initial convergence<br>+ Significantly faster (sec/it) later<br>- more memory usage than PrecondScheduleSFPaLMForeachSOAP<br>- higher overhead than AdamW (can be ammortized), goes to 0 with increasing number of steps |
68
+ | **PrecondSchedulePaLMSOAP** | PrecondScheduleSFPaLMForeachSOAP without schedule-free | + Best initial convergence<br>+ Significantly faster (sec/it) later<br>+ high stability<br>- more memory usage than PrecondScheduleSFPaLMForeachSOAP<br>- higher overhead than AdamW (can be ammortized), goes to 0 with increasing number of steps |
69
+ | **PrecondScheduleSOAP** | PrecondScheduleSFPaLMForeachSOAP without PaLM's beta2 schedule | + Better initial convergence<br>+ Significantly faster (sec/it) later<br>- more memory usage than PrecondScheduleSFPaLMForeachSOAP<br>- higher overhead than AdamW (can be ammortized), goes to 0 with increasing number of steps |
68
70
 
69
71
  ## Precond Schedule
70
72
 
@@ -0,0 +1,56 @@
1
+ import torch
2
+ import torch.optim
3
+
4
+ from heavyball.utils import copy_stochastic_list_
5
+ from .utils import warmup, exp_avg_sq_, beta_debias, update_param_, StatefulOptimizer, promote
6
+
7
+
8
+ @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
9
+ def _compilable_step_(y, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps, decay):
10
+ g32, exp_avg32, exp_avg_sq32 = [list(map(promote, x)) for x in [grad, exp_avg, exp_avg_sq]]
11
+
12
+ torch._foreach_lerp_(exp_avg32, g32, 1 - beta_debias(beta1, step + 1))
13
+ denom = list(exp_avg_sq_(exp_avg_sq32, g32, beta_debias(beta2, step + 1), eps))
14
+
15
+ update_param_(y, exp_avg32, lr, decay, lambda p, e, l: p.addcdiv_(e, denom.pop(0), value=l))
16
+
17
+ copy_stochastic_list_(exp_avg, exp_avg32)
18
+ copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
19
+
20
+
21
+ class ForeachAdamW(StatefulOptimizer):
22
+ def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
23
+ foreach: bool = True, storage_dtype: str = 'float32'):
24
+ defaults = dict(lr=lr, betas=betas, eps=eps, k=0, warmup_steps=warmup_steps, train_mode=True, weight_sum=0.0,
25
+ lr_max=-1.0, weight_decay=weight_decay, storage_dtype=storage_dtype)
26
+ super().__init__(params, defaults, foreach)
27
+
28
+ def _step(self, group):
29
+ eps = group['eps']
30
+ decay = group['weight_decay']
31
+ k = group['k']
32
+
33
+ if not group['train_mode']:
34
+ raise Exception("Not in train mode!")
35
+
36
+ active_p = [p for p in group['params'] if p.grad is not None]
37
+
38
+ if not active_p:
39
+ return
40
+
41
+ storage_dtype = getattr(torch, group['storage_dtype'])
42
+
43
+ for p in active_p:
44
+ if 'exp_avg' not in self.state_(p):
45
+ self.state_(p)['exp_avg'] = torch.zeros_like(p.data, dtype=storage_dtype)
46
+ self.state_(p)['exp_avg_sq'] = torch.zeros_like(p.data, dtype=storage_dtype)
47
+
48
+ y, grad, exp_avg_sq, exp_avg = zip(
49
+ *[(p.data, p.grad, self.state_(p)['exp_avg_sq'], self.state_(p)['exp_avg']) for p in active_p])
50
+
51
+ lr = -warmup(group['lr'], k + 1, group['warmup_steps'])
52
+ lr = torch.empty((), dtype=torch.float32, device=y[0].device).fill_(lr)
53
+ step = torch.empty((), dtype=torch.int32, device=y[0].device).fill_(k)
54
+ _compilable_step_(y, grad, exp_avg_sq, exp_avg, group['betas'][0], group['betas'][1], step, lr, eps, decay)
55
+
56
+ group['k'] = k + 1
@@ -0,0 +1,78 @@
1
+ import torch
2
+ import torch.optim
3
+ from heavyball.utils import copy_stochastic_list_
4
+
5
+ from .utils import warmup, beta_debias, update_param_, StatefulOptimizer, promote
6
+
7
+
8
+ @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
9
+ def _compilable_step_(y, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps, decay):
10
+ g32, exp_avg32, exp_avg_sq32 = [list(map(promote, x)) for x in [grad, exp_avg, exp_avg_sq]]
11
+ update_param_(y, exp_avg, lr, decay)
12
+
13
+ beta1 = beta_debias(beta1, step)
14
+ denom = torch._foreach_sqrt(exp_avg_sq32)
15
+ torch._foreach_maximum_(denom, eps)
16
+ torch._foreach_mul_(exp_avg32, beta1)
17
+ [ea32.addcdiv_(g, d, value=1 - beta1) for ea32, g, d in zip(exp_avg32, g32, denom)]
18
+
19
+ beta2 = beta_debias(beta2, step + 1)
20
+ torch._foreach_mul_(exp_avg_sq32, beta2)
21
+ [eas32.addcmul_(g, g, value=1 - beta2) for eas32, g in zip(exp_avg_sq32, g32)]
22
+
23
+ copy_stochastic_list_(exp_avg, exp_avg32)
24
+ copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
25
+
26
+
27
+ class ForeachADOPT(StatefulOptimizer):
28
+
29
+ def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
30
+ foreach: bool = True, storage_dtype: str = 'float32'):
31
+ defaults = dict(lr=lr, betas=betas, eps=eps, k=0, warmup_steps=warmup_steps, train_mode=True, weight_sum=0.0,
32
+ lr_max=-1.0, weight_decay=weight_decay, storage_dtype=storage_dtype)
33
+ super().__init__(params, defaults, foreach)
34
+
35
+ def _step(self, group):
36
+ eps = group['eps']
37
+ decay = group['weight_decay']
38
+ k = group['k']
39
+
40
+ if not group['train_mode']:
41
+ raise Exception("Not in train mode!")
42
+
43
+ active_p = [p for p in group['params'] if p.grad is not None]
44
+
45
+ if not active_p:
46
+ return
47
+
48
+ storage_dtype = getattr(torch, group['storage_dtype'])
49
+
50
+ for p in active_p:
51
+ if 'exp_avg' not in self.state_(p):
52
+ self.state_(p)['exp_avg'] = torch.zeros_like(p.data, dtype=storage_dtype)
53
+ self.state_(p)['exp_avg_sq'] = torch.zeros_like(p.data, dtype=storage_dtype)
54
+
55
+ y, grad, exp_avg_sq, exp_avg = zip(
56
+ *[(p.data, p.grad, self.state_(p)['exp_avg_sq'], self.state_(p)['exp_avg']) for p in active_p])
57
+
58
+ group['k'] = k + 1
59
+
60
+ if k > 1:
61
+ lr = -warmup(group['lr'], k - 1, group['warmup_steps'])
62
+ lr = torch.empty((), dtype=torch.float32, device=y[0].device).fill_(lr)
63
+ k = torch.empty((), dtype=torch.int32, device=y[0].device).fill_(k)
64
+ _compilable_step_(y, grad, exp_avg_sq, exp_avg, group['betas'][0], group['betas'][1], k, lr, eps, decay)
65
+ return
66
+
67
+ grad = [promote(g) for g in grad]
68
+ if k > 0:
69
+ beta1 = beta_debias(group['betas'][0], k)
70
+ denom = torch._foreach_sqrt(exp_avg_sq)
71
+ torch._foreach_maximum_(denom, eps)
72
+ torch._foreach_mul_(exp_avg, beta1)
73
+ torch._foreach_addcdiv_(exp_avg, grad, denom, 1 - beta1)
74
+
75
+ beta2 = beta_debias(group['betas'][1], k + 1)
76
+ torch._foreach_mul_(exp_avg_sq, beta2)
77
+ torch._foreach_addcmul_(exp_avg_sq, grad, grad, value=1 - beta2)
78
+ del grad
@@ -0,0 +1,61 @@
1
+ import torch
2
+ import torch.optim
3
+
4
+ from .utils import warmup, exp_avg_sq_, beta_debias, update_param_, StatefulOptimizer, promote, copy_stochastic_list_
5
+
6
+
7
+ @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
8
+ def _compilable_step_(y, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps, decay):
9
+ g32, exp_avg32, exp_avg_sq32 = [list(map(promote, x)) for x in [grad, exp_avg, exp_avg_sq]]
10
+
11
+ denom = exp_avg_sq_(exp_avg_sq32, g32, beta_debias(beta2, step), eps)
12
+
13
+ beta1 = beta_debias(beta1, step)
14
+ torch._foreach_mul_(exp_avg32, beta1)
15
+ [ea32.addcdiv_(g, d, value=1 - beta1) for ea32, g, d in zip(exp_avg32, g32, denom)]
16
+
17
+ update_param_(y, exp_avg32, lr, decay)
18
+
19
+ copy_stochastic_list_(exp_avg, exp_avg32)
20
+ copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
21
+
22
+
23
+ class ForeachLaProp(StatefulOptimizer):
24
+
25
+ def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=1,
26
+ foreach: bool = True, storage_dtype: str = 'float32'):
27
+ defaults = dict(lr=lr, betas=betas, eps=eps, k=0, warmup_steps=warmup_steps, train_mode=True, weight_sum=0.0,
28
+ lr_max=-1.0, weight_decay=weight_decay, storage_dtype=storage_dtype)
29
+ super().__init__(params, defaults, foreach)
30
+
31
+ def _step(self, group):
32
+ eps = group['eps']
33
+ decay = group['weight_decay']
34
+ k = group['k']
35
+
36
+ if not group['train_mode']:
37
+ raise Exception("Not in train mode!")
38
+
39
+ active_p = [p for p in group['params'] if p.grad is not None]
40
+
41
+ if not active_p:
42
+ return
43
+
44
+ storage_dtype = getattr(torch, group['storage_dtype'])
45
+
46
+ for p in active_p:
47
+ if 'exp_avg' not in self.state_(p):
48
+ self.state_(p)['exp_avg'] = torch.zeros_like(p.data, dtype=storage_dtype)
49
+ self.state_(p)['exp_avg_sq'] = torch.zeros_like(p.data, dtype=storage_dtype)
50
+
51
+ y, grad, exp_avg_sq, exp_avg = zip(
52
+ *[(p.data, p.grad, self.state_(p)['exp_avg_sq'], self.state_(p)['exp_avg']) #
53
+ for p in active_p])
54
+
55
+ lr = -warmup(group['lr'], k + 1, group['warmup_steps'])
56
+ lr = torch.empty((), dtype=torch.float32, device=y[0].device).fill_(lr)
57
+ step = torch.empty((), dtype=torch.int32, device=y[0].device).fill_(k + 1)
58
+
59
+ _compilable_step_(y, grad, exp_avg_sq, exp_avg, group['betas'][0], group['betas'][1], step, lr, eps, decay)
60
+
61
+ group['k'] = k + 1
@@ -0,0 +1,63 @@
1
+ import torch
2
+ import torch.optim
3
+ from heavyball.utils import get_ckp1, copy_stochastic_list_
4
+
5
+ from .utils import warmup, ScheduleFree, exp_avg_sq_, beta_debias, promote, _compilable_schedule_free_
6
+
7
+
8
+ @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
9
+ def _compilable_step_(y, grad, exp_avg_sq, z, beta1, beta2, step, ckp1, eps, decay, lr):
10
+ old_debiased2 = beta_debias(beta2, step)
11
+
12
+ g32 = [promote(g_) for g_ in grad]
13
+ exp_avg_sq32 = [promote(e_) for e_ in exp_avg_sq]
14
+
15
+ denom = exp_avg_sq_(exp_avg_sq32, g32, old_debiased2, eps)
16
+ torch._foreach_div_(g32, denom)
17
+ if decay != 0:
18
+ torch._foreach_add_(g32, y, alpha=decay)
19
+ for p, z_, g in zip(y, z, g32):
20
+ _compilable_schedule_free_(p, z_, ckp1, g, lr, beta1)
21
+
22
+ copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
23
+
24
+ class ForeachSFAdamW(ScheduleFree):
25
+ def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0, r=0.0,
26
+ weight_lr_power=2.0, foreach: bool = True, storage_dtype: str = 'float32'):
27
+
28
+ defaults = dict(lr=lr, betas=betas, eps=eps, r=r, k=0, warmup_steps=warmup_steps, train_mode=True,
29
+ weight_sum=0.0, lr_max=-1.0, weight_lr_power=weight_lr_power, weight_decay=weight_decay,
30
+ foreach=foreach, storage_dtype=storage_dtype)
31
+ super().__init__(params, defaults, foreach)
32
+
33
+ def _step(self, group):
34
+ eps = group['eps']
35
+ decay = group['weight_decay']
36
+ k = group['k']
37
+
38
+ if not group['train_mode']:
39
+ raise Exception("Not in train mode!")
40
+
41
+ active_p = [p for p in group['params'] if p.grad is not None]
42
+
43
+ if not active_p:
44
+ return
45
+
46
+ storage_dtype = getattr(torch, group['storage_dtype'])
47
+
48
+ for p in active_p:
49
+ if 'z' not in self.state_(p):
50
+ self.state_(p)['z'] = torch.clone(p.data)
51
+ self.state_(p)['exp_avg_sq'] = torch.zeros_like(p.data, dtype=storage_dtype)
52
+
53
+ y, grad, exp_avg_sq, z = zip(*[(p.data, p.grad, self.state_(p)['exp_avg_sq'], self.state_(p)['z']) #
54
+ for p in active_p])
55
+
56
+ lr = warmup(group['lr'], k + 1, group['warmup_steps'])
57
+ ckp1, group['weight_sum'] = get_ckp1(lr, group['weight_lr_power'], group['weight_sum'], group['r'], k + 1)
58
+
59
+ step = torch.empty((), dtype=torch.int32, device=y[0].device).fill_(k + 1)
60
+ ckp1 = torch.empty((), dtype=torch.float32, device=y[0].device).fill_(ckp1)
61
+ lr = torch.empty((), dtype=torch.float32, device=y[0].device).fill_(lr)
62
+ _compilable_step_(y, grad, exp_avg_sq, z, group['betas'][0], group['betas'][1], step, ckp1, eps, decay, lr)
63
+ group['k'] = k + 1
@@ -0,0 +1,69 @@
1
+ import torch
2
+ import torch.optim
3
+
4
+ from .utils import warmup, ScheduleFree, exp_avg_sq_, beta_debias, get_ckp1, promote, \
5
+ _compilable_schedule_free_, copy_stochastic_list_
6
+
7
+
8
+ @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
9
+ def _compilable_step_(y, grad, exp_avg_sq, z, beta1, beta2, step, ckp1, eps, decay, lr):
10
+ old_debiased2 = beta_debias(beta2, step)
11
+
12
+ g32 = [promote(g_) for g_ in grad]
13
+ exp_avg_sq32 = [promote(e_) for e_ in exp_avg_sq]
14
+
15
+ denom = exp_avg_sq_(exp_avg_sq32, g32, old_debiased2, eps)
16
+ torch._foreach_div_(g32, denom)
17
+ if decay != 0:
18
+ torch._foreach_add_(g32, y, alpha=decay)
19
+ for p, z_, g in zip(y, z, g32):
20
+ _compilable_schedule_free_(p, z_, ckp1, g, lr, beta1)
21
+
22
+ copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
23
+
24
+
25
+ class PaLMForeachSFAdamW(ScheduleFree):
26
+ def __init__(self, params, lr=0.0025, beta=0.9, betas=(None, None), eps=1e-8, weight_decay=0, warmup_steps=0, r=0.0,
27
+ weight_lr_power=2.0, beta2_scale: float = 0.8, foreach: bool = True, storage_dtype: str = 'float32'):
28
+ if betas[0] is not None:
29
+ beta = betas[0]
30
+ defaults = dict(lr=lr, beta=beta, eps=eps, r=r, k=0, warmup_steps=warmup_steps, train_mode=True, weight_sum=0.0,
31
+ lr_max=-1.0, weight_lr_power=weight_lr_power, weight_decay=weight_decay,
32
+ beta2_scale=beta2_scale, storage_dtype=storage_dtype)
33
+ super().__init__(params, defaults, foreach)
34
+
35
+ def _step(self, group):
36
+ eps = group['eps']
37
+ decay = group['weight_decay']
38
+ k = group['k']
39
+
40
+ if not group['train_mode']:
41
+ raise Exception("Not in train mode!")
42
+
43
+ active_p = [p for p in group['params'] if p.grad is not None]
44
+
45
+ if not active_p:
46
+ return
47
+
48
+ storage_dtype = getattr(torch, group['storage_dtype'])
49
+
50
+ for p in active_p:
51
+ if 'z' not in self.state_(p):
52
+ self.state_(p)['z'] = torch.clone(p.data)
53
+ self.state_(p)['exp_avg_sq'] = torch.zeros_like(p.data, dtype=storage_dtype)
54
+
55
+ # Decay the first moment running average coefficient
56
+ beta2 = 1 - (k + 1) ** -group['beta2_scale']
57
+
58
+ y, grad, exp_avg_sq, z = zip(*[(p.data, p.grad, self.state_(p)['exp_avg_sq'], self.state_(p)['z']) #
59
+ for p in active_p])
60
+
61
+ lr = warmup(group['lr'], k + 1, group['warmup_steps'])
62
+ ckp1, group['weight_sum'] = get_ckp1(lr, group['weight_lr_power'], group['weight_sum'], group['r'], k + 1)
63
+
64
+ step = torch.empty((), dtype=torch.int32, device=y[0].device).fill_(k + 1)
65
+ ckp1 = torch.empty((), dtype=torch.float32, device=y[0].device).fill_(ckp1)
66
+ beta2 = torch.empty((), dtype=torch.float32, device=y[0].device).fill_(beta2)
67
+ lr = torch.empty((), dtype=torch.float32, device=y[0].device).fill_(lr)
68
+ _compilable_step_(y, grad, exp_avg_sq, z, group['beta'], beta2, step, ckp1, eps, decay, lr)
69
+ group['k'] = k + 1
@@ -104,7 +104,8 @@ class ForeachPSGDKron(PSGDBase):
104
104
 
105
105
  if should_update:
106
106
  q32 = [promote(q_) for q_ in q]
107
- self.do_update(group, [p], [g], [q32], precond_lr, [q_orig], store_triu_as_line)
107
+ self.do_update(group, [p], [ea if momentum_into_precond_update else g], [q32], precond_lr, [q_orig],
108
+ store_triu_as_line)
108
109
  set_(g, psgd_precond_grad(q, self.state_(p)["exprs"], ea))
109
110
 
110
111
  grad_list = self.clip_fn(grad_list)
@@ -40,14 +40,25 @@ def warmup(lr: float, step: int, warmup_steps: int):
40
40
 
41
41
  @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
42
42
  def _compilable_schedule_free_(p, z, ckp1, grad, lr, beta1):
43
- p32 = p.float()
44
- z32 = z.float()
45
- p32.lerp_(end=z32, weight=1 - ckp1)
43
+ p32 = promote(p)
44
+ z32 = promote(z)
45
+ p32.lerp_(end=z32, weight=ckp1)
46
46
  p32.add_(grad, alpha=lr * (beta1 * (1 - ckp1) - 1))
47
- _guarded_copy_stochastic(p, p32)
47
+ copy_stochastic_(p, p32)
48
48
 
49
49
  z32.add_(grad, alpha=-lr)
50
- _guarded_copy_stochastic(z, z32)
50
+ copy_stochastic_(z, z32)
51
+
52
+
53
+ def get_ckp1(lr, weight_lr_power, weight_sum, r, step):
54
+ weight = lr ** weight_lr_power * max(step, 1) ** r
55
+ weight_sum = weight_sum + weight
56
+
57
+ try:
58
+ ckp1 = weight / weight_sum
59
+ except ZeroDivisionError:
60
+ ckp1 = 0
61
+ return ckp1, weight_sum
51
62
 
52
63
 
53
64
  def schedule_free_(lr: float, weight_lr_power: float, weight_sum: float, beta1: float, parameters: List[torch.Tensor],
@@ -136,7 +147,7 @@ def exp_avg_sq_(state, grad, beta2, eps, out=None):
136
147
  return torch.sqrt(state, out=out).clamp_(min=eps)
137
148
 
138
149
  torch._foreach_mul_(state, beta2)
139
- torch._foreach_addcmul_(state, grad, grad, value=1 - beta2)
150
+ [s.addcmul_(g, g, value=1 - beta2) for s, g in zip(state, grad)]
140
151
  denom = torch._foreach_sqrt(state)
141
152
  torch._foreach_maximum_(denom, eps)
142
153
  return denom
@@ -332,9 +343,9 @@ def compute_ggt(grad, GG, max_precond_dim, precondition_1d, beta):
332
343
 
333
344
 
334
345
  def promote(x):
335
- if x in (torch.bfloat16, torch.float16):
346
+ if isinstance(x, torch.dtype) and x in (torch.bfloat16, torch.float16):
336
347
  return torch.float32
337
- if hasattr(x, 'dtype') and x.dtype in (torch.bfloat16, torch.float16):
348
+ if isinstance(x, torch.Tensor) and x.dtype in (torch.bfloat16, torch.float16):
338
349
  return x.float()
339
350
  return x
340
351
 
@@ -486,13 +497,8 @@ def copy_stochastic_list_(target: List[torch.Tensor], source: List[torch.Tensor]
486
497
  copy_stochastic_(t, s)
487
498
 
488
499
 
489
- def _guarded_copy_stochastic(target: torch.Tensor, source: torch.Tensor):
490
- if target.dtype != torch.bfloat16 or source.dtype not in (torch.float16, torch.float32, torch.float64):
491
- set_(target, source)
492
- _compilable_copy_stochastic_(target, source)
493
-
494
-
495
- @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
500
+ # this can be dynamic for most optimizers - just not for PSGD. So, it's disabled for all
501
+ @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True)
496
502
  def _compilable_copy_stochastic_(target: torch.Tensor, source: torch.Tensor):
497
503
  """Taken as-is from https://github.com/pytorch/pytorch/issues/120376#issuecomment-1974828905"""
498
504
  # create a random 16 bit integer
@@ -509,22 +515,24 @@ def _compilable_copy_stochastic_(target: torch.Tensor, source: torch.Tensor):
509
515
 
510
516
 
511
517
  def copy_stochastic_(target: torch.Tensor, source: torch.Tensor):
512
- if target.data_ptr() == source.data_ptr():
518
+ if not torch.compiler.is_compiling() and target.data_ptr() == source.data_ptr():
513
519
  return
514
- _guarded_copy_stochastic(target, source)
520
+ if target.dtype != torch.bfloat16 or source.dtype not in (torch.float16, torch.float32, torch.float64):
521
+ set_(target, source)
522
+ _compilable_copy_stochastic_(target, source)
515
523
 
516
524
 
517
525
  @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
518
526
  def _compilable_update_one_(p, u, decay, add_fn, lr):
519
- p32 = p.float()
520
- u32 = u.view(p.shape).float()
527
+ p32 = promote(p)
528
+ u32 = promote(u.view(p.shape))
521
529
  if decay > 0:
522
530
  p32.mul_(1 - decay * lr)
523
531
  if add_fn is None:
524
532
  p32.add_(u32, alpha=lr)
525
533
  else:
526
534
  add_fn(p32, u32, lr)
527
- _guarded_copy_stochastic(p, p32)
535
+ copy_stochastic_(p, p32)
528
536
 
529
537
 
530
538
  def update_param_(param: List[torch.Tensor], update: List[torch.Tensor], lr: float, decay: float,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: heavyball
3
- Version: 0.18.7
3
+ Version: 0.19.0
4
4
  Summary: Efficient optimizers
5
5
  Home-page: https://github.com/clashluke/heavyball
6
6
  Author: Lucas Nestler
@@ -32,7 +32,7 @@ A simple package of efficient optimizers
32
32
  The goal is not to thrive for completeness, full maintenance or abstraction, but instead to provide a simple
33
33
  largely static alternative to `torch.optim` with more and better optimizers.
34
34
 
35
- Currently (2024-11-21, 0.18.6), the recommended stable optimizer is `PrecondSchedulePaLMSOAP` (see below). The
35
+ Currently (2024-11-22, 0.19), the recommended stable optimizer is `PrecondSchedulePaLMSOAP` (see below). The
36
36
  recommended experimental optimizer is `DelayedPSGDKron` ([tuning guide](docs/psgd_efficiency.md)).
37
37
 
38
38
  ## Features
@@ -45,8 +45,10 @@ recommended experimental optimizer is `DelayedPSGDKron` ([tuning guide](docs/psg
45
45
  * **ScheduleFree**: No learning rate schedule, but better convergence
46
46
  * [**Preconditioner Schedule**](https://github.com/lixilinx/psgd_torch/): Improved loss-per-step in early convergence,
47
47
  better step-per-second in late convergence (explained below)
48
- * **Memory-efficient storage** PSGD supports `store_triu_as_line` (default: `True`) to trade off memory usage for memory
49
- bandwidth; turn it off for lower overheads (for more, see [PSGD Efficiency](docs/psgd_efficiency.md))
48
+ * **Memory-efficient storage** PSGD supports `store_triu_as_line` (default: `True`) and `q_dtype` to trade off memory
49
+ usage for memory
50
+ bandwidth; Other optimizers have `storage_dtype`, supporting lower-precision EMAs at no(?) performance drop via
51
+ stochastic rounding
50
52
 
51
53
  ## Getting started
52
54
 
@@ -76,19 +78,19 @@ for _ in range(1000):
76
78
 
77
79
  ## Optimizers
78
80
 
79
- | Name | Description | Advantages / Disadvantages |
80
- |-------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
81
- | **AdamW** | More efficient (speed, memory) [AdamW](https://arxiv.org/abs/1711.05101) | + Faster than AdamW<br>+ Possibly more (numerically) stable
82
- | **LaProp** | More efficient (speed, memory) [LaProp](https://arxiv.org/abs/2002.04839) | + Same cost as AdamW<br>+ Marginally better converence (better proofs)<br>+ Higher hyperparameter stability<br>- Not a guaranteed win (can be neutral)<br>- No "Slingshot" |
83
- | **ADOPT** | More efficient (speed, memory) [ADOPT](https://arxiv.org/abs/2411.02853) | + Same cost as AdamW<br>+ Rigorous mathematical convergence proofs, even for challenging models (GANs)<br>- Empirically underperforms LaProp<br>- no bf16 |
84
- | **SFAdamW** | More efficient (speed, memory) [ScheduleFree AdamW](https://arxiv.org/abs/2405.15682) | + Same cost as AdamW, but better eval perf<br>+ Full control over hyperparameters |
85
- | **PaLMSFAdamW** | ForeachSFAdamW with [PaLM's beta2 schedule](https://arxiv.org/abs/2204.02311) | + Same cost as AdamW, but better eval perf<br>+ Less control, but faster early and more stable late convergence<br>+ ScheduleFree<br>- slow early convergence |
86
- | **SOAP** | More efficient (speed, memory) [SOAP](https://arxiv.org/abs/2409.11321) | + Faster convergence (loss-at-step)<br>+ Full control over hyperparameters<br>- more memory usage<br>- more hyperparameters<br>- higher overhead than AdamW (can be ammortized; better loss-at-second) |
87
- | **PaLMSOAP** | ForeachSOAP with [PaLM's beta2 schedule](https://arxiv.org/abs/2204.02311) | + Faster convergence (loss-at-step)<br>+ Less control, but faster early and more stable late convergence<br>- more memory usage<br>- more hyperparameters<br>- higher overhead than AdamW (can be ammortized; better loss-at-second) |
88
- | **SFPaLMSOAP** | ScheduleFree PaLMForeachSOAP | + Fast convergence (loss-at-step)<br>+ less memory usage than PaLMForeachSOAP (more tham AdamW)<br>- slower initial convergence than PaLMForeachSOAP (but allows higher LRs)<br>- higher overhead than AdamW (can be ammortized) |
81
+ | Name | Description | Advantages / Disadvantages |
82
+ |-------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
83
+ | **AdamW** | More efficient (speed, memory) [AdamW](https://arxiv.org/abs/1711.05101) | + Faster than AdamW<br>+ Possibly more (numerically) stable
84
+ | **LaProp** | More efficient (speed, memory) [LaProp](https://arxiv.org/abs/2002.04839) | + Same cost as AdamW<br>+ Marginally better converence (better proofs)<br>+ Higher hyperparameter stability<br>- Not a guaranteed win (can be neutral)<br>- No "Slingshot" |
85
+ | **ADOPT** | More efficient (speed, memory) [ADOPT](https://arxiv.org/abs/2411.02853) | + Same cost as AdamW<br>+ Rigorous mathematical convergence proofs, even for challenging models (GANs)<br>- Empirically underperforms LaProp<br>- no bf16 |
86
+ | **SFAdamW** | More efficient (speed, memory) [ScheduleFree AdamW](https://arxiv.org/abs/2405.15682) | + Same cost as AdamW, but better eval perf<br>+ Full control over hyperparameters |
87
+ | **PaLMSFAdamW** | ForeachSFAdamW with [PaLM's beta2 schedule](https://arxiv.org/abs/2204.02311) | + Same cost as AdamW, but better eval perf<br>+ Less control, but faster early and more stable late convergence<br>+ ScheduleFree<br>- slow early convergence |
88
+ | **SOAP** | More efficient (speed, memory) [SOAP](https://arxiv.org/abs/2409.11321) | + Faster convergence (loss-at-step)<br>+ Full control over hyperparameters<br>- more memory usage<br>- more hyperparameters<br>- higher overhead than AdamW (can be ammortized; better loss-at-second) |
89
+ | **PaLMSOAP** | ForeachSOAP with [PaLM's beta2 schedule](https://arxiv.org/abs/2204.02311) | + Faster convergence (loss-at-step)<br>+ Less control, but faster early and more stable late convergence<br>- more memory usage<br>- more hyperparameters<br>- higher overhead than AdamW (can be ammortized; better loss-at-second) |
90
+ | **SFPaLMSOAP** | ScheduleFree PaLMForeachSOAP | + Fast convergence (loss-at-step)<br>+ less memory usage than PaLMForeachSOAP (more tham AdamW)<br>- slower initial convergence than PaLMForeachSOAP (but allows higher LRs)<br>- higher overhead than AdamW (can be ammortized) |
89
91
  | **PrecondScheduleSFPaLMSOAP** | SFPaLMForeachSOAP with [preconditioner schedule](https://github.com/lixilinx/psgd_torch/), matching the error of PrecondEvery=2 with the cost of PrecondEvery=512 | + Better initial convergence than SFPaLMForeachSOAP<br>+ Significantly faster (sec/it) later<br>+ less memory usage than PaLMForeachSOAP (more tham AdamW)<br>- slower initial convergence than PaLMForeachSOAP (but allows higher LRs)<br>- higher overhead than AdamW (can be ammortized), goes to 0 with increasing number of step |
90
- | **PrecondSchedulePaLMSOAP** | PrecondScheduleSFPaLMForeachSOAP without schedule-free | + Best initial convergence<br>+ Significantly faster (sec/it) later<br>+ high stability<br>- more memory usage than PrecondScheduleSFPaLMForeachSOAP<br>- higher overhead than AdamW (can be ammortized), goes to 0 with increasing number of steps |
91
- | **PrecondScheduleSOAP** | PrecondScheduleSFPaLMForeachSOAP without PaLM's beta2 schedule | + Better initial convergence<br>+ Significantly faster (sec/it) later<br>- more memory usage than PrecondScheduleSFPaLMForeachSOAP<br>- higher overhead than AdamW (can be ammortized), goes to 0 with increasing number of steps |
92
+ | **PrecondSchedulePaLMSOAP** | PrecondScheduleSFPaLMForeachSOAP without schedule-free | + Best initial convergence<br>+ Significantly faster (sec/it) later<br>+ high stability<br>- more memory usage than PrecondScheduleSFPaLMForeachSOAP<br>- higher overhead than AdamW (can be ammortized), goes to 0 with increasing number of steps |
93
+ | **PrecondScheduleSOAP** | PrecondScheduleSFPaLMForeachSOAP without PaLM's beta2 schedule | + Better initial convergence<br>+ Significantly faster (sec/it) later<br>- more memory usage than PrecondScheduleSFPaLMForeachSOAP<br>- higher overhead than AdamW (can be ammortized), goes to 0 with increasing number of steps |
92
94
 
93
95
  ## Precond Schedule
94
96
 
@@ -27,6 +27,7 @@ heavyball.egg-info/requires.txt
27
27
  heavyball.egg-info/top_level.txt
28
28
  test/test_bf16_params.py
29
29
  test/test_bf16_q.py
30
+ test/test_bf16_storage.py
30
31
  test/test_closure.py
31
32
  test/test_foreach.py
32
33
  test/test_memory.py
@@ -10,7 +10,7 @@ setuptools.setup(
10
10
  name='heavyball',
11
11
  license='BSD',
12
12
  description='Efficient optimizers',
13
- version='0.18.7',
13
+ version='0.19.0',
14
14
  long_description=README,
15
15
  url='https://github.com/clashluke/heavyball',
16
16
  packages=setuptools.find_packages(),
@@ -22,7 +22,6 @@ def get_memory():
22
22
  @pytest.mark.parametrize("size,depth", [(256, 2)])
23
23
  def test_foreach(opt, size, depth: int, iterations: int = 128, outer_iterations: int = 3):
24
24
  set_torch()
25
-
26
25
  opt = getattr(heavyball, opt)
27
26
 
28
27
  peaks = []
@@ -0,0 +1,60 @@
1
+ import pytest
2
+ import torch
3
+ from torch import nn
4
+ from torch._dynamo import config
5
+
6
+ import heavyball
7
+ import heavyball.utils
8
+ from benchmark.utils import get_optim
9
+ from heavyball.utils import clean, set_torch, PSGDBase
10
+
11
+ config.cache_size_limit = 128
12
+
13
+ def get_memory():
14
+ clean()
15
+ torch.cuda.synchronize()
16
+ clean()
17
+ torch.cuda.synchronize()
18
+ return torch.cuda.memory_allocated()
19
+
20
+
21
+ @pytest.mark.parametrize("opt", heavyball.__all__)
22
+ @pytest.mark.parametrize("size,depth", [(256, 2)])
23
+ def test_foreach(opt, size, depth: int, iterations: int = 128, outer_iterations: int = 3):
24
+ set_torch()
25
+
26
+ if 'soap' in opt.lower():
27
+ raise pytest.skip('soap is not supported')
28
+
29
+ opt = getattr(heavyball, opt)
30
+
31
+ if PSGDBase in opt.__mro__:
32
+ raise pytest.skip('PSGD is not supported')
33
+
34
+ peaks = []
35
+ losses = []
36
+
37
+ for dtype_name in ["float32", "bfloat16"]:
38
+ torch.manual_seed(0x2131290)
39
+ peaks.append([])
40
+ losses.append([])
41
+
42
+ dtype = getattr(torch, dtype_name)
43
+
44
+ for i in range(outer_iterations):
45
+ model = nn.Sequential(*[nn.Linear(size, size) for _ in range(depth)]).cuda().to(dtype)
46
+ o = get_optim(opt, model.parameters(), lr=1e-3, storage_dtype=dtype_name)
47
+
48
+ for _ in range(iterations):
49
+ loss = model(torch.randn((1024, size), device='cuda', dtype=dtype)).square().mean()
50
+ loss.backward()
51
+ o.step()
52
+ o.zero_grad()
53
+ losses[-1].append(loss.detach())
54
+
55
+ del model, o
56
+ clean()
57
+
58
+ for i, (l0, l1) in enumerate(zip(*losses)):
59
+ print(i, l0.item(), l1.item())
60
+ assert torch.allclose(l0.float(), l1.float(), rtol=0.1)
@@ -1,42 +0,0 @@
1
- import torch
2
- import torch.optim
3
-
4
- from .utils import warmup, exp_avg_sq_, beta_debias, update_param_, StatefulOptimizer
5
-
6
-
7
- class ForeachAdamW(StatefulOptimizer):
8
- def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
9
- foreach: bool = True):
10
- defaults = dict(lr=lr, betas=betas, eps=eps, k=0, warmup_steps=warmup_steps, train_mode=True, weight_sum=0.0,
11
- lr_max=-1.0, weight_decay=weight_decay)
12
- super().__init__(params, defaults, foreach)
13
-
14
- def _step(self, group):
15
- eps = group['eps']
16
- decay = group['weight_decay']
17
- k = group['k']
18
-
19
- if not group['train_mode']:
20
- raise Exception("Not in train mode!")
21
-
22
- active_p = [p for p in group['params'] if p.grad is not None]
23
-
24
- if not active_p:
25
- return
26
-
27
- for p in active_p:
28
- if 'exp_avg' not in self.state_(p):
29
- self.state_(p)['exp_avg'] = torch.zeros_like(p.data, dtype=torch.float32)
30
- self.state_(p)['exp_avg_sq'] = torch.zeros_like(p.data, dtype=torch.float32)
31
-
32
- y, grad, exp_avg_sq, exp_avg = zip(
33
- *[(p.data, p.grad.float(), self.state_(p)['exp_avg_sq'], self.state_(p)['exp_avg']) for p in active_p])
34
-
35
- # Decay the first and second moment running average coefficient
36
- torch._foreach_lerp_(exp_avg, grad, 1 - beta_debias(group['betas'][0], k + 1))
37
- denom = list(exp_avg_sq_(exp_avg_sq, grad, beta_debias(group['betas'][1], k + 1), eps))
38
-
39
- # Normalize grad in-place for memory efficiency
40
- lr = -warmup(group['lr'], k + 1, group['warmup_steps'])
41
- update_param_(y, exp_avg, lr, decay, lambda p, e, l: p.addcdiv_(e, denom.pop(0), value=l))
42
- group['k'] = k + 1
@@ -1,52 +0,0 @@
1
- import torch
2
- import torch.optim
3
-
4
- from .utils import warmup, beta_debias, update_param_, StatefulOptimizer
5
-
6
-
7
- class ForeachADOPT(StatefulOptimizer):
8
-
9
- def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
10
- foreach: bool = True):
11
- defaults = dict(lr=lr, betas=betas, eps=eps, k=0, warmup_steps=warmup_steps, train_mode=True, weight_sum=0.0,
12
- lr_max=-1.0, weight_decay=weight_decay)
13
- super().__init__(params, defaults, foreach)
14
-
15
- def _step(self, group):
16
- eps = group['eps']
17
- decay = group['weight_decay']
18
- k = group['k']
19
-
20
- if not group['train_mode']:
21
- raise Exception("Not in train mode!")
22
-
23
- active_p = [p for p in group['params'] if p.grad is not None]
24
-
25
- if not active_p:
26
- return
27
-
28
- for p in active_p:
29
- if 'exp_avg' not in self.state_(p):
30
- self.state_(p)['exp_avg'] = torch.zeros_like(p.data, dtype=torch.float32)
31
- self.state_(p)['exp_avg_sq'] = torch.zeros_like(p.data, dtype=torch.float32)
32
-
33
- y, grad, exp_avg_sq, exp_avg = zip(
34
- *[(p.data, p.grad.float(), self.state_(p)['exp_avg_sq'], self.state_(p)['exp_avg']) for p in active_p])
35
-
36
- if k > 1:
37
- lr = -warmup(group['lr'], k - 1, group['warmup_steps'])
38
-
39
- update_param_(y, exp_avg, lr, decay)
40
- if k > 0:
41
- beta1 = beta_debias(group['betas'][0], k)
42
- denom = torch._foreach_sqrt(exp_avg_sq)
43
- torch._foreach_maximum_(denom, eps)
44
- torch._foreach_mul_(exp_avg, beta1)
45
- torch._foreach_addcdiv_(exp_avg, grad, denom, 1 - beta1)
46
-
47
- beta2 = beta_debias(group['betas'][1], k + 1)
48
- torch._foreach_mul_(exp_avg_sq, beta2)
49
- torch._foreach_addcmul_(exp_avg_sq, grad, grad, value=1 - beta2)
50
- del grad
51
-
52
- group['k'] = k + 1
@@ -1,47 +0,0 @@
1
- import torch
2
- import torch.optim
3
-
4
- from .utils import warmup, exp_avg_sq_, beta_debias, update_param_, StatefulOptimizer
5
-
6
-
7
- class ForeachLaProp(StatefulOptimizer):
8
-
9
- def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=1,
10
- foreach: bool = True):
11
- defaults = dict(lr=lr, betas=betas, eps=eps, k=0, warmup_steps=warmup_steps, train_mode=True, weight_sum=0.0,
12
- lr_max=-1.0, weight_decay=weight_decay)
13
- super().__init__(params, defaults, foreach)
14
-
15
- def _step(self, group):
16
- eps = group['eps']
17
- decay = group['weight_decay']
18
- k = group['k']
19
-
20
- if not group['train_mode']:
21
- raise Exception("Not in train mode!")
22
-
23
- active_p = [p for p in group['params'] if p.grad is not None]
24
-
25
- if not active_p:
26
- return
27
-
28
- for p in active_p:
29
- if 'exp_avg' not in self.state_(p):
30
- self.state_(p)['exp_avg'] = torch.zeros_like(p.data, dtype=torch.float32)
31
- self.state_(p)['exp_avg_sq'] = torch.zeros_like(p.data, dtype=torch.float32)
32
-
33
- y, grad, exp_avg_sq, exp_avg = zip(
34
- *[(p.data, p.grad.float(), self.state_(p)['exp_avg_sq'], self.state_(p)['exp_avg']) for p in active_p])
35
-
36
- # Decay the first and second moment running average coefficient
37
- denom = exp_avg_sq_(exp_avg_sq, grad, beta_debias(group['betas'][1], k + 1), eps)
38
- beta1 = beta_debias(group['betas'][0], k + 1)
39
- torch._foreach_mul_(exp_avg, beta1)
40
- torch._foreach_addcdiv_(exp_avg, grad, denom, 1 - beta1)
41
- del grad
42
-
43
- # Normalize grad in-place for memory efficiency
44
- lr = -warmup(group['lr'], k + 1, group['warmup_steps'])
45
- update_param_(y, exp_avg, lr, decay)
46
-
47
- group['k'] = k + 1
@@ -1,54 +0,0 @@
1
- import torch
2
- import torch.optim
3
-
4
- from .utils import schedule_free_, warmup, ScheduleFree, exp_avg_sq_, beta_debias
5
-
6
-
7
- class ForeachSFAdamW(ScheduleFree):
8
- def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0, r=0.0,
9
- weight_lr_power=2.0, foreach: bool = True):
10
-
11
- defaults = dict(lr=lr, betas=betas, eps=eps, r=r, k=0, warmup_steps=warmup_steps, train_mode=True,
12
- weight_sum=0.0, lr_max=-1.0, weight_lr_power=weight_lr_power, weight_decay=weight_decay,
13
- foreach=foreach)
14
- super().__init__(params, defaults, foreach)
15
-
16
- def _step(self, group):
17
- eps = group['eps']
18
- decay = group['weight_decay']
19
- k = group['k']
20
-
21
- if not group['train_mode']:
22
- raise Exception("Not in train mode!")
23
-
24
- active_p = [p for p in group['params'] if p.grad is not None]
25
-
26
- if not active_p:
27
- return
28
-
29
- for p in active_p:
30
- if 'z' not in self.state_(p):
31
- self.state_(p)['z'] = torch.clone(p.data)
32
- self.state_(p)['exp_avg_sq'] = torch.zeros_like(p.data, dtype=torch.float32)
33
-
34
- y, grad, exp_avg_sq, z = zip(
35
- *[(p.data, p.grad.float(), self.state_(p)['exp_avg_sq'], self.state_(p)['z']) for p in active_p])
36
-
37
- # Decay the first moment running average coefficient
38
- old_debiased = beta_debias(group['betas'][1], k + 1)
39
-
40
- # Decay the first and second moment running average coefficient
41
- denom = exp_avg_sq_(exp_avg_sq, grad, old_debiased, eps)
42
-
43
- # Normalize grad in-place for memory efficiency
44
- torch._foreach_div_(grad, denom)
45
-
46
- # Weight decay calculated at y
47
- if decay != 0:
48
- torch._foreach_add_(grad, y, alpha=decay)
49
-
50
- lr = warmup(group['lr'], k + 1, group['warmup_steps'])
51
- group['weight_sum'] = schedule_free_(lr, group['weight_lr_power'], group['weight_sum'], group['betas'][0], y, z,
52
- grad, group['r'], k + 1)
53
-
54
- group['k'] = k + 1
@@ -1,57 +0,0 @@
1
- import torch
2
- import torch.optim
3
-
4
- from .utils import schedule_free_, warmup, ScheduleFree, exp_avg_sq_, beta_debias
5
-
6
-
7
- class PaLMForeachSFAdamW(ScheduleFree):
8
- def __init__(self, params, lr=0.0025, beta=0.9, betas=(None, None), eps=1e-8, weight_decay=0, warmup_steps=0, r=0.0,
9
- weight_lr_power=2.0, beta2_scale: float = 0.8,
10
- foreach: bool = True):
11
- if betas[0] is not None:
12
- beta = betas[0]
13
- defaults = dict(lr=lr, beta=beta, eps=eps, r=r, k=0, warmup_steps=warmup_steps, train_mode=True, weight_sum=0.0,
14
- lr_max=-1.0, weight_lr_power=weight_lr_power, weight_decay=weight_decay,
15
- beta2_scale=beta2_scale)
16
- super().__init__(params, defaults, foreach)
17
-
18
- def _step(self, group):
19
- eps = group['eps']
20
- decay = group['weight_decay']
21
- k = group['k']
22
-
23
- if not group['train_mode']:
24
- raise Exception("Not in train mode!")
25
-
26
- active_p = [p for p in group['params'] if p.grad is not None]
27
-
28
- if not active_p:
29
- return
30
-
31
- for p in active_p:
32
- if 'z' not in self.state_(p):
33
- self.state_(p)['z'] = torch.clone(p.data)
34
- self.state_(p)['exp_avg_sq'] = torch.zeros_like(p.data, dtype=torch.float32)
35
-
36
- y, grad, exp_avg_sq, z = zip(
37
- *[(p.data, p.grad.float(), self.state_(p)['exp_avg_sq'], self.state_(p)['z']) for p in active_p])
38
-
39
- # Decay the first moment running average coefficient
40
- beta2 = 1 - (k + 1) ** -group['beta2_scale']
41
- old_debiased = beta_debias(beta2, k + 1)
42
-
43
- # Decay the first and second moment running average coefficient
44
- denom = exp_avg_sq_(exp_avg_sq, grad, old_debiased, eps)
45
-
46
- # Normalize grad in-place for memory efficiency
47
- torch._foreach_div_(grad, denom)
48
-
49
- # Weight decay calculated at y
50
- if decay != 0:
51
- torch._foreach_add_(grad, y, alpha=decay)
52
-
53
- lr = warmup(group['lr'], k + 1, group['warmup_steps'])
54
- group['weight_sum'] = schedule_free_(lr, group['weight_lr_power'], group['weight_sum'], group['beta'], y, z,
55
- grad, group['r'], k + 1)
56
-
57
- group['k'] = k + 1
File without changes
File without changes
File without changes
File without changes