torchzero 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (104) hide show
  1. torchzero/__init__.py +4 -0
  2. torchzero/core/__init__.py +13 -0
  3. torchzero/core/module.py +471 -0
  4. torchzero/core/tensorlist_optimizer.py +219 -0
  5. torchzero/modules/__init__.py +21 -0
  6. torchzero/modules/adaptive/__init__.py +4 -0
  7. torchzero/modules/adaptive/adaptive.py +192 -0
  8. torchzero/modules/experimental/__init__.py +19 -0
  9. torchzero/modules/experimental/experimental.py +294 -0
  10. torchzero/modules/experimental/quad_interp.py +104 -0
  11. torchzero/modules/experimental/subspace.py +259 -0
  12. torchzero/modules/gradient_approximation/__init__.py +7 -0
  13. torchzero/modules/gradient_approximation/_fd_formulas.py +3 -0
  14. torchzero/modules/gradient_approximation/base_approximator.py +110 -0
  15. torchzero/modules/gradient_approximation/fdm.py +125 -0
  16. torchzero/modules/gradient_approximation/forward_gradient.py +163 -0
  17. torchzero/modules/gradient_approximation/newton_fdm.py +198 -0
  18. torchzero/modules/gradient_approximation/rfdm.py +125 -0
  19. torchzero/modules/line_search/__init__.py +30 -0
  20. torchzero/modules/line_search/armijo.py +56 -0
  21. torchzero/modules/line_search/base_ls.py +139 -0
  22. torchzero/modules/line_search/directional_newton.py +217 -0
  23. torchzero/modules/line_search/grid_ls.py +158 -0
  24. torchzero/modules/line_search/scipy_minimize_scalar.py +62 -0
  25. torchzero/modules/meta/__init__.py +12 -0
  26. torchzero/modules/meta/alternate.py +65 -0
  27. torchzero/modules/meta/grafting.py +195 -0
  28. torchzero/modules/meta/optimizer_wrapper.py +173 -0
  29. torchzero/modules/meta/return_overrides.py +46 -0
  30. torchzero/modules/misc/__init__.py +10 -0
  31. torchzero/modules/misc/accumulate.py +43 -0
  32. torchzero/modules/misc/basic.py +115 -0
  33. torchzero/modules/misc/lr.py +96 -0
  34. torchzero/modules/misc/multistep.py +51 -0
  35. torchzero/modules/misc/on_increase.py +53 -0
  36. torchzero/modules/momentum/__init__.py +4 -0
  37. torchzero/modules/momentum/momentum.py +106 -0
  38. torchzero/modules/operations/__init__.py +29 -0
  39. torchzero/modules/operations/multi.py +298 -0
  40. torchzero/modules/operations/reduction.py +134 -0
  41. torchzero/modules/operations/singular.py +113 -0
  42. torchzero/modules/optimizers/__init__.py +10 -0
  43. torchzero/modules/optimizers/adagrad.py +49 -0
  44. torchzero/modules/optimizers/adam.py +118 -0
  45. torchzero/modules/optimizers/lion.py +28 -0
  46. torchzero/modules/optimizers/rmsprop.py +51 -0
  47. torchzero/modules/optimizers/rprop.py +99 -0
  48. torchzero/modules/optimizers/sgd.py +54 -0
  49. torchzero/modules/orthogonalization/__init__.py +2 -0
  50. torchzero/modules/orthogonalization/newtonschulz.py +159 -0
  51. torchzero/modules/orthogonalization/svd.py +86 -0
  52. torchzero/modules/quasi_newton/__init__.py +4 -0
  53. torchzero/modules/regularization/__init__.py +22 -0
  54. torchzero/modules/regularization/dropout.py +34 -0
  55. torchzero/modules/regularization/noise.py +77 -0
  56. torchzero/modules/regularization/normalization.py +328 -0
  57. torchzero/modules/regularization/ortho_grad.py +78 -0
  58. torchzero/modules/regularization/weight_decay.py +92 -0
  59. torchzero/modules/scheduling/__init__.py +2 -0
  60. torchzero/modules/scheduling/lr_schedulers.py +131 -0
  61. torchzero/modules/scheduling/step_size.py +80 -0
  62. torchzero/modules/second_order/__init__.py +4 -0
  63. torchzero/modules/second_order/newton.py +165 -0
  64. torchzero/modules/smoothing/__init__.py +5 -0
  65. torchzero/modules/smoothing/gaussian_smoothing.py +90 -0
  66. torchzero/modules/smoothing/laplacian_smoothing.py +128 -0
  67. torchzero/modules/weight_averaging/__init__.py +2 -0
  68. torchzero/modules/weight_averaging/ema.py +72 -0
  69. torchzero/modules/weight_averaging/swa.py +171 -0
  70. torchzero/optim/__init__.py +10 -0
  71. torchzero/optim/experimental/__init__.py +20 -0
  72. torchzero/optim/experimental/experimental.py +343 -0
  73. torchzero/optim/experimental/ray_search.py +83 -0
  74. torchzero/optim/first_order/__init__.py +18 -0
  75. torchzero/optim/first_order/cautious.py +158 -0
  76. torchzero/optim/first_order/forward_gradient.py +70 -0
  77. torchzero/optim/first_order/optimizers.py +570 -0
  78. torchzero/optim/modular.py +132 -0
  79. torchzero/optim/quasi_newton/__init__.py +1 -0
  80. torchzero/optim/quasi_newton/directional_newton.py +58 -0
  81. torchzero/optim/second_order/__init__.py +1 -0
  82. torchzero/optim/second_order/newton.py +94 -0
  83. torchzero/optim/wrappers/__init__.py +0 -0
  84. torchzero/optim/wrappers/nevergrad.py +113 -0
  85. torchzero/optim/wrappers/nlopt.py +165 -0
  86. torchzero/optim/wrappers/scipy.py +439 -0
  87. torchzero/optim/zeroth_order/__init__.py +4 -0
  88. torchzero/optim/zeroth_order/fdm.py +87 -0
  89. torchzero/optim/zeroth_order/newton_fdm.py +146 -0
  90. torchzero/optim/zeroth_order/rfdm.py +217 -0
  91. torchzero/optim/zeroth_order/rs.py +85 -0
  92. torchzero/random/__init__.py +1 -0
  93. torchzero/random/random.py +46 -0
  94. torchzero/tensorlist.py +819 -0
  95. torchzero/utils/__init__.py +0 -0
  96. torchzero/utils/compile.py +39 -0
  97. torchzero/utils/derivatives.py +99 -0
  98. torchzero/utils/python_tools.py +25 -0
  99. torchzero/utils/torch_tools.py +92 -0
  100. torchzero-0.0.1.dist-info/LICENSE +21 -0
  101. torchzero-0.0.1.dist-info/METADATA +118 -0
  102. torchzero-0.0.1.dist-info/RECORD +104 -0
  103. torchzero-0.0.1.dist-info/WHEEL +5 -0
  104. torchzero-0.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,46 @@
1
+ import torch
2
+ from ...tensorlist import TensorList
3
+ from ...core import OptimizerModule, _get_loss, _ClosureType
4
+
5
+ class SetGrad(OptimizerModule):
6
+ """Doesn't update parameters, instead replaces all parameters `.grad` attribute with the current update.
7
+ You can now step with any pytorch optimizer that utilises the `.grad` attribute."""
8
+ def __init__(self):
9
+ super().__init__({})
10
+
11
+ @torch.no_grad
12
+ def step(self, state):
13
+ if self.next_module is not None: raise ValueError("SetGrad can't have children")
14
+ params = self.get_params()
15
+ g = state.maybe_use_grad_(params) # this may execute the closure which might be modified
16
+ params.set_grad_(g)
17
+ return state.get_loss()
18
+
19
+
20
+ class ReturnAscent(OptimizerModule):
21
+ """Doesn't update parameters, instead returns the update as a TensorList of tensors."""
22
+ def __init__(self):
23
+ super().__init__({})
24
+
25
+ @torch.no_grad
26
+ def step(self, state) -> TensorList: # type:ignore
27
+ if self.next_module is not None: raise ValueError("ReturnAscent can't have children")
28
+ params = self.get_params()
29
+ update = state.maybe_use_grad_(params) # this will execute the closure which might be modified
30
+ return update
31
+
32
+ class ReturnClosure(OptimizerModule):
33
+ """Doesn't update parameters, instead returns the current modified closure.
34
+ For example, if you put this after :code:`torchzero.modules.FDM(target = "closure")`,
35
+ the closure will set `.grad` attribute to gradients approximated via finite difference.
36
+ You can then pass that closure to something that requires closure like `torch.optim.LBFGS`."""
37
+ def __init__(self):
38
+ super().__init__({})
39
+
40
+ @torch.no_grad
41
+ def step(self, state) -> _ClosureType: # type:ignore
42
+ if self.next_module is not None: raise ValueError("ReturnClosure can't have children")
43
+ if state.closure is None:
44
+ raise ValueError("MakeClosure requires closure")
45
+ return state.closure
46
+
@@ -0,0 +1,10 @@
1
+ r"""
2
+ This module includes various basic operators, notable LR for setting the learning rate,
3
+ as well as gradient/update clipping and normalization.
4
+ """
5
+
6
+ from .basic import Clone, Fill, Grad, Identity, Lambda, Zeros, Alpha, GradToUpdate, MakeClosure
7
+ from .lr import LR
8
+ from .on_increase import NegateOnLossIncrease
9
+ from .multistep import Multistep
10
+ from .accumulate import Accumulate
@@ -0,0 +1,43 @@
1
+ from collections.abc import Callable, Iterable
2
+
3
+ import torch
4
+
5
+ from ...tensorlist import TensorList
6
+
7
+ from ...core import OptimizerModule
8
+
9
+
10
+ class Accumulate(OptimizerModule):
11
+ """Accumulates update over n steps, and steps once updates have been accumulated.
12
+ Put this as the first module to get gradient accumulation.
13
+
14
+ Args:
15
+ n_steps (int): number of steps (batches) to accumulate the update over.
16
+ mean (bool, optional):
17
+ If True, divides accumulated gradients by number of step,
18
+ since most loss functions calculate the mean of all samples
19
+ over batch. Defaults to True.
20
+ """
21
+ def __init__(self, n_steps: int, mean = True):
22
+
23
+ super().__init__({})
24
+ self.n_steps = n_steps
25
+ self.mean = mean
26
+ self.cur_step = 0
27
+
28
+ @torch.no_grad
29
+ def step(self, state):
30
+ self.cur_step += 1
31
+
32
+ params = self.get_params()
33
+ accumulated_update = self.get_state_key('accumulated_grads')
34
+ accumulated_update += state.maybe_use_grad_(params)
35
+
36
+ if self.cur_step % self.n_steps == 0:
37
+ state.ascent = accumulated_update.clone()
38
+ if self.mean: state.ascent /= self.n_steps
39
+ accumulated_update.zero_()
40
+ return self._update_params_or_step_with_next(state)
41
+
42
+
43
+ return state.get_loss()
@@ -0,0 +1,115 @@
1
+ from collections.abc import Callable, Iterable
2
+
3
+ import torch
4
+
5
+ from ...tensorlist import TensorList
6
+
7
+ from ...core import OptimizerModule, _Chainable
8
+
9
+
10
+ class Alpha(OptimizerModule):
11
+ """Multiplies update by the learning rate, won't get picked up by learning rate schedulers."""
12
+ def __init__(self, alpha = 1e-3):
13
+ defaults = dict(alpha = alpha)
14
+ super().__init__(defaults)
15
+
16
+ @torch.no_grad
17
+ def _update(self, state, ascent):
18
+ # multiply ascent direction by lr in-place
19
+ lr = self.get_group_key('alpha')
20
+ ascent *= lr
21
+ return ascent
22
+
23
+ class Clone(OptimizerModule):
24
+ """Clones the update. Some modules update ascent in-place, so this may be
25
+ useful if you need to preserve it."""
26
+ def __init__(self):
27
+ super().__init__({})
28
+
29
+ @torch.no_grad
30
+ def _update(self, state, ascent): return ascent.clone()
31
+
32
+ class Identity(OptimizerModule):
33
+ """Does nothing."""
34
+ def __init__(self, *args, **kwargs):
35
+ super().__init__({})
36
+
37
+ @torch.no_grad
38
+ def _update(self, state, ascent): return ascent
39
+
40
+ class Lambda(OptimizerModule):
41
+ """Applies a function to the ascent direction.
42
+ The function must take a TensorList as the argument, and return the modified tensorlist.
43
+
44
+ Args:
45
+ f (Callable): function
46
+ """
47
+ def __init__(self, f: Callable[[TensorList], TensorList]):
48
+ super().__init__({})
49
+ self.f = f
50
+
51
+ @torch.no_grad()
52
+ def _update(self, state, ascent): return self.f(ascent)
53
+
54
+ class Grad(OptimizerModule):
55
+ """Uses gradient as the update. This is useful for chains."""
56
+ def __init__(self):
57
+ super().__init__({})
58
+
59
+ @torch.no_grad
60
+ def _update(self, state, ascent):
61
+ ascent = state.ascent = state.maybe_compute_grad_(self.get_params())
62
+ return ascent
63
+
64
+ class Zeros(OptimizerModule):
65
+ def __init__(self):
66
+ super().__init__({})
67
+
68
+ @torch.no_grad
69
+ def _update(self, state, ascent):
70
+ return ascent.zeros_like()
71
+
72
+ class Fill(OptimizerModule):
73
+ def __init__(self, value):
74
+ super().__init__({"value": value})
75
+
76
+ @torch.no_grad
77
+ def _update(self, state, ascent):
78
+ return ascent.fill(self.get_group_key('value'))
79
+
80
+
81
+ class GradToUpdate(OptimizerModule):
82
+ """sets gradient and .grad attributes to current update"""
83
+ def __init__(self):
84
+ super().__init__({})
85
+
86
+ def _update(self, state, ascent):
87
+ state.set_grad_(ascent, self.get_params())
88
+ return ascent
89
+
90
+ class MakeClosure(OptimizerModule):
91
+ """Makes a closure that sets `.grad` attribute to the update generated by `modules`"""
92
+ def __init__(self, modules: _Chainable):
93
+ super().__init__({})
94
+ self._set_child_('modules', modules)
95
+
96
+ def step(self, state):
97
+ if state.closure is None: raise ValueError("MakeClosure requires a closure")
98
+
99
+ params = self.get_params()
100
+ orig_closure = state.closure
101
+ orig_state = state.copy(True)
102
+
103
+ def new_closure(backward = True):
104
+ if backward:
105
+ cloned_state = orig_state.copy(True)
106
+ g = self.children['modules'].return_ascent(cloned_state)
107
+ params.set_grad_(g)
108
+ return cloned_state.get_loss()
109
+
110
+ else:
111
+ return orig_closure(False)
112
+
113
+ state.closure = new_closure # type:ignore
114
+ return self._update_params_or_step_with_next(state)
115
+
@@ -0,0 +1,96 @@
1
+ import random
2
+ from collections.abc import Callable, Iterable
3
+ from functools import partial
4
+ from typing import TYPE_CHECKING, Any, overload
5
+
6
+ import torch
7
+
8
+ from ...tensorlist import TensorList
9
+
10
+ from ...core import OptimizerModule
11
+
12
+ if TYPE_CHECKING:
13
+ from ...optim import Modular
14
+
15
+ def _init_scheduler_hook(opt: "Modular", module: "LR", scheduler_cls, **kwargs):
16
+ """post init hook that initializes the lr scheduler to the LR module and sets `_scheduler_step_fn`."""
17
+ scheduler = scheduler_cls(module, **kwargs)
18
+ module._scheduler_step_fn = scheduler.step
19
+
20
+ def _set_momentum_hook(optimizer, state, momentum):
21
+ for module in optimizer.unrolled_modules:
22
+ if 'momentum' in module.defaults:
23
+ for g in module.param_groups:
24
+ g['momentum'] = momentum
25
+ elif 'beta1' in module.defaults:
26
+ for g in module.param_groups:
27
+ g['beta1'] = momentum
28
+
29
+ class LR(OptimizerModule):
30
+ """Multiplies update by the learning rate. Optionally uses an lr scheduler.
31
+
32
+ Args:
33
+ lr (float, optional): learning rate. Defaults to 1e-3.
34
+ scheduler (Callable[..., torch.optim.lr_scheduler.LRScheduler | Any] | None, optional):
35
+ A scheduler class, for example `torch.optim.lr_scheduler.OneCycleLR`. Defaults to None.
36
+ cycle_momentum (bool, optional):
37
+ enables schedulers that support it to affect momentum (like OneCycleLR).
38
+ The momentum will be cycled on ALL modules that have `momentum` or `beta1` setting.
39
+ This does not support external optimizers, wrapped with `Wrap`. Defaults to True.
40
+ sheduler_step_every (int, optional):
41
+ step with scheduler every n optimizer steps.
42
+ Useful when the scheduler steps once per epoch. Defaults to 1.
43
+ **kwargs:
44
+ kwargs to pass to `scheduler`.
45
+ """
46
+ IS_LR_MODULE = True
47
+ def __init__(
48
+ self,
49
+ lr: float = 1e-3,
50
+ scheduler_cls: Callable[..., torch.optim.lr_scheduler.LRScheduler | Any] | None = None,
51
+ cycle_momentum: bool = True,
52
+ sheduler_step_every: int = 1,
53
+ # *args,
54
+ **kwargs,
55
+ ):
56
+
57
+ defaults = dict(lr = lr)
58
+
59
+ if (scheduler_cls is not None) and cycle_momentum:
60
+ defaults['momentum'] = 0
61
+ super().__init__(defaults)
62
+
63
+ self._scheduler_step_fn = None
64
+ self.sheduler_step_every = sheduler_step_every
65
+ self.cycle_momentum = cycle_momentum
66
+ self.cur = 0
67
+
68
+ if scheduler_cls is not None:
69
+ self.post_init_hooks.append(lambda opt, module: _init_scheduler_hook(opt, module, scheduler_cls, **kwargs))
70
+
71
+ self._skip = False
72
+
73
+ @torch.no_grad
74
+ def _update(self, state, ascent):
75
+ # step with scheduler
76
+ if self._scheduler_step_fn is not None:
77
+ if self.cur != 0 and self.cur % self.sheduler_step_every == 0:
78
+ self._scheduler_step_fn()
79
+
80
+ # add a hook to cycle momentum
81
+ if self.cycle_momentum:
82
+ state.add_post_step_hook(_set_momentum_hook)
83
+
84
+ # remove init hook to delete reference to scheduler
85
+ if self.cur == 0 and len(self.post_init_hooks) == 1:
86
+ del self.post_init_hooks[0]
87
+
88
+ # skip if lr was applied by previous module (LR fusing)
89
+ if not self._skip:
90
+ # multiply ascent direction by lr in-place
91
+ lr = self.get_group_key('lr')
92
+ ascent *= lr
93
+
94
+ self.cur += 1
95
+ self._skip = False
96
+ return ascent
@@ -0,0 +1,51 @@
1
+ from collections.abc import Callable, Iterable
2
+
3
+ import torch
4
+
5
+ from ...tensorlist import TensorList
6
+
7
+ from ...core import OptimizerModule, _Chainable
8
+
9
+
10
+ class Multistep(OptimizerModule):
11
+ """Performs multiple steps (per batch), passes total update to the next module.
12
+
13
+ Args:
14
+ modules (_Chainable): modules to perform multiple steps with.
15
+ num_steps (int, optional): number of steps to perform. Defaults to 2.
16
+ """
17
+ def __init__(self, modules: _Chainable, num_steps: int = 2):
18
+ super().__init__({})
19
+ self.num_steps = num_steps
20
+
21
+ self._set_child_('modules', modules)
22
+
23
+ def step(self, state):
24
+ # no next module, just perform multiple steps
25
+ if self.next_module is None:
26
+ ret = None
27
+ for step in range(self.num_steps):
28
+ state_copy = state.copy(clone_ascent=True) if step != self.num_steps - 1 else state
29
+ ret = self.children['modules'].step(state_copy)
30
+
31
+ # since parameters are updated after stepping, grad and fx0 must be erased as they are no longer correct
32
+ state.grad = None; state.fx0 = None
33
+
34
+ return ret
35
+
36
+ # accumulate steps and pass to next module
37
+ p0 = self.get_params().clone()
38
+ for step in range(self.num_steps):
39
+ state_copy = state.copy(clone_ascent=True) if step != self.num_steps - 1 else state
40
+ self.children['modules'].step(state_copy)
41
+
42
+ # since parameters are updated after stepping, grad and fx0 must be erased as they are no longer correct
43
+ state.grad = None; state.fx0 = None
44
+
45
+ p1 = self.get_params()
46
+ state.ascent = p0 - p1
47
+
48
+ # undo ascent
49
+ p1.set_(p0)
50
+
51
+ return self._update_params_or_step_with_next(state, p1)
@@ -0,0 +1,53 @@
1
+ import torch
2
+
3
+ from ...core import OptimizerModule
4
+
5
+
6
+ class NegateOnLossIncrease(OptimizerModule):
7
+ """Performs an additional evaluation to check if update increases the loss. If it does,
8
+ negates or backtracks the update.
9
+
10
+ Args:
11
+ backtrack (bool, optional):
12
+ if True, sets update to minus update, otherwise sets it to zero. Defaults to True.
13
+ """
14
+ def __init__(self, backtrack = True):
15
+ super().__init__({})
16
+ self.backtrack = backtrack
17
+
18
+ @torch.no_grad()
19
+ def step(self, state):
20
+ if state.closure is None: raise ValueError('NegateOnLossIncrease requires closure.')
21
+ if state.fx0 is None: state.fx0 = state.closure(False)
22
+
23
+ # subtract ascent direction to params and see if loss decreases
24
+ params = self.get_params()
25
+ ascent_direction = state.maybe_use_grad_(params)
26
+ params -= ascent_direction
27
+ state.fx0_approx = state.closure(False)
28
+
29
+ # if this has no children, update params and return loss
30
+ if self.next_module is None:
31
+ if params is None: params = self.get_params()
32
+
33
+ if state.fx0_approx > state.fx0:
34
+ # loss increased, so we negate thea scent direction
35
+ # we are currently at params - ascent direction
36
+ # so we add twice the ascent direction
37
+ params.add_(ascent_direction, alpha = 2 if self.backtrack else 1)
38
+
39
+ # else: we are already at a lower loss point
40
+ return state.get_loss()
41
+
42
+ # otherwise undo the ascent direction because it is passed to the child
43
+ params += ascent_direction
44
+
45
+ # if loss increases, negate ascent direction
46
+ if state.fx0_approx > state.fx0:
47
+ if self.backtrack: ascent_direction.neg_()
48
+ else: ascent_direction.zero_()
49
+
50
+ # otherwise undo the ascent direction and pass the updated ascent direction to the child
51
+ return self.next_module.step(state)
52
+
53
+
@@ -0,0 +1,4 @@
1
+ """
2
+ Modules that implement momentum.
3
+ """
4
+ from .momentum import HeavyBall, NesterovMomentum, RandomCoordinateMomentum, GradientAveraging
@@ -0,0 +1,106 @@
1
+ from collections import abc
2
+
3
+ import torch
4
+
5
+ from ...tensorlist import TensorList
6
+ from ...core import OptimizerModule
7
+
8
+ def _heavyball_step(ascent, velocity: TensorList, momentum, dampening: TensorList):
9
+ velocity.mul_(momentum).add_(ascent * (1 - dampening))
10
+ return velocity.clone()
11
+
12
+ class HeavyBall(OptimizerModule):
13
+ """Polyak's (heavy ball) momentum. Exactly matches pytorch SGD `momentum` option.
14
+
15
+ Args:
16
+ decay (float, optional): momentum decay. Defaults to 0.9.
17
+ dampening (float, optional): momentum dampening. Defaults to 0.
18
+ """
19
+ def __init__(self, momentum: float = 0.9, dampening: float = 0, ):
20
+ defaults = dict(momentum = momentum, dampening = dampening)
21
+ super().__init__(defaults)
22
+
23
+ @torch.no_grad
24
+ def _update(self, state, ascent):
25
+ velocity = self.get_state_key('velocity', init = ascent)
26
+ settings = self.get_all_group_keys()
27
+ updated_direction = _heavyball_step(ascent, velocity, settings['momentum'], settings['dampening'])
28
+ return updated_direction
29
+
30
+
31
+ def _nesterov_step_(ascent, velocity: TensorList, momentum, dampening,):
32
+ # update velocity with the ascent direction
33
+ velocity += ascent
34
+
35
+ # decay velocity (this can be moved before previous line for slightly different results)
36
+ velocity *= momentum
37
+
38
+ # update ascent direction with velocity
39
+ ascent += velocity * (1 - dampening)
40
+
41
+
42
+ class NesterovMomentum(OptimizerModule):
43
+ """Nesterov momentum. Exactly matches pytorch SGD with `nesterov=True`,
44
+ except this also supports dampening.
45
+
46
+ Args:
47
+ decay (float, optional): momentum decay. Defaults to 0.9.
48
+ dampening (float, optional): momentum dampening. Defaults to 0.
49
+ """
50
+ def __init__(self, decay: float = 0.9, dampening: float = 0, ):
51
+ defaults = dict(momentum = decay, dampening = dampening)
52
+ super().__init__(defaults)
53
+
54
+ @torch.no_grad
55
+ def _update(self, state, ascent):
56
+ velocity = self.get_state_key('velocity')
57
+ settings = self.get_all_group_keys()
58
+ _nesterov_step_(ascent, velocity, settings['momentum'], settings['dampening'])
59
+ return ascent
60
+
61
+ class GradientAveraging(OptimizerModule):
62
+ """Averages last 2 gradients (TODO)"""
63
+ def __init__(self, dampening: float = 0, ):
64
+ defaults = dict(dampening = dampening)
65
+ super().__init__(defaults)
66
+
67
+ @torch.no_grad
68
+ def _update(self, state, ascent):
69
+ velocity = self.get_state_key('velocity')
70
+ dampening = self.get_group_key('dampening')
71
+
72
+ new_direction = ascent + velocity * (1-dampening)
73
+ velocity.copy_(ascent)
74
+
75
+ return new_direction
76
+
77
+
78
+ class RandomCoordinateMomentum(OptimizerModule):
79
+ """Only uses `p` random coordinates of the new update. Other coordinates remain from previous update.
80
+ This works but I don't know if it is any good.
81
+
82
+ Args:
83
+ p (float, optional): probability to update velocity with a new weigh value. Defaults to 0.1.
84
+ nesterov (bool, optional): if False, update uses delayed momentum. Defaults to True.
85
+ """
86
+ def __init__(self, p: float = 0.1, nesterov=True):
87
+ defaults = dict(p=p)
88
+ super().__init__(defaults)
89
+ self.nesterov = nesterov
90
+
91
+ @torch.no_grad
92
+ def _update(self, state, ascent):
93
+ velocity = self.get_state_key('velocity', init = ascent)
94
+ settings = self.get_all_group_keys()
95
+
96
+ # pick p veclocity indexes to update with the new ascent direction
97
+ indexes = ascent.bernoulli_like(settings['p']).as_bool()
98
+
99
+ if self.nesterov:
100
+ # update the velocity at those indexes
101
+ velocity.masked_set_(mask = indexes, value = ascent)
102
+ return velocity.clone()
103
+
104
+ new_ascent = velocity.clone()
105
+ velocity.masked_set_(mask = indexes, value = ascent)
106
+ return new_ascent
@@ -0,0 +1,29 @@
1
+ from .multi import (
2
+ Add,
3
+ AddMagnitude,
4
+ Div,
5
+ Divide,
6
+ Interpolate,
7
+ Lerp,
8
+ Mul,
9
+ Pow,
10
+ Power,
11
+ RDiv,
12
+ RPow,
13
+ RSub,
14
+ Sub,
15
+ Subtract,
16
+ )
17
+ from .reduction import Mean, Product, Sum
18
+ from .singular import (
19
+ Abs,
20
+ Cos,
21
+ MagnitudePower,
22
+ NanToNum,
23
+ Negate,
24
+ Operation,
25
+ Reciprocal,
26
+ Sign,
27
+ Sin,
28
+ sign_grad_,
29
+ )