torchzero 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torchzero/__init__.py +4 -0
- torchzero/core/__init__.py +13 -0
- torchzero/core/module.py +471 -0
- torchzero/core/tensorlist_optimizer.py +219 -0
- torchzero/modules/__init__.py +21 -0
- torchzero/modules/adaptive/__init__.py +4 -0
- torchzero/modules/adaptive/adaptive.py +192 -0
- torchzero/modules/experimental/__init__.py +19 -0
- torchzero/modules/experimental/experimental.py +294 -0
- torchzero/modules/experimental/quad_interp.py +104 -0
- torchzero/modules/experimental/subspace.py +259 -0
- torchzero/modules/gradient_approximation/__init__.py +7 -0
- torchzero/modules/gradient_approximation/_fd_formulas.py +3 -0
- torchzero/modules/gradient_approximation/base_approximator.py +110 -0
- torchzero/modules/gradient_approximation/fdm.py +125 -0
- torchzero/modules/gradient_approximation/forward_gradient.py +163 -0
- torchzero/modules/gradient_approximation/newton_fdm.py +198 -0
- torchzero/modules/gradient_approximation/rfdm.py +125 -0
- torchzero/modules/line_search/__init__.py +30 -0
- torchzero/modules/line_search/armijo.py +56 -0
- torchzero/modules/line_search/base_ls.py +139 -0
- torchzero/modules/line_search/directional_newton.py +217 -0
- torchzero/modules/line_search/grid_ls.py +158 -0
- torchzero/modules/line_search/scipy_minimize_scalar.py +62 -0
- torchzero/modules/meta/__init__.py +12 -0
- torchzero/modules/meta/alternate.py +65 -0
- torchzero/modules/meta/grafting.py +195 -0
- torchzero/modules/meta/optimizer_wrapper.py +173 -0
- torchzero/modules/meta/return_overrides.py +46 -0
- torchzero/modules/misc/__init__.py +10 -0
- torchzero/modules/misc/accumulate.py +43 -0
- torchzero/modules/misc/basic.py +115 -0
- torchzero/modules/misc/lr.py +96 -0
- torchzero/modules/misc/multistep.py +51 -0
- torchzero/modules/misc/on_increase.py +53 -0
- torchzero/modules/momentum/__init__.py +4 -0
- torchzero/modules/momentum/momentum.py +106 -0
- torchzero/modules/operations/__init__.py +29 -0
- torchzero/modules/operations/multi.py +298 -0
- torchzero/modules/operations/reduction.py +134 -0
- torchzero/modules/operations/singular.py +113 -0
- torchzero/modules/optimizers/__init__.py +10 -0
- torchzero/modules/optimizers/adagrad.py +49 -0
- torchzero/modules/optimizers/adam.py +118 -0
- torchzero/modules/optimizers/lion.py +28 -0
- torchzero/modules/optimizers/rmsprop.py +51 -0
- torchzero/modules/optimizers/rprop.py +99 -0
- torchzero/modules/optimizers/sgd.py +54 -0
- torchzero/modules/orthogonalization/__init__.py +2 -0
- torchzero/modules/orthogonalization/newtonschulz.py +159 -0
- torchzero/modules/orthogonalization/svd.py +86 -0
- torchzero/modules/quasi_newton/__init__.py +4 -0
- torchzero/modules/regularization/__init__.py +22 -0
- torchzero/modules/regularization/dropout.py +34 -0
- torchzero/modules/regularization/noise.py +77 -0
- torchzero/modules/regularization/normalization.py +328 -0
- torchzero/modules/regularization/ortho_grad.py +78 -0
- torchzero/modules/regularization/weight_decay.py +92 -0
- torchzero/modules/scheduling/__init__.py +2 -0
- torchzero/modules/scheduling/lr_schedulers.py +131 -0
- torchzero/modules/scheduling/step_size.py +80 -0
- torchzero/modules/second_order/__init__.py +4 -0
- torchzero/modules/second_order/newton.py +165 -0
- torchzero/modules/smoothing/__init__.py +5 -0
- torchzero/modules/smoothing/gaussian_smoothing.py +90 -0
- torchzero/modules/smoothing/laplacian_smoothing.py +128 -0
- torchzero/modules/weight_averaging/__init__.py +2 -0
- torchzero/modules/weight_averaging/ema.py +72 -0
- torchzero/modules/weight_averaging/swa.py +171 -0
- torchzero/optim/__init__.py +10 -0
- torchzero/optim/experimental/__init__.py +20 -0
- torchzero/optim/experimental/experimental.py +343 -0
- torchzero/optim/experimental/ray_search.py +83 -0
- torchzero/optim/first_order/__init__.py +18 -0
- torchzero/optim/first_order/cautious.py +158 -0
- torchzero/optim/first_order/forward_gradient.py +70 -0
- torchzero/optim/first_order/optimizers.py +570 -0
- torchzero/optim/modular.py +132 -0
- torchzero/optim/quasi_newton/__init__.py +1 -0
- torchzero/optim/quasi_newton/directional_newton.py +58 -0
- torchzero/optim/second_order/__init__.py +1 -0
- torchzero/optim/second_order/newton.py +94 -0
- torchzero/optim/wrappers/__init__.py +0 -0
- torchzero/optim/wrappers/nevergrad.py +113 -0
- torchzero/optim/wrappers/nlopt.py +165 -0
- torchzero/optim/wrappers/scipy.py +439 -0
- torchzero/optim/zeroth_order/__init__.py +4 -0
- torchzero/optim/zeroth_order/fdm.py +87 -0
- torchzero/optim/zeroth_order/newton_fdm.py +146 -0
- torchzero/optim/zeroth_order/rfdm.py +217 -0
- torchzero/optim/zeroth_order/rs.py +85 -0
- torchzero/random/__init__.py +1 -0
- torchzero/random/random.py +46 -0
- torchzero/tensorlist.py +819 -0
- torchzero/utils/__init__.py +0 -0
- torchzero/utils/compile.py +39 -0
- torchzero/utils/derivatives.py +99 -0
- torchzero/utils/python_tools.py +25 -0
- torchzero/utils/torch_tools.py +92 -0
- torchzero-0.0.1.dist-info/LICENSE +21 -0
- torchzero-0.0.1.dist-info/METADATA +118 -0
- torchzero-0.0.1.dist-info/RECORD +104 -0
- torchzero-0.0.1.dist-info/WHEEL +5 -0
- torchzero-0.0.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
from collections import abc
|
|
2
|
+
from typing import Literal
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
from ...core import OptimizerModule
|
|
7
|
+
from ...tensorlist import Distributions, TensorList, _Scalar, _ScalarSequence
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def add_noise_(
|
|
11
|
+
grads: abc.Iterable[torch.Tensor],
|
|
12
|
+
alpha: "_Scalar | _ScalarSequence" = 1e-2,
|
|
13
|
+
distribution: Distributions = "normal",
|
|
14
|
+
mode: Literal["absolute", "global", "param", "channel"] = "param",
|
|
15
|
+
):
|
|
16
|
+
if not isinstance(grads, TensorList): grads = TensorList(grads)
|
|
17
|
+
if mode == 'absolute':
|
|
18
|
+
grads += grads.sample_like(alpha, distribution)
|
|
19
|
+
|
|
20
|
+
elif mode == 'global':
|
|
21
|
+
grads += grads.sample_like((grads.total_vector_norm(1)/grads.total_numel() * alpha).detach().cpu().item(), distribution)
|
|
22
|
+
|
|
23
|
+
elif mode == 'param':
|
|
24
|
+
grads += grads.sample_like(grads.abs().mean()*alpha, distribution)
|
|
25
|
+
|
|
26
|
+
elif mode == 'channel':
|
|
27
|
+
grads = grads.unbind_channels()
|
|
28
|
+
grads += grads.sample_like(grads.abs().mean()*alpha, distribution)
|
|
29
|
+
|
|
30
|
+
class AddNoise(OptimizerModule):
|
|
31
|
+
"""Add noise to update. By default noise magnitude is relative to the mean of each parameter.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
alpha (float, optional): magnitude of noise. Defaults to 1e-2.
|
|
35
|
+
distribution (Distributions, optional): distribution of noise. Defaults to 'normal'.
|
|
36
|
+
mode (str, optional):
|
|
37
|
+
how to calculate noise magnitude.
|
|
38
|
+
|
|
39
|
+
- "absolute": ignores gradient magnitude and always uses `alpha` as magnitude.
|
|
40
|
+
|
|
41
|
+
- "global": multiplies `alpha` by mean of the entire gradient, as if it was a single vector.
|
|
42
|
+
|
|
43
|
+
- "param": multiplies `alpha` by mean of each individual parameter (default).
|
|
44
|
+
|
|
45
|
+
- "channel": multiplies `alpha` by mean of each channel of each parameter.
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
def __init__(
|
|
49
|
+
self,
|
|
50
|
+
alpha: float = 1.,
|
|
51
|
+
distribution: Distributions = "normal",
|
|
52
|
+
mode: Literal["absolute", "global", "param", "channel"] = "param",
|
|
53
|
+
):
|
|
54
|
+
defaults = dict(alpha = alpha)
|
|
55
|
+
super().__init__(defaults)
|
|
56
|
+
self.distribution: Distributions = distribution
|
|
57
|
+
self.mode: Literal["absolute", "global", "param", "channel"] = mode
|
|
58
|
+
|
|
59
|
+
@torch.no_grad
|
|
60
|
+
def _update(self, state, ascent):
|
|
61
|
+
alpha = self.get_group_key('alpha')
|
|
62
|
+
|
|
63
|
+
add_noise_(ascent, alpha, self.distribution, self.mode)
|
|
64
|
+
return ascent
|
|
65
|
+
|
|
66
|
+
class Random(OptimizerModule):
|
|
67
|
+
"""uses a random vector as the update. The vector is completely random and isn't checked to be descent direction.
|
|
68
|
+
This is therefore mainly useful in combination with other modules like Sum, Multiply, etc."""
|
|
69
|
+
def __init__(self, alpha: float = 1, distribution: Distributions = "normal"):
|
|
70
|
+
defaults = dict(alpha = alpha)
|
|
71
|
+
super().__init__(defaults)
|
|
72
|
+
self.distribution: Distributions = distribution
|
|
73
|
+
|
|
74
|
+
@torch.no_grad
|
|
75
|
+
def _update(self, state, ascent):
|
|
76
|
+
alpha = self.get_group_key('alpha')
|
|
77
|
+
return ascent.sample_like(alpha, self.distribution)
|
|
@@ -0,0 +1,328 @@
|
|
|
1
|
+
from collections import abc
|
|
2
|
+
import typing
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from ...tensorlist import TensorList
|
|
6
|
+
from ...core import OptimizerModule
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def _normalize_grad_(
|
|
10
|
+
grads: abc.Iterable[torch.Tensor],
|
|
11
|
+
norm_value: float = 1,
|
|
12
|
+
ord: float = 2,
|
|
13
|
+
min: float = 0,
|
|
14
|
+
mode: typing.Literal["global", "param", "channel"] = "param",
|
|
15
|
+
min_numel=2,
|
|
16
|
+
):
|
|
17
|
+
if mode in ('param', 'channel'):
|
|
18
|
+
for grad in grads:
|
|
19
|
+
if grad.numel() >= min_numel:
|
|
20
|
+
if mode == 'channel' and grad.ndim >= 2:
|
|
21
|
+
norm = torch.linalg.vector_norm(grad, ord, dim=tuple(range(1, grad.ndim)), keepdim=True) # pylint:disable=not-callable
|
|
22
|
+
norm[norm<=min] = 1
|
|
23
|
+
grad /= norm / norm_value
|
|
24
|
+
else: # mode = 'param' or 1d grad
|
|
25
|
+
norm = torch.linalg.vector_norm(grad, ord) # pylint:disable=not-callable
|
|
26
|
+
if norm > min:
|
|
27
|
+
grad /= norm / norm_value
|
|
28
|
+
else:
|
|
29
|
+
if not isinstance(grads, TensorList): grads = TensorList(grads)
|
|
30
|
+
norm = grads.total_vector_norm(ord)
|
|
31
|
+
if norm > min:
|
|
32
|
+
grads /= norm / norm_value
|
|
33
|
+
|
|
34
|
+
@torch.no_grad
|
|
35
|
+
def normalize_grad_(
|
|
36
|
+
params: abc.Iterable[torch.Tensor],
|
|
37
|
+
norm_value: float = 1,
|
|
38
|
+
ord: float = 2,
|
|
39
|
+
min: float = 0,
|
|
40
|
+
mode: typing.Literal["global", "param", "channel"] = "global",
|
|
41
|
+
min_numel=2,
|
|
42
|
+
):
|
|
43
|
+
"""Normalizes gradients of an iterable of parameters.
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
params (abc.Iterable[torch.Tensor]): parameters that hold gradients to normalize.
|
|
47
|
+
norm_value (float, optional): value to normalize to. Defaults to 1.
|
|
48
|
+
ord (float, optional): order of the norm. Defaults to 2.
|
|
49
|
+
min (float, optional):
|
|
50
|
+
won't normalize when gradient is below this norm, you can increase this
|
|
51
|
+
to avoid amplifying extremely small gradients. Defaults to 0.
|
|
52
|
+
mode (str, optional):
|
|
53
|
+
what to normalize.
|
|
54
|
+
|
|
55
|
+
- "global": normalize the entire gradient, as if it was a single vector.
|
|
56
|
+
|
|
57
|
+
- "param": normalize each param's gradient (default).
|
|
58
|
+
|
|
59
|
+
- "channel": normalize gradient of each channel of each param.
|
|
60
|
+
min_numel (int, optional):
|
|
61
|
+
skips parameters with less than this many elements. This avoids the issue where
|
|
62
|
+
parameters that have a single element always get set to the value of 1.
|
|
63
|
+
Ignored when mode is 'global'.
|
|
64
|
+
|
|
65
|
+
Example:
|
|
66
|
+
>>> normalize_grad_(model.parameters())
|
|
67
|
+
"""
|
|
68
|
+
_normalize_grad_(
|
|
69
|
+
(p.grad for p in params if p.grad is not None),
|
|
70
|
+
norm_value = norm_value,
|
|
71
|
+
ord = ord,
|
|
72
|
+
min = min,
|
|
73
|
+
mode = mode,
|
|
74
|
+
min_numel = min_numel,
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
class Normalize(OptimizerModule):
|
|
78
|
+
"""Normalizes update to the given norm value.
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
norm_value (float, optional): value to normalize to. Defaults to 1.
|
|
82
|
+
ord (float, optional): order of the norm. Defaults to 2.
|
|
83
|
+
min (float, optional):
|
|
84
|
+
won't normalize when gradient is below this norm, you can increase this
|
|
85
|
+
to avoid amplifying extremely small gradients. Defaults to 0.
|
|
86
|
+
mode (str, optional):
|
|
87
|
+
what to normalize.
|
|
88
|
+
|
|
89
|
+
- "global": normalize the entire gradient, as if it was a single vector.
|
|
90
|
+
|
|
91
|
+
- "param": normalize each param's gradient (default).
|
|
92
|
+
|
|
93
|
+
- "channel": normalize gradient of each channel of each param.
|
|
94
|
+
min_numel (int, optional):
|
|
95
|
+
skips parameters with less than this many elements. This avoids the issue where
|
|
96
|
+
parameters that have a single element always get set to the value of 1.
|
|
97
|
+
Ignored when mode is 'global'.
|
|
98
|
+
"""
|
|
99
|
+
def __init__(
|
|
100
|
+
self,
|
|
101
|
+
norm_value: float = 1,
|
|
102
|
+
ord: float = 2,
|
|
103
|
+
min: float = 0,
|
|
104
|
+
mode: typing.Literal["global", "param", "channel"] = "param",
|
|
105
|
+
min_numel=2,
|
|
106
|
+
):
|
|
107
|
+
super().__init__({})
|
|
108
|
+
self.norm_value = norm_value
|
|
109
|
+
self.ord = ord
|
|
110
|
+
self.min = min
|
|
111
|
+
self.mode: typing.Literal["global", "param", "channel"] = mode
|
|
112
|
+
self.min_numel = min_numel
|
|
113
|
+
|
|
114
|
+
@torch.no_grad
|
|
115
|
+
def _update(self, state, ascent):
|
|
116
|
+
_normalize_grad_(
|
|
117
|
+
ascent,
|
|
118
|
+
norm_value = self.norm_value,
|
|
119
|
+
ord = self.ord,
|
|
120
|
+
min = self.min,
|
|
121
|
+
mode = self.mode,
|
|
122
|
+
min_numel = self.min_numel,
|
|
123
|
+
)
|
|
124
|
+
return ascent
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def _centralize_grad_(
|
|
128
|
+
grads: abc.Iterable[torch.Tensor],
|
|
129
|
+
mode: typing.Literal["global", "param", "channel"] = "channel",
|
|
130
|
+
min_ndim=2,
|
|
131
|
+
min_numel=2,
|
|
132
|
+
):
|
|
133
|
+
if mode in ('param', 'channel'):
|
|
134
|
+
if mode == 'channel': min_ndim = max(min_ndim, 2)
|
|
135
|
+
for grad in grads:
|
|
136
|
+
if grad.numel() >= min_numel and grad.ndim > min_ndim:
|
|
137
|
+
if mode == 'channel':
|
|
138
|
+
grad -= grad.mean(dim=tuple(range(1, grad.ndim)), keepdim=True)
|
|
139
|
+
else: # mode = 'param'
|
|
140
|
+
grad -= grad.mean()
|
|
141
|
+
else:
|
|
142
|
+
if not isinstance(grads, TensorList): grads = TensorList(grads)
|
|
143
|
+
grads -= grads.mean()
|
|
144
|
+
|
|
145
|
+
@torch.no_grad
|
|
146
|
+
def centralize_grad_(
|
|
147
|
+
params: abc.Iterable[torch.Tensor],
|
|
148
|
+
mode: typing.Literal["global", "param", "channel"] = "channel",
|
|
149
|
+
min_ndim=2,
|
|
150
|
+
min_numel=2,
|
|
151
|
+
):
|
|
152
|
+
"""Centralizes gradients of an iterable of parameters.
|
|
153
|
+
|
|
154
|
+
Args:
|
|
155
|
+
params (abc.Iterable[torch.Tensor]): parameters that hold gradients to centralize.
|
|
156
|
+
mode (str, optional):
|
|
157
|
+
what to centralize.
|
|
158
|
+
|
|
159
|
+
- "global": centralize the entire gradient (uses mean of entire gradient).
|
|
160
|
+
|
|
161
|
+
- "param": centralize each param's gradient.
|
|
162
|
+
|
|
163
|
+
- "channel": centralize gradient of each channel of each param (default).
|
|
164
|
+
min_numel (int, optional):
|
|
165
|
+
skips parameters with less than this many elements. This avoids negating updates for
|
|
166
|
+
parameters that have a single element since subtracting mean always makes it 0.
|
|
167
|
+
Ignored when mode is 'global'.
|
|
168
|
+
min_ndim (int, optional):
|
|
169
|
+
skips parameters with less than this many dimensions.
|
|
170
|
+
bias usually has 1 dimension and you don't want to centralize it.
|
|
171
|
+
Ignored when mode is 'global'.
|
|
172
|
+
|
|
173
|
+
reference
|
|
174
|
+
*Yong, H., Huang, J., Hua, X., & Zhang, L. (2020).
|
|
175
|
+
Gradient centralization: A new optimization technique for deep neural networks.
|
|
176
|
+
In Computer Vision–ECCV 2020: 16th European Conference, Glasgow, UK,
|
|
177
|
+
August 23–28, 2020, Proceedings, Part I 16 (pp. 635-652). Springer International Publishing.*
|
|
178
|
+
|
|
179
|
+
Example:
|
|
180
|
+
>>> centralize_grad_(model.parameters())
|
|
181
|
+
"""
|
|
182
|
+
_centralize_grad_(
|
|
183
|
+
(p.grad for p in params if p.grad is not None),
|
|
184
|
+
mode = mode,
|
|
185
|
+
min_ndim = min_ndim,
|
|
186
|
+
min_numel = min_numel,
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
class Centralize(OptimizerModule):
|
|
190
|
+
"""Centralizes the update.
|
|
191
|
+
|
|
192
|
+
Args:
|
|
193
|
+
mode (str, optional):
|
|
194
|
+
what to centralize.
|
|
195
|
+
|
|
196
|
+
- "global": centralize the entire gradient (uses mean of entire gradient).
|
|
197
|
+
|
|
198
|
+
- "param": centralize each param's gradient.
|
|
199
|
+
|
|
200
|
+
- "channel": centralize gradient of each channel of each param (default).
|
|
201
|
+
min_numel (int, optional):
|
|
202
|
+
skips parameters with less than this many elements. This avoids negating updates for
|
|
203
|
+
parameters that have a single element since subtracting mean always makes it 0.
|
|
204
|
+
Ignored when mode is 'global'.
|
|
205
|
+
min_ndim (int, optional):
|
|
206
|
+
skips parameters with less than this many dimensions.
|
|
207
|
+
bias usually has 1 dimension and you don't want to centralize it.
|
|
208
|
+
Ignored when mode is 'global'.
|
|
209
|
+
|
|
210
|
+
reference
|
|
211
|
+
*Yong, H., Huang, J., Hua, X., & Zhang, L. (2020).
|
|
212
|
+
Gradient centralization: A new optimization technique for deep neural networks.
|
|
213
|
+
In Computer Vision–ECCV 2020: 16th European Conference, Glasgow, UK,
|
|
214
|
+
August 23–28, 2020, Proceedings, Part I 16 (pp. 635-652). Springer International Publishing.*
|
|
215
|
+
"""
|
|
216
|
+
def __init__(
|
|
217
|
+
self,
|
|
218
|
+
mode: typing.Literal["global", "param", "channel"] = "channel",
|
|
219
|
+
min_ndim=2,
|
|
220
|
+
min_numel=2,
|
|
221
|
+
):
|
|
222
|
+
super().__init__({})
|
|
223
|
+
self.mode: typing.Literal["global", "param", "channel"] = mode
|
|
224
|
+
self.min_ndim = min_ndim
|
|
225
|
+
self.min_numel = min_numel
|
|
226
|
+
|
|
227
|
+
@torch.no_grad
|
|
228
|
+
def _update(self, state, ascent):
|
|
229
|
+
_centralize_grad_(
|
|
230
|
+
ascent,
|
|
231
|
+
mode = self.mode,
|
|
232
|
+
min_ndim = self.min_ndim,
|
|
233
|
+
min_numel = self.min_numel,
|
|
234
|
+
)
|
|
235
|
+
return ascent
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
def clip_grad_value_(params: abc.Iterable[torch.Tensor], value:float):
|
|
239
|
+
"""Clip the gradients of an iterable of parameters at specified value.
|
|
240
|
+
|
|
241
|
+
Args:
|
|
242
|
+
params (abc.Iterable[torch.Tensor]): an iterable of Tensors or a single Tensor that will have gradients clipped.
|
|
243
|
+
value (float, optional):
|
|
244
|
+
maximum allowed magnitude of the gradients.
|
|
245
|
+
The gradients are clipped in the range `[-clip_value, clip_value]`
|
|
246
|
+
"""
|
|
247
|
+
TensorList(params).get_existing_grads().clamp_(-value, value)
|
|
248
|
+
|
|
249
|
+
class ClipValue(OptimizerModule):
|
|
250
|
+
"""Clip the update at specified value.
|
|
251
|
+
|
|
252
|
+
Args:
|
|
253
|
+
value (float, optional): maximum allowed magnitude of the gradients.
|
|
254
|
+
The gradients are clipped in the range `[-clip_value, clip_value]`
|
|
255
|
+
"""
|
|
256
|
+
def __init__(self, value: float):
|
|
257
|
+
defaults = dict(value = value)
|
|
258
|
+
super().__init__(defaults)
|
|
259
|
+
|
|
260
|
+
@torch.no_grad
|
|
261
|
+
def _update(self, state, ascent):
|
|
262
|
+
value = self.get_group_key('value')
|
|
263
|
+
ascent.clamp_(-value, value)
|
|
264
|
+
return ascent
|
|
265
|
+
|
|
266
|
+
def clip_grad_norm_(
|
|
267
|
+
params: abc.Iterable[torch.Tensor],
|
|
268
|
+
max_norm: float,
|
|
269
|
+
ord: float = 2,
|
|
270
|
+
mode: typing.Literal["global", "param", "channel"] = "param",
|
|
271
|
+
):
|
|
272
|
+
"""Clip the gradient norm of an iterable of parameters.
|
|
273
|
+
|
|
274
|
+
Args:
|
|
275
|
+
params (abc.Iterable[torch.Tensor]): parameters that hold gradients to clip the norm of.
|
|
276
|
+
max_norm (float, optional): norm value to clip to.
|
|
277
|
+
ord (float, optional): order of the norm. Defaults to 2.
|
|
278
|
+
mode (str, optional):
|
|
279
|
+
what to calculate the norm over.
|
|
280
|
+
|
|
281
|
+
- "global": calculates and clips the norm of the entire gradient, as if it was a single vector.
|
|
282
|
+
|
|
283
|
+
- "param": calculates and clips each param's gradient norm (default).
|
|
284
|
+
|
|
285
|
+
- "channel": calculate and clip the norm of gradient of each channel of each param.
|
|
286
|
+
|
|
287
|
+
Example:
|
|
288
|
+
>>> clip_grad_norm_(model.parameters())
|
|
289
|
+
"""
|
|
290
|
+
_normalize_grad_(
|
|
291
|
+
(p.grad for p in params if p.grad is not None),
|
|
292
|
+
norm_value = max_norm,
|
|
293
|
+
min = max_norm,
|
|
294
|
+
ord = ord,
|
|
295
|
+
mode = mode,
|
|
296
|
+
)
|
|
297
|
+
|
|
298
|
+
class ClipNorm(OptimizerModule):
|
|
299
|
+
"""Clip the gradient norm of an iterable of parameters.
|
|
300
|
+
|
|
301
|
+
Args:
|
|
302
|
+
max_norm (float, optional): norm value to clip to.
|
|
303
|
+
ord (float, optional): order of the norm. Defaults to 2.
|
|
304
|
+
mode (str, optional):
|
|
305
|
+
what to calculate the norm over.
|
|
306
|
+
|
|
307
|
+
- "global": calculates and clips the norm of the entire gradient, as if it was a single vector.
|
|
308
|
+
|
|
309
|
+
- "param": calculates and clips each param's gradient norm (default).
|
|
310
|
+
|
|
311
|
+
- "channel": calculate and clip the norm of gradient of each channel of each param.
|
|
312
|
+
"""
|
|
313
|
+
def __init__(self, max_norm: float, ord:float=2, mode: typing.Literal["global", "param", "channel"] = "param",):
|
|
314
|
+
super().__init__({})
|
|
315
|
+
self.max_norm = max_norm
|
|
316
|
+
self.ord = ord
|
|
317
|
+
self.mode: typing.Literal["global", "param", "channel"] = mode
|
|
318
|
+
|
|
319
|
+
@torch.no_grad
|
|
320
|
+
def _update(self, state, ascent):
|
|
321
|
+
_normalize_grad_(
|
|
322
|
+
ascent,
|
|
323
|
+
norm_value = self.max_norm,
|
|
324
|
+
min = self.max_norm,
|
|
325
|
+
ord = self.ord,
|
|
326
|
+
mode = self.mode,
|
|
327
|
+
)
|
|
328
|
+
return ascent
|
|
@@ -0,0 +1,78 @@
|
|
|
1
|
+
"""
|
|
2
|
+
⟂Grad (read “ortho-grad”) was proposed in https://arxiv.org/abs/2501.04697.
|
|
3
|
+
|
|
4
|
+
"""
|
|
5
|
+
from collections.abc import Iterable
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
|
|
9
|
+
from ...tensorlist import TensorList
|
|
10
|
+
from ...core import OptimizerModule, _Targets
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def orthograd_(params: Iterable[torch.Tensor], eps: float = 1e-30):
|
|
14
|
+
"""Applies ⟂Grad - projects gradient of an iterable of parameters to be orthogonal to the weights.
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
params (abc.Iterable[torch.Tensor]): parameters that hold gradients to apply ⟂Grad to.
|
|
18
|
+
eps (float, optional): epsilon added to the denominator for numerical stability (default: 1e-30)
|
|
19
|
+
|
|
20
|
+
reference
|
|
21
|
+
https://arxiv.org/abs/2501.04697
|
|
22
|
+
"""
|
|
23
|
+
if not isinstance(params, TensorList): params = TensorList(params)
|
|
24
|
+
params = params.with_grad()
|
|
25
|
+
grad = params.grad
|
|
26
|
+
grad -= (((params*grad).total_sum())/(params*params).total_sum() + eps) * params
|
|
27
|
+
|
|
28
|
+
class OrthoGrad(OptimizerModule):
|
|
29
|
+
"""⟂Grad - projects gradient of an iterable of parameters to be orthogonal to the weights.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
params (abc.Iterable[torch.Tensor]): parameters that hold gradients to apply ⟂Grad to.
|
|
33
|
+
eps (float, optional): epsilon added to the denominator for numerical stability (default: 1e-30)
|
|
34
|
+
renormalize (bool, optional): whether to renormalize gradients back to original norm (default: True).
|
|
35
|
+
sqrt_scale (bool, optional):
|
|
36
|
+
uses square root of the scale to make it more impactful, experimental setting and doesn't really work (default: False).
|
|
37
|
+
add (bool, optional):
|
|
38
|
+
Experimental option that changes subtraction to addition.
|
|
39
|
+
I don't think it has any geometric meaning but it drives weights towards zero instead of away from it.
|
|
40
|
+
and it seems to work well with sqrt_scale = True. It speeds up convergence by a lot compared to using vanilla gradient,
|
|
41
|
+
but also has INSANE overfitting.
|
|
42
|
+
target (str, optional):
|
|
43
|
+
determines what this module updates.
|
|
44
|
+
|
|
45
|
+
"ascent" - it updates the ascent (default).
|
|
46
|
+
|
|
47
|
+
"grad" - it updates the gradient (and sets `.grad` attributes to updated gradient).
|
|
48
|
+
|
|
49
|
+
"closure" - it makes a new closure that sets the updated ascent to the .`grad` attributes.
|
|
50
|
+
|
|
51
|
+
reference
|
|
52
|
+
https://arxiv.org/abs/2501.04697
|
|
53
|
+
"""
|
|
54
|
+
def __init__(self, eps: float = 1e-30, renormalize=True, sqrt_scale = False, add=False, target: _Targets = 'ascent'):
|
|
55
|
+
super().__init__({}, target=target)
|
|
56
|
+
self.eps = eps
|
|
57
|
+
self.add = add
|
|
58
|
+
self.renormalize = renormalize
|
|
59
|
+
self.sqrt_scale = sqrt_scale
|
|
60
|
+
|
|
61
|
+
def _update(self, state, ascent):
|
|
62
|
+
params = self.get_params()
|
|
63
|
+
|
|
64
|
+
if self.renormalize: orig_norm = ascent.norm(2) + self.eps
|
|
65
|
+
else: orig_norm = 1
|
|
66
|
+
|
|
67
|
+
scale = (params*ascent).total_sum() / ((params*params).total_sum() + self.eps)
|
|
68
|
+
if self.sqrt_scale:
|
|
69
|
+
scale = scale.abs().sqrt() * scale.sign()
|
|
70
|
+
|
|
71
|
+
if self.add: ascent += params * scale
|
|
72
|
+
else: ascent -= params * scale
|
|
73
|
+
|
|
74
|
+
if self.renormalize:
|
|
75
|
+
ascent *= (orig_norm / ascent.norm(2))
|
|
76
|
+
|
|
77
|
+
return ascent
|
|
78
|
+
|
|
@@ -0,0 +1,92 @@
|
|
|
1
|
+
from typing import Literal
|
|
2
|
+
from collections.abc import Iterable
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
from ...tensorlist import TensorList
|
|
7
|
+
from ...core import OptimizerModule, _Targets
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def l2_regularize_(params: Iterable[torch.Tensor], alpha: float = 1e-2):
|
|
11
|
+
"""Adds L2 weight regularization term to the gradients in-place.
|
|
12
|
+
|
|
13
|
+
Args:
|
|
14
|
+
params (Iterable[torch.Tensor]): an iterable of Tensors or a single Tensor.
|
|
15
|
+
alpha (float, optional): multiplier to the regularizer. Defaults to 1e-2.
|
|
16
|
+
"""
|
|
17
|
+
p = TensorList(params).with_requires_grad()
|
|
18
|
+
p.ensure_grad_()
|
|
19
|
+
p.grad.add_(p, alpha = alpha)
|
|
20
|
+
|
|
21
|
+
def l1_regularize_(params: Iterable[torch.Tensor], alpha: float = 1e-2):
|
|
22
|
+
"""Adds L1 weight regularization term to the gradients in-place.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
params (Iterable[torch.Tensor]): an iterable of Tensors or a single Tensor.
|
|
26
|
+
alpha (float, optional): multiplier to the regularizer. Defaults to 1e-2.
|
|
27
|
+
"""
|
|
28
|
+
p = TensorList(params).with_requires_grad()
|
|
29
|
+
p.ensure_grad_()
|
|
30
|
+
p.grad.add_(p.sign(), alpha = alpha)
|
|
31
|
+
|
|
32
|
+
def weight_decay_penalty(params: Iterable[torch.Tensor], alpha: float = 1e-2, ord:float = 2):
|
|
33
|
+
"""Calculate the weight decay penalty term that can be added to the loss.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
params (Iterable[torch.Tensor]): an iterable of Tensors or a single Tensor.
|
|
37
|
+
alpha (float): multiplier to the regularizer.
|
|
38
|
+
ord (int, optional): order of the norm. Defaults to 2.
|
|
39
|
+
"""
|
|
40
|
+
return TensorList(params).norm(ord) * alpha
|
|
41
|
+
|
|
42
|
+
def decay_weights_(params: Iterable[torch.Tensor], alpha: float = 1e-2, ord:Literal[1, 2] = 2):
|
|
43
|
+
"""Apply weight decay directly to parameters in-place.
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
params (Iterable[torch.Tensor]): an iterable of Tensors or a single Tensor to decay.
|
|
47
|
+
alpha (float): by how much to decay parameters (default: 1e-2)
|
|
48
|
+
ord (float, optional):
|
|
49
|
+
order of the penalty, 1 and 2 are currently supported (L1 and L2 regularization) (default: 2)
|
|
50
|
+
"""
|
|
51
|
+
params = TensorList(params)
|
|
52
|
+
if ord == 2: params.mul_(1-alpha)
|
|
53
|
+
elif ord == 1: params.sub_(params.sign().mul_(alpha))
|
|
54
|
+
else: raise NotImplementedError(f'order {ord} is not supported')
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class WeightDecay(OptimizerModule):
|
|
58
|
+
"""Adds weight decay term (L1 or L2 regularization) to the ascent direction.
|
|
59
|
+
|
|
60
|
+
Put this at the end to make it decoupled.
|
|
61
|
+
|
|
62
|
+
Args:
|
|
63
|
+
alpha (float, optional): multiplier to the regularizer (default: 1e-2)
|
|
64
|
+
ord (Literal[1, 2], optional):
|
|
65
|
+
order of the penalty, 1 and 2 are currently supported (L1 and L2 regularization).
|
|
66
|
+
Defaults to 2.
|
|
67
|
+
target (str, optional):
|
|
68
|
+
determines what this module updates.
|
|
69
|
+
|
|
70
|
+
"ascent" - it updates the ascent
|
|
71
|
+
|
|
72
|
+
"grad" - it updates the gradient (and sets `.grad` attributes to updated gradient).
|
|
73
|
+
|
|
74
|
+
"closure" - it makes a new closure that sets the updated ascent to the .`grad` attributes.
|
|
75
|
+
"""
|
|
76
|
+
def __init__(self, alpha: float = 1e-2, ord:Literal[1, 2] = 2, target: _Targets = "ascent"):
|
|
77
|
+
defaults = dict(alpha = alpha)
|
|
78
|
+
super().__init__(defaults, target = target)
|
|
79
|
+
self.ord = ord
|
|
80
|
+
|
|
81
|
+
@torch.no_grad
|
|
82
|
+
def _update(self, state, ascent):
|
|
83
|
+
params = self.get_params()
|
|
84
|
+
alpha = self.get_group_key('alpha')
|
|
85
|
+
|
|
86
|
+
if any(i != 0 for i in alpha):
|
|
87
|
+
|
|
88
|
+
if self.ord == 1: ascent.add_(params.sign() * alpha)
|
|
89
|
+
elif self.ord == 2: ascent.add_(params * alpha)
|
|
90
|
+
else: raise NotImplementedError(f'weight descent of order {self.ord} not implemented.')
|
|
91
|
+
|
|
92
|
+
return ascent
|