torchzero 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torchzero/__init__.py +4 -0
- torchzero/core/__init__.py +13 -0
- torchzero/core/module.py +471 -0
- torchzero/core/tensorlist_optimizer.py +219 -0
- torchzero/modules/__init__.py +21 -0
- torchzero/modules/adaptive/__init__.py +4 -0
- torchzero/modules/adaptive/adaptive.py +192 -0
- torchzero/modules/experimental/__init__.py +19 -0
- torchzero/modules/experimental/experimental.py +294 -0
- torchzero/modules/experimental/quad_interp.py +104 -0
- torchzero/modules/experimental/subspace.py +259 -0
- torchzero/modules/gradient_approximation/__init__.py +7 -0
- torchzero/modules/gradient_approximation/_fd_formulas.py +3 -0
- torchzero/modules/gradient_approximation/base_approximator.py +110 -0
- torchzero/modules/gradient_approximation/fdm.py +125 -0
- torchzero/modules/gradient_approximation/forward_gradient.py +163 -0
- torchzero/modules/gradient_approximation/newton_fdm.py +198 -0
- torchzero/modules/gradient_approximation/rfdm.py +125 -0
- torchzero/modules/line_search/__init__.py +30 -0
- torchzero/modules/line_search/armijo.py +56 -0
- torchzero/modules/line_search/base_ls.py +139 -0
- torchzero/modules/line_search/directional_newton.py +217 -0
- torchzero/modules/line_search/grid_ls.py +158 -0
- torchzero/modules/line_search/scipy_minimize_scalar.py +62 -0
- torchzero/modules/meta/__init__.py +12 -0
- torchzero/modules/meta/alternate.py +65 -0
- torchzero/modules/meta/grafting.py +195 -0
- torchzero/modules/meta/optimizer_wrapper.py +173 -0
- torchzero/modules/meta/return_overrides.py +46 -0
- torchzero/modules/misc/__init__.py +10 -0
- torchzero/modules/misc/accumulate.py +43 -0
- torchzero/modules/misc/basic.py +115 -0
- torchzero/modules/misc/lr.py +96 -0
- torchzero/modules/misc/multistep.py +51 -0
- torchzero/modules/misc/on_increase.py +53 -0
- torchzero/modules/momentum/__init__.py +4 -0
- torchzero/modules/momentum/momentum.py +106 -0
- torchzero/modules/operations/__init__.py +29 -0
- torchzero/modules/operations/multi.py +298 -0
- torchzero/modules/operations/reduction.py +134 -0
- torchzero/modules/operations/singular.py +113 -0
- torchzero/modules/optimizers/__init__.py +10 -0
- torchzero/modules/optimizers/adagrad.py +49 -0
- torchzero/modules/optimizers/adam.py +118 -0
- torchzero/modules/optimizers/lion.py +28 -0
- torchzero/modules/optimizers/rmsprop.py +51 -0
- torchzero/modules/optimizers/rprop.py +99 -0
- torchzero/modules/optimizers/sgd.py +54 -0
- torchzero/modules/orthogonalization/__init__.py +2 -0
- torchzero/modules/orthogonalization/newtonschulz.py +159 -0
- torchzero/modules/orthogonalization/svd.py +86 -0
- torchzero/modules/quasi_newton/__init__.py +4 -0
- torchzero/modules/regularization/__init__.py +22 -0
- torchzero/modules/regularization/dropout.py +34 -0
- torchzero/modules/regularization/noise.py +77 -0
- torchzero/modules/regularization/normalization.py +328 -0
- torchzero/modules/regularization/ortho_grad.py +78 -0
- torchzero/modules/regularization/weight_decay.py +92 -0
- torchzero/modules/scheduling/__init__.py +2 -0
- torchzero/modules/scheduling/lr_schedulers.py +131 -0
- torchzero/modules/scheduling/step_size.py +80 -0
- torchzero/modules/second_order/__init__.py +4 -0
- torchzero/modules/second_order/newton.py +165 -0
- torchzero/modules/smoothing/__init__.py +5 -0
- torchzero/modules/smoothing/gaussian_smoothing.py +90 -0
- torchzero/modules/smoothing/laplacian_smoothing.py +128 -0
- torchzero/modules/weight_averaging/__init__.py +2 -0
- torchzero/modules/weight_averaging/ema.py +72 -0
- torchzero/modules/weight_averaging/swa.py +171 -0
- torchzero/optim/__init__.py +10 -0
- torchzero/optim/experimental/__init__.py +20 -0
- torchzero/optim/experimental/experimental.py +343 -0
- torchzero/optim/experimental/ray_search.py +83 -0
- torchzero/optim/first_order/__init__.py +18 -0
- torchzero/optim/first_order/cautious.py +158 -0
- torchzero/optim/first_order/forward_gradient.py +70 -0
- torchzero/optim/first_order/optimizers.py +570 -0
- torchzero/optim/modular.py +132 -0
- torchzero/optim/quasi_newton/__init__.py +1 -0
- torchzero/optim/quasi_newton/directional_newton.py +58 -0
- torchzero/optim/second_order/__init__.py +1 -0
- torchzero/optim/second_order/newton.py +94 -0
- torchzero/optim/wrappers/__init__.py +0 -0
- torchzero/optim/wrappers/nevergrad.py +113 -0
- torchzero/optim/wrappers/nlopt.py +165 -0
- torchzero/optim/wrappers/scipy.py +439 -0
- torchzero/optim/zeroth_order/__init__.py +4 -0
- torchzero/optim/zeroth_order/fdm.py +87 -0
- torchzero/optim/zeroth_order/newton_fdm.py +146 -0
- torchzero/optim/zeroth_order/rfdm.py +217 -0
- torchzero/optim/zeroth_order/rs.py +85 -0
- torchzero/random/__init__.py +1 -0
- torchzero/random/random.py +46 -0
- torchzero/tensorlist.py +819 -0
- torchzero/utils/__init__.py +0 -0
- torchzero/utils/compile.py +39 -0
- torchzero/utils/derivatives.py +99 -0
- torchzero/utils/python_tools.py +25 -0
- torchzero/utils/torch_tools.py +92 -0
- torchzero-0.0.1.dist-info/LICENSE +21 -0
- torchzero-0.0.1.dist-info/METADATA +118 -0
- torchzero-0.0.1.dist-info/RECORD +104 -0
- torchzero-0.0.1.dist-info/WHEEL +5 -0
- torchzero-0.0.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,439 @@
|
|
|
1
|
+
from typing import Literal, Any
|
|
2
|
+
from collections import abc
|
|
3
|
+
from functools import partial
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import torch
|
|
7
|
+
|
|
8
|
+
import scipy.optimize
|
|
9
|
+
|
|
10
|
+
from ...core import _ClosureType, TensorListOptimizer
|
|
11
|
+
from ...utils.derivatives import jacobian, jacobian_list_to_vec, hessian, hessian_list_to_mat, jacobian_and_hessian
|
|
12
|
+
from ...modules import WrapClosure
|
|
13
|
+
from ...modules.experimental.subspace import Projection, Proj2Masks, ProjGrad, ProjNormalize, Subspace
|
|
14
|
+
from ...modules.second_order.newton import regularize_hessian_
|
|
15
|
+
from ...tensorlist import TensorList
|
|
16
|
+
from ..modular import Modular
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def _ensure_float(x):
|
|
20
|
+
if isinstance(x, torch.Tensor): return x.detach().cpu().item()
|
|
21
|
+
if isinstance(x, np.ndarray): return x.item()
|
|
22
|
+
return float(x)
|
|
23
|
+
|
|
24
|
+
def _ensure_numpy(x):
|
|
25
|
+
if isinstance(x, torch.Tensor): return x.detach().cpu()
|
|
26
|
+
if isinstance(x, np.ndarray): return x
|
|
27
|
+
return np.array(x)
|
|
28
|
+
|
|
29
|
+
class ScipyMinimize(TensorListOptimizer):
|
|
30
|
+
"""Use scipy.minimize.optimize as pytorch optimizer. Note that this performs full minimization on each step,
|
|
31
|
+
so usually you would want to perform a single step, although performing multiple steps will refine the
|
|
32
|
+
solution.
|
|
33
|
+
|
|
34
|
+
Please refer to https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.minimize.html
|
|
35
|
+
for a detailed description of args.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
params: iterable of parameters to optimize or dicts defining parameter groups.
|
|
39
|
+
method (str | None, optional): type of solver.
|
|
40
|
+
If None, scipy will select one of BFGS, L-BFGS-B, SLSQP,
|
|
41
|
+
depending on whether or not the problem has constraints or bounds.
|
|
42
|
+
Defaults to None.
|
|
43
|
+
bounds (optional): bounds on variables. Defaults to None.
|
|
44
|
+
constraints (tuple, optional): constraints definition. Defaults to ().
|
|
45
|
+
tol (float | None, optional): Tolerance for termination. Defaults to None.
|
|
46
|
+
callback (Callable | None, optional): A callable called after each iteration. Defaults to None.
|
|
47
|
+
options (dict | None, optional): A dictionary of solver options. Defaults to None.
|
|
48
|
+
jac (str, optional): Method for computing the gradient vector.
|
|
49
|
+
Only for CG, BFGS, Newton-CG, L-BFGS-B, TNC, SLSQP, dogleg, trust-ncg, trust-krylov, trust-exact and trust-constr.
|
|
50
|
+
In addition to scipy options, this supports 'autograd', which uses pytorch autograd.
|
|
51
|
+
This setting is ignored for methods that don't require gradient. Defaults to 'autograd'.
|
|
52
|
+
hess (str, optional):
|
|
53
|
+
Method for computing the Hessian matrix.
|
|
54
|
+
Only for Newton-CG, dogleg, trust-ncg, trust-krylov, trust-exact and trust-constr.
|
|
55
|
+
This setting is ignored for methods that don't require hessian. Defaults to 'autograd'.
|
|
56
|
+
tikhonov (float, optional):
|
|
57
|
+
optional hessian regularizer value. Only has effect for methods that require hessian.
|
|
58
|
+
"""
|
|
59
|
+
def __init__(
|
|
60
|
+
self,
|
|
61
|
+
params,
|
|
62
|
+
method: Literal['nelder-mead', 'powell', 'cg', 'bfgs', 'newton-cg',
|
|
63
|
+
'l-bfgs-b', 'tnc', 'cobyla', 'cobyqa', 'slsqp',
|
|
64
|
+
'trust-constr', 'dogleg', 'trust-ncg', 'trust-exact',
|
|
65
|
+
'trust-krylov'] | str | None = None,
|
|
66
|
+
lb = None,
|
|
67
|
+
ub = None,
|
|
68
|
+
constraints = (),
|
|
69
|
+
tol: float | None = None,
|
|
70
|
+
callback = None,
|
|
71
|
+
options = None,
|
|
72
|
+
jac: Literal['2-point', '3-point', 'cs', 'autograd'] = 'autograd',
|
|
73
|
+
hess: Literal['2-point', '3-point', 'cs', 'autograd'] | scipy.optimize.HessianUpdateStrategy = 'autograd',
|
|
74
|
+
tikhonov: float | Literal['eig'] = 0,
|
|
75
|
+
):
|
|
76
|
+
defaults = dict(lb=lb, ub=ub)
|
|
77
|
+
super().__init__(params, defaults)
|
|
78
|
+
self.method = method
|
|
79
|
+
self.constraints = constraints
|
|
80
|
+
self.tol = tol
|
|
81
|
+
self.callback = callback
|
|
82
|
+
self.options = options
|
|
83
|
+
|
|
84
|
+
self.jac = jac
|
|
85
|
+
self.hess = hess
|
|
86
|
+
self.tikhonov: float | Literal['eig'] = tikhonov
|
|
87
|
+
|
|
88
|
+
self.use_jac_autograd = jac.lower() == 'autograd' and (method is None or method.lower() in [
|
|
89
|
+
'cg', 'bfgs', 'newton-cg', 'l-bfgs-b', 'tnc', 'slsqp', 'dogleg',
|
|
90
|
+
'trust-ncg', 'trust-krylov', 'trust-exact', 'trust-constr',
|
|
91
|
+
])
|
|
92
|
+
self.use_hess_autograd = isinstance(hess, str) and hess.lower() == 'autograd' and method is not None and method.lower() in [
|
|
93
|
+
'newton-cg', 'dogleg', 'trust-ncg', 'trust-krylov', 'trust-exact'
|
|
94
|
+
]
|
|
95
|
+
|
|
96
|
+
if self.jac == 'autograd':
|
|
97
|
+
if self.use_jac_autograd: self.jac = True
|
|
98
|
+
else: self.jac = None
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def _hess(self, x: np.ndarray, params: TensorList, closure: _ClosureType): # type:ignore
|
|
102
|
+
params.from_vec_(torch.from_numpy(x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
|
|
103
|
+
with torch.enable_grad():
|
|
104
|
+
value = closure(False)
|
|
105
|
+
H = hessian([value], wrt = params) # type:ignore
|
|
106
|
+
Hmat = hessian_list_to_mat(H)
|
|
107
|
+
regularize_hessian_(Hmat, self.tikhonov)
|
|
108
|
+
return Hmat.detach().cpu().numpy()
|
|
109
|
+
|
|
110
|
+
def _objective(self, x: np.ndarray, params: TensorList, closure: _ClosureType):
|
|
111
|
+
# set params to x
|
|
112
|
+
params.from_vec_(torch.from_numpy(x).to(params[0], copy=False))
|
|
113
|
+
|
|
114
|
+
# return value and maybe gradients
|
|
115
|
+
if self.use_jac_autograd:
|
|
116
|
+
with torch.enable_grad(): value = _ensure_float(closure())
|
|
117
|
+
return value, params.ensure_grad_().grad.to_vec().detach().cpu().numpy()
|
|
118
|
+
return _ensure_float(closure(False))
|
|
119
|
+
|
|
120
|
+
@torch.no_grad
|
|
121
|
+
def step(self, closure: _ClosureType): # type:ignore # pylint:disable = signature-differs
|
|
122
|
+
params = self.get_params()
|
|
123
|
+
|
|
124
|
+
# determine hess argument
|
|
125
|
+
if self.hess == 'autograd':
|
|
126
|
+
if self.use_hess_autograd: hess = partial(self._hess, params = params, closure = closure)
|
|
127
|
+
else: hess = None
|
|
128
|
+
else: hess = self.hess
|
|
129
|
+
|
|
130
|
+
x0 = params.to_vec().detach().cpu().numpy()
|
|
131
|
+
|
|
132
|
+
# make bounds
|
|
133
|
+
lb, ub = self.get_group_keys('lb', 'ub', cls=list)
|
|
134
|
+
bounds = []
|
|
135
|
+
for p, l, u in zip(params, lb, ub):
|
|
136
|
+
bounds.extend([(l, u)] * p.numel())
|
|
137
|
+
|
|
138
|
+
if self.method is not None and (self.method.lower() == 'tnc' or self.method.lower() == 'slsqp'):
|
|
139
|
+
x0 = x0.astype(np.float64) # those methods error without this
|
|
140
|
+
|
|
141
|
+
res = scipy.optimize.minimize(
|
|
142
|
+
partial(self._objective, params = params, closure = closure),
|
|
143
|
+
x0 = x0,
|
|
144
|
+
method=self.method,
|
|
145
|
+
bounds=bounds,
|
|
146
|
+
constraints=self.constraints,
|
|
147
|
+
tol=self.tol,
|
|
148
|
+
callback=self.callback,
|
|
149
|
+
options=self.options,
|
|
150
|
+
jac = self.jac,
|
|
151
|
+
hess = hess,
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
params.from_vec_(torch.from_numpy(res.x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
|
|
155
|
+
return res.fun
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
class ScipyRoot(TensorListOptimizer):
|
|
160
|
+
"""Find a root of a vector function (UNTESTED!).
|
|
161
|
+
|
|
162
|
+
Args:
|
|
163
|
+
params: iterable of parameters to optimize or dicts defining parameter groups.
|
|
164
|
+
method (str | None, optional): _description_. Defaults to None.
|
|
165
|
+
tol (float | None, optional): _description_. Defaults to None.
|
|
166
|
+
callback (_type_, optional): _description_. Defaults to None.
|
|
167
|
+
options (_type_, optional): _description_. Defaults to None.
|
|
168
|
+
jac (T.Literal['2, optional): _description_. Defaults to 'autograd'.
|
|
169
|
+
"""
|
|
170
|
+
def __init__(
|
|
171
|
+
self,
|
|
172
|
+
params,
|
|
173
|
+
method: Literal[
|
|
174
|
+
"hybr",
|
|
175
|
+
"lm",
|
|
176
|
+
"broyden1",
|
|
177
|
+
"broyden2",
|
|
178
|
+
"anderson",
|
|
179
|
+
"linearmixing",
|
|
180
|
+
"diagbroyden",
|
|
181
|
+
"excitingmixing",
|
|
182
|
+
"krylov",
|
|
183
|
+
"df-sane",
|
|
184
|
+
] = 'hybr',
|
|
185
|
+
tol: float | None = None,
|
|
186
|
+
callback = None,
|
|
187
|
+
options = None,
|
|
188
|
+
jac: Literal['2-point', '3-point', 'cs', 'autograd'] = 'autograd',
|
|
189
|
+
):
|
|
190
|
+
super().__init__(params, {})
|
|
191
|
+
self.method = method
|
|
192
|
+
self.tol = tol
|
|
193
|
+
self.callback = callback
|
|
194
|
+
self.options = options
|
|
195
|
+
|
|
196
|
+
self.jac = jac
|
|
197
|
+
if self.jac == 'autograd': self.jac = True
|
|
198
|
+
|
|
199
|
+
def _objective(self, x: np.ndarray, params: TensorList, closure: _ClosureType):
|
|
200
|
+
# set params to x
|
|
201
|
+
params.from_vec_(torch.from_numpy(x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
|
|
202
|
+
|
|
203
|
+
# return value and maybe gradients
|
|
204
|
+
if self.jac:
|
|
205
|
+
with torch.enable_grad():
|
|
206
|
+
value = closure(False)
|
|
207
|
+
if not isinstance(value, torch.Tensor):
|
|
208
|
+
raise TypeError(f"Autograd jacobian requires closure to return torch.Tensor, got {type(value)}")
|
|
209
|
+
jac = jacobian_list_to_vec(jacobian([value], wrt=params))
|
|
210
|
+
return _ensure_numpy(value), jac.detach().cpu().numpy()
|
|
211
|
+
return _ensure_numpy(closure(False))
|
|
212
|
+
|
|
213
|
+
@torch.no_grad
|
|
214
|
+
def step(self, closure: _ClosureType): # type:ignore # pylint:disable = signature-differs
|
|
215
|
+
params = self.get_params()
|
|
216
|
+
|
|
217
|
+
x0 = params.to_vec().detach().cpu().numpy()
|
|
218
|
+
|
|
219
|
+
res = scipy.optimize.root(
|
|
220
|
+
partial(self._objective, params = params, closure = closure),
|
|
221
|
+
x0 = x0,
|
|
222
|
+
method=self.method,
|
|
223
|
+
tol=self.tol,
|
|
224
|
+
callback=self.callback,
|
|
225
|
+
options=self.options,
|
|
226
|
+
jac = self.jac,
|
|
227
|
+
)
|
|
228
|
+
|
|
229
|
+
params.from_vec_(torch.from_numpy(res.x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
|
|
230
|
+
return res.fun
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
class ScipyRootOptimization(TensorListOptimizer):
|
|
234
|
+
"""Optimization via finding roots of the gradient with `scipy.optimize.root` (for experiments, won't work well on most problems).
|
|
235
|
+
|
|
236
|
+
Args:
|
|
237
|
+
params: iterable of parameters to optimize or dicts defining parameter groups.
|
|
238
|
+
method (str, optional): one of methods from https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.root.html#scipy.optimize.root. Defaults to 'hybr'.
|
|
239
|
+
tol (float | None, optional): tolerance. Defaults to None.
|
|
240
|
+
callback (_type_, optional): callback. Defaults to None.
|
|
241
|
+
options (_type_, optional): options for optimizer. Defaults to None.
|
|
242
|
+
jac (Literal['2, optional): jacobian calculation method. Defaults to 'autograd'.
|
|
243
|
+
tikhonov (float | Literal['eig'], optional): tikhonov regularization (only for 'hybr' and 'lm'). Defaults to 0.
|
|
244
|
+
add_loss (float, optional): adds loss value to jacobian multiplied by this to try to avoid finding maxima. Defaults to 0.
|
|
245
|
+
mul_loss (float, optional): multiplies jacobian by loss value multiplied by this to try to avoid finding maxima. Defaults to 0.
|
|
246
|
+
"""
|
|
247
|
+
def __init__(
|
|
248
|
+
self,
|
|
249
|
+
params,
|
|
250
|
+
method: Literal[
|
|
251
|
+
"hybr",
|
|
252
|
+
"lm",
|
|
253
|
+
"broyden1",
|
|
254
|
+
"broyden2",
|
|
255
|
+
"anderson",
|
|
256
|
+
"linearmixing",
|
|
257
|
+
"diagbroyden",
|
|
258
|
+
"excitingmixing",
|
|
259
|
+
"krylov",
|
|
260
|
+
"df-sane",
|
|
261
|
+
] = 'hybr',
|
|
262
|
+
tol: float | None = None,
|
|
263
|
+
callback = None,
|
|
264
|
+
options = None,
|
|
265
|
+
jac: Literal['2-point', '3-point', 'cs', 'autograd'] = 'autograd',
|
|
266
|
+
tikhonov: float | Literal['eig'] = 0,
|
|
267
|
+
add_loss: float = 0,
|
|
268
|
+
mul_loss: float = 0,
|
|
269
|
+
):
|
|
270
|
+
super().__init__(params, {})
|
|
271
|
+
self.method = method
|
|
272
|
+
self.tol = tol
|
|
273
|
+
self.callback = callback
|
|
274
|
+
self.options = options
|
|
275
|
+
self.value = None
|
|
276
|
+
self.tikhonov: float | Literal['eig'] = tikhonov
|
|
277
|
+
self.add_loss = add_loss
|
|
278
|
+
self.mul_loss = mul_loss
|
|
279
|
+
|
|
280
|
+
self.jac = jac == 'autograd'
|
|
281
|
+
|
|
282
|
+
# those don't require jacobian
|
|
283
|
+
if self.method.lower() in ('broyden1', 'broyden2', 'anderson', 'linearmixing', 'diagbroyden', 'excitingmixing', 'krylov', 'df-sane'):
|
|
284
|
+
self.jac = None
|
|
285
|
+
|
|
286
|
+
def _objective(self, x: np.ndarray, params: TensorList, closure: _ClosureType):
|
|
287
|
+
# set params to x
|
|
288
|
+
params.from_vec_(torch.from_numpy(x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
|
|
289
|
+
|
|
290
|
+
# return gradients and maybe hessian
|
|
291
|
+
if self.jac:
|
|
292
|
+
with torch.enable_grad():
|
|
293
|
+
self.value = closure(False)
|
|
294
|
+
if not isinstance(self.value, torch.Tensor):
|
|
295
|
+
raise TypeError(f"Autograd jacobian requires closure to return torch.Tensor, got {type(self.value)}")
|
|
296
|
+
jac_list, hess_list = jacobian_and_hessian([self.value], wrt=params)
|
|
297
|
+
jac = jacobian_list_to_vec(jac_list)
|
|
298
|
+
hess = hessian_list_to_mat(hess_list)
|
|
299
|
+
regularize_hessian_(hess, self.tikhonov)
|
|
300
|
+
if self.mul_loss != 0: jac *= self.value * self.mul_loss
|
|
301
|
+
if self.add_loss != 0: jac += self.value * self.add_loss
|
|
302
|
+
return jac.detach().cpu().numpy(), hess.detach().cpu().numpy()
|
|
303
|
+
|
|
304
|
+
# return the gradients
|
|
305
|
+
with torch.enable_grad(): self.value = closure()
|
|
306
|
+
jac = params.ensure_grad_().grad.to_vec()
|
|
307
|
+
if self.mul_loss != 0: jac *= self.value * self.mul_loss
|
|
308
|
+
if self.add_loss != 0: jac += self.value * self.add_loss
|
|
309
|
+
return jac.detach().cpu().numpy()
|
|
310
|
+
|
|
311
|
+
@torch.no_grad
|
|
312
|
+
def step(self, closure: _ClosureType): # type:ignore # pylint:disable = signature-differs
|
|
313
|
+
params = self.get_params()
|
|
314
|
+
|
|
315
|
+
x0 = params.to_vec().detach().cpu().numpy()
|
|
316
|
+
|
|
317
|
+
res = scipy.optimize.root(
|
|
318
|
+
partial(self._objective, params = params, closure = closure),
|
|
319
|
+
x0 = x0,
|
|
320
|
+
method=self.method,
|
|
321
|
+
tol=self.tol,
|
|
322
|
+
callback=self.callback,
|
|
323
|
+
options=self.options,
|
|
324
|
+
jac = self.jac,
|
|
325
|
+
)
|
|
326
|
+
|
|
327
|
+
params.from_vec_(torch.from_numpy(res.x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
|
|
328
|
+
return self.value
|
|
329
|
+
|
|
330
|
+
class ScipyDE(TensorListOptimizer):
|
|
331
|
+
"""Use scipy.minimize.differential_evolution as pytorch optimizer. Note that this performs full minimization on each step,
|
|
332
|
+
so usually you would want to perform a single step. This also requires bounds to be specified.
|
|
333
|
+
|
|
334
|
+
Please refer to https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.differential_evolution.html
|
|
335
|
+
for all other args.
|
|
336
|
+
|
|
337
|
+
Args:
|
|
338
|
+
params: iterable of parameters to optimize or dicts defining parameter groups.
|
|
339
|
+
bounds (tuple[float,float], optional): tuple with lower and upper bounds.
|
|
340
|
+
DE requires bounds to be specified. Defaults to None.
|
|
341
|
+
|
|
342
|
+
other args:
|
|
343
|
+
refer to https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.differential_evolution.html
|
|
344
|
+
"""
|
|
345
|
+
def __init__(
|
|
346
|
+
self,
|
|
347
|
+
params,
|
|
348
|
+
bounds: tuple[float,float],
|
|
349
|
+
strategy: Literal['best1bin', 'best1exp', 'rand1bin', 'rand1exp', 'rand2bin', 'rand2exp',
|
|
350
|
+
'randtobest1bin', 'randtobest1exp', 'currenttobest1bin', 'currenttobest1exp',
|
|
351
|
+
'best2exp', 'best2bin'] = 'best1bin',
|
|
352
|
+
maxiter: int = 1000,
|
|
353
|
+
popsize: int = 15,
|
|
354
|
+
tol: float = 0.01,
|
|
355
|
+
mutation = (0.5, 1),
|
|
356
|
+
recombination: float = 0.7,
|
|
357
|
+
seed = None,
|
|
358
|
+
callback = None,
|
|
359
|
+
disp: bool = False,
|
|
360
|
+
polish: bool = False,
|
|
361
|
+
init: str = 'latinhypercube',
|
|
362
|
+
atol: int = 0,
|
|
363
|
+
updating: str = 'immediate',
|
|
364
|
+
workers: int = 1,
|
|
365
|
+
constraints = (),
|
|
366
|
+
*,
|
|
367
|
+
integrality = None,
|
|
368
|
+
|
|
369
|
+
):
|
|
370
|
+
super().__init__(params, {})
|
|
371
|
+
|
|
372
|
+
kwargs = locals().copy()
|
|
373
|
+
del kwargs['self'], kwargs['params'], kwargs['bounds'], kwargs['__class__']
|
|
374
|
+
self._kwargs = kwargs
|
|
375
|
+
self._lb, self._ub = bounds
|
|
376
|
+
|
|
377
|
+
def _objective(self, x: np.ndarray, params: TensorList, closure: _ClosureType):
|
|
378
|
+
params.from_vec_(torch.from_numpy(x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
|
|
379
|
+
return _ensure_float(closure(False))
|
|
380
|
+
|
|
381
|
+
@torch.no_grad
|
|
382
|
+
def step(self, closure: _ClosureType): # type:ignore # pylint:disable = signature-differs
|
|
383
|
+
params = self.get_params()
|
|
384
|
+
|
|
385
|
+
x0 = params.to_vec().detach().cpu().numpy()
|
|
386
|
+
bounds = [(self._lb, self._ub)] * len(x0)
|
|
387
|
+
|
|
388
|
+
res = scipy.optimize.differential_evolution(
|
|
389
|
+
partial(self._objective, params = params, closure = closure),
|
|
390
|
+
x0 = x0,
|
|
391
|
+
bounds=bounds,
|
|
392
|
+
**self._kwargs
|
|
393
|
+
)
|
|
394
|
+
|
|
395
|
+
params.from_vec_(torch.from_numpy(res.x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
|
|
396
|
+
return res.fun
|
|
397
|
+
|
|
398
|
+
|
|
399
|
+
class ScipyMinimizeSubspace(Modular):
|
|
400
|
+
"""for experiments and won't work well on most problems.
|
|
401
|
+
|
|
402
|
+
explanation - optimizes in a small subspace using scipy.optimize.minimize, but doesnt seem to work well"""
|
|
403
|
+
def __init__(
|
|
404
|
+
self,
|
|
405
|
+
params,
|
|
406
|
+
projections: Projection | abc.Iterable[Projection] = (
|
|
407
|
+
Proj2Masks(5),
|
|
408
|
+
ProjNormalize(
|
|
409
|
+
ProjGrad(),
|
|
410
|
+
)
|
|
411
|
+
),
|
|
412
|
+
method=None,
|
|
413
|
+
lb = None,
|
|
414
|
+
ub = None,
|
|
415
|
+
constraints=(),
|
|
416
|
+
tol=None,
|
|
417
|
+
callback=None,
|
|
418
|
+
options=None,
|
|
419
|
+
jac: Literal['2-point', '3-point', 'cs', 'autograd'] = 'autograd',
|
|
420
|
+
hess: Literal['2-point', '3-point', 'cs', 'autograd'] | scipy.optimize.HessianUpdateStrategy = '2-point',
|
|
421
|
+
):
|
|
422
|
+
|
|
423
|
+
scopt = WrapClosure(
|
|
424
|
+
ScipyMinimize,
|
|
425
|
+
method = method,
|
|
426
|
+
lb = lb,
|
|
427
|
+
ub = ub,
|
|
428
|
+
constraints = constraints,
|
|
429
|
+
tol = tol,
|
|
430
|
+
callback = callback,
|
|
431
|
+
options = options,
|
|
432
|
+
jac = jac,
|
|
433
|
+
hess = hess
|
|
434
|
+
)
|
|
435
|
+
modules = [
|
|
436
|
+
Subspace(scopt, projections),
|
|
437
|
+
]
|
|
438
|
+
|
|
439
|
+
super().__init__(params, modules)
|
|
@@ -0,0 +1,87 @@
|
|
|
1
|
+
from typing import Literal
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from ...modules import FDM as _FDM, WrapClosure, SGD, WeightDecay, LR
|
|
6
|
+
from ...modules.gradient_approximation._fd_formulas import _FD_Formulas
|
|
7
|
+
from ..modular import Modular
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class FDM(Modular):
|
|
11
|
+
"""Gradient approximation via finite difference.
|
|
12
|
+
|
|
13
|
+
This performs `n + 1` evaluations per step with `forward` and `backward` formulas,
|
|
14
|
+
and `2 * n` with `central` formula, where n is the number of parameters.
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
params: iterable of parameters to optimize or dicts defining parameter groups.
|
|
18
|
+
lr (float, optional): learning rate. Defaults to 1e-3.
|
|
19
|
+
eps (float, optional): finite difference epsilon. Defaults to 1e-3.
|
|
20
|
+
formula (_FD_Formulas, optional): finite difference formula. Defaults to "forward".
|
|
21
|
+
n_points (T.Literal[2, 3], optional): number of points for finite difference formula, 2 or 3. Defaults to 2.
|
|
22
|
+
momentum (float, optional): momentum. Defaults to 0.
|
|
23
|
+
dampening (float, optional): momentum dampening. Defaults to 0.
|
|
24
|
+
nesterov (bool, optional):
|
|
25
|
+
enables nesterov momentum, otherwise uses heavyball momentum. Defaults to False.
|
|
26
|
+
weight_decay (float, optional): weight decay (L2 regularization). Defaults to 0.
|
|
27
|
+
decoupled (bool, optional):
|
|
28
|
+
decouples weight decay from gradient. If True, weight decay doesn't depend on learning rate.
|
|
29
|
+
"""
|
|
30
|
+
def __init__(
|
|
31
|
+
self,
|
|
32
|
+
params,
|
|
33
|
+
lr: float = 1e-3,
|
|
34
|
+
eps: float = 1e-3,
|
|
35
|
+
formula: _FD_Formulas = "forward",
|
|
36
|
+
n_points: Literal[2, 3] = 2,
|
|
37
|
+
momentum: float = 0,
|
|
38
|
+
dampening: float = 0,
|
|
39
|
+
nesterov: bool = False,
|
|
40
|
+
weight_decay: float = 0,
|
|
41
|
+
decoupled=False,
|
|
42
|
+
|
|
43
|
+
):
|
|
44
|
+
modules: list = [
|
|
45
|
+
_FDM(eps = eps, formula=formula, n_points=n_points),
|
|
46
|
+
SGD(momentum = momentum, dampening = dampening, weight_decay = weight_decay if not decoupled else 0, nesterov = nesterov),
|
|
47
|
+
LR(lr),
|
|
48
|
+
|
|
49
|
+
]
|
|
50
|
+
if decoupled: modules.append(WeightDecay(weight_decay))
|
|
51
|
+
super().__init__(params, modules)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class FDMWrapper(Modular):
|
|
55
|
+
"""Gradient approximation via finite difference. This wraps any other optimizer.
|
|
56
|
+
This also supports optimizers that perform multiple gradient evaluations per step, like LBFGS.
|
|
57
|
+
|
|
58
|
+
Exaple:
|
|
59
|
+
```
|
|
60
|
+
lbfgs = torch.optim.LBFGS(params, lr = 1)
|
|
61
|
+
fdm = FDMWrapper(optimizer = lbfgs)
|
|
62
|
+
```
|
|
63
|
+
|
|
64
|
+
This performs n+1 evaluations per step with `forward` and `backward` formulas,
|
|
65
|
+
and 2*n with `central` formula.
|
|
66
|
+
|
|
67
|
+
Args:
|
|
68
|
+
params: iterable of parameters to optimize or dicts defining parameter groups.
|
|
69
|
+
optimizer (torch.optim.Optimizer): optimizer that will perform optimization using FDM-approximated gradients.
|
|
70
|
+
eps (float, optional): finite difference epsilon. Defaults to 1e-3.
|
|
71
|
+
formula (_FD_Formulas, optional): finite difference formula. Defaults to "forward".
|
|
72
|
+
n_points (T.Literal[2, 3], optional): number of points for finite difference formula, 2 or 3. Defaults to 2.
|
|
73
|
+
"""
|
|
74
|
+
def __init__(
|
|
75
|
+
self,
|
|
76
|
+
optimizer: torch.optim.Optimizer,
|
|
77
|
+
eps: float = 1e-3,
|
|
78
|
+
formula: _FD_Formulas = "forward",
|
|
79
|
+
n_points: Literal[2, 3] = 2,
|
|
80
|
+
):
|
|
81
|
+
modules = [
|
|
82
|
+
_FDM(eps = eps, formula=formula, n_points=n_points, target = 'closure'),
|
|
83
|
+
WrapClosure(optimizer)
|
|
84
|
+
]
|
|
85
|
+
# some optimizers have `eps` setting in param groups too.
|
|
86
|
+
# it should not be passed to FDM
|
|
87
|
+
super().__init__([p for g in optimizer.param_groups.copy() for p in g['params']], modules)
|
|
@@ -0,0 +1,146 @@
|
|
|
1
|
+
from typing import Any, Literal
|
|
2
|
+
import torch
|
|
3
|
+
|
|
4
|
+
from ...modules import (LR, FallbackLinearSystemSolvers,
|
|
5
|
+
LinearSystemSolvers, LineSearches, ClipNorm)
|
|
6
|
+
from ...modules import NewtonFDM as _NewtonFDM, get_line_search
|
|
7
|
+
from ...modules.experimental.subspace import Proj2Masks, ProjRandom, Subspace
|
|
8
|
+
from ..modular import Modular
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class NewtonFDM(Modular):
|
|
12
|
+
"""Newton method with gradient and hessian approximated via finite difference.
|
|
13
|
+
|
|
14
|
+
This performs approximately `4 * n^2 + 1` evaluations per step;
|
|
15
|
+
if `diag` is True, performs `n * 2 + 1` evaluations per step.
|
|
16
|
+
|
|
17
|
+
Args:
|
|
18
|
+
params: iterable of parameters to optimize or dicts defining parameter groups.
|
|
19
|
+
lr (float, optional): learning rate.
|
|
20
|
+
eps (float, optional): epsilon for finite difference.
|
|
21
|
+
Note that with float32 this needs to be quite high to avoid numerical instability. Defaults to 1e-2.
|
|
22
|
+
diag (bool, optional): whether to only approximate diagonal elements of the hessian.
|
|
23
|
+
This also ignores `solver` if True. Defaults to False.
|
|
24
|
+
solver (LinearSystemSolvers, optional):
|
|
25
|
+
solver for Hx = g. Defaults to "cholesky_lu" (cholesky or LU if it fails).
|
|
26
|
+
fallback (FallbackLinearSystemSolvers, optional):
|
|
27
|
+
what to do if solver fails. Defaults to "safe_diag"
|
|
28
|
+
(takes nonzero diagonal elements, or fallbacks to gradient descent if all elements are 0).
|
|
29
|
+
validate (bool, optional):
|
|
30
|
+
validate if the step didn't increase the loss by `loss * tol` with an additional forward pass.
|
|
31
|
+
If not, undo the step and perform a gradient descent step.
|
|
32
|
+
tol (float, optional):
|
|
33
|
+
only has effect if `validate` is enabled.
|
|
34
|
+
If loss increased by `loss * tol`, perform gradient descent step.
|
|
35
|
+
Set this to 0 to guarantee that loss always decreases. Defaults to 1.
|
|
36
|
+
gd_lr (float, optional):
|
|
37
|
+
only has effect if `validate` is enabled.
|
|
38
|
+
Gradient descent step learning rate. Defaults to 1e-2.
|
|
39
|
+
line_search (OptimizerModule | None, optional): line search module, can be None. Defaults to 'brent'.
|
|
40
|
+
"""
|
|
41
|
+
def __init__(
|
|
42
|
+
self,
|
|
43
|
+
params,
|
|
44
|
+
lr: float = 1,
|
|
45
|
+
eps: float = 1e-2,
|
|
46
|
+
diag=False,
|
|
47
|
+
solver: LinearSystemSolvers = "cholesky_lu",
|
|
48
|
+
fallback: FallbackLinearSystemSolvers = "safe_diag",
|
|
49
|
+
max_norm: float | None = None,
|
|
50
|
+
validate=False,
|
|
51
|
+
tol: float = 2,
|
|
52
|
+
gd_lr = 1e-2,
|
|
53
|
+
line_search: LineSearches | None = 'brent',
|
|
54
|
+
):
|
|
55
|
+
modules: list[Any] = [
|
|
56
|
+
_NewtonFDM(eps = eps, diag = diag, solver=solver, fallback=fallback, validate=validate, tol=tol, gd_lr=gd_lr),
|
|
57
|
+
]
|
|
58
|
+
|
|
59
|
+
if max_norm is not None:
|
|
60
|
+
modules.append(ClipNorm(max_norm))
|
|
61
|
+
|
|
62
|
+
modules.append(LR(lr))
|
|
63
|
+
|
|
64
|
+
if line_search is not None:
|
|
65
|
+
modules.append(get_line_search(line_search))
|
|
66
|
+
|
|
67
|
+
super().__init__(params, modules)
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
class RandomSubspaceNewtonFDM(Modular):
|
|
71
|
+
"""This projects the parameters into a smaller dimensional subspace,
|
|
72
|
+
making approximating the hessian via finite difference feasible.
|
|
73
|
+
|
|
74
|
+
This performs approximately `4 * subspace_ndim^2 + 1` evaluations per step;
|
|
75
|
+
if `diag` is True, performs `subspace_ndim * 2 + 1` evaluations per step.
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
params: iterable of parameters to optimize or dicts defining parameter groups.
|
|
79
|
+
subspace_ndim (float, optional): number of random subspace dimensions.
|
|
80
|
+
lr (float, optional): learning rate.
|
|
81
|
+
eps (float, optional): epsilon for finite difference.
|
|
82
|
+
Note that with float32 this needs to be quite high to avoid numerical instability. Defaults to 1e-2.
|
|
83
|
+
diag (bool, optional): whether to only approximate diagonal elements of the hessian.
|
|
84
|
+
solver (LinearSystemSolvers, optional):
|
|
85
|
+
solver for Hx = g. Defaults to "cholesky_lu" (cholesky or LU if it fails).
|
|
86
|
+
fallback (FallbackLinearSystemSolvers, optional):
|
|
87
|
+
what to do if solver fails. Defaults to "safe_diag"
|
|
88
|
+
(takes nonzero diagonal elements, or fallbacks to gradient descent if all elements are 0).
|
|
89
|
+
validate (bool, optional):
|
|
90
|
+
validate if the step didn't increase the loss by `loss * tol` with an additional forward pass.
|
|
91
|
+
If not, undo the step and perform a gradient descent step.
|
|
92
|
+
tol (float, optional):
|
|
93
|
+
only has effect if `validate` is enabled.
|
|
94
|
+
If loss increased by `loss * tol`, perform gradient descent step.
|
|
95
|
+
Set this to 0 to guarantee that loss always decreases. Defaults to 1.
|
|
96
|
+
gd_lr (float, optional):
|
|
97
|
+
only has effect if `validate` is enabled.
|
|
98
|
+
Gradient descent step learning rate. Defaults to 1e-2.
|
|
99
|
+
line_search (OptimizerModule | None, optional): line search module, can be None. Defaults to BacktrackingLS().
|
|
100
|
+
randomize_every (float, optional): generates new random projections every n steps. Defaults to 1.
|
|
101
|
+
"""
|
|
102
|
+
def __init__(
|
|
103
|
+
self,
|
|
104
|
+
params,
|
|
105
|
+
subspace_ndim: int = 3,
|
|
106
|
+
lr: float = 1,
|
|
107
|
+
eps: float = 1e-2,
|
|
108
|
+
diag=False,
|
|
109
|
+
solver: LinearSystemSolvers = "cholesky_lu",
|
|
110
|
+
fallback: FallbackLinearSystemSolvers = "safe_diag",
|
|
111
|
+
max_norm: float | None = None,
|
|
112
|
+
validate=False,
|
|
113
|
+
tol: float = 2,
|
|
114
|
+
gd_lr = 1e-2,
|
|
115
|
+
line_search: LineSearches | None = 'brent',
|
|
116
|
+
randomize_every: int = 1,
|
|
117
|
+
):
|
|
118
|
+
if subspace_ndim == 1: projections = [ProjRandom(1)]
|
|
119
|
+
else:
|
|
120
|
+
projections: list[Any] = [Proj2Masks(subspace_ndim//2)]
|
|
121
|
+
if subspace_ndim % 2 == 1: projections.append(ProjRandom(1))
|
|
122
|
+
|
|
123
|
+
modules: list[Any] = [
|
|
124
|
+
Subspace(
|
|
125
|
+
modules = _NewtonFDM(
|
|
126
|
+
eps = eps,
|
|
127
|
+
diag = diag,
|
|
128
|
+
solver=solver,
|
|
129
|
+
fallback=fallback,
|
|
130
|
+
validate=validate,
|
|
131
|
+
tol=tol,
|
|
132
|
+
gd_lr=gd_lr
|
|
133
|
+
),
|
|
134
|
+
projections = projections,
|
|
135
|
+
update_every=randomize_every),
|
|
136
|
+
]
|
|
137
|
+
if max_norm is not None:
|
|
138
|
+
modules.append(ClipNorm(max_norm))
|
|
139
|
+
|
|
140
|
+
modules.append(LR(lr))
|
|
141
|
+
|
|
142
|
+
if line_search is not None:
|
|
143
|
+
modules.append(get_line_search(line_search))
|
|
144
|
+
|
|
145
|
+
super().__init__(params, modules)
|
|
146
|
+
|