heavyball 0.14.6__py3-none-any.whl → 0.15.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -54,85 +54,73 @@ class PrecondScheduleSFPaLMSOAP(ScheduleFree):
54
54
  self._data_format = data_format
55
55
  self.rng = random.Random(0x120983109)
56
56
 
57
- @torch.no_grad()
58
- def step(self, closure=None):
59
- """
60
- Performs a single optimization step.
61
-
62
- Arguments:
63
- closure (`Callable`, *optional*): A closure that reevaluates the model and returns the loss.
64
- """
65
- if closure is None:
66
- loss = None
67
- else:
68
- loss = closure()
69
-
70
- for group in self.param_groups:
71
- vals = []
72
- max_precond_dim = group['max_precond_dim']
73
- precondition_1d = group['precondition_1d']
74
-
75
- step = group['step'] = group.get("step", -1) + 1
76
-
77
- for p in group["params"]:
78
- if p.grad is None:
79
- continue
80
- grad = p.grad.float()
81
- vals.append((p, grad))
82
-
83
- p_list, grad = zip(*vals)
84
- vals = []
85
-
86
- # adaptive gradient clipping
87
- adaptive_gradient_clipping_(p_list, grad, group["gradient_clip_val"], eps=group["eps"])
88
-
89
- for p, g in split_p_and_g_in_group(group):
90
- state = self.state_(p)
91
-
92
- if "z" not in state:
93
- state["z"] = torch.clone(p.data)
94
- state["exp_avg_sq"] = torch.zeros_like(g, dtype=torch.float32)
95
- init_preconditioner(g, state, max_precond_dim, precondition_1d)
96
- update_preconditioner(g, state, max_precond_dim, precondition_1d, 0, True)
97
- continue # first step is skipped so that we never use the current gradients in the projection.
98
-
99
- # Projecting gradients to the eigenbases of Shampoo's preconditioner
100
- # i.e. projecting to the eigenbases of matrices in state['GG']
101
- grad_projected = project(g, state['Q'], False)
102
- z, exp_avg_sq = state["z"], state["exp_avg_sq"]
103
- vals.append((p, g, grad_projected, z, exp_avg_sq))
104
-
105
- if not vals:
57
+ def _step(self, group):
58
+ vals = []
59
+ max_precond_dim = group['max_precond_dim']
60
+ precondition_1d = group['precondition_1d']
61
+
62
+ step = group['step'] = group.get("step", -1) + 1
63
+
64
+ for p in group["params"]:
65
+ if p.grad is None:
106
66
  continue
67
+ grad = p.grad.float()
68
+ vals.append((p, grad))
69
+
70
+ if not vals:
71
+ return
72
+
73
+ p_list, grad = zip(*vals)
74
+ vals = []
75
+
76
+ # adaptive gradient clipping
77
+ adaptive_gradient_clipping_(p_list, grad, group["gradient_clip_val"], eps=group["eps"])
78
+
79
+ for p, g in split_p_and_g_in_group(group):
80
+ state = self.state_(p)
81
+
82
+ if "z" not in state:
83
+ state["z"] = torch.clone(p.data)
84
+ state["exp_avg_sq"] = torch.zeros_like(g, dtype=torch.float32)
85
+ init_preconditioner(g, state, max_precond_dim, precondition_1d)
86
+ update_preconditioner(g, state, max_precond_dim, precondition_1d, 0, True)
87
+ continue # first step is skipped so that we never use the current gradients in the projection.
88
+
89
+ # Projecting gradients to the eigenbases of Shampoo's preconditioner
90
+ # i.e. projecting to the eigenbases of matrices in state['GG']
91
+ grad_projected = project(g, state['Q'], False)
92
+ z, exp_avg_sq = state["z"], state["exp_avg_sq"]
93
+ vals.append((p, g, grad_projected, z, exp_avg_sq))
107
94
 
108
- p_list, grad, grad_projected, z, exp_avg_sq = zip(*vals)
109
- del vals
95
+ if not vals:
96
+ return
110
97
 
111
- beta2 = 1 - max(step, 1) ** -group['beta2_scale']
112
- old_debiased2 = beta_debias(beta2, step)
98
+ p_list, grad, grad_projected, z, exp_avg_sq = zip(*vals)
99
+ del vals
113
100
 
114
- # Decay the first and second moment running average coefficient
115
- # In-place operations to update the averages at the same time
116
- denom = exp_avg_sq_(exp_avg_sq, grad_projected, old_debiased2, group['eps'])
117
- torch._foreach_div_(grad_projected, denom)
101
+ beta2 = 1 - max(step, 1) ** -group['beta2_scale']
102
+ old_debiased2 = beta_debias(beta2, step)
118
103
 
119
- update_precond = precond_schedule(step, group['precond_scheduler'], self.rng)
104
+ # Decay the first and second moment running average coefficient
105
+ # In-place operations to update the averages at the same time
106
+ denom = exp_avg_sq_(exp_avg_sq, grad_projected, old_debiased2, group['eps'])
107
+ torch._foreach_div_(grad_projected, denom)
120
108
 
121
- for p, g, gp in zip(p_list, grad, grad_projected):
122
- state = self.state_(p)
123
- # Projecting back the preconditioned (by Adam) exponential moving average of gradients
124
- # to the original space
125
- set_(gp, project(gp, state['Q'], back=True))
109
+ update_precond = precond_schedule(step, group['precond_scheduler'], self.rng)
126
110
 
127
- update_preconditioner(g, state, max_precond_dim, precondition_1d, old_debiased2,
128
- update_precond)
111
+ for p, g, gp in zip(p_list, grad, grad_projected):
112
+ state = self.state_(p)
113
+ # Projecting back the preconditioned (by Adam) exponential moving average of gradients
114
+ # to the original space
115
+ set_(gp, project(gp, state['Q'], back=True))
129
116
 
130
- # Weight decay calculated at y
131
- if group["weight_decay"] > 0:
132
- torch._foreach_add_(grad, p_list, alpha=group["weight_decay"])
117
+ update_preconditioner(g, state, max_precond_dim, precondition_1d, old_debiased2,
118
+ update_precond)
133
119
 
134
- lr = warmup(group['lr'], step, group['warmup_steps'])
135
- group['weight_sum'] = schedule_free_(lr, group['weight_lr_power'], group['weight_sum'], group['beta'],
136
- p_list, z, grad_projected, group['r'], step)
120
+ # Weight decay calculated at y
121
+ if group["weight_decay"] > 0:
122
+ torch._foreach_add_(grad, p_list, alpha=group["weight_decay"])
137
123
 
138
- return loss
124
+ lr = warmup(group['lr'], step, group['warmup_steps'])
125
+ group['weight_sum'] = schedule_free_(lr, group['weight_lr_power'], group['weight_sum'], group['beta'],
126
+ p_list, z, grad_projected, group['r'], step)
heavyball/psgd_kron.py CHANGED
@@ -4,9 +4,10 @@ Modified under Creative Commons Attribution 4.0 International
4
4
  Source available at https://github.com/evanatyourservice/kron_torch/blob/97a2b5ee8a1a4c29e4780bbf6c521e545189eff9/kron_torch/kron.py
5
5
  """
6
6
 
7
- import torch
8
7
  from typing import Optional
9
8
 
9
+ import torch
10
+
10
11
  from .utils import update_param_, warmup, psgd_precond_grad, init_Q_exprs, trust_region_clip_, PSGDBase, \
11
12
  precond_update_prob_schedule, split_p_and_g_in_group, line_to_triu, triu_to_line, set_
12
13
 
@@ -62,13 +63,7 @@ class ForeachPSGDKron(PSGDBase):
62
63
 
63
64
  self._prob_step = 0
64
65
 
65
- @torch.no_grad()
66
- def step(self, closure=None):
67
- loss = None
68
- if closure is not None:
69
- with torch.enable_grad():
70
- loss = closure()
71
-
66
+ def _step(self, group):
72
67
  # update preconditioners all together
73
68
  update_prob = self.preconditioner_update_probability
74
69
  if callable(update_prob):
@@ -76,54 +71,51 @@ class ForeachPSGDKron(PSGDBase):
76
71
  do_update = self.rng.random() < update_prob
77
72
  self._prob_step += 1
78
73
 
79
- for group in self.param_groups:
80
- momentum_into_precond_update = group.get("momentum_into_precond_update", True)
81
- precond_init_scale = group['precond_init_scale']
82
- max_size_triangular = group['max_size_triangular']
83
- min_ndim_triangular = group['min_ndim_triangular']
84
- memory_save_mode = group['memory_save_mode']
85
- precond_lr = group['precond_lr']
86
- weight_decay = group['weight_decay']
87
- lr = group['lr']
88
- beta = group['beta']
89
-
90
- vals = []
74
+ momentum_into_precond_update = group.get("momentum_into_precond_update", True)
75
+ precond_init_scale = group['precond_init_scale']
76
+ max_size_triangular = group['max_size_triangular']
77
+ min_ndim_triangular = group['min_ndim_triangular']
78
+ memory_save_mode = group['memory_save_mode']
79
+ precond_lr = group['precond_lr']
80
+ weight_decay = group['weight_decay']
81
+ lr = group['lr']
82
+ beta = group['beta']
91
83
 
92
- for p, g in split_p_and_g_in_group(group):
93
- state = self.state_(p)
84
+ vals = []
94
85
 
95
- if 'Q' not in state:
96
- state["exp_avg"] = torch.zeros_like(g)
97
- Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular, min_ndim_triangular,
98
- memory_save_mode, dtype=g.dtype)
99
- state['Q'] = triu_to_line(Q)
86
+ for p, g in split_p_and_g_in_group(group):
87
+ state = self.state_(p)
100
88
 
101
- vals.append((p, g, state["exp_avg"], state["Q"]))
89
+ if 'Q' not in state:
90
+ state["exp_avg"] = torch.zeros_like(g)
91
+ Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular, min_ndim_triangular,
92
+ memory_save_mode, dtype=g.dtype)
93
+ state['Q'] = triu_to_line(Q)
102
94
 
103
- if not vals:
104
- continue
95
+ vals.append((p, g, state["exp_avg"], state["Q"]))
105
96
 
106
- p_list, grad_list, exp_avg_list, Q_list = zip(*vals)
107
- del vals
97
+ if not vals:
98
+ return
108
99
 
109
- group["step"] += 1
100
+ p_list, grad_list, exp_avg_list, Q_list = zip(*vals)
101
+ del vals
110
102
 
111
- torch._foreach_lerp_(exp_avg_list, grad_list, (1 - beta) / (1 - beta ** group["step"]))
103
+ group["step"] += 1
112
104
 
113
- grad_list, Q_list, exp_avg_list = list(grad_list), list(Q_list), list(exp_avg_list)
114
- for i, (p, g) in enumerate(zip(p_list, grad_list)):
115
- q_orig = Q_list.pop(0)
116
- ea = exp_avg_list.pop(0)
117
- q = line_to_triu(q_orig)
105
+ torch._foreach_lerp_(exp_avg_list, grad_list, (1 - beta) / (1 - beta ** group["step"]))
118
106
 
119
- self.balance(do_update, [g], [q])
120
- if do_update:
121
- self.do_update([p], [ea if momentum_into_precond_update else g], [q], precond_lr, [q_orig])
122
- set_(g, psgd_precond_grad(q, self.state_(p)["exprs"], ea))
107
+ grad_list, Q_list, exp_avg_list = list(grad_list), list(Q_list), list(exp_avg_list)
108
+ for i, (p, g) in enumerate(zip(p_list, grad_list)):
109
+ q_orig = Q_list.pop(0)
110
+ ea = exp_avg_list.pop(0)
111
+ q = line_to_triu(q_orig)
123
112
 
124
- grad_list = self.clip_fn(grad_list)
113
+ self.balance(do_update, [g], [q])
114
+ if do_update:
115
+ self.do_update([p], [ea if momentum_into_precond_update else g], [q], precond_lr, [q_orig])
116
+ set_(g, psgd_precond_grad(q, self.state_(p)["exprs"], ea))
125
117
 
126
- lr = -warmup(lr, group['step'], group['warmup_steps'])
127
- update_param_(p_list, grad_list, lr, weight_decay)
118
+ grad_list = self.clip_fn(grad_list)
128
119
 
129
- return loss
120
+ lr = -warmup(lr, group['step'], group['warmup_steps'])
121
+ update_param_(p_list, grad_list, lr, weight_decay)
heavyball/pure_psgd.py CHANGED
@@ -59,13 +59,7 @@ class ForeachPurePSGD(PSGDBase):
59
59
 
60
60
  self._prob_step = 0
61
61
 
62
- @torch.no_grad()
63
- def step(self, closure=None):
64
- loss = None
65
- if closure is not None:
66
- with torch.enable_grad():
67
- loss = closure()
68
-
62
+ def _step(self, group):
69
63
  # update preconditioners all together
70
64
  update_prob = self.preconditioner_update_probability
71
65
  if callable(update_prob):
@@ -73,48 +67,45 @@ class ForeachPurePSGD(PSGDBase):
73
67
  do_update = self.rng.random() < update_prob
74
68
  self._prob_step += 1
75
69
 
76
- for group in self.param_groups:
77
- precond_init_scale = group['precond_init_scale']
78
- max_size_triangular = group['max_size_triangular']
79
- min_ndim_triangular = group['min_ndim_triangular']
80
- memory_save_mode = group['memory_save_mode']
81
- precond_lr = group['precond_lr']
82
- weight_decay = group['weight_decay']
83
- lr = group['lr']
84
-
85
- vals = []
70
+ precond_init_scale = group['precond_init_scale']
71
+ max_size_triangular = group['max_size_triangular']
72
+ min_ndim_triangular = group['min_ndim_triangular']
73
+ memory_save_mode = group['memory_save_mode']
74
+ precond_lr = group['precond_lr']
75
+ weight_decay = group['weight_decay']
76
+ lr = group['lr']
86
77
 
87
- for p, g in split_p_and_g_in_group(group):
88
- state = self.state_(p)
78
+ vals = []
89
79
 
90
- if 'Q' not in state:
91
- Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular, min_ndim_triangular,
92
- memory_save_mode, dtype=g.dtype)
93
- state['Q'] = triu_to_line(Q)
80
+ for p, g in split_p_and_g_in_group(group):
81
+ state = self.state_(p)
94
82
 
95
- vals.append((p, g, state["Q"]))
83
+ if 'Q' not in state:
84
+ Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular, min_ndim_triangular,
85
+ memory_save_mode, dtype=g.dtype)
86
+ state['Q'] = triu_to_line(Q)
96
87
 
97
- if not vals:
98
- continue
88
+ vals.append((p, g, state["Q"]))
99
89
 
100
- p_list, grad_list, Q_list = zip(*vals)
101
- del vals
90
+ if not vals:
91
+ return
102
92
 
103
- group["step"] += 1
93
+ p_list, grad_list, Q_list = zip(*vals)
94
+ del vals
104
95
 
105
- Q_list = list(Q_list)
106
- for i, (p, g) in enumerate(zip(p_list, grad_list)):
107
- q_orig = Q_list.pop(0)
108
- q = line_to_triu(q_orig)
96
+ group["step"] += 1
109
97
 
110
- self.balance(do_update, [g], [q])
111
- if do_update:
112
- self.do_update([p], [g], [q], precond_lr, [q_orig])
113
- psgd_precond_grad(q, self.state_(p)["exprs"], g, inplace=True)
98
+ Q_list = list(Q_list)
99
+ for i, (p, g) in enumerate(zip(p_list, grad_list)):
100
+ q_orig = Q_list.pop(0)
101
+ q = line_to_triu(q_orig)
114
102
 
115
- grad_list = self.clip_fn(grad_list)
103
+ self.balance(do_update, [g], [q])
104
+ if do_update:
105
+ self.do_update([p], [g], [q], precond_lr, [q_orig])
106
+ psgd_precond_grad(q, self.state_(p)["exprs"], g, inplace=True)
116
107
 
117
- lr = -warmup(lr, group['step'], group['warmup_steps'])
118
- update_param_(p_list, grad_list, lr, weight_decay)
108
+ grad_list = self.clip_fn(grad_list)
119
109
 
120
- return loss
110
+ lr = -warmup(lr, group['step'], group['warmup_steps'])
111
+ update_param_(p_list, grad_list, lr, weight_decay)
@@ -46,84 +46,73 @@ class SFPaLMForeachSOAP(ScheduleFree):
46
46
  self._data_format = data_format
47
47
  self.rng = random.Random(0x120983109)
48
48
 
49
- @torch.no_grad()
50
- def step(self, closure=None):
51
- """
52
- Performs a single optimization step.
53
-
54
- Arguments:
55
- closure (`Callable`, *optional*): A closure that reevaluates the model and returns the loss.
56
- """
57
- if closure is None:
58
- loss = None
59
- else:
60
- loss = closure()
61
-
62
- for group in self.param_groups:
63
- vals = []
64
- max_precond_dim = group['max_precond_dim']
65
- precondition_1d = group['precondition_1d']
66
-
67
- step = group['step'] = group.get("step", -1) + 1
68
-
69
- for p in group["params"]:
70
- if p.grad is None:
71
- continue
72
- grad = p.grad.float()
73
- vals.append((p, grad))
74
-
75
- p_list, grad = zip(*vals)
76
-
77
- adaptive_gradient_clipping_(p_list, grad, group["gradient_clip_val"], eps=group["eps"])
78
-
79
- vals = []
80
-
81
- for p, g in split_p_and_g_in_group(group):
82
- state = self.state_(p)
83
-
84
- if "z" not in state:
85
- state["z"] = torch.clone(p).float()
86
- state["exp_avg_sq"] = torch.zeros_like(g, dtype=torch.float32)
87
- init_preconditioner(g, state, max_precond_dim, precondition_1d)
88
- update_preconditioner(g, state, max_precond_dim, precondition_1d, 0, True)
89
- continue # first step is skipped so that we never use the current gradients in the projection.
90
-
91
- # Projecting gradients to the eigenbases of Shampoo's preconditioner
92
- # i.e. projecting to the eigenbases of matrices in state['GG']
93
- grad_projected = project(g, state['Q'], False)
94
- z, exp_avg_sq = state["z"], state["exp_avg_sq"]
95
- vals.append((p, g, grad_projected, z, exp_avg_sq))
96
-
97
- if not vals:
49
+ def _step(self, group):
50
+ vals = []
51
+ max_precond_dim = group['max_precond_dim']
52
+ precondition_1d = group['precondition_1d']
53
+
54
+ step = group['step'] = group.get("step", -1) + 1
55
+
56
+ for p in group["params"]:
57
+ if p.grad is None:
98
58
  continue
59
+ grad = p.grad.float()
60
+ vals.append((p, grad))
61
+
62
+ if not vals:
63
+ return
64
+
65
+ p_list, grad = zip(*vals)
66
+
67
+ adaptive_gradient_clipping_(p_list, grad, group["gradient_clip_val"], eps=group["eps"])
68
+
69
+ vals = []
70
+
71
+ for p, g in split_p_and_g_in_group(group):
72
+ state = self.state_(p)
73
+
74
+ if "z" not in state:
75
+ state["z"] = torch.clone(p).float()
76
+ state["exp_avg_sq"] = torch.zeros_like(g, dtype=torch.float32)
77
+ init_preconditioner(g, state, max_precond_dim, precondition_1d)
78
+ update_preconditioner(g, state, max_precond_dim, precondition_1d, 0, True)
79
+ continue # first step is skipped so that we never use the current gradients in the projection.
80
+
81
+ # Projecting gradients to the eigenbases of Shampoo's preconditioner
82
+ # i.e. projecting to the eigenbases of matrices in state['GG']
83
+ grad_projected = project(g, state['Q'], False)
84
+ z, exp_avg_sq = state["z"], state["exp_avg_sq"]
85
+ vals.append((p, g, grad_projected, z, exp_avg_sq))
86
+
87
+ if not vals:
88
+ return
99
89
 
100
- p_list, grad, grad_projected, z, exp_avg_sq = zip(*vals)
90
+ p_list, grad, grad_projected, z, exp_avg_sq = zip(*vals)
101
91
 
102
- beta2 = 1 - max(step, 1) ** -group['beta2_scale']
103
- new_debiased2 = beta_debias(beta2, step)
92
+ beta2 = 1 - max(step, 1) ** -group['beta2_scale']
93
+ new_debiased2 = beta_debias(beta2, step)
104
94
 
105
- # Decay the first and second moment running average coefficient
106
- # In-place operations to update the averages at the same time
107
- denom = exp_avg_sq_(exp_avg_sq, grad, new_debiased2, group["eps"])
108
- torch._foreach_div_(grad_projected, denom)
95
+ # Decay the first and second moment running average coefficient
96
+ # In-place operations to update the averages at the same time
97
+ denom = exp_avg_sq_(exp_avg_sq, grad, new_debiased2, group["eps"])
98
+ torch._foreach_div_(grad_projected, denom)
109
99
 
110
- update_precond = group['step'] > 0 and group['step'] % group['precondition_frequency'] == 0
100
+ update_precond = group['step'] > 0 and group['step'] % group['precondition_frequency'] == 0
111
101
 
112
- for p, g, gp in zip(p_list, grad, grad_projected):
113
- state = self.state_(p)
114
- # Projecting back the preconditioned (by Adam) exponential moving average of gradients
115
- # to the original space
116
- # CANT DO /= HERE AS EXP_AVG MAY POINT TO THE BUFFER
117
- set_(gp, project(gp, state['Q'], back=True))
102
+ for p, g, gp in zip(p_list, grad, grad_projected):
103
+ state = self.state_(p)
104
+ # Projecting back the preconditioned (by Adam) exponential moving average of gradients
105
+ # to the original space
106
+ # CANT DO /= HERE AS EXP_AVG MAY POINT TO THE BUFFER
107
+ set_(gp, project(gp, state['Q'], back=True))
118
108
 
119
- update_preconditioner(g, state, max_precond_dim, precondition_1d, 1 - new_debiased2,
120
- update_precond)
109
+ update_preconditioner(g, state, max_precond_dim, precondition_1d, 1 - new_debiased2,
110
+ update_precond)
121
111
 
122
- # Weight decay calculated at y
123
- if group["weight_decay"] > 0:
124
- torch._foreach_add_(grad, p_list, alpha=group["weight_decay"])
112
+ # Weight decay calculated at y
113
+ if group["weight_decay"] > 0:
114
+ torch._foreach_add_(grad, p_list, alpha=group["weight_decay"])
125
115
 
126
- lr = warmup(group['lr'], step, group['warmup_steps'])
127
- group['weight_sum'] = schedule_free_(lr, group['weight_lr_power'], group['weight_sum'], group['beta'],
128
- p_list, z, grad_projected, group['r'], step)
129
- return loss
116
+ lr = warmup(group['lr'], step, group['warmup_steps'])
117
+ group['weight_sum'] = schedule_free_(lr, group['weight_lr_power'], group['weight_sum'], group['beta'],
118
+ p_list, z, grad_projected, group['r'], step)
heavyball/utils.py CHANGED
@@ -3,7 +3,7 @@ import gc
3
3
  import math
4
4
  import random
5
5
  import string
6
- from typing import List, Optional, Tuple
6
+ from typing import List, Optional, Tuple, Callable
7
7
 
8
8
  import numpy as np
9
9
  import torch
@@ -399,6 +399,20 @@ class StatefulOptimizer(torch.optim.Optimizer):
399
399
  tree_map(_add, self.state_(p))
400
400
  return total_bytes
401
401
 
402
+ def _step(self, group):
403
+ raise NotImplementedError
404
+
405
+ def step(self, closure: Optional[Callable] = None):
406
+ if closure is None:
407
+ loss = None
408
+ else:
409
+ with torch.enable_grad():
410
+ loss = closure()
411
+ with torch.no_grad():
412
+ for group in self.param_groups:
413
+ self._step(group)
414
+ return loss
415
+
402
416
 
403
417
  class ScheduleFree(StatefulOptimizer):
404
418
  def eval(self):
@@ -570,7 +584,6 @@ def psgd_balance_Q(Q_in):
570
584
 
571
585
 
572
586
  def psgd_calc_A_and_conjB(exprA, G, Q, V):
573
- print([q.shape for q in Q], G.shape, V.shape, exprA)
574
587
  A = torch.einsum(exprA, *Q, G)
575
588
  order = G.dim()
576
589
  p = list(range(order))
@@ -685,9 +698,11 @@ def a_law_compress(x, A=87.6):
685
698
  torch._foreach_mul_(xa, 1 / (1 + math.log(A)))
686
699
  return xa
687
700
 
701
+
688
702
  def identity(x):
689
703
  return x
690
704
 
705
+
691
706
  def trust_region_clip_(grad, lerp: float = 0.9, scale: float = 1.5):
692
707
  torch._foreach_mul_(grad, 1 / scale)
693
708
  tanh = torch._foreach_tanh(grad)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: heavyball
3
- Version: 0.14.6
3
+ Version: 0.15.0
4
4
  Summary: Efficient optimizers
5
5
  Home-page: https://github.com/clashluke/heavyball
6
6
  Author: Lucas Nestler
@@ -0,0 +1,22 @@
1
+ heavyball/__init__.py,sha256=AGNWRYDkPFZ9Px3117ChPgUgEk2gqflWWXO4UxHlmYc,1156
2
+ heavyball/delayed_psgd.py,sha256=28osHU-2khgdQ1ASglxTtA5MA1j2GiYP3OmNirkqMso,5574
3
+ heavyball/foreach_adamw.py,sha256=NSzoIgNm7eavzbJgkAF0k7TUEnWAgOpt9-4juIFoaSA,1729
4
+ heavyball/foreach_adopt.py,sha256=WA07m5jocLfb1GPU8s6mJ2PteS-03ronkKm-VJrAm5I,1863
5
+ heavyball/foreach_laprop.py,sha256=mE2NDGX9XgvRhsewcWnk_-FulZPqGA65ejYF_9-A1Xk,1768
6
+ heavyball/foreach_sfadamw.py,sha256=ussHfPd99u3RTfMrCuu5oIbwNFLXK19wO1Fbz3JShlc,2097
7
+ heavyball/foreach_soap.py,sha256=WWvssYKg607uoEJHftp8ag8mtKSKSeHrT0QTgqBucVg,4587
8
+ heavyball/p_adam.py,sha256=xOJuws2ELPcL-TUyH-2fPvwRdBNZUmaqiKDJFK33bPM,5694
9
+ heavyball/palm_foreach_sfadamw.py,sha256=wjUb_fNZNUmzWXyKvwB0unP9lvNMmaYSQo5YoeS5cj0,2200
10
+ heavyball/palm_foreach_soap.py,sha256=2Sb4hUHQeexJcCgjHeQM_ENkZ6lG1DVxW72ryrvR6iY,5890
11
+ heavyball/precond_schedule_foreach_soap.py,sha256=bHsDyh-UvHpHjumjqqy0PePoR1ZMsJV6o5wWvpLAA04,4815
12
+ heavyball/precond_schedule_palm_foreach_soap.py,sha256=myLTJNQKLtZ3Xi3MVTB-RYtx_XeMRJw5CIMJW75ndUY,6163
13
+ heavyball/precond_schedule_sfpsoap.py,sha256=xeNWetBzBEYqfOSzl98aAVJsHk43QkrUUhHH_YD_mS4,6740
14
+ heavyball/psgd_kron.py,sha256=7PxvVNeXQcxHMDdn0hFn4psEH89xOctflKt5sKjryMU,5554
15
+ heavyball/pure_psgd.py,sha256=1vafWGQ5vtEE01T1qhI9GDXvzFw5zzq0rJrewa1jY4E,4847
16
+ heavyball/schedule_free_palm_foreach_soap.py,sha256=w0P7lMmoijTpL9V7NwOHcNBFJQ7S1TS9aCiwPhY2yVw,6319
17
+ heavyball/utils.py,sha256=xNqBJBZyK5n5EKy2g4qkXf342uGPGvk6pzjFPzeBncM,27861
18
+ heavyball-0.15.0.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
19
+ heavyball-0.15.0.dist-info/METADATA,sha256=CuUQyfkSwtwfsnVo6vLM_0hIMwJfB_J26-baYAlqDvM,11667
20
+ heavyball-0.15.0.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
21
+ heavyball-0.15.0.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
22
+ heavyball-0.15.0.dist-info/RECORD,,
@@ -1,22 +0,0 @@
1
- heavyball/__init__.py,sha256=ef7IWcPF8Uh3WQHzMiMqOFvUnU_LdG7BO9XVecJOph4,1156
2
- heavyball/delayed_psgd.py,sha256=Gfa1ogkFPPL7ohayYAwbugB8hyLRUI5FgcJfsK69KGI,5936
3
- heavyball/foreach_adamw.py,sha256=L727chOuVqdiVbYYzloy4g3oKH2FmQh40o_bqWeZtk8,2106
4
- heavyball/foreach_adopt.py,sha256=M4zZVcYlhGjqO6qekivCuYfX6JtMcp4cJi_RrSHT3H8,2268
5
- heavyball/foreach_laprop.py,sha256=htYGqgvlZsF_JpILdPMTnX72SqfrokBJ2J4nDeT0JVU,2157
6
- heavyball/foreach_sfadamw.py,sha256=KIGol7Phfq1DHE_nEle4wDuqNdbGsT3kUcMKzJX3msg,2498
7
- heavyball/foreach_soap.py,sha256=Ccz9Mc_xaHnrJ_7jUq9ZVxyR0WEqopzOXTUqUY-V8G8,5137
8
- heavyball/p_adam.py,sha256=jQgTkKekqnLj1XPA4-fgpWG8P_BtUq2976zEt2QymTo,6060
9
- heavyball/palm_foreach_sfadamw.py,sha256=8IGlRCdkfMzUqcSfmTM3Ce04NLNyrT2xfiBcPkrWwqc,2605
10
- heavyball/palm_foreach_soap.py,sha256=NEJ3Xeh7pqURUk3cAP2qJe8z2WzYKg60pQe4bsGiaY4,6441
11
- heavyball/precond_schedule_foreach_soap.py,sha256=H6Oc5IAL5MR-fgu92AboPs3Xm8mBmYUMPLsEcuJ12VI,5370
12
- heavyball/precond_schedule_palm_foreach_soap.py,sha256=v81hRjcqS6Mm-KxT5Rk3TEiKAE8WI2IbmVbSa-YfBkE,6760
13
- heavyball/precond_schedule_sfpsoap.py,sha256=7ScnN0in8u9hPiJE7QnOoZOH6Tn-6HeVy4f-bO3bHzY,7279
14
- heavyball/psgd_kron.py,sha256=AH8ugd_IxKGVtY9y_Ot7myVSxFDbLlRJIqr2bBlAYy8,5911
15
- heavyball/pure_psgd.py,sha256=jp5fnawUdgccEFlZDPrZr4ZbxYV85IIrev4tybZxBVU,5185
16
- heavyball/schedule_free_palm_foreach_soap.py,sha256=bV7H-FNNoH5WpposLrNhkqU7mBicMorqKEALBSdROEM,6853
17
- heavyball/utils.py,sha256=WfvymrU9Xv7PMfitXZvm-4XklCy6wK0tWqOXKt96Tww,27521
18
- heavyball-0.14.6.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
19
- heavyball-0.14.6.dist-info/METADATA,sha256=RWQo73o1ajPpDO8uJNOhdV5d4uFdkycXRCtHrM-KfDw,11667
20
- heavyball-0.14.6.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
21
- heavyball-0.14.6.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
22
- heavyball-0.14.6.dist-info/RECORD,,