heavyball 1.7.1__py3-none-any.whl → 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,122 @@
1
+ Metadata-Version: 2.4
2
+ Name: heavyball
3
+ Version: 2.0.0
4
+ Summary: Efficient Optimizers
5
+ Author-email: HeavyBall Authors <github.heavyball@nestler.sh>
6
+ Project-URL: source, https://github.com/HomebrewML/HeavyBall
7
+ Project-URL: tracker, https://github.com/HomebrewML/HeavyBall/issues
8
+ Keywords: torch,optimizer,muon,soap,psgd
9
+ Classifier: Intended Audience :: Developers
10
+ Classifier: Intended Audience :: Science/Research
11
+ Classifier: License :: OSI Approved :: BSD License
12
+ Classifier: Natural Language :: English
13
+ Classifier: Operating System :: OS Independent
14
+ Classifier: Programming Language :: Python :: 3
15
+ Requires-Python: >=3.9
16
+ Description-Content-Type: text/markdown
17
+ License-File: LICENSE
18
+ Requires-Dist: opt-einsum>=3.4.0
19
+ Requires-Dist: torch>=2.7.0
20
+ Requires-Dist: numpy
21
+ Provides-Extra: dev
22
+ Requires-Dist: pre-commit; extra == "dev"
23
+ Requires-Dist: pytest; extra == "dev"
24
+ Requires-Dist: ruff; extra == "dev"
25
+ Requires-Dist: matplotlib; extra == "dev"
26
+ Requires-Dist: seaborn; extra == "dev"
27
+ Requires-Dist: hyperopt; extra == "dev"
28
+ Requires-Dist: pandas; extra == "dev"
29
+ Requires-Dist: typer; extra == "dev"
30
+ Requires-Dist: optuna; extra == "dev"
31
+ Requires-Dist: optunahub; extra == "dev"
32
+ Requires-Dist: botorch; extra == "dev"
33
+ Requires-Dist: hebo; extra == "dev"
34
+ Dynamic: license-file
35
+
36
+ # heavyball
37
+
38
+ [![PyPI version](https://img.shields.io/pypi/v/heavyball?color=blue)][pypi]
39
+ [![License](https://img.shields.io/badge/license-BSD--3--Clause-blue.svg)][license]
40
+
41
+ _High-performance, extensible, chainable optimizers for PyTorch._
42
+
43
+ ## Why heavyball
44
+
45
+ - **Lightning-Fast Training**: Batched `foreach` operations deliver significant speedups on large models.
46
+ - **Adaptive & Extensible**: Built-in AdamW, RMSprop, Schedule-Free algorithms, and PaLM-inspired schedules.
47
+ - **Plug-and-Play**: Drop-in replacements for `torch.optim` with seamless integration.
48
+ - **Customizable**: Chainable API lets you compose optimizers and transforms (MARS correction, cautious updates, orthogonal updates).
49
+ - **Battle-Tested**: Extensive benchmarks and real-world examples included.
50
+
51
+ ## Key Features
52
+
53
+ - Foreach-based optimizers: `ForeachAdamW`, `ForeachRMSprop`, `ForeachSFAdamW`, `Muon`, `ADOPT`, `MSAM` (Momentum SAM), …
54
+ - Schedule-Free optimizers with dynamic learning rate adaptation.
55
+ - Advanced update rules: MARS correction, cautious updates, PaLM beta2 scheduling.
56
+ - Chainable transforms for custom optimization recipes.
57
+ - Comprehensive benchmark suite (`benchmark/`).
58
+ - Detailed documentation and example-driven tutorials.
59
+
60
+ ## Quickstart
61
+
62
+ **Install:**
63
+ ```bash
64
+ pip install heavyball
65
+ ```
66
+
67
+ **Basic usage:**
68
+ ```python
69
+ import torch
70
+ from torch import nn
71
+ from heavyball import ForeachAdamW
72
+
73
+ model = nn.Sequential(
74
+ nn.Linear(128, 64), nn.ReLU(), nn.Linear(64, 10)
75
+ )
76
+ optimizer = ForeachAdamW(model.parameters(), lr=1e-3)
77
+
78
+ for data, target in dataloader:
79
+ optimizer.zero_grad()
80
+ output = model(data)
81
+ loss = torch.nn.functional.cross_entropy(output, target)
82
+ loss.backward()
83
+ optimizer.step()
84
+ ```
85
+
86
+ ## Benchmarks
87
+
88
+ > Reproduce benchmarks with:
89
+ > ```bash
90
+ > python3 -m benchmark.run_all_benchmarks --opt ForeachSOAP --opt LaProp --opt AdamW --opt Muon --opt ForeachCachedNewtonPSGD --opt RMSprop --opt OrthoLaProp --opt ForeachSFAdamW --opt ForeachADOPT --opt LaPropOrtho --opt CachedPSGDKron --opt SignLaProp --opt ForeachSOLP --opt PSGDLRA --opt NewtonPSGDLRA --opt NewtonHybrid2PSGDKron --opt NewtonHybrid2PSGDLRA --opt mars-NewtonHybrid2PSGDLRA --opt MSAMLaProp --opt mars-adaptive-NewtonHybrid2PSGDKron --opt mars-ortho-NewtonHybrid2PSGDKron --opt MuonLaProp --opt mars-unscaled-NewtonHybrid2PSGDKron --opt mars-NewtonHybrid2PSGDKron --opt cautious-AdamW --opt unscaled_cautious-AdamW --opt mars-AdamW --dtype float32 --steps 1000000 --trials 1000 --parallelism 256 --seeds 1 --difficulties trivial --difficulties easy --difficulties medium --difficulties hard --difficulties extreme --difficulties nightmare --timeout 2880
91
+ > ```
92
+
93
+ ## Migrating from HeavyBall 1.x
94
+
95
+ - Read the detailed [2.0.0 migration notes](docs/heavyball-2.0.0-migration.md) for an end-to-end checklist, including guidance for restoring legacy behaviour when needed.
96
+ - Use `scripts/migrate_optimizer_state.py` to rewrite pre-2.0 optimizer checkpoints:
97
+ ```bash
98
+ python scripts/migrate_optimizer_state.py path/to/checkpoint.pt heavyball.ForeachAdamW --state-key optimizer
99
+ ```
100
+ The utility renames legacy state entries, fans them out per parameter view, and injects the HeavyBall metadata block expected by 2.0.0.
101
+
102
+
103
+ ## Contributing
104
+
105
+ We welcome contributions! Please check the [issue tracker][tracker] and follow these steps:
106
+ 1. Fork the repo and create a feature branch.
107
+ 2. Install dev dependencies: `pip install -e .[dev]`.
108
+ 3. Run tests: `pytest`.
109
+ 4. Submit a pull request.
110
+
111
+ ## License
112
+
113
+ BSD 3-Clause — see the [LICENSE](LICENSE) file.
114
+
115
+ ---
116
+ <p align="center">
117
+ Made by the HeavyBall team.
118
+ </p>
119
+
120
+ [pypi]: https://pypi.org/project/heavyball/
121
+ [license]: LICENSE
122
+ [tracker]: https://github.com/HomebrewML/HeavyBall/issues
@@ -0,0 +1,9 @@
1
+ heavyball/__init__.py,sha256=cabACszT-lZaLPkN2rANFxLxkpcDPb8q4p7iL8jfTtQ,28274
2
+ heavyball/chainable.py,sha256=l7uzXKMbJxKn-kgMR9In8BxYKuvIO_y8uL6P5b0LZo8,41250
3
+ heavyball/helpers.py,sha256=zk_S84wpGcvO9P6kn4UeaQUIDowHxcbM9qQITEm2g5I,30267
4
+ heavyball/utils.py,sha256=sK85OOhmPHvAbUjZhkgX5tRfE3ECai0Yx4Zvt4p8z2Q,97794
5
+ heavyball-2.0.0.dist-info/licenses/LICENSE,sha256=G9fFZcNIVWjU7o6Pr_4sJBRCNDU5X-zelSxIJ2D48ms,1323
6
+ heavyball-2.0.0.dist-info/METADATA,sha256=9JKPFmnvaMT9oigMckzaqI5DjUN0hJj3ePzc1ok3Ia8,4973
7
+ heavyball-2.0.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
8
+ heavyball-2.0.0.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
9
+ heavyball-2.0.0.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (78.1.0)
2
+ Generator: setuptools (80.9.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -1,38 +0,0 @@
1
- """
2
- PSGD optimization module - optimized implementations of PSGD functions
3
- to improve execution speed while maintaining numerical equivalence.
4
- """
5
-
6
- # Import optimized functions
7
- # Import integrator API
8
- from .integrator import (
9
- enable_optimizations,
10
- get_optimization_status,
11
- restore_original_functions,
12
- )
13
- from .optimizations import (
14
- # LRA optimizations
15
- low_rank_mm_optimized,
16
- lra_precond_optimized,
17
- precond_grad_cached_optimized,
18
- # KRON optimizations
19
- psgd_calc_A_and_conjB_optimized,
20
- psgd_precond_grad_optimized,
21
- psgd_update_precond_optimized,
22
- update_lra_precond_optimized,
23
- )
24
-
25
- __all__ = [
26
- # Optimized functions
27
- "low_rank_mm_optimized",
28
- "update_lra_precond_optimized",
29
- "lra_precond_optimized",
30
- "psgd_calc_A_and_conjB_optimized",
31
- "psgd_update_precond_optimized",
32
- "psgd_precond_grad_optimized",
33
- "precond_grad_cached_optimized",
34
- # Integrator API
35
- "enable_optimizations",
36
- "restore_original_functions",
37
- "get_optimization_status",
38
- ]
@@ -1,169 +0,0 @@
1
- """
2
- Integration module to selectively enable optimized implementations
3
- of PSGD functions while maintaining API compatibility.
4
- """
5
-
6
- import os
7
- import sys
8
- from typing import Any, Dict
9
-
10
- import torch
11
-
12
- from . import optimizations
13
- from .. import utils
14
-
15
- # Store original function references
16
- _original_functions = {}
17
- _optimized_functions = {}
18
-
19
- # Mapping of original functions to their optimized versions
20
- OPTIMIZATION_MAP = {
21
- # LRA functions
22
- utils.update_lra_precond_: optimizations.update_lra_precond_optimized,
23
- utils.lra_precond: optimizations.lra_precond_optimized,
24
- # KRON functions
25
- utils.psgd_update_precond: optimizations.psgd_update_precond_optimized,
26
- utils.psgd_precond_grad: optimizations.psgd_precond_grad_optimized,
27
- utils.precond_grad_cached_: optimizations.precond_grad_cached_optimized,
28
- }
29
-
30
- # Config for enabling/disabling optimizations
31
- _config = {
32
- "enabled": os.environ.get("HEAVYBALL_OPTIMIZE", "1") == "1",
33
- "torch_compile_allowed": os.environ.get("HEAVYBALL_USE_COMPILE", "1") == "1",
34
- "enable_lra": True,
35
- "enable_kron": True,
36
- "verbose": os.environ.get("HEAVYBALL_VERBOSE", "0") == "1",
37
- }
38
-
39
-
40
- def _apply_monkey_patch(original_func, optimized_func):
41
- """Monkey patch a function with its optimized version."""
42
- if original_func not in _original_functions:
43
- _original_functions[original_func] = original_func
44
-
45
- # Store reference to the optimized function
46
- _optimized_functions[original_func] = optimized_func
47
-
48
- # Get the module where the original function is defined
49
- module = original_func.__module__
50
- func_name = original_func.__name__
51
-
52
- # Replace the function in its module
53
- if hasattr(sys.modules[module], func_name):
54
- setattr(sys.modules[module], func_name, optimized_func)
55
-
56
- if _config["verbose"]:
57
- print(f"Replaced {module}.{func_name} with optimized version")
58
- else:
59
- if _config["verbose"]:
60
- print(f"Warning: Could not find {func_name} in module {module}")
61
-
62
-
63
- def enable_optimizations(
64
- enable: bool = True, lra: bool = True, kron: bool = True, torch_compile: bool = True, verbose: bool = False
65
- ):
66
- """
67
- Enable or disable PSGD optimizations.
68
-
69
- Args:
70
- enable: Whether to enable optimizations at all
71
- lra: Whether to enable LRA-specific optimizations
72
- kron: Whether to enable Kron-specific optimizations
73
- torch_compile: Whether to allow torch.compile optimizations
74
- verbose: Whether to print optimization status messages
75
- """
76
- _config["enabled"] = enable
77
- _config["enable_lra"] = lra
78
- _config["enable_kron"] = kron
79
- _config["torch_compile_allowed"] = torch_compile
80
- _config["verbose"] = verbose
81
-
82
- if verbose:
83
- print(f"PSGD Optimizations: {'enabled' if enable else 'disabled'}")
84
- print(f" - LRA optimizations: {'enabled' if lra else 'disabled'}")
85
- print(f" - KRON optimizations: {'enabled' if kron else 'disabled'}")
86
- print(f" - torch.compile: {'allowed' if torch_compile else 'disabled'}")
87
-
88
- if not enable:
89
- # Restore original functions
90
- restore_original_functions()
91
- return
92
-
93
- # Apply optimizations based on config
94
- for orig_func, opt_func in OPTIMIZATION_MAP.items():
95
- # Skip LRA functions if disabled
96
- if not _config["enable_lra"] and orig_func in [utils.update_lra_precond_, utils.lra_precond]:
97
- continue
98
-
99
- # Skip KRON functions if disabled
100
- if not _config["enable_kron"] and orig_func in [
101
- utils.psgd_update_precond,
102
- utils.psgd_precond_grad,
103
- utils.precond_grad_cached_,
104
- ]:
105
- continue
106
-
107
- _apply_monkey_patch(orig_func, opt_func)
108
-
109
- # Disable torch.compile if not allowed
110
- if not _config["torch_compile_allowed"]:
111
- # Monkey patch torch.compile to be a no-op
112
- def _noop_compile(fn, **kwargs):
113
- return fn
114
-
115
- if not hasattr(torch, "_original_compile"):
116
- torch._original_compile = torch.compile
117
- torch.compile = _noop_compile
118
- if verbose:
119
- print("Disabled torch.compile (replaced with no-op)")
120
- else:
121
- # Restore original torch.compile
122
- if hasattr(torch, "_original_compile"):
123
- torch.compile = torch._original_compile
124
- del torch._original_compile
125
- if verbose:
126
- print("Restored original torch.compile")
127
-
128
-
129
- def restore_original_functions():
130
- """Restore all original function implementations."""
131
- for orig_func, func_ref in _original_functions.items():
132
- module = orig_func.__module__
133
- func_name = orig_func.__name__
134
-
135
- if hasattr(sys.modules[module], func_name):
136
- setattr(sys.modules[module], func_name, func_ref)
137
-
138
- if _config["verbose"]:
139
- print(f"Restored original implementation of {module}.{func_name}")
140
-
141
- # Also restore torch.compile if it was modified
142
- if hasattr(torch, "_original_compile"):
143
- torch.compile = torch._original_compile
144
- del torch._original_compile
145
- if _config["verbose"]:
146
- print("Restored original torch.compile")
147
-
148
-
149
- def get_optimization_status() -> Dict[str, Any]:
150
- """Get current optimization status."""
151
- return {
152
- "enabled": _config["enabled"],
153
- "lra_enabled": _config["enable_lra"],
154
- "kron_enabled": _config["enable_kron"],
155
- "torch_compile_allowed": _config["torch_compile_allowed"],
156
- "optimized_functions": list(_optimized_functions.keys()),
157
- "original_functions": list(_original_functions.keys()),
158
- }
159
-
160
-
161
- # Auto-initialize optimizations based on environment
162
- if os.environ.get("HEAVYBALL_AUTO_OPTIMIZE", "1") == "1":
163
- enable_optimizations(
164
- enable=_config["enabled"],
165
- lra=_config["enable_lra"],
166
- kron=_config["enable_kron"],
167
- torch_compile=_config["torch_compile_allowed"],
168
- verbose=_config["verbose"],
169
- )
@@ -1,329 +0,0 @@
1
- import random
2
- from typing import List, Optional
3
-
4
- import torch
5
- from torch import Tensor
6
-
7
- from .. import utils
8
- from ..utils import decorator, decorator_knowngood, min_dtype, scalar_guard, tiny_bf16
9
-
10
- #############################
11
- # PSGD LRA OPTIMIZATIONS
12
- #############################
13
-
14
-
15
- @decorator
16
- def low_rank_mm_optimized(U: Tensor, V: Tensor, x: Tensor) -> Tensor:
17
- """Optimized version of low_rank_mm using fused operations and memory reuse"""
18
- dtype = min_dtype([U, V, x])
19
- # Convert only once and cache the result
20
- U_dt, V_dt, x_dt = U.to(dtype), V.to(dtype), x.to(dtype)
21
-
22
- # Use a more efficient implementation that avoids multiple conversions
23
- # torch.bmm can be more efficient than einsum for this specific pattern
24
- if U.dim() == 2: # This is the common case (batch, rank)
25
- # Shape of result: (batch, )
26
- tmp = torch.mul(U_dt, x_dt.unsqueeze(-1)).sum(dim=0) # (rank, )
27
- result = torch.mv(V_dt, tmp) # (batch, )
28
- return result.to(x.dtype) + x
29
- else:
30
- # Fallback to original implementation for other dimensionalities
31
- return x + torch.einsum("br,gr,g->b", U_dt, V_dt, x_dt).to(x.dtype)
32
-
33
-
34
- @torch.compile(mode="reduce-overhead")
35
- def update_lra_precond_core(
36
- U: Tensor, V: Tensor, d: Tensor, vector: Tensor, hessian_vector: Tensor, eps: float, step: float, delayed: bool
37
- ):
38
- """Core computational part of update_lra_precond optimized with torch.compile"""
39
- # Here we apply torch.compile to the computational bottleneck
40
- # All inputs are already properly typed and processed
41
-
42
- Qh = low_rank_mm_optimized(U, V, d * hessian_vector)
43
- Ph = d * low_rank_mm_optimized(V, U, Qh)
44
- rank = U.size(1)
45
-
46
- # Cache VtU computation which is used multiple times
47
- VtU = torch.einsum("br,bn->rn", V, U) # (rank, rank)
48
- I = torch.eye(rank, dtype=VtU.dtype, device=VtU.device)
49
- IpVtU = I + VtU
50
- invQtv = vector / d
51
-
52
- # LU factorization to reuse computation
53
- LU, pivots = torch.linalg.lu_factor(IpVtU)
54
-
55
- # Compute vectors inline to reduce memory allocation
56
- invQtv = invQtv - V @ torch.linalg.lu_solve(LU, pivots, (U.T @ invQtv).view(-1, 1), adjoint=True).flatten()
57
- invPv = invQtv - U @ torch.linalg.lu_solve(LU, pivots, (V.T @ invQtv).view(-1, 1)).flatten()
58
- invPv = invPv / d
59
-
60
- # Compute nabla D
61
- nablaD = Ph * hessian_vector - vector * invPv
62
-
63
- # Compute divisor more efficiently using fused operations
64
- Ph_squared = Ph.square()
65
- vector_squared = vector.square()
66
- hv_squared = hessian_vector.square()
67
- invPv_squared = invPv.square()
68
-
69
- divisor = (Ph_squared + vector_squared) * (hv_squared + invPv_squared)
70
- divisor = divisor.add(eps).sqrt().max()
71
- d_step = step / divisor
72
-
73
- # Compute for gradient update
74
- a, b = Qh, invQtv
75
-
76
- # Update either U or V, not both at the same time
77
- precond_u = random.random() < 0.5
78
- precond = V if precond_u else U
79
-
80
- # Cache computations that get reused
81
- atV = torch.einsum("b,br->r", a, precond)
82
- btV = torch.einsum("b,br->r", b, precond)
83
- atVVt = torch.einsum("r,br->b", atV, precond)
84
- btVVt = torch.einsum("r,br->b", btV, precond)
85
-
86
- # Compute step size
87
- precond_step = step / (a.norm() * atVVt.norm() + b.norm() * btVVt.norm() + eps)
88
-
89
- # Update precond matrix
90
- if precond_u:
91
- a_new = torch.einsum("b,r,rg->bg", a, atV, IpVtU)
92
- b_new = torch.einsum("b,r,rg->bg", b, btV, IpVtU)
93
- else:
94
- # Optimize with in-place operations where possible
95
- a_new = a + torch.einsum("br,r->b", V, atV)
96
- b_new = b + torch.einsum("br,r->b", V, btV)
97
- a_new = torch.einsum("b,r->br", a_new, atV)
98
- b_new = torch.einsum("b,r->br", b_new, btV)
99
-
100
- # Return updated values
101
- return d, nablaD, d_step, U if precond_u else V, b_new - a_new, precond_step, precond_u
102
-
103
-
104
- def update_lra_precond_optimized(
105
- U: List[Tensor],
106
- V: List[Tensor],
107
- d: List[Tensor],
108
- vector: Tensor,
109
- hessian_vector: Tensor,
110
- eps: float,
111
- step: float,
112
- delayed: bool,
113
- ):
114
- """
115
- Optimized version of update_lra_precond_ with:
116
- 1. Reduced memory allocations
117
- 2. Fused operations
118
- 3. Torch.compile for core computations
119
- 4. Better caching of intermediate results
120
- """
121
- U_orig, V_orig, d_orig = U, V, d
122
-
123
- # Flatten once
124
- U_flat, V_flat, d_flat = utils.flatten(U, 1), utils.flatten(V, 1), utils.flatten(d)
125
-
126
- # Convert dtype once
127
- dtype = min_dtype([U_flat, V_flat, vector, hessian_vector])
128
- U_dt = U_flat.to(dtype)
129
- V_dt = V_flat.to(dtype)
130
- vector_dt = vector.to(dtype)
131
- hv_dt = hessian_vector.to(dtype)
132
-
133
- # Convert scalar once
134
- eps_tensor = scalar_guard(eps, vector)
135
-
136
- try:
137
- # Run optimized core computation with torch.compile
138
- d_flat, nablaD, d_step, precond, update, precond_step, precond_u = update_lra_precond_core(
139
- U_dt, V_dt, d_flat, vector_dt, hv_dt, eps, step, delayed
140
- )
141
-
142
- # Apply updates efficiently
143
- utils.apply_flat_add(d_orig, d_flat * nablaD, -d_step)
144
- utils.apply_flat_add(U_orig if precond_u else V_orig, update, precond_step)
145
-
146
- # For immediate updates
147
- if not delayed:
148
- utils.stochastic_add_([d], [d_flat * nablaD], -d_step)
149
- utils.stochastic_add_([U if precond_u else V], [update], precond_step)
150
-
151
- return U_flat.to(U_orig[0].dtype), V_flat.to(V_orig[0].dtype), d_flat.to(d_orig[0].dtype)
152
-
153
- except RuntimeError:
154
- # Fallback to original implementation on failure
155
- return utils.update_lra_precond_(U, V, d, vector, hessian_vector, eps, step, delayed)
156
-
157
-
158
- @decorator
159
- def lra_precond_optimized(U, V, d, g):
160
- """
161
- Optimized version of lra_precond using memory caching and fused operations
162
- """
163
- # Get the common dtype only once
164
- dtype = min_dtype([U, V, d, g])
165
-
166
- # Convert to this dtype once
167
- U_dt, V_dt, d_dt, g_dt = U.to(dtype), V.to(dtype), d.to(dtype), g.to(dtype)
168
-
169
- # First part: g_mid = d * g
170
- g_mid = d_dt * g_dt
171
-
172
- # Second part: Qh = low_rank_mm(U, V, g_mid)
173
- # Use optimized low_rank_mm
174
- Qh = low_rank_mm_optimized(U_dt, V_dt, g_mid)
175
-
176
- # Third part: result = d * low_rank_mm(V, U, Qh)
177
- result = d_dt * low_rank_mm_optimized(V_dt, U_dt, Qh)
178
-
179
- # Return result in original dtype
180
- return result.to(g.dtype)
181
-
182
-
183
- #############################
184
- # PSGD KRON OPTIMIZATIONS
185
- #############################
186
-
187
-
188
- @decorator
189
- def psgd_calc_A_and_conjB_optimized(exprA, G, Q, conjB):
190
- """Optimized version of psgd_calc_A_and_conjB using torch.compile and memory reuse"""
191
- order = G.dim()
192
- if order > 1:
193
- conjB = conjB.view_as(G).permute(*range(1, order), 0)
194
-
195
- # Convert dtype once
196
- G_dtype = utils.promote(G.dtype)
197
- conjB = conjB.to(G_dtype)
198
-
199
- # Compute A using einsum (could be cached if called multiple times with same Q, G)
200
- A = utils.casted_einsum(exprA, *Q, G)
201
-
202
- # Process each Q matrix with potential optimizations
203
- for i, q in enumerate(Q):
204
- q = utils.promote(q)
205
- if q.dim() <= 1:
206
- # Scalar case - use in-place division
207
- conjB.div_(q)
208
- else:
209
- # Matrix case - use optimized triangular solve
210
- # Reshape once and contiguous to optimize memory access
211
- conjB_reshaped = conjB.reshape(-1, q.size(0)).contiguous()
212
- solved = torch.linalg.solve_triangular(q, conjB_reshaped, upper=True, left=False)
213
- conjB = solved.reshape_as(conjB)
214
-
215
- # Only transpose if needed for next iteration
216
- if i < order - 1:
217
- conjB = conjB.transpose(i, -1)
218
-
219
- return A, conjB
220
-
221
-
222
- @torch.compile(mode="reduce-overhead")
223
- def psgd_update_precond_core(Q, term1, term2, precond_lr, norm, q):
224
- """Core computation of psgd_update_precond optimized with torch.compile"""
225
- term1 *= precond_lr
226
- if q.dim() < 2:
227
- term1 *= q / norm.clamp_(min=tiny_bf16)
228
- else:
229
- torch.triu(term1, out=term1)
230
- term1 /= torch.where(norm > 0, utils.psgd_lb(term2, norm), norm).clamp_(tiny_bf16)
231
- term1 = torch.mm(term1, q)
232
- return term1
233
-
234
-
235
- def psgd_update_precond_optimized(Q, exprs, G, precond_lr, oq, store_triu_as_line, V):
236
- """Optimized version of psgd_update_precond with reduced allocations and torch.compile"""
237
- exprA, exprGs, _ = exprs
238
-
239
- # Use optimized A and conjB calculation
240
- A, conjB = psgd_calc_A_and_conjB_optimized(exprA, G, Q, V)
241
-
242
- # Process each Q matrix with optimizations
243
- for q, exprG, o in zip(Q, exprGs, oq):
244
- # Use optimized einsum implementations
245
- term1 = utils.promote(torch.einsum(exprG, A, A))
246
- term2 = utils.promote(torch.einsum(exprG, conjB, conjB))
247
-
248
- # Compute the update using compiled core function
249
- term1, term2 = term1 - term2, term1 + term2
250
- norm = term2.norm(float("inf"))
251
-
252
- try:
253
- # Try to use the optimized core calculation
254
- term1 = psgd_update_precond_core(Q, term1, term2, precond_lr, norm, q.to(term1.dtype))
255
- except (RuntimeError, TypeError):
256
- # Fallback to original implementation
257
- term1 *= precond_lr
258
- if q.dim() < 2:
259
- term1 *= q.to(term1.dtype) / norm.clamp_(min=tiny_bf16)
260
- else:
261
- torch.triu(term1, out=term1)
262
- term1 /= torch.where(norm > 0, utils.psgd_lb(term2, norm), norm).clamp_(tiny_bf16)
263
- term1 = torch.mm(term1, q.to(term1.dtype))
264
-
265
- # Convert to line format if needed
266
- if store_triu_as_line:
267
- term1 = utils.triu_to_line([term1])[0][1]
268
- # Apply update directly
269
- if o.dim() > 0:
270
- o.add_(term1)
271
- else:
272
- o = term1
273
- else:
274
- # Apply update directly
275
- o.add_(term1)
276
-
277
-
278
- @decorator_knowngood
279
- def psgd_precond_grad_optimized(
280
- expr: str, ea: Tensor, *preconds: Tensor, caution: bool = False, grad: Optional[Tensor] = None
281
- ):
282
- """Optimized version of psgd_precond_grad with better memory management"""
283
- if caution:
284
- ea = utils._compilable_cautioning(grad, ea)
285
-
286
- # Determine minimum dtype once
287
- md = min_dtype(list(preconds) + [ea])
288
-
289
- # Convert all tensors to the same dtype once
290
- args = [q.to(md) for q in preconds]
291
- ea_md = ea.to(md)
292
-
293
- # Optimize the einsum operation by avoiding duplicate conversions
294
- # and potentially making args contiguous if beneficial
295
- args_contiguous = [arg.contiguous() if not arg.is_contiguous() else arg for arg in args]
296
- args_double = args_contiguous + args_contiguous
297
-
298
- # Call einsum once with the combined args list
299
- new = torch.einsum(expr, *(args_double + [ea_md]))
300
-
301
- # Convert result back to original dtype
302
- return new.to(ea.dtype)
303
-
304
-
305
- @decorator_knowngood
306
- def precond_grad_cached_optimized(
307
- expr: str, ea: Tensor, *cached_q: Tensor, caution: bool = False, grad: Optional[Tensor] = None, cast: bool = True
308
- ):
309
- """Optimized version of precond_grad_cached_ with better memory management"""
310
- if caution:
311
- ea = utils._compilable_cautioning(grad, ea)
312
-
313
- # Determine minimum dtype once
314
- md = min_dtype(list(cached_q) + [ea])
315
-
316
- # Convert all tensors to the same dtype once and make contiguous if needed
317
- args = [q.to(md).contiguous() if not q.is_contiguous() else q.to(md) for q in cached_q]
318
- ea_md = ea.to(md).contiguous() if not ea.is_contiguous() else ea.to(md)
319
-
320
- # Add ea_md to args
321
- args.append(ea_md)
322
-
323
- # Call einsum once with the optimized args
324
- new = torch.einsum(expr, *args)
325
-
326
- # Convert result back if needed
327
- if cast:
328
- return new.to(ea.dtype)
329
- return new