heavyball 1.6.0__py3-none-any.whl → 1.6.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- heavyball/utils.py +57 -55
- {heavyball-1.6.0.dist-info → heavyball-1.6.2.dist-info}/METADATA +19 -19
- heavyball-1.6.2.dist-info/RECORD +8 -0
- {heavyball-1.6.0.dist-info → heavyball-1.6.2.dist-info}/WHEEL +1 -1
- heavyball-1.6.0.dist-info/RECORD +0 -8
- {heavyball-1.6.0.dist-info → heavyball-1.6.2.dist-info}/LICENSE +0 -0
- {heavyball-1.6.0.dist-info → heavyball-1.6.2.dist-info}/top_level.txt +0 -0
heavyball/utils.py
CHANGED
@@ -1,18 +1,12 @@
|
|
1
|
-
import copy
|
2
1
|
import functools
|
3
2
|
import gc
|
4
|
-
import inspect
|
5
3
|
import math
|
6
4
|
import random
|
7
5
|
import string
|
8
|
-
import sys
|
9
|
-
import time
|
10
6
|
import warnings
|
11
|
-
from datetime import datetime
|
12
7
|
from typing import List, Optional, Tuple, Callable, Union
|
13
8
|
from unittest.mock import patch
|
14
9
|
|
15
|
-
import hyperopt
|
16
10
|
import numpy as np
|
17
11
|
import torch
|
18
12
|
from torch import Tensor
|
@@ -165,14 +159,17 @@ def beta_debias(beta, step):
|
|
165
159
|
return 1 - (1 - beta) / (1 - beta ** step)
|
166
160
|
|
167
161
|
|
162
|
+
def eps_sqrt(item, eps):
|
163
|
+
return item.sqrt().clamp(min=eps)
|
164
|
+
|
165
|
+
|
168
166
|
@decorator_knowngood
|
169
167
|
def _compilable_exp_avg_sq_(state: List[Tensor], grad: List[Tensor], beta2: Tensor, eps: Tensor,
|
170
168
|
out: List[Optional[Tensor]]):
|
171
|
-
|
172
|
-
s32 =
|
173
|
-
copy_stochastic_list_(state, s32)
|
169
|
+
g32 = promote(grad)
|
170
|
+
s32 = _lerp(state, torch._foreach_mul(g32, g32), beta2)
|
174
171
|
|
175
|
-
denom = [d
|
172
|
+
denom = [eps_sqrt(d, eps) for d in s32]
|
176
173
|
|
177
174
|
if out[0] is None:
|
178
175
|
return denom
|
@@ -189,7 +186,7 @@ def exp_avg_sq_(state, grad, beta2, eps, out=None):
|
|
189
186
|
|
190
187
|
@decorator_knowngood
|
191
188
|
def _compilable_scale_by_exp_avg_sq_(state: List[Tensor], grad: List[Tensor], beta2: Tensor, eps: Tensor):
|
192
|
-
g32 =
|
189
|
+
g32 = promote(grad)
|
193
190
|
denom = _compilable_exp_avg_sq_(state, g32, beta2, eps, [None])
|
194
191
|
out = torch._foreach_div(g32, denom)
|
195
192
|
copy_stochastic_list_(grad, out)
|
@@ -265,8 +262,8 @@ def set_torch(benchmark_limit: int = 32):
|
|
265
262
|
cudnn.benchmark_limit = benchmark_limit
|
266
263
|
torch.use_deterministic_algorithms(False)
|
267
264
|
torch.set_float32_matmul_precision("high") # highest: FP32, high: TF32, medium: bf16
|
268
|
-
opt_einsum.enabled =
|
269
|
-
opt_einsum.strategy = "
|
265
|
+
opt_einsum.enabled = False
|
266
|
+
opt_einsum.strategy = "auto"
|
270
267
|
|
271
268
|
# Torch calls these for 2nd-order optimization in HeavyBall, but they are explicitly handled.
|
272
269
|
_ignore_warning(
|
@@ -379,7 +376,7 @@ def _compilable_scatter_set(target, source, index):
|
|
379
376
|
target[:] = source.contiguous()[index].reshape_as(target)
|
380
377
|
|
381
378
|
|
382
|
-
@decorator_knowngood
|
379
|
+
# @decorator_knowngood
|
383
380
|
def get_orthogonal_matrix_QR(GG: List[Tensor], Q: List[Tensor], exp_avg: Optional[Tensor] = None):
|
384
381
|
"""
|
385
382
|
Computes the eigenbases of the preconditioner using one round of power iteration
|
@@ -398,7 +395,8 @@ def get_orthogonal_matrix_QR(GG: List[Tensor], Q: List[Tensor], exp_avg: Optiona
|
|
398
395
|
new_qs = []
|
399
396
|
|
400
397
|
for m, q in zip(GG, Q):
|
401
|
-
if
|
398
|
+
if m is None:
|
399
|
+
new_qs.append(None)
|
402
400
|
continue
|
403
401
|
|
404
402
|
m = promote(m.data)
|
@@ -420,52 +418,60 @@ def get_orthogonal_matrix_QR(GG: List[Tensor], Q: List[Tensor], exp_avg: Optiona
|
|
420
418
|
in_str = einsum_base[:exp_avg.dim()]
|
421
419
|
out_str = einsum_base[exp_avg.dim():2 * exp_avg.dim()]
|
422
420
|
|
423
|
-
from_shampoo = ",".join([o + i for m, i, o in zip(Q, in_str, in_str.upper()) if
|
421
|
+
from_shampoo = ",".join([o + i for m, i, o in zip(Q, in_str, in_str.upper()) if m is not None])
|
424
422
|
if not from_shampoo:
|
425
423
|
return
|
426
424
|
|
427
|
-
to_shampoo = ','.join([i + o for m, i, o in zip(new_qs, in_str.upper(), out_str) if
|
425
|
+
to_shampoo = ','.join([i + o for m, i, o in zip(new_qs, in_str.upper(), out_str) if m is not None])
|
428
426
|
out_str = ''.join([o if o in to_shampoo else i for i, o in zip(in_str, out_str)])
|
429
427
|
|
430
428
|
subscripts = f'{in_str},{from_shampoo},{to_shampoo}->{out_str}'
|
431
|
-
exp_avg_new = torch.einsum(subscripts, exp_avg, *[q for q in Q
|
429
|
+
exp_avg_new = torch.einsum(subscripts, exp_avg, *[q for q in Q if q is not None],
|
430
|
+
*[q for q in new_qs if q is not None])
|
432
431
|
copy_stochastic_(exp_avg, exp_avg_new)
|
433
432
|
|
434
433
|
for q, q_new in zip(Q, new_qs):
|
435
|
-
|
434
|
+
if q is not None:
|
435
|
+
copy_stochastic_(q, q_new)
|
436
436
|
|
437
437
|
|
438
|
-
def get_orthogonal_matrix(mat):
|
438
|
+
def get_orthogonal_matrix(mat, max_eps: float = 1e-3, min_eps: float = 1e-30):
|
439
439
|
"""
|
440
440
|
Computes the eigenbases of the preconditioner using torch.linalg.eigh decomposition.
|
441
441
|
"""
|
442
442
|
|
443
443
|
final = []
|
444
444
|
for m in mat:
|
445
|
-
if
|
446
|
-
final.append(
|
445
|
+
if m is None:
|
446
|
+
final.append(None)
|
447
447
|
continue
|
448
448
|
|
449
449
|
m = promote(m.data)
|
450
450
|
|
451
451
|
device, dtype = m.device, m.dtype
|
452
|
-
|
453
|
-
|
454
|
-
m = m.to(modifier)
|
452
|
+
eps = min_eps
|
453
|
+
while True:
|
455
454
|
try:
|
456
|
-
|
455
|
+
eye = torch.eye(m.shape[0], device=m.device, dtype=m.dtype)
|
456
|
+
eigval, eigvec = torch.linalg.eigh(m + eps * eye)
|
457
457
|
eigvec = eigvec.to(device=device, dtype=dtype)
|
458
458
|
break
|
459
459
|
except torch.OutOfMemoryError:
|
460
|
-
|
460
|
+
if m.device.type == 'cpu':
|
461
|
+
raise
|
462
|
+
else:
|
463
|
+
m = m.cpu()
|
461
464
|
except RuntimeError: # failed to compute eigenvalues
|
462
|
-
|
465
|
+
if m.dtype != torch.double:
|
466
|
+
m = m.double()
|
467
|
+
elif eps < max_eps:
|
468
|
+
eps = eps ** (2 / 3)
|
469
|
+
else:
|
470
|
+
raise
|
463
471
|
clean()
|
464
|
-
else:
|
465
|
-
raise RuntimeError("Failed to compute eigenvalues.")
|
466
472
|
|
473
|
+
eigvec = eigvec.to(device=m.device, dtype=m.dtype)
|
467
474
|
eigvec = torch.flip(eigvec, [1])
|
468
|
-
|
469
475
|
final.append(eigvec)
|
470
476
|
|
471
477
|
return final
|
@@ -476,7 +482,9 @@ def _compilable_stochastic_lerp_(x: List[Tensor], y: List[Tensor], a: Union[floa
|
|
476
482
|
for x_, y_ in zip(x, y):
|
477
483
|
x32 = promote(x_)
|
478
484
|
y32 = promote(y_)
|
479
|
-
|
485
|
+
if x32.dtype != y32.dtype:
|
486
|
+
y32 = y32.to(x32.dtype)
|
487
|
+
copy_stochastic_(x_, x32 * (1 - a) + y32 * a)
|
480
488
|
|
481
489
|
|
482
490
|
def get_beta1(group):
|
@@ -575,7 +583,7 @@ def update_ggt(grad, GG, max_precond_dim, precondition_1d, beta):
|
|
575
583
|
g0 = einsum_base[:grad.dim()]
|
576
584
|
g1 = g0.replace(b, b.upper())
|
577
585
|
outer_product = torch.einsum(f'{g0},{g1}->{b + b.upper()}', grad, grad)
|
578
|
-
m
|
586
|
+
stochastic_lerp_(m, outer_product, 1 - beta)
|
579
587
|
|
580
588
|
|
581
589
|
def tree_apply(fn):
|
@@ -618,7 +626,8 @@ def init_preconditioner(grad, state, max_precond_dim, precondition_1d):
|
|
618
626
|
state['GG'] = [] # Will hold all the preconditioner matrices (L and R in the paper).
|
619
627
|
if grad.numel() > 1 and (grad.ndim > 1 or precondition_1d):
|
620
628
|
for sh in grad.shape:
|
621
|
-
if sh > max_precond_dim:
|
629
|
+
if sh > max_precond_dim or sh == 1:
|
630
|
+
# via @francois-rozet: https://github.com/HomebrewML/HeavyBall/commit/8b86be04967e2d095136d5603724f488f2d46592#diff-a430393dd0a6ee393944a9ed16416115c175de2414cf4a96e647197697f265e9R621
|
622
631
|
state['GG'].append(None)
|
623
632
|
else:
|
624
633
|
state['GG'].append(torch.zeros(sh, sh, device=grad.device, dtype=grad.dtype))
|
@@ -638,10 +647,10 @@ def project(grad, Q, back: bool):
|
|
638
647
|
:return:
|
639
648
|
"""
|
640
649
|
param = einsum_base[:grad.dim()]
|
641
|
-
preconditioners = ",".join([(g + g.upper())[::-1 if back else 1] for m, g in zip(Q, param) if
|
650
|
+
preconditioners = ",".join([(g + g.upper())[::-1 if back else 1] for m, g in zip(Q, param) if m is not None])
|
642
651
|
if preconditioners:
|
643
652
|
out = ''.join([c.upper() if c.upper() in preconditioners else c for c in param])
|
644
|
-
out = torch.einsum(f'{param},{preconditioners}->{out}', promote(grad), *[q for q in Q if
|
653
|
+
out = torch.einsum(f'{param},{preconditioners}->{out}', promote(grad), *[q for q in Q if q is not None])
|
645
654
|
grad = out.to(grad.dtype)
|
646
655
|
return grad
|
647
656
|
|
@@ -876,7 +885,7 @@ def _lerp(state: List[Tensor], grad: List[Tensor], beta):
|
|
876
885
|
ea32 = list(map(promote, state))
|
877
886
|
grad = list(map(promote, grad))
|
878
887
|
beta = promote(beta)
|
879
|
-
ea32
|
888
|
+
stochastic_lerp_(ea32, grad, 1 - beta)
|
880
889
|
copy_stochastic_list_(state, ea32)
|
881
890
|
return ea32
|
882
891
|
|
@@ -890,7 +899,7 @@ def _compilable_adam_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: Lis
|
|
890
899
|
g32 = list(map(promote, grad))
|
891
900
|
exp_avg32 = _lerp(exp_avg, g32, beta1)
|
892
901
|
denom = _compilable_exp_avg_sq_(exp_avg_sq, g32, beta2, eps, [None])
|
893
|
-
u32 =
|
902
|
+
u32 = torch._foreach_div(exp_avg32, denom)
|
894
903
|
copy_stochastic_list_(grad, u32)
|
895
904
|
|
896
905
|
|
@@ -973,14 +982,11 @@ def _fused_compilable_adopt_(y, update, grad, exp_avg_sq, exp_avg, beta1, beta2,
|
|
973
982
|
_compilable_update_(y, u32, decay, lr, caution, g32)
|
974
983
|
|
975
984
|
beta1 = beta_debias(beta1, step)
|
976
|
-
denom =
|
977
|
-
|
978
|
-
exp_avg32 = [ea32.lerp(g / d, 1 - beta1) for ea32, g, d in zip(exp_avg32, g32, denom)]
|
979
|
-
copy_stochastic_list_(exp_avg, exp_avg32)
|
985
|
+
denom = [eps_sqrt(d, eps) for d in exp_avg_sq32]
|
986
|
+
stochastic_lerp_(exp_avg, torch._foreach_div(g32, denom), 1 - beta1)
|
980
987
|
|
981
988
|
beta2 = beta_debias(beta2, step + 1)
|
982
|
-
|
983
|
-
copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
|
989
|
+
stochastic_lerp_(exp_avg_sq, torch._foreach_mul(g32, g32), 1 - beta2)
|
984
990
|
|
985
991
|
|
986
992
|
def fused_adopt_(y, update, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps, decay, caution):
|
@@ -990,27 +996,23 @@ def fused_adopt_(y, update, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, e
|
|
990
996
|
|
991
997
|
|
992
998
|
@decorator_knowngood
|
993
|
-
def _compilable_adopt_(grad, exp_avg_sq, exp_avg, beta1, beta2, step):
|
999
|
+
def _compilable_adopt_(grad, exp_avg_sq, exp_avg, beta1, beta2, step, eps):
|
994
1000
|
g32, exp_avg32, exp_avg_sq32 = [list(map(promote, x)) for x in [grad, exp_avg, exp_avg_sq]]
|
995
1001
|
update = [e.clone() for e in exp_avg]
|
996
1002
|
|
997
1003
|
beta1 = beta_debias(beta1, step)
|
998
|
-
denom =
|
999
|
-
|
1000
|
-
exp_avg32 = [ea32.lerp(g / d, 1 - beta1) for ea32, g, d in zip(exp_avg32, g32, denom)]
|
1001
|
-
copy_stochastic_list_(exp_avg, exp_avg32)
|
1004
|
+
denom = [eps_sqrt(d, eps) for d in exp_avg_sq32]
|
1005
|
+
stochastic_lerp_(exp_avg, torch._foreach_div(g32, denom), 1 - beta1)
|
1002
1006
|
|
1003
|
-
|
1004
|
-
exp_avg_sq32 = [eas32.lerp(g * g, 1 - beta2) for eas32, g in zip(exp_avg_sq32, g32)]
|
1005
|
-
copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
|
1007
|
+
stochastic_lerp_(exp_avg_sq, torch._foreach_mul(g32, g32), 1 - beta2)
|
1006
1008
|
|
1007
1009
|
copy_stochastic_list_(grad, update)
|
1008
1010
|
|
1009
1011
|
|
1010
|
-
def adopt(grad, exp_avg_sq, exp_avg, beta1, beta2, step):
|
1012
|
+
def adopt(grad, exp_avg_sq, exp_avg, beta1, beta2, step, eps: float = 1e-8):
|
1011
1013
|
exp_avg, exp_avg_sq, grad = list_guard(exp_avg, exp_avg_sq, grad)
|
1012
|
-
beta1, beta2, step = scalar_guard(beta1, beta2, step, exp_avg[0])
|
1013
|
-
_compilable_adopt_(grad, exp_avg_sq, exp_avg, beta1, beta2, step)
|
1014
|
+
beta1, beta2, step, eps = scalar_guard(beta1, beta2, step, eps, exp_avg[0])
|
1015
|
+
_compilable_adopt_(grad, exp_avg_sq, exp_avg, beta1, beta2, step, eps)
|
1014
1016
|
return grad
|
1015
1017
|
|
1016
1018
|
|
@@ -1,26 +1,26 @@
|
|
1
|
-
Metadata-Version: 2.
|
1
|
+
Metadata-Version: 2.2
|
2
2
|
Name: heavyball
|
3
|
-
Version: 1.6.
|
4
|
-
Summary: Efficient
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
Classifier: Development Status :: 5 - Production/Stable
|
10
|
-
Classifier: License :: OSI Approved :: BSD License
|
11
|
-
Classifier: Programming Language :: Python
|
12
|
-
Classifier: Programming Language :: Python :: 3.7
|
13
|
-
Classifier: Programming Language :: Python :: 3.8
|
14
|
-
Classifier: Programming Language :: Python :: 3.9
|
15
|
-
Classifier: Topic :: Software Development :: Libraries
|
16
|
-
Classifier: Topic :: Software Development :: Libraries :: Python Modules
|
3
|
+
Version: 1.6.2
|
4
|
+
Summary: Efficient Optimizers
|
5
|
+
Author-email: HeavyBall Authors <github.heavyball@nestler.sh>
|
6
|
+
Project-URL: source, https://github.com/HomebrewML/HeavyBall
|
7
|
+
Project-URL: tracker, https://github.com/HomebrewML/HeavyBall/issues
|
8
|
+
Keywords: torch,optimizer,muon,soap,psgd
|
17
9
|
Classifier: Intended Audience :: Developers
|
18
|
-
|
10
|
+
Classifier: Intended Audience :: Science/Research
|
11
|
+
Classifier: License :: OSI Approved :: BSD License
|
12
|
+
Classifier: Natural Language :: English
|
13
|
+
Classifier: Operating System :: OS Independent
|
14
|
+
Classifier: Programming Language :: Python :: 3
|
15
|
+
Requires-Python: >=3.9
|
19
16
|
Description-Content-Type: text/markdown
|
20
17
|
License-File: LICENSE
|
21
|
-
Requires-Dist: opt-einsum
|
22
|
-
Requires-Dist: torch
|
23
|
-
|
18
|
+
Requires-Dist: opt-einsum>=3.0.0
|
19
|
+
Requires-Dist: torch>=2.0.0
|
20
|
+
Provides-Extra: dev
|
21
|
+
Requires-Dist: pre-commit; extra == "dev"
|
22
|
+
Requires-Dist: pytest; extra == "dev"
|
23
|
+
Requires-Dist: ruff; extra == "dev"
|
24
24
|
|
25
25
|
# `heavyball`: Efficient Optimizers
|
26
26
|
|
@@ -0,0 +1,8 @@
|
|
1
|
+
heavyball/__init__.py,sha256=DKp8rEAf7mt2-j9XRVlgjaLjyfuwUsyl_uXJoOKWAHg,15362
|
2
|
+
heavyball/chainable.py,sha256=n_u0QS92WitbtnENvNQ0m4dZTHuJ5ObQ88XA3cmhCfo,27298
|
3
|
+
heavyball/utils.py,sha256=CFBFHTekWaqKhmrSLuMvRsxZ41YxPfsYihEPvJMKOQc,56088
|
4
|
+
heavyball-1.6.2.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
|
5
|
+
heavyball-1.6.2.dist-info/METADATA,sha256=q2CEAHIg6jdGn7dey36EMExwJBrNFTDgZFpEzEHDvBY,43479
|
6
|
+
heavyball-1.6.2.dist-info/WHEEL,sha256=jB7zZ3N9hIM9adW7qlTAyycLYW9npaWKLRzaoVcLKcM,91
|
7
|
+
heavyball-1.6.2.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
|
8
|
+
heavyball-1.6.2.dist-info/RECORD,,
|
heavyball-1.6.0.dist-info/RECORD
DELETED
@@ -1,8 +0,0 @@
|
|
1
|
-
heavyball/__init__.py,sha256=DKp8rEAf7mt2-j9XRVlgjaLjyfuwUsyl_uXJoOKWAHg,15362
|
2
|
-
heavyball/chainable.py,sha256=n_u0QS92WitbtnENvNQ0m4dZTHuJ5ObQ88XA3cmhCfo,27298
|
3
|
-
heavyball/utils.py,sha256=Nk0q_sfv47F-QC9Wwi5KCt-C_71OhuzM98XHlYGvl24,55905
|
4
|
-
heavyball-1.6.0.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
|
5
|
-
heavyball-1.6.0.dist-info/METADATA,sha256=5suezTlZCOBwCgHeFgkLaywYwjAWN1SPg6yhvAv1WgE,43441
|
6
|
-
heavyball-1.6.0.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
7
|
-
heavyball-1.6.0.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
|
8
|
-
heavyball-1.6.0.dist-info/RECORD,,
|
File without changes
|
File without changes
|