heavyball 1.6.3__py3-none-any.whl → 1.7.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- heavyball/__init__.py +515 -100
- heavyball/chainable.py +487 -156
- heavyball/optimizations/__init__.py +38 -0
- heavyball/optimizations/integrator.py +169 -0
- heavyball/optimizations/optimizations.py +329 -0
- heavyball/utils.py +780 -241
- {heavyball-1.6.3.dist-info → heavyball-1.7.1.dist-info}/METADATA +3 -2
- heavyball-1.7.1.dist-info/RECORD +11 -0
- {heavyball-1.6.3.dist-info → heavyball-1.7.1.dist-info}/WHEEL +1 -1
- {heavyball-1.6.3.dist-info → heavyball-1.7.1.dist-info/licenses}/LICENSE +1 -1
- heavyball-1.6.3.dist-info/RECORD +0 -8
- {heavyball-1.6.3.dist-info → heavyball-1.7.1.dist-info}/top_level.txt +0 -0
heavyball/__init__.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1
1
|
import functools
|
2
|
+
import math
|
2
3
|
from typing import Optional
|
3
4
|
|
4
5
|
from . import chainable as C
|
@@ -6,10 +7,24 @@ from . import utils
|
|
6
7
|
|
7
8
|
|
8
9
|
class ForeachAdamW(C.BaseOpt):
|
9
|
-
def __init__(
|
10
|
-
|
11
|
-
|
12
|
-
|
10
|
+
def __init__(
|
11
|
+
self,
|
12
|
+
params,
|
13
|
+
lr=0.0025,
|
14
|
+
betas=(0.9, 0.99),
|
15
|
+
eps=1e-8,
|
16
|
+
weight_decay=0,
|
17
|
+
warmup_steps=0,
|
18
|
+
foreach: bool = True,
|
19
|
+
storage_dtype: str = "float32",
|
20
|
+
mars: bool = False,
|
21
|
+
caution: bool = False,
|
22
|
+
mars_gamma: float = 0.0025,
|
23
|
+
gradient_clipping: C.str_or_fn = C.use_default,
|
24
|
+
update_clipping: C.str_or_fn = C.use_default,
|
25
|
+
palm: bool = C.use_default,
|
26
|
+
beta2_scale: float = 0.8,
|
27
|
+
):
|
13
28
|
defaults = locals()
|
14
29
|
defaults.pop("self")
|
15
30
|
params = defaults.pop("params")
|
@@ -21,26 +36,74 @@ class ForeachRMSprop(C.BaseOpt):
|
|
21
36
|
Debiased RMSprop (not torch.optim.RMSprop)
|
22
37
|
"""
|
23
38
|
|
24
|
-
def __init__(
|
25
|
-
|
26
|
-
|
27
|
-
|
39
|
+
def __init__(
|
40
|
+
self,
|
41
|
+
params,
|
42
|
+
lr=0.0025,
|
43
|
+
betas=(0.9, 0.99),
|
44
|
+
eps=1e-6,
|
45
|
+
weight_decay=0,
|
46
|
+
warmup_steps=0,
|
47
|
+
r=0.0,
|
48
|
+
weight_lr_power=2.0,
|
49
|
+
foreach: bool = True,
|
50
|
+
storage_dtype: str = "float32",
|
51
|
+
mars: bool = False,
|
52
|
+
caution: bool = False,
|
53
|
+
mars_gamma: float = 0.0025,
|
54
|
+
gradient_clipping: C.str_or_fn = C.use_default,
|
55
|
+
update_clipping: C.str_or_fn = C.use_default,
|
56
|
+
palm: bool = C.use_default,
|
57
|
+
beta2_scale: float = 0.8,
|
58
|
+
):
|
28
59
|
defaults = locals()
|
29
60
|
defaults.pop("self")
|
30
61
|
params = defaults.pop("params")
|
31
|
-
super().__init__(
|
62
|
+
super().__init__(
|
63
|
+
params,
|
64
|
+
defaults,
|
65
|
+
foreach,
|
66
|
+
gradient_clipping,
|
67
|
+
update_clipping,
|
68
|
+
palm,
|
69
|
+
C.scale_by_exp_avg_sq,
|
70
|
+
)
|
32
71
|
|
33
72
|
|
34
73
|
class ForeachSFAdamW(C.ScheduleFree):
|
35
|
-
def __init__(
|
36
|
-
|
37
|
-
|
38
|
-
|
74
|
+
def __init__(
|
75
|
+
self,
|
76
|
+
params,
|
77
|
+
lr=0.0025,
|
78
|
+
betas=(0.9, 0.99),
|
79
|
+
eps=1e-6,
|
80
|
+
weight_decay=0,
|
81
|
+
warmup_steps=0,
|
82
|
+
r=0.0,
|
83
|
+
weight_lr_power=2.0,
|
84
|
+
foreach: bool = True,
|
85
|
+
storage_dtype: str = "float32",
|
86
|
+
mars: bool = False,
|
87
|
+
caution: bool = False,
|
88
|
+
mars_gamma: float = 0.0025,
|
89
|
+
gradient_clipping: C.str_or_fn = C.use_default,
|
90
|
+
update_clipping: C.str_or_fn = C.use_default,
|
91
|
+
palm: bool = C.use_default,
|
92
|
+
beta2_scale: float = 0.8,
|
93
|
+
):
|
39
94
|
defaults = locals()
|
40
95
|
defaults.pop("self")
|
41
96
|
params = defaults.pop("params")
|
42
|
-
super().__init__(
|
43
|
-
|
97
|
+
super().__init__(
|
98
|
+
params,
|
99
|
+
defaults,
|
100
|
+
foreach,
|
101
|
+
gradient_clipping,
|
102
|
+
update_clipping,
|
103
|
+
palm,
|
104
|
+
C.scale_by_exp_avg_sq,
|
105
|
+
C.update_by_schedule_free,
|
106
|
+
)
|
44
107
|
|
45
108
|
|
46
109
|
class PaLMForeachSFAdamW(ForeachSFAdamW):
|
@@ -48,10 +111,24 @@ class PaLMForeachSFAdamW(ForeachSFAdamW):
|
|
48
111
|
|
49
112
|
|
50
113
|
class ForeachADOPT(C.BaseOpt):
|
51
|
-
def __init__(
|
52
|
-
|
53
|
-
|
54
|
-
|
114
|
+
def __init__(
|
115
|
+
self,
|
116
|
+
params,
|
117
|
+
lr=0.0025,
|
118
|
+
betas=(0.9, 0.99),
|
119
|
+
eps=1e-8,
|
120
|
+
weight_decay=0,
|
121
|
+
warmup_steps=0,
|
122
|
+
foreach: bool = True,
|
123
|
+
storage_dtype: str = "float32",
|
124
|
+
mars: bool = False,
|
125
|
+
caution: bool = False,
|
126
|
+
mars_gamma: float = 0.0025,
|
127
|
+
gradient_clipping: C.str_or_fn = C.use_default,
|
128
|
+
update_clipping: C.str_or_fn = C.use_default,
|
129
|
+
palm: bool = C.use_default,
|
130
|
+
beta2_scale: float = 0.8,
|
131
|
+
):
|
55
132
|
defaults = locals()
|
56
133
|
defaults.pop("self")
|
57
134
|
params = defaults.pop("params")
|
@@ -59,23 +136,59 @@ class ForeachADOPT(C.BaseOpt):
|
|
59
136
|
|
60
137
|
|
61
138
|
class ForeachMuon(C.BaseOpt):
|
62
|
-
def __init__(
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
139
|
+
def __init__(
|
140
|
+
self,
|
141
|
+
params,
|
142
|
+
lr=0.0025,
|
143
|
+
betas=(0.9, 0.99),
|
144
|
+
eps=1e-8,
|
145
|
+
weight_decay=0,
|
146
|
+
warmup_steps=0,
|
147
|
+
foreach: bool = True,
|
148
|
+
storage_dtype: str = "float32",
|
149
|
+
mars: bool = False,
|
150
|
+
caution: bool = False,
|
151
|
+
mars_gamma: float = 0.0025,
|
152
|
+
gradient_clipping: C.str_or_fn = C.use_default,
|
153
|
+
update_clipping: C.str_or_fn = C.use_default,
|
154
|
+
palm: bool = C.use_default,
|
155
|
+
beta2_scale: float = 0.8,
|
156
|
+
nesterov: bool = True,
|
157
|
+
):
|
67
158
|
defaults = locals()
|
68
159
|
defaults.pop("self")
|
69
160
|
params = defaults.pop("params")
|
70
|
-
super().__init__(
|
71
|
-
|
161
|
+
super().__init__(
|
162
|
+
params,
|
163
|
+
defaults,
|
164
|
+
foreach,
|
165
|
+
gradient_clipping,
|
166
|
+
update_clipping,
|
167
|
+
palm,
|
168
|
+
C.nesterov_momentum if nesterov else C.heavyball_momentum,
|
169
|
+
C.orthogonalize_update,
|
170
|
+
)
|
72
171
|
|
73
172
|
|
74
173
|
class ForeachLaProp(C.BaseOpt):
|
75
|
-
def __init__(
|
76
|
-
|
77
|
-
|
78
|
-
|
174
|
+
def __init__(
|
175
|
+
self,
|
176
|
+
params,
|
177
|
+
lr=0.0025,
|
178
|
+
betas=(0.9, 0.99),
|
179
|
+
eps=1e-8,
|
180
|
+
weight_decay=0,
|
181
|
+
warmup_steps=0,
|
182
|
+
foreach: bool = True,
|
183
|
+
storage_dtype: str = "float32",
|
184
|
+
mars: bool = False,
|
185
|
+
caution: bool = False,
|
186
|
+
mars_gamma: float = 0.0025,
|
187
|
+
gradient_clipping: C.str_or_fn = C.use_default,
|
188
|
+
update_clipping: C.str_or_fn = C.use_default,
|
189
|
+
palm: bool = C.use_default,
|
190
|
+
beta2_scale: float = 0.8,
|
191
|
+
):
|
79
192
|
defaults = locals()
|
80
193
|
defaults.pop("self")
|
81
194
|
params = defaults.pop("params")
|
@@ -83,15 +196,37 @@ class ForeachLaProp(C.BaseOpt):
|
|
83
196
|
|
84
197
|
|
85
198
|
class MuonLaProp(C.BaseOpt):
|
86
|
-
def __init__(
|
87
|
-
|
88
|
-
|
89
|
-
|
199
|
+
def __init__(
|
200
|
+
self,
|
201
|
+
params,
|
202
|
+
lr=0.0025,
|
203
|
+
betas=(0.9, 0.99),
|
204
|
+
eps=1e-8,
|
205
|
+
weight_decay=0,
|
206
|
+
warmup_steps=0,
|
207
|
+
foreach: bool = True,
|
208
|
+
storage_dtype: str = "float32",
|
209
|
+
mars: bool = False,
|
210
|
+
caution: bool = False,
|
211
|
+
mars_gamma: float = 0.0025,
|
212
|
+
gradient_clipping: C.str_or_fn = C.use_default,
|
213
|
+
update_clipping: C.str_or_fn = C.use_default,
|
214
|
+
palm: bool = C.use_default,
|
215
|
+
beta2_scale: float = 0.8,
|
216
|
+
):
|
90
217
|
defaults = locals()
|
91
218
|
defaults.pop("self")
|
92
219
|
params = defaults.pop("params")
|
93
|
-
super().__init__(
|
94
|
-
|
220
|
+
super().__init__(
|
221
|
+
params,
|
222
|
+
defaults,
|
223
|
+
foreach,
|
224
|
+
gradient_clipping,
|
225
|
+
update_clipping,
|
226
|
+
palm,
|
227
|
+
C.scale_by_laprop,
|
228
|
+
C.orthogonalize_update,
|
229
|
+
)
|
95
230
|
|
96
231
|
|
97
232
|
class ForeachSOAP(C.BaseOpt):
|
@@ -105,16 +240,38 @@ class ForeachSOAP(C.BaseOpt):
|
|
105
240
|
https://arxiv.org/abs/2409.11321
|
106
241
|
https://github.com/nikhilvyas/SOAP
|
107
242
|
"""
|
243
|
+
|
108
244
|
use_precond_schedule: bool = False
|
109
245
|
|
110
|
-
def __init__(
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
246
|
+
def __init__(
|
247
|
+
self,
|
248
|
+
params,
|
249
|
+
lr: float = 3e-3,
|
250
|
+
betas=(0.9, 0.95),
|
251
|
+
shampoo_beta: float = 0.95,
|
252
|
+
eps: float = 1e-8,
|
253
|
+
weight_decay: float = 0.01,
|
254
|
+
precondition_frequency: int = 2,
|
255
|
+
max_precond_dim: int = 2048, #
|
256
|
+
merge_dims: bool = True,
|
257
|
+
precondition_1d: bool = False,
|
258
|
+
normalize_grads: bool = False,
|
259
|
+
correct_bias: bool = True,
|
260
|
+
warmup_steps: int = 0,
|
261
|
+
split: bool = False,
|
262
|
+
foreach: bool = True,
|
263
|
+
mars: bool = False,
|
264
|
+
caution: bool = False,
|
265
|
+
mars_gamma: float = 0.0025,
|
266
|
+
palm: bool = C.use_default,
|
267
|
+
precond_scheduler=(1 / 3, 9),
|
268
|
+
beta2_scale: float = 0.8,
|
269
|
+
use_precond_schedule: bool = C.use_default,
|
270
|
+
gradient_clipping: C.str_or_fn = C.use_default,
|
271
|
+
update_clipping: C.str_or_fn = C.use_default,
|
272
|
+
storage_dtype: str = "float32",
|
273
|
+
stochastic_schedule: bool = False,
|
274
|
+
):
|
118
275
|
use_precond_schedule = C.default(use_precond_schedule, self.use_precond_schedule)
|
119
276
|
|
120
277
|
defaults = locals()
|
@@ -122,24 +279,54 @@ class ForeachSOAP(C.BaseOpt):
|
|
122
279
|
params = defaults.pop("params")
|
123
280
|
|
124
281
|
if use_precond_schedule:
|
125
|
-
del defaults[
|
282
|
+
del defaults["precondition_frequency"]
|
126
283
|
self.precond_schedule = utils.get_soap_precond_schedule(defaults.pop("precond_scheduler"))
|
127
284
|
else:
|
128
|
-
del defaults[
|
285
|
+
del defaults["precond_scheduler"]
|
129
286
|
self.precond_schedule = 1 / defaults.pop("precondition_frequency")
|
130
|
-
super().__init__(
|
131
|
-
|
287
|
+
super().__init__(
|
288
|
+
params,
|
289
|
+
defaults,
|
290
|
+
foreach,
|
291
|
+
gradient_clipping,
|
292
|
+
update_clipping,
|
293
|
+
palm, #
|
294
|
+
C.scale_by_soap,
|
295
|
+
)
|
132
296
|
|
133
297
|
|
134
298
|
class ForeachSignLaProp(C.BaseOpt):
|
135
|
-
def __init__(
|
136
|
-
|
137
|
-
|
138
|
-
|
299
|
+
def __init__(
|
300
|
+
self,
|
301
|
+
params,
|
302
|
+
lr=0.0025,
|
303
|
+
betas=(0.9, 0.99),
|
304
|
+
eps=1e-8,
|
305
|
+
weight_decay=0,
|
306
|
+
warmup_steps=0,
|
307
|
+
foreach: bool = True,
|
308
|
+
storage_dtype: str = "float32",
|
309
|
+
mars: bool = False,
|
310
|
+
caution: bool = False,
|
311
|
+
mars_gamma: float = 0.0025,
|
312
|
+
gradient_clipping: C.str_or_fn = C.use_default,
|
313
|
+
update_clipping: C.str_or_fn = C.use_default,
|
314
|
+
palm: bool = C.use_default,
|
315
|
+
beta2_scale: float = 0.8,
|
316
|
+
):
|
139
317
|
defaults = locals()
|
140
318
|
defaults.pop("self")
|
141
319
|
params = defaults.pop("params")
|
142
|
-
super().__init__(
|
320
|
+
super().__init__(
|
321
|
+
params,
|
322
|
+
defaults,
|
323
|
+
foreach,
|
324
|
+
gradient_clipping,
|
325
|
+
update_clipping,
|
326
|
+
palm,
|
327
|
+
C.scale_by_laprop,
|
328
|
+
C.sign,
|
329
|
+
)
|
143
330
|
|
144
331
|
|
145
332
|
class ForeachSOLP(C.BaseOpt):
|
@@ -153,16 +340,38 @@ class ForeachSOLP(C.BaseOpt):
|
|
153
340
|
https://arxiv.org/abs/2409.11321
|
154
341
|
https://github.com/nikhilvyas/SOAP
|
155
342
|
"""
|
343
|
+
|
156
344
|
use_precond_schedule: bool = False
|
157
345
|
|
158
|
-
def __init__(
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
346
|
+
def __init__(
|
347
|
+
self,
|
348
|
+
params,
|
349
|
+
lr: float = 3e-3,
|
350
|
+
betas=(0.9, 0.95),
|
351
|
+
shampoo_beta: float = 0.95,
|
352
|
+
eps: float = 1e-8,
|
353
|
+
weight_decay: float = 0.01,
|
354
|
+
precondition_frequency: int = 2,
|
355
|
+
max_precond_dim: int = 2048, #
|
356
|
+
merge_dims: bool = True,
|
357
|
+
precondition_1d: bool = False,
|
358
|
+
normalize_grads: bool = False,
|
359
|
+
correct_bias: bool = True,
|
360
|
+
warmup_steps: int = 0,
|
361
|
+
split: bool = False,
|
362
|
+
foreach: bool = True,
|
363
|
+
mars: bool = False,
|
364
|
+
caution: bool = False,
|
365
|
+
mars_gamma: float = 0.0025,
|
366
|
+
palm: bool = C.use_default,
|
367
|
+
precond_scheduler=(1 / 3, 9),
|
368
|
+
beta2_scale: float = 0.8,
|
369
|
+
use_precond_schedule: bool = C.use_default,
|
370
|
+
gradient_clipping: C.str_or_fn = C.use_default,
|
371
|
+
update_clipping: C.str_or_fn = C.use_default,
|
372
|
+
storage_dtype: str = "float32",
|
373
|
+
stochastic_schedule: bool = False,
|
374
|
+
):
|
166
375
|
use_precond_schedule = C.default(use_precond_schedule, self.use_precond_schedule)
|
167
376
|
|
168
377
|
defaults = locals()
|
@@ -170,13 +379,20 @@ class ForeachSOLP(C.BaseOpt):
|
|
170
379
|
params = defaults.pop("params")
|
171
380
|
|
172
381
|
if use_precond_schedule:
|
173
|
-
del defaults[
|
382
|
+
del defaults["precondition_frequency"]
|
174
383
|
self.precond_schedule = utils.get_soap_precond_schedule(defaults.pop("precond_scheduler"))
|
175
384
|
else:
|
176
|
-
del defaults[
|
385
|
+
del defaults["precond_scheduler"]
|
177
386
|
self.precond_schedule = 1 / defaults.pop("precondition_frequency")
|
178
|
-
super().__init__(
|
179
|
-
|
387
|
+
super().__init__(
|
388
|
+
params,
|
389
|
+
defaults,
|
390
|
+
foreach,
|
391
|
+
gradient_clipping,
|
392
|
+
update_clipping,
|
393
|
+
palm, #
|
394
|
+
functools.partial(C.scale_by_soap, inner="laprop"),
|
395
|
+
)
|
180
396
|
|
181
397
|
|
182
398
|
class PaLMForeachSOAP(ForeachSOAP):
|
@@ -194,27 +410,71 @@ class PrecondSchedulePaLMForeachSOAP(ForeachSOAP):
|
|
194
410
|
|
195
411
|
|
196
412
|
class OrthoLaProp(C.BaseOpt):
|
197
|
-
def __init__(
|
198
|
-
|
199
|
-
|
200
|
-
|
413
|
+
def __init__(
|
414
|
+
self,
|
415
|
+
params,
|
416
|
+
lr=0.0025,
|
417
|
+
betas=(0.9, 0.99),
|
418
|
+
eps=1e-8,
|
419
|
+
weight_decay=0,
|
420
|
+
warmup_steps=0,
|
421
|
+
foreach: bool = True,
|
422
|
+
storage_dtype: str = "float32",
|
423
|
+
mars: bool = False,
|
424
|
+
caution: bool = False,
|
425
|
+
mars_gamma: float = 0.0025,
|
426
|
+
gradient_clipping: C.str_or_fn = C.use_default,
|
427
|
+
update_clipping: C.str_or_fn = C.use_default,
|
428
|
+
palm: bool = C.use_default,
|
429
|
+
beta2_scale: float = 0.8,
|
430
|
+
):
|
201
431
|
defaults = locals()
|
202
432
|
defaults.pop("self")
|
203
433
|
params = defaults.pop("params")
|
204
|
-
super().__init__(
|
205
|
-
|
434
|
+
super().__init__(
|
435
|
+
params,
|
436
|
+
defaults,
|
437
|
+
foreach,
|
438
|
+
gradient_clipping,
|
439
|
+
update_clipping,
|
440
|
+
palm,
|
441
|
+
C.orthogonalize_grad_to_param,
|
442
|
+
C.scale_by_laprop,
|
443
|
+
)
|
206
444
|
|
207
445
|
|
208
446
|
class LaPropOrtho(C.BaseOpt):
|
209
|
-
def __init__(
|
210
|
-
|
211
|
-
|
212
|
-
|
447
|
+
def __init__(
|
448
|
+
self,
|
449
|
+
params,
|
450
|
+
lr=0.0025,
|
451
|
+
betas=(0.9, 0.99),
|
452
|
+
eps=1e-8,
|
453
|
+
weight_decay=0,
|
454
|
+
warmup_steps=0,
|
455
|
+
foreach: bool = True,
|
456
|
+
storage_dtype: str = "float32",
|
457
|
+
mars: bool = False,
|
458
|
+
caution: bool = False,
|
459
|
+
mars_gamma: float = 0.0025,
|
460
|
+
gradient_clipping: C.str_or_fn = C.use_default,
|
461
|
+
update_clipping: C.str_or_fn = C.use_default,
|
462
|
+
palm: bool = C.use_default,
|
463
|
+
beta2_scale: float = 0.8,
|
464
|
+
):
|
213
465
|
defaults = locals()
|
214
466
|
defaults.pop("self")
|
215
467
|
params = defaults.pop("params")
|
216
|
-
super().__init__(
|
217
|
-
|
468
|
+
super().__init__(
|
469
|
+
params,
|
470
|
+
defaults,
|
471
|
+
foreach,
|
472
|
+
gradient_clipping,
|
473
|
+
update_clipping,
|
474
|
+
palm,
|
475
|
+
C.scale_by_laprop,
|
476
|
+
C.orthogonalize_grad_to_param,
|
477
|
+
)
|
218
478
|
|
219
479
|
|
220
480
|
class ForeachPSGDKron(C.BaseOpt):
|
@@ -228,20 +488,43 @@ class ForeachPSGDKron(C.BaseOpt):
|
|
228
488
|
cached: bool = False
|
229
489
|
exp_avg_input: bool = True
|
230
490
|
|
231
|
-
def __init__(
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
|
238
|
-
|
239
|
-
|
240
|
-
|
491
|
+
def __init__(
|
492
|
+
self,
|
493
|
+
params,
|
494
|
+
lr=0.001,
|
495
|
+
beta=0.9,
|
496
|
+
weight_decay=0.0,
|
497
|
+
preconditioner_update_probability=None,
|
498
|
+
max_size_triangular=2048,
|
499
|
+
min_ndim_triangular=2,
|
500
|
+
memory_save_mode=None,
|
501
|
+
momentum_into_precond_update=True,
|
502
|
+
warmup_steps: int = 0,
|
503
|
+
merge_dims: bool = False,
|
504
|
+
split: bool = False,
|
505
|
+
store_triu_as_line: bool = True,
|
506
|
+
foreach: bool = True,
|
507
|
+
q_dtype="float32",
|
508
|
+
stochastic_schedule: bool = False,
|
509
|
+
storage_dtype: str = "float32",
|
510
|
+
mars: bool = False,
|
511
|
+
caution: bool = False,
|
512
|
+
mars_gamma: float = 0.0025,
|
513
|
+
delayed: Optional[bool] = C.use_default,
|
514
|
+
cached: Optional[bool] = C.use_default,
|
515
|
+
exp_avg_input: Optional[bool] = C.use_default,
|
516
|
+
gradient_clipping: C.str_or_fn = C.use_default,
|
517
|
+
update_clipping: C.str_or_fn = C.use_default, #
|
518
|
+
# expert parameters
|
519
|
+
precond_init_scale=None,
|
520
|
+
precond_init_scale_scale=1,
|
521
|
+
precond_lr=0.1,
|
522
|
+
):
|
241
523
|
defaults = locals()
|
242
524
|
defaults.pop("self")
|
243
|
-
self.precond_schedule =
|
244
|
-
"preconditioner_update_probability") or utils.precond_update_prob_schedule()
|
525
|
+
self.precond_schedule = (
|
526
|
+
defaults.pop("preconditioner_update_probability") or utils.precond_update_prob_schedule()
|
527
|
+
)
|
245
528
|
params = defaults.pop("params")
|
246
529
|
|
247
530
|
delayed = C.default(delayed, self.delayed)
|
@@ -249,9 +532,16 @@ class ForeachPSGDKron(C.BaseOpt):
|
|
249
532
|
exp_avg_input = C.default(exp_avg_input, self.exp_avg_input)
|
250
533
|
update_clipping = C.default(update_clipping, utils.trust_region_clip_)
|
251
534
|
|
252
|
-
super().__init__(
|
253
|
-
|
254
|
-
|
535
|
+
super().__init__(
|
536
|
+
params,
|
537
|
+
defaults,
|
538
|
+
foreach,
|
539
|
+
gradient_clipping,
|
540
|
+
update_clipping,
|
541
|
+
False, #
|
542
|
+
*(C.exp_avg,) * exp_avg_input, #
|
543
|
+
functools.partial(C.scale_by_delayed_psgd if delayed else C.scale_by_psgd, cached=cached),
|
544
|
+
)
|
255
545
|
|
256
546
|
|
257
547
|
class ForeachPurePSGD(ForeachPSGDKron):
|
@@ -275,6 +565,90 @@ class ForeachCachedNewtonPSGD(ForeachCachedPSGDKron):
|
|
275
565
|
hessian_approx = True
|
276
566
|
|
277
567
|
|
568
|
+
class NewtonHybrid2PSGDKron(ForeachCachedNewtonPSGD):
|
569
|
+
hvp_interval = 2
|
570
|
+
|
571
|
+
|
572
|
+
class ForeachPSGDLRA(C.BaseOpt):
|
573
|
+
"""
|
574
|
+
Originally from Evan Walters and Omead Pooladzandi, 2024
|
575
|
+
Modified under Creative Commons Attribution 4.0 International
|
576
|
+
Source available at https://github.com/evanatyourservice/kron_torch/blob/97a2b5ee8a1a4c29e4780bbf6c521e545189eff9/kron_torch/kron.py
|
577
|
+
"""
|
578
|
+
|
579
|
+
delayed: bool = False
|
580
|
+
exp_avg_input: bool = True
|
581
|
+
|
582
|
+
def __init__(
|
583
|
+
self,
|
584
|
+
params,
|
585
|
+
lr=0.001,
|
586
|
+
beta=0.9,
|
587
|
+
weight_decay=0.0,
|
588
|
+
preconditioner_update_probability=None,
|
589
|
+
momentum_into_precond_update=True,
|
590
|
+
rank: Optional[int] = None,
|
591
|
+
warmup_steps: int = 0,
|
592
|
+
foreach: bool = True,
|
593
|
+
q_dtype="float32",
|
594
|
+
stochastic_schedule: bool = False,
|
595
|
+
storage_dtype: str = "float32",
|
596
|
+
mars: bool = False,
|
597
|
+
caution: bool = False,
|
598
|
+
mars_gamma: float = 0.0025,
|
599
|
+
delayed: Optional[bool] = C.use_default,
|
600
|
+
exp_avg_input: Optional[bool] = C.use_default,
|
601
|
+
gradient_clipping: C.str_or_fn = C.use_default,
|
602
|
+
update_clipping: C.str_or_fn = C.use_default,
|
603
|
+
eps: float = 1e-8, #
|
604
|
+
# expert parameters
|
605
|
+
precond_init_scale=None,
|
606
|
+
precond_init_scale_scale=1,
|
607
|
+
precond_lr=0.1,
|
608
|
+
):
|
609
|
+
defaults = locals()
|
610
|
+
defaults.pop("self")
|
611
|
+
self.precond_schedule = (
|
612
|
+
defaults.pop("preconditioner_update_probability") or utils.precond_update_prob_schedule()
|
613
|
+
)
|
614
|
+
params = defaults.pop("params")
|
615
|
+
|
616
|
+
if rank is None:
|
617
|
+
utils.warn_once(
|
618
|
+
f"{rank=}. It will be set to log2(param_count). This requires `params` to be of type list. Currently, {type(params)=}"
|
619
|
+
)
|
620
|
+
params = list(params)
|
621
|
+
defaults["rank"] = round(math.log2(sum(p.numel() for p in params)))
|
622
|
+
utils.warn_once(f"rank was set to {defaults['rank']}")
|
623
|
+
|
624
|
+
delayed = C.default(delayed, self.delayed)
|
625
|
+
exp_avg_input = C.default(exp_avg_input, self.exp_avg_input)
|
626
|
+
update_clipping = C.default(update_clipping, utils.trust_region_clip_)
|
627
|
+
|
628
|
+
super().__init__(
|
629
|
+
params,
|
630
|
+
defaults,
|
631
|
+
foreach,
|
632
|
+
gradient_clipping,
|
633
|
+
update_clipping,
|
634
|
+
False, #
|
635
|
+
*(C.exp_avg,) * exp_avg_input, #
|
636
|
+
C.scale_by_delayed_psgd_lra if delayed else C.scale_by_psgd_lra,
|
637
|
+
)
|
638
|
+
|
639
|
+
|
640
|
+
class ForeachDelayedPSGDLRA(ForeachPSGDLRA):
|
641
|
+
delayed: bool = True
|
642
|
+
|
643
|
+
|
644
|
+
class ForeachNewtonPSGDLRA(ForeachPSGDLRA):
|
645
|
+
hessian_approx = True
|
646
|
+
|
647
|
+
|
648
|
+
class NewtonHybrid2PSGDLRA(ForeachNewtonPSGDLRA):
|
649
|
+
hvp_interval = 2
|
650
|
+
|
651
|
+
|
278
652
|
PalmForEachSoap = PaLMForeachSOAP
|
279
653
|
PaLMSOAP = PaLMForeachSOAP
|
280
654
|
PaLMSFAdamW = PaLMForeachSFAdamW
|
@@ -293,11 +667,52 @@ CachedPSGDKron = ForeachCachedPSGDKron
|
|
293
667
|
CachedDelayedPSGDKron = ForeachCachedDelayedPSGDKron
|
294
668
|
Muon = ForeachMuon
|
295
669
|
SignLaProp = ForeachSignLaProp
|
296
|
-
|
297
|
-
|
298
|
-
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
|
303
|
-
|
670
|
+
DelayedPSGDLRA = ForeachDelayedPSGDLRA
|
671
|
+
PSGDLRA = ForeachPSGDLRA
|
672
|
+
NewtonPSGDLRA = ForeachNewtonPSGDLRA
|
673
|
+
|
674
|
+
__all__ = [
|
675
|
+
"Muon",
|
676
|
+
"RMSprop",
|
677
|
+
"PrecondSchedulePaLMSOAP",
|
678
|
+
"PSGDKron",
|
679
|
+
"PurePSGD",
|
680
|
+
"DelayedPSGD",
|
681
|
+
"CachedPSGDKron",
|
682
|
+
"CachedDelayedPSGDKron",
|
683
|
+
"PalmForEachSoap",
|
684
|
+
"PaLMSOAP",
|
685
|
+
"PaLMSFAdamW",
|
686
|
+
"LaProp",
|
687
|
+
"ADOPT",
|
688
|
+
"PrecondScheduleSOAP",
|
689
|
+
"PrecondSchedulePaLMSOAP",
|
690
|
+
"RMSprop",
|
691
|
+
"MuonLaProp",
|
692
|
+
"ForeachSignLaProp",
|
693
|
+
"ForeachDelayedPSGDLRA",
|
694
|
+
"ForeachPSGDLRA",
|
695
|
+
"ForeachPSGDLRA",
|
696
|
+
"ForeachNewtonPSGDLRA", #
|
697
|
+
"ForeachAdamW",
|
698
|
+
"ForeachSFAdamW",
|
699
|
+
"ForeachLaProp",
|
700
|
+
"ForeachADOPT",
|
701
|
+
"ForeachSOAP",
|
702
|
+
"ForeachPSGDKron",
|
703
|
+
"ForeachPurePSGD",
|
704
|
+
"ForeachDelayedPSGD",
|
705
|
+
"ForeachCachedPSGDKron",
|
706
|
+
"ForeachCachedDelayedPSGDKron",
|
707
|
+
"ForeachRMSprop",
|
708
|
+
"ForeachMuon",
|
709
|
+
"ForeachCachedNewtonPSGD",
|
710
|
+
"OrthoLaProp",
|
711
|
+
"LaPropOrtho",
|
712
|
+
"SignLaProp",
|
713
|
+
"DelayedPSGD",
|
714
|
+
"PSGDLRA",
|
715
|
+
"NewtonPSGDLRA",
|
716
|
+
"NewtonHybrid2PSGDLRA",
|
717
|
+
"NewtonHybrid2PSGDKron",
|
718
|
+
]
|