torchzero 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torchzero/__init__.py +4 -0
- torchzero/core/__init__.py +13 -0
- torchzero/core/module.py +471 -0
- torchzero/core/tensorlist_optimizer.py +219 -0
- torchzero/modules/__init__.py +21 -0
- torchzero/modules/adaptive/__init__.py +4 -0
- torchzero/modules/adaptive/adaptive.py +192 -0
- torchzero/modules/experimental/__init__.py +19 -0
- torchzero/modules/experimental/experimental.py +294 -0
- torchzero/modules/experimental/quad_interp.py +104 -0
- torchzero/modules/experimental/subspace.py +259 -0
- torchzero/modules/gradient_approximation/__init__.py +7 -0
- torchzero/modules/gradient_approximation/_fd_formulas.py +3 -0
- torchzero/modules/gradient_approximation/base_approximator.py +110 -0
- torchzero/modules/gradient_approximation/fdm.py +125 -0
- torchzero/modules/gradient_approximation/forward_gradient.py +163 -0
- torchzero/modules/gradient_approximation/newton_fdm.py +198 -0
- torchzero/modules/gradient_approximation/rfdm.py +125 -0
- torchzero/modules/line_search/__init__.py +30 -0
- torchzero/modules/line_search/armijo.py +56 -0
- torchzero/modules/line_search/base_ls.py +139 -0
- torchzero/modules/line_search/directional_newton.py +217 -0
- torchzero/modules/line_search/grid_ls.py +158 -0
- torchzero/modules/line_search/scipy_minimize_scalar.py +62 -0
- torchzero/modules/meta/__init__.py +12 -0
- torchzero/modules/meta/alternate.py +65 -0
- torchzero/modules/meta/grafting.py +195 -0
- torchzero/modules/meta/optimizer_wrapper.py +173 -0
- torchzero/modules/meta/return_overrides.py +46 -0
- torchzero/modules/misc/__init__.py +10 -0
- torchzero/modules/misc/accumulate.py +43 -0
- torchzero/modules/misc/basic.py +115 -0
- torchzero/modules/misc/lr.py +96 -0
- torchzero/modules/misc/multistep.py +51 -0
- torchzero/modules/misc/on_increase.py +53 -0
- torchzero/modules/momentum/__init__.py +4 -0
- torchzero/modules/momentum/momentum.py +106 -0
- torchzero/modules/operations/__init__.py +29 -0
- torchzero/modules/operations/multi.py +298 -0
- torchzero/modules/operations/reduction.py +134 -0
- torchzero/modules/operations/singular.py +113 -0
- torchzero/modules/optimizers/__init__.py +10 -0
- torchzero/modules/optimizers/adagrad.py +49 -0
- torchzero/modules/optimizers/adam.py +118 -0
- torchzero/modules/optimizers/lion.py +28 -0
- torchzero/modules/optimizers/rmsprop.py +51 -0
- torchzero/modules/optimizers/rprop.py +99 -0
- torchzero/modules/optimizers/sgd.py +54 -0
- torchzero/modules/orthogonalization/__init__.py +2 -0
- torchzero/modules/orthogonalization/newtonschulz.py +159 -0
- torchzero/modules/orthogonalization/svd.py +86 -0
- torchzero/modules/quasi_newton/__init__.py +4 -0
- torchzero/modules/regularization/__init__.py +22 -0
- torchzero/modules/regularization/dropout.py +34 -0
- torchzero/modules/regularization/noise.py +77 -0
- torchzero/modules/regularization/normalization.py +328 -0
- torchzero/modules/regularization/ortho_grad.py +78 -0
- torchzero/modules/regularization/weight_decay.py +92 -0
- torchzero/modules/scheduling/__init__.py +2 -0
- torchzero/modules/scheduling/lr_schedulers.py +131 -0
- torchzero/modules/scheduling/step_size.py +80 -0
- torchzero/modules/second_order/__init__.py +4 -0
- torchzero/modules/second_order/newton.py +165 -0
- torchzero/modules/smoothing/__init__.py +5 -0
- torchzero/modules/smoothing/gaussian_smoothing.py +90 -0
- torchzero/modules/smoothing/laplacian_smoothing.py +128 -0
- torchzero/modules/weight_averaging/__init__.py +2 -0
- torchzero/modules/weight_averaging/ema.py +72 -0
- torchzero/modules/weight_averaging/swa.py +171 -0
- torchzero/optim/__init__.py +10 -0
- torchzero/optim/experimental/__init__.py +20 -0
- torchzero/optim/experimental/experimental.py +343 -0
- torchzero/optim/experimental/ray_search.py +83 -0
- torchzero/optim/first_order/__init__.py +18 -0
- torchzero/optim/first_order/cautious.py +158 -0
- torchzero/optim/first_order/forward_gradient.py +70 -0
- torchzero/optim/first_order/optimizers.py +570 -0
- torchzero/optim/modular.py +132 -0
- torchzero/optim/quasi_newton/__init__.py +1 -0
- torchzero/optim/quasi_newton/directional_newton.py +58 -0
- torchzero/optim/second_order/__init__.py +1 -0
- torchzero/optim/second_order/newton.py +94 -0
- torchzero/optim/wrappers/__init__.py +0 -0
- torchzero/optim/wrappers/nevergrad.py +113 -0
- torchzero/optim/wrappers/nlopt.py +165 -0
- torchzero/optim/wrappers/scipy.py +439 -0
- torchzero/optim/zeroth_order/__init__.py +4 -0
- torchzero/optim/zeroth_order/fdm.py +87 -0
- torchzero/optim/zeroth_order/newton_fdm.py +146 -0
- torchzero/optim/zeroth_order/rfdm.py +217 -0
- torchzero/optim/zeroth_order/rs.py +85 -0
- torchzero/random/__init__.py +1 -0
- torchzero/random/random.py +46 -0
- torchzero/tensorlist.py +819 -0
- torchzero/utils/__init__.py +0 -0
- torchzero/utils/compile.py +39 -0
- torchzero/utils/derivatives.py +99 -0
- torchzero/utils/python_tools.py +25 -0
- torchzero/utils/torch_tools.py +92 -0
- torchzero-0.0.1.dist-info/LICENSE +21 -0
- torchzero-0.0.1.dist-info/METADATA +118 -0
- torchzero-0.0.1.dist-info/RECORD +104 -0
- torchzero-0.0.1.dist-info/WHEEL +5 -0
- torchzero-0.0.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,298 @@
|
|
|
1
|
+
from collections.abc import Iterable
|
|
2
|
+
import torch
|
|
3
|
+
|
|
4
|
+
from ...core import OptimizerModule
|
|
5
|
+
|
|
6
|
+
_Value = int | float | OptimizerModule | Iterable[OptimizerModule]
|
|
7
|
+
|
|
8
|
+
class Add(OptimizerModule):
|
|
9
|
+
"""add `value` to update. `value` can be a scalar, an OptimizerModule or sequence of OptimizerModules"""
|
|
10
|
+
def __init__(self, value: _Value):
|
|
11
|
+
super().__init__({})
|
|
12
|
+
|
|
13
|
+
if not isinstance(value, (int, float)):
|
|
14
|
+
self._set_child_('value', value)
|
|
15
|
+
|
|
16
|
+
self.value = value
|
|
17
|
+
|
|
18
|
+
@torch.no_grad()
|
|
19
|
+
def _update(self, state, ascent):
|
|
20
|
+
if isinstance(self.value, (int, float)):
|
|
21
|
+
return ascent.add_(self.value)
|
|
22
|
+
|
|
23
|
+
state_copy = state.copy(clone_ascent = True)
|
|
24
|
+
v = self.children['value'].return_ascent(state_copy)
|
|
25
|
+
return ascent.add_(v)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class Sub(OptimizerModule):
|
|
29
|
+
"""subtracts `value` from update. `value` can be a scalar, an OptimizerModule or sequence of OptimizerModules"""
|
|
30
|
+
def __init__(self, subtrahend: _Value):
|
|
31
|
+
super().__init__({})
|
|
32
|
+
|
|
33
|
+
if not isinstance(subtrahend, (int, float)):
|
|
34
|
+
self._set_child_('subtrahend', subtrahend)
|
|
35
|
+
|
|
36
|
+
self.subtrahend = subtrahend
|
|
37
|
+
|
|
38
|
+
@torch.no_grad()
|
|
39
|
+
def _update(self, state, ascent):
|
|
40
|
+
if isinstance(self.subtrahend, (int, float)):
|
|
41
|
+
return ascent.sub_(self.subtrahend)
|
|
42
|
+
|
|
43
|
+
state_copy = state.copy(clone_ascent = True)
|
|
44
|
+
subtrahend = self.children['subtrahend'].return_ascent(state_copy)
|
|
45
|
+
return ascent.sub_(subtrahend)
|
|
46
|
+
|
|
47
|
+
class RSub(OptimizerModule):
|
|
48
|
+
"""subtracts update from `value`. `value` can be a scalar, an OptimizerModule or sequence of OptimizerModules"""
|
|
49
|
+
def __init__(self, minuend: _Value):
|
|
50
|
+
super().__init__({})
|
|
51
|
+
|
|
52
|
+
if not isinstance(minuend, (int, float)):
|
|
53
|
+
self._set_child_('minuend', minuend)
|
|
54
|
+
|
|
55
|
+
self.minuend = minuend
|
|
56
|
+
|
|
57
|
+
@torch.no_grad()
|
|
58
|
+
def _update(self, state, ascent):
|
|
59
|
+
if isinstance(self.minuend, (int, float)):
|
|
60
|
+
return ascent.sub_(self.minuend).neg_()
|
|
61
|
+
|
|
62
|
+
state_copy = state.copy(clone_ascent = True)
|
|
63
|
+
minuend = self.children['minuend'].return_ascent(state_copy)
|
|
64
|
+
return ascent.sub_(minuend).neg_()
|
|
65
|
+
|
|
66
|
+
class Subtract(OptimizerModule):
|
|
67
|
+
"""Calculates `minuend - subtrahend`"""
|
|
68
|
+
def __init__(
|
|
69
|
+
self,
|
|
70
|
+
minuend: OptimizerModule | Iterable[OptimizerModule],
|
|
71
|
+
subtrahend: OptimizerModule | Iterable[OptimizerModule],
|
|
72
|
+
):
|
|
73
|
+
super().__init__({})
|
|
74
|
+
self._set_child_('minuend', minuend)
|
|
75
|
+
self._set_child_('subtrahend', subtrahend)
|
|
76
|
+
|
|
77
|
+
@torch.no_grad
|
|
78
|
+
def step(self, state):
|
|
79
|
+
state_copy = state.copy(clone_ascent = True)
|
|
80
|
+
minuend = self.children['minuend'].return_ascent(state_copy)
|
|
81
|
+
state.update_attrs_(state_copy)
|
|
82
|
+
subtrahend = self.children['subtrahend'].return_ascent(state)
|
|
83
|
+
|
|
84
|
+
state.ascent = minuend.sub_(subtrahend)
|
|
85
|
+
return self._update_params_or_step_with_next(state)
|
|
86
|
+
|
|
87
|
+
class Mul(OptimizerModule):
|
|
88
|
+
"""multiplies update by `value`. `value` can be a scalar, an OptimizerModule or sequence of OptimizerModules"""
|
|
89
|
+
def __init__(self, value: _Value):
|
|
90
|
+
super().__init__({})
|
|
91
|
+
|
|
92
|
+
if not isinstance(value, (int, float)):
|
|
93
|
+
self._set_child_('value', value)
|
|
94
|
+
|
|
95
|
+
self.value = value
|
|
96
|
+
|
|
97
|
+
@torch.no_grad()
|
|
98
|
+
def _update(self, state, ascent):
|
|
99
|
+
if isinstance(self.value, (int, float)):
|
|
100
|
+
return ascent.mul_(self.value)
|
|
101
|
+
|
|
102
|
+
state_copy = state.copy(clone_ascent = True)
|
|
103
|
+
v = self.children['value'].return_ascent(state_copy)
|
|
104
|
+
return ascent.mul_(v)
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
class Div(OptimizerModule):
|
|
108
|
+
"""divides update by `value`. `value` can be a scalar, an OptimizerModule or sequence of OptimizerModules"""
|
|
109
|
+
def __init__(self, denominator: _Value):
|
|
110
|
+
super().__init__({})
|
|
111
|
+
|
|
112
|
+
if not isinstance(denominator, (int, float)):
|
|
113
|
+
self._set_child_('denominator', denominator)
|
|
114
|
+
|
|
115
|
+
self.denominator = denominator
|
|
116
|
+
|
|
117
|
+
@torch.no_grad()
|
|
118
|
+
def _update(self, state, ascent):
|
|
119
|
+
if isinstance(self.denominator, (int, float)):
|
|
120
|
+
return ascent.div_(self.denominator)
|
|
121
|
+
|
|
122
|
+
state_copy = state.copy(clone_ascent = True)
|
|
123
|
+
denominator = self.children['denominator'].return_ascent(state_copy)
|
|
124
|
+
return ascent.div_(denominator)
|
|
125
|
+
|
|
126
|
+
class RDiv(OptimizerModule):
|
|
127
|
+
"""`value` by update. `value` can be a scalar, an OptimizerModule or sequence of OptimizerModules"""
|
|
128
|
+
def __init__(self, numerator: _Value):
|
|
129
|
+
super().__init__({})
|
|
130
|
+
|
|
131
|
+
if not isinstance(numerator, (int, float)):
|
|
132
|
+
self._set_child_('numerator', numerator)
|
|
133
|
+
|
|
134
|
+
self.numerator = numerator
|
|
135
|
+
|
|
136
|
+
@torch.no_grad()
|
|
137
|
+
def _update(self, state, ascent):
|
|
138
|
+
if isinstance(self.numerator, (int, float)):
|
|
139
|
+
return ascent.reciprocal_().mul_(self.numerator)
|
|
140
|
+
|
|
141
|
+
state_copy = state.copy(clone_ascent = True)
|
|
142
|
+
numerator = self.children['numerator'].return_ascent(state_copy)
|
|
143
|
+
return ascent.reciprocal_().mul_(numerator)
|
|
144
|
+
|
|
145
|
+
class Divide(OptimizerModule):
|
|
146
|
+
"""calculates *numerator / denominator*"""
|
|
147
|
+
def __init__(
|
|
148
|
+
self,
|
|
149
|
+
numerator: OptimizerModule | Iterable[OptimizerModule],
|
|
150
|
+
denominator: OptimizerModule | Iterable[OptimizerModule],
|
|
151
|
+
):
|
|
152
|
+
super().__init__({})
|
|
153
|
+
self._set_child_('numerator', numerator)
|
|
154
|
+
self._set_child_('denominator', denominator)
|
|
155
|
+
|
|
156
|
+
@torch.no_grad
|
|
157
|
+
def step(self, state):
|
|
158
|
+
state_copy = state.copy(clone_ascent = True)
|
|
159
|
+
numerator = self.children['numerator'].return_ascent(state_copy)
|
|
160
|
+
state.update_attrs_(state_copy)
|
|
161
|
+
denominator = self.children['denominator'].return_ascent(state)
|
|
162
|
+
|
|
163
|
+
state.ascent = numerator.div_(denominator)
|
|
164
|
+
return self._update_params_or_step_with_next(state)
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
class Pow(OptimizerModule):
|
|
168
|
+
"""takes ascent to the power of `value`. `value` can be a scalar, an OptimizerModule or sequence of OptimizerModules"""
|
|
169
|
+
def __init__(self, power: _Value):
|
|
170
|
+
super().__init__({})
|
|
171
|
+
|
|
172
|
+
if not isinstance(power, (int, float)):
|
|
173
|
+
self._set_child_('power', power)
|
|
174
|
+
|
|
175
|
+
self.power = power
|
|
176
|
+
|
|
177
|
+
@torch.no_grad()
|
|
178
|
+
def _update(self, state, ascent):
|
|
179
|
+
if isinstance(self.power, (int, float)):
|
|
180
|
+
return ascent.pow_(self.power)
|
|
181
|
+
|
|
182
|
+
state_copy = state.copy(clone_ascent = True)
|
|
183
|
+
power = self.children['power'].return_ascent(state_copy)
|
|
184
|
+
return ascent.pow_(power)
|
|
185
|
+
|
|
186
|
+
class RPow(OptimizerModule):
|
|
187
|
+
"""takes `value` to the power of ascent. `value` can be a scalar, an OptimizerModule or sequence of OptimizerModules"""
|
|
188
|
+
def __init__(self, base: _Value):
|
|
189
|
+
super().__init__({})
|
|
190
|
+
|
|
191
|
+
if not isinstance(base, (int, float)):
|
|
192
|
+
self._set_child_('base', base)
|
|
193
|
+
|
|
194
|
+
self.base = base
|
|
195
|
+
|
|
196
|
+
@torch.no_grad()
|
|
197
|
+
def _update(self, state, ascent):
|
|
198
|
+
if isinstance(self.base, (int, float)):
|
|
199
|
+
return self.base ** ascent
|
|
200
|
+
|
|
201
|
+
state_copy = state.copy(clone_ascent = True)
|
|
202
|
+
base = self.children['base'].return_ascent(state_copy)
|
|
203
|
+
return base.pow_(ascent)
|
|
204
|
+
|
|
205
|
+
class Power(OptimizerModule):
|
|
206
|
+
"""calculates *base ^ power*"""
|
|
207
|
+
def __init__(
|
|
208
|
+
self,
|
|
209
|
+
base: OptimizerModule | Iterable[OptimizerModule],
|
|
210
|
+
power: OptimizerModule | Iterable[OptimizerModule],
|
|
211
|
+
):
|
|
212
|
+
super().__init__({})
|
|
213
|
+
self._set_child_('base', base)
|
|
214
|
+
self._set_child_('power', power)
|
|
215
|
+
|
|
216
|
+
@torch.no_grad
|
|
217
|
+
def step(self, state):
|
|
218
|
+
state_copy = state.copy(clone_ascent = True)
|
|
219
|
+
base = self.children['base'].return_ascent(state_copy)
|
|
220
|
+
state.update_attrs_(state_copy)
|
|
221
|
+
power = self.children['power'].return_ascent(state)
|
|
222
|
+
|
|
223
|
+
state.ascent = base.pow_(power)
|
|
224
|
+
return self._update_params_or_step_with_next(state)
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
class Lerp(OptimizerModule):
|
|
228
|
+
"""Linear interpolation between update and `end` based on scalar `weight`.
|
|
229
|
+
|
|
230
|
+
`out = update + weight * (end - update)`"""
|
|
231
|
+
def __init__(self, end: OptimizerModule | Iterable[OptimizerModule], weight: float):
|
|
232
|
+
super().__init__({})
|
|
233
|
+
|
|
234
|
+
self._set_child_('end', end)
|
|
235
|
+
self.weight = weight
|
|
236
|
+
|
|
237
|
+
@torch.no_grad()
|
|
238
|
+
def _update(self, state, ascent):
|
|
239
|
+
|
|
240
|
+
state_copy = state.copy(clone_ascent = True)
|
|
241
|
+
end = self.children['end'].return_ascent(state_copy)
|
|
242
|
+
return ascent.lerp_(end, self.weight)
|
|
243
|
+
|
|
244
|
+
|
|
245
|
+
class Interpolate(OptimizerModule):
|
|
246
|
+
"""Does a linear interpolation of two module's updates - `start` (given by input), and `end`, based on a scalar
|
|
247
|
+
`weight`.
|
|
248
|
+
|
|
249
|
+
`out = input + weight * (end - input)`"""
|
|
250
|
+
def __init__(
|
|
251
|
+
self,
|
|
252
|
+
input: OptimizerModule | Iterable[OptimizerModule],
|
|
253
|
+
end: OptimizerModule | Iterable[OptimizerModule],
|
|
254
|
+
weight: float,
|
|
255
|
+
):
|
|
256
|
+
super().__init__({})
|
|
257
|
+
self._set_child_('input', input)
|
|
258
|
+
self._set_child_('end', end)
|
|
259
|
+
self.weight = weight
|
|
260
|
+
|
|
261
|
+
@torch.no_grad
|
|
262
|
+
def step(self, state):
|
|
263
|
+
state_copy = state.copy(clone_ascent = True)
|
|
264
|
+
input = self.children['input'].return_ascent(state_copy)
|
|
265
|
+
state.update_attrs_(state_copy)
|
|
266
|
+
end = self.children['end'].return_ascent(state)
|
|
267
|
+
|
|
268
|
+
state.ascent = input.lerp_(end, weight = self.weight)
|
|
269
|
+
|
|
270
|
+
return self._update_params_or_step_with_next(state)
|
|
271
|
+
|
|
272
|
+
class AddMagnitude(OptimizerModule):
|
|
273
|
+
"""Add `value` multiplied by sign of the ascent, i.e. this adds `value` to the magnitude of the update.
|
|
274
|
+
|
|
275
|
+
Args:
|
|
276
|
+
value (Value): value to add to magnitude, either a float or an OptimizerModule.
|
|
277
|
+
add_to_zero (bool, optional):
|
|
278
|
+
if True, adds `value` to 0s. Otherwise, zeros remain zero.
|
|
279
|
+
Only has effect if value is a float. Defaults to True.
|
|
280
|
+
"""
|
|
281
|
+
def __init__(self, value: _Value, add_to_zero=True):
|
|
282
|
+
super().__init__({})
|
|
283
|
+
|
|
284
|
+
if not isinstance(value, (int, float)):
|
|
285
|
+
self._set_child_('value', value)
|
|
286
|
+
|
|
287
|
+
self.value = value
|
|
288
|
+
self.add_to_zero = add_to_zero
|
|
289
|
+
|
|
290
|
+
@torch.no_grad()
|
|
291
|
+
def _update(self, state, ascent):
|
|
292
|
+
if isinstance(self.value, (int, float)):
|
|
293
|
+
if self.add_to_zero: return ascent.add_(ascent.clamp_magnitude(min=1).sign_().mul_(self.value))
|
|
294
|
+
return ascent.add_(ascent.sign_().mul_(self.value))
|
|
295
|
+
|
|
296
|
+
state_copy = state.copy(clone_ascent = True)
|
|
297
|
+
v = self.children['value'].return_ascent(state_copy)
|
|
298
|
+
return ascent.add_(v.abs_().mul_(ascent.sign()))
|
|
@@ -0,0 +1,134 @@
|
|
|
1
|
+
from collections.abc import Callable, Iterable
|
|
2
|
+
import numpy as np
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from ...core import OptimizerModule
|
|
6
|
+
|
|
7
|
+
_Value = int | float | OptimizerModule | Iterable[OptimizerModule]
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class Sum(OptimizerModule):
|
|
11
|
+
"""calculates sum of multiple updates.
|
|
12
|
+
|
|
13
|
+
Args:
|
|
14
|
+
*modules:
|
|
15
|
+
either OptimizerModules or iterables of OptimizerModules to chain. Scalars are also allowed."""
|
|
16
|
+
def __init__(
|
|
17
|
+
self,
|
|
18
|
+
*modules: _Value,
|
|
19
|
+
):
|
|
20
|
+
super().__init__({})
|
|
21
|
+
|
|
22
|
+
scalars = [i for i in modules if isinstance(i, (int,float))]
|
|
23
|
+
self.scalar = sum(scalars) if len(scalars) > 0 else None
|
|
24
|
+
|
|
25
|
+
for i,module in enumerate(i for i in modules if not isinstance(i, (int, float))):
|
|
26
|
+
self._set_child_(i, module)
|
|
27
|
+
|
|
28
|
+
@torch.no_grad
|
|
29
|
+
def step(self, state):
|
|
30
|
+
if len(self.children) == 1:
|
|
31
|
+
state.ascent = self.children[0].return_ascent(state)
|
|
32
|
+
if self.scalar is not None: state.ascent += self.scalar
|
|
33
|
+
return self._update_params_or_step_with_next(state)
|
|
34
|
+
|
|
35
|
+
sum = None
|
|
36
|
+
for i, c in sorted(self.children.items(), key=lambda x: x[0]):
|
|
37
|
+
if i == len(self.children) - 1: cur_state = state
|
|
38
|
+
else: cur_state = state.copy(clone_ascent = True)
|
|
39
|
+
|
|
40
|
+
if sum is None: sum = c.return_ascent(cur_state)
|
|
41
|
+
else: sum += c.return_ascent(cur_state)
|
|
42
|
+
|
|
43
|
+
if i != len(self.children) - 1: state.update_attrs_(cur_state)
|
|
44
|
+
|
|
45
|
+
assert sum is not None
|
|
46
|
+
if self.scalar is not None: sum += self.scalar
|
|
47
|
+
state.ascent = sum
|
|
48
|
+
return self._update_params_or_step_with_next(state)
|
|
49
|
+
|
|
50
|
+
class Mean(OptimizerModule):
|
|
51
|
+
"""calculates mean of multiple updates.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
*modules:
|
|
55
|
+
either OptimizerModules or iterables of OptimizerModules to chain. Scalars are also allowed."""
|
|
56
|
+
|
|
57
|
+
def __init__(
|
|
58
|
+
self,
|
|
59
|
+
*modules: _Value,
|
|
60
|
+
):
|
|
61
|
+
super().__init__({})
|
|
62
|
+
|
|
63
|
+
scalars = [i for i in modules if isinstance(i, (int,float))]
|
|
64
|
+
self.scalar = sum(scalars) if len(scalars) > 0 else None
|
|
65
|
+
|
|
66
|
+
self.n_values = len(modules)
|
|
67
|
+
|
|
68
|
+
for i,module in enumerate(i for i in modules if not isinstance(i, (int, float))):
|
|
69
|
+
self._set_child_(i, module)
|
|
70
|
+
|
|
71
|
+
@torch.no_grad
|
|
72
|
+
def step(self, state):
|
|
73
|
+
if len(self.children) == 1:
|
|
74
|
+
state.ascent = self.children[0].return_ascent(state)
|
|
75
|
+
if self.scalar is not None: state.ascent += self.scalar
|
|
76
|
+
if self.n_values > 1: state.ascent /= self.n_values
|
|
77
|
+
return self._update_params_or_step_with_next(state)
|
|
78
|
+
|
|
79
|
+
sum = None
|
|
80
|
+
for i, c in sorted(self.children.items(), key=lambda x: x[0]):
|
|
81
|
+
if i == len(self.children) - 1: cur_state = state
|
|
82
|
+
else: cur_state = state.copy(clone_ascent = True)
|
|
83
|
+
|
|
84
|
+
if sum is None: sum = c.return_ascent(cur_state)
|
|
85
|
+
else: sum += c.return_ascent(cur_state)
|
|
86
|
+
|
|
87
|
+
if i != len(self.children) - 1: state.update_attrs_(cur_state)
|
|
88
|
+
|
|
89
|
+
assert sum is not None
|
|
90
|
+
if self.scalar is not None: sum += self.scalar
|
|
91
|
+
if self.n_values > 1: sum /= self.n_values
|
|
92
|
+
state.ascent = sum
|
|
93
|
+
return self._update_params_or_step_with_next(state)
|
|
94
|
+
|
|
95
|
+
class Product(OptimizerModule):
|
|
96
|
+
"""calculates product of multiple updates.
|
|
97
|
+
|
|
98
|
+
Args:
|
|
99
|
+
*modules:
|
|
100
|
+
either OptimizerModules or iterables of OptimizerModules to chain. Scalars are also allowed."""
|
|
101
|
+
|
|
102
|
+
def __init__(
|
|
103
|
+
self,
|
|
104
|
+
*modules: _Value,
|
|
105
|
+
):
|
|
106
|
+
super().__init__({})
|
|
107
|
+
|
|
108
|
+
scalars = [i for i in modules if isinstance(i, (int,float))]
|
|
109
|
+
self.scalar = np.prod(scalars).item() if len(scalars) > 0 else None
|
|
110
|
+
|
|
111
|
+
for i,module in enumerate(i for i in modules if not isinstance(i, (int, float))):
|
|
112
|
+
self._set_child_(i, module)
|
|
113
|
+
|
|
114
|
+
@torch.no_grad
|
|
115
|
+
def step(self, state):
|
|
116
|
+
if len(self.children) == 1:
|
|
117
|
+
state.ascent = self.children[0].return_ascent(state)
|
|
118
|
+
if self.scalar is not None: state.ascent *= self.scalar
|
|
119
|
+
return self._update_params_or_step_with_next(state)
|
|
120
|
+
|
|
121
|
+
prod = None
|
|
122
|
+
for i, c in sorted(self.children.items(), key=lambda x: x[0]):
|
|
123
|
+
if i == len(self.children) - 1: cur_state = state
|
|
124
|
+
else: cur_state = state.copy(clone_ascent = True)
|
|
125
|
+
|
|
126
|
+
if prod is None: prod = c.return_ascent(cur_state)
|
|
127
|
+
else: prod *= c.return_ascent(cur_state)
|
|
128
|
+
|
|
129
|
+
if i != len(self.children) - 1: state.update_attrs_(cur_state)
|
|
130
|
+
|
|
131
|
+
assert prod is not None
|
|
132
|
+
if self.scalar is not None: prod *= self.scalar
|
|
133
|
+
state.ascent = prod
|
|
134
|
+
return self._update_params_or_step_with_next(state)
|
|
@@ -0,0 +1,113 @@
|
|
|
1
|
+
from collections.abc import Iterable
|
|
2
|
+
from operator import methodcaller
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
from ...core import OptimizerModule
|
|
7
|
+
from ...tensorlist import TensorList
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class Operation(OptimizerModule):
|
|
11
|
+
"""Applies an operation to the ascent, supported operations:
|
|
12
|
+
|
|
13
|
+
`abs`, `sign`, `sin`, `cos`, `tan`, `asin`, `acos`, `atan`, `sinh`, `cosh`,
|
|
14
|
+
`tanh`, `log`, `log1p`, `log2`, `log10`, `erf`, `erfc`, `exp`, `neg`, `reciprocal`,
|
|
15
|
+
`copy`, `zero`, `sqrt`, `floor`, `ceil`, `round`."""
|
|
16
|
+
def __init__(self, operation: str):
|
|
17
|
+
super().__init__({})
|
|
18
|
+
self.operation = methodcaller(f'{operation}_')
|
|
19
|
+
|
|
20
|
+
@torch.no_grad
|
|
21
|
+
def _update(self, state, ascent): return self.operation(ascent)
|
|
22
|
+
|
|
23
|
+
class Reciprocal(OptimizerModule):
|
|
24
|
+
"""*1 / update*"""
|
|
25
|
+
def __init__(self,):
|
|
26
|
+
super().__init__({})
|
|
27
|
+
|
|
28
|
+
@torch.no_grad()
|
|
29
|
+
def _update(self, state, ascent): return ascent.reciprocal_()
|
|
30
|
+
|
|
31
|
+
class Negate(OptimizerModule):
|
|
32
|
+
"""minus update"""
|
|
33
|
+
def __init__(self,):
|
|
34
|
+
super().__init__({})
|
|
35
|
+
|
|
36
|
+
@torch.no_grad()
|
|
37
|
+
def _update(self, state, ascent): return ascent.neg_()
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def sign_grad_(params: Iterable[torch.Tensor]):
|
|
41
|
+
"""Apply sign function to gradients of an iterable of parameters.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
params (abc.Iterable[torch.Tensor]): an iterable of Tensors or a single Tensor.
|
|
45
|
+
"""
|
|
46
|
+
TensorList(params).get_existing_grads().sign_()
|
|
47
|
+
|
|
48
|
+
class Sign(OptimizerModule):
|
|
49
|
+
"""applies sign function to the update"""
|
|
50
|
+
def __init__(self):
|
|
51
|
+
super().__init__({})
|
|
52
|
+
|
|
53
|
+
@torch.no_grad
|
|
54
|
+
def _update(self, state, ascent): return ascent.sign_()
|
|
55
|
+
|
|
56
|
+
class Abs(OptimizerModule):
|
|
57
|
+
"""takes absolute values of the update."""
|
|
58
|
+
def __init__(self):
|
|
59
|
+
super().__init__({})
|
|
60
|
+
|
|
61
|
+
@torch.no_grad
|
|
62
|
+
def _update(self, state, ascent): return ascent.abs_()
|
|
63
|
+
|
|
64
|
+
class Sin(OptimizerModule):
|
|
65
|
+
"""applies sin function to the ascent"""
|
|
66
|
+
def __init__(self):
|
|
67
|
+
super().__init__({})
|
|
68
|
+
|
|
69
|
+
@torch.no_grad
|
|
70
|
+
def _update(self, state, ascent): return ascent.sin_()
|
|
71
|
+
|
|
72
|
+
class Cos(OptimizerModule):
|
|
73
|
+
"""applies cos function to the ascent"""
|
|
74
|
+
def __init__(self):
|
|
75
|
+
super().__init__({})
|
|
76
|
+
|
|
77
|
+
@torch.no_grad
|
|
78
|
+
def _update(self, state, ascent): return ascent.cos_()
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
class NanToNum(OptimizerModule):
|
|
82
|
+
"""Convert `nan`, `inf` and `-inf` to numbers.
|
|
83
|
+
|
|
84
|
+
Args:
|
|
85
|
+
nan (optional): the value to replace NaNs with. Default is zero.
|
|
86
|
+
posinf (optional): if a Number, the value to replace positive infinity values with.
|
|
87
|
+
If None, positive infinity values are replaced with the greatest finite value
|
|
88
|
+
representable by input's dtype. Default is None.
|
|
89
|
+
neginf (optional): if a Number, the value to replace negative infinity values with.
|
|
90
|
+
If None, negative infinity values are replaced with the lowest finite value
|
|
91
|
+
representable by input's dtype. Default is None.
|
|
92
|
+
"""
|
|
93
|
+
def __init__(self, nan=None, posinf=None, neginf=None):
|
|
94
|
+
super().__init__({})
|
|
95
|
+
self.nan = nan
|
|
96
|
+
self.posinf = posinf
|
|
97
|
+
self.neginf = neginf
|
|
98
|
+
|
|
99
|
+
@torch.no_grad()
|
|
100
|
+
def _update(self, state, ascent): return ascent.nan_to_num_(self.nan, self.posinf, self.neginf)
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
class MagnitudePower(OptimizerModule):
|
|
104
|
+
"""Raises update to the `value` power, but preserves the sign when the power is odd."""
|
|
105
|
+
def __init__(self, value: int | float):
|
|
106
|
+
super().__init__({})
|
|
107
|
+
self.value = value
|
|
108
|
+
|
|
109
|
+
@torch.no_grad()
|
|
110
|
+
def _update(self, state, ascent):
|
|
111
|
+
if self.value % 2 == 1: return ascent.pow_(self.value)
|
|
112
|
+
return ascent.abs().pow_(self.value) * ascent.sign()
|
|
113
|
+
|
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
from collections import abc
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from ...tensorlist import TensorList
|
|
6
|
+
from ...core import OptimizerModule
|
|
7
|
+
|
|
8
|
+
def _adagrad_step_(ascent: TensorList, grad_sum: TensorList, alpha: TensorList, lr_decay: TensorList, eps: TensorList, step: int):
|
|
9
|
+
clr = alpha / (1 + step * lr_decay)
|
|
10
|
+
grad_sum.addcmul_(ascent, ascent)
|
|
11
|
+
return ascent.div_(grad_sum.sqrt().add_(eps)).mul_(clr)
|
|
12
|
+
|
|
13
|
+
class Adagrad(OptimizerModule):
|
|
14
|
+
"""
|
|
15
|
+
Divides ascent direction by mean square root of the sum of all past ascent directions.
|
|
16
|
+
|
|
17
|
+
Exactly matches `torch.optim.Adagrad`.
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
lr_decay (float, optional): learning rate decay. Defaults to 0.
|
|
21
|
+
initial_accumulator_value (float, optional): initial value of the sum of squares of gradients. Defaults to 0.
|
|
22
|
+
eps (float, optional): term added to the denominator to improve numerical stability. Defaults to 1e-10.
|
|
23
|
+
alpha (float, optional): learning rate. Defaults to 1.
|
|
24
|
+
|
|
25
|
+
reference
|
|
26
|
+
https://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf
|
|
27
|
+
"""
|
|
28
|
+
def __init__(self, lr_decay: float = 0, initial_accumulator_value: float = 0, eps: float = 1e-10, alpha: float = 1):
|
|
29
|
+
defaults = dict(alpha = alpha, lr_decay = lr_decay, initial_accumulator_value=initial_accumulator_value, eps = eps)
|
|
30
|
+
super().__init__(defaults)
|
|
31
|
+
self.cur_step = 0
|
|
32
|
+
|
|
33
|
+
@torch.no_grad
|
|
34
|
+
def _update(self, state, ascent):
|
|
35
|
+
settings = self.get_all_group_keys()
|
|
36
|
+
if self.cur_step == 0: init = ascent.full_like(settings['initial_accumulator_value'])
|
|
37
|
+
else: init = None
|
|
38
|
+
grad_sum = self.get_state_key('grad_sum', init = init) # type:ignore
|
|
39
|
+
|
|
40
|
+
updated_direction = _adagrad_step_(
|
|
41
|
+
ascent=ascent,
|
|
42
|
+
grad_sum=grad_sum,
|
|
43
|
+
alpha=settings["alpha"],
|
|
44
|
+
eps=settings["eps"],
|
|
45
|
+
lr_decay=settings["lr_decay"],
|
|
46
|
+
step=self.cur_step,
|
|
47
|
+
)
|
|
48
|
+
self.cur_step += 1
|
|
49
|
+
return updated_direction
|