heavyball 2.1.2__py3-none-any.whl → 2.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
heavyball/helpers.py CHANGED
@@ -3,7 +3,8 @@ from __future__ import annotations
3
3
  import functools
4
4
  import math
5
5
  import threading
6
- from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union
6
+ from contextlib import contextmanager
7
+ from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Sequence, Tuple, Union
7
8
 
8
9
  import numpy
9
10
  import numpy as np
@@ -11,7 +12,6 @@ import optuna
11
12
  import optunahub
12
13
  import pandas as pd
13
14
  import torch
14
- from botorch.utils.sampling import manual_seed
15
15
  from hebo.design_space.design_space import DesignSpace
16
16
  from hebo.optimizers.hebo import HEBO
17
17
  from optuna._transform import _SearchSpaceTransform
@@ -21,13 +21,6 @@ from optuna.samplers._lazy_random_state import LazyRandomState
21
21
  from optuna.study import Study
22
22
  from optuna.study._study_direction import StudyDirection
23
23
  from optuna.trial import FrozenTrial, TrialState
24
- from optuna_integration.botorch import (
25
- ehvi_candidates_func,
26
- logei_candidates_func,
27
- qehvi_candidates_func,
28
- qei_candidates_func,
29
- qparego_candidates_func,
30
- )
31
24
  from torch import Tensor
32
25
  from torch.nn import functional as F
33
26
 
@@ -37,6 +30,33 @@ _MAXINT32 = (1 << 31) - 1
37
30
  _SAMPLER_KEY = "auto:sampler"
38
31
 
39
32
 
33
+ @contextmanager
34
+ def manual_seed(seed: int | None = None) -> Generator[None, None, None]:
35
+ r"""
36
+ Contextmanager for manual setting the torch.random seed.
37
+
38
+ Args:
39
+ seed: The seed to set the random number generator to.
40
+
41
+ Returns:
42
+ Generator
43
+
44
+ Example:
45
+ >>> with manual_seed(1234):
46
+ >>> X = torch.rand(3)
47
+
48
+ copied as-is from https://github.com/meta-pytorch/botorch/blob/a42cd65f9b704cdb6f2ee64db99a022eb15295d5/botorch/utils/sampling.py#L53C1-L75C50 under the MIT License
49
+ """
50
+ old_state = torch.random.get_rng_state()
51
+ try:
52
+ if seed is not None:
53
+ torch.random.manual_seed(seed)
54
+ yield
55
+ finally:
56
+ if seed is not None:
57
+ torch.random.set_rng_state(old_state)
58
+
59
+
40
60
  class SimpleAPIBaseSampler(BaseSampler):
41
61
  def __init__(
42
62
  self,
@@ -65,6 +85,16 @@ def _get_default_candidates_func(
65
85
  """
66
86
  The original is available at https://github.com/optuna/optuna-integration/blob/156a8bc081322791015d2beefff9373ed7b24047/optuna_integration/botorch/botorch.py under the MIT License
67
87
  """
88
+
89
+ # lazy import
90
+ from optuna_integration.botorch import (
91
+ ehvi_candidates_func,
92
+ logei_candidates_func,
93
+ qehvi_candidates_func,
94
+ qei_candidates_func,
95
+ qparego_candidates_func,
96
+ )
97
+
68
98
  if n_objectives > 3 and not has_constraint and not consider_running_trials:
69
99
  return ehvi_candidates_func
70
100
  elif n_objectives > 3:
heavyball/utils.py CHANGED
@@ -47,7 +47,7 @@ _cudnn_double_backward_pattern = re.compile(
47
47
  )
48
48
  _torch_compile_double_backward_pattern = re.compile(r"compile.*does not currently support double backward")
49
49
  _fd_error = (
50
- "You can accelerate startup by globally enabling finite_differences first " #
50
+ "You can accelerate startup by globally enabling finite_differences first "
51
51
  "(via opt.finite_differences=True or by subclassing it)\n"
52
52
  "Original Error: "
53
53
  )
@@ -418,9 +418,13 @@ def zeropower_via_newtonschulz5(G, steps=5, eps=1e-7):
418
418
 
419
419
 
420
420
  ###### START
421
- # Taken from https://github.com/thinking-machines-lab/manifolds/blob/89dcae50f01af59f1e0570289474da3a2ecaa60b/src/msign.py#L47
421
+ # Based on https://arxiv.org/pdf/2505.16932v3
422
+ # and https://github.com/NoahAmsel/PolarExpress/blob/5454910920ca8c65afda28820cdf9e49b9436ed0/polar_express.py#L69-L82
423
+ # and https://github.com/thinking-machines-lab/manifolds/blob/89dcae50f01af59f1e0570289474da3a2ecaa60b/src/msign.py#L47
424
+ #
422
425
  # under the MIT License
423
426
 
427
+ # Coefficients are from https://arxiv.org/pdf/2505.16932v3
424
428
  ABC_LIST: list[tuple[float, float, float]] = [
425
429
  (8.28721201814563, -23.595886519098837, 17.300387312530933),
426
430
  (4.107059111542203, -2.9478499167379106, 0.5448431082926601),
@@ -438,7 +442,7 @@ ABC_LIST_STABLE: list[tuple[float, float, float]] = [
438
442
  ] + [ABC_LIST[-1]]
439
443
 
440
444
 
441
- def msign(G: torch.Tensor, steps: int = 10) -> torch.Tensor:
445
+ def msign(G: torch.Tensor, steps: int = 10, eps: float = 1e-7) -> torch.Tensor:
442
446
  """
443
447
  Polar Express algorithm for the matrix sign function:
444
448
  https://arxiv.org/abs/2505.16932
@@ -450,7 +454,9 @@ def msign(G: torch.Tensor, steps: int = 10) -> torch.Tensor:
450
454
  if should_transpose:
451
455
  x = x.mT
452
456
 
453
- x /= x.norm(dim=(-2, -1), keepdim=True) * 1.01
457
+ # x = x / (x.norm(dim=(-2, -1), keepdim=True) * 1.01 + eps)
458
+ stochastic_divide_with_eps_(x, x.norm(dim=(-2, -1)) * 1.01, eps)
459
+
454
460
  for step in range(steps):
455
461
  a, b, c = ABC_LIST_STABLE[step] if step < len(ABC_LIST_STABLE) else ABC_LIST_STABLE[-1]
456
462
  s = x @ x.mT
@@ -464,7 +470,6 @@ def msign(G: torch.Tensor, steps: int = 10) -> torch.Tensor:
464
470
 
465
471
  if should_transpose:
466
472
  x = x.mT
467
- x = torch.nan_to_num(x)
468
473
  return x.float()
469
474
 
470
475
 
@@ -1585,7 +1590,7 @@ def _compilable_copy_stochastic_(target: Tensor, source: Tensor):
1585
1590
 
1586
1591
  def copy_stochastic_(target: Tensor, source: Tensor):
1587
1592
  if target.dtype == torch.bfloat16 and source.dtype in (torch.float16, torch.float32, torch.float64):
1588
- _compilable_copy_stochastic_(target, source.float())
1593
+ source = stochastic_round_(target, source)
1589
1594
  set_(target, source)
1590
1595
 
1591
1596
 
@@ -2412,10 +2417,11 @@ def bf16_matmul(x: Tensor, y: Tensor):
2412
2417
  def if_iscompiling(fn):
2413
2418
  base = getattr(torch, fn.__name__, None)
2414
2419
 
2415
- def _fn(x):
2416
- if torch.compiler.is_compiling() and hasattr(torch, fn.__name__):
2417
- return base(x)
2418
- return fn(x)
2420
+ @functools.wraps(fn)
2421
+ def _fn(*args, **kwargs):
2422
+ if torch.compiler.is_compiling() and base is not None:
2423
+ return base(*args, **kwargs)
2424
+ return fn(*args, **kwargs)
2419
2425
 
2420
2426
  return _fn
2421
2427
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: heavyball
3
- Version: 2.1.2
3
+ Version: 2.1.3
4
4
  Summary: Efficient Optimizers
5
5
  Author-email: HeavyBall Authors <github.heavyball@nestler.sh>
6
6
  Project-URL: source, https://github.com/HomebrewML/HeavyBall
@@ -21,6 +21,7 @@ Requires-Dist: numpy<2.0.0
21
21
  Provides-Extra: dev
22
22
  Requires-Dist: pre-commit; extra == "dev"
23
23
  Requires-Dist: pytest; extra == "dev"
24
+ Requires-Dist: hypothesis; extra == "dev"
24
25
  Requires-Dist: ruff; extra == "dev"
25
26
  Requires-Dist: matplotlib; extra == "dev"
26
27
  Requires-Dist: seaborn; extra == "dev"
@@ -0,0 +1,9 @@
1
+ heavyball/__init__.py,sha256=1BTb7G-VcfcMyS4EpuVnhE5DBp2fj_Zzs9EQr6slPzg,30491
2
+ heavyball/chainable.py,sha256=8S-7QRZYiy_ARhQ8uDu5G0Eg3ouT9Vcfk-rxbKlp4zI,42510
3
+ heavyball/helpers.py,sha256=is4Egdgoj2GUsBYdraItonqsoVIY9ZKP_VZl-hEnF1Y,31077
4
+ heavyball/utils.py,sha256=_AOFIkFyaMO39YjbvclkzivR-nKe_kLShRZda3rgMiA,104850
5
+ heavyball-2.1.3.dist-info/licenses/LICENSE,sha256=G9fFZcNIVWjU7o6Pr_4sJBRCNDU5X-zelSxIJ2D48ms,1323
6
+ heavyball-2.1.3.dist-info/METADATA,sha256=by35259YI9DvUQ8Vq958sHmAxqSVtPY5JoY5Hn0CccY,5088
7
+ heavyball-2.1.3.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
8
+ heavyball-2.1.3.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
9
+ heavyball-2.1.3.dist-info/RECORD,,
@@ -1,9 +0,0 @@
1
- heavyball/__init__.py,sha256=1BTb7G-VcfcMyS4EpuVnhE5DBp2fj_Zzs9EQr6slPzg,30491
2
- heavyball/chainable.py,sha256=8S-7QRZYiy_ARhQ8uDu5G0Eg3ouT9Vcfk-rxbKlp4zI,42510
3
- heavyball/helpers.py,sha256=zk_S84wpGcvO9P6kn4UeaQUIDowHxcbM9qQITEm2g5I,30267
4
- heavyball/utils.py,sha256=Lx9XlfkyQbfYMPtqiA0rNIz4PXQe_bpLqKFby3upHMw,104514
5
- heavyball-2.1.2.dist-info/licenses/LICENSE,sha256=G9fFZcNIVWjU7o6Pr_4sJBRCNDU5X-zelSxIJ2D48ms,1323
6
- heavyball-2.1.2.dist-info/METADATA,sha256=EMM0OI4cPeaQlMkts2j9CCp9KxhJm-o_9VDNLm4ySQg,5046
7
- heavyball-2.1.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
8
- heavyball-2.1.2.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
9
- heavyball-2.1.2.dist-info/RECORD,,