heavyball 1.6.3__py3-none-any.whl → 1.7.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
heavyball/__init__.py CHANGED
@@ -1,4 +1,5 @@
1
1
  import functools
2
+ import math
2
3
  from typing import Optional
3
4
 
4
5
  from . import chainable as C
@@ -6,10 +7,24 @@ from . import utils
6
7
 
7
8
 
8
9
  class ForeachAdamW(C.BaseOpt):
9
- def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
10
- foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False,
11
- mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default,
12
- update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8):
10
+ def __init__(
11
+ self,
12
+ params,
13
+ lr=0.0025,
14
+ betas=(0.9, 0.99),
15
+ eps=1e-8,
16
+ weight_decay=0,
17
+ warmup_steps=0,
18
+ foreach: bool = True,
19
+ storage_dtype: str = "float32",
20
+ mars: bool = False,
21
+ caution: bool = False,
22
+ mars_gamma: float = 0.0025,
23
+ gradient_clipping: C.str_or_fn = C.use_default,
24
+ update_clipping: C.str_or_fn = C.use_default,
25
+ palm: bool = C.use_default,
26
+ beta2_scale: float = 0.8,
27
+ ):
13
28
  defaults = locals()
14
29
  defaults.pop("self")
15
30
  params = defaults.pop("params")
@@ -21,26 +36,74 @@ class ForeachRMSprop(C.BaseOpt):
21
36
  Debiased RMSprop (not torch.optim.RMSprop)
22
37
  """
23
38
 
24
- def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-6, weight_decay=0, warmup_steps=0, r=0.0,
25
- weight_lr_power=2.0, foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False,
26
- caution: bool = False, mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default,
27
- update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8):
39
+ def __init__(
40
+ self,
41
+ params,
42
+ lr=0.0025,
43
+ betas=(0.9, 0.99),
44
+ eps=1e-6,
45
+ weight_decay=0,
46
+ warmup_steps=0,
47
+ r=0.0,
48
+ weight_lr_power=2.0,
49
+ foreach: bool = True,
50
+ storage_dtype: str = "float32",
51
+ mars: bool = False,
52
+ caution: bool = False,
53
+ mars_gamma: float = 0.0025,
54
+ gradient_clipping: C.str_or_fn = C.use_default,
55
+ update_clipping: C.str_or_fn = C.use_default,
56
+ palm: bool = C.use_default,
57
+ beta2_scale: float = 0.8,
58
+ ):
28
59
  defaults = locals()
29
60
  defaults.pop("self")
30
61
  params = defaults.pop("params")
31
- super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, C.scale_by_exp_avg_sq)
62
+ super().__init__(
63
+ params,
64
+ defaults,
65
+ foreach,
66
+ gradient_clipping,
67
+ update_clipping,
68
+ palm,
69
+ C.scale_by_exp_avg_sq,
70
+ )
32
71
 
33
72
 
34
73
  class ForeachSFAdamW(C.ScheduleFree):
35
- def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-6, weight_decay=0, warmup_steps=0, r=0.0,
36
- weight_lr_power=2.0, foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False,
37
- caution: bool = False, mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default,
38
- update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8):
74
+ def __init__(
75
+ self,
76
+ params,
77
+ lr=0.0025,
78
+ betas=(0.9, 0.99),
79
+ eps=1e-6,
80
+ weight_decay=0,
81
+ warmup_steps=0,
82
+ r=0.0,
83
+ weight_lr_power=2.0,
84
+ foreach: bool = True,
85
+ storage_dtype: str = "float32",
86
+ mars: bool = False,
87
+ caution: bool = False,
88
+ mars_gamma: float = 0.0025,
89
+ gradient_clipping: C.str_or_fn = C.use_default,
90
+ update_clipping: C.str_or_fn = C.use_default,
91
+ palm: bool = C.use_default,
92
+ beta2_scale: float = 0.8,
93
+ ):
39
94
  defaults = locals()
40
95
  defaults.pop("self")
41
96
  params = defaults.pop("params")
42
- super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, C.scale_by_exp_avg_sq,
43
- C.update_by_schedule_free)
97
+ super().__init__(
98
+ params,
99
+ defaults,
100
+ foreach,
101
+ gradient_clipping,
102
+ update_clipping,
103
+ palm,
104
+ C.scale_by_exp_avg_sq,
105
+ C.update_by_schedule_free,
106
+ )
44
107
 
45
108
 
46
109
  class PaLMForeachSFAdamW(ForeachSFAdamW):
@@ -48,10 +111,24 @@ class PaLMForeachSFAdamW(ForeachSFAdamW):
48
111
 
49
112
 
50
113
  class ForeachADOPT(C.BaseOpt):
51
- def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
52
- foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False,
53
- mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default,
54
- update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8):
114
+ def __init__(
115
+ self,
116
+ params,
117
+ lr=0.0025,
118
+ betas=(0.9, 0.99),
119
+ eps=1e-8,
120
+ weight_decay=0,
121
+ warmup_steps=0,
122
+ foreach: bool = True,
123
+ storage_dtype: str = "float32",
124
+ mars: bool = False,
125
+ caution: bool = False,
126
+ mars_gamma: float = 0.0025,
127
+ gradient_clipping: C.str_or_fn = C.use_default,
128
+ update_clipping: C.str_or_fn = C.use_default,
129
+ palm: bool = C.use_default,
130
+ beta2_scale: float = 0.8,
131
+ ):
55
132
  defaults = locals()
56
133
  defaults.pop("self")
57
134
  params = defaults.pop("params")
@@ -59,23 +136,59 @@ class ForeachADOPT(C.BaseOpt):
59
136
 
60
137
 
61
138
  class ForeachMuon(C.BaseOpt):
62
- def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
63
- foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False,
64
- mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default,
65
- update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8,
66
- nesterov: bool = True):
139
+ def __init__(
140
+ self,
141
+ params,
142
+ lr=0.0025,
143
+ betas=(0.9, 0.99),
144
+ eps=1e-8,
145
+ weight_decay=0,
146
+ warmup_steps=0,
147
+ foreach: bool = True,
148
+ storage_dtype: str = "float32",
149
+ mars: bool = False,
150
+ caution: bool = False,
151
+ mars_gamma: float = 0.0025,
152
+ gradient_clipping: C.str_or_fn = C.use_default,
153
+ update_clipping: C.str_or_fn = C.use_default,
154
+ palm: bool = C.use_default,
155
+ beta2_scale: float = 0.8,
156
+ nesterov: bool = True,
157
+ ):
67
158
  defaults = locals()
68
159
  defaults.pop("self")
69
160
  params = defaults.pop("params")
70
- super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm,
71
- C.nesterov_momentum if nesterov else C.heavyball_momentum, C.orthogonalize_update)
161
+ super().__init__(
162
+ params,
163
+ defaults,
164
+ foreach,
165
+ gradient_clipping,
166
+ update_clipping,
167
+ palm,
168
+ C.nesterov_momentum if nesterov else C.heavyball_momentum,
169
+ C.orthogonalize_update,
170
+ )
72
171
 
73
172
 
74
173
  class ForeachLaProp(C.BaseOpt):
75
- def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
76
- foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False,
77
- mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default,
78
- update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8):
174
+ def __init__(
175
+ self,
176
+ params,
177
+ lr=0.0025,
178
+ betas=(0.9, 0.99),
179
+ eps=1e-8,
180
+ weight_decay=0,
181
+ warmup_steps=0,
182
+ foreach: bool = True,
183
+ storage_dtype: str = "float32",
184
+ mars: bool = False,
185
+ caution: bool = False,
186
+ mars_gamma: float = 0.0025,
187
+ gradient_clipping: C.str_or_fn = C.use_default,
188
+ update_clipping: C.str_or_fn = C.use_default,
189
+ palm: bool = C.use_default,
190
+ beta2_scale: float = 0.8,
191
+ ):
79
192
  defaults = locals()
80
193
  defaults.pop("self")
81
194
  params = defaults.pop("params")
@@ -83,15 +196,37 @@ class ForeachLaProp(C.BaseOpt):
83
196
 
84
197
 
85
198
  class MuonLaProp(C.BaseOpt):
86
- def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
87
- foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False,
88
- mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default,
89
- update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8):
199
+ def __init__(
200
+ self,
201
+ params,
202
+ lr=0.0025,
203
+ betas=(0.9, 0.99),
204
+ eps=1e-8,
205
+ weight_decay=0,
206
+ warmup_steps=0,
207
+ foreach: bool = True,
208
+ storage_dtype: str = "float32",
209
+ mars: bool = False,
210
+ caution: bool = False,
211
+ mars_gamma: float = 0.0025,
212
+ gradient_clipping: C.str_or_fn = C.use_default,
213
+ update_clipping: C.str_or_fn = C.use_default,
214
+ palm: bool = C.use_default,
215
+ beta2_scale: float = 0.8,
216
+ ):
90
217
  defaults = locals()
91
218
  defaults.pop("self")
92
219
  params = defaults.pop("params")
93
- super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, C.scale_by_laprop,
94
- C.orthogonalize_update)
220
+ super().__init__(
221
+ params,
222
+ defaults,
223
+ foreach,
224
+ gradient_clipping,
225
+ update_clipping,
226
+ palm,
227
+ C.scale_by_laprop,
228
+ C.orthogonalize_update,
229
+ )
95
230
 
96
231
 
97
232
  class ForeachSOAP(C.BaseOpt):
@@ -105,16 +240,38 @@ class ForeachSOAP(C.BaseOpt):
105
240
  https://arxiv.org/abs/2409.11321
106
241
  https://github.com/nikhilvyas/SOAP
107
242
  """
243
+
108
244
  use_precond_schedule: bool = False
109
245
 
110
- def __init__(self, params, lr: float = 3e-3, betas=(0.9, 0.95), shampoo_beta: float = 0.95, eps: float = 1e-8,
111
- weight_decay: float = 0.01, precondition_frequency: int = 2, max_precond_dim: int = 2048, #
112
- merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
113
- correct_bias: bool = True, warmup_steps: int = 0, split: bool = False, foreach: bool = True,
114
- mars: bool = False, caution: bool = False, mars_gamma: float = 0.0025, palm: bool = C.use_default,
115
- precond_scheduler=(1 / 3, 9), beta2_scale: float = 0.8, use_precond_schedule: bool = C.use_default,
116
- gradient_clipping: C.str_or_fn = C.use_default, update_clipping: C.str_or_fn = C.use_default,
117
- storage_dtype: str = 'float32', stochastic_schedule: bool = False):
246
+ def __init__(
247
+ self,
248
+ params,
249
+ lr: float = 3e-3,
250
+ betas=(0.9, 0.95),
251
+ shampoo_beta: float = 0.95,
252
+ eps: float = 1e-8,
253
+ weight_decay: float = 0.01,
254
+ precondition_frequency: int = 2,
255
+ max_precond_dim: int = 2048, #
256
+ merge_dims: bool = True,
257
+ precondition_1d: bool = False,
258
+ normalize_grads: bool = False,
259
+ correct_bias: bool = True,
260
+ warmup_steps: int = 0,
261
+ split: bool = False,
262
+ foreach: bool = True,
263
+ mars: bool = False,
264
+ caution: bool = False,
265
+ mars_gamma: float = 0.0025,
266
+ palm: bool = C.use_default,
267
+ precond_scheduler=(1 / 3, 9),
268
+ beta2_scale: float = 0.8,
269
+ use_precond_schedule: bool = C.use_default,
270
+ gradient_clipping: C.str_or_fn = C.use_default,
271
+ update_clipping: C.str_or_fn = C.use_default,
272
+ storage_dtype: str = "float32",
273
+ stochastic_schedule: bool = False,
274
+ ):
118
275
  use_precond_schedule = C.default(use_precond_schedule, self.use_precond_schedule)
119
276
 
120
277
  defaults = locals()
@@ -122,24 +279,54 @@ class ForeachSOAP(C.BaseOpt):
122
279
  params = defaults.pop("params")
123
280
 
124
281
  if use_precond_schedule:
125
- del defaults['precondition_frequency']
282
+ del defaults["precondition_frequency"]
126
283
  self.precond_schedule = utils.get_soap_precond_schedule(defaults.pop("precond_scheduler"))
127
284
  else:
128
- del defaults['precond_scheduler']
285
+ del defaults["precond_scheduler"]
129
286
  self.precond_schedule = 1 / defaults.pop("precondition_frequency")
130
- super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, #
131
- C.scale_by_soap)
287
+ super().__init__(
288
+ params,
289
+ defaults,
290
+ foreach,
291
+ gradient_clipping,
292
+ update_clipping,
293
+ palm, #
294
+ C.scale_by_soap,
295
+ )
132
296
 
133
297
 
134
298
  class ForeachSignLaProp(C.BaseOpt):
135
- def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
136
- foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False,
137
- mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default,
138
- update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8):
299
+ def __init__(
300
+ self,
301
+ params,
302
+ lr=0.0025,
303
+ betas=(0.9, 0.99),
304
+ eps=1e-8,
305
+ weight_decay=0,
306
+ warmup_steps=0,
307
+ foreach: bool = True,
308
+ storage_dtype: str = "float32",
309
+ mars: bool = False,
310
+ caution: bool = False,
311
+ mars_gamma: float = 0.0025,
312
+ gradient_clipping: C.str_or_fn = C.use_default,
313
+ update_clipping: C.str_or_fn = C.use_default,
314
+ palm: bool = C.use_default,
315
+ beta2_scale: float = 0.8,
316
+ ):
139
317
  defaults = locals()
140
318
  defaults.pop("self")
141
319
  params = defaults.pop("params")
142
- super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, C.scale_by_laprop, C.sign)
320
+ super().__init__(
321
+ params,
322
+ defaults,
323
+ foreach,
324
+ gradient_clipping,
325
+ update_clipping,
326
+ palm,
327
+ C.scale_by_laprop,
328
+ C.sign,
329
+ )
143
330
 
144
331
 
145
332
  class ForeachSOLP(C.BaseOpt):
@@ -153,16 +340,38 @@ class ForeachSOLP(C.BaseOpt):
153
340
  https://arxiv.org/abs/2409.11321
154
341
  https://github.com/nikhilvyas/SOAP
155
342
  """
343
+
156
344
  use_precond_schedule: bool = False
157
345
 
158
- def __init__(self, params, lr: float = 3e-3, betas=(0.9, 0.95), shampoo_beta: float = 0.95, eps: float = 1e-8,
159
- weight_decay: float = 0.01, precondition_frequency: int = 2, max_precond_dim: int = 2048, #
160
- merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
161
- correct_bias: bool = True, warmup_steps: int = 0, split: bool = False, foreach: bool = True,
162
- mars: bool = False, caution: bool = False, mars_gamma: float = 0.0025, palm: bool = C.use_default,
163
- precond_scheduler=(1 / 3, 9), beta2_scale: float = 0.8, use_precond_schedule: bool = C.use_default,
164
- gradient_clipping: C.str_or_fn = C.use_default, update_clipping: C.str_or_fn = C.use_default,
165
- storage_dtype: str = 'float32', stochastic_schedule: bool = False):
346
+ def __init__(
347
+ self,
348
+ params,
349
+ lr: float = 3e-3,
350
+ betas=(0.9, 0.95),
351
+ shampoo_beta: float = 0.95,
352
+ eps: float = 1e-8,
353
+ weight_decay: float = 0.01,
354
+ precondition_frequency: int = 2,
355
+ max_precond_dim: int = 2048, #
356
+ merge_dims: bool = True,
357
+ precondition_1d: bool = False,
358
+ normalize_grads: bool = False,
359
+ correct_bias: bool = True,
360
+ warmup_steps: int = 0,
361
+ split: bool = False,
362
+ foreach: bool = True,
363
+ mars: bool = False,
364
+ caution: bool = False,
365
+ mars_gamma: float = 0.0025,
366
+ palm: bool = C.use_default,
367
+ precond_scheduler=(1 / 3, 9),
368
+ beta2_scale: float = 0.8,
369
+ use_precond_schedule: bool = C.use_default,
370
+ gradient_clipping: C.str_or_fn = C.use_default,
371
+ update_clipping: C.str_or_fn = C.use_default,
372
+ storage_dtype: str = "float32",
373
+ stochastic_schedule: bool = False,
374
+ ):
166
375
  use_precond_schedule = C.default(use_precond_schedule, self.use_precond_schedule)
167
376
 
168
377
  defaults = locals()
@@ -170,13 +379,20 @@ class ForeachSOLP(C.BaseOpt):
170
379
  params = defaults.pop("params")
171
380
 
172
381
  if use_precond_schedule:
173
- del defaults['precondition_frequency']
382
+ del defaults["precondition_frequency"]
174
383
  self.precond_schedule = utils.get_soap_precond_schedule(defaults.pop("precond_scheduler"))
175
384
  else:
176
- del defaults['precond_scheduler']
385
+ del defaults["precond_scheduler"]
177
386
  self.precond_schedule = 1 / defaults.pop("precondition_frequency")
178
- super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, #
179
- functools.partial(C.scale_by_soap, inner='laprop'))
387
+ super().__init__(
388
+ params,
389
+ defaults,
390
+ foreach,
391
+ gradient_clipping,
392
+ update_clipping,
393
+ palm, #
394
+ functools.partial(C.scale_by_soap, inner="laprop"),
395
+ )
180
396
 
181
397
 
182
398
  class PaLMForeachSOAP(ForeachSOAP):
@@ -194,27 +410,71 @@ class PrecondSchedulePaLMForeachSOAP(ForeachSOAP):
194
410
 
195
411
 
196
412
  class OrthoLaProp(C.BaseOpt):
197
- def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
198
- foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False,
199
- mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default,
200
- update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8):
413
+ def __init__(
414
+ self,
415
+ params,
416
+ lr=0.0025,
417
+ betas=(0.9, 0.99),
418
+ eps=1e-8,
419
+ weight_decay=0,
420
+ warmup_steps=0,
421
+ foreach: bool = True,
422
+ storage_dtype: str = "float32",
423
+ mars: bool = False,
424
+ caution: bool = False,
425
+ mars_gamma: float = 0.0025,
426
+ gradient_clipping: C.str_or_fn = C.use_default,
427
+ update_clipping: C.str_or_fn = C.use_default,
428
+ palm: bool = C.use_default,
429
+ beta2_scale: float = 0.8,
430
+ ):
201
431
  defaults = locals()
202
432
  defaults.pop("self")
203
433
  params = defaults.pop("params")
204
- super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm,
205
- C.orthogonalize_grad_to_param, C.scale_by_laprop)
434
+ super().__init__(
435
+ params,
436
+ defaults,
437
+ foreach,
438
+ gradient_clipping,
439
+ update_clipping,
440
+ palm,
441
+ C.orthogonalize_grad_to_param,
442
+ C.scale_by_laprop,
443
+ )
206
444
 
207
445
 
208
446
  class LaPropOrtho(C.BaseOpt):
209
- def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
210
- foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False,
211
- mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default,
212
- update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8):
447
+ def __init__(
448
+ self,
449
+ params,
450
+ lr=0.0025,
451
+ betas=(0.9, 0.99),
452
+ eps=1e-8,
453
+ weight_decay=0,
454
+ warmup_steps=0,
455
+ foreach: bool = True,
456
+ storage_dtype: str = "float32",
457
+ mars: bool = False,
458
+ caution: bool = False,
459
+ mars_gamma: float = 0.0025,
460
+ gradient_clipping: C.str_or_fn = C.use_default,
461
+ update_clipping: C.str_or_fn = C.use_default,
462
+ palm: bool = C.use_default,
463
+ beta2_scale: float = 0.8,
464
+ ):
213
465
  defaults = locals()
214
466
  defaults.pop("self")
215
467
  params = defaults.pop("params")
216
- super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, C.scale_by_laprop,
217
- C.orthogonalize_grad_to_param)
468
+ super().__init__(
469
+ params,
470
+ defaults,
471
+ foreach,
472
+ gradient_clipping,
473
+ update_clipping,
474
+ palm,
475
+ C.scale_by_laprop,
476
+ C.orthogonalize_grad_to_param,
477
+ )
218
478
 
219
479
 
220
480
  class ForeachPSGDKron(C.BaseOpt):
@@ -228,20 +488,43 @@ class ForeachPSGDKron(C.BaseOpt):
228
488
  cached: bool = False
229
489
  exp_avg_input: bool = True
230
490
 
231
- def __init__(self, params, lr=0.001, beta=0.9, weight_decay=0.0, preconditioner_update_probability=None,
232
- max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
233
- momentum_into_precond_update=True, warmup_steps: int = 0, merge_dims: bool = False,
234
- split: bool = False, store_triu_as_line: bool = True, foreach: bool = True, q_dtype='float32',
235
- stochastic_schedule: bool = False, storage_dtype: str = 'float32', mars: bool = False,
236
- caution: bool = False, mars_gamma: float = 0.0025, delayed: Optional[bool] = C.use_default,
237
- cached: Optional[bool] = C.use_default, exp_avg_input: Optional[bool] = C.use_default,
238
- gradient_clipping: C.str_or_fn = C.use_default, update_clipping: C.str_or_fn = C.use_default, #
239
- # expert parameters
240
- precond_init_scale=1.0, precond_lr=0.1):
491
+ def __init__(
492
+ self,
493
+ params,
494
+ lr=0.001,
495
+ beta=0.9,
496
+ weight_decay=0.0,
497
+ preconditioner_update_probability=None,
498
+ max_size_triangular=2048,
499
+ min_ndim_triangular=2,
500
+ memory_save_mode=None,
501
+ momentum_into_precond_update=True,
502
+ warmup_steps: int = 0,
503
+ merge_dims: bool = False,
504
+ split: bool = False,
505
+ store_triu_as_line: bool = True,
506
+ foreach: bool = True,
507
+ q_dtype="float32",
508
+ stochastic_schedule: bool = False,
509
+ storage_dtype: str = "float32",
510
+ mars: bool = False,
511
+ caution: bool = False,
512
+ mars_gamma: float = 0.0025,
513
+ delayed: Optional[bool] = C.use_default,
514
+ cached: Optional[bool] = C.use_default,
515
+ exp_avg_input: Optional[bool] = C.use_default,
516
+ gradient_clipping: C.str_or_fn = C.use_default,
517
+ update_clipping: C.str_or_fn = C.use_default, #
518
+ # expert parameters
519
+ precond_init_scale=None,
520
+ precond_init_scale_scale=1,
521
+ precond_lr=0.1,
522
+ ):
241
523
  defaults = locals()
242
524
  defaults.pop("self")
243
- self.precond_schedule = defaults.pop(
244
- "preconditioner_update_probability") or utils.precond_update_prob_schedule()
525
+ self.precond_schedule = (
526
+ defaults.pop("preconditioner_update_probability") or utils.precond_update_prob_schedule()
527
+ )
245
528
  params = defaults.pop("params")
246
529
 
247
530
  delayed = C.default(delayed, self.delayed)
@@ -249,9 +532,16 @@ class ForeachPSGDKron(C.BaseOpt):
249
532
  exp_avg_input = C.default(exp_avg_input, self.exp_avg_input)
250
533
  update_clipping = C.default(update_clipping, utils.trust_region_clip_)
251
534
 
252
- super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, False, #
253
- *(C.exp_avg,) * exp_avg_input, #
254
- functools.partial(C.scale_by_delayed_psgd if delayed else C.scale_by_psgd, cached=cached))
535
+ super().__init__(
536
+ params,
537
+ defaults,
538
+ foreach,
539
+ gradient_clipping,
540
+ update_clipping,
541
+ False, #
542
+ *(C.exp_avg,) * exp_avg_input, #
543
+ functools.partial(C.scale_by_delayed_psgd if delayed else C.scale_by_psgd, cached=cached),
544
+ )
255
545
 
256
546
 
257
547
  class ForeachPurePSGD(ForeachPSGDKron):
@@ -275,6 +565,90 @@ class ForeachCachedNewtonPSGD(ForeachCachedPSGDKron):
275
565
  hessian_approx = True
276
566
 
277
567
 
568
+ class NewtonHybrid2PSGDKron(ForeachCachedNewtonPSGD):
569
+ hvp_interval = 2
570
+
571
+
572
+ class ForeachPSGDLRA(C.BaseOpt):
573
+ """
574
+ Originally from Evan Walters and Omead Pooladzandi, 2024
575
+ Modified under Creative Commons Attribution 4.0 International
576
+ Source available at https://github.com/evanatyourservice/kron_torch/blob/97a2b5ee8a1a4c29e4780bbf6c521e545189eff9/kron_torch/kron.py
577
+ """
578
+
579
+ delayed: bool = False
580
+ exp_avg_input: bool = True
581
+
582
+ def __init__(
583
+ self,
584
+ params,
585
+ lr=0.001,
586
+ beta=0.9,
587
+ weight_decay=0.0,
588
+ preconditioner_update_probability=None,
589
+ momentum_into_precond_update=True,
590
+ rank: Optional[int] = None,
591
+ warmup_steps: int = 0,
592
+ foreach: bool = True,
593
+ q_dtype="float32",
594
+ stochastic_schedule: bool = False,
595
+ storage_dtype: str = "float32",
596
+ mars: bool = False,
597
+ caution: bool = False,
598
+ mars_gamma: float = 0.0025,
599
+ delayed: Optional[bool] = C.use_default,
600
+ exp_avg_input: Optional[bool] = C.use_default,
601
+ gradient_clipping: C.str_or_fn = C.use_default,
602
+ update_clipping: C.str_or_fn = C.use_default,
603
+ eps: float = 1e-8, #
604
+ # expert parameters
605
+ precond_init_scale=None,
606
+ precond_init_scale_scale=1,
607
+ precond_lr=0.1,
608
+ ):
609
+ defaults = locals()
610
+ defaults.pop("self")
611
+ self.precond_schedule = (
612
+ defaults.pop("preconditioner_update_probability") or utils.precond_update_prob_schedule()
613
+ )
614
+ params = defaults.pop("params")
615
+
616
+ if rank is None:
617
+ utils.warn_once(
618
+ f"{rank=}. It will be set to log2(param_count). This requires `params` to be of type list. Currently, {type(params)=}"
619
+ )
620
+ params = list(params)
621
+ defaults["rank"] = round(math.log2(sum(p.numel() for p in params)))
622
+ utils.warn_once(f"rank was set to {defaults['rank']}")
623
+
624
+ delayed = C.default(delayed, self.delayed)
625
+ exp_avg_input = C.default(exp_avg_input, self.exp_avg_input)
626
+ update_clipping = C.default(update_clipping, utils.trust_region_clip_)
627
+
628
+ super().__init__(
629
+ params,
630
+ defaults,
631
+ foreach,
632
+ gradient_clipping,
633
+ update_clipping,
634
+ False, #
635
+ *(C.exp_avg,) * exp_avg_input, #
636
+ C.scale_by_delayed_psgd_lra if delayed else C.scale_by_psgd_lra,
637
+ )
638
+
639
+
640
+ class ForeachDelayedPSGDLRA(ForeachPSGDLRA):
641
+ delayed: bool = True
642
+
643
+
644
+ class ForeachNewtonPSGDLRA(ForeachPSGDLRA):
645
+ hessian_approx = True
646
+
647
+
648
+ class NewtonHybrid2PSGDLRA(ForeachNewtonPSGDLRA):
649
+ hvp_interval = 2
650
+
651
+
278
652
  PalmForEachSoap = PaLMForeachSOAP
279
653
  PaLMSOAP = PaLMForeachSOAP
280
654
  PaLMSFAdamW = PaLMForeachSFAdamW
@@ -293,11 +667,52 @@ CachedPSGDKron = ForeachCachedPSGDKron
293
667
  CachedDelayedPSGDKron = ForeachCachedDelayedPSGDKron
294
668
  Muon = ForeachMuon
295
669
  SignLaProp = ForeachSignLaProp
296
-
297
- __all__ = ["Muon", "RMSprop", "PrecondSchedulePaLMSOAP", "PSGDKron", "PurePSGD", "DelayedPSGD", "CachedPSGDKron",
298
- "CachedDelayedPSGDKron", "PalmForEachSoap", "PaLMSOAP", "PaLMSFAdamW", "LaProp", "ADOPT",
299
- "PrecondScheduleSOAP", "PrecondSchedulePaLMSOAP", 'RMSprop', 'MuonLaProp', 'ForeachSignLaProp' #
300
- "ForeachAdamW", "ForeachSFAdamW",
301
- "ForeachLaProp", "ForeachADOPT", "ForeachSOAP", "ForeachPSGDKron", "ForeachPurePSGD", "ForeachDelayedPSGD",
302
- "ForeachCachedPSGDKron", "ForeachCachedDelayedPSGDKron", "ForeachRMSprop", "ForeachMuon",
303
- 'ForeachCachedNewtonPSGD', 'OrthoLaProp', 'LaPropOrtho', 'SignLaProp']
670
+ DelayedPSGDLRA = ForeachDelayedPSGDLRA
671
+ PSGDLRA = ForeachPSGDLRA
672
+ NewtonPSGDLRA = ForeachNewtonPSGDLRA
673
+
674
+ __all__ = [
675
+ "Muon",
676
+ "RMSprop",
677
+ "PrecondSchedulePaLMSOAP",
678
+ "PSGDKron",
679
+ "PurePSGD",
680
+ "DelayedPSGD",
681
+ "CachedPSGDKron",
682
+ "CachedDelayedPSGDKron",
683
+ "PalmForEachSoap",
684
+ "PaLMSOAP",
685
+ "PaLMSFAdamW",
686
+ "LaProp",
687
+ "ADOPT",
688
+ "PrecondScheduleSOAP",
689
+ "PrecondSchedulePaLMSOAP",
690
+ "RMSprop",
691
+ "MuonLaProp",
692
+ "ForeachSignLaProp",
693
+ "ForeachDelayedPSGDLRA",
694
+ "ForeachPSGDLRA",
695
+ "ForeachPSGDLRA",
696
+ "ForeachNewtonPSGDLRA", #
697
+ "ForeachAdamW",
698
+ "ForeachSFAdamW",
699
+ "ForeachLaProp",
700
+ "ForeachADOPT",
701
+ "ForeachSOAP",
702
+ "ForeachPSGDKron",
703
+ "ForeachPurePSGD",
704
+ "ForeachDelayedPSGD",
705
+ "ForeachCachedPSGDKron",
706
+ "ForeachCachedDelayedPSGDKron",
707
+ "ForeachRMSprop",
708
+ "ForeachMuon",
709
+ "ForeachCachedNewtonPSGD",
710
+ "OrthoLaProp",
711
+ "LaPropOrtho",
712
+ "SignLaProp",
713
+ "DelayedPSGD",
714
+ "PSGDLRA",
715
+ "NewtonPSGDLRA",
716
+ "NewtonHybrid2PSGDLRA",
717
+ "NewtonHybrid2PSGDKron",
718
+ ]