heavyball 0.18.2__py3-none-any.whl → 0.18.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- heavyball/cached_delayed_psgd_kron.py +2 -1
- heavyball/cached_psgd_kron.py +4 -3
- heavyball/delayed_psgd.py +2 -1
- heavyball/p_adam.py +2 -1
- heavyball/psgd_kron.py +2 -1
- heavyball/pure_psgd.py +2 -2
- heavyball/utils.py +1 -3
- {heavyball-0.18.2.dist-info → heavyball-0.18.3.dist-info}/METADATA +1 -1
- {heavyball-0.18.2.dist-info → heavyball-0.18.3.dist-info}/RECORD +12 -12
- {heavyball-0.18.2.dist-info → heavyball-0.18.3.dist-info}/LICENSE +0 -0
- {heavyball-0.18.2.dist-info → heavyball-0.18.3.dist-info}/WHEEL +0 -0
- {heavyball-0.18.2.dist-info → heavyball-0.18.3.dist-info}/top_level.txt +0 -0
@@ -62,6 +62,7 @@ class ForeachCachedDelayedPSGDKron(PSGDBase):
|
|
62
62
|
super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
|
63
63
|
|
64
64
|
def _step(self, group):
|
65
|
+
should_update = self.should_update(group)
|
65
66
|
momentum_into_precond_update = group.get("momentum_into_precond_update", True)
|
66
67
|
precond_init_scale = group['precond_init_scale']
|
67
68
|
max_size_triangular = group['max_size_triangular']
|
@@ -115,7 +116,7 @@ class ForeachCachedDelayedPSGDKron(PSGDBase):
|
|
115
116
|
|
116
117
|
new = torch.einsum(self.state_(p)['cache_expr'], *cached_q, ea)
|
117
118
|
|
118
|
-
if
|
119
|
+
if should_update:
|
119
120
|
q = line_to_triu(q_orig) if store_triu_as_line else q_orig
|
120
121
|
q32 = [promote(q_) for q_ in q]
|
121
122
|
self.do_update(group, [p], [ea if momentum_into_precond_update else g], [q32], precond_lr, [q_orig],
|
heavyball/cached_psgd_kron.py
CHANGED
@@ -8,8 +8,8 @@ from typing import Optional
|
|
8
8
|
|
9
9
|
import torch
|
10
10
|
|
11
|
-
from .utils import update_param_, warmup, init_Q_exprs, trust_region_clip_, PSGDBase, \
|
12
|
-
|
11
|
+
from .utils import update_param_, warmup, init_Q_exprs, trust_region_clip_, PSGDBase, split_p_and_g_in_group, \
|
12
|
+
line_to_triu, triu_to_line, set_, einsum_base, promote
|
13
13
|
|
14
14
|
|
15
15
|
class ForeachCachedPSGDKron(PSGDBase):
|
@@ -71,6 +71,7 @@ class ForeachCachedPSGDKron(PSGDBase):
|
|
71
71
|
beta = group['beta']
|
72
72
|
store_triu_as_line = group['store_triu_as_line']
|
73
73
|
q_dtype = getattr(torch, group['q_dtype'])
|
74
|
+
should_update = self.should_update(group)
|
74
75
|
|
75
76
|
vals = []
|
76
77
|
|
@@ -111,7 +112,7 @@ class ForeachCachedPSGDKron(PSGDBase):
|
|
111
112
|
q_orig = Q_list.pop(0)
|
112
113
|
ea = exp_avg_list.pop(0)
|
113
114
|
|
114
|
-
if
|
115
|
+
if should_update:
|
115
116
|
q = line_to_triu(q_orig) if store_triu_as_line else q_orig
|
116
117
|
q32 = [promote(q_) for q_ in q]
|
117
118
|
self.do_update(group, [p], [ea if momentum_into_precond_update else g], [q32], precond_lr, [q_orig],
|
heavyball/delayed_psgd.py
CHANGED
@@ -62,6 +62,7 @@ class ForeachDelayedPSGD(PSGDBase):
|
|
62
62
|
|
63
63
|
|
64
64
|
def _step(self, group):
|
65
|
+
should_update = self.should_update(group)
|
65
66
|
momentum_into_precond_update = group.get("momentum_into_precond_update", True)
|
66
67
|
precond_init_scale = group['precond_init_scale']
|
67
68
|
max_size_triangular = group['max_size_triangular']
|
@@ -103,7 +104,7 @@ class ForeachDelayedPSGD(PSGDBase):
|
|
103
104
|
ea = exp_avg_list.pop(0)
|
104
105
|
q = line_to_triu(q_orig) if store_triu_as_line else q_orig
|
105
106
|
new = psgd_precond_grad(q, self.state_(p)["exprs"], ea)
|
106
|
-
if
|
107
|
+
if should_update:
|
107
108
|
q32 = [promote(q_) for q_ in q]
|
108
109
|
self.do_update(group,[p], [ea if momentum_into_precond_update else g], [q32], precond_lr, [q_orig], store_triu_as_line)
|
109
110
|
set_(g, new)
|
heavyball/p_adam.py
CHANGED
@@ -61,6 +61,7 @@ class ForeachPaLMPAdam(PSGDBase):
|
|
61
61
|
super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
|
62
62
|
|
63
63
|
def _step(self, group):
|
64
|
+
should_update = self.should_update(group)
|
64
65
|
precond_init_scale = group['precond_init_scale']
|
65
66
|
max_size_triangular = group['max_size_triangular']
|
66
67
|
min_ndim_triangular = group['min_ndim_triangular']
|
@@ -94,7 +95,7 @@ class ForeachPaLMPAdam(PSGDBase):
|
|
94
95
|
group["step"] += 1
|
95
96
|
|
96
97
|
Q_triu = [line_to_triu(q) if store_triu_as_line else q for q in Q_list]
|
97
|
-
if
|
98
|
+
if should_update:
|
98
99
|
for g, p, q_, q_orig in zip(grad_list, p_list, Q_triu, Q_list):
|
99
100
|
q32 = [promote(qq_) for qq_ in q_]
|
100
101
|
self.do_update(group, [p], [g], [q32], precond_lr, [q_orig], store_triu_as_line)
|
heavyball/psgd_kron.py
CHANGED
@@ -60,6 +60,7 @@ class ForeachPSGDKron(PSGDBase):
|
|
60
60
|
super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
|
61
61
|
|
62
62
|
def _step(self, group):
|
63
|
+
should_update = self.should_update(group)
|
63
64
|
momentum_into_precond_update = group.get("momentum_into_precond_update", True)
|
64
65
|
precond_init_scale = group['precond_init_scale']
|
65
66
|
max_size_triangular = group['max_size_triangular']
|
@@ -101,7 +102,7 @@ class ForeachPSGDKron(PSGDBase):
|
|
101
102
|
ea = exp_avg_list.pop(0)
|
102
103
|
q = line_to_triu(q_orig) if store_triu_as_line else q_orig
|
103
104
|
|
104
|
-
if
|
105
|
+
if should_update:
|
105
106
|
q32 = [promote(q_) for q_ in q]
|
106
107
|
self.do_update(group, [p], [g], [q32], precond_lr, [q_orig], store_triu_as_line)
|
107
108
|
set_(g, psgd_precond_grad(q, self.state_(p)["exprs"], ea))
|
heavyball/pure_psgd.py
CHANGED
@@ -57,7 +57,7 @@ class ForeachPurePSGD(PSGDBase):
|
|
57
57
|
super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
|
58
58
|
|
59
59
|
def _step(self, group):
|
60
|
-
|
60
|
+
should_update = self.should_update(group)
|
61
61
|
precond_init_scale = group['precond_init_scale']
|
62
62
|
max_size_triangular = group['max_size_triangular']
|
63
63
|
min_ndim_triangular = group['min_ndim_triangular']
|
@@ -93,7 +93,7 @@ class ForeachPurePSGD(PSGDBase):
|
|
93
93
|
q_orig = Q_list.pop(0)
|
94
94
|
q = line_to_triu(q_orig) if store_triu_as_line else q_orig
|
95
95
|
|
96
|
-
if
|
96
|
+
if group:
|
97
97
|
q32 = [promote(q_) for q_ in q]
|
98
98
|
self.do_update(group, [p], [g], [q32], precond_lr, [q_orig], store_triu_as_line)
|
99
99
|
psgd_precond_grad(q, self.state_(p)["exprs"], g, inplace=True)
|
heavyball/utils.py
CHANGED
@@ -335,9 +335,7 @@ def promote(x):
|
|
335
335
|
def min_dtype(xs: List[torch.Tensor]):
|
336
336
|
dtypes = [x.dtype for x in xs]
|
337
337
|
for d in (torch.float32, torch.bfloat16, torch.float16):
|
338
|
-
if all(d
|
339
|
-
return d
|
340
|
-
if all(d in (x, torch.float32, torch.float64) for x in dtypes):
|
338
|
+
if all(x in (d, torch.float32, torch.float64) for x in dtypes):
|
341
339
|
return d
|
342
340
|
return torch.float32
|
343
341
|
|
@@ -1,24 +1,24 @@
|
|
1
1
|
heavyball/__init__.py,sha256=iqP428JWwwx-XDOZ0nUdbCkOLEyfoqVyWZLQLAcwxaw,2214
|
2
|
-
heavyball/cached_delayed_psgd_kron.py,sha256=
|
3
|
-
heavyball/cached_psgd_kron.py,sha256=
|
4
|
-
heavyball/delayed_psgd.py,sha256=
|
2
|
+
heavyball/cached_delayed_psgd_kron.py,sha256=PQAER6UgVh5l87DGRZrJ8CVP9UhyCG5wJD9rPLnj_G8,6460
|
3
|
+
heavyball/cached_psgd_kron.py,sha256=GaeneBp0irksCSBIrJY4D_0hCpZ-uSRPMhqVX_a-og8,6417
|
4
|
+
heavyball/delayed_psgd.py,sha256=fhBWFLTSl1S2gHWCeYak-STaXRwpC56sWZGLFMKFEJM,5589
|
5
5
|
heavyball/foreach_adamw.py,sha256=h_ar0ZZRM0q_wxuEkxEOEYe-2p-mB4OMgAHivrUnPl8,1777
|
6
6
|
heavyball/foreach_adopt.py,sha256=ogOw2JjwEQNj7AKlweAphQFdMJ_GcMDm-RyDvEzugoc,1911
|
7
7
|
heavyball/foreach_laprop.py,sha256=yGVmGqWiSw8Y2Xj70ndkR8ZMygakTB4_iRwV02Svkqg,1816
|
8
8
|
heavyball/foreach_sfadamw.py,sha256=15-n6-lx4PAHYsKYmXbugxsR5MnqaPYy2vUudPRiitg,2087
|
9
9
|
heavyball/foreach_soap.py,sha256=h6ptMch7oaynvu3eIJtWnVXypDA_5JDVm3Zb3PNEma0,4634
|
10
|
-
heavyball/p_adam.py,sha256=
|
10
|
+
heavyball/p_adam.py,sha256=4zJDGJrpgUyVzr3GiELETFre4xr3-PE10OuAZj-jFM8,5883
|
11
11
|
heavyball/palm_foreach_sfadamw.py,sha256=yvZbPyjDW8qd3r4qDXb6uTr5RozQ7JSDj4aYYRnKGLA,2248
|
12
12
|
heavyball/palm_foreach_soap.py,sha256=g4hbiGRcti-J-a0SwAkP4ii5pU-aalsZH5bssyhroLk,5938
|
13
13
|
heavyball/precond_schedule_foreach_soap.py,sha256=WLg5SzpJnKPZUvFyIvdwSZa1Umt5cpr3Kow_42orM-E,4863
|
14
14
|
heavyball/precond_schedule_palm_foreach_soap.py,sha256=ammQrvRZFF-wc-wEiPEoFhS_7b8pdV61QfcLoQfimSo,6211
|
15
15
|
heavyball/precond_schedule_sfpsoap.py,sha256=vq7jd302refKPa_9X2lkOTOtCCcTBVByPdojklrY8pA,6770
|
16
|
-
heavyball/psgd_kron.py,sha256=
|
17
|
-
heavyball/pure_psgd.py,sha256=
|
16
|
+
heavyball/psgd_kron.py,sha256=u46dorOUXx-do1IYeno2wj-6l1zYKMQQC-N2Zr2PzLI,5476
|
17
|
+
heavyball/pure_psgd.py,sha256=iUy7mMKWxwNiVUMYrQ7SBnreu3t_XSbnhTW3a1yw4m0,4835
|
18
18
|
heavyball/schedule_free_palm_foreach_soap.py,sha256=zkcikH5wWbzq4kOrmBjilvY3iWzuUddcv2HNEPKr3MI,6366
|
19
|
-
heavyball/utils.py,sha256=
|
20
|
-
heavyball-0.18.
|
21
|
-
heavyball-0.18.
|
22
|
-
heavyball-0.18.
|
23
|
-
heavyball-0.18.
|
24
|
-
heavyball-0.18.
|
19
|
+
heavyball/utils.py,sha256=qs_WfzJdS-3XyEuw-m6mWMEeR95r7bGFVC8wWCHtD48,30365
|
20
|
+
heavyball-0.18.3.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
|
21
|
+
heavyball-0.18.3.dist-info/METADATA,sha256=Cx8LM2g3BFOk8WJH3B8ve8kQ7HghMCIRLggdJp37x4g,11810
|
22
|
+
heavyball-0.18.3.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
23
|
+
heavyball-0.18.3.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
|
24
|
+
heavyball-0.18.3.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|