heavyball 0.17.3__tar.gz → 0.18.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. {heavyball-0.17.3 → heavyball-0.18.1}/PKG-INFO +1 -1
  2. {heavyball-0.17.3 → heavyball-0.18.1}/heavyball/cached_delayed_psgd_kron.py +12 -27
  3. {heavyball-0.17.3 → heavyball-0.18.1}/heavyball/cached_psgd_kron.py +12 -27
  4. {heavyball-0.17.3 → heavyball-0.18.1}/heavyball/delayed_psgd.py +8 -20
  5. {heavyball-0.17.3 → heavyball-0.18.1}/heavyball/p_adam.py +17 -30
  6. {heavyball-0.17.3 → heavyball-0.18.1}/heavyball/psgd_kron.py +10 -24
  7. {heavyball-0.17.3 → heavyball-0.18.1}/heavyball/pure_psgd.py +14 -27
  8. {heavyball-0.17.3 → heavyball-0.18.1}/heavyball/utils.py +26 -12
  9. {heavyball-0.17.3 → heavyball-0.18.1}/heavyball.egg-info/PKG-INFO +1 -1
  10. {heavyball-0.17.3 → heavyball-0.18.1}/heavyball.egg-info/SOURCES.txt +2 -1
  11. {heavyball-0.17.3 → heavyball-0.18.1}/setup.py +1 -1
  12. {heavyball-0.17.3 → heavyball-0.18.1}/test/test_bf16_q.py +1 -1
  13. {heavyball-0.17.3 → heavyball-0.18.1}/test/test_foreach.py +1 -1
  14. heavyball-0.18.1/test/test_stochastic_updates.py +52 -0
  15. {heavyball-0.17.3 → heavyball-0.18.1}/LICENSE +0 -0
  16. {heavyball-0.17.3 → heavyball-0.18.1}/README.md +0 -0
  17. {heavyball-0.17.3 → heavyball-0.18.1}/heavyball/__init__.py +0 -0
  18. {heavyball-0.17.3 → heavyball-0.18.1}/heavyball/foreach_adamw.py +0 -0
  19. {heavyball-0.17.3 → heavyball-0.18.1}/heavyball/foreach_adopt.py +0 -0
  20. {heavyball-0.17.3 → heavyball-0.18.1}/heavyball/foreach_laprop.py +0 -0
  21. {heavyball-0.17.3 → heavyball-0.18.1}/heavyball/foreach_sfadamw.py +0 -0
  22. {heavyball-0.17.3 → heavyball-0.18.1}/heavyball/foreach_soap.py +0 -0
  23. {heavyball-0.17.3 → heavyball-0.18.1}/heavyball/palm_foreach_sfadamw.py +0 -0
  24. {heavyball-0.17.3 → heavyball-0.18.1}/heavyball/palm_foreach_soap.py +0 -0
  25. {heavyball-0.17.3 → heavyball-0.18.1}/heavyball/precond_schedule_foreach_soap.py +0 -0
  26. {heavyball-0.17.3 → heavyball-0.18.1}/heavyball/precond_schedule_palm_foreach_soap.py +0 -0
  27. {heavyball-0.17.3 → heavyball-0.18.1}/heavyball/precond_schedule_sfpsoap.py +0 -0
  28. {heavyball-0.17.3 → heavyball-0.18.1}/heavyball/schedule_free_palm_foreach_soap.py +0 -0
  29. {heavyball-0.17.3 → heavyball-0.18.1}/heavyball.egg-info/dependency_links.txt +0 -0
  30. {heavyball-0.17.3 → heavyball-0.18.1}/heavyball.egg-info/requires.txt +0 -0
  31. {heavyball-0.17.3 → heavyball-0.18.1}/heavyball.egg-info/top_level.txt +0 -0
  32. {heavyball-0.17.3 → heavyball-0.18.1}/setup.cfg +0 -0
  33. {heavyball-0.17.3 → heavyball-0.18.1}/test/test_closure.py +0 -0
  34. {heavyball-0.17.3 → heavyball-0.18.1}/test/test_memory.py +0 -0
  35. {heavyball-0.17.3 → heavyball-0.18.1}/test/test_merge.py +0 -0
  36. {heavyball-0.17.3 → heavyball-0.18.1}/test/test_no_grad.py +0 -0
  37. {heavyball-0.17.3 → heavyball-0.18.1}/test/test_psgd.py +0 -0
  38. {heavyball-0.17.3 → heavyball-0.18.1}/test/test_soap.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: heavyball
3
- Version: 0.17.3
3
+ Version: 0.18.1
4
4
  Summary: Efficient optimizers
5
5
  Home-page: https://github.com/clashluke/heavyball
6
6
  Author: Lucas Nestler
@@ -7,10 +7,9 @@ Source available at https://github.com/evanatyourservice/kron_torch/blob/97a2b5e
7
7
  from typing import Optional
8
8
 
9
9
  import torch
10
- from heavyball.utils import einsum_base
11
10
 
12
- from .utils import update_param_, warmup, psgd_precond_grad, init_Q_exprs, trust_region_clip_, PSGDBase, \
13
- precond_update_prob_schedule, split_p_and_g_in_group, line_to_triu, triu_to_line, set_, einsum_base, promote
11
+ from .utils import update_param_, warmup, init_Q_exprs, trust_region_clip_, PSGDBase, split_p_and_g_in_group, \
12
+ line_to_triu, triu_to_line, set_, einsum_base, promote
14
13
 
15
14
 
16
15
  class ForeachCachedDelayedPSGDKron(PSGDBase):
@@ -42,7 +41,9 @@ class ForeachCachedDelayedPSGDKron(PSGDBase):
42
41
  max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
43
42
  momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
44
43
  split: bool = False, clip_fn: Optional[callable] = None, store_triu_as_line: bool = True,
45
- foreach: bool = True, q_dtype='float32'):
44
+ foreach: bool = True, q_dtype='float32', stochastic_schedule: bool = True, #
45
+ # expert parameters
46
+ precond_init_scale=1.0, precond_lr=0.1):
46
47
  if not 0.0 <= lr:
47
48
  raise ValueError(f"Invalid learning rate: {lr}")
48
49
  if not 0.0 <= beta < 1.0:
@@ -50,33 +51,17 @@ class ForeachCachedDelayedPSGDKron(PSGDBase):
50
51
  if not 0.0 <= weight_decay:
51
52
  raise ValueError(f"Invalid weight_decay value: {weight_decay}")
52
53
 
53
- if preconditioner_update_probability is None:
54
- preconditioner_update_probability = precond_update_prob_schedule()
55
54
  if clip_fn is None:
56
55
  clip_fn = lambda x: trust_region_clip_(x, 0.9, 1.5)
57
- self.preconditioner_update_probability = preconditioner_update_probability
58
- self.clip_fn = clip_fn
59
56
 
60
57
  defaults = dict(lr=lr, beta=beta, weight_decay=weight_decay, max_size_triangular=max_size_triangular,
61
58
  min_ndim_triangular=min_ndim_triangular, memory_save_mode=memory_save_mode,
62
- momentum_into_precond_update=momentum_into_precond_update, precond_lr=0.1,
63
- # precond lr hardcoded to 0.1
64
- precond_init_scale=1.0, # precond init scale hardcoded to 1.0
65
- step=0, warmup_steps=warmup_steps, merge_dims=merge_dims, split=split,
66
- store_triu_as_line=store_triu_as_line,
67
- q_dtype=q_dtype)
68
- super().__init__(params, defaults, foreach)
69
-
70
- self._prob_step = 0
59
+ momentum_into_precond_update=momentum_into_precond_update, precond_lr=precond_lr,
60
+ precond_init_scale=precond_init_scale, step=0, warmup_steps=warmup_steps, merge_dims=merge_dims,
61
+ split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype)
62
+ super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
71
63
 
72
64
  def _step(self, group):
73
- # update preconditioners all together
74
- update_prob = self.preconditioner_update_probability
75
- if callable(update_prob):
76
- update_prob = update_prob(self._prob_step)
77
- do_update = self.rng.random() < update_prob
78
- self._prob_step += 1
79
-
80
65
  momentum_into_precond_update = group.get("momentum_into_precond_update", True)
81
66
  precond_init_scale = group['precond_init_scale']
82
67
  max_size_triangular = group['max_size_triangular']
@@ -128,11 +113,11 @@ class ForeachCachedDelayedPSGDKron(PSGDBase):
128
113
  q_orig = Q_list.pop(0)
129
114
  ea = exp_avg_list.pop(0)
130
115
 
131
- if do_update:
116
+ if self.should_update(group):
132
117
  q = line_to_triu(q_orig) if store_triu_as_line else q_orig
133
118
  q32 = [promote(q_) for q_ in q]
134
- self.balance([g], [q32])
135
- self.do_update([p], [ea if momentum_into_precond_update else g], [q32], precond_lr, [q_orig], store_triu_as_line=store_triu_as_line)
119
+ self.do_update(group, [p], [ea if momentum_into_precond_update else g], [q32], precond_lr, [q_orig],
120
+ store_triu_as_line)
136
121
  for c_, q_ in zip(cached_q, q):
137
122
  if q_.ndim == 2:
138
123
  torch.matmul(q_.T.conj(), q_, out=c_)
@@ -7,10 +7,9 @@ Source available at https://github.com/evanatyourservice/kron_torch/blob/97a2b5e
7
7
  from typing import Optional
8
8
 
9
9
  import torch
10
- from heavyball.utils import einsum_base
11
10
 
12
- from .utils import update_param_, warmup, psgd_precond_grad, init_Q_exprs, trust_region_clip_, PSGDBase, \
13
- precond_update_prob_schedule, split_p_and_g_in_group, line_to_triu, triu_to_line, set_, einsum_base, promote
11
+ from .utils import update_param_, warmup, init_Q_exprs, trust_region_clip_, PSGDBase, \
12
+ split_p_and_g_in_group, line_to_triu, triu_to_line, set_, einsum_base, promote
14
13
 
15
14
 
16
15
  class ForeachCachedPSGDKron(PSGDBase):
@@ -40,7 +39,9 @@ class ForeachCachedPSGDKron(PSGDBase):
40
39
  max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
41
40
  momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
42
41
  split: bool = False, clip_fn: Optional[callable] = None, store_triu_as_line: bool = True,
43
- foreach: bool = True, q_dtype='float32'):
42
+ foreach: bool = True, q_dtype='float32', stochastic_schedule: bool = True, #
43
+ # expert parameters
44
+ precond_init_scale=1.0, precond_lr=0.1):
44
45
  if not 0.0 <= lr:
45
46
  raise ValueError(f"Invalid learning rate: {lr}")
46
47
  if not 0.0 <= beta < 1.0:
@@ -48,33 +49,17 @@ class ForeachCachedPSGDKron(PSGDBase):
48
49
  if not 0.0 <= weight_decay:
49
50
  raise ValueError(f"Invalid weight_decay value: {weight_decay}")
50
51
 
51
- if preconditioner_update_probability is None:
52
- preconditioner_update_probability = precond_update_prob_schedule()
53
52
  if clip_fn is None:
54
53
  clip_fn = lambda x: trust_region_clip_(x, 0.9, 1.5)
55
- self.preconditioner_update_probability = preconditioner_update_probability
56
- self.clip_fn = clip_fn
57
54
 
58
55
  defaults = dict(lr=lr, beta=beta, weight_decay=weight_decay, max_size_triangular=max_size_triangular,
59
56
  min_ndim_triangular=min_ndim_triangular, memory_save_mode=memory_save_mode,
60
- momentum_into_precond_update=momentum_into_precond_update, precond_lr=0.1,
61
- # precond lr hardcoded to 0.1
62
- precond_init_scale=1.0, # precond init scale hardcoded to 1.0
63
- step=0, warmup_steps=warmup_steps, merge_dims=merge_dims, split=split,
64
- store_triu_as_line=store_triu_as_line,
65
- q_dtype=q_dtype)
66
- super().__init__(params, defaults, foreach)
67
-
68
- self._prob_step = 0
57
+ momentum_into_precond_update=momentum_into_precond_update, precond_lr=precond_lr,
58
+ precond_init_scale=precond_init_scale, step=0, warmup_steps=warmup_steps, merge_dims=merge_dims,
59
+ split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype)
60
+ super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
69
61
 
70
62
  def _step(self, group):
71
- # update preconditioners all together
72
- update_prob = self.preconditioner_update_probability
73
- if callable(update_prob):
74
- update_prob = update_prob(self._prob_step)
75
- do_update = self.rng.random() < update_prob
76
- self._prob_step += 1
77
-
78
63
  momentum_into_precond_update = group.get("momentum_into_precond_update", True)
79
64
  precond_init_scale = group['precond_init_scale']
80
65
  max_size_triangular = group['max_size_triangular']
@@ -128,11 +113,11 @@ class ForeachCachedPSGDKron(PSGDBase):
128
113
 
129
114
  new = torch.einsum(self.state_(p)['cache_expr'], *cached_q, ea)
130
115
 
131
- if do_update:
116
+ if self.should_update(group):
132
117
  q = line_to_triu(q_orig) if store_triu_as_line else q_orig
133
118
  q32 = [promote(q_) for q_ in q]
134
- self.balance([g], [q32])
135
- self.do_update([p], [ea if momentum_into_precond_update else g], [q32], precond_lr, [q_orig], store_triu_as_line=store_triu_as_line)
119
+ self.do_update(group, [p], [ea if momentum_into_precond_update else g], [q32], precond_lr, [q_orig],
120
+ store_triu_as_line)
136
121
  for c_, q_ in zip(cached_q, q):
137
122
  if q_.ndim == 2:
138
123
  torch.matmul(q_.T.conj(), q_, out=c_)
@@ -39,7 +39,9 @@ class ForeachDelayedPSGD(PSGDBase):
39
39
  max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
40
40
  momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
41
41
  split: bool = False, clip_fn: callable = None, store_triu_as_line: bool = True,
42
- foreach: bool = True, q_dtype='float32'):
42
+ foreach: bool = True, q_dtype='float32', stochastic_schedule: bool = True, #
43
+ # expert parameters
44
+ precond_init_scale=1.0, precond_lr=0.1):
43
45
  if not 0.0 <= lr:
44
46
  raise ValueError(f"Invalid learning rate: {lr}")
45
47
  if not 0.0 <= beta < 1.0:
@@ -47,32 +49,19 @@ class ForeachDelayedPSGD(PSGDBase):
47
49
  if not 0.0 <= weight_decay:
48
50
  raise ValueError(f"Invalid weight_decay value: {weight_decay}")
49
51
 
50
- if preconditioner_update_probability is None:
51
- preconditioner_update_probability = precond_update_prob_schedule()
52
52
  if clip_fn is None:
53
53
  clip_fn = lambda x: trust_region_clip_(x, 0.9, 1.5)
54
- self.preconditioner_update_probability = preconditioner_update_probability
55
- self.clip_fn = clip_fn
56
54
 
57
55
  defaults = dict(lr=lr, beta=beta, weight_decay=weight_decay, max_size_triangular=max_size_triangular,
58
56
  min_ndim_triangular=min_ndim_triangular, memory_save_mode=memory_save_mode,
59
- momentum_into_precond_update=momentum_into_precond_update, precond_lr=0.1,
60
- # precond lr hardcoded to 0.1
61
- precond_init_scale=1.0, # precond init scale hardcoded to 1.0
57
+ momentum_into_precond_update=momentum_into_precond_update, precond_lr=precond_lr,
58
+ precond_init_scale=precond_init_scale,
62
59
  step=0, warmup_steps=warmup_steps, merge_dims=merge_dims, split=split,
63
60
  store_triu_as_line=store_triu_as_line, q_dtype=q_dtype)
64
- super().__init__(params, defaults, foreach)
61
+ super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
65
62
 
66
- self._prob_step = 0
67
63
 
68
64
  def _step(self, group):
69
- # update preconditioners all together
70
- update_prob = self.preconditioner_update_probability
71
- if callable(update_prob):
72
- update_prob = update_prob(self._prob_step)
73
- do_update = self.rng.random() < update_prob
74
- self._prob_step += 1
75
-
76
65
  momentum_into_precond_update = group.get("momentum_into_precond_update", True)
77
66
  precond_init_scale = group['precond_init_scale']
78
67
  max_size_triangular = group['max_size_triangular']
@@ -114,10 +103,9 @@ class ForeachDelayedPSGD(PSGDBase):
114
103
  ea = exp_avg_list.pop(0)
115
104
  q = line_to_triu(q_orig) if store_triu_as_line else q_orig
116
105
  new = psgd_precond_grad(q, self.state_(p)["exprs"], ea)
117
- if do_update:
106
+ if self.should_update(group):
118
107
  q32 = [promote(q_) for q_ in q]
119
- self.do_update([p], [ea if momentum_into_precond_update else g], [q32], precond_lr, [q_orig], store_triu_as_line=store_triu_as_line)
120
- self.balance([g], [q32])
108
+ self.do_update(group,[p], [ea if momentum_into_precond_update else g], [q32], precond_lr, [q_orig], store_triu_as_line)
121
109
  set_(g, new)
122
110
 
123
111
  grad_list = self.clip_fn(grad_list)
@@ -5,10 +5,10 @@ Source available at https://github.com/evanatyourservice/kron_torch/blob/97a2b5e
5
5
  """
6
6
 
7
7
  import torch
8
- from heavyball.utils import triu_to_line, line_to_triu
9
8
 
10
- from .utils import update_param_, warmup, psgd_precond_grad, init_Q_exprs, PSGDBase, precond_update_prob_schedule, \
11
- exp_avg_sq_, beta_debias, split_p_and_g_in_group, promote
9
+ from heavyball.utils import triu_to_line, line_to_triu, identity
10
+ from .utils import update_param_, warmup, psgd_precond_grad, init_Q_exprs, PSGDBase, exp_avg_sq_, beta_debias, \
11
+ split_p_and_g_in_group, promote
12
12
 
13
13
 
14
14
  class ForeachPaLMPAdam(PSGDBase):
@@ -38,8 +38,10 @@ class ForeachPaLMPAdam(PSGDBase):
38
38
  max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
39
39
  momentum_into_precond_update=True, warmup_steps: int = 1, betas=(None, None), beta: float = 0.9,
40
40
  beta2_scale: float = 0.8, merge_dims: bool = False, split: bool = False, clip_fn: callable = None,
41
- store_triu_as_line: bool = True,
42
- foreach: bool = True, q_dtype='float32'):
41
+ store_triu_as_line: bool = True, foreach: bool = True, q_dtype='float32',
42
+ stochastic_schedule: bool = True, #
43
+ # expert parameters
44
+ precond_init_scale=1.0, precond_lr=0.1):
43
45
  if not 0.0 <= lr:
44
46
  raise ValueError(f"Invalid learning rate: {lr}")
45
47
  if not 0.0 <= weight_decay:
@@ -47,32 +49,18 @@ class ForeachPaLMPAdam(PSGDBase):
47
49
  if betas[0] is not None:
48
50
  beta = betas[0]
49
51
 
50
- if preconditioner_update_probability is None:
51
- preconditioner_update_probability = precond_update_prob_schedule()
52
52
  if clip_fn is None:
53
- clip_fn = lambda x: x
54
- self.preconditioner_update_probability = preconditioner_update_probability
55
- self.clip_fn = clip_fn
53
+ clip_fn = identity
56
54
 
57
55
  defaults = dict(lr=lr, weight_decay=weight_decay, max_size_triangular=max_size_triangular,
58
56
  min_ndim_triangular=min_ndim_triangular, memory_save_mode=memory_save_mode,
59
- momentum_into_precond_update=momentum_into_precond_update, precond_lr=0.1,
60
- # precond lr hardcoded to 0.1
61
- precond_init_scale=1.0, # precond init scale hardcoded to 1.0
62
- step=0, warmup_steps=warmup_steps, beta=beta, beta2_scale=beta2_scale, merge_dims=merge_dims,
63
- split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype)
64
- super().__init__(params, defaults, foreach)
65
-
66
- self._prob_step = 0
57
+ momentum_into_precond_update=momentum_into_precond_update, precond_lr=precond_lr,
58
+ precond_init_scale=precond_init_scale, step=0, warmup_steps=warmup_steps, beta=beta,
59
+ beta2_scale=beta2_scale, merge_dims=merge_dims, split=split,
60
+ store_triu_as_line=store_triu_as_line, q_dtype=q_dtype)
61
+ super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
67
62
 
68
63
  def _step(self, group):
69
- # update preconditioners all together
70
- update_prob = self.preconditioner_update_probability
71
- if callable(update_prob):
72
- update_prob = update_prob(self._prob_step)
73
- do_update = self.rng.random() < update_prob
74
- self._prob_step += 1
75
-
76
64
  precond_init_scale = group['precond_init_scale']
77
65
  max_size_triangular = group['max_size_triangular']
78
66
  min_ndim_triangular = group['min_ndim_triangular']
@@ -91,8 +79,8 @@ class ForeachPaLMPAdam(PSGDBase):
91
79
  if 'Q' not in state:
92
80
  state['exp_avg'] = torch.zeros_like(g)
93
81
  state['exp_avg_sq'] = torch.zeros_like(g)
94
- Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular,
95
- min_ndim_triangular, memory_save_mode, dtype=q_dtype)
82
+ Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular, min_ndim_triangular,
83
+ memory_save_mode, dtype=q_dtype)
96
84
  state['Q'] = triu_to_line(Q) if store_triu_as_line else Q
97
85
 
98
86
  vals.append((p, g, state["Q"], state['exp_avg'], state['exp_avg_sq']))
@@ -106,11 +94,10 @@ class ForeachPaLMPAdam(PSGDBase):
106
94
  group["step"] += 1
107
95
 
108
96
  Q_triu = [line_to_triu(q) if store_triu_as_line else q for q in Q_list]
109
- if do_update:
97
+ if self.should_update(group):
110
98
  for g, p, q_, q_orig in zip(grad_list, p_list, Q_triu, Q_list):
111
99
  q32 = [promote(qq_) for qq_ in q_]
112
- self.balance([g], [q32])
113
- self.do_update([p], [g], [q32], precond_lr, [q_orig], store_triu_as_line=store_triu_as_line)
100
+ self.do_update(group, [p], [g], [q32], precond_lr, [q_orig], store_triu_as_line)
114
101
  torch._foreach_lerp_(exp_avg, grad_list, 1 - beta_debias(group['beta'], group['step']))
115
102
 
116
103
  beta2 = 1 - group['step'] ** -group['beta2_scale']
@@ -9,7 +9,7 @@ from typing import Optional
9
9
  import torch
10
10
 
11
11
  from .utils import update_param_, warmup, psgd_precond_grad, init_Q_exprs, trust_region_clip_, PSGDBase, \
12
- precond_update_prob_schedule, split_p_and_g_in_group, line_to_triu, triu_to_line, set_, promote
12
+ split_p_and_g_in_group, line_to_triu, triu_to_line, set_, promote
13
13
 
14
14
 
15
15
  class ForeachPSGDKron(PSGDBase):
@@ -39,7 +39,9 @@ class ForeachPSGDKron(PSGDBase):
39
39
  max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
40
40
  momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
41
41
  split: bool = False, clip_fn: Optional[callable] = None, store_triu_as_line: bool = True,
42
- foreach: bool = True, q_dtype='float32'):
42
+ foreach: bool = True, q_dtype='float32', stochastic_schedule: bool = True, #
43
+ # expert parameters
44
+ precond_init_scale=1.0, precond_lr=0.1):
43
45
  if not 0.0 <= lr:
44
46
  raise ValueError(f"Invalid learning rate: {lr}")
45
47
  if not 0.0 <= beta < 1.0:
@@ -47,32 +49,17 @@ class ForeachPSGDKron(PSGDBase):
47
49
  if not 0.0 <= weight_decay:
48
50
  raise ValueError(f"Invalid weight_decay value: {weight_decay}")
49
51
 
50
- if preconditioner_update_probability is None:
51
- preconditioner_update_probability = precond_update_prob_schedule()
52
52
  if clip_fn is None:
53
53
  clip_fn = lambda x: trust_region_clip_(x, 0.9, 1.5)
54
- self.preconditioner_update_probability = preconditioner_update_probability
55
- self.clip_fn = clip_fn
56
54
 
57
55
  defaults = dict(lr=lr, beta=beta, weight_decay=weight_decay, max_size_triangular=max_size_triangular,
58
56
  min_ndim_triangular=min_ndim_triangular, memory_save_mode=memory_save_mode,
59
- momentum_into_precond_update=momentum_into_precond_update, precond_lr=0.1,
60
- # precond lr hardcoded to 0.1
61
- precond_init_scale=1.0, # precond init scale hardcoded to 1.0
62
- step=0, warmup_steps=warmup_steps, merge_dims=merge_dims, split=split,
63
- store_triu_as_line=store_triu_as_line, q_dtype=q_dtype)
64
- super().__init__(params, defaults, foreach)
65
-
66
- self._prob_step = 0
57
+ momentum_into_precond_update=momentum_into_precond_update, precond_lr=precond_lr,
58
+ precond_init_scale=precond_init_scale, step=0, warmup_steps=warmup_steps, merge_dims=merge_dims,
59
+ split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype)
60
+ super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
67
61
 
68
62
  def _step(self, group):
69
- # update preconditioners all together
70
- update_prob = self.preconditioner_update_probability
71
- if callable(update_prob):
72
- update_prob = update_prob(self._prob_step)
73
- do_update = self.rng.random() < update_prob
74
- self._prob_step += 1
75
-
76
63
  momentum_into_precond_update = group.get("momentum_into_precond_update", True)
77
64
  precond_init_scale = group['precond_init_scale']
78
65
  max_size_triangular = group['max_size_triangular']
@@ -114,10 +101,9 @@ class ForeachPSGDKron(PSGDBase):
114
101
  ea = exp_avg_list.pop(0)
115
102
  q = line_to_triu(q_orig) if store_triu_as_line else q_orig
116
103
 
117
- if do_update:
104
+ if self.should_update(group):
118
105
  q32 = [promote(q_) for q_ in q]
119
- self.balance([ea if momentum_into_precond_update else g], [q32])
120
- self.do_update([p], [g], [q32], precond_lr, [q_orig], store_triu_as_line=store_triu_as_line)
106
+ self.do_update(group, [p], [g], [q32], precond_lr, [q_orig], store_triu_as_line)
121
107
  set_(g, psgd_precond_grad(q, self.state_(p)["exprs"], ea))
122
108
 
123
109
  grad_list = self.clip_fn(grad_list)
@@ -5,10 +5,10 @@ Source available at https://github.com/evanatyourservice/kron_torch/blob/97a2b5e
5
5
  """
6
6
 
7
7
  import torch
8
- from heavyball.utils import copy_stochastic_list_
9
8
 
10
- from .utils import update_param_, warmup, psgd_precond_grad, init_Q_exprs, PSGDBase, precond_update_prob_schedule, \
11
- split_p_and_g_in_group, line_to_triu, triu_to_line, promote
9
+ from heavyball.utils import identity
10
+ from .utils import update_param_, warmup, psgd_precond_grad, init_Q_exprs, PSGDBase, split_p_and_g_in_group, \
11
+ line_to_triu, triu_to_line, promote
12
12
 
13
13
 
14
14
  class ForeachPurePSGD(PSGDBase):
@@ -37,39 +37,27 @@ class ForeachPurePSGD(PSGDBase):
37
37
  def __init__(self, params, lr=0.001, weight_decay=0.0, preconditioner_update_probability=None,
38
38
  max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
39
39
  momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
40
- split: bool = False, clip_fn: callable = None, store_triu_as_line: bool = True,
41
- foreach: bool = True, q_dtype='float32'):
40
+ split: bool = False, clip_fn: callable = None, store_triu_as_line: bool = True, foreach: bool = True,
41
+ q_dtype='float32', stochastic_schedule: bool = True, #
42
+ # expert parameters
43
+ precond_init_scale=1.0, precond_lr=0.1):
42
44
  if not 0.0 <= lr:
43
45
  raise ValueError(f"Invalid learning rate: {lr}")
44
46
  if not 0.0 <= weight_decay:
45
47
  raise ValueError(f"Invalid weight_decay value: {weight_decay}")
46
48
 
47
- if preconditioner_update_probability is None:
48
- preconditioner_update_probability = precond_update_prob_schedule()
49
49
  if clip_fn is None:
50
- clip_fn = lambda x: x
51
- self.preconditioner_update_probability = preconditioner_update_probability
52
- self.clip_fn = clip_fn
50
+ clip_fn = identity
53
51
 
54
52
  defaults = dict(lr=lr, weight_decay=weight_decay, max_size_triangular=max_size_triangular,
55
53
  min_ndim_triangular=min_ndim_triangular, memory_save_mode=memory_save_mode,
56
- momentum_into_precond_update=momentum_into_precond_update, precond_lr=0.1,
57
- # precond lr hardcoded to 0.1
58
- precond_init_scale=1.0, # precond init scale hardcoded to 1.0
59
- step=0, warmup_steps=warmup_steps, merge_dims=merge_dims, split=split,
60
- store_triu_as_line=store_triu_as_line, q_dtype=q_dtype)
61
- super().__init__(params, defaults, foreach)
62
-
63
- self._prob_step = 0
54
+ momentum_into_precond_update=momentum_into_precond_update, precond_lr=precond_lr,
55
+ precond_init_scale=precond_init_scale, step=0, warmup_steps=warmup_steps, merge_dims=merge_dims,
56
+ split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype)
57
+ super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
64
58
 
65
59
  def _step(self, group):
66
60
  # update preconditioners all together
67
- update_prob = self.preconditioner_update_probability
68
- if callable(update_prob):
69
- update_prob = update_prob(self._prob_step)
70
- do_update = self.rng.random() < update_prob
71
- self._prob_step += 1
72
-
73
61
  precond_init_scale = group['precond_init_scale']
74
62
  max_size_triangular = group['max_size_triangular']
75
63
  min_ndim_triangular = group['min_ndim_triangular']
@@ -105,10 +93,9 @@ class ForeachPurePSGD(PSGDBase):
105
93
  q_orig = Q_list.pop(0)
106
94
  q = line_to_triu(q_orig) if store_triu_as_line else q_orig
107
95
 
108
- if do_update:
96
+ if self.should_update(group):
109
97
  q32 = [promote(q_) for q_ in q]
110
- self.balance([g], [q32])
111
- self.do_update([p], [g], [q32], precond_lr, [q_orig], store_triu_as_line=store_triu_as_line)
98
+ self.do_update(group, [p], [g], [q32], precond_lr, [q_orig], store_triu_as_line)
112
99
  psgd_precond_grad(q, self.state_(p)["exprs"], g, inplace=True)
113
100
 
114
101
  grad_list = self.clip_fn(grad_list)
@@ -793,21 +793,35 @@ def update_triu_(q_state, materialised):
793
793
  class PSGDBase(StatefulOptimizer):
794
794
  balance_probability: float = 0.01
795
795
 
796
- def __init__(self, parameters, groups, foreach: bool = True):
797
- super().__init__(parameters, groups, foreach)
796
+ def __init__(self, parameters, groups, foreach: bool, stochastic_schedule: bool, clip_fn,
797
+ preconditioner_update_probability):
798
+ super().__init__(parameters, {**groups, 'stochastic_schedule': stochastic_schedule}, foreach)
798
799
  self.rng = random.Random(0x1923213)
799
800
  self._tiny = torch.finfo(torch.bfloat16).tiny
800
-
801
- def balance(self, grad_list, Q_list):
802
- if self.rng.random() > self.balance_probability:
803
- return
804
-
805
- for g, q in zip(grad_list, Q_list):
806
- if g.dim() > 1:
807
- psgd_balance_Q(q)
808
-
809
- def do_update(self, p_list, grad_list, q_list, precond_lr, original_q: Optional[List] = None,
801
+ if clip_fn is None:
802
+ clip_fn = identity
803
+ if preconditioner_update_probability is None:
804
+ preconditioner_update_probability = precond_update_prob_schedule()
805
+ self.clip_fn = clip_fn
806
+ self.preconditioner_update_probability = preconditioner_update_probability
807
+
808
+ def should_update(self, group, prob: Optional[float] = None, name: str = 'cumulative_prob'):
809
+ group[f'{name}_prob_step'] = group.get(f'{name}_prob_step', 0) + 1
810
+ if prob is None:
811
+ prob = self.preconditioner_update_probability(group[f'{name}_prob_step'])
812
+ if group['stochastic_schedule']:
813
+ return self.rng.random() < prob
814
+ cumulative_prob = group.get(name, 0)
815
+ group[name] = cumulative_prob + prob
816
+ return int(group[name]) > int(cumulative_prob)
817
+
818
+ def do_update(self, group, p_list, grad_list, q_list, precond_lr, original_q: Optional[List] = None,
810
819
  store_triu_as_line=False):
820
+ if self.should_update(group, self.balance_probability, 'balance_prob'):
821
+ for g, q in zip(grad_list, q_list):
822
+ if g.dim() > 1:
823
+ psgd_balance_Q(q)
824
+
811
825
  for i, (p, grad, Q) in enumerate(zip(p_list, grad_list, q_list)):
812
826
  psgd_update_precond(Q, self.state_(p)["exprs"], torch.randn_like(grad), grad, precond_lr, self._tiny)
813
827
  if original_q:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: heavyball
3
- Version: 0.17.3
3
+ Version: 0.18.1
4
4
  Summary: Efficient optimizers
5
5
  Home-page: https://github.com/clashluke/heavyball
6
6
  Author: Lucas Nestler
@@ -32,4 +32,5 @@ test/test_memory.py
32
32
  test/test_merge.py
33
33
  test/test_no_grad.py
34
34
  test/test_psgd.py
35
- test/test_soap.py
35
+ test/test_soap.py
36
+ test/test_stochastic_updates.py
@@ -10,7 +10,7 @@ setuptools.setup(
10
10
  name='heavyball',
11
11
  license='BSD',
12
12
  description='Efficient optimizers',
13
- version='0.17.3',
13
+ version='0.18.1',
14
14
  long_description=README,
15
15
  url='https://github.com/clashluke/heavyball',
16
16
  packages=setuptools.find_packages(),
@@ -28,11 +28,11 @@ def test_foreach(opt, size, depth: int, iterations: int = 128, outer_iterations:
28
28
  losses = []
29
29
 
30
30
  for q_dtype in ['float32', 'bfloat16']:
31
+ torch.manual_seed(0x2131290)
31
32
  peaks.append([])
32
33
  losses.append([])
33
34
 
34
35
  for i in range(outer_iterations):
35
- torch.manual_seed(0x2131290)
36
36
  model = nn.Sequential(*[nn.Linear(size, size) for _ in range(depth)]).cuda()
37
37
  o = get_optim(opt, model.parameters(), lr=1e-3, q_dtype=q_dtype)
38
38
 
@@ -26,11 +26,11 @@ def test_foreach(opt, size, depth: int, iterations: int = 5, outer_iterations: i
26
26
  losses = []
27
27
 
28
28
  for foreach in [True, False]:
29
+ torch.manual_seed(0x2131290)
29
30
  peaks.append([])
30
31
  losses.append([])
31
32
 
32
33
  for i in range(outer_iterations):
33
- torch.manual_seed(0x2131290)
34
34
  clean()
35
35
  model = nn.Sequential(*[nn.Linear(size, size) for _ in range(depth)]).cuda()
36
36
  clean()
@@ -0,0 +1,52 @@
1
+ import heavyball
2
+ import heavyball.utils
3
+ import pytest
4
+ import torch
5
+ from benchmark.utils import get_optim
6
+ from heavyball.utils import clean, set_torch, PSGDBase
7
+ from torch import nn
8
+
9
+
10
+ def get_memory():
11
+ clean()
12
+ torch.cuda.synchronize()
13
+ clean()
14
+ torch.cuda.synchronize()
15
+ return torch.cuda.memory_allocated()
16
+
17
+
18
+ @pytest.mark.parametrize("opt", heavyball.__all__)
19
+ @pytest.mark.parametrize("size,depth", [(128, 2)])
20
+ def test_foreach(opt, size, depth: int, iterations: int = 8192, outer_iterations: int = 3):
21
+ set_torch()
22
+
23
+ opt = getattr(heavyball, opt)
24
+ if not issubclass(opt, PSGDBase):
25
+ raise pytest.skip('Only PSGD is supported')
26
+
27
+ peaks = []
28
+ losses = []
29
+
30
+ for stochastic in [False, True]:
31
+ torch.manual_seed(0x2131290)
32
+ peaks.append([])
33
+ losses.append([])
34
+
35
+ for i in range(outer_iterations):
36
+ model = nn.Sequential(*[nn.Linear(size, size) for _ in range(depth)]).cuda()
37
+ o = get_optim(opt, model.parameters(), lr=1e-3, stochastic_schedule=stochastic)
38
+
39
+ for _ in range(iterations):
40
+ loss = model(torch.randn((128, size)).cuda()).square().mean()
41
+ loss.backward()
42
+ o.step()
43
+ o.zero_grad()
44
+ losses[-1].append(loss.detach())
45
+
46
+ del model, o
47
+ clean()
48
+
49
+ stochastic = sum([l.item() for l in losses[1]])
50
+ deterministic = sum([l.item() for l in losses[0]])
51
+ print(f"{deterministic=}, {stochastic=}")
52
+ assert deterministic < stochastic
File without changes
File without changes
File without changes
File without changes
File without changes