torchzero 0.1.8__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docs/source/conf.py +57 -0
- tests/test_identical.py +230 -0
- tests/test_module.py +50 -0
- tests/test_opts.py +884 -0
- tests/test_tensorlist.py +1787 -0
- tests/test_utils_optimizer.py +170 -0
- tests/test_vars.py +184 -0
- torchzero/__init__.py +4 -4
- torchzero/core/__init__.py +3 -13
- torchzero/core/module.py +629 -510
- torchzero/core/preconditioner.py +137 -0
- torchzero/core/transform.py +252 -0
- torchzero/modules/__init__.py +13 -21
- torchzero/modules/clipping/__init__.py +3 -0
- torchzero/modules/clipping/clipping.py +320 -0
- torchzero/modules/clipping/ema_clipping.py +135 -0
- torchzero/modules/clipping/growth_clipping.py +187 -0
- torchzero/modules/experimental/__init__.py +13 -18
- torchzero/modules/experimental/absoap.py +350 -0
- torchzero/modules/experimental/adadam.py +111 -0
- torchzero/modules/experimental/adamY.py +135 -0
- torchzero/modules/experimental/adasoap.py +282 -0
- torchzero/modules/experimental/algebraic_newton.py +145 -0
- torchzero/modules/experimental/curveball.py +89 -0
- torchzero/modules/experimental/dsoap.py +290 -0
- torchzero/modules/experimental/gradmin.py +85 -0
- torchzero/modules/experimental/reduce_outward_lr.py +35 -0
- torchzero/modules/experimental/spectral.py +286 -0
- torchzero/modules/experimental/subspace_preconditioners.py +128 -0
- torchzero/modules/experimental/tropical_newton.py +136 -0
- torchzero/modules/functional.py +209 -0
- torchzero/modules/grad_approximation/__init__.py +4 -0
- torchzero/modules/grad_approximation/fdm.py +120 -0
- torchzero/modules/grad_approximation/forward_gradient.py +81 -0
- torchzero/modules/grad_approximation/grad_approximator.py +66 -0
- torchzero/modules/grad_approximation/rfdm.py +259 -0
- torchzero/modules/line_search/__init__.py +5 -30
- torchzero/modules/line_search/backtracking.py +186 -0
- torchzero/modules/line_search/line_search.py +181 -0
- torchzero/modules/line_search/scipy.py +37 -0
- torchzero/modules/line_search/strong_wolfe.py +260 -0
- torchzero/modules/line_search/trust_region.py +61 -0
- torchzero/modules/lr/__init__.py +2 -0
- torchzero/modules/lr/lr.py +59 -0
- torchzero/modules/lr/step_size.py +97 -0
- torchzero/modules/momentum/__init__.py +14 -4
- torchzero/modules/momentum/averaging.py +78 -0
- torchzero/modules/momentum/cautious.py +181 -0
- torchzero/modules/momentum/ema.py +173 -0
- torchzero/modules/momentum/experimental.py +189 -0
- torchzero/modules/momentum/matrix_momentum.py +124 -0
- torchzero/modules/momentum/momentum.py +43 -106
- torchzero/modules/ops/__init__.py +103 -0
- torchzero/modules/ops/accumulate.py +65 -0
- torchzero/modules/ops/binary.py +240 -0
- torchzero/modules/ops/debug.py +25 -0
- torchzero/modules/ops/misc.py +419 -0
- torchzero/modules/ops/multi.py +137 -0
- torchzero/modules/ops/reduce.py +149 -0
- torchzero/modules/ops/split.py +75 -0
- torchzero/modules/ops/switch.py +68 -0
- torchzero/modules/ops/unary.py +115 -0
- torchzero/modules/ops/utility.py +112 -0
- torchzero/modules/optimizers/__init__.py +18 -10
- torchzero/modules/optimizers/adagrad.py +146 -49
- torchzero/modules/optimizers/adam.py +112 -118
- torchzero/modules/optimizers/lion.py +18 -11
- torchzero/modules/optimizers/muon.py +222 -0
- torchzero/modules/optimizers/orthograd.py +55 -0
- torchzero/modules/optimizers/rmsprop.py +103 -51
- torchzero/modules/optimizers/rprop.py +342 -99
- torchzero/modules/optimizers/shampoo.py +197 -0
- torchzero/modules/optimizers/soap.py +286 -0
- torchzero/modules/optimizers/sophia_h.py +129 -0
- torchzero/modules/projections/__init__.py +5 -0
- torchzero/modules/projections/dct.py +73 -0
- torchzero/modules/projections/fft.py +73 -0
- torchzero/modules/projections/galore.py +10 -0
- torchzero/modules/projections/projection.py +218 -0
- torchzero/modules/projections/structural.py +151 -0
- torchzero/modules/quasi_newton/__init__.py +7 -4
- torchzero/modules/quasi_newton/cg.py +218 -0
- torchzero/modules/quasi_newton/experimental/__init__.py +1 -0
- torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +265 -0
- torchzero/modules/quasi_newton/lbfgs.py +228 -0
- torchzero/modules/quasi_newton/lsr1.py +170 -0
- torchzero/modules/quasi_newton/olbfgs.py +196 -0
- torchzero/modules/quasi_newton/quasi_newton.py +475 -0
- torchzero/modules/second_order/__init__.py +3 -4
- torchzero/modules/second_order/newton.py +142 -165
- torchzero/modules/second_order/newton_cg.py +84 -0
- torchzero/modules/second_order/nystrom.py +168 -0
- torchzero/modules/smoothing/__init__.py +2 -5
- torchzero/modules/smoothing/gaussian.py +164 -0
- torchzero/modules/smoothing/{laplacian_smoothing.py → laplacian.py} +115 -128
- torchzero/modules/weight_decay/__init__.py +1 -0
- torchzero/modules/weight_decay/weight_decay.py +52 -0
- torchzero/modules/wrappers/__init__.py +1 -0
- torchzero/modules/wrappers/optim_wrapper.py +91 -0
- torchzero/optim/__init__.py +2 -10
- torchzero/optim/utility/__init__.py +1 -0
- torchzero/optim/utility/split.py +45 -0
- torchzero/optim/wrappers/nevergrad.py +2 -28
- torchzero/optim/wrappers/nlopt.py +31 -16
- torchzero/optim/wrappers/scipy.py +79 -156
- torchzero/utils/__init__.py +27 -0
- torchzero/utils/compile.py +175 -37
- torchzero/utils/derivatives.py +513 -99
- torchzero/utils/linalg/__init__.py +5 -0
- torchzero/utils/linalg/matrix_funcs.py +87 -0
- torchzero/utils/linalg/orthogonalize.py +11 -0
- torchzero/utils/linalg/qr.py +71 -0
- torchzero/utils/linalg/solve.py +168 -0
- torchzero/utils/linalg/svd.py +20 -0
- torchzero/utils/numberlist.py +132 -0
- torchzero/utils/ops.py +10 -0
- torchzero/utils/optimizer.py +284 -0
- torchzero/utils/optuna_tools.py +40 -0
- torchzero/utils/params.py +149 -0
- torchzero/utils/python_tools.py +40 -25
- torchzero/utils/tensorlist.py +1081 -0
- torchzero/utils/torch_tools.py +48 -12
- torchzero-0.3.1.dist-info/METADATA +379 -0
- torchzero-0.3.1.dist-info/RECORD +128 -0
- {torchzero-0.1.8.dist-info → torchzero-0.3.1.dist-info}/WHEEL +1 -1
- {torchzero-0.1.8.dist-info → torchzero-0.3.1.dist-info/licenses}/LICENSE +0 -0
- torchzero-0.3.1.dist-info/top_level.txt +3 -0
- torchzero/core/tensorlist_optimizer.py +0 -219
- torchzero/modules/adaptive/__init__.py +0 -4
- torchzero/modules/adaptive/adaptive.py +0 -192
- torchzero/modules/experimental/experimental.py +0 -294
- torchzero/modules/experimental/quad_interp.py +0 -104
- torchzero/modules/experimental/subspace.py +0 -259
- torchzero/modules/gradient_approximation/__init__.py +0 -7
- torchzero/modules/gradient_approximation/_fd_formulas.py +0 -3
- torchzero/modules/gradient_approximation/base_approximator.py +0 -105
- torchzero/modules/gradient_approximation/fdm.py +0 -125
- torchzero/modules/gradient_approximation/forward_gradient.py +0 -163
- torchzero/modules/gradient_approximation/newton_fdm.py +0 -198
- torchzero/modules/gradient_approximation/rfdm.py +0 -125
- torchzero/modules/line_search/armijo.py +0 -56
- torchzero/modules/line_search/base_ls.py +0 -139
- torchzero/modules/line_search/directional_newton.py +0 -217
- torchzero/modules/line_search/grid_ls.py +0 -158
- torchzero/modules/line_search/scipy_minimize_scalar.py +0 -62
- torchzero/modules/meta/__init__.py +0 -12
- torchzero/modules/meta/alternate.py +0 -65
- torchzero/modules/meta/grafting.py +0 -195
- torchzero/modules/meta/optimizer_wrapper.py +0 -173
- torchzero/modules/meta/return_overrides.py +0 -46
- torchzero/modules/misc/__init__.py +0 -10
- torchzero/modules/misc/accumulate.py +0 -43
- torchzero/modules/misc/basic.py +0 -115
- torchzero/modules/misc/lr.py +0 -96
- torchzero/modules/misc/multistep.py +0 -51
- torchzero/modules/misc/on_increase.py +0 -53
- torchzero/modules/operations/__init__.py +0 -29
- torchzero/modules/operations/multi.py +0 -298
- torchzero/modules/operations/reduction.py +0 -134
- torchzero/modules/operations/singular.py +0 -113
- torchzero/modules/optimizers/sgd.py +0 -54
- torchzero/modules/orthogonalization/__init__.py +0 -2
- torchzero/modules/orthogonalization/newtonschulz.py +0 -159
- torchzero/modules/orthogonalization/svd.py +0 -86
- torchzero/modules/regularization/__init__.py +0 -22
- torchzero/modules/regularization/dropout.py +0 -34
- torchzero/modules/regularization/noise.py +0 -77
- torchzero/modules/regularization/normalization.py +0 -328
- torchzero/modules/regularization/ortho_grad.py +0 -78
- torchzero/modules/regularization/weight_decay.py +0 -92
- torchzero/modules/scheduling/__init__.py +0 -2
- torchzero/modules/scheduling/lr_schedulers.py +0 -131
- torchzero/modules/scheduling/step_size.py +0 -80
- torchzero/modules/smoothing/gaussian_smoothing.py +0 -90
- torchzero/modules/weight_averaging/__init__.py +0 -2
- torchzero/modules/weight_averaging/ema.py +0 -72
- torchzero/modules/weight_averaging/swa.py +0 -171
- torchzero/optim/experimental/__init__.py +0 -20
- torchzero/optim/experimental/experimental.py +0 -343
- torchzero/optim/experimental/ray_search.py +0 -83
- torchzero/optim/first_order/__init__.py +0 -18
- torchzero/optim/first_order/cautious.py +0 -158
- torchzero/optim/first_order/forward_gradient.py +0 -70
- torchzero/optim/first_order/optimizers.py +0 -570
- torchzero/optim/modular.py +0 -148
- torchzero/optim/quasi_newton/__init__.py +0 -1
- torchzero/optim/quasi_newton/directional_newton.py +0 -58
- torchzero/optim/second_order/__init__.py +0 -1
- torchzero/optim/second_order/newton.py +0 -94
- torchzero/optim/zeroth_order/__init__.py +0 -4
- torchzero/optim/zeroth_order/fdm.py +0 -87
- torchzero/optim/zeroth_order/newton_fdm.py +0 -146
- torchzero/optim/zeroth_order/rfdm.py +0 -217
- torchzero/optim/zeroth_order/rs.py +0 -85
- torchzero/random/__init__.py +0 -1
- torchzero/random/random.py +0 -46
- torchzero/tensorlist.py +0 -826
- torchzero-0.1.8.dist-info/METADATA +0 -130
- torchzero-0.1.8.dist-info/RECORD +0 -104
- torchzero-0.1.8.dist-info/top_level.txt +0 -1
docs/source/conf.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
# Configuration file for the Sphinx documentation builder.
|
|
2
|
+
#
|
|
3
|
+
# For the full list of built-in configuration values, see the documentation:
|
|
4
|
+
# https://www.sphinx-doc.org/en/master/usage/configuration.html
|
|
5
|
+
|
|
6
|
+
# -- Project information -----------------------------------------------------
|
|
7
|
+
# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
|
|
8
|
+
import sys, os
|
|
9
|
+
#sys.path.insert(0, os.path.abspath('.../src'))
|
|
10
|
+
|
|
11
|
+
project = 'torchzero'
|
|
12
|
+
copyright = '2024, Ivan Nikishev'
|
|
13
|
+
author = 'Ivan Nikishev'
|
|
14
|
+
|
|
15
|
+
# -- General configuration ---------------------------------------------------
|
|
16
|
+
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
|
|
17
|
+
|
|
18
|
+
# https://sphinx-intro-tutorial.readthedocs.io/en/latest/sphinx_extensions.html
|
|
19
|
+
extensions = [
|
|
20
|
+
'sphinx.ext.autodoc',
|
|
21
|
+
'sphinx.ext.autosummary',
|
|
22
|
+
'sphinx.ext.viewcode',
|
|
23
|
+
'sphinx.ext.autosectionlabel',
|
|
24
|
+
'sphinx.ext.githubpages',
|
|
25
|
+
'sphinx.ext.napoleon',
|
|
26
|
+
'autoapi.extension',
|
|
27
|
+
# 'sphinx_rtd_theme',
|
|
28
|
+
]
|
|
29
|
+
autosummary_generate = True
|
|
30
|
+
autoapi_dirs = ['../../src']
|
|
31
|
+
autoapi_type = "python"
|
|
32
|
+
# autoapi_ignore = ["*/tensorlist.py"]
|
|
33
|
+
|
|
34
|
+
# https://sphinx-autoapi.readthedocs.io/en/latest/reference/config.html#confval-autoapi_options
|
|
35
|
+
autoapi_options = [
|
|
36
|
+
"members",
|
|
37
|
+
"undoc-members",
|
|
38
|
+
"show-inheritance",
|
|
39
|
+
"show-module-summary",
|
|
40
|
+
"imported-members",
|
|
41
|
+
]
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
templates_path = ['_templates']
|
|
45
|
+
exclude_patterns = []
|
|
46
|
+
|
|
47
|
+
# -- Options for HTML output -------------------------------------------------
|
|
48
|
+
# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
|
|
49
|
+
|
|
50
|
+
#html_theme = 'alabaster'
|
|
51
|
+
html_theme = 'furo'
|
|
52
|
+
html_static_path = ['_static']
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
# OTHER STUFF I FOUND ON THE INTERNET AND PUT THERE HOPING IT DOES SOMETHING USEFUL
|
|
56
|
+
source_suffix = ['.rst', '.md']
|
|
57
|
+
master_doc = 'index'
|
tests/test_identical.py
ADDED
|
@@ -0,0 +1,230 @@
|
|
|
1
|
+
from collections.abc import Callable, Sequence
|
|
2
|
+
import pytest
|
|
3
|
+
import torch
|
|
4
|
+
import torchzero as tz
|
|
5
|
+
|
|
6
|
+
def _booth(x, y):
|
|
7
|
+
return (x + 2 * y - 7) ** 2 + (2 * x + y - 5) ** 2
|
|
8
|
+
|
|
9
|
+
_BOOTH_X0 = torch.tensor([0., -8.])
|
|
10
|
+
|
|
11
|
+
def _get_trajectory(opt_fn: Callable, x0: torch.Tensor, merge: bool, use_closure: bool, steps: int):
|
|
12
|
+
"""Returns a Tensor - trajectory of `opt_fn` on the booth function."""
|
|
13
|
+
trajectory = []
|
|
14
|
+
if merge:
|
|
15
|
+
params = x0.clone().requires_grad_()
|
|
16
|
+
optimizer = opt_fn([params])
|
|
17
|
+
else:
|
|
18
|
+
params = [x0[0].clone().requires_grad_(), x0[1].clone().requires_grad_()]
|
|
19
|
+
optimizer = opt_fn(params)
|
|
20
|
+
|
|
21
|
+
for _ in range(steps):
|
|
22
|
+
if use_closure:
|
|
23
|
+
def closure(backward=True):
|
|
24
|
+
trajectory.append(torch.stack([p.clone() for p in params]))
|
|
25
|
+
|
|
26
|
+
loss = _booth(*params)
|
|
27
|
+
if backward:
|
|
28
|
+
optimizer.zero_grad()
|
|
29
|
+
loss.backward()
|
|
30
|
+
return loss
|
|
31
|
+
|
|
32
|
+
loss = optimizer.step(closure)
|
|
33
|
+
assert torch.isfinite(loss), f'non-finite loss {loss}'
|
|
34
|
+
for p in params: assert torch.isfinite(p), f'non-finite params {params}'
|
|
35
|
+
|
|
36
|
+
else:
|
|
37
|
+
trajectory.append(torch.stack([p.clone() for p in params]))
|
|
38
|
+
|
|
39
|
+
loss = _booth(*params)
|
|
40
|
+
assert torch.isfinite(loss), f'non-finite loss {loss}'
|
|
41
|
+
optimizer.zero_grad()
|
|
42
|
+
loss.backward()
|
|
43
|
+
optimizer.step()
|
|
44
|
+
for p in params: assert torch.isfinite(p), f'non-finite params {params}'
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
return torch.stack(trajectory, 0), optimizer
|
|
48
|
+
|
|
49
|
+
def _compare_trajectories(opt1, t1:torch.Tensor, opt2, t2:torch.Tensor):
|
|
50
|
+
assert torch.allclose(t1, t2, rtol=1e-4, atol=1e-6), f'trajectories dont match. opts:\n{opt1}\n{opt2}\ntrajectories:\n{t1}\n{t2}'
|
|
51
|
+
|
|
52
|
+
def _assert_identical_opts(opt_fns: Sequence[Callable], merge: bool, use_closure: bool, device, steps: int):
|
|
53
|
+
"""checks that all `opt_fns` have identical trajectories on booth"""
|
|
54
|
+
x0 = _BOOTH_X0.clone().to(device=device)
|
|
55
|
+
base_opt = None
|
|
56
|
+
base_trajectory = None
|
|
57
|
+
for opt_fn in opt_fns:
|
|
58
|
+
t, opt = _get_trajectory(opt_fn, x0, merge, use_closure, steps)
|
|
59
|
+
if base_trajectory is None or base_opt is None:
|
|
60
|
+
base_trajectory = t
|
|
61
|
+
base_opt = opt
|
|
62
|
+
else: _compare_trajectories(base_opt, base_trajectory, opt, t)
|
|
63
|
+
|
|
64
|
+
def _assert_identical_merge(opt_fn: Callable, use_closure, device, steps: int):
|
|
65
|
+
"""checks that trajectories match with x and y parameters split and merged"""
|
|
66
|
+
x0 = _BOOTH_X0.clone().to(device=device)
|
|
67
|
+
merged, merged_opt = _get_trajectory(opt_fn, x0, merge=True, use_closure=use_closure, steps=steps)
|
|
68
|
+
unmerged, unmerged_opt = _get_trajectory(opt_fn, x0, merge=False, use_closure=use_closure, steps=steps)
|
|
69
|
+
_compare_trajectories(merged_opt, merged, unmerged_opt, unmerged)
|
|
70
|
+
|
|
71
|
+
def _assert_identical_closure(opt_fn: Callable, merge, device, steps: int):
|
|
72
|
+
"""checks that trajectories match with and without closure"""
|
|
73
|
+
x0 = _BOOTH_X0.clone().to(device=device)
|
|
74
|
+
closure, closure_opt = _get_trajectory(opt_fn, x0, merge=merge, use_closure=True, steps=steps)
|
|
75
|
+
no_closure, no_closure_opt = _get_trajectory(opt_fn, x0, merge=merge, use_closure=False, steps=steps)
|
|
76
|
+
_compare_trajectories(closure_opt, closure, no_closure_opt, no_closure)
|
|
77
|
+
|
|
78
|
+
def _assert_identical_merge_closure(opt_fn: Callable, device, steps: int):
|
|
79
|
+
"""checks that trajectories match with x and y parameters split and merged and with and without closure"""
|
|
80
|
+
x0 = _BOOTH_X0.clone().to(device=device)
|
|
81
|
+
merge_closure, opt_merge_closure = _get_trajectory(opt_fn, x0, merge=True, use_closure=True, steps=steps)
|
|
82
|
+
merge_no_closure, opt_merge_no_closure = _get_trajectory(opt_fn, x0, merge=True, use_closure=False, steps=steps)
|
|
83
|
+
no_merge_closure, opt_no_merge_closure = _get_trajectory(opt_fn, x0, merge=False, use_closure=True, steps=steps)
|
|
84
|
+
no_merge_no_closure, opt_no_merge_no_closure = _get_trajectory(opt_fn, x0, merge=False, use_closure=False, steps=steps)
|
|
85
|
+
|
|
86
|
+
_compare_trajectories(opt_merge_closure, merge_closure, opt_merge_no_closure, merge_no_closure)
|
|
87
|
+
_compare_trajectories(opt_merge_closure, merge_closure, opt_no_merge_closure, no_merge_closure)
|
|
88
|
+
_compare_trajectories(opt_merge_closure, merge_closure, opt_no_merge_no_closure, no_merge_no_closure)
|
|
89
|
+
|
|
90
|
+
def _assert_identical_device(opt_fn: Callable, merge: bool, use_closure: bool, steps: int):
|
|
91
|
+
"""checks that trajectories match on cpu and cuda."""
|
|
92
|
+
if not torch.cuda.is_available(): return
|
|
93
|
+
cpu, cpu_opt = _get_trajectory(opt_fn, _BOOTH_X0.clone().cpu(), merge=merge, use_closure=use_closure, steps=steps)
|
|
94
|
+
cuda, cuda_opt = _get_trajectory(opt_fn, _BOOTH_X0.clone().cuda(), merge=merge, use_closure=use_closure, steps=steps)
|
|
95
|
+
_compare_trajectories(cpu_opt, cpu, cuda_opt, cuda.to(cpu))
|
|
96
|
+
|
|
97
|
+
@pytest.mark.parametrize('amsgrad', [True, False])
|
|
98
|
+
def test_adam(amsgrad):
|
|
99
|
+
# torch_fn = lambda p: torch.optim.Adam(p, lr=1, amsgrad=amsgrad)
|
|
100
|
+
# pytorch applies debiasing separately so it is applied before epsilo
|
|
101
|
+
tz_fn = lambda p: tz.Modular(p, tz.m.Adam(amsgrad=amsgrad))
|
|
102
|
+
tz_fn2 = lambda p: tz.Modular(p, tz.m.Adam(amsgrad=amsgrad), tz.m.LR(1)) # test LR fusing
|
|
103
|
+
tz_fn3 = lambda p: tz.Modular(p, tz.m.Adam(amsgrad=amsgrad), tz.m.LR(1), tz.m.Add(1), tz.m.Sub(1))
|
|
104
|
+
tz_fn4 = lambda p: tz.Modular(p, tz.m.Adam(amsgrad=amsgrad), tz.m.Add(1), tz.m.Sub(1), tz.m.LR(1))
|
|
105
|
+
tz_fn5 = lambda p: tz.Modular(p, tz.m.Clone(), tz.m.Adam(amsgrad=amsgrad))
|
|
106
|
+
tz_fn_ops = lambda p: tz.Modular(
|
|
107
|
+
p,
|
|
108
|
+
tz.m.DivModules(
|
|
109
|
+
tz.m.EMA(0.9, debiased=True),
|
|
110
|
+
[tz.m.SqrtEMASquared(0.999, debiased=True, amsgrad=amsgrad), tz.m.Add(1e-8)]
|
|
111
|
+
))
|
|
112
|
+
tz_fn_ops2 = lambda p: tz.Modular(
|
|
113
|
+
p,
|
|
114
|
+
tz.m.DivModules(
|
|
115
|
+
[tz.m.EMA(0.9), tz.m.Debias(beta1=0.9)],
|
|
116
|
+
[tz.m.EMASquared(0.999, amsgrad=amsgrad), tz.m.Sqrt(), tz.m.Debias2(beta=0.999), tz.m.Add(1e-8)]
|
|
117
|
+
))
|
|
118
|
+
tz_fn_ops3 = lambda p: tz.Modular(
|
|
119
|
+
p,
|
|
120
|
+
tz.m.DivModules(
|
|
121
|
+
[tz.m.EMA(0.9), tz.m.Debias(beta1=0.9, beta2=0.999)],
|
|
122
|
+
[tz.m.EMASquared(0.999, amsgrad=amsgrad), tz.m.Sqrt(), tz.m.Add(1e-8)]
|
|
123
|
+
))
|
|
124
|
+
tz_fn_ops4 = lambda p: tz.Modular(
|
|
125
|
+
p,
|
|
126
|
+
tz.m.DivModules(
|
|
127
|
+
[tz.m.EMA(0.9), tz.m.Debias(beta1=0.9)],
|
|
128
|
+
[
|
|
129
|
+
tz.m.Pow(2),
|
|
130
|
+
tz.m.EMA(0.999),
|
|
131
|
+
tz.m.AccumulateMaximum() if amsgrad else tz.m.Identity(),
|
|
132
|
+
tz.m.Sqrt(),
|
|
133
|
+
tz.m.Debias2(beta=0.999),
|
|
134
|
+
tz.m.Add(1e-8)]
|
|
135
|
+
))
|
|
136
|
+
tz_fns = (tz_fn, tz_fn2, tz_fn3, tz_fn4, tz_fn5, tz_fn_ops, tz_fn_ops2, tz_fn_ops3, tz_fn_ops4)
|
|
137
|
+
|
|
138
|
+
_assert_identical_opts(tz_fns, merge=True, use_closure=True, device='cpu', steps=10)
|
|
139
|
+
for fn in tz_fns:
|
|
140
|
+
_assert_identical_merge_closure(fn, device='cpu', steps=10)
|
|
141
|
+
_assert_identical_device(fn, merge=True, use_closure=True, steps=10)
|
|
142
|
+
|
|
143
|
+
@pytest.mark.parametrize('beta1', [0.5, 0.9])
|
|
144
|
+
@pytest.mark.parametrize('beta2', [0.99, 0.999])
|
|
145
|
+
@pytest.mark.parametrize('eps', [1e-1, 1e-8])
|
|
146
|
+
@pytest.mark.parametrize('amsgrad', [True, False])
|
|
147
|
+
@pytest.mark.parametrize('lr', [0.1, 1])
|
|
148
|
+
def test_adam_hyperparams(beta1, beta2, eps, amsgrad, lr):
|
|
149
|
+
tz_fn = lambda p: tz.Modular(p, tz.m.Adam(beta1, beta2, eps, amsgrad=amsgrad), tz.m.LR(lr))
|
|
150
|
+
tz_fn2 = lambda p: tz.Modular(p, tz.m.Adam(beta1, beta2, eps, amsgrad=amsgrad, alpha=lr))
|
|
151
|
+
_assert_identical_opts([tz_fn, tz_fn2], merge=True, use_closure=True, device='cpu', steps=10)
|
|
152
|
+
|
|
153
|
+
@pytest.mark.parametrize('centered', [True, False])
|
|
154
|
+
def test_rmsprop(centered):
|
|
155
|
+
torch_fn = lambda p: torch.optim.RMSprop(p, 1, centered=centered)
|
|
156
|
+
tz_fn = lambda p: tz.Modular(p, tz.m.RMSprop(centered=centered, init='zeros'))
|
|
157
|
+
tz_fn2 = lambda p: tz.Modular(
|
|
158
|
+
p,
|
|
159
|
+
tz.m.Div([tz.m.CenteredSqrtEMASquared(0.99) if centered else tz.m.SqrtEMASquared(0.99), tz.m.Add(1e-8)]),
|
|
160
|
+
)
|
|
161
|
+
tz_fn3 = lambda p: tz.Modular(
|
|
162
|
+
p,
|
|
163
|
+
tz.m.Div([tz.m.CenteredEMASquared(0.99) if centered else tz.m.EMASquared(0.99), tz.m.Sqrt(), tz.m.Add(1e-8)]),
|
|
164
|
+
)
|
|
165
|
+
tz_fns = (tz_fn, tz_fn2, tz_fn3)
|
|
166
|
+
_assert_identical_opts([torch_fn, *tz_fns], merge=True, use_closure=True, device='cpu', steps=10)
|
|
167
|
+
for fn in tz_fns:
|
|
168
|
+
_assert_identical_merge_closure(fn, device='cpu', steps=10)
|
|
169
|
+
_assert_identical_device(fn, merge=True, use_closure=True, steps=10)
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
@pytest.mark.parametrize('beta', [0.5, 0.9])
|
|
173
|
+
@pytest.mark.parametrize('eps', [1e-1, 1e-8])
|
|
174
|
+
@pytest.mark.parametrize('centered', [True, False])
|
|
175
|
+
@pytest.mark.parametrize('lr', [0.1, 1])
|
|
176
|
+
def test_rmsprop_hyperparams(beta, eps, centered, lr):
|
|
177
|
+
tz_fn = lambda p: tz.Modular(p, tz.m.RMSprop(beta, eps, centered, init='zeros'), tz.m.LR(lr))
|
|
178
|
+
torch_fn = lambda p: torch.optim.RMSprop(p, lr, beta, eps=eps, centered=centered)
|
|
179
|
+
_assert_identical_opts([torch_fn, tz_fn], merge=True, use_closure=True, device='cpu', steps=10)
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
@pytest.mark.parametrize('nplus', (1.2, 2))
|
|
184
|
+
@pytest.mark.parametrize('nminus', (0.5, 0.9))
|
|
185
|
+
@pytest.mark.parametrize('lb', [1e-8, 1])
|
|
186
|
+
@pytest.mark.parametrize('ub', [50, 1.5])
|
|
187
|
+
@pytest.mark.parametrize('lr', [0.1, 1])
|
|
188
|
+
def test_rprop(nplus, nminus, lb, ub, lr):
|
|
189
|
+
tz_fn = lambda p: tz.Modular(p, tz.m.LR(lr), tz.m.Rprop(nplus, nminus, lb, ub, alpha=lr, backtrack=False))
|
|
190
|
+
torch_fn = lambda p: torch.optim.Rprop(p, lr, (nminus, nplus), (lb, ub))
|
|
191
|
+
_assert_identical_opts([torch_fn, tz_fn], merge=True, use_closure=True, device='cpu', steps=30)
|
|
192
|
+
_assert_identical_merge_closure(tz_fn, 'cpu', 30)
|
|
193
|
+
_assert_identical_device(tz_fn, merge=True, use_closure=True, steps=10)
|
|
194
|
+
|
|
195
|
+
def test_adagrad():
|
|
196
|
+
torch_fn = lambda p: torch.optim.Adagrad(p, 1)
|
|
197
|
+
tz_fn = lambda p: tz.Modular(p, tz.m.Adagrad(), tz.m.LR(1))
|
|
198
|
+
tz_fn2 = lambda p: tz.Modular(
|
|
199
|
+
p,
|
|
200
|
+
tz.m.Div([tz.m.Pow(2), tz.m.AccumulateSum(), tz.m.Sqrt(), tz.m.Add(1e-10)]),
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
tz_fns = (tz_fn, tz_fn2)
|
|
204
|
+
_assert_identical_opts([torch_fn, *tz_fns], merge=True, use_closure=True, device='cpu', steps=10)
|
|
205
|
+
for fn in tz_fns:
|
|
206
|
+
_assert_identical_merge_closure(fn, device='cpu', steps=10)
|
|
207
|
+
_assert_identical_device(fn, merge=True, use_closure=True, steps=10)
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
@pytest.mark.parametrize('initial_accumulator_value', [0, 1])
|
|
212
|
+
@pytest.mark.parametrize('eps', [1e-2, 1e-10])
|
|
213
|
+
@pytest.mark.parametrize('lr', [0.1, 1])
|
|
214
|
+
def test_adagrad_hyperparams(initial_accumulator_value, eps, lr):
|
|
215
|
+
torch_fn = lambda p: torch.optim.Adagrad(p, lr, initial_accumulator_value=initial_accumulator_value, eps=eps)
|
|
216
|
+
tz_fn1 = lambda p: tz.Modular(p, tz.m.Adagrad(initial_accumulator_value=initial_accumulator_value, eps=eps), tz.m.LR(lr))
|
|
217
|
+
tz_fn2 = lambda p: tz.Modular(p, tz.m.Adagrad(initial_accumulator_value=initial_accumulator_value, eps=eps, alpha=lr))
|
|
218
|
+
_assert_identical_opts([torch_fn, tz_fn1, tz_fn2], merge=True, use_closure=True, device='cpu', steps=10)
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
@pytest.mark.parametrize('tensorwise', [True, False])
|
|
222
|
+
def test_graft(tensorwise):
|
|
223
|
+
graft1 = lambda p: tz.Modular(p, tz.m.GraftModules(tz.m.LBFGS(), tz.m.RMSprop(), tensorwise=tensorwise), tz.m.LR(1e-1))
|
|
224
|
+
graft2 = lambda p: tz.Modular(p, tz.m.LBFGS(), tz.m.Graft([tz.m.Grad(), tz.m.RMSprop()], tensorwise=tensorwise), tz.m.LR(1e-1))
|
|
225
|
+
_assert_identical_opts([graft1, graft2], merge=True, use_closure=True, device='cpu', steps=10)
|
|
226
|
+
for fn in [graft1, graft2]:
|
|
227
|
+
if tensorwise: _assert_identical_closure(fn, merge=True, device='cpu', steps=10)
|
|
228
|
+
else: _assert_identical_merge_closure(fn, device='cpu', steps=10)
|
|
229
|
+
_assert_identical_device(fn, merge=True, use_closure=True, steps=10)
|
|
230
|
+
|
tests/test_module.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
from collections.abc import Iterable
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from torchzero.core.module import Module, _make_param_groups
|
|
5
|
+
from torchzero.utils.optimizer import get_params
|
|
6
|
+
from torchzero.utils.params import _add_defaults_to_param_groups_
|
|
7
|
+
|
|
8
|
+
def _assert_same_storage_(seq1: Iterable[torch.Tensor], seq2: Iterable[torch.Tensor]):
|
|
9
|
+
seq1=tuple(seq1)
|
|
10
|
+
seq2=tuple(seq2)
|
|
11
|
+
assert len(seq1) == len(seq2), f'lengths do not match: {len(seq1)} != {len(seq2)}'
|
|
12
|
+
for t1, t2 in zip(seq1, seq2):
|
|
13
|
+
assert t1 is t2
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def test_process_parameters():
|
|
17
|
+
model = torch.nn.Sequential(torch.nn.Linear(3, 6), torch.nn.Linear(6, 3))
|
|
18
|
+
|
|
19
|
+
# iterable of parameters
|
|
20
|
+
_assert_same_storage_(model.parameters(), get_params(_make_param_groups(model.parameters(), differentiable=False), 'all'))
|
|
21
|
+
|
|
22
|
+
# named parameters
|
|
23
|
+
_assert_same_storage_(model.parameters(), get_params(_make_param_groups(model.named_parameters(), differentiable=False), 'all'))
|
|
24
|
+
|
|
25
|
+
# param groups
|
|
26
|
+
param_groups = [{'params': model[0].parameters(), 'lr': 0.1}, {'params': model[1].parameters()}]
|
|
27
|
+
_assert_same_storage_(model.parameters(), get_params(_make_param_groups(param_groups, differentiable=False), 'all'))
|
|
28
|
+
|
|
29
|
+
# check that param groups dict is correct
|
|
30
|
+
param_groups = [
|
|
31
|
+
{'params': model[0].parameters(), 'lr': 0.1},
|
|
32
|
+
{'params': model[1].parameters()}
|
|
33
|
+
]
|
|
34
|
+
expected = [
|
|
35
|
+
{'params': list(model[0].parameters()), 'lr': 0.1},
|
|
36
|
+
{'params': list(model[1].parameters())}
|
|
37
|
+
]
|
|
38
|
+
assert _make_param_groups(param_groups, differentiable=False) == expected
|
|
39
|
+
|
|
40
|
+
# named params
|
|
41
|
+
_names = {'param_names': ['weight','bias']}
|
|
42
|
+
param_groups = [
|
|
43
|
+
{'params': model[0].named_parameters(), 'lr': 0.1},
|
|
44
|
+
{'params': model[1].named_parameters()}
|
|
45
|
+
]
|
|
46
|
+
expected = [
|
|
47
|
+
{'params': list(model[0].parameters()), 'lr': 0.1, **_names},
|
|
48
|
+
{'params': list(model[1].parameters()), 'lr': 0.01, **_names}
|
|
49
|
+
]
|
|
50
|
+
assert _add_defaults_to_param_groups_(_make_param_groups(param_groups, differentiable=False), {"lr": 0.01}) == expected
|