heavyball 0.21.7__py3-none-any.whl → 0.22.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -3,7 +3,7 @@ import random
3
3
  import torch
4
4
 
5
5
  from .utils import init_preconditioner, update_preconditioner, project, beta_debias, exp_avg_, update_param_, \
6
- precond_schedule, set_, split_p_and_g_in_group, StatefulOptimizer
6
+ precond_schedule, set_, StatefulOptimizer
7
7
 
8
8
 
9
9
  class PrecondSchedulePaLMForeachSOAP(StatefulOptimizer):
@@ -33,14 +33,15 @@ class PrecondSchedulePaLMForeachSOAP(StatefulOptimizer):
33
33
  merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
34
34
  data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1,
35
35
  precond_scheduler=(1 / 3, 9), betas=(None, None), beta2_scale: float = 0.8, split: bool = False,
36
- foreach: bool = True):
36
+ foreach: bool = True, mars: bool = False, caution: bool = False, mars_gamma: float = 0.0025):
37
37
  if betas[0] is not None:
38
38
  beta = betas[0]
39
39
  defaults = {"lr": lr, "beta": beta, "shampoo_beta": shampoo_beta, "eps": eps, "weight_decay": weight_decay,
40
40
  "precondition_frequency": precondition_frequency, "max_precond_dim": max_precond_dim,
41
41
  "merge_dims": merge_dims, "precondition_1d": precondition_1d, "normalize_grads": normalize_grads,
42
42
  "correct_bias": correct_bias, 'warmup_steps': warmup_steps, 'precond_scheduler': precond_scheduler,
43
- 'beta2_scale': beta2_scale, 'split': split}
43
+ 'beta2_scale': beta2_scale, 'split': split, 'mars': mars, 'caution': caution,
44
+ 'mars_gamma': mars_gamma}
44
45
  super().__init__(params, defaults, foreach)
45
46
  self._data_format = data_format
46
47
  self.rng = random.Random(0x120983109)
@@ -52,7 +53,7 @@ class PrecondSchedulePaLMForeachSOAP(StatefulOptimizer):
52
53
  max_precond_dim = group['max_precond_dim']
53
54
  precondition_1d = group['precondition_1d']
54
55
 
55
- for p, g in split_p_and_g_in_group(group):
56
+ for p, g in self.split_p_and_g_in_group(group, beta1=group['beta']):
56
57
  state = self.state_(p)
57
58
  step = state['step'] = state.get("step", -1) + 1
58
59
 
@@ -86,6 +87,8 @@ class PrecondSchedulePaLMForeachSOAP(StatefulOptimizer):
86
87
  denom = exp_avg_(exp_avg, exp_avg_sq, grad, grad_projected, beta1, beta2, step_tensor)
87
88
 
88
89
  update_precond = precond_schedule(step, group['precond_scheduler'], self.rng)
90
+ step_size = -group["lr"] * min(step / group['warmup_steps'], 1)
91
+
89
92
  for p, g, ea, d in zip(p_list, grad, exp_avg, denom):
90
93
  state = self.state_(p)
91
94
  # Projecting the exponential moving average of gradients to the eigenbases of Shampoo's preconditioner
@@ -96,10 +99,7 @@ class PrecondSchedulePaLMForeachSOAP(StatefulOptimizer):
96
99
  # to the original space
97
100
  # CANT DO /= HERE AS EXP_AVG MAY POINT TO THE BUFFER
98
101
  exp_avg_projected = exp_avg_projected / d
99
- set_(d, project(exp_avg_projected, state['Q'], True))
102
+ precond = project(exp_avg_projected, state['Q'], True)
100
103
 
101
104
  update_preconditioner(g, state, max_precond_dim, precondition_1d, old_debiased2, update_precond)
102
-
103
- # Why does this have to be rebiased here?
104
- step_size = -group["lr"] * min(step / group['warmup_steps'], 1)
105
- update_param_(p_list, denom, step_size, group["weight_decay"])
105
+ update_param_([p], [precond], step_size, group["weight_decay"], caution=group['caution'], grad=[g])
@@ -3,11 +3,11 @@ import random
3
3
  import torch
4
4
 
5
5
  from .utils import init_preconditioner, update_preconditioner, project, set_, adaptive_gradient_clipping_, exp_avg_sq_, \
6
- beta_debias, schedule_free_, warmup, ScheduleFree, precond_schedule, split_p_and_g_in_group, copy_stochastic_list_, \
6
+ beta_debias, schedule_free_, warmup, ScheduleFree, precond_schedule, copy_stochastic_list_, \
7
7
  promote
8
8
 
9
9
 
10
- @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
10
+ @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
11
11
  def _compilable_exp_avg_sq_(exp_avg_sq, grad_projected, old_debiased2, eps):
12
12
  eas32, gp32 = [list(map(promote, x)) for x in (exp_avg_sq, grad_projected)]
13
13
  denom = exp_avg_sq_(eas32, gp32, old_debiased2, eps)
@@ -52,15 +52,20 @@ class PrecondScheduleSFPaLMSOAP(ScheduleFree):
52
52
  merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
53
53
  data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1, r=0.0,
54
54
  weight_lr_power=2.0, gradient_clip_val: float = 0.1, precond_scheduler=(1 / 3, 9), betas=(None, None),
55
- split: bool = False, foreach: bool = True):
55
+ split: bool = False, foreach: bool = True, mars: bool = False, caution: bool = False,
56
+ mars_gamma: float = 0.0025):
56
57
  if betas[0] is not None:
57
58
  beta = betas[0]
59
+
60
+ assert not caution, "Caution is not implemented in ScheduleFree optimizers"
61
+
58
62
  defaults = {"lr": lr, "beta": beta, "beta2_scale": beta2_scale, "eps": eps, "weight_decay": weight_decay,
59
63
  "precondition_frequency": precondition_frequency, "max_precond_dim": max_precond_dim,
60
64
  "merge_dims": merge_dims, "precondition_1d": precondition_1d, "normalize_grads": normalize_grads,
61
65
  "correct_bias": correct_bias, 'warmup_steps': warmup_steps, 'r': r,
62
66
  'weight_lr_power': weight_lr_power, 'train_mode': True, 'step': -1, 'weight_sum': 0,
63
- 'gradient_clip_val': gradient_clip_val, 'precond_scheduler': precond_scheduler, 'split': split}
67
+ 'gradient_clip_val': gradient_clip_val, 'precond_scheduler': precond_scheduler, 'split': split,
68
+ 'mars': mars, 'caution': caution, 'mars_gamma': mars_gamma}
64
69
  super().__init__(params, defaults, foreach)
65
70
  self._data_format = data_format
66
71
  self.rng = random.Random(0x120983109)
@@ -87,7 +92,7 @@ class PrecondScheduleSFPaLMSOAP(ScheduleFree):
87
92
  # adaptive gradient clipping
88
93
  adaptive_gradient_clipping_(p_list, grad, group["gradient_clip_val"], eps=group["eps"])
89
94
 
90
- for p, g in split_p_and_g_in_group(group):
95
+ for p, g in self.split_p_and_g_in_group(group, beta1=group['beta']):
91
96
  state = self.state_(p)
92
97
 
93
98
  if "z" not in state:
heavyball/psgd_kron.py CHANGED
@@ -9,7 +9,7 @@ from typing import Optional
9
9
  import torch
10
10
 
11
11
  from .utils import update_param_, warmup, psgd_precond_grad, init_Q_exprs, trust_region_clip_, PSGDBase, \
12
- split_p_and_g_in_group, line_to_triu, triu_to_line, promote, stochastic_lerp_, beta_debias
12
+ line_to_triu, triu_to_line, promote, stochastic_lerp_, beta_debias
13
13
 
14
14
 
15
15
  class ForeachPSGDKron(PSGDBase):
@@ -40,7 +40,8 @@ class ForeachPSGDKron(PSGDBase):
40
40
  momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
41
41
  split: bool = False, clip_fn: Optional[callable] = None, store_triu_as_line: bool = True,
42
42
  foreach: bool = True, q_dtype='float32', stochastic_schedule: bool = True,
43
- storage_dtype: str = 'float32', #
43
+ storage_dtype: str = 'float32', mars: bool = False, caution: bool = False, mars_gamma: float = 0.0025,
44
+ #
44
45
  # expert parameters
45
46
  precond_init_scale=1.0, precond_lr=0.1):
46
47
  if not 0.0 <= lr:
@@ -57,7 +58,9 @@ class ForeachPSGDKron(PSGDBase):
57
58
  min_ndim_triangular=min_ndim_triangular, memory_save_mode=memory_save_mode,
58
59
  momentum_into_precond_update=momentum_into_precond_update, precond_lr=precond_lr,
59
60
  precond_init_scale=precond_init_scale, step=0, warmup_steps=warmup_steps, merge_dims=merge_dims,
60
- split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype, storage_dtype=storage_dtype)
61
+ split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype,
62
+ storage_dtype=storage_dtype,
63
+ mars=mars, caution=caution, mars_gamma=mars_gamma)
61
64
  super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
62
65
 
63
66
  def _step(self, group):
@@ -77,7 +80,7 @@ class ForeachPSGDKron(PSGDBase):
77
80
 
78
81
  vals = []
79
82
 
80
- for p, g in split_p_and_g_in_group(group, should_promote=False):
83
+ for p, g in self.split_p_and_g_in_group(group, should_promote=False, beta1=beta):
81
84
  state = self.state_(p)
82
85
 
83
86
  if 'Q' not in state:
@@ -114,4 +117,4 @@ class ForeachPSGDKron(PSGDBase):
114
117
  self.do_update(group, [p], [ea if momentum_into_precond_update else g], [q32], precond_lr, [q_orig],
115
118
  store_triu_as_line)
116
119
  g = psgd_precond_grad(q, self.state_(p)["exprs"], ea)
117
- update_param_([p], self.clip_fn([g]), lr, weight_decay)
120
+ update_param_([p], self.clip_fn([g]), lr, weight_decay, caution=group['caution'], grad=[g])
heavyball/pure_psgd.py CHANGED
@@ -5,9 +5,9 @@ Source available at https://github.com/evanatyourservice/kron_torch/blob/97a2b5e
5
5
  """
6
6
 
7
7
  import torch
8
-
9
8
  from heavyball.utils import identity
10
- from .utils import update_param_, warmup, psgd_precond_grad, init_Q_exprs, PSGDBase, split_p_and_g_in_group, \
9
+
10
+ from .utils import update_param_, warmup, psgd_precond_grad, init_Q_exprs, PSGDBase, \
11
11
  line_to_triu, triu_to_line, promote
12
12
 
13
13
 
@@ -38,7 +38,8 @@ class ForeachPurePSGD(PSGDBase):
38
38
  max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
39
39
  momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
40
40
  split: bool = False, clip_fn: callable = None, store_triu_as_line: bool = True, foreach: bool = True,
41
- q_dtype='float32', stochastic_schedule: bool = True, #
41
+ q_dtype='float32', stochastic_schedule: bool = True, mars: bool = False, caution: bool = False,
42
+ mars_gamma: float = 0.0025, #
42
43
  # expert parameters
43
44
  precond_init_scale=1.0, precond_lr=0.1):
44
45
  if not 0.0 <= lr:
@@ -49,11 +50,14 @@ class ForeachPurePSGD(PSGDBase):
49
50
  if clip_fn is None:
50
51
  clip_fn = identity
51
52
 
53
+ assert not mars, "MARS is not supported in this optimizer"
54
+
52
55
  defaults = dict(lr=lr, weight_decay=weight_decay, max_size_triangular=max_size_triangular,
53
56
  min_ndim_triangular=min_ndim_triangular, memory_save_mode=memory_save_mode,
54
57
  momentum_into_precond_update=momentum_into_precond_update, precond_lr=precond_lr,
55
58
  precond_init_scale=precond_init_scale, step=0, warmup_steps=warmup_steps, merge_dims=merge_dims,
56
- split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype)
59
+ split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype, mars=mars, caution=caution,
60
+ mars_gamma=mars_gamma)
57
61
  super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
58
62
 
59
63
  def _step(self, group):
@@ -70,7 +74,7 @@ class ForeachPurePSGD(PSGDBase):
70
74
 
71
75
  vals = []
72
76
 
73
- for p, g in split_p_and_g_in_group(group, should_promote=False):
77
+ for p, g in self.split_p_and_g_in_group(group, should_promote=False, beta1=0.0):
74
78
  state = self.state_(p)
75
79
 
76
80
  if 'Q' not in state:
@@ -98,4 +102,4 @@ class ForeachPurePSGD(PSGDBase):
98
102
  q32 = [promote(q_) for q_ in q]
99
103
  self.do_update(group, [p], [g], [q32], precond_lr, [q_orig], store_triu_as_line)
100
104
  psgd_precond_grad(q, self.state_(p)["exprs"], g, inplace=True)
101
- update_param_([p], self.clip_fn([g]), lr, weight_decay)
105
+ update_param_([p], self.clip_fn([g]), lr, weight_decay, caution=group['caution'], grad=[g])
@@ -1,12 +1,13 @@
1
1
  import random
2
2
 
3
3
  import torch
4
+ from heavyball.utils import mars_correction
4
5
 
5
6
  from .utils import init_preconditioner, update_preconditioner, project, set_, adaptive_gradient_clipping_, exp_avg_sq_, \
6
- beta_debias, schedule_free_, warmup, ScheduleFree, split_p_and_g_in_group, copy_stochastic_list_, promote
7
+ beta_debias, schedule_free_, warmup, ScheduleFree, copy_stochastic_list_, promote
7
8
 
8
9
 
9
- @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
10
+ @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
10
11
  def _compilable_exp_avg_sq_(exp_avg_sq, grad_projected, old_debiased2, eps):
11
12
  eas32, gp32 = [list(map(promote, x)) for x in (exp_avg_sq, grad_projected)]
12
13
  denom = exp_avg_sq_(eas32, gp32, old_debiased2, eps)
@@ -44,15 +45,19 @@ class SFPaLMForeachSOAP(ScheduleFree):
44
45
  merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
45
46
  data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1, r=0.0,
46
47
  weight_lr_power=2.0, gradient_clip_val: float = 0.1, betas=(None, None), split: bool = False,
47
- foreach: bool = True):
48
+ foreach: bool = True, mars: bool = False, caution: bool = False, mars_gamma: float = 0.0025):
48
49
  if betas[0] is not None:
49
50
  beta = betas[0]
51
+
52
+ assert not caution, "Caution is not implemented in ScheduleFree optimizers"
53
+
50
54
  defaults = {"lr": lr, "beta": beta, "beta2_scale": beta2_scale, "eps": eps, "weight_decay": weight_decay,
51
55
  "precondition_frequency": precondition_frequency, "max_precond_dim": max_precond_dim,
52
56
  "merge_dims": merge_dims, "precondition_1d": precondition_1d, "normalize_grads": normalize_grads,
53
57
  "correct_bias": correct_bias, 'warmup_steps': warmup_steps, 'r': r,
54
58
  'weight_lr_power': weight_lr_power, 'train_mode': True, 'step': -1,
55
- 'gradient_clip_val': gradient_clip_val, 'weight_sum': 0, 'split': split}
59
+ 'gradient_clip_val': gradient_clip_val, 'weight_sum': 0, 'split': split, 'mars': mars,
60
+ 'caution': caution, 'mars_gamma': mars_gamma}
56
61
  super().__init__(params, defaults, foreach)
57
62
  self._data_format = data_format
58
63
  self.rng = random.Random(0x120983109)
@@ -61,6 +66,7 @@ class SFPaLMForeachSOAP(ScheduleFree):
61
66
  vals = []
62
67
  max_precond_dim = group['max_precond_dim']
63
68
  precondition_1d = group['precondition_1d']
69
+ mars = group['mars']
64
70
 
65
71
  step = group['step'] = group.get("step", 0) + 1
66
72
 
@@ -79,12 +85,14 @@ class SFPaLMForeachSOAP(ScheduleFree):
79
85
 
80
86
  vals = []
81
87
 
82
- for p, g in split_p_and_g_in_group(group):
88
+ for p, g in self.split_p_and_g_in_group(group, beta1=group['beta']):
83
89
  state = self.state_(p)
84
90
 
85
91
  if "z" not in state:
86
92
  state["z"] = torch.clone(p).float()
87
93
  state["exp_avg_sq"] = torch.zeros_like(g, dtype=torch.float32)
94
+ if mars:
95
+ state['mars_prev_grad'] = g.clone()
88
96
  init_preconditioner(g, state, max_precond_dim, precondition_1d)
89
97
  update_preconditioner(g, state, max_precond_dim, precondition_1d, 0, True)
90
98
  continue # first step is skipped so that we never use the current gradients in the projection.
heavyball/utils.py CHANGED
@@ -38,7 +38,7 @@ def warmup(lr: float, step: int, warmup_steps: int):
38
38
  return lr * step / warmup_steps
39
39
 
40
40
 
41
- @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
41
+ @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
42
42
  def _compilable_schedule_free_(p, z, ckp1, grad, lr, beta1):
43
43
  p32 = promote(p)
44
44
  z32 = promote(z)
@@ -141,19 +141,27 @@ def beta_debias(beta, step):
141
141
  return 1 - (1 - beta) / (1 - beta ** step)
142
142
 
143
143
 
144
- @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
145
- def exp_avg_sq_(state, grad, beta2, eps, out=None):
146
- if isinstance(state, torch.Tensor):
147
- state.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
148
- return torch.sqrt(state, out=out).clamp_(min=eps)
149
-
144
+ @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
145
+ def _compilable_exp_avg_sq_(state, grad, beta2, eps, out=None):
150
146
  torch._foreach_mul_(state, beta2)
151
147
  [s.addcmul_(g, g, value=1 - beta2) for s, g in zip(state, grad)]
152
148
  denom = torch._foreach_sqrt(state)
153
- torch._foreach_maximum_(denom, eps)
149
+ [denom.clamp_(min=eps) for denom in denom]
150
+ if out is not None:
151
+ copy_stochastic_list_(out, denom)
152
+ return out
153
+
154
154
  return denom
155
155
 
156
156
 
157
+ def exp_avg_sq_(state, grad, beta2, eps, out=None):
158
+ state, grad = list_guard(state), list_guard(grad)
159
+ if not isinstance(beta2, torch.Tensor):
160
+ beta2 = torch.empty((), dtype=torch.float32, device=state[0].device).fill_(beta2)
161
+ if not isinstance(eps, torch.Tensor):
162
+ eps = torch.empty((), dtype=torch.float32, device=state[0].device).fill_(eps)
163
+ return _compilable_exp_avg_sq_(state, grad, beta2, eps, out)
164
+
157
165
  def adaptive_gradient_clipping_(parameters: List[torch.Tensor], gradients: List[torch.Tensor], clip_val: float,
158
166
  minimum: float = 1e-3, eps: float = 1e-8):
159
167
  if clip_val <= 0:
@@ -168,12 +176,19 @@ def adaptive_gradient_clipping_(parameters: List[torch.Tensor], gradients: List[
168
176
  torch._foreach_mul_(gradients, p_norm)
169
177
 
170
178
 
179
+ def is_compiling():
180
+ try:
181
+ return torch.compiler.is_compiling()
182
+ except AttributeError:
183
+ return True
184
+
185
+
171
186
  def set_(dst: torch.Tensor, src: torch.Tensor):
172
- if not torch.compiler.is_compiling() and src.data_ptr() == dst.data_ptr():
187
+ if not is_compiling() and src.data_ptr() == dst.data_ptr():
173
188
  return
174
189
  if src.shape != dst.shape:
175
190
  src = src.reshape_as(dst)
176
- if not torch.compiler.is_compiling() and src.is_contiguous() and dst.is_contiguous() and src.dtype == dst.dtype:
191
+ if not is_compiling() and src.is_contiguous() and dst.is_contiguous() and src.dtype == dst.dtype:
177
192
  dst.set_(src)
178
193
  else:
179
194
  dst.copy_(src)
@@ -328,7 +343,7 @@ def get_orthogonal_matrix(mat):
328
343
  return final
329
344
 
330
345
 
331
- @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
346
+ @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
332
347
  def _compilable_stochastic_lerp_(x: List[torch.Tensor], y: List[torch.Tensor], a: Union[float, int, torch.Tensor]):
333
348
  for x_, y_ in zip(x, y):
334
349
  x32 = promote(x_)
@@ -338,12 +353,19 @@ def _compilable_stochastic_lerp_(x: List[torch.Tensor], y: List[torch.Tensor], a
338
353
 
339
354
 
340
355
  def stochastic_lerp_(x: List[torch.Tensor], y: List[torch.Tensor], a: Union[float, int, torch.Tensor]):
356
+ x, y = list_guard(x), list_guard(y)
341
357
  if not isinstance(a, torch.Tensor):
342
358
  a = torch.empty((), dtype=torch.float32, device=x[0].device).fill_(a)
343
359
  _compilable_stochastic_lerp_(x, y, a)
344
360
 
345
361
 
346
- @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
362
+ def list_guard(x):
363
+ if isinstance(x, (list, tuple)):
364
+ return x
365
+ return [x]
366
+
367
+
368
+ @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
347
369
  def _compilable_stochastic_add_(x: List[torch.Tensor], y: List[torch.Tensor], alpha: Union[float, int, torch.Tensor]):
348
370
  for x_, y_ in zip(x, y):
349
371
  x32 = promote(x_)
@@ -353,6 +375,7 @@ def _compilable_stochastic_add_(x: List[torch.Tensor], y: List[torch.Tensor], al
353
375
 
354
376
 
355
377
  def stochastic_add_(x: List[torch.Tensor], y: List[torch.Tensor], alpha: Union[float, int, torch.Tensor]):
378
+ x, y = list_guard(x), list_guard(y)
356
379
  if not isinstance(alpha, torch.Tensor):
357
380
  alpha = torch.empty((), dtype=torch.float32, device=x[0].device).fill_(alpha)
358
381
  _compilable_stochastic_add_(x, y, alpha)
@@ -463,6 +486,43 @@ class StatefulOptimizer(torch.optim.Optimizer):
463
486
  def state_(self, arg: torch.Tensor):
464
487
  return self.state[self.key(arg)]
465
488
 
489
+ def mars_correct_list(self, group, p_list, g_list, mars_gamma, beta):
490
+ for p, g in zip(p_list, g_list):
491
+ state = self.state_(p)
492
+ if 'mars_old_grad' not in state:
493
+ state['mars_old_grad'] = torch.zeros_like(g)
494
+ old_gs = [self.state_(p)['mars_old_grad'] for p in p_list]
495
+ mars_correction(g_list, old_gs, mars_gamma, beta)
496
+
497
+ def split_p_and_g_in_group(self, group: dict, skip_none: bool = True, should_promote: bool = True,
498
+ beta1: float = -1.0):
499
+ for p in group["params"]:
500
+ if skip_none and p.grad is None:
501
+ continue
502
+
503
+ if p.grad is None:
504
+ grad = None
505
+ else:
506
+ if should_promote:
507
+ grad = promote(p.grad)
508
+ else:
509
+ grad = p.grad
510
+ if beta1 >= 0 and group.get('mars', False):
511
+ self.mars_correct_list(group, [p], [grad], group['mars_gamma'], beta1)
512
+
513
+ p.grad = None
514
+
515
+ p_views = merge_group(group, p)
516
+ if grad is not None:
517
+ grad = merge_group(group, grad)
518
+ if isinstance(p_views, torch.Tensor):
519
+ yield p_views, grad
520
+ continue
521
+ if grad is None:
522
+ yield from zip(p_views, [None] * len(p_views))
523
+ continue
524
+ yield from zip(p_views, grad)
525
+
466
526
  def state_size(self) -> int:
467
527
  total_bytes = 0
468
528
 
@@ -472,7 +532,7 @@ class StatefulOptimizer(torch.optim.Optimizer):
472
532
  total_bytes += x.numel() * x.element_size()
473
533
 
474
534
  for group in self.param_groups:
475
- for p, _ in split_p_and_g_in_group(group, skip_none=False):
535
+ for p, _ in self.split_p_and_g_in_group(group, skip_none=False):
476
536
  tree_map(_add, self.state_(p))
477
537
  return total_bytes
478
538
 
@@ -581,7 +641,7 @@ def copy_stochastic_list_(target: List[torch.Tensor], source: List[torch.Tensor]
581
641
  copy_stochastic_(t, s)
582
642
 
583
643
 
584
- @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
644
+ @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
585
645
  def _compilable_exp_avg_(exp_avg, exp_avg_sq, grad, grad_projected, beta1, beta2, step):
586
646
  beta1 = beta_debias(beta1, step)
587
647
  beta2 = beta_debias(beta2, step)
@@ -625,22 +685,24 @@ def _compilable_copy_stochastic_(target: torch.Tensor, source: torch.Tensor):
625
685
 
626
686
 
627
687
  def copy_stochastic_(target: torch.Tensor, source: torch.Tensor):
628
- if not torch.compiler.is_compiling() and target.data_ptr() == source.data_ptr():
688
+ if not is_compiling() and target.data_ptr() == source.data_ptr():
629
689
  return
630
690
  if target.dtype != torch.bfloat16 or source.dtype not in (torch.float16, torch.float32, torch.float64):
631
691
  set_(target, source)
632
692
  _compilable_copy_stochastic_(target, source)
633
693
 
634
694
 
635
- @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
636
- def _compilable_update_(p, u, decay, add_fn, lr):
695
+ @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
696
+ def _compilable_update_(p, u, decay, add_fn, lr, caution, g):
637
697
  u = [u_.view_as(p_) for u_, p_ in zip(u, p)]
638
- p32, u32 = [list(map(promote, x)) for x in [p, u]]
698
+ p32, u32, g32 = [list(map(promote, x)) for x in [p, u, g]]
639
699
 
640
700
  if decay > 0:
641
701
  torch._foreach_mul_(p32, 1 - decay * lr)
642
702
 
643
- for p32_, u32_ in zip(p32, u32): # lr is data-dependent -> can't compile a foreach
703
+ for p32_, u32_, g32_ in zip(p32, u32, g32): # lr is data-dependent -> can't compile a foreach
704
+ if caution:
705
+ _compilable_cautioning_(g32_, u32_)
644
706
  if add_fn is None:
645
707
  p32_.add_(u32_, alpha=lr)
646
708
  else:
@@ -650,9 +712,12 @@ def _compilable_update_(p, u, decay, add_fn, lr):
650
712
 
651
713
 
652
714
  def update_param_(param: List[torch.Tensor], update: List[torch.Tensor], lr: float, decay: float,
653
- add_fn: callable = None):
715
+ add_fn: callable = None, caution: bool = False, grad: List[torch.Tensor] = None):
654
716
  lr_tensor = torch.empty((), dtype=torch.float32, device=param[0].device).fill_(lr)
655
- _compilable_update_(param, update, decay, add_fn, lr_tensor)
717
+ param, update, grad = list_guard(param), list_guard(update), list_guard(grad)
718
+ if not caution:
719
+ grad = [None] * len(param)
720
+ _compilable_update_(param, update, decay, add_fn, lr_tensor, caution, grad)
656
721
 
657
722
 
658
723
  def precond_schedule(step, precond_scheduler, rng):
@@ -788,7 +853,7 @@ def psgd_lb(A, max_abs):
788
853
  return x
789
854
 
790
855
 
791
- @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
856
+ @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
792
857
  def psgd_update_precond(Q, exprs, G, precond_lr, tiny, oq, store_triu_as_line):
793
858
  """Update Kronecker product preconditioner Q with pair (V, G)."""
794
859
  exprA, exprGs, _ = exprs
@@ -821,7 +886,7 @@ def psgd_update_precond(Q, exprs, G, precond_lr, tiny, oq, store_triu_as_line):
821
886
  stochastic_add_([o], [term1], -1)
822
887
 
823
888
 
824
- @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
889
+ @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
825
890
  def psgd_precond_grad(Q, exprs, G, inplace: bool = False):
826
891
  """Precondition gradient G with preconditioner Q."""
827
892
  md = min_dtype(Q)
@@ -965,18 +1030,45 @@ class PSGDBase(StatefulOptimizer):
965
1030
  psgd_balance_Q(q)
966
1031
 
967
1032
 
968
- #@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
969
- def _compilable_precond_grad_cached_(cached_q, ea, expr, param, lr, weight_decay, clip_fn):
1033
+ # TODO: Figure out why this sometimes crashes
1034
+ # @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
1035
+ def _compilable_precond_grad_cached_(cached_q, ea, expr, param, lr, weight_decay, clip_fn, caution, grad):
970
1036
  md = min_dtype(cached_q + [ea])
971
1037
  new = torch.einsum(expr, *[c_.to(md) for c_ in cached_q], ea.to(md)).to(torch.float32)
972
- update_param_([param], clip_fn([new]), lr, weight_decay)
1038
+ update_param_([param], clip_fn([new]), lr, weight_decay, caution=caution, grad=grad)
973
1039
 
974
1040
 
975
1041
  def precond_grad_cached_(cached_q: List[torch.Tensor], ea: torch.Tensor, expr: str, param: torch.Tensor, lr: float,
976
- weight_decay: float, clip_fn):
1042
+ weight_decay: float, clip_fn, caution, grad):
977
1043
  if isinstance(lr, float):
978
1044
  lr = torch.empty((), dtype=torch.float32, device=param.device).fill_(lr)
979
- _compilable_precond_grad_cached_(cached_q, ea, expr, param, lr, weight_decay, clip_fn)
1045
+ _compilable_precond_grad_cached_(cached_q, ea, expr, param, lr, weight_decay, clip_fn, caution, grad)
1046
+
1047
+
1048
+ @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
1049
+ def _compilable_mars_correction_(g, old_g, a):
1050
+ g_copy = [g_.clone() for g_ in g]
1051
+ _compilable_stochastic_lerp_(g, old_g, a)
1052
+ copy_stochastic_list_(old_g, g_copy)
1053
+
1054
+
1055
+ def mars_correction(g, old_g, beta1, gamma):
1056
+ a = -gamma * beta1 / (1 - beta1)
1057
+ g, old_g = list_guard(g), list_guard(old_g)
1058
+ a = torch.empty((), dtype=torch.float32, device=g[0].device).fill_(a)
1059
+ _compilable_mars_correction_(g, old_g, a)
1060
+
1061
+
1062
+ @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
1063
+ def _compilable_cautioning_(g, update):
1064
+ mask = (g * update) > 0
1065
+ update.masked_fill_(~mask, 0)
1066
+ scale = mask.numel() / mask.sum().clamp(min=1)
1067
+ update.mul_(scale)
1068
+
1069
+
1070
+ def caution(g, update):
1071
+ _compilable_cautioning_(g, update)
980
1072
 
981
1073
 
982
1074
  def precond_update_prob_schedule(max_prob=1.0, min_prob=0.03, decay=0.001, flat_start=250):
@@ -1013,29 +1105,3 @@ def merge_group(group, *tensors):
1013
1105
  append_or_extend(out, dim_merger(t, group['max_size_triangular'] if 'max_size_triangular' in group else group[
1014
1106
  'max_precond_dim'], group.get('split', False)))
1015
1107
  return out
1016
-
1017
-
1018
- def split_p_and_g_in_group(group: dict, skip_none: bool = True, should_promote: bool = True):
1019
- for p in group["params"]:
1020
- if skip_none and p.grad is None:
1021
- continue
1022
-
1023
- if p.grad is None:
1024
- grad = None
1025
- else:
1026
- if should_promote:
1027
- grad = promote(p.grad)
1028
- else:
1029
- grad = p.grad
1030
- p.grad = None
1031
-
1032
- p_views = merge_group(group, p)
1033
- if grad is not None:
1034
- grad = merge_group(group, grad)
1035
- if isinstance(p_views, torch.Tensor):
1036
- yield p_views, grad
1037
- continue
1038
- if grad is None:
1039
- yield from zip(p_views, [None] * len(p_views))
1040
- continue
1041
- yield from zip(p_views, grad)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: heavyball
3
- Version: 0.21.7
3
+ Version: 0.22.0
4
4
  Summary: Efficient optimizers
5
5
  Home-page: https://github.com/clashluke/heavyball
6
6
  Author: Lucas Nestler
@@ -32,7 +32,7 @@ A simple package of efficient optimizers
32
32
  The goal is not to thrive for completeness, full maintenance or abstraction, but instead to provide a simple
33
33
  largely static alternative to `torch.optim` with more and better optimizers.
34
34
 
35
- Currently (2024-11-22, 0.21.0), the recommended stable optimizer is `PrecondSchedulePaLMSOAP` (see below). The
35
+ Currently (2024-11-26, 0.22.0), the recommended stable optimizer is `PrecondSchedulePaLMSOAP` (see below). The
36
36
  recommended experimental optimizer is `DelayedPSGDKron` ([tuning guide](docs/psgd_efficiency.md)).
37
37
 
38
38
  ## Features
@@ -0,0 +1,24 @@
1
+ heavyball/__init__.py,sha256=icHYN-MGsmHkLUlHCMcZkOlwY7GT63_ayR_a5iPKmzM,2226
2
+ heavyball/cached_delayed_psgd_kron.py,sha256=n3wIOhrop0Ls4MZ0kXpwGuImp1jzPs6VGdxIlPyoYdQ,6827
3
+ heavyball/cached_psgd_kron.py,sha256=KCLsfvj9qh_2FNwRTdWM3zjnt2oGHfsf4Y341rPcceI,6778
4
+ heavyball/delayed_psgd.py,sha256=CaG17zqorLsCSDeGEePOyb6n9ugv8W6gyRQqeQNq-e8,6272
5
+ heavyball/foreach_adamw.py,sha256=uawSbGGUD2E1RtcwspP83yQNElERdGX-diqCI5e8FqE,2825
6
+ heavyball/foreach_adopt.py,sha256=DFEaPswVzdHcbxC-mirsf_okM_HR6r34PDUTty5CrUE,3547
7
+ heavyball/foreach_laprop.py,sha256=J4Vms0nAOMh3GQtAOPyrYOe5WtpzokVv25b9oDnwc2A,2833
8
+ heavyball/foreach_sfadamw.py,sha256=HWbLekY5BloHDIgrN2J0a7IolZCt8Ah2xkLAU_-5oSc,3079
9
+ heavyball/foreach_soap.py,sha256=7B_dP2Hm_xqwpBQiPYkv_c6eoRnU1dV2VZfvSoa4uJ8,4729
10
+ heavyball/p_adam.py,sha256=F-id4qOkAaDTJaKTSNhSsonX-Js5IzIu1Bdj1S4qE2g,6306
11
+ heavyball/palm_foreach_sfadamw.py,sha256=E8raxrBIkSmTEGFzwnfWxKwDJjBQE2vdsmyqfc8aL_A,3375
12
+ heavyball/palm_foreach_soap.py,sha256=IknGm_CzrqDIFEoCkejxjoZ4sfIy6RSoInqlMUOYLB4,6156
13
+ heavyball/precond_schedule_foreach_soap.py,sha256=bJ2ifPFa8zEP9GO8eBpqZzsmP7p_iQkkCkllNeEMHPU,4892
14
+ heavyball/precond_schedule_palm_foreach_soap.py,sha256=4dT9f134-Faq2KuCMCHzMtrkMO-es5p_DYS1of5yF-s,6428
15
+ heavyball/precond_schedule_sfpsoap.py,sha256=FOR-axwlkSN7IHZWYYUVFfjSFCLxc_NdiTlb-n5gmgs,7530
16
+ heavyball/psgd_kron.py,sha256=achB23mQUT3F00IGhjjVf_8YW7VOTHR6YdoCDRyWxsI,6039
17
+ heavyball/pure_psgd.py,sha256=dbYgkunFFA6EsO6fJEhaJRxTH0smi7qLX3Np9XTQ9E4,5079
18
+ heavyball/schedule_free_palm_foreach_soap.py,sha256=0WT_gvTKymqLQzYT6ewDgCmpDq-HgMAewipw1QvyQYA,7267
19
+ heavyball/utils.py,sha256=TVpyev0oL4a78px4cvtaGoGPJqfpfTKE-xBWkRCmzkw,39785
20
+ heavyball-0.22.0.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
21
+ heavyball-0.22.0.dist-info/METADATA,sha256=LqVR3tUgxpk21zsmKxfJAQCKLPzmaQz2PQiKvlvpQe8,11926
22
+ heavyball-0.22.0.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
23
+ heavyball-0.22.0.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
24
+ heavyball-0.22.0.dist-info/RECORD,,
@@ -1,24 +0,0 @@
1
- heavyball/__init__.py,sha256=iqP428JWwwx-XDOZ0nUdbCkOLEyfoqVyWZLQLAcwxaw,2214
2
- heavyball/cached_delayed_psgd_kron.py,sha256=Nyxl-G-o6greKwDN-vLiw5W02GXO2LRvknc0OzvzFnE,6674
3
- heavyball/cached_psgd_kron.py,sha256=HzD6se0AYb-W5hpydUxcR9uqrpe_54PBwgL1VWX3DHU,6592
4
- heavyball/delayed_psgd.py,sha256=m4c-OvcLMrRxSAPYs2l6Up21uCyF2kvHvpcnfe3nzGs,6212
5
- heavyball/foreach_adamw.py,sha256=Rb5U80cgUcEqlEbUU250UTWdoqA7nyiqkV5w1U4bWX4,2445
6
- heavyball/foreach_adopt.py,sha256=ecdi1fKg9i087OGjtKWVbE_DD6Yf4pvpzv4ELCcusvQ,3211
7
- heavyball/foreach_laprop.py,sha256=vi6C_gfjXxw5uN0KHgzxI9itUI1dcgOf3ufoO_VVMp0,2471
8
- heavyball/foreach_sfadamw.py,sha256=rLZORmCIMu9G09FdDgMSiI6pNq34IVoxsPVWtmeDdbQ,2753
9
- heavyball/foreach_soap.py,sha256=4mWSMWYTdjgiXiboI5DwdigecruDtNGKylGAFAVhCRA,4562
10
- heavyball/p_adam.py,sha256=Xyxsavwtw-t0OyTHitYQXZSmF9UJlMDzDAURge-MbbQ,6047
11
- heavyball/palm_foreach_sfadamw.py,sha256=JbNrcoquBGGUI5XNMFouDjpNurVHUW9DbX1A3tSrtno,3025
12
- heavyball/palm_foreach_soap.py,sha256=GzAwM8kOt1X0QCmUZDTdHwPxbJwjH8ic43dyAK5BYCA,6015
13
- heavyball/precond_schedule_foreach_soap.py,sha256=HcObXLfSNN_lKNb4nmC6tkdHcqDIMNX6hILpHKScqLc,4744
14
- heavyball/precond_schedule_palm_foreach_soap.py,sha256=xZ7CJvIfdu2RNAZt2g1S7Xb0Jyy1hNC4MowOFU3nWkk,6283
15
- heavyball/precond_schedule_sfpsoap.py,sha256=PNneiOkrRyV1yIZn91lPmYofd1_OiLqJTDy75RLpXJk,7273
16
- heavyball/psgd_kron.py,sha256=RSLJi5FnSmYjvYBufNDClnYRm-eN_Kpa1Ar2tNP6-X0,5824
17
- heavyball/pure_psgd.py,sha256=LZK0qmvZkBF8g00evaVLtW-sIUJmdoxag1K7O26AqEo,4820
18
- heavyball/schedule_free_palm_foreach_soap.py,sha256=_AqrnChY6iDlQUkF2YUxS7eLjSWCIuvEUOHvMHVM1yY,6873
19
- heavyball/utils.py,sha256=lyJRL-j_-LW6nVcbYTOrzcMACT3lIBmYioMseqgiexk,37211
20
- heavyball-0.21.7.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
21
- heavyball-0.21.7.dist-info/METADATA,sha256=rO1SbkbCdLKf2SD9aTxb_oDuQSZTOq1uh4NmcsjBt4g,11926
22
- heavyball-0.21.7.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
23
- heavyball-0.21.7.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
24
- heavyball-0.21.7.dist-info/RECORD,,