torchzero 0.3.13__py3-none-any.whl → 0.3.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_opts.py +4 -10
- torchzero/core/__init__.py +4 -1
- torchzero/core/chain.py +50 -0
- torchzero/core/functional.py +37 -0
- torchzero/core/modular.py +237 -0
- torchzero/core/module.py +12 -599
- torchzero/core/reformulation.py +3 -1
- torchzero/core/transform.py +7 -5
- torchzero/core/var.py +376 -0
- torchzero/modules/__init__.py +0 -1
- torchzero/modules/adaptive/adahessian.py +2 -2
- torchzero/modules/adaptive/esgd.py +2 -2
- torchzero/modules/adaptive/matrix_momentum.py +1 -1
- torchzero/modules/adaptive/sophia_h.py +2 -2
- torchzero/modules/conjugate_gradient/cg.py +16 -16
- torchzero/modules/experimental/__init__.py +1 -0
- torchzero/modules/experimental/newtonnewton.py +5 -5
- torchzero/modules/experimental/spsa1.py +93 -0
- torchzero/modules/functional.py +7 -0
- torchzero/modules/grad_approximation/__init__.py +1 -1
- torchzero/modules/grad_approximation/forward_gradient.py +2 -5
- torchzero/modules/grad_approximation/rfdm.py +27 -110
- torchzero/modules/line_search/__init__.py +1 -1
- torchzero/modules/line_search/_polyinterp.py +3 -1
- torchzero/modules/line_search/adaptive.py +3 -3
- torchzero/modules/line_search/backtracking.py +1 -1
- torchzero/modules/line_search/interpolation.py +160 -0
- torchzero/modules/line_search/line_search.py +11 -20
- torchzero/modules/line_search/scipy.py +15 -3
- torchzero/modules/line_search/strong_wolfe.py +3 -5
- torchzero/modules/misc/misc.py +2 -2
- torchzero/modules/misc/multistep.py +13 -13
- torchzero/modules/quasi_newton/__init__.py +2 -0
- torchzero/modules/quasi_newton/quasi_newton.py +15 -6
- torchzero/modules/quasi_newton/sg2.py +292 -0
- torchzero/modules/restarts/restars.py +5 -4
- torchzero/modules/second_order/__init__.py +6 -3
- torchzero/modules/second_order/ifn.py +89 -0
- torchzero/modules/second_order/inm.py +105 -0
- torchzero/modules/second_order/newton.py +103 -193
- torchzero/modules/second_order/newton_cg.py +86 -110
- torchzero/modules/second_order/nystrom.py +1 -1
- torchzero/modules/second_order/rsn.py +227 -0
- torchzero/modules/trust_region/levenberg_marquardt.py +2 -2
- torchzero/modules/trust_region/trust_cg.py +6 -4
- torchzero/modules/wrappers/optim_wrapper.py +49 -42
- torchzero/modules/zeroth_order/__init__.py +1 -1
- torchzero/modules/zeroth_order/cd.py +1 -238
- torchzero/utils/derivatives.py +19 -19
- torchzero/utils/linalg/linear_operator.py +50 -2
- torchzero/utils/optimizer.py +2 -2
- torchzero/utils/python_tools.py +1 -0
- {torchzero-0.3.13.dist-info → torchzero-0.3.15.dist-info}/METADATA +1 -1
- {torchzero-0.3.13.dist-info → torchzero-0.3.15.dist-info}/RECORD +57 -48
- torchzero/modules/higher_order/__init__.py +0 -1
- /torchzero/modules/{higher_order → experimental}/higher_order_newton.py +0 -0
- {torchzero-0.3.13.dist-info → torchzero-0.3.15.dist-info}/WHEEL +0 -0
- {torchzero-0.3.13.dist-info → torchzero-0.3.15.dist-info}/top_level.txt +0 -0
|
@@ -10,6 +10,7 @@ import torch
|
|
|
10
10
|
|
|
11
11
|
from ...core import Module, Target, Var
|
|
12
12
|
from ...utils import tofloat, set_storage_
|
|
13
|
+
from ..functional import clip_by_finfo
|
|
13
14
|
|
|
14
15
|
|
|
15
16
|
class MaxLineSearchItersReached(Exception): pass
|
|
@@ -103,23 +104,18 @@ class LineSearchBase(Module, ABC):
|
|
|
103
104
|
):
|
|
104
105
|
if not math.isfinite(step_size): return
|
|
105
106
|
|
|
106
|
-
#
|
|
107
|
-
step_size =
|
|
107
|
+
# avoid overflow error
|
|
108
|
+
step_size = clip_by_finfo(tofloat(step_size), torch.finfo(update[0].dtype))
|
|
108
109
|
|
|
109
110
|
# skip is parameters are already at suggested step size
|
|
110
111
|
if self._current_step_size == step_size: return
|
|
111
112
|
|
|
112
|
-
# this was basically causing floating point imprecision to build up
|
|
113
|
-
#if False:
|
|
114
|
-
# if abs(alpha) < abs(step_size) and step_size != 0:
|
|
115
|
-
# torch._foreach_add_(params, update, alpha=alpha)
|
|
116
|
-
|
|
117
|
-
# else:
|
|
118
113
|
assert self._initial_params is not None
|
|
119
114
|
if step_size == 0:
|
|
120
115
|
new_params = [p.clone() for p in self._initial_params]
|
|
121
116
|
else:
|
|
122
117
|
new_params = torch._foreach_sub(self._initial_params, update, alpha=step_size)
|
|
118
|
+
|
|
123
119
|
for c, n in zip(params, new_params):
|
|
124
120
|
set_storage_(c, n)
|
|
125
121
|
|
|
@@ -131,10 +127,7 @@ class LineSearchBase(Module, ABC):
|
|
|
131
127
|
params: list[torch.Tensor],
|
|
132
128
|
update: list[torch.Tensor],
|
|
133
129
|
):
|
|
134
|
-
|
|
135
|
-
# alpha = [self._current_step_size - s for s in step_size]
|
|
136
|
-
# if any(a!=0 for a in alpha):
|
|
137
|
-
# torch._foreach_add_(params, torch._foreach_mul(update, alpha))
|
|
130
|
+
|
|
138
131
|
assert self._initial_params is not None
|
|
139
132
|
if not np.isfinite(step_size).all(): step_size = [0 for _ in step_size]
|
|
140
133
|
|
|
@@ -248,16 +241,14 @@ class LineSearchBase(Module, ABC):
|
|
|
248
241
|
except MaxLineSearchItersReached:
|
|
249
242
|
step_size = self._best_step_size
|
|
250
243
|
|
|
244
|
+
step_size = clip_by_finfo(step_size, torch.finfo(update[0].dtype))
|
|
245
|
+
|
|
251
246
|
# set loss_approx
|
|
252
247
|
if var.loss_approx is None: var.loss_approx = self._lowest_loss
|
|
253
248
|
|
|
254
|
-
# this is last module
|
|
255
|
-
if var.
|
|
256
|
-
|
|
257
|
-
self.set_step_size_(step_size, params=params, update=update)
|
|
258
|
-
|
|
259
|
-
else:
|
|
260
|
-
self._set_per_parameter_step_size_([step_size*lr for lr in var.last_module_lrs], params=params, update=update)
|
|
249
|
+
# if this is last module, directly update parameters to avoid redundant operations
|
|
250
|
+
if var.modular is not None and self is var.modular.modules[-1]:
|
|
251
|
+
self.set_step_size_(step_size, params=params, update=update)
|
|
261
252
|
|
|
262
253
|
var.stop = True; var.skip_update = True
|
|
263
254
|
return var
|
|
@@ -277,7 +268,7 @@ class GridLineSearch(LineSearchBase):
|
|
|
277
268
|
|
|
278
269
|
@torch.no_grad
|
|
279
270
|
def search(self, update, var):
|
|
280
|
-
start,end,num=itemgetter('start','end','num')(self.defaults)
|
|
271
|
+
start, end, num = itemgetter('start', 'end', 'num')(self.defaults)
|
|
281
272
|
|
|
282
273
|
for lr in torch.linspace(start,end,num):
|
|
283
274
|
self.evaluate_f(lr.item(), var=var, backward=False)
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import math
|
|
1
2
|
from collections.abc import Mapping
|
|
2
3
|
from operator import itemgetter
|
|
3
4
|
|
|
@@ -17,6 +18,7 @@ class ScipyMinimizeScalar(LineSearchBase):
|
|
|
17
18
|
bounds (Sequence | None, optional):
|
|
18
19
|
For method ‘bounded’, bounds is mandatory and must have two finite items corresponding to the optimization bounds. Defaults to None.
|
|
19
20
|
tol (float | None, optional): Tolerance for termination. Defaults to None.
|
|
21
|
+
prev_init (bool, optional): uses previous step size as initial guess for the line search.
|
|
20
22
|
options (dict | None, optional): A dictionary of solver options. Defaults to None.
|
|
21
23
|
|
|
22
24
|
For more details on methods and arguments refer to https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.minimize_scalar.html
|
|
@@ -29,9 +31,10 @@ class ScipyMinimizeScalar(LineSearchBase):
|
|
|
29
31
|
bracket=None,
|
|
30
32
|
bounds=None,
|
|
31
33
|
tol: float | None = None,
|
|
34
|
+
prev_init: bool = False,
|
|
32
35
|
options=None,
|
|
33
36
|
):
|
|
34
|
-
defaults = dict(method=method,bracket=bracket,bounds=bounds,tol=tol,options=options,maxiter=maxiter)
|
|
37
|
+
defaults = dict(method=method,bracket=bracket,bounds=bounds,tol=tol,options=options,maxiter=maxiter, prev_init=prev_init)
|
|
35
38
|
super().__init__(defaults)
|
|
36
39
|
|
|
37
40
|
import scipy.optimize
|
|
@@ -48,5 +51,14 @@ class ScipyMinimizeScalar(LineSearchBase):
|
|
|
48
51
|
options = dict(options) if isinstance(options, Mapping) else {}
|
|
49
52
|
options['maxiter'] = maxiter
|
|
50
53
|
|
|
51
|
-
|
|
52
|
-
|
|
54
|
+
if self.defaults["prev_init"] and "x_prev" in self.global_state:
|
|
55
|
+
if bracket is None: bracket = (0, 1)
|
|
56
|
+
bracket = (*bracket[:-1], self.global_state["x_prev"])
|
|
57
|
+
|
|
58
|
+
x = self.scopt.minimize_scalar(objective, method=method, bracket=bracket, bounds=bounds, tol=tol, options=options).x # pyright:ignore[reportAttributeAccessIssue]
|
|
59
|
+
|
|
60
|
+
max = torch.finfo(var.params[0].dtype).max / 2
|
|
61
|
+
if (not math.isfinite(x)) or abs(x) >= max: x = 0
|
|
62
|
+
|
|
63
|
+
self.global_state['x_prev'] = x
|
|
64
|
+
return x
|
|
@@ -7,7 +7,7 @@ import numpy as np
|
|
|
7
7
|
import torch
|
|
8
8
|
from torch.optim.lbfgs import _cubic_interpolate
|
|
9
9
|
|
|
10
|
-
from ...utils import as_tensorlist, totensor
|
|
10
|
+
from ...utils import as_tensorlist, totensor, tofloat
|
|
11
11
|
from ._polyinterp import polyinterp, polyinterp2
|
|
12
12
|
from .line_search import LineSearchBase, TerminationCondition, termination_condition
|
|
13
13
|
from ..step_size.adaptive import _bb_geom
|
|
@@ -92,7 +92,7 @@ class _StrongWolfe:
|
|
|
92
92
|
return _apply_bounds(a_lo + 0.5 * (a_hi - a_lo), bounds)
|
|
93
93
|
|
|
94
94
|
if self.interpolation in ('polynomial', 'polynomial2'):
|
|
95
|
-
finite_history = [(a, f, g) for a, (f,g) in self.history.items() if math.isfinite(a) and math.isfinite(f) and math.isfinite(g)]
|
|
95
|
+
finite_history = [(tofloat(a), tofloat(f), tofloat(g)) for a, (f,g) in self.history.items() if math.isfinite(a) and math.isfinite(f) and math.isfinite(g)]
|
|
96
96
|
if bounds is None: bounds = (None, None)
|
|
97
97
|
polyinterp_fn = polyinterp if self.interpolation == 'polynomial' else polyinterp2
|
|
98
98
|
try:
|
|
@@ -330,7 +330,6 @@ class StrongWolfe(LineSearchBase):
|
|
|
330
330
|
if adaptive:
|
|
331
331
|
a_init *= self.global_state.get('initial_scale', 1)
|
|
332
332
|
|
|
333
|
-
|
|
334
333
|
strong_wolfe = _StrongWolfe(
|
|
335
334
|
f=objective,
|
|
336
335
|
f_0=f_0,
|
|
@@ -360,7 +359,6 @@ class StrongWolfe(LineSearchBase):
|
|
|
360
359
|
if inverted: a = -a
|
|
361
360
|
|
|
362
361
|
if a is not None and a != 0 and math.isfinite(a):
|
|
363
|
-
#self.global_state['initial_scale'] = min(1.0, self.global_state.get('initial_scale', 1) * math.sqrt(2))
|
|
364
362
|
self.global_state['initial_scale'] = 1
|
|
365
363
|
self.global_state['a_prev'] = a
|
|
366
364
|
self.global_state['f_prev'] = f_0
|
|
@@ -372,6 +370,6 @@ class StrongWolfe(LineSearchBase):
|
|
|
372
370
|
self.global_state['initial_scale'] = self.global_state.get('initial_scale', 1) * 0.5
|
|
373
371
|
finfo = torch.finfo(dir[0].dtype)
|
|
374
372
|
if self.global_state['initial_scale'] < finfo.tiny * 2:
|
|
375
|
-
self.global_state['initial_scale'] =
|
|
373
|
+
self.global_state['initial_scale'] = init_value * 2
|
|
376
374
|
|
|
377
375
|
return 0
|
torchzero/modules/misc/misc.py
CHANGED
|
@@ -306,8 +306,8 @@ class RandomHvp(Module):
|
|
|
306
306
|
for i in range(n_samples):
|
|
307
307
|
u = params.sample_like(distribution=distribution, variance=1)
|
|
308
308
|
|
|
309
|
-
Hvp, rgrad =
|
|
310
|
-
h=h, normalize=True,
|
|
309
|
+
Hvp, rgrad = var.hessian_vector_product(u, at_x0=True, rgrad=rgrad, hvp_method=hvp_method,
|
|
310
|
+
h=h, normalize=True, retain_graph=i < n_samples-1)
|
|
311
311
|
|
|
312
312
|
if D is None: D = Hvp
|
|
313
313
|
else: torch._foreach_add_(D, Hvp)
|
|
@@ -15,7 +15,7 @@ def _sequential_step(self: Module, var: Var, sequential: bool):
|
|
|
15
15
|
if var.closure is None and len(modules) > 1: raise ValueError('Multistep and Sequential require closure')
|
|
16
16
|
|
|
17
17
|
# store original params unless this is last module and can update params directly
|
|
18
|
-
params_before_steps =
|
|
18
|
+
params_before_steps = [p.clone() for p in params]
|
|
19
19
|
|
|
20
20
|
# first step - pass var as usual
|
|
21
21
|
var = modules[0].step(var)
|
|
@@ -27,8 +27,8 @@ def _sequential_step(self: Module, var: Var, sequential: bool):
|
|
|
27
27
|
|
|
28
28
|
# update params
|
|
29
29
|
if (not new_var.skip_update):
|
|
30
|
-
if new_var.last_module_lrs is not None:
|
|
31
|
-
|
|
30
|
+
# if new_var.last_module_lrs is not None:
|
|
31
|
+
# torch._foreach_mul_(new_var.get_update(), new_var.last_module_lrs)
|
|
32
32
|
|
|
33
33
|
torch._foreach_sub_(params, new_var.get_update())
|
|
34
34
|
|
|
@@ -41,16 +41,16 @@ def _sequential_step(self: Module, var: Var, sequential: bool):
|
|
|
41
41
|
|
|
42
42
|
# final parameter update
|
|
43
43
|
if (not new_var.skip_update):
|
|
44
|
-
if new_var.last_module_lrs is not None:
|
|
45
|
-
|
|
44
|
+
# if new_var.last_module_lrs is not None:
|
|
45
|
+
# torch._foreach_mul_(new_var.get_update(), new_var.last_module_lrs)
|
|
46
46
|
|
|
47
47
|
torch._foreach_sub_(params, new_var.get_update())
|
|
48
48
|
|
|
49
49
|
# if last module, update is applied so return new var
|
|
50
|
-
if params_before_steps is None:
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
50
|
+
# if params_before_steps is None:
|
|
51
|
+
# new_var.stop = True
|
|
52
|
+
# new_var.skip_update = True
|
|
53
|
+
# return new_var
|
|
54
54
|
|
|
55
55
|
# otherwise use parameter difference as update
|
|
56
56
|
var.update = list(torch._foreach_sub(params_before_steps, params))
|
|
@@ -106,10 +106,10 @@ class NegateOnLossIncrease(Module):
|
|
|
106
106
|
f_1 = closure(False)
|
|
107
107
|
|
|
108
108
|
if f_1 <= f_0:
|
|
109
|
-
if var.is_last and var.last_module_lrs is None:
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
109
|
+
# if var.is_last and var.last_module_lrs is None:
|
|
110
|
+
# var.stop = True
|
|
111
|
+
# var.skip_update = True
|
|
112
|
+
# return var
|
|
113
113
|
|
|
114
114
|
torch._foreach_add_(var.params, update)
|
|
115
115
|
return var
|
|
@@ -1182,16 +1182,19 @@ class ShorR(HessianUpdateStrategy):
|
|
|
1182
1182
|
"""Shor’s r-algorithm.
|
|
1183
1183
|
|
|
1184
1184
|
Note:
|
|
1185
|
-
A line search such as ``tz.m.StrongWolfe(a_init="quadratic", fallback=True)`` is required.
|
|
1186
|
-
|
|
1187
|
-
|
|
1185
|
+
- A line search such as ``[tz.m.StrongWolfe(a_init="quadratic", fallback=True), tz.m.Mul(1.2)]`` is required. Similarly to conjugate gradient, ShorR doesn't have an automatic step size scaling, so setting ``a_init`` in the line search is recommended.
|
|
1186
|
+
|
|
1187
|
+
- The line search should try to overstep by a little, therefore it can help to multiply direction given by a line search by some value slightly larger than 1 such as 1.2.
|
|
1188
1188
|
|
|
1189
1189
|
References:
|
|
1190
|
-
|
|
1190
|
+
Those are the original references, but neither seem to be available online:
|
|
1191
|
+
- Shor, N. Z., Utilization of the Operation of Space Dilatation in the Minimization of Convex Functions, Kibernetika, No. 1, pp. 6-12, 1970.
|
|
1192
|
+
|
|
1193
|
+
- Skokov, V. A., Note on Minimization Methods Employing Space Stretching, Kibernetika, No. 4, pp. 115-117, 1974.
|
|
1191
1194
|
|
|
1192
|
-
Burke, James V., Adrian S. Lewis, and Michael L. Overton. "The Speed of Shor's R-algorithm." IMA Journal of numerical analysis 28.4 (2008): 711-720.
|
|
1195
|
+
An overview is available in [Burke, James V., Adrian S. Lewis, and Michael L. Overton. "The Speed of Shor's R-algorithm." IMA Journal of numerical analysis 28.4 (2008): 711-720](https://sites.math.washington.edu/~burke/papers/reprints/60-speed-Shor-R.pdf).
|
|
1193
1196
|
|
|
1194
|
-
Ansari, Zafar A. Limited Memory Space Dilation and Reduction Algorithms. Diss. Virginia Tech, 1998.
|
|
1197
|
+
Reference by Skokov, V. A. describes a more efficient formula which can be found here [Ansari, Zafar A. Limited Memory Space Dilation and Reduction Algorithms. Diss. Virginia Tech, 1998.](https://camo.ici.ro/books/thesis/th.pdf)
|
|
1195
1198
|
"""
|
|
1196
1199
|
|
|
1197
1200
|
def __init__(
|
|
@@ -1229,3 +1232,9 @@ class ShorR(HessianUpdateStrategy):
|
|
|
1229
1232
|
|
|
1230
1233
|
def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
|
|
1231
1234
|
return shor_r_(H=H, y=y, alpha=setting['alpha'])
|
|
1235
|
+
|
|
1236
|
+
|
|
1237
|
+
# Todd, Michael J. "The symmetric rank-one quasi-Newton method is a space-dilation subgradient algorithm." Operations research letters 5.5 (1986): 217-219.
|
|
1238
|
+
# TODO
|
|
1239
|
+
|
|
1240
|
+
# Sorensen, D. C. "The q-superlinear convergence of a collinear scaling algorithm for unconstrained optimization." SIAM Journal on Numerical Analysis 17.1 (1980): 84-114.
|
|
@@ -0,0 +1,292 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
from ...core import Module, Chainable, apply_transform
|
|
4
|
+
from ...utils import TensorList, vec_to_tensors
|
|
5
|
+
from ..second_order.newton import _newton_step, _get_H
|
|
6
|
+
|
|
7
|
+
def sg2_(
|
|
8
|
+
delta_g: torch.Tensor,
|
|
9
|
+
cd: torch.Tensor,
|
|
10
|
+
) -> torch.Tensor:
|
|
11
|
+
"""cd is c * perturbation, and must be multiplied by two if hessian estimate is two-sided
|
|
12
|
+
(or divide delta_g by two)."""
|
|
13
|
+
|
|
14
|
+
M = torch.outer(1.0 / cd, delta_g)
|
|
15
|
+
H_hat = 0.5 * (M + M.T)
|
|
16
|
+
|
|
17
|
+
return H_hat
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class SG2(Module):
|
|
22
|
+
"""second-order stochastic gradient
|
|
23
|
+
|
|
24
|
+
SG2 with line search
|
|
25
|
+
```python
|
|
26
|
+
opt = tz.Modular(
|
|
27
|
+
model.parameters(),
|
|
28
|
+
tz.m.SG2(),
|
|
29
|
+
tz.m.Backtracking()
|
|
30
|
+
)
|
|
31
|
+
```
|
|
32
|
+
|
|
33
|
+
SG2 with trust region
|
|
34
|
+
```python
|
|
35
|
+
opt = tz.Modular(
|
|
36
|
+
model.parameters(),
|
|
37
|
+
tz.m.LevenbergMarquardt(tz.m.SG2()),
|
|
38
|
+
)
|
|
39
|
+
```
|
|
40
|
+
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
def __init__(
|
|
44
|
+
self,
|
|
45
|
+
n_samples: int = 1,
|
|
46
|
+
h: float = 1e-2,
|
|
47
|
+
beta: float | None = None,
|
|
48
|
+
damping: float = 0,
|
|
49
|
+
eigval_fn=None,
|
|
50
|
+
one_sided: bool = False, # one-sided hessian
|
|
51
|
+
use_lstsq: bool = True,
|
|
52
|
+
seed=None,
|
|
53
|
+
inner: Chainable | None = None,
|
|
54
|
+
):
|
|
55
|
+
defaults = dict(n_samples=n_samples, h=h, beta=beta, damping=damping, eigval_fn=eigval_fn, one_sided=one_sided, seed=seed, use_lstsq=use_lstsq)
|
|
56
|
+
super().__init__(defaults)
|
|
57
|
+
|
|
58
|
+
if inner is not None: self.set_child('inner', inner)
|
|
59
|
+
|
|
60
|
+
@torch.no_grad
|
|
61
|
+
def update(self, var):
|
|
62
|
+
k = self.global_state.get('step', 0) + 1
|
|
63
|
+
self.global_state["step"] = k
|
|
64
|
+
|
|
65
|
+
params = TensorList(var.params)
|
|
66
|
+
closure = var.closure
|
|
67
|
+
if closure is None:
|
|
68
|
+
raise RuntimeError("closure is required for SG2")
|
|
69
|
+
generator = self.get_generator(params[0].device, self.defaults["seed"])
|
|
70
|
+
|
|
71
|
+
h = self.get_settings(params, "h")
|
|
72
|
+
x_0 = params.clone()
|
|
73
|
+
n_samples = self.defaults["n_samples"]
|
|
74
|
+
H_hat = None
|
|
75
|
+
|
|
76
|
+
for i in range(n_samples):
|
|
77
|
+
# generate perturbation
|
|
78
|
+
cd = params.rademacher_like(generator=generator).mul_(h)
|
|
79
|
+
|
|
80
|
+
# one sided
|
|
81
|
+
if self.defaults["one_sided"]:
|
|
82
|
+
g_0 = TensorList(var.get_grad())
|
|
83
|
+
params.add_(cd)
|
|
84
|
+
closure()
|
|
85
|
+
|
|
86
|
+
g_p = params.grad.fill_none_(params)
|
|
87
|
+
delta_g = (g_p - g_0) * 2
|
|
88
|
+
|
|
89
|
+
# two sided
|
|
90
|
+
else:
|
|
91
|
+
params.add_(cd)
|
|
92
|
+
closure()
|
|
93
|
+
g_p = params.grad.fill_none_(params)
|
|
94
|
+
|
|
95
|
+
params.copy_(x_0)
|
|
96
|
+
params.sub_(cd)
|
|
97
|
+
closure()
|
|
98
|
+
g_n = params.grad.fill_none_(params)
|
|
99
|
+
|
|
100
|
+
delta_g = g_p - g_n
|
|
101
|
+
|
|
102
|
+
# restore params
|
|
103
|
+
params.set_(x_0)
|
|
104
|
+
|
|
105
|
+
# compute H hat
|
|
106
|
+
H_i = sg2_(
|
|
107
|
+
delta_g = delta_g.to_vec(),
|
|
108
|
+
cd = cd.to_vec(),
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
if H_hat is None: H_hat = H_i
|
|
112
|
+
else: H_hat += H_i
|
|
113
|
+
|
|
114
|
+
assert H_hat is not None
|
|
115
|
+
if n_samples > 1: H_hat /= n_samples
|
|
116
|
+
|
|
117
|
+
# update H
|
|
118
|
+
H = self.global_state.get("H", None)
|
|
119
|
+
if H is None: H = H_hat
|
|
120
|
+
else:
|
|
121
|
+
beta = self.defaults["beta"]
|
|
122
|
+
if beta is None: beta = k / (k+1)
|
|
123
|
+
H.lerp_(H_hat, 1-beta)
|
|
124
|
+
|
|
125
|
+
self.global_state["H"] = H
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
@torch.no_grad
|
|
129
|
+
def apply(self, var):
|
|
130
|
+
dir = _newton_step(
|
|
131
|
+
var=var,
|
|
132
|
+
H = self.global_state["H"],
|
|
133
|
+
damping = self.defaults["damping"],
|
|
134
|
+
inner = self.children.get("inner", None),
|
|
135
|
+
H_tfm=None,
|
|
136
|
+
eigval_fn=self.defaults["eigval_fn"],
|
|
137
|
+
use_lstsq=self.defaults["use_lstsq"],
|
|
138
|
+
g_proj=None,
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
var.update = vec_to_tensors(dir, var.params)
|
|
142
|
+
return var
|
|
143
|
+
|
|
144
|
+
def get_H(self,var=...):
|
|
145
|
+
return _get_H(self.global_state["H"], self.defaults["eigval_fn"])
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
# two sided
|
|
151
|
+
# we have g via x + d, x - d
|
|
152
|
+
# H via g(x + d), g(x - d)
|
|
153
|
+
# 1 is x, x+2d
|
|
154
|
+
# 2 is x, x-2d
|
|
155
|
+
# 5 evals in total
|
|
156
|
+
|
|
157
|
+
# one sided
|
|
158
|
+
# g via x, x + d
|
|
159
|
+
# 1 is x, x + d
|
|
160
|
+
# 2 is x, x - d
|
|
161
|
+
# 3 evals and can use two sided for g_0
|
|
162
|
+
|
|
163
|
+
class SPSA2(Module):
|
|
164
|
+
"""second-order SPSA
|
|
165
|
+
|
|
166
|
+
SPSA2 with line search
|
|
167
|
+
```python
|
|
168
|
+
opt = tz.Modular(
|
|
169
|
+
model.parameters(),
|
|
170
|
+
tz.m.SPSA2(),
|
|
171
|
+
tz.m.Backtracking()
|
|
172
|
+
)
|
|
173
|
+
```
|
|
174
|
+
|
|
175
|
+
SPSA2 with trust region
|
|
176
|
+
```python
|
|
177
|
+
opt = tz.Modular(
|
|
178
|
+
model.parameters(),
|
|
179
|
+
tz.m.LevenbergMarquardt(tz.m.SPSA2()),
|
|
180
|
+
)
|
|
181
|
+
```
|
|
182
|
+
"""
|
|
183
|
+
|
|
184
|
+
def __init__(
|
|
185
|
+
self,
|
|
186
|
+
n_samples: int = 1,
|
|
187
|
+
h: float = 1e-2,
|
|
188
|
+
beta: float | None = None,
|
|
189
|
+
damping: float = 0,
|
|
190
|
+
eigval_fn=None,
|
|
191
|
+
use_lstsq: bool = True,
|
|
192
|
+
seed=None,
|
|
193
|
+
inner: Chainable | None = None,
|
|
194
|
+
):
|
|
195
|
+
defaults = dict(n_samples=n_samples, h=h, beta=beta, damping=damping, eigval_fn=eigval_fn, seed=seed, use_lstsq=use_lstsq)
|
|
196
|
+
super().__init__(defaults)
|
|
197
|
+
|
|
198
|
+
if inner is not None: self.set_child('inner', inner)
|
|
199
|
+
|
|
200
|
+
@torch.no_grad
|
|
201
|
+
def update(self, var):
|
|
202
|
+
k = self.global_state.get('step', 0) + 1
|
|
203
|
+
self.global_state["step"] = k
|
|
204
|
+
|
|
205
|
+
params = TensorList(var.params)
|
|
206
|
+
closure = var.closure
|
|
207
|
+
if closure is None:
|
|
208
|
+
raise RuntimeError("closure is required for SPSA2")
|
|
209
|
+
|
|
210
|
+
generator = self.get_generator(params[0].device, self.defaults["seed"])
|
|
211
|
+
|
|
212
|
+
h = self.get_settings(params, "h")
|
|
213
|
+
x_0 = params.clone()
|
|
214
|
+
n_samples = self.defaults["n_samples"]
|
|
215
|
+
H_hat = None
|
|
216
|
+
g_0 = None
|
|
217
|
+
|
|
218
|
+
for i in range(n_samples):
|
|
219
|
+
# perturbations for g and H
|
|
220
|
+
cd_g = params.rademacher_like(generator=generator).mul_(h)
|
|
221
|
+
cd_H = params.rademacher_like(generator=generator).mul_(h)
|
|
222
|
+
|
|
223
|
+
# evaluate 4 points
|
|
224
|
+
x_p = x_0 + cd_g
|
|
225
|
+
x_n = x_0 - cd_g
|
|
226
|
+
|
|
227
|
+
params.set_(x_p)
|
|
228
|
+
f_p = closure(False)
|
|
229
|
+
params.add_(cd_H)
|
|
230
|
+
f_pp = closure(False)
|
|
231
|
+
|
|
232
|
+
params.set_(x_n)
|
|
233
|
+
f_n = closure(False)
|
|
234
|
+
params.add_(cd_H)
|
|
235
|
+
f_np = closure(False)
|
|
236
|
+
|
|
237
|
+
g_p_vec = (f_pp - f_p) / cd_H
|
|
238
|
+
g_n_vec = (f_np - f_n) / cd_H
|
|
239
|
+
delta_g = g_p_vec - g_n_vec
|
|
240
|
+
|
|
241
|
+
# restore params
|
|
242
|
+
params.set_(x_0)
|
|
243
|
+
|
|
244
|
+
# compute grad
|
|
245
|
+
g_i = (f_p - f_n) / (2 * cd_g)
|
|
246
|
+
if g_0 is None: g_0 = g_i
|
|
247
|
+
else: g_0 += g_i
|
|
248
|
+
|
|
249
|
+
# compute H hat
|
|
250
|
+
H_i = sg2_(
|
|
251
|
+
delta_g = delta_g.to_vec().div_(2.0),
|
|
252
|
+
cd = cd_g.to_vec(), # The interval is measured by the original 'cd'
|
|
253
|
+
)
|
|
254
|
+
if H_hat is None: H_hat = H_i
|
|
255
|
+
else: H_hat += H_i
|
|
256
|
+
|
|
257
|
+
assert g_0 is not None and H_hat is not None
|
|
258
|
+
if n_samples > 1:
|
|
259
|
+
g_0 /= n_samples
|
|
260
|
+
H_hat /= n_samples
|
|
261
|
+
|
|
262
|
+
# set grad to approximated grad
|
|
263
|
+
var.grad = g_0
|
|
264
|
+
|
|
265
|
+
# update H
|
|
266
|
+
H = self.global_state.get("H", None)
|
|
267
|
+
if H is None: H = H_hat
|
|
268
|
+
else:
|
|
269
|
+
beta = self.defaults["beta"]
|
|
270
|
+
if beta is None: beta = k / (k+1)
|
|
271
|
+
H.lerp_(H_hat, 1-beta)
|
|
272
|
+
|
|
273
|
+
self.global_state["H"] = H
|
|
274
|
+
|
|
275
|
+
@torch.no_grad
|
|
276
|
+
def apply(self, var):
|
|
277
|
+
dir = _newton_step(
|
|
278
|
+
var=var,
|
|
279
|
+
H = self.global_state["H"],
|
|
280
|
+
damping = self.defaults["damping"],
|
|
281
|
+
inner = self.children.get("inner", None),
|
|
282
|
+
H_tfm=None,
|
|
283
|
+
eigval_fn=self.defaults["eigval_fn"],
|
|
284
|
+
use_lstsq=self.defaults["use_lstsq"],
|
|
285
|
+
g_proj=None,
|
|
286
|
+
)
|
|
287
|
+
|
|
288
|
+
var.update = vec_to_tensors(dir, var.params)
|
|
289
|
+
return var
|
|
290
|
+
|
|
291
|
+
def get_H(self,var=...):
|
|
292
|
+
return _get_H(self.global_state["H"], self.defaults["eigval_fn"])
|
|
@@ -60,18 +60,18 @@ class RestartStrategyBase(Module, ABC):
|
|
|
60
60
|
|
|
61
61
|
|
|
62
62
|
class RestartOnStuck(RestartStrategyBase):
|
|
63
|
-
"""Resets the state when update (difference in parameters) is
|
|
63
|
+
"""Resets the state when update (difference in parameters) is zero for multiple steps in a row.
|
|
64
64
|
|
|
65
65
|
Args:
|
|
66
66
|
modules (Chainable | None):
|
|
67
67
|
modules to reset. If None, resets all modules.
|
|
68
68
|
tol (float, optional):
|
|
69
|
-
step is considered failed when maximum absolute parameter difference is smaller than this. Defaults to
|
|
69
|
+
step is considered failed when maximum absolute parameter difference is smaller than this. Defaults to None (uses twice the smallest respresentable number)
|
|
70
70
|
n_tol (int, optional):
|
|
71
|
-
number of failed consequtive steps required to trigger a reset. Defaults to
|
|
71
|
+
number of failed consequtive steps required to trigger a reset. Defaults to 10.
|
|
72
72
|
|
|
73
73
|
"""
|
|
74
|
-
def __init__(self, modules: Chainable | None, tol: float =
|
|
74
|
+
def __init__(self, modules: Chainable | None, tol: float | None = None, n_tol: int = 10):
|
|
75
75
|
defaults = dict(tol=tol, n_tol=n_tol)
|
|
76
76
|
super().__init__(defaults, modules)
|
|
77
77
|
|
|
@@ -82,6 +82,7 @@ class RestartOnStuck(RestartStrategyBase):
|
|
|
82
82
|
|
|
83
83
|
params = TensorList(var.params)
|
|
84
84
|
tol = self.defaults['tol']
|
|
85
|
+
if tol is None: tol = torch.finfo(params[0].dtype).tiny * 2
|
|
85
86
|
n_tol = self.defaults['n_tol']
|
|
86
87
|
n_bad = self.global_state.get('n_bad', 0)
|
|
87
88
|
|
|
@@ -1,4 +1,7 @@
|
|
|
1
|
-
from .
|
|
1
|
+
from .ifn import InverseFreeNewton
|
|
2
|
+
from .inm import INM
|
|
3
|
+
from .multipoint import SixthOrder3P, SixthOrder3PM2, SixthOrder5P, TwoPointNewton
|
|
4
|
+
from .newton import Newton
|
|
2
5
|
from .newton_cg import NewtonCG, NewtonCGSteihaug
|
|
3
|
-
from .nystrom import
|
|
4
|
-
from .
|
|
6
|
+
from .nystrom import NystromPCG, NystromSketchAndSolve
|
|
7
|
+
from .rsn import RSN
|