heavyball 0.21.7__tar.gz → 0.22.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. {heavyball-0.21.7 → heavyball-0.22.0}/PKG-INFO +2 -2
  2. {heavyball-0.21.7 → heavyball-0.22.0}/README.md +1 -1
  3. {heavyball-0.21.7 → heavyball-0.22.0}/heavyball/__init__.py +6 -5
  4. {heavyball-0.21.7 → heavyball-0.22.0}/heavyball/cached_delayed_psgd_kron.py +6 -5
  5. {heavyball-0.21.7 → heavyball-0.22.0}/heavyball/cached_psgd_kron.py +7 -5
  6. {heavyball-0.21.7 → heavyball-0.22.0}/heavyball/delayed_psgd.py +12 -9
  7. {heavyball-0.21.7 → heavyball-0.22.0}/heavyball/foreach_adamw.py +14 -7
  8. {heavyball-0.21.7 → heavyball-0.22.0}/heavyball/foreach_adopt.py +11 -6
  9. {heavyball-0.21.7 → heavyball-0.22.0}/heavyball/foreach_laprop.py +12 -6
  10. {heavyball-0.21.7 → heavyball-0.22.0}/heavyball/foreach_sfadamw.py +10 -3
  11. {heavyball-0.21.7 → heavyball-0.22.0}/heavyball/foreach_soap.py +10 -8
  12. {heavyball-0.21.7 → heavyball-0.22.0}/heavyball/p_adam.py +9 -7
  13. {heavyball-0.21.7 → heavyball-0.22.0}/heavyball/palm_foreach_sfadamw.py +11 -3
  14. {heavyball-0.21.7 → heavyball-0.22.0}/heavyball/palm_foreach_soap.py +8 -9
  15. {heavyball-0.21.7 → heavyball-0.22.0}/heavyball/precond_schedule_foreach_soap.py +10 -8
  16. {heavyball-0.21.7 → heavyball-0.22.0}/heavyball/precond_schedule_palm_foreach_soap.py +9 -9
  17. {heavyball-0.21.7 → heavyball-0.22.0}/heavyball/precond_schedule_sfpsoap.py +10 -5
  18. {heavyball-0.21.7 → heavyball-0.22.0}/heavyball/psgd_kron.py +8 -5
  19. {heavyball-0.21.7 → heavyball-0.22.0}/heavyball/pure_psgd.py +10 -6
  20. {heavyball-0.21.7 → heavyball-0.22.0}/heavyball/schedule_free_palm_foreach_soap.py +13 -5
  21. {heavyball-0.21.7 → heavyball-0.22.0}/heavyball/utils.py +120 -54
  22. {heavyball-0.21.7 → heavyball-0.22.0}/heavyball.egg-info/PKG-INFO +2 -2
  23. {heavyball-0.21.7 → heavyball-0.22.0}/heavyball.egg-info/SOURCES.txt +2 -0
  24. {heavyball-0.21.7 → heavyball-0.22.0}/setup.py +1 -1
  25. {heavyball-0.21.7 → heavyball-0.22.0}/test/test_bf16_params.py +13 -13
  26. heavyball-0.22.0/test/test_caution.py +50 -0
  27. heavyball-0.22.0/test/test_mars.py +53 -0
  28. {heavyball-0.21.7 → heavyball-0.22.0}/LICENSE +0 -0
  29. {heavyball-0.21.7 → heavyball-0.22.0}/heavyball.egg-info/dependency_links.txt +0 -0
  30. {heavyball-0.21.7 → heavyball-0.22.0}/heavyball.egg-info/requires.txt +0 -0
  31. {heavyball-0.21.7 → heavyball-0.22.0}/heavyball.egg-info/top_level.txt +0 -0
  32. {heavyball-0.21.7 → heavyball-0.22.0}/setup.cfg +0 -0
  33. {heavyball-0.21.7 → heavyball-0.22.0}/test/test_bf16_q.py +0 -0
  34. {heavyball-0.21.7 → heavyball-0.22.0}/test/test_bf16_storage.py +0 -0
  35. {heavyball-0.21.7 → heavyball-0.22.0}/test/test_closure.py +0 -0
  36. {heavyball-0.21.7 → heavyball-0.22.0}/test/test_ema.py +0 -0
  37. {heavyball-0.21.7 → heavyball-0.22.0}/test/test_foreach.py +0 -0
  38. {heavyball-0.21.7 → heavyball-0.22.0}/test/test_memory.py +0 -0
  39. {heavyball-0.21.7 → heavyball-0.22.0}/test/test_merge.py +0 -0
  40. {heavyball-0.21.7 → heavyball-0.22.0}/test/test_no_grad.py +0 -0
  41. {heavyball-0.21.7 → heavyball-0.22.0}/test/test_psgd.py +0 -0
  42. {heavyball-0.21.7 → heavyball-0.22.0}/test/test_soap.py +0 -0
  43. {heavyball-0.21.7 → heavyball-0.22.0}/test/test_stochastic_updates.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: heavyball
3
- Version: 0.21.7
3
+ Version: 0.22.0
4
4
  Summary: Efficient optimizers
5
5
  Home-page: https://github.com/clashluke/heavyball
6
6
  Author: Lucas Nestler
@@ -32,7 +32,7 @@ A simple package of efficient optimizers
32
32
  The goal is not to thrive for completeness, full maintenance or abstraction, but instead to provide a simple
33
33
  largely static alternative to `torch.optim` with more and better optimizers.
34
34
 
35
- Currently (2024-11-22, 0.21.0), the recommended stable optimizer is `PrecondSchedulePaLMSOAP` (see below). The
35
+ Currently (2024-11-26, 0.22.0), the recommended stable optimizer is `PrecondSchedulePaLMSOAP` (see below). The
36
36
  recommended experimental optimizer is `DelayedPSGDKron` ([tuning guide](docs/psgd_efficiency.md)).
37
37
 
38
38
  ## Features
@@ -8,7 +8,7 @@ A simple package of efficient optimizers
8
8
  The goal is not to thrive for completeness, full maintenance or abstraction, but instead to provide a simple
9
9
  largely static alternative to `torch.optim` with more and better optimizers.
10
10
 
11
- Currently (2024-11-22, 0.21.0), the recommended stable optimizer is `PrecondSchedulePaLMSOAP` (see below). The
11
+ Currently (2024-11-26, 0.22.0), the recommended stable optimizer is `PrecondSchedulePaLMSOAP` (see below). The
12
12
  recommended experimental optimizer is `DelayedPSGDKron` ([tuning guide](docs/psgd_efficiency.md)).
13
13
 
14
14
  ## Features
@@ -1,3 +1,4 @@
1
+ from .cached_delayed_psgd_kron import ForeachCachedDelayedPSGDKron
1
2
  from .cached_psgd_kron import ForeachCachedPSGDKron
2
3
  from .delayed_psgd import ForeachDelayedPSGD
3
4
  from .foreach_adamw import ForeachAdamW
@@ -14,7 +15,6 @@ from .precond_schedule_sfpsoap import PrecondScheduleSFPaLMSOAP
14
15
  from .psgd_kron import ForeachPSGDKron
15
16
  from .pure_psgd import ForeachPurePSGD
16
17
  from .schedule_free_palm_foreach_soap import SFPaLMForeachSOAP
17
- from .cached_delayed_psgd_kron import ForeachCachedDelayedPSGDKron
18
18
 
19
19
  PalmForEachSoap = PaLMForeachSOAP
20
20
 
@@ -39,7 +39,8 @@ CachedDelayedPSGDKron = ForeachCachedDelayedPSGDKron
39
39
  __all__ = ['PalmForEachSoap', 'PaLMForeachSFAdamW', 'PaLMForeachSOAP', 'SFPaLMForeachSOAP', 'PrecondScheduleSFPaLMSOAP',
40
40
  'ForeachSOAP', 'ForeachSFAdamW', 'ForeachLaProp', 'ForeachADOPT', 'PrecondScheduleForeachSOAP',
41
41
  'PrecondSchedulePaLMForeachSOAP', 'ForeachPSGDKron', 'ForeachAdamW', 'ForeachPurePSGD', 'ForeachPaLMPAdam',
42
- 'ForeachDelayedPSGD', 'ForeachCachedPSGDKron', 'ForeachCachedDelayedPSGDKron', #
43
- 'PaLMSOAP', 'PaLMSFAdamW', 'PaLMSFSoap', 'PaLMSFAdamW', 'PrecondScheduleSFPaLMSOAP',
44
- 'SOAP', 'SFAdamW', 'LaProp', 'ADOPT', 'PSGDKron', 'AdamW', 'PurePSGD', 'PaLMPAdam', 'DelayedPSGD',
45
- 'CachedPSGDKron', 'CachedDelayedPSGDKron', 'PrecondScheduleSOAP', 'PrecondSchedulePaLMSOAP']
42
+ 'ForeachDelayedPSGD', 'ForeachCachedPSGDKron', 'ForeachCachedDelayedPSGDKron',
43
+ #
44
+ 'PaLMSOAP', 'PaLMSFAdamW', 'PaLMSFSoap', 'PaLMSFAdamW', 'PrecondScheduleSFPaLMSOAP', 'SOAP', 'SFAdamW',
45
+ 'LaProp', 'ADOPT', 'PSGDKron', 'AdamW', 'PurePSGD', 'PaLMPAdam', 'DelayedPSGD', 'CachedPSGDKron',
46
+ 'CachedDelayedPSGDKron', 'PrecondScheduleSOAP', 'PrecondSchedulePaLMSOAP']
@@ -9,7 +9,7 @@ from typing import Optional
9
9
  import torch
10
10
  from heavyball.utils import min_dtype, precond_grad_cached_
11
11
 
12
- from .utils import update_param_, warmup, init_Q_exprs, trust_region_clip_, PSGDBase, split_p_and_g_in_group, \
12
+ from .utils import update_param_, warmup, init_Q_exprs, trust_region_clip_, PSGDBase, \
13
13
  line_to_triu, triu_to_line, einsum_base, promote, stochastic_lerp_, beta_debias, precond_grad_cached_
14
14
 
15
15
 
@@ -43,7 +43,8 @@ class ForeachCachedDelayedPSGDKron(PSGDBase):
43
43
  momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
44
44
  split: bool = False, clip_fn: Optional[callable] = None, store_triu_as_line: bool = True,
45
45
  foreach: bool = True, q_dtype='float32', stochastic_schedule: bool = True,
46
- storage_dtype: str = 'float32', #
46
+ storage_dtype: str = 'float32', mars: bool = False, caution: bool = False, mars_gamma: float = 0.0025,
47
+ #
47
48
  # expert parameters
48
49
  precond_init_scale=1.0, precond_lr=0.1):
49
50
  if not 0.0 <= lr:
@@ -61,7 +62,7 @@ class ForeachCachedDelayedPSGDKron(PSGDBase):
61
62
  momentum_into_precond_update=momentum_into_precond_update, precond_lr=precond_lr,
62
63
  precond_init_scale=precond_init_scale, step=0, warmup_steps=warmup_steps, merge_dims=merge_dims,
63
64
  split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype,
64
- storage_dtype=storage_dtype)
65
+ storage_dtype=storage_dtype, caution=caution, mars_gamma=mars_gamma, mars=mars)
65
66
  super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
66
67
 
67
68
  def _step(self, group):
@@ -81,7 +82,7 @@ class ForeachCachedDelayedPSGDKron(PSGDBase):
81
82
 
82
83
  vals = []
83
84
 
84
- for p, g in split_p_and_g_in_group(group, should_promote=False):
85
+ for p, g in self.split_p_and_g_in_group(group, should_promote=False, beta1=beta):
85
86
  state = self.state_(p)
86
87
 
87
88
  if 'Q' not in state:
@@ -120,7 +121,7 @@ class ForeachCachedDelayedPSGDKron(PSGDBase):
120
121
  q_orig = Q_list.pop(0)
121
122
  ea = exp_avg_list.pop(0)
122
123
 
123
- precond_grad_cached_(cached_q, ea, self.state_(p)['cache_expr'], p, lr, weight_decay, self.clip_fn)
124
+ precond_grad_cached_(cached_q, ea, self.state_(p)['cache_expr'], p, lr, weight_decay, self.clip_fn, group['caution'], g)
124
125
 
125
126
  if should_update:
126
127
  q = line_to_triu(q_orig) if store_triu_as_line else q_orig
@@ -8,7 +8,7 @@ from typing import Optional
8
8
 
9
9
  import torch
10
10
 
11
- from .utils import update_param_, warmup, init_Q_exprs, trust_region_clip_, PSGDBase, split_p_and_g_in_group, \
11
+ from .utils import update_param_, warmup, init_Q_exprs, trust_region_clip_, PSGDBase, \
12
12
  line_to_triu, triu_to_line, einsum_base, promote, stochastic_lerp_, beta_debias, precond_grad_cached_
13
13
 
14
14
 
@@ -40,7 +40,8 @@ class ForeachCachedPSGDKron(PSGDBase):
40
40
  momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
41
41
  split: bool = False, clip_fn: Optional[callable] = None, store_triu_as_line: bool = True,
42
42
  foreach: bool = True, q_dtype='float32', stochastic_schedule: bool = True,
43
- storage_dtype: str = 'float32', #
43
+ storage_dtype: str = 'float32', mars: bool = False, caution: bool = False, mars_gamma: float = 0.0025,
44
+ #
44
45
  # expert parameters
45
46
  precond_init_scale=1.0, precond_lr=0.1):
46
47
  if not 0.0 <= lr:
@@ -58,7 +59,7 @@ class ForeachCachedPSGDKron(PSGDBase):
58
59
  momentum_into_precond_update=momentum_into_precond_update, precond_lr=precond_lr,
59
60
  precond_init_scale=precond_init_scale, step=0, warmup_steps=warmup_steps, merge_dims=merge_dims,
60
61
  split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype,
61
- storage_dtype=storage_dtype)
62
+ storage_dtype=storage_dtype, caution=caution, mars_gamma=mars_gamma, mars=mars)
62
63
  super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
63
64
 
64
65
  def _step(self, group):
@@ -78,7 +79,7 @@ class ForeachCachedPSGDKron(PSGDBase):
78
79
 
79
80
  vals = []
80
81
 
81
- for p, g in split_p_and_g_in_group(group, should_promote=False):
82
+ for p, g in self.split_p_and_g_in_group(group, should_promote=False, beta1=beta):
82
83
  state = self.state_(p)
83
84
 
84
85
  if 'Q' not in state:
@@ -128,4 +129,5 @@ class ForeachCachedPSGDKron(PSGDBase):
128
129
  else:
129
130
  torch.mul(q_.conj(), q_, out=c_)
130
131
 
131
- precond_grad_cached_(cached_q, ea, self.state_(p)['cache_expr'], p, lr, weight_decay, self.clip_fn)
132
+ precond_grad_cached_(cached_q, ea, self.state_(p)['cache_expr'], p, lr, weight_decay, self.clip_fn,
133
+ group['caution'], g)
@@ -8,14 +8,13 @@ import torch
8
8
  from heavyball.utils import stochastic_lerp_, beta_debias
9
9
 
10
10
  from .utils import update_param_, warmup, psgd_precond_grad, init_Q_exprs, trust_region_clip_, PSGDBase, \
11
- split_p_and_g_in_group, triu_to_line, line_to_triu, promote
11
+ triu_to_line, line_to_triu, promote
12
+
12
13
 
13
- # TODO: E1123 00:51:55.423000 159394 site-packages/torch/_guards.py:283] [5/0] Error while creating guard:
14
- # E1123 00:51:55.423000 159394 site-packages/torch/_guards.py:283] [5/0] Name: "G['psgd_precond_grad'].__defaults__[0]"
15
14
  @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
16
- def _compilable_psgd_precond_grad_(q, exprs, ea, p, lr, weight_decay, clip_fn):
15
+ def _compilable_psgd_precond_grad_(q, exprs, ea, p, lr, weight_deca, clip_fn, caution, grad):
17
16
  new = psgd_precond_grad(q, exprs, ea)
18
- update_param_([p], clip_fn([new]), lr, weight_decay)
17
+ update_param_([p], clip_fn([new]), lr, weight_decay, caution=caution, grad=grad)
19
18
 
20
19
 
21
20
  class ForeachDelayedPSGD(PSGDBase):
@@ -46,7 +45,8 @@ class ForeachDelayedPSGD(PSGDBase):
46
45
  max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
47
46
  momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
48
47
  split: bool = False, clip_fn: callable = None, store_triu_as_line: bool = True, foreach: bool = True,
49
- q_dtype='float32', stochastic_schedule: bool = True, storage_dtype: str = 'float32', #
48
+ q_dtype='float32', stochastic_schedule: bool = True, storage_dtype: str = 'float32',
49
+ mars: bool = False, caution: bool = False, mars_gamma: float = 0.0025, #
50
50
  # expert parameters
51
51
  precond_init_scale=1.0, precond_lr=0.1):
52
52
  if not 0.0 <= lr:
@@ -63,7 +63,9 @@ class ForeachDelayedPSGD(PSGDBase):
63
63
  min_ndim_triangular=min_ndim_triangular, memory_save_mode=memory_save_mode,
64
64
  momentum_into_precond_update=momentum_into_precond_update, precond_lr=precond_lr,
65
65
  precond_init_scale=precond_init_scale, step=0, warmup_steps=warmup_steps, merge_dims=merge_dims,
66
- split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype, storage_dtype=storage_dtype)
66
+ split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype,
67
+ storage_dtype=storage_dtype,
68
+ caution=caution, mars_gamma=mars_gamma, mars=mars)
67
69
  super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
68
70
 
69
71
  def _step(self, group):
@@ -83,7 +85,7 @@ class ForeachDelayedPSGD(PSGDBase):
83
85
 
84
86
  vals = []
85
87
 
86
- for p, g in split_p_and_g_in_group(group, should_promote=False):
88
+ for p, g in self.split_p_and_g_in_group(group, should_promote=False, beta1=beta):
87
89
  state = self.state_(p)
88
90
 
89
91
  if 'Q' not in state:
@@ -112,7 +114,8 @@ class ForeachDelayedPSGD(PSGDBase):
112
114
  q_orig = Q_list.pop(0)
113
115
  ea = exp_avg_list.pop(0)
114
116
  q = line_to_triu(q_orig) if store_triu_as_line else q_orig
115
- _compilable_psgd_precond_grad_(q, self.state_(p)["exprs"], ea, p, lr, weight_decay, self.clip_fn)
117
+ _compilable_psgd_precond_grad_(q, self.state_(p)["exprs"], ea, p, lr, weight_decay, self.clip_fn, group['caution'],
118
+ g)
116
119
  if should_update:
117
120
  q32 = [promote(q_) for q_ in q]
118
121
  self.do_update(group, [p], [ea if momentum_into_precond_update else g], [q32], precond_lr, [q_orig],
@@ -1,18 +1,19 @@
1
1
  import torch
2
2
  import torch.optim
3
-
4
3
  from heavyball.utils import copy_stochastic_list_
4
+
5
5
  from .utils import warmup, exp_avg_sq_, beta_debias, update_param_, StatefulOptimizer, promote
6
6
 
7
7
 
8
- @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
9
- def _compilable_step_(y, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps, decay):
8
+ @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
9
+ def _compilable_step_(y, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps, decay, caution):
10
10
  g32, exp_avg32, exp_avg_sq32 = [list(map(promote, x)) for x in [grad, exp_avg, exp_avg_sq]]
11
11
 
12
12
  torch._foreach_lerp_(exp_avg32, g32, 1 - beta_debias(beta1, step + 1))
13
13
  denom = list(exp_avg_sq_(exp_avg_sq32, g32, beta_debias(beta2, step + 1), eps))
14
14
 
15
- update_param_(y, exp_avg32, lr, decay, lambda p, e, l: p.addcdiv_(e, denom.pop(0), value=l))
15
+ update_param_(y, exp_avg32, lr, decay, lambda p, e, l: p.addcdiv_(e, denom.pop(0), value=l), caution=caution,
16
+ grad=g32)
16
17
 
17
18
  copy_stochastic_list_(exp_avg, exp_avg32)
18
19
  copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
@@ -20,9 +21,11 @@ def _compilable_step_(y, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps,
20
21
 
21
22
  class ForeachAdamW(StatefulOptimizer):
22
23
  def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
23
- foreach: bool = True, storage_dtype: str = 'float32'):
24
+ foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False,
25
+ mars_gamma: float = 0.0025):
24
26
  defaults = dict(lr=lr, betas=betas, eps=eps, k=0, warmup_steps=warmup_steps, train_mode=True, weight_sum=0.0,
25
- lr_max=-1.0, weight_decay=weight_decay, storage_dtype=storage_dtype)
27
+ lr_max=-1.0, weight_decay=weight_decay, storage_dtype=storage_dtype, mars=mars, caution=caution,
28
+ mars_gamma=mars_gamma)
26
29
  super().__init__(params, defaults, foreach)
27
30
 
28
31
  def _step(self, group):
@@ -48,9 +51,13 @@ class ForeachAdamW(StatefulOptimizer):
48
51
  y, grad, exp_avg_sq, exp_avg = zip(
49
52
  *[(p.data, p.grad, self.state_(p)['exp_avg_sq'], self.state_(p)['exp_avg']) for p in active_p])
50
53
 
54
+ if group['mars']:
55
+ self.mars_correct_list(group, y, grad, group['mars_gamma'], group['betas'][0])
56
+
51
57
  lr = -warmup(group['lr'], k + 1, group['warmup_steps'])
52
58
  lr = torch.empty((), dtype=torch.float32, device=y[0].device).fill_(lr)
53
59
  step = torch.empty((), dtype=torch.int32, device=y[0].device).fill_(k)
54
- _compilable_step_(y, grad, exp_avg_sq, exp_avg, group['betas'][0], group['betas'][1], step, lr, eps, decay)
60
+ _compilable_step_(y, grad, exp_avg_sq, exp_avg, group['betas'][0], group['betas'][1], step, lr, eps, decay,
61
+ group['caution'])
55
62
 
56
63
  group['k'] = k + 1
@@ -5,10 +5,10 @@ from heavyball.utils import copy_stochastic_list_
5
5
  from .utils import warmup, beta_debias, update_param_, StatefulOptimizer, promote
6
6
 
7
7
 
8
- @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
9
- def _compilable_step_(y, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps, decay):
8
+ @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
9
+ def _compilable_step_(y, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps, decay, caution):
10
10
  g32, exp_avg32, exp_avg_sq32 = [list(map(promote, x)) for x in [grad, exp_avg, exp_avg_sq]]
11
- update_param_(y, exp_avg, lr, decay)
11
+ update_param_(y, exp_avg, lr, decay, caution=caution, grad=g32)
12
12
 
13
13
  beta1 = beta_debias(beta1, step)
14
14
  denom = torch._foreach_sqrt(exp_avg_sq32)
@@ -27,9 +27,11 @@ def _compilable_step_(y, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps,
27
27
  class ForeachADOPT(StatefulOptimizer):
28
28
 
29
29
  def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
30
- foreach: bool = True, storage_dtype: str = 'float32'):
30
+ foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False,
31
+ mars_gamma: float = 0.0025):
31
32
  defaults = dict(lr=lr, betas=betas, eps=eps, k=0, warmup_steps=warmup_steps, train_mode=True, weight_sum=0.0,
32
- lr_max=-1.0, weight_decay=weight_decay, storage_dtype=storage_dtype)
33
+ lr_max=-1.0, weight_decay=weight_decay, storage_dtype=storage_dtype, mars=mars, caution=caution,
34
+ mars_gamma=mars_gamma)
33
35
  super().__init__(params, defaults, foreach)
34
36
 
35
37
  def _step(self, group):
@@ -57,11 +59,14 @@ class ForeachADOPT(StatefulOptimizer):
57
59
 
58
60
  group['k'] = k + 1
59
61
 
62
+ if group['mars']:
63
+ self.mars_correct_list(group, y, grad, group['mars_gamma'], group['betas'][0])
64
+
60
65
  if k > 1:
61
66
  lr = -warmup(group['lr'], k - 1, group['warmup_steps'])
62
67
  lr = torch.empty((), dtype=torch.float32, device=y[0].device).fill_(lr)
63
68
  k = torch.empty((), dtype=torch.int32, device=y[0].device).fill_(k)
64
- _compilable_step_(y, grad, exp_avg_sq, exp_avg, group['betas'][0], group['betas'][1], k, lr, eps, decay)
69
+ _compilable_step_(y, grad, exp_avg_sq, exp_avg, group['betas'][0], group['betas'][1], k, lr, eps, decay, group['caution'])
65
70
  return
66
71
 
67
72
  grad = [promote(g) for g in grad]
@@ -4,8 +4,8 @@ import torch.optim
4
4
  from .utils import warmup, exp_avg_sq_, beta_debias, update_param_, StatefulOptimizer, promote, copy_stochastic_list_
5
5
 
6
6
 
7
- @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
8
- def _compilable_step_(y, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps, decay):
7
+ @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
8
+ def _compilable_step_(y, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps, decay, caution):
9
9
  g32, exp_avg32, exp_avg_sq32 = [list(map(promote, x)) for x in [grad, exp_avg, exp_avg_sq]]
10
10
 
11
11
  denom = exp_avg_sq_(exp_avg_sq32, g32, beta_debias(beta2, step), eps)
@@ -14,7 +14,7 @@ def _compilable_step_(y, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps,
14
14
  torch._foreach_mul_(exp_avg32, beta1)
15
15
  [ea32.addcdiv_(g, d, value=1 - beta1) for ea32, g, d in zip(exp_avg32, g32, denom)]
16
16
 
17
- update_param_(y, exp_avg32, lr, decay)
17
+ update_param_(y, exp_avg32, lr, decay, caution=caution, grad=g32)
18
18
 
19
19
  copy_stochastic_list_(exp_avg, exp_avg32)
20
20
  copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
@@ -23,9 +23,11 @@ def _compilable_step_(y, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps,
23
23
  class ForeachLaProp(StatefulOptimizer):
24
24
 
25
25
  def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=1,
26
- foreach: bool = True, storage_dtype: str = 'float32'):
26
+ foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False,
27
+ mars_gamma: float = 0.0025):
27
28
  defaults = dict(lr=lr, betas=betas, eps=eps, k=0, warmup_steps=warmup_steps, train_mode=True, weight_sum=0.0,
28
- lr_max=-1.0, weight_decay=weight_decay, storage_dtype=storage_dtype)
29
+ lr_max=-1.0, weight_decay=weight_decay, storage_dtype=storage_dtype, mars=mars, caution=caution,
30
+ mars_gamma=mars_gamma)
29
31
  super().__init__(params, defaults, foreach)
30
32
 
31
33
  def _step(self, group):
@@ -52,10 +54,14 @@ class ForeachLaProp(StatefulOptimizer):
52
54
  *[(p.data, p.grad, self.state_(p)['exp_avg_sq'], self.state_(p)['exp_avg']) #
53
55
  for p in active_p])
54
56
 
57
+ if group['mars']:
58
+ self.mars_correct_list(group, y, grad, group['mars_gamma'], group['betas'][0])
59
+
55
60
  lr = -warmup(group['lr'], k + 1, group['warmup_steps'])
56
61
  lr = torch.empty((), dtype=torch.float32, device=y[0].device).fill_(lr)
57
62
  step = torch.empty((), dtype=torch.int32, device=y[0].device).fill_(k + 1)
58
63
 
59
- _compilable_step_(y, grad, exp_avg_sq, exp_avg, group['betas'][0], group['betas'][1], step, lr, eps, decay)
64
+ _compilable_step_(y, grad, exp_avg_sq, exp_avg, group['betas'][0], group['betas'][1], step, lr, eps, decay,
65
+ group['caution'])
60
66
 
61
67
  group['k'] = k + 1
@@ -5,7 +5,7 @@ from heavyball.utils import get_ckp1, copy_stochastic_list_
5
5
  from .utils import warmup, ScheduleFree, exp_avg_sq_, beta_debias, promote, _compilable_schedule_free_
6
6
 
7
7
 
8
- @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
8
+ @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
9
9
  def _compilable_step_(y, grad, exp_avg_sq, z, beta1, beta2, step, ckp1, eps, decay, lr):
10
10
  old_debiased2 = beta_debias(beta2, step)
11
11
 
@@ -21,13 +21,17 @@ def _compilable_step_(y, grad, exp_avg_sq, z, beta1, beta2, step, ckp1, eps, dec
21
21
 
22
22
  copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
23
23
 
24
+
24
25
  class ForeachSFAdamW(ScheduleFree):
25
26
  def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0, r=0.0,
26
- weight_lr_power=2.0, foreach: bool = True, storage_dtype: str = 'float32'):
27
+ weight_lr_power=2.0, foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False,
28
+ caution: bool = False, mars_gamma: float = 0.0025):
29
+
30
+ assert not caution, "Caution not implemented for SFAdamW"
27
31
 
28
32
  defaults = dict(lr=lr, betas=betas, eps=eps, r=r, k=0, warmup_steps=warmup_steps, train_mode=True,
29
33
  weight_sum=0.0, lr_max=-1.0, weight_lr_power=weight_lr_power, weight_decay=weight_decay,
30
- foreach=foreach, storage_dtype=storage_dtype)
34
+ foreach=foreach, storage_dtype=storage_dtype, mars=mars, caution=caution, mars_gamma=mars_gamma)
31
35
  super().__init__(params, defaults, foreach)
32
36
 
33
37
  def _step(self, group):
@@ -53,6 +57,9 @@ class ForeachSFAdamW(ScheduleFree):
53
57
  y, grad, exp_avg_sq, z = zip(*[(p.data, p.grad, self.state_(p)['exp_avg_sq'], self.state_(p)['z']) #
54
58
  for p in active_p])
55
59
 
60
+ if group['mars']:
61
+ self.mars_correct_list(group, y, grad, group['mars_gamma'], group['betas'][0])
62
+
56
63
  lr = warmup(group['lr'], k + 1, group['warmup_steps'])
57
64
  ckp1, group['weight_sum'] = get_ckp1(lr, group['weight_lr_power'], group['weight_sum'], group['r'], k + 1)
58
65
 
@@ -1,7 +1,7 @@
1
1
  import torch
2
2
 
3
3
  from .utils import init_preconditioner, update_preconditioner, project, beta_debias, exp_avg_sq_, update_param_, set_, \
4
- split_p_and_g_in_group, StatefulOptimizer, exp_avg_
4
+ StatefulOptimizer, exp_avg_
5
5
 
6
6
 
7
7
  class ForeachSOAP(StatefulOptimizer):
@@ -26,11 +26,13 @@ class ForeachSOAP(StatefulOptimizer):
26
26
  weight_decay: float = 0.01, precondition_frequency: int = 2, max_precond_dim: int = 2048, #
27
27
  merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
28
28
  data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1,
29
- split: bool = False, foreach: bool = True):
29
+ split: bool = False, foreach: bool = True, mars: bool = False, caution: bool = False,
30
+ mars_gamma: float = 0.0025):
30
31
  defaults = {"lr": lr, "betas": betas, "shampoo_beta": shampoo_beta, "eps": eps, "weight_decay": weight_decay,
31
32
  "precondition_frequency": precondition_frequency, "max_precond_dim": max_precond_dim,
32
33
  "merge_dims": merge_dims, "precondition_1d": precondition_1d, "normalize_grads": normalize_grads,
33
- "correct_bias": correct_bias, 'warmup_steps': warmup_steps, 'split': split}
34
+ "correct_bias": correct_bias, 'warmup_steps': warmup_steps, 'split': split, 'mars': mars,
35
+ 'caution': caution, 'mars_gamma': mars_gamma}
34
36
  super().__init__(params, defaults, foreach)
35
37
  self._data_format = data_format
36
38
 
@@ -41,7 +43,7 @@ class ForeachSOAP(StatefulOptimizer):
41
43
  max_precond_dim = group['max_precond_dim']
42
44
  precondition_1d = group['precondition_1d']
43
45
 
44
- for p, g in split_p_and_g_in_group(group):
46
+ for p, g in self.split_p_and_g_in_group(group, beta1=group['betas'][0]):
45
47
  state = self.state_(p)
46
48
  step = state['step'] = state.get("step", -1) + 1
47
49
 
@@ -71,6 +73,8 @@ class ForeachSOAP(StatefulOptimizer):
71
73
  step_tensor = torch.empty((), dtype=torch.int32, device=p_list[0].device).fill_(step)
72
74
  denom = exp_avg_(exp_avg, exp_avg_sq, grad, grad_projected, beta1, beta2, step_tensor)
73
75
 
76
+ step_size = -group["lr"] * min(step / group['warmup_steps'], 1)
77
+
74
78
  for p, g, ea, d in zip(p_list, grad, exp_avg, denom):
75
79
  state = self.state_(p)
76
80
  # Projecting the exponential moving average of gradients to the eigenbases of Shampoo's preconditioner
@@ -80,11 +84,9 @@ class ForeachSOAP(StatefulOptimizer):
80
84
  # Projecting back the preconditioned (by Adam) exponential moving average of gradients
81
85
  # to the original space
82
86
  # CANT DO /= HERE AS EXP_AVG MAY POINT TO THE BUFFER
83
- set_(d, project(exp_avg_projected / d, state['Q'], True))
87
+ precond = project(exp_avg_projected / d, state['Q'], True)
84
88
 
85
89
  update_preconditioner(g, state, max_precond_dim, precondition_1d, old_debiased2,
86
90
  step > 0 and step % group['precondition_frequency'] == 0)
87
91
 
88
- # Why does this have to be rebiased here?
89
- step_size = -group["lr"] * min(step / group['warmup_steps'], 1)
90
- update_param_(p_list, denom, step_size, group["weight_decay"])
92
+ update_param_([p], [precond], step_size, group["weight_decay"], caution=group['caution'], grad=[g])
@@ -5,10 +5,10 @@ Source available at https://github.com/evanatyourservice/kron_torch/blob/97a2b5e
5
5
  """
6
6
 
7
7
  import torch
8
-
9
8
  from heavyball.utils import triu_to_line, line_to_triu, identity, stochastic_lerp_
9
+
10
10
  from .utils import update_param_, warmup, psgd_precond_grad, init_Q_exprs, PSGDBase, exp_avg_sq_, beta_debias, \
11
- split_p_and_g_in_group, promote
11
+ promote
12
12
 
13
13
 
14
14
  class ForeachPaLMPAdam(PSGDBase):
@@ -39,7 +39,8 @@ class ForeachPaLMPAdam(PSGDBase):
39
39
  momentum_into_precond_update=True, warmup_steps: int = 1, betas=(None, None), beta: float = 0.9,
40
40
  beta2_scale: float = 0.8, merge_dims: bool = False, split: bool = False, clip_fn: callable = None,
41
41
  store_triu_as_line: bool = True, foreach: bool = True, q_dtype='float32',
42
- stochastic_schedule: bool = True, storage_dtype:str ='float32',#
42
+ stochastic_schedule: bool = True, storage_dtype: str = 'float32', mars: bool = False,
43
+ caution: bool = False, mars_gamma: float = 0.0025, #
43
44
  # expert parameters
44
45
  precond_init_scale=1.0, precond_lr=0.1):
45
46
  if not 0.0 <= lr:
@@ -57,7 +58,8 @@ class ForeachPaLMPAdam(PSGDBase):
57
58
  momentum_into_precond_update=momentum_into_precond_update, precond_lr=precond_lr,
58
59
  precond_init_scale=precond_init_scale, step=0, warmup_steps=warmup_steps, beta=beta,
59
60
  beta2_scale=beta2_scale, merge_dims=merge_dims, split=split,
60
- store_triu_as_line=store_triu_as_line, q_dtype=q_dtype, storage_dtype=storage_dtype)
61
+ store_triu_as_line=store_triu_as_line, q_dtype=q_dtype, storage_dtype=storage_dtype,
62
+ mars=mars, caution=caution, mars_gamma=mars_gamma)
61
63
  super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
62
64
 
63
65
  def _step(self, group):
@@ -75,7 +77,7 @@ class ForeachPaLMPAdam(PSGDBase):
75
77
 
76
78
  vals = []
77
79
 
78
- for p, g in split_p_and_g_in_group(group, should_promote=False):
80
+ for p, g in self.split_p_and_g_in_group(group, should_promote=False, beta1=group['beta']):
79
81
  state = self.state_(p)
80
82
 
81
83
  if 'Q' not in state:
@@ -107,6 +109,7 @@ class ForeachPaLMPAdam(PSGDBase):
107
109
  lr = -warmup(lr, group['step'], group['warmup_steps'])
108
110
 
109
111
  for p, Q, g, ea, eas in zip(p_list, Q_triu, grad_list, exp_avg, exp_avg_sq):
112
+ gc = g.clone() if group['caution'] else None
110
113
  psgd_precond_grad(Q, self.state_(p)["exprs"], g, inplace=True)
111
114
  ea = psgd_precond_grad(Q, self.state_(p)["exprs"], ea)
112
115
  exp_avg_sq_(eas, g, beta_debias(beta2, group['step']), 1e-8, out=g)
@@ -115,5 +118,4 @@ class ForeachPaLMPAdam(PSGDBase):
115
118
  divide by g here, because g == denom (from exp_avg_sq_(out=g)), avoids denom allocation
116
119
  divide into g so we can deallocate ea, avoids one allocation (-> less memory than equivalent foreach)
117
120
  """
118
- update_param_([p], self.clip_fn([g]), lr, weight_decay)
119
-
121
+ update_param_([p], self.clip_fn([g]), lr, weight_decay, caution=group['caution'], grad=gc)
@@ -5,7 +5,7 @@ from .utils import warmup, ScheduleFree, exp_avg_sq_, beta_debias, get_ckp1, pro
5
5
  _compilable_schedule_free_, copy_stochastic_list_
6
6
 
7
7
 
8
- @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
8
+ @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
9
9
  def _compilable_step_(y, grad, exp_avg_sq, z, beta1, beta2, step, ckp1, eps, decay, lr):
10
10
  old_debiased2 = beta_debias(beta2, step)
11
11
 
@@ -24,12 +24,17 @@ def _compilable_step_(y, grad, exp_avg_sq, z, beta1, beta2, step, ckp1, eps, dec
24
24
 
25
25
  class PaLMForeachSFAdamW(ScheduleFree):
26
26
  def __init__(self, params, lr=0.0025, beta=0.9, betas=(None, None), eps=1e-8, weight_decay=0, warmup_steps=0, r=0.0,
27
- weight_lr_power=2.0, beta2_scale: float = 0.8, foreach: bool = True, storage_dtype: str = 'float32'):
27
+ weight_lr_power=2.0, beta2_scale: float = 0.8, foreach: bool = True, storage_dtype: str = 'float32',
28
+ mars: bool = False, caution: bool = False, mars_gamma: float = 0.0025):
28
29
  if betas[0] is not None:
29
30
  beta = betas[0]
31
+
32
+ assert not caution, "Caution not implemented for SFAdamW"
33
+
30
34
  defaults = dict(lr=lr, beta=beta, eps=eps, r=r, k=0, warmup_steps=warmup_steps, train_mode=True, weight_sum=0.0,
31
35
  lr_max=-1.0, weight_lr_power=weight_lr_power, weight_decay=weight_decay,
32
- beta2_scale=beta2_scale, storage_dtype=storage_dtype)
36
+ beta2_scale=beta2_scale, storage_dtype=storage_dtype, mars=mars, caution=caution,
37
+ mars_gamma=mars_gamma)
33
38
  super().__init__(params, defaults, foreach)
34
39
 
35
40
  def _step(self, group):
@@ -58,6 +63,9 @@ class PaLMForeachSFAdamW(ScheduleFree):
58
63
  y, grad, exp_avg_sq, z = zip(*[(p.data, p.grad, self.state_(p)['exp_avg_sq'], self.state_(p)['z']) #
59
64
  for p in active_p])
60
65
 
66
+ if group['mars']:
67
+ self.mars_correct_list(group, y, grad, group['mars_gamma'], group['betas'][0])
68
+
61
69
  lr = warmup(group['lr'], k + 1, group['warmup_steps'])
62
70
  ckp1, group['weight_sum'] = get_ckp1(lr, group['weight_lr_power'], group['weight_sum'], group['r'], k + 1)
63
71
 
@@ -1,8 +1,7 @@
1
1
  import torch
2
2
 
3
3
  from .utils import init_preconditioner, update_preconditioner, project, beta_debias, exp_avg_sq_, update_param_, set_, \
4
- split_p_and_g_in_group, StatefulOptimizer, exp_avg_
5
-
4
+ StatefulOptimizer, exp_avg_
6
5
 
7
6
 
8
7
  class PaLMForeachSOAP(StatefulOptimizer):
@@ -33,14 +32,15 @@ class PaLMForeachSOAP(StatefulOptimizer):
33
32
  max_precond_dim: int = 2048, #
34
33
  merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
35
34
  data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1,
36
- beta2_scale: float = 0.8, split: bool = False, foreach: bool = True):
35
+ beta2_scale: float = 0.8, split: bool = False, foreach: bool = True, mars: bool = False,
36
+ caution: bool = False, mars_gamma: float = 0.0025):
37
37
  if betas[0] is not None:
38
38
  beta = betas[0]
39
39
  defaults = {"lr": lr, "beta": beta, "shampoo_beta": shampoo_beta, "eps": eps, "weight_decay": weight_decay,
40
40
  "precondition_frequency": precondition_frequency, "max_precond_dim": max_precond_dim,
41
41
  "merge_dims": merge_dims, "precondition_1d": precondition_1d, "normalize_grads": normalize_grads,
42
42
  "correct_bias": correct_bias, 'warmup_steps': warmup_steps, 'beta2_scale': beta2_scale,
43
- 'split': split}
43
+ 'split': split, 'mars': mars, 'caution': caution, 'mars_gamma': mars_gamma}
44
44
  super().__init__(params, defaults, foreach)
45
45
  self._data_format = data_format
46
46
 
@@ -51,7 +51,7 @@ class PaLMForeachSOAP(StatefulOptimizer):
51
51
  max_precond_dim = group['max_precond_dim']
52
52
  precondition_1d = group['precondition_1d']
53
53
 
54
- for p, g in split_p_and_g_in_group(group):
54
+ for p, g in self.split_p_and_g_in_group(group, beta1=group['beta']):
55
55
  state = self.state_(p)
56
56
  step = state['step'] = state.get("step", -1) + 1
57
57
 
@@ -82,6 +82,7 @@ class PaLMForeachSOAP(StatefulOptimizer):
82
82
  beta2 = torch.empty((), dtype=torch.float32, device=p_list[0].device).fill_(beta2)
83
83
  step_tensor = torch.empty((), dtype=torch.int32, device=p_list[0].device).fill_(step)
84
84
  denom = exp_avg_(exp_avg, exp_avg_sq, grad, grad_projected, beta1, beta2, step_tensor)
85
+ step_size = -group["lr"] * min(step / group['warmup_steps'], 1)
85
86
 
86
87
  for p, g, ea, d in zip(p_list, grad, exp_avg, denom):
87
88
  state = self.state_(p)
@@ -92,11 +93,9 @@ class PaLMForeachSOAP(StatefulOptimizer):
92
93
  # Projecting back the preconditioned (by Adam) exponential moving average of gradients
93
94
  # to the original space
94
95
  # CANT DO /= HERE AS EXP_AVG MAY POINT TO THE BUFFER
95
- set_(d, project(exp_avg_projected / d, state['Q'], True))
96
+ precond = project(exp_avg_projected / d, state['Q'], True)
96
97
 
97
98
  update_preconditioner(g, state, max_precond_dim, precondition_1d, old_debiased2,
98
99
  step > 0 and step % group['precondition_frequency'] == 0)
99
100
 
100
- # Why does this have to be rebiased here?
101
- step_size = -group["lr"] * min(step / group['warmup_steps'], 1)
102
- update_param_(p_list, denom, step_size, group["weight_decay"])
101
+ update_param_([p], [precond], step_size, group["weight_decay"], caution=group['caution'], grad=[g])