heavyball 1.0.0__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- heavyball/__init__.py +18 -4
- heavyball/chainable.py +100 -43
- heavyball/utils.py +128 -81
- {heavyball-1.0.0.dist-info → heavyball-1.1.0.dist-info}/METADATA +3 -3
- heavyball-1.1.0.dist-info/RECORD +8 -0
- heavyball-1.0.0.dist-info/RECORD +0 -8
- {heavyball-1.0.0.dist-info → heavyball-1.1.0.dist-info}/LICENSE +0 -0
- {heavyball-1.0.0.dist-info → heavyball-1.1.0.dist-info}/WHEEL +0 -0
- {heavyball-1.0.0.dist-info → heavyball-1.1.0.dist-info}/top_level.txt +0 -0
heavyball/__init__.py
CHANGED
@@ -59,6 +59,19 @@ class ForeachADOPT(C.BaseOpt):
|
|
59
59
|
super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, C.update_by_adopt)
|
60
60
|
|
61
61
|
|
62
|
+
class ForeachMuon(C.BaseOpt):
|
63
|
+
def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
|
64
|
+
foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False,
|
65
|
+
mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default,
|
66
|
+
update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8,
|
67
|
+
nesterov: bool = True):
|
68
|
+
defaults = dict(lr=lr, betas=betas, eps=eps, k=0, warmup_steps=warmup_steps, train_mode=True, weight_sum=0.0,
|
69
|
+
lr_max=-1.0, weight_decay=weight_decay, storage_dtype=storage_dtype, mars=mars, caution=caution,
|
70
|
+
mars_gamma=mars_gamma, beta2_scale=beta2_scale)
|
71
|
+
super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm,
|
72
|
+
C.nesterov_momentum if nesterov else C.heavyball_momentum, C.orthogonalize_update)
|
73
|
+
|
74
|
+
|
62
75
|
class ForeachLaProp(C.BaseOpt):
|
63
76
|
def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
|
64
77
|
foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False,
|
@@ -200,10 +213,11 @@ PurePSGD = ForeachPurePSGD
|
|
200
213
|
DelayedPSGD = ForeachDelayedPSGD
|
201
214
|
CachedPSGDKron = ForeachCachedPSGDKron
|
202
215
|
CachedDelayedPSGDKron = ForeachCachedDelayedPSGDKron
|
216
|
+
Muon = ForeachMuon
|
203
217
|
|
204
|
-
__all__ = ["PrecondSchedulePaLMSOAP", "PSGDKron", "PurePSGD", "DelayedPSGD", "CachedPSGDKron",
|
205
|
-
"PalmForEachSoap", "PaLMSOAP", "PaLMSFAdamW", "LaProp", "ADOPT",
|
206
|
-
"PrecondSchedulePaLMSOAP", 'RMSprop',
|
218
|
+
__all__ = ["Muon","RMSprop", "PrecondSchedulePaLMSOAP", "PSGDKron", "PurePSGD", "DelayedPSGD", "CachedPSGDKron",
|
219
|
+
"CachedDelayedPSGDKron", "PalmForEachSoap", "PaLMSOAP", "PaLMSFAdamW", "LaProp", "ADOPT",
|
220
|
+
"PrecondScheduleSOAP", "PrecondSchedulePaLMSOAP", 'RMSprop', #
|
207
221
|
"ForeachAdamW", "ForeachSFAdamW", "ForeachLaProp", "ForeachADOPT", "ForeachSOAP", "ForeachPSGDKron",
|
208
222
|
"ForeachPurePSGD", "ForeachDelayedPSGD", "ForeachCachedPSGDKron", "ForeachCachedDelayedPSGDKron",
|
209
|
-
"ForeachRMSprop"]
|
223
|
+
"ForeachRMSprop", "ForeachMuon"]
|
heavyball/chainable.py
CHANGED
@@ -33,6 +33,23 @@ def _guard_in_state(state, key, template_fn):
|
|
33
33
|
return state[key]
|
34
34
|
|
35
35
|
|
36
|
+
class FunctionTransform:
|
37
|
+
def __init__(self, fn):
|
38
|
+
self.fn = fn
|
39
|
+
self.fn_name = self.get_fn().__name__
|
40
|
+
|
41
|
+
def __call__(self, state, group, update, grad, param, *args, **kwargs):
|
42
|
+
raise NotImplementedError
|
43
|
+
|
44
|
+
def get_fn(self):
|
45
|
+
if hasattr(self.fn, 'get_fn'):
|
46
|
+
return self.fn.get_fn()
|
47
|
+
return self.fn
|
48
|
+
|
49
|
+
def val_name(self, name):
|
50
|
+
return f"{self.fn_name}_{name}"
|
51
|
+
|
52
|
+
|
36
53
|
def _zero_guard(state, key, ref, dtype):
|
37
54
|
return _guard_in_state(state, key,
|
38
55
|
lambda: torch.zeros_like(ref, dtype=torch.float32, memory_format=torch.preserve_format))
|
@@ -43,61 +60,77 @@ def _storage_dtype(group):
|
|
43
60
|
return getattr(torch, dtype)
|
44
61
|
|
45
62
|
|
46
|
-
|
47
|
-
def
|
48
|
-
|
49
|
-
|
50
|
-
return fn(state, group, update, grad, param, *args, *vars, **kwargs)
|
63
|
+
class ZeroGuard(FunctionTransform):
|
64
|
+
def __init__(self, fn, names):
|
65
|
+
super().__init__(fn)
|
66
|
+
self.names = names
|
51
67
|
|
52
|
-
|
68
|
+
def __call__(self, state, group, update, grad, param, *args, **kwargs):
|
69
|
+
vars = [[_zero_guard(state(p), self.val_name(name), p, _storage_dtype(group)) for p in param] #
|
70
|
+
for name in self.names]
|
71
|
+
return self.fn(state, group, update, grad, param, *args, *vars, **kwargs)
|
53
72
|
|
54
|
-
return _outer
|
55
73
|
|
74
|
+
class CopyGuard(FunctionTransform):
|
75
|
+
def __init__(self, fn, index, names):
|
76
|
+
super().__init__(fn)
|
77
|
+
self.index = index
|
78
|
+
self.names = names
|
56
79
|
|
57
|
-
def
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
for name in names]
|
63
|
-
return fn(state, group, update, grad, param, *args, *vars, **kwargs)
|
80
|
+
def __call__(self, state, group, update, grad, param, *args, **kwargs):
|
81
|
+
val = [update, grad, param, *args][self.index]
|
82
|
+
vars = [[_guard_in_state(state(p), self.val_name(name), lambda: torch.clone(v)) for p, v in zip(param, val)] #
|
83
|
+
for name in self.names]
|
84
|
+
return self.fn(state, group, update, grad, param, *args, *vars, **kwargs)
|
64
85
|
|
65
|
-
return _fn
|
66
86
|
|
67
|
-
|
87
|
+
class GeneralGuard(FunctionTransform): # We can't guard against reuse in the general case
|
88
|
+
def __init__(self, fn, names, init_fn):
|
89
|
+
super().__init__(fn)
|
90
|
+
self.names = names
|
91
|
+
self.init_fn = init_fn
|
68
92
|
|
93
|
+
def __call__(self, state, group, update, grad, param, *args, **kwargs):
|
94
|
+
vars = []
|
95
|
+
skip_update = False
|
96
|
+
for p, g, u in zip(param, grad, update):
|
97
|
+
st = state(p)
|
98
|
+
skip_update |= _inplace_guard_(st, self.names, lambda: self.init_fn(st, group, u, g, p, **kwargs))
|
99
|
+
vars.append([st[name] if isinstance(name, str) else st.get(name[0], name[1]) for name in self.names])
|
100
|
+
if skip_update:
|
101
|
+
raise SkipUpdate
|
102
|
+
return self.fn(state, group, update, grad, param, *args, *zip(*vars), **kwargs)
|
69
103
|
|
70
|
-
def general_guard(*names, init_fn):
|
71
|
-
def _outer(fn):
|
72
|
-
def _fn(state, group, update, grad, param, *args, **kwargs):
|
73
|
-
vars = []
|
74
|
-
skip_update = False
|
75
|
-
for p, g, u in zip(param, grad, update):
|
76
|
-
st = state(p)
|
77
|
-
skip_update |= _inplace_guard_(st, names, lambda: init_fn(st, group, u, g, p, **kwargs))
|
78
|
-
vars.append([st[name] if isinstance(name, str) else st.get(name[0], name[1]) for name in names])
|
79
|
-
if skip_update:
|
80
|
-
raise SkipUpdate
|
81
|
-
return fn(state, group, update, grad, param, *args, *zip(*vars), **kwargs)
|
82
104
|
|
83
|
-
|
105
|
+
class NoState(FunctionTransform):
|
106
|
+
def __call__(self, state, group, update, grad, param, *args, **kwargs):
|
107
|
+
return self.fn(group, update, grad, param, *args, **kwargs)
|
84
108
|
|
85
|
-
return _outer
|
86
109
|
|
110
|
+
class NoStateNoForeach(FunctionTransform):
|
111
|
+
def __call__(self, state, group, update, grad, param, *args, **kwargs):
|
112
|
+
for a in zip(update, grad, param, *args):
|
113
|
+
return self.fn(group, *a, **kwargs)
|
87
114
|
|
88
|
-
def no_state(fn):
|
89
|
-
def _fn(state, *args, **kwargs):
|
90
|
-
return fn(*args, **kwargs)
|
91
115
|
|
92
|
-
|
116
|
+
def zero_guard(*names):
|
117
|
+
return functools.partial(ZeroGuard, names=names)
|
93
118
|
|
94
119
|
|
95
|
-
def
|
96
|
-
|
97
|
-
|
98
|
-
|
120
|
+
def copy_guard(index, *names):
|
121
|
+
return functools.partial(CopyGuard, index=index, names=names)
|
122
|
+
|
123
|
+
|
124
|
+
def general_guard(*names, init_fn):
|
125
|
+
return functools.partial(GeneralGuard, names=names, init_fn=init_fn)
|
99
126
|
|
100
|
-
|
127
|
+
|
128
|
+
def no_state(fn):
|
129
|
+
return NoState(fn)
|
130
|
+
|
131
|
+
|
132
|
+
def no_state_no_foreach(fn):
|
133
|
+
return NoStateNoForeach(fn)
|
101
134
|
|
102
135
|
|
103
136
|
class SkipUpdate(ValueError):
|
@@ -116,8 +149,6 @@ def exp_avg(group, update, grad, param, exp_avg):
|
|
116
149
|
def scale_by_exp_avg_sq(group, update, grad, param, exp_avg_sq):
|
117
150
|
out = utils.scale_by_exp_avg_sq_(exp_avg_sq, update, utils.beta_debias(utils.get_beta2(group), group["step"]),
|
118
151
|
group['eps'])
|
119
|
-
if group['step'] == 1:
|
120
|
-
raise SkipUpdate
|
121
152
|
return out
|
122
153
|
|
123
154
|
|
@@ -232,6 +263,29 @@ def precond_schedule(group, prob: Union[callable, float, None] = None, name: str
|
|
232
263
|
raise ValueError("No preconditioner update schedule specified.")
|
233
264
|
|
234
265
|
|
266
|
+
@no_state_no_foreach
|
267
|
+
def orthogonalize_update(group, update, grad, param):
|
268
|
+
if update.dim() == 1:
|
269
|
+
return update
|
270
|
+
original_shape = update.shape
|
271
|
+
# doing it this way, as tmp and update are not guaranteed to share memory address or layout
|
272
|
+
tmp = update.flatten(1, -1)
|
273
|
+
utils.inplace_orthogonal_(tmp, utils.zeroth_power_mode, tmp)
|
274
|
+
return tmp.reshape(original_shape)
|
275
|
+
|
276
|
+
|
277
|
+
@zero_guard("momentum")
|
278
|
+
@no_state
|
279
|
+
def nesterov_momentum(group, updates, grads, params, momentum):
|
280
|
+
return utils.nesterov_momentum(momentum, updates, utils.get_beta1(group))
|
281
|
+
|
282
|
+
|
283
|
+
@zero_guard("momentum")
|
284
|
+
@no_state
|
285
|
+
def heavyball_momentum(group, updates, grads, params, momentum):
|
286
|
+
return utils.heavyball_momentum(momentum, updates, utils.get_beta1(group))
|
287
|
+
|
288
|
+
|
235
289
|
@zero_guard("exp_avg", "exp_avg_sq")
|
236
290
|
@general_guard("Q", "GG", init_fn=_init_soap)
|
237
291
|
@no_state
|
@@ -346,7 +400,7 @@ def apply_to_idx(fn, idx):
|
|
346
400
|
|
347
401
|
|
348
402
|
def chain(state: Union[callable, dict], group, grad, param, *fns):
|
349
|
-
update = grad
|
403
|
+
update = [torch.clone(g) for g in grad]
|
350
404
|
skip_update = False
|
351
405
|
for fn in fns:
|
352
406
|
try:
|
@@ -375,7 +429,10 @@ class ChainOpt(utils.StatefulOptimizer):
|
|
375
429
|
else:
|
376
430
|
group['lr'] = -group['base_lr']
|
377
431
|
|
378
|
-
|
432
|
+
vals = list(self.split_p_and_g_in_group(group, should_promote=False, beta1=utils.get_beta1(group)))
|
433
|
+
if not vals:
|
434
|
+
return
|
435
|
+
p, g = zip(*vals)
|
379
436
|
|
380
437
|
if not group['foreach'] or len(p) == 1:
|
381
438
|
for param, grad in zip(p, g):
|
heavyball/utils.py
CHANGED
@@ -23,7 +23,7 @@ from torch.utils._pytree import tree_map
|
|
23
23
|
compile_mode = "max-autotune-no-cudagraphs"
|
24
24
|
dynamic = False
|
25
25
|
compile_mode_recommended_to_none = None
|
26
|
-
zeroth_power_mode = 'qr' # 'qr' is baseline, 'newtonschulz' converges better and faster
|
26
|
+
zeroth_power_mode = 'qr' # 'qr' is baseline, 'newtonschulz' converges better and faster
|
27
27
|
tiny_bf16 = torch.finfo(torch.bfloat16).tiny
|
28
28
|
|
29
29
|
|
@@ -91,11 +91,9 @@ def schedule_free_(lr: float, weight_lr_power: float, weight_sum: float, beta1:
|
|
91
91
|
ckp1 = weight / weight_sum
|
92
92
|
except ZeroDivisionError:
|
93
93
|
ckp1 = 0
|
94
|
-
ckp1 = 0
|
95
94
|
|
96
|
-
|
97
|
-
|
98
|
-
lr, ckp1, beta1 = scalar_guard(lr, parameters[0]), scalar_guard(ckp1, parameters[0]), scalar_guard(beta1, parameters[0])
|
95
|
+
grad, parameters, z = list_guard(grad, parameters, z)
|
96
|
+
lr, ckp1, beta1 = scalar_guard(lr, ckp1, beta1, grad[0])
|
99
97
|
_compilable_schedule_free_(parameters, z, ckp1, grad, lr, beta1, decay)
|
100
98
|
return weight_sum
|
101
99
|
|
@@ -179,28 +177,28 @@ def _compilable_exp_avg_sq_(state: List[Tensor], grad: List[Tensor], beta2: Tens
|
|
179
177
|
|
180
178
|
|
181
179
|
def exp_avg_sq_(state, grad, beta2, eps, out=None):
|
182
|
-
state, grad, out = list_guard(state
|
183
|
-
beta2, eps = scalar_guard(beta2,
|
180
|
+
state, grad, out = list_guard(state, grad, out)
|
181
|
+
beta2, eps = scalar_guard(beta2, eps, state[0])
|
184
182
|
return _compilable_exp_avg_sq_(state, grad, beta2, eps, out)
|
185
183
|
|
186
184
|
|
187
185
|
@decorator_knowngood
|
188
|
-
def _compilable_scale_by_exp_avg_sq_(state: List[Tensor], grad: List[Tensor], beta2: Tensor, eps: Tensor
|
189
|
-
out: List[Optional[Tensor]]):
|
186
|
+
def _compilable_scale_by_exp_avg_sq_(state: List[Tensor], grad: List[Tensor], beta2: Tensor, eps: Tensor):
|
190
187
|
s32, g32 = [list(map(promote, x)) for x in (state, grad)]
|
191
188
|
torch._foreach_mul_(s32, beta2)
|
192
189
|
[s.addcmul_(g, g, value=1 - beta2) for s, g in zip(s32, g32)]
|
193
190
|
denom = torch._foreach_sqrt(s32)
|
194
191
|
[d.clamp_(min=eps) for d in denom]
|
195
|
-
out = torch.
|
192
|
+
out = torch._foreach_div_(g32, denom)
|
196
193
|
copy_stochastic_list_(state, s32)
|
197
|
-
|
194
|
+
copy_stochastic_list_(grad, out)
|
198
195
|
|
199
196
|
|
200
|
-
def scale_by_exp_avg_sq_(
|
201
|
-
grad, exp_avg_sq = list_guard(grad
|
202
|
-
beta2, eps = scalar_guard(beta2,
|
203
|
-
|
197
|
+
def scale_by_exp_avg_sq_(exp_avg_sq, grad, beta2, eps):
|
198
|
+
grad, exp_avg_sq = list_guard(grad, exp_avg_sq)
|
199
|
+
beta2, eps = scalar_guard(beta2, eps, grad[0])
|
200
|
+
_compilable_scale_by_exp_avg_sq_(exp_avg_sq, grad, beta2, eps)
|
201
|
+
return grad
|
204
202
|
|
205
203
|
|
206
204
|
@decorator_knowngood
|
@@ -219,7 +217,7 @@ def adaptive_gradient_clipping_(parameters: List[Tensor], gradients: List[Tensor
|
|
219
217
|
minimum: float = 1e-3, eps: float = 1e-8):
|
220
218
|
if clip_val <= 0:
|
221
219
|
return gradients
|
222
|
-
parameters, gradients = list_guard(parameters
|
220
|
+
parameters, gradients = list_guard(parameters, gradients)
|
223
221
|
clip_val = scalar_guard(clip_val, parameters[0])
|
224
222
|
return _compilable_agc_(parameters, gradients, clip_val, minimum, eps)
|
225
223
|
|
@@ -254,33 +252,29 @@ def set_torch():
|
|
254
252
|
|
255
253
|
|
256
254
|
@decorator
|
257
|
-
def zeropower_via_newtonschulz5(G,
|
255
|
+
def zeropower_via_newtonschulz5(G, steps=5, eps=1e-7):
|
258
256
|
"""
|
259
|
-
Modified from "modded-nanogpt" under the MIT license:
|
260
|
-
Original: https://github.com/KellerJordan/modded-nanogpt/blob/a0dcbfdd9a0617d091d5123cfc354745428e40d3/train_gpt2.py
|
261
|
-
|
262
257
|
Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
|
263
258
|
quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
|
264
259
|
of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
|
265
260
|
zero even beyond the point where the iteration no longer converges all the way to one everywhere
|
266
261
|
on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
|
267
|
-
where S' is diagonal with S_{ii}'
|
262
|
+
where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
|
268
263
|
performance at all relative to UV^T, where USV^T = G is the SVD.
|
269
264
|
"""
|
270
265
|
assert len(G.shape) == 2
|
271
266
|
a, b, c = (3.4445, -4.7750, 2.0315)
|
272
|
-
X = G.
|
273
|
-
|
274
|
-
X = X / (X.norm() + eps) # ensure top singular value <= 1
|
267
|
+
X = G.bfloat16()
|
268
|
+
X /= (X.norm() + eps) # ensure top singular value <= 1
|
275
269
|
if G.size(0) > G.size(1):
|
276
270
|
X = X.T
|
277
271
|
for _ in range(steps):
|
278
|
-
A = X @ X.T
|
279
|
-
B = A @
|
280
|
-
|
272
|
+
A = X @ X.T
|
273
|
+
B = b * A + c * A @ A # adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
|
274
|
+
X = a * X + B @ X
|
281
275
|
if G.size(0) > G.size(1):
|
282
276
|
X = X.T
|
283
|
-
return X
|
277
|
+
return X.to(G.dtype)
|
284
278
|
|
285
279
|
|
286
280
|
def ortho(x):
|
@@ -292,6 +286,53 @@ def ortho(x):
|
|
292
286
|
raise NotImplementedError(f"Unknown zeroth_power_mode: {zeroth_power_mode}")
|
293
287
|
|
294
288
|
|
289
|
+
@decorator_knowngood
|
290
|
+
def _compilable_heavyball_momentum_(state, grad, beta):
|
291
|
+
s32, g32 = [list(map(promote, x)) for x in (state, grad)]
|
292
|
+
torch._foreach_mul_(s32, beta)
|
293
|
+
torch._foreach_add_(s32, g32)
|
294
|
+
copy_stochastic_list_(state, s32)
|
295
|
+
copy_stochastic_list_(grad, s32)
|
296
|
+
|
297
|
+
|
298
|
+
@decorator_knowngood
|
299
|
+
def _compilable_nesterov_momentum_(state, grad, beta):
|
300
|
+
s32, g32 = [list(map(promote, x)) for x in (state, grad)]
|
301
|
+
torch._foreach_mul_(s32, beta)
|
302
|
+
torch._foreach_add_(s32, g32)
|
303
|
+
[g.add_(s, alpha=beta) for g, s in zip(g32, s32)]
|
304
|
+
copy_stochastic_list_(state, s32)
|
305
|
+
copy_stochastic_list_(grad, g32)
|
306
|
+
|
307
|
+
|
308
|
+
def heavyball_momentum(state, grad, beta):
|
309
|
+
state, grad = list_guard(state, grad)
|
310
|
+
beta = scalar_guard(beta, state[0])
|
311
|
+
_compilable_heavyball_momentum_(state, grad, beta)
|
312
|
+
return grad
|
313
|
+
|
314
|
+
|
315
|
+
def nesterov_momentum(state, grad, beta):
|
316
|
+
state, grad = list_guard(state, grad)
|
317
|
+
beta = scalar_guard(beta, state[0])
|
318
|
+
_compilable_nesterov_momentum_(state, grad, beta)
|
319
|
+
return grad
|
320
|
+
|
321
|
+
|
322
|
+
@decorator_knowngood
|
323
|
+
def inplace_orthogonal_(x, mode, out):
|
324
|
+
if mode == 'qr':
|
325
|
+
y = torch.linalg.qr(x).Q
|
326
|
+
elif mode == 'svd':
|
327
|
+
u, s, v = torch.linalg.svd(x)
|
328
|
+
y = u @ v.T
|
329
|
+
elif mode == 'newtonschulz':
|
330
|
+
y = zeropower_via_newtonschulz5(x, 5)
|
331
|
+
else:
|
332
|
+
raise NotImplementedError(f"Unknown zeroth_power_mode: {mode}")
|
333
|
+
set_(out, y)
|
334
|
+
|
335
|
+
|
295
336
|
def get_orthogonal_matrix_QR(GG, Q, exp_avg_sq):
|
296
337
|
"""
|
297
338
|
Computes the eigenbases of the preconditioner using one round of power iteration
|
@@ -322,17 +363,7 @@ def get_orthogonal_matrix_QR(GG, Q, exp_avg_sq):
|
|
322
363
|
est_eig = torch.einsum('ij,ij->j', o, tmp)
|
323
364
|
sort_idx = torch.argsort(est_eig, descending=True)
|
324
365
|
indices.append(sort_idx)
|
325
|
-
|
326
|
-
set_(q, torch.linalg.eigh(m)[1])
|
327
|
-
elif zeroth_power_mode.startswith('newtonschulz'):
|
328
|
-
iterations = zeroth_power_mode[len('newtonschulz'):]
|
329
|
-
if iterations == '':
|
330
|
-
iterations = 10
|
331
|
-
else:
|
332
|
-
iterations = int(iterations)
|
333
|
-
set_(q, zeropower_via_newtonschulz5(m, o[:, sort_idx], iterations))
|
334
|
-
else:
|
335
|
-
set_(q, ortho(tmp[:, sort_idx]))
|
366
|
+
inplace_orthogonal_(tmp[:, sort_idx], q)
|
336
367
|
|
337
368
|
indices = tuple(slice(None) if ind is None else ind.view(*(1,) * i, -1, *(1,) * (exp_avg_sq.dim() - i - 1)) #
|
338
369
|
for i, ind in enumerate(indices))
|
@@ -407,7 +438,6 @@ def get_beta1(group):
|
|
407
438
|
|
408
439
|
|
409
440
|
def get_beta2(group):
|
410
|
-
beta = None
|
411
441
|
if 'beta2_scale' in group:
|
412
442
|
step = max(group.get("step", 1), 1)
|
413
443
|
return 1 - step ** -group['beta2_scale']
|
@@ -417,23 +447,36 @@ def get_beta2(group):
|
|
417
447
|
|
418
448
|
|
419
449
|
def stochastic_lerp_(x: List[Tensor], y: List[Tensor], a: Union[float, int, Tensor]):
|
420
|
-
x, y = list_guard(x
|
450
|
+
x, y = list_guard(x, y)
|
421
451
|
a = scalar_guard(a, x[0])
|
422
452
|
_compilable_stochastic_lerp_(x, y, a)
|
423
453
|
|
424
454
|
|
425
|
-
def list_guard(
|
426
|
-
|
427
|
-
|
428
|
-
|
455
|
+
def list_guard(*xs):
|
456
|
+
out = []
|
457
|
+
for x in xs:
|
458
|
+
if isinstance(x, (list, tuple)):
|
459
|
+
out.append(x)
|
460
|
+
else:
|
461
|
+
out.append([x])
|
462
|
+
if len(xs) == 1:
|
463
|
+
return out[0]
|
464
|
+
return out
|
429
465
|
|
430
466
|
|
431
|
-
def scalar_guard(
|
432
|
-
|
433
|
-
|
434
|
-
|
435
|
-
|
436
|
-
|
467
|
+
def scalar_guard(*args):
|
468
|
+
*xs, ref = args
|
469
|
+
out = []
|
470
|
+
for x in xs:
|
471
|
+
if isinstance(x, float):
|
472
|
+
out.append(torch.empty((), dtype=torch.float32, device=ref.device).fill_(x))
|
473
|
+
elif isinstance(x, int):
|
474
|
+
out.append(torch.empty((), dtype=torch.int64, device=ref.device).fill_(x))
|
475
|
+
else:
|
476
|
+
out.append(x)
|
477
|
+
if len(xs) == 1:
|
478
|
+
return out[0]
|
479
|
+
return out
|
437
480
|
|
438
481
|
|
439
482
|
@decorator_knowngood
|
@@ -446,7 +489,7 @@ def _compilable_stochastic_add_(x: List[Tensor], y: List[Tensor], alpha: Union[f
|
|
446
489
|
|
447
490
|
|
448
491
|
def stochastic_add_(x: List[Tensor], y: List[Tensor], alpha: Union[float, int, Tensor]):
|
449
|
-
x, y = list_guard(x
|
492
|
+
x, y = list_guard(x, y)
|
450
493
|
alpha = scalar_guard(alpha, x[0])
|
451
494
|
_compilable_stochastic_add_(x, y, alpha)
|
452
495
|
|
@@ -695,14 +738,14 @@ def _compilable_adam_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: Lis
|
|
695
738
|
|
696
739
|
copy_stochastic_list_(exp_avg, exp_avg32)
|
697
740
|
copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
|
698
|
-
|
741
|
+
copy_stochastic_list_(grad, u32)
|
699
742
|
|
700
743
|
|
701
744
|
def adam_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], beta1: float, beta2: float, step: int):
|
702
745
|
exp_avg, exp_avg_sq, grad = map(list_guard, (exp_avg, exp_avg_sq, grad))
|
703
|
-
beta1, beta2, step = scalar_guard(beta1,
|
704
|
-
|
705
|
-
return
|
746
|
+
beta1, beta2, step = scalar_guard(beta1, beta2, step, exp_avg[0])
|
747
|
+
_compilable_adam_(exp_avg, exp_avg_sq, grad, beta1, beta2, step)
|
748
|
+
return grad
|
706
749
|
|
707
750
|
|
708
751
|
@decorator_knowngood
|
@@ -725,32 +768,32 @@ def _fused_compilable_adam_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq:
|
|
725
768
|
|
726
769
|
def fused_adam_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], beta1: float,
|
727
770
|
beta2: float, step: int, lr: float, eps: float, decay: float, caution: bool):
|
728
|
-
y, exp_avg, exp_avg_sq, grad =
|
729
|
-
beta1, beta2, step, lr =
|
771
|
+
y, exp_avg, exp_avg_sq, grad = list_guard(y, exp_avg, exp_avg_sq, grad)
|
772
|
+
beta1, beta2, step, lr = scalar_guard(beta1, beta2, step, lr, y[0])
|
730
773
|
return _fused_compilable_adam_(y, exp_avg, exp_avg_sq, grad, beta1, beta2, step, decay, lr, eps, caution)
|
731
774
|
|
732
775
|
|
733
776
|
@decorator_knowngood
|
734
|
-
def _compilable_laprop_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor],
|
777
|
+
def _compilable_laprop_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], beta1: Tensor,
|
735
778
|
beta2: Tensor, step: Tensor):
|
736
779
|
beta1 = beta_debias(beta1, step)
|
737
780
|
beta2 = beta_debias(beta2, step)
|
738
781
|
|
739
|
-
gp32, exp_avg_sq32 = [list(map(promote, x)) for x in [
|
782
|
+
gp32, exp_avg_sq32 = [list(map(promote, x)) for x in [grad, exp_avg_sq]]
|
740
783
|
|
741
784
|
denom = exp_avg_sq_(exp_avg_sq32, gp32, beta2, 1e-8)
|
742
785
|
gp32 = torch._foreach_div(gp32, denom)
|
743
786
|
stochastic_lerp_(exp_avg, gp32, 1 - beta1)
|
744
787
|
|
745
788
|
copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
|
789
|
+
copy_stochastic_list_(grad, exp_avg)
|
746
790
|
|
747
791
|
|
748
|
-
def laprop_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor],
|
749
|
-
|
750
|
-
|
751
|
-
|
752
|
-
|
753
|
-
return exp_avg
|
792
|
+
def laprop_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], beta1: float, beta2: float, step: int):
|
793
|
+
exp_avg, exp_avg_sq, grad = list_guard(exp_avg, exp_avg_sq, grad)
|
794
|
+
beta1, beta2, step = scalar_guard(beta1, beta2, step, exp_avg[0])
|
795
|
+
_compilable_laprop_(exp_avg, exp_avg_sq, grad, beta1, beta2, step)
|
796
|
+
return grad
|
754
797
|
|
755
798
|
|
756
799
|
@decorator_knowngood
|
@@ -770,11 +813,11 @@ def _fused_compilable_laprop_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq
|
|
770
813
|
copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
|
771
814
|
|
772
815
|
|
773
|
-
def fused_laprop_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tensor],
|
774
|
-
|
775
|
-
|
776
|
-
beta1, beta2, step, lr =
|
777
|
-
_fused_compilable_laprop_(y, exp_avg, exp_avg_sq,
|
816
|
+
def fused_laprop_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], beta1: float,
|
817
|
+
beta2: float, step: int, lr: float, decay: float, caution: bool):
|
818
|
+
exp_avg, exp_avg_sq, grad, y = list_guard(exp_avg, exp_avg_sq, grad, y)
|
819
|
+
beta1, beta2, step, lr = scalar_guard(beta1, beta2, step, lr, exp_avg[0])
|
820
|
+
_fused_compilable_laprop_(y, exp_avg, exp_avg_sq, grad, beta1, beta2, step, lr, decay, caution)
|
778
821
|
|
779
822
|
|
780
823
|
@decorator_knowngood
|
@@ -797,8 +840,8 @@ def _fused_compilable_adopt_(y, grad, exp_avg_sq, exp_avg, beta1, beta2, step, l
|
|
797
840
|
|
798
841
|
|
799
842
|
def fused_adopt_(y, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps, decay, caution):
|
800
|
-
|
801
|
-
beta1, beta2, step, lr =
|
843
|
+
exp_avg, exp_avg_sq, grad, y = list_guard(exp_avg, exp_avg_sq, grad, y)
|
844
|
+
beta1, beta2, step, lr = scalar_guard(beta1, beta2, step, lr, exp_avg[0])
|
802
845
|
_fused_compilable_adopt_(y, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps, decay, caution)
|
803
846
|
|
804
847
|
|
@@ -819,14 +862,14 @@ def _compilable_adopt_(grad, exp_avg_sq, exp_avg, beta1, beta2, step):
|
|
819
862
|
|
820
863
|
copy_stochastic_list_(exp_avg, exp_avg32)
|
821
864
|
copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
|
822
|
-
|
823
|
-
return update
|
865
|
+
copy_stochastic_list_(grad, update)
|
824
866
|
|
825
867
|
|
826
868
|
def adopt(grad, exp_avg_sq, exp_avg, beta1, beta2, step):
|
827
|
-
|
828
|
-
beta1, beta2, step = scalar_guard(beta1,
|
829
|
-
|
869
|
+
exp_avg, exp_avg_sq, grad = list_guard(exp_avg, exp_avg_sq, grad, y)
|
870
|
+
beta1, beta2, step = scalar_guard(beta1, beta2, step, lr, exp_avg[0])
|
871
|
+
_compilable_adopt_(grad, exp_avg_sq, exp_avg, beta1, beta2, step)
|
872
|
+
return grad
|
830
873
|
|
831
874
|
|
832
875
|
def stochastic_round_list_(ref: List[Tensor], source: List[Tensor]):
|
@@ -877,7 +920,7 @@ def _compilable_update_(p: List[Tensor], u: List[Tensor], decay: Tensor, add_fn:
|
|
877
920
|
|
878
921
|
def update_param_(param: List[Tensor], update: List[Tensor], lr: float, decay: float, add_fn: callable = None,
|
879
922
|
caution: bool = False, grad: List[Tensor] = None):
|
880
|
-
param, update, grad = list_guard(param
|
923
|
+
param, update, grad = list_guard(param, update, grad)
|
881
924
|
lr = scalar_guard(lr, param[0])
|
882
925
|
if not caution:
|
883
926
|
grad = [None] * len(param)
|
@@ -983,11 +1026,15 @@ def psgd_balance_Q(Q_in):
|
|
983
1026
|
|
984
1027
|
|
985
1028
|
def psgd_calc_A_and_conjB(exprA, G, Q):
|
1029
|
+
V = torch.randn(G.shape, dtype=G.dtype, device=G.device)
|
1030
|
+
eps = scalar_guard(math.sqrt(torch.finfo(torch.float32).eps), G)
|
1031
|
+
eps *= G.norm() / G.numel()
|
1032
|
+
G += V * eps
|
986
1033
|
md = min_dtype(Q + [G])
|
987
1034
|
A = torch.einsum(exprA, *[q.to(md) for q in Q], G.to(md)).to(G.dtype)
|
988
1035
|
order = G.dim()
|
989
1036
|
p = list(range(order))
|
990
|
-
conjB = torch.
|
1037
|
+
conjB = torch.permute(V, p[1:] + p[:1]).to(promote(G.dtype))
|
991
1038
|
Q = [promote(q) for q in Q]
|
992
1039
|
for i, q in enumerate(Q):
|
993
1040
|
if q.dim() <= 1:
|
@@ -1134,7 +1181,7 @@ def _compilable_trust_region_clip_(grad, lerp: float = 0.9, scale: float = 1.5):
|
|
1134
1181
|
|
1135
1182
|
def trust_region_clip_(grad, lerp=0.9, scale=1.5):
|
1136
1183
|
grad = list_guard(grad)
|
1137
|
-
lerp, scale = scalar_guard(lerp,
|
1184
|
+
lerp, scale = scalar_guard(lerp, scale, grad[0])
|
1138
1185
|
return _compilable_trust_region_clip_(grad, lerp, scale)
|
1139
1186
|
|
1140
1187
|
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: heavyball
|
3
|
-
Version: 1.
|
3
|
+
Version: 1.1.0
|
4
4
|
Summary: Efficient optimizers
|
5
5
|
Home-page: https://github.com/clashluke/heavyball
|
6
6
|
Author: Lucas Nestler
|
@@ -132,5 +132,5 @@ It has several handy functions:
|
|
132
132
|
|
133
133
|
* `set_torch()` sets pytorch optimization settings (TF32, opt_einsum, benchmark, ...)
|
134
134
|
* `compile_mode`, a string passed as-is to `torch.compile(mode=compile_mode)` in all compiled heavyball calls; `compile_mode=None` disables torch_compile
|
135
|
-
* `zeroth_power_mode`, a string determining whether to use QR, newtonschulz
|
136
|
-
the eigenvectors.
|
135
|
+
* `zeroth_power_mode`, a string determining whether to use QR, newtonschulz, or svd or eigh to approximate
|
136
|
+
the eigenvectors.
|
@@ -0,0 +1,8 @@
|
|
1
|
+
heavyball/__init__.py,sha256=ZPadmK5Q9hGrBAeNYldMUw8vciy9dCDL3d7Zk9erC3E,12702
|
2
|
+
heavyball/chainable.py,sha256=NLyJZS_vuQxzLE3RjQy0isQGCY5xGQBCs1wirT9BsQY,20172
|
3
|
+
heavyball/utils.py,sha256=BEbvwWswWMyEa4zozlnLaQIKEMH_k7OUGu84jlcj5t0,46736
|
4
|
+
heavyball-1.1.0.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
|
5
|
+
heavyball-1.1.0.dist-info/METADATA,sha256=hVz8TPY0A9SqkjVpLX6LVXXlBVmAC8my96DHO8jtSb8,12022
|
6
|
+
heavyball-1.1.0.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
7
|
+
heavyball-1.1.0.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
|
8
|
+
heavyball-1.1.0.dist-info/RECORD,,
|
heavyball-1.0.0.dist-info/RECORD
DELETED
@@ -1,8 +0,0 @@
|
|
1
|
-
heavyball/__init__.py,sha256=1QPYBIH8amnk3-_rKe6L9FJ0rkV5wVNRr7Yw9BXjIYI,11636
|
2
|
-
heavyball/chainable.py,sha256=cp-tpetPr4CNN9xJ85JSo89JYC5BWUygoE6dnET6tmc,18141
|
3
|
-
heavyball/utils.py,sha256=qUoB9EIxl7GUyLkV5a5JAKOD6TvPc1FNsqyUbJ-HY6o,46343
|
4
|
-
heavyball-1.0.0.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
|
5
|
-
heavyball-1.0.0.dist-info/METADATA,sha256=9C2btIxngp26TRCJFU6B8ftkWQt1rfZZC10rkAhaORw,12074
|
6
|
-
heavyball-1.0.0.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
7
|
-
heavyball-1.0.0.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
|
8
|
-
heavyball-1.0.0.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|