heavyball 0.17.1__tar.gz → 0.17.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. {heavyball-0.17.1 → heavyball-0.17.2}/PKG-INFO +1 -1
  2. {heavyball-0.17.1 → heavyball-0.17.2}/heavyball/utils.py +28 -11
  3. {heavyball-0.17.1 → heavyball-0.17.2}/heavyball.egg-info/PKG-INFO +1 -1
  4. {heavyball-0.17.1 → heavyball-0.17.2}/setup.py +1 -1
  5. {heavyball-0.17.1 → heavyball-0.17.2}/LICENSE +0 -0
  6. {heavyball-0.17.1 → heavyball-0.17.2}/README.md +0 -0
  7. {heavyball-0.17.1 → heavyball-0.17.2}/heavyball/__init__.py +0 -0
  8. {heavyball-0.17.1 → heavyball-0.17.2}/heavyball/cached_delayed_psgd_kron.py +0 -0
  9. {heavyball-0.17.1 → heavyball-0.17.2}/heavyball/cached_psgd_kron.py +0 -0
  10. {heavyball-0.17.1 → heavyball-0.17.2}/heavyball/delayed_psgd.py +0 -0
  11. {heavyball-0.17.1 → heavyball-0.17.2}/heavyball/foreach_adamw.py +0 -0
  12. {heavyball-0.17.1 → heavyball-0.17.2}/heavyball/foreach_adopt.py +0 -0
  13. {heavyball-0.17.1 → heavyball-0.17.2}/heavyball/foreach_laprop.py +0 -0
  14. {heavyball-0.17.1 → heavyball-0.17.2}/heavyball/foreach_sfadamw.py +0 -0
  15. {heavyball-0.17.1 → heavyball-0.17.2}/heavyball/foreach_soap.py +0 -0
  16. {heavyball-0.17.1 → heavyball-0.17.2}/heavyball/p_adam.py +0 -0
  17. {heavyball-0.17.1 → heavyball-0.17.2}/heavyball/palm_foreach_sfadamw.py +0 -0
  18. {heavyball-0.17.1 → heavyball-0.17.2}/heavyball/palm_foreach_soap.py +0 -0
  19. {heavyball-0.17.1 → heavyball-0.17.2}/heavyball/precond_schedule_foreach_soap.py +0 -0
  20. {heavyball-0.17.1 → heavyball-0.17.2}/heavyball/precond_schedule_palm_foreach_soap.py +0 -0
  21. {heavyball-0.17.1 → heavyball-0.17.2}/heavyball/precond_schedule_sfpsoap.py +0 -0
  22. {heavyball-0.17.1 → heavyball-0.17.2}/heavyball/psgd_kron.py +0 -0
  23. {heavyball-0.17.1 → heavyball-0.17.2}/heavyball/pure_psgd.py +0 -0
  24. {heavyball-0.17.1 → heavyball-0.17.2}/heavyball/schedule_free_palm_foreach_soap.py +0 -0
  25. {heavyball-0.17.1 → heavyball-0.17.2}/heavyball.egg-info/SOURCES.txt +0 -0
  26. {heavyball-0.17.1 → heavyball-0.17.2}/heavyball.egg-info/dependency_links.txt +0 -0
  27. {heavyball-0.17.1 → heavyball-0.17.2}/heavyball.egg-info/requires.txt +0 -0
  28. {heavyball-0.17.1 → heavyball-0.17.2}/heavyball.egg-info/top_level.txt +0 -0
  29. {heavyball-0.17.1 → heavyball-0.17.2}/setup.cfg +0 -0
  30. {heavyball-0.17.1 → heavyball-0.17.2}/test/test_bf16_q.py +0 -0
  31. {heavyball-0.17.1 → heavyball-0.17.2}/test/test_closure.py +0 -0
  32. {heavyball-0.17.1 → heavyball-0.17.2}/test/test_foreach.py +0 -0
  33. {heavyball-0.17.1 → heavyball-0.17.2}/test/test_memory.py +0 -0
  34. {heavyball-0.17.1 → heavyball-0.17.2}/test/test_merge.py +0 -0
  35. {heavyball-0.17.1 → heavyball-0.17.2}/test/test_no_grad.py +0 -0
  36. {heavyball-0.17.1 → heavyball-0.17.2}/test/test_psgd.py +0 -0
  37. {heavyball-0.17.1 → heavyball-0.17.2}/test/test_soap.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: heavyball
3
- Version: 0.17.1
3
+ Version: 0.17.2
4
4
  Summary: Efficient optimizers
5
5
  Home-page: https://github.com/clashluke/heavyball
6
6
  Author: Lucas Nestler
@@ -332,6 +332,16 @@ def promote(x):
332
332
  return x
333
333
 
334
334
 
335
+ def min_dtype(xs: List[torch.Tensor]):
336
+ dtypes = [x.dtype for x in xs]
337
+ for d in (torch.float32, torch.bfloat16, torch.float16):
338
+ if all(d == x for x in dtypes):
339
+ return d
340
+ if all(d in (x, torch.float32, torch.float64) for x in dtypes):
341
+ return d
342
+ return torch.float32
343
+
344
+
335
345
  def update_preconditioner(grad, state, max_precond_dim, precondition_1d, beta, update_precond):
336
346
  """
337
347
  Updates the preconditioner matrices and the eigenbases (L, R, Q_L, Q_R in the paper).
@@ -471,13 +481,8 @@ def copy_stochastic_list_(target: List[torch.Tensor], source: List[torch.Tensor]
471
481
  copy_stochastic_(t, s)
472
482
 
473
483
 
474
- def copy_stochastic_(target: torch.Tensor, source: torch.Tensor):
475
- if target.data_ptr() == source.data_ptr():
476
- return
477
- if target.dtype != torch.bfloat16:
478
- set_(target, source)
479
- return
480
-
484
+ @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
485
+ def _compilable_copy_stochastic_(target: torch.Tensor, source: torch.Tensor):
481
486
  """Taken as-is from https://github.com/pytorch/pytorch/issues/120376#issuecomment-1974828905"""
482
487
  # create a random 16 bit integer
483
488
  result = torch.randint_like(source, dtype=torch.int32, low=0, high=(1 << 16))
@@ -492,6 +497,15 @@ def copy_stochastic_(target: torch.Tensor, source: torch.Tensor):
492
497
  target.copy_(result.view(dtype=torch.float32))
493
498
 
494
499
 
500
+ def copy_stochastic_(target: torch.Tensor, source: torch.Tensor):
501
+ if target.data_ptr() == source.data_ptr():
502
+ return
503
+ if target.dtype != torch.bfloat16 or source.dtype not in (torch.float16, torch.float32, torch.float64):
504
+ set_(target, source)
505
+ return
506
+ _compilable_copy_stochastic_(target, source)
507
+
508
+
495
509
  def update_param_(param: List[torch.Tensor], update: List[torch.Tensor], lr: float, decay: float,
496
510
  add_fn: callable = None):
497
511
  param32 = [promote(p) for p in param]
@@ -602,7 +616,8 @@ def psgd_balance_Q(Q_in):
602
616
 
603
617
 
604
618
  def psgd_calc_A_and_conjB(exprA, G, Q, V):
605
- A = torch.einsum(exprA, *Q, G)
619
+ md = min_dtype(Q)
620
+ A = torch.einsum(exprA, *[q.to(md) for q in Q], G.to(md))
606
621
  order = G.dim()
607
622
  p = list(range(order))
608
623
  conjB = torch.permute(V.conj(), p[1:] + p[:1])
@@ -669,7 +684,8 @@ def psgd_update_precond(Q, exprs, V, G, step, tiny):
669
684
  @decorator
670
685
  def psgd_precond_grad(Q, exprs, G, inplace: bool = False):
671
686
  """Precondition gradient G with preconditioner Q."""
672
- out = torch.einsum(exprs[-1], *[q.conj() for q in Q], *Q, G.to(Q[0].dtype))
687
+ md = min_dtype(Q)
688
+ out = torch.einsum(exprs[-1], *[q.conj().to(md) for q in Q], *[q.to(md) for q in Q], G.to(md))
673
689
  if inplace:
674
690
  set_(G, out)
675
691
  return G
@@ -787,14 +803,15 @@ class PSGDBase(StatefulOptimizer):
787
803
  if g.dim() > 1:
788
804
  psgd_balance_Q(q)
789
805
 
790
- def do_update(self, p_list, grad_list, q_list, precond_lr, original_q: Optional[List] = None, store_triu_as_line=False):
806
+ def do_update(self, p_list, grad_list, q_list, precond_lr, original_q: Optional[List] = None,
807
+ store_triu_as_line=False):
791
808
  for i, (p, grad, Q) in enumerate(zip(p_list, grad_list, q_list)):
792
809
  psgd_update_precond(Q, self.state_(p)["exprs"], torch.randn_like(grad), grad, precond_lr, self._tiny)
793
810
  if original_q:
794
811
  if store_triu_as_line:
795
812
  update_triu_(original_q[i], Q)
796
813
  else:
797
- copy_stochastic_(original_q[i], Q)
814
+ copy_stochastic_list_(original_q[i], Q)
798
815
 
799
816
 
800
817
  def precond_update_prob_schedule(max_prob=1.0, min_prob=0.03, decay=0.001, flat_start=250):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: heavyball
3
- Version: 0.17.1
3
+ Version: 0.17.2
4
4
  Summary: Efficient optimizers
5
5
  Home-page: https://github.com/clashluke/heavyball
6
6
  Author: Lucas Nestler
@@ -10,7 +10,7 @@ setuptools.setup(
10
10
  name='heavyball',
11
11
  license='BSD',
12
12
  description='Efficient optimizers',
13
- version='0.17.1',
13
+ version='0.17.2',
14
14
  long_description=README,
15
15
  url='https://github.com/clashluke/heavyball',
16
16
  packages=setuptools.find_packages(),
File without changes
File without changes
File without changes
File without changes
File without changes