heavyball 0.16.0__tar.gz → 0.17.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. {heavyball-0.16.0 → heavyball-0.17.1}/PKG-INFO +17 -17
  2. {heavyball-0.16.0 → heavyball-0.17.1}/README.md +16 -16
  3. {heavyball-0.16.0 → heavyball-0.17.1}/heavyball/__init__.py +7 -6
  4. heavyball-0.17.1/heavyball/cached_delayed_psgd_kron.py +146 -0
  5. {heavyball-0.16.0 → heavyball-0.17.1}/heavyball/cached_psgd_kron.py +13 -8
  6. {heavyball-0.16.0 → heavyball-0.17.1}/heavyball/delayed_psgd.py +8 -7
  7. {heavyball-0.16.0 → heavyball-0.17.1}/heavyball/p_adam.py +9 -7
  8. {heavyball-0.16.0 → heavyball-0.17.1}/heavyball/psgd_kron.py +8 -7
  9. {heavyball-0.16.0 → heavyball-0.17.1}/heavyball/pure_psgd.py +9 -6
  10. {heavyball-0.16.0 → heavyball-0.17.1}/heavyball/utils.py +18 -13
  11. {heavyball-0.16.0 → heavyball-0.17.1}/heavyball.egg-info/PKG-INFO +17 -17
  12. {heavyball-0.16.0 → heavyball-0.17.1}/heavyball.egg-info/SOURCES.txt +2 -0
  13. {heavyball-0.16.0 → heavyball-0.17.1}/setup.py +1 -1
  14. heavyball-0.17.1/test/test_bf16_q.py +52 -0
  15. {heavyball-0.16.0 → heavyball-0.17.1}/test/test_closure.py +1 -1
  16. {heavyball-0.16.0 → heavyball-0.17.1}/test/test_memory.py +2 -2
  17. {heavyball-0.16.0 → heavyball-0.17.1}/test/test_merge.py +1 -1
  18. {heavyball-0.16.0 → heavyball-0.17.1}/test/test_psgd.py +3 -14
  19. {heavyball-0.16.0 → heavyball-0.17.1}/LICENSE +0 -0
  20. {heavyball-0.16.0 → heavyball-0.17.1}/heavyball/foreach_adamw.py +0 -0
  21. {heavyball-0.16.0 → heavyball-0.17.1}/heavyball/foreach_adopt.py +0 -0
  22. {heavyball-0.16.0 → heavyball-0.17.1}/heavyball/foreach_laprop.py +0 -0
  23. {heavyball-0.16.0 → heavyball-0.17.1}/heavyball/foreach_sfadamw.py +0 -0
  24. {heavyball-0.16.0 → heavyball-0.17.1}/heavyball/foreach_soap.py +0 -0
  25. {heavyball-0.16.0 → heavyball-0.17.1}/heavyball/palm_foreach_sfadamw.py +0 -0
  26. {heavyball-0.16.0 → heavyball-0.17.1}/heavyball/palm_foreach_soap.py +0 -0
  27. {heavyball-0.16.0 → heavyball-0.17.1}/heavyball/precond_schedule_foreach_soap.py +0 -0
  28. {heavyball-0.16.0 → heavyball-0.17.1}/heavyball/precond_schedule_palm_foreach_soap.py +0 -0
  29. {heavyball-0.16.0 → heavyball-0.17.1}/heavyball/precond_schedule_sfpsoap.py +0 -0
  30. {heavyball-0.16.0 → heavyball-0.17.1}/heavyball/schedule_free_palm_foreach_soap.py +0 -0
  31. {heavyball-0.16.0 → heavyball-0.17.1}/heavyball.egg-info/dependency_links.txt +0 -0
  32. {heavyball-0.16.0 → heavyball-0.17.1}/heavyball.egg-info/requires.txt +0 -0
  33. {heavyball-0.16.0 → heavyball-0.17.1}/heavyball.egg-info/top_level.txt +0 -0
  34. {heavyball-0.16.0 → heavyball-0.17.1}/setup.cfg +0 -0
  35. {heavyball-0.16.0 → heavyball-0.17.1}/test/test_foreach.py +0 -0
  36. {heavyball-0.16.0 → heavyball-0.17.1}/test/test_no_grad.py +0 -0
  37. {heavyball-0.16.0 → heavyball-0.17.1}/test/test_soap.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: heavyball
3
- Version: 0.16.0
3
+ Version: 0.17.1
4
4
  Summary: Efficient optimizers
5
5
  Home-page: https://github.com/clashluke/heavyball
6
6
  Author: Lucas Nestler
@@ -32,8 +32,8 @@ A simple package of efficient optimizers
32
32
  The goal is not to thrive for completeness, full maintenance or abstraction, but instead to provide a simple
33
33
  largely static alternative to `torch.optim` with more and better optimizers.
34
34
 
35
- Currently (2024-11-17, 0.15.0), the recommended stable optimizer is `PrecondSchedulePaLMForeachSOAP` (see below). The
36
- recommended experimental optimizer is `ForeachPSGDKron`.
35
+ Currently (2024-11-20, 0.17.0), the recommended stable optimizer is `PrecondSchedulePaLMSOAP` (see below). The
36
+ recommended experimental optimizer is `DelayedPSGDKron` ([tuning guide](docs/psgd_efficiency.md)).
37
37
 
38
38
  ## Features
39
39
 
@@ -62,7 +62,7 @@ import heavyball
62
62
  model = torch.nn.Linear(16, 1)
63
63
 
64
64
  # Create an optimizer
65
- optimizer = heavyball.PrecondSchedulePaLMForeachSOAP(model.parameters(), lr=1e-3)
65
+ optimizer = heavyball.PrecondSchedulePaLMSOAP(model.parameters(), lr=1e-3)
66
66
 
67
67
  x = torch.randn(128, 16)
68
68
  y = torch.randn(128, 1)
@@ -76,19 +76,19 @@ for _ in range(1000):
76
76
 
77
77
  ## Optimizers
78
78
 
79
- | Name | Description | Advantages / Disadvantages |
80
- |--------------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
81
- | **ForeachAdamW** | More efficient (speed, memory) [AdamW](https://arxiv.org/abs/1711.05101) | + Faster than AdamW<br>+ Possibly more (numerically) stable
82
- | **ForeachLaProp** | More efficient (speed, memory) [LaProp](https://arxiv.org/abs/2002.04839) | + Same cost as AdamW<br>+ Marginally better converence (better proofs)<br>+ Higher hyperparameter stability<br>- Not a guaranteed win (can be neutral)<br>- No "Slingshot" |
83
- | **ForeachADOPT** | More efficient (speed, memory) [ADOPT](https://arxiv.org/abs/2411.02853) | + Same cost as AdamW<br>+ Rigorous mathematical convergence proofs, even for challenging models (GANs)<br>- Empirically underperforms LaProp<br>- no bf16 |
84
- | **ForeachSFAdamW** | More efficient (speed, memory) [ScheduleFree AdamW](https://arxiv.org/abs/2405.15682) | + Same cost as AdamW, but better eval perf<br>+ Full control over hyperparameters |
85
- | **PaLMForeachSFAdamW** | ForeachSFAdamW with [PaLM's beta2 schedule](https://arxiv.org/abs/2204.02311) | + Same cost as AdamW, but better eval perf<br>+ Less control, but faster early and more stable late convergence<br>+ ScheduleFree<br>- slow early convergence |
86
- | **ForeachSOAP** | More efficient (speed, memory) [SOAP](https://arxiv.org/abs/2409.11321) | + Faster convergence (loss-at-step)<br>+ Full control over hyperparameters<br>- more memory usage<br>- more hyperparameters<br>- higher overhead than AdamW (can be ammortized; better loss-at-second) |
87
- | **PaLMForeachSOAP** | ForeachSOAP with [PaLM's beta2 schedule](https://arxiv.org/abs/2204.02311) | + Faster convergence (loss-at-step)<br>+ Less control, but faster early and more stable late convergence<br>- more memory usage<br>- more hyperparameters<br>- higher overhead than AdamW (can be ammortized; better loss-at-second) |
88
- | **SFPaLMForeachSOAP** | ScheduleFree PaLMForeachSOAP | + Fast convergence (loss-at-step)<br>+ less memory usage than PaLMForeachSOAP (more tham AdamW)<br>- slower initial convergence than PaLMForeachSOAP (but allows higher LRs)<br>- higher overhead than AdamW (can be ammortized) |
89
- | **PrecondScheduleSFPaLMForeachSOAP** | SFPaLMForeachSOAP with [preconditioner schedule](https://github.com/lixilinx/psgd_torch/), matching the error of PrecondEvery=2 with the cost of PrecondEvery=512 | + Better initial convergence than SFPaLMForeachSOAP<br>+ Significantly faster (sec/it) later<br>+ less memory usage than PaLMForeachSOAP (more tham AdamW)<br>- slower initial convergence than PaLMForeachSOAP (but allows higher LRs)<br>- higher overhead than AdamW (can be ammortized), goes to 0 with increasing number of step |
90
- | **PrecondSchedulePaLMForeachSOAP** | PrecondScheduleSFPaLMForeachSOAP without schedule-free | + Best initial convergence<br>+ Significantly faster (sec/it) later<br>+ high stability<br>- more memory usage than PrecondScheduleSFPaLMForeachSOAP<br>- higher overhead than AdamW (can be ammortized), goes to 0 with increasing number of steps |
91
- | **PrecondScheduleForeachSOAP** | PrecondScheduleSFPaLMForeachSOAP without PaLM's beta2 schedule | + Better initial convergence<br>+ Significantly faster (sec/it) later<br>- more memory usage than PrecondScheduleSFPaLMForeachSOAP<br>- higher overhead than AdamW (can be ammortized), goes to 0 with increasing number of steps |
79
+ | Name | Description | Advantages / Disadvantages |
80
+ |-------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
81
+ | **AdamW** | More efficient (speed, memory) [AdamW](https://arxiv.org/abs/1711.05101) | + Faster than AdamW<br>+ Possibly more (numerically) stable
82
+ | **LaProp** | More efficient (speed, memory) [LaProp](https://arxiv.org/abs/2002.04839) | + Same cost as AdamW<br>+ Marginally better converence (better proofs)<br>+ Higher hyperparameter stability<br>- Not a guaranteed win (can be neutral)<br>- No "Slingshot" |
83
+ | **ADOPT** | More efficient (speed, memory) [ADOPT](https://arxiv.org/abs/2411.02853) | + Same cost as AdamW<br>+ Rigorous mathematical convergence proofs, even for challenging models (GANs)<br>- Empirically underperforms LaProp<br>- no bf16 |
84
+ | **SFAdamW** | More efficient (speed, memory) [ScheduleFree AdamW](https://arxiv.org/abs/2405.15682) | + Same cost as AdamW, but better eval perf<br>+ Full control over hyperparameters |
85
+ | **PaLMSFAdamW** | ForeachSFAdamW with [PaLM's beta2 schedule](https://arxiv.org/abs/2204.02311) | + Same cost as AdamW, but better eval perf<br>+ Less control, but faster early and more stable late convergence<br>+ ScheduleFree<br>- slow early convergence |
86
+ | **SOAP** | More efficient (speed, memory) [SOAP](https://arxiv.org/abs/2409.11321) | + Faster convergence (loss-at-step)<br>+ Full control over hyperparameters<br>- more memory usage<br>- more hyperparameters<br>- higher overhead than AdamW (can be ammortized; better loss-at-second) |
87
+ | **PaLMSOAP** | ForeachSOAP with [PaLM's beta2 schedule](https://arxiv.org/abs/2204.02311) | + Faster convergence (loss-at-step)<br>+ Less control, but faster early and more stable late convergence<br>- more memory usage<br>- more hyperparameters<br>- higher overhead than AdamW (can be ammortized; better loss-at-second) |
88
+ | **SFPaLMSOAP** | ScheduleFree PaLMForeachSOAP | + Fast convergence (loss-at-step)<br>+ less memory usage than PaLMForeachSOAP (more tham AdamW)<br>- slower initial convergence than PaLMForeachSOAP (but allows higher LRs)<br>- higher overhead than AdamW (can be ammortized) |
89
+ | **PrecondScheduleSFPaLMSOAP** | SFPaLMForeachSOAP with [preconditioner schedule](https://github.com/lixilinx/psgd_torch/), matching the error of PrecondEvery=2 with the cost of PrecondEvery=512 | + Better initial convergence than SFPaLMForeachSOAP<br>+ Significantly faster (sec/it) later<br>+ less memory usage than PaLMForeachSOAP (more tham AdamW)<br>- slower initial convergence than PaLMForeachSOAP (but allows higher LRs)<br>- higher overhead than AdamW (can be ammortized), goes to 0 with increasing number of step |
90
+ | **PrecondSchedulePaLMSOAP** | PrecondScheduleSFPaLMForeachSOAP without schedule-free | + Best initial convergence<br>+ Significantly faster (sec/it) later<br>+ high stability<br>- more memory usage than PrecondScheduleSFPaLMForeachSOAP<br>- higher overhead than AdamW (can be ammortized), goes to 0 with increasing number of steps |
91
+ | **PrecondScheduleSOAP** | PrecondScheduleSFPaLMForeachSOAP without PaLM's beta2 schedule | + Better initial convergence<br>+ Significantly faster (sec/it) later<br>- more memory usage than PrecondScheduleSFPaLMForeachSOAP<br>- higher overhead than AdamW (can be ammortized), goes to 0 with increasing number of steps |
92
92
 
93
93
  ## Precond Schedule
94
94
 
@@ -8,8 +8,8 @@ A simple package of efficient optimizers
8
8
  The goal is not to thrive for completeness, full maintenance or abstraction, but instead to provide a simple
9
9
  largely static alternative to `torch.optim` with more and better optimizers.
10
10
 
11
- Currently (2024-11-17, 0.15.0), the recommended stable optimizer is `PrecondSchedulePaLMForeachSOAP` (see below). The
12
- recommended experimental optimizer is `ForeachPSGDKron`.
11
+ Currently (2024-11-20, 0.17.0), the recommended stable optimizer is `PrecondSchedulePaLMSOAP` (see below). The
12
+ recommended experimental optimizer is `DelayedPSGDKron` ([tuning guide](docs/psgd_efficiency.md)).
13
13
 
14
14
  ## Features
15
15
 
@@ -38,7 +38,7 @@ import heavyball
38
38
  model = torch.nn.Linear(16, 1)
39
39
 
40
40
  # Create an optimizer
41
- optimizer = heavyball.PrecondSchedulePaLMForeachSOAP(model.parameters(), lr=1e-3)
41
+ optimizer = heavyball.PrecondSchedulePaLMSOAP(model.parameters(), lr=1e-3)
42
42
 
43
43
  x = torch.randn(128, 16)
44
44
  y = torch.randn(128, 1)
@@ -52,19 +52,19 @@ for _ in range(1000):
52
52
 
53
53
  ## Optimizers
54
54
 
55
- | Name | Description | Advantages / Disadvantages |
56
- |--------------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
57
- | **ForeachAdamW** | More efficient (speed, memory) [AdamW](https://arxiv.org/abs/1711.05101) | + Faster than AdamW<br>+ Possibly more (numerically) stable
58
- | **ForeachLaProp** | More efficient (speed, memory) [LaProp](https://arxiv.org/abs/2002.04839) | + Same cost as AdamW<br>+ Marginally better converence (better proofs)<br>+ Higher hyperparameter stability<br>- Not a guaranteed win (can be neutral)<br>- No "Slingshot" |
59
- | **ForeachADOPT** | More efficient (speed, memory) [ADOPT](https://arxiv.org/abs/2411.02853) | + Same cost as AdamW<br>+ Rigorous mathematical convergence proofs, even for challenging models (GANs)<br>- Empirically underperforms LaProp<br>- no bf16 |
60
- | **ForeachSFAdamW** | More efficient (speed, memory) [ScheduleFree AdamW](https://arxiv.org/abs/2405.15682) | + Same cost as AdamW, but better eval perf<br>+ Full control over hyperparameters |
61
- | **PaLMForeachSFAdamW** | ForeachSFAdamW with [PaLM's beta2 schedule](https://arxiv.org/abs/2204.02311) | + Same cost as AdamW, but better eval perf<br>+ Less control, but faster early and more stable late convergence<br>+ ScheduleFree<br>- slow early convergence |
62
- | **ForeachSOAP** | More efficient (speed, memory) [SOAP](https://arxiv.org/abs/2409.11321) | + Faster convergence (loss-at-step)<br>+ Full control over hyperparameters<br>- more memory usage<br>- more hyperparameters<br>- higher overhead than AdamW (can be ammortized; better loss-at-second) |
63
- | **PaLMForeachSOAP** | ForeachSOAP with [PaLM's beta2 schedule](https://arxiv.org/abs/2204.02311) | + Faster convergence (loss-at-step)<br>+ Less control, but faster early and more stable late convergence<br>- more memory usage<br>- more hyperparameters<br>- higher overhead than AdamW (can be ammortized; better loss-at-second) |
64
- | **SFPaLMForeachSOAP** | ScheduleFree PaLMForeachSOAP | + Fast convergence (loss-at-step)<br>+ less memory usage than PaLMForeachSOAP (more tham AdamW)<br>- slower initial convergence than PaLMForeachSOAP (but allows higher LRs)<br>- higher overhead than AdamW (can be ammortized) |
65
- | **PrecondScheduleSFPaLMForeachSOAP** | SFPaLMForeachSOAP with [preconditioner schedule](https://github.com/lixilinx/psgd_torch/), matching the error of PrecondEvery=2 with the cost of PrecondEvery=512 | + Better initial convergence than SFPaLMForeachSOAP<br>+ Significantly faster (sec/it) later<br>+ less memory usage than PaLMForeachSOAP (more tham AdamW)<br>- slower initial convergence than PaLMForeachSOAP (but allows higher LRs)<br>- higher overhead than AdamW (can be ammortized), goes to 0 with increasing number of step |
66
- | **PrecondSchedulePaLMForeachSOAP** | PrecondScheduleSFPaLMForeachSOAP without schedule-free | + Best initial convergence<br>+ Significantly faster (sec/it) later<br>+ high stability<br>- more memory usage than PrecondScheduleSFPaLMForeachSOAP<br>- higher overhead than AdamW (can be ammortized), goes to 0 with increasing number of steps |
67
- | **PrecondScheduleForeachSOAP** | PrecondScheduleSFPaLMForeachSOAP without PaLM's beta2 schedule | + Better initial convergence<br>+ Significantly faster (sec/it) later<br>- more memory usage than PrecondScheduleSFPaLMForeachSOAP<br>- higher overhead than AdamW (can be ammortized), goes to 0 with increasing number of steps |
55
+ | Name | Description | Advantages / Disadvantages |
56
+ |-------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
57
+ | **AdamW** | More efficient (speed, memory) [AdamW](https://arxiv.org/abs/1711.05101) | + Faster than AdamW<br>+ Possibly more (numerically) stable
58
+ | **LaProp** | More efficient (speed, memory) [LaProp](https://arxiv.org/abs/2002.04839) | + Same cost as AdamW<br>+ Marginally better converence (better proofs)<br>+ Higher hyperparameter stability<br>- Not a guaranteed win (can be neutral)<br>- No "Slingshot" |
59
+ | **ADOPT** | More efficient (speed, memory) [ADOPT](https://arxiv.org/abs/2411.02853) | + Same cost as AdamW<br>+ Rigorous mathematical convergence proofs, even for challenging models (GANs)<br>- Empirically underperforms LaProp<br>- no bf16 |
60
+ | **SFAdamW** | More efficient (speed, memory) [ScheduleFree AdamW](https://arxiv.org/abs/2405.15682) | + Same cost as AdamW, but better eval perf<br>+ Full control over hyperparameters |
61
+ | **PaLMSFAdamW** | ForeachSFAdamW with [PaLM's beta2 schedule](https://arxiv.org/abs/2204.02311) | + Same cost as AdamW, but better eval perf<br>+ Less control, but faster early and more stable late convergence<br>+ ScheduleFree<br>- slow early convergence |
62
+ | **SOAP** | More efficient (speed, memory) [SOAP](https://arxiv.org/abs/2409.11321) | + Faster convergence (loss-at-step)<br>+ Full control over hyperparameters<br>- more memory usage<br>- more hyperparameters<br>- higher overhead than AdamW (can be ammortized; better loss-at-second) |
63
+ | **PaLMSOAP** | ForeachSOAP with [PaLM's beta2 schedule](https://arxiv.org/abs/2204.02311) | + Faster convergence (loss-at-step)<br>+ Less control, but faster early and more stable late convergence<br>- more memory usage<br>- more hyperparameters<br>- higher overhead than AdamW (can be ammortized; better loss-at-second) |
64
+ | **SFPaLMSOAP** | ScheduleFree PaLMForeachSOAP | + Fast convergence (loss-at-step)<br>+ less memory usage than PaLMForeachSOAP (more tham AdamW)<br>- slower initial convergence than PaLMForeachSOAP (but allows higher LRs)<br>- higher overhead than AdamW (can be ammortized) |
65
+ | **PrecondScheduleSFPaLMSOAP** | SFPaLMForeachSOAP with [preconditioner schedule](https://github.com/lixilinx/psgd_torch/), matching the error of PrecondEvery=2 with the cost of PrecondEvery=512 | + Better initial convergence than SFPaLMForeachSOAP<br>+ Significantly faster (sec/it) later<br>+ less memory usage than PaLMForeachSOAP (more tham AdamW)<br>- slower initial convergence than PaLMForeachSOAP (but allows higher LRs)<br>- higher overhead than AdamW (can be ammortized), goes to 0 with increasing number of step |
66
+ | **PrecondSchedulePaLMSOAP** | PrecondScheduleSFPaLMForeachSOAP without schedule-free | + Best initial convergence<br>+ Significantly faster (sec/it) later<br>+ high stability<br>- more memory usage than PrecondScheduleSFPaLMForeachSOAP<br>- higher overhead than AdamW (can be ammortized), goes to 0 with increasing number of steps |
67
+ | **PrecondScheduleSOAP** | PrecondScheduleSFPaLMForeachSOAP without PaLM's beta2 schedule | + Better initial convergence<br>+ Significantly faster (sec/it) later<br>- more memory usage than PrecondScheduleSFPaLMForeachSOAP<br>- higher overhead than AdamW (can be ammortized), goes to 0 with increasing number of steps |
68
68
 
69
69
  ## Precond Schedule
70
70
 
@@ -14,31 +14,32 @@ from .precond_schedule_sfpsoap import PrecondScheduleSFPaLMSOAP
14
14
  from .psgd_kron import ForeachPSGDKron
15
15
  from .pure_psgd import ForeachPurePSGD
16
16
  from .schedule_free_palm_foreach_soap import SFPaLMForeachSOAP
17
+ from .cached_delayed_psgd_kron import ForeachCachedDelayedPSGDKron
17
18
 
18
19
  PalmForEachSoap = PaLMForeachSOAP
19
20
 
20
21
  PaLMSOAP = PaLMForeachSOAP
21
22
  PaLMSFAdamW = PaLMForeachSFAdamW
22
23
  PaLMSFSoap = SFPaLMForeachSOAP
23
- PaLMForeachSOAP = PaLMForeachSOAP
24
24
  PrecondScheduleSFPaLMSOAP = PrecondScheduleSFPaLMSOAP
25
25
  SOAP = ForeachSOAP
26
26
  SFAdamW = ForeachSFAdamW
27
27
  LaProp = ForeachLaProp
28
28
  ADOPT = ForeachADOPT
29
- PrecondScheduleForeachSOAP = PrecondScheduleForeachSOAP
30
- PrecondSchedulePaLMForeachSOAP = PrecondSchedulePaLMForeachSOAP
29
+ PrecondScheduleSOAP = PrecondScheduleForeachSOAP
30
+ PrecondSchedulePaLMSOAP = PrecondSchedulePaLMForeachSOAP
31
31
  PSGDKron = ForeachPSGDKron
32
32
  AdamW = ForeachAdamW
33
33
  PurePSGD = ForeachPurePSGD
34
34
  PaLMPAdam = ForeachPaLMPAdam
35
35
  DelayedPSGD = ForeachDelayedPSGD
36
36
  CachedPSGDKron = ForeachCachedPSGDKron
37
+ CachedDelayedPSGDKron = ForeachCachedDelayedPSGDKron
37
38
 
38
39
  __all__ = ['PalmForEachSoap', 'PaLMForeachSFAdamW', 'PaLMForeachSOAP', 'SFPaLMForeachSOAP', 'PrecondScheduleSFPaLMSOAP',
39
40
  'ForeachSOAP', 'ForeachSFAdamW', 'ForeachLaProp', 'ForeachADOPT', 'PrecondScheduleForeachSOAP',
40
41
  'PrecondSchedulePaLMForeachSOAP', 'ForeachPSGDKron', 'ForeachAdamW', 'ForeachPurePSGD', 'ForeachPaLMPAdam',
41
- 'ForeachDelayedPSGD', 'ForeachCachedPSGDKron', #
42
- 'PaLMSOAP', 'PaLMSFAdamW', 'PaLMSFSoap', 'PaLMSFAdamW', 'PaLMForeachSOAP', 'PrecondScheduleSFPaLMSOAP',
42
+ 'ForeachDelayedPSGD', 'ForeachCachedPSGDKron', 'ForeachCachedDelayedPSGDKron', #
43
+ 'PaLMSOAP', 'PaLMSFAdamW', 'PaLMSFSoap', 'PaLMSFAdamW', 'PrecondScheduleSFPaLMSOAP',
43
44
  'SOAP', 'SFAdamW', 'LaProp', 'ADOPT', 'PSGDKron', 'AdamW', 'PurePSGD', 'PaLMPAdam', 'DelayedPSGD',
44
- 'CachedPSGDKron']
45
+ 'CachedPSGDKron', 'CachedDelayedPSGDKron', 'PrecondScheduleSOAP', 'PrecondSchedulePaLMSOAP']
@@ -0,0 +1,146 @@
1
+ """
2
+ Originally from Evan Walters and Omead Pooladzandi, 2024
3
+ Modified under Creative Commons Attribution 4.0 International
4
+ Source available at https://github.com/evanatyourservice/kron_torch/blob/97a2b5ee8a1a4c29e4780bbf6c521e545189eff9/kron_torch/kron.py
5
+ """
6
+
7
+ from typing import Optional
8
+
9
+ import torch
10
+ from heavyball.utils import einsum_base
11
+
12
+ from .utils import update_param_, warmup, psgd_precond_grad, init_Q_exprs, trust_region_clip_, PSGDBase, \
13
+ precond_update_prob_schedule, split_p_and_g_in_group, line_to_triu, triu_to_line, set_, einsum_base, promote
14
+
15
+
16
+ class ForeachCachedDelayedPSGDKron(PSGDBase):
17
+ """
18
+ Implements PSGD with off-by-one preconditioning (akin to ADOPT and SOAP) with cached preconditioners.
19
+
20
+
21
+ Args:
22
+ params (iterable): Iterable of parameters to optimize or dicts defining
23
+ parameter groups.
24
+ lr (float): Learning rate.
25
+ b1 (float): Momentum parameter.
26
+ weight_decay (float): Weight decay (L2 penalty).
27
+ preconditioner_update_probability (callable or float, optional): Probability of
28
+ updating the preconditioner. If None, defaults to a schedule that anneals
29
+ from 1.0 to 0.03 by 4000 steps.
30
+ max_size_triangular (int): Max size for dim's preconditioner to be triangular.
31
+ min_ndim_triangular (int): Minimum number of dimensions a layer needs
32
+ to have triangular preconditioners.
33
+ memory_save_mode: (string, optional), None, 'one_diag', or 'all_diag', None is default
34
+ to set all preconditioners to be triangular, 'one_diag' sets the largest
35
+ or last dim to be diagonal per layer, and 'all_diag' sets all preconditioners
36
+ to be diagonal.
37
+ momentum_into_precond_update: (bool), whether to send momentum into preconditioner
38
+ update instead of raw gradients.
39
+ """
40
+
41
+ def __init__(self, params, lr=0.001, beta=0.9, weight_decay=0.0, preconditioner_update_probability=None,
42
+ max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
43
+ momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
44
+ split: bool = False, clip_fn: Optional[callable] = None, store_triu_as_line: bool = True,
45
+ foreach: bool = True, q_dtype='float32'):
46
+ if not 0.0 <= lr:
47
+ raise ValueError(f"Invalid learning rate: {lr}")
48
+ if not 0.0 <= beta < 1.0:
49
+ raise ValueError(f"Invalid beta parameter: {beta}")
50
+ if not 0.0 <= weight_decay:
51
+ raise ValueError(f"Invalid weight_decay value: {weight_decay}")
52
+
53
+ if preconditioner_update_probability is None:
54
+ preconditioner_update_probability = precond_update_prob_schedule()
55
+ if clip_fn is None:
56
+ clip_fn = lambda x: trust_region_clip_(x, 0.9, 1.5)
57
+ self.preconditioner_update_probability = preconditioner_update_probability
58
+ self.clip_fn = clip_fn
59
+
60
+ defaults = dict(lr=lr, beta=beta, weight_decay=weight_decay, max_size_triangular=max_size_triangular,
61
+ min_ndim_triangular=min_ndim_triangular, memory_save_mode=memory_save_mode,
62
+ momentum_into_precond_update=momentum_into_precond_update, precond_lr=0.1,
63
+ # precond lr hardcoded to 0.1
64
+ precond_init_scale=1.0, # precond init scale hardcoded to 1.0
65
+ step=0, warmup_steps=warmup_steps, merge_dims=merge_dims, split=split,
66
+ store_triu_as_line=store_triu_as_line,
67
+ q_dtype=q_dtype)
68
+ super().__init__(params, defaults, foreach)
69
+
70
+ self._prob_step = 0
71
+
72
+ def _step(self, group):
73
+ # update preconditioners all together
74
+ update_prob = self.preconditioner_update_probability
75
+ if callable(update_prob):
76
+ update_prob = update_prob(self._prob_step)
77
+ do_update = self.rng.random() < update_prob
78
+ self._prob_step += 1
79
+
80
+ momentum_into_precond_update = group.get("momentum_into_precond_update", True)
81
+ precond_init_scale = group['precond_init_scale']
82
+ max_size_triangular = group['max_size_triangular']
83
+ min_ndim_triangular = group['min_ndim_triangular']
84
+ memory_save_mode = group['memory_save_mode']
85
+ precond_lr = group['precond_lr']
86
+ weight_decay = group['weight_decay']
87
+ lr = group['lr']
88
+ beta = group['beta']
89
+ store_triu_as_line = group['store_triu_as_line']
90
+ q_dtype = getattr(torch, group['q_dtype'])
91
+
92
+ vals = []
93
+
94
+ for p, g in split_p_and_g_in_group(group):
95
+ state = self.state_(p)
96
+
97
+ if 'Q' not in state:
98
+ state["exp_avg"] = torch.zeros_like(g)
99
+ Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular, min_ndim_triangular,
100
+ memory_save_mode, dtype=q_dtype)
101
+ state['Q'] = triu_to_line(Q) if store_triu_as_line else Q
102
+ state['Q_cache'] = [torch.empty_like(q) for q in Q]
103
+
104
+ expr = [f'{c.upper()}{c}' if q_.ndim == 2 else c for c, q_ in zip(einsum_base, Q)]
105
+ expr = ','.join(expr)
106
+ grad_expr = ''.join(c for c, _ in zip(einsum_base, g.shape))
107
+ out_expr = ''.join(c.upper() if c.upper() in expr else c for c in grad_expr)
108
+ expr = f'{expr},{grad_expr}->{out_expr}'
109
+
110
+ state['cache_expr'] = expr
111
+
112
+ vals.append((p, g, state["exp_avg"], state["Q"], state['Q_cache']))
113
+
114
+ if not vals:
115
+ return
116
+
117
+ p_list, grad_list, exp_avg_list, Q_list, Q_cache_list = zip(*vals)
118
+ del vals
119
+
120
+ group["step"] += 1
121
+
122
+ torch._foreach_lerp_(exp_avg_list, grad_list, (1 - beta) / (1 - beta ** group["step"]))
123
+
124
+ grad_list, Q_list, Q_cache_list, exp_avg_list = list(grad_list), list(Q_list), list(Q_cache_list), list(
125
+ exp_avg_list)
126
+ for i, (p, g) in enumerate(zip(p_list, grad_list)):
127
+ cached_q = Q_cache_list.pop(0)
128
+ q_orig = Q_list.pop(0)
129
+ ea = exp_avg_list.pop(0)
130
+
131
+ if do_update:
132
+ q = line_to_triu(q_orig) if store_triu_as_line else q_orig
133
+ q32 = [promote(q_) for q_ in q]
134
+ self.balance([g], [q32])
135
+ self.do_update([p], [ea if momentum_into_precond_update else g], [q32], precond_lr, [q_orig], store_triu_as_line=store_triu_as_line)
136
+ for c_, q_ in zip(cached_q, q):
137
+ if q_.ndim == 2:
138
+ torch.matmul(q_.T.conj(), q_, out=c_)
139
+ else:
140
+ torch.mul(q_.conj(), q_, out=c_)
141
+
142
+ set_(g, torch.einsum(self.state_(p)['cache_expr'], *cached_q, ea))
143
+ grad_list = self.clip_fn(grad_list)
144
+
145
+ lr = -warmup(lr, group['step'], group['warmup_steps'])
146
+ update_param_(p_list, grad_list, lr, weight_decay)
@@ -10,7 +10,7 @@ import torch
10
10
  from heavyball.utils import einsum_base
11
11
 
12
12
  from .utils import update_param_, warmup, psgd_precond_grad, init_Q_exprs, trust_region_clip_, PSGDBase, \
13
- precond_update_prob_schedule, split_p_and_g_in_group, line_to_triu, triu_to_line, set_, einsum_base
13
+ precond_update_prob_schedule, split_p_and_g_in_group, line_to_triu, triu_to_line, set_, einsum_base, promote
14
14
 
15
15
 
16
16
  class ForeachCachedPSGDKron(PSGDBase):
@@ -40,7 +40,7 @@ class ForeachCachedPSGDKron(PSGDBase):
40
40
  max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
41
41
  momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
42
42
  split: bool = False, clip_fn: Optional[callable] = None, store_triu_as_line: bool = True,
43
- foreach: bool = True):
43
+ foreach: bool = True, q_dtype='float32'):
44
44
  if not 0.0 <= lr:
45
45
  raise ValueError(f"Invalid learning rate: {lr}")
46
46
  if not 0.0 <= beta < 1.0:
@@ -61,7 +61,8 @@ class ForeachCachedPSGDKron(PSGDBase):
61
61
  # precond lr hardcoded to 0.1
62
62
  precond_init_scale=1.0, # precond init scale hardcoded to 1.0
63
63
  step=0, warmup_steps=warmup_steps, merge_dims=merge_dims, split=split,
64
- store_triu_as_line=store_triu_as_line)
64
+ store_triu_as_line=store_triu_as_line,
65
+ q_dtype=q_dtype)
65
66
  super().__init__(params, defaults, foreach)
66
67
 
67
68
  self._prob_step = 0
@@ -84,6 +85,7 @@ class ForeachCachedPSGDKron(PSGDBase):
84
85
  lr = group['lr']
85
86
  beta = group['beta']
86
87
  store_triu_as_line = group['store_triu_as_line']
88
+ q_dtype = getattr(torch, group['q_dtype'])
87
89
 
88
90
  vals = []
89
91
 
@@ -93,7 +95,7 @@ class ForeachCachedPSGDKron(PSGDBase):
93
95
  if 'Q' not in state:
94
96
  state["exp_avg"] = torch.zeros_like(g)
95
97
  Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular, min_ndim_triangular,
96
- memory_save_mode, dtype=g.dtype)
98
+ memory_save_mode, dtype=q_dtype)
97
99
  state['Q'] = triu_to_line(Q) if store_triu_as_line else Q
98
100
  state['Q_cache'] = [torch.empty_like(q) for q in Q]
99
101
 
@@ -124,18 +126,21 @@ class ForeachCachedPSGDKron(PSGDBase):
124
126
  q_orig = Q_list.pop(0)
125
127
  ea = exp_avg_list.pop(0)
126
128
 
129
+ new = torch.einsum(self.state_(p)['cache_expr'], *cached_q, ea)
130
+
127
131
  if do_update:
128
132
  q = line_to_triu(q_orig) if store_triu_as_line else q_orig
129
- self.balance([g], [q])
130
- self.do_update([p], [ea if momentum_into_precond_update else g], [q], precond_lr,
131
- [q_orig] if store_triu_as_line else None)
133
+ q32 = [promote(q_) for q_ in q]
134
+ self.balance([g], [q32])
135
+ self.do_update([p], [ea if momentum_into_precond_update else g], [q32], precond_lr, [q_orig], store_triu_as_line=store_triu_as_line)
132
136
  for c_, q_ in zip(cached_q, q):
133
137
  if q_.ndim == 2:
134
138
  torch.matmul(q_.T.conj(), q_, out=c_)
135
139
  else:
136
140
  torch.mul(q_.conj(), q_, out=c_)
137
141
 
138
- set_(g, torch.einsum(self.state_(p)['cache_expr'], *cached_q, ea))
142
+ set_(g, new)
143
+
139
144
  grad_list = self.clip_fn(grad_list)
140
145
 
141
146
  lr = -warmup(lr, group['step'], group['warmup_steps'])
@@ -8,7 +8,7 @@ import torch
8
8
  from heavyball.utils import copy_stochastic_list_
9
9
 
10
10
  from .utils import update_param_, warmup, psgd_precond_grad, init_Q_exprs, trust_region_clip_, PSGDBase, \
11
- precond_update_prob_schedule, split_p_and_g_in_group, triu_to_line, line_to_triu, set_
11
+ precond_update_prob_schedule, split_p_and_g_in_group, triu_to_line, line_to_triu, set_, promote
12
12
 
13
13
 
14
14
  class ForeachDelayedPSGD(PSGDBase):
@@ -39,7 +39,7 @@ class ForeachDelayedPSGD(PSGDBase):
39
39
  max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
40
40
  momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
41
41
  split: bool = False, clip_fn: callable = None, store_triu_as_line: bool = True,
42
- foreach: bool = True):
42
+ foreach: bool = True, q_dtype='float32'):
43
43
  if not 0.0 <= lr:
44
44
  raise ValueError(f"Invalid learning rate: {lr}")
45
45
  if not 0.0 <= beta < 1.0:
@@ -60,7 +60,7 @@ class ForeachDelayedPSGD(PSGDBase):
60
60
  # precond lr hardcoded to 0.1
61
61
  precond_init_scale=1.0, # precond init scale hardcoded to 1.0
62
62
  step=0, warmup_steps=warmup_steps, merge_dims=merge_dims, split=split,
63
- store_triu_as_line=store_triu_as_line)
63
+ store_triu_as_line=store_triu_as_line, q_dtype=q_dtype)
64
64
  super().__init__(params, defaults, foreach)
65
65
 
66
66
  self._prob_step = 0
@@ -83,6 +83,7 @@ class ForeachDelayedPSGD(PSGDBase):
83
83
  lr = group['lr']
84
84
  beta = group['beta']
85
85
  store_triu_as_line = group['store_triu_as_line']
86
+ q_dtype = getattr(torch, group['q_dtype'])
86
87
 
87
88
  vals = []
88
89
 
@@ -92,7 +93,7 @@ class ForeachDelayedPSGD(PSGDBase):
92
93
  if 'Q' not in state:
93
94
  state["exp_avg"] = torch.zeros_like(g)
94
95
  Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular, min_ndim_triangular,
95
- memory_save_mode, dtype=g.dtype)
96
+ memory_save_mode, dtype=q_dtype)
96
97
  state["Q"] = triu_to_line(Q) if store_triu_as_line else Q
97
98
 
98
99
  vals.append((p, g, state["exp_avg"], state["Q"]))
@@ -114,9 +115,9 @@ class ForeachDelayedPSGD(PSGDBase):
114
115
  q = line_to_triu(q_orig) if store_triu_as_line else q_orig
115
116
  new = psgd_precond_grad(q, self.state_(p)["exprs"], ea)
116
117
  if do_update:
117
- self.do_update([p], [ea if momentum_into_precond_update else g], [q], precond_lr,
118
- [q_orig] if store_triu_as_line else None)
119
- self.balance([g], [q])
118
+ q32 = [promote(q_) for q_ in q]
119
+ self.do_update([p], [ea if momentum_into_precond_update else g], [q32], precond_lr, [q_orig], store_triu_as_line=store_triu_as_line)
120
+ self.balance([g], [q32])
120
121
  set_(g, new)
121
122
 
122
123
  grad_list = self.clip_fn(grad_list)
@@ -8,7 +8,7 @@ import torch
8
8
  from heavyball.utils import triu_to_line, line_to_triu
9
9
 
10
10
  from .utils import update_param_, warmup, psgd_precond_grad, init_Q_exprs, PSGDBase, precond_update_prob_schedule, \
11
- exp_avg_sq_, beta_debias, split_p_and_g_in_group
11
+ exp_avg_sq_, beta_debias, split_p_and_g_in_group, promote
12
12
 
13
13
 
14
14
  class ForeachPaLMPAdam(PSGDBase):
@@ -39,7 +39,7 @@ class ForeachPaLMPAdam(PSGDBase):
39
39
  momentum_into_precond_update=True, warmup_steps: int = 1, betas=(None, None), beta: float = 0.9,
40
40
  beta2_scale: float = 0.8, merge_dims: bool = False, split: bool = False, clip_fn: callable = None,
41
41
  store_triu_as_line: bool = True,
42
- foreach: bool = True):
42
+ foreach: bool = True, q_dtype='float32'):
43
43
  if not 0.0 <= lr:
44
44
  raise ValueError(f"Invalid learning rate: {lr}")
45
45
  if not 0.0 <= weight_decay:
@@ -60,7 +60,7 @@ class ForeachPaLMPAdam(PSGDBase):
60
60
  # precond lr hardcoded to 0.1
61
61
  precond_init_scale=1.0, # precond init scale hardcoded to 1.0
62
62
  step=0, warmup_steps=warmup_steps, beta=beta, beta2_scale=beta2_scale, merge_dims=merge_dims,
63
- split=split, store_triu_as_line=store_triu_as_line)
63
+ split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype)
64
64
  super().__init__(params, defaults, foreach)
65
65
 
66
66
  self._prob_step = 0
@@ -81,6 +81,7 @@ class ForeachPaLMPAdam(PSGDBase):
81
81
  weight_decay = group['weight_decay']
82
82
  lr = group['lr']
83
83
  store_triu_as_line = group['store_triu_as_line']
84
+ q_dtype = getattr(torch, group['q_dtype'])
84
85
 
85
86
  vals = []
86
87
 
@@ -91,7 +92,7 @@ class ForeachPaLMPAdam(PSGDBase):
91
92
  state['exp_avg'] = torch.zeros_like(g)
92
93
  state['exp_avg_sq'] = torch.zeros_like(g)
93
94
  Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular,
94
- min_ndim_triangular, memory_save_mode, dtype=g.dtype)
95
+ min_ndim_triangular, memory_save_mode, dtype=q_dtype)
95
96
  state['Q'] = triu_to_line(Q) if store_triu_as_line else Q
96
97
 
97
98
  vals.append((p, g, state["Q"], state['exp_avg'], state['exp_avg_sq']))
@@ -106,9 +107,10 @@ class ForeachPaLMPAdam(PSGDBase):
106
107
 
107
108
  Q_triu = [line_to_triu(q) if store_triu_as_line else q for q in Q_list]
108
109
  if do_update:
109
- self.balance(grad_list, Q_triu)
110
- self.do_update(p_list, grad_list, Q_triu, precond_lr, Q_list if store_triu_as_line else None)
111
-
110
+ for g, p, q_, q_orig in zip(grad_list, p_list, Q_triu, Q_list):
111
+ q32 = [promote(qq_) for qq_ in q_]
112
+ self.balance([g], [q32])
113
+ self.do_update([p], [g], [q32], precond_lr, [q_orig], store_triu_as_line=store_triu_as_line)
112
114
  torch._foreach_lerp_(exp_avg, grad_list, 1 - beta_debias(group['beta'], group['step']))
113
115
 
114
116
  beta2 = 1 - group['step'] ** -group['beta2_scale']
@@ -9,7 +9,7 @@ from typing import Optional
9
9
  import torch
10
10
 
11
11
  from .utils import update_param_, warmup, psgd_precond_grad, init_Q_exprs, trust_region_clip_, PSGDBase, \
12
- precond_update_prob_schedule, split_p_and_g_in_group, line_to_triu, triu_to_line, set_
12
+ precond_update_prob_schedule, split_p_and_g_in_group, line_to_triu, triu_to_line, set_, promote
13
13
 
14
14
 
15
15
  class ForeachPSGDKron(PSGDBase):
@@ -39,7 +39,7 @@ class ForeachPSGDKron(PSGDBase):
39
39
  max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
40
40
  momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
41
41
  split: bool = False, clip_fn: Optional[callable] = None, store_triu_as_line: bool = True,
42
- foreach: bool = True):
42
+ foreach: bool = True, q_dtype='float32'):
43
43
  if not 0.0 <= lr:
44
44
  raise ValueError(f"Invalid learning rate: {lr}")
45
45
  if not 0.0 <= beta < 1.0:
@@ -60,7 +60,7 @@ class ForeachPSGDKron(PSGDBase):
60
60
  # precond lr hardcoded to 0.1
61
61
  precond_init_scale=1.0, # precond init scale hardcoded to 1.0
62
62
  step=0, warmup_steps=warmup_steps, merge_dims=merge_dims, split=split,
63
- store_triu_as_line=store_triu_as_line)
63
+ store_triu_as_line=store_triu_as_line, q_dtype=q_dtype)
64
64
  super().__init__(params, defaults, foreach)
65
65
 
66
66
  self._prob_step = 0
@@ -83,6 +83,7 @@ class ForeachPSGDKron(PSGDBase):
83
83
  lr = group['lr']
84
84
  beta = group['beta']
85
85
  store_triu_as_line = group['store_triu_as_line']
86
+ q_dtype = getattr(torch, group['q_dtype'])
86
87
 
87
88
  vals = []
88
89
 
@@ -92,7 +93,7 @@ class ForeachPSGDKron(PSGDBase):
92
93
  if 'Q' not in state:
93
94
  state["exp_avg"] = torch.zeros_like(g)
94
95
  Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular, min_ndim_triangular,
95
- memory_save_mode, dtype=g.dtype)
96
+ memory_save_mode, dtype=q_dtype)
96
97
  state['Q'] = triu_to_line(Q) if store_triu_as_line else Q
97
98
 
98
99
  vals.append((p, g, state["exp_avg"], state["Q"]))
@@ -114,9 +115,9 @@ class ForeachPSGDKron(PSGDBase):
114
115
  q = line_to_triu(q_orig) if store_triu_as_line else q_orig
115
116
 
116
117
  if do_update:
117
- self.balance([g], [q])
118
- self.do_update([p], [ea if momentum_into_precond_update else g], [q], precond_lr,
119
- [q_orig] if store_triu_as_line else None)
118
+ q32 = [promote(q_) for q_ in q]
119
+ self.balance([ea if momentum_into_precond_update else g], [q32])
120
+ self.do_update([p], [g], [q32], precond_lr, [q_orig], store_triu_as_line=store_triu_as_line)
120
121
  set_(g, psgd_precond_grad(q, self.state_(p)["exprs"], ea))
121
122
 
122
123
  grad_list = self.clip_fn(grad_list)
@@ -5,9 +5,10 @@ Source available at https://github.com/evanatyourservice/kron_torch/blob/97a2b5e
5
5
  """
6
6
 
7
7
  import torch
8
+ from heavyball.utils import copy_stochastic_list_
8
9
 
9
10
  from .utils import update_param_, warmup, psgd_precond_grad, init_Q_exprs, PSGDBase, precond_update_prob_schedule, \
10
- split_p_and_g_in_group, line_to_triu, triu_to_line
11
+ split_p_and_g_in_group, line_to_triu, triu_to_line, promote
11
12
 
12
13
 
13
14
  class ForeachPurePSGD(PSGDBase):
@@ -37,7 +38,7 @@ class ForeachPurePSGD(PSGDBase):
37
38
  max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
38
39
  momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
39
40
  split: bool = False, clip_fn: callable = None, store_triu_as_line: bool = True,
40
- foreach: bool = True):
41
+ foreach: bool = True, q_dtype='float32'):
41
42
  if not 0.0 <= lr:
42
43
  raise ValueError(f"Invalid learning rate: {lr}")
43
44
  if not 0.0 <= weight_decay:
@@ -56,7 +57,7 @@ class ForeachPurePSGD(PSGDBase):
56
57
  # precond lr hardcoded to 0.1
57
58
  precond_init_scale=1.0, # precond init scale hardcoded to 1.0
58
59
  step=0, warmup_steps=warmup_steps, merge_dims=merge_dims, split=split,
59
- store_triu_as_line=store_triu_as_line)
60
+ store_triu_as_line=store_triu_as_line, q_dtype=q_dtype)
60
61
  super().__init__(params, defaults, foreach)
61
62
 
62
63
  self._prob_step = 0
@@ -77,6 +78,7 @@ class ForeachPurePSGD(PSGDBase):
77
78
  weight_decay = group['weight_decay']
78
79
  lr = group['lr']
79
80
  store_triu_as_line = group['store_triu_as_line']
81
+ q_dtype = getattr(torch, group['q_dtype'])
80
82
 
81
83
  vals = []
82
84
 
@@ -85,7 +87,7 @@ class ForeachPurePSGD(PSGDBase):
85
87
 
86
88
  if 'Q' not in state:
87
89
  Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular, min_ndim_triangular,
88
- memory_save_mode, dtype=g.dtype)
90
+ memory_save_mode, dtype=q_dtype)
89
91
  state['Q'] = triu_to_line(Q) if store_triu_as_line else Q
90
92
 
91
93
  vals.append((p, g, state["Q"]))
@@ -104,8 +106,9 @@ class ForeachPurePSGD(PSGDBase):
104
106
  q = line_to_triu(q_orig) if store_triu_as_line else q_orig
105
107
 
106
108
  if do_update:
107
- self.balance([g], [q])
108
- self.do_update([p], [g], [q], precond_lr, [q_orig] if store_triu_as_line else None)
109
+ q32 = [promote(q_) for q_ in q]
110
+ self.balance([g], [q32])
111
+ self.do_update([p], [g], [q32], precond_lr, [q_orig], store_triu_as_line=store_triu_as_line)
109
112
  psgd_precond_grad(q, self.state_(p)["exprs"], g, inplace=True)
110
113
 
111
114
  grad_list = self.clip_fn(grad_list)
@@ -325,9 +325,9 @@ def compute_ggt(grad, GG, max_precond_dim, precondition_1d, beta):
325
325
 
326
326
 
327
327
  def promote(x):
328
- if x is (torch.bfloat16, torch.float16):
328
+ if x in (torch.bfloat16, torch.float16):
329
329
  return torch.float32
330
- if x.dtype in (torch.bfloat16, torch.float16):
330
+ if hasattr(x, 'dtype') and x.dtype in (torch.bfloat16, torch.float16):
331
331
  return x.float()
332
332
  return x
333
333
 
@@ -468,15 +468,15 @@ class ScheduleFree(StatefulOptimizer):
468
468
 
469
469
  def copy_stochastic_list_(target: List[torch.Tensor], source: List[torch.Tensor]):
470
470
  for t, s in zip(target, source):
471
- if t.dtype == torch.bfloat16:
472
- copy_stochastic_(t, s)
473
- else:
474
- set_(t, s)
471
+ copy_stochastic_(t, s)
475
472
 
476
473
 
477
474
  def copy_stochastic_(target: torch.Tensor, source: torch.Tensor):
478
475
  if target.data_ptr() == source.data_ptr():
479
476
  return
477
+ if target.dtype != torch.bfloat16:
478
+ set_(target, source)
479
+ return
480
480
 
481
481
  """Taken as-is from https://github.com/pytorch/pytorch/issues/120376#issuecomment-1974828905"""
482
482
  # create a random 16 bit integer
@@ -555,7 +555,7 @@ def init_Q_exprs(t, scale, max_size, min_ndim_triangular, memory_save_mode, dtyp
555
555
  for i, (size, dim_d) in enumerate(zip(shape, dim_diag)):
556
556
  if size == 1 or size > max_size or len(shape) < min_ndim_triangular or dim_d:
557
557
  # use diagonal matrix as preconditioner for this dim
558
- Q.append(scale * torch.ones(size, dtype=dtype, device=t.device))
558
+ Q.append(scale * torch.ones(size, dtype=promote(dtype), device=t.device))
559
559
 
560
560
  piece1A.append(letters[i])
561
561
  piece2A = piece2A + letters[i]
@@ -669,11 +669,11 @@ def psgd_update_precond(Q, exprs, V, G, step, tiny):
669
669
  @decorator
670
670
  def psgd_precond_grad(Q, exprs, G, inplace: bool = False):
671
671
  """Precondition gradient G with preconditioner Q."""
672
- out = torch.einsum(exprs[-1], *[q.conj() for q in Q], *Q, G)
672
+ out = torch.einsum(exprs[-1], *[q.conj() for q in Q], *Q, G.to(Q[0].dtype))
673
673
  if inplace:
674
674
  set_(G, out)
675
675
  return G
676
- return out
676
+ return out.to(G.dtype)
677
677
 
678
678
 
679
679
  def norm_clip_(x, scale=None):
@@ -768,28 +768,33 @@ def line_to_triu(Q_list: List[Tuple[Optional[List[int]], torch.Tensor]]):
768
768
  def update_triu_(q_state, materialised):
769
769
  for (shape0, q), (shape1, m) in zip(q_state, triu_to_line(materialised)):
770
770
  assert shape0 == shape1
771
- set_(q, m)
771
+ copy_stochastic_(q, m)
772
772
 
773
773
 
774
774
  class PSGDBase(StatefulOptimizer):
775
+ balance_probability: float = 0.01
776
+
775
777
  def __init__(self, parameters, groups, foreach: bool = True):
776
778
  super().__init__(parameters, groups, foreach)
777
779
  self.rng = random.Random(0x1923213)
778
780
  self._tiny = torch.finfo(torch.bfloat16).tiny
779
781
 
780
782
  def balance(self, grad_list, Q_list):
781
- if self.rng.random() > 0.01:
783
+ if self.rng.random() > self.balance_probability:
782
784
  return
783
785
 
784
786
  for g, q in zip(grad_list, Q_list):
785
787
  if g.dim() > 1:
786
788
  psgd_balance_Q(q)
787
789
 
788
- def do_update(self, p_list, grad_list, q_list, precond_lr, original_q: Optional[List] = None):
790
+ def do_update(self, p_list, grad_list, q_list, precond_lr, original_q: Optional[List] = None, store_triu_as_line=False):
789
791
  for i, (p, grad, Q) in enumerate(zip(p_list, grad_list, q_list)):
790
792
  psgd_update_precond(Q, self.state_(p)["exprs"], torch.randn_like(grad), grad, precond_lr, self._tiny)
791
793
  if original_q:
792
- update_triu_(original_q[i], Q)
794
+ if store_triu_as_line:
795
+ update_triu_(original_q[i], Q)
796
+ else:
797
+ copy_stochastic_(original_q[i], Q)
793
798
 
794
799
 
795
800
  def precond_update_prob_schedule(max_prob=1.0, min_prob=0.03, decay=0.001, flat_start=250):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: heavyball
3
- Version: 0.16.0
3
+ Version: 0.17.1
4
4
  Summary: Efficient optimizers
5
5
  Home-page: https://github.com/clashluke/heavyball
6
6
  Author: Lucas Nestler
@@ -32,8 +32,8 @@ A simple package of efficient optimizers
32
32
  The goal is not to thrive for completeness, full maintenance or abstraction, but instead to provide a simple
33
33
  largely static alternative to `torch.optim` with more and better optimizers.
34
34
 
35
- Currently (2024-11-17, 0.15.0), the recommended stable optimizer is `PrecondSchedulePaLMForeachSOAP` (see below). The
36
- recommended experimental optimizer is `ForeachPSGDKron`.
35
+ Currently (2024-11-20, 0.17.0), the recommended stable optimizer is `PrecondSchedulePaLMSOAP` (see below). The
36
+ recommended experimental optimizer is `DelayedPSGDKron` ([tuning guide](docs/psgd_efficiency.md)).
37
37
 
38
38
  ## Features
39
39
 
@@ -62,7 +62,7 @@ import heavyball
62
62
  model = torch.nn.Linear(16, 1)
63
63
 
64
64
  # Create an optimizer
65
- optimizer = heavyball.PrecondSchedulePaLMForeachSOAP(model.parameters(), lr=1e-3)
65
+ optimizer = heavyball.PrecondSchedulePaLMSOAP(model.parameters(), lr=1e-3)
66
66
 
67
67
  x = torch.randn(128, 16)
68
68
  y = torch.randn(128, 1)
@@ -76,19 +76,19 @@ for _ in range(1000):
76
76
 
77
77
  ## Optimizers
78
78
 
79
- | Name | Description | Advantages / Disadvantages |
80
- |--------------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
81
- | **ForeachAdamW** | More efficient (speed, memory) [AdamW](https://arxiv.org/abs/1711.05101) | + Faster than AdamW<br>+ Possibly more (numerically) stable
82
- | **ForeachLaProp** | More efficient (speed, memory) [LaProp](https://arxiv.org/abs/2002.04839) | + Same cost as AdamW<br>+ Marginally better converence (better proofs)<br>+ Higher hyperparameter stability<br>- Not a guaranteed win (can be neutral)<br>- No "Slingshot" |
83
- | **ForeachADOPT** | More efficient (speed, memory) [ADOPT](https://arxiv.org/abs/2411.02853) | + Same cost as AdamW<br>+ Rigorous mathematical convergence proofs, even for challenging models (GANs)<br>- Empirically underperforms LaProp<br>- no bf16 |
84
- | **ForeachSFAdamW** | More efficient (speed, memory) [ScheduleFree AdamW](https://arxiv.org/abs/2405.15682) | + Same cost as AdamW, but better eval perf<br>+ Full control over hyperparameters |
85
- | **PaLMForeachSFAdamW** | ForeachSFAdamW with [PaLM's beta2 schedule](https://arxiv.org/abs/2204.02311) | + Same cost as AdamW, but better eval perf<br>+ Less control, but faster early and more stable late convergence<br>+ ScheduleFree<br>- slow early convergence |
86
- | **ForeachSOAP** | More efficient (speed, memory) [SOAP](https://arxiv.org/abs/2409.11321) | + Faster convergence (loss-at-step)<br>+ Full control over hyperparameters<br>- more memory usage<br>- more hyperparameters<br>- higher overhead than AdamW (can be ammortized; better loss-at-second) |
87
- | **PaLMForeachSOAP** | ForeachSOAP with [PaLM's beta2 schedule](https://arxiv.org/abs/2204.02311) | + Faster convergence (loss-at-step)<br>+ Less control, but faster early and more stable late convergence<br>- more memory usage<br>- more hyperparameters<br>- higher overhead than AdamW (can be ammortized; better loss-at-second) |
88
- | **SFPaLMForeachSOAP** | ScheduleFree PaLMForeachSOAP | + Fast convergence (loss-at-step)<br>+ less memory usage than PaLMForeachSOAP (more tham AdamW)<br>- slower initial convergence than PaLMForeachSOAP (but allows higher LRs)<br>- higher overhead than AdamW (can be ammortized) |
89
- | **PrecondScheduleSFPaLMForeachSOAP** | SFPaLMForeachSOAP with [preconditioner schedule](https://github.com/lixilinx/psgd_torch/), matching the error of PrecondEvery=2 with the cost of PrecondEvery=512 | + Better initial convergence than SFPaLMForeachSOAP<br>+ Significantly faster (sec/it) later<br>+ less memory usage than PaLMForeachSOAP (more tham AdamW)<br>- slower initial convergence than PaLMForeachSOAP (but allows higher LRs)<br>- higher overhead than AdamW (can be ammortized), goes to 0 with increasing number of step |
90
- | **PrecondSchedulePaLMForeachSOAP** | PrecondScheduleSFPaLMForeachSOAP without schedule-free | + Best initial convergence<br>+ Significantly faster (sec/it) later<br>+ high stability<br>- more memory usage than PrecondScheduleSFPaLMForeachSOAP<br>- higher overhead than AdamW (can be ammortized), goes to 0 with increasing number of steps |
91
- | **PrecondScheduleForeachSOAP** | PrecondScheduleSFPaLMForeachSOAP without PaLM's beta2 schedule | + Better initial convergence<br>+ Significantly faster (sec/it) later<br>- more memory usage than PrecondScheduleSFPaLMForeachSOAP<br>- higher overhead than AdamW (can be ammortized), goes to 0 with increasing number of steps |
79
+ | Name | Description | Advantages / Disadvantages |
80
+ |-------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
81
+ | **AdamW** | More efficient (speed, memory) [AdamW](https://arxiv.org/abs/1711.05101) | + Faster than AdamW<br>+ Possibly more (numerically) stable
82
+ | **LaProp** | More efficient (speed, memory) [LaProp](https://arxiv.org/abs/2002.04839) | + Same cost as AdamW<br>+ Marginally better converence (better proofs)<br>+ Higher hyperparameter stability<br>- Not a guaranteed win (can be neutral)<br>- No "Slingshot" |
83
+ | **ADOPT** | More efficient (speed, memory) [ADOPT](https://arxiv.org/abs/2411.02853) | + Same cost as AdamW<br>+ Rigorous mathematical convergence proofs, even for challenging models (GANs)<br>- Empirically underperforms LaProp<br>- no bf16 |
84
+ | **SFAdamW** | More efficient (speed, memory) [ScheduleFree AdamW](https://arxiv.org/abs/2405.15682) | + Same cost as AdamW, but better eval perf<br>+ Full control over hyperparameters |
85
+ | **PaLMSFAdamW** | ForeachSFAdamW with [PaLM's beta2 schedule](https://arxiv.org/abs/2204.02311) | + Same cost as AdamW, but better eval perf<br>+ Less control, but faster early and more stable late convergence<br>+ ScheduleFree<br>- slow early convergence |
86
+ | **SOAP** | More efficient (speed, memory) [SOAP](https://arxiv.org/abs/2409.11321) | + Faster convergence (loss-at-step)<br>+ Full control over hyperparameters<br>- more memory usage<br>- more hyperparameters<br>- higher overhead than AdamW (can be ammortized; better loss-at-second) |
87
+ | **PaLMSOAP** | ForeachSOAP with [PaLM's beta2 schedule](https://arxiv.org/abs/2204.02311) | + Faster convergence (loss-at-step)<br>+ Less control, but faster early and more stable late convergence<br>- more memory usage<br>- more hyperparameters<br>- higher overhead than AdamW (can be ammortized; better loss-at-second) |
88
+ | **SFPaLMSOAP** | ScheduleFree PaLMForeachSOAP | + Fast convergence (loss-at-step)<br>+ less memory usage than PaLMForeachSOAP (more tham AdamW)<br>- slower initial convergence than PaLMForeachSOAP (but allows higher LRs)<br>- higher overhead than AdamW (can be ammortized) |
89
+ | **PrecondScheduleSFPaLMSOAP** | SFPaLMForeachSOAP with [preconditioner schedule](https://github.com/lixilinx/psgd_torch/), matching the error of PrecondEvery=2 with the cost of PrecondEvery=512 | + Better initial convergence than SFPaLMForeachSOAP<br>+ Significantly faster (sec/it) later<br>+ less memory usage than PaLMForeachSOAP (more tham AdamW)<br>- slower initial convergence than PaLMForeachSOAP (but allows higher LRs)<br>- higher overhead than AdamW (can be ammortized), goes to 0 with increasing number of step |
90
+ | **PrecondSchedulePaLMSOAP** | PrecondScheduleSFPaLMForeachSOAP without schedule-free | + Best initial convergence<br>+ Significantly faster (sec/it) later<br>+ high stability<br>- more memory usage than PrecondScheduleSFPaLMForeachSOAP<br>- higher overhead than AdamW (can be ammortized), goes to 0 with increasing number of steps |
91
+ | **PrecondScheduleSOAP** | PrecondScheduleSFPaLMForeachSOAP without PaLM's beta2 schedule | + Better initial convergence<br>+ Significantly faster (sec/it) later<br>- more memory usage than PrecondScheduleSFPaLMForeachSOAP<br>- higher overhead than AdamW (can be ammortized), goes to 0 with increasing number of steps |
92
92
 
93
93
  ## Precond Schedule
94
94
 
@@ -2,6 +2,7 @@ LICENSE
2
2
  README.md
3
3
  setup.py
4
4
  heavyball/__init__.py
5
+ heavyball/cached_delayed_psgd_kron.py
5
6
  heavyball/cached_psgd_kron.py
6
7
  heavyball/delayed_psgd.py
7
8
  heavyball/foreach_adamw.py
@@ -24,6 +25,7 @@ heavyball.egg-info/SOURCES.txt
24
25
  heavyball.egg-info/dependency_links.txt
25
26
  heavyball.egg-info/requires.txt
26
27
  heavyball.egg-info/top_level.txt
28
+ test/test_bf16_q.py
27
29
  test/test_closure.py
28
30
  test/test_foreach.py
29
31
  test/test_memory.py
@@ -10,7 +10,7 @@ setuptools.setup(
10
10
  name='heavyball',
11
11
  license='BSD',
12
12
  description='Efficient optimizers',
13
- version='0.16.0',
13
+ version='0.17.1',
14
14
  long_description=README,
15
15
  url='https://github.com/clashluke/heavyball',
16
16
  packages=setuptools.find_packages(),
@@ -0,0 +1,52 @@
1
+ import heavyball
2
+ import heavyball.utils
3
+ import pytest
4
+ import torch
5
+ from benchmark.utils import get_optim
6
+ from heavyball.utils import clean, set_torch, PSGDBase
7
+ from torch import nn
8
+
9
+
10
+ def get_memory():
11
+ clean()
12
+ torch.cuda.synchronize()
13
+ clean()
14
+ torch.cuda.synchronize()
15
+ return torch.cuda.memory_allocated()
16
+
17
+
18
+ @pytest.mark.parametrize("opt", heavyball.__all__)
19
+ @pytest.mark.parametrize("size,depth", [(256, 2)])
20
+ def test_foreach(opt, size, depth: int, iterations: int = 128, outer_iterations: int = 3):
21
+ set_torch()
22
+
23
+ opt = getattr(heavyball, opt)
24
+ if not issubclass(opt, PSGDBase):
25
+ raise pytest.skip('Only PSGD is supported')
26
+
27
+ peaks = []
28
+ losses = []
29
+
30
+ for q_dtype in ['float32', 'bfloat16']:
31
+ peaks.append([])
32
+ losses.append([])
33
+
34
+ for i in range(outer_iterations):
35
+ torch.manual_seed(0x2131290)
36
+ model = nn.Sequential(*[nn.Linear(size, size) for _ in range(depth)]).cuda()
37
+ o = get_optim(opt, model.parameters(), lr=1e-3, q_dtype=q_dtype)
38
+
39
+ for _ in range(iterations):
40
+ loss = model(torch.randn((1024, size)).cuda()).square().mean()
41
+ loss.backward()
42
+ o.step()
43
+ o.zero_grad()
44
+ losses[-1].append(loss.detach())
45
+
46
+ del model, o
47
+ clean()
48
+
49
+
50
+ for i, (l0, l1) in enumerate(zip(*losses)):
51
+ print(i, l0.item(), l1.item())
52
+ assert torch.allclose(l0, l1, rtol=0.1)
@@ -20,7 +20,7 @@ class Param(nn.Module):
20
20
 
21
21
  @pytest.mark.parametrize("opt", heavyball.__all__)
22
22
  @pytest.mark.parametrize("size", [(4, 4, 4, 4), ])
23
- def test_closre(opt, size: List[int], depth: int = 2, iterations: int = 5, outer_iterations: int = 3):
23
+ def test_closure(opt, size: List[int], depth: int = 2, iterations: int = 5, outer_iterations: int = 3):
24
24
  clean()
25
25
  set_torch()
26
26
 
@@ -25,14 +25,14 @@ expected_memory = {'adamw': {'after': 4, 'peak': 5.1}, 'soap': {'after': 7, 'pea
25
25
  @pytest.mark.parametrize("size,depth", [(8192, 1), (2048, 16)])
26
26
  def test_memory(opt, method, size, depth: int, iterations: int = 5, outer_iterations: int = 3):
27
27
  if 'soap' not in opt.lower() and method != 'qr':
28
- return
28
+ raise pytest.skip('Only SOAP supports `method` argument')
29
29
  set_torch()
30
30
 
31
31
  for k, v in expected_memory.items():
32
32
  if k in opt.lower():
33
33
  break
34
34
  else:
35
- raise ValueError(f'Unknown optimizer {opt}')
35
+ raise pytest.skip(f'Opt {opt} not supported')
36
36
 
37
37
  opt = getattr(heavyball, opt)
38
38
  heavyball.utils.zeroth_power_mode = method
@@ -26,7 +26,7 @@ class Param(nn.Module):
26
26
  def test_merge(opt, method, size: List[int], merge, split, depth: int = 2, iterations: int = 5,
27
27
  outer_iterations: int = 3):
28
28
  if 'soap' not in opt.lower() and method != 'qr':
29
- return
29
+ raise pytest.skip('Only SOAP supports `method` argument')
30
30
  clean()
31
31
  set_torch()
32
32
 
@@ -1,11 +1,10 @@
1
- import pytest
2
- import torch
3
- from torch import nn
4
-
5
1
  import heavyball
6
2
  import heavyball.utils
3
+ import pytest
4
+ import torch
7
5
  from benchmark.utils import get_optim
8
6
  from heavyball.utils import clean, set_torch
7
+ from torch import nn
9
8
 
10
9
 
11
10
  def get_memory():
@@ -16,10 +15,6 @@ def get_memory():
16
15
  return torch.cuda.memory_allocated()
17
16
 
18
17
 
19
- expected_memory = {'adamw': {'after': 4, 'peak': 5.1}, 'soap': {'after': 7, 'peak': 14},
20
- 'psgd': {'after': 4, 'peak': 11.5}, 'padam': {'after': 5, 'peak': 11.4}}
21
-
22
-
23
18
  @pytest.mark.parametrize("opt", ['ForeachPSGDKron', 'ForeachPaLMPAdam', 'ForeachPurePSGD', 'ForeachDelayedPSGD'])
24
19
  @pytest.mark.parametrize("method",
25
20
  ['norm_clip_', 'mu_law_compress', 'a_law_compress', 'trust_region_clip_', 'identity'])
@@ -27,12 +22,6 @@ expected_memory = {'adamw': {'after': 4, 'peak': 5.1}, 'soap': {'after': 7, 'pea
27
22
  def test_clip(opt, method, size, depth: int, iterations: int = 100, outer_iterations: int = 3):
28
23
  set_torch()
29
24
 
30
- for k, v in expected_memory.items():
31
- if k in opt.lower():
32
- break
33
- else:
34
- raise ValueError(f'Unknown optimizer {opt}')
35
-
36
25
  opt = getattr(heavyball, opt)
37
26
 
38
27
  for i in range(outer_iterations):
File without changes
File without changes
File without changes