heavyball 1.7.1__py3-none-any.whl → 2.0.0.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- heavyball/__init__.py +193 -16
- heavyball/chainable.py +338 -190
- heavyball/helpers.py +804 -0
- heavyball/utils.py +813 -252
- heavyball-2.0.0.dev0.dist-info/METADATA +109 -0
- heavyball-2.0.0.dev0.dist-info/RECORD +9 -0
- {heavyball-1.7.1.dist-info → heavyball-2.0.0.dev0.dist-info}/WHEEL +1 -1
- heavyball/optimizations/__init__.py +0 -38
- heavyball/optimizations/integrator.py +0 -169
- heavyball/optimizations/optimizations.py +0 -329
- heavyball-1.7.1.dist-info/METADATA +0 -939
- heavyball-1.7.1.dist-info/RECORD +0 -11
- {heavyball-1.7.1.dist-info → heavyball-2.0.0.dev0.dist-info}/licenses/LICENSE +0 -0
- {heavyball-1.7.1.dist-info → heavyball-2.0.0.dev0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,109 @@
|
|
1
|
+
Metadata-Version: 2.4
|
2
|
+
Name: heavyball
|
3
|
+
Version: 2.0.0.dev0
|
4
|
+
Summary: Efficient Optimizers
|
5
|
+
Author-email: HeavyBall Authors <github.heavyball@nestler.sh>
|
6
|
+
Project-URL: source, https://github.com/HomebrewML/HeavyBall
|
7
|
+
Project-URL: tracker, https://github.com/HomebrewML/HeavyBall/issues
|
8
|
+
Keywords: torch,optimizer,muon,soap,psgd
|
9
|
+
Classifier: Intended Audience :: Developers
|
10
|
+
Classifier: Intended Audience :: Science/Research
|
11
|
+
Classifier: License :: OSI Approved :: BSD License
|
12
|
+
Classifier: Natural Language :: English
|
13
|
+
Classifier: Operating System :: OS Independent
|
14
|
+
Classifier: Programming Language :: Python :: 3
|
15
|
+
Requires-Python: >=3.9
|
16
|
+
Description-Content-Type: text/markdown
|
17
|
+
License-File: LICENSE
|
18
|
+
Requires-Dist: opt-einsum>=3.4.0
|
19
|
+
Requires-Dist: torch>=2.1.0
|
20
|
+
Requires-Dist: numpy
|
21
|
+
Provides-Extra: dev
|
22
|
+
Requires-Dist: pre-commit; extra == "dev"
|
23
|
+
Requires-Dist: pytest; extra == "dev"
|
24
|
+
Requires-Dist: ruff; extra == "dev"
|
25
|
+
Requires-Dist: matplotlib; extra == "dev"
|
26
|
+
Requires-Dist: seaborn; extra == "dev"
|
27
|
+
Requires-Dist: hyperopt; extra == "dev"
|
28
|
+
Requires-Dist: pandas; extra == "dev"
|
29
|
+
Requires-Dist: typer; extra == "dev"
|
30
|
+
Dynamic: license-file
|
31
|
+
|
32
|
+
# heavyball
|
33
|
+
|
34
|
+
[][pypi]
|
35
|
+
[][license]
|
36
|
+
|
37
|
+
_High-performance, extensible, chainable optimizers for PyTorch._
|
38
|
+
|
39
|
+
## Why heavyball
|
40
|
+
|
41
|
+
- **Lightning-Fast Training**: Batched `foreach` operations deliver significant speedups on large models.
|
42
|
+
- **Adaptive & Extensible**: Built-in AdamW, RMSprop, Schedule-Free algorithms, and PaLM-inspired schedules.
|
43
|
+
- **Plug-and-Play**: Drop-in replacements for `torch.optim` with seamless integration.
|
44
|
+
- **Customizable**: Chainable API lets you compose optimizers and transforms (MARS correction, cautious updates, orthogonal updates).
|
45
|
+
- **Battle-Tested**: Extensive benchmarks and real-world examples included.
|
46
|
+
|
47
|
+
## Key Features
|
48
|
+
|
49
|
+
- Foreach-based optimizers: `ForeachAdamW`, `ForeachRMSprop`, `ForeachSFAdamW`, `Muon`, `ADOPT`, `MSAM`, …
|
50
|
+
- Schedule-Free optimizers with dynamic learning rate adaptation.
|
51
|
+
- Advanced update rules: MARS correction, cautious updates, PaLM beta2 scheduling.
|
52
|
+
- Chainable transforms for custom optimization recipes.
|
53
|
+
- Comprehensive benchmark suite (`benchmark/`).
|
54
|
+
- Detailed documentation and example-driven tutorials.
|
55
|
+
|
56
|
+
## Quickstart
|
57
|
+
|
58
|
+
**Install:**
|
59
|
+
```bash
|
60
|
+
pip install heavyball
|
61
|
+
```
|
62
|
+
|
63
|
+
**Basic usage:**
|
64
|
+
```python
|
65
|
+
import torch
|
66
|
+
from torch import nn
|
67
|
+
from heavyball import ForeachAdamW
|
68
|
+
|
69
|
+
model = nn.Sequential(
|
70
|
+
nn.Linear(128, 64), nn.ReLU(), nn.Linear(64, 10)
|
71
|
+
)
|
72
|
+
optimizer = ForeachAdamW(model.parameters(), lr=1e-3)
|
73
|
+
|
74
|
+
for data, target in dataloader:
|
75
|
+
optimizer.zero_grad()
|
76
|
+
output = model(data)
|
77
|
+
loss = torch.nn.functional.cross_entropy(output, target)
|
78
|
+
loss.backward()
|
79
|
+
optimizer.step()
|
80
|
+
```
|
81
|
+
|
82
|
+
## Benchmarks
|
83
|
+
|
84
|
+
> Reproduce benchmarks with:
|
85
|
+
> ```bash
|
86
|
+
> python3 -m benchmark.run_all_benchmarks --opt ForeachSOAP --opt LaProp --opt AdamW --opt Muon --opt ForeachCachedNewtonPSGD --opt RMSprop --opt OrthoLaProp --opt ForeachSFAdamW --opt ForeachADOPT --opt LaPropOrtho --opt CachedPSGDKron --opt SignLaProp --opt ForeachSOLP --opt PSGDLRA --opt NewtonPSGDLRA --opt NewtonHybrid2PSGDKron --opt NewtonHybrid2PSGDLRA --opt mars-NewtonHybrid2PSGDLRA --opt MSAMLaProp --opt mars-adaptive-NewtonHybrid2PSGDKron --opt mars-ortho-NewtonHybrid2PSGDKron --opt MuonLaProp --opt mars-unscaled-NewtonHybrid2PSGDKron --opt mars-NewtonHybrid2PSGDKron --opt cautious-AdamW --opt unscaled_cautious-AdamW --opt mars-AdamW --dtype float32 --steps 1000000 --trials 1000 --parallelism 256 --seeds 1 --difficulties trivial --difficulties easy --difficulties medium --difficulties hard --difficulties extreme --difficulties nightmare --timeout 2880
|
87
|
+
> ```
|
88
|
+
|
89
|
+
|
90
|
+
## Contributing
|
91
|
+
|
92
|
+
We welcome contributions! Please check the [issue tracker][tracker] and follow these steps:
|
93
|
+
1. Fork the repo and create a feature branch.
|
94
|
+
2. Install dev dependencies: `pip install -e .[dev]`.
|
95
|
+
3. Run tests: `pytest`.
|
96
|
+
4. Submit a pull request.
|
97
|
+
|
98
|
+
## License
|
99
|
+
|
100
|
+
BSD 3-Clause — see the [LICENSE](LICENSE) file.
|
101
|
+
|
102
|
+
---
|
103
|
+
<p align="center">
|
104
|
+
Made by the HeavyBall team.
|
105
|
+
</p>
|
106
|
+
|
107
|
+
[pypi]: https://pypi.org/project/heavyball/
|
108
|
+
[license]: LICENSE
|
109
|
+
[tracker]: https://github.com/HomebrewML/HeavyBall/issues
|
@@ -0,0 +1,9 @@
|
|
1
|
+
heavyball/__init__.py,sha256=BUKhEKbpIY-93nMd5CEkhKv6d6LbSAstY_WSpFGGzQg,26017
|
2
|
+
heavyball/chainable.py,sha256=aoPG_MtVncDqZRKeKsIkmnTsYLhmF3_I_06ZMovgTKc,39190
|
3
|
+
heavyball/helpers.py,sha256=gKUhzu38f8e7CdnBbR8M51g9w8w0Kft_RW3V2fMXKw0,30159
|
4
|
+
heavyball/utils.py,sha256=jrFHeZ0MjkiUutuLTT8ky8G_kaDFxIZACsfK_tRVQaE,90204
|
5
|
+
heavyball-2.0.0.dev0.dist-info/licenses/LICENSE,sha256=G9fFZcNIVWjU7o6Pr_4sJBRCNDU5X-zelSxIJ2D48ms,1323
|
6
|
+
heavyball-2.0.0.dev0.dist-info/METADATA,sha256=Au0Wdlt4YpTMnVRZdH29Svjwbuqks4fxivdFRXWPP7I,4254
|
7
|
+
heavyball-2.0.0.dev0.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
|
8
|
+
heavyball-2.0.0.dev0.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
|
9
|
+
heavyball-2.0.0.dev0.dist-info/RECORD,,
|
@@ -1,38 +0,0 @@
|
|
1
|
-
"""
|
2
|
-
PSGD optimization module - optimized implementations of PSGD functions
|
3
|
-
to improve execution speed while maintaining numerical equivalence.
|
4
|
-
"""
|
5
|
-
|
6
|
-
# Import optimized functions
|
7
|
-
# Import integrator API
|
8
|
-
from .integrator import (
|
9
|
-
enable_optimizations,
|
10
|
-
get_optimization_status,
|
11
|
-
restore_original_functions,
|
12
|
-
)
|
13
|
-
from .optimizations import (
|
14
|
-
# LRA optimizations
|
15
|
-
low_rank_mm_optimized,
|
16
|
-
lra_precond_optimized,
|
17
|
-
precond_grad_cached_optimized,
|
18
|
-
# KRON optimizations
|
19
|
-
psgd_calc_A_and_conjB_optimized,
|
20
|
-
psgd_precond_grad_optimized,
|
21
|
-
psgd_update_precond_optimized,
|
22
|
-
update_lra_precond_optimized,
|
23
|
-
)
|
24
|
-
|
25
|
-
__all__ = [
|
26
|
-
# Optimized functions
|
27
|
-
"low_rank_mm_optimized",
|
28
|
-
"update_lra_precond_optimized",
|
29
|
-
"lra_precond_optimized",
|
30
|
-
"psgd_calc_A_and_conjB_optimized",
|
31
|
-
"psgd_update_precond_optimized",
|
32
|
-
"psgd_precond_grad_optimized",
|
33
|
-
"precond_grad_cached_optimized",
|
34
|
-
# Integrator API
|
35
|
-
"enable_optimizations",
|
36
|
-
"restore_original_functions",
|
37
|
-
"get_optimization_status",
|
38
|
-
]
|
@@ -1,169 +0,0 @@
|
|
1
|
-
"""
|
2
|
-
Integration module to selectively enable optimized implementations
|
3
|
-
of PSGD functions while maintaining API compatibility.
|
4
|
-
"""
|
5
|
-
|
6
|
-
import os
|
7
|
-
import sys
|
8
|
-
from typing import Any, Dict
|
9
|
-
|
10
|
-
import torch
|
11
|
-
|
12
|
-
from . import optimizations
|
13
|
-
from .. import utils
|
14
|
-
|
15
|
-
# Store original function references
|
16
|
-
_original_functions = {}
|
17
|
-
_optimized_functions = {}
|
18
|
-
|
19
|
-
# Mapping of original functions to their optimized versions
|
20
|
-
OPTIMIZATION_MAP = {
|
21
|
-
# LRA functions
|
22
|
-
utils.update_lra_precond_: optimizations.update_lra_precond_optimized,
|
23
|
-
utils.lra_precond: optimizations.lra_precond_optimized,
|
24
|
-
# KRON functions
|
25
|
-
utils.psgd_update_precond: optimizations.psgd_update_precond_optimized,
|
26
|
-
utils.psgd_precond_grad: optimizations.psgd_precond_grad_optimized,
|
27
|
-
utils.precond_grad_cached_: optimizations.precond_grad_cached_optimized,
|
28
|
-
}
|
29
|
-
|
30
|
-
# Config for enabling/disabling optimizations
|
31
|
-
_config = {
|
32
|
-
"enabled": os.environ.get("HEAVYBALL_OPTIMIZE", "1") == "1",
|
33
|
-
"torch_compile_allowed": os.environ.get("HEAVYBALL_USE_COMPILE", "1") == "1",
|
34
|
-
"enable_lra": True,
|
35
|
-
"enable_kron": True,
|
36
|
-
"verbose": os.environ.get("HEAVYBALL_VERBOSE", "0") == "1",
|
37
|
-
}
|
38
|
-
|
39
|
-
|
40
|
-
def _apply_monkey_patch(original_func, optimized_func):
|
41
|
-
"""Monkey patch a function with its optimized version."""
|
42
|
-
if original_func not in _original_functions:
|
43
|
-
_original_functions[original_func] = original_func
|
44
|
-
|
45
|
-
# Store reference to the optimized function
|
46
|
-
_optimized_functions[original_func] = optimized_func
|
47
|
-
|
48
|
-
# Get the module where the original function is defined
|
49
|
-
module = original_func.__module__
|
50
|
-
func_name = original_func.__name__
|
51
|
-
|
52
|
-
# Replace the function in its module
|
53
|
-
if hasattr(sys.modules[module], func_name):
|
54
|
-
setattr(sys.modules[module], func_name, optimized_func)
|
55
|
-
|
56
|
-
if _config["verbose"]:
|
57
|
-
print(f"Replaced {module}.{func_name} with optimized version")
|
58
|
-
else:
|
59
|
-
if _config["verbose"]:
|
60
|
-
print(f"Warning: Could not find {func_name} in module {module}")
|
61
|
-
|
62
|
-
|
63
|
-
def enable_optimizations(
|
64
|
-
enable: bool = True, lra: bool = True, kron: bool = True, torch_compile: bool = True, verbose: bool = False
|
65
|
-
):
|
66
|
-
"""
|
67
|
-
Enable or disable PSGD optimizations.
|
68
|
-
|
69
|
-
Args:
|
70
|
-
enable: Whether to enable optimizations at all
|
71
|
-
lra: Whether to enable LRA-specific optimizations
|
72
|
-
kron: Whether to enable Kron-specific optimizations
|
73
|
-
torch_compile: Whether to allow torch.compile optimizations
|
74
|
-
verbose: Whether to print optimization status messages
|
75
|
-
"""
|
76
|
-
_config["enabled"] = enable
|
77
|
-
_config["enable_lra"] = lra
|
78
|
-
_config["enable_kron"] = kron
|
79
|
-
_config["torch_compile_allowed"] = torch_compile
|
80
|
-
_config["verbose"] = verbose
|
81
|
-
|
82
|
-
if verbose:
|
83
|
-
print(f"PSGD Optimizations: {'enabled' if enable else 'disabled'}")
|
84
|
-
print(f" - LRA optimizations: {'enabled' if lra else 'disabled'}")
|
85
|
-
print(f" - KRON optimizations: {'enabled' if kron else 'disabled'}")
|
86
|
-
print(f" - torch.compile: {'allowed' if torch_compile else 'disabled'}")
|
87
|
-
|
88
|
-
if not enable:
|
89
|
-
# Restore original functions
|
90
|
-
restore_original_functions()
|
91
|
-
return
|
92
|
-
|
93
|
-
# Apply optimizations based on config
|
94
|
-
for orig_func, opt_func in OPTIMIZATION_MAP.items():
|
95
|
-
# Skip LRA functions if disabled
|
96
|
-
if not _config["enable_lra"] and orig_func in [utils.update_lra_precond_, utils.lra_precond]:
|
97
|
-
continue
|
98
|
-
|
99
|
-
# Skip KRON functions if disabled
|
100
|
-
if not _config["enable_kron"] and orig_func in [
|
101
|
-
utils.psgd_update_precond,
|
102
|
-
utils.psgd_precond_grad,
|
103
|
-
utils.precond_grad_cached_,
|
104
|
-
]:
|
105
|
-
continue
|
106
|
-
|
107
|
-
_apply_monkey_patch(orig_func, opt_func)
|
108
|
-
|
109
|
-
# Disable torch.compile if not allowed
|
110
|
-
if not _config["torch_compile_allowed"]:
|
111
|
-
# Monkey patch torch.compile to be a no-op
|
112
|
-
def _noop_compile(fn, **kwargs):
|
113
|
-
return fn
|
114
|
-
|
115
|
-
if not hasattr(torch, "_original_compile"):
|
116
|
-
torch._original_compile = torch.compile
|
117
|
-
torch.compile = _noop_compile
|
118
|
-
if verbose:
|
119
|
-
print("Disabled torch.compile (replaced with no-op)")
|
120
|
-
else:
|
121
|
-
# Restore original torch.compile
|
122
|
-
if hasattr(torch, "_original_compile"):
|
123
|
-
torch.compile = torch._original_compile
|
124
|
-
del torch._original_compile
|
125
|
-
if verbose:
|
126
|
-
print("Restored original torch.compile")
|
127
|
-
|
128
|
-
|
129
|
-
def restore_original_functions():
|
130
|
-
"""Restore all original function implementations."""
|
131
|
-
for orig_func, func_ref in _original_functions.items():
|
132
|
-
module = orig_func.__module__
|
133
|
-
func_name = orig_func.__name__
|
134
|
-
|
135
|
-
if hasattr(sys.modules[module], func_name):
|
136
|
-
setattr(sys.modules[module], func_name, func_ref)
|
137
|
-
|
138
|
-
if _config["verbose"]:
|
139
|
-
print(f"Restored original implementation of {module}.{func_name}")
|
140
|
-
|
141
|
-
# Also restore torch.compile if it was modified
|
142
|
-
if hasattr(torch, "_original_compile"):
|
143
|
-
torch.compile = torch._original_compile
|
144
|
-
del torch._original_compile
|
145
|
-
if _config["verbose"]:
|
146
|
-
print("Restored original torch.compile")
|
147
|
-
|
148
|
-
|
149
|
-
def get_optimization_status() -> Dict[str, Any]:
|
150
|
-
"""Get current optimization status."""
|
151
|
-
return {
|
152
|
-
"enabled": _config["enabled"],
|
153
|
-
"lra_enabled": _config["enable_lra"],
|
154
|
-
"kron_enabled": _config["enable_kron"],
|
155
|
-
"torch_compile_allowed": _config["torch_compile_allowed"],
|
156
|
-
"optimized_functions": list(_optimized_functions.keys()),
|
157
|
-
"original_functions": list(_original_functions.keys()),
|
158
|
-
}
|
159
|
-
|
160
|
-
|
161
|
-
# Auto-initialize optimizations based on environment
|
162
|
-
if os.environ.get("HEAVYBALL_AUTO_OPTIMIZE", "1") == "1":
|
163
|
-
enable_optimizations(
|
164
|
-
enable=_config["enabled"],
|
165
|
-
lra=_config["enable_lra"],
|
166
|
-
kron=_config["enable_kron"],
|
167
|
-
torch_compile=_config["torch_compile_allowed"],
|
168
|
-
verbose=_config["verbose"],
|
169
|
-
)
|
@@ -1,329 +0,0 @@
|
|
1
|
-
import random
|
2
|
-
from typing import List, Optional
|
3
|
-
|
4
|
-
import torch
|
5
|
-
from torch import Tensor
|
6
|
-
|
7
|
-
from .. import utils
|
8
|
-
from ..utils import decorator, decorator_knowngood, min_dtype, scalar_guard, tiny_bf16
|
9
|
-
|
10
|
-
#############################
|
11
|
-
# PSGD LRA OPTIMIZATIONS
|
12
|
-
#############################
|
13
|
-
|
14
|
-
|
15
|
-
@decorator
|
16
|
-
def low_rank_mm_optimized(U: Tensor, V: Tensor, x: Tensor) -> Tensor:
|
17
|
-
"""Optimized version of low_rank_mm using fused operations and memory reuse"""
|
18
|
-
dtype = min_dtype([U, V, x])
|
19
|
-
# Convert only once and cache the result
|
20
|
-
U_dt, V_dt, x_dt = U.to(dtype), V.to(dtype), x.to(dtype)
|
21
|
-
|
22
|
-
# Use a more efficient implementation that avoids multiple conversions
|
23
|
-
# torch.bmm can be more efficient than einsum for this specific pattern
|
24
|
-
if U.dim() == 2: # This is the common case (batch, rank)
|
25
|
-
# Shape of result: (batch, )
|
26
|
-
tmp = torch.mul(U_dt, x_dt.unsqueeze(-1)).sum(dim=0) # (rank, )
|
27
|
-
result = torch.mv(V_dt, tmp) # (batch, )
|
28
|
-
return result.to(x.dtype) + x
|
29
|
-
else:
|
30
|
-
# Fallback to original implementation for other dimensionalities
|
31
|
-
return x + torch.einsum("br,gr,g->b", U_dt, V_dt, x_dt).to(x.dtype)
|
32
|
-
|
33
|
-
|
34
|
-
@torch.compile(mode="reduce-overhead")
|
35
|
-
def update_lra_precond_core(
|
36
|
-
U: Tensor, V: Tensor, d: Tensor, vector: Tensor, hessian_vector: Tensor, eps: float, step: float, delayed: bool
|
37
|
-
):
|
38
|
-
"""Core computational part of update_lra_precond optimized with torch.compile"""
|
39
|
-
# Here we apply torch.compile to the computational bottleneck
|
40
|
-
# All inputs are already properly typed and processed
|
41
|
-
|
42
|
-
Qh = low_rank_mm_optimized(U, V, d * hessian_vector)
|
43
|
-
Ph = d * low_rank_mm_optimized(V, U, Qh)
|
44
|
-
rank = U.size(1)
|
45
|
-
|
46
|
-
# Cache VtU computation which is used multiple times
|
47
|
-
VtU = torch.einsum("br,bn->rn", V, U) # (rank, rank)
|
48
|
-
I = torch.eye(rank, dtype=VtU.dtype, device=VtU.device)
|
49
|
-
IpVtU = I + VtU
|
50
|
-
invQtv = vector / d
|
51
|
-
|
52
|
-
# LU factorization to reuse computation
|
53
|
-
LU, pivots = torch.linalg.lu_factor(IpVtU)
|
54
|
-
|
55
|
-
# Compute vectors inline to reduce memory allocation
|
56
|
-
invQtv = invQtv - V @ torch.linalg.lu_solve(LU, pivots, (U.T @ invQtv).view(-1, 1), adjoint=True).flatten()
|
57
|
-
invPv = invQtv - U @ torch.linalg.lu_solve(LU, pivots, (V.T @ invQtv).view(-1, 1)).flatten()
|
58
|
-
invPv = invPv / d
|
59
|
-
|
60
|
-
# Compute nabla D
|
61
|
-
nablaD = Ph * hessian_vector - vector * invPv
|
62
|
-
|
63
|
-
# Compute divisor more efficiently using fused operations
|
64
|
-
Ph_squared = Ph.square()
|
65
|
-
vector_squared = vector.square()
|
66
|
-
hv_squared = hessian_vector.square()
|
67
|
-
invPv_squared = invPv.square()
|
68
|
-
|
69
|
-
divisor = (Ph_squared + vector_squared) * (hv_squared + invPv_squared)
|
70
|
-
divisor = divisor.add(eps).sqrt().max()
|
71
|
-
d_step = step / divisor
|
72
|
-
|
73
|
-
# Compute for gradient update
|
74
|
-
a, b = Qh, invQtv
|
75
|
-
|
76
|
-
# Update either U or V, not both at the same time
|
77
|
-
precond_u = random.random() < 0.5
|
78
|
-
precond = V if precond_u else U
|
79
|
-
|
80
|
-
# Cache computations that get reused
|
81
|
-
atV = torch.einsum("b,br->r", a, precond)
|
82
|
-
btV = torch.einsum("b,br->r", b, precond)
|
83
|
-
atVVt = torch.einsum("r,br->b", atV, precond)
|
84
|
-
btVVt = torch.einsum("r,br->b", btV, precond)
|
85
|
-
|
86
|
-
# Compute step size
|
87
|
-
precond_step = step / (a.norm() * atVVt.norm() + b.norm() * btVVt.norm() + eps)
|
88
|
-
|
89
|
-
# Update precond matrix
|
90
|
-
if precond_u:
|
91
|
-
a_new = torch.einsum("b,r,rg->bg", a, atV, IpVtU)
|
92
|
-
b_new = torch.einsum("b,r,rg->bg", b, btV, IpVtU)
|
93
|
-
else:
|
94
|
-
# Optimize with in-place operations where possible
|
95
|
-
a_new = a + torch.einsum("br,r->b", V, atV)
|
96
|
-
b_new = b + torch.einsum("br,r->b", V, btV)
|
97
|
-
a_new = torch.einsum("b,r->br", a_new, atV)
|
98
|
-
b_new = torch.einsum("b,r->br", b_new, btV)
|
99
|
-
|
100
|
-
# Return updated values
|
101
|
-
return d, nablaD, d_step, U if precond_u else V, b_new - a_new, precond_step, precond_u
|
102
|
-
|
103
|
-
|
104
|
-
def update_lra_precond_optimized(
|
105
|
-
U: List[Tensor],
|
106
|
-
V: List[Tensor],
|
107
|
-
d: List[Tensor],
|
108
|
-
vector: Tensor,
|
109
|
-
hessian_vector: Tensor,
|
110
|
-
eps: float,
|
111
|
-
step: float,
|
112
|
-
delayed: bool,
|
113
|
-
):
|
114
|
-
"""
|
115
|
-
Optimized version of update_lra_precond_ with:
|
116
|
-
1. Reduced memory allocations
|
117
|
-
2. Fused operations
|
118
|
-
3. Torch.compile for core computations
|
119
|
-
4. Better caching of intermediate results
|
120
|
-
"""
|
121
|
-
U_orig, V_orig, d_orig = U, V, d
|
122
|
-
|
123
|
-
# Flatten once
|
124
|
-
U_flat, V_flat, d_flat = utils.flatten(U, 1), utils.flatten(V, 1), utils.flatten(d)
|
125
|
-
|
126
|
-
# Convert dtype once
|
127
|
-
dtype = min_dtype([U_flat, V_flat, vector, hessian_vector])
|
128
|
-
U_dt = U_flat.to(dtype)
|
129
|
-
V_dt = V_flat.to(dtype)
|
130
|
-
vector_dt = vector.to(dtype)
|
131
|
-
hv_dt = hessian_vector.to(dtype)
|
132
|
-
|
133
|
-
# Convert scalar once
|
134
|
-
eps_tensor = scalar_guard(eps, vector)
|
135
|
-
|
136
|
-
try:
|
137
|
-
# Run optimized core computation with torch.compile
|
138
|
-
d_flat, nablaD, d_step, precond, update, precond_step, precond_u = update_lra_precond_core(
|
139
|
-
U_dt, V_dt, d_flat, vector_dt, hv_dt, eps, step, delayed
|
140
|
-
)
|
141
|
-
|
142
|
-
# Apply updates efficiently
|
143
|
-
utils.apply_flat_add(d_orig, d_flat * nablaD, -d_step)
|
144
|
-
utils.apply_flat_add(U_orig if precond_u else V_orig, update, precond_step)
|
145
|
-
|
146
|
-
# For immediate updates
|
147
|
-
if not delayed:
|
148
|
-
utils.stochastic_add_([d], [d_flat * nablaD], -d_step)
|
149
|
-
utils.stochastic_add_([U if precond_u else V], [update], precond_step)
|
150
|
-
|
151
|
-
return U_flat.to(U_orig[0].dtype), V_flat.to(V_orig[0].dtype), d_flat.to(d_orig[0].dtype)
|
152
|
-
|
153
|
-
except RuntimeError:
|
154
|
-
# Fallback to original implementation on failure
|
155
|
-
return utils.update_lra_precond_(U, V, d, vector, hessian_vector, eps, step, delayed)
|
156
|
-
|
157
|
-
|
158
|
-
@decorator
|
159
|
-
def lra_precond_optimized(U, V, d, g):
|
160
|
-
"""
|
161
|
-
Optimized version of lra_precond using memory caching and fused operations
|
162
|
-
"""
|
163
|
-
# Get the common dtype only once
|
164
|
-
dtype = min_dtype([U, V, d, g])
|
165
|
-
|
166
|
-
# Convert to this dtype once
|
167
|
-
U_dt, V_dt, d_dt, g_dt = U.to(dtype), V.to(dtype), d.to(dtype), g.to(dtype)
|
168
|
-
|
169
|
-
# First part: g_mid = d * g
|
170
|
-
g_mid = d_dt * g_dt
|
171
|
-
|
172
|
-
# Second part: Qh = low_rank_mm(U, V, g_mid)
|
173
|
-
# Use optimized low_rank_mm
|
174
|
-
Qh = low_rank_mm_optimized(U_dt, V_dt, g_mid)
|
175
|
-
|
176
|
-
# Third part: result = d * low_rank_mm(V, U, Qh)
|
177
|
-
result = d_dt * low_rank_mm_optimized(V_dt, U_dt, Qh)
|
178
|
-
|
179
|
-
# Return result in original dtype
|
180
|
-
return result.to(g.dtype)
|
181
|
-
|
182
|
-
|
183
|
-
#############################
|
184
|
-
# PSGD KRON OPTIMIZATIONS
|
185
|
-
#############################
|
186
|
-
|
187
|
-
|
188
|
-
@decorator
|
189
|
-
def psgd_calc_A_and_conjB_optimized(exprA, G, Q, conjB):
|
190
|
-
"""Optimized version of psgd_calc_A_and_conjB using torch.compile and memory reuse"""
|
191
|
-
order = G.dim()
|
192
|
-
if order > 1:
|
193
|
-
conjB = conjB.view_as(G).permute(*range(1, order), 0)
|
194
|
-
|
195
|
-
# Convert dtype once
|
196
|
-
G_dtype = utils.promote(G.dtype)
|
197
|
-
conjB = conjB.to(G_dtype)
|
198
|
-
|
199
|
-
# Compute A using einsum (could be cached if called multiple times with same Q, G)
|
200
|
-
A = utils.casted_einsum(exprA, *Q, G)
|
201
|
-
|
202
|
-
# Process each Q matrix with potential optimizations
|
203
|
-
for i, q in enumerate(Q):
|
204
|
-
q = utils.promote(q)
|
205
|
-
if q.dim() <= 1:
|
206
|
-
# Scalar case - use in-place division
|
207
|
-
conjB.div_(q)
|
208
|
-
else:
|
209
|
-
# Matrix case - use optimized triangular solve
|
210
|
-
# Reshape once and contiguous to optimize memory access
|
211
|
-
conjB_reshaped = conjB.reshape(-1, q.size(0)).contiguous()
|
212
|
-
solved = torch.linalg.solve_triangular(q, conjB_reshaped, upper=True, left=False)
|
213
|
-
conjB = solved.reshape_as(conjB)
|
214
|
-
|
215
|
-
# Only transpose if needed for next iteration
|
216
|
-
if i < order - 1:
|
217
|
-
conjB = conjB.transpose(i, -1)
|
218
|
-
|
219
|
-
return A, conjB
|
220
|
-
|
221
|
-
|
222
|
-
@torch.compile(mode="reduce-overhead")
|
223
|
-
def psgd_update_precond_core(Q, term1, term2, precond_lr, norm, q):
|
224
|
-
"""Core computation of psgd_update_precond optimized with torch.compile"""
|
225
|
-
term1 *= precond_lr
|
226
|
-
if q.dim() < 2:
|
227
|
-
term1 *= q / norm.clamp_(min=tiny_bf16)
|
228
|
-
else:
|
229
|
-
torch.triu(term1, out=term1)
|
230
|
-
term1 /= torch.where(norm > 0, utils.psgd_lb(term2, norm), norm).clamp_(tiny_bf16)
|
231
|
-
term1 = torch.mm(term1, q)
|
232
|
-
return term1
|
233
|
-
|
234
|
-
|
235
|
-
def psgd_update_precond_optimized(Q, exprs, G, precond_lr, oq, store_triu_as_line, V):
|
236
|
-
"""Optimized version of psgd_update_precond with reduced allocations and torch.compile"""
|
237
|
-
exprA, exprGs, _ = exprs
|
238
|
-
|
239
|
-
# Use optimized A and conjB calculation
|
240
|
-
A, conjB = psgd_calc_A_and_conjB_optimized(exprA, G, Q, V)
|
241
|
-
|
242
|
-
# Process each Q matrix with optimizations
|
243
|
-
for q, exprG, o in zip(Q, exprGs, oq):
|
244
|
-
# Use optimized einsum implementations
|
245
|
-
term1 = utils.promote(torch.einsum(exprG, A, A))
|
246
|
-
term2 = utils.promote(torch.einsum(exprG, conjB, conjB))
|
247
|
-
|
248
|
-
# Compute the update using compiled core function
|
249
|
-
term1, term2 = term1 - term2, term1 + term2
|
250
|
-
norm = term2.norm(float("inf"))
|
251
|
-
|
252
|
-
try:
|
253
|
-
# Try to use the optimized core calculation
|
254
|
-
term1 = psgd_update_precond_core(Q, term1, term2, precond_lr, norm, q.to(term1.dtype))
|
255
|
-
except (RuntimeError, TypeError):
|
256
|
-
# Fallback to original implementation
|
257
|
-
term1 *= precond_lr
|
258
|
-
if q.dim() < 2:
|
259
|
-
term1 *= q.to(term1.dtype) / norm.clamp_(min=tiny_bf16)
|
260
|
-
else:
|
261
|
-
torch.triu(term1, out=term1)
|
262
|
-
term1 /= torch.where(norm > 0, utils.psgd_lb(term2, norm), norm).clamp_(tiny_bf16)
|
263
|
-
term1 = torch.mm(term1, q.to(term1.dtype))
|
264
|
-
|
265
|
-
# Convert to line format if needed
|
266
|
-
if store_triu_as_line:
|
267
|
-
term1 = utils.triu_to_line([term1])[0][1]
|
268
|
-
# Apply update directly
|
269
|
-
if o.dim() > 0:
|
270
|
-
o.add_(term1)
|
271
|
-
else:
|
272
|
-
o = term1
|
273
|
-
else:
|
274
|
-
# Apply update directly
|
275
|
-
o.add_(term1)
|
276
|
-
|
277
|
-
|
278
|
-
@decorator_knowngood
|
279
|
-
def psgd_precond_grad_optimized(
|
280
|
-
expr: str, ea: Tensor, *preconds: Tensor, caution: bool = False, grad: Optional[Tensor] = None
|
281
|
-
):
|
282
|
-
"""Optimized version of psgd_precond_grad with better memory management"""
|
283
|
-
if caution:
|
284
|
-
ea = utils._compilable_cautioning(grad, ea)
|
285
|
-
|
286
|
-
# Determine minimum dtype once
|
287
|
-
md = min_dtype(list(preconds) + [ea])
|
288
|
-
|
289
|
-
# Convert all tensors to the same dtype once
|
290
|
-
args = [q.to(md) for q in preconds]
|
291
|
-
ea_md = ea.to(md)
|
292
|
-
|
293
|
-
# Optimize the einsum operation by avoiding duplicate conversions
|
294
|
-
# and potentially making args contiguous if beneficial
|
295
|
-
args_contiguous = [arg.contiguous() if not arg.is_contiguous() else arg for arg in args]
|
296
|
-
args_double = args_contiguous + args_contiguous
|
297
|
-
|
298
|
-
# Call einsum once with the combined args list
|
299
|
-
new = torch.einsum(expr, *(args_double + [ea_md]))
|
300
|
-
|
301
|
-
# Convert result back to original dtype
|
302
|
-
return new.to(ea.dtype)
|
303
|
-
|
304
|
-
|
305
|
-
@decorator_knowngood
|
306
|
-
def precond_grad_cached_optimized(
|
307
|
-
expr: str, ea: Tensor, *cached_q: Tensor, caution: bool = False, grad: Optional[Tensor] = None, cast: bool = True
|
308
|
-
):
|
309
|
-
"""Optimized version of precond_grad_cached_ with better memory management"""
|
310
|
-
if caution:
|
311
|
-
ea = utils._compilable_cautioning(grad, ea)
|
312
|
-
|
313
|
-
# Determine minimum dtype once
|
314
|
-
md = min_dtype(list(cached_q) + [ea])
|
315
|
-
|
316
|
-
# Convert all tensors to the same dtype once and make contiguous if needed
|
317
|
-
args = [q.to(md).contiguous() if not q.is_contiguous() else q.to(md) for q in cached_q]
|
318
|
-
ea_md = ea.to(md).contiguous() if not ea.is_contiguous() else ea.to(md)
|
319
|
-
|
320
|
-
# Add ea_md to args
|
321
|
-
args.append(ea_md)
|
322
|
-
|
323
|
-
# Call einsum once with the optimized args
|
324
|
-
new = torch.einsum(expr, *args)
|
325
|
-
|
326
|
-
# Convert result back if needed
|
327
|
-
if cast:
|
328
|
-
return new.to(ea.dtype)
|
329
|
-
return new
|