heavyball 1.7.1__py3-none-any.whl → 2.0.0.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- heavyball/__init__.py +193 -16
- heavyball/chainable.py +338 -190
- heavyball/helpers.py +804 -0
- heavyball/utils.py +813 -252
- heavyball-2.0.0.dev0.dist-info/METADATA +109 -0
- heavyball-2.0.0.dev0.dist-info/RECORD +9 -0
- {heavyball-1.7.1.dist-info → heavyball-2.0.0.dev0.dist-info}/WHEEL +1 -1
- heavyball/optimizations/__init__.py +0 -38
- heavyball/optimizations/integrator.py +0 -169
- heavyball/optimizations/optimizations.py +0 -329
- heavyball-1.7.1.dist-info/METADATA +0 -939
- heavyball-1.7.1.dist-info/RECORD +0 -11
- {heavyball-1.7.1.dist-info → heavyball-2.0.0.dev0.dist-info}/licenses/LICENSE +0 -0
- {heavyball-1.7.1.dist-info → heavyball-2.0.0.dev0.dist-info}/top_level.txt +0 -0
heavyball/__init__.py
CHANGED
@@ -2,10 +2,45 @@ import functools
|
|
2
2
|
import math
|
3
3
|
from typing import Optional
|
4
4
|
|
5
|
+
import torch.optim
|
6
|
+
|
5
7
|
from . import chainable as C
|
6
8
|
from . import utils
|
7
9
|
|
8
10
|
|
11
|
+
class SAMWrapper(torch.optim.Optimizer):
|
12
|
+
def __init__(self, params, wrapped_optimizer: utils.StatefulOptimizer, ball: float = 0.1):
|
13
|
+
if not isinstance(wrapped_optimizer, utils.StatefulOptimizer):
|
14
|
+
raise ValueError(f"{wrapped_optimizer.__class__.__name__} is not a HeavyBall optimizer")
|
15
|
+
super().__init__(params, {"ball": ball})
|
16
|
+
self.wrapped_optimizer = wrapped_optimizer
|
17
|
+
|
18
|
+
@torch.no_grad()
|
19
|
+
def step(self, closure=None):
|
20
|
+
if closure is None:
|
21
|
+
raise ValueError("SAM requires closure")
|
22
|
+
with torch.enable_grad():
|
23
|
+
closure()
|
24
|
+
old_params = [utils.sam_step(group["params"], group["ball"]) for group in self.param_groups]
|
25
|
+
|
26
|
+
originaL_handle_closure = self.wrapped_optimizer._handle_closure
|
27
|
+
|
28
|
+
def _handle_closure(closure):
|
29
|
+
originaL_handle_closure(closure)
|
30
|
+
for group, old in zip(self.param_groups, old_params):
|
31
|
+
utils.copy_stochastic_list_(group["params"], old)
|
32
|
+
|
33
|
+
try:
|
34
|
+
self.wrapped_optimizer._handle_closure = _handle_closure
|
35
|
+
loss = self.wrapped_optimizer.step(closure)
|
36
|
+
finally:
|
37
|
+
self.wrapped_optimizer._handle_closure = originaL_handle_closure
|
38
|
+
return loss
|
39
|
+
|
40
|
+
def zero_grad(self, set_to_none: bool = True):
|
41
|
+
self.wrapped_optimizer.zero_grad()
|
42
|
+
|
43
|
+
|
9
44
|
class ForeachAdamW(C.BaseOpt):
|
10
45
|
def __init__(
|
11
46
|
self,
|
@@ -24,10 +59,16 @@ class ForeachAdamW(C.BaseOpt):
|
|
24
59
|
update_clipping: C.str_or_fn = C.use_default,
|
25
60
|
palm: bool = C.use_default,
|
26
61
|
beta2_scale: float = 0.8,
|
62
|
+
**kwargs,
|
27
63
|
):
|
28
64
|
defaults = locals()
|
29
65
|
defaults.pop("self")
|
30
66
|
params = defaults.pop("params")
|
67
|
+
defaults.update(defaults.pop("kwargs"))
|
68
|
+
|
69
|
+
if kwargs:
|
70
|
+
utils.warn_once(f"Working with uncaptured keyword arguments: {kwargs}")
|
71
|
+
|
31
72
|
super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, C.update_by_adam)
|
32
73
|
|
33
74
|
|
@@ -55,10 +96,16 @@ class ForeachRMSprop(C.BaseOpt):
|
|
55
96
|
update_clipping: C.str_or_fn = C.use_default,
|
56
97
|
palm: bool = C.use_default,
|
57
98
|
beta2_scale: float = 0.8,
|
99
|
+
**kwargs,
|
58
100
|
):
|
59
101
|
defaults = locals()
|
60
102
|
defaults.pop("self")
|
61
103
|
params = defaults.pop("params")
|
104
|
+
defaults.update(defaults.pop("kwargs"))
|
105
|
+
|
106
|
+
if kwargs:
|
107
|
+
utils.warn_once(f"Working with uncaptured keyword arguments: {kwargs}")
|
108
|
+
|
62
109
|
super().__init__(
|
63
110
|
params,
|
64
111
|
defaults,
|
@@ -90,10 +137,16 @@ class ForeachSFAdamW(C.ScheduleFree):
|
|
90
137
|
update_clipping: C.str_or_fn = C.use_default,
|
91
138
|
palm: bool = C.use_default,
|
92
139
|
beta2_scale: float = 0.8,
|
140
|
+
**kwargs,
|
93
141
|
):
|
94
142
|
defaults = locals()
|
95
143
|
defaults.pop("self")
|
96
144
|
params = defaults.pop("params")
|
145
|
+
defaults.update(defaults.pop("kwargs"))
|
146
|
+
|
147
|
+
if kwargs:
|
148
|
+
utils.warn_once(f"Working with uncaptured keyword arguments: {kwargs}")
|
149
|
+
|
97
150
|
super().__init__(
|
98
151
|
params,
|
99
152
|
defaults,
|
@@ -106,6 +159,49 @@ class ForeachSFAdamW(C.ScheduleFree):
|
|
106
159
|
)
|
107
160
|
|
108
161
|
|
162
|
+
class MSAMLaProp(C.MSAM):
|
163
|
+
def __init__(
|
164
|
+
self,
|
165
|
+
params,
|
166
|
+
lr=0.0025,
|
167
|
+
betas=(0.9, 0.99),
|
168
|
+
eps=1e-6,
|
169
|
+
weight_decay=0,
|
170
|
+
warmup_steps=0,
|
171
|
+
r=0.0,
|
172
|
+
weight_lr_power=2.0,
|
173
|
+
foreach: bool = True,
|
174
|
+
storage_dtype: str = "float32",
|
175
|
+
mars: bool = False,
|
176
|
+
caution: bool = False,
|
177
|
+
mars_gamma: float = 0.0025,
|
178
|
+
gradient_clipping: C.str_or_fn = C.use_default,
|
179
|
+
update_clipping: C.str_or_fn = C.use_default,
|
180
|
+
palm: bool = C.use_default,
|
181
|
+
beta2_scale: float = 0.8,
|
182
|
+
sam_step_size: float = 0.1,
|
183
|
+
**kwargs,
|
184
|
+
):
|
185
|
+
defaults = locals()
|
186
|
+
defaults.pop("self")
|
187
|
+
params = defaults.pop("params")
|
188
|
+
defaults.update(defaults.pop("kwargs"))
|
189
|
+
|
190
|
+
if kwargs:
|
191
|
+
utils.warn_once(f"Working with uncaptured keyword arguments: {kwargs}")
|
192
|
+
|
193
|
+
super().__init__(
|
194
|
+
params,
|
195
|
+
defaults,
|
196
|
+
foreach,
|
197
|
+
gradient_clipping,
|
198
|
+
update_clipping,
|
199
|
+
palm,
|
200
|
+
C.scale_by_exp_avg_sq,
|
201
|
+
C.update_by_msam,
|
202
|
+
)
|
203
|
+
|
204
|
+
|
109
205
|
class PaLMForeachSFAdamW(ForeachSFAdamW):
|
110
206
|
palm: bool = True
|
111
207
|
|
@@ -128,10 +224,16 @@ class ForeachADOPT(C.BaseOpt):
|
|
128
224
|
update_clipping: C.str_or_fn = C.use_default,
|
129
225
|
palm: bool = C.use_default,
|
130
226
|
beta2_scale: float = 0.8,
|
227
|
+
**kwargs,
|
131
228
|
):
|
132
229
|
defaults = locals()
|
133
230
|
defaults.pop("self")
|
134
231
|
params = defaults.pop("params")
|
232
|
+
defaults.update(defaults.pop("kwargs"))
|
233
|
+
|
234
|
+
if kwargs:
|
235
|
+
utils.warn_once(f"Working with uncaptured keyword arguments: {kwargs}")
|
236
|
+
|
135
237
|
super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, C.update_by_adopt)
|
136
238
|
|
137
239
|
|
@@ -154,10 +256,16 @@ class ForeachMuon(C.BaseOpt):
|
|
154
256
|
palm: bool = C.use_default,
|
155
257
|
beta2_scale: float = 0.8,
|
156
258
|
nesterov: bool = True,
|
259
|
+
**kwargs,
|
157
260
|
):
|
158
261
|
defaults = locals()
|
159
262
|
defaults.pop("self")
|
160
263
|
params = defaults.pop("params")
|
264
|
+
defaults.update(defaults.pop("kwargs"))
|
265
|
+
|
266
|
+
if kwargs:
|
267
|
+
utils.warn_once(f"Working with uncaptured keyword arguments: {kwargs}")
|
268
|
+
|
161
269
|
super().__init__(
|
162
270
|
params,
|
163
271
|
defaults,
|
@@ -165,7 +273,7 @@ class ForeachMuon(C.BaseOpt):
|
|
165
273
|
gradient_clipping,
|
166
274
|
update_clipping,
|
167
275
|
palm,
|
168
|
-
C.
|
276
|
+
C.nesterov_ema if nesterov else C.exp_avg,
|
169
277
|
C.orthogonalize_update,
|
170
278
|
)
|
171
279
|
|
@@ -188,10 +296,16 @@ class ForeachLaProp(C.BaseOpt):
|
|
188
296
|
update_clipping: C.str_or_fn = C.use_default,
|
189
297
|
palm: bool = C.use_default,
|
190
298
|
beta2_scale: float = 0.8,
|
299
|
+
**kwargs,
|
191
300
|
):
|
192
301
|
defaults = locals()
|
193
302
|
defaults.pop("self")
|
194
303
|
params = defaults.pop("params")
|
304
|
+
defaults.update(defaults.pop("kwargs"))
|
305
|
+
|
306
|
+
if kwargs:
|
307
|
+
utils.warn_once(f"Working with uncaptured keyword arguments: {kwargs}")
|
308
|
+
|
195
309
|
super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, C.update_by_laprop)
|
196
310
|
|
197
311
|
|
@@ -213,10 +327,16 @@ class MuonLaProp(C.BaseOpt):
|
|
213
327
|
update_clipping: C.str_or_fn = C.use_default,
|
214
328
|
palm: bool = C.use_default,
|
215
329
|
beta2_scale: float = 0.8,
|
330
|
+
**kwargs,
|
216
331
|
):
|
217
332
|
defaults = locals()
|
218
333
|
defaults.pop("self")
|
219
334
|
params = defaults.pop("params")
|
335
|
+
defaults.update(defaults.pop("kwargs"))
|
336
|
+
|
337
|
+
if kwargs:
|
338
|
+
utils.warn_once(f"Working with uncaptured keyword arguments: {kwargs}")
|
339
|
+
|
220
340
|
super().__init__(
|
221
341
|
params,
|
222
342
|
defaults,
|
@@ -271,12 +391,18 @@ class ForeachSOAP(C.BaseOpt):
|
|
271
391
|
update_clipping: C.str_or_fn = C.use_default,
|
272
392
|
storage_dtype: str = "float32",
|
273
393
|
stochastic_schedule: bool = False,
|
394
|
+
precond_grad_accum: bool = False,
|
395
|
+
**kwargs,
|
274
396
|
):
|
275
397
|
use_precond_schedule = C.default(use_precond_schedule, self.use_precond_schedule)
|
276
398
|
|
277
399
|
defaults = locals()
|
278
400
|
defaults.pop("self")
|
279
401
|
params = defaults.pop("params")
|
402
|
+
defaults.update(defaults.pop("kwargs"))
|
403
|
+
|
404
|
+
if kwargs:
|
405
|
+
utils.warn_once(f"Working with uncaptured keyword arguments: {kwargs}")
|
280
406
|
|
281
407
|
if use_precond_schedule:
|
282
408
|
del defaults["precondition_frequency"]
|
@@ -313,10 +439,16 @@ class ForeachSignLaProp(C.BaseOpt):
|
|
313
439
|
update_clipping: C.str_or_fn = C.use_default,
|
314
440
|
palm: bool = C.use_default,
|
315
441
|
beta2_scale: float = 0.8,
|
442
|
+
**kwargs,
|
316
443
|
):
|
317
444
|
defaults = locals()
|
318
445
|
defaults.pop("self")
|
319
446
|
params = defaults.pop("params")
|
447
|
+
defaults.update(defaults.pop("kwargs"))
|
448
|
+
|
449
|
+
if kwargs:
|
450
|
+
utils.warn_once(f"Working with uncaptured keyword arguments: {kwargs}")
|
451
|
+
|
320
452
|
super().__init__(
|
321
453
|
params,
|
322
454
|
defaults,
|
@@ -371,12 +503,17 @@ class ForeachSOLP(C.BaseOpt):
|
|
371
503
|
update_clipping: C.str_or_fn = C.use_default,
|
372
504
|
storage_dtype: str = "float32",
|
373
505
|
stochastic_schedule: bool = False,
|
506
|
+
**kwargs,
|
374
507
|
):
|
375
508
|
use_precond_schedule = C.default(use_precond_schedule, self.use_precond_schedule)
|
376
509
|
|
377
510
|
defaults = locals()
|
378
511
|
defaults.pop("self")
|
379
512
|
params = defaults.pop("params")
|
513
|
+
defaults.update(defaults.pop("kwargs"))
|
514
|
+
|
515
|
+
if kwargs:
|
516
|
+
utils.warn_once(f"Working with uncaptured keyword arguments: {kwargs}")
|
380
517
|
|
381
518
|
if use_precond_schedule:
|
382
519
|
del defaults["precondition_frequency"]
|
@@ -427,10 +564,15 @@ class OrthoLaProp(C.BaseOpt):
|
|
427
564
|
update_clipping: C.str_or_fn = C.use_default,
|
428
565
|
palm: bool = C.use_default,
|
429
566
|
beta2_scale: float = 0.8,
|
567
|
+
**kwargs,
|
430
568
|
):
|
431
569
|
defaults = locals()
|
432
570
|
defaults.pop("self")
|
433
571
|
params = defaults.pop("params")
|
572
|
+
defaults.update(defaults.pop("kwargs"))
|
573
|
+
|
574
|
+
if kwargs:
|
575
|
+
utils.warn_once(f"Working with uncaptured keyword arguments: {kwargs}")
|
434
576
|
super().__init__(
|
435
577
|
params,
|
436
578
|
defaults,
|
@@ -461,10 +603,15 @@ class LaPropOrtho(C.BaseOpt):
|
|
461
603
|
update_clipping: C.str_or_fn = C.use_default,
|
462
604
|
palm: bool = C.use_default,
|
463
605
|
beta2_scale: float = 0.8,
|
606
|
+
**kwargs,
|
464
607
|
):
|
465
608
|
defaults = locals()
|
466
609
|
defaults.pop("self")
|
467
610
|
params = defaults.pop("params")
|
611
|
+
defaults.update(defaults.pop("kwargs"))
|
612
|
+
|
613
|
+
if kwargs:
|
614
|
+
utils.warn_once(f"Working with uncaptured keyword arguments: {kwargs}")
|
468
615
|
super().__init__(
|
469
616
|
params,
|
470
617
|
defaults,
|
@@ -487,12 +634,14 @@ class ForeachPSGDKron(C.BaseOpt):
|
|
487
634
|
delayed: bool = False
|
488
635
|
cached: bool = False
|
489
636
|
exp_avg_input: bool = True
|
637
|
+
quad: bool = False
|
490
638
|
|
491
639
|
def __init__(
|
492
640
|
self,
|
493
641
|
params,
|
494
642
|
lr=0.001,
|
495
|
-
beta=
|
643
|
+
beta=None,
|
644
|
+
betas=(0.9, 0.999),
|
496
645
|
weight_decay=0.0,
|
497
646
|
preconditioner_update_probability=None,
|
498
647
|
max_size_triangular=2048,
|
@@ -515,23 +664,38 @@ class ForeachPSGDKron(C.BaseOpt):
|
|
515
664
|
exp_avg_input: Optional[bool] = C.use_default,
|
516
665
|
gradient_clipping: C.str_or_fn = C.use_default,
|
517
666
|
update_clipping: C.str_or_fn = C.use_default, #
|
667
|
+
adaptive: bool = False,
|
668
|
+
ortho_method: Optional[str] = None, # If None, no orthogonalization
|
669
|
+
precond_grad_accum: bool = False,
|
670
|
+
lower_bound_beta: float = 0.9, # 0.0 recovers pre-2.0.0 PSGD
|
671
|
+
inverse_free: bool = C.use_default,
|
672
|
+
dampening: float = 2**-13,
|
673
|
+
precond_update_power_iterations: int = 2,
|
518
674
|
# expert parameters
|
519
675
|
precond_init_scale=None,
|
520
|
-
precond_init_scale_scale=1,
|
521
|
-
|
676
|
+
precond_init_scale_scale: float = 1,
|
677
|
+
precond_init_scale_power: Optional[float] = None,
|
678
|
+
precond_lr: float = 0.1,
|
679
|
+
**kwargs,
|
522
680
|
):
|
681
|
+
delayed = C.default(delayed, self.delayed)
|
682
|
+
cached = C.default(cached, self.cached)
|
683
|
+
exp_avg_input = C.default(exp_avg_input, self.exp_avg_input)
|
684
|
+
update_clipping = C.default(update_clipping, utils.trust_region_clip_)
|
685
|
+
inverse_free = C.default(inverse_free, self.quad)
|
686
|
+
|
523
687
|
defaults = locals()
|
524
688
|
defaults.pop("self")
|
689
|
+
defaults.update(defaults.pop("kwargs"))
|
690
|
+
|
691
|
+
if kwargs:
|
692
|
+
utils.warn_once(f"Working with uncaptured keyword arguments: {kwargs}")
|
693
|
+
|
525
694
|
self.precond_schedule = (
|
526
695
|
defaults.pop("preconditioner_update_probability") or utils.precond_update_prob_schedule()
|
527
696
|
)
|
528
697
|
params = defaults.pop("params")
|
529
698
|
|
530
|
-
delayed = C.default(delayed, self.delayed)
|
531
|
-
cached = C.default(cached, self.cached)
|
532
|
-
exp_avg_input = C.default(exp_avg_input, self.exp_avg_input)
|
533
|
-
update_clipping = C.default(update_clipping, utils.trust_region_clip_)
|
534
|
-
|
535
699
|
super().__init__(
|
536
700
|
params,
|
537
701
|
defaults,
|
@@ -569,6 +733,11 @@ class NewtonHybrid2PSGDKron(ForeachCachedNewtonPSGD):
|
|
569
733
|
hvp_interval = 2
|
570
734
|
|
571
735
|
|
736
|
+
class QUAD(ForeachPSGDKron):
|
737
|
+
quad = True
|
738
|
+
cached = True
|
739
|
+
|
740
|
+
|
572
741
|
class ForeachPSGDLRA(C.BaseOpt):
|
573
742
|
"""
|
574
743
|
Originally from Evan Walters and Omead Pooladzandi, 2024
|
@@ -601,13 +770,24 @@ class ForeachPSGDLRA(C.BaseOpt):
|
|
601
770
|
gradient_clipping: C.str_or_fn = C.use_default,
|
602
771
|
update_clipping: C.str_or_fn = C.use_default,
|
603
772
|
eps: float = 1e-8, #
|
604
|
-
# expert parameters
|
773
|
+
precond_grad_accum: bool = False, # expert parameters
|
605
774
|
precond_init_scale=None,
|
606
|
-
precond_init_scale_scale=1,
|
607
|
-
|
775
|
+
precond_init_scale_scale: float = 1,
|
776
|
+
precond_init_scale_power: Optional[float] = None,
|
777
|
+
precond_lr: float = 0.1,
|
778
|
+
**kwargs,
|
608
779
|
):
|
780
|
+
delayed = C.default(delayed, self.delayed)
|
781
|
+
exp_avg_input = C.default(exp_avg_input, self.exp_avg_input)
|
782
|
+
update_clipping = C.default(update_clipping, utils.trust_region_clip_)
|
783
|
+
|
609
784
|
defaults = locals()
|
610
785
|
defaults.pop("self")
|
786
|
+
defaults.update(defaults.pop("kwargs"))
|
787
|
+
|
788
|
+
if kwargs:
|
789
|
+
utils.warn_once(f"Working with uncaptured keyword arguments: {kwargs}")
|
790
|
+
|
611
791
|
self.precond_schedule = (
|
612
792
|
defaults.pop("preconditioner_update_probability") or utils.precond_update_prob_schedule()
|
613
793
|
)
|
@@ -621,10 +801,6 @@ class ForeachPSGDLRA(C.BaseOpt):
|
|
621
801
|
defaults["rank"] = round(math.log2(sum(p.numel() for p in params)))
|
622
802
|
utils.warn_once(f"rank was set to {defaults['rank']}")
|
623
803
|
|
624
|
-
delayed = C.default(delayed, self.delayed)
|
625
|
-
exp_avg_input = C.default(exp_avg_input, self.exp_avg_input)
|
626
|
-
update_clipping = C.default(update_clipping, utils.trust_region_clip_)
|
627
|
-
|
628
804
|
super().__init__(
|
629
805
|
params,
|
630
806
|
defaults,
|
@@ -715,4 +891,5 @@ __all__ = [
|
|
715
891
|
"NewtonPSGDLRA",
|
716
892
|
"NewtonHybrid2PSGDLRA",
|
717
893
|
"NewtonHybrid2PSGDKron",
|
894
|
+
"MSAMLaProp",
|
718
895
|
]
|