heavyball 1.6.3__py3-none-any.whl → 1.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- heavyball/__init__.py +496 -100
- heavyball/chainable.py +444 -155
- heavyball/utils.py +326 -143
- {heavyball-1.6.3.dist-info → heavyball-1.7.0.dist-info}/METADATA +3 -2
- heavyball-1.7.0.dist-info/RECORD +8 -0
- {heavyball-1.6.3.dist-info → heavyball-1.7.0.dist-info}/WHEEL +1 -1
- {heavyball-1.6.3.dist-info → heavyball-1.7.0.dist-info/licenses}/LICENSE +1 -1
- heavyball-1.6.3.dist-info/RECORD +0 -8
- {heavyball-1.6.3.dist-info → heavyball-1.7.0.dist-info}/top_level.txt +0 -0
heavyball/__init__.py
CHANGED
@@ -6,10 +6,24 @@ from . import utils
|
|
6
6
|
|
7
7
|
|
8
8
|
class ForeachAdamW(C.BaseOpt):
|
9
|
-
def __init__(
|
10
|
-
|
11
|
-
|
12
|
-
|
9
|
+
def __init__(
|
10
|
+
self,
|
11
|
+
params,
|
12
|
+
lr=0.0025,
|
13
|
+
betas=(0.9, 0.99),
|
14
|
+
eps=1e-8,
|
15
|
+
weight_decay=0,
|
16
|
+
warmup_steps=0,
|
17
|
+
foreach: bool = True,
|
18
|
+
storage_dtype: str = "float32",
|
19
|
+
mars: bool = False,
|
20
|
+
caution: bool = False,
|
21
|
+
mars_gamma: float = 0.0025,
|
22
|
+
gradient_clipping: C.str_or_fn = C.use_default,
|
23
|
+
update_clipping: C.str_or_fn = C.use_default,
|
24
|
+
palm: bool = C.use_default,
|
25
|
+
beta2_scale: float = 0.8,
|
26
|
+
):
|
13
27
|
defaults = locals()
|
14
28
|
defaults.pop("self")
|
15
29
|
params = defaults.pop("params")
|
@@ -21,26 +35,74 @@ class ForeachRMSprop(C.BaseOpt):
|
|
21
35
|
Debiased RMSprop (not torch.optim.RMSprop)
|
22
36
|
"""
|
23
37
|
|
24
|
-
def __init__(
|
25
|
-
|
26
|
-
|
27
|
-
|
38
|
+
def __init__(
|
39
|
+
self,
|
40
|
+
params,
|
41
|
+
lr=0.0025,
|
42
|
+
betas=(0.9, 0.99),
|
43
|
+
eps=1e-6,
|
44
|
+
weight_decay=0,
|
45
|
+
warmup_steps=0,
|
46
|
+
r=0.0,
|
47
|
+
weight_lr_power=2.0,
|
48
|
+
foreach: bool = True,
|
49
|
+
storage_dtype: str = "float32",
|
50
|
+
mars: bool = False,
|
51
|
+
caution: bool = False,
|
52
|
+
mars_gamma: float = 0.0025,
|
53
|
+
gradient_clipping: C.str_or_fn = C.use_default,
|
54
|
+
update_clipping: C.str_or_fn = C.use_default,
|
55
|
+
palm: bool = C.use_default,
|
56
|
+
beta2_scale: float = 0.8,
|
57
|
+
):
|
28
58
|
defaults = locals()
|
29
59
|
defaults.pop("self")
|
30
60
|
params = defaults.pop("params")
|
31
|
-
super().__init__(
|
61
|
+
super().__init__(
|
62
|
+
params,
|
63
|
+
defaults,
|
64
|
+
foreach,
|
65
|
+
gradient_clipping,
|
66
|
+
update_clipping,
|
67
|
+
palm,
|
68
|
+
C.scale_by_exp_avg_sq,
|
69
|
+
)
|
32
70
|
|
33
71
|
|
34
72
|
class ForeachSFAdamW(C.ScheduleFree):
|
35
|
-
def __init__(
|
36
|
-
|
37
|
-
|
38
|
-
|
73
|
+
def __init__(
|
74
|
+
self,
|
75
|
+
params,
|
76
|
+
lr=0.0025,
|
77
|
+
betas=(0.9, 0.99),
|
78
|
+
eps=1e-6,
|
79
|
+
weight_decay=0,
|
80
|
+
warmup_steps=0,
|
81
|
+
r=0.0,
|
82
|
+
weight_lr_power=2.0,
|
83
|
+
foreach: bool = True,
|
84
|
+
storage_dtype: str = "float32",
|
85
|
+
mars: bool = False,
|
86
|
+
caution: bool = False,
|
87
|
+
mars_gamma: float = 0.0025,
|
88
|
+
gradient_clipping: C.str_or_fn = C.use_default,
|
89
|
+
update_clipping: C.str_or_fn = C.use_default,
|
90
|
+
palm: bool = C.use_default,
|
91
|
+
beta2_scale: float = 0.8,
|
92
|
+
):
|
39
93
|
defaults = locals()
|
40
94
|
defaults.pop("self")
|
41
95
|
params = defaults.pop("params")
|
42
|
-
super().__init__(
|
43
|
-
|
96
|
+
super().__init__(
|
97
|
+
params,
|
98
|
+
defaults,
|
99
|
+
foreach,
|
100
|
+
gradient_clipping,
|
101
|
+
update_clipping,
|
102
|
+
palm,
|
103
|
+
C.scale_by_exp_avg_sq,
|
104
|
+
C.update_by_schedule_free,
|
105
|
+
)
|
44
106
|
|
45
107
|
|
46
108
|
class PaLMForeachSFAdamW(ForeachSFAdamW):
|
@@ -48,10 +110,24 @@ class PaLMForeachSFAdamW(ForeachSFAdamW):
|
|
48
110
|
|
49
111
|
|
50
112
|
class ForeachADOPT(C.BaseOpt):
|
51
|
-
def __init__(
|
52
|
-
|
53
|
-
|
54
|
-
|
113
|
+
def __init__(
|
114
|
+
self,
|
115
|
+
params,
|
116
|
+
lr=0.0025,
|
117
|
+
betas=(0.9, 0.99),
|
118
|
+
eps=1e-8,
|
119
|
+
weight_decay=0,
|
120
|
+
warmup_steps=0,
|
121
|
+
foreach: bool = True,
|
122
|
+
storage_dtype: str = "float32",
|
123
|
+
mars: bool = False,
|
124
|
+
caution: bool = False,
|
125
|
+
mars_gamma: float = 0.0025,
|
126
|
+
gradient_clipping: C.str_or_fn = C.use_default,
|
127
|
+
update_clipping: C.str_or_fn = C.use_default,
|
128
|
+
palm: bool = C.use_default,
|
129
|
+
beta2_scale: float = 0.8,
|
130
|
+
):
|
55
131
|
defaults = locals()
|
56
132
|
defaults.pop("self")
|
57
133
|
params = defaults.pop("params")
|
@@ -59,23 +135,59 @@ class ForeachADOPT(C.BaseOpt):
|
|
59
135
|
|
60
136
|
|
61
137
|
class ForeachMuon(C.BaseOpt):
|
62
|
-
def __init__(
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
138
|
+
def __init__(
|
139
|
+
self,
|
140
|
+
params,
|
141
|
+
lr=0.0025,
|
142
|
+
betas=(0.9, 0.99),
|
143
|
+
eps=1e-8,
|
144
|
+
weight_decay=0,
|
145
|
+
warmup_steps=0,
|
146
|
+
foreach: bool = True,
|
147
|
+
storage_dtype: str = "float32",
|
148
|
+
mars: bool = False,
|
149
|
+
caution: bool = False,
|
150
|
+
mars_gamma: float = 0.0025,
|
151
|
+
gradient_clipping: C.str_or_fn = C.use_default,
|
152
|
+
update_clipping: C.str_or_fn = C.use_default,
|
153
|
+
palm: bool = C.use_default,
|
154
|
+
beta2_scale: float = 0.8,
|
155
|
+
nesterov: bool = True,
|
156
|
+
):
|
67
157
|
defaults = locals()
|
68
158
|
defaults.pop("self")
|
69
159
|
params = defaults.pop("params")
|
70
|
-
super().__init__(
|
71
|
-
|
160
|
+
super().__init__(
|
161
|
+
params,
|
162
|
+
defaults,
|
163
|
+
foreach,
|
164
|
+
gradient_clipping,
|
165
|
+
update_clipping,
|
166
|
+
palm,
|
167
|
+
C.nesterov_momentum if nesterov else C.heavyball_momentum,
|
168
|
+
C.orthogonalize_update,
|
169
|
+
)
|
72
170
|
|
73
171
|
|
74
172
|
class ForeachLaProp(C.BaseOpt):
|
75
|
-
def __init__(
|
76
|
-
|
77
|
-
|
78
|
-
|
173
|
+
def __init__(
|
174
|
+
self,
|
175
|
+
params,
|
176
|
+
lr=0.0025,
|
177
|
+
betas=(0.9, 0.99),
|
178
|
+
eps=1e-8,
|
179
|
+
weight_decay=0,
|
180
|
+
warmup_steps=0,
|
181
|
+
foreach: bool = True,
|
182
|
+
storage_dtype: str = "float32",
|
183
|
+
mars: bool = False,
|
184
|
+
caution: bool = False,
|
185
|
+
mars_gamma: float = 0.0025,
|
186
|
+
gradient_clipping: C.str_or_fn = C.use_default,
|
187
|
+
update_clipping: C.str_or_fn = C.use_default,
|
188
|
+
palm: bool = C.use_default,
|
189
|
+
beta2_scale: float = 0.8,
|
190
|
+
):
|
79
191
|
defaults = locals()
|
80
192
|
defaults.pop("self")
|
81
193
|
params = defaults.pop("params")
|
@@ -83,15 +195,37 @@ class ForeachLaProp(C.BaseOpt):
|
|
83
195
|
|
84
196
|
|
85
197
|
class MuonLaProp(C.BaseOpt):
|
86
|
-
def __init__(
|
87
|
-
|
88
|
-
|
89
|
-
|
198
|
+
def __init__(
|
199
|
+
self,
|
200
|
+
params,
|
201
|
+
lr=0.0025,
|
202
|
+
betas=(0.9, 0.99),
|
203
|
+
eps=1e-8,
|
204
|
+
weight_decay=0,
|
205
|
+
warmup_steps=0,
|
206
|
+
foreach: bool = True,
|
207
|
+
storage_dtype: str = "float32",
|
208
|
+
mars: bool = False,
|
209
|
+
caution: bool = False,
|
210
|
+
mars_gamma: float = 0.0025,
|
211
|
+
gradient_clipping: C.str_or_fn = C.use_default,
|
212
|
+
update_clipping: C.str_or_fn = C.use_default,
|
213
|
+
palm: bool = C.use_default,
|
214
|
+
beta2_scale: float = 0.8,
|
215
|
+
):
|
90
216
|
defaults = locals()
|
91
217
|
defaults.pop("self")
|
92
218
|
params = defaults.pop("params")
|
93
|
-
super().__init__(
|
94
|
-
|
219
|
+
super().__init__(
|
220
|
+
params,
|
221
|
+
defaults,
|
222
|
+
foreach,
|
223
|
+
gradient_clipping,
|
224
|
+
update_clipping,
|
225
|
+
palm,
|
226
|
+
C.scale_by_laprop,
|
227
|
+
C.orthogonalize_update,
|
228
|
+
)
|
95
229
|
|
96
230
|
|
97
231
|
class ForeachSOAP(C.BaseOpt):
|
@@ -105,16 +239,38 @@ class ForeachSOAP(C.BaseOpt):
|
|
105
239
|
https://arxiv.org/abs/2409.11321
|
106
240
|
https://github.com/nikhilvyas/SOAP
|
107
241
|
"""
|
242
|
+
|
108
243
|
use_precond_schedule: bool = False
|
109
244
|
|
110
|
-
def __init__(
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
245
|
+
def __init__(
|
246
|
+
self,
|
247
|
+
params,
|
248
|
+
lr: float = 3e-3,
|
249
|
+
betas=(0.9, 0.95),
|
250
|
+
shampoo_beta: float = 0.95,
|
251
|
+
eps: float = 1e-8,
|
252
|
+
weight_decay: float = 0.01,
|
253
|
+
precondition_frequency: int = 2,
|
254
|
+
max_precond_dim: int = 2048, #
|
255
|
+
merge_dims: bool = True,
|
256
|
+
precondition_1d: bool = False,
|
257
|
+
normalize_grads: bool = False,
|
258
|
+
correct_bias: bool = True,
|
259
|
+
warmup_steps: int = 0,
|
260
|
+
split: bool = False,
|
261
|
+
foreach: bool = True,
|
262
|
+
mars: bool = False,
|
263
|
+
caution: bool = False,
|
264
|
+
mars_gamma: float = 0.0025,
|
265
|
+
palm: bool = C.use_default,
|
266
|
+
precond_scheduler=(1 / 3, 9),
|
267
|
+
beta2_scale: float = 0.8,
|
268
|
+
use_precond_schedule: bool = C.use_default,
|
269
|
+
gradient_clipping: C.str_or_fn = C.use_default,
|
270
|
+
update_clipping: C.str_or_fn = C.use_default,
|
271
|
+
storage_dtype: str = "float32",
|
272
|
+
stochastic_schedule: bool = False,
|
273
|
+
):
|
118
274
|
use_precond_schedule = C.default(use_precond_schedule, self.use_precond_schedule)
|
119
275
|
|
120
276
|
defaults = locals()
|
@@ -122,24 +278,54 @@ class ForeachSOAP(C.BaseOpt):
|
|
122
278
|
params = defaults.pop("params")
|
123
279
|
|
124
280
|
if use_precond_schedule:
|
125
|
-
del defaults[
|
281
|
+
del defaults["precondition_frequency"]
|
126
282
|
self.precond_schedule = utils.get_soap_precond_schedule(defaults.pop("precond_scheduler"))
|
127
283
|
else:
|
128
|
-
del defaults[
|
284
|
+
del defaults["precond_scheduler"]
|
129
285
|
self.precond_schedule = 1 / defaults.pop("precondition_frequency")
|
130
|
-
super().__init__(
|
131
|
-
|
286
|
+
super().__init__(
|
287
|
+
params,
|
288
|
+
defaults,
|
289
|
+
foreach,
|
290
|
+
gradient_clipping,
|
291
|
+
update_clipping,
|
292
|
+
palm, #
|
293
|
+
C.scale_by_soap,
|
294
|
+
)
|
132
295
|
|
133
296
|
|
134
297
|
class ForeachSignLaProp(C.BaseOpt):
|
135
|
-
def __init__(
|
136
|
-
|
137
|
-
|
138
|
-
|
298
|
+
def __init__(
|
299
|
+
self,
|
300
|
+
params,
|
301
|
+
lr=0.0025,
|
302
|
+
betas=(0.9, 0.99),
|
303
|
+
eps=1e-8,
|
304
|
+
weight_decay=0,
|
305
|
+
warmup_steps=0,
|
306
|
+
foreach: bool = True,
|
307
|
+
storage_dtype: str = "float32",
|
308
|
+
mars: bool = False,
|
309
|
+
caution: bool = False,
|
310
|
+
mars_gamma: float = 0.0025,
|
311
|
+
gradient_clipping: C.str_or_fn = C.use_default,
|
312
|
+
update_clipping: C.str_or_fn = C.use_default,
|
313
|
+
palm: bool = C.use_default,
|
314
|
+
beta2_scale: float = 0.8,
|
315
|
+
):
|
139
316
|
defaults = locals()
|
140
317
|
defaults.pop("self")
|
141
318
|
params = defaults.pop("params")
|
142
|
-
super().__init__(
|
319
|
+
super().__init__(
|
320
|
+
params,
|
321
|
+
defaults,
|
322
|
+
foreach,
|
323
|
+
gradient_clipping,
|
324
|
+
update_clipping,
|
325
|
+
palm,
|
326
|
+
C.scale_by_laprop,
|
327
|
+
C.sign,
|
328
|
+
)
|
143
329
|
|
144
330
|
|
145
331
|
class ForeachSOLP(C.BaseOpt):
|
@@ -153,16 +339,38 @@ class ForeachSOLP(C.BaseOpt):
|
|
153
339
|
https://arxiv.org/abs/2409.11321
|
154
340
|
https://github.com/nikhilvyas/SOAP
|
155
341
|
"""
|
342
|
+
|
156
343
|
use_precond_schedule: bool = False
|
157
344
|
|
158
|
-
def __init__(
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
345
|
+
def __init__(
|
346
|
+
self,
|
347
|
+
params,
|
348
|
+
lr: float = 3e-3,
|
349
|
+
betas=(0.9, 0.95),
|
350
|
+
shampoo_beta: float = 0.95,
|
351
|
+
eps: float = 1e-8,
|
352
|
+
weight_decay: float = 0.01,
|
353
|
+
precondition_frequency: int = 2,
|
354
|
+
max_precond_dim: int = 2048, #
|
355
|
+
merge_dims: bool = True,
|
356
|
+
precondition_1d: bool = False,
|
357
|
+
normalize_grads: bool = False,
|
358
|
+
correct_bias: bool = True,
|
359
|
+
warmup_steps: int = 0,
|
360
|
+
split: bool = False,
|
361
|
+
foreach: bool = True,
|
362
|
+
mars: bool = False,
|
363
|
+
caution: bool = False,
|
364
|
+
mars_gamma: float = 0.0025,
|
365
|
+
palm: bool = C.use_default,
|
366
|
+
precond_scheduler=(1 / 3, 9),
|
367
|
+
beta2_scale: float = 0.8,
|
368
|
+
use_precond_schedule: bool = C.use_default,
|
369
|
+
gradient_clipping: C.str_or_fn = C.use_default,
|
370
|
+
update_clipping: C.str_or_fn = C.use_default,
|
371
|
+
storage_dtype: str = "float32",
|
372
|
+
stochastic_schedule: bool = False,
|
373
|
+
):
|
166
374
|
use_precond_schedule = C.default(use_precond_schedule, self.use_precond_schedule)
|
167
375
|
|
168
376
|
defaults = locals()
|
@@ -170,13 +378,20 @@ class ForeachSOLP(C.BaseOpt):
|
|
170
378
|
params = defaults.pop("params")
|
171
379
|
|
172
380
|
if use_precond_schedule:
|
173
|
-
del defaults[
|
381
|
+
del defaults["precondition_frequency"]
|
174
382
|
self.precond_schedule = utils.get_soap_precond_schedule(defaults.pop("precond_scheduler"))
|
175
383
|
else:
|
176
|
-
del defaults[
|
384
|
+
del defaults["precond_scheduler"]
|
177
385
|
self.precond_schedule = 1 / defaults.pop("precondition_frequency")
|
178
|
-
super().__init__(
|
179
|
-
|
386
|
+
super().__init__(
|
387
|
+
params,
|
388
|
+
defaults,
|
389
|
+
foreach,
|
390
|
+
gradient_clipping,
|
391
|
+
update_clipping,
|
392
|
+
palm, #
|
393
|
+
functools.partial(C.scale_by_soap, inner="laprop"),
|
394
|
+
)
|
180
395
|
|
181
396
|
|
182
397
|
class PaLMForeachSOAP(ForeachSOAP):
|
@@ -194,27 +409,71 @@ class PrecondSchedulePaLMForeachSOAP(ForeachSOAP):
|
|
194
409
|
|
195
410
|
|
196
411
|
class OrthoLaProp(C.BaseOpt):
|
197
|
-
def __init__(
|
198
|
-
|
199
|
-
|
200
|
-
|
412
|
+
def __init__(
|
413
|
+
self,
|
414
|
+
params,
|
415
|
+
lr=0.0025,
|
416
|
+
betas=(0.9, 0.99),
|
417
|
+
eps=1e-8,
|
418
|
+
weight_decay=0,
|
419
|
+
warmup_steps=0,
|
420
|
+
foreach: bool = True,
|
421
|
+
storage_dtype: str = "float32",
|
422
|
+
mars: bool = False,
|
423
|
+
caution: bool = False,
|
424
|
+
mars_gamma: float = 0.0025,
|
425
|
+
gradient_clipping: C.str_or_fn = C.use_default,
|
426
|
+
update_clipping: C.str_or_fn = C.use_default,
|
427
|
+
palm: bool = C.use_default,
|
428
|
+
beta2_scale: float = 0.8,
|
429
|
+
):
|
201
430
|
defaults = locals()
|
202
431
|
defaults.pop("self")
|
203
432
|
params = defaults.pop("params")
|
204
|
-
super().__init__(
|
205
|
-
|
433
|
+
super().__init__(
|
434
|
+
params,
|
435
|
+
defaults,
|
436
|
+
foreach,
|
437
|
+
gradient_clipping,
|
438
|
+
update_clipping,
|
439
|
+
palm,
|
440
|
+
C.orthogonalize_grad_to_param,
|
441
|
+
C.scale_by_laprop,
|
442
|
+
)
|
206
443
|
|
207
444
|
|
208
445
|
class LaPropOrtho(C.BaseOpt):
|
209
|
-
def __init__(
|
210
|
-
|
211
|
-
|
212
|
-
|
446
|
+
def __init__(
|
447
|
+
self,
|
448
|
+
params,
|
449
|
+
lr=0.0025,
|
450
|
+
betas=(0.9, 0.99),
|
451
|
+
eps=1e-8,
|
452
|
+
weight_decay=0,
|
453
|
+
warmup_steps=0,
|
454
|
+
foreach: bool = True,
|
455
|
+
storage_dtype: str = "float32",
|
456
|
+
mars: bool = False,
|
457
|
+
caution: bool = False,
|
458
|
+
mars_gamma: float = 0.0025,
|
459
|
+
gradient_clipping: C.str_or_fn = C.use_default,
|
460
|
+
update_clipping: C.str_or_fn = C.use_default,
|
461
|
+
palm: bool = C.use_default,
|
462
|
+
beta2_scale: float = 0.8,
|
463
|
+
):
|
213
464
|
defaults = locals()
|
214
465
|
defaults.pop("self")
|
215
466
|
params = defaults.pop("params")
|
216
|
-
super().__init__(
|
217
|
-
|
467
|
+
super().__init__(
|
468
|
+
params,
|
469
|
+
defaults,
|
470
|
+
foreach,
|
471
|
+
gradient_clipping,
|
472
|
+
update_clipping,
|
473
|
+
palm,
|
474
|
+
C.scale_by_laprop,
|
475
|
+
C.orthogonalize_grad_to_param,
|
476
|
+
)
|
218
477
|
|
219
478
|
|
220
479
|
class ForeachPSGDKron(C.BaseOpt):
|
@@ -228,20 +487,43 @@ class ForeachPSGDKron(C.BaseOpt):
|
|
228
487
|
cached: bool = False
|
229
488
|
exp_avg_input: bool = True
|
230
489
|
|
231
|
-
def __init__(
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
|
238
|
-
|
239
|
-
|
240
|
-
|
490
|
+
def __init__(
|
491
|
+
self,
|
492
|
+
params,
|
493
|
+
lr=0.001,
|
494
|
+
beta=0.9,
|
495
|
+
weight_decay=0.0,
|
496
|
+
preconditioner_update_probability=None,
|
497
|
+
max_size_triangular=2048,
|
498
|
+
min_ndim_triangular=2,
|
499
|
+
memory_save_mode=None,
|
500
|
+
momentum_into_precond_update=True,
|
501
|
+
warmup_steps: int = 0,
|
502
|
+
merge_dims: bool = False,
|
503
|
+
split: bool = False,
|
504
|
+
store_triu_as_line: bool = True,
|
505
|
+
foreach: bool = True,
|
506
|
+
q_dtype="float32",
|
507
|
+
stochastic_schedule: bool = False,
|
508
|
+
storage_dtype: str = "float32",
|
509
|
+
mars: bool = False,
|
510
|
+
caution: bool = False,
|
511
|
+
mars_gamma: float = 0.0025,
|
512
|
+
delayed: Optional[bool] = C.use_default,
|
513
|
+
cached: Optional[bool] = C.use_default,
|
514
|
+
exp_avg_input: Optional[bool] = C.use_default,
|
515
|
+
gradient_clipping: C.str_or_fn = C.use_default,
|
516
|
+
update_clipping: C.str_or_fn = C.use_default, #
|
517
|
+
# expert parameters
|
518
|
+
precond_init_scale=None,
|
519
|
+
precond_init_scale_scale=1,
|
520
|
+
precond_lr=0.1,
|
521
|
+
):
|
241
522
|
defaults = locals()
|
242
523
|
defaults.pop("self")
|
243
|
-
self.precond_schedule =
|
244
|
-
"preconditioner_update_probability") or utils.precond_update_prob_schedule()
|
524
|
+
self.precond_schedule = (
|
525
|
+
defaults.pop("preconditioner_update_probability") or utils.precond_update_prob_schedule()
|
526
|
+
)
|
245
527
|
params = defaults.pop("params")
|
246
528
|
|
247
529
|
delayed = C.default(delayed, self.delayed)
|
@@ -249,9 +531,16 @@ class ForeachPSGDKron(C.BaseOpt):
|
|
249
531
|
exp_avg_input = C.default(exp_avg_input, self.exp_avg_input)
|
250
532
|
update_clipping = C.default(update_clipping, utils.trust_region_clip_)
|
251
533
|
|
252
|
-
super().__init__(
|
253
|
-
|
254
|
-
|
534
|
+
super().__init__(
|
535
|
+
params,
|
536
|
+
defaults,
|
537
|
+
foreach,
|
538
|
+
gradient_clipping,
|
539
|
+
update_clipping,
|
540
|
+
False, #
|
541
|
+
*(C.exp_avg,) * exp_avg_input, #
|
542
|
+
functools.partial(C.scale_by_delayed_psgd if delayed else C.scale_by_psgd, cached=cached),
|
543
|
+
)
|
255
544
|
|
256
545
|
|
257
546
|
class ForeachPurePSGD(ForeachPSGDKron):
|
@@ -275,6 +564,74 @@ class ForeachCachedNewtonPSGD(ForeachCachedPSGDKron):
|
|
275
564
|
hessian_approx = True
|
276
565
|
|
277
566
|
|
567
|
+
class ForeachPSGDLRA(C.BaseOpt):
|
568
|
+
"""
|
569
|
+
Originally from Evan Walters and Omead Pooladzandi, 2024
|
570
|
+
Modified under Creative Commons Attribution 4.0 International
|
571
|
+
Source available at https://github.com/evanatyourservice/kron_torch/blob/97a2b5ee8a1a4c29e4780bbf6c521e545189eff9/kron_torch/kron.py
|
572
|
+
"""
|
573
|
+
|
574
|
+
delayed: bool = False
|
575
|
+
exp_avg_input: bool = True
|
576
|
+
|
577
|
+
def __init__(
|
578
|
+
self,
|
579
|
+
params,
|
580
|
+
lr=0.001,
|
581
|
+
beta=0.9,
|
582
|
+
weight_decay=0.0,
|
583
|
+
preconditioner_update_probability=None,
|
584
|
+
momentum_into_precond_update=True,
|
585
|
+
rank: int = 4,
|
586
|
+
warmup_steps: int = 0,
|
587
|
+
foreach: bool = True,
|
588
|
+
q_dtype="float32",
|
589
|
+
stochastic_schedule: bool = False,
|
590
|
+
storage_dtype: str = "float32",
|
591
|
+
mars: bool = False,
|
592
|
+
caution: bool = False,
|
593
|
+
mars_gamma: float = 0.0025,
|
594
|
+
delayed: Optional[bool] = C.use_default,
|
595
|
+
exp_avg_input: Optional[bool] = C.use_default,
|
596
|
+
gradient_clipping: C.str_or_fn = C.use_default,
|
597
|
+
update_clipping: C.str_or_fn = C.use_default,
|
598
|
+
eps: float = 1e-8, #
|
599
|
+
# expert parameters
|
600
|
+
precond_init_scale=None,
|
601
|
+
precond_init_scale_scale=1,
|
602
|
+
precond_lr=0.1,
|
603
|
+
):
|
604
|
+
defaults = locals()
|
605
|
+
defaults.pop("self")
|
606
|
+
self.precond_schedule = (
|
607
|
+
defaults.pop("preconditioner_update_probability") or utils.precond_update_prob_schedule()
|
608
|
+
)
|
609
|
+
params = defaults.pop("params")
|
610
|
+
|
611
|
+
delayed = C.default(delayed, self.delayed)
|
612
|
+
exp_avg_input = C.default(exp_avg_input, self.exp_avg_input)
|
613
|
+
update_clipping = C.default(update_clipping, utils.trust_region_clip_)
|
614
|
+
|
615
|
+
super().__init__(
|
616
|
+
params,
|
617
|
+
defaults,
|
618
|
+
foreach,
|
619
|
+
gradient_clipping,
|
620
|
+
update_clipping,
|
621
|
+
False, #
|
622
|
+
*(C.exp_avg,) * exp_avg_input, #
|
623
|
+
C.scale_by_delayed_psgd_lra if delayed else C.scale_by_psgd_lra,
|
624
|
+
)
|
625
|
+
|
626
|
+
|
627
|
+
class ForeachDelayedPSGDLRA(ForeachPSGDLRA):
|
628
|
+
delayed: bool = True
|
629
|
+
|
630
|
+
|
631
|
+
class ForeachNewtonPSGDLRA(ForeachPSGDLRA):
|
632
|
+
hessian_approx = True
|
633
|
+
|
634
|
+
|
278
635
|
PalmForEachSoap = PaLMForeachSOAP
|
279
636
|
PaLMSOAP = PaLMForeachSOAP
|
280
637
|
PaLMSFAdamW = PaLMForeachSFAdamW
|
@@ -293,11 +650,50 @@ CachedPSGDKron = ForeachCachedPSGDKron
|
|
293
650
|
CachedDelayedPSGDKron = ForeachCachedDelayedPSGDKron
|
294
651
|
Muon = ForeachMuon
|
295
652
|
SignLaProp = ForeachSignLaProp
|
296
|
-
|
297
|
-
|
298
|
-
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
|
303
|
-
|
653
|
+
DelayedPSGDLRA = ForeachDelayedPSGDLRA
|
654
|
+
PSGDLRA = ForeachPSGDLRA
|
655
|
+
NewtonPSGDLRA = ForeachNewtonPSGDLRA
|
656
|
+
|
657
|
+
__all__ = [
|
658
|
+
"Muon",
|
659
|
+
"RMSprop",
|
660
|
+
"PrecondSchedulePaLMSOAP",
|
661
|
+
"PSGDKron",
|
662
|
+
"PurePSGD",
|
663
|
+
"DelayedPSGD",
|
664
|
+
"CachedPSGDKron",
|
665
|
+
"CachedDelayedPSGDKron",
|
666
|
+
"PalmForEachSoap",
|
667
|
+
"PaLMSOAP",
|
668
|
+
"PaLMSFAdamW",
|
669
|
+
"LaProp",
|
670
|
+
"ADOPT",
|
671
|
+
"PrecondScheduleSOAP",
|
672
|
+
"PrecondSchedulePaLMSOAP",
|
673
|
+
"RMSprop",
|
674
|
+
"MuonLaProp",
|
675
|
+
"ForeachSignLaProp",
|
676
|
+
"ForeachDelayedPSGDLRA",
|
677
|
+
"ForeachPSGDLRA",
|
678
|
+
"ForeachPSGDLRA",
|
679
|
+
"ForeachNewtonPSGDLRA", #
|
680
|
+
"ForeachAdamW",
|
681
|
+
"ForeachSFAdamW",
|
682
|
+
"ForeachLaProp",
|
683
|
+
"ForeachADOPT",
|
684
|
+
"ForeachSOAP",
|
685
|
+
"ForeachPSGDKron",
|
686
|
+
"ForeachPurePSGD",
|
687
|
+
"ForeachDelayedPSGD",
|
688
|
+
"ForeachCachedPSGDKron",
|
689
|
+
"ForeachCachedDelayedPSGDKron",
|
690
|
+
"ForeachRMSprop",
|
691
|
+
"ForeachMuon",
|
692
|
+
"ForeachCachedNewtonPSGD",
|
693
|
+
"OrthoLaProp",
|
694
|
+
"LaPropOrtho",
|
695
|
+
"SignLaProp",
|
696
|
+
"DelayedPSGD",
|
697
|
+
"PSGDLRA",
|
698
|
+
"NewtonPSGDLRA",
|
699
|
+
]
|