torchzero 0.3.10__py3-none-any.whl → 0.3.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docs/source/conf.py +6 -4
- docs/source/docstring template.py +46 -0
- tests/test_identical.py +2 -3
- tests/test_opts.py +64 -50
- tests/test_vars.py +1 -0
- torchzero/core/module.py +138 -6
- torchzero/core/transform.py +158 -51
- torchzero/modules/__init__.py +3 -2
- torchzero/modules/clipping/clipping.py +114 -17
- torchzero/modules/clipping/ema_clipping.py +27 -13
- torchzero/modules/clipping/growth_clipping.py +8 -7
- torchzero/modules/experimental/__init__.py +22 -5
- torchzero/modules/experimental/absoap.py +5 -2
- torchzero/modules/experimental/adadam.py +8 -2
- torchzero/modules/experimental/adamY.py +8 -2
- torchzero/modules/experimental/adam_lambertw.py +149 -0
- torchzero/modules/{line_search/trust_region.py → experimental/adaptive_step_size.py} +21 -4
- torchzero/modules/experimental/adasoap.py +7 -2
- torchzero/modules/experimental/cosine.py +214 -0
- torchzero/modules/experimental/cubic_adam.py +97 -0
- torchzero/modules/{projections → experimental}/dct.py +11 -11
- torchzero/modules/experimental/eigendescent.py +4 -1
- torchzero/modules/experimental/etf.py +32 -9
- torchzero/modules/experimental/exp_adam.py +113 -0
- torchzero/modules/experimental/expanded_lbfgs.py +141 -0
- torchzero/modules/{projections → experimental}/fft.py +10 -10
- torchzero/modules/experimental/hnewton.py +85 -0
- torchzero/modules/{quasi_newton/experimental → experimental}/modular_lbfgs.py +27 -28
- torchzero/modules/experimental/newtonnewton.py +7 -3
- torchzero/modules/experimental/parabolic_search.py +220 -0
- torchzero/modules/experimental/reduce_outward_lr.py +4 -4
- torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +12 -54
- torchzero/modules/experimental/subspace_preconditioners.py +11 -4
- torchzero/modules/experimental/{tada.py → tensor_adagrad.py} +10 -6
- torchzero/modules/functional.py +12 -2
- torchzero/modules/grad_approximation/fdm.py +30 -3
- torchzero/modules/grad_approximation/forward_gradient.py +13 -3
- torchzero/modules/grad_approximation/grad_approximator.py +51 -6
- torchzero/modules/grad_approximation/rfdm.py +285 -38
- torchzero/modules/higher_order/higher_order_newton.py +152 -89
- torchzero/modules/line_search/__init__.py +4 -4
- torchzero/modules/line_search/adaptive.py +99 -0
- torchzero/modules/line_search/backtracking.py +34 -9
- torchzero/modules/line_search/line_search.py +70 -12
- torchzero/modules/line_search/polynomial.py +233 -0
- torchzero/modules/line_search/scipy.py +2 -2
- torchzero/modules/line_search/strong_wolfe.py +34 -7
- torchzero/modules/misc/__init__.py +27 -0
- torchzero/modules/{ops → misc}/debug.py +24 -1
- torchzero/modules/misc/escape.py +60 -0
- torchzero/modules/misc/gradient_accumulation.py +70 -0
- torchzero/modules/misc/misc.py +316 -0
- torchzero/modules/misc/multistep.py +158 -0
- torchzero/modules/misc/regularization.py +171 -0
- torchzero/modules/{ops → misc}/split.py +29 -1
- torchzero/modules/{ops → misc}/switch.py +44 -3
- torchzero/modules/momentum/__init__.py +1 -1
- torchzero/modules/momentum/averaging.py +6 -6
- torchzero/modules/momentum/cautious.py +45 -8
- torchzero/modules/momentum/ema.py +7 -7
- torchzero/modules/momentum/experimental.py +2 -2
- torchzero/modules/momentum/matrix_momentum.py +90 -63
- torchzero/modules/momentum/momentum.py +2 -1
- torchzero/modules/ops/__init__.py +3 -31
- torchzero/modules/ops/accumulate.py +6 -10
- torchzero/modules/ops/binary.py +72 -26
- torchzero/modules/ops/multi.py +77 -16
- torchzero/modules/ops/reduce.py +15 -7
- torchzero/modules/ops/unary.py +29 -13
- torchzero/modules/ops/utility.py +20 -12
- torchzero/modules/optimizers/__init__.py +12 -3
- torchzero/modules/optimizers/adagrad.py +23 -13
- torchzero/modules/optimizers/adahessian.py +223 -0
- torchzero/modules/optimizers/adam.py +7 -6
- torchzero/modules/optimizers/adan.py +110 -0
- torchzero/modules/optimizers/adaptive_heavyball.py +57 -0
- torchzero/modules/optimizers/esgd.py +171 -0
- torchzero/modules/{experimental/spectral.py → optimizers/ladagrad.py} +91 -71
- torchzero/modules/optimizers/lion.py +1 -1
- torchzero/modules/optimizers/mars.py +91 -0
- torchzero/modules/optimizers/msam.py +186 -0
- torchzero/modules/optimizers/muon.py +30 -5
- torchzero/modules/optimizers/orthograd.py +1 -1
- torchzero/modules/optimizers/rmsprop.py +7 -4
- torchzero/modules/optimizers/rprop.py +42 -8
- torchzero/modules/optimizers/sam.py +163 -0
- torchzero/modules/optimizers/shampoo.py +39 -5
- torchzero/modules/optimizers/soap.py +29 -19
- torchzero/modules/optimizers/sophia_h.py +71 -14
- torchzero/modules/projections/__init__.py +2 -4
- torchzero/modules/projections/cast.py +51 -0
- torchzero/modules/projections/galore.py +3 -1
- torchzero/modules/projections/projection.py +188 -94
- torchzero/modules/quasi_newton/__init__.py +12 -2
- torchzero/modules/quasi_newton/cg.py +160 -59
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +163 -0
- torchzero/modules/quasi_newton/lbfgs.py +154 -97
- torchzero/modules/quasi_newton/lsr1.py +101 -57
- torchzero/modules/quasi_newton/quasi_newton.py +863 -215
- torchzero/modules/quasi_newton/trust_region.py +397 -0
- torchzero/modules/second_order/__init__.py +2 -2
- torchzero/modules/second_order/newton.py +220 -41
- torchzero/modules/second_order/newton_cg.py +300 -11
- torchzero/modules/second_order/nystrom.py +104 -1
- torchzero/modules/smoothing/gaussian.py +34 -0
- torchzero/modules/smoothing/laplacian.py +14 -4
- torchzero/modules/step_size/__init__.py +2 -0
- torchzero/modules/step_size/adaptive.py +122 -0
- torchzero/modules/step_size/lr.py +154 -0
- torchzero/modules/weight_decay/__init__.py +1 -1
- torchzero/modules/weight_decay/weight_decay.py +89 -7
- torchzero/modules/wrappers/optim_wrapper.py +29 -1
- torchzero/optim/wrappers/directsearch.py +39 -2
- torchzero/optim/wrappers/fcmaes.py +21 -13
- torchzero/optim/wrappers/mads.py +5 -6
- torchzero/optim/wrappers/nevergrad.py +16 -1
- torchzero/optim/wrappers/optuna.py +1 -1
- torchzero/optim/wrappers/scipy.py +5 -3
- torchzero/utils/__init__.py +2 -2
- torchzero/utils/derivatives.py +3 -3
- torchzero/utils/linalg/__init__.py +1 -1
- torchzero/utils/linalg/solve.py +251 -12
- torchzero/utils/numberlist.py +2 -0
- torchzero/utils/python_tools.py +10 -0
- torchzero/utils/tensorlist.py +40 -28
- {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/METADATA +65 -40
- torchzero-0.3.11.dist-info/RECORD +159 -0
- torchzero/modules/experimental/diagonal_higher_order_newton.py +0 -225
- torchzero/modules/experimental/soapy.py +0 -163
- torchzero/modules/experimental/structured_newton.py +0 -111
- torchzero/modules/lr/__init__.py +0 -2
- torchzero/modules/lr/adaptive.py +0 -93
- torchzero/modules/lr/lr.py +0 -63
- torchzero/modules/ops/misc.py +0 -418
- torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
- torchzero/modules/quasi_newton/olbfgs.py +0 -196
- torchzero-0.3.10.dist-info/RECORD +0 -139
- {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/WHEEL +0 -0
- {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/licenses/LICENSE +0 -0
- {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,397 @@
|
|
|
1
|
+
"""Trust region API is currently experimental, it will probably change completely"""
|
|
2
|
+
# pylint:disable=not-callable
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from typing import Any, Literal, cast, final
|
|
5
|
+
from collections.abc import Sequence, Mapping
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
import torch
|
|
9
|
+
from scipy.optimize import lsq_linear
|
|
10
|
+
|
|
11
|
+
from ...core import Chainable, Module, apply_transform, Var
|
|
12
|
+
from ...utils import TensorList, vec_to_tensors
|
|
13
|
+
from ...utils.derivatives import (
|
|
14
|
+
hessian_list_to_mat,
|
|
15
|
+
jacobian_and_hessian_wrt,
|
|
16
|
+
)
|
|
17
|
+
from .quasi_newton import HessianUpdateStrategy
|
|
18
|
+
from ...utils.linalg import steihaug_toint_cg
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def trust_lstsq(H: torch.Tensor, g: torch.Tensor, trust_region: float):
|
|
22
|
+
res = lsq_linear(H.numpy(force=True).astype(np.float64), g.numpy(force=True).astype(np.float64), bounds=(-trust_region, trust_region))
|
|
23
|
+
x = torch.from_numpy(res.x).to(H)
|
|
24
|
+
return x, res.cost
|
|
25
|
+
|
|
26
|
+
def _flatten_tensors(tensors: list[torch.Tensor]):
|
|
27
|
+
return torch.cat([t.ravel() for t in tensors])
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class TrustRegionBase(Module, ABC):
|
|
31
|
+
def __init__(
|
|
32
|
+
self,
|
|
33
|
+
defaults: dict | None = None,
|
|
34
|
+
hess_module: HessianUpdateStrategy | None = None,
|
|
35
|
+
update_freq: int = 1,
|
|
36
|
+
inner: Chainable | None = None,
|
|
37
|
+
):
|
|
38
|
+
self._update_freq = update_freq
|
|
39
|
+
super().__init__(defaults)
|
|
40
|
+
|
|
41
|
+
if hess_module is not None:
|
|
42
|
+
self.set_child('hess_module', hess_module)
|
|
43
|
+
|
|
44
|
+
if inner is not None:
|
|
45
|
+
self.set_child('inner', inner)
|
|
46
|
+
|
|
47
|
+
@abstractmethod
|
|
48
|
+
def trust_region_step(self, var: Var, tensors:list[torch.Tensor], P: torch.Tensor, is_inverse:bool) -> Var:
|
|
49
|
+
"""trust region logic"""
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
@final
|
|
53
|
+
@torch.no_grad
|
|
54
|
+
def update(self, var):
|
|
55
|
+
# ---------------------------------- update ---------------------------------- #
|
|
56
|
+
closure = var.closure
|
|
57
|
+
if closure is None: raise RuntimeError("Trust region requires closure")
|
|
58
|
+
params = var.params
|
|
59
|
+
|
|
60
|
+
step = self.global_state.get('step', 0)
|
|
61
|
+
self.global_state['step'] = step + 1
|
|
62
|
+
|
|
63
|
+
P = None
|
|
64
|
+
is_inverse=None
|
|
65
|
+
g_list = var.grad
|
|
66
|
+
loss = var.loss
|
|
67
|
+
if step % self._update_freq == 0:
|
|
68
|
+
|
|
69
|
+
if 'hess_module' not in self.children:
|
|
70
|
+
params=var.params
|
|
71
|
+
closure=var.closure
|
|
72
|
+
if closure is None: raise ValueError('Closure is required for trust region')
|
|
73
|
+
with torch.enable_grad():
|
|
74
|
+
loss = var.loss = var.loss_approx = closure(False)
|
|
75
|
+
g_list, H_list = jacobian_and_hessian_wrt([loss], params, batched=True)
|
|
76
|
+
g_list = [t[0] for t in g_list] # remove leading dim from loss
|
|
77
|
+
var.grad = g_list
|
|
78
|
+
P = hessian_list_to_mat(H_list)
|
|
79
|
+
is_inverse=False
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
else:
|
|
83
|
+
hessian_module = cast(HessianUpdateStrategy, self.children['hess_module'])
|
|
84
|
+
hessian_module.update(var)
|
|
85
|
+
P, is_inverse = hessian_module.get_B()
|
|
86
|
+
|
|
87
|
+
if self._update_freq != 0:
|
|
88
|
+
self.global_state['B'] = P
|
|
89
|
+
self.global_state['is_inverse'] = is_inverse
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
@final
|
|
93
|
+
@torch.no_grad
|
|
94
|
+
def apply(self, var):
|
|
95
|
+
P = self.global_state['B']
|
|
96
|
+
is_inverse = self.global_state['is_inverse']
|
|
97
|
+
|
|
98
|
+
# -------------------------------- inner step -------------------------------- #
|
|
99
|
+
update = var.get_update()
|
|
100
|
+
if 'inner' in self.children:
|
|
101
|
+
update = apply_transform(self.children['inner'], update, params=var.params, grads=var.grad, var=var)
|
|
102
|
+
|
|
103
|
+
# ----------------------------------- apply ---------------------------------- #
|
|
104
|
+
return self.trust_region_step(var=var, tensors=update, P=P, is_inverse=is_inverse)
|
|
105
|
+
|
|
106
|
+
def _update_tr_radius(update_vec:torch.Tensor, params: Sequence[torch.Tensor], closure,
|
|
107
|
+
loss, g:torch.Tensor, H:torch.Tensor, trust_region:float, settings: Mapping):
|
|
108
|
+
"""returns (update, new_trust_region)
|
|
109
|
+
|
|
110
|
+
Args:
|
|
111
|
+
update_vec (torch.Tensor): update vector which is SUBTRACTED from parameters
|
|
112
|
+
params (_type_): params tensor list
|
|
113
|
+
closure (_type_): closure
|
|
114
|
+
loss (_type_): loss at x0
|
|
115
|
+
g (torch.Tensor): gradient vector
|
|
116
|
+
H (torch.Tensor): hessian
|
|
117
|
+
trust_region (float): current trust region value
|
|
118
|
+
"""
|
|
119
|
+
# evaluate actual loss reduction
|
|
120
|
+
update_unflattned = vec_to_tensors(update_vec, params)
|
|
121
|
+
params = TensorList(params)
|
|
122
|
+
params -= update_unflattned
|
|
123
|
+
loss_star = closure(False)
|
|
124
|
+
params += update_unflattned
|
|
125
|
+
reduction = loss - loss_star
|
|
126
|
+
|
|
127
|
+
# expected reduction is g.T @ p + 0.5 * p.T @ B @ p
|
|
128
|
+
if H.ndim == 1: Hu = H * update_vec
|
|
129
|
+
else: Hu = H @ update_vec
|
|
130
|
+
pred_reduction = - (g.dot(update_vec) + 0.5 * update_vec.dot(Hu))
|
|
131
|
+
rho = reduction / (pred_reduction.clip(min=1e-8))
|
|
132
|
+
|
|
133
|
+
# failed step
|
|
134
|
+
if rho < 0.25:
|
|
135
|
+
trust_region *= settings["nminus"]
|
|
136
|
+
|
|
137
|
+
# very good step
|
|
138
|
+
elif rho > 0.75:
|
|
139
|
+
diff = trust_region - update_vec.abs()
|
|
140
|
+
if (diff.amin() / trust_region) > 1e-4: # hits boundary
|
|
141
|
+
trust_region *= settings["nplus"]
|
|
142
|
+
|
|
143
|
+
# # if the ratio is high enough then accept the proposed step
|
|
144
|
+
# if rho > settings["eta"]:
|
|
145
|
+
# update = vec_to_tensors(update_vec, params)
|
|
146
|
+
|
|
147
|
+
# else:
|
|
148
|
+
# update = params.zeros_like()
|
|
149
|
+
|
|
150
|
+
return trust_region, rho > settings["eta"]
|
|
151
|
+
|
|
152
|
+
class TrustCG(TrustRegionBase):
|
|
153
|
+
"""Trust region via Steihaug-Toint Conjugate Gradient method. This is mainly useful for quasi-newton methods.
|
|
154
|
+
If you don't use :code:`hess_module`, use the matrix-free :code:`tz.m.NewtonCGSteihaug` which only uses hessian-vector products.
|
|
155
|
+
|
|
156
|
+
Args:
|
|
157
|
+
hess_module (HessianUpdateStrategy | None, optional):
|
|
158
|
+
Hessian update strategy, must be one of the :code:`HessianUpdateStrategy` modules. Make sure to set :code:`inverse=False`. If None, uses autograd to calculate the hessian. Defaults to None.
|
|
159
|
+
eta (float, optional):
|
|
160
|
+
if ratio of actual to predicted rediction is larger than this, step is accepted.
|
|
161
|
+
When :code:`hess_module` is None, this can be set to 0. Defaults to 0.15.
|
|
162
|
+
nplus (float, optional): increase factor on successful steps. Defaults to 1.5.
|
|
163
|
+
nminus (float, optional): decrease factor on unsuccessful steps. Defaults to 0.75.
|
|
164
|
+
init (float, optional): Initial trust region value. Defaults to 1.
|
|
165
|
+
update_freq (int, optional): frequency of updating the hessian. Defaults to 1.
|
|
166
|
+
reg (int, optional): hessian regularization. Defaults to 0.
|
|
167
|
+
inner (Chainable | None, optional): preconditioning is applied to output of thise module. Defaults to None.
|
|
168
|
+
|
|
169
|
+
Examples:
|
|
170
|
+
Trust-SR1
|
|
171
|
+
|
|
172
|
+
.. code-block:: python
|
|
173
|
+
|
|
174
|
+
opt = tz.Modular(
|
|
175
|
+
model.parameters(),
|
|
176
|
+
tz.m.TrustCG(hess_module=tz.m.SR1(inverse=False)),
|
|
177
|
+
)
|
|
178
|
+
"""
|
|
179
|
+
def __init__(
|
|
180
|
+
self,
|
|
181
|
+
hess_module: HessianUpdateStrategy | None,
|
|
182
|
+
eta: float= 0.15,
|
|
183
|
+
nplus: float = 2,
|
|
184
|
+
nminus: float = 0.25,
|
|
185
|
+
init: float = 1,
|
|
186
|
+
update_freq: int = 1,
|
|
187
|
+
reg: float = 0,
|
|
188
|
+
max_attempts: int = 10,
|
|
189
|
+
inner: Chainable | None = None,
|
|
190
|
+
):
|
|
191
|
+
defaults = dict(init=init, nplus=nplus, nminus=nminus, eta=eta, reg=reg, max_attempts=max_attempts)
|
|
192
|
+
super().__init__(defaults, hess_module=hess_module, update_freq=update_freq, inner=inner)
|
|
193
|
+
|
|
194
|
+
@torch.no_grad
|
|
195
|
+
def trust_region_step(self, var, tensors, P, is_inverse):
|
|
196
|
+
params = TensorList(var.params)
|
|
197
|
+
settings = self.settings[params[0]]
|
|
198
|
+
g = _flatten_tensors(tensors)
|
|
199
|
+
|
|
200
|
+
reg = settings['reg']
|
|
201
|
+
max_attempts = settings['max_attempts']
|
|
202
|
+
|
|
203
|
+
loss = var.loss
|
|
204
|
+
closure = var.closure
|
|
205
|
+
if closure is None: raise RuntimeError("Trust region requires closure")
|
|
206
|
+
if loss is None: loss = closure(False)
|
|
207
|
+
|
|
208
|
+
if is_inverse:
|
|
209
|
+
if P.ndim == 1: P = P.reciprocal()
|
|
210
|
+
else: raise NotImplementedError()
|
|
211
|
+
|
|
212
|
+
success = False
|
|
213
|
+
update_vec = None
|
|
214
|
+
while not success:
|
|
215
|
+
max_attempts -= 1
|
|
216
|
+
if max_attempts < 0: break
|
|
217
|
+
|
|
218
|
+
trust_region = self.global_state.get('trust_region', settings['init'])
|
|
219
|
+
|
|
220
|
+
if trust_region < 1e-8 or trust_region > 1e8:
|
|
221
|
+
trust_region = self.global_state['trust_region'] = settings['init']
|
|
222
|
+
|
|
223
|
+
update_vec = steihaug_toint_cg(P, g, trust_region, reg=reg)
|
|
224
|
+
|
|
225
|
+
self.global_state['trust_region'], success = _update_tr_radius(
|
|
226
|
+
update_vec=update_vec, params=params, closure=closure,
|
|
227
|
+
loss=loss, g=g, H=P, trust_region=trust_region, settings = settings,
|
|
228
|
+
)
|
|
229
|
+
|
|
230
|
+
assert update_vec is not None
|
|
231
|
+
if success: var.update = vec_to_tensors(update_vec, params)
|
|
232
|
+
else: var.update = params.zeros_like()
|
|
233
|
+
|
|
234
|
+
return var
|
|
235
|
+
|
|
236
|
+
|
|
237
|
+
# code from https://github.com/konstmish/opt_methods/blob/master/optmethods/second_order/cubic.py
|
|
238
|
+
# ported to torch
|
|
239
|
+
def ls_cubic_solver(f, g:torch.Tensor, H:torch.Tensor, M: float, is_inverse: bool, loss_plus, it_max=100, epsilon=1e-8, ):
|
|
240
|
+
"""
|
|
241
|
+
Solve min_z <g, z-x> + 1/2<z-x, H(z-x)> + M/3 ||z-x||^3
|
|
242
|
+
|
|
243
|
+
For explanation of Cauchy point, see "Gradient Descent
|
|
244
|
+
Efficiently Finds the Cubic-Regularized Non-Convex Newton Step"
|
|
245
|
+
https://arxiv.org/pdf/1612.00547.pdf
|
|
246
|
+
Other potential implementations can be found in paper
|
|
247
|
+
"Adaptive cubic regularisation methods"
|
|
248
|
+
https://people.maths.ox.ac.uk/cartis/papers/ARCpI.pdf
|
|
249
|
+
"""
|
|
250
|
+
solver_it = 1
|
|
251
|
+
if is_inverse:
|
|
252
|
+
newton_step = - H @ g
|
|
253
|
+
H = torch.linalg.inv(H)
|
|
254
|
+
else:
|
|
255
|
+
newton_step, info = torch.linalg.solve_ex(H, g)
|
|
256
|
+
if info != 0:
|
|
257
|
+
newton_step = torch.linalg.lstsq(H, g).solution
|
|
258
|
+
newton_step.neg_()
|
|
259
|
+
if M == 0:
|
|
260
|
+
return newton_step, solver_it
|
|
261
|
+
def cauchy_point(g, H, M):
|
|
262
|
+
if torch.linalg.vector_norm(g) == 0 or M == 0:
|
|
263
|
+
return 0 * g
|
|
264
|
+
g_dir = g / torch.linalg.vector_norm(g)
|
|
265
|
+
H_g_g = H @ g_dir @ g_dir
|
|
266
|
+
R = -H_g_g / (2*M) + torch.sqrt((H_g_g/M)**2/4 + torch.linalg.vector_norm(g)/M)
|
|
267
|
+
return -R * g_dir
|
|
268
|
+
|
|
269
|
+
def conv_criterion(s, r):
|
|
270
|
+
"""
|
|
271
|
+
The convergence criterion is an increasing and concave function in r
|
|
272
|
+
and it is equal to 0 only if r is the solution to the cubic problem
|
|
273
|
+
"""
|
|
274
|
+
s_norm = torch.linalg.vector_norm(s)
|
|
275
|
+
return 1/s_norm - 1/r
|
|
276
|
+
|
|
277
|
+
# Solution s satisfies ||s|| >= Cauchy_radius
|
|
278
|
+
r_min = torch.linalg.vector_norm(cauchy_point(g, H, M))
|
|
279
|
+
|
|
280
|
+
if f > loss_plus(newton_step):
|
|
281
|
+
return newton_step, solver_it
|
|
282
|
+
|
|
283
|
+
r_max = torch.linalg.vector_norm(newton_step)
|
|
284
|
+
if r_max - r_min < epsilon:
|
|
285
|
+
return newton_step, solver_it
|
|
286
|
+
id_matrix = torch.eye(g.size(0), device=g.device, dtype=g.dtype)
|
|
287
|
+
s_lam = None
|
|
288
|
+
for _ in range(it_max):
|
|
289
|
+
r_try = (r_min + r_max) / 2
|
|
290
|
+
lam = r_try * M
|
|
291
|
+
s_lam = -torch.linalg.solve(H + lam*id_matrix, g)
|
|
292
|
+
solver_it += 1
|
|
293
|
+
crit = conv_criterion(s_lam, r_try)
|
|
294
|
+
if np.abs(crit) < epsilon:
|
|
295
|
+
return s_lam, solver_it
|
|
296
|
+
if crit < 0:
|
|
297
|
+
r_min = r_try
|
|
298
|
+
else:
|
|
299
|
+
r_max = r_try
|
|
300
|
+
if r_max - r_min < epsilon:
|
|
301
|
+
break
|
|
302
|
+
assert s_lam is not None
|
|
303
|
+
return s_lam, solver_it
|
|
304
|
+
|
|
305
|
+
class CubicRegularization(TrustRegionBase):
|
|
306
|
+
"""Cubic regularization.
|
|
307
|
+
|
|
308
|
+
.. note::
|
|
309
|
+
by default this functions like a trust region, set nplus and nminus = 1 to make regularization parameter fixed.
|
|
310
|
+
:code:`init` sets 1/regularization.
|
|
311
|
+
|
|
312
|
+
Args:
|
|
313
|
+
hess_module (HessianUpdateStrategy | None, optional):
|
|
314
|
+
Hessian update strategy, must be one of the :code:`HessianUpdateStrategy` modules. This works better with true hessian though. Make sure to set :code:`inverse=False`. If None, uses autograd to calculate the hessian. Defaults to None.
|
|
315
|
+
eta (float, optional):
|
|
316
|
+
if ratio of actual to predicted rediction is larger than this, step is accepted.
|
|
317
|
+
When :code:`hess_module` is None, this can be set to 0. Defaults to 0.0.
|
|
318
|
+
nplus (float, optional): increase factor on successful steps. Defaults to 1.5.
|
|
319
|
+
nminus (float, optional): decrease factor on unsuccessful steps. Defaults to 0.75.
|
|
320
|
+
init (float, optional): Initial trust region value. Defaults to 1.
|
|
321
|
+
maxiter (float, optional): maximum iterations when solving cubic subproblem, defaults to 1e-7.
|
|
322
|
+
eps (float, optional): epsilon for the solver, defaults to 1e-8.
|
|
323
|
+
update_freq (int, optional): frequency of updating the hessian. Defaults to 1.
|
|
324
|
+
inner (Chainable | None, optional): preconditioning is applied to output of thise module. Defaults to None.
|
|
325
|
+
|
|
326
|
+
Examples:
|
|
327
|
+
Cubic regularized newton
|
|
328
|
+
|
|
329
|
+
.. code-block:: python
|
|
330
|
+
|
|
331
|
+
opt = tz.Modular(
|
|
332
|
+
model.parameters(),
|
|
333
|
+
tz.m.CubicRegularization(),
|
|
334
|
+
)
|
|
335
|
+
|
|
336
|
+
"""
|
|
337
|
+
def __init__(
|
|
338
|
+
self,
|
|
339
|
+
hess_module: HessianUpdateStrategy | None = None,
|
|
340
|
+
eta: float= 0.0,
|
|
341
|
+
nplus: float = 2,
|
|
342
|
+
nminus: float = 0.25,
|
|
343
|
+
init: float = 1,
|
|
344
|
+
maxiter: int = 100,
|
|
345
|
+
eps: float = 1e-8,
|
|
346
|
+
update_freq: int = 1,
|
|
347
|
+
max_attempts: int = 10,
|
|
348
|
+
inner: Chainable | None = None,
|
|
349
|
+
):
|
|
350
|
+
defaults = dict(init=init, nplus=nplus, nminus=nminus, eta=eta, maxiter=maxiter, eps=eps, max_attempts=max_attempts)
|
|
351
|
+
super().__init__(defaults, hess_module=hess_module, update_freq=update_freq, inner=inner)
|
|
352
|
+
|
|
353
|
+
@torch.no_grad
|
|
354
|
+
def trust_region_step(self, var, tensors, P, is_inverse):
|
|
355
|
+
params = TensorList(var.params)
|
|
356
|
+
settings = self.settings[params[0]]
|
|
357
|
+
g = _flatten_tensors(tensors)
|
|
358
|
+
|
|
359
|
+
maxiter = settings['maxiter']
|
|
360
|
+
max_attempts = settings['max_attempts']
|
|
361
|
+
eps = settings['eps']
|
|
362
|
+
|
|
363
|
+
loss = var.loss
|
|
364
|
+
closure = var.closure
|
|
365
|
+
if closure is None: raise RuntimeError("Trust region requires closure")
|
|
366
|
+
if loss is None: loss = closure(False)
|
|
367
|
+
|
|
368
|
+
def loss_plus(x):
|
|
369
|
+
x_unflat = vec_to_tensors(x, params)
|
|
370
|
+
params.add_(x_unflat)
|
|
371
|
+
loss_x = closure(False)
|
|
372
|
+
params.sub_(x_unflat)
|
|
373
|
+
return loss_x
|
|
374
|
+
|
|
375
|
+
success = False
|
|
376
|
+
update_vec = None
|
|
377
|
+
while not success:
|
|
378
|
+
max_attempts -= 1
|
|
379
|
+
if max_attempts < 0: break
|
|
380
|
+
|
|
381
|
+
trust_region = self.global_state.get('trust_region', settings['init'])
|
|
382
|
+
if trust_region < 1e-8 or trust_region > 1e16: trust_region = self.global_state['trust_region'] = settings['init']
|
|
383
|
+
|
|
384
|
+
update_vec, _ = ls_cubic_solver(f=loss, g=g, H=P, M=1/trust_region, is_inverse=is_inverse,
|
|
385
|
+
loss_plus=loss_plus, it_max=maxiter, epsilon=eps)
|
|
386
|
+
update_vec.neg_()
|
|
387
|
+
|
|
388
|
+
self.global_state['trust_region'], success = _update_tr_radius(
|
|
389
|
+
update_vec=update_vec, params=params, closure=closure,
|
|
390
|
+
loss=loss, g=g, H=P, trust_region=trust_region, settings = settings,
|
|
391
|
+
)
|
|
392
|
+
|
|
393
|
+
assert update_vec is not None
|
|
394
|
+
if success: var.update = vec_to_tensors(update_vec, params)
|
|
395
|
+
else: var.update = params.zeros_like()
|
|
396
|
+
|
|
397
|
+
return var
|
|
@@ -1,3 +1,3 @@
|
|
|
1
|
-
from .newton import Newton
|
|
2
|
-
from .newton_cg import NewtonCG
|
|
1
|
+
from .newton import Newton, InverseFreeNewton
|
|
2
|
+
from .newton_cg import NewtonCG, TruncatedNewtonCG
|
|
3
3
|
from .nystrom import NystromSketchAndSolve, NystromPCG
|