heavyball 1.7.2__py3-none-any.whl → 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
heavyball/__init__.py CHANGED
@@ -2,10 +2,73 @@ import functools
2
2
  import math
3
3
  from typing import Optional
4
4
 
5
+ import torch.optim
6
+
5
7
  from . import chainable as C
6
8
  from . import utils
7
9
 
8
10
 
11
+ class SAMWrapper(torch.optim.Optimizer):
12
+ def __init__(self, params, wrapped_optimizer: utils.StatefulOptimizer, ball: float = 0.1):
13
+ if not isinstance(wrapped_optimizer, utils.StatefulOptimizer):
14
+ raise ValueError(f"{wrapped_optimizer.__class__.__name__} is not a HeavyBall optimizer")
15
+ super().__init__(params, {"ball": ball})
16
+ self.wrapped_optimizer = wrapped_optimizer
17
+
18
+ @torch.no_grad()
19
+ def step(self, closure=None):
20
+ if closure is None:
21
+ raise ValueError("SAM requires closure")
22
+ with torch.enable_grad():
23
+ closure()
24
+ old_params = [utils.sam_step(group["params"], group["ball"]) for group in self.param_groups]
25
+
26
+ originaL_handle_closure = self.wrapped_optimizer._handle_closure
27
+
28
+ def _handle_closure(closure):
29
+ originaL_handle_closure(closure)
30
+ for group, old in zip(self.param_groups, old_params):
31
+ utils.copy_stochastic_list_(group["params"], old)
32
+
33
+ try:
34
+ self.wrapped_optimizer._handle_closure = _handle_closure
35
+ loss = self.wrapped_optimizer.step(closure)
36
+ finally:
37
+ self.wrapped_optimizer._handle_closure = originaL_handle_closure
38
+ return loss
39
+
40
+ def zero_grad(self, set_to_none: bool = True):
41
+ self.wrapped_optimizer.zero_grad()
42
+
43
+
44
+ class SGD(C.BaseOpt):
45
+ def __init__(
46
+ self,
47
+ params,
48
+ lr=0.0025,
49
+ beta=0.9,
50
+ weight_decay=0,
51
+ warmup_steps=0,
52
+ foreach: bool = True,
53
+ storage_dtype: str = "float32",
54
+ mars: bool = False,
55
+ caution: bool = False,
56
+ mars_gamma: float = 0.0025,
57
+ gradient_clipping: C.str_or_fn = C.use_default,
58
+ update_clipping: C.str_or_fn = C.use_default,
59
+ **kwargs,
60
+ ):
61
+ defaults = locals()
62
+ defaults.pop("self")
63
+ params = defaults.pop("params")
64
+ defaults.update(defaults.pop("kwargs"))
65
+
66
+ if kwargs:
67
+ utils.warn_once(f"Working with uncaptured keyword arguments: {kwargs}")
68
+
69
+ super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, fns=(C.heavyball_momentum,))
70
+
71
+
9
72
  class ForeachAdamW(C.BaseOpt):
10
73
  def __init__(
11
74
  self,
@@ -24,11 +87,55 @@ class ForeachAdamW(C.BaseOpt):
24
87
  update_clipping: C.str_or_fn = C.use_default,
25
88
  palm: bool = C.use_default,
26
89
  beta2_scale: float = 0.8,
90
+ **kwargs,
27
91
  ):
28
92
  defaults = locals()
29
93
  defaults.pop("self")
30
94
  params = defaults.pop("params")
31
- super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, C.update_by_adam)
95
+ defaults.update(defaults.pop("kwargs"))
96
+
97
+ if kwargs:
98
+ utils.warn_once(f"Working with uncaptured keyword arguments: {kwargs}")
99
+
100
+ super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, fns=(C.update_by_adam,))
101
+
102
+
103
+ class ForeachAdamC(C.BaseOpt):
104
+ def __init__(
105
+ self,
106
+ params,
107
+ lr=0.0025,
108
+ betas=(0.9, 0.99),
109
+ eps=1e-8,
110
+ weight_decay=0,
111
+ max_lr: float | None = None,
112
+ warmup_steps=0,
113
+ foreach: bool = True,
114
+ storage_dtype: str = "float32",
115
+ mars: bool = False,
116
+ caution: bool = False,
117
+ mars_gamma: float = 0.0025,
118
+ gradient_clipping: C.str_or_fn = C.use_default,
119
+ update_clipping: C.str_or_fn = C.use_default,
120
+ palm: bool = C.use_default,
121
+ beta2_scale: float = 0.8,
122
+ **kwargs,
123
+ ):
124
+ if max_lr is None:
125
+ utils.warn_once(
126
+ "max_lr was not set. setting it to the current learning rate, under the assumption that it strictly decreases"
127
+ )
128
+ max_lr = lr
129
+
130
+ defaults = locals()
131
+ defaults.pop("self")
132
+ params = defaults.pop("params")
133
+ defaults.update(defaults.pop("kwargs"))
134
+
135
+ if kwargs:
136
+ utils.warn_once(f"Working with uncaptured keyword arguments: {kwargs}")
137
+
138
+ super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, fns=(C.update_by_adamc,))
32
139
 
33
140
 
34
141
  class ForeachRMSprop(C.BaseOpt):
@@ -55,10 +162,16 @@ class ForeachRMSprop(C.BaseOpt):
55
162
  update_clipping: C.str_or_fn = C.use_default,
56
163
  palm: bool = C.use_default,
57
164
  beta2_scale: float = 0.8,
165
+ **kwargs,
58
166
  ):
59
167
  defaults = locals()
60
168
  defaults.pop("self")
61
169
  params = defaults.pop("params")
170
+ defaults.update(defaults.pop("kwargs"))
171
+
172
+ if kwargs:
173
+ utils.warn_once(f"Working with uncaptured keyword arguments: {kwargs}")
174
+
62
175
  super().__init__(
63
176
  params,
64
177
  defaults,
@@ -66,7 +179,7 @@ class ForeachRMSprop(C.BaseOpt):
66
179
  gradient_clipping,
67
180
  update_clipping,
68
181
  palm,
69
- C.scale_by_exp_avg_sq,
182
+ fns=(C.scale_by_exp_avg_sq,),
70
183
  )
71
184
 
72
185
 
@@ -90,10 +203,58 @@ class ForeachSFAdamW(C.ScheduleFree):
90
203
  update_clipping: C.str_or_fn = C.use_default,
91
204
  palm: bool = C.use_default,
92
205
  beta2_scale: float = 0.8,
206
+ **kwargs,
207
+ ):
208
+ defaults = locals()
209
+ defaults.pop("self")
210
+ params = defaults.pop("params")
211
+ defaults.update(defaults.pop("kwargs"))
212
+
213
+ if kwargs:
214
+ utils.warn_once(f"Working with uncaptured keyword arguments: {kwargs}")
215
+
216
+ super().__init__(
217
+ params,
218
+ defaults,
219
+ foreach,
220
+ gradient_clipping,
221
+ update_clipping,
222
+ palm,
223
+ fns=(C.scale_by_exp_avg_sq, C.update_by_schedule_free),
224
+ )
225
+
226
+
227
+ class MSAMLaProp(C.MSAM):
228
+ def __init__(
229
+ self,
230
+ params,
231
+ lr=0.0025,
232
+ betas=(0.9, 0.99),
233
+ eps=1e-6,
234
+ weight_decay=0,
235
+ warmup_steps=0,
236
+ r=0.0,
237
+ weight_lr_power=2.0,
238
+ foreach: bool = True,
239
+ storage_dtype: str = "float32",
240
+ mars: bool = False,
241
+ caution: bool = False,
242
+ mars_gamma: float = 0.0025,
243
+ gradient_clipping: C.str_or_fn = C.use_default,
244
+ update_clipping: C.str_or_fn = C.use_default,
245
+ palm: bool = C.use_default,
246
+ beta2_scale: float = 0.8,
247
+ sam_step_size: float = 0.1,
248
+ **kwargs,
93
249
  ):
94
250
  defaults = locals()
95
251
  defaults.pop("self")
96
252
  params = defaults.pop("params")
253
+ defaults.update(defaults.pop("kwargs"))
254
+
255
+ if kwargs:
256
+ utils.warn_once(f"Working with uncaptured keyword arguments: {kwargs}")
257
+
97
258
  super().__init__(
98
259
  params,
99
260
  defaults,
@@ -101,8 +262,7 @@ class ForeachSFAdamW(C.ScheduleFree):
101
262
  gradient_clipping,
102
263
  update_clipping,
103
264
  palm,
104
- C.scale_by_exp_avg_sq,
105
- C.update_by_schedule_free,
265
+ fns=(C.scale_by_exp_avg_sq, C.update_by_msam),
106
266
  )
107
267
 
108
268
 
@@ -128,11 +288,17 @@ class ForeachADOPT(C.BaseOpt):
128
288
  update_clipping: C.str_or_fn = C.use_default,
129
289
  palm: bool = C.use_default,
130
290
  beta2_scale: float = 0.8,
291
+ **kwargs,
131
292
  ):
132
293
  defaults = locals()
133
294
  defaults.pop("self")
134
295
  params = defaults.pop("params")
135
- super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, C.update_by_adopt)
296
+ defaults.update(defaults.pop("kwargs"))
297
+
298
+ if kwargs:
299
+ utils.warn_once(f"Working with uncaptured keyword arguments: {kwargs}")
300
+
301
+ super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, fns=(C.update_by_adopt,))
136
302
 
137
303
 
138
304
  class ForeachMuon(C.BaseOpt):
@@ -154,10 +320,16 @@ class ForeachMuon(C.BaseOpt):
154
320
  palm: bool = C.use_default,
155
321
  beta2_scale: float = 0.8,
156
322
  nesterov: bool = True,
323
+ **kwargs,
157
324
  ):
158
325
  defaults = locals()
159
326
  defaults.pop("self")
160
327
  params = defaults.pop("params")
328
+ defaults.update(defaults.pop("kwargs"))
329
+
330
+ if kwargs:
331
+ utils.warn_once(f"Working with uncaptured keyword arguments: {kwargs}")
332
+
161
333
  super().__init__(
162
334
  params,
163
335
  defaults,
@@ -165,8 +337,7 @@ class ForeachMuon(C.BaseOpt):
165
337
  gradient_clipping,
166
338
  update_clipping,
167
339
  palm,
168
- C.nesterov_momentum if nesterov else C.heavyball_momentum,
169
- C.orthogonalize_update,
340
+ fns=(C.nesterov_ema if nesterov else C.exp_avg, C.orthogonalize_update),
170
341
  )
171
342
 
172
343
 
@@ -188,11 +359,17 @@ class ForeachLaProp(C.BaseOpt):
188
359
  update_clipping: C.str_or_fn = C.use_default,
189
360
  palm: bool = C.use_default,
190
361
  beta2_scale: float = 0.8,
362
+ **kwargs,
191
363
  ):
192
364
  defaults = locals()
193
365
  defaults.pop("self")
194
366
  params = defaults.pop("params")
195
- super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, C.update_by_laprop)
367
+ defaults.update(defaults.pop("kwargs"))
368
+
369
+ if kwargs:
370
+ utils.warn_once(f"Working with uncaptured keyword arguments: {kwargs}")
371
+
372
+ super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, fns=(C.update_by_laprop,))
196
373
 
197
374
 
198
375
  class MuonLaProp(C.BaseOpt):
@@ -213,10 +390,16 @@ class MuonLaProp(C.BaseOpt):
213
390
  update_clipping: C.str_or_fn = C.use_default,
214
391
  palm: bool = C.use_default,
215
392
  beta2_scale: float = 0.8,
393
+ **kwargs,
216
394
  ):
217
395
  defaults = locals()
218
396
  defaults.pop("self")
219
397
  params = defaults.pop("params")
398
+ defaults.update(defaults.pop("kwargs"))
399
+
400
+ if kwargs:
401
+ utils.warn_once(f"Working with uncaptured keyword arguments: {kwargs}")
402
+
220
403
  super().__init__(
221
404
  params,
222
405
  defaults,
@@ -224,8 +407,7 @@ class MuonLaProp(C.BaseOpt):
224
407
  gradient_clipping,
225
408
  update_clipping,
226
409
  palm,
227
- C.scale_by_laprop,
228
- C.orthogonalize_update,
410
+ fns=(C.scale_by_laprop, C.orthogonalize_update),
229
411
  )
230
412
 
231
413
 
@@ -271,12 +453,18 @@ class ForeachSOAP(C.BaseOpt):
271
453
  update_clipping: C.str_or_fn = C.use_default,
272
454
  storage_dtype: str = "float32",
273
455
  stochastic_schedule: bool = False,
456
+ precond_grad_accum: bool = False,
457
+ **kwargs,
274
458
  ):
275
459
  use_precond_schedule = C.default(use_precond_schedule, self.use_precond_schedule)
276
460
 
277
461
  defaults = locals()
278
462
  defaults.pop("self")
279
463
  params = defaults.pop("params")
464
+ defaults.update(defaults.pop("kwargs"))
465
+
466
+ if kwargs:
467
+ utils.warn_once(f"Working with uncaptured keyword arguments: {kwargs}")
280
468
 
281
469
  if use_precond_schedule:
282
470
  del defaults["precondition_frequency"]
@@ -291,7 +479,7 @@ class ForeachSOAP(C.BaseOpt):
291
479
  gradient_clipping,
292
480
  update_clipping,
293
481
  palm, #
294
- C.scale_by_soap,
482
+ fns=(C.scale_by_soap,),
295
483
  )
296
484
 
297
485
 
@@ -313,10 +501,16 @@ class ForeachSignLaProp(C.BaseOpt):
313
501
  update_clipping: C.str_or_fn = C.use_default,
314
502
  palm: bool = C.use_default,
315
503
  beta2_scale: float = 0.8,
504
+ **kwargs,
316
505
  ):
317
506
  defaults = locals()
318
507
  defaults.pop("self")
319
508
  params = defaults.pop("params")
509
+ defaults.update(defaults.pop("kwargs"))
510
+
511
+ if kwargs:
512
+ utils.warn_once(f"Working with uncaptured keyword arguments: {kwargs}")
513
+
320
514
  super().__init__(
321
515
  params,
322
516
  defaults,
@@ -324,8 +518,7 @@ class ForeachSignLaProp(C.BaseOpt):
324
518
  gradient_clipping,
325
519
  update_clipping,
326
520
  palm,
327
- C.scale_by_laprop,
328
- C.sign,
521
+ fns=(C.scale_by_laprop, C.sign),
329
522
  )
330
523
 
331
524
 
@@ -371,12 +564,17 @@ class ForeachSOLP(C.BaseOpt):
371
564
  update_clipping: C.str_or_fn = C.use_default,
372
565
  storage_dtype: str = "float32",
373
566
  stochastic_schedule: bool = False,
567
+ **kwargs,
374
568
  ):
375
569
  use_precond_schedule = C.default(use_precond_schedule, self.use_precond_schedule)
376
570
 
377
571
  defaults = locals()
378
572
  defaults.pop("self")
379
573
  params = defaults.pop("params")
574
+ defaults.update(defaults.pop("kwargs"))
575
+
576
+ if kwargs:
577
+ utils.warn_once(f"Working with uncaptured keyword arguments: {kwargs}")
380
578
 
381
579
  if use_precond_schedule:
382
580
  del defaults["precondition_frequency"]
@@ -391,7 +589,7 @@ class ForeachSOLP(C.BaseOpt):
391
589
  gradient_clipping,
392
590
  update_clipping,
393
591
  palm, #
394
- functools.partial(C.scale_by_soap, inner="laprop"),
592
+ fns=(functools.partial(C.scale_by_soap, inner="laprop"),),
395
593
  )
396
594
 
397
595
 
@@ -427,10 +625,15 @@ class OrthoLaProp(C.BaseOpt):
427
625
  update_clipping: C.str_or_fn = C.use_default,
428
626
  palm: bool = C.use_default,
429
627
  beta2_scale: float = 0.8,
628
+ **kwargs,
430
629
  ):
431
630
  defaults = locals()
432
631
  defaults.pop("self")
433
632
  params = defaults.pop("params")
633
+ defaults.update(defaults.pop("kwargs"))
634
+
635
+ if kwargs:
636
+ utils.warn_once(f"Working with uncaptured keyword arguments: {kwargs}")
434
637
  super().__init__(
435
638
  params,
436
639
  defaults,
@@ -438,8 +641,7 @@ class OrthoLaProp(C.BaseOpt):
438
641
  gradient_clipping,
439
642
  update_clipping,
440
643
  palm,
441
- C.orthogonalize_grad_to_param,
442
- C.scale_by_laprop,
644
+ fns=(C.orthogonalize_grad_to_param, C.scale_by_laprop),
443
645
  )
444
646
 
445
647
 
@@ -461,10 +663,15 @@ class LaPropOrtho(C.BaseOpt):
461
663
  update_clipping: C.str_or_fn = C.use_default,
462
664
  palm: bool = C.use_default,
463
665
  beta2_scale: float = 0.8,
666
+ **kwargs,
464
667
  ):
465
668
  defaults = locals()
466
669
  defaults.pop("self")
467
670
  params = defaults.pop("params")
671
+ defaults.update(defaults.pop("kwargs"))
672
+
673
+ if kwargs:
674
+ utils.warn_once(f"Working with uncaptured keyword arguments: {kwargs}")
468
675
  super().__init__(
469
676
  params,
470
677
  defaults,
@@ -472,8 +679,7 @@ class LaPropOrtho(C.BaseOpt):
472
679
  gradient_clipping,
473
680
  update_clipping,
474
681
  palm,
475
- C.scale_by_laprop,
476
- C.orthogonalize_grad_to_param,
682
+ fns=(C.scale_by_laprop, C.orthogonalize_grad_to_param),
477
683
  )
478
684
 
479
685
 
@@ -487,12 +693,14 @@ class ForeachPSGDKron(C.BaseOpt):
487
693
  delayed: bool = False
488
694
  cached: bool = False
489
695
  exp_avg_input: bool = True
696
+ quad: bool = False
490
697
 
491
698
  def __init__(
492
699
  self,
493
700
  params,
494
701
  lr=0.001,
495
- beta=0.9,
702
+ beta=None,
703
+ betas=(0.9, 0.999),
496
704
  weight_decay=0.0,
497
705
  preconditioner_update_probability=None,
498
706
  max_size_triangular=2048,
@@ -515,23 +723,40 @@ class ForeachPSGDKron(C.BaseOpt):
515
723
  exp_avg_input: Optional[bool] = C.use_default,
516
724
  gradient_clipping: C.str_or_fn = C.use_default,
517
725
  update_clipping: C.str_or_fn = C.use_default, #
726
+ adaptive: bool = False,
727
+ ortho_method: Optional[str] = None, # If None, no orthogonalization
728
+ precond_grad_accum: bool = False,
729
+ lower_bound_beta: float = 0.9, # 0.0 recovers pre-2.0.0 PSGD
730
+ inverse_free: bool = C.use_default,
731
+ dampening: float = 2**-13,
732
+ precond_update_power_iterations: int = 2,
518
733
  # expert parameters
519
734
  precond_init_scale=None,
520
- precond_init_scale_scale=1,
521
- precond_lr=0.1,
735
+ precond_init_scale_scale: float = 1,
736
+ precond_init_scale_power: Optional[float] = None,
737
+ precond_lr: float = 0.1,
738
+ **kwargs,
522
739
  ):
740
+ delayed = C.default(delayed, self.delayed)
741
+ cached = C.default(cached, self.cached)
742
+ exp_avg_input = C.default(exp_avg_input, self.exp_avg_input)
743
+ update_clipping = C.default(update_clipping, utils.trust_region_clip_)
744
+ inverse_free = C.default(inverse_free, self.quad)
745
+ if inverse_free:
746
+ raise ValueError("inverse_free (i.e., PSGD-QUAD) is not supported at the moment. Consider using https://github.com/evanatyourservice/quad_torch")
747
+
523
748
  defaults = locals()
524
749
  defaults.pop("self")
750
+ defaults.update(defaults.pop("kwargs"))
751
+
752
+ if kwargs:
753
+ utils.warn_once(f"Working with uncaptured keyword arguments: {kwargs}")
754
+
525
755
  self.precond_schedule = (
526
756
  defaults.pop("preconditioner_update_probability") or utils.precond_update_prob_schedule()
527
757
  )
528
758
  params = defaults.pop("params")
529
759
 
530
- delayed = C.default(delayed, self.delayed)
531
- cached = C.default(cached, self.cached)
532
- exp_avg_input = C.default(exp_avg_input, self.exp_avg_input)
533
- update_clipping = C.default(update_clipping, utils.trust_region_clip_)
534
-
535
760
  super().__init__(
536
761
  params,
537
762
  defaults,
@@ -539,8 +764,10 @@ class ForeachPSGDKron(C.BaseOpt):
539
764
  gradient_clipping,
540
765
  update_clipping,
541
766
  False, #
542
- *(C.exp_avg,) * exp_avg_input, #
543
- functools.partial(C.scale_by_delayed_psgd if delayed else C.scale_by_psgd, cached=cached),
767
+ fns=(
768
+ *(C.exp_avg,) * exp_avg_input,
769
+ functools.partial(C.scale_by_delayed_psgd if delayed else C.scale_by_psgd, cached=cached),
770
+ ),
544
771
  )
545
772
 
546
773
 
@@ -569,6 +796,7 @@ class NewtonHybrid2PSGDKron(ForeachCachedNewtonPSGD):
569
796
  hvp_interval = 2
570
797
 
571
798
 
799
+
572
800
  class ForeachPSGDLRA(C.BaseOpt):
573
801
  """
574
802
  Originally from Evan Walters and Omead Pooladzandi, 2024
@@ -601,13 +829,24 @@ class ForeachPSGDLRA(C.BaseOpt):
601
829
  gradient_clipping: C.str_or_fn = C.use_default,
602
830
  update_clipping: C.str_or_fn = C.use_default,
603
831
  eps: float = 1e-8, #
604
- # expert parameters
832
+ precond_grad_accum: bool = False, # expert parameters
605
833
  precond_init_scale=None,
606
- precond_init_scale_scale=1,
607
- precond_lr=0.1,
834
+ precond_init_scale_scale: float = 1,
835
+ precond_init_scale_power: Optional[float] = None,
836
+ precond_lr: float = 0.1,
837
+ **kwargs,
608
838
  ):
839
+ delayed = C.default(delayed, self.delayed)
840
+ exp_avg_input = C.default(exp_avg_input, self.exp_avg_input)
841
+ update_clipping = C.default(update_clipping, utils.trust_region_clip_)
842
+
609
843
  defaults = locals()
610
844
  defaults.pop("self")
845
+ defaults.update(defaults.pop("kwargs"))
846
+
847
+ if kwargs:
848
+ utils.warn_once(f"Working with uncaptured keyword arguments: {kwargs}")
849
+
611
850
  self.precond_schedule = (
612
851
  defaults.pop("preconditioner_update_probability") or utils.precond_update_prob_schedule()
613
852
  )
@@ -621,10 +860,6 @@ class ForeachPSGDLRA(C.BaseOpt):
621
860
  defaults["rank"] = round(math.log2(sum(p.numel() for p in params)))
622
861
  utils.warn_once(f"rank was set to {defaults['rank']}")
623
862
 
624
- delayed = C.default(delayed, self.delayed)
625
- exp_avg_input = C.default(exp_avg_input, self.exp_avg_input)
626
- update_clipping = C.default(update_clipping, utils.trust_region_clip_)
627
-
628
863
  super().__init__(
629
864
  params,
630
865
  defaults,
@@ -632,8 +867,7 @@ class ForeachPSGDLRA(C.BaseOpt):
632
867
  gradient_clipping,
633
868
  update_clipping,
634
869
  False, #
635
- *(C.exp_avg,) * exp_avg_input, #
636
- C.scale_by_delayed_psgd_lra if delayed else C.scale_by_psgd_lra,
870
+ fns=(*(C.exp_avg,) * exp_avg_input, C.scale_by_delayed_psgd_lra if delayed else C.scale_by_psgd_lra),
637
871
  )
638
872
 
639
873
 
@@ -670,6 +904,7 @@ SignLaProp = ForeachSignLaProp
670
904
  DelayedPSGDLRA = ForeachDelayedPSGDLRA
671
905
  PSGDLRA = ForeachPSGDLRA
672
906
  NewtonPSGDLRA = ForeachNewtonPSGDLRA
907
+ NewtonPSGDKron = ForeachCachedNewtonPSGD
673
908
 
674
909
  __all__ = [
675
910
  "Muon",
@@ -715,4 +950,8 @@ __all__ = [
715
950
  "NewtonPSGDLRA",
716
951
  "NewtonHybrid2PSGDLRA",
717
952
  "NewtonHybrid2PSGDKron",
953
+ "MSAMLaProp",
954
+ "NewtonPSGDKron",
955
+ "ForeachAdamC",
956
+ "SGD"
718
957
  ]