heavyball 1.7.1__tar.gz → 1.7.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {heavyball-1.7.1 → heavyball-1.7.2}/PKG-INFO +1 -1
- {heavyball-1.7.1 → heavyball-1.7.2}/heavyball/utils.py +83 -30
- {heavyball-1.7.1 → heavyball-1.7.2}/heavyball.egg-info/PKG-INFO +1 -1
- {heavyball-1.7.1 → heavyball-1.7.2}/heavyball.egg-info/SOURCES.txt +1 -10
- {heavyball-1.7.1 → heavyball-1.7.2}/pyproject.toml +1 -1
- heavyball-1.7.1/heavyball/optimizations/__init__.py +0 -38
- heavyball-1.7.1/heavyball/optimizations/integrator.py +0 -169
- heavyball-1.7.1/heavyball/optimizations/optimizations.py +0 -329
- heavyball-1.7.1/tests/test_psgd_kron_line_optim.py +0 -141
- heavyball-1.7.1/tests/test_psgd_kron_regression.py +0 -46
- heavyball-1.7.1/tests/test_psgd_lra_regression.py +0 -87
- heavyball-1.7.1/tests/test_psgd_optimization.py +0 -190
- heavyball-1.7.1/tests/test_psgd_optimizations.py +0 -97
- heavyball-1.7.1/tests/test_psgd_training_performance.py +0 -253
- {heavyball-1.7.1 → heavyball-1.7.2}/LICENSE +0 -0
- {heavyball-1.7.1 → heavyball-1.7.2}/README.md +0 -0
- {heavyball-1.7.1 → heavyball-1.7.2}/heavyball/__init__.py +0 -0
- {heavyball-1.7.1 → heavyball-1.7.2}/heavyball/chainable.py +0 -0
- {heavyball-1.7.1 → heavyball-1.7.2}/heavyball.egg-info/dependency_links.txt +0 -0
- {heavyball-1.7.1 → heavyball-1.7.2}/heavyball.egg-info/requires.txt +0 -0
- {heavyball-1.7.1 → heavyball-1.7.2}/heavyball.egg-info/top_level.txt +0 -0
- {heavyball-1.7.1 → heavyball-1.7.2}/setup.cfg +0 -0
- {heavyball-1.7.1 → heavyball-1.7.2}/test/test_bf16_params.py +0 -0
- {heavyball-1.7.1 → heavyball-1.7.2}/test/test_bf16_q.py +0 -0
- {heavyball-1.7.1 → heavyball-1.7.2}/test/test_bf16_storage.py +0 -0
- {heavyball-1.7.1 → heavyball-1.7.2}/test/test_caution.py +0 -0
- {heavyball-1.7.1 → heavyball-1.7.2}/test/test_channels_last.py +0 -0
- {heavyball-1.7.1 → heavyball-1.7.2}/test/test_closure.py +0 -0
- {heavyball-1.7.1 → heavyball-1.7.2}/test/test_ema.py +0 -0
- {heavyball-1.7.1 → heavyball-1.7.2}/test/test_foreach.py +0 -0
- {heavyball-1.7.1 → heavyball-1.7.2}/test/test_hook.py +0 -0
- {heavyball-1.7.1 → heavyball-1.7.2}/test/test_mars.py +0 -0
- {heavyball-1.7.1 → heavyball-1.7.2}/test/test_memory.py +0 -0
- {heavyball-1.7.1 → heavyball-1.7.2}/test/test_memory_leak.py +0 -0
- {heavyball-1.7.1 → heavyball-1.7.2}/test/test_merge.py +0 -0
- {heavyball-1.7.1 → heavyball-1.7.2}/test/test_no_grad.py +0 -0
- {heavyball-1.7.1 → heavyball-1.7.2}/test/test_soap.py +0 -0
- {heavyball-1.7.1 → heavyball-1.7.2}/test/test_stochastic_updates.py +0 -0
@@ -50,7 +50,7 @@ def decorator(func):
|
|
50
50
|
return _fn
|
51
51
|
|
52
52
|
|
53
|
-
def decorator_knowngood(func: Callable):
|
53
|
+
def decorator_knowngood(func: Callable, fullgraph: bool = True):
|
54
54
|
compiled = None
|
55
55
|
|
56
56
|
@functools.wraps(func)
|
@@ -59,7 +59,7 @@ def decorator_knowngood(func: Callable):
|
|
59
59
|
return func(*args, **kwargs)
|
60
60
|
nonlocal compiled
|
61
61
|
if compiled is None:
|
62
|
-
compiled = torch.compile(fullgraph=
|
62
|
+
compiled = torch.compile(fullgraph=fullgraph, dynamic=dynamic, mode=compile_mode)(func)
|
63
63
|
return compiled(*args, **kwargs)
|
64
64
|
|
65
65
|
return _fn
|
@@ -1332,6 +1332,7 @@ def divided_root(x: torch.Tensor, y: torch.Tensor, pow0: float, pow1: float, eps
|
|
1332
1332
|
def precond_init_scale(scale, scale_scale, grad, hessian_vector, vector, scale_max: float = 1e6):
|
1333
1333
|
automatic_scale = True
|
1334
1334
|
manual_hint = " Set it manually using `precond_init_scale=0.1`"
|
1335
|
+
|
1335
1336
|
if scale is not None:
|
1336
1337
|
automatic_scale = False
|
1337
1338
|
warn_once(
|
@@ -1345,12 +1346,16 @@ def precond_init_scale(scale, scale_scale, grad, hessian_vector, vector, scale_m
|
|
1345
1346
|
scale = mean_root(grad, 4) * scale_scale
|
1346
1347
|
else:
|
1347
1348
|
scale = divided_root(vector, hessian_vector, 2, 4) * scale_scale
|
1349
|
+
|
1348
1350
|
if isinstance(scale, torch.Tensor):
|
1349
1351
|
scale = scale.item() # slow, but necessary
|
1352
|
+
|
1350
1353
|
if np.isfinite(scale):
|
1351
|
-
if scale > scale_max or scale < 1 / scale_max:
|
1354
|
+
if scale > scale_max or scale < 1 / scale_max: # fallthrough to later checks
|
1352
1355
|
warn_once(f"The computed precond_init_scale {scale} is outside of the expected range.{manual_hint}")
|
1353
|
-
|
1356
|
+
else:
|
1357
|
+
return scale
|
1358
|
+
|
1354
1359
|
if not automatic_scale:
|
1355
1360
|
raise ValueError("The manually set precond_init_scale is not finite")
|
1356
1361
|
|
@@ -1361,6 +1366,10 @@ def precond_init_scale(scale, scale_scale, grad, hessian_vector, vector, scale_m
|
|
1361
1366
|
raise ValueError(f"Grad or HVP is all 0s, causing NaNs in precond_init_scale computation.{manual_hint}")
|
1362
1367
|
if not torch.isfinite(x).all().item():
|
1363
1368
|
raise ValueError("Grad or HVP is not finite")
|
1369
|
+
|
1370
|
+
if np.isfinite(scale):
|
1371
|
+
return scale
|
1372
|
+
|
1364
1373
|
raise ValueError(f"Computed precond_init_scale is not finite.{manual_hint}")
|
1365
1374
|
|
1366
1375
|
|
@@ -1634,35 +1643,58 @@ def dampen_multiple(g: List[Tensor], damp: float = 2**-13):
|
|
1634
1643
|
return flatten(vs), flatten(gs)
|
1635
1644
|
|
1636
1645
|
|
1637
|
-
@decorator_knowngood
|
1638
1646
|
def casted_einsum(expr: str, *args: Tensor) -> Tensor:
|
1639
1647
|
md = min_dtype(args)
|
1640
1648
|
return torch.einsum(expr, *[a.to(md) for a in args]).to(args[-1].dtype)
|
1641
1649
|
|
1642
1650
|
|
1651
|
+
@decorator_knowngood
|
1652
|
+
def _psgd_calc_scalars_(Qs: List[Tensor], conjB: Tensor):
|
1653
|
+
triangular_qs = []
|
1654
|
+
for i, q in enumerate(Qs):
|
1655
|
+
q = promote(q)
|
1656
|
+
if q.dim() <= 1:
|
1657
|
+
shape = [1] * conjB.ndim
|
1658
|
+
shape[i] = -1
|
1659
|
+
conjB /= q.view(shape)
|
1660
|
+
else:
|
1661
|
+
triangular_qs.append((i, q))
|
1662
|
+
return triangular_qs
|
1663
|
+
|
1664
|
+
|
1665
|
+
@decorator_knowngood
|
1666
|
+
def _reshape_conjB(solved: Tensor, original_shape: List[int], last_dim: int, new_shape: int):
|
1667
|
+
solved = solved.reshape(original_shape)
|
1668
|
+
solved.transpose(last_dim, -1)
|
1669
|
+
return solved.reshape(new_shape).contiguous()
|
1670
|
+
|
1671
|
+
|
1643
1672
|
def psgd_calc_A_and_conjB(exprA, G, Q, conjB): # conjB ("V", "vector") == randn during hvp/whitening
|
1644
1673
|
order = G.dim()
|
1645
1674
|
if order > 1:
|
1646
1675
|
conjB = conjB.view_as(G).permute(*range(1, order), 0)
|
1647
1676
|
conjB = conjB.to(promote(G.dtype))
|
1648
1677
|
A = casted_einsum(exprA, *Q, G)
|
1649
|
-
|
1650
|
-
|
1651
|
-
|
1652
|
-
|
1653
|
-
|
1654
|
-
|
1655
|
-
|
1656
|
-
|
1657
|
-
conjB = conjB.transpose(i, -1)
|
1678
|
+
solve = torch.compiler.disable(torch.linalg.solve_triangular)
|
1679
|
+
original_shape = conjB.shape
|
1680
|
+
prev_i = -1
|
1681
|
+
for i, tri_q in _psgd_calc_scalars_(Q, conjB):
|
1682
|
+
conjB = _reshape_conjB(conjB, original_shape, prev_i, [-1, tri_q.size(0)])
|
1683
|
+
prev_i = i
|
1684
|
+
conjB = solve(tri_q, conjB, upper=True, left=False)
|
1685
|
+
conjB = _reshape_conjB(conjB, original_shape, prev_i, original_shape)
|
1658
1686
|
return A, conjB
|
1659
1687
|
|
1660
1688
|
|
1661
|
-
|
1689
|
+
@decorator_knowngood
|
1690
|
+
def _max_select(to_index: Tensor, to_argmax: Tensor):
|
1691
|
+
idx = to_argmax.argmax()
|
1692
|
+
return to_index.index_select(1, idx).flatten().contiguous()
|
1693
|
+
|
1694
|
+
|
1695
|
+
def psgd_lb(A: Tensor, max_abs: Tensor):
|
1662
1696
|
A /= max_abs
|
1663
|
-
|
1664
|
-
i = torch.argmax(a0)
|
1665
|
-
x = torch.index_select(A, 1, i).flatten().contiguous()
|
1697
|
+
x = _max_select(A, torch.einsum("ij,ij->j", A, A))
|
1666
1698
|
x = torch.einsum("i,ij->j", x, A)
|
1667
1699
|
x /= x.norm()
|
1668
1700
|
x = torch.einsum("j,kj->k", x, A)
|
@@ -1671,30 +1703,51 @@ def psgd_lb(A, max_abs):
|
|
1671
1703
|
return x
|
1672
1704
|
|
1673
1705
|
|
1706
|
+
@decorator_knowngood
|
1707
|
+
def _subtract_from_line_(state: Tensor, term: Tensor):
|
1708
|
+
stochastic_add_([state], [triu_to_line([term])[0][1]], -1)
|
1709
|
+
|
1710
|
+
|
1711
|
+
@decorator_knowngood
|
1712
|
+
def _prescale_term_(term1: Tensor, fac: Tensor, norm: Tensor, lower_bound: Tensor):
|
1713
|
+
out = term1.float().triu() * fac
|
1714
|
+
out = out / torch.where(norm > 0, lower_bound, norm).clamp(tiny_bf16)
|
1715
|
+
copy_stochastic_(term1, out)
|
1716
|
+
|
1717
|
+
|
1718
|
+
@decorator_knowngood
|
1719
|
+
def _compilable_stochastic_multiply_div_(x: Tensor, fac: Tensor, y: Tensor, z: Tensor):
|
1720
|
+
copy_stochastic_(x, promote(x) * promote(fac) * promote(y) / promote(z).clamp(min=tiny_bf16))
|
1721
|
+
|
1722
|
+
|
1723
|
+
@decorator_knowngood
|
1724
|
+
def _compilable_add_sub_(x: Tensor, y: Tensor):
|
1725
|
+
x = promote(x)
|
1726
|
+
y = promote(y)
|
1727
|
+
return x - y, x + y
|
1728
|
+
|
1729
|
+
|
1674
1730
|
@decorator
|
1675
1731
|
def psgd_update_precond(Q, exprs, G, precond_lr, oq, store_triu_as_line, V):
|
1676
1732
|
"""Update Kronecker product preconditioner Q with pair (V, G)."""
|
1677
1733
|
exprA, exprGs, _ = exprs
|
1678
1734
|
A, conjB = psgd_calc_A_and_conjB(exprA, G, Q, V)
|
1735
|
+
precond_lr = scalar_guard(precond_lr, G)
|
1679
1736
|
|
1680
1737
|
for q, exprG, o in zip(Q, exprGs, oq):
|
1681
|
-
term1 =
|
1682
|
-
term2 =
|
1683
|
-
term1, term2 = term1
|
1684
|
-
term1 *= precond_lr
|
1738
|
+
term1 = torch.einsum(exprG, A, A)
|
1739
|
+
term2 = torch.einsum(exprG, conjB, conjB)
|
1740
|
+
term1, term2 = _compilable_add_sub_(term1, term2)
|
1685
1741
|
norm = term2.norm(float("inf"))
|
1686
1742
|
if q.dim() < 2:
|
1687
|
-
term1
|
1743
|
+
_compilable_stochastic_multiply_div_(term1, precond_lr, q, norm)
|
1688
1744
|
else:
|
1689
|
-
|
1690
|
-
term1
|
1691
|
-
|
1745
|
+
lower_bound = psgd_lb(term2, norm)
|
1746
|
+
_prescale_term_(term1, precond_lr, lower_bound, norm)
|
1747
|
+
torch.mm(term1, q.to(term1.dtype), out=term1)
|
1692
1748
|
if store_triu_as_line:
|
1693
|
-
|
1694
|
-
# Apply update directly to the tensor part of the state tuple o[1]
|
1695
|
-
stochastic_add_(o[1], term1, -1)
|
1749
|
+
_subtract_from_line_(q, term1)
|
1696
1750
|
else:
|
1697
|
-
# Apply update to the state tensor o
|
1698
1751
|
stochastic_add_(o, term1, -1)
|
1699
1752
|
|
1700
1753
|
|
@@ -9,9 +9,6 @@ heavyball.egg-info/SOURCES.txt
|
|
9
9
|
heavyball.egg-info/dependency_links.txt
|
10
10
|
heavyball.egg-info/requires.txt
|
11
11
|
heavyball.egg-info/top_level.txt
|
12
|
-
heavyball/optimizations/__init__.py
|
13
|
-
heavyball/optimizations/integrator.py
|
14
|
-
heavyball/optimizations/optimizations.py
|
15
12
|
test/test_bf16_params.py
|
16
13
|
test/test_bf16_q.py
|
17
14
|
test/test_bf16_storage.py
|
@@ -27,10 +24,4 @@ test/test_memory_leak.py
|
|
27
24
|
test/test_merge.py
|
28
25
|
test/test_no_grad.py
|
29
26
|
test/test_soap.py
|
30
|
-
test/test_stochastic_updates.py
|
31
|
-
tests/test_psgd_kron_line_optim.py
|
32
|
-
tests/test_psgd_kron_regression.py
|
33
|
-
tests/test_psgd_lra_regression.py
|
34
|
-
tests/test_psgd_optimization.py
|
35
|
-
tests/test_psgd_optimizations.py
|
36
|
-
tests/test_psgd_training_performance.py
|
27
|
+
test/test_stochastic_updates.py
|
@@ -5,7 +5,7 @@ build-backend = "setuptools.build_meta"
|
|
5
5
|
[project]
|
6
6
|
name = "heavyball"
|
7
7
|
description = "Efficient Optimizers"
|
8
|
-
version = "1.7.
|
8
|
+
version = "1.7.2"
|
9
9
|
authors = [{ name = "HeavyBall Authors", email = "github.heavyball@nestler.sh" }]
|
10
10
|
classifiers = ["Intended Audience :: Developers",
|
11
11
|
"Intended Audience :: Science/Research",
|
@@ -1,38 +0,0 @@
|
|
1
|
-
"""
|
2
|
-
PSGD optimization module - optimized implementations of PSGD functions
|
3
|
-
to improve execution speed while maintaining numerical equivalence.
|
4
|
-
"""
|
5
|
-
|
6
|
-
# Import optimized functions
|
7
|
-
# Import integrator API
|
8
|
-
from .integrator import (
|
9
|
-
enable_optimizations,
|
10
|
-
get_optimization_status,
|
11
|
-
restore_original_functions,
|
12
|
-
)
|
13
|
-
from .optimizations import (
|
14
|
-
# LRA optimizations
|
15
|
-
low_rank_mm_optimized,
|
16
|
-
lra_precond_optimized,
|
17
|
-
precond_grad_cached_optimized,
|
18
|
-
# KRON optimizations
|
19
|
-
psgd_calc_A_and_conjB_optimized,
|
20
|
-
psgd_precond_grad_optimized,
|
21
|
-
psgd_update_precond_optimized,
|
22
|
-
update_lra_precond_optimized,
|
23
|
-
)
|
24
|
-
|
25
|
-
__all__ = [
|
26
|
-
# Optimized functions
|
27
|
-
"low_rank_mm_optimized",
|
28
|
-
"update_lra_precond_optimized",
|
29
|
-
"lra_precond_optimized",
|
30
|
-
"psgd_calc_A_and_conjB_optimized",
|
31
|
-
"psgd_update_precond_optimized",
|
32
|
-
"psgd_precond_grad_optimized",
|
33
|
-
"precond_grad_cached_optimized",
|
34
|
-
# Integrator API
|
35
|
-
"enable_optimizations",
|
36
|
-
"restore_original_functions",
|
37
|
-
"get_optimization_status",
|
38
|
-
]
|
@@ -1,169 +0,0 @@
|
|
1
|
-
"""
|
2
|
-
Integration module to selectively enable optimized implementations
|
3
|
-
of PSGD functions while maintaining API compatibility.
|
4
|
-
"""
|
5
|
-
|
6
|
-
import os
|
7
|
-
import sys
|
8
|
-
from typing import Any, Dict
|
9
|
-
|
10
|
-
import torch
|
11
|
-
|
12
|
-
from . import optimizations
|
13
|
-
from .. import utils
|
14
|
-
|
15
|
-
# Store original function references
|
16
|
-
_original_functions = {}
|
17
|
-
_optimized_functions = {}
|
18
|
-
|
19
|
-
# Mapping of original functions to their optimized versions
|
20
|
-
OPTIMIZATION_MAP = {
|
21
|
-
# LRA functions
|
22
|
-
utils.update_lra_precond_: optimizations.update_lra_precond_optimized,
|
23
|
-
utils.lra_precond: optimizations.lra_precond_optimized,
|
24
|
-
# KRON functions
|
25
|
-
utils.psgd_update_precond: optimizations.psgd_update_precond_optimized,
|
26
|
-
utils.psgd_precond_grad: optimizations.psgd_precond_grad_optimized,
|
27
|
-
utils.precond_grad_cached_: optimizations.precond_grad_cached_optimized,
|
28
|
-
}
|
29
|
-
|
30
|
-
# Config for enabling/disabling optimizations
|
31
|
-
_config = {
|
32
|
-
"enabled": os.environ.get("HEAVYBALL_OPTIMIZE", "1") == "1",
|
33
|
-
"torch_compile_allowed": os.environ.get("HEAVYBALL_USE_COMPILE", "1") == "1",
|
34
|
-
"enable_lra": True,
|
35
|
-
"enable_kron": True,
|
36
|
-
"verbose": os.environ.get("HEAVYBALL_VERBOSE", "0") == "1",
|
37
|
-
}
|
38
|
-
|
39
|
-
|
40
|
-
def _apply_monkey_patch(original_func, optimized_func):
|
41
|
-
"""Monkey patch a function with its optimized version."""
|
42
|
-
if original_func not in _original_functions:
|
43
|
-
_original_functions[original_func] = original_func
|
44
|
-
|
45
|
-
# Store reference to the optimized function
|
46
|
-
_optimized_functions[original_func] = optimized_func
|
47
|
-
|
48
|
-
# Get the module where the original function is defined
|
49
|
-
module = original_func.__module__
|
50
|
-
func_name = original_func.__name__
|
51
|
-
|
52
|
-
# Replace the function in its module
|
53
|
-
if hasattr(sys.modules[module], func_name):
|
54
|
-
setattr(sys.modules[module], func_name, optimized_func)
|
55
|
-
|
56
|
-
if _config["verbose"]:
|
57
|
-
print(f"Replaced {module}.{func_name} with optimized version")
|
58
|
-
else:
|
59
|
-
if _config["verbose"]:
|
60
|
-
print(f"Warning: Could not find {func_name} in module {module}")
|
61
|
-
|
62
|
-
|
63
|
-
def enable_optimizations(
|
64
|
-
enable: bool = True, lra: bool = True, kron: bool = True, torch_compile: bool = True, verbose: bool = False
|
65
|
-
):
|
66
|
-
"""
|
67
|
-
Enable or disable PSGD optimizations.
|
68
|
-
|
69
|
-
Args:
|
70
|
-
enable: Whether to enable optimizations at all
|
71
|
-
lra: Whether to enable LRA-specific optimizations
|
72
|
-
kron: Whether to enable Kron-specific optimizations
|
73
|
-
torch_compile: Whether to allow torch.compile optimizations
|
74
|
-
verbose: Whether to print optimization status messages
|
75
|
-
"""
|
76
|
-
_config["enabled"] = enable
|
77
|
-
_config["enable_lra"] = lra
|
78
|
-
_config["enable_kron"] = kron
|
79
|
-
_config["torch_compile_allowed"] = torch_compile
|
80
|
-
_config["verbose"] = verbose
|
81
|
-
|
82
|
-
if verbose:
|
83
|
-
print(f"PSGD Optimizations: {'enabled' if enable else 'disabled'}")
|
84
|
-
print(f" - LRA optimizations: {'enabled' if lra else 'disabled'}")
|
85
|
-
print(f" - KRON optimizations: {'enabled' if kron else 'disabled'}")
|
86
|
-
print(f" - torch.compile: {'allowed' if torch_compile else 'disabled'}")
|
87
|
-
|
88
|
-
if not enable:
|
89
|
-
# Restore original functions
|
90
|
-
restore_original_functions()
|
91
|
-
return
|
92
|
-
|
93
|
-
# Apply optimizations based on config
|
94
|
-
for orig_func, opt_func in OPTIMIZATION_MAP.items():
|
95
|
-
# Skip LRA functions if disabled
|
96
|
-
if not _config["enable_lra"] and orig_func in [utils.update_lra_precond_, utils.lra_precond]:
|
97
|
-
continue
|
98
|
-
|
99
|
-
# Skip KRON functions if disabled
|
100
|
-
if not _config["enable_kron"] and orig_func in [
|
101
|
-
utils.psgd_update_precond,
|
102
|
-
utils.psgd_precond_grad,
|
103
|
-
utils.precond_grad_cached_,
|
104
|
-
]:
|
105
|
-
continue
|
106
|
-
|
107
|
-
_apply_monkey_patch(orig_func, opt_func)
|
108
|
-
|
109
|
-
# Disable torch.compile if not allowed
|
110
|
-
if not _config["torch_compile_allowed"]:
|
111
|
-
# Monkey patch torch.compile to be a no-op
|
112
|
-
def _noop_compile(fn, **kwargs):
|
113
|
-
return fn
|
114
|
-
|
115
|
-
if not hasattr(torch, "_original_compile"):
|
116
|
-
torch._original_compile = torch.compile
|
117
|
-
torch.compile = _noop_compile
|
118
|
-
if verbose:
|
119
|
-
print("Disabled torch.compile (replaced with no-op)")
|
120
|
-
else:
|
121
|
-
# Restore original torch.compile
|
122
|
-
if hasattr(torch, "_original_compile"):
|
123
|
-
torch.compile = torch._original_compile
|
124
|
-
del torch._original_compile
|
125
|
-
if verbose:
|
126
|
-
print("Restored original torch.compile")
|
127
|
-
|
128
|
-
|
129
|
-
def restore_original_functions():
|
130
|
-
"""Restore all original function implementations."""
|
131
|
-
for orig_func, func_ref in _original_functions.items():
|
132
|
-
module = orig_func.__module__
|
133
|
-
func_name = orig_func.__name__
|
134
|
-
|
135
|
-
if hasattr(sys.modules[module], func_name):
|
136
|
-
setattr(sys.modules[module], func_name, func_ref)
|
137
|
-
|
138
|
-
if _config["verbose"]:
|
139
|
-
print(f"Restored original implementation of {module}.{func_name}")
|
140
|
-
|
141
|
-
# Also restore torch.compile if it was modified
|
142
|
-
if hasattr(torch, "_original_compile"):
|
143
|
-
torch.compile = torch._original_compile
|
144
|
-
del torch._original_compile
|
145
|
-
if _config["verbose"]:
|
146
|
-
print("Restored original torch.compile")
|
147
|
-
|
148
|
-
|
149
|
-
def get_optimization_status() -> Dict[str, Any]:
|
150
|
-
"""Get current optimization status."""
|
151
|
-
return {
|
152
|
-
"enabled": _config["enabled"],
|
153
|
-
"lra_enabled": _config["enable_lra"],
|
154
|
-
"kron_enabled": _config["enable_kron"],
|
155
|
-
"torch_compile_allowed": _config["torch_compile_allowed"],
|
156
|
-
"optimized_functions": list(_optimized_functions.keys()),
|
157
|
-
"original_functions": list(_original_functions.keys()),
|
158
|
-
}
|
159
|
-
|
160
|
-
|
161
|
-
# Auto-initialize optimizations based on environment
|
162
|
-
if os.environ.get("HEAVYBALL_AUTO_OPTIMIZE", "1") == "1":
|
163
|
-
enable_optimizations(
|
164
|
-
enable=_config["enabled"],
|
165
|
-
lra=_config["enable_lra"],
|
166
|
-
kron=_config["enable_kron"],
|
167
|
-
torch_compile=_config["torch_compile_allowed"],
|
168
|
-
verbose=_config["verbose"],
|
169
|
-
)
|