heavyball 1.6.2__py3-none-any.whl → 1.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
heavyball/__init__.py CHANGED
@@ -6,10 +6,24 @@ from . import utils
6
6
 
7
7
 
8
8
  class ForeachAdamW(C.BaseOpt):
9
- def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
10
- foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False,
11
- mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default,
12
- update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8):
9
+ def __init__(
10
+ self,
11
+ params,
12
+ lr=0.0025,
13
+ betas=(0.9, 0.99),
14
+ eps=1e-8,
15
+ weight_decay=0,
16
+ warmup_steps=0,
17
+ foreach: bool = True,
18
+ storage_dtype: str = "float32",
19
+ mars: bool = False,
20
+ caution: bool = False,
21
+ mars_gamma: float = 0.0025,
22
+ gradient_clipping: C.str_or_fn = C.use_default,
23
+ update_clipping: C.str_or_fn = C.use_default,
24
+ palm: bool = C.use_default,
25
+ beta2_scale: float = 0.8,
26
+ ):
13
27
  defaults = locals()
14
28
  defaults.pop("self")
15
29
  params = defaults.pop("params")
@@ -21,26 +35,74 @@ class ForeachRMSprop(C.BaseOpt):
21
35
  Debiased RMSprop (not torch.optim.RMSprop)
22
36
  """
23
37
 
24
- def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-6, weight_decay=0, warmup_steps=0, r=0.0,
25
- weight_lr_power=2.0, foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False,
26
- caution: bool = False, mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default,
27
- update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8):
38
+ def __init__(
39
+ self,
40
+ params,
41
+ lr=0.0025,
42
+ betas=(0.9, 0.99),
43
+ eps=1e-6,
44
+ weight_decay=0,
45
+ warmup_steps=0,
46
+ r=0.0,
47
+ weight_lr_power=2.0,
48
+ foreach: bool = True,
49
+ storage_dtype: str = "float32",
50
+ mars: bool = False,
51
+ caution: bool = False,
52
+ mars_gamma: float = 0.0025,
53
+ gradient_clipping: C.str_or_fn = C.use_default,
54
+ update_clipping: C.str_or_fn = C.use_default,
55
+ palm: bool = C.use_default,
56
+ beta2_scale: float = 0.8,
57
+ ):
28
58
  defaults = locals()
29
59
  defaults.pop("self")
30
60
  params = defaults.pop("params")
31
- super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, C.scale_by_exp_avg_sq)
61
+ super().__init__(
62
+ params,
63
+ defaults,
64
+ foreach,
65
+ gradient_clipping,
66
+ update_clipping,
67
+ palm,
68
+ C.scale_by_exp_avg_sq,
69
+ )
32
70
 
33
71
 
34
72
  class ForeachSFAdamW(C.ScheduleFree):
35
- def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-6, weight_decay=0, warmup_steps=0, r=0.0,
36
- weight_lr_power=2.0, foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False,
37
- caution: bool = False, mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default,
38
- update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8):
73
+ def __init__(
74
+ self,
75
+ params,
76
+ lr=0.0025,
77
+ betas=(0.9, 0.99),
78
+ eps=1e-6,
79
+ weight_decay=0,
80
+ warmup_steps=0,
81
+ r=0.0,
82
+ weight_lr_power=2.0,
83
+ foreach: bool = True,
84
+ storage_dtype: str = "float32",
85
+ mars: bool = False,
86
+ caution: bool = False,
87
+ mars_gamma: float = 0.0025,
88
+ gradient_clipping: C.str_or_fn = C.use_default,
89
+ update_clipping: C.str_or_fn = C.use_default,
90
+ palm: bool = C.use_default,
91
+ beta2_scale: float = 0.8,
92
+ ):
39
93
  defaults = locals()
40
94
  defaults.pop("self")
41
95
  params = defaults.pop("params")
42
- super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, C.scale_by_exp_avg_sq,
43
- C.update_by_schedule_free)
96
+ super().__init__(
97
+ params,
98
+ defaults,
99
+ foreach,
100
+ gradient_clipping,
101
+ update_clipping,
102
+ palm,
103
+ C.scale_by_exp_avg_sq,
104
+ C.update_by_schedule_free,
105
+ )
44
106
 
45
107
 
46
108
  class PaLMForeachSFAdamW(ForeachSFAdamW):
@@ -48,10 +110,24 @@ class PaLMForeachSFAdamW(ForeachSFAdamW):
48
110
 
49
111
 
50
112
  class ForeachADOPT(C.BaseOpt):
51
- def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
52
- foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False,
53
- mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default,
54
- update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8):
113
+ def __init__(
114
+ self,
115
+ params,
116
+ lr=0.0025,
117
+ betas=(0.9, 0.99),
118
+ eps=1e-8,
119
+ weight_decay=0,
120
+ warmup_steps=0,
121
+ foreach: bool = True,
122
+ storage_dtype: str = "float32",
123
+ mars: bool = False,
124
+ caution: bool = False,
125
+ mars_gamma: float = 0.0025,
126
+ gradient_clipping: C.str_or_fn = C.use_default,
127
+ update_clipping: C.str_or_fn = C.use_default,
128
+ palm: bool = C.use_default,
129
+ beta2_scale: float = 0.8,
130
+ ):
55
131
  defaults = locals()
56
132
  defaults.pop("self")
57
133
  params = defaults.pop("params")
@@ -59,23 +135,59 @@ class ForeachADOPT(C.BaseOpt):
59
135
 
60
136
 
61
137
  class ForeachMuon(C.BaseOpt):
62
- def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
63
- foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False,
64
- mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default,
65
- update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8,
66
- nesterov: bool = True):
138
+ def __init__(
139
+ self,
140
+ params,
141
+ lr=0.0025,
142
+ betas=(0.9, 0.99),
143
+ eps=1e-8,
144
+ weight_decay=0,
145
+ warmup_steps=0,
146
+ foreach: bool = True,
147
+ storage_dtype: str = "float32",
148
+ mars: bool = False,
149
+ caution: bool = False,
150
+ mars_gamma: float = 0.0025,
151
+ gradient_clipping: C.str_or_fn = C.use_default,
152
+ update_clipping: C.str_or_fn = C.use_default,
153
+ palm: bool = C.use_default,
154
+ beta2_scale: float = 0.8,
155
+ nesterov: bool = True,
156
+ ):
67
157
  defaults = locals()
68
158
  defaults.pop("self")
69
159
  params = defaults.pop("params")
70
- super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm,
71
- C.nesterov_momentum if nesterov else C.heavyball_momentum, C.orthogonalize_update)
160
+ super().__init__(
161
+ params,
162
+ defaults,
163
+ foreach,
164
+ gradient_clipping,
165
+ update_clipping,
166
+ palm,
167
+ C.nesterov_momentum if nesterov else C.heavyball_momentum,
168
+ C.orthogonalize_update,
169
+ )
72
170
 
73
171
 
74
172
  class ForeachLaProp(C.BaseOpt):
75
- def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
76
- foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False,
77
- mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default,
78
- update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8):
173
+ def __init__(
174
+ self,
175
+ params,
176
+ lr=0.0025,
177
+ betas=(0.9, 0.99),
178
+ eps=1e-8,
179
+ weight_decay=0,
180
+ warmup_steps=0,
181
+ foreach: bool = True,
182
+ storage_dtype: str = "float32",
183
+ mars: bool = False,
184
+ caution: bool = False,
185
+ mars_gamma: float = 0.0025,
186
+ gradient_clipping: C.str_or_fn = C.use_default,
187
+ update_clipping: C.str_or_fn = C.use_default,
188
+ palm: bool = C.use_default,
189
+ beta2_scale: float = 0.8,
190
+ ):
79
191
  defaults = locals()
80
192
  defaults.pop("self")
81
193
  params = defaults.pop("params")
@@ -83,15 +195,37 @@ class ForeachLaProp(C.BaseOpt):
83
195
 
84
196
 
85
197
  class MuonLaProp(C.BaseOpt):
86
- def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
87
- foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False,
88
- mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default,
89
- update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8):
198
+ def __init__(
199
+ self,
200
+ params,
201
+ lr=0.0025,
202
+ betas=(0.9, 0.99),
203
+ eps=1e-8,
204
+ weight_decay=0,
205
+ warmup_steps=0,
206
+ foreach: bool = True,
207
+ storage_dtype: str = "float32",
208
+ mars: bool = False,
209
+ caution: bool = False,
210
+ mars_gamma: float = 0.0025,
211
+ gradient_clipping: C.str_or_fn = C.use_default,
212
+ update_clipping: C.str_or_fn = C.use_default,
213
+ palm: bool = C.use_default,
214
+ beta2_scale: float = 0.8,
215
+ ):
90
216
  defaults = locals()
91
217
  defaults.pop("self")
92
218
  params = defaults.pop("params")
93
- super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, C.scale_by_laprop,
94
- C.orthogonalize_update)
219
+ super().__init__(
220
+ params,
221
+ defaults,
222
+ foreach,
223
+ gradient_clipping,
224
+ update_clipping,
225
+ palm,
226
+ C.scale_by_laprop,
227
+ C.orthogonalize_update,
228
+ )
95
229
 
96
230
 
97
231
  class ForeachSOAP(C.BaseOpt):
@@ -105,16 +239,38 @@ class ForeachSOAP(C.BaseOpt):
105
239
  https://arxiv.org/abs/2409.11321
106
240
  https://github.com/nikhilvyas/SOAP
107
241
  """
242
+
108
243
  use_precond_schedule: bool = False
109
244
 
110
- def __init__(self, params, lr: float = 3e-3, betas=(0.9, 0.95), shampoo_beta: float = 0.95, eps: float = 1e-8,
111
- weight_decay: float = 0.01, precondition_frequency: int = 2, max_precond_dim: int = 2048, #
112
- merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
113
- correct_bias: bool = True, warmup_steps: int = 0, split: bool = False, foreach: bool = True,
114
- mars: bool = False, caution: bool = False, mars_gamma: float = 0.0025, palm: bool = C.use_default,
115
- precond_scheduler=(1 / 3, 9), beta2_scale: float = 0.8, use_precond_schedule: bool = C.use_default,
116
- gradient_clipping: C.str_or_fn = C.use_default, update_clipping: C.str_or_fn = C.use_default,
117
- storage_dtype: str = 'float32', stochastic_schedule: bool = False):
245
+ def __init__(
246
+ self,
247
+ params,
248
+ lr: float = 3e-3,
249
+ betas=(0.9, 0.95),
250
+ shampoo_beta: float = 0.95,
251
+ eps: float = 1e-8,
252
+ weight_decay: float = 0.01,
253
+ precondition_frequency: int = 2,
254
+ max_precond_dim: int = 2048, #
255
+ merge_dims: bool = True,
256
+ precondition_1d: bool = False,
257
+ normalize_grads: bool = False,
258
+ correct_bias: bool = True,
259
+ warmup_steps: int = 0,
260
+ split: bool = False,
261
+ foreach: bool = True,
262
+ mars: bool = False,
263
+ caution: bool = False,
264
+ mars_gamma: float = 0.0025,
265
+ palm: bool = C.use_default,
266
+ precond_scheduler=(1 / 3, 9),
267
+ beta2_scale: float = 0.8,
268
+ use_precond_schedule: bool = C.use_default,
269
+ gradient_clipping: C.str_or_fn = C.use_default,
270
+ update_clipping: C.str_or_fn = C.use_default,
271
+ storage_dtype: str = "float32",
272
+ stochastic_schedule: bool = False,
273
+ ):
118
274
  use_precond_schedule = C.default(use_precond_schedule, self.use_precond_schedule)
119
275
 
120
276
  defaults = locals()
@@ -122,24 +278,54 @@ class ForeachSOAP(C.BaseOpt):
122
278
  params = defaults.pop("params")
123
279
 
124
280
  if use_precond_schedule:
125
- del defaults['precondition_frequency']
281
+ del defaults["precondition_frequency"]
126
282
  self.precond_schedule = utils.get_soap_precond_schedule(defaults.pop("precond_scheduler"))
127
283
  else:
128
- del defaults['precond_scheduler']
284
+ del defaults["precond_scheduler"]
129
285
  self.precond_schedule = 1 / defaults.pop("precondition_frequency")
130
- super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, #
131
- C.scale_by_soap)
286
+ super().__init__(
287
+ params,
288
+ defaults,
289
+ foreach,
290
+ gradient_clipping,
291
+ update_clipping,
292
+ palm, #
293
+ C.scale_by_soap,
294
+ )
132
295
 
133
296
 
134
297
  class ForeachSignLaProp(C.BaseOpt):
135
- def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
136
- foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False,
137
- mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default,
138
- update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8):
298
+ def __init__(
299
+ self,
300
+ params,
301
+ lr=0.0025,
302
+ betas=(0.9, 0.99),
303
+ eps=1e-8,
304
+ weight_decay=0,
305
+ warmup_steps=0,
306
+ foreach: bool = True,
307
+ storage_dtype: str = "float32",
308
+ mars: bool = False,
309
+ caution: bool = False,
310
+ mars_gamma: float = 0.0025,
311
+ gradient_clipping: C.str_or_fn = C.use_default,
312
+ update_clipping: C.str_or_fn = C.use_default,
313
+ palm: bool = C.use_default,
314
+ beta2_scale: float = 0.8,
315
+ ):
139
316
  defaults = locals()
140
317
  defaults.pop("self")
141
318
  params = defaults.pop("params")
142
- super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, C.scale_by_laprop, C.sign)
319
+ super().__init__(
320
+ params,
321
+ defaults,
322
+ foreach,
323
+ gradient_clipping,
324
+ update_clipping,
325
+ palm,
326
+ C.scale_by_laprop,
327
+ C.sign,
328
+ )
143
329
 
144
330
 
145
331
  class ForeachSOLP(C.BaseOpt):
@@ -153,16 +339,38 @@ class ForeachSOLP(C.BaseOpt):
153
339
  https://arxiv.org/abs/2409.11321
154
340
  https://github.com/nikhilvyas/SOAP
155
341
  """
342
+
156
343
  use_precond_schedule: bool = False
157
344
 
158
- def __init__(self, params, lr: float = 3e-3, betas=(0.9, 0.95), shampoo_beta: float = 0.95, eps: float = 1e-8,
159
- weight_decay: float = 0.01, precondition_frequency: int = 2, max_precond_dim: int = 2048, #
160
- merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
161
- correct_bias: bool = True, warmup_steps: int = 0, split: bool = False, foreach: bool = True,
162
- mars: bool = False, caution: bool = False, mars_gamma: float = 0.0025, palm: bool = C.use_default,
163
- precond_scheduler=(1 / 3, 9), beta2_scale: float = 0.8, use_precond_schedule: bool = C.use_default,
164
- gradient_clipping: C.str_or_fn = C.use_default, update_clipping: C.str_or_fn = C.use_default,
165
- storage_dtype: str = 'float32', stochastic_schedule: bool = False):
345
+ def __init__(
346
+ self,
347
+ params,
348
+ lr: float = 3e-3,
349
+ betas=(0.9, 0.95),
350
+ shampoo_beta: float = 0.95,
351
+ eps: float = 1e-8,
352
+ weight_decay: float = 0.01,
353
+ precondition_frequency: int = 2,
354
+ max_precond_dim: int = 2048, #
355
+ merge_dims: bool = True,
356
+ precondition_1d: bool = False,
357
+ normalize_grads: bool = False,
358
+ correct_bias: bool = True,
359
+ warmup_steps: int = 0,
360
+ split: bool = False,
361
+ foreach: bool = True,
362
+ mars: bool = False,
363
+ caution: bool = False,
364
+ mars_gamma: float = 0.0025,
365
+ palm: bool = C.use_default,
366
+ precond_scheduler=(1 / 3, 9),
367
+ beta2_scale: float = 0.8,
368
+ use_precond_schedule: bool = C.use_default,
369
+ gradient_clipping: C.str_or_fn = C.use_default,
370
+ update_clipping: C.str_or_fn = C.use_default,
371
+ storage_dtype: str = "float32",
372
+ stochastic_schedule: bool = False,
373
+ ):
166
374
  use_precond_schedule = C.default(use_precond_schedule, self.use_precond_schedule)
167
375
 
168
376
  defaults = locals()
@@ -170,13 +378,20 @@ class ForeachSOLP(C.BaseOpt):
170
378
  params = defaults.pop("params")
171
379
 
172
380
  if use_precond_schedule:
173
- del defaults['precondition_frequency']
381
+ del defaults["precondition_frequency"]
174
382
  self.precond_schedule = utils.get_soap_precond_schedule(defaults.pop("precond_scheduler"))
175
383
  else:
176
- del defaults['precond_scheduler']
384
+ del defaults["precond_scheduler"]
177
385
  self.precond_schedule = 1 / defaults.pop("precondition_frequency")
178
- super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, #
179
- functools.partial(C.scale_by_soap, inner='laprop'))
386
+ super().__init__(
387
+ params,
388
+ defaults,
389
+ foreach,
390
+ gradient_clipping,
391
+ update_clipping,
392
+ palm, #
393
+ functools.partial(C.scale_by_soap, inner="laprop"),
394
+ )
180
395
 
181
396
 
182
397
  class PaLMForeachSOAP(ForeachSOAP):
@@ -194,27 +409,71 @@ class PrecondSchedulePaLMForeachSOAP(ForeachSOAP):
194
409
 
195
410
 
196
411
  class OrthoLaProp(C.BaseOpt):
197
- def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
198
- foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False,
199
- mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default,
200
- update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8):
412
+ def __init__(
413
+ self,
414
+ params,
415
+ lr=0.0025,
416
+ betas=(0.9, 0.99),
417
+ eps=1e-8,
418
+ weight_decay=0,
419
+ warmup_steps=0,
420
+ foreach: bool = True,
421
+ storage_dtype: str = "float32",
422
+ mars: bool = False,
423
+ caution: bool = False,
424
+ mars_gamma: float = 0.0025,
425
+ gradient_clipping: C.str_or_fn = C.use_default,
426
+ update_clipping: C.str_or_fn = C.use_default,
427
+ palm: bool = C.use_default,
428
+ beta2_scale: float = 0.8,
429
+ ):
201
430
  defaults = locals()
202
431
  defaults.pop("self")
203
432
  params = defaults.pop("params")
204
- super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm,
205
- C.orthogonalize_grad_to_param, C.scale_by_laprop)
433
+ super().__init__(
434
+ params,
435
+ defaults,
436
+ foreach,
437
+ gradient_clipping,
438
+ update_clipping,
439
+ palm,
440
+ C.orthogonalize_grad_to_param,
441
+ C.scale_by_laprop,
442
+ )
206
443
 
207
444
 
208
445
  class LaPropOrtho(C.BaseOpt):
209
- def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
210
- foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False,
211
- mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default,
212
- update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8):
446
+ def __init__(
447
+ self,
448
+ params,
449
+ lr=0.0025,
450
+ betas=(0.9, 0.99),
451
+ eps=1e-8,
452
+ weight_decay=0,
453
+ warmup_steps=0,
454
+ foreach: bool = True,
455
+ storage_dtype: str = "float32",
456
+ mars: bool = False,
457
+ caution: bool = False,
458
+ mars_gamma: float = 0.0025,
459
+ gradient_clipping: C.str_or_fn = C.use_default,
460
+ update_clipping: C.str_or_fn = C.use_default,
461
+ palm: bool = C.use_default,
462
+ beta2_scale: float = 0.8,
463
+ ):
213
464
  defaults = locals()
214
465
  defaults.pop("self")
215
466
  params = defaults.pop("params")
216
- super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, C.scale_by_laprop,
217
- C.orthogonalize_grad_to_param)
467
+ super().__init__(
468
+ params,
469
+ defaults,
470
+ foreach,
471
+ gradient_clipping,
472
+ update_clipping,
473
+ palm,
474
+ C.scale_by_laprop,
475
+ C.orthogonalize_grad_to_param,
476
+ )
218
477
 
219
478
 
220
479
  class ForeachPSGDKron(C.BaseOpt):
@@ -228,20 +487,43 @@ class ForeachPSGDKron(C.BaseOpt):
228
487
  cached: bool = False
229
488
  exp_avg_input: bool = True
230
489
 
231
- def __init__(self, params, lr=0.001, beta=0.9, weight_decay=0.0, preconditioner_update_probability=None,
232
- max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
233
- momentum_into_precond_update=True, warmup_steps: int = 0, merge_dims: bool = False,
234
- split: bool = False, store_triu_as_line: bool = True, foreach: bool = True, q_dtype='float32',
235
- stochastic_schedule: bool = False, storage_dtype: str = 'float32', mars: bool = False,
236
- caution: bool = False, mars_gamma: float = 0.0025, delayed: Optional[bool] = C.use_default,
237
- cached: Optional[bool] = C.use_default, exp_avg_input: Optional[bool] = C.use_default,
238
- gradient_clipping: C.str_or_fn = C.use_default, update_clipping: C.str_or_fn = C.use_default, #
239
- # expert parameters
240
- precond_init_scale=1.0, precond_lr=0.1):
490
+ def __init__(
491
+ self,
492
+ params,
493
+ lr=0.001,
494
+ beta=0.9,
495
+ weight_decay=0.0,
496
+ preconditioner_update_probability=None,
497
+ max_size_triangular=2048,
498
+ min_ndim_triangular=2,
499
+ memory_save_mode=None,
500
+ momentum_into_precond_update=True,
501
+ warmup_steps: int = 0,
502
+ merge_dims: bool = False,
503
+ split: bool = False,
504
+ store_triu_as_line: bool = True,
505
+ foreach: bool = True,
506
+ q_dtype="float32",
507
+ stochastic_schedule: bool = False,
508
+ storage_dtype: str = "float32",
509
+ mars: bool = False,
510
+ caution: bool = False,
511
+ mars_gamma: float = 0.0025,
512
+ delayed: Optional[bool] = C.use_default,
513
+ cached: Optional[bool] = C.use_default,
514
+ exp_avg_input: Optional[bool] = C.use_default,
515
+ gradient_clipping: C.str_or_fn = C.use_default,
516
+ update_clipping: C.str_or_fn = C.use_default, #
517
+ # expert parameters
518
+ precond_init_scale=None,
519
+ precond_init_scale_scale=1,
520
+ precond_lr=0.1,
521
+ ):
241
522
  defaults = locals()
242
523
  defaults.pop("self")
243
- self.precond_schedule = defaults.pop(
244
- "preconditioner_update_probability") or utils.precond_update_prob_schedule()
524
+ self.precond_schedule = (
525
+ defaults.pop("preconditioner_update_probability") or utils.precond_update_prob_schedule()
526
+ )
245
527
  params = defaults.pop("params")
246
528
 
247
529
  delayed = C.default(delayed, self.delayed)
@@ -249,9 +531,16 @@ class ForeachPSGDKron(C.BaseOpt):
249
531
  exp_avg_input = C.default(exp_avg_input, self.exp_avg_input)
250
532
  update_clipping = C.default(update_clipping, utils.trust_region_clip_)
251
533
 
252
- super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, False, #
253
- *(C.exp_avg,) * exp_avg_input, #
254
- functools.partial(C.scale_by_delayed_psgd if delayed else C.scale_by_psgd, cached=cached))
534
+ super().__init__(
535
+ params,
536
+ defaults,
537
+ foreach,
538
+ gradient_clipping,
539
+ update_clipping,
540
+ False, #
541
+ *(C.exp_avg,) * exp_avg_input, #
542
+ functools.partial(C.scale_by_delayed_psgd if delayed else C.scale_by_psgd, cached=cached),
543
+ )
255
544
 
256
545
 
257
546
  class ForeachPurePSGD(ForeachPSGDKron):
@@ -275,6 +564,74 @@ class ForeachCachedNewtonPSGD(ForeachCachedPSGDKron):
275
564
  hessian_approx = True
276
565
 
277
566
 
567
+ class ForeachPSGDLRA(C.BaseOpt):
568
+ """
569
+ Originally from Evan Walters and Omead Pooladzandi, 2024
570
+ Modified under Creative Commons Attribution 4.0 International
571
+ Source available at https://github.com/evanatyourservice/kron_torch/blob/97a2b5ee8a1a4c29e4780bbf6c521e545189eff9/kron_torch/kron.py
572
+ """
573
+
574
+ delayed: bool = False
575
+ exp_avg_input: bool = True
576
+
577
+ def __init__(
578
+ self,
579
+ params,
580
+ lr=0.001,
581
+ beta=0.9,
582
+ weight_decay=0.0,
583
+ preconditioner_update_probability=None,
584
+ momentum_into_precond_update=True,
585
+ rank: int = 4,
586
+ warmup_steps: int = 0,
587
+ foreach: bool = True,
588
+ q_dtype="float32",
589
+ stochastic_schedule: bool = False,
590
+ storage_dtype: str = "float32",
591
+ mars: bool = False,
592
+ caution: bool = False,
593
+ mars_gamma: float = 0.0025,
594
+ delayed: Optional[bool] = C.use_default,
595
+ exp_avg_input: Optional[bool] = C.use_default,
596
+ gradient_clipping: C.str_or_fn = C.use_default,
597
+ update_clipping: C.str_or_fn = C.use_default,
598
+ eps: float = 1e-8, #
599
+ # expert parameters
600
+ precond_init_scale=None,
601
+ precond_init_scale_scale=1,
602
+ precond_lr=0.1,
603
+ ):
604
+ defaults = locals()
605
+ defaults.pop("self")
606
+ self.precond_schedule = (
607
+ defaults.pop("preconditioner_update_probability") or utils.precond_update_prob_schedule()
608
+ )
609
+ params = defaults.pop("params")
610
+
611
+ delayed = C.default(delayed, self.delayed)
612
+ exp_avg_input = C.default(exp_avg_input, self.exp_avg_input)
613
+ update_clipping = C.default(update_clipping, utils.trust_region_clip_)
614
+
615
+ super().__init__(
616
+ params,
617
+ defaults,
618
+ foreach,
619
+ gradient_clipping,
620
+ update_clipping,
621
+ False, #
622
+ *(C.exp_avg,) * exp_avg_input, #
623
+ C.scale_by_delayed_psgd_lra if delayed else C.scale_by_psgd_lra,
624
+ )
625
+
626
+
627
+ class ForeachDelayedPSGDLRA(ForeachPSGDLRA):
628
+ delayed: bool = True
629
+
630
+
631
+ class ForeachNewtonPSGDLRA(ForeachPSGDLRA):
632
+ hessian_approx = True
633
+
634
+
278
635
  PalmForEachSoap = PaLMForeachSOAP
279
636
  PaLMSOAP = PaLMForeachSOAP
280
637
  PaLMSFAdamW = PaLMForeachSFAdamW
@@ -293,11 +650,50 @@ CachedPSGDKron = ForeachCachedPSGDKron
293
650
  CachedDelayedPSGDKron = ForeachCachedDelayedPSGDKron
294
651
  Muon = ForeachMuon
295
652
  SignLaProp = ForeachSignLaProp
296
-
297
- __all__ = ["Muon", "RMSprop", "PrecondSchedulePaLMSOAP", "PSGDKron", "PurePSGD", "DelayedPSGD", "CachedPSGDKron",
298
- "CachedDelayedPSGDKron", "PalmForEachSoap", "PaLMSOAP", "PaLMSFAdamW", "LaProp", "ADOPT",
299
- "PrecondScheduleSOAP", "PrecondSchedulePaLMSOAP", 'RMSprop', 'MuonLaProp', 'ForeachSignLaProp' #
300
- "ForeachAdamW", "ForeachSFAdamW",
301
- "ForeachLaProp", "ForeachADOPT", "ForeachSOAP", "ForeachPSGDKron", "ForeachPurePSGD", "ForeachDelayedPSGD",
302
- "ForeachCachedPSGDKron", "ForeachCachedDelayedPSGDKron", "ForeachRMSprop", "ForeachMuon",
303
- 'ForeachCachedNewtonPSGD', 'OrthoLaProp', 'LaPropOrtho', 'SignLaProp']
653
+ DelayedPSGDLRA = ForeachDelayedPSGDLRA
654
+ PSGDLRA = ForeachPSGDLRA
655
+ NewtonPSGDLRA = ForeachNewtonPSGDLRA
656
+
657
+ __all__ = [
658
+ "Muon",
659
+ "RMSprop",
660
+ "PrecondSchedulePaLMSOAP",
661
+ "PSGDKron",
662
+ "PurePSGD",
663
+ "DelayedPSGD",
664
+ "CachedPSGDKron",
665
+ "CachedDelayedPSGDKron",
666
+ "PalmForEachSoap",
667
+ "PaLMSOAP",
668
+ "PaLMSFAdamW",
669
+ "LaProp",
670
+ "ADOPT",
671
+ "PrecondScheduleSOAP",
672
+ "PrecondSchedulePaLMSOAP",
673
+ "RMSprop",
674
+ "MuonLaProp",
675
+ "ForeachSignLaProp",
676
+ "ForeachDelayedPSGDLRA",
677
+ "ForeachPSGDLRA",
678
+ "ForeachPSGDLRA",
679
+ "ForeachNewtonPSGDLRA", #
680
+ "ForeachAdamW",
681
+ "ForeachSFAdamW",
682
+ "ForeachLaProp",
683
+ "ForeachADOPT",
684
+ "ForeachSOAP",
685
+ "ForeachPSGDKron",
686
+ "ForeachPurePSGD",
687
+ "ForeachDelayedPSGD",
688
+ "ForeachCachedPSGDKron",
689
+ "ForeachCachedDelayedPSGDKron",
690
+ "ForeachRMSprop",
691
+ "ForeachMuon",
692
+ "ForeachCachedNewtonPSGD",
693
+ "OrthoLaProp",
694
+ "LaPropOrtho",
695
+ "SignLaProp",
696
+ "DelayedPSGD",
697
+ "PSGDLRA",
698
+ "NewtonPSGDLRA",
699
+ ]