heavyball 0.17.0__py3-none-any.whl → 0.17.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
heavyball/__init__.py CHANGED
@@ -21,26 +21,25 @@ PalmForEachSoap = PaLMForeachSOAP
21
21
  PaLMSOAP = PaLMForeachSOAP
22
22
  PaLMSFAdamW = PaLMForeachSFAdamW
23
23
  PaLMSFSoap = SFPaLMForeachSOAP
24
- PaLMForeachSOAP = PaLMForeachSOAP
25
24
  PrecondScheduleSFPaLMSOAP = PrecondScheduleSFPaLMSOAP
26
25
  SOAP = ForeachSOAP
27
26
  SFAdamW = ForeachSFAdamW
28
27
  LaProp = ForeachLaProp
29
28
  ADOPT = ForeachADOPT
30
- PrecondScheduleForeachSOAP = PrecondScheduleForeachSOAP
31
- PrecondSchedulePaLMForeachSOAP = PrecondSchedulePaLMForeachSOAP
29
+ PrecondScheduleSOAP = PrecondScheduleForeachSOAP
30
+ PrecondSchedulePaLMSOAP = PrecondSchedulePaLMForeachSOAP
32
31
  PSGDKron = ForeachPSGDKron
33
32
  AdamW = ForeachAdamW
34
33
  PurePSGD = ForeachPurePSGD
35
34
  PaLMPAdam = ForeachPaLMPAdam
36
35
  DelayedPSGD = ForeachDelayedPSGD
37
36
  CachedPSGDKron = ForeachCachedPSGDKron
38
- CachedDelayedPSGDKron
37
+ CachedDelayedPSGDKron = ForeachCachedDelayedPSGDKron
39
38
 
40
39
  __all__ = ['PalmForEachSoap', 'PaLMForeachSFAdamW', 'PaLMForeachSOAP', 'SFPaLMForeachSOAP', 'PrecondScheduleSFPaLMSOAP',
41
40
  'ForeachSOAP', 'ForeachSFAdamW', 'ForeachLaProp', 'ForeachADOPT', 'PrecondScheduleForeachSOAP',
42
41
  'PrecondSchedulePaLMForeachSOAP', 'ForeachPSGDKron', 'ForeachAdamW', 'ForeachPurePSGD', 'ForeachPaLMPAdam',
43
- 'ForeachDelayedPSGD', 'ForeachCachedPSGDKron', 'ForeachCachedDelayedPSGDKron' #
44
- 'PaLMSOAP', 'PaLMSFAdamW', 'PaLMSFSoap', 'PaLMSFAdamW', 'PaLMForeachSOAP', 'PrecondScheduleSFPaLMSOAP',
42
+ 'ForeachDelayedPSGD', 'ForeachCachedPSGDKron', 'ForeachCachedDelayedPSGDKron', #
43
+ 'PaLMSOAP', 'PaLMSFAdamW', 'PaLMSFSoap', 'PaLMSFAdamW', 'PrecondScheduleSFPaLMSOAP',
45
44
  'SOAP', 'SFAdamW', 'LaProp', 'ADOPT', 'PSGDKron', 'AdamW', 'PurePSGD', 'PaLMPAdam', 'DelayedPSGD',
46
- 'CachedPSGDKron', 'CachedDelayedPSGDKron']
45
+ 'CachedPSGDKron', 'CachedDelayedPSGDKron', 'PrecondScheduleSOAP', 'PrecondSchedulePaLMSOAP']
heavyball/utils.py CHANGED
@@ -332,6 +332,16 @@ def promote(x):
332
332
  return x
333
333
 
334
334
 
335
+ def min_dtype(xs: List[torch.Tensor]):
336
+ dtypes = [x.dtype for x in xs]
337
+ for d in (torch.float32, torch.bfloat16, torch.float16):
338
+ if all(d == x for x in dtypes):
339
+ return d
340
+ if all(d in (x, torch.float32, torch.float64) for x in dtypes):
341
+ return d
342
+ return torch.float32
343
+
344
+
335
345
  def update_preconditioner(grad, state, max_precond_dim, precondition_1d, beta, update_precond):
336
346
  """
337
347
  Updates the preconditioner matrices and the eigenbases (L, R, Q_L, Q_R in the paper).
@@ -471,13 +481,8 @@ def copy_stochastic_list_(target: List[torch.Tensor], source: List[torch.Tensor]
471
481
  copy_stochastic_(t, s)
472
482
 
473
483
 
474
- def copy_stochastic_(target: torch.Tensor, source: torch.Tensor):
475
- if target.data_ptr() == source.data_ptr():
476
- return
477
- if target.dtype != torch.bfloat16:
478
- set_(target, source)
479
- return
480
-
484
+ @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
485
+ def _compilable_copy_stochastic_(target: torch.Tensor, source: torch.Tensor):
481
486
  """Taken as-is from https://github.com/pytorch/pytorch/issues/120376#issuecomment-1974828905"""
482
487
  # create a random 16 bit integer
483
488
  result = torch.randint_like(source, dtype=torch.int32, low=0, high=(1 << 16))
@@ -492,6 +497,15 @@ def copy_stochastic_(target: torch.Tensor, source: torch.Tensor):
492
497
  target.copy_(result.view(dtype=torch.float32))
493
498
 
494
499
 
500
+ def copy_stochastic_(target: torch.Tensor, source: torch.Tensor):
501
+ if target.data_ptr() == source.data_ptr():
502
+ return
503
+ if target.dtype != torch.bfloat16 or source.dtype not in (torch.float16, torch.float32, torch.float64):
504
+ set_(target, source)
505
+ return
506
+ _compilable_copy_stochastic_(target, source)
507
+
508
+
495
509
  def update_param_(param: List[torch.Tensor], update: List[torch.Tensor], lr: float, decay: float,
496
510
  add_fn: callable = None):
497
511
  param32 = [promote(p) for p in param]
@@ -602,7 +616,8 @@ def psgd_balance_Q(Q_in):
602
616
 
603
617
 
604
618
  def psgd_calc_A_and_conjB(exprA, G, Q, V):
605
- A = torch.einsum(exprA, *Q, G)
619
+ md = min_dtype(Q)
620
+ A = torch.einsum(exprA, *[q.to(md) for q in Q], G.to(md))
606
621
  order = G.dim()
607
622
  p = list(range(order))
608
623
  conjB = torch.permute(V.conj(), p[1:] + p[:1])
@@ -669,7 +684,8 @@ def psgd_update_precond(Q, exprs, V, G, step, tiny):
669
684
  @decorator
670
685
  def psgd_precond_grad(Q, exprs, G, inplace: bool = False):
671
686
  """Precondition gradient G with preconditioner Q."""
672
- out = torch.einsum(exprs[-1], *[q.conj() for q in Q], *Q, G.to(Q[0].dtype))
687
+ md = min_dtype(Q)
688
+ out = torch.einsum(exprs[-1], *[q.conj().to(md) for q in Q], *[q.to(md) for q in Q], G.to(md))
673
689
  if inplace:
674
690
  set_(G, out)
675
691
  return G
@@ -787,14 +803,15 @@ class PSGDBase(StatefulOptimizer):
787
803
  if g.dim() > 1:
788
804
  psgd_balance_Q(q)
789
805
 
790
- def do_update(self, p_list, grad_list, q_list, precond_lr, original_q: Optional[List] = None, store_triu_as_line=False):
806
+ def do_update(self, p_list, grad_list, q_list, precond_lr, original_q: Optional[List] = None,
807
+ store_triu_as_line=False):
791
808
  for i, (p, grad, Q) in enumerate(zip(p_list, grad_list, q_list)):
792
809
  psgd_update_precond(Q, self.state_(p)["exprs"], torch.randn_like(grad), grad, precond_lr, self._tiny)
793
810
  if original_q:
794
811
  if store_triu_as_line:
795
812
  update_triu_(original_q[i], Q)
796
813
  else:
797
- copy_stochastic_(original_q[i], Q)
814
+ copy_stochastic_list_(original_q[i], Q)
798
815
 
799
816
 
800
817
  def precond_update_prob_schedule(max_prob=1.0, min_prob=0.03, decay=0.001, flat_start=250):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: heavyball
3
- Version: 0.17.0
3
+ Version: 0.17.2
4
4
  Summary: Efficient optimizers
5
5
  Home-page: https://github.com/clashluke/heavyball
6
6
  Author: Lucas Nestler
@@ -32,8 +32,8 @@ A simple package of efficient optimizers
32
32
  The goal is not to thrive for completeness, full maintenance or abstraction, but instead to provide a simple
33
33
  largely static alternative to `torch.optim` with more and better optimizers.
34
34
 
35
- Currently (2024-11-17, 0.15.0), the recommended stable optimizer is `PrecondSchedulePaLMForeachSOAP` (see below). The
36
- recommended experimental optimizer is `ForeachPSGDKron`.
35
+ Currently (2024-11-20, 0.17.0), the recommended stable optimizer is `PrecondSchedulePaLMSOAP` (see below). The
36
+ recommended experimental optimizer is `DelayedPSGDKron` ([tuning guide](docs/psgd_efficiency.md)).
37
37
 
38
38
  ## Features
39
39
 
@@ -62,7 +62,7 @@ import heavyball
62
62
  model = torch.nn.Linear(16, 1)
63
63
 
64
64
  # Create an optimizer
65
- optimizer = heavyball.PrecondSchedulePaLMForeachSOAP(model.parameters(), lr=1e-3)
65
+ optimizer = heavyball.PrecondSchedulePaLMSOAP(model.parameters(), lr=1e-3)
66
66
 
67
67
  x = torch.randn(128, 16)
68
68
  y = torch.randn(128, 1)
@@ -76,19 +76,19 @@ for _ in range(1000):
76
76
 
77
77
  ## Optimizers
78
78
 
79
- | Name | Description | Advantages / Disadvantages |
80
- |--------------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
81
- | **ForeachAdamW** | More efficient (speed, memory) [AdamW](https://arxiv.org/abs/1711.05101) | + Faster than AdamW<br>+ Possibly more (numerically) stable
82
- | **ForeachLaProp** | More efficient (speed, memory) [LaProp](https://arxiv.org/abs/2002.04839) | + Same cost as AdamW<br>+ Marginally better converence (better proofs)<br>+ Higher hyperparameter stability<br>- Not a guaranteed win (can be neutral)<br>- No "Slingshot" |
83
- | **ForeachADOPT** | More efficient (speed, memory) [ADOPT](https://arxiv.org/abs/2411.02853) | + Same cost as AdamW<br>+ Rigorous mathematical convergence proofs, even for challenging models (GANs)<br>- Empirically underperforms LaProp<br>- no bf16 |
84
- | **ForeachSFAdamW** | More efficient (speed, memory) [ScheduleFree AdamW](https://arxiv.org/abs/2405.15682) | + Same cost as AdamW, but better eval perf<br>+ Full control over hyperparameters |
85
- | **PaLMForeachSFAdamW** | ForeachSFAdamW with [PaLM's beta2 schedule](https://arxiv.org/abs/2204.02311) | + Same cost as AdamW, but better eval perf<br>+ Less control, but faster early and more stable late convergence<br>+ ScheduleFree<br>- slow early convergence |
86
- | **ForeachSOAP** | More efficient (speed, memory) [SOAP](https://arxiv.org/abs/2409.11321) | + Faster convergence (loss-at-step)<br>+ Full control over hyperparameters<br>- more memory usage<br>- more hyperparameters<br>- higher overhead than AdamW (can be ammortized; better loss-at-second) |
87
- | **PaLMForeachSOAP** | ForeachSOAP with [PaLM's beta2 schedule](https://arxiv.org/abs/2204.02311) | + Faster convergence (loss-at-step)<br>+ Less control, but faster early and more stable late convergence<br>- more memory usage<br>- more hyperparameters<br>- higher overhead than AdamW (can be ammortized; better loss-at-second) |
88
- | **SFPaLMForeachSOAP** | ScheduleFree PaLMForeachSOAP | + Fast convergence (loss-at-step)<br>+ less memory usage than PaLMForeachSOAP (more tham AdamW)<br>- slower initial convergence than PaLMForeachSOAP (but allows higher LRs)<br>- higher overhead than AdamW (can be ammortized) |
89
- | **PrecondScheduleSFPaLMForeachSOAP** | SFPaLMForeachSOAP with [preconditioner schedule](https://github.com/lixilinx/psgd_torch/), matching the error of PrecondEvery=2 with the cost of PrecondEvery=512 | + Better initial convergence than SFPaLMForeachSOAP<br>+ Significantly faster (sec/it) later<br>+ less memory usage than PaLMForeachSOAP (more tham AdamW)<br>- slower initial convergence than PaLMForeachSOAP (but allows higher LRs)<br>- higher overhead than AdamW (can be ammortized), goes to 0 with increasing number of step |
90
- | **PrecondSchedulePaLMForeachSOAP** | PrecondScheduleSFPaLMForeachSOAP without schedule-free | + Best initial convergence<br>+ Significantly faster (sec/it) later<br>+ high stability<br>- more memory usage than PrecondScheduleSFPaLMForeachSOAP<br>- higher overhead than AdamW (can be ammortized), goes to 0 with increasing number of steps |
91
- | **PrecondScheduleForeachSOAP** | PrecondScheduleSFPaLMForeachSOAP without PaLM's beta2 schedule | + Better initial convergence<br>+ Significantly faster (sec/it) later<br>- more memory usage than PrecondScheduleSFPaLMForeachSOAP<br>- higher overhead than AdamW (can be ammortized), goes to 0 with increasing number of steps |
79
+ | Name | Description | Advantages / Disadvantages |
80
+ |-------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
81
+ | **AdamW** | More efficient (speed, memory) [AdamW](https://arxiv.org/abs/1711.05101) | + Faster than AdamW<br>+ Possibly more (numerically) stable
82
+ | **LaProp** | More efficient (speed, memory) [LaProp](https://arxiv.org/abs/2002.04839) | + Same cost as AdamW<br>+ Marginally better converence (better proofs)<br>+ Higher hyperparameter stability<br>- Not a guaranteed win (can be neutral)<br>- No "Slingshot" |
83
+ | **ADOPT** | More efficient (speed, memory) [ADOPT](https://arxiv.org/abs/2411.02853) | + Same cost as AdamW<br>+ Rigorous mathematical convergence proofs, even for challenging models (GANs)<br>- Empirically underperforms LaProp<br>- no bf16 |
84
+ | **SFAdamW** | More efficient (speed, memory) [ScheduleFree AdamW](https://arxiv.org/abs/2405.15682) | + Same cost as AdamW, but better eval perf<br>+ Full control over hyperparameters |
85
+ | **PaLMSFAdamW** | ForeachSFAdamW with [PaLM's beta2 schedule](https://arxiv.org/abs/2204.02311) | + Same cost as AdamW, but better eval perf<br>+ Less control, but faster early and more stable late convergence<br>+ ScheduleFree<br>- slow early convergence |
86
+ | **SOAP** | More efficient (speed, memory) [SOAP](https://arxiv.org/abs/2409.11321) | + Faster convergence (loss-at-step)<br>+ Full control over hyperparameters<br>- more memory usage<br>- more hyperparameters<br>- higher overhead than AdamW (can be ammortized; better loss-at-second) |
87
+ | **PaLMSOAP** | ForeachSOAP with [PaLM's beta2 schedule](https://arxiv.org/abs/2204.02311) | + Faster convergence (loss-at-step)<br>+ Less control, but faster early and more stable late convergence<br>- more memory usage<br>- more hyperparameters<br>- higher overhead than AdamW (can be ammortized; better loss-at-second) |
88
+ | **SFPaLMSOAP** | ScheduleFree PaLMForeachSOAP | + Fast convergence (loss-at-step)<br>+ less memory usage than PaLMForeachSOAP (more tham AdamW)<br>- slower initial convergence than PaLMForeachSOAP (but allows higher LRs)<br>- higher overhead than AdamW (can be ammortized) |
89
+ | **PrecondScheduleSFPaLMSOAP** | SFPaLMForeachSOAP with [preconditioner schedule](https://github.com/lixilinx/psgd_torch/), matching the error of PrecondEvery=2 with the cost of PrecondEvery=512 | + Better initial convergence than SFPaLMForeachSOAP<br>+ Significantly faster (sec/it) later<br>+ less memory usage than PaLMForeachSOAP (more tham AdamW)<br>- slower initial convergence than PaLMForeachSOAP (but allows higher LRs)<br>- higher overhead than AdamW (can be ammortized), goes to 0 with increasing number of step |
90
+ | **PrecondSchedulePaLMSOAP** | PrecondScheduleSFPaLMForeachSOAP without schedule-free | + Best initial convergence<br>+ Significantly faster (sec/it) later<br>+ high stability<br>- more memory usage than PrecondScheduleSFPaLMForeachSOAP<br>- higher overhead than AdamW (can be ammortized), goes to 0 with increasing number of steps |
91
+ | **PrecondScheduleSOAP** | PrecondScheduleSFPaLMForeachSOAP without PaLM's beta2 schedule | + Better initial convergence<br>+ Significantly faster (sec/it) later<br>- more memory usage than PrecondScheduleSFPaLMForeachSOAP<br>- higher overhead than AdamW (can be ammortized), goes to 0 with increasing number of steps |
92
92
 
93
93
  ## Precond Schedule
94
94
 
@@ -1,4 +1,4 @@
1
- heavyball/__init__.py,sha256=mDHahP4u0fy2YKWA4FPMAp7jLPMt5WwUkEiOrwE4u3E,2199
1
+ heavyball/__init__.py,sha256=iqP428JWwwx-XDOZ0nUdbCkOLEyfoqVyWZLQLAcwxaw,2214
2
2
  heavyball/cached_delayed_psgd_kron.py,sha256=DvjNNHzbnS-NDq965wve-VQ-ol7IFljYYGTuTwPHOhU,6971
3
3
  heavyball/cached_psgd_kron.py,sha256=xy3-yRKFUvRTstJb_asMVp-k-5Zuw_HyILPi7BsuMKQ,6974
4
4
  heavyball/delayed_psgd.py,sha256=rDDUj3miEn6HRJmKl-ZImsqkqBASSn8aC7MEV_06fzU,6017
@@ -16,9 +16,9 @@ heavyball/precond_schedule_sfpsoap.py,sha256=vq7jd302refKPa_9X2lkOTOtCCcTBVByPdo
16
16
  heavyball/psgd_kron.py,sha256=2IpPj2TOExNGm8hSewi3er2GczJRNgC7r2J5yYSSA_0,5998
17
17
  heavyball/pure_psgd.py,sha256=uA7W9a3Qm1sxHQhtNxaUYrmE5x55lP5iJOKy_qT8XaQ,5341
18
18
  heavyball/schedule_free_palm_foreach_soap.py,sha256=zkcikH5wWbzq4kOrmBjilvY3iWzuUddcv2HNEPKr3MI,6366
19
- heavyball/utils.py,sha256=Jqh7VdWGeiSdwaPtUNB9l14wuuFPSReLaTwJA3juFbM,28765
20
- heavyball-0.17.0.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
21
- heavyball-0.17.0.dist-info/METADATA,sha256=GIJQ4ha-fcYR6ltOs4WUO8L_LhWGiZv2UrEZcuJD0LI,11941
22
- heavyball-0.17.0.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
23
- heavyball-0.17.0.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
24
- heavyball-0.17.0.dist-info/RECORD,,
19
+ heavyball/utils.py,sha256=8dhagAGj03D7kBEWOJmqsCjQKP069e1WwrzVp1JsBr8,29472
20
+ heavyball-0.17.2.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
21
+ heavyball-0.17.2.dist-info/METADATA,sha256=kCK9J8gg-6lj0qao7S-yDc7jsOzGxNJ6F4_JUFaiIR4,11810
22
+ heavyball-0.17.2.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
23
+ heavyball-0.17.2.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
24
+ heavyball-0.17.2.dist-info/RECORD,,