heavyball 0.15.1__tar.gz → 0.16.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. {heavyball-0.15.1 → heavyball-0.16.0}/PKG-INFO +4 -2
  2. {heavyball-0.15.1 → heavyball-0.16.0}/README.md +3 -1
  3. {heavyball-0.15.1 → heavyball-0.16.0}/heavyball/cached_psgd_kron.py +3 -2
  4. {heavyball-0.15.1 → heavyball-0.16.0}/heavyball/delayed_psgd.py +5 -3
  5. {heavyball-0.15.1 → heavyball-0.16.0}/heavyball/foreach_adamw.py +3 -2
  6. {heavyball-0.15.1 → heavyball-0.16.0}/heavyball/foreach_adopt.py +3 -2
  7. {heavyball-0.15.1 → heavyball-0.16.0}/heavyball/foreach_laprop.py +3 -2
  8. {heavyball-0.15.1 → heavyball-0.16.0}/heavyball/foreach_sfadamw.py +4 -4
  9. {heavyball-0.15.1 → heavyball-0.16.0}/heavyball/foreach_soap.py +4 -3
  10. {heavyball-0.15.1 → heavyball-0.16.0}/heavyball/p_adam.py +4 -3
  11. {heavyball-0.15.1 → heavyball-0.16.0}/heavyball/palm_foreach_sfadamw.py +3 -2
  12. {heavyball-0.15.1 → heavyball-0.16.0}/heavyball/palm_foreach_soap.py +3 -2
  13. {heavyball-0.15.1 → heavyball-0.16.0}/heavyball/precond_schedule_foreach_soap.py +3 -2
  14. {heavyball-0.15.1 → heavyball-0.16.0}/heavyball/precond_schedule_palm_foreach_soap.py +3 -2
  15. {heavyball-0.15.1 → heavyball-0.16.0}/heavyball/precond_schedule_sfpsoap.py +3 -3
  16. {heavyball-0.15.1 → heavyball-0.16.0}/heavyball/psgd_kron.py +5 -3
  17. {heavyball-0.15.1 → heavyball-0.16.0}/heavyball/pure_psgd.py +3 -2
  18. {heavyball-0.15.1 → heavyball-0.16.0}/heavyball/schedule_free_palm_foreach_soap.py +4 -3
  19. {heavyball-0.15.1 → heavyball-0.16.0}/heavyball/utils.py +23 -5
  20. {heavyball-0.15.1 → heavyball-0.16.0}/heavyball.egg-info/PKG-INFO +4 -2
  21. {heavyball-0.15.1 → heavyball-0.16.0}/heavyball.egg-info/SOURCES.txt +1 -0
  22. {heavyball-0.15.1 → heavyball-0.16.0}/setup.py +1 -1
  23. heavyball-0.16.0/test/test_foreach.py +65 -0
  24. {heavyball-0.15.1 → heavyball-0.16.0}/LICENSE +0 -0
  25. {heavyball-0.15.1 → heavyball-0.16.0}/heavyball/__init__.py +0 -0
  26. {heavyball-0.15.1 → heavyball-0.16.0}/heavyball.egg-info/dependency_links.txt +0 -0
  27. {heavyball-0.15.1 → heavyball-0.16.0}/heavyball.egg-info/requires.txt +0 -0
  28. {heavyball-0.15.1 → heavyball-0.16.0}/heavyball.egg-info/top_level.txt +0 -0
  29. {heavyball-0.15.1 → heavyball-0.16.0}/setup.cfg +0 -0
  30. {heavyball-0.15.1 → heavyball-0.16.0}/test/test_closure.py +0 -0
  31. {heavyball-0.15.1 → heavyball-0.16.0}/test/test_memory.py +0 -0
  32. {heavyball-0.15.1 → heavyball-0.16.0}/test/test_merge.py +0 -0
  33. {heavyball-0.15.1 → heavyball-0.16.0}/test/test_no_grad.py +0 -0
  34. {heavyball-0.15.1 → heavyball-0.16.0}/test/test_psgd.py +0 -0
  35. {heavyball-0.15.1 → heavyball-0.16.0}/test/test_soap.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: heavyball
3
- Version: 0.15.1
3
+ Version: 0.16.0
4
4
  Summary: Efficient optimizers
5
5
  Home-page: https://github.com/clashluke/heavyball
6
6
  Author: Lucas Nestler
@@ -39,12 +39,14 @@ recommended experimental optimizer is `ForeachPSGDKron`.
39
39
 
40
40
  * **Stochastic Rounding**: [FP32 convergence with BF16 parameters](https://github.com/pytorch/pytorch/issues/120376)
41
41
  * **Inplace EMA**: Same math, but less memory, less compute and higher stability
42
- * **Foreach**: Fast multi-tensor application
42
+ * **Foreach**: Fast multi-tensor application (turn it off to save memory via `foreach=False`)
43
43
  * **PaLM Beta2**: Fast initial
44
44
  convergence, [stable late convergence](https://x.com/_clashluke/status/1820810798693818761)
45
45
  * **ScheduleFree**: No learning rate schedule, but better convergence
46
46
  * [**Preconditioner Schedule**](https://github.com/lixilinx/psgd_torch/): Improved loss-per-step in early convergence,
47
47
  better step-per-second in late convergence (explained below)
48
+ * **Memory-efficient storage** PSGD supports `store_triu_as_line` (default: `True`) to trade off memory usage for memory
49
+ bandwidth; turn it off for lower overheads (for more, see [PSGD Efficiency](docs/psgd_efficiency.md))
48
50
 
49
51
  ## Getting started
50
52
 
@@ -15,12 +15,14 @@ recommended experimental optimizer is `ForeachPSGDKron`.
15
15
 
16
16
  * **Stochastic Rounding**: [FP32 convergence with BF16 parameters](https://github.com/pytorch/pytorch/issues/120376)
17
17
  * **Inplace EMA**: Same math, but less memory, less compute and higher stability
18
- * **Foreach**: Fast multi-tensor application
18
+ * **Foreach**: Fast multi-tensor application (turn it off to save memory via `foreach=False`)
19
19
  * **PaLM Beta2**: Fast initial
20
20
  convergence, [stable late convergence](https://x.com/_clashluke/status/1820810798693818761)
21
21
  * **ScheduleFree**: No learning rate schedule, but better convergence
22
22
  * [**Preconditioner Schedule**](https://github.com/lixilinx/psgd_torch/): Improved loss-per-step in early convergence,
23
23
  better step-per-second in late convergence (explained below)
24
+ * **Memory-efficient storage** PSGD supports `store_triu_as_line` (default: `True`) to trade off memory usage for memory
25
+ bandwidth; turn it off for lower overheads (for more, see [PSGD Efficiency](docs/psgd_efficiency.md))
24
26
 
25
27
  ## Getting started
26
28
 
@@ -39,7 +39,8 @@ class ForeachCachedPSGDKron(PSGDBase):
39
39
  def __init__(self, params, lr=0.001, beta=0.9, weight_decay=0.0, preconditioner_update_probability=None,
40
40
  max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
41
41
  momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
42
- split: bool = False, clip_fn: Optional[callable] = None, store_triu_as_line: bool = True):
42
+ split: bool = False, clip_fn: Optional[callable] = None, store_triu_as_line: bool = True,
43
+ foreach: bool = True):
43
44
  if not 0.0 <= lr:
44
45
  raise ValueError(f"Invalid learning rate: {lr}")
45
46
  if not 0.0 <= beta < 1.0:
@@ -61,7 +62,7 @@ class ForeachCachedPSGDKron(PSGDBase):
61
62
  precond_init_scale=1.0, # precond init scale hardcoded to 1.0
62
63
  step=0, warmup_steps=warmup_steps, merge_dims=merge_dims, split=split,
63
64
  store_triu_as_line=store_triu_as_line)
64
- super().__init__(params, defaults)
65
+ super().__init__(params, defaults, foreach)
65
66
 
66
67
  self._prob_step = 0
67
68
 
@@ -38,7 +38,8 @@ class ForeachDelayedPSGD(PSGDBase):
38
38
  def __init__(self, params, lr=0.001, beta=0.9, weight_decay=0.0, preconditioner_update_probability=None,
39
39
  max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
40
40
  momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
41
- split: bool = False, clip_fn: callable = None, store_triu_as_line: bool = True):
41
+ split: bool = False, clip_fn: callable = None, store_triu_as_line: bool = True,
42
+ foreach: bool = True):
42
43
  if not 0.0 <= lr:
43
44
  raise ValueError(f"Invalid learning rate: {lr}")
44
45
  if not 0.0 <= beta < 1.0:
@@ -60,7 +61,7 @@ class ForeachDelayedPSGD(PSGDBase):
60
61
  precond_init_scale=1.0, # precond init scale hardcoded to 1.0
61
62
  step=0, warmup_steps=warmup_steps, merge_dims=merge_dims, split=split,
62
63
  store_triu_as_line=store_triu_as_line)
63
- super().__init__(params, defaults)
64
+ super().__init__(params, defaults, foreach)
64
65
 
65
66
  self._prob_step = 0
66
67
 
@@ -113,7 +114,8 @@ class ForeachDelayedPSGD(PSGDBase):
113
114
  q = line_to_triu(q_orig) if store_triu_as_line else q_orig
114
115
  new = psgd_precond_grad(q, self.state_(p)["exprs"], ea)
115
116
  if do_update:
116
- self.do_update([p], [ea if momentum_into_precond_update else g], [q], precond_lr, [q_orig] if store_triu_as_line else None)
117
+ self.do_update([p], [ea if momentum_into_precond_update else g], [q], precond_lr,
118
+ [q_orig] if store_triu_as_line else None)
117
119
  self.balance([g], [q])
118
120
  set_(g, new)
119
121
 
@@ -5,10 +5,11 @@ from .utils import warmup, exp_avg_sq_, beta_debias, update_param_, StatefulOpti
5
5
 
6
6
 
7
7
  class ForeachAdamW(StatefulOptimizer):
8
- def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0):
8
+ def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
9
+ foreach: bool = True):
9
10
  defaults = dict(lr=lr, betas=betas, eps=eps, k=0, warmup_steps=warmup_steps, train_mode=True, weight_sum=0.0,
10
11
  lr_max=-1.0, weight_decay=weight_decay)
11
- super().__init__(params, defaults)
12
+ super().__init__(params, defaults, foreach)
12
13
 
13
14
  def _step(self, group):
14
15
  eps = group['eps']
@@ -6,10 +6,11 @@ from .utils import warmup, beta_debias, update_param_, StatefulOptimizer
6
6
 
7
7
  class ForeachADOPT(StatefulOptimizer):
8
8
 
9
- def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0):
9
+ def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
10
+ foreach: bool = True):
10
11
  defaults = dict(lr=lr, betas=betas, eps=eps, k=0, warmup_steps=warmup_steps, train_mode=True, weight_sum=0.0,
11
12
  lr_max=-1.0, weight_decay=weight_decay)
12
- super().__init__(params, defaults)
13
+ super().__init__(params, defaults, foreach)
13
14
 
14
15
  def _step(self, group):
15
16
  eps = group['eps']
@@ -6,10 +6,11 @@ from .utils import warmup, exp_avg_sq_, beta_debias, update_param_, StatefulOpti
6
6
 
7
7
  class ForeachLaProp(StatefulOptimizer):
8
8
 
9
- def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=1):
9
+ def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=1,
10
+ foreach: bool = True):
10
11
  defaults = dict(lr=lr, betas=betas, eps=eps, k=0, warmup_steps=warmup_steps, train_mode=True, weight_sum=0.0,
11
12
  lr_max=-1.0, weight_decay=weight_decay)
12
- super().__init__(params, defaults)
13
+ super().__init__(params, defaults, foreach)
13
14
 
14
15
  def _step(self, group):
15
16
  eps = group['eps']
@@ -6,12 +6,12 @@ from .utils import schedule_free_, warmup, ScheduleFree, exp_avg_sq_, beta_debia
6
6
 
7
7
  class ForeachSFAdamW(ScheduleFree):
8
8
  def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0, r=0.0,
9
- weight_lr_power=2.0, foreach=hasattr(torch, "_foreach_mul_")):
9
+ weight_lr_power=2.0, foreach: bool = True):
10
10
 
11
11
  defaults = dict(lr=lr, betas=betas, eps=eps, r=r, k=0, warmup_steps=warmup_steps, train_mode=True,
12
12
  weight_sum=0.0, lr_max=-1.0, weight_lr_power=weight_lr_power, weight_decay=weight_decay,
13
13
  foreach=foreach)
14
- super().__init__(params, defaults)
14
+ super().__init__(params, defaults, foreach)
15
15
 
16
16
  def _step(self, group):
17
17
  eps = group['eps']
@@ -48,7 +48,7 @@ class ForeachSFAdamW(ScheduleFree):
48
48
  torch._foreach_add_(grad, y, alpha=decay)
49
49
 
50
50
  lr = warmup(group['lr'], k + 1, group['warmup_steps'])
51
- group['weight_sum'] = schedule_free_(lr, group['weight_lr_power'], group['weight_sum'], group['betas'][0],
52
- y, z, grad, group['r'], k + 1)
51
+ group['weight_sum'] = schedule_free_(lr, group['weight_lr_power'], group['weight_sum'], group['betas'][0], y, z,
52
+ grad, group['r'], k + 1)
53
53
 
54
54
  group['k'] = k + 1
@@ -26,12 +26,13 @@ class ForeachSOAP(StatefulOptimizer):
26
26
  weight_decay: float = 0.01, precondition_frequency: int = 2, max_precond_dim: int = 2048, #
27
27
  merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
28
28
  data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1,
29
- split: bool = False):
29
+ split: bool = False,
30
+ foreach: bool = True):
30
31
  defaults = {"lr": lr, "betas": betas, "shampoo_beta": shampoo_beta, "eps": eps, "weight_decay": weight_decay,
31
32
  "precondition_frequency": precondition_frequency, "max_precond_dim": max_precond_dim,
32
33
  "merge_dims": merge_dims, "precondition_1d": precondition_1d, "normalize_grads": normalize_grads,
33
34
  "correct_bias": correct_bias, 'warmup_steps': warmup_steps, 'split': split}
34
- super().__init__(params, defaults)
35
+ super().__init__(params, defaults, foreach)
35
36
  self._data_format = data_format
36
37
 
37
38
  def _step(self, group):
@@ -59,7 +60,7 @@ class ForeachSOAP(StatefulOptimizer):
59
60
  vals.append((p, g, grad_projected, exp_avg, exp_avg_sq))
60
61
 
61
62
  if not vals:
62
- return
63
+ return
63
64
 
64
65
  p_list, grad, grad_projected, exp_avg, exp_avg_sq = zip(*vals)
65
66
  beta1, beta2 = group["betas"]
@@ -38,7 +38,8 @@ class ForeachPaLMPAdam(PSGDBase):
38
38
  max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
39
39
  momentum_into_precond_update=True, warmup_steps: int = 1, betas=(None, None), beta: float = 0.9,
40
40
  beta2_scale: float = 0.8, merge_dims: bool = False, split: bool = False, clip_fn: callable = None,
41
- store_triu_as_line: bool = True):
41
+ store_triu_as_line: bool = True,
42
+ foreach: bool = True):
42
43
  if not 0.0 <= lr:
43
44
  raise ValueError(f"Invalid learning rate: {lr}")
44
45
  if not 0.0 <= weight_decay:
@@ -60,7 +61,7 @@ class ForeachPaLMPAdam(PSGDBase):
60
61
  precond_init_scale=1.0, # precond init scale hardcoded to 1.0
61
62
  step=0, warmup_steps=warmup_steps, beta=beta, beta2_scale=beta2_scale, merge_dims=merge_dims,
62
63
  split=split, store_triu_as_line=store_triu_as_line)
63
- super().__init__(params, defaults)
64
+ super().__init__(params, defaults, foreach)
64
65
 
65
66
  self._prob_step = 0
66
67
 
@@ -90,7 +91,7 @@ class ForeachPaLMPAdam(PSGDBase):
90
91
  state['exp_avg'] = torch.zeros_like(g)
91
92
  state['exp_avg_sq'] = torch.zeros_like(g)
92
93
  Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular,
93
- min_ndim_triangular, memory_save_mode, dtype=g.dtype)
94
+ min_ndim_triangular, memory_save_mode, dtype=g.dtype)
94
95
  state['Q'] = triu_to_line(Q) if store_triu_as_line else Q
95
96
 
96
97
  vals.append((p, g, state["Q"], state['exp_avg'], state['exp_avg_sq']))
@@ -6,13 +6,14 @@ from .utils import schedule_free_, warmup, ScheduleFree, exp_avg_sq_, beta_debia
6
6
 
7
7
  class PaLMForeachSFAdamW(ScheduleFree):
8
8
  def __init__(self, params, lr=0.0025, beta=0.9, betas=(None, None), eps=1e-8, weight_decay=0, warmup_steps=0, r=0.0,
9
- weight_lr_power=2.0, beta2_scale: float = 0.8):
9
+ weight_lr_power=2.0, beta2_scale: float = 0.8,
10
+ foreach: bool = True):
10
11
  if betas[0] is not None:
11
12
  beta = betas[0]
12
13
  defaults = dict(lr=lr, beta=beta, eps=eps, r=r, k=0, warmup_steps=warmup_steps, train_mode=True, weight_sum=0.0,
13
14
  lr_max=-1.0, weight_lr_power=weight_lr_power, weight_decay=weight_decay,
14
15
  beta2_scale=beta2_scale)
15
- super().__init__(params, defaults)
16
+ super().__init__(params, defaults, foreach)
16
17
 
17
18
  def _step(self, group):
18
19
  eps = group['eps']
@@ -32,7 +32,8 @@ class PaLMForeachSOAP(StatefulOptimizer):
32
32
  max_precond_dim: int = 2048, #
33
33
  merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
34
34
  data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1,
35
- beta2_scale: float = 0.8, split: bool = False):
35
+ beta2_scale: float = 0.8, split: bool = False,
36
+ foreach: bool = True):
36
37
  if betas[0] is not None:
37
38
  beta = betas[0]
38
39
  defaults = {"lr": lr, "beta": beta, "shampoo_beta": shampoo_beta, "eps": eps, "weight_decay": weight_decay,
@@ -40,7 +41,7 @@ class PaLMForeachSOAP(StatefulOptimizer):
40
41
  "merge_dims": merge_dims, "precondition_1d": precondition_1d, "normalize_grads": normalize_grads,
41
42
  "correct_bias": correct_bias, 'warmup_steps': warmup_steps, 'beta2_scale': beta2_scale,
42
43
  'split': split}
43
- super().__init__(params, defaults)
44
+ super().__init__(params, defaults, foreach)
44
45
  self._data_format = data_format
45
46
 
46
47
  def _step(self, group):
@@ -27,13 +27,14 @@ class PrecondScheduleForeachSOAP(StatefulOptimizer):
27
27
  weight_decay: float = 0.01, precondition_frequency: int = 2, max_precond_dim: int = 2048, #
28
28
  merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
29
29
  data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1,
30
- precond_scheduler=(1 / 3, 9), split: bool = False):
30
+ precond_scheduler=(1 / 3, 9), split: bool = False,
31
+ foreach: bool = True):
31
32
  defaults = {"lr": lr, "betas": betas, "shampoo_beta": shampoo_beta, "eps": eps, "weight_decay": weight_decay,
32
33
  "precondition_frequency": precondition_frequency, "max_precond_dim": max_precond_dim,
33
34
  "merge_dims": merge_dims, "precondition_1d": precondition_1d, "normalize_grads": normalize_grads,
34
35
  "correct_bias": correct_bias, 'warmup_steps': warmup_steps, 'precond_scheduler': precond_scheduler,
35
36
  'split': split}
36
- super().__init__(params, defaults)
37
+ super().__init__(params, defaults, foreach)
37
38
  self._data_format = data_format
38
39
  self.rng = random.Random(0x120983109)
39
40
 
@@ -32,7 +32,8 @@ class PrecondSchedulePaLMForeachSOAP(StatefulOptimizer):
32
32
  weight_decay: float = 0.01, precondition_frequency: int = 2, max_precond_dim: int = 2048, #
33
33
  merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
34
34
  data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1,
35
- precond_scheduler=(1 / 3, 9), betas=(None, None), beta2_scale: float = 0.8, split: bool = False):
35
+ precond_scheduler=(1 / 3, 9), betas=(None, None), beta2_scale: float = 0.8, split: bool = False,
36
+ foreach: bool = True):
36
37
  if betas[0] is not None:
37
38
  beta = betas[0]
38
39
  defaults = {"lr": lr, "beta": beta, "shampoo_beta": shampoo_beta, "eps": eps, "weight_decay": weight_decay,
@@ -40,7 +41,7 @@ class PrecondSchedulePaLMForeachSOAP(StatefulOptimizer):
40
41
  "merge_dims": merge_dims, "precondition_1d": precondition_1d, "normalize_grads": normalize_grads,
41
42
  "correct_bias": correct_bias, 'warmup_steps': warmup_steps, 'precond_scheduler': precond_scheduler,
42
43
  'beta2_scale': beta2_scale, 'split': split}
43
- super().__init__(params, defaults)
44
+ super().__init__(params, defaults, foreach)
44
45
  self._data_format = data_format
45
46
  self.rng = random.Random(0x120983109)
46
47
 
@@ -41,7 +41,7 @@ class PrecondScheduleSFPaLMSOAP(ScheduleFree):
41
41
  merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
42
42
  data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1, r=0.0,
43
43
  weight_lr_power=2.0, gradient_clip_val: float = 0.1, precond_scheduler=(1 / 3, 9),
44
- betas=(None, None), split: bool = False):
44
+ betas=(None, None), split: bool = False, foreach: bool = True):
45
45
  if betas[0] is not None:
46
46
  beta = betas[0]
47
47
  defaults = {"lr": lr, "beta": beta, "beta2_scale": beta2_scale, "eps": eps, "weight_decay": weight_decay,
@@ -50,7 +50,7 @@ class PrecondScheduleSFPaLMSOAP(ScheduleFree):
50
50
  "correct_bias": correct_bias, 'warmup_steps': warmup_steps, 'r': r,
51
51
  'weight_lr_power': weight_lr_power, 'train_mode': True, 'step': -1, 'weight_sum': 0,
52
52
  'gradient_clip_val': gradient_clip_val, 'precond_scheduler': precond_scheduler, 'split': split}
53
- super().__init__(params, defaults)
53
+ super().__init__(params, defaults, foreach)
54
54
  self._data_format = data_format
55
55
  self.rng = random.Random(0x120983109)
56
56
 
@@ -59,7 +59,7 @@ class PrecondScheduleSFPaLMSOAP(ScheduleFree):
59
59
  max_precond_dim = group['max_precond_dim']
60
60
  precondition_1d = group['precondition_1d']
61
61
 
62
- step = group['step'] = group.get("step", -1) + 1
62
+ step = group['step'] = group.get("step", 0) + 1
63
63
 
64
64
  for p in group["params"]:
65
65
  if p.grad is None:
@@ -38,7 +38,8 @@ class ForeachPSGDKron(PSGDBase):
38
38
  def __init__(self, params, lr=0.001, beta=0.9, weight_decay=0.0, preconditioner_update_probability=None,
39
39
  max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
40
40
  momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
41
- split: bool = False, clip_fn: Optional[callable] = None, store_triu_as_line: bool = True):
41
+ split: bool = False, clip_fn: Optional[callable] = None, store_triu_as_line: bool = True,
42
+ foreach: bool = True):
42
43
  if not 0.0 <= lr:
43
44
  raise ValueError(f"Invalid learning rate: {lr}")
44
45
  if not 0.0 <= beta < 1.0:
@@ -60,7 +61,7 @@ class ForeachPSGDKron(PSGDBase):
60
61
  precond_init_scale=1.0, # precond init scale hardcoded to 1.0
61
62
  step=0, warmup_steps=warmup_steps, merge_dims=merge_dims, split=split,
62
63
  store_triu_as_line=store_triu_as_line)
63
- super().__init__(params, defaults)
64
+ super().__init__(params, defaults, foreach)
64
65
 
65
66
  self._prob_step = 0
66
67
 
@@ -114,7 +115,8 @@ class ForeachPSGDKron(PSGDBase):
114
115
 
115
116
  if do_update:
116
117
  self.balance([g], [q])
117
- self.do_update([p], [ea if momentum_into_precond_update else g], [q], precond_lr, [q_orig] if store_triu_as_line else None)
118
+ self.do_update([p], [ea if momentum_into_precond_update else g], [q], precond_lr,
119
+ [q_orig] if store_triu_as_line else None)
118
120
  set_(g, psgd_precond_grad(q, self.state_(p)["exprs"], ea))
119
121
 
120
122
  grad_list = self.clip_fn(grad_list)
@@ -36,7 +36,8 @@ class ForeachPurePSGD(PSGDBase):
36
36
  def __init__(self, params, lr=0.001, weight_decay=0.0, preconditioner_update_probability=None,
37
37
  max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
38
38
  momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
39
- split: bool = False, clip_fn: callable = None, store_triu_as_line: bool = True):
39
+ split: bool = False, clip_fn: callable = None, store_triu_as_line: bool = True,
40
+ foreach: bool = True):
40
41
  if not 0.0 <= lr:
41
42
  raise ValueError(f"Invalid learning rate: {lr}")
42
43
  if not 0.0 <= weight_decay:
@@ -56,7 +57,7 @@ class ForeachPurePSGD(PSGDBase):
56
57
  precond_init_scale=1.0, # precond init scale hardcoded to 1.0
57
58
  step=0, warmup_steps=warmup_steps, merge_dims=merge_dims, split=split,
58
59
  store_triu_as_line=store_triu_as_line)
59
- super().__init__(params, defaults)
60
+ super().__init__(params, defaults, foreach)
60
61
 
61
62
  self._prob_step = 0
62
63
 
@@ -33,7 +33,8 @@ class SFPaLMForeachSOAP(ScheduleFree):
33
33
  weight_decay: float = 0.01, precondition_frequency: int = 2, max_precond_dim: int = 2048, #
34
34
  merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
35
35
  data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1, r=0.0,
36
- weight_lr_power=2.0, gradient_clip_val: float = 0.1, betas=(None, None), split: bool = False):
36
+ weight_lr_power=2.0, gradient_clip_val: float = 0.1, betas=(None, None), split: bool = False,
37
+ foreach: bool = True):
37
38
  if betas[0] is not None:
38
39
  beta = betas[0]
39
40
  defaults = {"lr": lr, "beta": beta, "beta2_scale": beta2_scale, "eps": eps, "weight_decay": weight_decay,
@@ -42,7 +43,7 @@ class SFPaLMForeachSOAP(ScheduleFree):
42
43
  "correct_bias": correct_bias, 'warmup_steps': warmup_steps, 'r': r,
43
44
  'weight_lr_power': weight_lr_power, 'train_mode': True, 'step': -1,
44
45
  'gradient_clip_val': gradient_clip_val, 'weight_sum': 0, 'split': split}
45
- super().__init__(params, defaults)
46
+ super().__init__(params, defaults, foreach)
46
47
  self._data_format = data_format
47
48
  self.rng = random.Random(0x120983109)
48
49
 
@@ -51,7 +52,7 @@ class SFPaLMForeachSOAP(ScheduleFree):
51
52
  max_precond_dim = group['max_precond_dim']
52
53
  precondition_1d = group['precondition_1d']
53
54
 
54
- step = group['step'] = group.get("step", -1) + 1
55
+ step = group['step'] = group.get("step", 0) + 1
55
56
 
56
57
  for p in group["params"]:
57
58
  if p.grad is None:
@@ -383,8 +383,25 @@ def project(grad, Q, back: bool):
383
383
 
384
384
 
385
385
  class StatefulOptimizer(torch.optim.Optimizer):
386
+ def __init__(self, params, defaults, foreach: bool = True):
387
+ super().__init__(params, {**defaults, 'foreach': foreach})
388
+ self.fake_groups = {}
389
+
390
+ def key(self, param: torch.Tensor):
391
+ return (param.data_ptr(), tuple(param.shape))
392
+
393
+ def get_groups(self, group):
394
+ if group['foreach']:
395
+ return [group]
396
+
397
+ for p in group['params']:
398
+ if self.key(p) not in self.fake_groups:
399
+ self.fake_groups[self.key(p)] = {**group, 'params': [p]}
400
+
401
+ return [self.fake_groups[self.key(p)] for p in group['params']]
402
+
386
403
  def state_(self, arg: torch.Tensor):
387
- return self.state[(arg.data_ptr(), tuple(arg.shape))]
404
+ return self.state[self.key(arg)]
388
405
 
389
406
  def state_size(self) -> int:
390
407
  total_bytes = 0
@@ -409,8 +426,9 @@ class StatefulOptimizer(torch.optim.Optimizer):
409
426
  with torch.enable_grad():
410
427
  loss = closure()
411
428
  with torch.no_grad():
412
- for group in self.param_groups:
413
- self._step(group)
429
+ for top_group in self.param_groups:
430
+ for group in self.get_groups(top_group):
431
+ self._step(group)
414
432
  return loss
415
433
 
416
434
 
@@ -754,8 +772,8 @@ def update_triu_(q_state, materialised):
754
772
 
755
773
 
756
774
  class PSGDBase(StatefulOptimizer):
757
- def __init__(self, parameters, groups):
758
- super().__init__(parameters, groups)
775
+ def __init__(self, parameters, groups, foreach: bool = True):
776
+ super().__init__(parameters, groups, foreach)
759
777
  self.rng = random.Random(0x1923213)
760
778
  self._tiny = torch.finfo(torch.bfloat16).tiny
761
779
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: heavyball
3
- Version: 0.15.1
3
+ Version: 0.16.0
4
4
  Summary: Efficient optimizers
5
5
  Home-page: https://github.com/clashluke/heavyball
6
6
  Author: Lucas Nestler
@@ -39,12 +39,14 @@ recommended experimental optimizer is `ForeachPSGDKron`.
39
39
 
40
40
  * **Stochastic Rounding**: [FP32 convergence with BF16 parameters](https://github.com/pytorch/pytorch/issues/120376)
41
41
  * **Inplace EMA**: Same math, but less memory, less compute and higher stability
42
- * **Foreach**: Fast multi-tensor application
42
+ * **Foreach**: Fast multi-tensor application (turn it off to save memory via `foreach=False`)
43
43
  * **PaLM Beta2**: Fast initial
44
44
  convergence, [stable late convergence](https://x.com/_clashluke/status/1820810798693818761)
45
45
  * **ScheduleFree**: No learning rate schedule, but better convergence
46
46
  * [**Preconditioner Schedule**](https://github.com/lixilinx/psgd_torch/): Improved loss-per-step in early convergence,
47
47
  better step-per-second in late convergence (explained below)
48
+ * **Memory-efficient storage** PSGD supports `store_triu_as_line` (default: `True`) to trade off memory usage for memory
49
+ bandwidth; turn it off for lower overheads (for more, see [PSGD Efficiency](docs/psgd_efficiency.md))
48
50
 
49
51
  ## Getting started
50
52
 
@@ -25,6 +25,7 @@ heavyball.egg-info/dependency_links.txt
25
25
  heavyball.egg-info/requires.txt
26
26
  heavyball.egg-info/top_level.txt
27
27
  test/test_closure.py
28
+ test/test_foreach.py
28
29
  test/test_memory.py
29
30
  test/test_merge.py
30
31
  test/test_no_grad.py
@@ -10,7 +10,7 @@ setuptools.setup(
10
10
  name='heavyball',
11
11
  license='BSD',
12
12
  description='Efficient optimizers',
13
- version='0.15.1',
13
+ version='0.16.0',
14
14
  long_description=README,
15
15
  url='https://github.com/clashluke/heavyball',
16
16
  packages=setuptools.find_packages(),
@@ -0,0 +1,65 @@
1
+ import heavyball
2
+ import heavyball.utils
3
+ import pytest
4
+ import torch
5
+ from benchmark.utils import get_optim
6
+ from heavyball.utils import clean, set_torch, PSGDBase
7
+ from torch import nn
8
+
9
+
10
+ def get_memory():
11
+ clean()
12
+ torch.cuda.synchronize()
13
+ clean()
14
+ torch.cuda.synchronize()
15
+ return torch.cuda.memory_allocated()
16
+
17
+
18
+ @pytest.mark.parametrize("opt", heavyball.__all__)
19
+ @pytest.mark.parametrize("size,depth", [(256, 128)])
20
+ def test_foreach(opt, size, depth: int, iterations: int = 5, outer_iterations: int = 3):
21
+ set_torch()
22
+
23
+ opt = getattr(heavyball, opt)
24
+
25
+ peaks = []
26
+ losses = []
27
+
28
+ for foreach in [True, False]:
29
+ peaks.append([])
30
+ losses.append([])
31
+
32
+ for i in range(outer_iterations):
33
+ torch.manual_seed(0x2131290)
34
+ clean()
35
+ model = nn.Sequential(*[nn.Linear(size, size) for _ in range(depth)]).cuda()
36
+ clean()
37
+
38
+ torch.cuda.reset_peak_memory_stats()
39
+ torch.cuda.reset_max_memory_allocated()
40
+ torch.cuda.reset_max_memory_cached()
41
+ torch.cuda.reset_accumulated_memory_stats()
42
+
43
+ clean()
44
+ o = get_optim(opt, model.parameters(), lr=1e-3, foreach=foreach)
45
+ clean()
46
+
47
+ for _ in range(iterations):
48
+ loss = model(torch.randn((1, size)).cuda()).sum()
49
+ loss.backward()
50
+ o.step()
51
+ o.zero_grad()
52
+ losses[-1].append(loss.detach())
53
+
54
+ del model, o
55
+ clean()
56
+
57
+ peak = torch.cuda.memory_stats()['allocated_bytes.all.peak']
58
+
59
+ if i > 0:
60
+ peaks[-1].append(peak)
61
+
62
+ for p0, p1 in zip(*peaks):
63
+ assert p0 > p1
64
+ for l0, l1 in zip(*losses): # increase error tolerance for PSGD, as we have different RNGs -> expected differences
65
+ assert torch.allclose(l0, l1, rtol=0.01 if isinstance(opt, PSGDBase) else 1e-5)
File without changes
File without changes
File without changes
File without changes