heavyball 2.1.1__py3-none-any.whl → 2.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
heavyball/helpers.py CHANGED
@@ -3,7 +3,8 @@ from __future__ import annotations
3
3
  import functools
4
4
  import math
5
5
  import threading
6
- from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union
6
+ from contextlib import contextmanager
7
+ from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Sequence, Tuple, Union
7
8
 
8
9
  import numpy
9
10
  import numpy as np
@@ -11,7 +12,6 @@ import optuna
11
12
  import optunahub
12
13
  import pandas as pd
13
14
  import torch
14
- from botorch.utils.sampling import manual_seed
15
15
  from hebo.design_space.design_space import DesignSpace
16
16
  from hebo.optimizers.hebo import HEBO
17
17
  from optuna._transform import _SearchSpaceTransform
@@ -21,13 +21,6 @@ from optuna.samplers._lazy_random_state import LazyRandomState
21
21
  from optuna.study import Study
22
22
  from optuna.study._study_direction import StudyDirection
23
23
  from optuna.trial import FrozenTrial, TrialState
24
- from optuna_integration.botorch import (
25
- ehvi_candidates_func,
26
- logei_candidates_func,
27
- qehvi_candidates_func,
28
- qei_candidates_func,
29
- qparego_candidates_func,
30
- )
31
24
  from torch import Tensor
32
25
  from torch.nn import functional as F
33
26
 
@@ -37,6 +30,33 @@ _MAXINT32 = (1 << 31) - 1
37
30
  _SAMPLER_KEY = "auto:sampler"
38
31
 
39
32
 
33
+ @contextmanager
34
+ def manual_seed(seed: int | None = None) -> Generator[None, None, None]:
35
+ r"""
36
+ Contextmanager for manual setting the torch.random seed.
37
+
38
+ Args:
39
+ seed: The seed to set the random number generator to.
40
+
41
+ Returns:
42
+ Generator
43
+
44
+ Example:
45
+ >>> with manual_seed(1234):
46
+ >>> X = torch.rand(3)
47
+
48
+ copied as-is from https://github.com/meta-pytorch/botorch/blob/a42cd65f9b704cdb6f2ee64db99a022eb15295d5/botorch/utils/sampling.py#L53C1-L75C50 under the MIT License
49
+ """
50
+ old_state = torch.random.get_rng_state()
51
+ try:
52
+ if seed is not None:
53
+ torch.random.manual_seed(seed)
54
+ yield
55
+ finally:
56
+ if seed is not None:
57
+ torch.random.set_rng_state(old_state)
58
+
59
+
40
60
  class SimpleAPIBaseSampler(BaseSampler):
41
61
  def __init__(
42
62
  self,
@@ -65,6 +85,16 @@ def _get_default_candidates_func(
65
85
  """
66
86
  The original is available at https://github.com/optuna/optuna-integration/blob/156a8bc081322791015d2beefff9373ed7b24047/optuna_integration/botorch/botorch.py under the MIT License
67
87
  """
88
+
89
+ # lazy import
90
+ from optuna_integration.botorch import (
91
+ ehvi_candidates_func,
92
+ logei_candidates_func,
93
+ qehvi_candidates_func,
94
+ qei_candidates_func,
95
+ qparego_candidates_func,
96
+ )
97
+
68
98
  if n_objectives > 3 and not has_constraint and not consider_running_trials:
69
99
  return ehvi_candidates_func
70
100
  elif n_objectives > 3:
heavyball/utils.py CHANGED
@@ -27,6 +27,7 @@ class ZerothPowerMode(enum.Enum):
27
27
  qr = "qr"
28
28
  svd = "svd"
29
29
  legacy_svd = "legacy_svd"
30
+ thinky_polar_express = "thinky_polar_express"
30
31
 
31
32
 
32
33
  class OrthoScaleMode(enum.Enum):
@@ -46,7 +47,7 @@ _cudnn_double_backward_pattern = re.compile(
46
47
  )
47
48
  _torch_compile_double_backward_pattern = re.compile(r"compile.*does not currently support double backward")
48
49
  _fd_error = (
49
- "You can accelerate startup by globally enabling finite_differences first " #
50
+ "You can accelerate startup by globally enabling finite_differences first "
50
51
  "(via opt.finite_differences=True or by subclassing it)\n"
51
52
  "Original Error: "
52
53
  )
@@ -390,12 +391,12 @@ def zeropower_via_newtonschulz5(G, steps=5, eps=1e-7):
390
391
  ) # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng
391
392
  assert steps == 5
392
393
  G = G.clone()
393
- X = G if G.dtype == torch.float64 else stochastic_round_(G)
394
+ x = G if G.dtype == torch.float64 else stochastic_round_(G)
394
395
  if G.size(-2) > G.size(-1):
395
- X = X.mT
396
+ x = x.mT
396
397
 
397
398
  # X = X / (X.norm(dim=(-2, -1), keepdim=True) + eps)
398
- stochastic_divide_with_eps_(X, G.norm(dim=(-2, -1)), eps) # ensure top singular value <= 1
399
+ stochastic_divide_with_eps_(x, G.norm(dim=(-2, -1)), eps) # ensure top singular value <= 1
399
400
  # Perform the NS iterations
400
401
  for a, b, c in [
401
402
  (4.0848, -6.8946, 2.9270),
@@ -404,13 +405,75 @@ def zeropower_via_newtonschulz5(G, steps=5, eps=1e-7):
404
405
  (2.8769, -3.1427, 1.2046),
405
406
  (2.8366, -3.0525, 1.2012),
406
407
  ]:
407
- A = X @ X.mT
408
- B = b * A + c * A @ A # adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
409
- X = a * X + B @ X
408
+ s = x @ x.mT
409
+ y = c * s
410
+ y.diagonal(dim1=-2, dim2=-1).add_(b)
411
+ y = y @ s
412
+ y.diagonal(dim1=-2, dim2=-1).add_(a)
413
+ x = y @ x
410
414
 
411
415
  if G.size(-2) > G.size(-1):
412
- X = X.mT
413
- return X.to(G.dtype)
416
+ x = x.mT
417
+ return x.to(G.dtype)
418
+
419
+
420
+ ###### START
421
+ # Based on https://arxiv.org/pdf/2505.16932v3
422
+ # and https://github.com/NoahAmsel/PolarExpress/blob/5454910920ca8c65afda28820cdf9e49b9436ed0/polar_express.py#L69-L82
423
+ # and https://github.com/thinking-machines-lab/manifolds/blob/89dcae50f01af59f1e0570289474da3a2ecaa60b/src/msign.py#L47
424
+ #
425
+ # under the MIT License
426
+
427
+ # Coefficients are from https://arxiv.org/pdf/2505.16932v3
428
+ ABC_LIST: list[tuple[float, float, float]] = [
429
+ (8.28721201814563, -23.595886519098837, 17.300387312530933),
430
+ (4.107059111542203, -2.9478499167379106, 0.5448431082926601),
431
+ (3.9486908534822946, -2.908902115962949, 0.5518191394370137),
432
+ (3.3184196573706015, -2.488488024314874, 0.51004894012372),
433
+ (2.300652019954817, -1.6689039845747493, 0.4188073119525673),
434
+ (1.891301407787398, -1.2679958271945868, 0.37680408948524835),
435
+ (1.8750014808534479, -1.2500016453999487, 0.3750001645474248),
436
+ (1.875, -1.25, 0.375),
437
+ ]
438
+
439
+ # safety factor for numerical stability (but exclude last polynomial)
440
+ ABC_LIST_STABLE: list[tuple[float, float, float]] = [
441
+ (a / 1.01, b / 1.01**3, c / 1.01**5) for (a, b, c) in ABC_LIST[:-1]
442
+ ] + [ABC_LIST[-1]]
443
+
444
+
445
+ def msign(G: torch.Tensor, steps: int = 10, eps: float = 1e-7) -> torch.Tensor:
446
+ """
447
+ Polar Express algorithm for the matrix sign function:
448
+ https://arxiv.org/abs/2505.16932
449
+ """
450
+ assert G.ndim >= 2
451
+ should_transpose: bool = G.size(-2) > G.size(-1)
452
+
453
+ x = G if G.dtype == torch.float64 else stochastic_round_(G)
454
+ if should_transpose:
455
+ x = x.mT
456
+
457
+ # x = x / (x.norm(dim=(-2, -1), keepdim=True) * 1.01 + eps)
458
+ stochastic_divide_with_eps_(x, x.norm(dim=(-2, -1)) * 1.01, eps)
459
+
460
+ for step in range(steps):
461
+ a, b, c = ABC_LIST_STABLE[step] if step < len(ABC_LIST_STABLE) else ABC_LIST_STABLE[-1]
462
+ s = x @ x.mT
463
+ # goal is to compute x = a x + b S x + c S^2 x
464
+ # we can break this up into: x = (a I + (b I + c S) S) x
465
+ y = c * s
466
+ y.diagonal(dim1=-2, dim2=-1).add_(b)
467
+ y = y @ s
468
+ y.diagonal(dim1=-2, dim2=-1).add_(a)
469
+ x = y @ x
470
+
471
+ if should_transpose:
472
+ x = x.mT
473
+ return x.float()
474
+
475
+
476
+ ###### END
414
477
 
415
478
 
416
479
  @decorator_knowngood
@@ -418,19 +481,22 @@ def legacy_zeropower_via_newtonschulz5(G, steps=5, eps=1e-7):
418
481
  assert len(G.shape) == 2
419
482
  a, b, c = (3.4445, -4.7750, 2.0315)
420
483
  G = G.clone()
421
- X = G if G.dtype == torch.float64 else stochastic_round_(G)
484
+ x = G if G.dtype == torch.float64 else stochastic_round_(G)
422
485
 
423
486
  # X = X / (X.norm(dim=(-2, -1), keepdim=True) + eps)
424
- stochastic_divide_with_eps_(X, G.norm(dim=(-2, -1)), eps) # ensure top singular value <= 1
487
+ stochastic_divide_with_eps_(x, G.norm(dim=(-2, -1)), eps) # ensure top singular value <= 1
425
488
  if G.size(0) > G.size(1):
426
- X = X.T
489
+ x = x.T
427
490
  for _ in range(steps):
428
- A = X @ X.T
429
- B = b * A + c * A @ A # adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
430
- X = a * X + B @ X
491
+ s = x @ x.mT
492
+ y = c * s
493
+ y.diagonal(dim1=-2, dim2=-1).add_(b)
494
+ y = y @ s
495
+ y.diagonal(dim1=-2, dim2=-1).add_(a)
496
+ x = y @ x
431
497
  if G.size(0) > G.size(1):
432
- X = X.T
433
- return X.to(G.dtype)
498
+ x = x.T
499
+ return x.to(G.dtype)
434
500
 
435
501
 
436
502
  @decorator_knowngood
@@ -492,6 +558,8 @@ def _compilable_orthogonal_(x: Tensor, mode: str | ZerothPowerMode, out: Tensor
492
558
  scale_mode = OrthoScaleMode(scale_mode)
493
559
  if mode == ZerothPowerMode.newtonschulz or x.shape[0] != x.shape[1]:
494
560
  y = zeropower_via_newtonschulz5(x, 5)
561
+ elif mode == ZerothPowerMode.thinky_polar_express:
562
+ y = msign(x, 10)
495
563
  elif mode == ZerothPowerMode.legacy_newtonschulz:
496
564
  y = legacy_zeropower_via_newtonschulz5(x, 5)
497
565
  elif mode == ZerothPowerMode.qr:
@@ -1522,7 +1590,7 @@ def _compilable_copy_stochastic_(target: Tensor, source: Tensor):
1522
1590
 
1523
1591
  def copy_stochastic_(target: Tensor, source: Tensor):
1524
1592
  if target.dtype == torch.bfloat16 and source.dtype in (torch.float16, torch.float32, torch.float64):
1525
- _compilable_copy_stochastic_(target, source.float())
1593
+ source = stochastic_round_(target, source)
1526
1594
  set_(target, source)
1527
1595
 
1528
1596
 
@@ -2349,10 +2417,11 @@ def bf16_matmul(x: Tensor, y: Tensor):
2349
2417
  def if_iscompiling(fn):
2350
2418
  base = getattr(torch, fn.__name__, None)
2351
2419
 
2352
- def _fn(x):
2353
- if torch.compiler.is_compiling() and hasattr(torch, fn.__name__):
2354
- return base(x)
2355
- return fn(x)
2420
+ @functools.wraps(fn)
2421
+ def _fn(*args, **kwargs):
2422
+ if torch.compiler.is_compiling() and base is not None:
2423
+ return base(*args, **kwargs)
2424
+ return fn(*args, **kwargs)
2356
2425
 
2357
2426
  return _fn
2358
2427
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: heavyball
3
- Version: 2.1.1
3
+ Version: 2.1.3
4
4
  Summary: Efficient Optimizers
5
5
  Author-email: HeavyBall Authors <github.heavyball@nestler.sh>
6
6
  Project-URL: source, https://github.com/HomebrewML/HeavyBall
@@ -16,11 +16,12 @@ Requires-Python: >=3.9
16
16
  Description-Content-Type: text/markdown
17
17
  License-File: LICENSE
18
18
  Requires-Dist: opt-einsum>=3.4.0
19
- Requires-Dist: torch>=2.7.0
20
- Requires-Dist: numpy
19
+ Requires-Dist: torch<3.0,>=2.2
20
+ Requires-Dist: numpy<2.0.0
21
21
  Provides-Extra: dev
22
22
  Requires-Dist: pre-commit; extra == "dev"
23
23
  Requires-Dist: pytest; extra == "dev"
24
+ Requires-Dist: hypothesis; extra == "dev"
24
25
  Requires-Dist: ruff; extra == "dev"
25
26
  Requires-Dist: matplotlib; extra == "dev"
26
27
  Requires-Dist: seaborn; extra == "dev"
@@ -0,0 +1,9 @@
1
+ heavyball/__init__.py,sha256=1BTb7G-VcfcMyS4EpuVnhE5DBp2fj_Zzs9EQr6slPzg,30491
2
+ heavyball/chainable.py,sha256=8S-7QRZYiy_ARhQ8uDu5G0Eg3ouT9Vcfk-rxbKlp4zI,42510
3
+ heavyball/helpers.py,sha256=is4Egdgoj2GUsBYdraItonqsoVIY9ZKP_VZl-hEnF1Y,31077
4
+ heavyball/utils.py,sha256=_AOFIkFyaMO39YjbvclkzivR-nKe_kLShRZda3rgMiA,104850
5
+ heavyball-2.1.3.dist-info/licenses/LICENSE,sha256=G9fFZcNIVWjU7o6Pr_4sJBRCNDU5X-zelSxIJ2D48ms,1323
6
+ heavyball-2.1.3.dist-info/METADATA,sha256=by35259YI9DvUQ8Vq958sHmAxqSVtPY5JoY5Hn0CccY,5088
7
+ heavyball-2.1.3.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
8
+ heavyball-2.1.3.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
9
+ heavyball-2.1.3.dist-info/RECORD,,
@@ -1,9 +0,0 @@
1
- heavyball/__init__.py,sha256=1BTb7G-VcfcMyS4EpuVnhE5DBp2fj_Zzs9EQr6slPzg,30491
2
- heavyball/chainable.py,sha256=8S-7QRZYiy_ARhQ8uDu5G0Eg3ouT9Vcfk-rxbKlp4zI,42510
3
- heavyball/helpers.py,sha256=zk_S84wpGcvO9P6kn4UeaQUIDowHxcbM9qQITEm2g5I,30267
4
- heavyball/utils.py,sha256=zAOlSDqMbSUJEdCfoOcUbRIO94Qg4cxT40IN_UPskQk,102492
5
- heavyball-2.1.1.dist-info/licenses/LICENSE,sha256=G9fFZcNIVWjU7o6Pr_4sJBRCNDU5X-zelSxIJ2D48ms,1323
6
- heavyball-2.1.1.dist-info/METADATA,sha256=92i_Q4bxQgRsH8BEOYEuW0Qg43nR5jJLSPGIIJmyzxc,5037
7
- heavyball-2.1.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
8
- heavyball-2.1.1.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
9
- heavyball-2.1.1.dist-info/RECORD,,