heavyball 1.3.1__tar.gz → 1.4.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {heavyball-1.3.1 → heavyball-1.4.0}/PKG-INFO +1 -1
- {heavyball-1.3.1 → heavyball-1.4.0}/heavyball/chainable.py +39 -23
- {heavyball-1.3.1 → heavyball-1.4.0}/heavyball/utils.py +19 -19
- {heavyball-1.3.1 → heavyball-1.4.0}/heavyball.egg-info/PKG-INFO +1 -1
- {heavyball-1.3.1 → heavyball-1.4.0}/setup.py +1 -1
- {heavyball-1.3.1 → heavyball-1.4.0}/LICENSE +0 -0
- {heavyball-1.3.1 → heavyball-1.4.0}/README.md +0 -0
- {heavyball-1.3.1 → heavyball-1.4.0}/heavyball/__init__.py +0 -0
- {heavyball-1.3.1 → heavyball-1.4.0}/heavyball.egg-info/SOURCES.txt +0 -0
- {heavyball-1.3.1 → heavyball-1.4.0}/heavyball.egg-info/dependency_links.txt +0 -0
- {heavyball-1.3.1 → heavyball-1.4.0}/heavyball.egg-info/requires.txt +0 -0
- {heavyball-1.3.1 → heavyball-1.4.0}/heavyball.egg-info/top_level.txt +0 -0
- {heavyball-1.3.1 → heavyball-1.4.0}/setup.cfg +0 -0
- {heavyball-1.3.1 → heavyball-1.4.0}/test/test_bf16_params.py +0 -0
- {heavyball-1.3.1 → heavyball-1.4.0}/test/test_bf16_q.py +0 -0
- {heavyball-1.3.1 → heavyball-1.4.0}/test/test_bf16_storage.py +0 -0
- {heavyball-1.3.1 → heavyball-1.4.0}/test/test_caution.py +0 -0
- {heavyball-1.3.1 → heavyball-1.4.0}/test/test_channels_last.py +0 -0
- {heavyball-1.3.1 → heavyball-1.4.0}/test/test_closure.py +0 -0
- {heavyball-1.3.1 → heavyball-1.4.0}/test/test_ema.py +0 -0
- {heavyball-1.3.1 → heavyball-1.4.0}/test/test_foreach.py +0 -0
- {heavyball-1.3.1 → heavyball-1.4.0}/test/test_hook.py +0 -0
- {heavyball-1.3.1 → heavyball-1.4.0}/test/test_mars.py +0 -0
- {heavyball-1.3.1 → heavyball-1.4.0}/test/test_memory.py +0 -0
- {heavyball-1.3.1 → heavyball-1.4.0}/test/test_merge.py +0 -0
- {heavyball-1.3.1 → heavyball-1.4.0}/test/test_no_grad.py +0 -0
- {heavyball-1.3.1 → heavyball-1.4.0}/test/test_psgd.py +0 -0
- {heavyball-1.3.1 → heavyball-1.4.0}/test/test_soap.py +0 -0
- {heavyball-1.3.1 → heavyball-1.4.0}/test/test_stochastic_updates.py +0 -0
@@ -1,6 +1,5 @@
|
|
1
1
|
import functools
|
2
2
|
import random
|
3
|
-
import warnings
|
4
3
|
from typing import Optional, Union, Literal
|
5
4
|
|
6
5
|
import torch
|
@@ -85,10 +84,11 @@ class CopyGuard(FunctionTransform):
|
|
85
84
|
|
86
85
|
|
87
86
|
class GeneralGuard(FunctionTransform): # We can't guard against reuse in the general case
|
88
|
-
def __init__(self, fn, names, init_fn):
|
87
|
+
def __init__(self, fn, names, init_fn, skip_first: bool = True):
|
89
88
|
super().__init__(fn)
|
90
89
|
self.names = names
|
91
90
|
self.init_fn = init_fn
|
91
|
+
self.skip_first = skip_first
|
92
92
|
|
93
93
|
def __call__(self, state, group, update, grad, param, *args, **kwargs):
|
94
94
|
vars = []
|
@@ -97,7 +97,7 @@ class GeneralGuard(FunctionTransform): # We can't guard against reuse in the ge
|
|
97
97
|
st = state(p)
|
98
98
|
skip_update |= _inplace_guard_(st, self.names, lambda: self.init_fn(st, group, u, g, p, **kwargs))
|
99
99
|
vars.append([st[name] if isinstance(name, str) else st.get(name[0], name[1]) for name in self.names])
|
100
|
-
if skip_update:
|
100
|
+
if skip_update and self.skip_first:
|
101
101
|
raise SkipUpdate
|
102
102
|
return self.fn(state, group, update, grad, param, *args, *zip(*vars), **kwargs)
|
103
103
|
|
@@ -109,8 +109,17 @@ class NoState(FunctionTransform):
|
|
109
109
|
|
110
110
|
class NoStateNoForeach(FunctionTransform):
|
111
111
|
def __call__(self, state, group, update, grad, param, *args, **kwargs):
|
112
|
+
updates = []
|
113
|
+
skip_update = False
|
112
114
|
for a in zip(update, grad, param, *args):
|
113
|
-
|
115
|
+
try:
|
116
|
+
updates.append(self.fn(group, *a, **kwargs))
|
117
|
+
except SkipUpdate:
|
118
|
+
skip_update = True
|
119
|
+
pass
|
120
|
+
if skip_update:
|
121
|
+
raise SkipUpdate
|
122
|
+
return updates
|
114
123
|
|
115
124
|
|
116
125
|
def zero_guard(*names):
|
@@ -118,11 +127,11 @@ def zero_guard(*names):
|
|
118
127
|
|
119
128
|
|
120
129
|
def copy_guard(index, *names):
|
121
|
-
return functools.partial(CopyGuard, index=index, names=names)
|
130
|
+
return functools.partial(CopyGuard, index=index, names=names,)
|
122
131
|
|
123
132
|
|
124
|
-
def general_guard(*names, init_fn):
|
125
|
-
return functools.partial(GeneralGuard, names=names, init_fn=init_fn)
|
133
|
+
def general_guard(*names, init_fn, skip_first: bool = True):
|
134
|
+
return functools.partial(GeneralGuard, names=names, init_fn=init_fn, skip_first=skip_first)
|
126
135
|
|
127
136
|
|
128
137
|
def no_state(fn):
|
@@ -311,18 +320,18 @@ def _update_psgd_precond(cached, Q_cache, group, param, grad, Q_mat, Q, exprs, p
|
|
311
320
|
if prob is None:
|
312
321
|
prob = utils.precond_update_prob_schedule()
|
313
322
|
if not precond_schedule(group, prob, name=f"cumulative_prob_{id(Q)}"):
|
314
|
-
return
|
323
|
+
return Q_mat
|
315
324
|
|
316
|
-
Q = [utils.promote(q_) for q_ in Q]
|
317
325
|
utils.psgd_update_precond(Q_mat, exprs, grad, group['precond_lr'], Q, group['store_triu_as_line'])
|
318
326
|
|
319
|
-
if grad.dim() > 1 and precond_schedule(group, balance_probability, "
|
327
|
+
if grad.dim() > 1 and precond_schedule(group, balance_probability, f"balance_prob_{id(Q)}"):
|
320
328
|
if group['store_triu_as_line']:
|
321
329
|
utils.psgd_balance_Q([q_ for _, q_ in Q])
|
322
330
|
else:
|
323
331
|
utils.psgd_balance_Q(Q)
|
324
332
|
|
325
|
-
_update_psgd_cache(cached, Q_cache, Q_mat)
|
333
|
+
return _update_psgd_cache(cached, Q_cache, Q_mat)
|
334
|
+
|
326
335
|
|
327
336
|
def _update_psgd_cache(cached, Q_cache, q):
|
328
337
|
if not cached:
|
@@ -351,44 +360,47 @@ def _fused_cached_psgd_precond_grad(group, grad, param, cached, cache_expr, expr
|
|
351
360
|
group['caution'], *Q_mat)
|
352
361
|
|
353
362
|
|
354
|
-
@general_guard("Q", "exprs", ("Q_cache", None), ("cache_expr", None), init_fn=_init_psgd)
|
363
|
+
@general_guard("Q", "exprs", ("Q_cache", None), ("cache_expr", None), init_fn=_init_psgd, skip_first=False)
|
355
364
|
@no_state_no_foreach
|
356
365
|
def scale_by_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_expr: str, cached: bool = False,
|
357
366
|
prob: Optional[callable] = None):
|
358
|
-
old = update
|
359
367
|
update = update.to(memory_format=torch.contiguous_format)
|
360
368
|
Q_mat = utils.line_to_triu(Q) if group['store_triu_as_line'] else Q
|
361
|
-
_update_psgd_precond(cached, Q_cache, group, param,
|
369
|
+
Q_mat = _update_psgd_precond(cached, Q_cache, group, param,
|
370
|
+
update if group['momentum_into_precond_update'] else grad, Q_mat, Q, exprs, prob)
|
362
371
|
return _cached_psgd_precond_grad(cached, cache_expr, exprs, update, Q_mat, Q_cache)
|
363
372
|
|
364
373
|
|
365
|
-
@general_guard("Q", "exprs", ("Q_cache", None), ("cache_expr", None), init_fn=_init_psgd)
|
374
|
+
@general_guard("Q", "exprs", ("Q_cache", None), ("cache_expr", None), init_fn=_init_psgd, skip_first=False)
|
366
375
|
@no_state_no_foreach
|
367
376
|
def scale_by_delayed_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_expr: str, cached: bool = False,
|
368
377
|
prob: Optional[callable] = None):
|
369
378
|
Q_mat = utils.line_to_triu(Q) if group['store_triu_as_line'] else Q
|
370
379
|
precond = _cached_psgd_precond_grad(cached, cache_expr, exprs, update, Q_mat, Q_cache)
|
371
|
-
_update_psgd_precond(cached, Q_cache, group, param,
|
380
|
+
_update_psgd_precond(cached, Q_cache, group, param, update if group['momentum_into_precond_update'] else grad,
|
381
|
+
Q_mat, Q, exprs, prob)
|
372
382
|
return precond
|
373
383
|
|
374
384
|
|
375
|
-
@general_guard("Q", "exprs", ("Q_cache", None), ("cache_expr", None), init_fn=_init_psgd)
|
385
|
+
@general_guard("Q", "exprs", ("Q_cache", None), ("cache_expr", None), init_fn=_init_psgd, skip_first=False)
|
376
386
|
@no_state_no_foreach
|
377
387
|
def update_by_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_expr: str, cached: bool = False,
|
378
388
|
prob: Optional[callable] = None):
|
379
389
|
Q_mat = utils.line_to_triu(Q) if group['store_triu_as_line'] else Q
|
380
|
-
_update_psgd_precond(cached, Q_cache, group, param,
|
390
|
+
Q_mat = _update_psgd_precond(cached, Q_cache, group, param,
|
391
|
+
update if group['momentum_into_precond_update'] else grad, Q_mat, Q, exprs, prob)
|
381
392
|
_fused_cached_psgd_precond_grad(group, update, param, cached, cache_expr, exprs, update, Q_mat, Q_cache)
|
382
393
|
raise SkipUpdate
|
383
394
|
|
384
395
|
|
385
|
-
@general_guard("Q", "exprs", ("Q_cache", None), ("cache_expr", None), init_fn=_init_psgd)
|
396
|
+
@general_guard("Q", "exprs", ("Q_cache", None), ("cache_expr", None), init_fn=_init_psgd, skip_first=False)
|
386
397
|
@no_state_no_foreach
|
387
398
|
def update_by_delayed_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_expr: str, cached: bool = False,
|
388
399
|
prob: Optional[callable] = None):
|
389
400
|
Q_mat = utils.line_to_triu(Q) if group['store_triu_as_line'] else Q
|
390
401
|
_fused_cached_psgd_precond_grad(group, update, param, cached, cache_expr, exprs, update, Q_mat, Q_cache)
|
391
|
-
_update_psgd_precond(cached, Q_cache, group, param,
|
402
|
+
_update_psgd_precond(cached, Q_cache, group, param, update if group['momentum_into_precond_update'] else grad,
|
403
|
+
Q_mat, Q, exprs, prob)
|
392
404
|
raise SkipUpdate
|
393
405
|
|
394
406
|
|
@@ -422,7 +434,6 @@ def chain(state: Union[callable, dict], group, grad, param, *fns):
|
|
422
434
|
|
423
435
|
|
424
436
|
class ChainOpt(utils.StatefulOptimizer):
|
425
|
-
compile_step: bool = False
|
426
437
|
promote: bool = False
|
427
438
|
|
428
439
|
def __init__(self, params, defaults, foreach: bool, *fns):
|
@@ -432,6 +443,10 @@ class ChainOpt(utils.StatefulOptimizer):
|
|
432
443
|
def _step(self, group):
|
433
444
|
if 'base_lr' not in group:
|
434
445
|
group['base_lr'] = group['lr']
|
446
|
+
if 'prev_lr' in group and group['prev_lr'] != group['lr']:
|
447
|
+
utils.warn_once(f'Learning rate changed between steps. This is an experimental feature and '
|
448
|
+
f'only supported with foreach=True (currently foreach={group["foreach"]}).')
|
449
|
+
group['base_lr'] = group['lr']
|
435
450
|
|
436
451
|
vals = list(self.split_p_and_g_in_group(group, should_promote=self.promote, beta1=utils.get_beta1(group)))
|
437
452
|
if not vals:
|
@@ -451,9 +466,10 @@ class ChainOpt(utils.StatefulOptimizer):
|
|
451
466
|
group['step'] = state['step'] = step = step + 1
|
452
467
|
|
453
468
|
if group['warmup_steps'] and step < group['warmup_steps']:
|
454
|
-
group['lr'] = group['base_lr'] * step / group['warmup_steps']
|
469
|
+
group['prev_lr'] = group['lr'] = group['base_lr'] * step / group['warmup_steps']
|
470
|
+
|
455
471
|
else:
|
456
|
-
group['lr'] = group['base_lr']
|
472
|
+
group['prev_lr'] = group['lr'] = group['base_lr']
|
457
473
|
|
458
474
|
if not group['foreach'] or len(p) == 1:
|
459
475
|
for param, grad in zip(p, g):
|
@@ -193,12 +193,12 @@ def scale_by_exp_avg_sq_(exp_avg_sq, grad, beta2, eps):
|
|
193
193
|
return grad
|
194
194
|
|
195
195
|
|
196
|
-
# TODO: This lerp was fucked - check other lerps
|
197
196
|
@decorator_knowngood
|
198
197
|
def _compilable_exp_avg_(state, grad, beta):
|
199
|
-
|
200
|
-
|
201
|
-
|
198
|
+
for s, g in zip(state, grad):
|
199
|
+
lerped = s.lerp(g, 1 - beta)
|
200
|
+
copy_stochastic_(s, lerped)
|
201
|
+
copy_stochastic_(g, lerped)
|
202
202
|
|
203
203
|
|
204
204
|
def scale_by_exp_avg_(state, grad, beta):
|
@@ -592,6 +592,7 @@ def project(grad, Q, back: bool):
|
|
592
592
|
|
593
593
|
class StatefulOptimizer(torch.optim.Optimizer):
|
594
594
|
ema_decay: float = 0.001
|
595
|
+
compile_step: bool = False
|
595
596
|
|
596
597
|
def __init__(self, params, defaults, foreach: bool = True, use_ema: bool = False):
|
597
598
|
super().__init__(params, {**defaults, 'foreach': foreach})
|
@@ -637,6 +638,10 @@ class StatefulOptimizer(torch.optim.Optimizer):
|
|
637
638
|
|
638
639
|
p.grad = None
|
639
640
|
|
641
|
+
if self.compile_step:
|
642
|
+
yield p, grad
|
643
|
+
continue
|
644
|
+
|
640
645
|
p_views = merge_group(group, p)
|
641
646
|
if grad is not None:
|
642
647
|
grad = merge_group(group, grad)
|
@@ -1030,7 +1035,7 @@ def psgd_calc_A_and_conjB(exprA, G, Q):
|
|
1030
1035
|
V = torch.randn(G.shape, dtype=G.dtype, device=G.device)
|
1031
1036
|
eps = scalar_guard(math.sqrt(torch.finfo(torch.float32).eps), G)
|
1032
1037
|
eps *= G.norm() / G.numel()
|
1033
|
-
G
|
1038
|
+
G = G + V * eps
|
1034
1039
|
md = min_dtype(Q + [G])
|
1035
1040
|
A = torch.einsum(exprA, *[q.to(md) for q in Q], G.to(md)).to(G.dtype)
|
1036
1041
|
order = G.dim()
|
@@ -1078,26 +1083,20 @@ def psgd_update_precond(Q, exprs, G, precond_lr, oq, store_triu_as_line):
|
|
1078
1083
|
term1 = promote(torch.einsum(exprG, A, A))
|
1079
1084
|
term2 = promote(torch.einsum(exprG, conjB, conjB))
|
1080
1085
|
|
1081
|
-
term2
|
1082
|
-
term1 *= 2 # 2a
|
1083
|
-
if term1.dtype == term2.dtype:
|
1084
|
-
term1 -= term2 # 2a - (a + b) == a - b
|
1085
|
-
else:
|
1086
|
-
term1 = term1 - term2
|
1086
|
+
term1, term2 = term1 - term2, term1 + term2
|
1087
1087
|
|
1088
1088
|
term1 *= precond_lr
|
1089
1089
|
norm = term2.norm(float('inf'))
|
1090
1090
|
if q.dim() < 2:
|
1091
|
-
term1 *= q.to(term1.dtype)
|
1092
|
-
term1 /= norm.clamp_(min=tiny_bf16)
|
1091
|
+
term1 *= q.to(term1.dtype) / norm.clamp_(min=tiny_bf16)
|
1093
1092
|
else:
|
1094
1093
|
torch.triu(term1, out=term1)
|
1095
|
-
term1 /= psgd_lb(term2, norm).clamp_(tiny_bf16)
|
1096
|
-
torch.
|
1094
|
+
term1 /= torch.where(norm > 0, psgd_lb(term2, norm), norm).clamp_(tiny_bf16)
|
1095
|
+
term1 = torch.mm(term1, q)
|
1097
1096
|
if store_triu_as_line:
|
1098
1097
|
term1 = triu_to_line([term1])[0][1]
|
1099
1098
|
o = o[1]
|
1100
|
-
stochastic_add_(
|
1099
|
+
stochastic_add_(o, term1, -1)
|
1101
1100
|
|
1102
1101
|
|
1103
1102
|
@decorator_knowngood
|
@@ -1162,7 +1161,7 @@ def mu_law_compress(x, mu=127.0):
|
|
1162
1161
|
"""
|
1163
1162
|
x = list_guard(x)
|
1164
1163
|
mu = scalar_guard(mu, x[0])
|
1165
|
-
|
1164
|
+
_compilable_mu_law_compress_(x, mu)
|
1166
1165
|
return x
|
1167
1166
|
|
1168
1167
|
|
@@ -1191,7 +1190,7 @@ def a_law_compress(x, A=87.6):
|
|
1191
1190
|
"""
|
1192
1191
|
x = list_guard(x)
|
1193
1192
|
A = scalar_guard(A, x[0])
|
1194
|
-
|
1193
|
+
_compilable_a_law_compress_(x, A)
|
1195
1194
|
return x
|
1196
1195
|
|
1197
1196
|
|
@@ -1295,6 +1294,7 @@ def _compilable_fused_precond_grad_cached_(expr: str, ea: Tensor, param, lr, gra
|
|
1295
1294
|
|
1296
1295
|
|
1297
1296
|
def fused_precond_grad_cached_(expr: str, ea: Tensor, param, lr, grad, decay, caution, *cached_q: Tensor):
|
1297
|
+
|
1298
1298
|
lr = scalar_guard(lr, param[0])
|
1299
1299
|
_compilable_fused_precond_grad_cached_(expr, ea, param, lr, grad, decay, caution, *cached_q)
|
1300
1300
|
|
@@ -1310,7 +1310,7 @@ def psgd_precond_grad(expr: str, ea: Tensor, *preconds: Tensor):
|
|
1310
1310
|
|
1311
1311
|
@decorator_knowngood
|
1312
1312
|
def _compilable_fused_psgd_precond_grad(expr: str, ea: Tensor, param, lr, grad, decay, caution, *preconds: Tensor):
|
1313
|
-
precond = psgd_precond_grad(expr,
|
1313
|
+
precond = psgd_precond_grad(expr, ea, *preconds)
|
1314
1314
|
update_param_(param, precond, lr, decay, caution=caution, grad=grad)
|
1315
1315
|
|
1316
1316
|
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|