heavyball 1.7.1__tar.gz → 1.7.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. {heavyball-1.7.1 → heavyball-1.7.2}/PKG-INFO +1 -1
  2. {heavyball-1.7.1 → heavyball-1.7.2}/heavyball/utils.py +83 -30
  3. {heavyball-1.7.1 → heavyball-1.7.2}/heavyball.egg-info/PKG-INFO +1 -1
  4. {heavyball-1.7.1 → heavyball-1.7.2}/heavyball.egg-info/SOURCES.txt +1 -10
  5. {heavyball-1.7.1 → heavyball-1.7.2}/pyproject.toml +1 -1
  6. heavyball-1.7.1/heavyball/optimizations/__init__.py +0 -38
  7. heavyball-1.7.1/heavyball/optimizations/integrator.py +0 -169
  8. heavyball-1.7.1/heavyball/optimizations/optimizations.py +0 -329
  9. heavyball-1.7.1/tests/test_psgd_kron_line_optim.py +0 -141
  10. heavyball-1.7.1/tests/test_psgd_kron_regression.py +0 -46
  11. heavyball-1.7.1/tests/test_psgd_lra_regression.py +0 -87
  12. heavyball-1.7.1/tests/test_psgd_optimization.py +0 -190
  13. heavyball-1.7.1/tests/test_psgd_optimizations.py +0 -97
  14. heavyball-1.7.1/tests/test_psgd_training_performance.py +0 -253
  15. {heavyball-1.7.1 → heavyball-1.7.2}/LICENSE +0 -0
  16. {heavyball-1.7.1 → heavyball-1.7.2}/README.md +0 -0
  17. {heavyball-1.7.1 → heavyball-1.7.2}/heavyball/__init__.py +0 -0
  18. {heavyball-1.7.1 → heavyball-1.7.2}/heavyball/chainable.py +0 -0
  19. {heavyball-1.7.1 → heavyball-1.7.2}/heavyball.egg-info/dependency_links.txt +0 -0
  20. {heavyball-1.7.1 → heavyball-1.7.2}/heavyball.egg-info/requires.txt +0 -0
  21. {heavyball-1.7.1 → heavyball-1.7.2}/heavyball.egg-info/top_level.txt +0 -0
  22. {heavyball-1.7.1 → heavyball-1.7.2}/setup.cfg +0 -0
  23. {heavyball-1.7.1 → heavyball-1.7.2}/test/test_bf16_params.py +0 -0
  24. {heavyball-1.7.1 → heavyball-1.7.2}/test/test_bf16_q.py +0 -0
  25. {heavyball-1.7.1 → heavyball-1.7.2}/test/test_bf16_storage.py +0 -0
  26. {heavyball-1.7.1 → heavyball-1.7.2}/test/test_caution.py +0 -0
  27. {heavyball-1.7.1 → heavyball-1.7.2}/test/test_channels_last.py +0 -0
  28. {heavyball-1.7.1 → heavyball-1.7.2}/test/test_closure.py +0 -0
  29. {heavyball-1.7.1 → heavyball-1.7.2}/test/test_ema.py +0 -0
  30. {heavyball-1.7.1 → heavyball-1.7.2}/test/test_foreach.py +0 -0
  31. {heavyball-1.7.1 → heavyball-1.7.2}/test/test_hook.py +0 -0
  32. {heavyball-1.7.1 → heavyball-1.7.2}/test/test_mars.py +0 -0
  33. {heavyball-1.7.1 → heavyball-1.7.2}/test/test_memory.py +0 -0
  34. {heavyball-1.7.1 → heavyball-1.7.2}/test/test_memory_leak.py +0 -0
  35. {heavyball-1.7.1 → heavyball-1.7.2}/test/test_merge.py +0 -0
  36. {heavyball-1.7.1 → heavyball-1.7.2}/test/test_no_grad.py +0 -0
  37. {heavyball-1.7.1 → heavyball-1.7.2}/test/test_soap.py +0 -0
  38. {heavyball-1.7.1 → heavyball-1.7.2}/test/test_stochastic_updates.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: heavyball
3
- Version: 1.7.1
3
+ Version: 1.7.2
4
4
  Summary: Efficient Optimizers
5
5
  Author-email: HeavyBall Authors <github.heavyball@nestler.sh>
6
6
  Project-URL: source, https://github.com/HomebrewML/HeavyBall
@@ -50,7 +50,7 @@ def decorator(func):
50
50
  return _fn
51
51
 
52
52
 
53
- def decorator_knowngood(func: Callable):
53
+ def decorator_knowngood(func: Callable, fullgraph: bool = True):
54
54
  compiled = None
55
55
 
56
56
  @functools.wraps(func)
@@ -59,7 +59,7 @@ def decorator_knowngood(func: Callable):
59
59
  return func(*args, **kwargs)
60
60
  nonlocal compiled
61
61
  if compiled is None:
62
- compiled = torch.compile(fullgraph=True, dynamic=dynamic, mode=compile_mode)(func)
62
+ compiled = torch.compile(fullgraph=fullgraph, dynamic=dynamic, mode=compile_mode)(func)
63
63
  return compiled(*args, **kwargs)
64
64
 
65
65
  return _fn
@@ -1332,6 +1332,7 @@ def divided_root(x: torch.Tensor, y: torch.Tensor, pow0: float, pow1: float, eps
1332
1332
  def precond_init_scale(scale, scale_scale, grad, hessian_vector, vector, scale_max: float = 1e6):
1333
1333
  automatic_scale = True
1334
1334
  manual_hint = " Set it manually using `precond_init_scale=0.1`"
1335
+
1335
1336
  if scale is not None:
1336
1337
  automatic_scale = False
1337
1338
  warn_once(
@@ -1345,12 +1346,16 @@ def precond_init_scale(scale, scale_scale, grad, hessian_vector, vector, scale_m
1345
1346
  scale = mean_root(grad, 4) * scale_scale
1346
1347
  else:
1347
1348
  scale = divided_root(vector, hessian_vector, 2, 4) * scale_scale
1349
+
1348
1350
  if isinstance(scale, torch.Tensor):
1349
1351
  scale = scale.item() # slow, but necessary
1352
+
1350
1353
  if np.isfinite(scale):
1351
- if scale > scale_max or scale < 1 / scale_max:
1354
+ if scale > scale_max or scale < 1 / scale_max: # fallthrough to later checks
1352
1355
  warn_once(f"The computed precond_init_scale {scale} is outside of the expected range.{manual_hint}")
1353
- return scale
1356
+ else:
1357
+ return scale
1358
+
1354
1359
  if not automatic_scale:
1355
1360
  raise ValueError("The manually set precond_init_scale is not finite")
1356
1361
 
@@ -1361,6 +1366,10 @@ def precond_init_scale(scale, scale_scale, grad, hessian_vector, vector, scale_m
1361
1366
  raise ValueError(f"Grad or HVP is all 0s, causing NaNs in precond_init_scale computation.{manual_hint}")
1362
1367
  if not torch.isfinite(x).all().item():
1363
1368
  raise ValueError("Grad or HVP is not finite")
1369
+
1370
+ if np.isfinite(scale):
1371
+ return scale
1372
+
1364
1373
  raise ValueError(f"Computed precond_init_scale is not finite.{manual_hint}")
1365
1374
 
1366
1375
 
@@ -1634,35 +1643,58 @@ def dampen_multiple(g: List[Tensor], damp: float = 2**-13):
1634
1643
  return flatten(vs), flatten(gs)
1635
1644
 
1636
1645
 
1637
- @decorator_knowngood
1638
1646
  def casted_einsum(expr: str, *args: Tensor) -> Tensor:
1639
1647
  md = min_dtype(args)
1640
1648
  return torch.einsum(expr, *[a.to(md) for a in args]).to(args[-1].dtype)
1641
1649
 
1642
1650
 
1651
+ @decorator_knowngood
1652
+ def _psgd_calc_scalars_(Qs: List[Tensor], conjB: Tensor):
1653
+ triangular_qs = []
1654
+ for i, q in enumerate(Qs):
1655
+ q = promote(q)
1656
+ if q.dim() <= 1:
1657
+ shape = [1] * conjB.ndim
1658
+ shape[i] = -1
1659
+ conjB /= q.view(shape)
1660
+ else:
1661
+ triangular_qs.append((i, q))
1662
+ return triangular_qs
1663
+
1664
+
1665
+ @decorator_knowngood
1666
+ def _reshape_conjB(solved: Tensor, original_shape: List[int], last_dim: int, new_shape: int):
1667
+ solved = solved.reshape(original_shape)
1668
+ solved.transpose(last_dim, -1)
1669
+ return solved.reshape(new_shape).contiguous()
1670
+
1671
+
1643
1672
  def psgd_calc_A_and_conjB(exprA, G, Q, conjB): # conjB ("V", "vector") == randn during hvp/whitening
1644
1673
  order = G.dim()
1645
1674
  if order > 1:
1646
1675
  conjB = conjB.view_as(G).permute(*range(1, order), 0)
1647
1676
  conjB = conjB.to(promote(G.dtype))
1648
1677
  A = casted_einsum(exprA, *Q, G)
1649
- for i, q in enumerate(Q):
1650
- q = promote(q)
1651
- if q.dim() <= 1:
1652
- conjB /= q
1653
- else:
1654
- solved = torch.linalg.solve_triangular(q, conjB.reshape(-1, q.size(0)).contiguous(), upper=True, left=False)
1655
- conjB = solved.reshape_as(conjB)
1656
- if i < order - 1:
1657
- conjB = conjB.transpose(i, -1)
1678
+ solve = torch.compiler.disable(torch.linalg.solve_triangular)
1679
+ original_shape = conjB.shape
1680
+ prev_i = -1
1681
+ for i, tri_q in _psgd_calc_scalars_(Q, conjB):
1682
+ conjB = _reshape_conjB(conjB, original_shape, prev_i, [-1, tri_q.size(0)])
1683
+ prev_i = i
1684
+ conjB = solve(tri_q, conjB, upper=True, left=False)
1685
+ conjB = _reshape_conjB(conjB, original_shape, prev_i, original_shape)
1658
1686
  return A, conjB
1659
1687
 
1660
1688
 
1661
- def psgd_lb(A, max_abs):
1689
+ @decorator_knowngood
1690
+ def _max_select(to_index: Tensor, to_argmax: Tensor):
1691
+ idx = to_argmax.argmax()
1692
+ return to_index.index_select(1, idx).flatten().contiguous()
1693
+
1694
+
1695
+ def psgd_lb(A: Tensor, max_abs: Tensor):
1662
1696
  A /= max_abs
1663
- a0 = torch.einsum("ij,ij->j", A, A)
1664
- i = torch.argmax(a0)
1665
- x = torch.index_select(A, 1, i).flatten().contiguous()
1697
+ x = _max_select(A, torch.einsum("ij,ij->j", A, A))
1666
1698
  x = torch.einsum("i,ij->j", x, A)
1667
1699
  x /= x.norm()
1668
1700
  x = torch.einsum("j,kj->k", x, A)
@@ -1671,30 +1703,51 @@ def psgd_lb(A, max_abs):
1671
1703
  return x
1672
1704
 
1673
1705
 
1706
+ @decorator_knowngood
1707
+ def _subtract_from_line_(state: Tensor, term: Tensor):
1708
+ stochastic_add_([state], [triu_to_line([term])[0][1]], -1)
1709
+
1710
+
1711
+ @decorator_knowngood
1712
+ def _prescale_term_(term1: Tensor, fac: Tensor, norm: Tensor, lower_bound: Tensor):
1713
+ out = term1.float().triu() * fac
1714
+ out = out / torch.where(norm > 0, lower_bound, norm).clamp(tiny_bf16)
1715
+ copy_stochastic_(term1, out)
1716
+
1717
+
1718
+ @decorator_knowngood
1719
+ def _compilable_stochastic_multiply_div_(x: Tensor, fac: Tensor, y: Tensor, z: Tensor):
1720
+ copy_stochastic_(x, promote(x) * promote(fac) * promote(y) / promote(z).clamp(min=tiny_bf16))
1721
+
1722
+
1723
+ @decorator_knowngood
1724
+ def _compilable_add_sub_(x: Tensor, y: Tensor):
1725
+ x = promote(x)
1726
+ y = promote(y)
1727
+ return x - y, x + y
1728
+
1729
+
1674
1730
  @decorator
1675
1731
  def psgd_update_precond(Q, exprs, G, precond_lr, oq, store_triu_as_line, V):
1676
1732
  """Update Kronecker product preconditioner Q with pair (V, G)."""
1677
1733
  exprA, exprGs, _ = exprs
1678
1734
  A, conjB = psgd_calc_A_and_conjB(exprA, G, Q, V)
1735
+ precond_lr = scalar_guard(precond_lr, G)
1679
1736
 
1680
1737
  for q, exprG, o in zip(Q, exprGs, oq):
1681
- term1 = promote(torch.einsum(exprG, A, A))
1682
- term2 = promote(torch.einsum(exprG, conjB, conjB))
1683
- term1, term2 = term1 - term2, term1 + term2
1684
- term1 *= precond_lr
1738
+ term1 = torch.einsum(exprG, A, A)
1739
+ term2 = torch.einsum(exprG, conjB, conjB)
1740
+ term1, term2 = _compilable_add_sub_(term1, term2)
1685
1741
  norm = term2.norm(float("inf"))
1686
1742
  if q.dim() < 2:
1687
- term1 *= q.to(term1.dtype) / norm.clamp_(min=tiny_bf16)
1743
+ _compilable_stochastic_multiply_div_(term1, precond_lr, q, norm)
1688
1744
  else:
1689
- torch.triu(term1, out=term1)
1690
- term1 /= torch.where(norm > 0, psgd_lb(term2, norm), norm).clamp_(tiny_bf16)
1691
- term1 = torch.mm(term1, q.to(term1.dtype))
1745
+ lower_bound = psgd_lb(term2, norm)
1746
+ _prescale_term_(term1, precond_lr, lower_bound, norm)
1747
+ torch.mm(term1, q.to(term1.dtype), out=term1)
1692
1748
  if store_triu_as_line:
1693
- term1 = triu_to_line([term1])[0][1] # Convert update to line format
1694
- # Apply update directly to the tensor part of the state tuple o[1]
1695
- stochastic_add_(o[1], term1, -1)
1749
+ _subtract_from_line_(q, term1)
1696
1750
  else:
1697
- # Apply update to the state tensor o
1698
1751
  stochastic_add_(o, term1, -1)
1699
1752
 
1700
1753
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: heavyball
3
- Version: 1.7.1
3
+ Version: 1.7.2
4
4
  Summary: Efficient Optimizers
5
5
  Author-email: HeavyBall Authors <github.heavyball@nestler.sh>
6
6
  Project-URL: source, https://github.com/HomebrewML/HeavyBall
@@ -9,9 +9,6 @@ heavyball.egg-info/SOURCES.txt
9
9
  heavyball.egg-info/dependency_links.txt
10
10
  heavyball.egg-info/requires.txt
11
11
  heavyball.egg-info/top_level.txt
12
- heavyball/optimizations/__init__.py
13
- heavyball/optimizations/integrator.py
14
- heavyball/optimizations/optimizations.py
15
12
  test/test_bf16_params.py
16
13
  test/test_bf16_q.py
17
14
  test/test_bf16_storage.py
@@ -27,10 +24,4 @@ test/test_memory_leak.py
27
24
  test/test_merge.py
28
25
  test/test_no_grad.py
29
26
  test/test_soap.py
30
- test/test_stochastic_updates.py
31
- tests/test_psgd_kron_line_optim.py
32
- tests/test_psgd_kron_regression.py
33
- tests/test_psgd_lra_regression.py
34
- tests/test_psgd_optimization.py
35
- tests/test_psgd_optimizations.py
36
- tests/test_psgd_training_performance.py
27
+ test/test_stochastic_updates.py
@@ -5,7 +5,7 @@ build-backend = "setuptools.build_meta"
5
5
  [project]
6
6
  name = "heavyball"
7
7
  description = "Efficient Optimizers"
8
- version = "1.7.1"
8
+ version = "1.7.2"
9
9
  authors = [{ name = "HeavyBall Authors", email = "github.heavyball@nestler.sh" }]
10
10
  classifiers = ["Intended Audience :: Developers",
11
11
  "Intended Audience :: Science/Research",
@@ -1,38 +0,0 @@
1
- """
2
- PSGD optimization module - optimized implementations of PSGD functions
3
- to improve execution speed while maintaining numerical equivalence.
4
- """
5
-
6
- # Import optimized functions
7
- # Import integrator API
8
- from .integrator import (
9
- enable_optimizations,
10
- get_optimization_status,
11
- restore_original_functions,
12
- )
13
- from .optimizations import (
14
- # LRA optimizations
15
- low_rank_mm_optimized,
16
- lra_precond_optimized,
17
- precond_grad_cached_optimized,
18
- # KRON optimizations
19
- psgd_calc_A_and_conjB_optimized,
20
- psgd_precond_grad_optimized,
21
- psgd_update_precond_optimized,
22
- update_lra_precond_optimized,
23
- )
24
-
25
- __all__ = [
26
- # Optimized functions
27
- "low_rank_mm_optimized",
28
- "update_lra_precond_optimized",
29
- "lra_precond_optimized",
30
- "psgd_calc_A_and_conjB_optimized",
31
- "psgd_update_precond_optimized",
32
- "psgd_precond_grad_optimized",
33
- "precond_grad_cached_optimized",
34
- # Integrator API
35
- "enable_optimizations",
36
- "restore_original_functions",
37
- "get_optimization_status",
38
- ]
@@ -1,169 +0,0 @@
1
- """
2
- Integration module to selectively enable optimized implementations
3
- of PSGD functions while maintaining API compatibility.
4
- """
5
-
6
- import os
7
- import sys
8
- from typing import Any, Dict
9
-
10
- import torch
11
-
12
- from . import optimizations
13
- from .. import utils
14
-
15
- # Store original function references
16
- _original_functions = {}
17
- _optimized_functions = {}
18
-
19
- # Mapping of original functions to their optimized versions
20
- OPTIMIZATION_MAP = {
21
- # LRA functions
22
- utils.update_lra_precond_: optimizations.update_lra_precond_optimized,
23
- utils.lra_precond: optimizations.lra_precond_optimized,
24
- # KRON functions
25
- utils.psgd_update_precond: optimizations.psgd_update_precond_optimized,
26
- utils.psgd_precond_grad: optimizations.psgd_precond_grad_optimized,
27
- utils.precond_grad_cached_: optimizations.precond_grad_cached_optimized,
28
- }
29
-
30
- # Config for enabling/disabling optimizations
31
- _config = {
32
- "enabled": os.environ.get("HEAVYBALL_OPTIMIZE", "1") == "1",
33
- "torch_compile_allowed": os.environ.get("HEAVYBALL_USE_COMPILE", "1") == "1",
34
- "enable_lra": True,
35
- "enable_kron": True,
36
- "verbose": os.environ.get("HEAVYBALL_VERBOSE", "0") == "1",
37
- }
38
-
39
-
40
- def _apply_monkey_patch(original_func, optimized_func):
41
- """Monkey patch a function with its optimized version."""
42
- if original_func not in _original_functions:
43
- _original_functions[original_func] = original_func
44
-
45
- # Store reference to the optimized function
46
- _optimized_functions[original_func] = optimized_func
47
-
48
- # Get the module where the original function is defined
49
- module = original_func.__module__
50
- func_name = original_func.__name__
51
-
52
- # Replace the function in its module
53
- if hasattr(sys.modules[module], func_name):
54
- setattr(sys.modules[module], func_name, optimized_func)
55
-
56
- if _config["verbose"]:
57
- print(f"Replaced {module}.{func_name} with optimized version")
58
- else:
59
- if _config["verbose"]:
60
- print(f"Warning: Could not find {func_name} in module {module}")
61
-
62
-
63
- def enable_optimizations(
64
- enable: bool = True, lra: bool = True, kron: bool = True, torch_compile: bool = True, verbose: bool = False
65
- ):
66
- """
67
- Enable or disable PSGD optimizations.
68
-
69
- Args:
70
- enable: Whether to enable optimizations at all
71
- lra: Whether to enable LRA-specific optimizations
72
- kron: Whether to enable Kron-specific optimizations
73
- torch_compile: Whether to allow torch.compile optimizations
74
- verbose: Whether to print optimization status messages
75
- """
76
- _config["enabled"] = enable
77
- _config["enable_lra"] = lra
78
- _config["enable_kron"] = kron
79
- _config["torch_compile_allowed"] = torch_compile
80
- _config["verbose"] = verbose
81
-
82
- if verbose:
83
- print(f"PSGD Optimizations: {'enabled' if enable else 'disabled'}")
84
- print(f" - LRA optimizations: {'enabled' if lra else 'disabled'}")
85
- print(f" - KRON optimizations: {'enabled' if kron else 'disabled'}")
86
- print(f" - torch.compile: {'allowed' if torch_compile else 'disabled'}")
87
-
88
- if not enable:
89
- # Restore original functions
90
- restore_original_functions()
91
- return
92
-
93
- # Apply optimizations based on config
94
- for orig_func, opt_func in OPTIMIZATION_MAP.items():
95
- # Skip LRA functions if disabled
96
- if not _config["enable_lra"] and orig_func in [utils.update_lra_precond_, utils.lra_precond]:
97
- continue
98
-
99
- # Skip KRON functions if disabled
100
- if not _config["enable_kron"] and orig_func in [
101
- utils.psgd_update_precond,
102
- utils.psgd_precond_grad,
103
- utils.precond_grad_cached_,
104
- ]:
105
- continue
106
-
107
- _apply_monkey_patch(orig_func, opt_func)
108
-
109
- # Disable torch.compile if not allowed
110
- if not _config["torch_compile_allowed"]:
111
- # Monkey patch torch.compile to be a no-op
112
- def _noop_compile(fn, **kwargs):
113
- return fn
114
-
115
- if not hasattr(torch, "_original_compile"):
116
- torch._original_compile = torch.compile
117
- torch.compile = _noop_compile
118
- if verbose:
119
- print("Disabled torch.compile (replaced with no-op)")
120
- else:
121
- # Restore original torch.compile
122
- if hasattr(torch, "_original_compile"):
123
- torch.compile = torch._original_compile
124
- del torch._original_compile
125
- if verbose:
126
- print("Restored original torch.compile")
127
-
128
-
129
- def restore_original_functions():
130
- """Restore all original function implementations."""
131
- for orig_func, func_ref in _original_functions.items():
132
- module = orig_func.__module__
133
- func_name = orig_func.__name__
134
-
135
- if hasattr(sys.modules[module], func_name):
136
- setattr(sys.modules[module], func_name, func_ref)
137
-
138
- if _config["verbose"]:
139
- print(f"Restored original implementation of {module}.{func_name}")
140
-
141
- # Also restore torch.compile if it was modified
142
- if hasattr(torch, "_original_compile"):
143
- torch.compile = torch._original_compile
144
- del torch._original_compile
145
- if _config["verbose"]:
146
- print("Restored original torch.compile")
147
-
148
-
149
- def get_optimization_status() -> Dict[str, Any]:
150
- """Get current optimization status."""
151
- return {
152
- "enabled": _config["enabled"],
153
- "lra_enabled": _config["enable_lra"],
154
- "kron_enabled": _config["enable_kron"],
155
- "torch_compile_allowed": _config["torch_compile_allowed"],
156
- "optimized_functions": list(_optimized_functions.keys()),
157
- "original_functions": list(_original_functions.keys()),
158
- }
159
-
160
-
161
- # Auto-initialize optimizations based on environment
162
- if os.environ.get("HEAVYBALL_AUTO_OPTIMIZE", "1") == "1":
163
- enable_optimizations(
164
- enable=_config["enabled"],
165
- lra=_config["enable_lra"],
166
- kron=_config["enable_kron"],
167
- torch_compile=_config["torch_compile_allowed"],
168
- verbose=_config["verbose"],
169
- )