heavyball 1.6.1__tar.gz → 1.6.3__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {heavyball-1.6.1 → heavyball-1.6.3}/PKG-INFO +24 -18
- {heavyball-1.6.1 → heavyball-1.6.3}/heavyball/utils.py +19 -12
- {heavyball-1.6.1 → heavyball-1.6.3}/heavyball.egg-info/PKG-INFO +24 -18
- {heavyball-1.6.1 → heavyball-1.6.3}/heavyball.egg-info/SOURCES.txt +1 -1
- heavyball-1.6.3/heavyball.egg-info/requires.txt +13 -0
- heavyball-1.6.3/pyproject.toml +52 -0
- heavyball-1.6.1/heavyball.egg-info/requires.txt +0 -3
- heavyball-1.6.1/setup.py +0 -33
- {heavyball-1.6.1 → heavyball-1.6.3}/LICENSE +0 -0
- {heavyball-1.6.1 → heavyball-1.6.3}/README.md +0 -0
- {heavyball-1.6.1 → heavyball-1.6.3}/heavyball/__init__.py +0 -0
- {heavyball-1.6.1 → heavyball-1.6.3}/heavyball/chainable.py +0 -0
- {heavyball-1.6.1 → heavyball-1.6.3}/heavyball.egg-info/dependency_links.txt +0 -0
- {heavyball-1.6.1 → heavyball-1.6.3}/heavyball.egg-info/top_level.txt +0 -0
- {heavyball-1.6.1 → heavyball-1.6.3}/setup.cfg +0 -0
- {heavyball-1.6.1 → heavyball-1.6.3}/test/test_bf16_params.py +0 -0
- {heavyball-1.6.1 → heavyball-1.6.3}/test/test_bf16_q.py +0 -0
- {heavyball-1.6.1 → heavyball-1.6.3}/test/test_bf16_storage.py +0 -0
- {heavyball-1.6.1 → heavyball-1.6.3}/test/test_caution.py +0 -0
- {heavyball-1.6.1 → heavyball-1.6.3}/test/test_channels_last.py +0 -0
- {heavyball-1.6.1 → heavyball-1.6.3}/test/test_closure.py +0 -0
- {heavyball-1.6.1 → heavyball-1.6.3}/test/test_ema.py +0 -0
- {heavyball-1.6.1 → heavyball-1.6.3}/test/test_foreach.py +0 -0
- {heavyball-1.6.1 → heavyball-1.6.3}/test/test_hook.py +0 -0
- {heavyball-1.6.1 → heavyball-1.6.3}/test/test_mars.py +0 -0
- {heavyball-1.6.1 → heavyball-1.6.3}/test/test_memory.py +0 -0
- {heavyball-1.6.1 → heavyball-1.6.3}/test/test_merge.py +0 -0
- {heavyball-1.6.1 → heavyball-1.6.3}/test/test_no_grad.py +0 -0
- {heavyball-1.6.1 → heavyball-1.6.3}/test/test_psgd.py +0 -0
- {heavyball-1.6.1 → heavyball-1.6.3}/test/test_soap.py +0 -0
- {heavyball-1.6.1 → heavyball-1.6.3}/test/test_stochastic_updates.py +0 -0
@@ -1,26 +1,32 @@
|
|
1
|
-
Metadata-Version: 2.
|
1
|
+
Metadata-Version: 2.2
|
2
2
|
Name: heavyball
|
3
|
-
Version: 1.6.
|
4
|
-
Summary: Efficient
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
Classifier: Development Status :: 5 - Production/Stable
|
10
|
-
Classifier: License :: OSI Approved :: BSD License
|
11
|
-
Classifier: Programming Language :: Python
|
12
|
-
Classifier: Programming Language :: Python :: 3.7
|
13
|
-
Classifier: Programming Language :: Python :: 3.8
|
14
|
-
Classifier: Programming Language :: Python :: 3.9
|
15
|
-
Classifier: Topic :: Software Development :: Libraries
|
16
|
-
Classifier: Topic :: Software Development :: Libraries :: Python Modules
|
3
|
+
Version: 1.6.3
|
4
|
+
Summary: Efficient Optimizers
|
5
|
+
Author-email: HeavyBall Authors <github.heavyball@nestler.sh>
|
6
|
+
Project-URL: source, https://github.com/HomebrewML/HeavyBall
|
7
|
+
Project-URL: tracker, https://github.com/HomebrewML/HeavyBall/issues
|
8
|
+
Keywords: torch,optimizer,muon,soap,psgd
|
17
9
|
Classifier: Intended Audience :: Developers
|
18
|
-
|
10
|
+
Classifier: Intended Audience :: Science/Research
|
11
|
+
Classifier: License :: OSI Approved :: BSD License
|
12
|
+
Classifier: Natural Language :: English
|
13
|
+
Classifier: Operating System :: OS Independent
|
14
|
+
Classifier: Programming Language :: Python :: 3
|
15
|
+
Requires-Python: >=3.9
|
19
16
|
Description-Content-Type: text/markdown
|
20
17
|
License-File: LICENSE
|
21
|
-
Requires-Dist: opt-einsum
|
22
|
-
Requires-Dist: torch
|
18
|
+
Requires-Dist: opt-einsum>=3.4.0
|
19
|
+
Requires-Dist: torch>=2.1.0
|
23
20
|
Requires-Dist: numpy
|
21
|
+
Provides-Extra: dev
|
22
|
+
Requires-Dist: pre-commit; extra == "dev"
|
23
|
+
Requires-Dist: pytest; extra == "dev"
|
24
|
+
Requires-Dist: ruff; extra == "dev"
|
25
|
+
Requires-Dist: matplotlib; extra == "dev"
|
26
|
+
Requires-Dist: seaborn; extra == "dev"
|
27
|
+
Requires-Dist: hyperopt; extra == "dev"
|
28
|
+
Requires-Dist: pandas; extra == "dev"
|
29
|
+
Requires-Dist: typer; extra == "dev"
|
24
30
|
|
25
31
|
# `heavyball`: Efficient Optimizers
|
26
32
|
|
@@ -376,7 +376,7 @@ def _compilable_scatter_set(target, source, index):
|
|
376
376
|
target[:] = source.contiguous()[index].reshape_as(target)
|
377
377
|
|
378
378
|
|
379
|
-
|
379
|
+
# @decorator_knowngood
|
380
380
|
def get_orthogonal_matrix_QR(GG: List[Tensor], Q: List[Tensor], exp_avg: Optional[Tensor] = None):
|
381
381
|
"""
|
382
382
|
Computes the eigenbases of the preconditioner using one round of power iteration
|
@@ -426,7 +426,8 @@ def get_orthogonal_matrix_QR(GG: List[Tensor], Q: List[Tensor], exp_avg: Optiona
|
|
426
426
|
out_str = ''.join([o if o in to_shampoo else i for i, o in zip(in_str, out_str)])
|
427
427
|
|
428
428
|
subscripts = f'{in_str},{from_shampoo},{to_shampoo}->{out_str}'
|
429
|
-
exp_avg_new = torch.einsum(subscripts, exp_avg, *[q for q in Q if q is not None],
|
429
|
+
exp_avg_new = torch.einsum(subscripts, exp_avg, *[q for q in Q if q is not None],
|
430
|
+
*[q for q in new_qs if q is not None])
|
430
431
|
copy_stochastic_(exp_avg, exp_avg_new)
|
431
432
|
|
432
433
|
for q, q_new in zip(Q, new_qs):
|
@@ -434,7 +435,7 @@ def get_orthogonal_matrix_QR(GG: List[Tensor], Q: List[Tensor], exp_avg: Optiona
|
|
434
435
|
copy_stochastic_(q, q_new)
|
435
436
|
|
436
437
|
|
437
|
-
def get_orthogonal_matrix(mat):
|
438
|
+
def get_orthogonal_matrix(mat, max_eps: float = 1e-3, min_eps: float = 1e-30):
|
438
439
|
"""
|
439
440
|
Computes the eigenbases of the preconditioner using torch.linalg.eigh decomposition.
|
440
441
|
"""
|
@@ -448,23 +449,29 @@ def get_orthogonal_matrix(mat):
|
|
448
449
|
m = promote(m.data)
|
449
450
|
|
450
451
|
device, dtype = m.device, m.dtype
|
451
|
-
|
452
|
-
|
453
|
-
m = m.to(modifier)
|
452
|
+
eps = min_eps
|
453
|
+
while True:
|
454
454
|
try:
|
455
|
-
|
455
|
+
eye = torch.eye(m.shape[0], device=m.device, dtype=m.dtype)
|
456
|
+
eigval, eigvec = torch.linalg.eigh(m + eps * eye)
|
456
457
|
eigvec = eigvec.to(device=device, dtype=dtype)
|
457
458
|
break
|
458
459
|
except torch.OutOfMemoryError:
|
459
|
-
|
460
|
+
if m.device.type == 'cpu':
|
461
|
+
raise
|
462
|
+
else:
|
463
|
+
m = m.cpu()
|
460
464
|
except RuntimeError: # failed to compute eigenvalues
|
461
|
-
|
465
|
+
if m.dtype != torch.double:
|
466
|
+
m = m.double()
|
467
|
+
elif eps < max_eps:
|
468
|
+
eps = eps ** (2 / 3)
|
469
|
+
else:
|
470
|
+
raise
|
462
471
|
clean()
|
463
|
-
else:
|
464
|
-
raise RuntimeError("Failed to compute eigenvalues.")
|
465
472
|
|
473
|
+
eigvec = eigvec.to(device=m.device, dtype=m.dtype)
|
466
474
|
eigvec = torch.flip(eigvec, [1])
|
467
|
-
|
468
475
|
final.append(eigvec)
|
469
476
|
|
470
477
|
return final
|
@@ -1,26 +1,32 @@
|
|
1
|
-
Metadata-Version: 2.
|
1
|
+
Metadata-Version: 2.2
|
2
2
|
Name: heavyball
|
3
|
-
Version: 1.6.
|
4
|
-
Summary: Efficient
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
Classifier: Development Status :: 5 - Production/Stable
|
10
|
-
Classifier: License :: OSI Approved :: BSD License
|
11
|
-
Classifier: Programming Language :: Python
|
12
|
-
Classifier: Programming Language :: Python :: 3.7
|
13
|
-
Classifier: Programming Language :: Python :: 3.8
|
14
|
-
Classifier: Programming Language :: Python :: 3.9
|
15
|
-
Classifier: Topic :: Software Development :: Libraries
|
16
|
-
Classifier: Topic :: Software Development :: Libraries :: Python Modules
|
3
|
+
Version: 1.6.3
|
4
|
+
Summary: Efficient Optimizers
|
5
|
+
Author-email: HeavyBall Authors <github.heavyball@nestler.sh>
|
6
|
+
Project-URL: source, https://github.com/HomebrewML/HeavyBall
|
7
|
+
Project-URL: tracker, https://github.com/HomebrewML/HeavyBall/issues
|
8
|
+
Keywords: torch,optimizer,muon,soap,psgd
|
17
9
|
Classifier: Intended Audience :: Developers
|
18
|
-
|
10
|
+
Classifier: Intended Audience :: Science/Research
|
11
|
+
Classifier: License :: OSI Approved :: BSD License
|
12
|
+
Classifier: Natural Language :: English
|
13
|
+
Classifier: Operating System :: OS Independent
|
14
|
+
Classifier: Programming Language :: Python :: 3
|
15
|
+
Requires-Python: >=3.9
|
19
16
|
Description-Content-Type: text/markdown
|
20
17
|
License-File: LICENSE
|
21
|
-
Requires-Dist: opt-einsum
|
22
|
-
Requires-Dist: torch
|
18
|
+
Requires-Dist: opt-einsum>=3.4.0
|
19
|
+
Requires-Dist: torch>=2.1.0
|
23
20
|
Requires-Dist: numpy
|
21
|
+
Provides-Extra: dev
|
22
|
+
Requires-Dist: pre-commit; extra == "dev"
|
23
|
+
Requires-Dist: pytest; extra == "dev"
|
24
|
+
Requires-Dist: ruff; extra == "dev"
|
25
|
+
Requires-Dist: matplotlib; extra == "dev"
|
26
|
+
Requires-Dist: seaborn; extra == "dev"
|
27
|
+
Requires-Dist: hyperopt; extra == "dev"
|
28
|
+
Requires-Dist: pandas; extra == "dev"
|
29
|
+
Requires-Dist: typer; extra == "dev"
|
24
30
|
|
25
31
|
# `heavyball`: Efficient Optimizers
|
26
32
|
|
@@ -0,0 +1,52 @@
|
|
1
|
+
[build-system]
|
2
|
+
requires = ["setuptools>=75.0"]
|
3
|
+
build-backend = "setuptools.build_meta"
|
4
|
+
|
5
|
+
[project]
|
6
|
+
name = "heavyball"
|
7
|
+
description = "Efficient Optimizers"
|
8
|
+
version = "1.6.3"
|
9
|
+
authors = [{ name = "HeavyBall Authors", email = "github.heavyball@nestler.sh" }]
|
10
|
+
classifiers = ["Intended Audience :: Developers",
|
11
|
+
"Intended Audience :: Science/Research",
|
12
|
+
"License :: OSI Approved :: BSD License",
|
13
|
+
"Natural Language :: English",
|
14
|
+
"Operating System :: OS Independent",
|
15
|
+
"Programming Language :: Python :: 3",
|
16
|
+
]
|
17
|
+
dependencies = ["opt-einsum>=3.4.0",
|
18
|
+
"torch>=2.1.0",
|
19
|
+
"numpy",
|
20
|
+
]
|
21
|
+
keywords = ["torch",
|
22
|
+
"optimizer",
|
23
|
+
"muon",
|
24
|
+
"soap",
|
25
|
+
"psgd",
|
26
|
+
]
|
27
|
+
readme = "README.md"
|
28
|
+
requires-python = ">=3.9"
|
29
|
+
|
30
|
+
[project.optional-dependencies]
|
31
|
+
dev = ["pre-commit", "pytest", "ruff", "matplotlib", "seaborn", "hyperopt", "pandas", "typer"]
|
32
|
+
|
33
|
+
[project.urls]
|
34
|
+
source = "https://github.com/HomebrewML/HeavyBall"
|
35
|
+
tracker = "https://github.com/HomebrewML/HeavyBall/issues"
|
36
|
+
|
37
|
+
[tool.ruff]
|
38
|
+
line-length = 120
|
39
|
+
|
40
|
+
[tool.ruff.lint]
|
41
|
+
extend-select = ["I", "W"]
|
42
|
+
ignore = ["E741"]
|
43
|
+
preview = true
|
44
|
+
|
45
|
+
[tool.ruff.lint.isort]
|
46
|
+
relative-imports-order = "closest-to-furthest"
|
47
|
+
|
48
|
+
[tool.ruff.format]
|
49
|
+
preview = true
|
50
|
+
|
51
|
+
[tool.setuptools.packages.find]
|
52
|
+
include = ["heavyball*"]
|
heavyball-1.6.1/setup.py
DELETED
@@ -1,33 +0,0 @@
|
|
1
|
-
import setuptools
|
2
|
-
|
3
|
-
|
4
|
-
with open('README.md') as f:
|
5
|
-
README = f.read()
|
6
|
-
|
7
|
-
setuptools.setup(
|
8
|
-
author="HeavyBall Authors",
|
9
|
-
author_email="github.heavyball@nestler.sh",
|
10
|
-
name='heavyball',
|
11
|
-
license='BSD',
|
12
|
-
description='Efficient optimizers',
|
13
|
-
version='1.6.1',
|
14
|
-
long_description=README,
|
15
|
-
url='https://github.com/HomebrewML/HeavyBall',
|
16
|
-
packages=setuptools.find_packages(),
|
17
|
-
python_requires=">=3.7",
|
18
|
-
long_description_content_type="text/markdown",
|
19
|
-
install_requires=['opt-einsum', 'torch', 'numpy'],
|
20
|
-
classifiers=[
|
21
|
-
# Trove classifiers
|
22
|
-
# (https://pypi.python.org/pypi?%3Aaction=list_classifiers)
|
23
|
-
'Development Status :: 5 - Production/Stable',
|
24
|
-
'License :: OSI Approved :: BSD License',
|
25
|
-
'Programming Language :: Python',
|
26
|
-
'Programming Language :: Python :: 3.7',
|
27
|
-
'Programming Language :: Python :: 3.8',
|
28
|
-
'Programming Language :: Python :: 3.9',
|
29
|
-
'Topic :: Software Development :: Libraries',
|
30
|
-
'Topic :: Software Development :: Libraries :: Python Modules',
|
31
|
-
'Intended Audience :: Developers',
|
32
|
-
],
|
33
|
-
)
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|