heavyball 0.14.7__py3-none-any.whl → 0.15.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- heavyball/__init__.py +25 -3
- heavyball/cached_psgd_kron.py +141 -0
- heavyball/delayed_psgd.py +43 -51
- heavyball/foreach_adamw.py +22 -32
- heavyball/foreach_adopt.py +38 -48
- heavyball/foreach_laprop.py +25 -35
- heavyball/foreach_sfadamw.py +28 -38
- heavyball/foreach_soap.py +56 -70
- heavyball/p_adam.py +46 -50
- heavyball/palm_foreach_sfadamw.py +31 -41
- heavyball/palm_foreach_soap.py +56 -70
- heavyball/precond_schedule_foreach_soap.py +57 -71
- heavyball/precond_schedule_palm_foreach_soap.py +58 -73
- heavyball/precond_schedule_sfpsoap.py +60 -72
- heavyball/psgd_kron.py +43 -49
- heavyball/pure_psgd.py +36 -43
- heavyball/schedule_free_palm_foreach_soap.py +61 -72
- heavyball/utils.py +23 -7
- {heavyball-0.14.7.dist-info → heavyball-0.15.1.dist-info}/METADATA +1 -1
- heavyball-0.15.1.dist-info/RECORD +23 -0
- heavyball-0.14.7.dist-info/RECORD +0 -22
- {heavyball-0.14.7.dist-info → heavyball-0.15.1.dist-info}/LICENSE +0 -0
- {heavyball-0.14.7.dist-info → heavyball-0.15.1.dist-info}/WHEEL +0 -0
- {heavyball-0.14.7.dist-info → heavyball-0.15.1.dist-info}/top_level.txt +0 -0
@@ -46,84 +46,73 @@ class SFPaLMForeachSOAP(ScheduleFree):
|
|
46
46
|
self._data_format = data_format
|
47
47
|
self.rng = random.Random(0x120983109)
|
48
48
|
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
""
|
57
|
-
|
58
|
-
loss = None
|
59
|
-
else:
|
60
|
-
loss = closure()
|
61
|
-
|
62
|
-
for group in self.param_groups:
|
63
|
-
vals = []
|
64
|
-
max_precond_dim = group['max_precond_dim']
|
65
|
-
precondition_1d = group['precondition_1d']
|
66
|
-
|
67
|
-
step = group['step'] = group.get("step", -1) + 1
|
68
|
-
|
69
|
-
for p in group["params"]:
|
70
|
-
if p.grad is None:
|
71
|
-
continue
|
72
|
-
grad = p.grad.float()
|
73
|
-
vals.append((p, grad))
|
74
|
-
|
75
|
-
p_list, grad = zip(*vals)
|
76
|
-
|
77
|
-
adaptive_gradient_clipping_(p_list, grad, group["gradient_clip_val"], eps=group["eps"])
|
78
|
-
|
79
|
-
vals = []
|
80
|
-
|
81
|
-
for p, g in split_p_and_g_in_group(group):
|
82
|
-
state = self.state_(p)
|
83
|
-
|
84
|
-
if "z" not in state:
|
85
|
-
state["z"] = torch.clone(p).float()
|
86
|
-
state["exp_avg_sq"] = torch.zeros_like(g, dtype=torch.float32)
|
87
|
-
init_preconditioner(g, state, max_precond_dim, precondition_1d)
|
88
|
-
update_preconditioner(g, state, max_precond_dim, precondition_1d, 0, True)
|
89
|
-
continue # first step is skipped so that we never use the current gradients in the projection.
|
90
|
-
|
91
|
-
# Projecting gradients to the eigenbases of Shampoo's preconditioner
|
92
|
-
# i.e. projecting to the eigenbases of matrices in state['GG']
|
93
|
-
grad_projected = project(g, state['Q'], False)
|
94
|
-
z, exp_avg_sq = state["z"], state["exp_avg_sq"]
|
95
|
-
vals.append((p, g, grad_projected, z, exp_avg_sq))
|
96
|
-
|
97
|
-
if not vals:
|
49
|
+
def _step(self, group):
|
50
|
+
vals = []
|
51
|
+
max_precond_dim = group['max_precond_dim']
|
52
|
+
precondition_1d = group['precondition_1d']
|
53
|
+
|
54
|
+
step = group['step'] = group.get("step", -1) + 1
|
55
|
+
|
56
|
+
for p in group["params"]:
|
57
|
+
if p.grad is None:
|
98
58
|
continue
|
59
|
+
grad = p.grad.float()
|
60
|
+
vals.append((p, grad))
|
61
|
+
|
62
|
+
if not vals:
|
63
|
+
return
|
64
|
+
|
65
|
+
p_list, grad = zip(*vals)
|
66
|
+
|
67
|
+
adaptive_gradient_clipping_(p_list, grad, group["gradient_clip_val"], eps=group["eps"])
|
68
|
+
|
69
|
+
vals = []
|
70
|
+
|
71
|
+
for p, g in split_p_and_g_in_group(group):
|
72
|
+
state = self.state_(p)
|
73
|
+
|
74
|
+
if "z" not in state:
|
75
|
+
state["z"] = torch.clone(p).float()
|
76
|
+
state["exp_avg_sq"] = torch.zeros_like(g, dtype=torch.float32)
|
77
|
+
init_preconditioner(g, state, max_precond_dim, precondition_1d)
|
78
|
+
update_preconditioner(g, state, max_precond_dim, precondition_1d, 0, True)
|
79
|
+
continue # first step is skipped so that we never use the current gradients in the projection.
|
80
|
+
|
81
|
+
# Projecting gradients to the eigenbases of Shampoo's preconditioner
|
82
|
+
# i.e. projecting to the eigenbases of matrices in state['GG']
|
83
|
+
grad_projected = project(g, state['Q'], False)
|
84
|
+
z, exp_avg_sq = state["z"], state["exp_avg_sq"]
|
85
|
+
vals.append((p, g, grad_projected, z, exp_avg_sq))
|
86
|
+
|
87
|
+
if not vals:
|
88
|
+
return
|
99
89
|
|
100
|
-
|
90
|
+
p_list, grad, grad_projected, z, exp_avg_sq = zip(*vals)
|
101
91
|
|
102
|
-
|
103
|
-
|
92
|
+
beta2 = 1 - max(step, 1) ** -group['beta2_scale']
|
93
|
+
new_debiased2 = beta_debias(beta2, step)
|
104
94
|
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
95
|
+
# Decay the first and second moment running average coefficient
|
96
|
+
# In-place operations to update the averages at the same time
|
97
|
+
denom = exp_avg_sq_(exp_avg_sq, grad, new_debiased2, group["eps"])
|
98
|
+
torch._foreach_div_(grad_projected, denom)
|
109
99
|
|
110
|
-
|
100
|
+
update_precond = group['step'] > 0 and group['step'] % group['precondition_frequency'] == 0
|
111
101
|
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
102
|
+
for p, g, gp in zip(p_list, grad, grad_projected):
|
103
|
+
state = self.state_(p)
|
104
|
+
# Projecting back the preconditioned (by Adam) exponential moving average of gradients
|
105
|
+
# to the original space
|
106
|
+
# CANT DO /= HERE AS EXP_AVG MAY POINT TO THE BUFFER
|
107
|
+
set_(gp, project(gp, state['Q'], back=True))
|
118
108
|
|
119
|
-
|
120
|
-
|
109
|
+
update_preconditioner(g, state, max_precond_dim, precondition_1d, 1 - new_debiased2,
|
110
|
+
update_precond)
|
121
111
|
|
122
|
-
|
123
|
-
|
124
|
-
|
112
|
+
# Weight decay calculated at y
|
113
|
+
if group["weight_decay"] > 0:
|
114
|
+
torch._foreach_add_(grad, p_list, alpha=group["weight_decay"])
|
125
115
|
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
return loss
|
116
|
+
lr = warmup(group['lr'], step, group['warmup_steps'])
|
117
|
+
group['weight_sum'] = schedule_free_(lr, group['weight_lr_power'], group['weight_sum'], group['beta'],
|
118
|
+
p_list, z, grad_projected, group['r'], step)
|
heavyball/utils.py
CHANGED
@@ -3,7 +3,7 @@ import gc
|
|
3
3
|
import math
|
4
4
|
import random
|
5
5
|
import string
|
6
|
-
from typing import List, Optional, Tuple
|
6
|
+
from typing import List, Optional, Tuple, Callable
|
7
7
|
|
8
8
|
import numpy as np
|
9
9
|
import torch
|
@@ -29,7 +29,7 @@ def decorator(func):
|
|
29
29
|
return _fn
|
30
30
|
|
31
31
|
|
32
|
-
|
32
|
+
einsum_base = string.ascii_lowercase + string.ascii_uppercase
|
33
33
|
|
34
34
|
|
35
35
|
def warmup(lr: float, step: int, warmup_steps: int):
|
@@ -317,8 +317,8 @@ def compute_ggt(grad, GG, max_precond_dim, precondition_1d, beta):
|
|
317
317
|
for idx, sh in enumerate(grad.shape):
|
318
318
|
if sh > max_precond_dim:
|
319
319
|
continue
|
320
|
-
b =
|
321
|
-
g0 =
|
320
|
+
b = einsum_base[idx]
|
321
|
+
g0 = einsum_base[:grad.dim()]
|
322
322
|
g1 = g0.replace(b, b.upper())
|
323
323
|
outer_product = torch.einsum(f'{g0},{g1}->{b + b.upper()}', grad, grad)
|
324
324
|
GG[idx].lerp_(promote(outer_product), 1 - beta)
|
@@ -374,7 +374,7 @@ def project(grad, Q, back: bool):
|
|
374
374
|
:param back: whether to project to Shampoo eigenbases or back to original space
|
375
375
|
:return:
|
376
376
|
"""
|
377
|
-
param =
|
377
|
+
param = einsum_base[:grad.dim()]
|
378
378
|
preconditioners = ",".join([(g + g.upper())[::-1 if back else 1] for m, g in zip(Q, param) if len(m) > 0])
|
379
379
|
if preconditioners:
|
380
380
|
out = ''.join([c.upper() if c.upper() in preconditioners else c for c in param])
|
@@ -399,6 +399,20 @@ class StatefulOptimizer(torch.optim.Optimizer):
|
|
399
399
|
tree_map(_add, self.state_(p))
|
400
400
|
return total_bytes
|
401
401
|
|
402
|
+
def _step(self, group):
|
403
|
+
raise NotImplementedError
|
404
|
+
|
405
|
+
def step(self, closure: Optional[Callable] = None):
|
406
|
+
if closure is None:
|
407
|
+
loss = None
|
408
|
+
else:
|
409
|
+
with torch.enable_grad():
|
410
|
+
loss = closure()
|
411
|
+
with torch.no_grad():
|
412
|
+
for group in self.param_groups:
|
413
|
+
self._step(group)
|
414
|
+
return loss
|
415
|
+
|
402
416
|
|
403
417
|
class ScheduleFree(StatefulOptimizer):
|
404
418
|
def eval(self):
|
@@ -684,9 +698,11 @@ def a_law_compress(x, A=87.6):
|
|
684
698
|
torch._foreach_mul_(xa, 1 / (1 + math.log(A)))
|
685
699
|
return xa
|
686
700
|
|
701
|
+
|
687
702
|
def identity(x):
|
688
703
|
return x
|
689
704
|
|
705
|
+
|
690
706
|
def trust_region_clip_(grad, lerp: float = 0.9, scale: float = 1.5):
|
691
707
|
torch._foreach_mul_(grad, 1 / scale)
|
692
708
|
tanh = torch._foreach_tanh(grad)
|
@@ -743,8 +759,8 @@ class PSGDBase(StatefulOptimizer):
|
|
743
759
|
self.rng = random.Random(0x1923213)
|
744
760
|
self._tiny = torch.finfo(torch.bfloat16).tiny
|
745
761
|
|
746
|
-
def balance(self,
|
747
|
-
if
|
762
|
+
def balance(self, grad_list, Q_list):
|
763
|
+
if self.rng.random() > 0.01:
|
748
764
|
return
|
749
765
|
|
750
766
|
for g, q in zip(grad_list, Q_list):
|
@@ -0,0 +1,23 @@
|
|
1
|
+
heavyball/__init__.py,sha256=KbT0GMU0DKqZxq9laCrD7XgiqS9yxC1W52zhte5kjKs,2054
|
2
|
+
heavyball/cached_psgd_kron.py,sha256=mXDtxq2WJST_aUJhrLr_xCCXSFaDvD5gCTSEveBUtac,6754
|
3
|
+
heavyball/delayed_psgd.py,sha256=dN3NW1jmjxmUkgqxPwUVrqLY8nnBOFp4TVtJ_BhPDR4,5814
|
4
|
+
heavyball/foreach_adamw.py,sha256=NSzoIgNm7eavzbJgkAF0k7TUEnWAgOpt9-4juIFoaSA,1729
|
5
|
+
heavyball/foreach_adopt.py,sha256=WA07m5jocLfb1GPU8s6mJ2PteS-03ronkKm-VJrAm5I,1863
|
6
|
+
heavyball/foreach_laprop.py,sha256=mE2NDGX9XgvRhsewcWnk_-FulZPqGA65ejYF_9-A1Xk,1768
|
7
|
+
heavyball/foreach_sfadamw.py,sha256=ussHfPd99u3RTfMrCuu5oIbwNFLXK19wO1Fbz3JShlc,2097
|
8
|
+
heavyball/foreach_soap.py,sha256=WWvssYKg607uoEJHftp8ag8mtKSKSeHrT0QTgqBucVg,4587
|
9
|
+
heavyball/p_adam.py,sha256=ms7BoMHu3jKGsuztUeECrsXufGAwBpqGsxgZ5LBXLQg,6073
|
10
|
+
heavyball/palm_foreach_sfadamw.py,sha256=wjUb_fNZNUmzWXyKvwB0unP9lvNMmaYSQo5YoeS5cj0,2200
|
11
|
+
heavyball/palm_foreach_soap.py,sha256=2Sb4hUHQeexJcCgjHeQM_ENkZ6lG1DVxW72ryrvR6iY,5890
|
12
|
+
heavyball/precond_schedule_foreach_soap.py,sha256=bHsDyh-UvHpHjumjqqy0PePoR1ZMsJV6o5wWvpLAA04,4815
|
13
|
+
heavyball/precond_schedule_palm_foreach_soap.py,sha256=myLTJNQKLtZ3Xi3MVTB-RYtx_XeMRJw5CIMJW75ndUY,6163
|
14
|
+
heavyball/precond_schedule_sfpsoap.py,sha256=xeNWetBzBEYqfOSzl98aAVJsHk43QkrUUhHH_YD_mS4,6740
|
15
|
+
heavyball/psgd_kron.py,sha256=rMG5UPEgyfQs_n1MHSEicekVDpbbIzinlL8akEyY918,5795
|
16
|
+
heavyball/pure_psgd.py,sha256=LLVJhUAb04hgAmT3BTz_faswwQEQUkLhm_VwGQmbBUo,5088
|
17
|
+
heavyball/schedule_free_palm_foreach_soap.py,sha256=w0P7lMmoijTpL9V7NwOHcNBFJQ7S1TS9aCiwPhY2yVw,6319
|
18
|
+
heavyball/utils.py,sha256=PWmwjZPL4oxMjK79a5R1e7JHykphNi5GdpYqO_xmmFU,27829
|
19
|
+
heavyball-0.15.1.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
|
20
|
+
heavyball-0.15.1.dist-info/METADATA,sha256=0wImMJNYM-Zg0akh9hRf7X8ofVW6zlmpyDGgAkK5GFA,11667
|
21
|
+
heavyball-0.15.1.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
22
|
+
heavyball-0.15.1.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
|
23
|
+
heavyball-0.15.1.dist-info/RECORD,,
|
@@ -1,22 +0,0 @@
|
|
1
|
-
heavyball/__init__.py,sha256=ef7IWcPF8Uh3WQHzMiMqOFvUnU_LdG7BO9XVecJOph4,1156
|
2
|
-
heavyball/delayed_psgd.py,sha256=Gfa1ogkFPPL7ohayYAwbugB8hyLRUI5FgcJfsK69KGI,5936
|
3
|
-
heavyball/foreach_adamw.py,sha256=L727chOuVqdiVbYYzloy4g3oKH2FmQh40o_bqWeZtk8,2106
|
4
|
-
heavyball/foreach_adopt.py,sha256=M4zZVcYlhGjqO6qekivCuYfX6JtMcp4cJi_RrSHT3H8,2268
|
5
|
-
heavyball/foreach_laprop.py,sha256=htYGqgvlZsF_JpILdPMTnX72SqfrokBJ2J4nDeT0JVU,2157
|
6
|
-
heavyball/foreach_sfadamw.py,sha256=KIGol7Phfq1DHE_nEle4wDuqNdbGsT3kUcMKzJX3msg,2498
|
7
|
-
heavyball/foreach_soap.py,sha256=Ccz9Mc_xaHnrJ_7jUq9ZVxyR0WEqopzOXTUqUY-V8G8,5137
|
8
|
-
heavyball/p_adam.py,sha256=jQgTkKekqnLj1XPA4-fgpWG8P_BtUq2976zEt2QymTo,6060
|
9
|
-
heavyball/palm_foreach_sfadamw.py,sha256=8IGlRCdkfMzUqcSfmTM3Ce04NLNyrT2xfiBcPkrWwqc,2605
|
10
|
-
heavyball/palm_foreach_soap.py,sha256=NEJ3Xeh7pqURUk3cAP2qJe8z2WzYKg60pQe4bsGiaY4,6441
|
11
|
-
heavyball/precond_schedule_foreach_soap.py,sha256=H6Oc5IAL5MR-fgu92AboPs3Xm8mBmYUMPLsEcuJ12VI,5370
|
12
|
-
heavyball/precond_schedule_palm_foreach_soap.py,sha256=v81hRjcqS6Mm-KxT5Rk3TEiKAE8WI2IbmVbSa-YfBkE,6760
|
13
|
-
heavyball/precond_schedule_sfpsoap.py,sha256=7ScnN0in8u9hPiJE7QnOoZOH6Tn-6HeVy4f-bO3bHzY,7279
|
14
|
-
heavyball/psgd_kron.py,sha256=AH8ugd_IxKGVtY9y_Ot7myVSxFDbLlRJIqr2bBlAYy8,5911
|
15
|
-
heavyball/pure_psgd.py,sha256=jp5fnawUdgccEFlZDPrZr4ZbxYV85IIrev4tybZxBVU,5185
|
16
|
-
heavyball/schedule_free_palm_foreach_soap.py,sha256=bV7H-FNNoH5WpposLrNhkqU7mBicMorqKEALBSdROEM,6853
|
17
|
-
heavyball/utils.py,sha256=y5VAd9CQjcl_a1WUcORviAYf7Jz_c7n3-b7i5kLUJIA,27464
|
18
|
-
heavyball-0.14.7.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
|
19
|
-
heavyball-0.14.7.dist-info/METADATA,sha256=5QWB3nuNAp8YjeX0-Y5Uzkek_wjuGF3XG6UWrQk8R0c,11667
|
20
|
-
heavyball-0.14.7.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
21
|
-
heavyball-0.14.7.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
|
22
|
-
heavyball-0.14.7.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|