heavyball 0.17.2__tar.gz → 0.18.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. {heavyball-0.17.2 → heavyball-0.18.0}/PKG-INFO +1 -1
  2. {heavyball-0.17.2 → heavyball-0.18.0}/heavyball/cached_delayed_psgd_kron.py +6 -19
  3. {heavyball-0.17.2 → heavyball-0.18.0}/heavyball/cached_psgd_kron.py +4 -17
  4. {heavyball-0.17.2 → heavyball-0.18.0}/heavyball/delayed_psgd.py +4 -17
  5. {heavyball-0.17.2 → heavyball-0.18.0}/heavyball/p_adam.py +9 -23
  6. {heavyball-0.17.2 → heavyball-0.18.0}/heavyball/psgd_kron.py +4 -17
  7. {heavyball-0.17.2 → heavyball-0.18.0}/heavyball/pure_psgd.py +6 -19
  8. {heavyball-0.17.2 → heavyball-0.18.0}/heavyball/utils.py +30 -13
  9. {heavyball-0.17.2 → heavyball-0.18.0}/heavyball.egg-info/PKG-INFO +1 -1
  10. {heavyball-0.17.2 → heavyball-0.18.0}/heavyball.egg-info/SOURCES.txt +2 -1
  11. {heavyball-0.17.2 → heavyball-0.18.0}/setup.py +1 -1
  12. {heavyball-0.17.2 → heavyball-0.18.0}/test/test_bf16_q.py +1 -1
  13. {heavyball-0.17.2 → heavyball-0.18.0}/test/test_foreach.py +1 -1
  14. heavyball-0.18.0/test/test_stochastic_updates.py +52 -0
  15. {heavyball-0.17.2 → heavyball-0.18.0}/LICENSE +0 -0
  16. {heavyball-0.17.2 → heavyball-0.18.0}/README.md +0 -0
  17. {heavyball-0.17.2 → heavyball-0.18.0}/heavyball/__init__.py +0 -0
  18. {heavyball-0.17.2 → heavyball-0.18.0}/heavyball/foreach_adamw.py +0 -0
  19. {heavyball-0.17.2 → heavyball-0.18.0}/heavyball/foreach_adopt.py +0 -0
  20. {heavyball-0.17.2 → heavyball-0.18.0}/heavyball/foreach_laprop.py +0 -0
  21. {heavyball-0.17.2 → heavyball-0.18.0}/heavyball/foreach_sfadamw.py +0 -0
  22. {heavyball-0.17.2 → heavyball-0.18.0}/heavyball/foreach_soap.py +0 -0
  23. {heavyball-0.17.2 → heavyball-0.18.0}/heavyball/palm_foreach_sfadamw.py +0 -0
  24. {heavyball-0.17.2 → heavyball-0.18.0}/heavyball/palm_foreach_soap.py +0 -0
  25. {heavyball-0.17.2 → heavyball-0.18.0}/heavyball/precond_schedule_foreach_soap.py +0 -0
  26. {heavyball-0.17.2 → heavyball-0.18.0}/heavyball/precond_schedule_palm_foreach_soap.py +0 -0
  27. {heavyball-0.17.2 → heavyball-0.18.0}/heavyball/precond_schedule_sfpsoap.py +0 -0
  28. {heavyball-0.17.2 → heavyball-0.18.0}/heavyball/schedule_free_palm_foreach_soap.py +0 -0
  29. {heavyball-0.17.2 → heavyball-0.18.0}/heavyball.egg-info/dependency_links.txt +0 -0
  30. {heavyball-0.17.2 → heavyball-0.18.0}/heavyball.egg-info/requires.txt +0 -0
  31. {heavyball-0.17.2 → heavyball-0.18.0}/heavyball.egg-info/top_level.txt +0 -0
  32. {heavyball-0.17.2 → heavyball-0.18.0}/setup.cfg +0 -0
  33. {heavyball-0.17.2 → heavyball-0.18.0}/test/test_closure.py +0 -0
  34. {heavyball-0.17.2 → heavyball-0.18.0}/test/test_memory.py +0 -0
  35. {heavyball-0.17.2 → heavyball-0.18.0}/test/test_merge.py +0 -0
  36. {heavyball-0.17.2 → heavyball-0.18.0}/test/test_no_grad.py +0 -0
  37. {heavyball-0.17.2 → heavyball-0.18.0}/test/test_psgd.py +0 -0
  38. {heavyball-0.17.2 → heavyball-0.18.0}/test/test_soap.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: heavyball
3
- Version: 0.17.2
3
+ Version: 0.18.0
4
4
  Summary: Efficient optimizers
5
5
  Home-page: https://github.com/clashluke/heavyball
6
6
  Author: Lucas Nestler
@@ -42,7 +42,7 @@ class ForeachCachedDelayedPSGDKron(PSGDBase):
42
42
  max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
43
43
  momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
44
44
  split: bool = False, clip_fn: Optional[callable] = None, store_triu_as_line: bool = True,
45
- foreach: bool = True, q_dtype='float32'):
45
+ foreach: bool = True, q_dtype='float32', stochastic_schedule: bool = True):
46
46
  if not 0.0 <= lr:
47
47
  raise ValueError(f"Invalid learning rate: {lr}")
48
48
  if not 0.0 <= beta < 1.0:
@@ -50,12 +50,8 @@ class ForeachCachedDelayedPSGDKron(PSGDBase):
50
50
  if not 0.0 <= weight_decay:
51
51
  raise ValueError(f"Invalid weight_decay value: {weight_decay}")
52
52
 
53
- if preconditioner_update_probability is None:
54
- preconditioner_update_probability = precond_update_prob_schedule()
55
53
  if clip_fn is None:
56
54
  clip_fn = lambda x: trust_region_clip_(x, 0.9, 1.5)
57
- self.preconditioner_update_probability = preconditioner_update_probability
58
- self.clip_fn = clip_fn
59
55
 
60
56
  defaults = dict(lr=lr, beta=beta, weight_decay=weight_decay, max_size_triangular=max_size_triangular,
61
57
  min_ndim_triangular=min_ndim_triangular, memory_save_mode=memory_save_mode,
@@ -63,20 +59,11 @@ class ForeachCachedDelayedPSGDKron(PSGDBase):
63
59
  # precond lr hardcoded to 0.1
64
60
  precond_init_scale=1.0, # precond init scale hardcoded to 1.0
65
61
  step=0, warmup_steps=warmup_steps, merge_dims=merge_dims, split=split,
66
- store_triu_as_line=store_triu_as_line,
67
- q_dtype=q_dtype)
68
- super().__init__(params, defaults, foreach)
62
+ store_triu_as_line=store_triu_as_line, q_dtype=q_dtype)
63
+ super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
69
64
 
70
- self._prob_step = 0
71
65
 
72
66
  def _step(self, group):
73
- # update preconditioners all together
74
- update_prob = self.preconditioner_update_probability
75
- if callable(update_prob):
76
- update_prob = update_prob(self._prob_step)
77
- do_update = self.rng.random() < update_prob
78
- self._prob_step += 1
79
-
80
67
  momentum_into_precond_update = group.get("momentum_into_precond_update", True)
81
68
  precond_init_scale = group['precond_init_scale']
82
69
  max_size_triangular = group['max_size_triangular']
@@ -128,11 +115,11 @@ class ForeachCachedDelayedPSGDKron(PSGDBase):
128
115
  q_orig = Q_list.pop(0)
129
116
  ea = exp_avg_list.pop(0)
130
117
 
131
- if do_update:
118
+ if self.should_update(group):
132
119
  q = line_to_triu(q_orig) if store_triu_as_line else q_orig
133
120
  q32 = [promote(q_) for q_ in q]
134
- self.balance([g], [q32])
135
- self.do_update([p], [ea if momentum_into_precond_update else g], [q32], precond_lr, [q_orig], store_triu_as_line=store_triu_as_line)
121
+ self.do_update(group, [p], [ea if momentum_into_precond_update else g], [q32], precond_lr, [q_orig],
122
+ store_triu_as_line)
136
123
  for c_, q_ in zip(cached_q, q):
137
124
  if q_.ndim == 2:
138
125
  torch.matmul(q_.T.conj(), q_, out=c_)
@@ -40,7 +40,7 @@ class ForeachCachedPSGDKron(PSGDBase):
40
40
  max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
41
41
  momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
42
42
  split: bool = False, clip_fn: Optional[callable] = None, store_triu_as_line: bool = True,
43
- foreach: bool = True, q_dtype='float32'):
43
+ foreach: bool = True, q_dtype='float32', stochastic_schedule: bool = True):
44
44
  if not 0.0 <= lr:
45
45
  raise ValueError(f"Invalid learning rate: {lr}")
46
46
  if not 0.0 <= beta < 1.0:
@@ -48,12 +48,8 @@ class ForeachCachedPSGDKron(PSGDBase):
48
48
  if not 0.0 <= weight_decay:
49
49
  raise ValueError(f"Invalid weight_decay value: {weight_decay}")
50
50
 
51
- if preconditioner_update_probability is None:
52
- preconditioner_update_probability = precond_update_prob_schedule()
53
51
  if clip_fn is None:
54
52
  clip_fn = lambda x: trust_region_clip_(x, 0.9, 1.5)
55
- self.preconditioner_update_probability = preconditioner_update_probability
56
- self.clip_fn = clip_fn
57
53
 
58
54
  defaults = dict(lr=lr, beta=beta, weight_decay=weight_decay, max_size_triangular=max_size_triangular,
59
55
  min_ndim_triangular=min_ndim_triangular, memory_save_mode=memory_save_mode,
@@ -63,18 +59,10 @@ class ForeachCachedPSGDKron(PSGDBase):
63
59
  step=0, warmup_steps=warmup_steps, merge_dims=merge_dims, split=split,
64
60
  store_triu_as_line=store_triu_as_line,
65
61
  q_dtype=q_dtype)
66
- super().__init__(params, defaults, foreach)
62
+ super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
67
63
 
68
- self._prob_step = 0
69
64
 
70
65
  def _step(self, group):
71
- # update preconditioners all together
72
- update_prob = self.preconditioner_update_probability
73
- if callable(update_prob):
74
- update_prob = update_prob(self._prob_step)
75
- do_update = self.rng.random() < update_prob
76
- self._prob_step += 1
77
-
78
66
  momentum_into_precond_update = group.get("momentum_into_precond_update", True)
79
67
  precond_init_scale = group['precond_init_scale']
80
68
  max_size_triangular = group['max_size_triangular']
@@ -128,11 +116,10 @@ class ForeachCachedPSGDKron(PSGDBase):
128
116
 
129
117
  new = torch.einsum(self.state_(p)['cache_expr'], *cached_q, ea)
130
118
 
131
- if do_update:
119
+ if self.should_update(group):
132
120
  q = line_to_triu(q_orig) if store_triu_as_line else q_orig
133
121
  q32 = [promote(q_) for q_ in q]
134
- self.balance([g], [q32])
135
- self.do_update([p], [ea if momentum_into_precond_update else g], [q32], precond_lr, [q_orig], store_triu_as_line=store_triu_as_line)
122
+ self.do_update(group,[p], [ea if momentum_into_precond_update else g], [q32], precond_lr, [q_orig], store_triu_as_line)
136
123
  for c_, q_ in zip(cached_q, q):
137
124
  if q_.ndim == 2:
138
125
  torch.matmul(q_.T.conj(), q_, out=c_)
@@ -39,7 +39,7 @@ class ForeachDelayedPSGD(PSGDBase):
39
39
  max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
40
40
  momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
41
41
  split: bool = False, clip_fn: callable = None, store_triu_as_line: bool = True,
42
- foreach: bool = True, q_dtype='float32'):
42
+ foreach: bool = True, q_dtype='float32', stochastic_schedule: bool = True):
43
43
  if not 0.0 <= lr:
44
44
  raise ValueError(f"Invalid learning rate: {lr}")
45
45
  if not 0.0 <= beta < 1.0:
@@ -47,12 +47,8 @@ class ForeachDelayedPSGD(PSGDBase):
47
47
  if not 0.0 <= weight_decay:
48
48
  raise ValueError(f"Invalid weight_decay value: {weight_decay}")
49
49
 
50
- if preconditioner_update_probability is None:
51
- preconditioner_update_probability = precond_update_prob_schedule()
52
50
  if clip_fn is None:
53
51
  clip_fn = lambda x: trust_region_clip_(x, 0.9, 1.5)
54
- self.preconditioner_update_probability = preconditioner_update_probability
55
- self.clip_fn = clip_fn
56
52
 
57
53
  defaults = dict(lr=lr, beta=beta, weight_decay=weight_decay, max_size_triangular=max_size_triangular,
58
54
  min_ndim_triangular=min_ndim_triangular, memory_save_mode=memory_save_mode,
@@ -61,18 +57,10 @@ class ForeachDelayedPSGD(PSGDBase):
61
57
  precond_init_scale=1.0, # precond init scale hardcoded to 1.0
62
58
  step=0, warmup_steps=warmup_steps, merge_dims=merge_dims, split=split,
63
59
  store_triu_as_line=store_triu_as_line, q_dtype=q_dtype)
64
- super().__init__(params, defaults, foreach)
60
+ super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
65
61
 
66
- self._prob_step = 0
67
62
 
68
63
  def _step(self, group):
69
- # update preconditioners all together
70
- update_prob = self.preconditioner_update_probability
71
- if callable(update_prob):
72
- update_prob = update_prob(self._prob_step)
73
- do_update = self.rng.random() < update_prob
74
- self._prob_step += 1
75
-
76
64
  momentum_into_precond_update = group.get("momentum_into_precond_update", True)
77
65
  precond_init_scale = group['precond_init_scale']
78
66
  max_size_triangular = group['max_size_triangular']
@@ -114,10 +102,9 @@ class ForeachDelayedPSGD(PSGDBase):
114
102
  ea = exp_avg_list.pop(0)
115
103
  q = line_to_triu(q_orig) if store_triu_as_line else q_orig
116
104
  new = psgd_precond_grad(q, self.state_(p)["exprs"], ea)
117
- if do_update:
105
+ if self.should_update(group):
118
106
  q32 = [promote(q_) for q_ in q]
119
- self.do_update([p], [ea if momentum_into_precond_update else g], [q32], precond_lr, [q_orig], store_triu_as_line=store_triu_as_line)
120
- self.balance([g], [q32])
107
+ self.do_update(group,[p], [ea if momentum_into_precond_update else g], [q32], precond_lr, [q_orig], store_triu_as_line)
121
108
  set_(g, new)
122
109
 
123
110
  grad_list = self.clip_fn(grad_list)
@@ -5,7 +5,7 @@ Source available at https://github.com/evanatyourservice/kron_torch/blob/97a2b5e
5
5
  """
6
6
 
7
7
  import torch
8
- from heavyball.utils import triu_to_line, line_to_triu
8
+ from heavyball.utils import triu_to_line, line_to_triu, identity
9
9
 
10
10
  from .utils import update_param_, warmup, psgd_precond_grad, init_Q_exprs, PSGDBase, precond_update_prob_schedule, \
11
11
  exp_avg_sq_, beta_debias, split_p_and_g_in_group, promote
@@ -38,8 +38,8 @@ class ForeachPaLMPAdam(PSGDBase):
38
38
  max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
39
39
  momentum_into_precond_update=True, warmup_steps: int = 1, betas=(None, None), beta: float = 0.9,
40
40
  beta2_scale: float = 0.8, merge_dims: bool = False, split: bool = False, clip_fn: callable = None,
41
- store_triu_as_line: bool = True,
42
- foreach: bool = True, q_dtype='float32'):
41
+ store_triu_as_line: bool = True, foreach: bool = True, q_dtype='float32',
42
+ stochastic_schedule: bool = True):
43
43
  if not 0.0 <= lr:
44
44
  raise ValueError(f"Invalid learning rate: {lr}")
45
45
  if not 0.0 <= weight_decay:
@@ -47,12 +47,8 @@ class ForeachPaLMPAdam(PSGDBase):
47
47
  if betas[0] is not None:
48
48
  beta = betas[0]
49
49
 
50
- if preconditioner_update_probability is None:
51
- preconditioner_update_probability = precond_update_prob_schedule()
52
50
  if clip_fn is None:
53
- clip_fn = lambda x: x
54
- self.preconditioner_update_probability = preconditioner_update_probability
55
- self.clip_fn = clip_fn
51
+ clip_fn = identity
56
52
 
57
53
  defaults = dict(lr=lr, weight_decay=weight_decay, max_size_triangular=max_size_triangular,
58
54
  min_ndim_triangular=min_ndim_triangular, memory_save_mode=memory_save_mode,
@@ -61,18 +57,9 @@ class ForeachPaLMPAdam(PSGDBase):
61
57
  precond_init_scale=1.0, # precond init scale hardcoded to 1.0
62
58
  step=0, warmup_steps=warmup_steps, beta=beta, beta2_scale=beta2_scale, merge_dims=merge_dims,
63
59
  split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype)
64
- super().__init__(params, defaults, foreach)
65
-
66
- self._prob_step = 0
60
+ super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
67
61
 
68
62
  def _step(self, group):
69
- # update preconditioners all together
70
- update_prob = self.preconditioner_update_probability
71
- if callable(update_prob):
72
- update_prob = update_prob(self._prob_step)
73
- do_update = self.rng.random() < update_prob
74
- self._prob_step += 1
75
-
76
63
  precond_init_scale = group['precond_init_scale']
77
64
  max_size_triangular = group['max_size_triangular']
78
65
  min_ndim_triangular = group['min_ndim_triangular']
@@ -91,8 +78,8 @@ class ForeachPaLMPAdam(PSGDBase):
91
78
  if 'Q' not in state:
92
79
  state['exp_avg'] = torch.zeros_like(g)
93
80
  state['exp_avg_sq'] = torch.zeros_like(g)
94
- Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular,
95
- min_ndim_triangular, memory_save_mode, dtype=q_dtype)
81
+ Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular, min_ndim_triangular,
82
+ memory_save_mode, dtype=q_dtype)
96
83
  state['Q'] = triu_to_line(Q) if store_triu_as_line else Q
97
84
 
98
85
  vals.append((p, g, state["Q"], state['exp_avg'], state['exp_avg_sq']))
@@ -106,11 +93,10 @@ class ForeachPaLMPAdam(PSGDBase):
106
93
  group["step"] += 1
107
94
 
108
95
  Q_triu = [line_to_triu(q) if store_triu_as_line else q for q in Q_list]
109
- if do_update:
96
+ if self.should_update(group):
110
97
  for g, p, q_, q_orig in zip(grad_list, p_list, Q_triu, Q_list):
111
98
  q32 = [promote(qq_) for qq_ in q_]
112
- self.balance([g], [q32])
113
- self.do_update([p], [g], [q32], precond_lr, [q_orig], store_triu_as_line=store_triu_as_line)
99
+ self.do_update(group, [p], [g], [q32], precond_lr, [q_orig], store_triu_as_line)
114
100
  torch._foreach_lerp_(exp_avg, grad_list, 1 - beta_debias(group['beta'], group['step']))
115
101
 
116
102
  beta2 = 1 - group['step'] ** -group['beta2_scale']
@@ -39,7 +39,7 @@ class ForeachPSGDKron(PSGDBase):
39
39
  max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
40
40
  momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
41
41
  split: bool = False, clip_fn: Optional[callable] = None, store_triu_as_line: bool = True,
42
- foreach: bool = True, q_dtype='float32'):
42
+ foreach: bool = True, q_dtype='float32', stochastic_schedule: bool = True):
43
43
  if not 0.0 <= lr:
44
44
  raise ValueError(f"Invalid learning rate: {lr}")
45
45
  if not 0.0 <= beta < 1.0:
@@ -47,12 +47,8 @@ class ForeachPSGDKron(PSGDBase):
47
47
  if not 0.0 <= weight_decay:
48
48
  raise ValueError(f"Invalid weight_decay value: {weight_decay}")
49
49
 
50
- if preconditioner_update_probability is None:
51
- preconditioner_update_probability = precond_update_prob_schedule()
52
50
  if clip_fn is None:
53
51
  clip_fn = lambda x: trust_region_clip_(x, 0.9, 1.5)
54
- self.preconditioner_update_probability = preconditioner_update_probability
55
- self.clip_fn = clip_fn
56
52
 
57
53
  defaults = dict(lr=lr, beta=beta, weight_decay=weight_decay, max_size_triangular=max_size_triangular,
58
54
  min_ndim_triangular=min_ndim_triangular, memory_save_mode=memory_save_mode,
@@ -61,18 +57,10 @@ class ForeachPSGDKron(PSGDBase):
61
57
  precond_init_scale=1.0, # precond init scale hardcoded to 1.0
62
58
  step=0, warmup_steps=warmup_steps, merge_dims=merge_dims, split=split,
63
59
  store_triu_as_line=store_triu_as_line, q_dtype=q_dtype)
64
- super().__init__(params, defaults, foreach)
60
+ super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
65
61
 
66
- self._prob_step = 0
67
62
 
68
63
  def _step(self, group):
69
- # update preconditioners all together
70
- update_prob = self.preconditioner_update_probability
71
- if callable(update_prob):
72
- update_prob = update_prob(self._prob_step)
73
- do_update = self.rng.random() < update_prob
74
- self._prob_step += 1
75
-
76
64
  momentum_into_precond_update = group.get("momentum_into_precond_update", True)
77
65
  precond_init_scale = group['precond_init_scale']
78
66
  max_size_triangular = group['max_size_triangular']
@@ -114,10 +102,9 @@ class ForeachPSGDKron(PSGDBase):
114
102
  ea = exp_avg_list.pop(0)
115
103
  q = line_to_triu(q_orig) if store_triu_as_line else q_orig
116
104
 
117
- if do_update:
105
+ if self.should_update(group):
118
106
  q32 = [promote(q_) for q_ in q]
119
- self.balance([ea if momentum_into_precond_update else g], [q32])
120
- self.do_update([p], [g], [q32], precond_lr, [q_orig], store_triu_as_line=store_triu_as_line)
107
+ self.do_update(group,[p], [g], [q32], precond_lr, [q_orig], store_triu_as_line)
121
108
  set_(g, psgd_precond_grad(q, self.state_(p)["exprs"], ea))
122
109
 
123
110
  grad_list = self.clip_fn(grad_list)
@@ -5,7 +5,7 @@ Source available at https://github.com/evanatyourservice/kron_torch/blob/97a2b5e
5
5
  """
6
6
 
7
7
  import torch
8
- from heavyball.utils import copy_stochastic_list_
8
+ from heavyball.utils import copy_stochastic_list_, identity
9
9
 
10
10
  from .utils import update_param_, warmup, psgd_precond_grad, init_Q_exprs, PSGDBase, precond_update_prob_schedule, \
11
11
  split_p_and_g_in_group, line_to_triu, triu_to_line, promote
@@ -38,18 +38,14 @@ class ForeachPurePSGD(PSGDBase):
38
38
  max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
39
39
  momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
40
40
  split: bool = False, clip_fn: callable = None, store_triu_as_line: bool = True,
41
- foreach: bool = True, q_dtype='float32'):
41
+ foreach: bool = True, q_dtype='float32', stochastic_schedule: bool = True):
42
42
  if not 0.0 <= lr:
43
43
  raise ValueError(f"Invalid learning rate: {lr}")
44
44
  if not 0.0 <= weight_decay:
45
45
  raise ValueError(f"Invalid weight_decay value: {weight_decay}")
46
46
 
47
- if preconditioner_update_probability is None:
48
- preconditioner_update_probability = precond_update_prob_schedule()
49
47
  if clip_fn is None:
50
- clip_fn = lambda x: x
51
- self.preconditioner_update_probability = preconditioner_update_probability
52
- self.clip_fn = clip_fn
48
+ clip_fn = identity
53
49
 
54
50
  defaults = dict(lr=lr, weight_decay=weight_decay, max_size_triangular=max_size_triangular,
55
51
  min_ndim_triangular=min_ndim_triangular, memory_save_mode=memory_save_mode,
@@ -58,18 +54,10 @@ class ForeachPurePSGD(PSGDBase):
58
54
  precond_init_scale=1.0, # precond init scale hardcoded to 1.0
59
55
  step=0, warmup_steps=warmup_steps, merge_dims=merge_dims, split=split,
60
56
  store_triu_as_line=store_triu_as_line, q_dtype=q_dtype)
61
- super().__init__(params, defaults, foreach)
62
-
63
- self._prob_step = 0
57
+ super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
64
58
 
65
59
  def _step(self, group):
66
60
  # update preconditioners all together
67
- update_prob = self.preconditioner_update_probability
68
- if callable(update_prob):
69
- update_prob = update_prob(self._prob_step)
70
- do_update = self.rng.random() < update_prob
71
- self._prob_step += 1
72
-
73
61
  precond_init_scale = group['precond_init_scale']
74
62
  max_size_triangular = group['max_size_triangular']
75
63
  min_ndim_triangular = group['min_ndim_triangular']
@@ -105,10 +93,9 @@ class ForeachPurePSGD(PSGDBase):
105
93
  q_orig = Q_list.pop(0)
106
94
  q = line_to_triu(q_orig) if store_triu_as_line else q_orig
107
95
 
108
- if do_update:
96
+ if self.should_update(group):
109
97
  q32 = [promote(q_) for q_ in q]
110
- self.balance([g], [q32])
111
- self.do_update([p], [g], [q32], precond_lr, [q_orig], store_triu_as_line=store_triu_as_line)
98
+ self.do_update(group,[p], [g], [q32], precond_lr, [q_orig], store_triu_as_line)
112
99
  psgd_precond_grad(q, self.state_(p)["exprs"], g, inplace=True)
113
100
 
114
101
  grad_list = self.clip_fn(grad_list)
@@ -668,7 +668,10 @@ def psgd_update_precond(Q, exprs, V, G, step, tiny):
668
668
 
669
669
  term2 += term1 # a + b
670
670
  term1 *= 2 # 2a
671
- term1 -= term2 # 2a - (a + b) == a - b
671
+ if term1.dtype == term2.dtype:
672
+ term1 -= term2 # 2a - (a + b) == a - b
673
+ else:
674
+ term1 = term1 - term2
672
675
 
673
676
  term1 *= step
674
677
  norm = term2.norm(float('inf'))
@@ -790,21 +793,35 @@ def update_triu_(q_state, materialised):
790
793
  class PSGDBase(StatefulOptimizer):
791
794
  balance_probability: float = 0.01
792
795
 
793
- def __init__(self, parameters, groups, foreach: bool = True):
794
- super().__init__(parameters, groups, foreach)
796
+ def __init__(self, parameters, groups, foreach: bool, stochastic_schedule: bool, clip_fn,
797
+ preconditioner_update_probability):
798
+ super().__init__(parameters, {**groups, 'stochastic_schedule': stochastic_schedule}, foreach)
795
799
  self.rng = random.Random(0x1923213)
796
800
  self._tiny = torch.finfo(torch.bfloat16).tiny
797
-
798
- def balance(self, grad_list, Q_list):
799
- if self.rng.random() > self.balance_probability:
800
- return
801
-
802
- for g, q in zip(grad_list, Q_list):
803
- if g.dim() > 1:
804
- psgd_balance_Q(q)
805
-
806
- def do_update(self, p_list, grad_list, q_list, precond_lr, original_q: Optional[List] = None,
801
+ if clip_fn is None:
802
+ clip_fn = identity
803
+ if preconditioner_update_probability is None:
804
+ preconditioner_update_probability = precond_update_prob_schedule()
805
+ self.clip_fn = clip_fn
806
+ self.preconditioner_update_probability = preconditioner_update_probability
807
+
808
+ def should_update(self, group, prob: Optional[float] = None, name: str = 'cumulative_prob'):
809
+ group[f'{name}_prob_step'] = group.get(f'{name}_prob_step', 0) + 1
810
+ if prob is None:
811
+ prob = self.preconditioner_update_probability(group[f'{name}_prob_step'])
812
+ if group['stochastic_schedule']:
813
+ return self.rng.random() < prob
814
+ cumulative_prob = group.get(name, 0)
815
+ group[name] = cumulative_prob + prob
816
+ return int(group[name]) > int(cumulative_prob)
817
+
818
+ def do_update(self, group, p_list, grad_list, q_list, precond_lr, original_q: Optional[List] = None,
807
819
  store_triu_as_line=False):
820
+ if self.should_update(group, self.balance_probability, 'balance_prob'):
821
+ for g, q in zip(grad_list, q_list):
822
+ if g.dim() > 1:
823
+ psgd_balance_Q(q)
824
+
808
825
  for i, (p, grad, Q) in enumerate(zip(p_list, grad_list, q_list)):
809
826
  psgd_update_precond(Q, self.state_(p)["exprs"], torch.randn_like(grad), grad, precond_lr, self._tiny)
810
827
  if original_q:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: heavyball
3
- Version: 0.17.2
3
+ Version: 0.18.0
4
4
  Summary: Efficient optimizers
5
5
  Home-page: https://github.com/clashluke/heavyball
6
6
  Author: Lucas Nestler
@@ -32,4 +32,5 @@ test/test_memory.py
32
32
  test/test_merge.py
33
33
  test/test_no_grad.py
34
34
  test/test_psgd.py
35
- test/test_soap.py
35
+ test/test_soap.py
36
+ test/test_stochastic_updates.py
@@ -10,7 +10,7 @@ setuptools.setup(
10
10
  name='heavyball',
11
11
  license='BSD',
12
12
  description='Efficient optimizers',
13
- version='0.17.2',
13
+ version='0.18.0',
14
14
  long_description=README,
15
15
  url='https://github.com/clashluke/heavyball',
16
16
  packages=setuptools.find_packages(),
@@ -28,11 +28,11 @@ def test_foreach(opt, size, depth: int, iterations: int = 128, outer_iterations:
28
28
  losses = []
29
29
 
30
30
  for q_dtype in ['float32', 'bfloat16']:
31
+ torch.manual_seed(0x2131290)
31
32
  peaks.append([])
32
33
  losses.append([])
33
34
 
34
35
  for i in range(outer_iterations):
35
- torch.manual_seed(0x2131290)
36
36
  model = nn.Sequential(*[nn.Linear(size, size) for _ in range(depth)]).cuda()
37
37
  o = get_optim(opt, model.parameters(), lr=1e-3, q_dtype=q_dtype)
38
38
 
@@ -26,11 +26,11 @@ def test_foreach(opt, size, depth: int, iterations: int = 5, outer_iterations: i
26
26
  losses = []
27
27
 
28
28
  for foreach in [True, False]:
29
+ torch.manual_seed(0x2131290)
29
30
  peaks.append([])
30
31
  losses.append([])
31
32
 
32
33
  for i in range(outer_iterations):
33
- torch.manual_seed(0x2131290)
34
34
  clean()
35
35
  model = nn.Sequential(*[nn.Linear(size, size) for _ in range(depth)]).cuda()
36
36
  clean()
@@ -0,0 +1,52 @@
1
+ import heavyball
2
+ import heavyball.utils
3
+ import pytest
4
+ import torch
5
+ from benchmark.utils import get_optim
6
+ from heavyball.utils import clean, set_torch, PSGDBase
7
+ from torch import nn
8
+
9
+
10
+ def get_memory():
11
+ clean()
12
+ torch.cuda.synchronize()
13
+ clean()
14
+ torch.cuda.synchronize()
15
+ return torch.cuda.memory_allocated()
16
+
17
+
18
+ @pytest.mark.parametrize("opt", heavyball.__all__)
19
+ @pytest.mark.parametrize("size,depth", [(128, 2)])
20
+ def test_foreach(opt, size, depth: int, iterations: int = 1024, outer_iterations: int = 3):
21
+ set_torch()
22
+
23
+ opt = getattr(heavyball, opt)
24
+ if not issubclass(opt, PSGDBase):
25
+ raise pytest.skip('Only PSGD is supported')
26
+
27
+ peaks = []
28
+ losses = []
29
+
30
+ for stochastic in [False, True]:
31
+ torch.manual_seed(0x2131290)
32
+ peaks.append([])
33
+ losses.append([])
34
+
35
+ for i in range(outer_iterations):
36
+ model = nn.Sequential(*[nn.Linear(size, size) for _ in range(depth)]).cuda()
37
+ o = get_optim(opt, model.parameters(), lr=1e-3, stochastic_schedule=stochastic)
38
+
39
+ for _ in range(iterations):
40
+ loss = model(torch.randn((128, size)).cuda()).square().mean()
41
+ loss.backward()
42
+ o.step()
43
+ o.zero_grad()
44
+ losses[-1].append(loss.detach())
45
+
46
+ del model, o
47
+ clean()
48
+
49
+ stochastic = sum([l.item() for l in losses[1]])
50
+ deterministic = sum([l.item() for l in losses[0]])
51
+ print(f"{deterministic=}, {stochastic=}")
52
+ assert deterministic < stochastic
File without changes
File without changes
File without changes
File without changes
File without changes