heavyball 2.0.0.dev0__py3-none-any.whl → 2.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
heavyball/__init__.py CHANGED
@@ -41,6 +41,34 @@ class SAMWrapper(torch.optim.Optimizer):
41
41
  self.wrapped_optimizer.zero_grad()
42
42
 
43
43
 
44
+ class SGD(C.BaseOpt):
45
+ def __init__(
46
+ self,
47
+ params,
48
+ lr=0.0025,
49
+ beta=0.9,
50
+ weight_decay=0,
51
+ warmup_steps=0,
52
+ foreach: bool = True,
53
+ storage_dtype: str = "float32",
54
+ mars: bool = False,
55
+ caution: bool = False,
56
+ mars_gamma: float = 0.0025,
57
+ gradient_clipping: C.str_or_fn = C.use_default,
58
+ update_clipping: C.str_or_fn = C.use_default,
59
+ **kwargs,
60
+ ):
61
+ defaults = locals()
62
+ defaults.pop("self")
63
+ params = defaults.pop("params")
64
+ defaults.update(defaults.pop("kwargs"))
65
+
66
+ if kwargs:
67
+ utils.warn_once(f"Working with uncaptured keyword arguments: {kwargs}")
68
+
69
+ super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, fns=(C.heavyball_momentum,))
70
+
71
+
44
72
  class ForeachAdamW(C.BaseOpt):
45
73
  def __init__(
46
74
  self,
@@ -69,7 +97,110 @@ class ForeachAdamW(C.BaseOpt):
69
97
  if kwargs:
70
98
  utils.warn_once(f"Working with uncaptured keyword arguments: {kwargs}")
71
99
 
72
- super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, C.update_by_adam)
100
+ super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, fns=(C.update_by_adam,))
101
+
102
+
103
+ class UnscaledAdamW(C.BaseOpt):
104
+ def __init__(
105
+ self,
106
+ params,
107
+ lr=0.0025,
108
+ betas=(0.9, 0.99),
109
+ eps=1e-8,
110
+ weight_decay=0,
111
+ warmup_steps=0,
112
+ foreach: bool = True,
113
+ storage_dtype: str = "float32",
114
+ mars: bool = False,
115
+ caution: bool = False,
116
+ mars_gamma: float = 0.0025,
117
+ gradient_clipping: C.str_or_fn = C.use_default,
118
+ update_clipping: C.str_or_fn = C.use_default,
119
+ palm: bool = C.use_default,
120
+ beta2_scale: float = 0.8,
121
+ **kwargs,
122
+ ):
123
+ defaults = locals()
124
+ defaults.pop("self")
125
+ params = defaults.pop("params")
126
+ defaults.update(defaults.pop("kwargs"))
127
+
128
+ if kwargs:
129
+ utils.warn_once(f"Working with uncaptured keyword arguments: {kwargs}")
130
+
131
+ super().__init__(
132
+ params, defaults, foreach, gradient_clipping, update_clipping, palm, fns=(C.scale_by_unscaled_adam,)
133
+ )
134
+
135
+
136
+ class SUDSAdamW(C.BaseOpt):
137
+ def __init__(
138
+ self,
139
+ params,
140
+ lr=0.0025,
141
+ betas=(0.9, 0.99),
142
+ eps=1e-8,
143
+ weight_decay=0,
144
+ warmup_steps=0,
145
+ foreach: bool = True,
146
+ storage_dtype: str = "float32",
147
+ mars: bool = False,
148
+ caution: bool = False,
149
+ mars_gamma: float = 0.0025,
150
+ gradient_clipping: C.str_or_fn = C.use_default,
151
+ update_clipping: C.str_or_fn = C.use_default,
152
+ palm: bool = C.use_default,
153
+ beta2_scale: float = 0.8,
154
+ precond_lr: float = 1e-2,
155
+ **kwargs,
156
+ ):
157
+ defaults = locals()
158
+ defaults.pop("self")
159
+ params = defaults.pop("params")
160
+ defaults.update(defaults.pop("kwargs"))
161
+
162
+ if kwargs:
163
+ utils.warn_once(f"Working with uncaptured keyword arguments: {kwargs}")
164
+
165
+ super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, fns=(C.scale_by_suds,))
166
+
167
+
168
+ class ForeachAdamC(C.BaseOpt):
169
+ def __init__(
170
+ self,
171
+ params,
172
+ lr=0.0025,
173
+ betas=(0.9, 0.99),
174
+ eps=1e-8,
175
+ weight_decay=0,
176
+ max_lr: float | None = None,
177
+ warmup_steps=0,
178
+ foreach: bool = True,
179
+ storage_dtype: str = "float32",
180
+ mars: bool = False,
181
+ caution: bool = False,
182
+ mars_gamma: float = 0.0025,
183
+ gradient_clipping: C.str_or_fn = C.use_default,
184
+ update_clipping: C.str_or_fn = C.use_default,
185
+ palm: bool = C.use_default,
186
+ beta2_scale: float = 0.8,
187
+ **kwargs,
188
+ ):
189
+ if max_lr is None:
190
+ utils.warn_once(
191
+ "max_lr was not set. setting it to the current learning rate, under the assumption that it strictly decreases"
192
+ )
193
+ max_lr = lr
194
+
195
+ defaults = locals()
196
+ defaults.pop("self")
197
+ params = defaults.pop("params")
198
+ defaults.update(defaults.pop("kwargs"))
199
+
200
+ if kwargs:
201
+ utils.warn_once(f"Working with uncaptured keyword arguments: {kwargs}")
202
+
203
+ super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, fns=(C.update_by_adamc,))
73
204
 
74
205
 
75
206
  class ForeachRMSprop(C.BaseOpt):
@@ -113,7 +244,7 @@ class ForeachRMSprop(C.BaseOpt):
113
244
  gradient_clipping,
114
245
  update_clipping,
115
246
  palm,
116
- C.scale_by_exp_avg_sq,
247
+ fns=(C.scale_by_exp_avg_sq,),
117
248
  )
118
249
 
119
250
 
@@ -154,8 +285,7 @@ class ForeachSFAdamW(C.ScheduleFree):
154
285
  gradient_clipping,
155
286
  update_clipping,
156
287
  palm,
157
- C.scale_by_exp_avg_sq,
158
- C.update_by_schedule_free,
288
+ fns=(C.scale_by_exp_avg_sq, C.update_by_schedule_free),
159
289
  )
160
290
 
161
291
 
@@ -197,8 +327,7 @@ class MSAMLaProp(C.MSAM):
197
327
  gradient_clipping,
198
328
  update_clipping,
199
329
  palm,
200
- C.scale_by_exp_avg_sq,
201
- C.update_by_msam,
330
+ fns=(C.scale_by_exp_avg_sq, C.update_by_msam),
202
331
  )
203
332
 
204
333
 
@@ -234,7 +363,7 @@ class ForeachADOPT(C.BaseOpt):
234
363
  if kwargs:
235
364
  utils.warn_once(f"Working with uncaptured keyword arguments: {kwargs}")
236
365
 
237
- super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, C.update_by_adopt)
366
+ super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, fns=(C.update_by_adopt,))
238
367
 
239
368
 
240
369
  class ForeachMuon(C.BaseOpt):
@@ -256,6 +385,7 @@ class ForeachMuon(C.BaseOpt):
256
385
  palm: bool = C.use_default,
257
386
  beta2_scale: float = 0.8,
258
387
  nesterov: bool = True,
388
+ heavyball_momentum: bool = False,
259
389
  **kwargs,
260
390
  ):
261
391
  defaults = locals()
@@ -266,6 +396,16 @@ class ForeachMuon(C.BaseOpt):
266
396
  if kwargs:
267
397
  utils.warn_once(f"Working with uncaptured keyword arguments: {kwargs}")
268
398
 
399
+ if heavyball_momentum:
400
+ if nesterov:
401
+ ema = C.nesterov_momentum
402
+ else:
403
+ ema = C.heavyball_momentum
404
+ elif nesterov:
405
+ ema = C.nesterov_ema
406
+ else:
407
+ ema = C.exp_avg
408
+
269
409
  super().__init__(
270
410
  params,
271
411
  defaults,
@@ -273,8 +413,7 @@ class ForeachMuon(C.BaseOpt):
273
413
  gradient_clipping,
274
414
  update_clipping,
275
415
  palm,
276
- C.nesterov_ema if nesterov else C.exp_avg,
277
- C.orthogonalize_update,
416
+ fns=(ema, C.orthogonalize_update),
278
417
  )
279
418
 
280
419
 
@@ -306,7 +445,7 @@ class ForeachLaProp(C.BaseOpt):
306
445
  if kwargs:
307
446
  utils.warn_once(f"Working with uncaptured keyword arguments: {kwargs}")
308
447
 
309
- super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, C.update_by_laprop)
448
+ super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, fns=(C.update_by_laprop,))
310
449
 
311
450
 
312
451
  class MuonLaProp(C.BaseOpt):
@@ -344,8 +483,7 @@ class MuonLaProp(C.BaseOpt):
344
483
  gradient_clipping,
345
484
  update_clipping,
346
485
  palm,
347
- C.scale_by_laprop,
348
- C.orthogonalize_update,
486
+ fns=(C.scale_by_laprop, C.orthogonalize_update),
349
487
  )
350
488
 
351
489
 
@@ -417,7 +555,7 @@ class ForeachSOAP(C.BaseOpt):
417
555
  gradient_clipping,
418
556
  update_clipping,
419
557
  palm, #
420
- C.scale_by_soap,
558
+ fns=(C.scale_by_soap,),
421
559
  )
422
560
 
423
561
 
@@ -456,8 +594,7 @@ class ForeachSignLaProp(C.BaseOpt):
456
594
  gradient_clipping,
457
595
  update_clipping,
458
596
  palm,
459
- C.scale_by_laprop,
460
- C.sign,
597
+ fns=(C.scale_by_laprop, C.sign),
461
598
  )
462
599
 
463
600
 
@@ -528,7 +665,7 @@ class ForeachSOLP(C.BaseOpt):
528
665
  gradient_clipping,
529
666
  update_clipping,
530
667
  palm, #
531
- functools.partial(C.scale_by_soap, inner="laprop"),
668
+ fns=(functools.partial(C.scale_by_soap, inner="laprop"),),
532
669
  )
533
670
 
534
671
 
@@ -580,8 +717,7 @@ class OrthoLaProp(C.BaseOpt):
580
717
  gradient_clipping,
581
718
  update_clipping,
582
719
  palm,
583
- C.orthogonalize_grad_to_param,
584
- C.scale_by_laprop,
720
+ fns=(C.orthogonalize_grad_to_param, C.scale_by_laprop),
585
721
  )
586
722
 
587
723
 
@@ -619,8 +755,7 @@ class LaPropOrtho(C.BaseOpt):
619
755
  gradient_clipping,
620
756
  update_clipping,
621
757
  palm,
622
- C.scale_by_laprop,
623
- C.orthogonalize_grad_to_param,
758
+ fns=(C.scale_by_laprop, C.orthogonalize_grad_to_param),
624
759
  )
625
760
 
626
761
 
@@ -683,6 +818,10 @@ class ForeachPSGDKron(C.BaseOpt):
683
818
  exp_avg_input = C.default(exp_avg_input, self.exp_avg_input)
684
819
  update_clipping = C.default(update_clipping, utils.trust_region_clip_)
685
820
  inverse_free = C.default(inverse_free, self.quad)
821
+ if inverse_free:
822
+ raise ValueError(
823
+ "inverse_free (i.e., PSGD-QUAD) is not supported at the moment. Consider using https://github.com/evanatyourservice/quad_torch"
824
+ )
686
825
 
687
826
  defaults = locals()
688
827
  defaults.pop("self")
@@ -703,8 +842,10 @@ class ForeachPSGDKron(C.BaseOpt):
703
842
  gradient_clipping,
704
843
  update_clipping,
705
844
  False, #
706
- *(C.exp_avg,) * exp_avg_input, #
707
- functools.partial(C.scale_by_delayed_psgd if delayed else C.scale_by_psgd, cached=cached),
845
+ fns=(
846
+ *(C.exp_avg,) * exp_avg_input,
847
+ functools.partial(C.scale_by_delayed_psgd if delayed else C.scale_by_psgd, cached=cached),
848
+ ),
708
849
  )
709
850
 
710
851
 
@@ -733,11 +874,6 @@ class NewtonHybrid2PSGDKron(ForeachCachedNewtonPSGD):
733
874
  hvp_interval = 2
734
875
 
735
876
 
736
- class QUAD(ForeachPSGDKron):
737
- quad = True
738
- cached = True
739
-
740
-
741
877
  class ForeachPSGDLRA(C.BaseOpt):
742
878
  """
743
879
  Originally from Evan Walters and Omead Pooladzandi, 2024
@@ -808,8 +944,7 @@ class ForeachPSGDLRA(C.BaseOpt):
808
944
  gradient_clipping,
809
945
  update_clipping,
810
946
  False, #
811
- *(C.exp_avg,) * exp_avg_input, #
812
- C.scale_by_delayed_psgd_lra if delayed else C.scale_by_psgd_lra,
947
+ fns=(*(C.exp_avg,) * exp_avg_input, C.scale_by_delayed_psgd_lra if delayed else C.scale_by_psgd_lra),
813
948
  )
814
949
 
815
950
 
@@ -846,6 +981,7 @@ SignLaProp = ForeachSignLaProp
846
981
  DelayedPSGDLRA = ForeachDelayedPSGDLRA
847
982
  PSGDLRA = ForeachPSGDLRA
848
983
  NewtonPSGDLRA = ForeachNewtonPSGDLRA
984
+ NewtonPSGDKron = ForeachCachedNewtonPSGD
849
985
 
850
986
  __all__ = [
851
987
  "Muon",
@@ -892,4 +1028,7 @@ __all__ = [
892
1028
  "NewtonHybrid2PSGDLRA",
893
1029
  "NewtonHybrid2PSGDKron",
894
1030
  "MSAMLaProp",
1031
+ "NewtonPSGDKron",
1032
+ "ForeachAdamC",
1033
+ "SGD",
895
1034
  ]