heavyball 2.1.2__py3-none-any.whl → 2.1.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- heavyball/__init__.py +56 -89
- heavyball/chainable.py +6 -4
- heavyball/helpers.py +127 -56
- heavyball/utils.py +74 -61
- {heavyball-2.1.2.dist-info → heavyball-2.1.4.dist-info}/METADATA +2 -1
- heavyball-2.1.4.dist-info/RECORD +9 -0
- heavyball-2.1.2.dist-info/RECORD +0 -9
- {heavyball-2.1.2.dist-info → heavyball-2.1.4.dist-info}/WHEEL +0 -0
- {heavyball-2.1.2.dist-info → heavyball-2.1.4.dist-info}/licenses/LICENSE +0 -0
- {heavyball-2.1.2.dist-info → heavyball-2.1.4.dist-info}/top_level.txt +0 -0
heavyball/__init__.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
import functools
|
2
2
|
import math
|
3
|
-
from typing import Optional
|
3
|
+
from typing import Optional, Type, Union
|
4
4
|
|
5
5
|
import torch.optim
|
6
6
|
|
@@ -8,39 +8,6 @@ from . import chainable as C
|
|
8
8
|
from . import utils
|
9
9
|
|
10
10
|
|
11
|
-
class SAMWrapper(torch.optim.Optimizer):
|
12
|
-
def __init__(self, params, wrapped_optimizer: utils.StatefulOptimizer, ball: float = 0.1):
|
13
|
-
if not isinstance(wrapped_optimizer, utils.StatefulOptimizer):
|
14
|
-
raise ValueError(f"{wrapped_optimizer.__class__.__name__} is not a HeavyBall optimizer")
|
15
|
-
super().__init__(params, {"ball": ball})
|
16
|
-
self.wrapped_optimizer = wrapped_optimizer
|
17
|
-
|
18
|
-
@torch.no_grad()
|
19
|
-
def step(self, closure=None):
|
20
|
-
if closure is None:
|
21
|
-
raise ValueError("SAM requires closure")
|
22
|
-
with torch.enable_grad():
|
23
|
-
closure()
|
24
|
-
old_params = [utils.sam_step(group["params"], group["ball"]) for group in self.param_groups]
|
25
|
-
|
26
|
-
originaL_handle_closure = self.wrapped_optimizer._handle_closure
|
27
|
-
|
28
|
-
def _handle_closure(closure):
|
29
|
-
originaL_handle_closure(closure)
|
30
|
-
for group, old in zip(self.param_groups, old_params):
|
31
|
-
utils.copy_stochastic_list_(group["params"], old)
|
32
|
-
|
33
|
-
try:
|
34
|
-
self.wrapped_optimizer._handle_closure = _handle_closure
|
35
|
-
loss = self.wrapped_optimizer.step(closure)
|
36
|
-
finally:
|
37
|
-
self.wrapped_optimizer._handle_closure = originaL_handle_closure
|
38
|
-
return loss
|
39
|
-
|
40
|
-
def zero_grad(self, set_to_none: bool = True):
|
41
|
-
self.wrapped_optimizer.zero_grad()
|
42
|
-
|
43
|
-
|
44
11
|
class SGD(C.BaseOpt):
|
45
12
|
def __init__(
|
46
13
|
self,
|
@@ -778,7 +745,7 @@ class ForeachPSGDKron(C.BaseOpt):
|
|
778
745
|
beta=None,
|
779
746
|
betas=(0.9, 0.999),
|
780
747
|
weight_decay=0.0,
|
781
|
-
preconditioner_update_probability=
|
748
|
+
preconditioner_update_probability=C.use_default,
|
782
749
|
max_size_triangular=2048,
|
783
750
|
min_ndim_triangular=2,
|
784
751
|
memory_save_mode=None,
|
@@ -830,8 +797,8 @@ class ForeachPSGDKron(C.BaseOpt):
|
|
830
797
|
if kwargs:
|
831
798
|
utils.warn_once(f"Working with uncaptured keyword arguments: {kwargs}")
|
832
799
|
|
833
|
-
self.precond_schedule = (
|
834
|
-
defaults.pop("preconditioner_update_probability")
|
800
|
+
self.precond_schedule = C.default(
|
801
|
+
defaults.pop("preconditioner_update_probability"), utils.precond_update_prob_schedule()
|
835
802
|
)
|
836
803
|
params = defaults.pop("params")
|
837
804
|
|
@@ -890,7 +857,7 @@ class ForeachPSGDLRA(C.BaseOpt):
|
|
890
857
|
lr=0.001,
|
891
858
|
beta=0.9,
|
892
859
|
weight_decay=0.0,
|
893
|
-
preconditioner_update_probability=
|
860
|
+
preconditioner_update_probability=C.use_default,
|
894
861
|
momentum_into_precond_update=True,
|
895
862
|
rank: Optional[int] = None,
|
896
863
|
warmup_steps: int = 0,
|
@@ -924,8 +891,8 @@ class ForeachPSGDLRA(C.BaseOpt):
|
|
924
891
|
if kwargs:
|
925
892
|
utils.warn_once(f"Working with uncaptured keyword arguments: {kwargs}")
|
926
893
|
|
927
|
-
self.precond_schedule = (
|
928
|
-
defaults.pop("preconditioner_update_probability")
|
894
|
+
self.precond_schedule = C.default(
|
895
|
+
defaults.pop("preconditioner_update_probability"), utils.precond_update_prob_schedule()
|
929
896
|
)
|
930
897
|
params = defaults.pop("params")
|
931
898
|
|
@@ -960,6 +927,54 @@ class NewtonHybrid2PSGDLRA(ForeachNewtonPSGDLRA):
|
|
960
927
|
hvp_interval = 2
|
961
928
|
|
962
929
|
|
930
|
+
class SAMWrapper(torch.optim.Optimizer):
|
931
|
+
def __init__(
|
932
|
+
self,
|
933
|
+
params,
|
934
|
+
wrapped_optimizer: Union[utils.StatefulOptimizer, Type[utils.StatefulOptimizer]] = ForeachAdamW,
|
935
|
+
ball: float = 0.1,
|
936
|
+
):
|
937
|
+
params = list(params)
|
938
|
+
super().__init__(params, {"ball": ball})
|
939
|
+
|
940
|
+
if isinstance(wrapped_optimizer, type):
|
941
|
+
if not issubclass(wrapped_optimizer, utils.StatefulOptimizer):
|
942
|
+
raise ValueError(f"{wrapped_optimizer.__name__} is not a HeavyBall optimizer")
|
943
|
+
wrapped_optimizer = wrapped_optimizer(params)
|
944
|
+
elif not isinstance(wrapped_optimizer, utils.StatefulOptimizer):
|
945
|
+
raise ValueError(f"{wrapped_optimizer.__class__.__name__} is not a HeavyBall optimizer")
|
946
|
+
|
947
|
+
self.wrapped_optimizer = wrapped_optimizer
|
948
|
+
|
949
|
+
@torch.no_grad()
|
950
|
+
def step(self, closure=None):
|
951
|
+
if closure is None:
|
952
|
+
raise ValueError("SAM requires closure")
|
953
|
+
with torch.enable_grad():
|
954
|
+
closure()
|
955
|
+
old_params = [utils.sam_step(group["params"], group["ball"]) for group in self.param_groups]
|
956
|
+
|
957
|
+
original_handle_closure = self.wrapped_optimizer._handle_closure
|
958
|
+
|
959
|
+
def _handle_closure(closure):
|
960
|
+
try:
|
961
|
+
_loss = original_handle_closure(closure)
|
962
|
+
finally:
|
963
|
+
for group, old in zip(self.param_groups, old_params):
|
964
|
+
utils.copy_stochastic_list_(group["params"], old)
|
965
|
+
return _loss
|
966
|
+
|
967
|
+
try:
|
968
|
+
self.wrapped_optimizer._handle_closure = _handle_closure
|
969
|
+
loss = self.wrapped_optimizer.step(closure)
|
970
|
+
finally:
|
971
|
+
self.wrapped_optimizer._handle_closure = original_handle_closure
|
972
|
+
return loss
|
973
|
+
|
974
|
+
def zero_grad(self, set_to_none: bool = True):
|
975
|
+
self.wrapped_optimizer.zero_grad(set_to_none=set_to_none)
|
976
|
+
|
977
|
+
|
963
978
|
PalmForEachSoap = PaLMForeachSOAP
|
964
979
|
PaLMSOAP = PaLMForeachSOAP
|
965
980
|
PaLMSFAdamW = PaLMForeachSFAdamW
|
@@ -983,52 +998,4 @@ PSGDLRA = ForeachPSGDLRA
|
|
983
998
|
NewtonPSGDLRA = ForeachNewtonPSGDLRA
|
984
999
|
NewtonPSGDKron = ForeachCachedNewtonPSGD
|
985
1000
|
|
986
|
-
__all__ = [
|
987
|
-
"Muon",
|
988
|
-
"RMSprop",
|
989
|
-
"PrecondSchedulePaLMSOAP",
|
990
|
-
"PSGDKron",
|
991
|
-
"PurePSGD",
|
992
|
-
"DelayedPSGD",
|
993
|
-
"CachedPSGDKron",
|
994
|
-
"CachedDelayedPSGDKron",
|
995
|
-
"PalmForEachSoap",
|
996
|
-
"PaLMSOAP",
|
997
|
-
"PaLMSFAdamW",
|
998
|
-
"LaProp",
|
999
|
-
"ADOPT",
|
1000
|
-
"PrecondScheduleSOAP",
|
1001
|
-
"PrecondSchedulePaLMSOAP",
|
1002
|
-
"RMSprop",
|
1003
|
-
"MuonLaProp",
|
1004
|
-
"ForeachSignLaProp",
|
1005
|
-
"ForeachDelayedPSGDLRA",
|
1006
|
-
"ForeachPSGDLRA",
|
1007
|
-
"ForeachPSGDLRA",
|
1008
|
-
"ForeachNewtonPSGDLRA", #
|
1009
|
-
"ForeachAdamW",
|
1010
|
-
"ForeachSFAdamW",
|
1011
|
-
"ForeachLaProp",
|
1012
|
-
"ForeachADOPT",
|
1013
|
-
"ForeachSOAP",
|
1014
|
-
"ForeachPSGDKron",
|
1015
|
-
"ForeachPurePSGD",
|
1016
|
-
"ForeachDelayedPSGD",
|
1017
|
-
"ForeachCachedPSGDKron",
|
1018
|
-
"ForeachCachedDelayedPSGDKron",
|
1019
|
-
"ForeachRMSprop",
|
1020
|
-
"ForeachMuon",
|
1021
|
-
"ForeachCachedNewtonPSGD",
|
1022
|
-
"OrthoLaProp",
|
1023
|
-
"LaPropOrtho",
|
1024
|
-
"SignLaProp",
|
1025
|
-
"DelayedPSGD",
|
1026
|
-
"PSGDLRA",
|
1027
|
-
"NewtonPSGDLRA",
|
1028
|
-
"NewtonHybrid2PSGDLRA",
|
1029
|
-
"NewtonHybrid2PSGDKron",
|
1030
|
-
"MSAMLaProp",
|
1031
|
-
"NewtonPSGDKron",
|
1032
|
-
"ForeachAdamC",
|
1033
|
-
"SGD",
|
1034
|
-
]
|
1001
|
+
__all__ = [k for k, v in globals().items() if isinstance(v, type) and issubclass(v, torch.optim.Optimizer)]
|
heavyball/chainable.py
CHANGED
@@ -62,8 +62,6 @@ class FunctionTransform:
|
|
62
62
|
self._init(st, group, *a, **kwargs)
|
63
63
|
except SkipUpdate:
|
64
64
|
skip_update = True
|
65
|
-
except:
|
66
|
-
raise
|
67
65
|
finally:
|
68
66
|
if "is_initialized" not in st:
|
69
67
|
st["is_initialized"] = set()
|
@@ -499,6 +497,7 @@ def scale_by_suds(group, update, grad, param, exp_avg, exp_avg_sq, fisher_approx
|
|
499
497
|
precond_update, _ = utils.eigvecs_product_rank1(precond_update.flatten(), fisher_approx.flatten(), w)
|
500
498
|
|
501
499
|
new_approx = utils.oja_update(fisher_approx.flatten().to(update.dtype), update.flatten(), group["precond_lr"])
|
500
|
+
new_approx = new_approx.view_as(fisher_approx)
|
502
501
|
utils.copy_stochastic_(fisher_approx, new_approx)
|
503
502
|
return precond_update
|
504
503
|
|
@@ -565,7 +564,7 @@ def _init_psgd_kron(state, group, update, grad, param, cached: bool = False, pro
|
|
565
564
|
)
|
566
565
|
state["Q"] = utils.triu_to_line(Q) if group["store_triu_as_line"] else Q
|
567
566
|
state["running_lower_bound"] = [torch.zeros((1,), device=q.device, dtype=torch.float64) for q in Q]
|
568
|
-
state["step"] = torch.zeros((), device=param.device, dtype=torch.
|
567
|
+
state["step"] = torch.zeros((), device=param.device, dtype=torch.float64) # torch casts int to float in ckpt load
|
569
568
|
if group["adaptive"]:
|
570
569
|
state["velocity"] = [torch.zeros((), device=q.device, dtype=q.dtype) for q in Q]
|
571
570
|
if not cached:
|
@@ -750,7 +749,9 @@ def _update_psgd_precond(
|
|
750
749
|
if isinstance(prob, float):
|
751
750
|
float_prob = prob
|
752
751
|
else:
|
753
|
-
|
752
|
+
prob_step = group.get(f"cumulative_prob_{id(Q)}_prob_step", 1)
|
753
|
+
float_prob = prob(prob_step)
|
754
|
+
group[f"cumulative_prob_{id(Q)}_prob_step"] = prob_step + 1
|
754
755
|
group["is_cached"] = should_use_cache = cached and float_prob < 0.5
|
755
756
|
|
756
757
|
if precond is not None:
|
@@ -1086,6 +1087,7 @@ class ChainOpt(utils.StatefulOptimizer):
|
|
1086
1087
|
if not group["foreach"] or len(p) == 1:
|
1087
1088
|
for param, grad in zip(p, g):
|
1088
1089
|
chain(self.state_, group, [grad], [param], *self.fns)
|
1090
|
+
group["caution"] = caution
|
1089
1091
|
else:
|
1090
1092
|
chain(self.state_, group, g, p, *self.fns)
|
1091
1093
|
|
heavyball/helpers.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1
|
-
|
2
|
-
|
1
|
+
import copy
|
3
2
|
import functools
|
4
3
|
import math
|
5
4
|
import threading
|
6
|
-
from
|
5
|
+
from contextlib import contextmanager
|
6
|
+
from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Sequence, Tuple, Union
|
7
7
|
|
8
8
|
import numpy
|
9
9
|
import numpy as np
|
@@ -11,23 +11,15 @@ import optuna
|
|
11
11
|
import optunahub
|
12
12
|
import pandas as pd
|
13
13
|
import torch
|
14
|
-
from botorch.utils.sampling import manual_seed
|
15
14
|
from hebo.design_space.design_space import DesignSpace
|
16
15
|
from hebo.optimizers.hebo import HEBO
|
17
|
-
from optuna._transform import _SearchSpaceTransform
|
16
|
+
from optuna._transform import _SearchSpaceTransform as SearchSpaceTransform
|
18
17
|
from optuna.distributions import BaseDistribution, CategoricalDistribution, FloatDistribution, IntDistribution
|
19
18
|
from optuna.samplers import BaseSampler, CmaEsSampler, RandomSampler
|
20
19
|
from optuna.samplers._lazy_random_state import LazyRandomState
|
21
20
|
from optuna.study import Study
|
22
21
|
from optuna.study._study_direction import StudyDirection
|
23
22
|
from optuna.trial import FrozenTrial, TrialState
|
24
|
-
from optuna_integration.botorch import (
|
25
|
-
ehvi_candidates_func,
|
26
|
-
logei_candidates_func,
|
27
|
-
qehvi_candidates_func,
|
28
|
-
qei_candidates_func,
|
29
|
-
qparego_candidates_func,
|
30
|
-
)
|
31
23
|
from torch import Tensor
|
32
24
|
from torch.nn import functional as F
|
33
25
|
|
@@ -37,12 +29,39 @@ _MAXINT32 = (1 << 31) - 1
|
|
37
29
|
_SAMPLER_KEY = "auto:sampler"
|
38
30
|
|
39
31
|
|
32
|
+
@contextmanager
|
33
|
+
def manual_seed(seed: int | None = None) -> Generator[None, None, None]:
|
34
|
+
r"""
|
35
|
+
Contextmanager for manual setting the torch.random seed.
|
36
|
+
|
37
|
+
Args:
|
38
|
+
seed: The seed to set the random number generator to.
|
39
|
+
|
40
|
+
Returns:
|
41
|
+
Generator
|
42
|
+
|
43
|
+
Example:
|
44
|
+
>>> with manual_seed(1234):
|
45
|
+
>>> X = torch.rand(3)
|
46
|
+
|
47
|
+
copied as-is from https://github.com/meta-pytorch/botorch/blob/a42cd65f9b704cdb6f2ee64db99a022eb15295d5/botorch/utils/sampling.py#L53C1-L75C50 under the MIT License
|
48
|
+
"""
|
49
|
+
old_state = torch.random.get_rng_state()
|
50
|
+
try:
|
51
|
+
if seed is not None:
|
52
|
+
torch.random.manual_seed(seed)
|
53
|
+
yield
|
54
|
+
finally:
|
55
|
+
if seed is not None:
|
56
|
+
torch.random.set_rng_state(old_state)
|
57
|
+
|
58
|
+
|
40
59
|
class SimpleAPIBaseSampler(BaseSampler):
|
41
60
|
def __init__(
|
42
61
|
self,
|
43
|
-
search_space: dict[str, BaseDistribution] = None,
|
62
|
+
search_space: Optional[dict[str, BaseDistribution]] = None,
|
44
63
|
):
|
45
|
-
self.search_space = search_space
|
64
|
+
self.search_space = {} if search_space is None else dict(search_space)
|
46
65
|
|
47
66
|
def suggest_all(self, trial: FrozenTrial):
|
48
67
|
return {k: trial._suggest(k, dist) for k, dist in self.search_space.items()}
|
@@ -65,6 +84,16 @@ def _get_default_candidates_func(
|
|
65
84
|
"""
|
66
85
|
The original is available at https://github.com/optuna/optuna-integration/blob/156a8bc081322791015d2beefff9373ed7b24047/optuna_integration/botorch/botorch.py under the MIT License
|
67
86
|
"""
|
87
|
+
|
88
|
+
# lazy import
|
89
|
+
from optuna_integration.botorch import (
|
90
|
+
ehvi_candidates_func,
|
91
|
+
logei_candidates_func,
|
92
|
+
qehvi_candidates_func,
|
93
|
+
qei_candidates_func,
|
94
|
+
qparego_candidates_func,
|
95
|
+
)
|
96
|
+
|
68
97
|
if n_objectives > 3 and not has_constraint and not consider_running_trials:
|
69
98
|
return ehvi_candidates_func
|
70
99
|
elif n_objectives > 3:
|
@@ -124,7 +153,7 @@ def _untransform_numerical_param_torch(
|
|
124
153
|
|
125
154
|
|
126
155
|
@torch.no_grad()
|
127
|
-
def untransform(self:
|
156
|
+
def untransform(self: SearchSpaceTransform, trans_params: Tensor) -> dict[str, Any]:
|
128
157
|
assert trans_params.shape == (self._raw_bounds.shape[0],)
|
129
158
|
|
130
159
|
if self._transform_0_1:
|
@@ -152,29 +181,31 @@ class BoTorchSampler(SimpleAPIBaseSampler):
|
|
152
181
|
|
153
182
|
def __init__(
|
154
183
|
self,
|
155
|
-
search_space: dict[str, BaseDistribution] = None,
|
184
|
+
search_space: Optional[dict[str, BaseDistribution]] = None,
|
156
185
|
*,
|
157
|
-
candidates_func:
|
158
|
-
constraints_func:
|
186
|
+
candidates_func: Optional[Callable[..., Tensor]] = None,
|
187
|
+
constraints_func: Optional[Callable[..., Tensor]] = None,
|
159
188
|
n_startup_trials: int = 10,
|
160
189
|
consider_running_trials: bool = False,
|
161
|
-
independent_sampler:
|
190
|
+
independent_sampler: Optional[BaseSampler] = None,
|
162
191
|
seed: int | None = None,
|
163
192
|
device: torch.device | str | None = None,
|
164
193
|
trial_chunks: int = 128,
|
165
194
|
):
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
195
|
+
if constraints_func is not None:
|
196
|
+
raise NotImplementedError("constraints_func is currently not supported by BoTorchSampler.")
|
197
|
+
if consider_running_trials:
|
198
|
+
raise NotImplementedError("consider_running_trials is currently not supported by BoTorchSampler.")
|
199
|
+
if candidates_func is not None and not callable(candidates_func):
|
200
|
+
raise TypeError("candidates_func must be callable.")
|
201
|
+
self._candidates_func = candidates_func
|
202
|
+
self._independent_sampler = independent_sampler or RandomSampler(seed=seed)
|
172
203
|
self._n_startup_trials = n_startup_trials
|
173
204
|
self._seed = seed
|
174
205
|
self.trial_chunks = trial_chunks
|
175
206
|
|
176
207
|
self._study_id: int | None = None
|
177
|
-
self.search_space = search_space
|
208
|
+
self.search_space = {} if search_space is None else dict(search_space)
|
178
209
|
if isinstance(device, str):
|
179
210
|
device = torch.device(device)
|
180
211
|
self._device = device or torch.device("cpu")
|
@@ -182,14 +213,24 @@ class BoTorchSampler(SimpleAPIBaseSampler):
|
|
182
213
|
self._values = None
|
183
214
|
self._params = None
|
184
215
|
self._index = 0
|
216
|
+
self._bounds_dim: int | None = None
|
185
217
|
|
186
218
|
def infer_relative_search_space(self, study: Study, trial: FrozenTrial) -> dict[str, BaseDistribution]:
|
187
219
|
return self.search_space
|
188
220
|
|
189
221
|
@torch.no_grad()
|
190
222
|
def _preprocess_trials(
|
191
|
-
self, trans:
|
223
|
+
self, trans: SearchSpaceTransform, study: Study, trials: list[FrozenTrial]
|
192
224
|
) -> Tuple[int, Tensor, Tensor]:
|
225
|
+
bounds_dim = trans.bounds.shape[0]
|
226
|
+
if self._bounds_dim is not None and self._bounds_dim != bounds_dim:
|
227
|
+
self._values = None
|
228
|
+
self._params = None
|
229
|
+
self._index = 0
|
230
|
+
self.seen_trials = set()
|
231
|
+
if self._bounds_dim is None:
|
232
|
+
self._bounds_dim = bounds_dim
|
233
|
+
|
193
234
|
new_trials = []
|
194
235
|
for trial in trials:
|
195
236
|
tid: int = trial._trial_id
|
@@ -200,6 +241,10 @@ class BoTorchSampler(SimpleAPIBaseSampler):
|
|
200
241
|
|
201
242
|
n_objectives = len(study.directions)
|
202
243
|
if not new_trials:
|
244
|
+
if self._values is None or self._params is None:
|
245
|
+
empty_values = torch.zeros((0, n_objectives), dtype=torch.float64, device=self._device)
|
246
|
+
empty_params = torch.zeros((0, bounds_dim), dtype=torch.float64, device=self._device)
|
247
|
+
return n_objectives, empty_values, empty_params
|
203
248
|
return n_objectives, self._values[: self._index], self._params[: self._index]
|
204
249
|
|
205
250
|
n_completed_trials = len(trials)
|
@@ -216,18 +261,28 @@ class BoTorchSampler(SimpleAPIBaseSampler):
|
|
216
261
|
if direction == StudyDirection.MINIMIZE: # BoTorch always assumes maximization.
|
217
262
|
values[:, obj_idx] *= -1
|
218
263
|
|
219
|
-
|
264
|
+
bounds_dim = trans.bounds.shape[0]
|
265
|
+
cache_stale = (
|
266
|
+
self._values is None
|
267
|
+
or self._params is None
|
268
|
+
or self._values.size(1) != n_objectives
|
269
|
+
or self._params.size(1) != bounds_dim
|
270
|
+
)
|
271
|
+
if cache_stale:
|
220
272
|
self._values = torch.zeros((self.trial_chunks, n_objectives), dtype=torch.float64, device=self._device)
|
221
|
-
self._params = torch.zeros(
|
222
|
-
|
223
|
-
)
|
273
|
+
self._params = torch.zeros((self.trial_chunks, bounds_dim), dtype=torch.float64, device=self._device)
|
274
|
+
self._index = 0
|
275
|
+
self.seen_trials = set()
|
276
|
+
self._bounds_dim = bounds_dim
|
224
277
|
spillage = (self._index + n_completed_trials) - self._values.size(0)
|
225
278
|
if spillage > 0:
|
226
279
|
pad = int(math.ceil(spillage / self.trial_chunks) * self.trial_chunks)
|
227
280
|
self._values = F.pad(self._values, (0, 0, 0, pad))
|
228
281
|
self._params = F.pad(self._params, (0, 0, 0, pad))
|
229
|
-
|
230
|
-
|
282
|
+
values_tensor = torch.from_numpy(values).to(self._device)
|
283
|
+
params_tensor = torch.from_numpy(params).to(self._device)
|
284
|
+
self._values[self._index : self._index + n_completed_trials] = values_tensor
|
285
|
+
self._params[self._index : self._index + n_completed_trials] = params_tensor
|
231
286
|
self._index += n_completed_trials
|
232
287
|
|
233
288
|
return n_objectives, self._values[: self._index], self._params[: self._index]
|
@@ -246,7 +301,7 @@ class BoTorchSampler(SimpleAPIBaseSampler):
|
|
246
301
|
if n_completed_trials < self._n_startup_trials:
|
247
302
|
return {}
|
248
303
|
|
249
|
-
trans =
|
304
|
+
trans = SearchSpaceTransform(search_space)
|
250
305
|
n_objectives, values, params = self._preprocess_trials(trans, study, completed_trials)
|
251
306
|
|
252
307
|
if self._candidates_func is None:
|
@@ -349,10 +404,10 @@ class HEBOSampler(optunahub.samplers.SimpleBaseSampler, SimpleAPIBaseSampler):
|
|
349
404
|
independent_sampler: BaseSampler | None = None,
|
350
405
|
) -> None:
|
351
406
|
super().__init__(search_space, seed)
|
352
|
-
|
353
|
-
|
407
|
+
if constant_liar:
|
408
|
+
raise NotImplementedError("constant_liar is not supported by HEBOSampler.")
|
354
409
|
self._hebo = HEBO(_convert_to_hebo_design_space(search_space), scramble_seed=self._seed)
|
355
|
-
self._independent_sampler = optuna.samplers.RandomSampler(seed=seed)
|
410
|
+
self._independent_sampler = independent_sampler or optuna.samplers.RandomSampler(seed=seed)
|
356
411
|
self._rng = np.random.default_rng(seed)
|
357
412
|
|
358
413
|
def sample_relative(
|
@@ -421,10 +476,12 @@ class FastINGO:
|
|
421
476
|
learning_rate: Optional[float] = None,
|
422
477
|
last_n: int = 4096,
|
423
478
|
loco_step_size: float = 0.1,
|
424
|
-
device=
|
479
|
+
device: str | None = None,
|
425
480
|
batchnorm_decay: float = 0.99,
|
426
481
|
score_decay: float = 0.99,
|
427
482
|
) -> None:
|
483
|
+
if device is None:
|
484
|
+
device = _use_cuda()
|
428
485
|
n_dimension = len(mean)
|
429
486
|
if population_size is None:
|
430
487
|
population_size = 4 + int(np.floor(3 * np.log(n_dimension)))
|
@@ -491,8 +548,14 @@ class FastINGO:
|
|
491
548
|
if y.numel() <= 2:
|
492
549
|
return
|
493
550
|
|
494
|
-
|
495
|
-
|
551
|
+
min_y = y.min()
|
552
|
+
max_y = y.max()
|
553
|
+
if torch.isclose(max_y, min_y, rtol=0.0, atol=1e-12):
|
554
|
+
return
|
555
|
+
|
556
|
+
if min_y <= 0:
|
557
|
+
y = y + (1e-8 - min_y)
|
558
|
+
y = y.clamp_min_(1e-8).log()
|
496
559
|
|
497
560
|
ema = -torch.arange(y.size(0), device=y.device, dtype=y.dtype)
|
498
561
|
weight = self.batchnorm_decay**ema
|
@@ -553,7 +616,7 @@ class ImplicitNaturalGradientSampler(BaseSampler):
|
|
553
616
|
def reseed_rng(self) -> None:
|
554
617
|
self._independent_sampler.reseed_rng()
|
555
618
|
if self._optimizer:
|
556
|
-
self._optimizer.
|
619
|
+
self._optimizer.generator.seed()
|
557
620
|
|
558
621
|
def infer_relative_search_space(
|
559
622
|
self, study: "optuna.Study", trial: "optuna.trial.FrozenTrial"
|
@@ -603,14 +666,11 @@ class ImplicitNaturalGradientSampler(BaseSampler):
|
|
603
666
|
self._warn_independent_sampling = False
|
604
667
|
return {}
|
605
668
|
|
606
|
-
trans =
|
669
|
+
trans = SearchSpaceTransform(search_space)
|
607
670
|
|
608
|
-
if self._optimizer is None:
|
671
|
+
if self._optimizer is None or self._optimizer.dim != len(trans.bounds):
|
609
672
|
self._optimizer = self._init_optimizer(trans, population_size=self._population_size)
|
610
|
-
|
611
|
-
if self._optimizer.dim != len(trans.bounds):
|
612
|
-
self._warn_independent_sampling = False
|
613
|
-
return {}
|
673
|
+
self._param_queue.clear()
|
614
674
|
|
615
675
|
solution_trials = [t for t in completed_trials if self._check_trial_is_generation(t)]
|
616
676
|
for t in solution_trials:
|
@@ -621,7 +681,7 @@ class ImplicitNaturalGradientSampler(BaseSampler):
|
|
621
681
|
|
622
682
|
def _init_optimizer(
|
623
683
|
self,
|
624
|
-
trans:
|
684
|
+
trans: SearchSpaceTransform,
|
625
685
|
population_size: Optional[int] = None,
|
626
686
|
) -> FastINGO:
|
627
687
|
lower_bounds = trans.bounds[:, 0]
|
@@ -675,6 +735,7 @@ class ThreadLocalSampler(threading.local):
|
|
675
735
|
|
676
736
|
|
677
737
|
def init_cmaes(study, seed, trials, search_space):
|
738
|
+
trials = copy.deepcopy(trials)
|
678
739
|
trials.sort(key=lambda trial: trial.datetime_complete)
|
679
740
|
return CmaEsSampler(seed=seed, source_trials=trials, lr_adapt=True)
|
680
741
|
|
@@ -686,8 +747,14 @@ def init_hebo(study, seed, trials, search_space):
|
|
686
747
|
return sampler
|
687
748
|
|
688
749
|
|
750
|
+
def _use_cuda():
|
751
|
+
return "cuda" if torch.cuda.is_available() else "cpu"
|
752
|
+
|
753
|
+
|
689
754
|
def init_botorch(study, seed, trials, search_space):
|
690
|
-
return BoTorchSampler(
|
755
|
+
return BoTorchSampler(
|
756
|
+
search_space=search_space, seed=seed, device=_use_cuda()
|
757
|
+
) # will automatically pull in latest data
|
691
758
|
|
692
759
|
|
693
760
|
def init_nsgaii(study, seed, trials, search_space):
|
@@ -709,17 +776,20 @@ class AutoSampler(BaseSampler):
|
|
709
776
|
def __init__(
|
710
777
|
self,
|
711
778
|
samplers: Iterable[Tuple[int, Callable]] | None = None,
|
712
|
-
search_space: dict[str, BaseDistribution] = None,
|
779
|
+
search_space: Optional[dict[str, BaseDistribution]] = None,
|
713
780
|
*,
|
714
781
|
seed: int | None = None,
|
715
|
-
constraints_func:
|
782
|
+
constraints_func: Optional[Callable[..., Any]] = None,
|
716
783
|
) -> None:
|
717
|
-
|
784
|
+
if constraints_func is not None:
|
785
|
+
raise NotImplementedError("constraints_func is not supported by AutoSampler.")
|
718
786
|
if samplers is None:
|
787
|
+
if search_space is None:
|
788
|
+
raise ValueError("AutoSampler requires a search_space when using the default sampler schedule.")
|
719
789
|
samplers = ((0, init_hebo), (100, init_nsgaii))
|
720
790
|
self.sampler_indices = np.sort(np.array([x[0] for x in samplers], dtype=np.int32))
|
721
791
|
self.samplers = [x[1] for x in sorted(samplers, key=lambda x: x[0])]
|
722
|
-
self.search_space = search_space
|
792
|
+
self.search_space = {} if search_space is None else dict(search_space)
|
723
793
|
self._rng = LazyRandomState(seed)
|
724
794
|
self._random_sampler = RandomSampler(seed=seed)
|
725
795
|
self._thread_local_sampler = ThreadLocalSampler()
|
@@ -762,7 +832,7 @@ class AutoSampler(BaseSampler):
|
|
762
832
|
complete_trials = study._get_trials(deepcopy=False, states=(TrialState.COMPLETE,), use_cache=True)
|
763
833
|
self._completed_trials = max(self._completed_trials, len(complete_trials))
|
764
834
|
new_index = (self._completed_trials >= self.sampler_indices).sum() - 1
|
765
|
-
if new_index == self._current_index:
|
835
|
+
if new_index == self._current_index or new_index < 0:
|
766
836
|
return
|
767
837
|
self._current_index = new_index
|
768
838
|
self._sampler = self.samplers[new_index](
|
@@ -775,7 +845,7 @@ class AutoSampler(BaseSampler):
|
|
775
845
|
def sample_relative(
|
776
846
|
self, study: Study, trial: FrozenTrial, search_space: dict[str, BaseDistribution]
|
777
847
|
) -> dict[str, Any]:
|
778
|
-
return self._sampler.sample_relative(study, trial, self.search_space)
|
848
|
+
return self._sampler.sample_relative(study, trial, search_space or self.search_space)
|
779
849
|
|
780
850
|
def sample_independent(
|
781
851
|
self,
|
@@ -804,5 +874,6 @@ class AutoSampler(BaseSampler):
|
|
804
874
|
state: TrialState,
|
805
875
|
values: Sequence[float] | None,
|
806
876
|
) -> None:
|
807
|
-
|
877
|
+
if state not in (TrialState.COMPLETE, TrialState.FAIL, TrialState.PRUNED):
|
878
|
+
raise ValueError(f"Unsupported trial state: {state}.")
|
808
879
|
self._sampler.after_trial(study, trial, state, values)
|
heavyball/utils.py
CHANGED
@@ -47,7 +47,7 @@ _cudnn_double_backward_pattern = re.compile(
|
|
47
47
|
)
|
48
48
|
_torch_compile_double_backward_pattern = re.compile(r"compile.*does not currently support double backward")
|
49
49
|
_fd_error = (
|
50
|
-
"You can accelerate startup by globally enabling finite_differences first "
|
50
|
+
"You can accelerate startup by globally enabling finite_differences first "
|
51
51
|
"(via opt.finite_differences=True or by subclassing it)\n"
|
52
52
|
"Original Error: "
|
53
53
|
)
|
@@ -343,7 +343,8 @@ def set_(dst: Tensor, src: Tensor):
|
|
343
343
|
|
344
344
|
|
345
345
|
def clean():
|
346
|
-
torch.cuda.
|
346
|
+
if torch.cuda.is_available():
|
347
|
+
torch.cuda.empty_cache()
|
347
348
|
gc.collect()
|
348
349
|
|
349
350
|
|
@@ -418,9 +419,13 @@ def zeropower_via_newtonschulz5(G, steps=5, eps=1e-7):
|
|
418
419
|
|
419
420
|
|
420
421
|
###### START
|
421
|
-
#
|
422
|
+
# Based on https://arxiv.org/pdf/2505.16932v3
|
423
|
+
# and https://github.com/NoahAmsel/PolarExpress/blob/5454910920ca8c65afda28820cdf9e49b9436ed0/polar_express.py#L69-L82
|
424
|
+
# and https://github.com/thinking-machines-lab/manifolds/blob/89dcae50f01af59f1e0570289474da3a2ecaa60b/src/msign.py#L47
|
425
|
+
#
|
422
426
|
# under the MIT License
|
423
427
|
|
428
|
+
# Coefficients are from https://arxiv.org/pdf/2505.16932v3
|
424
429
|
ABC_LIST: list[tuple[float, float, float]] = [
|
425
430
|
(8.28721201814563, -23.595886519098837, 17.300387312530933),
|
426
431
|
(4.107059111542203, -2.9478499167379106, 0.5448431082926601),
|
@@ -438,7 +443,7 @@ ABC_LIST_STABLE: list[tuple[float, float, float]] = [
|
|
438
443
|
] + [ABC_LIST[-1]]
|
439
444
|
|
440
445
|
|
441
|
-
def msign(G: torch.Tensor, steps: int = 10) -> torch.Tensor:
|
446
|
+
def msign(G: torch.Tensor, steps: int = 10, eps: float = 1e-7) -> torch.Tensor:
|
442
447
|
"""
|
443
448
|
Polar Express algorithm for the matrix sign function:
|
444
449
|
https://arxiv.org/abs/2505.16932
|
@@ -450,7 +455,9 @@ def msign(G: torch.Tensor, steps: int = 10) -> torch.Tensor:
|
|
450
455
|
if should_transpose:
|
451
456
|
x = x.mT
|
452
457
|
|
453
|
-
x
|
458
|
+
# x = x / (x.norm(dim=(-2, -1), keepdim=True) * 1.01 + eps)
|
459
|
+
stochastic_divide_with_eps_(x, x.norm(dim=(-2, -1)) * 1.01, eps)
|
460
|
+
|
454
461
|
for step in range(steps):
|
455
462
|
a, b, c = ABC_LIST_STABLE[step] if step < len(ABC_LIST_STABLE) else ABC_LIST_STABLE[-1]
|
456
463
|
s = x @ x.mT
|
@@ -464,8 +471,7 @@ def msign(G: torch.Tensor, steps: int = 10) -> torch.Tensor:
|
|
464
471
|
|
465
472
|
if should_transpose:
|
466
473
|
x = x.mT
|
467
|
-
x
|
468
|
-
return x.float()
|
474
|
+
return x.to(G.dtype)
|
469
475
|
|
470
476
|
|
471
477
|
###### END
|
@@ -665,9 +671,9 @@ def get_orthogonal_matrix(mat, max_eps: float = 1e-3, min_eps: float = 1e-30):
|
|
665
671
|
final.append(None)
|
666
672
|
continue
|
667
673
|
|
674
|
+
device, dtype = m.device, m.dtype
|
668
675
|
m = promote(m.data)
|
669
676
|
|
670
|
-
device, dtype = m.device, m.dtype
|
671
677
|
eps = min_eps
|
672
678
|
while True:
|
673
679
|
try:
|
@@ -695,7 +701,6 @@ def get_orthogonal_matrix(mat, max_eps: float = 1e-3, min_eps: float = 1e-30):
|
|
695
701
|
raise
|
696
702
|
clean()
|
697
703
|
|
698
|
-
eigvec = eigvec.to(device=m.device, dtype=m.dtype)
|
699
704
|
eigvec = torch.flip(eigvec, [1])
|
700
705
|
final.append(eigvec)
|
701
706
|
|
@@ -1048,13 +1053,15 @@ class StatefulOptimizer(torch.optim.Optimizer):
|
|
1048
1053
|
def get_groups(self, group):
|
1049
1054
|
return [group]
|
1050
1055
|
|
1051
|
-
@functools.lru_cache(maxsize=None)
|
1052
1056
|
def state_(self, arg: Tensor, fail: bool = True):
|
1053
|
-
|
1054
|
-
|
1055
|
-
if _tensor_key(arg) not in self.mapping_inverse:
|
1057
|
+
key = _tensor_key(arg)
|
1058
|
+
if key not in self.mapping_inverse:
|
1056
1059
|
self._init_mapping()
|
1057
|
-
|
1060
|
+
if key not in self.mapping_inverse:
|
1061
|
+
if not fail:
|
1062
|
+
return {}
|
1063
|
+
raise KeyError("Tensor has no tracked state.")
|
1064
|
+
state_param, index = self.mapping_inverse[key]
|
1058
1065
|
if state_param not in self.state:
|
1059
1066
|
self.state[state_param] = collections.defaultdict(dict)
|
1060
1067
|
return self.state[state_param][index]
|
@@ -1142,7 +1149,7 @@ class StatefulOptimizer(torch.optim.Optimizer):
|
|
1142
1149
|
active_p = [p for p in group["params"]]
|
1143
1150
|
|
1144
1151
|
if not active_p:
|
1145
|
-
|
1152
|
+
continue
|
1146
1153
|
|
1147
1154
|
k = group["ema_step"] = group.get("ema_step", -1) + 1
|
1148
1155
|
|
@@ -1159,7 +1166,7 @@ class StatefulOptimizer(torch.optim.Optimizer):
|
|
1159
1166
|
active_p = [p for p in group["params"]]
|
1160
1167
|
|
1161
1168
|
if not active_p:
|
1162
|
-
|
1169
|
+
continue
|
1163
1170
|
|
1164
1171
|
for p in active_p:
|
1165
1172
|
if "param_ema" in self.state_(p):
|
@@ -1173,7 +1180,7 @@ class StatefulOptimizer(torch.optim.Optimizer):
|
|
1173
1180
|
active_p = [p for p in group["params"]]
|
1174
1181
|
|
1175
1182
|
if not active_p:
|
1176
|
-
|
1183
|
+
continue
|
1177
1184
|
|
1178
1185
|
for p in active_p:
|
1179
1186
|
if "param_ema" in self.state_(p):
|
@@ -1202,7 +1209,7 @@ class StatefulOptimizer(torch.optim.Optimizer):
|
|
1202
1209
|
for group in self.param_groups:
|
1203
1210
|
for p, g in self.split_p_and_g_in_group(group, skip_none=True, raw=True):
|
1204
1211
|
p.grad = grads.pop(0)
|
1205
|
-
|
1212
|
+
stochastic_add_divide_(g, p.grad, -1, torch.finfo(p.dtype).eps ** 0.5)
|
1206
1213
|
p.hessian_vector = g
|
1207
1214
|
p.data.copy_(p.orig)
|
1208
1215
|
del p.orig
|
@@ -1292,6 +1299,8 @@ class StatefulOptimizer(torch.optim.Optimizer):
|
|
1292
1299
|
self._is_preconditioning = psgd_should_update(self.inner_group, self.precond_schedule, self.precond_rng)
|
1293
1300
|
loss = self._handle_closure(closure)
|
1294
1301
|
|
1302
|
+
if self.use_ema:
|
1303
|
+
self.ema_update()
|
1295
1304
|
# we assume that parameters are constant and that there are no excessive recompiles
|
1296
1305
|
with torch.no_grad(), torch._dynamo.utils.disable_cache_limit():
|
1297
1306
|
for group in self.param_groups:
|
@@ -1299,8 +1308,6 @@ class StatefulOptimizer(torch.optim.Optimizer):
|
|
1299
1308
|
group["param_count"] = sum(p.numel() for p in group["params"])
|
1300
1309
|
group["is_preconditioning"] = self._is_preconditioning
|
1301
1310
|
self._step(group)
|
1302
|
-
if self.use_ema:
|
1303
|
-
self.ema_update()
|
1304
1311
|
for real, views in self.mapping.items():
|
1305
1312
|
for tensor in (real, *views):
|
1306
1313
|
for key in ("grad", "vector", "hessian_vector", "orig"):
|
@@ -1564,18 +1571,20 @@ def stochastic_round_list_(ref: List[Tensor], source: List[Tensor]):
|
|
1564
1571
|
|
1565
1572
|
@decorator_knowngood
|
1566
1573
|
def stochastic_round_(ref: Tensor, source: Tensor | None = None):
|
1567
|
-
if source is
|
1568
|
-
if source.dtype == torch.bfloat16 or ref.dtype == source.dtype:
|
1569
|
-
return source
|
1570
|
-
if ref.dtype != torch.bfloat16:
|
1571
|
-
return source.to(ref.dtype)
|
1572
|
-
else:
|
1574
|
+
if source is None:
|
1573
1575
|
source = ref
|
1574
|
-
|
1575
|
-
|
1576
|
-
|
1577
|
-
|
1578
|
-
|
1576
|
+
if ref.dtype != torch.bfloat16:
|
1577
|
+
return source.to(ref.dtype)
|
1578
|
+
if source.dtype == torch.bfloat16:
|
1579
|
+
return source
|
1580
|
+
if source.dtype in (torch.float16, torch.float32, torch.float64):
|
1581
|
+
source = source.to(torch.float32)
|
1582
|
+
noise = torch.randint_like(source, dtype=torch.int32, low=0, high=(1 << 16))
|
1583
|
+
bits = source.view(dtype=torch.int32)
|
1584
|
+
bits.add_(noise)
|
1585
|
+
bits.bitwise_and_(-65536) # FFFF0000 mask, preserves sign+exp+7 mantissa bits
|
1586
|
+
return bits.view(dtype=torch.float32).bfloat16()
|
1587
|
+
return source.to(ref.dtype)
|
1579
1588
|
|
1580
1589
|
|
1581
1590
|
@decorator_knowngood
|
@@ -1585,7 +1594,7 @@ def _compilable_copy_stochastic_(target: Tensor, source: Tensor):
|
|
1585
1594
|
|
1586
1595
|
def copy_stochastic_(target: Tensor, source: Tensor):
|
1587
1596
|
if target.dtype == torch.bfloat16 and source.dtype in (torch.float16, torch.float32, torch.float64):
|
1588
|
-
|
1597
|
+
source = stochastic_round_(target, source)
|
1589
1598
|
set_(target, source)
|
1590
1599
|
|
1591
1600
|
|
@@ -1908,7 +1917,8 @@ def update_lra_precond_(
|
|
1908
1917
|
|
1909
1918
|
# LU factorization to reuse computation
|
1910
1919
|
try:
|
1911
|
-
|
1920
|
+
lu_matrix = promote(IpVtU) # operate in fp32 when inputs are bf16/half
|
1921
|
+
LU, pivots = torch.linalg.lu_factor(lu_matrix)
|
1912
1922
|
except RuntimeError:
|
1913
1923
|
# Error:
|
1914
1924
|
# U[2,2] is zero and using it on lu_solve would result in a division by zero.
|
@@ -1918,8 +1928,13 @@ def update_lra_precond_(
|
|
1918
1928
|
# So, we skip this step and reattempt on the next one
|
1919
1929
|
return U.to(U_orig[0].dtype), V.to(V_orig[0].dtype), d.to(d_orig[0].dtype)
|
1920
1930
|
|
1921
|
-
|
1922
|
-
|
1931
|
+
solve_dtype = LU.dtype
|
1932
|
+
rhs = (U.T @ invQtv).view(-1, 1).to(solve_dtype)
|
1933
|
+
correction = torch.linalg.lu_solve(LU, pivots, rhs, adjoint=True).to(V.dtype)
|
1934
|
+
invQtv = invQtv - (V @ correction).flatten()
|
1935
|
+
rhs = (V.T @ invQtv).view(-1, 1).to(solve_dtype)
|
1936
|
+
solution = torch.linalg.lu_solve(LU, pivots, rhs).to(U.dtype)
|
1937
|
+
invPv = (U @ solution).flatten()
|
1923
1938
|
|
1924
1939
|
eps, step = scalar_guard(eps, step, vector)
|
1925
1940
|
_compilable_d_step(d, d_orig, invQtv, vector, invPv, hessian_vector, Ph, eps, step, delayed)
|
@@ -2039,7 +2054,10 @@ def extract_from_flat_update(params: List[Tensor], update: Tensor):
|
|
2039
2054
|
@decorator_knowngood
|
2040
2055
|
def flatten(x: List[Tensor], remaining: int = 0) -> Tensor:
|
2041
2056
|
last_dim = x[0].shape[-remaining:] if remaining else []
|
2042
|
-
|
2057
|
+
tensors = [i.reshape(-1, *last_dim) for i in x if i.numel()]
|
2058
|
+
if not tensors:
|
2059
|
+
return torch.zeros((), dtype=x[0].device, device=x[0].device)
|
2060
|
+
return torch.cat(tensors, 0)
|
2043
2061
|
|
2044
2062
|
|
2045
2063
|
@decorator_knowngood
|
@@ -2111,16 +2129,6 @@ def psgd_calc_A_and_conjB(G: Tensor, Q, conjB: Tensor | None): # conjB ("V", "v
|
|
2111
2129
|
return A, conjB
|
2112
2130
|
|
2113
2131
|
|
2114
|
-
@decorator_knowngood
|
2115
|
-
def _random_projection(x: Tensor, scale: Optional[Tensor]):
|
2116
|
-
if scale is None:
|
2117
|
-
scale = x.norm(float("inf")).clamp(min=1e-8)
|
2118
|
-
k = 2 ** math.ceil(math.log2(math.log2(min(x.shape)))) # next-largest-power-of-2 of log2-of-size
|
2119
|
-
norm = x.square().sum(0)
|
2120
|
-
indices = torch.topk(norm, k, largest=True).indices
|
2121
|
-
return x.index_select(1, indices).contiguous() / scale, scale
|
2122
|
-
|
2123
|
-
|
2124
2132
|
def max_singular_value_exact(A, use_lobpcg: bool = False):
|
2125
2133
|
try:
|
2126
2134
|
if use_lobpcg:
|
@@ -2164,7 +2172,15 @@ def max_singular_value_cholesky(A: Tensor, max_abs: Optional[Tensor] = None):
|
|
2164
2172
|
"""
|
2165
2173
|
Adapted from @evanatyourservice
|
2166
2174
|
"""
|
2167
|
-
|
2175
|
+
if max_abs is None:
|
2176
|
+
max_abs = A.norm(float("inf")).clamp(min=1e-8)
|
2177
|
+
|
2178
|
+
# cholesky uses random projection, but this uses topk -- topk is a warm start, which may converge to a biased result
|
2179
|
+
k = 2 ** math.ceil(math.log2(math.log2(min(A.shape)))) # next-largest-power-of-2 of log2-of-size
|
2180
|
+
norm = A.square().sum(0)
|
2181
|
+
indices = torch.topk(norm, k, largest=True).indices
|
2182
|
+
Y = A.index_select(1, indices).contiguous() / max_abs
|
2183
|
+
|
2168
2184
|
Q = inplace_orthogonal_(Y, precise_zeroth_power_mode)
|
2169
2185
|
Q = Q / max_abs
|
2170
2186
|
Z = A.T @ Q
|
@@ -2412,10 +2428,11 @@ def bf16_matmul(x: Tensor, y: Tensor):
|
|
2412
2428
|
def if_iscompiling(fn):
|
2413
2429
|
base = getattr(torch, fn.__name__, None)
|
2414
2430
|
|
2415
|
-
|
2416
|
-
|
2417
|
-
|
2418
|
-
|
2431
|
+
@functools.wraps(fn)
|
2432
|
+
def _fn(*args, **kwargs):
|
2433
|
+
if torch.compiler.is_compiling() and base is not None:
|
2434
|
+
return base(*args, **kwargs)
|
2435
|
+
return fn(*args, **kwargs)
|
2419
2436
|
|
2420
2437
|
return _fn
|
2421
2438
|
|
@@ -2551,7 +2568,7 @@ def _psgd_quad_preconditioner_grad(GG: List[Tensor], Q: List[Tensor], numel: int
|
|
2551
2568
|
else:
|
2552
2569
|
scale = gg.size(0) / numel
|
2553
2570
|
gg = 2 * torch.eye(gg.size(0), device=gg.device, dtype=gg.dtype) - gg * scale
|
2554
|
-
update = q - gg
|
2571
|
+
update = q - casted_einsum("ab,cd,bc", gg, gg, q)
|
2555
2572
|
out.append(update + update.T) # make matrix symmetric
|
2556
2573
|
return out
|
2557
2574
|
|
@@ -3105,7 +3122,7 @@ def pointwise_lr_adaptation(
|
|
3105
3122
|
):
|
3106
3123
|
grads, update, state, delta = list_guard(grads, update, state, delta)
|
3107
3124
|
lr_lr = scalar_guard(lr_lr, grads[0])
|
3108
|
-
|
3125
|
+
_compilable_pointwise_lr_adapt_(grads, update, state, delta, lr_lr)
|
3109
3126
|
|
3110
3127
|
|
3111
3128
|
def hook_optimizer_into_model(model, optimizer, *args, **kwargs):
|
@@ -3125,8 +3142,6 @@ def hook_optimizer_into_model(model, optimizer, *args, **kwargs):
|
|
3125
3142
|
|
3126
3143
|
def fused_hook(parameters, optimizer, *args, **kwargs):
|
3127
3144
|
parameters = list(parameters)
|
3128
|
-
param_count = len(parameters)
|
3129
|
-
seen_params = set()
|
3130
3145
|
|
3131
3146
|
o = optimizer(parameters, *args, **kwargs)
|
3132
3147
|
step_fn = o.step
|
@@ -3135,12 +3150,8 @@ def fused_hook(parameters, optimizer, *args, **kwargs):
|
|
3135
3150
|
)
|
3136
3151
|
|
3137
3152
|
def _step(p: Tensor):
|
3138
|
-
|
3139
|
-
|
3140
|
-
if len(seen_params) < param_count:
|
3141
|
-
step_fn()
|
3142
|
-
o.zero_grad()
|
3143
|
-
seen_params.clear()
|
3153
|
+
step_fn()
|
3154
|
+
o.zero_grad()
|
3144
3155
|
|
3145
3156
|
for p in parameters:
|
3146
3157
|
p.register_post_accumulate_grad_hook(_step)
|
@@ -3165,6 +3176,8 @@ def sam_step(parameters, ball_size, adaptive: bool = True):
|
|
3165
3176
|
old_params = []
|
3166
3177
|
for p in parameters:
|
3167
3178
|
old_params.append(p.detach().clone())
|
3179
|
+
if not hasattr_none(p, "grad"):
|
3180
|
+
continue
|
3168
3181
|
grad = promote(p.grad)
|
3169
3182
|
if adaptive:
|
3170
3183
|
grad = grad * promote(p).square()
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: heavyball
|
3
|
-
Version: 2.1.
|
3
|
+
Version: 2.1.4
|
4
4
|
Summary: Efficient Optimizers
|
5
5
|
Author-email: HeavyBall Authors <github.heavyball@nestler.sh>
|
6
6
|
Project-URL: source, https://github.com/HomebrewML/HeavyBall
|
@@ -21,6 +21,7 @@ Requires-Dist: numpy<2.0.0
|
|
21
21
|
Provides-Extra: dev
|
22
22
|
Requires-Dist: pre-commit; extra == "dev"
|
23
23
|
Requires-Dist: pytest; extra == "dev"
|
24
|
+
Requires-Dist: hypothesis; extra == "dev"
|
24
25
|
Requires-Dist: ruff; extra == "dev"
|
25
26
|
Requires-Dist: matplotlib; extra == "dev"
|
26
27
|
Requires-Dist: seaborn; extra == "dev"
|
@@ -0,0 +1,9 @@
|
|
1
|
+
heavyball/__init__.py,sha256=9VgWebob-zO7hKg_KmQuSOB4Z_Rh-gCDs_V2TTfQKSo,30123
|
2
|
+
heavyball/chainable.py,sha256=O8QiHJ-E5RD-fzo3iulSHgvKgtRZ1Lff2ls3iLmXcoI,42695
|
3
|
+
heavyball/helpers.py,sha256=eiotfrJz4V6ewfF9ZboC_JEUi_TCmO195uT6sqqohTE,33429
|
4
|
+
heavyball/utils.py,sha256=u4RFOdmYkhsjPE4M_N53oDnuh-vHbvRHc6OLQTEeq-c,105239
|
5
|
+
heavyball-2.1.4.dist-info/licenses/LICENSE,sha256=G9fFZcNIVWjU7o6Pr_4sJBRCNDU5X-zelSxIJ2D48ms,1323
|
6
|
+
heavyball-2.1.4.dist-info/METADATA,sha256=MxDWUcqFgMWmG3FXtf0UVzDy9qsAWea4tPgnDnx9wXQ,5088
|
7
|
+
heavyball-2.1.4.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
8
|
+
heavyball-2.1.4.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
|
9
|
+
heavyball-2.1.4.dist-info/RECORD,,
|
heavyball-2.1.2.dist-info/RECORD
DELETED
@@ -1,9 +0,0 @@
|
|
1
|
-
heavyball/__init__.py,sha256=1BTb7G-VcfcMyS4EpuVnhE5DBp2fj_Zzs9EQr6slPzg,30491
|
2
|
-
heavyball/chainable.py,sha256=8S-7QRZYiy_ARhQ8uDu5G0Eg3ouT9Vcfk-rxbKlp4zI,42510
|
3
|
-
heavyball/helpers.py,sha256=zk_S84wpGcvO9P6kn4UeaQUIDowHxcbM9qQITEm2g5I,30267
|
4
|
-
heavyball/utils.py,sha256=Lx9XlfkyQbfYMPtqiA0rNIz4PXQe_bpLqKFby3upHMw,104514
|
5
|
-
heavyball-2.1.2.dist-info/licenses/LICENSE,sha256=G9fFZcNIVWjU7o6Pr_4sJBRCNDU5X-zelSxIJ2D48ms,1323
|
6
|
-
heavyball-2.1.2.dist-info/METADATA,sha256=EMM0OI4cPeaQlMkts2j9CCp9KxhJm-o_9VDNLm4ySQg,5046
|
7
|
-
heavyball-2.1.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
8
|
-
heavyball-2.1.2.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
|
9
|
-
heavyball-2.1.2.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|