heavyball 2.0.0.dev0__py3-none-any.whl → 2.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- heavyball/__init__.py +168 -29
- heavyball/chainable.py +165 -63
- heavyball/helpers.py +5 -1
- heavyball/utils.py +490 -124
- {heavyball-2.0.0.dev0.dist-info → heavyball-2.1.0.dist-info}/METADATA +19 -7
- heavyball-2.1.0.dist-info/RECORD +9 -0
- {heavyball-2.0.0.dev0.dist-info → heavyball-2.1.0.dist-info}/WHEEL +1 -1
- heavyball-2.0.0.dev0.dist-info/RECORD +0 -9
- {heavyball-2.0.0.dev0.dist-info → heavyball-2.1.0.dist-info}/licenses/LICENSE +0 -0
- {heavyball-2.0.0.dev0.dist-info → heavyball-2.1.0.dist-info}/top_level.txt +0 -0
heavyball/__init__.py
CHANGED
@@ -41,6 +41,34 @@ class SAMWrapper(torch.optim.Optimizer):
|
|
41
41
|
self.wrapped_optimizer.zero_grad()
|
42
42
|
|
43
43
|
|
44
|
+
class SGD(C.BaseOpt):
|
45
|
+
def __init__(
|
46
|
+
self,
|
47
|
+
params,
|
48
|
+
lr=0.0025,
|
49
|
+
beta=0.9,
|
50
|
+
weight_decay=0,
|
51
|
+
warmup_steps=0,
|
52
|
+
foreach: bool = True,
|
53
|
+
storage_dtype: str = "float32",
|
54
|
+
mars: bool = False,
|
55
|
+
caution: bool = False,
|
56
|
+
mars_gamma: float = 0.0025,
|
57
|
+
gradient_clipping: C.str_or_fn = C.use_default,
|
58
|
+
update_clipping: C.str_or_fn = C.use_default,
|
59
|
+
**kwargs,
|
60
|
+
):
|
61
|
+
defaults = locals()
|
62
|
+
defaults.pop("self")
|
63
|
+
params = defaults.pop("params")
|
64
|
+
defaults.update(defaults.pop("kwargs"))
|
65
|
+
|
66
|
+
if kwargs:
|
67
|
+
utils.warn_once(f"Working with uncaptured keyword arguments: {kwargs}")
|
68
|
+
|
69
|
+
super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, fns=(C.heavyball_momentum,))
|
70
|
+
|
71
|
+
|
44
72
|
class ForeachAdamW(C.BaseOpt):
|
45
73
|
def __init__(
|
46
74
|
self,
|
@@ -69,7 +97,110 @@ class ForeachAdamW(C.BaseOpt):
|
|
69
97
|
if kwargs:
|
70
98
|
utils.warn_once(f"Working with uncaptured keyword arguments: {kwargs}")
|
71
99
|
|
72
|
-
super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, C.update_by_adam)
|
100
|
+
super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, fns=(C.update_by_adam,))
|
101
|
+
|
102
|
+
|
103
|
+
class UnscaledAdamW(C.BaseOpt):
|
104
|
+
def __init__(
|
105
|
+
self,
|
106
|
+
params,
|
107
|
+
lr=0.0025,
|
108
|
+
betas=(0.9, 0.99),
|
109
|
+
eps=1e-8,
|
110
|
+
weight_decay=0,
|
111
|
+
warmup_steps=0,
|
112
|
+
foreach: bool = True,
|
113
|
+
storage_dtype: str = "float32",
|
114
|
+
mars: bool = False,
|
115
|
+
caution: bool = False,
|
116
|
+
mars_gamma: float = 0.0025,
|
117
|
+
gradient_clipping: C.str_or_fn = C.use_default,
|
118
|
+
update_clipping: C.str_or_fn = C.use_default,
|
119
|
+
palm: bool = C.use_default,
|
120
|
+
beta2_scale: float = 0.8,
|
121
|
+
**kwargs,
|
122
|
+
):
|
123
|
+
defaults = locals()
|
124
|
+
defaults.pop("self")
|
125
|
+
params = defaults.pop("params")
|
126
|
+
defaults.update(defaults.pop("kwargs"))
|
127
|
+
|
128
|
+
if kwargs:
|
129
|
+
utils.warn_once(f"Working with uncaptured keyword arguments: {kwargs}")
|
130
|
+
|
131
|
+
super().__init__(
|
132
|
+
params, defaults, foreach, gradient_clipping, update_clipping, palm, fns=(C.scale_by_unscaled_adam,)
|
133
|
+
)
|
134
|
+
|
135
|
+
|
136
|
+
class SUDSAdamW(C.BaseOpt):
|
137
|
+
def __init__(
|
138
|
+
self,
|
139
|
+
params,
|
140
|
+
lr=0.0025,
|
141
|
+
betas=(0.9, 0.99),
|
142
|
+
eps=1e-8,
|
143
|
+
weight_decay=0,
|
144
|
+
warmup_steps=0,
|
145
|
+
foreach: bool = True,
|
146
|
+
storage_dtype: str = "float32",
|
147
|
+
mars: bool = False,
|
148
|
+
caution: bool = False,
|
149
|
+
mars_gamma: float = 0.0025,
|
150
|
+
gradient_clipping: C.str_or_fn = C.use_default,
|
151
|
+
update_clipping: C.str_or_fn = C.use_default,
|
152
|
+
palm: bool = C.use_default,
|
153
|
+
beta2_scale: float = 0.8,
|
154
|
+
precond_lr: float = 1e-2,
|
155
|
+
**kwargs,
|
156
|
+
):
|
157
|
+
defaults = locals()
|
158
|
+
defaults.pop("self")
|
159
|
+
params = defaults.pop("params")
|
160
|
+
defaults.update(defaults.pop("kwargs"))
|
161
|
+
|
162
|
+
if kwargs:
|
163
|
+
utils.warn_once(f"Working with uncaptured keyword arguments: {kwargs}")
|
164
|
+
|
165
|
+
super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, fns=(C.scale_by_suds,))
|
166
|
+
|
167
|
+
|
168
|
+
class ForeachAdamC(C.BaseOpt):
|
169
|
+
def __init__(
|
170
|
+
self,
|
171
|
+
params,
|
172
|
+
lr=0.0025,
|
173
|
+
betas=(0.9, 0.99),
|
174
|
+
eps=1e-8,
|
175
|
+
weight_decay=0,
|
176
|
+
max_lr: float | None = None,
|
177
|
+
warmup_steps=0,
|
178
|
+
foreach: bool = True,
|
179
|
+
storage_dtype: str = "float32",
|
180
|
+
mars: bool = False,
|
181
|
+
caution: bool = False,
|
182
|
+
mars_gamma: float = 0.0025,
|
183
|
+
gradient_clipping: C.str_or_fn = C.use_default,
|
184
|
+
update_clipping: C.str_or_fn = C.use_default,
|
185
|
+
palm: bool = C.use_default,
|
186
|
+
beta2_scale: float = 0.8,
|
187
|
+
**kwargs,
|
188
|
+
):
|
189
|
+
if max_lr is None:
|
190
|
+
utils.warn_once(
|
191
|
+
"max_lr was not set. setting it to the current learning rate, under the assumption that it strictly decreases"
|
192
|
+
)
|
193
|
+
max_lr = lr
|
194
|
+
|
195
|
+
defaults = locals()
|
196
|
+
defaults.pop("self")
|
197
|
+
params = defaults.pop("params")
|
198
|
+
defaults.update(defaults.pop("kwargs"))
|
199
|
+
|
200
|
+
if kwargs:
|
201
|
+
utils.warn_once(f"Working with uncaptured keyword arguments: {kwargs}")
|
202
|
+
|
203
|
+
super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, fns=(C.update_by_adamc,))
|
73
204
|
|
74
205
|
|
75
206
|
class ForeachRMSprop(C.BaseOpt):
|
@@ -113,7 +244,7 @@ class ForeachRMSprop(C.BaseOpt):
|
|
113
244
|
gradient_clipping,
|
114
245
|
update_clipping,
|
115
246
|
palm,
|
116
|
-
C.scale_by_exp_avg_sq,
|
247
|
+
fns=(C.scale_by_exp_avg_sq,),
|
117
248
|
)
|
118
249
|
|
119
250
|
|
@@ -154,8 +285,7 @@ class ForeachSFAdamW(C.ScheduleFree):
|
|
154
285
|
gradient_clipping,
|
155
286
|
update_clipping,
|
156
287
|
palm,
|
157
|
-
C.scale_by_exp_avg_sq,
|
158
|
-
C.update_by_schedule_free,
|
288
|
+
fns=(C.scale_by_exp_avg_sq, C.update_by_schedule_free),
|
159
289
|
)
|
160
290
|
|
161
291
|
|
@@ -197,8 +327,7 @@ class MSAMLaProp(C.MSAM):
|
|
197
327
|
gradient_clipping,
|
198
328
|
update_clipping,
|
199
329
|
palm,
|
200
|
-
C.scale_by_exp_avg_sq,
|
201
|
-
C.update_by_msam,
|
330
|
+
fns=(C.scale_by_exp_avg_sq, C.update_by_msam),
|
202
331
|
)
|
203
332
|
|
204
333
|
|
@@ -234,7 +363,7 @@ class ForeachADOPT(C.BaseOpt):
|
|
234
363
|
if kwargs:
|
235
364
|
utils.warn_once(f"Working with uncaptured keyword arguments: {kwargs}")
|
236
365
|
|
237
|
-
super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, C.update_by_adopt)
|
366
|
+
super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, fns=(C.update_by_adopt,))
|
238
367
|
|
239
368
|
|
240
369
|
class ForeachMuon(C.BaseOpt):
|
@@ -256,6 +385,7 @@ class ForeachMuon(C.BaseOpt):
|
|
256
385
|
palm: bool = C.use_default,
|
257
386
|
beta2_scale: float = 0.8,
|
258
387
|
nesterov: bool = True,
|
388
|
+
heavyball_momentum: bool = False,
|
259
389
|
**kwargs,
|
260
390
|
):
|
261
391
|
defaults = locals()
|
@@ -266,6 +396,16 @@ class ForeachMuon(C.BaseOpt):
|
|
266
396
|
if kwargs:
|
267
397
|
utils.warn_once(f"Working with uncaptured keyword arguments: {kwargs}")
|
268
398
|
|
399
|
+
if heavyball_momentum:
|
400
|
+
if nesterov:
|
401
|
+
ema = C.nesterov_momentum
|
402
|
+
else:
|
403
|
+
ema = C.heavyball_momentum
|
404
|
+
elif nesterov:
|
405
|
+
ema = C.nesterov_ema
|
406
|
+
else:
|
407
|
+
ema = C.exp_avg
|
408
|
+
|
269
409
|
super().__init__(
|
270
410
|
params,
|
271
411
|
defaults,
|
@@ -273,8 +413,7 @@ class ForeachMuon(C.BaseOpt):
|
|
273
413
|
gradient_clipping,
|
274
414
|
update_clipping,
|
275
415
|
palm,
|
276
|
-
|
277
|
-
C.orthogonalize_update,
|
416
|
+
fns=(ema, C.orthogonalize_update),
|
278
417
|
)
|
279
418
|
|
280
419
|
|
@@ -306,7 +445,7 @@ class ForeachLaProp(C.BaseOpt):
|
|
306
445
|
if kwargs:
|
307
446
|
utils.warn_once(f"Working with uncaptured keyword arguments: {kwargs}")
|
308
447
|
|
309
|
-
super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, C.update_by_laprop)
|
448
|
+
super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, fns=(C.update_by_laprop,))
|
310
449
|
|
311
450
|
|
312
451
|
class MuonLaProp(C.BaseOpt):
|
@@ -344,8 +483,7 @@ class MuonLaProp(C.BaseOpt):
|
|
344
483
|
gradient_clipping,
|
345
484
|
update_clipping,
|
346
485
|
palm,
|
347
|
-
C.scale_by_laprop,
|
348
|
-
C.orthogonalize_update,
|
486
|
+
fns=(C.scale_by_laprop, C.orthogonalize_update),
|
349
487
|
)
|
350
488
|
|
351
489
|
|
@@ -417,7 +555,7 @@ class ForeachSOAP(C.BaseOpt):
|
|
417
555
|
gradient_clipping,
|
418
556
|
update_clipping,
|
419
557
|
palm, #
|
420
|
-
C.scale_by_soap,
|
558
|
+
fns=(C.scale_by_soap,),
|
421
559
|
)
|
422
560
|
|
423
561
|
|
@@ -456,8 +594,7 @@ class ForeachSignLaProp(C.BaseOpt):
|
|
456
594
|
gradient_clipping,
|
457
595
|
update_clipping,
|
458
596
|
palm,
|
459
|
-
C.scale_by_laprop,
|
460
|
-
C.sign,
|
597
|
+
fns=(C.scale_by_laprop, C.sign),
|
461
598
|
)
|
462
599
|
|
463
600
|
|
@@ -528,7 +665,7 @@ class ForeachSOLP(C.BaseOpt):
|
|
528
665
|
gradient_clipping,
|
529
666
|
update_clipping,
|
530
667
|
palm, #
|
531
|
-
functools.partial(C.scale_by_soap, inner="laprop"),
|
668
|
+
fns=(functools.partial(C.scale_by_soap, inner="laprop"),),
|
532
669
|
)
|
533
670
|
|
534
671
|
|
@@ -580,8 +717,7 @@ class OrthoLaProp(C.BaseOpt):
|
|
580
717
|
gradient_clipping,
|
581
718
|
update_clipping,
|
582
719
|
palm,
|
583
|
-
C.orthogonalize_grad_to_param,
|
584
|
-
C.scale_by_laprop,
|
720
|
+
fns=(C.orthogonalize_grad_to_param, C.scale_by_laprop),
|
585
721
|
)
|
586
722
|
|
587
723
|
|
@@ -619,8 +755,7 @@ class LaPropOrtho(C.BaseOpt):
|
|
619
755
|
gradient_clipping,
|
620
756
|
update_clipping,
|
621
757
|
palm,
|
622
|
-
C.scale_by_laprop,
|
623
|
-
C.orthogonalize_grad_to_param,
|
758
|
+
fns=(C.scale_by_laprop, C.orthogonalize_grad_to_param),
|
624
759
|
)
|
625
760
|
|
626
761
|
|
@@ -683,6 +818,10 @@ class ForeachPSGDKron(C.BaseOpt):
|
|
683
818
|
exp_avg_input = C.default(exp_avg_input, self.exp_avg_input)
|
684
819
|
update_clipping = C.default(update_clipping, utils.trust_region_clip_)
|
685
820
|
inverse_free = C.default(inverse_free, self.quad)
|
821
|
+
if inverse_free:
|
822
|
+
raise ValueError(
|
823
|
+
"inverse_free (i.e., PSGD-QUAD) is not supported at the moment. Consider using https://github.com/evanatyourservice/quad_torch"
|
824
|
+
)
|
686
825
|
|
687
826
|
defaults = locals()
|
688
827
|
defaults.pop("self")
|
@@ -703,8 +842,10 @@ class ForeachPSGDKron(C.BaseOpt):
|
|
703
842
|
gradient_clipping,
|
704
843
|
update_clipping,
|
705
844
|
False, #
|
706
|
-
|
707
|
-
|
845
|
+
fns=(
|
846
|
+
*(C.exp_avg,) * exp_avg_input,
|
847
|
+
functools.partial(C.scale_by_delayed_psgd if delayed else C.scale_by_psgd, cached=cached),
|
848
|
+
),
|
708
849
|
)
|
709
850
|
|
710
851
|
|
@@ -733,11 +874,6 @@ class NewtonHybrid2PSGDKron(ForeachCachedNewtonPSGD):
|
|
733
874
|
hvp_interval = 2
|
734
875
|
|
735
876
|
|
736
|
-
class QUAD(ForeachPSGDKron):
|
737
|
-
quad = True
|
738
|
-
cached = True
|
739
|
-
|
740
|
-
|
741
877
|
class ForeachPSGDLRA(C.BaseOpt):
|
742
878
|
"""
|
743
879
|
Originally from Evan Walters and Omead Pooladzandi, 2024
|
@@ -808,8 +944,7 @@ class ForeachPSGDLRA(C.BaseOpt):
|
|
808
944
|
gradient_clipping,
|
809
945
|
update_clipping,
|
810
946
|
False, #
|
811
|
-
*(C.exp_avg,) * exp_avg_input,
|
812
|
-
C.scale_by_delayed_psgd_lra if delayed else C.scale_by_psgd_lra,
|
947
|
+
fns=(*(C.exp_avg,) * exp_avg_input, C.scale_by_delayed_psgd_lra if delayed else C.scale_by_psgd_lra),
|
813
948
|
)
|
814
949
|
|
815
950
|
|
@@ -846,6 +981,7 @@ SignLaProp = ForeachSignLaProp
|
|
846
981
|
DelayedPSGDLRA = ForeachDelayedPSGDLRA
|
847
982
|
PSGDLRA = ForeachPSGDLRA
|
848
983
|
NewtonPSGDLRA = ForeachNewtonPSGDLRA
|
984
|
+
NewtonPSGDKron = ForeachCachedNewtonPSGD
|
849
985
|
|
850
986
|
__all__ = [
|
851
987
|
"Muon",
|
@@ -892,4 +1028,7 @@ __all__ = [
|
|
892
1028
|
"NewtonHybrid2PSGDLRA",
|
893
1029
|
"NewtonHybrid2PSGDKron",
|
894
1030
|
"MSAMLaProp",
|
1031
|
+
"NewtonPSGDKron",
|
1032
|
+
"ForeachAdamC",
|
1033
|
+
"SGD",
|
895
1034
|
]
|