heavyball 1.6.1__tar.gz → 1.6.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (31) hide show
  1. {heavyball-1.6.1 → heavyball-1.6.3}/PKG-INFO +24 -18
  2. {heavyball-1.6.1 → heavyball-1.6.3}/heavyball/utils.py +19 -12
  3. {heavyball-1.6.1 → heavyball-1.6.3}/heavyball.egg-info/PKG-INFO +24 -18
  4. {heavyball-1.6.1 → heavyball-1.6.3}/heavyball.egg-info/SOURCES.txt +1 -1
  5. heavyball-1.6.3/heavyball.egg-info/requires.txt +13 -0
  6. heavyball-1.6.3/pyproject.toml +52 -0
  7. heavyball-1.6.1/heavyball.egg-info/requires.txt +0 -3
  8. heavyball-1.6.1/setup.py +0 -33
  9. {heavyball-1.6.1 → heavyball-1.6.3}/LICENSE +0 -0
  10. {heavyball-1.6.1 → heavyball-1.6.3}/README.md +0 -0
  11. {heavyball-1.6.1 → heavyball-1.6.3}/heavyball/__init__.py +0 -0
  12. {heavyball-1.6.1 → heavyball-1.6.3}/heavyball/chainable.py +0 -0
  13. {heavyball-1.6.1 → heavyball-1.6.3}/heavyball.egg-info/dependency_links.txt +0 -0
  14. {heavyball-1.6.1 → heavyball-1.6.3}/heavyball.egg-info/top_level.txt +0 -0
  15. {heavyball-1.6.1 → heavyball-1.6.3}/setup.cfg +0 -0
  16. {heavyball-1.6.1 → heavyball-1.6.3}/test/test_bf16_params.py +0 -0
  17. {heavyball-1.6.1 → heavyball-1.6.3}/test/test_bf16_q.py +0 -0
  18. {heavyball-1.6.1 → heavyball-1.6.3}/test/test_bf16_storage.py +0 -0
  19. {heavyball-1.6.1 → heavyball-1.6.3}/test/test_caution.py +0 -0
  20. {heavyball-1.6.1 → heavyball-1.6.3}/test/test_channels_last.py +0 -0
  21. {heavyball-1.6.1 → heavyball-1.6.3}/test/test_closure.py +0 -0
  22. {heavyball-1.6.1 → heavyball-1.6.3}/test/test_ema.py +0 -0
  23. {heavyball-1.6.1 → heavyball-1.6.3}/test/test_foreach.py +0 -0
  24. {heavyball-1.6.1 → heavyball-1.6.3}/test/test_hook.py +0 -0
  25. {heavyball-1.6.1 → heavyball-1.6.3}/test/test_mars.py +0 -0
  26. {heavyball-1.6.1 → heavyball-1.6.3}/test/test_memory.py +0 -0
  27. {heavyball-1.6.1 → heavyball-1.6.3}/test/test_merge.py +0 -0
  28. {heavyball-1.6.1 → heavyball-1.6.3}/test/test_no_grad.py +0 -0
  29. {heavyball-1.6.1 → heavyball-1.6.3}/test/test_psgd.py +0 -0
  30. {heavyball-1.6.1 → heavyball-1.6.3}/test/test_soap.py +0 -0
  31. {heavyball-1.6.1 → heavyball-1.6.3}/test/test_stochastic_updates.py +0 -0
@@ -1,26 +1,32 @@
1
- Metadata-Version: 2.1
1
+ Metadata-Version: 2.2
2
2
  Name: heavyball
3
- Version: 1.6.1
4
- Summary: Efficient optimizers
5
- Home-page: https://github.com/HomebrewML/HeavyBall
6
- Author: HeavyBall Authors
7
- Author-email: github.heavyball@nestler.sh
8
- License: BSD
9
- Classifier: Development Status :: 5 - Production/Stable
10
- Classifier: License :: OSI Approved :: BSD License
11
- Classifier: Programming Language :: Python
12
- Classifier: Programming Language :: Python :: 3.7
13
- Classifier: Programming Language :: Python :: 3.8
14
- Classifier: Programming Language :: Python :: 3.9
15
- Classifier: Topic :: Software Development :: Libraries
16
- Classifier: Topic :: Software Development :: Libraries :: Python Modules
3
+ Version: 1.6.3
4
+ Summary: Efficient Optimizers
5
+ Author-email: HeavyBall Authors <github.heavyball@nestler.sh>
6
+ Project-URL: source, https://github.com/HomebrewML/HeavyBall
7
+ Project-URL: tracker, https://github.com/HomebrewML/HeavyBall/issues
8
+ Keywords: torch,optimizer,muon,soap,psgd
17
9
  Classifier: Intended Audience :: Developers
18
- Requires-Python: >=3.7
10
+ Classifier: Intended Audience :: Science/Research
11
+ Classifier: License :: OSI Approved :: BSD License
12
+ Classifier: Natural Language :: English
13
+ Classifier: Operating System :: OS Independent
14
+ Classifier: Programming Language :: Python :: 3
15
+ Requires-Python: >=3.9
19
16
  Description-Content-Type: text/markdown
20
17
  License-File: LICENSE
21
- Requires-Dist: opt-einsum
22
- Requires-Dist: torch
18
+ Requires-Dist: opt-einsum>=3.4.0
19
+ Requires-Dist: torch>=2.1.0
23
20
  Requires-Dist: numpy
21
+ Provides-Extra: dev
22
+ Requires-Dist: pre-commit; extra == "dev"
23
+ Requires-Dist: pytest; extra == "dev"
24
+ Requires-Dist: ruff; extra == "dev"
25
+ Requires-Dist: matplotlib; extra == "dev"
26
+ Requires-Dist: seaborn; extra == "dev"
27
+ Requires-Dist: hyperopt; extra == "dev"
28
+ Requires-Dist: pandas; extra == "dev"
29
+ Requires-Dist: typer; extra == "dev"
24
30
 
25
31
  # `heavyball`: Efficient Optimizers
26
32
 
@@ -376,7 +376,7 @@ def _compilable_scatter_set(target, source, index):
376
376
  target[:] = source.contiguous()[index].reshape_as(target)
377
377
 
378
378
 
379
- #@decorator_knowngood
379
+ # @decorator_knowngood
380
380
  def get_orthogonal_matrix_QR(GG: List[Tensor], Q: List[Tensor], exp_avg: Optional[Tensor] = None):
381
381
  """
382
382
  Computes the eigenbases of the preconditioner using one round of power iteration
@@ -426,7 +426,8 @@ def get_orthogonal_matrix_QR(GG: List[Tensor], Q: List[Tensor], exp_avg: Optiona
426
426
  out_str = ''.join([o if o in to_shampoo else i for i, o in zip(in_str, out_str)])
427
427
 
428
428
  subscripts = f'{in_str},{from_shampoo},{to_shampoo}->{out_str}'
429
- exp_avg_new = torch.einsum(subscripts, exp_avg, *[q for q in Q if q is not None], *[q for q in new_qs if q is not None])
429
+ exp_avg_new = torch.einsum(subscripts, exp_avg, *[q for q in Q if q is not None],
430
+ *[q for q in new_qs if q is not None])
430
431
  copy_stochastic_(exp_avg, exp_avg_new)
431
432
 
432
433
  for q, q_new in zip(Q, new_qs):
@@ -434,7 +435,7 @@ def get_orthogonal_matrix_QR(GG: List[Tensor], Q: List[Tensor], exp_avg: Optiona
434
435
  copy_stochastic_(q, q_new)
435
436
 
436
437
 
437
- def get_orthogonal_matrix(mat):
438
+ def get_orthogonal_matrix(mat, max_eps: float = 1e-3, min_eps: float = 1e-30):
438
439
  """
439
440
  Computes the eigenbases of the preconditioner using torch.linalg.eigh decomposition.
440
441
  """
@@ -448,23 +449,29 @@ def get_orthogonal_matrix(mat):
448
449
  m = promote(m.data)
449
450
 
450
451
  device, dtype = m.device, m.dtype
451
- for modifier in (None, torch.double, 'cpu'):
452
- if modifier is not None:
453
- m = m.to(modifier)
452
+ eps = min_eps
453
+ while True:
454
454
  try:
455
- eigval, eigvec = torch.linalg.eigh(m + 1e-30 * torch.eye(m.shape[0], device=m.device, dtype=m.dtype))
455
+ eye = torch.eye(m.shape[0], device=m.device, dtype=m.dtype)
456
+ eigval, eigvec = torch.linalg.eigh(m + eps * eye)
456
457
  eigvec = eigvec.to(device=device, dtype=dtype)
457
458
  break
458
459
  except torch.OutOfMemoryError:
459
- pass
460
+ if m.device.type == 'cpu':
461
+ raise
462
+ else:
463
+ m = m.cpu()
460
464
  except RuntimeError: # failed to compute eigenvalues
461
- continue
465
+ if m.dtype != torch.double:
466
+ m = m.double()
467
+ elif eps < max_eps:
468
+ eps = eps ** (2 / 3)
469
+ else:
470
+ raise
462
471
  clean()
463
- else:
464
- raise RuntimeError("Failed to compute eigenvalues.")
465
472
 
473
+ eigvec = eigvec.to(device=m.device, dtype=m.dtype)
466
474
  eigvec = torch.flip(eigvec, [1])
467
-
468
475
  final.append(eigvec)
469
476
 
470
477
  return final
@@ -1,26 +1,32 @@
1
- Metadata-Version: 2.1
1
+ Metadata-Version: 2.2
2
2
  Name: heavyball
3
- Version: 1.6.1
4
- Summary: Efficient optimizers
5
- Home-page: https://github.com/HomebrewML/HeavyBall
6
- Author: HeavyBall Authors
7
- Author-email: github.heavyball@nestler.sh
8
- License: BSD
9
- Classifier: Development Status :: 5 - Production/Stable
10
- Classifier: License :: OSI Approved :: BSD License
11
- Classifier: Programming Language :: Python
12
- Classifier: Programming Language :: Python :: 3.7
13
- Classifier: Programming Language :: Python :: 3.8
14
- Classifier: Programming Language :: Python :: 3.9
15
- Classifier: Topic :: Software Development :: Libraries
16
- Classifier: Topic :: Software Development :: Libraries :: Python Modules
3
+ Version: 1.6.3
4
+ Summary: Efficient Optimizers
5
+ Author-email: HeavyBall Authors <github.heavyball@nestler.sh>
6
+ Project-URL: source, https://github.com/HomebrewML/HeavyBall
7
+ Project-URL: tracker, https://github.com/HomebrewML/HeavyBall/issues
8
+ Keywords: torch,optimizer,muon,soap,psgd
17
9
  Classifier: Intended Audience :: Developers
18
- Requires-Python: >=3.7
10
+ Classifier: Intended Audience :: Science/Research
11
+ Classifier: License :: OSI Approved :: BSD License
12
+ Classifier: Natural Language :: English
13
+ Classifier: Operating System :: OS Independent
14
+ Classifier: Programming Language :: Python :: 3
15
+ Requires-Python: >=3.9
19
16
  Description-Content-Type: text/markdown
20
17
  License-File: LICENSE
21
- Requires-Dist: opt-einsum
22
- Requires-Dist: torch
18
+ Requires-Dist: opt-einsum>=3.4.0
19
+ Requires-Dist: torch>=2.1.0
23
20
  Requires-Dist: numpy
21
+ Provides-Extra: dev
22
+ Requires-Dist: pre-commit; extra == "dev"
23
+ Requires-Dist: pytest; extra == "dev"
24
+ Requires-Dist: ruff; extra == "dev"
25
+ Requires-Dist: matplotlib; extra == "dev"
26
+ Requires-Dist: seaborn; extra == "dev"
27
+ Requires-Dist: hyperopt; extra == "dev"
28
+ Requires-Dist: pandas; extra == "dev"
29
+ Requires-Dist: typer; extra == "dev"
24
30
 
25
31
  # `heavyball`: Efficient Optimizers
26
32
 
@@ -1,6 +1,6 @@
1
1
  LICENSE
2
2
  README.md
3
- setup.py
3
+ pyproject.toml
4
4
  heavyball/__init__.py
5
5
  heavyball/chainable.py
6
6
  heavyball/utils.py
@@ -0,0 +1,13 @@
1
+ opt-einsum>=3.4.0
2
+ torch>=2.1.0
3
+ numpy
4
+
5
+ [dev]
6
+ pre-commit
7
+ pytest
8
+ ruff
9
+ matplotlib
10
+ seaborn
11
+ hyperopt
12
+ pandas
13
+ typer
@@ -0,0 +1,52 @@
1
+ [build-system]
2
+ requires = ["setuptools>=75.0"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "heavyball"
7
+ description = "Efficient Optimizers"
8
+ version = "1.6.3"
9
+ authors = [{ name = "HeavyBall Authors", email = "github.heavyball@nestler.sh" }]
10
+ classifiers = ["Intended Audience :: Developers",
11
+ "Intended Audience :: Science/Research",
12
+ "License :: OSI Approved :: BSD License",
13
+ "Natural Language :: English",
14
+ "Operating System :: OS Independent",
15
+ "Programming Language :: Python :: 3",
16
+ ]
17
+ dependencies = ["opt-einsum>=3.4.0",
18
+ "torch>=2.1.0",
19
+ "numpy",
20
+ ]
21
+ keywords = ["torch",
22
+ "optimizer",
23
+ "muon",
24
+ "soap",
25
+ "psgd",
26
+ ]
27
+ readme = "README.md"
28
+ requires-python = ">=3.9"
29
+
30
+ [project.optional-dependencies]
31
+ dev = ["pre-commit", "pytest", "ruff", "matplotlib", "seaborn", "hyperopt", "pandas", "typer"]
32
+
33
+ [project.urls]
34
+ source = "https://github.com/HomebrewML/HeavyBall"
35
+ tracker = "https://github.com/HomebrewML/HeavyBall/issues"
36
+
37
+ [tool.ruff]
38
+ line-length = 120
39
+
40
+ [tool.ruff.lint]
41
+ extend-select = ["I", "W"]
42
+ ignore = ["E741"]
43
+ preview = true
44
+
45
+ [tool.ruff.lint.isort]
46
+ relative-imports-order = "closest-to-furthest"
47
+
48
+ [tool.ruff.format]
49
+ preview = true
50
+
51
+ [tool.setuptools.packages.find]
52
+ include = ["heavyball*"]
@@ -1,3 +0,0 @@
1
- opt-einsum
2
- torch
3
- numpy
heavyball-1.6.1/setup.py DELETED
@@ -1,33 +0,0 @@
1
- import setuptools
2
-
3
-
4
- with open('README.md') as f:
5
- README = f.read()
6
-
7
- setuptools.setup(
8
- author="HeavyBall Authors",
9
- author_email="github.heavyball@nestler.sh",
10
- name='heavyball',
11
- license='BSD',
12
- description='Efficient optimizers',
13
- version='1.6.1',
14
- long_description=README,
15
- url='https://github.com/HomebrewML/HeavyBall',
16
- packages=setuptools.find_packages(),
17
- python_requires=">=3.7",
18
- long_description_content_type="text/markdown",
19
- install_requires=['opt-einsum', 'torch', 'numpy'],
20
- classifiers=[
21
- # Trove classifiers
22
- # (https://pypi.python.org/pypi?%3Aaction=list_classifiers)
23
- 'Development Status :: 5 - Production/Stable',
24
- 'License :: OSI Approved :: BSD License',
25
- 'Programming Language :: Python',
26
- 'Programming Language :: Python :: 3.7',
27
- 'Programming Language :: Python :: 3.8',
28
- 'Programming Language :: Python :: 3.9',
29
- 'Topic :: Software Development :: Libraries',
30
- 'Topic :: Software Development :: Libraries :: Python Modules',
31
- 'Intended Audience :: Developers',
32
- ],
33
- )
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes