heavyball 0.18.1__tar.gz → 0.18.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. {heavyball-0.18.1 → heavyball-0.18.2}/PKG-INFO +1 -1
  2. {heavyball-0.18.1 → heavyball-0.18.2}/heavyball/cached_delayed_psgd_kron.py +3 -1
  3. {heavyball-0.18.1 → heavyball-0.18.2}/heavyball/cached_psgd_kron.py +1 -3
  4. {heavyball-0.18.1 → heavyball-0.18.2}/heavyball/utils.py +7 -6
  5. {heavyball-0.18.1 → heavyball-0.18.2}/heavyball.egg-info/PKG-INFO +1 -1
  6. {heavyball-0.18.1 → heavyball-0.18.2}/setup.py +1 -1
  7. {heavyball-0.18.1 → heavyball-0.18.2}/test/test_stochastic_updates.py +3 -2
  8. {heavyball-0.18.1 → heavyball-0.18.2}/LICENSE +0 -0
  9. {heavyball-0.18.1 → heavyball-0.18.2}/README.md +0 -0
  10. {heavyball-0.18.1 → heavyball-0.18.2}/heavyball/__init__.py +0 -0
  11. {heavyball-0.18.1 → heavyball-0.18.2}/heavyball/delayed_psgd.py +0 -0
  12. {heavyball-0.18.1 → heavyball-0.18.2}/heavyball/foreach_adamw.py +0 -0
  13. {heavyball-0.18.1 → heavyball-0.18.2}/heavyball/foreach_adopt.py +0 -0
  14. {heavyball-0.18.1 → heavyball-0.18.2}/heavyball/foreach_laprop.py +0 -0
  15. {heavyball-0.18.1 → heavyball-0.18.2}/heavyball/foreach_sfadamw.py +0 -0
  16. {heavyball-0.18.1 → heavyball-0.18.2}/heavyball/foreach_soap.py +0 -0
  17. {heavyball-0.18.1 → heavyball-0.18.2}/heavyball/p_adam.py +0 -0
  18. {heavyball-0.18.1 → heavyball-0.18.2}/heavyball/palm_foreach_sfadamw.py +0 -0
  19. {heavyball-0.18.1 → heavyball-0.18.2}/heavyball/palm_foreach_soap.py +0 -0
  20. {heavyball-0.18.1 → heavyball-0.18.2}/heavyball/precond_schedule_foreach_soap.py +0 -0
  21. {heavyball-0.18.1 → heavyball-0.18.2}/heavyball/precond_schedule_palm_foreach_soap.py +0 -0
  22. {heavyball-0.18.1 → heavyball-0.18.2}/heavyball/precond_schedule_sfpsoap.py +0 -0
  23. {heavyball-0.18.1 → heavyball-0.18.2}/heavyball/psgd_kron.py +0 -0
  24. {heavyball-0.18.1 → heavyball-0.18.2}/heavyball/pure_psgd.py +0 -0
  25. {heavyball-0.18.1 → heavyball-0.18.2}/heavyball/schedule_free_palm_foreach_soap.py +0 -0
  26. {heavyball-0.18.1 → heavyball-0.18.2}/heavyball.egg-info/SOURCES.txt +0 -0
  27. {heavyball-0.18.1 → heavyball-0.18.2}/heavyball.egg-info/dependency_links.txt +0 -0
  28. {heavyball-0.18.1 → heavyball-0.18.2}/heavyball.egg-info/requires.txt +0 -0
  29. {heavyball-0.18.1 → heavyball-0.18.2}/heavyball.egg-info/top_level.txt +0 -0
  30. {heavyball-0.18.1 → heavyball-0.18.2}/setup.cfg +0 -0
  31. {heavyball-0.18.1 → heavyball-0.18.2}/test/test_bf16_q.py +0 -0
  32. {heavyball-0.18.1 → heavyball-0.18.2}/test/test_closure.py +0 -0
  33. {heavyball-0.18.1 → heavyball-0.18.2}/test/test_foreach.py +0 -0
  34. {heavyball-0.18.1 → heavyball-0.18.2}/test/test_memory.py +0 -0
  35. {heavyball-0.18.1 → heavyball-0.18.2}/test/test_merge.py +0 -0
  36. {heavyball-0.18.1 → heavyball-0.18.2}/test/test_no_grad.py +0 -0
  37. {heavyball-0.18.1 → heavyball-0.18.2}/test/test_psgd.py +0 -0
  38. {heavyball-0.18.1 → heavyball-0.18.2}/test/test_soap.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: heavyball
3
- Version: 0.18.1
3
+ Version: 0.18.2
4
4
  Summary: Efficient optimizers
5
5
  Home-page: https://github.com/clashluke/heavyball
6
6
  Author: Lucas Nestler
@@ -113,6 +113,8 @@ class ForeachCachedDelayedPSGDKron(PSGDBase):
113
113
  q_orig = Q_list.pop(0)
114
114
  ea = exp_avg_list.pop(0)
115
115
 
116
+ new = torch.einsum(self.state_(p)['cache_expr'], *cached_q, ea)
117
+
116
118
  if self.should_update(group):
117
119
  q = line_to_triu(q_orig) if store_triu_as_line else q_orig
118
120
  q32 = [promote(q_) for q_ in q]
@@ -124,7 +126,7 @@ class ForeachCachedDelayedPSGDKron(PSGDBase):
124
126
  else:
125
127
  torch.mul(q_.conj(), q_, out=c_)
126
128
 
127
- set_(g, torch.einsum(self.state_(p)['cache_expr'], *cached_q, ea))
129
+ set_(g, new)
128
130
  grad_list = self.clip_fn(grad_list)
129
131
 
130
132
  lr = -warmup(lr, group['step'], group['warmup_steps'])
@@ -111,8 +111,6 @@ class ForeachCachedPSGDKron(PSGDBase):
111
111
  q_orig = Q_list.pop(0)
112
112
  ea = exp_avg_list.pop(0)
113
113
 
114
- new = torch.einsum(self.state_(p)['cache_expr'], *cached_q, ea)
115
-
116
114
  if self.should_update(group):
117
115
  q = line_to_triu(q_orig) if store_triu_as_line else q_orig
118
116
  q32 = [promote(q_) for q_ in q]
@@ -124,7 +122,7 @@ class ForeachCachedPSGDKron(PSGDBase):
124
122
  else:
125
123
  torch.mul(q_.conj(), q_, out=c_)
126
124
 
127
- set_(g, new)
125
+ set_(g, torch.einsum(self.state_(p)['cache_expr'], *cached_q, ea))
128
126
 
129
127
  grad_list = self.clip_fn(grad_list)
130
128
 
@@ -817,14 +817,15 @@ class PSGDBase(StatefulOptimizer):
817
817
 
818
818
  def do_update(self, group, p_list, grad_list, q_list, precond_lr, original_q: Optional[List] = None,
819
819
  store_triu_as_line=False):
820
- if self.should_update(group, self.balance_probability, 'balance_prob'):
821
- for g, q in zip(grad_list, q_list):
822
- if g.dim() > 1:
823
- psgd_balance_Q(q)
824
-
825
820
  for i, (p, grad, Q) in enumerate(zip(p_list, grad_list, q_list)):
826
821
  psgd_update_precond(Q, self.state_(p)["exprs"], torch.randn_like(grad), grad, precond_lr, self._tiny)
827
- if original_q:
822
+
823
+ for g, q in zip(grad_list, q_list):
824
+ if g.dim() > 1:
825
+ psgd_balance_Q(q)
826
+
827
+ if original_q:
828
+ for q in q_list:
828
829
  if store_triu_as_line:
829
830
  update_triu_(original_q[i], Q)
830
831
  else:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: heavyball
3
- Version: 0.18.1
3
+ Version: 0.18.2
4
4
  Summary: Efficient optimizers
5
5
  Home-page: https://github.com/clashluke/heavyball
6
6
  Author: Lucas Nestler
@@ -10,7 +10,7 @@ setuptools.setup(
10
10
  name='heavyball',
11
11
  license='BSD',
12
12
  description='Efficient optimizers',
13
- version='0.18.1',
13
+ version='0.18.2',
14
14
  long_description=README,
15
15
  url='https://github.com/clashluke/heavyball',
16
16
  packages=setuptools.find_packages(),
@@ -16,7 +16,7 @@ def get_memory():
16
16
 
17
17
 
18
18
  @pytest.mark.parametrize("opt", heavyball.__all__)
19
- @pytest.mark.parametrize("size,depth", [(128, 2)])
19
+ @pytest.mark.parametrize("size,depth", [(128, 1)])
20
20
  def test_foreach(opt, size, depth: int, iterations: int = 8192, outer_iterations: int = 3):
21
21
  set_torch()
22
22
 
@@ -28,12 +28,13 @@ def test_foreach(opt, size, depth: int, iterations: int = 8192, outer_iterations
28
28
  losses = []
29
29
 
30
30
  for stochastic in [False, True]:
31
+ print('stochastic', stochastic)
31
32
  torch.manual_seed(0x2131290)
32
33
  peaks.append([])
33
34
  losses.append([])
34
35
 
35
36
  for i in range(outer_iterations):
36
- model = nn.Sequential(*[nn.Linear(size, size) for _ in range(depth)]).cuda()
37
+ model = nn.Sequential(*[nn.Linear(size, size, bias=False) for _ in range(depth)]).cuda()
37
38
  o = get_optim(opt, model.parameters(), lr=1e-3, stochastic_schedule=stochastic)
38
39
 
39
40
  for _ in range(iterations):
File without changes
File without changes
File without changes
File without changes
File without changes