heavyball 0.18.8__tar.gz → 0.20.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. {heavyball-0.18.8 → heavyball-0.20.1}/PKG-INFO +18 -16
  2. {heavyball-0.18.8 → heavyball-0.20.1}/README.md +17 -15
  3. {heavyball-0.18.8 → heavyball-0.20.1}/heavyball/cached_delayed_psgd_kron.py +11 -11
  4. {heavyball-0.18.8 → heavyball-0.20.1}/heavyball/cached_psgd_kron.py +13 -12
  5. {heavyball-0.18.8 → heavyball-0.20.1}/heavyball/delayed_psgd.py +15 -18
  6. {heavyball-0.18.8 → heavyball-0.20.1}/heavyball/foreach_adamw.py +7 -5
  7. {heavyball-0.18.8 → heavyball-0.20.1}/heavyball/foreach_adopt.py +6 -4
  8. {heavyball-0.18.8 → heavyball-0.20.1}/heavyball/foreach_laprop.py +10 -5
  9. {heavyball-0.18.8 → heavyball-0.20.1}/heavyball/foreach_sfadamw.py +7 -4
  10. {heavyball-0.18.8 → heavyball-0.20.1}/heavyball/foreach_soap.py +4 -7
  11. {heavyball-0.18.8 → heavyball-0.20.1}/heavyball/p_adam.py +9 -9
  12. {heavyball-0.18.8 → heavyball-0.20.1}/heavyball/palm_foreach_sfadamw.py +9 -4
  13. {heavyball-0.18.8 → heavyball-0.20.1}/heavyball/palm_foreach_soap.py +6 -6
  14. {heavyball-0.18.8 → heavyball-0.20.1}/heavyball/precond_schedule_foreach_soap.py +6 -10
  15. {heavyball-0.18.8 → heavyball-0.20.1}/heavyball/precond_schedule_palm_foreach_soap.py +4 -4
  16. {heavyball-0.18.8 → heavyball-0.20.1}/heavyball/precond_schedule_sfpsoap.py +20 -10
  17. {heavyball-0.18.8 → heavyball-0.20.1}/heavyball/psgd_kron.py +15 -12
  18. {heavyball-0.18.8 → heavyball-0.20.1}/heavyball/pure_psgd.py +3 -6
  19. {heavyball-0.18.8 → heavyball-0.20.1}/heavyball/schedule_free_palm_foreach_soap.py +17 -8
  20. {heavyball-0.18.8 → heavyball-0.20.1}/heavyball/utils.py +146 -57
  21. {heavyball-0.18.8 → heavyball-0.20.1}/heavyball.egg-info/PKG-INFO +18 -16
  22. {heavyball-0.18.8 → heavyball-0.20.1}/heavyball.egg-info/SOURCES.txt +2 -0
  23. {heavyball-0.18.8 → heavyball-0.20.1}/setup.py +1 -1
  24. {heavyball-0.18.8 → heavyball-0.20.1}/test/test_bf16_params.py +2 -1
  25. heavyball-0.20.1/test/test_bf16_storage.py +60 -0
  26. heavyball-0.20.1/test/test_ema.py +61 -0
  27. {heavyball-0.18.8 → heavyball-0.20.1}/LICENSE +0 -0
  28. {heavyball-0.18.8 → heavyball-0.20.1}/heavyball/__init__.py +0 -0
  29. {heavyball-0.18.8 → heavyball-0.20.1}/heavyball.egg-info/dependency_links.txt +0 -0
  30. {heavyball-0.18.8 → heavyball-0.20.1}/heavyball.egg-info/requires.txt +0 -0
  31. {heavyball-0.18.8 → heavyball-0.20.1}/heavyball.egg-info/top_level.txt +0 -0
  32. {heavyball-0.18.8 → heavyball-0.20.1}/setup.cfg +0 -0
  33. {heavyball-0.18.8 → heavyball-0.20.1}/test/test_bf16_q.py +0 -0
  34. {heavyball-0.18.8 → heavyball-0.20.1}/test/test_closure.py +0 -0
  35. {heavyball-0.18.8 → heavyball-0.20.1}/test/test_foreach.py +0 -0
  36. {heavyball-0.18.8 → heavyball-0.20.1}/test/test_memory.py +0 -0
  37. {heavyball-0.18.8 → heavyball-0.20.1}/test/test_merge.py +0 -0
  38. {heavyball-0.18.8 → heavyball-0.20.1}/test/test_no_grad.py +0 -0
  39. {heavyball-0.18.8 → heavyball-0.20.1}/test/test_psgd.py +0 -0
  40. {heavyball-0.18.8 → heavyball-0.20.1}/test/test_soap.py +0 -0
  41. {heavyball-0.18.8 → heavyball-0.20.1}/test/test_stochastic_updates.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: heavyball
3
- Version: 0.18.8
3
+ Version: 0.20.1
4
4
  Summary: Efficient optimizers
5
5
  Home-page: https://github.com/clashluke/heavyball
6
6
  Author: Lucas Nestler
@@ -32,7 +32,7 @@ A simple package of efficient optimizers
32
32
  The goal is not to thrive for completeness, full maintenance or abstraction, but instead to provide a simple
33
33
  largely static alternative to `torch.optim` with more and better optimizers.
34
34
 
35
- Currently (2024-11-21, 0.18.6), the recommended stable optimizer is `PrecondSchedulePaLMSOAP` (see below). The
35
+ Currently (2024-11-22, 0.19.1), the recommended stable optimizer is `PrecondSchedulePaLMSOAP` (see below). The
36
36
  recommended experimental optimizer is `DelayedPSGDKron` ([tuning guide](docs/psgd_efficiency.md)).
37
37
 
38
38
  ## Features
@@ -45,8 +45,10 @@ recommended experimental optimizer is `DelayedPSGDKron` ([tuning guide](docs/psg
45
45
  * **ScheduleFree**: No learning rate schedule, but better convergence
46
46
  * [**Preconditioner Schedule**](https://github.com/lixilinx/psgd_torch/): Improved loss-per-step in early convergence,
47
47
  better step-per-second in late convergence (explained below)
48
- * **Memory-efficient storage** PSGD supports `store_triu_as_line` (default: `True`) to trade off memory usage for memory
49
- bandwidth; turn it off for lower overheads (for more, see [PSGD Efficiency](docs/psgd_efficiency.md))
48
+ * **Memory-efficient storage** PSGD supports `store_triu_as_line` (default: `True`) and `q_dtype` to trade off memory
49
+ usage for memory
50
+ bandwidth; Other optimizers have `storage_dtype`, supporting lower-precision EMAs at no(?) performance drop via
51
+ stochastic rounding
50
52
 
51
53
  ## Getting started
52
54
 
@@ -76,19 +78,19 @@ for _ in range(1000):
76
78
 
77
79
  ## Optimizers
78
80
 
79
- | Name | Description | Advantages / Disadvantages |
80
- |-------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
81
- | **AdamW** | More efficient (speed, memory) [AdamW](https://arxiv.org/abs/1711.05101) | + Faster than AdamW<br>+ Possibly more (numerically) stable
82
- | **LaProp** | More efficient (speed, memory) [LaProp](https://arxiv.org/abs/2002.04839) | + Same cost as AdamW<br>+ Marginally better converence (better proofs)<br>+ Higher hyperparameter stability<br>- Not a guaranteed win (can be neutral)<br>- No "Slingshot" |
83
- | **ADOPT** | More efficient (speed, memory) [ADOPT](https://arxiv.org/abs/2411.02853) | + Same cost as AdamW<br>+ Rigorous mathematical convergence proofs, even for challenging models (GANs)<br>- Empirically underperforms LaProp<br>- no bf16 |
84
- | **SFAdamW** | More efficient (speed, memory) [ScheduleFree AdamW](https://arxiv.org/abs/2405.15682) | + Same cost as AdamW, but better eval perf<br>+ Full control over hyperparameters |
85
- | **PaLMSFAdamW** | ForeachSFAdamW with [PaLM's beta2 schedule](https://arxiv.org/abs/2204.02311) | + Same cost as AdamW, but better eval perf<br>+ Less control, but faster early and more stable late convergence<br>+ ScheduleFree<br>- slow early convergence |
86
- | **SOAP** | More efficient (speed, memory) [SOAP](https://arxiv.org/abs/2409.11321) | + Faster convergence (loss-at-step)<br>+ Full control over hyperparameters<br>- more memory usage<br>- more hyperparameters<br>- higher overhead than AdamW (can be ammortized; better loss-at-second) |
87
- | **PaLMSOAP** | ForeachSOAP with [PaLM's beta2 schedule](https://arxiv.org/abs/2204.02311) | + Faster convergence (loss-at-step)<br>+ Less control, but faster early and more stable late convergence<br>- more memory usage<br>- more hyperparameters<br>- higher overhead than AdamW (can be ammortized; better loss-at-second) |
88
- | **SFPaLMSOAP** | ScheduleFree PaLMForeachSOAP | + Fast convergence (loss-at-step)<br>+ less memory usage than PaLMForeachSOAP (more tham AdamW)<br>- slower initial convergence than PaLMForeachSOAP (but allows higher LRs)<br>- higher overhead than AdamW (can be ammortized) |
81
+ | Name | Description | Advantages / Disadvantages |
82
+ |-------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
83
+ | **AdamW** | More efficient (speed, memory) [AdamW](https://arxiv.org/abs/1711.05101) | + Faster than AdamW<br>+ Possibly more (numerically) stable
84
+ | **LaProp** | More efficient (speed, memory) [LaProp](https://arxiv.org/abs/2002.04839) | + Same cost as AdamW<br>+ Marginally better converence (better proofs)<br>+ Higher hyperparameter stability<br>- Not a guaranteed win (can be neutral)<br>- No "Slingshot" |
85
+ | **ADOPT** | More efficient (speed, memory) [ADOPT](https://arxiv.org/abs/2411.02853) | + Same cost as AdamW<br>+ Rigorous mathematical convergence proofs, even for challenging models (GANs)<br>- Empirically underperforms LaProp<br>- no bf16 |
86
+ | **SFAdamW** | More efficient (speed, memory) [ScheduleFree AdamW](https://arxiv.org/abs/2405.15682) | + Same cost as AdamW, but better eval perf<br>+ Full control over hyperparameters |
87
+ | **PaLMSFAdamW** | ForeachSFAdamW with [PaLM's beta2 schedule](https://arxiv.org/abs/2204.02311) | + Same cost as AdamW, but better eval perf<br>+ Less control, but faster early and more stable late convergence<br>+ ScheduleFree<br>- slow early convergence |
88
+ | **SOAP** | More efficient (speed, memory) [SOAP](https://arxiv.org/abs/2409.11321) | + Faster convergence (loss-at-step)<br>+ Full control over hyperparameters<br>- more memory usage<br>- more hyperparameters<br>- higher overhead than AdamW (can be ammortized; better loss-at-second) |
89
+ | **PaLMSOAP** | ForeachSOAP with [PaLM's beta2 schedule](https://arxiv.org/abs/2204.02311) | + Faster convergence (loss-at-step)<br>+ Less control, but faster early and more stable late convergence<br>- more memory usage<br>- more hyperparameters<br>- higher overhead than AdamW (can be ammortized; better loss-at-second) |
90
+ | **SFPaLMSOAP** | ScheduleFree PaLMForeachSOAP | + Fast convergence (loss-at-step)<br>+ less memory usage than PaLMForeachSOAP (more tham AdamW)<br>- slower initial convergence than PaLMForeachSOAP (but allows higher LRs)<br>- higher overhead than AdamW (can be ammortized) |
89
91
  | **PrecondScheduleSFPaLMSOAP** | SFPaLMForeachSOAP with [preconditioner schedule](https://github.com/lixilinx/psgd_torch/), matching the error of PrecondEvery=2 with the cost of PrecondEvery=512 | + Better initial convergence than SFPaLMForeachSOAP<br>+ Significantly faster (sec/it) later<br>+ less memory usage than PaLMForeachSOAP (more tham AdamW)<br>- slower initial convergence than PaLMForeachSOAP (but allows higher LRs)<br>- higher overhead than AdamW (can be ammortized), goes to 0 with increasing number of step |
90
- | **PrecondSchedulePaLMSOAP** | PrecondScheduleSFPaLMForeachSOAP without schedule-free | + Best initial convergence<br>+ Significantly faster (sec/it) later<br>+ high stability<br>- more memory usage than PrecondScheduleSFPaLMForeachSOAP<br>- higher overhead than AdamW (can be ammortized), goes to 0 with increasing number of steps |
91
- | **PrecondScheduleSOAP** | PrecondScheduleSFPaLMForeachSOAP without PaLM's beta2 schedule | + Better initial convergence<br>+ Significantly faster (sec/it) later<br>- more memory usage than PrecondScheduleSFPaLMForeachSOAP<br>- higher overhead than AdamW (can be ammortized), goes to 0 with increasing number of steps |
92
+ | **PrecondSchedulePaLMSOAP** | PrecondScheduleSFPaLMForeachSOAP without schedule-free | + Best initial convergence<br>+ Significantly faster (sec/it) later<br>+ high stability<br>- more memory usage than PrecondScheduleSFPaLMForeachSOAP<br>- higher overhead than AdamW (can be ammortized), goes to 0 with increasing number of steps |
93
+ | **PrecondScheduleSOAP** | PrecondScheduleSFPaLMForeachSOAP without PaLM's beta2 schedule | + Better initial convergence<br>+ Significantly faster (sec/it) later<br>- more memory usage than PrecondScheduleSFPaLMForeachSOAP<br>- higher overhead than AdamW (can be ammortized), goes to 0 with increasing number of steps |
92
94
 
93
95
  ## Precond Schedule
94
96
 
@@ -8,7 +8,7 @@ A simple package of efficient optimizers
8
8
  The goal is not to thrive for completeness, full maintenance or abstraction, but instead to provide a simple
9
9
  largely static alternative to `torch.optim` with more and better optimizers.
10
10
 
11
- Currently (2024-11-21, 0.18.6), the recommended stable optimizer is `PrecondSchedulePaLMSOAP` (see below). The
11
+ Currently (2024-11-22, 0.19.1), the recommended stable optimizer is `PrecondSchedulePaLMSOAP` (see below). The
12
12
  recommended experimental optimizer is `DelayedPSGDKron` ([tuning guide](docs/psgd_efficiency.md)).
13
13
 
14
14
  ## Features
@@ -21,8 +21,10 @@ recommended experimental optimizer is `DelayedPSGDKron` ([tuning guide](docs/psg
21
21
  * **ScheduleFree**: No learning rate schedule, but better convergence
22
22
  * [**Preconditioner Schedule**](https://github.com/lixilinx/psgd_torch/): Improved loss-per-step in early convergence,
23
23
  better step-per-second in late convergence (explained below)
24
- * **Memory-efficient storage** PSGD supports `store_triu_as_line` (default: `True`) to trade off memory usage for memory
25
- bandwidth; turn it off for lower overheads (for more, see [PSGD Efficiency](docs/psgd_efficiency.md))
24
+ * **Memory-efficient storage** PSGD supports `store_triu_as_line` (default: `True`) and `q_dtype` to trade off memory
25
+ usage for memory
26
+ bandwidth; Other optimizers have `storage_dtype`, supporting lower-precision EMAs at no(?) performance drop via
27
+ stochastic rounding
26
28
 
27
29
  ## Getting started
28
30
 
@@ -52,19 +54,19 @@ for _ in range(1000):
52
54
 
53
55
  ## Optimizers
54
56
 
55
- | Name | Description | Advantages / Disadvantages |
56
- |-------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
57
- | **AdamW** | More efficient (speed, memory) [AdamW](https://arxiv.org/abs/1711.05101) | + Faster than AdamW<br>+ Possibly more (numerically) stable
58
- | **LaProp** | More efficient (speed, memory) [LaProp](https://arxiv.org/abs/2002.04839) | + Same cost as AdamW<br>+ Marginally better converence (better proofs)<br>+ Higher hyperparameter stability<br>- Not a guaranteed win (can be neutral)<br>- No "Slingshot" |
59
- | **ADOPT** | More efficient (speed, memory) [ADOPT](https://arxiv.org/abs/2411.02853) | + Same cost as AdamW<br>+ Rigorous mathematical convergence proofs, even for challenging models (GANs)<br>- Empirically underperforms LaProp<br>- no bf16 |
60
- | **SFAdamW** | More efficient (speed, memory) [ScheduleFree AdamW](https://arxiv.org/abs/2405.15682) | + Same cost as AdamW, but better eval perf<br>+ Full control over hyperparameters |
61
- | **PaLMSFAdamW** | ForeachSFAdamW with [PaLM's beta2 schedule](https://arxiv.org/abs/2204.02311) | + Same cost as AdamW, but better eval perf<br>+ Less control, but faster early and more stable late convergence<br>+ ScheduleFree<br>- slow early convergence |
62
- | **SOAP** | More efficient (speed, memory) [SOAP](https://arxiv.org/abs/2409.11321) | + Faster convergence (loss-at-step)<br>+ Full control over hyperparameters<br>- more memory usage<br>- more hyperparameters<br>- higher overhead than AdamW (can be ammortized; better loss-at-second) |
63
- | **PaLMSOAP** | ForeachSOAP with [PaLM's beta2 schedule](https://arxiv.org/abs/2204.02311) | + Faster convergence (loss-at-step)<br>+ Less control, but faster early and more stable late convergence<br>- more memory usage<br>- more hyperparameters<br>- higher overhead than AdamW (can be ammortized; better loss-at-second) |
64
- | **SFPaLMSOAP** | ScheduleFree PaLMForeachSOAP | + Fast convergence (loss-at-step)<br>+ less memory usage than PaLMForeachSOAP (more tham AdamW)<br>- slower initial convergence than PaLMForeachSOAP (but allows higher LRs)<br>- higher overhead than AdamW (can be ammortized) |
57
+ | Name | Description | Advantages / Disadvantages |
58
+ |-------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
59
+ | **AdamW** | More efficient (speed, memory) [AdamW](https://arxiv.org/abs/1711.05101) | + Faster than AdamW<br>+ Possibly more (numerically) stable
60
+ | **LaProp** | More efficient (speed, memory) [LaProp](https://arxiv.org/abs/2002.04839) | + Same cost as AdamW<br>+ Marginally better converence (better proofs)<br>+ Higher hyperparameter stability<br>- Not a guaranteed win (can be neutral)<br>- No "Slingshot" |
61
+ | **ADOPT** | More efficient (speed, memory) [ADOPT](https://arxiv.org/abs/2411.02853) | + Same cost as AdamW<br>+ Rigorous mathematical convergence proofs, even for challenging models (GANs)<br>- Empirically underperforms LaProp<br>- no bf16 |
62
+ | **SFAdamW** | More efficient (speed, memory) [ScheduleFree AdamW](https://arxiv.org/abs/2405.15682) | + Same cost as AdamW, but better eval perf<br>+ Full control over hyperparameters |
63
+ | **PaLMSFAdamW** | ForeachSFAdamW with [PaLM's beta2 schedule](https://arxiv.org/abs/2204.02311) | + Same cost as AdamW, but better eval perf<br>+ Less control, but faster early and more stable late convergence<br>+ ScheduleFree<br>- slow early convergence |
64
+ | **SOAP** | More efficient (speed, memory) [SOAP](https://arxiv.org/abs/2409.11321) | + Faster convergence (loss-at-step)<br>+ Full control over hyperparameters<br>- more memory usage<br>- more hyperparameters<br>- higher overhead than AdamW (can be ammortized; better loss-at-second) |
65
+ | **PaLMSOAP** | ForeachSOAP with [PaLM's beta2 schedule](https://arxiv.org/abs/2204.02311) | + Faster convergence (loss-at-step)<br>+ Less control, but faster early and more stable late convergence<br>- more memory usage<br>- more hyperparameters<br>- higher overhead than AdamW (can be ammortized; better loss-at-second) |
66
+ | **SFPaLMSOAP** | ScheduleFree PaLMForeachSOAP | + Fast convergence (loss-at-step)<br>+ less memory usage than PaLMForeachSOAP (more tham AdamW)<br>- slower initial convergence than PaLMForeachSOAP (but allows higher LRs)<br>- higher overhead than AdamW (can be ammortized) |
65
67
  | **PrecondScheduleSFPaLMSOAP** | SFPaLMForeachSOAP with [preconditioner schedule](https://github.com/lixilinx/psgd_torch/), matching the error of PrecondEvery=2 with the cost of PrecondEvery=512 | + Better initial convergence than SFPaLMForeachSOAP<br>+ Significantly faster (sec/it) later<br>+ less memory usage than PaLMForeachSOAP (more tham AdamW)<br>- slower initial convergence than PaLMForeachSOAP (but allows higher LRs)<br>- higher overhead than AdamW (can be ammortized), goes to 0 with increasing number of step |
66
- | **PrecondSchedulePaLMSOAP** | PrecondScheduleSFPaLMForeachSOAP without schedule-free | + Best initial convergence<br>+ Significantly faster (sec/it) later<br>+ high stability<br>- more memory usage than PrecondScheduleSFPaLMForeachSOAP<br>- higher overhead than AdamW (can be ammortized), goes to 0 with increasing number of steps |
67
- | **PrecondScheduleSOAP** | PrecondScheduleSFPaLMForeachSOAP without PaLM's beta2 schedule | + Better initial convergence<br>+ Significantly faster (sec/it) later<br>- more memory usage than PrecondScheduleSFPaLMForeachSOAP<br>- higher overhead than AdamW (can be ammortized), goes to 0 with increasing number of steps |
68
+ | **PrecondSchedulePaLMSOAP** | PrecondScheduleSFPaLMForeachSOAP without schedule-free | + Best initial convergence<br>+ Significantly faster (sec/it) later<br>+ high stability<br>- more memory usage than PrecondScheduleSFPaLMForeachSOAP<br>- higher overhead than AdamW (can be ammortized), goes to 0 with increasing number of steps |
69
+ | **PrecondScheduleSOAP** | PrecondScheduleSFPaLMForeachSOAP without PaLM's beta2 schedule | + Better initial convergence<br>+ Significantly faster (sec/it) later<br>- more memory usage than PrecondScheduleSFPaLMForeachSOAP<br>- higher overhead than AdamW (can be ammortized), goes to 0 with increasing number of steps |
68
70
 
69
71
  ## Precond Schedule
70
72
 
@@ -9,7 +9,7 @@ from typing import Optional
9
9
  import torch
10
10
 
11
11
  from .utils import update_param_, warmup, init_Q_exprs, trust_region_clip_, PSGDBase, split_p_and_g_in_group, \
12
- line_to_triu, triu_to_line, set_, einsum_base, promote
12
+ line_to_triu, triu_to_line, einsum_base, promote, stochastic_lerp_, beta_debias
13
13
 
14
14
 
15
15
  class ForeachCachedDelayedPSGDKron(PSGDBase):
@@ -41,7 +41,8 @@ class ForeachCachedDelayedPSGDKron(PSGDBase):
41
41
  max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
42
42
  momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
43
43
  split: bool = False, clip_fn: Optional[callable] = None, store_triu_as_line: bool = True,
44
- foreach: bool = True, q_dtype='float32', stochastic_schedule: bool = True, #
44
+ foreach: bool = True, q_dtype='float32', stochastic_schedule: bool = True,
45
+ storage_dtype: str = 'float32', #
45
46
  # expert parameters
46
47
  precond_init_scale=1.0, precond_lr=0.1):
47
48
  if not 0.0 <= lr:
@@ -58,7 +59,7 @@ class ForeachCachedDelayedPSGDKron(PSGDBase):
58
59
  min_ndim_triangular=min_ndim_triangular, memory_save_mode=memory_save_mode,
59
60
  momentum_into_precond_update=momentum_into_precond_update, precond_lr=precond_lr,
60
61
  precond_init_scale=precond_init_scale, step=0, warmup_steps=warmup_steps, merge_dims=merge_dims,
61
- split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype)
62
+ split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype, storage_dtype=storage_dtype)
62
63
  super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
63
64
 
64
65
  def _step(self, group):
@@ -74,14 +75,15 @@ class ForeachCachedDelayedPSGDKron(PSGDBase):
74
75
  beta = group['beta']
75
76
  store_triu_as_line = group['store_triu_as_line']
76
77
  q_dtype = getattr(torch, group['q_dtype'])
78
+ storage_dtype = getattr(torch, group['storage_dtype'])
77
79
 
78
80
  vals = []
79
81
 
80
- for p, g in split_p_and_g_in_group(group):
82
+ for p, g in split_p_and_g_in_group(group, should_promote=False):
81
83
  state = self.state_(p)
82
84
 
83
85
  if 'Q' not in state:
84
- state["exp_avg"] = torch.zeros_like(g)
86
+ state["exp_avg"] = torch.zeros_like(g, dtype=storage_dtype)
85
87
  Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular, min_ndim_triangular,
86
88
  memory_save_mode, dtype=q_dtype)
87
89
  state['Q'] = triu_to_line(Q) if store_triu_as_line else Q
@@ -105,7 +107,9 @@ class ForeachCachedDelayedPSGDKron(PSGDBase):
105
107
 
106
108
  group["step"] += 1
107
109
 
108
- torch._foreach_lerp_(exp_avg_list, grad_list, (1 - beta) / (1 - beta ** group["step"]))
110
+ stochastic_lerp_(exp_avg_list, grad_list, 1 - beta_debias(beta, group['step']))
111
+
112
+ lr = -warmup(lr, group['step'], group['warmup_steps'])
109
113
 
110
114
  grad_list, Q_list, Q_cache_list, exp_avg_list = list(grad_list), list(Q_list), list(Q_cache_list), list(
111
115
  exp_avg_list)
@@ -127,8 +131,4 @@ class ForeachCachedDelayedPSGDKron(PSGDBase):
127
131
  else:
128
132
  torch.mul(q_.conj(), q_, out=c_)
129
133
 
130
- set_(g, new)
131
- grad_list = self.clip_fn(grad_list)
132
-
133
- lr = -warmup(lr, group['step'], group['warmup_steps'])
134
- update_param_(p_list, grad_list, lr, weight_decay)
134
+ update_param_([p], self.clip_fn([new]), lr, weight_decay)
@@ -9,7 +9,7 @@ from typing import Optional
9
9
  import torch
10
10
 
11
11
  from .utils import update_param_, warmup, init_Q_exprs, trust_region_clip_, PSGDBase, split_p_and_g_in_group, \
12
- line_to_triu, triu_to_line, set_, einsum_base, promote
12
+ line_to_triu, triu_to_line, einsum_base, promote, stochastic_lerp_, beta_debias
13
13
 
14
14
 
15
15
  class ForeachCachedPSGDKron(PSGDBase):
@@ -39,7 +39,8 @@ class ForeachCachedPSGDKron(PSGDBase):
39
39
  max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
40
40
  momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
41
41
  split: bool = False, clip_fn: Optional[callable] = None, store_triu_as_line: bool = True,
42
- foreach: bool = True, q_dtype='float32', stochastic_schedule: bool = True, #
42
+ foreach: bool = True, q_dtype='float32', stochastic_schedule: bool = True,
43
+ storage_dtype: str = 'float32', #
43
44
  # expert parameters
44
45
  precond_init_scale=1.0, precond_lr=0.1):
45
46
  if not 0.0 <= lr:
@@ -56,7 +57,8 @@ class ForeachCachedPSGDKron(PSGDBase):
56
57
  min_ndim_triangular=min_ndim_triangular, memory_save_mode=memory_save_mode,
57
58
  momentum_into_precond_update=momentum_into_precond_update, precond_lr=precond_lr,
58
59
  precond_init_scale=precond_init_scale, step=0, warmup_steps=warmup_steps, merge_dims=merge_dims,
59
- split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype)
60
+ split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype,
61
+ storage_dtype=storage_dtype)
60
62
  super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
61
63
 
62
64
  def _step(self, group):
@@ -71,15 +73,16 @@ class ForeachCachedPSGDKron(PSGDBase):
71
73
  beta = group['beta']
72
74
  store_triu_as_line = group['store_triu_as_line']
73
75
  q_dtype = getattr(torch, group['q_dtype'])
76
+ storage_dtype = getattr(torch, group['storage_dtype'])
74
77
  should_update = self.should_update(group)
75
78
 
76
79
  vals = []
77
80
 
78
- for p, g in split_p_and_g_in_group(group):
81
+ for p, g in split_p_and_g_in_group(group, should_promote=False):
79
82
  state = self.state_(p)
80
83
 
81
84
  if 'Q' not in state:
82
- state["exp_avg"] = torch.zeros_like(g)
85
+ state["exp_avg"] = torch.zeros_like(g, dtype=storage_dtype)
83
86
  Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular, min_ndim_triangular,
84
87
  memory_save_mode, dtype=q_dtype)
85
88
  state['Q'] = triu_to_line(Q) if store_triu_as_line else Q
@@ -103,7 +106,9 @@ class ForeachCachedPSGDKron(PSGDBase):
103
106
 
104
107
  group["step"] += 1
105
108
 
106
- torch._foreach_lerp_(exp_avg_list, grad_list, (1 - beta) / (1 - beta ** group["step"]))
109
+ stochastic_lerp_(exp_avg_list, grad_list, 1 - beta_debias(beta, group['step']))
110
+
111
+ lr = -warmup(lr, group['step'], group['warmup_steps'])
107
112
 
108
113
  grad_list, Q_list, Q_cache_list, exp_avg_list = list(grad_list), list(Q_list), list(Q_cache_list), list(
109
114
  exp_avg_list)
@@ -123,9 +128,5 @@ class ForeachCachedPSGDKron(PSGDBase):
123
128
  else:
124
129
  torch.mul(q_.conj(), q_, out=c_)
125
130
 
126
- set_(g, torch.einsum(self.state_(p)['cache_expr'], *cached_q, ea))
127
-
128
- grad_list = self.clip_fn(grad_list)
129
-
130
- lr = -warmup(lr, group['step'], group['warmup_steps'])
131
- update_param_(p_list, grad_list, lr, weight_decay)
131
+ g = torch.einsum(self.state_(p)['cache_expr'], *cached_q, ea)
132
+ update_param_([p], self.clip_fn([g]), lr, weight_decay)
@@ -5,10 +5,10 @@ Source available at https://github.com/evanatyourservice/kron_torch/blob/97a2b5e
5
5
  """
6
6
 
7
7
  import torch
8
- from heavyball.utils import copy_stochastic_list_
9
8
 
9
+ from heavyball.utils import stochastic_lerp_, beta_debias
10
10
  from .utils import update_param_, warmup, psgd_precond_grad, init_Q_exprs, trust_region_clip_, PSGDBase, \
11
- precond_update_prob_schedule, split_p_and_g_in_group, triu_to_line, line_to_triu, set_, promote
11
+ split_p_and_g_in_group, triu_to_line, line_to_triu, promote
12
12
 
13
13
 
14
14
  class ForeachDelayedPSGD(PSGDBase):
@@ -38,8 +38,8 @@ class ForeachDelayedPSGD(PSGDBase):
38
38
  def __init__(self, params, lr=0.001, beta=0.9, weight_decay=0.0, preconditioner_update_probability=None,
39
39
  max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
40
40
  momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
41
- split: bool = False, clip_fn: callable = None, store_triu_as_line: bool = True,
42
- foreach: bool = True, q_dtype='float32', stochastic_schedule: bool = True, #
41
+ split: bool = False, clip_fn: callable = None, store_triu_as_line: bool = True, foreach: bool = True,
42
+ q_dtype='float32', stochastic_schedule: bool = True, storage_dtype:str='float32', #
43
43
  # expert parameters
44
44
  precond_init_scale=1.0, precond_lr=0.1):
45
45
  if not 0.0 <= lr:
@@ -55,12 +55,10 @@ class ForeachDelayedPSGD(PSGDBase):
55
55
  defaults = dict(lr=lr, beta=beta, weight_decay=weight_decay, max_size_triangular=max_size_triangular,
56
56
  min_ndim_triangular=min_ndim_triangular, memory_save_mode=memory_save_mode,
57
57
  momentum_into_precond_update=momentum_into_precond_update, precond_lr=precond_lr,
58
- precond_init_scale=precond_init_scale,
59
- step=0, warmup_steps=warmup_steps, merge_dims=merge_dims, split=split,
60
- store_triu_as_line=store_triu_as_line, q_dtype=q_dtype)
58
+ precond_init_scale=precond_init_scale, step=0, warmup_steps=warmup_steps, merge_dims=merge_dims,
59
+ split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype)
61
60
  super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
62
61
 
63
-
64
62
  def _step(self, group):
65
63
  should_update = self.should_update(group)
66
64
  momentum_into_precond_update = group.get("momentum_into_precond_update", True)
@@ -74,14 +72,15 @@ class ForeachDelayedPSGD(PSGDBase):
74
72
  beta = group['beta']
75
73
  store_triu_as_line = group['store_triu_as_line']
76
74
  q_dtype = getattr(torch, group['q_dtype'])
75
+ storage_dtype = getattr(torch, group['storage_dtype'])
77
76
 
78
77
  vals = []
79
78
 
80
- for p, g in split_p_and_g_in_group(group):
79
+ for p, g in split_p_and_g_in_group(group, should_promote=False):
81
80
  state = self.state_(p)
82
81
 
83
82
  if 'Q' not in state:
84
- state["exp_avg"] = torch.zeros_like(g)
83
+ state["exp_avg"] = torch.zeros_like(g, dtype=storage_dtype)
85
84
  Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular, min_ndim_triangular,
86
85
  memory_save_mode, dtype=q_dtype)
87
86
  state["Q"] = triu_to_line(Q) if store_triu_as_line else Q
@@ -96,7 +95,9 @@ class ForeachDelayedPSGD(PSGDBase):
96
95
 
97
96
  group["step"] += 1
98
97
 
99
- torch._foreach_lerp_(exp_avg_list, grad_list, (1 - beta) / (1 - beta ** group["step"]))
98
+ stochastic_lerp_(exp_avg_list, grad_list, beta_debias(beta, group["step"]))
99
+
100
+ lr = -warmup(lr, group['step'], group['warmup_steps'])
100
101
 
101
102
  Q_list, exp_avg_list = list(Q_list), list(exp_avg_list)
102
103
  for i, (p, g) in enumerate(zip(p_list, grad_list)):
@@ -106,10 +107,6 @@ class ForeachDelayedPSGD(PSGDBase):
106
107
  new = psgd_precond_grad(q, self.state_(p)["exprs"], ea)
107
108
  if should_update:
108
109
  q32 = [promote(q_) for q_ in q]
109
- self.do_update(group,[p], [ea if momentum_into_precond_update else g], [q32], precond_lr, [q_orig], store_triu_as_line)
110
- set_(g, new)
111
-
112
- grad_list = self.clip_fn(grad_list)
113
-
114
- lr = -warmup(lr, group['step'], group['warmup_steps'])
115
- update_param_(p_list, grad_list, lr, weight_decay)
110
+ self.do_update(group, [p], [ea if momentum_into_precond_update else g], [q32], precond_lr, [q_orig],
111
+ store_triu_as_line)
112
+ update_param_([p], self.clip_fn([new]), lr, weight_decay)
@@ -1,7 +1,7 @@
1
1
  import torch
2
2
  import torch.optim
3
- from heavyball.utils import copy_stochastic_list_
4
3
 
4
+ from heavyball.utils import copy_stochastic_list_
5
5
  from .utils import warmup, exp_avg_sq_, beta_debias, update_param_, StatefulOptimizer, promote
6
6
 
7
7
 
@@ -20,9 +20,9 @@ def _compilable_step_(y, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps,
20
20
 
21
21
  class ForeachAdamW(StatefulOptimizer):
22
22
  def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
23
- foreach: bool = True):
23
+ foreach: bool = True, storage_dtype: str = 'float32'):
24
24
  defaults = dict(lr=lr, betas=betas, eps=eps, k=0, warmup_steps=warmup_steps, train_mode=True, weight_sum=0.0,
25
- lr_max=-1.0, weight_decay=weight_decay)
25
+ lr_max=-1.0, weight_decay=weight_decay, storage_dtype=storage_dtype)
26
26
  super().__init__(params, defaults, foreach)
27
27
 
28
28
  def _step(self, group):
@@ -38,10 +38,12 @@ class ForeachAdamW(StatefulOptimizer):
38
38
  if not active_p:
39
39
  return
40
40
 
41
+ storage_dtype = getattr(torch, group['storage_dtype'])
42
+
41
43
  for p in active_p:
42
44
  if 'exp_avg' not in self.state_(p):
43
- self.state_(p)['exp_avg'] = torch.zeros_like(p.data, dtype=torch.float32)
44
- self.state_(p)['exp_avg_sq'] = torch.zeros_like(p.data, dtype=torch.float32)
45
+ self.state_(p)['exp_avg'] = torch.zeros_like(p.data, dtype=storage_dtype)
46
+ self.state_(p)['exp_avg_sq'] = torch.zeros_like(p.data, dtype=storage_dtype)
45
47
 
46
48
  y, grad, exp_avg_sq, exp_avg = zip(
47
49
  *[(p.data, p.grad, self.state_(p)['exp_avg_sq'], self.state_(p)['exp_avg']) for p in active_p])
@@ -27,9 +27,9 @@ def _compilable_step_(y, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps,
27
27
  class ForeachADOPT(StatefulOptimizer):
28
28
 
29
29
  def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
30
- foreach: bool = True):
30
+ foreach: bool = True, storage_dtype: str = 'float32'):
31
31
  defaults = dict(lr=lr, betas=betas, eps=eps, k=0, warmup_steps=warmup_steps, train_mode=True, weight_sum=0.0,
32
- lr_max=-1.0, weight_decay=weight_decay)
32
+ lr_max=-1.0, weight_decay=weight_decay, storage_dtype=storage_dtype)
33
33
  super().__init__(params, defaults, foreach)
34
34
 
35
35
  def _step(self, group):
@@ -45,10 +45,12 @@ class ForeachADOPT(StatefulOptimizer):
45
45
  if not active_p:
46
46
  return
47
47
 
48
+ storage_dtype = getattr(torch, group['storage_dtype'])
49
+
48
50
  for p in active_p:
49
51
  if 'exp_avg' not in self.state_(p):
50
- self.state_(p)['exp_avg'] = torch.zeros_like(p.data, dtype=torch.float32)
51
- self.state_(p)['exp_avg_sq'] = torch.zeros_like(p.data, dtype=torch.float32)
52
+ self.state_(p)['exp_avg'] = torch.zeros_like(p.data, dtype=storage_dtype)
53
+ self.state_(p)['exp_avg_sq'] = torch.zeros_like(p.data, dtype=storage_dtype)
52
54
 
53
55
  y, grad, exp_avg_sq, exp_avg = zip(
54
56
  *[(p.data, p.grad, self.state_(p)['exp_avg_sq'], self.state_(p)['exp_avg']) for p in active_p])
@@ -1,7 +1,7 @@
1
1
  import torch
2
2
  import torch.optim
3
3
 
4
- from .utils import warmup, exp_avg_sq_, beta_debias, update_param_, StatefulOptimizer, promote
4
+ from .utils import warmup, exp_avg_sq_, beta_debias, update_param_, StatefulOptimizer, promote, copy_stochastic_list_
5
5
 
6
6
 
7
7
  @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
@@ -16,13 +16,16 @@ def _compilable_step_(y, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps,
16
16
 
17
17
  update_param_(y, exp_avg32, lr, decay)
18
18
 
19
+ copy_stochastic_list_(exp_avg, exp_avg32)
20
+ copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
21
+
19
22
 
20
23
  class ForeachLaProp(StatefulOptimizer):
21
24
 
22
25
  def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=1,
23
- foreach: bool = True):
26
+ foreach: bool = True, storage_dtype: str = 'float32'):
24
27
  defaults = dict(lr=lr, betas=betas, eps=eps, k=0, warmup_steps=warmup_steps, train_mode=True, weight_sum=0.0,
25
- lr_max=-1.0, weight_decay=weight_decay)
28
+ lr_max=-1.0, weight_decay=weight_decay, storage_dtype=storage_dtype)
26
29
  super().__init__(params, defaults, foreach)
27
30
 
28
31
  def _step(self, group):
@@ -38,10 +41,12 @@ class ForeachLaProp(StatefulOptimizer):
38
41
  if not active_p:
39
42
  return
40
43
 
44
+ storage_dtype = getattr(torch, group['storage_dtype'])
45
+
41
46
  for p in active_p:
42
47
  if 'exp_avg' not in self.state_(p):
43
- self.state_(p)['exp_avg'] = torch.zeros_like(p.data, dtype=torch.float32)
44
- self.state_(p)['exp_avg_sq'] = torch.zeros_like(p.data, dtype=torch.float32)
48
+ self.state_(p)['exp_avg'] = torch.zeros_like(p.data, dtype=storage_dtype)
49
+ self.state_(p)['exp_avg_sq'] = torch.zeros_like(p.data, dtype=storage_dtype)
45
50
 
46
51
  y, grad, exp_avg_sq, exp_avg = zip(
47
52
  *[(p.data, p.grad, self.state_(p)['exp_avg_sq'], self.state_(p)['exp_avg']) #
@@ -1,6 +1,6 @@
1
1
  import torch
2
2
  import torch.optim
3
- from heavyball.utils import get_ckp1
3
+ from heavyball.utils import get_ckp1, copy_stochastic_list_
4
4
 
5
5
  from .utils import warmup, ScheduleFree, exp_avg_sq_, beta_debias, promote, _compilable_schedule_free_
6
6
 
@@ -19,14 +19,15 @@ def _compilable_step_(y, grad, exp_avg_sq, z, beta1, beta2, step, ckp1, eps, dec
19
19
  for p, z_, g in zip(y, z, g32):
20
20
  _compilable_schedule_free_(p, z_, ckp1, g, lr, beta1)
21
21
 
22
+ copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
22
23
 
23
24
  class ForeachSFAdamW(ScheduleFree):
24
25
  def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0, r=0.0,
25
- weight_lr_power=2.0, foreach: bool = True):
26
+ weight_lr_power=2.0, foreach: bool = True, storage_dtype: str = 'float32'):
26
27
 
27
28
  defaults = dict(lr=lr, betas=betas, eps=eps, r=r, k=0, warmup_steps=warmup_steps, train_mode=True,
28
29
  weight_sum=0.0, lr_max=-1.0, weight_lr_power=weight_lr_power, weight_decay=weight_decay,
29
- foreach=foreach)
30
+ foreach=foreach, storage_dtype=storage_dtype)
30
31
  super().__init__(params, defaults, foreach)
31
32
 
32
33
  def _step(self, group):
@@ -42,10 +43,12 @@ class ForeachSFAdamW(ScheduleFree):
42
43
  if not active_p:
43
44
  return
44
45
 
46
+ storage_dtype = getattr(torch, group['storage_dtype'])
47
+
45
48
  for p in active_p:
46
49
  if 'z' not in self.state_(p):
47
50
  self.state_(p)['z'] = torch.clone(p.data)
48
- self.state_(p)['exp_avg_sq'] = torch.zeros_like(p.data, dtype=torch.float32)
51
+ self.state_(p)['exp_avg_sq'] = torch.zeros_like(p.data, dtype=storage_dtype)
49
52
 
50
53
  y, grad, exp_avg_sq, z = zip(*[(p.data, p.grad, self.state_(p)['exp_avg_sq'], self.state_(p)['z']) #
51
54
  for p in active_p])
@@ -1,7 +1,7 @@
1
1
  import torch
2
2
 
3
3
  from .utils import init_preconditioner, update_preconditioner, project, beta_debias, exp_avg_sq_, update_param_, set_, \
4
- split_p_and_g_in_group, StatefulOptimizer
4
+ split_p_and_g_in_group, StatefulOptimizer, exp_avg_
5
5
 
6
6
 
7
7
  class ForeachSOAP(StatefulOptimizer):
@@ -26,8 +26,7 @@ class ForeachSOAP(StatefulOptimizer):
26
26
  weight_decay: float = 0.01, precondition_frequency: int = 2, max_precond_dim: int = 2048, #
27
27
  merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
28
28
  data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1,
29
- split: bool = False,
30
- foreach: bool = True):
29
+ split: bool = False, foreach: bool = True):
31
30
  defaults = {"lr": lr, "betas": betas, "shampoo_beta": shampoo_beta, "eps": eps, "weight_decay": weight_decay,
32
31
  "precondition_frequency": precondition_frequency, "max_precond_dim": max_precond_dim,
33
32
  "merge_dims": merge_dims, "precondition_1d": precondition_1d, "normalize_grads": normalize_grads,
@@ -65,14 +64,12 @@ class ForeachSOAP(StatefulOptimizer):
65
64
  p_list, grad, grad_projected, exp_avg, exp_avg_sq = zip(*vals)
66
65
  beta1, beta2 = group["betas"]
67
66
 
68
- old_debiased1 = beta_debias(beta1, step)
69
67
  old_debiased2 = beta_debias(beta2, step)
70
68
 
71
69
  # Decay the first and second moment running average coefficient
72
70
  # In-place operations to update the averages at the same time
73
- torch._foreach_mul_(exp_avg, old_debiased1)
74
- torch._foreach_add_(exp_avg, grad, alpha=1 - old_debiased1)
75
- denom = exp_avg_sq_(exp_avg_sq, grad_projected, old_debiased2, group['eps'])
71
+ step_tensor = torch.empty((), dtype=torch.int32, device=p_list[0].device).fill_(step)
72
+ denom = exp_avg_(exp_avg, exp_avg_sq, grad, grad_projected, beta1, beta2, step_tensor)
76
73
 
77
74
  for p, g, ea, d in zip(p_list, grad, exp_avg, denom):
78
75
  state = self.state_(p)
@@ -39,7 +39,7 @@ class ForeachPaLMPAdam(PSGDBase):
39
39
  momentum_into_precond_update=True, warmup_steps: int = 1, betas=(None, None), beta: float = 0.9,
40
40
  beta2_scale: float = 0.8, merge_dims: bool = False, split: bool = False, clip_fn: callable = None,
41
41
  store_triu_as_line: bool = True, foreach: bool = True, q_dtype='float32',
42
- stochastic_schedule: bool = True, #
42
+ stochastic_schedule: bool = True, storage_dtype:str ='float32',#
43
43
  # expert parameters
44
44
  precond_init_scale=1.0, precond_lr=0.1):
45
45
  if not 0.0 <= lr:
@@ -57,7 +57,7 @@ class ForeachPaLMPAdam(PSGDBase):
57
57
  momentum_into_precond_update=momentum_into_precond_update, precond_lr=precond_lr,
58
58
  precond_init_scale=precond_init_scale, step=0, warmup_steps=warmup_steps, beta=beta,
59
59
  beta2_scale=beta2_scale, merge_dims=merge_dims, split=split,
60
- store_triu_as_line=store_triu_as_line, q_dtype=q_dtype)
60
+ store_triu_as_line=store_triu_as_line, q_dtype=q_dtype, storage_dtype=storage_dtype)
61
61
  super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
62
62
 
63
63
  def _step(self, group):
@@ -71,15 +71,16 @@ class ForeachPaLMPAdam(PSGDBase):
71
71
  lr = group['lr']
72
72
  store_triu_as_line = group['store_triu_as_line']
73
73
  q_dtype = getattr(torch, group['q_dtype'])
74
+ storage_dtype = getattr(torch, group['storage_dtype'])
74
75
 
75
76
  vals = []
76
77
 
77
- for p, g in split_p_and_g_in_group(group):
78
+ for p, g in split_p_and_g_in_group(group, should_promote=False):
78
79
  state = self.state_(p)
79
80
 
80
81
  if 'Q' not in state:
81
- state['exp_avg'] = torch.zeros_like(g)
82
- state['exp_avg_sq'] = torch.zeros_like(g)
82
+ state['exp_avg'] = torch.zeros_like(g, dtype=storage_dtype)
83
+ state['exp_avg_sq'] = torch.zeros_like(g, dtype=storage_dtype)
83
84
  Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular, min_ndim_triangular,
84
85
  memory_save_mode, dtype=q_dtype)
85
86
  state['Q'] = triu_to_line(Q) if store_triu_as_line else Q
@@ -103,6 +104,8 @@ class ForeachPaLMPAdam(PSGDBase):
103
104
 
104
105
  beta2 = 1 - group['step'] ** -group['beta2_scale']
105
106
 
107
+ lr = -warmup(lr, group['step'], group['warmup_steps'])
108
+
106
109
  for p, Q, g, ea, eas in zip(p_list, Q_triu, grad_list, exp_avg, exp_avg_sq):
107
110
  psgd_precond_grad(Q, self.state_(p)["exprs"], g, inplace=True)
108
111
  ea = psgd_precond_grad(Q, self.state_(p)["exprs"], ea)
@@ -112,8 +115,5 @@ class ForeachPaLMPAdam(PSGDBase):
112
115
  divide by g here, because g == denom (from exp_avg_sq_(out=g)), avoids denom allocation
113
116
  divide into g so we can deallocate ea, avoids one allocation (-> less memory than equivalent foreach)
114
117
  """
118
+ update_param_([p], self.clip_fn([g]), lr, weight_decay)
115
119
 
116
- grad_list = self.clip_fn(grad_list)
117
-
118
- lr = -warmup(lr, group['step'], group['warmup_steps'])
119
- update_param_(p_list, grad_list, lr, weight_decay)