torchzero 0.3.11__py3-none-any.whl → 0.3.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_opts.py +95 -69
- tests/test_tensorlist.py +8 -7
- torchzero/__init__.py +1 -1
- torchzero/core/__init__.py +2 -2
- torchzero/core/module.py +225 -72
- torchzero/core/reformulation.py +65 -0
- torchzero/core/transform.py +44 -24
- torchzero/modules/__init__.py +13 -5
- torchzero/modules/{optimizers → adaptive}/__init__.py +5 -2
- torchzero/modules/adaptive/adagrad.py +356 -0
- torchzero/modules/{optimizers → adaptive}/adahessian.py +53 -52
- torchzero/modules/{optimizers → adaptive}/adam.py +0 -3
- torchzero/modules/{optimizers → adaptive}/adan.py +26 -40
- torchzero/modules/{optimizers → adaptive}/adaptive_heavyball.py +3 -6
- torchzero/modules/adaptive/aegd.py +54 -0
- torchzero/modules/{optimizers → adaptive}/esgd.py +1 -1
- torchzero/modules/{optimizers/ladagrad.py → adaptive/lmadagrad.py} +42 -39
- torchzero/modules/{optimizers → adaptive}/mars.py +24 -36
- torchzero/modules/adaptive/matrix_momentum.py +146 -0
- torchzero/modules/{optimizers → adaptive}/msam.py +14 -12
- torchzero/modules/{optimizers → adaptive}/muon.py +19 -20
- torchzero/modules/adaptive/natural_gradient.py +175 -0
- torchzero/modules/{optimizers → adaptive}/rprop.py +0 -2
- torchzero/modules/{optimizers → adaptive}/sam.py +1 -1
- torchzero/modules/{optimizers → adaptive}/shampoo.py +8 -4
- torchzero/modules/{optimizers → adaptive}/soap.py +27 -50
- torchzero/modules/{optimizers → adaptive}/sophia_h.py +2 -3
- torchzero/modules/clipping/clipping.py +85 -92
- torchzero/modules/clipping/ema_clipping.py +5 -5
- torchzero/modules/conjugate_gradient/__init__.py +11 -0
- torchzero/modules/{quasi_newton → conjugate_gradient}/cg.py +355 -369
- torchzero/modules/experimental/__init__.py +9 -32
- torchzero/modules/experimental/dct.py +2 -2
- torchzero/modules/experimental/fft.py +2 -2
- torchzero/modules/experimental/gradmin.py +4 -3
- torchzero/modules/experimental/l_infinity.py +111 -0
- torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +3 -40
- torchzero/modules/experimental/newton_solver.py +79 -17
- torchzero/modules/experimental/newtonnewton.py +27 -14
- torchzero/modules/experimental/scipy_newton_cg.py +105 -0
- torchzero/modules/experimental/structural_projections.py +1 -1
- torchzero/modules/functional.py +50 -14
- torchzero/modules/grad_approximation/fdm.py +19 -20
- torchzero/modules/grad_approximation/forward_gradient.py +4 -2
- torchzero/modules/grad_approximation/grad_approximator.py +43 -47
- torchzero/modules/grad_approximation/rfdm.py +144 -122
- torchzero/modules/higher_order/__init__.py +1 -1
- torchzero/modules/higher_order/higher_order_newton.py +31 -23
- torchzero/modules/least_squares/__init__.py +1 -0
- torchzero/modules/least_squares/gn.py +161 -0
- torchzero/modules/line_search/__init__.py +2 -2
- torchzero/modules/line_search/_polyinterp.py +289 -0
- torchzero/modules/line_search/adaptive.py +69 -44
- torchzero/modules/line_search/backtracking.py +83 -70
- torchzero/modules/line_search/line_search.py +159 -68
- torchzero/modules/line_search/scipy.py +1 -1
- torchzero/modules/line_search/strong_wolfe.py +319 -218
- torchzero/modules/misc/__init__.py +8 -0
- torchzero/modules/misc/debug.py +4 -4
- torchzero/modules/misc/escape.py +9 -7
- torchzero/modules/misc/gradient_accumulation.py +88 -22
- torchzero/modules/misc/homotopy.py +59 -0
- torchzero/modules/misc/misc.py +82 -15
- torchzero/modules/misc/multistep.py +47 -11
- torchzero/modules/misc/regularization.py +5 -9
- torchzero/modules/misc/split.py +55 -35
- torchzero/modules/misc/switch.py +1 -1
- torchzero/modules/momentum/__init__.py +1 -5
- torchzero/modules/momentum/averaging.py +3 -3
- torchzero/modules/momentum/cautious.py +42 -47
- torchzero/modules/momentum/momentum.py +35 -1
- torchzero/modules/ops/__init__.py +9 -1
- torchzero/modules/ops/binary.py +9 -8
- torchzero/modules/{momentum/ema.py → ops/higher_level.py} +10 -33
- torchzero/modules/ops/multi.py +15 -15
- torchzero/modules/ops/reduce.py +1 -1
- torchzero/modules/ops/utility.py +12 -8
- torchzero/modules/projections/projection.py +4 -4
- torchzero/modules/quasi_newton/__init__.py +1 -16
- torchzero/modules/quasi_newton/damping.py +105 -0
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -163
- torchzero/modules/quasi_newton/lbfgs.py +256 -200
- torchzero/modules/quasi_newton/lsr1.py +167 -132
- torchzero/modules/quasi_newton/quasi_newton.py +346 -446
- torchzero/modules/restarts/__init__.py +7 -0
- torchzero/modules/restarts/restars.py +252 -0
- torchzero/modules/second_order/__init__.py +2 -1
- torchzero/modules/second_order/multipoint.py +238 -0
- torchzero/modules/second_order/newton.py +133 -88
- torchzero/modules/second_order/newton_cg.py +141 -80
- torchzero/modules/smoothing/__init__.py +1 -1
- torchzero/modules/smoothing/sampling.py +300 -0
- torchzero/modules/step_size/__init__.py +1 -1
- torchzero/modules/step_size/adaptive.py +312 -47
- torchzero/modules/termination/__init__.py +14 -0
- torchzero/modules/termination/termination.py +207 -0
- torchzero/modules/trust_region/__init__.py +5 -0
- torchzero/modules/trust_region/cubic_regularization.py +170 -0
- torchzero/modules/trust_region/dogleg.py +92 -0
- torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
- torchzero/modules/trust_region/trust_cg.py +97 -0
- torchzero/modules/trust_region/trust_region.py +350 -0
- torchzero/modules/variance_reduction/__init__.py +1 -0
- torchzero/modules/variance_reduction/svrg.py +208 -0
- torchzero/modules/weight_decay/weight_decay.py +65 -64
- torchzero/modules/zeroth_order/__init__.py +1 -0
- torchzero/modules/zeroth_order/cd.py +359 -0
- torchzero/optim/root.py +65 -0
- torchzero/optim/utility/split.py +8 -8
- torchzero/optim/wrappers/directsearch.py +0 -1
- torchzero/optim/wrappers/fcmaes.py +3 -2
- torchzero/optim/wrappers/nlopt.py +0 -2
- torchzero/optim/wrappers/optuna.py +2 -2
- torchzero/optim/wrappers/scipy.py +81 -22
- torchzero/utils/__init__.py +40 -4
- torchzero/utils/compile.py +1 -1
- torchzero/utils/derivatives.py +123 -111
- torchzero/utils/linalg/__init__.py +9 -2
- torchzero/utils/linalg/linear_operator.py +329 -0
- torchzero/utils/linalg/matrix_funcs.py +2 -2
- torchzero/utils/linalg/orthogonalize.py +2 -1
- torchzero/utils/linalg/qr.py +2 -2
- torchzero/utils/linalg/solve.py +226 -154
- torchzero/utils/metrics.py +83 -0
- torchzero/utils/python_tools.py +6 -0
- torchzero/utils/tensorlist.py +105 -34
- torchzero/utils/torch_tools.py +9 -4
- torchzero-0.3.13.dist-info/METADATA +14 -0
- torchzero-0.3.13.dist-info/RECORD +166 -0
- {torchzero-0.3.11.dist-info → torchzero-0.3.13.dist-info}/top_level.txt +0 -1
- docs/source/conf.py +0 -59
- docs/source/docstring template.py +0 -46
- torchzero/modules/experimental/absoap.py +0 -253
- torchzero/modules/experimental/adadam.py +0 -118
- torchzero/modules/experimental/adamY.py +0 -131
- torchzero/modules/experimental/adam_lambertw.py +0 -149
- torchzero/modules/experimental/adaptive_step_size.py +0 -90
- torchzero/modules/experimental/adasoap.py +0 -177
- torchzero/modules/experimental/cosine.py +0 -214
- torchzero/modules/experimental/cubic_adam.py +0 -97
- torchzero/modules/experimental/eigendescent.py +0 -120
- torchzero/modules/experimental/etf.py +0 -195
- torchzero/modules/experimental/exp_adam.py +0 -113
- torchzero/modules/experimental/expanded_lbfgs.py +0 -141
- torchzero/modules/experimental/hnewton.py +0 -85
- torchzero/modules/experimental/modular_lbfgs.py +0 -265
- torchzero/modules/experimental/parabolic_search.py +0 -220
- torchzero/modules/experimental/subspace_preconditioners.py +0 -145
- torchzero/modules/experimental/tensor_adagrad.py +0 -42
- torchzero/modules/line_search/polynomial.py +0 -233
- torchzero/modules/momentum/matrix_momentum.py +0 -193
- torchzero/modules/optimizers/adagrad.py +0 -165
- torchzero/modules/quasi_newton/trust_region.py +0 -397
- torchzero/modules/smoothing/gaussian.py +0 -198
- torchzero-0.3.11.dist-info/METADATA +0 -404
- torchzero-0.3.11.dist-info/RECORD +0 -159
- torchzero-0.3.11.dist-info/licenses/LICENSE +0 -21
- /torchzero/modules/{optimizers → adaptive}/lion.py +0 -0
- /torchzero/modules/{optimizers → adaptive}/orthograd.py +0 -0
- /torchzero/modules/{optimizers → adaptive}/rmsprop.py +0 -0
- {torchzero-0.3.11.dist-info → torchzero-0.3.13.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,350 @@
|
|
|
1
|
+
import math
|
|
2
|
+
import warnings
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from collections.abc import Callable, Mapping, Sequence
|
|
5
|
+
from functools import partial
|
|
6
|
+
from typing import Any, Literal, Protocol, cast, final, overload
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
|
|
10
|
+
from ...core import Chainable, Module, Var, apply_transform
|
|
11
|
+
from ...utils import TensorList, safe_dict_update_, tofloat, vec_to_tensors, generic_finfo, generic_vector_norm
|
|
12
|
+
from ...utils.linalg.linear_operator import LinearOperator
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def _flatten_tensors(tensors: list[torch.Tensor]):
|
|
16
|
+
return torch.cat([t.ravel() for t in tensors])
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class _RadiusStrategy(Protocol):
|
|
21
|
+
def __call__(
|
|
22
|
+
self,
|
|
23
|
+
params: Sequence[torch.Tensor],
|
|
24
|
+
closure: Callable,
|
|
25
|
+
f: float,
|
|
26
|
+
g: torch.Tensor,
|
|
27
|
+
H: LinearOperator,
|
|
28
|
+
d: torch.Tensor,
|
|
29
|
+
trust_radius: float,
|
|
30
|
+
eta: float, # 0.0
|
|
31
|
+
nplus: float, # 3.5
|
|
32
|
+
nminus: float, # 0.25
|
|
33
|
+
rho_good: float, # 0.99
|
|
34
|
+
rho_bad: float, # 1e-4
|
|
35
|
+
boundary_tol: float | None,
|
|
36
|
+
init: float,
|
|
37
|
+
state: Mapping[str, Any],
|
|
38
|
+
settings: Mapping[str, Any],
|
|
39
|
+
radius_fn: Callable | None = torch.linalg.vector_norm,
|
|
40
|
+
) -> tuple[float, bool]:
|
|
41
|
+
"""returns (new trust_region value, success).
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
params (Sequence[torch.Tensor]): params tensor list
|
|
45
|
+
closure (Callable): closure
|
|
46
|
+
d (torch.Tensor):
|
|
47
|
+
current update vector with current trust_region, which is SUBTRACTED from parameters.
|
|
48
|
+
May be exact solution to (B+yI)x=g, approximate, or a solution to a different subproblem
|
|
49
|
+
(e.g. cubic regularization).
|
|
50
|
+
f (float | torch.Tensor): loss at x0
|
|
51
|
+
g (torch.Tensor): gradient vector
|
|
52
|
+
H (LinearOperator | None): hessian approximation
|
|
53
|
+
trust_radius (float): current trust region value
|
|
54
|
+
eta (float, optional):
|
|
55
|
+
if ratio of actual to predicted rediction is larger than this, step is accepted.
|
|
56
|
+
When :code:`hess_module` is GaussNewton, this can be set to 0. Defaults to 0.15.
|
|
57
|
+
nplus (float, optional): increase factor on successful steps. Defaults to 1.5.
|
|
58
|
+
nminus (float, optional): decrease factor on unsuccessful steps. Defaults to 0.75.
|
|
59
|
+
rho_good (float, optional):
|
|
60
|
+
if ratio of actual to predicted rediction is larger than this, trust region size is multiplied by `nplus`.
|
|
61
|
+
rho_bad (float, optional):
|
|
62
|
+
if ratio of actual to predicted rediction is less than this, trust region size is multiplied by `nminus`.
|
|
63
|
+
boundary_tol (float | None, optional):
|
|
64
|
+
The trust region only increases when suggested step's norm is at least `(1-boundary_tol)*trust_region`.
|
|
65
|
+
This prevents increasing trust region when solution is not on the boundary. Defaults to 1e-2.
|
|
66
|
+
init (float, optional): Initial trust region value. Defaults to 1.
|
|
67
|
+
state (dict, optional): global state of the module for storing persistent info.
|
|
68
|
+
settings (dict, optional): all settings in case this strategy has other settings.
|
|
69
|
+
radius_fn (Callable | None, optional):
|
|
70
|
+
function that accepts ``(d: torch.Tensor)`` and returns the actual region of ``d``
|
|
71
|
+
(e.g. L2) norm for L2 trust region.
|
|
72
|
+
"""
|
|
73
|
+
... # pylint:disable=unnecessary-ellipsis
|
|
74
|
+
|
|
75
|
+
def _get_rho(params: Sequence[torch.Tensor], closure:Callable,
|
|
76
|
+
f: float, g: torch.Tensor, H: LinearOperator, d:torch.Tensor, ):
|
|
77
|
+
"""rho is reduction/pred_reduction"""
|
|
78
|
+
|
|
79
|
+
# evaluate actual loss reduction
|
|
80
|
+
update_unflattned = vec_to_tensors(d, params)
|
|
81
|
+
params = TensorList(params)
|
|
82
|
+
x0 = params.clone() # same as in line searches, large directions are undone very imprecisely
|
|
83
|
+
|
|
84
|
+
params -= update_unflattned
|
|
85
|
+
f_star = closure(False)
|
|
86
|
+
params.set_(x0)
|
|
87
|
+
|
|
88
|
+
reduction = f - f_star
|
|
89
|
+
|
|
90
|
+
# expected reduction is g.T @ p + 0.5 * p.T @ B @ p
|
|
91
|
+
Hu = H.matvec(d)
|
|
92
|
+
pred_reduction = g.dot(d) - 0.5 * d.dot(Hu)
|
|
93
|
+
|
|
94
|
+
rho = reduction / (pred_reduction.clip(min=torch.finfo(g.dtype).tiny * 2))
|
|
95
|
+
return rho, f_star, reduction, pred_reduction
|
|
96
|
+
|
|
97
|
+
def _get_rho_tensorlist(params: Sequence[torch.Tensor], closure:Callable,
|
|
98
|
+
f: float, g: TensorList, Hvp: Callable[[TensorList], TensorList], d:TensorList):
|
|
99
|
+
"""rho is reduction/pred_reduction"""
|
|
100
|
+
params = TensorList(params)
|
|
101
|
+
x0 = params.clone() # same as in line searches, large directions are undone very imprecisely
|
|
102
|
+
|
|
103
|
+
# evaluate before modifying params to not break autograd
|
|
104
|
+
Hu = Hvp(d)
|
|
105
|
+
|
|
106
|
+
# actual f
|
|
107
|
+
params -= d
|
|
108
|
+
f_star = closure(False)
|
|
109
|
+
params.copy_(x0)
|
|
110
|
+
|
|
111
|
+
reduction = f - f_star
|
|
112
|
+
|
|
113
|
+
# expected f is g.T @ p + 0.5 * p.T @ B @ p
|
|
114
|
+
pred_reduction = g.dot(d) - 0.5 * d.dot(Hu)
|
|
115
|
+
|
|
116
|
+
rho = reduction / (pred_reduction.clip(min=torch.finfo(g[0].dtype).tiny * 2))
|
|
117
|
+
return rho, f_star, reduction, pred_reduction
|
|
118
|
+
|
|
119
|
+
@torch.no_grad
|
|
120
|
+
def default_radius(
|
|
121
|
+
params: Sequence[torch.Tensor],
|
|
122
|
+
closure: Callable,
|
|
123
|
+
f: float,
|
|
124
|
+
g: torch.Tensor | TensorList,
|
|
125
|
+
H: LinearOperator | Callable,
|
|
126
|
+
d: torch.Tensor | TensorList,
|
|
127
|
+
trust_radius: float,
|
|
128
|
+
eta: float, # 0.0
|
|
129
|
+
nplus: float, # 3.5
|
|
130
|
+
nminus: float, # 0.25
|
|
131
|
+
rho_good: float, # 0.99
|
|
132
|
+
rho_bad: float, # 1e-4
|
|
133
|
+
boundary_tol: float | None,
|
|
134
|
+
init: float,
|
|
135
|
+
state: Mapping[str, Any],
|
|
136
|
+
settings: Mapping[str, Any],
|
|
137
|
+
radius_fn: Callable | None = generic_vector_norm,
|
|
138
|
+
check_overflow: bool = True,
|
|
139
|
+
# dynamic_nminus: bool=False,
|
|
140
|
+
) -> tuple[float, bool]:
|
|
141
|
+
|
|
142
|
+
# when rho_bad < rho < eta, no update is made but trust region is not updated.
|
|
143
|
+
if eta > rho_bad:
|
|
144
|
+
warnings.warn(f"trust region eta={eta} is larger than rho_bad={rho_bad}, "
|
|
145
|
+
"this can lead to trust region getting stuck.")
|
|
146
|
+
|
|
147
|
+
if isinstance(g, torch.Tensor):
|
|
148
|
+
rho, f_star, _, _ = _get_rho(params=params, closure=closure, f=f, g=g, H=H, d=d) # pyright:ignore[reportArgumentType]
|
|
149
|
+
else:
|
|
150
|
+
rho, f_star, _, _ = _get_rho_tensorlist(params=params, closure=closure, f=f, g=g, Hvp=H, d=d) # pyright:ignore[reportArgumentType]
|
|
151
|
+
|
|
152
|
+
is_finite = math.isfinite(f_star)
|
|
153
|
+
|
|
154
|
+
# find boundary of current step
|
|
155
|
+
if radius_fn is None: d_radius = trust_radius
|
|
156
|
+
else: d_radius = radius_fn(d)
|
|
157
|
+
|
|
158
|
+
# failed step
|
|
159
|
+
if rho < rho_bad or not is_finite:
|
|
160
|
+
# if dynamic_nminus and rho > 0: nminus = nminus * max(rho, 1e-4)
|
|
161
|
+
trust_radius = d_radius*nminus
|
|
162
|
+
|
|
163
|
+
# very good step
|
|
164
|
+
elif rho > rho_good and is_finite:
|
|
165
|
+
if (boundary_tol is None) or (trust_radius-d_radius)/trust_radius < boundary_tol:
|
|
166
|
+
trust_radius = max(trust_radius, d_radius*nplus)
|
|
167
|
+
|
|
168
|
+
# prevent very small or large values
|
|
169
|
+
if check_overflow:
|
|
170
|
+
finfo = generic_finfo(g)
|
|
171
|
+
if trust_radius < finfo.tiny*2 or trust_radius > finfo.max/2:
|
|
172
|
+
trust_radius = init
|
|
173
|
+
|
|
174
|
+
# return new trust region and success boolean
|
|
175
|
+
return tofloat(trust_radius), rho > eta and is_finite
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
def fixed_radius(
|
|
179
|
+
params: Sequence[torch.Tensor],
|
|
180
|
+
closure: Callable,
|
|
181
|
+
f: float,
|
|
182
|
+
g: torch.Tensor,
|
|
183
|
+
H: LinearOperator,
|
|
184
|
+
d: torch.Tensor,
|
|
185
|
+
trust_radius: float,
|
|
186
|
+
eta: float, # 0.0
|
|
187
|
+
nplus: float, # 3.5
|
|
188
|
+
nminus: float, # 0.25
|
|
189
|
+
rho_good: float, # 0.99
|
|
190
|
+
rho_bad: float, # 1e-4
|
|
191
|
+
boundary_tol: float | None,
|
|
192
|
+
init: float,
|
|
193
|
+
state: Mapping[str, Any],
|
|
194
|
+
settings: Mapping[str, Any],
|
|
195
|
+
radius_fn: Callable | None = torch.linalg.vector_norm,
|
|
196
|
+
) -> tuple[float, bool]:
|
|
197
|
+
return init, True
|
|
198
|
+
|
|
199
|
+
_RADIUS_KEYS = Literal['default', 'fixed']
|
|
200
|
+
_RADIUS_STRATEGIES: dict[_RADIUS_KEYS, _RadiusStrategy] = {
|
|
201
|
+
"default": default_radius,
|
|
202
|
+
"fixed": fixed_radius,
|
|
203
|
+
# "dynamic": partial(default_radius, dynamic_nminus=True)
|
|
204
|
+
}
|
|
205
|
+
|
|
206
|
+
class TrustRegionBase(Module, ABC):
|
|
207
|
+
def __init__(
|
|
208
|
+
self,
|
|
209
|
+
defaults: dict | None,
|
|
210
|
+
hess_module: Chainable,
|
|
211
|
+
# suggested default values:
|
|
212
|
+
# Gould, Nicholas IM, et al. "Sensitivity of trust-region algorithms to their parameters." 4OR 3.3 (2005): 227-241.
|
|
213
|
+
# which I found from https://github.com/patrick-kidger/optimistix/blob/c1dad7e75fc35bd5a4977ac3a872991e51e83d2c/optimistix/_solver/trust_region.py#L113-200
|
|
214
|
+
eta: float, # 0.0
|
|
215
|
+
nplus: float, # 3.5
|
|
216
|
+
nminus: float, # 0.25
|
|
217
|
+
rho_good: float, # 0.99
|
|
218
|
+
rho_bad: float, # 1e-4
|
|
219
|
+
boundary_tol: float | None, # None or 1e-1
|
|
220
|
+
init: float, # 1
|
|
221
|
+
max_attempts: int, # 10
|
|
222
|
+
radius_strategy: _RadiusStrategy | _RADIUS_KEYS, # "default"
|
|
223
|
+
radius_fn: Callable | None, # torch.linalg.vector_norm
|
|
224
|
+
update_freq: int = 1,
|
|
225
|
+
inner: Chainable | None = None,
|
|
226
|
+
):
|
|
227
|
+
if isinstance(radius_strategy, str): radius_strategy = _RADIUS_STRATEGIES[radius_strategy]
|
|
228
|
+
if defaults is None: defaults = {}
|
|
229
|
+
|
|
230
|
+
safe_dict_update_(
|
|
231
|
+
defaults,
|
|
232
|
+
dict(eta=eta, nplus=nplus, nminus=nminus, rho_good=rho_good, rho_bad=rho_bad, init=init,
|
|
233
|
+
update_freq=update_freq, max_attempts=max_attempts, radius_strategy=radius_strategy,
|
|
234
|
+
boundary_tol=boundary_tol)
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
super().__init__(defaults)
|
|
238
|
+
|
|
239
|
+
self._radius_fn = radius_fn
|
|
240
|
+
self.set_child('hess_module', hess_module)
|
|
241
|
+
|
|
242
|
+
if inner is not None:
|
|
243
|
+
self.set_child('inner', inner)
|
|
244
|
+
|
|
245
|
+
@abstractmethod
|
|
246
|
+
def trust_solve(
|
|
247
|
+
self,
|
|
248
|
+
f: float,
|
|
249
|
+
g: torch.Tensor,
|
|
250
|
+
H: LinearOperator,
|
|
251
|
+
radius: float,
|
|
252
|
+
params: list[torch.Tensor],
|
|
253
|
+
closure: Callable,
|
|
254
|
+
settings: Mapping[str, Any],
|
|
255
|
+
) -> torch.Tensor:
|
|
256
|
+
"""Solve Hx=g with a trust region penalty/bound defined by `radius`"""
|
|
257
|
+
... # pylint:disable=unnecessary-ellipsis
|
|
258
|
+
|
|
259
|
+
def trust_region_update(self, var: Var, H: LinearOperator | None) -> None:
|
|
260
|
+
"""updates the state of this module after H or B have been updated, if necessary"""
|
|
261
|
+
|
|
262
|
+
def trust_region_apply(self, var: Var, tensors:list[torch.Tensor], H: LinearOperator | None) -> Var:
|
|
263
|
+
"""Solves the trust region subproblem and outputs ``Var`` with the solution direction."""
|
|
264
|
+
assert H is not None
|
|
265
|
+
|
|
266
|
+
params = TensorList(var.params)
|
|
267
|
+
settings = self.settings[params[0]]
|
|
268
|
+
g = _flatten_tensors(tensors)
|
|
269
|
+
|
|
270
|
+
max_attempts = settings['max_attempts']
|
|
271
|
+
|
|
272
|
+
# loss at x_0
|
|
273
|
+
loss = var.loss
|
|
274
|
+
closure = var.closure
|
|
275
|
+
if closure is None: raise RuntimeError("Trust region requires closure")
|
|
276
|
+
if loss is None: loss = var.get_loss(False)
|
|
277
|
+
loss = tofloat(loss)
|
|
278
|
+
|
|
279
|
+
# trust region step and update
|
|
280
|
+
success = False
|
|
281
|
+
d = None
|
|
282
|
+
while not success:
|
|
283
|
+
max_attempts -= 1
|
|
284
|
+
if max_attempts < 0: break
|
|
285
|
+
|
|
286
|
+
trust_radius = self.global_state.get('trust_radius', settings['init'])
|
|
287
|
+
|
|
288
|
+
# solve Hx=g
|
|
289
|
+
d = self.trust_solve(f=loss, g=g, H=H, radius=trust_radius, params=params, closure=closure, settings=settings)
|
|
290
|
+
|
|
291
|
+
# update trust radius
|
|
292
|
+
radius_strategy: _RadiusStrategy = settings['radius_strategy']
|
|
293
|
+
self.global_state["trust_radius"], success = radius_strategy(
|
|
294
|
+
params=params,
|
|
295
|
+
closure=closure,
|
|
296
|
+
d=d,
|
|
297
|
+
f=loss,
|
|
298
|
+
g=g,
|
|
299
|
+
H=H,
|
|
300
|
+
trust_radius=trust_radius,
|
|
301
|
+
|
|
302
|
+
eta=settings["eta"],
|
|
303
|
+
nplus=settings["nplus"],
|
|
304
|
+
nminus=settings["nminus"],
|
|
305
|
+
rho_good=settings["rho_good"],
|
|
306
|
+
rho_bad=settings["rho_bad"],
|
|
307
|
+
boundary_tol=settings["boundary_tol"],
|
|
308
|
+
init=settings["init"],
|
|
309
|
+
|
|
310
|
+
state=self.global_state,
|
|
311
|
+
settings=settings,
|
|
312
|
+
radius_fn=self._radius_fn,
|
|
313
|
+
)
|
|
314
|
+
|
|
315
|
+
assert d is not None
|
|
316
|
+
if success: var.update = vec_to_tensors(d, params)
|
|
317
|
+
else: var.update = params.zeros_like()
|
|
318
|
+
|
|
319
|
+
return var
|
|
320
|
+
|
|
321
|
+
|
|
322
|
+
@final
|
|
323
|
+
@torch.no_grad
|
|
324
|
+
def update(self, var):
|
|
325
|
+
step = self.global_state.get('step', 0)
|
|
326
|
+
self.global_state['step'] = step + 1
|
|
327
|
+
|
|
328
|
+
if step % self.defaults["update_freq"] == 0:
|
|
329
|
+
|
|
330
|
+
hessian_module = self.children['hess_module']
|
|
331
|
+
hessian_module.update(var)
|
|
332
|
+
H = hessian_module.get_H(var)
|
|
333
|
+
self.global_state["H"] = H
|
|
334
|
+
|
|
335
|
+
self.trust_region_update(var, H=H)
|
|
336
|
+
|
|
337
|
+
|
|
338
|
+
@final
|
|
339
|
+
@torch.no_grad
|
|
340
|
+
def apply(self, var):
|
|
341
|
+
H = self.global_state.get('H', None)
|
|
342
|
+
|
|
343
|
+
# -------------------------------- inner step -------------------------------- #
|
|
344
|
+
update = var.get_update()
|
|
345
|
+
if 'inner' in self.children:
|
|
346
|
+
update = apply_transform(self.children['inner'], update, params=var.params, grads=var.grad, var=var)
|
|
347
|
+
|
|
348
|
+
# ----------------------------------- apply ---------------------------------- #
|
|
349
|
+
return self.trust_region_apply(var=var, tensors=update, H=H)
|
|
350
|
+
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .svrg import SVRG
|
|
@@ -0,0 +1,208 @@
|
|
|
1
|
+
import warnings
|
|
2
|
+
from functools import partial
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
from ...core.module import Module
|
|
7
|
+
from ...utils import tofloat
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def _reset_except_self(optimizer, var, self: Module):
|
|
11
|
+
for m in optimizer.unrolled_modules:
|
|
12
|
+
if m is not self:
|
|
13
|
+
m.reset()
|
|
14
|
+
|
|
15
|
+
class SVRG(Module):
|
|
16
|
+
"""Stochastic variance reduced gradient method (SVRG).
|
|
17
|
+
|
|
18
|
+
To use, put SVRG as the first module, it can be used with any other modules.
|
|
19
|
+
To reduce variance of a gradient estimator, put the gradient estimator before SVRG.
|
|
20
|
+
|
|
21
|
+
First it uses first ``accum_steps`` batches to compute full gradient at initial
|
|
22
|
+
parameters using gradient accumulation, the model will not be updated during this.
|
|
23
|
+
|
|
24
|
+
Then it performs ``svrg_steps`` SVRG steps, each requires two forward and backward passes.
|
|
25
|
+
|
|
26
|
+
After ``svrg_steps``, it goes back to full gradient computation step step.
|
|
27
|
+
|
|
28
|
+
As an alternative to gradient accumulation you can pass "full_closure" argument to the ``step`` method,
|
|
29
|
+
which should compute full gradients, set them to ``.grad`` attributes of the parameters,
|
|
30
|
+
and return full loss.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
svrg_steps (int): number of steps before calculating full gradient. This can be set to length of the dataloader.
|
|
34
|
+
accum_steps (int | None, optional):
|
|
35
|
+
number of steps to accumulate the gradient for. Not used if "full_closure" is passed to the ``step`` method. If None, uses value of ``svrg_steps``. Defaults to None.
|
|
36
|
+
reset_before_accum (bool, optional):
|
|
37
|
+
whether to reset all other modules when re-calculating full gradient. Defaults to True.
|
|
38
|
+
svrg_loss (bool, optional):
|
|
39
|
+
whether to replace loss with SVRG loss (calculated by same formula as SVRG gradient). Defaults to True.
|
|
40
|
+
alpha (float, optional):
|
|
41
|
+
multiplier to ``g_full(x_0) - g_batch(x_0)`` term, can be annealed linearly from 1 to 0 as suggested in https://arxiv.org/pdf/2311.05589#page=6
|
|
42
|
+
|
|
43
|
+
## Examples:
|
|
44
|
+
SVRG-LBFGS
|
|
45
|
+
```python
|
|
46
|
+
opt = tz.Modular(
|
|
47
|
+
model.parameters(),
|
|
48
|
+
tz.m.SVRG(len(dataloader)),
|
|
49
|
+
tz.m.LBFGS(),
|
|
50
|
+
tz.m.Backtracking(),
|
|
51
|
+
)
|
|
52
|
+
```
|
|
53
|
+
|
|
54
|
+
For extra variance reduction one can use Online versions of algorithms, although it won't always help.
|
|
55
|
+
```python
|
|
56
|
+
opt = tz.Modular(
|
|
57
|
+
model.parameters(),
|
|
58
|
+
tz.m.SVRG(len(dataloader)),
|
|
59
|
+
tz.m.Online(tz.m.LBFGS()),
|
|
60
|
+
tz.m.Backtracking(),
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
Variance reduction can also be applied to gradient estimators.
|
|
64
|
+
```python
|
|
65
|
+
opt = tz.Modular(
|
|
66
|
+
model.parameters(),
|
|
67
|
+
tz.m.SPSA(),
|
|
68
|
+
tz.m.SVRG(100),
|
|
69
|
+
tz.m.LR(1e-2),
|
|
70
|
+
)
|
|
71
|
+
```
|
|
72
|
+
## Notes
|
|
73
|
+
|
|
74
|
+
The SVRG gradient is computed as ``g_b(x) - alpha * g_b(x_0) - g_f(x0.)``, where:
|
|
75
|
+
- ``x`` is current parameters
|
|
76
|
+
- ``x_0`` is initial parameters, where full gradient was computed
|
|
77
|
+
- ``g_b`` refers to mini-batch gradient at ``x`` or ``x_0``
|
|
78
|
+
- ``g_f`` refers to full gradient at ``x_0``.
|
|
79
|
+
|
|
80
|
+
The SVRG loss is computed using the same formula.
|
|
81
|
+
"""
|
|
82
|
+
def __init__(self, svrg_steps: int, accum_steps: int | None = None, reset_before_accum:bool=True, svrg_loss:bool=True, alpha:float=1):
|
|
83
|
+
defaults = dict(svrg_steps = svrg_steps, accum_steps=accum_steps, reset_before_accum=reset_before_accum, svrg_loss=svrg_loss, alpha=alpha)
|
|
84
|
+
super().__init__(defaults)
|
|
85
|
+
|
|
86
|
+
@torch.no_grad
|
|
87
|
+
def step(self, var):
|
|
88
|
+
params = var.params
|
|
89
|
+
closure = var.closure
|
|
90
|
+
assert closure is not None
|
|
91
|
+
|
|
92
|
+
if "full_grad" not in self.global_state:
|
|
93
|
+
|
|
94
|
+
# -------------------------- calculate full gradient ------------------------- #
|
|
95
|
+
if "full_closure" in var.storage:
|
|
96
|
+
full_closure = var.storage['full_closure']
|
|
97
|
+
with torch.enable_grad():
|
|
98
|
+
full_loss = full_closure()
|
|
99
|
+
if all(p.grad is None for p in params):
|
|
100
|
+
warnings.warn("all gradients are None after evaluating full_closure.")
|
|
101
|
+
|
|
102
|
+
full_grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
|
|
103
|
+
self.global_state["full_loss"] = full_loss
|
|
104
|
+
self.global_state["full_grad"] = full_grad
|
|
105
|
+
self.global_state['x_0'] = [p.clone() for p in params]
|
|
106
|
+
|
|
107
|
+
# current batch will be used for svrg update
|
|
108
|
+
|
|
109
|
+
else:
|
|
110
|
+
# accumulate gradients over n steps
|
|
111
|
+
accum_steps = self.defaults['accum_steps']
|
|
112
|
+
if accum_steps is None: accum_steps = self.defaults['svrg_steps']
|
|
113
|
+
|
|
114
|
+
current_accum_step = self.global_state.get('current_accum_step', 0) + 1
|
|
115
|
+
self.global_state['current_accum_step'] = current_accum_step
|
|
116
|
+
|
|
117
|
+
# accumulate grads
|
|
118
|
+
accumulator = self.get_state(params, 'accumulator')
|
|
119
|
+
grad = var.get_grad()
|
|
120
|
+
torch._foreach_add_(accumulator, grad)
|
|
121
|
+
|
|
122
|
+
# accumulate loss
|
|
123
|
+
loss_accumulator = self.global_state.get('loss_accumulator', 0)
|
|
124
|
+
loss_accumulator += tofloat(var.loss)
|
|
125
|
+
self.global_state['loss_accumulator'] = loss_accumulator
|
|
126
|
+
|
|
127
|
+
# on nth step, use the accumulated gradient
|
|
128
|
+
if current_accum_step >= accum_steps:
|
|
129
|
+
torch._foreach_div_(accumulator, accum_steps)
|
|
130
|
+
self.global_state["full_grad"] = accumulator
|
|
131
|
+
self.global_state["full_loss"] = loss_accumulator / accum_steps
|
|
132
|
+
|
|
133
|
+
self.global_state['x_0'] = [p.clone() for p in params]
|
|
134
|
+
self.clear_state_keys('accumulator')
|
|
135
|
+
del self.global_state['current_accum_step']
|
|
136
|
+
|
|
137
|
+
# otherwise skip update until enough grads are accumulated
|
|
138
|
+
else:
|
|
139
|
+
var.update = None
|
|
140
|
+
var.stop = True
|
|
141
|
+
var.skip_update = True
|
|
142
|
+
return var
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
svrg_steps = self.defaults['svrg_steps']
|
|
146
|
+
current_svrg_step = self.global_state.get('current_svrg_step', 0) + 1
|
|
147
|
+
self.global_state['current_svrg_step'] = current_svrg_step
|
|
148
|
+
|
|
149
|
+
# --------------------------- SVRG gradient closure -------------------------- #
|
|
150
|
+
x0 = self.global_state['x_0']
|
|
151
|
+
gf_x0 = self.global_state["full_grad"]
|
|
152
|
+
ff_x0 = self.global_state['full_loss']
|
|
153
|
+
use_svrg_loss = self.defaults['svrg_loss']
|
|
154
|
+
alpha = self.get_settings(params, 'alpha')
|
|
155
|
+
alpha_0 = alpha[0]
|
|
156
|
+
if all(a == 1 for a in alpha): alpha = None
|
|
157
|
+
|
|
158
|
+
def svrg_closure(backward=True):
|
|
159
|
+
# g_b(x) - α * (g_f(x_0) - g_b(x_0)) and same for loss
|
|
160
|
+
with torch.no_grad():
|
|
161
|
+
x = [p.clone() for p in params]
|
|
162
|
+
|
|
163
|
+
if backward:
|
|
164
|
+
# f and g at x
|
|
165
|
+
with torch.enable_grad(): fb_x = closure()
|
|
166
|
+
gb_x = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
|
|
167
|
+
|
|
168
|
+
# f and g at x_0
|
|
169
|
+
torch._foreach_copy_(params, x0)
|
|
170
|
+
with torch.enable_grad(): fb_x0 = closure()
|
|
171
|
+
gb_x0 = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
|
|
172
|
+
torch._foreach_copy_(params, x)
|
|
173
|
+
|
|
174
|
+
# g_svrg = gb_x - alpha * (gf_x0 - gb_x0)
|
|
175
|
+
correction = torch._foreach_sub(gb_x0, gf_x0)
|
|
176
|
+
if alpha is not None: torch._foreach_mul_(correction, alpha)
|
|
177
|
+
g_svrg = torch._foreach_sub(gb_x, correction)
|
|
178
|
+
|
|
179
|
+
f_svrg = fb_x - alpha_0 * (fb_x0 - ff_x0)
|
|
180
|
+
for p, g in zip(params, g_svrg):
|
|
181
|
+
p.grad = g
|
|
182
|
+
|
|
183
|
+
if use_svrg_loss: return f_svrg
|
|
184
|
+
return fb_x
|
|
185
|
+
|
|
186
|
+
# no backward
|
|
187
|
+
if use_svrg_loss:
|
|
188
|
+
fb_x = closure(False)
|
|
189
|
+
torch._foreach_copy_(params, x0)
|
|
190
|
+
fb_x0 = closure(False)
|
|
191
|
+
torch._foreach_copy_(params, x)
|
|
192
|
+
f_svrg = fb_x - alpha_0 * (fb_x0 - ff_x0)
|
|
193
|
+
return f_svrg
|
|
194
|
+
|
|
195
|
+
return closure(False)
|
|
196
|
+
|
|
197
|
+
var.closure = svrg_closure
|
|
198
|
+
|
|
199
|
+
# --- after svrg_steps steps reset so that new full gradient is calculated on next step --- #
|
|
200
|
+
if current_svrg_step >= svrg_steps:
|
|
201
|
+
del self.global_state['current_svrg_step']
|
|
202
|
+
del self.global_state['full_grad']
|
|
203
|
+
del self.global_state['full_loss']
|
|
204
|
+
del self.global_state['x_0']
|
|
205
|
+
if self.defaults['reset_before_accum']:
|
|
206
|
+
var.post_step_hooks.append(partial(_reset_except_self, self=self))
|
|
207
|
+
|
|
208
|
+
return var
|