heavyball 0.21.8__py3-none-any.whl → 0.22.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
heavyball/__init__.py CHANGED
@@ -1,3 +1,4 @@
1
+ from .cached_delayed_psgd_kron import ForeachCachedDelayedPSGDKron
1
2
  from .cached_psgd_kron import ForeachCachedPSGDKron
2
3
  from .delayed_psgd import ForeachDelayedPSGD
3
4
  from .foreach_adamw import ForeachAdamW
@@ -14,7 +15,6 @@ from .precond_schedule_sfpsoap import PrecondScheduleSFPaLMSOAP
14
15
  from .psgd_kron import ForeachPSGDKron
15
16
  from .pure_psgd import ForeachPurePSGD
16
17
  from .schedule_free_palm_foreach_soap import SFPaLMForeachSOAP
17
- from .cached_delayed_psgd_kron import ForeachCachedDelayedPSGDKron
18
18
 
19
19
  PalmForEachSoap = PaLMForeachSOAP
20
20
 
@@ -39,7 +39,8 @@ CachedDelayedPSGDKron = ForeachCachedDelayedPSGDKron
39
39
  __all__ = ['PalmForEachSoap', 'PaLMForeachSFAdamW', 'PaLMForeachSOAP', 'SFPaLMForeachSOAP', 'PrecondScheduleSFPaLMSOAP',
40
40
  'ForeachSOAP', 'ForeachSFAdamW', 'ForeachLaProp', 'ForeachADOPT', 'PrecondScheduleForeachSOAP',
41
41
  'PrecondSchedulePaLMForeachSOAP', 'ForeachPSGDKron', 'ForeachAdamW', 'ForeachPurePSGD', 'ForeachPaLMPAdam',
42
- 'ForeachDelayedPSGD', 'ForeachCachedPSGDKron', 'ForeachCachedDelayedPSGDKron', #
43
- 'PaLMSOAP', 'PaLMSFAdamW', 'PaLMSFSoap', 'PaLMSFAdamW', 'PrecondScheduleSFPaLMSOAP',
44
- 'SOAP', 'SFAdamW', 'LaProp', 'ADOPT', 'PSGDKron', 'AdamW', 'PurePSGD', 'PaLMPAdam', 'DelayedPSGD',
45
- 'CachedPSGDKron', 'CachedDelayedPSGDKron', 'PrecondScheduleSOAP', 'PrecondSchedulePaLMSOAP']
42
+ 'ForeachDelayedPSGD', 'ForeachCachedPSGDKron', 'ForeachCachedDelayedPSGDKron',
43
+ #
44
+ 'PaLMSOAP', 'PaLMSFAdamW', 'PaLMSFSoap', 'PaLMSFAdamW', 'PrecondScheduleSFPaLMSOAP', 'SOAP', 'SFAdamW',
45
+ 'LaProp', 'ADOPT', 'PSGDKron', 'AdamW', 'PurePSGD', 'PaLMPAdam', 'DelayedPSGD', 'CachedPSGDKron',
46
+ 'CachedDelayedPSGDKron', 'PrecondScheduleSOAP', 'PrecondSchedulePaLMSOAP']
@@ -9,7 +9,7 @@ from typing import Optional
9
9
  import torch
10
10
  from heavyball.utils import min_dtype, precond_grad_cached_
11
11
 
12
- from .utils import update_param_, warmup, init_Q_exprs, trust_region_clip_, PSGDBase, split_p_and_g_in_group, \
12
+ from .utils import update_param_, warmup, init_Q_exprs, trust_region_clip_, PSGDBase, \
13
13
  line_to_triu, triu_to_line, einsum_base, promote, stochastic_lerp_, beta_debias, precond_grad_cached_
14
14
 
15
15
 
@@ -43,7 +43,8 @@ class ForeachCachedDelayedPSGDKron(PSGDBase):
43
43
  momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
44
44
  split: bool = False, clip_fn: Optional[callable] = None, store_triu_as_line: bool = True,
45
45
  foreach: bool = True, q_dtype='float32', stochastic_schedule: bool = True,
46
- storage_dtype: str = 'float32', #
46
+ storage_dtype: str = 'float32', mars: bool = False, caution: bool = False, mars_gamma: float = 0.0025,
47
+ #
47
48
  # expert parameters
48
49
  precond_init_scale=1.0, precond_lr=0.1):
49
50
  if not 0.0 <= lr:
@@ -61,7 +62,7 @@ class ForeachCachedDelayedPSGDKron(PSGDBase):
61
62
  momentum_into_precond_update=momentum_into_precond_update, precond_lr=precond_lr,
62
63
  precond_init_scale=precond_init_scale, step=0, warmup_steps=warmup_steps, merge_dims=merge_dims,
63
64
  split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype,
64
- storage_dtype=storage_dtype)
65
+ storage_dtype=storage_dtype, caution=caution, mars_gamma=mars_gamma, mars=mars)
65
66
  super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
66
67
 
67
68
  def _step(self, group):
@@ -81,7 +82,7 @@ class ForeachCachedDelayedPSGDKron(PSGDBase):
81
82
 
82
83
  vals = []
83
84
 
84
- for p, g in split_p_and_g_in_group(group, should_promote=False):
85
+ for p, g in self.split_p_and_g_in_group(group, should_promote=False, beta1=beta):
85
86
  state = self.state_(p)
86
87
 
87
88
  if 'Q' not in state:
@@ -120,7 +121,7 @@ class ForeachCachedDelayedPSGDKron(PSGDBase):
120
121
  q_orig = Q_list.pop(0)
121
122
  ea = exp_avg_list.pop(0)
122
123
 
123
- precond_grad_cached_(cached_q, ea, self.state_(p)['cache_expr'], p, lr, weight_decay, self.clip_fn)
124
+ precond_grad_cached_(cached_q, ea, self.state_(p)['cache_expr'], p, lr, weight_decay, self.clip_fn, group['caution'], g)
124
125
 
125
126
  if should_update:
126
127
  q = line_to_triu(q_orig) if store_triu_as_line else q_orig
@@ -8,7 +8,7 @@ from typing import Optional
8
8
 
9
9
  import torch
10
10
 
11
- from .utils import update_param_, warmup, init_Q_exprs, trust_region_clip_, PSGDBase, split_p_and_g_in_group, \
11
+ from .utils import update_param_, warmup, init_Q_exprs, trust_region_clip_, PSGDBase, \
12
12
  line_to_triu, triu_to_line, einsum_base, promote, stochastic_lerp_, beta_debias, precond_grad_cached_
13
13
 
14
14
 
@@ -40,7 +40,8 @@ class ForeachCachedPSGDKron(PSGDBase):
40
40
  momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
41
41
  split: bool = False, clip_fn: Optional[callable] = None, store_triu_as_line: bool = True,
42
42
  foreach: bool = True, q_dtype='float32', stochastic_schedule: bool = True,
43
- storage_dtype: str = 'float32', #
43
+ storage_dtype: str = 'float32', mars: bool = False, caution: bool = False, mars_gamma: float = 0.0025,
44
+ #
44
45
  # expert parameters
45
46
  precond_init_scale=1.0, precond_lr=0.1):
46
47
  if not 0.0 <= lr:
@@ -58,7 +59,7 @@ class ForeachCachedPSGDKron(PSGDBase):
58
59
  momentum_into_precond_update=momentum_into_precond_update, precond_lr=precond_lr,
59
60
  precond_init_scale=precond_init_scale, step=0, warmup_steps=warmup_steps, merge_dims=merge_dims,
60
61
  split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype,
61
- storage_dtype=storage_dtype)
62
+ storage_dtype=storage_dtype, caution=caution, mars_gamma=mars_gamma, mars=mars)
62
63
  super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
63
64
 
64
65
  def _step(self, group):
@@ -78,7 +79,7 @@ class ForeachCachedPSGDKron(PSGDBase):
78
79
 
79
80
  vals = []
80
81
 
81
- for p, g in split_p_and_g_in_group(group, should_promote=False):
82
+ for p, g in self.split_p_and_g_in_group(group, should_promote=False, beta1=beta):
82
83
  state = self.state_(p)
83
84
 
84
85
  if 'Q' not in state:
@@ -128,4 +129,5 @@ class ForeachCachedPSGDKron(PSGDBase):
128
129
  else:
129
130
  torch.mul(q_.conj(), q_, out=c_)
130
131
 
131
- precond_grad_cached_(cached_q, ea, self.state_(p)['cache_expr'], p, lr, weight_decay, self.clip_fn)
132
+ precond_grad_cached_(cached_q, ea, self.state_(p)['cache_expr'], p, lr, weight_decay, self.clip_fn,
133
+ group['caution'], g)
heavyball/delayed_psgd.py CHANGED
@@ -8,14 +8,13 @@ import torch
8
8
  from heavyball.utils import stochastic_lerp_, beta_debias
9
9
 
10
10
  from .utils import update_param_, warmup, psgd_precond_grad, init_Q_exprs, trust_region_clip_, PSGDBase, \
11
- split_p_and_g_in_group, triu_to_line, line_to_triu, promote
11
+ triu_to_line, line_to_triu, promote
12
+
12
13
 
13
- # TODO: E1123 00:51:55.423000 159394 site-packages/torch/_guards.py:283] [5/0] Error while creating guard:
14
- # E1123 00:51:55.423000 159394 site-packages/torch/_guards.py:283] [5/0] Name: "G['psgd_precond_grad'].__defaults__[0]"
15
14
  @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
16
- def _compilable_psgd_precond_grad_(q, exprs, ea, p, lr, weight_decay, clip_fn):
15
+ def _compilable_psgd_precond_grad_(q, exprs, ea, p, lr, weight_deca, clip_fn, caution, grad):
17
16
  new = psgd_precond_grad(q, exprs, ea)
18
- update_param_([p], clip_fn([new]), lr, weight_decay)
17
+ update_param_([p], clip_fn([new]), lr, weight_decay, caution=caution, grad=grad)
19
18
 
20
19
 
21
20
  class ForeachDelayedPSGD(PSGDBase):
@@ -46,7 +45,8 @@ class ForeachDelayedPSGD(PSGDBase):
46
45
  max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
47
46
  momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
48
47
  split: bool = False, clip_fn: callable = None, store_triu_as_line: bool = True, foreach: bool = True,
49
- q_dtype='float32', stochastic_schedule: bool = True, storage_dtype: str = 'float32', #
48
+ q_dtype='float32', stochastic_schedule: bool = True, storage_dtype: str = 'float32',
49
+ mars: bool = False, caution: bool = False, mars_gamma: float = 0.0025, #
50
50
  # expert parameters
51
51
  precond_init_scale=1.0, precond_lr=0.1):
52
52
  if not 0.0 <= lr:
@@ -63,7 +63,9 @@ class ForeachDelayedPSGD(PSGDBase):
63
63
  min_ndim_triangular=min_ndim_triangular, memory_save_mode=memory_save_mode,
64
64
  momentum_into_precond_update=momentum_into_precond_update, precond_lr=precond_lr,
65
65
  precond_init_scale=precond_init_scale, step=0, warmup_steps=warmup_steps, merge_dims=merge_dims,
66
- split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype, storage_dtype=storage_dtype)
66
+ split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype,
67
+ storage_dtype=storage_dtype,
68
+ caution=caution, mars_gamma=mars_gamma, mars=mars)
67
69
  super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
68
70
 
69
71
  def _step(self, group):
@@ -83,7 +85,7 @@ class ForeachDelayedPSGD(PSGDBase):
83
85
 
84
86
  vals = []
85
87
 
86
- for p, g in split_p_and_g_in_group(group, should_promote=False):
88
+ for p, g in self.split_p_and_g_in_group(group, should_promote=False, beta1=beta):
87
89
  state = self.state_(p)
88
90
 
89
91
  if 'Q' not in state:
@@ -112,7 +114,8 @@ class ForeachDelayedPSGD(PSGDBase):
112
114
  q_orig = Q_list.pop(0)
113
115
  ea = exp_avg_list.pop(0)
114
116
  q = line_to_triu(q_orig) if store_triu_as_line else q_orig
115
- _compilable_psgd_precond_grad_(q, self.state_(p)["exprs"], ea, p, lr, weight_decay, self.clip_fn)
117
+ _compilable_psgd_precond_grad_(q, self.state_(p)["exprs"], ea, p, lr, weight_decay, self.clip_fn, group['caution'],
118
+ g)
116
119
  if should_update:
117
120
  q32 = [promote(q_) for q_ in q]
118
121
  self.do_update(group, [p], [ea if momentum_into_precond_update else g], [q32], precond_lr, [q_orig],
@@ -1,18 +1,19 @@
1
1
  import torch
2
2
  import torch.optim
3
-
4
3
  from heavyball.utils import copy_stochastic_list_
4
+
5
5
  from .utils import warmup, exp_avg_sq_, beta_debias, update_param_, StatefulOptimizer, promote
6
6
 
7
7
 
8
- @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
9
- def _compilable_step_(y, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps, decay):
8
+ @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
9
+ def _compilable_step_(y, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps, decay, caution):
10
10
  g32, exp_avg32, exp_avg_sq32 = [list(map(promote, x)) for x in [grad, exp_avg, exp_avg_sq]]
11
11
 
12
12
  torch._foreach_lerp_(exp_avg32, g32, 1 - beta_debias(beta1, step + 1))
13
13
  denom = list(exp_avg_sq_(exp_avg_sq32, g32, beta_debias(beta2, step + 1), eps))
14
14
 
15
- update_param_(y, exp_avg32, lr, decay, lambda p, e, l: p.addcdiv_(e, denom.pop(0), value=l))
15
+ update_param_(y, exp_avg32, lr, decay, lambda p, e, l: p.addcdiv_(e, denom.pop(0), value=l), caution=caution,
16
+ grad=g32)
16
17
 
17
18
  copy_stochastic_list_(exp_avg, exp_avg32)
18
19
  copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
@@ -20,9 +21,11 @@ def _compilable_step_(y, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps,
20
21
 
21
22
  class ForeachAdamW(StatefulOptimizer):
22
23
  def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
23
- foreach: bool = True, storage_dtype: str = 'float32'):
24
+ foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False,
25
+ mars_gamma: float = 0.0025):
24
26
  defaults = dict(lr=lr, betas=betas, eps=eps, k=0, warmup_steps=warmup_steps, train_mode=True, weight_sum=0.0,
25
- lr_max=-1.0, weight_decay=weight_decay, storage_dtype=storage_dtype)
27
+ lr_max=-1.0, weight_decay=weight_decay, storage_dtype=storage_dtype, mars=mars, caution=caution,
28
+ mars_gamma=mars_gamma)
26
29
  super().__init__(params, defaults, foreach)
27
30
 
28
31
  def _step(self, group):
@@ -48,9 +51,13 @@ class ForeachAdamW(StatefulOptimizer):
48
51
  y, grad, exp_avg_sq, exp_avg = zip(
49
52
  *[(p.data, p.grad, self.state_(p)['exp_avg_sq'], self.state_(p)['exp_avg']) for p in active_p])
50
53
 
54
+ if group['mars']:
55
+ self.mars_correct_list(group, y, grad, group['mars_gamma'], group['betas'][0])
56
+
51
57
  lr = -warmup(group['lr'], k + 1, group['warmup_steps'])
52
58
  lr = torch.empty((), dtype=torch.float32, device=y[0].device).fill_(lr)
53
59
  step = torch.empty((), dtype=torch.int32, device=y[0].device).fill_(k)
54
- _compilable_step_(y, grad, exp_avg_sq, exp_avg, group['betas'][0], group['betas'][1], step, lr, eps, decay)
60
+ _compilable_step_(y, grad, exp_avg_sq, exp_avg, group['betas'][0], group['betas'][1], step, lr, eps, decay,
61
+ group['caution'])
55
62
 
56
63
  group['k'] = k + 1
@@ -5,10 +5,10 @@ from heavyball.utils import copy_stochastic_list_
5
5
  from .utils import warmup, beta_debias, update_param_, StatefulOptimizer, promote
6
6
 
7
7
 
8
- @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
9
- def _compilable_step_(y, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps, decay):
8
+ @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
9
+ def _compilable_step_(y, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps, decay, caution):
10
10
  g32, exp_avg32, exp_avg_sq32 = [list(map(promote, x)) for x in [grad, exp_avg, exp_avg_sq]]
11
- update_param_(y, exp_avg, lr, decay)
11
+ update_param_(y, exp_avg, lr, decay, caution=caution, grad=g32)
12
12
 
13
13
  beta1 = beta_debias(beta1, step)
14
14
  denom = torch._foreach_sqrt(exp_avg_sq32)
@@ -27,9 +27,11 @@ def _compilable_step_(y, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps,
27
27
  class ForeachADOPT(StatefulOptimizer):
28
28
 
29
29
  def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
30
- foreach: bool = True, storage_dtype: str = 'float32'):
30
+ foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False,
31
+ mars_gamma: float = 0.0025):
31
32
  defaults = dict(lr=lr, betas=betas, eps=eps, k=0, warmup_steps=warmup_steps, train_mode=True, weight_sum=0.0,
32
- lr_max=-1.0, weight_decay=weight_decay, storage_dtype=storage_dtype)
33
+ lr_max=-1.0, weight_decay=weight_decay, storage_dtype=storage_dtype, mars=mars, caution=caution,
34
+ mars_gamma=mars_gamma)
33
35
  super().__init__(params, defaults, foreach)
34
36
 
35
37
  def _step(self, group):
@@ -57,11 +59,14 @@ class ForeachADOPT(StatefulOptimizer):
57
59
 
58
60
  group['k'] = k + 1
59
61
 
62
+ if group['mars']:
63
+ self.mars_correct_list(group, y, grad, group['mars_gamma'], group['betas'][0])
64
+
60
65
  if k > 1:
61
66
  lr = -warmup(group['lr'], k - 1, group['warmup_steps'])
62
67
  lr = torch.empty((), dtype=torch.float32, device=y[0].device).fill_(lr)
63
68
  k = torch.empty((), dtype=torch.int32, device=y[0].device).fill_(k)
64
- _compilable_step_(y, grad, exp_avg_sq, exp_avg, group['betas'][0], group['betas'][1], k, lr, eps, decay)
69
+ _compilable_step_(y, grad, exp_avg_sq, exp_avg, group['betas'][0], group['betas'][1], k, lr, eps, decay, group['caution'])
65
70
  return
66
71
 
67
72
  grad = [promote(g) for g in grad]
@@ -4,8 +4,8 @@ import torch.optim
4
4
  from .utils import warmup, exp_avg_sq_, beta_debias, update_param_, StatefulOptimizer, promote, copy_stochastic_list_
5
5
 
6
6
 
7
- @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
8
- def _compilable_step_(y, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps, decay):
7
+ @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
8
+ def _compilable_step_(y, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps, decay, caution):
9
9
  g32, exp_avg32, exp_avg_sq32 = [list(map(promote, x)) for x in [grad, exp_avg, exp_avg_sq]]
10
10
 
11
11
  denom = exp_avg_sq_(exp_avg_sq32, g32, beta_debias(beta2, step), eps)
@@ -14,7 +14,7 @@ def _compilable_step_(y, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps,
14
14
  torch._foreach_mul_(exp_avg32, beta1)
15
15
  [ea32.addcdiv_(g, d, value=1 - beta1) for ea32, g, d in zip(exp_avg32, g32, denom)]
16
16
 
17
- update_param_(y, exp_avg32, lr, decay)
17
+ update_param_(y, exp_avg32, lr, decay, caution=caution, grad=g32)
18
18
 
19
19
  copy_stochastic_list_(exp_avg, exp_avg32)
20
20
  copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
@@ -23,9 +23,11 @@ def _compilable_step_(y, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps,
23
23
  class ForeachLaProp(StatefulOptimizer):
24
24
 
25
25
  def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=1,
26
- foreach: bool = True, storage_dtype: str = 'float32'):
26
+ foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False,
27
+ mars_gamma: float = 0.0025):
27
28
  defaults = dict(lr=lr, betas=betas, eps=eps, k=0, warmup_steps=warmup_steps, train_mode=True, weight_sum=0.0,
28
- lr_max=-1.0, weight_decay=weight_decay, storage_dtype=storage_dtype)
29
+ lr_max=-1.0, weight_decay=weight_decay, storage_dtype=storage_dtype, mars=mars, caution=caution,
30
+ mars_gamma=mars_gamma)
29
31
  super().__init__(params, defaults, foreach)
30
32
 
31
33
  def _step(self, group):
@@ -52,10 +54,14 @@ class ForeachLaProp(StatefulOptimizer):
52
54
  *[(p.data, p.grad, self.state_(p)['exp_avg_sq'], self.state_(p)['exp_avg']) #
53
55
  for p in active_p])
54
56
 
57
+ if group['mars']:
58
+ self.mars_correct_list(group, y, grad, group['mars_gamma'], group['betas'][0])
59
+
55
60
  lr = -warmup(group['lr'], k + 1, group['warmup_steps'])
56
61
  lr = torch.empty((), dtype=torch.float32, device=y[0].device).fill_(lr)
57
62
  step = torch.empty((), dtype=torch.int32, device=y[0].device).fill_(k + 1)
58
63
 
59
- _compilable_step_(y, grad, exp_avg_sq, exp_avg, group['betas'][0], group['betas'][1], step, lr, eps, decay)
64
+ _compilable_step_(y, grad, exp_avg_sq, exp_avg, group['betas'][0], group['betas'][1], step, lr, eps, decay,
65
+ group['caution'])
60
66
 
61
67
  group['k'] = k + 1
@@ -5,7 +5,7 @@ from heavyball.utils import get_ckp1, copy_stochastic_list_
5
5
  from .utils import warmup, ScheduleFree, exp_avg_sq_, beta_debias, promote, _compilable_schedule_free_
6
6
 
7
7
 
8
- @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
8
+ @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
9
9
  def _compilable_step_(y, grad, exp_avg_sq, z, beta1, beta2, step, ckp1, eps, decay, lr):
10
10
  old_debiased2 = beta_debias(beta2, step)
11
11
 
@@ -21,13 +21,17 @@ def _compilable_step_(y, grad, exp_avg_sq, z, beta1, beta2, step, ckp1, eps, dec
21
21
 
22
22
  copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
23
23
 
24
+
24
25
  class ForeachSFAdamW(ScheduleFree):
25
26
  def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0, r=0.0,
26
- weight_lr_power=2.0, foreach: bool = True, storage_dtype: str = 'float32'):
27
+ weight_lr_power=2.0, foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False,
28
+ caution: bool = False, mars_gamma: float = 0.0025):
29
+
30
+ assert not caution, "Caution not implemented for SFAdamW"
27
31
 
28
32
  defaults = dict(lr=lr, betas=betas, eps=eps, r=r, k=0, warmup_steps=warmup_steps, train_mode=True,
29
33
  weight_sum=0.0, lr_max=-1.0, weight_lr_power=weight_lr_power, weight_decay=weight_decay,
30
- foreach=foreach, storage_dtype=storage_dtype)
34
+ foreach=foreach, storage_dtype=storage_dtype, mars=mars, caution=caution, mars_gamma=mars_gamma)
31
35
  super().__init__(params, defaults, foreach)
32
36
 
33
37
  def _step(self, group):
@@ -53,6 +57,9 @@ class ForeachSFAdamW(ScheduleFree):
53
57
  y, grad, exp_avg_sq, z = zip(*[(p.data, p.grad, self.state_(p)['exp_avg_sq'], self.state_(p)['z']) #
54
58
  for p in active_p])
55
59
 
60
+ if group['mars']:
61
+ self.mars_correct_list(group, y, grad, group['mars_gamma'], group['betas'][0])
62
+
56
63
  lr = warmup(group['lr'], k + 1, group['warmup_steps'])
57
64
  ckp1, group['weight_sum'] = get_ckp1(lr, group['weight_lr_power'], group['weight_sum'], group['r'], k + 1)
58
65
 
heavyball/foreach_soap.py CHANGED
@@ -1,7 +1,7 @@
1
1
  import torch
2
2
 
3
3
  from .utils import init_preconditioner, update_preconditioner, project, beta_debias, exp_avg_sq_, update_param_, set_, \
4
- split_p_and_g_in_group, StatefulOptimizer, exp_avg_
4
+ StatefulOptimizer, exp_avg_
5
5
 
6
6
 
7
7
  class ForeachSOAP(StatefulOptimizer):
@@ -26,11 +26,13 @@ class ForeachSOAP(StatefulOptimizer):
26
26
  weight_decay: float = 0.01, precondition_frequency: int = 2, max_precond_dim: int = 2048, #
27
27
  merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
28
28
  data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1,
29
- split: bool = False, foreach: bool = True):
29
+ split: bool = False, foreach: bool = True, mars: bool = False, caution: bool = False,
30
+ mars_gamma: float = 0.0025):
30
31
  defaults = {"lr": lr, "betas": betas, "shampoo_beta": shampoo_beta, "eps": eps, "weight_decay": weight_decay,
31
32
  "precondition_frequency": precondition_frequency, "max_precond_dim": max_precond_dim,
32
33
  "merge_dims": merge_dims, "precondition_1d": precondition_1d, "normalize_grads": normalize_grads,
33
- "correct_bias": correct_bias, 'warmup_steps': warmup_steps, 'split': split}
34
+ "correct_bias": correct_bias, 'warmup_steps': warmup_steps, 'split': split, 'mars': mars,
35
+ 'caution': caution, 'mars_gamma': mars_gamma}
34
36
  super().__init__(params, defaults, foreach)
35
37
  self._data_format = data_format
36
38
 
@@ -41,7 +43,7 @@ class ForeachSOAP(StatefulOptimizer):
41
43
  max_precond_dim = group['max_precond_dim']
42
44
  precondition_1d = group['precondition_1d']
43
45
 
44
- for p, g in split_p_and_g_in_group(group):
46
+ for p, g in self.split_p_and_g_in_group(group, beta1=group['betas'][0]):
45
47
  state = self.state_(p)
46
48
  step = state['step'] = state.get("step", -1) + 1
47
49
 
@@ -71,6 +73,8 @@ class ForeachSOAP(StatefulOptimizer):
71
73
  step_tensor = torch.empty((), dtype=torch.int32, device=p_list[0].device).fill_(step)
72
74
  denom = exp_avg_(exp_avg, exp_avg_sq, grad, grad_projected, beta1, beta2, step_tensor)
73
75
 
76
+ step_size = -group["lr"] * min(step / group['warmup_steps'], 1)
77
+
74
78
  for p, g, ea, d in zip(p_list, grad, exp_avg, denom):
75
79
  state = self.state_(p)
76
80
  # Projecting the exponential moving average of gradients to the eigenbases of Shampoo's preconditioner
@@ -80,11 +84,9 @@ class ForeachSOAP(StatefulOptimizer):
80
84
  # Projecting back the preconditioned (by Adam) exponential moving average of gradients
81
85
  # to the original space
82
86
  # CANT DO /= HERE AS EXP_AVG MAY POINT TO THE BUFFER
83
- set_(d, project(exp_avg_projected / d, state['Q'], True))
87
+ precond = project(exp_avg_projected / d, state['Q'], True)
84
88
 
85
89
  update_preconditioner(g, state, max_precond_dim, precondition_1d, old_debiased2,
86
90
  step > 0 and step % group['precondition_frequency'] == 0)
87
91
 
88
- # Why does this have to be rebiased here?
89
- step_size = -group["lr"] * min(step / group['warmup_steps'], 1)
90
- update_param_(p_list, denom, step_size, group["weight_decay"])
92
+ update_param_([p], [precond], step_size, group["weight_decay"], caution=group['caution'], grad=[g])
heavyball/p_adam.py CHANGED
@@ -5,10 +5,10 @@ Source available at https://github.com/evanatyourservice/kron_torch/blob/97a2b5e
5
5
  """
6
6
 
7
7
  import torch
8
-
9
8
  from heavyball.utils import triu_to_line, line_to_triu, identity, stochastic_lerp_
9
+
10
10
  from .utils import update_param_, warmup, psgd_precond_grad, init_Q_exprs, PSGDBase, exp_avg_sq_, beta_debias, \
11
- split_p_and_g_in_group, promote
11
+ promote
12
12
 
13
13
 
14
14
  class ForeachPaLMPAdam(PSGDBase):
@@ -39,7 +39,8 @@ class ForeachPaLMPAdam(PSGDBase):
39
39
  momentum_into_precond_update=True, warmup_steps: int = 1, betas=(None, None), beta: float = 0.9,
40
40
  beta2_scale: float = 0.8, merge_dims: bool = False, split: bool = False, clip_fn: callable = None,
41
41
  store_triu_as_line: bool = True, foreach: bool = True, q_dtype='float32',
42
- stochastic_schedule: bool = True, storage_dtype:str ='float32',#
42
+ stochastic_schedule: bool = True, storage_dtype: str = 'float32', mars: bool = False,
43
+ caution: bool = False, mars_gamma: float = 0.0025, #
43
44
  # expert parameters
44
45
  precond_init_scale=1.0, precond_lr=0.1):
45
46
  if not 0.0 <= lr:
@@ -57,7 +58,8 @@ class ForeachPaLMPAdam(PSGDBase):
57
58
  momentum_into_precond_update=momentum_into_precond_update, precond_lr=precond_lr,
58
59
  precond_init_scale=precond_init_scale, step=0, warmup_steps=warmup_steps, beta=beta,
59
60
  beta2_scale=beta2_scale, merge_dims=merge_dims, split=split,
60
- store_triu_as_line=store_triu_as_line, q_dtype=q_dtype, storage_dtype=storage_dtype)
61
+ store_triu_as_line=store_triu_as_line, q_dtype=q_dtype, storage_dtype=storage_dtype,
62
+ mars=mars, caution=caution, mars_gamma=mars_gamma)
61
63
  super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
62
64
 
63
65
  def _step(self, group):
@@ -75,7 +77,7 @@ class ForeachPaLMPAdam(PSGDBase):
75
77
 
76
78
  vals = []
77
79
 
78
- for p, g in split_p_and_g_in_group(group, should_promote=False):
80
+ for p, g in self.split_p_and_g_in_group(group, should_promote=False, beta1=group['beta']):
79
81
  state = self.state_(p)
80
82
 
81
83
  if 'Q' not in state:
@@ -107,6 +109,7 @@ class ForeachPaLMPAdam(PSGDBase):
107
109
  lr = -warmup(lr, group['step'], group['warmup_steps'])
108
110
 
109
111
  for p, Q, g, ea, eas in zip(p_list, Q_triu, grad_list, exp_avg, exp_avg_sq):
112
+ gc = g.clone() if group['caution'] else None
110
113
  psgd_precond_grad(Q, self.state_(p)["exprs"], g, inplace=True)
111
114
  ea = psgd_precond_grad(Q, self.state_(p)["exprs"], ea)
112
115
  exp_avg_sq_(eas, g, beta_debias(beta2, group['step']), 1e-8, out=g)
@@ -115,5 +118,4 @@ class ForeachPaLMPAdam(PSGDBase):
115
118
  divide by g here, because g == denom (from exp_avg_sq_(out=g)), avoids denom allocation
116
119
  divide into g so we can deallocate ea, avoids one allocation (-> less memory than equivalent foreach)
117
120
  """
118
- update_param_([p], self.clip_fn([g]), lr, weight_decay)
119
-
121
+ update_param_([p], self.clip_fn([g]), lr, weight_decay, caution=group['caution'], grad=gc)
@@ -5,7 +5,7 @@ from .utils import warmup, ScheduleFree, exp_avg_sq_, beta_debias, get_ckp1, pro
5
5
  _compilable_schedule_free_, copy_stochastic_list_
6
6
 
7
7
 
8
- @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
8
+ @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
9
9
  def _compilable_step_(y, grad, exp_avg_sq, z, beta1, beta2, step, ckp1, eps, decay, lr):
10
10
  old_debiased2 = beta_debias(beta2, step)
11
11
 
@@ -24,12 +24,17 @@ def _compilable_step_(y, grad, exp_avg_sq, z, beta1, beta2, step, ckp1, eps, dec
24
24
 
25
25
  class PaLMForeachSFAdamW(ScheduleFree):
26
26
  def __init__(self, params, lr=0.0025, beta=0.9, betas=(None, None), eps=1e-8, weight_decay=0, warmup_steps=0, r=0.0,
27
- weight_lr_power=2.0, beta2_scale: float = 0.8, foreach: bool = True, storage_dtype: str = 'float32'):
27
+ weight_lr_power=2.0, beta2_scale: float = 0.8, foreach: bool = True, storage_dtype: str = 'float32',
28
+ mars: bool = False, caution: bool = False, mars_gamma: float = 0.0025):
28
29
  if betas[0] is not None:
29
30
  beta = betas[0]
31
+
32
+ assert not caution, "Caution not implemented for SFAdamW"
33
+
30
34
  defaults = dict(lr=lr, beta=beta, eps=eps, r=r, k=0, warmup_steps=warmup_steps, train_mode=True, weight_sum=0.0,
31
35
  lr_max=-1.0, weight_lr_power=weight_lr_power, weight_decay=weight_decay,
32
- beta2_scale=beta2_scale, storage_dtype=storage_dtype)
36
+ beta2_scale=beta2_scale, storage_dtype=storage_dtype, mars=mars, caution=caution,
37
+ mars_gamma=mars_gamma)
33
38
  super().__init__(params, defaults, foreach)
34
39
 
35
40
  def _step(self, group):
@@ -58,6 +63,9 @@ class PaLMForeachSFAdamW(ScheduleFree):
58
63
  y, grad, exp_avg_sq, z = zip(*[(p.data, p.grad, self.state_(p)['exp_avg_sq'], self.state_(p)['z']) #
59
64
  for p in active_p])
60
65
 
66
+ if group['mars']:
67
+ self.mars_correct_list(group, y, grad, group['mars_gamma'], group['betas'][0])
68
+
61
69
  lr = warmup(group['lr'], k + 1, group['warmup_steps'])
62
70
  ckp1, group['weight_sum'] = get_ckp1(lr, group['weight_lr_power'], group['weight_sum'], group['r'], k + 1)
63
71
 
@@ -1,8 +1,7 @@
1
1
  import torch
2
2
 
3
3
  from .utils import init_preconditioner, update_preconditioner, project, beta_debias, exp_avg_sq_, update_param_, set_, \
4
- split_p_and_g_in_group, StatefulOptimizer, exp_avg_
5
-
4
+ StatefulOptimizer, exp_avg_
6
5
 
7
6
 
8
7
  class PaLMForeachSOAP(StatefulOptimizer):
@@ -33,14 +32,15 @@ class PaLMForeachSOAP(StatefulOptimizer):
33
32
  max_precond_dim: int = 2048, #
34
33
  merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
35
34
  data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1,
36
- beta2_scale: float = 0.8, split: bool = False, foreach: bool = True):
35
+ beta2_scale: float = 0.8, split: bool = False, foreach: bool = True, mars: bool = False,
36
+ caution: bool = False, mars_gamma: float = 0.0025):
37
37
  if betas[0] is not None:
38
38
  beta = betas[0]
39
39
  defaults = {"lr": lr, "beta": beta, "shampoo_beta": shampoo_beta, "eps": eps, "weight_decay": weight_decay,
40
40
  "precondition_frequency": precondition_frequency, "max_precond_dim": max_precond_dim,
41
41
  "merge_dims": merge_dims, "precondition_1d": precondition_1d, "normalize_grads": normalize_grads,
42
42
  "correct_bias": correct_bias, 'warmup_steps': warmup_steps, 'beta2_scale': beta2_scale,
43
- 'split': split}
43
+ 'split': split, 'mars': mars, 'caution': caution, 'mars_gamma': mars_gamma}
44
44
  super().__init__(params, defaults, foreach)
45
45
  self._data_format = data_format
46
46
 
@@ -51,7 +51,7 @@ class PaLMForeachSOAP(StatefulOptimizer):
51
51
  max_precond_dim = group['max_precond_dim']
52
52
  precondition_1d = group['precondition_1d']
53
53
 
54
- for p, g in split_p_and_g_in_group(group):
54
+ for p, g in self.split_p_and_g_in_group(group, beta1=group['beta']):
55
55
  state = self.state_(p)
56
56
  step = state['step'] = state.get("step", -1) + 1
57
57
 
@@ -82,6 +82,7 @@ class PaLMForeachSOAP(StatefulOptimizer):
82
82
  beta2 = torch.empty((), dtype=torch.float32, device=p_list[0].device).fill_(beta2)
83
83
  step_tensor = torch.empty((), dtype=torch.int32, device=p_list[0].device).fill_(step)
84
84
  denom = exp_avg_(exp_avg, exp_avg_sq, grad, grad_projected, beta1, beta2, step_tensor)
85
+ step_size = -group["lr"] * min(step / group['warmup_steps'], 1)
85
86
 
86
87
  for p, g, ea, d in zip(p_list, grad, exp_avg, denom):
87
88
  state = self.state_(p)
@@ -92,11 +93,9 @@ class PaLMForeachSOAP(StatefulOptimizer):
92
93
  # Projecting back the preconditioned (by Adam) exponential moving average of gradients
93
94
  # to the original space
94
95
  # CANT DO /= HERE AS EXP_AVG MAY POINT TO THE BUFFER
95
- set_(d, project(exp_avg_projected / d, state['Q'], True))
96
+ precond = project(exp_avg_projected / d, state['Q'], True)
96
97
 
97
98
  update_preconditioner(g, state, max_precond_dim, precondition_1d, old_debiased2,
98
99
  step > 0 and step % group['precondition_frequency'] == 0)
99
100
 
100
- # Why does this have to be rebiased here?
101
- step_size = -group["lr"] * min(step / group['warmup_steps'], 1)
102
- update_param_(p_list, denom, step_size, group["weight_decay"])
101
+ update_param_([p], [precond], step_size, group["weight_decay"], caution=group['caution'], grad=[g])
@@ -3,7 +3,7 @@ import random
3
3
  import torch
4
4
 
5
5
  from .utils import init_preconditioner, update_preconditioner, project, beta_debias, update_param_, \
6
- precond_schedule, set_, split_p_and_g_in_group, StatefulOptimizer, exp_avg_
6
+ precond_schedule, set_, StatefulOptimizer, exp_avg_
7
7
 
8
8
 
9
9
  class PrecondScheduleForeachSOAP(StatefulOptimizer):
@@ -27,12 +27,13 @@ class PrecondScheduleForeachSOAP(StatefulOptimizer):
27
27
  weight_decay: float = 0.01, precondition_frequency: int = 2, max_precond_dim: int = 2048, #
28
28
  merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
29
29
  data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1,
30
- precond_scheduler=(1 / 3, 9), split: bool = False, foreach: bool = True):
30
+ precond_scheduler=(1 / 3, 9), split: bool = False, foreach: bool = True, mars: bool = False,
31
+ caution: bool = False, mars_gamma: float = 0.0025):
31
32
  defaults = {"lr": lr, "betas": betas, "shampoo_beta": shampoo_beta, "eps": eps, "weight_decay": weight_decay,
32
33
  "precondition_frequency": precondition_frequency, "max_precond_dim": max_precond_dim,
33
34
  "merge_dims": merge_dims, "precondition_1d": precondition_1d, "normalize_grads": normalize_grads,
34
35
  "correct_bias": correct_bias, 'warmup_steps': warmup_steps, 'precond_scheduler': precond_scheduler,
35
- 'split': split}
36
+ 'split': split, 'mars': mars, 'caution': caution, 'mars_gamma': mars_gamma}
36
37
  super().__init__(params, defaults, foreach)
37
38
  self._data_format = data_format
38
39
  self.rng = random.Random(0x120983109)
@@ -44,7 +45,7 @@ class PrecondScheduleForeachSOAP(StatefulOptimizer):
44
45
  max_precond_dim = group['max_precond_dim']
45
46
  precondition_1d = group['precondition_1d']
46
47
 
47
- for p, g in split_p_and_g_in_group(group):
48
+ for p, g in self.split_p_and_g_in_group(group, beta1=group['betas'][0]):
48
49
  state = self.state_(p)
49
50
  step = state['step'] = state.get("step", -1) + 1
50
51
 
@@ -75,6 +76,8 @@ class PrecondScheduleForeachSOAP(StatefulOptimizer):
75
76
  denom = exp_avg_(exp_avg, exp_avg_sq, grad, grad_projected, beta1, beta2, step_tensor)
76
77
 
77
78
  update_precond = precond_schedule(step, group['precond_scheduler'], self.rng)
79
+ step_size = -group["lr"] * min(step / group['warmup_steps'], 1)
80
+
78
81
  for p, g, ea, d in zip(p_list, grad, exp_avg, denom):
79
82
  state = self.state_(p)
80
83
  # Projecting the exponential moving average of gradients to the eigenbases of Shampoo's preconditioner
@@ -84,10 +87,9 @@ class PrecondScheduleForeachSOAP(StatefulOptimizer):
84
87
  # Projecting back the preconditioned (by Adam) exponential moving average of gradients
85
88
  # to the original space
86
89
  # CANT DO /= HERE AS EXP_AVG MAY POINT TO THE BUFFER
87
- set_(d, project(exp_avg_projected / d, state['Q'], True))
90
+ precond = project(exp_avg_projected / d, state['Q'], True)
88
91
 
89
92
  update_preconditioner(g, state, max_precond_dim, precondition_1d, old_debiased2, update_precond)
90
93
 
91
- # Why does this have to be rebiased here?
92
- step_size = -group["lr"] * min(step / group['warmup_steps'], 1)
93
- update_param_(p_list, denom, step_size, group["weight_decay"])
94
+ update_param_([p], [precond], step_size, group["weight_decay"], caution=group['caution'], grad=[g])
95
+
@@ -3,7 +3,7 @@ import random
3
3
  import torch
4
4
 
5
5
  from .utils import init_preconditioner, update_preconditioner, project, beta_debias, exp_avg_, update_param_, \
6
- precond_schedule, set_, split_p_and_g_in_group, StatefulOptimizer
6
+ precond_schedule, set_, StatefulOptimizer
7
7
 
8
8
 
9
9
  class PrecondSchedulePaLMForeachSOAP(StatefulOptimizer):
@@ -33,14 +33,15 @@ class PrecondSchedulePaLMForeachSOAP(StatefulOptimizer):
33
33
  merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
34
34
  data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1,
35
35
  precond_scheduler=(1 / 3, 9), betas=(None, None), beta2_scale: float = 0.8, split: bool = False,
36
- foreach: bool = True):
36
+ foreach: bool = True, mars: bool = False, caution: bool = False, mars_gamma: float = 0.0025):
37
37
  if betas[0] is not None:
38
38
  beta = betas[0]
39
39
  defaults = {"lr": lr, "beta": beta, "shampoo_beta": shampoo_beta, "eps": eps, "weight_decay": weight_decay,
40
40
  "precondition_frequency": precondition_frequency, "max_precond_dim": max_precond_dim,
41
41
  "merge_dims": merge_dims, "precondition_1d": precondition_1d, "normalize_grads": normalize_grads,
42
42
  "correct_bias": correct_bias, 'warmup_steps': warmup_steps, 'precond_scheduler': precond_scheduler,
43
- 'beta2_scale': beta2_scale, 'split': split}
43
+ 'beta2_scale': beta2_scale, 'split': split, 'mars': mars, 'caution': caution,
44
+ 'mars_gamma': mars_gamma}
44
45
  super().__init__(params, defaults, foreach)
45
46
  self._data_format = data_format
46
47
  self.rng = random.Random(0x120983109)
@@ -52,7 +53,7 @@ class PrecondSchedulePaLMForeachSOAP(StatefulOptimizer):
52
53
  max_precond_dim = group['max_precond_dim']
53
54
  precondition_1d = group['precondition_1d']
54
55
 
55
- for p, g in split_p_and_g_in_group(group):
56
+ for p, g in self.split_p_and_g_in_group(group, beta1=group['beta']):
56
57
  state = self.state_(p)
57
58
  step = state['step'] = state.get("step", -1) + 1
58
59
 
@@ -86,6 +87,8 @@ class PrecondSchedulePaLMForeachSOAP(StatefulOptimizer):
86
87
  denom = exp_avg_(exp_avg, exp_avg_sq, grad, grad_projected, beta1, beta2, step_tensor)
87
88
 
88
89
  update_precond = precond_schedule(step, group['precond_scheduler'], self.rng)
90
+ step_size = -group["lr"] * min(step / group['warmup_steps'], 1)
91
+
89
92
  for p, g, ea, d in zip(p_list, grad, exp_avg, denom):
90
93
  state = self.state_(p)
91
94
  # Projecting the exponential moving average of gradients to the eigenbases of Shampoo's preconditioner
@@ -96,10 +99,7 @@ class PrecondSchedulePaLMForeachSOAP(StatefulOptimizer):
96
99
  # to the original space
97
100
  # CANT DO /= HERE AS EXP_AVG MAY POINT TO THE BUFFER
98
101
  exp_avg_projected = exp_avg_projected / d
99
- set_(d, project(exp_avg_projected, state['Q'], True))
102
+ precond = project(exp_avg_projected, state['Q'], True)
100
103
 
101
104
  update_preconditioner(g, state, max_precond_dim, precondition_1d, old_debiased2, update_precond)
102
-
103
- # Why does this have to be rebiased here?
104
- step_size = -group["lr"] * min(step / group['warmup_steps'], 1)
105
- update_param_(p_list, denom, step_size, group["weight_decay"])
105
+ update_param_([p], [precond], step_size, group["weight_decay"], caution=group['caution'], grad=[g])
@@ -3,11 +3,11 @@ import random
3
3
  import torch
4
4
 
5
5
  from .utils import init_preconditioner, update_preconditioner, project, set_, adaptive_gradient_clipping_, exp_avg_sq_, \
6
- beta_debias, schedule_free_, warmup, ScheduleFree, precond_schedule, split_p_and_g_in_group, copy_stochastic_list_, \
6
+ beta_debias, schedule_free_, warmup, ScheduleFree, precond_schedule, copy_stochastic_list_, \
7
7
  promote
8
8
 
9
9
 
10
- @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
10
+ @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
11
11
  def _compilable_exp_avg_sq_(exp_avg_sq, grad_projected, old_debiased2, eps):
12
12
  eas32, gp32 = [list(map(promote, x)) for x in (exp_avg_sq, grad_projected)]
13
13
  denom = exp_avg_sq_(eas32, gp32, old_debiased2, eps)
@@ -52,15 +52,20 @@ class PrecondScheduleSFPaLMSOAP(ScheduleFree):
52
52
  merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
53
53
  data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1, r=0.0,
54
54
  weight_lr_power=2.0, gradient_clip_val: float = 0.1, precond_scheduler=(1 / 3, 9), betas=(None, None),
55
- split: bool = False, foreach: bool = True):
55
+ split: bool = False, foreach: bool = True, mars: bool = False, caution: bool = False,
56
+ mars_gamma: float = 0.0025):
56
57
  if betas[0] is not None:
57
58
  beta = betas[0]
59
+
60
+ assert not caution, "Caution is not implemented in ScheduleFree optimizers"
61
+
58
62
  defaults = {"lr": lr, "beta": beta, "beta2_scale": beta2_scale, "eps": eps, "weight_decay": weight_decay,
59
63
  "precondition_frequency": precondition_frequency, "max_precond_dim": max_precond_dim,
60
64
  "merge_dims": merge_dims, "precondition_1d": precondition_1d, "normalize_grads": normalize_grads,
61
65
  "correct_bias": correct_bias, 'warmup_steps': warmup_steps, 'r': r,
62
66
  'weight_lr_power': weight_lr_power, 'train_mode': True, 'step': -1, 'weight_sum': 0,
63
- 'gradient_clip_val': gradient_clip_val, 'precond_scheduler': precond_scheduler, 'split': split}
67
+ 'gradient_clip_val': gradient_clip_val, 'precond_scheduler': precond_scheduler, 'split': split,
68
+ 'mars': mars, 'caution': caution, 'mars_gamma': mars_gamma}
64
69
  super().__init__(params, defaults, foreach)
65
70
  self._data_format = data_format
66
71
  self.rng = random.Random(0x120983109)
@@ -87,7 +92,7 @@ class PrecondScheduleSFPaLMSOAP(ScheduleFree):
87
92
  # adaptive gradient clipping
88
93
  adaptive_gradient_clipping_(p_list, grad, group["gradient_clip_val"], eps=group["eps"])
89
94
 
90
- for p, g in split_p_and_g_in_group(group):
95
+ for p, g in self.split_p_and_g_in_group(group, beta1=group['beta']):
91
96
  state = self.state_(p)
92
97
 
93
98
  if "z" not in state:
heavyball/psgd_kron.py CHANGED
@@ -9,7 +9,7 @@ from typing import Optional
9
9
  import torch
10
10
 
11
11
  from .utils import update_param_, warmup, psgd_precond_grad, init_Q_exprs, trust_region_clip_, PSGDBase, \
12
- split_p_and_g_in_group, line_to_triu, triu_to_line, promote, stochastic_lerp_, beta_debias
12
+ line_to_triu, triu_to_line, promote, stochastic_lerp_, beta_debias
13
13
 
14
14
 
15
15
  class ForeachPSGDKron(PSGDBase):
@@ -40,7 +40,8 @@ class ForeachPSGDKron(PSGDBase):
40
40
  momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
41
41
  split: bool = False, clip_fn: Optional[callable] = None, store_triu_as_line: bool = True,
42
42
  foreach: bool = True, q_dtype='float32', stochastic_schedule: bool = True,
43
- storage_dtype: str = 'float32', #
43
+ storage_dtype: str = 'float32', mars: bool = False, caution: bool = False, mars_gamma: float = 0.0025,
44
+ #
44
45
  # expert parameters
45
46
  precond_init_scale=1.0, precond_lr=0.1):
46
47
  if not 0.0 <= lr:
@@ -57,7 +58,9 @@ class ForeachPSGDKron(PSGDBase):
57
58
  min_ndim_triangular=min_ndim_triangular, memory_save_mode=memory_save_mode,
58
59
  momentum_into_precond_update=momentum_into_precond_update, precond_lr=precond_lr,
59
60
  precond_init_scale=precond_init_scale, step=0, warmup_steps=warmup_steps, merge_dims=merge_dims,
60
- split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype, storage_dtype=storage_dtype)
61
+ split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype,
62
+ storage_dtype=storage_dtype,
63
+ mars=mars, caution=caution, mars_gamma=mars_gamma)
61
64
  super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
62
65
 
63
66
  def _step(self, group):
@@ -77,7 +80,7 @@ class ForeachPSGDKron(PSGDBase):
77
80
 
78
81
  vals = []
79
82
 
80
- for p, g in split_p_and_g_in_group(group, should_promote=False):
83
+ for p, g in self.split_p_and_g_in_group(group, should_promote=False, beta1=beta):
81
84
  state = self.state_(p)
82
85
 
83
86
  if 'Q' not in state:
@@ -114,4 +117,4 @@ class ForeachPSGDKron(PSGDBase):
114
117
  self.do_update(group, [p], [ea if momentum_into_precond_update else g], [q32], precond_lr, [q_orig],
115
118
  store_triu_as_line)
116
119
  g = psgd_precond_grad(q, self.state_(p)["exprs"], ea)
117
- update_param_([p], self.clip_fn([g]), lr, weight_decay)
120
+ update_param_([p], self.clip_fn([g]), lr, weight_decay, caution=group['caution'], grad=[g])
heavyball/pure_psgd.py CHANGED
@@ -5,9 +5,9 @@ Source available at https://github.com/evanatyourservice/kron_torch/blob/97a2b5e
5
5
  """
6
6
 
7
7
  import torch
8
-
9
8
  from heavyball.utils import identity
10
- from .utils import update_param_, warmup, psgd_precond_grad, init_Q_exprs, PSGDBase, split_p_and_g_in_group, \
9
+
10
+ from .utils import update_param_, warmup, psgd_precond_grad, init_Q_exprs, PSGDBase, \
11
11
  line_to_triu, triu_to_line, promote
12
12
 
13
13
 
@@ -38,7 +38,8 @@ class ForeachPurePSGD(PSGDBase):
38
38
  max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
39
39
  momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
40
40
  split: bool = False, clip_fn: callable = None, store_triu_as_line: bool = True, foreach: bool = True,
41
- q_dtype='float32', stochastic_schedule: bool = True, #
41
+ q_dtype='float32', stochastic_schedule: bool = True, mars: bool = False, caution: bool = False,
42
+ mars_gamma: float = 0.0025, #
42
43
  # expert parameters
43
44
  precond_init_scale=1.0, precond_lr=0.1):
44
45
  if not 0.0 <= lr:
@@ -49,11 +50,14 @@ class ForeachPurePSGD(PSGDBase):
49
50
  if clip_fn is None:
50
51
  clip_fn = identity
51
52
 
53
+ assert not mars, "MARS is not supported in this optimizer"
54
+
52
55
  defaults = dict(lr=lr, weight_decay=weight_decay, max_size_triangular=max_size_triangular,
53
56
  min_ndim_triangular=min_ndim_triangular, memory_save_mode=memory_save_mode,
54
57
  momentum_into_precond_update=momentum_into_precond_update, precond_lr=precond_lr,
55
58
  precond_init_scale=precond_init_scale, step=0, warmup_steps=warmup_steps, merge_dims=merge_dims,
56
- split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype)
59
+ split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype, mars=mars, caution=caution,
60
+ mars_gamma=mars_gamma)
57
61
  super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
58
62
 
59
63
  def _step(self, group):
@@ -70,7 +74,7 @@ class ForeachPurePSGD(PSGDBase):
70
74
 
71
75
  vals = []
72
76
 
73
- for p, g in split_p_and_g_in_group(group, should_promote=False):
77
+ for p, g in self.split_p_and_g_in_group(group, should_promote=False, beta1=0.0):
74
78
  state = self.state_(p)
75
79
 
76
80
  if 'Q' not in state:
@@ -98,4 +102,4 @@ class ForeachPurePSGD(PSGDBase):
98
102
  q32 = [promote(q_) for q_ in q]
99
103
  self.do_update(group, [p], [g], [q32], precond_lr, [q_orig], store_triu_as_line)
100
104
  psgd_precond_grad(q, self.state_(p)["exprs"], g, inplace=True)
101
- update_param_([p], self.clip_fn([g]), lr, weight_decay)
105
+ update_param_([p], self.clip_fn([g]), lr, weight_decay, caution=group['caution'], grad=[g])
@@ -1,12 +1,13 @@
1
1
  import random
2
2
 
3
3
  import torch
4
+ from heavyball.utils import mars_correction
4
5
 
5
6
  from .utils import init_preconditioner, update_preconditioner, project, set_, adaptive_gradient_clipping_, exp_avg_sq_, \
6
- beta_debias, schedule_free_, warmup, ScheduleFree, split_p_and_g_in_group, copy_stochastic_list_, promote
7
+ beta_debias, schedule_free_, warmup, ScheduleFree, copy_stochastic_list_, promote
7
8
 
8
9
 
9
- @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
10
+ @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
10
11
  def _compilable_exp_avg_sq_(exp_avg_sq, grad_projected, old_debiased2, eps):
11
12
  eas32, gp32 = [list(map(promote, x)) for x in (exp_avg_sq, grad_projected)]
12
13
  denom = exp_avg_sq_(eas32, gp32, old_debiased2, eps)
@@ -44,15 +45,19 @@ class SFPaLMForeachSOAP(ScheduleFree):
44
45
  merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
45
46
  data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1, r=0.0,
46
47
  weight_lr_power=2.0, gradient_clip_val: float = 0.1, betas=(None, None), split: bool = False,
47
- foreach: bool = True):
48
+ foreach: bool = True, mars: bool = False, caution: bool = False, mars_gamma: float = 0.0025):
48
49
  if betas[0] is not None:
49
50
  beta = betas[0]
51
+
52
+ assert not caution, "Caution is not implemented in ScheduleFree optimizers"
53
+
50
54
  defaults = {"lr": lr, "beta": beta, "beta2_scale": beta2_scale, "eps": eps, "weight_decay": weight_decay,
51
55
  "precondition_frequency": precondition_frequency, "max_precond_dim": max_precond_dim,
52
56
  "merge_dims": merge_dims, "precondition_1d": precondition_1d, "normalize_grads": normalize_grads,
53
57
  "correct_bias": correct_bias, 'warmup_steps': warmup_steps, 'r': r,
54
58
  'weight_lr_power': weight_lr_power, 'train_mode': True, 'step': -1,
55
- 'gradient_clip_val': gradient_clip_val, 'weight_sum': 0, 'split': split}
59
+ 'gradient_clip_val': gradient_clip_val, 'weight_sum': 0, 'split': split, 'mars': mars,
60
+ 'caution': caution, 'mars_gamma': mars_gamma}
56
61
  super().__init__(params, defaults, foreach)
57
62
  self._data_format = data_format
58
63
  self.rng = random.Random(0x120983109)
@@ -61,6 +66,7 @@ class SFPaLMForeachSOAP(ScheduleFree):
61
66
  vals = []
62
67
  max_precond_dim = group['max_precond_dim']
63
68
  precondition_1d = group['precondition_1d']
69
+ mars = group['mars']
64
70
 
65
71
  step = group['step'] = group.get("step", 0) + 1
66
72
 
@@ -79,12 +85,14 @@ class SFPaLMForeachSOAP(ScheduleFree):
79
85
 
80
86
  vals = []
81
87
 
82
- for p, g in split_p_and_g_in_group(group):
88
+ for p, g in self.split_p_and_g_in_group(group, beta1=group['beta']):
83
89
  state = self.state_(p)
84
90
 
85
91
  if "z" not in state:
86
92
  state["z"] = torch.clone(p).float()
87
93
  state["exp_avg_sq"] = torch.zeros_like(g, dtype=torch.float32)
94
+ if mars:
95
+ state['mars_prev_grad'] = g.clone()
88
96
  init_preconditioner(g, state, max_precond_dim, precondition_1d)
89
97
  update_preconditioner(g, state, max_precond_dim, precondition_1d, 0, True)
90
98
  continue # first step is skipped so that we never use the current gradients in the projection.
heavyball/utils.py CHANGED
@@ -142,18 +142,26 @@ def beta_debias(beta, step):
142
142
 
143
143
 
144
144
  @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
145
- def exp_avg_sq_(state, grad, beta2, eps, out=None):
146
- if isinstance(state, torch.Tensor):
147
- state.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
148
- return torch.sqrt(state, out=out).clamp_(min=eps)
149
-
145
+ def _compilable_exp_avg_sq_(state, grad, beta2, eps, out=None):
150
146
  torch._foreach_mul_(state, beta2)
151
147
  [s.addcmul_(g, g, value=1 - beta2) for s, g in zip(state, grad)]
152
148
  denom = torch._foreach_sqrt(state)
153
- torch._foreach_maximum_(denom, eps)
149
+ [denom.clamp_(min=eps) for denom in denom]
150
+ if out is not None:
151
+ copy_stochastic_list_(out, denom)
152
+ return out
153
+
154
154
  return denom
155
155
 
156
156
 
157
+ def exp_avg_sq_(state, grad, beta2, eps, out=None):
158
+ state, grad = list_guard(state), list_guard(grad)
159
+ if not isinstance(beta2, torch.Tensor):
160
+ beta2 = torch.empty((), dtype=torch.float32, device=state[0].device).fill_(beta2)
161
+ if not isinstance(eps, torch.Tensor):
162
+ eps = torch.empty((), dtype=torch.float32, device=state[0].device).fill_(eps)
163
+ return _compilable_exp_avg_sq_(state, grad, beta2, eps, out)
164
+
157
165
  def adaptive_gradient_clipping_(parameters: List[torch.Tensor], gradients: List[torch.Tensor], clip_val: float,
158
166
  minimum: float = 1e-3, eps: float = 1e-8):
159
167
  if clip_val <= 0:
@@ -168,12 +176,19 @@ def adaptive_gradient_clipping_(parameters: List[torch.Tensor], gradients: List[
168
176
  torch._foreach_mul_(gradients, p_norm)
169
177
 
170
178
 
179
+ def is_compiling():
180
+ try:
181
+ return torch.compiler.is_compiling()
182
+ except AttributeError:
183
+ return True
184
+
185
+
171
186
  def set_(dst: torch.Tensor, src: torch.Tensor):
172
- if not torch.compiler.is_compiling() and src.data_ptr() == dst.data_ptr():
187
+ if not is_compiling() and src.data_ptr() == dst.data_ptr():
173
188
  return
174
189
  if src.shape != dst.shape:
175
190
  src = src.reshape_as(dst)
176
- if not torch.compiler.is_compiling() and src.is_contiguous() and dst.is_contiguous() and src.dtype == dst.dtype:
191
+ if not is_compiling() and src.is_contiguous() and dst.is_contiguous() and src.dtype == dst.dtype:
177
192
  dst.set_(src)
178
193
  else:
179
194
  dst.copy_(src)
@@ -338,11 +353,18 @@ def _compilable_stochastic_lerp_(x: List[torch.Tensor], y: List[torch.Tensor], a
338
353
 
339
354
 
340
355
  def stochastic_lerp_(x: List[torch.Tensor], y: List[torch.Tensor], a: Union[float, int, torch.Tensor]):
356
+ x, y = list_guard(x), list_guard(y)
341
357
  if not isinstance(a, torch.Tensor):
342
358
  a = torch.empty((), dtype=torch.float32, device=x[0].device).fill_(a)
343
359
  _compilable_stochastic_lerp_(x, y, a)
344
360
 
345
361
 
362
+ def list_guard(x):
363
+ if isinstance(x, (list, tuple)):
364
+ return x
365
+ return [x]
366
+
367
+
346
368
  @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
347
369
  def _compilable_stochastic_add_(x: List[torch.Tensor], y: List[torch.Tensor], alpha: Union[float, int, torch.Tensor]):
348
370
  for x_, y_ in zip(x, y):
@@ -353,6 +375,7 @@ def _compilable_stochastic_add_(x: List[torch.Tensor], y: List[torch.Tensor], al
353
375
 
354
376
 
355
377
  def stochastic_add_(x: List[torch.Tensor], y: List[torch.Tensor], alpha: Union[float, int, torch.Tensor]):
378
+ x, y = list_guard(x), list_guard(y)
356
379
  if not isinstance(alpha, torch.Tensor):
357
380
  alpha = torch.empty((), dtype=torch.float32, device=x[0].device).fill_(alpha)
358
381
  _compilable_stochastic_add_(x, y, alpha)
@@ -463,6 +486,43 @@ class StatefulOptimizer(torch.optim.Optimizer):
463
486
  def state_(self, arg: torch.Tensor):
464
487
  return self.state[self.key(arg)]
465
488
 
489
+ def mars_correct_list(self, group, p_list, g_list, mars_gamma, beta):
490
+ for p, g in zip(p_list, g_list):
491
+ state = self.state_(p)
492
+ if 'mars_old_grad' not in state:
493
+ state['mars_old_grad'] = torch.zeros_like(g)
494
+ old_gs = [self.state_(p)['mars_old_grad'] for p in p_list]
495
+ mars_correction(g_list, old_gs, mars_gamma, beta)
496
+
497
+ def split_p_and_g_in_group(self, group: dict, skip_none: bool = True, should_promote: bool = True,
498
+ beta1: float = -1.0):
499
+ for p in group["params"]:
500
+ if skip_none and p.grad is None:
501
+ continue
502
+
503
+ if p.grad is None:
504
+ grad = None
505
+ else:
506
+ if should_promote:
507
+ grad = promote(p.grad)
508
+ else:
509
+ grad = p.grad
510
+ if beta1 >= 0 and group.get('mars', False):
511
+ self.mars_correct_list(group, [p], [grad], group['mars_gamma'], beta1)
512
+
513
+ p.grad = None
514
+
515
+ p_views = merge_group(group, p)
516
+ if grad is not None:
517
+ grad = merge_group(group, grad)
518
+ if isinstance(p_views, torch.Tensor):
519
+ yield p_views, grad
520
+ continue
521
+ if grad is None:
522
+ yield from zip(p_views, [None] * len(p_views))
523
+ continue
524
+ yield from zip(p_views, grad)
525
+
466
526
  def state_size(self) -> int:
467
527
  total_bytes = 0
468
528
 
@@ -472,7 +532,7 @@ class StatefulOptimizer(torch.optim.Optimizer):
472
532
  total_bytes += x.numel() * x.element_size()
473
533
 
474
534
  for group in self.param_groups:
475
- for p, _ in split_p_and_g_in_group(group, skip_none=False):
535
+ for p, _ in self.split_p_and_g_in_group(group, skip_none=False):
476
536
  tree_map(_add, self.state_(p))
477
537
  return total_bytes
478
538
 
@@ -625,7 +685,7 @@ def _compilable_copy_stochastic_(target: torch.Tensor, source: torch.Tensor):
625
685
 
626
686
 
627
687
  def copy_stochastic_(target: torch.Tensor, source: torch.Tensor):
628
- if not torch.compiler.is_compiling() and target.data_ptr() == source.data_ptr():
688
+ if not is_compiling() and target.data_ptr() == source.data_ptr():
629
689
  return
630
690
  if target.dtype != torch.bfloat16 or source.dtype not in (torch.float16, torch.float32, torch.float64):
631
691
  set_(target, source)
@@ -633,14 +693,16 @@ def copy_stochastic_(target: torch.Tensor, source: torch.Tensor):
633
693
 
634
694
 
635
695
  @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
636
- def _compilable_update_(p, u, decay, add_fn, lr):
696
+ def _compilable_update_(p, u, decay, add_fn, lr, caution, g):
637
697
  u = [u_.view_as(p_) for u_, p_ in zip(u, p)]
638
- p32, u32 = [list(map(promote, x)) for x in [p, u]]
698
+ p32, u32, g32 = [list(map(promote, x)) for x in [p, u, g]]
639
699
 
640
700
  if decay > 0:
641
701
  torch._foreach_mul_(p32, 1 - decay * lr)
642
702
 
643
- for p32_, u32_ in zip(p32, u32): # lr is data-dependent -> can't compile a foreach
703
+ for p32_, u32_, g32_ in zip(p32, u32, g32): # lr is data-dependent -> can't compile a foreach
704
+ if caution:
705
+ _compilable_cautioning_(g32_, u32_)
644
706
  if add_fn is None:
645
707
  p32_.add_(u32_, alpha=lr)
646
708
  else:
@@ -650,9 +712,12 @@ def _compilable_update_(p, u, decay, add_fn, lr):
650
712
 
651
713
 
652
714
  def update_param_(param: List[torch.Tensor], update: List[torch.Tensor], lr: float, decay: float,
653
- add_fn: callable = None):
715
+ add_fn: callable = None, caution: bool = False, grad: List[torch.Tensor] = None):
654
716
  lr_tensor = torch.empty((), dtype=torch.float32, device=param[0].device).fill_(lr)
655
- _compilable_update_(param, update, decay, add_fn, lr_tensor)
717
+ param, update, grad = list_guard(param), list_guard(update), list_guard(grad)
718
+ if not caution:
719
+ grad = [None] * len(param)
720
+ _compilable_update_(param, update, decay, add_fn, lr_tensor, caution, grad)
656
721
 
657
722
 
658
723
  def precond_schedule(step, precond_scheduler, rng):
@@ -965,18 +1030,45 @@ class PSGDBase(StatefulOptimizer):
965
1030
  psgd_balance_Q(q)
966
1031
 
967
1032
 
968
- @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
969
- def _compilable_precond_grad_cached_(cached_q, ea, expr, param, lr, weight_decay, clip_fn):
1033
+ # TODO: Figure out why this sometimes crashes
1034
+ # @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
1035
+ def _compilable_precond_grad_cached_(cached_q, ea, expr, param, lr, weight_decay, clip_fn, caution, grad):
970
1036
  md = min_dtype(cached_q + [ea])
971
1037
  new = torch.einsum(expr, *[c_.to(md) for c_ in cached_q], ea.to(md)).to(torch.float32)
972
- update_param_([param], clip_fn([new]), lr, weight_decay)
1038
+ update_param_([param], clip_fn([new]), lr, weight_decay, caution=caution, grad=grad)
973
1039
 
974
1040
 
975
1041
  def precond_grad_cached_(cached_q: List[torch.Tensor], ea: torch.Tensor, expr: str, param: torch.Tensor, lr: float,
976
- weight_decay: float, clip_fn):
1042
+ weight_decay: float, clip_fn, caution, grad):
977
1043
  if isinstance(lr, float):
978
1044
  lr = torch.empty((), dtype=torch.float32, device=param.device).fill_(lr)
979
- _compilable_precond_grad_cached_(cached_q, ea, expr, param, lr, weight_decay, clip_fn)
1045
+ _compilable_precond_grad_cached_(cached_q, ea, expr, param, lr, weight_decay, clip_fn, caution, grad)
1046
+
1047
+
1048
+ @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
1049
+ def _compilable_mars_correction_(g, old_g, a):
1050
+ g_copy = [g_.clone() for g_ in g]
1051
+ _compilable_stochastic_lerp_(g, old_g, a)
1052
+ copy_stochastic_list_(old_g, g_copy)
1053
+
1054
+
1055
+ def mars_correction(g, old_g, beta1, gamma):
1056
+ a = -gamma * beta1 / (1 - beta1)
1057
+ g, old_g = list_guard(g), list_guard(old_g)
1058
+ a = torch.empty((), dtype=torch.float32, device=g[0].device).fill_(a)
1059
+ _compilable_mars_correction_(g, old_g, a)
1060
+
1061
+
1062
+ @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
1063
+ def _compilable_cautioning_(g, update):
1064
+ mask = (g * update) > 0
1065
+ update.masked_fill_(~mask, 0)
1066
+ scale = mask.numel() / mask.sum().clamp(min=1)
1067
+ update.mul_(scale)
1068
+
1069
+
1070
+ def caution(g, update):
1071
+ _compilable_cautioning_(g, update)
980
1072
 
981
1073
 
982
1074
  def precond_update_prob_schedule(max_prob=1.0, min_prob=0.03, decay=0.001, flat_start=250):
@@ -1013,29 +1105,3 @@ def merge_group(group, *tensors):
1013
1105
  append_or_extend(out, dim_merger(t, group['max_size_triangular'] if 'max_size_triangular' in group else group[
1014
1106
  'max_precond_dim'], group.get('split', False)))
1015
1107
  return out
1016
-
1017
-
1018
- def split_p_and_g_in_group(group: dict, skip_none: bool = True, should_promote: bool = True):
1019
- for p in group["params"]:
1020
- if skip_none and p.grad is None:
1021
- continue
1022
-
1023
- if p.grad is None:
1024
- grad = None
1025
- else:
1026
- if should_promote:
1027
- grad = promote(p.grad)
1028
- else:
1029
- grad = p.grad
1030
- p.grad = None
1031
-
1032
- p_views = merge_group(group, p)
1033
- if grad is not None:
1034
- grad = merge_group(group, grad)
1035
- if isinstance(p_views, torch.Tensor):
1036
- yield p_views, grad
1037
- continue
1038
- if grad is None:
1039
- yield from zip(p_views, [None] * len(p_views))
1040
- continue
1041
- yield from zip(p_views, grad)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: heavyball
3
- Version: 0.21.8
3
+ Version: 0.22.0
4
4
  Summary: Efficient optimizers
5
5
  Home-page: https://github.com/clashluke/heavyball
6
6
  Author: Lucas Nestler
@@ -32,7 +32,7 @@ A simple package of efficient optimizers
32
32
  The goal is not to thrive for completeness, full maintenance or abstraction, but instead to provide a simple
33
33
  largely static alternative to `torch.optim` with more and better optimizers.
34
34
 
35
- Currently (2024-11-22, 0.21.0), the recommended stable optimizer is `PrecondSchedulePaLMSOAP` (see below). The
35
+ Currently (2024-11-26, 0.22.0), the recommended stable optimizer is `PrecondSchedulePaLMSOAP` (see below). The
36
36
  recommended experimental optimizer is `DelayedPSGDKron` ([tuning guide](docs/psgd_efficiency.md)).
37
37
 
38
38
  ## Features
@@ -0,0 +1,24 @@
1
+ heavyball/__init__.py,sha256=icHYN-MGsmHkLUlHCMcZkOlwY7GT63_ayR_a5iPKmzM,2226
2
+ heavyball/cached_delayed_psgd_kron.py,sha256=n3wIOhrop0Ls4MZ0kXpwGuImp1jzPs6VGdxIlPyoYdQ,6827
3
+ heavyball/cached_psgd_kron.py,sha256=KCLsfvj9qh_2FNwRTdWM3zjnt2oGHfsf4Y341rPcceI,6778
4
+ heavyball/delayed_psgd.py,sha256=CaG17zqorLsCSDeGEePOyb6n9ugv8W6gyRQqeQNq-e8,6272
5
+ heavyball/foreach_adamw.py,sha256=uawSbGGUD2E1RtcwspP83yQNElERdGX-diqCI5e8FqE,2825
6
+ heavyball/foreach_adopt.py,sha256=DFEaPswVzdHcbxC-mirsf_okM_HR6r34PDUTty5CrUE,3547
7
+ heavyball/foreach_laprop.py,sha256=J4Vms0nAOMh3GQtAOPyrYOe5WtpzokVv25b9oDnwc2A,2833
8
+ heavyball/foreach_sfadamw.py,sha256=HWbLekY5BloHDIgrN2J0a7IolZCt8Ah2xkLAU_-5oSc,3079
9
+ heavyball/foreach_soap.py,sha256=7B_dP2Hm_xqwpBQiPYkv_c6eoRnU1dV2VZfvSoa4uJ8,4729
10
+ heavyball/p_adam.py,sha256=F-id4qOkAaDTJaKTSNhSsonX-Js5IzIu1Bdj1S4qE2g,6306
11
+ heavyball/palm_foreach_sfadamw.py,sha256=E8raxrBIkSmTEGFzwnfWxKwDJjBQE2vdsmyqfc8aL_A,3375
12
+ heavyball/palm_foreach_soap.py,sha256=IknGm_CzrqDIFEoCkejxjoZ4sfIy6RSoInqlMUOYLB4,6156
13
+ heavyball/precond_schedule_foreach_soap.py,sha256=bJ2ifPFa8zEP9GO8eBpqZzsmP7p_iQkkCkllNeEMHPU,4892
14
+ heavyball/precond_schedule_palm_foreach_soap.py,sha256=4dT9f134-Faq2KuCMCHzMtrkMO-es5p_DYS1of5yF-s,6428
15
+ heavyball/precond_schedule_sfpsoap.py,sha256=FOR-axwlkSN7IHZWYYUVFfjSFCLxc_NdiTlb-n5gmgs,7530
16
+ heavyball/psgd_kron.py,sha256=achB23mQUT3F00IGhjjVf_8YW7VOTHR6YdoCDRyWxsI,6039
17
+ heavyball/pure_psgd.py,sha256=dbYgkunFFA6EsO6fJEhaJRxTH0smi7qLX3Np9XTQ9E4,5079
18
+ heavyball/schedule_free_palm_foreach_soap.py,sha256=0WT_gvTKymqLQzYT6ewDgCmpDq-HgMAewipw1QvyQYA,7267
19
+ heavyball/utils.py,sha256=TVpyev0oL4a78px4cvtaGoGPJqfpfTKE-xBWkRCmzkw,39785
20
+ heavyball-0.22.0.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
21
+ heavyball-0.22.0.dist-info/METADATA,sha256=LqVR3tUgxpk21zsmKxfJAQCKLPzmaQz2PQiKvlvpQe8,11926
22
+ heavyball-0.22.0.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
23
+ heavyball-0.22.0.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
24
+ heavyball-0.22.0.dist-info/RECORD,,
@@ -1,24 +0,0 @@
1
- heavyball/__init__.py,sha256=iqP428JWwwx-XDOZ0nUdbCkOLEyfoqVyWZLQLAcwxaw,2214
2
- heavyball/cached_delayed_psgd_kron.py,sha256=Nyxl-G-o6greKwDN-vLiw5W02GXO2LRvknc0OzvzFnE,6674
3
- heavyball/cached_psgd_kron.py,sha256=HzD6se0AYb-W5hpydUxcR9uqrpe_54PBwgL1VWX3DHU,6592
4
- heavyball/delayed_psgd.py,sha256=m4c-OvcLMrRxSAPYs2l6Up21uCyF2kvHvpcnfe3nzGs,6212
5
- heavyball/foreach_adamw.py,sha256=Rb5U80cgUcEqlEbUU250UTWdoqA7nyiqkV5w1U4bWX4,2445
6
- heavyball/foreach_adopt.py,sha256=ecdi1fKg9i087OGjtKWVbE_DD6Yf4pvpzv4ELCcusvQ,3211
7
- heavyball/foreach_laprop.py,sha256=vi6C_gfjXxw5uN0KHgzxI9itUI1dcgOf3ufoO_VVMp0,2471
8
- heavyball/foreach_sfadamw.py,sha256=rLZORmCIMu9G09FdDgMSiI6pNq34IVoxsPVWtmeDdbQ,2753
9
- heavyball/foreach_soap.py,sha256=4mWSMWYTdjgiXiboI5DwdigecruDtNGKylGAFAVhCRA,4562
10
- heavyball/p_adam.py,sha256=Xyxsavwtw-t0OyTHitYQXZSmF9UJlMDzDAURge-MbbQ,6047
11
- heavyball/palm_foreach_sfadamw.py,sha256=JbNrcoquBGGUI5XNMFouDjpNurVHUW9DbX1A3tSrtno,3025
12
- heavyball/palm_foreach_soap.py,sha256=GzAwM8kOt1X0QCmUZDTdHwPxbJwjH8ic43dyAK5BYCA,6015
13
- heavyball/precond_schedule_foreach_soap.py,sha256=HcObXLfSNN_lKNb4nmC6tkdHcqDIMNX6hILpHKScqLc,4744
14
- heavyball/precond_schedule_palm_foreach_soap.py,sha256=xZ7CJvIfdu2RNAZt2g1S7Xb0Jyy1hNC4MowOFU3nWkk,6283
15
- heavyball/precond_schedule_sfpsoap.py,sha256=PNneiOkrRyV1yIZn91lPmYofd1_OiLqJTDy75RLpXJk,7273
16
- heavyball/psgd_kron.py,sha256=RSLJi5FnSmYjvYBufNDClnYRm-eN_Kpa1Ar2tNP6-X0,5824
17
- heavyball/pure_psgd.py,sha256=LZK0qmvZkBF8g00evaVLtW-sIUJmdoxag1K7O26AqEo,4820
18
- heavyball/schedule_free_palm_foreach_soap.py,sha256=_AqrnChY6iDlQUkF2YUxS7eLjSWCIuvEUOHvMHVM1yY,6873
19
- heavyball/utils.py,sha256=xTDZEt2_DM57EYnJkRq7d7scTnro4eKPdMtEwPdLy-c,37218
20
- heavyball-0.21.8.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
21
- heavyball-0.21.8.dist-info/METADATA,sha256=nLyxHlENmhAGyU9GManYKKJJTykhsAMt7hkJNXPu_YY,11926
22
- heavyball-0.21.8.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
23
- heavyball-0.21.8.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
24
- heavyball-0.21.8.dist-info/RECORD,,