heavyball 2.0.0.dev0__tar.gz → 2.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. {heavyball-2.0.0.dev0 → heavyball-2.1.0}/PKG-INFO +19 -7
  2. {heavyball-2.0.0.dev0 → heavyball-2.1.0}/README.md +13 -4
  3. {heavyball-2.0.0.dev0 → heavyball-2.1.0}/heavyball/__init__.py +168 -29
  4. {heavyball-2.0.0.dev0 → heavyball-2.1.0}/heavyball/chainable.py +165 -63
  5. {heavyball-2.0.0.dev0 → heavyball-2.1.0}/heavyball/helpers.py +5 -1
  6. {heavyball-2.0.0.dev0 → heavyball-2.1.0}/heavyball/utils.py +490 -124
  7. {heavyball-2.0.0.dev0 → heavyball-2.1.0}/heavyball.egg-info/PKG-INFO +19 -7
  8. {heavyball-2.0.0.dev0 → heavyball-2.1.0}/heavyball.egg-info/SOURCES.txt +6 -1
  9. {heavyball-2.0.0.dev0 → heavyball-2.1.0}/heavyball.egg-info/requires.txt +5 -2
  10. {heavyball-2.0.0.dev0 → heavyball-2.1.0}/pyproject.toml +3 -3
  11. {heavyball-2.0.0.dev0 → heavyball-2.1.0}/test/test_bf16_params.py +1 -1
  12. {heavyball-2.0.0.dev0 → heavyball-2.1.0}/test/test_bf16_q.py +3 -3
  13. {heavyball-2.0.0.dev0 → heavyball-2.1.0}/test/test_bf16_storage.py +3 -3
  14. {heavyball-2.0.0.dev0 → heavyball-2.1.0}/test/test_caution.py +1 -1
  15. {heavyball-2.0.0.dev0 → heavyball-2.1.0}/test/test_channels_last.py +1 -1
  16. heavyball-2.1.0/test/test_clip.py +116 -0
  17. {heavyball-2.0.0.dev0 → heavyball-2.1.0}/test/test_closure.py +1 -1
  18. {heavyball-2.0.0.dev0 → heavyball-2.1.0}/test/test_ema.py +1 -1
  19. {heavyball-2.0.0.dev0 → heavyball-2.1.0}/test/test_foreach.py +3 -3
  20. {heavyball-2.0.0.dev0 → heavyball-2.1.0}/test/test_hook.py +1 -1
  21. {heavyball-2.0.0.dev0 → heavyball-2.1.0}/test/test_mars.py +3 -3
  22. {heavyball-2.0.0.dev0 → heavyball-2.1.0}/test/test_memory.py +1 -1
  23. {heavyball-2.0.0.dev0 → heavyball-2.1.0}/test/test_memory_leak.py +1 -1
  24. {heavyball-2.0.0.dev0 → heavyball-2.1.0}/test/test_merge.py +2 -2
  25. heavyball-2.1.0/test/test_migrate_cli.py +178 -0
  26. heavyball-2.1.0/test/test_nd_param.py +40 -0
  27. {heavyball-2.0.0.dev0 → heavyball-2.1.0}/test/test_no_grad.py +2 -2
  28. {heavyball-2.0.0.dev0 → heavyball-2.1.0}/test/test_save_restore.py +1 -1
  29. heavyball-2.1.0/test/test_singular_values.py +88 -0
  30. {heavyball-2.0.0.dev0 → heavyball-2.1.0}/test/test_soap.py +0 -1
  31. {heavyball-2.0.0.dev0 → heavyball-2.1.0}/test/test_stochastic_updates.py +3 -3
  32. heavyball-2.1.0/test/test_toy_training.py +130 -0
  33. {heavyball-2.0.0.dev0 → heavyball-2.1.0}/LICENSE +0 -0
  34. {heavyball-2.0.0.dev0 → heavyball-2.1.0}/heavyball.egg-info/dependency_links.txt +0 -0
  35. {heavyball-2.0.0.dev0 → heavyball-2.1.0}/heavyball.egg-info/top_level.txt +0 -0
  36. {heavyball-2.0.0.dev0 → heavyball-2.1.0}/setup.cfg +0 -0
  37. {heavyball-2.0.0.dev0 → heavyball-2.1.0}/test/test_psgd_precond_init_stability.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: heavyball
3
- Version: 2.0.0.dev0
3
+ Version: 2.1.0
4
4
  Summary: Efficient Optimizers
5
5
  Author-email: HeavyBall Authors <github.heavyball@nestler.sh>
6
6
  Project-URL: source, https://github.com/HomebrewML/HeavyBall
@@ -16,7 +16,7 @@ Requires-Python: >=3.9
16
16
  Description-Content-Type: text/markdown
17
17
  License-File: LICENSE
18
18
  Requires-Dist: opt-einsum>=3.4.0
19
- Requires-Dist: torch>=2.1.0
19
+ Requires-Dist: torch>=2.7.0
20
20
  Requires-Dist: numpy
21
21
  Provides-Extra: dev
22
22
  Requires-Dist: pre-commit; extra == "dev"
@@ -24,9 +24,12 @@ Requires-Dist: pytest; extra == "dev"
24
24
  Requires-Dist: ruff; extra == "dev"
25
25
  Requires-Dist: matplotlib; extra == "dev"
26
26
  Requires-Dist: seaborn; extra == "dev"
27
- Requires-Dist: hyperopt; extra == "dev"
28
27
  Requires-Dist: pandas; extra == "dev"
29
28
  Requires-Dist: typer; extra == "dev"
29
+ Requires-Dist: optuna; extra == "dev"
30
+ Requires-Dist: optunahub; extra == "dev"
31
+ Requires-Dist: hebo; extra == "dev"
32
+ Requires-Dist: lightbench; extra == "dev"
30
33
  Dynamic: license-file
31
34
 
32
35
  # heavyball
@@ -46,11 +49,11 @@ _High-performance, extensible, chainable optimizers for PyTorch._
46
49
 
47
50
  ## Key Features
48
51
 
49
- - Foreach-based optimizers: `ForeachAdamW`, `ForeachRMSprop`, `ForeachSFAdamW`, `Muon`, `ADOPT`, `MSAM`,
52
+ - Foreach-based optimizers: `ForeachAdamW`, `ForeachRMSprop`, `ForeachSFAdamW`, `Muon`, `ADOPT`, `MSAM` (Momentum SAM),
50
53
  - Schedule-Free optimizers with dynamic learning rate adaptation.
51
54
  - Advanced update rules: MARS correction, cautious updates, PaLM beta2 scheduling.
52
55
  - Chainable transforms for custom optimization recipes.
53
- - Comprehensive benchmark suite (`benchmark/`).
56
+ - Comprehensive benchmark suite packaged separately as LightBench (`../LightBench`).
54
57
  - Detailed documentation and example-driven tutorials.
55
58
 
56
59
  ## Quickstart
@@ -81,11 +84,20 @@ for data, target in dataloader:
81
84
 
82
85
  ## Benchmarks
83
86
 
84
- > Reproduce benchmarks with:
87
+ > Reproduce benchmarks with LightBench (install it via `pip install -e ../LightBench` from the repo root):
85
88
  > ```bash
86
- > python3 -m benchmark.run_all_benchmarks --opt ForeachSOAP --opt LaProp --opt AdamW --opt Muon --opt ForeachCachedNewtonPSGD --opt RMSprop --opt OrthoLaProp --opt ForeachSFAdamW --opt ForeachADOPT --opt LaPropOrtho --opt CachedPSGDKron --opt SignLaProp --opt ForeachSOLP --opt PSGDLRA --opt NewtonPSGDLRA --opt NewtonHybrid2PSGDKron --opt NewtonHybrid2PSGDLRA --opt mars-NewtonHybrid2PSGDLRA --opt MSAMLaProp --opt mars-adaptive-NewtonHybrid2PSGDKron --opt mars-ortho-NewtonHybrid2PSGDKron --opt MuonLaProp --opt mars-unscaled-NewtonHybrid2PSGDKron --opt mars-NewtonHybrid2PSGDKron --opt cautious-AdamW --opt unscaled_cautious-AdamW --opt mars-AdamW --dtype float32 --steps 1000000 --trials 1000 --parallelism 256 --seeds 1 --difficulties trivial --difficulties easy --difficulties medium --difficulties hard --difficulties extreme --difficulties nightmare --timeout 2880
89
+ > python3 -m lightbench.run_all_benchmarks --opt ForeachSOAP --opt LaProp --opt AdamW --opt Muon --opt ForeachCachedNewtonPSGD --opt RMSprop --opt OrthoLaProp --opt ForeachSFAdamW --opt ForeachADOPT --opt LaPropOrtho --opt CachedPSGDKron --opt SignLaProp --opt ForeachSOLP --opt PSGDLRA --opt NewtonPSGDLRA --opt NewtonHybrid2PSGDKron --opt NewtonHybrid2PSGDLRA --opt mars-NewtonHybrid2PSGDLRA --opt MSAMLaProp --opt mars-adaptive-NewtonHybrid2PSGDKron --opt mars-ortho-NewtonHybrid2PSGDKron --opt MuonLaProp --opt mars-unscaled-NewtonHybrid2PSGDKron --opt mars-NewtonHybrid2PSGDKron --opt cautious-AdamW --opt unscaled_cautious-AdamW --opt mars-AdamW --dtype float32 --steps 1000000 --trials 1000 --parallelism 256 --seeds 1 --difficulties trivial --difficulties easy --difficulties medium --difficulties hard --difficulties extreme --difficulties nightmare --timeout 2880
87
90
  > ```
88
91
 
92
+ ## Migrating from HeavyBall 1.x
93
+
94
+ - Read the detailed [2.0.0 migration notes](docs/heavyball2.md) for an end-to-end checklist, including guidance for restoring legacy behaviour when needed.
95
+ - Use `scripts/migrate_optimizer_state.py` to rewrite pre-2.0 optimizer checkpoints:
96
+ ```bash
97
+ python scripts/migrate_optimizer_state.py path/to/checkpoint.pt heavyball.ForeachAdamW --state-key optimizer
98
+ ```
99
+ The utility renames legacy state entries, fans them out per parameter view, and injects the HeavyBall metadata block expected by 2.0.0.
100
+
89
101
 
90
102
  ## Contributing
91
103
 
@@ -15,11 +15,11 @@ _High-performance, extensible, chainable optimizers for PyTorch._
15
15
 
16
16
  ## Key Features
17
17
 
18
- - Foreach-based optimizers: `ForeachAdamW`, `ForeachRMSprop`, `ForeachSFAdamW`, `Muon`, `ADOPT`, `MSAM`,
18
+ - Foreach-based optimizers: `ForeachAdamW`, `ForeachRMSprop`, `ForeachSFAdamW`, `Muon`, `ADOPT`, `MSAM` (Momentum SAM),
19
19
  - Schedule-Free optimizers with dynamic learning rate adaptation.
20
20
  - Advanced update rules: MARS correction, cautious updates, PaLM beta2 scheduling.
21
21
  - Chainable transforms for custom optimization recipes.
22
- - Comprehensive benchmark suite (`benchmark/`).
22
+ - Comprehensive benchmark suite packaged separately as LightBench (`../LightBench`).
23
23
  - Detailed documentation and example-driven tutorials.
24
24
 
25
25
  ## Quickstart
@@ -50,11 +50,20 @@ for data, target in dataloader:
50
50
 
51
51
  ## Benchmarks
52
52
 
53
- > Reproduce benchmarks with:
53
+ > Reproduce benchmarks with LightBench (install it via `pip install -e ../LightBench` from the repo root):
54
54
  > ```bash
55
- > python3 -m benchmark.run_all_benchmarks --opt ForeachSOAP --opt LaProp --opt AdamW --opt Muon --opt ForeachCachedNewtonPSGD --opt RMSprop --opt OrthoLaProp --opt ForeachSFAdamW --opt ForeachADOPT --opt LaPropOrtho --opt CachedPSGDKron --opt SignLaProp --opt ForeachSOLP --opt PSGDLRA --opt NewtonPSGDLRA --opt NewtonHybrid2PSGDKron --opt NewtonHybrid2PSGDLRA --opt mars-NewtonHybrid2PSGDLRA --opt MSAMLaProp --opt mars-adaptive-NewtonHybrid2PSGDKron --opt mars-ortho-NewtonHybrid2PSGDKron --opt MuonLaProp --opt mars-unscaled-NewtonHybrid2PSGDKron --opt mars-NewtonHybrid2PSGDKron --opt cautious-AdamW --opt unscaled_cautious-AdamW --opt mars-AdamW --dtype float32 --steps 1000000 --trials 1000 --parallelism 256 --seeds 1 --difficulties trivial --difficulties easy --difficulties medium --difficulties hard --difficulties extreme --difficulties nightmare --timeout 2880
55
+ > python3 -m lightbench.run_all_benchmarks --opt ForeachSOAP --opt LaProp --opt AdamW --opt Muon --opt ForeachCachedNewtonPSGD --opt RMSprop --opt OrthoLaProp --opt ForeachSFAdamW --opt ForeachADOPT --opt LaPropOrtho --opt CachedPSGDKron --opt SignLaProp --opt ForeachSOLP --opt PSGDLRA --opt NewtonPSGDLRA --opt NewtonHybrid2PSGDKron --opt NewtonHybrid2PSGDLRA --opt mars-NewtonHybrid2PSGDLRA --opt MSAMLaProp --opt mars-adaptive-NewtonHybrid2PSGDKron --opt mars-ortho-NewtonHybrid2PSGDKron --opt MuonLaProp --opt mars-unscaled-NewtonHybrid2PSGDKron --opt mars-NewtonHybrid2PSGDKron --opt cautious-AdamW --opt unscaled_cautious-AdamW --opt mars-AdamW --dtype float32 --steps 1000000 --trials 1000 --parallelism 256 --seeds 1 --difficulties trivial --difficulties easy --difficulties medium --difficulties hard --difficulties extreme --difficulties nightmare --timeout 2880
56
56
  > ```
57
57
 
58
+ ## Migrating from HeavyBall 1.x
59
+
60
+ - Read the detailed [2.0.0 migration notes](docs/heavyball2.md) for an end-to-end checklist, including guidance for restoring legacy behaviour when needed.
61
+ - Use `scripts/migrate_optimizer_state.py` to rewrite pre-2.0 optimizer checkpoints:
62
+ ```bash
63
+ python scripts/migrate_optimizer_state.py path/to/checkpoint.pt heavyball.ForeachAdamW --state-key optimizer
64
+ ```
65
+ The utility renames legacy state entries, fans them out per parameter view, and injects the HeavyBall metadata block expected by 2.0.0.
66
+
58
67
 
59
68
  ## Contributing
60
69
 
@@ -41,6 +41,34 @@ class SAMWrapper(torch.optim.Optimizer):
41
41
  self.wrapped_optimizer.zero_grad()
42
42
 
43
43
 
44
+ class SGD(C.BaseOpt):
45
+ def __init__(
46
+ self,
47
+ params,
48
+ lr=0.0025,
49
+ beta=0.9,
50
+ weight_decay=0,
51
+ warmup_steps=0,
52
+ foreach: bool = True,
53
+ storage_dtype: str = "float32",
54
+ mars: bool = False,
55
+ caution: bool = False,
56
+ mars_gamma: float = 0.0025,
57
+ gradient_clipping: C.str_or_fn = C.use_default,
58
+ update_clipping: C.str_or_fn = C.use_default,
59
+ **kwargs,
60
+ ):
61
+ defaults = locals()
62
+ defaults.pop("self")
63
+ params = defaults.pop("params")
64
+ defaults.update(defaults.pop("kwargs"))
65
+
66
+ if kwargs:
67
+ utils.warn_once(f"Working with uncaptured keyword arguments: {kwargs}")
68
+
69
+ super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, fns=(C.heavyball_momentum,))
70
+
71
+
44
72
  class ForeachAdamW(C.BaseOpt):
45
73
  def __init__(
46
74
  self,
@@ -69,7 +97,110 @@ class ForeachAdamW(C.BaseOpt):
69
97
  if kwargs:
70
98
  utils.warn_once(f"Working with uncaptured keyword arguments: {kwargs}")
71
99
 
72
- super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, C.update_by_adam)
100
+ super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, fns=(C.update_by_adam,))
101
+
102
+
103
+ class UnscaledAdamW(C.BaseOpt):
104
+ def __init__(
105
+ self,
106
+ params,
107
+ lr=0.0025,
108
+ betas=(0.9, 0.99),
109
+ eps=1e-8,
110
+ weight_decay=0,
111
+ warmup_steps=0,
112
+ foreach: bool = True,
113
+ storage_dtype: str = "float32",
114
+ mars: bool = False,
115
+ caution: bool = False,
116
+ mars_gamma: float = 0.0025,
117
+ gradient_clipping: C.str_or_fn = C.use_default,
118
+ update_clipping: C.str_or_fn = C.use_default,
119
+ palm: bool = C.use_default,
120
+ beta2_scale: float = 0.8,
121
+ **kwargs,
122
+ ):
123
+ defaults = locals()
124
+ defaults.pop("self")
125
+ params = defaults.pop("params")
126
+ defaults.update(defaults.pop("kwargs"))
127
+
128
+ if kwargs:
129
+ utils.warn_once(f"Working with uncaptured keyword arguments: {kwargs}")
130
+
131
+ super().__init__(
132
+ params, defaults, foreach, gradient_clipping, update_clipping, palm, fns=(C.scale_by_unscaled_adam,)
133
+ )
134
+
135
+
136
+ class SUDSAdamW(C.BaseOpt):
137
+ def __init__(
138
+ self,
139
+ params,
140
+ lr=0.0025,
141
+ betas=(0.9, 0.99),
142
+ eps=1e-8,
143
+ weight_decay=0,
144
+ warmup_steps=0,
145
+ foreach: bool = True,
146
+ storage_dtype: str = "float32",
147
+ mars: bool = False,
148
+ caution: bool = False,
149
+ mars_gamma: float = 0.0025,
150
+ gradient_clipping: C.str_or_fn = C.use_default,
151
+ update_clipping: C.str_or_fn = C.use_default,
152
+ palm: bool = C.use_default,
153
+ beta2_scale: float = 0.8,
154
+ precond_lr: float = 1e-2,
155
+ **kwargs,
156
+ ):
157
+ defaults = locals()
158
+ defaults.pop("self")
159
+ params = defaults.pop("params")
160
+ defaults.update(defaults.pop("kwargs"))
161
+
162
+ if kwargs:
163
+ utils.warn_once(f"Working with uncaptured keyword arguments: {kwargs}")
164
+
165
+ super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, fns=(C.scale_by_suds,))
166
+
167
+
168
+ class ForeachAdamC(C.BaseOpt):
169
+ def __init__(
170
+ self,
171
+ params,
172
+ lr=0.0025,
173
+ betas=(0.9, 0.99),
174
+ eps=1e-8,
175
+ weight_decay=0,
176
+ max_lr: float | None = None,
177
+ warmup_steps=0,
178
+ foreach: bool = True,
179
+ storage_dtype: str = "float32",
180
+ mars: bool = False,
181
+ caution: bool = False,
182
+ mars_gamma: float = 0.0025,
183
+ gradient_clipping: C.str_or_fn = C.use_default,
184
+ update_clipping: C.str_or_fn = C.use_default,
185
+ palm: bool = C.use_default,
186
+ beta2_scale: float = 0.8,
187
+ **kwargs,
188
+ ):
189
+ if max_lr is None:
190
+ utils.warn_once(
191
+ "max_lr was not set. setting it to the current learning rate, under the assumption that it strictly decreases"
192
+ )
193
+ max_lr = lr
194
+
195
+ defaults = locals()
196
+ defaults.pop("self")
197
+ params = defaults.pop("params")
198
+ defaults.update(defaults.pop("kwargs"))
199
+
200
+ if kwargs:
201
+ utils.warn_once(f"Working with uncaptured keyword arguments: {kwargs}")
202
+
203
+ super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, fns=(C.update_by_adamc,))
73
204
 
74
205
 
75
206
  class ForeachRMSprop(C.BaseOpt):
@@ -113,7 +244,7 @@ class ForeachRMSprop(C.BaseOpt):
113
244
  gradient_clipping,
114
245
  update_clipping,
115
246
  palm,
116
- C.scale_by_exp_avg_sq,
247
+ fns=(C.scale_by_exp_avg_sq,),
117
248
  )
118
249
 
119
250
 
@@ -154,8 +285,7 @@ class ForeachSFAdamW(C.ScheduleFree):
154
285
  gradient_clipping,
155
286
  update_clipping,
156
287
  palm,
157
- C.scale_by_exp_avg_sq,
158
- C.update_by_schedule_free,
288
+ fns=(C.scale_by_exp_avg_sq, C.update_by_schedule_free),
159
289
  )
160
290
 
161
291
 
@@ -197,8 +327,7 @@ class MSAMLaProp(C.MSAM):
197
327
  gradient_clipping,
198
328
  update_clipping,
199
329
  palm,
200
- C.scale_by_exp_avg_sq,
201
- C.update_by_msam,
330
+ fns=(C.scale_by_exp_avg_sq, C.update_by_msam),
202
331
  )
203
332
 
204
333
 
@@ -234,7 +363,7 @@ class ForeachADOPT(C.BaseOpt):
234
363
  if kwargs:
235
364
  utils.warn_once(f"Working with uncaptured keyword arguments: {kwargs}")
236
365
 
237
- super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, C.update_by_adopt)
366
+ super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, fns=(C.update_by_adopt,))
238
367
 
239
368
 
240
369
  class ForeachMuon(C.BaseOpt):
@@ -256,6 +385,7 @@ class ForeachMuon(C.BaseOpt):
256
385
  palm: bool = C.use_default,
257
386
  beta2_scale: float = 0.8,
258
387
  nesterov: bool = True,
388
+ heavyball_momentum: bool = False,
259
389
  **kwargs,
260
390
  ):
261
391
  defaults = locals()
@@ -266,6 +396,16 @@ class ForeachMuon(C.BaseOpt):
266
396
  if kwargs:
267
397
  utils.warn_once(f"Working with uncaptured keyword arguments: {kwargs}")
268
398
 
399
+ if heavyball_momentum:
400
+ if nesterov:
401
+ ema = C.nesterov_momentum
402
+ else:
403
+ ema = C.heavyball_momentum
404
+ elif nesterov:
405
+ ema = C.nesterov_ema
406
+ else:
407
+ ema = C.exp_avg
408
+
269
409
  super().__init__(
270
410
  params,
271
411
  defaults,
@@ -273,8 +413,7 @@ class ForeachMuon(C.BaseOpt):
273
413
  gradient_clipping,
274
414
  update_clipping,
275
415
  palm,
276
- C.nesterov_ema if nesterov else C.exp_avg,
277
- C.orthogonalize_update,
416
+ fns=(ema, C.orthogonalize_update),
278
417
  )
279
418
 
280
419
 
@@ -306,7 +445,7 @@ class ForeachLaProp(C.BaseOpt):
306
445
  if kwargs:
307
446
  utils.warn_once(f"Working with uncaptured keyword arguments: {kwargs}")
308
447
 
309
- super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, C.update_by_laprop)
448
+ super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, fns=(C.update_by_laprop,))
310
449
 
311
450
 
312
451
  class MuonLaProp(C.BaseOpt):
@@ -344,8 +483,7 @@ class MuonLaProp(C.BaseOpt):
344
483
  gradient_clipping,
345
484
  update_clipping,
346
485
  palm,
347
- C.scale_by_laprop,
348
- C.orthogonalize_update,
486
+ fns=(C.scale_by_laprop, C.orthogonalize_update),
349
487
  )
350
488
 
351
489
 
@@ -417,7 +555,7 @@ class ForeachSOAP(C.BaseOpt):
417
555
  gradient_clipping,
418
556
  update_clipping,
419
557
  palm, #
420
- C.scale_by_soap,
558
+ fns=(C.scale_by_soap,),
421
559
  )
422
560
 
423
561
 
@@ -456,8 +594,7 @@ class ForeachSignLaProp(C.BaseOpt):
456
594
  gradient_clipping,
457
595
  update_clipping,
458
596
  palm,
459
- C.scale_by_laprop,
460
- C.sign,
597
+ fns=(C.scale_by_laprop, C.sign),
461
598
  )
462
599
 
463
600
 
@@ -528,7 +665,7 @@ class ForeachSOLP(C.BaseOpt):
528
665
  gradient_clipping,
529
666
  update_clipping,
530
667
  palm, #
531
- functools.partial(C.scale_by_soap, inner="laprop"),
668
+ fns=(functools.partial(C.scale_by_soap, inner="laprop"),),
532
669
  )
533
670
 
534
671
 
@@ -580,8 +717,7 @@ class OrthoLaProp(C.BaseOpt):
580
717
  gradient_clipping,
581
718
  update_clipping,
582
719
  palm,
583
- C.orthogonalize_grad_to_param,
584
- C.scale_by_laprop,
720
+ fns=(C.orthogonalize_grad_to_param, C.scale_by_laprop),
585
721
  )
586
722
 
587
723
 
@@ -619,8 +755,7 @@ class LaPropOrtho(C.BaseOpt):
619
755
  gradient_clipping,
620
756
  update_clipping,
621
757
  palm,
622
- C.scale_by_laprop,
623
- C.orthogonalize_grad_to_param,
758
+ fns=(C.scale_by_laprop, C.orthogonalize_grad_to_param),
624
759
  )
625
760
 
626
761
 
@@ -683,6 +818,10 @@ class ForeachPSGDKron(C.BaseOpt):
683
818
  exp_avg_input = C.default(exp_avg_input, self.exp_avg_input)
684
819
  update_clipping = C.default(update_clipping, utils.trust_region_clip_)
685
820
  inverse_free = C.default(inverse_free, self.quad)
821
+ if inverse_free:
822
+ raise ValueError(
823
+ "inverse_free (i.e., PSGD-QUAD) is not supported at the moment. Consider using https://github.com/evanatyourservice/quad_torch"
824
+ )
686
825
 
687
826
  defaults = locals()
688
827
  defaults.pop("self")
@@ -703,8 +842,10 @@ class ForeachPSGDKron(C.BaseOpt):
703
842
  gradient_clipping,
704
843
  update_clipping,
705
844
  False, #
706
- *(C.exp_avg,) * exp_avg_input, #
707
- functools.partial(C.scale_by_delayed_psgd if delayed else C.scale_by_psgd, cached=cached),
845
+ fns=(
846
+ *(C.exp_avg,) * exp_avg_input,
847
+ functools.partial(C.scale_by_delayed_psgd if delayed else C.scale_by_psgd, cached=cached),
848
+ ),
708
849
  )
709
850
 
710
851
 
@@ -733,11 +874,6 @@ class NewtonHybrid2PSGDKron(ForeachCachedNewtonPSGD):
733
874
  hvp_interval = 2
734
875
 
735
876
 
736
- class QUAD(ForeachPSGDKron):
737
- quad = True
738
- cached = True
739
-
740
-
741
877
  class ForeachPSGDLRA(C.BaseOpt):
742
878
  """
743
879
  Originally from Evan Walters and Omead Pooladzandi, 2024
@@ -808,8 +944,7 @@ class ForeachPSGDLRA(C.BaseOpt):
808
944
  gradient_clipping,
809
945
  update_clipping,
810
946
  False, #
811
- *(C.exp_avg,) * exp_avg_input, #
812
- C.scale_by_delayed_psgd_lra if delayed else C.scale_by_psgd_lra,
947
+ fns=(*(C.exp_avg,) * exp_avg_input, C.scale_by_delayed_psgd_lra if delayed else C.scale_by_psgd_lra),
813
948
  )
814
949
 
815
950
 
@@ -846,6 +981,7 @@ SignLaProp = ForeachSignLaProp
846
981
  DelayedPSGDLRA = ForeachDelayedPSGDLRA
847
982
  PSGDLRA = ForeachPSGDLRA
848
983
  NewtonPSGDLRA = ForeachNewtonPSGDLRA
984
+ NewtonPSGDKron = ForeachCachedNewtonPSGD
849
985
 
850
986
  __all__ = [
851
987
  "Muon",
@@ -892,4 +1028,7 @@ __all__ = [
892
1028
  "NewtonHybrid2PSGDLRA",
893
1029
  "NewtonHybrid2PSGDKron",
894
1030
  "MSAMLaProp",
1031
+ "NewtonPSGDKron",
1032
+ "ForeachAdamC",
1033
+ "SGD",
895
1034
  ]