torchzero 0.1.3__py3-none-any.whl → 0.1.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torchzero/core/__init__.py +1 -1
- torchzero/core/module.py +72 -49
- torchzero/core/tensorlist_optimizer.py +1 -1
- torchzero/modules/adaptive/adaptive.py +11 -11
- torchzero/modules/experimental/experimental.py +41 -41
- torchzero/modules/experimental/quad_interp.py +8 -8
- torchzero/modules/experimental/subspace.py +37 -37
- torchzero/modules/gradient_approximation/base_approximator.py +19 -24
- torchzero/modules/gradient_approximation/fdm.py +1 -1
- torchzero/modules/gradient_approximation/newton_fdm.py +13 -13
- torchzero/modules/gradient_approximation/rfdm.py +1 -1
- torchzero/modules/line_search/armijo.py +8 -8
- torchzero/modules/line_search/base_ls.py +8 -8
- torchzero/modules/line_search/directional_newton.py +14 -14
- torchzero/modules/line_search/grid_ls.py +7 -7
- torchzero/modules/line_search/scipy_minimize_scalar.py +3 -3
- torchzero/modules/meta/alternate.py +4 -4
- torchzero/modules/meta/grafting.py +23 -23
- torchzero/modules/meta/optimizer_wrapper.py +14 -14
- torchzero/modules/meta/return_overrides.py +8 -8
- torchzero/modules/misc/accumulate.py +6 -6
- torchzero/modules/misc/basic.py +16 -16
- torchzero/modules/misc/lr.py +2 -2
- torchzero/modules/misc/multistep.py +7 -7
- torchzero/modules/misc/on_increase.py +9 -9
- torchzero/modules/momentum/momentum.py +4 -4
- torchzero/modules/operations/multi.py +44 -44
- torchzero/modules/operations/reduction.py +28 -28
- torchzero/modules/operations/singular.py +9 -9
- torchzero/modules/optimizers/adagrad.py +1 -1
- torchzero/modules/optimizers/adam.py +8 -8
- torchzero/modules/optimizers/lion.py +1 -1
- torchzero/modules/optimizers/rmsprop.py +1 -1
- torchzero/modules/optimizers/rprop.py +1 -1
- torchzero/modules/optimizers/sgd.py +2 -2
- torchzero/modules/orthogonalization/newtonschulz.py +3 -3
- torchzero/modules/orthogonalization/svd.py +1 -1
- torchzero/modules/regularization/dropout.py +1 -1
- torchzero/modules/regularization/noise.py +3 -3
- torchzero/modules/regularization/normalization.py +5 -5
- torchzero/modules/regularization/ortho_grad.py +1 -1
- torchzero/modules/regularization/weight_decay.py +1 -1
- torchzero/modules/scheduling/lr_schedulers.py +2 -2
- torchzero/modules/scheduling/step_size.py +8 -8
- torchzero/modules/second_order/newton.py +12 -12
- torchzero/modules/smoothing/__init__.py +1 -1
- torchzero/modules/smoothing/gaussian_smoothing.py +7 -7
- torchzero/modules/smoothing/laplacian_smoothing.py +1 -1
- torchzero/modules/weight_averaging/ema.py +3 -3
- torchzero/modules/weight_averaging/swa.py +8 -8
- torchzero/optim/first_order/forward_gradient.py +1 -1
- torchzero/optim/modular.py +4 -4
- torchzero/tensorlist.py +8 -1
- {torchzero-0.1.3.dist-info → torchzero-0.1.5.dist-info}/METADATA +1 -1
- torchzero-0.1.5.dist-info/RECORD +104 -0
- torchzero-0.1.3.dist-info/RECORD +0 -104
- {torchzero-0.1.3.dist-info → torchzero-0.1.5.dist-info}/LICENSE +0 -0
- {torchzero-0.1.3.dist-info → torchzero-0.1.5.dist-info}/WHEEL +0 -0
- {torchzero-0.1.3.dist-info → torchzero-0.1.5.dist-info}/top_level.txt +0 -0
|
@@ -5,14 +5,14 @@ from collections import abc
|
|
|
5
5
|
import torch
|
|
6
6
|
|
|
7
7
|
from ... import tensorlist as tl
|
|
8
|
-
from ...core import
|
|
8
|
+
from ...core import OptimizationVars, OptimizerModule, _Chain, _maybe_pass_backward
|
|
9
9
|
# this whole thing can also be implemented via parameter vectors.
|
|
10
10
|
# Need to test which one is more efficient...
|
|
11
11
|
|
|
12
12
|
class Projection(ABC):
|
|
13
13
|
n = 1
|
|
14
14
|
@abstractmethod
|
|
15
|
-
def sample(self, params: tl.TensorList,
|
|
15
|
+
def sample(self, params: tl.TensorList, vars: OptimizationVars) -> list[tl.TensorList]:
|
|
16
16
|
"""Generate a projection.
|
|
17
17
|
|
|
18
18
|
Args:
|
|
@@ -28,7 +28,7 @@ class ProjRandom(Projection):
|
|
|
28
28
|
self.distribution: tl.Distributions = distribution
|
|
29
29
|
self.n = n
|
|
30
30
|
|
|
31
|
-
def sample(self, params: tl.TensorList,
|
|
31
|
+
def sample(self, params: tl.TensorList, vars: OptimizationVars):
|
|
32
32
|
return [params.sample_like(distribution=self.distribution) for _ in range(self.n)]
|
|
33
33
|
|
|
34
34
|
|
|
@@ -42,7 +42,7 @@ class Proj2Masks(Projection):
|
|
|
42
42
|
def n(self):
|
|
43
43
|
return self.n_pairs * 2
|
|
44
44
|
|
|
45
|
-
def sample(self, params: tl.TensorList,
|
|
45
|
+
def sample(self, params: tl.TensorList, vars: OptimizationVars):
|
|
46
46
|
projections = []
|
|
47
47
|
for i in range(self.n_pairs):
|
|
48
48
|
mask = params.bernoulli_like(0.5)
|
|
@@ -55,9 +55,9 @@ class Proj2Masks(Projection):
|
|
|
55
55
|
|
|
56
56
|
class ProjAscent(Projection):
|
|
57
57
|
"""Use ascent direction as the projection."""
|
|
58
|
-
def sample(self, params: tl.TensorList,
|
|
59
|
-
if
|
|
60
|
-
return [
|
|
58
|
+
def sample(self, params: tl.TensorList, vars: OptimizationVars):
|
|
59
|
+
if vars.ascent is None: raise ValueError
|
|
60
|
+
return [vars.ascent]
|
|
61
61
|
|
|
62
62
|
class ProjAscentRay(Projection):
|
|
63
63
|
def __init__(self, eps = 0.1, n = 1, distribution: tl.Distributions = 'normal', ):
|
|
@@ -65,14 +65,14 @@ class ProjAscentRay(Projection):
|
|
|
65
65
|
self.distribution: tl.Distributions = distribution
|
|
66
66
|
self.n = n
|
|
67
67
|
|
|
68
|
-
def sample(self, params: tl.TensorList,
|
|
69
|
-
if
|
|
68
|
+
def sample(self, params: tl.TensorList, vars: OptimizationVars):
|
|
69
|
+
if vars.ascent is None: raise ValueError
|
|
70
70
|
mean = params.total_mean().detach().cpu().item()
|
|
71
|
-
return [
|
|
71
|
+
return [vars.ascent + vars.ascent.sample_like(mean * self.eps, distribution=self.distribution) for _ in range(self.n)]
|
|
72
72
|
|
|
73
73
|
class ProjGrad(Projection):
|
|
74
|
-
def sample(self, params: tl.TensorList,
|
|
75
|
-
grad =
|
|
74
|
+
def sample(self, params: tl.TensorList, vars: OptimizationVars):
|
|
75
|
+
grad = vars.maybe_compute_grad_(params)
|
|
76
76
|
return [grad]
|
|
77
77
|
|
|
78
78
|
class ProjGradRay(Projection):
|
|
@@ -81,8 +81,8 @@ class ProjGradRay(Projection):
|
|
|
81
81
|
self.distribution: tl.Distributions = distribution
|
|
82
82
|
self.n = n
|
|
83
83
|
|
|
84
|
-
def sample(self, params: tl.TensorList,
|
|
85
|
-
grad =
|
|
84
|
+
def sample(self, params: tl.TensorList, vars: OptimizationVars):
|
|
85
|
+
grad = vars.maybe_compute_grad_(params)
|
|
86
86
|
mean = params.total_mean().detach().cpu().item()
|
|
87
87
|
return [grad + grad.sample_like(mean * self.eps, distribution=self.distribution) for _ in range(self.n)]
|
|
88
88
|
|
|
@@ -95,23 +95,23 @@ class ProjGradAscentDifference(Projection):
|
|
|
95
95
|
"""
|
|
96
96
|
self.normalize = normalize
|
|
97
97
|
|
|
98
|
-
def sample(self, params: tl.TensorList,
|
|
99
|
-
grad =
|
|
98
|
+
def sample(self, params: tl.TensorList, vars: OptimizationVars):
|
|
99
|
+
grad = vars.maybe_compute_grad_(params)
|
|
100
100
|
if self.normalize:
|
|
101
|
-
return [
|
|
101
|
+
return [vars.ascent / vars.ascent.total_vector_norm(2) - grad / grad.total_vector_norm(2)] # type:ignore
|
|
102
102
|
|
|
103
|
-
return [
|
|
103
|
+
return [vars.ascent - grad] # type:ignore
|
|
104
104
|
|
|
105
105
|
class ProjLastGradDifference(Projection):
|
|
106
106
|
def __init__(self):
|
|
107
107
|
"""Use difference between last two gradients as the projection."""
|
|
108
108
|
self.last_grad = None
|
|
109
|
-
def sample(self, params: tl.TensorList,
|
|
109
|
+
def sample(self, params: tl.TensorList, vars: OptimizationVars):
|
|
110
110
|
if self.last_grad is None:
|
|
111
|
-
self.last_grad =
|
|
111
|
+
self.last_grad = vars.maybe_compute_grad_(params)
|
|
112
112
|
return [self.last_grad]
|
|
113
113
|
|
|
114
|
-
grad =
|
|
114
|
+
grad = vars.maybe_compute_grad_(params)
|
|
115
115
|
diff = grad - self.last_grad
|
|
116
116
|
self.last_grad = grad
|
|
117
117
|
return [diff]
|
|
@@ -121,13 +121,13 @@ class ProjLastAscentDifference(Projection):
|
|
|
121
121
|
"""Use difference between last two ascent directions as the projection."""
|
|
122
122
|
self.last_direction = T.cast(tl.TensorList, None)
|
|
123
123
|
|
|
124
|
-
def sample(self, params: tl.TensorList,
|
|
124
|
+
def sample(self, params: tl.TensorList, vars: OptimizationVars):
|
|
125
125
|
if self.last_direction is None:
|
|
126
|
-
self.last_direction: tl.TensorList =
|
|
126
|
+
self.last_direction: tl.TensorList = vars.ascent # type:ignore
|
|
127
127
|
return [self.last_direction]
|
|
128
128
|
|
|
129
|
-
diff =
|
|
130
|
-
self.last_direction =
|
|
129
|
+
diff = vars.ascent - self.last_direction # type:ignore
|
|
130
|
+
self.last_direction = vars.ascent # type:ignore
|
|
131
131
|
return [diff]
|
|
132
132
|
|
|
133
133
|
class ProjNormalize(Projection):
|
|
@@ -139,10 +139,10 @@ class ProjNormalize(Projection):
|
|
|
139
139
|
def n(self):
|
|
140
140
|
return sum(proj.n for proj in self.projections)
|
|
141
141
|
|
|
142
|
-
def sample(self, params: tl.TensorList,
|
|
143
|
-
vecs = [proj for obj in self.projections for proj in obj.sample(params,
|
|
142
|
+
def sample(self, params: tl.TensorList, vars: OptimizationVars): # type:ignore
|
|
143
|
+
vecs = [proj for obj in self.projections for proj in obj.sample(params, vars)]
|
|
144
144
|
norms = [v.total_vector_norm(2) for v in vecs]
|
|
145
|
-
return [v/norm if norm!=0 else v.randn_like() for v,norm in zip(vecs,norms)]
|
|
145
|
+
return [v/norm if norm!=0 else v.randn_like() for v,norm in zip(vecs,norms)] # type:ignore
|
|
146
146
|
|
|
147
147
|
class Subspace(OptimizerModule):
|
|
148
148
|
"""This is pretty inefficient, I thought of a much better way to do this via jvp and I will rewrite this soon.
|
|
@@ -198,17 +198,17 @@ class Subspace(OptimizerModule):
|
|
|
198
198
|
child.add_param_group({"params": params})
|
|
199
199
|
|
|
200
200
|
@torch.no_grad
|
|
201
|
-
def step(self,
|
|
201
|
+
def step(self, vars):
|
|
202
202
|
#if self.next_module is None: raise ValueError('RandomProjection needs a child')
|
|
203
|
-
if
|
|
204
|
-
closure =
|
|
203
|
+
if vars.closure is None: raise ValueError('RandomProjection needs a closure')
|
|
204
|
+
closure = vars.closure
|
|
205
205
|
params = self.get_params()
|
|
206
206
|
|
|
207
207
|
# every `regenerate_every` steps we generate new random projections.
|
|
208
208
|
if self.current_step == 0 or (self.update_every is not None and self.current_step % self.update_every == 0):
|
|
209
209
|
|
|
210
210
|
# generate n projection vetors
|
|
211
|
-
self.projection_vectors = [sample for proj in self.projections for sample in proj.sample(params,
|
|
211
|
+
self.projection_vectors = [sample for proj in self.projections for sample in proj.sample(params, vars)]
|
|
212
212
|
|
|
213
213
|
# child params is n scalars corresponding to each projection vector
|
|
214
214
|
self.projected_params = self.children['subspace']._params[0] # type:ignore
|
|
@@ -235,7 +235,7 @@ class Subspace(OptimizerModule):
|
|
|
235
235
|
# ascent_direction = tl.sum([ascent_direction*v for v in self.projection_vectors])
|
|
236
236
|
|
|
237
237
|
# perform a step with the child
|
|
238
|
-
subspace_state =
|
|
238
|
+
subspace_state = vars.copy(False)
|
|
239
239
|
subspace_state.closure = projected_closure
|
|
240
240
|
subspace_state.ascent = None
|
|
241
241
|
if subspace_state.grad is not None:
|
|
@@ -244,11 +244,11 @@ class Subspace(OptimizerModule):
|
|
|
244
244
|
|
|
245
245
|
# that is going to update child's paramers, which we now project back to the full parameter space
|
|
246
246
|
residual = tl.sum([vec * p for vec, p in zip(self.projection_vectors, self.projected_params)])
|
|
247
|
-
|
|
247
|
+
vars.ascent = residual.neg_()
|
|
248
248
|
|
|
249
249
|
# move fx0 and fx0 approx to state
|
|
250
|
-
if subspace_state.fx0 is not None:
|
|
251
|
-
if subspace_state.fx0_approx is not None:
|
|
250
|
+
if subspace_state.fx0 is not None: vars.fx0 = subspace_state.fx0
|
|
251
|
+
if subspace_state.fx0_approx is not None: vars.fx0 = subspace_state.fx0_approx
|
|
252
252
|
# projected_params are residuals that have been applied to actual params on previous step in some way
|
|
253
253
|
# therefore they need to now become zero (otherwise they work like momentum with no decay).
|
|
254
254
|
# note: THIS WON'T WORK WITH INTEGRATIONS, UNLESS THEY PERFORM FULL MINIMIZATION EACH STEP
|
|
@@ -256,4 +256,4 @@ class Subspace(OptimizerModule):
|
|
|
256
256
|
self.projected_params.zero_()
|
|
257
257
|
|
|
258
258
|
self.current_step += 1
|
|
259
|
-
return self._update_params_or_step_with_next(
|
|
259
|
+
return self._update_params_or_step_with_next(vars)
|
|
@@ -5,7 +5,7 @@ from typing import Any, Literal
|
|
|
5
5
|
import torch
|
|
6
6
|
|
|
7
7
|
from ...core import (
|
|
8
|
-
|
|
8
|
+
OptimizationVars,
|
|
9
9
|
OptimizerModule,
|
|
10
10
|
_ClosureType,
|
|
11
11
|
_maybe_pass_backward,
|
|
@@ -39,12 +39,12 @@ class GradientApproximatorBase(OptimizerModule, ABC):
|
|
|
39
39
|
super().__init__(defaults, target)
|
|
40
40
|
self.requires_fx0 = requires_fx0
|
|
41
41
|
|
|
42
|
-
def _step_make_closure_(self,
|
|
43
|
-
if
|
|
44
|
-
closure =
|
|
42
|
+
def _step_make_closure_(self, vars: OptimizationVars, params: TensorList):
|
|
43
|
+
if vars.closure is None: raise ValueError("gradient approximation requires closure")
|
|
44
|
+
closure = vars.closure
|
|
45
45
|
|
|
46
|
-
if self.requires_fx0: fx0 =
|
|
47
|
-
else: fx0 =
|
|
46
|
+
if self.requires_fx0: fx0 = vars.evaluate_fx0_(False)
|
|
47
|
+
else: fx0 = vars.fx0
|
|
48
48
|
|
|
49
49
|
def new_closure(backward=True) -> _ScalarLoss:
|
|
50
50
|
if backward:
|
|
@@ -56,35 +56,35 @@ class GradientApproximatorBase(OptimizerModule, ABC):
|
|
|
56
56
|
|
|
57
57
|
return closure(False)
|
|
58
58
|
|
|
59
|
-
|
|
59
|
+
vars.closure = new_closure
|
|
60
60
|
|
|
61
|
-
def _step_make_target_(self,
|
|
62
|
-
if
|
|
61
|
+
def _step_make_target_(self, vars: OptimizationVars, params: TensorList):
|
|
62
|
+
if vars.closure is None: raise ValueError("gradient approximation requires closure")
|
|
63
63
|
|
|
64
|
-
if self.requires_fx0: fx0 =
|
|
65
|
-
else: fx0 =
|
|
64
|
+
if self.requires_fx0: fx0 = vars.evaluate_fx0_(False)
|
|
65
|
+
else: fx0 = vars.fx0
|
|
66
66
|
|
|
67
|
-
g,
|
|
68
|
-
if self._default_step_target == 'ascent':
|
|
69
|
-
elif self._default_step_target == 'grad':
|
|
67
|
+
g, vars.fx0, vars.fx0_approx = self._make_ascent(vars.closure, params, fx0)
|
|
68
|
+
if self._default_step_target == 'ascent': vars.ascent = g
|
|
69
|
+
elif self._default_step_target == 'grad': vars.set_grad_(g, params)
|
|
70
70
|
else: raise ValueError(f"Unknown target {self._default_step_target}")
|
|
71
71
|
|
|
72
72
|
@torch.no_grad
|
|
73
|
-
def step(self,
|
|
73
|
+
def step(self, vars: OptimizationVars):
|
|
74
74
|
params = self.get_params()
|
|
75
75
|
if self._default_step_target == 'closure':
|
|
76
|
-
self._step_make_closure_(
|
|
76
|
+
self._step_make_closure_(vars, params)
|
|
77
77
|
|
|
78
78
|
else:
|
|
79
|
-
self._step_make_target_(
|
|
79
|
+
self._step_make_target_(vars, params)
|
|
80
80
|
|
|
81
|
-
return self._update_params_or_step_with_next(
|
|
81
|
+
return self._update_params_or_step_with_next(vars, params)
|
|
82
82
|
|
|
83
83
|
@abstractmethod
|
|
84
84
|
@torch.no_grad
|
|
85
85
|
def _make_ascent(
|
|
86
86
|
self,
|
|
87
|
-
#
|
|
87
|
+
# vars: OptimizationVars,
|
|
88
88
|
closure: _ClosureType,
|
|
89
89
|
params: TensorList,
|
|
90
90
|
fx0: Any,
|
|
@@ -95,11 +95,6 @@ class GradientApproximatorBase(OptimizerModule, ABC):
|
|
|
95
95
|
|
|
96
96
|
(ascent, fx0, fx0_approx)
|
|
97
97
|
|
|
98
|
-
:code:`ascent` is the approximated gradient,
|
|
99
|
-
:code:`fx0` is loss value strictly with initial parameters of the current step,
|
|
100
|
-
:code:`fx0_approx` is loss value with perturbed parameters (will be returned by optimizer step if fx0 is None).
|
|
101
|
-
:code:`fx0` and :code:`fx0_approx` can be None.
|
|
102
|
-
|
|
103
98
|
Args:
|
|
104
99
|
closure (_ClosureType): closure
|
|
105
100
|
params (TensorList): parameters
|
|
@@ -4,7 +4,7 @@ import torch
|
|
|
4
4
|
|
|
5
5
|
from ...utils.python_tools import _ScalarLoss
|
|
6
6
|
from ...tensorlist import TensorList
|
|
7
|
-
from ...core import _ClosureType, OptimizerModule,
|
|
7
|
+
from ...core import _ClosureType, OptimizerModule, OptimizationVars
|
|
8
8
|
from ._fd_formulas import _FD_Formulas
|
|
9
9
|
from .base_approximator import GradientApproximatorBase
|
|
10
10
|
|
|
@@ -121,16 +121,16 @@ class NewtonFDM(OptimizerModule):
|
|
|
121
121
|
self.tol = tol
|
|
122
122
|
|
|
123
123
|
@torch.no_grad
|
|
124
|
-
def step(self,
|
|
124
|
+
def step(self, vars):
|
|
125
125
|
"""Returns a new ascent direction."""
|
|
126
|
-
if
|
|
127
|
-
if
|
|
126
|
+
if vars.closure is None: raise ValueError('NewtonFDM requires a closure.')
|
|
127
|
+
if vars.ascent is not None: raise ValueError('NewtonFDM got ascent direction')
|
|
128
128
|
|
|
129
129
|
params = self.get_params()
|
|
130
130
|
epsilons = self.get_group_key('eps')
|
|
131
131
|
|
|
132
132
|
# evaluate fx0.
|
|
133
|
-
if
|
|
133
|
+
if vars.fx0 is None: vars.fx0 = vars.closure(False)
|
|
134
134
|
|
|
135
135
|
# evaluate gradients and hessian via finite differences.
|
|
136
136
|
grads = params.zeros_like()
|
|
@@ -152,7 +152,7 @@ class NewtonFDM(OptimizerModule):
|
|
|
152
152
|
cur2 += 1
|
|
153
153
|
continue
|
|
154
154
|
_three_point_2cd_(
|
|
155
|
-
closure =
|
|
155
|
+
closure = vars.closure,
|
|
156
156
|
idx1 = idx1,
|
|
157
157
|
idx2 = idx2,
|
|
158
158
|
p1 = flat_param1,
|
|
@@ -161,7 +161,7 @@ class NewtonFDM(OptimizerModule):
|
|
|
161
161
|
hessian = hessian,
|
|
162
162
|
eps1 = eps1,
|
|
163
163
|
eps2 = eps2,
|
|
164
|
-
fx0 =
|
|
164
|
+
fx0 = vars.fx0,
|
|
165
165
|
i1 = cur1,
|
|
166
166
|
i2 = cur2,
|
|
167
167
|
)
|
|
@@ -181,18 +181,18 @@ class NewtonFDM(OptimizerModule):
|
|
|
181
181
|
newton_step, success = _fallback_gd(hessian, gvec)
|
|
182
182
|
|
|
183
183
|
# update params or pass the gradients to the child.
|
|
184
|
-
|
|
184
|
+
vars.ascent = grads.from_vec(newton_step)
|
|
185
185
|
|
|
186
186
|
|
|
187
187
|
# validate if newton step decreased loss
|
|
188
188
|
if self.validate:
|
|
189
189
|
|
|
190
|
-
params.sub_(
|
|
191
|
-
fx1 =
|
|
192
|
-
params.add_(
|
|
190
|
+
params.sub_(vars.ascent)
|
|
191
|
+
fx1 = vars.closure(False)
|
|
192
|
+
params.add_(vars.ascent)
|
|
193
193
|
|
|
194
194
|
# if loss increases, set ascent direction to gvec times lr
|
|
195
|
-
if fx1 -
|
|
196
|
-
|
|
195
|
+
if fx1 - vars.fx0 > vars.fx0 * self.tol:
|
|
196
|
+
vars.ascent = grads.from_vec(gvec) * self.gd_lr
|
|
197
197
|
|
|
198
|
-
return self._update_params_or_step_with_next(
|
|
198
|
+
return self._update_params_or_step_with_next(vars, params)
|
|
@@ -4,7 +4,7 @@ import torch
|
|
|
4
4
|
|
|
5
5
|
from ...utils.python_tools import _ScalarLoss
|
|
6
6
|
from ...tensorlist import Distributions, TensorList
|
|
7
|
-
from ...core import _ClosureType, OptimizerModule,
|
|
7
|
+
from ...core import _ClosureType, OptimizerModule, OptimizationVars
|
|
8
8
|
from ._fd_formulas import _FD_Formulas
|
|
9
9
|
from .base_approximator import GradientApproximatorBase
|
|
10
10
|
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
|
|
3
3
|
from ...tensorlist import TensorList
|
|
4
|
-
from ...core import
|
|
4
|
+
from ...core import OptimizationVars
|
|
5
5
|
from .base_ls import LineSearchBase
|
|
6
6
|
|
|
7
7
|
|
|
@@ -32,23 +32,23 @@ class ArmijoLS(LineSearchBase):
|
|
|
32
32
|
self.max_iter = max_iter
|
|
33
33
|
|
|
34
34
|
@torch.no_grad
|
|
35
|
-
def _find_best_lr(self,
|
|
36
|
-
if
|
|
37
|
-
ascent =
|
|
38
|
-
grad =
|
|
35
|
+
def _find_best_lr(self, vars: OptimizationVars, params: TensorList) -> float:
|
|
36
|
+
if vars.closure is None: raise RuntimeError(f"Line searches ({self.__class__.__name__}) require a closure")
|
|
37
|
+
ascent = vars.maybe_use_grad_(params)
|
|
38
|
+
grad = vars.maybe_compute_grad_(params)
|
|
39
39
|
alpha = self.get_first_group_key('alpha')
|
|
40
|
-
if
|
|
40
|
+
if vars.fx0 is None: vars.fx0 = vars.closure(False)
|
|
41
41
|
|
|
42
42
|
# loss decrease per lr=1 if function was linear
|
|
43
43
|
decrease_per_lr = (grad*ascent).total_sum()
|
|
44
44
|
|
|
45
45
|
for _ in range(self.max_iter):
|
|
46
|
-
loss = self._evaluate_lr_(alpha,
|
|
46
|
+
loss = self._evaluate_lr_(alpha, vars.closure, ascent, params)
|
|
47
47
|
|
|
48
48
|
# expected decrease
|
|
49
49
|
expected_decrease = decrease_per_lr * alpha
|
|
50
50
|
|
|
51
|
-
if (
|
|
51
|
+
if (vars.fx0 - loss) / expected_decrease >= self.beta:
|
|
52
52
|
return alpha
|
|
53
53
|
|
|
54
54
|
alpha *= self.mul
|
|
@@ -5,7 +5,7 @@ from abc import ABC, abstractmethod
|
|
|
5
5
|
import torch
|
|
6
6
|
|
|
7
7
|
from ...tensorlist import TensorList
|
|
8
|
-
from ...core import _ClosureType,
|
|
8
|
+
from ...core import _ClosureType, OptimizationVars, OptimizerModule, _maybe_pass_backward
|
|
9
9
|
from ...utils.python_tools import _ScalarLoss
|
|
10
10
|
|
|
11
11
|
|
|
@@ -108,20 +108,20 @@ class LineSearchBase(OptimizerModule, ABC):
|
|
|
108
108
|
if isinstance(v, torch.Tensor): return v.detach().cpu().item()
|
|
109
109
|
return float(v)
|
|
110
110
|
|
|
111
|
-
def _find_best_lr(self,
|
|
111
|
+
def _find_best_lr(self, vars: OptimizationVars, params: TensorList) -> float:
|
|
112
112
|
"""This should return the best lr."""
|
|
113
113
|
... # pylint:disable=unnecessary-ellipsis
|
|
114
114
|
|
|
115
115
|
@torch.no_grad
|
|
116
|
-
def step(self,
|
|
116
|
+
def step(self, vars: OptimizationVars):
|
|
117
117
|
self._reset()
|
|
118
118
|
if self.log_lrs: self._lrs.append({})
|
|
119
119
|
|
|
120
120
|
params = self.get_params()
|
|
121
|
-
ascent_direction =
|
|
121
|
+
ascent_direction = vars.maybe_use_grad_(params)
|
|
122
122
|
|
|
123
123
|
try:
|
|
124
|
-
lr = self._find_best_lr(
|
|
124
|
+
lr = self._find_best_lr(vars, params) # pylint:disable=assignment-from-no-return
|
|
125
125
|
except MaxIterReached:
|
|
126
126
|
lr = self._best_lr
|
|
127
127
|
|
|
@@ -133,7 +133,7 @@ class LineSearchBase(OptimizerModule, ABC):
|
|
|
133
133
|
# otherwise undo the update by setting lr to 0 and instead multiply ascent direction by lr.
|
|
134
134
|
self._set_lr_(0, ascent_direction, params)
|
|
135
135
|
ascent_direction.mul_(self._best_lr)
|
|
136
|
-
|
|
137
|
-
if
|
|
138
|
-
return self.next_module.step(
|
|
136
|
+
vars.ascent = ascent_direction
|
|
137
|
+
if vars.fx0_approx is None: vars.fx0_approx = self._lowest_loss
|
|
138
|
+
return self.next_module.step(vars)
|
|
139
139
|
|
|
@@ -2,7 +2,7 @@ import numpy as np
|
|
|
2
2
|
import torch
|
|
3
3
|
|
|
4
4
|
from ...tensorlist import TensorList
|
|
5
|
-
from ...core import
|
|
5
|
+
from ...core import OptimizationVars
|
|
6
6
|
from .base_ls import LineSearchBase
|
|
7
7
|
|
|
8
8
|
_FloatOrTensor = float | torch.Tensor
|
|
@@ -57,14 +57,14 @@ class DirectionalNewton(LineSearchBase):
|
|
|
57
57
|
self.validate_step = validate_step
|
|
58
58
|
|
|
59
59
|
@torch.no_grad
|
|
60
|
-
def _find_best_lr(self,
|
|
61
|
-
if
|
|
62
|
-
closure =
|
|
60
|
+
def _find_best_lr(self, vars: OptimizationVars, params: TensorList) -> float:
|
|
61
|
+
if vars.closure is None: raise ValueError('QuardaticLS requires closure')
|
|
62
|
+
closure = vars.closure
|
|
63
63
|
|
|
64
64
|
params = self.get_params()
|
|
65
|
-
grad =
|
|
66
|
-
ascent =
|
|
67
|
-
if
|
|
65
|
+
grad = vars.maybe_compute_grad_(params)
|
|
66
|
+
ascent = vars.maybe_use_grad_(params)
|
|
67
|
+
if vars.fx0 is None: vars.fx0 = vars.closure(False) # at this stage maybe_compute_grad could've evaluated fx0
|
|
68
68
|
|
|
69
69
|
alpha: float = self.get_first_group_key('alpha') # this doesn't support variable lrs but we still want to support schedulers
|
|
70
70
|
|
|
@@ -78,7 +78,7 @@ class DirectionalNewton(LineSearchBase):
|
|
|
78
78
|
if y1_prime != 0:
|
|
79
79
|
xmin, a = _fit_and_minimize_quadratic_2points_grad(
|
|
80
80
|
x1=0,
|
|
81
|
-
y1=
|
|
81
|
+
y1=vars.fx0,
|
|
82
82
|
y1_prime=-y1_prime,
|
|
83
83
|
x2=alpha,
|
|
84
84
|
# we stepped in the direction of minus gradient times lr.
|
|
@@ -172,14 +172,14 @@ class DirectionalNewton3Points(LineSearchBase):
|
|
|
172
172
|
self.validate_step = validate_step
|
|
173
173
|
|
|
174
174
|
@torch.no_grad
|
|
175
|
-
def _find_best_lr(self,
|
|
176
|
-
if
|
|
177
|
-
closure =
|
|
178
|
-
ascent_direction =
|
|
175
|
+
def _find_best_lr(self, vars: OptimizationVars, params: TensorList) -> float:
|
|
176
|
+
if vars.closure is None: raise ValueError('QuardaticLS requires closure')
|
|
177
|
+
closure = vars.closure
|
|
178
|
+
ascent_direction = vars.ascent
|
|
179
179
|
if ascent_direction is None: raise ValueError('Ascent direction is None')
|
|
180
180
|
alpha: float = self.get_first_group_key('alpha')
|
|
181
181
|
|
|
182
|
-
if
|
|
182
|
+
if vars.fx0 is None: vars.fx0 = vars.closure(False)
|
|
183
183
|
params = self.get_params()
|
|
184
184
|
|
|
185
185
|
# make a step in the direction and evaluate f(x2)
|
|
@@ -190,7 +190,7 @@ class DirectionalNewton3Points(LineSearchBase):
|
|
|
190
190
|
|
|
191
191
|
# if gradients weren't 0
|
|
192
192
|
xmin, a = _newton_step_3points(
|
|
193
|
-
0,
|
|
193
|
+
0, vars.fx0,
|
|
194
194
|
# we stepped in the direction of minus ascent_direction.
|
|
195
195
|
alpha, y2,
|
|
196
196
|
alpha * 2, y3
|
|
@@ -5,7 +5,7 @@ import numpy as np
|
|
|
5
5
|
import torch
|
|
6
6
|
|
|
7
7
|
from ...tensorlist import TensorList
|
|
8
|
-
from ...core import _ClosureType,
|
|
8
|
+
from ...core import _ClosureType, OptimizationVars
|
|
9
9
|
from .base_ls import LineSearchBase
|
|
10
10
|
|
|
11
11
|
class GridLS(LineSearchBase):
|
|
@@ -34,16 +34,16 @@ class GridLS(LineSearchBase):
|
|
|
34
34
|
self.stop_on_worsened = stop_on_worsened
|
|
35
35
|
|
|
36
36
|
@torch.no_grad
|
|
37
|
-
def _find_best_lr(self,
|
|
38
|
-
if
|
|
39
|
-
if
|
|
37
|
+
def _find_best_lr(self, vars: OptimizationVars, params: TensorList) -> float:
|
|
38
|
+
if vars.closure is None: raise ValueError("closure is not set")
|
|
39
|
+
if vars.ascent is None: raise ValueError("ascent_direction is not set")
|
|
40
40
|
|
|
41
41
|
if self.stop_on_improvement:
|
|
42
|
-
if
|
|
43
|
-
self._lowest_loss =
|
|
42
|
+
if vars.fx0 is None: vars.fx0 = vars.closure(False)
|
|
43
|
+
self._lowest_loss = vars.fx0
|
|
44
44
|
|
|
45
45
|
for lr in self.lrs:
|
|
46
|
-
loss = self._evaluate_lr_(float(lr),
|
|
46
|
+
loss = self._evaluate_lr_(float(lr), vars.closure, vars.ascent, params)
|
|
47
47
|
|
|
48
48
|
# if worsened
|
|
49
49
|
if self.stop_on_worsened and loss != self._lowest_loss:
|
|
@@ -7,7 +7,7 @@ except ModuleNotFoundError:
|
|
|
7
7
|
scopt = typing.cast(typing.Any, None)
|
|
8
8
|
|
|
9
9
|
from ...tensorlist import TensorList
|
|
10
|
-
from ...core import
|
|
10
|
+
from ...core import OptimizationVars
|
|
11
11
|
|
|
12
12
|
from .base_ls import LineSearchBase, MaxIterReached
|
|
13
13
|
|
|
@@ -45,11 +45,11 @@ class ScipyMinimizeScalarLS(LineSearchBase):
|
|
|
45
45
|
self.options = options
|
|
46
46
|
|
|
47
47
|
@torch.no_grad
|
|
48
|
-
def _find_best_lr(self,
|
|
48
|
+
def _find_best_lr(self, vars: OptimizationVars, params: TensorList) -> float:
|
|
49
49
|
try:
|
|
50
50
|
res = scopt.minimize_scalar(
|
|
51
51
|
self._evaluate_lr_ensure_float,
|
|
52
|
-
args = (
|
|
52
|
+
args = (vars.closure, vars.ascent, params),
|
|
53
53
|
method = self.method,
|
|
54
54
|
tol = self.tol,
|
|
55
55
|
bracket = self.bracket,
|
|
@@ -40,7 +40,7 @@ class Alternate(OptimizerModule):
|
|
|
40
40
|
if len(self.mode) != len(self.children):
|
|
41
41
|
raise ValueError(f"got {len(self.children)} modules but {len(mode)} repeats, they should be the same")
|
|
42
42
|
|
|
43
|
-
def step(self,
|
|
43
|
+
def step(self, vars):
|
|
44
44
|
if self.mode == 'random':
|
|
45
45
|
module = self.random.choice(list(self.children.values()))
|
|
46
46
|
|
|
@@ -58,8 +58,8 @@ class Alternate(OptimizerModule):
|
|
|
58
58
|
self.remaining -= 1
|
|
59
59
|
|
|
60
60
|
if self.next_module is None:
|
|
61
|
-
return module.step(
|
|
61
|
+
return module.step(vars)
|
|
62
62
|
|
|
63
|
-
|
|
64
|
-
return self._update_params_or_step_with_next(
|
|
63
|
+
vars.ascent = module.return_ascent(vars)
|
|
64
|
+
return self._update_params_or_step_with_next(vars)
|
|
65
65
|
|