torchzero 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torchzero/__init__.py +4 -0
- torchzero/core/__init__.py +13 -0
- torchzero/core/module.py +471 -0
- torchzero/core/tensorlist_optimizer.py +219 -0
- torchzero/modules/__init__.py +21 -0
- torchzero/modules/adaptive/__init__.py +4 -0
- torchzero/modules/adaptive/adaptive.py +192 -0
- torchzero/modules/experimental/__init__.py +19 -0
- torchzero/modules/experimental/experimental.py +294 -0
- torchzero/modules/experimental/quad_interp.py +104 -0
- torchzero/modules/experimental/subspace.py +259 -0
- torchzero/modules/gradient_approximation/__init__.py +7 -0
- torchzero/modules/gradient_approximation/_fd_formulas.py +3 -0
- torchzero/modules/gradient_approximation/base_approximator.py +110 -0
- torchzero/modules/gradient_approximation/fdm.py +125 -0
- torchzero/modules/gradient_approximation/forward_gradient.py +163 -0
- torchzero/modules/gradient_approximation/newton_fdm.py +198 -0
- torchzero/modules/gradient_approximation/rfdm.py +125 -0
- torchzero/modules/line_search/__init__.py +30 -0
- torchzero/modules/line_search/armijo.py +56 -0
- torchzero/modules/line_search/base_ls.py +139 -0
- torchzero/modules/line_search/directional_newton.py +217 -0
- torchzero/modules/line_search/grid_ls.py +158 -0
- torchzero/modules/line_search/scipy_minimize_scalar.py +62 -0
- torchzero/modules/meta/__init__.py +12 -0
- torchzero/modules/meta/alternate.py +65 -0
- torchzero/modules/meta/grafting.py +195 -0
- torchzero/modules/meta/optimizer_wrapper.py +173 -0
- torchzero/modules/meta/return_overrides.py +46 -0
- torchzero/modules/misc/__init__.py +10 -0
- torchzero/modules/misc/accumulate.py +43 -0
- torchzero/modules/misc/basic.py +115 -0
- torchzero/modules/misc/lr.py +96 -0
- torchzero/modules/misc/multistep.py +51 -0
- torchzero/modules/misc/on_increase.py +53 -0
- torchzero/modules/momentum/__init__.py +4 -0
- torchzero/modules/momentum/momentum.py +106 -0
- torchzero/modules/operations/__init__.py +29 -0
- torchzero/modules/operations/multi.py +298 -0
- torchzero/modules/operations/reduction.py +134 -0
- torchzero/modules/operations/singular.py +113 -0
- torchzero/modules/optimizers/__init__.py +10 -0
- torchzero/modules/optimizers/adagrad.py +49 -0
- torchzero/modules/optimizers/adam.py +118 -0
- torchzero/modules/optimizers/lion.py +28 -0
- torchzero/modules/optimizers/rmsprop.py +51 -0
- torchzero/modules/optimizers/rprop.py +99 -0
- torchzero/modules/optimizers/sgd.py +54 -0
- torchzero/modules/orthogonalization/__init__.py +2 -0
- torchzero/modules/orthogonalization/newtonschulz.py +159 -0
- torchzero/modules/orthogonalization/svd.py +86 -0
- torchzero/modules/quasi_newton/__init__.py +4 -0
- torchzero/modules/regularization/__init__.py +22 -0
- torchzero/modules/regularization/dropout.py +34 -0
- torchzero/modules/regularization/noise.py +77 -0
- torchzero/modules/regularization/normalization.py +328 -0
- torchzero/modules/regularization/ortho_grad.py +78 -0
- torchzero/modules/regularization/weight_decay.py +92 -0
- torchzero/modules/scheduling/__init__.py +2 -0
- torchzero/modules/scheduling/lr_schedulers.py +131 -0
- torchzero/modules/scheduling/step_size.py +80 -0
- torchzero/modules/second_order/__init__.py +4 -0
- torchzero/modules/second_order/newton.py +165 -0
- torchzero/modules/smoothing/__init__.py +5 -0
- torchzero/modules/smoothing/gaussian_smoothing.py +90 -0
- torchzero/modules/smoothing/laplacian_smoothing.py +128 -0
- torchzero/modules/weight_averaging/__init__.py +2 -0
- torchzero/modules/weight_averaging/ema.py +72 -0
- torchzero/modules/weight_averaging/swa.py +171 -0
- torchzero/optim/__init__.py +10 -0
- torchzero/optim/experimental/__init__.py +20 -0
- torchzero/optim/experimental/experimental.py +343 -0
- torchzero/optim/experimental/ray_search.py +83 -0
- torchzero/optim/first_order/__init__.py +18 -0
- torchzero/optim/first_order/cautious.py +158 -0
- torchzero/optim/first_order/forward_gradient.py +70 -0
- torchzero/optim/first_order/optimizers.py +570 -0
- torchzero/optim/modular.py +132 -0
- torchzero/optim/quasi_newton/__init__.py +1 -0
- torchzero/optim/quasi_newton/directional_newton.py +58 -0
- torchzero/optim/second_order/__init__.py +1 -0
- torchzero/optim/second_order/newton.py +94 -0
- torchzero/optim/wrappers/__init__.py +0 -0
- torchzero/optim/wrappers/nevergrad.py +113 -0
- torchzero/optim/wrappers/nlopt.py +165 -0
- torchzero/optim/wrappers/scipy.py +439 -0
- torchzero/optim/zeroth_order/__init__.py +4 -0
- torchzero/optim/zeroth_order/fdm.py +87 -0
- torchzero/optim/zeroth_order/newton_fdm.py +146 -0
- torchzero/optim/zeroth_order/rfdm.py +217 -0
- torchzero/optim/zeroth_order/rs.py +85 -0
- torchzero/random/__init__.py +1 -0
- torchzero/random/random.py +46 -0
- torchzero/tensorlist.py +819 -0
- torchzero/utils/__init__.py +0 -0
- torchzero/utils/compile.py +39 -0
- torchzero/utils/derivatives.py +99 -0
- torchzero/utils/python_tools.py +25 -0
- torchzero/utils/torch_tools.py +92 -0
- torchzero-0.0.1.dist-info/LICENSE +21 -0
- torchzero-0.0.1.dist-info/METADATA +118 -0
- torchzero-0.0.1.dist-info/RECORD +104 -0
- torchzero-0.0.1.dist-info/WHEEL +5 -0
- torchzero-0.0.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from ...tensorlist import TensorList
|
|
3
|
+
from ...core import OptimizerModule, _get_loss, _ClosureType
|
|
4
|
+
|
|
5
|
+
class SetGrad(OptimizerModule):
|
|
6
|
+
"""Doesn't update parameters, instead replaces all parameters `.grad` attribute with the current update.
|
|
7
|
+
You can now step with any pytorch optimizer that utilises the `.grad` attribute."""
|
|
8
|
+
def __init__(self):
|
|
9
|
+
super().__init__({})
|
|
10
|
+
|
|
11
|
+
@torch.no_grad
|
|
12
|
+
def step(self, state):
|
|
13
|
+
if self.next_module is not None: raise ValueError("SetGrad can't have children")
|
|
14
|
+
params = self.get_params()
|
|
15
|
+
g = state.maybe_use_grad_(params) # this may execute the closure which might be modified
|
|
16
|
+
params.set_grad_(g)
|
|
17
|
+
return state.get_loss()
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class ReturnAscent(OptimizerModule):
|
|
21
|
+
"""Doesn't update parameters, instead returns the update as a TensorList of tensors."""
|
|
22
|
+
def __init__(self):
|
|
23
|
+
super().__init__({})
|
|
24
|
+
|
|
25
|
+
@torch.no_grad
|
|
26
|
+
def step(self, state) -> TensorList: # type:ignore
|
|
27
|
+
if self.next_module is not None: raise ValueError("ReturnAscent can't have children")
|
|
28
|
+
params = self.get_params()
|
|
29
|
+
update = state.maybe_use_grad_(params) # this will execute the closure which might be modified
|
|
30
|
+
return update
|
|
31
|
+
|
|
32
|
+
class ReturnClosure(OptimizerModule):
|
|
33
|
+
"""Doesn't update parameters, instead returns the current modified closure.
|
|
34
|
+
For example, if you put this after :code:`torchzero.modules.FDM(target = "closure")`,
|
|
35
|
+
the closure will set `.grad` attribute to gradients approximated via finite difference.
|
|
36
|
+
You can then pass that closure to something that requires closure like `torch.optim.LBFGS`."""
|
|
37
|
+
def __init__(self):
|
|
38
|
+
super().__init__({})
|
|
39
|
+
|
|
40
|
+
@torch.no_grad
|
|
41
|
+
def step(self, state) -> _ClosureType: # type:ignore
|
|
42
|
+
if self.next_module is not None: raise ValueError("ReturnClosure can't have children")
|
|
43
|
+
if state.closure is None:
|
|
44
|
+
raise ValueError("MakeClosure requires closure")
|
|
45
|
+
return state.closure
|
|
46
|
+
|
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
r"""
|
|
2
|
+
This module includes various basic operators, notable LR for setting the learning rate,
|
|
3
|
+
as well as gradient/update clipping and normalization.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from .basic import Clone, Fill, Grad, Identity, Lambda, Zeros, Alpha, GradToUpdate, MakeClosure
|
|
7
|
+
from .lr import LR
|
|
8
|
+
from .on_increase import NegateOnLossIncrease
|
|
9
|
+
from .multistep import Multistep
|
|
10
|
+
from .accumulate import Accumulate
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
from collections.abc import Callable, Iterable
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from ...tensorlist import TensorList
|
|
6
|
+
|
|
7
|
+
from ...core import OptimizerModule
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class Accumulate(OptimizerModule):
|
|
11
|
+
"""Accumulates update over n steps, and steps once updates have been accumulated.
|
|
12
|
+
Put this as the first module to get gradient accumulation.
|
|
13
|
+
|
|
14
|
+
Args:
|
|
15
|
+
n_steps (int): number of steps (batches) to accumulate the update over.
|
|
16
|
+
mean (bool, optional):
|
|
17
|
+
If True, divides accumulated gradients by number of step,
|
|
18
|
+
since most loss functions calculate the mean of all samples
|
|
19
|
+
over batch. Defaults to True.
|
|
20
|
+
"""
|
|
21
|
+
def __init__(self, n_steps: int, mean = True):
|
|
22
|
+
|
|
23
|
+
super().__init__({})
|
|
24
|
+
self.n_steps = n_steps
|
|
25
|
+
self.mean = mean
|
|
26
|
+
self.cur_step = 0
|
|
27
|
+
|
|
28
|
+
@torch.no_grad
|
|
29
|
+
def step(self, state):
|
|
30
|
+
self.cur_step += 1
|
|
31
|
+
|
|
32
|
+
params = self.get_params()
|
|
33
|
+
accumulated_update = self.get_state_key('accumulated_grads')
|
|
34
|
+
accumulated_update += state.maybe_use_grad_(params)
|
|
35
|
+
|
|
36
|
+
if self.cur_step % self.n_steps == 0:
|
|
37
|
+
state.ascent = accumulated_update.clone()
|
|
38
|
+
if self.mean: state.ascent /= self.n_steps
|
|
39
|
+
accumulated_update.zero_()
|
|
40
|
+
return self._update_params_or_step_with_next(state)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
return state.get_loss()
|
|
@@ -0,0 +1,115 @@
|
|
|
1
|
+
from collections.abc import Callable, Iterable
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from ...tensorlist import TensorList
|
|
6
|
+
|
|
7
|
+
from ...core import OptimizerModule, _Chainable
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class Alpha(OptimizerModule):
|
|
11
|
+
"""Multiplies update by the learning rate, won't get picked up by learning rate schedulers."""
|
|
12
|
+
def __init__(self, alpha = 1e-3):
|
|
13
|
+
defaults = dict(alpha = alpha)
|
|
14
|
+
super().__init__(defaults)
|
|
15
|
+
|
|
16
|
+
@torch.no_grad
|
|
17
|
+
def _update(self, state, ascent):
|
|
18
|
+
# multiply ascent direction by lr in-place
|
|
19
|
+
lr = self.get_group_key('alpha')
|
|
20
|
+
ascent *= lr
|
|
21
|
+
return ascent
|
|
22
|
+
|
|
23
|
+
class Clone(OptimizerModule):
|
|
24
|
+
"""Clones the update. Some modules update ascent in-place, so this may be
|
|
25
|
+
useful if you need to preserve it."""
|
|
26
|
+
def __init__(self):
|
|
27
|
+
super().__init__({})
|
|
28
|
+
|
|
29
|
+
@torch.no_grad
|
|
30
|
+
def _update(self, state, ascent): return ascent.clone()
|
|
31
|
+
|
|
32
|
+
class Identity(OptimizerModule):
|
|
33
|
+
"""Does nothing."""
|
|
34
|
+
def __init__(self, *args, **kwargs):
|
|
35
|
+
super().__init__({})
|
|
36
|
+
|
|
37
|
+
@torch.no_grad
|
|
38
|
+
def _update(self, state, ascent): return ascent
|
|
39
|
+
|
|
40
|
+
class Lambda(OptimizerModule):
|
|
41
|
+
"""Applies a function to the ascent direction.
|
|
42
|
+
The function must take a TensorList as the argument, and return the modified tensorlist.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
f (Callable): function
|
|
46
|
+
"""
|
|
47
|
+
def __init__(self, f: Callable[[TensorList], TensorList]):
|
|
48
|
+
super().__init__({})
|
|
49
|
+
self.f = f
|
|
50
|
+
|
|
51
|
+
@torch.no_grad()
|
|
52
|
+
def _update(self, state, ascent): return self.f(ascent)
|
|
53
|
+
|
|
54
|
+
class Grad(OptimizerModule):
|
|
55
|
+
"""Uses gradient as the update. This is useful for chains."""
|
|
56
|
+
def __init__(self):
|
|
57
|
+
super().__init__({})
|
|
58
|
+
|
|
59
|
+
@torch.no_grad
|
|
60
|
+
def _update(self, state, ascent):
|
|
61
|
+
ascent = state.ascent = state.maybe_compute_grad_(self.get_params())
|
|
62
|
+
return ascent
|
|
63
|
+
|
|
64
|
+
class Zeros(OptimizerModule):
|
|
65
|
+
def __init__(self):
|
|
66
|
+
super().__init__({})
|
|
67
|
+
|
|
68
|
+
@torch.no_grad
|
|
69
|
+
def _update(self, state, ascent):
|
|
70
|
+
return ascent.zeros_like()
|
|
71
|
+
|
|
72
|
+
class Fill(OptimizerModule):
|
|
73
|
+
def __init__(self, value):
|
|
74
|
+
super().__init__({"value": value})
|
|
75
|
+
|
|
76
|
+
@torch.no_grad
|
|
77
|
+
def _update(self, state, ascent):
|
|
78
|
+
return ascent.fill(self.get_group_key('value'))
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
class GradToUpdate(OptimizerModule):
|
|
82
|
+
"""sets gradient and .grad attributes to current update"""
|
|
83
|
+
def __init__(self):
|
|
84
|
+
super().__init__({})
|
|
85
|
+
|
|
86
|
+
def _update(self, state, ascent):
|
|
87
|
+
state.set_grad_(ascent, self.get_params())
|
|
88
|
+
return ascent
|
|
89
|
+
|
|
90
|
+
class MakeClosure(OptimizerModule):
|
|
91
|
+
"""Makes a closure that sets `.grad` attribute to the update generated by `modules`"""
|
|
92
|
+
def __init__(self, modules: _Chainable):
|
|
93
|
+
super().__init__({})
|
|
94
|
+
self._set_child_('modules', modules)
|
|
95
|
+
|
|
96
|
+
def step(self, state):
|
|
97
|
+
if state.closure is None: raise ValueError("MakeClosure requires a closure")
|
|
98
|
+
|
|
99
|
+
params = self.get_params()
|
|
100
|
+
orig_closure = state.closure
|
|
101
|
+
orig_state = state.copy(True)
|
|
102
|
+
|
|
103
|
+
def new_closure(backward = True):
|
|
104
|
+
if backward:
|
|
105
|
+
cloned_state = orig_state.copy(True)
|
|
106
|
+
g = self.children['modules'].return_ascent(cloned_state)
|
|
107
|
+
params.set_grad_(g)
|
|
108
|
+
return cloned_state.get_loss()
|
|
109
|
+
|
|
110
|
+
else:
|
|
111
|
+
return orig_closure(False)
|
|
112
|
+
|
|
113
|
+
state.closure = new_closure # type:ignore
|
|
114
|
+
return self._update_params_or_step_with_next(state)
|
|
115
|
+
|
|
@@ -0,0 +1,96 @@
|
|
|
1
|
+
import random
|
|
2
|
+
from collections.abc import Callable, Iterable
|
|
3
|
+
from functools import partial
|
|
4
|
+
from typing import TYPE_CHECKING, Any, overload
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
|
|
8
|
+
from ...tensorlist import TensorList
|
|
9
|
+
|
|
10
|
+
from ...core import OptimizerModule
|
|
11
|
+
|
|
12
|
+
if TYPE_CHECKING:
|
|
13
|
+
from ...optim import Modular
|
|
14
|
+
|
|
15
|
+
def _init_scheduler_hook(opt: "Modular", module: "LR", scheduler_cls, **kwargs):
|
|
16
|
+
"""post init hook that initializes the lr scheduler to the LR module and sets `_scheduler_step_fn`."""
|
|
17
|
+
scheduler = scheduler_cls(module, **kwargs)
|
|
18
|
+
module._scheduler_step_fn = scheduler.step
|
|
19
|
+
|
|
20
|
+
def _set_momentum_hook(optimizer, state, momentum):
|
|
21
|
+
for module in optimizer.unrolled_modules:
|
|
22
|
+
if 'momentum' in module.defaults:
|
|
23
|
+
for g in module.param_groups:
|
|
24
|
+
g['momentum'] = momentum
|
|
25
|
+
elif 'beta1' in module.defaults:
|
|
26
|
+
for g in module.param_groups:
|
|
27
|
+
g['beta1'] = momentum
|
|
28
|
+
|
|
29
|
+
class LR(OptimizerModule):
|
|
30
|
+
"""Multiplies update by the learning rate. Optionally uses an lr scheduler.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
lr (float, optional): learning rate. Defaults to 1e-3.
|
|
34
|
+
scheduler (Callable[..., torch.optim.lr_scheduler.LRScheduler | Any] | None, optional):
|
|
35
|
+
A scheduler class, for example `torch.optim.lr_scheduler.OneCycleLR`. Defaults to None.
|
|
36
|
+
cycle_momentum (bool, optional):
|
|
37
|
+
enables schedulers that support it to affect momentum (like OneCycleLR).
|
|
38
|
+
The momentum will be cycled on ALL modules that have `momentum` or `beta1` setting.
|
|
39
|
+
This does not support external optimizers, wrapped with `Wrap`. Defaults to True.
|
|
40
|
+
sheduler_step_every (int, optional):
|
|
41
|
+
step with scheduler every n optimizer steps.
|
|
42
|
+
Useful when the scheduler steps once per epoch. Defaults to 1.
|
|
43
|
+
**kwargs:
|
|
44
|
+
kwargs to pass to `scheduler`.
|
|
45
|
+
"""
|
|
46
|
+
IS_LR_MODULE = True
|
|
47
|
+
def __init__(
|
|
48
|
+
self,
|
|
49
|
+
lr: float = 1e-3,
|
|
50
|
+
scheduler_cls: Callable[..., torch.optim.lr_scheduler.LRScheduler | Any] | None = None,
|
|
51
|
+
cycle_momentum: bool = True,
|
|
52
|
+
sheduler_step_every: int = 1,
|
|
53
|
+
# *args,
|
|
54
|
+
**kwargs,
|
|
55
|
+
):
|
|
56
|
+
|
|
57
|
+
defaults = dict(lr = lr)
|
|
58
|
+
|
|
59
|
+
if (scheduler_cls is not None) and cycle_momentum:
|
|
60
|
+
defaults['momentum'] = 0
|
|
61
|
+
super().__init__(defaults)
|
|
62
|
+
|
|
63
|
+
self._scheduler_step_fn = None
|
|
64
|
+
self.sheduler_step_every = sheduler_step_every
|
|
65
|
+
self.cycle_momentum = cycle_momentum
|
|
66
|
+
self.cur = 0
|
|
67
|
+
|
|
68
|
+
if scheduler_cls is not None:
|
|
69
|
+
self.post_init_hooks.append(lambda opt, module: _init_scheduler_hook(opt, module, scheduler_cls, **kwargs))
|
|
70
|
+
|
|
71
|
+
self._skip = False
|
|
72
|
+
|
|
73
|
+
@torch.no_grad
|
|
74
|
+
def _update(self, state, ascent):
|
|
75
|
+
# step with scheduler
|
|
76
|
+
if self._scheduler_step_fn is not None:
|
|
77
|
+
if self.cur != 0 and self.cur % self.sheduler_step_every == 0:
|
|
78
|
+
self._scheduler_step_fn()
|
|
79
|
+
|
|
80
|
+
# add a hook to cycle momentum
|
|
81
|
+
if self.cycle_momentum:
|
|
82
|
+
state.add_post_step_hook(_set_momentum_hook)
|
|
83
|
+
|
|
84
|
+
# remove init hook to delete reference to scheduler
|
|
85
|
+
if self.cur == 0 and len(self.post_init_hooks) == 1:
|
|
86
|
+
del self.post_init_hooks[0]
|
|
87
|
+
|
|
88
|
+
# skip if lr was applied by previous module (LR fusing)
|
|
89
|
+
if not self._skip:
|
|
90
|
+
# multiply ascent direction by lr in-place
|
|
91
|
+
lr = self.get_group_key('lr')
|
|
92
|
+
ascent *= lr
|
|
93
|
+
|
|
94
|
+
self.cur += 1
|
|
95
|
+
self._skip = False
|
|
96
|
+
return ascent
|
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
from collections.abc import Callable, Iterable
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from ...tensorlist import TensorList
|
|
6
|
+
|
|
7
|
+
from ...core import OptimizerModule, _Chainable
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class Multistep(OptimizerModule):
|
|
11
|
+
"""Performs multiple steps (per batch), passes total update to the next module.
|
|
12
|
+
|
|
13
|
+
Args:
|
|
14
|
+
modules (_Chainable): modules to perform multiple steps with.
|
|
15
|
+
num_steps (int, optional): number of steps to perform. Defaults to 2.
|
|
16
|
+
"""
|
|
17
|
+
def __init__(self, modules: _Chainable, num_steps: int = 2):
|
|
18
|
+
super().__init__({})
|
|
19
|
+
self.num_steps = num_steps
|
|
20
|
+
|
|
21
|
+
self._set_child_('modules', modules)
|
|
22
|
+
|
|
23
|
+
def step(self, state):
|
|
24
|
+
# no next module, just perform multiple steps
|
|
25
|
+
if self.next_module is None:
|
|
26
|
+
ret = None
|
|
27
|
+
for step in range(self.num_steps):
|
|
28
|
+
state_copy = state.copy(clone_ascent=True) if step != self.num_steps - 1 else state
|
|
29
|
+
ret = self.children['modules'].step(state_copy)
|
|
30
|
+
|
|
31
|
+
# since parameters are updated after stepping, grad and fx0 must be erased as they are no longer correct
|
|
32
|
+
state.grad = None; state.fx0 = None
|
|
33
|
+
|
|
34
|
+
return ret
|
|
35
|
+
|
|
36
|
+
# accumulate steps and pass to next module
|
|
37
|
+
p0 = self.get_params().clone()
|
|
38
|
+
for step in range(self.num_steps):
|
|
39
|
+
state_copy = state.copy(clone_ascent=True) if step != self.num_steps - 1 else state
|
|
40
|
+
self.children['modules'].step(state_copy)
|
|
41
|
+
|
|
42
|
+
# since parameters are updated after stepping, grad and fx0 must be erased as they are no longer correct
|
|
43
|
+
state.grad = None; state.fx0 = None
|
|
44
|
+
|
|
45
|
+
p1 = self.get_params()
|
|
46
|
+
state.ascent = p0 - p1
|
|
47
|
+
|
|
48
|
+
# undo ascent
|
|
49
|
+
p1.set_(p0)
|
|
50
|
+
|
|
51
|
+
return self._update_params_or_step_with_next(state, p1)
|
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
from ...core import OptimizerModule
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class NegateOnLossIncrease(OptimizerModule):
|
|
7
|
+
"""Performs an additional evaluation to check if update increases the loss. If it does,
|
|
8
|
+
negates or backtracks the update.
|
|
9
|
+
|
|
10
|
+
Args:
|
|
11
|
+
backtrack (bool, optional):
|
|
12
|
+
if True, sets update to minus update, otherwise sets it to zero. Defaults to True.
|
|
13
|
+
"""
|
|
14
|
+
def __init__(self, backtrack = True):
|
|
15
|
+
super().__init__({})
|
|
16
|
+
self.backtrack = backtrack
|
|
17
|
+
|
|
18
|
+
@torch.no_grad()
|
|
19
|
+
def step(self, state):
|
|
20
|
+
if state.closure is None: raise ValueError('NegateOnLossIncrease requires closure.')
|
|
21
|
+
if state.fx0 is None: state.fx0 = state.closure(False)
|
|
22
|
+
|
|
23
|
+
# subtract ascent direction to params and see if loss decreases
|
|
24
|
+
params = self.get_params()
|
|
25
|
+
ascent_direction = state.maybe_use_grad_(params)
|
|
26
|
+
params -= ascent_direction
|
|
27
|
+
state.fx0_approx = state.closure(False)
|
|
28
|
+
|
|
29
|
+
# if this has no children, update params and return loss
|
|
30
|
+
if self.next_module is None:
|
|
31
|
+
if params is None: params = self.get_params()
|
|
32
|
+
|
|
33
|
+
if state.fx0_approx > state.fx0:
|
|
34
|
+
# loss increased, so we negate thea scent direction
|
|
35
|
+
# we are currently at params - ascent direction
|
|
36
|
+
# so we add twice the ascent direction
|
|
37
|
+
params.add_(ascent_direction, alpha = 2 if self.backtrack else 1)
|
|
38
|
+
|
|
39
|
+
# else: we are already at a lower loss point
|
|
40
|
+
return state.get_loss()
|
|
41
|
+
|
|
42
|
+
# otherwise undo the ascent direction because it is passed to the child
|
|
43
|
+
params += ascent_direction
|
|
44
|
+
|
|
45
|
+
# if loss increases, negate ascent direction
|
|
46
|
+
if state.fx0_approx > state.fx0:
|
|
47
|
+
if self.backtrack: ascent_direction.neg_()
|
|
48
|
+
else: ascent_direction.zero_()
|
|
49
|
+
|
|
50
|
+
# otherwise undo the ascent direction and pass the updated ascent direction to the child
|
|
51
|
+
return self.next_module.step(state)
|
|
52
|
+
|
|
53
|
+
|
|
@@ -0,0 +1,106 @@
|
|
|
1
|
+
from collections import abc
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from ...tensorlist import TensorList
|
|
6
|
+
from ...core import OptimizerModule
|
|
7
|
+
|
|
8
|
+
def _heavyball_step(ascent, velocity: TensorList, momentum, dampening: TensorList):
|
|
9
|
+
velocity.mul_(momentum).add_(ascent * (1 - dampening))
|
|
10
|
+
return velocity.clone()
|
|
11
|
+
|
|
12
|
+
class HeavyBall(OptimizerModule):
|
|
13
|
+
"""Polyak's (heavy ball) momentum. Exactly matches pytorch SGD `momentum` option.
|
|
14
|
+
|
|
15
|
+
Args:
|
|
16
|
+
decay (float, optional): momentum decay. Defaults to 0.9.
|
|
17
|
+
dampening (float, optional): momentum dampening. Defaults to 0.
|
|
18
|
+
"""
|
|
19
|
+
def __init__(self, momentum: float = 0.9, dampening: float = 0, ):
|
|
20
|
+
defaults = dict(momentum = momentum, dampening = dampening)
|
|
21
|
+
super().__init__(defaults)
|
|
22
|
+
|
|
23
|
+
@torch.no_grad
|
|
24
|
+
def _update(self, state, ascent):
|
|
25
|
+
velocity = self.get_state_key('velocity', init = ascent)
|
|
26
|
+
settings = self.get_all_group_keys()
|
|
27
|
+
updated_direction = _heavyball_step(ascent, velocity, settings['momentum'], settings['dampening'])
|
|
28
|
+
return updated_direction
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def _nesterov_step_(ascent, velocity: TensorList, momentum, dampening,):
|
|
32
|
+
# update velocity with the ascent direction
|
|
33
|
+
velocity += ascent
|
|
34
|
+
|
|
35
|
+
# decay velocity (this can be moved before previous line for slightly different results)
|
|
36
|
+
velocity *= momentum
|
|
37
|
+
|
|
38
|
+
# update ascent direction with velocity
|
|
39
|
+
ascent += velocity * (1 - dampening)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class NesterovMomentum(OptimizerModule):
|
|
43
|
+
"""Nesterov momentum. Exactly matches pytorch SGD with `nesterov=True`,
|
|
44
|
+
except this also supports dampening.
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
decay (float, optional): momentum decay. Defaults to 0.9.
|
|
48
|
+
dampening (float, optional): momentum dampening. Defaults to 0.
|
|
49
|
+
"""
|
|
50
|
+
def __init__(self, decay: float = 0.9, dampening: float = 0, ):
|
|
51
|
+
defaults = dict(momentum = decay, dampening = dampening)
|
|
52
|
+
super().__init__(defaults)
|
|
53
|
+
|
|
54
|
+
@torch.no_grad
|
|
55
|
+
def _update(self, state, ascent):
|
|
56
|
+
velocity = self.get_state_key('velocity')
|
|
57
|
+
settings = self.get_all_group_keys()
|
|
58
|
+
_nesterov_step_(ascent, velocity, settings['momentum'], settings['dampening'])
|
|
59
|
+
return ascent
|
|
60
|
+
|
|
61
|
+
class GradientAveraging(OptimizerModule):
|
|
62
|
+
"""Averages last 2 gradients (TODO)"""
|
|
63
|
+
def __init__(self, dampening: float = 0, ):
|
|
64
|
+
defaults = dict(dampening = dampening)
|
|
65
|
+
super().__init__(defaults)
|
|
66
|
+
|
|
67
|
+
@torch.no_grad
|
|
68
|
+
def _update(self, state, ascent):
|
|
69
|
+
velocity = self.get_state_key('velocity')
|
|
70
|
+
dampening = self.get_group_key('dampening')
|
|
71
|
+
|
|
72
|
+
new_direction = ascent + velocity * (1-dampening)
|
|
73
|
+
velocity.copy_(ascent)
|
|
74
|
+
|
|
75
|
+
return new_direction
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
class RandomCoordinateMomentum(OptimizerModule):
|
|
79
|
+
"""Only uses `p` random coordinates of the new update. Other coordinates remain from previous update.
|
|
80
|
+
This works but I don't know if it is any good.
|
|
81
|
+
|
|
82
|
+
Args:
|
|
83
|
+
p (float, optional): probability to update velocity with a new weigh value. Defaults to 0.1.
|
|
84
|
+
nesterov (bool, optional): if False, update uses delayed momentum. Defaults to True.
|
|
85
|
+
"""
|
|
86
|
+
def __init__(self, p: float = 0.1, nesterov=True):
|
|
87
|
+
defaults = dict(p=p)
|
|
88
|
+
super().__init__(defaults)
|
|
89
|
+
self.nesterov = nesterov
|
|
90
|
+
|
|
91
|
+
@torch.no_grad
|
|
92
|
+
def _update(self, state, ascent):
|
|
93
|
+
velocity = self.get_state_key('velocity', init = ascent)
|
|
94
|
+
settings = self.get_all_group_keys()
|
|
95
|
+
|
|
96
|
+
# pick p veclocity indexes to update with the new ascent direction
|
|
97
|
+
indexes = ascent.bernoulli_like(settings['p']).as_bool()
|
|
98
|
+
|
|
99
|
+
if self.nesterov:
|
|
100
|
+
# update the velocity at those indexes
|
|
101
|
+
velocity.masked_set_(mask = indexes, value = ascent)
|
|
102
|
+
return velocity.clone()
|
|
103
|
+
|
|
104
|
+
new_ascent = velocity.clone()
|
|
105
|
+
velocity.masked_set_(mask = indexes, value = ascent)
|
|
106
|
+
return new_ascent
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
from .multi import (
|
|
2
|
+
Add,
|
|
3
|
+
AddMagnitude,
|
|
4
|
+
Div,
|
|
5
|
+
Divide,
|
|
6
|
+
Interpolate,
|
|
7
|
+
Lerp,
|
|
8
|
+
Mul,
|
|
9
|
+
Pow,
|
|
10
|
+
Power,
|
|
11
|
+
RDiv,
|
|
12
|
+
RPow,
|
|
13
|
+
RSub,
|
|
14
|
+
Sub,
|
|
15
|
+
Subtract,
|
|
16
|
+
)
|
|
17
|
+
from .reduction import Mean, Product, Sum
|
|
18
|
+
from .singular import (
|
|
19
|
+
Abs,
|
|
20
|
+
Cos,
|
|
21
|
+
MagnitudePower,
|
|
22
|
+
NanToNum,
|
|
23
|
+
Negate,
|
|
24
|
+
Operation,
|
|
25
|
+
Reciprocal,
|
|
26
|
+
Sign,
|
|
27
|
+
Sin,
|
|
28
|
+
sign_grad_,
|
|
29
|
+
)
|