heavyball 0.20.1__py3-none-any.whl → 0.21.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
heavyball/utils.py CHANGED
@@ -329,23 +329,33 @@ def get_orthogonal_matrix(mat):
329
329
 
330
330
 
331
331
  @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
332
- def stochastic_lerp_(x: List[torch.Tensor], y: List[torch.Tensor], a: Union[float, int, torch.Tensor]):
333
- x32 = [promote(x_) for x_ in x]
334
- y32 = [promote(y_) for y_ in y]
332
+ def _compilable_stochastic_lerp_(x: List[torch.Tensor], y: List[torch.Tensor], a: Union[float, int, torch.Tensor]):
333
+ for x_, y_ in zip(x, y):
334
+ x32 = promote(x_)
335
+ y32 = promote(y_)
336
+ x32.lerp_(y32, a)
337
+ copy_stochastic_(x_, x32)
335
338
 
336
- torch._foreach_lerp_(x32, y32, a)
337
339
 
338
- copy_stochastic_list_(x, x32)
340
+ def stochastic_lerp_(x: List[torch.Tensor], y: List[torch.Tensor], a: Union[float, int, torch.Tensor]):
341
+ if not isinstance(a, torch.Tensor):
342
+ a = torch.empty((), dtype=torch.float32, device=x[0].device).fill_(a)
343
+ _compilable_stochastic_lerp_(x, y, a)
339
344
 
340
345
 
341
346
  @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
342
- def stochastic_add_(x: List[torch.Tensor], y: List[torch.Tensor], alpha: Union[float, int, torch.Tensor]):
343
- x32 = [promote(x_) for x_ in x]
344
- y32 = [promote(y_) for y_ in y]
347
+ def _compilable_stochastic_add_(x: List[torch.Tensor], y: List[torch.Tensor], alpha: Union[float, int, torch.Tensor]):
348
+ for x_, y_ in zip(x, y):
349
+ x32 = promote(x_)
350
+ y32 = promote(y_)
351
+ x32.add_(y32, alpha=alpha)
352
+ copy_stochastic_(x_, x32)
345
353
 
346
- [x_.add_(y_, alpha=alpha) for x_, y_ in zip(x32, y32)]
347
354
 
348
- copy_stochastic_list_(x, x32)
355
+ def stochastic_add_(x: List[torch.Tensor], y: List[torch.Tensor], alpha: Union[float, int, torch.Tensor]):
356
+ if not isinstance(alpha, torch.Tensor):
357
+ alpha = torch.empty((), dtype=torch.float32, device=x[0].device).fill_(alpha)
358
+ _compilable_stochastic_add_(x, y, alpha)
349
359
 
350
360
 
351
361
  @decorator
@@ -572,7 +582,7 @@ def copy_stochastic_list_(target: List[torch.Tensor], source: List[torch.Tensor]
572
582
 
573
583
 
574
584
  @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
575
- def exp_avg_(exp_avg, exp_avg_sq, grad, grad_projected, beta1, beta2, step):
585
+ def _compilable_exp_avg_(exp_avg, exp_avg_sq, grad, grad_projected, beta1, beta2, step):
576
586
  beta1 = beta_debias(beta1, step)
577
587
  beta2 = beta_debias(beta2, step)
578
588
 
@@ -585,6 +595,18 @@ def exp_avg_(exp_avg, exp_avg_sq, grad, grad_projected, beta1, beta2, step):
585
595
  return denom
586
596
 
587
597
 
598
+ def exp_avg_(exp_avg: List[torch.Tensor], exp_avg_sq: List[torch.Tensor], grad: List[torch.Tensor],
599
+ grad_projected: List[torch.Tensor], beta1: float, beta2: float, step: int):
600
+ if isinstance(beta1, float):
601
+ beta1 = torch.empty((), dtype=torch.float32, device=exp_avg[0].device).fill_(beta1)
602
+ if isinstance(beta2, float):
603
+ beta2 = torch.empty((), dtype=torch.float32, device=exp_avg[0].device).fill_(beta2)
604
+ if isinstance(step, int):
605
+ step = torch.empty((), dtype=torch.int32, device=exp_avg[0].device).fill_(step)
606
+ denom = _compilable_exp_avg_(exp_avg, exp_avg_sq, grad, grad_projected, beta1, beta2, step)
607
+ return denom
608
+
609
+
588
610
  # this can be dynamic for most optimizers - just not for PSGD. So, it's disabled for all
589
611
  @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True)
590
612
  def _compilable_copy_stochastic_(target: torch.Tensor, source: torch.Tensor):
@@ -734,7 +756,7 @@ def psgd_calc_A_and_conjB(exprA, G, Q):
734
756
  A = torch.einsum(exprA, *[q.to(md) for q in Q], G.to(md)).to(G.dtype)
735
757
  order = G.dim()
736
758
  p = list(range(order))
737
- conjB = torch.randn_like(G.shape[1:] + G.shape[:1], dtype=promote(G.dtype), device=G.device)
759
+ conjB = torch.randn(G.shape[1:] + G.shape[:1], dtype=promote(G.dtype), device=G.device)
738
760
  Q = [promote(q) for q in Q]
739
761
  for i, q in enumerate(Q):
740
762
  if q.dim() <= 1:
@@ -756,17 +778,17 @@ def psgd_lb(A, max_abs):
756
778
  a0 = torch.einsum('ij,ij->j', A, A)
757
779
  i = torch.argmax(a0)
758
780
 
759
- x = torch.index_select(a, 1, i).flatten().contiguous()
781
+ x = torch.index_select(A, 1, i).flatten().contiguous()
760
782
 
761
- x = torch.einsum('i,ij->j', x_, a)
783
+ x = torch.einsum('i,ij->j', x, A)
762
784
  x /= x.norm()
763
- x = torch.einsum('j,kj->k', x_, a)
785
+ x = torch.einsum('j,kj->k', x, A)
764
786
  x = x.norm()
765
787
  x *= max_abs
766
788
  return x
767
789
 
768
790
 
769
- @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
791
+ @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
770
792
  def psgd_update_precond(Q, exprs, G, precond_lr, tiny, oq, store_triu_as_line):
771
793
  """Update Kronecker product preconditioner Q with pair (V, G)."""
772
794
  exprA, exprGs, _ = exprs
@@ -799,7 +821,7 @@ def psgd_update_precond(Q, exprs, G, precond_lr, tiny, oq, store_triu_as_line):
799
821
  stochastic_add_([o], [term1], -1)
800
822
 
801
823
 
802
- @decorator
824
+ @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
803
825
  def psgd_precond_grad(Q, exprs, G, inplace: bool = False):
804
826
  """Precondition gradient G with preconditioner Q."""
805
827
  md = min_dtype(Q)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: heavyball
3
- Version: 0.20.1
3
+ Version: 0.21.0
4
4
  Summary: Efficient optimizers
5
5
  Home-page: https://github.com/clashluke/heavyball
6
6
  Author: Lucas Nestler
@@ -32,7 +32,7 @@ A simple package of efficient optimizers
32
32
  The goal is not to thrive for completeness, full maintenance or abstraction, but instead to provide a simple
33
33
  largely static alternative to `torch.optim` with more and better optimizers.
34
34
 
35
- Currently (2024-11-22, 0.19.1), the recommended stable optimizer is `PrecondSchedulePaLMSOAP` (see below). The
35
+ Currently (2024-11-22, 0.21.0), the recommended stable optimizer is `PrecondSchedulePaLMSOAP` (see below). The
36
36
  recommended experimental optimizer is `DelayedPSGDKron` ([tuning guide](docs/psgd_efficiency.md)).
37
37
 
38
38
  ## Features
@@ -16,9 +16,9 @@ heavyball/precond_schedule_sfpsoap.py,sha256=PNneiOkrRyV1yIZn91lPmYofd1_OiLqJTDy
16
16
  heavyball/psgd_kron.py,sha256=RSLJi5FnSmYjvYBufNDClnYRm-eN_Kpa1Ar2tNP6-X0,5824
17
17
  heavyball/pure_psgd.py,sha256=LZK0qmvZkBF8g00evaVLtW-sIUJmdoxag1K7O26AqEo,4820
18
18
  heavyball/schedule_free_palm_foreach_soap.py,sha256=_AqrnChY6iDlQUkF2YUxS7eLjSWCIuvEUOHvMHVM1yY,6873
19
- heavyball/utils.py,sha256=14vt4r_MeTsp1q3m0lpgF-Q3PCJg6GLGJrhjRxnbWwQ,35174
20
- heavyball-0.20.1.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
21
- heavyball-0.20.1.dist-info/METADATA,sha256=qzF2P7e2EREeTy_4h85tvUY53omjNm32z83CUHTqt3U,11926
22
- heavyball-0.20.1.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
23
- heavyball-0.20.1.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
24
- heavyball-0.20.1.dist-info/RECORD,,
19
+ heavyball/utils.py,sha256=H8RsADNAXVbjQ9wWstYIKkXMq9E81aUF1j-2wfCeSLA,36471
20
+ heavyball-0.21.0.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
21
+ heavyball-0.21.0.dist-info/METADATA,sha256=hbXhr4XcPAkgfW8hpgoPRPrUoKeTCTvhZdofj4h8_8c,11926
22
+ heavyball-0.21.0.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
23
+ heavyball-0.21.0.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
24
+ heavyball-0.21.0.dist-info/RECORD,,