heavyball 2.1.3__tar.gz → 2.1.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. {heavyball-2.1.3 → heavyball-2.1.4}/PKG-INFO +1 -1
  2. {heavyball-2.1.3 → heavyball-2.1.4}/heavyball/__init__.py +56 -89
  3. {heavyball-2.1.3 → heavyball-2.1.4}/heavyball/chainable.py +6 -4
  4. {heavyball-2.1.3 → heavyball-2.1.4}/heavyball/helpers.py +88 -47
  5. {heavyball-2.1.3 → heavyball-2.1.4}/heavyball/utils.py +58 -51
  6. {heavyball-2.1.3 → heavyball-2.1.4}/heavyball.egg-info/PKG-INFO +1 -1
  7. {heavyball-2.1.3 → heavyball-2.1.4}/heavyball.egg-info/SOURCES.txt +1 -0
  8. {heavyball-2.1.3 → heavyball-2.1.4}/pyproject.toml +1 -1
  9. {heavyball-2.1.3 → heavyball-2.1.4}/test/test_bf16_params.py +5 -0
  10. {heavyball-2.1.3 → heavyball-2.1.4}/test/test_foreach.py +40 -5
  11. {heavyball-2.1.3 → heavyball-2.1.4}/test/test_memory.py +0 -1
  12. heavyball-2.1.4/test/test_stochastic_utils_cpu.py +49 -0
  13. {heavyball-2.1.3 → heavyball-2.1.4}/test/test_utils_cpu.py +22 -21
  14. {heavyball-2.1.3 → heavyball-2.1.4}/LICENSE +0 -0
  15. {heavyball-2.1.3 → heavyball-2.1.4}/README.md +0 -0
  16. {heavyball-2.1.3 → heavyball-2.1.4}/heavyball.egg-info/dependency_links.txt +0 -0
  17. {heavyball-2.1.3 → heavyball-2.1.4}/heavyball.egg-info/requires.txt +0 -0
  18. {heavyball-2.1.3 → heavyball-2.1.4}/heavyball.egg-info/top_level.txt +0 -0
  19. {heavyball-2.1.3 → heavyball-2.1.4}/setup.cfg +0 -0
  20. {heavyball-2.1.3 → heavyball-2.1.4}/test/test_bf16_q.py +0 -0
  21. {heavyball-2.1.3 → heavyball-2.1.4}/test/test_bf16_storage.py +0 -0
  22. {heavyball-2.1.3 → heavyball-2.1.4}/test/test_caution.py +0 -0
  23. {heavyball-2.1.3 → heavyball-2.1.4}/test/test_chainable_cpu.py +0 -0
  24. {heavyball-2.1.3 → heavyball-2.1.4}/test/test_channels_last.py +0 -0
  25. {heavyball-2.1.3 → heavyball-2.1.4}/test/test_clip.py +0 -0
  26. {heavyball-2.1.3 → heavyball-2.1.4}/test/test_closure.py +0 -0
  27. {heavyball-2.1.3 → heavyball-2.1.4}/test/test_cpu_features.py +0 -0
  28. {heavyball-2.1.3 → heavyball-2.1.4}/test/test_ema.py +0 -0
  29. {heavyball-2.1.3 → heavyball-2.1.4}/test/test_helpers_cpu.py +0 -0
  30. {heavyball-2.1.3 → heavyball-2.1.4}/test/test_hook.py +0 -0
  31. {heavyball-2.1.3 → heavyball-2.1.4}/test/test_mars.py +0 -0
  32. {heavyball-2.1.3 → heavyball-2.1.4}/test/test_memory_leak.py +0 -0
  33. {heavyball-2.1.3 → heavyball-2.1.4}/test/test_merge.py +0 -0
  34. {heavyball-2.1.3 → heavyball-2.1.4}/test/test_migrate_cli.py +0 -0
  35. {heavyball-2.1.3 → heavyball-2.1.4}/test/test_nd_param.py +0 -0
  36. {heavyball-2.1.3 → heavyball-2.1.4}/test/test_no_grad.py +0 -0
  37. {heavyball-2.1.3 → heavyball-2.1.4}/test/test_optimizer_cpu_smoke.py +0 -0
  38. {heavyball-2.1.3 → heavyball-2.1.4}/test/test_psgd_precond_init_stability.py +0 -0
  39. {heavyball-2.1.3 → heavyball-2.1.4}/test/test_save_restore.py +0 -0
  40. {heavyball-2.1.3 → heavyball-2.1.4}/test/test_singular_values.py +0 -0
  41. {heavyball-2.1.3 → heavyball-2.1.4}/test/test_soap.py +0 -0
  42. {heavyball-2.1.3 → heavyball-2.1.4}/test/test_stochastic_updates.py +0 -0
  43. {heavyball-2.1.3 → heavyball-2.1.4}/test/test_toy_training.py +0 -0
  44. {heavyball-2.1.3 → heavyball-2.1.4}/test/test_utils_property.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: heavyball
3
- Version: 2.1.3
3
+ Version: 2.1.4
4
4
  Summary: Efficient Optimizers
5
5
  Author-email: HeavyBall Authors <github.heavyball@nestler.sh>
6
6
  Project-URL: source, https://github.com/HomebrewML/HeavyBall
@@ -1,6 +1,6 @@
1
1
  import functools
2
2
  import math
3
- from typing import Optional
3
+ from typing import Optional, Type, Union
4
4
 
5
5
  import torch.optim
6
6
 
@@ -8,39 +8,6 @@ from . import chainable as C
8
8
  from . import utils
9
9
 
10
10
 
11
- class SAMWrapper(torch.optim.Optimizer):
12
- def __init__(self, params, wrapped_optimizer: utils.StatefulOptimizer, ball: float = 0.1):
13
- if not isinstance(wrapped_optimizer, utils.StatefulOptimizer):
14
- raise ValueError(f"{wrapped_optimizer.__class__.__name__} is not a HeavyBall optimizer")
15
- super().__init__(params, {"ball": ball})
16
- self.wrapped_optimizer = wrapped_optimizer
17
-
18
- @torch.no_grad()
19
- def step(self, closure=None):
20
- if closure is None:
21
- raise ValueError("SAM requires closure")
22
- with torch.enable_grad():
23
- closure()
24
- old_params = [utils.sam_step(group["params"], group["ball"]) for group in self.param_groups]
25
-
26
- originaL_handle_closure = self.wrapped_optimizer._handle_closure
27
-
28
- def _handle_closure(closure):
29
- originaL_handle_closure(closure)
30
- for group, old in zip(self.param_groups, old_params):
31
- utils.copy_stochastic_list_(group["params"], old)
32
-
33
- try:
34
- self.wrapped_optimizer._handle_closure = _handle_closure
35
- loss = self.wrapped_optimizer.step(closure)
36
- finally:
37
- self.wrapped_optimizer._handle_closure = originaL_handle_closure
38
- return loss
39
-
40
- def zero_grad(self, set_to_none: bool = True):
41
- self.wrapped_optimizer.zero_grad()
42
-
43
-
44
11
  class SGD(C.BaseOpt):
45
12
  def __init__(
46
13
  self,
@@ -778,7 +745,7 @@ class ForeachPSGDKron(C.BaseOpt):
778
745
  beta=None,
779
746
  betas=(0.9, 0.999),
780
747
  weight_decay=0.0,
781
- preconditioner_update_probability=None,
748
+ preconditioner_update_probability=C.use_default,
782
749
  max_size_triangular=2048,
783
750
  min_ndim_triangular=2,
784
751
  memory_save_mode=None,
@@ -830,8 +797,8 @@ class ForeachPSGDKron(C.BaseOpt):
830
797
  if kwargs:
831
798
  utils.warn_once(f"Working with uncaptured keyword arguments: {kwargs}")
832
799
 
833
- self.precond_schedule = (
834
- defaults.pop("preconditioner_update_probability") or utils.precond_update_prob_schedule()
800
+ self.precond_schedule = C.default(
801
+ defaults.pop("preconditioner_update_probability"), utils.precond_update_prob_schedule()
835
802
  )
836
803
  params = defaults.pop("params")
837
804
 
@@ -890,7 +857,7 @@ class ForeachPSGDLRA(C.BaseOpt):
890
857
  lr=0.001,
891
858
  beta=0.9,
892
859
  weight_decay=0.0,
893
- preconditioner_update_probability=None,
860
+ preconditioner_update_probability=C.use_default,
894
861
  momentum_into_precond_update=True,
895
862
  rank: Optional[int] = None,
896
863
  warmup_steps: int = 0,
@@ -924,8 +891,8 @@ class ForeachPSGDLRA(C.BaseOpt):
924
891
  if kwargs:
925
892
  utils.warn_once(f"Working with uncaptured keyword arguments: {kwargs}")
926
893
 
927
- self.precond_schedule = (
928
- defaults.pop("preconditioner_update_probability") or utils.precond_update_prob_schedule()
894
+ self.precond_schedule = C.default(
895
+ defaults.pop("preconditioner_update_probability"), utils.precond_update_prob_schedule()
929
896
  )
930
897
  params = defaults.pop("params")
931
898
 
@@ -960,6 +927,54 @@ class NewtonHybrid2PSGDLRA(ForeachNewtonPSGDLRA):
960
927
  hvp_interval = 2
961
928
 
962
929
 
930
+ class SAMWrapper(torch.optim.Optimizer):
931
+ def __init__(
932
+ self,
933
+ params,
934
+ wrapped_optimizer: Union[utils.StatefulOptimizer, Type[utils.StatefulOptimizer]] = ForeachAdamW,
935
+ ball: float = 0.1,
936
+ ):
937
+ params = list(params)
938
+ super().__init__(params, {"ball": ball})
939
+
940
+ if isinstance(wrapped_optimizer, type):
941
+ if not issubclass(wrapped_optimizer, utils.StatefulOptimizer):
942
+ raise ValueError(f"{wrapped_optimizer.__name__} is not a HeavyBall optimizer")
943
+ wrapped_optimizer = wrapped_optimizer(params)
944
+ elif not isinstance(wrapped_optimizer, utils.StatefulOptimizer):
945
+ raise ValueError(f"{wrapped_optimizer.__class__.__name__} is not a HeavyBall optimizer")
946
+
947
+ self.wrapped_optimizer = wrapped_optimizer
948
+
949
+ @torch.no_grad()
950
+ def step(self, closure=None):
951
+ if closure is None:
952
+ raise ValueError("SAM requires closure")
953
+ with torch.enable_grad():
954
+ closure()
955
+ old_params = [utils.sam_step(group["params"], group["ball"]) for group in self.param_groups]
956
+
957
+ original_handle_closure = self.wrapped_optimizer._handle_closure
958
+
959
+ def _handle_closure(closure):
960
+ try:
961
+ _loss = original_handle_closure(closure)
962
+ finally:
963
+ for group, old in zip(self.param_groups, old_params):
964
+ utils.copy_stochastic_list_(group["params"], old)
965
+ return _loss
966
+
967
+ try:
968
+ self.wrapped_optimizer._handle_closure = _handle_closure
969
+ loss = self.wrapped_optimizer.step(closure)
970
+ finally:
971
+ self.wrapped_optimizer._handle_closure = original_handle_closure
972
+ return loss
973
+
974
+ def zero_grad(self, set_to_none: bool = True):
975
+ self.wrapped_optimizer.zero_grad(set_to_none=set_to_none)
976
+
977
+
963
978
  PalmForEachSoap = PaLMForeachSOAP
964
979
  PaLMSOAP = PaLMForeachSOAP
965
980
  PaLMSFAdamW = PaLMForeachSFAdamW
@@ -983,52 +998,4 @@ PSGDLRA = ForeachPSGDLRA
983
998
  NewtonPSGDLRA = ForeachNewtonPSGDLRA
984
999
  NewtonPSGDKron = ForeachCachedNewtonPSGD
985
1000
 
986
- __all__ = [
987
- "Muon",
988
- "RMSprop",
989
- "PrecondSchedulePaLMSOAP",
990
- "PSGDKron",
991
- "PurePSGD",
992
- "DelayedPSGD",
993
- "CachedPSGDKron",
994
- "CachedDelayedPSGDKron",
995
- "PalmForEachSoap",
996
- "PaLMSOAP",
997
- "PaLMSFAdamW",
998
- "LaProp",
999
- "ADOPT",
1000
- "PrecondScheduleSOAP",
1001
- "PrecondSchedulePaLMSOAP",
1002
- "RMSprop",
1003
- "MuonLaProp",
1004
- "ForeachSignLaProp",
1005
- "ForeachDelayedPSGDLRA",
1006
- "ForeachPSGDLRA",
1007
- "ForeachPSGDLRA",
1008
- "ForeachNewtonPSGDLRA", #
1009
- "ForeachAdamW",
1010
- "ForeachSFAdamW",
1011
- "ForeachLaProp",
1012
- "ForeachADOPT",
1013
- "ForeachSOAP",
1014
- "ForeachPSGDKron",
1015
- "ForeachPurePSGD",
1016
- "ForeachDelayedPSGD",
1017
- "ForeachCachedPSGDKron",
1018
- "ForeachCachedDelayedPSGDKron",
1019
- "ForeachRMSprop",
1020
- "ForeachMuon",
1021
- "ForeachCachedNewtonPSGD",
1022
- "OrthoLaProp",
1023
- "LaPropOrtho",
1024
- "SignLaProp",
1025
- "DelayedPSGD",
1026
- "PSGDLRA",
1027
- "NewtonPSGDLRA",
1028
- "NewtonHybrid2PSGDLRA",
1029
- "NewtonHybrid2PSGDKron",
1030
- "MSAMLaProp",
1031
- "NewtonPSGDKron",
1032
- "ForeachAdamC",
1033
- "SGD",
1034
- ]
1001
+ __all__ = [k for k, v in globals().items() if isinstance(v, type) and issubclass(v, torch.optim.Optimizer)]
@@ -62,8 +62,6 @@ class FunctionTransform:
62
62
  self._init(st, group, *a, **kwargs)
63
63
  except SkipUpdate:
64
64
  skip_update = True
65
- except:
66
- raise
67
65
  finally:
68
66
  if "is_initialized" not in st:
69
67
  st["is_initialized"] = set()
@@ -499,6 +497,7 @@ def scale_by_suds(group, update, grad, param, exp_avg, exp_avg_sq, fisher_approx
499
497
  precond_update, _ = utils.eigvecs_product_rank1(precond_update.flatten(), fisher_approx.flatten(), w)
500
498
 
501
499
  new_approx = utils.oja_update(fisher_approx.flatten().to(update.dtype), update.flatten(), group["precond_lr"])
500
+ new_approx = new_approx.view_as(fisher_approx)
502
501
  utils.copy_stochastic_(fisher_approx, new_approx)
503
502
  return precond_update
504
503
 
@@ -565,7 +564,7 @@ def _init_psgd_kron(state, group, update, grad, param, cached: bool = False, pro
565
564
  )
566
565
  state["Q"] = utils.triu_to_line(Q) if group["store_triu_as_line"] else Q
567
566
  state["running_lower_bound"] = [torch.zeros((1,), device=q.device, dtype=torch.float64) for q in Q]
568
- state["step"] = torch.zeros((), device=param.device, dtype=torch.int64)
567
+ state["step"] = torch.zeros((), device=param.device, dtype=torch.float64) # torch casts int to float in ckpt load
569
568
  if group["adaptive"]:
570
569
  state["velocity"] = [torch.zeros((), device=q.device, dtype=q.dtype) for q in Q]
571
570
  if not cached:
@@ -750,7 +749,9 @@ def _update_psgd_precond(
750
749
  if isinstance(prob, float):
751
750
  float_prob = prob
752
751
  else:
753
- float_prob = prob(group.get(f"cumulative_prob_{id(Q)}_prob_step", 1))
752
+ prob_step = group.get(f"cumulative_prob_{id(Q)}_prob_step", 1)
753
+ float_prob = prob(prob_step)
754
+ group[f"cumulative_prob_{id(Q)}_prob_step"] = prob_step + 1
754
755
  group["is_cached"] = should_use_cache = cached and float_prob < 0.5
755
756
 
756
757
  if precond is not None:
@@ -1086,6 +1087,7 @@ class ChainOpt(utils.StatefulOptimizer):
1086
1087
  if not group["foreach"] or len(p) == 1:
1087
1088
  for param, grad in zip(p, g):
1088
1089
  chain(self.state_, group, [grad], [param], *self.fns)
1090
+ group["caution"] = caution
1089
1091
  else:
1090
1092
  chain(self.state_, group, g, p, *self.fns)
1091
1093
 
@@ -1,5 +1,4 @@
1
- from __future__ import annotations
2
-
1
+ import copy
3
2
  import functools
4
3
  import math
5
4
  import threading
@@ -14,7 +13,7 @@ import pandas as pd
14
13
  import torch
15
14
  from hebo.design_space.design_space import DesignSpace
16
15
  from hebo.optimizers.hebo import HEBO
17
- from optuna._transform import _SearchSpaceTransform
16
+ from optuna._transform import _SearchSpaceTransform as SearchSpaceTransform
18
17
  from optuna.distributions import BaseDistribution, CategoricalDistribution, FloatDistribution, IntDistribution
19
18
  from optuna.samplers import BaseSampler, CmaEsSampler, RandomSampler
20
19
  from optuna.samplers._lazy_random_state import LazyRandomState
@@ -60,9 +59,9 @@ def manual_seed(seed: int | None = None) -> Generator[None, None, None]:
60
59
  class SimpleAPIBaseSampler(BaseSampler):
61
60
  def __init__(
62
61
  self,
63
- search_space: dict[str, BaseDistribution] = None,
62
+ search_space: Optional[dict[str, BaseDistribution]] = None,
64
63
  ):
65
- self.search_space = search_space
64
+ self.search_space = {} if search_space is None else dict(search_space)
66
65
 
67
66
  def suggest_all(self, trial: FrozenTrial):
68
67
  return {k: trial._suggest(k, dist) for k, dist in self.search_space.items()}
@@ -154,7 +153,7 @@ def _untransform_numerical_param_torch(
154
153
 
155
154
 
156
155
  @torch.no_grad()
157
- def untransform(self: _SearchSpaceTransform, trans_params: Tensor) -> dict[str, Any]:
156
+ def untransform(self: SearchSpaceTransform, trans_params: Tensor) -> dict[str, Any]:
158
157
  assert trans_params.shape == (self._raw_bounds.shape[0],)
159
158
 
160
159
  if self._transform_0_1:
@@ -182,29 +181,31 @@ class BoTorchSampler(SimpleAPIBaseSampler):
182
181
 
183
182
  def __init__(
184
183
  self,
185
- search_space: dict[str, BaseDistribution] = None,
184
+ search_space: Optional[dict[str, BaseDistribution]] = None,
186
185
  *,
187
- candidates_func: None = None,
188
- constraints_func: None = None,
186
+ candidates_func: Optional[Callable[..., Tensor]] = None,
187
+ constraints_func: Optional[Callable[..., Tensor]] = None,
189
188
  n_startup_trials: int = 10,
190
189
  consider_running_trials: bool = False,
191
- independent_sampler: None = None,
190
+ independent_sampler: Optional[BaseSampler] = None,
192
191
  seed: int | None = None,
193
192
  device: torch.device | str | None = None,
194
193
  trial_chunks: int = 128,
195
194
  ):
196
- assert constraints_func is None
197
- assert candidates_func is None
198
- assert consider_running_trials is False
199
- assert independent_sampler is None
200
- self._candidates_func = None
201
- self._independent_sampler = RandomSampler(seed=seed)
195
+ if constraints_func is not None:
196
+ raise NotImplementedError("constraints_func is currently not supported by BoTorchSampler.")
197
+ if consider_running_trials:
198
+ raise NotImplementedError("consider_running_trials is currently not supported by BoTorchSampler.")
199
+ if candidates_func is not None and not callable(candidates_func):
200
+ raise TypeError("candidates_func must be callable.")
201
+ self._candidates_func = candidates_func
202
+ self._independent_sampler = independent_sampler or RandomSampler(seed=seed)
202
203
  self._n_startup_trials = n_startup_trials
203
204
  self._seed = seed
204
205
  self.trial_chunks = trial_chunks
205
206
 
206
207
  self._study_id: int | None = None
207
- self.search_space = search_space
208
+ self.search_space = {} if search_space is None else dict(search_space)
208
209
  if isinstance(device, str):
209
210
  device = torch.device(device)
210
211
  self._device = device or torch.device("cpu")
@@ -212,14 +213,24 @@ class BoTorchSampler(SimpleAPIBaseSampler):
212
213
  self._values = None
213
214
  self._params = None
214
215
  self._index = 0
216
+ self._bounds_dim: int | None = None
215
217
 
216
218
  def infer_relative_search_space(self, study: Study, trial: FrozenTrial) -> dict[str, BaseDistribution]:
217
219
  return self.search_space
218
220
 
219
221
  @torch.no_grad()
220
222
  def _preprocess_trials(
221
- self, trans: _SearchSpaceTransform, study: Study, trials: list[FrozenTrial]
223
+ self, trans: SearchSpaceTransform, study: Study, trials: list[FrozenTrial]
222
224
  ) -> Tuple[int, Tensor, Tensor]:
225
+ bounds_dim = trans.bounds.shape[0]
226
+ if self._bounds_dim is not None and self._bounds_dim != bounds_dim:
227
+ self._values = None
228
+ self._params = None
229
+ self._index = 0
230
+ self.seen_trials = set()
231
+ if self._bounds_dim is None:
232
+ self._bounds_dim = bounds_dim
233
+
223
234
  new_trials = []
224
235
  for trial in trials:
225
236
  tid: int = trial._trial_id
@@ -230,6 +241,10 @@ class BoTorchSampler(SimpleAPIBaseSampler):
230
241
 
231
242
  n_objectives = len(study.directions)
232
243
  if not new_trials:
244
+ if self._values is None or self._params is None:
245
+ empty_values = torch.zeros((0, n_objectives), dtype=torch.float64, device=self._device)
246
+ empty_params = torch.zeros((0, bounds_dim), dtype=torch.float64, device=self._device)
247
+ return n_objectives, empty_values, empty_params
233
248
  return n_objectives, self._values[: self._index], self._params[: self._index]
234
249
 
235
250
  n_completed_trials = len(trials)
@@ -246,18 +261,28 @@ class BoTorchSampler(SimpleAPIBaseSampler):
246
261
  if direction == StudyDirection.MINIMIZE: # BoTorch always assumes maximization.
247
262
  values[:, obj_idx] *= -1
248
263
 
249
- if self._values is None:
264
+ bounds_dim = trans.bounds.shape[0]
265
+ cache_stale = (
266
+ self._values is None
267
+ or self._params is None
268
+ or self._values.size(1) != n_objectives
269
+ or self._params.size(1) != bounds_dim
270
+ )
271
+ if cache_stale:
250
272
  self._values = torch.zeros((self.trial_chunks, n_objectives), dtype=torch.float64, device=self._device)
251
- self._params = torch.zeros(
252
- (self.trial_chunks, trans.bounds.shape[0]), dtype=torch.float64, device=self._device
253
- )
273
+ self._params = torch.zeros((self.trial_chunks, bounds_dim), dtype=torch.float64, device=self._device)
274
+ self._index = 0
275
+ self.seen_trials = set()
276
+ self._bounds_dim = bounds_dim
254
277
  spillage = (self._index + n_completed_trials) - self._values.size(0)
255
278
  if spillage > 0:
256
279
  pad = int(math.ceil(spillage / self.trial_chunks) * self.trial_chunks)
257
280
  self._values = F.pad(self._values, (0, 0, 0, pad))
258
281
  self._params = F.pad(self._params, (0, 0, 0, pad))
259
- self._values[self._index : self._index + n_completed_trials] = torch.from_numpy(values)
260
- self._params[self._index : self._index + n_completed_trials] = torch.from_numpy(params)
282
+ values_tensor = torch.from_numpy(values).to(self._device)
283
+ params_tensor = torch.from_numpy(params).to(self._device)
284
+ self._values[self._index : self._index + n_completed_trials] = values_tensor
285
+ self._params[self._index : self._index + n_completed_trials] = params_tensor
261
286
  self._index += n_completed_trials
262
287
 
263
288
  return n_objectives, self._values[: self._index], self._params[: self._index]
@@ -276,7 +301,7 @@ class BoTorchSampler(SimpleAPIBaseSampler):
276
301
  if n_completed_trials < self._n_startup_trials:
277
302
  return {}
278
303
 
279
- trans = _SearchSpaceTransform(search_space)
304
+ trans = SearchSpaceTransform(search_space)
280
305
  n_objectives, values, params = self._preprocess_trials(trans, study, completed_trials)
281
306
 
282
307
  if self._candidates_func is None:
@@ -379,10 +404,10 @@ class HEBOSampler(optunahub.samplers.SimpleBaseSampler, SimpleAPIBaseSampler):
379
404
  independent_sampler: BaseSampler | None = None,
380
405
  ) -> None:
381
406
  super().__init__(search_space, seed)
382
- assert constant_liar is False
383
- assert independent_sampler is None
407
+ if constant_liar:
408
+ raise NotImplementedError("constant_liar is not supported by HEBOSampler.")
384
409
  self._hebo = HEBO(_convert_to_hebo_design_space(search_space), scramble_seed=self._seed)
385
- self._independent_sampler = optuna.samplers.RandomSampler(seed=seed)
410
+ self._independent_sampler = independent_sampler or optuna.samplers.RandomSampler(seed=seed)
386
411
  self._rng = np.random.default_rng(seed)
387
412
 
388
413
  def sample_relative(
@@ -451,10 +476,12 @@ class FastINGO:
451
476
  learning_rate: Optional[float] = None,
452
477
  last_n: int = 4096,
453
478
  loco_step_size: float = 0.1,
454
- device="cuda",
479
+ device: str | None = None,
455
480
  batchnorm_decay: float = 0.99,
456
481
  score_decay: float = 0.99,
457
482
  ) -> None:
483
+ if device is None:
484
+ device = _use_cuda()
458
485
  n_dimension = len(mean)
459
486
  if population_size is None:
460
487
  population_size = 4 + int(np.floor(3 * np.log(n_dimension)))
@@ -521,8 +548,14 @@ class FastINGO:
521
548
  if y.numel() <= 2:
522
549
  return
523
550
 
524
- y = y + torch.where(y.min() <= 0, 1e-8 - y.min(), 0)
525
- y = y.log()
551
+ min_y = y.min()
552
+ max_y = y.max()
553
+ if torch.isclose(max_y, min_y, rtol=0.0, atol=1e-12):
554
+ return
555
+
556
+ if min_y <= 0:
557
+ y = y + (1e-8 - min_y)
558
+ y = y.clamp_min_(1e-8).log()
526
559
 
527
560
  ema = -torch.arange(y.size(0), device=y.device, dtype=y.dtype)
528
561
  weight = self.batchnorm_decay**ema
@@ -583,7 +616,7 @@ class ImplicitNaturalGradientSampler(BaseSampler):
583
616
  def reseed_rng(self) -> None:
584
617
  self._independent_sampler.reseed_rng()
585
618
  if self._optimizer:
586
- self._optimizer._rng.seed()
619
+ self._optimizer.generator.seed()
587
620
 
588
621
  def infer_relative_search_space(
589
622
  self, study: "optuna.Study", trial: "optuna.trial.FrozenTrial"
@@ -633,14 +666,11 @@ class ImplicitNaturalGradientSampler(BaseSampler):
633
666
  self._warn_independent_sampling = False
634
667
  return {}
635
668
 
636
- trans = _SearchSpaceTransform(search_space)
669
+ trans = SearchSpaceTransform(search_space)
637
670
 
638
- if self._optimizer is None:
671
+ if self._optimizer is None or self._optimizer.dim != len(trans.bounds):
639
672
  self._optimizer = self._init_optimizer(trans, population_size=self._population_size)
640
-
641
- if self._optimizer.dim != len(trans.bounds):
642
- self._warn_independent_sampling = False
643
- return {}
673
+ self._param_queue.clear()
644
674
 
645
675
  solution_trials = [t for t in completed_trials if self._check_trial_is_generation(t)]
646
676
  for t in solution_trials:
@@ -651,7 +681,7 @@ class ImplicitNaturalGradientSampler(BaseSampler):
651
681
 
652
682
  def _init_optimizer(
653
683
  self,
654
- trans: _SearchSpaceTransform,
684
+ trans: SearchSpaceTransform,
655
685
  population_size: Optional[int] = None,
656
686
  ) -> FastINGO:
657
687
  lower_bounds = trans.bounds[:, 0]
@@ -705,6 +735,7 @@ class ThreadLocalSampler(threading.local):
705
735
 
706
736
 
707
737
  def init_cmaes(study, seed, trials, search_space):
738
+ trials = copy.deepcopy(trials)
708
739
  trials.sort(key=lambda trial: trial.datetime_complete)
709
740
  return CmaEsSampler(seed=seed, source_trials=trials, lr_adapt=True)
710
741
 
@@ -716,8 +747,14 @@ def init_hebo(study, seed, trials, search_space):
716
747
  return sampler
717
748
 
718
749
 
750
+ def _use_cuda():
751
+ return "cuda" if torch.cuda.is_available() else "cpu"
752
+
753
+
719
754
  def init_botorch(study, seed, trials, search_space):
720
- return BoTorchSampler(search_space=search_space, seed=seed, device="cuda") # will automatically pull in latest data
755
+ return BoTorchSampler(
756
+ search_space=search_space, seed=seed, device=_use_cuda()
757
+ ) # will automatically pull in latest data
721
758
 
722
759
 
723
760
  def init_nsgaii(study, seed, trials, search_space):
@@ -739,17 +776,20 @@ class AutoSampler(BaseSampler):
739
776
  def __init__(
740
777
  self,
741
778
  samplers: Iterable[Tuple[int, Callable]] | None = None,
742
- search_space: dict[str, BaseDistribution] = None,
779
+ search_space: Optional[dict[str, BaseDistribution]] = None,
743
780
  *,
744
781
  seed: int | None = None,
745
- constraints_func: None = None,
782
+ constraints_func: Optional[Callable[..., Any]] = None,
746
783
  ) -> None:
747
- assert constraints_func is None
784
+ if constraints_func is not None:
785
+ raise NotImplementedError("constraints_func is not supported by AutoSampler.")
748
786
  if samplers is None:
787
+ if search_space is None:
788
+ raise ValueError("AutoSampler requires a search_space when using the default sampler schedule.")
749
789
  samplers = ((0, init_hebo), (100, init_nsgaii))
750
790
  self.sampler_indices = np.sort(np.array([x[0] for x in samplers], dtype=np.int32))
751
791
  self.samplers = [x[1] for x in sorted(samplers, key=lambda x: x[0])]
752
- self.search_space = search_space
792
+ self.search_space = {} if search_space is None else dict(search_space)
753
793
  self._rng = LazyRandomState(seed)
754
794
  self._random_sampler = RandomSampler(seed=seed)
755
795
  self._thread_local_sampler = ThreadLocalSampler()
@@ -792,7 +832,7 @@ class AutoSampler(BaseSampler):
792
832
  complete_trials = study._get_trials(deepcopy=False, states=(TrialState.COMPLETE,), use_cache=True)
793
833
  self._completed_trials = max(self._completed_trials, len(complete_trials))
794
834
  new_index = (self._completed_trials >= self.sampler_indices).sum() - 1
795
- if new_index == self._current_index:
835
+ if new_index == self._current_index or new_index < 0:
796
836
  return
797
837
  self._current_index = new_index
798
838
  self._sampler = self.samplers[new_index](
@@ -805,7 +845,7 @@ class AutoSampler(BaseSampler):
805
845
  def sample_relative(
806
846
  self, study: Study, trial: FrozenTrial, search_space: dict[str, BaseDistribution]
807
847
  ) -> dict[str, Any]:
808
- return self._sampler.sample_relative(study, trial, self.search_space)
848
+ return self._sampler.sample_relative(study, trial, search_space or self.search_space)
809
849
 
810
850
  def sample_independent(
811
851
  self,
@@ -834,5 +874,6 @@ class AutoSampler(BaseSampler):
834
874
  state: TrialState,
835
875
  values: Sequence[float] | None,
836
876
  ) -> None:
837
- assert state in [TrialState.COMPLETE, TrialState.FAIL, TrialState.PRUNED]
877
+ if state not in (TrialState.COMPLETE, TrialState.FAIL, TrialState.PRUNED):
878
+ raise ValueError(f"Unsupported trial state: {state}.")
838
879
  self._sampler.after_trial(study, trial, state, values)
@@ -343,7 +343,8 @@ def set_(dst: Tensor, src: Tensor):
343
343
 
344
344
 
345
345
  def clean():
346
- torch.cuda.empty_cache()
346
+ if torch.cuda.is_available():
347
+ torch.cuda.empty_cache()
347
348
  gc.collect()
348
349
 
349
350
 
@@ -470,7 +471,7 @@ def msign(G: torch.Tensor, steps: int = 10, eps: float = 1e-7) -> torch.Tensor:
470
471
 
471
472
  if should_transpose:
472
473
  x = x.mT
473
- return x.float()
474
+ return x.to(G.dtype)
474
475
 
475
476
 
476
477
  ###### END
@@ -670,9 +671,9 @@ def get_orthogonal_matrix(mat, max_eps: float = 1e-3, min_eps: float = 1e-30):
670
671
  final.append(None)
671
672
  continue
672
673
 
674
+ device, dtype = m.device, m.dtype
673
675
  m = promote(m.data)
674
676
 
675
- device, dtype = m.device, m.dtype
676
677
  eps = min_eps
677
678
  while True:
678
679
  try:
@@ -700,7 +701,6 @@ def get_orthogonal_matrix(mat, max_eps: float = 1e-3, min_eps: float = 1e-30):
700
701
  raise
701
702
  clean()
702
703
 
703
- eigvec = eigvec.to(device=m.device, dtype=m.dtype)
704
704
  eigvec = torch.flip(eigvec, [1])
705
705
  final.append(eigvec)
706
706
 
@@ -1053,13 +1053,15 @@ class StatefulOptimizer(torch.optim.Optimizer):
1053
1053
  def get_groups(self, group):
1054
1054
  return [group]
1055
1055
 
1056
- @functools.lru_cache(maxsize=None)
1057
1056
  def state_(self, arg: Tensor, fail: bool = True):
1058
- if not fail and arg not in self.mapping:
1059
- return {}
1060
- if _tensor_key(arg) not in self.mapping_inverse:
1057
+ key = _tensor_key(arg)
1058
+ if key not in self.mapping_inverse:
1061
1059
  self._init_mapping()
1062
- state_param, index = self.mapping_inverse[_tensor_key(arg)]
1060
+ if key not in self.mapping_inverse:
1061
+ if not fail:
1062
+ return {}
1063
+ raise KeyError("Tensor has no tracked state.")
1064
+ state_param, index = self.mapping_inverse[key]
1063
1065
  if state_param not in self.state:
1064
1066
  self.state[state_param] = collections.defaultdict(dict)
1065
1067
  return self.state[state_param][index]
@@ -1147,7 +1149,7 @@ class StatefulOptimizer(torch.optim.Optimizer):
1147
1149
  active_p = [p for p in group["params"]]
1148
1150
 
1149
1151
  if not active_p:
1150
- return
1152
+ continue
1151
1153
 
1152
1154
  k = group["ema_step"] = group.get("ema_step", -1) + 1
1153
1155
 
@@ -1164,7 +1166,7 @@ class StatefulOptimizer(torch.optim.Optimizer):
1164
1166
  active_p = [p for p in group["params"]]
1165
1167
 
1166
1168
  if not active_p:
1167
- return
1169
+ continue
1168
1170
 
1169
1171
  for p in active_p:
1170
1172
  if "param_ema" in self.state_(p):
@@ -1178,7 +1180,7 @@ class StatefulOptimizer(torch.optim.Optimizer):
1178
1180
  active_p = [p for p in group["params"]]
1179
1181
 
1180
1182
  if not active_p:
1181
- return
1183
+ continue
1182
1184
 
1183
1185
  for p in active_p:
1184
1186
  if "param_ema" in self.state_(p):
@@ -1207,7 +1209,7 @@ class StatefulOptimizer(torch.optim.Optimizer):
1207
1209
  for group in self.param_groups:
1208
1210
  for p, g in self.split_p_and_g_in_group(group, skip_none=True, raw=True):
1209
1211
  p.grad = grads.pop(0)
1210
- stochastic_add_(g, p.grad, -1) # technically, we have to divide by the scale here
1212
+ stochastic_add_divide_(g, p.grad, -1, torch.finfo(p.dtype).eps ** 0.5)
1211
1213
  p.hessian_vector = g
1212
1214
  p.data.copy_(p.orig)
1213
1215
  del p.orig
@@ -1297,6 +1299,8 @@ class StatefulOptimizer(torch.optim.Optimizer):
1297
1299
  self._is_preconditioning = psgd_should_update(self.inner_group, self.precond_schedule, self.precond_rng)
1298
1300
  loss = self._handle_closure(closure)
1299
1301
 
1302
+ if self.use_ema:
1303
+ self.ema_update()
1300
1304
  # we assume that parameters are constant and that there are no excessive recompiles
1301
1305
  with torch.no_grad(), torch._dynamo.utils.disable_cache_limit():
1302
1306
  for group in self.param_groups:
@@ -1304,8 +1308,6 @@ class StatefulOptimizer(torch.optim.Optimizer):
1304
1308
  group["param_count"] = sum(p.numel() for p in group["params"])
1305
1309
  group["is_preconditioning"] = self._is_preconditioning
1306
1310
  self._step(group)
1307
- if self.use_ema:
1308
- self.ema_update()
1309
1311
  for real, views in self.mapping.items():
1310
1312
  for tensor in (real, *views):
1311
1313
  for key in ("grad", "vector", "hessian_vector", "orig"):
@@ -1569,18 +1571,20 @@ def stochastic_round_list_(ref: List[Tensor], source: List[Tensor]):
1569
1571
 
1570
1572
  @decorator_knowngood
1571
1573
  def stochastic_round_(ref: Tensor, source: Tensor | None = None):
1572
- if source is not None:
1573
- if source.dtype == torch.bfloat16 or ref.dtype == source.dtype:
1574
- return source
1575
- if ref.dtype != torch.bfloat16:
1576
- return source.to(ref.dtype)
1577
- else:
1574
+ if source is None:
1578
1575
  source = ref
1579
- source = source.float()
1580
- result = torch.randint_like(source, dtype=torch.int32, low=0, high=(1 << 16))
1581
- result.add_(source.view(dtype=torch.int32))
1582
- result.bitwise_and_(-65536) # -65536 = FFFF0000 as a signed int32
1583
- return result.view(dtype=torch.float32).bfloat16()
1576
+ if ref.dtype != torch.bfloat16:
1577
+ return source.to(ref.dtype)
1578
+ if source.dtype == torch.bfloat16:
1579
+ return source
1580
+ if source.dtype in (torch.float16, torch.float32, torch.float64):
1581
+ source = source.to(torch.float32)
1582
+ noise = torch.randint_like(source, dtype=torch.int32, low=0, high=(1 << 16))
1583
+ bits = source.view(dtype=torch.int32)
1584
+ bits.add_(noise)
1585
+ bits.bitwise_and_(-65536) # FFFF0000 mask, preserves sign+exp+7 mantissa bits
1586
+ return bits.view(dtype=torch.float32).bfloat16()
1587
+ return source.to(ref.dtype)
1584
1588
 
1585
1589
 
1586
1590
  @decorator_knowngood
@@ -1913,7 +1917,8 @@ def update_lra_precond_(
1913
1917
 
1914
1918
  # LU factorization to reuse computation
1915
1919
  try:
1916
- LU, pivots = torch.linalg.lu_factor(IpVtU)
1920
+ lu_matrix = promote(IpVtU) # operate in fp32 when inputs are bf16/half
1921
+ LU, pivots = torch.linalg.lu_factor(lu_matrix)
1917
1922
  except RuntimeError:
1918
1923
  # Error:
1919
1924
  # U[2,2] is zero and using it on lu_solve would result in a division by zero.
@@ -1923,8 +1928,13 @@ def update_lra_precond_(
1923
1928
  # So, we skip this step and reattempt on the next one
1924
1929
  return U.to(U_orig[0].dtype), V.to(V_orig[0].dtype), d.to(d_orig[0].dtype)
1925
1930
 
1926
- invQtv = invQtv - V @ torch.linalg.lu_solve(LU, pivots, (U.T @ invQtv).view(-1, 1), adjoint=True).flatten()
1927
- invPv = U @ torch.linalg.lu_solve(LU, pivots, (V.T @ invQtv).view(-1, 1)).flatten()
1931
+ solve_dtype = LU.dtype
1932
+ rhs = (U.T @ invQtv).view(-1, 1).to(solve_dtype)
1933
+ correction = torch.linalg.lu_solve(LU, pivots, rhs, adjoint=True).to(V.dtype)
1934
+ invQtv = invQtv - (V @ correction).flatten()
1935
+ rhs = (V.T @ invQtv).view(-1, 1).to(solve_dtype)
1936
+ solution = torch.linalg.lu_solve(LU, pivots, rhs).to(U.dtype)
1937
+ invPv = (U @ solution).flatten()
1928
1938
 
1929
1939
  eps, step = scalar_guard(eps, step, vector)
1930
1940
  _compilable_d_step(d, d_orig, invQtv, vector, invPv, hessian_vector, Ph, eps, step, delayed)
@@ -2044,7 +2054,10 @@ def extract_from_flat_update(params: List[Tensor], update: Tensor):
2044
2054
  @decorator_knowngood
2045
2055
  def flatten(x: List[Tensor], remaining: int = 0) -> Tensor:
2046
2056
  last_dim = x[0].shape[-remaining:] if remaining else []
2047
- return torch.cat([i.reshape(-1, *last_dim) for i in x if i.numel()], 0)
2057
+ tensors = [i.reshape(-1, *last_dim) for i in x if i.numel()]
2058
+ if not tensors:
2059
+ return torch.zeros((), dtype=x[0].device, device=x[0].device)
2060
+ return torch.cat(tensors, 0)
2048
2061
 
2049
2062
 
2050
2063
  @decorator_knowngood
@@ -2116,16 +2129,6 @@ def psgd_calc_A_and_conjB(G: Tensor, Q, conjB: Tensor | None): # conjB ("V", "v
2116
2129
  return A, conjB
2117
2130
 
2118
2131
 
2119
- @decorator_knowngood
2120
- def _random_projection(x: Tensor, scale: Optional[Tensor]):
2121
- if scale is None:
2122
- scale = x.norm(float("inf")).clamp(min=1e-8)
2123
- k = 2 ** math.ceil(math.log2(math.log2(min(x.shape)))) # next-largest-power-of-2 of log2-of-size
2124
- norm = x.square().sum(0)
2125
- indices = torch.topk(norm, k, largest=True).indices
2126
- return x.index_select(1, indices).contiguous() / scale, scale
2127
-
2128
-
2129
2132
  def max_singular_value_exact(A, use_lobpcg: bool = False):
2130
2133
  try:
2131
2134
  if use_lobpcg:
@@ -2169,7 +2172,15 @@ def max_singular_value_cholesky(A: Tensor, max_abs: Optional[Tensor] = None):
2169
2172
  """
2170
2173
  Adapted from @evanatyourservice
2171
2174
  """
2172
- Y, max_abs = _random_projection(A, max_abs)
2175
+ if max_abs is None:
2176
+ max_abs = A.norm(float("inf")).clamp(min=1e-8)
2177
+
2178
+ # cholesky uses random projection, but this uses topk -- topk is a warm start, which may converge to a biased result
2179
+ k = 2 ** math.ceil(math.log2(math.log2(min(A.shape)))) # next-largest-power-of-2 of log2-of-size
2180
+ norm = A.square().sum(0)
2181
+ indices = torch.topk(norm, k, largest=True).indices
2182
+ Y = A.index_select(1, indices).contiguous() / max_abs
2183
+
2173
2184
  Q = inplace_orthogonal_(Y, precise_zeroth_power_mode)
2174
2185
  Q = Q / max_abs
2175
2186
  Z = A.T @ Q
@@ -2557,7 +2568,7 @@ def _psgd_quad_preconditioner_grad(GG: List[Tensor], Q: List[Tensor], numel: int
2557
2568
  else:
2558
2569
  scale = gg.size(0) / numel
2559
2570
  gg = 2 * torch.eye(gg.size(0), device=gg.device, dtype=gg.dtype) - gg * scale
2560
- update = q - gg @ q @ gg
2571
+ update = q - casted_einsum("ab,cd,bc", gg, gg, q)
2561
2572
  out.append(update + update.T) # make matrix symmetric
2562
2573
  return out
2563
2574
 
@@ -3111,7 +3122,7 @@ def pointwise_lr_adaptation(
3111
3122
  ):
3112
3123
  grads, update, state, delta = list_guard(grads, update, state, delta)
3113
3124
  lr_lr = scalar_guard(lr_lr, grads[0])
3114
- _compilable_lr_adapt_(grads, update, state, delta, lr_lr)
3125
+ _compilable_pointwise_lr_adapt_(grads, update, state, delta, lr_lr)
3115
3126
 
3116
3127
 
3117
3128
  def hook_optimizer_into_model(model, optimizer, *args, **kwargs):
@@ -3131,8 +3142,6 @@ def hook_optimizer_into_model(model, optimizer, *args, **kwargs):
3131
3142
 
3132
3143
  def fused_hook(parameters, optimizer, *args, **kwargs):
3133
3144
  parameters = list(parameters)
3134
- param_count = len(parameters)
3135
- seen_params = set()
3136
3145
 
3137
3146
  o = optimizer(parameters, *args, **kwargs)
3138
3147
  step_fn = o.step
@@ -3141,12 +3150,8 @@ def fused_hook(parameters, optimizer, *args, **kwargs):
3141
3150
  )
3142
3151
 
3143
3152
  def _step(p: Tensor):
3144
- seen_params.add(p)
3145
-
3146
- if len(seen_params) < param_count:
3147
- step_fn()
3148
- o.zero_grad()
3149
- seen_params.clear()
3153
+ step_fn()
3154
+ o.zero_grad()
3150
3155
 
3151
3156
  for p in parameters:
3152
3157
  p.register_post_accumulate_grad_hook(_step)
@@ -3171,6 +3176,8 @@ def sam_step(parameters, ball_size, adaptive: bool = True):
3171
3176
  old_params = []
3172
3177
  for p in parameters:
3173
3178
  old_params.append(p.detach().clone())
3179
+ if not hasattr_none(p, "grad"):
3180
+ continue
3174
3181
  grad = promote(p.grad)
3175
3182
  if adaptive:
3176
3183
  grad = grad * promote(p).square()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: heavyball
3
- Version: 2.1.3
3
+ Version: 2.1.4
4
4
  Summary: Efficient Optimizers
5
5
  Author-email: HeavyBall Authors <github.heavyball@nestler.sh>
6
6
  Project-URL: source, https://github.com/HomebrewML/HeavyBall
@@ -36,6 +36,7 @@ test/test_save_restore.py
36
36
  test/test_singular_values.py
37
37
  test/test_soap.py
38
38
  test/test_stochastic_updates.py
39
+ test/test_stochastic_utils_cpu.py
39
40
  test/test_toy_training.py
40
41
  test/test_utils_cpu.py
41
42
  test/test_utils_property.py
@@ -5,7 +5,7 @@ build-backend = "setuptools.build_meta"
5
5
  [project]
6
6
  name = "heavyball"
7
7
  description = "Efficient Optimizers"
8
- version = "2.1.3"
8
+ version = "2.1.4"
9
9
  authors = [{ name = "HeavyBall Authors", email = "github.heavyball@nestler.sh" }]
10
10
  classifiers = ["Intended Audience :: Developers",
11
11
  "Intended Audience :: Science/Research",
@@ -14,6 +14,11 @@ os.environ["TORCH_LOGS"] = "+recompiles"
14
14
 
15
15
  config.cache_size_limit = 128
16
16
 
17
+ pytestmark = pytest.mark.skipif(
18
+ not torch.cuda.is_available(),
19
+ reason="CUDA is required to run bf16 foreach parameter tests.",
20
+ )
21
+
17
22
 
18
23
  @pytest.mark.parametrize("opt", heavyball.__all__)
19
24
  @pytest.mark.parametrize("size,depth", [(256, 1)])
@@ -1,3 +1,5 @@
1
+ import os
2
+
1
3
  import pytest
2
4
  import torch
3
5
  from lightbench.utils import get_optim
@@ -15,13 +17,40 @@ def get_memory():
15
17
  return torch.cuda.memory_allocated()
16
18
 
17
19
 
20
+ def _read_int(name: str, default: int, *, minimum: int) -> int:
21
+ raw = os.environ.get(name)
22
+ if raw is None:
23
+ return default
24
+ try:
25
+ return max(minimum, int(raw))
26
+ except ValueError:
27
+ return default
28
+
29
+
30
+ DEFAULT_SIZE = _read_int("HB_FOREACH_TEST_SIZE", 128, minimum=1)
31
+ DEFAULT_DEPTH = _read_int("HB_FOREACH_TEST_DEPTH", 16, minimum=1)
32
+ DEFAULT_ITERATIONS = _read_int("HB_FOREACH_TEST_ITERATIONS", 64, minimum=1)
33
+ DEFAULT_OUTER = _read_int("HB_FOREACH_TEST_OUTER", 1, minimum=1)
34
+ DEFAULT_WARMUP = _read_int("HB_FOREACH_TEST_WARMUP", 1, minimum=0)
35
+
36
+
18
37
  @pytest.mark.parametrize("opt", heavyball.__all__)
19
- @pytest.mark.parametrize("size,depth", [(256, 128)])
20
- def test_foreach(opt, size, depth: int, iterations: int = 4096, outer_iterations: int = 2):
38
+ @pytest.mark.parametrize("size,depth", [(DEFAULT_SIZE, DEFAULT_DEPTH)])
39
+ def test_foreach(
40
+ opt,
41
+ size,
42
+ depth: int,
43
+ iterations: int = DEFAULT_ITERATIONS,
44
+ outer_iterations: int = DEFAULT_OUTER,
45
+ warmup_runs: int = DEFAULT_WARMUP,
46
+ ):
21
47
  set_torch()
22
48
 
23
49
  opt = getattr(heavyball, opt)
24
50
 
51
+ total_runs = warmup_runs + outer_iterations
52
+ assert total_runs >= 1
53
+
25
54
  peaks = []
26
55
  losses = []
27
56
 
@@ -30,7 +59,7 @@ def test_foreach(opt, size, depth: int, iterations: int = 4096, outer_iterations
30
59
  peaks.append([])
31
60
  losses.append([])
32
61
 
33
- for i in range(outer_iterations):
62
+ for i in range(total_runs):
34
63
  clean()
35
64
  model = nn.Sequential(*[nn.Linear(size, size) for _ in range(depth)]).cuda()
36
65
  clean()
@@ -56,8 +85,14 @@ def test_foreach(opt, size, depth: int, iterations: int = 4096, outer_iterations
56
85
 
57
86
  peak = torch.cuda.memory_stats()["allocated_bytes.all.peak"]
58
87
 
59
- if i > 0:
60
- peaks[-1].append(peak)
88
+ if i < warmup_runs:
89
+ continue
90
+
91
+ peaks[-1].append(peak)
92
+
93
+ if warmup_runs:
94
+ cutoff = warmup_runs * iterations
95
+ losses = [loss_list[cutoff:] for loss_list in losses]
61
96
 
62
97
  for p0, p1 in zip(*peaks):
63
98
  assert p0 > p1
@@ -19,7 +19,6 @@ expected_memory = {
19
19
  "adamw": {"after": 4, "peak": 5.1},
20
20
  "soap": {"after": 7, "peak": 14},
21
21
  "psgd": {"after": 4, "peak": 11.5},
22
- "padam": {"after": 5, "peak": 11.4},
23
22
  }
24
23
 
25
24
 
@@ -0,0 +1,49 @@
1
+ import os
2
+
3
+ os.environ.setdefault("TORCH_COMPILE_DISABLE", "1")
4
+
5
+ import torch
6
+
7
+ from heavyball.utils import copy_stochastic_, stochastic_add_, stochastic_divide_with_eps_
8
+
9
+
10
+ def _average_stochastic_round(source: torch.Tensor, trials: int = 512) -> torch.Tensor:
11
+ dest = torch.empty_like(source, dtype=torch.bfloat16)
12
+ totals = torch.zeros_like(source, dtype=torch.float64)
13
+ for _ in range(trials):
14
+ copy_stochastic_(dest, source)
15
+ totals += dest.to(dtype=torch.float64)
16
+ return totals / trials
17
+
18
+
19
+ def test_copy_stochastic_round_is_close_to_source_mean():
20
+ torch.manual_seed(0x5566AA)
21
+ values = torch.randn(2048, dtype=torch.float32) * 3.0
22
+ averaged = _average_stochastic_round(values, trials=256)
23
+ delta = averaged - values.double()
24
+
25
+ # Stochastic round should stay close to the original float32 values.
26
+ assert delta.abs().mean().item() < 5e-3
27
+ assert delta.abs().max().item() < 2.5e-2
28
+
29
+
30
+ def test_stochastic_add_broadcasts_partner_lists():
31
+ torch.manual_seed(0x172893)
32
+ targets = [torch.zeros(4, dtype=torch.bfloat16) for _ in range(2)]
33
+ partner = [torch.linspace(-1.0, 1.0, 4, dtype=torch.float32)]
34
+
35
+ stochastic_add_(targets, partner, alpha=0.25)
36
+ expected = partner[0] * 0.25
37
+ for tensor in targets:
38
+ assert torch.allclose(tensor.float(), expected, atol=5e-3, rtol=0)
39
+
40
+
41
+ def test_stochastic_divide_with_eps_matches_float_result():
42
+ torch.manual_seed(0xABCDEF)
43
+ numerator = torch.randn(32, dtype=torch.bfloat16)
44
+ denominator = torch.rand(32, dtype=torch.bfloat16) + 0.05
45
+ result = numerator.clone()
46
+
47
+ stochastic_divide_with_eps_(result, denominator, eps=1e-3)
48
+ expected = numerator.float() / (denominator.float() + 1e-3)
49
+ assert torch.allclose(result.float(), expected, atol=2e-2, rtol=2e-2)
@@ -3,7 +3,6 @@ import random
3
3
  import warnings
4
4
  from copy import deepcopy
5
5
 
6
- import numpy as np
7
6
  import pytest
8
7
  import torch
9
8
  from torch import Tensor, nn
@@ -222,29 +221,31 @@ def test_stochastic_math_helpers_match_expected_results():
222
221
  assert torch.allclose(a, torch.full_like(a, expected), atol=1e-6)
223
222
 
224
223
 
225
- def test_stochastic_math_accuracy(steps: int = 100, items: int = 32, target_shift: float = 1.0):
226
- """
227
- TODO: Rework this test or stochastic rounding.
228
- With target_shift=1, it passes. With target_shift != 1, it does not pass -- this is unexpected
229
- """
224
+ def test_stochastic_math_accuracy():
230
225
  torch.manual_seed(0x172893)
231
- rng = np.random.default_rng(0x213112)
226
+ items = 8
227
+ steps = 2048
228
+ increments = torch.full((items,), 1e-3, dtype=torch.float32)
232
229
 
233
- accum_baseline = torch.zeros(items, dtype=torch.bfloat16) + target_shift
234
- accum_stochastic = torch.zeros(items, dtype=torch.bfloat16) + target_shift
235
- accum_groundtruth = torch.zeros(items, dtype=torch.float64) + target_shift
230
+ baseline = torch.zeros(items, dtype=torch.bfloat16)
231
+ stochastic = torch.zeros(items, dtype=torch.bfloat16)
232
+ ground_truth = torch.zeros(items, dtype=torch.float64)
236
233
 
237
- add = 1 / (1 + 2 * torch.arange(items, dtype=torch.float32))
238
- alphas = np.exp(-2 - 2 * rng.random((steps // 2,)))
239
- for _ in range(2):
240
- for alpha in alphas:
241
- accum_baseline.add_(add, alpha=alpha)
242
- accum_groundtruth.add_(add, alpha=alpha)
243
- stochastic_add_(accum_stochastic, add, alpha=alpha)
244
- assert (accum_baseline.double() - accum_groundtruth).norm().item() > (
245
- accum_stochastic.double() - accum_groundtruth
246
- ).norm().item()
247
- alphas = -alphas
234
+ for _ in range(steps):
235
+ baseline.add_(increments)
236
+ ground_truth.add_(increments)
237
+ stochastic_add_(stochastic, increments)
238
+
239
+ baseline_error = torch.abs(baseline.float() - ground_truth.float()).mean().item()
240
+ stochastic_error = torch.abs(stochastic.float() - ground_truth.float()).mean().item()
241
+
242
+ assert baseline_error > 1.0
243
+ assert stochastic_error < 0.2
244
+ assert stochastic_error < baseline_error * 0.2
245
+
246
+ baseline_bias = abs(baseline.float().mean().item() - ground_truth.float().mean().item())
247
+ stochastic_bias = abs(stochastic.float().mean().item() - ground_truth.float().mean().item())
248
+ assert stochastic_bias < baseline_bias
248
249
 
249
250
 
250
251
  def test_disable_caution_scaling_toggles_behavior():
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes