heavyball 0.14.7__py3-none-any.whl → 0.15.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -46,84 +46,73 @@ class SFPaLMForeachSOAP(ScheduleFree):
46
46
  self._data_format = data_format
47
47
  self.rng = random.Random(0x120983109)
48
48
 
49
- @torch.no_grad()
50
- def step(self, closure=None):
51
- """
52
- Performs a single optimization step.
53
-
54
- Arguments:
55
- closure (`Callable`, *optional*): A closure that reevaluates the model and returns the loss.
56
- """
57
- if closure is None:
58
- loss = None
59
- else:
60
- loss = closure()
61
-
62
- for group in self.param_groups:
63
- vals = []
64
- max_precond_dim = group['max_precond_dim']
65
- precondition_1d = group['precondition_1d']
66
-
67
- step = group['step'] = group.get("step", -1) + 1
68
-
69
- for p in group["params"]:
70
- if p.grad is None:
71
- continue
72
- grad = p.grad.float()
73
- vals.append((p, grad))
74
-
75
- p_list, grad = zip(*vals)
76
-
77
- adaptive_gradient_clipping_(p_list, grad, group["gradient_clip_val"], eps=group["eps"])
78
-
79
- vals = []
80
-
81
- for p, g in split_p_and_g_in_group(group):
82
- state = self.state_(p)
83
-
84
- if "z" not in state:
85
- state["z"] = torch.clone(p).float()
86
- state["exp_avg_sq"] = torch.zeros_like(g, dtype=torch.float32)
87
- init_preconditioner(g, state, max_precond_dim, precondition_1d)
88
- update_preconditioner(g, state, max_precond_dim, precondition_1d, 0, True)
89
- continue # first step is skipped so that we never use the current gradients in the projection.
90
-
91
- # Projecting gradients to the eigenbases of Shampoo's preconditioner
92
- # i.e. projecting to the eigenbases of matrices in state['GG']
93
- grad_projected = project(g, state['Q'], False)
94
- z, exp_avg_sq = state["z"], state["exp_avg_sq"]
95
- vals.append((p, g, grad_projected, z, exp_avg_sq))
96
-
97
- if not vals:
49
+ def _step(self, group):
50
+ vals = []
51
+ max_precond_dim = group['max_precond_dim']
52
+ precondition_1d = group['precondition_1d']
53
+
54
+ step = group['step'] = group.get("step", -1) + 1
55
+
56
+ for p in group["params"]:
57
+ if p.grad is None:
98
58
  continue
59
+ grad = p.grad.float()
60
+ vals.append((p, grad))
61
+
62
+ if not vals:
63
+ return
64
+
65
+ p_list, grad = zip(*vals)
66
+
67
+ adaptive_gradient_clipping_(p_list, grad, group["gradient_clip_val"], eps=group["eps"])
68
+
69
+ vals = []
70
+
71
+ for p, g in split_p_and_g_in_group(group):
72
+ state = self.state_(p)
73
+
74
+ if "z" not in state:
75
+ state["z"] = torch.clone(p).float()
76
+ state["exp_avg_sq"] = torch.zeros_like(g, dtype=torch.float32)
77
+ init_preconditioner(g, state, max_precond_dim, precondition_1d)
78
+ update_preconditioner(g, state, max_precond_dim, precondition_1d, 0, True)
79
+ continue # first step is skipped so that we never use the current gradients in the projection.
80
+
81
+ # Projecting gradients to the eigenbases of Shampoo's preconditioner
82
+ # i.e. projecting to the eigenbases of matrices in state['GG']
83
+ grad_projected = project(g, state['Q'], False)
84
+ z, exp_avg_sq = state["z"], state["exp_avg_sq"]
85
+ vals.append((p, g, grad_projected, z, exp_avg_sq))
86
+
87
+ if not vals:
88
+ return
99
89
 
100
- p_list, grad, grad_projected, z, exp_avg_sq = zip(*vals)
90
+ p_list, grad, grad_projected, z, exp_avg_sq = zip(*vals)
101
91
 
102
- beta2 = 1 - max(step, 1) ** -group['beta2_scale']
103
- new_debiased2 = beta_debias(beta2, step)
92
+ beta2 = 1 - max(step, 1) ** -group['beta2_scale']
93
+ new_debiased2 = beta_debias(beta2, step)
104
94
 
105
- # Decay the first and second moment running average coefficient
106
- # In-place operations to update the averages at the same time
107
- denom = exp_avg_sq_(exp_avg_sq, grad, new_debiased2, group["eps"])
108
- torch._foreach_div_(grad_projected, denom)
95
+ # Decay the first and second moment running average coefficient
96
+ # In-place operations to update the averages at the same time
97
+ denom = exp_avg_sq_(exp_avg_sq, grad, new_debiased2, group["eps"])
98
+ torch._foreach_div_(grad_projected, denom)
109
99
 
110
- update_precond = group['step'] > 0 and group['step'] % group['precondition_frequency'] == 0
100
+ update_precond = group['step'] > 0 and group['step'] % group['precondition_frequency'] == 0
111
101
 
112
- for p, g, gp in zip(p_list, grad, grad_projected):
113
- state = self.state_(p)
114
- # Projecting back the preconditioned (by Adam) exponential moving average of gradients
115
- # to the original space
116
- # CANT DO /= HERE AS EXP_AVG MAY POINT TO THE BUFFER
117
- set_(gp, project(gp, state['Q'], back=True))
102
+ for p, g, gp in zip(p_list, grad, grad_projected):
103
+ state = self.state_(p)
104
+ # Projecting back the preconditioned (by Adam) exponential moving average of gradients
105
+ # to the original space
106
+ # CANT DO /= HERE AS EXP_AVG MAY POINT TO THE BUFFER
107
+ set_(gp, project(gp, state['Q'], back=True))
118
108
 
119
- update_preconditioner(g, state, max_precond_dim, precondition_1d, 1 - new_debiased2,
120
- update_precond)
109
+ update_preconditioner(g, state, max_precond_dim, precondition_1d, 1 - new_debiased2,
110
+ update_precond)
121
111
 
122
- # Weight decay calculated at y
123
- if group["weight_decay"] > 0:
124
- torch._foreach_add_(grad, p_list, alpha=group["weight_decay"])
112
+ # Weight decay calculated at y
113
+ if group["weight_decay"] > 0:
114
+ torch._foreach_add_(grad, p_list, alpha=group["weight_decay"])
125
115
 
126
- lr = warmup(group['lr'], step, group['warmup_steps'])
127
- group['weight_sum'] = schedule_free_(lr, group['weight_lr_power'], group['weight_sum'], group['beta'],
128
- p_list, z, grad_projected, group['r'], step)
129
- return loss
116
+ lr = warmup(group['lr'], step, group['warmup_steps'])
117
+ group['weight_sum'] = schedule_free_(lr, group['weight_lr_power'], group['weight_sum'], group['beta'],
118
+ p_list, z, grad_projected, group['r'], step)
heavyball/utils.py CHANGED
@@ -3,7 +3,7 @@ import gc
3
3
  import math
4
4
  import random
5
5
  import string
6
- from typing import List, Optional, Tuple
6
+ from typing import List, Optional, Tuple, Callable
7
7
 
8
8
  import numpy as np
9
9
  import torch
@@ -29,7 +29,7 @@ def decorator(func):
29
29
  return _fn
30
30
 
31
31
 
32
- _einsum_base = string.ascii_lowercase + string.ascii_uppercase
32
+ einsum_base = string.ascii_lowercase + string.ascii_uppercase
33
33
 
34
34
 
35
35
  def warmup(lr: float, step: int, warmup_steps: int):
@@ -317,8 +317,8 @@ def compute_ggt(grad, GG, max_precond_dim, precondition_1d, beta):
317
317
  for idx, sh in enumerate(grad.shape):
318
318
  if sh > max_precond_dim:
319
319
  continue
320
- b = _einsum_base[idx]
321
- g0 = _einsum_base[:grad.dim()]
320
+ b = einsum_base[idx]
321
+ g0 = einsum_base[:grad.dim()]
322
322
  g1 = g0.replace(b, b.upper())
323
323
  outer_product = torch.einsum(f'{g0},{g1}->{b + b.upper()}', grad, grad)
324
324
  GG[idx].lerp_(promote(outer_product), 1 - beta)
@@ -374,7 +374,7 @@ def project(grad, Q, back: bool):
374
374
  :param back: whether to project to Shampoo eigenbases or back to original space
375
375
  :return:
376
376
  """
377
- param = _einsum_base[:grad.dim()]
377
+ param = einsum_base[:grad.dim()]
378
378
  preconditioners = ",".join([(g + g.upper())[::-1 if back else 1] for m, g in zip(Q, param) if len(m) > 0])
379
379
  if preconditioners:
380
380
  out = ''.join([c.upper() if c.upper() in preconditioners else c for c in param])
@@ -399,6 +399,20 @@ class StatefulOptimizer(torch.optim.Optimizer):
399
399
  tree_map(_add, self.state_(p))
400
400
  return total_bytes
401
401
 
402
+ def _step(self, group):
403
+ raise NotImplementedError
404
+
405
+ def step(self, closure: Optional[Callable] = None):
406
+ if closure is None:
407
+ loss = None
408
+ else:
409
+ with torch.enable_grad():
410
+ loss = closure()
411
+ with torch.no_grad():
412
+ for group in self.param_groups:
413
+ self._step(group)
414
+ return loss
415
+
402
416
 
403
417
  class ScheduleFree(StatefulOptimizer):
404
418
  def eval(self):
@@ -684,9 +698,11 @@ def a_law_compress(x, A=87.6):
684
698
  torch._foreach_mul_(xa, 1 / (1 + math.log(A)))
685
699
  return xa
686
700
 
701
+
687
702
  def identity(x):
688
703
  return x
689
704
 
705
+
690
706
  def trust_region_clip_(grad, lerp: float = 0.9, scale: float = 1.5):
691
707
  torch._foreach_mul_(grad, 1 / scale)
692
708
  tanh = torch._foreach_tanh(grad)
@@ -743,8 +759,8 @@ class PSGDBase(StatefulOptimizer):
743
759
  self.rng = random.Random(0x1923213)
744
760
  self._tiny = torch.finfo(torch.bfloat16).tiny
745
761
 
746
- def balance(self, do_update, grad_list, Q_list):
747
- if not do_update or self.rng.random() > 0.01:
762
+ def balance(self, grad_list, Q_list):
763
+ if self.rng.random() > 0.01:
748
764
  return
749
765
 
750
766
  for g, q in zip(grad_list, Q_list):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: heavyball
3
- Version: 0.14.7
3
+ Version: 0.15.1
4
4
  Summary: Efficient optimizers
5
5
  Home-page: https://github.com/clashluke/heavyball
6
6
  Author: Lucas Nestler
@@ -0,0 +1,23 @@
1
+ heavyball/__init__.py,sha256=KbT0GMU0DKqZxq9laCrD7XgiqS9yxC1W52zhte5kjKs,2054
2
+ heavyball/cached_psgd_kron.py,sha256=mXDtxq2WJST_aUJhrLr_xCCXSFaDvD5gCTSEveBUtac,6754
3
+ heavyball/delayed_psgd.py,sha256=dN3NW1jmjxmUkgqxPwUVrqLY8nnBOFp4TVtJ_BhPDR4,5814
4
+ heavyball/foreach_adamw.py,sha256=NSzoIgNm7eavzbJgkAF0k7TUEnWAgOpt9-4juIFoaSA,1729
5
+ heavyball/foreach_adopt.py,sha256=WA07m5jocLfb1GPU8s6mJ2PteS-03ronkKm-VJrAm5I,1863
6
+ heavyball/foreach_laprop.py,sha256=mE2NDGX9XgvRhsewcWnk_-FulZPqGA65ejYF_9-A1Xk,1768
7
+ heavyball/foreach_sfadamw.py,sha256=ussHfPd99u3RTfMrCuu5oIbwNFLXK19wO1Fbz3JShlc,2097
8
+ heavyball/foreach_soap.py,sha256=WWvssYKg607uoEJHftp8ag8mtKSKSeHrT0QTgqBucVg,4587
9
+ heavyball/p_adam.py,sha256=ms7BoMHu3jKGsuztUeECrsXufGAwBpqGsxgZ5LBXLQg,6073
10
+ heavyball/palm_foreach_sfadamw.py,sha256=wjUb_fNZNUmzWXyKvwB0unP9lvNMmaYSQo5YoeS5cj0,2200
11
+ heavyball/palm_foreach_soap.py,sha256=2Sb4hUHQeexJcCgjHeQM_ENkZ6lG1DVxW72ryrvR6iY,5890
12
+ heavyball/precond_schedule_foreach_soap.py,sha256=bHsDyh-UvHpHjumjqqy0PePoR1ZMsJV6o5wWvpLAA04,4815
13
+ heavyball/precond_schedule_palm_foreach_soap.py,sha256=myLTJNQKLtZ3Xi3MVTB-RYtx_XeMRJw5CIMJW75ndUY,6163
14
+ heavyball/precond_schedule_sfpsoap.py,sha256=xeNWetBzBEYqfOSzl98aAVJsHk43QkrUUhHH_YD_mS4,6740
15
+ heavyball/psgd_kron.py,sha256=rMG5UPEgyfQs_n1MHSEicekVDpbbIzinlL8akEyY918,5795
16
+ heavyball/pure_psgd.py,sha256=LLVJhUAb04hgAmT3BTz_faswwQEQUkLhm_VwGQmbBUo,5088
17
+ heavyball/schedule_free_palm_foreach_soap.py,sha256=w0P7lMmoijTpL9V7NwOHcNBFJQ7S1TS9aCiwPhY2yVw,6319
18
+ heavyball/utils.py,sha256=PWmwjZPL4oxMjK79a5R1e7JHykphNi5GdpYqO_xmmFU,27829
19
+ heavyball-0.15.1.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
20
+ heavyball-0.15.1.dist-info/METADATA,sha256=0wImMJNYM-Zg0akh9hRf7X8ofVW6zlmpyDGgAkK5GFA,11667
21
+ heavyball-0.15.1.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
22
+ heavyball-0.15.1.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
23
+ heavyball-0.15.1.dist-info/RECORD,,
@@ -1,22 +0,0 @@
1
- heavyball/__init__.py,sha256=ef7IWcPF8Uh3WQHzMiMqOFvUnU_LdG7BO9XVecJOph4,1156
2
- heavyball/delayed_psgd.py,sha256=Gfa1ogkFPPL7ohayYAwbugB8hyLRUI5FgcJfsK69KGI,5936
3
- heavyball/foreach_adamw.py,sha256=L727chOuVqdiVbYYzloy4g3oKH2FmQh40o_bqWeZtk8,2106
4
- heavyball/foreach_adopt.py,sha256=M4zZVcYlhGjqO6qekivCuYfX6JtMcp4cJi_RrSHT3H8,2268
5
- heavyball/foreach_laprop.py,sha256=htYGqgvlZsF_JpILdPMTnX72SqfrokBJ2J4nDeT0JVU,2157
6
- heavyball/foreach_sfadamw.py,sha256=KIGol7Phfq1DHE_nEle4wDuqNdbGsT3kUcMKzJX3msg,2498
7
- heavyball/foreach_soap.py,sha256=Ccz9Mc_xaHnrJ_7jUq9ZVxyR0WEqopzOXTUqUY-V8G8,5137
8
- heavyball/p_adam.py,sha256=jQgTkKekqnLj1XPA4-fgpWG8P_BtUq2976zEt2QymTo,6060
9
- heavyball/palm_foreach_sfadamw.py,sha256=8IGlRCdkfMzUqcSfmTM3Ce04NLNyrT2xfiBcPkrWwqc,2605
10
- heavyball/palm_foreach_soap.py,sha256=NEJ3Xeh7pqURUk3cAP2qJe8z2WzYKg60pQe4bsGiaY4,6441
11
- heavyball/precond_schedule_foreach_soap.py,sha256=H6Oc5IAL5MR-fgu92AboPs3Xm8mBmYUMPLsEcuJ12VI,5370
12
- heavyball/precond_schedule_palm_foreach_soap.py,sha256=v81hRjcqS6Mm-KxT5Rk3TEiKAE8WI2IbmVbSa-YfBkE,6760
13
- heavyball/precond_schedule_sfpsoap.py,sha256=7ScnN0in8u9hPiJE7QnOoZOH6Tn-6HeVy4f-bO3bHzY,7279
14
- heavyball/psgd_kron.py,sha256=AH8ugd_IxKGVtY9y_Ot7myVSxFDbLlRJIqr2bBlAYy8,5911
15
- heavyball/pure_psgd.py,sha256=jp5fnawUdgccEFlZDPrZr4ZbxYV85IIrev4tybZxBVU,5185
16
- heavyball/schedule_free_palm_foreach_soap.py,sha256=bV7H-FNNoH5WpposLrNhkqU7mBicMorqKEALBSdROEM,6853
17
- heavyball/utils.py,sha256=y5VAd9CQjcl_a1WUcORviAYf7Jz_c7n3-b7i5kLUJIA,27464
18
- heavyball-0.14.7.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
19
- heavyball-0.14.7.dist-info/METADATA,sha256=5QWB3nuNAp8YjeX0-Y5Uzkek_wjuGF3XG6UWrQk8R0c,11667
20
- heavyball-0.14.7.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
21
- heavyball-0.14.7.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
22
- heavyball-0.14.7.dist-info/RECORD,,