heavyball 0.18.1__tar.gz → 0.18.3__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {heavyball-0.18.1 → heavyball-0.18.3}/PKG-INFO +1 -1
- {heavyball-0.18.1 → heavyball-0.18.3}/heavyball/cached_delayed_psgd_kron.py +5 -2
- {heavyball-0.18.1 → heavyball-0.18.3}/heavyball/cached_psgd_kron.py +5 -6
- {heavyball-0.18.1 → heavyball-0.18.3}/heavyball/delayed_psgd.py +2 -1
- {heavyball-0.18.1 → heavyball-0.18.3}/heavyball/p_adam.py +2 -1
- {heavyball-0.18.1 → heavyball-0.18.3}/heavyball/psgd_kron.py +2 -1
- {heavyball-0.18.1 → heavyball-0.18.3}/heavyball/pure_psgd.py +2 -2
- {heavyball-0.18.1 → heavyball-0.18.3}/heavyball/utils.py +8 -9
- {heavyball-0.18.1 → heavyball-0.18.3}/heavyball.egg-info/PKG-INFO +1 -1
- {heavyball-0.18.1 → heavyball-0.18.3}/setup.py +1 -1
- {heavyball-0.18.1 → heavyball-0.18.3}/test/test_stochastic_updates.py +3 -2
- {heavyball-0.18.1 → heavyball-0.18.3}/LICENSE +0 -0
- {heavyball-0.18.1 → heavyball-0.18.3}/README.md +0 -0
- {heavyball-0.18.1 → heavyball-0.18.3}/heavyball/__init__.py +0 -0
- {heavyball-0.18.1 → heavyball-0.18.3}/heavyball/foreach_adamw.py +0 -0
- {heavyball-0.18.1 → heavyball-0.18.3}/heavyball/foreach_adopt.py +0 -0
- {heavyball-0.18.1 → heavyball-0.18.3}/heavyball/foreach_laprop.py +0 -0
- {heavyball-0.18.1 → heavyball-0.18.3}/heavyball/foreach_sfadamw.py +0 -0
- {heavyball-0.18.1 → heavyball-0.18.3}/heavyball/foreach_soap.py +0 -0
- {heavyball-0.18.1 → heavyball-0.18.3}/heavyball/palm_foreach_sfadamw.py +0 -0
- {heavyball-0.18.1 → heavyball-0.18.3}/heavyball/palm_foreach_soap.py +0 -0
- {heavyball-0.18.1 → heavyball-0.18.3}/heavyball/precond_schedule_foreach_soap.py +0 -0
- {heavyball-0.18.1 → heavyball-0.18.3}/heavyball/precond_schedule_palm_foreach_soap.py +0 -0
- {heavyball-0.18.1 → heavyball-0.18.3}/heavyball/precond_schedule_sfpsoap.py +0 -0
- {heavyball-0.18.1 → heavyball-0.18.3}/heavyball/schedule_free_palm_foreach_soap.py +0 -0
- {heavyball-0.18.1 → heavyball-0.18.3}/heavyball.egg-info/SOURCES.txt +0 -0
- {heavyball-0.18.1 → heavyball-0.18.3}/heavyball.egg-info/dependency_links.txt +0 -0
- {heavyball-0.18.1 → heavyball-0.18.3}/heavyball.egg-info/requires.txt +0 -0
- {heavyball-0.18.1 → heavyball-0.18.3}/heavyball.egg-info/top_level.txt +0 -0
- {heavyball-0.18.1 → heavyball-0.18.3}/setup.cfg +0 -0
- {heavyball-0.18.1 → heavyball-0.18.3}/test/test_bf16_q.py +0 -0
- {heavyball-0.18.1 → heavyball-0.18.3}/test/test_closure.py +0 -0
- {heavyball-0.18.1 → heavyball-0.18.3}/test/test_foreach.py +0 -0
- {heavyball-0.18.1 → heavyball-0.18.3}/test/test_memory.py +0 -0
- {heavyball-0.18.1 → heavyball-0.18.3}/test/test_merge.py +0 -0
- {heavyball-0.18.1 → heavyball-0.18.3}/test/test_no_grad.py +0 -0
- {heavyball-0.18.1 → heavyball-0.18.3}/test/test_psgd.py +0 -0
- {heavyball-0.18.1 → heavyball-0.18.3}/test/test_soap.py +0 -0
@@ -62,6 +62,7 @@ class ForeachCachedDelayedPSGDKron(PSGDBase):
|
|
62
62
|
super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
|
63
63
|
|
64
64
|
def _step(self, group):
|
65
|
+
should_update = self.should_update(group)
|
65
66
|
momentum_into_precond_update = group.get("momentum_into_precond_update", True)
|
66
67
|
precond_init_scale = group['precond_init_scale']
|
67
68
|
max_size_triangular = group['max_size_triangular']
|
@@ -113,7 +114,9 @@ class ForeachCachedDelayedPSGDKron(PSGDBase):
|
|
113
114
|
q_orig = Q_list.pop(0)
|
114
115
|
ea = exp_avg_list.pop(0)
|
115
116
|
|
116
|
-
|
117
|
+
new = torch.einsum(self.state_(p)['cache_expr'], *cached_q, ea)
|
118
|
+
|
119
|
+
if should_update:
|
117
120
|
q = line_to_triu(q_orig) if store_triu_as_line else q_orig
|
118
121
|
q32 = [promote(q_) for q_ in q]
|
119
122
|
self.do_update(group, [p], [ea if momentum_into_precond_update else g], [q32], precond_lr, [q_orig],
|
@@ -124,7 +127,7 @@ class ForeachCachedDelayedPSGDKron(PSGDBase):
|
|
124
127
|
else:
|
125
128
|
torch.mul(q_.conj(), q_, out=c_)
|
126
129
|
|
127
|
-
set_(g,
|
130
|
+
set_(g, new)
|
128
131
|
grad_list = self.clip_fn(grad_list)
|
129
132
|
|
130
133
|
lr = -warmup(lr, group['step'], group['warmup_steps'])
|
@@ -8,8 +8,8 @@ from typing import Optional
|
|
8
8
|
|
9
9
|
import torch
|
10
10
|
|
11
|
-
from .utils import update_param_, warmup, init_Q_exprs, trust_region_clip_, PSGDBase, \
|
12
|
-
|
11
|
+
from .utils import update_param_, warmup, init_Q_exprs, trust_region_clip_, PSGDBase, split_p_and_g_in_group, \
|
12
|
+
line_to_triu, triu_to_line, set_, einsum_base, promote
|
13
13
|
|
14
14
|
|
15
15
|
class ForeachCachedPSGDKron(PSGDBase):
|
@@ -71,6 +71,7 @@ class ForeachCachedPSGDKron(PSGDBase):
|
|
71
71
|
beta = group['beta']
|
72
72
|
store_triu_as_line = group['store_triu_as_line']
|
73
73
|
q_dtype = getattr(torch, group['q_dtype'])
|
74
|
+
should_update = self.should_update(group)
|
74
75
|
|
75
76
|
vals = []
|
76
77
|
|
@@ -111,9 +112,7 @@ class ForeachCachedPSGDKron(PSGDBase):
|
|
111
112
|
q_orig = Q_list.pop(0)
|
112
113
|
ea = exp_avg_list.pop(0)
|
113
114
|
|
114
|
-
|
115
|
-
|
116
|
-
if self.should_update(group):
|
115
|
+
if should_update:
|
117
116
|
q = line_to_triu(q_orig) if store_triu_as_line else q_orig
|
118
117
|
q32 = [promote(q_) for q_ in q]
|
119
118
|
self.do_update(group, [p], [ea if momentum_into_precond_update else g], [q32], precond_lr, [q_orig],
|
@@ -124,7 +123,7 @@ class ForeachCachedPSGDKron(PSGDBase):
|
|
124
123
|
else:
|
125
124
|
torch.mul(q_.conj(), q_, out=c_)
|
126
125
|
|
127
|
-
set_(g,
|
126
|
+
set_(g, torch.einsum(self.state_(p)['cache_expr'], *cached_q, ea))
|
128
127
|
|
129
128
|
grad_list = self.clip_fn(grad_list)
|
130
129
|
|
@@ -62,6 +62,7 @@ class ForeachDelayedPSGD(PSGDBase):
|
|
62
62
|
|
63
63
|
|
64
64
|
def _step(self, group):
|
65
|
+
should_update = self.should_update(group)
|
65
66
|
momentum_into_precond_update = group.get("momentum_into_precond_update", True)
|
66
67
|
precond_init_scale = group['precond_init_scale']
|
67
68
|
max_size_triangular = group['max_size_triangular']
|
@@ -103,7 +104,7 @@ class ForeachDelayedPSGD(PSGDBase):
|
|
103
104
|
ea = exp_avg_list.pop(0)
|
104
105
|
q = line_to_triu(q_orig) if store_triu_as_line else q_orig
|
105
106
|
new = psgd_precond_grad(q, self.state_(p)["exprs"], ea)
|
106
|
-
if
|
107
|
+
if should_update:
|
107
108
|
q32 = [promote(q_) for q_ in q]
|
108
109
|
self.do_update(group,[p], [ea if momentum_into_precond_update else g], [q32], precond_lr, [q_orig], store_triu_as_line)
|
109
110
|
set_(g, new)
|
@@ -61,6 +61,7 @@ class ForeachPaLMPAdam(PSGDBase):
|
|
61
61
|
super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
|
62
62
|
|
63
63
|
def _step(self, group):
|
64
|
+
should_update = self.should_update(group)
|
64
65
|
precond_init_scale = group['precond_init_scale']
|
65
66
|
max_size_triangular = group['max_size_triangular']
|
66
67
|
min_ndim_triangular = group['min_ndim_triangular']
|
@@ -94,7 +95,7 @@ class ForeachPaLMPAdam(PSGDBase):
|
|
94
95
|
group["step"] += 1
|
95
96
|
|
96
97
|
Q_triu = [line_to_triu(q) if store_triu_as_line else q for q in Q_list]
|
97
|
-
if
|
98
|
+
if should_update:
|
98
99
|
for g, p, q_, q_orig in zip(grad_list, p_list, Q_triu, Q_list):
|
99
100
|
q32 = [promote(qq_) for qq_ in q_]
|
100
101
|
self.do_update(group, [p], [g], [q32], precond_lr, [q_orig], store_triu_as_line)
|
@@ -60,6 +60,7 @@ class ForeachPSGDKron(PSGDBase):
|
|
60
60
|
super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
|
61
61
|
|
62
62
|
def _step(self, group):
|
63
|
+
should_update = self.should_update(group)
|
63
64
|
momentum_into_precond_update = group.get("momentum_into_precond_update", True)
|
64
65
|
precond_init_scale = group['precond_init_scale']
|
65
66
|
max_size_triangular = group['max_size_triangular']
|
@@ -101,7 +102,7 @@ class ForeachPSGDKron(PSGDBase):
|
|
101
102
|
ea = exp_avg_list.pop(0)
|
102
103
|
q = line_to_triu(q_orig) if store_triu_as_line else q_orig
|
103
104
|
|
104
|
-
if
|
105
|
+
if should_update:
|
105
106
|
q32 = [promote(q_) for q_ in q]
|
106
107
|
self.do_update(group, [p], [g], [q32], precond_lr, [q_orig], store_triu_as_line)
|
107
108
|
set_(g, psgd_precond_grad(q, self.state_(p)["exprs"], ea))
|
@@ -57,7 +57,7 @@ class ForeachPurePSGD(PSGDBase):
|
|
57
57
|
super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
|
58
58
|
|
59
59
|
def _step(self, group):
|
60
|
-
|
60
|
+
should_update = self.should_update(group)
|
61
61
|
precond_init_scale = group['precond_init_scale']
|
62
62
|
max_size_triangular = group['max_size_triangular']
|
63
63
|
min_ndim_triangular = group['min_ndim_triangular']
|
@@ -93,7 +93,7 @@ class ForeachPurePSGD(PSGDBase):
|
|
93
93
|
q_orig = Q_list.pop(0)
|
94
94
|
q = line_to_triu(q_orig) if store_triu_as_line else q_orig
|
95
95
|
|
96
|
-
if
|
96
|
+
if group:
|
97
97
|
q32 = [promote(q_) for q_ in q]
|
98
98
|
self.do_update(group, [p], [g], [q32], precond_lr, [q_orig], store_triu_as_line)
|
99
99
|
psgd_precond_grad(q, self.state_(p)["exprs"], g, inplace=True)
|
@@ -335,9 +335,7 @@ def promote(x):
|
|
335
335
|
def min_dtype(xs: List[torch.Tensor]):
|
336
336
|
dtypes = [x.dtype for x in xs]
|
337
337
|
for d in (torch.float32, torch.bfloat16, torch.float16):
|
338
|
-
if all(d
|
339
|
-
return d
|
340
|
-
if all(d in (x, torch.float32, torch.float64) for x in dtypes):
|
338
|
+
if all(x in (d, torch.float32, torch.float64) for x in dtypes):
|
341
339
|
return d
|
342
340
|
return torch.float32
|
343
341
|
|
@@ -817,14 +815,15 @@ class PSGDBase(StatefulOptimizer):
|
|
817
815
|
|
818
816
|
def do_update(self, group, p_list, grad_list, q_list, precond_lr, original_q: Optional[List] = None,
|
819
817
|
store_triu_as_line=False):
|
820
|
-
if self.should_update(group, self.balance_probability, 'balance_prob'):
|
821
|
-
for g, q in zip(grad_list, q_list):
|
822
|
-
if g.dim() > 1:
|
823
|
-
psgd_balance_Q(q)
|
824
|
-
|
825
818
|
for i, (p, grad, Q) in enumerate(zip(p_list, grad_list, q_list)):
|
826
819
|
psgd_update_precond(Q, self.state_(p)["exprs"], torch.randn_like(grad), grad, precond_lr, self._tiny)
|
827
|
-
|
820
|
+
|
821
|
+
for g, q in zip(grad_list, q_list):
|
822
|
+
if g.dim() > 1:
|
823
|
+
psgd_balance_Q(q)
|
824
|
+
|
825
|
+
if original_q:
|
826
|
+
for q in q_list:
|
828
827
|
if store_triu_as_line:
|
829
828
|
update_triu_(original_q[i], Q)
|
830
829
|
else:
|
@@ -16,7 +16,7 @@ def get_memory():
|
|
16
16
|
|
17
17
|
|
18
18
|
@pytest.mark.parametrize("opt", heavyball.__all__)
|
19
|
-
@pytest.mark.parametrize("size,depth", [(128,
|
19
|
+
@pytest.mark.parametrize("size,depth", [(128, 1)])
|
20
20
|
def test_foreach(opt, size, depth: int, iterations: int = 8192, outer_iterations: int = 3):
|
21
21
|
set_torch()
|
22
22
|
|
@@ -28,12 +28,13 @@ def test_foreach(opt, size, depth: int, iterations: int = 8192, outer_iterations
|
|
28
28
|
losses = []
|
29
29
|
|
30
30
|
for stochastic in [False, True]:
|
31
|
+
print('stochastic', stochastic)
|
31
32
|
torch.manual_seed(0x2131290)
|
32
33
|
peaks.append([])
|
33
34
|
losses.append([])
|
34
35
|
|
35
36
|
for i in range(outer_iterations):
|
36
|
-
model = nn.Sequential(*[nn.Linear(size, size) for _ in range(depth)]).cuda()
|
37
|
+
model = nn.Sequential(*[nn.Linear(size, size, bias=False) for _ in range(depth)]).cuda()
|
37
38
|
o = get_optim(opt, model.parameters(), lr=1e-3, stochastic_schedule=stochastic)
|
38
39
|
|
39
40
|
for _ in range(iterations):
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|