heavyball 0.18.3__tar.gz → 0.18.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. {heavyball-0.18.3 → heavyball-0.18.4}/PKG-INFO +1 -1
  2. {heavyball-0.18.3 → heavyball-0.18.4}/heavyball/utils.py +11 -10
  3. {heavyball-0.18.3 → heavyball-0.18.4}/heavyball.egg-info/PKG-INFO +1 -1
  4. {heavyball-0.18.3 → heavyball-0.18.4}/setup.py +1 -1
  5. {heavyball-0.18.3 → heavyball-0.18.4}/LICENSE +0 -0
  6. {heavyball-0.18.3 → heavyball-0.18.4}/README.md +0 -0
  7. {heavyball-0.18.3 → heavyball-0.18.4}/heavyball/__init__.py +0 -0
  8. {heavyball-0.18.3 → heavyball-0.18.4}/heavyball/cached_delayed_psgd_kron.py +0 -0
  9. {heavyball-0.18.3 → heavyball-0.18.4}/heavyball/cached_psgd_kron.py +0 -0
  10. {heavyball-0.18.3 → heavyball-0.18.4}/heavyball/delayed_psgd.py +0 -0
  11. {heavyball-0.18.3 → heavyball-0.18.4}/heavyball/foreach_adamw.py +0 -0
  12. {heavyball-0.18.3 → heavyball-0.18.4}/heavyball/foreach_adopt.py +0 -0
  13. {heavyball-0.18.3 → heavyball-0.18.4}/heavyball/foreach_laprop.py +0 -0
  14. {heavyball-0.18.3 → heavyball-0.18.4}/heavyball/foreach_sfadamw.py +0 -0
  15. {heavyball-0.18.3 → heavyball-0.18.4}/heavyball/foreach_soap.py +0 -0
  16. {heavyball-0.18.3 → heavyball-0.18.4}/heavyball/p_adam.py +0 -0
  17. {heavyball-0.18.3 → heavyball-0.18.4}/heavyball/palm_foreach_sfadamw.py +0 -0
  18. {heavyball-0.18.3 → heavyball-0.18.4}/heavyball/palm_foreach_soap.py +0 -0
  19. {heavyball-0.18.3 → heavyball-0.18.4}/heavyball/precond_schedule_foreach_soap.py +0 -0
  20. {heavyball-0.18.3 → heavyball-0.18.4}/heavyball/precond_schedule_palm_foreach_soap.py +0 -0
  21. {heavyball-0.18.3 → heavyball-0.18.4}/heavyball/precond_schedule_sfpsoap.py +0 -0
  22. {heavyball-0.18.3 → heavyball-0.18.4}/heavyball/psgd_kron.py +0 -0
  23. {heavyball-0.18.3 → heavyball-0.18.4}/heavyball/pure_psgd.py +0 -0
  24. {heavyball-0.18.3 → heavyball-0.18.4}/heavyball/schedule_free_palm_foreach_soap.py +0 -0
  25. {heavyball-0.18.3 → heavyball-0.18.4}/heavyball.egg-info/SOURCES.txt +0 -0
  26. {heavyball-0.18.3 → heavyball-0.18.4}/heavyball.egg-info/dependency_links.txt +0 -0
  27. {heavyball-0.18.3 → heavyball-0.18.4}/heavyball.egg-info/requires.txt +0 -0
  28. {heavyball-0.18.3 → heavyball-0.18.4}/heavyball.egg-info/top_level.txt +0 -0
  29. {heavyball-0.18.3 → heavyball-0.18.4}/setup.cfg +0 -0
  30. {heavyball-0.18.3 → heavyball-0.18.4}/test/test_bf16_q.py +0 -0
  31. {heavyball-0.18.3 → heavyball-0.18.4}/test/test_closure.py +0 -0
  32. {heavyball-0.18.3 → heavyball-0.18.4}/test/test_foreach.py +0 -0
  33. {heavyball-0.18.3 → heavyball-0.18.4}/test/test_memory.py +0 -0
  34. {heavyball-0.18.3 → heavyball-0.18.4}/test/test_merge.py +0 -0
  35. {heavyball-0.18.3 → heavyball-0.18.4}/test/test_no_grad.py +0 -0
  36. {heavyball-0.18.3 → heavyball-0.18.4}/test/test_psgd.py +0 -0
  37. {heavyball-0.18.3 → heavyball-0.18.4}/test/test_soap.py +0 -0
  38. {heavyball-0.18.3 → heavyball-0.18.4}/test/test_stochastic_updates.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: heavyball
3
- Version: 0.18.3
3
+ Version: 0.18.4
4
4
  Summary: Efficient optimizers
5
5
  Home-page: https://github.com/clashluke/heavyball
6
6
  Author: Lucas Nestler
@@ -479,7 +479,7 @@ def copy_stochastic_list_(target: List[torch.Tensor], source: List[torch.Tensor]
479
479
  copy_stochastic_(t, s)
480
480
 
481
481
 
482
- @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
482
+ @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
483
483
  def _compilable_copy_stochastic_(target: torch.Tensor, source: torch.Tensor):
484
484
  """Taken as-is from https://github.com/pytorch/pytorch/issues/120376#issuecomment-1974828905"""
485
485
  # create a random 16 bit integer
@@ -815,20 +815,21 @@ class PSGDBase(StatefulOptimizer):
815
815
 
816
816
  def do_update(self, group, p_list, grad_list, q_list, precond_lr, original_q: Optional[List] = None,
817
817
  store_triu_as_line=False):
818
- for i, (p, grad, Q) in enumerate(zip(p_list, grad_list, q_list)):
818
+ if original_q:
819
+ if store_triu_as_line:
820
+ update_fn = update_triu_
821
+ else:
822
+ update_fn = copy_stochastic_list_
823
+ else:
824
+ update_fn = lambda x, y: None
825
+ for i, (p, grad, Q, oq) in enumerate(zip(p_list, grad_list, q_list, original_q)):
819
826
  psgd_update_precond(Q, self.state_(p)["exprs"], torch.randn_like(grad), grad, precond_lr, self._tiny)
827
+ update_fn(oq, Q)
820
828
 
821
- for g, q in zip(grad_list, q_list):
829
+ for g, q in zip(grad_list, original_q if original_q else q_list):
822
830
  if g.dim() > 1:
823
831
  psgd_balance_Q(q)
824
832
 
825
- if original_q:
826
- for q in q_list:
827
- if store_triu_as_line:
828
- update_triu_(original_q[i], Q)
829
- else:
830
- copy_stochastic_list_(original_q[i], Q)
831
-
832
833
 
833
834
  def precond_update_prob_schedule(max_prob=1.0, min_prob=0.03, decay=0.001, flat_start=250):
834
835
  """Anneal preconditioner update probability during beginning of training.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: heavyball
3
- Version: 0.18.3
3
+ Version: 0.18.4
4
4
  Summary: Efficient optimizers
5
5
  Home-page: https://github.com/clashluke/heavyball
6
6
  Author: Lucas Nestler
@@ -10,7 +10,7 @@ setuptools.setup(
10
10
  name='heavyball',
11
11
  license='BSD',
12
12
  description='Efficient optimizers',
13
- version='0.18.3',
13
+ version='0.18.4',
14
14
  long_description=README,
15
15
  url='https://github.com/clashluke/heavyball',
16
16
  packages=setuptools.find_packages(),
File without changes
File without changes
File without changes
File without changes
File without changes