heavyball 2.1.2__tar.gz → 2.1.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. {heavyball-2.1.2 → heavyball-2.1.4}/PKG-INFO +2 -1
  2. {heavyball-2.1.2 → heavyball-2.1.4}/heavyball/__init__.py +56 -89
  3. {heavyball-2.1.2 → heavyball-2.1.4}/heavyball/chainable.py +6 -4
  4. {heavyball-2.1.2 → heavyball-2.1.4}/heavyball/helpers.py +127 -56
  5. {heavyball-2.1.2 → heavyball-2.1.4}/heavyball/utils.py +74 -61
  6. {heavyball-2.1.2 → heavyball-2.1.4}/heavyball.egg-info/PKG-INFO +2 -1
  7. {heavyball-2.1.2 → heavyball-2.1.4}/heavyball.egg-info/SOURCES.txt +8 -1
  8. {heavyball-2.1.2 → heavyball-2.1.4}/heavyball.egg-info/requires.txt +1 -0
  9. {heavyball-2.1.2 → heavyball-2.1.4}/pyproject.toml +2 -2
  10. {heavyball-2.1.2 → heavyball-2.1.4}/test/test_bf16_params.py +5 -1
  11. {heavyball-2.1.2 → heavyball-2.1.4}/test/test_bf16_q.py +0 -1
  12. {heavyball-2.1.2 → heavyball-2.1.4}/test/test_bf16_storage.py +0 -1
  13. {heavyball-2.1.2 → heavyball-2.1.4}/test/test_caution.py +0 -1
  14. heavyball-2.1.4/test/test_chainable_cpu.py +65 -0
  15. {heavyball-2.1.2 → heavyball-2.1.4}/test/test_channels_last.py +0 -1
  16. {heavyball-2.1.2 → heavyball-2.1.4}/test/test_closure.py +0 -1
  17. heavyball-2.1.4/test/test_cpu_features.py +134 -0
  18. {heavyball-2.1.2 → heavyball-2.1.4}/test/test_ema.py +0 -1
  19. {heavyball-2.1.2 → heavyball-2.1.4}/test/test_foreach.py +40 -6
  20. heavyball-2.1.4/test/test_helpers_cpu.py +107 -0
  21. {heavyball-2.1.2 → heavyball-2.1.4}/test/test_hook.py +0 -1
  22. {heavyball-2.1.2 → heavyball-2.1.4}/test/test_mars.py +0 -1
  23. {heavyball-2.1.2 → heavyball-2.1.4}/test/test_memory.py +0 -2
  24. {heavyball-2.1.2 → heavyball-2.1.4}/test/test_memory_leak.py +0 -1
  25. {heavyball-2.1.2 → heavyball-2.1.4}/test/test_merge.py +0 -1
  26. {heavyball-2.1.2 → heavyball-2.1.4}/test/test_nd_param.py +0 -1
  27. {heavyball-2.1.2 → heavyball-2.1.4}/test/test_no_grad.py +0 -1
  28. heavyball-2.1.4/test/test_optimizer_cpu_smoke.py +65 -0
  29. {heavyball-2.1.2 → heavyball-2.1.4}/test/test_save_restore.py +0 -1
  30. {heavyball-2.1.2 → heavyball-2.1.4}/test/test_singular_values.py +1 -1
  31. {heavyball-2.1.2 → heavyball-2.1.4}/test/test_stochastic_updates.py +0 -1
  32. heavyball-2.1.4/test/test_stochastic_utils_cpu.py +49 -0
  33. {heavyball-2.1.2 → heavyball-2.1.4}/test/test_toy_training.py +4 -4
  34. heavyball-2.1.4/test/test_utils_cpu.py +296 -0
  35. heavyball-2.1.4/test/test_utils_property.py +281 -0
  36. {heavyball-2.1.2 → heavyball-2.1.4}/LICENSE +0 -0
  37. {heavyball-2.1.2 → heavyball-2.1.4}/README.md +0 -0
  38. {heavyball-2.1.2 → heavyball-2.1.4}/heavyball.egg-info/dependency_links.txt +0 -0
  39. {heavyball-2.1.2 → heavyball-2.1.4}/heavyball.egg-info/top_level.txt +0 -0
  40. {heavyball-2.1.2 → heavyball-2.1.4}/setup.cfg +0 -0
  41. {heavyball-2.1.2 → heavyball-2.1.4}/test/test_clip.py +0 -0
  42. {heavyball-2.1.2 → heavyball-2.1.4}/test/test_migrate_cli.py +0 -0
  43. {heavyball-2.1.2 → heavyball-2.1.4}/test/test_psgd_precond_init_stability.py +0 -0
  44. {heavyball-2.1.2 → heavyball-2.1.4}/test/test_soap.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: heavyball
3
- Version: 2.1.2
3
+ Version: 2.1.4
4
4
  Summary: Efficient Optimizers
5
5
  Author-email: HeavyBall Authors <github.heavyball@nestler.sh>
6
6
  Project-URL: source, https://github.com/HomebrewML/HeavyBall
@@ -21,6 +21,7 @@ Requires-Dist: numpy<2.0.0
21
21
  Provides-Extra: dev
22
22
  Requires-Dist: pre-commit; extra == "dev"
23
23
  Requires-Dist: pytest; extra == "dev"
24
+ Requires-Dist: hypothesis; extra == "dev"
24
25
  Requires-Dist: ruff; extra == "dev"
25
26
  Requires-Dist: matplotlib; extra == "dev"
26
27
  Requires-Dist: seaborn; extra == "dev"
@@ -1,6 +1,6 @@
1
1
  import functools
2
2
  import math
3
- from typing import Optional
3
+ from typing import Optional, Type, Union
4
4
 
5
5
  import torch.optim
6
6
 
@@ -8,39 +8,6 @@ from . import chainable as C
8
8
  from . import utils
9
9
 
10
10
 
11
- class SAMWrapper(torch.optim.Optimizer):
12
- def __init__(self, params, wrapped_optimizer: utils.StatefulOptimizer, ball: float = 0.1):
13
- if not isinstance(wrapped_optimizer, utils.StatefulOptimizer):
14
- raise ValueError(f"{wrapped_optimizer.__class__.__name__} is not a HeavyBall optimizer")
15
- super().__init__(params, {"ball": ball})
16
- self.wrapped_optimizer = wrapped_optimizer
17
-
18
- @torch.no_grad()
19
- def step(self, closure=None):
20
- if closure is None:
21
- raise ValueError("SAM requires closure")
22
- with torch.enable_grad():
23
- closure()
24
- old_params = [utils.sam_step(group["params"], group["ball"]) for group in self.param_groups]
25
-
26
- originaL_handle_closure = self.wrapped_optimizer._handle_closure
27
-
28
- def _handle_closure(closure):
29
- originaL_handle_closure(closure)
30
- for group, old in zip(self.param_groups, old_params):
31
- utils.copy_stochastic_list_(group["params"], old)
32
-
33
- try:
34
- self.wrapped_optimizer._handle_closure = _handle_closure
35
- loss = self.wrapped_optimizer.step(closure)
36
- finally:
37
- self.wrapped_optimizer._handle_closure = originaL_handle_closure
38
- return loss
39
-
40
- def zero_grad(self, set_to_none: bool = True):
41
- self.wrapped_optimizer.zero_grad()
42
-
43
-
44
11
  class SGD(C.BaseOpt):
45
12
  def __init__(
46
13
  self,
@@ -778,7 +745,7 @@ class ForeachPSGDKron(C.BaseOpt):
778
745
  beta=None,
779
746
  betas=(0.9, 0.999),
780
747
  weight_decay=0.0,
781
- preconditioner_update_probability=None,
748
+ preconditioner_update_probability=C.use_default,
782
749
  max_size_triangular=2048,
783
750
  min_ndim_triangular=2,
784
751
  memory_save_mode=None,
@@ -830,8 +797,8 @@ class ForeachPSGDKron(C.BaseOpt):
830
797
  if kwargs:
831
798
  utils.warn_once(f"Working with uncaptured keyword arguments: {kwargs}")
832
799
 
833
- self.precond_schedule = (
834
- defaults.pop("preconditioner_update_probability") or utils.precond_update_prob_schedule()
800
+ self.precond_schedule = C.default(
801
+ defaults.pop("preconditioner_update_probability"), utils.precond_update_prob_schedule()
835
802
  )
836
803
  params = defaults.pop("params")
837
804
 
@@ -890,7 +857,7 @@ class ForeachPSGDLRA(C.BaseOpt):
890
857
  lr=0.001,
891
858
  beta=0.9,
892
859
  weight_decay=0.0,
893
- preconditioner_update_probability=None,
860
+ preconditioner_update_probability=C.use_default,
894
861
  momentum_into_precond_update=True,
895
862
  rank: Optional[int] = None,
896
863
  warmup_steps: int = 0,
@@ -924,8 +891,8 @@ class ForeachPSGDLRA(C.BaseOpt):
924
891
  if kwargs:
925
892
  utils.warn_once(f"Working with uncaptured keyword arguments: {kwargs}")
926
893
 
927
- self.precond_schedule = (
928
- defaults.pop("preconditioner_update_probability") or utils.precond_update_prob_schedule()
894
+ self.precond_schedule = C.default(
895
+ defaults.pop("preconditioner_update_probability"), utils.precond_update_prob_schedule()
929
896
  )
930
897
  params = defaults.pop("params")
931
898
 
@@ -960,6 +927,54 @@ class NewtonHybrid2PSGDLRA(ForeachNewtonPSGDLRA):
960
927
  hvp_interval = 2
961
928
 
962
929
 
930
+ class SAMWrapper(torch.optim.Optimizer):
931
+ def __init__(
932
+ self,
933
+ params,
934
+ wrapped_optimizer: Union[utils.StatefulOptimizer, Type[utils.StatefulOptimizer]] = ForeachAdamW,
935
+ ball: float = 0.1,
936
+ ):
937
+ params = list(params)
938
+ super().__init__(params, {"ball": ball})
939
+
940
+ if isinstance(wrapped_optimizer, type):
941
+ if not issubclass(wrapped_optimizer, utils.StatefulOptimizer):
942
+ raise ValueError(f"{wrapped_optimizer.__name__} is not a HeavyBall optimizer")
943
+ wrapped_optimizer = wrapped_optimizer(params)
944
+ elif not isinstance(wrapped_optimizer, utils.StatefulOptimizer):
945
+ raise ValueError(f"{wrapped_optimizer.__class__.__name__} is not a HeavyBall optimizer")
946
+
947
+ self.wrapped_optimizer = wrapped_optimizer
948
+
949
+ @torch.no_grad()
950
+ def step(self, closure=None):
951
+ if closure is None:
952
+ raise ValueError("SAM requires closure")
953
+ with torch.enable_grad():
954
+ closure()
955
+ old_params = [utils.sam_step(group["params"], group["ball"]) for group in self.param_groups]
956
+
957
+ original_handle_closure = self.wrapped_optimizer._handle_closure
958
+
959
+ def _handle_closure(closure):
960
+ try:
961
+ _loss = original_handle_closure(closure)
962
+ finally:
963
+ for group, old in zip(self.param_groups, old_params):
964
+ utils.copy_stochastic_list_(group["params"], old)
965
+ return _loss
966
+
967
+ try:
968
+ self.wrapped_optimizer._handle_closure = _handle_closure
969
+ loss = self.wrapped_optimizer.step(closure)
970
+ finally:
971
+ self.wrapped_optimizer._handle_closure = original_handle_closure
972
+ return loss
973
+
974
+ def zero_grad(self, set_to_none: bool = True):
975
+ self.wrapped_optimizer.zero_grad(set_to_none=set_to_none)
976
+
977
+
963
978
  PalmForEachSoap = PaLMForeachSOAP
964
979
  PaLMSOAP = PaLMForeachSOAP
965
980
  PaLMSFAdamW = PaLMForeachSFAdamW
@@ -983,52 +998,4 @@ PSGDLRA = ForeachPSGDLRA
983
998
  NewtonPSGDLRA = ForeachNewtonPSGDLRA
984
999
  NewtonPSGDKron = ForeachCachedNewtonPSGD
985
1000
 
986
- __all__ = [
987
- "Muon",
988
- "RMSprop",
989
- "PrecondSchedulePaLMSOAP",
990
- "PSGDKron",
991
- "PurePSGD",
992
- "DelayedPSGD",
993
- "CachedPSGDKron",
994
- "CachedDelayedPSGDKron",
995
- "PalmForEachSoap",
996
- "PaLMSOAP",
997
- "PaLMSFAdamW",
998
- "LaProp",
999
- "ADOPT",
1000
- "PrecondScheduleSOAP",
1001
- "PrecondSchedulePaLMSOAP",
1002
- "RMSprop",
1003
- "MuonLaProp",
1004
- "ForeachSignLaProp",
1005
- "ForeachDelayedPSGDLRA",
1006
- "ForeachPSGDLRA",
1007
- "ForeachPSGDLRA",
1008
- "ForeachNewtonPSGDLRA", #
1009
- "ForeachAdamW",
1010
- "ForeachSFAdamW",
1011
- "ForeachLaProp",
1012
- "ForeachADOPT",
1013
- "ForeachSOAP",
1014
- "ForeachPSGDKron",
1015
- "ForeachPurePSGD",
1016
- "ForeachDelayedPSGD",
1017
- "ForeachCachedPSGDKron",
1018
- "ForeachCachedDelayedPSGDKron",
1019
- "ForeachRMSprop",
1020
- "ForeachMuon",
1021
- "ForeachCachedNewtonPSGD",
1022
- "OrthoLaProp",
1023
- "LaPropOrtho",
1024
- "SignLaProp",
1025
- "DelayedPSGD",
1026
- "PSGDLRA",
1027
- "NewtonPSGDLRA",
1028
- "NewtonHybrid2PSGDLRA",
1029
- "NewtonHybrid2PSGDKron",
1030
- "MSAMLaProp",
1031
- "NewtonPSGDKron",
1032
- "ForeachAdamC",
1033
- "SGD",
1034
- ]
1001
+ __all__ = [k for k, v in globals().items() if isinstance(v, type) and issubclass(v, torch.optim.Optimizer)]
@@ -62,8 +62,6 @@ class FunctionTransform:
62
62
  self._init(st, group, *a, **kwargs)
63
63
  except SkipUpdate:
64
64
  skip_update = True
65
- except:
66
- raise
67
65
  finally:
68
66
  if "is_initialized" not in st:
69
67
  st["is_initialized"] = set()
@@ -499,6 +497,7 @@ def scale_by_suds(group, update, grad, param, exp_avg, exp_avg_sq, fisher_approx
499
497
  precond_update, _ = utils.eigvecs_product_rank1(precond_update.flatten(), fisher_approx.flatten(), w)
500
498
 
501
499
  new_approx = utils.oja_update(fisher_approx.flatten().to(update.dtype), update.flatten(), group["precond_lr"])
500
+ new_approx = new_approx.view_as(fisher_approx)
502
501
  utils.copy_stochastic_(fisher_approx, new_approx)
503
502
  return precond_update
504
503
 
@@ -565,7 +564,7 @@ def _init_psgd_kron(state, group, update, grad, param, cached: bool = False, pro
565
564
  )
566
565
  state["Q"] = utils.triu_to_line(Q) if group["store_triu_as_line"] else Q
567
566
  state["running_lower_bound"] = [torch.zeros((1,), device=q.device, dtype=torch.float64) for q in Q]
568
- state["step"] = torch.zeros((), device=param.device, dtype=torch.int64)
567
+ state["step"] = torch.zeros((), device=param.device, dtype=torch.float64) # torch casts int to float in ckpt load
569
568
  if group["adaptive"]:
570
569
  state["velocity"] = [torch.zeros((), device=q.device, dtype=q.dtype) for q in Q]
571
570
  if not cached:
@@ -750,7 +749,9 @@ def _update_psgd_precond(
750
749
  if isinstance(prob, float):
751
750
  float_prob = prob
752
751
  else:
753
- float_prob = prob(group.get(f"cumulative_prob_{id(Q)}_prob_step", 1))
752
+ prob_step = group.get(f"cumulative_prob_{id(Q)}_prob_step", 1)
753
+ float_prob = prob(prob_step)
754
+ group[f"cumulative_prob_{id(Q)}_prob_step"] = prob_step + 1
754
755
  group["is_cached"] = should_use_cache = cached and float_prob < 0.5
755
756
 
756
757
  if precond is not None:
@@ -1086,6 +1087,7 @@ class ChainOpt(utils.StatefulOptimizer):
1086
1087
  if not group["foreach"] or len(p) == 1:
1087
1088
  for param, grad in zip(p, g):
1088
1089
  chain(self.state_, group, [grad], [param], *self.fns)
1090
+ group["caution"] = caution
1089
1091
  else:
1090
1092
  chain(self.state_, group, g, p, *self.fns)
1091
1093
 
@@ -1,9 +1,9 @@
1
- from __future__ import annotations
2
-
1
+ import copy
3
2
  import functools
4
3
  import math
5
4
  import threading
6
- from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union
5
+ from contextlib import contextmanager
6
+ from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Sequence, Tuple, Union
7
7
 
8
8
  import numpy
9
9
  import numpy as np
@@ -11,23 +11,15 @@ import optuna
11
11
  import optunahub
12
12
  import pandas as pd
13
13
  import torch
14
- from botorch.utils.sampling import manual_seed
15
14
  from hebo.design_space.design_space import DesignSpace
16
15
  from hebo.optimizers.hebo import HEBO
17
- from optuna._transform import _SearchSpaceTransform
16
+ from optuna._transform import _SearchSpaceTransform as SearchSpaceTransform
18
17
  from optuna.distributions import BaseDistribution, CategoricalDistribution, FloatDistribution, IntDistribution
19
18
  from optuna.samplers import BaseSampler, CmaEsSampler, RandomSampler
20
19
  from optuna.samplers._lazy_random_state import LazyRandomState
21
20
  from optuna.study import Study
22
21
  from optuna.study._study_direction import StudyDirection
23
22
  from optuna.trial import FrozenTrial, TrialState
24
- from optuna_integration.botorch import (
25
- ehvi_candidates_func,
26
- logei_candidates_func,
27
- qehvi_candidates_func,
28
- qei_candidates_func,
29
- qparego_candidates_func,
30
- )
31
23
  from torch import Tensor
32
24
  from torch.nn import functional as F
33
25
 
@@ -37,12 +29,39 @@ _MAXINT32 = (1 << 31) - 1
37
29
  _SAMPLER_KEY = "auto:sampler"
38
30
 
39
31
 
32
+ @contextmanager
33
+ def manual_seed(seed: int | None = None) -> Generator[None, None, None]:
34
+ r"""
35
+ Contextmanager for manual setting the torch.random seed.
36
+
37
+ Args:
38
+ seed: The seed to set the random number generator to.
39
+
40
+ Returns:
41
+ Generator
42
+
43
+ Example:
44
+ >>> with manual_seed(1234):
45
+ >>> X = torch.rand(3)
46
+
47
+ copied as-is from https://github.com/meta-pytorch/botorch/blob/a42cd65f9b704cdb6f2ee64db99a022eb15295d5/botorch/utils/sampling.py#L53C1-L75C50 under the MIT License
48
+ """
49
+ old_state = torch.random.get_rng_state()
50
+ try:
51
+ if seed is not None:
52
+ torch.random.manual_seed(seed)
53
+ yield
54
+ finally:
55
+ if seed is not None:
56
+ torch.random.set_rng_state(old_state)
57
+
58
+
40
59
  class SimpleAPIBaseSampler(BaseSampler):
41
60
  def __init__(
42
61
  self,
43
- search_space: dict[str, BaseDistribution] = None,
62
+ search_space: Optional[dict[str, BaseDistribution]] = None,
44
63
  ):
45
- self.search_space = search_space
64
+ self.search_space = {} if search_space is None else dict(search_space)
46
65
 
47
66
  def suggest_all(self, trial: FrozenTrial):
48
67
  return {k: trial._suggest(k, dist) for k, dist in self.search_space.items()}
@@ -65,6 +84,16 @@ def _get_default_candidates_func(
65
84
  """
66
85
  The original is available at https://github.com/optuna/optuna-integration/blob/156a8bc081322791015d2beefff9373ed7b24047/optuna_integration/botorch/botorch.py under the MIT License
67
86
  """
87
+
88
+ # lazy import
89
+ from optuna_integration.botorch import (
90
+ ehvi_candidates_func,
91
+ logei_candidates_func,
92
+ qehvi_candidates_func,
93
+ qei_candidates_func,
94
+ qparego_candidates_func,
95
+ )
96
+
68
97
  if n_objectives > 3 and not has_constraint and not consider_running_trials:
69
98
  return ehvi_candidates_func
70
99
  elif n_objectives > 3:
@@ -124,7 +153,7 @@ def _untransform_numerical_param_torch(
124
153
 
125
154
 
126
155
  @torch.no_grad()
127
- def untransform(self: _SearchSpaceTransform, trans_params: Tensor) -> dict[str, Any]:
156
+ def untransform(self: SearchSpaceTransform, trans_params: Tensor) -> dict[str, Any]:
128
157
  assert trans_params.shape == (self._raw_bounds.shape[0],)
129
158
 
130
159
  if self._transform_0_1:
@@ -152,29 +181,31 @@ class BoTorchSampler(SimpleAPIBaseSampler):
152
181
 
153
182
  def __init__(
154
183
  self,
155
- search_space: dict[str, BaseDistribution] = None,
184
+ search_space: Optional[dict[str, BaseDistribution]] = None,
156
185
  *,
157
- candidates_func: None = None,
158
- constraints_func: None = None,
186
+ candidates_func: Optional[Callable[..., Tensor]] = None,
187
+ constraints_func: Optional[Callable[..., Tensor]] = None,
159
188
  n_startup_trials: int = 10,
160
189
  consider_running_trials: bool = False,
161
- independent_sampler: None = None,
190
+ independent_sampler: Optional[BaseSampler] = None,
162
191
  seed: int | None = None,
163
192
  device: torch.device | str | None = None,
164
193
  trial_chunks: int = 128,
165
194
  ):
166
- assert constraints_func is None
167
- assert candidates_func is None
168
- assert consider_running_trials is False
169
- assert independent_sampler is None
170
- self._candidates_func = None
171
- self._independent_sampler = RandomSampler(seed=seed)
195
+ if constraints_func is not None:
196
+ raise NotImplementedError("constraints_func is currently not supported by BoTorchSampler.")
197
+ if consider_running_trials:
198
+ raise NotImplementedError("consider_running_trials is currently not supported by BoTorchSampler.")
199
+ if candidates_func is not None and not callable(candidates_func):
200
+ raise TypeError("candidates_func must be callable.")
201
+ self._candidates_func = candidates_func
202
+ self._independent_sampler = independent_sampler or RandomSampler(seed=seed)
172
203
  self._n_startup_trials = n_startup_trials
173
204
  self._seed = seed
174
205
  self.trial_chunks = trial_chunks
175
206
 
176
207
  self._study_id: int | None = None
177
- self.search_space = search_space
208
+ self.search_space = {} if search_space is None else dict(search_space)
178
209
  if isinstance(device, str):
179
210
  device = torch.device(device)
180
211
  self._device = device or torch.device("cpu")
@@ -182,14 +213,24 @@ class BoTorchSampler(SimpleAPIBaseSampler):
182
213
  self._values = None
183
214
  self._params = None
184
215
  self._index = 0
216
+ self._bounds_dim: int | None = None
185
217
 
186
218
  def infer_relative_search_space(self, study: Study, trial: FrozenTrial) -> dict[str, BaseDistribution]:
187
219
  return self.search_space
188
220
 
189
221
  @torch.no_grad()
190
222
  def _preprocess_trials(
191
- self, trans: _SearchSpaceTransform, study: Study, trials: list[FrozenTrial]
223
+ self, trans: SearchSpaceTransform, study: Study, trials: list[FrozenTrial]
192
224
  ) -> Tuple[int, Tensor, Tensor]:
225
+ bounds_dim = trans.bounds.shape[0]
226
+ if self._bounds_dim is not None and self._bounds_dim != bounds_dim:
227
+ self._values = None
228
+ self._params = None
229
+ self._index = 0
230
+ self.seen_trials = set()
231
+ if self._bounds_dim is None:
232
+ self._bounds_dim = bounds_dim
233
+
193
234
  new_trials = []
194
235
  for trial in trials:
195
236
  tid: int = trial._trial_id
@@ -200,6 +241,10 @@ class BoTorchSampler(SimpleAPIBaseSampler):
200
241
 
201
242
  n_objectives = len(study.directions)
202
243
  if not new_trials:
244
+ if self._values is None or self._params is None:
245
+ empty_values = torch.zeros((0, n_objectives), dtype=torch.float64, device=self._device)
246
+ empty_params = torch.zeros((0, bounds_dim), dtype=torch.float64, device=self._device)
247
+ return n_objectives, empty_values, empty_params
203
248
  return n_objectives, self._values[: self._index], self._params[: self._index]
204
249
 
205
250
  n_completed_trials = len(trials)
@@ -216,18 +261,28 @@ class BoTorchSampler(SimpleAPIBaseSampler):
216
261
  if direction == StudyDirection.MINIMIZE: # BoTorch always assumes maximization.
217
262
  values[:, obj_idx] *= -1
218
263
 
219
- if self._values is None:
264
+ bounds_dim = trans.bounds.shape[0]
265
+ cache_stale = (
266
+ self._values is None
267
+ or self._params is None
268
+ or self._values.size(1) != n_objectives
269
+ or self._params.size(1) != bounds_dim
270
+ )
271
+ if cache_stale:
220
272
  self._values = torch.zeros((self.trial_chunks, n_objectives), dtype=torch.float64, device=self._device)
221
- self._params = torch.zeros(
222
- (self.trial_chunks, trans.bounds.shape[0]), dtype=torch.float64, device=self._device
223
- )
273
+ self._params = torch.zeros((self.trial_chunks, bounds_dim), dtype=torch.float64, device=self._device)
274
+ self._index = 0
275
+ self.seen_trials = set()
276
+ self._bounds_dim = bounds_dim
224
277
  spillage = (self._index + n_completed_trials) - self._values.size(0)
225
278
  if spillage > 0:
226
279
  pad = int(math.ceil(spillage / self.trial_chunks) * self.trial_chunks)
227
280
  self._values = F.pad(self._values, (0, 0, 0, pad))
228
281
  self._params = F.pad(self._params, (0, 0, 0, pad))
229
- self._values[self._index : self._index + n_completed_trials] = torch.from_numpy(values)
230
- self._params[self._index : self._index + n_completed_trials] = torch.from_numpy(params)
282
+ values_tensor = torch.from_numpy(values).to(self._device)
283
+ params_tensor = torch.from_numpy(params).to(self._device)
284
+ self._values[self._index : self._index + n_completed_trials] = values_tensor
285
+ self._params[self._index : self._index + n_completed_trials] = params_tensor
231
286
  self._index += n_completed_trials
232
287
 
233
288
  return n_objectives, self._values[: self._index], self._params[: self._index]
@@ -246,7 +301,7 @@ class BoTorchSampler(SimpleAPIBaseSampler):
246
301
  if n_completed_trials < self._n_startup_trials:
247
302
  return {}
248
303
 
249
- trans = _SearchSpaceTransform(search_space)
304
+ trans = SearchSpaceTransform(search_space)
250
305
  n_objectives, values, params = self._preprocess_trials(trans, study, completed_trials)
251
306
 
252
307
  if self._candidates_func is None:
@@ -349,10 +404,10 @@ class HEBOSampler(optunahub.samplers.SimpleBaseSampler, SimpleAPIBaseSampler):
349
404
  independent_sampler: BaseSampler | None = None,
350
405
  ) -> None:
351
406
  super().__init__(search_space, seed)
352
- assert constant_liar is False
353
- assert independent_sampler is None
407
+ if constant_liar:
408
+ raise NotImplementedError("constant_liar is not supported by HEBOSampler.")
354
409
  self._hebo = HEBO(_convert_to_hebo_design_space(search_space), scramble_seed=self._seed)
355
- self._independent_sampler = optuna.samplers.RandomSampler(seed=seed)
410
+ self._independent_sampler = independent_sampler or optuna.samplers.RandomSampler(seed=seed)
356
411
  self._rng = np.random.default_rng(seed)
357
412
 
358
413
  def sample_relative(
@@ -421,10 +476,12 @@ class FastINGO:
421
476
  learning_rate: Optional[float] = None,
422
477
  last_n: int = 4096,
423
478
  loco_step_size: float = 0.1,
424
- device="cuda",
479
+ device: str | None = None,
425
480
  batchnorm_decay: float = 0.99,
426
481
  score_decay: float = 0.99,
427
482
  ) -> None:
483
+ if device is None:
484
+ device = _use_cuda()
428
485
  n_dimension = len(mean)
429
486
  if population_size is None:
430
487
  population_size = 4 + int(np.floor(3 * np.log(n_dimension)))
@@ -491,8 +548,14 @@ class FastINGO:
491
548
  if y.numel() <= 2:
492
549
  return
493
550
 
494
- y = y + torch.where(y.min() <= 0, 1e-8 - y.min(), 0)
495
- y = y.log()
551
+ min_y = y.min()
552
+ max_y = y.max()
553
+ if torch.isclose(max_y, min_y, rtol=0.0, atol=1e-12):
554
+ return
555
+
556
+ if min_y <= 0:
557
+ y = y + (1e-8 - min_y)
558
+ y = y.clamp_min_(1e-8).log()
496
559
 
497
560
  ema = -torch.arange(y.size(0), device=y.device, dtype=y.dtype)
498
561
  weight = self.batchnorm_decay**ema
@@ -553,7 +616,7 @@ class ImplicitNaturalGradientSampler(BaseSampler):
553
616
  def reseed_rng(self) -> None:
554
617
  self._independent_sampler.reseed_rng()
555
618
  if self._optimizer:
556
- self._optimizer._rng.seed()
619
+ self._optimizer.generator.seed()
557
620
 
558
621
  def infer_relative_search_space(
559
622
  self, study: "optuna.Study", trial: "optuna.trial.FrozenTrial"
@@ -603,14 +666,11 @@ class ImplicitNaturalGradientSampler(BaseSampler):
603
666
  self._warn_independent_sampling = False
604
667
  return {}
605
668
 
606
- trans = _SearchSpaceTransform(search_space)
669
+ trans = SearchSpaceTransform(search_space)
607
670
 
608
- if self._optimizer is None:
671
+ if self._optimizer is None or self._optimizer.dim != len(trans.bounds):
609
672
  self._optimizer = self._init_optimizer(trans, population_size=self._population_size)
610
-
611
- if self._optimizer.dim != len(trans.bounds):
612
- self._warn_independent_sampling = False
613
- return {}
673
+ self._param_queue.clear()
614
674
 
615
675
  solution_trials = [t for t in completed_trials if self._check_trial_is_generation(t)]
616
676
  for t in solution_trials:
@@ -621,7 +681,7 @@ class ImplicitNaturalGradientSampler(BaseSampler):
621
681
 
622
682
  def _init_optimizer(
623
683
  self,
624
- trans: _SearchSpaceTransform,
684
+ trans: SearchSpaceTransform,
625
685
  population_size: Optional[int] = None,
626
686
  ) -> FastINGO:
627
687
  lower_bounds = trans.bounds[:, 0]
@@ -675,6 +735,7 @@ class ThreadLocalSampler(threading.local):
675
735
 
676
736
 
677
737
  def init_cmaes(study, seed, trials, search_space):
738
+ trials = copy.deepcopy(trials)
678
739
  trials.sort(key=lambda trial: trial.datetime_complete)
679
740
  return CmaEsSampler(seed=seed, source_trials=trials, lr_adapt=True)
680
741
 
@@ -686,8 +747,14 @@ def init_hebo(study, seed, trials, search_space):
686
747
  return sampler
687
748
 
688
749
 
750
+ def _use_cuda():
751
+ return "cuda" if torch.cuda.is_available() else "cpu"
752
+
753
+
689
754
  def init_botorch(study, seed, trials, search_space):
690
- return BoTorchSampler(search_space=search_space, seed=seed, device="cuda") # will automatically pull in latest data
755
+ return BoTorchSampler(
756
+ search_space=search_space, seed=seed, device=_use_cuda()
757
+ ) # will automatically pull in latest data
691
758
 
692
759
 
693
760
  def init_nsgaii(study, seed, trials, search_space):
@@ -709,17 +776,20 @@ class AutoSampler(BaseSampler):
709
776
  def __init__(
710
777
  self,
711
778
  samplers: Iterable[Tuple[int, Callable]] | None = None,
712
- search_space: dict[str, BaseDistribution] = None,
779
+ search_space: Optional[dict[str, BaseDistribution]] = None,
713
780
  *,
714
781
  seed: int | None = None,
715
- constraints_func: None = None,
782
+ constraints_func: Optional[Callable[..., Any]] = None,
716
783
  ) -> None:
717
- assert constraints_func is None
784
+ if constraints_func is not None:
785
+ raise NotImplementedError("constraints_func is not supported by AutoSampler.")
718
786
  if samplers is None:
787
+ if search_space is None:
788
+ raise ValueError("AutoSampler requires a search_space when using the default sampler schedule.")
719
789
  samplers = ((0, init_hebo), (100, init_nsgaii))
720
790
  self.sampler_indices = np.sort(np.array([x[0] for x in samplers], dtype=np.int32))
721
791
  self.samplers = [x[1] for x in sorted(samplers, key=lambda x: x[0])]
722
- self.search_space = search_space
792
+ self.search_space = {} if search_space is None else dict(search_space)
723
793
  self._rng = LazyRandomState(seed)
724
794
  self._random_sampler = RandomSampler(seed=seed)
725
795
  self._thread_local_sampler = ThreadLocalSampler()
@@ -762,7 +832,7 @@ class AutoSampler(BaseSampler):
762
832
  complete_trials = study._get_trials(deepcopy=False, states=(TrialState.COMPLETE,), use_cache=True)
763
833
  self._completed_trials = max(self._completed_trials, len(complete_trials))
764
834
  new_index = (self._completed_trials >= self.sampler_indices).sum() - 1
765
- if new_index == self._current_index:
835
+ if new_index == self._current_index or new_index < 0:
766
836
  return
767
837
  self._current_index = new_index
768
838
  self._sampler = self.samplers[new_index](
@@ -775,7 +845,7 @@ class AutoSampler(BaseSampler):
775
845
  def sample_relative(
776
846
  self, study: Study, trial: FrozenTrial, search_space: dict[str, BaseDistribution]
777
847
  ) -> dict[str, Any]:
778
- return self._sampler.sample_relative(study, trial, self.search_space)
848
+ return self._sampler.sample_relative(study, trial, search_space or self.search_space)
779
849
 
780
850
  def sample_independent(
781
851
  self,
@@ -804,5 +874,6 @@ class AutoSampler(BaseSampler):
804
874
  state: TrialState,
805
875
  values: Sequence[float] | None,
806
876
  ) -> None:
807
- assert state in [TrialState.COMPLETE, TrialState.FAIL, TrialState.PRUNED]
877
+ if state not in (TrialState.COMPLETE, TrialState.FAIL, TrialState.PRUNED):
878
+ raise ValueError(f"Unsupported trial state: {state}.")
808
879
  self._sampler.after_trial(study, trial, state, values)