torchzero 0.3.14__py3-none-any.whl → 0.3.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_opts.py +4 -3
- torchzero/core/__init__.py +4 -1
- torchzero/core/chain.py +50 -0
- torchzero/core/functional.py +37 -0
- torchzero/core/modular.py +237 -0
- torchzero/core/module.py +8 -599
- torchzero/core/reformulation.py +3 -1
- torchzero/core/transform.py +7 -5
- torchzero/core/var.py +376 -0
- torchzero/modules/__init__.py +0 -1
- torchzero/modules/adaptive/adahessian.py +2 -2
- torchzero/modules/adaptive/esgd.py +2 -2
- torchzero/modules/adaptive/matrix_momentum.py +1 -1
- torchzero/modules/adaptive/sophia_h.py +2 -2
- torchzero/modules/experimental/__init__.py +1 -0
- torchzero/modules/experimental/newtonnewton.py +5 -5
- torchzero/modules/experimental/spsa1.py +2 -2
- torchzero/modules/functional.py +7 -0
- torchzero/modules/line_search/__init__.py +1 -1
- torchzero/modules/line_search/_polyinterp.py +3 -1
- torchzero/modules/line_search/adaptive.py +3 -3
- torchzero/modules/line_search/backtracking.py +1 -1
- torchzero/modules/line_search/interpolation.py +160 -0
- torchzero/modules/line_search/line_search.py +11 -20
- torchzero/modules/line_search/strong_wolfe.py +3 -3
- torchzero/modules/misc/misc.py +2 -2
- torchzero/modules/misc/multistep.py +13 -13
- torchzero/modules/quasi_newton/__init__.py +2 -0
- torchzero/modules/quasi_newton/quasi_newton.py +15 -6
- torchzero/modules/quasi_newton/sg2.py +292 -0
- torchzero/modules/second_order/__init__.py +6 -3
- torchzero/modules/second_order/ifn.py +89 -0
- torchzero/modules/second_order/inm.py +105 -0
- torchzero/modules/second_order/newton.py +103 -193
- torchzero/modules/second_order/nystrom.py +1 -1
- torchzero/modules/second_order/rsn.py +227 -0
- torchzero/modules/wrappers/optim_wrapper.py +49 -42
- torchzero/utils/derivatives.py +19 -19
- torchzero/utils/linalg/linear_operator.py +50 -2
- {torchzero-0.3.14.dist-info → torchzero-0.3.15.dist-info}/METADATA +1 -1
- {torchzero-0.3.14.dist-info → torchzero-0.3.15.dist-info}/RECORD +44 -36
- torchzero/modules/higher_order/__init__.py +0 -1
- /torchzero/modules/{higher_order → experimental}/higher_order_newton.py +0 -0
- {torchzero-0.3.14.dist-info → torchzero-0.3.15.dist-info}/WHEEL +0 -0
- {torchzero-0.3.14.dist-info → torchzero-0.3.15.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,160 @@
|
|
|
1
|
+
import math
|
|
2
|
+
from bisect import insort
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
from numpy.polynomial import Polynomial
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
# we have a list of points in ascending order of their `y` value
|
|
9
|
+
class Point:
|
|
10
|
+
__slots__ = ("x", "y", "d")
|
|
11
|
+
def __init__(self, x, y, d):
|
|
12
|
+
self.x = x
|
|
13
|
+
self.y = y
|
|
14
|
+
self.d = d
|
|
15
|
+
|
|
16
|
+
def __lt__(self, other):
|
|
17
|
+
return self.y < other.y
|
|
18
|
+
|
|
19
|
+
def _get_dpoint(points: list[Point]):
|
|
20
|
+
"""returns lowest point with derivative and list of other points"""
|
|
21
|
+
for i,p in enumerate(points):
|
|
22
|
+
if p.d is not None:
|
|
23
|
+
cpoints = points.copy()
|
|
24
|
+
del cpoints[i]
|
|
25
|
+
return p, cpoints
|
|
26
|
+
return None, points
|
|
27
|
+
|
|
28
|
+
# -------------------------------- quadratic2 -------------------------------- #
|
|
29
|
+
def _fitmin_quadratic2(x1, y1, d1, x2, y2):
|
|
30
|
+
|
|
31
|
+
a = (y2 - y1 - d1*(x2 - x1)) / (x2 - x1)**2
|
|
32
|
+
if a <= 0: return None
|
|
33
|
+
|
|
34
|
+
b = d1 - 2*a*x1
|
|
35
|
+
# c = y_1 - d_1*x_1 + a*x_1**2
|
|
36
|
+
|
|
37
|
+
return -b / (2*a)
|
|
38
|
+
|
|
39
|
+
def quadratic2(points:list[Point]):
|
|
40
|
+
pd, points = _get_dpoint(points)
|
|
41
|
+
if pd is None: return None
|
|
42
|
+
if len(points) == 0: return None
|
|
43
|
+
|
|
44
|
+
pn = points[0]
|
|
45
|
+
return _fitmin_quadratic2(pd.x, pd.y, pd.d, pn.x, pn.y)
|
|
46
|
+
|
|
47
|
+
# -------------------------------- quadratic3 -------------------------------- #
|
|
48
|
+
def _fitmin_quadratic3(x1, y1, x2, y2, x3, y3):
|
|
49
|
+
quad = Polynomial.fit([x1,x2,x3], [y1,y2,y3], deg=2)
|
|
50
|
+
a,b,c = quad.coef
|
|
51
|
+
if a <= 0: return None
|
|
52
|
+
return -b / (2*a)
|
|
53
|
+
|
|
54
|
+
def quadratic3(points:list[Point]):
|
|
55
|
+
if len(points) < 3: return None
|
|
56
|
+
|
|
57
|
+
p1,p2,p3 = points[:3]
|
|
58
|
+
return _fitmin_quadratic3(p1.x, p1.y, p2.x, p2.y, p3.x, p3.y)
|
|
59
|
+
|
|
60
|
+
# ---------------------------------- cubic3 ---------------------------------- #
|
|
61
|
+
def _minimize_polynomial(poly: Polynomial):
|
|
62
|
+
roots = poly.deriv().roots()
|
|
63
|
+
vals = poly(roots)
|
|
64
|
+
argmin = np.argmin(vals)
|
|
65
|
+
return roots[argmin], vals[argmin]
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def _fitmin_cubic3(x1,y1,x2,y2,x3,y3,x4,d4):
|
|
69
|
+
"""x4 is allowed to be equal to x1"""
|
|
70
|
+
|
|
71
|
+
A = np.array([
|
|
72
|
+
[x1**3, x1**2, x1, 1],
|
|
73
|
+
[x2**3, x2**2, x2, 1],
|
|
74
|
+
[x3**3, x3**2, x3, 1],
|
|
75
|
+
[3*x4**2, 2*x4, 1, 0]
|
|
76
|
+
])
|
|
77
|
+
|
|
78
|
+
B = np.array([y1, y2, y3, d4])
|
|
79
|
+
|
|
80
|
+
try:
|
|
81
|
+
coeffs = np.linalg.solve(A, B)
|
|
82
|
+
except np.linalg.LinAlgError:
|
|
83
|
+
return None
|
|
84
|
+
|
|
85
|
+
cubic = Polynomial(coeffs)
|
|
86
|
+
x_min, y_min = _minimize_polynomial(cubic)
|
|
87
|
+
if y_min < min(y1,y2,y3): return x_min
|
|
88
|
+
return None
|
|
89
|
+
|
|
90
|
+
def cubic3(points: list[Point]):
|
|
91
|
+
pd, points = _get_dpoint(points)
|
|
92
|
+
if pd is None: return None
|
|
93
|
+
if len(points) < 2: return None
|
|
94
|
+
p1, p2 = points[:2]
|
|
95
|
+
return _fitmin_cubic3(pd.x, pd.y, p1.x, p1.y, p2.x, p2.y, pd.x, pd.d)
|
|
96
|
+
|
|
97
|
+
# ---------------------------------- cubic4 ---------------------------------- #
|
|
98
|
+
def _fitmin_cubic4(x1, y1, x2, y2, x3, y3, x4, y4):
|
|
99
|
+
cubic = Polynomial.fit([x1,x2,x3,x4], [y1,y2,y3,y4], deg=3)
|
|
100
|
+
x_min, y_min = _minimize_polynomial(cubic)
|
|
101
|
+
if y_min < min(y1,y2,y3,y4): return x_min
|
|
102
|
+
return None
|
|
103
|
+
|
|
104
|
+
def cubic4(points:list[Point]):
|
|
105
|
+
if len(points) < 4: return None
|
|
106
|
+
|
|
107
|
+
p1,p2,p3,p4 = points[:4]
|
|
108
|
+
return _fitmin_cubic4(p1.x, p1.y, p2.x, p2.y, p3.x, p3.y, p4.x, p4.y)
|
|
109
|
+
|
|
110
|
+
# ---------------------------------- linear3 --------------------------------- #
|
|
111
|
+
def _linear_intersection(x1,y1,s1,x2,y2,s2):
|
|
112
|
+
if s1 == 0 or s2 == 0 or s1 == s2: return None
|
|
113
|
+
return (y1 - s1*x1 - y2 + s2*x2) / (s2 - s1)
|
|
114
|
+
|
|
115
|
+
def _fitmin_linear3(x1, y1, d1, x2, y2, x3, y3):
|
|
116
|
+
# we have that
|
|
117
|
+
# s2 = (y2 - y3) / (x2 - x3) # slope origin in x2 y2
|
|
118
|
+
# f1(x) = y1 + d1 * (x - x1)
|
|
119
|
+
# f2(x) = y2 + s2 * (x - x2)
|
|
120
|
+
# y1 + d1 * (x - x1) = y2 + s2 * (x - x2)
|
|
121
|
+
# y1 + d1 x - d1 x1 - y2 - s2 x + s2 x2 = 0
|
|
122
|
+
# s2 x - d1 x = y1 - d1 x1 - y2 + s2 x2
|
|
123
|
+
# x = (y1 - d1 x1 - y2 + s2 x2) / (s2 - d1)
|
|
124
|
+
|
|
125
|
+
if x2 < x1 < x3 or x3 < x1 < x2: # point with derivative in between
|
|
126
|
+
return None
|
|
127
|
+
|
|
128
|
+
if d1 > 0:
|
|
129
|
+
if x2 > x1 or x3 > x1: return None # intersection is above to the right
|
|
130
|
+
if x2 > x3: x2,y2,x3,y3 = x3,y3,x2,y2
|
|
131
|
+
if d1 < 0:
|
|
132
|
+
if x2 < x1 or x3 < x1: return None # intersection is above to the left
|
|
133
|
+
if x2 < x3: x2,y2,x3,y3 = x3,y3,x2,y2
|
|
134
|
+
|
|
135
|
+
s2 = (y2 - y3) / (x2 - x3)
|
|
136
|
+
return _linear_intersection(x1,y1,d1,x2,y2,s2)
|
|
137
|
+
|
|
138
|
+
def linear3(points:list[Point]):
|
|
139
|
+
pd, points = _get_dpoint(points)
|
|
140
|
+
if pd is None: return None
|
|
141
|
+
if len(points) < 2: return None
|
|
142
|
+
p1, p2 = points[:2]
|
|
143
|
+
return _fitmin_linear3(pd.x, pd.y, pd.d, p1.x, p1.y, p2.x, p2.y)
|
|
144
|
+
|
|
145
|
+
# ---------------------------------- linear4 --------------------------------- #
|
|
146
|
+
def _fitmin_linear4(x1, y1, x2, y2, x3, y3, x4, y4):
|
|
147
|
+
# sort by x
|
|
148
|
+
points = ((x1,y1), (x2,y2), (x3,y3), (x4,y4))
|
|
149
|
+
points = sorted(points, key=lambda x: x[0])
|
|
150
|
+
|
|
151
|
+
(x1,y1), (x2,y2), (x3,y3), (x4,y4) = points
|
|
152
|
+
s1 = (y1 - y2) / (x1 - x2)
|
|
153
|
+
s3 = (y3 - y4) / (x3 - x4)
|
|
154
|
+
|
|
155
|
+
return _linear_intersection(x1,y1,s1,x3,y3,s3)
|
|
156
|
+
|
|
157
|
+
def linear4(points:list[Point]):
|
|
158
|
+
if len(points) < 4: return None
|
|
159
|
+
p1,p2,p3,p4 = points[:4]
|
|
160
|
+
return _fitmin_linear4(p1.x, p1.y, p2.x, p2.y, p3.x, p3.y, p4.x, p4.y)
|
|
@@ -10,6 +10,7 @@ import torch
|
|
|
10
10
|
|
|
11
11
|
from ...core import Module, Target, Var
|
|
12
12
|
from ...utils import tofloat, set_storage_
|
|
13
|
+
from ..functional import clip_by_finfo
|
|
13
14
|
|
|
14
15
|
|
|
15
16
|
class MaxLineSearchItersReached(Exception): pass
|
|
@@ -103,23 +104,18 @@ class LineSearchBase(Module, ABC):
|
|
|
103
104
|
):
|
|
104
105
|
if not math.isfinite(step_size): return
|
|
105
106
|
|
|
106
|
-
#
|
|
107
|
-
step_size =
|
|
107
|
+
# avoid overflow error
|
|
108
|
+
step_size = clip_by_finfo(tofloat(step_size), torch.finfo(update[0].dtype))
|
|
108
109
|
|
|
109
110
|
# skip is parameters are already at suggested step size
|
|
110
111
|
if self._current_step_size == step_size: return
|
|
111
112
|
|
|
112
|
-
# this was basically causing floating point imprecision to build up
|
|
113
|
-
#if False:
|
|
114
|
-
# if abs(alpha) < abs(step_size) and step_size != 0:
|
|
115
|
-
# torch._foreach_add_(params, update, alpha=alpha)
|
|
116
|
-
|
|
117
|
-
# else:
|
|
118
113
|
assert self._initial_params is not None
|
|
119
114
|
if step_size == 0:
|
|
120
115
|
new_params = [p.clone() for p in self._initial_params]
|
|
121
116
|
else:
|
|
122
117
|
new_params = torch._foreach_sub(self._initial_params, update, alpha=step_size)
|
|
118
|
+
|
|
123
119
|
for c, n in zip(params, new_params):
|
|
124
120
|
set_storage_(c, n)
|
|
125
121
|
|
|
@@ -131,10 +127,7 @@ class LineSearchBase(Module, ABC):
|
|
|
131
127
|
params: list[torch.Tensor],
|
|
132
128
|
update: list[torch.Tensor],
|
|
133
129
|
):
|
|
134
|
-
|
|
135
|
-
# alpha = [self._current_step_size - s for s in step_size]
|
|
136
|
-
# if any(a!=0 for a in alpha):
|
|
137
|
-
# torch._foreach_add_(params, torch._foreach_mul(update, alpha))
|
|
130
|
+
|
|
138
131
|
assert self._initial_params is not None
|
|
139
132
|
if not np.isfinite(step_size).all(): step_size = [0 for _ in step_size]
|
|
140
133
|
|
|
@@ -248,16 +241,14 @@ class LineSearchBase(Module, ABC):
|
|
|
248
241
|
except MaxLineSearchItersReached:
|
|
249
242
|
step_size = self._best_step_size
|
|
250
243
|
|
|
244
|
+
step_size = clip_by_finfo(step_size, torch.finfo(update[0].dtype))
|
|
245
|
+
|
|
251
246
|
# set loss_approx
|
|
252
247
|
if var.loss_approx is None: var.loss_approx = self._lowest_loss
|
|
253
248
|
|
|
254
|
-
# this is last module
|
|
255
|
-
if var.
|
|
256
|
-
|
|
257
|
-
self.set_step_size_(step_size, params=params, update=update)
|
|
258
|
-
|
|
259
|
-
else:
|
|
260
|
-
self._set_per_parameter_step_size_([step_size*lr for lr in var.last_module_lrs], params=params, update=update)
|
|
249
|
+
# if this is last module, directly update parameters to avoid redundant operations
|
|
250
|
+
if var.modular is not None and self is var.modular.modules[-1]:
|
|
251
|
+
self.set_step_size_(step_size, params=params, update=update)
|
|
261
252
|
|
|
262
253
|
var.stop = True; var.skip_update = True
|
|
263
254
|
return var
|
|
@@ -277,7 +268,7 @@ class GridLineSearch(LineSearchBase):
|
|
|
277
268
|
|
|
278
269
|
@torch.no_grad
|
|
279
270
|
def search(self, update, var):
|
|
280
|
-
start,end,num=itemgetter('start','end','num')(self.defaults)
|
|
271
|
+
start, end, num = itemgetter('start', 'end', 'num')(self.defaults)
|
|
281
272
|
|
|
282
273
|
for lr in torch.linspace(start,end,num):
|
|
283
274
|
self.evaluate_f(lr.item(), var=var, backward=False)
|
|
@@ -7,7 +7,7 @@ import numpy as np
|
|
|
7
7
|
import torch
|
|
8
8
|
from torch.optim.lbfgs import _cubic_interpolate
|
|
9
9
|
|
|
10
|
-
from ...utils import as_tensorlist, totensor
|
|
10
|
+
from ...utils import as_tensorlist, totensor, tofloat
|
|
11
11
|
from ._polyinterp import polyinterp, polyinterp2
|
|
12
12
|
from .line_search import LineSearchBase, TerminationCondition, termination_condition
|
|
13
13
|
from ..step_size.adaptive import _bb_geom
|
|
@@ -92,7 +92,7 @@ class _StrongWolfe:
|
|
|
92
92
|
return _apply_bounds(a_lo + 0.5 * (a_hi - a_lo), bounds)
|
|
93
93
|
|
|
94
94
|
if self.interpolation in ('polynomial', 'polynomial2'):
|
|
95
|
-
finite_history = [(a, f, g) for a, (f,g) in self.history.items() if math.isfinite(a) and math.isfinite(f) and math.isfinite(g)]
|
|
95
|
+
finite_history = [(tofloat(a), tofloat(f), tofloat(g)) for a, (f,g) in self.history.items() if math.isfinite(a) and math.isfinite(f) and math.isfinite(g)]
|
|
96
96
|
if bounds is None: bounds = (None, None)
|
|
97
97
|
polyinterp_fn = polyinterp if self.interpolation == 'polynomial' else polyinterp2
|
|
98
98
|
try:
|
|
@@ -370,6 +370,6 @@ class StrongWolfe(LineSearchBase):
|
|
|
370
370
|
self.global_state['initial_scale'] = self.global_state.get('initial_scale', 1) * 0.5
|
|
371
371
|
finfo = torch.finfo(dir[0].dtype)
|
|
372
372
|
if self.global_state['initial_scale'] < finfo.tiny * 2:
|
|
373
|
-
self.global_state['initial_scale'] =
|
|
373
|
+
self.global_state['initial_scale'] = init_value * 2
|
|
374
374
|
|
|
375
375
|
return 0
|
torchzero/modules/misc/misc.py
CHANGED
|
@@ -306,8 +306,8 @@ class RandomHvp(Module):
|
|
|
306
306
|
for i in range(n_samples):
|
|
307
307
|
u = params.sample_like(distribution=distribution, variance=1)
|
|
308
308
|
|
|
309
|
-
Hvp, rgrad =
|
|
310
|
-
h=h, normalize=True,
|
|
309
|
+
Hvp, rgrad = var.hessian_vector_product(u, at_x0=True, rgrad=rgrad, hvp_method=hvp_method,
|
|
310
|
+
h=h, normalize=True, retain_graph=i < n_samples-1)
|
|
311
311
|
|
|
312
312
|
if D is None: D = Hvp
|
|
313
313
|
else: torch._foreach_add_(D, Hvp)
|
|
@@ -15,7 +15,7 @@ def _sequential_step(self: Module, var: Var, sequential: bool):
|
|
|
15
15
|
if var.closure is None and len(modules) > 1: raise ValueError('Multistep and Sequential require closure')
|
|
16
16
|
|
|
17
17
|
# store original params unless this is last module and can update params directly
|
|
18
|
-
params_before_steps =
|
|
18
|
+
params_before_steps = [p.clone() for p in params]
|
|
19
19
|
|
|
20
20
|
# first step - pass var as usual
|
|
21
21
|
var = modules[0].step(var)
|
|
@@ -27,8 +27,8 @@ def _sequential_step(self: Module, var: Var, sequential: bool):
|
|
|
27
27
|
|
|
28
28
|
# update params
|
|
29
29
|
if (not new_var.skip_update):
|
|
30
|
-
if new_var.last_module_lrs is not None:
|
|
31
|
-
|
|
30
|
+
# if new_var.last_module_lrs is not None:
|
|
31
|
+
# torch._foreach_mul_(new_var.get_update(), new_var.last_module_lrs)
|
|
32
32
|
|
|
33
33
|
torch._foreach_sub_(params, new_var.get_update())
|
|
34
34
|
|
|
@@ -41,16 +41,16 @@ def _sequential_step(self: Module, var: Var, sequential: bool):
|
|
|
41
41
|
|
|
42
42
|
# final parameter update
|
|
43
43
|
if (not new_var.skip_update):
|
|
44
|
-
if new_var.last_module_lrs is not None:
|
|
45
|
-
|
|
44
|
+
# if new_var.last_module_lrs is not None:
|
|
45
|
+
# torch._foreach_mul_(new_var.get_update(), new_var.last_module_lrs)
|
|
46
46
|
|
|
47
47
|
torch._foreach_sub_(params, new_var.get_update())
|
|
48
48
|
|
|
49
49
|
# if last module, update is applied so return new var
|
|
50
|
-
if params_before_steps is None:
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
50
|
+
# if params_before_steps is None:
|
|
51
|
+
# new_var.stop = True
|
|
52
|
+
# new_var.skip_update = True
|
|
53
|
+
# return new_var
|
|
54
54
|
|
|
55
55
|
# otherwise use parameter difference as update
|
|
56
56
|
var.update = list(torch._foreach_sub(params_before_steps, params))
|
|
@@ -106,10 +106,10 @@ class NegateOnLossIncrease(Module):
|
|
|
106
106
|
f_1 = closure(False)
|
|
107
107
|
|
|
108
108
|
if f_1 <= f_0:
|
|
109
|
-
if var.is_last and var.last_module_lrs is None:
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
109
|
+
# if var.is_last and var.last_module_lrs is None:
|
|
110
|
+
# var.stop = True
|
|
111
|
+
# var.skip_update = True
|
|
112
|
+
# return var
|
|
113
113
|
|
|
114
114
|
torch._foreach_add_(var.params, update)
|
|
115
115
|
return var
|
|
@@ -1182,16 +1182,19 @@ class ShorR(HessianUpdateStrategy):
|
|
|
1182
1182
|
"""Shor’s r-algorithm.
|
|
1183
1183
|
|
|
1184
1184
|
Note:
|
|
1185
|
-
A line search such as ``tz.m.StrongWolfe(a_init="quadratic", fallback=True)`` is required.
|
|
1186
|
-
|
|
1187
|
-
|
|
1185
|
+
- A line search such as ``[tz.m.StrongWolfe(a_init="quadratic", fallback=True), tz.m.Mul(1.2)]`` is required. Similarly to conjugate gradient, ShorR doesn't have an automatic step size scaling, so setting ``a_init`` in the line search is recommended.
|
|
1186
|
+
|
|
1187
|
+
- The line search should try to overstep by a little, therefore it can help to multiply direction given by a line search by some value slightly larger than 1 such as 1.2.
|
|
1188
1188
|
|
|
1189
1189
|
References:
|
|
1190
|
-
|
|
1190
|
+
Those are the original references, but neither seem to be available online:
|
|
1191
|
+
- Shor, N. Z., Utilization of the Operation of Space Dilatation in the Minimization of Convex Functions, Kibernetika, No. 1, pp. 6-12, 1970.
|
|
1192
|
+
|
|
1193
|
+
- Skokov, V. A., Note on Minimization Methods Employing Space Stretching, Kibernetika, No. 4, pp. 115-117, 1974.
|
|
1191
1194
|
|
|
1192
|
-
Burke, James V., Adrian S. Lewis, and Michael L. Overton. "The Speed of Shor's R-algorithm." IMA Journal of numerical analysis 28.4 (2008): 711-720.
|
|
1195
|
+
An overview is available in [Burke, James V., Adrian S. Lewis, and Michael L. Overton. "The Speed of Shor's R-algorithm." IMA Journal of numerical analysis 28.4 (2008): 711-720](https://sites.math.washington.edu/~burke/papers/reprints/60-speed-Shor-R.pdf).
|
|
1193
1196
|
|
|
1194
|
-
Ansari, Zafar A. Limited Memory Space Dilation and Reduction Algorithms. Diss. Virginia Tech, 1998.
|
|
1197
|
+
Reference by Skokov, V. A. describes a more efficient formula which can be found here [Ansari, Zafar A. Limited Memory Space Dilation and Reduction Algorithms. Diss. Virginia Tech, 1998.](https://camo.ici.ro/books/thesis/th.pdf)
|
|
1195
1198
|
"""
|
|
1196
1199
|
|
|
1197
1200
|
def __init__(
|
|
@@ -1229,3 +1232,9 @@ class ShorR(HessianUpdateStrategy):
|
|
|
1229
1232
|
|
|
1230
1233
|
def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
|
|
1231
1234
|
return shor_r_(H=H, y=y, alpha=setting['alpha'])
|
|
1235
|
+
|
|
1236
|
+
|
|
1237
|
+
# Todd, Michael J. "The symmetric rank-one quasi-Newton method is a space-dilation subgradient algorithm." Operations research letters 5.5 (1986): 217-219.
|
|
1238
|
+
# TODO
|
|
1239
|
+
|
|
1240
|
+
# Sorensen, D. C. "The q-superlinear convergence of a collinear scaling algorithm for unconstrained optimization." SIAM Journal on Numerical Analysis 17.1 (1980): 84-114.
|
|
@@ -0,0 +1,292 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
from ...core import Module, Chainable, apply_transform
|
|
4
|
+
from ...utils import TensorList, vec_to_tensors
|
|
5
|
+
from ..second_order.newton import _newton_step, _get_H
|
|
6
|
+
|
|
7
|
+
def sg2_(
|
|
8
|
+
delta_g: torch.Tensor,
|
|
9
|
+
cd: torch.Tensor,
|
|
10
|
+
) -> torch.Tensor:
|
|
11
|
+
"""cd is c * perturbation, and must be multiplied by two if hessian estimate is two-sided
|
|
12
|
+
(or divide delta_g by two)."""
|
|
13
|
+
|
|
14
|
+
M = torch.outer(1.0 / cd, delta_g)
|
|
15
|
+
H_hat = 0.5 * (M + M.T)
|
|
16
|
+
|
|
17
|
+
return H_hat
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class SG2(Module):
|
|
22
|
+
"""second-order stochastic gradient
|
|
23
|
+
|
|
24
|
+
SG2 with line search
|
|
25
|
+
```python
|
|
26
|
+
opt = tz.Modular(
|
|
27
|
+
model.parameters(),
|
|
28
|
+
tz.m.SG2(),
|
|
29
|
+
tz.m.Backtracking()
|
|
30
|
+
)
|
|
31
|
+
```
|
|
32
|
+
|
|
33
|
+
SG2 with trust region
|
|
34
|
+
```python
|
|
35
|
+
opt = tz.Modular(
|
|
36
|
+
model.parameters(),
|
|
37
|
+
tz.m.LevenbergMarquardt(tz.m.SG2()),
|
|
38
|
+
)
|
|
39
|
+
```
|
|
40
|
+
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
def __init__(
|
|
44
|
+
self,
|
|
45
|
+
n_samples: int = 1,
|
|
46
|
+
h: float = 1e-2,
|
|
47
|
+
beta: float | None = None,
|
|
48
|
+
damping: float = 0,
|
|
49
|
+
eigval_fn=None,
|
|
50
|
+
one_sided: bool = False, # one-sided hessian
|
|
51
|
+
use_lstsq: bool = True,
|
|
52
|
+
seed=None,
|
|
53
|
+
inner: Chainable | None = None,
|
|
54
|
+
):
|
|
55
|
+
defaults = dict(n_samples=n_samples, h=h, beta=beta, damping=damping, eigval_fn=eigval_fn, one_sided=one_sided, seed=seed, use_lstsq=use_lstsq)
|
|
56
|
+
super().__init__(defaults)
|
|
57
|
+
|
|
58
|
+
if inner is not None: self.set_child('inner', inner)
|
|
59
|
+
|
|
60
|
+
@torch.no_grad
|
|
61
|
+
def update(self, var):
|
|
62
|
+
k = self.global_state.get('step', 0) + 1
|
|
63
|
+
self.global_state["step"] = k
|
|
64
|
+
|
|
65
|
+
params = TensorList(var.params)
|
|
66
|
+
closure = var.closure
|
|
67
|
+
if closure is None:
|
|
68
|
+
raise RuntimeError("closure is required for SG2")
|
|
69
|
+
generator = self.get_generator(params[0].device, self.defaults["seed"])
|
|
70
|
+
|
|
71
|
+
h = self.get_settings(params, "h")
|
|
72
|
+
x_0 = params.clone()
|
|
73
|
+
n_samples = self.defaults["n_samples"]
|
|
74
|
+
H_hat = None
|
|
75
|
+
|
|
76
|
+
for i in range(n_samples):
|
|
77
|
+
# generate perturbation
|
|
78
|
+
cd = params.rademacher_like(generator=generator).mul_(h)
|
|
79
|
+
|
|
80
|
+
# one sided
|
|
81
|
+
if self.defaults["one_sided"]:
|
|
82
|
+
g_0 = TensorList(var.get_grad())
|
|
83
|
+
params.add_(cd)
|
|
84
|
+
closure()
|
|
85
|
+
|
|
86
|
+
g_p = params.grad.fill_none_(params)
|
|
87
|
+
delta_g = (g_p - g_0) * 2
|
|
88
|
+
|
|
89
|
+
# two sided
|
|
90
|
+
else:
|
|
91
|
+
params.add_(cd)
|
|
92
|
+
closure()
|
|
93
|
+
g_p = params.grad.fill_none_(params)
|
|
94
|
+
|
|
95
|
+
params.copy_(x_0)
|
|
96
|
+
params.sub_(cd)
|
|
97
|
+
closure()
|
|
98
|
+
g_n = params.grad.fill_none_(params)
|
|
99
|
+
|
|
100
|
+
delta_g = g_p - g_n
|
|
101
|
+
|
|
102
|
+
# restore params
|
|
103
|
+
params.set_(x_0)
|
|
104
|
+
|
|
105
|
+
# compute H hat
|
|
106
|
+
H_i = sg2_(
|
|
107
|
+
delta_g = delta_g.to_vec(),
|
|
108
|
+
cd = cd.to_vec(),
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
if H_hat is None: H_hat = H_i
|
|
112
|
+
else: H_hat += H_i
|
|
113
|
+
|
|
114
|
+
assert H_hat is not None
|
|
115
|
+
if n_samples > 1: H_hat /= n_samples
|
|
116
|
+
|
|
117
|
+
# update H
|
|
118
|
+
H = self.global_state.get("H", None)
|
|
119
|
+
if H is None: H = H_hat
|
|
120
|
+
else:
|
|
121
|
+
beta = self.defaults["beta"]
|
|
122
|
+
if beta is None: beta = k / (k+1)
|
|
123
|
+
H.lerp_(H_hat, 1-beta)
|
|
124
|
+
|
|
125
|
+
self.global_state["H"] = H
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
@torch.no_grad
|
|
129
|
+
def apply(self, var):
|
|
130
|
+
dir = _newton_step(
|
|
131
|
+
var=var,
|
|
132
|
+
H = self.global_state["H"],
|
|
133
|
+
damping = self.defaults["damping"],
|
|
134
|
+
inner = self.children.get("inner", None),
|
|
135
|
+
H_tfm=None,
|
|
136
|
+
eigval_fn=self.defaults["eigval_fn"],
|
|
137
|
+
use_lstsq=self.defaults["use_lstsq"],
|
|
138
|
+
g_proj=None,
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
var.update = vec_to_tensors(dir, var.params)
|
|
142
|
+
return var
|
|
143
|
+
|
|
144
|
+
def get_H(self,var=...):
|
|
145
|
+
return _get_H(self.global_state["H"], self.defaults["eigval_fn"])
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
# two sided
|
|
151
|
+
# we have g via x + d, x - d
|
|
152
|
+
# H via g(x + d), g(x - d)
|
|
153
|
+
# 1 is x, x+2d
|
|
154
|
+
# 2 is x, x-2d
|
|
155
|
+
# 5 evals in total
|
|
156
|
+
|
|
157
|
+
# one sided
|
|
158
|
+
# g via x, x + d
|
|
159
|
+
# 1 is x, x + d
|
|
160
|
+
# 2 is x, x - d
|
|
161
|
+
# 3 evals and can use two sided for g_0
|
|
162
|
+
|
|
163
|
+
class SPSA2(Module):
|
|
164
|
+
"""second-order SPSA
|
|
165
|
+
|
|
166
|
+
SPSA2 with line search
|
|
167
|
+
```python
|
|
168
|
+
opt = tz.Modular(
|
|
169
|
+
model.parameters(),
|
|
170
|
+
tz.m.SPSA2(),
|
|
171
|
+
tz.m.Backtracking()
|
|
172
|
+
)
|
|
173
|
+
```
|
|
174
|
+
|
|
175
|
+
SPSA2 with trust region
|
|
176
|
+
```python
|
|
177
|
+
opt = tz.Modular(
|
|
178
|
+
model.parameters(),
|
|
179
|
+
tz.m.LevenbergMarquardt(tz.m.SPSA2()),
|
|
180
|
+
)
|
|
181
|
+
```
|
|
182
|
+
"""
|
|
183
|
+
|
|
184
|
+
def __init__(
|
|
185
|
+
self,
|
|
186
|
+
n_samples: int = 1,
|
|
187
|
+
h: float = 1e-2,
|
|
188
|
+
beta: float | None = None,
|
|
189
|
+
damping: float = 0,
|
|
190
|
+
eigval_fn=None,
|
|
191
|
+
use_lstsq: bool = True,
|
|
192
|
+
seed=None,
|
|
193
|
+
inner: Chainable | None = None,
|
|
194
|
+
):
|
|
195
|
+
defaults = dict(n_samples=n_samples, h=h, beta=beta, damping=damping, eigval_fn=eigval_fn, seed=seed, use_lstsq=use_lstsq)
|
|
196
|
+
super().__init__(defaults)
|
|
197
|
+
|
|
198
|
+
if inner is not None: self.set_child('inner', inner)
|
|
199
|
+
|
|
200
|
+
@torch.no_grad
|
|
201
|
+
def update(self, var):
|
|
202
|
+
k = self.global_state.get('step', 0) + 1
|
|
203
|
+
self.global_state["step"] = k
|
|
204
|
+
|
|
205
|
+
params = TensorList(var.params)
|
|
206
|
+
closure = var.closure
|
|
207
|
+
if closure is None:
|
|
208
|
+
raise RuntimeError("closure is required for SPSA2")
|
|
209
|
+
|
|
210
|
+
generator = self.get_generator(params[0].device, self.defaults["seed"])
|
|
211
|
+
|
|
212
|
+
h = self.get_settings(params, "h")
|
|
213
|
+
x_0 = params.clone()
|
|
214
|
+
n_samples = self.defaults["n_samples"]
|
|
215
|
+
H_hat = None
|
|
216
|
+
g_0 = None
|
|
217
|
+
|
|
218
|
+
for i in range(n_samples):
|
|
219
|
+
# perturbations for g and H
|
|
220
|
+
cd_g = params.rademacher_like(generator=generator).mul_(h)
|
|
221
|
+
cd_H = params.rademacher_like(generator=generator).mul_(h)
|
|
222
|
+
|
|
223
|
+
# evaluate 4 points
|
|
224
|
+
x_p = x_0 + cd_g
|
|
225
|
+
x_n = x_0 - cd_g
|
|
226
|
+
|
|
227
|
+
params.set_(x_p)
|
|
228
|
+
f_p = closure(False)
|
|
229
|
+
params.add_(cd_H)
|
|
230
|
+
f_pp = closure(False)
|
|
231
|
+
|
|
232
|
+
params.set_(x_n)
|
|
233
|
+
f_n = closure(False)
|
|
234
|
+
params.add_(cd_H)
|
|
235
|
+
f_np = closure(False)
|
|
236
|
+
|
|
237
|
+
g_p_vec = (f_pp - f_p) / cd_H
|
|
238
|
+
g_n_vec = (f_np - f_n) / cd_H
|
|
239
|
+
delta_g = g_p_vec - g_n_vec
|
|
240
|
+
|
|
241
|
+
# restore params
|
|
242
|
+
params.set_(x_0)
|
|
243
|
+
|
|
244
|
+
# compute grad
|
|
245
|
+
g_i = (f_p - f_n) / (2 * cd_g)
|
|
246
|
+
if g_0 is None: g_0 = g_i
|
|
247
|
+
else: g_0 += g_i
|
|
248
|
+
|
|
249
|
+
# compute H hat
|
|
250
|
+
H_i = sg2_(
|
|
251
|
+
delta_g = delta_g.to_vec().div_(2.0),
|
|
252
|
+
cd = cd_g.to_vec(), # The interval is measured by the original 'cd'
|
|
253
|
+
)
|
|
254
|
+
if H_hat is None: H_hat = H_i
|
|
255
|
+
else: H_hat += H_i
|
|
256
|
+
|
|
257
|
+
assert g_0 is not None and H_hat is not None
|
|
258
|
+
if n_samples > 1:
|
|
259
|
+
g_0 /= n_samples
|
|
260
|
+
H_hat /= n_samples
|
|
261
|
+
|
|
262
|
+
# set grad to approximated grad
|
|
263
|
+
var.grad = g_0
|
|
264
|
+
|
|
265
|
+
# update H
|
|
266
|
+
H = self.global_state.get("H", None)
|
|
267
|
+
if H is None: H = H_hat
|
|
268
|
+
else:
|
|
269
|
+
beta = self.defaults["beta"]
|
|
270
|
+
if beta is None: beta = k / (k+1)
|
|
271
|
+
H.lerp_(H_hat, 1-beta)
|
|
272
|
+
|
|
273
|
+
self.global_state["H"] = H
|
|
274
|
+
|
|
275
|
+
@torch.no_grad
|
|
276
|
+
def apply(self, var):
|
|
277
|
+
dir = _newton_step(
|
|
278
|
+
var=var,
|
|
279
|
+
H = self.global_state["H"],
|
|
280
|
+
damping = self.defaults["damping"],
|
|
281
|
+
inner = self.children.get("inner", None),
|
|
282
|
+
H_tfm=None,
|
|
283
|
+
eigval_fn=self.defaults["eigval_fn"],
|
|
284
|
+
use_lstsq=self.defaults["use_lstsq"],
|
|
285
|
+
g_proj=None,
|
|
286
|
+
)
|
|
287
|
+
|
|
288
|
+
var.update = vec_to_tensors(dir, var.params)
|
|
289
|
+
return var
|
|
290
|
+
|
|
291
|
+
def get_H(self,var=...):
|
|
292
|
+
return _get_H(self.global_state["H"], self.defaults["eigval_fn"])
|
|
@@ -1,4 +1,7 @@
|
|
|
1
|
-
from .
|
|
1
|
+
from .ifn import InverseFreeNewton
|
|
2
|
+
from .inm import INM
|
|
3
|
+
from .multipoint import SixthOrder3P, SixthOrder3PM2, SixthOrder5P, TwoPointNewton
|
|
4
|
+
from .newton import Newton
|
|
2
5
|
from .newton_cg import NewtonCG, NewtonCGSteihaug
|
|
3
|
-
from .nystrom import
|
|
4
|
-
from .
|
|
6
|
+
from .nystrom import NystromPCG, NystromSketchAndSolve
|
|
7
|
+
from .rsn import RSN
|