heavyball 1.5.2__tar.gz → 1.5.3__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {heavyball-1.5.2 → heavyball-1.5.3}/PKG-INFO +1 -1
- {heavyball-1.5.2 → heavyball-1.5.3}/heavyball/__init__.py +14 -1
- {heavyball-1.5.2 → heavyball-1.5.3}/heavyball/chainable.py +63 -4
- {heavyball-1.5.2 → heavyball-1.5.3}/heavyball/utils.py +86 -11
- {heavyball-1.5.2 → heavyball-1.5.3}/heavyball.egg-info/PKG-INFO +1 -1
- {heavyball-1.5.2 → heavyball-1.5.3}/setup.py +1 -1
- {heavyball-1.5.2 → heavyball-1.5.3}/LICENSE +0 -0
- {heavyball-1.5.2 → heavyball-1.5.3}/README.md +0 -0
- {heavyball-1.5.2 → heavyball-1.5.3}/heavyball.egg-info/SOURCES.txt +0 -0
- {heavyball-1.5.2 → heavyball-1.5.3}/heavyball.egg-info/dependency_links.txt +0 -0
- {heavyball-1.5.2 → heavyball-1.5.3}/heavyball.egg-info/requires.txt +0 -0
- {heavyball-1.5.2 → heavyball-1.5.3}/heavyball.egg-info/top_level.txt +0 -0
- {heavyball-1.5.2 → heavyball-1.5.3}/setup.cfg +0 -0
- {heavyball-1.5.2 → heavyball-1.5.3}/test/test_bf16_params.py +0 -0
- {heavyball-1.5.2 → heavyball-1.5.3}/test/test_bf16_q.py +0 -0
- {heavyball-1.5.2 → heavyball-1.5.3}/test/test_bf16_storage.py +0 -0
- {heavyball-1.5.2 → heavyball-1.5.3}/test/test_caution.py +0 -0
- {heavyball-1.5.2 → heavyball-1.5.3}/test/test_channels_last.py +0 -0
- {heavyball-1.5.2 → heavyball-1.5.3}/test/test_closure.py +0 -0
- {heavyball-1.5.2 → heavyball-1.5.3}/test/test_ema.py +0 -0
- {heavyball-1.5.2 → heavyball-1.5.3}/test/test_foreach.py +0 -0
- {heavyball-1.5.2 → heavyball-1.5.3}/test/test_hook.py +0 -0
- {heavyball-1.5.2 → heavyball-1.5.3}/test/test_mars.py +0 -0
- {heavyball-1.5.2 → heavyball-1.5.3}/test/test_memory.py +0 -0
- {heavyball-1.5.2 → heavyball-1.5.3}/test/test_merge.py +0 -0
- {heavyball-1.5.2 → heavyball-1.5.3}/test/test_no_grad.py +0 -0
- {heavyball-1.5.2 → heavyball-1.5.3}/test/test_psgd.py +0 -0
- {heavyball-1.5.2 → heavyball-1.5.3}/test/test_soap.py +0 -0
- {heavyball-1.5.2 → heavyball-1.5.3}/test/test_stochastic_updates.py +0 -0
@@ -163,6 +163,19 @@ class OrthoLaProp(C.BaseOpt):
|
|
163
163
|
C.orthogonalize_grad_to_param, C.scale_by_laprop)
|
164
164
|
|
165
165
|
|
166
|
+
|
167
|
+
class LaPropOrtho(C.BaseOpt):
|
168
|
+
def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
|
169
|
+
foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False,
|
170
|
+
mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default,
|
171
|
+
update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8):
|
172
|
+
defaults = locals()
|
173
|
+
defaults.pop("self")
|
174
|
+
params = defaults.pop("params")
|
175
|
+
super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm,
|
176
|
+
C.scale_by_laprop, C.orthogonalize_grad_to_param)
|
177
|
+
|
178
|
+
|
166
179
|
class ForeachPSGDKron(C.BaseOpt):
|
167
180
|
"""
|
168
181
|
Originally from Evan Walters and Omead Pooladzandi, 2024
|
@@ -244,4 +257,4 @@ __all__ = ["Muon", "RMSprop", "PrecondSchedulePaLMSOAP", "PSGDKron", "PurePSGD",
|
|
244
257
|
"PrecondScheduleSOAP", "PrecondSchedulePaLMSOAP", 'RMSprop', 'MuonLaProp', #
|
245
258
|
"ForeachAdamW", "ForeachSFAdamW", "ForeachLaProp", "ForeachADOPT", "ForeachSOAP", "ForeachPSGDKron",
|
246
259
|
"ForeachPurePSGD", "ForeachDelayedPSGD", "ForeachCachedPSGDKron", "ForeachCachedDelayedPSGDKron",
|
247
|
-
"ForeachRMSprop", "ForeachMuon", 'ForeachCachedNewtonPSGD']
|
260
|
+
"ForeachRMSprop", "ForeachMuon", 'ForeachCachedNewtonPSGD', 'OrthoLaProp', 'LaPropOrtho']
|
@@ -1,6 +1,6 @@
|
|
1
1
|
import functools
|
2
2
|
import random
|
3
|
-
from typing import Optional, Union, Literal
|
3
|
+
from typing import Optional, Union, Literal, List
|
4
4
|
|
5
5
|
import torch
|
6
6
|
|
@@ -152,6 +152,22 @@ def exp_avg(group, update, grad, param, exp_avg):
|
|
152
152
|
return utils.scale_by_exp_avg_(exp_avg, update, utils.beta_debias(utils.get_beta1(group), group["step"]))
|
153
153
|
|
154
154
|
|
155
|
+
@zero_guard('exp_avg')
|
156
|
+
@no_state
|
157
|
+
def weight_decay_to_ema(group, update, grad, param, exp_avg):
|
158
|
+
utils.weight_decay_to_ema_(exp_avg, update, utils.beta_debias(group['ema_beta'], group['step']),
|
159
|
+
group['weight_decay_to_ema'] * group['lr'])
|
160
|
+
return update
|
161
|
+
|
162
|
+
|
163
|
+
@zero_guard('exp_avg')
|
164
|
+
@no_state
|
165
|
+
def l1_weight_decay_to_ema(group, update, grad, param, exp_avg):
|
166
|
+
utils.l1_weight_decay_to_ema_(exp_avg, update, utils.beta_debias(group['ema_beta'], group['step']),
|
167
|
+
group['weight_decay_to_ema'] * group['lr'])
|
168
|
+
return update
|
169
|
+
|
170
|
+
|
155
171
|
@zero_guard("exp_avg_sq")
|
156
172
|
@no_state
|
157
173
|
def scale_by_exp_avg_sq(group, update, grad, param, exp_avg_sq):
|
@@ -295,6 +311,25 @@ def nesterov_momentum(group, updates, grads, params, momentum):
|
|
295
311
|
return utils.nesterov_momentum(momentum, updates, utils.get_beta1(group))
|
296
312
|
|
297
313
|
|
314
|
+
@zero_guard('momentum')
|
315
|
+
@no_state
|
316
|
+
def nesterov_ema(group, updates, grads, params, momentum): # equivalent to Grokfast
|
317
|
+
return utils.nesterov_ema(momentum, updates, utils.get_beta1(group))
|
318
|
+
|
319
|
+
|
320
|
+
def _store_std(state, group, update, grad, param):
|
321
|
+
state['init_std'] = torch.std(grad, dim=0)
|
322
|
+
|
323
|
+
|
324
|
+
@general_guard("init_std", init_fn=_store_std)
|
325
|
+
@no_state
|
326
|
+
def mup_approx(group, updates, grads, params, init_std):
|
327
|
+
_updates = [(u, i) for u, i in zip(updates, init_std) if u.ndim > 1]
|
328
|
+
_updates, _init_std = zip(*_updates)
|
329
|
+
utils.stochastic_multiply_(_updates, _init_std)
|
330
|
+
return updates
|
331
|
+
|
332
|
+
|
298
333
|
@zero_guard("momentum")
|
299
334
|
@no_state
|
300
335
|
def heavyball_momentum(group, updates, grads, params, momentum):
|
@@ -312,7 +347,7 @@ def scale_by_soap(group, update, grad, param, exp_avg, exp_avg_sq, Q, GG, inner:
|
|
312
347
|
|
313
348
|
grad_projected = [utils.project(u, q, False) for u, q in zip(update, Q)]
|
314
349
|
fn = _optim_fns[inner]
|
315
|
-
precond = fn(exp_avg, exp_avg_sq, grad_projected, utils.get_beta1(group), utils.get_beta2(group), group['step'])
|
350
|
+
precond = fn(exp_avg, exp_avg_sq, grad_projected, utils.get_beta1(group), utils.get_beta2(group), group['step'], group['eps'])
|
316
351
|
precond = [utils.project(p, q, True) for p, q in zip(precond, Q)]
|
317
352
|
|
318
353
|
for u, q, gg, eas in zip(update, Q, GG, exp_avg_sq):
|
@@ -414,6 +449,11 @@ def update_by_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_expr: st
|
|
414
449
|
raise SkipUpdate
|
415
450
|
|
416
451
|
|
452
|
+
@no_state
|
453
|
+
def sign(group, update, grad, param, graft: bool = True):
|
454
|
+
return utils.sign_(update, graft)
|
455
|
+
|
456
|
+
|
417
457
|
@general_guard("Q", "exprs", ("Q_cache", None), ("cache_expr", None), init_fn=_init_psgd, skip_first=False)
|
418
458
|
@no_state_no_foreach
|
419
459
|
def update_by_delayed_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_expr: str, cached: bool = False,
|
@@ -439,8 +479,7 @@ def apply_to_idx(fn, idx):
|
|
439
479
|
return _fn
|
440
480
|
|
441
481
|
|
442
|
-
def
|
443
|
-
update = [torch.clone(g, memory_format=torch.preserve_format) for g in grad]
|
482
|
+
def _inner_chain(state, group, update, grad, param, *fns):
|
444
483
|
skip_update = False
|
445
484
|
for fn in fns:
|
446
485
|
try:
|
@@ -450,10 +489,30 @@ def chain(state: Union[callable, dict], group, grad, param, *fns):
|
|
450
489
|
continue
|
451
490
|
if update is None:
|
452
491
|
break
|
492
|
+
return update, skip_update
|
493
|
+
|
494
|
+
|
495
|
+
def chain(state: Union[callable, dict], group, grad, param, *fns):
|
496
|
+
update = [torch.clone(g, memory_format=torch.preserve_format) for g in grad]
|
497
|
+
update, skip_update = _inner_chain(state, group, update, grad, param, *fns)
|
453
498
|
if not skip_update and update is not None:
|
454
499
|
utils.update_param_(param, update, group['lr'], group['weight_decay'], caution=group['caution'], grad=grad)
|
455
500
|
|
456
501
|
|
502
|
+
def create_branch(branches: List[List[callable]], merge_fn: callable):
|
503
|
+
def _branch(state, group, update, grad, param):
|
504
|
+
outputs = []
|
505
|
+
for branch in branches:
|
506
|
+
branch_update = [torch.clone(g, memory_format=torch.preserve_format) for u in update]
|
507
|
+
branch_update, skip_update = _inner_chain(state, group, branch_update, grad, param, *branch)
|
508
|
+
if skip_update:
|
509
|
+
raise ValueError("Branches should not skip updates")
|
510
|
+
outputs.append(branch_update)
|
511
|
+
return merge_fn(outputs)
|
512
|
+
|
513
|
+
return _branch
|
514
|
+
|
515
|
+
|
457
516
|
class ChainOpt(utils.StatefulOptimizer):
|
458
517
|
promote: bool = False
|
459
518
|
|
@@ -317,6 +317,19 @@ def nesterov_momentum(state, grad, beta):
|
|
317
317
|
return grad
|
318
318
|
|
319
319
|
|
320
|
+
@decorator_knowngood
|
321
|
+
def _compilable_nesterov_ema_(state, grad, beta):
|
322
|
+
ema32 = _lerp32(state, grad, beta)
|
323
|
+
stochastic_add_(grad, ema32, 1)
|
324
|
+
|
325
|
+
|
326
|
+
def nesterov_ema(state, grad, beta):
|
327
|
+
state, grad = list_guard(state, grad)
|
328
|
+
beta = scalar_guard(beta, state[0])
|
329
|
+
_compilable_nesterov_ema_(state, grad, beta)
|
330
|
+
return grad
|
331
|
+
|
332
|
+
|
320
333
|
def _compilable_grafting(magnitude, direction):
|
321
334
|
return direction * (magnitude.norm() / direction.norm().clamp(min=1e-6))
|
322
335
|
|
@@ -509,6 +522,19 @@ def stochastic_add_(x: List[Tensor], y: List[Tensor], alpha: Union[float, int, T
|
|
509
522
|
_compilable_stochastic_add_(x, y, alpha)
|
510
523
|
|
511
524
|
|
525
|
+
@decorator_knowngood
|
526
|
+
def _compilable_stochastic_multiply_(x: List[Tensor], y: List[Tensor]):
|
527
|
+
for x_, y_ in zip(x, y):
|
528
|
+
x32 = promote(x_)
|
529
|
+
y32 = promote(y_)
|
530
|
+
copy_stochastic_(x_, x32 * y32)
|
531
|
+
|
532
|
+
|
533
|
+
def stochastic_multiply_(x: List[Tensor], y: List[Tensor]):
|
534
|
+
x, y = list_guard(x, y)
|
535
|
+
_compilable_stochastic_multiply_(x, y)
|
536
|
+
|
537
|
+
|
512
538
|
@decorator
|
513
539
|
def compute_ggt(grad, GG, max_precond_dim, precondition_1d, beta):
|
514
540
|
if grad.dim() == 1 and (not precondition_1d or grad.shape[0] > max_precond_dim):
|
@@ -783,7 +809,7 @@ def _compilable_adam_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: Lis
|
|
783
809
|
|
784
810
|
|
785
811
|
def adam_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], beta1: float, beta2: float, step: int,
|
786
|
-
eps: float):
|
812
|
+
eps: float = 1e-8):
|
787
813
|
exp_avg, exp_avg_sq, grad = map(list_guard, (exp_avg, exp_avg_sq, grad))
|
788
814
|
beta1, beta2, step, eps = scalar_guard(beta1, beta2, step, eps, exp_avg[0])
|
789
815
|
_compilable_adam_(exp_avg, exp_avg_sq, grad, beta1, beta2, step, eps)
|
@@ -815,23 +841,23 @@ def fused_adam_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tensor]
|
|
815
841
|
|
816
842
|
@decorator_knowngood
|
817
843
|
def _compilable_laprop_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], beta1: Tensor,
|
818
|
-
beta2: Tensor, step: Tensor):
|
844
|
+
beta2: Tensor, step: Tensor, eps: Tensor):
|
819
845
|
beta1 = beta_debias(beta1, step)
|
820
846
|
beta2 = beta_debias(beta2, step)
|
821
847
|
|
822
848
|
gp32 = list(map(promote, grad))
|
823
849
|
|
824
|
-
denom = exp_avg_sq_(exp_avg_sq, gp32, beta2,
|
850
|
+
denom = exp_avg_sq_(exp_avg_sq, gp32, beta2, eps)
|
825
851
|
gp32 = torch._foreach_div(gp32, denom)
|
826
852
|
gp32 = _lerp32(exp_avg, gp32, beta1)
|
827
853
|
|
828
854
|
copy_stochastic_list_(grad, gp32)
|
829
855
|
|
830
856
|
|
831
|
-
def laprop_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], beta1: float, beta2: float, step: int):
|
857
|
+
def laprop_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], beta1: float, beta2: float, step: int, eps: float = 1e-8):
|
832
858
|
exp_avg, exp_avg_sq, grad = list_guard(exp_avg, exp_avg_sq, grad)
|
833
|
-
beta1, beta2, step = scalar_guard(beta1, beta2, step, exp_avg[0])
|
834
|
-
_compilable_laprop_(exp_avg, exp_avg_sq, grad, beta1, beta2, step)
|
859
|
+
beta1, beta2, step, eps = scalar_guard(beta1, beta2, step, exp_avg[0], eps)
|
860
|
+
_compilable_laprop_(exp_avg, exp_avg_sq, grad, beta1, beta2, step, eps)
|
835
861
|
return grad
|
836
862
|
|
837
863
|
|
@@ -970,6 +996,10 @@ def get_soap_precond_schedule(precond_scheduler):
|
|
970
996
|
return _inner
|
971
997
|
|
972
998
|
|
999
|
+
def _max_idx(x: List[int]):
|
1000
|
+
return len(x) - 1 - np.argmax(x[::-1]) # we want to start counting from the back, as torch is fan-out/fan-in
|
1001
|
+
|
1002
|
+
|
973
1003
|
def init_Q_exprs(t, scale, max_size, min_ndim_triangular, memory_save_mode, dtype=None):
|
974
1004
|
"""For a scalar or tensor t, we initialize its preconditioner Q and
|
975
1005
|
reusable einsum expressions for updating Q and preconditioning gradient.
|
@@ -992,17 +1022,20 @@ def init_Q_exprs(t, scale, max_size, min_ndim_triangular, memory_save_mode, dtyp
|
|
992
1022
|
|
993
1023
|
scale = scale ** (1 / len(shape))
|
994
1024
|
|
1025
|
+
dim_diag = [False for _ in shape]
|
995
1026
|
if memory_save_mode is None:
|
996
|
-
|
1027
|
+
pass
|
997
1028
|
elif memory_save_mode == "one_diag":
|
998
|
-
|
999
|
-
|
1000
|
-
|
1029
|
+
dim_diag[_max_idx(shape)] = True
|
1030
|
+
elif memory_save_mode == "smart_one_diag":
|
1031
|
+
sorted_shape = sorted(shape)
|
1032
|
+
if len(shape) >= 2 and sorted_shape[-1] > sorted_shape[-2]:
|
1033
|
+
dim_diag[_max_idx(shape)] = True
|
1001
1034
|
elif memory_save_mode == "all_diag":
|
1002
1035
|
dim_diag = [True for _ in shape]
|
1003
1036
|
else:
|
1004
1037
|
raise ValueError(f"Invalid memory_save_mode: {memory_save_mode}, must be one of "
|
1005
|
-
"[None, 'one_diag', 'all_diag']")
|
1038
|
+
"[None, 'one_diag', 'all_diag', 'smart_one_diag']")
|
1006
1039
|
|
1007
1040
|
Q = []
|
1008
1041
|
piece1A, piece2A, piece3A = ([], "", "")
|
@@ -1221,6 +1254,48 @@ def identity(x):
|
|
1221
1254
|
return x
|
1222
1255
|
|
1223
1256
|
|
1257
|
+
@decorator_knowngood
|
1258
|
+
def _compilable_weight_decay_to_ema_(p, ema, ema_decay, weight_decay):
|
1259
|
+
ema32 = _lerp32(ema, p, ema_decay)
|
1260
|
+
_lerp32(p, ema32, 1 - weight_decay)
|
1261
|
+
|
1262
|
+
|
1263
|
+
def weight_decay_to_ema_(p, ema, ema_decay, weight_decay):
|
1264
|
+
p, ema = list_guard(p, ema)
|
1265
|
+
ema_decay, weight_decay = scalar_guard(ema_decay, weight_decay, p[0])
|
1266
|
+
_compilable_weight_decay_to_ema_(p, ema, ema_decay, weight_decay)
|
1267
|
+
|
1268
|
+
|
1269
|
+
@decorator_knowngood
|
1270
|
+
def _compilable_l1_weight_decay_to_ema_(p, ema, ema_deacy, weight_decay):
|
1271
|
+
ema32 = _lerp32(ema, p, ema_deacy)
|
1272
|
+
for p_, e_ in zip(p, ema32):
|
1273
|
+
p32 = promote(p)
|
1274
|
+
p32 = p32 + (p32 - e_).sign() * weight_decay
|
1275
|
+
copy_stochastic_(p_, p32)
|
1276
|
+
|
1277
|
+
|
1278
|
+
def l1_weight_decay_to_ema_(p, ema, ema_decay, weight_decay):
|
1279
|
+
p, ema = list_guard(p, ema)
|
1280
|
+
ema_decay, weight_decay = scalar_guard(ema_decay, weight_decay, p[0])
|
1281
|
+
_compilable_l1_weight_decay_to_ema_(p, ema, ema_decay, weight_decay)
|
1282
|
+
|
1283
|
+
|
1284
|
+
@decorator_knowngood
|
1285
|
+
def _compilable_sign_(grad: List[Tensor], graft: bool):
|
1286
|
+
for g_ in grad:
|
1287
|
+
gs = g_.sign()
|
1288
|
+
if graft:
|
1289
|
+
gs = _compilable_grafting(g_, gs)
|
1290
|
+
copy_stochastic_(g_, gs)
|
1291
|
+
|
1292
|
+
|
1293
|
+
def sign_(grad: List[Tensor], graft: bool = True):
|
1294
|
+
grad = list_guard(grad)
|
1295
|
+
_compilable_sign_(grad, graft)
|
1296
|
+
return grad
|
1297
|
+
|
1298
|
+
|
1224
1299
|
@decorator_knowngood
|
1225
1300
|
def _compilable_trust_region_clip_(grad, lerp, scale):
|
1226
1301
|
# (sgn(x) * log(1 + |x|) * 0.1 + tanh(x) * 0.9).clamp_(min=-2, max=2)
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|