heavyball 1.7.2__py3-none-any.whl → 2.0.0.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
heavyball/__init__.py CHANGED
@@ -2,10 +2,45 @@ import functools
2
2
  import math
3
3
  from typing import Optional
4
4
 
5
+ import torch.optim
6
+
5
7
  from . import chainable as C
6
8
  from . import utils
7
9
 
8
10
 
11
+ class SAMWrapper(torch.optim.Optimizer):
12
+ def __init__(self, params, wrapped_optimizer: utils.StatefulOptimizer, ball: float = 0.1):
13
+ if not isinstance(wrapped_optimizer, utils.StatefulOptimizer):
14
+ raise ValueError(f"{wrapped_optimizer.__class__.__name__} is not a HeavyBall optimizer")
15
+ super().__init__(params, {"ball": ball})
16
+ self.wrapped_optimizer = wrapped_optimizer
17
+
18
+ @torch.no_grad()
19
+ def step(self, closure=None):
20
+ if closure is None:
21
+ raise ValueError("SAM requires closure")
22
+ with torch.enable_grad():
23
+ closure()
24
+ old_params = [utils.sam_step(group["params"], group["ball"]) for group in self.param_groups]
25
+
26
+ originaL_handle_closure = self.wrapped_optimizer._handle_closure
27
+
28
+ def _handle_closure(closure):
29
+ originaL_handle_closure(closure)
30
+ for group, old in zip(self.param_groups, old_params):
31
+ utils.copy_stochastic_list_(group["params"], old)
32
+
33
+ try:
34
+ self.wrapped_optimizer._handle_closure = _handle_closure
35
+ loss = self.wrapped_optimizer.step(closure)
36
+ finally:
37
+ self.wrapped_optimizer._handle_closure = originaL_handle_closure
38
+ return loss
39
+
40
+ def zero_grad(self, set_to_none: bool = True):
41
+ self.wrapped_optimizer.zero_grad()
42
+
43
+
9
44
  class ForeachAdamW(C.BaseOpt):
10
45
  def __init__(
11
46
  self,
@@ -24,10 +59,16 @@ class ForeachAdamW(C.BaseOpt):
24
59
  update_clipping: C.str_or_fn = C.use_default,
25
60
  palm: bool = C.use_default,
26
61
  beta2_scale: float = 0.8,
62
+ **kwargs,
27
63
  ):
28
64
  defaults = locals()
29
65
  defaults.pop("self")
30
66
  params = defaults.pop("params")
67
+ defaults.update(defaults.pop("kwargs"))
68
+
69
+ if kwargs:
70
+ utils.warn_once(f"Working with uncaptured keyword arguments: {kwargs}")
71
+
31
72
  super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, C.update_by_adam)
32
73
 
33
74
 
@@ -55,10 +96,16 @@ class ForeachRMSprop(C.BaseOpt):
55
96
  update_clipping: C.str_or_fn = C.use_default,
56
97
  palm: bool = C.use_default,
57
98
  beta2_scale: float = 0.8,
99
+ **kwargs,
58
100
  ):
59
101
  defaults = locals()
60
102
  defaults.pop("self")
61
103
  params = defaults.pop("params")
104
+ defaults.update(defaults.pop("kwargs"))
105
+
106
+ if kwargs:
107
+ utils.warn_once(f"Working with uncaptured keyword arguments: {kwargs}")
108
+
62
109
  super().__init__(
63
110
  params,
64
111
  defaults,
@@ -90,10 +137,16 @@ class ForeachSFAdamW(C.ScheduleFree):
90
137
  update_clipping: C.str_or_fn = C.use_default,
91
138
  palm: bool = C.use_default,
92
139
  beta2_scale: float = 0.8,
140
+ **kwargs,
93
141
  ):
94
142
  defaults = locals()
95
143
  defaults.pop("self")
96
144
  params = defaults.pop("params")
145
+ defaults.update(defaults.pop("kwargs"))
146
+
147
+ if kwargs:
148
+ utils.warn_once(f"Working with uncaptured keyword arguments: {kwargs}")
149
+
97
150
  super().__init__(
98
151
  params,
99
152
  defaults,
@@ -106,6 +159,49 @@ class ForeachSFAdamW(C.ScheduleFree):
106
159
  )
107
160
 
108
161
 
162
+ class MSAMLaProp(C.MSAM):
163
+ def __init__(
164
+ self,
165
+ params,
166
+ lr=0.0025,
167
+ betas=(0.9, 0.99),
168
+ eps=1e-6,
169
+ weight_decay=0,
170
+ warmup_steps=0,
171
+ r=0.0,
172
+ weight_lr_power=2.0,
173
+ foreach: bool = True,
174
+ storage_dtype: str = "float32",
175
+ mars: bool = False,
176
+ caution: bool = False,
177
+ mars_gamma: float = 0.0025,
178
+ gradient_clipping: C.str_or_fn = C.use_default,
179
+ update_clipping: C.str_or_fn = C.use_default,
180
+ palm: bool = C.use_default,
181
+ beta2_scale: float = 0.8,
182
+ sam_step_size: float = 0.1,
183
+ **kwargs,
184
+ ):
185
+ defaults = locals()
186
+ defaults.pop("self")
187
+ params = defaults.pop("params")
188
+ defaults.update(defaults.pop("kwargs"))
189
+
190
+ if kwargs:
191
+ utils.warn_once(f"Working with uncaptured keyword arguments: {kwargs}")
192
+
193
+ super().__init__(
194
+ params,
195
+ defaults,
196
+ foreach,
197
+ gradient_clipping,
198
+ update_clipping,
199
+ palm,
200
+ C.scale_by_exp_avg_sq,
201
+ C.update_by_msam,
202
+ )
203
+
204
+
109
205
  class PaLMForeachSFAdamW(ForeachSFAdamW):
110
206
  palm: bool = True
111
207
 
@@ -128,10 +224,16 @@ class ForeachADOPT(C.BaseOpt):
128
224
  update_clipping: C.str_or_fn = C.use_default,
129
225
  palm: bool = C.use_default,
130
226
  beta2_scale: float = 0.8,
227
+ **kwargs,
131
228
  ):
132
229
  defaults = locals()
133
230
  defaults.pop("self")
134
231
  params = defaults.pop("params")
232
+ defaults.update(defaults.pop("kwargs"))
233
+
234
+ if kwargs:
235
+ utils.warn_once(f"Working with uncaptured keyword arguments: {kwargs}")
236
+
135
237
  super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, C.update_by_adopt)
136
238
 
137
239
 
@@ -154,10 +256,16 @@ class ForeachMuon(C.BaseOpt):
154
256
  palm: bool = C.use_default,
155
257
  beta2_scale: float = 0.8,
156
258
  nesterov: bool = True,
259
+ **kwargs,
157
260
  ):
158
261
  defaults = locals()
159
262
  defaults.pop("self")
160
263
  params = defaults.pop("params")
264
+ defaults.update(defaults.pop("kwargs"))
265
+
266
+ if kwargs:
267
+ utils.warn_once(f"Working with uncaptured keyword arguments: {kwargs}")
268
+
161
269
  super().__init__(
162
270
  params,
163
271
  defaults,
@@ -165,7 +273,7 @@ class ForeachMuon(C.BaseOpt):
165
273
  gradient_clipping,
166
274
  update_clipping,
167
275
  palm,
168
- C.nesterov_momentum if nesterov else C.heavyball_momentum,
276
+ C.nesterov_ema if nesterov else C.exp_avg,
169
277
  C.orthogonalize_update,
170
278
  )
171
279
 
@@ -188,10 +296,16 @@ class ForeachLaProp(C.BaseOpt):
188
296
  update_clipping: C.str_or_fn = C.use_default,
189
297
  palm: bool = C.use_default,
190
298
  beta2_scale: float = 0.8,
299
+ **kwargs,
191
300
  ):
192
301
  defaults = locals()
193
302
  defaults.pop("self")
194
303
  params = defaults.pop("params")
304
+ defaults.update(defaults.pop("kwargs"))
305
+
306
+ if kwargs:
307
+ utils.warn_once(f"Working with uncaptured keyword arguments: {kwargs}")
308
+
195
309
  super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, C.update_by_laprop)
196
310
 
197
311
 
@@ -213,10 +327,16 @@ class MuonLaProp(C.BaseOpt):
213
327
  update_clipping: C.str_or_fn = C.use_default,
214
328
  palm: bool = C.use_default,
215
329
  beta2_scale: float = 0.8,
330
+ **kwargs,
216
331
  ):
217
332
  defaults = locals()
218
333
  defaults.pop("self")
219
334
  params = defaults.pop("params")
335
+ defaults.update(defaults.pop("kwargs"))
336
+
337
+ if kwargs:
338
+ utils.warn_once(f"Working with uncaptured keyword arguments: {kwargs}")
339
+
220
340
  super().__init__(
221
341
  params,
222
342
  defaults,
@@ -271,12 +391,18 @@ class ForeachSOAP(C.BaseOpt):
271
391
  update_clipping: C.str_or_fn = C.use_default,
272
392
  storage_dtype: str = "float32",
273
393
  stochastic_schedule: bool = False,
394
+ precond_grad_accum: bool = False,
395
+ **kwargs,
274
396
  ):
275
397
  use_precond_schedule = C.default(use_precond_schedule, self.use_precond_schedule)
276
398
 
277
399
  defaults = locals()
278
400
  defaults.pop("self")
279
401
  params = defaults.pop("params")
402
+ defaults.update(defaults.pop("kwargs"))
403
+
404
+ if kwargs:
405
+ utils.warn_once(f"Working with uncaptured keyword arguments: {kwargs}")
280
406
 
281
407
  if use_precond_schedule:
282
408
  del defaults["precondition_frequency"]
@@ -313,10 +439,16 @@ class ForeachSignLaProp(C.BaseOpt):
313
439
  update_clipping: C.str_or_fn = C.use_default,
314
440
  palm: bool = C.use_default,
315
441
  beta2_scale: float = 0.8,
442
+ **kwargs,
316
443
  ):
317
444
  defaults = locals()
318
445
  defaults.pop("self")
319
446
  params = defaults.pop("params")
447
+ defaults.update(defaults.pop("kwargs"))
448
+
449
+ if kwargs:
450
+ utils.warn_once(f"Working with uncaptured keyword arguments: {kwargs}")
451
+
320
452
  super().__init__(
321
453
  params,
322
454
  defaults,
@@ -371,12 +503,17 @@ class ForeachSOLP(C.BaseOpt):
371
503
  update_clipping: C.str_or_fn = C.use_default,
372
504
  storage_dtype: str = "float32",
373
505
  stochastic_schedule: bool = False,
506
+ **kwargs,
374
507
  ):
375
508
  use_precond_schedule = C.default(use_precond_schedule, self.use_precond_schedule)
376
509
 
377
510
  defaults = locals()
378
511
  defaults.pop("self")
379
512
  params = defaults.pop("params")
513
+ defaults.update(defaults.pop("kwargs"))
514
+
515
+ if kwargs:
516
+ utils.warn_once(f"Working with uncaptured keyword arguments: {kwargs}")
380
517
 
381
518
  if use_precond_schedule:
382
519
  del defaults["precondition_frequency"]
@@ -427,10 +564,15 @@ class OrthoLaProp(C.BaseOpt):
427
564
  update_clipping: C.str_or_fn = C.use_default,
428
565
  palm: bool = C.use_default,
429
566
  beta2_scale: float = 0.8,
567
+ **kwargs,
430
568
  ):
431
569
  defaults = locals()
432
570
  defaults.pop("self")
433
571
  params = defaults.pop("params")
572
+ defaults.update(defaults.pop("kwargs"))
573
+
574
+ if kwargs:
575
+ utils.warn_once(f"Working with uncaptured keyword arguments: {kwargs}")
434
576
  super().__init__(
435
577
  params,
436
578
  defaults,
@@ -461,10 +603,15 @@ class LaPropOrtho(C.BaseOpt):
461
603
  update_clipping: C.str_or_fn = C.use_default,
462
604
  palm: bool = C.use_default,
463
605
  beta2_scale: float = 0.8,
606
+ **kwargs,
464
607
  ):
465
608
  defaults = locals()
466
609
  defaults.pop("self")
467
610
  params = defaults.pop("params")
611
+ defaults.update(defaults.pop("kwargs"))
612
+
613
+ if kwargs:
614
+ utils.warn_once(f"Working with uncaptured keyword arguments: {kwargs}")
468
615
  super().__init__(
469
616
  params,
470
617
  defaults,
@@ -487,12 +634,14 @@ class ForeachPSGDKron(C.BaseOpt):
487
634
  delayed: bool = False
488
635
  cached: bool = False
489
636
  exp_avg_input: bool = True
637
+ quad: bool = False
490
638
 
491
639
  def __init__(
492
640
  self,
493
641
  params,
494
642
  lr=0.001,
495
- beta=0.9,
643
+ beta=None,
644
+ betas=(0.9, 0.999),
496
645
  weight_decay=0.0,
497
646
  preconditioner_update_probability=None,
498
647
  max_size_triangular=2048,
@@ -515,23 +664,38 @@ class ForeachPSGDKron(C.BaseOpt):
515
664
  exp_avg_input: Optional[bool] = C.use_default,
516
665
  gradient_clipping: C.str_or_fn = C.use_default,
517
666
  update_clipping: C.str_or_fn = C.use_default, #
667
+ adaptive: bool = False,
668
+ ortho_method: Optional[str] = None, # If None, no orthogonalization
669
+ precond_grad_accum: bool = False,
670
+ lower_bound_beta: float = 0.9, # 0.0 recovers pre-2.0.0 PSGD
671
+ inverse_free: bool = C.use_default,
672
+ dampening: float = 2**-13,
673
+ precond_update_power_iterations: int = 2,
518
674
  # expert parameters
519
675
  precond_init_scale=None,
520
- precond_init_scale_scale=1,
521
- precond_lr=0.1,
676
+ precond_init_scale_scale: float = 1,
677
+ precond_init_scale_power: Optional[float] = None,
678
+ precond_lr: float = 0.1,
679
+ **kwargs,
522
680
  ):
681
+ delayed = C.default(delayed, self.delayed)
682
+ cached = C.default(cached, self.cached)
683
+ exp_avg_input = C.default(exp_avg_input, self.exp_avg_input)
684
+ update_clipping = C.default(update_clipping, utils.trust_region_clip_)
685
+ inverse_free = C.default(inverse_free, self.quad)
686
+
523
687
  defaults = locals()
524
688
  defaults.pop("self")
689
+ defaults.update(defaults.pop("kwargs"))
690
+
691
+ if kwargs:
692
+ utils.warn_once(f"Working with uncaptured keyword arguments: {kwargs}")
693
+
525
694
  self.precond_schedule = (
526
695
  defaults.pop("preconditioner_update_probability") or utils.precond_update_prob_schedule()
527
696
  )
528
697
  params = defaults.pop("params")
529
698
 
530
- delayed = C.default(delayed, self.delayed)
531
- cached = C.default(cached, self.cached)
532
- exp_avg_input = C.default(exp_avg_input, self.exp_avg_input)
533
- update_clipping = C.default(update_clipping, utils.trust_region_clip_)
534
-
535
699
  super().__init__(
536
700
  params,
537
701
  defaults,
@@ -569,6 +733,11 @@ class NewtonHybrid2PSGDKron(ForeachCachedNewtonPSGD):
569
733
  hvp_interval = 2
570
734
 
571
735
 
736
+ class QUAD(ForeachPSGDKron):
737
+ quad = True
738
+ cached = True
739
+
740
+
572
741
  class ForeachPSGDLRA(C.BaseOpt):
573
742
  """
574
743
  Originally from Evan Walters and Omead Pooladzandi, 2024
@@ -601,13 +770,24 @@ class ForeachPSGDLRA(C.BaseOpt):
601
770
  gradient_clipping: C.str_or_fn = C.use_default,
602
771
  update_clipping: C.str_or_fn = C.use_default,
603
772
  eps: float = 1e-8, #
604
- # expert parameters
773
+ precond_grad_accum: bool = False, # expert parameters
605
774
  precond_init_scale=None,
606
- precond_init_scale_scale=1,
607
- precond_lr=0.1,
775
+ precond_init_scale_scale: float = 1,
776
+ precond_init_scale_power: Optional[float] = None,
777
+ precond_lr: float = 0.1,
778
+ **kwargs,
608
779
  ):
780
+ delayed = C.default(delayed, self.delayed)
781
+ exp_avg_input = C.default(exp_avg_input, self.exp_avg_input)
782
+ update_clipping = C.default(update_clipping, utils.trust_region_clip_)
783
+
609
784
  defaults = locals()
610
785
  defaults.pop("self")
786
+ defaults.update(defaults.pop("kwargs"))
787
+
788
+ if kwargs:
789
+ utils.warn_once(f"Working with uncaptured keyword arguments: {kwargs}")
790
+
611
791
  self.precond_schedule = (
612
792
  defaults.pop("preconditioner_update_probability") or utils.precond_update_prob_schedule()
613
793
  )
@@ -621,10 +801,6 @@ class ForeachPSGDLRA(C.BaseOpt):
621
801
  defaults["rank"] = round(math.log2(sum(p.numel() for p in params)))
622
802
  utils.warn_once(f"rank was set to {defaults['rank']}")
623
803
 
624
- delayed = C.default(delayed, self.delayed)
625
- exp_avg_input = C.default(exp_avg_input, self.exp_avg_input)
626
- update_clipping = C.default(update_clipping, utils.trust_region_clip_)
627
-
628
804
  super().__init__(
629
805
  params,
630
806
  defaults,
@@ -715,4 +891,5 @@ __all__ = [
715
891
  "NewtonPSGDLRA",
716
892
  "NewtonHybrid2PSGDLRA",
717
893
  "NewtonHybrid2PSGDKron",
894
+ "MSAMLaProp",
718
895
  ]