heavyball 2.1.1__py3-none-any.whl → 2.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- heavyball/helpers.py +39 -9
- heavyball/utils.py +91 -22
- {heavyball-2.1.1.dist-info → heavyball-2.1.3.dist-info}/METADATA +4 -3
- heavyball-2.1.3.dist-info/RECORD +9 -0
- heavyball-2.1.1.dist-info/RECORD +0 -9
- {heavyball-2.1.1.dist-info → heavyball-2.1.3.dist-info}/WHEEL +0 -0
- {heavyball-2.1.1.dist-info → heavyball-2.1.3.dist-info}/licenses/LICENSE +0 -0
- {heavyball-2.1.1.dist-info → heavyball-2.1.3.dist-info}/top_level.txt +0 -0
heavyball/helpers.py
CHANGED
@@ -3,7 +3,8 @@ from __future__ import annotations
|
|
3
3
|
import functools
|
4
4
|
import math
|
5
5
|
import threading
|
6
|
-
from
|
6
|
+
from contextlib import contextmanager
|
7
|
+
from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Sequence, Tuple, Union
|
7
8
|
|
8
9
|
import numpy
|
9
10
|
import numpy as np
|
@@ -11,7 +12,6 @@ import optuna
|
|
11
12
|
import optunahub
|
12
13
|
import pandas as pd
|
13
14
|
import torch
|
14
|
-
from botorch.utils.sampling import manual_seed
|
15
15
|
from hebo.design_space.design_space import DesignSpace
|
16
16
|
from hebo.optimizers.hebo import HEBO
|
17
17
|
from optuna._transform import _SearchSpaceTransform
|
@@ -21,13 +21,6 @@ from optuna.samplers._lazy_random_state import LazyRandomState
|
|
21
21
|
from optuna.study import Study
|
22
22
|
from optuna.study._study_direction import StudyDirection
|
23
23
|
from optuna.trial import FrozenTrial, TrialState
|
24
|
-
from optuna_integration.botorch import (
|
25
|
-
ehvi_candidates_func,
|
26
|
-
logei_candidates_func,
|
27
|
-
qehvi_candidates_func,
|
28
|
-
qei_candidates_func,
|
29
|
-
qparego_candidates_func,
|
30
|
-
)
|
31
24
|
from torch import Tensor
|
32
25
|
from torch.nn import functional as F
|
33
26
|
|
@@ -37,6 +30,33 @@ _MAXINT32 = (1 << 31) - 1
|
|
37
30
|
_SAMPLER_KEY = "auto:sampler"
|
38
31
|
|
39
32
|
|
33
|
+
@contextmanager
|
34
|
+
def manual_seed(seed: int | None = None) -> Generator[None, None, None]:
|
35
|
+
r"""
|
36
|
+
Contextmanager for manual setting the torch.random seed.
|
37
|
+
|
38
|
+
Args:
|
39
|
+
seed: The seed to set the random number generator to.
|
40
|
+
|
41
|
+
Returns:
|
42
|
+
Generator
|
43
|
+
|
44
|
+
Example:
|
45
|
+
>>> with manual_seed(1234):
|
46
|
+
>>> X = torch.rand(3)
|
47
|
+
|
48
|
+
copied as-is from https://github.com/meta-pytorch/botorch/blob/a42cd65f9b704cdb6f2ee64db99a022eb15295d5/botorch/utils/sampling.py#L53C1-L75C50 under the MIT License
|
49
|
+
"""
|
50
|
+
old_state = torch.random.get_rng_state()
|
51
|
+
try:
|
52
|
+
if seed is not None:
|
53
|
+
torch.random.manual_seed(seed)
|
54
|
+
yield
|
55
|
+
finally:
|
56
|
+
if seed is not None:
|
57
|
+
torch.random.set_rng_state(old_state)
|
58
|
+
|
59
|
+
|
40
60
|
class SimpleAPIBaseSampler(BaseSampler):
|
41
61
|
def __init__(
|
42
62
|
self,
|
@@ -65,6 +85,16 @@ def _get_default_candidates_func(
|
|
65
85
|
"""
|
66
86
|
The original is available at https://github.com/optuna/optuna-integration/blob/156a8bc081322791015d2beefff9373ed7b24047/optuna_integration/botorch/botorch.py under the MIT License
|
67
87
|
"""
|
88
|
+
|
89
|
+
# lazy import
|
90
|
+
from optuna_integration.botorch import (
|
91
|
+
ehvi_candidates_func,
|
92
|
+
logei_candidates_func,
|
93
|
+
qehvi_candidates_func,
|
94
|
+
qei_candidates_func,
|
95
|
+
qparego_candidates_func,
|
96
|
+
)
|
97
|
+
|
68
98
|
if n_objectives > 3 and not has_constraint and not consider_running_trials:
|
69
99
|
return ehvi_candidates_func
|
70
100
|
elif n_objectives > 3:
|
heavyball/utils.py
CHANGED
@@ -27,6 +27,7 @@ class ZerothPowerMode(enum.Enum):
|
|
27
27
|
qr = "qr"
|
28
28
|
svd = "svd"
|
29
29
|
legacy_svd = "legacy_svd"
|
30
|
+
thinky_polar_express = "thinky_polar_express"
|
30
31
|
|
31
32
|
|
32
33
|
class OrthoScaleMode(enum.Enum):
|
@@ -46,7 +47,7 @@ _cudnn_double_backward_pattern = re.compile(
|
|
46
47
|
)
|
47
48
|
_torch_compile_double_backward_pattern = re.compile(r"compile.*does not currently support double backward")
|
48
49
|
_fd_error = (
|
49
|
-
"You can accelerate startup by globally enabling finite_differences first "
|
50
|
+
"You can accelerate startup by globally enabling finite_differences first "
|
50
51
|
"(via opt.finite_differences=True or by subclassing it)\n"
|
51
52
|
"Original Error: "
|
52
53
|
)
|
@@ -390,12 +391,12 @@ def zeropower_via_newtonschulz5(G, steps=5, eps=1e-7):
|
|
390
391
|
) # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng
|
391
392
|
assert steps == 5
|
392
393
|
G = G.clone()
|
393
|
-
|
394
|
+
x = G if G.dtype == torch.float64 else stochastic_round_(G)
|
394
395
|
if G.size(-2) > G.size(-1):
|
395
|
-
|
396
|
+
x = x.mT
|
396
397
|
|
397
398
|
# X = X / (X.norm(dim=(-2, -1), keepdim=True) + eps)
|
398
|
-
stochastic_divide_with_eps_(
|
399
|
+
stochastic_divide_with_eps_(x, G.norm(dim=(-2, -1)), eps) # ensure top singular value <= 1
|
399
400
|
# Perform the NS iterations
|
400
401
|
for a, b, c in [
|
401
402
|
(4.0848, -6.8946, 2.9270),
|
@@ -404,13 +405,75 @@ def zeropower_via_newtonschulz5(G, steps=5, eps=1e-7):
|
|
404
405
|
(2.8769, -3.1427, 1.2046),
|
405
406
|
(2.8366, -3.0525, 1.2012),
|
406
407
|
]:
|
407
|
-
|
408
|
-
|
409
|
-
|
408
|
+
s = x @ x.mT
|
409
|
+
y = c * s
|
410
|
+
y.diagonal(dim1=-2, dim2=-1).add_(b)
|
411
|
+
y = y @ s
|
412
|
+
y.diagonal(dim1=-2, dim2=-1).add_(a)
|
413
|
+
x = y @ x
|
410
414
|
|
411
415
|
if G.size(-2) > G.size(-1):
|
412
|
-
|
413
|
-
return
|
416
|
+
x = x.mT
|
417
|
+
return x.to(G.dtype)
|
418
|
+
|
419
|
+
|
420
|
+
###### START
|
421
|
+
# Based on https://arxiv.org/pdf/2505.16932v3
|
422
|
+
# and https://github.com/NoahAmsel/PolarExpress/blob/5454910920ca8c65afda28820cdf9e49b9436ed0/polar_express.py#L69-L82
|
423
|
+
# and https://github.com/thinking-machines-lab/manifolds/blob/89dcae50f01af59f1e0570289474da3a2ecaa60b/src/msign.py#L47
|
424
|
+
#
|
425
|
+
# under the MIT License
|
426
|
+
|
427
|
+
# Coefficients are from https://arxiv.org/pdf/2505.16932v3
|
428
|
+
ABC_LIST: list[tuple[float, float, float]] = [
|
429
|
+
(8.28721201814563, -23.595886519098837, 17.300387312530933),
|
430
|
+
(4.107059111542203, -2.9478499167379106, 0.5448431082926601),
|
431
|
+
(3.9486908534822946, -2.908902115962949, 0.5518191394370137),
|
432
|
+
(3.3184196573706015, -2.488488024314874, 0.51004894012372),
|
433
|
+
(2.300652019954817, -1.6689039845747493, 0.4188073119525673),
|
434
|
+
(1.891301407787398, -1.2679958271945868, 0.37680408948524835),
|
435
|
+
(1.8750014808534479, -1.2500016453999487, 0.3750001645474248),
|
436
|
+
(1.875, -1.25, 0.375),
|
437
|
+
]
|
438
|
+
|
439
|
+
# safety factor for numerical stability (but exclude last polynomial)
|
440
|
+
ABC_LIST_STABLE: list[tuple[float, float, float]] = [
|
441
|
+
(a / 1.01, b / 1.01**3, c / 1.01**5) for (a, b, c) in ABC_LIST[:-1]
|
442
|
+
] + [ABC_LIST[-1]]
|
443
|
+
|
444
|
+
|
445
|
+
def msign(G: torch.Tensor, steps: int = 10, eps: float = 1e-7) -> torch.Tensor:
|
446
|
+
"""
|
447
|
+
Polar Express algorithm for the matrix sign function:
|
448
|
+
https://arxiv.org/abs/2505.16932
|
449
|
+
"""
|
450
|
+
assert G.ndim >= 2
|
451
|
+
should_transpose: bool = G.size(-2) > G.size(-1)
|
452
|
+
|
453
|
+
x = G if G.dtype == torch.float64 else stochastic_round_(G)
|
454
|
+
if should_transpose:
|
455
|
+
x = x.mT
|
456
|
+
|
457
|
+
# x = x / (x.norm(dim=(-2, -1), keepdim=True) * 1.01 + eps)
|
458
|
+
stochastic_divide_with_eps_(x, x.norm(dim=(-2, -1)) * 1.01, eps)
|
459
|
+
|
460
|
+
for step in range(steps):
|
461
|
+
a, b, c = ABC_LIST_STABLE[step] if step < len(ABC_LIST_STABLE) else ABC_LIST_STABLE[-1]
|
462
|
+
s = x @ x.mT
|
463
|
+
# goal is to compute x = a x + b S x + c S^2 x
|
464
|
+
# we can break this up into: x = (a I + (b I + c S) S) x
|
465
|
+
y = c * s
|
466
|
+
y.diagonal(dim1=-2, dim2=-1).add_(b)
|
467
|
+
y = y @ s
|
468
|
+
y.diagonal(dim1=-2, dim2=-1).add_(a)
|
469
|
+
x = y @ x
|
470
|
+
|
471
|
+
if should_transpose:
|
472
|
+
x = x.mT
|
473
|
+
return x.float()
|
474
|
+
|
475
|
+
|
476
|
+
###### END
|
414
477
|
|
415
478
|
|
416
479
|
@decorator_knowngood
|
@@ -418,19 +481,22 @@ def legacy_zeropower_via_newtonschulz5(G, steps=5, eps=1e-7):
|
|
418
481
|
assert len(G.shape) == 2
|
419
482
|
a, b, c = (3.4445, -4.7750, 2.0315)
|
420
483
|
G = G.clone()
|
421
|
-
|
484
|
+
x = G if G.dtype == torch.float64 else stochastic_round_(G)
|
422
485
|
|
423
486
|
# X = X / (X.norm(dim=(-2, -1), keepdim=True) + eps)
|
424
|
-
stochastic_divide_with_eps_(
|
487
|
+
stochastic_divide_with_eps_(x, G.norm(dim=(-2, -1)), eps) # ensure top singular value <= 1
|
425
488
|
if G.size(0) > G.size(1):
|
426
|
-
|
489
|
+
x = x.T
|
427
490
|
for _ in range(steps):
|
428
|
-
|
429
|
-
|
430
|
-
|
491
|
+
s = x @ x.mT
|
492
|
+
y = c * s
|
493
|
+
y.diagonal(dim1=-2, dim2=-1).add_(b)
|
494
|
+
y = y @ s
|
495
|
+
y.diagonal(dim1=-2, dim2=-1).add_(a)
|
496
|
+
x = y @ x
|
431
497
|
if G.size(0) > G.size(1):
|
432
|
-
|
433
|
-
return
|
498
|
+
x = x.T
|
499
|
+
return x.to(G.dtype)
|
434
500
|
|
435
501
|
|
436
502
|
@decorator_knowngood
|
@@ -492,6 +558,8 @@ def _compilable_orthogonal_(x: Tensor, mode: str | ZerothPowerMode, out: Tensor
|
|
492
558
|
scale_mode = OrthoScaleMode(scale_mode)
|
493
559
|
if mode == ZerothPowerMode.newtonschulz or x.shape[0] != x.shape[1]:
|
494
560
|
y = zeropower_via_newtonschulz5(x, 5)
|
561
|
+
elif mode == ZerothPowerMode.thinky_polar_express:
|
562
|
+
y = msign(x, 10)
|
495
563
|
elif mode == ZerothPowerMode.legacy_newtonschulz:
|
496
564
|
y = legacy_zeropower_via_newtonschulz5(x, 5)
|
497
565
|
elif mode == ZerothPowerMode.qr:
|
@@ -1522,7 +1590,7 @@ def _compilable_copy_stochastic_(target: Tensor, source: Tensor):
|
|
1522
1590
|
|
1523
1591
|
def copy_stochastic_(target: Tensor, source: Tensor):
|
1524
1592
|
if target.dtype == torch.bfloat16 and source.dtype in (torch.float16, torch.float32, torch.float64):
|
1525
|
-
|
1593
|
+
source = stochastic_round_(target, source)
|
1526
1594
|
set_(target, source)
|
1527
1595
|
|
1528
1596
|
|
@@ -2349,10 +2417,11 @@ def bf16_matmul(x: Tensor, y: Tensor):
|
|
2349
2417
|
def if_iscompiling(fn):
|
2350
2418
|
base = getattr(torch, fn.__name__, None)
|
2351
2419
|
|
2352
|
-
|
2353
|
-
|
2354
|
-
|
2355
|
-
|
2420
|
+
@functools.wraps(fn)
|
2421
|
+
def _fn(*args, **kwargs):
|
2422
|
+
if torch.compiler.is_compiling() and base is not None:
|
2423
|
+
return base(*args, **kwargs)
|
2424
|
+
return fn(*args, **kwargs)
|
2356
2425
|
|
2357
2426
|
return _fn
|
2358
2427
|
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: heavyball
|
3
|
-
Version: 2.1.
|
3
|
+
Version: 2.1.3
|
4
4
|
Summary: Efficient Optimizers
|
5
5
|
Author-email: HeavyBall Authors <github.heavyball@nestler.sh>
|
6
6
|
Project-URL: source, https://github.com/HomebrewML/HeavyBall
|
@@ -16,11 +16,12 @@ Requires-Python: >=3.9
|
|
16
16
|
Description-Content-Type: text/markdown
|
17
17
|
License-File: LICENSE
|
18
18
|
Requires-Dist: opt-einsum>=3.4.0
|
19
|
-
Requires-Dist: torch
|
20
|
-
Requires-Dist: numpy
|
19
|
+
Requires-Dist: torch<3.0,>=2.2
|
20
|
+
Requires-Dist: numpy<2.0.0
|
21
21
|
Provides-Extra: dev
|
22
22
|
Requires-Dist: pre-commit; extra == "dev"
|
23
23
|
Requires-Dist: pytest; extra == "dev"
|
24
|
+
Requires-Dist: hypothesis; extra == "dev"
|
24
25
|
Requires-Dist: ruff; extra == "dev"
|
25
26
|
Requires-Dist: matplotlib; extra == "dev"
|
26
27
|
Requires-Dist: seaborn; extra == "dev"
|
@@ -0,0 +1,9 @@
|
|
1
|
+
heavyball/__init__.py,sha256=1BTb7G-VcfcMyS4EpuVnhE5DBp2fj_Zzs9EQr6slPzg,30491
|
2
|
+
heavyball/chainable.py,sha256=8S-7QRZYiy_ARhQ8uDu5G0Eg3ouT9Vcfk-rxbKlp4zI,42510
|
3
|
+
heavyball/helpers.py,sha256=is4Egdgoj2GUsBYdraItonqsoVIY9ZKP_VZl-hEnF1Y,31077
|
4
|
+
heavyball/utils.py,sha256=_AOFIkFyaMO39YjbvclkzivR-nKe_kLShRZda3rgMiA,104850
|
5
|
+
heavyball-2.1.3.dist-info/licenses/LICENSE,sha256=G9fFZcNIVWjU7o6Pr_4sJBRCNDU5X-zelSxIJ2D48ms,1323
|
6
|
+
heavyball-2.1.3.dist-info/METADATA,sha256=by35259YI9DvUQ8Vq958sHmAxqSVtPY5JoY5Hn0CccY,5088
|
7
|
+
heavyball-2.1.3.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
8
|
+
heavyball-2.1.3.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
|
9
|
+
heavyball-2.1.3.dist-info/RECORD,,
|
heavyball-2.1.1.dist-info/RECORD
DELETED
@@ -1,9 +0,0 @@
|
|
1
|
-
heavyball/__init__.py,sha256=1BTb7G-VcfcMyS4EpuVnhE5DBp2fj_Zzs9EQr6slPzg,30491
|
2
|
-
heavyball/chainable.py,sha256=8S-7QRZYiy_ARhQ8uDu5G0Eg3ouT9Vcfk-rxbKlp4zI,42510
|
3
|
-
heavyball/helpers.py,sha256=zk_S84wpGcvO9P6kn4UeaQUIDowHxcbM9qQITEm2g5I,30267
|
4
|
-
heavyball/utils.py,sha256=zAOlSDqMbSUJEdCfoOcUbRIO94Qg4cxT40IN_UPskQk,102492
|
5
|
-
heavyball-2.1.1.dist-info/licenses/LICENSE,sha256=G9fFZcNIVWjU7o6Pr_4sJBRCNDU5X-zelSxIJ2D48ms,1323
|
6
|
-
heavyball-2.1.1.dist-info/METADATA,sha256=92i_Q4bxQgRsH8BEOYEuW0Qg43nR5jJLSPGIIJmyzxc,5037
|
7
|
-
heavyball-2.1.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
8
|
-
heavyball-2.1.1.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
|
9
|
-
heavyball-2.1.1.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|