torchzero 0.3.8__py3-none-any.whl → 0.3.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_opts.py +55 -22
- tests/test_tensorlist.py +3 -3
- tests/test_vars.py +61 -61
- torchzero/core/__init__.py +2 -3
- torchzero/core/module.py +49 -49
- torchzero/core/transform.py +219 -158
- torchzero/modules/__init__.py +1 -0
- torchzero/modules/clipping/clipping.py +10 -10
- torchzero/modules/clipping/ema_clipping.py +14 -13
- torchzero/modules/clipping/growth_clipping.py +16 -18
- torchzero/modules/experimental/__init__.py +12 -3
- torchzero/modules/experimental/absoap.py +50 -156
- torchzero/modules/experimental/adadam.py +15 -14
- torchzero/modules/experimental/adamY.py +17 -27
- torchzero/modules/experimental/adasoap.py +20 -130
- torchzero/modules/experimental/curveball.py +12 -12
- torchzero/modules/experimental/diagonal_higher_order_newton.py +225 -0
- torchzero/modules/experimental/eigendescent.py +117 -0
- torchzero/modules/experimental/etf.py +172 -0
- torchzero/modules/experimental/gradmin.py +2 -2
- torchzero/modules/experimental/newton_solver.py +11 -11
- torchzero/modules/experimental/newtonnewton.py +88 -0
- torchzero/modules/experimental/reduce_outward_lr.py +8 -5
- torchzero/modules/experimental/soapy.py +19 -146
- torchzero/modules/experimental/spectral.py +79 -204
- torchzero/modules/experimental/structured_newton.py +111 -0
- torchzero/modules/experimental/subspace_preconditioners.py +13 -10
- torchzero/modules/experimental/tada.py +38 -0
- torchzero/modules/grad_approximation/fdm.py +2 -2
- torchzero/modules/grad_approximation/forward_gradient.py +5 -5
- torchzero/modules/grad_approximation/grad_approximator.py +21 -21
- torchzero/modules/grad_approximation/rfdm.py +28 -15
- torchzero/modules/higher_order/__init__.py +1 -0
- torchzero/modules/higher_order/higher_order_newton.py +256 -0
- torchzero/modules/line_search/backtracking.py +42 -23
- torchzero/modules/line_search/line_search.py +40 -40
- torchzero/modules/line_search/scipy.py +18 -3
- torchzero/modules/line_search/strong_wolfe.py +21 -32
- torchzero/modules/line_search/trust_region.py +18 -6
- torchzero/modules/lr/__init__.py +1 -1
- torchzero/modules/lr/{step_size.py → adaptive.py} +22 -26
- torchzero/modules/lr/lr.py +20 -16
- torchzero/modules/momentum/averaging.py +25 -10
- torchzero/modules/momentum/cautious.py +73 -35
- torchzero/modules/momentum/ema.py +92 -41
- torchzero/modules/momentum/experimental.py +21 -13
- torchzero/modules/momentum/matrix_momentum.py +96 -54
- torchzero/modules/momentum/momentum.py +24 -4
- torchzero/modules/ops/accumulate.py +51 -21
- torchzero/modules/ops/binary.py +36 -36
- torchzero/modules/ops/debug.py +7 -7
- torchzero/modules/ops/misc.py +128 -129
- torchzero/modules/ops/multi.py +19 -19
- torchzero/modules/ops/reduce.py +16 -16
- torchzero/modules/ops/split.py +26 -26
- torchzero/modules/ops/switch.py +4 -4
- torchzero/modules/ops/unary.py +20 -20
- torchzero/modules/ops/utility.py +37 -37
- torchzero/modules/optimizers/adagrad.py +33 -24
- torchzero/modules/optimizers/adam.py +31 -34
- torchzero/modules/optimizers/lion.py +4 -4
- torchzero/modules/optimizers/muon.py +6 -6
- torchzero/modules/optimizers/orthograd.py +4 -5
- torchzero/modules/optimizers/rmsprop.py +13 -16
- torchzero/modules/optimizers/rprop.py +52 -49
- torchzero/modules/optimizers/shampoo.py +17 -23
- torchzero/modules/optimizers/soap.py +12 -19
- torchzero/modules/optimizers/sophia_h.py +13 -13
- torchzero/modules/projections/dct.py +4 -4
- torchzero/modules/projections/fft.py +6 -6
- torchzero/modules/projections/galore.py +1 -1
- torchzero/modules/projections/projection.py +57 -57
- torchzero/modules/projections/structural.py +17 -17
- torchzero/modules/quasi_newton/__init__.py +33 -4
- torchzero/modules/quasi_newton/cg.py +76 -26
- torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +24 -24
- torchzero/modules/quasi_newton/lbfgs.py +15 -15
- torchzero/modules/quasi_newton/lsr1.py +18 -17
- torchzero/modules/quasi_newton/olbfgs.py +19 -19
- torchzero/modules/quasi_newton/quasi_newton.py +257 -48
- torchzero/modules/second_order/newton.py +38 -21
- torchzero/modules/second_order/newton_cg.py +13 -12
- torchzero/modules/second_order/nystrom.py +19 -19
- torchzero/modules/smoothing/gaussian.py +21 -21
- torchzero/modules/smoothing/laplacian.py +7 -9
- torchzero/modules/weight_decay/__init__.py +1 -1
- torchzero/modules/weight_decay/weight_decay.py +43 -9
- torchzero/modules/wrappers/optim_wrapper.py +11 -11
- torchzero/optim/wrappers/directsearch.py +244 -0
- torchzero/optim/wrappers/fcmaes.py +97 -0
- torchzero/optim/wrappers/mads.py +90 -0
- torchzero/optim/wrappers/nevergrad.py +4 -4
- torchzero/optim/wrappers/nlopt.py +28 -14
- torchzero/optim/wrappers/optuna.py +70 -0
- torchzero/optim/wrappers/scipy.py +162 -13
- torchzero/utils/__init__.py +2 -6
- torchzero/utils/derivatives.py +2 -1
- torchzero/utils/optimizer.py +55 -74
- torchzero/utils/python_tools.py +17 -4
- {torchzero-0.3.8.dist-info → torchzero-0.3.10.dist-info}/METADATA +14 -14
- torchzero-0.3.10.dist-info/RECORD +139 -0
- {torchzero-0.3.8.dist-info → torchzero-0.3.10.dist-info}/WHEEL +1 -1
- torchzero/core/preconditioner.py +0 -138
- torchzero/modules/experimental/algebraic_newton.py +0 -145
- torchzero/modules/experimental/tropical_newton.py +0 -136
- torchzero-0.3.8.dist-info/RECORD +0 -130
- {torchzero-0.3.8.dist-info → torchzero-0.3.10.dist-info}/licenses/LICENSE +0 -0
- {torchzero-0.3.8.dist-info → torchzero-0.3.10.dist-info}/top_level.txt +0 -0
|
@@ -8,7 +8,7 @@ from typing import Any
|
|
|
8
8
|
import numpy as np
|
|
9
9
|
import torch
|
|
10
10
|
|
|
11
|
-
from ...core import Module, Target,
|
|
11
|
+
from ...core import Module, Target, Var
|
|
12
12
|
from ...utils import tofloat
|
|
13
13
|
|
|
14
14
|
|
|
@@ -62,12 +62,12 @@ class LineSearch(Module, ABC):
|
|
|
62
62
|
if any(a!=0 for a in alpha):
|
|
63
63
|
torch._foreach_add_(params, torch._foreach_mul(update, alpha))
|
|
64
64
|
|
|
65
|
-
def _loss(self, step_size: float,
|
|
65
|
+
def _loss(self, step_size: float, var: Var, closure, params: list[torch.Tensor],
|
|
66
66
|
update: list[torch.Tensor], backward:bool=False) -> float:
|
|
67
67
|
|
|
68
68
|
# if step_size is 0, we might already know the loss
|
|
69
|
-
if (
|
|
70
|
-
return tofloat(
|
|
69
|
+
if (var.loss is not None) and (step_size == 0):
|
|
70
|
+
return tofloat(var.loss)
|
|
71
71
|
|
|
72
72
|
# check max iter
|
|
73
73
|
if self._maxiter is not None and self._current_iter >= self._maxiter: raise MaxLineSearchItersReached
|
|
@@ -85,23 +85,23 @@ class LineSearch(Module, ABC):
|
|
|
85
85
|
self._lowest_loss = tofloat(loss)
|
|
86
86
|
self._best_step_size = step_size
|
|
87
87
|
|
|
88
|
-
# if evaluated loss at step size 0, set it to
|
|
88
|
+
# if evaluated loss at step size 0, set it to var.loss
|
|
89
89
|
if step_size == 0:
|
|
90
|
-
|
|
91
|
-
if backward:
|
|
90
|
+
var.loss = loss
|
|
91
|
+
if backward: var.grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
|
|
92
92
|
|
|
93
93
|
return tofloat(loss)
|
|
94
94
|
|
|
95
|
-
def _loss_derivative(self, step_size: float,
|
|
95
|
+
def _loss_derivative(self, step_size: float, var: Var, closure,
|
|
96
96
|
params: list[torch.Tensor], update: list[torch.Tensor]):
|
|
97
97
|
# if step_size is 0, we might already know the derivative
|
|
98
|
-
if (
|
|
99
|
-
loss = self._loss(step_size=step_size,
|
|
100
|
-
derivative = - sum(t.sum() for t in torch._foreach_mul(
|
|
98
|
+
if (var.grad is not None) and (step_size == 0):
|
|
99
|
+
loss = self._loss(step_size=step_size,var=var,closure=closure,params=params,update=update,backward=False)
|
|
100
|
+
derivative = - sum(t.sum() for t in torch._foreach_mul(var.grad, update))
|
|
101
101
|
|
|
102
102
|
else:
|
|
103
103
|
# loss with a backward pass sets params.grad
|
|
104
|
-
loss = self._loss(step_size=step_size,
|
|
104
|
+
loss = self._loss(step_size=step_size,var=var,closure=closure,params=params,update=update,backward=True)
|
|
105
105
|
|
|
106
106
|
# directional derivative
|
|
107
107
|
derivative = - sum(t.sum() for t in torch._foreach_mul([p.grad if p.grad is not None
|
|
@@ -109,60 +109,60 @@ class LineSearch(Module, ABC):
|
|
|
109
109
|
|
|
110
110
|
return loss, tofloat(derivative)
|
|
111
111
|
|
|
112
|
-
def evaluate_step_size(self, step_size: float,
|
|
113
|
-
closure =
|
|
112
|
+
def evaluate_step_size(self, step_size: float, var: Var, backward:bool=False):
|
|
113
|
+
closure = var.closure
|
|
114
114
|
if closure is None: raise RuntimeError('line search requires closure')
|
|
115
|
-
return self._loss(step_size=step_size,
|
|
115
|
+
return self._loss(step_size=step_size, var=var, closure=closure, params=var.params,update=var.get_update(),backward=backward)
|
|
116
116
|
|
|
117
|
-
def evaluate_step_size_loss_and_derivative(self, step_size: float,
|
|
118
|
-
closure =
|
|
117
|
+
def evaluate_step_size_loss_and_derivative(self, step_size: float, var: Var):
|
|
118
|
+
closure = var.closure
|
|
119
119
|
if closure is None: raise RuntimeError('line search requires closure')
|
|
120
|
-
return self._loss_derivative(step_size=step_size,
|
|
120
|
+
return self._loss_derivative(step_size=step_size, var=var, closure=closure, params=var.params,update=var.get_update())
|
|
121
121
|
|
|
122
|
-
def make_objective(self,
|
|
123
|
-
closure =
|
|
122
|
+
def make_objective(self, var: Var, backward:bool=False):
|
|
123
|
+
closure = var.closure
|
|
124
124
|
if closure is None: raise RuntimeError('line search requires closure')
|
|
125
|
-
return partial(self._loss,
|
|
125
|
+
return partial(self._loss, var=var, closure=closure, params=var.params, update=var.get_update(), backward=backward)
|
|
126
126
|
|
|
127
|
-
def make_objective_with_derivative(self,
|
|
128
|
-
closure =
|
|
127
|
+
def make_objective_with_derivative(self, var: Var):
|
|
128
|
+
closure = var.closure
|
|
129
129
|
if closure is None: raise RuntimeError('line search requires closure')
|
|
130
|
-
return partial(self._loss_derivative,
|
|
130
|
+
return partial(self._loss_derivative, var=var, closure=closure, params=var.params, update=var.get_update())
|
|
131
131
|
|
|
132
132
|
@abstractmethod
|
|
133
|
-
def search(self, update: list[torch.Tensor],
|
|
133
|
+
def search(self, update: list[torch.Tensor], var: Var) -> float:
|
|
134
134
|
"""Finds the step size to use"""
|
|
135
135
|
|
|
136
136
|
@torch.no_grad
|
|
137
|
-
def step(self,
|
|
137
|
+
def step(self, var: Var) -> Var:
|
|
138
138
|
self._reset()
|
|
139
|
-
params =
|
|
140
|
-
update =
|
|
139
|
+
params = var.params
|
|
140
|
+
update = var.get_update()
|
|
141
141
|
|
|
142
142
|
try:
|
|
143
|
-
step_size = self.search(update=update,
|
|
143
|
+
step_size = self.search(update=update, var=var)
|
|
144
144
|
except MaxLineSearchItersReached:
|
|
145
145
|
step_size = self._best_step_size
|
|
146
146
|
|
|
147
147
|
# set loss_approx
|
|
148
|
-
if
|
|
148
|
+
if var.loss_approx is None: var.loss_approx = self._lowest_loss
|
|
149
149
|
|
|
150
150
|
# this is last module - set step size to found step_size times lr
|
|
151
|
-
if
|
|
151
|
+
if var.is_last:
|
|
152
152
|
|
|
153
|
-
if
|
|
153
|
+
if var.last_module_lrs is None:
|
|
154
154
|
self.set_step_size_(step_size, params=params, update=update)
|
|
155
155
|
|
|
156
156
|
else:
|
|
157
|
-
self._set_per_parameter_step_size_([step_size*lr for lr in
|
|
157
|
+
self._set_per_parameter_step_size_([step_size*lr for lr in var.last_module_lrs], params=params, update=update)
|
|
158
158
|
|
|
159
|
-
|
|
160
|
-
return
|
|
159
|
+
var.stop = True; var.skip_update = True
|
|
160
|
+
return var
|
|
161
161
|
|
|
162
162
|
# revert parameters and multiply update by step size
|
|
163
163
|
self.set_step_size_(0, params=params, update=update)
|
|
164
|
-
torch._foreach_mul_(
|
|
165
|
-
return
|
|
164
|
+
torch._foreach_mul_(var.update, step_size)
|
|
165
|
+
return var
|
|
166
166
|
|
|
167
167
|
|
|
168
168
|
class GridLineSearch(LineSearch):
|
|
@@ -172,10 +172,10 @@ class GridLineSearch(LineSearch):
|
|
|
172
172
|
super().__init__(defaults)
|
|
173
173
|
|
|
174
174
|
@torch.no_grad
|
|
175
|
-
def search(self, update,
|
|
176
|
-
start,end,num=itemgetter('start','end','num')(self.settings[
|
|
175
|
+
def search(self, update, var):
|
|
176
|
+
start,end,num=itemgetter('start','end','num')(self.settings[var.params[0]])
|
|
177
177
|
|
|
178
178
|
for lr in torch.linspace(start,end,num):
|
|
179
|
-
self.evaluate_step_size(lr.item(),
|
|
179
|
+
self.evaluate_step_size(lr.item(), var=var, backward=False)
|
|
180
180
|
|
|
181
181
|
return self._best_step_size
|
|
@@ -7,6 +7,21 @@ from .line_search import LineSearch
|
|
|
7
7
|
|
|
8
8
|
|
|
9
9
|
class ScipyMinimizeScalar(LineSearch):
|
|
10
|
+
"""Line search via :code:`scipy.optimize.minimize_scalar` which implements brent, golden search and bounded brent methods.
|
|
11
|
+
|
|
12
|
+
Args:
|
|
13
|
+
method (str | None, optional): "brent", "golden" or "bounded". Defaults to None.
|
|
14
|
+
maxiter (int | None, optional): maximum number of function evaluations the line search is allowed to perform. Defaults to None.
|
|
15
|
+
bracket (Sequence | None, optional):
|
|
16
|
+
Either a triple (xa, xb, xc) satisfying xa < xb < xc and func(xb) < func(xa) and func(xb) < func(xc), or a pair (xa, xb) to be used as initial points for a downhill bracket search. Defaults to None.
|
|
17
|
+
bounds (Sequence | None, optional):
|
|
18
|
+
For method ‘bounded’, bounds is mandatory and must have two finite items corresponding to the optimization bounds. Defaults to None.
|
|
19
|
+
tol (float | None, optional): Tolerance for termination. Defaults to None.
|
|
20
|
+
options (dict | None, optional): A dictionary of solver options. Defaults to None.
|
|
21
|
+
|
|
22
|
+
For more details on methods and arguments refer to https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.minimize_scalar.html
|
|
23
|
+
|
|
24
|
+
"""
|
|
10
25
|
def __init__(
|
|
11
26
|
self,
|
|
12
27
|
method: str | None = None,
|
|
@@ -24,10 +39,10 @@ class ScipyMinimizeScalar(LineSearch):
|
|
|
24
39
|
|
|
25
40
|
|
|
26
41
|
@torch.no_grad
|
|
27
|
-
def search(self, update,
|
|
28
|
-
objective = self.make_objective(
|
|
42
|
+
def search(self, update, var):
|
|
43
|
+
objective = self.make_objective(var=var)
|
|
29
44
|
method, bracket, bounds, tol, options, maxiter = itemgetter(
|
|
30
|
-
'method', 'bracket', 'bounds', 'tol', 'options', 'maxiter')(self.settings[
|
|
45
|
+
'method', 'bracket', 'bounds', 'tol', 'options', 'maxiter')(self.settings[var.params[0]])
|
|
31
46
|
|
|
32
47
|
if maxiter is not None:
|
|
33
48
|
options = dict(options) if isinstance(options, Mapping) else {}
|
|
@@ -183,6 +183,21 @@ def _notfinite(x):
|
|
|
183
183
|
return not math.isfinite(x)
|
|
184
184
|
|
|
185
185
|
class StrongWolfe(LineSearch):
|
|
186
|
+
"""Cubic interpolation line search satisfying Strong Wolfe condition.
|
|
187
|
+
|
|
188
|
+
Args:
|
|
189
|
+
init (float, optional): Initial step size. Defaults to 1.0.
|
|
190
|
+
c1 (float, optional): Acceptance value for weak wolfe condition. Defaults to 1e-4.
|
|
191
|
+
c2 (float, optional): Acceptance value for strong wolfe condition (set to 0.1 for conjugate gradient). Defaults to 0.9.
|
|
192
|
+
maxiter (int, optional): Maximum number of line search iterations. Defaults to 25.
|
|
193
|
+
maxzoom (int, optional): Maximum number of zoom iterations. Defaults to 10.
|
|
194
|
+
expand (float, optional): Expansion factor (multipler to step size when weak condition not satisfied). Defaults to 2.0.
|
|
195
|
+
adaptive (bool, optional):
|
|
196
|
+
when enabled, if line search failed, initial step size is reduced.
|
|
197
|
+
Otherwise it is reset to initial value. Defaults to True.
|
|
198
|
+
plus_minus (bool, optional):
|
|
199
|
+
If enabled and the direction is not descent direction, performs line search in opposite direction. Defaults to False.
|
|
200
|
+
"""
|
|
186
201
|
def __init__(
|
|
187
202
|
self,
|
|
188
203
|
init: float = 1.0,
|
|
@@ -193,23 +208,22 @@ class StrongWolfe(LineSearch):
|
|
|
193
208
|
# a_max: float = 1e10,
|
|
194
209
|
expand: float = 2.0,
|
|
195
210
|
adaptive = True,
|
|
196
|
-
fallback = False,
|
|
197
211
|
plus_minus = False,
|
|
198
212
|
):
|
|
199
213
|
defaults=dict(init=init,c1=c1,c2=c2,maxiter=maxiter,maxzoom=maxzoom,
|
|
200
|
-
expand=expand, adaptive=adaptive,
|
|
214
|
+
expand=expand, adaptive=adaptive, plus_minus=plus_minus)
|
|
201
215
|
super().__init__(defaults=defaults)
|
|
202
216
|
|
|
203
217
|
self.global_state['initial_scale'] = 1.0
|
|
204
218
|
self.global_state['beta_scale'] = 1.0
|
|
205
219
|
|
|
206
220
|
@torch.no_grad
|
|
207
|
-
def search(self, update,
|
|
208
|
-
objective = self.make_objective_with_derivative(
|
|
221
|
+
def search(self, update, var):
|
|
222
|
+
objective = self.make_objective_with_derivative(var=var)
|
|
209
223
|
|
|
210
|
-
init, c1, c2, maxiter, maxzoom, expand, adaptive,
|
|
224
|
+
init, c1, c2, maxiter, maxzoom, expand, adaptive, plus_minus = itemgetter(
|
|
211
225
|
'init', 'c1', 'c2', 'maxiter', 'maxzoom',
|
|
212
|
-
'expand', 'adaptive', '
|
|
226
|
+
'expand', 'adaptive', 'plus_minus')(self.settings[var.params[0]])
|
|
213
227
|
|
|
214
228
|
f_0, g_0 = objective(0)
|
|
215
229
|
|
|
@@ -232,29 +246,4 @@ class StrongWolfe(LineSearch):
|
|
|
232
246
|
|
|
233
247
|
# fallback to backtracking on fail
|
|
234
248
|
if adaptive: self.global_state['initial_scale'] *= 0.5
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
objective = self.make_objective(vars=vars)
|
|
238
|
-
|
|
239
|
-
# # directional derivative
|
|
240
|
-
g_0 = -sum(t.sum() for t in torch._foreach_mul(vars.get_grad(), vars.get_update()))
|
|
241
|
-
|
|
242
|
-
step_size = backtracking_line_search(
|
|
243
|
-
objective,
|
|
244
|
-
g_0,
|
|
245
|
-
init=init * self.global_state["initial_scale"],
|
|
246
|
-
beta=0.5 * self.global_state["beta_scale"],
|
|
247
|
-
c=c1,
|
|
248
|
-
maxiter=maxiter * 2,
|
|
249
|
-
a_min=None,
|
|
250
|
-
try_negative=plus_minus,
|
|
251
|
-
)
|
|
252
|
-
|
|
253
|
-
# found an alpha that reduces loss
|
|
254
|
-
if step_size is not None:
|
|
255
|
-
self.global_state['beta_scale'] = min(1.0, self.global_state.get('beta_scale', 1) * math.sqrt(1.5))
|
|
256
|
-
return step_size
|
|
257
|
-
|
|
258
|
-
# on fail reduce beta scale value
|
|
259
|
-
self.global_state['beta_scale'] /= 1.5
|
|
260
|
-
return 0
|
|
249
|
+
return 0
|
|
@@ -6,31 +6,43 @@ from .line_search import LineSearch
|
|
|
6
6
|
|
|
7
7
|
|
|
8
8
|
class TrustRegion(LineSearch):
|
|
9
|
-
"""Basic first order trust region
|
|
9
|
+
"""Basic first order trust region method. Re-evaluates the function after stepping, if value decreased sufficiently,
|
|
10
|
+
step size is increased. If value increased, step size is decreased. This is prone to collapsing.
|
|
11
|
+
|
|
12
|
+
Args:
|
|
13
|
+
nplus (float, optional): multiplier to step size on successful steps. Defaults to 1.5.
|
|
14
|
+
nminus (float, optional): multiplier to step size on unsuccessful steps. Defaults to 0.75.
|
|
15
|
+
c (float, optional): descent condition. Defaults to 1e-4.
|
|
16
|
+
init (float, optional): initial step size. Defaults to 1.
|
|
17
|
+
backtrack (bool, optional): whether to undo the step if value increased. Defaults to True.
|
|
18
|
+
adaptive (bool, optional):
|
|
19
|
+
If enabled, when multiple consecutive steps have been successful or unsuccessful,
|
|
20
|
+
the corresponding multipliers are increased, otherwise they are reset. Defaults to True.
|
|
21
|
+
"""
|
|
10
22
|
def __init__(self, nplus: float=1.5, nminus: float=0.75, c: float=1e-4, init: float = 1, backtrack: bool = True, adaptive: bool = True):
|
|
11
23
|
defaults = dict(nplus=nplus, nminus=nminus, c=c, init=init, backtrack=backtrack, adaptive=adaptive)
|
|
12
24
|
super().__init__(defaults)
|
|
13
25
|
|
|
14
26
|
@torch.no_grad
|
|
15
|
-
def search(self, update,
|
|
27
|
+
def search(self, update, var):
|
|
16
28
|
|
|
17
|
-
nplus, nminus, c, init, backtrack, adaptive = itemgetter('nplus','nminus','c','init','backtrack', 'adaptive')(self.settings[
|
|
29
|
+
nplus, nminus, c, init, backtrack, adaptive = itemgetter('nplus','nminus','c','init','backtrack', 'adaptive')(self.settings[var.params[0]])
|
|
18
30
|
step_size = self.global_state.setdefault('step_size', init)
|
|
19
31
|
previous_success = self.global_state.setdefault('previous_success', False)
|
|
20
32
|
nplus_mul = self.global_state.setdefault('nplus_mul', 1)
|
|
21
33
|
nminus_mul = self.global_state.setdefault('nminus_mul', 1)
|
|
22
34
|
|
|
23
35
|
|
|
24
|
-
f_0 = self.evaluate_step_size(0,
|
|
36
|
+
f_0 = self.evaluate_step_size(0, var, backward=False)
|
|
25
37
|
|
|
26
38
|
# directional derivative (0 if c = 0 because it is not needed)
|
|
27
39
|
if c == 0: d = 0
|
|
28
|
-
else: d = -sum(t.sum() for t in torch._foreach_mul(
|
|
40
|
+
else: d = -sum(t.sum() for t in torch._foreach_mul(var.get_grad(), update))
|
|
29
41
|
|
|
30
42
|
# test step size
|
|
31
43
|
sufficient_f = f_0 + c * step_size * min(d, 0) # pyright:ignore[reportArgumentType]
|
|
32
44
|
|
|
33
|
-
f_1 = self.evaluate_step_size(step_size,
|
|
45
|
+
f_1 = self.evaluate_step_size(step_size, var, backward=False)
|
|
34
46
|
|
|
35
47
|
proposed = step_size
|
|
36
48
|
|
torchzero/modules/lr/__init__.py
CHANGED
|
@@ -1,2 +1,2 @@
|
|
|
1
1
|
from .lr import LR, StepSize, Warmup
|
|
2
|
-
from .
|
|
2
|
+
from .adaptive import PolyakStepSize, RandomStepSize
|
|
@@ -1,18 +1,20 @@
|
|
|
1
|
+
"""Various step size strategies"""
|
|
1
2
|
import random
|
|
2
3
|
from typing import Any
|
|
3
|
-
|
|
4
|
+
from operator import itemgetter
|
|
4
5
|
import torch
|
|
5
6
|
|
|
6
7
|
from ...core import Transform
|
|
7
|
-
from ...utils import TensorList, NumberList
|
|
8
|
+
from ...utils import TensorList, NumberList, unpack_dicts
|
|
8
9
|
|
|
9
10
|
|
|
10
11
|
class PolyakStepSize(Transform):
|
|
11
|
-
"""Polyak step-size.
|
|
12
|
+
"""Polyak's step-size method.
|
|
12
13
|
|
|
13
14
|
Args:
|
|
14
15
|
max (float | None, optional): maximum possible step size. Defaults to None.
|
|
15
|
-
min_obj_value (int, optional):
|
|
16
|
+
min_obj_value (int, optional):
|
|
17
|
+
(estimated) minimal possible value of the objective function (lowest possible loss). Defaults to 0.
|
|
16
18
|
use_grad (bool, optional):
|
|
17
19
|
if True, uses dot product of update and gradient to compute the step size.
|
|
18
20
|
Otherwise, dot product of update with itself is used, which has no geometric meaning so it probably won't work well.
|
|
@@ -28,29 +30,24 @@ class PolyakStepSize(Transform):
|
|
|
28
30
|
super().__init__(defaults, uses_grad=use_grad)
|
|
29
31
|
|
|
30
32
|
@torch.no_grad
|
|
31
|
-
def
|
|
32
|
-
loss = vars.get_loss(False)
|
|
33
|
+
def apply(self, tensors, params, grads, loss, states, settings):
|
|
33
34
|
assert grads is not None
|
|
34
35
|
tensors = TensorList(tensors)
|
|
35
36
|
grads = TensorList(grads)
|
|
36
|
-
alpha =
|
|
37
|
-
|
|
38
|
-
parameterwise =
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
37
|
+
alpha = NumberList(s['alpha'] for s in settings)
|
|
38
|
+
|
|
39
|
+
parameterwise, use_grad, max, min_obj_value = itemgetter('parameterwise', 'use_grad', 'max', 'min_obj_value')(settings[0])
|
|
40
|
+
|
|
41
|
+
if use_grad: denom = tensors.dot(grads)
|
|
42
|
+
else: denom = tensors.dot(tensors)
|
|
42
43
|
|
|
43
44
|
if parameterwise:
|
|
44
|
-
if use_grad: denom = (tensors * grads).sum()
|
|
45
|
-
else: denom = tensors.pow(2).sum()
|
|
46
45
|
polyak_step_size: TensorList | Any = (loss - min_obj_value) / denom.where(denom!=0, 1)
|
|
47
46
|
polyak_step_size = polyak_step_size.where(denom != 0, 0)
|
|
48
47
|
if max is not None: polyak_step_size = polyak_step_size.clamp_max(max)
|
|
49
48
|
|
|
50
49
|
else:
|
|
51
|
-
if
|
|
52
|
-
else: denom = tensors.dot(tensors)
|
|
53
|
-
if denom == 0: polyak_step_size = 0 # we converged
|
|
50
|
+
if denom.abs() <= torch.finfo(denom.dtype).eps: polyak_step_size = 0 # converged
|
|
54
51
|
else: polyak_step_size = (loss - min_obj_value) / denom
|
|
55
52
|
|
|
56
53
|
if max is not None:
|
|
@@ -60,9 +57,8 @@ class PolyakStepSize(Transform):
|
|
|
60
57
|
return tensors
|
|
61
58
|
|
|
62
59
|
|
|
63
|
-
|
|
64
60
|
class RandomStepSize(Transform):
|
|
65
|
-
"""Uses random global step size from `low` to `high`.
|
|
61
|
+
"""Uses random global or layer-wise step size from `low` to `high`.
|
|
66
62
|
|
|
67
63
|
Args:
|
|
68
64
|
low (float, optional): minimum learning rate. Defaults to 0.
|
|
@@ -76,21 +72,21 @@ class RandomStepSize(Transform):
|
|
|
76
72
|
super().__init__(defaults, uses_grad=False)
|
|
77
73
|
|
|
78
74
|
@torch.no_grad
|
|
79
|
-
def
|
|
80
|
-
|
|
81
|
-
parameterwise =
|
|
75
|
+
def apply(self, tensors, params, grads, loss, states, settings):
|
|
76
|
+
s = settings[0]
|
|
77
|
+
parameterwise = s['parameterwise']
|
|
82
78
|
|
|
83
|
-
seed =
|
|
79
|
+
seed = s['seed']
|
|
84
80
|
if 'generator' not in self.global_state:
|
|
85
81
|
self.global_state['generator'] = random.Random(seed)
|
|
86
82
|
generator: random.Random = self.global_state['generator']
|
|
87
83
|
|
|
88
84
|
if parameterwise:
|
|
89
|
-
low, high =
|
|
85
|
+
low, high = unpack_dicts(settings, 'low', 'high')
|
|
90
86
|
lr = [generator.uniform(l, h) for l, h in zip(low, high)]
|
|
91
87
|
else:
|
|
92
|
-
low =
|
|
93
|
-
high =
|
|
88
|
+
low = s['low']
|
|
89
|
+
high = s['high']
|
|
94
90
|
lr = generator.uniform(low, high)
|
|
95
91
|
|
|
96
92
|
torch._foreach_mul_(tensors, lr)
|
torchzero/modules/lr/lr.py
CHANGED
|
@@ -1,8 +1,8 @@
|
|
|
1
|
+
"""Learning rate"""
|
|
1
2
|
import torch
|
|
2
3
|
|
|
3
4
|
from ...core import Transform
|
|
4
|
-
from ...utils import NumberList, TensorList, generic_eq
|
|
5
|
-
|
|
5
|
+
from ...utils import NumberList, TensorList, generic_eq, unpack_dicts
|
|
6
6
|
|
|
7
7
|
def lazy_lr(tensors: TensorList, lr: float | list, inplace:bool):
|
|
8
8
|
"""multiplies by lr if lr is not 1"""
|
|
@@ -11,48 +11,52 @@ def lazy_lr(tensors: TensorList, lr: float | list, inplace:bool):
|
|
|
11
11
|
return tensors * lr
|
|
12
12
|
|
|
13
13
|
class LR(Transform):
|
|
14
|
+
"""Learning rate. Adding this module also adds support for LR schedulers."""
|
|
14
15
|
def __init__(self, lr: float):
|
|
15
16
|
defaults=dict(lr=lr)
|
|
16
17
|
super().__init__(defaults, uses_grad=False)
|
|
17
18
|
|
|
18
19
|
@torch.no_grad
|
|
19
|
-
def
|
|
20
|
-
return lazy_lr(TensorList(tensors), lr=
|
|
20
|
+
def apply(self, tensors, params, grads, loss, states, settings):
|
|
21
|
+
return lazy_lr(TensorList(tensors), lr=[s['lr'] for s in settings], inplace=True)
|
|
21
22
|
|
|
22
23
|
class StepSize(Transform):
|
|
23
|
-
"""this is exactly the same as LR, except the `lr` parameter can be renamed to any other name"""
|
|
24
|
+
"""this is exactly the same as LR, except the `lr` parameter can be renamed to any other name to avoid clashes"""
|
|
24
25
|
def __init__(self, step_size: float, key = 'step_size'):
|
|
25
26
|
defaults={"key": key, key: step_size}
|
|
26
27
|
super().__init__(defaults, uses_grad=False)
|
|
27
28
|
|
|
28
29
|
@torch.no_grad
|
|
29
|
-
def
|
|
30
|
-
|
|
31
|
-
for p in params:
|
|
32
|
-
settings = self.settings[p]
|
|
33
|
-
lrs.append(settings[settings['key']])
|
|
34
|
-
return lazy_lr(TensorList(tensors), lr=lrs, inplace=True)
|
|
30
|
+
def apply(self, tensors, params, grads, loss, states, settings):
|
|
31
|
+
return lazy_lr(TensorList(tensors), lr=[s[s['key']] for s in settings], inplace=True)
|
|
35
32
|
|
|
36
33
|
|
|
37
|
-
def
|
|
34
|
+
def _warmup_lr(step: int, start_lr: float | NumberList, end_lr: float | NumberList, steps: float):
|
|
38
35
|
"""returns warm up lr scalar"""
|
|
39
36
|
if step > steps: return end_lr
|
|
40
37
|
return start_lr + (end_lr - start_lr) * (step / steps)
|
|
41
38
|
|
|
42
39
|
class Warmup(Transform):
|
|
40
|
+
"""Learning rate warmup, linearly increases learning rate multiplier from :code:`start_lr` to :code:`end_lr` over :code:`steps` steps.
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
start_lr (_type_, optional): initial learning rate multiplier on first step. Defaults to 1e-5.
|
|
44
|
+
end_lr (float, optional): learning rate multiplier at the end and after warmup. Defaults to 1.
|
|
45
|
+
steps (int, optional): number of steps to perform warmup for. Defaults to 100.
|
|
46
|
+
"""
|
|
43
47
|
def __init__(self, start_lr = 1e-5, end_lr:float = 1, steps = 100):
|
|
44
48
|
defaults = dict(start_lr=start_lr,end_lr=end_lr, steps=steps)
|
|
45
49
|
super().__init__(defaults, uses_grad=False)
|
|
46
50
|
|
|
47
51
|
@torch.no_grad
|
|
48
|
-
def
|
|
49
|
-
start_lr, end_lr =
|
|
50
|
-
num_steps =
|
|
52
|
+
def apply(self, tensors, params, grads, loss, states, settings):
|
|
53
|
+
start_lr, end_lr = unpack_dicts(settings, 'start_lr', 'end_lr', cls = NumberList)
|
|
54
|
+
num_steps = settings[0]['steps']
|
|
51
55
|
step = self.global_state.get('step', 0)
|
|
52
56
|
|
|
53
57
|
target = lazy_lr(
|
|
54
58
|
TensorList(tensors),
|
|
55
|
-
lr=
|
|
59
|
+
lr=_warmup_lr(step=step, start_lr=start_lr, end_lr=end_lr, steps=num_steps),
|
|
56
60
|
inplace=True
|
|
57
61
|
)
|
|
58
62
|
self.global_state['step'] = step + 1
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
"""Modules that perform averaging over a history of past updates."""
|
|
1
2
|
from collections import deque
|
|
2
3
|
from collections.abc import Sequence
|
|
3
4
|
from typing import Any, Literal, cast
|
|
@@ -9,14 +10,19 @@ from ...utils import tolist
|
|
|
9
10
|
|
|
10
11
|
|
|
11
12
|
class Averaging(TensorwiseTransform):
|
|
13
|
+
"""Average of past :code:`history_size` updates.
|
|
14
|
+
|
|
15
|
+
Args:
|
|
16
|
+
history_size (int): Number of past updates to average
|
|
17
|
+
target (Target, optional): target. Defaults to 'update'.
|
|
18
|
+
"""
|
|
12
19
|
def __init__(self, history_size: int, target: Target = 'update'):
|
|
13
20
|
defaults = dict(history_size=history_size)
|
|
14
21
|
super().__init__(uses_grad=False, defaults=defaults, target=target)
|
|
15
22
|
|
|
16
23
|
@torch.no_grad
|
|
17
|
-
def
|
|
18
|
-
history_size =
|
|
19
|
-
state = self.state[param]
|
|
24
|
+
def apply_tensor(self, tensor, param, grad, loss, state, settings):
|
|
25
|
+
history_size = settings['history_size']
|
|
20
26
|
if 'history' not in state:
|
|
21
27
|
state['history'] = deque(maxlen=history_size)
|
|
22
28
|
state['average'] = torch.zeros_like(tensor)
|
|
@@ -29,15 +35,19 @@ class Averaging(TensorwiseTransform):
|
|
|
29
35
|
return average / len(history)
|
|
30
36
|
|
|
31
37
|
class WeightedAveraging(TensorwiseTransform):
|
|
32
|
-
"""
|
|
38
|
+
"""Weighted average of past :code:`len(weights)` updates.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
weights (Sequence[float]): a sequence of weights from oldest to newest.
|
|
42
|
+
target (Target, optional): target. Defaults to 'update'.
|
|
43
|
+
"""
|
|
33
44
|
def __init__(self, weights: Sequence[float] | torch.Tensor | Any, target: Target = 'update'):
|
|
34
45
|
defaults = dict(weights = tolist(weights))
|
|
35
46
|
super().__init__(uses_grad=False, defaults=defaults, target=target)
|
|
36
47
|
|
|
37
48
|
@torch.no_grad
|
|
38
|
-
def
|
|
39
|
-
weights =
|
|
40
|
-
state = self.state[param]
|
|
49
|
+
def apply_tensor(self, tensor, param, grad, loss, state, settings):
|
|
50
|
+
weights = settings['weights']
|
|
41
51
|
|
|
42
52
|
if 'history' not in state:
|
|
43
53
|
state['history'] = deque(maxlen=len(weights))
|
|
@@ -59,14 +69,19 @@ class WeightedAveraging(TensorwiseTransform):
|
|
|
59
69
|
|
|
60
70
|
|
|
61
71
|
class MedianAveraging(TensorwiseTransform):
|
|
72
|
+
"""Median of past :code:`history_size` updates.
|
|
73
|
+
|
|
74
|
+
Args:
|
|
75
|
+
history_size (int): Number of past updates to average
|
|
76
|
+
target (Target, optional): target. Defaults to 'update'.
|
|
77
|
+
"""
|
|
62
78
|
def __init__(self, history_size: int, target: Target = 'update'):
|
|
63
79
|
defaults = dict(history_size = history_size)
|
|
64
80
|
super().__init__(uses_grad=False, defaults=defaults, target=target)
|
|
65
81
|
|
|
66
82
|
@torch.no_grad
|
|
67
|
-
def
|
|
68
|
-
history_size =
|
|
69
|
-
state = self.state[param]
|
|
83
|
+
def apply_tensor(self, tensor, param, grad, loss, state, settings):
|
|
84
|
+
history_size = settings['history_size']
|
|
70
85
|
|
|
71
86
|
if 'history' not in state:
|
|
72
87
|
state['history'] = deque(maxlen=history_size)
|