heavyball 0.15.0__tar.gz → 0.15.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. {heavyball-0.15.0 → heavyball-0.15.1}/PKG-INFO +1 -1
  2. {heavyball-0.15.0 → heavyball-0.15.1}/heavyball/__init__.py +24 -2
  3. heavyball-0.15.1/heavyball/cached_psgd_kron.py +141 -0
  4. {heavyball-0.15.0 → heavyball-0.15.1}/heavyball/delayed_psgd.py +8 -7
  5. {heavyball-0.15.0 → heavyball-0.15.1}/heavyball/p_adam.py +11 -6
  6. {heavyball-0.15.0 → heavyball-0.15.1}/heavyball/psgd_kron.py +8 -6
  7. {heavyball-0.15.0 → heavyball-0.15.1}/heavyball/pure_psgd.py +8 -6
  8. {heavyball-0.15.0 → heavyball-0.15.1}/heavyball/utils.py +6 -6
  9. {heavyball-0.15.0 → heavyball-0.15.1}/heavyball.egg-info/PKG-INFO +1 -1
  10. {heavyball-0.15.0 → heavyball-0.15.1}/heavyball.egg-info/SOURCES.txt +1 -0
  11. {heavyball-0.15.0 → heavyball-0.15.1}/setup.py +1 -1
  12. {heavyball-0.15.0 → heavyball-0.15.1}/LICENSE +0 -0
  13. {heavyball-0.15.0 → heavyball-0.15.1}/README.md +0 -0
  14. {heavyball-0.15.0 → heavyball-0.15.1}/heavyball/foreach_adamw.py +0 -0
  15. {heavyball-0.15.0 → heavyball-0.15.1}/heavyball/foreach_adopt.py +0 -0
  16. {heavyball-0.15.0 → heavyball-0.15.1}/heavyball/foreach_laprop.py +0 -0
  17. {heavyball-0.15.0 → heavyball-0.15.1}/heavyball/foreach_sfadamw.py +0 -0
  18. {heavyball-0.15.0 → heavyball-0.15.1}/heavyball/foreach_soap.py +0 -0
  19. {heavyball-0.15.0 → heavyball-0.15.1}/heavyball/palm_foreach_sfadamw.py +0 -0
  20. {heavyball-0.15.0 → heavyball-0.15.1}/heavyball/palm_foreach_soap.py +0 -0
  21. {heavyball-0.15.0 → heavyball-0.15.1}/heavyball/precond_schedule_foreach_soap.py +0 -0
  22. {heavyball-0.15.0 → heavyball-0.15.1}/heavyball/precond_schedule_palm_foreach_soap.py +0 -0
  23. {heavyball-0.15.0 → heavyball-0.15.1}/heavyball/precond_schedule_sfpsoap.py +0 -0
  24. {heavyball-0.15.0 → heavyball-0.15.1}/heavyball/schedule_free_palm_foreach_soap.py +0 -0
  25. {heavyball-0.15.0 → heavyball-0.15.1}/heavyball.egg-info/dependency_links.txt +0 -0
  26. {heavyball-0.15.0 → heavyball-0.15.1}/heavyball.egg-info/requires.txt +0 -0
  27. {heavyball-0.15.0 → heavyball-0.15.1}/heavyball.egg-info/top_level.txt +0 -0
  28. {heavyball-0.15.0 → heavyball-0.15.1}/setup.cfg +0 -0
  29. {heavyball-0.15.0 → heavyball-0.15.1}/test/test_closure.py +0 -0
  30. {heavyball-0.15.0 → heavyball-0.15.1}/test/test_memory.py +0 -0
  31. {heavyball-0.15.0 → heavyball-0.15.1}/test/test_merge.py +0 -0
  32. {heavyball-0.15.0 → heavyball-0.15.1}/test/test_no_grad.py +0 -0
  33. {heavyball-0.15.0 → heavyball-0.15.1}/test/test_psgd.py +0 -0
  34. {heavyball-0.15.0 → heavyball-0.15.1}/test/test_soap.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: heavyball
3
- Version: 0.15.0
3
+ Version: 0.15.1
4
4
  Summary: Efficient optimizers
5
5
  Home-page: https://github.com/clashluke/heavyball
6
6
  Author: Lucas Nestler
@@ -1,3 +1,4 @@
1
+ from .cached_psgd_kron import ForeachCachedPSGDKron
1
2
  from .delayed_psgd import ForeachDelayedPSGD
2
3
  from .foreach_adamw import ForeachAdamW
3
4
  from .foreach_adopt import ForeachADOPT
@@ -16,7 +17,28 @@ from .schedule_free_palm_foreach_soap import SFPaLMForeachSOAP
16
17
 
17
18
  PalmForEachSoap = PaLMForeachSOAP
18
19
 
20
+ PaLMSOAP = PaLMForeachSOAP
21
+ PaLMSFAdamW = PaLMForeachSFAdamW
22
+ PaLMSFSoap = SFPaLMForeachSOAP
23
+ PaLMForeachSOAP = PaLMForeachSOAP
24
+ PrecondScheduleSFPaLMSOAP = PrecondScheduleSFPaLMSOAP
25
+ SOAP = ForeachSOAP
26
+ SFAdamW = ForeachSFAdamW
27
+ LaProp = ForeachLaProp
28
+ ADOPT = ForeachADOPT
29
+ PrecondScheduleForeachSOAP = PrecondScheduleForeachSOAP
30
+ PrecondSchedulePaLMForeachSOAP = PrecondSchedulePaLMForeachSOAP
31
+ PSGDKron = ForeachPSGDKron
32
+ AdamW = ForeachAdamW
33
+ PurePSGD = ForeachPurePSGD
34
+ PaLMPAdam = ForeachPaLMPAdam
35
+ DelayedPSGD = ForeachDelayedPSGD
36
+ CachedPSGDKron = ForeachCachedPSGDKron
37
+
19
38
  __all__ = ['PalmForEachSoap', 'PaLMForeachSFAdamW', 'PaLMForeachSOAP', 'SFPaLMForeachSOAP', 'PrecondScheduleSFPaLMSOAP',
20
39
  'ForeachSOAP', 'ForeachSFAdamW', 'ForeachLaProp', 'ForeachADOPT', 'PrecondScheduleForeachSOAP',
21
- 'PrecondSchedulePaLMForeachSOAP', 'ForeachPSGDKron', 'ForeachAdamW', 'ForeachPurePSGD',
22
- 'ForeachPaLMPAdam', 'ForeachDelayedPSGD']
40
+ 'PrecondSchedulePaLMForeachSOAP', 'ForeachPSGDKron', 'ForeachAdamW', 'ForeachPurePSGD', 'ForeachPaLMPAdam',
41
+ 'ForeachDelayedPSGD', 'ForeachCachedPSGDKron', #
42
+ 'PaLMSOAP', 'PaLMSFAdamW', 'PaLMSFSoap', 'PaLMSFAdamW', 'PaLMForeachSOAP', 'PrecondScheduleSFPaLMSOAP',
43
+ 'SOAP', 'SFAdamW', 'LaProp', 'ADOPT', 'PSGDKron', 'AdamW', 'PurePSGD', 'PaLMPAdam', 'DelayedPSGD',
44
+ 'CachedPSGDKron']
@@ -0,0 +1,141 @@
1
+ """
2
+ Originally from Evan Walters and Omead Pooladzandi, 2024
3
+ Modified under Creative Commons Attribution 4.0 International
4
+ Source available at https://github.com/evanatyourservice/kron_torch/blob/97a2b5ee8a1a4c29e4780bbf6c521e545189eff9/kron_torch/kron.py
5
+ """
6
+
7
+ from typing import Optional
8
+
9
+ import torch
10
+ from heavyball.utils import einsum_base
11
+
12
+ from .utils import update_param_, warmup, psgd_precond_grad, init_Q_exprs, trust_region_clip_, PSGDBase, \
13
+ precond_update_prob_schedule, split_p_and_g_in_group, line_to_triu, triu_to_line, set_, einsum_base
14
+
15
+
16
+ class ForeachCachedPSGDKron(PSGDBase):
17
+ """Implements PSGD Kron from https://github.com/lixilinx/psgd_torch with cached preconditioners.
18
+
19
+ Args:
20
+ params (iterable): Iterable of parameters to optimize or dicts defining
21
+ parameter groups.
22
+ lr (float): Learning rate.
23
+ b1 (float): Momentum parameter.
24
+ weight_decay (float): Weight decay (L2 penalty).
25
+ preconditioner_update_probability (callable or float, optional): Probability of
26
+ updating the preconditioner. If None, defaults to a schedule that anneals
27
+ from 1.0 to 0.03 by 4000 steps.
28
+ max_size_triangular (int): Max size for dim's preconditioner to be triangular.
29
+ min_ndim_triangular (int): Minimum number of dimensions a layer needs
30
+ to have triangular preconditioners.
31
+ memory_save_mode: (string, optional), None, 'one_diag', or 'all_diag', None is default
32
+ to set all preconditioners to be triangular, 'one_diag' sets the largest
33
+ or last dim to be diagonal per layer, and 'all_diag' sets all preconditioners
34
+ to be diagonal.
35
+ momentum_into_precond_update: (bool), whether to send momentum into preconditioner
36
+ update instead of raw gradients.
37
+ """
38
+
39
+ def __init__(self, params, lr=0.001, beta=0.9, weight_decay=0.0, preconditioner_update_probability=None,
40
+ max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
41
+ momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
42
+ split: bool = False, clip_fn: Optional[callable] = None, store_triu_as_line: bool = True):
43
+ if not 0.0 <= lr:
44
+ raise ValueError(f"Invalid learning rate: {lr}")
45
+ if not 0.0 <= beta < 1.0:
46
+ raise ValueError(f"Invalid beta parameter: {beta}")
47
+ if not 0.0 <= weight_decay:
48
+ raise ValueError(f"Invalid weight_decay value: {weight_decay}")
49
+
50
+ if preconditioner_update_probability is None:
51
+ preconditioner_update_probability = precond_update_prob_schedule()
52
+ if clip_fn is None:
53
+ clip_fn = lambda x: trust_region_clip_(x, 0.9, 1.5)
54
+ self.preconditioner_update_probability = preconditioner_update_probability
55
+ self.clip_fn = clip_fn
56
+
57
+ defaults = dict(lr=lr, beta=beta, weight_decay=weight_decay, max_size_triangular=max_size_triangular,
58
+ min_ndim_triangular=min_ndim_triangular, memory_save_mode=memory_save_mode,
59
+ momentum_into_precond_update=momentum_into_precond_update, precond_lr=0.1,
60
+ # precond lr hardcoded to 0.1
61
+ precond_init_scale=1.0, # precond init scale hardcoded to 1.0
62
+ step=0, warmup_steps=warmup_steps, merge_dims=merge_dims, split=split,
63
+ store_triu_as_line=store_triu_as_line)
64
+ super().__init__(params, defaults)
65
+
66
+ self._prob_step = 0
67
+
68
+ def _step(self, group):
69
+ # update preconditioners all together
70
+ update_prob = self.preconditioner_update_probability
71
+ if callable(update_prob):
72
+ update_prob = update_prob(self._prob_step)
73
+ do_update = self.rng.random() < update_prob
74
+ self._prob_step += 1
75
+
76
+ momentum_into_precond_update = group.get("momentum_into_precond_update", True)
77
+ precond_init_scale = group['precond_init_scale']
78
+ max_size_triangular = group['max_size_triangular']
79
+ min_ndim_triangular = group['min_ndim_triangular']
80
+ memory_save_mode = group['memory_save_mode']
81
+ precond_lr = group['precond_lr']
82
+ weight_decay = group['weight_decay']
83
+ lr = group['lr']
84
+ beta = group['beta']
85
+ store_triu_as_line = group['store_triu_as_line']
86
+
87
+ vals = []
88
+
89
+ for p, g in split_p_and_g_in_group(group):
90
+ state = self.state_(p)
91
+
92
+ if 'Q' not in state:
93
+ state["exp_avg"] = torch.zeros_like(g)
94
+ Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular, min_ndim_triangular,
95
+ memory_save_mode, dtype=g.dtype)
96
+ state['Q'] = triu_to_line(Q) if store_triu_as_line else Q
97
+ state['Q_cache'] = [torch.empty_like(q) for q in Q]
98
+
99
+ expr = [f'{c.upper()}{c}' if q_.ndim == 2 else c for c, q_ in zip(einsum_base, Q)]
100
+ expr = ','.join(expr)
101
+ grad_expr = ''.join(c for c, _ in zip(einsum_base, g.shape))
102
+ out_expr = ''.join(c.upper() if c.upper() in expr else c for c in grad_expr)
103
+ expr = f'{expr},{grad_expr}->{out_expr}'
104
+
105
+ state['cache_expr'] = expr
106
+
107
+ vals.append((p, g, state["exp_avg"], state["Q"], state['Q_cache']))
108
+
109
+ if not vals:
110
+ return
111
+
112
+ p_list, grad_list, exp_avg_list, Q_list, Q_cache_list = zip(*vals)
113
+ del vals
114
+
115
+ group["step"] += 1
116
+
117
+ torch._foreach_lerp_(exp_avg_list, grad_list, (1 - beta) / (1 - beta ** group["step"]))
118
+
119
+ grad_list, Q_list, Q_cache_list, exp_avg_list = list(grad_list), list(Q_list), list(Q_cache_list), list(
120
+ exp_avg_list)
121
+ for i, (p, g) in enumerate(zip(p_list, grad_list)):
122
+ cached_q = Q_cache_list.pop(0)
123
+ q_orig = Q_list.pop(0)
124
+ ea = exp_avg_list.pop(0)
125
+
126
+ if do_update:
127
+ q = line_to_triu(q_orig) if store_triu_as_line else q_orig
128
+ self.balance([g], [q])
129
+ self.do_update([p], [ea if momentum_into_precond_update else g], [q], precond_lr,
130
+ [q_orig] if store_triu_as_line else None)
131
+ for c_, q_ in zip(cached_q, q):
132
+ if q_.ndim == 2:
133
+ torch.matmul(q_.T.conj(), q_, out=c_)
134
+ else:
135
+ torch.mul(q_.conj(), q_, out=c_)
136
+
137
+ set_(g, torch.einsum(self.state_(p)['cache_expr'], *cached_q, ea))
138
+ grad_list = self.clip_fn(grad_list)
139
+
140
+ lr = -warmup(lr, group['step'], group['warmup_steps'])
141
+ update_param_(p_list, grad_list, lr, weight_decay)
@@ -38,7 +38,7 @@ class ForeachDelayedPSGD(PSGDBase):
38
38
  def __init__(self, params, lr=0.001, beta=0.9, weight_decay=0.0, preconditioner_update_probability=None,
39
39
  max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
40
40
  momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
41
- split: bool = False, clip_fn: callable = None):
41
+ split: bool = False, clip_fn: callable = None, store_triu_as_line: bool = True):
42
42
  if not 0.0 <= lr:
43
43
  raise ValueError(f"Invalid learning rate: {lr}")
44
44
  if not 0.0 <= beta < 1.0:
@@ -58,7 +58,8 @@ class ForeachDelayedPSGD(PSGDBase):
58
58
  momentum_into_precond_update=momentum_into_precond_update, precond_lr=0.1,
59
59
  # precond lr hardcoded to 0.1
60
60
  precond_init_scale=1.0, # precond init scale hardcoded to 1.0
61
- step=0, warmup_steps=warmup_steps, merge_dims=merge_dims, split=split)
61
+ step=0, warmup_steps=warmup_steps, merge_dims=merge_dims, split=split,
62
+ store_triu_as_line=store_triu_as_line)
62
63
  super().__init__(params, defaults)
63
64
 
64
65
  self._prob_step = 0
@@ -80,6 +81,7 @@ class ForeachDelayedPSGD(PSGDBase):
80
81
  weight_decay = group['weight_decay']
81
82
  lr = group['lr']
82
83
  beta = group['beta']
84
+ store_triu_as_line = group['store_triu_as_line']
83
85
 
84
86
  vals = []
85
87
 
@@ -90,7 +92,7 @@ class ForeachDelayedPSGD(PSGDBase):
90
92
  state["exp_avg"] = torch.zeros_like(g)
91
93
  Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular, min_ndim_triangular,
92
94
  memory_save_mode, dtype=g.dtype)
93
- state["Q"] = triu_to_line(Q)
95
+ state["Q"] = triu_to_line(Q) if store_triu_as_line else Q
94
96
 
95
97
  vals.append((p, g, state["exp_avg"], state["Q"]))
96
98
 
@@ -108,12 +110,11 @@ class ForeachDelayedPSGD(PSGDBase):
108
110
  for i, (p, g) in enumerate(zip(p_list, grad_list)):
109
111
  q_orig = Q_list.pop(0)
110
112
  ea = exp_avg_list.pop(0)
111
- q = line_to_triu(q_orig)
112
- self.balance(do_update, [g], [q])
113
+ q = line_to_triu(q_orig) if store_triu_as_line else q_orig
113
114
  new = psgd_precond_grad(q, self.state_(p)["exprs"], ea)
114
-
115
115
  if do_update:
116
- self.do_update([p], [ea if momentum_into_precond_update else g], [q], precond_lr, [q_orig])
116
+ self.do_update([p], [ea if momentum_into_precond_update else g], [q], precond_lr, [q_orig] if store_triu_as_line else None)
117
+ self.balance([g], [q])
117
118
  set_(g, new)
118
119
 
119
120
  grad_list = self.clip_fn(grad_list)
@@ -5,6 +5,7 @@ Source available at https://github.com/evanatyourservice/kron_torch/blob/97a2b5e
5
5
  """
6
6
 
7
7
  import torch
8
+ from heavyball.utils import triu_to_line, line_to_triu
8
9
 
9
10
  from .utils import update_param_, warmup, psgd_precond_grad, init_Q_exprs, PSGDBase, precond_update_prob_schedule, \
10
11
  exp_avg_sq_, beta_debias, split_p_and_g_in_group
@@ -36,7 +37,8 @@ class ForeachPaLMPAdam(PSGDBase):
36
37
  def __init__(self, params, lr=0.001, weight_decay=0.0, preconditioner_update_probability=None,
37
38
  max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
38
39
  momentum_into_precond_update=True, warmup_steps: int = 1, betas=(None, None), beta: float = 0.9,
39
- beta2_scale: float = 0.8, merge_dims: bool = False, split: bool = False, clip_fn: callable = None):
40
+ beta2_scale: float = 0.8, merge_dims: bool = False, split: bool = False, clip_fn: callable = None,
41
+ store_triu_as_line: bool = True):
40
42
  if not 0.0 <= lr:
41
43
  raise ValueError(f"Invalid learning rate: {lr}")
42
44
  if not 0.0 <= weight_decay:
@@ -57,7 +59,7 @@ class ForeachPaLMPAdam(PSGDBase):
57
59
  # precond lr hardcoded to 0.1
58
60
  precond_init_scale=1.0, # precond init scale hardcoded to 1.0
59
61
  step=0, warmup_steps=warmup_steps, beta=beta, beta2_scale=beta2_scale, merge_dims=merge_dims,
60
- split=split)
62
+ split=split, store_triu_as_line=store_triu_as_line)
61
63
  super().__init__(params, defaults)
62
64
 
63
65
  self._prob_step = 0
@@ -77,6 +79,7 @@ class ForeachPaLMPAdam(PSGDBase):
77
79
  precond_lr = group['precond_lr']
78
80
  weight_decay = group['weight_decay']
79
81
  lr = group['lr']
82
+ store_triu_as_line = group['store_triu_as_line']
80
83
 
81
84
  vals = []
82
85
 
@@ -86,8 +89,9 @@ class ForeachPaLMPAdam(PSGDBase):
86
89
  if 'Q' not in state:
87
90
  state['exp_avg'] = torch.zeros_like(g)
88
91
  state['exp_avg_sq'] = torch.zeros_like(g)
89
- state["Q"], state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular,
92
+ Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular,
90
93
  min_ndim_triangular, memory_save_mode, dtype=g.dtype)
94
+ state['Q'] = triu_to_line(Q) if store_triu_as_line else Q
91
95
 
92
96
  vals.append((p, g, state["Q"], state['exp_avg'], state['exp_avg_sq']))
93
97
 
@@ -99,15 +103,16 @@ class ForeachPaLMPAdam(PSGDBase):
99
103
 
100
104
  group["step"] += 1
101
105
 
102
- self.balance(do_update, grad_list, Q_list)
106
+ Q_triu = [line_to_triu(q) if store_triu_as_line else q for q in Q_list]
103
107
  if do_update:
104
- self.do_update(p_list, grad_list, Q_list, precond_lr)
108
+ self.balance(grad_list, Q_triu)
109
+ self.do_update(p_list, grad_list, Q_triu, precond_lr, Q_list if store_triu_as_line else None)
105
110
 
106
111
  torch._foreach_lerp_(exp_avg, grad_list, 1 - beta_debias(group['beta'], group['step']))
107
112
 
108
113
  beta2 = 1 - group['step'] ** -group['beta2_scale']
109
114
 
110
- for p, Q, g, ea, eas in zip(p_list, Q_list, grad_list, exp_avg, exp_avg_sq):
115
+ for p, Q, g, ea, eas in zip(p_list, Q_triu, grad_list, exp_avg, exp_avg_sq):
111
116
  psgd_precond_grad(Q, self.state_(p)["exprs"], g, inplace=True)
112
117
  ea = psgd_precond_grad(Q, self.state_(p)["exprs"], ea)
113
118
  exp_avg_sq_(eas, g, beta_debias(beta2, group['step']), 1e-8, out=g)
@@ -38,7 +38,7 @@ class ForeachPSGDKron(PSGDBase):
38
38
  def __init__(self, params, lr=0.001, beta=0.9, weight_decay=0.0, preconditioner_update_probability=None,
39
39
  max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
40
40
  momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
41
- split: bool = False, clip_fn: Optional[callable] = None):
41
+ split: bool = False, clip_fn: Optional[callable] = None, store_triu_as_line: bool = True):
42
42
  if not 0.0 <= lr:
43
43
  raise ValueError(f"Invalid learning rate: {lr}")
44
44
  if not 0.0 <= beta < 1.0:
@@ -58,7 +58,8 @@ class ForeachPSGDKron(PSGDBase):
58
58
  momentum_into_precond_update=momentum_into_precond_update, precond_lr=0.1,
59
59
  # precond lr hardcoded to 0.1
60
60
  precond_init_scale=1.0, # precond init scale hardcoded to 1.0
61
- step=0, warmup_steps=warmup_steps, merge_dims=merge_dims, split=split)
61
+ step=0, warmup_steps=warmup_steps, merge_dims=merge_dims, split=split,
62
+ store_triu_as_line=store_triu_as_line)
62
63
  super().__init__(params, defaults)
63
64
 
64
65
  self._prob_step = 0
@@ -80,6 +81,7 @@ class ForeachPSGDKron(PSGDBase):
80
81
  weight_decay = group['weight_decay']
81
82
  lr = group['lr']
82
83
  beta = group['beta']
84
+ store_triu_as_line = group['store_triu_as_line']
83
85
 
84
86
  vals = []
85
87
 
@@ -90,7 +92,7 @@ class ForeachPSGDKron(PSGDBase):
90
92
  state["exp_avg"] = torch.zeros_like(g)
91
93
  Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular, min_ndim_triangular,
92
94
  memory_save_mode, dtype=g.dtype)
93
- state['Q'] = triu_to_line(Q)
95
+ state['Q'] = triu_to_line(Q) if store_triu_as_line else Q
94
96
 
95
97
  vals.append((p, g, state["exp_avg"], state["Q"]))
96
98
 
@@ -108,11 +110,11 @@ class ForeachPSGDKron(PSGDBase):
108
110
  for i, (p, g) in enumerate(zip(p_list, grad_list)):
109
111
  q_orig = Q_list.pop(0)
110
112
  ea = exp_avg_list.pop(0)
111
- q = line_to_triu(q_orig)
113
+ q = line_to_triu(q_orig) if store_triu_as_line else q_orig
112
114
 
113
- self.balance(do_update, [g], [q])
114
115
  if do_update:
115
- self.do_update([p], [ea if momentum_into_precond_update else g], [q], precond_lr, [q_orig])
116
+ self.balance([g], [q])
117
+ self.do_update([p], [ea if momentum_into_precond_update else g], [q], precond_lr, [q_orig] if store_triu_as_line else None)
116
118
  set_(g, psgd_precond_grad(q, self.state_(p)["exprs"], ea))
117
119
 
118
120
  grad_list = self.clip_fn(grad_list)
@@ -36,7 +36,7 @@ class ForeachPurePSGD(PSGDBase):
36
36
  def __init__(self, params, lr=0.001, weight_decay=0.0, preconditioner_update_probability=None,
37
37
  max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
38
38
  momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
39
- split: bool = False, clip_fn: callable = None):
39
+ split: bool = False, clip_fn: callable = None, store_triu_as_line: bool = True):
40
40
  if not 0.0 <= lr:
41
41
  raise ValueError(f"Invalid learning rate: {lr}")
42
42
  if not 0.0 <= weight_decay:
@@ -54,7 +54,8 @@ class ForeachPurePSGD(PSGDBase):
54
54
  momentum_into_precond_update=momentum_into_precond_update, precond_lr=0.1,
55
55
  # precond lr hardcoded to 0.1
56
56
  precond_init_scale=1.0, # precond init scale hardcoded to 1.0
57
- step=0, warmup_steps=warmup_steps, merge_dims=merge_dims, split=split)
57
+ step=0, warmup_steps=warmup_steps, merge_dims=merge_dims, split=split,
58
+ store_triu_as_line=store_triu_as_line)
58
59
  super().__init__(params, defaults)
59
60
 
60
61
  self._prob_step = 0
@@ -74,6 +75,7 @@ class ForeachPurePSGD(PSGDBase):
74
75
  precond_lr = group['precond_lr']
75
76
  weight_decay = group['weight_decay']
76
77
  lr = group['lr']
78
+ store_triu_as_line = group['store_triu_as_line']
77
79
 
78
80
  vals = []
79
81
 
@@ -83,7 +85,7 @@ class ForeachPurePSGD(PSGDBase):
83
85
  if 'Q' not in state:
84
86
  Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular, min_ndim_triangular,
85
87
  memory_save_mode, dtype=g.dtype)
86
- state['Q'] = triu_to_line(Q)
88
+ state['Q'] = triu_to_line(Q) if store_triu_as_line else Q
87
89
 
88
90
  vals.append((p, g, state["Q"]))
89
91
 
@@ -98,11 +100,11 @@ class ForeachPurePSGD(PSGDBase):
98
100
  Q_list = list(Q_list)
99
101
  for i, (p, g) in enumerate(zip(p_list, grad_list)):
100
102
  q_orig = Q_list.pop(0)
101
- q = line_to_triu(q_orig)
103
+ q = line_to_triu(q_orig) if store_triu_as_line else q_orig
102
104
 
103
- self.balance(do_update, [g], [q])
104
105
  if do_update:
105
- self.do_update([p], [g], [q], precond_lr, [q_orig])
106
+ self.balance([g], [q])
107
+ self.do_update([p], [g], [q], precond_lr, [q_orig] if store_triu_as_line else None)
106
108
  psgd_precond_grad(q, self.state_(p)["exprs"], g, inplace=True)
107
109
 
108
110
  grad_list = self.clip_fn(grad_list)
@@ -29,7 +29,7 @@ def decorator(func):
29
29
  return _fn
30
30
 
31
31
 
32
- _einsum_base = string.ascii_lowercase + string.ascii_uppercase
32
+ einsum_base = string.ascii_lowercase + string.ascii_uppercase
33
33
 
34
34
 
35
35
  def warmup(lr: float, step: int, warmup_steps: int):
@@ -317,8 +317,8 @@ def compute_ggt(grad, GG, max_precond_dim, precondition_1d, beta):
317
317
  for idx, sh in enumerate(grad.shape):
318
318
  if sh > max_precond_dim:
319
319
  continue
320
- b = _einsum_base[idx]
321
- g0 = _einsum_base[:grad.dim()]
320
+ b = einsum_base[idx]
321
+ g0 = einsum_base[:grad.dim()]
322
322
  g1 = g0.replace(b, b.upper())
323
323
  outer_product = torch.einsum(f'{g0},{g1}->{b + b.upper()}', grad, grad)
324
324
  GG[idx].lerp_(promote(outer_product), 1 - beta)
@@ -374,7 +374,7 @@ def project(grad, Q, back: bool):
374
374
  :param back: whether to project to Shampoo eigenbases or back to original space
375
375
  :return:
376
376
  """
377
- param = _einsum_base[:grad.dim()]
377
+ param = einsum_base[:grad.dim()]
378
378
  preconditioners = ",".join([(g + g.upper())[::-1 if back else 1] for m, g in zip(Q, param) if len(m) > 0])
379
379
  if preconditioners:
380
380
  out = ''.join([c.upper() if c.upper() in preconditioners else c for c in param])
@@ -759,8 +759,8 @@ class PSGDBase(StatefulOptimizer):
759
759
  self.rng = random.Random(0x1923213)
760
760
  self._tiny = torch.finfo(torch.bfloat16).tiny
761
761
 
762
- def balance(self, do_update, grad_list, Q_list):
763
- if not do_update or self.rng.random() > 0.01:
762
+ def balance(self, grad_list, Q_list):
763
+ if self.rng.random() > 0.01:
764
764
  return
765
765
 
766
766
  for g, q in zip(grad_list, Q_list):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: heavyball
3
- Version: 0.15.0
3
+ Version: 0.15.1
4
4
  Summary: Efficient optimizers
5
5
  Home-page: https://github.com/clashluke/heavyball
6
6
  Author: Lucas Nestler
@@ -2,6 +2,7 @@ LICENSE
2
2
  README.md
3
3
  setup.py
4
4
  heavyball/__init__.py
5
+ heavyball/cached_psgd_kron.py
5
6
  heavyball/delayed_psgd.py
6
7
  heavyball/foreach_adamw.py
7
8
  heavyball/foreach_adopt.py
@@ -10,7 +10,7 @@ setuptools.setup(
10
10
  name='heavyball',
11
11
  license='BSD',
12
12
  description='Efficient optimizers',
13
- version='0.15.0',
13
+ version='0.15.1',
14
14
  long_description=README,
15
15
  url='https://github.com/clashluke/heavyball',
16
16
  packages=setuptools.find_packages(),
File without changes
File without changes
File without changes
File without changes
File without changes