heavyball 1.0.0__py3-none-any.whl → 1.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- heavyball/__init__.py +18 -4
- heavyball/chainable.py +107 -49
- heavyball/utils.py +143 -81
- {heavyball-1.0.0.dist-info → heavyball-1.1.1.dist-info}/METADATA +3 -3
- heavyball-1.1.1.dist-info/RECORD +8 -0
- heavyball-1.0.0.dist-info/RECORD +0 -8
- {heavyball-1.0.0.dist-info → heavyball-1.1.1.dist-info}/LICENSE +0 -0
- {heavyball-1.0.0.dist-info → heavyball-1.1.1.dist-info}/WHEEL +0 -0
- {heavyball-1.0.0.dist-info → heavyball-1.1.1.dist-info}/top_level.txt +0 -0
heavyball/__init__.py
CHANGED
@@ -59,6 +59,19 @@ class ForeachADOPT(C.BaseOpt):
|
|
59
59
|
super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, C.update_by_adopt)
|
60
60
|
|
61
61
|
|
62
|
+
class ForeachMuon(C.BaseOpt):
|
63
|
+
def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
|
64
|
+
foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False,
|
65
|
+
mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default,
|
66
|
+
update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8,
|
67
|
+
nesterov: bool = True):
|
68
|
+
defaults = dict(lr=lr, betas=betas, eps=eps, k=0, warmup_steps=warmup_steps, train_mode=True, weight_sum=0.0,
|
69
|
+
lr_max=-1.0, weight_decay=weight_decay, storage_dtype=storage_dtype, mars=mars, caution=caution,
|
70
|
+
mars_gamma=mars_gamma, beta2_scale=beta2_scale)
|
71
|
+
super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm,
|
72
|
+
C.nesterov_momentum if nesterov else C.heavyball_momentum, C.orthogonalize_update)
|
73
|
+
|
74
|
+
|
62
75
|
class ForeachLaProp(C.BaseOpt):
|
63
76
|
def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
|
64
77
|
foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False,
|
@@ -200,10 +213,11 @@ PurePSGD = ForeachPurePSGD
|
|
200
213
|
DelayedPSGD = ForeachDelayedPSGD
|
201
214
|
CachedPSGDKron = ForeachCachedPSGDKron
|
202
215
|
CachedDelayedPSGDKron = ForeachCachedDelayedPSGDKron
|
216
|
+
Muon = ForeachMuon
|
203
217
|
|
204
|
-
__all__ = ["PrecondSchedulePaLMSOAP", "PSGDKron", "PurePSGD", "DelayedPSGD", "CachedPSGDKron",
|
205
|
-
"PalmForEachSoap", "PaLMSOAP", "PaLMSFAdamW", "LaProp", "ADOPT",
|
206
|
-
"PrecondSchedulePaLMSOAP", 'RMSprop',
|
218
|
+
__all__ = ["Muon","RMSprop", "PrecondSchedulePaLMSOAP", "PSGDKron", "PurePSGD", "DelayedPSGD", "CachedPSGDKron",
|
219
|
+
"CachedDelayedPSGDKron", "PalmForEachSoap", "PaLMSOAP", "PaLMSFAdamW", "LaProp", "ADOPT",
|
220
|
+
"PrecondScheduleSOAP", "PrecondSchedulePaLMSOAP", 'RMSprop', #
|
207
221
|
"ForeachAdamW", "ForeachSFAdamW", "ForeachLaProp", "ForeachADOPT", "ForeachSOAP", "ForeachPSGDKron",
|
208
222
|
"ForeachPurePSGD", "ForeachDelayedPSGD", "ForeachCachedPSGDKron", "ForeachCachedDelayedPSGDKron",
|
209
|
-
"ForeachRMSprop"]
|
223
|
+
"ForeachRMSprop", "ForeachMuon"]
|
heavyball/chainable.py
CHANGED
@@ -33,6 +33,23 @@ def _guard_in_state(state, key, template_fn):
|
|
33
33
|
return state[key]
|
34
34
|
|
35
35
|
|
36
|
+
class FunctionTransform:
|
37
|
+
def __init__(self, fn):
|
38
|
+
self.fn = fn
|
39
|
+
self.fn_name = self.get_fn().__name__
|
40
|
+
|
41
|
+
def __call__(self, state, group, update, grad, param, *args, **kwargs):
|
42
|
+
raise NotImplementedError
|
43
|
+
|
44
|
+
def get_fn(self):
|
45
|
+
if hasattr(self.fn, 'get_fn'):
|
46
|
+
return self.fn.get_fn()
|
47
|
+
return self.fn
|
48
|
+
|
49
|
+
def val_name(self, name):
|
50
|
+
return f"{self.fn_name}_{name}"
|
51
|
+
|
52
|
+
|
36
53
|
def _zero_guard(state, key, ref, dtype):
|
37
54
|
return _guard_in_state(state, key,
|
38
55
|
lambda: torch.zeros_like(ref, dtype=torch.float32, memory_format=torch.preserve_format))
|
@@ -43,61 +60,77 @@ def _storage_dtype(group):
|
|
43
60
|
return getattr(torch, dtype)
|
44
61
|
|
45
62
|
|
46
|
-
|
47
|
-
def
|
48
|
-
|
49
|
-
|
50
|
-
return fn(state, group, update, grad, param, *args, *vars, **kwargs)
|
63
|
+
class ZeroGuard(FunctionTransform):
|
64
|
+
def __init__(self, fn, names):
|
65
|
+
super().__init__(fn)
|
66
|
+
self.names = names
|
51
67
|
|
52
|
-
|
68
|
+
def __call__(self, state, group, update, grad, param, *args, **kwargs):
|
69
|
+
vars = [[_zero_guard(state(p), self.val_name(name), p, _storage_dtype(group)) for p in param] #
|
70
|
+
for name in self.names]
|
71
|
+
return self.fn(state, group, update, grad, param, *args, *vars, **kwargs)
|
53
72
|
|
54
|
-
return _outer
|
55
73
|
|
74
|
+
class CopyGuard(FunctionTransform):
|
75
|
+
def __init__(self, fn, index, names):
|
76
|
+
super().__init__(fn)
|
77
|
+
self.index = index
|
78
|
+
self.names = names
|
56
79
|
|
57
|
-
def
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
for name in names]
|
63
|
-
return fn(state, group, update, grad, param, *args, *vars, **kwargs)
|
80
|
+
def __call__(self, state, group, update, grad, param, *args, **kwargs):
|
81
|
+
val = [update, grad, param, *args][self.index]
|
82
|
+
vars = [[_guard_in_state(state(p), self.val_name(name), lambda: torch.clone(v)) for p, v in zip(param, val)] #
|
83
|
+
for name in self.names]
|
84
|
+
return self.fn(state, group, update, grad, param, *args, *vars, **kwargs)
|
64
85
|
|
65
|
-
return _fn
|
66
86
|
|
67
|
-
|
87
|
+
class GeneralGuard(FunctionTransform): # We can't guard against reuse in the general case
|
88
|
+
def __init__(self, fn, names, init_fn):
|
89
|
+
super().__init__(fn)
|
90
|
+
self.names = names
|
91
|
+
self.init_fn = init_fn
|
68
92
|
|
93
|
+
def __call__(self, state, group, update, grad, param, *args, **kwargs):
|
94
|
+
vars = []
|
95
|
+
skip_update = False
|
96
|
+
for p, g, u in zip(param, grad, update):
|
97
|
+
st = state(p)
|
98
|
+
skip_update |= _inplace_guard_(st, self.names, lambda: self.init_fn(st, group, u, g, p, **kwargs))
|
99
|
+
vars.append([st[name] if isinstance(name, str) else st.get(name[0], name[1]) for name in self.names])
|
100
|
+
if skip_update:
|
101
|
+
raise SkipUpdate
|
102
|
+
return self.fn(state, group, update, grad, param, *args, *zip(*vars), **kwargs)
|
69
103
|
|
70
|
-
def general_guard(*names, init_fn):
|
71
|
-
def _outer(fn):
|
72
|
-
def _fn(state, group, update, grad, param, *args, **kwargs):
|
73
|
-
vars = []
|
74
|
-
skip_update = False
|
75
|
-
for p, g, u in zip(param, grad, update):
|
76
|
-
st = state(p)
|
77
|
-
skip_update |= _inplace_guard_(st, names, lambda: init_fn(st, group, u, g, p, **kwargs))
|
78
|
-
vars.append([st[name] if isinstance(name, str) else st.get(name[0], name[1]) for name in names])
|
79
|
-
if skip_update:
|
80
|
-
raise SkipUpdate
|
81
|
-
return fn(state, group, update, grad, param, *args, *zip(*vars), **kwargs)
|
82
104
|
|
83
|
-
|
105
|
+
class NoState(FunctionTransform):
|
106
|
+
def __call__(self, state, group, update, grad, param, *args, **kwargs):
|
107
|
+
return self.fn(group, update, grad, param, *args, **kwargs)
|
84
108
|
|
85
|
-
return _outer
|
86
109
|
|
110
|
+
class NoStateNoForeach(FunctionTransform):
|
111
|
+
def __call__(self, state, group, update, grad, param, *args, **kwargs):
|
112
|
+
for a in zip(update, grad, param, *args):
|
113
|
+
return self.fn(group, *a, **kwargs)
|
87
114
|
|
88
|
-
def no_state(fn):
|
89
|
-
def _fn(state, *args, **kwargs):
|
90
|
-
return fn(*args, **kwargs)
|
91
115
|
|
92
|
-
|
116
|
+
def zero_guard(*names):
|
117
|
+
return functools.partial(ZeroGuard, names=names)
|
93
118
|
|
94
119
|
|
95
|
-
def
|
96
|
-
|
97
|
-
|
98
|
-
|
120
|
+
def copy_guard(index, *names):
|
121
|
+
return functools.partial(CopyGuard, index=index, names=names)
|
122
|
+
|
123
|
+
|
124
|
+
def general_guard(*names, init_fn):
|
125
|
+
return functools.partial(GeneralGuard, names=names, init_fn=init_fn)
|
99
126
|
|
100
|
-
|
127
|
+
|
128
|
+
def no_state(fn):
|
129
|
+
return NoState(fn)
|
130
|
+
|
131
|
+
|
132
|
+
def no_state_no_foreach(fn):
|
133
|
+
return NoStateNoForeach(fn)
|
101
134
|
|
102
135
|
|
103
136
|
class SkipUpdate(ValueError):
|
@@ -107,18 +140,14 @@ class SkipUpdate(ValueError):
|
|
107
140
|
@zero_guard("exp_avg")
|
108
141
|
@no_state
|
109
142
|
def exp_avg(group, update, grad, param, exp_avg):
|
110
|
-
utils.
|
111
|
-
return exp_avg
|
143
|
+
return utils.scale_by_exp_avg_(exp_avg, update, utils.beta_debias(utils.get_beta1(group), group["step"]))
|
112
144
|
|
113
145
|
|
114
146
|
@zero_guard("exp_avg_sq")
|
115
147
|
@no_state
|
116
148
|
def scale_by_exp_avg_sq(group, update, grad, param, exp_avg_sq):
|
117
|
-
|
118
|
-
|
119
|
-
if group['step'] == 1:
|
120
|
-
raise SkipUpdate
|
121
|
-
return out
|
149
|
+
return utils.scale_by_exp_avg_sq_(exp_avg_sq, update, utils.beta_debias(utils.get_beta2(group), group["step"]),
|
150
|
+
group['eps'])
|
122
151
|
|
123
152
|
|
124
153
|
@zero_guard("exp_avg", "exp_avg_sq")
|
@@ -232,6 +261,29 @@ def precond_schedule(group, prob: Union[callable, float, None] = None, name: str
|
|
232
261
|
raise ValueError("No preconditioner update schedule specified.")
|
233
262
|
|
234
263
|
|
264
|
+
@no_state_no_foreach
|
265
|
+
def orthogonalize_update(group, update, grad, param):
|
266
|
+
if update.dim() == 1:
|
267
|
+
return update
|
268
|
+
original_shape = update.shape
|
269
|
+
# doing it this way, as tmp and update are not guaranteed to share memory address or layout
|
270
|
+
tmp = update.flatten(1, -1)
|
271
|
+
utils.inplace_orthogonal_(tmp, utils.zeroth_power_mode, tmp)
|
272
|
+
return tmp.reshape(original_shape)
|
273
|
+
|
274
|
+
|
275
|
+
@zero_guard("momentum")
|
276
|
+
@no_state
|
277
|
+
def nesterov_momentum(group, updates, grads, params, momentum):
|
278
|
+
return utils.nesterov_momentum(momentum, updates, utils.get_beta1(group))
|
279
|
+
|
280
|
+
|
281
|
+
@zero_guard("momentum")
|
282
|
+
@no_state
|
283
|
+
def heavyball_momentum(group, updates, grads, params, momentum):
|
284
|
+
return utils.heavyball_momentum(momentum, updates, utils.get_beta1(group))
|
285
|
+
|
286
|
+
|
235
287
|
@zero_guard("exp_avg", "exp_avg_sq")
|
236
288
|
@general_guard("Q", "GG", init_fn=_init_soap)
|
237
289
|
@no_state
|
@@ -296,9 +348,12 @@ def _fused_cached_psgd_precond_grad(group, grad, param, cached, cache_expr, expr
|
|
296
348
|
@no_state_no_foreach
|
297
349
|
def scale_by_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_expr: str, cached: bool = False,
|
298
350
|
prob: Optional[callable] = None):
|
351
|
+
old = update
|
352
|
+
update = update.to(memory_format=torch.contiguous_format)
|
299
353
|
Q_mat = utils.line_to_triu(Q) if group['store_triu_as_line'] else Q
|
300
354
|
_update_psgd_precond(group, param, update, Q_mat, Q, exprs, prob)
|
301
|
-
|
355
|
+
out = _cached_psgd_precond_grad(False, cache_expr, exprs, update, Q_mat, Q_cache)
|
356
|
+
return torch.as_strided(out, old.shape, old.stride())
|
302
357
|
|
303
358
|
|
304
359
|
@general_guard("Q", "exprs", ("Q_cache", None), ("cache_expr", None), init_fn=_init_psgd)
|
@@ -346,7 +401,7 @@ def apply_to_idx(fn, idx):
|
|
346
401
|
|
347
402
|
|
348
403
|
def chain(state: Union[callable, dict], group, grad, param, *fns):
|
349
|
-
update = grad
|
404
|
+
update = [torch.clone(g, memory_format=torch.preserve_format) for g in grad]
|
350
405
|
skip_update = False
|
351
406
|
for fn in fns:
|
352
407
|
try:
|
@@ -375,7 +430,10 @@ class ChainOpt(utils.StatefulOptimizer):
|
|
375
430
|
else:
|
376
431
|
group['lr'] = -group['base_lr']
|
377
432
|
|
378
|
-
|
433
|
+
vals = list(self.split_p_and_g_in_group(group, should_promote=False, beta1=utils.get_beta1(group)))
|
434
|
+
if not vals:
|
435
|
+
return
|
436
|
+
p, g = zip(*vals)
|
379
437
|
|
380
438
|
if not group['foreach'] or len(p) == 1:
|
381
439
|
for param, grad in zip(p, g):
|
heavyball/utils.py
CHANGED
@@ -23,7 +23,7 @@ from torch.utils._pytree import tree_map
|
|
23
23
|
compile_mode = "max-autotune-no-cudagraphs"
|
24
24
|
dynamic = False
|
25
25
|
compile_mode_recommended_to_none = None
|
26
|
-
zeroth_power_mode = 'qr' # 'qr' is baseline, 'newtonschulz' converges better and faster
|
26
|
+
zeroth_power_mode = 'qr' # 'qr' is baseline, 'newtonschulz' converges better and faster
|
27
27
|
tiny_bf16 = torch.finfo(torch.bfloat16).tiny
|
28
28
|
|
29
29
|
|
@@ -91,11 +91,9 @@ def schedule_free_(lr: float, weight_lr_power: float, weight_sum: float, beta1:
|
|
91
91
|
ckp1 = weight / weight_sum
|
92
92
|
except ZeroDivisionError:
|
93
93
|
ckp1 = 0
|
94
|
-
ckp1 = 0
|
95
94
|
|
96
|
-
|
97
|
-
|
98
|
-
lr, ckp1, beta1 = scalar_guard(lr, parameters[0]), scalar_guard(ckp1, parameters[0]), scalar_guard(beta1, parameters[0])
|
95
|
+
grad, parameters, z = list_guard(grad, parameters, z)
|
96
|
+
lr, ckp1, beta1 = scalar_guard(lr, ckp1, beta1, grad[0])
|
99
97
|
_compilable_schedule_free_(parameters, z, ckp1, grad, lr, beta1, decay)
|
100
98
|
return weight_sum
|
101
99
|
|
@@ -179,28 +177,43 @@ def _compilable_exp_avg_sq_(state: List[Tensor], grad: List[Tensor], beta2: Tens
|
|
179
177
|
|
180
178
|
|
181
179
|
def exp_avg_sq_(state, grad, beta2, eps, out=None):
|
182
|
-
state, grad, out = list_guard(state
|
183
|
-
beta2, eps = scalar_guard(beta2,
|
180
|
+
state, grad, out = list_guard(state, grad, out)
|
181
|
+
beta2, eps = scalar_guard(beta2, eps, state[0])
|
184
182
|
return _compilable_exp_avg_sq_(state, grad, beta2, eps, out)
|
185
183
|
|
186
184
|
|
187
185
|
@decorator_knowngood
|
188
|
-
def _compilable_scale_by_exp_avg_sq_(state: List[Tensor], grad: List[Tensor], beta2: Tensor, eps: Tensor
|
189
|
-
out: List[Optional[Tensor]]):
|
186
|
+
def _compilable_scale_by_exp_avg_sq_(state: List[Tensor], grad: List[Tensor], beta2: Tensor, eps: Tensor):
|
190
187
|
s32, g32 = [list(map(promote, x)) for x in (state, grad)]
|
191
188
|
torch._foreach_mul_(s32, beta2)
|
192
189
|
[s.addcmul_(g, g, value=1 - beta2) for s, g in zip(s32, g32)]
|
193
190
|
denom = torch._foreach_sqrt(s32)
|
194
191
|
[d.clamp_(min=eps) for d in denom]
|
195
|
-
out = torch.
|
192
|
+
out = torch._foreach_div_(g32, denom)
|
196
193
|
copy_stochastic_list_(state, s32)
|
197
|
-
|
194
|
+
copy_stochastic_list_(grad, out)
|
198
195
|
|
199
196
|
|
200
|
-
def scale_by_exp_avg_sq_(
|
201
|
-
grad, exp_avg_sq = list_guard(grad
|
202
|
-
beta2, eps = scalar_guard(beta2,
|
203
|
-
|
197
|
+
def scale_by_exp_avg_sq_(exp_avg_sq, grad, beta2, eps):
|
198
|
+
grad, exp_avg_sq = list_guard(grad, exp_avg_sq)
|
199
|
+
beta2, eps = scalar_guard(beta2, eps, grad[0])
|
200
|
+
_compilable_scale_by_exp_avg_sq_(exp_avg_sq, grad, beta2, eps)
|
201
|
+
return grad
|
202
|
+
|
203
|
+
|
204
|
+
@decorator_knowngood
|
205
|
+
def _compilable_exp_avg_(state, grad, beta):
|
206
|
+
s32, g32 = [list(map(promote, x)) for x in (state, grad)]
|
207
|
+
[s.lerp_(g, beta) for s, g in zip(s32, g32)]
|
208
|
+
copy_stochastic_list_(state, s32)
|
209
|
+
copy_stochastic_list_(grad, s32)
|
210
|
+
|
211
|
+
|
212
|
+
def scale_by_exp_avg_(state, grad, beta):
|
213
|
+
state, grad = list_guard(state, grad)
|
214
|
+
beta = scalar_guard(beta, state[0])
|
215
|
+
_compilable_exp_avg_(state, grad, beta)
|
216
|
+
return grad
|
204
217
|
|
205
218
|
|
206
219
|
@decorator_knowngood
|
@@ -219,7 +232,7 @@ def adaptive_gradient_clipping_(parameters: List[Tensor], gradients: List[Tensor
|
|
219
232
|
minimum: float = 1e-3, eps: float = 1e-8):
|
220
233
|
if clip_val <= 0:
|
221
234
|
return gradients
|
222
|
-
parameters, gradients = list_guard(parameters
|
235
|
+
parameters, gradients = list_guard(parameters, gradients)
|
223
236
|
clip_val = scalar_guard(clip_val, parameters[0])
|
224
237
|
return _compilable_agc_(parameters, gradients, clip_val, minimum, eps)
|
225
238
|
|
@@ -254,33 +267,29 @@ def set_torch():
|
|
254
267
|
|
255
268
|
|
256
269
|
@decorator
|
257
|
-
def zeropower_via_newtonschulz5(G,
|
270
|
+
def zeropower_via_newtonschulz5(G, steps=5, eps=1e-7):
|
258
271
|
"""
|
259
|
-
Modified from "modded-nanogpt" under the MIT license:
|
260
|
-
Original: https://github.com/KellerJordan/modded-nanogpt/blob/a0dcbfdd9a0617d091d5123cfc354745428e40d3/train_gpt2.py
|
261
|
-
|
262
272
|
Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
|
263
273
|
quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
|
264
274
|
of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
|
265
275
|
zero even beyond the point where the iteration no longer converges all the way to one everywhere
|
266
276
|
on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
|
267
|
-
where S' is diagonal with S_{ii}'
|
277
|
+
where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
|
268
278
|
performance at all relative to UV^T, where USV^T = G is the SVD.
|
269
279
|
"""
|
270
280
|
assert len(G.shape) == 2
|
271
281
|
a, b, c = (3.4445, -4.7750, 2.0315)
|
272
|
-
X = G.
|
273
|
-
|
274
|
-
X = X / (X.norm() + eps) # ensure top singular value <= 1
|
282
|
+
X = G.bfloat16()
|
283
|
+
X /= (X.norm() + eps) # ensure top singular value <= 1
|
275
284
|
if G.size(0) > G.size(1):
|
276
285
|
X = X.T
|
277
286
|
for _ in range(steps):
|
278
|
-
A = X @ X.T
|
279
|
-
B = A @
|
280
|
-
|
287
|
+
A = X @ X.T
|
288
|
+
B = b * A + c * A @ A # adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
|
289
|
+
X = a * X + B @ X
|
281
290
|
if G.size(0) > G.size(1):
|
282
291
|
X = X.T
|
283
|
-
return X
|
292
|
+
return X.to(G.dtype)
|
284
293
|
|
285
294
|
|
286
295
|
def ortho(x):
|
@@ -292,6 +301,53 @@ def ortho(x):
|
|
292
301
|
raise NotImplementedError(f"Unknown zeroth_power_mode: {zeroth_power_mode}")
|
293
302
|
|
294
303
|
|
304
|
+
@decorator_knowngood
|
305
|
+
def _compilable_heavyball_momentum_(state, grad, beta):
|
306
|
+
s32, g32 = [list(map(promote, x)) for x in (state, grad)]
|
307
|
+
torch._foreach_mul_(s32, beta)
|
308
|
+
torch._foreach_add_(s32, g32)
|
309
|
+
copy_stochastic_list_(state, s32)
|
310
|
+
copy_stochastic_list_(grad, s32)
|
311
|
+
|
312
|
+
|
313
|
+
@decorator_knowngood
|
314
|
+
def _compilable_nesterov_momentum_(state, grad, beta):
|
315
|
+
s32, g32 = [list(map(promote, x)) for x in (state, grad)]
|
316
|
+
torch._foreach_mul_(s32, beta)
|
317
|
+
torch._foreach_add_(s32, g32)
|
318
|
+
[g.add_(s, alpha=beta) for g, s in zip(g32, s32)]
|
319
|
+
copy_stochastic_list_(state, s32)
|
320
|
+
copy_stochastic_list_(grad, g32)
|
321
|
+
|
322
|
+
|
323
|
+
def heavyball_momentum(state, grad, beta):
|
324
|
+
state, grad = list_guard(state, grad)
|
325
|
+
beta = scalar_guard(beta, state[0])
|
326
|
+
_compilable_heavyball_momentum_(state, grad, beta)
|
327
|
+
return grad
|
328
|
+
|
329
|
+
|
330
|
+
def nesterov_momentum(state, grad, beta):
|
331
|
+
state, grad = list_guard(state, grad)
|
332
|
+
beta = scalar_guard(beta, state[0])
|
333
|
+
_compilable_nesterov_momentum_(state, grad, beta)
|
334
|
+
return grad
|
335
|
+
|
336
|
+
|
337
|
+
@decorator_knowngood
|
338
|
+
def inplace_orthogonal_(x, mode, out):
|
339
|
+
if mode == 'newtonschulz' or x.shape[0] != x.shape[1]:
|
340
|
+
y = zeropower_via_newtonschulz5(x, 5)
|
341
|
+
elif mode == 'qr':
|
342
|
+
y = torch.linalg.qr(x).Q
|
343
|
+
elif mode == 'svd':
|
344
|
+
u, s, v = torch.linalg.svd(x)
|
345
|
+
y = u @ v.T
|
346
|
+
else:
|
347
|
+
raise NotImplementedError(f"Unknown zeroth_power_mode: {mode}")
|
348
|
+
set_(out, y)
|
349
|
+
|
350
|
+
|
295
351
|
def get_orthogonal_matrix_QR(GG, Q, exp_avg_sq):
|
296
352
|
"""
|
297
353
|
Computes the eigenbases of the preconditioner using one round of power iteration
|
@@ -322,17 +378,7 @@ def get_orthogonal_matrix_QR(GG, Q, exp_avg_sq):
|
|
322
378
|
est_eig = torch.einsum('ij,ij->j', o, tmp)
|
323
379
|
sort_idx = torch.argsort(est_eig, descending=True)
|
324
380
|
indices.append(sort_idx)
|
325
|
-
|
326
|
-
set_(q, torch.linalg.eigh(m)[1])
|
327
|
-
elif zeroth_power_mode.startswith('newtonschulz'):
|
328
|
-
iterations = zeroth_power_mode[len('newtonschulz'):]
|
329
|
-
if iterations == '':
|
330
|
-
iterations = 10
|
331
|
-
else:
|
332
|
-
iterations = int(iterations)
|
333
|
-
set_(q, zeropower_via_newtonschulz5(m, o[:, sort_idx], iterations))
|
334
|
-
else:
|
335
|
-
set_(q, ortho(tmp[:, sort_idx]))
|
381
|
+
inplace_orthogonal_(tmp[:, sort_idx], zeroth_power_mode, q)
|
336
382
|
|
337
383
|
indices = tuple(slice(None) if ind is None else ind.view(*(1,) * i, -1, *(1,) * (exp_avg_sq.dim() - i - 1)) #
|
338
384
|
for i, ind in enumerate(indices))
|
@@ -407,7 +453,6 @@ def get_beta1(group):
|
|
407
453
|
|
408
454
|
|
409
455
|
def get_beta2(group):
|
410
|
-
beta = None
|
411
456
|
if 'beta2_scale' in group:
|
412
457
|
step = max(group.get("step", 1), 1)
|
413
458
|
return 1 - step ** -group['beta2_scale']
|
@@ -417,23 +462,36 @@ def get_beta2(group):
|
|
417
462
|
|
418
463
|
|
419
464
|
def stochastic_lerp_(x: List[Tensor], y: List[Tensor], a: Union[float, int, Tensor]):
|
420
|
-
x, y = list_guard(x
|
465
|
+
x, y = list_guard(x, y)
|
421
466
|
a = scalar_guard(a, x[0])
|
422
467
|
_compilable_stochastic_lerp_(x, y, a)
|
423
468
|
|
424
469
|
|
425
|
-
def list_guard(
|
426
|
-
|
427
|
-
|
428
|
-
|
470
|
+
def list_guard(*xs):
|
471
|
+
out = []
|
472
|
+
for x in xs:
|
473
|
+
if isinstance(x, (list, tuple)):
|
474
|
+
out.append(x)
|
475
|
+
else:
|
476
|
+
out.append([x])
|
477
|
+
if len(xs) == 1:
|
478
|
+
return out[0]
|
479
|
+
return out
|
429
480
|
|
430
481
|
|
431
|
-
def scalar_guard(
|
432
|
-
|
433
|
-
|
434
|
-
|
435
|
-
|
436
|
-
|
482
|
+
def scalar_guard(*args):
|
483
|
+
*xs, ref = args
|
484
|
+
out = []
|
485
|
+
for x in xs:
|
486
|
+
if isinstance(x, float):
|
487
|
+
out.append(torch.empty((), dtype=torch.float32, device=ref.device).fill_(x))
|
488
|
+
elif isinstance(x, int):
|
489
|
+
out.append(torch.empty((), dtype=torch.int64, device=ref.device).fill_(x))
|
490
|
+
else:
|
491
|
+
out.append(x)
|
492
|
+
if len(xs) == 1:
|
493
|
+
return out[0]
|
494
|
+
return out
|
437
495
|
|
438
496
|
|
439
497
|
@decorator_knowngood
|
@@ -446,7 +504,7 @@ def _compilable_stochastic_add_(x: List[Tensor], y: List[Tensor], alpha: Union[f
|
|
446
504
|
|
447
505
|
|
448
506
|
def stochastic_add_(x: List[Tensor], y: List[Tensor], alpha: Union[float, int, Tensor]):
|
449
|
-
x, y = list_guard(x
|
507
|
+
x, y = list_guard(x, y)
|
450
508
|
alpha = scalar_guard(alpha, x[0])
|
451
509
|
_compilable_stochastic_add_(x, y, alpha)
|
452
510
|
|
@@ -695,14 +753,14 @@ def _compilable_adam_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: Lis
|
|
695
753
|
|
696
754
|
copy_stochastic_list_(exp_avg, exp_avg32)
|
697
755
|
copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
|
698
|
-
|
756
|
+
copy_stochastic_list_(grad, u32)
|
699
757
|
|
700
758
|
|
701
759
|
def adam_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], beta1: float, beta2: float, step: int):
|
702
760
|
exp_avg, exp_avg_sq, grad = map(list_guard, (exp_avg, exp_avg_sq, grad))
|
703
|
-
beta1, beta2, step = scalar_guard(beta1,
|
704
|
-
|
705
|
-
return
|
761
|
+
beta1, beta2, step = scalar_guard(beta1, beta2, step, exp_avg[0])
|
762
|
+
_compilable_adam_(exp_avg, exp_avg_sq, grad, beta1, beta2, step)
|
763
|
+
return grad
|
706
764
|
|
707
765
|
|
708
766
|
@decorator_knowngood
|
@@ -725,32 +783,32 @@ def _fused_compilable_adam_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq:
|
|
725
783
|
|
726
784
|
def fused_adam_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], beta1: float,
|
727
785
|
beta2: float, step: int, lr: float, eps: float, decay: float, caution: bool):
|
728
|
-
y, exp_avg, exp_avg_sq, grad =
|
729
|
-
beta1, beta2, step, lr =
|
786
|
+
y, exp_avg, exp_avg_sq, grad = list_guard(y, exp_avg, exp_avg_sq, grad)
|
787
|
+
beta1, beta2, step, lr = scalar_guard(beta1, beta2, step, lr, y[0])
|
730
788
|
return _fused_compilable_adam_(y, exp_avg, exp_avg_sq, grad, beta1, beta2, step, decay, lr, eps, caution)
|
731
789
|
|
732
790
|
|
733
791
|
@decorator_knowngood
|
734
|
-
def _compilable_laprop_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor],
|
792
|
+
def _compilable_laprop_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], beta1: Tensor,
|
735
793
|
beta2: Tensor, step: Tensor):
|
736
794
|
beta1 = beta_debias(beta1, step)
|
737
795
|
beta2 = beta_debias(beta2, step)
|
738
796
|
|
739
|
-
gp32, exp_avg_sq32 = [list(map(promote, x)) for x in [
|
797
|
+
gp32, exp_avg_sq32 = [list(map(promote, x)) for x in [grad, exp_avg_sq]]
|
740
798
|
|
741
799
|
denom = exp_avg_sq_(exp_avg_sq32, gp32, beta2, 1e-8)
|
742
800
|
gp32 = torch._foreach_div(gp32, denom)
|
743
801
|
stochastic_lerp_(exp_avg, gp32, 1 - beta1)
|
744
802
|
|
745
803
|
copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
|
804
|
+
copy_stochastic_list_(grad, exp_avg)
|
746
805
|
|
747
806
|
|
748
|
-
def laprop_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor],
|
749
|
-
|
750
|
-
|
751
|
-
|
752
|
-
|
753
|
-
return exp_avg
|
807
|
+
def laprop_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], beta1: float, beta2: float, step: int):
|
808
|
+
exp_avg, exp_avg_sq, grad = list_guard(exp_avg, exp_avg_sq, grad)
|
809
|
+
beta1, beta2, step = scalar_guard(beta1, beta2, step, exp_avg[0])
|
810
|
+
_compilable_laprop_(exp_avg, exp_avg_sq, grad, beta1, beta2, step)
|
811
|
+
return grad
|
754
812
|
|
755
813
|
|
756
814
|
@decorator_knowngood
|
@@ -770,11 +828,11 @@ def _fused_compilable_laprop_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq
|
|
770
828
|
copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
|
771
829
|
|
772
830
|
|
773
|
-
def fused_laprop_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tensor],
|
774
|
-
|
775
|
-
|
776
|
-
beta1, beta2, step, lr =
|
777
|
-
_fused_compilable_laprop_(y, exp_avg, exp_avg_sq,
|
831
|
+
def fused_laprop_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], beta1: float,
|
832
|
+
beta2: float, step: int, lr: float, decay: float, caution: bool):
|
833
|
+
exp_avg, exp_avg_sq, grad, y = list_guard(exp_avg, exp_avg_sq, grad, y)
|
834
|
+
beta1, beta2, step, lr = scalar_guard(beta1, beta2, step, lr, exp_avg[0])
|
835
|
+
_fused_compilable_laprop_(y, exp_avg, exp_avg_sq, grad, beta1, beta2, step, lr, decay, caution)
|
778
836
|
|
779
837
|
|
780
838
|
@decorator_knowngood
|
@@ -797,8 +855,8 @@ def _fused_compilable_adopt_(y, grad, exp_avg_sq, exp_avg, beta1, beta2, step, l
|
|
797
855
|
|
798
856
|
|
799
857
|
def fused_adopt_(y, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps, decay, caution):
|
800
|
-
|
801
|
-
beta1, beta2, step, lr =
|
858
|
+
exp_avg, exp_avg_sq, grad, y = list_guard(exp_avg, exp_avg_sq, grad, y)
|
859
|
+
beta1, beta2, step, lr = scalar_guard(beta1, beta2, step, lr, exp_avg[0])
|
802
860
|
_fused_compilable_adopt_(y, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps, decay, caution)
|
803
861
|
|
804
862
|
|
@@ -819,14 +877,14 @@ def _compilable_adopt_(grad, exp_avg_sq, exp_avg, beta1, beta2, step):
|
|
819
877
|
|
820
878
|
copy_stochastic_list_(exp_avg, exp_avg32)
|
821
879
|
copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
|
822
|
-
|
823
|
-
return update
|
880
|
+
copy_stochastic_list_(grad, update)
|
824
881
|
|
825
882
|
|
826
883
|
def adopt(grad, exp_avg_sq, exp_avg, beta1, beta2, step):
|
827
|
-
|
828
|
-
beta1, beta2, step = scalar_guard(beta1,
|
829
|
-
|
884
|
+
exp_avg, exp_avg_sq, grad = list_guard(exp_avg, exp_avg_sq, grad, y)
|
885
|
+
beta1, beta2, step = scalar_guard(beta1, beta2, step, lr, exp_avg[0])
|
886
|
+
_compilable_adopt_(grad, exp_avg_sq, exp_avg, beta1, beta2, step)
|
887
|
+
return grad
|
830
888
|
|
831
889
|
|
832
890
|
def stochastic_round_list_(ref: List[Tensor], source: List[Tensor]):
|
@@ -877,7 +935,7 @@ def _compilable_update_(p: List[Tensor], u: List[Tensor], decay: Tensor, add_fn:
|
|
877
935
|
|
878
936
|
def update_param_(param: List[Tensor], update: List[Tensor], lr: float, decay: float, add_fn: callable = None,
|
879
937
|
caution: bool = False, grad: List[Tensor] = None):
|
880
|
-
param, update, grad = list_guard(param
|
938
|
+
param, update, grad = list_guard(param, update, grad)
|
881
939
|
lr = scalar_guard(lr, param[0])
|
882
940
|
if not caution:
|
883
941
|
grad = [None] * len(param)
|
@@ -983,11 +1041,15 @@ def psgd_balance_Q(Q_in):
|
|
983
1041
|
|
984
1042
|
|
985
1043
|
def psgd_calc_A_and_conjB(exprA, G, Q):
|
1044
|
+
V = torch.randn(G.shape, dtype=G.dtype, device=G.device)
|
1045
|
+
eps = scalar_guard(math.sqrt(torch.finfo(torch.float32).eps), G)
|
1046
|
+
eps *= G.norm() / G.numel()
|
1047
|
+
G += V * eps
|
986
1048
|
md = min_dtype(Q + [G])
|
987
1049
|
A = torch.einsum(exprA, *[q.to(md) for q in Q], G.to(md)).to(G.dtype)
|
988
1050
|
order = G.dim()
|
989
1051
|
p = list(range(order))
|
990
|
-
conjB = torch.
|
1052
|
+
conjB = torch.permute(V, p[1:] + p[:1]).to(promote(G.dtype))
|
991
1053
|
Q = [promote(q) for q in Q]
|
992
1054
|
for i, q in enumerate(Q):
|
993
1055
|
if q.dim() <= 1:
|
@@ -1134,7 +1196,7 @@ def _compilable_trust_region_clip_(grad, lerp: float = 0.9, scale: float = 1.5):
|
|
1134
1196
|
|
1135
1197
|
def trust_region_clip_(grad, lerp=0.9, scale=1.5):
|
1136
1198
|
grad = list_guard(grad)
|
1137
|
-
lerp, scale = scalar_guard(lerp,
|
1199
|
+
lerp, scale = scalar_guard(lerp, scale, grad[0])
|
1138
1200
|
return _compilable_trust_region_clip_(grad, lerp, scale)
|
1139
1201
|
|
1140
1202
|
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: heavyball
|
3
|
-
Version: 1.
|
3
|
+
Version: 1.1.1
|
4
4
|
Summary: Efficient optimizers
|
5
5
|
Home-page: https://github.com/clashluke/heavyball
|
6
6
|
Author: Lucas Nestler
|
@@ -132,5 +132,5 @@ It has several handy functions:
|
|
132
132
|
|
133
133
|
* `set_torch()` sets pytorch optimization settings (TF32, opt_einsum, benchmark, ...)
|
134
134
|
* `compile_mode`, a string passed as-is to `torch.compile(mode=compile_mode)` in all compiled heavyball calls; `compile_mode=None` disables torch_compile
|
135
|
-
* `zeroth_power_mode`, a string determining whether to use QR, newtonschulz
|
136
|
-
the eigenvectors.
|
135
|
+
* `zeroth_power_mode`, a string determining whether to use QR, newtonschulz, or svd or eigh to approximate
|
136
|
+
the eigenvectors.
|
@@ -0,0 +1,8 @@
|
|
1
|
+
heavyball/__init__.py,sha256=ZPadmK5Q9hGrBAeNYldMUw8vciy9dCDL3d7Zk9erC3E,12702
|
2
|
+
heavyball/chainable.py,sha256=_KVV_bA_WYbyaiGDOaoQMHv-IM9jbgZx_cwzmRiKxl8,20321
|
3
|
+
heavyball/utils.py,sha256=F4nwiyDCTGg3uU5XAaX9_p11qf5Uiw4WRseMbuLfq0Y,47223
|
4
|
+
heavyball-1.1.1.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
|
5
|
+
heavyball-1.1.1.dist-info/METADATA,sha256=uLpyF5G1Stjxxj23qmxpnk3ViUd0M55Q0wIDGDY58qk,12022
|
6
|
+
heavyball-1.1.1.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
7
|
+
heavyball-1.1.1.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
|
8
|
+
heavyball-1.1.1.dist-info/RECORD,,
|
heavyball-1.0.0.dist-info/RECORD
DELETED
@@ -1,8 +0,0 @@
|
|
1
|
-
heavyball/__init__.py,sha256=1QPYBIH8amnk3-_rKe6L9FJ0rkV5wVNRr7Yw9BXjIYI,11636
|
2
|
-
heavyball/chainable.py,sha256=cp-tpetPr4CNN9xJ85JSo89JYC5BWUygoE6dnET6tmc,18141
|
3
|
-
heavyball/utils.py,sha256=qUoB9EIxl7GUyLkV5a5JAKOD6TvPc1FNsqyUbJ-HY6o,46343
|
4
|
-
heavyball-1.0.0.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
|
5
|
-
heavyball-1.0.0.dist-info/METADATA,sha256=9C2btIxngp26TRCJFU6B8ftkWQt1rfZZC10rkAhaORw,12074
|
6
|
-
heavyball-1.0.0.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
7
|
-
heavyball-1.0.0.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
|
8
|
-
heavyball-1.0.0.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|