torchzero 0.3.10__py3-none-any.whl → 0.3.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_identical.py +2 -3
- tests/test_opts.py +140 -100
- tests/test_tensorlist.py +8 -7
- tests/test_vars.py +1 -0
- torchzero/__init__.py +1 -1
- torchzero/core/__init__.py +2 -2
- torchzero/core/module.py +335 -50
- torchzero/core/reformulation.py +65 -0
- torchzero/core/transform.py +197 -70
- torchzero/modules/__init__.py +13 -4
- torchzero/modules/adaptive/__init__.py +30 -0
- torchzero/modules/adaptive/adagrad.py +356 -0
- torchzero/modules/adaptive/adahessian.py +224 -0
- torchzero/modules/{optimizers → adaptive}/adam.py +6 -8
- torchzero/modules/adaptive/adan.py +96 -0
- torchzero/modules/adaptive/adaptive_heavyball.py +54 -0
- torchzero/modules/adaptive/aegd.py +54 -0
- torchzero/modules/adaptive/esgd.py +171 -0
- torchzero/modules/{optimizers → adaptive}/lion.py +1 -1
- torchzero/modules/{experimental/spectral.py → adaptive/lmadagrad.py} +94 -71
- torchzero/modules/adaptive/mars.py +79 -0
- torchzero/modules/adaptive/matrix_momentum.py +146 -0
- torchzero/modules/adaptive/msam.py +188 -0
- torchzero/modules/{optimizers → adaptive}/muon.py +29 -5
- torchzero/modules/adaptive/natural_gradient.py +175 -0
- torchzero/modules/{optimizers → adaptive}/orthograd.py +1 -1
- torchzero/modules/{optimizers → adaptive}/rmsprop.py +7 -4
- torchzero/modules/{optimizers → adaptive}/rprop.py +42 -10
- torchzero/modules/adaptive/sam.py +163 -0
- torchzero/modules/{optimizers → adaptive}/shampoo.py +47 -9
- torchzero/modules/{optimizers → adaptive}/soap.py +52 -65
- torchzero/modules/adaptive/sophia_h.py +185 -0
- torchzero/modules/clipping/clipping.py +115 -25
- torchzero/modules/clipping/ema_clipping.py +31 -17
- torchzero/modules/clipping/growth_clipping.py +8 -7
- torchzero/modules/conjugate_gradient/__init__.py +11 -0
- torchzero/modules/conjugate_gradient/cg.py +355 -0
- torchzero/modules/experimental/__init__.py +13 -19
- torchzero/modules/{projections → experimental}/dct.py +11 -11
- torchzero/modules/{projections → experimental}/fft.py +10 -10
- torchzero/modules/experimental/gradmin.py +4 -3
- torchzero/modules/experimental/l_infinity.py +111 -0
- torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +5 -42
- torchzero/modules/experimental/newton_solver.py +79 -17
- torchzero/modules/experimental/newtonnewton.py +32 -15
- torchzero/modules/experimental/reduce_outward_lr.py +4 -4
- torchzero/modules/experimental/scipy_newton_cg.py +105 -0
- torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +13 -55
- torchzero/modules/functional.py +52 -6
- torchzero/modules/grad_approximation/fdm.py +30 -4
- torchzero/modules/grad_approximation/forward_gradient.py +16 -4
- torchzero/modules/grad_approximation/grad_approximator.py +51 -10
- torchzero/modules/grad_approximation/rfdm.py +321 -52
- torchzero/modules/higher_order/__init__.py +1 -1
- torchzero/modules/higher_order/higher_order_newton.py +164 -93
- torchzero/modules/least_squares/__init__.py +1 -0
- torchzero/modules/least_squares/gn.py +161 -0
- torchzero/modules/line_search/__init__.py +4 -4
- torchzero/modules/line_search/_polyinterp.py +289 -0
- torchzero/modules/line_search/adaptive.py +124 -0
- torchzero/modules/line_search/backtracking.py +95 -57
- torchzero/modules/line_search/line_search.py +171 -22
- torchzero/modules/line_search/scipy.py +3 -3
- torchzero/modules/line_search/strong_wolfe.py +327 -199
- torchzero/modules/misc/__init__.py +35 -0
- torchzero/modules/misc/debug.py +48 -0
- torchzero/modules/misc/escape.py +62 -0
- torchzero/modules/misc/gradient_accumulation.py +136 -0
- torchzero/modules/misc/homotopy.py +59 -0
- torchzero/modules/misc/misc.py +383 -0
- torchzero/modules/misc/multistep.py +194 -0
- torchzero/modules/misc/regularization.py +167 -0
- torchzero/modules/misc/split.py +123 -0
- torchzero/modules/{ops → misc}/switch.py +45 -4
- torchzero/modules/momentum/__init__.py +1 -5
- torchzero/modules/momentum/averaging.py +9 -9
- torchzero/modules/momentum/cautious.py +51 -19
- torchzero/modules/momentum/momentum.py +37 -2
- torchzero/modules/ops/__init__.py +11 -31
- torchzero/modules/ops/accumulate.py +6 -10
- torchzero/modules/ops/binary.py +81 -34
- torchzero/modules/{momentum/ema.py → ops/higher_level.py} +16 -39
- torchzero/modules/ops/multi.py +82 -21
- torchzero/modules/ops/reduce.py +16 -8
- torchzero/modules/ops/unary.py +29 -13
- torchzero/modules/ops/utility.py +30 -18
- torchzero/modules/projections/__init__.py +2 -4
- torchzero/modules/projections/cast.py +51 -0
- torchzero/modules/projections/galore.py +3 -1
- torchzero/modules/projections/projection.py +190 -96
- torchzero/modules/quasi_newton/__init__.py +9 -14
- torchzero/modules/quasi_newton/damping.py +105 -0
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -0
- torchzero/modules/quasi_newton/lbfgs.py +286 -173
- torchzero/modules/quasi_newton/lsr1.py +185 -106
- torchzero/modules/quasi_newton/quasi_newton.py +816 -268
- torchzero/modules/restarts/__init__.py +7 -0
- torchzero/modules/restarts/restars.py +252 -0
- torchzero/modules/second_order/__init__.py +3 -2
- torchzero/modules/second_order/multipoint.py +238 -0
- torchzero/modules/second_order/newton.py +292 -68
- torchzero/modules/second_order/newton_cg.py +365 -15
- torchzero/modules/second_order/nystrom.py +104 -1
- torchzero/modules/smoothing/__init__.py +1 -1
- torchzero/modules/smoothing/laplacian.py +14 -4
- torchzero/modules/smoothing/sampling.py +300 -0
- torchzero/modules/step_size/__init__.py +2 -0
- torchzero/modules/step_size/adaptive.py +387 -0
- torchzero/modules/step_size/lr.py +154 -0
- torchzero/modules/termination/__init__.py +14 -0
- torchzero/modules/termination/termination.py +207 -0
- torchzero/modules/trust_region/__init__.py +5 -0
- torchzero/modules/trust_region/cubic_regularization.py +170 -0
- torchzero/modules/trust_region/dogleg.py +92 -0
- torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
- torchzero/modules/trust_region/trust_cg.py +97 -0
- torchzero/modules/trust_region/trust_region.py +350 -0
- torchzero/modules/variance_reduction/__init__.py +1 -0
- torchzero/modules/variance_reduction/svrg.py +208 -0
- torchzero/modules/weight_decay/__init__.py +1 -1
- torchzero/modules/weight_decay/weight_decay.py +94 -11
- torchzero/modules/wrappers/optim_wrapper.py +29 -1
- torchzero/modules/zeroth_order/__init__.py +1 -0
- torchzero/modules/zeroth_order/cd.py +359 -0
- torchzero/optim/root.py +65 -0
- torchzero/optim/utility/split.py +8 -8
- torchzero/optim/wrappers/directsearch.py +39 -3
- torchzero/optim/wrappers/fcmaes.py +24 -15
- torchzero/optim/wrappers/mads.py +5 -6
- torchzero/optim/wrappers/nevergrad.py +16 -1
- torchzero/optim/wrappers/nlopt.py +0 -2
- torchzero/optim/wrappers/optuna.py +3 -3
- torchzero/optim/wrappers/scipy.py +86 -25
- torchzero/utils/__init__.py +40 -4
- torchzero/utils/compile.py +1 -1
- torchzero/utils/derivatives.py +126 -114
- torchzero/utils/linalg/__init__.py +9 -2
- torchzero/utils/linalg/linear_operator.py +329 -0
- torchzero/utils/linalg/matrix_funcs.py +2 -2
- torchzero/utils/linalg/orthogonalize.py +2 -1
- torchzero/utils/linalg/qr.py +2 -2
- torchzero/utils/linalg/solve.py +369 -58
- torchzero/utils/metrics.py +83 -0
- torchzero/utils/numberlist.py +2 -0
- torchzero/utils/python_tools.py +16 -0
- torchzero/utils/tensorlist.py +134 -51
- torchzero/utils/torch_tools.py +9 -4
- torchzero-0.3.13.dist-info/METADATA +14 -0
- torchzero-0.3.13.dist-info/RECORD +166 -0
- {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/top_level.txt +0 -1
- docs/source/conf.py +0 -57
- torchzero/modules/experimental/absoap.py +0 -250
- torchzero/modules/experimental/adadam.py +0 -112
- torchzero/modules/experimental/adamY.py +0 -125
- torchzero/modules/experimental/adasoap.py +0 -172
- torchzero/modules/experimental/diagonal_higher_order_newton.py +0 -225
- torchzero/modules/experimental/eigendescent.py +0 -117
- torchzero/modules/experimental/etf.py +0 -172
- torchzero/modules/experimental/soapy.py +0 -163
- torchzero/modules/experimental/structured_newton.py +0 -111
- torchzero/modules/experimental/subspace_preconditioners.py +0 -138
- torchzero/modules/experimental/tada.py +0 -38
- torchzero/modules/line_search/trust_region.py +0 -73
- torchzero/modules/lr/__init__.py +0 -2
- torchzero/modules/lr/adaptive.py +0 -93
- torchzero/modules/lr/lr.py +0 -63
- torchzero/modules/momentum/matrix_momentum.py +0 -166
- torchzero/modules/ops/debug.py +0 -25
- torchzero/modules/ops/misc.py +0 -418
- torchzero/modules/ops/split.py +0 -75
- torchzero/modules/optimizers/__init__.py +0 -18
- torchzero/modules/optimizers/adagrad.py +0 -155
- torchzero/modules/optimizers/sophia_h.py +0 -129
- torchzero/modules/quasi_newton/cg.py +0 -268
- torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
- torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +0 -266
- torchzero/modules/quasi_newton/olbfgs.py +0 -196
- torchzero/modules/smoothing/gaussian.py +0 -164
- torchzero-0.3.10.dist-info/METADATA +0 -379
- torchzero-0.3.10.dist-info/RECORD +0 -139
- torchzero-0.3.10.dist-info/licenses/LICENSE +0 -21
- {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,208 @@
|
|
|
1
|
+
import warnings
|
|
2
|
+
from functools import partial
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
from ...core.module import Module
|
|
7
|
+
from ...utils import tofloat
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def _reset_except_self(optimizer, var, self: Module):
|
|
11
|
+
for m in optimizer.unrolled_modules:
|
|
12
|
+
if m is not self:
|
|
13
|
+
m.reset()
|
|
14
|
+
|
|
15
|
+
class SVRG(Module):
|
|
16
|
+
"""Stochastic variance reduced gradient method (SVRG).
|
|
17
|
+
|
|
18
|
+
To use, put SVRG as the first module, it can be used with any other modules.
|
|
19
|
+
To reduce variance of a gradient estimator, put the gradient estimator before SVRG.
|
|
20
|
+
|
|
21
|
+
First it uses first ``accum_steps`` batches to compute full gradient at initial
|
|
22
|
+
parameters using gradient accumulation, the model will not be updated during this.
|
|
23
|
+
|
|
24
|
+
Then it performs ``svrg_steps`` SVRG steps, each requires two forward and backward passes.
|
|
25
|
+
|
|
26
|
+
After ``svrg_steps``, it goes back to full gradient computation step step.
|
|
27
|
+
|
|
28
|
+
As an alternative to gradient accumulation you can pass "full_closure" argument to the ``step`` method,
|
|
29
|
+
which should compute full gradients, set them to ``.grad`` attributes of the parameters,
|
|
30
|
+
and return full loss.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
svrg_steps (int): number of steps before calculating full gradient. This can be set to length of the dataloader.
|
|
34
|
+
accum_steps (int | None, optional):
|
|
35
|
+
number of steps to accumulate the gradient for. Not used if "full_closure" is passed to the ``step`` method. If None, uses value of ``svrg_steps``. Defaults to None.
|
|
36
|
+
reset_before_accum (bool, optional):
|
|
37
|
+
whether to reset all other modules when re-calculating full gradient. Defaults to True.
|
|
38
|
+
svrg_loss (bool, optional):
|
|
39
|
+
whether to replace loss with SVRG loss (calculated by same formula as SVRG gradient). Defaults to True.
|
|
40
|
+
alpha (float, optional):
|
|
41
|
+
multiplier to ``g_full(x_0) - g_batch(x_0)`` term, can be annealed linearly from 1 to 0 as suggested in https://arxiv.org/pdf/2311.05589#page=6
|
|
42
|
+
|
|
43
|
+
## Examples:
|
|
44
|
+
SVRG-LBFGS
|
|
45
|
+
```python
|
|
46
|
+
opt = tz.Modular(
|
|
47
|
+
model.parameters(),
|
|
48
|
+
tz.m.SVRG(len(dataloader)),
|
|
49
|
+
tz.m.LBFGS(),
|
|
50
|
+
tz.m.Backtracking(),
|
|
51
|
+
)
|
|
52
|
+
```
|
|
53
|
+
|
|
54
|
+
For extra variance reduction one can use Online versions of algorithms, although it won't always help.
|
|
55
|
+
```python
|
|
56
|
+
opt = tz.Modular(
|
|
57
|
+
model.parameters(),
|
|
58
|
+
tz.m.SVRG(len(dataloader)),
|
|
59
|
+
tz.m.Online(tz.m.LBFGS()),
|
|
60
|
+
tz.m.Backtracking(),
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
Variance reduction can also be applied to gradient estimators.
|
|
64
|
+
```python
|
|
65
|
+
opt = tz.Modular(
|
|
66
|
+
model.parameters(),
|
|
67
|
+
tz.m.SPSA(),
|
|
68
|
+
tz.m.SVRG(100),
|
|
69
|
+
tz.m.LR(1e-2),
|
|
70
|
+
)
|
|
71
|
+
```
|
|
72
|
+
## Notes
|
|
73
|
+
|
|
74
|
+
The SVRG gradient is computed as ``g_b(x) - alpha * g_b(x_0) - g_f(x0.)``, where:
|
|
75
|
+
- ``x`` is current parameters
|
|
76
|
+
- ``x_0`` is initial parameters, where full gradient was computed
|
|
77
|
+
- ``g_b`` refers to mini-batch gradient at ``x`` or ``x_0``
|
|
78
|
+
- ``g_f`` refers to full gradient at ``x_0``.
|
|
79
|
+
|
|
80
|
+
The SVRG loss is computed using the same formula.
|
|
81
|
+
"""
|
|
82
|
+
def __init__(self, svrg_steps: int, accum_steps: int | None = None, reset_before_accum:bool=True, svrg_loss:bool=True, alpha:float=1):
|
|
83
|
+
defaults = dict(svrg_steps = svrg_steps, accum_steps=accum_steps, reset_before_accum=reset_before_accum, svrg_loss=svrg_loss, alpha=alpha)
|
|
84
|
+
super().__init__(defaults)
|
|
85
|
+
|
|
86
|
+
@torch.no_grad
|
|
87
|
+
def step(self, var):
|
|
88
|
+
params = var.params
|
|
89
|
+
closure = var.closure
|
|
90
|
+
assert closure is not None
|
|
91
|
+
|
|
92
|
+
if "full_grad" not in self.global_state:
|
|
93
|
+
|
|
94
|
+
# -------------------------- calculate full gradient ------------------------- #
|
|
95
|
+
if "full_closure" in var.storage:
|
|
96
|
+
full_closure = var.storage['full_closure']
|
|
97
|
+
with torch.enable_grad():
|
|
98
|
+
full_loss = full_closure()
|
|
99
|
+
if all(p.grad is None for p in params):
|
|
100
|
+
warnings.warn("all gradients are None after evaluating full_closure.")
|
|
101
|
+
|
|
102
|
+
full_grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
|
|
103
|
+
self.global_state["full_loss"] = full_loss
|
|
104
|
+
self.global_state["full_grad"] = full_grad
|
|
105
|
+
self.global_state['x_0'] = [p.clone() for p in params]
|
|
106
|
+
|
|
107
|
+
# current batch will be used for svrg update
|
|
108
|
+
|
|
109
|
+
else:
|
|
110
|
+
# accumulate gradients over n steps
|
|
111
|
+
accum_steps = self.defaults['accum_steps']
|
|
112
|
+
if accum_steps is None: accum_steps = self.defaults['svrg_steps']
|
|
113
|
+
|
|
114
|
+
current_accum_step = self.global_state.get('current_accum_step', 0) + 1
|
|
115
|
+
self.global_state['current_accum_step'] = current_accum_step
|
|
116
|
+
|
|
117
|
+
# accumulate grads
|
|
118
|
+
accumulator = self.get_state(params, 'accumulator')
|
|
119
|
+
grad = var.get_grad()
|
|
120
|
+
torch._foreach_add_(accumulator, grad)
|
|
121
|
+
|
|
122
|
+
# accumulate loss
|
|
123
|
+
loss_accumulator = self.global_state.get('loss_accumulator', 0)
|
|
124
|
+
loss_accumulator += tofloat(var.loss)
|
|
125
|
+
self.global_state['loss_accumulator'] = loss_accumulator
|
|
126
|
+
|
|
127
|
+
# on nth step, use the accumulated gradient
|
|
128
|
+
if current_accum_step >= accum_steps:
|
|
129
|
+
torch._foreach_div_(accumulator, accum_steps)
|
|
130
|
+
self.global_state["full_grad"] = accumulator
|
|
131
|
+
self.global_state["full_loss"] = loss_accumulator / accum_steps
|
|
132
|
+
|
|
133
|
+
self.global_state['x_0'] = [p.clone() for p in params]
|
|
134
|
+
self.clear_state_keys('accumulator')
|
|
135
|
+
del self.global_state['current_accum_step']
|
|
136
|
+
|
|
137
|
+
# otherwise skip update until enough grads are accumulated
|
|
138
|
+
else:
|
|
139
|
+
var.update = None
|
|
140
|
+
var.stop = True
|
|
141
|
+
var.skip_update = True
|
|
142
|
+
return var
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
svrg_steps = self.defaults['svrg_steps']
|
|
146
|
+
current_svrg_step = self.global_state.get('current_svrg_step', 0) + 1
|
|
147
|
+
self.global_state['current_svrg_step'] = current_svrg_step
|
|
148
|
+
|
|
149
|
+
# --------------------------- SVRG gradient closure -------------------------- #
|
|
150
|
+
x0 = self.global_state['x_0']
|
|
151
|
+
gf_x0 = self.global_state["full_grad"]
|
|
152
|
+
ff_x0 = self.global_state['full_loss']
|
|
153
|
+
use_svrg_loss = self.defaults['svrg_loss']
|
|
154
|
+
alpha = self.get_settings(params, 'alpha')
|
|
155
|
+
alpha_0 = alpha[0]
|
|
156
|
+
if all(a == 1 for a in alpha): alpha = None
|
|
157
|
+
|
|
158
|
+
def svrg_closure(backward=True):
|
|
159
|
+
# g_b(x) - α * (g_f(x_0) - g_b(x_0)) and same for loss
|
|
160
|
+
with torch.no_grad():
|
|
161
|
+
x = [p.clone() for p in params]
|
|
162
|
+
|
|
163
|
+
if backward:
|
|
164
|
+
# f and g at x
|
|
165
|
+
with torch.enable_grad(): fb_x = closure()
|
|
166
|
+
gb_x = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
|
|
167
|
+
|
|
168
|
+
# f and g at x_0
|
|
169
|
+
torch._foreach_copy_(params, x0)
|
|
170
|
+
with torch.enable_grad(): fb_x0 = closure()
|
|
171
|
+
gb_x0 = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
|
|
172
|
+
torch._foreach_copy_(params, x)
|
|
173
|
+
|
|
174
|
+
# g_svrg = gb_x - alpha * (gf_x0 - gb_x0)
|
|
175
|
+
correction = torch._foreach_sub(gb_x0, gf_x0)
|
|
176
|
+
if alpha is not None: torch._foreach_mul_(correction, alpha)
|
|
177
|
+
g_svrg = torch._foreach_sub(gb_x, correction)
|
|
178
|
+
|
|
179
|
+
f_svrg = fb_x - alpha_0 * (fb_x0 - ff_x0)
|
|
180
|
+
for p, g in zip(params, g_svrg):
|
|
181
|
+
p.grad = g
|
|
182
|
+
|
|
183
|
+
if use_svrg_loss: return f_svrg
|
|
184
|
+
return fb_x
|
|
185
|
+
|
|
186
|
+
# no backward
|
|
187
|
+
if use_svrg_loss:
|
|
188
|
+
fb_x = closure(False)
|
|
189
|
+
torch._foreach_copy_(params, x0)
|
|
190
|
+
fb_x0 = closure(False)
|
|
191
|
+
torch._foreach_copy_(params, x)
|
|
192
|
+
f_svrg = fb_x - alpha_0 * (fb_x0 - ff_x0)
|
|
193
|
+
return f_svrg
|
|
194
|
+
|
|
195
|
+
return closure(False)
|
|
196
|
+
|
|
197
|
+
var.closure = svrg_closure
|
|
198
|
+
|
|
199
|
+
# --- after svrg_steps steps reset so that new full gradient is calculated on next step --- #
|
|
200
|
+
if current_svrg_step >= svrg_steps:
|
|
201
|
+
del self.global_state['current_svrg_step']
|
|
202
|
+
del self.global_state['full_grad']
|
|
203
|
+
del self.global_state['full_loss']
|
|
204
|
+
del self.global_state['x_0']
|
|
205
|
+
if self.defaults['reset_before_accum']:
|
|
206
|
+
var.post_step_hooks.append(partial(_reset_except_self, self=self))
|
|
207
|
+
|
|
208
|
+
return var
|
|
@@ -1 +1 @@
|
|
|
1
|
-
from .weight_decay import WeightDecay, DirectWeightDecay, decay_weights_,
|
|
1
|
+
from .weight_decay import WeightDecay, DirectWeightDecay, decay_weights_, RelativeWeightDecay
|
|
@@ -4,7 +4,7 @@ from typing import Literal
|
|
|
4
4
|
import torch
|
|
5
5
|
|
|
6
6
|
from ...core import Module, Target, Transform
|
|
7
|
-
from ...utils import NumberList, TensorList, as_tensorlist, unpack_dicts, unpack_states
|
|
7
|
+
from ...utils import NumberList, TensorList, as_tensorlist, unpack_dicts, unpack_states, Metrics
|
|
8
8
|
|
|
9
9
|
|
|
10
10
|
@torch.no_grad
|
|
@@ -14,7 +14,7 @@ def weight_decay_(
|
|
|
14
14
|
weight_decay: float | NumberList,
|
|
15
15
|
ord: int = 2
|
|
16
16
|
):
|
|
17
|
-
"""returns
|
|
17
|
+
"""modifies in-place and returns ``grad_``."""
|
|
18
18
|
if ord == 1: return grad_.add_(params.sign().mul_(weight_decay))
|
|
19
19
|
if ord == 2: return grad_.add_(params.mul(weight_decay))
|
|
20
20
|
if ord - 1 % 2 != 0: return grad_.add_(params.pow(ord-1).mul_(weight_decay))
|
|
@@ -22,34 +22,113 @@ def weight_decay_(
|
|
|
22
22
|
|
|
23
23
|
|
|
24
24
|
class WeightDecay(Transform):
|
|
25
|
+
"""Weight decay.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
weight_decay (float): weight decay scale.
|
|
29
|
+
ord (int, optional): order of the penalty, e.g. 1 for L1 and 2 for L2. Defaults to 2.
|
|
30
|
+
target (Target, optional): what to set on var. Defaults to 'update'.
|
|
31
|
+
|
|
32
|
+
### Examples:
|
|
33
|
+
|
|
34
|
+
Adam with non-decoupled weight decay
|
|
35
|
+
```python
|
|
36
|
+
opt = tz.Modular(
|
|
37
|
+
model.parameters(),
|
|
38
|
+
tz.m.WeightDecay(1e-3),
|
|
39
|
+
tz.m.Adam(),
|
|
40
|
+
tz.m.LR(1e-3)
|
|
41
|
+
)
|
|
42
|
+
```
|
|
43
|
+
|
|
44
|
+
Adam with decoupled weight decay that still scales with learning rate
|
|
45
|
+
```python
|
|
46
|
+
|
|
47
|
+
opt = tz.Modular(
|
|
48
|
+
model.parameters(),
|
|
49
|
+
tz.m.Adam(),
|
|
50
|
+
tz.m.WeightDecay(1e-3),
|
|
51
|
+
tz.m.LR(1e-3)
|
|
52
|
+
)
|
|
53
|
+
```
|
|
54
|
+
|
|
55
|
+
Adam with fully decoupled weight decay that doesn't scale with learning rate
|
|
56
|
+
```python
|
|
57
|
+
opt = tz.Modular(
|
|
58
|
+
model.parameters(),
|
|
59
|
+
tz.m.Adam(),
|
|
60
|
+
tz.m.LR(1e-3),
|
|
61
|
+
tz.m.WeightDecay(1e-6)
|
|
62
|
+
)
|
|
63
|
+
```
|
|
64
|
+
|
|
65
|
+
"""
|
|
25
66
|
def __init__(self, weight_decay: float, ord: int = 2, target: Target = 'update'):
|
|
67
|
+
|
|
26
68
|
defaults = dict(weight_decay=weight_decay, ord=ord)
|
|
27
69
|
super().__init__(defaults, uses_grad=False, target=target)
|
|
28
70
|
|
|
29
71
|
@torch.no_grad
|
|
30
|
-
def
|
|
72
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
31
73
|
weight_decay = NumberList(s['weight_decay'] for s in settings)
|
|
32
74
|
ord = settings[0]['ord']
|
|
33
75
|
|
|
34
76
|
return weight_decay_(as_tensorlist(tensors), as_tensorlist(params), weight_decay, ord)
|
|
35
77
|
|
|
36
|
-
class
|
|
78
|
+
class RelativeWeightDecay(Transform):
|
|
79
|
+
"""Weight decay relative to the mean absolute value of update, gradient or parameters depending on value of ``norm_input`` argument.
|
|
80
|
+
|
|
81
|
+
Args:
|
|
82
|
+
weight_decay (float): relative weight decay scale.
|
|
83
|
+
ord (int, optional): order of the penalty, e.g. 1 for L1 and 2 for L2. Defaults to 2.
|
|
84
|
+
norm_input (str, optional):
|
|
85
|
+
determines what should weight decay be relative to. "update", "grad" or "params".
|
|
86
|
+
Defaults to "update".
|
|
87
|
+
metric (Ords, optional):
|
|
88
|
+
metric (norm, etc) that weight decay should be relative to.
|
|
89
|
+
defaults to 'mad' (mean absolute deviation).
|
|
90
|
+
target (Target, optional): what to set on var. Defaults to 'update'.
|
|
91
|
+
|
|
92
|
+
### Examples:
|
|
93
|
+
|
|
94
|
+
Adam with non-decoupled relative weight decay
|
|
95
|
+
```python
|
|
96
|
+
opt = tz.Modular(
|
|
97
|
+
model.parameters(),
|
|
98
|
+
tz.m.RelativeWeightDecay(1e-1),
|
|
99
|
+
tz.m.Adam(),
|
|
100
|
+
tz.m.LR(1e-3)
|
|
101
|
+
)
|
|
102
|
+
```
|
|
103
|
+
|
|
104
|
+
Adam with decoupled relative weight decay
|
|
105
|
+
```python
|
|
106
|
+
opt = tz.Modular(
|
|
107
|
+
model.parameters(),
|
|
108
|
+
tz.m.Adam(),
|
|
109
|
+
tz.m.RelativeWeightDecay(1e-1),
|
|
110
|
+
tz.m.LR(1e-3)
|
|
111
|
+
)
|
|
112
|
+
```
|
|
113
|
+
"""
|
|
37
114
|
def __init__(
|
|
38
115
|
self,
|
|
39
116
|
weight_decay: float = 0.1,
|
|
40
|
-
ord: int
|
|
117
|
+
ord: int = 2,
|
|
41
118
|
norm_input: Literal["update", "grad", "params"] = "update",
|
|
119
|
+
metric: Metrics = 'mad',
|
|
42
120
|
target: Target = "update",
|
|
43
121
|
):
|
|
44
|
-
defaults = dict(weight_decay=weight_decay, ord=ord, norm_input=norm_input)
|
|
122
|
+
defaults = dict(weight_decay=weight_decay, ord=ord, norm_input=norm_input, metric=metric)
|
|
45
123
|
super().__init__(defaults, uses_grad=norm_input == 'grad', target=target)
|
|
46
124
|
|
|
47
125
|
@torch.no_grad
|
|
48
|
-
def
|
|
126
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
49
127
|
weight_decay = NumberList(s['weight_decay'] for s in settings)
|
|
50
128
|
|
|
51
129
|
ord = settings[0]['ord']
|
|
52
130
|
norm_input = settings[0]['norm_input']
|
|
131
|
+
metric = settings[0]['metric']
|
|
53
132
|
|
|
54
133
|
if norm_input == 'update': src = TensorList(tensors)
|
|
55
134
|
elif norm_input == 'grad':
|
|
@@ -60,8 +139,7 @@ class NormalizedWeightDecay(Transform):
|
|
|
60
139
|
else:
|
|
61
140
|
raise ValueError(norm_input)
|
|
62
141
|
|
|
63
|
-
norm = src.
|
|
64
|
-
|
|
142
|
+
norm = src.global_metric(metric)
|
|
65
143
|
return weight_decay_(as_tensorlist(tensors), as_tensorlist(params), weight_decay * norm, ord)
|
|
66
144
|
|
|
67
145
|
|
|
@@ -72,7 +150,12 @@ def decay_weights_(params: Iterable[torch.Tensor], weight_decay: float | NumberL
|
|
|
72
150
|
weight_decay_(params, params, -weight_decay, ord)
|
|
73
151
|
|
|
74
152
|
class DirectWeightDecay(Module):
|
|
75
|
-
"""
|
|
153
|
+
"""Directly applies weight decay to parameters.
|
|
154
|
+
|
|
155
|
+
Args:
|
|
156
|
+
weight_decay (float): weight decay scale.
|
|
157
|
+
ord (int, optional): order of the penalty, e.g. 1 for L1 and 2 for L2. Defaults to 2.
|
|
158
|
+
"""
|
|
76
159
|
def __init__(self, weight_decay: float, ord: int = 2,):
|
|
77
160
|
defaults = dict(weight_decay=weight_decay, ord=ord)
|
|
78
161
|
super().__init__(defaults)
|
|
@@ -80,7 +163,7 @@ class DirectWeightDecay(Module):
|
|
|
80
163
|
@torch.no_grad
|
|
81
164
|
def step(self, var):
|
|
82
165
|
weight_decay = self.get_settings(var.params, 'weight_decay', cls=NumberList)
|
|
83
|
-
ord = self.
|
|
166
|
+
ord = self.defaults['ord']
|
|
84
167
|
|
|
85
168
|
decay_weights_(var.params, weight_decay, ord)
|
|
86
169
|
return var
|
|
@@ -7,7 +7,35 @@ from ...utils import Params, _copy_param_groups, _make_param_groups
|
|
|
7
7
|
|
|
8
8
|
|
|
9
9
|
class Wrap(Module):
|
|
10
|
-
"""
|
|
10
|
+
"""
|
|
11
|
+
Wraps a pytorch optimizer to use it as a module.
|
|
12
|
+
|
|
13
|
+
.. note::
|
|
14
|
+
Custom param groups are supported only by `set_param_groups`, settings passed to Modular will be ignored.
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
opt_fn (Callable[..., torch.optim.Optimizer] | torch.optim.Optimizer):
|
|
18
|
+
function that takes in parameters and returns the optimizer, for example :code:`torch.optim.Adam`
|
|
19
|
+
or :code:`lambda parameters: torch.optim.Adam(parameters, lr=1e-3)`
|
|
20
|
+
*args:
|
|
21
|
+
**kwargs:
|
|
22
|
+
Extra args to be passed to opt_fn. The function is called as :code:`opt_fn(parameters, *args, **kwargs)`.
|
|
23
|
+
|
|
24
|
+
Example:
|
|
25
|
+
wrapping pytorch_optimizer.StableAdamW
|
|
26
|
+
|
|
27
|
+
.. code-block:: py
|
|
28
|
+
|
|
29
|
+
from pytorch_optimizer import StableAdamW
|
|
30
|
+
opt = tz.Modular(
|
|
31
|
+
model.parameters(),
|
|
32
|
+
tz.m.Wrap(StableAdamW, lr=1),
|
|
33
|
+
tz.m.Cautious(),
|
|
34
|
+
tz.m.LR(1e-2)
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
"""
|
|
11
39
|
def __init__(self, opt_fn: Callable[..., torch.optim.Optimizer] | torch.optim.Optimizer, *args, **kwargs):
|
|
12
40
|
super().__init__()
|
|
13
41
|
self._opt_fn = opt_fn
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .cd import CD, CCD, CCDLS
|