heavyball 0.24.3__py3-none-any.whl → 0.24.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- heavyball/foreach_soap.py +2 -3
- heavyball/palm_foreach_soap.py +2 -3
- heavyball/precond_schedule_foreach_soap.py +2 -2
- heavyball/precond_schedule_palm_foreach_soap.py +2 -2
- {heavyball-0.24.3.dist-info → heavyball-0.24.4.dist-info}/METADATA +1 -1
- {heavyball-0.24.3.dist-info → heavyball-0.24.4.dist-info}/RECORD +9 -9
- {heavyball-0.24.3.dist-info → heavyball-0.24.4.dist-info}/LICENSE +0 -0
- {heavyball-0.24.3.dist-info → heavyball-0.24.4.dist-info}/WHEEL +0 -0
- {heavyball-0.24.3.dist-info → heavyball-0.24.4.dist-info}/top_level.txt +0 -0
heavyball/foreach_soap.py
CHANGED
@@ -71,12 +71,11 @@ class ForeachSOAP(StatefulOptimizer):
|
|
71
71
|
# Decay the first and second moment running average coefficient
|
72
72
|
# In-place operations to update the averages at the same time
|
73
73
|
step_tensor = torch.empty((), dtype=torch.int32, device=p_list[0].device).fill_(step)
|
74
|
-
denom = exp_avg_(exp_avg, exp_avg_sq, grad, grad_projected, beta1, beta2, step_tensor)
|
75
74
|
|
76
75
|
step_size = -group["lr"] * min(step / group['warmup_steps'], 1)
|
77
76
|
|
78
|
-
for p, g, ea,
|
79
|
-
|
77
|
+
for p, g, gp, ea, eas in zip(p_list, grad, grad_projected, exp_avg, exp_avg_sq):
|
78
|
+
d = exp_avg_(ea, eas, g, gp, beta1, beta2, step_tensor)[0]
|
80
79
|
# Projecting the exponential moving average of gradients to the eigenbases of Shampoo's preconditioner
|
81
80
|
# i.e. projecting to the eigenbases of matrices in state['GG']
|
82
81
|
exp_avg_projected = project(ea, state['Q'], False)
|
heavyball/palm_foreach_soap.py
CHANGED
@@ -81,11 +81,10 @@ class PaLMForeachSOAP(StatefulOptimizer):
|
|
81
81
|
# In-place operations to update the averages at the same time
|
82
82
|
beta2 = torch.empty((), dtype=torch.float32, device=p_list[0].device).fill_(beta2)
|
83
83
|
step_tensor = torch.empty((), dtype=torch.int32, device=p_list[0].device).fill_(step)
|
84
|
-
denom = exp_avg_(exp_avg, exp_avg_sq, grad, grad_projected, beta1, beta2, step_tensor)
|
85
84
|
step_size = -group["lr"] * min(step / group['warmup_steps'], 1)
|
86
85
|
|
87
|
-
for p, g, ea,
|
88
|
-
|
86
|
+
for p, g, gp, ea, eas in zip(p_list, grad, grad_projected, exp_avg, exp_avg_sq):
|
87
|
+
d = exp_avg_(ea, eas, g, gp, beta1, beta2, step_tensor)[0]
|
89
88
|
# Projecting the exponential moving average of gradients to the eigenbases of Shampoo's preconditioner
|
90
89
|
# i.e. projecting to the eigenbases of matrices in state['GG']
|
91
90
|
exp_avg_projected = project(ea, state['Q'], False)
|
@@ -73,12 +73,12 @@ class PrecondScheduleForeachSOAP(StatefulOptimizer):
|
|
73
73
|
# Decay the first and second moment running average coefficient
|
74
74
|
# In-place operations to update the averages at the same time
|
75
75
|
step_tensor = torch.empty((), dtype=torch.int32, device=p_list[0].device).fill_(step)
|
76
|
-
denom = exp_avg_(exp_avg, exp_avg_sq, grad, grad_projected, beta1, beta2, step_tensor)
|
77
76
|
|
78
77
|
update_precond = precond_schedule(step, group['precond_scheduler'], self.rng)
|
79
78
|
step_size = -group["lr"] * min(step / group['warmup_steps'], 1)
|
80
79
|
|
81
|
-
for p, g, ea,
|
80
|
+
for p, g, gp, ea, eas in zip(p_list, grad, grad_projected, exp_avg, exp_avg_sq):
|
81
|
+
d = exp_avg_(ea, eas, g, gp, beta1, beta2, step_tensor)[0]
|
82
82
|
state = self.state_(p)
|
83
83
|
# Projecting the exponential moving average of gradients to the eigenbases of Shampoo's preconditioner
|
84
84
|
# i.e. projecting to the eigenbases of matrices in state['GG']
|
@@ -84,12 +84,12 @@ class PrecondSchedulePaLMForeachSOAP(StatefulOptimizer):
|
|
84
84
|
# In-place operations to update the averages at the same time
|
85
85
|
beta2 = torch.empty((), dtype=torch.float32, device=p_list[0].device).fill_(beta2)
|
86
86
|
step_tensor = torch.empty((), dtype=torch.int32, device=p_list[0].device).fill_(step)
|
87
|
-
denom = exp_avg_(exp_avg, exp_avg_sq, grad, grad_projected, beta1, beta2, step_tensor)
|
88
87
|
|
89
88
|
update_precond = precond_schedule(step, group['precond_scheduler'], self.rng)
|
90
89
|
step_size = -group["lr"] * min(step / group['warmup_steps'], 1)
|
91
90
|
|
92
|
-
for p, g, ea,
|
91
|
+
for p, g, gp, ea, eas in zip(p_list, grad, grad_projected, exp_avg, exp_avg_sq):
|
92
|
+
d = exp_avg_(ea, eas, g, gp, beta1, beta2, step_tensor)[0]
|
93
93
|
state = self.state_(p)
|
94
94
|
# Projecting the exponential moving average of gradients to the eigenbases of Shampoo's preconditioner
|
95
95
|
# i.e. projecting to the eigenbases of matrices in state['GG']
|
@@ -6,19 +6,19 @@ heavyball/foreach_adamw.py,sha256=K4xTes4drylAqaqWky8O_Bg_mmbAmcHZ5DEBs5vMD-s,28
|
|
6
6
|
heavyball/foreach_adopt.py,sha256=fHnbEqvKKc5IKPDWC9Qo9PiISSjj1MEViy0Jb3BRgZQ,3582
|
7
7
|
heavyball/foreach_laprop.py,sha256=EXkwFQ-H7hHWLmiNUsxUcmXhzNNLMjieHjfOlY_6kmo,2868
|
8
8
|
heavyball/foreach_sfadamw.py,sha256=TeWf0nKXQEFcz02rADYRJenDM9mX1dGHhvILLks6OW8,3087
|
9
|
-
heavyball/foreach_soap.py,sha256=
|
9
|
+
heavyball/foreach_soap.py,sha256=Tgwg4_Sir9nI_3R85f8NMQagquUBJmAEMQqh0uD3b0Y,4771
|
10
10
|
heavyball/p_adam.py,sha256=qEcuU8VEc35vaWAXjT0O65vfCuNn_3ttwL4RlJKN3Xw,6389
|
11
11
|
heavyball/palm_foreach_sfadamw.py,sha256=1qOr-uniSmI1sNCJc1SnvyKH5iFu80Z6H5h93lDTwcE,3410
|
12
|
-
heavyball/palm_foreach_soap.py,sha256=
|
13
|
-
heavyball/precond_schedule_foreach_soap.py,sha256=
|
14
|
-
heavyball/precond_schedule_palm_foreach_soap.py,sha256=
|
12
|
+
heavyball/palm_foreach_soap.py,sha256=zSjpYYm1hfgIudjo_q3ozu3Vkfhz8w8im1c-ou1U3sI,6198
|
13
|
+
heavyball/precond_schedule_foreach_soap.py,sha256=p7oD2bESyCPsdGkJYhHluraDb_1K5Q28RNL6fIvD5C8,4969
|
14
|
+
heavyball/precond_schedule_palm_foreach_soap.py,sha256=Sb3Fhv-EG28_oXnbVpE0iHe5R8i5_hltqoi_DgPuoEU,6505
|
15
15
|
heavyball/precond_schedule_sfpsoap.py,sha256=KUKdZzd336w24zPRcqwRatj7IVmd1Us0a_VuzASluIo,7565
|
16
16
|
heavyball/psgd_kron.py,sha256=PtTe6eR547Y-4CvgjpchgkQsr_kWr4AN-uY9L_JO_C8,6088
|
17
17
|
heavyball/pure_psgd.py,sha256=344NdVNHwUFX3fU2R1S_Xh9SXAML3E4ryHr7xfMh9Cc,5076
|
18
18
|
heavyball/schedule_free_palm_foreach_soap.py,sha256=KTQY37MZH7YnOSTLKY8uVySUXxWXbFVUA1QXN3iv8Ds,7244
|
19
19
|
heavyball/utils.py,sha256=AxhcHzbFAvhTgTFyIcdxs9TJkH4AgVEaNeBRjOLzoBM,40095
|
20
|
-
heavyball-0.24.
|
21
|
-
heavyball-0.24.
|
22
|
-
heavyball-0.24.
|
23
|
-
heavyball-0.24.
|
24
|
-
heavyball-0.24.
|
20
|
+
heavyball-0.24.4.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
|
21
|
+
heavyball-0.24.4.dist-info/METADATA,sha256=oksy8cvcHSdMEs9Mzv7WDAVkkkUPcpA1uYNlUgZM_bk,11926
|
22
|
+
heavyball-0.24.4.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
23
|
+
heavyball-0.24.4.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
|
24
|
+
heavyball-0.24.4.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|