torchzero 0.3.9__py3-none-any.whl → 0.3.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docs/source/conf.py +6 -4
- docs/source/docstring template.py +46 -0
- tests/test_identical.py +2 -3
- tests/test_opts.py +115 -68
- tests/test_tensorlist.py +2 -2
- tests/test_vars.py +62 -61
- torchzero/core/__init__.py +2 -3
- torchzero/core/module.py +185 -53
- torchzero/core/transform.py +327 -159
- torchzero/modules/__init__.py +3 -1
- torchzero/modules/clipping/clipping.py +120 -23
- torchzero/modules/clipping/ema_clipping.py +37 -22
- torchzero/modules/clipping/growth_clipping.py +20 -21
- torchzero/modules/experimental/__init__.py +30 -4
- torchzero/modules/experimental/absoap.py +53 -156
- torchzero/modules/experimental/adadam.py +22 -15
- torchzero/modules/experimental/adamY.py +21 -25
- torchzero/modules/experimental/adam_lambertw.py +149 -0
- torchzero/modules/{line_search/trust_region.py → experimental/adaptive_step_size.py} +37 -8
- torchzero/modules/experimental/adasoap.py +24 -129
- torchzero/modules/experimental/cosine.py +214 -0
- torchzero/modules/experimental/cubic_adam.py +97 -0
- torchzero/modules/experimental/curveball.py +12 -12
- torchzero/modules/{projections → experimental}/dct.py +11 -11
- torchzero/modules/experimental/eigendescent.py +120 -0
- torchzero/modules/experimental/etf.py +195 -0
- torchzero/modules/experimental/exp_adam.py +113 -0
- torchzero/modules/experimental/expanded_lbfgs.py +141 -0
- torchzero/modules/{projections → experimental}/fft.py +10 -10
- torchzero/modules/experimental/gradmin.py +2 -2
- torchzero/modules/experimental/hnewton.py +85 -0
- torchzero/modules/{quasi_newton/experimental → experimental}/modular_lbfgs.py +49 -50
- torchzero/modules/experimental/newton_solver.py +11 -11
- torchzero/modules/experimental/newtonnewton.py +92 -0
- torchzero/modules/experimental/parabolic_search.py +220 -0
- torchzero/modules/experimental/reduce_outward_lr.py +10 -7
- torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +12 -54
- torchzero/modules/experimental/subspace_preconditioners.py +20 -10
- torchzero/modules/experimental/tensor_adagrad.py +42 -0
- torchzero/modules/functional.py +12 -2
- torchzero/modules/grad_approximation/fdm.py +31 -4
- torchzero/modules/grad_approximation/forward_gradient.py +17 -7
- torchzero/modules/grad_approximation/grad_approximator.py +69 -24
- torchzero/modules/grad_approximation/rfdm.py +310 -50
- torchzero/modules/higher_order/__init__.py +1 -0
- torchzero/modules/higher_order/higher_order_newton.py +319 -0
- torchzero/modules/line_search/__init__.py +4 -4
- torchzero/modules/line_search/adaptive.py +99 -0
- torchzero/modules/line_search/backtracking.py +75 -31
- torchzero/modules/line_search/line_search.py +107 -49
- torchzero/modules/line_search/polynomial.py +233 -0
- torchzero/modules/line_search/scipy.py +20 -5
- torchzero/modules/line_search/strong_wolfe.py +52 -36
- torchzero/modules/misc/__init__.py +27 -0
- torchzero/modules/misc/debug.py +48 -0
- torchzero/modules/misc/escape.py +60 -0
- torchzero/modules/misc/gradient_accumulation.py +70 -0
- torchzero/modules/misc/misc.py +316 -0
- torchzero/modules/misc/multistep.py +158 -0
- torchzero/modules/misc/regularization.py +171 -0
- torchzero/modules/misc/split.py +103 -0
- torchzero/modules/{ops → misc}/switch.py +48 -7
- torchzero/modules/momentum/__init__.py +1 -1
- torchzero/modules/momentum/averaging.py +25 -10
- torchzero/modules/momentum/cautious.py +115 -40
- torchzero/modules/momentum/ema.py +92 -41
- torchzero/modules/momentum/experimental.py +21 -13
- torchzero/modules/momentum/matrix_momentum.py +145 -76
- torchzero/modules/momentum/momentum.py +25 -4
- torchzero/modules/ops/__init__.py +3 -31
- torchzero/modules/ops/accumulate.py +51 -25
- torchzero/modules/ops/binary.py +108 -62
- torchzero/modules/ops/multi.py +95 -34
- torchzero/modules/ops/reduce.py +31 -23
- torchzero/modules/ops/unary.py +37 -21
- torchzero/modules/ops/utility.py +53 -45
- torchzero/modules/optimizers/__init__.py +12 -3
- torchzero/modules/optimizers/adagrad.py +48 -29
- torchzero/modules/optimizers/adahessian.py +223 -0
- torchzero/modules/optimizers/adam.py +35 -37
- torchzero/modules/optimizers/adan.py +110 -0
- torchzero/modules/optimizers/adaptive_heavyball.py +57 -0
- torchzero/modules/optimizers/esgd.py +171 -0
- torchzero/modules/optimizers/ladagrad.py +183 -0
- torchzero/modules/optimizers/lion.py +4 -4
- torchzero/modules/optimizers/mars.py +91 -0
- torchzero/modules/optimizers/msam.py +186 -0
- torchzero/modules/optimizers/muon.py +32 -7
- torchzero/modules/optimizers/orthograd.py +4 -5
- torchzero/modules/optimizers/rmsprop.py +19 -19
- torchzero/modules/optimizers/rprop.py +89 -52
- torchzero/modules/optimizers/sam.py +163 -0
- torchzero/modules/optimizers/shampoo.py +55 -27
- torchzero/modules/optimizers/soap.py +40 -37
- torchzero/modules/optimizers/sophia_h.py +82 -25
- torchzero/modules/projections/__init__.py +2 -4
- torchzero/modules/projections/cast.py +51 -0
- torchzero/modules/projections/galore.py +4 -2
- torchzero/modules/projections/projection.py +212 -118
- torchzero/modules/quasi_newton/__init__.py +44 -5
- torchzero/modules/quasi_newton/cg.py +190 -39
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +163 -0
- torchzero/modules/quasi_newton/lbfgs.py +154 -97
- torchzero/modules/quasi_newton/lsr1.py +102 -58
- torchzero/modules/quasi_newton/quasi_newton.py +1032 -177
- torchzero/modules/quasi_newton/trust_region.py +397 -0
- torchzero/modules/second_order/__init__.py +2 -2
- torchzero/modules/second_order/newton.py +245 -54
- torchzero/modules/second_order/newton_cg.py +311 -21
- torchzero/modules/second_order/nystrom.py +124 -21
- torchzero/modules/smoothing/gaussian.py +55 -21
- torchzero/modules/smoothing/laplacian.py +20 -12
- torchzero/modules/step_size/__init__.py +2 -0
- torchzero/modules/step_size/adaptive.py +122 -0
- torchzero/modules/step_size/lr.py +154 -0
- torchzero/modules/weight_decay/__init__.py +1 -1
- torchzero/modules/weight_decay/weight_decay.py +126 -10
- torchzero/modules/wrappers/optim_wrapper.py +40 -12
- torchzero/optim/wrappers/directsearch.py +281 -0
- torchzero/optim/wrappers/fcmaes.py +105 -0
- torchzero/optim/wrappers/mads.py +89 -0
- torchzero/optim/wrappers/nevergrad.py +20 -5
- torchzero/optim/wrappers/nlopt.py +28 -14
- torchzero/optim/wrappers/optuna.py +70 -0
- torchzero/optim/wrappers/scipy.py +167 -16
- torchzero/utils/__init__.py +3 -7
- torchzero/utils/derivatives.py +5 -4
- torchzero/utils/linalg/__init__.py +1 -1
- torchzero/utils/linalg/solve.py +251 -12
- torchzero/utils/numberlist.py +2 -0
- torchzero/utils/optimizer.py +55 -74
- torchzero/utils/python_tools.py +27 -4
- torchzero/utils/tensorlist.py +40 -28
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/METADATA +76 -51
- torchzero-0.3.11.dist-info/RECORD +159 -0
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/WHEEL +1 -1
- torchzero/core/preconditioner.py +0 -138
- torchzero/modules/experimental/algebraic_newton.py +0 -145
- torchzero/modules/experimental/soapy.py +0 -290
- torchzero/modules/experimental/spectral.py +0 -288
- torchzero/modules/experimental/structured_newton.py +0 -111
- torchzero/modules/experimental/tropical_newton.py +0 -136
- torchzero/modules/lr/__init__.py +0 -2
- torchzero/modules/lr/lr.py +0 -59
- torchzero/modules/lr/step_size.py +0 -97
- torchzero/modules/ops/debug.py +0 -25
- torchzero/modules/ops/misc.py +0 -419
- torchzero/modules/ops/split.py +0 -75
- torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
- torchzero/modules/quasi_newton/olbfgs.py +0 -196
- torchzero-0.3.9.dist-info/RECORD +0 -131
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/licenses/LICENSE +0 -0
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/top_level.txt +0 -0
|
@@ -8,15 +8,16 @@ from typing import Any
|
|
|
8
8
|
import numpy as np
|
|
9
9
|
import torch
|
|
10
10
|
|
|
11
|
-
from ...core import Module, Target,
|
|
11
|
+
from ...core import Module, Target, Var
|
|
12
12
|
from ...utils import tofloat
|
|
13
13
|
|
|
14
14
|
|
|
15
15
|
class MaxLineSearchItersReached(Exception): pass
|
|
16
16
|
|
|
17
17
|
|
|
18
|
-
class
|
|
18
|
+
class LineSearchBase(Module, ABC):
|
|
19
19
|
"""Base class for line searches.
|
|
20
|
+
|
|
20
21
|
This is an abstract class, to use it, subclass it and override `search`.
|
|
21
22
|
|
|
22
23
|
Args:
|
|
@@ -26,6 +27,62 @@ class LineSearch(Module, ABC):
|
|
|
26
27
|
the objective this many times, and step size with the lowest loss value will be used.
|
|
27
28
|
This is useful when passing `make_objective` to an external library which
|
|
28
29
|
doesn't have a maxiter option. Defaults to None.
|
|
30
|
+
|
|
31
|
+
Other useful methods:
|
|
32
|
+
* `evaluate_step_size` - returns loss with a given scalar step size
|
|
33
|
+
* `evaluate_step_size_loss_and_derivative` - returns loss and directional derivative with a given scalar step size
|
|
34
|
+
* `make_objective` - creates a function that accepts a scalar step size and returns loss. This can be passed to a scalar solver, such as scipy.optimize.minimize_scalar.
|
|
35
|
+
* `make_objective_with_derivative` - creates a function that accepts a scalar step size and returns a tuple with loss and directional derivative. This can be passed to a scalar solver.
|
|
36
|
+
|
|
37
|
+
Examples:
|
|
38
|
+
#### Basic line search
|
|
39
|
+
|
|
40
|
+
This evaluates all step sizes in a range by using the :code:`self.evaluate_step_size` method.
|
|
41
|
+
|
|
42
|
+
.. code-block:: python
|
|
43
|
+
|
|
44
|
+
class GridLineSearch(LineSearch):
|
|
45
|
+
def __init__(self, start, end, num):
|
|
46
|
+
defaults = dict(start=start,end=end,num=num)
|
|
47
|
+
super().__init__(defaults)
|
|
48
|
+
|
|
49
|
+
@torch.no_grad
|
|
50
|
+
def search(self, update, var):
|
|
51
|
+
settings = self.settings[var.params[0]]
|
|
52
|
+
start = settings["start"]
|
|
53
|
+
end = settings["end"]
|
|
54
|
+
num = settings["num"]
|
|
55
|
+
|
|
56
|
+
lowest_loss = float("inf")
|
|
57
|
+
best_step_size = best_step_size
|
|
58
|
+
|
|
59
|
+
for step_size in torch.linspace(start,end,num):
|
|
60
|
+
loss = self.evaluate_step_size(step_size.item(), var=var, backward=False)
|
|
61
|
+
if loss < lowest_loss:
|
|
62
|
+
lowest_loss = loss
|
|
63
|
+
best_step_size = step_size
|
|
64
|
+
|
|
65
|
+
return best_step_size
|
|
66
|
+
|
|
67
|
+
#### Using external solver via self.make_objective
|
|
68
|
+
|
|
69
|
+
Here we let :code:`scipy.optimize.minimize_scalar` solver find the best step size via :code:`self.make_objective`
|
|
70
|
+
|
|
71
|
+
.. code-block:: python
|
|
72
|
+
|
|
73
|
+
class ScipyMinimizeScalar(LineSearch):
|
|
74
|
+
def __init__(self, method: str | None = None):
|
|
75
|
+
defaults = dict(method=method)
|
|
76
|
+
super().__init__(defaults)
|
|
77
|
+
|
|
78
|
+
@torch.no_grad
|
|
79
|
+
def search(self, update, var):
|
|
80
|
+
objective = self.make_objective(var=var)
|
|
81
|
+
method = self.settings[var.params[0]]["method"]
|
|
82
|
+
|
|
83
|
+
res = self.scopt.minimize_scalar(objective, method=method)
|
|
84
|
+
return res.x
|
|
85
|
+
|
|
29
86
|
"""
|
|
30
87
|
def __init__(self, defaults: dict[str, Any] | None, maxiter: int | None = None):
|
|
31
88
|
super().__init__(defaults)
|
|
@@ -62,12 +119,12 @@ class LineSearch(Module, ABC):
|
|
|
62
119
|
if any(a!=0 for a in alpha):
|
|
63
120
|
torch._foreach_add_(params, torch._foreach_mul(update, alpha))
|
|
64
121
|
|
|
65
|
-
def _loss(self, step_size: float,
|
|
122
|
+
def _loss(self, step_size: float, var: Var, closure, params: list[torch.Tensor],
|
|
66
123
|
update: list[torch.Tensor], backward:bool=False) -> float:
|
|
67
124
|
|
|
68
125
|
# if step_size is 0, we might already know the loss
|
|
69
|
-
if (
|
|
70
|
-
return tofloat(
|
|
126
|
+
if (var.loss is not None) and (step_size == 0):
|
|
127
|
+
return tofloat(var.loss)
|
|
71
128
|
|
|
72
129
|
# check max iter
|
|
73
130
|
if self._maxiter is not None and self._current_iter >= self._maxiter: raise MaxLineSearchItersReached
|
|
@@ -85,23 +142,23 @@ class LineSearch(Module, ABC):
|
|
|
85
142
|
self._lowest_loss = tofloat(loss)
|
|
86
143
|
self._best_step_size = step_size
|
|
87
144
|
|
|
88
|
-
# if evaluated loss at step size 0, set it to
|
|
145
|
+
# if evaluated loss at step size 0, set it to var.loss
|
|
89
146
|
if step_size == 0:
|
|
90
|
-
|
|
91
|
-
if backward:
|
|
147
|
+
var.loss = loss
|
|
148
|
+
if backward: var.grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
|
|
92
149
|
|
|
93
150
|
return tofloat(loss)
|
|
94
151
|
|
|
95
|
-
def _loss_derivative(self, step_size: float,
|
|
152
|
+
def _loss_derivative(self, step_size: float, var: Var, closure,
|
|
96
153
|
params: list[torch.Tensor], update: list[torch.Tensor]):
|
|
97
154
|
# if step_size is 0, we might already know the derivative
|
|
98
|
-
if (
|
|
99
|
-
loss = self._loss(step_size=step_size,
|
|
100
|
-
derivative = - sum(t.sum() for t in torch._foreach_mul(
|
|
155
|
+
if (var.grad is not None) and (step_size == 0):
|
|
156
|
+
loss = self._loss(step_size=step_size,var=var,closure=closure,params=params,update=update,backward=False)
|
|
157
|
+
derivative = - sum(t.sum() for t in torch._foreach_mul(var.grad, update))
|
|
101
158
|
|
|
102
159
|
else:
|
|
103
160
|
# loss with a backward pass sets params.grad
|
|
104
|
-
loss = self._loss(step_size=step_size,
|
|
161
|
+
loss = self._loss(step_size=step_size,var=var,closure=closure,params=params,update=update,backward=True)
|
|
105
162
|
|
|
106
163
|
# directional derivative
|
|
107
164
|
derivative = - sum(t.sum() for t in torch._foreach_mul([p.grad if p.grad is not None
|
|
@@ -109,73 +166,74 @@ class LineSearch(Module, ABC):
|
|
|
109
166
|
|
|
110
167
|
return loss, tofloat(derivative)
|
|
111
168
|
|
|
112
|
-
def evaluate_step_size(self, step_size: float,
|
|
113
|
-
closure =
|
|
169
|
+
def evaluate_step_size(self, step_size: float, var: Var, backward:bool=False):
|
|
170
|
+
closure = var.closure
|
|
114
171
|
if closure is None: raise RuntimeError('line search requires closure')
|
|
115
|
-
return self._loss(step_size=step_size,
|
|
172
|
+
return self._loss(step_size=step_size, var=var, closure=closure, params=var.params,update=var.get_update(),backward=backward)
|
|
116
173
|
|
|
117
|
-
def evaluate_step_size_loss_and_derivative(self, step_size: float,
|
|
118
|
-
closure =
|
|
174
|
+
def evaluate_step_size_loss_and_derivative(self, step_size: float, var: Var):
|
|
175
|
+
closure = var.closure
|
|
119
176
|
if closure is None: raise RuntimeError('line search requires closure')
|
|
120
|
-
return self._loss_derivative(step_size=step_size,
|
|
177
|
+
return self._loss_derivative(step_size=step_size, var=var, closure=closure, params=var.params,update=var.get_update())
|
|
121
178
|
|
|
122
|
-
def make_objective(self,
|
|
123
|
-
closure =
|
|
179
|
+
def make_objective(self, var: Var, backward:bool=False):
|
|
180
|
+
closure = var.closure
|
|
124
181
|
if closure is None: raise RuntimeError('line search requires closure')
|
|
125
|
-
return partial(self._loss,
|
|
182
|
+
return partial(self._loss, var=var, closure=closure, params=var.params, update=var.get_update(), backward=backward)
|
|
126
183
|
|
|
127
|
-
def make_objective_with_derivative(self,
|
|
128
|
-
closure =
|
|
184
|
+
def make_objective_with_derivative(self, var: Var):
|
|
185
|
+
closure = var.closure
|
|
129
186
|
if closure is None: raise RuntimeError('line search requires closure')
|
|
130
|
-
return partial(self._loss_derivative,
|
|
187
|
+
return partial(self._loss_derivative, var=var, closure=closure, params=var.params, update=var.get_update())
|
|
131
188
|
|
|
132
189
|
@abstractmethod
|
|
133
|
-
def search(self, update: list[torch.Tensor],
|
|
190
|
+
def search(self, update: list[torch.Tensor], var: Var) -> float:
|
|
134
191
|
"""Finds the step size to use"""
|
|
135
192
|
|
|
136
193
|
@torch.no_grad
|
|
137
|
-
def step(self,
|
|
194
|
+
def step(self, var: Var) -> Var:
|
|
138
195
|
self._reset()
|
|
139
|
-
params =
|
|
140
|
-
update =
|
|
196
|
+
params = var.params
|
|
197
|
+
update = var.get_update()
|
|
141
198
|
|
|
142
199
|
try:
|
|
143
|
-
step_size = self.search(update=update,
|
|
200
|
+
step_size = self.search(update=update, var=var)
|
|
144
201
|
except MaxLineSearchItersReached:
|
|
145
202
|
step_size = self._best_step_size
|
|
146
203
|
|
|
147
204
|
# set loss_approx
|
|
148
|
-
if
|
|
205
|
+
if var.loss_approx is None: var.loss_approx = self._lowest_loss
|
|
149
206
|
|
|
150
207
|
# this is last module - set step size to found step_size times lr
|
|
151
|
-
if
|
|
208
|
+
if var.is_last:
|
|
152
209
|
|
|
153
|
-
if
|
|
210
|
+
if var.last_module_lrs is None:
|
|
154
211
|
self.set_step_size_(step_size, params=params, update=update)
|
|
155
212
|
|
|
156
213
|
else:
|
|
157
|
-
self._set_per_parameter_step_size_([step_size*lr for lr in
|
|
214
|
+
self._set_per_parameter_step_size_([step_size*lr for lr in var.last_module_lrs], params=params, update=update)
|
|
158
215
|
|
|
159
|
-
|
|
160
|
-
return
|
|
216
|
+
var.stop = True; var.skip_update = True
|
|
217
|
+
return var
|
|
161
218
|
|
|
162
219
|
# revert parameters and multiply update by step size
|
|
163
220
|
self.set_step_size_(0, params=params, update=update)
|
|
164
|
-
torch._foreach_mul_(
|
|
165
|
-
return
|
|
221
|
+
torch._foreach_mul_(var.update, step_size)
|
|
222
|
+
return var
|
|
166
223
|
|
|
167
224
|
|
|
168
|
-
class GridLineSearch(LineSearch):
|
|
169
|
-
"""Mostly for testing, this is not practical"""
|
|
170
|
-
def __init__(self, start, end, num):
|
|
171
|
-
defaults = dict(start=start,end=end,num=num)
|
|
172
|
-
super().__init__(defaults)
|
|
173
225
|
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
226
|
+
# class GridLineSearch(LineSearch):
|
|
227
|
+
# """Mostly for testing, this is not practical"""
|
|
228
|
+
# def __init__(self, start, end, num):
|
|
229
|
+
# defaults = dict(start=start,end=end,num=num)
|
|
230
|
+
# super().__init__(defaults)
|
|
231
|
+
|
|
232
|
+
# @torch.no_grad
|
|
233
|
+
# def search(self, update, var):
|
|
234
|
+
# start,end,num=itemgetter('start','end','num')(self.settings[var.params[0]])
|
|
177
235
|
|
|
178
|
-
|
|
179
|
-
|
|
236
|
+
# for lr in torch.linspace(start,end,num):
|
|
237
|
+
# self.evaluate_step_size(lr.item(), var=var, backward=False)
|
|
180
238
|
|
|
181
|
-
|
|
239
|
+
# return self._best_step_size
|
|
@@ -0,0 +1,233 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import torch
|
|
3
|
+
|
|
4
|
+
from .line_search import LineSearchBase
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
# polynomial interpolation
|
|
8
|
+
# this code is from https://github.com/hjmshi/PyTorch-LBFGS/blob/master/functions/LBFGS.py
|
|
9
|
+
# PyTorch-LBFGS: A PyTorch Implementation of L-BFGS
|
|
10
|
+
def polyinterp(points, x_min_bound=None, x_max_bound=None, plot=False):
|
|
11
|
+
"""
|
|
12
|
+
Gives the minimizer and minimum of the interpolating polynomial over given points
|
|
13
|
+
based on function and derivative information. Defaults to bisection if no critical
|
|
14
|
+
points are valid.
|
|
15
|
+
|
|
16
|
+
Based on polyinterp.m Matlab function in minFunc by Mark Schmidt with some slight
|
|
17
|
+
modifications.
|
|
18
|
+
|
|
19
|
+
Implemented by: Hao-Jun Michael Shi and Dheevatsa Mudigere
|
|
20
|
+
Last edited 12/6/18.
|
|
21
|
+
|
|
22
|
+
Inputs:
|
|
23
|
+
points (nparray): two-dimensional array with each point of form [x f g]
|
|
24
|
+
x_min_bound (float): minimum value that brackets minimum (default: minimum of points)
|
|
25
|
+
x_max_bound (float): maximum value that brackets minimum (default: maximum of points)
|
|
26
|
+
plot (bool): plot interpolating polynomial
|
|
27
|
+
|
|
28
|
+
Outputs:
|
|
29
|
+
x_sol (float): minimizer of interpolating polynomial
|
|
30
|
+
F_min (float): minimum of interpolating polynomial
|
|
31
|
+
|
|
32
|
+
Note:
|
|
33
|
+
. Set f or g to np.nan if they are unknown
|
|
34
|
+
|
|
35
|
+
"""
|
|
36
|
+
no_points = points.shape[0]
|
|
37
|
+
order = np.sum(1 - np.isnan(points[:, 1:3]).astype('int')) - 1
|
|
38
|
+
|
|
39
|
+
x_min = np.min(points[:, 0])
|
|
40
|
+
x_max = np.max(points[:, 0])
|
|
41
|
+
|
|
42
|
+
# compute bounds of interpolation area
|
|
43
|
+
if x_min_bound is None:
|
|
44
|
+
x_min_bound = x_min
|
|
45
|
+
if x_max_bound is None:
|
|
46
|
+
x_max_bound = x_max
|
|
47
|
+
|
|
48
|
+
# explicit formula for quadratic interpolation
|
|
49
|
+
if no_points == 2 and order == 2 and plot is False:
|
|
50
|
+
# Solution to quadratic interpolation is given by:
|
|
51
|
+
# a = -(f1 - f2 - g1(x1 - x2))/(x1 - x2)^2
|
|
52
|
+
# x_min = x1 - g1/(2a)
|
|
53
|
+
# if x1 = 0, then is given by:
|
|
54
|
+
# x_min = - (g1*x2^2)/(2(f2 - f1 - g1*x2))
|
|
55
|
+
|
|
56
|
+
if points[0, 0] == 0:
|
|
57
|
+
x_sol = -points[0, 2] * points[1, 0] ** 2 / (2 * (points[1, 1] - points[0, 1] - points[0, 2] * points[1, 0]))
|
|
58
|
+
else:
|
|
59
|
+
a = -(points[0, 1] - points[1, 1] - points[0, 2] * (points[0, 0] - points[1, 0])) / (points[0, 0] - points[1, 0]) ** 2
|
|
60
|
+
x_sol = points[0, 0] - points[0, 2]/(2*a)
|
|
61
|
+
|
|
62
|
+
x_sol = np.minimum(np.maximum(x_min_bound, x_sol), x_max_bound)
|
|
63
|
+
|
|
64
|
+
# explicit formula for cubic interpolation
|
|
65
|
+
elif no_points == 2 and order == 3 and plot is False:
|
|
66
|
+
# Solution to cubic interpolation is given by:
|
|
67
|
+
# d1 = g1 + g2 - 3((f1 - f2)/(x1 - x2))
|
|
68
|
+
# d2 = sqrt(d1^2 - g1*g2)
|
|
69
|
+
# x_min = x2 - (x2 - x1)*((g2 + d2 - d1)/(g2 - g1 + 2*d2))
|
|
70
|
+
d1 = points[0, 2] + points[1, 2] - 3 * ((points[0, 1] - points[1, 1]) / (points[0, 0] - points[1, 0]))
|
|
71
|
+
d2 = np.sqrt(d1 ** 2 - points[0, 2] * points[1, 2])
|
|
72
|
+
if np.isreal(d2):
|
|
73
|
+
x_sol = points[1, 0] - (points[1, 0] - points[0, 0]) * ((points[1, 2] + d2 - d1) / (points[1, 2] - points[0, 2] + 2 * d2))
|
|
74
|
+
x_sol = np.minimum(np.maximum(x_min_bound, x_sol), x_max_bound)
|
|
75
|
+
else:
|
|
76
|
+
x_sol = (x_max_bound + x_min_bound)/2
|
|
77
|
+
|
|
78
|
+
# solve linear system
|
|
79
|
+
else:
|
|
80
|
+
# define linear constraints
|
|
81
|
+
A = np.zeros((0, order + 1))
|
|
82
|
+
b = np.zeros((0, 1))
|
|
83
|
+
|
|
84
|
+
# add linear constraints on function values
|
|
85
|
+
for i in range(no_points):
|
|
86
|
+
if not np.isnan(points[i, 1]):
|
|
87
|
+
constraint = np.zeros((1, order + 1))
|
|
88
|
+
for j in range(order, -1, -1):
|
|
89
|
+
constraint[0, order - j] = points[i, 0] ** j
|
|
90
|
+
A = np.append(A, constraint, 0)
|
|
91
|
+
b = np.append(b, points[i, 1])
|
|
92
|
+
|
|
93
|
+
# add linear constraints on gradient values
|
|
94
|
+
for i in range(no_points):
|
|
95
|
+
if not np.isnan(points[i, 2]):
|
|
96
|
+
constraint = np.zeros((1, order + 1))
|
|
97
|
+
for j in range(order):
|
|
98
|
+
constraint[0, j] = (order - j) * points[i, 0] ** (order - j - 1)
|
|
99
|
+
A = np.append(A, constraint, 0)
|
|
100
|
+
b = np.append(b, points[i, 2])
|
|
101
|
+
|
|
102
|
+
# check if system is solvable
|
|
103
|
+
if A.shape[0] != A.shape[1] or np.linalg.matrix_rank(A) != A.shape[0]:
|
|
104
|
+
x_sol = (x_min_bound + x_max_bound)/2
|
|
105
|
+
f_min = np.inf
|
|
106
|
+
else:
|
|
107
|
+
# solve linear system for interpolating polynomial
|
|
108
|
+
coeff = np.linalg.solve(A, b)
|
|
109
|
+
|
|
110
|
+
# compute critical points
|
|
111
|
+
dcoeff = np.zeros(order)
|
|
112
|
+
for i in range(len(coeff) - 1):
|
|
113
|
+
dcoeff[i] = coeff[i] * (order - i)
|
|
114
|
+
|
|
115
|
+
crit_pts = np.array([x_min_bound, x_max_bound])
|
|
116
|
+
crit_pts = np.append(crit_pts, points[:, 0])
|
|
117
|
+
|
|
118
|
+
if not np.isinf(dcoeff).any():
|
|
119
|
+
roots = np.roots(dcoeff)
|
|
120
|
+
crit_pts = np.append(crit_pts, roots)
|
|
121
|
+
|
|
122
|
+
# test critical points
|
|
123
|
+
f_min = np.inf
|
|
124
|
+
x_sol = (x_min_bound + x_max_bound) / 2 # defaults to bisection
|
|
125
|
+
for crit_pt in crit_pts:
|
|
126
|
+
if np.isreal(crit_pt) and crit_pt >= x_min_bound and crit_pt <= x_max_bound:
|
|
127
|
+
F_cp = np.polyval(coeff, crit_pt)
|
|
128
|
+
if np.isreal(F_cp) and F_cp < f_min:
|
|
129
|
+
x_sol = np.real(crit_pt)
|
|
130
|
+
f_min = np.real(F_cp)
|
|
131
|
+
|
|
132
|
+
if(plot):
|
|
133
|
+
import matplotlib.pyplot as plt
|
|
134
|
+
plt.figure()
|
|
135
|
+
x = np.arange(x_min_bound, x_max_bound, (x_max_bound - x_min_bound)/10000)
|
|
136
|
+
f = np.polyval(coeff, x)
|
|
137
|
+
plt.plot(x, f)
|
|
138
|
+
plt.plot(x_sol, f_min, 'x')
|
|
139
|
+
|
|
140
|
+
return x_sol
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
# class PolynomialLineSearch(LineSearch):
|
|
145
|
+
# """TODO
|
|
146
|
+
|
|
147
|
+
# Line search via polynomial interpolation.
|
|
148
|
+
|
|
149
|
+
# Args:
|
|
150
|
+
# init (float, optional): Initial step size. Defaults to 1.0.
|
|
151
|
+
# c1 (float, optional): Acceptance value for weak wolfe condition. Defaults to 1e-4.
|
|
152
|
+
# c2 (float, optional): Acceptance value for strong wolfe condition (set to 0.1 for conjugate gradient). Defaults to 0.9.
|
|
153
|
+
# maxiter (int, optional): Maximum number of line search iterations. Defaults to 25.
|
|
154
|
+
# maxzoom (int, optional): Maximum number of zoom iterations. Defaults to 10.
|
|
155
|
+
# expand (float, optional): Expansion factor (multipler to step size when weak condition not satisfied). Defaults to 2.0.
|
|
156
|
+
# adaptive (bool, optional):
|
|
157
|
+
# when enabled, if line search failed, initial step size is reduced.
|
|
158
|
+
# Otherwise it is reset to initial value. Defaults to True.
|
|
159
|
+
# plus_minus (bool, optional):
|
|
160
|
+
# If enabled and the direction is not descent direction, performs line search in opposite direction. Defaults to False.
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
# Examples:
|
|
164
|
+
# Conjugate gradient method with strong wolfe line search. Nocedal, Wright recommend setting c2 to 0.1 for CG.
|
|
165
|
+
|
|
166
|
+
# .. code-block:: python
|
|
167
|
+
|
|
168
|
+
# opt = tz.Modular(
|
|
169
|
+
# model.parameters(),
|
|
170
|
+
# tz.m.PolakRibiere(),
|
|
171
|
+
# tz.m.StrongWolfe(c2=0.1)
|
|
172
|
+
# )
|
|
173
|
+
|
|
174
|
+
# LBFGS strong wolfe line search:
|
|
175
|
+
|
|
176
|
+
# .. code-block:: python
|
|
177
|
+
|
|
178
|
+
# opt = tz.Modular(
|
|
179
|
+
# model.parameters(),
|
|
180
|
+
# tz.m.LBFGS(),
|
|
181
|
+
# tz.m.StrongWolfe()
|
|
182
|
+
# )
|
|
183
|
+
|
|
184
|
+
# """
|
|
185
|
+
# def __init__(
|
|
186
|
+
# self,
|
|
187
|
+
# init: float = 1.0,
|
|
188
|
+
# c1: float = 1e-4,
|
|
189
|
+
# c2: float = 0.9,
|
|
190
|
+
# maxiter: int = 25,
|
|
191
|
+
# maxzoom: int = 10,
|
|
192
|
+
# # a_max: float = 1e10,
|
|
193
|
+
# expand: float = 2.0,
|
|
194
|
+
# adaptive = True,
|
|
195
|
+
# plus_minus = False,
|
|
196
|
+
# ):
|
|
197
|
+
# defaults=dict(init=init,c1=c1,c2=c2,maxiter=maxiter,maxzoom=maxzoom,
|
|
198
|
+
# expand=expand, adaptive=adaptive, plus_minus=plus_minus)
|
|
199
|
+
# super().__init__(defaults=defaults)
|
|
200
|
+
|
|
201
|
+
# self.global_state['initial_scale'] = 1.0
|
|
202
|
+
# self.global_state['beta_scale'] = 1.0
|
|
203
|
+
|
|
204
|
+
# @torch.no_grad
|
|
205
|
+
# def search(self, update, var):
|
|
206
|
+
# objective = self.make_objective_with_derivative(var=var)
|
|
207
|
+
|
|
208
|
+
# init, c1, c2, maxiter, maxzoom, expand, adaptive, plus_minus = itemgetter(
|
|
209
|
+
# 'init', 'c1', 'c2', 'maxiter', 'maxzoom',
|
|
210
|
+
# 'expand', 'adaptive', 'plus_minus')(self.settings[var.params[0]])
|
|
211
|
+
|
|
212
|
+
# f_0, g_0 = objective(0)
|
|
213
|
+
|
|
214
|
+
# step_size,f_a = strong_wolfe(
|
|
215
|
+
# objective,
|
|
216
|
+
# f_0=f_0, g_0=g_0,
|
|
217
|
+
# init=init * self.global_state.setdefault("initial_scale", 1),
|
|
218
|
+
# c1=c1,
|
|
219
|
+
# c2=c2,
|
|
220
|
+
# maxiter=maxiter,
|
|
221
|
+
# maxzoom=maxzoom,
|
|
222
|
+
# expand=expand,
|
|
223
|
+
# plus_minus=plus_minus,
|
|
224
|
+
# )
|
|
225
|
+
|
|
226
|
+
# if f_a is not None and (f_a > f_0 or _notfinite(f_a)): step_size = None
|
|
227
|
+
# if step_size is not None and step_size != 0 and not _notfinite(step_size):
|
|
228
|
+
# self.global_state['initial_scale'] = min(1.0, self.global_state['initial_scale'] * math.sqrt(2))
|
|
229
|
+
# return step_size
|
|
230
|
+
|
|
231
|
+
# # fallback to backtracking on fail
|
|
232
|
+
# if adaptive: self.global_state['initial_scale'] *= 0.5
|
|
233
|
+
# return 0
|
|
@@ -3,10 +3,25 @@ from operator import itemgetter
|
|
|
3
3
|
|
|
4
4
|
import torch
|
|
5
5
|
|
|
6
|
-
from .line_search import
|
|
6
|
+
from .line_search import LineSearchBase
|
|
7
7
|
|
|
8
8
|
|
|
9
|
-
class ScipyMinimizeScalar(
|
|
9
|
+
class ScipyMinimizeScalar(LineSearchBase):
|
|
10
|
+
"""Line search via :code:`scipy.optimize.minimize_scalar` which implements brent, golden search and bounded brent methods.
|
|
11
|
+
|
|
12
|
+
Args:
|
|
13
|
+
method (str | None, optional): "brent", "golden" or "bounded". Defaults to None.
|
|
14
|
+
maxiter (int | None, optional): maximum number of function evaluations the line search is allowed to perform. Defaults to None.
|
|
15
|
+
bracket (Sequence | None, optional):
|
|
16
|
+
Either a triple (xa, xb, xc) satisfying xa < xb < xc and func(xb) < func(xa) and func(xb) < func(xc), or a pair (xa, xb) to be used as initial points for a downhill bracket search. Defaults to None.
|
|
17
|
+
bounds (Sequence | None, optional):
|
|
18
|
+
For method ‘bounded’, bounds is mandatory and must have two finite items corresponding to the optimization bounds. Defaults to None.
|
|
19
|
+
tol (float | None, optional): Tolerance for termination. Defaults to None.
|
|
20
|
+
options (dict | None, optional): A dictionary of solver options. Defaults to None.
|
|
21
|
+
|
|
22
|
+
For more details on methods and arguments refer to https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.minimize_scalar.html
|
|
23
|
+
|
|
24
|
+
"""
|
|
10
25
|
def __init__(
|
|
11
26
|
self,
|
|
12
27
|
method: str | None = None,
|
|
@@ -24,10 +39,10 @@ class ScipyMinimizeScalar(LineSearch):
|
|
|
24
39
|
|
|
25
40
|
|
|
26
41
|
@torch.no_grad
|
|
27
|
-
def search(self, update,
|
|
28
|
-
objective = self.make_objective(
|
|
42
|
+
def search(self, update, var):
|
|
43
|
+
objective = self.make_objective(var=var)
|
|
29
44
|
method, bracket, bounds, tol, options, maxiter = itemgetter(
|
|
30
|
-
'method', 'bracket', 'bounds', 'tol', 'options', 'maxiter')(self.settings[
|
|
45
|
+
'method', 'bracket', 'bounds', 'tol', 'options', 'maxiter')(self.settings[var.params[0]])
|
|
31
46
|
|
|
32
47
|
if maxiter is not None:
|
|
33
48
|
options = dict(options) if isinstance(options, Mapping) else {}
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
"""this needs to be reworked maybe but it also works"""
|
|
1
2
|
import math
|
|
2
3
|
import warnings
|
|
3
4
|
from operator import itemgetter
|
|
@@ -5,8 +6,7 @@ from operator import itemgetter
|
|
|
5
6
|
import torch
|
|
6
7
|
from torch.optim.lbfgs import _cubic_interpolate
|
|
7
8
|
|
|
8
|
-
from .line_search import
|
|
9
|
-
from .backtracking import backtracking_line_search
|
|
9
|
+
from .line_search import LineSearchBase
|
|
10
10
|
from ...utils import totensor
|
|
11
11
|
|
|
12
12
|
|
|
@@ -182,7 +182,47 @@ def _notfinite(x):
|
|
|
182
182
|
if isinstance(x, torch.Tensor): return not torch.isfinite(x).all()
|
|
183
183
|
return not math.isfinite(x)
|
|
184
184
|
|
|
185
|
-
class StrongWolfe(
|
|
185
|
+
class StrongWolfe(LineSearchBase):
|
|
186
|
+
"""Cubic interpolation line search satisfying Strong Wolfe condition.
|
|
187
|
+
|
|
188
|
+
Args:
|
|
189
|
+
init (float, optional): Initial step size. Defaults to 1.0.
|
|
190
|
+
c1 (float, optional): Acceptance value for weak wolfe condition. Defaults to 1e-4.
|
|
191
|
+
c2 (float, optional): Acceptance value for strong wolfe condition (set to 0.1 for conjugate gradient). Defaults to 0.9.
|
|
192
|
+
maxiter (int, optional): Maximum number of line search iterations. Defaults to 25.
|
|
193
|
+
maxzoom (int, optional): Maximum number of zoom iterations. Defaults to 10.
|
|
194
|
+
expand (float, optional): Expansion factor (multipler to step size when weak condition not satisfied). Defaults to 2.0.
|
|
195
|
+
use_prev (bool, optional):
|
|
196
|
+
if True, previous step size is used as the initial step size on the next step.
|
|
197
|
+
adaptive (bool, optional):
|
|
198
|
+
when enabled, if line search failed, initial step size is reduced.
|
|
199
|
+
Otherwise it is reset to initial value. Defaults to True.
|
|
200
|
+
plus_minus (bool, optional):
|
|
201
|
+
If enabled and the direction is not descent direction, performs line search in opposite direction. Defaults to False.
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
Examples:
|
|
205
|
+
Conjugate gradient method with strong wolfe line search. Nocedal, Wright recommend setting c2 to 0.1 for CG.
|
|
206
|
+
|
|
207
|
+
.. code-block:: python
|
|
208
|
+
|
|
209
|
+
opt = tz.Modular(
|
|
210
|
+
model.parameters(),
|
|
211
|
+
tz.m.PolakRibiere(),
|
|
212
|
+
tz.m.StrongWolfe(c2=0.1)
|
|
213
|
+
)
|
|
214
|
+
|
|
215
|
+
LBFGS strong wolfe line search:
|
|
216
|
+
|
|
217
|
+
.. code-block:: python
|
|
218
|
+
|
|
219
|
+
opt = tz.Modular(
|
|
220
|
+
model.parameters(),
|
|
221
|
+
tz.m.LBFGS(),
|
|
222
|
+
tz.m.StrongWolfe()
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
"""
|
|
186
226
|
def __init__(
|
|
187
227
|
self,
|
|
188
228
|
init: float = 1.0,
|
|
@@ -192,26 +232,27 @@ class StrongWolfe(LineSearch):
|
|
|
192
232
|
maxzoom: int = 10,
|
|
193
233
|
# a_max: float = 1e10,
|
|
194
234
|
expand: float = 2.0,
|
|
235
|
+
use_prev: bool = False,
|
|
195
236
|
adaptive = True,
|
|
196
|
-
fallback = False,
|
|
197
237
|
plus_minus = False,
|
|
198
238
|
):
|
|
199
239
|
defaults=dict(init=init,c1=c1,c2=c2,maxiter=maxiter,maxzoom=maxzoom,
|
|
200
|
-
expand=expand, adaptive=adaptive,
|
|
240
|
+
expand=expand, adaptive=adaptive, plus_minus=plus_minus,use_prev=use_prev)
|
|
201
241
|
super().__init__(defaults=defaults)
|
|
202
242
|
|
|
203
243
|
self.global_state['initial_scale'] = 1.0
|
|
204
244
|
self.global_state['beta_scale'] = 1.0
|
|
205
245
|
|
|
206
246
|
@torch.no_grad
|
|
207
|
-
def search(self, update,
|
|
208
|
-
objective = self.make_objective_with_derivative(
|
|
247
|
+
def search(self, update, var):
|
|
248
|
+
objective = self.make_objective_with_derivative(var=var)
|
|
209
249
|
|
|
210
|
-
init, c1, c2, maxiter, maxzoom, expand, adaptive,
|
|
250
|
+
init, c1, c2, maxiter, maxzoom, expand, adaptive, plus_minus, use_prev = itemgetter(
|
|
211
251
|
'init', 'c1', 'c2', 'maxiter', 'maxzoom',
|
|
212
|
-
'expand', 'adaptive', '
|
|
252
|
+
'expand', 'adaptive', 'plus_minus', 'use_prev')(self.settings[var.params[0]])
|
|
213
253
|
|
|
214
254
|
f_0, g_0 = objective(0)
|
|
255
|
+
if use_prev: init = self.global_state.get('prev_alpha', init)
|
|
215
256
|
|
|
216
257
|
step_size,f_a = strong_wolfe(
|
|
217
258
|
objective,
|
|
@@ -228,33 +269,8 @@ class StrongWolfe(LineSearch):
|
|
|
228
269
|
if f_a is not None and (f_a > f_0 or _notfinite(f_a)): step_size = None
|
|
229
270
|
if step_size is not None and step_size != 0 and not _notfinite(step_size):
|
|
230
271
|
self.global_state['initial_scale'] = min(1.0, self.global_state['initial_scale'] * math.sqrt(2))
|
|
272
|
+
self.global_state['prev_alpha'] = step_size
|
|
231
273
|
return step_size
|
|
232
274
|
|
|
233
|
-
# fallback to backtracking on fail
|
|
234
275
|
if adaptive: self.global_state['initial_scale'] *= 0.5
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
objective = self.make_objective(vars=vars)
|
|
238
|
-
|
|
239
|
-
# # directional derivative
|
|
240
|
-
g_0 = -sum(t.sum() for t in torch._foreach_mul(vars.get_grad(), vars.get_update()))
|
|
241
|
-
|
|
242
|
-
step_size = backtracking_line_search(
|
|
243
|
-
objective,
|
|
244
|
-
g_0,
|
|
245
|
-
init=init * self.global_state["initial_scale"],
|
|
246
|
-
beta=0.5 * self.global_state["beta_scale"],
|
|
247
|
-
c=c1,
|
|
248
|
-
maxiter=maxiter * 2,
|
|
249
|
-
a_min=None,
|
|
250
|
-
try_negative=plus_minus,
|
|
251
|
-
)
|
|
252
|
-
|
|
253
|
-
# found an alpha that reduces loss
|
|
254
|
-
if step_size is not None:
|
|
255
|
-
self.global_state['beta_scale'] = min(1.0, self.global_state.get('beta_scale', 1) * math.sqrt(1.5))
|
|
256
|
-
return step_size
|
|
257
|
-
|
|
258
|
-
# on fail reduce beta scale value
|
|
259
|
-
self.global_state['beta_scale'] /= 1.5
|
|
260
|
-
return 0
|
|
276
|
+
return 0
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
from .debug import PrintLoss, PrintParams, PrintShape, PrintUpdate
|
|
2
|
+
from .escape import EscapeAnnealing
|
|
3
|
+
from .gradient_accumulation import GradientAccumulation
|
|
4
|
+
from .misc import (
|
|
5
|
+
DivByLoss,
|
|
6
|
+
FillLoss,
|
|
7
|
+
GradSign,
|
|
8
|
+
GraftGradToUpdate,
|
|
9
|
+
GraftToGrad,
|
|
10
|
+
GraftToParams,
|
|
11
|
+
HpuEstimate,
|
|
12
|
+
LastAbsoluteRatio,
|
|
13
|
+
LastDifference,
|
|
14
|
+
LastGradDifference,
|
|
15
|
+
LastProduct,
|
|
16
|
+
LastRatio,
|
|
17
|
+
MulByLoss,
|
|
18
|
+
NoiseSign,
|
|
19
|
+
Previous,
|
|
20
|
+
RandomHvp,
|
|
21
|
+
Relative,
|
|
22
|
+
UpdateSign,
|
|
23
|
+
)
|
|
24
|
+
from .multistep import Multistep, NegateOnLossIncrease, Online, Sequential
|
|
25
|
+
from .regularization import Dropout, PerturbWeights, WeightDropout
|
|
26
|
+
from .split import Split
|
|
27
|
+
from .switch import Alternate, Switch
|