heavyball 1.7.2__py3-none-any.whl → 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- heavyball/__init__.py +276 -37
- heavyball/chainable.py +419 -206
- heavyball/helpers.py +808 -0
- heavyball/utils.py +1062 -315
- heavyball-2.0.0.dist-info/METADATA +122 -0
- heavyball-2.0.0.dist-info/RECORD +9 -0
- {heavyball-1.7.2.dist-info → heavyball-2.0.0.dist-info}/WHEEL +1 -1
- heavyball-1.7.2.dist-info/METADATA +0 -939
- heavyball-1.7.2.dist-info/RECORD +0 -8
- {heavyball-1.7.2.dist-info → heavyball-2.0.0.dist-info}/licenses/LICENSE +0 -0
- {heavyball-1.7.2.dist-info → heavyball-2.0.0.dist-info}/top_level.txt +0 -0
heavyball/__init__.py
CHANGED
@@ -2,10 +2,73 @@ import functools
|
|
2
2
|
import math
|
3
3
|
from typing import Optional
|
4
4
|
|
5
|
+
import torch.optim
|
6
|
+
|
5
7
|
from . import chainable as C
|
6
8
|
from . import utils
|
7
9
|
|
8
10
|
|
11
|
+
class SAMWrapper(torch.optim.Optimizer):
|
12
|
+
def __init__(self, params, wrapped_optimizer: utils.StatefulOptimizer, ball: float = 0.1):
|
13
|
+
if not isinstance(wrapped_optimizer, utils.StatefulOptimizer):
|
14
|
+
raise ValueError(f"{wrapped_optimizer.__class__.__name__} is not a HeavyBall optimizer")
|
15
|
+
super().__init__(params, {"ball": ball})
|
16
|
+
self.wrapped_optimizer = wrapped_optimizer
|
17
|
+
|
18
|
+
@torch.no_grad()
|
19
|
+
def step(self, closure=None):
|
20
|
+
if closure is None:
|
21
|
+
raise ValueError("SAM requires closure")
|
22
|
+
with torch.enable_grad():
|
23
|
+
closure()
|
24
|
+
old_params = [utils.sam_step(group["params"], group["ball"]) for group in self.param_groups]
|
25
|
+
|
26
|
+
originaL_handle_closure = self.wrapped_optimizer._handle_closure
|
27
|
+
|
28
|
+
def _handle_closure(closure):
|
29
|
+
originaL_handle_closure(closure)
|
30
|
+
for group, old in zip(self.param_groups, old_params):
|
31
|
+
utils.copy_stochastic_list_(group["params"], old)
|
32
|
+
|
33
|
+
try:
|
34
|
+
self.wrapped_optimizer._handle_closure = _handle_closure
|
35
|
+
loss = self.wrapped_optimizer.step(closure)
|
36
|
+
finally:
|
37
|
+
self.wrapped_optimizer._handle_closure = originaL_handle_closure
|
38
|
+
return loss
|
39
|
+
|
40
|
+
def zero_grad(self, set_to_none: bool = True):
|
41
|
+
self.wrapped_optimizer.zero_grad()
|
42
|
+
|
43
|
+
|
44
|
+
class SGD(C.BaseOpt):
|
45
|
+
def __init__(
|
46
|
+
self,
|
47
|
+
params,
|
48
|
+
lr=0.0025,
|
49
|
+
beta=0.9,
|
50
|
+
weight_decay=0,
|
51
|
+
warmup_steps=0,
|
52
|
+
foreach: bool = True,
|
53
|
+
storage_dtype: str = "float32",
|
54
|
+
mars: bool = False,
|
55
|
+
caution: bool = False,
|
56
|
+
mars_gamma: float = 0.0025,
|
57
|
+
gradient_clipping: C.str_or_fn = C.use_default,
|
58
|
+
update_clipping: C.str_or_fn = C.use_default,
|
59
|
+
**kwargs,
|
60
|
+
):
|
61
|
+
defaults = locals()
|
62
|
+
defaults.pop("self")
|
63
|
+
params = defaults.pop("params")
|
64
|
+
defaults.update(defaults.pop("kwargs"))
|
65
|
+
|
66
|
+
if kwargs:
|
67
|
+
utils.warn_once(f"Working with uncaptured keyword arguments: {kwargs}")
|
68
|
+
|
69
|
+
super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, fns=(C.heavyball_momentum,))
|
70
|
+
|
71
|
+
|
9
72
|
class ForeachAdamW(C.BaseOpt):
|
10
73
|
def __init__(
|
11
74
|
self,
|
@@ -24,11 +87,55 @@ class ForeachAdamW(C.BaseOpt):
|
|
24
87
|
update_clipping: C.str_or_fn = C.use_default,
|
25
88
|
palm: bool = C.use_default,
|
26
89
|
beta2_scale: float = 0.8,
|
90
|
+
**kwargs,
|
27
91
|
):
|
28
92
|
defaults = locals()
|
29
93
|
defaults.pop("self")
|
30
94
|
params = defaults.pop("params")
|
31
|
-
|
95
|
+
defaults.update(defaults.pop("kwargs"))
|
96
|
+
|
97
|
+
if kwargs:
|
98
|
+
utils.warn_once(f"Working with uncaptured keyword arguments: {kwargs}")
|
99
|
+
|
100
|
+
super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, fns=(C.update_by_adam,))
|
101
|
+
|
102
|
+
|
103
|
+
class ForeachAdamC(C.BaseOpt):
|
104
|
+
def __init__(
|
105
|
+
self,
|
106
|
+
params,
|
107
|
+
lr=0.0025,
|
108
|
+
betas=(0.9, 0.99),
|
109
|
+
eps=1e-8,
|
110
|
+
weight_decay=0,
|
111
|
+
max_lr: float | None = None,
|
112
|
+
warmup_steps=0,
|
113
|
+
foreach: bool = True,
|
114
|
+
storage_dtype: str = "float32",
|
115
|
+
mars: bool = False,
|
116
|
+
caution: bool = False,
|
117
|
+
mars_gamma: float = 0.0025,
|
118
|
+
gradient_clipping: C.str_or_fn = C.use_default,
|
119
|
+
update_clipping: C.str_or_fn = C.use_default,
|
120
|
+
palm: bool = C.use_default,
|
121
|
+
beta2_scale: float = 0.8,
|
122
|
+
**kwargs,
|
123
|
+
):
|
124
|
+
if max_lr is None:
|
125
|
+
utils.warn_once(
|
126
|
+
"max_lr was not set. setting it to the current learning rate, under the assumption that it strictly decreases"
|
127
|
+
)
|
128
|
+
max_lr = lr
|
129
|
+
|
130
|
+
defaults = locals()
|
131
|
+
defaults.pop("self")
|
132
|
+
params = defaults.pop("params")
|
133
|
+
defaults.update(defaults.pop("kwargs"))
|
134
|
+
|
135
|
+
if kwargs:
|
136
|
+
utils.warn_once(f"Working with uncaptured keyword arguments: {kwargs}")
|
137
|
+
|
138
|
+
super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, fns=(C.update_by_adamc,))
|
32
139
|
|
33
140
|
|
34
141
|
class ForeachRMSprop(C.BaseOpt):
|
@@ -55,10 +162,16 @@ class ForeachRMSprop(C.BaseOpt):
|
|
55
162
|
update_clipping: C.str_or_fn = C.use_default,
|
56
163
|
palm: bool = C.use_default,
|
57
164
|
beta2_scale: float = 0.8,
|
165
|
+
**kwargs,
|
58
166
|
):
|
59
167
|
defaults = locals()
|
60
168
|
defaults.pop("self")
|
61
169
|
params = defaults.pop("params")
|
170
|
+
defaults.update(defaults.pop("kwargs"))
|
171
|
+
|
172
|
+
if kwargs:
|
173
|
+
utils.warn_once(f"Working with uncaptured keyword arguments: {kwargs}")
|
174
|
+
|
62
175
|
super().__init__(
|
63
176
|
params,
|
64
177
|
defaults,
|
@@ -66,7 +179,7 @@ class ForeachRMSprop(C.BaseOpt):
|
|
66
179
|
gradient_clipping,
|
67
180
|
update_clipping,
|
68
181
|
palm,
|
69
|
-
C.scale_by_exp_avg_sq,
|
182
|
+
fns=(C.scale_by_exp_avg_sq,),
|
70
183
|
)
|
71
184
|
|
72
185
|
|
@@ -90,10 +203,58 @@ class ForeachSFAdamW(C.ScheduleFree):
|
|
90
203
|
update_clipping: C.str_or_fn = C.use_default,
|
91
204
|
palm: bool = C.use_default,
|
92
205
|
beta2_scale: float = 0.8,
|
206
|
+
**kwargs,
|
207
|
+
):
|
208
|
+
defaults = locals()
|
209
|
+
defaults.pop("self")
|
210
|
+
params = defaults.pop("params")
|
211
|
+
defaults.update(defaults.pop("kwargs"))
|
212
|
+
|
213
|
+
if kwargs:
|
214
|
+
utils.warn_once(f"Working with uncaptured keyword arguments: {kwargs}")
|
215
|
+
|
216
|
+
super().__init__(
|
217
|
+
params,
|
218
|
+
defaults,
|
219
|
+
foreach,
|
220
|
+
gradient_clipping,
|
221
|
+
update_clipping,
|
222
|
+
palm,
|
223
|
+
fns=(C.scale_by_exp_avg_sq, C.update_by_schedule_free),
|
224
|
+
)
|
225
|
+
|
226
|
+
|
227
|
+
class MSAMLaProp(C.MSAM):
|
228
|
+
def __init__(
|
229
|
+
self,
|
230
|
+
params,
|
231
|
+
lr=0.0025,
|
232
|
+
betas=(0.9, 0.99),
|
233
|
+
eps=1e-6,
|
234
|
+
weight_decay=0,
|
235
|
+
warmup_steps=0,
|
236
|
+
r=0.0,
|
237
|
+
weight_lr_power=2.0,
|
238
|
+
foreach: bool = True,
|
239
|
+
storage_dtype: str = "float32",
|
240
|
+
mars: bool = False,
|
241
|
+
caution: bool = False,
|
242
|
+
mars_gamma: float = 0.0025,
|
243
|
+
gradient_clipping: C.str_or_fn = C.use_default,
|
244
|
+
update_clipping: C.str_or_fn = C.use_default,
|
245
|
+
palm: bool = C.use_default,
|
246
|
+
beta2_scale: float = 0.8,
|
247
|
+
sam_step_size: float = 0.1,
|
248
|
+
**kwargs,
|
93
249
|
):
|
94
250
|
defaults = locals()
|
95
251
|
defaults.pop("self")
|
96
252
|
params = defaults.pop("params")
|
253
|
+
defaults.update(defaults.pop("kwargs"))
|
254
|
+
|
255
|
+
if kwargs:
|
256
|
+
utils.warn_once(f"Working with uncaptured keyword arguments: {kwargs}")
|
257
|
+
|
97
258
|
super().__init__(
|
98
259
|
params,
|
99
260
|
defaults,
|
@@ -101,8 +262,7 @@ class ForeachSFAdamW(C.ScheduleFree):
|
|
101
262
|
gradient_clipping,
|
102
263
|
update_clipping,
|
103
264
|
palm,
|
104
|
-
C.scale_by_exp_avg_sq,
|
105
|
-
C.update_by_schedule_free,
|
265
|
+
fns=(C.scale_by_exp_avg_sq, C.update_by_msam),
|
106
266
|
)
|
107
267
|
|
108
268
|
|
@@ -128,11 +288,17 @@ class ForeachADOPT(C.BaseOpt):
|
|
128
288
|
update_clipping: C.str_or_fn = C.use_default,
|
129
289
|
palm: bool = C.use_default,
|
130
290
|
beta2_scale: float = 0.8,
|
291
|
+
**kwargs,
|
131
292
|
):
|
132
293
|
defaults = locals()
|
133
294
|
defaults.pop("self")
|
134
295
|
params = defaults.pop("params")
|
135
|
-
|
296
|
+
defaults.update(defaults.pop("kwargs"))
|
297
|
+
|
298
|
+
if kwargs:
|
299
|
+
utils.warn_once(f"Working with uncaptured keyword arguments: {kwargs}")
|
300
|
+
|
301
|
+
super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, fns=(C.update_by_adopt,))
|
136
302
|
|
137
303
|
|
138
304
|
class ForeachMuon(C.BaseOpt):
|
@@ -154,10 +320,16 @@ class ForeachMuon(C.BaseOpt):
|
|
154
320
|
palm: bool = C.use_default,
|
155
321
|
beta2_scale: float = 0.8,
|
156
322
|
nesterov: bool = True,
|
323
|
+
**kwargs,
|
157
324
|
):
|
158
325
|
defaults = locals()
|
159
326
|
defaults.pop("self")
|
160
327
|
params = defaults.pop("params")
|
328
|
+
defaults.update(defaults.pop("kwargs"))
|
329
|
+
|
330
|
+
if kwargs:
|
331
|
+
utils.warn_once(f"Working with uncaptured keyword arguments: {kwargs}")
|
332
|
+
|
161
333
|
super().__init__(
|
162
334
|
params,
|
163
335
|
defaults,
|
@@ -165,8 +337,7 @@ class ForeachMuon(C.BaseOpt):
|
|
165
337
|
gradient_clipping,
|
166
338
|
update_clipping,
|
167
339
|
palm,
|
168
|
-
C.
|
169
|
-
C.orthogonalize_update,
|
340
|
+
fns=(C.nesterov_ema if nesterov else C.exp_avg, C.orthogonalize_update),
|
170
341
|
)
|
171
342
|
|
172
343
|
|
@@ -188,11 +359,17 @@ class ForeachLaProp(C.BaseOpt):
|
|
188
359
|
update_clipping: C.str_or_fn = C.use_default,
|
189
360
|
palm: bool = C.use_default,
|
190
361
|
beta2_scale: float = 0.8,
|
362
|
+
**kwargs,
|
191
363
|
):
|
192
364
|
defaults = locals()
|
193
365
|
defaults.pop("self")
|
194
366
|
params = defaults.pop("params")
|
195
|
-
|
367
|
+
defaults.update(defaults.pop("kwargs"))
|
368
|
+
|
369
|
+
if kwargs:
|
370
|
+
utils.warn_once(f"Working with uncaptured keyword arguments: {kwargs}")
|
371
|
+
|
372
|
+
super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, fns=(C.update_by_laprop,))
|
196
373
|
|
197
374
|
|
198
375
|
class MuonLaProp(C.BaseOpt):
|
@@ -213,10 +390,16 @@ class MuonLaProp(C.BaseOpt):
|
|
213
390
|
update_clipping: C.str_or_fn = C.use_default,
|
214
391
|
palm: bool = C.use_default,
|
215
392
|
beta2_scale: float = 0.8,
|
393
|
+
**kwargs,
|
216
394
|
):
|
217
395
|
defaults = locals()
|
218
396
|
defaults.pop("self")
|
219
397
|
params = defaults.pop("params")
|
398
|
+
defaults.update(defaults.pop("kwargs"))
|
399
|
+
|
400
|
+
if kwargs:
|
401
|
+
utils.warn_once(f"Working with uncaptured keyword arguments: {kwargs}")
|
402
|
+
|
220
403
|
super().__init__(
|
221
404
|
params,
|
222
405
|
defaults,
|
@@ -224,8 +407,7 @@ class MuonLaProp(C.BaseOpt):
|
|
224
407
|
gradient_clipping,
|
225
408
|
update_clipping,
|
226
409
|
palm,
|
227
|
-
C.scale_by_laprop,
|
228
|
-
C.orthogonalize_update,
|
410
|
+
fns=(C.scale_by_laprop, C.orthogonalize_update),
|
229
411
|
)
|
230
412
|
|
231
413
|
|
@@ -271,12 +453,18 @@ class ForeachSOAP(C.BaseOpt):
|
|
271
453
|
update_clipping: C.str_or_fn = C.use_default,
|
272
454
|
storage_dtype: str = "float32",
|
273
455
|
stochastic_schedule: bool = False,
|
456
|
+
precond_grad_accum: bool = False,
|
457
|
+
**kwargs,
|
274
458
|
):
|
275
459
|
use_precond_schedule = C.default(use_precond_schedule, self.use_precond_schedule)
|
276
460
|
|
277
461
|
defaults = locals()
|
278
462
|
defaults.pop("self")
|
279
463
|
params = defaults.pop("params")
|
464
|
+
defaults.update(defaults.pop("kwargs"))
|
465
|
+
|
466
|
+
if kwargs:
|
467
|
+
utils.warn_once(f"Working with uncaptured keyword arguments: {kwargs}")
|
280
468
|
|
281
469
|
if use_precond_schedule:
|
282
470
|
del defaults["precondition_frequency"]
|
@@ -291,7 +479,7 @@ class ForeachSOAP(C.BaseOpt):
|
|
291
479
|
gradient_clipping,
|
292
480
|
update_clipping,
|
293
481
|
palm, #
|
294
|
-
C.scale_by_soap,
|
482
|
+
fns=(C.scale_by_soap,),
|
295
483
|
)
|
296
484
|
|
297
485
|
|
@@ -313,10 +501,16 @@ class ForeachSignLaProp(C.BaseOpt):
|
|
313
501
|
update_clipping: C.str_or_fn = C.use_default,
|
314
502
|
palm: bool = C.use_default,
|
315
503
|
beta2_scale: float = 0.8,
|
504
|
+
**kwargs,
|
316
505
|
):
|
317
506
|
defaults = locals()
|
318
507
|
defaults.pop("self")
|
319
508
|
params = defaults.pop("params")
|
509
|
+
defaults.update(defaults.pop("kwargs"))
|
510
|
+
|
511
|
+
if kwargs:
|
512
|
+
utils.warn_once(f"Working with uncaptured keyword arguments: {kwargs}")
|
513
|
+
|
320
514
|
super().__init__(
|
321
515
|
params,
|
322
516
|
defaults,
|
@@ -324,8 +518,7 @@ class ForeachSignLaProp(C.BaseOpt):
|
|
324
518
|
gradient_clipping,
|
325
519
|
update_clipping,
|
326
520
|
palm,
|
327
|
-
C.scale_by_laprop,
|
328
|
-
C.sign,
|
521
|
+
fns=(C.scale_by_laprop, C.sign),
|
329
522
|
)
|
330
523
|
|
331
524
|
|
@@ -371,12 +564,17 @@ class ForeachSOLP(C.BaseOpt):
|
|
371
564
|
update_clipping: C.str_or_fn = C.use_default,
|
372
565
|
storage_dtype: str = "float32",
|
373
566
|
stochastic_schedule: bool = False,
|
567
|
+
**kwargs,
|
374
568
|
):
|
375
569
|
use_precond_schedule = C.default(use_precond_schedule, self.use_precond_schedule)
|
376
570
|
|
377
571
|
defaults = locals()
|
378
572
|
defaults.pop("self")
|
379
573
|
params = defaults.pop("params")
|
574
|
+
defaults.update(defaults.pop("kwargs"))
|
575
|
+
|
576
|
+
if kwargs:
|
577
|
+
utils.warn_once(f"Working with uncaptured keyword arguments: {kwargs}")
|
380
578
|
|
381
579
|
if use_precond_schedule:
|
382
580
|
del defaults["precondition_frequency"]
|
@@ -391,7 +589,7 @@ class ForeachSOLP(C.BaseOpt):
|
|
391
589
|
gradient_clipping,
|
392
590
|
update_clipping,
|
393
591
|
palm, #
|
394
|
-
functools.partial(C.scale_by_soap, inner="laprop"),
|
592
|
+
fns=(functools.partial(C.scale_by_soap, inner="laprop"),),
|
395
593
|
)
|
396
594
|
|
397
595
|
|
@@ -427,10 +625,15 @@ class OrthoLaProp(C.BaseOpt):
|
|
427
625
|
update_clipping: C.str_or_fn = C.use_default,
|
428
626
|
palm: bool = C.use_default,
|
429
627
|
beta2_scale: float = 0.8,
|
628
|
+
**kwargs,
|
430
629
|
):
|
431
630
|
defaults = locals()
|
432
631
|
defaults.pop("self")
|
433
632
|
params = defaults.pop("params")
|
633
|
+
defaults.update(defaults.pop("kwargs"))
|
634
|
+
|
635
|
+
if kwargs:
|
636
|
+
utils.warn_once(f"Working with uncaptured keyword arguments: {kwargs}")
|
434
637
|
super().__init__(
|
435
638
|
params,
|
436
639
|
defaults,
|
@@ -438,8 +641,7 @@ class OrthoLaProp(C.BaseOpt):
|
|
438
641
|
gradient_clipping,
|
439
642
|
update_clipping,
|
440
643
|
palm,
|
441
|
-
C.orthogonalize_grad_to_param,
|
442
|
-
C.scale_by_laprop,
|
644
|
+
fns=(C.orthogonalize_grad_to_param, C.scale_by_laprop),
|
443
645
|
)
|
444
646
|
|
445
647
|
|
@@ -461,10 +663,15 @@ class LaPropOrtho(C.BaseOpt):
|
|
461
663
|
update_clipping: C.str_or_fn = C.use_default,
|
462
664
|
palm: bool = C.use_default,
|
463
665
|
beta2_scale: float = 0.8,
|
666
|
+
**kwargs,
|
464
667
|
):
|
465
668
|
defaults = locals()
|
466
669
|
defaults.pop("self")
|
467
670
|
params = defaults.pop("params")
|
671
|
+
defaults.update(defaults.pop("kwargs"))
|
672
|
+
|
673
|
+
if kwargs:
|
674
|
+
utils.warn_once(f"Working with uncaptured keyword arguments: {kwargs}")
|
468
675
|
super().__init__(
|
469
676
|
params,
|
470
677
|
defaults,
|
@@ -472,8 +679,7 @@ class LaPropOrtho(C.BaseOpt):
|
|
472
679
|
gradient_clipping,
|
473
680
|
update_clipping,
|
474
681
|
palm,
|
475
|
-
C.scale_by_laprop,
|
476
|
-
C.orthogonalize_grad_to_param,
|
682
|
+
fns=(C.scale_by_laprop, C.orthogonalize_grad_to_param),
|
477
683
|
)
|
478
684
|
|
479
685
|
|
@@ -487,12 +693,14 @@ class ForeachPSGDKron(C.BaseOpt):
|
|
487
693
|
delayed: bool = False
|
488
694
|
cached: bool = False
|
489
695
|
exp_avg_input: bool = True
|
696
|
+
quad: bool = False
|
490
697
|
|
491
698
|
def __init__(
|
492
699
|
self,
|
493
700
|
params,
|
494
701
|
lr=0.001,
|
495
|
-
beta=
|
702
|
+
beta=None,
|
703
|
+
betas=(0.9, 0.999),
|
496
704
|
weight_decay=0.0,
|
497
705
|
preconditioner_update_probability=None,
|
498
706
|
max_size_triangular=2048,
|
@@ -515,23 +723,40 @@ class ForeachPSGDKron(C.BaseOpt):
|
|
515
723
|
exp_avg_input: Optional[bool] = C.use_default,
|
516
724
|
gradient_clipping: C.str_or_fn = C.use_default,
|
517
725
|
update_clipping: C.str_or_fn = C.use_default, #
|
726
|
+
adaptive: bool = False,
|
727
|
+
ortho_method: Optional[str] = None, # If None, no orthogonalization
|
728
|
+
precond_grad_accum: bool = False,
|
729
|
+
lower_bound_beta: float = 0.9, # 0.0 recovers pre-2.0.0 PSGD
|
730
|
+
inverse_free: bool = C.use_default,
|
731
|
+
dampening: float = 2**-13,
|
732
|
+
precond_update_power_iterations: int = 2,
|
518
733
|
# expert parameters
|
519
734
|
precond_init_scale=None,
|
520
|
-
precond_init_scale_scale=1,
|
521
|
-
|
735
|
+
precond_init_scale_scale: float = 1,
|
736
|
+
precond_init_scale_power: Optional[float] = None,
|
737
|
+
precond_lr: float = 0.1,
|
738
|
+
**kwargs,
|
522
739
|
):
|
740
|
+
delayed = C.default(delayed, self.delayed)
|
741
|
+
cached = C.default(cached, self.cached)
|
742
|
+
exp_avg_input = C.default(exp_avg_input, self.exp_avg_input)
|
743
|
+
update_clipping = C.default(update_clipping, utils.trust_region_clip_)
|
744
|
+
inverse_free = C.default(inverse_free, self.quad)
|
745
|
+
if inverse_free:
|
746
|
+
raise ValueError("inverse_free (i.e., PSGD-QUAD) is not supported at the moment. Consider using https://github.com/evanatyourservice/quad_torch")
|
747
|
+
|
523
748
|
defaults = locals()
|
524
749
|
defaults.pop("self")
|
750
|
+
defaults.update(defaults.pop("kwargs"))
|
751
|
+
|
752
|
+
if kwargs:
|
753
|
+
utils.warn_once(f"Working with uncaptured keyword arguments: {kwargs}")
|
754
|
+
|
525
755
|
self.precond_schedule = (
|
526
756
|
defaults.pop("preconditioner_update_probability") or utils.precond_update_prob_schedule()
|
527
757
|
)
|
528
758
|
params = defaults.pop("params")
|
529
759
|
|
530
|
-
delayed = C.default(delayed, self.delayed)
|
531
|
-
cached = C.default(cached, self.cached)
|
532
|
-
exp_avg_input = C.default(exp_avg_input, self.exp_avg_input)
|
533
|
-
update_clipping = C.default(update_clipping, utils.trust_region_clip_)
|
534
|
-
|
535
760
|
super().__init__(
|
536
761
|
params,
|
537
762
|
defaults,
|
@@ -539,8 +764,10 @@ class ForeachPSGDKron(C.BaseOpt):
|
|
539
764
|
gradient_clipping,
|
540
765
|
update_clipping,
|
541
766
|
False, #
|
542
|
-
|
543
|
-
|
767
|
+
fns=(
|
768
|
+
*(C.exp_avg,) * exp_avg_input,
|
769
|
+
functools.partial(C.scale_by_delayed_psgd if delayed else C.scale_by_psgd, cached=cached),
|
770
|
+
),
|
544
771
|
)
|
545
772
|
|
546
773
|
|
@@ -569,6 +796,7 @@ class NewtonHybrid2PSGDKron(ForeachCachedNewtonPSGD):
|
|
569
796
|
hvp_interval = 2
|
570
797
|
|
571
798
|
|
799
|
+
|
572
800
|
class ForeachPSGDLRA(C.BaseOpt):
|
573
801
|
"""
|
574
802
|
Originally from Evan Walters and Omead Pooladzandi, 2024
|
@@ -601,13 +829,24 @@ class ForeachPSGDLRA(C.BaseOpt):
|
|
601
829
|
gradient_clipping: C.str_or_fn = C.use_default,
|
602
830
|
update_clipping: C.str_or_fn = C.use_default,
|
603
831
|
eps: float = 1e-8, #
|
604
|
-
# expert parameters
|
832
|
+
precond_grad_accum: bool = False, # expert parameters
|
605
833
|
precond_init_scale=None,
|
606
|
-
precond_init_scale_scale=1,
|
607
|
-
|
834
|
+
precond_init_scale_scale: float = 1,
|
835
|
+
precond_init_scale_power: Optional[float] = None,
|
836
|
+
precond_lr: float = 0.1,
|
837
|
+
**kwargs,
|
608
838
|
):
|
839
|
+
delayed = C.default(delayed, self.delayed)
|
840
|
+
exp_avg_input = C.default(exp_avg_input, self.exp_avg_input)
|
841
|
+
update_clipping = C.default(update_clipping, utils.trust_region_clip_)
|
842
|
+
|
609
843
|
defaults = locals()
|
610
844
|
defaults.pop("self")
|
845
|
+
defaults.update(defaults.pop("kwargs"))
|
846
|
+
|
847
|
+
if kwargs:
|
848
|
+
utils.warn_once(f"Working with uncaptured keyword arguments: {kwargs}")
|
849
|
+
|
611
850
|
self.precond_schedule = (
|
612
851
|
defaults.pop("preconditioner_update_probability") or utils.precond_update_prob_schedule()
|
613
852
|
)
|
@@ -621,10 +860,6 @@ class ForeachPSGDLRA(C.BaseOpt):
|
|
621
860
|
defaults["rank"] = round(math.log2(sum(p.numel() for p in params)))
|
622
861
|
utils.warn_once(f"rank was set to {defaults['rank']}")
|
623
862
|
|
624
|
-
delayed = C.default(delayed, self.delayed)
|
625
|
-
exp_avg_input = C.default(exp_avg_input, self.exp_avg_input)
|
626
|
-
update_clipping = C.default(update_clipping, utils.trust_region_clip_)
|
627
|
-
|
628
863
|
super().__init__(
|
629
864
|
params,
|
630
865
|
defaults,
|
@@ -632,8 +867,7 @@ class ForeachPSGDLRA(C.BaseOpt):
|
|
632
867
|
gradient_clipping,
|
633
868
|
update_clipping,
|
634
869
|
False, #
|
635
|
-
*(C.exp_avg,) * exp_avg_input,
|
636
|
-
C.scale_by_delayed_psgd_lra if delayed else C.scale_by_psgd_lra,
|
870
|
+
fns=(*(C.exp_avg,) * exp_avg_input, C.scale_by_delayed_psgd_lra if delayed else C.scale_by_psgd_lra),
|
637
871
|
)
|
638
872
|
|
639
873
|
|
@@ -670,6 +904,7 @@ SignLaProp = ForeachSignLaProp
|
|
670
904
|
DelayedPSGDLRA = ForeachDelayedPSGDLRA
|
671
905
|
PSGDLRA = ForeachPSGDLRA
|
672
906
|
NewtonPSGDLRA = ForeachNewtonPSGDLRA
|
907
|
+
NewtonPSGDKron = ForeachCachedNewtonPSGD
|
673
908
|
|
674
909
|
__all__ = [
|
675
910
|
"Muon",
|
@@ -715,4 +950,8 @@ __all__ = [
|
|
715
950
|
"NewtonPSGDLRA",
|
716
951
|
"NewtonHybrid2PSGDLRA",
|
717
952
|
"NewtonHybrid2PSGDKron",
|
953
|
+
"MSAMLaProp",
|
954
|
+
"NewtonPSGDKron",
|
955
|
+
"ForeachAdamC",
|
956
|
+
"SGD"
|
718
957
|
]
|