heavyball 0.23.0__py3-none-any.whl → 0.23.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
heavyball/utils.py CHANGED
@@ -686,9 +686,9 @@ def _compilable_copy_stochastic_(target: Tensor, source: Tensor):
686
686
  def copy_stochastic_(target: Tensor, source: Tensor):
687
687
  if not is_compiling() and target.data_ptr() == source.data_ptr():
688
688
  return
689
- if target.dtype != torch.bfloat16 or source.dtype not in (torch.float16, torch.float32, torch.float64):
690
- set_(target, source)
691
- _compilable_copy_stochastic_(target, source)
689
+ if target.dtype == torch.bfloat16 and source.dtype in (torch.float16, torch.float32, torch.float64):
690
+ _compilable_copy_stochastic_(target, source.float())
691
+ set_(target, source)
692
692
 
693
693
 
694
694
  @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: heavyball
3
- Version: 0.23.0
3
+ Version: 0.23.1
4
4
  Summary: Efficient optimizers
5
5
  Home-page: https://github.com/clashluke/heavyball
6
6
  Author: Lucas Nestler
@@ -16,9 +16,9 @@ heavyball/precond_schedule_sfpsoap.py,sha256=FOR-axwlkSN7IHZWYYUVFfjSFCLxc_NdiTl
16
16
  heavyball/psgd_kron.py,sha256=4eiGPXAFjvGIXLdiai1UJfAvTozAV1TXaE9UGkE4BLc,6051
17
17
  heavyball/pure_psgd.py,sha256=344NdVNHwUFX3fU2R1S_Xh9SXAML3E4ryHr7xfMh9Cc,5076
18
18
  heavyball/schedule_free_palm_foreach_soap.py,sha256=0WT_gvTKymqLQzYT6ewDgCmpDq-HgMAewipw1QvyQYA,7267
19
- heavyball/utils.py,sha256=AZlY8dfM0d-C0FXBCJHTJOOoi3RjkMJ-XhU25aBN878,39521
20
- heavyball-0.23.0.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
21
- heavyball-0.23.0.dist-info/METADATA,sha256=3IBUhXA7VJT9GQh460OznCAcIqCG_Mv5Q7HZO8FQ40w,11926
22
- heavyball-0.23.0.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
23
- heavyball-0.23.0.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
24
- heavyball-0.23.0.dist-info/RECORD,,
19
+ heavyball/utils.py,sha256=8XE-z5T7FkbPlfo8Dh9dfoH8UsE-HgjDiJCD_XHkT54,39526
20
+ heavyball-0.23.1.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
21
+ heavyball-0.23.1.dist-info/METADATA,sha256=eE1t-LDRa2ajLlXzITHLzyOt3elr9t4gxaOk55m6pj8,11926
22
+ heavyball-0.23.1.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
23
+ heavyball-0.23.1.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
24
+ heavyball-0.23.1.dist-info/RECORD,,