heavyball 2.1.2__tar.gz → 2.1.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. {heavyball-2.1.2 → heavyball-2.1.3}/PKG-INFO +2 -1
  2. {heavyball-2.1.2 → heavyball-2.1.3}/heavyball/helpers.py +39 -9
  3. {heavyball-2.1.2 → heavyball-2.1.3}/heavyball/utils.py +16 -10
  4. {heavyball-2.1.2 → heavyball-2.1.3}/heavyball.egg-info/PKG-INFO +2 -1
  5. {heavyball-2.1.2 → heavyball-2.1.3}/heavyball.egg-info/SOURCES.txt +7 -1
  6. {heavyball-2.1.2 → heavyball-2.1.3}/heavyball.egg-info/requires.txt +1 -0
  7. {heavyball-2.1.2 → heavyball-2.1.3}/pyproject.toml +2 -2
  8. {heavyball-2.1.2 → heavyball-2.1.3}/test/test_bf16_params.py +0 -1
  9. {heavyball-2.1.2 → heavyball-2.1.3}/test/test_bf16_q.py +0 -1
  10. {heavyball-2.1.2 → heavyball-2.1.3}/test/test_bf16_storage.py +0 -1
  11. {heavyball-2.1.2 → heavyball-2.1.3}/test/test_caution.py +0 -1
  12. heavyball-2.1.3/test/test_chainable_cpu.py +65 -0
  13. {heavyball-2.1.2 → heavyball-2.1.3}/test/test_channels_last.py +0 -1
  14. {heavyball-2.1.2 → heavyball-2.1.3}/test/test_closure.py +0 -1
  15. heavyball-2.1.3/test/test_cpu_features.py +134 -0
  16. {heavyball-2.1.2 → heavyball-2.1.3}/test/test_ema.py +0 -1
  17. {heavyball-2.1.2 → heavyball-2.1.3}/test/test_foreach.py +0 -1
  18. heavyball-2.1.3/test/test_helpers_cpu.py +107 -0
  19. {heavyball-2.1.2 → heavyball-2.1.3}/test/test_hook.py +0 -1
  20. {heavyball-2.1.2 → heavyball-2.1.3}/test/test_mars.py +0 -1
  21. {heavyball-2.1.2 → heavyball-2.1.3}/test/test_memory.py +0 -1
  22. {heavyball-2.1.2 → heavyball-2.1.3}/test/test_memory_leak.py +0 -1
  23. {heavyball-2.1.2 → heavyball-2.1.3}/test/test_merge.py +0 -1
  24. {heavyball-2.1.2 → heavyball-2.1.3}/test/test_nd_param.py +0 -1
  25. {heavyball-2.1.2 → heavyball-2.1.3}/test/test_no_grad.py +0 -1
  26. heavyball-2.1.3/test/test_optimizer_cpu_smoke.py +65 -0
  27. {heavyball-2.1.2 → heavyball-2.1.3}/test/test_save_restore.py +0 -1
  28. {heavyball-2.1.2 → heavyball-2.1.3}/test/test_singular_values.py +1 -1
  29. {heavyball-2.1.2 → heavyball-2.1.3}/test/test_stochastic_updates.py +0 -1
  30. {heavyball-2.1.2 → heavyball-2.1.3}/test/test_toy_training.py +4 -4
  31. heavyball-2.1.3/test/test_utils_cpu.py +295 -0
  32. heavyball-2.1.3/test/test_utils_property.py +281 -0
  33. {heavyball-2.1.2 → heavyball-2.1.3}/LICENSE +0 -0
  34. {heavyball-2.1.2 → heavyball-2.1.3}/README.md +0 -0
  35. {heavyball-2.1.2 → heavyball-2.1.3}/heavyball/__init__.py +0 -0
  36. {heavyball-2.1.2 → heavyball-2.1.3}/heavyball/chainable.py +0 -0
  37. {heavyball-2.1.2 → heavyball-2.1.3}/heavyball.egg-info/dependency_links.txt +0 -0
  38. {heavyball-2.1.2 → heavyball-2.1.3}/heavyball.egg-info/top_level.txt +0 -0
  39. {heavyball-2.1.2 → heavyball-2.1.3}/setup.cfg +0 -0
  40. {heavyball-2.1.2 → heavyball-2.1.3}/test/test_clip.py +0 -0
  41. {heavyball-2.1.2 → heavyball-2.1.3}/test/test_migrate_cli.py +0 -0
  42. {heavyball-2.1.2 → heavyball-2.1.3}/test/test_psgd_precond_init_stability.py +0 -0
  43. {heavyball-2.1.2 → heavyball-2.1.3}/test/test_soap.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: heavyball
3
- Version: 2.1.2
3
+ Version: 2.1.3
4
4
  Summary: Efficient Optimizers
5
5
  Author-email: HeavyBall Authors <github.heavyball@nestler.sh>
6
6
  Project-URL: source, https://github.com/HomebrewML/HeavyBall
@@ -21,6 +21,7 @@ Requires-Dist: numpy<2.0.0
21
21
  Provides-Extra: dev
22
22
  Requires-Dist: pre-commit; extra == "dev"
23
23
  Requires-Dist: pytest; extra == "dev"
24
+ Requires-Dist: hypothesis; extra == "dev"
24
25
  Requires-Dist: ruff; extra == "dev"
25
26
  Requires-Dist: matplotlib; extra == "dev"
26
27
  Requires-Dist: seaborn; extra == "dev"
@@ -3,7 +3,8 @@ from __future__ import annotations
3
3
  import functools
4
4
  import math
5
5
  import threading
6
- from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union
6
+ from contextlib import contextmanager
7
+ from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Sequence, Tuple, Union
7
8
 
8
9
  import numpy
9
10
  import numpy as np
@@ -11,7 +12,6 @@ import optuna
11
12
  import optunahub
12
13
  import pandas as pd
13
14
  import torch
14
- from botorch.utils.sampling import manual_seed
15
15
  from hebo.design_space.design_space import DesignSpace
16
16
  from hebo.optimizers.hebo import HEBO
17
17
  from optuna._transform import _SearchSpaceTransform
@@ -21,13 +21,6 @@ from optuna.samplers._lazy_random_state import LazyRandomState
21
21
  from optuna.study import Study
22
22
  from optuna.study._study_direction import StudyDirection
23
23
  from optuna.trial import FrozenTrial, TrialState
24
- from optuna_integration.botorch import (
25
- ehvi_candidates_func,
26
- logei_candidates_func,
27
- qehvi_candidates_func,
28
- qei_candidates_func,
29
- qparego_candidates_func,
30
- )
31
24
  from torch import Tensor
32
25
  from torch.nn import functional as F
33
26
 
@@ -37,6 +30,33 @@ _MAXINT32 = (1 << 31) - 1
37
30
  _SAMPLER_KEY = "auto:sampler"
38
31
 
39
32
 
33
+ @contextmanager
34
+ def manual_seed(seed: int | None = None) -> Generator[None, None, None]:
35
+ r"""
36
+ Contextmanager for manual setting the torch.random seed.
37
+
38
+ Args:
39
+ seed: The seed to set the random number generator to.
40
+
41
+ Returns:
42
+ Generator
43
+
44
+ Example:
45
+ >>> with manual_seed(1234):
46
+ >>> X = torch.rand(3)
47
+
48
+ copied as-is from https://github.com/meta-pytorch/botorch/blob/a42cd65f9b704cdb6f2ee64db99a022eb15295d5/botorch/utils/sampling.py#L53C1-L75C50 under the MIT License
49
+ """
50
+ old_state = torch.random.get_rng_state()
51
+ try:
52
+ if seed is not None:
53
+ torch.random.manual_seed(seed)
54
+ yield
55
+ finally:
56
+ if seed is not None:
57
+ torch.random.set_rng_state(old_state)
58
+
59
+
40
60
  class SimpleAPIBaseSampler(BaseSampler):
41
61
  def __init__(
42
62
  self,
@@ -65,6 +85,16 @@ def _get_default_candidates_func(
65
85
  """
66
86
  The original is available at https://github.com/optuna/optuna-integration/blob/156a8bc081322791015d2beefff9373ed7b24047/optuna_integration/botorch/botorch.py under the MIT License
67
87
  """
88
+
89
+ # lazy import
90
+ from optuna_integration.botorch import (
91
+ ehvi_candidates_func,
92
+ logei_candidates_func,
93
+ qehvi_candidates_func,
94
+ qei_candidates_func,
95
+ qparego_candidates_func,
96
+ )
97
+
68
98
  if n_objectives > 3 and not has_constraint and not consider_running_trials:
69
99
  return ehvi_candidates_func
70
100
  elif n_objectives > 3:
@@ -47,7 +47,7 @@ _cudnn_double_backward_pattern = re.compile(
47
47
  )
48
48
  _torch_compile_double_backward_pattern = re.compile(r"compile.*does not currently support double backward")
49
49
  _fd_error = (
50
- "You can accelerate startup by globally enabling finite_differences first " #
50
+ "You can accelerate startup by globally enabling finite_differences first "
51
51
  "(via opt.finite_differences=True or by subclassing it)\n"
52
52
  "Original Error: "
53
53
  )
@@ -418,9 +418,13 @@ def zeropower_via_newtonschulz5(G, steps=5, eps=1e-7):
418
418
 
419
419
 
420
420
  ###### START
421
- # Taken from https://github.com/thinking-machines-lab/manifolds/blob/89dcae50f01af59f1e0570289474da3a2ecaa60b/src/msign.py#L47
421
+ # Based on https://arxiv.org/pdf/2505.16932v3
422
+ # and https://github.com/NoahAmsel/PolarExpress/blob/5454910920ca8c65afda28820cdf9e49b9436ed0/polar_express.py#L69-L82
423
+ # and https://github.com/thinking-machines-lab/manifolds/blob/89dcae50f01af59f1e0570289474da3a2ecaa60b/src/msign.py#L47
424
+ #
422
425
  # under the MIT License
423
426
 
427
+ # Coefficients are from https://arxiv.org/pdf/2505.16932v3
424
428
  ABC_LIST: list[tuple[float, float, float]] = [
425
429
  (8.28721201814563, -23.595886519098837, 17.300387312530933),
426
430
  (4.107059111542203, -2.9478499167379106, 0.5448431082926601),
@@ -438,7 +442,7 @@ ABC_LIST_STABLE: list[tuple[float, float, float]] = [
438
442
  ] + [ABC_LIST[-1]]
439
443
 
440
444
 
441
- def msign(G: torch.Tensor, steps: int = 10) -> torch.Tensor:
445
+ def msign(G: torch.Tensor, steps: int = 10, eps: float = 1e-7) -> torch.Tensor:
442
446
  """
443
447
  Polar Express algorithm for the matrix sign function:
444
448
  https://arxiv.org/abs/2505.16932
@@ -450,7 +454,9 @@ def msign(G: torch.Tensor, steps: int = 10) -> torch.Tensor:
450
454
  if should_transpose:
451
455
  x = x.mT
452
456
 
453
- x /= x.norm(dim=(-2, -1), keepdim=True) * 1.01
457
+ # x = x / (x.norm(dim=(-2, -1), keepdim=True) * 1.01 + eps)
458
+ stochastic_divide_with_eps_(x, x.norm(dim=(-2, -1)) * 1.01, eps)
459
+
454
460
  for step in range(steps):
455
461
  a, b, c = ABC_LIST_STABLE[step] if step < len(ABC_LIST_STABLE) else ABC_LIST_STABLE[-1]
456
462
  s = x @ x.mT
@@ -464,7 +470,6 @@ def msign(G: torch.Tensor, steps: int = 10) -> torch.Tensor:
464
470
 
465
471
  if should_transpose:
466
472
  x = x.mT
467
- x = torch.nan_to_num(x)
468
473
  return x.float()
469
474
 
470
475
 
@@ -1585,7 +1590,7 @@ def _compilable_copy_stochastic_(target: Tensor, source: Tensor):
1585
1590
 
1586
1591
  def copy_stochastic_(target: Tensor, source: Tensor):
1587
1592
  if target.dtype == torch.bfloat16 and source.dtype in (torch.float16, torch.float32, torch.float64):
1588
- _compilable_copy_stochastic_(target, source.float())
1593
+ source = stochastic_round_(target, source)
1589
1594
  set_(target, source)
1590
1595
 
1591
1596
 
@@ -2412,10 +2417,11 @@ def bf16_matmul(x: Tensor, y: Tensor):
2412
2417
  def if_iscompiling(fn):
2413
2418
  base = getattr(torch, fn.__name__, None)
2414
2419
 
2415
- def _fn(x):
2416
- if torch.compiler.is_compiling() and hasattr(torch, fn.__name__):
2417
- return base(x)
2418
- return fn(x)
2420
+ @functools.wraps(fn)
2421
+ def _fn(*args, **kwargs):
2422
+ if torch.compiler.is_compiling() and base is not None:
2423
+ return base(*args, **kwargs)
2424
+ return fn(*args, **kwargs)
2419
2425
 
2420
2426
  return _fn
2421
2427
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: heavyball
3
- Version: 2.1.2
3
+ Version: 2.1.3
4
4
  Summary: Efficient Optimizers
5
5
  Author-email: HeavyBall Authors <github.heavyball@nestler.sh>
6
6
  Project-URL: source, https://github.com/HomebrewML/HeavyBall
@@ -21,6 +21,7 @@ Requires-Dist: numpy<2.0.0
21
21
  Provides-Extra: dev
22
22
  Requires-Dist: pre-commit; extra == "dev"
23
23
  Requires-Dist: pytest; extra == "dev"
24
+ Requires-Dist: hypothesis; extra == "dev"
24
25
  Requires-Dist: ruff; extra == "dev"
25
26
  Requires-Dist: matplotlib; extra == "dev"
26
27
  Requires-Dist: seaborn; extra == "dev"
@@ -14,11 +14,14 @@ test/test_bf16_params.py
14
14
  test/test_bf16_q.py
15
15
  test/test_bf16_storage.py
16
16
  test/test_caution.py
17
+ test/test_chainable_cpu.py
17
18
  test/test_channels_last.py
18
19
  test/test_clip.py
19
20
  test/test_closure.py
21
+ test/test_cpu_features.py
20
22
  test/test_ema.py
21
23
  test/test_foreach.py
24
+ test/test_helpers_cpu.py
22
25
  test/test_hook.py
23
26
  test/test_mars.py
24
27
  test/test_memory.py
@@ -27,9 +30,12 @@ test/test_merge.py
27
30
  test/test_migrate_cli.py
28
31
  test/test_nd_param.py
29
32
  test/test_no_grad.py
33
+ test/test_optimizer_cpu_smoke.py
30
34
  test/test_psgd_precond_init_stability.py
31
35
  test/test_save_restore.py
32
36
  test/test_singular_values.py
33
37
  test/test_soap.py
34
38
  test/test_stochastic_updates.py
35
- test/test_toy_training.py
39
+ test/test_toy_training.py
40
+ test/test_utils_cpu.py
41
+ test/test_utils_property.py
@@ -5,6 +5,7 @@ numpy<2.0.0
5
5
  [dev]
6
6
  pre-commit
7
7
  pytest
8
+ hypothesis
8
9
  ruff
9
10
  matplotlib
10
11
  seaborn
@@ -5,7 +5,7 @@ build-backend = "setuptools.build_meta"
5
5
  [project]
6
6
  name = "heavyball"
7
7
  description = "Efficient Optimizers"
8
- version = "2.1.2"
8
+ version = "2.1.3"
9
9
  authors = [{ name = "HeavyBall Authors", email = "github.heavyball@nestler.sh" }]
10
10
  classifiers = ["Intended Audience :: Developers",
11
11
  "Intended Audience :: Science/Research",
@@ -28,7 +28,7 @@ readme = "README.md"
28
28
  requires-python = ">=3.9"
29
29
 
30
30
  [project.optional-dependencies]
31
- dev = ["pre-commit", "pytest", "ruff", "matplotlib", "seaborn", "pandas", "typer", "optuna", "optunahub", "hebo", "lightbench"]
31
+ dev = ["pre-commit", "pytest", "hypothesis", "ruff", "matplotlib", "seaborn", "pandas", "typer", "optuna", "optunahub", "hebo", "lightbench"]
32
32
 
33
33
  [project.urls]
34
34
  source = "https://github.com/HomebrewML/HeavyBall"
@@ -8,7 +8,6 @@ from torch import nn
8
8
  from torch._dynamo import config
9
9
 
10
10
  import heavyball
11
- import heavyball.utils
12
11
  from heavyball.utils import clean, set_torch
13
12
 
14
13
  os.environ["TORCH_LOGS"] = "+recompiles"
@@ -5,7 +5,6 @@ from torch import nn
5
5
  from torch._dynamo import config
6
6
 
7
7
  import heavyball
8
- import heavyball.utils
9
8
  from heavyball.utils import clean, set_torch
10
9
 
11
10
  config.cache_size_limit = 128
@@ -5,7 +5,6 @@ from torch import nn
5
5
  from torch._dynamo import config
6
6
 
7
7
  import heavyball
8
- import heavyball.utils
9
8
  from heavyball.utils import clean, set_torch
10
9
 
11
10
  config.cache_size_limit = 128
@@ -9,7 +9,6 @@ from torch import nn
9
9
  from torch._dynamo import config
10
10
 
11
11
  import heavyball
12
- import heavyball.utils
13
12
  from heavyball.utils import clean, set_torch
14
13
 
15
14
  config.cache_size_limit = 128
@@ -0,0 +1,65 @@
1
+ import os
2
+
3
+ import torch
4
+
5
+ import heavyball.chainable as C
6
+ import heavyball.utils
7
+
8
+ os.environ.setdefault("TORCH_COMPILE_DISABLE", "1")
9
+ heavyball.utils.compile_mode = None
10
+
11
+
12
+ def _identity_update(state, group, update, grad, param):
13
+ return update
14
+
15
+
16
+ def test_chain_applies_update_on_cpu():
17
+ param = [torch.nn.Parameter(torch.zeros(2))]
18
+ grad = [torch.ones(2)]
19
+ group = {"lr": 0.1, "caution": False, "weight_decay": 0.0}
20
+
21
+ with torch.no_grad():
22
+ C.chain(lambda _: {}, group, grad, param, _identity_update)
23
+
24
+ assert torch.allclose(param[0].detach(), torch.full((2,), -0.1))
25
+
26
+
27
+ def test_branch_merges_multiple_paths():
28
+ def double(_, __, update, ___, ____):
29
+ return [u * 2 for u in update]
30
+
31
+ def negate(_, __, update, ___, ____):
32
+ return [u * -1 for u in update]
33
+
34
+ def merge_fn(outputs):
35
+ return [sum(vals) / len(vals) for vals in zip(*outputs)]
36
+
37
+ branch = C.Branch([[double], [negate]], merge_fn)
38
+
39
+ update = [torch.ones(2)]
40
+ grad = [torch.ones(2)]
41
+ param = [torch.nn.Parameter(torch.ones(2))]
42
+
43
+ result = branch(lambda _: {}, {}, update, grad, param)
44
+ expected = torch.full_like(update[0], 0.5)
45
+ assert torch.allclose(result[0], expected)
46
+
47
+
48
+ def test_set_indices_assigns_transform_ids():
49
+ def base(_, __, update, ___, ____, buffer):
50
+ assert buffer is not None
51
+ return update
52
+
53
+ zero_guard = C.ZeroGuard(base, ["buffer"])
54
+ assigned = C.set_indices([zero_guard], retain=False)[0]
55
+ assert assigned.transform_idx == 0
56
+
57
+ def state_fn(_x):
58
+ return {}
59
+
60
+ group = {"storage_dtype": "float32"}
61
+ update = [torch.ones(1)]
62
+ grad = [torch.ones(1)]
63
+ param = [torch.nn.Parameter(torch.ones(1))]
64
+
65
+ assigned(state_fn, group, update, grad, param)
@@ -9,7 +9,6 @@ from torch import nn
9
9
  from torch._dynamo import config
10
10
 
11
11
  import heavyball
12
- import heavyball.utils
13
12
  from heavyball.utils import clean, set_torch
14
13
 
15
14
  heavyball.utils.zeroth_power_mode = "newtonschulz"
@@ -6,7 +6,6 @@ from lightbench.utils import get_optim
6
6
  from torch import nn
7
7
 
8
8
  import heavyball
9
- import heavyball.utils
10
9
  from heavyball.utils import clean, set_torch
11
10
 
12
11
 
@@ -0,0 +1,134 @@
1
+ """Fast CPU-only smoke tests for non-PSGD HeavyBall features."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from copy import deepcopy
6
+
7
+ import pytest
8
+ import torch
9
+ from torch import nn
10
+
11
+ import heavyball
12
+
13
+
14
+ def _train_once(optimizer, model: nn.Module, data: torch.Tensor, target: torch.Tensor, steps: int = 3) -> float:
15
+ """Run a few optimization steps and return the final loss."""
16
+
17
+ loss = torch.tensor(float("nan"))
18
+ for _ in range(steps):
19
+ optimizer.zero_grad(set_to_none=True)
20
+ prediction = model(data)
21
+ loss = torch.nn.functional.mse_loss(prediction, target)
22
+ loss.backward()
23
+ optimizer.step()
24
+ return loss.item()
25
+
26
+
27
+ def _parameter_drift(model: nn.Module, original: list[torch.Tensor]) -> float:
28
+ current = [param.detach() for param in model.parameters()]
29
+ diffs = [curr - init for curr, init in zip(current, original, strict=True)]
30
+ stacked = torch.cat([diff.reshape(-1) for diff in diffs])
31
+ return stacked.norm().item()
32
+
33
+
34
+ def _make_batch(
35
+ in_features: int = 8, out_features: int = 4, batch: int = 16
36
+ ) -> tuple[nn.Module, torch.Tensor, torch.Tensor]:
37
+ torch.manual_seed(0x172893)
38
+ model = nn.Sequential(nn.Linear(in_features, out_features), nn.ReLU(), nn.Linear(out_features, out_features))
39
+ data = torch.randn(batch, in_features)
40
+ target = torch.randn(batch, out_features)
41
+ return model, data, target
42
+
43
+
44
+ @pytest.mark.parametrize(
45
+ "opt_name",
46
+ [
47
+ "ForeachSOAP",
48
+ "Muon",
49
+ "ForeachAdamW",
50
+ ],
51
+ )
52
+ def test_selected_optimizers_run_on_cpu(opt_name: str) -> None:
53
+ model, data, target = _make_batch()
54
+ init = [param.detach().clone() for param in model.parameters()]
55
+
56
+ opt_cls = getattr(heavyball, opt_name)
57
+ optimizer = opt_cls(model.parameters(), warmup_steps=0)
58
+ final_loss = _train_once(optimizer, model, data, target, steps=3)
59
+
60
+ assert torch.isfinite(torch.tensor(final_loss))
61
+ assert _parameter_drift(model, init) > 0.0
62
+
63
+
64
+ def test_caution_reduces_update_magnitude() -> None:
65
+ baseline_model, data, target = _make_batch()
66
+ cautious_model = deepcopy(baseline_model)
67
+
68
+ baseline_init = [param.detach().clone() for param in baseline_model.parameters()]
69
+ cautious_init = [param.detach().clone() for param in cautious_model.parameters()]
70
+
71
+ baseline_opt = heavyball.SGD(
72
+ baseline_model.parameters(),
73
+ lr=1e-3,
74
+ caution=False,
75
+ )
76
+ cautious_opt = heavyball.SGD(
77
+ cautious_model.parameters(),
78
+ lr=1e-3,
79
+ caution=True,
80
+ )
81
+
82
+ _train_once(baseline_opt, baseline_model, data, target)
83
+ _train_once(cautious_opt, cautious_model, data, target)
84
+
85
+ baseline_drift = _parameter_drift(baseline_model, baseline_init)
86
+ cautious_drift = _parameter_drift(cautious_model, cautious_init)
87
+
88
+ assert cautious_drift <= baseline_drift * 1.05 # caution should not overshoot compared to baseline
89
+
90
+
91
+ def test_mars_flag_changes_behavior() -> None:
92
+ model_a, data, target = _make_batch()
93
+ model_b = deepcopy(model_a)
94
+
95
+ opt_a = heavyball.ForeachAdamW(model_a.parameters(), mars=False, warmup_steps=0)
96
+ opt_b = heavyball.ForeachAdamW(model_b.parameters(), mars=True, warmup_steps=0)
97
+
98
+ init = [param.detach().clone() for param in model_a.parameters()]
99
+
100
+ _train_once(opt_a, model_a, data, target)
101
+ _train_once(opt_b, model_b, data, target)
102
+
103
+ baseline_drift = _parameter_drift(model_a, init)
104
+ mars_drift = _parameter_drift(model_b, init)
105
+ assert baseline_drift > 0.0
106
+ assert mars_drift > 0.0
107
+
108
+ deltas = [a.detach() - b.detach() for a, b in zip(model_a.parameters(), model_b.parameters(), strict=True)]
109
+ combined = torch.cat([delta.reshape(-1) for delta in deltas])
110
+ assert combined.norm().item() > 1e-6 # mars path should diverge from baseline
111
+
112
+
113
+ def test_sam_wrapper_requires_closure() -> None:
114
+ model = nn.Linear(4, 2)
115
+ base = heavyball.ForeachAdamW(model.parameters())
116
+ wrapper = heavyball.SAMWrapper(model.parameters(), wrapped_optimizer=base)
117
+
118
+ with pytest.raises(ValueError):
119
+ wrapper.step()
120
+
121
+ data = torch.randn(8, 4)
122
+ target = torch.randn(8, 2)
123
+
124
+ def closure():
125
+ wrapper.zero_grad()
126
+ loss = torch.nn.functional.mse_loss(model(data), target)
127
+ loss.backward()
128
+ return loss
129
+
130
+ before = [param.detach().clone() for param in model.parameters()]
131
+ wrapper.step(closure)
132
+ after = [param.detach() for param in model.parameters()]
133
+ diff = torch.cat([(a - b).reshape(-1) for a, b in zip(after, before, strict=True)])
134
+ assert diff.norm().item() > 0.0
@@ -5,7 +5,6 @@ from torch import nn
5
5
  from torch._dynamo import config
6
6
 
7
7
  import heavyball
8
- import heavyball.utils
9
8
  from heavyball.utils import clean, set_torch
10
9
 
11
10
  config.cache_size_limit = 128
@@ -4,7 +4,6 @@ from lightbench.utils import get_optim
4
4
  from torch import nn
5
5
 
6
6
  import heavyball
7
- import heavyball.utils
8
7
  from heavyball.utils import clean, set_torch
9
8
 
10
9
 
@@ -0,0 +1,107 @@
1
+ import numpy as np
2
+ import optuna
3
+ import pandas as pd
4
+ import torch
5
+ from optuna.distributions import FloatDistribution, IntDistribution
6
+ from optuna.samplers import RandomSampler
7
+ from optuna.trial import TrialState
8
+
9
+ from heavyball import helpers
10
+
11
+
12
+ def test_bound_to_torch_roundtrip_cpu():
13
+ arr = np.arange(4, dtype=np.float64).reshape(2, 2)
14
+ tensor = helpers.bound_to_torch(arr.tobytes(), arr.shape, "cpu")
15
+ assert torch.allclose(tensor, torch.from_numpy(arr.T))
16
+
17
+
18
+ def test_nextafter_matches_numpy():
19
+ forward = helpers.nextafter(0.5, 1.0)
20
+ backward = helpers.nextafter(1, 0)
21
+ assert forward == np.nextafter(0.5, 1.0)
22
+ assert backward == np.nextafter(1, 0)
23
+
24
+
25
+ def test_untransform_numerical_param_torch_handles_steps():
26
+ dist = FloatDistribution(0.0, 1.0, step=0.1)
27
+ value = torch.tensor(0.46)
28
+ untransformed = helpers._untransform_numerical_param_torch(value, dist, transform_log=False)
29
+ assert torch.isclose(untransformed, torch.tensor(0.5))
30
+
31
+
32
+ def test_simple_api_sampler_suggest_all_returns_expected():
33
+ distributions = {"x": FloatDistribution(0.0, 1.0), "y": IntDistribution(0, 3, step=1)}
34
+
35
+ class _Sampler(helpers.SimpleAPIBaseSampler):
36
+ def infer_relative_search_space(self, study, trial):
37
+ return self.search_space
38
+
39
+ def sample_relative(self, study, trial, search_space):
40
+ return {}
41
+
42
+ def sample_independent(self, study, trial, param_name, param_distribution):
43
+ return trial.params[param_name]
44
+
45
+ sampler = _Sampler(distributions)
46
+
47
+ class DummyTrial:
48
+ def __init__(self, params):
49
+ self.params = params
50
+
51
+ def _suggest(self, name, dist):
52
+ return self.params[name]
53
+
54
+ trial = DummyTrial({"x": 0.25, "y": 2})
55
+ suggestions = sampler.suggest_all(trial)
56
+ assert suggestions == {"x": 0.25, "y": 2}
57
+
58
+
59
+ def test_botorch_sampler_sample_relative_smoke(monkeypatch):
60
+ search_space = {"width": FloatDistribution(0.0, 1.0)}
61
+ study = optuna.create_study(direction="minimize", sampler=RandomSampler(seed=0))
62
+ for _ in range(3):
63
+ trial = study.ask()
64
+ width = trial.suggest_float("width", 0.0, 1.0)
65
+ study.tell(trial, width)
66
+
67
+ sampler = helpers.BoTorchSampler(search_space, n_startup_trials=1, seed=0, device="cpu")
68
+
69
+ def _dummy_candidates(params, values, *_args):
70
+ assert params.shape[1] == 1
71
+ return params.mean(dim=0)
72
+
73
+ sampler._candidates_func = _dummy_candidates
74
+
75
+ pending = study.ask()
76
+ suggestion = sampler.sample_relative(study, pending, search_space)
77
+ assert "width" in suggestion
78
+ assert 0.0 <= suggestion["width"] <= 1.0
79
+
80
+
81
+ def test_hebo_sampler_observe_and_sample(monkeypatch):
82
+ class DummyHEBO:
83
+ def __init__(self, *_args, **_kwargs):
84
+ self.observed = None
85
+
86
+ def suggest(self):
87
+ return pd.DataFrame([{"depth": 0.0}])
88
+
89
+ def observe(self, params, values):
90
+ self.observed = (params, values)
91
+
92
+ monkeypatch.setattr(helpers, "HEBO", DummyHEBO)
93
+
94
+ search_space = {"depth": FloatDistribution(0.0, 1.0)}
95
+ sampler = helpers.HEBOSampler(search_space, seed=1)
96
+
97
+ study = optuna.create_study(direction="minimize", sampler=RandomSampler(seed=1))
98
+ trial = study.ask()
99
+ trial.suggest_float("depth", 0.0, 1.0)
100
+ study.tell(trial, 0.2)
101
+
102
+ suggestion = sampler.sample_relative(study, study.ask(), search_space)
103
+ assert suggestion["depth"] == 0.0
104
+
105
+ completed = study.get_trials(deepcopy=False)[0]
106
+ sampler.after_trial(study, completed, TrialState.COMPLETE, [0.2])
107
+ assert sampler._hebo.observed is not None
@@ -9,7 +9,6 @@ from torch import nn
9
9
  from torch._dynamo import config
10
10
 
11
11
  import heavyball
12
- import heavyball.utils
13
12
  from heavyball.utils import clean, hook_optimizer_into_model, set_torch
14
13
 
15
14
  heavyball.utils.compile_mode = "default"
@@ -5,7 +5,6 @@ from torch import nn
5
5
  from torch._dynamo import config
6
6
 
7
7
  import heavyball
8
- import heavyball.utils
9
8
  from heavyball.utils import clean, set_torch
10
9
 
11
10
  config.cache_size_limit = 128
@@ -4,7 +4,6 @@ from lightbench.utils import get_optim
4
4
  from torch import nn
5
5
 
6
6
  import heavyball
7
- import heavyball.utils
8
7
  from heavyball.utils import clean, set_torch
9
8
 
10
9
 
@@ -6,7 +6,6 @@ from torch import nn
6
6
  from torch.nn import functional as F
7
7
 
8
8
  import heavyball
9
- import heavyball.utils
10
9
  from heavyball.utils import clean, set_torch
11
10
 
12
11
 
@@ -6,7 +6,6 @@ from lightbench.utils import get_optim
6
6
  from torch import nn
7
7
 
8
8
  import heavyball
9
- import heavyball.utils
10
9
  from heavyball.utils import clean, set_torch
11
10
 
12
11
 
@@ -5,7 +5,6 @@ from torch import nn
5
5
  from torch._dynamo import config
6
6
 
7
7
  import heavyball
8
- import heavyball.utils
9
8
  from heavyball.utils import set_torch
10
9
 
11
10
  config.cache_size_limit = 2**20