torchzero 0.3.13__py3-none-any.whl → 0.3.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_opts.py +0 -7
- torchzero/core/module.py +4 -0
- torchzero/modules/conjugate_gradient/cg.py +16 -16
- torchzero/modules/experimental/spsa1.py +93 -0
- torchzero/modules/grad_approximation/__init__.py +1 -1
- torchzero/modules/grad_approximation/forward_gradient.py +2 -5
- torchzero/modules/grad_approximation/rfdm.py +27 -110
- torchzero/modules/line_search/scipy.py +15 -3
- torchzero/modules/line_search/strong_wolfe.py +0 -2
- torchzero/modules/restarts/restars.py +5 -4
- torchzero/modules/second_order/newton_cg.py +86 -110
- torchzero/modules/trust_region/levenberg_marquardt.py +2 -2
- torchzero/modules/trust_region/trust_cg.py +6 -4
- torchzero/modules/zeroth_order/__init__.py +1 -1
- torchzero/modules/zeroth_order/cd.py +1 -238
- torchzero/utils/optimizer.py +2 -2
- torchzero/utils/python_tools.py +1 -0
- {torchzero-0.3.13.dist-info → torchzero-0.3.14.dist-info}/METADATA +1 -1
- {torchzero-0.3.13.dist-info → torchzero-0.3.14.dist-info}/RECORD +21 -20
- {torchzero-0.3.13.dist-info → torchzero-0.3.14.dist-info}/WHEEL +0 -0
- {torchzero-0.3.13.dist-info → torchzero-0.3.14.dist-info}/top_level.txt +0 -0
|
@@ -13,21 +13,12 @@ from ..trust_region.trust_region import default_radius
|
|
|
13
13
|
class NewtonCG(Module):
|
|
14
14
|
"""Newton's method with a matrix-free conjugate gradient or minimial-residual solver.
|
|
15
15
|
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
forming the full Hessian matrix, it only requires Hessian-vector products
|
|
19
|
-
(HVPs). These can be calculated efficiently using automatic
|
|
20
|
-
differentiation or approximated using finite differences.
|
|
16
|
+
Notes:
|
|
17
|
+
* In most cases NewtonCGSteihaug should be the first module in the chain because it relies on autograd. Use the ``inner`` argument if you wish to apply Newton preconditioning to another module's output.
|
|
21
18
|
|
|
22
|
-
|
|
23
|
-
In most cases NewtonCG should be the first module in the chain because it relies on autograd. Use the :code:`inner` argument if you wish to apply Newton preconditioning to another module's output.
|
|
19
|
+
* This module requires the a closure passed to the optimizer step, as it needs to re-evaluate the loss and gradients for calculating HVPs. The closure must accept a ``backward`` argument (refer to documentation).
|
|
24
20
|
|
|
25
|
-
|
|
26
|
-
This module requires the a closure passed to the optimizer step,
|
|
27
|
-
as it needs to re-evaluate the loss and gradients for calculating HVPs.
|
|
28
|
-
The closure must accept a ``backward`` argument (refer to documentation).
|
|
29
|
-
|
|
30
|
-
.. warning::
|
|
21
|
+
Warning:
|
|
31
22
|
CG may fail if hessian is not positive-definite.
|
|
32
23
|
|
|
33
24
|
Args:
|
|
@@ -66,26 +57,24 @@ class NewtonCG(Module):
|
|
|
66
57
|
NewtonCG will attempt to apply preconditioning to the output of this module.
|
|
67
58
|
|
|
68
59
|
Examples:
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
)
|
|
88
|
-
|
|
60
|
+
Newton-CG with a backtracking line search:
|
|
61
|
+
|
|
62
|
+
```python
|
|
63
|
+
opt = tz.Modular(
|
|
64
|
+
model.parameters(),
|
|
65
|
+
tz.m.NewtonCG(),
|
|
66
|
+
tz.m.Backtracking()
|
|
67
|
+
)
|
|
68
|
+
```
|
|
69
|
+
|
|
70
|
+
Truncated Newton method (useful for large-scale problems):
|
|
71
|
+
```
|
|
72
|
+
opt = tz.Modular(
|
|
73
|
+
model.parameters(),
|
|
74
|
+
tz.m.NewtonCG(maxiter=10),
|
|
75
|
+
tz.m.Backtracking()
|
|
76
|
+
)
|
|
77
|
+
```
|
|
89
78
|
|
|
90
79
|
"""
|
|
91
80
|
def __init__(
|
|
@@ -95,7 +84,7 @@ class NewtonCG(Module):
|
|
|
95
84
|
reg: float = 1e-8,
|
|
96
85
|
hvp_method: Literal["forward", "central", "autograd"] = "autograd",
|
|
97
86
|
solver: Literal['cg', 'minres', 'minres_npc'] = 'cg',
|
|
98
|
-
h: float = 1e-3,
|
|
87
|
+
h: float = 1e-3, # tuned 1e-4 or 1e-3
|
|
99
88
|
miniter:int = 1,
|
|
100
89
|
warm_start=False,
|
|
101
90
|
inner: Chainable | None = None,
|
|
@@ -187,96 +176,95 @@ class NewtonCG(Module):
|
|
|
187
176
|
|
|
188
177
|
|
|
189
178
|
class NewtonCGSteihaug(Module):
|
|
190
|
-
"""
|
|
179
|
+
"""Newton's method with trust region and a matrix-free Steihaug-Toint conjugate gradient solver.
|
|
191
180
|
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
forming the full Hessian matrix, it only requires Hessian-vector products
|
|
195
|
-
(HVPs). These can be calculated efficiently using automatic
|
|
196
|
-
differentiation or approximated using finite differences.
|
|
181
|
+
Notes:
|
|
182
|
+
* In most cases NewtonCGSteihaug should be the first module in the chain because it relies on autograd. Use the ``inner`` argument if you wish to apply Newton preconditioning to another module's output.
|
|
197
183
|
|
|
198
|
-
|
|
199
|
-
In most cases NewtonCGSteihaug should be the first module in the chain because it relies on autograd. Use the :code:`inner` argument if you wish to apply Newton preconditioning to another module's output.
|
|
200
|
-
|
|
201
|
-
.. note::
|
|
202
|
-
This module requires the a closure passed to the optimizer step,
|
|
203
|
-
as it needs to re-evaluate the loss and gradients for calculating HVPs.
|
|
204
|
-
The closure must accept a ``backward`` argument (refer to documentation).
|
|
205
|
-
|
|
206
|
-
.. warning::
|
|
207
|
-
CG may fail if hessian is not positive-definite.
|
|
184
|
+
* This module requires the a closure passed to the optimizer step, as it needs to re-evaluate the loss and gradients for calculating HVPs. The closure must accept a ``backward`` argument (refer to documentation).
|
|
208
185
|
|
|
209
186
|
Args:
|
|
210
|
-
maxiter (int | None, optional):
|
|
211
|
-
Maximum number of iterations for the conjugate gradient solver.
|
|
212
|
-
By default, this is set to the number of dimensions in the
|
|
213
|
-
objective function, which is the theoretical upper bound for CG
|
|
214
|
-
convergence. Setting this to a smaller value (truncated Newton)
|
|
215
|
-
can still generate good search directions. Defaults to None.
|
|
216
187
|
eta (float, optional):
|
|
217
|
-
|
|
218
|
-
nplus (float, optional):
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
trust region
|
|
222
|
-
|
|
188
|
+
if ratio of actual to predicted rediction is larger than this, step is accepted. Defaults to 0.0.
|
|
189
|
+
nplus (float, optional): increase factor on successful steps. Defaults to 1.5.
|
|
190
|
+
nminus (float, optional): decrease factor on unsuccessful steps. Defaults to 0.75.
|
|
191
|
+
rho_good (float, optional):
|
|
192
|
+
if ratio of actual to predicted rediction is larger than this, trust region size is multiplied by `nplus`.
|
|
193
|
+
rho_bad (float, optional):
|
|
194
|
+
if ratio of actual to predicted rediction is less than this, trust region size is multiplied by `nminus`.
|
|
195
|
+
init (float, optional): Initial trust region value. Defaults to 1.
|
|
196
|
+
max_attempts (max_attempts, optional):
|
|
197
|
+
maximum number of trust radius reductions per step. A zero update vector is returned when
|
|
198
|
+
this limit is exceeded. Defaults to 10.
|
|
199
|
+
max_history (int, optional):
|
|
200
|
+
CG will store this many intermediate solutions, reusing them when trust radius is reduced
|
|
201
|
+
instead of re-running CG. Each solution storage requires 2N memory. Defaults to 100.
|
|
202
|
+
boundary_tol (float | None, optional):
|
|
203
|
+
The trust region only increases when suggested step's norm is at least `(1-boundary_tol)*trust_region`.
|
|
204
|
+
This prevents increasing trust region when solution is not on the boundary. Defaults to 1e-2.
|
|
205
|
+
|
|
206
|
+
maxiter (int | None, optional):
|
|
207
|
+
maximum number of CG iterations per step. Each iteration requies one backward pass if `hvp_method="forward"`, two otherwise. Defaults to None.
|
|
208
|
+
miniter (int, optional):
|
|
209
|
+
minimal number of CG iterations. This prevents making no progress
|
|
223
210
|
tol (float, optional):
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
reg (float, optional):
|
|
227
|
-
|
|
228
|
-
|
|
211
|
+
terminates CG when norm of the residual is less than this value. Defaults to 1e-8.
|
|
212
|
+
when initial guess is below tolerance. Defaults to 1.
|
|
213
|
+
reg (float, optional): hessian regularization. Defaults to 1e-8.
|
|
214
|
+
solver (str, optional): solver, "cg" or "minres". "cg" is recommended. Defaults to 'cg'.
|
|
215
|
+
adapt_tol (bool, optional):
|
|
216
|
+
if True, whenever trust radius collapses to smallest representable number,
|
|
217
|
+
the tolerance is multiplied by 0.1. Defaults to True.
|
|
218
|
+
npc_terminate (bool, optional):
|
|
219
|
+
whether to terminate CG/MINRES whenever negative curvature is detected. Defaults to False.
|
|
220
|
+
|
|
229
221
|
hvp_method (str, optional):
|
|
230
|
-
|
|
222
|
+
either "forward" to use forward formula which requires one backward pass per Hvp, or "central" to use a more accurate central formula which requires two backward passes. "forward" is usually accurate enough. Defaults to "forward".
|
|
223
|
+
h (float, optional): finite difference step size. Defaults to 1e-3.
|
|
231
224
|
|
|
232
|
-
- ``"autograd"``: Use PyTorch's autograd to calculate exact HVPs.
|
|
233
|
-
This requires creating a graph for the gradient.
|
|
234
|
-
- ``"forward"``: Use a forward finite difference formula to
|
|
235
|
-
approximate the HVP. This requires one extra gradient evaluation.
|
|
236
|
-
- ``"central"``: Use a central finite difference formula for a
|
|
237
|
-
more accurate HVP approximation. This requires two extra
|
|
238
|
-
gradient evaluations.
|
|
239
|
-
Defaults to "autograd".
|
|
240
|
-
h (float, optional):
|
|
241
|
-
The step size for finite differences if :code:`hvp_method` is
|
|
242
|
-
``"forward"`` or ``"central"``. Defaults to 1e-3.
|
|
243
225
|
inner (Chainable | None, optional):
|
|
244
|
-
|
|
226
|
+
applies preconditioning to output of this module. Defaults to None.
|
|
245
227
|
|
|
246
|
-
Examples:
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
.. code-block:: python
|
|
228
|
+
### Examples:
|
|
229
|
+
Trust-region Newton-CG:
|
|
250
230
|
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
231
|
+
```python
|
|
232
|
+
opt = tz.Modular(
|
|
233
|
+
model.parameters(),
|
|
234
|
+
tz.m.NewtonCGSteihaug(),
|
|
235
|
+
)
|
|
236
|
+
```
|
|
255
237
|
|
|
256
|
-
Reference:
|
|
238
|
+
### Reference:
|
|
257
239
|
Steihaug, Trond. "The conjugate gradient method and trust regions in large scale optimization." SIAM Journal on Numerical Analysis 20.3 (1983): 626-637.
|
|
258
240
|
"""
|
|
259
241
|
def __init__(
|
|
260
242
|
self,
|
|
261
|
-
|
|
243
|
+
# trust region settings
|
|
262
244
|
eta: float= 0.0,
|
|
263
245
|
nplus: float = 3.5,
|
|
264
246
|
nminus: float = 0.25,
|
|
265
247
|
rho_good: float = 0.99,
|
|
266
248
|
rho_bad: float = 1e-4,
|
|
267
249
|
init: float = 1,
|
|
268
|
-
tol: float = 1e-8,
|
|
269
|
-
reg: float = 1e-8,
|
|
270
|
-
hvp_method: Literal["forward", "central"] = "forward",
|
|
271
|
-
solver: Literal['cg', "minres"] = 'cg',
|
|
272
|
-
h: float = 1e-3,
|
|
273
250
|
max_attempts: int = 100,
|
|
274
251
|
max_history: int = 100,
|
|
275
|
-
boundary_tol: float = 1e-
|
|
252
|
+
boundary_tol: float = 1e-6, # tuned
|
|
253
|
+
|
|
254
|
+
# cg settings
|
|
255
|
+
maxiter: int | None = None,
|
|
276
256
|
miniter: int = 1,
|
|
277
|
-
|
|
257
|
+
tol: float = 1e-8,
|
|
258
|
+
reg: float = 1e-8,
|
|
259
|
+
solver: Literal['cg', "minres"] = 'cg',
|
|
278
260
|
adapt_tol: bool = True,
|
|
279
261
|
npc_terminate: bool = False,
|
|
262
|
+
|
|
263
|
+
# hvp settings
|
|
264
|
+
hvp_method: Literal["forward", "central"] = "central",
|
|
265
|
+
h: float = 1e-3, # tuned 1e-4 or 1e-3
|
|
266
|
+
|
|
267
|
+
# inner
|
|
280
268
|
inner: Chainable | None = None,
|
|
281
269
|
):
|
|
282
270
|
defaults = locals().copy()
|
|
@@ -336,19 +324,8 @@ class NewtonCGSteihaug(Module):
|
|
|
336
324
|
raise ValueError(hvp_method)
|
|
337
325
|
|
|
338
326
|
|
|
339
|
-
# ------------------------- update RMS preconditioner ------------------------ #
|
|
340
|
-
b = var.get_update()
|
|
341
|
-
P_mm = None
|
|
342
|
-
rms_beta = self.defaults["rms_beta"]
|
|
343
|
-
if rms_beta is not None:
|
|
344
|
-
exp_avg_sq = self.get_state(params, "exp_avg_sq", init=b, cls=TensorList)
|
|
345
|
-
exp_avg_sq.mul_(rms_beta).addcmul(b, b, value=1-rms_beta)
|
|
346
|
-
exp_avg_sq_sqrt = exp_avg_sq.sqrt().add_(1e-8)
|
|
347
|
-
def _P_mm(x):
|
|
348
|
-
return x / exp_avg_sq_sqrt
|
|
349
|
-
P_mm = _P_mm
|
|
350
|
-
|
|
351
327
|
# -------------------------------- inner step -------------------------------- #
|
|
328
|
+
b = var.get_update()
|
|
352
329
|
if 'inner' in self.children:
|
|
353
330
|
b = apply_transform(self.children['inner'], b, params=params, grads=grad, var=var)
|
|
354
331
|
b = as_tensorlist(b)
|
|
@@ -392,7 +369,6 @@ class NewtonCGSteihaug(Module):
|
|
|
392
369
|
miniter=miniter,
|
|
393
370
|
npc_terminate=npc_terminate,
|
|
394
371
|
history_size=max_history,
|
|
395
|
-
P_mm=P_mm,
|
|
396
372
|
)
|
|
397
373
|
|
|
398
374
|
elif solver == 'minres':
|
|
@@ -14,13 +14,13 @@ class LevenbergMarquardt(TrustRegionBase):
|
|
|
14
14
|
hess_module (Module | None, optional):
|
|
15
15
|
A module that maintains a hessian approximation (not hessian inverse!).
|
|
16
16
|
This includes all full-matrix quasi-newton methods, ``tz.m.Newton`` and ``tz.m.GaussNewton``.
|
|
17
|
-
When using quasi-newton methods, set
|
|
17
|
+
When using quasi-newton methods, set ``inverse=False`` when constructing them.
|
|
18
18
|
y (float, optional):
|
|
19
19
|
when ``y=0``, identity matrix is added to hessian, when ``y=1``, diagonal of the hessian approximation
|
|
20
20
|
is added. Values between interpolate. This should only be used with Gauss-Newton. Defaults to 0.
|
|
21
21
|
eta (float, optional):
|
|
22
22
|
if ratio of actual to predicted rediction is larger than this, step is accepted.
|
|
23
|
-
When
|
|
23
|
+
When ``hess_module`` is ``Newton`` or ``GaussNewton``, this can be set to 0. Defaults to 0.15.
|
|
24
24
|
nplus (float, optional): increase factor on successful steps. Defaults to 1.5.
|
|
25
25
|
nminus (float, optional): decrease factor on unsuccessful steps. Defaults to 0.75.
|
|
26
26
|
rho_good (float, optional):
|
|
@@ -60,17 +60,19 @@ class TrustCG(TrustRegionBase):
|
|
|
60
60
|
nminus: float = 0.25,
|
|
61
61
|
rho_good: float = 0.99,
|
|
62
62
|
rho_bad: float = 1e-4,
|
|
63
|
-
boundary_tol: float | None = 1e-
|
|
63
|
+
boundary_tol: float | None = 1e-6, # tuned
|
|
64
64
|
init: float = 1,
|
|
65
65
|
max_attempts: int = 10,
|
|
66
66
|
radius_strategy: _RadiusStrategy | _RADIUS_KEYS = 'default',
|
|
67
67
|
reg: float = 0,
|
|
68
|
-
|
|
68
|
+
maxiter: int | None = None,
|
|
69
|
+
miniter: int = 1,
|
|
70
|
+
cg_tol: float = 1e-8,
|
|
69
71
|
prefer_exact: bool = True,
|
|
70
72
|
update_freq: int = 1,
|
|
71
73
|
inner: Chainable | None = None,
|
|
72
74
|
):
|
|
73
|
-
defaults = dict(reg=reg, prefer_exact=prefer_exact, cg_tol=cg_tol)
|
|
75
|
+
defaults = dict(reg=reg, prefer_exact=prefer_exact, cg_tol=cg_tol, maxiter=maxiter, miniter=miniter)
|
|
74
76
|
super().__init__(
|
|
75
77
|
defaults=defaults,
|
|
76
78
|
hess_module=hess_module,
|
|
@@ -93,5 +95,5 @@ class TrustCG(TrustRegionBase):
|
|
|
93
95
|
if settings['prefer_exact'] and isinstance(H, linear_operator.ScaledIdentity):
|
|
94
96
|
return H.solve_bounded(g, radius)
|
|
95
97
|
|
|
96
|
-
x, _ = cg(H.matvec, g, trust_radius=radius, reg=settings['reg'], tol=settings["cg_tol"])
|
|
98
|
+
x, _ = cg(H.matvec, g, trust_radius=radius, reg=settings['reg'], maxiter=settings["maxiter"], miniter=settings["miniter"], tol=settings["cg_tol"])
|
|
97
99
|
return x
|
|
@@ -1 +1 @@
|
|
|
1
|
-
from .cd import CD
|
|
1
|
+
from .cd import CD
|
|
@@ -9,11 +9,10 @@ import torch
|
|
|
9
9
|
|
|
10
10
|
from ...core import Module
|
|
11
11
|
from ...utils import NumberList, TensorList
|
|
12
|
-
from ..line_search.adaptive import adaptive_tracking
|
|
13
12
|
|
|
14
13
|
class CD(Module):
|
|
15
14
|
"""Coordinate descent. Proposes a descent direction along a single coordinate.
|
|
16
|
-
|
|
15
|
+
A line search such as ``tz.m.ScipyMinimizeScalar(maxiter=8)`` or a fixed step size can be used after this.
|
|
17
16
|
|
|
18
17
|
Args:
|
|
19
18
|
h (float, optional): finite difference step size. Defaults to 1e-3.
|
|
@@ -121,239 +120,3 @@ class CD(Module):
|
|
|
121
120
|
var.update = update
|
|
122
121
|
return var
|
|
123
122
|
|
|
124
|
-
|
|
125
|
-
def _icd_get_idx(self: Module, params: TensorList):
|
|
126
|
-
ndim = params.global_numel()
|
|
127
|
-
igrad = self.get_state(params, "igrad", cls=TensorList)
|
|
128
|
-
|
|
129
|
-
# -------------------------- 1st n steps fill igrad -------------------------- #
|
|
130
|
-
index = self.global_state.get('index', 0)
|
|
131
|
-
self.global_state['index'] = index + 1
|
|
132
|
-
if index < ndim:
|
|
133
|
-
return index, igrad
|
|
134
|
-
|
|
135
|
-
# ------------------ select randomly weighted by magnitudes ------------------ #
|
|
136
|
-
igrad_abs = igrad.abs()
|
|
137
|
-
gmin = igrad_abs.global_min()
|
|
138
|
-
gmax = igrad_abs.global_max()
|
|
139
|
-
|
|
140
|
-
pmin, pmax, pow = self.get_settings(params, "pmin", "pmax", "pow", cls=NumberList)
|
|
141
|
-
|
|
142
|
-
p: TensorList = ((igrad_abs - gmin) / (gmax - gmin)) ** pow # pyright:ignore[reportOperatorIssue]
|
|
143
|
-
p.mul_(pmax-pmin).add_(pmin)
|
|
144
|
-
|
|
145
|
-
if 'np_gen' not in self.global_state:
|
|
146
|
-
self.global_state['np_gen'] = np.random.default_rng(0)
|
|
147
|
-
np_gen = self.global_state['np_gen']
|
|
148
|
-
|
|
149
|
-
p_vec = p.to_vec()
|
|
150
|
-
p_sum = p_vec.sum()
|
|
151
|
-
if p_sum > 1e-12:
|
|
152
|
-
return np_gen.choice(ndim, p=p_vec.div_(p_sum).numpy(force=True)), igrad
|
|
153
|
-
|
|
154
|
-
# --------------------- sum is too small, do cycle again --------------------- #
|
|
155
|
-
self.global_state.clear()
|
|
156
|
-
self.clear_state_keys('h_vec', 'igrad', 'alphas')
|
|
157
|
-
|
|
158
|
-
if 'generator' not in self.global_state:
|
|
159
|
-
self.global_state['generator'] = random.Random(0)
|
|
160
|
-
generator = self.global_state['generator']
|
|
161
|
-
return generator.randrange(0, p_vec.numel()), igrad
|
|
162
|
-
|
|
163
|
-
class CCD(Module):
|
|
164
|
-
"""Cumulative coordinate descent. This updates one gradient coordinate at a time and accumulates it
|
|
165
|
-
to the update direction. The coordinate updated is random weighted by magnitudes of current update direction.
|
|
166
|
-
As update direction ceases to be a descent direction due to old accumulated coordinates, it is decayed.
|
|
167
|
-
|
|
168
|
-
Args:
|
|
169
|
-
pmin (float, optional): multiplier to probability of picking the lowest magnitude gradient. Defaults to 0.1.
|
|
170
|
-
pmax (float, optional): multiplier to probability of picking the largest magnitude gradient. Defaults to 1.0.
|
|
171
|
-
pow (int, optional): power transform to probabilities. Defaults to 2.
|
|
172
|
-
decay (float, optional): accumulated gradient decay on failed step. Defaults to 0.5.
|
|
173
|
-
decay2 (float, optional): decay multiplier decay on failed step. Defaults to 0.25.
|
|
174
|
-
nplus (float, optional): step size increase on successful steps. Defaults to 1.5.
|
|
175
|
-
nminus (float, optional): step size increase on unsuccessful steps. Defaults to 0.75.
|
|
176
|
-
"""
|
|
177
|
-
def __init__(self, pmin=0.1, pmax=1.0, pow=2, decay:float=0.8, decay2:float=0.2, nplus=1.5, nminus=0.75):
|
|
178
|
-
|
|
179
|
-
defaults = dict(pmin=pmin, pmax=pmax, pow=pow, decay=decay, decay2=decay2, nplus=nplus, nminus=nminus)
|
|
180
|
-
super().__init__(defaults)
|
|
181
|
-
|
|
182
|
-
@torch.no_grad
|
|
183
|
-
def step(self, var):
|
|
184
|
-
closure = var.closure
|
|
185
|
-
if closure is None:
|
|
186
|
-
raise RuntimeError("CD requires closure")
|
|
187
|
-
|
|
188
|
-
params = TensorList(var.params)
|
|
189
|
-
p_prev = self.get_state(params, "p_prev", init=params, cls=TensorList)
|
|
190
|
-
|
|
191
|
-
f_0 = var.get_loss(False)
|
|
192
|
-
step_size = self.global_state.get('step_size', 1)
|
|
193
|
-
|
|
194
|
-
# ------------------------ hard reset on infinite loss ----------------------- #
|
|
195
|
-
if not math.isfinite(f_0):
|
|
196
|
-
del self.global_state['f_prev']
|
|
197
|
-
var.update = params - p_prev
|
|
198
|
-
self.global_state.clear()
|
|
199
|
-
self.state.clear()
|
|
200
|
-
self.global_state["step_size"] = step_size / 10
|
|
201
|
-
return var
|
|
202
|
-
|
|
203
|
-
# ---------------------------- soft reset if stuck --------------------------- #
|
|
204
|
-
if "igrad" in self.state[params[0]]:
|
|
205
|
-
n_bad = self.global_state.get('n_bad', 0)
|
|
206
|
-
|
|
207
|
-
f_prev = self.global_state.get("f_prev", None)
|
|
208
|
-
if f_prev is not None:
|
|
209
|
-
|
|
210
|
-
decay2 = self.defaults["decay2"]
|
|
211
|
-
decay = self.global_state.get("decay", self.defaults["decay"])
|
|
212
|
-
|
|
213
|
-
if f_0 >= f_prev:
|
|
214
|
-
|
|
215
|
-
igrad = self.get_state(params, "igrad", cls=TensorList)
|
|
216
|
-
del self.global_state['f_prev']
|
|
217
|
-
|
|
218
|
-
# undo previous update
|
|
219
|
-
var.update = params - p_prev
|
|
220
|
-
|
|
221
|
-
# increment n_bad
|
|
222
|
-
self.global_state['n_bad'] = n_bad + 1
|
|
223
|
-
|
|
224
|
-
# decay step size
|
|
225
|
-
self.global_state['step_size'] = step_size * self.defaults["nminus"]
|
|
226
|
-
|
|
227
|
-
# soft reset
|
|
228
|
-
if n_bad > 0:
|
|
229
|
-
igrad *= decay
|
|
230
|
-
self.global_state["decay"] = decay*decay2
|
|
231
|
-
self.global_state['n_bad'] = 0
|
|
232
|
-
|
|
233
|
-
return var
|
|
234
|
-
|
|
235
|
-
else:
|
|
236
|
-
# increase step size and reset n_bad
|
|
237
|
-
self.global_state['step_size'] = step_size * self.defaults["nplus"]
|
|
238
|
-
self.global_state['n_bad'] = 0
|
|
239
|
-
self.global_state["decay"] = self.defaults["decay"]
|
|
240
|
-
|
|
241
|
-
self.global_state['f_prev'] = float(f_0)
|
|
242
|
-
|
|
243
|
-
# ------------------------------ determine index ----------------------------- #
|
|
244
|
-
idx, igrad = _icd_get_idx(self, params)
|
|
245
|
-
|
|
246
|
-
# -------------------------- find descent direction -------------------------- #
|
|
247
|
-
h_vec = self.get_state(params, 'h_vec', init=lambda x: torch.full_like(x, 1e-3), cls=TensorList)
|
|
248
|
-
h = float(h_vec.flat_get(idx))
|
|
249
|
-
|
|
250
|
-
params.flat_set_lambda_(idx, lambda x: x + h)
|
|
251
|
-
f_p = closure(False)
|
|
252
|
-
|
|
253
|
-
params.flat_set_lambda_(idx, lambda x: x - 2*h)
|
|
254
|
-
f_n = closure(False)
|
|
255
|
-
params.flat_set_lambda_(idx, lambda x: x + h)
|
|
256
|
-
|
|
257
|
-
# ---------------------------------- adapt h --------------------------------- #
|
|
258
|
-
if f_0 <= f_p and f_0 <= f_n:
|
|
259
|
-
h_vec.flat_set_lambda_(idx, lambda x: max(x/2, 1e-10))
|
|
260
|
-
else:
|
|
261
|
-
if abs(f_0 - f_n) < 1e-12 or abs((f_p - f_0) / (f_0 - f_n) - 1) < 1e-2:
|
|
262
|
-
h_vec.flat_set_lambda_(idx, lambda x: min(x*2, 1e10))
|
|
263
|
-
|
|
264
|
-
# ------------------------------- update igrad ------------------------------- #
|
|
265
|
-
if f_0 < f_p and f_0 < f_n: alpha = 0
|
|
266
|
-
else: alpha = (f_p - f_n) / (2*h)
|
|
267
|
-
|
|
268
|
-
igrad.flat_set_(idx, alpha)
|
|
269
|
-
|
|
270
|
-
# ----------------------------- create the update ---------------------------- #
|
|
271
|
-
var.update = igrad * step_size
|
|
272
|
-
p_prev.copy_(params)
|
|
273
|
-
return var
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
class CCDLS(Module):
|
|
277
|
-
"""CCD with line search instead of adaptive step size.
|
|
278
|
-
|
|
279
|
-
Args:
|
|
280
|
-
pmin (float, optional): multiplier to probability of picking the lowest magnitude gradient. Defaults to 0.1.
|
|
281
|
-
pmax (float, optional): multiplier to probability of picking the largest magnitude gradient. Defaults to 1.0.
|
|
282
|
-
pow (int, optional): power transform to probabilities. Defaults to 2.
|
|
283
|
-
decay (float, optional): accumulated gradient decay on failed step. Defaults to 0.5.
|
|
284
|
-
decay2 (float, optional): decay multiplier decay on failed step. Defaults to 0.25.
|
|
285
|
-
maxiter (int, optional): max number of line search iterations.
|
|
286
|
-
"""
|
|
287
|
-
def __init__(self, pmin=0.1, pmax=1.0, pow=2, decay=0.8, decay2=0.2, maxiter=10, ):
|
|
288
|
-
defaults = dict(pmin=pmin, pmax=pmax, pow=pow, maxiter=maxiter, decay=decay, decay2=decay2)
|
|
289
|
-
super().__init__(defaults)
|
|
290
|
-
|
|
291
|
-
@torch.no_grad
|
|
292
|
-
def step(self, var):
|
|
293
|
-
closure = var.closure
|
|
294
|
-
if closure is None:
|
|
295
|
-
raise RuntimeError("CD requires closure")
|
|
296
|
-
|
|
297
|
-
params = TensorList(var.params)
|
|
298
|
-
finfo = torch.finfo(params[0].dtype)
|
|
299
|
-
f_0 = var.get_loss(False)
|
|
300
|
-
|
|
301
|
-
# ------------------------------ determine index ----------------------------- #
|
|
302
|
-
idx, igrad = _icd_get_idx(self, params)
|
|
303
|
-
|
|
304
|
-
# -------------------------- find descent direction -------------------------- #
|
|
305
|
-
h_vec = self.get_state(params, 'h_vec', init=lambda x: torch.full_like(x, 1e-3), cls=TensorList)
|
|
306
|
-
h = float(h_vec.flat_get(idx))
|
|
307
|
-
|
|
308
|
-
params.flat_set_lambda_(idx, lambda x: x + h)
|
|
309
|
-
f_p = closure(False)
|
|
310
|
-
|
|
311
|
-
params.flat_set_lambda_(idx, lambda x: x - 2*h)
|
|
312
|
-
f_n = closure(False)
|
|
313
|
-
params.flat_set_lambda_(idx, lambda x: x + h)
|
|
314
|
-
|
|
315
|
-
# ---------------------------------- adapt h --------------------------------- #
|
|
316
|
-
if f_0 <= f_p and f_0 <= f_n:
|
|
317
|
-
h_vec.flat_set_lambda_(idx, lambda x: max(x/2, finfo.tiny * 2))
|
|
318
|
-
else:
|
|
319
|
-
# here eps, not tiny
|
|
320
|
-
if abs(f_0 - f_n) < finfo.eps or abs((f_p - f_0) / (f_0 - f_n) - 1) < 1e-2:
|
|
321
|
-
h_vec.flat_set_lambda_(idx, lambda x: min(x*2, finfo.max / 2))
|
|
322
|
-
|
|
323
|
-
# ------------------------------- update igrad ------------------------------- #
|
|
324
|
-
if f_0 < f_p and f_0 < f_n: alpha = 0
|
|
325
|
-
else: alpha = (f_p - f_n) / (2*h)
|
|
326
|
-
|
|
327
|
-
igrad.flat_set_(idx, alpha)
|
|
328
|
-
|
|
329
|
-
# -------------------------------- line search ------------------------------- #
|
|
330
|
-
x0 = params.clone()
|
|
331
|
-
def f(a):
|
|
332
|
-
params.sub_(igrad, alpha=a)
|
|
333
|
-
loss = closure(False)
|
|
334
|
-
params.copy_(x0)
|
|
335
|
-
return loss
|
|
336
|
-
|
|
337
|
-
a_prev = self.global_state.get('a_prev', 1)
|
|
338
|
-
a, f_a, niter = adaptive_tracking(f, a_prev, maxiter=self.defaults['maxiter'], f_0=f_0)
|
|
339
|
-
if (a is None) or (not math.isfinite(a)) or (not math.isfinite(f_a)):
|
|
340
|
-
a = 0
|
|
341
|
-
|
|
342
|
-
# -------------------------------- set a_prev -------------------------------- #
|
|
343
|
-
decay2 = self.defaults["decay2"]
|
|
344
|
-
decay = self.global_state.get("decay", self.defaults["decay"])
|
|
345
|
-
|
|
346
|
-
if abs(a) > finfo.tiny * 2:
|
|
347
|
-
assert f_a < f_0
|
|
348
|
-
self.global_state['a_prev'] = max(min(a, finfo.max / 2), finfo.tiny * 2)
|
|
349
|
-
self.global_state["decay"] = self.defaults["decay"]
|
|
350
|
-
|
|
351
|
-
# ---------------------------- soft reset on fail ---------------------------- #
|
|
352
|
-
else:
|
|
353
|
-
igrad *= decay
|
|
354
|
-
self.global_state["decay"] = decay*decay2
|
|
355
|
-
self.global_state['a_prev'] = a_prev / 2
|
|
356
|
-
|
|
357
|
-
# -------------------------------- set update -------------------------------- #
|
|
358
|
-
var.update = igrad * a
|
|
359
|
-
return var
|
torchzero/utils/optimizer.py
CHANGED
|
@@ -110,7 +110,7 @@ def get_state_vals(state: Mapping[torch.Tensor, MutableMapping[str, Any]], param
|
|
|
110
110
|
for i, param in enumerate(params):
|
|
111
111
|
s = state[param]
|
|
112
112
|
if key not in s:
|
|
113
|
-
if must_exist: raise KeyError(f"Key {key} doesn't exist in state with keys {tuple(s.keys())}")
|
|
113
|
+
if must_exist: raise KeyError(f"Key `{key}` doesn't exist in state with keys {tuple(s.keys())}")
|
|
114
114
|
s[key] = _make_initial_state_value(param, init, i)
|
|
115
115
|
values.append(s[key])
|
|
116
116
|
return values
|
|
@@ -125,7 +125,7 @@ def get_state_vals(state: Mapping[torch.Tensor, MutableMapping[str, Any]], param
|
|
|
125
125
|
s = state[param]
|
|
126
126
|
for k_i, key in enumerate(keys):
|
|
127
127
|
if key not in s:
|
|
128
|
-
if must_exist: raise KeyError(f"Key {key} doesn't exist in state with keys {tuple(s.keys())}")
|
|
128
|
+
if must_exist: raise KeyError(f"Key `{key}` doesn't exist in state with keys {tuple(s.keys())}")
|
|
129
129
|
k_init = init[k_i] if isinstance(init, (list,tuple)) else init
|
|
130
130
|
s[key] = _make_initial_state_value(param, k_init, i)
|
|
131
131
|
values[k_i].append(s[key])
|
torchzero/utils/python_tools.py
CHANGED