torchzero 0.1.8__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (200) hide show
  1. docs/source/conf.py +57 -0
  2. tests/test_identical.py +230 -0
  3. tests/test_module.py +50 -0
  4. tests/test_opts.py +884 -0
  5. tests/test_tensorlist.py +1787 -0
  6. tests/test_utils_optimizer.py +170 -0
  7. tests/test_vars.py +184 -0
  8. torchzero/__init__.py +4 -4
  9. torchzero/core/__init__.py +3 -13
  10. torchzero/core/module.py +629 -510
  11. torchzero/core/preconditioner.py +137 -0
  12. torchzero/core/transform.py +252 -0
  13. torchzero/modules/__init__.py +13 -21
  14. torchzero/modules/clipping/__init__.py +3 -0
  15. torchzero/modules/clipping/clipping.py +320 -0
  16. torchzero/modules/clipping/ema_clipping.py +135 -0
  17. torchzero/modules/clipping/growth_clipping.py +187 -0
  18. torchzero/modules/experimental/__init__.py +13 -18
  19. torchzero/modules/experimental/absoap.py +350 -0
  20. torchzero/modules/experimental/adadam.py +111 -0
  21. torchzero/modules/experimental/adamY.py +135 -0
  22. torchzero/modules/experimental/adasoap.py +282 -0
  23. torchzero/modules/experimental/algebraic_newton.py +145 -0
  24. torchzero/modules/experimental/curveball.py +89 -0
  25. torchzero/modules/experimental/dsoap.py +290 -0
  26. torchzero/modules/experimental/gradmin.py +85 -0
  27. torchzero/modules/experimental/reduce_outward_lr.py +35 -0
  28. torchzero/modules/experimental/spectral.py +286 -0
  29. torchzero/modules/experimental/subspace_preconditioners.py +128 -0
  30. torchzero/modules/experimental/tropical_newton.py +136 -0
  31. torchzero/modules/functional.py +209 -0
  32. torchzero/modules/grad_approximation/__init__.py +4 -0
  33. torchzero/modules/grad_approximation/fdm.py +120 -0
  34. torchzero/modules/grad_approximation/forward_gradient.py +81 -0
  35. torchzero/modules/grad_approximation/grad_approximator.py +66 -0
  36. torchzero/modules/grad_approximation/rfdm.py +259 -0
  37. torchzero/modules/line_search/__init__.py +5 -30
  38. torchzero/modules/line_search/backtracking.py +186 -0
  39. torchzero/modules/line_search/line_search.py +181 -0
  40. torchzero/modules/line_search/scipy.py +37 -0
  41. torchzero/modules/line_search/strong_wolfe.py +260 -0
  42. torchzero/modules/line_search/trust_region.py +61 -0
  43. torchzero/modules/lr/__init__.py +2 -0
  44. torchzero/modules/lr/lr.py +59 -0
  45. torchzero/modules/lr/step_size.py +97 -0
  46. torchzero/modules/momentum/__init__.py +14 -4
  47. torchzero/modules/momentum/averaging.py +78 -0
  48. torchzero/modules/momentum/cautious.py +181 -0
  49. torchzero/modules/momentum/ema.py +173 -0
  50. torchzero/modules/momentum/experimental.py +189 -0
  51. torchzero/modules/momentum/matrix_momentum.py +124 -0
  52. torchzero/modules/momentum/momentum.py +43 -106
  53. torchzero/modules/ops/__init__.py +103 -0
  54. torchzero/modules/ops/accumulate.py +65 -0
  55. torchzero/modules/ops/binary.py +240 -0
  56. torchzero/modules/ops/debug.py +25 -0
  57. torchzero/modules/ops/misc.py +419 -0
  58. torchzero/modules/ops/multi.py +137 -0
  59. torchzero/modules/ops/reduce.py +149 -0
  60. torchzero/modules/ops/split.py +75 -0
  61. torchzero/modules/ops/switch.py +68 -0
  62. torchzero/modules/ops/unary.py +115 -0
  63. torchzero/modules/ops/utility.py +112 -0
  64. torchzero/modules/optimizers/__init__.py +18 -10
  65. torchzero/modules/optimizers/adagrad.py +146 -49
  66. torchzero/modules/optimizers/adam.py +112 -118
  67. torchzero/modules/optimizers/lion.py +18 -11
  68. torchzero/modules/optimizers/muon.py +222 -0
  69. torchzero/modules/optimizers/orthograd.py +55 -0
  70. torchzero/modules/optimizers/rmsprop.py +103 -51
  71. torchzero/modules/optimizers/rprop.py +342 -99
  72. torchzero/modules/optimizers/shampoo.py +197 -0
  73. torchzero/modules/optimizers/soap.py +286 -0
  74. torchzero/modules/optimizers/sophia_h.py +129 -0
  75. torchzero/modules/projections/__init__.py +5 -0
  76. torchzero/modules/projections/dct.py +73 -0
  77. torchzero/modules/projections/fft.py +73 -0
  78. torchzero/modules/projections/galore.py +10 -0
  79. torchzero/modules/projections/projection.py +218 -0
  80. torchzero/modules/projections/structural.py +151 -0
  81. torchzero/modules/quasi_newton/__init__.py +7 -4
  82. torchzero/modules/quasi_newton/cg.py +218 -0
  83. torchzero/modules/quasi_newton/experimental/__init__.py +1 -0
  84. torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +265 -0
  85. torchzero/modules/quasi_newton/lbfgs.py +228 -0
  86. torchzero/modules/quasi_newton/lsr1.py +170 -0
  87. torchzero/modules/quasi_newton/olbfgs.py +196 -0
  88. torchzero/modules/quasi_newton/quasi_newton.py +475 -0
  89. torchzero/modules/second_order/__init__.py +3 -4
  90. torchzero/modules/second_order/newton.py +142 -165
  91. torchzero/modules/second_order/newton_cg.py +84 -0
  92. torchzero/modules/second_order/nystrom.py +168 -0
  93. torchzero/modules/smoothing/__init__.py +2 -5
  94. torchzero/modules/smoothing/gaussian.py +164 -0
  95. torchzero/modules/smoothing/{laplacian_smoothing.py → laplacian.py} +115 -128
  96. torchzero/modules/weight_decay/__init__.py +1 -0
  97. torchzero/modules/weight_decay/weight_decay.py +52 -0
  98. torchzero/modules/wrappers/__init__.py +1 -0
  99. torchzero/modules/wrappers/optim_wrapper.py +91 -0
  100. torchzero/optim/__init__.py +2 -10
  101. torchzero/optim/utility/__init__.py +1 -0
  102. torchzero/optim/utility/split.py +45 -0
  103. torchzero/optim/wrappers/nevergrad.py +2 -28
  104. torchzero/optim/wrappers/nlopt.py +31 -16
  105. torchzero/optim/wrappers/scipy.py +79 -156
  106. torchzero/utils/__init__.py +27 -0
  107. torchzero/utils/compile.py +175 -37
  108. torchzero/utils/derivatives.py +513 -99
  109. torchzero/utils/linalg/__init__.py +5 -0
  110. torchzero/utils/linalg/matrix_funcs.py +87 -0
  111. torchzero/utils/linalg/orthogonalize.py +11 -0
  112. torchzero/utils/linalg/qr.py +71 -0
  113. torchzero/utils/linalg/solve.py +168 -0
  114. torchzero/utils/linalg/svd.py +20 -0
  115. torchzero/utils/numberlist.py +132 -0
  116. torchzero/utils/ops.py +10 -0
  117. torchzero/utils/optimizer.py +284 -0
  118. torchzero/utils/optuna_tools.py +40 -0
  119. torchzero/utils/params.py +149 -0
  120. torchzero/utils/python_tools.py +40 -25
  121. torchzero/utils/tensorlist.py +1081 -0
  122. torchzero/utils/torch_tools.py +48 -12
  123. torchzero-0.3.1.dist-info/METADATA +379 -0
  124. torchzero-0.3.1.dist-info/RECORD +128 -0
  125. {torchzero-0.1.8.dist-info → torchzero-0.3.1.dist-info}/WHEEL +1 -1
  126. {torchzero-0.1.8.dist-info → torchzero-0.3.1.dist-info/licenses}/LICENSE +0 -0
  127. torchzero-0.3.1.dist-info/top_level.txt +3 -0
  128. torchzero/core/tensorlist_optimizer.py +0 -219
  129. torchzero/modules/adaptive/__init__.py +0 -4
  130. torchzero/modules/adaptive/adaptive.py +0 -192
  131. torchzero/modules/experimental/experimental.py +0 -294
  132. torchzero/modules/experimental/quad_interp.py +0 -104
  133. torchzero/modules/experimental/subspace.py +0 -259
  134. torchzero/modules/gradient_approximation/__init__.py +0 -7
  135. torchzero/modules/gradient_approximation/_fd_formulas.py +0 -3
  136. torchzero/modules/gradient_approximation/base_approximator.py +0 -105
  137. torchzero/modules/gradient_approximation/fdm.py +0 -125
  138. torchzero/modules/gradient_approximation/forward_gradient.py +0 -163
  139. torchzero/modules/gradient_approximation/newton_fdm.py +0 -198
  140. torchzero/modules/gradient_approximation/rfdm.py +0 -125
  141. torchzero/modules/line_search/armijo.py +0 -56
  142. torchzero/modules/line_search/base_ls.py +0 -139
  143. torchzero/modules/line_search/directional_newton.py +0 -217
  144. torchzero/modules/line_search/grid_ls.py +0 -158
  145. torchzero/modules/line_search/scipy_minimize_scalar.py +0 -62
  146. torchzero/modules/meta/__init__.py +0 -12
  147. torchzero/modules/meta/alternate.py +0 -65
  148. torchzero/modules/meta/grafting.py +0 -195
  149. torchzero/modules/meta/optimizer_wrapper.py +0 -173
  150. torchzero/modules/meta/return_overrides.py +0 -46
  151. torchzero/modules/misc/__init__.py +0 -10
  152. torchzero/modules/misc/accumulate.py +0 -43
  153. torchzero/modules/misc/basic.py +0 -115
  154. torchzero/modules/misc/lr.py +0 -96
  155. torchzero/modules/misc/multistep.py +0 -51
  156. torchzero/modules/misc/on_increase.py +0 -53
  157. torchzero/modules/operations/__init__.py +0 -29
  158. torchzero/modules/operations/multi.py +0 -298
  159. torchzero/modules/operations/reduction.py +0 -134
  160. torchzero/modules/operations/singular.py +0 -113
  161. torchzero/modules/optimizers/sgd.py +0 -54
  162. torchzero/modules/orthogonalization/__init__.py +0 -2
  163. torchzero/modules/orthogonalization/newtonschulz.py +0 -159
  164. torchzero/modules/orthogonalization/svd.py +0 -86
  165. torchzero/modules/regularization/__init__.py +0 -22
  166. torchzero/modules/regularization/dropout.py +0 -34
  167. torchzero/modules/regularization/noise.py +0 -77
  168. torchzero/modules/regularization/normalization.py +0 -328
  169. torchzero/modules/regularization/ortho_grad.py +0 -78
  170. torchzero/modules/regularization/weight_decay.py +0 -92
  171. torchzero/modules/scheduling/__init__.py +0 -2
  172. torchzero/modules/scheduling/lr_schedulers.py +0 -131
  173. torchzero/modules/scheduling/step_size.py +0 -80
  174. torchzero/modules/smoothing/gaussian_smoothing.py +0 -90
  175. torchzero/modules/weight_averaging/__init__.py +0 -2
  176. torchzero/modules/weight_averaging/ema.py +0 -72
  177. torchzero/modules/weight_averaging/swa.py +0 -171
  178. torchzero/optim/experimental/__init__.py +0 -20
  179. torchzero/optim/experimental/experimental.py +0 -343
  180. torchzero/optim/experimental/ray_search.py +0 -83
  181. torchzero/optim/first_order/__init__.py +0 -18
  182. torchzero/optim/first_order/cautious.py +0 -158
  183. torchzero/optim/first_order/forward_gradient.py +0 -70
  184. torchzero/optim/first_order/optimizers.py +0 -570
  185. torchzero/optim/modular.py +0 -148
  186. torchzero/optim/quasi_newton/__init__.py +0 -1
  187. torchzero/optim/quasi_newton/directional_newton.py +0 -58
  188. torchzero/optim/second_order/__init__.py +0 -1
  189. torchzero/optim/second_order/newton.py +0 -94
  190. torchzero/optim/zeroth_order/__init__.py +0 -4
  191. torchzero/optim/zeroth_order/fdm.py +0 -87
  192. torchzero/optim/zeroth_order/newton_fdm.py +0 -146
  193. torchzero/optim/zeroth_order/rfdm.py +0 -217
  194. torchzero/optim/zeroth_order/rs.py +0 -85
  195. torchzero/random/__init__.py +0 -1
  196. torchzero/random/random.py +0 -46
  197. torchzero/tensorlist.py +0 -826
  198. torchzero-0.1.8.dist-info/METADATA +0 -130
  199. torchzero-0.1.8.dist-info/RECORD +0 -104
  200. torchzero-0.1.8.dist-info/top_level.txt +0 -1
@@ -1,219 +0,0 @@
1
- from typing import Literal, Any, overload, TypeVar
2
- from abc import ABC
3
- from collections.abc import Callable, Sequence, Iterable, Mapping, MutableSequence
4
- import numpy as np
5
- import torch
6
- import torch.optim.optimizer
7
- from torch.optim.optimizer import ParamsT
8
-
9
- from ..tensorlist import TensorList, NumberList
10
- from ..utils.torch_tools import totensor, tofloat
11
- from ..utils.python_tools import _ScalarLoss
12
-
13
- _StateInit = Literal['params', 'grad'] | Callable | TensorList
14
-
15
- _ClosureType = Callable[..., _ScalarLoss]
16
- """
17
-
18
- Closure example:
19
-
20
- .. code-block:: python
21
-
22
- def closure(backward = True):
23
- loss = model(inputs)
24
- if backward:
25
- optimizer.zero_grad()
26
- loss.backward()
27
- return loss
28
-
29
- This closure will also work with all built in pytorch optimizers including LBFGS, as well as and most custom ones.
30
- """
31
-
32
- def _maybe_pass_backward(closure: _ClosureType, backward: bool) -> _ScalarLoss:
33
- """not passing backward when it is true makes this work with closures with no `backward` argument"""
34
- if backward:
35
- with torch.enable_grad(): return closure()
36
- return closure(False)
37
-
38
- CLS = TypeVar('CLS')
39
- class TensorListOptimizer(torch.optim.Optimizer, ABC):
40
- """torch.optim.Optimizer with some additional methods related to TensorList.
41
-
42
- Args:
43
- params (ParamsT): iterable of parameters.
44
- defaults (_type_): dictionary with default parameters for the optimizer.
45
- """
46
- def __init__(self, params: ParamsT, defaults):
47
- super().__init__(params, defaults)
48
- self._params: list[torch.Tensor] = [param for group in self.param_groups for param in group['params']]
49
- self.has_complex = any(torch.is_complex(x) for x in self._params)
50
- """True if any of the params are complex"""
51
-
52
- def add_param_group(self, param_group: dict[str, Any]) -> None:
53
- super().add_param_group(param_group)
54
- self._params: list[torch.Tensor] = [param for group in self.param_groups for param in group['params']]
55
- self.has_complex = any(torch.is_complex(x) for x in self._params)
56
-
57
- # def get_params[CLS: Any](self, cls: type[CLS] = TensorList) -> CLS:
58
- def get_params(self, cls: type[CLS] = TensorList) -> CLS:
59
- """returns all params with `requires_grad = True` as a TensorList."""
60
- return cls(p for p in self._params if p.requires_grad) # type:ignore
61
-
62
- def ensure_grad_(self):
63
- """Replaces None grad attribute with zeroes for all parameters that require grad."""
64
- for p in self.get_params():
65
- if p.requires_grad and p.grad is None: p.grad = torch.zeros_like(p)
66
-
67
- # def get_state_key[CLS: MutableSequence](self, key: str, init: _StateInit = torch.zeros_like, params=None, cls: type[CLS] = TensorList) -> CLS:
68
- def get_state_key(self, key: str, init: _StateInit = torch.zeros_like, params=None, cls: type[CLS] = TensorList) -> CLS:
69
- """Returns a tensorlist of all `key` states of all params with `requires_grad = True`.
70
-
71
- Args:
72
- key (str): key to create/access.
73
- init: Initial value if key doesn't exist. Can be `params`, `grad`, or callable such as `torch.zeros_like`.
74
- Defaults to torch.zeros_like.
75
- params (optional): optionally pass params if you already created them. Defaults to None.
76
- cls (optional): optionally specify any other MutableSequence subclass to use instead of TensorList.
77
-
78
- Returns:
79
- TensorList: TensorList with the `key` state. Those tensors are stored in the optimizer, so modify them in-place.
80
- """
81
- value = cls()
82
- if params is None: params = self.get_params()
83
- for pi, p in enumerate(params):
84
- state = self.state[p]
85
- if key not in state:
86
- if callable(init): state[key] = init(p)
87
- elif isinstance(init, TensorList): state[key] = init[pi].clone()
88
- elif init == 'params': state[key] = p.clone().detach()
89
- elif init == 'grad': state[key] = p.grad.clone().detach() if p.grad is not None else torch.zeros_like(p)
90
- else: raise ValueError(f'unknown init - {init}')
91
- value.append(state[key]) # type:ignore
92
- return value
93
-
94
- # def get_state_keys[CLS: MutableSequence](
95
- def get_state_keys(
96
- self,
97
- *keys: str,
98
- inits: _StateInit | Sequence[_StateInit] = torch.zeros_like,
99
- params=None,
100
- cls: type[CLS] = TensorList,
101
- ) -> list[CLS]:
102
- """Returns a TensorList with the `key` states of all `params`. Creates the states if they don't exist."""
103
-
104
- values = [cls() for _ in range(len(keys))]
105
- if params is None: params = self.get_params()
106
- if callable(inits) or isinstance(inits, str): inits = [inits] * len(keys) # type:ignore
107
-
108
- for pi, p in enumerate(params):
109
- state = self.state[p]
110
- for i, (key, init) in enumerate(zip(keys, inits)): # type:ignore
111
- if key not in state:
112
- if callable(init): state[key] = init(p)
113
- elif isinstance(init, TensorList): state[key] = init[pi].clone()
114
- elif init == 'params': state[key] = p.clone().detach()
115
- elif init == 'grad': state[key] = p.grad.clone().detach() if p.grad is not None else torch.zeros_like(p)
116
- else: raise ValueError(f'unknown init - {init}')
117
- values[i].append(state[key]) # type:ignore
118
- return values
119
-
120
- def _yield_groups_key(self, key: str):
121
- for group in self.param_groups:
122
- value = group[key]
123
- for p in group['params']:
124
- if p.requires_grad: yield value
125
-
126
-
127
- # def get_group_key[CLS: Any](self, key: str, cls: type[CLS] = NumberList) -> CLS:
128
- def get_group_key(self, key: str, cls: type[CLS] = NumberList) -> CLS:
129
- """Returns a TensorList with the param_groups `key` setting of each param."""
130
- return cls(self._yield_groups_key(key)) # type:ignore
131
-
132
- def get_first_group_key(self, key:str) -> Any:
133
- """Returns the param_groups `key` setting of the first param."""
134
- return next(iter(self._yield_groups_key(key)))
135
-
136
- # def get_all_group_keys[CLS: Any](self, cls: type[CLS] = NumberList) -> dict[str, CLS]:
137
- def get_all_group_keys(self, cls: type[CLS] = NumberList) -> dict[str, CLS]:
138
- all_values: dict[str, CLS] = {}
139
- for group in self.param_groups:
140
-
141
- n_params = len([p for p in group['params'] if p.requires_grad])
142
-
143
- for key, value in group.items():
144
- if key != 'params':
145
- if key not in all_values: all_values[key] = cls(value for _ in range(n_params)) # type:ignore
146
- else: all_values[key].extend([value for _ in range(n_params)]) # type:ignore
147
-
148
- return all_values
149
-
150
- # def get_group_keys[CLS: MutableSequence](self, *keys: str, cls: type[CLS] = NumberList) -> list[CLS]:
151
- def get_group_keys(self, *keys: str, cls: type[CLS] = NumberList) -> list[CLS]:
152
- """Returns a list with the param_groups `key` setting of each param."""
153
-
154
- all_values: list[CLS] = [cls() for _ in keys]
155
- for group in self.param_groups:
156
-
157
- n_params = len([p for p in group['params'] if p.requires_grad])
158
-
159
- for i, key in enumerate(keys):
160
- value = group[key]
161
- all_values[i].extend([value for _ in range(n_params)]) # type:ignore
162
-
163
- return all_values
164
-
165
- @torch.no_grad
166
- def evaluate_loss_at_vec(self, vec, closure=None, params = None, backward=False, ensure_float=False):
167
- """_summary_
168
-
169
- Args:
170
- vec (_type_): _description_
171
- closure (_type_, optional): _description_. Defaults to None.
172
- params (_type_, optional): _description_. Defaults to None.
173
- backward (bool, optional): _description_. Defaults to False.
174
- ensure_float (bool, optional): _description_. Defaults to False.
175
-
176
- Returns:
177
- _type_: _description_
178
- """
179
- vec = totensor(vec)
180
- if closure is None: closure = self._closure # type:ignore # pylint:disable=no-member
181
- if params is None: params = self.get_params()
182
-
183
- params.from_vec_(vec.to(params[0]))
184
- loss = _maybe_pass_backward(closure, backward)
185
-
186
- if ensure_float: return tofloat(loss)
187
- return _maybe_pass_backward(closure, backward)
188
-
189
- @overload
190
- def evaluate_loss_grad_at_vec(self, vec, closure=None, params = None, to_numpy: Literal[True] = False) -> tuple[float, np.ndarray]: ... # type:ignore
191
- @overload
192
- def evaluate_loss_grad_at_vec(self, vec, closure=None, params = None, to_numpy: Literal[False] = False) -> tuple[_ScalarLoss, torch.Tensor]: ...
193
- @torch.no_grad
194
- def evaluate_loss_grad_at_vec(self, vec, closure=None, params = None, to_numpy: Literal[True] | Literal[False]=False):
195
- """_summary_
196
-
197
- Args:
198
- vec (_type_): _description_
199
- closure (_type_, optional): _description_. Defaults to None.
200
- params (_type_, optional): _description_. Defaults to None.
201
- to_numpy (Literal[True] | Literal[False], optional): _description_. Defaults to False.
202
-
203
- Returns:
204
- _type_: _description_
205
- """
206
- if params is None: params = self.get_params()
207
- loss = self.evaluate_loss_at_vec(vec, closure, params, backward = True, ensure_float = to_numpy)
208
- grad = params.grad.to_vec()
209
-
210
- if to_numpy: return tofloat(loss), grad.detach().cpu().numpy()
211
- return loss, grad
212
-
213
-
214
- @torch.no_grad
215
- def _maybe_evaluate_closure(self, closure, backward=True):
216
- loss = None
217
- if closure is not None:
218
- loss = _maybe_pass_backward(closure, backward)
219
- return loss
@@ -1,4 +0,0 @@
1
- r"""
2
- Modules related to adapting the learning rate.
3
- """
4
- from .adaptive import Cautious, UseGradMagnitude, UseGradSign, ScaleLRBySignChange, NegateOnSignChange
@@ -1,192 +0,0 @@
1
- import typing
2
- import torch
3
-
4
- from ...core import OptimizerModule
5
-
6
- class Cautious(OptimizerModule):
7
- """Negates update for parameters where update and gradient sign is inconsistent.
8
- Optionally normalizes the update by the number of parameters that are not masked.
9
- This is meant to be used after any momentum-based modules.
10
-
11
- Args:
12
- normalize (bool, optional):
13
- renormalize update after masking.
14
- only has effect when mode is 'zero'. Defaults to False.
15
- eps (float, optional): epsilon for normalization. Defaults to 1e-6.
16
- mode (str, optional):
17
- what to do with updates with inconsistent signs.
18
-
19
- "zero" - set them to zero (as in paper)
20
-
21
- "grad" - set them to the gradient
22
-
23
- "negate" - negate them (same as using update magnitude and gradient sign)
24
-
25
- reference
26
- *Cautious Optimizers: Improving Training with One Line of Code.
27
- Kaizhao Liang, Lizhang Chen, Bo Liu, Qiang Liu*
28
- """
29
- def __init__(self, normalize = False, eps=1e-6, mode: typing.Literal['zero', 'grad', 'backtrack'] = 'zero'):
30
- super().__init__({})
31
- self.eps = eps
32
- self.normalize = normalize
33
- self.mode: typing.Literal['zero', 'grad', 'backtrack'] = mode
34
-
35
- @torch.no_grad
36
- def _update(self, vars, ascent):
37
- params = self.get_params()
38
- grad = vars.maybe_compute_grad_(params)
39
-
40
- # mask will be > 0 for parameters where both signs are the same
41
- mask = (ascent * grad) > 0
42
- if self.mode in ('zero', 'grad'):
43
- if self.normalize and self.mode == 'zero':
44
- fmask = mask.to(ascent[0].dtype)
45
- fmask /= fmask.total_mean() + self.eps # type:ignore
46
- else:
47
- fmask = mask
48
-
49
- ascent *= fmask
50
-
51
- if self.mode == 'grad':
52
- ascent += grad * mask.logical_not_()
53
-
54
- return ascent
55
-
56
- # mode = 'backtrack'
57
- ascent -= ascent.mul(2).mul_(mask.logical_not_())
58
- return ascent
59
-
60
-
61
- class UseGradSign(OptimizerModule):
62
- """
63
- Uses update magnitude but gradient sign.
64
- """
65
- def __init__(self):
66
- super().__init__({})
67
-
68
- @torch.no_grad
69
- def _update(self, vars, ascent):
70
- params = self.get_params()
71
- grad = vars.maybe_compute_grad_(params)
72
-
73
- return ascent.abs_().mul_(grad.sign())
74
-
75
- class UseGradMagnitude(OptimizerModule):
76
- """
77
- Uses update sign but gradient magnitude.
78
- """
79
- def __init__(self):
80
- super().__init__({})
81
-
82
- @torch.no_grad
83
- def _update(self, vars, ascent):
84
- params = self.get_params()
85
- grad = vars.maybe_compute_grad_(params)
86
-
87
- return ascent.sign_().mul_(grad.abs())
88
-
89
-
90
- class ScaleLRBySignChange(OptimizerModule):
91
- """
92
- learning rate gets multiplied by `nplus` if ascent/gradient didn't change the sign,
93
- or `nminus` if it did.
94
-
95
- This is part of RProp update rule.
96
-
97
- Args:
98
- nplus (float): learning rate gets multiplied by `nplus` if ascent/gradient didn't change the sign
99
- nminus (float): learning rate gets multiplied by `nminus` if ascent/gradient changed the sign
100
- lb (float): lower bound for lr.
101
- ub (float): upper bound for lr.
102
- alpha (float): initial learning rate.
103
-
104
- """
105
- def __init__(self, nplus: float = 1.2, nminus: float = 0.5, lb = 1e-6, ub = 50, alpha=1, use_grad=False):
106
- defaults = dict(nplus = nplus, nminus = nminus, alpha = alpha, lb = lb, ub = ub)
107
- super().__init__(defaults)
108
- self.current_step = 0
109
- self.use_grad = use_grad
110
-
111
- @torch.no_grad
112
- def _update(self, vars, ascent):
113
- params = self.get_params()
114
-
115
- if self.use_grad: cur = vars.maybe_compute_grad_(params)
116
- else: cur = ascent
117
-
118
- nplus, nminus, lb, ub = self.get_group_keys('nplus', 'nminus', 'lb', 'ub')
119
- prev, lrs = self.get_state_keys('prev_ascent', 'lrs', params=params)
120
-
121
- # initialize on 1st step
122
- if self.current_step == 0:
123
- lrs.fill_(self.get_group_key('alpha'))
124
- ascent.mul_(lrs)
125
- prev.copy_(ascent)
126
- self.current_step += 1
127
- return ascent
128
-
129
- mask = cur * prev
130
- sign_changed = mask < 0
131
- sign_same = mask > 0
132
-
133
- # multiply magnitudes where sign didn't change
134
- lrs.masked_set_(sign_same, lrs * nplus)
135
- # multiply magnitudes where sign changed
136
- lrs.masked_set_(sign_changed, lrs * nminus)
137
- # bounds
138
- lrs.clamp_(lb, ub)
139
-
140
- ascent.mul_(lrs)
141
- prev.copy_(cur)
142
- self.current_step += 1
143
- return ascent
144
-
145
-
146
-
147
- class NegateOnSignChange(OptimizerModule):
148
- """Negates or undoes update for parameters where where gradient or update sign changes.
149
-
150
- This is part of RProp update rule.
151
-
152
- Args:
153
- normalize (bool, optional): renormalize update after masking. Defaults to False.
154
- eps (_type_, optional): epsilon for normalization. Defaults to 1e-6.
155
- use_grad (bool, optional): if True, tracks sign change of the gradient,
156
- otherwise track sign change of the update. Defaults to True.
157
- backtrack (bool, optional): if True, undoes the update when sign changes, otherwise negates it.
158
- Defaults to True.
159
-
160
- """
161
- # todo: add momentum to negation (to cautious as well and rprop negation as well)
162
- def __init__(self, normalize = False, eps=1e-6, use_grad = False, backtrack = True):
163
- super().__init__({})
164
- self.eps = eps
165
- self.normalize = normalize
166
- self.use_grad = use_grad
167
- self.backtrack = backtrack
168
- self.current_step = 0
169
-
170
- @torch.no_grad
171
- def _update(self, vars, ascent):
172
- params = self.get_params()
173
-
174
- if self.use_grad: cur = vars.maybe_compute_grad_(params)
175
- else: cur = ascent
176
-
177
- prev = self.get_state_key('prev')
178
-
179
- # initialize on first step
180
- if self.current_step == 0:
181
- prev.set_(cur)
182
- self.current_step += 1
183
- return ascent
184
-
185
- # mask will be > 0 for parameters where both signs are the same
186
- mask = (cur * prev) < 0
187
- if self.backtrack: ascent.masked_set_(mask, prev)
188
- else: ascent.select_set_(mask, 0)
189
-
190
- prev.set_(cur)
191
- self.current_step += 1
192
- return ascent