heavyball 0.18.1__tar.gz → 0.18.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {heavyball-0.18.1 → heavyball-0.18.2}/PKG-INFO +1 -1
- {heavyball-0.18.1 → heavyball-0.18.2}/heavyball/cached_delayed_psgd_kron.py +3 -1
- {heavyball-0.18.1 → heavyball-0.18.2}/heavyball/cached_psgd_kron.py +1 -3
- {heavyball-0.18.1 → heavyball-0.18.2}/heavyball/utils.py +7 -6
- {heavyball-0.18.1 → heavyball-0.18.2}/heavyball.egg-info/PKG-INFO +1 -1
- {heavyball-0.18.1 → heavyball-0.18.2}/setup.py +1 -1
- {heavyball-0.18.1 → heavyball-0.18.2}/test/test_stochastic_updates.py +3 -2
- {heavyball-0.18.1 → heavyball-0.18.2}/LICENSE +0 -0
- {heavyball-0.18.1 → heavyball-0.18.2}/README.md +0 -0
- {heavyball-0.18.1 → heavyball-0.18.2}/heavyball/__init__.py +0 -0
- {heavyball-0.18.1 → heavyball-0.18.2}/heavyball/delayed_psgd.py +0 -0
- {heavyball-0.18.1 → heavyball-0.18.2}/heavyball/foreach_adamw.py +0 -0
- {heavyball-0.18.1 → heavyball-0.18.2}/heavyball/foreach_adopt.py +0 -0
- {heavyball-0.18.1 → heavyball-0.18.2}/heavyball/foreach_laprop.py +0 -0
- {heavyball-0.18.1 → heavyball-0.18.2}/heavyball/foreach_sfadamw.py +0 -0
- {heavyball-0.18.1 → heavyball-0.18.2}/heavyball/foreach_soap.py +0 -0
- {heavyball-0.18.1 → heavyball-0.18.2}/heavyball/p_adam.py +0 -0
- {heavyball-0.18.1 → heavyball-0.18.2}/heavyball/palm_foreach_sfadamw.py +0 -0
- {heavyball-0.18.1 → heavyball-0.18.2}/heavyball/palm_foreach_soap.py +0 -0
- {heavyball-0.18.1 → heavyball-0.18.2}/heavyball/precond_schedule_foreach_soap.py +0 -0
- {heavyball-0.18.1 → heavyball-0.18.2}/heavyball/precond_schedule_palm_foreach_soap.py +0 -0
- {heavyball-0.18.1 → heavyball-0.18.2}/heavyball/precond_schedule_sfpsoap.py +0 -0
- {heavyball-0.18.1 → heavyball-0.18.2}/heavyball/psgd_kron.py +0 -0
- {heavyball-0.18.1 → heavyball-0.18.2}/heavyball/pure_psgd.py +0 -0
- {heavyball-0.18.1 → heavyball-0.18.2}/heavyball/schedule_free_palm_foreach_soap.py +0 -0
- {heavyball-0.18.1 → heavyball-0.18.2}/heavyball.egg-info/SOURCES.txt +0 -0
- {heavyball-0.18.1 → heavyball-0.18.2}/heavyball.egg-info/dependency_links.txt +0 -0
- {heavyball-0.18.1 → heavyball-0.18.2}/heavyball.egg-info/requires.txt +0 -0
- {heavyball-0.18.1 → heavyball-0.18.2}/heavyball.egg-info/top_level.txt +0 -0
- {heavyball-0.18.1 → heavyball-0.18.2}/setup.cfg +0 -0
- {heavyball-0.18.1 → heavyball-0.18.2}/test/test_bf16_q.py +0 -0
- {heavyball-0.18.1 → heavyball-0.18.2}/test/test_closure.py +0 -0
- {heavyball-0.18.1 → heavyball-0.18.2}/test/test_foreach.py +0 -0
- {heavyball-0.18.1 → heavyball-0.18.2}/test/test_memory.py +0 -0
- {heavyball-0.18.1 → heavyball-0.18.2}/test/test_merge.py +0 -0
- {heavyball-0.18.1 → heavyball-0.18.2}/test/test_no_grad.py +0 -0
- {heavyball-0.18.1 → heavyball-0.18.2}/test/test_psgd.py +0 -0
- {heavyball-0.18.1 → heavyball-0.18.2}/test/test_soap.py +0 -0
@@ -113,6 +113,8 @@ class ForeachCachedDelayedPSGDKron(PSGDBase):
|
|
113
113
|
q_orig = Q_list.pop(0)
|
114
114
|
ea = exp_avg_list.pop(0)
|
115
115
|
|
116
|
+
new = torch.einsum(self.state_(p)['cache_expr'], *cached_q, ea)
|
117
|
+
|
116
118
|
if self.should_update(group):
|
117
119
|
q = line_to_triu(q_orig) if store_triu_as_line else q_orig
|
118
120
|
q32 = [promote(q_) for q_ in q]
|
@@ -124,7 +126,7 @@ class ForeachCachedDelayedPSGDKron(PSGDBase):
|
|
124
126
|
else:
|
125
127
|
torch.mul(q_.conj(), q_, out=c_)
|
126
128
|
|
127
|
-
set_(g,
|
129
|
+
set_(g, new)
|
128
130
|
grad_list = self.clip_fn(grad_list)
|
129
131
|
|
130
132
|
lr = -warmup(lr, group['step'], group['warmup_steps'])
|
@@ -111,8 +111,6 @@ class ForeachCachedPSGDKron(PSGDBase):
|
|
111
111
|
q_orig = Q_list.pop(0)
|
112
112
|
ea = exp_avg_list.pop(0)
|
113
113
|
|
114
|
-
new = torch.einsum(self.state_(p)['cache_expr'], *cached_q, ea)
|
115
|
-
|
116
114
|
if self.should_update(group):
|
117
115
|
q = line_to_triu(q_orig) if store_triu_as_line else q_orig
|
118
116
|
q32 = [promote(q_) for q_ in q]
|
@@ -124,7 +122,7 @@ class ForeachCachedPSGDKron(PSGDBase):
|
|
124
122
|
else:
|
125
123
|
torch.mul(q_.conj(), q_, out=c_)
|
126
124
|
|
127
|
-
set_(g,
|
125
|
+
set_(g, torch.einsum(self.state_(p)['cache_expr'], *cached_q, ea))
|
128
126
|
|
129
127
|
grad_list = self.clip_fn(grad_list)
|
130
128
|
|
@@ -817,14 +817,15 @@ class PSGDBase(StatefulOptimizer):
|
|
817
817
|
|
818
818
|
def do_update(self, group, p_list, grad_list, q_list, precond_lr, original_q: Optional[List] = None,
|
819
819
|
store_triu_as_line=False):
|
820
|
-
if self.should_update(group, self.balance_probability, 'balance_prob'):
|
821
|
-
for g, q in zip(grad_list, q_list):
|
822
|
-
if g.dim() > 1:
|
823
|
-
psgd_balance_Q(q)
|
824
|
-
|
825
820
|
for i, (p, grad, Q) in enumerate(zip(p_list, grad_list, q_list)):
|
826
821
|
psgd_update_precond(Q, self.state_(p)["exprs"], torch.randn_like(grad), grad, precond_lr, self._tiny)
|
827
|
-
|
822
|
+
|
823
|
+
for g, q in zip(grad_list, q_list):
|
824
|
+
if g.dim() > 1:
|
825
|
+
psgd_balance_Q(q)
|
826
|
+
|
827
|
+
if original_q:
|
828
|
+
for q in q_list:
|
828
829
|
if store_triu_as_line:
|
829
830
|
update_triu_(original_q[i], Q)
|
830
831
|
else:
|
@@ -16,7 +16,7 @@ def get_memory():
|
|
16
16
|
|
17
17
|
|
18
18
|
@pytest.mark.parametrize("opt", heavyball.__all__)
|
19
|
-
@pytest.mark.parametrize("size,depth", [(128,
|
19
|
+
@pytest.mark.parametrize("size,depth", [(128, 1)])
|
20
20
|
def test_foreach(opt, size, depth: int, iterations: int = 8192, outer_iterations: int = 3):
|
21
21
|
set_torch()
|
22
22
|
|
@@ -28,12 +28,13 @@ def test_foreach(opt, size, depth: int, iterations: int = 8192, outer_iterations
|
|
28
28
|
losses = []
|
29
29
|
|
30
30
|
for stochastic in [False, True]:
|
31
|
+
print('stochastic', stochastic)
|
31
32
|
torch.manual_seed(0x2131290)
|
32
33
|
peaks.append([])
|
33
34
|
losses.append([])
|
34
35
|
|
35
36
|
for i in range(outer_iterations):
|
36
|
-
model = nn.Sequential(*[nn.Linear(size, size) for _ in range(depth)]).cuda()
|
37
|
+
model = nn.Sequential(*[nn.Linear(size, size, bias=False) for _ in range(depth)]).cuda()
|
37
38
|
o = get_optim(opt, model.parameters(), lr=1e-3, stochastic_schedule=stochastic)
|
38
39
|
|
39
40
|
for _ in range(iterations):
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|