heavyball 2.1.2__py3-none-any.whl → 2.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- heavyball/helpers.py +39 -9
- heavyball/utils.py +16 -10
- {heavyball-2.1.2.dist-info → heavyball-2.1.3.dist-info}/METADATA +2 -1
- heavyball-2.1.3.dist-info/RECORD +9 -0
- heavyball-2.1.2.dist-info/RECORD +0 -9
- {heavyball-2.1.2.dist-info → heavyball-2.1.3.dist-info}/WHEEL +0 -0
- {heavyball-2.1.2.dist-info → heavyball-2.1.3.dist-info}/licenses/LICENSE +0 -0
- {heavyball-2.1.2.dist-info → heavyball-2.1.3.dist-info}/top_level.txt +0 -0
heavyball/helpers.py
CHANGED
@@ -3,7 +3,8 @@ from __future__ import annotations
|
|
3
3
|
import functools
|
4
4
|
import math
|
5
5
|
import threading
|
6
|
-
from
|
6
|
+
from contextlib import contextmanager
|
7
|
+
from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Sequence, Tuple, Union
|
7
8
|
|
8
9
|
import numpy
|
9
10
|
import numpy as np
|
@@ -11,7 +12,6 @@ import optuna
|
|
11
12
|
import optunahub
|
12
13
|
import pandas as pd
|
13
14
|
import torch
|
14
|
-
from botorch.utils.sampling import manual_seed
|
15
15
|
from hebo.design_space.design_space import DesignSpace
|
16
16
|
from hebo.optimizers.hebo import HEBO
|
17
17
|
from optuna._transform import _SearchSpaceTransform
|
@@ -21,13 +21,6 @@ from optuna.samplers._lazy_random_state import LazyRandomState
|
|
21
21
|
from optuna.study import Study
|
22
22
|
from optuna.study._study_direction import StudyDirection
|
23
23
|
from optuna.trial import FrozenTrial, TrialState
|
24
|
-
from optuna_integration.botorch import (
|
25
|
-
ehvi_candidates_func,
|
26
|
-
logei_candidates_func,
|
27
|
-
qehvi_candidates_func,
|
28
|
-
qei_candidates_func,
|
29
|
-
qparego_candidates_func,
|
30
|
-
)
|
31
24
|
from torch import Tensor
|
32
25
|
from torch.nn import functional as F
|
33
26
|
|
@@ -37,6 +30,33 @@ _MAXINT32 = (1 << 31) - 1
|
|
37
30
|
_SAMPLER_KEY = "auto:sampler"
|
38
31
|
|
39
32
|
|
33
|
+
@contextmanager
|
34
|
+
def manual_seed(seed: int | None = None) -> Generator[None, None, None]:
|
35
|
+
r"""
|
36
|
+
Contextmanager for manual setting the torch.random seed.
|
37
|
+
|
38
|
+
Args:
|
39
|
+
seed: The seed to set the random number generator to.
|
40
|
+
|
41
|
+
Returns:
|
42
|
+
Generator
|
43
|
+
|
44
|
+
Example:
|
45
|
+
>>> with manual_seed(1234):
|
46
|
+
>>> X = torch.rand(3)
|
47
|
+
|
48
|
+
copied as-is from https://github.com/meta-pytorch/botorch/blob/a42cd65f9b704cdb6f2ee64db99a022eb15295d5/botorch/utils/sampling.py#L53C1-L75C50 under the MIT License
|
49
|
+
"""
|
50
|
+
old_state = torch.random.get_rng_state()
|
51
|
+
try:
|
52
|
+
if seed is not None:
|
53
|
+
torch.random.manual_seed(seed)
|
54
|
+
yield
|
55
|
+
finally:
|
56
|
+
if seed is not None:
|
57
|
+
torch.random.set_rng_state(old_state)
|
58
|
+
|
59
|
+
|
40
60
|
class SimpleAPIBaseSampler(BaseSampler):
|
41
61
|
def __init__(
|
42
62
|
self,
|
@@ -65,6 +85,16 @@ def _get_default_candidates_func(
|
|
65
85
|
"""
|
66
86
|
The original is available at https://github.com/optuna/optuna-integration/blob/156a8bc081322791015d2beefff9373ed7b24047/optuna_integration/botorch/botorch.py under the MIT License
|
67
87
|
"""
|
88
|
+
|
89
|
+
# lazy import
|
90
|
+
from optuna_integration.botorch import (
|
91
|
+
ehvi_candidates_func,
|
92
|
+
logei_candidates_func,
|
93
|
+
qehvi_candidates_func,
|
94
|
+
qei_candidates_func,
|
95
|
+
qparego_candidates_func,
|
96
|
+
)
|
97
|
+
|
68
98
|
if n_objectives > 3 and not has_constraint and not consider_running_trials:
|
69
99
|
return ehvi_candidates_func
|
70
100
|
elif n_objectives > 3:
|
heavyball/utils.py
CHANGED
@@ -47,7 +47,7 @@ _cudnn_double_backward_pattern = re.compile(
|
|
47
47
|
)
|
48
48
|
_torch_compile_double_backward_pattern = re.compile(r"compile.*does not currently support double backward")
|
49
49
|
_fd_error = (
|
50
|
-
"You can accelerate startup by globally enabling finite_differences first "
|
50
|
+
"You can accelerate startup by globally enabling finite_differences first "
|
51
51
|
"(via opt.finite_differences=True or by subclassing it)\n"
|
52
52
|
"Original Error: "
|
53
53
|
)
|
@@ -418,9 +418,13 @@ def zeropower_via_newtonschulz5(G, steps=5, eps=1e-7):
|
|
418
418
|
|
419
419
|
|
420
420
|
###### START
|
421
|
-
#
|
421
|
+
# Based on https://arxiv.org/pdf/2505.16932v3
|
422
|
+
# and https://github.com/NoahAmsel/PolarExpress/blob/5454910920ca8c65afda28820cdf9e49b9436ed0/polar_express.py#L69-L82
|
423
|
+
# and https://github.com/thinking-machines-lab/manifolds/blob/89dcae50f01af59f1e0570289474da3a2ecaa60b/src/msign.py#L47
|
424
|
+
#
|
422
425
|
# under the MIT License
|
423
426
|
|
427
|
+
# Coefficients are from https://arxiv.org/pdf/2505.16932v3
|
424
428
|
ABC_LIST: list[tuple[float, float, float]] = [
|
425
429
|
(8.28721201814563, -23.595886519098837, 17.300387312530933),
|
426
430
|
(4.107059111542203, -2.9478499167379106, 0.5448431082926601),
|
@@ -438,7 +442,7 @@ ABC_LIST_STABLE: list[tuple[float, float, float]] = [
|
|
438
442
|
] + [ABC_LIST[-1]]
|
439
443
|
|
440
444
|
|
441
|
-
def msign(G: torch.Tensor, steps: int = 10) -> torch.Tensor:
|
445
|
+
def msign(G: torch.Tensor, steps: int = 10, eps: float = 1e-7) -> torch.Tensor:
|
442
446
|
"""
|
443
447
|
Polar Express algorithm for the matrix sign function:
|
444
448
|
https://arxiv.org/abs/2505.16932
|
@@ -450,7 +454,9 @@ def msign(G: torch.Tensor, steps: int = 10) -> torch.Tensor:
|
|
450
454
|
if should_transpose:
|
451
455
|
x = x.mT
|
452
456
|
|
453
|
-
x
|
457
|
+
# x = x / (x.norm(dim=(-2, -1), keepdim=True) * 1.01 + eps)
|
458
|
+
stochastic_divide_with_eps_(x, x.norm(dim=(-2, -1)) * 1.01, eps)
|
459
|
+
|
454
460
|
for step in range(steps):
|
455
461
|
a, b, c = ABC_LIST_STABLE[step] if step < len(ABC_LIST_STABLE) else ABC_LIST_STABLE[-1]
|
456
462
|
s = x @ x.mT
|
@@ -464,7 +470,6 @@ def msign(G: torch.Tensor, steps: int = 10) -> torch.Tensor:
|
|
464
470
|
|
465
471
|
if should_transpose:
|
466
472
|
x = x.mT
|
467
|
-
x = torch.nan_to_num(x)
|
468
473
|
return x.float()
|
469
474
|
|
470
475
|
|
@@ -1585,7 +1590,7 @@ def _compilable_copy_stochastic_(target: Tensor, source: Tensor):
|
|
1585
1590
|
|
1586
1591
|
def copy_stochastic_(target: Tensor, source: Tensor):
|
1587
1592
|
if target.dtype == torch.bfloat16 and source.dtype in (torch.float16, torch.float32, torch.float64):
|
1588
|
-
|
1593
|
+
source = stochastic_round_(target, source)
|
1589
1594
|
set_(target, source)
|
1590
1595
|
|
1591
1596
|
|
@@ -2412,10 +2417,11 @@ def bf16_matmul(x: Tensor, y: Tensor):
|
|
2412
2417
|
def if_iscompiling(fn):
|
2413
2418
|
base = getattr(torch, fn.__name__, None)
|
2414
2419
|
|
2415
|
-
|
2416
|
-
|
2417
|
-
|
2418
|
-
|
2420
|
+
@functools.wraps(fn)
|
2421
|
+
def _fn(*args, **kwargs):
|
2422
|
+
if torch.compiler.is_compiling() and base is not None:
|
2423
|
+
return base(*args, **kwargs)
|
2424
|
+
return fn(*args, **kwargs)
|
2419
2425
|
|
2420
2426
|
return _fn
|
2421
2427
|
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: heavyball
|
3
|
-
Version: 2.1.
|
3
|
+
Version: 2.1.3
|
4
4
|
Summary: Efficient Optimizers
|
5
5
|
Author-email: HeavyBall Authors <github.heavyball@nestler.sh>
|
6
6
|
Project-URL: source, https://github.com/HomebrewML/HeavyBall
|
@@ -21,6 +21,7 @@ Requires-Dist: numpy<2.0.0
|
|
21
21
|
Provides-Extra: dev
|
22
22
|
Requires-Dist: pre-commit; extra == "dev"
|
23
23
|
Requires-Dist: pytest; extra == "dev"
|
24
|
+
Requires-Dist: hypothesis; extra == "dev"
|
24
25
|
Requires-Dist: ruff; extra == "dev"
|
25
26
|
Requires-Dist: matplotlib; extra == "dev"
|
26
27
|
Requires-Dist: seaborn; extra == "dev"
|
@@ -0,0 +1,9 @@
|
|
1
|
+
heavyball/__init__.py,sha256=1BTb7G-VcfcMyS4EpuVnhE5DBp2fj_Zzs9EQr6slPzg,30491
|
2
|
+
heavyball/chainable.py,sha256=8S-7QRZYiy_ARhQ8uDu5G0Eg3ouT9Vcfk-rxbKlp4zI,42510
|
3
|
+
heavyball/helpers.py,sha256=is4Egdgoj2GUsBYdraItonqsoVIY9ZKP_VZl-hEnF1Y,31077
|
4
|
+
heavyball/utils.py,sha256=_AOFIkFyaMO39YjbvclkzivR-nKe_kLShRZda3rgMiA,104850
|
5
|
+
heavyball-2.1.3.dist-info/licenses/LICENSE,sha256=G9fFZcNIVWjU7o6Pr_4sJBRCNDU5X-zelSxIJ2D48ms,1323
|
6
|
+
heavyball-2.1.3.dist-info/METADATA,sha256=by35259YI9DvUQ8Vq958sHmAxqSVtPY5JoY5Hn0CccY,5088
|
7
|
+
heavyball-2.1.3.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
8
|
+
heavyball-2.1.3.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
|
9
|
+
heavyball-2.1.3.dist-info/RECORD,,
|
heavyball-2.1.2.dist-info/RECORD
DELETED
@@ -1,9 +0,0 @@
|
|
1
|
-
heavyball/__init__.py,sha256=1BTb7G-VcfcMyS4EpuVnhE5DBp2fj_Zzs9EQr6slPzg,30491
|
2
|
-
heavyball/chainable.py,sha256=8S-7QRZYiy_ARhQ8uDu5G0Eg3ouT9Vcfk-rxbKlp4zI,42510
|
3
|
-
heavyball/helpers.py,sha256=zk_S84wpGcvO9P6kn4UeaQUIDowHxcbM9qQITEm2g5I,30267
|
4
|
-
heavyball/utils.py,sha256=Lx9XlfkyQbfYMPtqiA0rNIz4PXQe_bpLqKFby3upHMw,104514
|
5
|
-
heavyball-2.1.2.dist-info/licenses/LICENSE,sha256=G9fFZcNIVWjU7o6Pr_4sJBRCNDU5X-zelSxIJ2D48ms,1323
|
6
|
-
heavyball-2.1.2.dist-info/METADATA,sha256=EMM0OI4cPeaQlMkts2j9CCp9KxhJm-o_9VDNLm4ySQg,5046
|
7
|
-
heavyball-2.1.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
8
|
-
heavyball-2.1.2.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
|
9
|
-
heavyball-2.1.2.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|