torchzero 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (104) hide show
  1. torchzero/__init__.py +4 -0
  2. torchzero/core/__init__.py +13 -0
  3. torchzero/core/module.py +471 -0
  4. torchzero/core/tensorlist_optimizer.py +219 -0
  5. torchzero/modules/__init__.py +21 -0
  6. torchzero/modules/adaptive/__init__.py +4 -0
  7. torchzero/modules/adaptive/adaptive.py +192 -0
  8. torchzero/modules/experimental/__init__.py +19 -0
  9. torchzero/modules/experimental/experimental.py +294 -0
  10. torchzero/modules/experimental/quad_interp.py +104 -0
  11. torchzero/modules/experimental/subspace.py +259 -0
  12. torchzero/modules/gradient_approximation/__init__.py +7 -0
  13. torchzero/modules/gradient_approximation/_fd_formulas.py +3 -0
  14. torchzero/modules/gradient_approximation/base_approximator.py +110 -0
  15. torchzero/modules/gradient_approximation/fdm.py +125 -0
  16. torchzero/modules/gradient_approximation/forward_gradient.py +163 -0
  17. torchzero/modules/gradient_approximation/newton_fdm.py +198 -0
  18. torchzero/modules/gradient_approximation/rfdm.py +125 -0
  19. torchzero/modules/line_search/__init__.py +30 -0
  20. torchzero/modules/line_search/armijo.py +56 -0
  21. torchzero/modules/line_search/base_ls.py +139 -0
  22. torchzero/modules/line_search/directional_newton.py +217 -0
  23. torchzero/modules/line_search/grid_ls.py +158 -0
  24. torchzero/modules/line_search/scipy_minimize_scalar.py +62 -0
  25. torchzero/modules/meta/__init__.py +12 -0
  26. torchzero/modules/meta/alternate.py +65 -0
  27. torchzero/modules/meta/grafting.py +195 -0
  28. torchzero/modules/meta/optimizer_wrapper.py +173 -0
  29. torchzero/modules/meta/return_overrides.py +46 -0
  30. torchzero/modules/misc/__init__.py +10 -0
  31. torchzero/modules/misc/accumulate.py +43 -0
  32. torchzero/modules/misc/basic.py +115 -0
  33. torchzero/modules/misc/lr.py +96 -0
  34. torchzero/modules/misc/multistep.py +51 -0
  35. torchzero/modules/misc/on_increase.py +53 -0
  36. torchzero/modules/momentum/__init__.py +4 -0
  37. torchzero/modules/momentum/momentum.py +106 -0
  38. torchzero/modules/operations/__init__.py +29 -0
  39. torchzero/modules/operations/multi.py +298 -0
  40. torchzero/modules/operations/reduction.py +134 -0
  41. torchzero/modules/operations/singular.py +113 -0
  42. torchzero/modules/optimizers/__init__.py +10 -0
  43. torchzero/modules/optimizers/adagrad.py +49 -0
  44. torchzero/modules/optimizers/adam.py +118 -0
  45. torchzero/modules/optimizers/lion.py +28 -0
  46. torchzero/modules/optimizers/rmsprop.py +51 -0
  47. torchzero/modules/optimizers/rprop.py +99 -0
  48. torchzero/modules/optimizers/sgd.py +54 -0
  49. torchzero/modules/orthogonalization/__init__.py +2 -0
  50. torchzero/modules/orthogonalization/newtonschulz.py +159 -0
  51. torchzero/modules/orthogonalization/svd.py +86 -0
  52. torchzero/modules/quasi_newton/__init__.py +4 -0
  53. torchzero/modules/regularization/__init__.py +22 -0
  54. torchzero/modules/regularization/dropout.py +34 -0
  55. torchzero/modules/regularization/noise.py +77 -0
  56. torchzero/modules/regularization/normalization.py +328 -0
  57. torchzero/modules/regularization/ortho_grad.py +78 -0
  58. torchzero/modules/regularization/weight_decay.py +92 -0
  59. torchzero/modules/scheduling/__init__.py +2 -0
  60. torchzero/modules/scheduling/lr_schedulers.py +131 -0
  61. torchzero/modules/scheduling/step_size.py +80 -0
  62. torchzero/modules/second_order/__init__.py +4 -0
  63. torchzero/modules/second_order/newton.py +165 -0
  64. torchzero/modules/smoothing/__init__.py +5 -0
  65. torchzero/modules/smoothing/gaussian_smoothing.py +90 -0
  66. torchzero/modules/smoothing/laplacian_smoothing.py +128 -0
  67. torchzero/modules/weight_averaging/__init__.py +2 -0
  68. torchzero/modules/weight_averaging/ema.py +72 -0
  69. torchzero/modules/weight_averaging/swa.py +171 -0
  70. torchzero/optim/__init__.py +10 -0
  71. torchzero/optim/experimental/__init__.py +20 -0
  72. torchzero/optim/experimental/experimental.py +343 -0
  73. torchzero/optim/experimental/ray_search.py +83 -0
  74. torchzero/optim/first_order/__init__.py +18 -0
  75. torchzero/optim/first_order/cautious.py +158 -0
  76. torchzero/optim/first_order/forward_gradient.py +70 -0
  77. torchzero/optim/first_order/optimizers.py +570 -0
  78. torchzero/optim/modular.py +132 -0
  79. torchzero/optim/quasi_newton/__init__.py +1 -0
  80. torchzero/optim/quasi_newton/directional_newton.py +58 -0
  81. torchzero/optim/second_order/__init__.py +1 -0
  82. torchzero/optim/second_order/newton.py +94 -0
  83. torchzero/optim/wrappers/__init__.py +0 -0
  84. torchzero/optim/wrappers/nevergrad.py +113 -0
  85. torchzero/optim/wrappers/nlopt.py +165 -0
  86. torchzero/optim/wrappers/scipy.py +439 -0
  87. torchzero/optim/zeroth_order/__init__.py +4 -0
  88. torchzero/optim/zeroth_order/fdm.py +87 -0
  89. torchzero/optim/zeroth_order/newton_fdm.py +146 -0
  90. torchzero/optim/zeroth_order/rfdm.py +217 -0
  91. torchzero/optim/zeroth_order/rs.py +85 -0
  92. torchzero/random/__init__.py +1 -0
  93. torchzero/random/random.py +46 -0
  94. torchzero/tensorlist.py +819 -0
  95. torchzero/utils/__init__.py +0 -0
  96. torchzero/utils/compile.py +39 -0
  97. torchzero/utils/derivatives.py +99 -0
  98. torchzero/utils/python_tools.py +25 -0
  99. torchzero/utils/torch_tools.py +92 -0
  100. torchzero-0.0.1.dist-info/LICENSE +21 -0
  101. torchzero-0.0.1.dist-info/METADATA +118 -0
  102. torchzero-0.0.1.dist-info/RECORD +104 -0
  103. torchzero-0.0.1.dist-info/WHEEL +5 -0
  104. torchzero-0.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,298 @@
1
+ from collections.abc import Iterable
2
+ import torch
3
+
4
+ from ...core import OptimizerModule
5
+
6
+ _Value = int | float | OptimizerModule | Iterable[OptimizerModule]
7
+
8
+ class Add(OptimizerModule):
9
+ """add `value` to update. `value` can be a scalar, an OptimizerModule or sequence of OptimizerModules"""
10
+ def __init__(self, value: _Value):
11
+ super().__init__({})
12
+
13
+ if not isinstance(value, (int, float)):
14
+ self._set_child_('value', value)
15
+
16
+ self.value = value
17
+
18
+ @torch.no_grad()
19
+ def _update(self, state, ascent):
20
+ if isinstance(self.value, (int, float)):
21
+ return ascent.add_(self.value)
22
+
23
+ state_copy = state.copy(clone_ascent = True)
24
+ v = self.children['value'].return_ascent(state_copy)
25
+ return ascent.add_(v)
26
+
27
+
28
+ class Sub(OptimizerModule):
29
+ """subtracts `value` from update. `value` can be a scalar, an OptimizerModule or sequence of OptimizerModules"""
30
+ def __init__(self, subtrahend: _Value):
31
+ super().__init__({})
32
+
33
+ if not isinstance(subtrahend, (int, float)):
34
+ self._set_child_('subtrahend', subtrahend)
35
+
36
+ self.subtrahend = subtrahend
37
+
38
+ @torch.no_grad()
39
+ def _update(self, state, ascent):
40
+ if isinstance(self.subtrahend, (int, float)):
41
+ return ascent.sub_(self.subtrahend)
42
+
43
+ state_copy = state.copy(clone_ascent = True)
44
+ subtrahend = self.children['subtrahend'].return_ascent(state_copy)
45
+ return ascent.sub_(subtrahend)
46
+
47
+ class RSub(OptimizerModule):
48
+ """subtracts update from `value`. `value` can be a scalar, an OptimizerModule or sequence of OptimizerModules"""
49
+ def __init__(self, minuend: _Value):
50
+ super().__init__({})
51
+
52
+ if not isinstance(minuend, (int, float)):
53
+ self._set_child_('minuend', minuend)
54
+
55
+ self.minuend = minuend
56
+
57
+ @torch.no_grad()
58
+ def _update(self, state, ascent):
59
+ if isinstance(self.minuend, (int, float)):
60
+ return ascent.sub_(self.minuend).neg_()
61
+
62
+ state_copy = state.copy(clone_ascent = True)
63
+ minuend = self.children['minuend'].return_ascent(state_copy)
64
+ return ascent.sub_(minuend).neg_()
65
+
66
+ class Subtract(OptimizerModule):
67
+ """Calculates `minuend - subtrahend`"""
68
+ def __init__(
69
+ self,
70
+ minuend: OptimizerModule | Iterable[OptimizerModule],
71
+ subtrahend: OptimizerModule | Iterable[OptimizerModule],
72
+ ):
73
+ super().__init__({})
74
+ self._set_child_('minuend', minuend)
75
+ self._set_child_('subtrahend', subtrahend)
76
+
77
+ @torch.no_grad
78
+ def step(self, state):
79
+ state_copy = state.copy(clone_ascent = True)
80
+ minuend = self.children['minuend'].return_ascent(state_copy)
81
+ state.update_attrs_(state_copy)
82
+ subtrahend = self.children['subtrahend'].return_ascent(state)
83
+
84
+ state.ascent = minuend.sub_(subtrahend)
85
+ return self._update_params_or_step_with_next(state)
86
+
87
+ class Mul(OptimizerModule):
88
+ """multiplies update by `value`. `value` can be a scalar, an OptimizerModule or sequence of OptimizerModules"""
89
+ def __init__(self, value: _Value):
90
+ super().__init__({})
91
+
92
+ if not isinstance(value, (int, float)):
93
+ self._set_child_('value', value)
94
+
95
+ self.value = value
96
+
97
+ @torch.no_grad()
98
+ def _update(self, state, ascent):
99
+ if isinstance(self.value, (int, float)):
100
+ return ascent.mul_(self.value)
101
+
102
+ state_copy = state.copy(clone_ascent = True)
103
+ v = self.children['value'].return_ascent(state_copy)
104
+ return ascent.mul_(v)
105
+
106
+
107
+ class Div(OptimizerModule):
108
+ """divides update by `value`. `value` can be a scalar, an OptimizerModule or sequence of OptimizerModules"""
109
+ def __init__(self, denominator: _Value):
110
+ super().__init__({})
111
+
112
+ if not isinstance(denominator, (int, float)):
113
+ self._set_child_('denominator', denominator)
114
+
115
+ self.denominator = denominator
116
+
117
+ @torch.no_grad()
118
+ def _update(self, state, ascent):
119
+ if isinstance(self.denominator, (int, float)):
120
+ return ascent.div_(self.denominator)
121
+
122
+ state_copy = state.copy(clone_ascent = True)
123
+ denominator = self.children['denominator'].return_ascent(state_copy)
124
+ return ascent.div_(denominator)
125
+
126
+ class RDiv(OptimizerModule):
127
+ """`value` by update. `value` can be a scalar, an OptimizerModule or sequence of OptimizerModules"""
128
+ def __init__(self, numerator: _Value):
129
+ super().__init__({})
130
+
131
+ if not isinstance(numerator, (int, float)):
132
+ self._set_child_('numerator', numerator)
133
+
134
+ self.numerator = numerator
135
+
136
+ @torch.no_grad()
137
+ def _update(self, state, ascent):
138
+ if isinstance(self.numerator, (int, float)):
139
+ return ascent.reciprocal_().mul_(self.numerator)
140
+
141
+ state_copy = state.copy(clone_ascent = True)
142
+ numerator = self.children['numerator'].return_ascent(state_copy)
143
+ return ascent.reciprocal_().mul_(numerator)
144
+
145
+ class Divide(OptimizerModule):
146
+ """calculates *numerator / denominator*"""
147
+ def __init__(
148
+ self,
149
+ numerator: OptimizerModule | Iterable[OptimizerModule],
150
+ denominator: OptimizerModule | Iterable[OptimizerModule],
151
+ ):
152
+ super().__init__({})
153
+ self._set_child_('numerator', numerator)
154
+ self._set_child_('denominator', denominator)
155
+
156
+ @torch.no_grad
157
+ def step(self, state):
158
+ state_copy = state.copy(clone_ascent = True)
159
+ numerator = self.children['numerator'].return_ascent(state_copy)
160
+ state.update_attrs_(state_copy)
161
+ denominator = self.children['denominator'].return_ascent(state)
162
+
163
+ state.ascent = numerator.div_(denominator)
164
+ return self._update_params_or_step_with_next(state)
165
+
166
+
167
+ class Pow(OptimizerModule):
168
+ """takes ascent to the power of `value`. `value` can be a scalar, an OptimizerModule or sequence of OptimizerModules"""
169
+ def __init__(self, power: _Value):
170
+ super().__init__({})
171
+
172
+ if not isinstance(power, (int, float)):
173
+ self._set_child_('power', power)
174
+
175
+ self.power = power
176
+
177
+ @torch.no_grad()
178
+ def _update(self, state, ascent):
179
+ if isinstance(self.power, (int, float)):
180
+ return ascent.pow_(self.power)
181
+
182
+ state_copy = state.copy(clone_ascent = True)
183
+ power = self.children['power'].return_ascent(state_copy)
184
+ return ascent.pow_(power)
185
+
186
+ class RPow(OptimizerModule):
187
+ """takes `value` to the power of ascent. `value` can be a scalar, an OptimizerModule or sequence of OptimizerModules"""
188
+ def __init__(self, base: _Value):
189
+ super().__init__({})
190
+
191
+ if not isinstance(base, (int, float)):
192
+ self._set_child_('base', base)
193
+
194
+ self.base = base
195
+
196
+ @torch.no_grad()
197
+ def _update(self, state, ascent):
198
+ if isinstance(self.base, (int, float)):
199
+ return self.base ** ascent
200
+
201
+ state_copy = state.copy(clone_ascent = True)
202
+ base = self.children['base'].return_ascent(state_copy)
203
+ return base.pow_(ascent)
204
+
205
+ class Power(OptimizerModule):
206
+ """calculates *base ^ power*"""
207
+ def __init__(
208
+ self,
209
+ base: OptimizerModule | Iterable[OptimizerModule],
210
+ power: OptimizerModule | Iterable[OptimizerModule],
211
+ ):
212
+ super().__init__({})
213
+ self._set_child_('base', base)
214
+ self._set_child_('power', power)
215
+
216
+ @torch.no_grad
217
+ def step(self, state):
218
+ state_copy = state.copy(clone_ascent = True)
219
+ base = self.children['base'].return_ascent(state_copy)
220
+ state.update_attrs_(state_copy)
221
+ power = self.children['power'].return_ascent(state)
222
+
223
+ state.ascent = base.pow_(power)
224
+ return self._update_params_or_step_with_next(state)
225
+
226
+
227
+ class Lerp(OptimizerModule):
228
+ """Linear interpolation between update and `end` based on scalar `weight`.
229
+
230
+ `out = update + weight * (end - update)`"""
231
+ def __init__(self, end: OptimizerModule | Iterable[OptimizerModule], weight: float):
232
+ super().__init__({})
233
+
234
+ self._set_child_('end', end)
235
+ self.weight = weight
236
+
237
+ @torch.no_grad()
238
+ def _update(self, state, ascent):
239
+
240
+ state_copy = state.copy(clone_ascent = True)
241
+ end = self.children['end'].return_ascent(state_copy)
242
+ return ascent.lerp_(end, self.weight)
243
+
244
+
245
+ class Interpolate(OptimizerModule):
246
+ """Does a linear interpolation of two module's updates - `start` (given by input), and `end`, based on a scalar
247
+ `weight`.
248
+
249
+ `out = input + weight * (end - input)`"""
250
+ def __init__(
251
+ self,
252
+ input: OptimizerModule | Iterable[OptimizerModule],
253
+ end: OptimizerModule | Iterable[OptimizerModule],
254
+ weight: float,
255
+ ):
256
+ super().__init__({})
257
+ self._set_child_('input', input)
258
+ self._set_child_('end', end)
259
+ self.weight = weight
260
+
261
+ @torch.no_grad
262
+ def step(self, state):
263
+ state_copy = state.copy(clone_ascent = True)
264
+ input = self.children['input'].return_ascent(state_copy)
265
+ state.update_attrs_(state_copy)
266
+ end = self.children['end'].return_ascent(state)
267
+
268
+ state.ascent = input.lerp_(end, weight = self.weight)
269
+
270
+ return self._update_params_or_step_with_next(state)
271
+
272
+ class AddMagnitude(OptimizerModule):
273
+ """Add `value` multiplied by sign of the ascent, i.e. this adds `value` to the magnitude of the update.
274
+
275
+ Args:
276
+ value (Value): value to add to magnitude, either a float or an OptimizerModule.
277
+ add_to_zero (bool, optional):
278
+ if True, adds `value` to 0s. Otherwise, zeros remain zero.
279
+ Only has effect if value is a float. Defaults to True.
280
+ """
281
+ def __init__(self, value: _Value, add_to_zero=True):
282
+ super().__init__({})
283
+
284
+ if not isinstance(value, (int, float)):
285
+ self._set_child_('value', value)
286
+
287
+ self.value = value
288
+ self.add_to_zero = add_to_zero
289
+
290
+ @torch.no_grad()
291
+ def _update(self, state, ascent):
292
+ if isinstance(self.value, (int, float)):
293
+ if self.add_to_zero: return ascent.add_(ascent.clamp_magnitude(min=1).sign_().mul_(self.value))
294
+ return ascent.add_(ascent.sign_().mul_(self.value))
295
+
296
+ state_copy = state.copy(clone_ascent = True)
297
+ v = self.children['value'].return_ascent(state_copy)
298
+ return ascent.add_(v.abs_().mul_(ascent.sign()))
@@ -0,0 +1,134 @@
1
+ from collections.abc import Callable, Iterable
2
+ import numpy as np
3
+ import torch
4
+
5
+ from ...core import OptimizerModule
6
+
7
+ _Value = int | float | OptimizerModule | Iterable[OptimizerModule]
8
+
9
+
10
+ class Sum(OptimizerModule):
11
+ """calculates sum of multiple updates.
12
+
13
+ Args:
14
+ *modules:
15
+ either OptimizerModules or iterables of OptimizerModules to chain. Scalars are also allowed."""
16
+ def __init__(
17
+ self,
18
+ *modules: _Value,
19
+ ):
20
+ super().__init__({})
21
+
22
+ scalars = [i for i in modules if isinstance(i, (int,float))]
23
+ self.scalar = sum(scalars) if len(scalars) > 0 else None
24
+
25
+ for i,module in enumerate(i for i in modules if not isinstance(i, (int, float))):
26
+ self._set_child_(i, module)
27
+
28
+ @torch.no_grad
29
+ def step(self, state):
30
+ if len(self.children) == 1:
31
+ state.ascent = self.children[0].return_ascent(state)
32
+ if self.scalar is not None: state.ascent += self.scalar
33
+ return self._update_params_or_step_with_next(state)
34
+
35
+ sum = None
36
+ for i, c in sorted(self.children.items(), key=lambda x: x[0]):
37
+ if i == len(self.children) - 1: cur_state = state
38
+ else: cur_state = state.copy(clone_ascent = True)
39
+
40
+ if sum is None: sum = c.return_ascent(cur_state)
41
+ else: sum += c.return_ascent(cur_state)
42
+
43
+ if i != len(self.children) - 1: state.update_attrs_(cur_state)
44
+
45
+ assert sum is not None
46
+ if self.scalar is not None: sum += self.scalar
47
+ state.ascent = sum
48
+ return self._update_params_or_step_with_next(state)
49
+
50
+ class Mean(OptimizerModule):
51
+ """calculates mean of multiple updates.
52
+
53
+ Args:
54
+ *modules:
55
+ either OptimizerModules or iterables of OptimizerModules to chain. Scalars are also allowed."""
56
+
57
+ def __init__(
58
+ self,
59
+ *modules: _Value,
60
+ ):
61
+ super().__init__({})
62
+
63
+ scalars = [i for i in modules if isinstance(i, (int,float))]
64
+ self.scalar = sum(scalars) if len(scalars) > 0 else None
65
+
66
+ self.n_values = len(modules)
67
+
68
+ for i,module in enumerate(i for i in modules if not isinstance(i, (int, float))):
69
+ self._set_child_(i, module)
70
+
71
+ @torch.no_grad
72
+ def step(self, state):
73
+ if len(self.children) == 1:
74
+ state.ascent = self.children[0].return_ascent(state)
75
+ if self.scalar is not None: state.ascent += self.scalar
76
+ if self.n_values > 1: state.ascent /= self.n_values
77
+ return self._update_params_or_step_with_next(state)
78
+
79
+ sum = None
80
+ for i, c in sorted(self.children.items(), key=lambda x: x[0]):
81
+ if i == len(self.children) - 1: cur_state = state
82
+ else: cur_state = state.copy(clone_ascent = True)
83
+
84
+ if sum is None: sum = c.return_ascent(cur_state)
85
+ else: sum += c.return_ascent(cur_state)
86
+
87
+ if i != len(self.children) - 1: state.update_attrs_(cur_state)
88
+
89
+ assert sum is not None
90
+ if self.scalar is not None: sum += self.scalar
91
+ if self.n_values > 1: sum /= self.n_values
92
+ state.ascent = sum
93
+ return self._update_params_or_step_with_next(state)
94
+
95
+ class Product(OptimizerModule):
96
+ """calculates product of multiple updates.
97
+
98
+ Args:
99
+ *modules:
100
+ either OptimizerModules or iterables of OptimizerModules to chain. Scalars are also allowed."""
101
+
102
+ def __init__(
103
+ self,
104
+ *modules: _Value,
105
+ ):
106
+ super().__init__({})
107
+
108
+ scalars = [i for i in modules if isinstance(i, (int,float))]
109
+ self.scalar = np.prod(scalars).item() if len(scalars) > 0 else None
110
+
111
+ for i,module in enumerate(i for i in modules if not isinstance(i, (int, float))):
112
+ self._set_child_(i, module)
113
+
114
+ @torch.no_grad
115
+ def step(self, state):
116
+ if len(self.children) == 1:
117
+ state.ascent = self.children[0].return_ascent(state)
118
+ if self.scalar is not None: state.ascent *= self.scalar
119
+ return self._update_params_or_step_with_next(state)
120
+
121
+ prod = None
122
+ for i, c in sorted(self.children.items(), key=lambda x: x[0]):
123
+ if i == len(self.children) - 1: cur_state = state
124
+ else: cur_state = state.copy(clone_ascent = True)
125
+
126
+ if prod is None: prod = c.return_ascent(cur_state)
127
+ else: prod *= c.return_ascent(cur_state)
128
+
129
+ if i != len(self.children) - 1: state.update_attrs_(cur_state)
130
+
131
+ assert prod is not None
132
+ if self.scalar is not None: prod *= self.scalar
133
+ state.ascent = prod
134
+ return self._update_params_or_step_with_next(state)
@@ -0,0 +1,113 @@
1
+ from collections.abc import Iterable
2
+ from operator import methodcaller
3
+
4
+ import torch
5
+
6
+ from ...core import OptimizerModule
7
+ from ...tensorlist import TensorList
8
+
9
+
10
+ class Operation(OptimizerModule):
11
+ """Applies an operation to the ascent, supported operations:
12
+
13
+ `abs`, `sign`, `sin`, `cos`, `tan`, `asin`, `acos`, `atan`, `sinh`, `cosh`,
14
+ `tanh`, `log`, `log1p`, `log2`, `log10`, `erf`, `erfc`, `exp`, `neg`, `reciprocal`,
15
+ `copy`, `zero`, `sqrt`, `floor`, `ceil`, `round`."""
16
+ def __init__(self, operation: str):
17
+ super().__init__({})
18
+ self.operation = methodcaller(f'{operation}_')
19
+
20
+ @torch.no_grad
21
+ def _update(self, state, ascent): return self.operation(ascent)
22
+
23
+ class Reciprocal(OptimizerModule):
24
+ """*1 / update*"""
25
+ def __init__(self,):
26
+ super().__init__({})
27
+
28
+ @torch.no_grad()
29
+ def _update(self, state, ascent): return ascent.reciprocal_()
30
+
31
+ class Negate(OptimizerModule):
32
+ """minus update"""
33
+ def __init__(self,):
34
+ super().__init__({})
35
+
36
+ @torch.no_grad()
37
+ def _update(self, state, ascent): return ascent.neg_()
38
+
39
+
40
+ def sign_grad_(params: Iterable[torch.Tensor]):
41
+ """Apply sign function to gradients of an iterable of parameters.
42
+
43
+ Args:
44
+ params (abc.Iterable[torch.Tensor]): an iterable of Tensors or a single Tensor.
45
+ """
46
+ TensorList(params).get_existing_grads().sign_()
47
+
48
+ class Sign(OptimizerModule):
49
+ """applies sign function to the update"""
50
+ def __init__(self):
51
+ super().__init__({})
52
+
53
+ @torch.no_grad
54
+ def _update(self, state, ascent): return ascent.sign_()
55
+
56
+ class Abs(OptimizerModule):
57
+ """takes absolute values of the update."""
58
+ def __init__(self):
59
+ super().__init__({})
60
+
61
+ @torch.no_grad
62
+ def _update(self, state, ascent): return ascent.abs_()
63
+
64
+ class Sin(OptimizerModule):
65
+ """applies sin function to the ascent"""
66
+ def __init__(self):
67
+ super().__init__({})
68
+
69
+ @torch.no_grad
70
+ def _update(self, state, ascent): return ascent.sin_()
71
+
72
+ class Cos(OptimizerModule):
73
+ """applies cos function to the ascent"""
74
+ def __init__(self):
75
+ super().__init__({})
76
+
77
+ @torch.no_grad
78
+ def _update(self, state, ascent): return ascent.cos_()
79
+
80
+
81
+ class NanToNum(OptimizerModule):
82
+ """Convert `nan`, `inf` and `-inf` to numbers.
83
+
84
+ Args:
85
+ nan (optional): the value to replace NaNs with. Default is zero.
86
+ posinf (optional): if a Number, the value to replace positive infinity values with.
87
+ If None, positive infinity values are replaced with the greatest finite value
88
+ representable by input's dtype. Default is None.
89
+ neginf (optional): if a Number, the value to replace negative infinity values with.
90
+ If None, negative infinity values are replaced with the lowest finite value
91
+ representable by input's dtype. Default is None.
92
+ """
93
+ def __init__(self, nan=None, posinf=None, neginf=None):
94
+ super().__init__({})
95
+ self.nan = nan
96
+ self.posinf = posinf
97
+ self.neginf = neginf
98
+
99
+ @torch.no_grad()
100
+ def _update(self, state, ascent): return ascent.nan_to_num_(self.nan, self.posinf, self.neginf)
101
+
102
+
103
+ class MagnitudePower(OptimizerModule):
104
+ """Raises update to the `value` power, but preserves the sign when the power is odd."""
105
+ def __init__(self, value: int | float):
106
+ super().__init__({})
107
+ self.value = value
108
+
109
+ @torch.no_grad()
110
+ def _update(self, state, ascent):
111
+ if self.value % 2 == 1: return ascent.pow_(self.value)
112
+ return ascent.abs().pow_(self.value) * ascent.sign()
113
+
@@ -0,0 +1,10 @@
1
+ r"""
2
+ This include various optimizers as composable modules.
3
+ """
4
+ # from .adam import Adam
5
+ from .sgd import SGD
6
+ from .rprop import Rprop
7
+ from .rmsprop import RMSProp
8
+ from .adagrad import Adagrad
9
+ from .adam import Adam
10
+ from .lion import Lion
@@ -0,0 +1,49 @@
1
+ from collections import abc
2
+
3
+ import torch
4
+
5
+ from ...tensorlist import TensorList
6
+ from ...core import OptimizerModule
7
+
8
+ def _adagrad_step_(ascent: TensorList, grad_sum: TensorList, alpha: TensorList, lr_decay: TensorList, eps: TensorList, step: int):
9
+ clr = alpha / (1 + step * lr_decay)
10
+ grad_sum.addcmul_(ascent, ascent)
11
+ return ascent.div_(grad_sum.sqrt().add_(eps)).mul_(clr)
12
+
13
+ class Adagrad(OptimizerModule):
14
+ """
15
+ Divides ascent direction by mean square root of the sum of all past ascent directions.
16
+
17
+ Exactly matches `torch.optim.Adagrad`.
18
+
19
+ Args:
20
+ lr_decay (float, optional): learning rate decay. Defaults to 0.
21
+ initial_accumulator_value (float, optional): initial value of the sum of squares of gradients. Defaults to 0.
22
+ eps (float, optional): term added to the denominator to improve numerical stability. Defaults to 1e-10.
23
+ alpha (float, optional): learning rate. Defaults to 1.
24
+
25
+ reference
26
+ https://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf
27
+ """
28
+ def __init__(self, lr_decay: float = 0, initial_accumulator_value: float = 0, eps: float = 1e-10, alpha: float = 1):
29
+ defaults = dict(alpha = alpha, lr_decay = lr_decay, initial_accumulator_value=initial_accumulator_value, eps = eps)
30
+ super().__init__(defaults)
31
+ self.cur_step = 0
32
+
33
+ @torch.no_grad
34
+ def _update(self, state, ascent):
35
+ settings = self.get_all_group_keys()
36
+ if self.cur_step == 0: init = ascent.full_like(settings['initial_accumulator_value'])
37
+ else: init = None
38
+ grad_sum = self.get_state_key('grad_sum', init = init) # type:ignore
39
+
40
+ updated_direction = _adagrad_step_(
41
+ ascent=ascent,
42
+ grad_sum=grad_sum,
43
+ alpha=settings["alpha"],
44
+ eps=settings["eps"],
45
+ lr_decay=settings["lr_decay"],
46
+ step=self.cur_step,
47
+ )
48
+ self.cur_step += 1
49
+ return updated_direction