torchzero 0.1.8__py3-none-any.whl → 0.3.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docs/source/conf.py +57 -0
- tests/test_identical.py +230 -0
- tests/test_module.py +50 -0
- tests/test_opts.py +884 -0
- tests/test_tensorlist.py +1787 -0
- tests/test_utils_optimizer.py +170 -0
- tests/test_vars.py +184 -0
- torchzero/__init__.py +4 -4
- torchzero/core/__init__.py +3 -13
- torchzero/core/module.py +629 -510
- torchzero/core/preconditioner.py +137 -0
- torchzero/core/transform.py +252 -0
- torchzero/modules/__init__.py +13 -21
- torchzero/modules/clipping/__init__.py +3 -0
- torchzero/modules/clipping/clipping.py +320 -0
- torchzero/modules/clipping/ema_clipping.py +135 -0
- torchzero/modules/clipping/growth_clipping.py +187 -0
- torchzero/modules/experimental/__init__.py +13 -18
- torchzero/modules/experimental/absoap.py +350 -0
- torchzero/modules/experimental/adadam.py +111 -0
- torchzero/modules/experimental/adamY.py +135 -0
- torchzero/modules/experimental/adasoap.py +282 -0
- torchzero/modules/experimental/algebraic_newton.py +145 -0
- torchzero/modules/experimental/curveball.py +89 -0
- torchzero/modules/experimental/dsoap.py +290 -0
- torchzero/modules/experimental/gradmin.py +85 -0
- torchzero/modules/experimental/reduce_outward_lr.py +35 -0
- torchzero/modules/experimental/spectral.py +286 -0
- torchzero/modules/experimental/subspace_preconditioners.py +128 -0
- torchzero/modules/experimental/tropical_newton.py +136 -0
- torchzero/modules/functional.py +209 -0
- torchzero/modules/grad_approximation/__init__.py +4 -0
- torchzero/modules/grad_approximation/fdm.py +120 -0
- torchzero/modules/grad_approximation/forward_gradient.py +81 -0
- torchzero/modules/grad_approximation/grad_approximator.py +66 -0
- torchzero/modules/grad_approximation/rfdm.py +259 -0
- torchzero/modules/line_search/__init__.py +5 -30
- torchzero/modules/line_search/backtracking.py +186 -0
- torchzero/modules/line_search/line_search.py +181 -0
- torchzero/modules/line_search/scipy.py +37 -0
- torchzero/modules/line_search/strong_wolfe.py +260 -0
- torchzero/modules/line_search/trust_region.py +61 -0
- torchzero/modules/lr/__init__.py +2 -0
- torchzero/modules/lr/lr.py +59 -0
- torchzero/modules/lr/step_size.py +97 -0
- torchzero/modules/momentum/__init__.py +14 -4
- torchzero/modules/momentum/averaging.py +78 -0
- torchzero/modules/momentum/cautious.py +181 -0
- torchzero/modules/momentum/ema.py +173 -0
- torchzero/modules/momentum/experimental.py +189 -0
- torchzero/modules/momentum/matrix_momentum.py +124 -0
- torchzero/modules/momentum/momentum.py +43 -106
- torchzero/modules/ops/__init__.py +103 -0
- torchzero/modules/ops/accumulate.py +65 -0
- torchzero/modules/ops/binary.py +240 -0
- torchzero/modules/ops/debug.py +25 -0
- torchzero/modules/ops/misc.py +419 -0
- torchzero/modules/ops/multi.py +137 -0
- torchzero/modules/ops/reduce.py +149 -0
- torchzero/modules/ops/split.py +75 -0
- torchzero/modules/ops/switch.py +68 -0
- torchzero/modules/ops/unary.py +115 -0
- torchzero/modules/ops/utility.py +112 -0
- torchzero/modules/optimizers/__init__.py +18 -10
- torchzero/modules/optimizers/adagrad.py +146 -49
- torchzero/modules/optimizers/adam.py +112 -118
- torchzero/modules/optimizers/lion.py +18 -11
- torchzero/modules/optimizers/muon.py +222 -0
- torchzero/modules/optimizers/orthograd.py +55 -0
- torchzero/modules/optimizers/rmsprop.py +103 -51
- torchzero/modules/optimizers/rprop.py +342 -99
- torchzero/modules/optimizers/shampoo.py +197 -0
- torchzero/modules/optimizers/soap.py +286 -0
- torchzero/modules/optimizers/sophia_h.py +129 -0
- torchzero/modules/projections/__init__.py +5 -0
- torchzero/modules/projections/dct.py +73 -0
- torchzero/modules/projections/fft.py +73 -0
- torchzero/modules/projections/galore.py +10 -0
- torchzero/modules/projections/projection.py +218 -0
- torchzero/modules/projections/structural.py +151 -0
- torchzero/modules/quasi_newton/__init__.py +7 -4
- torchzero/modules/quasi_newton/cg.py +218 -0
- torchzero/modules/quasi_newton/experimental/__init__.py +1 -0
- torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +265 -0
- torchzero/modules/quasi_newton/lbfgs.py +228 -0
- torchzero/modules/quasi_newton/lsr1.py +170 -0
- torchzero/modules/quasi_newton/olbfgs.py +196 -0
- torchzero/modules/quasi_newton/quasi_newton.py +475 -0
- torchzero/modules/second_order/__init__.py +3 -4
- torchzero/modules/second_order/newton.py +142 -165
- torchzero/modules/second_order/newton_cg.py +84 -0
- torchzero/modules/second_order/nystrom.py +168 -0
- torchzero/modules/smoothing/__init__.py +2 -5
- torchzero/modules/smoothing/gaussian.py +164 -0
- torchzero/modules/smoothing/{laplacian_smoothing.py → laplacian.py} +115 -128
- torchzero/modules/weight_decay/__init__.py +1 -0
- torchzero/modules/weight_decay/weight_decay.py +52 -0
- torchzero/modules/wrappers/__init__.py +1 -0
- torchzero/modules/wrappers/optim_wrapper.py +91 -0
- torchzero/optim/__init__.py +2 -10
- torchzero/optim/utility/__init__.py +1 -0
- torchzero/optim/utility/split.py +45 -0
- torchzero/optim/wrappers/nevergrad.py +2 -28
- torchzero/optim/wrappers/nlopt.py +31 -16
- torchzero/optim/wrappers/scipy.py +79 -156
- torchzero/utils/__init__.py +27 -0
- torchzero/utils/compile.py +175 -37
- torchzero/utils/derivatives.py +513 -99
- torchzero/utils/linalg/__init__.py +5 -0
- torchzero/utils/linalg/matrix_funcs.py +87 -0
- torchzero/utils/linalg/orthogonalize.py +11 -0
- torchzero/utils/linalg/qr.py +71 -0
- torchzero/utils/linalg/solve.py +168 -0
- torchzero/utils/linalg/svd.py +20 -0
- torchzero/utils/numberlist.py +132 -0
- torchzero/utils/ops.py +10 -0
- torchzero/utils/optimizer.py +284 -0
- torchzero/utils/optuna_tools.py +40 -0
- torchzero/utils/params.py +149 -0
- torchzero/utils/python_tools.py +40 -25
- torchzero/utils/tensorlist.py +1081 -0
- torchzero/utils/torch_tools.py +48 -12
- torchzero-0.3.2.dist-info/METADATA +379 -0
- torchzero-0.3.2.dist-info/RECORD +128 -0
- {torchzero-0.1.8.dist-info → torchzero-0.3.2.dist-info}/WHEEL +1 -1
- {torchzero-0.1.8.dist-info → torchzero-0.3.2.dist-info/licenses}/LICENSE +0 -0
- torchzero-0.3.2.dist-info/top_level.txt +3 -0
- torchzero/core/tensorlist_optimizer.py +0 -219
- torchzero/modules/adaptive/__init__.py +0 -4
- torchzero/modules/adaptive/adaptive.py +0 -192
- torchzero/modules/experimental/experimental.py +0 -294
- torchzero/modules/experimental/quad_interp.py +0 -104
- torchzero/modules/experimental/subspace.py +0 -259
- torchzero/modules/gradient_approximation/__init__.py +0 -7
- torchzero/modules/gradient_approximation/_fd_formulas.py +0 -3
- torchzero/modules/gradient_approximation/base_approximator.py +0 -105
- torchzero/modules/gradient_approximation/fdm.py +0 -125
- torchzero/modules/gradient_approximation/forward_gradient.py +0 -163
- torchzero/modules/gradient_approximation/newton_fdm.py +0 -198
- torchzero/modules/gradient_approximation/rfdm.py +0 -125
- torchzero/modules/line_search/armijo.py +0 -56
- torchzero/modules/line_search/base_ls.py +0 -139
- torchzero/modules/line_search/directional_newton.py +0 -217
- torchzero/modules/line_search/grid_ls.py +0 -158
- torchzero/modules/line_search/scipy_minimize_scalar.py +0 -62
- torchzero/modules/meta/__init__.py +0 -12
- torchzero/modules/meta/alternate.py +0 -65
- torchzero/modules/meta/grafting.py +0 -195
- torchzero/modules/meta/optimizer_wrapper.py +0 -173
- torchzero/modules/meta/return_overrides.py +0 -46
- torchzero/modules/misc/__init__.py +0 -10
- torchzero/modules/misc/accumulate.py +0 -43
- torchzero/modules/misc/basic.py +0 -115
- torchzero/modules/misc/lr.py +0 -96
- torchzero/modules/misc/multistep.py +0 -51
- torchzero/modules/misc/on_increase.py +0 -53
- torchzero/modules/operations/__init__.py +0 -29
- torchzero/modules/operations/multi.py +0 -298
- torchzero/modules/operations/reduction.py +0 -134
- torchzero/modules/operations/singular.py +0 -113
- torchzero/modules/optimizers/sgd.py +0 -54
- torchzero/modules/orthogonalization/__init__.py +0 -2
- torchzero/modules/orthogonalization/newtonschulz.py +0 -159
- torchzero/modules/orthogonalization/svd.py +0 -86
- torchzero/modules/regularization/__init__.py +0 -22
- torchzero/modules/regularization/dropout.py +0 -34
- torchzero/modules/regularization/noise.py +0 -77
- torchzero/modules/regularization/normalization.py +0 -328
- torchzero/modules/regularization/ortho_grad.py +0 -78
- torchzero/modules/regularization/weight_decay.py +0 -92
- torchzero/modules/scheduling/__init__.py +0 -2
- torchzero/modules/scheduling/lr_schedulers.py +0 -131
- torchzero/modules/scheduling/step_size.py +0 -80
- torchzero/modules/smoothing/gaussian_smoothing.py +0 -90
- torchzero/modules/weight_averaging/__init__.py +0 -2
- torchzero/modules/weight_averaging/ema.py +0 -72
- torchzero/modules/weight_averaging/swa.py +0 -171
- torchzero/optim/experimental/__init__.py +0 -20
- torchzero/optim/experimental/experimental.py +0 -343
- torchzero/optim/experimental/ray_search.py +0 -83
- torchzero/optim/first_order/__init__.py +0 -18
- torchzero/optim/first_order/cautious.py +0 -158
- torchzero/optim/first_order/forward_gradient.py +0 -70
- torchzero/optim/first_order/optimizers.py +0 -570
- torchzero/optim/modular.py +0 -148
- torchzero/optim/quasi_newton/__init__.py +0 -1
- torchzero/optim/quasi_newton/directional_newton.py +0 -58
- torchzero/optim/second_order/__init__.py +0 -1
- torchzero/optim/second_order/newton.py +0 -94
- torchzero/optim/zeroth_order/__init__.py +0 -4
- torchzero/optim/zeroth_order/fdm.py +0 -87
- torchzero/optim/zeroth_order/newton_fdm.py +0 -146
- torchzero/optim/zeroth_order/rfdm.py +0 -217
- torchzero/optim/zeroth_order/rs.py +0 -85
- torchzero/random/__init__.py +0 -1
- torchzero/random/random.py +0 -46
- torchzero/tensorlist.py +0 -826
- torchzero-0.1.8.dist-info/METADATA +0 -130
- torchzero-0.1.8.dist-info/RECORD +0 -104
- torchzero-0.1.8.dist-info/top_level.txt +0 -1
|
@@ -0,0 +1,170 @@
|
|
|
1
|
+
from collections.abc import Iterable
|
|
2
|
+
from typing import Any
|
|
3
|
+
from functools import partial
|
|
4
|
+
import pytest
|
|
5
|
+
import torch
|
|
6
|
+
from torchzero.utils.optimizer import (
|
|
7
|
+
Optimizer,
|
|
8
|
+
get_group_vals,
|
|
9
|
+
get_params,
|
|
10
|
+
get_state_vals,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def _assert_same_storage_(seq1: Iterable[torch.Tensor], seq2: Iterable[torch.Tensor]):
|
|
15
|
+
seq1=tuple(seq1)
|
|
16
|
+
seq2=tuple(seq2)
|
|
17
|
+
assert len(seq1) == len(seq2), f'lengths do not match: {len(seq1)} != {len(seq2)}'
|
|
18
|
+
for t1, t2 in zip(seq1, seq2):
|
|
19
|
+
assert t1 is t2
|
|
20
|
+
|
|
21
|
+
def _assert_equals_different_storage_(seq1: Iterable[torch.Tensor], seq2: Iterable[torch.Tensor]):
|
|
22
|
+
seq1=tuple(seq1)
|
|
23
|
+
seq2=tuple(seq2)
|
|
24
|
+
assert len(seq1) == len(seq2), f'lengths do not match: {len(seq1)} != {len(seq2)}'
|
|
25
|
+
for t1, t2 in zip(seq1, seq2):
|
|
26
|
+
assert t1 is not t2
|
|
27
|
+
assert (t1 == t2).all()
|
|
28
|
+
|
|
29
|
+
def test_assert_compare_tensors():
|
|
30
|
+
t1 = [torch.randn(1, 3) for _ in range(10)]
|
|
31
|
+
t2 = [torch.randn(1, 3) for _ in range(10)]
|
|
32
|
+
|
|
33
|
+
_assert_same_storage_(t1, t1)
|
|
34
|
+
_assert_same_storage_(t2, t2)
|
|
35
|
+
|
|
36
|
+
with pytest.raises(AssertionError):
|
|
37
|
+
_assert_same_storage_(t1, t2)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def test_get_params():
|
|
41
|
+
param_groups = [
|
|
42
|
+
{'params': [torch.randn(1, 1, requires_grad=True), torch.randn(1, 2, requires_grad=True)]},
|
|
43
|
+
{'params': [torch.randn(2, 1, requires_grad=True), torch.randn(2, 2, requires_grad=False)], "lr": 0.1},
|
|
44
|
+
{'params': [torch.randn(3, 1, requires_grad=False)], 'lr': 0.001, 'betas': (0.9, 0.99)},
|
|
45
|
+
]
|
|
46
|
+
param_groups[0]['params'][0].grad = torch.randn(1, 1)
|
|
47
|
+
|
|
48
|
+
params = get_params(param_groups, mode = 'requires_grad', cls = list)
|
|
49
|
+
_assert_same_storage_(params, [*param_groups[0]['params'], param_groups[1]['params'][0]])
|
|
50
|
+
|
|
51
|
+
params = get_params(param_groups, mode = 'has_grad', cls = list)
|
|
52
|
+
_assert_same_storage_(params, [param_groups[0]['params'][0]])
|
|
53
|
+
|
|
54
|
+
params = get_params(param_groups, mode = 'all', cls = list)
|
|
55
|
+
_assert_same_storage_(params, [*param_groups[0]['params'], *param_groups[1]['params'], *param_groups[2]['params']])
|
|
56
|
+
|
|
57
|
+
def test_get_group_vals():
|
|
58
|
+
param_groups = [
|
|
59
|
+
{'params': [torch.randn(2, 1, requires_grad=True), torch.randn(2, 2, requires_grad=True)], "lr": 0.1, 'beta': 0.95, 'eps': 1e-8},
|
|
60
|
+
{'params': [torch.randn(1, 1, requires_grad=True), torch.randn(1, 2, requires_grad=False)], 'lr': 0.01, 'beta': 0.99, 'eps': 1e-7},
|
|
61
|
+
{'params': [torch.randn(3, 1, requires_grad=False)], 'lr': 0.001, 'beta': 0.999, 'eps': 1e-6},
|
|
62
|
+
]
|
|
63
|
+
param_groups[0]['params'][0].grad = torch.randn(2, 1)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
lr = get_group_vals(param_groups, 'lr', mode = 'requires_grad', cls = list)
|
|
67
|
+
assert lr == [0.1, 0.1, 0.01], lr
|
|
68
|
+
|
|
69
|
+
lr, beta = get_group_vals(param_groups, 'lr', 'beta', mode = 'requires_grad', cls = list)
|
|
70
|
+
assert lr == [0.1, 0.1, 0.01], lr
|
|
71
|
+
assert beta == [0.95, 0.95, 0.99], beta
|
|
72
|
+
|
|
73
|
+
lr, beta, eps = get_group_vals(param_groups, ('lr', 'beta', 'eps'), mode = 'requires_grad', cls = list)
|
|
74
|
+
assert lr == [0.1, 0.1, 0.01], lr
|
|
75
|
+
assert beta == [0.95, 0.95, 0.99], beta
|
|
76
|
+
assert eps == [1e-8, 1e-8, 1e-7], eps
|
|
77
|
+
|
|
78
|
+
lr = get_group_vals(param_groups, 'lr', mode = 'has_grad', cls = list)
|
|
79
|
+
assert lr == [0.1], lr
|
|
80
|
+
|
|
81
|
+
lr, beta, eps = get_group_vals(param_groups, 'lr', 'beta', 'eps', mode = 'all', cls = list)
|
|
82
|
+
assert lr == [0.1, 0.1, 0.01, 0.01, 0.001], lr
|
|
83
|
+
assert beta == [0.95, 0.95, 0.99, 0.99, 0.999], beta
|
|
84
|
+
assert eps == [1e-8, 1e-8, 1e-7, 1e-7, 1e-6], eps
|
|
85
|
+
|
|
86
|
+
def test_get_state_vals():
|
|
87
|
+
# accessing state values of a single parameter, which acts as the key, so no tensors are passed
|
|
88
|
+
tensor = torch.randn(3,3)
|
|
89
|
+
state = {tensor: {'exp_avg': torch.ones_like(tensor)}}
|
|
90
|
+
existing_cov_exp_avg = state[tensor]['exp_avg']
|
|
91
|
+
cov_exp_avg, cov_exp_avg_sq = get_state_vals(state, [tensor], ('exp_avg', 'exp_avg_sq'), init = [torch.zeros_like, lambda x: torch.full_like(x, 2)])
|
|
92
|
+
assert torch.allclose(cov_exp_avg[0], torch.ones_like(tensor))
|
|
93
|
+
assert torch.allclose(cov_exp_avg_sq[0], torch.full_like(tensor, 2))
|
|
94
|
+
assert cov_exp_avg[0] is existing_cov_exp_avg
|
|
95
|
+
assert state[tensor]['exp_avg'] is existing_cov_exp_avg
|
|
96
|
+
assert state[tensor]['exp_avg_sq'] is cov_exp_avg_sq[0]
|
|
97
|
+
|
|
98
|
+
# accessing state values of multiple parameters
|
|
99
|
+
parameters = [torch.randn(i,2) for i in range(1, 11)]
|
|
100
|
+
state = {p: {} for p in parameters}
|
|
101
|
+
exp_avgs = get_state_vals(state, parameters, 'exp_avg', cls=list)
|
|
102
|
+
assert isinstance(exp_avgs, list), type(exp_avgs)
|
|
103
|
+
assert len(exp_avgs) == 10, len(exp_avgs)
|
|
104
|
+
assert all(torch.allclose(a, torch.zeros_like(parameters[i])) for i, a in enumerate(exp_avgs))
|
|
105
|
+
exp_avgs2 = get_state_vals(state, parameters, 'exp_avg', cls=list)
|
|
106
|
+
_assert_same_storage_(exp_avgs, exp_avgs2)
|
|
107
|
+
|
|
108
|
+
# per-parameter inits
|
|
109
|
+
parameters = [torch.full((i,2), fill_value=i**2) for i in range(1, 11)]
|
|
110
|
+
state = {p: {} for p in parameters}
|
|
111
|
+
exp_avgs = get_state_vals(state, parameters, 'exp_avg', init = [partial(torch.full_like, fill_value=i) for i in range(10)], cls=list)
|
|
112
|
+
assert isinstance(exp_avgs, list), type(exp_avgs)
|
|
113
|
+
assert len(exp_avgs) == 10, len(exp_avgs)
|
|
114
|
+
assert all(torch.allclose(a, torch.full_like(parameters[i], i)) for i, a in enumerate(exp_avgs)), exp_avgs
|
|
115
|
+
exp_avgs2 = get_state_vals(state, parameters, 'exp_avg', cls=list)
|
|
116
|
+
_assert_same_storage_(exp_avgs, exp_avgs2)
|
|
117
|
+
|
|
118
|
+
# per-parmeter init with a list
|
|
119
|
+
parameters = [torch.full((i,2), fill_value=i**2) for i in range(1, 11)]
|
|
120
|
+
state = {p: {} for p in parameters}
|
|
121
|
+
inits = [torch.full([i], fill_value=i) for i in range(1, 11)]
|
|
122
|
+
exp_avgs = get_state_vals(state, parameters, 'exp_avg', init = inits, cls=list)
|
|
123
|
+
assert isinstance(exp_avgs, list), type(exp_avgs)
|
|
124
|
+
assert len(exp_avgs) == 10, len(exp_avgs)
|
|
125
|
+
_assert_equals_different_storage_(inits, exp_avgs) # inits are cloned
|
|
126
|
+
exp_avgs2 = get_state_vals(state, parameters, 'exp_avg', cls=list)
|
|
127
|
+
_assert_same_storage_(exp_avgs, exp_avgs2)
|
|
128
|
+
|
|
129
|
+
# init with a value
|
|
130
|
+
parameters = [torch.full((i,2), fill_value=i**2) for i in range(1, 11)]
|
|
131
|
+
state = {p: {} for p in parameters}
|
|
132
|
+
inits = 1
|
|
133
|
+
exp_avgs = get_state_vals(state, parameters, 'exp_avg', init = inits, cls=list)
|
|
134
|
+
assert isinstance(exp_avgs, list), type(exp_avgs)
|
|
135
|
+
assert len(exp_avgs) == 10, len(exp_avgs)
|
|
136
|
+
assert all(v==1 for v in exp_avgs), exp_avgs
|
|
137
|
+
assert exp_avgs == get_state_vals(state, parameters, 'exp_avg', cls=list) # no init because already initialized
|
|
138
|
+
|
|
139
|
+
# accessing multiple keys
|
|
140
|
+
parameters = [torch.randn(i,2) for i in range(1,11)]
|
|
141
|
+
state = {p: {} for p in parameters}
|
|
142
|
+
exp_avgs, exp_avg_sqs, max_avgs = get_state_vals(state, parameters, 'exp_avg', 'exp_avg_sq', 'max_avg', cls=list)
|
|
143
|
+
assert len(exp_avgs) == len(exp_avg_sqs) == len(max_avgs) == 10
|
|
144
|
+
assert isinstance(exp_avgs, list), type(exp_avgs)
|
|
145
|
+
assert isinstance(exp_avg_sqs, list), type(exp_avg_sqs)
|
|
146
|
+
assert isinstance(max_avgs, list), type(max_avgs)
|
|
147
|
+
assert all(torch.allclose(a, torch.zeros_like(parameters[i])) for i, a in enumerate(exp_avgs))
|
|
148
|
+
assert all(torch.allclose(a, torch.zeros_like(parameters[i])) for i, a in enumerate(exp_avg_sqs))
|
|
149
|
+
assert all(torch.allclose(a, torch.zeros_like(parameters[i])) for i, a in enumerate(max_avgs))
|
|
150
|
+
exp_avgs2 = get_state_vals(state, parameters, 'exp_avg', cls=list)
|
|
151
|
+
exp_avg_sqs2 = get_state_vals(state, parameters, 'exp_avg_sq', cls=list)
|
|
152
|
+
max_avgs2 = get_state_vals(state, parameters, 'max_avg', cls=list)
|
|
153
|
+
_assert_same_storage_(exp_avgs, exp_avgs2)
|
|
154
|
+
_assert_same_storage_(exp_avg_sqs, exp_avg_sqs2)
|
|
155
|
+
_assert_same_storage_(max_avgs, max_avgs2)
|
|
156
|
+
|
|
157
|
+
# per-key init
|
|
158
|
+
parameters = [torch.randn(i,2) for i in range(1,11)]
|
|
159
|
+
state = {p: {} for p in parameters}
|
|
160
|
+
exp_avgs, exp_avg_sqs, max_avgs = get_state_vals(state, parameters, 'exp_avg', 'exp_avg_sq', 'max_avg', init=(4,5,5.5), cls=list)
|
|
161
|
+
assert len(exp_avgs) == len(exp_avg_sqs) == len(max_avgs) == 10
|
|
162
|
+
assert isinstance(exp_avgs, list), type(exp_avgs)
|
|
163
|
+
assert isinstance(exp_avg_sqs, list), type(exp_avg_sqs)
|
|
164
|
+
assert isinstance(max_avgs, list), type(max_avgs)
|
|
165
|
+
assert all(v==4 for v in exp_avgs), exp_avgs
|
|
166
|
+
assert all(v==5 for v in exp_avg_sqs), exp_avg_sqs
|
|
167
|
+
assert all(v==5.5 for v in max_avgs), max_avgs
|
|
168
|
+
assert exp_avgs == get_state_vals(state, parameters, 'exp_avg', cls=list)
|
|
169
|
+
assert exp_avg_sqs == get_state_vals(state, parameters, 'exp_avg_sq', cls=list)
|
|
170
|
+
assert max_avgs == get_state_vals(state, parameters, 'max_avg', cls=list)
|
tests/test_vars.py
ADDED
|
@@ -0,0 +1,184 @@
|
|
|
1
|
+
import pytest
|
|
2
|
+
import torch
|
|
3
|
+
from torchzero.core.module import Vars
|
|
4
|
+
from torchzero.utils.tensorlist import TensorList
|
|
5
|
+
|
|
6
|
+
@torch.no_grad
|
|
7
|
+
def test_vars_get_loss():
|
|
8
|
+
|
|
9
|
+
# ---------------------------- test that it works ---------------------------- #
|
|
10
|
+
params = [torch.tensor(2.0, requires_grad=True)]
|
|
11
|
+
evaluated = False
|
|
12
|
+
|
|
13
|
+
def closure_1(backward=True):
|
|
14
|
+
assert not backward, 'backward = True'
|
|
15
|
+
|
|
16
|
+
# ensure closure only evaluates once
|
|
17
|
+
nonlocal evaluated
|
|
18
|
+
assert evaluated is False, 'closure was evaluated twice'
|
|
19
|
+
evaluated = True
|
|
20
|
+
|
|
21
|
+
loss = params[0]**2
|
|
22
|
+
if backward:
|
|
23
|
+
params[0].grad = None
|
|
24
|
+
loss.backward()
|
|
25
|
+
else:
|
|
26
|
+
assert not loss.requires_grad, "loss requires grad with backward=False"
|
|
27
|
+
return loss
|
|
28
|
+
|
|
29
|
+
vars = Vars(params=params, closure=closure_1, model=None, current_step=0)
|
|
30
|
+
|
|
31
|
+
assert vars.loss is None, vars.loss
|
|
32
|
+
|
|
33
|
+
assert (loss := vars.get_loss(backward=False)) == 4.0, loss
|
|
34
|
+
assert evaluated, evaluated
|
|
35
|
+
assert loss is vars.loss
|
|
36
|
+
assert vars.loss == 4.0
|
|
37
|
+
assert vars.loss_approx == 4.0
|
|
38
|
+
assert vars.grad is None, vars.grad
|
|
39
|
+
|
|
40
|
+
# reevaluate, which should just return already evaluated loss
|
|
41
|
+
assert (loss := vars.get_loss(backward=False)) == 4.0, loss
|
|
42
|
+
assert vars.grad is None, vars.grad
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
# ----------------------- test that backward=True works ---------------------- #
|
|
46
|
+
params = [torch.tensor(3.0, requires_grad=True)]
|
|
47
|
+
evaluated = False
|
|
48
|
+
|
|
49
|
+
def closure_2(backward=True):
|
|
50
|
+
# ensure closure only evaluates once
|
|
51
|
+
nonlocal evaluated
|
|
52
|
+
assert evaluated is False, 'closure was evaluated twice'
|
|
53
|
+
evaluated = True
|
|
54
|
+
|
|
55
|
+
loss = params[0] * 2
|
|
56
|
+
if backward:
|
|
57
|
+
assert loss.requires_grad, "loss does not require grad so `with torch.enable_grad()` context didn't work"
|
|
58
|
+
params[0].grad = None
|
|
59
|
+
loss.backward()
|
|
60
|
+
else:
|
|
61
|
+
assert not loss.requires_grad, "loss requires grad with backward=False"
|
|
62
|
+
return loss
|
|
63
|
+
|
|
64
|
+
vars = Vars(params=params, closure=closure_2, model=None, current_step=0)
|
|
65
|
+
assert vars.grad is None, vars.grad
|
|
66
|
+
assert (loss := vars.get_loss(backward=True)) == 6.0, loss
|
|
67
|
+
assert vars.grad is not None
|
|
68
|
+
assert vars.grad[0] == 2.0, vars.grad
|
|
69
|
+
|
|
70
|
+
# reevaluate, which should just return already evaluated loss
|
|
71
|
+
assert (loss := vars.get_loss(backward=True)) == 6.0, loss
|
|
72
|
+
assert vars.grad[0] == 2.0, vars.grad
|
|
73
|
+
|
|
74
|
+
# get grad, which should just return already evaluated grad
|
|
75
|
+
assert (grad := vars.get_grad())[0] == 2.0, grad
|
|
76
|
+
assert grad is vars.grad, grad
|
|
77
|
+
|
|
78
|
+
# get update, which should create and return cloned grad
|
|
79
|
+
assert vars.update is None
|
|
80
|
+
assert (update := vars.get_update())[0] == 2.0, update
|
|
81
|
+
assert update is vars.update
|
|
82
|
+
assert update is not vars.grad
|
|
83
|
+
assert vars.grad is not None
|
|
84
|
+
assert update[0] == vars.grad[0]
|
|
85
|
+
|
|
86
|
+
@torch.no_grad
|
|
87
|
+
def test_vars_get_grad():
|
|
88
|
+
params = [torch.tensor(2.0, requires_grad=True)]
|
|
89
|
+
evaluated = False
|
|
90
|
+
|
|
91
|
+
def closure(backward=True):
|
|
92
|
+
# ensure closure only evaluates once
|
|
93
|
+
nonlocal evaluated
|
|
94
|
+
assert evaluated is False, 'closure was evaluated twice'
|
|
95
|
+
evaluated = True
|
|
96
|
+
|
|
97
|
+
loss = params[0]**2
|
|
98
|
+
if backward:
|
|
99
|
+
assert loss.requires_grad, "loss does not require grad so `with torch.enable_grad()` context didn't work"
|
|
100
|
+
params[0].grad = None
|
|
101
|
+
loss.backward()
|
|
102
|
+
else:
|
|
103
|
+
assert not loss.requires_grad, "loss requires grad with backward=False"
|
|
104
|
+
return loss
|
|
105
|
+
|
|
106
|
+
vars = Vars(params=params, closure=closure, model=None, current_step=0)
|
|
107
|
+
assert (grad := vars.get_grad())[0] == 4.0, grad
|
|
108
|
+
assert grad is vars.grad
|
|
109
|
+
|
|
110
|
+
assert vars.loss == 4.0
|
|
111
|
+
assert (loss := vars.get_loss(backward=False)) == 4.0, loss
|
|
112
|
+
assert (loss := vars.get_loss(backward=True)) == 4.0, loss
|
|
113
|
+
assert vars.loss_approx == 4.0
|
|
114
|
+
|
|
115
|
+
assert vars.update is None, vars.update
|
|
116
|
+
assert (update := vars.get_update())[0] == 4.0, update
|
|
117
|
+
|
|
118
|
+
@torch.no_grad
|
|
119
|
+
def test_vars_get_update():
|
|
120
|
+
params = [torch.tensor(2.0, requires_grad=True)]
|
|
121
|
+
evaluated = False
|
|
122
|
+
|
|
123
|
+
def closure(backward=True):
|
|
124
|
+
# ensure closure only evaluates once
|
|
125
|
+
nonlocal evaluated
|
|
126
|
+
assert evaluated is False, 'closure was evaluated twice'
|
|
127
|
+
evaluated = True
|
|
128
|
+
|
|
129
|
+
loss = params[0]**2
|
|
130
|
+
if backward:
|
|
131
|
+
assert loss.requires_grad, "loss does not require grad so `with torch.enable_grad()` context didn't work"
|
|
132
|
+
params[0].grad = None
|
|
133
|
+
loss.backward()
|
|
134
|
+
else:
|
|
135
|
+
assert not loss.requires_grad, "loss requires grad with backward=False"
|
|
136
|
+
return loss
|
|
137
|
+
|
|
138
|
+
vars = Vars(params=params, closure=closure, model=None, current_step=0)
|
|
139
|
+
assert vars.update is None, vars.update
|
|
140
|
+
assert (update := vars.get_update())[0] == 4.0, update
|
|
141
|
+
assert update is vars.update
|
|
142
|
+
|
|
143
|
+
assert (grad := vars.get_grad())[0] == 4.0, grad
|
|
144
|
+
assert grad is vars.grad
|
|
145
|
+
assert grad is not update
|
|
146
|
+
|
|
147
|
+
assert vars.loss == 4.0
|
|
148
|
+
assert (loss := vars.get_loss(backward=False)) == 4.0, loss
|
|
149
|
+
assert (loss := vars.get_loss(backward=True)) == 4.0, loss
|
|
150
|
+
assert vars.loss_approx == 4.0
|
|
151
|
+
|
|
152
|
+
assert (update := vars.get_update())[0] == 4.0, update
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
def _assert_vars_are_same_(v1: Vars, v2: Vars, clone_update: bool):
|
|
156
|
+
for k,v in v1.__dict__.items():
|
|
157
|
+
if not k.startswith('__'):
|
|
158
|
+
# if k == 'post_step_hooks': continue
|
|
159
|
+
if k == 'update' and clone_update:
|
|
160
|
+
if v1.update is None or v2.update is None:
|
|
161
|
+
assert v1.update is None and v2.update is None, f'{k} is not the same, {v1 = }, {v2 = }'
|
|
162
|
+
else:
|
|
163
|
+
assert (TensorList(v1.update) == TensorList(v2.update)).global_all()
|
|
164
|
+
assert v1.update is not v2.update
|
|
165
|
+
else:
|
|
166
|
+
assert getattr(v2, k) is v, f'{k} is not the same, {v1 = }, {v2 = }'
|
|
167
|
+
|
|
168
|
+
def test_vars_clone():
|
|
169
|
+
model = torch.nn.Sequential(torch.nn.Linear(2,2), torch.nn.Linear(2,4))
|
|
170
|
+
def closure(backward): return 1
|
|
171
|
+
vars = Vars(params=list(model.parameters()), closure=closure, model=model, current_step=0)
|
|
172
|
+
|
|
173
|
+
_assert_vars_are_same_(vars, vars.clone(clone_update=False), clone_update=False)
|
|
174
|
+
_assert_vars_are_same_(vars, vars.clone(clone_update=True), clone_update=True)
|
|
175
|
+
|
|
176
|
+
vars.grad = TensorList(torch.randn(5))
|
|
177
|
+
_assert_vars_are_same_(vars, vars.clone(clone_update=False), clone_update=False)
|
|
178
|
+
_assert_vars_are_same_(vars, vars.clone(clone_update=True), clone_update=True)
|
|
179
|
+
|
|
180
|
+
vars.update = TensorList(torch.randn(5) * 2)
|
|
181
|
+
vars.loss = torch.randn(1)
|
|
182
|
+
vars.loss_approx = vars.loss
|
|
183
|
+
_assert_vars_are_same_(vars, vars.clone(clone_update=False), clone_update=False)
|
|
184
|
+
_assert_vars_are_same_(vars, vars.clone(clone_update=True), clone_update=True)
|
torchzero/__init__.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from . import
|
|
2
|
-
from .
|
|
3
|
-
from . import
|
|
4
|
-
from .
|
|
1
|
+
from . import core, optim, utils
|
|
2
|
+
from .core import Modular
|
|
3
|
+
from .utils import compile
|
|
4
|
+
from . import modules as m
|
torchzero/core/__init__.py
CHANGED
|
@@ -1,13 +1,3 @@
|
|
|
1
|
-
import
|
|
2
|
-
|
|
3
|
-
from .
|
|
4
|
-
OptimizationVars,
|
|
5
|
-
OptimizerModule,
|
|
6
|
-
_Chain,
|
|
7
|
-
_Chainable,
|
|
8
|
-
_get_loss,
|
|
9
|
-
_ScalarLoss,
|
|
10
|
-
_Targets,
|
|
11
|
-
)
|
|
12
|
-
|
|
13
|
-
from .tensorlist_optimizer import TensorListOptimizer, ParamsT, _ClosureType, _maybe_pass_backward
|
|
1
|
+
from .module import Vars, Module, Modular, Chain, maybe_chain, Chainable
|
|
2
|
+
from .transform import Transform, TensorwiseTransform, Target, apply
|
|
3
|
+
from .preconditioner import Preconditioner, TensorwisePreconditioner
|