torchzero 0.3.11__py3-none-any.whl → 0.3.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_opts.py +95 -69
- tests/test_tensorlist.py +8 -7
- torchzero/__init__.py +1 -1
- torchzero/core/__init__.py +2 -2
- torchzero/core/module.py +225 -72
- torchzero/core/reformulation.py +65 -0
- torchzero/core/transform.py +44 -24
- torchzero/modules/__init__.py +13 -5
- torchzero/modules/{optimizers → adaptive}/__init__.py +5 -2
- torchzero/modules/adaptive/adagrad.py +356 -0
- torchzero/modules/{optimizers → adaptive}/adahessian.py +53 -52
- torchzero/modules/{optimizers → adaptive}/adam.py +0 -3
- torchzero/modules/{optimizers → adaptive}/adan.py +26 -40
- torchzero/modules/{optimizers → adaptive}/adaptive_heavyball.py +3 -6
- torchzero/modules/adaptive/aegd.py +54 -0
- torchzero/modules/{optimizers → adaptive}/esgd.py +1 -1
- torchzero/modules/{optimizers/ladagrad.py → adaptive/lmadagrad.py} +42 -39
- torchzero/modules/{optimizers → adaptive}/mars.py +24 -36
- torchzero/modules/adaptive/matrix_momentum.py +146 -0
- torchzero/modules/{optimizers → adaptive}/msam.py +14 -12
- torchzero/modules/{optimizers → adaptive}/muon.py +19 -20
- torchzero/modules/adaptive/natural_gradient.py +175 -0
- torchzero/modules/{optimizers → adaptive}/rprop.py +0 -2
- torchzero/modules/{optimizers → adaptive}/sam.py +1 -1
- torchzero/modules/{optimizers → adaptive}/shampoo.py +8 -4
- torchzero/modules/{optimizers → adaptive}/soap.py +27 -50
- torchzero/modules/{optimizers → adaptive}/sophia_h.py +2 -3
- torchzero/modules/clipping/clipping.py +85 -92
- torchzero/modules/clipping/ema_clipping.py +5 -5
- torchzero/modules/conjugate_gradient/__init__.py +11 -0
- torchzero/modules/{quasi_newton → conjugate_gradient}/cg.py +355 -369
- torchzero/modules/experimental/__init__.py +9 -32
- torchzero/modules/experimental/dct.py +2 -2
- torchzero/modules/experimental/fft.py +2 -2
- torchzero/modules/experimental/gradmin.py +4 -3
- torchzero/modules/experimental/l_infinity.py +111 -0
- torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +3 -40
- torchzero/modules/experimental/newton_solver.py +79 -17
- torchzero/modules/experimental/newtonnewton.py +27 -14
- torchzero/modules/experimental/scipy_newton_cg.py +105 -0
- torchzero/modules/experimental/structural_projections.py +1 -1
- torchzero/modules/functional.py +50 -14
- torchzero/modules/grad_approximation/fdm.py +19 -20
- torchzero/modules/grad_approximation/forward_gradient.py +4 -2
- torchzero/modules/grad_approximation/grad_approximator.py +43 -47
- torchzero/modules/grad_approximation/rfdm.py +144 -122
- torchzero/modules/higher_order/__init__.py +1 -1
- torchzero/modules/higher_order/higher_order_newton.py +31 -23
- torchzero/modules/least_squares/__init__.py +1 -0
- torchzero/modules/least_squares/gn.py +161 -0
- torchzero/modules/line_search/__init__.py +2 -2
- torchzero/modules/line_search/_polyinterp.py +289 -0
- torchzero/modules/line_search/adaptive.py +69 -44
- torchzero/modules/line_search/backtracking.py +83 -70
- torchzero/modules/line_search/line_search.py +159 -68
- torchzero/modules/line_search/scipy.py +1 -1
- torchzero/modules/line_search/strong_wolfe.py +319 -218
- torchzero/modules/misc/__init__.py +8 -0
- torchzero/modules/misc/debug.py +4 -4
- torchzero/modules/misc/escape.py +9 -7
- torchzero/modules/misc/gradient_accumulation.py +88 -22
- torchzero/modules/misc/homotopy.py +59 -0
- torchzero/modules/misc/misc.py +82 -15
- torchzero/modules/misc/multistep.py +47 -11
- torchzero/modules/misc/regularization.py +5 -9
- torchzero/modules/misc/split.py +55 -35
- torchzero/modules/misc/switch.py +1 -1
- torchzero/modules/momentum/__init__.py +1 -5
- torchzero/modules/momentum/averaging.py +3 -3
- torchzero/modules/momentum/cautious.py +42 -47
- torchzero/modules/momentum/momentum.py +35 -1
- torchzero/modules/ops/__init__.py +9 -1
- torchzero/modules/ops/binary.py +9 -8
- torchzero/modules/{momentum/ema.py → ops/higher_level.py} +10 -33
- torchzero/modules/ops/multi.py +15 -15
- torchzero/modules/ops/reduce.py +1 -1
- torchzero/modules/ops/utility.py +12 -8
- torchzero/modules/projections/projection.py +4 -4
- torchzero/modules/quasi_newton/__init__.py +1 -16
- torchzero/modules/quasi_newton/damping.py +105 -0
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -163
- torchzero/modules/quasi_newton/lbfgs.py +256 -200
- torchzero/modules/quasi_newton/lsr1.py +167 -132
- torchzero/modules/quasi_newton/quasi_newton.py +346 -446
- torchzero/modules/restarts/__init__.py +7 -0
- torchzero/modules/restarts/restars.py +252 -0
- torchzero/modules/second_order/__init__.py +2 -1
- torchzero/modules/second_order/multipoint.py +238 -0
- torchzero/modules/second_order/newton.py +133 -88
- torchzero/modules/second_order/newton_cg.py +141 -80
- torchzero/modules/smoothing/__init__.py +1 -1
- torchzero/modules/smoothing/sampling.py +300 -0
- torchzero/modules/step_size/__init__.py +1 -1
- torchzero/modules/step_size/adaptive.py +312 -47
- torchzero/modules/termination/__init__.py +14 -0
- torchzero/modules/termination/termination.py +207 -0
- torchzero/modules/trust_region/__init__.py +5 -0
- torchzero/modules/trust_region/cubic_regularization.py +170 -0
- torchzero/modules/trust_region/dogleg.py +92 -0
- torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
- torchzero/modules/trust_region/trust_cg.py +97 -0
- torchzero/modules/trust_region/trust_region.py +350 -0
- torchzero/modules/variance_reduction/__init__.py +1 -0
- torchzero/modules/variance_reduction/svrg.py +208 -0
- torchzero/modules/weight_decay/weight_decay.py +65 -64
- torchzero/modules/zeroth_order/__init__.py +1 -0
- torchzero/modules/zeroth_order/cd.py +359 -0
- torchzero/optim/root.py +65 -0
- torchzero/optim/utility/split.py +8 -8
- torchzero/optim/wrappers/directsearch.py +0 -1
- torchzero/optim/wrappers/fcmaes.py +3 -2
- torchzero/optim/wrappers/nlopt.py +0 -2
- torchzero/optim/wrappers/optuna.py +2 -2
- torchzero/optim/wrappers/scipy.py +81 -22
- torchzero/utils/__init__.py +40 -4
- torchzero/utils/compile.py +1 -1
- torchzero/utils/derivatives.py +123 -111
- torchzero/utils/linalg/__init__.py +9 -2
- torchzero/utils/linalg/linear_operator.py +329 -0
- torchzero/utils/linalg/matrix_funcs.py +2 -2
- torchzero/utils/linalg/orthogonalize.py +2 -1
- torchzero/utils/linalg/qr.py +2 -2
- torchzero/utils/linalg/solve.py +226 -154
- torchzero/utils/metrics.py +83 -0
- torchzero/utils/python_tools.py +6 -0
- torchzero/utils/tensorlist.py +105 -34
- torchzero/utils/torch_tools.py +9 -4
- torchzero-0.3.13.dist-info/METADATA +14 -0
- torchzero-0.3.13.dist-info/RECORD +166 -0
- {torchzero-0.3.11.dist-info → torchzero-0.3.13.dist-info}/top_level.txt +0 -1
- docs/source/conf.py +0 -59
- docs/source/docstring template.py +0 -46
- torchzero/modules/experimental/absoap.py +0 -253
- torchzero/modules/experimental/adadam.py +0 -118
- torchzero/modules/experimental/adamY.py +0 -131
- torchzero/modules/experimental/adam_lambertw.py +0 -149
- torchzero/modules/experimental/adaptive_step_size.py +0 -90
- torchzero/modules/experimental/adasoap.py +0 -177
- torchzero/modules/experimental/cosine.py +0 -214
- torchzero/modules/experimental/cubic_adam.py +0 -97
- torchzero/modules/experimental/eigendescent.py +0 -120
- torchzero/modules/experimental/etf.py +0 -195
- torchzero/modules/experimental/exp_adam.py +0 -113
- torchzero/modules/experimental/expanded_lbfgs.py +0 -141
- torchzero/modules/experimental/hnewton.py +0 -85
- torchzero/modules/experimental/modular_lbfgs.py +0 -265
- torchzero/modules/experimental/parabolic_search.py +0 -220
- torchzero/modules/experimental/subspace_preconditioners.py +0 -145
- torchzero/modules/experimental/tensor_adagrad.py +0 -42
- torchzero/modules/line_search/polynomial.py +0 -233
- torchzero/modules/momentum/matrix_momentum.py +0 -193
- torchzero/modules/optimizers/adagrad.py +0 -165
- torchzero/modules/quasi_newton/trust_region.py +0 -397
- torchzero/modules/smoothing/gaussian.py +0 -198
- torchzero-0.3.11.dist-info/METADATA +0 -404
- torchzero-0.3.11.dist-info/RECORD +0 -159
- torchzero-0.3.11.dist-info/licenses/LICENSE +0 -21
- /torchzero/modules/{optimizers → adaptive}/lion.py +0 -0
- /torchzero/modules/{optimizers → adaptive}/orthograd.py +0 -0
- /torchzero/modules/{optimizers → adaptive}/rmsprop.py +0 -0
- {torchzero-0.3.11.dist-info → torchzero-0.3.13.dist-info}/WHEEL +0 -0
|
@@ -1,397 +0,0 @@
|
|
|
1
|
-
"""Trust region API is currently experimental, it will probably change completely"""
|
|
2
|
-
# pylint:disable=not-callable
|
|
3
|
-
from abc import ABC, abstractmethod
|
|
4
|
-
from typing import Any, Literal, cast, final
|
|
5
|
-
from collections.abc import Sequence, Mapping
|
|
6
|
-
|
|
7
|
-
import numpy as np
|
|
8
|
-
import torch
|
|
9
|
-
from scipy.optimize import lsq_linear
|
|
10
|
-
|
|
11
|
-
from ...core import Chainable, Module, apply_transform, Var
|
|
12
|
-
from ...utils import TensorList, vec_to_tensors
|
|
13
|
-
from ...utils.derivatives import (
|
|
14
|
-
hessian_list_to_mat,
|
|
15
|
-
jacobian_and_hessian_wrt,
|
|
16
|
-
)
|
|
17
|
-
from .quasi_newton import HessianUpdateStrategy
|
|
18
|
-
from ...utils.linalg import steihaug_toint_cg
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
def trust_lstsq(H: torch.Tensor, g: torch.Tensor, trust_region: float):
|
|
22
|
-
res = lsq_linear(H.numpy(force=True).astype(np.float64), g.numpy(force=True).astype(np.float64), bounds=(-trust_region, trust_region))
|
|
23
|
-
x = torch.from_numpy(res.x).to(H)
|
|
24
|
-
return x, res.cost
|
|
25
|
-
|
|
26
|
-
def _flatten_tensors(tensors: list[torch.Tensor]):
|
|
27
|
-
return torch.cat([t.ravel() for t in tensors])
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
class TrustRegionBase(Module, ABC):
|
|
31
|
-
def __init__(
|
|
32
|
-
self,
|
|
33
|
-
defaults: dict | None = None,
|
|
34
|
-
hess_module: HessianUpdateStrategy | None = None,
|
|
35
|
-
update_freq: int = 1,
|
|
36
|
-
inner: Chainable | None = None,
|
|
37
|
-
):
|
|
38
|
-
self._update_freq = update_freq
|
|
39
|
-
super().__init__(defaults)
|
|
40
|
-
|
|
41
|
-
if hess_module is not None:
|
|
42
|
-
self.set_child('hess_module', hess_module)
|
|
43
|
-
|
|
44
|
-
if inner is not None:
|
|
45
|
-
self.set_child('inner', inner)
|
|
46
|
-
|
|
47
|
-
@abstractmethod
|
|
48
|
-
def trust_region_step(self, var: Var, tensors:list[torch.Tensor], P: torch.Tensor, is_inverse:bool) -> Var:
|
|
49
|
-
"""trust region logic"""
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
@final
|
|
53
|
-
@torch.no_grad
|
|
54
|
-
def update(self, var):
|
|
55
|
-
# ---------------------------------- update ---------------------------------- #
|
|
56
|
-
closure = var.closure
|
|
57
|
-
if closure is None: raise RuntimeError("Trust region requires closure")
|
|
58
|
-
params = var.params
|
|
59
|
-
|
|
60
|
-
step = self.global_state.get('step', 0)
|
|
61
|
-
self.global_state['step'] = step + 1
|
|
62
|
-
|
|
63
|
-
P = None
|
|
64
|
-
is_inverse=None
|
|
65
|
-
g_list = var.grad
|
|
66
|
-
loss = var.loss
|
|
67
|
-
if step % self._update_freq == 0:
|
|
68
|
-
|
|
69
|
-
if 'hess_module' not in self.children:
|
|
70
|
-
params=var.params
|
|
71
|
-
closure=var.closure
|
|
72
|
-
if closure is None: raise ValueError('Closure is required for trust region')
|
|
73
|
-
with torch.enable_grad():
|
|
74
|
-
loss = var.loss = var.loss_approx = closure(False)
|
|
75
|
-
g_list, H_list = jacobian_and_hessian_wrt([loss], params, batched=True)
|
|
76
|
-
g_list = [t[0] for t in g_list] # remove leading dim from loss
|
|
77
|
-
var.grad = g_list
|
|
78
|
-
P = hessian_list_to_mat(H_list)
|
|
79
|
-
is_inverse=False
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
else:
|
|
83
|
-
hessian_module = cast(HessianUpdateStrategy, self.children['hess_module'])
|
|
84
|
-
hessian_module.update(var)
|
|
85
|
-
P, is_inverse = hessian_module.get_B()
|
|
86
|
-
|
|
87
|
-
if self._update_freq != 0:
|
|
88
|
-
self.global_state['B'] = P
|
|
89
|
-
self.global_state['is_inverse'] = is_inverse
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
@final
|
|
93
|
-
@torch.no_grad
|
|
94
|
-
def apply(self, var):
|
|
95
|
-
P = self.global_state['B']
|
|
96
|
-
is_inverse = self.global_state['is_inverse']
|
|
97
|
-
|
|
98
|
-
# -------------------------------- inner step -------------------------------- #
|
|
99
|
-
update = var.get_update()
|
|
100
|
-
if 'inner' in self.children:
|
|
101
|
-
update = apply_transform(self.children['inner'], update, params=var.params, grads=var.grad, var=var)
|
|
102
|
-
|
|
103
|
-
# ----------------------------------- apply ---------------------------------- #
|
|
104
|
-
return self.trust_region_step(var=var, tensors=update, P=P, is_inverse=is_inverse)
|
|
105
|
-
|
|
106
|
-
def _update_tr_radius(update_vec:torch.Tensor, params: Sequence[torch.Tensor], closure,
|
|
107
|
-
loss, g:torch.Tensor, H:torch.Tensor, trust_region:float, settings: Mapping):
|
|
108
|
-
"""returns (update, new_trust_region)
|
|
109
|
-
|
|
110
|
-
Args:
|
|
111
|
-
update_vec (torch.Tensor): update vector which is SUBTRACTED from parameters
|
|
112
|
-
params (_type_): params tensor list
|
|
113
|
-
closure (_type_): closure
|
|
114
|
-
loss (_type_): loss at x0
|
|
115
|
-
g (torch.Tensor): gradient vector
|
|
116
|
-
H (torch.Tensor): hessian
|
|
117
|
-
trust_region (float): current trust region value
|
|
118
|
-
"""
|
|
119
|
-
# evaluate actual loss reduction
|
|
120
|
-
update_unflattned = vec_to_tensors(update_vec, params)
|
|
121
|
-
params = TensorList(params)
|
|
122
|
-
params -= update_unflattned
|
|
123
|
-
loss_star = closure(False)
|
|
124
|
-
params += update_unflattned
|
|
125
|
-
reduction = loss - loss_star
|
|
126
|
-
|
|
127
|
-
# expected reduction is g.T @ p + 0.5 * p.T @ B @ p
|
|
128
|
-
if H.ndim == 1: Hu = H * update_vec
|
|
129
|
-
else: Hu = H @ update_vec
|
|
130
|
-
pred_reduction = - (g.dot(update_vec) + 0.5 * update_vec.dot(Hu))
|
|
131
|
-
rho = reduction / (pred_reduction.clip(min=1e-8))
|
|
132
|
-
|
|
133
|
-
# failed step
|
|
134
|
-
if rho < 0.25:
|
|
135
|
-
trust_region *= settings["nminus"]
|
|
136
|
-
|
|
137
|
-
# very good step
|
|
138
|
-
elif rho > 0.75:
|
|
139
|
-
diff = trust_region - update_vec.abs()
|
|
140
|
-
if (diff.amin() / trust_region) > 1e-4: # hits boundary
|
|
141
|
-
trust_region *= settings["nplus"]
|
|
142
|
-
|
|
143
|
-
# # if the ratio is high enough then accept the proposed step
|
|
144
|
-
# if rho > settings["eta"]:
|
|
145
|
-
# update = vec_to_tensors(update_vec, params)
|
|
146
|
-
|
|
147
|
-
# else:
|
|
148
|
-
# update = params.zeros_like()
|
|
149
|
-
|
|
150
|
-
return trust_region, rho > settings["eta"]
|
|
151
|
-
|
|
152
|
-
class TrustCG(TrustRegionBase):
|
|
153
|
-
"""Trust region via Steihaug-Toint Conjugate Gradient method. This is mainly useful for quasi-newton methods.
|
|
154
|
-
If you don't use :code:`hess_module`, use the matrix-free :code:`tz.m.NewtonCGSteihaug` which only uses hessian-vector products.
|
|
155
|
-
|
|
156
|
-
Args:
|
|
157
|
-
hess_module (HessianUpdateStrategy | None, optional):
|
|
158
|
-
Hessian update strategy, must be one of the :code:`HessianUpdateStrategy` modules. Make sure to set :code:`inverse=False`. If None, uses autograd to calculate the hessian. Defaults to None.
|
|
159
|
-
eta (float, optional):
|
|
160
|
-
if ratio of actual to predicted rediction is larger than this, step is accepted.
|
|
161
|
-
When :code:`hess_module` is None, this can be set to 0. Defaults to 0.15.
|
|
162
|
-
nplus (float, optional): increase factor on successful steps. Defaults to 1.5.
|
|
163
|
-
nminus (float, optional): decrease factor on unsuccessful steps. Defaults to 0.75.
|
|
164
|
-
init (float, optional): Initial trust region value. Defaults to 1.
|
|
165
|
-
update_freq (int, optional): frequency of updating the hessian. Defaults to 1.
|
|
166
|
-
reg (int, optional): hessian regularization. Defaults to 0.
|
|
167
|
-
inner (Chainable | None, optional): preconditioning is applied to output of thise module. Defaults to None.
|
|
168
|
-
|
|
169
|
-
Examples:
|
|
170
|
-
Trust-SR1
|
|
171
|
-
|
|
172
|
-
.. code-block:: python
|
|
173
|
-
|
|
174
|
-
opt = tz.Modular(
|
|
175
|
-
model.parameters(),
|
|
176
|
-
tz.m.TrustCG(hess_module=tz.m.SR1(inverse=False)),
|
|
177
|
-
)
|
|
178
|
-
"""
|
|
179
|
-
def __init__(
|
|
180
|
-
self,
|
|
181
|
-
hess_module: HessianUpdateStrategy | None,
|
|
182
|
-
eta: float= 0.15,
|
|
183
|
-
nplus: float = 2,
|
|
184
|
-
nminus: float = 0.25,
|
|
185
|
-
init: float = 1,
|
|
186
|
-
update_freq: int = 1,
|
|
187
|
-
reg: float = 0,
|
|
188
|
-
max_attempts: int = 10,
|
|
189
|
-
inner: Chainable | None = None,
|
|
190
|
-
):
|
|
191
|
-
defaults = dict(init=init, nplus=nplus, nminus=nminus, eta=eta, reg=reg, max_attempts=max_attempts)
|
|
192
|
-
super().__init__(defaults, hess_module=hess_module, update_freq=update_freq, inner=inner)
|
|
193
|
-
|
|
194
|
-
@torch.no_grad
|
|
195
|
-
def trust_region_step(self, var, tensors, P, is_inverse):
|
|
196
|
-
params = TensorList(var.params)
|
|
197
|
-
settings = self.settings[params[0]]
|
|
198
|
-
g = _flatten_tensors(tensors)
|
|
199
|
-
|
|
200
|
-
reg = settings['reg']
|
|
201
|
-
max_attempts = settings['max_attempts']
|
|
202
|
-
|
|
203
|
-
loss = var.loss
|
|
204
|
-
closure = var.closure
|
|
205
|
-
if closure is None: raise RuntimeError("Trust region requires closure")
|
|
206
|
-
if loss is None: loss = closure(False)
|
|
207
|
-
|
|
208
|
-
if is_inverse:
|
|
209
|
-
if P.ndim == 1: P = P.reciprocal()
|
|
210
|
-
else: raise NotImplementedError()
|
|
211
|
-
|
|
212
|
-
success = False
|
|
213
|
-
update_vec = None
|
|
214
|
-
while not success:
|
|
215
|
-
max_attempts -= 1
|
|
216
|
-
if max_attempts < 0: break
|
|
217
|
-
|
|
218
|
-
trust_region = self.global_state.get('trust_region', settings['init'])
|
|
219
|
-
|
|
220
|
-
if trust_region < 1e-8 or trust_region > 1e8:
|
|
221
|
-
trust_region = self.global_state['trust_region'] = settings['init']
|
|
222
|
-
|
|
223
|
-
update_vec = steihaug_toint_cg(P, g, trust_region, reg=reg)
|
|
224
|
-
|
|
225
|
-
self.global_state['trust_region'], success = _update_tr_radius(
|
|
226
|
-
update_vec=update_vec, params=params, closure=closure,
|
|
227
|
-
loss=loss, g=g, H=P, trust_region=trust_region, settings = settings,
|
|
228
|
-
)
|
|
229
|
-
|
|
230
|
-
assert update_vec is not None
|
|
231
|
-
if success: var.update = vec_to_tensors(update_vec, params)
|
|
232
|
-
else: var.update = params.zeros_like()
|
|
233
|
-
|
|
234
|
-
return var
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
# code from https://github.com/konstmish/opt_methods/blob/master/optmethods/second_order/cubic.py
|
|
238
|
-
# ported to torch
|
|
239
|
-
def ls_cubic_solver(f, g:torch.Tensor, H:torch.Tensor, M: float, is_inverse: bool, loss_plus, it_max=100, epsilon=1e-8, ):
|
|
240
|
-
"""
|
|
241
|
-
Solve min_z <g, z-x> + 1/2<z-x, H(z-x)> + M/3 ||z-x||^3
|
|
242
|
-
|
|
243
|
-
For explanation of Cauchy point, see "Gradient Descent
|
|
244
|
-
Efficiently Finds the Cubic-Regularized Non-Convex Newton Step"
|
|
245
|
-
https://arxiv.org/pdf/1612.00547.pdf
|
|
246
|
-
Other potential implementations can be found in paper
|
|
247
|
-
"Adaptive cubic regularisation methods"
|
|
248
|
-
https://people.maths.ox.ac.uk/cartis/papers/ARCpI.pdf
|
|
249
|
-
"""
|
|
250
|
-
solver_it = 1
|
|
251
|
-
if is_inverse:
|
|
252
|
-
newton_step = - H @ g
|
|
253
|
-
H = torch.linalg.inv(H)
|
|
254
|
-
else:
|
|
255
|
-
newton_step, info = torch.linalg.solve_ex(H, g)
|
|
256
|
-
if info != 0:
|
|
257
|
-
newton_step = torch.linalg.lstsq(H, g).solution
|
|
258
|
-
newton_step.neg_()
|
|
259
|
-
if M == 0:
|
|
260
|
-
return newton_step, solver_it
|
|
261
|
-
def cauchy_point(g, H, M):
|
|
262
|
-
if torch.linalg.vector_norm(g) == 0 or M == 0:
|
|
263
|
-
return 0 * g
|
|
264
|
-
g_dir = g / torch.linalg.vector_norm(g)
|
|
265
|
-
H_g_g = H @ g_dir @ g_dir
|
|
266
|
-
R = -H_g_g / (2*M) + torch.sqrt((H_g_g/M)**2/4 + torch.linalg.vector_norm(g)/M)
|
|
267
|
-
return -R * g_dir
|
|
268
|
-
|
|
269
|
-
def conv_criterion(s, r):
|
|
270
|
-
"""
|
|
271
|
-
The convergence criterion is an increasing and concave function in r
|
|
272
|
-
and it is equal to 0 only if r is the solution to the cubic problem
|
|
273
|
-
"""
|
|
274
|
-
s_norm = torch.linalg.vector_norm(s)
|
|
275
|
-
return 1/s_norm - 1/r
|
|
276
|
-
|
|
277
|
-
# Solution s satisfies ||s|| >= Cauchy_radius
|
|
278
|
-
r_min = torch.linalg.vector_norm(cauchy_point(g, H, M))
|
|
279
|
-
|
|
280
|
-
if f > loss_plus(newton_step):
|
|
281
|
-
return newton_step, solver_it
|
|
282
|
-
|
|
283
|
-
r_max = torch.linalg.vector_norm(newton_step)
|
|
284
|
-
if r_max - r_min < epsilon:
|
|
285
|
-
return newton_step, solver_it
|
|
286
|
-
id_matrix = torch.eye(g.size(0), device=g.device, dtype=g.dtype)
|
|
287
|
-
s_lam = None
|
|
288
|
-
for _ in range(it_max):
|
|
289
|
-
r_try = (r_min + r_max) / 2
|
|
290
|
-
lam = r_try * M
|
|
291
|
-
s_lam = -torch.linalg.solve(H + lam*id_matrix, g)
|
|
292
|
-
solver_it += 1
|
|
293
|
-
crit = conv_criterion(s_lam, r_try)
|
|
294
|
-
if np.abs(crit) < epsilon:
|
|
295
|
-
return s_lam, solver_it
|
|
296
|
-
if crit < 0:
|
|
297
|
-
r_min = r_try
|
|
298
|
-
else:
|
|
299
|
-
r_max = r_try
|
|
300
|
-
if r_max - r_min < epsilon:
|
|
301
|
-
break
|
|
302
|
-
assert s_lam is not None
|
|
303
|
-
return s_lam, solver_it
|
|
304
|
-
|
|
305
|
-
class CubicRegularization(TrustRegionBase):
|
|
306
|
-
"""Cubic regularization.
|
|
307
|
-
|
|
308
|
-
.. note::
|
|
309
|
-
by default this functions like a trust region, set nplus and nminus = 1 to make regularization parameter fixed.
|
|
310
|
-
:code:`init` sets 1/regularization.
|
|
311
|
-
|
|
312
|
-
Args:
|
|
313
|
-
hess_module (HessianUpdateStrategy | None, optional):
|
|
314
|
-
Hessian update strategy, must be one of the :code:`HessianUpdateStrategy` modules. This works better with true hessian though. Make sure to set :code:`inverse=False`. If None, uses autograd to calculate the hessian. Defaults to None.
|
|
315
|
-
eta (float, optional):
|
|
316
|
-
if ratio of actual to predicted rediction is larger than this, step is accepted.
|
|
317
|
-
When :code:`hess_module` is None, this can be set to 0. Defaults to 0.0.
|
|
318
|
-
nplus (float, optional): increase factor on successful steps. Defaults to 1.5.
|
|
319
|
-
nminus (float, optional): decrease factor on unsuccessful steps. Defaults to 0.75.
|
|
320
|
-
init (float, optional): Initial trust region value. Defaults to 1.
|
|
321
|
-
maxiter (float, optional): maximum iterations when solving cubic subproblem, defaults to 1e-7.
|
|
322
|
-
eps (float, optional): epsilon for the solver, defaults to 1e-8.
|
|
323
|
-
update_freq (int, optional): frequency of updating the hessian. Defaults to 1.
|
|
324
|
-
inner (Chainable | None, optional): preconditioning is applied to output of thise module. Defaults to None.
|
|
325
|
-
|
|
326
|
-
Examples:
|
|
327
|
-
Cubic regularized newton
|
|
328
|
-
|
|
329
|
-
.. code-block:: python
|
|
330
|
-
|
|
331
|
-
opt = tz.Modular(
|
|
332
|
-
model.parameters(),
|
|
333
|
-
tz.m.CubicRegularization(),
|
|
334
|
-
)
|
|
335
|
-
|
|
336
|
-
"""
|
|
337
|
-
def __init__(
|
|
338
|
-
self,
|
|
339
|
-
hess_module: HessianUpdateStrategy | None = None,
|
|
340
|
-
eta: float= 0.0,
|
|
341
|
-
nplus: float = 2,
|
|
342
|
-
nminus: float = 0.25,
|
|
343
|
-
init: float = 1,
|
|
344
|
-
maxiter: int = 100,
|
|
345
|
-
eps: float = 1e-8,
|
|
346
|
-
update_freq: int = 1,
|
|
347
|
-
max_attempts: int = 10,
|
|
348
|
-
inner: Chainable | None = None,
|
|
349
|
-
):
|
|
350
|
-
defaults = dict(init=init, nplus=nplus, nminus=nminus, eta=eta, maxiter=maxiter, eps=eps, max_attempts=max_attempts)
|
|
351
|
-
super().__init__(defaults, hess_module=hess_module, update_freq=update_freq, inner=inner)
|
|
352
|
-
|
|
353
|
-
@torch.no_grad
|
|
354
|
-
def trust_region_step(self, var, tensors, P, is_inverse):
|
|
355
|
-
params = TensorList(var.params)
|
|
356
|
-
settings = self.settings[params[0]]
|
|
357
|
-
g = _flatten_tensors(tensors)
|
|
358
|
-
|
|
359
|
-
maxiter = settings['maxiter']
|
|
360
|
-
max_attempts = settings['max_attempts']
|
|
361
|
-
eps = settings['eps']
|
|
362
|
-
|
|
363
|
-
loss = var.loss
|
|
364
|
-
closure = var.closure
|
|
365
|
-
if closure is None: raise RuntimeError("Trust region requires closure")
|
|
366
|
-
if loss is None: loss = closure(False)
|
|
367
|
-
|
|
368
|
-
def loss_plus(x):
|
|
369
|
-
x_unflat = vec_to_tensors(x, params)
|
|
370
|
-
params.add_(x_unflat)
|
|
371
|
-
loss_x = closure(False)
|
|
372
|
-
params.sub_(x_unflat)
|
|
373
|
-
return loss_x
|
|
374
|
-
|
|
375
|
-
success = False
|
|
376
|
-
update_vec = None
|
|
377
|
-
while not success:
|
|
378
|
-
max_attempts -= 1
|
|
379
|
-
if max_attempts < 0: break
|
|
380
|
-
|
|
381
|
-
trust_region = self.global_state.get('trust_region', settings['init'])
|
|
382
|
-
if trust_region < 1e-8 or trust_region > 1e16: trust_region = self.global_state['trust_region'] = settings['init']
|
|
383
|
-
|
|
384
|
-
update_vec, _ = ls_cubic_solver(f=loss, g=g, H=P, M=1/trust_region, is_inverse=is_inverse,
|
|
385
|
-
loss_plus=loss_plus, it_max=maxiter, epsilon=eps)
|
|
386
|
-
update_vec.neg_()
|
|
387
|
-
|
|
388
|
-
self.global_state['trust_region'], success = _update_tr_radius(
|
|
389
|
-
update_vec=update_vec, params=params, closure=closure,
|
|
390
|
-
loss=loss, g=g, H=P, trust_region=trust_region, settings = settings,
|
|
391
|
-
)
|
|
392
|
-
|
|
393
|
-
assert update_vec is not None
|
|
394
|
-
if success: var.update = vec_to_tensors(update_vec, params)
|
|
395
|
-
else: var.update = params.zeros_like()
|
|
396
|
-
|
|
397
|
-
return var
|
|
@@ -1,198 +0,0 @@
|
|
|
1
|
-
import warnings
|
|
2
|
-
from abc import ABC, abstractmethod
|
|
3
|
-
from collections.abc import Callable, Sequence
|
|
4
|
-
from functools import partial
|
|
5
|
-
from typing import Literal
|
|
6
|
-
|
|
7
|
-
import torch
|
|
8
|
-
|
|
9
|
-
from ...core import Modular, Module, Var
|
|
10
|
-
from ...utils import NumberList, TensorList
|
|
11
|
-
from ...utils.derivatives import jacobian_wrt
|
|
12
|
-
from ..grad_approximation import GradApproximator, GradTarget
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
class Reformulation(Module, ABC):
|
|
16
|
-
def __init__(self, defaults):
|
|
17
|
-
super().__init__(defaults)
|
|
18
|
-
|
|
19
|
-
@abstractmethod
|
|
20
|
-
def closure(self, backward: bool, closure: Callable, params:list[torch.Tensor], var: Var) -> tuple[float | torch.Tensor, Sequence[torch.Tensor] | None]:
|
|
21
|
-
"""returns loss and gradient, if backward is False then gradient can be None"""
|
|
22
|
-
|
|
23
|
-
def pre_step(self, var: Var) -> Var | None:
|
|
24
|
-
"""This runs once before each step, whereas `closure` may run multiple times per step if further modules
|
|
25
|
-
evaluate gradients at multiple points. This is useful for example to pre-generate new random perturbations."""
|
|
26
|
-
return var
|
|
27
|
-
|
|
28
|
-
def step(self, var):
|
|
29
|
-
ret = self.pre_step(var)
|
|
30
|
-
if isinstance(ret, Var): var = ret
|
|
31
|
-
|
|
32
|
-
if var.closure is None: raise RuntimeError("Reformulation requires closure")
|
|
33
|
-
params, closure = var.params, var.closure
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
def modified_closure(backward=True):
|
|
37
|
-
loss, grad = self.closure(backward, closure, params, var)
|
|
38
|
-
|
|
39
|
-
if grad is not None:
|
|
40
|
-
for p,g in zip(params, grad):
|
|
41
|
-
p.grad = g
|
|
42
|
-
|
|
43
|
-
return loss
|
|
44
|
-
|
|
45
|
-
var.closure = modified_closure
|
|
46
|
-
return var
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
def _decay_sigma_(self: Module, params):
|
|
50
|
-
for p in params:
|
|
51
|
-
state = self.state[p]
|
|
52
|
-
settings = self.settings[p]
|
|
53
|
-
state['sigma'] *= settings['decay']
|
|
54
|
-
|
|
55
|
-
def _generate_perturbations_to_state_(self: Module, params: TensorList, n_samples, sigmas, generator):
|
|
56
|
-
perturbations = [params.sample_like(generator=generator) for _ in range(n_samples)]
|
|
57
|
-
torch._foreach_mul_([p for l in perturbations for p in l], [v for vv in sigmas for v in [vv]*n_samples])
|
|
58
|
-
for param, prt in zip(params, zip(*perturbations)):
|
|
59
|
-
self.state[param]['perturbations'] = prt
|
|
60
|
-
|
|
61
|
-
def _clear_state_hook(optimizer: Modular, var: Var, self: Module):
|
|
62
|
-
for m in optimizer.unrolled_modules:
|
|
63
|
-
if m is not self:
|
|
64
|
-
m.reset()
|
|
65
|
-
|
|
66
|
-
class GaussianHomotopy(Reformulation):
|
|
67
|
-
"""Approximately smoothes the function with a gaussian kernel by sampling it at random perturbed points around current point. Both function values and gradients are averaged over all samples. The perturbed points are generated before each
|
|
68
|
-
step and remain the same throughout the step.
|
|
69
|
-
|
|
70
|
-
.. note::
|
|
71
|
-
This module reformulates the objective, it modifies the closure to evaluate value and gradients of a smoothed function. All modules after this will operate on the modified objective.
|
|
72
|
-
|
|
73
|
-
.. note::
|
|
74
|
-
This module requires the a closure passed to the optimizer step,
|
|
75
|
-
as it needs to re-evaluate the loss and gradients at perturbed points.
|
|
76
|
-
|
|
77
|
-
Args:
|
|
78
|
-
n_samples (int): number of points to sample, larger values lead to a more accurate smoothing.
|
|
79
|
-
init_sigma (float): initial scale of perturbations.
|
|
80
|
-
tol (float | None, optional):
|
|
81
|
-
if maximal parameters change value is smaller than this, sigma is reduced by :code:`decay`. Defaults to 1e-4.
|
|
82
|
-
decay (float, optional): multiplier to sigma when converged on a smoothed function. Defaults to 0.5.
|
|
83
|
-
max_steps (int | None, optional): maximum number of steps before decaying sigma. Defaults to None.
|
|
84
|
-
clear_state (bool, optional):
|
|
85
|
-
whether to clear all other module states when sigma is decayed, because the objective function changes. Defaults to True.
|
|
86
|
-
seed (int | None, optional): seed for random perturbationss. Defaults to None.
|
|
87
|
-
|
|
88
|
-
Examples:
|
|
89
|
-
Gaussian-smoothed NewtonCG
|
|
90
|
-
|
|
91
|
-
.. code-block:: python
|
|
92
|
-
|
|
93
|
-
opt = tz.Modular(
|
|
94
|
-
model.parameters(),
|
|
95
|
-
tz.m.GaussianHomotopy(100),
|
|
96
|
-
tz.m.NewtonCG(maxiter=20),
|
|
97
|
-
tz.m.AdaptiveBacktracking(),
|
|
98
|
-
)
|
|
99
|
-
|
|
100
|
-
"""
|
|
101
|
-
def __init__(
|
|
102
|
-
self,
|
|
103
|
-
n_samples: int,
|
|
104
|
-
init_sigma: float,
|
|
105
|
-
tol: float | None = 1e-4,
|
|
106
|
-
decay=0.5,
|
|
107
|
-
max_steps: int | None = None,
|
|
108
|
-
clear_state=True,
|
|
109
|
-
seed: int | None = None,
|
|
110
|
-
):
|
|
111
|
-
defaults = dict(n_samples=n_samples, init_sigma=init_sigma, tol=tol, decay=decay, max_steps=max_steps, clear_state=clear_state, seed=seed)
|
|
112
|
-
super().__init__(defaults)
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
def _get_generator(self, seed: int | None | torch.Generator, params: list[torch.Tensor]):
|
|
116
|
-
if 'generator' not in self.global_state:
|
|
117
|
-
if isinstance(seed, torch.Generator): self.global_state['generator'] = seed
|
|
118
|
-
elif seed is not None: self.global_state['generator'] = torch.Generator(params[0].device).manual_seed(seed)
|
|
119
|
-
else: self.global_state['generator'] = None
|
|
120
|
-
return self.global_state['generator']
|
|
121
|
-
|
|
122
|
-
def pre_step(self, var):
|
|
123
|
-
params = TensorList(var.params)
|
|
124
|
-
settings = self.settings[params[0]]
|
|
125
|
-
n_samples = settings['n_samples']
|
|
126
|
-
init_sigma = [self.settings[p]['init_sigma'] for p in params]
|
|
127
|
-
sigmas = self.get_state(params, 'sigma', init=init_sigma)
|
|
128
|
-
|
|
129
|
-
if any('perturbations' not in self.state[p] for p in params):
|
|
130
|
-
generator = self._get_generator(settings['seed'], params)
|
|
131
|
-
_generate_perturbations_to_state_(self, params=params, n_samples=n_samples, sigmas=sigmas, generator=generator)
|
|
132
|
-
|
|
133
|
-
# sigma decay rules
|
|
134
|
-
max_steps = settings['max_steps']
|
|
135
|
-
decayed = False
|
|
136
|
-
if max_steps is not None and max_steps > 0:
|
|
137
|
-
level_steps = self.global_state['level_steps'] = self.global_state.get('level_steps', 0) + 1
|
|
138
|
-
if level_steps > max_steps:
|
|
139
|
-
self.global_state['level_steps'] = 0
|
|
140
|
-
_decay_sigma_(self, params)
|
|
141
|
-
decayed = True
|
|
142
|
-
|
|
143
|
-
tol = settings['tol']
|
|
144
|
-
if tol is not None and not decayed:
|
|
145
|
-
if not any('prev_params' in self.state[p] for p in params):
|
|
146
|
-
prev_params = self.get_state(params, 'prev_params', cls=TensorList, init='param')
|
|
147
|
-
else:
|
|
148
|
-
prev_params = self.get_state(params, 'prev_params', cls=TensorList, init='param')
|
|
149
|
-
s = params - prev_params
|
|
150
|
-
|
|
151
|
-
if s.abs().global_max() <= tol:
|
|
152
|
-
_decay_sigma_(self, params)
|
|
153
|
-
decayed = True
|
|
154
|
-
|
|
155
|
-
prev_params.copy_(params)
|
|
156
|
-
|
|
157
|
-
if decayed:
|
|
158
|
-
generator = self._get_generator(settings['seed'], params)
|
|
159
|
-
_generate_perturbations_to_state_(self, params=params, n_samples=n_samples, sigmas=sigmas, generator=generator)
|
|
160
|
-
if settings['clear_state']:
|
|
161
|
-
var.post_step_hooks.append(partial(_clear_state_hook, self=self))
|
|
162
|
-
|
|
163
|
-
@torch.no_grad
|
|
164
|
-
def closure(self, backward, closure, params, var):
|
|
165
|
-
params = TensorList(params)
|
|
166
|
-
|
|
167
|
-
settings = self.settings[params[0]]
|
|
168
|
-
n_samples = settings['n_samples']
|
|
169
|
-
|
|
170
|
-
perturbations = list(zip(*(self.state[p]['perturbations'] for p in params)))
|
|
171
|
-
|
|
172
|
-
loss = None
|
|
173
|
-
grad = None
|
|
174
|
-
for i in range(n_samples):
|
|
175
|
-
prt = perturbations[i]
|
|
176
|
-
|
|
177
|
-
params.add_(prt)
|
|
178
|
-
if backward:
|
|
179
|
-
with torch.enable_grad(): l = closure()
|
|
180
|
-
if grad is None: grad = params.grad
|
|
181
|
-
else: grad += params.grad
|
|
182
|
-
|
|
183
|
-
else:
|
|
184
|
-
l = closure(False)
|
|
185
|
-
|
|
186
|
-
if loss is None: loss = l
|
|
187
|
-
else: loss = loss+l
|
|
188
|
-
|
|
189
|
-
params.sub_(prt)
|
|
190
|
-
|
|
191
|
-
assert loss is not None
|
|
192
|
-
if n_samples > 1:
|
|
193
|
-
loss = loss / n_samples
|
|
194
|
-
if backward:
|
|
195
|
-
assert grad is not None
|
|
196
|
-
grad.div_(n_samples)
|
|
197
|
-
|
|
198
|
-
return loss, grad
|