heavyball 0.23.0__tar.gz → 0.23.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. {heavyball-0.23.0 → heavyball-0.23.1}/PKG-INFO +1 -1
  2. {heavyball-0.23.0 → heavyball-0.23.1}/heavyball/utils.py +3 -3
  3. {heavyball-0.23.0 → heavyball-0.23.1}/heavyball.egg-info/PKG-INFO +1 -1
  4. {heavyball-0.23.0 → heavyball-0.23.1}/setup.py +1 -1
  5. {heavyball-0.23.0 → heavyball-0.23.1}/LICENSE +0 -0
  6. {heavyball-0.23.0 → heavyball-0.23.1}/README.md +0 -0
  7. {heavyball-0.23.0 → heavyball-0.23.1}/heavyball/__init__.py +0 -0
  8. {heavyball-0.23.0 → heavyball-0.23.1}/heavyball/cached_delayed_psgd_kron.py +0 -0
  9. {heavyball-0.23.0 → heavyball-0.23.1}/heavyball/cached_psgd_kron.py +0 -0
  10. {heavyball-0.23.0 → heavyball-0.23.1}/heavyball/delayed_psgd.py +0 -0
  11. {heavyball-0.23.0 → heavyball-0.23.1}/heavyball/foreach_adamw.py +0 -0
  12. {heavyball-0.23.0 → heavyball-0.23.1}/heavyball/foreach_adopt.py +0 -0
  13. {heavyball-0.23.0 → heavyball-0.23.1}/heavyball/foreach_laprop.py +0 -0
  14. {heavyball-0.23.0 → heavyball-0.23.1}/heavyball/foreach_sfadamw.py +0 -0
  15. {heavyball-0.23.0 → heavyball-0.23.1}/heavyball/foreach_soap.py +0 -0
  16. {heavyball-0.23.0 → heavyball-0.23.1}/heavyball/p_adam.py +0 -0
  17. {heavyball-0.23.0 → heavyball-0.23.1}/heavyball/palm_foreach_sfadamw.py +0 -0
  18. {heavyball-0.23.0 → heavyball-0.23.1}/heavyball/palm_foreach_soap.py +0 -0
  19. {heavyball-0.23.0 → heavyball-0.23.1}/heavyball/precond_schedule_foreach_soap.py +0 -0
  20. {heavyball-0.23.0 → heavyball-0.23.1}/heavyball/precond_schedule_palm_foreach_soap.py +0 -0
  21. {heavyball-0.23.0 → heavyball-0.23.1}/heavyball/precond_schedule_sfpsoap.py +0 -0
  22. {heavyball-0.23.0 → heavyball-0.23.1}/heavyball/psgd_kron.py +0 -0
  23. {heavyball-0.23.0 → heavyball-0.23.1}/heavyball/pure_psgd.py +0 -0
  24. {heavyball-0.23.0 → heavyball-0.23.1}/heavyball/schedule_free_palm_foreach_soap.py +0 -0
  25. {heavyball-0.23.0 → heavyball-0.23.1}/heavyball.egg-info/SOURCES.txt +0 -0
  26. {heavyball-0.23.0 → heavyball-0.23.1}/heavyball.egg-info/dependency_links.txt +0 -0
  27. {heavyball-0.23.0 → heavyball-0.23.1}/heavyball.egg-info/requires.txt +0 -0
  28. {heavyball-0.23.0 → heavyball-0.23.1}/heavyball.egg-info/top_level.txt +0 -0
  29. {heavyball-0.23.0 → heavyball-0.23.1}/setup.cfg +0 -0
  30. {heavyball-0.23.0 → heavyball-0.23.1}/test/test_bf16_params.py +0 -0
  31. {heavyball-0.23.0 → heavyball-0.23.1}/test/test_bf16_q.py +0 -0
  32. {heavyball-0.23.0 → heavyball-0.23.1}/test/test_bf16_storage.py +0 -0
  33. {heavyball-0.23.0 → heavyball-0.23.1}/test/test_caution.py +0 -0
  34. {heavyball-0.23.0 → heavyball-0.23.1}/test/test_closure.py +0 -0
  35. {heavyball-0.23.0 → heavyball-0.23.1}/test/test_ema.py +0 -0
  36. {heavyball-0.23.0 → heavyball-0.23.1}/test/test_foreach.py +0 -0
  37. {heavyball-0.23.0 → heavyball-0.23.1}/test/test_mars.py +0 -0
  38. {heavyball-0.23.0 → heavyball-0.23.1}/test/test_memory.py +0 -0
  39. {heavyball-0.23.0 → heavyball-0.23.1}/test/test_merge.py +0 -0
  40. {heavyball-0.23.0 → heavyball-0.23.1}/test/test_no_grad.py +0 -0
  41. {heavyball-0.23.0 → heavyball-0.23.1}/test/test_psgd.py +0 -0
  42. {heavyball-0.23.0 → heavyball-0.23.1}/test/test_soap.py +0 -0
  43. {heavyball-0.23.0 → heavyball-0.23.1}/test/test_stochastic_updates.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: heavyball
3
- Version: 0.23.0
3
+ Version: 0.23.1
4
4
  Summary: Efficient optimizers
5
5
  Home-page: https://github.com/clashluke/heavyball
6
6
  Author: Lucas Nestler
@@ -686,9 +686,9 @@ def _compilable_copy_stochastic_(target: Tensor, source: Tensor):
686
686
  def copy_stochastic_(target: Tensor, source: Tensor):
687
687
  if not is_compiling() and target.data_ptr() == source.data_ptr():
688
688
  return
689
- if target.dtype != torch.bfloat16 or source.dtype not in (torch.float16, torch.float32, torch.float64):
690
- set_(target, source)
691
- _compilable_copy_stochastic_(target, source)
689
+ if target.dtype == torch.bfloat16 and source.dtype in (torch.float16, torch.float32, torch.float64):
690
+ _compilable_copy_stochastic_(target, source.float())
691
+ set_(target, source)
692
692
 
693
693
 
694
694
  @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: heavyball
3
- Version: 0.23.0
3
+ Version: 0.23.1
4
4
  Summary: Efficient optimizers
5
5
  Home-page: https://github.com/clashluke/heavyball
6
6
  Author: Lucas Nestler
@@ -10,7 +10,7 @@ setuptools.setup(
10
10
  name='heavyball',
11
11
  license='BSD',
12
12
  description='Efficient optimizers',
13
- version='0.23.0',
13
+ version='0.23.1',
14
14
  long_description=README,
15
15
  url='https://github.com/clashluke/heavyball',
16
16
  packages=setuptools.find_packages(),
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes