heavyball 2.1.3__tar.gz → 2.1.4__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {heavyball-2.1.3 → heavyball-2.1.4}/PKG-INFO +1 -1
- {heavyball-2.1.3 → heavyball-2.1.4}/heavyball/__init__.py +56 -89
- {heavyball-2.1.3 → heavyball-2.1.4}/heavyball/chainable.py +6 -4
- {heavyball-2.1.3 → heavyball-2.1.4}/heavyball/helpers.py +88 -47
- {heavyball-2.1.3 → heavyball-2.1.4}/heavyball/utils.py +58 -51
- {heavyball-2.1.3 → heavyball-2.1.4}/heavyball.egg-info/PKG-INFO +1 -1
- {heavyball-2.1.3 → heavyball-2.1.4}/heavyball.egg-info/SOURCES.txt +1 -0
- {heavyball-2.1.3 → heavyball-2.1.4}/pyproject.toml +1 -1
- {heavyball-2.1.3 → heavyball-2.1.4}/test/test_bf16_params.py +5 -0
- {heavyball-2.1.3 → heavyball-2.1.4}/test/test_foreach.py +40 -5
- {heavyball-2.1.3 → heavyball-2.1.4}/test/test_memory.py +0 -1
- heavyball-2.1.4/test/test_stochastic_utils_cpu.py +49 -0
- {heavyball-2.1.3 → heavyball-2.1.4}/test/test_utils_cpu.py +22 -21
- {heavyball-2.1.3 → heavyball-2.1.4}/LICENSE +0 -0
- {heavyball-2.1.3 → heavyball-2.1.4}/README.md +0 -0
- {heavyball-2.1.3 → heavyball-2.1.4}/heavyball.egg-info/dependency_links.txt +0 -0
- {heavyball-2.1.3 → heavyball-2.1.4}/heavyball.egg-info/requires.txt +0 -0
- {heavyball-2.1.3 → heavyball-2.1.4}/heavyball.egg-info/top_level.txt +0 -0
- {heavyball-2.1.3 → heavyball-2.1.4}/setup.cfg +0 -0
- {heavyball-2.1.3 → heavyball-2.1.4}/test/test_bf16_q.py +0 -0
- {heavyball-2.1.3 → heavyball-2.1.4}/test/test_bf16_storage.py +0 -0
- {heavyball-2.1.3 → heavyball-2.1.4}/test/test_caution.py +0 -0
- {heavyball-2.1.3 → heavyball-2.1.4}/test/test_chainable_cpu.py +0 -0
- {heavyball-2.1.3 → heavyball-2.1.4}/test/test_channels_last.py +0 -0
- {heavyball-2.1.3 → heavyball-2.1.4}/test/test_clip.py +0 -0
- {heavyball-2.1.3 → heavyball-2.1.4}/test/test_closure.py +0 -0
- {heavyball-2.1.3 → heavyball-2.1.4}/test/test_cpu_features.py +0 -0
- {heavyball-2.1.3 → heavyball-2.1.4}/test/test_ema.py +0 -0
- {heavyball-2.1.3 → heavyball-2.1.4}/test/test_helpers_cpu.py +0 -0
- {heavyball-2.1.3 → heavyball-2.1.4}/test/test_hook.py +0 -0
- {heavyball-2.1.3 → heavyball-2.1.4}/test/test_mars.py +0 -0
- {heavyball-2.1.3 → heavyball-2.1.4}/test/test_memory_leak.py +0 -0
- {heavyball-2.1.3 → heavyball-2.1.4}/test/test_merge.py +0 -0
- {heavyball-2.1.3 → heavyball-2.1.4}/test/test_migrate_cli.py +0 -0
- {heavyball-2.1.3 → heavyball-2.1.4}/test/test_nd_param.py +0 -0
- {heavyball-2.1.3 → heavyball-2.1.4}/test/test_no_grad.py +0 -0
- {heavyball-2.1.3 → heavyball-2.1.4}/test/test_optimizer_cpu_smoke.py +0 -0
- {heavyball-2.1.3 → heavyball-2.1.4}/test/test_psgd_precond_init_stability.py +0 -0
- {heavyball-2.1.3 → heavyball-2.1.4}/test/test_save_restore.py +0 -0
- {heavyball-2.1.3 → heavyball-2.1.4}/test/test_singular_values.py +0 -0
- {heavyball-2.1.3 → heavyball-2.1.4}/test/test_soap.py +0 -0
- {heavyball-2.1.3 → heavyball-2.1.4}/test/test_stochastic_updates.py +0 -0
- {heavyball-2.1.3 → heavyball-2.1.4}/test/test_toy_training.py +0 -0
- {heavyball-2.1.3 → heavyball-2.1.4}/test/test_utils_property.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
import functools
|
2
2
|
import math
|
3
|
-
from typing import Optional
|
3
|
+
from typing import Optional, Type, Union
|
4
4
|
|
5
5
|
import torch.optim
|
6
6
|
|
@@ -8,39 +8,6 @@ from . import chainable as C
|
|
8
8
|
from . import utils
|
9
9
|
|
10
10
|
|
11
|
-
class SAMWrapper(torch.optim.Optimizer):
|
12
|
-
def __init__(self, params, wrapped_optimizer: utils.StatefulOptimizer, ball: float = 0.1):
|
13
|
-
if not isinstance(wrapped_optimizer, utils.StatefulOptimizer):
|
14
|
-
raise ValueError(f"{wrapped_optimizer.__class__.__name__} is not a HeavyBall optimizer")
|
15
|
-
super().__init__(params, {"ball": ball})
|
16
|
-
self.wrapped_optimizer = wrapped_optimizer
|
17
|
-
|
18
|
-
@torch.no_grad()
|
19
|
-
def step(self, closure=None):
|
20
|
-
if closure is None:
|
21
|
-
raise ValueError("SAM requires closure")
|
22
|
-
with torch.enable_grad():
|
23
|
-
closure()
|
24
|
-
old_params = [utils.sam_step(group["params"], group["ball"]) for group in self.param_groups]
|
25
|
-
|
26
|
-
originaL_handle_closure = self.wrapped_optimizer._handle_closure
|
27
|
-
|
28
|
-
def _handle_closure(closure):
|
29
|
-
originaL_handle_closure(closure)
|
30
|
-
for group, old in zip(self.param_groups, old_params):
|
31
|
-
utils.copy_stochastic_list_(group["params"], old)
|
32
|
-
|
33
|
-
try:
|
34
|
-
self.wrapped_optimizer._handle_closure = _handle_closure
|
35
|
-
loss = self.wrapped_optimizer.step(closure)
|
36
|
-
finally:
|
37
|
-
self.wrapped_optimizer._handle_closure = originaL_handle_closure
|
38
|
-
return loss
|
39
|
-
|
40
|
-
def zero_grad(self, set_to_none: bool = True):
|
41
|
-
self.wrapped_optimizer.zero_grad()
|
42
|
-
|
43
|
-
|
44
11
|
class SGD(C.BaseOpt):
|
45
12
|
def __init__(
|
46
13
|
self,
|
@@ -778,7 +745,7 @@ class ForeachPSGDKron(C.BaseOpt):
|
|
778
745
|
beta=None,
|
779
746
|
betas=(0.9, 0.999),
|
780
747
|
weight_decay=0.0,
|
781
|
-
preconditioner_update_probability=
|
748
|
+
preconditioner_update_probability=C.use_default,
|
782
749
|
max_size_triangular=2048,
|
783
750
|
min_ndim_triangular=2,
|
784
751
|
memory_save_mode=None,
|
@@ -830,8 +797,8 @@ class ForeachPSGDKron(C.BaseOpt):
|
|
830
797
|
if kwargs:
|
831
798
|
utils.warn_once(f"Working with uncaptured keyword arguments: {kwargs}")
|
832
799
|
|
833
|
-
self.precond_schedule = (
|
834
|
-
defaults.pop("preconditioner_update_probability")
|
800
|
+
self.precond_schedule = C.default(
|
801
|
+
defaults.pop("preconditioner_update_probability"), utils.precond_update_prob_schedule()
|
835
802
|
)
|
836
803
|
params = defaults.pop("params")
|
837
804
|
|
@@ -890,7 +857,7 @@ class ForeachPSGDLRA(C.BaseOpt):
|
|
890
857
|
lr=0.001,
|
891
858
|
beta=0.9,
|
892
859
|
weight_decay=0.0,
|
893
|
-
preconditioner_update_probability=
|
860
|
+
preconditioner_update_probability=C.use_default,
|
894
861
|
momentum_into_precond_update=True,
|
895
862
|
rank: Optional[int] = None,
|
896
863
|
warmup_steps: int = 0,
|
@@ -924,8 +891,8 @@ class ForeachPSGDLRA(C.BaseOpt):
|
|
924
891
|
if kwargs:
|
925
892
|
utils.warn_once(f"Working with uncaptured keyword arguments: {kwargs}")
|
926
893
|
|
927
|
-
self.precond_schedule = (
|
928
|
-
defaults.pop("preconditioner_update_probability")
|
894
|
+
self.precond_schedule = C.default(
|
895
|
+
defaults.pop("preconditioner_update_probability"), utils.precond_update_prob_schedule()
|
929
896
|
)
|
930
897
|
params = defaults.pop("params")
|
931
898
|
|
@@ -960,6 +927,54 @@ class NewtonHybrid2PSGDLRA(ForeachNewtonPSGDLRA):
|
|
960
927
|
hvp_interval = 2
|
961
928
|
|
962
929
|
|
930
|
+
class SAMWrapper(torch.optim.Optimizer):
|
931
|
+
def __init__(
|
932
|
+
self,
|
933
|
+
params,
|
934
|
+
wrapped_optimizer: Union[utils.StatefulOptimizer, Type[utils.StatefulOptimizer]] = ForeachAdamW,
|
935
|
+
ball: float = 0.1,
|
936
|
+
):
|
937
|
+
params = list(params)
|
938
|
+
super().__init__(params, {"ball": ball})
|
939
|
+
|
940
|
+
if isinstance(wrapped_optimizer, type):
|
941
|
+
if not issubclass(wrapped_optimizer, utils.StatefulOptimizer):
|
942
|
+
raise ValueError(f"{wrapped_optimizer.__name__} is not a HeavyBall optimizer")
|
943
|
+
wrapped_optimizer = wrapped_optimizer(params)
|
944
|
+
elif not isinstance(wrapped_optimizer, utils.StatefulOptimizer):
|
945
|
+
raise ValueError(f"{wrapped_optimizer.__class__.__name__} is not a HeavyBall optimizer")
|
946
|
+
|
947
|
+
self.wrapped_optimizer = wrapped_optimizer
|
948
|
+
|
949
|
+
@torch.no_grad()
|
950
|
+
def step(self, closure=None):
|
951
|
+
if closure is None:
|
952
|
+
raise ValueError("SAM requires closure")
|
953
|
+
with torch.enable_grad():
|
954
|
+
closure()
|
955
|
+
old_params = [utils.sam_step(group["params"], group["ball"]) for group in self.param_groups]
|
956
|
+
|
957
|
+
original_handle_closure = self.wrapped_optimizer._handle_closure
|
958
|
+
|
959
|
+
def _handle_closure(closure):
|
960
|
+
try:
|
961
|
+
_loss = original_handle_closure(closure)
|
962
|
+
finally:
|
963
|
+
for group, old in zip(self.param_groups, old_params):
|
964
|
+
utils.copy_stochastic_list_(group["params"], old)
|
965
|
+
return _loss
|
966
|
+
|
967
|
+
try:
|
968
|
+
self.wrapped_optimizer._handle_closure = _handle_closure
|
969
|
+
loss = self.wrapped_optimizer.step(closure)
|
970
|
+
finally:
|
971
|
+
self.wrapped_optimizer._handle_closure = original_handle_closure
|
972
|
+
return loss
|
973
|
+
|
974
|
+
def zero_grad(self, set_to_none: bool = True):
|
975
|
+
self.wrapped_optimizer.zero_grad(set_to_none=set_to_none)
|
976
|
+
|
977
|
+
|
963
978
|
PalmForEachSoap = PaLMForeachSOAP
|
964
979
|
PaLMSOAP = PaLMForeachSOAP
|
965
980
|
PaLMSFAdamW = PaLMForeachSFAdamW
|
@@ -983,52 +998,4 @@ PSGDLRA = ForeachPSGDLRA
|
|
983
998
|
NewtonPSGDLRA = ForeachNewtonPSGDLRA
|
984
999
|
NewtonPSGDKron = ForeachCachedNewtonPSGD
|
985
1000
|
|
986
|
-
__all__ = [
|
987
|
-
"Muon",
|
988
|
-
"RMSprop",
|
989
|
-
"PrecondSchedulePaLMSOAP",
|
990
|
-
"PSGDKron",
|
991
|
-
"PurePSGD",
|
992
|
-
"DelayedPSGD",
|
993
|
-
"CachedPSGDKron",
|
994
|
-
"CachedDelayedPSGDKron",
|
995
|
-
"PalmForEachSoap",
|
996
|
-
"PaLMSOAP",
|
997
|
-
"PaLMSFAdamW",
|
998
|
-
"LaProp",
|
999
|
-
"ADOPT",
|
1000
|
-
"PrecondScheduleSOAP",
|
1001
|
-
"PrecondSchedulePaLMSOAP",
|
1002
|
-
"RMSprop",
|
1003
|
-
"MuonLaProp",
|
1004
|
-
"ForeachSignLaProp",
|
1005
|
-
"ForeachDelayedPSGDLRA",
|
1006
|
-
"ForeachPSGDLRA",
|
1007
|
-
"ForeachPSGDLRA",
|
1008
|
-
"ForeachNewtonPSGDLRA", #
|
1009
|
-
"ForeachAdamW",
|
1010
|
-
"ForeachSFAdamW",
|
1011
|
-
"ForeachLaProp",
|
1012
|
-
"ForeachADOPT",
|
1013
|
-
"ForeachSOAP",
|
1014
|
-
"ForeachPSGDKron",
|
1015
|
-
"ForeachPurePSGD",
|
1016
|
-
"ForeachDelayedPSGD",
|
1017
|
-
"ForeachCachedPSGDKron",
|
1018
|
-
"ForeachCachedDelayedPSGDKron",
|
1019
|
-
"ForeachRMSprop",
|
1020
|
-
"ForeachMuon",
|
1021
|
-
"ForeachCachedNewtonPSGD",
|
1022
|
-
"OrthoLaProp",
|
1023
|
-
"LaPropOrtho",
|
1024
|
-
"SignLaProp",
|
1025
|
-
"DelayedPSGD",
|
1026
|
-
"PSGDLRA",
|
1027
|
-
"NewtonPSGDLRA",
|
1028
|
-
"NewtonHybrid2PSGDLRA",
|
1029
|
-
"NewtonHybrid2PSGDKron",
|
1030
|
-
"MSAMLaProp",
|
1031
|
-
"NewtonPSGDKron",
|
1032
|
-
"ForeachAdamC",
|
1033
|
-
"SGD",
|
1034
|
-
]
|
1001
|
+
__all__ = [k for k, v in globals().items() if isinstance(v, type) and issubclass(v, torch.optim.Optimizer)]
|
@@ -62,8 +62,6 @@ class FunctionTransform:
|
|
62
62
|
self._init(st, group, *a, **kwargs)
|
63
63
|
except SkipUpdate:
|
64
64
|
skip_update = True
|
65
|
-
except:
|
66
|
-
raise
|
67
65
|
finally:
|
68
66
|
if "is_initialized" not in st:
|
69
67
|
st["is_initialized"] = set()
|
@@ -499,6 +497,7 @@ def scale_by_suds(group, update, grad, param, exp_avg, exp_avg_sq, fisher_approx
|
|
499
497
|
precond_update, _ = utils.eigvecs_product_rank1(precond_update.flatten(), fisher_approx.flatten(), w)
|
500
498
|
|
501
499
|
new_approx = utils.oja_update(fisher_approx.flatten().to(update.dtype), update.flatten(), group["precond_lr"])
|
500
|
+
new_approx = new_approx.view_as(fisher_approx)
|
502
501
|
utils.copy_stochastic_(fisher_approx, new_approx)
|
503
502
|
return precond_update
|
504
503
|
|
@@ -565,7 +564,7 @@ def _init_psgd_kron(state, group, update, grad, param, cached: bool = False, pro
|
|
565
564
|
)
|
566
565
|
state["Q"] = utils.triu_to_line(Q) if group["store_triu_as_line"] else Q
|
567
566
|
state["running_lower_bound"] = [torch.zeros((1,), device=q.device, dtype=torch.float64) for q in Q]
|
568
|
-
state["step"] = torch.zeros((), device=param.device, dtype=torch.
|
567
|
+
state["step"] = torch.zeros((), device=param.device, dtype=torch.float64) # torch casts int to float in ckpt load
|
569
568
|
if group["adaptive"]:
|
570
569
|
state["velocity"] = [torch.zeros((), device=q.device, dtype=q.dtype) for q in Q]
|
571
570
|
if not cached:
|
@@ -750,7 +749,9 @@ def _update_psgd_precond(
|
|
750
749
|
if isinstance(prob, float):
|
751
750
|
float_prob = prob
|
752
751
|
else:
|
753
|
-
|
752
|
+
prob_step = group.get(f"cumulative_prob_{id(Q)}_prob_step", 1)
|
753
|
+
float_prob = prob(prob_step)
|
754
|
+
group[f"cumulative_prob_{id(Q)}_prob_step"] = prob_step + 1
|
754
755
|
group["is_cached"] = should_use_cache = cached and float_prob < 0.5
|
755
756
|
|
756
757
|
if precond is not None:
|
@@ -1086,6 +1087,7 @@ class ChainOpt(utils.StatefulOptimizer):
|
|
1086
1087
|
if not group["foreach"] or len(p) == 1:
|
1087
1088
|
for param, grad in zip(p, g):
|
1088
1089
|
chain(self.state_, group, [grad], [param], *self.fns)
|
1090
|
+
group["caution"] = caution
|
1089
1091
|
else:
|
1090
1092
|
chain(self.state_, group, g, p, *self.fns)
|
1091
1093
|
|
@@ -1,5 +1,4 @@
|
|
1
|
-
|
2
|
-
|
1
|
+
import copy
|
3
2
|
import functools
|
4
3
|
import math
|
5
4
|
import threading
|
@@ -14,7 +13,7 @@ import pandas as pd
|
|
14
13
|
import torch
|
15
14
|
from hebo.design_space.design_space import DesignSpace
|
16
15
|
from hebo.optimizers.hebo import HEBO
|
17
|
-
from optuna._transform import _SearchSpaceTransform
|
16
|
+
from optuna._transform import _SearchSpaceTransform as SearchSpaceTransform
|
18
17
|
from optuna.distributions import BaseDistribution, CategoricalDistribution, FloatDistribution, IntDistribution
|
19
18
|
from optuna.samplers import BaseSampler, CmaEsSampler, RandomSampler
|
20
19
|
from optuna.samplers._lazy_random_state import LazyRandomState
|
@@ -60,9 +59,9 @@ def manual_seed(seed: int | None = None) -> Generator[None, None, None]:
|
|
60
59
|
class SimpleAPIBaseSampler(BaseSampler):
|
61
60
|
def __init__(
|
62
61
|
self,
|
63
|
-
search_space: dict[str, BaseDistribution] = None,
|
62
|
+
search_space: Optional[dict[str, BaseDistribution]] = None,
|
64
63
|
):
|
65
|
-
self.search_space = search_space
|
64
|
+
self.search_space = {} if search_space is None else dict(search_space)
|
66
65
|
|
67
66
|
def suggest_all(self, trial: FrozenTrial):
|
68
67
|
return {k: trial._suggest(k, dist) for k, dist in self.search_space.items()}
|
@@ -154,7 +153,7 @@ def _untransform_numerical_param_torch(
|
|
154
153
|
|
155
154
|
|
156
155
|
@torch.no_grad()
|
157
|
-
def untransform(self:
|
156
|
+
def untransform(self: SearchSpaceTransform, trans_params: Tensor) -> dict[str, Any]:
|
158
157
|
assert trans_params.shape == (self._raw_bounds.shape[0],)
|
159
158
|
|
160
159
|
if self._transform_0_1:
|
@@ -182,29 +181,31 @@ class BoTorchSampler(SimpleAPIBaseSampler):
|
|
182
181
|
|
183
182
|
def __init__(
|
184
183
|
self,
|
185
|
-
search_space: dict[str, BaseDistribution] = None,
|
184
|
+
search_space: Optional[dict[str, BaseDistribution]] = None,
|
186
185
|
*,
|
187
|
-
candidates_func:
|
188
|
-
constraints_func:
|
186
|
+
candidates_func: Optional[Callable[..., Tensor]] = None,
|
187
|
+
constraints_func: Optional[Callable[..., Tensor]] = None,
|
189
188
|
n_startup_trials: int = 10,
|
190
189
|
consider_running_trials: bool = False,
|
191
|
-
independent_sampler:
|
190
|
+
independent_sampler: Optional[BaseSampler] = None,
|
192
191
|
seed: int | None = None,
|
193
192
|
device: torch.device | str | None = None,
|
194
193
|
trial_chunks: int = 128,
|
195
194
|
):
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
195
|
+
if constraints_func is not None:
|
196
|
+
raise NotImplementedError("constraints_func is currently not supported by BoTorchSampler.")
|
197
|
+
if consider_running_trials:
|
198
|
+
raise NotImplementedError("consider_running_trials is currently not supported by BoTorchSampler.")
|
199
|
+
if candidates_func is not None and not callable(candidates_func):
|
200
|
+
raise TypeError("candidates_func must be callable.")
|
201
|
+
self._candidates_func = candidates_func
|
202
|
+
self._independent_sampler = independent_sampler or RandomSampler(seed=seed)
|
202
203
|
self._n_startup_trials = n_startup_trials
|
203
204
|
self._seed = seed
|
204
205
|
self.trial_chunks = trial_chunks
|
205
206
|
|
206
207
|
self._study_id: int | None = None
|
207
|
-
self.search_space = search_space
|
208
|
+
self.search_space = {} if search_space is None else dict(search_space)
|
208
209
|
if isinstance(device, str):
|
209
210
|
device = torch.device(device)
|
210
211
|
self._device = device or torch.device("cpu")
|
@@ -212,14 +213,24 @@ class BoTorchSampler(SimpleAPIBaseSampler):
|
|
212
213
|
self._values = None
|
213
214
|
self._params = None
|
214
215
|
self._index = 0
|
216
|
+
self._bounds_dim: int | None = None
|
215
217
|
|
216
218
|
def infer_relative_search_space(self, study: Study, trial: FrozenTrial) -> dict[str, BaseDistribution]:
|
217
219
|
return self.search_space
|
218
220
|
|
219
221
|
@torch.no_grad()
|
220
222
|
def _preprocess_trials(
|
221
|
-
self, trans:
|
223
|
+
self, trans: SearchSpaceTransform, study: Study, trials: list[FrozenTrial]
|
222
224
|
) -> Tuple[int, Tensor, Tensor]:
|
225
|
+
bounds_dim = trans.bounds.shape[0]
|
226
|
+
if self._bounds_dim is not None and self._bounds_dim != bounds_dim:
|
227
|
+
self._values = None
|
228
|
+
self._params = None
|
229
|
+
self._index = 0
|
230
|
+
self.seen_trials = set()
|
231
|
+
if self._bounds_dim is None:
|
232
|
+
self._bounds_dim = bounds_dim
|
233
|
+
|
223
234
|
new_trials = []
|
224
235
|
for trial in trials:
|
225
236
|
tid: int = trial._trial_id
|
@@ -230,6 +241,10 @@ class BoTorchSampler(SimpleAPIBaseSampler):
|
|
230
241
|
|
231
242
|
n_objectives = len(study.directions)
|
232
243
|
if not new_trials:
|
244
|
+
if self._values is None or self._params is None:
|
245
|
+
empty_values = torch.zeros((0, n_objectives), dtype=torch.float64, device=self._device)
|
246
|
+
empty_params = torch.zeros((0, bounds_dim), dtype=torch.float64, device=self._device)
|
247
|
+
return n_objectives, empty_values, empty_params
|
233
248
|
return n_objectives, self._values[: self._index], self._params[: self._index]
|
234
249
|
|
235
250
|
n_completed_trials = len(trials)
|
@@ -246,18 +261,28 @@ class BoTorchSampler(SimpleAPIBaseSampler):
|
|
246
261
|
if direction == StudyDirection.MINIMIZE: # BoTorch always assumes maximization.
|
247
262
|
values[:, obj_idx] *= -1
|
248
263
|
|
249
|
-
|
264
|
+
bounds_dim = trans.bounds.shape[0]
|
265
|
+
cache_stale = (
|
266
|
+
self._values is None
|
267
|
+
or self._params is None
|
268
|
+
or self._values.size(1) != n_objectives
|
269
|
+
or self._params.size(1) != bounds_dim
|
270
|
+
)
|
271
|
+
if cache_stale:
|
250
272
|
self._values = torch.zeros((self.trial_chunks, n_objectives), dtype=torch.float64, device=self._device)
|
251
|
-
self._params = torch.zeros(
|
252
|
-
|
253
|
-
)
|
273
|
+
self._params = torch.zeros((self.trial_chunks, bounds_dim), dtype=torch.float64, device=self._device)
|
274
|
+
self._index = 0
|
275
|
+
self.seen_trials = set()
|
276
|
+
self._bounds_dim = bounds_dim
|
254
277
|
spillage = (self._index + n_completed_trials) - self._values.size(0)
|
255
278
|
if spillage > 0:
|
256
279
|
pad = int(math.ceil(spillage / self.trial_chunks) * self.trial_chunks)
|
257
280
|
self._values = F.pad(self._values, (0, 0, 0, pad))
|
258
281
|
self._params = F.pad(self._params, (0, 0, 0, pad))
|
259
|
-
|
260
|
-
|
282
|
+
values_tensor = torch.from_numpy(values).to(self._device)
|
283
|
+
params_tensor = torch.from_numpy(params).to(self._device)
|
284
|
+
self._values[self._index : self._index + n_completed_trials] = values_tensor
|
285
|
+
self._params[self._index : self._index + n_completed_trials] = params_tensor
|
261
286
|
self._index += n_completed_trials
|
262
287
|
|
263
288
|
return n_objectives, self._values[: self._index], self._params[: self._index]
|
@@ -276,7 +301,7 @@ class BoTorchSampler(SimpleAPIBaseSampler):
|
|
276
301
|
if n_completed_trials < self._n_startup_trials:
|
277
302
|
return {}
|
278
303
|
|
279
|
-
trans =
|
304
|
+
trans = SearchSpaceTransform(search_space)
|
280
305
|
n_objectives, values, params = self._preprocess_trials(trans, study, completed_trials)
|
281
306
|
|
282
307
|
if self._candidates_func is None:
|
@@ -379,10 +404,10 @@ class HEBOSampler(optunahub.samplers.SimpleBaseSampler, SimpleAPIBaseSampler):
|
|
379
404
|
independent_sampler: BaseSampler | None = None,
|
380
405
|
) -> None:
|
381
406
|
super().__init__(search_space, seed)
|
382
|
-
|
383
|
-
|
407
|
+
if constant_liar:
|
408
|
+
raise NotImplementedError("constant_liar is not supported by HEBOSampler.")
|
384
409
|
self._hebo = HEBO(_convert_to_hebo_design_space(search_space), scramble_seed=self._seed)
|
385
|
-
self._independent_sampler = optuna.samplers.RandomSampler(seed=seed)
|
410
|
+
self._independent_sampler = independent_sampler or optuna.samplers.RandomSampler(seed=seed)
|
386
411
|
self._rng = np.random.default_rng(seed)
|
387
412
|
|
388
413
|
def sample_relative(
|
@@ -451,10 +476,12 @@ class FastINGO:
|
|
451
476
|
learning_rate: Optional[float] = None,
|
452
477
|
last_n: int = 4096,
|
453
478
|
loco_step_size: float = 0.1,
|
454
|
-
device=
|
479
|
+
device: str | None = None,
|
455
480
|
batchnorm_decay: float = 0.99,
|
456
481
|
score_decay: float = 0.99,
|
457
482
|
) -> None:
|
483
|
+
if device is None:
|
484
|
+
device = _use_cuda()
|
458
485
|
n_dimension = len(mean)
|
459
486
|
if population_size is None:
|
460
487
|
population_size = 4 + int(np.floor(3 * np.log(n_dimension)))
|
@@ -521,8 +548,14 @@ class FastINGO:
|
|
521
548
|
if y.numel() <= 2:
|
522
549
|
return
|
523
550
|
|
524
|
-
|
525
|
-
|
551
|
+
min_y = y.min()
|
552
|
+
max_y = y.max()
|
553
|
+
if torch.isclose(max_y, min_y, rtol=0.0, atol=1e-12):
|
554
|
+
return
|
555
|
+
|
556
|
+
if min_y <= 0:
|
557
|
+
y = y + (1e-8 - min_y)
|
558
|
+
y = y.clamp_min_(1e-8).log()
|
526
559
|
|
527
560
|
ema = -torch.arange(y.size(0), device=y.device, dtype=y.dtype)
|
528
561
|
weight = self.batchnorm_decay**ema
|
@@ -583,7 +616,7 @@ class ImplicitNaturalGradientSampler(BaseSampler):
|
|
583
616
|
def reseed_rng(self) -> None:
|
584
617
|
self._independent_sampler.reseed_rng()
|
585
618
|
if self._optimizer:
|
586
|
-
self._optimizer.
|
619
|
+
self._optimizer.generator.seed()
|
587
620
|
|
588
621
|
def infer_relative_search_space(
|
589
622
|
self, study: "optuna.Study", trial: "optuna.trial.FrozenTrial"
|
@@ -633,14 +666,11 @@ class ImplicitNaturalGradientSampler(BaseSampler):
|
|
633
666
|
self._warn_independent_sampling = False
|
634
667
|
return {}
|
635
668
|
|
636
|
-
trans =
|
669
|
+
trans = SearchSpaceTransform(search_space)
|
637
670
|
|
638
|
-
if self._optimizer is None:
|
671
|
+
if self._optimizer is None or self._optimizer.dim != len(trans.bounds):
|
639
672
|
self._optimizer = self._init_optimizer(trans, population_size=self._population_size)
|
640
|
-
|
641
|
-
if self._optimizer.dim != len(trans.bounds):
|
642
|
-
self._warn_independent_sampling = False
|
643
|
-
return {}
|
673
|
+
self._param_queue.clear()
|
644
674
|
|
645
675
|
solution_trials = [t for t in completed_trials if self._check_trial_is_generation(t)]
|
646
676
|
for t in solution_trials:
|
@@ -651,7 +681,7 @@ class ImplicitNaturalGradientSampler(BaseSampler):
|
|
651
681
|
|
652
682
|
def _init_optimizer(
|
653
683
|
self,
|
654
|
-
trans:
|
684
|
+
trans: SearchSpaceTransform,
|
655
685
|
population_size: Optional[int] = None,
|
656
686
|
) -> FastINGO:
|
657
687
|
lower_bounds = trans.bounds[:, 0]
|
@@ -705,6 +735,7 @@ class ThreadLocalSampler(threading.local):
|
|
705
735
|
|
706
736
|
|
707
737
|
def init_cmaes(study, seed, trials, search_space):
|
738
|
+
trials = copy.deepcopy(trials)
|
708
739
|
trials.sort(key=lambda trial: trial.datetime_complete)
|
709
740
|
return CmaEsSampler(seed=seed, source_trials=trials, lr_adapt=True)
|
710
741
|
|
@@ -716,8 +747,14 @@ def init_hebo(study, seed, trials, search_space):
|
|
716
747
|
return sampler
|
717
748
|
|
718
749
|
|
750
|
+
def _use_cuda():
|
751
|
+
return "cuda" if torch.cuda.is_available() else "cpu"
|
752
|
+
|
753
|
+
|
719
754
|
def init_botorch(study, seed, trials, search_space):
|
720
|
-
return BoTorchSampler(
|
755
|
+
return BoTorchSampler(
|
756
|
+
search_space=search_space, seed=seed, device=_use_cuda()
|
757
|
+
) # will automatically pull in latest data
|
721
758
|
|
722
759
|
|
723
760
|
def init_nsgaii(study, seed, trials, search_space):
|
@@ -739,17 +776,20 @@ class AutoSampler(BaseSampler):
|
|
739
776
|
def __init__(
|
740
777
|
self,
|
741
778
|
samplers: Iterable[Tuple[int, Callable]] | None = None,
|
742
|
-
search_space: dict[str, BaseDistribution] = None,
|
779
|
+
search_space: Optional[dict[str, BaseDistribution]] = None,
|
743
780
|
*,
|
744
781
|
seed: int | None = None,
|
745
|
-
constraints_func:
|
782
|
+
constraints_func: Optional[Callable[..., Any]] = None,
|
746
783
|
) -> None:
|
747
|
-
|
784
|
+
if constraints_func is not None:
|
785
|
+
raise NotImplementedError("constraints_func is not supported by AutoSampler.")
|
748
786
|
if samplers is None:
|
787
|
+
if search_space is None:
|
788
|
+
raise ValueError("AutoSampler requires a search_space when using the default sampler schedule.")
|
749
789
|
samplers = ((0, init_hebo), (100, init_nsgaii))
|
750
790
|
self.sampler_indices = np.sort(np.array([x[0] for x in samplers], dtype=np.int32))
|
751
791
|
self.samplers = [x[1] for x in sorted(samplers, key=lambda x: x[0])]
|
752
|
-
self.search_space = search_space
|
792
|
+
self.search_space = {} if search_space is None else dict(search_space)
|
753
793
|
self._rng = LazyRandomState(seed)
|
754
794
|
self._random_sampler = RandomSampler(seed=seed)
|
755
795
|
self._thread_local_sampler = ThreadLocalSampler()
|
@@ -792,7 +832,7 @@ class AutoSampler(BaseSampler):
|
|
792
832
|
complete_trials = study._get_trials(deepcopy=False, states=(TrialState.COMPLETE,), use_cache=True)
|
793
833
|
self._completed_trials = max(self._completed_trials, len(complete_trials))
|
794
834
|
new_index = (self._completed_trials >= self.sampler_indices).sum() - 1
|
795
|
-
if new_index == self._current_index:
|
835
|
+
if new_index == self._current_index or new_index < 0:
|
796
836
|
return
|
797
837
|
self._current_index = new_index
|
798
838
|
self._sampler = self.samplers[new_index](
|
@@ -805,7 +845,7 @@ class AutoSampler(BaseSampler):
|
|
805
845
|
def sample_relative(
|
806
846
|
self, study: Study, trial: FrozenTrial, search_space: dict[str, BaseDistribution]
|
807
847
|
) -> dict[str, Any]:
|
808
|
-
return self._sampler.sample_relative(study, trial, self.search_space)
|
848
|
+
return self._sampler.sample_relative(study, trial, search_space or self.search_space)
|
809
849
|
|
810
850
|
def sample_independent(
|
811
851
|
self,
|
@@ -834,5 +874,6 @@ class AutoSampler(BaseSampler):
|
|
834
874
|
state: TrialState,
|
835
875
|
values: Sequence[float] | None,
|
836
876
|
) -> None:
|
837
|
-
|
877
|
+
if state not in (TrialState.COMPLETE, TrialState.FAIL, TrialState.PRUNED):
|
878
|
+
raise ValueError(f"Unsupported trial state: {state}.")
|
838
879
|
self._sampler.after_trial(study, trial, state, values)
|
@@ -343,7 +343,8 @@ def set_(dst: Tensor, src: Tensor):
|
|
343
343
|
|
344
344
|
|
345
345
|
def clean():
|
346
|
-
torch.cuda.
|
346
|
+
if torch.cuda.is_available():
|
347
|
+
torch.cuda.empty_cache()
|
347
348
|
gc.collect()
|
348
349
|
|
349
350
|
|
@@ -470,7 +471,7 @@ def msign(G: torch.Tensor, steps: int = 10, eps: float = 1e-7) -> torch.Tensor:
|
|
470
471
|
|
471
472
|
if should_transpose:
|
472
473
|
x = x.mT
|
473
|
-
return x.
|
474
|
+
return x.to(G.dtype)
|
474
475
|
|
475
476
|
|
476
477
|
###### END
|
@@ -670,9 +671,9 @@ def get_orthogonal_matrix(mat, max_eps: float = 1e-3, min_eps: float = 1e-30):
|
|
670
671
|
final.append(None)
|
671
672
|
continue
|
672
673
|
|
674
|
+
device, dtype = m.device, m.dtype
|
673
675
|
m = promote(m.data)
|
674
676
|
|
675
|
-
device, dtype = m.device, m.dtype
|
676
677
|
eps = min_eps
|
677
678
|
while True:
|
678
679
|
try:
|
@@ -700,7 +701,6 @@ def get_orthogonal_matrix(mat, max_eps: float = 1e-3, min_eps: float = 1e-30):
|
|
700
701
|
raise
|
701
702
|
clean()
|
702
703
|
|
703
|
-
eigvec = eigvec.to(device=m.device, dtype=m.dtype)
|
704
704
|
eigvec = torch.flip(eigvec, [1])
|
705
705
|
final.append(eigvec)
|
706
706
|
|
@@ -1053,13 +1053,15 @@ class StatefulOptimizer(torch.optim.Optimizer):
|
|
1053
1053
|
def get_groups(self, group):
|
1054
1054
|
return [group]
|
1055
1055
|
|
1056
|
-
@functools.lru_cache(maxsize=None)
|
1057
1056
|
def state_(self, arg: Tensor, fail: bool = True):
|
1058
|
-
|
1059
|
-
|
1060
|
-
if _tensor_key(arg) not in self.mapping_inverse:
|
1057
|
+
key = _tensor_key(arg)
|
1058
|
+
if key not in self.mapping_inverse:
|
1061
1059
|
self._init_mapping()
|
1062
|
-
|
1060
|
+
if key not in self.mapping_inverse:
|
1061
|
+
if not fail:
|
1062
|
+
return {}
|
1063
|
+
raise KeyError("Tensor has no tracked state.")
|
1064
|
+
state_param, index = self.mapping_inverse[key]
|
1063
1065
|
if state_param not in self.state:
|
1064
1066
|
self.state[state_param] = collections.defaultdict(dict)
|
1065
1067
|
return self.state[state_param][index]
|
@@ -1147,7 +1149,7 @@ class StatefulOptimizer(torch.optim.Optimizer):
|
|
1147
1149
|
active_p = [p for p in group["params"]]
|
1148
1150
|
|
1149
1151
|
if not active_p:
|
1150
|
-
|
1152
|
+
continue
|
1151
1153
|
|
1152
1154
|
k = group["ema_step"] = group.get("ema_step", -1) + 1
|
1153
1155
|
|
@@ -1164,7 +1166,7 @@ class StatefulOptimizer(torch.optim.Optimizer):
|
|
1164
1166
|
active_p = [p for p in group["params"]]
|
1165
1167
|
|
1166
1168
|
if not active_p:
|
1167
|
-
|
1169
|
+
continue
|
1168
1170
|
|
1169
1171
|
for p in active_p:
|
1170
1172
|
if "param_ema" in self.state_(p):
|
@@ -1178,7 +1180,7 @@ class StatefulOptimizer(torch.optim.Optimizer):
|
|
1178
1180
|
active_p = [p for p in group["params"]]
|
1179
1181
|
|
1180
1182
|
if not active_p:
|
1181
|
-
|
1183
|
+
continue
|
1182
1184
|
|
1183
1185
|
for p in active_p:
|
1184
1186
|
if "param_ema" in self.state_(p):
|
@@ -1207,7 +1209,7 @@ class StatefulOptimizer(torch.optim.Optimizer):
|
|
1207
1209
|
for group in self.param_groups:
|
1208
1210
|
for p, g in self.split_p_and_g_in_group(group, skip_none=True, raw=True):
|
1209
1211
|
p.grad = grads.pop(0)
|
1210
|
-
|
1212
|
+
stochastic_add_divide_(g, p.grad, -1, torch.finfo(p.dtype).eps ** 0.5)
|
1211
1213
|
p.hessian_vector = g
|
1212
1214
|
p.data.copy_(p.orig)
|
1213
1215
|
del p.orig
|
@@ -1297,6 +1299,8 @@ class StatefulOptimizer(torch.optim.Optimizer):
|
|
1297
1299
|
self._is_preconditioning = psgd_should_update(self.inner_group, self.precond_schedule, self.precond_rng)
|
1298
1300
|
loss = self._handle_closure(closure)
|
1299
1301
|
|
1302
|
+
if self.use_ema:
|
1303
|
+
self.ema_update()
|
1300
1304
|
# we assume that parameters are constant and that there are no excessive recompiles
|
1301
1305
|
with torch.no_grad(), torch._dynamo.utils.disable_cache_limit():
|
1302
1306
|
for group in self.param_groups:
|
@@ -1304,8 +1308,6 @@ class StatefulOptimizer(torch.optim.Optimizer):
|
|
1304
1308
|
group["param_count"] = sum(p.numel() for p in group["params"])
|
1305
1309
|
group["is_preconditioning"] = self._is_preconditioning
|
1306
1310
|
self._step(group)
|
1307
|
-
if self.use_ema:
|
1308
|
-
self.ema_update()
|
1309
1311
|
for real, views in self.mapping.items():
|
1310
1312
|
for tensor in (real, *views):
|
1311
1313
|
for key in ("grad", "vector", "hessian_vector", "orig"):
|
@@ -1569,18 +1571,20 @@ def stochastic_round_list_(ref: List[Tensor], source: List[Tensor]):
|
|
1569
1571
|
|
1570
1572
|
@decorator_knowngood
|
1571
1573
|
def stochastic_round_(ref: Tensor, source: Tensor | None = None):
|
1572
|
-
if source is
|
1573
|
-
if source.dtype == torch.bfloat16 or ref.dtype == source.dtype:
|
1574
|
-
return source
|
1575
|
-
if ref.dtype != torch.bfloat16:
|
1576
|
-
return source.to(ref.dtype)
|
1577
|
-
else:
|
1574
|
+
if source is None:
|
1578
1575
|
source = ref
|
1579
|
-
|
1580
|
-
|
1581
|
-
|
1582
|
-
|
1583
|
-
|
1576
|
+
if ref.dtype != torch.bfloat16:
|
1577
|
+
return source.to(ref.dtype)
|
1578
|
+
if source.dtype == torch.bfloat16:
|
1579
|
+
return source
|
1580
|
+
if source.dtype in (torch.float16, torch.float32, torch.float64):
|
1581
|
+
source = source.to(torch.float32)
|
1582
|
+
noise = torch.randint_like(source, dtype=torch.int32, low=0, high=(1 << 16))
|
1583
|
+
bits = source.view(dtype=torch.int32)
|
1584
|
+
bits.add_(noise)
|
1585
|
+
bits.bitwise_and_(-65536) # FFFF0000 mask, preserves sign+exp+7 mantissa bits
|
1586
|
+
return bits.view(dtype=torch.float32).bfloat16()
|
1587
|
+
return source.to(ref.dtype)
|
1584
1588
|
|
1585
1589
|
|
1586
1590
|
@decorator_knowngood
|
@@ -1913,7 +1917,8 @@ def update_lra_precond_(
|
|
1913
1917
|
|
1914
1918
|
# LU factorization to reuse computation
|
1915
1919
|
try:
|
1916
|
-
|
1920
|
+
lu_matrix = promote(IpVtU) # operate in fp32 when inputs are bf16/half
|
1921
|
+
LU, pivots = torch.linalg.lu_factor(lu_matrix)
|
1917
1922
|
except RuntimeError:
|
1918
1923
|
# Error:
|
1919
1924
|
# U[2,2] is zero and using it on lu_solve would result in a division by zero.
|
@@ -1923,8 +1928,13 @@ def update_lra_precond_(
|
|
1923
1928
|
# So, we skip this step and reattempt on the next one
|
1924
1929
|
return U.to(U_orig[0].dtype), V.to(V_orig[0].dtype), d.to(d_orig[0].dtype)
|
1925
1930
|
|
1926
|
-
|
1927
|
-
|
1931
|
+
solve_dtype = LU.dtype
|
1932
|
+
rhs = (U.T @ invQtv).view(-1, 1).to(solve_dtype)
|
1933
|
+
correction = torch.linalg.lu_solve(LU, pivots, rhs, adjoint=True).to(V.dtype)
|
1934
|
+
invQtv = invQtv - (V @ correction).flatten()
|
1935
|
+
rhs = (V.T @ invQtv).view(-1, 1).to(solve_dtype)
|
1936
|
+
solution = torch.linalg.lu_solve(LU, pivots, rhs).to(U.dtype)
|
1937
|
+
invPv = (U @ solution).flatten()
|
1928
1938
|
|
1929
1939
|
eps, step = scalar_guard(eps, step, vector)
|
1930
1940
|
_compilable_d_step(d, d_orig, invQtv, vector, invPv, hessian_vector, Ph, eps, step, delayed)
|
@@ -2044,7 +2054,10 @@ def extract_from_flat_update(params: List[Tensor], update: Tensor):
|
|
2044
2054
|
@decorator_knowngood
|
2045
2055
|
def flatten(x: List[Tensor], remaining: int = 0) -> Tensor:
|
2046
2056
|
last_dim = x[0].shape[-remaining:] if remaining else []
|
2047
|
-
|
2057
|
+
tensors = [i.reshape(-1, *last_dim) for i in x if i.numel()]
|
2058
|
+
if not tensors:
|
2059
|
+
return torch.zeros((), dtype=x[0].device, device=x[0].device)
|
2060
|
+
return torch.cat(tensors, 0)
|
2048
2061
|
|
2049
2062
|
|
2050
2063
|
@decorator_knowngood
|
@@ -2116,16 +2129,6 @@ def psgd_calc_A_and_conjB(G: Tensor, Q, conjB: Tensor | None): # conjB ("V", "v
|
|
2116
2129
|
return A, conjB
|
2117
2130
|
|
2118
2131
|
|
2119
|
-
@decorator_knowngood
|
2120
|
-
def _random_projection(x: Tensor, scale: Optional[Tensor]):
|
2121
|
-
if scale is None:
|
2122
|
-
scale = x.norm(float("inf")).clamp(min=1e-8)
|
2123
|
-
k = 2 ** math.ceil(math.log2(math.log2(min(x.shape)))) # next-largest-power-of-2 of log2-of-size
|
2124
|
-
norm = x.square().sum(0)
|
2125
|
-
indices = torch.topk(norm, k, largest=True).indices
|
2126
|
-
return x.index_select(1, indices).contiguous() / scale, scale
|
2127
|
-
|
2128
|
-
|
2129
2132
|
def max_singular_value_exact(A, use_lobpcg: bool = False):
|
2130
2133
|
try:
|
2131
2134
|
if use_lobpcg:
|
@@ -2169,7 +2172,15 @@ def max_singular_value_cholesky(A: Tensor, max_abs: Optional[Tensor] = None):
|
|
2169
2172
|
"""
|
2170
2173
|
Adapted from @evanatyourservice
|
2171
2174
|
"""
|
2172
|
-
|
2175
|
+
if max_abs is None:
|
2176
|
+
max_abs = A.norm(float("inf")).clamp(min=1e-8)
|
2177
|
+
|
2178
|
+
# cholesky uses random projection, but this uses topk -- topk is a warm start, which may converge to a biased result
|
2179
|
+
k = 2 ** math.ceil(math.log2(math.log2(min(A.shape)))) # next-largest-power-of-2 of log2-of-size
|
2180
|
+
norm = A.square().sum(0)
|
2181
|
+
indices = torch.topk(norm, k, largest=True).indices
|
2182
|
+
Y = A.index_select(1, indices).contiguous() / max_abs
|
2183
|
+
|
2173
2184
|
Q = inplace_orthogonal_(Y, precise_zeroth_power_mode)
|
2174
2185
|
Q = Q / max_abs
|
2175
2186
|
Z = A.T @ Q
|
@@ -2557,7 +2568,7 @@ def _psgd_quad_preconditioner_grad(GG: List[Tensor], Q: List[Tensor], numel: int
|
|
2557
2568
|
else:
|
2558
2569
|
scale = gg.size(0) / numel
|
2559
2570
|
gg = 2 * torch.eye(gg.size(0), device=gg.device, dtype=gg.dtype) - gg * scale
|
2560
|
-
update = q - gg
|
2571
|
+
update = q - casted_einsum("ab,cd,bc", gg, gg, q)
|
2561
2572
|
out.append(update + update.T) # make matrix symmetric
|
2562
2573
|
return out
|
2563
2574
|
|
@@ -3111,7 +3122,7 @@ def pointwise_lr_adaptation(
|
|
3111
3122
|
):
|
3112
3123
|
grads, update, state, delta = list_guard(grads, update, state, delta)
|
3113
3124
|
lr_lr = scalar_guard(lr_lr, grads[0])
|
3114
|
-
|
3125
|
+
_compilable_pointwise_lr_adapt_(grads, update, state, delta, lr_lr)
|
3115
3126
|
|
3116
3127
|
|
3117
3128
|
def hook_optimizer_into_model(model, optimizer, *args, **kwargs):
|
@@ -3131,8 +3142,6 @@ def hook_optimizer_into_model(model, optimizer, *args, **kwargs):
|
|
3131
3142
|
|
3132
3143
|
def fused_hook(parameters, optimizer, *args, **kwargs):
|
3133
3144
|
parameters = list(parameters)
|
3134
|
-
param_count = len(parameters)
|
3135
|
-
seen_params = set()
|
3136
3145
|
|
3137
3146
|
o = optimizer(parameters, *args, **kwargs)
|
3138
3147
|
step_fn = o.step
|
@@ -3141,12 +3150,8 @@ def fused_hook(parameters, optimizer, *args, **kwargs):
|
|
3141
3150
|
)
|
3142
3151
|
|
3143
3152
|
def _step(p: Tensor):
|
3144
|
-
|
3145
|
-
|
3146
|
-
if len(seen_params) < param_count:
|
3147
|
-
step_fn()
|
3148
|
-
o.zero_grad()
|
3149
|
-
seen_params.clear()
|
3153
|
+
step_fn()
|
3154
|
+
o.zero_grad()
|
3150
3155
|
|
3151
3156
|
for p in parameters:
|
3152
3157
|
p.register_post_accumulate_grad_hook(_step)
|
@@ -3171,6 +3176,8 @@ def sam_step(parameters, ball_size, adaptive: bool = True):
|
|
3171
3176
|
old_params = []
|
3172
3177
|
for p in parameters:
|
3173
3178
|
old_params.append(p.detach().clone())
|
3179
|
+
if not hasattr_none(p, "grad"):
|
3180
|
+
continue
|
3174
3181
|
grad = promote(p.grad)
|
3175
3182
|
if adaptive:
|
3176
3183
|
grad = grad * promote(p).square()
|
@@ -5,7 +5,7 @@ build-backend = "setuptools.build_meta"
|
|
5
5
|
[project]
|
6
6
|
name = "heavyball"
|
7
7
|
description = "Efficient Optimizers"
|
8
|
-
version = "2.1.
|
8
|
+
version = "2.1.4"
|
9
9
|
authors = [{ name = "HeavyBall Authors", email = "github.heavyball@nestler.sh" }]
|
10
10
|
classifiers = ["Intended Audience :: Developers",
|
11
11
|
"Intended Audience :: Science/Research",
|
@@ -14,6 +14,11 @@ os.environ["TORCH_LOGS"] = "+recompiles"
|
|
14
14
|
|
15
15
|
config.cache_size_limit = 128
|
16
16
|
|
17
|
+
pytestmark = pytest.mark.skipif(
|
18
|
+
not torch.cuda.is_available(),
|
19
|
+
reason="CUDA is required to run bf16 foreach parameter tests.",
|
20
|
+
)
|
21
|
+
|
17
22
|
|
18
23
|
@pytest.mark.parametrize("opt", heavyball.__all__)
|
19
24
|
@pytest.mark.parametrize("size,depth", [(256, 1)])
|
@@ -1,3 +1,5 @@
|
|
1
|
+
import os
|
2
|
+
|
1
3
|
import pytest
|
2
4
|
import torch
|
3
5
|
from lightbench.utils import get_optim
|
@@ -15,13 +17,40 @@ def get_memory():
|
|
15
17
|
return torch.cuda.memory_allocated()
|
16
18
|
|
17
19
|
|
20
|
+
def _read_int(name: str, default: int, *, minimum: int) -> int:
|
21
|
+
raw = os.environ.get(name)
|
22
|
+
if raw is None:
|
23
|
+
return default
|
24
|
+
try:
|
25
|
+
return max(minimum, int(raw))
|
26
|
+
except ValueError:
|
27
|
+
return default
|
28
|
+
|
29
|
+
|
30
|
+
DEFAULT_SIZE = _read_int("HB_FOREACH_TEST_SIZE", 128, minimum=1)
|
31
|
+
DEFAULT_DEPTH = _read_int("HB_FOREACH_TEST_DEPTH", 16, minimum=1)
|
32
|
+
DEFAULT_ITERATIONS = _read_int("HB_FOREACH_TEST_ITERATIONS", 64, minimum=1)
|
33
|
+
DEFAULT_OUTER = _read_int("HB_FOREACH_TEST_OUTER", 1, minimum=1)
|
34
|
+
DEFAULT_WARMUP = _read_int("HB_FOREACH_TEST_WARMUP", 1, minimum=0)
|
35
|
+
|
36
|
+
|
18
37
|
@pytest.mark.parametrize("opt", heavyball.__all__)
|
19
|
-
@pytest.mark.parametrize("size,depth", [(
|
20
|
-
def test_foreach(
|
38
|
+
@pytest.mark.parametrize("size,depth", [(DEFAULT_SIZE, DEFAULT_DEPTH)])
|
39
|
+
def test_foreach(
|
40
|
+
opt,
|
41
|
+
size,
|
42
|
+
depth: int,
|
43
|
+
iterations: int = DEFAULT_ITERATIONS,
|
44
|
+
outer_iterations: int = DEFAULT_OUTER,
|
45
|
+
warmup_runs: int = DEFAULT_WARMUP,
|
46
|
+
):
|
21
47
|
set_torch()
|
22
48
|
|
23
49
|
opt = getattr(heavyball, opt)
|
24
50
|
|
51
|
+
total_runs = warmup_runs + outer_iterations
|
52
|
+
assert total_runs >= 1
|
53
|
+
|
25
54
|
peaks = []
|
26
55
|
losses = []
|
27
56
|
|
@@ -30,7 +59,7 @@ def test_foreach(opt, size, depth: int, iterations: int = 4096, outer_iterations
|
|
30
59
|
peaks.append([])
|
31
60
|
losses.append([])
|
32
61
|
|
33
|
-
for i in range(
|
62
|
+
for i in range(total_runs):
|
34
63
|
clean()
|
35
64
|
model = nn.Sequential(*[nn.Linear(size, size) for _ in range(depth)]).cuda()
|
36
65
|
clean()
|
@@ -56,8 +85,14 @@ def test_foreach(opt, size, depth: int, iterations: int = 4096, outer_iterations
|
|
56
85
|
|
57
86
|
peak = torch.cuda.memory_stats()["allocated_bytes.all.peak"]
|
58
87
|
|
59
|
-
if i
|
60
|
-
|
88
|
+
if i < warmup_runs:
|
89
|
+
continue
|
90
|
+
|
91
|
+
peaks[-1].append(peak)
|
92
|
+
|
93
|
+
if warmup_runs:
|
94
|
+
cutoff = warmup_runs * iterations
|
95
|
+
losses = [loss_list[cutoff:] for loss_list in losses]
|
61
96
|
|
62
97
|
for p0, p1 in zip(*peaks):
|
63
98
|
assert p0 > p1
|
@@ -0,0 +1,49 @@
|
|
1
|
+
import os
|
2
|
+
|
3
|
+
os.environ.setdefault("TORCH_COMPILE_DISABLE", "1")
|
4
|
+
|
5
|
+
import torch
|
6
|
+
|
7
|
+
from heavyball.utils import copy_stochastic_, stochastic_add_, stochastic_divide_with_eps_
|
8
|
+
|
9
|
+
|
10
|
+
def _average_stochastic_round(source: torch.Tensor, trials: int = 512) -> torch.Tensor:
|
11
|
+
dest = torch.empty_like(source, dtype=torch.bfloat16)
|
12
|
+
totals = torch.zeros_like(source, dtype=torch.float64)
|
13
|
+
for _ in range(trials):
|
14
|
+
copy_stochastic_(dest, source)
|
15
|
+
totals += dest.to(dtype=torch.float64)
|
16
|
+
return totals / trials
|
17
|
+
|
18
|
+
|
19
|
+
def test_copy_stochastic_round_is_close_to_source_mean():
|
20
|
+
torch.manual_seed(0x5566AA)
|
21
|
+
values = torch.randn(2048, dtype=torch.float32) * 3.0
|
22
|
+
averaged = _average_stochastic_round(values, trials=256)
|
23
|
+
delta = averaged - values.double()
|
24
|
+
|
25
|
+
# Stochastic round should stay close to the original float32 values.
|
26
|
+
assert delta.abs().mean().item() < 5e-3
|
27
|
+
assert delta.abs().max().item() < 2.5e-2
|
28
|
+
|
29
|
+
|
30
|
+
def test_stochastic_add_broadcasts_partner_lists():
|
31
|
+
torch.manual_seed(0x172893)
|
32
|
+
targets = [torch.zeros(4, dtype=torch.bfloat16) for _ in range(2)]
|
33
|
+
partner = [torch.linspace(-1.0, 1.0, 4, dtype=torch.float32)]
|
34
|
+
|
35
|
+
stochastic_add_(targets, partner, alpha=0.25)
|
36
|
+
expected = partner[0] * 0.25
|
37
|
+
for tensor in targets:
|
38
|
+
assert torch.allclose(tensor.float(), expected, atol=5e-3, rtol=0)
|
39
|
+
|
40
|
+
|
41
|
+
def test_stochastic_divide_with_eps_matches_float_result():
|
42
|
+
torch.manual_seed(0xABCDEF)
|
43
|
+
numerator = torch.randn(32, dtype=torch.bfloat16)
|
44
|
+
denominator = torch.rand(32, dtype=torch.bfloat16) + 0.05
|
45
|
+
result = numerator.clone()
|
46
|
+
|
47
|
+
stochastic_divide_with_eps_(result, denominator, eps=1e-3)
|
48
|
+
expected = numerator.float() / (denominator.float() + 1e-3)
|
49
|
+
assert torch.allclose(result.float(), expected, atol=2e-2, rtol=2e-2)
|
@@ -3,7 +3,6 @@ import random
|
|
3
3
|
import warnings
|
4
4
|
from copy import deepcopy
|
5
5
|
|
6
|
-
import numpy as np
|
7
6
|
import pytest
|
8
7
|
import torch
|
9
8
|
from torch import Tensor, nn
|
@@ -222,29 +221,31 @@ def test_stochastic_math_helpers_match_expected_results():
|
|
222
221
|
assert torch.allclose(a, torch.full_like(a, expected), atol=1e-6)
|
223
222
|
|
224
223
|
|
225
|
-
def test_stochastic_math_accuracy(
|
226
|
-
"""
|
227
|
-
TODO: Rework this test or stochastic rounding.
|
228
|
-
With target_shift=1, it passes. With target_shift != 1, it does not pass -- this is unexpected
|
229
|
-
"""
|
224
|
+
def test_stochastic_math_accuracy():
|
230
225
|
torch.manual_seed(0x172893)
|
231
|
-
|
226
|
+
items = 8
|
227
|
+
steps = 2048
|
228
|
+
increments = torch.full((items,), 1e-3, dtype=torch.float32)
|
232
229
|
|
233
|
-
|
234
|
-
|
235
|
-
|
230
|
+
baseline = torch.zeros(items, dtype=torch.bfloat16)
|
231
|
+
stochastic = torch.zeros(items, dtype=torch.bfloat16)
|
232
|
+
ground_truth = torch.zeros(items, dtype=torch.float64)
|
236
233
|
|
237
|
-
|
238
|
-
|
239
|
-
|
240
|
-
|
241
|
-
|
242
|
-
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
234
|
+
for _ in range(steps):
|
235
|
+
baseline.add_(increments)
|
236
|
+
ground_truth.add_(increments)
|
237
|
+
stochastic_add_(stochastic, increments)
|
238
|
+
|
239
|
+
baseline_error = torch.abs(baseline.float() - ground_truth.float()).mean().item()
|
240
|
+
stochastic_error = torch.abs(stochastic.float() - ground_truth.float()).mean().item()
|
241
|
+
|
242
|
+
assert baseline_error > 1.0
|
243
|
+
assert stochastic_error < 0.2
|
244
|
+
assert stochastic_error < baseline_error * 0.2
|
245
|
+
|
246
|
+
baseline_bias = abs(baseline.float().mean().item() - ground_truth.float().mean().item())
|
247
|
+
stochastic_bias = abs(stochastic.float().mean().item() - ground_truth.float().mean().item())
|
248
|
+
assert stochastic_bias < baseline_bias
|
248
249
|
|
249
250
|
|
250
251
|
def test_disable_caution_scaling_toggles_behavior():
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|