torchzero 0.1.3__py3-none-any.whl → 0.1.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torchzero/core/__init__.py +1 -1
- torchzero/core/module.py +72 -49
- torchzero/core/tensorlist_optimizer.py +1 -1
- torchzero/modules/adaptive/adaptive.py +11 -11
- torchzero/modules/experimental/experimental.py +41 -41
- torchzero/modules/experimental/quad_interp.py +8 -8
- torchzero/modules/experimental/subspace.py +37 -37
- torchzero/modules/gradient_approximation/base_approximator.py +19 -24
- torchzero/modules/gradient_approximation/fdm.py +1 -1
- torchzero/modules/gradient_approximation/newton_fdm.py +13 -13
- torchzero/modules/gradient_approximation/rfdm.py +1 -1
- torchzero/modules/line_search/armijo.py +8 -8
- torchzero/modules/line_search/base_ls.py +8 -8
- torchzero/modules/line_search/directional_newton.py +14 -14
- torchzero/modules/line_search/grid_ls.py +7 -7
- torchzero/modules/line_search/scipy_minimize_scalar.py +3 -3
- torchzero/modules/meta/alternate.py +4 -4
- torchzero/modules/meta/grafting.py +23 -23
- torchzero/modules/meta/optimizer_wrapper.py +14 -14
- torchzero/modules/meta/return_overrides.py +8 -8
- torchzero/modules/misc/accumulate.py +6 -6
- torchzero/modules/misc/basic.py +16 -16
- torchzero/modules/misc/lr.py +2 -2
- torchzero/modules/misc/multistep.py +7 -7
- torchzero/modules/misc/on_increase.py +9 -9
- torchzero/modules/momentum/momentum.py +4 -4
- torchzero/modules/operations/multi.py +44 -44
- torchzero/modules/operations/reduction.py +28 -28
- torchzero/modules/operations/singular.py +9 -9
- torchzero/modules/optimizers/adagrad.py +1 -1
- torchzero/modules/optimizers/adam.py +8 -8
- torchzero/modules/optimizers/lion.py +1 -1
- torchzero/modules/optimizers/rmsprop.py +1 -1
- torchzero/modules/optimizers/rprop.py +1 -1
- torchzero/modules/optimizers/sgd.py +2 -2
- torchzero/modules/orthogonalization/newtonschulz.py +3 -3
- torchzero/modules/orthogonalization/svd.py +1 -1
- torchzero/modules/regularization/dropout.py +1 -1
- torchzero/modules/regularization/noise.py +3 -3
- torchzero/modules/regularization/normalization.py +5 -5
- torchzero/modules/regularization/ortho_grad.py +1 -1
- torchzero/modules/regularization/weight_decay.py +1 -1
- torchzero/modules/scheduling/lr_schedulers.py +2 -2
- torchzero/modules/scheduling/step_size.py +8 -8
- torchzero/modules/second_order/newton.py +12 -12
- torchzero/modules/smoothing/__init__.py +1 -1
- torchzero/modules/smoothing/gaussian_smoothing.py +7 -7
- torchzero/modules/smoothing/laplacian_smoothing.py +1 -1
- torchzero/modules/weight_averaging/ema.py +3 -3
- torchzero/modules/weight_averaging/swa.py +8 -8
- torchzero/optim/first_order/forward_gradient.py +1 -1
- torchzero/optim/modular.py +4 -4
- torchzero/tensorlist.py +8 -1
- {torchzero-0.1.3.dist-info → torchzero-0.1.5.dist-info}/METADATA +1 -1
- torchzero-0.1.5.dist-info/RECORD +104 -0
- torchzero-0.1.3.dist-info/RECORD +0 -104
- {torchzero-0.1.3.dist-info → torchzero-0.1.5.dist-info}/LICENSE +0 -0
- {torchzero-0.1.3.dist-info → torchzero-0.1.5.dist-info}/WHEEL +0 -0
- {torchzero-0.1.3.dist-info → torchzero-0.1.5.dist-info}/top_level.txt +0 -0
|
@@ -45,15 +45,15 @@ class Graft(OptimizerModule):
|
|
|
45
45
|
|
|
46
46
|
|
|
47
47
|
@torch.no_grad
|
|
48
|
-
def step(self,
|
|
49
|
-
state_copy =
|
|
48
|
+
def step(self, vars):
|
|
49
|
+
state_copy = vars.copy(clone_ascent=True)
|
|
50
50
|
magnitude = self.children['magnitude'].return_ascent(state_copy)
|
|
51
51
|
|
|
52
|
-
if state_copy.grad is not None:
|
|
53
|
-
if state_copy.fx0 is not None:
|
|
54
|
-
if state_copy.fx0_approx is not None:
|
|
52
|
+
if state_copy.grad is not None: vars.grad = state_copy.grad
|
|
53
|
+
if state_copy.fx0 is not None: vars.fx0 = state_copy.fx0
|
|
54
|
+
if state_copy.fx0_approx is not None: vars.fx0_approx = state_copy.fx0_approx
|
|
55
55
|
|
|
56
|
-
direction = self.children['direction'].return_ascent(
|
|
56
|
+
direction = self.children['direction'].return_ascent(vars)
|
|
57
57
|
|
|
58
58
|
if self.layerwise:
|
|
59
59
|
M = magnitude.norm(self.ord)
|
|
@@ -65,8 +65,8 @@ class Graft(OptimizerModule):
|
|
|
65
65
|
D = direction.total_vector_norm(self.ord)
|
|
66
66
|
if D == 0: D = M
|
|
67
67
|
|
|
68
|
-
|
|
69
|
-
return self._update_params_or_step_with_next(
|
|
68
|
+
vars.ascent = direction.mul_(M / (D + self.eps))
|
|
69
|
+
return self._update_params_or_step_with_next(vars)
|
|
70
70
|
|
|
71
71
|
|
|
72
72
|
|
|
@@ -94,17 +94,17 @@ class SignGrafting(OptimizerModule):
|
|
|
94
94
|
|
|
95
95
|
|
|
96
96
|
@torch.no_grad
|
|
97
|
-
def step(self,
|
|
98
|
-
state_copy =
|
|
97
|
+
def step(self, vars):
|
|
98
|
+
state_copy = vars.copy(clone_ascent=True)
|
|
99
99
|
magnitude = self.children['magnitude'].return_ascent(state_copy)
|
|
100
100
|
|
|
101
101
|
# make sure to store grad and fx0 if it was calculated
|
|
102
|
-
|
|
102
|
+
vars.update_attrs_(state_copy)
|
|
103
103
|
|
|
104
|
-
sign = self.children['sign'].return_ascent(
|
|
104
|
+
sign = self.children['sign'].return_ascent(vars)
|
|
105
105
|
|
|
106
|
-
|
|
107
|
-
return self._update_params_or_step_with_next(
|
|
106
|
+
vars.ascent = magnitude.copysign_(sign)
|
|
107
|
+
return self._update_params_or_step_with_next(vars)
|
|
108
108
|
|
|
109
109
|
|
|
110
110
|
class IntermoduleCautious(OptimizerModule):
|
|
@@ -153,17 +153,17 @@ class IntermoduleCautious(OptimizerModule):
|
|
|
153
153
|
self.mode: Literal["zero", "grad", "backtrack", "compare_module"] = mode
|
|
154
154
|
|
|
155
155
|
@torch.no_grad
|
|
156
|
-
def step(self,
|
|
156
|
+
def step(self, vars):
|
|
157
157
|
params = None
|
|
158
|
-
state_copy =
|
|
158
|
+
state_copy = vars.copy(clone_ascent=True)
|
|
159
159
|
ascent = self.children['main'].return_ascent(state_copy)
|
|
160
|
-
|
|
160
|
+
vars.update_attrs_(state_copy)
|
|
161
161
|
|
|
162
|
-
if self.compare_mode == 'module': compare = self.children['compare'].return_ascent(
|
|
162
|
+
if self.compare_mode == 'module': compare = self.children['compare'].return_ascent(vars)
|
|
163
163
|
else:
|
|
164
164
|
params = self.get_params()
|
|
165
|
-
if self.compare_mode == 'ascent': compare: TensorList =
|
|
166
|
-
elif self.compare_mode == 'grad': compare: TensorList =
|
|
165
|
+
if self.compare_mode == 'ascent': compare: TensorList = vars.maybe_use_grad_(params)
|
|
166
|
+
elif self.compare_mode == 'grad': compare: TensorList = vars.maybe_compute_grad_(params)
|
|
167
167
|
else: raise ValueError(f'Invalid compare_module: {self.compare_mode}')
|
|
168
168
|
|
|
169
169
|
# mask will be > 0 for parameters where both signs are the same
|
|
@@ -185,11 +185,11 @@ class IntermoduleCautious(OptimizerModule):
|
|
|
185
185
|
|
|
186
186
|
if self.mode == 'grad':
|
|
187
187
|
params = self.get_params()
|
|
188
|
-
ascent +=
|
|
188
|
+
ascent += vars.maybe_compute_grad_(params) * mask.logical_not_()
|
|
189
189
|
|
|
190
190
|
elif self.mode == 'compare_module':
|
|
191
191
|
ascent += compare * mask.logical_not_()
|
|
192
192
|
|
|
193
|
-
|
|
194
|
-
return self._update_params_or_step_with_next(
|
|
193
|
+
vars.ascent = ascent
|
|
194
|
+
return self._update_params_or_step_with_next(vars, params)
|
|
195
195
|
|
|
@@ -64,7 +64,7 @@ class Wrap(OptimizerModule):
|
|
|
64
64
|
self.optimizer = self._optimizer_cls(params, *self._args, **self._kwargs)
|
|
65
65
|
|
|
66
66
|
@torch.no_grad
|
|
67
|
-
def step(self,
|
|
67
|
+
def step(self, vars):
|
|
68
68
|
# check attrs
|
|
69
69
|
# if self.pass_closure:
|
|
70
70
|
# if state.closure is None: raise ValueError('ClosureOptimizerWrapper requires closure.')
|
|
@@ -75,22 +75,22 @@ class Wrap(OptimizerModule):
|
|
|
75
75
|
|
|
76
76
|
if self.next_module is None:
|
|
77
77
|
# set grad to ascent and make a step with the optimizer
|
|
78
|
-
g =
|
|
78
|
+
g = vars.maybe_use_grad_(params)
|
|
79
79
|
params.set_grad_(g)
|
|
80
|
-
|
|
81
|
-
return
|
|
80
|
+
vars.fx0 = self.optimizer.step()
|
|
81
|
+
return vars.get_loss()
|
|
82
82
|
|
|
83
83
|
|
|
84
84
|
params_before_step = params.clone()
|
|
85
85
|
|
|
86
|
-
g =
|
|
86
|
+
g = vars.maybe_use_grad_(params)
|
|
87
87
|
params.set_grad_(g)
|
|
88
|
-
|
|
88
|
+
vars.fx0 = self.optimizer.step()
|
|
89
89
|
|
|
90
90
|
# calculate update as difference in params
|
|
91
|
-
|
|
91
|
+
vars.ascent = params_before_step - params
|
|
92
92
|
params.set_(params_before_step)
|
|
93
|
-
return self.next_module.step(
|
|
93
|
+
return self.next_module.step(vars)
|
|
94
94
|
|
|
95
95
|
|
|
96
96
|
class WrapClosure(OptimizerModule):
|
|
@@ -148,7 +148,7 @@ class WrapClosure(OptimizerModule):
|
|
|
148
148
|
self.optimizer = self._optimizer_cls(params, *self._args, **self._kwargs)
|
|
149
149
|
|
|
150
150
|
@torch.no_grad
|
|
151
|
-
def step(self,
|
|
151
|
+
def step(self, vars):
|
|
152
152
|
# check attrs
|
|
153
153
|
# if self.pass_closure:
|
|
154
154
|
# if state.closure is None: raise ValueError('ClosureOptimizerWrapper requires closure.')
|
|
@@ -159,15 +159,15 @@ class WrapClosure(OptimizerModule):
|
|
|
159
159
|
|
|
160
160
|
if self.next_module is None:
|
|
161
161
|
# set grad to ascent and make a step with the optimizer
|
|
162
|
-
|
|
163
|
-
return
|
|
162
|
+
vars.fx0 = self.optimizer.step(vars.closure) # type:ignore
|
|
163
|
+
return vars.get_loss()
|
|
164
164
|
|
|
165
165
|
|
|
166
166
|
params_before_step = params.clone()
|
|
167
|
-
|
|
167
|
+
vars.fx0 = self.optimizer.step(vars.closure) # type:ignore
|
|
168
168
|
|
|
169
169
|
# calculate update as difference in params
|
|
170
|
-
|
|
170
|
+
vars.ascent = params_before_step - params
|
|
171
171
|
params.set_(params_before_step)
|
|
172
|
-
return self.next_module.step(
|
|
172
|
+
return self.next_module.step(vars)
|
|
173
173
|
|
|
@@ -9,12 +9,12 @@ class SetGrad(OptimizerModule):
|
|
|
9
9
|
super().__init__({})
|
|
10
10
|
|
|
11
11
|
@torch.no_grad
|
|
12
|
-
def step(self,
|
|
12
|
+
def step(self, vars):
|
|
13
13
|
if self.next_module is not None: raise ValueError("SetGrad can't have children")
|
|
14
14
|
params = self.get_params()
|
|
15
|
-
g =
|
|
15
|
+
g = vars.maybe_use_grad_(params) # this may execute the closure which might be modified
|
|
16
16
|
params.set_grad_(g)
|
|
17
|
-
return
|
|
17
|
+
return vars.get_loss()
|
|
18
18
|
|
|
19
19
|
|
|
20
20
|
class ReturnAscent(OptimizerModule):
|
|
@@ -23,10 +23,10 @@ class ReturnAscent(OptimizerModule):
|
|
|
23
23
|
super().__init__({})
|
|
24
24
|
|
|
25
25
|
@torch.no_grad
|
|
26
|
-
def step(self,
|
|
26
|
+
def step(self, vars) -> TensorList: # type:ignore
|
|
27
27
|
if self.next_module is not None: raise ValueError("ReturnAscent can't have children")
|
|
28
28
|
params = self.get_params()
|
|
29
|
-
update =
|
|
29
|
+
update = vars.maybe_use_grad_(params) # this will execute the closure which might be modified
|
|
30
30
|
return update
|
|
31
31
|
|
|
32
32
|
class ReturnClosure(OptimizerModule):
|
|
@@ -38,9 +38,9 @@ class ReturnClosure(OptimizerModule):
|
|
|
38
38
|
super().__init__({})
|
|
39
39
|
|
|
40
40
|
@torch.no_grad
|
|
41
|
-
def step(self,
|
|
41
|
+
def step(self, vars) -> _ClosureType: # type:ignore
|
|
42
42
|
if self.next_module is not None: raise ValueError("ReturnClosure can't have children")
|
|
43
|
-
if
|
|
43
|
+
if vars.closure is None:
|
|
44
44
|
raise ValueError("MakeClosure requires closure")
|
|
45
|
-
return
|
|
45
|
+
return vars.closure
|
|
46
46
|
|
|
@@ -26,18 +26,18 @@ class Accumulate(OptimizerModule):
|
|
|
26
26
|
self.cur_step = 0
|
|
27
27
|
|
|
28
28
|
@torch.no_grad
|
|
29
|
-
def step(self,
|
|
29
|
+
def step(self, vars):
|
|
30
30
|
self.cur_step += 1
|
|
31
31
|
|
|
32
32
|
params = self.get_params()
|
|
33
33
|
accumulated_update = self.get_state_key('accumulated_grads')
|
|
34
|
-
accumulated_update +=
|
|
34
|
+
accumulated_update += vars.maybe_use_grad_(params)
|
|
35
35
|
|
|
36
36
|
if self.cur_step % self.n_steps == 0:
|
|
37
|
-
|
|
38
|
-
if self.mean:
|
|
37
|
+
vars.ascent = accumulated_update.clone()
|
|
38
|
+
if self.mean: vars.ascent /= self.n_steps
|
|
39
39
|
accumulated_update.zero_()
|
|
40
|
-
return self._update_params_or_step_with_next(
|
|
40
|
+
return self._update_params_or_step_with_next(vars)
|
|
41
41
|
|
|
42
42
|
|
|
43
|
-
return
|
|
43
|
+
return vars.get_loss()
|
torchzero/modules/misc/basic.py
CHANGED
|
@@ -14,7 +14,7 @@ class Alpha(OptimizerModule):
|
|
|
14
14
|
super().__init__(defaults)
|
|
15
15
|
|
|
16
16
|
@torch.no_grad
|
|
17
|
-
def _update(self,
|
|
17
|
+
def _update(self, vars, ascent):
|
|
18
18
|
# multiply ascent direction by lr in-place
|
|
19
19
|
lr = self.get_group_key('alpha')
|
|
20
20
|
ascent *= lr
|
|
@@ -27,7 +27,7 @@ class Clone(OptimizerModule):
|
|
|
27
27
|
super().__init__({})
|
|
28
28
|
|
|
29
29
|
@torch.no_grad
|
|
30
|
-
def _update(self,
|
|
30
|
+
def _update(self, vars, ascent): return ascent.clone()
|
|
31
31
|
|
|
32
32
|
class Identity(OptimizerModule):
|
|
33
33
|
"""Does nothing."""
|
|
@@ -35,7 +35,7 @@ class Identity(OptimizerModule):
|
|
|
35
35
|
super().__init__({})
|
|
36
36
|
|
|
37
37
|
@torch.no_grad
|
|
38
|
-
def _update(self,
|
|
38
|
+
def _update(self, vars, ascent): return ascent
|
|
39
39
|
|
|
40
40
|
class Lambda(OptimizerModule):
|
|
41
41
|
"""Applies a function to the ascent direction.
|
|
@@ -49,7 +49,7 @@ class Lambda(OptimizerModule):
|
|
|
49
49
|
self.f = f
|
|
50
50
|
|
|
51
51
|
@torch.no_grad()
|
|
52
|
-
def _update(self,
|
|
52
|
+
def _update(self, vars, ascent): return self.f(ascent)
|
|
53
53
|
|
|
54
54
|
class Grad(OptimizerModule):
|
|
55
55
|
"""Uses gradient as the update. This is useful for chains."""
|
|
@@ -57,8 +57,8 @@ class Grad(OptimizerModule):
|
|
|
57
57
|
super().__init__({})
|
|
58
58
|
|
|
59
59
|
@torch.no_grad
|
|
60
|
-
def _update(self,
|
|
61
|
-
ascent =
|
|
60
|
+
def _update(self, vars, ascent):
|
|
61
|
+
ascent = vars.ascent = vars.maybe_compute_grad_(self.get_params())
|
|
62
62
|
return ascent
|
|
63
63
|
|
|
64
64
|
class Zeros(OptimizerModule):
|
|
@@ -66,7 +66,7 @@ class Zeros(OptimizerModule):
|
|
|
66
66
|
super().__init__({})
|
|
67
67
|
|
|
68
68
|
@torch.no_grad
|
|
69
|
-
def _update(self,
|
|
69
|
+
def _update(self, vars, ascent):
|
|
70
70
|
return ascent.zeros_like()
|
|
71
71
|
|
|
72
72
|
class Fill(OptimizerModule):
|
|
@@ -74,7 +74,7 @@ class Fill(OptimizerModule):
|
|
|
74
74
|
super().__init__({"value": value})
|
|
75
75
|
|
|
76
76
|
@torch.no_grad
|
|
77
|
-
def _update(self,
|
|
77
|
+
def _update(self, vars, ascent):
|
|
78
78
|
return ascent.fill(self.get_group_key('value'))
|
|
79
79
|
|
|
80
80
|
|
|
@@ -83,8 +83,8 @@ class GradToUpdate(OptimizerModule):
|
|
|
83
83
|
def __init__(self):
|
|
84
84
|
super().__init__({})
|
|
85
85
|
|
|
86
|
-
def _update(self,
|
|
87
|
-
|
|
86
|
+
def _update(self, vars, ascent):
|
|
87
|
+
vars.set_grad_(ascent, self.get_params())
|
|
88
88
|
return ascent
|
|
89
89
|
|
|
90
90
|
class MakeClosure(OptimizerModule):
|
|
@@ -93,12 +93,12 @@ class MakeClosure(OptimizerModule):
|
|
|
93
93
|
super().__init__({})
|
|
94
94
|
self._set_child_('modules', modules)
|
|
95
95
|
|
|
96
|
-
def step(self,
|
|
97
|
-
if
|
|
96
|
+
def step(self, vars):
|
|
97
|
+
if vars.closure is None: raise ValueError("MakeClosure requires a closure")
|
|
98
98
|
|
|
99
99
|
params = self.get_params()
|
|
100
|
-
orig_closure =
|
|
101
|
-
orig_state =
|
|
100
|
+
orig_closure = vars.closure
|
|
101
|
+
orig_state = vars.copy(True)
|
|
102
102
|
|
|
103
103
|
def new_closure(backward = True):
|
|
104
104
|
if backward:
|
|
@@ -110,6 +110,6 @@ class MakeClosure(OptimizerModule):
|
|
|
110
110
|
else:
|
|
111
111
|
return orig_closure(False)
|
|
112
112
|
|
|
113
|
-
|
|
114
|
-
return self._update_params_or_step_with_next(
|
|
113
|
+
vars.closure = new_closure # type:ignore
|
|
114
|
+
return self._update_params_or_step_with_next(vars)
|
|
115
115
|
|
torchzero/modules/misc/lr.py
CHANGED
|
@@ -71,7 +71,7 @@ class LR(OptimizerModule):
|
|
|
71
71
|
self._skip = False
|
|
72
72
|
|
|
73
73
|
@torch.no_grad
|
|
74
|
-
def _update(self,
|
|
74
|
+
def _update(self, vars, ascent):
|
|
75
75
|
# step with scheduler
|
|
76
76
|
if self._scheduler_step_fn is not None:
|
|
77
77
|
if self.cur != 0 and self.cur % self.sheduler_step_every == 0:
|
|
@@ -79,7 +79,7 @@ class LR(OptimizerModule):
|
|
|
79
79
|
|
|
80
80
|
# add a hook to cycle momentum
|
|
81
81
|
if self.cycle_momentum:
|
|
82
|
-
|
|
82
|
+
vars.add_post_step_hook(_set_momentum_hook)
|
|
83
83
|
|
|
84
84
|
# remove init hook to delete reference to scheduler
|
|
85
85
|
if self.cur == 0 and len(self.post_init_hooks) == 1:
|
|
@@ -20,32 +20,32 @@ class Multistep(OptimizerModule):
|
|
|
20
20
|
|
|
21
21
|
self._set_child_('modules', modules)
|
|
22
22
|
|
|
23
|
-
def step(self,
|
|
23
|
+
def step(self, vars):
|
|
24
24
|
# no next module, just perform multiple steps
|
|
25
25
|
if self.next_module is None:
|
|
26
26
|
ret = None
|
|
27
27
|
for step in range(self.num_steps):
|
|
28
|
-
state_copy =
|
|
28
|
+
state_copy = vars.copy(clone_ascent=True) if step != self.num_steps - 1 else vars
|
|
29
29
|
ret = self.children['modules'].step(state_copy)
|
|
30
30
|
|
|
31
31
|
# since parameters are updated after stepping, grad and fx0 must be erased as they are no longer correct
|
|
32
|
-
|
|
32
|
+
vars.grad = None; vars.fx0 = None
|
|
33
33
|
|
|
34
34
|
return ret
|
|
35
35
|
|
|
36
36
|
# accumulate steps and pass to next module
|
|
37
37
|
p0 = self.get_params().clone()
|
|
38
38
|
for step in range(self.num_steps):
|
|
39
|
-
state_copy =
|
|
39
|
+
state_copy = vars.copy(clone_ascent=True) if step != self.num_steps - 1 else vars
|
|
40
40
|
self.children['modules'].step(state_copy)
|
|
41
41
|
|
|
42
42
|
# since parameters are updated after stepping, grad and fx0 must be erased as they are no longer correct
|
|
43
|
-
|
|
43
|
+
vars.grad = None; vars.fx0 = None
|
|
44
44
|
|
|
45
45
|
p1 = self.get_params()
|
|
46
|
-
|
|
46
|
+
vars.ascent = p0 - p1
|
|
47
47
|
|
|
48
48
|
# undo ascent
|
|
49
49
|
p1.set_(p0)
|
|
50
50
|
|
|
51
|
-
return self._update_params_or_step_with_next(
|
|
51
|
+
return self._update_params_or_step_with_next(vars, p1)
|
|
@@ -16,38 +16,38 @@ class NegateOnLossIncrease(OptimizerModule):
|
|
|
16
16
|
self.backtrack = backtrack
|
|
17
17
|
|
|
18
18
|
@torch.no_grad()
|
|
19
|
-
def step(self,
|
|
20
|
-
if
|
|
21
|
-
if
|
|
19
|
+
def step(self, vars):
|
|
20
|
+
if vars.closure is None: raise ValueError('NegateOnLossIncrease requires closure.')
|
|
21
|
+
if vars.fx0 is None: vars.fx0 = vars.closure(False)
|
|
22
22
|
|
|
23
23
|
# subtract ascent direction to params and see if loss decreases
|
|
24
24
|
params = self.get_params()
|
|
25
|
-
ascent_direction =
|
|
25
|
+
ascent_direction = vars.maybe_use_grad_(params)
|
|
26
26
|
params -= ascent_direction
|
|
27
|
-
|
|
27
|
+
vars.fx0_approx = vars.closure(False)
|
|
28
28
|
|
|
29
29
|
# if this has no children, update params and return loss
|
|
30
30
|
if self.next_module is None:
|
|
31
31
|
if params is None: params = self.get_params()
|
|
32
32
|
|
|
33
|
-
if
|
|
33
|
+
if vars.fx0_approx > vars.fx0:
|
|
34
34
|
# loss increased, so we negate thea scent direction
|
|
35
35
|
# we are currently at params - ascent direction
|
|
36
36
|
# so we add twice the ascent direction
|
|
37
37
|
params.add_(ascent_direction, alpha = 2 if self.backtrack else 1)
|
|
38
38
|
|
|
39
39
|
# else: we are already at a lower loss point
|
|
40
|
-
return
|
|
40
|
+
return vars.get_loss()
|
|
41
41
|
|
|
42
42
|
# otherwise undo the ascent direction because it is passed to the child
|
|
43
43
|
params += ascent_direction
|
|
44
44
|
|
|
45
45
|
# if loss increases, negate ascent direction
|
|
46
|
-
if
|
|
46
|
+
if vars.fx0_approx > vars.fx0:
|
|
47
47
|
if self.backtrack: ascent_direction.neg_()
|
|
48
48
|
else: ascent_direction.zero_()
|
|
49
49
|
|
|
50
50
|
# otherwise undo the ascent direction and pass the updated ascent direction to the child
|
|
51
|
-
return self.next_module.step(
|
|
51
|
+
return self.next_module.step(vars)
|
|
52
52
|
|
|
53
53
|
|
|
@@ -21,7 +21,7 @@ class HeavyBall(OptimizerModule):
|
|
|
21
21
|
super().__init__(defaults)
|
|
22
22
|
|
|
23
23
|
@torch.no_grad
|
|
24
|
-
def _update(self,
|
|
24
|
+
def _update(self, vars, ascent):
|
|
25
25
|
velocity = self.get_state_key('velocity', init = ascent)
|
|
26
26
|
settings = self.get_all_group_keys()
|
|
27
27
|
updated_direction = _heavyball_step(ascent, velocity, settings['momentum'], settings['dampening'])
|
|
@@ -52,7 +52,7 @@ class NesterovMomentum(OptimizerModule):
|
|
|
52
52
|
super().__init__(defaults)
|
|
53
53
|
|
|
54
54
|
@torch.no_grad
|
|
55
|
-
def _update(self,
|
|
55
|
+
def _update(self, vars, ascent):
|
|
56
56
|
velocity = self.get_state_key('velocity')
|
|
57
57
|
settings = self.get_all_group_keys()
|
|
58
58
|
_nesterov_step_(ascent, velocity, settings['momentum'], settings['dampening'])
|
|
@@ -65,7 +65,7 @@ class GradientAveraging(OptimizerModule):
|
|
|
65
65
|
super().__init__(defaults)
|
|
66
66
|
|
|
67
67
|
@torch.no_grad
|
|
68
|
-
def _update(self,
|
|
68
|
+
def _update(self, vars, ascent):
|
|
69
69
|
velocity = self.get_state_key('velocity')
|
|
70
70
|
dampening = self.get_group_key('dampening')
|
|
71
71
|
|
|
@@ -89,7 +89,7 @@ class RandomCoordinateMomentum(OptimizerModule):
|
|
|
89
89
|
self.nesterov = nesterov
|
|
90
90
|
|
|
91
91
|
@torch.no_grad
|
|
92
|
-
def _update(self,
|
|
92
|
+
def _update(self, vars, ascent):
|
|
93
93
|
velocity = self.get_state_key('velocity', init = ascent)
|
|
94
94
|
settings = self.get_all_group_keys()
|
|
95
95
|
|