heavyball 0.18.3__py3-none-any.whl → 0.18.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -34,9 +34,9 @@ class ForeachAdamW(StatefulOptimizer):
34
34
 
35
35
  # Decay the first and second moment running average coefficient
36
36
  torch._foreach_lerp_(exp_avg, grad, 1 - beta_debias(group['betas'][0], k + 1))
37
- denom = exp_avg_sq_(exp_avg_sq, grad, beta_debias(group['betas'][1], k + 1), eps)
37
+ denom = list(exp_avg_sq_(exp_avg_sq, grad, beta_debias(group['betas'][1], k + 1), eps))
38
38
 
39
39
  # Normalize grad in-place for memory efficiency
40
40
  lr = -warmup(group['lr'], k + 1, group['warmup_steps'])
41
- update_param_(y, exp_avg, lr, decay, lambda p, e, l: torch._foreach_addcdiv_(p, e, denom, l))
41
+ update_param_(y, exp_avg, lr, decay, lambda p, e, l: p.addcdiv_(e, denom.pop(0), value=l))
42
42
  group['k'] = k + 1
heavyball/utils.py CHANGED
@@ -38,6 +38,18 @@ def warmup(lr: float, step: int, warmup_steps: int):
38
38
  return lr * step / warmup_steps
39
39
 
40
40
 
41
+ @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
42
+ def _compilable_schedule_free_(p, z, ckp1, grad, lr, beta1):
43
+ p32 = p.float()
44
+ z32 = z.float()
45
+ p32.lerp_(end=z32, weight=1 - ckp1)
46
+ p32.add_(grad, alpha=lr * (beta1 * (1 - ckp1) - 1))
47
+ _compilable_copy_stochastic_(p, p32)
48
+
49
+ z32.add_(grad, alpha=-lr)
50
+ _compilable_copy_stochastic_(z, z32)
51
+
52
+
41
53
  def schedule_free_(lr: float, weight_lr_power: float, weight_sum: float, beta1: float, parameters: List[torch.Tensor],
42
54
  z: List[torch.Tensor], grad: list[torch.Tensor], r: float = 0.0, step: int = 0):
43
55
  weight = lr ** weight_lr_power * max(step, 1) ** r
@@ -50,15 +62,10 @@ def schedule_free_(lr: float, weight_lr_power: float, weight_sum: float, beta1:
50
62
 
51
63
  # These operations update y in-place,
52
64
  # without computing x explicitly.
53
- p32 = [promote(p) for p in parameters]
54
- z32 = [promote(z_) for z_ in z]
55
- torch._foreach_lerp_(p32, z32, weight=ckp1)
56
- torch._foreach_add_(p32, grad, alpha=lr * (beta1 * (1 - ckp1) - 1))
57
- copy_stochastic_list_(parameters, p32)
58
-
59
- # z step
60
- torch._foreach_sub_(z, grad, alpha=lr)
61
- copy_stochastic_list_(z, z32)
65
+ lr_tensor = torch.empty((), dtype=torch.float32, device=parameters[0].device).fill_(lr)
66
+ ckp1_tensor = torch.empty((), dtype=torch.float32, device=parameters[0].device).fill_(ckp1)
67
+ for p, z_, g in zip(parameters, z, grad):
68
+ _compilable_schedule_free_(p, z_, ckp1_tensor, g, lr_tensor, beta1)
62
69
  return weight_sum
63
70
 
64
71
 
@@ -479,7 +486,7 @@ def copy_stochastic_list_(target: List[torch.Tensor], source: List[torch.Tensor]
479
486
  copy_stochastic_(t, s)
480
487
 
481
488
 
482
- @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
489
+ @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
483
490
  def _compilable_copy_stochastic_(target: torch.Tensor, source: torch.Tensor):
484
491
  """Taken as-is from https://github.com/pytorch/pytorch/issues/120376#issuecomment-1974828905"""
485
492
  # create a random 16 bit integer
@@ -504,17 +511,24 @@ def copy_stochastic_(target: torch.Tensor, source: torch.Tensor):
504
511
  _compilable_copy_stochastic_(target, source)
505
512
 
506
513
 
507
- def update_param_(param: List[torch.Tensor], update: List[torch.Tensor], lr: float, decay: float,
508
- add_fn: callable = None):
509
- param32 = [promote(p) for p in param]
510
- update32 = [promote(u.view(p.shape)) for u, p in zip(update, param)]
514
+ @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
515
+ def _compilable_update_one_(p, u, decay, add_fn, lr):
516
+ p32 = p.float()
517
+ u32 = u.view(p.shape).float()
511
518
  if decay > 0:
512
- torch._foreach_mul_(param32, 1 - decay * lr)
519
+ p32.mul_(1 - decay * lr)
513
520
  if add_fn is None:
514
- torch._foreach_add_(param32, update32, alpha=lr)
521
+ p32.add_(u32, alpha=lr)
515
522
  else:
516
- add_fn(param32, update32, lr)
517
- copy_stochastic_list_(param, param32)
523
+ add_fn(p32, u32, lr)
524
+ _compilable_copy_stochastic_(p, p32)
525
+
526
+
527
+ def update_param_(param: List[torch.Tensor], update: List[torch.Tensor], lr: float, decay: float,
528
+ add_fn: callable = None):
529
+ lr_tensor = torch.empty((), dtype=torch.float32, device=param[0].device).fill_(lr)
530
+ for p, u in zip(param, update):
531
+ _compilable_update_one_(p, u, decay, add_fn, lr_tensor)
518
532
 
519
533
 
520
534
  def precond_schedule(step, precond_scheduler, rng):
@@ -815,19 +829,23 @@ class PSGDBase(StatefulOptimizer):
815
829
 
816
830
  def do_update(self, group, p_list, grad_list, q_list, precond_lr, original_q: Optional[List] = None,
817
831
  store_triu_as_line=False):
818
- for i, (p, grad, Q) in enumerate(zip(p_list, grad_list, q_list)):
832
+ if original_q:
833
+ if store_triu_as_line:
834
+ update_fn = update_triu_
835
+ else:
836
+ update_fn = copy_stochastic_list_
837
+ else:
838
+ update_fn = lambda x, y: None
839
+ for i, (p, grad, Q, oq) in enumerate(zip(p_list, grad_list, q_list, original_q)):
819
840
  psgd_update_precond(Q, self.state_(p)["exprs"], torch.randn_like(grad), grad, precond_lr, self._tiny)
841
+ update_fn(oq, Q)
820
842
 
821
- for g, q in zip(grad_list, q_list):
843
+ for g, q in zip(grad_list, original_q if original_q else q_list):
822
844
  if g.dim() > 1:
823
- psgd_balance_Q(q)
824
-
825
- if original_q:
826
- for q in q_list:
827
845
  if store_triu_as_line:
828
- update_triu_(original_q[i], Q)
846
+ psgd_balance_Q([q_ for _, q_ in q])
829
847
  else:
830
- copy_stochastic_list_(original_q[i], Q)
848
+ psgd_balance_Q(q)
831
849
 
832
850
 
833
851
  def precond_update_prob_schedule(max_prob=1.0, min_prob=0.03, decay=0.001, flat_start=250):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: heavyball
3
- Version: 0.18.3
3
+ Version: 0.18.5
4
4
  Summary: Efficient optimizers
5
5
  Home-page: https://github.com/clashluke/heavyball
6
6
  Author: Lucas Nestler
@@ -2,7 +2,7 @@ heavyball/__init__.py,sha256=iqP428JWwwx-XDOZ0nUdbCkOLEyfoqVyWZLQLAcwxaw,2214
2
2
  heavyball/cached_delayed_psgd_kron.py,sha256=PQAER6UgVh5l87DGRZrJ8CVP9UhyCG5wJD9rPLnj_G8,6460
3
3
  heavyball/cached_psgd_kron.py,sha256=GaeneBp0irksCSBIrJY4D_0hCpZ-uSRPMhqVX_a-og8,6417
4
4
  heavyball/delayed_psgd.py,sha256=fhBWFLTSl1S2gHWCeYak-STaXRwpC56sWZGLFMKFEJM,5589
5
- heavyball/foreach_adamw.py,sha256=h_ar0ZZRM0q_wxuEkxEOEYe-2p-mB4OMgAHivrUnPl8,1777
5
+ heavyball/foreach_adamw.py,sha256=CTg7rfUmlTSjihD5KY9xP0sT2dUKZyZ4-2V42Vlr28U,1780
6
6
  heavyball/foreach_adopt.py,sha256=ogOw2JjwEQNj7AKlweAphQFdMJ_GcMDm-RyDvEzugoc,1911
7
7
  heavyball/foreach_laprop.py,sha256=yGVmGqWiSw8Y2Xj70ndkR8ZMygakTB4_iRwV02Svkqg,1816
8
8
  heavyball/foreach_sfadamw.py,sha256=15-n6-lx4PAHYsKYmXbugxsR5MnqaPYy2vUudPRiitg,2087
@@ -16,9 +16,9 @@ heavyball/precond_schedule_sfpsoap.py,sha256=vq7jd302refKPa_9X2lkOTOtCCcTBVByPdo
16
16
  heavyball/psgd_kron.py,sha256=u46dorOUXx-do1IYeno2wj-6l1zYKMQQC-N2Zr2PzLI,5476
17
17
  heavyball/pure_psgd.py,sha256=iUy7mMKWxwNiVUMYrQ7SBnreu3t_XSbnhTW3a1yw4m0,4835
18
18
  heavyball/schedule_free_palm_foreach_soap.py,sha256=zkcikH5wWbzq4kOrmBjilvY3iWzuUddcv2HNEPKr3MI,6366
19
- heavyball/utils.py,sha256=qs_WfzJdS-3XyEuw-m6mWMEeR95r7bGFVC8wWCHtD48,30365
20
- heavyball-0.18.3.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
21
- heavyball-0.18.3.dist-info/METADATA,sha256=Cx8LM2g3BFOk8WJH3B8ve8kQ7HghMCIRLggdJp37x4g,11810
22
- heavyball-0.18.3.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
23
- heavyball-0.18.3.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
24
- heavyball-0.18.3.dist-info/RECORD,,
19
+ heavyball/utils.py,sha256=2VBEQhtQ4mwsD99JMu7iWbiYPkutspjG3hGwCbIHZ9U,31134
20
+ heavyball-0.18.5.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
21
+ heavyball-0.18.5.dist-info/METADATA,sha256=Zcc87BhCxDTX7bjJ3pGG7VdIRmpZuYLwmWBKDiLc3AU,11810
22
+ heavyball-0.18.5.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
23
+ heavyball-0.18.5.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
24
+ heavyball-0.18.5.dist-info/RECORD,,