torchzero 0.3.11__py3-none-any.whl → 0.3.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_opts.py +95 -69
- tests/test_tensorlist.py +8 -7
- torchzero/__init__.py +1 -1
- torchzero/core/__init__.py +2 -2
- torchzero/core/module.py +225 -72
- torchzero/core/reformulation.py +65 -0
- torchzero/core/transform.py +44 -24
- torchzero/modules/__init__.py +13 -5
- torchzero/modules/{optimizers → adaptive}/__init__.py +5 -2
- torchzero/modules/adaptive/adagrad.py +356 -0
- torchzero/modules/{optimizers → adaptive}/adahessian.py +53 -52
- torchzero/modules/{optimizers → adaptive}/adam.py +0 -3
- torchzero/modules/{optimizers → adaptive}/adan.py +26 -40
- torchzero/modules/{optimizers → adaptive}/adaptive_heavyball.py +3 -6
- torchzero/modules/adaptive/aegd.py +54 -0
- torchzero/modules/{optimizers → adaptive}/esgd.py +1 -1
- torchzero/modules/{optimizers/ladagrad.py → adaptive/lmadagrad.py} +42 -39
- torchzero/modules/{optimizers → adaptive}/mars.py +24 -36
- torchzero/modules/adaptive/matrix_momentum.py +146 -0
- torchzero/modules/{optimizers → adaptive}/msam.py +14 -12
- torchzero/modules/{optimizers → adaptive}/muon.py +19 -20
- torchzero/modules/adaptive/natural_gradient.py +175 -0
- torchzero/modules/{optimizers → adaptive}/rprop.py +0 -2
- torchzero/modules/{optimizers → adaptive}/sam.py +1 -1
- torchzero/modules/{optimizers → adaptive}/shampoo.py +8 -4
- torchzero/modules/{optimizers → adaptive}/soap.py +27 -50
- torchzero/modules/{optimizers → adaptive}/sophia_h.py +2 -3
- torchzero/modules/clipping/clipping.py +85 -92
- torchzero/modules/clipping/ema_clipping.py +5 -5
- torchzero/modules/conjugate_gradient/__init__.py +11 -0
- torchzero/modules/{quasi_newton → conjugate_gradient}/cg.py +355 -369
- torchzero/modules/experimental/__init__.py +9 -32
- torchzero/modules/experimental/dct.py +2 -2
- torchzero/modules/experimental/fft.py +2 -2
- torchzero/modules/experimental/gradmin.py +4 -3
- torchzero/modules/experimental/l_infinity.py +111 -0
- torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +3 -40
- torchzero/modules/experimental/newton_solver.py +79 -17
- torchzero/modules/experimental/newtonnewton.py +27 -14
- torchzero/modules/experimental/scipy_newton_cg.py +105 -0
- torchzero/modules/experimental/structural_projections.py +1 -1
- torchzero/modules/functional.py +50 -14
- torchzero/modules/grad_approximation/fdm.py +19 -20
- torchzero/modules/grad_approximation/forward_gradient.py +4 -2
- torchzero/modules/grad_approximation/grad_approximator.py +43 -47
- torchzero/modules/grad_approximation/rfdm.py +144 -122
- torchzero/modules/higher_order/__init__.py +1 -1
- torchzero/modules/higher_order/higher_order_newton.py +31 -23
- torchzero/modules/least_squares/__init__.py +1 -0
- torchzero/modules/least_squares/gn.py +161 -0
- torchzero/modules/line_search/__init__.py +2 -2
- torchzero/modules/line_search/_polyinterp.py +289 -0
- torchzero/modules/line_search/adaptive.py +69 -44
- torchzero/modules/line_search/backtracking.py +83 -70
- torchzero/modules/line_search/line_search.py +159 -68
- torchzero/modules/line_search/scipy.py +1 -1
- torchzero/modules/line_search/strong_wolfe.py +319 -218
- torchzero/modules/misc/__init__.py +8 -0
- torchzero/modules/misc/debug.py +4 -4
- torchzero/modules/misc/escape.py +9 -7
- torchzero/modules/misc/gradient_accumulation.py +88 -22
- torchzero/modules/misc/homotopy.py +59 -0
- torchzero/modules/misc/misc.py +82 -15
- torchzero/modules/misc/multistep.py +47 -11
- torchzero/modules/misc/regularization.py +5 -9
- torchzero/modules/misc/split.py +55 -35
- torchzero/modules/misc/switch.py +1 -1
- torchzero/modules/momentum/__init__.py +1 -5
- torchzero/modules/momentum/averaging.py +3 -3
- torchzero/modules/momentum/cautious.py +42 -47
- torchzero/modules/momentum/momentum.py +35 -1
- torchzero/modules/ops/__init__.py +9 -1
- torchzero/modules/ops/binary.py +9 -8
- torchzero/modules/{momentum/ema.py → ops/higher_level.py} +10 -33
- torchzero/modules/ops/multi.py +15 -15
- torchzero/modules/ops/reduce.py +1 -1
- torchzero/modules/ops/utility.py +12 -8
- torchzero/modules/projections/projection.py +4 -4
- torchzero/modules/quasi_newton/__init__.py +1 -16
- torchzero/modules/quasi_newton/damping.py +105 -0
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -163
- torchzero/modules/quasi_newton/lbfgs.py +256 -200
- torchzero/modules/quasi_newton/lsr1.py +167 -132
- torchzero/modules/quasi_newton/quasi_newton.py +346 -446
- torchzero/modules/restarts/__init__.py +7 -0
- torchzero/modules/restarts/restars.py +252 -0
- torchzero/modules/second_order/__init__.py +2 -1
- torchzero/modules/second_order/multipoint.py +238 -0
- torchzero/modules/second_order/newton.py +133 -88
- torchzero/modules/second_order/newton_cg.py +141 -80
- torchzero/modules/smoothing/__init__.py +1 -1
- torchzero/modules/smoothing/sampling.py +300 -0
- torchzero/modules/step_size/__init__.py +1 -1
- torchzero/modules/step_size/adaptive.py +312 -47
- torchzero/modules/termination/__init__.py +14 -0
- torchzero/modules/termination/termination.py +207 -0
- torchzero/modules/trust_region/__init__.py +5 -0
- torchzero/modules/trust_region/cubic_regularization.py +170 -0
- torchzero/modules/trust_region/dogleg.py +92 -0
- torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
- torchzero/modules/trust_region/trust_cg.py +97 -0
- torchzero/modules/trust_region/trust_region.py +350 -0
- torchzero/modules/variance_reduction/__init__.py +1 -0
- torchzero/modules/variance_reduction/svrg.py +208 -0
- torchzero/modules/weight_decay/weight_decay.py +65 -64
- torchzero/modules/zeroth_order/__init__.py +1 -0
- torchzero/modules/zeroth_order/cd.py +359 -0
- torchzero/optim/root.py +65 -0
- torchzero/optim/utility/split.py +8 -8
- torchzero/optim/wrappers/directsearch.py +0 -1
- torchzero/optim/wrappers/fcmaes.py +3 -2
- torchzero/optim/wrappers/nlopt.py +0 -2
- torchzero/optim/wrappers/optuna.py +2 -2
- torchzero/optim/wrappers/scipy.py +81 -22
- torchzero/utils/__init__.py +40 -4
- torchzero/utils/compile.py +1 -1
- torchzero/utils/derivatives.py +123 -111
- torchzero/utils/linalg/__init__.py +9 -2
- torchzero/utils/linalg/linear_operator.py +329 -0
- torchzero/utils/linalg/matrix_funcs.py +2 -2
- torchzero/utils/linalg/orthogonalize.py +2 -1
- torchzero/utils/linalg/qr.py +2 -2
- torchzero/utils/linalg/solve.py +226 -154
- torchzero/utils/metrics.py +83 -0
- torchzero/utils/python_tools.py +6 -0
- torchzero/utils/tensorlist.py +105 -34
- torchzero/utils/torch_tools.py +9 -4
- torchzero-0.3.13.dist-info/METADATA +14 -0
- torchzero-0.3.13.dist-info/RECORD +166 -0
- {torchzero-0.3.11.dist-info → torchzero-0.3.13.dist-info}/top_level.txt +0 -1
- docs/source/conf.py +0 -59
- docs/source/docstring template.py +0 -46
- torchzero/modules/experimental/absoap.py +0 -253
- torchzero/modules/experimental/adadam.py +0 -118
- torchzero/modules/experimental/adamY.py +0 -131
- torchzero/modules/experimental/adam_lambertw.py +0 -149
- torchzero/modules/experimental/adaptive_step_size.py +0 -90
- torchzero/modules/experimental/adasoap.py +0 -177
- torchzero/modules/experimental/cosine.py +0 -214
- torchzero/modules/experimental/cubic_adam.py +0 -97
- torchzero/modules/experimental/eigendescent.py +0 -120
- torchzero/modules/experimental/etf.py +0 -195
- torchzero/modules/experimental/exp_adam.py +0 -113
- torchzero/modules/experimental/expanded_lbfgs.py +0 -141
- torchzero/modules/experimental/hnewton.py +0 -85
- torchzero/modules/experimental/modular_lbfgs.py +0 -265
- torchzero/modules/experimental/parabolic_search.py +0 -220
- torchzero/modules/experimental/subspace_preconditioners.py +0 -145
- torchzero/modules/experimental/tensor_adagrad.py +0 -42
- torchzero/modules/line_search/polynomial.py +0 -233
- torchzero/modules/momentum/matrix_momentum.py +0 -193
- torchzero/modules/optimizers/adagrad.py +0 -165
- torchzero/modules/quasi_newton/trust_region.py +0 -397
- torchzero/modules/smoothing/gaussian.py +0 -198
- torchzero-0.3.11.dist-info/METADATA +0 -404
- torchzero-0.3.11.dist-info/RECORD +0 -159
- torchzero-0.3.11.dist-info/licenses/LICENSE +0 -21
- /torchzero/modules/{optimizers → adaptive}/lion.py +0 -0
- /torchzero/modules/{optimizers → adaptive}/orthograd.py +0 -0
- /torchzero/modules/{optimizers → adaptive}/rmsprop.py +0 -0
- {torchzero-0.3.11.dist-info → torchzero-0.3.13.dist-info}/WHEEL +0 -0
|
@@ -1,369 +1,355 @@
|
|
|
1
|
-
from abc import ABC, abstractmethod
|
|
2
|
-
from typing import Literal
|
|
3
|
-
|
|
4
|
-
import torch
|
|
5
|
-
|
|
6
|
-
from ...core import
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
def
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
super().
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
#
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
def
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
def
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
#
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
reset_interval=reset_interval,
|
|
357
|
-
beta=beta,
|
|
358
|
-
update_freq=update_freq,
|
|
359
|
-
scale_first=scale_first,
|
|
360
|
-
scale_second=scale_second,
|
|
361
|
-
concat_params=concat_params,
|
|
362
|
-
inverse=True,
|
|
363
|
-
inner=inner,
|
|
364
|
-
)
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
|
|
369
|
-
return projected_gradient_(H=H, y=y)
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from typing import Literal
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
from ...core import (
|
|
7
|
+
Chainable,
|
|
8
|
+
Modular,
|
|
9
|
+
Module,
|
|
10
|
+
Transform,
|
|
11
|
+
Var,
|
|
12
|
+
apply_transform,
|
|
13
|
+
)
|
|
14
|
+
from ...utils import TensorList, as_tensorlist, unpack_dicts, unpack_states
|
|
15
|
+
from ..line_search import LineSearchBase
|
|
16
|
+
from ..quasi_newton.quasi_newton import HessianUpdateStrategy
|
|
17
|
+
from ..functional import safe_clip
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class ConguateGradientBase(Transform, ABC):
|
|
21
|
+
"""Base class for conjugate gradient methods. The only difference between them is how beta is calculated.
|
|
22
|
+
|
|
23
|
+
This is an abstract class, to use it, subclass it and override `get_beta`.
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
defaults (dict | None, optional): dictionary of settings defaults. Defaults to None.
|
|
28
|
+
clip_beta (bool, optional): whether to clip beta to be no less than 0. Defaults to False.
|
|
29
|
+
restart_interval (int | None | Literal["auto"], optional):
|
|
30
|
+
interval between resetting the search direction.
|
|
31
|
+
"auto" means number of dimensions + 1, None means no reset. Defaults to None.
|
|
32
|
+
inner (Chainable | None, optional): previous direction is added to the output of this module. Defaults to None.
|
|
33
|
+
|
|
34
|
+
Example:
|
|
35
|
+
|
|
36
|
+
```python
|
|
37
|
+
class PolakRibiere(ConguateGradientBase):
|
|
38
|
+
def __init__(
|
|
39
|
+
self,
|
|
40
|
+
clip_beta=True,
|
|
41
|
+
restart_interval: int | None = None,
|
|
42
|
+
inner: Chainable | None = None
|
|
43
|
+
):
|
|
44
|
+
super().__init__(clip_beta=clip_beta, restart_interval=restart_interval, inner=inner)
|
|
45
|
+
|
|
46
|
+
def get_beta(self, p, g, prev_g, prev_d):
|
|
47
|
+
denom = prev_g.dot(prev_g)
|
|
48
|
+
if denom.abs() <= torch.finfo(g[0].dtype).eps: return 0
|
|
49
|
+
return g.dot(g - prev_g) / denom
|
|
50
|
+
```
|
|
51
|
+
|
|
52
|
+
"""
|
|
53
|
+
def __init__(self, defaults = None, clip_beta: bool = False, restart_interval: int | None | Literal['auto'] = None, inner: Chainable | None = None):
|
|
54
|
+
if defaults is None: defaults = {}
|
|
55
|
+
defaults['restart_interval'] = restart_interval
|
|
56
|
+
defaults['clip_beta'] = clip_beta
|
|
57
|
+
super().__init__(defaults, uses_grad=False)
|
|
58
|
+
|
|
59
|
+
if inner is not None:
|
|
60
|
+
self.set_child('inner', inner)
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def reset_for_online(self):
|
|
64
|
+
super().reset_for_online()
|
|
65
|
+
self.clear_state_keys('prev_grad')
|
|
66
|
+
self.global_state.pop('stage', None)
|
|
67
|
+
self.global_state['step'] = self.global_state.get('step', 1) - 1
|
|
68
|
+
|
|
69
|
+
def initialize(self, p: TensorList, g: TensorList):
|
|
70
|
+
"""runs on first step when prev_grads and prev_dir are not available"""
|
|
71
|
+
|
|
72
|
+
@abstractmethod
|
|
73
|
+
def get_beta(self, p: TensorList, g: TensorList, prev_g: TensorList, prev_d: TensorList) -> float | torch.Tensor:
|
|
74
|
+
"""returns beta"""
|
|
75
|
+
|
|
76
|
+
@torch.no_grad
|
|
77
|
+
def update_tensors(self, tensors, params, grads, loss, states, settings):
|
|
78
|
+
tensors = as_tensorlist(tensors)
|
|
79
|
+
params = as_tensorlist(params)
|
|
80
|
+
|
|
81
|
+
step = self.global_state.get('step', 0) + 1
|
|
82
|
+
self.global_state['step'] = step
|
|
83
|
+
|
|
84
|
+
# initialize on first step
|
|
85
|
+
if self.global_state.get('stage', 0) == 0:
|
|
86
|
+
g_prev, d_prev = unpack_states(states, tensors, 'g_prev', 'd_prev', cls=TensorList)
|
|
87
|
+
d_prev.copy_(tensors)
|
|
88
|
+
g_prev.copy_(tensors)
|
|
89
|
+
self.initialize(params, tensors)
|
|
90
|
+
self.global_state['stage'] = 1
|
|
91
|
+
|
|
92
|
+
else:
|
|
93
|
+
# if `update_tensors` was called multiple times before `apply_tensors`,
|
|
94
|
+
# stage becomes 2
|
|
95
|
+
self.global_state['stage'] = 2
|
|
96
|
+
|
|
97
|
+
@torch.no_grad
|
|
98
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
99
|
+
tensors = as_tensorlist(tensors)
|
|
100
|
+
step = self.global_state['step']
|
|
101
|
+
|
|
102
|
+
if 'inner' in self.children:
|
|
103
|
+
tensors = as_tensorlist(apply_transform(self.children['inner'], tensors, params, grads))
|
|
104
|
+
|
|
105
|
+
assert self.global_state['stage'] != 0
|
|
106
|
+
if self.global_state['stage'] == 1:
|
|
107
|
+
self.global_state['stage'] = 2
|
|
108
|
+
return tensors
|
|
109
|
+
|
|
110
|
+
params = as_tensorlist(params)
|
|
111
|
+
g_prev, d_prev = unpack_states(states, tensors, 'g_prev', 'd_prev', cls=TensorList)
|
|
112
|
+
|
|
113
|
+
# get beta
|
|
114
|
+
beta = self.get_beta(params, tensors, g_prev, d_prev)
|
|
115
|
+
if settings[0]['clip_beta']: beta = max(0, beta) # pyright:ignore[reportArgumentType]
|
|
116
|
+
|
|
117
|
+
# inner step
|
|
118
|
+
# calculate new direction with beta
|
|
119
|
+
dir = tensors.add_(d_prev.mul_(beta))
|
|
120
|
+
d_prev.copy_(dir)
|
|
121
|
+
|
|
122
|
+
# resetting
|
|
123
|
+
restart_interval = settings[0]['restart_interval']
|
|
124
|
+
if restart_interval == 'auto': restart_interval = tensors.global_numel() + 1
|
|
125
|
+
if restart_interval is not None and step % restart_interval == 0:
|
|
126
|
+
self.state.clear()
|
|
127
|
+
self.global_state.clear()
|
|
128
|
+
|
|
129
|
+
return dir
|
|
130
|
+
|
|
131
|
+
# ------------------------------- Polak-Ribière ------------------------------ #
|
|
132
|
+
def polak_ribiere_beta(g: TensorList, prev_g: TensorList):
|
|
133
|
+
denom = prev_g.dot(prev_g)
|
|
134
|
+
if denom.abs() <= torch.finfo(g[0].dtype).tiny * 2: return 0
|
|
135
|
+
return g.dot(g - prev_g) / denom
|
|
136
|
+
|
|
137
|
+
class PolakRibiere(ConguateGradientBase):
|
|
138
|
+
"""Polak-Ribière-Polyak nonlinear conjugate gradient method.
|
|
139
|
+
|
|
140
|
+
Note:
|
|
141
|
+
This requires step size to be determined via a line search, so put a line search like ``tz.m.StrongWolfe(c2=0.1, a_init="first-order")`` after this.
|
|
142
|
+
"""
|
|
143
|
+
def __init__(self, clip_beta=True, restart_interval: int | None = None, inner: Chainable | None = None):
|
|
144
|
+
super().__init__(clip_beta=clip_beta, restart_interval=restart_interval, inner=inner)
|
|
145
|
+
|
|
146
|
+
def get_beta(self, p, g, prev_g, prev_d):
|
|
147
|
+
return polak_ribiere_beta(g, prev_g)
|
|
148
|
+
|
|
149
|
+
# ------------------------------ Fletcher–Reeves ----------------------------- #
|
|
150
|
+
def fletcher_reeves_beta(gg: torch.Tensor, prev_gg: torch.Tensor):
|
|
151
|
+
if prev_gg.abs() <= torch.finfo(gg.dtype).tiny * 2: return 0
|
|
152
|
+
return gg / prev_gg
|
|
153
|
+
|
|
154
|
+
class FletcherReeves(ConguateGradientBase):
|
|
155
|
+
"""Fletcher–Reeves nonlinear conjugate gradient method.
|
|
156
|
+
|
|
157
|
+
Note:
|
|
158
|
+
This requires step size to be determined via a line search, so put a line search like ``tz.m.StrongWolfe(c2=0.1, a_init="first-order")`` after this.
|
|
159
|
+
"""
|
|
160
|
+
def __init__(self, restart_interval: int | None | Literal['auto'] = 'auto', clip_beta=False, inner: Chainable | None = None):
|
|
161
|
+
super().__init__(clip_beta=clip_beta, restart_interval=restart_interval, inner=inner)
|
|
162
|
+
|
|
163
|
+
def initialize(self, p, g):
|
|
164
|
+
self.global_state['prev_gg'] = g.dot(g)
|
|
165
|
+
|
|
166
|
+
def get_beta(self, p, g, prev_g, prev_d):
|
|
167
|
+
gg = g.dot(g)
|
|
168
|
+
beta = fletcher_reeves_beta(gg, self.global_state['prev_gg'])
|
|
169
|
+
self.global_state['prev_gg'] = gg
|
|
170
|
+
return beta
|
|
171
|
+
|
|
172
|
+
# ----------------------------- Hestenes–Stiefel ----------------------------- #
|
|
173
|
+
def hestenes_stiefel_beta(g: TensorList, prev_d: TensorList,prev_g: TensorList):
|
|
174
|
+
grad_diff = g - prev_g
|
|
175
|
+
denom = prev_d.dot(grad_diff)
|
|
176
|
+
if denom.abs() < torch.finfo(g[0].dtype).tiny * 2: return 0
|
|
177
|
+
return (g.dot(grad_diff) / denom).neg()
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
class HestenesStiefel(ConguateGradientBase):
|
|
181
|
+
"""Hestenes–Stiefel nonlinear conjugate gradient method.
|
|
182
|
+
|
|
183
|
+
Note:
|
|
184
|
+
This requires step size to be determined via a line search, so put a line search like ``tz.m.StrongWolfe(c2=0.1, a_init="first-order")`` after this.
|
|
185
|
+
"""
|
|
186
|
+
def __init__(self, restart_interval: int | None | Literal['auto'] = None, clip_beta=False, inner: Chainable | None = None):
|
|
187
|
+
super().__init__(clip_beta=clip_beta, restart_interval=restart_interval, inner=inner)
|
|
188
|
+
|
|
189
|
+
def get_beta(self, p, g, prev_g, prev_d):
|
|
190
|
+
return hestenes_stiefel_beta(g, prev_d, prev_g)
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
# --------------------------------- Dai–Yuan --------------------------------- #
|
|
194
|
+
def dai_yuan_beta(g: TensorList, prev_d: TensorList,prev_g: TensorList):
|
|
195
|
+
denom = prev_d.dot(g - prev_g)
|
|
196
|
+
if denom.abs() <= torch.finfo(g[0].dtype).tiny * 2: return 0
|
|
197
|
+
return (g.dot(g) / denom).neg()
|
|
198
|
+
|
|
199
|
+
class DaiYuan(ConguateGradientBase):
|
|
200
|
+
"""Dai–Yuan nonlinear conjugate gradient method.
|
|
201
|
+
|
|
202
|
+
Note:
|
|
203
|
+
This requires step size to be determined via a line search, so put a line search like ``tz.m.StrongWolfe(c2=0.1)`` after this.
|
|
204
|
+
"""
|
|
205
|
+
def __init__(self, restart_interval: int | None | Literal['auto'] = None, clip_beta=False, inner: Chainable | None = None):
|
|
206
|
+
super().__init__(clip_beta=clip_beta, restart_interval=restart_interval, inner=inner)
|
|
207
|
+
|
|
208
|
+
def get_beta(self, p, g, prev_g, prev_d):
|
|
209
|
+
return dai_yuan_beta(g, prev_d, prev_g)
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
# -------------------------------- Liu-Storey -------------------------------- #
|
|
213
|
+
def liu_storey_beta(g:TensorList, prev_d:TensorList, prev_g:TensorList, ):
|
|
214
|
+
denom = prev_g.dot(prev_d)
|
|
215
|
+
if denom.abs() <= torch.finfo(g[0].dtype).tiny * 2: return 0
|
|
216
|
+
return g.dot(g - prev_g) / denom
|
|
217
|
+
|
|
218
|
+
class LiuStorey(ConguateGradientBase):
|
|
219
|
+
"""Liu-Storey nonlinear conjugate gradient method.
|
|
220
|
+
|
|
221
|
+
Note:
|
|
222
|
+
This requires step size to be determined via a line search, so put a line search like ``tz.m.StrongWolfe(c2=0.1, a_init="first-order")`` after this.
|
|
223
|
+
"""
|
|
224
|
+
def __init__(self, restart_interval: int | None | Literal['auto'] = None, clip_beta=False, inner: Chainable | None = None):
|
|
225
|
+
super().__init__(clip_beta=clip_beta, restart_interval=restart_interval, inner=inner)
|
|
226
|
+
|
|
227
|
+
def get_beta(self, p, g, prev_g, prev_d):
|
|
228
|
+
return liu_storey_beta(g, prev_d, prev_g)
|
|
229
|
+
|
|
230
|
+
# ----------------------------- Conjugate Descent ---------------------------- #
|
|
231
|
+
def conjugate_descent_beta(g:TensorList, prev_d:TensorList, prev_g:TensorList):
|
|
232
|
+
denom = prev_g.dot(prev_d)
|
|
233
|
+
if denom.abs() <= torch.finfo(g[0].dtype).tiny * 2: return 0
|
|
234
|
+
return g.dot(g) / denom
|
|
235
|
+
|
|
236
|
+
class ConjugateDescent(ConguateGradientBase):
|
|
237
|
+
"""Conjugate Descent (CD).
|
|
238
|
+
|
|
239
|
+
Note:
|
|
240
|
+
This requires step size to be determined via a line search, so put a line search like ``tz.m.StrongWolfe(c2=0.1, a_init="first-order")`` after this.
|
|
241
|
+
"""
|
|
242
|
+
def __init__(self, restart_interval: int | None | Literal['auto'] = None, clip_beta=False, inner: Chainable | None = None):
|
|
243
|
+
super().__init__(clip_beta=clip_beta, restart_interval=restart_interval, inner=inner)
|
|
244
|
+
|
|
245
|
+
def get_beta(self, p, g, prev_g, prev_d):
|
|
246
|
+
return conjugate_descent_beta(g, prev_d, prev_g)
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
# -------------------------------- Hager-Zhang ------------------------------- #
|
|
250
|
+
def hager_zhang_beta(g:TensorList, prev_d:TensorList, prev_g:TensorList,):
|
|
251
|
+
g_diff = g - prev_g
|
|
252
|
+
denom = prev_d.dot(g_diff)
|
|
253
|
+
if denom.abs() <= torch.finfo(g[0].dtype).tiny * 2: return 0
|
|
254
|
+
|
|
255
|
+
term1 = 1/denom
|
|
256
|
+
# term2
|
|
257
|
+
term2 = (g_diff - (2 * prev_d * (g_diff.pow(2).global_sum()/denom))).dot(g)
|
|
258
|
+
return (term1 * term2).neg()
|
|
259
|
+
|
|
260
|
+
|
|
261
|
+
class HagerZhang(ConguateGradientBase):
|
|
262
|
+
"""Hager-Zhang nonlinear conjugate gradient method,
|
|
263
|
+
|
|
264
|
+
Note:
|
|
265
|
+
This requires step size to be determined via a line search, so put a line search like ``tz.m.StrongWolfe(c2=0.1, a_init="first-order")`` after this.
|
|
266
|
+
"""
|
|
267
|
+
def __init__(self, restart_interval: int | None | Literal['auto'] = None, clip_beta=False, inner: Chainable | None = None):
|
|
268
|
+
super().__init__(clip_beta=clip_beta, restart_interval=restart_interval, inner=inner)
|
|
269
|
+
|
|
270
|
+
def get_beta(self, p, g, prev_g, prev_d):
|
|
271
|
+
return hager_zhang_beta(g, prev_d, prev_g)
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
# ----------------------------------- DYHS ---------------------------------- #
|
|
275
|
+
def dyhs_beta(g: TensorList, prev_d: TensorList,prev_g: TensorList):
|
|
276
|
+
grad_diff = g - prev_g
|
|
277
|
+
denom = prev_d.dot(grad_diff)
|
|
278
|
+
if denom.abs() <= torch.finfo(g[0].dtype).tiny * 2: return 0
|
|
279
|
+
|
|
280
|
+
# Dai-Yuan
|
|
281
|
+
dy_beta = (g.dot(g) / denom).neg().clamp(min=0)
|
|
282
|
+
|
|
283
|
+
# Hestenes–Stiefel
|
|
284
|
+
hs_beta = (g.dot(grad_diff) / denom).neg().clamp(min=0)
|
|
285
|
+
|
|
286
|
+
return max(0, min(dy_beta, hs_beta)) # type:ignore
|
|
287
|
+
|
|
288
|
+
class DYHS(ConguateGradientBase):
|
|
289
|
+
"""Dai-Yuan - Hestenes–Stiefel hybrid conjugate gradient method.
|
|
290
|
+
|
|
291
|
+
Note:
|
|
292
|
+
This requires step size to be determined via a line search, so put a line search like ``tz.m.StrongWolfe(c2=0.1, a_init="first-order")`` after this.
|
|
293
|
+
"""
|
|
294
|
+
def __init__(self, restart_interval: int | None | Literal['auto'] = None, clip_beta=False, inner: Chainable | None = None):
|
|
295
|
+
super().__init__(clip_beta=clip_beta, restart_interval=restart_interval, inner=inner)
|
|
296
|
+
|
|
297
|
+
def get_beta(self, p, g, prev_g, prev_d):
|
|
298
|
+
return dyhs_beta(g, prev_d, prev_g)
|
|
299
|
+
|
|
300
|
+
|
|
301
|
+
def projected_gradient_(H:torch.Tensor, y:torch.Tensor):
|
|
302
|
+
Hy = H @ y
|
|
303
|
+
yHy = safe_clip(y.dot(Hy))
|
|
304
|
+
H -= (Hy.outer(y) @ H) / yHy
|
|
305
|
+
return H
|
|
306
|
+
|
|
307
|
+
class ProjectedGradientMethod(HessianUpdateStrategy): # this doesn't maintain hessian
|
|
308
|
+
"""Projected gradient method. Directly projects the gradient onto subspace conjugate to past directions.
|
|
309
|
+
|
|
310
|
+
Notes:
|
|
311
|
+
- This method uses N^2 memory.
|
|
312
|
+
- This requires step size to be determined via a line search, so put a line search like ``tz.m.StrongWolfe(c2=0.1, a_init="first-order")`` after this.
|
|
313
|
+
- This is not the same as projected gradient descent.
|
|
314
|
+
|
|
315
|
+
Reference:
|
|
316
|
+
Pearson, J. D. (1969). Variable metric methods of minimisation. The Computer Journal, 12(2), 171–178. doi:10.1093/comjnl/12.2.171. (algorithm 5 in section 6)
|
|
317
|
+
|
|
318
|
+
"""
|
|
319
|
+
|
|
320
|
+
def __init__(
|
|
321
|
+
self,
|
|
322
|
+
init_scale: float | Literal["auto"] = 1,
|
|
323
|
+
tol: float = 1e-32,
|
|
324
|
+
ptol: float | None = 1e-32,
|
|
325
|
+
ptol_restart: bool = False,
|
|
326
|
+
gtol: float | None = 1e-32,
|
|
327
|
+
restart_interval: int | None | Literal['auto'] = 'auto',
|
|
328
|
+
beta: float | None = None,
|
|
329
|
+
update_freq: int = 1,
|
|
330
|
+
scale_first: bool = False,
|
|
331
|
+
concat_params: bool = True,
|
|
332
|
+
# inverse: bool = True,
|
|
333
|
+
inner: Chainable | None = None,
|
|
334
|
+
):
|
|
335
|
+
super().__init__(
|
|
336
|
+
defaults=None,
|
|
337
|
+
init_scale=init_scale,
|
|
338
|
+
tol=tol,
|
|
339
|
+
ptol=ptol,
|
|
340
|
+
ptol_restart=ptol_restart,
|
|
341
|
+
gtol=gtol,
|
|
342
|
+
restart_interval=restart_interval,
|
|
343
|
+
beta=beta,
|
|
344
|
+
update_freq=update_freq,
|
|
345
|
+
scale_first=scale_first,
|
|
346
|
+
concat_params=concat_params,
|
|
347
|
+
inverse=True,
|
|
348
|
+
inner=inner,
|
|
349
|
+
)
|
|
350
|
+
|
|
351
|
+
|
|
352
|
+
|
|
353
|
+
def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
|
|
354
|
+
return projected_gradient_(H=H, y=y)
|
|
355
|
+
|