heavyball 0.20.0__py3-none-any.whl → 0.21.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
heavyball/utils.py CHANGED
@@ -329,23 +329,33 @@ def get_orthogonal_matrix(mat):
329
329
 
330
330
 
331
331
  @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
332
- def stochastic_lerp_(x: List[torch.Tensor], y: List[torch.Tensor], a: Union[float, int, torch.Tensor]):
333
- x32 = [promote(x_) for x_ in x]
334
- y32 = [promote(y_) for y_ in y]
332
+ def _compilable_stochastic_lerp_(x: List[torch.Tensor], y: List[torch.Tensor], a: Union[float, int, torch.Tensor]):
333
+ for x_, y_ in zip(x, y):
334
+ x32 = promote(x_)
335
+ y32 = promote(y_)
336
+ x32.lerp_(y32, a)
337
+ copy_stochastic_(x_, x32)
335
338
 
336
- torch._foreach_lerp_(x32, y32, a)
337
339
 
338
- copy_stochastic_list_(x, x32)
340
+ def stochastic_lerp_(x: List[torch.Tensor], y: List[torch.Tensor], a: Union[float, int, torch.Tensor]):
341
+ if not isinstance(a, torch.Tensor):
342
+ a = torch.empty((), dtype=torch.float32, device=x[0].device).fill_(a)
343
+ _compilable_stochastic_lerp_(x, y, a)
339
344
 
340
345
 
341
346
  @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
342
- def stochastic_add_(x: List[torch.Tensor], y: List[torch.Tensor], alpha: Union[float, int, torch.Tensor]):
343
- x32 = [promote(x_) for x_ in x]
344
- y32 = [promote(y_) for y_ in y]
347
+ def _compilable_stochastic_add_(x: List[torch.Tensor], y: List[torch.Tensor], alpha: Union[float, int, torch.Tensor]):
348
+ for x_, y_ in zip(x, y):
349
+ x32 = promote(x_)
350
+ y32 = promote(y_)
351
+ x32.add_(y32, alpha=alpha)
352
+ copy_stochastic_(x_, x32)
345
353
 
346
- [x_.add_(y_, alpha=alpha) for x_, y_ in zip(x32, y32)]
347
354
 
348
- copy_stochastic_list_(x, x32)
355
+ def stochastic_add_(x: List[torch.Tensor], y: List[torch.Tensor], alpha: Union[float, int, torch.Tensor]):
356
+ if not isinstance(alpha, torch.Tensor):
357
+ alpha = torch.empty((), dtype=torch.float32, device=x[0].device).fill_(alpha)
358
+ _compilable_stochastic_add_(x, y, alpha)
349
359
 
350
360
 
351
361
  @decorator
@@ -572,7 +582,7 @@ def copy_stochastic_list_(target: List[torch.Tensor], source: List[torch.Tensor]
572
582
 
573
583
 
574
584
  @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
575
- def exp_avg_(exp_avg, exp_avg_sq, grad, grad_projected, beta1, beta2, step):
585
+ def _compilable_exp_avg_(exp_avg, exp_avg_sq, grad, grad_projected, beta1, beta2, step):
576
586
  beta1 = beta_debias(beta1, step)
577
587
  beta2 = beta_debias(beta2, step)
578
588
 
@@ -585,6 +595,18 @@ def exp_avg_(exp_avg, exp_avg_sq, grad, grad_projected, beta1, beta2, step):
585
595
  return denom
586
596
 
587
597
 
598
+ def exp_avg_(exp_avg: List[torch.Tensor], exp_avg_sq: List[torch.Tensor], grad: List[torch.Tensor],
599
+ grad_projected: List[torch.Tensor], beta1: float, beta2: float, step: int):
600
+ if isinstance(beta1, float):
601
+ beta1 = torch.empty((), dtype=torch.float32, device=exp_avg[0].device).fill_(beta1)
602
+ if isinstance(beta2, float):
603
+ beta2 = torch.empty((), dtype=torch.float32, device=exp_avg[0].device).fill_(beta2)
604
+ if isinstance(step, int):
605
+ step = torch.empty((), dtype=torch.int32, device=exp_avg[0].device).fill_(step)
606
+ denom = _compilable_exp_avg_(exp_avg, exp_avg_sq, grad, grad_projected, beta1, beta2, step)
607
+ return denom
608
+
609
+
588
610
  # this can be dynamic for most optimizers - just not for PSGD. So, it's disabled for all
589
611
  @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True)
590
612
  def _compilable_copy_stochastic_(target: torch.Tensor, source: torch.Tensor):
@@ -721,7 +743,7 @@ def init_Q_exprs(t, scale, max_size, min_ndim_triangular, memory_save_mode, dtyp
721
743
  return [Q, (exprA, tuple(exprGs), exprP)]
722
744
 
723
745
 
724
- @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
746
+ @decorator
725
747
  def psgd_balance_Q(Q_in):
726
748
  norms = torch.stack([q.norm(float("inf")) for q in Q_in])
727
749
  geometric_mean = norms.log().mean().exp()
@@ -734,8 +756,7 @@ def psgd_calc_A_and_conjB(exprA, G, Q):
734
756
  A = torch.einsum(exprA, *[q.to(md) for q in Q], G.to(md)).to(G.dtype)
735
757
  order = G.dim()
736
758
  p = list(range(order))
737
- V = torch.randn_like(G, dtype=promote(G.dtype))
738
- conjB = torch.permute(V, p[1:] + p[:1])
759
+ conjB = torch.randn(G.shape[1:] + G.shape[:1], dtype=promote(G.dtype), device=G.device)
739
760
  Q = [promote(q) for q in Q]
740
761
  for i, q in enumerate(Q):
741
762
  if q.dim() <= 1:
@@ -755,27 +776,19 @@ def psgd_calc_A_and_conjB(exprA, G, Q):
755
776
  def psgd_lb(A, max_abs):
756
777
  A /= max_abs
757
778
  a0 = torch.einsum('ij,ij->j', A, A)
758
- a1 = torch.einsum('ij,ij->i', A, A)
759
- value0 = torch.max(a0)
760
- value1 = torch.max(a1)
761
779
  i = torch.argmax(a0)
762
- j = torch.argmax(a1)
763
780
 
764
- comp = value0 > value1
765
- x = torch.cond(comp, lambda a: torch.index_select(a, 1, i).flatten().contiguous(), #
766
- lambda a: torch.index_select(a, 0, j).flatten().contiguous(), (A,))
781
+ x = torch.index_select(A, 1, i).flatten().contiguous()
767
782
 
768
- x = torch.cond(comp, lambda x_, a: torch.einsum('i,ij->j', x_, a), lambda x_, a: torch.einsum('i,ji->j', x_, a),
769
- (x, A,))
783
+ x = torch.einsum('i,ij->j', x, A)
770
784
  x /= x.norm()
771
- x = torch.cond(comp, lambda x_, a: torch.einsum('j,kj->k', x_, a), lambda x_, a: torch.einsum('j,jk->k', x_, a),
772
- (x, A,))
785
+ x = torch.einsum('j,kj->k', x, A)
773
786
  x = x.norm()
774
787
  x *= max_abs
775
788
  return x
776
789
 
777
790
 
778
- @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
791
+ @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
779
792
  def psgd_update_precond(Q, exprs, G, precond_lr, tiny, oq, store_triu_as_line):
780
793
  """Update Kronecker product preconditioner Q with pair (V, G)."""
781
794
  exprA, exprGs, _ = exprs
@@ -808,7 +821,7 @@ def psgd_update_precond(Q, exprs, G, precond_lr, tiny, oq, store_triu_as_line):
808
821
  stochastic_add_([o], [term1], -1)
809
822
 
810
823
 
811
- @decorator
824
+ @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
812
825
  def psgd_precond_grad(Q, exprs, G, inplace: bool = False):
813
826
  """Precondition gradient G with preconditioner Q."""
814
827
  md = min_dtype(Q)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: heavyball
3
- Version: 0.20.0
3
+ Version: 0.21.0
4
4
  Summary: Efficient optimizers
5
5
  Home-page: https://github.com/clashluke/heavyball
6
6
  Author: Lucas Nestler
@@ -32,7 +32,7 @@ A simple package of efficient optimizers
32
32
  The goal is not to thrive for completeness, full maintenance or abstraction, but instead to provide a simple
33
33
  largely static alternative to `torch.optim` with more and better optimizers.
34
34
 
35
- Currently (2024-11-22, 0.19.1), the recommended stable optimizer is `PrecondSchedulePaLMSOAP` (see below). The
35
+ Currently (2024-11-22, 0.21.0), the recommended stable optimizer is `PrecondSchedulePaLMSOAP` (see below). The
36
36
  recommended experimental optimizer is `DelayedPSGDKron` ([tuning guide](docs/psgd_efficiency.md)).
37
37
 
38
38
  ## Features
@@ -16,9 +16,9 @@ heavyball/precond_schedule_sfpsoap.py,sha256=PNneiOkrRyV1yIZn91lPmYofd1_OiLqJTDy
16
16
  heavyball/psgd_kron.py,sha256=RSLJi5FnSmYjvYBufNDClnYRm-eN_Kpa1Ar2tNP6-X0,5824
17
17
  heavyball/pure_psgd.py,sha256=LZK0qmvZkBF8g00evaVLtW-sIUJmdoxag1K7O26AqEo,4820
18
18
  heavyball/schedule_free_palm_foreach_soap.py,sha256=_AqrnChY6iDlQUkF2YUxS7eLjSWCIuvEUOHvMHVM1yY,6873
19
- heavyball/utils.py,sha256=ESazD0yv14Aa8XKi_pz2CyfVkpcbgYcG2-WMvhQOnxk,35719
20
- heavyball-0.20.0.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
21
- heavyball-0.20.0.dist-info/METADATA,sha256=dJ43LOTrNqh7cDTDzZDSu57goP1gNhU3dfZ26BUK9hA,11926
22
- heavyball-0.20.0.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
23
- heavyball-0.20.0.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
24
- heavyball-0.20.0.dist-info/RECORD,,
19
+ heavyball/utils.py,sha256=H8RsADNAXVbjQ9wWstYIKkXMq9E81aUF1j-2wfCeSLA,36471
20
+ heavyball-0.21.0.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
21
+ heavyball-0.21.0.dist-info/METADATA,sha256=hbXhr4XcPAkgfW8hpgoPRPrUoKeTCTvhZdofj4h8_8c,11926
22
+ heavyball-0.21.0.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
23
+ heavyball-0.21.0.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
24
+ heavyball-0.21.0.dist-info/RECORD,,