heavyball 2.0.0.dev0__tar.gz → 2.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {heavyball-2.0.0.dev0 → heavyball-2.1.0}/PKG-INFO +19 -7
- {heavyball-2.0.0.dev0 → heavyball-2.1.0}/README.md +13 -4
- {heavyball-2.0.0.dev0 → heavyball-2.1.0}/heavyball/__init__.py +168 -29
- {heavyball-2.0.0.dev0 → heavyball-2.1.0}/heavyball/chainable.py +165 -63
- {heavyball-2.0.0.dev0 → heavyball-2.1.0}/heavyball/helpers.py +5 -1
- {heavyball-2.0.0.dev0 → heavyball-2.1.0}/heavyball/utils.py +490 -124
- {heavyball-2.0.0.dev0 → heavyball-2.1.0}/heavyball.egg-info/PKG-INFO +19 -7
- {heavyball-2.0.0.dev0 → heavyball-2.1.0}/heavyball.egg-info/SOURCES.txt +6 -1
- {heavyball-2.0.0.dev0 → heavyball-2.1.0}/heavyball.egg-info/requires.txt +5 -2
- {heavyball-2.0.0.dev0 → heavyball-2.1.0}/pyproject.toml +3 -3
- {heavyball-2.0.0.dev0 → heavyball-2.1.0}/test/test_bf16_params.py +1 -1
- {heavyball-2.0.0.dev0 → heavyball-2.1.0}/test/test_bf16_q.py +3 -3
- {heavyball-2.0.0.dev0 → heavyball-2.1.0}/test/test_bf16_storage.py +3 -3
- {heavyball-2.0.0.dev0 → heavyball-2.1.0}/test/test_caution.py +1 -1
- {heavyball-2.0.0.dev0 → heavyball-2.1.0}/test/test_channels_last.py +1 -1
- heavyball-2.1.0/test/test_clip.py +116 -0
- {heavyball-2.0.0.dev0 → heavyball-2.1.0}/test/test_closure.py +1 -1
- {heavyball-2.0.0.dev0 → heavyball-2.1.0}/test/test_ema.py +1 -1
- {heavyball-2.0.0.dev0 → heavyball-2.1.0}/test/test_foreach.py +3 -3
- {heavyball-2.0.0.dev0 → heavyball-2.1.0}/test/test_hook.py +1 -1
- {heavyball-2.0.0.dev0 → heavyball-2.1.0}/test/test_mars.py +3 -3
- {heavyball-2.0.0.dev0 → heavyball-2.1.0}/test/test_memory.py +1 -1
- {heavyball-2.0.0.dev0 → heavyball-2.1.0}/test/test_memory_leak.py +1 -1
- {heavyball-2.0.0.dev0 → heavyball-2.1.0}/test/test_merge.py +2 -2
- heavyball-2.1.0/test/test_migrate_cli.py +178 -0
- heavyball-2.1.0/test/test_nd_param.py +40 -0
- {heavyball-2.0.0.dev0 → heavyball-2.1.0}/test/test_no_grad.py +2 -2
- {heavyball-2.0.0.dev0 → heavyball-2.1.0}/test/test_save_restore.py +1 -1
- heavyball-2.1.0/test/test_singular_values.py +88 -0
- {heavyball-2.0.0.dev0 → heavyball-2.1.0}/test/test_soap.py +0 -1
- {heavyball-2.0.0.dev0 → heavyball-2.1.0}/test/test_stochastic_updates.py +3 -3
- heavyball-2.1.0/test/test_toy_training.py +130 -0
- {heavyball-2.0.0.dev0 → heavyball-2.1.0}/LICENSE +0 -0
- {heavyball-2.0.0.dev0 → heavyball-2.1.0}/heavyball.egg-info/dependency_links.txt +0 -0
- {heavyball-2.0.0.dev0 → heavyball-2.1.0}/heavyball.egg-info/top_level.txt +0 -0
- {heavyball-2.0.0.dev0 → heavyball-2.1.0}/setup.cfg +0 -0
- {heavyball-2.0.0.dev0 → heavyball-2.1.0}/test/test_psgd_precond_init_stability.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: heavyball
|
3
|
-
Version: 2.
|
3
|
+
Version: 2.1.0
|
4
4
|
Summary: Efficient Optimizers
|
5
5
|
Author-email: HeavyBall Authors <github.heavyball@nestler.sh>
|
6
6
|
Project-URL: source, https://github.com/HomebrewML/HeavyBall
|
@@ -16,7 +16,7 @@ Requires-Python: >=3.9
|
|
16
16
|
Description-Content-Type: text/markdown
|
17
17
|
License-File: LICENSE
|
18
18
|
Requires-Dist: opt-einsum>=3.4.0
|
19
|
-
Requires-Dist: torch>=2.
|
19
|
+
Requires-Dist: torch>=2.7.0
|
20
20
|
Requires-Dist: numpy
|
21
21
|
Provides-Extra: dev
|
22
22
|
Requires-Dist: pre-commit; extra == "dev"
|
@@ -24,9 +24,12 @@ Requires-Dist: pytest; extra == "dev"
|
|
24
24
|
Requires-Dist: ruff; extra == "dev"
|
25
25
|
Requires-Dist: matplotlib; extra == "dev"
|
26
26
|
Requires-Dist: seaborn; extra == "dev"
|
27
|
-
Requires-Dist: hyperopt; extra == "dev"
|
28
27
|
Requires-Dist: pandas; extra == "dev"
|
29
28
|
Requires-Dist: typer; extra == "dev"
|
29
|
+
Requires-Dist: optuna; extra == "dev"
|
30
|
+
Requires-Dist: optunahub; extra == "dev"
|
31
|
+
Requires-Dist: hebo; extra == "dev"
|
32
|
+
Requires-Dist: lightbench; extra == "dev"
|
30
33
|
Dynamic: license-file
|
31
34
|
|
32
35
|
# heavyball
|
@@ -46,11 +49,11 @@ _High-performance, extensible, chainable optimizers for PyTorch._
|
|
46
49
|
|
47
50
|
## Key Features
|
48
51
|
|
49
|
-
- Foreach-based optimizers: `ForeachAdamW`, `ForeachRMSprop`, `ForeachSFAdamW`, `Muon`, `ADOPT`, `MSAM
|
52
|
+
- Foreach-based optimizers: `ForeachAdamW`, `ForeachRMSprop`, `ForeachSFAdamW`, `Muon`, `ADOPT`, `MSAM` (Momentum SAM), …
|
50
53
|
- Schedule-Free optimizers with dynamic learning rate adaptation.
|
51
54
|
- Advanced update rules: MARS correction, cautious updates, PaLM beta2 scheduling.
|
52
55
|
- Chainable transforms for custom optimization recipes.
|
53
|
-
- Comprehensive benchmark suite (`
|
56
|
+
- Comprehensive benchmark suite packaged separately as LightBench (`../LightBench`).
|
54
57
|
- Detailed documentation and example-driven tutorials.
|
55
58
|
|
56
59
|
## Quickstart
|
@@ -81,11 +84,20 @@ for data, target in dataloader:
|
|
81
84
|
|
82
85
|
## Benchmarks
|
83
86
|
|
84
|
-
> Reproduce benchmarks with:
|
87
|
+
> Reproduce benchmarks with LightBench (install it via `pip install -e ../LightBench` from the repo root):
|
85
88
|
> ```bash
|
86
|
-
> python3 -m
|
89
|
+
> python3 -m lightbench.run_all_benchmarks --opt ForeachSOAP --opt LaProp --opt AdamW --opt Muon --opt ForeachCachedNewtonPSGD --opt RMSprop --opt OrthoLaProp --opt ForeachSFAdamW --opt ForeachADOPT --opt LaPropOrtho --opt CachedPSGDKron --opt SignLaProp --opt ForeachSOLP --opt PSGDLRA --opt NewtonPSGDLRA --opt NewtonHybrid2PSGDKron --opt NewtonHybrid2PSGDLRA --opt mars-NewtonHybrid2PSGDLRA --opt MSAMLaProp --opt mars-adaptive-NewtonHybrid2PSGDKron --opt mars-ortho-NewtonHybrid2PSGDKron --opt MuonLaProp --opt mars-unscaled-NewtonHybrid2PSGDKron --opt mars-NewtonHybrid2PSGDKron --opt cautious-AdamW --opt unscaled_cautious-AdamW --opt mars-AdamW --dtype float32 --steps 1000000 --trials 1000 --parallelism 256 --seeds 1 --difficulties trivial --difficulties easy --difficulties medium --difficulties hard --difficulties extreme --difficulties nightmare --timeout 2880
|
87
90
|
> ```
|
88
91
|
|
92
|
+
## Migrating from HeavyBall 1.x
|
93
|
+
|
94
|
+
- Read the detailed [2.0.0 migration notes](docs/heavyball2.md) for an end-to-end checklist, including guidance for restoring legacy behaviour when needed.
|
95
|
+
- Use `scripts/migrate_optimizer_state.py` to rewrite pre-2.0 optimizer checkpoints:
|
96
|
+
```bash
|
97
|
+
python scripts/migrate_optimizer_state.py path/to/checkpoint.pt heavyball.ForeachAdamW --state-key optimizer
|
98
|
+
```
|
99
|
+
The utility renames legacy state entries, fans them out per parameter view, and injects the HeavyBall metadata block expected by 2.0.0.
|
100
|
+
|
89
101
|
|
90
102
|
## Contributing
|
91
103
|
|
@@ -15,11 +15,11 @@ _High-performance, extensible, chainable optimizers for PyTorch._
|
|
15
15
|
|
16
16
|
## Key Features
|
17
17
|
|
18
|
-
- Foreach-based optimizers: `ForeachAdamW`, `ForeachRMSprop`, `ForeachSFAdamW`, `Muon`, `ADOPT`, `MSAM
|
18
|
+
- Foreach-based optimizers: `ForeachAdamW`, `ForeachRMSprop`, `ForeachSFAdamW`, `Muon`, `ADOPT`, `MSAM` (Momentum SAM), …
|
19
19
|
- Schedule-Free optimizers with dynamic learning rate adaptation.
|
20
20
|
- Advanced update rules: MARS correction, cautious updates, PaLM beta2 scheduling.
|
21
21
|
- Chainable transforms for custom optimization recipes.
|
22
|
-
- Comprehensive benchmark suite (`
|
22
|
+
- Comprehensive benchmark suite packaged separately as LightBench (`../LightBench`).
|
23
23
|
- Detailed documentation and example-driven tutorials.
|
24
24
|
|
25
25
|
## Quickstart
|
@@ -50,11 +50,20 @@ for data, target in dataloader:
|
|
50
50
|
|
51
51
|
## Benchmarks
|
52
52
|
|
53
|
-
> Reproduce benchmarks with:
|
53
|
+
> Reproduce benchmarks with LightBench (install it via `pip install -e ../LightBench` from the repo root):
|
54
54
|
> ```bash
|
55
|
-
> python3 -m
|
55
|
+
> python3 -m lightbench.run_all_benchmarks --opt ForeachSOAP --opt LaProp --opt AdamW --opt Muon --opt ForeachCachedNewtonPSGD --opt RMSprop --opt OrthoLaProp --opt ForeachSFAdamW --opt ForeachADOPT --opt LaPropOrtho --opt CachedPSGDKron --opt SignLaProp --opt ForeachSOLP --opt PSGDLRA --opt NewtonPSGDLRA --opt NewtonHybrid2PSGDKron --opt NewtonHybrid2PSGDLRA --opt mars-NewtonHybrid2PSGDLRA --opt MSAMLaProp --opt mars-adaptive-NewtonHybrid2PSGDKron --opt mars-ortho-NewtonHybrid2PSGDKron --opt MuonLaProp --opt mars-unscaled-NewtonHybrid2PSGDKron --opt mars-NewtonHybrid2PSGDKron --opt cautious-AdamW --opt unscaled_cautious-AdamW --opt mars-AdamW --dtype float32 --steps 1000000 --trials 1000 --parallelism 256 --seeds 1 --difficulties trivial --difficulties easy --difficulties medium --difficulties hard --difficulties extreme --difficulties nightmare --timeout 2880
|
56
56
|
> ```
|
57
57
|
|
58
|
+
## Migrating from HeavyBall 1.x
|
59
|
+
|
60
|
+
- Read the detailed [2.0.0 migration notes](docs/heavyball2.md) for an end-to-end checklist, including guidance for restoring legacy behaviour when needed.
|
61
|
+
- Use `scripts/migrate_optimizer_state.py` to rewrite pre-2.0 optimizer checkpoints:
|
62
|
+
```bash
|
63
|
+
python scripts/migrate_optimizer_state.py path/to/checkpoint.pt heavyball.ForeachAdamW --state-key optimizer
|
64
|
+
```
|
65
|
+
The utility renames legacy state entries, fans them out per parameter view, and injects the HeavyBall metadata block expected by 2.0.0.
|
66
|
+
|
58
67
|
|
59
68
|
## Contributing
|
60
69
|
|
@@ -41,6 +41,34 @@ class SAMWrapper(torch.optim.Optimizer):
|
|
41
41
|
self.wrapped_optimizer.zero_grad()
|
42
42
|
|
43
43
|
|
44
|
+
class SGD(C.BaseOpt):
|
45
|
+
def __init__(
|
46
|
+
self,
|
47
|
+
params,
|
48
|
+
lr=0.0025,
|
49
|
+
beta=0.9,
|
50
|
+
weight_decay=0,
|
51
|
+
warmup_steps=0,
|
52
|
+
foreach: bool = True,
|
53
|
+
storage_dtype: str = "float32",
|
54
|
+
mars: bool = False,
|
55
|
+
caution: bool = False,
|
56
|
+
mars_gamma: float = 0.0025,
|
57
|
+
gradient_clipping: C.str_or_fn = C.use_default,
|
58
|
+
update_clipping: C.str_or_fn = C.use_default,
|
59
|
+
**kwargs,
|
60
|
+
):
|
61
|
+
defaults = locals()
|
62
|
+
defaults.pop("self")
|
63
|
+
params = defaults.pop("params")
|
64
|
+
defaults.update(defaults.pop("kwargs"))
|
65
|
+
|
66
|
+
if kwargs:
|
67
|
+
utils.warn_once(f"Working with uncaptured keyword arguments: {kwargs}")
|
68
|
+
|
69
|
+
super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, fns=(C.heavyball_momentum,))
|
70
|
+
|
71
|
+
|
44
72
|
class ForeachAdamW(C.BaseOpt):
|
45
73
|
def __init__(
|
46
74
|
self,
|
@@ -69,7 +97,110 @@ class ForeachAdamW(C.BaseOpt):
|
|
69
97
|
if kwargs:
|
70
98
|
utils.warn_once(f"Working with uncaptured keyword arguments: {kwargs}")
|
71
99
|
|
72
|
-
super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, C.update_by_adam)
|
100
|
+
super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, fns=(C.update_by_adam,))
|
101
|
+
|
102
|
+
|
103
|
+
class UnscaledAdamW(C.BaseOpt):
|
104
|
+
def __init__(
|
105
|
+
self,
|
106
|
+
params,
|
107
|
+
lr=0.0025,
|
108
|
+
betas=(0.9, 0.99),
|
109
|
+
eps=1e-8,
|
110
|
+
weight_decay=0,
|
111
|
+
warmup_steps=0,
|
112
|
+
foreach: bool = True,
|
113
|
+
storage_dtype: str = "float32",
|
114
|
+
mars: bool = False,
|
115
|
+
caution: bool = False,
|
116
|
+
mars_gamma: float = 0.0025,
|
117
|
+
gradient_clipping: C.str_or_fn = C.use_default,
|
118
|
+
update_clipping: C.str_or_fn = C.use_default,
|
119
|
+
palm: bool = C.use_default,
|
120
|
+
beta2_scale: float = 0.8,
|
121
|
+
**kwargs,
|
122
|
+
):
|
123
|
+
defaults = locals()
|
124
|
+
defaults.pop("self")
|
125
|
+
params = defaults.pop("params")
|
126
|
+
defaults.update(defaults.pop("kwargs"))
|
127
|
+
|
128
|
+
if kwargs:
|
129
|
+
utils.warn_once(f"Working with uncaptured keyword arguments: {kwargs}")
|
130
|
+
|
131
|
+
super().__init__(
|
132
|
+
params, defaults, foreach, gradient_clipping, update_clipping, palm, fns=(C.scale_by_unscaled_adam,)
|
133
|
+
)
|
134
|
+
|
135
|
+
|
136
|
+
class SUDSAdamW(C.BaseOpt):
|
137
|
+
def __init__(
|
138
|
+
self,
|
139
|
+
params,
|
140
|
+
lr=0.0025,
|
141
|
+
betas=(0.9, 0.99),
|
142
|
+
eps=1e-8,
|
143
|
+
weight_decay=0,
|
144
|
+
warmup_steps=0,
|
145
|
+
foreach: bool = True,
|
146
|
+
storage_dtype: str = "float32",
|
147
|
+
mars: bool = False,
|
148
|
+
caution: bool = False,
|
149
|
+
mars_gamma: float = 0.0025,
|
150
|
+
gradient_clipping: C.str_or_fn = C.use_default,
|
151
|
+
update_clipping: C.str_or_fn = C.use_default,
|
152
|
+
palm: bool = C.use_default,
|
153
|
+
beta2_scale: float = 0.8,
|
154
|
+
precond_lr: float = 1e-2,
|
155
|
+
**kwargs,
|
156
|
+
):
|
157
|
+
defaults = locals()
|
158
|
+
defaults.pop("self")
|
159
|
+
params = defaults.pop("params")
|
160
|
+
defaults.update(defaults.pop("kwargs"))
|
161
|
+
|
162
|
+
if kwargs:
|
163
|
+
utils.warn_once(f"Working with uncaptured keyword arguments: {kwargs}")
|
164
|
+
|
165
|
+
super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, fns=(C.scale_by_suds,))
|
166
|
+
|
167
|
+
|
168
|
+
class ForeachAdamC(C.BaseOpt):
|
169
|
+
def __init__(
|
170
|
+
self,
|
171
|
+
params,
|
172
|
+
lr=0.0025,
|
173
|
+
betas=(0.9, 0.99),
|
174
|
+
eps=1e-8,
|
175
|
+
weight_decay=0,
|
176
|
+
max_lr: float | None = None,
|
177
|
+
warmup_steps=0,
|
178
|
+
foreach: bool = True,
|
179
|
+
storage_dtype: str = "float32",
|
180
|
+
mars: bool = False,
|
181
|
+
caution: bool = False,
|
182
|
+
mars_gamma: float = 0.0025,
|
183
|
+
gradient_clipping: C.str_or_fn = C.use_default,
|
184
|
+
update_clipping: C.str_or_fn = C.use_default,
|
185
|
+
palm: bool = C.use_default,
|
186
|
+
beta2_scale: float = 0.8,
|
187
|
+
**kwargs,
|
188
|
+
):
|
189
|
+
if max_lr is None:
|
190
|
+
utils.warn_once(
|
191
|
+
"max_lr was not set. setting it to the current learning rate, under the assumption that it strictly decreases"
|
192
|
+
)
|
193
|
+
max_lr = lr
|
194
|
+
|
195
|
+
defaults = locals()
|
196
|
+
defaults.pop("self")
|
197
|
+
params = defaults.pop("params")
|
198
|
+
defaults.update(defaults.pop("kwargs"))
|
199
|
+
|
200
|
+
if kwargs:
|
201
|
+
utils.warn_once(f"Working with uncaptured keyword arguments: {kwargs}")
|
202
|
+
|
203
|
+
super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, fns=(C.update_by_adamc,))
|
73
204
|
|
74
205
|
|
75
206
|
class ForeachRMSprop(C.BaseOpt):
|
@@ -113,7 +244,7 @@ class ForeachRMSprop(C.BaseOpt):
|
|
113
244
|
gradient_clipping,
|
114
245
|
update_clipping,
|
115
246
|
palm,
|
116
|
-
C.scale_by_exp_avg_sq,
|
247
|
+
fns=(C.scale_by_exp_avg_sq,),
|
117
248
|
)
|
118
249
|
|
119
250
|
|
@@ -154,8 +285,7 @@ class ForeachSFAdamW(C.ScheduleFree):
|
|
154
285
|
gradient_clipping,
|
155
286
|
update_clipping,
|
156
287
|
palm,
|
157
|
-
C.scale_by_exp_avg_sq,
|
158
|
-
C.update_by_schedule_free,
|
288
|
+
fns=(C.scale_by_exp_avg_sq, C.update_by_schedule_free),
|
159
289
|
)
|
160
290
|
|
161
291
|
|
@@ -197,8 +327,7 @@ class MSAMLaProp(C.MSAM):
|
|
197
327
|
gradient_clipping,
|
198
328
|
update_clipping,
|
199
329
|
palm,
|
200
|
-
C.scale_by_exp_avg_sq,
|
201
|
-
C.update_by_msam,
|
330
|
+
fns=(C.scale_by_exp_avg_sq, C.update_by_msam),
|
202
331
|
)
|
203
332
|
|
204
333
|
|
@@ -234,7 +363,7 @@ class ForeachADOPT(C.BaseOpt):
|
|
234
363
|
if kwargs:
|
235
364
|
utils.warn_once(f"Working with uncaptured keyword arguments: {kwargs}")
|
236
365
|
|
237
|
-
super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, C.update_by_adopt)
|
366
|
+
super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, fns=(C.update_by_adopt,))
|
238
367
|
|
239
368
|
|
240
369
|
class ForeachMuon(C.BaseOpt):
|
@@ -256,6 +385,7 @@ class ForeachMuon(C.BaseOpt):
|
|
256
385
|
palm: bool = C.use_default,
|
257
386
|
beta2_scale: float = 0.8,
|
258
387
|
nesterov: bool = True,
|
388
|
+
heavyball_momentum: bool = False,
|
259
389
|
**kwargs,
|
260
390
|
):
|
261
391
|
defaults = locals()
|
@@ -266,6 +396,16 @@ class ForeachMuon(C.BaseOpt):
|
|
266
396
|
if kwargs:
|
267
397
|
utils.warn_once(f"Working with uncaptured keyword arguments: {kwargs}")
|
268
398
|
|
399
|
+
if heavyball_momentum:
|
400
|
+
if nesterov:
|
401
|
+
ema = C.nesterov_momentum
|
402
|
+
else:
|
403
|
+
ema = C.heavyball_momentum
|
404
|
+
elif nesterov:
|
405
|
+
ema = C.nesterov_ema
|
406
|
+
else:
|
407
|
+
ema = C.exp_avg
|
408
|
+
|
269
409
|
super().__init__(
|
270
410
|
params,
|
271
411
|
defaults,
|
@@ -273,8 +413,7 @@ class ForeachMuon(C.BaseOpt):
|
|
273
413
|
gradient_clipping,
|
274
414
|
update_clipping,
|
275
415
|
palm,
|
276
|
-
|
277
|
-
C.orthogonalize_update,
|
416
|
+
fns=(ema, C.orthogonalize_update),
|
278
417
|
)
|
279
418
|
|
280
419
|
|
@@ -306,7 +445,7 @@ class ForeachLaProp(C.BaseOpt):
|
|
306
445
|
if kwargs:
|
307
446
|
utils.warn_once(f"Working with uncaptured keyword arguments: {kwargs}")
|
308
447
|
|
309
|
-
super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, C.update_by_laprop)
|
448
|
+
super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, fns=(C.update_by_laprop,))
|
310
449
|
|
311
450
|
|
312
451
|
class MuonLaProp(C.BaseOpt):
|
@@ -344,8 +483,7 @@ class MuonLaProp(C.BaseOpt):
|
|
344
483
|
gradient_clipping,
|
345
484
|
update_clipping,
|
346
485
|
palm,
|
347
|
-
C.scale_by_laprop,
|
348
|
-
C.orthogonalize_update,
|
486
|
+
fns=(C.scale_by_laprop, C.orthogonalize_update),
|
349
487
|
)
|
350
488
|
|
351
489
|
|
@@ -417,7 +555,7 @@ class ForeachSOAP(C.BaseOpt):
|
|
417
555
|
gradient_clipping,
|
418
556
|
update_clipping,
|
419
557
|
palm, #
|
420
|
-
C.scale_by_soap,
|
558
|
+
fns=(C.scale_by_soap,),
|
421
559
|
)
|
422
560
|
|
423
561
|
|
@@ -456,8 +594,7 @@ class ForeachSignLaProp(C.BaseOpt):
|
|
456
594
|
gradient_clipping,
|
457
595
|
update_clipping,
|
458
596
|
palm,
|
459
|
-
C.scale_by_laprop,
|
460
|
-
C.sign,
|
597
|
+
fns=(C.scale_by_laprop, C.sign),
|
461
598
|
)
|
462
599
|
|
463
600
|
|
@@ -528,7 +665,7 @@ class ForeachSOLP(C.BaseOpt):
|
|
528
665
|
gradient_clipping,
|
529
666
|
update_clipping,
|
530
667
|
palm, #
|
531
|
-
functools.partial(C.scale_by_soap, inner="laprop"),
|
668
|
+
fns=(functools.partial(C.scale_by_soap, inner="laprop"),),
|
532
669
|
)
|
533
670
|
|
534
671
|
|
@@ -580,8 +717,7 @@ class OrthoLaProp(C.BaseOpt):
|
|
580
717
|
gradient_clipping,
|
581
718
|
update_clipping,
|
582
719
|
palm,
|
583
|
-
C.orthogonalize_grad_to_param,
|
584
|
-
C.scale_by_laprop,
|
720
|
+
fns=(C.orthogonalize_grad_to_param, C.scale_by_laprop),
|
585
721
|
)
|
586
722
|
|
587
723
|
|
@@ -619,8 +755,7 @@ class LaPropOrtho(C.BaseOpt):
|
|
619
755
|
gradient_clipping,
|
620
756
|
update_clipping,
|
621
757
|
palm,
|
622
|
-
C.scale_by_laprop,
|
623
|
-
C.orthogonalize_grad_to_param,
|
758
|
+
fns=(C.scale_by_laprop, C.orthogonalize_grad_to_param),
|
624
759
|
)
|
625
760
|
|
626
761
|
|
@@ -683,6 +818,10 @@ class ForeachPSGDKron(C.BaseOpt):
|
|
683
818
|
exp_avg_input = C.default(exp_avg_input, self.exp_avg_input)
|
684
819
|
update_clipping = C.default(update_clipping, utils.trust_region_clip_)
|
685
820
|
inverse_free = C.default(inverse_free, self.quad)
|
821
|
+
if inverse_free:
|
822
|
+
raise ValueError(
|
823
|
+
"inverse_free (i.e., PSGD-QUAD) is not supported at the moment. Consider using https://github.com/evanatyourservice/quad_torch"
|
824
|
+
)
|
686
825
|
|
687
826
|
defaults = locals()
|
688
827
|
defaults.pop("self")
|
@@ -703,8 +842,10 @@ class ForeachPSGDKron(C.BaseOpt):
|
|
703
842
|
gradient_clipping,
|
704
843
|
update_clipping,
|
705
844
|
False, #
|
706
|
-
|
707
|
-
|
845
|
+
fns=(
|
846
|
+
*(C.exp_avg,) * exp_avg_input,
|
847
|
+
functools.partial(C.scale_by_delayed_psgd if delayed else C.scale_by_psgd, cached=cached),
|
848
|
+
),
|
708
849
|
)
|
709
850
|
|
710
851
|
|
@@ -733,11 +874,6 @@ class NewtonHybrid2PSGDKron(ForeachCachedNewtonPSGD):
|
|
733
874
|
hvp_interval = 2
|
734
875
|
|
735
876
|
|
736
|
-
class QUAD(ForeachPSGDKron):
|
737
|
-
quad = True
|
738
|
-
cached = True
|
739
|
-
|
740
|
-
|
741
877
|
class ForeachPSGDLRA(C.BaseOpt):
|
742
878
|
"""
|
743
879
|
Originally from Evan Walters and Omead Pooladzandi, 2024
|
@@ -808,8 +944,7 @@ class ForeachPSGDLRA(C.BaseOpt):
|
|
808
944
|
gradient_clipping,
|
809
945
|
update_clipping,
|
810
946
|
False, #
|
811
|
-
*(C.exp_avg,) * exp_avg_input,
|
812
|
-
C.scale_by_delayed_psgd_lra if delayed else C.scale_by_psgd_lra,
|
947
|
+
fns=(*(C.exp_avg,) * exp_avg_input, C.scale_by_delayed_psgd_lra if delayed else C.scale_by_psgd_lra),
|
813
948
|
)
|
814
949
|
|
815
950
|
|
@@ -846,6 +981,7 @@ SignLaProp = ForeachSignLaProp
|
|
846
981
|
DelayedPSGDLRA = ForeachDelayedPSGDLRA
|
847
982
|
PSGDLRA = ForeachPSGDLRA
|
848
983
|
NewtonPSGDLRA = ForeachNewtonPSGDLRA
|
984
|
+
NewtonPSGDKron = ForeachCachedNewtonPSGD
|
849
985
|
|
850
986
|
__all__ = [
|
851
987
|
"Muon",
|
@@ -892,4 +1028,7 @@ __all__ = [
|
|
892
1028
|
"NewtonHybrid2PSGDLRA",
|
893
1029
|
"NewtonHybrid2PSGDKron",
|
894
1030
|
"MSAMLaProp",
|
1031
|
+
"NewtonPSGDKron",
|
1032
|
+
"ForeachAdamC",
|
1033
|
+
"SGD",
|
895
1034
|
]
|