heavyball 2.1.2__tar.gz → 2.1.4__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {heavyball-2.1.2 → heavyball-2.1.4}/PKG-INFO +2 -1
- {heavyball-2.1.2 → heavyball-2.1.4}/heavyball/__init__.py +56 -89
- {heavyball-2.1.2 → heavyball-2.1.4}/heavyball/chainable.py +6 -4
- {heavyball-2.1.2 → heavyball-2.1.4}/heavyball/helpers.py +127 -56
- {heavyball-2.1.2 → heavyball-2.1.4}/heavyball/utils.py +74 -61
- {heavyball-2.1.2 → heavyball-2.1.4}/heavyball.egg-info/PKG-INFO +2 -1
- {heavyball-2.1.2 → heavyball-2.1.4}/heavyball.egg-info/SOURCES.txt +8 -1
- {heavyball-2.1.2 → heavyball-2.1.4}/heavyball.egg-info/requires.txt +1 -0
- {heavyball-2.1.2 → heavyball-2.1.4}/pyproject.toml +2 -2
- {heavyball-2.1.2 → heavyball-2.1.4}/test/test_bf16_params.py +5 -1
- {heavyball-2.1.2 → heavyball-2.1.4}/test/test_bf16_q.py +0 -1
- {heavyball-2.1.2 → heavyball-2.1.4}/test/test_bf16_storage.py +0 -1
- {heavyball-2.1.2 → heavyball-2.1.4}/test/test_caution.py +0 -1
- heavyball-2.1.4/test/test_chainable_cpu.py +65 -0
- {heavyball-2.1.2 → heavyball-2.1.4}/test/test_channels_last.py +0 -1
- {heavyball-2.1.2 → heavyball-2.1.4}/test/test_closure.py +0 -1
- heavyball-2.1.4/test/test_cpu_features.py +134 -0
- {heavyball-2.1.2 → heavyball-2.1.4}/test/test_ema.py +0 -1
- {heavyball-2.1.2 → heavyball-2.1.4}/test/test_foreach.py +40 -6
- heavyball-2.1.4/test/test_helpers_cpu.py +107 -0
- {heavyball-2.1.2 → heavyball-2.1.4}/test/test_hook.py +0 -1
- {heavyball-2.1.2 → heavyball-2.1.4}/test/test_mars.py +0 -1
- {heavyball-2.1.2 → heavyball-2.1.4}/test/test_memory.py +0 -2
- {heavyball-2.1.2 → heavyball-2.1.4}/test/test_memory_leak.py +0 -1
- {heavyball-2.1.2 → heavyball-2.1.4}/test/test_merge.py +0 -1
- {heavyball-2.1.2 → heavyball-2.1.4}/test/test_nd_param.py +0 -1
- {heavyball-2.1.2 → heavyball-2.1.4}/test/test_no_grad.py +0 -1
- heavyball-2.1.4/test/test_optimizer_cpu_smoke.py +65 -0
- {heavyball-2.1.2 → heavyball-2.1.4}/test/test_save_restore.py +0 -1
- {heavyball-2.1.2 → heavyball-2.1.4}/test/test_singular_values.py +1 -1
- {heavyball-2.1.2 → heavyball-2.1.4}/test/test_stochastic_updates.py +0 -1
- heavyball-2.1.4/test/test_stochastic_utils_cpu.py +49 -0
- {heavyball-2.1.2 → heavyball-2.1.4}/test/test_toy_training.py +4 -4
- heavyball-2.1.4/test/test_utils_cpu.py +296 -0
- heavyball-2.1.4/test/test_utils_property.py +281 -0
- {heavyball-2.1.2 → heavyball-2.1.4}/LICENSE +0 -0
- {heavyball-2.1.2 → heavyball-2.1.4}/README.md +0 -0
- {heavyball-2.1.2 → heavyball-2.1.4}/heavyball.egg-info/dependency_links.txt +0 -0
- {heavyball-2.1.2 → heavyball-2.1.4}/heavyball.egg-info/top_level.txt +0 -0
- {heavyball-2.1.2 → heavyball-2.1.4}/setup.cfg +0 -0
- {heavyball-2.1.2 → heavyball-2.1.4}/test/test_clip.py +0 -0
- {heavyball-2.1.2 → heavyball-2.1.4}/test/test_migrate_cli.py +0 -0
- {heavyball-2.1.2 → heavyball-2.1.4}/test/test_psgd_precond_init_stability.py +0 -0
- {heavyball-2.1.2 → heavyball-2.1.4}/test/test_soap.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: heavyball
|
3
|
-
Version: 2.1.
|
3
|
+
Version: 2.1.4
|
4
4
|
Summary: Efficient Optimizers
|
5
5
|
Author-email: HeavyBall Authors <github.heavyball@nestler.sh>
|
6
6
|
Project-URL: source, https://github.com/HomebrewML/HeavyBall
|
@@ -21,6 +21,7 @@ Requires-Dist: numpy<2.0.0
|
|
21
21
|
Provides-Extra: dev
|
22
22
|
Requires-Dist: pre-commit; extra == "dev"
|
23
23
|
Requires-Dist: pytest; extra == "dev"
|
24
|
+
Requires-Dist: hypothesis; extra == "dev"
|
24
25
|
Requires-Dist: ruff; extra == "dev"
|
25
26
|
Requires-Dist: matplotlib; extra == "dev"
|
26
27
|
Requires-Dist: seaborn; extra == "dev"
|
@@ -1,6 +1,6 @@
|
|
1
1
|
import functools
|
2
2
|
import math
|
3
|
-
from typing import Optional
|
3
|
+
from typing import Optional, Type, Union
|
4
4
|
|
5
5
|
import torch.optim
|
6
6
|
|
@@ -8,39 +8,6 @@ from . import chainable as C
|
|
8
8
|
from . import utils
|
9
9
|
|
10
10
|
|
11
|
-
class SAMWrapper(torch.optim.Optimizer):
|
12
|
-
def __init__(self, params, wrapped_optimizer: utils.StatefulOptimizer, ball: float = 0.1):
|
13
|
-
if not isinstance(wrapped_optimizer, utils.StatefulOptimizer):
|
14
|
-
raise ValueError(f"{wrapped_optimizer.__class__.__name__} is not a HeavyBall optimizer")
|
15
|
-
super().__init__(params, {"ball": ball})
|
16
|
-
self.wrapped_optimizer = wrapped_optimizer
|
17
|
-
|
18
|
-
@torch.no_grad()
|
19
|
-
def step(self, closure=None):
|
20
|
-
if closure is None:
|
21
|
-
raise ValueError("SAM requires closure")
|
22
|
-
with torch.enable_grad():
|
23
|
-
closure()
|
24
|
-
old_params = [utils.sam_step(group["params"], group["ball"]) for group in self.param_groups]
|
25
|
-
|
26
|
-
originaL_handle_closure = self.wrapped_optimizer._handle_closure
|
27
|
-
|
28
|
-
def _handle_closure(closure):
|
29
|
-
originaL_handle_closure(closure)
|
30
|
-
for group, old in zip(self.param_groups, old_params):
|
31
|
-
utils.copy_stochastic_list_(group["params"], old)
|
32
|
-
|
33
|
-
try:
|
34
|
-
self.wrapped_optimizer._handle_closure = _handle_closure
|
35
|
-
loss = self.wrapped_optimizer.step(closure)
|
36
|
-
finally:
|
37
|
-
self.wrapped_optimizer._handle_closure = originaL_handle_closure
|
38
|
-
return loss
|
39
|
-
|
40
|
-
def zero_grad(self, set_to_none: bool = True):
|
41
|
-
self.wrapped_optimizer.zero_grad()
|
42
|
-
|
43
|
-
|
44
11
|
class SGD(C.BaseOpt):
|
45
12
|
def __init__(
|
46
13
|
self,
|
@@ -778,7 +745,7 @@ class ForeachPSGDKron(C.BaseOpt):
|
|
778
745
|
beta=None,
|
779
746
|
betas=(0.9, 0.999),
|
780
747
|
weight_decay=0.0,
|
781
|
-
preconditioner_update_probability=
|
748
|
+
preconditioner_update_probability=C.use_default,
|
782
749
|
max_size_triangular=2048,
|
783
750
|
min_ndim_triangular=2,
|
784
751
|
memory_save_mode=None,
|
@@ -830,8 +797,8 @@ class ForeachPSGDKron(C.BaseOpt):
|
|
830
797
|
if kwargs:
|
831
798
|
utils.warn_once(f"Working with uncaptured keyword arguments: {kwargs}")
|
832
799
|
|
833
|
-
self.precond_schedule = (
|
834
|
-
defaults.pop("preconditioner_update_probability")
|
800
|
+
self.precond_schedule = C.default(
|
801
|
+
defaults.pop("preconditioner_update_probability"), utils.precond_update_prob_schedule()
|
835
802
|
)
|
836
803
|
params = defaults.pop("params")
|
837
804
|
|
@@ -890,7 +857,7 @@ class ForeachPSGDLRA(C.BaseOpt):
|
|
890
857
|
lr=0.001,
|
891
858
|
beta=0.9,
|
892
859
|
weight_decay=0.0,
|
893
|
-
preconditioner_update_probability=
|
860
|
+
preconditioner_update_probability=C.use_default,
|
894
861
|
momentum_into_precond_update=True,
|
895
862
|
rank: Optional[int] = None,
|
896
863
|
warmup_steps: int = 0,
|
@@ -924,8 +891,8 @@ class ForeachPSGDLRA(C.BaseOpt):
|
|
924
891
|
if kwargs:
|
925
892
|
utils.warn_once(f"Working with uncaptured keyword arguments: {kwargs}")
|
926
893
|
|
927
|
-
self.precond_schedule = (
|
928
|
-
defaults.pop("preconditioner_update_probability")
|
894
|
+
self.precond_schedule = C.default(
|
895
|
+
defaults.pop("preconditioner_update_probability"), utils.precond_update_prob_schedule()
|
929
896
|
)
|
930
897
|
params = defaults.pop("params")
|
931
898
|
|
@@ -960,6 +927,54 @@ class NewtonHybrid2PSGDLRA(ForeachNewtonPSGDLRA):
|
|
960
927
|
hvp_interval = 2
|
961
928
|
|
962
929
|
|
930
|
+
class SAMWrapper(torch.optim.Optimizer):
|
931
|
+
def __init__(
|
932
|
+
self,
|
933
|
+
params,
|
934
|
+
wrapped_optimizer: Union[utils.StatefulOptimizer, Type[utils.StatefulOptimizer]] = ForeachAdamW,
|
935
|
+
ball: float = 0.1,
|
936
|
+
):
|
937
|
+
params = list(params)
|
938
|
+
super().__init__(params, {"ball": ball})
|
939
|
+
|
940
|
+
if isinstance(wrapped_optimizer, type):
|
941
|
+
if not issubclass(wrapped_optimizer, utils.StatefulOptimizer):
|
942
|
+
raise ValueError(f"{wrapped_optimizer.__name__} is not a HeavyBall optimizer")
|
943
|
+
wrapped_optimizer = wrapped_optimizer(params)
|
944
|
+
elif not isinstance(wrapped_optimizer, utils.StatefulOptimizer):
|
945
|
+
raise ValueError(f"{wrapped_optimizer.__class__.__name__} is not a HeavyBall optimizer")
|
946
|
+
|
947
|
+
self.wrapped_optimizer = wrapped_optimizer
|
948
|
+
|
949
|
+
@torch.no_grad()
|
950
|
+
def step(self, closure=None):
|
951
|
+
if closure is None:
|
952
|
+
raise ValueError("SAM requires closure")
|
953
|
+
with torch.enable_grad():
|
954
|
+
closure()
|
955
|
+
old_params = [utils.sam_step(group["params"], group["ball"]) for group in self.param_groups]
|
956
|
+
|
957
|
+
original_handle_closure = self.wrapped_optimizer._handle_closure
|
958
|
+
|
959
|
+
def _handle_closure(closure):
|
960
|
+
try:
|
961
|
+
_loss = original_handle_closure(closure)
|
962
|
+
finally:
|
963
|
+
for group, old in zip(self.param_groups, old_params):
|
964
|
+
utils.copy_stochastic_list_(group["params"], old)
|
965
|
+
return _loss
|
966
|
+
|
967
|
+
try:
|
968
|
+
self.wrapped_optimizer._handle_closure = _handle_closure
|
969
|
+
loss = self.wrapped_optimizer.step(closure)
|
970
|
+
finally:
|
971
|
+
self.wrapped_optimizer._handle_closure = original_handle_closure
|
972
|
+
return loss
|
973
|
+
|
974
|
+
def zero_grad(self, set_to_none: bool = True):
|
975
|
+
self.wrapped_optimizer.zero_grad(set_to_none=set_to_none)
|
976
|
+
|
977
|
+
|
963
978
|
PalmForEachSoap = PaLMForeachSOAP
|
964
979
|
PaLMSOAP = PaLMForeachSOAP
|
965
980
|
PaLMSFAdamW = PaLMForeachSFAdamW
|
@@ -983,52 +998,4 @@ PSGDLRA = ForeachPSGDLRA
|
|
983
998
|
NewtonPSGDLRA = ForeachNewtonPSGDLRA
|
984
999
|
NewtonPSGDKron = ForeachCachedNewtonPSGD
|
985
1000
|
|
986
|
-
__all__ = [
|
987
|
-
"Muon",
|
988
|
-
"RMSprop",
|
989
|
-
"PrecondSchedulePaLMSOAP",
|
990
|
-
"PSGDKron",
|
991
|
-
"PurePSGD",
|
992
|
-
"DelayedPSGD",
|
993
|
-
"CachedPSGDKron",
|
994
|
-
"CachedDelayedPSGDKron",
|
995
|
-
"PalmForEachSoap",
|
996
|
-
"PaLMSOAP",
|
997
|
-
"PaLMSFAdamW",
|
998
|
-
"LaProp",
|
999
|
-
"ADOPT",
|
1000
|
-
"PrecondScheduleSOAP",
|
1001
|
-
"PrecondSchedulePaLMSOAP",
|
1002
|
-
"RMSprop",
|
1003
|
-
"MuonLaProp",
|
1004
|
-
"ForeachSignLaProp",
|
1005
|
-
"ForeachDelayedPSGDLRA",
|
1006
|
-
"ForeachPSGDLRA",
|
1007
|
-
"ForeachPSGDLRA",
|
1008
|
-
"ForeachNewtonPSGDLRA", #
|
1009
|
-
"ForeachAdamW",
|
1010
|
-
"ForeachSFAdamW",
|
1011
|
-
"ForeachLaProp",
|
1012
|
-
"ForeachADOPT",
|
1013
|
-
"ForeachSOAP",
|
1014
|
-
"ForeachPSGDKron",
|
1015
|
-
"ForeachPurePSGD",
|
1016
|
-
"ForeachDelayedPSGD",
|
1017
|
-
"ForeachCachedPSGDKron",
|
1018
|
-
"ForeachCachedDelayedPSGDKron",
|
1019
|
-
"ForeachRMSprop",
|
1020
|
-
"ForeachMuon",
|
1021
|
-
"ForeachCachedNewtonPSGD",
|
1022
|
-
"OrthoLaProp",
|
1023
|
-
"LaPropOrtho",
|
1024
|
-
"SignLaProp",
|
1025
|
-
"DelayedPSGD",
|
1026
|
-
"PSGDLRA",
|
1027
|
-
"NewtonPSGDLRA",
|
1028
|
-
"NewtonHybrid2PSGDLRA",
|
1029
|
-
"NewtonHybrid2PSGDKron",
|
1030
|
-
"MSAMLaProp",
|
1031
|
-
"NewtonPSGDKron",
|
1032
|
-
"ForeachAdamC",
|
1033
|
-
"SGD",
|
1034
|
-
]
|
1001
|
+
__all__ = [k for k, v in globals().items() if isinstance(v, type) and issubclass(v, torch.optim.Optimizer)]
|
@@ -62,8 +62,6 @@ class FunctionTransform:
|
|
62
62
|
self._init(st, group, *a, **kwargs)
|
63
63
|
except SkipUpdate:
|
64
64
|
skip_update = True
|
65
|
-
except:
|
66
|
-
raise
|
67
65
|
finally:
|
68
66
|
if "is_initialized" not in st:
|
69
67
|
st["is_initialized"] = set()
|
@@ -499,6 +497,7 @@ def scale_by_suds(group, update, grad, param, exp_avg, exp_avg_sq, fisher_approx
|
|
499
497
|
precond_update, _ = utils.eigvecs_product_rank1(precond_update.flatten(), fisher_approx.flatten(), w)
|
500
498
|
|
501
499
|
new_approx = utils.oja_update(fisher_approx.flatten().to(update.dtype), update.flatten(), group["precond_lr"])
|
500
|
+
new_approx = new_approx.view_as(fisher_approx)
|
502
501
|
utils.copy_stochastic_(fisher_approx, new_approx)
|
503
502
|
return precond_update
|
504
503
|
|
@@ -565,7 +564,7 @@ def _init_psgd_kron(state, group, update, grad, param, cached: bool = False, pro
|
|
565
564
|
)
|
566
565
|
state["Q"] = utils.triu_to_line(Q) if group["store_triu_as_line"] else Q
|
567
566
|
state["running_lower_bound"] = [torch.zeros((1,), device=q.device, dtype=torch.float64) for q in Q]
|
568
|
-
state["step"] = torch.zeros((), device=param.device, dtype=torch.
|
567
|
+
state["step"] = torch.zeros((), device=param.device, dtype=torch.float64) # torch casts int to float in ckpt load
|
569
568
|
if group["adaptive"]:
|
570
569
|
state["velocity"] = [torch.zeros((), device=q.device, dtype=q.dtype) for q in Q]
|
571
570
|
if not cached:
|
@@ -750,7 +749,9 @@ def _update_psgd_precond(
|
|
750
749
|
if isinstance(prob, float):
|
751
750
|
float_prob = prob
|
752
751
|
else:
|
753
|
-
|
752
|
+
prob_step = group.get(f"cumulative_prob_{id(Q)}_prob_step", 1)
|
753
|
+
float_prob = prob(prob_step)
|
754
|
+
group[f"cumulative_prob_{id(Q)}_prob_step"] = prob_step + 1
|
754
755
|
group["is_cached"] = should_use_cache = cached and float_prob < 0.5
|
755
756
|
|
756
757
|
if precond is not None:
|
@@ -1086,6 +1087,7 @@ class ChainOpt(utils.StatefulOptimizer):
|
|
1086
1087
|
if not group["foreach"] or len(p) == 1:
|
1087
1088
|
for param, grad in zip(p, g):
|
1088
1089
|
chain(self.state_, group, [grad], [param], *self.fns)
|
1090
|
+
group["caution"] = caution
|
1089
1091
|
else:
|
1090
1092
|
chain(self.state_, group, g, p, *self.fns)
|
1091
1093
|
|
@@ -1,9 +1,9 @@
|
|
1
|
-
|
2
|
-
|
1
|
+
import copy
|
3
2
|
import functools
|
4
3
|
import math
|
5
4
|
import threading
|
6
|
-
from
|
5
|
+
from contextlib import contextmanager
|
6
|
+
from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Sequence, Tuple, Union
|
7
7
|
|
8
8
|
import numpy
|
9
9
|
import numpy as np
|
@@ -11,23 +11,15 @@ import optuna
|
|
11
11
|
import optunahub
|
12
12
|
import pandas as pd
|
13
13
|
import torch
|
14
|
-
from botorch.utils.sampling import manual_seed
|
15
14
|
from hebo.design_space.design_space import DesignSpace
|
16
15
|
from hebo.optimizers.hebo import HEBO
|
17
|
-
from optuna._transform import _SearchSpaceTransform
|
16
|
+
from optuna._transform import _SearchSpaceTransform as SearchSpaceTransform
|
18
17
|
from optuna.distributions import BaseDistribution, CategoricalDistribution, FloatDistribution, IntDistribution
|
19
18
|
from optuna.samplers import BaseSampler, CmaEsSampler, RandomSampler
|
20
19
|
from optuna.samplers._lazy_random_state import LazyRandomState
|
21
20
|
from optuna.study import Study
|
22
21
|
from optuna.study._study_direction import StudyDirection
|
23
22
|
from optuna.trial import FrozenTrial, TrialState
|
24
|
-
from optuna_integration.botorch import (
|
25
|
-
ehvi_candidates_func,
|
26
|
-
logei_candidates_func,
|
27
|
-
qehvi_candidates_func,
|
28
|
-
qei_candidates_func,
|
29
|
-
qparego_candidates_func,
|
30
|
-
)
|
31
23
|
from torch import Tensor
|
32
24
|
from torch.nn import functional as F
|
33
25
|
|
@@ -37,12 +29,39 @@ _MAXINT32 = (1 << 31) - 1
|
|
37
29
|
_SAMPLER_KEY = "auto:sampler"
|
38
30
|
|
39
31
|
|
32
|
+
@contextmanager
|
33
|
+
def manual_seed(seed: int | None = None) -> Generator[None, None, None]:
|
34
|
+
r"""
|
35
|
+
Contextmanager for manual setting the torch.random seed.
|
36
|
+
|
37
|
+
Args:
|
38
|
+
seed: The seed to set the random number generator to.
|
39
|
+
|
40
|
+
Returns:
|
41
|
+
Generator
|
42
|
+
|
43
|
+
Example:
|
44
|
+
>>> with manual_seed(1234):
|
45
|
+
>>> X = torch.rand(3)
|
46
|
+
|
47
|
+
copied as-is from https://github.com/meta-pytorch/botorch/blob/a42cd65f9b704cdb6f2ee64db99a022eb15295d5/botorch/utils/sampling.py#L53C1-L75C50 under the MIT License
|
48
|
+
"""
|
49
|
+
old_state = torch.random.get_rng_state()
|
50
|
+
try:
|
51
|
+
if seed is not None:
|
52
|
+
torch.random.manual_seed(seed)
|
53
|
+
yield
|
54
|
+
finally:
|
55
|
+
if seed is not None:
|
56
|
+
torch.random.set_rng_state(old_state)
|
57
|
+
|
58
|
+
|
40
59
|
class SimpleAPIBaseSampler(BaseSampler):
|
41
60
|
def __init__(
|
42
61
|
self,
|
43
|
-
search_space: dict[str, BaseDistribution] = None,
|
62
|
+
search_space: Optional[dict[str, BaseDistribution]] = None,
|
44
63
|
):
|
45
|
-
self.search_space = search_space
|
64
|
+
self.search_space = {} if search_space is None else dict(search_space)
|
46
65
|
|
47
66
|
def suggest_all(self, trial: FrozenTrial):
|
48
67
|
return {k: trial._suggest(k, dist) for k, dist in self.search_space.items()}
|
@@ -65,6 +84,16 @@ def _get_default_candidates_func(
|
|
65
84
|
"""
|
66
85
|
The original is available at https://github.com/optuna/optuna-integration/blob/156a8bc081322791015d2beefff9373ed7b24047/optuna_integration/botorch/botorch.py under the MIT License
|
67
86
|
"""
|
87
|
+
|
88
|
+
# lazy import
|
89
|
+
from optuna_integration.botorch import (
|
90
|
+
ehvi_candidates_func,
|
91
|
+
logei_candidates_func,
|
92
|
+
qehvi_candidates_func,
|
93
|
+
qei_candidates_func,
|
94
|
+
qparego_candidates_func,
|
95
|
+
)
|
96
|
+
|
68
97
|
if n_objectives > 3 and not has_constraint and not consider_running_trials:
|
69
98
|
return ehvi_candidates_func
|
70
99
|
elif n_objectives > 3:
|
@@ -124,7 +153,7 @@ def _untransform_numerical_param_torch(
|
|
124
153
|
|
125
154
|
|
126
155
|
@torch.no_grad()
|
127
|
-
def untransform(self:
|
156
|
+
def untransform(self: SearchSpaceTransform, trans_params: Tensor) -> dict[str, Any]:
|
128
157
|
assert trans_params.shape == (self._raw_bounds.shape[0],)
|
129
158
|
|
130
159
|
if self._transform_0_1:
|
@@ -152,29 +181,31 @@ class BoTorchSampler(SimpleAPIBaseSampler):
|
|
152
181
|
|
153
182
|
def __init__(
|
154
183
|
self,
|
155
|
-
search_space: dict[str, BaseDistribution] = None,
|
184
|
+
search_space: Optional[dict[str, BaseDistribution]] = None,
|
156
185
|
*,
|
157
|
-
candidates_func:
|
158
|
-
constraints_func:
|
186
|
+
candidates_func: Optional[Callable[..., Tensor]] = None,
|
187
|
+
constraints_func: Optional[Callable[..., Tensor]] = None,
|
159
188
|
n_startup_trials: int = 10,
|
160
189
|
consider_running_trials: bool = False,
|
161
|
-
independent_sampler:
|
190
|
+
independent_sampler: Optional[BaseSampler] = None,
|
162
191
|
seed: int | None = None,
|
163
192
|
device: torch.device | str | None = None,
|
164
193
|
trial_chunks: int = 128,
|
165
194
|
):
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
195
|
+
if constraints_func is not None:
|
196
|
+
raise NotImplementedError("constraints_func is currently not supported by BoTorchSampler.")
|
197
|
+
if consider_running_trials:
|
198
|
+
raise NotImplementedError("consider_running_trials is currently not supported by BoTorchSampler.")
|
199
|
+
if candidates_func is not None and not callable(candidates_func):
|
200
|
+
raise TypeError("candidates_func must be callable.")
|
201
|
+
self._candidates_func = candidates_func
|
202
|
+
self._independent_sampler = independent_sampler or RandomSampler(seed=seed)
|
172
203
|
self._n_startup_trials = n_startup_trials
|
173
204
|
self._seed = seed
|
174
205
|
self.trial_chunks = trial_chunks
|
175
206
|
|
176
207
|
self._study_id: int | None = None
|
177
|
-
self.search_space = search_space
|
208
|
+
self.search_space = {} if search_space is None else dict(search_space)
|
178
209
|
if isinstance(device, str):
|
179
210
|
device = torch.device(device)
|
180
211
|
self._device = device or torch.device("cpu")
|
@@ -182,14 +213,24 @@ class BoTorchSampler(SimpleAPIBaseSampler):
|
|
182
213
|
self._values = None
|
183
214
|
self._params = None
|
184
215
|
self._index = 0
|
216
|
+
self._bounds_dim: int | None = None
|
185
217
|
|
186
218
|
def infer_relative_search_space(self, study: Study, trial: FrozenTrial) -> dict[str, BaseDistribution]:
|
187
219
|
return self.search_space
|
188
220
|
|
189
221
|
@torch.no_grad()
|
190
222
|
def _preprocess_trials(
|
191
|
-
self, trans:
|
223
|
+
self, trans: SearchSpaceTransform, study: Study, trials: list[FrozenTrial]
|
192
224
|
) -> Tuple[int, Tensor, Tensor]:
|
225
|
+
bounds_dim = trans.bounds.shape[0]
|
226
|
+
if self._bounds_dim is not None and self._bounds_dim != bounds_dim:
|
227
|
+
self._values = None
|
228
|
+
self._params = None
|
229
|
+
self._index = 0
|
230
|
+
self.seen_trials = set()
|
231
|
+
if self._bounds_dim is None:
|
232
|
+
self._bounds_dim = bounds_dim
|
233
|
+
|
193
234
|
new_trials = []
|
194
235
|
for trial in trials:
|
195
236
|
tid: int = trial._trial_id
|
@@ -200,6 +241,10 @@ class BoTorchSampler(SimpleAPIBaseSampler):
|
|
200
241
|
|
201
242
|
n_objectives = len(study.directions)
|
202
243
|
if not new_trials:
|
244
|
+
if self._values is None or self._params is None:
|
245
|
+
empty_values = torch.zeros((0, n_objectives), dtype=torch.float64, device=self._device)
|
246
|
+
empty_params = torch.zeros((0, bounds_dim), dtype=torch.float64, device=self._device)
|
247
|
+
return n_objectives, empty_values, empty_params
|
203
248
|
return n_objectives, self._values[: self._index], self._params[: self._index]
|
204
249
|
|
205
250
|
n_completed_trials = len(trials)
|
@@ -216,18 +261,28 @@ class BoTorchSampler(SimpleAPIBaseSampler):
|
|
216
261
|
if direction == StudyDirection.MINIMIZE: # BoTorch always assumes maximization.
|
217
262
|
values[:, obj_idx] *= -1
|
218
263
|
|
219
|
-
|
264
|
+
bounds_dim = trans.bounds.shape[0]
|
265
|
+
cache_stale = (
|
266
|
+
self._values is None
|
267
|
+
or self._params is None
|
268
|
+
or self._values.size(1) != n_objectives
|
269
|
+
or self._params.size(1) != bounds_dim
|
270
|
+
)
|
271
|
+
if cache_stale:
|
220
272
|
self._values = torch.zeros((self.trial_chunks, n_objectives), dtype=torch.float64, device=self._device)
|
221
|
-
self._params = torch.zeros(
|
222
|
-
|
223
|
-
)
|
273
|
+
self._params = torch.zeros((self.trial_chunks, bounds_dim), dtype=torch.float64, device=self._device)
|
274
|
+
self._index = 0
|
275
|
+
self.seen_trials = set()
|
276
|
+
self._bounds_dim = bounds_dim
|
224
277
|
spillage = (self._index + n_completed_trials) - self._values.size(0)
|
225
278
|
if spillage > 0:
|
226
279
|
pad = int(math.ceil(spillage / self.trial_chunks) * self.trial_chunks)
|
227
280
|
self._values = F.pad(self._values, (0, 0, 0, pad))
|
228
281
|
self._params = F.pad(self._params, (0, 0, 0, pad))
|
229
|
-
|
230
|
-
|
282
|
+
values_tensor = torch.from_numpy(values).to(self._device)
|
283
|
+
params_tensor = torch.from_numpy(params).to(self._device)
|
284
|
+
self._values[self._index : self._index + n_completed_trials] = values_tensor
|
285
|
+
self._params[self._index : self._index + n_completed_trials] = params_tensor
|
231
286
|
self._index += n_completed_trials
|
232
287
|
|
233
288
|
return n_objectives, self._values[: self._index], self._params[: self._index]
|
@@ -246,7 +301,7 @@ class BoTorchSampler(SimpleAPIBaseSampler):
|
|
246
301
|
if n_completed_trials < self._n_startup_trials:
|
247
302
|
return {}
|
248
303
|
|
249
|
-
trans =
|
304
|
+
trans = SearchSpaceTransform(search_space)
|
250
305
|
n_objectives, values, params = self._preprocess_trials(trans, study, completed_trials)
|
251
306
|
|
252
307
|
if self._candidates_func is None:
|
@@ -349,10 +404,10 @@ class HEBOSampler(optunahub.samplers.SimpleBaseSampler, SimpleAPIBaseSampler):
|
|
349
404
|
independent_sampler: BaseSampler | None = None,
|
350
405
|
) -> None:
|
351
406
|
super().__init__(search_space, seed)
|
352
|
-
|
353
|
-
|
407
|
+
if constant_liar:
|
408
|
+
raise NotImplementedError("constant_liar is not supported by HEBOSampler.")
|
354
409
|
self._hebo = HEBO(_convert_to_hebo_design_space(search_space), scramble_seed=self._seed)
|
355
|
-
self._independent_sampler = optuna.samplers.RandomSampler(seed=seed)
|
410
|
+
self._independent_sampler = independent_sampler or optuna.samplers.RandomSampler(seed=seed)
|
356
411
|
self._rng = np.random.default_rng(seed)
|
357
412
|
|
358
413
|
def sample_relative(
|
@@ -421,10 +476,12 @@ class FastINGO:
|
|
421
476
|
learning_rate: Optional[float] = None,
|
422
477
|
last_n: int = 4096,
|
423
478
|
loco_step_size: float = 0.1,
|
424
|
-
device=
|
479
|
+
device: str | None = None,
|
425
480
|
batchnorm_decay: float = 0.99,
|
426
481
|
score_decay: float = 0.99,
|
427
482
|
) -> None:
|
483
|
+
if device is None:
|
484
|
+
device = _use_cuda()
|
428
485
|
n_dimension = len(mean)
|
429
486
|
if population_size is None:
|
430
487
|
population_size = 4 + int(np.floor(3 * np.log(n_dimension)))
|
@@ -491,8 +548,14 @@ class FastINGO:
|
|
491
548
|
if y.numel() <= 2:
|
492
549
|
return
|
493
550
|
|
494
|
-
|
495
|
-
|
551
|
+
min_y = y.min()
|
552
|
+
max_y = y.max()
|
553
|
+
if torch.isclose(max_y, min_y, rtol=0.0, atol=1e-12):
|
554
|
+
return
|
555
|
+
|
556
|
+
if min_y <= 0:
|
557
|
+
y = y + (1e-8 - min_y)
|
558
|
+
y = y.clamp_min_(1e-8).log()
|
496
559
|
|
497
560
|
ema = -torch.arange(y.size(0), device=y.device, dtype=y.dtype)
|
498
561
|
weight = self.batchnorm_decay**ema
|
@@ -553,7 +616,7 @@ class ImplicitNaturalGradientSampler(BaseSampler):
|
|
553
616
|
def reseed_rng(self) -> None:
|
554
617
|
self._independent_sampler.reseed_rng()
|
555
618
|
if self._optimizer:
|
556
|
-
self._optimizer.
|
619
|
+
self._optimizer.generator.seed()
|
557
620
|
|
558
621
|
def infer_relative_search_space(
|
559
622
|
self, study: "optuna.Study", trial: "optuna.trial.FrozenTrial"
|
@@ -603,14 +666,11 @@ class ImplicitNaturalGradientSampler(BaseSampler):
|
|
603
666
|
self._warn_independent_sampling = False
|
604
667
|
return {}
|
605
668
|
|
606
|
-
trans =
|
669
|
+
trans = SearchSpaceTransform(search_space)
|
607
670
|
|
608
|
-
if self._optimizer is None:
|
671
|
+
if self._optimizer is None or self._optimizer.dim != len(trans.bounds):
|
609
672
|
self._optimizer = self._init_optimizer(trans, population_size=self._population_size)
|
610
|
-
|
611
|
-
if self._optimizer.dim != len(trans.bounds):
|
612
|
-
self._warn_independent_sampling = False
|
613
|
-
return {}
|
673
|
+
self._param_queue.clear()
|
614
674
|
|
615
675
|
solution_trials = [t for t in completed_trials if self._check_trial_is_generation(t)]
|
616
676
|
for t in solution_trials:
|
@@ -621,7 +681,7 @@ class ImplicitNaturalGradientSampler(BaseSampler):
|
|
621
681
|
|
622
682
|
def _init_optimizer(
|
623
683
|
self,
|
624
|
-
trans:
|
684
|
+
trans: SearchSpaceTransform,
|
625
685
|
population_size: Optional[int] = None,
|
626
686
|
) -> FastINGO:
|
627
687
|
lower_bounds = trans.bounds[:, 0]
|
@@ -675,6 +735,7 @@ class ThreadLocalSampler(threading.local):
|
|
675
735
|
|
676
736
|
|
677
737
|
def init_cmaes(study, seed, trials, search_space):
|
738
|
+
trials = copy.deepcopy(trials)
|
678
739
|
trials.sort(key=lambda trial: trial.datetime_complete)
|
679
740
|
return CmaEsSampler(seed=seed, source_trials=trials, lr_adapt=True)
|
680
741
|
|
@@ -686,8 +747,14 @@ def init_hebo(study, seed, trials, search_space):
|
|
686
747
|
return sampler
|
687
748
|
|
688
749
|
|
750
|
+
def _use_cuda():
|
751
|
+
return "cuda" if torch.cuda.is_available() else "cpu"
|
752
|
+
|
753
|
+
|
689
754
|
def init_botorch(study, seed, trials, search_space):
|
690
|
-
return BoTorchSampler(
|
755
|
+
return BoTorchSampler(
|
756
|
+
search_space=search_space, seed=seed, device=_use_cuda()
|
757
|
+
) # will automatically pull in latest data
|
691
758
|
|
692
759
|
|
693
760
|
def init_nsgaii(study, seed, trials, search_space):
|
@@ -709,17 +776,20 @@ class AutoSampler(BaseSampler):
|
|
709
776
|
def __init__(
|
710
777
|
self,
|
711
778
|
samplers: Iterable[Tuple[int, Callable]] | None = None,
|
712
|
-
search_space: dict[str, BaseDistribution] = None,
|
779
|
+
search_space: Optional[dict[str, BaseDistribution]] = None,
|
713
780
|
*,
|
714
781
|
seed: int | None = None,
|
715
|
-
constraints_func:
|
782
|
+
constraints_func: Optional[Callable[..., Any]] = None,
|
716
783
|
) -> None:
|
717
|
-
|
784
|
+
if constraints_func is not None:
|
785
|
+
raise NotImplementedError("constraints_func is not supported by AutoSampler.")
|
718
786
|
if samplers is None:
|
787
|
+
if search_space is None:
|
788
|
+
raise ValueError("AutoSampler requires a search_space when using the default sampler schedule.")
|
719
789
|
samplers = ((0, init_hebo), (100, init_nsgaii))
|
720
790
|
self.sampler_indices = np.sort(np.array([x[0] for x in samplers], dtype=np.int32))
|
721
791
|
self.samplers = [x[1] for x in sorted(samplers, key=lambda x: x[0])]
|
722
|
-
self.search_space = search_space
|
792
|
+
self.search_space = {} if search_space is None else dict(search_space)
|
723
793
|
self._rng = LazyRandomState(seed)
|
724
794
|
self._random_sampler = RandomSampler(seed=seed)
|
725
795
|
self._thread_local_sampler = ThreadLocalSampler()
|
@@ -762,7 +832,7 @@ class AutoSampler(BaseSampler):
|
|
762
832
|
complete_trials = study._get_trials(deepcopy=False, states=(TrialState.COMPLETE,), use_cache=True)
|
763
833
|
self._completed_trials = max(self._completed_trials, len(complete_trials))
|
764
834
|
new_index = (self._completed_trials >= self.sampler_indices).sum() - 1
|
765
|
-
if new_index == self._current_index:
|
835
|
+
if new_index == self._current_index or new_index < 0:
|
766
836
|
return
|
767
837
|
self._current_index = new_index
|
768
838
|
self._sampler = self.samplers[new_index](
|
@@ -775,7 +845,7 @@ class AutoSampler(BaseSampler):
|
|
775
845
|
def sample_relative(
|
776
846
|
self, study: Study, trial: FrozenTrial, search_space: dict[str, BaseDistribution]
|
777
847
|
) -> dict[str, Any]:
|
778
|
-
return self._sampler.sample_relative(study, trial, self.search_space)
|
848
|
+
return self._sampler.sample_relative(study, trial, search_space or self.search_space)
|
779
849
|
|
780
850
|
def sample_independent(
|
781
851
|
self,
|
@@ -804,5 +874,6 @@ class AutoSampler(BaseSampler):
|
|
804
874
|
state: TrialState,
|
805
875
|
values: Sequence[float] | None,
|
806
876
|
) -> None:
|
807
|
-
|
877
|
+
if state not in (TrialState.COMPLETE, TrialState.FAIL, TrialState.PRUNED):
|
878
|
+
raise ValueError(f"Unsupported trial state: {state}.")
|
808
879
|
self._sampler.after_trial(study, trial, state, values)
|