torchzero 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torchzero/__init__.py +4 -0
- torchzero/core/__init__.py +13 -0
- torchzero/core/module.py +471 -0
- torchzero/core/tensorlist_optimizer.py +219 -0
- torchzero/modules/__init__.py +21 -0
- torchzero/modules/adaptive/__init__.py +4 -0
- torchzero/modules/adaptive/adaptive.py +192 -0
- torchzero/modules/experimental/__init__.py +19 -0
- torchzero/modules/experimental/experimental.py +294 -0
- torchzero/modules/experimental/quad_interp.py +104 -0
- torchzero/modules/experimental/subspace.py +259 -0
- torchzero/modules/gradient_approximation/__init__.py +7 -0
- torchzero/modules/gradient_approximation/_fd_formulas.py +3 -0
- torchzero/modules/gradient_approximation/base_approximator.py +110 -0
- torchzero/modules/gradient_approximation/fdm.py +125 -0
- torchzero/modules/gradient_approximation/forward_gradient.py +163 -0
- torchzero/modules/gradient_approximation/newton_fdm.py +198 -0
- torchzero/modules/gradient_approximation/rfdm.py +125 -0
- torchzero/modules/line_search/__init__.py +30 -0
- torchzero/modules/line_search/armijo.py +56 -0
- torchzero/modules/line_search/base_ls.py +139 -0
- torchzero/modules/line_search/directional_newton.py +217 -0
- torchzero/modules/line_search/grid_ls.py +158 -0
- torchzero/modules/line_search/scipy_minimize_scalar.py +62 -0
- torchzero/modules/meta/__init__.py +12 -0
- torchzero/modules/meta/alternate.py +65 -0
- torchzero/modules/meta/grafting.py +195 -0
- torchzero/modules/meta/optimizer_wrapper.py +173 -0
- torchzero/modules/meta/return_overrides.py +46 -0
- torchzero/modules/misc/__init__.py +10 -0
- torchzero/modules/misc/accumulate.py +43 -0
- torchzero/modules/misc/basic.py +115 -0
- torchzero/modules/misc/lr.py +96 -0
- torchzero/modules/misc/multistep.py +51 -0
- torchzero/modules/misc/on_increase.py +53 -0
- torchzero/modules/momentum/__init__.py +4 -0
- torchzero/modules/momentum/momentum.py +106 -0
- torchzero/modules/operations/__init__.py +29 -0
- torchzero/modules/operations/multi.py +298 -0
- torchzero/modules/operations/reduction.py +134 -0
- torchzero/modules/operations/singular.py +113 -0
- torchzero/modules/optimizers/__init__.py +10 -0
- torchzero/modules/optimizers/adagrad.py +49 -0
- torchzero/modules/optimizers/adam.py +118 -0
- torchzero/modules/optimizers/lion.py +28 -0
- torchzero/modules/optimizers/rmsprop.py +51 -0
- torchzero/modules/optimizers/rprop.py +99 -0
- torchzero/modules/optimizers/sgd.py +54 -0
- torchzero/modules/orthogonalization/__init__.py +2 -0
- torchzero/modules/orthogonalization/newtonschulz.py +159 -0
- torchzero/modules/orthogonalization/svd.py +86 -0
- torchzero/modules/quasi_newton/__init__.py +4 -0
- torchzero/modules/regularization/__init__.py +22 -0
- torchzero/modules/regularization/dropout.py +34 -0
- torchzero/modules/regularization/noise.py +77 -0
- torchzero/modules/regularization/normalization.py +328 -0
- torchzero/modules/regularization/ortho_grad.py +78 -0
- torchzero/modules/regularization/weight_decay.py +92 -0
- torchzero/modules/scheduling/__init__.py +2 -0
- torchzero/modules/scheduling/lr_schedulers.py +131 -0
- torchzero/modules/scheduling/step_size.py +80 -0
- torchzero/modules/second_order/__init__.py +4 -0
- torchzero/modules/second_order/newton.py +165 -0
- torchzero/modules/smoothing/__init__.py +5 -0
- torchzero/modules/smoothing/gaussian_smoothing.py +90 -0
- torchzero/modules/smoothing/laplacian_smoothing.py +128 -0
- torchzero/modules/weight_averaging/__init__.py +2 -0
- torchzero/modules/weight_averaging/ema.py +72 -0
- torchzero/modules/weight_averaging/swa.py +171 -0
- torchzero/optim/__init__.py +10 -0
- torchzero/optim/experimental/__init__.py +20 -0
- torchzero/optim/experimental/experimental.py +343 -0
- torchzero/optim/experimental/ray_search.py +83 -0
- torchzero/optim/first_order/__init__.py +18 -0
- torchzero/optim/first_order/cautious.py +158 -0
- torchzero/optim/first_order/forward_gradient.py +70 -0
- torchzero/optim/first_order/optimizers.py +570 -0
- torchzero/optim/modular.py +132 -0
- torchzero/optim/quasi_newton/__init__.py +1 -0
- torchzero/optim/quasi_newton/directional_newton.py +58 -0
- torchzero/optim/second_order/__init__.py +1 -0
- torchzero/optim/second_order/newton.py +94 -0
- torchzero/optim/wrappers/__init__.py +0 -0
- torchzero/optim/wrappers/nevergrad.py +113 -0
- torchzero/optim/wrappers/nlopt.py +165 -0
- torchzero/optim/wrappers/scipy.py +439 -0
- torchzero/optim/zeroth_order/__init__.py +4 -0
- torchzero/optim/zeroth_order/fdm.py +87 -0
- torchzero/optim/zeroth_order/newton_fdm.py +146 -0
- torchzero/optim/zeroth_order/rfdm.py +217 -0
- torchzero/optim/zeroth_order/rs.py +85 -0
- torchzero/random/__init__.py +1 -0
- torchzero/random/random.py +46 -0
- torchzero/tensorlist.py +819 -0
- torchzero/utils/__init__.py +0 -0
- torchzero/utils/compile.py +39 -0
- torchzero/utils/derivatives.py +99 -0
- torchzero/utils/python_tools.py +25 -0
- torchzero/utils/torch_tools.py +92 -0
- torchzero-0.0.1.dist-info/LICENSE +21 -0
- torchzero-0.0.1.dist-info/METADATA +118 -0
- torchzero-0.0.1.dist-info/RECORD +104 -0
- torchzero-0.0.1.dist-info/WHEEL +5 -0
- torchzero-0.0.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,158 @@
|
|
|
1
|
+
from typing import Any, Literal
|
|
2
|
+
from collections.abc import Sequence
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
from ...tensorlist import TensorList
|
|
8
|
+
from ...core import _ClosureType, OptimizationState
|
|
9
|
+
from .base_ls import LineSearchBase
|
|
10
|
+
|
|
11
|
+
class GridLS(LineSearchBase):
|
|
12
|
+
"""Test all `lrs` and pick best.
|
|
13
|
+
|
|
14
|
+
Args:
|
|
15
|
+
lrs (Sequence[float] | np.ndarray | torch.Tensor): sequence of lrs to test.
|
|
16
|
+
stop_on_improvement (bool, optional): stops if loss improves compared to current loss. Defaults to False.
|
|
17
|
+
stop_on_worsened (bool, optional):
|
|
18
|
+
stops if next lr loss is worse than previous one.
|
|
19
|
+
this assumes that lrs are in ascending order. Defaults to False.
|
|
20
|
+
log_lrs (bool, optional):
|
|
21
|
+
saves lrs and losses with them into optimizer._lrs (for debugging).
|
|
22
|
+
Defaults to False.
|
|
23
|
+
"""
|
|
24
|
+
def __init__(
|
|
25
|
+
self,
|
|
26
|
+
lrs: Sequence[float] | np.ndarray | torch.Tensor,
|
|
27
|
+
stop_on_improvement=False,
|
|
28
|
+
stop_on_worsened=False,
|
|
29
|
+
log_lrs = False,
|
|
30
|
+
):
|
|
31
|
+
super().__init__({}, maxiter=None, log_lrs=log_lrs)
|
|
32
|
+
self.lrs = lrs
|
|
33
|
+
self.stop_on_improvement = stop_on_improvement
|
|
34
|
+
self.stop_on_worsened = stop_on_worsened
|
|
35
|
+
|
|
36
|
+
@torch.no_grad
|
|
37
|
+
def _find_best_lr(self, state: OptimizationState, params: TensorList) -> float:
|
|
38
|
+
if state.closure is None: raise ValueError("closure is not set")
|
|
39
|
+
if state.ascent is None: raise ValueError("ascent_direction is not set")
|
|
40
|
+
|
|
41
|
+
if self.stop_on_improvement:
|
|
42
|
+
if state.fx0 is None: state.fx0 = state.closure(False)
|
|
43
|
+
self._lowest_loss = state.fx0
|
|
44
|
+
|
|
45
|
+
for lr in self.lrs:
|
|
46
|
+
loss = self._evaluate_lr_(float(lr), state.closure, state.ascent, params)
|
|
47
|
+
|
|
48
|
+
# if worsened
|
|
49
|
+
if self.stop_on_worsened and loss != self._lowest_loss:
|
|
50
|
+
break
|
|
51
|
+
|
|
52
|
+
# if improved
|
|
53
|
+
if self.stop_on_improvement and loss == self._lowest_loss:
|
|
54
|
+
break
|
|
55
|
+
|
|
56
|
+
return float(self._best_lr)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class MultiplicativeLS(GridLS):
|
|
61
|
+
"""Starts with `init` lr, then keeps multiplying it by `mul` until loss stops decreasing.
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
init (float, optional): initial lr. Defaults to 0.001.
|
|
65
|
+
mul (float, optional): lr multiplier. Defaults to 2.
|
|
66
|
+
num (int, optional): maximum number of multiplication steps. Defaults to 10.
|
|
67
|
+
stop_on_improvement (bool, optional): stops if loss improves compared to current loss. Defaults to False.
|
|
68
|
+
stop_on_worsened (bool, optional):
|
|
69
|
+
stops if next lr loss is worse than previous one.
|
|
70
|
+
this assumes that lrs are in ascending order. Defaults to False.
|
|
71
|
+
log_lrs (bool, optional):
|
|
72
|
+
saves lrs and losses with them into optimizer._lrs (for debugging).
|
|
73
|
+
Defaults to False.
|
|
74
|
+
"""
|
|
75
|
+
def __init__(
|
|
76
|
+
self,
|
|
77
|
+
init: float = 0.001,
|
|
78
|
+
mul: float = 2,
|
|
79
|
+
num=10,
|
|
80
|
+
stop_on_improvement=False,
|
|
81
|
+
stop_on_worsened=True,
|
|
82
|
+
):
|
|
83
|
+
super().__init__(
|
|
84
|
+
[init * mul**i for i in range(num)],
|
|
85
|
+
stop_on_improvement=stop_on_improvement,
|
|
86
|
+
stop_on_worsened=stop_on_worsened,
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
class BacktrackingLS(GridLS):
|
|
90
|
+
"""tests `init` lr, and keeps multiplying it by `mul` until loss becomes better than initial loss.
|
|
91
|
+
|
|
92
|
+
note: this doesn't include Armijo–Goldstein condition.
|
|
93
|
+
|
|
94
|
+
Args:
|
|
95
|
+
init (float, optional): initial lr. Defaults to 1.
|
|
96
|
+
mul (float, optional): lr multiplier. Defaults to 0.5.
|
|
97
|
+
num (int, optional): maximum number of multiplication steps. Defaults to 10.
|
|
98
|
+
stop_on_improvement (bool, optional): stops if loss improves compared to current loss. Defaults to False.
|
|
99
|
+
stop_on_worsened (bool, optional):
|
|
100
|
+
stops if next lr loss is worse than previous one.
|
|
101
|
+
this assumes that lrs are in ascending order. Defaults to False.
|
|
102
|
+
log_lrs (bool, optional):
|
|
103
|
+
saves lrs and losses with them into optimizer._lrs (for debugging).
|
|
104
|
+
Defaults to False.
|
|
105
|
+
|
|
106
|
+
"""
|
|
107
|
+
def __init__(
|
|
108
|
+
self,
|
|
109
|
+
init: float = 1,
|
|
110
|
+
mul: float = 0.5,
|
|
111
|
+
num=10,
|
|
112
|
+
stop_on_improvement=True,
|
|
113
|
+
stop_on_worsened=False,
|
|
114
|
+
log_lrs = False,
|
|
115
|
+
):
|
|
116
|
+
super().__init__(
|
|
117
|
+
[init * mul**i for i in range(num)],
|
|
118
|
+
stop_on_improvement=stop_on_improvement,
|
|
119
|
+
stop_on_worsened=stop_on_worsened,
|
|
120
|
+
log_lrs = log_lrs,
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
class LinspaceLS(GridLS):
|
|
124
|
+
"""Test all learning rates from a linspace and pick best."""
|
|
125
|
+
def __init__(
|
|
126
|
+
self,
|
|
127
|
+
start: float = 0.001,
|
|
128
|
+
end: float = 2,
|
|
129
|
+
steps=10,
|
|
130
|
+
stop_on_improvement=False,
|
|
131
|
+
stop_on_worsened=False,
|
|
132
|
+
log_lrs = False,
|
|
133
|
+
):
|
|
134
|
+
super().__init__(
|
|
135
|
+
torch.linspace(start, end, steps),
|
|
136
|
+
stop_on_improvement=stop_on_improvement,
|
|
137
|
+
stop_on_worsened=stop_on_worsened,
|
|
138
|
+
log_lrs = log_lrs,
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
class ArangeLS(GridLS):
|
|
142
|
+
"""Test all learning rates from a linspace and pick best."""
|
|
143
|
+
def __init__(
|
|
144
|
+
self,
|
|
145
|
+
start: float = 0.001,
|
|
146
|
+
end: float = 2,
|
|
147
|
+
step=0.1,
|
|
148
|
+
stop_on_improvement=False,
|
|
149
|
+
stop_on_worsened=False,
|
|
150
|
+
log_lrs = False,
|
|
151
|
+
|
|
152
|
+
):
|
|
153
|
+
super().__init__(
|
|
154
|
+
torch.arange(start, end, step),
|
|
155
|
+
stop_on_improvement=stop_on_improvement,
|
|
156
|
+
stop_on_worsened=stop_on_worsened,
|
|
157
|
+
log_lrs = log_lrs,
|
|
158
|
+
)
|
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
import typing
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
try:
|
|
5
|
+
import scipy.optimize as scopt
|
|
6
|
+
except ModuleNotFoundError:
|
|
7
|
+
scopt = typing.cast(typing.Any, None)
|
|
8
|
+
|
|
9
|
+
from ...tensorlist import TensorList
|
|
10
|
+
from ...core import OptimizationState
|
|
11
|
+
|
|
12
|
+
from .base_ls import LineSearchBase, MaxIterReached
|
|
13
|
+
|
|
14
|
+
if typing.TYPE_CHECKING:
|
|
15
|
+
import scipy.optimize as scopt
|
|
16
|
+
|
|
17
|
+
class ScipyMinimizeScalarLS(LineSearchBase):
|
|
18
|
+
"""Line search via `scipy.optimize.minimize_scalar`. All args except maxiter are the same as for it.
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
method (Optional[str], optional): 'brent', 'golden' or 'bounded'. Defaults to None.
|
|
22
|
+
maxiter (Optional[int], optional): hard limit on maximum number of function evaluations. Defaults to None.
|
|
23
|
+
bracket (optional): bracket. Defaults to None.
|
|
24
|
+
bounds (optional): bounds. Defaults to None.
|
|
25
|
+
tol (Optional[float], optional): some kind of tolerance. Defaults to None.
|
|
26
|
+
options (optional): options for method. Defaults to None.
|
|
27
|
+
log_lrs (bool, optional): logs lrs and values into `_lrs`. Defaults to False.
|
|
28
|
+
"""
|
|
29
|
+
def __init__(
|
|
30
|
+
self,
|
|
31
|
+
method: str | None = None,
|
|
32
|
+
maxiter: int | None = None,
|
|
33
|
+
bracket = None,
|
|
34
|
+
bounds = None,
|
|
35
|
+
tol: float | None = None,
|
|
36
|
+
options = None,
|
|
37
|
+
log_lrs = False,
|
|
38
|
+
):
|
|
39
|
+
if scopt is None: raise ModuleNotFoundError("scipy is not installed")
|
|
40
|
+
super().__init__({}, maxiter=maxiter, log_lrs=log_lrs)
|
|
41
|
+
self.method = method
|
|
42
|
+
self.tol = tol
|
|
43
|
+
self.bracket = bracket
|
|
44
|
+
self.bounds = bounds
|
|
45
|
+
self.options = options
|
|
46
|
+
|
|
47
|
+
@torch.no_grad
|
|
48
|
+
def _find_best_lr(self, state: OptimizationState, params: TensorList) -> float:
|
|
49
|
+
try:
|
|
50
|
+
res = scopt.minimize_scalar(
|
|
51
|
+
self._evaluate_lr_ensure_float,
|
|
52
|
+
args = (state.closure, state.ascent, params),
|
|
53
|
+
method = self.method,
|
|
54
|
+
tol = self.tol,
|
|
55
|
+
bracket = self.bracket,
|
|
56
|
+
bounds = self.bounds,
|
|
57
|
+
options = self.options,
|
|
58
|
+
) # type:ignore
|
|
59
|
+
except MaxIterReached:
|
|
60
|
+
pass
|
|
61
|
+
|
|
62
|
+
return float(self._best_lr)
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
"""Modules that use other modules."""
|
|
2
|
+
# from .chain import Chain, ChainReturn
|
|
3
|
+
import sys
|
|
4
|
+
|
|
5
|
+
from .alternate import Alternate
|
|
6
|
+
from .grafting import Graft, IntermoduleCautious, SignGrafting
|
|
7
|
+
from .return_overrides import ReturnAscent, ReturnClosure, SetGrad
|
|
8
|
+
|
|
9
|
+
# if sys.version_info[1] < 12:
|
|
10
|
+
from .optimizer_wrapper import Wrap, WrapClosure
|
|
11
|
+
# else:
|
|
12
|
+
# from .optimizer_wrapper import Wrap, WrapClosure
|
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
import random
|
|
2
|
+
from collections.abc import Iterable
|
|
3
|
+
from typing import Any, Literal
|
|
4
|
+
|
|
5
|
+
from ...core import OptimizerModule, _Chainable
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class Alternate(OptimizerModule):
|
|
9
|
+
"""Alternates stepping with multiple modules.
|
|
10
|
+
|
|
11
|
+
Args:
|
|
12
|
+
modules (Iterable[OptimizerModule | Iterable[OptimizerModule]]): modules to alternate between.
|
|
13
|
+
mode (int | list[int] | tuple[int] | "random"], optional):
|
|
14
|
+
can be integer - number of repeats for all modules;
|
|
15
|
+
list or tuple of integers per each module with number of repeats;
|
|
16
|
+
"random" to pick module randomly each time. Defaults to 1.
|
|
17
|
+
seed (int | None, optional): seed for "random" mode. Defaults to None.
|
|
18
|
+
"""
|
|
19
|
+
def __init__(
|
|
20
|
+
self,
|
|
21
|
+
modules: Iterable[_Chainable],
|
|
22
|
+
mode: int | list[int] | tuple[int] | Literal["random"] = 1,
|
|
23
|
+
seed: int | None = None
|
|
24
|
+
):
|
|
25
|
+
super().__init__({})
|
|
26
|
+
modules = list(modules)
|
|
27
|
+
|
|
28
|
+
for i,m in enumerate(modules):
|
|
29
|
+
self._set_child_(i, m)
|
|
30
|
+
|
|
31
|
+
self.random = random.Random(seed)
|
|
32
|
+
|
|
33
|
+
if isinstance(mode, int): mode = [mode for _ in modules]
|
|
34
|
+
self.mode: list[int] | tuple[int] | Literal['random'] = mode
|
|
35
|
+
|
|
36
|
+
self.cur = 0
|
|
37
|
+
if self.mode == 'random': self.remaining = 0
|
|
38
|
+
else:
|
|
39
|
+
self.remaining = self.mode[0]
|
|
40
|
+
if len(self.mode) != len(self.children):
|
|
41
|
+
raise ValueError(f"got {len(self.children)} modules but {len(mode)} repeats, they should be the same")
|
|
42
|
+
|
|
43
|
+
def step(self, state):
|
|
44
|
+
if self.mode == 'random':
|
|
45
|
+
module = self.random.choice(list(self.children.values()))
|
|
46
|
+
|
|
47
|
+
else:
|
|
48
|
+
if self.remaining == 0:
|
|
49
|
+
self.cur += 1
|
|
50
|
+
|
|
51
|
+
if self.cur >= len(self.mode):
|
|
52
|
+
self.cur = 0
|
|
53
|
+
|
|
54
|
+
if self.remaining == 0: self.remaining = self.mode[self.cur]
|
|
55
|
+
|
|
56
|
+
module = self.children[self.cur]
|
|
57
|
+
|
|
58
|
+
self.remaining -= 1
|
|
59
|
+
|
|
60
|
+
if self.next_module is None:
|
|
61
|
+
return module.step(state)
|
|
62
|
+
|
|
63
|
+
state.ascent = module.return_ascent(state)
|
|
64
|
+
return self._update_params_or_step_with_next(state)
|
|
65
|
+
|
|
@@ -0,0 +1,195 @@
|
|
|
1
|
+
from collections.abc import Iterable
|
|
2
|
+
from typing import Literal
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from ...core import OptimizerModule
|
|
6
|
+
from ...tensorlist import TensorList
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class Graft(OptimizerModule):
|
|
10
|
+
"""
|
|
11
|
+
Optimizer grafting (magnitude#direction).
|
|
12
|
+
Takes update of one optimizer and makes its norm same as update of another optimizer.
|
|
13
|
+
Can be applied to all weights or layerwise.
|
|
14
|
+
|
|
15
|
+
Args:
|
|
16
|
+
magnitude (OptimizerModule | Iterable[OptimizerModule]):
|
|
17
|
+
module to use magnitude from.
|
|
18
|
+
If sequence of modules is provided, they will be chained.
|
|
19
|
+
direction (OptimizerModule | Iterable[OptimizerModule]):
|
|
20
|
+
module/modules to use direction from.
|
|
21
|
+
If sequence of modules is provided, they will be chained.
|
|
22
|
+
ord (int, optional): norm type. Defaults to 2.
|
|
23
|
+
eps (float, optional): epsilon for numerical stability. Defaults to 1e-8.
|
|
24
|
+
layerwise (bool, optional): whether to apply grafting layerwise. Defaults to False.
|
|
25
|
+
|
|
26
|
+
reference
|
|
27
|
+
*Agarwal, N., Anil, R., Hazan, E., Koren, T., & Zhang, C.
|
|
28
|
+
Learning Rate Grafting: Transferability of Optimizer Tuning.*
|
|
29
|
+
"""
|
|
30
|
+
def __init__(
|
|
31
|
+
self,
|
|
32
|
+
magnitude: OptimizerModule | Iterable[OptimizerModule],
|
|
33
|
+
direction: OptimizerModule | Iterable[OptimizerModule],
|
|
34
|
+
ord: float = 2,
|
|
35
|
+
eps: float = 1e-8,
|
|
36
|
+
layerwise: bool = False,
|
|
37
|
+
# TODO: channelwise
|
|
38
|
+
):
|
|
39
|
+
super().__init__({})
|
|
40
|
+
self._set_child_('magnitude', magnitude)
|
|
41
|
+
self._set_child_('direction', direction)
|
|
42
|
+
self.ord = ord
|
|
43
|
+
self.eps = eps
|
|
44
|
+
self.layerwise = layerwise
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
@torch.no_grad
|
|
48
|
+
def step(self, state):
|
|
49
|
+
state_copy = state.copy(clone_ascent=True)
|
|
50
|
+
magnitude = self.children['magnitude'].return_ascent(state_copy)
|
|
51
|
+
|
|
52
|
+
if state_copy.grad is not None: state.grad = state_copy.grad
|
|
53
|
+
if state_copy.fx0 is not None: state.fx0 = state_copy.fx0
|
|
54
|
+
if state_copy.fx0_approx is not None: state.fx0_approx = state_copy.fx0_approx
|
|
55
|
+
|
|
56
|
+
direction = self.children['direction'].return_ascent(state)
|
|
57
|
+
|
|
58
|
+
if self.layerwise:
|
|
59
|
+
M = magnitude.norm(self.ord)
|
|
60
|
+
D = direction.norm(self.ord)
|
|
61
|
+
D.select_set_(D == 0, M)
|
|
62
|
+
|
|
63
|
+
else:
|
|
64
|
+
M = magnitude.total_vector_norm(self.ord)
|
|
65
|
+
D = direction.total_vector_norm(self.ord)
|
|
66
|
+
if D == 0: D = M
|
|
67
|
+
|
|
68
|
+
state.ascent = direction.mul_(M / (D + self.eps))
|
|
69
|
+
return self._update_params_or_step_with_next(state)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
class SignGrafting(OptimizerModule):
|
|
74
|
+
"""Weight-wise grafting-like operation where sign of the ascent is taken from first module
|
|
75
|
+
and magnitude from second module.
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
magnitude (OptimizerModule | Iterable[OptimizerModule]):
|
|
79
|
+
module to take magnitude from.
|
|
80
|
+
If sequence of modules is provided, they will be chained.
|
|
81
|
+
sign (OptimizerModule | Iterable[OptimizerModule]):
|
|
82
|
+
module to take sign from.
|
|
83
|
+
If sequence of modules is provided, they will be chained.
|
|
84
|
+
"""
|
|
85
|
+
def __init__(
|
|
86
|
+
self,
|
|
87
|
+
magnitude: OptimizerModule | Iterable[OptimizerModule],
|
|
88
|
+
sign: OptimizerModule | Iterable[OptimizerModule],
|
|
89
|
+
):
|
|
90
|
+
super().__init__({})
|
|
91
|
+
|
|
92
|
+
self._set_child_('magnitude', magnitude)
|
|
93
|
+
self._set_child_('sign', sign)
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
@torch.no_grad
|
|
97
|
+
def step(self, state):
|
|
98
|
+
state_copy = state.copy(clone_ascent=True)
|
|
99
|
+
magnitude = self.children['magnitude'].return_ascent(state_copy)
|
|
100
|
+
|
|
101
|
+
# make sure to store grad and fx0 if it was calculated
|
|
102
|
+
state.update_attrs_(state_copy)
|
|
103
|
+
|
|
104
|
+
sign = self.children['sign'].return_ascent(state)
|
|
105
|
+
|
|
106
|
+
state.ascent = magnitude.copysign_(sign)
|
|
107
|
+
return self._update_params_or_step_with_next(state)
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
class IntermoduleCautious(OptimizerModule):
|
|
111
|
+
"""Negates update for parameters where updates of two modules or module chains have inconsistent sign.
|
|
112
|
+
Optionally normalizes the update by the number of parameters that are not masked.
|
|
113
|
+
|
|
114
|
+
Args:
|
|
115
|
+
main_module (OptimizerModule | Iterable[OptimizerModule]):
|
|
116
|
+
main module or sequence of modules to chain, which update will be used with a consistency mask applied.
|
|
117
|
+
compare_module (OptimizerModule | Iterable[OptimizerModule]):
|
|
118
|
+
module or sequence of modules to chain, which update will be used to compute a consistency mask.
|
|
119
|
+
Can also be set to `ascent` to compare to update that is passed `main_module`, or `grad` to compare
|
|
120
|
+
to gradients.
|
|
121
|
+
normalize (bool, optional):
|
|
122
|
+
renormalize update after masking.
|
|
123
|
+
only has effect when mode is 'zero'. Defaults to False.
|
|
124
|
+
eps (float, optional): epsilon for normalization. Defaults to 1e-6.
|
|
125
|
+
mode (str, optional):
|
|
126
|
+
what to do with updates with inconsistent signs.
|
|
127
|
+
|
|
128
|
+
"zero" - set them to zero (as in paper)
|
|
129
|
+
|
|
130
|
+
"grad" - set them to the gradient
|
|
131
|
+
|
|
132
|
+
"compare_module" - set them to `compare_module`'s update
|
|
133
|
+
|
|
134
|
+
"negate" - negate them (same as using update magnitude and gradient sign)
|
|
135
|
+
"""
|
|
136
|
+
def __init__(
|
|
137
|
+
self,
|
|
138
|
+
main_module: OptimizerModule | Iterable[OptimizerModule],
|
|
139
|
+
compare_module: OptimizerModule | Iterable[OptimizerModule] | Literal['ascent', 'grad'],
|
|
140
|
+
normalize=False,
|
|
141
|
+
eps=1e-6,
|
|
142
|
+
mode: Literal["zero", "grad", "backtrack", "compare_module"] = "zero",
|
|
143
|
+
):
|
|
144
|
+
super().__init__({})
|
|
145
|
+
|
|
146
|
+
self._set_child_('main',main_module)
|
|
147
|
+
if isinstance(compare_module, str): self.compare_mode = compare_module
|
|
148
|
+
else:
|
|
149
|
+
self._set_child_('compare', compare_module)
|
|
150
|
+
self.compare_mode = 'module'
|
|
151
|
+
self.eps = eps
|
|
152
|
+
self.normalize = normalize
|
|
153
|
+
self.mode: Literal["zero", "grad", "backtrack", "compare_module"] = mode
|
|
154
|
+
|
|
155
|
+
@torch.no_grad
|
|
156
|
+
def step(self, state):
|
|
157
|
+
params = None
|
|
158
|
+
state_copy = state.copy(clone_ascent=True)
|
|
159
|
+
ascent = self.children['main'].return_ascent(state_copy)
|
|
160
|
+
state.update_attrs_(state_copy)
|
|
161
|
+
|
|
162
|
+
if self.compare_mode == 'module': compare = self.children['compare'].return_ascent(state)
|
|
163
|
+
else:
|
|
164
|
+
params = self.get_params()
|
|
165
|
+
if self.compare_mode == 'ascent': compare: TensorList = state.maybe_use_grad_(params)
|
|
166
|
+
elif self.compare_mode == 'grad': compare: TensorList = state.maybe_compute_grad_(params)
|
|
167
|
+
else: raise ValueError(f'Invalid compare_module: {self.compare_mode}')
|
|
168
|
+
|
|
169
|
+
# mask will be > 0 for parameters where both signs are the same
|
|
170
|
+
mask = (ascent * compare) > 0
|
|
171
|
+
|
|
172
|
+
if self.mode == 'backtrack':
|
|
173
|
+
ascent -= ascent.mul(2).mul_(mask.logical_not_())
|
|
174
|
+
|
|
175
|
+
else:
|
|
176
|
+
# normalize if mode is `zero`
|
|
177
|
+
if self.normalize and self.mode == 'zero':
|
|
178
|
+
fmask = mask.to(ascent[0].dtype)
|
|
179
|
+
fmask /= fmask.total_mean() + self.eps
|
|
180
|
+
else:
|
|
181
|
+
fmask = mask
|
|
182
|
+
|
|
183
|
+
# apply the mask
|
|
184
|
+
ascent *= fmask
|
|
185
|
+
|
|
186
|
+
if self.mode == 'grad':
|
|
187
|
+
params = self.get_params()
|
|
188
|
+
ascent += state.maybe_compute_grad_(params) * mask.logical_not_()
|
|
189
|
+
|
|
190
|
+
elif self.mode == 'compare_module':
|
|
191
|
+
ascent += compare * mask.logical_not_()
|
|
192
|
+
|
|
193
|
+
state.ascent = ascent
|
|
194
|
+
return self._update_params_or_step_with_next(state, params)
|
|
195
|
+
|
|
@@ -0,0 +1,173 @@
|
|
|
1
|
+
from collections.abc import Callable, Sequence
|
|
2
|
+
from typing import Any, overload
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
from typing_extensions import Concatenate, ParamSpec
|
|
6
|
+
|
|
7
|
+
from ...core import OptimizerModule
|
|
8
|
+
from .return_overrides import SetGrad
|
|
9
|
+
|
|
10
|
+
K = ParamSpec('K')
|
|
11
|
+
|
|
12
|
+
class Wrap(OptimizerModule):
|
|
13
|
+
"""
|
|
14
|
+
Wraps any torch.optim.Optimizer.
|
|
15
|
+
|
|
16
|
+
Sets .grad attribute to the current update and steps with the `optimizer`.
|
|
17
|
+
|
|
18
|
+
Additionally, if this is not the last module, this takes the update of `optimizer`,
|
|
19
|
+
undoes it and passes to the next module instead. That means you can chain multiple
|
|
20
|
+
optimizers together.
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
optimizer (torch.optim.Optimizer): optimizer to wrap,
|
|
24
|
+
or a callable (class) that constructs the optimizer.
|
|
25
|
+
kwargs:
|
|
26
|
+
if class is passed, kwargs are passed to the constructor.
|
|
27
|
+
parameters are passed separately and automatically
|
|
28
|
+
which is the point of passing a constructor
|
|
29
|
+
instead of an optimizer directly.
|
|
30
|
+
|
|
31
|
+
This can be constructed in two ways.
|
|
32
|
+
.. code-block:: python
|
|
33
|
+
wrapper = OptimizerWrapper(torch.optim.SGD(model.parameters(), lr = 0.1))
|
|
34
|
+
# or
|
|
35
|
+
wrapper = OptimizerWrapper(torch.optim.SGD, lr = 0.1)
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
@overload
|
|
39
|
+
def __init__(self, optimizer: torch.optim.Optimizer): ...
|
|
40
|
+
@overload
|
|
41
|
+
# def __init__[**K](
|
|
42
|
+
def __init__(
|
|
43
|
+
self,
|
|
44
|
+
optimizer: Callable[Concatenate[Any, K], torch.optim.Optimizer],
|
|
45
|
+
*args: K.args,
|
|
46
|
+
**kwargs: K.kwargs,
|
|
47
|
+
# optimizer: abc.Callable[..., torch.optim.Optimizer],
|
|
48
|
+
# *args,
|
|
49
|
+
# **kwargs,
|
|
50
|
+
): ...
|
|
51
|
+
def __init__(self, optimizer, *args, **kwargs):
|
|
52
|
+
|
|
53
|
+
super().__init__({})
|
|
54
|
+
self._optimizer_cls: torch.optim.Optimizer | Callable[..., torch.optim.Optimizer] = optimizer
|
|
55
|
+
self._args = args
|
|
56
|
+
self._kwargs = kwargs
|
|
57
|
+
|
|
58
|
+
def _initialize_(self, params, set_passed_params):
|
|
59
|
+
"""Initializes this optimizer and all children with the given parameters."""
|
|
60
|
+
super()._initialize_(params, set_passed_params=set_passed_params)
|
|
61
|
+
if isinstance(self._optimizer_cls, torch.optim.Optimizer) or not callable(self._optimizer_cls):
|
|
62
|
+
self.optimizer = self._optimizer_cls
|
|
63
|
+
else:
|
|
64
|
+
self.optimizer = self._optimizer_cls(params, *self._args, **self._kwargs)
|
|
65
|
+
|
|
66
|
+
@torch.no_grad
|
|
67
|
+
def step(self, state):
|
|
68
|
+
# check attrs
|
|
69
|
+
# if self.pass_closure:
|
|
70
|
+
# if state.closure is None: raise ValueError('ClosureOptimizerWrapper requires closure.')
|
|
71
|
+
# if state.ascent is not None:
|
|
72
|
+
# raise ValueError('pass_closure = True, means ascent must be None (not sure though)')
|
|
73
|
+
|
|
74
|
+
params = self.get_params()
|
|
75
|
+
|
|
76
|
+
if self.next_module is None:
|
|
77
|
+
# set grad to ascent and make a step with the optimizer
|
|
78
|
+
g = state.maybe_use_grad_(params)
|
|
79
|
+
params.set_grad_(g)
|
|
80
|
+
state.fx0 = self.optimizer.step()
|
|
81
|
+
return state.get_loss()
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
params_before_step = params.clone()
|
|
85
|
+
|
|
86
|
+
g = state.maybe_use_grad_(params)
|
|
87
|
+
params.set_grad_(g)
|
|
88
|
+
state.fx0 = self.optimizer.step()
|
|
89
|
+
|
|
90
|
+
# calculate update as difference in params
|
|
91
|
+
state.ascent = params_before_step - params
|
|
92
|
+
params.set_(params_before_step)
|
|
93
|
+
return self.next_module.step(state)
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
class WrapClosure(OptimizerModule):
|
|
97
|
+
"""
|
|
98
|
+
Wraps any torch.optim.Optimizer. This only works with modules with :code:`target = "Closure"` argument.
|
|
99
|
+
The modified closure will be passed to the optimizer.
|
|
100
|
+
|
|
101
|
+
Alternative any module can be turned into a closure module by using :any:`MakeClosure` module,
|
|
102
|
+
in that case this should be placed after MakeClosure.
|
|
103
|
+
|
|
104
|
+
Args:
|
|
105
|
+
optimizer (torch.optim.Optimizer): optimizer to wrap,
|
|
106
|
+
or a callable (class) that constructs the optimizer.
|
|
107
|
+
kwargs:
|
|
108
|
+
if class is passed, kwargs are passed to the constructor.
|
|
109
|
+
parameters are passed separately and automatically
|
|
110
|
+
which is the point of passing a constructor
|
|
111
|
+
instead of an optimizer directly.
|
|
112
|
+
|
|
113
|
+
This can be constructed in two ways.
|
|
114
|
+
|
|
115
|
+
.. code-block:: python
|
|
116
|
+
|
|
117
|
+
wrapper = OptimizerWrapper(torch.optim.SGD(model.parameters(), lr = 0.1))
|
|
118
|
+
# or
|
|
119
|
+
wrapper = OptimizerWrapper(torch.optim.SGD, lr = 0.1)
|
|
120
|
+
|
|
121
|
+
"""
|
|
122
|
+
|
|
123
|
+
@overload
|
|
124
|
+
def __init__(self, optimizer: torch.optim.Optimizer,): ...
|
|
125
|
+
@overload
|
|
126
|
+
def __init__(
|
|
127
|
+
self,
|
|
128
|
+
optimizer: Callable[Concatenate[Any, K], torch.optim.Optimizer],
|
|
129
|
+
*args: K.args,
|
|
130
|
+
**kwargs: K.kwargs,
|
|
131
|
+
# optimizer: abc.Callable[..., torch.optim.Optimizer],
|
|
132
|
+
# *args,
|
|
133
|
+
# **kwargs,
|
|
134
|
+
): ...
|
|
135
|
+
def __init__(self, optimizer, *args, **kwargs):
|
|
136
|
+
|
|
137
|
+
super().__init__({})
|
|
138
|
+
self._optimizer_cls: torch.optim.Optimizer | Callable[..., torch.optim.Optimizer] = optimizer
|
|
139
|
+
self._args = args
|
|
140
|
+
self._kwargs = kwargs
|
|
141
|
+
|
|
142
|
+
def _initialize_(self, params, set_passed_params):
|
|
143
|
+
"""Initializes this optimizer and all children with the given parameters."""
|
|
144
|
+
super()._initialize_(params, set_passed_params=set_passed_params)
|
|
145
|
+
if isinstance(self._optimizer_cls, torch.optim.Optimizer) or not callable(self._optimizer_cls):
|
|
146
|
+
self.optimizer = self._optimizer_cls
|
|
147
|
+
else:
|
|
148
|
+
self.optimizer = self._optimizer_cls(params, *self._args, **self._kwargs)
|
|
149
|
+
|
|
150
|
+
@torch.no_grad
|
|
151
|
+
def step(self, state):
|
|
152
|
+
# check attrs
|
|
153
|
+
# if self.pass_closure:
|
|
154
|
+
# if state.closure is None: raise ValueError('ClosureOptimizerWrapper requires closure.')
|
|
155
|
+
# if state.ascent is not None:
|
|
156
|
+
# raise ValueError('pass_closure = True, means ascent must be None (not sure though)')
|
|
157
|
+
|
|
158
|
+
params = self.get_params()
|
|
159
|
+
|
|
160
|
+
if self.next_module is None:
|
|
161
|
+
# set grad to ascent and make a step with the optimizer
|
|
162
|
+
state.fx0 = self.optimizer.step(state.closure) # type:ignore
|
|
163
|
+
return state.get_loss()
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
params_before_step = params.clone()
|
|
167
|
+
state.fx0 = self.optimizer.step(state.closure) # type:ignore
|
|
168
|
+
|
|
169
|
+
# calculate update as difference in params
|
|
170
|
+
state.ascent = params_before_step - params
|
|
171
|
+
params.set_(params_before_step)
|
|
172
|
+
return self.next_module.step(state)
|
|
173
|
+
|