torchzero 0.4.0__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_identical.py +22 -22
- tests/test_opts.py +199 -198
- torchzero/__init__.py +1 -1
- torchzero/core/__init__.py +1 -1
- torchzero/core/functional.py +1 -1
- torchzero/core/modular.py +5 -5
- torchzero/core/module.py +2 -2
- torchzero/core/objective.py +10 -10
- torchzero/core/transform.py +1 -1
- torchzero/linalg/__init__.py +3 -2
- torchzero/linalg/eigh.py +223 -4
- torchzero/linalg/orthogonalize.py +2 -4
- torchzero/linalg/qr.py +12 -0
- torchzero/linalg/solve.py +1 -3
- torchzero/linalg/svd.py +47 -20
- torchzero/modules/__init__.py +4 -3
- torchzero/modules/adaptive/__init__.py +11 -3
- torchzero/modules/adaptive/adagrad.py +10 -10
- torchzero/modules/adaptive/adahessian.py +2 -2
- torchzero/modules/adaptive/adam.py +1 -1
- torchzero/modules/adaptive/adan.py +1 -1
- torchzero/modules/adaptive/adaptive_heavyball.py +1 -1
- torchzero/modules/adaptive/esgd.py +2 -2
- torchzero/modules/adaptive/ggt.py +186 -0
- torchzero/modules/adaptive/lion.py +2 -1
- torchzero/modules/adaptive/lre_optimizers.py +299 -0
- torchzero/modules/adaptive/mars.py +2 -2
- torchzero/modules/adaptive/matrix_momentum.py +1 -1
- torchzero/modules/adaptive/msam.py +4 -4
- torchzero/modules/adaptive/muon.py +9 -6
- torchzero/modules/adaptive/natural_gradient.py +32 -15
- torchzero/modules/adaptive/psgd/__init__.py +5 -0
- torchzero/modules/adaptive/psgd/_psgd_utils.py +37 -0
- torchzero/modules/adaptive/psgd/psgd.py +1390 -0
- torchzero/modules/adaptive/psgd/psgd_dense_newton.py +174 -0
- torchzero/modules/adaptive/psgd/psgd_kron_newton.py +203 -0
- torchzero/modules/adaptive/psgd/psgd_kron_whiten.py +185 -0
- torchzero/modules/adaptive/psgd/psgd_lra_newton.py +118 -0
- torchzero/modules/adaptive/psgd/psgd_lra_whiten.py +116 -0
- torchzero/modules/adaptive/rprop.py +2 -2
- torchzero/modules/adaptive/sam.py +4 -4
- torchzero/modules/adaptive/shampoo.py +28 -3
- torchzero/modules/adaptive/soap.py +3 -3
- torchzero/modules/adaptive/sophia_h.py +2 -2
- torchzero/modules/clipping/clipping.py +7 -7
- torchzero/modules/conjugate_gradient/cg.py +2 -2
- torchzero/modules/experimental/__init__.py +5 -0
- torchzero/modules/experimental/adanystrom.py +258 -0
- torchzero/modules/experimental/common_directions_whiten.py +142 -0
- torchzero/modules/experimental/cubic_adam.py +160 -0
- torchzero/modules/experimental/eigen_sr1.py +182 -0
- torchzero/modules/experimental/eigengrad.py +207 -0
- torchzero/modules/experimental/l_infinity.py +1 -1
- torchzero/modules/experimental/matrix_nag.py +122 -0
- torchzero/modules/experimental/newton_solver.py +2 -2
- torchzero/modules/experimental/newtonnewton.py +34 -40
- torchzero/modules/grad_approximation/fdm.py +2 -2
- torchzero/modules/grad_approximation/rfdm.py +4 -4
- torchzero/modules/least_squares/gn.py +68 -45
- torchzero/modules/line_search/backtracking.py +2 -2
- torchzero/modules/line_search/line_search.py +1 -1
- torchzero/modules/line_search/strong_wolfe.py +2 -2
- torchzero/modules/misc/escape.py +1 -1
- torchzero/modules/misc/gradient_accumulation.py +1 -1
- torchzero/modules/misc/misc.py +1 -1
- torchzero/modules/misc/multistep.py +4 -7
- torchzero/modules/misc/regularization.py +2 -2
- torchzero/modules/misc/split.py +1 -1
- torchzero/modules/misc/switch.py +2 -2
- torchzero/modules/momentum/cautious.py +3 -3
- torchzero/modules/momentum/momentum.py +1 -1
- torchzero/modules/ops/higher_level.py +1 -1
- torchzero/modules/ops/multi.py +1 -1
- torchzero/modules/projections/projection.py +5 -2
- torchzero/modules/quasi_newton/__init__.py +1 -1
- torchzero/modules/quasi_newton/damping.py +1 -1
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +1 -1
- torchzero/modules/quasi_newton/lbfgs.py +3 -3
- torchzero/modules/quasi_newton/lsr1.py +3 -3
- torchzero/modules/quasi_newton/quasi_newton.py +44 -29
- torchzero/modules/quasi_newton/sg2.py +69 -205
- torchzero/modules/restarts/restars.py +17 -17
- torchzero/modules/second_order/inm.py +33 -25
- torchzero/modules/second_order/newton.py +132 -130
- torchzero/modules/second_order/newton_cg.py +3 -3
- torchzero/modules/second_order/nystrom.py +83 -32
- torchzero/modules/second_order/rsn.py +41 -44
- torchzero/modules/smoothing/laplacian.py +1 -1
- torchzero/modules/smoothing/sampling.py +2 -3
- torchzero/modules/step_size/adaptive.py +6 -6
- torchzero/modules/step_size/lr.py +2 -2
- torchzero/modules/trust_region/cubic_regularization.py +1 -1
- torchzero/modules/trust_region/levenberg_marquardt.py +2 -2
- torchzero/modules/trust_region/trust_cg.py +1 -1
- torchzero/modules/variance_reduction/svrg.py +4 -5
- torchzero/modules/weight_decay/reinit.py +2 -2
- torchzero/modules/weight_decay/weight_decay.py +5 -5
- torchzero/modules/wrappers/optim_wrapper.py +4 -4
- torchzero/modules/zeroth_order/cd.py +1 -1
- torchzero/optim/mbs.py +291 -0
- torchzero/optim/wrappers/nevergrad.py +0 -9
- torchzero/optim/wrappers/optuna.py +2 -0
- torchzero/utils/benchmarks/__init__.py +0 -0
- torchzero/utils/benchmarks/logistic.py +122 -0
- torchzero/utils/derivatives.py +4 -4
- {torchzero-0.4.0.dist-info → torchzero-0.4.1.dist-info}/METADATA +1 -1
- torchzero-0.4.1.dist-info/RECORD +209 -0
- torchzero/modules/adaptive/lmadagrad.py +0 -241
- torchzero-0.4.0.dist-info/RECORD +0 -191
- /torchzero/modules/{functional.py → opt_utils.py} +0 -0
- {torchzero-0.4.0.dist-info → torchzero-0.4.1.dist-info}/WHEEL +0 -0
- {torchzero-0.4.0.dist-info → torchzero-0.4.1.dist-info}/top_level.txt +0 -0
|
@@ -1,29 +1,39 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
|
|
3
|
-
from ...core import
|
|
4
|
-
from ...utils import TensorList,
|
|
5
|
-
from
|
|
3
|
+
from ...core import Chainable, Transform
|
|
4
|
+
from ...utils import TensorList, unpack_dicts, unpack_states, vec_to_tensors_
|
|
5
|
+
from ...linalg.linear_operator import Dense
|
|
6
|
+
|
|
6
7
|
|
|
7
8
|
def sg2_(
|
|
8
9
|
delta_g: torch.Tensor,
|
|
9
10
|
cd: torch.Tensor,
|
|
10
11
|
) -> torch.Tensor:
|
|
11
|
-
"""cd is c * perturbation
|
|
12
|
-
(or divide delta_g by two)."""
|
|
12
|
+
"""cd is c * perturbation."""
|
|
13
13
|
|
|
14
|
-
M = torch.outer(
|
|
14
|
+
M = torch.outer(0.5 / cd, delta_g)
|
|
15
15
|
H_hat = 0.5 * (M + M.T)
|
|
16
16
|
|
|
17
17
|
return H_hat
|
|
18
18
|
|
|
19
19
|
|
|
20
20
|
|
|
21
|
-
class SG2(
|
|
21
|
+
class SG2(Transform):
|
|
22
22
|
"""second-order stochastic gradient
|
|
23
23
|
|
|
24
|
+
2SPSA (second-order SPSA)
|
|
25
|
+
```python
|
|
26
|
+
opt = tz.Optimizer(
|
|
27
|
+
model.parameters(),
|
|
28
|
+
tz.m.SPSA(),
|
|
29
|
+
tz.m.SG2(),
|
|
30
|
+
tz.m.LR(1e-2),
|
|
31
|
+
)
|
|
32
|
+
```
|
|
33
|
+
|
|
24
34
|
SG2 with line search
|
|
25
35
|
```python
|
|
26
|
-
opt = tz.
|
|
36
|
+
opt = tz.Optimizer(
|
|
27
37
|
model.parameters(),
|
|
28
38
|
tz.m.SG2(),
|
|
29
39
|
tz.m.Backtracking()
|
|
@@ -32,9 +42,9 @@ class SG2(Module):
|
|
|
32
42
|
|
|
33
43
|
SG2 with trust region
|
|
34
44
|
```python
|
|
35
|
-
opt = tz.
|
|
45
|
+
opt = tz.Optimizer(
|
|
36
46
|
model.parameters(),
|
|
37
|
-
tz.m.LevenbergMarquardt(tz.m.SG2()),
|
|
47
|
+
tz.m.LevenbergMarquardt(tz.m.SG2(beta=0.75. n_samples=4)),
|
|
38
48
|
)
|
|
39
49
|
```
|
|
40
50
|
|
|
@@ -43,24 +53,22 @@ class SG2(Module):
|
|
|
43
53
|
def __init__(
|
|
44
54
|
self,
|
|
45
55
|
n_samples: int = 1,
|
|
46
|
-
|
|
56
|
+
n_first_step_samples: int = 10,
|
|
57
|
+
start_step: int = 10,
|
|
47
58
|
beta: float | None = None,
|
|
48
|
-
damping: float =
|
|
49
|
-
|
|
50
|
-
one_sided: bool = False, # one-sided hessian
|
|
51
|
-
use_lstsq: bool = True,
|
|
59
|
+
damping: float = 1e-4,
|
|
60
|
+
h: float = 1e-2,
|
|
52
61
|
seed=None,
|
|
62
|
+
update_freq: int = 1,
|
|
53
63
|
inner: Chainable | None = None,
|
|
54
64
|
):
|
|
55
|
-
defaults = dict(n_samples=n_samples, h=h, beta=beta, damping=damping,
|
|
56
|
-
super().__init__(defaults)
|
|
57
|
-
|
|
58
|
-
if inner is not None: self.set_child('inner', inner)
|
|
65
|
+
defaults = dict(n_samples=n_samples, h=h, beta=beta, damping=damping, seed=seed, start_step=start_step, n_first_step_samples=n_first_step_samples)
|
|
66
|
+
super().__init__(defaults, update_freq=update_freq, inner=inner)
|
|
59
67
|
|
|
60
68
|
@torch.no_grad
|
|
61
|
-
def
|
|
62
|
-
|
|
63
|
-
self.
|
|
69
|
+
def update_states(self, objective, states, settings):
|
|
70
|
+
fs = settings[0]
|
|
71
|
+
k = self.increment_counter("step", 0)
|
|
64
72
|
|
|
65
73
|
params = TensorList(objective.params)
|
|
66
74
|
closure = objective.closure
|
|
@@ -68,36 +76,28 @@ class SG2(Module):
|
|
|
68
76
|
raise RuntimeError("closure is required for SG2")
|
|
69
77
|
generator = self.get_generator(params[0].device, self.defaults["seed"])
|
|
70
78
|
|
|
71
|
-
h =
|
|
79
|
+
h = unpack_dicts(settings, "h")
|
|
72
80
|
x_0 = params.clone()
|
|
73
|
-
n_samples =
|
|
81
|
+
n_samples = fs["n_samples"]
|
|
82
|
+
if k == 0: n_samples = fs["n_first_step_samples"]
|
|
74
83
|
H_hat = None
|
|
75
84
|
|
|
85
|
+
# compute new approximation
|
|
76
86
|
for i in range(n_samples):
|
|
77
87
|
# generate perturbation
|
|
78
88
|
cd = params.rademacher_like(generator=generator).mul_(h)
|
|
79
89
|
|
|
80
|
-
#
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
closure()
|
|
90
|
+
# two sided hessian approximation
|
|
91
|
+
params.add_(cd)
|
|
92
|
+
closure()
|
|
93
|
+
g_p = params.grad.fill_none_(params)
|
|
85
94
|
|
|
86
|
-
|
|
87
|
-
|
|
95
|
+
params.copy_(x_0)
|
|
96
|
+
params.sub_(cd)
|
|
97
|
+
closure()
|
|
98
|
+
g_n = params.grad.fill_none_(params)
|
|
88
99
|
|
|
89
|
-
|
|
90
|
-
else:
|
|
91
|
-
params.add_(cd)
|
|
92
|
-
closure()
|
|
93
|
-
g_p = params.grad.fill_none_(params)
|
|
94
|
-
|
|
95
|
-
params.copy_(x_0)
|
|
96
|
-
params.sub_(cd)
|
|
97
|
-
closure()
|
|
98
|
-
g_n = params.grad.fill_none_(params)
|
|
99
|
-
|
|
100
|
-
delta_g = g_p - g_n
|
|
100
|
+
delta_g = g_p - g_n
|
|
101
101
|
|
|
102
102
|
# restore params
|
|
103
103
|
params.set_(x_0)
|
|
@@ -114,179 +114,43 @@ class SG2(Module):
|
|
|
114
114
|
assert H_hat is not None
|
|
115
115
|
if n_samples > 1: H_hat /= n_samples
|
|
116
116
|
|
|
117
|
+
# add damping
|
|
118
|
+
if fs["damping"] != 0:
|
|
119
|
+
reg = torch.eye(H_hat.size(0), device=H_hat.device, dtype=H_hat.dtype).mul_(fs["damping"])
|
|
120
|
+
H_hat += reg
|
|
121
|
+
|
|
117
122
|
# update H
|
|
118
123
|
H = self.global_state.get("H", None)
|
|
119
124
|
if H is None: H = H_hat
|
|
120
125
|
else:
|
|
121
|
-
beta =
|
|
122
|
-
if beta is None: beta = k / (k+
|
|
126
|
+
beta = fs["beta"]
|
|
127
|
+
if beta is None: beta = (k+1) / (k+2)
|
|
123
128
|
H.lerp_(H_hat, 1-beta)
|
|
124
129
|
|
|
125
130
|
self.global_state["H"] = H
|
|
126
131
|
|
|
127
132
|
|
|
128
133
|
@torch.no_grad
|
|
129
|
-
def
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
134
|
+
def apply_states(self, objective, states, settings):
|
|
135
|
+
fs = settings[0]
|
|
136
|
+
updates = objective.get_updates()
|
|
137
|
+
|
|
138
|
+
H: torch.Tensor = self.global_state["H"]
|
|
139
|
+
k = self.global_state["step"]
|
|
140
|
+
if k < fs["start_step"]:
|
|
141
|
+
# don't precondition yet
|
|
142
|
+
# I guess we can try using trace to scale the update
|
|
143
|
+
# because it will have horrible scaling otherwise
|
|
144
|
+
torch._foreach_div_(updates, H.trace())
|
|
145
|
+
return objective
|
|
146
|
+
|
|
147
|
+
b = torch.cat([t.ravel() for t in updates])
|
|
148
|
+
sol = torch.linalg.lstsq(H, b).solution # pylint:disable=not-callable
|
|
149
|
+
|
|
150
|
+
vec_to_tensors_(sol, updates)
|
|
142
151
|
return objective
|
|
143
152
|
|
|
144
|
-
def get_H(self,objective=...):
|
|
145
|
-
return
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
# two sided
|
|
151
|
-
# we have g via x + d, x - d
|
|
152
|
-
# H via g(x + d), g(x - d)
|
|
153
|
-
# 1 is x, x+2d
|
|
154
|
-
# 2 is x, x-2d
|
|
155
|
-
# 5 evals in total
|
|
156
|
-
|
|
157
|
-
# one sided
|
|
158
|
-
# g via x, x + d
|
|
159
|
-
# 1 is x, x + d
|
|
160
|
-
# 2 is x, x - d
|
|
161
|
-
# 3 evals and can use two sided for g_0
|
|
162
|
-
|
|
163
|
-
class SPSA2(Module):
|
|
164
|
-
"""second-order SPSA
|
|
165
|
-
|
|
166
|
-
SPSA2 with line search
|
|
167
|
-
```python
|
|
168
|
-
opt = tz.Modular(
|
|
169
|
-
model.parameters(),
|
|
170
|
-
tz.m.SPSA2(),
|
|
171
|
-
tz.m.Backtracking()
|
|
172
|
-
)
|
|
173
|
-
```
|
|
174
|
-
|
|
175
|
-
SPSA2 with trust region
|
|
176
|
-
```python
|
|
177
|
-
opt = tz.Modular(
|
|
178
|
-
model.parameters(),
|
|
179
|
-
tz.m.LevenbergMarquardt(tz.m.SPSA2()),
|
|
180
|
-
)
|
|
181
|
-
```
|
|
182
|
-
"""
|
|
183
|
-
|
|
184
|
-
def __init__(
|
|
185
|
-
self,
|
|
186
|
-
n_samples: int = 1,
|
|
187
|
-
h: float = 1e-2,
|
|
188
|
-
beta: float | None = None,
|
|
189
|
-
damping: float = 0,
|
|
190
|
-
eigval_fn=None,
|
|
191
|
-
use_lstsq: bool = True,
|
|
192
|
-
seed=None,
|
|
193
|
-
inner: Chainable | None = None,
|
|
194
|
-
):
|
|
195
|
-
defaults = dict(n_samples=n_samples, h=h, beta=beta, damping=damping, eigval_fn=eigval_fn, seed=seed, use_lstsq=use_lstsq)
|
|
196
|
-
super().__init__(defaults)
|
|
197
|
-
|
|
198
|
-
if inner is not None: self.set_child('inner', inner)
|
|
199
|
-
|
|
200
|
-
@torch.no_grad
|
|
201
|
-
def update(self, objective):
|
|
202
|
-
k = self.global_state.get('step', 0) + 1
|
|
203
|
-
self.global_state["step"] = k
|
|
204
|
-
|
|
205
|
-
params = TensorList(objective.params)
|
|
206
|
-
closure = objective.closure
|
|
207
|
-
if closure is None:
|
|
208
|
-
raise RuntimeError("closure is required for SPSA2")
|
|
209
|
-
|
|
210
|
-
generator = self.get_generator(params[0].device, self.defaults["seed"])
|
|
211
|
-
|
|
212
|
-
h = self.get_settings(params, "h")
|
|
213
|
-
x_0 = params.clone()
|
|
214
|
-
n_samples = self.defaults["n_samples"]
|
|
215
|
-
H_hat = None
|
|
216
|
-
g_0 = None
|
|
217
|
-
|
|
218
|
-
for i in range(n_samples):
|
|
219
|
-
# perturbations for g and H
|
|
220
|
-
cd_g = params.rademacher_like(generator=generator).mul_(h)
|
|
221
|
-
cd_H = params.rademacher_like(generator=generator).mul_(h)
|
|
222
|
-
|
|
223
|
-
# evaluate 4 points
|
|
224
|
-
x_p = x_0 + cd_g
|
|
225
|
-
x_n = x_0 - cd_g
|
|
226
|
-
|
|
227
|
-
params.set_(x_p)
|
|
228
|
-
f_p = closure(False)
|
|
229
|
-
params.add_(cd_H)
|
|
230
|
-
f_pp = closure(False)
|
|
231
|
-
|
|
232
|
-
params.set_(x_n)
|
|
233
|
-
f_n = closure(False)
|
|
234
|
-
params.add_(cd_H)
|
|
235
|
-
f_np = closure(False)
|
|
236
|
-
|
|
237
|
-
g_p_vec = (f_pp - f_p) / cd_H
|
|
238
|
-
g_n_vec = (f_np - f_n) / cd_H
|
|
239
|
-
delta_g = g_p_vec - g_n_vec
|
|
240
|
-
|
|
241
|
-
# restore params
|
|
242
|
-
params.set_(x_0)
|
|
243
|
-
|
|
244
|
-
# compute grad
|
|
245
|
-
g_i = (f_p - f_n) / (2 * cd_g)
|
|
246
|
-
if g_0 is None: g_0 = g_i
|
|
247
|
-
else: g_0 += g_i
|
|
248
|
-
|
|
249
|
-
# compute H hat
|
|
250
|
-
H_i = sg2_(
|
|
251
|
-
delta_g = delta_g.to_vec().div_(2.0),
|
|
252
|
-
cd = cd_g.to_vec(), # The interval is measured by the original 'cd'
|
|
253
|
-
)
|
|
254
|
-
if H_hat is None: H_hat = H_i
|
|
255
|
-
else: H_hat += H_i
|
|
256
|
-
|
|
257
|
-
assert g_0 is not None and H_hat is not None
|
|
258
|
-
if n_samples > 1:
|
|
259
|
-
g_0 /= n_samples
|
|
260
|
-
H_hat /= n_samples
|
|
261
|
-
|
|
262
|
-
# set grad to approximated grad
|
|
263
|
-
objective.grads = g_0
|
|
153
|
+
def get_H(self, objective=...):
|
|
154
|
+
return Dense(self.global_state["H"])
|
|
264
155
|
|
|
265
|
-
# update H
|
|
266
|
-
H = self.global_state.get("H", None)
|
|
267
|
-
if H is None: H = H_hat
|
|
268
|
-
else:
|
|
269
|
-
beta = self.defaults["beta"]
|
|
270
|
-
if beta is None: beta = k / (k+1)
|
|
271
|
-
H.lerp_(H_hat, 1-beta)
|
|
272
|
-
|
|
273
|
-
self.global_state["H"] = H
|
|
274
|
-
|
|
275
|
-
@torch.no_grad
|
|
276
|
-
def apply(self, objective):
|
|
277
|
-
dir = _newton_step(
|
|
278
|
-
objective=objective,
|
|
279
|
-
H = self.global_state["H"],
|
|
280
|
-
damping = self.defaults["damping"],
|
|
281
|
-
inner = self.children.get("inner", None),
|
|
282
|
-
H_tfm=None,
|
|
283
|
-
eigval_fn=self.defaults["eigval_fn"],
|
|
284
|
-
use_lstsq=self.defaults["use_lstsq"],
|
|
285
|
-
g_proj=None,
|
|
286
|
-
)
|
|
287
|
-
|
|
288
|
-
objective.updates = vec_to_tensors(dir, objective.params)
|
|
289
|
-
return objective
|
|
290
156
|
|
|
291
|
-
def get_H(self,objective=...):
|
|
292
|
-
return _get_H(self.global_state["H"], self.defaults["eigval_fn"])
|
|
@@ -8,8 +8,8 @@ from ...core import Chainable, Module, Objective
|
|
|
8
8
|
from ...utils import TensorList
|
|
9
9
|
from ..termination import TerminationCriteriaBase
|
|
10
10
|
|
|
11
|
-
def _reset_except_self(
|
|
12
|
-
for m in
|
|
11
|
+
def _reset_except_self(objective, modules, self: Module):
|
|
12
|
+
for m in modules:
|
|
13
13
|
if m is not self:
|
|
14
14
|
m.reset()
|
|
15
15
|
|
|
@@ -26,15 +26,15 @@ class RestartStrategyBase(Module, ABC):
|
|
|
26
26
|
self.set_child('modules', modules)
|
|
27
27
|
|
|
28
28
|
@abstractmethod
|
|
29
|
-
def should_reset(self,
|
|
29
|
+
def should_reset(self, objective: Objective) -> bool:
|
|
30
30
|
"""returns whether reset should occur"""
|
|
31
31
|
|
|
32
|
-
def _reset_on_condition(self,
|
|
32
|
+
def _reset_on_condition(self, objective: Objective):
|
|
33
33
|
modules = self.children.get('modules', None)
|
|
34
34
|
|
|
35
|
-
if self.should_reset(
|
|
35
|
+
if self.should_reset(objective):
|
|
36
36
|
if modules is None:
|
|
37
|
-
|
|
37
|
+
objective.post_step_hooks.append(partial(_reset_except_self, self=self))
|
|
38
38
|
else:
|
|
39
39
|
modules.reset()
|
|
40
40
|
|
|
@@ -78,11 +78,11 @@ class RestartOnStuck(RestartStrategyBase):
|
|
|
78
78
|
super().__init__(defaults, modules)
|
|
79
79
|
|
|
80
80
|
@torch.no_grad
|
|
81
|
-
def should_reset(self,
|
|
81
|
+
def should_reset(self, objective):
|
|
82
82
|
step = self.global_state.get('step', 0)
|
|
83
83
|
self.global_state['step'] = step + 1
|
|
84
84
|
|
|
85
|
-
params = TensorList(
|
|
85
|
+
params = TensorList(objective.params)
|
|
86
86
|
tol = self.defaults['tol']
|
|
87
87
|
if tol is None: tol = torch.finfo(params[0].dtype).tiny * 2
|
|
88
88
|
n_tol = self.defaults['n_tol']
|
|
@@ -124,12 +124,12 @@ class RestartEvery(RestartStrategyBase):
|
|
|
124
124
|
defaults = dict(steps=steps)
|
|
125
125
|
super().__init__(defaults, modules)
|
|
126
126
|
|
|
127
|
-
def should_reset(self,
|
|
127
|
+
def should_reset(self, objective):
|
|
128
128
|
step = self.global_state.get('step', 0) + 1
|
|
129
129
|
self.global_state['step'] = step
|
|
130
130
|
|
|
131
131
|
n = self.defaults['steps']
|
|
132
|
-
if isinstance(n, str): n = sum(p.numel() for p in
|
|
132
|
+
if isinstance(n, str): n = sum(p.numel() for p in objective.params if p.requires_grad)
|
|
133
133
|
|
|
134
134
|
# reset every n steps
|
|
135
135
|
if step % n == 0:
|
|
@@ -143,9 +143,9 @@ class RestartOnTerminationCriteria(RestartStrategyBase):
|
|
|
143
143
|
super().__init__(None, modules)
|
|
144
144
|
self.set_child('criteria', criteria)
|
|
145
145
|
|
|
146
|
-
def should_reset(self,
|
|
146
|
+
def should_reset(self, objective):
|
|
147
147
|
criteria = cast(TerminationCriteriaBase, self.children['criteria'])
|
|
148
|
-
return criteria.should_terminate(
|
|
148
|
+
return criteria.should_terminate(objective)
|
|
149
149
|
|
|
150
150
|
class PowellRestart(RestartStrategyBase):
|
|
151
151
|
"""Powell's two restarting criterions for conjugate gradient methods.
|
|
@@ -171,14 +171,14 @@ class PowellRestart(RestartStrategyBase):
|
|
|
171
171
|
defaults=dict(cond1=cond1, cond2=cond2)
|
|
172
172
|
super().__init__(defaults, modules)
|
|
173
173
|
|
|
174
|
-
def should_reset(self,
|
|
175
|
-
g = TensorList(
|
|
174
|
+
def should_reset(self, objective):
|
|
175
|
+
g = TensorList(objective.get_grads())
|
|
176
176
|
cond1 = self.defaults['cond1']; cond2 = self.defaults['cond2']
|
|
177
177
|
|
|
178
178
|
# -------------------------------- initialize -------------------------------- #
|
|
179
179
|
if 'initialized' not in self.global_state:
|
|
180
180
|
self.global_state['initialized'] = 0
|
|
181
|
-
g_prev = self.get_state(
|
|
181
|
+
g_prev = self.get_state(objective.params, 'g_prev', init=g)
|
|
182
182
|
return False
|
|
183
183
|
|
|
184
184
|
g_g = g.dot(g)
|
|
@@ -186,7 +186,7 @@ class PowellRestart(RestartStrategyBase):
|
|
|
186
186
|
reset = False
|
|
187
187
|
# ------------------------------- 1st condition ------------------------------ #
|
|
188
188
|
if cond1 is not None:
|
|
189
|
-
g_prev = self.get_state(
|
|
189
|
+
g_prev = self.get_state(objective.params, 'g_prev', must_exist=True, cls=TensorList)
|
|
190
190
|
g_g_prev = g_prev.dot(g)
|
|
191
191
|
|
|
192
192
|
if g_g_prev.abs() >= cond1 * g_g:
|
|
@@ -194,7 +194,7 @@ class PowellRestart(RestartStrategyBase):
|
|
|
194
194
|
|
|
195
195
|
# ------------------------------- 2nd condition ------------------------------ #
|
|
196
196
|
if (cond2 is not None) and (not reset):
|
|
197
|
-
d_g = TensorList(
|
|
197
|
+
d_g = TensorList(objective.get_updates()).dot(g)
|
|
198
198
|
if (-1-cond2) * g_g < d_g < (-1 + cond2) * g_g:
|
|
199
199
|
reset = True
|
|
200
200
|
|
|
@@ -3,9 +3,9 @@ from collections.abc import Callable
|
|
|
3
3
|
import torch
|
|
4
4
|
|
|
5
5
|
from ...core import Chainable, Transform, HessianMethod
|
|
6
|
-
from ...utils import TensorList,
|
|
7
|
-
from ..
|
|
8
|
-
from .newton import
|
|
6
|
+
from ...utils import TensorList, vec_to_tensors_, unpack_states
|
|
7
|
+
from ..opt_utils import safe_clip
|
|
8
|
+
from .newton import _newton_update_state_, _newton_solve, _newton_get_H
|
|
9
9
|
|
|
10
10
|
@torch.no_grad
|
|
11
11
|
def inm(f:torch.Tensor, J:torch.Tensor, s:torch.Tensor, y:torch.Tensor):
|
|
@@ -34,10 +34,10 @@ class ImprovedNewton(Transform):
|
|
|
34
34
|
def __init__(
|
|
35
35
|
self,
|
|
36
36
|
damping: float = 0,
|
|
37
|
-
use_lstsq: bool = False,
|
|
38
|
-
update_freq: int = 1,
|
|
39
|
-
H_tfm: Callable[[torch.Tensor, torch.Tensor], tuple[torch.Tensor, bool]] | Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
|
|
40
37
|
eigval_fn: Callable[[torch.Tensor], torch.Tensor] | None = None,
|
|
38
|
+
update_freq: int = 1,
|
|
39
|
+
precompute_inverse: bool | None = None,
|
|
40
|
+
use_lstsq: bool = False,
|
|
41
41
|
hessian_method: HessianMethod = "batched_autograd",
|
|
42
42
|
h: float = 1e-3,
|
|
43
43
|
inner: Chainable | None = None,
|
|
@@ -65,37 +65,45 @@ class ImprovedNewton(Transform):
|
|
|
65
65
|
x_prev, f_prev = unpack_states(states, objective.params, "x_prev", "f_prev", cls=TensorList)
|
|
66
66
|
|
|
67
67
|
# initialize on 1st step, do Newton step
|
|
68
|
-
if "
|
|
68
|
+
if "H" not in self.global_state:
|
|
69
69
|
x_prev.copy_(x_list)
|
|
70
70
|
f_prev.copy_(f_list)
|
|
71
|
-
|
|
72
|
-
return
|
|
71
|
+
P = J
|
|
73
72
|
|
|
74
73
|
# INM update
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
74
|
+
else:
|
|
75
|
+
s_list = x_list - x_prev
|
|
76
|
+
y_list = f_list - f_prev
|
|
77
|
+
x_prev.copy_(x_list)
|
|
78
|
+
f_prev.copy_(f_list)
|
|
79
79
|
|
|
80
|
-
|
|
80
|
+
P = inm(f, J, s=s_list.to_vec(), y=y_list.to_vec())
|
|
81
81
|
|
|
82
|
+
# update state
|
|
83
|
+
precompute_inverse = fs["precompute_inverse"]
|
|
84
|
+
if precompute_inverse is None:
|
|
85
|
+
precompute_inverse = fs["__update_freq"] >= 10
|
|
86
|
+
|
|
87
|
+
_newton_update_state_(
|
|
88
|
+
H=P,
|
|
89
|
+
state = self.global_state,
|
|
90
|
+
damping = fs["damping"],
|
|
91
|
+
eigval_fn = fs["eigval_fn"],
|
|
92
|
+
precompute_inverse = precompute_inverse,
|
|
93
|
+
use_lstsq = fs["use_lstsq"]
|
|
94
|
+
)
|
|
82
95
|
|
|
83
96
|
@torch.no_grad
|
|
84
97
|
def apply_states(self, objective, states, settings):
|
|
98
|
+
updates = objective.get_updates()
|
|
85
99
|
fs = settings[0]
|
|
86
100
|
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
H = self.global_state["P"],
|
|
90
|
-
damping = fs["damping"],
|
|
91
|
-
H_tfm = fs["H_tfm"],
|
|
92
|
-
eigval_fn = None, # it is applied in `update`
|
|
93
|
-
use_lstsq = fs["use_lstsq"],
|
|
94
|
-
)
|
|
95
|
-
|
|
96
|
-
objective.updates = vec_to_tensors(update, objective.params)
|
|
101
|
+
b = torch.cat([t.ravel() for t in updates])
|
|
102
|
+
sol = _newton_solve(b=b, state=self.global_state, use_lstsq=fs["use_lstsq"])
|
|
97
103
|
|
|
104
|
+
vec_to_tensors_(sol, updates)
|
|
98
105
|
return objective
|
|
99
106
|
|
|
107
|
+
|
|
100
108
|
def get_H(self,objective=...):
|
|
101
|
-
return
|
|
109
|
+
return _newton_get_H(self.global_state)
|