heavyball 0.16.0__py3-none-any.whl → 0.17.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
heavyball/__init__.py CHANGED
@@ -14,6 +14,7 @@ from .precond_schedule_sfpsoap import PrecondScheduleSFPaLMSOAP
14
14
  from .psgd_kron import ForeachPSGDKron
15
15
  from .pure_psgd import ForeachPurePSGD
16
16
  from .schedule_free_palm_foreach_soap import SFPaLMForeachSOAP
17
+ from .cached_delayed_psgd_kron import ForeachCachedDelayedPSGDKron
17
18
 
18
19
  PalmForEachSoap = PaLMForeachSOAP
19
20
 
@@ -34,11 +35,12 @@ PurePSGD = ForeachPurePSGD
34
35
  PaLMPAdam = ForeachPaLMPAdam
35
36
  DelayedPSGD = ForeachDelayedPSGD
36
37
  CachedPSGDKron = ForeachCachedPSGDKron
38
+ CachedDelayedPSGDKron
37
39
 
38
40
  __all__ = ['PalmForEachSoap', 'PaLMForeachSFAdamW', 'PaLMForeachSOAP', 'SFPaLMForeachSOAP', 'PrecondScheduleSFPaLMSOAP',
39
41
  'ForeachSOAP', 'ForeachSFAdamW', 'ForeachLaProp', 'ForeachADOPT', 'PrecondScheduleForeachSOAP',
40
42
  'PrecondSchedulePaLMForeachSOAP', 'ForeachPSGDKron', 'ForeachAdamW', 'ForeachPurePSGD', 'ForeachPaLMPAdam',
41
- 'ForeachDelayedPSGD', 'ForeachCachedPSGDKron', #
43
+ 'ForeachDelayedPSGD', 'ForeachCachedPSGDKron', 'ForeachCachedDelayedPSGDKron' #
42
44
  'PaLMSOAP', 'PaLMSFAdamW', 'PaLMSFSoap', 'PaLMSFAdamW', 'PaLMForeachSOAP', 'PrecondScheduleSFPaLMSOAP',
43
45
  'SOAP', 'SFAdamW', 'LaProp', 'ADOPT', 'PSGDKron', 'AdamW', 'PurePSGD', 'PaLMPAdam', 'DelayedPSGD',
44
- 'CachedPSGDKron']
46
+ 'CachedPSGDKron', 'CachedDelayedPSGDKron']
@@ -0,0 +1,146 @@
1
+ """
2
+ Originally from Evan Walters and Omead Pooladzandi, 2024
3
+ Modified under Creative Commons Attribution 4.0 International
4
+ Source available at https://github.com/evanatyourservice/kron_torch/blob/97a2b5ee8a1a4c29e4780bbf6c521e545189eff9/kron_torch/kron.py
5
+ """
6
+
7
+ from typing import Optional
8
+
9
+ import torch
10
+ from heavyball.utils import einsum_base
11
+
12
+ from .utils import update_param_, warmup, psgd_precond_grad, init_Q_exprs, trust_region_clip_, PSGDBase, \
13
+ precond_update_prob_schedule, split_p_and_g_in_group, line_to_triu, triu_to_line, set_, einsum_base, promote
14
+
15
+
16
+ class ForeachCachedDelayedPSGDKron(PSGDBase):
17
+ """
18
+ Implements PSGD with off-by-one preconditioning (akin to ADOPT and SOAP) with cached preconditioners.
19
+
20
+
21
+ Args:
22
+ params (iterable): Iterable of parameters to optimize or dicts defining
23
+ parameter groups.
24
+ lr (float): Learning rate.
25
+ b1 (float): Momentum parameter.
26
+ weight_decay (float): Weight decay (L2 penalty).
27
+ preconditioner_update_probability (callable or float, optional): Probability of
28
+ updating the preconditioner. If None, defaults to a schedule that anneals
29
+ from 1.0 to 0.03 by 4000 steps.
30
+ max_size_triangular (int): Max size for dim's preconditioner to be triangular.
31
+ min_ndim_triangular (int): Minimum number of dimensions a layer needs
32
+ to have triangular preconditioners.
33
+ memory_save_mode: (string, optional), None, 'one_diag', or 'all_diag', None is default
34
+ to set all preconditioners to be triangular, 'one_diag' sets the largest
35
+ or last dim to be diagonal per layer, and 'all_diag' sets all preconditioners
36
+ to be diagonal.
37
+ momentum_into_precond_update: (bool), whether to send momentum into preconditioner
38
+ update instead of raw gradients.
39
+ """
40
+
41
+ def __init__(self, params, lr=0.001, beta=0.9, weight_decay=0.0, preconditioner_update_probability=None,
42
+ max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
43
+ momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
44
+ split: bool = False, clip_fn: Optional[callable] = None, store_triu_as_line: bool = True,
45
+ foreach: bool = True, q_dtype='float32'):
46
+ if not 0.0 <= lr:
47
+ raise ValueError(f"Invalid learning rate: {lr}")
48
+ if not 0.0 <= beta < 1.0:
49
+ raise ValueError(f"Invalid beta parameter: {beta}")
50
+ if not 0.0 <= weight_decay:
51
+ raise ValueError(f"Invalid weight_decay value: {weight_decay}")
52
+
53
+ if preconditioner_update_probability is None:
54
+ preconditioner_update_probability = precond_update_prob_schedule()
55
+ if clip_fn is None:
56
+ clip_fn = lambda x: trust_region_clip_(x, 0.9, 1.5)
57
+ self.preconditioner_update_probability = preconditioner_update_probability
58
+ self.clip_fn = clip_fn
59
+
60
+ defaults = dict(lr=lr, beta=beta, weight_decay=weight_decay, max_size_triangular=max_size_triangular,
61
+ min_ndim_triangular=min_ndim_triangular, memory_save_mode=memory_save_mode,
62
+ momentum_into_precond_update=momentum_into_precond_update, precond_lr=0.1,
63
+ # precond lr hardcoded to 0.1
64
+ precond_init_scale=1.0, # precond init scale hardcoded to 1.0
65
+ step=0, warmup_steps=warmup_steps, merge_dims=merge_dims, split=split,
66
+ store_triu_as_line=store_triu_as_line,
67
+ q_dtype=q_dtype)
68
+ super().__init__(params, defaults, foreach)
69
+
70
+ self._prob_step = 0
71
+
72
+ def _step(self, group):
73
+ # update preconditioners all together
74
+ update_prob = self.preconditioner_update_probability
75
+ if callable(update_prob):
76
+ update_prob = update_prob(self._prob_step)
77
+ do_update = self.rng.random() < update_prob
78
+ self._prob_step += 1
79
+
80
+ momentum_into_precond_update = group.get("momentum_into_precond_update", True)
81
+ precond_init_scale = group['precond_init_scale']
82
+ max_size_triangular = group['max_size_triangular']
83
+ min_ndim_triangular = group['min_ndim_triangular']
84
+ memory_save_mode = group['memory_save_mode']
85
+ precond_lr = group['precond_lr']
86
+ weight_decay = group['weight_decay']
87
+ lr = group['lr']
88
+ beta = group['beta']
89
+ store_triu_as_line = group['store_triu_as_line']
90
+ q_dtype = getattr(torch, group['q_dtype'])
91
+
92
+ vals = []
93
+
94
+ for p, g in split_p_and_g_in_group(group):
95
+ state = self.state_(p)
96
+
97
+ if 'Q' not in state:
98
+ state["exp_avg"] = torch.zeros_like(g)
99
+ Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular, min_ndim_triangular,
100
+ memory_save_mode, dtype=q_dtype)
101
+ state['Q'] = triu_to_line(Q) if store_triu_as_line else Q
102
+ state['Q_cache'] = [torch.empty_like(q) for q in Q]
103
+
104
+ expr = [f'{c.upper()}{c}' if q_.ndim == 2 else c for c, q_ in zip(einsum_base, Q)]
105
+ expr = ','.join(expr)
106
+ grad_expr = ''.join(c for c, _ in zip(einsum_base, g.shape))
107
+ out_expr = ''.join(c.upper() if c.upper() in expr else c for c in grad_expr)
108
+ expr = f'{expr},{grad_expr}->{out_expr}'
109
+
110
+ state['cache_expr'] = expr
111
+
112
+ vals.append((p, g, state["exp_avg"], state["Q"], state['Q_cache']))
113
+
114
+ if not vals:
115
+ return
116
+
117
+ p_list, grad_list, exp_avg_list, Q_list, Q_cache_list = zip(*vals)
118
+ del vals
119
+
120
+ group["step"] += 1
121
+
122
+ torch._foreach_lerp_(exp_avg_list, grad_list, (1 - beta) / (1 - beta ** group["step"]))
123
+
124
+ grad_list, Q_list, Q_cache_list, exp_avg_list = list(grad_list), list(Q_list), list(Q_cache_list), list(
125
+ exp_avg_list)
126
+ for i, (p, g) in enumerate(zip(p_list, grad_list)):
127
+ cached_q = Q_cache_list.pop(0)
128
+ q_orig = Q_list.pop(0)
129
+ ea = exp_avg_list.pop(0)
130
+
131
+ if do_update:
132
+ q = line_to_triu(q_orig) if store_triu_as_line else q_orig
133
+ q32 = [promote(q_) for q_ in q]
134
+ self.balance([g], [q32])
135
+ self.do_update([p], [ea if momentum_into_precond_update else g], [q32], precond_lr, [q_orig], store_triu_as_line=store_triu_as_line)
136
+ for c_, q_ in zip(cached_q, q):
137
+ if q_.ndim == 2:
138
+ torch.matmul(q_.T.conj(), q_, out=c_)
139
+ else:
140
+ torch.mul(q_.conj(), q_, out=c_)
141
+
142
+ set_(g, torch.einsum(self.state_(p)['cache_expr'], *cached_q, ea))
143
+ grad_list = self.clip_fn(grad_list)
144
+
145
+ lr = -warmup(lr, group['step'], group['warmup_steps'])
146
+ update_param_(p_list, grad_list, lr, weight_decay)
@@ -10,7 +10,7 @@ import torch
10
10
  from heavyball.utils import einsum_base
11
11
 
12
12
  from .utils import update_param_, warmup, psgd_precond_grad, init_Q_exprs, trust_region_clip_, PSGDBase, \
13
- precond_update_prob_schedule, split_p_and_g_in_group, line_to_triu, triu_to_line, set_, einsum_base
13
+ precond_update_prob_schedule, split_p_and_g_in_group, line_to_triu, triu_to_line, set_, einsum_base, promote
14
14
 
15
15
 
16
16
  class ForeachCachedPSGDKron(PSGDBase):
@@ -40,7 +40,7 @@ class ForeachCachedPSGDKron(PSGDBase):
40
40
  max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
41
41
  momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
42
42
  split: bool = False, clip_fn: Optional[callable] = None, store_triu_as_line: bool = True,
43
- foreach: bool = True):
43
+ foreach: bool = True, q_dtype='float32'):
44
44
  if not 0.0 <= lr:
45
45
  raise ValueError(f"Invalid learning rate: {lr}")
46
46
  if not 0.0 <= beta < 1.0:
@@ -61,7 +61,8 @@ class ForeachCachedPSGDKron(PSGDBase):
61
61
  # precond lr hardcoded to 0.1
62
62
  precond_init_scale=1.0, # precond init scale hardcoded to 1.0
63
63
  step=0, warmup_steps=warmup_steps, merge_dims=merge_dims, split=split,
64
- store_triu_as_line=store_triu_as_line)
64
+ store_triu_as_line=store_triu_as_line,
65
+ q_dtype=q_dtype)
65
66
  super().__init__(params, defaults, foreach)
66
67
 
67
68
  self._prob_step = 0
@@ -84,6 +85,7 @@ class ForeachCachedPSGDKron(PSGDBase):
84
85
  lr = group['lr']
85
86
  beta = group['beta']
86
87
  store_triu_as_line = group['store_triu_as_line']
88
+ q_dtype = getattr(torch, group['q_dtype'])
87
89
 
88
90
  vals = []
89
91
 
@@ -93,7 +95,7 @@ class ForeachCachedPSGDKron(PSGDBase):
93
95
  if 'Q' not in state:
94
96
  state["exp_avg"] = torch.zeros_like(g)
95
97
  Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular, min_ndim_triangular,
96
- memory_save_mode, dtype=g.dtype)
98
+ memory_save_mode, dtype=q_dtype)
97
99
  state['Q'] = triu_to_line(Q) if store_triu_as_line else Q
98
100
  state['Q_cache'] = [torch.empty_like(q) for q in Q]
99
101
 
@@ -124,18 +126,21 @@ class ForeachCachedPSGDKron(PSGDBase):
124
126
  q_orig = Q_list.pop(0)
125
127
  ea = exp_avg_list.pop(0)
126
128
 
129
+ new = torch.einsum(self.state_(p)['cache_expr'], *cached_q, ea)
130
+
127
131
  if do_update:
128
132
  q = line_to_triu(q_orig) if store_triu_as_line else q_orig
129
- self.balance([g], [q])
130
- self.do_update([p], [ea if momentum_into_precond_update else g], [q], precond_lr,
131
- [q_orig] if store_triu_as_line else None)
133
+ q32 = [promote(q_) for q_ in q]
134
+ self.balance([g], [q32])
135
+ self.do_update([p], [ea if momentum_into_precond_update else g], [q32], precond_lr, [q_orig], store_triu_as_line=store_triu_as_line)
132
136
  for c_, q_ in zip(cached_q, q):
133
137
  if q_.ndim == 2:
134
138
  torch.matmul(q_.T.conj(), q_, out=c_)
135
139
  else:
136
140
  torch.mul(q_.conj(), q_, out=c_)
137
141
 
138
- set_(g, torch.einsum(self.state_(p)['cache_expr'], *cached_q, ea))
142
+ set_(g, new)
143
+
139
144
  grad_list = self.clip_fn(grad_list)
140
145
 
141
146
  lr = -warmup(lr, group['step'], group['warmup_steps'])
heavyball/delayed_psgd.py CHANGED
@@ -8,7 +8,7 @@ import torch
8
8
  from heavyball.utils import copy_stochastic_list_
9
9
 
10
10
  from .utils import update_param_, warmup, psgd_precond_grad, init_Q_exprs, trust_region_clip_, PSGDBase, \
11
- precond_update_prob_schedule, split_p_and_g_in_group, triu_to_line, line_to_triu, set_
11
+ precond_update_prob_schedule, split_p_and_g_in_group, triu_to_line, line_to_triu, set_, promote
12
12
 
13
13
 
14
14
  class ForeachDelayedPSGD(PSGDBase):
@@ -39,7 +39,7 @@ class ForeachDelayedPSGD(PSGDBase):
39
39
  max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
40
40
  momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
41
41
  split: bool = False, clip_fn: callable = None, store_triu_as_line: bool = True,
42
- foreach: bool = True):
42
+ foreach: bool = True, q_dtype='float32'):
43
43
  if not 0.0 <= lr:
44
44
  raise ValueError(f"Invalid learning rate: {lr}")
45
45
  if not 0.0 <= beta < 1.0:
@@ -60,7 +60,7 @@ class ForeachDelayedPSGD(PSGDBase):
60
60
  # precond lr hardcoded to 0.1
61
61
  precond_init_scale=1.0, # precond init scale hardcoded to 1.0
62
62
  step=0, warmup_steps=warmup_steps, merge_dims=merge_dims, split=split,
63
- store_triu_as_line=store_triu_as_line)
63
+ store_triu_as_line=store_triu_as_line, q_dtype=q_dtype)
64
64
  super().__init__(params, defaults, foreach)
65
65
 
66
66
  self._prob_step = 0
@@ -83,6 +83,7 @@ class ForeachDelayedPSGD(PSGDBase):
83
83
  lr = group['lr']
84
84
  beta = group['beta']
85
85
  store_triu_as_line = group['store_triu_as_line']
86
+ q_dtype = getattr(torch, group['q_dtype'])
86
87
 
87
88
  vals = []
88
89
 
@@ -92,7 +93,7 @@ class ForeachDelayedPSGD(PSGDBase):
92
93
  if 'Q' not in state:
93
94
  state["exp_avg"] = torch.zeros_like(g)
94
95
  Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular, min_ndim_triangular,
95
- memory_save_mode, dtype=g.dtype)
96
+ memory_save_mode, dtype=q_dtype)
96
97
  state["Q"] = triu_to_line(Q) if store_triu_as_line else Q
97
98
 
98
99
  vals.append((p, g, state["exp_avg"], state["Q"]))
@@ -114,9 +115,9 @@ class ForeachDelayedPSGD(PSGDBase):
114
115
  q = line_to_triu(q_orig) if store_triu_as_line else q_orig
115
116
  new = psgd_precond_grad(q, self.state_(p)["exprs"], ea)
116
117
  if do_update:
117
- self.do_update([p], [ea if momentum_into_precond_update else g], [q], precond_lr,
118
- [q_orig] if store_triu_as_line else None)
119
- self.balance([g], [q])
118
+ q32 = [promote(q_) for q_ in q]
119
+ self.do_update([p], [ea if momentum_into_precond_update else g], [q32], precond_lr, [q_orig], store_triu_as_line=store_triu_as_line)
120
+ self.balance([g], [q32])
120
121
  set_(g, new)
121
122
 
122
123
  grad_list = self.clip_fn(grad_list)
heavyball/p_adam.py CHANGED
@@ -8,7 +8,7 @@ import torch
8
8
  from heavyball.utils import triu_to_line, line_to_triu
9
9
 
10
10
  from .utils import update_param_, warmup, psgd_precond_grad, init_Q_exprs, PSGDBase, precond_update_prob_schedule, \
11
- exp_avg_sq_, beta_debias, split_p_and_g_in_group
11
+ exp_avg_sq_, beta_debias, split_p_and_g_in_group, promote
12
12
 
13
13
 
14
14
  class ForeachPaLMPAdam(PSGDBase):
@@ -39,7 +39,7 @@ class ForeachPaLMPAdam(PSGDBase):
39
39
  momentum_into_precond_update=True, warmup_steps: int = 1, betas=(None, None), beta: float = 0.9,
40
40
  beta2_scale: float = 0.8, merge_dims: bool = False, split: bool = False, clip_fn: callable = None,
41
41
  store_triu_as_line: bool = True,
42
- foreach: bool = True):
42
+ foreach: bool = True, q_dtype='float32'):
43
43
  if not 0.0 <= lr:
44
44
  raise ValueError(f"Invalid learning rate: {lr}")
45
45
  if not 0.0 <= weight_decay:
@@ -60,7 +60,7 @@ class ForeachPaLMPAdam(PSGDBase):
60
60
  # precond lr hardcoded to 0.1
61
61
  precond_init_scale=1.0, # precond init scale hardcoded to 1.0
62
62
  step=0, warmup_steps=warmup_steps, beta=beta, beta2_scale=beta2_scale, merge_dims=merge_dims,
63
- split=split, store_triu_as_line=store_triu_as_line)
63
+ split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype)
64
64
  super().__init__(params, defaults, foreach)
65
65
 
66
66
  self._prob_step = 0
@@ -81,6 +81,7 @@ class ForeachPaLMPAdam(PSGDBase):
81
81
  weight_decay = group['weight_decay']
82
82
  lr = group['lr']
83
83
  store_triu_as_line = group['store_triu_as_line']
84
+ q_dtype = getattr(torch, group['q_dtype'])
84
85
 
85
86
  vals = []
86
87
 
@@ -91,7 +92,7 @@ class ForeachPaLMPAdam(PSGDBase):
91
92
  state['exp_avg'] = torch.zeros_like(g)
92
93
  state['exp_avg_sq'] = torch.zeros_like(g)
93
94
  Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular,
94
- min_ndim_triangular, memory_save_mode, dtype=g.dtype)
95
+ min_ndim_triangular, memory_save_mode, dtype=q_dtype)
95
96
  state['Q'] = triu_to_line(Q) if store_triu_as_line else Q
96
97
 
97
98
  vals.append((p, g, state["Q"], state['exp_avg'], state['exp_avg_sq']))
@@ -106,9 +107,10 @@ class ForeachPaLMPAdam(PSGDBase):
106
107
 
107
108
  Q_triu = [line_to_triu(q) if store_triu_as_line else q for q in Q_list]
108
109
  if do_update:
109
- self.balance(grad_list, Q_triu)
110
- self.do_update(p_list, grad_list, Q_triu, precond_lr, Q_list if store_triu_as_line else None)
111
-
110
+ for g, p, q_, q_orig in zip(grad_list, p_list, Q_triu, Q_list):
111
+ q32 = [promote(qq_) for qq_ in q_]
112
+ self.balance([g], [q32])
113
+ self.do_update([p], [g], [q32], precond_lr, [q_orig], store_triu_as_line=store_triu_as_line)
112
114
  torch._foreach_lerp_(exp_avg, grad_list, 1 - beta_debias(group['beta'], group['step']))
113
115
 
114
116
  beta2 = 1 - group['step'] ** -group['beta2_scale']
heavyball/psgd_kron.py CHANGED
@@ -9,7 +9,7 @@ from typing import Optional
9
9
  import torch
10
10
 
11
11
  from .utils import update_param_, warmup, psgd_precond_grad, init_Q_exprs, trust_region_clip_, PSGDBase, \
12
- precond_update_prob_schedule, split_p_and_g_in_group, line_to_triu, triu_to_line, set_
12
+ precond_update_prob_schedule, split_p_and_g_in_group, line_to_triu, triu_to_line, set_, promote
13
13
 
14
14
 
15
15
  class ForeachPSGDKron(PSGDBase):
@@ -39,7 +39,7 @@ class ForeachPSGDKron(PSGDBase):
39
39
  max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
40
40
  momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
41
41
  split: bool = False, clip_fn: Optional[callable] = None, store_triu_as_line: bool = True,
42
- foreach: bool = True):
42
+ foreach: bool = True, q_dtype='float32'):
43
43
  if not 0.0 <= lr:
44
44
  raise ValueError(f"Invalid learning rate: {lr}")
45
45
  if not 0.0 <= beta < 1.0:
@@ -60,7 +60,7 @@ class ForeachPSGDKron(PSGDBase):
60
60
  # precond lr hardcoded to 0.1
61
61
  precond_init_scale=1.0, # precond init scale hardcoded to 1.0
62
62
  step=0, warmup_steps=warmup_steps, merge_dims=merge_dims, split=split,
63
- store_triu_as_line=store_triu_as_line)
63
+ store_triu_as_line=store_triu_as_line, q_dtype=q_dtype)
64
64
  super().__init__(params, defaults, foreach)
65
65
 
66
66
  self._prob_step = 0
@@ -83,6 +83,7 @@ class ForeachPSGDKron(PSGDBase):
83
83
  lr = group['lr']
84
84
  beta = group['beta']
85
85
  store_triu_as_line = group['store_triu_as_line']
86
+ q_dtype = getattr(torch, group['q_dtype'])
86
87
 
87
88
  vals = []
88
89
 
@@ -92,7 +93,7 @@ class ForeachPSGDKron(PSGDBase):
92
93
  if 'Q' not in state:
93
94
  state["exp_avg"] = torch.zeros_like(g)
94
95
  Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular, min_ndim_triangular,
95
- memory_save_mode, dtype=g.dtype)
96
+ memory_save_mode, dtype=q_dtype)
96
97
  state['Q'] = triu_to_line(Q) if store_triu_as_line else Q
97
98
 
98
99
  vals.append((p, g, state["exp_avg"], state["Q"]))
@@ -114,9 +115,9 @@ class ForeachPSGDKron(PSGDBase):
114
115
  q = line_to_triu(q_orig) if store_triu_as_line else q_orig
115
116
 
116
117
  if do_update:
117
- self.balance([g], [q])
118
- self.do_update([p], [ea if momentum_into_precond_update else g], [q], precond_lr,
119
- [q_orig] if store_triu_as_line else None)
118
+ q32 = [promote(q_) for q_ in q]
119
+ self.balance([ea if momentum_into_precond_update else g], [q32])
120
+ self.do_update([p], [g], [q32], precond_lr, [q_orig], store_triu_as_line=store_triu_as_line)
120
121
  set_(g, psgd_precond_grad(q, self.state_(p)["exprs"], ea))
121
122
 
122
123
  grad_list = self.clip_fn(grad_list)
heavyball/pure_psgd.py CHANGED
@@ -5,9 +5,10 @@ Source available at https://github.com/evanatyourservice/kron_torch/blob/97a2b5e
5
5
  """
6
6
 
7
7
  import torch
8
+ from heavyball.utils import copy_stochastic_list_
8
9
 
9
10
  from .utils import update_param_, warmup, psgd_precond_grad, init_Q_exprs, PSGDBase, precond_update_prob_schedule, \
10
- split_p_and_g_in_group, line_to_triu, triu_to_line
11
+ split_p_and_g_in_group, line_to_triu, triu_to_line, promote
11
12
 
12
13
 
13
14
  class ForeachPurePSGD(PSGDBase):
@@ -37,7 +38,7 @@ class ForeachPurePSGD(PSGDBase):
37
38
  max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
38
39
  momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
39
40
  split: bool = False, clip_fn: callable = None, store_triu_as_line: bool = True,
40
- foreach: bool = True):
41
+ foreach: bool = True, q_dtype='float32'):
41
42
  if not 0.0 <= lr:
42
43
  raise ValueError(f"Invalid learning rate: {lr}")
43
44
  if not 0.0 <= weight_decay:
@@ -56,7 +57,7 @@ class ForeachPurePSGD(PSGDBase):
56
57
  # precond lr hardcoded to 0.1
57
58
  precond_init_scale=1.0, # precond init scale hardcoded to 1.0
58
59
  step=0, warmup_steps=warmup_steps, merge_dims=merge_dims, split=split,
59
- store_triu_as_line=store_triu_as_line)
60
+ store_triu_as_line=store_triu_as_line, q_dtype=q_dtype)
60
61
  super().__init__(params, defaults, foreach)
61
62
 
62
63
  self._prob_step = 0
@@ -77,6 +78,7 @@ class ForeachPurePSGD(PSGDBase):
77
78
  weight_decay = group['weight_decay']
78
79
  lr = group['lr']
79
80
  store_triu_as_line = group['store_triu_as_line']
81
+ q_dtype = getattr(torch, group['q_dtype'])
80
82
 
81
83
  vals = []
82
84
 
@@ -85,7 +87,7 @@ class ForeachPurePSGD(PSGDBase):
85
87
 
86
88
  if 'Q' not in state:
87
89
  Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular, min_ndim_triangular,
88
- memory_save_mode, dtype=g.dtype)
90
+ memory_save_mode, dtype=q_dtype)
89
91
  state['Q'] = triu_to_line(Q) if store_triu_as_line else Q
90
92
 
91
93
  vals.append((p, g, state["Q"]))
@@ -104,8 +106,9 @@ class ForeachPurePSGD(PSGDBase):
104
106
  q = line_to_triu(q_orig) if store_triu_as_line else q_orig
105
107
 
106
108
  if do_update:
107
- self.balance([g], [q])
108
- self.do_update([p], [g], [q], precond_lr, [q_orig] if store_triu_as_line else None)
109
+ q32 = [promote(q_) for q_ in q]
110
+ self.balance([g], [q32])
111
+ self.do_update([p], [g], [q32], precond_lr, [q_orig], store_triu_as_line=store_triu_as_line)
109
112
  psgd_precond_grad(q, self.state_(p)["exprs"], g, inplace=True)
110
113
 
111
114
  grad_list = self.clip_fn(grad_list)
heavyball/utils.py CHANGED
@@ -325,9 +325,9 @@ def compute_ggt(grad, GG, max_precond_dim, precondition_1d, beta):
325
325
 
326
326
 
327
327
  def promote(x):
328
- if x is (torch.bfloat16, torch.float16):
328
+ if x in (torch.bfloat16, torch.float16):
329
329
  return torch.float32
330
- if x.dtype in (torch.bfloat16, torch.float16):
330
+ if hasattr(x, 'dtype') and x.dtype in (torch.bfloat16, torch.float16):
331
331
  return x.float()
332
332
  return x
333
333
 
@@ -468,15 +468,15 @@ class ScheduleFree(StatefulOptimizer):
468
468
 
469
469
  def copy_stochastic_list_(target: List[torch.Tensor], source: List[torch.Tensor]):
470
470
  for t, s in zip(target, source):
471
- if t.dtype == torch.bfloat16:
472
- copy_stochastic_(t, s)
473
- else:
474
- set_(t, s)
471
+ copy_stochastic_(t, s)
475
472
 
476
473
 
477
474
  def copy_stochastic_(target: torch.Tensor, source: torch.Tensor):
478
475
  if target.data_ptr() == source.data_ptr():
479
476
  return
477
+ if target.dtype != torch.bfloat16:
478
+ set_(target, source)
479
+ return
480
480
 
481
481
  """Taken as-is from https://github.com/pytorch/pytorch/issues/120376#issuecomment-1974828905"""
482
482
  # create a random 16 bit integer
@@ -555,7 +555,7 @@ def init_Q_exprs(t, scale, max_size, min_ndim_triangular, memory_save_mode, dtyp
555
555
  for i, (size, dim_d) in enumerate(zip(shape, dim_diag)):
556
556
  if size == 1 or size > max_size or len(shape) < min_ndim_triangular or dim_d:
557
557
  # use diagonal matrix as preconditioner for this dim
558
- Q.append(scale * torch.ones(size, dtype=dtype, device=t.device))
558
+ Q.append(scale * torch.ones(size, dtype=promote(dtype), device=t.device))
559
559
 
560
560
  piece1A.append(letters[i])
561
561
  piece2A = piece2A + letters[i]
@@ -669,11 +669,11 @@ def psgd_update_precond(Q, exprs, V, G, step, tiny):
669
669
  @decorator
670
670
  def psgd_precond_grad(Q, exprs, G, inplace: bool = False):
671
671
  """Precondition gradient G with preconditioner Q."""
672
- out = torch.einsum(exprs[-1], *[q.conj() for q in Q], *Q, G)
672
+ out = torch.einsum(exprs[-1], *[q.conj() for q in Q], *Q, G.to(Q[0].dtype))
673
673
  if inplace:
674
674
  set_(G, out)
675
675
  return G
676
- return out
676
+ return out.to(G.dtype)
677
677
 
678
678
 
679
679
  def norm_clip_(x, scale=None):
@@ -768,28 +768,33 @@ def line_to_triu(Q_list: List[Tuple[Optional[List[int]], torch.Tensor]]):
768
768
  def update_triu_(q_state, materialised):
769
769
  for (shape0, q), (shape1, m) in zip(q_state, triu_to_line(materialised)):
770
770
  assert shape0 == shape1
771
- set_(q, m)
771
+ copy_stochastic_(q, m)
772
772
 
773
773
 
774
774
  class PSGDBase(StatefulOptimizer):
775
+ balance_probability: float = 0.01
776
+
775
777
  def __init__(self, parameters, groups, foreach: bool = True):
776
778
  super().__init__(parameters, groups, foreach)
777
779
  self.rng = random.Random(0x1923213)
778
780
  self._tiny = torch.finfo(torch.bfloat16).tiny
779
781
 
780
782
  def balance(self, grad_list, Q_list):
781
- if self.rng.random() > 0.01:
783
+ if self.rng.random() > self.balance_probability:
782
784
  return
783
785
 
784
786
  for g, q in zip(grad_list, Q_list):
785
787
  if g.dim() > 1:
786
788
  psgd_balance_Q(q)
787
789
 
788
- def do_update(self, p_list, grad_list, q_list, precond_lr, original_q: Optional[List] = None):
790
+ def do_update(self, p_list, grad_list, q_list, precond_lr, original_q: Optional[List] = None, store_triu_as_line=False):
789
791
  for i, (p, grad, Q) in enumerate(zip(p_list, grad_list, q_list)):
790
792
  psgd_update_precond(Q, self.state_(p)["exprs"], torch.randn_like(grad), grad, precond_lr, self._tiny)
791
793
  if original_q:
792
- update_triu_(original_q[i], Q)
794
+ if store_triu_as_line:
795
+ update_triu_(original_q[i], Q)
796
+ else:
797
+ copy_stochastic_(original_q[i], Q)
793
798
 
794
799
 
795
800
  def precond_update_prob_schedule(max_prob=1.0, min_prob=0.03, decay=0.001, flat_start=250):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: heavyball
3
- Version: 0.16.0
3
+ Version: 0.17.0
4
4
  Summary: Efficient optimizers
5
5
  Home-page: https://github.com/clashluke/heavyball
6
6
  Author: Lucas Nestler
@@ -1,23 +1,24 @@
1
- heavyball/__init__.py,sha256=KbT0GMU0DKqZxq9laCrD7XgiqS9yxC1W52zhte5kjKs,2054
2
- heavyball/cached_psgd_kron.py,sha256=vJuy639G-_ZLSRX3goSFMXALv-ucYjrxaEtpj0IHo-M,6802
3
- heavyball/delayed_psgd.py,sha256=sbwgAed5gmQpHNTPvuE7Si-gB-s0NVvN4d-4rNUJj4c,5893
1
+ heavyball/__init__.py,sha256=mDHahP4u0fy2YKWA4FPMAp7jLPMt5WwUkEiOrwE4u3E,2199
2
+ heavyball/cached_delayed_psgd_kron.py,sha256=DvjNNHzbnS-NDq965wve-VQ-ol7IFljYYGTuTwPHOhU,6971
3
+ heavyball/cached_psgd_kron.py,sha256=xy3-yRKFUvRTstJb_asMVp-k-5Zuw_HyILPi7BsuMKQ,6974
4
+ heavyball/delayed_psgd.py,sha256=rDDUj3miEn6HRJmKl-ZImsqkqBASSn8aC7MEV_06fzU,6017
4
5
  heavyball/foreach_adamw.py,sha256=h_ar0ZZRM0q_wxuEkxEOEYe-2p-mB4OMgAHivrUnPl8,1777
5
6
  heavyball/foreach_adopt.py,sha256=ogOw2JjwEQNj7AKlweAphQFdMJ_GcMDm-RyDvEzugoc,1911
6
7
  heavyball/foreach_laprop.py,sha256=yGVmGqWiSw8Y2Xj70ndkR8ZMygakTB4_iRwV02Svkqg,1816
7
8
  heavyball/foreach_sfadamw.py,sha256=15-n6-lx4PAHYsKYmXbugxsR5MnqaPYy2vUudPRiitg,2087
8
9
  heavyball/foreach_soap.py,sha256=h6ptMch7oaynvu3eIJtWnVXypDA_5JDVm3Zb3PNEma0,4634
9
- heavyball/p_adam.py,sha256=aCu4Qn0eHJETHuCGrfNKp2aygKk2ZoNQyxut3Vcqmoc,6112
10
+ heavyball/p_adam.py,sha256=F2b-xGNROi9VfX7isa3kffWePojpBl5BI1n464w4tGQ,6334
10
11
  heavyball/palm_foreach_sfadamw.py,sha256=yvZbPyjDW8qd3r4qDXb6uTr5RozQ7JSDj4aYYRnKGLA,2248
11
12
  heavyball/palm_foreach_soap.py,sha256=g4hbiGRcti-J-a0SwAkP4ii5pU-aalsZH5bssyhroLk,5938
12
13
  heavyball/precond_schedule_foreach_soap.py,sha256=WLg5SzpJnKPZUvFyIvdwSZa1Umt5cpr3Kow_42orM-E,4863
13
14
  heavyball/precond_schedule_palm_foreach_soap.py,sha256=ammQrvRZFF-wc-wEiPEoFhS_7b8pdV61QfcLoQfimSo,6211
14
15
  heavyball/precond_schedule_sfpsoap.py,sha256=vq7jd302refKPa_9X2lkOTOtCCcTBVByPdojklrY8pA,6770
15
- heavyball/psgd_kron.py,sha256=iWTAViuzxTodtQGZnkLsEXrLG8tNU-BQB3KkTYAVcX4,5874
16
- heavyball/pure_psgd.py,sha256=EuCPNM8TX13cOop-mvvBFh6Uo1UjD1vsE053hvil92Q,5136
16
+ heavyball/psgd_kron.py,sha256=2IpPj2TOExNGm8hSewi3er2GczJRNgC7r2J5yYSSA_0,5998
17
+ heavyball/pure_psgd.py,sha256=uA7W9a3Qm1sxHQhtNxaUYrmE5x55lP5iJOKy_qT8XaQ,5341
17
18
  heavyball/schedule_free_palm_foreach_soap.py,sha256=zkcikH5wWbzq4kOrmBjilvY3iWzuUddcv2HNEPKr3MI,6366
18
- heavyball/utils.py,sha256=z6taEvpgszKTrscqgowKYqb0xIVpBDVDBNGgvTE4Pb8,28484
19
- heavyball-0.16.0.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
20
- heavyball-0.16.0.dist-info/METADATA,sha256=yjpldOTN2rXN2-KG7R9ytuyBfmSCDpznZeRuziANChE,11941
21
- heavyball-0.16.0.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
22
- heavyball-0.16.0.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
23
- heavyball-0.16.0.dist-info/RECORD,,
19
+ heavyball/utils.py,sha256=Jqh7VdWGeiSdwaPtUNB9l14wuuFPSReLaTwJA3juFbM,28765
20
+ heavyball-0.17.0.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
21
+ heavyball-0.17.0.dist-info/METADATA,sha256=GIJQ4ha-fcYR6ltOs4WUO8L_LhWGiZv2UrEZcuJD0LI,11941
22
+ heavyball-0.17.0.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
23
+ heavyball-0.17.0.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
24
+ heavyball-0.17.0.dist-info/RECORD,,