torchzero 0.1.8__py3-none-any.whl → 0.3.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docs/source/conf.py +57 -0
- tests/test_identical.py +230 -0
- tests/test_module.py +50 -0
- tests/test_opts.py +884 -0
- tests/test_tensorlist.py +1787 -0
- tests/test_utils_optimizer.py +170 -0
- tests/test_vars.py +184 -0
- torchzero/__init__.py +4 -4
- torchzero/core/__init__.py +3 -13
- torchzero/core/module.py +629 -510
- torchzero/core/preconditioner.py +137 -0
- torchzero/core/transform.py +252 -0
- torchzero/modules/__init__.py +13 -21
- torchzero/modules/clipping/__init__.py +3 -0
- torchzero/modules/clipping/clipping.py +320 -0
- torchzero/modules/clipping/ema_clipping.py +135 -0
- torchzero/modules/clipping/growth_clipping.py +187 -0
- torchzero/modules/experimental/__init__.py +13 -18
- torchzero/modules/experimental/absoap.py +350 -0
- torchzero/modules/experimental/adadam.py +111 -0
- torchzero/modules/experimental/adamY.py +135 -0
- torchzero/modules/experimental/adasoap.py +282 -0
- torchzero/modules/experimental/algebraic_newton.py +145 -0
- torchzero/modules/experimental/curveball.py +89 -0
- torchzero/modules/experimental/dsoap.py +290 -0
- torchzero/modules/experimental/gradmin.py +85 -0
- torchzero/modules/experimental/reduce_outward_lr.py +35 -0
- torchzero/modules/experimental/spectral.py +286 -0
- torchzero/modules/experimental/subspace_preconditioners.py +128 -0
- torchzero/modules/experimental/tropical_newton.py +136 -0
- torchzero/modules/functional.py +209 -0
- torchzero/modules/grad_approximation/__init__.py +4 -0
- torchzero/modules/grad_approximation/fdm.py +120 -0
- torchzero/modules/grad_approximation/forward_gradient.py +81 -0
- torchzero/modules/grad_approximation/grad_approximator.py +66 -0
- torchzero/modules/grad_approximation/rfdm.py +259 -0
- torchzero/modules/line_search/__init__.py +5 -30
- torchzero/modules/line_search/backtracking.py +186 -0
- torchzero/modules/line_search/line_search.py +181 -0
- torchzero/modules/line_search/scipy.py +37 -0
- torchzero/modules/line_search/strong_wolfe.py +260 -0
- torchzero/modules/line_search/trust_region.py +61 -0
- torchzero/modules/lr/__init__.py +2 -0
- torchzero/modules/lr/lr.py +59 -0
- torchzero/modules/lr/step_size.py +97 -0
- torchzero/modules/momentum/__init__.py +14 -4
- torchzero/modules/momentum/averaging.py +78 -0
- torchzero/modules/momentum/cautious.py +181 -0
- torchzero/modules/momentum/ema.py +173 -0
- torchzero/modules/momentum/experimental.py +189 -0
- torchzero/modules/momentum/matrix_momentum.py +124 -0
- torchzero/modules/momentum/momentum.py +43 -106
- torchzero/modules/ops/__init__.py +103 -0
- torchzero/modules/ops/accumulate.py +65 -0
- torchzero/modules/ops/binary.py +240 -0
- torchzero/modules/ops/debug.py +25 -0
- torchzero/modules/ops/misc.py +419 -0
- torchzero/modules/ops/multi.py +137 -0
- torchzero/modules/ops/reduce.py +149 -0
- torchzero/modules/ops/split.py +75 -0
- torchzero/modules/ops/switch.py +68 -0
- torchzero/modules/ops/unary.py +115 -0
- torchzero/modules/ops/utility.py +112 -0
- torchzero/modules/optimizers/__init__.py +18 -10
- torchzero/modules/optimizers/adagrad.py +146 -49
- torchzero/modules/optimizers/adam.py +112 -118
- torchzero/modules/optimizers/lion.py +18 -11
- torchzero/modules/optimizers/muon.py +222 -0
- torchzero/modules/optimizers/orthograd.py +55 -0
- torchzero/modules/optimizers/rmsprop.py +103 -51
- torchzero/modules/optimizers/rprop.py +342 -99
- torchzero/modules/optimizers/shampoo.py +197 -0
- torchzero/modules/optimizers/soap.py +286 -0
- torchzero/modules/optimizers/sophia_h.py +129 -0
- torchzero/modules/projections/__init__.py +5 -0
- torchzero/modules/projections/dct.py +73 -0
- torchzero/modules/projections/fft.py +73 -0
- torchzero/modules/projections/galore.py +10 -0
- torchzero/modules/projections/projection.py +218 -0
- torchzero/modules/projections/structural.py +151 -0
- torchzero/modules/quasi_newton/__init__.py +7 -4
- torchzero/modules/quasi_newton/cg.py +218 -0
- torchzero/modules/quasi_newton/experimental/__init__.py +1 -0
- torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +265 -0
- torchzero/modules/quasi_newton/lbfgs.py +228 -0
- torchzero/modules/quasi_newton/lsr1.py +170 -0
- torchzero/modules/quasi_newton/olbfgs.py +196 -0
- torchzero/modules/quasi_newton/quasi_newton.py +475 -0
- torchzero/modules/second_order/__init__.py +3 -4
- torchzero/modules/second_order/newton.py +142 -165
- torchzero/modules/second_order/newton_cg.py +84 -0
- torchzero/modules/second_order/nystrom.py +168 -0
- torchzero/modules/smoothing/__init__.py +2 -5
- torchzero/modules/smoothing/gaussian.py +164 -0
- torchzero/modules/smoothing/{laplacian_smoothing.py → laplacian.py} +115 -128
- torchzero/modules/weight_decay/__init__.py +1 -0
- torchzero/modules/weight_decay/weight_decay.py +52 -0
- torchzero/modules/wrappers/__init__.py +1 -0
- torchzero/modules/wrappers/optim_wrapper.py +91 -0
- torchzero/optim/__init__.py +2 -10
- torchzero/optim/utility/__init__.py +1 -0
- torchzero/optim/utility/split.py +45 -0
- torchzero/optim/wrappers/nevergrad.py +2 -28
- torchzero/optim/wrappers/nlopt.py +31 -16
- torchzero/optim/wrappers/scipy.py +79 -156
- torchzero/utils/__init__.py +27 -0
- torchzero/utils/compile.py +175 -37
- torchzero/utils/derivatives.py +513 -99
- torchzero/utils/linalg/__init__.py +5 -0
- torchzero/utils/linalg/matrix_funcs.py +87 -0
- torchzero/utils/linalg/orthogonalize.py +11 -0
- torchzero/utils/linalg/qr.py +71 -0
- torchzero/utils/linalg/solve.py +168 -0
- torchzero/utils/linalg/svd.py +20 -0
- torchzero/utils/numberlist.py +132 -0
- torchzero/utils/ops.py +10 -0
- torchzero/utils/optimizer.py +284 -0
- torchzero/utils/optuna_tools.py +40 -0
- torchzero/utils/params.py +149 -0
- torchzero/utils/python_tools.py +40 -25
- torchzero/utils/tensorlist.py +1081 -0
- torchzero/utils/torch_tools.py +48 -12
- torchzero-0.3.2.dist-info/METADATA +379 -0
- torchzero-0.3.2.dist-info/RECORD +128 -0
- {torchzero-0.1.8.dist-info → torchzero-0.3.2.dist-info}/WHEEL +1 -1
- {torchzero-0.1.8.dist-info → torchzero-0.3.2.dist-info/licenses}/LICENSE +0 -0
- torchzero-0.3.2.dist-info/top_level.txt +3 -0
- torchzero/core/tensorlist_optimizer.py +0 -219
- torchzero/modules/adaptive/__init__.py +0 -4
- torchzero/modules/adaptive/adaptive.py +0 -192
- torchzero/modules/experimental/experimental.py +0 -294
- torchzero/modules/experimental/quad_interp.py +0 -104
- torchzero/modules/experimental/subspace.py +0 -259
- torchzero/modules/gradient_approximation/__init__.py +0 -7
- torchzero/modules/gradient_approximation/_fd_formulas.py +0 -3
- torchzero/modules/gradient_approximation/base_approximator.py +0 -105
- torchzero/modules/gradient_approximation/fdm.py +0 -125
- torchzero/modules/gradient_approximation/forward_gradient.py +0 -163
- torchzero/modules/gradient_approximation/newton_fdm.py +0 -198
- torchzero/modules/gradient_approximation/rfdm.py +0 -125
- torchzero/modules/line_search/armijo.py +0 -56
- torchzero/modules/line_search/base_ls.py +0 -139
- torchzero/modules/line_search/directional_newton.py +0 -217
- torchzero/modules/line_search/grid_ls.py +0 -158
- torchzero/modules/line_search/scipy_minimize_scalar.py +0 -62
- torchzero/modules/meta/__init__.py +0 -12
- torchzero/modules/meta/alternate.py +0 -65
- torchzero/modules/meta/grafting.py +0 -195
- torchzero/modules/meta/optimizer_wrapper.py +0 -173
- torchzero/modules/meta/return_overrides.py +0 -46
- torchzero/modules/misc/__init__.py +0 -10
- torchzero/modules/misc/accumulate.py +0 -43
- torchzero/modules/misc/basic.py +0 -115
- torchzero/modules/misc/lr.py +0 -96
- torchzero/modules/misc/multistep.py +0 -51
- torchzero/modules/misc/on_increase.py +0 -53
- torchzero/modules/operations/__init__.py +0 -29
- torchzero/modules/operations/multi.py +0 -298
- torchzero/modules/operations/reduction.py +0 -134
- torchzero/modules/operations/singular.py +0 -113
- torchzero/modules/optimizers/sgd.py +0 -54
- torchzero/modules/orthogonalization/__init__.py +0 -2
- torchzero/modules/orthogonalization/newtonschulz.py +0 -159
- torchzero/modules/orthogonalization/svd.py +0 -86
- torchzero/modules/regularization/__init__.py +0 -22
- torchzero/modules/regularization/dropout.py +0 -34
- torchzero/modules/regularization/noise.py +0 -77
- torchzero/modules/regularization/normalization.py +0 -328
- torchzero/modules/regularization/ortho_grad.py +0 -78
- torchzero/modules/regularization/weight_decay.py +0 -92
- torchzero/modules/scheduling/__init__.py +0 -2
- torchzero/modules/scheduling/lr_schedulers.py +0 -131
- torchzero/modules/scheduling/step_size.py +0 -80
- torchzero/modules/smoothing/gaussian_smoothing.py +0 -90
- torchzero/modules/weight_averaging/__init__.py +0 -2
- torchzero/modules/weight_averaging/ema.py +0 -72
- torchzero/modules/weight_averaging/swa.py +0 -171
- torchzero/optim/experimental/__init__.py +0 -20
- torchzero/optim/experimental/experimental.py +0 -343
- torchzero/optim/experimental/ray_search.py +0 -83
- torchzero/optim/first_order/__init__.py +0 -18
- torchzero/optim/first_order/cautious.py +0 -158
- torchzero/optim/first_order/forward_gradient.py +0 -70
- torchzero/optim/first_order/optimizers.py +0 -570
- torchzero/optim/modular.py +0 -148
- torchzero/optim/quasi_newton/__init__.py +0 -1
- torchzero/optim/quasi_newton/directional_newton.py +0 -58
- torchzero/optim/second_order/__init__.py +0 -1
- torchzero/optim/second_order/newton.py +0 -94
- torchzero/optim/zeroth_order/__init__.py +0 -4
- torchzero/optim/zeroth_order/fdm.py +0 -87
- torchzero/optim/zeroth_order/newton_fdm.py +0 -146
- torchzero/optim/zeroth_order/rfdm.py +0 -217
- torchzero/optim/zeroth_order/rs.py +0 -85
- torchzero/random/__init__.py +0 -1
- torchzero/random/random.py +0 -46
- torchzero/tensorlist.py +0 -826
- torchzero-0.1.8.dist-info/METADATA +0 -130
- torchzero-0.1.8.dist-info/RECORD +0 -104
- torchzero-0.1.8.dist-info/top_level.txt +0 -1
|
@@ -1,328 +0,0 @@
|
|
|
1
|
-
from collections import abc
|
|
2
|
-
import typing
|
|
3
|
-
import torch
|
|
4
|
-
|
|
5
|
-
from ...tensorlist import TensorList
|
|
6
|
-
from ...core import OptimizerModule
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
def _normalize_grad_(
|
|
10
|
-
grads: abc.Iterable[torch.Tensor],
|
|
11
|
-
norm_value: float = 1,
|
|
12
|
-
ord: float = 2,
|
|
13
|
-
min: float = 0,
|
|
14
|
-
mode: typing.Literal["global", "param", "channel"] = "param",
|
|
15
|
-
min_numel=2,
|
|
16
|
-
):
|
|
17
|
-
if mode in ('param', 'channel'):
|
|
18
|
-
for grad in grads:
|
|
19
|
-
if grad.numel() >= min_numel:
|
|
20
|
-
if mode == 'channel' and grad.ndim >= 2:
|
|
21
|
-
norm = torch.linalg.vector_norm(grad, ord, dim=tuple(range(1, grad.ndim)), keepdim=True) # pylint:disable=not-callable
|
|
22
|
-
norm[norm<=min] = 1
|
|
23
|
-
grad /= norm / norm_value
|
|
24
|
-
else: # mode = 'param' or 1d grad
|
|
25
|
-
norm = torch.linalg.vector_norm(grad, ord) # pylint:disable=not-callable
|
|
26
|
-
if norm > min:
|
|
27
|
-
grad /= norm / norm_value
|
|
28
|
-
else:
|
|
29
|
-
if not isinstance(grads, TensorList): grads = TensorList(grads)
|
|
30
|
-
norm = grads.total_vector_norm(ord)
|
|
31
|
-
if norm > min:
|
|
32
|
-
grads /= norm / norm_value # type:ignore
|
|
33
|
-
|
|
34
|
-
@torch.no_grad
|
|
35
|
-
def normalize_grad_(
|
|
36
|
-
params: abc.Iterable[torch.Tensor],
|
|
37
|
-
norm_value: float = 1,
|
|
38
|
-
ord: float = 2,
|
|
39
|
-
min: float = 0,
|
|
40
|
-
mode: typing.Literal["global", "param", "channel"] = "global",
|
|
41
|
-
min_numel=2,
|
|
42
|
-
):
|
|
43
|
-
"""Normalizes gradients of an iterable of parameters.
|
|
44
|
-
|
|
45
|
-
Args:
|
|
46
|
-
params (abc.Iterable[torch.Tensor]): parameters that hold gradients to normalize.
|
|
47
|
-
norm_value (float, optional): value to normalize to. Defaults to 1.
|
|
48
|
-
ord (float, optional): order of the norm. Defaults to 2.
|
|
49
|
-
min (float, optional):
|
|
50
|
-
won't normalize when gradient is below this norm, you can increase this
|
|
51
|
-
to avoid amplifying extremely small gradients. Defaults to 0.
|
|
52
|
-
mode (str, optional):
|
|
53
|
-
what to normalize.
|
|
54
|
-
|
|
55
|
-
- "global": normalize the entire gradient, as if it was a single vector.
|
|
56
|
-
|
|
57
|
-
- "param": normalize each param's gradient (default).
|
|
58
|
-
|
|
59
|
-
- "channel": normalize gradient of each channel of each param.
|
|
60
|
-
min_numel (int, optional):
|
|
61
|
-
skips parameters with less than this many elements. This avoids the issue where
|
|
62
|
-
parameters that have a single element always get set to the value of 1.
|
|
63
|
-
Ignored when mode is 'global'.
|
|
64
|
-
|
|
65
|
-
Example:
|
|
66
|
-
>>> normalize_grad_(model.parameters())
|
|
67
|
-
"""
|
|
68
|
-
_normalize_grad_(
|
|
69
|
-
(p.grad for p in params if p.grad is not None),
|
|
70
|
-
norm_value = norm_value,
|
|
71
|
-
ord = ord,
|
|
72
|
-
min = min,
|
|
73
|
-
mode = mode,
|
|
74
|
-
min_numel = min_numel,
|
|
75
|
-
)
|
|
76
|
-
|
|
77
|
-
class Normalize(OptimizerModule):
|
|
78
|
-
"""Normalizes update to the given norm value.
|
|
79
|
-
|
|
80
|
-
Args:
|
|
81
|
-
norm_value (float, optional): value to normalize to. Defaults to 1.
|
|
82
|
-
ord (float, optional): order of the norm. Defaults to 2.
|
|
83
|
-
min (float, optional):
|
|
84
|
-
won't normalize when gradient is below this norm, you can increase this
|
|
85
|
-
to avoid amplifying extremely small gradients. Defaults to 0.
|
|
86
|
-
mode (str, optional):
|
|
87
|
-
what to normalize.
|
|
88
|
-
|
|
89
|
-
- "global": normalize the entire gradient, as if it was a single vector.
|
|
90
|
-
|
|
91
|
-
- "param": normalize each param's gradient (default).
|
|
92
|
-
|
|
93
|
-
- "channel": normalize gradient of each channel of each param.
|
|
94
|
-
min_numel (int, optional):
|
|
95
|
-
skips parameters with less than this many elements. This avoids the issue where
|
|
96
|
-
parameters that have a single element always get set to the value of 1.
|
|
97
|
-
Ignored when mode is 'global'.
|
|
98
|
-
"""
|
|
99
|
-
def __init__(
|
|
100
|
-
self,
|
|
101
|
-
norm_value: float = 1,
|
|
102
|
-
ord: float = 2,
|
|
103
|
-
min: float = 0,
|
|
104
|
-
mode: typing.Literal["global", "param", "channel"] = "param",
|
|
105
|
-
min_numel=2,
|
|
106
|
-
):
|
|
107
|
-
super().__init__({})
|
|
108
|
-
self.norm_value = norm_value
|
|
109
|
-
self.ord = ord
|
|
110
|
-
self.min = min
|
|
111
|
-
self.mode: typing.Literal["global", "param", "channel"] = mode
|
|
112
|
-
self.min_numel = min_numel
|
|
113
|
-
|
|
114
|
-
@torch.no_grad
|
|
115
|
-
def _update(self, vars, ascent):
|
|
116
|
-
_normalize_grad_(
|
|
117
|
-
ascent,
|
|
118
|
-
norm_value = self.norm_value,
|
|
119
|
-
ord = self.ord,
|
|
120
|
-
min = self.min,
|
|
121
|
-
mode = self.mode,
|
|
122
|
-
min_numel = self.min_numel,
|
|
123
|
-
)
|
|
124
|
-
return ascent
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
def _centralize_grad_(
|
|
128
|
-
grads: abc.Iterable[torch.Tensor],
|
|
129
|
-
mode: typing.Literal["global", "param", "channel"] = "channel",
|
|
130
|
-
min_ndim=2,
|
|
131
|
-
min_numel=2,
|
|
132
|
-
):
|
|
133
|
-
if mode in ('param', 'channel'):
|
|
134
|
-
if mode == 'channel': min_ndim = max(min_ndim, 2)
|
|
135
|
-
for grad in grads:
|
|
136
|
-
if grad.numel() >= min_numel and grad.ndim > min_ndim:
|
|
137
|
-
if mode == 'channel':
|
|
138
|
-
grad -= grad.mean(dim=tuple(range(1, grad.ndim)), keepdim=True)
|
|
139
|
-
else: # mode = 'param'
|
|
140
|
-
grad -= grad.mean()
|
|
141
|
-
else:
|
|
142
|
-
if not isinstance(grads, TensorList): grads = TensorList(grads)
|
|
143
|
-
grads -= grads.mean()
|
|
144
|
-
|
|
145
|
-
@torch.no_grad
|
|
146
|
-
def centralize_grad_(
|
|
147
|
-
params: abc.Iterable[torch.Tensor],
|
|
148
|
-
mode: typing.Literal["global", "param", "channel"] = "channel",
|
|
149
|
-
min_ndim=2,
|
|
150
|
-
min_numel=2,
|
|
151
|
-
):
|
|
152
|
-
"""Centralizes gradients of an iterable of parameters.
|
|
153
|
-
|
|
154
|
-
Args:
|
|
155
|
-
params (abc.Iterable[torch.Tensor]): parameters that hold gradients to centralize.
|
|
156
|
-
mode (str, optional):
|
|
157
|
-
what to centralize.
|
|
158
|
-
|
|
159
|
-
- "global": centralize the entire gradient (uses mean of entire gradient).
|
|
160
|
-
|
|
161
|
-
- "param": centralize each param's gradient.
|
|
162
|
-
|
|
163
|
-
- "channel": centralize gradient of each channel of each param (default).
|
|
164
|
-
min_numel (int, optional):
|
|
165
|
-
skips parameters with less than this many elements. This avoids negating updates for
|
|
166
|
-
parameters that have a single element since subtracting mean always makes it 0.
|
|
167
|
-
Ignored when mode is 'global'.
|
|
168
|
-
min_ndim (int, optional):
|
|
169
|
-
skips parameters with less than this many dimensions.
|
|
170
|
-
bias usually has 1 dimension and you don't want to centralize it.
|
|
171
|
-
Ignored when mode is 'global'.
|
|
172
|
-
|
|
173
|
-
reference
|
|
174
|
-
*Yong, H., Huang, J., Hua, X., & Zhang, L. (2020).
|
|
175
|
-
Gradient centralization: A new optimization technique for deep neural networks.
|
|
176
|
-
In Computer Vision–ECCV 2020: 16th European Conference, Glasgow, UK,
|
|
177
|
-
August 23–28, 2020, Proceedings, Part I 16 (pp. 635-652). Springer International Publishing.*
|
|
178
|
-
|
|
179
|
-
Example:
|
|
180
|
-
>>> centralize_grad_(model.parameters())
|
|
181
|
-
"""
|
|
182
|
-
_centralize_grad_(
|
|
183
|
-
(p.grad for p in params if p.grad is not None),
|
|
184
|
-
mode = mode,
|
|
185
|
-
min_ndim = min_ndim,
|
|
186
|
-
min_numel = min_numel,
|
|
187
|
-
)
|
|
188
|
-
|
|
189
|
-
class Centralize(OptimizerModule):
|
|
190
|
-
"""Centralizes the update.
|
|
191
|
-
|
|
192
|
-
Args:
|
|
193
|
-
mode (str, optional):
|
|
194
|
-
what to centralize.
|
|
195
|
-
|
|
196
|
-
- "global": centralize the entire gradient (uses mean of entire gradient).
|
|
197
|
-
|
|
198
|
-
- "param": centralize each param's gradient.
|
|
199
|
-
|
|
200
|
-
- "channel": centralize gradient of each channel of each param (default).
|
|
201
|
-
min_numel (int, optional):
|
|
202
|
-
skips parameters with less than this many elements. This avoids negating updates for
|
|
203
|
-
parameters that have a single element since subtracting mean always makes it 0.
|
|
204
|
-
Ignored when mode is 'global'.
|
|
205
|
-
min_ndim (int, optional):
|
|
206
|
-
skips parameters with less than this many dimensions.
|
|
207
|
-
bias usually has 1 dimension and you don't want to centralize it.
|
|
208
|
-
Ignored when mode is 'global'.
|
|
209
|
-
|
|
210
|
-
reference
|
|
211
|
-
*Yong, H., Huang, J., Hua, X., & Zhang, L. (2020).
|
|
212
|
-
Gradient centralization: A new optimization technique for deep neural networks.
|
|
213
|
-
In Computer Vision–ECCV 2020: 16th European Conference, Glasgow, UK,
|
|
214
|
-
August 23–28, 2020, Proceedings, Part I 16 (pp. 635-652). Springer International Publishing.*
|
|
215
|
-
"""
|
|
216
|
-
def __init__(
|
|
217
|
-
self,
|
|
218
|
-
mode: typing.Literal["global", "param", "channel"] = "channel",
|
|
219
|
-
min_ndim=2,
|
|
220
|
-
min_numel=2,
|
|
221
|
-
):
|
|
222
|
-
super().__init__({})
|
|
223
|
-
self.mode: typing.Literal["global", "param", "channel"] = mode
|
|
224
|
-
self.min_ndim = min_ndim
|
|
225
|
-
self.min_numel = min_numel
|
|
226
|
-
|
|
227
|
-
@torch.no_grad
|
|
228
|
-
def _update(self, vars, ascent):
|
|
229
|
-
_centralize_grad_(
|
|
230
|
-
ascent,
|
|
231
|
-
mode = self.mode,
|
|
232
|
-
min_ndim = self.min_ndim,
|
|
233
|
-
min_numel = self.min_numel,
|
|
234
|
-
)
|
|
235
|
-
return ascent
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
def clip_grad_value_(params: abc.Iterable[torch.Tensor], value:float):
|
|
239
|
-
"""Clip the gradients of an iterable of parameters at specified value.
|
|
240
|
-
|
|
241
|
-
Args:
|
|
242
|
-
params (abc.Iterable[torch.Tensor]): an iterable of Tensors or a single Tensor that will have gradients clipped.
|
|
243
|
-
value (float, optional):
|
|
244
|
-
maximum allowed magnitude of the gradients.
|
|
245
|
-
The gradients are clipped in the range `[-clip_value, clip_value]`
|
|
246
|
-
"""
|
|
247
|
-
TensorList(params).get_existing_grads().clamp_(-value, value)
|
|
248
|
-
|
|
249
|
-
class ClipValue(OptimizerModule):
|
|
250
|
-
"""Clip the update at specified value.
|
|
251
|
-
|
|
252
|
-
Args:
|
|
253
|
-
value (float, optional): maximum allowed magnitude of the gradients.
|
|
254
|
-
The gradients are clipped in the range `[-clip_value, clip_value]`
|
|
255
|
-
"""
|
|
256
|
-
def __init__(self, value: float):
|
|
257
|
-
defaults = dict(value = value)
|
|
258
|
-
super().__init__(defaults)
|
|
259
|
-
|
|
260
|
-
@torch.no_grad
|
|
261
|
-
def _update(self, vars, ascent):
|
|
262
|
-
value = self.get_group_key('value')
|
|
263
|
-
ascent.clamp_(-value, value)
|
|
264
|
-
return ascent
|
|
265
|
-
|
|
266
|
-
def clip_grad_norm_(
|
|
267
|
-
params: abc.Iterable[torch.Tensor],
|
|
268
|
-
max_norm: float,
|
|
269
|
-
ord: float = 2,
|
|
270
|
-
mode: typing.Literal["global", "param", "channel"] = "param",
|
|
271
|
-
):
|
|
272
|
-
"""Clip the gradient norm of an iterable of parameters.
|
|
273
|
-
|
|
274
|
-
Args:
|
|
275
|
-
params (abc.Iterable[torch.Tensor]): parameters that hold gradients to clip the norm of.
|
|
276
|
-
max_norm (float, optional): norm value to clip to.
|
|
277
|
-
ord (float, optional): order of the norm. Defaults to 2.
|
|
278
|
-
mode (str, optional):
|
|
279
|
-
what to calculate the norm over.
|
|
280
|
-
|
|
281
|
-
- "global": calculates and clips the norm of the entire gradient, as if it was a single vector.
|
|
282
|
-
|
|
283
|
-
- "param": calculates and clips each param's gradient norm (default).
|
|
284
|
-
|
|
285
|
-
- "channel": calculate and clip the norm of gradient of each channel of each param.
|
|
286
|
-
|
|
287
|
-
Example:
|
|
288
|
-
>>> clip_grad_norm_(model.parameters())
|
|
289
|
-
"""
|
|
290
|
-
_normalize_grad_(
|
|
291
|
-
(p.grad for p in params if p.grad is not None),
|
|
292
|
-
norm_value = max_norm,
|
|
293
|
-
min = max_norm,
|
|
294
|
-
ord = ord,
|
|
295
|
-
mode = mode,
|
|
296
|
-
)
|
|
297
|
-
|
|
298
|
-
class ClipNorm(OptimizerModule):
|
|
299
|
-
"""Clip the gradient norm of an iterable of parameters.
|
|
300
|
-
|
|
301
|
-
Args:
|
|
302
|
-
max_norm (float, optional): norm value to clip to.
|
|
303
|
-
ord (float, optional): order of the norm. Defaults to 2.
|
|
304
|
-
mode (str, optional):
|
|
305
|
-
what to calculate the norm over.
|
|
306
|
-
|
|
307
|
-
- "global": calculates and clips the norm of the entire gradient, as if it was a single vector.
|
|
308
|
-
|
|
309
|
-
- "param": calculates and clips each param's gradient norm (default).
|
|
310
|
-
|
|
311
|
-
- "channel": calculate and clip the norm of gradient of each channel of each param.
|
|
312
|
-
"""
|
|
313
|
-
def __init__(self, max_norm: float, ord:float=2, mode: typing.Literal["global", "param", "channel"] = "param",):
|
|
314
|
-
super().__init__({})
|
|
315
|
-
self.max_norm = max_norm
|
|
316
|
-
self.ord = ord
|
|
317
|
-
self.mode: typing.Literal["global", "param", "channel"] = mode
|
|
318
|
-
|
|
319
|
-
@torch.no_grad
|
|
320
|
-
def _update(self, vars, ascent):
|
|
321
|
-
_normalize_grad_(
|
|
322
|
-
ascent,
|
|
323
|
-
norm_value = self.max_norm,
|
|
324
|
-
min = self.max_norm,
|
|
325
|
-
ord = self.ord,
|
|
326
|
-
mode = self.mode,
|
|
327
|
-
)
|
|
328
|
-
return ascent
|
|
@@ -1,78 +0,0 @@
|
|
|
1
|
-
"""
|
|
2
|
-
⟂Grad (read “ortho-grad”) was proposed in https://arxiv.org/abs/2501.04697.
|
|
3
|
-
|
|
4
|
-
"""
|
|
5
|
-
from collections.abc import Iterable
|
|
6
|
-
|
|
7
|
-
import torch
|
|
8
|
-
|
|
9
|
-
from ...tensorlist import TensorList
|
|
10
|
-
from ...core import OptimizerModule, _Targets
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
def orthograd_(params: Iterable[torch.Tensor], eps: float = 1e-30):
|
|
14
|
-
"""Applies ⟂Grad - projects gradient of an iterable of parameters to be orthogonal to the weights.
|
|
15
|
-
|
|
16
|
-
Args:
|
|
17
|
-
params (abc.Iterable[torch.Tensor]): parameters that hold gradients to apply ⟂Grad to.
|
|
18
|
-
eps (float, optional): epsilon added to the denominator for numerical stability (default: 1e-30)
|
|
19
|
-
|
|
20
|
-
reference
|
|
21
|
-
https://arxiv.org/abs/2501.04697
|
|
22
|
-
"""
|
|
23
|
-
if not isinstance(params, TensorList): params = TensorList(params)
|
|
24
|
-
params = params.with_grad()
|
|
25
|
-
grad = params.grad
|
|
26
|
-
grad -= (((params*grad).total_sum())/(params*params).total_sum() + eps) * params
|
|
27
|
-
|
|
28
|
-
class OrthoGrad(OptimizerModule):
|
|
29
|
-
"""⟂Grad - projects gradient of an iterable of parameters to be orthogonal to the weights.
|
|
30
|
-
|
|
31
|
-
Args:
|
|
32
|
-
params (abc.Iterable[torch.Tensor]): parameters that hold gradients to apply ⟂Grad to.
|
|
33
|
-
eps (float, optional): epsilon added to the denominator for numerical stability (default: 1e-30)
|
|
34
|
-
renormalize (bool, optional): whether to renormalize gradients back to original norm (default: True).
|
|
35
|
-
sqrt_scale (bool, optional):
|
|
36
|
-
uses square root of the scale to make it more impactful, experimental setting and doesn't really work (default: False).
|
|
37
|
-
add (bool, optional):
|
|
38
|
-
Experimental option that changes subtraction to addition.
|
|
39
|
-
I don't think it has any geometric meaning but it drives weights towards zero instead of away from it.
|
|
40
|
-
and it seems to work well with sqrt_scale = True. It speeds up convergence by a lot compared to using vanilla gradient,
|
|
41
|
-
but also has INSANE overfitting.
|
|
42
|
-
target (str, optional):
|
|
43
|
-
determines what this module updates.
|
|
44
|
-
|
|
45
|
-
"ascent" - it updates the ascent (default).
|
|
46
|
-
|
|
47
|
-
"grad" - it updates the gradient (and sets `.grad` attributes to updated gradient).
|
|
48
|
-
|
|
49
|
-
"closure" - it makes a new closure that sets the updated ascent to the .`grad` attributes.
|
|
50
|
-
|
|
51
|
-
reference
|
|
52
|
-
https://arxiv.org/abs/2501.04697
|
|
53
|
-
"""
|
|
54
|
-
def __init__(self, eps: float = 1e-30, renormalize=True, sqrt_scale = False, add=False, target: _Targets = 'ascent'):
|
|
55
|
-
super().__init__({}, target=target)
|
|
56
|
-
self.eps = eps
|
|
57
|
-
self.add = add
|
|
58
|
-
self.renormalize = renormalize
|
|
59
|
-
self.sqrt_scale = sqrt_scale
|
|
60
|
-
|
|
61
|
-
def _update(self, vars, ascent):
|
|
62
|
-
params = self.get_params()
|
|
63
|
-
|
|
64
|
-
if self.renormalize: orig_norm = ascent.norm(2) + self.eps
|
|
65
|
-
else: orig_norm = 1
|
|
66
|
-
|
|
67
|
-
scale = (params*ascent).total_sum() / ((params*params).total_sum() + self.eps)
|
|
68
|
-
if self.sqrt_scale:
|
|
69
|
-
scale = scale.abs().sqrt() * scale.sign()
|
|
70
|
-
|
|
71
|
-
if self.add: ascent += params * scale
|
|
72
|
-
else: ascent -= params * scale
|
|
73
|
-
|
|
74
|
-
if self.renormalize:
|
|
75
|
-
ascent *= (orig_norm / ascent.norm(2))
|
|
76
|
-
|
|
77
|
-
return ascent
|
|
78
|
-
|
|
@@ -1,92 +0,0 @@
|
|
|
1
|
-
from typing import Literal
|
|
2
|
-
from collections.abc import Iterable
|
|
3
|
-
|
|
4
|
-
import torch
|
|
5
|
-
|
|
6
|
-
from ...tensorlist import TensorList
|
|
7
|
-
from ...core import OptimizerModule, _Targets
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
def l2_regularize_(params: Iterable[torch.Tensor], alpha: float = 1e-2):
|
|
11
|
-
"""Adds L2 weight regularization term to the gradients in-place.
|
|
12
|
-
|
|
13
|
-
Args:
|
|
14
|
-
params (Iterable[torch.Tensor]): an iterable of Tensors or a single Tensor.
|
|
15
|
-
alpha (float, optional): multiplier to the regularizer. Defaults to 1e-2.
|
|
16
|
-
"""
|
|
17
|
-
p = TensorList(params).with_requires_grad()
|
|
18
|
-
p.ensure_grad_()
|
|
19
|
-
p.grad.add_(p, alpha = alpha)
|
|
20
|
-
|
|
21
|
-
def l1_regularize_(params: Iterable[torch.Tensor], alpha: float = 1e-2):
|
|
22
|
-
"""Adds L1 weight regularization term to the gradients in-place.
|
|
23
|
-
|
|
24
|
-
Args:
|
|
25
|
-
params (Iterable[torch.Tensor]): an iterable of Tensors or a single Tensor.
|
|
26
|
-
alpha (float, optional): multiplier to the regularizer. Defaults to 1e-2.
|
|
27
|
-
"""
|
|
28
|
-
p = TensorList(params).with_requires_grad()
|
|
29
|
-
p.ensure_grad_()
|
|
30
|
-
p.grad.add_(p.sign(), alpha = alpha)
|
|
31
|
-
|
|
32
|
-
def weight_decay_penalty(params: Iterable[torch.Tensor], alpha: float = 1e-2, ord:float = 2):
|
|
33
|
-
"""Calculate the weight decay penalty term that can be added to the loss.
|
|
34
|
-
|
|
35
|
-
Args:
|
|
36
|
-
params (Iterable[torch.Tensor]): an iterable of Tensors or a single Tensor.
|
|
37
|
-
alpha (float): multiplier to the regularizer.
|
|
38
|
-
ord (int, optional): order of the norm. Defaults to 2.
|
|
39
|
-
"""
|
|
40
|
-
return TensorList(params).norm(ord) * alpha
|
|
41
|
-
|
|
42
|
-
def decay_weights_(params: Iterable[torch.Tensor], alpha: float = 1e-2, ord:Literal[1, 2] = 2):
|
|
43
|
-
"""Apply weight decay directly to parameters in-place.
|
|
44
|
-
|
|
45
|
-
Args:
|
|
46
|
-
params (Iterable[torch.Tensor]): an iterable of Tensors or a single Tensor to decay.
|
|
47
|
-
alpha (float): by how much to decay parameters (default: 1e-2)
|
|
48
|
-
ord (float, optional):
|
|
49
|
-
order of the penalty, 1 and 2 are currently supported (L1 and L2 regularization) (default: 2)
|
|
50
|
-
"""
|
|
51
|
-
params = TensorList(params)
|
|
52
|
-
if ord == 2: params.mul_(1-alpha)
|
|
53
|
-
elif ord == 1: params.sub_(params.sign().mul_(alpha))
|
|
54
|
-
else: raise NotImplementedError(f'order {ord} is not supported')
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
class WeightDecay(OptimizerModule):
|
|
58
|
-
"""Adds weight decay term (L1 or L2 regularization) to the ascent direction.
|
|
59
|
-
|
|
60
|
-
Put this at the end to make it decoupled.
|
|
61
|
-
|
|
62
|
-
Args:
|
|
63
|
-
alpha (float, optional): multiplier to the regularizer (default: 1e-2)
|
|
64
|
-
ord (Literal[1, 2], optional):
|
|
65
|
-
order of the penalty, 1 and 2 are currently supported (L1 and L2 regularization).
|
|
66
|
-
Defaults to 2.
|
|
67
|
-
target (str, optional):
|
|
68
|
-
determines what this module updates.
|
|
69
|
-
|
|
70
|
-
"ascent" - it updates the ascent
|
|
71
|
-
|
|
72
|
-
"grad" - it updates the gradient (and sets `.grad` attributes to updated gradient).
|
|
73
|
-
|
|
74
|
-
"closure" - it makes a new closure that sets the updated ascent to the .`grad` attributes.
|
|
75
|
-
"""
|
|
76
|
-
def __init__(self, alpha: float = 1e-2, ord:Literal[1, 2] = 2, target: _Targets = "ascent"):
|
|
77
|
-
defaults = dict(alpha = alpha)
|
|
78
|
-
super().__init__(defaults, target = target)
|
|
79
|
-
self.ord = ord
|
|
80
|
-
|
|
81
|
-
@torch.no_grad
|
|
82
|
-
def _update(self, vars, ascent):
|
|
83
|
-
params = self.get_params()
|
|
84
|
-
alpha = self.get_group_key('alpha')
|
|
85
|
-
|
|
86
|
-
if any(i != 0 for i in alpha):
|
|
87
|
-
|
|
88
|
-
if self.ord == 1: ascent.add_(params.sign() * alpha)
|
|
89
|
-
elif self.ord == 2: ascent.add_(params * alpha)
|
|
90
|
-
else: raise NotImplementedError(f'weight descent of order {self.ord} not implemented.')
|
|
91
|
-
|
|
92
|
-
return ascent
|
|
@@ -1,131 +0,0 @@
|
|
|
1
|
-
from collections.abc import Callable
|
|
2
|
-
from functools import partial
|
|
3
|
-
from typing import Any, overload, TYPE_CHECKING
|
|
4
|
-
import random
|
|
5
|
-
|
|
6
|
-
import torch
|
|
7
|
-
from ...core import OptimizerModule
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
if TYPE_CHECKING:
|
|
11
|
-
from ...optim import Modular
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
# LR SCHEDULING MOVED TO LR MODULE
|
|
15
|
-
|
|
16
|
-
# def _set_momentum_hook(optimizer, state, momentum):
|
|
17
|
-
# for module in optimizer.unrolled_modules:
|
|
18
|
-
# if 'momentum' in module.defaults:
|
|
19
|
-
# for g in module.param_groups:
|
|
20
|
-
# g['momentum'] = momentum
|
|
21
|
-
# if 'beta1' in module.defaults:
|
|
22
|
-
# for g in module.param_groups:
|
|
23
|
-
# g['beta1'] = momentum
|
|
24
|
-
|
|
25
|
-
# def _add_scheduler_hook(opt: "Modular", scheduler_cls, id):
|
|
26
|
-
# """post-init hook that sets `scheduler_step_fn` to the scheduler step."""
|
|
27
|
-
# # get LR module
|
|
28
|
-
# lr_module = opt.get_lr_module()
|
|
29
|
-
|
|
30
|
-
# # get current LRScheduler module
|
|
31
|
-
# scheds = [i for i in opt.unrolled_modules if isinstance(i, LRScheduler)]
|
|
32
|
-
# scheds = [i for i in scheds if i.id == id]
|
|
33
|
-
# if len(scheds) != 1:
|
|
34
|
-
# raise RuntimeError(f"more than 1 module with id {id}: {scheds}")
|
|
35
|
-
|
|
36
|
-
# sch_module = scheds[0]
|
|
37
|
-
|
|
38
|
-
# # make a scheduler and save the step function
|
|
39
|
-
# scheduler = scheduler_cls(lr_module)
|
|
40
|
-
# sch_module.scheduler_step_fn = scheduler.step
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
# class LRScheduler(OptimizerModule):
|
|
44
|
-
# """Use any pytorch lr scheduler.
|
|
45
|
-
|
|
46
|
-
# Important - the lr is applied multiplicatively and multiplies with learning rate of other modules,
|
|
47
|
-
# so usually base learning rate of the lr scheduler, such as `max_lr` for OneCycleLR, should be set to 1.
|
|
48
|
-
|
|
49
|
-
# Args:
|
|
50
|
-
# lr_scheduler (Callable[[torch.optim.Optimizer], torch.optim.lr_scheduler.LRScheduler | Any]):
|
|
51
|
-
# something like:
|
|
52
|
-
# .. code:: py
|
|
53
|
-
# lambda opt: OneCycleLR(opt, max_lr = 1, total_steps = 60000)
|
|
54
|
-
# update_every (int, optional):
|
|
55
|
-
# call `step` every n steps, useful for schedulers that only step once per epoch. Defaults to 1.
|
|
56
|
-
# cycle_momentum (bool, optional):
|
|
57
|
-
# enables support for cycling momentum with schedulers that support it, such as `OneCycleLR`.
|
|
58
|
-
# Unlike lr, momentum is not applied multiplicatively, but set to all other modules with
|
|
59
|
-
# `momentum` or `beta` settings. Has no effect if there are no modules that support momentum. Defaults to False.
|
|
60
|
-
# init_lr (float, optional):
|
|
61
|
-
# initial lr, I believe most lr schedulers ignore this. Defaults to 1.
|
|
62
|
-
# init_momentum (float, optional):
|
|
63
|
-
# initial init_momentum, I believe most lr schedulers ignore this.
|
|
64
|
-
# Has no effect if `cycle_momentum` is False or there are no modules that support momentum. Defaults to 0.
|
|
65
|
-
# """
|
|
66
|
-
# def __init__(
|
|
67
|
-
# self,
|
|
68
|
-
# lr_scheduler: Callable[[torch.optim.Optimizer], torch.optim.lr_scheduler.LRScheduler | Any],
|
|
69
|
-
# step_every: int = 1,
|
|
70
|
-
# cycle_momentum: bool = True,
|
|
71
|
-
# ):
|
|
72
|
-
# super().__init__({})
|
|
73
|
-
# scheduler = lr_scheduler(self.dummy_opt)
|
|
74
|
-
# self.update_every = step_every
|
|
75
|
-
# self.cycle_momentum = cycle_momentum
|
|
76
|
-
|
|
77
|
-
# self.scheduler_step_fn = scheduler.step
|
|
78
|
-
# self.cur = 0
|
|
79
|
-
# self.cur_lr = init_lr
|
|
80
|
-
# self.cur_momentum = init_momentum
|
|
81
|
-
|
|
82
|
-
# self.id = random.random()
|
|
83
|
-
|
|
84
|
-
# def step(self, vars):
|
|
85
|
-
# if self.cur % self.update_every == 0:
|
|
86
|
-
# self.scheduler_step_fn()
|
|
87
|
-
# self.cur_lr = self.dummy_opt.first_param_group['lr']
|
|
88
|
-
# self.cur_momentum = self.dummy_opt.first_param_group['momentum']
|
|
89
|
-
|
|
90
|
-
# params = self.get_params()
|
|
91
|
-
# ascent = state.maybe_use_grad_(params)
|
|
92
|
-
# ascent *= self.cur_lr
|
|
93
|
-
|
|
94
|
-
# if self.cycle_momentum:
|
|
95
|
-
# state.add_post_step_hook(partial(_set_momentum_hook, momentum = self.cur_momentum))
|
|
96
|
-
|
|
97
|
-
class LRWarmup(OptimizerModule):
|
|
98
|
-
"""Linear learning rate warmup.
|
|
99
|
-
|
|
100
|
-
Args:
|
|
101
|
-
n_steps (int): number of warmup steps.
|
|
102
|
-
start_lr (float, optional): initial lr. Defaults to 1e-8.
|
|
103
|
-
end_lr (float, optional): final lr. Defaults to 1.
|
|
104
|
-
delay_steps (int, optional): number of `start_lr` steps before starting the warmup. Defaults to 0.
|
|
105
|
-
"""
|
|
106
|
-
def __init__(self, n_steps: int, start_lr: float = 1e-8, end_lr: float = 1, delay_steps: int = 0):
|
|
107
|
-
|
|
108
|
-
super().__init__({})
|
|
109
|
-
self.n_steps = n_steps
|
|
110
|
-
self.start_lr = start_lr
|
|
111
|
-
self.end_lr = end_lr
|
|
112
|
-
self.delay_steps = delay_steps
|
|
113
|
-
|
|
114
|
-
self.cur = 0
|
|
115
|
-
|
|
116
|
-
def _update(self, vars, ascent):
|
|
117
|
-
if self.cur < self.delay_steps:
|
|
118
|
-
if self.start_lr != 1: ascent *= self.start_lr
|
|
119
|
-
|
|
120
|
-
elif self.cur >= self.n_steps + self.delay_steps:
|
|
121
|
-
if self.end_lr != 1: ascent *= self.end_lr
|
|
122
|
-
|
|
123
|
-
else:
|
|
124
|
-
remaining = (self.n_steps - (self.cur-self.delay_steps)) / self.n_steps
|
|
125
|
-
lr = (self.start_lr * remaining) + self.end_lr * (1 - remaining)
|
|
126
|
-
ascent *= lr
|
|
127
|
-
|
|
128
|
-
self.cur += 1
|
|
129
|
-
return ascent
|
|
130
|
-
|
|
131
|
-
|