heavyball 0.18.1__tar.gz → 0.18.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. {heavyball-0.18.1 → heavyball-0.18.3}/PKG-INFO +1 -1
  2. {heavyball-0.18.1 → heavyball-0.18.3}/heavyball/cached_delayed_psgd_kron.py +5 -2
  3. {heavyball-0.18.1 → heavyball-0.18.3}/heavyball/cached_psgd_kron.py +5 -6
  4. {heavyball-0.18.1 → heavyball-0.18.3}/heavyball/delayed_psgd.py +2 -1
  5. {heavyball-0.18.1 → heavyball-0.18.3}/heavyball/p_adam.py +2 -1
  6. {heavyball-0.18.1 → heavyball-0.18.3}/heavyball/psgd_kron.py +2 -1
  7. {heavyball-0.18.1 → heavyball-0.18.3}/heavyball/pure_psgd.py +2 -2
  8. {heavyball-0.18.1 → heavyball-0.18.3}/heavyball/utils.py +8 -9
  9. {heavyball-0.18.1 → heavyball-0.18.3}/heavyball.egg-info/PKG-INFO +1 -1
  10. {heavyball-0.18.1 → heavyball-0.18.3}/setup.py +1 -1
  11. {heavyball-0.18.1 → heavyball-0.18.3}/test/test_stochastic_updates.py +3 -2
  12. {heavyball-0.18.1 → heavyball-0.18.3}/LICENSE +0 -0
  13. {heavyball-0.18.1 → heavyball-0.18.3}/README.md +0 -0
  14. {heavyball-0.18.1 → heavyball-0.18.3}/heavyball/__init__.py +0 -0
  15. {heavyball-0.18.1 → heavyball-0.18.3}/heavyball/foreach_adamw.py +0 -0
  16. {heavyball-0.18.1 → heavyball-0.18.3}/heavyball/foreach_adopt.py +0 -0
  17. {heavyball-0.18.1 → heavyball-0.18.3}/heavyball/foreach_laprop.py +0 -0
  18. {heavyball-0.18.1 → heavyball-0.18.3}/heavyball/foreach_sfadamw.py +0 -0
  19. {heavyball-0.18.1 → heavyball-0.18.3}/heavyball/foreach_soap.py +0 -0
  20. {heavyball-0.18.1 → heavyball-0.18.3}/heavyball/palm_foreach_sfadamw.py +0 -0
  21. {heavyball-0.18.1 → heavyball-0.18.3}/heavyball/palm_foreach_soap.py +0 -0
  22. {heavyball-0.18.1 → heavyball-0.18.3}/heavyball/precond_schedule_foreach_soap.py +0 -0
  23. {heavyball-0.18.1 → heavyball-0.18.3}/heavyball/precond_schedule_palm_foreach_soap.py +0 -0
  24. {heavyball-0.18.1 → heavyball-0.18.3}/heavyball/precond_schedule_sfpsoap.py +0 -0
  25. {heavyball-0.18.1 → heavyball-0.18.3}/heavyball/schedule_free_palm_foreach_soap.py +0 -0
  26. {heavyball-0.18.1 → heavyball-0.18.3}/heavyball.egg-info/SOURCES.txt +0 -0
  27. {heavyball-0.18.1 → heavyball-0.18.3}/heavyball.egg-info/dependency_links.txt +0 -0
  28. {heavyball-0.18.1 → heavyball-0.18.3}/heavyball.egg-info/requires.txt +0 -0
  29. {heavyball-0.18.1 → heavyball-0.18.3}/heavyball.egg-info/top_level.txt +0 -0
  30. {heavyball-0.18.1 → heavyball-0.18.3}/setup.cfg +0 -0
  31. {heavyball-0.18.1 → heavyball-0.18.3}/test/test_bf16_q.py +0 -0
  32. {heavyball-0.18.1 → heavyball-0.18.3}/test/test_closure.py +0 -0
  33. {heavyball-0.18.1 → heavyball-0.18.3}/test/test_foreach.py +0 -0
  34. {heavyball-0.18.1 → heavyball-0.18.3}/test/test_memory.py +0 -0
  35. {heavyball-0.18.1 → heavyball-0.18.3}/test/test_merge.py +0 -0
  36. {heavyball-0.18.1 → heavyball-0.18.3}/test/test_no_grad.py +0 -0
  37. {heavyball-0.18.1 → heavyball-0.18.3}/test/test_psgd.py +0 -0
  38. {heavyball-0.18.1 → heavyball-0.18.3}/test/test_soap.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: heavyball
3
- Version: 0.18.1
3
+ Version: 0.18.3
4
4
  Summary: Efficient optimizers
5
5
  Home-page: https://github.com/clashluke/heavyball
6
6
  Author: Lucas Nestler
@@ -62,6 +62,7 @@ class ForeachCachedDelayedPSGDKron(PSGDBase):
62
62
  super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
63
63
 
64
64
  def _step(self, group):
65
+ should_update = self.should_update(group)
65
66
  momentum_into_precond_update = group.get("momentum_into_precond_update", True)
66
67
  precond_init_scale = group['precond_init_scale']
67
68
  max_size_triangular = group['max_size_triangular']
@@ -113,7 +114,9 @@ class ForeachCachedDelayedPSGDKron(PSGDBase):
113
114
  q_orig = Q_list.pop(0)
114
115
  ea = exp_avg_list.pop(0)
115
116
 
116
- if self.should_update(group):
117
+ new = torch.einsum(self.state_(p)['cache_expr'], *cached_q, ea)
118
+
119
+ if should_update:
117
120
  q = line_to_triu(q_orig) if store_triu_as_line else q_orig
118
121
  q32 = [promote(q_) for q_ in q]
119
122
  self.do_update(group, [p], [ea if momentum_into_precond_update else g], [q32], precond_lr, [q_orig],
@@ -124,7 +127,7 @@ class ForeachCachedDelayedPSGDKron(PSGDBase):
124
127
  else:
125
128
  torch.mul(q_.conj(), q_, out=c_)
126
129
 
127
- set_(g, torch.einsum(self.state_(p)['cache_expr'], *cached_q, ea))
130
+ set_(g, new)
128
131
  grad_list = self.clip_fn(grad_list)
129
132
 
130
133
  lr = -warmup(lr, group['step'], group['warmup_steps'])
@@ -8,8 +8,8 @@ from typing import Optional
8
8
 
9
9
  import torch
10
10
 
11
- from .utils import update_param_, warmup, init_Q_exprs, trust_region_clip_, PSGDBase, \
12
- split_p_and_g_in_group, line_to_triu, triu_to_line, set_, einsum_base, promote
11
+ from .utils import update_param_, warmup, init_Q_exprs, trust_region_clip_, PSGDBase, split_p_and_g_in_group, \
12
+ line_to_triu, triu_to_line, set_, einsum_base, promote
13
13
 
14
14
 
15
15
  class ForeachCachedPSGDKron(PSGDBase):
@@ -71,6 +71,7 @@ class ForeachCachedPSGDKron(PSGDBase):
71
71
  beta = group['beta']
72
72
  store_triu_as_line = group['store_triu_as_line']
73
73
  q_dtype = getattr(torch, group['q_dtype'])
74
+ should_update = self.should_update(group)
74
75
 
75
76
  vals = []
76
77
 
@@ -111,9 +112,7 @@ class ForeachCachedPSGDKron(PSGDBase):
111
112
  q_orig = Q_list.pop(0)
112
113
  ea = exp_avg_list.pop(0)
113
114
 
114
- new = torch.einsum(self.state_(p)['cache_expr'], *cached_q, ea)
115
-
116
- if self.should_update(group):
115
+ if should_update:
117
116
  q = line_to_triu(q_orig) if store_triu_as_line else q_orig
118
117
  q32 = [promote(q_) for q_ in q]
119
118
  self.do_update(group, [p], [ea if momentum_into_precond_update else g], [q32], precond_lr, [q_orig],
@@ -124,7 +123,7 @@ class ForeachCachedPSGDKron(PSGDBase):
124
123
  else:
125
124
  torch.mul(q_.conj(), q_, out=c_)
126
125
 
127
- set_(g, new)
126
+ set_(g, torch.einsum(self.state_(p)['cache_expr'], *cached_q, ea))
128
127
 
129
128
  grad_list = self.clip_fn(grad_list)
130
129
 
@@ -62,6 +62,7 @@ class ForeachDelayedPSGD(PSGDBase):
62
62
 
63
63
 
64
64
  def _step(self, group):
65
+ should_update = self.should_update(group)
65
66
  momentum_into_precond_update = group.get("momentum_into_precond_update", True)
66
67
  precond_init_scale = group['precond_init_scale']
67
68
  max_size_triangular = group['max_size_triangular']
@@ -103,7 +104,7 @@ class ForeachDelayedPSGD(PSGDBase):
103
104
  ea = exp_avg_list.pop(0)
104
105
  q = line_to_triu(q_orig) if store_triu_as_line else q_orig
105
106
  new = psgd_precond_grad(q, self.state_(p)["exprs"], ea)
106
- if self.should_update(group):
107
+ if should_update:
107
108
  q32 = [promote(q_) for q_ in q]
108
109
  self.do_update(group,[p], [ea if momentum_into_precond_update else g], [q32], precond_lr, [q_orig], store_triu_as_line)
109
110
  set_(g, new)
@@ -61,6 +61,7 @@ class ForeachPaLMPAdam(PSGDBase):
61
61
  super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
62
62
 
63
63
  def _step(self, group):
64
+ should_update = self.should_update(group)
64
65
  precond_init_scale = group['precond_init_scale']
65
66
  max_size_triangular = group['max_size_triangular']
66
67
  min_ndim_triangular = group['min_ndim_triangular']
@@ -94,7 +95,7 @@ class ForeachPaLMPAdam(PSGDBase):
94
95
  group["step"] += 1
95
96
 
96
97
  Q_triu = [line_to_triu(q) if store_triu_as_line else q for q in Q_list]
97
- if self.should_update(group):
98
+ if should_update:
98
99
  for g, p, q_, q_orig in zip(grad_list, p_list, Q_triu, Q_list):
99
100
  q32 = [promote(qq_) for qq_ in q_]
100
101
  self.do_update(group, [p], [g], [q32], precond_lr, [q_orig], store_triu_as_line)
@@ -60,6 +60,7 @@ class ForeachPSGDKron(PSGDBase):
60
60
  super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
61
61
 
62
62
  def _step(self, group):
63
+ should_update = self.should_update(group)
63
64
  momentum_into_precond_update = group.get("momentum_into_precond_update", True)
64
65
  precond_init_scale = group['precond_init_scale']
65
66
  max_size_triangular = group['max_size_triangular']
@@ -101,7 +102,7 @@ class ForeachPSGDKron(PSGDBase):
101
102
  ea = exp_avg_list.pop(0)
102
103
  q = line_to_triu(q_orig) if store_triu_as_line else q_orig
103
104
 
104
- if self.should_update(group):
105
+ if should_update:
105
106
  q32 = [promote(q_) for q_ in q]
106
107
  self.do_update(group, [p], [g], [q32], precond_lr, [q_orig], store_triu_as_line)
107
108
  set_(g, psgd_precond_grad(q, self.state_(p)["exprs"], ea))
@@ -57,7 +57,7 @@ class ForeachPurePSGD(PSGDBase):
57
57
  super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
58
58
 
59
59
  def _step(self, group):
60
- # update preconditioners all together
60
+ should_update = self.should_update(group)
61
61
  precond_init_scale = group['precond_init_scale']
62
62
  max_size_triangular = group['max_size_triangular']
63
63
  min_ndim_triangular = group['min_ndim_triangular']
@@ -93,7 +93,7 @@ class ForeachPurePSGD(PSGDBase):
93
93
  q_orig = Q_list.pop(0)
94
94
  q = line_to_triu(q_orig) if store_triu_as_line else q_orig
95
95
 
96
- if self.should_update(group):
96
+ if group:
97
97
  q32 = [promote(q_) for q_ in q]
98
98
  self.do_update(group, [p], [g], [q32], precond_lr, [q_orig], store_triu_as_line)
99
99
  psgd_precond_grad(q, self.state_(p)["exprs"], g, inplace=True)
@@ -335,9 +335,7 @@ def promote(x):
335
335
  def min_dtype(xs: List[torch.Tensor]):
336
336
  dtypes = [x.dtype for x in xs]
337
337
  for d in (torch.float32, torch.bfloat16, torch.float16):
338
- if all(d == x for x in dtypes):
339
- return d
340
- if all(d in (x, torch.float32, torch.float64) for x in dtypes):
338
+ if all(x in (d, torch.float32, torch.float64) for x in dtypes):
341
339
  return d
342
340
  return torch.float32
343
341
 
@@ -817,14 +815,15 @@ class PSGDBase(StatefulOptimizer):
817
815
 
818
816
  def do_update(self, group, p_list, grad_list, q_list, precond_lr, original_q: Optional[List] = None,
819
817
  store_triu_as_line=False):
820
- if self.should_update(group, self.balance_probability, 'balance_prob'):
821
- for g, q in zip(grad_list, q_list):
822
- if g.dim() > 1:
823
- psgd_balance_Q(q)
824
-
825
818
  for i, (p, grad, Q) in enumerate(zip(p_list, grad_list, q_list)):
826
819
  psgd_update_precond(Q, self.state_(p)["exprs"], torch.randn_like(grad), grad, precond_lr, self._tiny)
827
- if original_q:
820
+
821
+ for g, q in zip(grad_list, q_list):
822
+ if g.dim() > 1:
823
+ psgd_balance_Q(q)
824
+
825
+ if original_q:
826
+ for q in q_list:
828
827
  if store_triu_as_line:
829
828
  update_triu_(original_q[i], Q)
830
829
  else:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: heavyball
3
- Version: 0.18.1
3
+ Version: 0.18.3
4
4
  Summary: Efficient optimizers
5
5
  Home-page: https://github.com/clashluke/heavyball
6
6
  Author: Lucas Nestler
@@ -10,7 +10,7 @@ setuptools.setup(
10
10
  name='heavyball',
11
11
  license='BSD',
12
12
  description='Efficient optimizers',
13
- version='0.18.1',
13
+ version='0.18.3',
14
14
  long_description=README,
15
15
  url='https://github.com/clashluke/heavyball',
16
16
  packages=setuptools.find_packages(),
@@ -16,7 +16,7 @@ def get_memory():
16
16
 
17
17
 
18
18
  @pytest.mark.parametrize("opt", heavyball.__all__)
19
- @pytest.mark.parametrize("size,depth", [(128, 2)])
19
+ @pytest.mark.parametrize("size,depth", [(128, 1)])
20
20
  def test_foreach(opt, size, depth: int, iterations: int = 8192, outer_iterations: int = 3):
21
21
  set_torch()
22
22
 
@@ -28,12 +28,13 @@ def test_foreach(opt, size, depth: int, iterations: int = 8192, outer_iterations
28
28
  losses = []
29
29
 
30
30
  for stochastic in [False, True]:
31
+ print('stochastic', stochastic)
31
32
  torch.manual_seed(0x2131290)
32
33
  peaks.append([])
33
34
  losses.append([])
34
35
 
35
36
  for i in range(outer_iterations):
36
- model = nn.Sequential(*[nn.Linear(size, size) for _ in range(depth)]).cuda()
37
+ model = nn.Sequential(*[nn.Linear(size, size, bias=False) for _ in range(depth)]).cuda()
37
38
  o = get_optim(opt, model.parameters(), lr=1e-3, stochastic_schedule=stochastic)
38
39
 
39
40
  for _ in range(iterations):
File without changes
File without changes
File without changes
File without changes
File without changes