heavyball 2.1.2__tar.gz → 2.1.3__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {heavyball-2.1.2 → heavyball-2.1.3}/PKG-INFO +2 -1
- {heavyball-2.1.2 → heavyball-2.1.3}/heavyball/helpers.py +39 -9
- {heavyball-2.1.2 → heavyball-2.1.3}/heavyball/utils.py +16 -10
- {heavyball-2.1.2 → heavyball-2.1.3}/heavyball.egg-info/PKG-INFO +2 -1
- {heavyball-2.1.2 → heavyball-2.1.3}/heavyball.egg-info/SOURCES.txt +7 -1
- {heavyball-2.1.2 → heavyball-2.1.3}/heavyball.egg-info/requires.txt +1 -0
- {heavyball-2.1.2 → heavyball-2.1.3}/pyproject.toml +2 -2
- {heavyball-2.1.2 → heavyball-2.1.3}/test/test_bf16_params.py +0 -1
- {heavyball-2.1.2 → heavyball-2.1.3}/test/test_bf16_q.py +0 -1
- {heavyball-2.1.2 → heavyball-2.1.3}/test/test_bf16_storage.py +0 -1
- {heavyball-2.1.2 → heavyball-2.1.3}/test/test_caution.py +0 -1
- heavyball-2.1.3/test/test_chainable_cpu.py +65 -0
- {heavyball-2.1.2 → heavyball-2.1.3}/test/test_channels_last.py +0 -1
- {heavyball-2.1.2 → heavyball-2.1.3}/test/test_closure.py +0 -1
- heavyball-2.1.3/test/test_cpu_features.py +134 -0
- {heavyball-2.1.2 → heavyball-2.1.3}/test/test_ema.py +0 -1
- {heavyball-2.1.2 → heavyball-2.1.3}/test/test_foreach.py +0 -1
- heavyball-2.1.3/test/test_helpers_cpu.py +107 -0
- {heavyball-2.1.2 → heavyball-2.1.3}/test/test_hook.py +0 -1
- {heavyball-2.1.2 → heavyball-2.1.3}/test/test_mars.py +0 -1
- {heavyball-2.1.2 → heavyball-2.1.3}/test/test_memory.py +0 -1
- {heavyball-2.1.2 → heavyball-2.1.3}/test/test_memory_leak.py +0 -1
- {heavyball-2.1.2 → heavyball-2.1.3}/test/test_merge.py +0 -1
- {heavyball-2.1.2 → heavyball-2.1.3}/test/test_nd_param.py +0 -1
- {heavyball-2.1.2 → heavyball-2.1.3}/test/test_no_grad.py +0 -1
- heavyball-2.1.3/test/test_optimizer_cpu_smoke.py +65 -0
- {heavyball-2.1.2 → heavyball-2.1.3}/test/test_save_restore.py +0 -1
- {heavyball-2.1.2 → heavyball-2.1.3}/test/test_singular_values.py +1 -1
- {heavyball-2.1.2 → heavyball-2.1.3}/test/test_stochastic_updates.py +0 -1
- {heavyball-2.1.2 → heavyball-2.1.3}/test/test_toy_training.py +4 -4
- heavyball-2.1.3/test/test_utils_cpu.py +295 -0
- heavyball-2.1.3/test/test_utils_property.py +281 -0
- {heavyball-2.1.2 → heavyball-2.1.3}/LICENSE +0 -0
- {heavyball-2.1.2 → heavyball-2.1.3}/README.md +0 -0
- {heavyball-2.1.2 → heavyball-2.1.3}/heavyball/__init__.py +0 -0
- {heavyball-2.1.2 → heavyball-2.1.3}/heavyball/chainable.py +0 -0
- {heavyball-2.1.2 → heavyball-2.1.3}/heavyball.egg-info/dependency_links.txt +0 -0
- {heavyball-2.1.2 → heavyball-2.1.3}/heavyball.egg-info/top_level.txt +0 -0
- {heavyball-2.1.2 → heavyball-2.1.3}/setup.cfg +0 -0
- {heavyball-2.1.2 → heavyball-2.1.3}/test/test_clip.py +0 -0
- {heavyball-2.1.2 → heavyball-2.1.3}/test/test_migrate_cli.py +0 -0
- {heavyball-2.1.2 → heavyball-2.1.3}/test/test_psgd_precond_init_stability.py +0 -0
- {heavyball-2.1.2 → heavyball-2.1.3}/test/test_soap.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: heavyball
|
3
|
-
Version: 2.1.
|
3
|
+
Version: 2.1.3
|
4
4
|
Summary: Efficient Optimizers
|
5
5
|
Author-email: HeavyBall Authors <github.heavyball@nestler.sh>
|
6
6
|
Project-URL: source, https://github.com/HomebrewML/HeavyBall
|
@@ -21,6 +21,7 @@ Requires-Dist: numpy<2.0.0
|
|
21
21
|
Provides-Extra: dev
|
22
22
|
Requires-Dist: pre-commit; extra == "dev"
|
23
23
|
Requires-Dist: pytest; extra == "dev"
|
24
|
+
Requires-Dist: hypothesis; extra == "dev"
|
24
25
|
Requires-Dist: ruff; extra == "dev"
|
25
26
|
Requires-Dist: matplotlib; extra == "dev"
|
26
27
|
Requires-Dist: seaborn; extra == "dev"
|
@@ -3,7 +3,8 @@ from __future__ import annotations
|
|
3
3
|
import functools
|
4
4
|
import math
|
5
5
|
import threading
|
6
|
-
from
|
6
|
+
from contextlib import contextmanager
|
7
|
+
from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Sequence, Tuple, Union
|
7
8
|
|
8
9
|
import numpy
|
9
10
|
import numpy as np
|
@@ -11,7 +12,6 @@ import optuna
|
|
11
12
|
import optunahub
|
12
13
|
import pandas as pd
|
13
14
|
import torch
|
14
|
-
from botorch.utils.sampling import manual_seed
|
15
15
|
from hebo.design_space.design_space import DesignSpace
|
16
16
|
from hebo.optimizers.hebo import HEBO
|
17
17
|
from optuna._transform import _SearchSpaceTransform
|
@@ -21,13 +21,6 @@ from optuna.samplers._lazy_random_state import LazyRandomState
|
|
21
21
|
from optuna.study import Study
|
22
22
|
from optuna.study._study_direction import StudyDirection
|
23
23
|
from optuna.trial import FrozenTrial, TrialState
|
24
|
-
from optuna_integration.botorch import (
|
25
|
-
ehvi_candidates_func,
|
26
|
-
logei_candidates_func,
|
27
|
-
qehvi_candidates_func,
|
28
|
-
qei_candidates_func,
|
29
|
-
qparego_candidates_func,
|
30
|
-
)
|
31
24
|
from torch import Tensor
|
32
25
|
from torch.nn import functional as F
|
33
26
|
|
@@ -37,6 +30,33 @@ _MAXINT32 = (1 << 31) - 1
|
|
37
30
|
_SAMPLER_KEY = "auto:sampler"
|
38
31
|
|
39
32
|
|
33
|
+
@contextmanager
|
34
|
+
def manual_seed(seed: int | None = None) -> Generator[None, None, None]:
|
35
|
+
r"""
|
36
|
+
Contextmanager for manual setting the torch.random seed.
|
37
|
+
|
38
|
+
Args:
|
39
|
+
seed: The seed to set the random number generator to.
|
40
|
+
|
41
|
+
Returns:
|
42
|
+
Generator
|
43
|
+
|
44
|
+
Example:
|
45
|
+
>>> with manual_seed(1234):
|
46
|
+
>>> X = torch.rand(3)
|
47
|
+
|
48
|
+
copied as-is from https://github.com/meta-pytorch/botorch/blob/a42cd65f9b704cdb6f2ee64db99a022eb15295d5/botorch/utils/sampling.py#L53C1-L75C50 under the MIT License
|
49
|
+
"""
|
50
|
+
old_state = torch.random.get_rng_state()
|
51
|
+
try:
|
52
|
+
if seed is not None:
|
53
|
+
torch.random.manual_seed(seed)
|
54
|
+
yield
|
55
|
+
finally:
|
56
|
+
if seed is not None:
|
57
|
+
torch.random.set_rng_state(old_state)
|
58
|
+
|
59
|
+
|
40
60
|
class SimpleAPIBaseSampler(BaseSampler):
|
41
61
|
def __init__(
|
42
62
|
self,
|
@@ -65,6 +85,16 @@ def _get_default_candidates_func(
|
|
65
85
|
"""
|
66
86
|
The original is available at https://github.com/optuna/optuna-integration/blob/156a8bc081322791015d2beefff9373ed7b24047/optuna_integration/botorch/botorch.py under the MIT License
|
67
87
|
"""
|
88
|
+
|
89
|
+
# lazy import
|
90
|
+
from optuna_integration.botorch import (
|
91
|
+
ehvi_candidates_func,
|
92
|
+
logei_candidates_func,
|
93
|
+
qehvi_candidates_func,
|
94
|
+
qei_candidates_func,
|
95
|
+
qparego_candidates_func,
|
96
|
+
)
|
97
|
+
|
68
98
|
if n_objectives > 3 and not has_constraint and not consider_running_trials:
|
69
99
|
return ehvi_candidates_func
|
70
100
|
elif n_objectives > 3:
|
@@ -47,7 +47,7 @@ _cudnn_double_backward_pattern = re.compile(
|
|
47
47
|
)
|
48
48
|
_torch_compile_double_backward_pattern = re.compile(r"compile.*does not currently support double backward")
|
49
49
|
_fd_error = (
|
50
|
-
"You can accelerate startup by globally enabling finite_differences first "
|
50
|
+
"You can accelerate startup by globally enabling finite_differences first "
|
51
51
|
"(via opt.finite_differences=True or by subclassing it)\n"
|
52
52
|
"Original Error: "
|
53
53
|
)
|
@@ -418,9 +418,13 @@ def zeropower_via_newtonschulz5(G, steps=5, eps=1e-7):
|
|
418
418
|
|
419
419
|
|
420
420
|
###### START
|
421
|
-
#
|
421
|
+
# Based on https://arxiv.org/pdf/2505.16932v3
|
422
|
+
# and https://github.com/NoahAmsel/PolarExpress/blob/5454910920ca8c65afda28820cdf9e49b9436ed0/polar_express.py#L69-L82
|
423
|
+
# and https://github.com/thinking-machines-lab/manifolds/blob/89dcae50f01af59f1e0570289474da3a2ecaa60b/src/msign.py#L47
|
424
|
+
#
|
422
425
|
# under the MIT License
|
423
426
|
|
427
|
+
# Coefficients are from https://arxiv.org/pdf/2505.16932v3
|
424
428
|
ABC_LIST: list[tuple[float, float, float]] = [
|
425
429
|
(8.28721201814563, -23.595886519098837, 17.300387312530933),
|
426
430
|
(4.107059111542203, -2.9478499167379106, 0.5448431082926601),
|
@@ -438,7 +442,7 @@ ABC_LIST_STABLE: list[tuple[float, float, float]] = [
|
|
438
442
|
] + [ABC_LIST[-1]]
|
439
443
|
|
440
444
|
|
441
|
-
def msign(G: torch.Tensor, steps: int = 10) -> torch.Tensor:
|
445
|
+
def msign(G: torch.Tensor, steps: int = 10, eps: float = 1e-7) -> torch.Tensor:
|
442
446
|
"""
|
443
447
|
Polar Express algorithm for the matrix sign function:
|
444
448
|
https://arxiv.org/abs/2505.16932
|
@@ -450,7 +454,9 @@ def msign(G: torch.Tensor, steps: int = 10) -> torch.Tensor:
|
|
450
454
|
if should_transpose:
|
451
455
|
x = x.mT
|
452
456
|
|
453
|
-
x
|
457
|
+
# x = x / (x.norm(dim=(-2, -1), keepdim=True) * 1.01 + eps)
|
458
|
+
stochastic_divide_with_eps_(x, x.norm(dim=(-2, -1)) * 1.01, eps)
|
459
|
+
|
454
460
|
for step in range(steps):
|
455
461
|
a, b, c = ABC_LIST_STABLE[step] if step < len(ABC_LIST_STABLE) else ABC_LIST_STABLE[-1]
|
456
462
|
s = x @ x.mT
|
@@ -464,7 +470,6 @@ def msign(G: torch.Tensor, steps: int = 10) -> torch.Tensor:
|
|
464
470
|
|
465
471
|
if should_transpose:
|
466
472
|
x = x.mT
|
467
|
-
x = torch.nan_to_num(x)
|
468
473
|
return x.float()
|
469
474
|
|
470
475
|
|
@@ -1585,7 +1590,7 @@ def _compilable_copy_stochastic_(target: Tensor, source: Tensor):
|
|
1585
1590
|
|
1586
1591
|
def copy_stochastic_(target: Tensor, source: Tensor):
|
1587
1592
|
if target.dtype == torch.bfloat16 and source.dtype in (torch.float16, torch.float32, torch.float64):
|
1588
|
-
|
1593
|
+
source = stochastic_round_(target, source)
|
1589
1594
|
set_(target, source)
|
1590
1595
|
|
1591
1596
|
|
@@ -2412,10 +2417,11 @@ def bf16_matmul(x: Tensor, y: Tensor):
|
|
2412
2417
|
def if_iscompiling(fn):
|
2413
2418
|
base = getattr(torch, fn.__name__, None)
|
2414
2419
|
|
2415
|
-
|
2416
|
-
|
2417
|
-
|
2418
|
-
|
2420
|
+
@functools.wraps(fn)
|
2421
|
+
def _fn(*args, **kwargs):
|
2422
|
+
if torch.compiler.is_compiling() and base is not None:
|
2423
|
+
return base(*args, **kwargs)
|
2424
|
+
return fn(*args, **kwargs)
|
2419
2425
|
|
2420
2426
|
return _fn
|
2421
2427
|
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: heavyball
|
3
|
-
Version: 2.1.
|
3
|
+
Version: 2.1.3
|
4
4
|
Summary: Efficient Optimizers
|
5
5
|
Author-email: HeavyBall Authors <github.heavyball@nestler.sh>
|
6
6
|
Project-URL: source, https://github.com/HomebrewML/HeavyBall
|
@@ -21,6 +21,7 @@ Requires-Dist: numpy<2.0.0
|
|
21
21
|
Provides-Extra: dev
|
22
22
|
Requires-Dist: pre-commit; extra == "dev"
|
23
23
|
Requires-Dist: pytest; extra == "dev"
|
24
|
+
Requires-Dist: hypothesis; extra == "dev"
|
24
25
|
Requires-Dist: ruff; extra == "dev"
|
25
26
|
Requires-Dist: matplotlib; extra == "dev"
|
26
27
|
Requires-Dist: seaborn; extra == "dev"
|
@@ -14,11 +14,14 @@ test/test_bf16_params.py
|
|
14
14
|
test/test_bf16_q.py
|
15
15
|
test/test_bf16_storage.py
|
16
16
|
test/test_caution.py
|
17
|
+
test/test_chainable_cpu.py
|
17
18
|
test/test_channels_last.py
|
18
19
|
test/test_clip.py
|
19
20
|
test/test_closure.py
|
21
|
+
test/test_cpu_features.py
|
20
22
|
test/test_ema.py
|
21
23
|
test/test_foreach.py
|
24
|
+
test/test_helpers_cpu.py
|
22
25
|
test/test_hook.py
|
23
26
|
test/test_mars.py
|
24
27
|
test/test_memory.py
|
@@ -27,9 +30,12 @@ test/test_merge.py
|
|
27
30
|
test/test_migrate_cli.py
|
28
31
|
test/test_nd_param.py
|
29
32
|
test/test_no_grad.py
|
33
|
+
test/test_optimizer_cpu_smoke.py
|
30
34
|
test/test_psgd_precond_init_stability.py
|
31
35
|
test/test_save_restore.py
|
32
36
|
test/test_singular_values.py
|
33
37
|
test/test_soap.py
|
34
38
|
test/test_stochastic_updates.py
|
35
|
-
test/test_toy_training.py
|
39
|
+
test/test_toy_training.py
|
40
|
+
test/test_utils_cpu.py
|
41
|
+
test/test_utils_property.py
|
@@ -5,7 +5,7 @@ build-backend = "setuptools.build_meta"
|
|
5
5
|
[project]
|
6
6
|
name = "heavyball"
|
7
7
|
description = "Efficient Optimizers"
|
8
|
-
version = "2.1.
|
8
|
+
version = "2.1.3"
|
9
9
|
authors = [{ name = "HeavyBall Authors", email = "github.heavyball@nestler.sh" }]
|
10
10
|
classifiers = ["Intended Audience :: Developers",
|
11
11
|
"Intended Audience :: Science/Research",
|
@@ -28,7 +28,7 @@ readme = "README.md"
|
|
28
28
|
requires-python = ">=3.9"
|
29
29
|
|
30
30
|
[project.optional-dependencies]
|
31
|
-
dev = ["pre-commit", "pytest", "ruff", "matplotlib", "seaborn", "pandas", "typer", "optuna", "optunahub", "hebo", "lightbench"]
|
31
|
+
dev = ["pre-commit", "pytest", "hypothesis", "ruff", "matplotlib", "seaborn", "pandas", "typer", "optuna", "optunahub", "hebo", "lightbench"]
|
32
32
|
|
33
33
|
[project.urls]
|
34
34
|
source = "https://github.com/HomebrewML/HeavyBall"
|
@@ -0,0 +1,65 @@
|
|
1
|
+
import os
|
2
|
+
|
3
|
+
import torch
|
4
|
+
|
5
|
+
import heavyball.chainable as C
|
6
|
+
import heavyball.utils
|
7
|
+
|
8
|
+
os.environ.setdefault("TORCH_COMPILE_DISABLE", "1")
|
9
|
+
heavyball.utils.compile_mode = None
|
10
|
+
|
11
|
+
|
12
|
+
def _identity_update(state, group, update, grad, param):
|
13
|
+
return update
|
14
|
+
|
15
|
+
|
16
|
+
def test_chain_applies_update_on_cpu():
|
17
|
+
param = [torch.nn.Parameter(torch.zeros(2))]
|
18
|
+
grad = [torch.ones(2)]
|
19
|
+
group = {"lr": 0.1, "caution": False, "weight_decay": 0.0}
|
20
|
+
|
21
|
+
with torch.no_grad():
|
22
|
+
C.chain(lambda _: {}, group, grad, param, _identity_update)
|
23
|
+
|
24
|
+
assert torch.allclose(param[0].detach(), torch.full((2,), -0.1))
|
25
|
+
|
26
|
+
|
27
|
+
def test_branch_merges_multiple_paths():
|
28
|
+
def double(_, __, update, ___, ____):
|
29
|
+
return [u * 2 for u in update]
|
30
|
+
|
31
|
+
def negate(_, __, update, ___, ____):
|
32
|
+
return [u * -1 for u in update]
|
33
|
+
|
34
|
+
def merge_fn(outputs):
|
35
|
+
return [sum(vals) / len(vals) for vals in zip(*outputs)]
|
36
|
+
|
37
|
+
branch = C.Branch([[double], [negate]], merge_fn)
|
38
|
+
|
39
|
+
update = [torch.ones(2)]
|
40
|
+
grad = [torch.ones(2)]
|
41
|
+
param = [torch.nn.Parameter(torch.ones(2))]
|
42
|
+
|
43
|
+
result = branch(lambda _: {}, {}, update, grad, param)
|
44
|
+
expected = torch.full_like(update[0], 0.5)
|
45
|
+
assert torch.allclose(result[0], expected)
|
46
|
+
|
47
|
+
|
48
|
+
def test_set_indices_assigns_transform_ids():
|
49
|
+
def base(_, __, update, ___, ____, buffer):
|
50
|
+
assert buffer is not None
|
51
|
+
return update
|
52
|
+
|
53
|
+
zero_guard = C.ZeroGuard(base, ["buffer"])
|
54
|
+
assigned = C.set_indices([zero_guard], retain=False)[0]
|
55
|
+
assert assigned.transform_idx == 0
|
56
|
+
|
57
|
+
def state_fn(_x):
|
58
|
+
return {}
|
59
|
+
|
60
|
+
group = {"storage_dtype": "float32"}
|
61
|
+
update = [torch.ones(1)]
|
62
|
+
grad = [torch.ones(1)]
|
63
|
+
param = [torch.nn.Parameter(torch.ones(1))]
|
64
|
+
|
65
|
+
assigned(state_fn, group, update, grad, param)
|
@@ -0,0 +1,134 @@
|
|
1
|
+
"""Fast CPU-only smoke tests for non-PSGD HeavyBall features."""
|
2
|
+
|
3
|
+
from __future__ import annotations
|
4
|
+
|
5
|
+
from copy import deepcopy
|
6
|
+
|
7
|
+
import pytest
|
8
|
+
import torch
|
9
|
+
from torch import nn
|
10
|
+
|
11
|
+
import heavyball
|
12
|
+
|
13
|
+
|
14
|
+
def _train_once(optimizer, model: nn.Module, data: torch.Tensor, target: torch.Tensor, steps: int = 3) -> float:
|
15
|
+
"""Run a few optimization steps and return the final loss."""
|
16
|
+
|
17
|
+
loss = torch.tensor(float("nan"))
|
18
|
+
for _ in range(steps):
|
19
|
+
optimizer.zero_grad(set_to_none=True)
|
20
|
+
prediction = model(data)
|
21
|
+
loss = torch.nn.functional.mse_loss(prediction, target)
|
22
|
+
loss.backward()
|
23
|
+
optimizer.step()
|
24
|
+
return loss.item()
|
25
|
+
|
26
|
+
|
27
|
+
def _parameter_drift(model: nn.Module, original: list[torch.Tensor]) -> float:
|
28
|
+
current = [param.detach() for param in model.parameters()]
|
29
|
+
diffs = [curr - init for curr, init in zip(current, original, strict=True)]
|
30
|
+
stacked = torch.cat([diff.reshape(-1) for diff in diffs])
|
31
|
+
return stacked.norm().item()
|
32
|
+
|
33
|
+
|
34
|
+
def _make_batch(
|
35
|
+
in_features: int = 8, out_features: int = 4, batch: int = 16
|
36
|
+
) -> tuple[nn.Module, torch.Tensor, torch.Tensor]:
|
37
|
+
torch.manual_seed(0x172893)
|
38
|
+
model = nn.Sequential(nn.Linear(in_features, out_features), nn.ReLU(), nn.Linear(out_features, out_features))
|
39
|
+
data = torch.randn(batch, in_features)
|
40
|
+
target = torch.randn(batch, out_features)
|
41
|
+
return model, data, target
|
42
|
+
|
43
|
+
|
44
|
+
@pytest.mark.parametrize(
|
45
|
+
"opt_name",
|
46
|
+
[
|
47
|
+
"ForeachSOAP",
|
48
|
+
"Muon",
|
49
|
+
"ForeachAdamW",
|
50
|
+
],
|
51
|
+
)
|
52
|
+
def test_selected_optimizers_run_on_cpu(opt_name: str) -> None:
|
53
|
+
model, data, target = _make_batch()
|
54
|
+
init = [param.detach().clone() for param in model.parameters()]
|
55
|
+
|
56
|
+
opt_cls = getattr(heavyball, opt_name)
|
57
|
+
optimizer = opt_cls(model.parameters(), warmup_steps=0)
|
58
|
+
final_loss = _train_once(optimizer, model, data, target, steps=3)
|
59
|
+
|
60
|
+
assert torch.isfinite(torch.tensor(final_loss))
|
61
|
+
assert _parameter_drift(model, init) > 0.0
|
62
|
+
|
63
|
+
|
64
|
+
def test_caution_reduces_update_magnitude() -> None:
|
65
|
+
baseline_model, data, target = _make_batch()
|
66
|
+
cautious_model = deepcopy(baseline_model)
|
67
|
+
|
68
|
+
baseline_init = [param.detach().clone() for param in baseline_model.parameters()]
|
69
|
+
cautious_init = [param.detach().clone() for param in cautious_model.parameters()]
|
70
|
+
|
71
|
+
baseline_opt = heavyball.SGD(
|
72
|
+
baseline_model.parameters(),
|
73
|
+
lr=1e-3,
|
74
|
+
caution=False,
|
75
|
+
)
|
76
|
+
cautious_opt = heavyball.SGD(
|
77
|
+
cautious_model.parameters(),
|
78
|
+
lr=1e-3,
|
79
|
+
caution=True,
|
80
|
+
)
|
81
|
+
|
82
|
+
_train_once(baseline_opt, baseline_model, data, target)
|
83
|
+
_train_once(cautious_opt, cautious_model, data, target)
|
84
|
+
|
85
|
+
baseline_drift = _parameter_drift(baseline_model, baseline_init)
|
86
|
+
cautious_drift = _parameter_drift(cautious_model, cautious_init)
|
87
|
+
|
88
|
+
assert cautious_drift <= baseline_drift * 1.05 # caution should not overshoot compared to baseline
|
89
|
+
|
90
|
+
|
91
|
+
def test_mars_flag_changes_behavior() -> None:
|
92
|
+
model_a, data, target = _make_batch()
|
93
|
+
model_b = deepcopy(model_a)
|
94
|
+
|
95
|
+
opt_a = heavyball.ForeachAdamW(model_a.parameters(), mars=False, warmup_steps=0)
|
96
|
+
opt_b = heavyball.ForeachAdamW(model_b.parameters(), mars=True, warmup_steps=0)
|
97
|
+
|
98
|
+
init = [param.detach().clone() for param in model_a.parameters()]
|
99
|
+
|
100
|
+
_train_once(opt_a, model_a, data, target)
|
101
|
+
_train_once(opt_b, model_b, data, target)
|
102
|
+
|
103
|
+
baseline_drift = _parameter_drift(model_a, init)
|
104
|
+
mars_drift = _parameter_drift(model_b, init)
|
105
|
+
assert baseline_drift > 0.0
|
106
|
+
assert mars_drift > 0.0
|
107
|
+
|
108
|
+
deltas = [a.detach() - b.detach() for a, b in zip(model_a.parameters(), model_b.parameters(), strict=True)]
|
109
|
+
combined = torch.cat([delta.reshape(-1) for delta in deltas])
|
110
|
+
assert combined.norm().item() > 1e-6 # mars path should diverge from baseline
|
111
|
+
|
112
|
+
|
113
|
+
def test_sam_wrapper_requires_closure() -> None:
|
114
|
+
model = nn.Linear(4, 2)
|
115
|
+
base = heavyball.ForeachAdamW(model.parameters())
|
116
|
+
wrapper = heavyball.SAMWrapper(model.parameters(), wrapped_optimizer=base)
|
117
|
+
|
118
|
+
with pytest.raises(ValueError):
|
119
|
+
wrapper.step()
|
120
|
+
|
121
|
+
data = torch.randn(8, 4)
|
122
|
+
target = torch.randn(8, 2)
|
123
|
+
|
124
|
+
def closure():
|
125
|
+
wrapper.zero_grad()
|
126
|
+
loss = torch.nn.functional.mse_loss(model(data), target)
|
127
|
+
loss.backward()
|
128
|
+
return loss
|
129
|
+
|
130
|
+
before = [param.detach().clone() for param in model.parameters()]
|
131
|
+
wrapper.step(closure)
|
132
|
+
after = [param.detach() for param in model.parameters()]
|
133
|
+
diff = torch.cat([(a - b).reshape(-1) for a, b in zip(after, before, strict=True)])
|
134
|
+
assert diff.norm().item() > 0.0
|
@@ -0,0 +1,107 @@
|
|
1
|
+
import numpy as np
|
2
|
+
import optuna
|
3
|
+
import pandas as pd
|
4
|
+
import torch
|
5
|
+
from optuna.distributions import FloatDistribution, IntDistribution
|
6
|
+
from optuna.samplers import RandomSampler
|
7
|
+
from optuna.trial import TrialState
|
8
|
+
|
9
|
+
from heavyball import helpers
|
10
|
+
|
11
|
+
|
12
|
+
def test_bound_to_torch_roundtrip_cpu():
|
13
|
+
arr = np.arange(4, dtype=np.float64).reshape(2, 2)
|
14
|
+
tensor = helpers.bound_to_torch(arr.tobytes(), arr.shape, "cpu")
|
15
|
+
assert torch.allclose(tensor, torch.from_numpy(arr.T))
|
16
|
+
|
17
|
+
|
18
|
+
def test_nextafter_matches_numpy():
|
19
|
+
forward = helpers.nextafter(0.5, 1.0)
|
20
|
+
backward = helpers.nextafter(1, 0)
|
21
|
+
assert forward == np.nextafter(0.5, 1.0)
|
22
|
+
assert backward == np.nextafter(1, 0)
|
23
|
+
|
24
|
+
|
25
|
+
def test_untransform_numerical_param_torch_handles_steps():
|
26
|
+
dist = FloatDistribution(0.0, 1.0, step=0.1)
|
27
|
+
value = torch.tensor(0.46)
|
28
|
+
untransformed = helpers._untransform_numerical_param_torch(value, dist, transform_log=False)
|
29
|
+
assert torch.isclose(untransformed, torch.tensor(0.5))
|
30
|
+
|
31
|
+
|
32
|
+
def test_simple_api_sampler_suggest_all_returns_expected():
|
33
|
+
distributions = {"x": FloatDistribution(0.0, 1.0), "y": IntDistribution(0, 3, step=1)}
|
34
|
+
|
35
|
+
class _Sampler(helpers.SimpleAPIBaseSampler):
|
36
|
+
def infer_relative_search_space(self, study, trial):
|
37
|
+
return self.search_space
|
38
|
+
|
39
|
+
def sample_relative(self, study, trial, search_space):
|
40
|
+
return {}
|
41
|
+
|
42
|
+
def sample_independent(self, study, trial, param_name, param_distribution):
|
43
|
+
return trial.params[param_name]
|
44
|
+
|
45
|
+
sampler = _Sampler(distributions)
|
46
|
+
|
47
|
+
class DummyTrial:
|
48
|
+
def __init__(self, params):
|
49
|
+
self.params = params
|
50
|
+
|
51
|
+
def _suggest(self, name, dist):
|
52
|
+
return self.params[name]
|
53
|
+
|
54
|
+
trial = DummyTrial({"x": 0.25, "y": 2})
|
55
|
+
suggestions = sampler.suggest_all(trial)
|
56
|
+
assert suggestions == {"x": 0.25, "y": 2}
|
57
|
+
|
58
|
+
|
59
|
+
def test_botorch_sampler_sample_relative_smoke(monkeypatch):
|
60
|
+
search_space = {"width": FloatDistribution(0.0, 1.0)}
|
61
|
+
study = optuna.create_study(direction="minimize", sampler=RandomSampler(seed=0))
|
62
|
+
for _ in range(3):
|
63
|
+
trial = study.ask()
|
64
|
+
width = trial.suggest_float("width", 0.0, 1.0)
|
65
|
+
study.tell(trial, width)
|
66
|
+
|
67
|
+
sampler = helpers.BoTorchSampler(search_space, n_startup_trials=1, seed=0, device="cpu")
|
68
|
+
|
69
|
+
def _dummy_candidates(params, values, *_args):
|
70
|
+
assert params.shape[1] == 1
|
71
|
+
return params.mean(dim=0)
|
72
|
+
|
73
|
+
sampler._candidates_func = _dummy_candidates
|
74
|
+
|
75
|
+
pending = study.ask()
|
76
|
+
suggestion = sampler.sample_relative(study, pending, search_space)
|
77
|
+
assert "width" in suggestion
|
78
|
+
assert 0.0 <= suggestion["width"] <= 1.0
|
79
|
+
|
80
|
+
|
81
|
+
def test_hebo_sampler_observe_and_sample(monkeypatch):
|
82
|
+
class DummyHEBO:
|
83
|
+
def __init__(self, *_args, **_kwargs):
|
84
|
+
self.observed = None
|
85
|
+
|
86
|
+
def suggest(self):
|
87
|
+
return pd.DataFrame([{"depth": 0.0}])
|
88
|
+
|
89
|
+
def observe(self, params, values):
|
90
|
+
self.observed = (params, values)
|
91
|
+
|
92
|
+
monkeypatch.setattr(helpers, "HEBO", DummyHEBO)
|
93
|
+
|
94
|
+
search_space = {"depth": FloatDistribution(0.0, 1.0)}
|
95
|
+
sampler = helpers.HEBOSampler(search_space, seed=1)
|
96
|
+
|
97
|
+
study = optuna.create_study(direction="minimize", sampler=RandomSampler(seed=1))
|
98
|
+
trial = study.ask()
|
99
|
+
trial.suggest_float("depth", 0.0, 1.0)
|
100
|
+
study.tell(trial, 0.2)
|
101
|
+
|
102
|
+
suggestion = sampler.sample_relative(study, study.ask(), search_space)
|
103
|
+
assert suggestion["depth"] == 0.0
|
104
|
+
|
105
|
+
completed = study.get_trials(deepcopy=False)[0]
|
106
|
+
sampler.after_trial(study, completed, TrialState.COMPLETE, [0.2])
|
107
|
+
assert sampler._hebo.observed is not None
|