heavyball 0.24.3__py3-none-any.whl → 0.24.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
heavyball/foreach_soap.py CHANGED
@@ -71,12 +71,11 @@ class ForeachSOAP(StatefulOptimizer):
71
71
  # Decay the first and second moment running average coefficient
72
72
  # In-place operations to update the averages at the same time
73
73
  step_tensor = torch.empty((), dtype=torch.int32, device=p_list[0].device).fill_(step)
74
- denom = exp_avg_(exp_avg, exp_avg_sq, grad, grad_projected, beta1, beta2, step_tensor)
75
74
 
76
75
  step_size = -group["lr"] * min(step / group['warmup_steps'], 1)
77
76
 
78
- for p, g, ea, d in zip(p_list, grad, exp_avg, denom):
79
- state = self.state_(p)
77
+ for p, g, gp, ea, eas in zip(p_list, grad, grad_projected, exp_avg, exp_avg_sq):
78
+ d = exp_avg_(ea, eas, g, gp, beta1, beta2, step_tensor)[0]
80
79
  # Projecting the exponential moving average of gradients to the eigenbases of Shampoo's preconditioner
81
80
  # i.e. projecting to the eigenbases of matrices in state['GG']
82
81
  exp_avg_projected = project(ea, state['Q'], False)
@@ -81,11 +81,10 @@ class PaLMForeachSOAP(StatefulOptimizer):
81
81
  # In-place operations to update the averages at the same time
82
82
  beta2 = torch.empty((), dtype=torch.float32, device=p_list[0].device).fill_(beta2)
83
83
  step_tensor = torch.empty((), dtype=torch.int32, device=p_list[0].device).fill_(step)
84
- denom = exp_avg_(exp_avg, exp_avg_sq, grad, grad_projected, beta1, beta2, step_tensor)
85
84
  step_size = -group["lr"] * min(step / group['warmup_steps'], 1)
86
85
 
87
- for p, g, ea, d in zip(p_list, grad, exp_avg, denom):
88
- state = self.state_(p)
86
+ for p, g, gp, ea, eas in zip(p_list, grad, grad_projected, exp_avg, exp_avg_sq):
87
+ d = exp_avg_(ea, eas, g, gp, beta1, beta2, step_tensor)[0]
89
88
  # Projecting the exponential moving average of gradients to the eigenbases of Shampoo's preconditioner
90
89
  # i.e. projecting to the eigenbases of matrices in state['GG']
91
90
  exp_avg_projected = project(ea, state['Q'], False)
@@ -73,12 +73,12 @@ class PrecondScheduleForeachSOAP(StatefulOptimizer):
73
73
  # Decay the first and second moment running average coefficient
74
74
  # In-place operations to update the averages at the same time
75
75
  step_tensor = torch.empty((), dtype=torch.int32, device=p_list[0].device).fill_(step)
76
- denom = exp_avg_(exp_avg, exp_avg_sq, grad, grad_projected, beta1, beta2, step_tensor)
77
76
 
78
77
  update_precond = precond_schedule(step, group['precond_scheduler'], self.rng)
79
78
  step_size = -group["lr"] * min(step / group['warmup_steps'], 1)
80
79
 
81
- for p, g, ea, d in zip(p_list, grad, exp_avg, denom):
80
+ for p, g, gp, ea, eas in zip(p_list, grad, grad_projected, exp_avg, exp_avg_sq):
81
+ d = exp_avg_(ea, eas, g, gp, beta1, beta2, step_tensor)[0]
82
82
  state = self.state_(p)
83
83
  # Projecting the exponential moving average of gradients to the eigenbases of Shampoo's preconditioner
84
84
  # i.e. projecting to the eigenbases of matrices in state['GG']
@@ -84,12 +84,12 @@ class PrecondSchedulePaLMForeachSOAP(StatefulOptimizer):
84
84
  # In-place operations to update the averages at the same time
85
85
  beta2 = torch.empty((), dtype=torch.float32, device=p_list[0].device).fill_(beta2)
86
86
  step_tensor = torch.empty((), dtype=torch.int32, device=p_list[0].device).fill_(step)
87
- denom = exp_avg_(exp_avg, exp_avg_sq, grad, grad_projected, beta1, beta2, step_tensor)
88
87
 
89
88
  update_precond = precond_schedule(step, group['precond_scheduler'], self.rng)
90
89
  step_size = -group["lr"] * min(step / group['warmup_steps'], 1)
91
90
 
92
- for p, g, ea, d in zip(p_list, grad, exp_avg, denom):
91
+ for p, g, gp, ea, eas in zip(p_list, grad, grad_projected, exp_avg, exp_avg_sq):
92
+ d = exp_avg_(ea, eas, g, gp, beta1, beta2, step_tensor)[0]
93
93
  state = self.state_(p)
94
94
  # Projecting the exponential moving average of gradients to the eigenbases of Shampoo's preconditioner
95
95
  # i.e. projecting to the eigenbases of matrices in state['GG']
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: heavyball
3
- Version: 0.24.3
3
+ Version: 0.24.4
4
4
  Summary: Efficient optimizers
5
5
  Home-page: https://github.com/clashluke/heavyball
6
6
  Author: Lucas Nestler
@@ -6,19 +6,19 @@ heavyball/foreach_adamw.py,sha256=K4xTes4drylAqaqWky8O_Bg_mmbAmcHZ5DEBs5vMD-s,28
6
6
  heavyball/foreach_adopt.py,sha256=fHnbEqvKKc5IKPDWC9Qo9PiISSjj1MEViy0Jb3BRgZQ,3582
7
7
  heavyball/foreach_laprop.py,sha256=EXkwFQ-H7hHWLmiNUsxUcmXhzNNLMjieHjfOlY_6kmo,2868
8
8
  heavyball/foreach_sfadamw.py,sha256=TeWf0nKXQEFcz02rADYRJenDM9mX1dGHhvILLks6OW8,3087
9
- heavyball/foreach_soap.py,sha256=408jRysE9ek0ea-TphhSBMTa9zcjkgMX3qlx8qTCt34,4803
9
+ heavyball/foreach_soap.py,sha256=Tgwg4_Sir9nI_3R85f8NMQagquUBJmAEMQqh0uD3b0Y,4771
10
10
  heavyball/p_adam.py,sha256=qEcuU8VEc35vaWAXjT0O65vfCuNn_3ttwL4RlJKN3Xw,6389
11
11
  heavyball/palm_foreach_sfadamw.py,sha256=1qOr-uniSmI1sNCJc1SnvyKH5iFu80Z6H5h93lDTwcE,3410
12
- heavyball/palm_foreach_soap.py,sha256=cExM9nTC3zAgsRr42VOIMWNwYA4dAJaA8-pIo7SWilc,6230
13
- heavyball/precond_schedule_foreach_soap.py,sha256=EL_Z-v5l7BC98QgI-Zg9iyM77TAreVgD5Zln59ewGoI,4966
14
- heavyball/precond_schedule_palm_foreach_soap.py,sha256=HWo2t7yY-_n4pPGmDiELccy0jdELTVhdlH-eyFBih5k,6502
12
+ heavyball/palm_foreach_soap.py,sha256=zSjpYYm1hfgIudjo_q3ozu3Vkfhz8w8im1c-ou1U3sI,6198
13
+ heavyball/precond_schedule_foreach_soap.py,sha256=p7oD2bESyCPsdGkJYhHluraDb_1K5Q28RNL6fIvD5C8,4969
14
+ heavyball/precond_schedule_palm_foreach_soap.py,sha256=Sb3Fhv-EG28_oXnbVpE0iHe5R8i5_hltqoi_DgPuoEU,6505
15
15
  heavyball/precond_schedule_sfpsoap.py,sha256=KUKdZzd336w24zPRcqwRatj7IVmd1Us0a_VuzASluIo,7565
16
16
  heavyball/psgd_kron.py,sha256=PtTe6eR547Y-4CvgjpchgkQsr_kWr4AN-uY9L_JO_C8,6088
17
17
  heavyball/pure_psgd.py,sha256=344NdVNHwUFX3fU2R1S_Xh9SXAML3E4ryHr7xfMh9Cc,5076
18
18
  heavyball/schedule_free_palm_foreach_soap.py,sha256=KTQY37MZH7YnOSTLKY8uVySUXxWXbFVUA1QXN3iv8Ds,7244
19
19
  heavyball/utils.py,sha256=AxhcHzbFAvhTgTFyIcdxs9TJkH4AgVEaNeBRjOLzoBM,40095
20
- heavyball-0.24.3.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
21
- heavyball-0.24.3.dist-info/METADATA,sha256=32T-Q-a4k096KjxoR-3DQt25XpO_h0zs7lWKTDQLugI,11926
22
- heavyball-0.24.3.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
23
- heavyball-0.24.3.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
24
- heavyball-0.24.3.dist-info/RECORD,,
20
+ heavyball-0.24.4.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
21
+ heavyball-0.24.4.dist-info/METADATA,sha256=oksy8cvcHSdMEs9Mzv7WDAVkkkUPcpA1uYNlUgZM_bk,11926
22
+ heavyball-0.24.4.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
23
+ heavyball-0.24.4.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
24
+ heavyball-0.24.4.dist-info/RECORD,,