heavyball 1.6.3__tar.gz → 1.7.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. {heavyball-1.6.3 → heavyball-1.7.0}/LICENSE +1 -1
  2. {heavyball-1.6.3 → heavyball-1.7.0}/PKG-INFO +3 -2
  3. heavyball-1.7.0/heavyball/__init__.py +699 -0
  4. {heavyball-1.6.3 → heavyball-1.7.0}/heavyball/chainable.py +444 -155
  5. {heavyball-1.6.3 → heavyball-1.7.0}/heavyball/utils.py +326 -143
  6. {heavyball-1.6.3 → heavyball-1.7.0}/heavyball.egg-info/PKG-INFO +3 -2
  7. {heavyball-1.6.3 → heavyball-1.7.0}/heavyball.egg-info/SOURCES.txt +0 -1
  8. {heavyball-1.6.3 → heavyball-1.7.0}/pyproject.toml +1 -1
  9. {heavyball-1.6.3 → heavyball-1.7.0}/test/test_bf16_params.py +8 -7
  10. {heavyball-1.6.3 → heavyball-1.7.0}/test/test_bf16_q.py +4 -4
  11. {heavyball-1.6.3 → heavyball-1.7.0}/test/test_bf16_storage.py +5 -6
  12. {heavyball-1.6.3 → heavyball-1.7.0}/test/test_caution.py +7 -6
  13. {heavyball-1.6.3 → heavyball-1.7.0}/test/test_channels_last.py +8 -7
  14. {heavyball-1.6.3 → heavyball-1.7.0}/test/test_closure.py +12 -8
  15. {heavyball-1.6.3 → heavyball-1.7.0}/test/test_ema.py +2 -2
  16. {heavyball-1.6.3 → heavyball-1.7.0}/test/test_foreach.py +7 -6
  17. {heavyball-1.6.3 → heavyball-1.7.0}/test/test_hook.py +7 -6
  18. {heavyball-1.6.3 → heavyball-1.7.0}/test/test_mars.py +6 -5
  19. {heavyball-1.6.3 → heavyball-1.7.0}/test/test_memory.py +16 -12
  20. {heavyball-1.6.3 → heavyball-1.7.0}/test/test_merge.py +25 -10
  21. {heavyball-1.6.3 → heavyball-1.7.0}/test/test_no_grad.py +11 -5
  22. {heavyball-1.6.3 → heavyball-1.7.0}/test/test_soap.py +124 -70
  23. {heavyball-1.6.3 → heavyball-1.7.0}/test/test_stochastic_updates.py +8 -7
  24. heavyball-1.6.3/heavyball/__init__.py +0 -303
  25. heavyball-1.6.3/test/test_psgd.py +0 -66
  26. {heavyball-1.6.3 → heavyball-1.7.0}/README.md +0 -0
  27. {heavyball-1.6.3 → heavyball-1.7.0}/heavyball.egg-info/dependency_links.txt +0 -0
  28. {heavyball-1.6.3 → heavyball-1.7.0}/heavyball.egg-info/requires.txt +0 -0
  29. {heavyball-1.6.3 → heavyball-1.7.0}/heavyball.egg-info/top_level.txt +0 -0
  30. {heavyball-1.6.3 → heavyball-1.7.0}/setup.cfg +0 -0
@@ -22,4 +22,4 @@ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
22
22
  SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
23
23
  CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
24
24
  OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25
- OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
25
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.2
1
+ Metadata-Version: 2.4
2
2
  Name: heavyball
3
- Version: 1.6.3
3
+ Version: 1.7.0
4
4
  Summary: Efficient Optimizers
5
5
  Author-email: HeavyBall Authors <github.heavyball@nestler.sh>
6
6
  Project-URL: source, https://github.com/HomebrewML/HeavyBall
@@ -27,6 +27,7 @@ Requires-Dist: seaborn; extra == "dev"
27
27
  Requires-Dist: hyperopt; extra == "dev"
28
28
  Requires-Dist: pandas; extra == "dev"
29
29
  Requires-Dist: typer; extra == "dev"
30
+ Dynamic: license-file
30
31
 
31
32
  # `heavyball`: Efficient Optimizers
32
33
 
@@ -0,0 +1,699 @@
1
+ import functools
2
+ from typing import Optional
3
+
4
+ from . import chainable as C
5
+ from . import utils
6
+
7
+
8
+ class ForeachAdamW(C.BaseOpt):
9
+ def __init__(
10
+ self,
11
+ params,
12
+ lr=0.0025,
13
+ betas=(0.9, 0.99),
14
+ eps=1e-8,
15
+ weight_decay=0,
16
+ warmup_steps=0,
17
+ foreach: bool = True,
18
+ storage_dtype: str = "float32",
19
+ mars: bool = False,
20
+ caution: bool = False,
21
+ mars_gamma: float = 0.0025,
22
+ gradient_clipping: C.str_or_fn = C.use_default,
23
+ update_clipping: C.str_or_fn = C.use_default,
24
+ palm: bool = C.use_default,
25
+ beta2_scale: float = 0.8,
26
+ ):
27
+ defaults = locals()
28
+ defaults.pop("self")
29
+ params = defaults.pop("params")
30
+ super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, C.update_by_adam)
31
+
32
+
33
+ class ForeachRMSprop(C.BaseOpt):
34
+ """
35
+ Debiased RMSprop (not torch.optim.RMSprop)
36
+ """
37
+
38
+ def __init__(
39
+ self,
40
+ params,
41
+ lr=0.0025,
42
+ betas=(0.9, 0.99),
43
+ eps=1e-6,
44
+ weight_decay=0,
45
+ warmup_steps=0,
46
+ r=0.0,
47
+ weight_lr_power=2.0,
48
+ foreach: bool = True,
49
+ storage_dtype: str = "float32",
50
+ mars: bool = False,
51
+ caution: bool = False,
52
+ mars_gamma: float = 0.0025,
53
+ gradient_clipping: C.str_or_fn = C.use_default,
54
+ update_clipping: C.str_or_fn = C.use_default,
55
+ palm: bool = C.use_default,
56
+ beta2_scale: float = 0.8,
57
+ ):
58
+ defaults = locals()
59
+ defaults.pop("self")
60
+ params = defaults.pop("params")
61
+ super().__init__(
62
+ params,
63
+ defaults,
64
+ foreach,
65
+ gradient_clipping,
66
+ update_clipping,
67
+ palm,
68
+ C.scale_by_exp_avg_sq,
69
+ )
70
+
71
+
72
+ class ForeachSFAdamW(C.ScheduleFree):
73
+ def __init__(
74
+ self,
75
+ params,
76
+ lr=0.0025,
77
+ betas=(0.9, 0.99),
78
+ eps=1e-6,
79
+ weight_decay=0,
80
+ warmup_steps=0,
81
+ r=0.0,
82
+ weight_lr_power=2.0,
83
+ foreach: bool = True,
84
+ storage_dtype: str = "float32",
85
+ mars: bool = False,
86
+ caution: bool = False,
87
+ mars_gamma: float = 0.0025,
88
+ gradient_clipping: C.str_or_fn = C.use_default,
89
+ update_clipping: C.str_or_fn = C.use_default,
90
+ palm: bool = C.use_default,
91
+ beta2_scale: float = 0.8,
92
+ ):
93
+ defaults = locals()
94
+ defaults.pop("self")
95
+ params = defaults.pop("params")
96
+ super().__init__(
97
+ params,
98
+ defaults,
99
+ foreach,
100
+ gradient_clipping,
101
+ update_clipping,
102
+ palm,
103
+ C.scale_by_exp_avg_sq,
104
+ C.update_by_schedule_free,
105
+ )
106
+
107
+
108
+ class PaLMForeachSFAdamW(ForeachSFAdamW):
109
+ palm: bool = True
110
+
111
+
112
+ class ForeachADOPT(C.BaseOpt):
113
+ def __init__(
114
+ self,
115
+ params,
116
+ lr=0.0025,
117
+ betas=(0.9, 0.99),
118
+ eps=1e-8,
119
+ weight_decay=0,
120
+ warmup_steps=0,
121
+ foreach: bool = True,
122
+ storage_dtype: str = "float32",
123
+ mars: bool = False,
124
+ caution: bool = False,
125
+ mars_gamma: float = 0.0025,
126
+ gradient_clipping: C.str_or_fn = C.use_default,
127
+ update_clipping: C.str_or_fn = C.use_default,
128
+ palm: bool = C.use_default,
129
+ beta2_scale: float = 0.8,
130
+ ):
131
+ defaults = locals()
132
+ defaults.pop("self")
133
+ params = defaults.pop("params")
134
+ super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, C.update_by_adopt)
135
+
136
+
137
+ class ForeachMuon(C.BaseOpt):
138
+ def __init__(
139
+ self,
140
+ params,
141
+ lr=0.0025,
142
+ betas=(0.9, 0.99),
143
+ eps=1e-8,
144
+ weight_decay=0,
145
+ warmup_steps=0,
146
+ foreach: bool = True,
147
+ storage_dtype: str = "float32",
148
+ mars: bool = False,
149
+ caution: bool = False,
150
+ mars_gamma: float = 0.0025,
151
+ gradient_clipping: C.str_or_fn = C.use_default,
152
+ update_clipping: C.str_or_fn = C.use_default,
153
+ palm: bool = C.use_default,
154
+ beta2_scale: float = 0.8,
155
+ nesterov: bool = True,
156
+ ):
157
+ defaults = locals()
158
+ defaults.pop("self")
159
+ params = defaults.pop("params")
160
+ super().__init__(
161
+ params,
162
+ defaults,
163
+ foreach,
164
+ gradient_clipping,
165
+ update_clipping,
166
+ palm,
167
+ C.nesterov_momentum if nesterov else C.heavyball_momentum,
168
+ C.orthogonalize_update,
169
+ )
170
+
171
+
172
+ class ForeachLaProp(C.BaseOpt):
173
+ def __init__(
174
+ self,
175
+ params,
176
+ lr=0.0025,
177
+ betas=(0.9, 0.99),
178
+ eps=1e-8,
179
+ weight_decay=0,
180
+ warmup_steps=0,
181
+ foreach: bool = True,
182
+ storage_dtype: str = "float32",
183
+ mars: bool = False,
184
+ caution: bool = False,
185
+ mars_gamma: float = 0.0025,
186
+ gradient_clipping: C.str_or_fn = C.use_default,
187
+ update_clipping: C.str_or_fn = C.use_default,
188
+ palm: bool = C.use_default,
189
+ beta2_scale: float = 0.8,
190
+ ):
191
+ defaults = locals()
192
+ defaults.pop("self")
193
+ params = defaults.pop("params")
194
+ super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, C.update_by_laprop)
195
+
196
+
197
+ class MuonLaProp(C.BaseOpt):
198
+ def __init__(
199
+ self,
200
+ params,
201
+ lr=0.0025,
202
+ betas=(0.9, 0.99),
203
+ eps=1e-8,
204
+ weight_decay=0,
205
+ warmup_steps=0,
206
+ foreach: bool = True,
207
+ storage_dtype: str = "float32",
208
+ mars: bool = False,
209
+ caution: bool = False,
210
+ mars_gamma: float = 0.0025,
211
+ gradient_clipping: C.str_or_fn = C.use_default,
212
+ update_clipping: C.str_or_fn = C.use_default,
213
+ palm: bool = C.use_default,
214
+ beta2_scale: float = 0.8,
215
+ ):
216
+ defaults = locals()
217
+ defaults.pop("self")
218
+ params = defaults.pop("params")
219
+ super().__init__(
220
+ params,
221
+ defaults,
222
+ foreach,
223
+ gradient_clipping,
224
+ update_clipping,
225
+ palm,
226
+ C.scale_by_laprop,
227
+ C.orthogonalize_update,
228
+ )
229
+
230
+
231
+ class ForeachSOAP(C.BaseOpt):
232
+ """
233
+ ForeachSOAP
234
+
235
+ Sources:
236
+ Baseline SOAP:
237
+ SOAP: Improving and Stabilizing Shampoo using Adam
238
+ Nikhil Vyas, Depen Morwani, Rosie Zhao, Itai Shapira, David Brandfonbrener, Lucas Janson, Sham Kakade
239
+ https://arxiv.org/abs/2409.11321
240
+ https://github.com/nikhilvyas/SOAP
241
+ """
242
+
243
+ use_precond_schedule: bool = False
244
+
245
+ def __init__(
246
+ self,
247
+ params,
248
+ lr: float = 3e-3,
249
+ betas=(0.9, 0.95),
250
+ shampoo_beta: float = 0.95,
251
+ eps: float = 1e-8,
252
+ weight_decay: float = 0.01,
253
+ precondition_frequency: int = 2,
254
+ max_precond_dim: int = 2048, #
255
+ merge_dims: bool = True,
256
+ precondition_1d: bool = False,
257
+ normalize_grads: bool = False,
258
+ correct_bias: bool = True,
259
+ warmup_steps: int = 0,
260
+ split: bool = False,
261
+ foreach: bool = True,
262
+ mars: bool = False,
263
+ caution: bool = False,
264
+ mars_gamma: float = 0.0025,
265
+ palm: bool = C.use_default,
266
+ precond_scheduler=(1 / 3, 9),
267
+ beta2_scale: float = 0.8,
268
+ use_precond_schedule: bool = C.use_default,
269
+ gradient_clipping: C.str_or_fn = C.use_default,
270
+ update_clipping: C.str_or_fn = C.use_default,
271
+ storage_dtype: str = "float32",
272
+ stochastic_schedule: bool = False,
273
+ ):
274
+ use_precond_schedule = C.default(use_precond_schedule, self.use_precond_schedule)
275
+
276
+ defaults = locals()
277
+ defaults.pop("self")
278
+ params = defaults.pop("params")
279
+
280
+ if use_precond_schedule:
281
+ del defaults["precondition_frequency"]
282
+ self.precond_schedule = utils.get_soap_precond_schedule(defaults.pop("precond_scheduler"))
283
+ else:
284
+ del defaults["precond_scheduler"]
285
+ self.precond_schedule = 1 / defaults.pop("precondition_frequency")
286
+ super().__init__(
287
+ params,
288
+ defaults,
289
+ foreach,
290
+ gradient_clipping,
291
+ update_clipping,
292
+ palm, #
293
+ C.scale_by_soap,
294
+ )
295
+
296
+
297
+ class ForeachSignLaProp(C.BaseOpt):
298
+ def __init__(
299
+ self,
300
+ params,
301
+ lr=0.0025,
302
+ betas=(0.9, 0.99),
303
+ eps=1e-8,
304
+ weight_decay=0,
305
+ warmup_steps=0,
306
+ foreach: bool = True,
307
+ storage_dtype: str = "float32",
308
+ mars: bool = False,
309
+ caution: bool = False,
310
+ mars_gamma: float = 0.0025,
311
+ gradient_clipping: C.str_or_fn = C.use_default,
312
+ update_clipping: C.str_or_fn = C.use_default,
313
+ palm: bool = C.use_default,
314
+ beta2_scale: float = 0.8,
315
+ ):
316
+ defaults = locals()
317
+ defaults.pop("self")
318
+ params = defaults.pop("params")
319
+ super().__init__(
320
+ params,
321
+ defaults,
322
+ foreach,
323
+ gradient_clipping,
324
+ update_clipping,
325
+ palm,
326
+ C.scale_by_laprop,
327
+ C.sign,
328
+ )
329
+
330
+
331
+ class ForeachSOLP(C.BaseOpt):
332
+ """
333
+ ForeachSOLP
334
+
335
+ Sources:
336
+ Baseline SOAP:
337
+ SOAP: Improving and Stabilizing Shampoo using Adam
338
+ Nikhil Vyas, Depen Morwani, Rosie Zhao, Itai Shapira, David Brandfonbrener, Lucas Janson, Sham Kakade
339
+ https://arxiv.org/abs/2409.11321
340
+ https://github.com/nikhilvyas/SOAP
341
+ """
342
+
343
+ use_precond_schedule: bool = False
344
+
345
+ def __init__(
346
+ self,
347
+ params,
348
+ lr: float = 3e-3,
349
+ betas=(0.9, 0.95),
350
+ shampoo_beta: float = 0.95,
351
+ eps: float = 1e-8,
352
+ weight_decay: float = 0.01,
353
+ precondition_frequency: int = 2,
354
+ max_precond_dim: int = 2048, #
355
+ merge_dims: bool = True,
356
+ precondition_1d: bool = False,
357
+ normalize_grads: bool = False,
358
+ correct_bias: bool = True,
359
+ warmup_steps: int = 0,
360
+ split: bool = False,
361
+ foreach: bool = True,
362
+ mars: bool = False,
363
+ caution: bool = False,
364
+ mars_gamma: float = 0.0025,
365
+ palm: bool = C.use_default,
366
+ precond_scheduler=(1 / 3, 9),
367
+ beta2_scale: float = 0.8,
368
+ use_precond_schedule: bool = C.use_default,
369
+ gradient_clipping: C.str_or_fn = C.use_default,
370
+ update_clipping: C.str_or_fn = C.use_default,
371
+ storage_dtype: str = "float32",
372
+ stochastic_schedule: bool = False,
373
+ ):
374
+ use_precond_schedule = C.default(use_precond_schedule, self.use_precond_schedule)
375
+
376
+ defaults = locals()
377
+ defaults.pop("self")
378
+ params = defaults.pop("params")
379
+
380
+ if use_precond_schedule:
381
+ del defaults["precondition_frequency"]
382
+ self.precond_schedule = utils.get_soap_precond_schedule(defaults.pop("precond_scheduler"))
383
+ else:
384
+ del defaults["precond_scheduler"]
385
+ self.precond_schedule = 1 / defaults.pop("precondition_frequency")
386
+ super().__init__(
387
+ params,
388
+ defaults,
389
+ foreach,
390
+ gradient_clipping,
391
+ update_clipping,
392
+ palm, #
393
+ functools.partial(C.scale_by_soap, inner="laprop"),
394
+ )
395
+
396
+
397
+ class PaLMForeachSOAP(ForeachSOAP):
398
+ use_precond_schedule: bool = False
399
+ palm: bool = True
400
+
401
+
402
+ class PrecondScheduleForeachSOAP(ForeachSOAP):
403
+ use_precond_schedule: bool = True
404
+
405
+
406
+ class PrecondSchedulePaLMForeachSOAP(ForeachSOAP):
407
+ use_precond_schedule: bool = True
408
+ palm: bool = True
409
+
410
+
411
+ class OrthoLaProp(C.BaseOpt):
412
+ def __init__(
413
+ self,
414
+ params,
415
+ lr=0.0025,
416
+ betas=(0.9, 0.99),
417
+ eps=1e-8,
418
+ weight_decay=0,
419
+ warmup_steps=0,
420
+ foreach: bool = True,
421
+ storage_dtype: str = "float32",
422
+ mars: bool = False,
423
+ caution: bool = False,
424
+ mars_gamma: float = 0.0025,
425
+ gradient_clipping: C.str_or_fn = C.use_default,
426
+ update_clipping: C.str_or_fn = C.use_default,
427
+ palm: bool = C.use_default,
428
+ beta2_scale: float = 0.8,
429
+ ):
430
+ defaults = locals()
431
+ defaults.pop("self")
432
+ params = defaults.pop("params")
433
+ super().__init__(
434
+ params,
435
+ defaults,
436
+ foreach,
437
+ gradient_clipping,
438
+ update_clipping,
439
+ palm,
440
+ C.orthogonalize_grad_to_param,
441
+ C.scale_by_laprop,
442
+ )
443
+
444
+
445
+ class LaPropOrtho(C.BaseOpt):
446
+ def __init__(
447
+ self,
448
+ params,
449
+ lr=0.0025,
450
+ betas=(0.9, 0.99),
451
+ eps=1e-8,
452
+ weight_decay=0,
453
+ warmup_steps=0,
454
+ foreach: bool = True,
455
+ storage_dtype: str = "float32",
456
+ mars: bool = False,
457
+ caution: bool = False,
458
+ mars_gamma: float = 0.0025,
459
+ gradient_clipping: C.str_or_fn = C.use_default,
460
+ update_clipping: C.str_or_fn = C.use_default,
461
+ palm: bool = C.use_default,
462
+ beta2_scale: float = 0.8,
463
+ ):
464
+ defaults = locals()
465
+ defaults.pop("self")
466
+ params = defaults.pop("params")
467
+ super().__init__(
468
+ params,
469
+ defaults,
470
+ foreach,
471
+ gradient_clipping,
472
+ update_clipping,
473
+ palm,
474
+ C.scale_by_laprop,
475
+ C.orthogonalize_grad_to_param,
476
+ )
477
+
478
+
479
+ class ForeachPSGDKron(C.BaseOpt):
480
+ """
481
+ Originally from Evan Walters and Omead Pooladzandi, 2024
482
+ Modified under Creative Commons Attribution 4.0 International
483
+ Source available at https://github.com/evanatyourservice/kron_torch/blob/97a2b5ee8a1a4c29e4780bbf6c521e545189eff9/kron_torch/kron.py
484
+ """
485
+
486
+ delayed: bool = False
487
+ cached: bool = False
488
+ exp_avg_input: bool = True
489
+
490
+ def __init__(
491
+ self,
492
+ params,
493
+ lr=0.001,
494
+ beta=0.9,
495
+ weight_decay=0.0,
496
+ preconditioner_update_probability=None,
497
+ max_size_triangular=2048,
498
+ min_ndim_triangular=2,
499
+ memory_save_mode=None,
500
+ momentum_into_precond_update=True,
501
+ warmup_steps: int = 0,
502
+ merge_dims: bool = False,
503
+ split: bool = False,
504
+ store_triu_as_line: bool = True,
505
+ foreach: bool = True,
506
+ q_dtype="float32",
507
+ stochastic_schedule: bool = False,
508
+ storage_dtype: str = "float32",
509
+ mars: bool = False,
510
+ caution: bool = False,
511
+ mars_gamma: float = 0.0025,
512
+ delayed: Optional[bool] = C.use_default,
513
+ cached: Optional[bool] = C.use_default,
514
+ exp_avg_input: Optional[bool] = C.use_default,
515
+ gradient_clipping: C.str_or_fn = C.use_default,
516
+ update_clipping: C.str_or_fn = C.use_default, #
517
+ # expert parameters
518
+ precond_init_scale=None,
519
+ precond_init_scale_scale=1,
520
+ precond_lr=0.1,
521
+ ):
522
+ defaults = locals()
523
+ defaults.pop("self")
524
+ self.precond_schedule = (
525
+ defaults.pop("preconditioner_update_probability") or utils.precond_update_prob_schedule()
526
+ )
527
+ params = defaults.pop("params")
528
+
529
+ delayed = C.default(delayed, self.delayed)
530
+ cached = C.default(cached, self.cached)
531
+ exp_avg_input = C.default(exp_avg_input, self.exp_avg_input)
532
+ update_clipping = C.default(update_clipping, utils.trust_region_clip_)
533
+
534
+ super().__init__(
535
+ params,
536
+ defaults,
537
+ foreach,
538
+ gradient_clipping,
539
+ update_clipping,
540
+ False, #
541
+ *(C.exp_avg,) * exp_avg_input, #
542
+ functools.partial(C.scale_by_delayed_psgd if delayed else C.scale_by_psgd, cached=cached),
543
+ )
544
+
545
+
546
+ class ForeachPurePSGD(ForeachPSGDKron):
547
+ exp_avg_input: bool = False
548
+
549
+
550
+ class ForeachCachedDelayedPSGDKron(ForeachPSGDKron):
551
+ delayed: bool = True
552
+ cached: bool = True
553
+
554
+
555
+ class ForeachCachedPSGDKron(ForeachPSGDKron):
556
+ cached: bool = True
557
+
558
+
559
+ class ForeachDelayedPSGD(ForeachPSGDKron):
560
+ delayed: bool = True
561
+
562
+
563
+ class ForeachCachedNewtonPSGD(ForeachCachedPSGDKron):
564
+ hessian_approx = True
565
+
566
+
567
+ class ForeachPSGDLRA(C.BaseOpt):
568
+ """
569
+ Originally from Evan Walters and Omead Pooladzandi, 2024
570
+ Modified under Creative Commons Attribution 4.0 International
571
+ Source available at https://github.com/evanatyourservice/kron_torch/blob/97a2b5ee8a1a4c29e4780bbf6c521e545189eff9/kron_torch/kron.py
572
+ """
573
+
574
+ delayed: bool = False
575
+ exp_avg_input: bool = True
576
+
577
+ def __init__(
578
+ self,
579
+ params,
580
+ lr=0.001,
581
+ beta=0.9,
582
+ weight_decay=0.0,
583
+ preconditioner_update_probability=None,
584
+ momentum_into_precond_update=True,
585
+ rank: int = 4,
586
+ warmup_steps: int = 0,
587
+ foreach: bool = True,
588
+ q_dtype="float32",
589
+ stochastic_schedule: bool = False,
590
+ storage_dtype: str = "float32",
591
+ mars: bool = False,
592
+ caution: bool = False,
593
+ mars_gamma: float = 0.0025,
594
+ delayed: Optional[bool] = C.use_default,
595
+ exp_avg_input: Optional[bool] = C.use_default,
596
+ gradient_clipping: C.str_or_fn = C.use_default,
597
+ update_clipping: C.str_or_fn = C.use_default,
598
+ eps: float = 1e-8, #
599
+ # expert parameters
600
+ precond_init_scale=None,
601
+ precond_init_scale_scale=1,
602
+ precond_lr=0.1,
603
+ ):
604
+ defaults = locals()
605
+ defaults.pop("self")
606
+ self.precond_schedule = (
607
+ defaults.pop("preconditioner_update_probability") or utils.precond_update_prob_schedule()
608
+ )
609
+ params = defaults.pop("params")
610
+
611
+ delayed = C.default(delayed, self.delayed)
612
+ exp_avg_input = C.default(exp_avg_input, self.exp_avg_input)
613
+ update_clipping = C.default(update_clipping, utils.trust_region_clip_)
614
+
615
+ super().__init__(
616
+ params,
617
+ defaults,
618
+ foreach,
619
+ gradient_clipping,
620
+ update_clipping,
621
+ False, #
622
+ *(C.exp_avg,) * exp_avg_input, #
623
+ C.scale_by_delayed_psgd_lra if delayed else C.scale_by_psgd_lra,
624
+ )
625
+
626
+
627
+ class ForeachDelayedPSGDLRA(ForeachPSGDLRA):
628
+ delayed: bool = True
629
+
630
+
631
+ class ForeachNewtonPSGDLRA(ForeachPSGDLRA):
632
+ hessian_approx = True
633
+
634
+
635
+ PalmForEachSoap = PaLMForeachSOAP
636
+ PaLMSOAP = PaLMForeachSOAP
637
+ PaLMSFAdamW = PaLMForeachSFAdamW
638
+ SOAP = ForeachSOAP
639
+ SFAdamW = ForeachSFAdamW
640
+ LaProp = ForeachLaProp
641
+ ADOPT = ForeachADOPT
642
+ RMSprop = ForeachRMSprop
643
+ PrecondScheduleSOAP = PrecondScheduleForeachSOAP
644
+ PrecondSchedulePaLMSOAP = PrecondSchedulePaLMForeachSOAP
645
+ PSGDKron = ForeachPSGDKron
646
+ AdamW = ForeachAdamW
647
+ PurePSGD = ForeachPurePSGD
648
+ DelayedPSGD = ForeachDelayedPSGD
649
+ CachedPSGDKron = ForeachCachedPSGDKron
650
+ CachedDelayedPSGDKron = ForeachCachedDelayedPSGDKron
651
+ Muon = ForeachMuon
652
+ SignLaProp = ForeachSignLaProp
653
+ DelayedPSGDLRA = ForeachDelayedPSGDLRA
654
+ PSGDLRA = ForeachPSGDLRA
655
+ NewtonPSGDLRA = ForeachNewtonPSGDLRA
656
+
657
+ __all__ = [
658
+ "Muon",
659
+ "RMSprop",
660
+ "PrecondSchedulePaLMSOAP",
661
+ "PSGDKron",
662
+ "PurePSGD",
663
+ "DelayedPSGD",
664
+ "CachedPSGDKron",
665
+ "CachedDelayedPSGDKron",
666
+ "PalmForEachSoap",
667
+ "PaLMSOAP",
668
+ "PaLMSFAdamW",
669
+ "LaProp",
670
+ "ADOPT",
671
+ "PrecondScheduleSOAP",
672
+ "PrecondSchedulePaLMSOAP",
673
+ "RMSprop",
674
+ "MuonLaProp",
675
+ "ForeachSignLaProp",
676
+ "ForeachDelayedPSGDLRA",
677
+ "ForeachPSGDLRA",
678
+ "ForeachPSGDLRA",
679
+ "ForeachNewtonPSGDLRA", #
680
+ "ForeachAdamW",
681
+ "ForeachSFAdamW",
682
+ "ForeachLaProp",
683
+ "ForeachADOPT",
684
+ "ForeachSOAP",
685
+ "ForeachPSGDKron",
686
+ "ForeachPurePSGD",
687
+ "ForeachDelayedPSGD",
688
+ "ForeachCachedPSGDKron",
689
+ "ForeachCachedDelayedPSGDKron",
690
+ "ForeachRMSprop",
691
+ "ForeachMuon",
692
+ "ForeachCachedNewtonPSGD",
693
+ "OrthoLaProp",
694
+ "LaPropOrtho",
695
+ "SignLaProp",
696
+ "DelayedPSGD",
697
+ "PSGDLRA",
698
+ "NewtonPSGDLRA",
699
+ ]