heavyball 2.1.2__py3-none-any.whl → 2.1.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
heavyball/__init__.py CHANGED
@@ -1,6 +1,6 @@
1
1
  import functools
2
2
  import math
3
- from typing import Optional
3
+ from typing import Optional, Type, Union
4
4
 
5
5
  import torch.optim
6
6
 
@@ -8,39 +8,6 @@ from . import chainable as C
8
8
  from . import utils
9
9
 
10
10
 
11
- class SAMWrapper(torch.optim.Optimizer):
12
- def __init__(self, params, wrapped_optimizer: utils.StatefulOptimizer, ball: float = 0.1):
13
- if not isinstance(wrapped_optimizer, utils.StatefulOptimizer):
14
- raise ValueError(f"{wrapped_optimizer.__class__.__name__} is not a HeavyBall optimizer")
15
- super().__init__(params, {"ball": ball})
16
- self.wrapped_optimizer = wrapped_optimizer
17
-
18
- @torch.no_grad()
19
- def step(self, closure=None):
20
- if closure is None:
21
- raise ValueError("SAM requires closure")
22
- with torch.enable_grad():
23
- closure()
24
- old_params = [utils.sam_step(group["params"], group["ball"]) for group in self.param_groups]
25
-
26
- originaL_handle_closure = self.wrapped_optimizer._handle_closure
27
-
28
- def _handle_closure(closure):
29
- originaL_handle_closure(closure)
30
- for group, old in zip(self.param_groups, old_params):
31
- utils.copy_stochastic_list_(group["params"], old)
32
-
33
- try:
34
- self.wrapped_optimizer._handle_closure = _handle_closure
35
- loss = self.wrapped_optimizer.step(closure)
36
- finally:
37
- self.wrapped_optimizer._handle_closure = originaL_handle_closure
38
- return loss
39
-
40
- def zero_grad(self, set_to_none: bool = True):
41
- self.wrapped_optimizer.zero_grad()
42
-
43
-
44
11
  class SGD(C.BaseOpt):
45
12
  def __init__(
46
13
  self,
@@ -778,7 +745,7 @@ class ForeachPSGDKron(C.BaseOpt):
778
745
  beta=None,
779
746
  betas=(0.9, 0.999),
780
747
  weight_decay=0.0,
781
- preconditioner_update_probability=None,
748
+ preconditioner_update_probability=C.use_default,
782
749
  max_size_triangular=2048,
783
750
  min_ndim_triangular=2,
784
751
  memory_save_mode=None,
@@ -830,8 +797,8 @@ class ForeachPSGDKron(C.BaseOpt):
830
797
  if kwargs:
831
798
  utils.warn_once(f"Working with uncaptured keyword arguments: {kwargs}")
832
799
 
833
- self.precond_schedule = (
834
- defaults.pop("preconditioner_update_probability") or utils.precond_update_prob_schedule()
800
+ self.precond_schedule = C.default(
801
+ defaults.pop("preconditioner_update_probability"), utils.precond_update_prob_schedule()
835
802
  )
836
803
  params = defaults.pop("params")
837
804
 
@@ -890,7 +857,7 @@ class ForeachPSGDLRA(C.BaseOpt):
890
857
  lr=0.001,
891
858
  beta=0.9,
892
859
  weight_decay=0.0,
893
- preconditioner_update_probability=None,
860
+ preconditioner_update_probability=C.use_default,
894
861
  momentum_into_precond_update=True,
895
862
  rank: Optional[int] = None,
896
863
  warmup_steps: int = 0,
@@ -924,8 +891,8 @@ class ForeachPSGDLRA(C.BaseOpt):
924
891
  if kwargs:
925
892
  utils.warn_once(f"Working with uncaptured keyword arguments: {kwargs}")
926
893
 
927
- self.precond_schedule = (
928
- defaults.pop("preconditioner_update_probability") or utils.precond_update_prob_schedule()
894
+ self.precond_schedule = C.default(
895
+ defaults.pop("preconditioner_update_probability"), utils.precond_update_prob_schedule()
929
896
  )
930
897
  params = defaults.pop("params")
931
898
 
@@ -960,6 +927,54 @@ class NewtonHybrid2PSGDLRA(ForeachNewtonPSGDLRA):
960
927
  hvp_interval = 2
961
928
 
962
929
 
930
+ class SAMWrapper(torch.optim.Optimizer):
931
+ def __init__(
932
+ self,
933
+ params,
934
+ wrapped_optimizer: Union[utils.StatefulOptimizer, Type[utils.StatefulOptimizer]] = ForeachAdamW,
935
+ ball: float = 0.1,
936
+ ):
937
+ params = list(params)
938
+ super().__init__(params, {"ball": ball})
939
+
940
+ if isinstance(wrapped_optimizer, type):
941
+ if not issubclass(wrapped_optimizer, utils.StatefulOptimizer):
942
+ raise ValueError(f"{wrapped_optimizer.__name__} is not a HeavyBall optimizer")
943
+ wrapped_optimizer = wrapped_optimizer(params)
944
+ elif not isinstance(wrapped_optimizer, utils.StatefulOptimizer):
945
+ raise ValueError(f"{wrapped_optimizer.__class__.__name__} is not a HeavyBall optimizer")
946
+
947
+ self.wrapped_optimizer = wrapped_optimizer
948
+
949
+ @torch.no_grad()
950
+ def step(self, closure=None):
951
+ if closure is None:
952
+ raise ValueError("SAM requires closure")
953
+ with torch.enable_grad():
954
+ closure()
955
+ old_params = [utils.sam_step(group["params"], group["ball"]) for group in self.param_groups]
956
+
957
+ original_handle_closure = self.wrapped_optimizer._handle_closure
958
+
959
+ def _handle_closure(closure):
960
+ try:
961
+ _loss = original_handle_closure(closure)
962
+ finally:
963
+ for group, old in zip(self.param_groups, old_params):
964
+ utils.copy_stochastic_list_(group["params"], old)
965
+ return _loss
966
+
967
+ try:
968
+ self.wrapped_optimizer._handle_closure = _handle_closure
969
+ loss = self.wrapped_optimizer.step(closure)
970
+ finally:
971
+ self.wrapped_optimizer._handle_closure = original_handle_closure
972
+ return loss
973
+
974
+ def zero_grad(self, set_to_none: bool = True):
975
+ self.wrapped_optimizer.zero_grad(set_to_none=set_to_none)
976
+
977
+
963
978
  PalmForEachSoap = PaLMForeachSOAP
964
979
  PaLMSOAP = PaLMForeachSOAP
965
980
  PaLMSFAdamW = PaLMForeachSFAdamW
@@ -983,52 +998,4 @@ PSGDLRA = ForeachPSGDLRA
983
998
  NewtonPSGDLRA = ForeachNewtonPSGDLRA
984
999
  NewtonPSGDKron = ForeachCachedNewtonPSGD
985
1000
 
986
- __all__ = [
987
- "Muon",
988
- "RMSprop",
989
- "PrecondSchedulePaLMSOAP",
990
- "PSGDKron",
991
- "PurePSGD",
992
- "DelayedPSGD",
993
- "CachedPSGDKron",
994
- "CachedDelayedPSGDKron",
995
- "PalmForEachSoap",
996
- "PaLMSOAP",
997
- "PaLMSFAdamW",
998
- "LaProp",
999
- "ADOPT",
1000
- "PrecondScheduleSOAP",
1001
- "PrecondSchedulePaLMSOAP",
1002
- "RMSprop",
1003
- "MuonLaProp",
1004
- "ForeachSignLaProp",
1005
- "ForeachDelayedPSGDLRA",
1006
- "ForeachPSGDLRA",
1007
- "ForeachPSGDLRA",
1008
- "ForeachNewtonPSGDLRA", #
1009
- "ForeachAdamW",
1010
- "ForeachSFAdamW",
1011
- "ForeachLaProp",
1012
- "ForeachADOPT",
1013
- "ForeachSOAP",
1014
- "ForeachPSGDKron",
1015
- "ForeachPurePSGD",
1016
- "ForeachDelayedPSGD",
1017
- "ForeachCachedPSGDKron",
1018
- "ForeachCachedDelayedPSGDKron",
1019
- "ForeachRMSprop",
1020
- "ForeachMuon",
1021
- "ForeachCachedNewtonPSGD",
1022
- "OrthoLaProp",
1023
- "LaPropOrtho",
1024
- "SignLaProp",
1025
- "DelayedPSGD",
1026
- "PSGDLRA",
1027
- "NewtonPSGDLRA",
1028
- "NewtonHybrid2PSGDLRA",
1029
- "NewtonHybrid2PSGDKron",
1030
- "MSAMLaProp",
1031
- "NewtonPSGDKron",
1032
- "ForeachAdamC",
1033
- "SGD",
1034
- ]
1001
+ __all__ = [k for k, v in globals().items() if isinstance(v, type) and issubclass(v, torch.optim.Optimizer)]
heavyball/chainable.py CHANGED
@@ -62,8 +62,6 @@ class FunctionTransform:
62
62
  self._init(st, group, *a, **kwargs)
63
63
  except SkipUpdate:
64
64
  skip_update = True
65
- except:
66
- raise
67
65
  finally:
68
66
  if "is_initialized" not in st:
69
67
  st["is_initialized"] = set()
@@ -499,6 +497,7 @@ def scale_by_suds(group, update, grad, param, exp_avg, exp_avg_sq, fisher_approx
499
497
  precond_update, _ = utils.eigvecs_product_rank1(precond_update.flatten(), fisher_approx.flatten(), w)
500
498
 
501
499
  new_approx = utils.oja_update(fisher_approx.flatten().to(update.dtype), update.flatten(), group["precond_lr"])
500
+ new_approx = new_approx.view_as(fisher_approx)
502
501
  utils.copy_stochastic_(fisher_approx, new_approx)
503
502
  return precond_update
504
503
 
@@ -565,7 +564,7 @@ def _init_psgd_kron(state, group, update, grad, param, cached: bool = False, pro
565
564
  )
566
565
  state["Q"] = utils.triu_to_line(Q) if group["store_triu_as_line"] else Q
567
566
  state["running_lower_bound"] = [torch.zeros((1,), device=q.device, dtype=torch.float64) for q in Q]
568
- state["step"] = torch.zeros((), device=param.device, dtype=torch.int64)
567
+ state["step"] = torch.zeros((), device=param.device, dtype=torch.float64) # torch casts int to float in ckpt load
569
568
  if group["adaptive"]:
570
569
  state["velocity"] = [torch.zeros((), device=q.device, dtype=q.dtype) for q in Q]
571
570
  if not cached:
@@ -750,7 +749,9 @@ def _update_psgd_precond(
750
749
  if isinstance(prob, float):
751
750
  float_prob = prob
752
751
  else:
753
- float_prob = prob(group.get(f"cumulative_prob_{id(Q)}_prob_step", 1))
752
+ prob_step = group.get(f"cumulative_prob_{id(Q)}_prob_step", 1)
753
+ float_prob = prob(prob_step)
754
+ group[f"cumulative_prob_{id(Q)}_prob_step"] = prob_step + 1
754
755
  group["is_cached"] = should_use_cache = cached and float_prob < 0.5
755
756
 
756
757
  if precond is not None:
@@ -1086,6 +1087,7 @@ class ChainOpt(utils.StatefulOptimizer):
1086
1087
  if not group["foreach"] or len(p) == 1:
1087
1088
  for param, grad in zip(p, g):
1088
1089
  chain(self.state_, group, [grad], [param], *self.fns)
1090
+ group["caution"] = caution
1089
1091
  else:
1090
1092
  chain(self.state_, group, g, p, *self.fns)
1091
1093
 
heavyball/helpers.py CHANGED
@@ -1,9 +1,9 @@
1
- from __future__ import annotations
2
-
1
+ import copy
3
2
  import functools
4
3
  import math
5
4
  import threading
6
- from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union
5
+ from contextlib import contextmanager
6
+ from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Sequence, Tuple, Union
7
7
 
8
8
  import numpy
9
9
  import numpy as np
@@ -11,23 +11,15 @@ import optuna
11
11
  import optunahub
12
12
  import pandas as pd
13
13
  import torch
14
- from botorch.utils.sampling import manual_seed
15
14
  from hebo.design_space.design_space import DesignSpace
16
15
  from hebo.optimizers.hebo import HEBO
17
- from optuna._transform import _SearchSpaceTransform
16
+ from optuna._transform import _SearchSpaceTransform as SearchSpaceTransform
18
17
  from optuna.distributions import BaseDistribution, CategoricalDistribution, FloatDistribution, IntDistribution
19
18
  from optuna.samplers import BaseSampler, CmaEsSampler, RandomSampler
20
19
  from optuna.samplers._lazy_random_state import LazyRandomState
21
20
  from optuna.study import Study
22
21
  from optuna.study._study_direction import StudyDirection
23
22
  from optuna.trial import FrozenTrial, TrialState
24
- from optuna_integration.botorch import (
25
- ehvi_candidates_func,
26
- logei_candidates_func,
27
- qehvi_candidates_func,
28
- qei_candidates_func,
29
- qparego_candidates_func,
30
- )
31
23
  from torch import Tensor
32
24
  from torch.nn import functional as F
33
25
 
@@ -37,12 +29,39 @@ _MAXINT32 = (1 << 31) - 1
37
29
  _SAMPLER_KEY = "auto:sampler"
38
30
 
39
31
 
32
+ @contextmanager
33
+ def manual_seed(seed: int | None = None) -> Generator[None, None, None]:
34
+ r"""
35
+ Contextmanager for manual setting the torch.random seed.
36
+
37
+ Args:
38
+ seed: The seed to set the random number generator to.
39
+
40
+ Returns:
41
+ Generator
42
+
43
+ Example:
44
+ >>> with manual_seed(1234):
45
+ >>> X = torch.rand(3)
46
+
47
+ copied as-is from https://github.com/meta-pytorch/botorch/blob/a42cd65f9b704cdb6f2ee64db99a022eb15295d5/botorch/utils/sampling.py#L53C1-L75C50 under the MIT License
48
+ """
49
+ old_state = torch.random.get_rng_state()
50
+ try:
51
+ if seed is not None:
52
+ torch.random.manual_seed(seed)
53
+ yield
54
+ finally:
55
+ if seed is not None:
56
+ torch.random.set_rng_state(old_state)
57
+
58
+
40
59
  class SimpleAPIBaseSampler(BaseSampler):
41
60
  def __init__(
42
61
  self,
43
- search_space: dict[str, BaseDistribution] = None,
62
+ search_space: Optional[dict[str, BaseDistribution]] = None,
44
63
  ):
45
- self.search_space = search_space
64
+ self.search_space = {} if search_space is None else dict(search_space)
46
65
 
47
66
  def suggest_all(self, trial: FrozenTrial):
48
67
  return {k: trial._suggest(k, dist) for k, dist in self.search_space.items()}
@@ -65,6 +84,16 @@ def _get_default_candidates_func(
65
84
  """
66
85
  The original is available at https://github.com/optuna/optuna-integration/blob/156a8bc081322791015d2beefff9373ed7b24047/optuna_integration/botorch/botorch.py under the MIT License
67
86
  """
87
+
88
+ # lazy import
89
+ from optuna_integration.botorch import (
90
+ ehvi_candidates_func,
91
+ logei_candidates_func,
92
+ qehvi_candidates_func,
93
+ qei_candidates_func,
94
+ qparego_candidates_func,
95
+ )
96
+
68
97
  if n_objectives > 3 and not has_constraint and not consider_running_trials:
69
98
  return ehvi_candidates_func
70
99
  elif n_objectives > 3:
@@ -124,7 +153,7 @@ def _untransform_numerical_param_torch(
124
153
 
125
154
 
126
155
  @torch.no_grad()
127
- def untransform(self: _SearchSpaceTransform, trans_params: Tensor) -> dict[str, Any]:
156
+ def untransform(self: SearchSpaceTransform, trans_params: Tensor) -> dict[str, Any]:
128
157
  assert trans_params.shape == (self._raw_bounds.shape[0],)
129
158
 
130
159
  if self._transform_0_1:
@@ -152,29 +181,31 @@ class BoTorchSampler(SimpleAPIBaseSampler):
152
181
 
153
182
  def __init__(
154
183
  self,
155
- search_space: dict[str, BaseDistribution] = None,
184
+ search_space: Optional[dict[str, BaseDistribution]] = None,
156
185
  *,
157
- candidates_func: None = None,
158
- constraints_func: None = None,
186
+ candidates_func: Optional[Callable[..., Tensor]] = None,
187
+ constraints_func: Optional[Callable[..., Tensor]] = None,
159
188
  n_startup_trials: int = 10,
160
189
  consider_running_trials: bool = False,
161
- independent_sampler: None = None,
190
+ independent_sampler: Optional[BaseSampler] = None,
162
191
  seed: int | None = None,
163
192
  device: torch.device | str | None = None,
164
193
  trial_chunks: int = 128,
165
194
  ):
166
- assert constraints_func is None
167
- assert candidates_func is None
168
- assert consider_running_trials is False
169
- assert independent_sampler is None
170
- self._candidates_func = None
171
- self._independent_sampler = RandomSampler(seed=seed)
195
+ if constraints_func is not None:
196
+ raise NotImplementedError("constraints_func is currently not supported by BoTorchSampler.")
197
+ if consider_running_trials:
198
+ raise NotImplementedError("consider_running_trials is currently not supported by BoTorchSampler.")
199
+ if candidates_func is not None and not callable(candidates_func):
200
+ raise TypeError("candidates_func must be callable.")
201
+ self._candidates_func = candidates_func
202
+ self._independent_sampler = independent_sampler or RandomSampler(seed=seed)
172
203
  self._n_startup_trials = n_startup_trials
173
204
  self._seed = seed
174
205
  self.trial_chunks = trial_chunks
175
206
 
176
207
  self._study_id: int | None = None
177
- self.search_space = search_space
208
+ self.search_space = {} if search_space is None else dict(search_space)
178
209
  if isinstance(device, str):
179
210
  device = torch.device(device)
180
211
  self._device = device or torch.device("cpu")
@@ -182,14 +213,24 @@ class BoTorchSampler(SimpleAPIBaseSampler):
182
213
  self._values = None
183
214
  self._params = None
184
215
  self._index = 0
216
+ self._bounds_dim: int | None = None
185
217
 
186
218
  def infer_relative_search_space(self, study: Study, trial: FrozenTrial) -> dict[str, BaseDistribution]:
187
219
  return self.search_space
188
220
 
189
221
  @torch.no_grad()
190
222
  def _preprocess_trials(
191
- self, trans: _SearchSpaceTransform, study: Study, trials: list[FrozenTrial]
223
+ self, trans: SearchSpaceTransform, study: Study, trials: list[FrozenTrial]
192
224
  ) -> Tuple[int, Tensor, Tensor]:
225
+ bounds_dim = trans.bounds.shape[0]
226
+ if self._bounds_dim is not None and self._bounds_dim != bounds_dim:
227
+ self._values = None
228
+ self._params = None
229
+ self._index = 0
230
+ self.seen_trials = set()
231
+ if self._bounds_dim is None:
232
+ self._bounds_dim = bounds_dim
233
+
193
234
  new_trials = []
194
235
  for trial in trials:
195
236
  tid: int = trial._trial_id
@@ -200,6 +241,10 @@ class BoTorchSampler(SimpleAPIBaseSampler):
200
241
 
201
242
  n_objectives = len(study.directions)
202
243
  if not new_trials:
244
+ if self._values is None or self._params is None:
245
+ empty_values = torch.zeros((0, n_objectives), dtype=torch.float64, device=self._device)
246
+ empty_params = torch.zeros((0, bounds_dim), dtype=torch.float64, device=self._device)
247
+ return n_objectives, empty_values, empty_params
203
248
  return n_objectives, self._values[: self._index], self._params[: self._index]
204
249
 
205
250
  n_completed_trials = len(trials)
@@ -216,18 +261,28 @@ class BoTorchSampler(SimpleAPIBaseSampler):
216
261
  if direction == StudyDirection.MINIMIZE: # BoTorch always assumes maximization.
217
262
  values[:, obj_idx] *= -1
218
263
 
219
- if self._values is None:
264
+ bounds_dim = trans.bounds.shape[0]
265
+ cache_stale = (
266
+ self._values is None
267
+ or self._params is None
268
+ or self._values.size(1) != n_objectives
269
+ or self._params.size(1) != bounds_dim
270
+ )
271
+ if cache_stale:
220
272
  self._values = torch.zeros((self.trial_chunks, n_objectives), dtype=torch.float64, device=self._device)
221
- self._params = torch.zeros(
222
- (self.trial_chunks, trans.bounds.shape[0]), dtype=torch.float64, device=self._device
223
- )
273
+ self._params = torch.zeros((self.trial_chunks, bounds_dim), dtype=torch.float64, device=self._device)
274
+ self._index = 0
275
+ self.seen_trials = set()
276
+ self._bounds_dim = bounds_dim
224
277
  spillage = (self._index + n_completed_trials) - self._values.size(0)
225
278
  if spillage > 0:
226
279
  pad = int(math.ceil(spillage / self.trial_chunks) * self.trial_chunks)
227
280
  self._values = F.pad(self._values, (0, 0, 0, pad))
228
281
  self._params = F.pad(self._params, (0, 0, 0, pad))
229
- self._values[self._index : self._index + n_completed_trials] = torch.from_numpy(values)
230
- self._params[self._index : self._index + n_completed_trials] = torch.from_numpy(params)
282
+ values_tensor = torch.from_numpy(values).to(self._device)
283
+ params_tensor = torch.from_numpy(params).to(self._device)
284
+ self._values[self._index : self._index + n_completed_trials] = values_tensor
285
+ self._params[self._index : self._index + n_completed_trials] = params_tensor
231
286
  self._index += n_completed_trials
232
287
 
233
288
  return n_objectives, self._values[: self._index], self._params[: self._index]
@@ -246,7 +301,7 @@ class BoTorchSampler(SimpleAPIBaseSampler):
246
301
  if n_completed_trials < self._n_startup_trials:
247
302
  return {}
248
303
 
249
- trans = _SearchSpaceTransform(search_space)
304
+ trans = SearchSpaceTransform(search_space)
250
305
  n_objectives, values, params = self._preprocess_trials(trans, study, completed_trials)
251
306
 
252
307
  if self._candidates_func is None:
@@ -349,10 +404,10 @@ class HEBOSampler(optunahub.samplers.SimpleBaseSampler, SimpleAPIBaseSampler):
349
404
  independent_sampler: BaseSampler | None = None,
350
405
  ) -> None:
351
406
  super().__init__(search_space, seed)
352
- assert constant_liar is False
353
- assert independent_sampler is None
407
+ if constant_liar:
408
+ raise NotImplementedError("constant_liar is not supported by HEBOSampler.")
354
409
  self._hebo = HEBO(_convert_to_hebo_design_space(search_space), scramble_seed=self._seed)
355
- self._independent_sampler = optuna.samplers.RandomSampler(seed=seed)
410
+ self._independent_sampler = independent_sampler or optuna.samplers.RandomSampler(seed=seed)
356
411
  self._rng = np.random.default_rng(seed)
357
412
 
358
413
  def sample_relative(
@@ -421,10 +476,12 @@ class FastINGO:
421
476
  learning_rate: Optional[float] = None,
422
477
  last_n: int = 4096,
423
478
  loco_step_size: float = 0.1,
424
- device="cuda",
479
+ device: str | None = None,
425
480
  batchnorm_decay: float = 0.99,
426
481
  score_decay: float = 0.99,
427
482
  ) -> None:
483
+ if device is None:
484
+ device = _use_cuda()
428
485
  n_dimension = len(mean)
429
486
  if population_size is None:
430
487
  population_size = 4 + int(np.floor(3 * np.log(n_dimension)))
@@ -491,8 +548,14 @@ class FastINGO:
491
548
  if y.numel() <= 2:
492
549
  return
493
550
 
494
- y = y + torch.where(y.min() <= 0, 1e-8 - y.min(), 0)
495
- y = y.log()
551
+ min_y = y.min()
552
+ max_y = y.max()
553
+ if torch.isclose(max_y, min_y, rtol=0.0, atol=1e-12):
554
+ return
555
+
556
+ if min_y <= 0:
557
+ y = y + (1e-8 - min_y)
558
+ y = y.clamp_min_(1e-8).log()
496
559
 
497
560
  ema = -torch.arange(y.size(0), device=y.device, dtype=y.dtype)
498
561
  weight = self.batchnorm_decay**ema
@@ -553,7 +616,7 @@ class ImplicitNaturalGradientSampler(BaseSampler):
553
616
  def reseed_rng(self) -> None:
554
617
  self._independent_sampler.reseed_rng()
555
618
  if self._optimizer:
556
- self._optimizer._rng.seed()
619
+ self._optimizer.generator.seed()
557
620
 
558
621
  def infer_relative_search_space(
559
622
  self, study: "optuna.Study", trial: "optuna.trial.FrozenTrial"
@@ -603,14 +666,11 @@ class ImplicitNaturalGradientSampler(BaseSampler):
603
666
  self._warn_independent_sampling = False
604
667
  return {}
605
668
 
606
- trans = _SearchSpaceTransform(search_space)
669
+ trans = SearchSpaceTransform(search_space)
607
670
 
608
- if self._optimizer is None:
671
+ if self._optimizer is None or self._optimizer.dim != len(trans.bounds):
609
672
  self._optimizer = self._init_optimizer(trans, population_size=self._population_size)
610
-
611
- if self._optimizer.dim != len(trans.bounds):
612
- self._warn_independent_sampling = False
613
- return {}
673
+ self._param_queue.clear()
614
674
 
615
675
  solution_trials = [t for t in completed_trials if self._check_trial_is_generation(t)]
616
676
  for t in solution_trials:
@@ -621,7 +681,7 @@ class ImplicitNaturalGradientSampler(BaseSampler):
621
681
 
622
682
  def _init_optimizer(
623
683
  self,
624
- trans: _SearchSpaceTransform,
684
+ trans: SearchSpaceTransform,
625
685
  population_size: Optional[int] = None,
626
686
  ) -> FastINGO:
627
687
  lower_bounds = trans.bounds[:, 0]
@@ -675,6 +735,7 @@ class ThreadLocalSampler(threading.local):
675
735
 
676
736
 
677
737
  def init_cmaes(study, seed, trials, search_space):
738
+ trials = copy.deepcopy(trials)
678
739
  trials.sort(key=lambda trial: trial.datetime_complete)
679
740
  return CmaEsSampler(seed=seed, source_trials=trials, lr_adapt=True)
680
741
 
@@ -686,8 +747,14 @@ def init_hebo(study, seed, trials, search_space):
686
747
  return sampler
687
748
 
688
749
 
750
+ def _use_cuda():
751
+ return "cuda" if torch.cuda.is_available() else "cpu"
752
+
753
+
689
754
  def init_botorch(study, seed, trials, search_space):
690
- return BoTorchSampler(search_space=search_space, seed=seed, device="cuda") # will automatically pull in latest data
755
+ return BoTorchSampler(
756
+ search_space=search_space, seed=seed, device=_use_cuda()
757
+ ) # will automatically pull in latest data
691
758
 
692
759
 
693
760
  def init_nsgaii(study, seed, trials, search_space):
@@ -709,17 +776,20 @@ class AutoSampler(BaseSampler):
709
776
  def __init__(
710
777
  self,
711
778
  samplers: Iterable[Tuple[int, Callable]] | None = None,
712
- search_space: dict[str, BaseDistribution] = None,
779
+ search_space: Optional[dict[str, BaseDistribution]] = None,
713
780
  *,
714
781
  seed: int | None = None,
715
- constraints_func: None = None,
782
+ constraints_func: Optional[Callable[..., Any]] = None,
716
783
  ) -> None:
717
- assert constraints_func is None
784
+ if constraints_func is not None:
785
+ raise NotImplementedError("constraints_func is not supported by AutoSampler.")
718
786
  if samplers is None:
787
+ if search_space is None:
788
+ raise ValueError("AutoSampler requires a search_space when using the default sampler schedule.")
719
789
  samplers = ((0, init_hebo), (100, init_nsgaii))
720
790
  self.sampler_indices = np.sort(np.array([x[0] for x in samplers], dtype=np.int32))
721
791
  self.samplers = [x[1] for x in sorted(samplers, key=lambda x: x[0])]
722
- self.search_space = search_space
792
+ self.search_space = {} if search_space is None else dict(search_space)
723
793
  self._rng = LazyRandomState(seed)
724
794
  self._random_sampler = RandomSampler(seed=seed)
725
795
  self._thread_local_sampler = ThreadLocalSampler()
@@ -762,7 +832,7 @@ class AutoSampler(BaseSampler):
762
832
  complete_trials = study._get_trials(deepcopy=False, states=(TrialState.COMPLETE,), use_cache=True)
763
833
  self._completed_trials = max(self._completed_trials, len(complete_trials))
764
834
  new_index = (self._completed_trials >= self.sampler_indices).sum() - 1
765
- if new_index == self._current_index:
835
+ if new_index == self._current_index or new_index < 0:
766
836
  return
767
837
  self._current_index = new_index
768
838
  self._sampler = self.samplers[new_index](
@@ -775,7 +845,7 @@ class AutoSampler(BaseSampler):
775
845
  def sample_relative(
776
846
  self, study: Study, trial: FrozenTrial, search_space: dict[str, BaseDistribution]
777
847
  ) -> dict[str, Any]:
778
- return self._sampler.sample_relative(study, trial, self.search_space)
848
+ return self._sampler.sample_relative(study, trial, search_space or self.search_space)
779
849
 
780
850
  def sample_independent(
781
851
  self,
@@ -804,5 +874,6 @@ class AutoSampler(BaseSampler):
804
874
  state: TrialState,
805
875
  values: Sequence[float] | None,
806
876
  ) -> None:
807
- assert state in [TrialState.COMPLETE, TrialState.FAIL, TrialState.PRUNED]
877
+ if state not in (TrialState.COMPLETE, TrialState.FAIL, TrialState.PRUNED):
878
+ raise ValueError(f"Unsupported trial state: {state}.")
808
879
  self._sampler.after_trial(study, trial, state, values)
heavyball/utils.py CHANGED
@@ -47,7 +47,7 @@ _cudnn_double_backward_pattern = re.compile(
47
47
  )
48
48
  _torch_compile_double_backward_pattern = re.compile(r"compile.*does not currently support double backward")
49
49
  _fd_error = (
50
- "You can accelerate startup by globally enabling finite_differences first " #
50
+ "You can accelerate startup by globally enabling finite_differences first "
51
51
  "(via opt.finite_differences=True or by subclassing it)\n"
52
52
  "Original Error: "
53
53
  )
@@ -343,7 +343,8 @@ def set_(dst: Tensor, src: Tensor):
343
343
 
344
344
 
345
345
  def clean():
346
- torch.cuda.empty_cache()
346
+ if torch.cuda.is_available():
347
+ torch.cuda.empty_cache()
347
348
  gc.collect()
348
349
 
349
350
 
@@ -418,9 +419,13 @@ def zeropower_via_newtonschulz5(G, steps=5, eps=1e-7):
418
419
 
419
420
 
420
421
  ###### START
421
- # Taken from https://github.com/thinking-machines-lab/manifolds/blob/89dcae50f01af59f1e0570289474da3a2ecaa60b/src/msign.py#L47
422
+ # Based on https://arxiv.org/pdf/2505.16932v3
423
+ # and https://github.com/NoahAmsel/PolarExpress/blob/5454910920ca8c65afda28820cdf9e49b9436ed0/polar_express.py#L69-L82
424
+ # and https://github.com/thinking-machines-lab/manifolds/blob/89dcae50f01af59f1e0570289474da3a2ecaa60b/src/msign.py#L47
425
+ #
422
426
  # under the MIT License
423
427
 
428
+ # Coefficients are from https://arxiv.org/pdf/2505.16932v3
424
429
  ABC_LIST: list[tuple[float, float, float]] = [
425
430
  (8.28721201814563, -23.595886519098837, 17.300387312530933),
426
431
  (4.107059111542203, -2.9478499167379106, 0.5448431082926601),
@@ -438,7 +443,7 @@ ABC_LIST_STABLE: list[tuple[float, float, float]] = [
438
443
  ] + [ABC_LIST[-1]]
439
444
 
440
445
 
441
- def msign(G: torch.Tensor, steps: int = 10) -> torch.Tensor:
446
+ def msign(G: torch.Tensor, steps: int = 10, eps: float = 1e-7) -> torch.Tensor:
442
447
  """
443
448
  Polar Express algorithm for the matrix sign function:
444
449
  https://arxiv.org/abs/2505.16932
@@ -450,7 +455,9 @@ def msign(G: torch.Tensor, steps: int = 10) -> torch.Tensor:
450
455
  if should_transpose:
451
456
  x = x.mT
452
457
 
453
- x /= x.norm(dim=(-2, -1), keepdim=True) * 1.01
458
+ # x = x / (x.norm(dim=(-2, -1), keepdim=True) * 1.01 + eps)
459
+ stochastic_divide_with_eps_(x, x.norm(dim=(-2, -1)) * 1.01, eps)
460
+
454
461
  for step in range(steps):
455
462
  a, b, c = ABC_LIST_STABLE[step] if step < len(ABC_LIST_STABLE) else ABC_LIST_STABLE[-1]
456
463
  s = x @ x.mT
@@ -464,8 +471,7 @@ def msign(G: torch.Tensor, steps: int = 10) -> torch.Tensor:
464
471
 
465
472
  if should_transpose:
466
473
  x = x.mT
467
- x = torch.nan_to_num(x)
468
- return x.float()
474
+ return x.to(G.dtype)
469
475
 
470
476
 
471
477
  ###### END
@@ -665,9 +671,9 @@ def get_orthogonal_matrix(mat, max_eps: float = 1e-3, min_eps: float = 1e-30):
665
671
  final.append(None)
666
672
  continue
667
673
 
674
+ device, dtype = m.device, m.dtype
668
675
  m = promote(m.data)
669
676
 
670
- device, dtype = m.device, m.dtype
671
677
  eps = min_eps
672
678
  while True:
673
679
  try:
@@ -695,7 +701,6 @@ def get_orthogonal_matrix(mat, max_eps: float = 1e-3, min_eps: float = 1e-30):
695
701
  raise
696
702
  clean()
697
703
 
698
- eigvec = eigvec.to(device=m.device, dtype=m.dtype)
699
704
  eigvec = torch.flip(eigvec, [1])
700
705
  final.append(eigvec)
701
706
 
@@ -1048,13 +1053,15 @@ class StatefulOptimizer(torch.optim.Optimizer):
1048
1053
  def get_groups(self, group):
1049
1054
  return [group]
1050
1055
 
1051
- @functools.lru_cache(maxsize=None)
1052
1056
  def state_(self, arg: Tensor, fail: bool = True):
1053
- if not fail and arg not in self.mapping:
1054
- return {}
1055
- if _tensor_key(arg) not in self.mapping_inverse:
1057
+ key = _tensor_key(arg)
1058
+ if key not in self.mapping_inverse:
1056
1059
  self._init_mapping()
1057
- state_param, index = self.mapping_inverse[_tensor_key(arg)]
1060
+ if key not in self.mapping_inverse:
1061
+ if not fail:
1062
+ return {}
1063
+ raise KeyError("Tensor has no tracked state.")
1064
+ state_param, index = self.mapping_inverse[key]
1058
1065
  if state_param not in self.state:
1059
1066
  self.state[state_param] = collections.defaultdict(dict)
1060
1067
  return self.state[state_param][index]
@@ -1142,7 +1149,7 @@ class StatefulOptimizer(torch.optim.Optimizer):
1142
1149
  active_p = [p for p in group["params"]]
1143
1150
 
1144
1151
  if not active_p:
1145
- return
1152
+ continue
1146
1153
 
1147
1154
  k = group["ema_step"] = group.get("ema_step", -1) + 1
1148
1155
 
@@ -1159,7 +1166,7 @@ class StatefulOptimizer(torch.optim.Optimizer):
1159
1166
  active_p = [p for p in group["params"]]
1160
1167
 
1161
1168
  if not active_p:
1162
- return
1169
+ continue
1163
1170
 
1164
1171
  for p in active_p:
1165
1172
  if "param_ema" in self.state_(p):
@@ -1173,7 +1180,7 @@ class StatefulOptimizer(torch.optim.Optimizer):
1173
1180
  active_p = [p for p in group["params"]]
1174
1181
 
1175
1182
  if not active_p:
1176
- return
1183
+ continue
1177
1184
 
1178
1185
  for p in active_p:
1179
1186
  if "param_ema" in self.state_(p):
@@ -1202,7 +1209,7 @@ class StatefulOptimizer(torch.optim.Optimizer):
1202
1209
  for group in self.param_groups:
1203
1210
  for p, g in self.split_p_and_g_in_group(group, skip_none=True, raw=True):
1204
1211
  p.grad = grads.pop(0)
1205
- stochastic_add_(g, p.grad, -1) # technically, we have to divide by the scale here
1212
+ stochastic_add_divide_(g, p.grad, -1, torch.finfo(p.dtype).eps ** 0.5)
1206
1213
  p.hessian_vector = g
1207
1214
  p.data.copy_(p.orig)
1208
1215
  del p.orig
@@ -1292,6 +1299,8 @@ class StatefulOptimizer(torch.optim.Optimizer):
1292
1299
  self._is_preconditioning = psgd_should_update(self.inner_group, self.precond_schedule, self.precond_rng)
1293
1300
  loss = self._handle_closure(closure)
1294
1301
 
1302
+ if self.use_ema:
1303
+ self.ema_update()
1295
1304
  # we assume that parameters are constant and that there are no excessive recompiles
1296
1305
  with torch.no_grad(), torch._dynamo.utils.disable_cache_limit():
1297
1306
  for group in self.param_groups:
@@ -1299,8 +1308,6 @@ class StatefulOptimizer(torch.optim.Optimizer):
1299
1308
  group["param_count"] = sum(p.numel() for p in group["params"])
1300
1309
  group["is_preconditioning"] = self._is_preconditioning
1301
1310
  self._step(group)
1302
- if self.use_ema:
1303
- self.ema_update()
1304
1311
  for real, views in self.mapping.items():
1305
1312
  for tensor in (real, *views):
1306
1313
  for key in ("grad", "vector", "hessian_vector", "orig"):
@@ -1564,18 +1571,20 @@ def stochastic_round_list_(ref: List[Tensor], source: List[Tensor]):
1564
1571
 
1565
1572
  @decorator_knowngood
1566
1573
  def stochastic_round_(ref: Tensor, source: Tensor | None = None):
1567
- if source is not None:
1568
- if source.dtype == torch.bfloat16 or ref.dtype == source.dtype:
1569
- return source
1570
- if ref.dtype != torch.bfloat16:
1571
- return source.to(ref.dtype)
1572
- else:
1574
+ if source is None:
1573
1575
  source = ref
1574
- source = source.float()
1575
- result = torch.randint_like(source, dtype=torch.int32, low=0, high=(1 << 16))
1576
- result.add_(source.view(dtype=torch.int32))
1577
- result.bitwise_and_(-65536) # -65536 = FFFF0000 as a signed int32
1578
- return result.view(dtype=torch.float32).bfloat16()
1576
+ if ref.dtype != torch.bfloat16:
1577
+ return source.to(ref.dtype)
1578
+ if source.dtype == torch.bfloat16:
1579
+ return source
1580
+ if source.dtype in (torch.float16, torch.float32, torch.float64):
1581
+ source = source.to(torch.float32)
1582
+ noise = torch.randint_like(source, dtype=torch.int32, low=0, high=(1 << 16))
1583
+ bits = source.view(dtype=torch.int32)
1584
+ bits.add_(noise)
1585
+ bits.bitwise_and_(-65536) # FFFF0000 mask, preserves sign+exp+7 mantissa bits
1586
+ return bits.view(dtype=torch.float32).bfloat16()
1587
+ return source.to(ref.dtype)
1579
1588
 
1580
1589
 
1581
1590
  @decorator_knowngood
@@ -1585,7 +1594,7 @@ def _compilable_copy_stochastic_(target: Tensor, source: Tensor):
1585
1594
 
1586
1595
  def copy_stochastic_(target: Tensor, source: Tensor):
1587
1596
  if target.dtype == torch.bfloat16 and source.dtype in (torch.float16, torch.float32, torch.float64):
1588
- _compilable_copy_stochastic_(target, source.float())
1597
+ source = stochastic_round_(target, source)
1589
1598
  set_(target, source)
1590
1599
 
1591
1600
 
@@ -1908,7 +1917,8 @@ def update_lra_precond_(
1908
1917
 
1909
1918
  # LU factorization to reuse computation
1910
1919
  try:
1911
- LU, pivots = torch.linalg.lu_factor(IpVtU)
1920
+ lu_matrix = promote(IpVtU) # operate in fp32 when inputs are bf16/half
1921
+ LU, pivots = torch.linalg.lu_factor(lu_matrix)
1912
1922
  except RuntimeError:
1913
1923
  # Error:
1914
1924
  # U[2,2] is zero and using it on lu_solve would result in a division by zero.
@@ -1918,8 +1928,13 @@ def update_lra_precond_(
1918
1928
  # So, we skip this step and reattempt on the next one
1919
1929
  return U.to(U_orig[0].dtype), V.to(V_orig[0].dtype), d.to(d_orig[0].dtype)
1920
1930
 
1921
- invQtv = invQtv - V @ torch.linalg.lu_solve(LU, pivots, (U.T @ invQtv).view(-1, 1), adjoint=True).flatten()
1922
- invPv = U @ torch.linalg.lu_solve(LU, pivots, (V.T @ invQtv).view(-1, 1)).flatten()
1931
+ solve_dtype = LU.dtype
1932
+ rhs = (U.T @ invQtv).view(-1, 1).to(solve_dtype)
1933
+ correction = torch.linalg.lu_solve(LU, pivots, rhs, adjoint=True).to(V.dtype)
1934
+ invQtv = invQtv - (V @ correction).flatten()
1935
+ rhs = (V.T @ invQtv).view(-1, 1).to(solve_dtype)
1936
+ solution = torch.linalg.lu_solve(LU, pivots, rhs).to(U.dtype)
1937
+ invPv = (U @ solution).flatten()
1923
1938
 
1924
1939
  eps, step = scalar_guard(eps, step, vector)
1925
1940
  _compilable_d_step(d, d_orig, invQtv, vector, invPv, hessian_vector, Ph, eps, step, delayed)
@@ -2039,7 +2054,10 @@ def extract_from_flat_update(params: List[Tensor], update: Tensor):
2039
2054
  @decorator_knowngood
2040
2055
  def flatten(x: List[Tensor], remaining: int = 0) -> Tensor:
2041
2056
  last_dim = x[0].shape[-remaining:] if remaining else []
2042
- return torch.cat([i.reshape(-1, *last_dim) for i in x if i.numel()], 0)
2057
+ tensors = [i.reshape(-1, *last_dim) for i in x if i.numel()]
2058
+ if not tensors:
2059
+ return torch.zeros((), dtype=x[0].device, device=x[0].device)
2060
+ return torch.cat(tensors, 0)
2043
2061
 
2044
2062
 
2045
2063
  @decorator_knowngood
@@ -2111,16 +2129,6 @@ def psgd_calc_A_and_conjB(G: Tensor, Q, conjB: Tensor | None): # conjB ("V", "v
2111
2129
  return A, conjB
2112
2130
 
2113
2131
 
2114
- @decorator_knowngood
2115
- def _random_projection(x: Tensor, scale: Optional[Tensor]):
2116
- if scale is None:
2117
- scale = x.norm(float("inf")).clamp(min=1e-8)
2118
- k = 2 ** math.ceil(math.log2(math.log2(min(x.shape)))) # next-largest-power-of-2 of log2-of-size
2119
- norm = x.square().sum(0)
2120
- indices = torch.topk(norm, k, largest=True).indices
2121
- return x.index_select(1, indices).contiguous() / scale, scale
2122
-
2123
-
2124
2132
  def max_singular_value_exact(A, use_lobpcg: bool = False):
2125
2133
  try:
2126
2134
  if use_lobpcg:
@@ -2164,7 +2172,15 @@ def max_singular_value_cholesky(A: Tensor, max_abs: Optional[Tensor] = None):
2164
2172
  """
2165
2173
  Adapted from @evanatyourservice
2166
2174
  """
2167
- Y, max_abs = _random_projection(A, max_abs)
2175
+ if max_abs is None:
2176
+ max_abs = A.norm(float("inf")).clamp(min=1e-8)
2177
+
2178
+ # cholesky uses random projection, but this uses topk -- topk is a warm start, which may converge to a biased result
2179
+ k = 2 ** math.ceil(math.log2(math.log2(min(A.shape)))) # next-largest-power-of-2 of log2-of-size
2180
+ norm = A.square().sum(0)
2181
+ indices = torch.topk(norm, k, largest=True).indices
2182
+ Y = A.index_select(1, indices).contiguous() / max_abs
2183
+
2168
2184
  Q = inplace_orthogonal_(Y, precise_zeroth_power_mode)
2169
2185
  Q = Q / max_abs
2170
2186
  Z = A.T @ Q
@@ -2412,10 +2428,11 @@ def bf16_matmul(x: Tensor, y: Tensor):
2412
2428
  def if_iscompiling(fn):
2413
2429
  base = getattr(torch, fn.__name__, None)
2414
2430
 
2415
- def _fn(x):
2416
- if torch.compiler.is_compiling() and hasattr(torch, fn.__name__):
2417
- return base(x)
2418
- return fn(x)
2431
+ @functools.wraps(fn)
2432
+ def _fn(*args, **kwargs):
2433
+ if torch.compiler.is_compiling() and base is not None:
2434
+ return base(*args, **kwargs)
2435
+ return fn(*args, **kwargs)
2419
2436
 
2420
2437
  return _fn
2421
2438
 
@@ -2551,7 +2568,7 @@ def _psgd_quad_preconditioner_grad(GG: List[Tensor], Q: List[Tensor], numel: int
2551
2568
  else:
2552
2569
  scale = gg.size(0) / numel
2553
2570
  gg = 2 * torch.eye(gg.size(0), device=gg.device, dtype=gg.dtype) - gg * scale
2554
- update = q - gg @ q @ gg
2571
+ update = q - casted_einsum("ab,cd,bc", gg, gg, q)
2555
2572
  out.append(update + update.T) # make matrix symmetric
2556
2573
  return out
2557
2574
 
@@ -3105,7 +3122,7 @@ def pointwise_lr_adaptation(
3105
3122
  ):
3106
3123
  grads, update, state, delta = list_guard(grads, update, state, delta)
3107
3124
  lr_lr = scalar_guard(lr_lr, grads[0])
3108
- _compilable_lr_adapt_(grads, update, state, delta, lr_lr)
3125
+ _compilable_pointwise_lr_adapt_(grads, update, state, delta, lr_lr)
3109
3126
 
3110
3127
 
3111
3128
  def hook_optimizer_into_model(model, optimizer, *args, **kwargs):
@@ -3125,8 +3142,6 @@ def hook_optimizer_into_model(model, optimizer, *args, **kwargs):
3125
3142
 
3126
3143
  def fused_hook(parameters, optimizer, *args, **kwargs):
3127
3144
  parameters = list(parameters)
3128
- param_count = len(parameters)
3129
- seen_params = set()
3130
3145
 
3131
3146
  o = optimizer(parameters, *args, **kwargs)
3132
3147
  step_fn = o.step
@@ -3135,12 +3150,8 @@ def fused_hook(parameters, optimizer, *args, **kwargs):
3135
3150
  )
3136
3151
 
3137
3152
  def _step(p: Tensor):
3138
- seen_params.add(p)
3139
-
3140
- if len(seen_params) < param_count:
3141
- step_fn()
3142
- o.zero_grad()
3143
- seen_params.clear()
3153
+ step_fn()
3154
+ o.zero_grad()
3144
3155
 
3145
3156
  for p in parameters:
3146
3157
  p.register_post_accumulate_grad_hook(_step)
@@ -3165,6 +3176,8 @@ def sam_step(parameters, ball_size, adaptive: bool = True):
3165
3176
  old_params = []
3166
3177
  for p in parameters:
3167
3178
  old_params.append(p.detach().clone())
3179
+ if not hasattr_none(p, "grad"):
3180
+ continue
3168
3181
  grad = promote(p.grad)
3169
3182
  if adaptive:
3170
3183
  grad = grad * promote(p).square()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: heavyball
3
- Version: 2.1.2
3
+ Version: 2.1.4
4
4
  Summary: Efficient Optimizers
5
5
  Author-email: HeavyBall Authors <github.heavyball@nestler.sh>
6
6
  Project-URL: source, https://github.com/HomebrewML/HeavyBall
@@ -21,6 +21,7 @@ Requires-Dist: numpy<2.0.0
21
21
  Provides-Extra: dev
22
22
  Requires-Dist: pre-commit; extra == "dev"
23
23
  Requires-Dist: pytest; extra == "dev"
24
+ Requires-Dist: hypothesis; extra == "dev"
24
25
  Requires-Dist: ruff; extra == "dev"
25
26
  Requires-Dist: matplotlib; extra == "dev"
26
27
  Requires-Dist: seaborn; extra == "dev"
@@ -0,0 +1,9 @@
1
+ heavyball/__init__.py,sha256=9VgWebob-zO7hKg_KmQuSOB4Z_Rh-gCDs_V2TTfQKSo,30123
2
+ heavyball/chainable.py,sha256=O8QiHJ-E5RD-fzo3iulSHgvKgtRZ1Lff2ls3iLmXcoI,42695
3
+ heavyball/helpers.py,sha256=eiotfrJz4V6ewfF9ZboC_JEUi_TCmO195uT6sqqohTE,33429
4
+ heavyball/utils.py,sha256=u4RFOdmYkhsjPE4M_N53oDnuh-vHbvRHc6OLQTEeq-c,105239
5
+ heavyball-2.1.4.dist-info/licenses/LICENSE,sha256=G9fFZcNIVWjU7o6Pr_4sJBRCNDU5X-zelSxIJ2D48ms,1323
6
+ heavyball-2.1.4.dist-info/METADATA,sha256=MxDWUcqFgMWmG3FXtf0UVzDy9qsAWea4tPgnDnx9wXQ,5088
7
+ heavyball-2.1.4.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
8
+ heavyball-2.1.4.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
9
+ heavyball-2.1.4.dist-info/RECORD,,
@@ -1,9 +0,0 @@
1
- heavyball/__init__.py,sha256=1BTb7G-VcfcMyS4EpuVnhE5DBp2fj_Zzs9EQr6slPzg,30491
2
- heavyball/chainable.py,sha256=8S-7QRZYiy_ARhQ8uDu5G0Eg3ouT9Vcfk-rxbKlp4zI,42510
3
- heavyball/helpers.py,sha256=zk_S84wpGcvO9P6kn4UeaQUIDowHxcbM9qQITEm2g5I,30267
4
- heavyball/utils.py,sha256=Lx9XlfkyQbfYMPtqiA0rNIz4PXQe_bpLqKFby3upHMw,104514
5
- heavyball-2.1.2.dist-info/licenses/LICENSE,sha256=G9fFZcNIVWjU7o6Pr_4sJBRCNDU5X-zelSxIJ2D48ms,1323
6
- heavyball-2.1.2.dist-info/METADATA,sha256=EMM0OI4cPeaQlMkts2j9CCp9KxhJm-o_9VDNLm4ySQg,5046
7
- heavyball-2.1.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
8
- heavyball-2.1.2.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
9
- heavyball-2.1.2.dist-info/RECORD,,