heavyball 0.18.5__py3-none-any.whl → 0.18.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
heavyball/utils.py CHANGED
@@ -44,10 +44,10 @@ def _compilable_schedule_free_(p, z, ckp1, grad, lr, beta1):
44
44
  z32 = z.float()
45
45
  p32.lerp_(end=z32, weight=1 - ckp1)
46
46
  p32.add_(grad, alpha=lr * (beta1 * (1 - ckp1) - 1))
47
- _compilable_copy_stochastic_(p, p32)
47
+ _guarded_copy_stochastic(p, p32)
48
48
 
49
49
  z32.add_(grad, alpha=-lr)
50
- _compilable_copy_stochastic_(z, z32)
50
+ _guarded_copy_stochastic(z, z32)
51
51
 
52
52
 
53
53
  def schedule_free_(lr: float, weight_lr_power: float, weight_sum: float, beta1: float, parameters: List[torch.Tensor],
@@ -157,11 +157,11 @@ def adaptive_gradient_clipping_(parameters: List[torch.Tensor], gradients: List[
157
157
 
158
158
 
159
159
  def set_(dst: torch.Tensor, src: torch.Tensor):
160
- if src.data_ptr() == dst.data_ptr():
160
+ if not torch.compiler.is_compiling() and src.data_ptr() == dst.data_ptr():
161
161
  return
162
162
  if src.shape != dst.shape:
163
163
  src = src.reshape_as(dst)
164
- if src.is_contiguous() and dst.is_contiguous() and src.dtype == dst.dtype:
164
+ if not torch.compiler.is_compiling() and src.is_contiguous() and dst.is_contiguous() and src.dtype == dst.dtype:
165
165
  dst.set_(src)
166
166
  else:
167
167
  dst.copy_(src)
@@ -486,6 +486,12 @@ def copy_stochastic_list_(target: List[torch.Tensor], source: List[torch.Tensor]
486
486
  copy_stochastic_(t, s)
487
487
 
488
488
 
489
+ def _guarded_copy_stochastic(target: torch.Tensor, source: torch.Tensor):
490
+ if target.dtype != torch.bfloat16 or source.dtype not in (torch.float16, torch.float32, torch.float64):
491
+ set_(target, source)
492
+ _compilable_copy_stochastic_(target, source)
493
+
494
+
489
495
  @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
490
496
  def _compilable_copy_stochastic_(target: torch.Tensor, source: torch.Tensor):
491
497
  """Taken as-is from https://github.com/pytorch/pytorch/issues/120376#issuecomment-1974828905"""
@@ -505,10 +511,7 @@ def _compilable_copy_stochastic_(target: torch.Tensor, source: torch.Tensor):
505
511
  def copy_stochastic_(target: torch.Tensor, source: torch.Tensor):
506
512
  if target.data_ptr() == source.data_ptr():
507
513
  return
508
- if target.dtype != torch.bfloat16 or source.dtype not in (torch.float16, torch.float32, torch.float64):
509
- set_(target, source)
510
- return
511
- _compilable_copy_stochastic_(target, source)
514
+ _guarded_copy_stochastic(target, source)
512
515
 
513
516
 
514
517
  @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
@@ -521,7 +524,7 @@ def _compilable_update_one_(p, u, decay, add_fn, lr):
521
524
  p32.add_(u32, alpha=lr)
522
525
  else:
523
526
  add_fn(p32, u32, lr)
524
- _compilable_copy_stochastic_(p, p32)
527
+ _guarded_copy_stochastic(p, p32)
525
528
 
526
529
 
527
530
  def update_param_(param: List[torch.Tensor], update: List[torch.Tensor], lr: float, decay: float,
@@ -840,12 +843,13 @@ class PSGDBase(StatefulOptimizer):
840
843
  psgd_update_precond(Q, self.state_(p)["exprs"], torch.randn_like(grad), grad, precond_lr, self._tiny)
841
844
  update_fn(oq, Q)
842
845
 
843
- for g, q in zip(grad_list, original_q if original_q else q_list):
844
- if g.dim() > 1:
845
- if store_triu_as_line:
846
- psgd_balance_Q([q_ for _, q_ in q])
847
- else:
848
- psgd_balance_Q(q)
846
+ if self.should_update(group, self.balance_probability, "balance_prob"):
847
+ for g, q in zip(grad_list, original_q if original_q else q_list):
848
+ if g.dim() > 1:
849
+ if store_triu_as_line:
850
+ psgd_balance_Q([q_ for _, q_ in q])
851
+ else:
852
+ psgd_balance_Q(q)
849
853
 
850
854
 
851
855
  def precond_update_prob_schedule(max_prob=1.0, min_prob=0.03, decay=0.001, flat_start=250):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: heavyball
3
- Version: 0.18.5
3
+ Version: 0.18.7
4
4
  Summary: Efficient optimizers
5
5
  Home-page: https://github.com/clashluke/heavyball
6
6
  Author: Lucas Nestler
@@ -32,7 +32,7 @@ A simple package of efficient optimizers
32
32
  The goal is not to thrive for completeness, full maintenance or abstraction, but instead to provide a simple
33
33
  largely static alternative to `torch.optim` with more and better optimizers.
34
34
 
35
- Currently (2024-11-20, 0.17.0), the recommended stable optimizer is `PrecondSchedulePaLMSOAP` (see below). The
35
+ Currently (2024-11-21, 0.18.6), the recommended stable optimizer is `PrecondSchedulePaLMSOAP` (see below). The
36
36
  recommended experimental optimizer is `DelayedPSGDKron` ([tuning guide](docs/psgd_efficiency.md)).
37
37
 
38
38
  ## Features
@@ -16,9 +16,9 @@ heavyball/precond_schedule_sfpsoap.py,sha256=vq7jd302refKPa_9X2lkOTOtCCcTBVByPdo
16
16
  heavyball/psgd_kron.py,sha256=u46dorOUXx-do1IYeno2wj-6l1zYKMQQC-N2Zr2PzLI,5476
17
17
  heavyball/pure_psgd.py,sha256=iUy7mMKWxwNiVUMYrQ7SBnreu3t_XSbnhTW3a1yw4m0,4835
18
18
  heavyball/schedule_free_palm_foreach_soap.py,sha256=zkcikH5wWbzq4kOrmBjilvY3iWzuUddcv2HNEPKr3MI,6366
19
- heavyball/utils.py,sha256=2VBEQhtQ4mwsD99JMu7iWbiYPkutspjG3hGwCbIHZ9U,31134
20
- heavyball-0.18.5.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
21
- heavyball-0.18.5.dist-info/METADATA,sha256=Zcc87BhCxDTX7bjJ3pGG7VdIRmpZuYLwmWBKDiLc3AU,11810
22
- heavyball-0.18.5.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
23
- heavyball-0.18.5.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
24
- heavyball-0.18.5.dist-info/RECORD,,
19
+ heavyball/utils.py,sha256=lKIV11qvlHITK7lwaScGbP1ryCmInse9Fe64t0OBmQQ,31408
20
+ heavyball-0.18.7.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
21
+ heavyball-0.18.7.dist-info/METADATA,sha256=KUYpVlwytyMmQBuby0Jf1WaklYdc2GPddiMAqyGKzsM,11810
22
+ heavyball-0.18.7.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
23
+ heavyball-0.18.7.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
24
+ heavyball-0.18.7.dist-info/RECORD,,