heavyball 0.21.8__py3-none-any.whl → 0.23.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -3,7 +3,7 @@ import random
3
3
  import torch
4
4
 
5
5
  from .utils import init_preconditioner, update_preconditioner, project, beta_debias, exp_avg_, update_param_, \
6
- precond_schedule, set_, split_p_and_g_in_group, StatefulOptimizer
6
+ precond_schedule, set_, StatefulOptimizer
7
7
 
8
8
 
9
9
  class PrecondSchedulePaLMForeachSOAP(StatefulOptimizer):
@@ -33,14 +33,15 @@ class PrecondSchedulePaLMForeachSOAP(StatefulOptimizer):
33
33
  merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
34
34
  data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1,
35
35
  precond_scheduler=(1 / 3, 9), betas=(None, None), beta2_scale: float = 0.8, split: bool = False,
36
- foreach: bool = True):
36
+ foreach: bool = True, mars: bool = False, caution: bool = False, mars_gamma: float = 0.0025):
37
37
  if betas[0] is not None:
38
38
  beta = betas[0]
39
39
  defaults = {"lr": lr, "beta": beta, "shampoo_beta": shampoo_beta, "eps": eps, "weight_decay": weight_decay,
40
40
  "precondition_frequency": precondition_frequency, "max_precond_dim": max_precond_dim,
41
41
  "merge_dims": merge_dims, "precondition_1d": precondition_1d, "normalize_grads": normalize_grads,
42
42
  "correct_bias": correct_bias, 'warmup_steps': warmup_steps, 'precond_scheduler': precond_scheduler,
43
- 'beta2_scale': beta2_scale, 'split': split}
43
+ 'beta2_scale': beta2_scale, 'split': split, 'mars': mars, 'caution': caution,
44
+ 'mars_gamma': mars_gamma}
44
45
  super().__init__(params, defaults, foreach)
45
46
  self._data_format = data_format
46
47
  self.rng = random.Random(0x120983109)
@@ -52,7 +53,7 @@ class PrecondSchedulePaLMForeachSOAP(StatefulOptimizer):
52
53
  max_precond_dim = group['max_precond_dim']
53
54
  precondition_1d = group['precondition_1d']
54
55
 
55
- for p, g in split_p_and_g_in_group(group):
56
+ for p, g in self.split_p_and_g_in_group(group, beta1=group['beta']):
56
57
  state = self.state_(p)
57
58
  step = state['step'] = state.get("step", -1) + 1
58
59
 
@@ -86,6 +87,8 @@ class PrecondSchedulePaLMForeachSOAP(StatefulOptimizer):
86
87
  denom = exp_avg_(exp_avg, exp_avg_sq, grad, grad_projected, beta1, beta2, step_tensor)
87
88
 
88
89
  update_precond = precond_schedule(step, group['precond_scheduler'], self.rng)
90
+ step_size = -group["lr"] * min(step / group['warmup_steps'], 1)
91
+
89
92
  for p, g, ea, d in zip(p_list, grad, exp_avg, denom):
90
93
  state = self.state_(p)
91
94
  # Projecting the exponential moving average of gradients to the eigenbases of Shampoo's preconditioner
@@ -96,10 +99,7 @@ class PrecondSchedulePaLMForeachSOAP(StatefulOptimizer):
96
99
  # to the original space
97
100
  # CANT DO /= HERE AS EXP_AVG MAY POINT TO THE BUFFER
98
101
  exp_avg_projected = exp_avg_projected / d
99
- set_(d, project(exp_avg_projected, state['Q'], True))
102
+ precond = project(exp_avg_projected, state['Q'], True)
100
103
 
101
104
  update_preconditioner(g, state, max_precond_dim, precondition_1d, old_debiased2, update_precond)
102
-
103
- # Why does this have to be rebiased here?
104
- step_size = -group["lr"] * min(step / group['warmup_steps'], 1)
105
- update_param_(p_list, denom, step_size, group["weight_decay"])
105
+ update_param_([p], [precond], step_size, group["weight_decay"], caution=group['caution'], grad=[g])
@@ -3,11 +3,11 @@ import random
3
3
  import torch
4
4
 
5
5
  from .utils import init_preconditioner, update_preconditioner, project, set_, adaptive_gradient_clipping_, exp_avg_sq_, \
6
- beta_debias, schedule_free_, warmup, ScheduleFree, precond_schedule, split_p_and_g_in_group, copy_stochastic_list_, \
6
+ beta_debias, schedule_free_, warmup, ScheduleFree, precond_schedule, copy_stochastic_list_, \
7
7
  promote
8
8
 
9
9
 
10
- @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
10
+ @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
11
11
  def _compilable_exp_avg_sq_(exp_avg_sq, grad_projected, old_debiased2, eps):
12
12
  eas32, gp32 = [list(map(promote, x)) for x in (exp_avg_sq, grad_projected)]
13
13
  denom = exp_avg_sq_(eas32, gp32, old_debiased2, eps)
@@ -52,15 +52,20 @@ class PrecondScheduleSFPaLMSOAP(ScheduleFree):
52
52
  merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
53
53
  data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1, r=0.0,
54
54
  weight_lr_power=2.0, gradient_clip_val: float = 0.1, precond_scheduler=(1 / 3, 9), betas=(None, None),
55
- split: bool = False, foreach: bool = True):
55
+ split: bool = False, foreach: bool = True, mars: bool = False, caution: bool = False,
56
+ mars_gamma: float = 0.0025):
56
57
  if betas[0] is not None:
57
58
  beta = betas[0]
59
+
60
+ assert not caution, "Caution is not implemented in ScheduleFree optimizers"
61
+
58
62
  defaults = {"lr": lr, "beta": beta, "beta2_scale": beta2_scale, "eps": eps, "weight_decay": weight_decay,
59
63
  "precondition_frequency": precondition_frequency, "max_precond_dim": max_precond_dim,
60
64
  "merge_dims": merge_dims, "precondition_1d": precondition_1d, "normalize_grads": normalize_grads,
61
65
  "correct_bias": correct_bias, 'warmup_steps': warmup_steps, 'r': r,
62
66
  'weight_lr_power': weight_lr_power, 'train_mode': True, 'step': -1, 'weight_sum': 0,
63
- 'gradient_clip_val': gradient_clip_val, 'precond_scheduler': precond_scheduler, 'split': split}
67
+ 'gradient_clip_val': gradient_clip_val, 'precond_scheduler': precond_scheduler, 'split': split,
68
+ 'mars': mars, 'caution': caution, 'mars_gamma': mars_gamma}
64
69
  super().__init__(params, defaults, foreach)
65
70
  self._data_format = data_format
66
71
  self.rng = random.Random(0x120983109)
@@ -87,7 +92,7 @@ class PrecondScheduleSFPaLMSOAP(ScheduleFree):
87
92
  # adaptive gradient clipping
88
93
  adaptive_gradient_clipping_(p_list, grad, group["gradient_clip_val"], eps=group["eps"])
89
94
 
90
- for p, g in split_p_and_g_in_group(group):
95
+ for p, g in self.split_p_and_g_in_group(group, beta1=group['beta']):
91
96
  state = self.state_(p)
92
97
 
93
98
  if "z" not in state:
heavyball/psgd_kron.py CHANGED
@@ -9,7 +9,7 @@ from typing import Optional
9
9
  import torch
10
10
 
11
11
  from .utils import update_param_, warmup, psgd_precond_grad, init_Q_exprs, trust_region_clip_, PSGDBase, \
12
- split_p_and_g_in_group, line_to_triu, triu_to_line, promote, stochastic_lerp_, beta_debias
12
+ line_to_triu, triu_to_line, promote, stochastic_lerp_, beta_debias
13
13
 
14
14
 
15
15
  class ForeachPSGDKron(PSGDBase):
@@ -40,7 +40,8 @@ class ForeachPSGDKron(PSGDBase):
40
40
  momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
41
41
  split: bool = False, clip_fn: Optional[callable] = None, store_triu_as_line: bool = True,
42
42
  foreach: bool = True, q_dtype='float32', stochastic_schedule: bool = True,
43
- storage_dtype: str = 'float32', #
43
+ storage_dtype: str = 'float32', mars: bool = False, caution: bool = False, mars_gamma: float = 0.0025,
44
+ #
44
45
  # expert parameters
45
46
  precond_init_scale=1.0, precond_lr=0.1):
46
47
  if not 0.0 <= lr:
@@ -57,7 +58,9 @@ class ForeachPSGDKron(PSGDBase):
57
58
  min_ndim_triangular=min_ndim_triangular, memory_save_mode=memory_save_mode,
58
59
  momentum_into_precond_update=momentum_into_precond_update, precond_lr=precond_lr,
59
60
  precond_init_scale=precond_init_scale, step=0, warmup_steps=warmup_steps, merge_dims=merge_dims,
60
- split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype, storage_dtype=storage_dtype)
61
+ split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype,
62
+ storage_dtype=storage_dtype,
63
+ mars=mars, caution=caution, mars_gamma=mars_gamma)
61
64
  super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
62
65
 
63
66
  def _step(self, group):
@@ -77,7 +80,7 @@ class ForeachPSGDKron(PSGDBase):
77
80
 
78
81
  vals = []
79
82
 
80
- for p, g in split_p_and_g_in_group(group, should_promote=False):
83
+ for p, g in self.split_p_and_g_in_group(group, should_promote=False, beta1=beta):
81
84
  state = self.state_(p)
82
85
 
83
86
  if 'Q' not in state:
@@ -113,5 +116,5 @@ class ForeachPSGDKron(PSGDBase):
113
116
  q32 = [promote(q_) for q_ in q]
114
117
  self.do_update(group, [p], [ea if momentum_into_precond_update else g], [q32], precond_lr, [q_orig],
115
118
  store_triu_as_line)
116
- g = psgd_precond_grad(q, self.state_(p)["exprs"], ea)
117
- update_param_([p], self.clip_fn([g]), lr, weight_decay)
119
+ g = psgd_precond_grad(False, self.state_(p)["exprs"][-1], ea, *q)
120
+ update_param_([p], self.clip_fn([g]), lr, weight_decay, caution=group['caution'], grad=[g])
heavyball/pure_psgd.py CHANGED
@@ -5,9 +5,9 @@ Source available at https://github.com/evanatyourservice/kron_torch/blob/97a2b5e
5
5
  """
6
6
 
7
7
  import torch
8
-
9
8
  from heavyball.utils import identity
10
- from .utils import update_param_, warmup, psgd_precond_grad, init_Q_exprs, PSGDBase, split_p_and_g_in_group, \
9
+
10
+ from .utils import update_param_, warmup, psgd_precond_grad, init_Q_exprs, PSGDBase, \
11
11
  line_to_triu, triu_to_line, promote
12
12
 
13
13
 
@@ -38,7 +38,8 @@ class ForeachPurePSGD(PSGDBase):
38
38
  max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
39
39
  momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
40
40
  split: bool = False, clip_fn: callable = None, store_triu_as_line: bool = True, foreach: bool = True,
41
- q_dtype='float32', stochastic_schedule: bool = True, #
41
+ q_dtype='float32', stochastic_schedule: bool = True, mars: bool = False, caution: bool = False,
42
+ mars_gamma: float = 0.0025, #
42
43
  # expert parameters
43
44
  precond_init_scale=1.0, precond_lr=0.1):
44
45
  if not 0.0 <= lr:
@@ -49,11 +50,14 @@ class ForeachPurePSGD(PSGDBase):
49
50
  if clip_fn is None:
50
51
  clip_fn = identity
51
52
 
53
+ assert not mars, "MARS is not supported in this optimizer"
54
+
52
55
  defaults = dict(lr=lr, weight_decay=weight_decay, max_size_triangular=max_size_triangular,
53
56
  min_ndim_triangular=min_ndim_triangular, memory_save_mode=memory_save_mode,
54
57
  momentum_into_precond_update=momentum_into_precond_update, precond_lr=precond_lr,
55
58
  precond_init_scale=precond_init_scale, step=0, warmup_steps=warmup_steps, merge_dims=merge_dims,
56
- split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype)
59
+ split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype, mars=mars, caution=caution,
60
+ mars_gamma=mars_gamma)
57
61
  super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
58
62
 
59
63
  def _step(self, group):
@@ -70,7 +74,7 @@ class ForeachPurePSGD(PSGDBase):
70
74
 
71
75
  vals = []
72
76
 
73
- for p, g in split_p_and_g_in_group(group, should_promote=False):
77
+ for p, g in self.split_p_and_g_in_group(group, should_promote=False, beta1=0.0):
74
78
  state = self.state_(p)
75
79
 
76
80
  if 'Q' not in state:
@@ -97,5 +101,5 @@ class ForeachPurePSGD(PSGDBase):
97
101
  if group:
98
102
  q32 = [promote(q_) for q_ in q]
99
103
  self.do_update(group, [p], [g], [q32], precond_lr, [q_orig], store_triu_as_line)
100
- psgd_precond_grad(q, self.state_(p)["exprs"], g, inplace=True)
101
- update_param_([p], self.clip_fn([g]), lr, weight_decay)
104
+ psgd_precond_grad(True, self.state_(p)["exprs"][-1], g, *q)
105
+ update_param_([p], self.clip_fn([g]), lr, weight_decay, caution=group['caution'], grad=[g])
@@ -1,12 +1,13 @@
1
1
  import random
2
2
 
3
3
  import torch
4
+ from heavyball.utils import mars_correction
4
5
 
5
6
  from .utils import init_preconditioner, update_preconditioner, project, set_, adaptive_gradient_clipping_, exp_avg_sq_, \
6
- beta_debias, schedule_free_, warmup, ScheduleFree, split_p_and_g_in_group, copy_stochastic_list_, promote
7
+ beta_debias, schedule_free_, warmup, ScheduleFree, copy_stochastic_list_, promote
7
8
 
8
9
 
9
- @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
10
+ @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
10
11
  def _compilable_exp_avg_sq_(exp_avg_sq, grad_projected, old_debiased2, eps):
11
12
  eas32, gp32 = [list(map(promote, x)) for x in (exp_avg_sq, grad_projected)]
12
13
  denom = exp_avg_sq_(eas32, gp32, old_debiased2, eps)
@@ -44,15 +45,19 @@ class SFPaLMForeachSOAP(ScheduleFree):
44
45
  merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
45
46
  data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1, r=0.0,
46
47
  weight_lr_power=2.0, gradient_clip_val: float = 0.1, betas=(None, None), split: bool = False,
47
- foreach: bool = True):
48
+ foreach: bool = True, mars: bool = False, caution: bool = False, mars_gamma: float = 0.0025):
48
49
  if betas[0] is not None:
49
50
  beta = betas[0]
51
+
52
+ assert not caution, "Caution is not implemented in ScheduleFree optimizers"
53
+
50
54
  defaults = {"lr": lr, "beta": beta, "beta2_scale": beta2_scale, "eps": eps, "weight_decay": weight_decay,
51
55
  "precondition_frequency": precondition_frequency, "max_precond_dim": max_precond_dim,
52
56
  "merge_dims": merge_dims, "precondition_1d": precondition_1d, "normalize_grads": normalize_grads,
53
57
  "correct_bias": correct_bias, 'warmup_steps': warmup_steps, 'r': r,
54
58
  'weight_lr_power': weight_lr_power, 'train_mode': True, 'step': -1,
55
- 'gradient_clip_val': gradient_clip_val, 'weight_sum': 0, 'split': split}
59
+ 'gradient_clip_val': gradient_clip_val, 'weight_sum': 0, 'split': split, 'mars': mars,
60
+ 'caution': caution, 'mars_gamma': mars_gamma}
56
61
  super().__init__(params, defaults, foreach)
57
62
  self._data_format = data_format
58
63
  self.rng = random.Random(0x120983109)
@@ -61,6 +66,7 @@ class SFPaLMForeachSOAP(ScheduleFree):
61
66
  vals = []
62
67
  max_precond_dim = group['max_precond_dim']
63
68
  precondition_1d = group['precondition_1d']
69
+ mars = group['mars']
64
70
 
65
71
  step = group['step'] = group.get("step", 0) + 1
66
72
 
@@ -79,12 +85,14 @@ class SFPaLMForeachSOAP(ScheduleFree):
79
85
 
80
86
  vals = []
81
87
 
82
- for p, g in split_p_and_g_in_group(group):
88
+ for p, g in self.split_p_and_g_in_group(group, beta1=group['beta']):
83
89
  state = self.state_(p)
84
90
 
85
91
  if "z" not in state:
86
92
  state["z"] = torch.clone(p).float()
87
93
  state["exp_avg_sq"] = torch.zeros_like(g, dtype=torch.float32)
94
+ if mars:
95
+ state['mars_prev_grad'] = g.clone()
88
96
  init_preconditioner(g, state, max_precond_dim, precondition_1d)
89
97
  update_preconditioner(g, state, max_precond_dim, precondition_1d, 0, True)
90
98
  continue # first step is skipped so that we never use the current gradients in the projection.