heavyball 0.15.1__py3-none-any.whl → 0.16.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -39,7 +39,8 @@ class ForeachCachedPSGDKron(PSGDBase):
39
39
  def __init__(self, params, lr=0.001, beta=0.9, weight_decay=0.0, preconditioner_update_probability=None,
40
40
  max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
41
41
  momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
42
- split: bool = False, clip_fn: Optional[callable] = None, store_triu_as_line: bool = True):
42
+ split: bool = False, clip_fn: Optional[callable] = None, store_triu_as_line: bool = True,
43
+ foreach: bool = True):
43
44
  if not 0.0 <= lr:
44
45
  raise ValueError(f"Invalid learning rate: {lr}")
45
46
  if not 0.0 <= beta < 1.0:
@@ -61,7 +62,7 @@ class ForeachCachedPSGDKron(PSGDBase):
61
62
  precond_init_scale=1.0, # precond init scale hardcoded to 1.0
62
63
  step=0, warmup_steps=warmup_steps, merge_dims=merge_dims, split=split,
63
64
  store_triu_as_line=store_triu_as_line)
64
- super().__init__(params, defaults)
65
+ super().__init__(params, defaults, foreach)
65
66
 
66
67
  self._prob_step = 0
67
68
 
heavyball/delayed_psgd.py CHANGED
@@ -38,7 +38,8 @@ class ForeachDelayedPSGD(PSGDBase):
38
38
  def __init__(self, params, lr=0.001, beta=0.9, weight_decay=0.0, preconditioner_update_probability=None,
39
39
  max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
40
40
  momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
41
- split: bool = False, clip_fn: callable = None, store_triu_as_line: bool = True):
41
+ split: bool = False, clip_fn: callable = None, store_triu_as_line: bool = True,
42
+ foreach: bool = True):
42
43
  if not 0.0 <= lr:
43
44
  raise ValueError(f"Invalid learning rate: {lr}")
44
45
  if not 0.0 <= beta < 1.0:
@@ -60,7 +61,7 @@ class ForeachDelayedPSGD(PSGDBase):
60
61
  precond_init_scale=1.0, # precond init scale hardcoded to 1.0
61
62
  step=0, warmup_steps=warmup_steps, merge_dims=merge_dims, split=split,
62
63
  store_triu_as_line=store_triu_as_line)
63
- super().__init__(params, defaults)
64
+ super().__init__(params, defaults, foreach)
64
65
 
65
66
  self._prob_step = 0
66
67
 
@@ -113,7 +114,8 @@ class ForeachDelayedPSGD(PSGDBase):
113
114
  q = line_to_triu(q_orig) if store_triu_as_line else q_orig
114
115
  new = psgd_precond_grad(q, self.state_(p)["exprs"], ea)
115
116
  if do_update:
116
- self.do_update([p], [ea if momentum_into_precond_update else g], [q], precond_lr, [q_orig] if store_triu_as_line else None)
117
+ self.do_update([p], [ea if momentum_into_precond_update else g], [q], precond_lr,
118
+ [q_orig] if store_triu_as_line else None)
117
119
  self.balance([g], [q])
118
120
  set_(g, new)
119
121
 
@@ -5,10 +5,11 @@ from .utils import warmup, exp_avg_sq_, beta_debias, update_param_, StatefulOpti
5
5
 
6
6
 
7
7
  class ForeachAdamW(StatefulOptimizer):
8
- def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0):
8
+ def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
9
+ foreach: bool = True):
9
10
  defaults = dict(lr=lr, betas=betas, eps=eps, k=0, warmup_steps=warmup_steps, train_mode=True, weight_sum=0.0,
10
11
  lr_max=-1.0, weight_decay=weight_decay)
11
- super().__init__(params, defaults)
12
+ super().__init__(params, defaults, foreach)
12
13
 
13
14
  def _step(self, group):
14
15
  eps = group['eps']
@@ -6,10 +6,11 @@ from .utils import warmup, beta_debias, update_param_, StatefulOptimizer
6
6
 
7
7
  class ForeachADOPT(StatefulOptimizer):
8
8
 
9
- def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0):
9
+ def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
10
+ foreach: bool = True):
10
11
  defaults = dict(lr=lr, betas=betas, eps=eps, k=0, warmup_steps=warmup_steps, train_mode=True, weight_sum=0.0,
11
12
  lr_max=-1.0, weight_decay=weight_decay)
12
- super().__init__(params, defaults)
13
+ super().__init__(params, defaults, foreach)
13
14
 
14
15
  def _step(self, group):
15
16
  eps = group['eps']
@@ -6,10 +6,11 @@ from .utils import warmup, exp_avg_sq_, beta_debias, update_param_, StatefulOpti
6
6
 
7
7
  class ForeachLaProp(StatefulOptimizer):
8
8
 
9
- def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=1):
9
+ def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=1,
10
+ foreach: bool = True):
10
11
  defaults = dict(lr=lr, betas=betas, eps=eps, k=0, warmup_steps=warmup_steps, train_mode=True, weight_sum=0.0,
11
12
  lr_max=-1.0, weight_decay=weight_decay)
12
- super().__init__(params, defaults)
13
+ super().__init__(params, defaults, foreach)
13
14
 
14
15
  def _step(self, group):
15
16
  eps = group['eps']
@@ -6,12 +6,12 @@ from .utils import schedule_free_, warmup, ScheduleFree, exp_avg_sq_, beta_debia
6
6
 
7
7
  class ForeachSFAdamW(ScheduleFree):
8
8
  def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0, r=0.0,
9
- weight_lr_power=2.0, foreach=hasattr(torch, "_foreach_mul_")):
9
+ weight_lr_power=2.0, foreach: bool = True):
10
10
 
11
11
  defaults = dict(lr=lr, betas=betas, eps=eps, r=r, k=0, warmup_steps=warmup_steps, train_mode=True,
12
12
  weight_sum=0.0, lr_max=-1.0, weight_lr_power=weight_lr_power, weight_decay=weight_decay,
13
13
  foreach=foreach)
14
- super().__init__(params, defaults)
14
+ super().__init__(params, defaults, foreach)
15
15
 
16
16
  def _step(self, group):
17
17
  eps = group['eps']
@@ -48,7 +48,7 @@ class ForeachSFAdamW(ScheduleFree):
48
48
  torch._foreach_add_(grad, y, alpha=decay)
49
49
 
50
50
  lr = warmup(group['lr'], k + 1, group['warmup_steps'])
51
- group['weight_sum'] = schedule_free_(lr, group['weight_lr_power'], group['weight_sum'], group['betas'][0],
52
- y, z, grad, group['r'], k + 1)
51
+ group['weight_sum'] = schedule_free_(lr, group['weight_lr_power'], group['weight_sum'], group['betas'][0], y, z,
52
+ grad, group['r'], k + 1)
53
53
 
54
54
  group['k'] = k + 1
heavyball/foreach_soap.py CHANGED
@@ -26,12 +26,13 @@ class ForeachSOAP(StatefulOptimizer):
26
26
  weight_decay: float = 0.01, precondition_frequency: int = 2, max_precond_dim: int = 2048, #
27
27
  merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
28
28
  data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1,
29
- split: bool = False):
29
+ split: bool = False,
30
+ foreach: bool = True):
30
31
  defaults = {"lr": lr, "betas": betas, "shampoo_beta": shampoo_beta, "eps": eps, "weight_decay": weight_decay,
31
32
  "precondition_frequency": precondition_frequency, "max_precond_dim": max_precond_dim,
32
33
  "merge_dims": merge_dims, "precondition_1d": precondition_1d, "normalize_grads": normalize_grads,
33
34
  "correct_bias": correct_bias, 'warmup_steps': warmup_steps, 'split': split}
34
- super().__init__(params, defaults)
35
+ super().__init__(params, defaults, foreach)
35
36
  self._data_format = data_format
36
37
 
37
38
  def _step(self, group):
@@ -59,7 +60,7 @@ class ForeachSOAP(StatefulOptimizer):
59
60
  vals.append((p, g, grad_projected, exp_avg, exp_avg_sq))
60
61
 
61
62
  if not vals:
62
- return
63
+ return
63
64
 
64
65
  p_list, grad, grad_projected, exp_avg, exp_avg_sq = zip(*vals)
65
66
  beta1, beta2 = group["betas"]
heavyball/p_adam.py CHANGED
@@ -38,7 +38,8 @@ class ForeachPaLMPAdam(PSGDBase):
38
38
  max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
39
39
  momentum_into_precond_update=True, warmup_steps: int = 1, betas=(None, None), beta: float = 0.9,
40
40
  beta2_scale: float = 0.8, merge_dims: bool = False, split: bool = False, clip_fn: callable = None,
41
- store_triu_as_line: bool = True):
41
+ store_triu_as_line: bool = True,
42
+ foreach: bool = True):
42
43
  if not 0.0 <= lr:
43
44
  raise ValueError(f"Invalid learning rate: {lr}")
44
45
  if not 0.0 <= weight_decay:
@@ -60,7 +61,7 @@ class ForeachPaLMPAdam(PSGDBase):
60
61
  precond_init_scale=1.0, # precond init scale hardcoded to 1.0
61
62
  step=0, warmup_steps=warmup_steps, beta=beta, beta2_scale=beta2_scale, merge_dims=merge_dims,
62
63
  split=split, store_triu_as_line=store_triu_as_line)
63
- super().__init__(params, defaults)
64
+ super().__init__(params, defaults, foreach)
64
65
 
65
66
  self._prob_step = 0
66
67
 
@@ -90,7 +91,7 @@ class ForeachPaLMPAdam(PSGDBase):
90
91
  state['exp_avg'] = torch.zeros_like(g)
91
92
  state['exp_avg_sq'] = torch.zeros_like(g)
92
93
  Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular,
93
- min_ndim_triangular, memory_save_mode, dtype=g.dtype)
94
+ min_ndim_triangular, memory_save_mode, dtype=g.dtype)
94
95
  state['Q'] = triu_to_line(Q) if store_triu_as_line else Q
95
96
 
96
97
  vals.append((p, g, state["Q"], state['exp_avg'], state['exp_avg_sq']))
@@ -6,13 +6,14 @@ from .utils import schedule_free_, warmup, ScheduleFree, exp_avg_sq_, beta_debia
6
6
 
7
7
  class PaLMForeachSFAdamW(ScheduleFree):
8
8
  def __init__(self, params, lr=0.0025, beta=0.9, betas=(None, None), eps=1e-8, weight_decay=0, warmup_steps=0, r=0.0,
9
- weight_lr_power=2.0, beta2_scale: float = 0.8):
9
+ weight_lr_power=2.0, beta2_scale: float = 0.8,
10
+ foreach: bool = True):
10
11
  if betas[0] is not None:
11
12
  beta = betas[0]
12
13
  defaults = dict(lr=lr, beta=beta, eps=eps, r=r, k=0, warmup_steps=warmup_steps, train_mode=True, weight_sum=0.0,
13
14
  lr_max=-1.0, weight_lr_power=weight_lr_power, weight_decay=weight_decay,
14
15
  beta2_scale=beta2_scale)
15
- super().__init__(params, defaults)
16
+ super().__init__(params, defaults, foreach)
16
17
 
17
18
  def _step(self, group):
18
19
  eps = group['eps']
@@ -32,7 +32,8 @@ class PaLMForeachSOAP(StatefulOptimizer):
32
32
  max_precond_dim: int = 2048, #
33
33
  merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
34
34
  data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1,
35
- beta2_scale: float = 0.8, split: bool = False):
35
+ beta2_scale: float = 0.8, split: bool = False,
36
+ foreach: bool = True):
36
37
  if betas[0] is not None:
37
38
  beta = betas[0]
38
39
  defaults = {"lr": lr, "beta": beta, "shampoo_beta": shampoo_beta, "eps": eps, "weight_decay": weight_decay,
@@ -40,7 +41,7 @@ class PaLMForeachSOAP(StatefulOptimizer):
40
41
  "merge_dims": merge_dims, "precondition_1d": precondition_1d, "normalize_grads": normalize_grads,
41
42
  "correct_bias": correct_bias, 'warmup_steps': warmup_steps, 'beta2_scale': beta2_scale,
42
43
  'split': split}
43
- super().__init__(params, defaults)
44
+ super().__init__(params, defaults, foreach)
44
45
  self._data_format = data_format
45
46
 
46
47
  def _step(self, group):
@@ -27,13 +27,14 @@ class PrecondScheduleForeachSOAP(StatefulOptimizer):
27
27
  weight_decay: float = 0.01, precondition_frequency: int = 2, max_precond_dim: int = 2048, #
28
28
  merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
29
29
  data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1,
30
- precond_scheduler=(1 / 3, 9), split: bool = False):
30
+ precond_scheduler=(1 / 3, 9), split: bool = False,
31
+ foreach: bool = True):
31
32
  defaults = {"lr": lr, "betas": betas, "shampoo_beta": shampoo_beta, "eps": eps, "weight_decay": weight_decay,
32
33
  "precondition_frequency": precondition_frequency, "max_precond_dim": max_precond_dim,
33
34
  "merge_dims": merge_dims, "precondition_1d": precondition_1d, "normalize_grads": normalize_grads,
34
35
  "correct_bias": correct_bias, 'warmup_steps': warmup_steps, 'precond_scheduler': precond_scheduler,
35
36
  'split': split}
36
- super().__init__(params, defaults)
37
+ super().__init__(params, defaults, foreach)
37
38
  self._data_format = data_format
38
39
  self.rng = random.Random(0x120983109)
39
40
 
@@ -32,7 +32,8 @@ class PrecondSchedulePaLMForeachSOAP(StatefulOptimizer):
32
32
  weight_decay: float = 0.01, precondition_frequency: int = 2, max_precond_dim: int = 2048, #
33
33
  merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
34
34
  data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1,
35
- precond_scheduler=(1 / 3, 9), betas=(None, None), beta2_scale: float = 0.8, split: bool = False):
35
+ precond_scheduler=(1 / 3, 9), betas=(None, None), beta2_scale: float = 0.8, split: bool = False,
36
+ foreach: bool = True):
36
37
  if betas[0] is not None:
37
38
  beta = betas[0]
38
39
  defaults = {"lr": lr, "beta": beta, "shampoo_beta": shampoo_beta, "eps": eps, "weight_decay": weight_decay,
@@ -40,7 +41,7 @@ class PrecondSchedulePaLMForeachSOAP(StatefulOptimizer):
40
41
  "merge_dims": merge_dims, "precondition_1d": precondition_1d, "normalize_grads": normalize_grads,
41
42
  "correct_bias": correct_bias, 'warmup_steps': warmup_steps, 'precond_scheduler': precond_scheduler,
42
43
  'beta2_scale': beta2_scale, 'split': split}
43
- super().__init__(params, defaults)
44
+ super().__init__(params, defaults, foreach)
44
45
  self._data_format = data_format
45
46
  self.rng = random.Random(0x120983109)
46
47
 
@@ -41,7 +41,7 @@ class PrecondScheduleSFPaLMSOAP(ScheduleFree):
41
41
  merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
42
42
  data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1, r=0.0,
43
43
  weight_lr_power=2.0, gradient_clip_val: float = 0.1, precond_scheduler=(1 / 3, 9),
44
- betas=(None, None), split: bool = False):
44
+ betas=(None, None), split: bool = False, foreach: bool = True):
45
45
  if betas[0] is not None:
46
46
  beta = betas[0]
47
47
  defaults = {"lr": lr, "beta": beta, "beta2_scale": beta2_scale, "eps": eps, "weight_decay": weight_decay,
@@ -50,7 +50,7 @@ class PrecondScheduleSFPaLMSOAP(ScheduleFree):
50
50
  "correct_bias": correct_bias, 'warmup_steps': warmup_steps, 'r': r,
51
51
  'weight_lr_power': weight_lr_power, 'train_mode': True, 'step': -1, 'weight_sum': 0,
52
52
  'gradient_clip_val': gradient_clip_val, 'precond_scheduler': precond_scheduler, 'split': split}
53
- super().__init__(params, defaults)
53
+ super().__init__(params, defaults, foreach)
54
54
  self._data_format = data_format
55
55
  self.rng = random.Random(0x120983109)
56
56
 
@@ -59,7 +59,7 @@ class PrecondScheduleSFPaLMSOAP(ScheduleFree):
59
59
  max_precond_dim = group['max_precond_dim']
60
60
  precondition_1d = group['precondition_1d']
61
61
 
62
- step = group['step'] = group.get("step", -1) + 1
62
+ step = group['step'] = group.get("step", 0) + 1
63
63
 
64
64
  for p in group["params"]:
65
65
  if p.grad is None:
heavyball/psgd_kron.py CHANGED
@@ -38,7 +38,8 @@ class ForeachPSGDKron(PSGDBase):
38
38
  def __init__(self, params, lr=0.001, beta=0.9, weight_decay=0.0, preconditioner_update_probability=None,
39
39
  max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
40
40
  momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
41
- split: bool = False, clip_fn: Optional[callable] = None, store_triu_as_line: bool = True):
41
+ split: bool = False, clip_fn: Optional[callable] = None, store_triu_as_line: bool = True,
42
+ foreach: bool = True):
42
43
  if not 0.0 <= lr:
43
44
  raise ValueError(f"Invalid learning rate: {lr}")
44
45
  if not 0.0 <= beta < 1.0:
@@ -60,7 +61,7 @@ class ForeachPSGDKron(PSGDBase):
60
61
  precond_init_scale=1.0, # precond init scale hardcoded to 1.0
61
62
  step=0, warmup_steps=warmup_steps, merge_dims=merge_dims, split=split,
62
63
  store_triu_as_line=store_triu_as_line)
63
- super().__init__(params, defaults)
64
+ super().__init__(params, defaults, foreach)
64
65
 
65
66
  self._prob_step = 0
66
67
 
@@ -114,7 +115,8 @@ class ForeachPSGDKron(PSGDBase):
114
115
 
115
116
  if do_update:
116
117
  self.balance([g], [q])
117
- self.do_update([p], [ea if momentum_into_precond_update else g], [q], precond_lr, [q_orig] if store_triu_as_line else None)
118
+ self.do_update([p], [ea if momentum_into_precond_update else g], [q], precond_lr,
119
+ [q_orig] if store_triu_as_line else None)
118
120
  set_(g, psgd_precond_grad(q, self.state_(p)["exprs"], ea))
119
121
 
120
122
  grad_list = self.clip_fn(grad_list)
heavyball/pure_psgd.py CHANGED
@@ -36,7 +36,8 @@ class ForeachPurePSGD(PSGDBase):
36
36
  def __init__(self, params, lr=0.001, weight_decay=0.0, preconditioner_update_probability=None,
37
37
  max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
38
38
  momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
39
- split: bool = False, clip_fn: callable = None, store_triu_as_line: bool = True):
39
+ split: bool = False, clip_fn: callable = None, store_triu_as_line: bool = True,
40
+ foreach: bool = True):
40
41
  if not 0.0 <= lr:
41
42
  raise ValueError(f"Invalid learning rate: {lr}")
42
43
  if not 0.0 <= weight_decay:
@@ -56,7 +57,7 @@ class ForeachPurePSGD(PSGDBase):
56
57
  precond_init_scale=1.0, # precond init scale hardcoded to 1.0
57
58
  step=0, warmup_steps=warmup_steps, merge_dims=merge_dims, split=split,
58
59
  store_triu_as_line=store_triu_as_line)
59
- super().__init__(params, defaults)
60
+ super().__init__(params, defaults, foreach)
60
61
 
61
62
  self._prob_step = 0
62
63
 
@@ -33,7 +33,8 @@ class SFPaLMForeachSOAP(ScheduleFree):
33
33
  weight_decay: float = 0.01, precondition_frequency: int = 2, max_precond_dim: int = 2048, #
34
34
  merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
35
35
  data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1, r=0.0,
36
- weight_lr_power=2.0, gradient_clip_val: float = 0.1, betas=(None, None), split: bool = False):
36
+ weight_lr_power=2.0, gradient_clip_val: float = 0.1, betas=(None, None), split: bool = False,
37
+ foreach: bool = True):
37
38
  if betas[0] is not None:
38
39
  beta = betas[0]
39
40
  defaults = {"lr": lr, "beta": beta, "beta2_scale": beta2_scale, "eps": eps, "weight_decay": weight_decay,
@@ -42,7 +43,7 @@ class SFPaLMForeachSOAP(ScheduleFree):
42
43
  "correct_bias": correct_bias, 'warmup_steps': warmup_steps, 'r': r,
43
44
  'weight_lr_power': weight_lr_power, 'train_mode': True, 'step': -1,
44
45
  'gradient_clip_val': gradient_clip_val, 'weight_sum': 0, 'split': split}
45
- super().__init__(params, defaults)
46
+ super().__init__(params, defaults, foreach)
46
47
  self._data_format = data_format
47
48
  self.rng = random.Random(0x120983109)
48
49
 
@@ -51,7 +52,7 @@ class SFPaLMForeachSOAP(ScheduleFree):
51
52
  max_precond_dim = group['max_precond_dim']
52
53
  precondition_1d = group['precondition_1d']
53
54
 
54
- step = group['step'] = group.get("step", -1) + 1
55
+ step = group['step'] = group.get("step", 0) + 1
55
56
 
56
57
  for p in group["params"]:
57
58
  if p.grad is None:
heavyball/utils.py CHANGED
@@ -383,8 +383,25 @@ def project(grad, Q, back: bool):
383
383
 
384
384
 
385
385
  class StatefulOptimizer(torch.optim.Optimizer):
386
+ def __init__(self, params, defaults, foreach: bool = True):
387
+ super().__init__(params, {**defaults, 'foreach': foreach})
388
+ self.fake_groups = {}
389
+
390
+ def key(self, param: torch.Tensor):
391
+ return (param.data_ptr(), tuple(param.shape))
392
+
393
+ def get_groups(self, group):
394
+ if group['foreach']:
395
+ return [group]
396
+
397
+ for p in group['params']:
398
+ if self.key(p) not in self.fake_groups:
399
+ self.fake_groups[self.key(p)] = {**group, 'params': [p]}
400
+
401
+ return [self.fake_groups[self.key(p)] for p in group['params']]
402
+
386
403
  def state_(self, arg: torch.Tensor):
387
- return self.state[(arg.data_ptr(), tuple(arg.shape))]
404
+ return self.state[self.key(arg)]
388
405
 
389
406
  def state_size(self) -> int:
390
407
  total_bytes = 0
@@ -409,8 +426,9 @@ class StatefulOptimizer(torch.optim.Optimizer):
409
426
  with torch.enable_grad():
410
427
  loss = closure()
411
428
  with torch.no_grad():
412
- for group in self.param_groups:
413
- self._step(group)
429
+ for top_group in self.param_groups:
430
+ for group in self.get_groups(top_group):
431
+ self._step(group)
414
432
  return loss
415
433
 
416
434
 
@@ -754,8 +772,8 @@ def update_triu_(q_state, materialised):
754
772
 
755
773
 
756
774
  class PSGDBase(StatefulOptimizer):
757
- def __init__(self, parameters, groups):
758
- super().__init__(parameters, groups)
775
+ def __init__(self, parameters, groups, foreach: bool = True):
776
+ super().__init__(parameters, groups, foreach)
759
777
  self.rng = random.Random(0x1923213)
760
778
  self._tiny = torch.finfo(torch.bfloat16).tiny
761
779
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: heavyball
3
- Version: 0.15.1
3
+ Version: 0.16.0
4
4
  Summary: Efficient optimizers
5
5
  Home-page: https://github.com/clashluke/heavyball
6
6
  Author: Lucas Nestler
@@ -39,12 +39,14 @@ recommended experimental optimizer is `ForeachPSGDKron`.
39
39
 
40
40
  * **Stochastic Rounding**: [FP32 convergence with BF16 parameters](https://github.com/pytorch/pytorch/issues/120376)
41
41
  * **Inplace EMA**: Same math, but less memory, less compute and higher stability
42
- * **Foreach**: Fast multi-tensor application
42
+ * **Foreach**: Fast multi-tensor application (turn it off to save memory via `foreach=False`)
43
43
  * **PaLM Beta2**: Fast initial
44
44
  convergence, [stable late convergence](https://x.com/_clashluke/status/1820810798693818761)
45
45
  * **ScheduleFree**: No learning rate schedule, but better convergence
46
46
  * [**Preconditioner Schedule**](https://github.com/lixilinx/psgd_torch/): Improved loss-per-step in early convergence,
47
47
  better step-per-second in late convergence (explained below)
48
+ * **Memory-efficient storage** PSGD supports `store_triu_as_line` (default: `True`) to trade off memory usage for memory
49
+ bandwidth; turn it off for lower overheads (for more, see [PSGD Efficiency](docs/psgd_efficiency.md))
48
50
 
49
51
  ## Getting started
50
52
 
@@ -0,0 +1,23 @@
1
+ heavyball/__init__.py,sha256=KbT0GMU0DKqZxq9laCrD7XgiqS9yxC1W52zhte5kjKs,2054
2
+ heavyball/cached_psgd_kron.py,sha256=vJuy639G-_ZLSRX3goSFMXALv-ucYjrxaEtpj0IHo-M,6802
3
+ heavyball/delayed_psgd.py,sha256=sbwgAed5gmQpHNTPvuE7Si-gB-s0NVvN4d-4rNUJj4c,5893
4
+ heavyball/foreach_adamw.py,sha256=h_ar0ZZRM0q_wxuEkxEOEYe-2p-mB4OMgAHivrUnPl8,1777
5
+ heavyball/foreach_adopt.py,sha256=ogOw2JjwEQNj7AKlweAphQFdMJ_GcMDm-RyDvEzugoc,1911
6
+ heavyball/foreach_laprop.py,sha256=yGVmGqWiSw8Y2Xj70ndkR8ZMygakTB4_iRwV02Svkqg,1816
7
+ heavyball/foreach_sfadamw.py,sha256=15-n6-lx4PAHYsKYmXbugxsR5MnqaPYy2vUudPRiitg,2087
8
+ heavyball/foreach_soap.py,sha256=h6ptMch7oaynvu3eIJtWnVXypDA_5JDVm3Zb3PNEma0,4634
9
+ heavyball/p_adam.py,sha256=aCu4Qn0eHJETHuCGrfNKp2aygKk2ZoNQyxut3Vcqmoc,6112
10
+ heavyball/palm_foreach_sfadamw.py,sha256=yvZbPyjDW8qd3r4qDXb6uTr5RozQ7JSDj4aYYRnKGLA,2248
11
+ heavyball/palm_foreach_soap.py,sha256=g4hbiGRcti-J-a0SwAkP4ii5pU-aalsZH5bssyhroLk,5938
12
+ heavyball/precond_schedule_foreach_soap.py,sha256=WLg5SzpJnKPZUvFyIvdwSZa1Umt5cpr3Kow_42orM-E,4863
13
+ heavyball/precond_schedule_palm_foreach_soap.py,sha256=ammQrvRZFF-wc-wEiPEoFhS_7b8pdV61QfcLoQfimSo,6211
14
+ heavyball/precond_schedule_sfpsoap.py,sha256=vq7jd302refKPa_9X2lkOTOtCCcTBVByPdojklrY8pA,6770
15
+ heavyball/psgd_kron.py,sha256=iWTAViuzxTodtQGZnkLsEXrLG8tNU-BQB3KkTYAVcX4,5874
16
+ heavyball/pure_psgd.py,sha256=EuCPNM8TX13cOop-mvvBFh6Uo1UjD1vsE053hvil92Q,5136
17
+ heavyball/schedule_free_palm_foreach_soap.py,sha256=zkcikH5wWbzq4kOrmBjilvY3iWzuUddcv2HNEPKr3MI,6366
18
+ heavyball/utils.py,sha256=z6taEvpgszKTrscqgowKYqb0xIVpBDVDBNGgvTE4Pb8,28484
19
+ heavyball-0.16.0.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
20
+ heavyball-0.16.0.dist-info/METADATA,sha256=yjpldOTN2rXN2-KG7R9ytuyBfmSCDpznZeRuziANChE,11941
21
+ heavyball-0.16.0.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
22
+ heavyball-0.16.0.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
23
+ heavyball-0.16.0.dist-info/RECORD,,
@@ -1,23 +0,0 @@
1
- heavyball/__init__.py,sha256=KbT0GMU0DKqZxq9laCrD7XgiqS9yxC1W52zhte5kjKs,2054
2
- heavyball/cached_psgd_kron.py,sha256=mXDtxq2WJST_aUJhrLr_xCCXSFaDvD5gCTSEveBUtac,6754
3
- heavyball/delayed_psgd.py,sha256=dN3NW1jmjxmUkgqxPwUVrqLY8nnBOFp4TVtJ_BhPDR4,5814
4
- heavyball/foreach_adamw.py,sha256=NSzoIgNm7eavzbJgkAF0k7TUEnWAgOpt9-4juIFoaSA,1729
5
- heavyball/foreach_adopt.py,sha256=WA07m5jocLfb1GPU8s6mJ2PteS-03ronkKm-VJrAm5I,1863
6
- heavyball/foreach_laprop.py,sha256=mE2NDGX9XgvRhsewcWnk_-FulZPqGA65ejYF_9-A1Xk,1768
7
- heavyball/foreach_sfadamw.py,sha256=ussHfPd99u3RTfMrCuu5oIbwNFLXK19wO1Fbz3JShlc,2097
8
- heavyball/foreach_soap.py,sha256=WWvssYKg607uoEJHftp8ag8mtKSKSeHrT0QTgqBucVg,4587
9
- heavyball/p_adam.py,sha256=ms7BoMHu3jKGsuztUeECrsXufGAwBpqGsxgZ5LBXLQg,6073
10
- heavyball/palm_foreach_sfadamw.py,sha256=wjUb_fNZNUmzWXyKvwB0unP9lvNMmaYSQo5YoeS5cj0,2200
11
- heavyball/palm_foreach_soap.py,sha256=2Sb4hUHQeexJcCgjHeQM_ENkZ6lG1DVxW72ryrvR6iY,5890
12
- heavyball/precond_schedule_foreach_soap.py,sha256=bHsDyh-UvHpHjumjqqy0PePoR1ZMsJV6o5wWvpLAA04,4815
13
- heavyball/precond_schedule_palm_foreach_soap.py,sha256=myLTJNQKLtZ3Xi3MVTB-RYtx_XeMRJw5CIMJW75ndUY,6163
14
- heavyball/precond_schedule_sfpsoap.py,sha256=xeNWetBzBEYqfOSzl98aAVJsHk43QkrUUhHH_YD_mS4,6740
15
- heavyball/psgd_kron.py,sha256=rMG5UPEgyfQs_n1MHSEicekVDpbbIzinlL8akEyY918,5795
16
- heavyball/pure_psgd.py,sha256=LLVJhUAb04hgAmT3BTz_faswwQEQUkLhm_VwGQmbBUo,5088
17
- heavyball/schedule_free_palm_foreach_soap.py,sha256=w0P7lMmoijTpL9V7NwOHcNBFJQ7S1TS9aCiwPhY2yVw,6319
18
- heavyball/utils.py,sha256=PWmwjZPL4oxMjK79a5R1e7JHykphNi5GdpYqO_xmmFU,27829
19
- heavyball-0.15.1.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
20
- heavyball-0.15.1.dist-info/METADATA,sha256=0wImMJNYM-Zg0akh9hRf7X8ofVW6zlmpyDGgAkK5GFA,11667
21
- heavyball-0.15.1.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
22
- heavyball-0.15.1.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
23
- heavyball-0.15.1.dist-info/RECORD,,