torchzero 0.3.10__py3-none-any.whl → 0.3.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_identical.py +2 -3
- tests/test_opts.py +140 -100
- tests/test_tensorlist.py +8 -7
- tests/test_vars.py +1 -0
- torchzero/__init__.py +1 -1
- torchzero/core/__init__.py +2 -2
- torchzero/core/module.py +335 -50
- torchzero/core/reformulation.py +65 -0
- torchzero/core/transform.py +197 -70
- torchzero/modules/__init__.py +13 -4
- torchzero/modules/adaptive/__init__.py +30 -0
- torchzero/modules/adaptive/adagrad.py +356 -0
- torchzero/modules/adaptive/adahessian.py +224 -0
- torchzero/modules/{optimizers → adaptive}/adam.py +6 -8
- torchzero/modules/adaptive/adan.py +96 -0
- torchzero/modules/adaptive/adaptive_heavyball.py +54 -0
- torchzero/modules/adaptive/aegd.py +54 -0
- torchzero/modules/adaptive/esgd.py +171 -0
- torchzero/modules/{optimizers → adaptive}/lion.py +1 -1
- torchzero/modules/{experimental/spectral.py → adaptive/lmadagrad.py} +94 -71
- torchzero/modules/adaptive/mars.py +79 -0
- torchzero/modules/adaptive/matrix_momentum.py +146 -0
- torchzero/modules/adaptive/msam.py +188 -0
- torchzero/modules/{optimizers → adaptive}/muon.py +29 -5
- torchzero/modules/adaptive/natural_gradient.py +175 -0
- torchzero/modules/{optimizers → adaptive}/orthograd.py +1 -1
- torchzero/modules/{optimizers → adaptive}/rmsprop.py +7 -4
- torchzero/modules/{optimizers → adaptive}/rprop.py +42 -10
- torchzero/modules/adaptive/sam.py +163 -0
- torchzero/modules/{optimizers → adaptive}/shampoo.py +47 -9
- torchzero/modules/{optimizers → adaptive}/soap.py +52 -65
- torchzero/modules/adaptive/sophia_h.py +185 -0
- torchzero/modules/clipping/clipping.py +115 -25
- torchzero/modules/clipping/ema_clipping.py +31 -17
- torchzero/modules/clipping/growth_clipping.py +8 -7
- torchzero/modules/conjugate_gradient/__init__.py +11 -0
- torchzero/modules/conjugate_gradient/cg.py +355 -0
- torchzero/modules/experimental/__init__.py +13 -19
- torchzero/modules/{projections → experimental}/dct.py +11 -11
- torchzero/modules/{projections → experimental}/fft.py +10 -10
- torchzero/modules/experimental/gradmin.py +4 -3
- torchzero/modules/experimental/l_infinity.py +111 -0
- torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +5 -42
- torchzero/modules/experimental/newton_solver.py +79 -17
- torchzero/modules/experimental/newtonnewton.py +32 -15
- torchzero/modules/experimental/reduce_outward_lr.py +4 -4
- torchzero/modules/experimental/scipy_newton_cg.py +105 -0
- torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +13 -55
- torchzero/modules/functional.py +52 -6
- torchzero/modules/grad_approximation/fdm.py +30 -4
- torchzero/modules/grad_approximation/forward_gradient.py +16 -4
- torchzero/modules/grad_approximation/grad_approximator.py +51 -10
- torchzero/modules/grad_approximation/rfdm.py +321 -52
- torchzero/modules/higher_order/__init__.py +1 -1
- torchzero/modules/higher_order/higher_order_newton.py +164 -93
- torchzero/modules/least_squares/__init__.py +1 -0
- torchzero/modules/least_squares/gn.py +161 -0
- torchzero/modules/line_search/__init__.py +4 -4
- torchzero/modules/line_search/_polyinterp.py +289 -0
- torchzero/modules/line_search/adaptive.py +124 -0
- torchzero/modules/line_search/backtracking.py +95 -57
- torchzero/modules/line_search/line_search.py +171 -22
- torchzero/modules/line_search/scipy.py +3 -3
- torchzero/modules/line_search/strong_wolfe.py +327 -199
- torchzero/modules/misc/__init__.py +35 -0
- torchzero/modules/misc/debug.py +48 -0
- torchzero/modules/misc/escape.py +62 -0
- torchzero/modules/misc/gradient_accumulation.py +136 -0
- torchzero/modules/misc/homotopy.py +59 -0
- torchzero/modules/misc/misc.py +383 -0
- torchzero/modules/misc/multistep.py +194 -0
- torchzero/modules/misc/regularization.py +167 -0
- torchzero/modules/misc/split.py +123 -0
- torchzero/modules/{ops → misc}/switch.py +45 -4
- torchzero/modules/momentum/__init__.py +1 -5
- torchzero/modules/momentum/averaging.py +9 -9
- torchzero/modules/momentum/cautious.py +51 -19
- torchzero/modules/momentum/momentum.py +37 -2
- torchzero/modules/ops/__init__.py +11 -31
- torchzero/modules/ops/accumulate.py +6 -10
- torchzero/modules/ops/binary.py +81 -34
- torchzero/modules/{momentum/ema.py → ops/higher_level.py} +16 -39
- torchzero/modules/ops/multi.py +82 -21
- torchzero/modules/ops/reduce.py +16 -8
- torchzero/modules/ops/unary.py +29 -13
- torchzero/modules/ops/utility.py +30 -18
- torchzero/modules/projections/__init__.py +2 -4
- torchzero/modules/projections/cast.py +51 -0
- torchzero/modules/projections/galore.py +3 -1
- torchzero/modules/projections/projection.py +190 -96
- torchzero/modules/quasi_newton/__init__.py +9 -14
- torchzero/modules/quasi_newton/damping.py +105 -0
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -0
- torchzero/modules/quasi_newton/lbfgs.py +286 -173
- torchzero/modules/quasi_newton/lsr1.py +185 -106
- torchzero/modules/quasi_newton/quasi_newton.py +816 -268
- torchzero/modules/restarts/__init__.py +7 -0
- torchzero/modules/restarts/restars.py +252 -0
- torchzero/modules/second_order/__init__.py +3 -2
- torchzero/modules/second_order/multipoint.py +238 -0
- torchzero/modules/second_order/newton.py +292 -68
- torchzero/modules/second_order/newton_cg.py +365 -15
- torchzero/modules/second_order/nystrom.py +104 -1
- torchzero/modules/smoothing/__init__.py +1 -1
- torchzero/modules/smoothing/laplacian.py +14 -4
- torchzero/modules/smoothing/sampling.py +300 -0
- torchzero/modules/step_size/__init__.py +2 -0
- torchzero/modules/step_size/adaptive.py +387 -0
- torchzero/modules/step_size/lr.py +154 -0
- torchzero/modules/termination/__init__.py +14 -0
- torchzero/modules/termination/termination.py +207 -0
- torchzero/modules/trust_region/__init__.py +5 -0
- torchzero/modules/trust_region/cubic_regularization.py +170 -0
- torchzero/modules/trust_region/dogleg.py +92 -0
- torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
- torchzero/modules/trust_region/trust_cg.py +97 -0
- torchzero/modules/trust_region/trust_region.py +350 -0
- torchzero/modules/variance_reduction/__init__.py +1 -0
- torchzero/modules/variance_reduction/svrg.py +208 -0
- torchzero/modules/weight_decay/__init__.py +1 -1
- torchzero/modules/weight_decay/weight_decay.py +94 -11
- torchzero/modules/wrappers/optim_wrapper.py +29 -1
- torchzero/modules/zeroth_order/__init__.py +1 -0
- torchzero/modules/zeroth_order/cd.py +359 -0
- torchzero/optim/root.py +65 -0
- torchzero/optim/utility/split.py +8 -8
- torchzero/optim/wrappers/directsearch.py +39 -3
- torchzero/optim/wrappers/fcmaes.py +24 -15
- torchzero/optim/wrappers/mads.py +5 -6
- torchzero/optim/wrappers/nevergrad.py +16 -1
- torchzero/optim/wrappers/nlopt.py +0 -2
- torchzero/optim/wrappers/optuna.py +3 -3
- torchzero/optim/wrappers/scipy.py +86 -25
- torchzero/utils/__init__.py +40 -4
- torchzero/utils/compile.py +1 -1
- torchzero/utils/derivatives.py +126 -114
- torchzero/utils/linalg/__init__.py +9 -2
- torchzero/utils/linalg/linear_operator.py +329 -0
- torchzero/utils/linalg/matrix_funcs.py +2 -2
- torchzero/utils/linalg/orthogonalize.py +2 -1
- torchzero/utils/linalg/qr.py +2 -2
- torchzero/utils/linalg/solve.py +369 -58
- torchzero/utils/metrics.py +83 -0
- torchzero/utils/numberlist.py +2 -0
- torchzero/utils/python_tools.py +16 -0
- torchzero/utils/tensorlist.py +134 -51
- torchzero/utils/torch_tools.py +9 -4
- torchzero-0.3.13.dist-info/METADATA +14 -0
- torchzero-0.3.13.dist-info/RECORD +166 -0
- {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/top_level.txt +0 -1
- docs/source/conf.py +0 -57
- torchzero/modules/experimental/absoap.py +0 -250
- torchzero/modules/experimental/adadam.py +0 -112
- torchzero/modules/experimental/adamY.py +0 -125
- torchzero/modules/experimental/adasoap.py +0 -172
- torchzero/modules/experimental/diagonal_higher_order_newton.py +0 -225
- torchzero/modules/experimental/eigendescent.py +0 -117
- torchzero/modules/experimental/etf.py +0 -172
- torchzero/modules/experimental/soapy.py +0 -163
- torchzero/modules/experimental/structured_newton.py +0 -111
- torchzero/modules/experimental/subspace_preconditioners.py +0 -138
- torchzero/modules/experimental/tada.py +0 -38
- torchzero/modules/line_search/trust_region.py +0 -73
- torchzero/modules/lr/__init__.py +0 -2
- torchzero/modules/lr/adaptive.py +0 -93
- torchzero/modules/lr/lr.py +0 -63
- torchzero/modules/momentum/matrix_momentum.py +0 -166
- torchzero/modules/ops/debug.py +0 -25
- torchzero/modules/ops/misc.py +0 -418
- torchzero/modules/ops/split.py +0 -75
- torchzero/modules/optimizers/__init__.py +0 -18
- torchzero/modules/optimizers/adagrad.py +0 -155
- torchzero/modules/optimizers/sophia_h.py +0 -129
- torchzero/modules/quasi_newton/cg.py +0 -268
- torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
- torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +0 -266
- torchzero/modules/quasi_newton/olbfgs.py +0 -196
- torchzero/modules/smoothing/gaussian.py +0 -164
- torchzero-0.3.10.dist-info/METADATA +0 -379
- torchzero-0.3.10.dist-info/RECORD +0 -139
- torchzero-0.3.10.dist-info/licenses/LICENSE +0 -21
- {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,194 @@
|
|
|
1
|
+
from collections.abc import Iterable
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from ...core import Chainable, Module, Var
|
|
6
|
+
from ...utils import TensorList
|
|
7
|
+
|
|
8
|
+
def _sequential_step(self: Module, var: Var, sequential: bool):
|
|
9
|
+
params = var.params
|
|
10
|
+
steps = self.settings[params[0]]['steps']
|
|
11
|
+
|
|
12
|
+
if sequential: modules = self.get_children_sequence() * steps
|
|
13
|
+
else: modules = [self.children['module']] * steps
|
|
14
|
+
|
|
15
|
+
if var.closure is None and len(modules) > 1: raise ValueError('Multistep and Sequential require closure')
|
|
16
|
+
|
|
17
|
+
# store original params unless this is last module and can update params directly
|
|
18
|
+
params_before_steps = None if (var.is_last and var.last_module_lrs is None) else [p.clone() for p in params]
|
|
19
|
+
|
|
20
|
+
# first step - pass var as usual
|
|
21
|
+
var = modules[0].step(var)
|
|
22
|
+
new_var = var
|
|
23
|
+
|
|
24
|
+
# subsequent steps - update parameters and create new var
|
|
25
|
+
if len(modules) > 1:
|
|
26
|
+
for m in modules[1:]:
|
|
27
|
+
|
|
28
|
+
# update params
|
|
29
|
+
if (not new_var.skip_update):
|
|
30
|
+
if new_var.last_module_lrs is not None:
|
|
31
|
+
torch._foreach_mul_(new_var.get_update(), new_var.last_module_lrs)
|
|
32
|
+
|
|
33
|
+
torch._foreach_sub_(params, new_var.get_update())
|
|
34
|
+
|
|
35
|
+
# create new var since we are at a new point, that means grad, update and loss will be None
|
|
36
|
+
new_var = Var(params=new_var.params, closure=new_var.closure,
|
|
37
|
+
model=new_var.model, current_step=new_var.current_step + 1)
|
|
38
|
+
|
|
39
|
+
# step
|
|
40
|
+
new_var = m.step(new_var)
|
|
41
|
+
|
|
42
|
+
# final parameter update
|
|
43
|
+
if (not new_var.skip_update):
|
|
44
|
+
if new_var.last_module_lrs is not None:
|
|
45
|
+
torch._foreach_mul_(new_var.get_update(), new_var.last_module_lrs)
|
|
46
|
+
|
|
47
|
+
torch._foreach_sub_(params, new_var.get_update())
|
|
48
|
+
|
|
49
|
+
# if last module, update is applied so return new var
|
|
50
|
+
if params_before_steps is None:
|
|
51
|
+
new_var.stop = True
|
|
52
|
+
new_var.skip_update = True
|
|
53
|
+
return new_var
|
|
54
|
+
|
|
55
|
+
# otherwise use parameter difference as update
|
|
56
|
+
var.update = list(torch._foreach_sub(params_before_steps, params))
|
|
57
|
+
for p, bef in zip(params, params_before_steps):
|
|
58
|
+
p.set_(bef) # pyright:ignore[reportArgumentType]
|
|
59
|
+
return var
|
|
60
|
+
|
|
61
|
+
class Multistep(Module):
|
|
62
|
+
"""Performs :code:`steps` inner steps with :code:`module` per each step.
|
|
63
|
+
|
|
64
|
+
The update is taken to be the parameter difference between parameters before and after the inner loop."""
|
|
65
|
+
def __init__(self, module: Chainable, steps: int):
|
|
66
|
+
defaults = dict(steps=steps)
|
|
67
|
+
super().__init__(defaults)
|
|
68
|
+
self.set_child('module', module)
|
|
69
|
+
|
|
70
|
+
@torch.no_grad
|
|
71
|
+
def step(self, var):
|
|
72
|
+
return _sequential_step(self, var, sequential=False)
|
|
73
|
+
|
|
74
|
+
class Sequential(Module):
|
|
75
|
+
"""On each step, this sequentially steps with :code:`modules` :code:`steps` times.
|
|
76
|
+
|
|
77
|
+
The update is taken to be the parameter difference between parameters before and after the inner loop."""
|
|
78
|
+
def __init__(self, modules: Iterable[Chainable], steps: int=1):
|
|
79
|
+
defaults = dict(steps=steps)
|
|
80
|
+
super().__init__(defaults)
|
|
81
|
+
self.set_children_sequence(modules)
|
|
82
|
+
|
|
83
|
+
@torch.no_grad
|
|
84
|
+
def step(self, var):
|
|
85
|
+
return _sequential_step(self, var, sequential=True)
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
class NegateOnLossIncrease(Module):
|
|
89
|
+
"""Uses an extra forward pass to evaluate loss at :code:`parameters+update`,
|
|
90
|
+
if loss is larger than at :code:`parameters`,
|
|
91
|
+
the update is set to 0 if :code:`backtrack=False` and to :code:`-update` otherwise"""
|
|
92
|
+
def __init__(self, backtrack=False):
|
|
93
|
+
defaults = dict(backtrack=backtrack)
|
|
94
|
+
super().__init__(defaults=defaults)
|
|
95
|
+
|
|
96
|
+
@torch.no_grad
|
|
97
|
+
def step(self, var):
|
|
98
|
+
closure = var.closure
|
|
99
|
+
if closure is None: raise RuntimeError('NegateOnLossIncrease requires closure')
|
|
100
|
+
backtrack = self.defaults['backtrack']
|
|
101
|
+
|
|
102
|
+
update = var.get_update()
|
|
103
|
+
f_0 = var.get_loss(backward=False)
|
|
104
|
+
|
|
105
|
+
torch._foreach_sub_(var.params, update)
|
|
106
|
+
f_1 = closure(False)
|
|
107
|
+
|
|
108
|
+
if f_1 <= f_0:
|
|
109
|
+
if var.is_last and var.last_module_lrs is None:
|
|
110
|
+
var.stop = True
|
|
111
|
+
var.skip_update = True
|
|
112
|
+
return var
|
|
113
|
+
|
|
114
|
+
torch._foreach_add_(var.params, update)
|
|
115
|
+
return var
|
|
116
|
+
|
|
117
|
+
torch._foreach_add_(var.params, update)
|
|
118
|
+
if backtrack:
|
|
119
|
+
torch._foreach_neg_(var.update)
|
|
120
|
+
else:
|
|
121
|
+
torch._foreach_zero_(var.update)
|
|
122
|
+
return var
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
class Online(Module):
|
|
126
|
+
"""Allows certain modules to be used for mini-batch optimization.
|
|
127
|
+
|
|
128
|
+
Examples:
|
|
129
|
+
|
|
130
|
+
Online L-BFGS with Backtracking line search
|
|
131
|
+
```python
|
|
132
|
+
opt = tz.Modular(
|
|
133
|
+
model.parameters(),
|
|
134
|
+
tz.m.Online(tz.m.LBFGS()),
|
|
135
|
+
tz.m.Backtracking()
|
|
136
|
+
)
|
|
137
|
+
```
|
|
138
|
+
|
|
139
|
+
Online L-BFGS trust region
|
|
140
|
+
```python
|
|
141
|
+
opt = tz.Modular(
|
|
142
|
+
model.parameters(),
|
|
143
|
+
tz.m.TrustCG(tz.m.Online(tz.m.LBFGS()))
|
|
144
|
+
)
|
|
145
|
+
```
|
|
146
|
+
|
|
147
|
+
"""
|
|
148
|
+
def __init__(self, *modules: Module,):
|
|
149
|
+
super().__init__()
|
|
150
|
+
|
|
151
|
+
self.set_child('module', modules)
|
|
152
|
+
|
|
153
|
+
@torch.no_grad
|
|
154
|
+
def update(self, var):
|
|
155
|
+
closure = var.closure
|
|
156
|
+
if closure is None: raise ValueError("Closure must be passed for Online")
|
|
157
|
+
|
|
158
|
+
step = self.global_state.get('step', 0) + 1
|
|
159
|
+
self.global_state['step'] = step
|
|
160
|
+
|
|
161
|
+
params = TensorList(var.params)
|
|
162
|
+
p_cur = params.clone()
|
|
163
|
+
p_prev = self.get_state(params, 'p_prev', cls=TensorList)
|
|
164
|
+
|
|
165
|
+
module = self.children['module']
|
|
166
|
+
var_c = var.clone(clone_update=False)
|
|
167
|
+
|
|
168
|
+
# on 1st step just step and store previous params
|
|
169
|
+
if step == 1:
|
|
170
|
+
p_prev.copy_(params)
|
|
171
|
+
|
|
172
|
+
module.update(var_c)
|
|
173
|
+
var.update_attrs_from_clone_(var_c)
|
|
174
|
+
return
|
|
175
|
+
|
|
176
|
+
# restore previous params and update
|
|
177
|
+
var_prev = Var(params=params, closure=closure, model=var.model, current_step=var.current_step)
|
|
178
|
+
params.set_(p_prev)
|
|
179
|
+
module.reset_for_online()
|
|
180
|
+
module.update(var_prev)
|
|
181
|
+
|
|
182
|
+
# restore current params and update
|
|
183
|
+
params.set_(p_cur)
|
|
184
|
+
p_prev.copy_(params)
|
|
185
|
+
module.update(var_c)
|
|
186
|
+
var.update_attrs_from_clone_(var_c)
|
|
187
|
+
|
|
188
|
+
@torch.no_grad
|
|
189
|
+
def apply(self, var):
|
|
190
|
+
module = self.children['module']
|
|
191
|
+
return module.apply(var.clone(clone_update=False))
|
|
192
|
+
|
|
193
|
+
def get_H(self, var):
|
|
194
|
+
return self.children['module'].get_H(var)
|
|
@@ -0,0 +1,167 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
from ...core import Chainable, Module, Target, Transform
|
|
4
|
+
from ...core.reformulation import Reformulation
|
|
5
|
+
from ...utils import Distributions, NumberList, TensorList
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class Dropout(Transform):
|
|
9
|
+
"""Applies dropout to the update.
|
|
10
|
+
|
|
11
|
+
For each weight the update to that weight has :code:`p` probability to be set to 0.
|
|
12
|
+
This can be used to implement gradient dropout or update dropout depending on placement.
|
|
13
|
+
|
|
14
|
+
Args:
|
|
15
|
+
p (float, optional): probability that update for a weight is replaced with 0. Defaults to 0.5.
|
|
16
|
+
graft (bool, optional):
|
|
17
|
+
if True, update after dropout is rescaled to have the same norm as before dropout. Defaults to False.
|
|
18
|
+
target (Target, optional): what to set on var, refer to documentation. Defaults to 'update'.
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
Examples:
|
|
22
|
+
Gradient dropout.
|
|
23
|
+
|
|
24
|
+
.. code-block:: python
|
|
25
|
+
|
|
26
|
+
opt = tz.Modular(
|
|
27
|
+
model.parameters(),
|
|
28
|
+
tz.m.Dropout(0.5),
|
|
29
|
+
tz.m.Adam(),
|
|
30
|
+
tz.m.LR(1e-3)
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
Update dropout.
|
|
34
|
+
|
|
35
|
+
.. code-block:: python
|
|
36
|
+
|
|
37
|
+
opt = tz.Modular(
|
|
38
|
+
model.parameters(),
|
|
39
|
+
tz.m.Adam(),
|
|
40
|
+
tz.m.Dropout(0.5),
|
|
41
|
+
tz.m.LR(1e-3)
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
"""
|
|
45
|
+
def __init__(self, p: float = 0.5, graft: bool=False, target: Target = 'update'):
|
|
46
|
+
defaults = dict(p=p, graft=graft)
|
|
47
|
+
super().__init__(defaults, uses_grad=False, target=target)
|
|
48
|
+
|
|
49
|
+
@torch.no_grad
|
|
50
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
51
|
+
tensors = TensorList(tensors)
|
|
52
|
+
p = NumberList(s['p'] for s in settings)
|
|
53
|
+
graft = settings[0]['graft']
|
|
54
|
+
|
|
55
|
+
if graft:
|
|
56
|
+
target_norm = tensors.global_vector_norm()
|
|
57
|
+
tensors.mul_(tensors.rademacher_like(1-p).add_(1).div_(2))
|
|
58
|
+
return tensors.mul_(target_norm / tensors.global_vector_norm()) # graft
|
|
59
|
+
|
|
60
|
+
return tensors.mul_(tensors.rademacher_like(1-p).add_(1).div_(2))
|
|
61
|
+
|
|
62
|
+
def _bernoulli_like(tensor, p = 0.5, generator = None):
|
|
63
|
+
"""p is probability of a 1, other values will be 0."""
|
|
64
|
+
return torch.bernoulli(torch.full_like(tensor, p), generator = generator)
|
|
65
|
+
|
|
66
|
+
class WeightDropout(Module):
|
|
67
|
+
"""
|
|
68
|
+
Changes the closure so that it evaluates loss and gradients with random weights replaced with 0.
|
|
69
|
+
|
|
70
|
+
Dropout can be disabled for a parameter by setting :code:`use_dropout=False` in corresponding parameter group.
|
|
71
|
+
|
|
72
|
+
Args:
|
|
73
|
+
p (float, optional): probability that any weight is replaced with 0. Defaults to 0.5.
|
|
74
|
+
graft (bool, optional):
|
|
75
|
+
if True, parameters after dropout are rescaled to have the same norm as before dropout. Defaults to False.
|
|
76
|
+
"""
|
|
77
|
+
def __init__(self, p: float = 0.5, graft: bool = True):
|
|
78
|
+
defaults = dict(p=p, graft=graft, use_dropout=True)
|
|
79
|
+
super().__init__(defaults)
|
|
80
|
+
|
|
81
|
+
@torch.no_grad
|
|
82
|
+
def step(self, var):
|
|
83
|
+
closure = var.closure
|
|
84
|
+
if closure is None: raise RuntimeError('WeightDropout requires closure')
|
|
85
|
+
params = TensorList(var.params)
|
|
86
|
+
p = NumberList(self.settings[p]['p'] for p in params)
|
|
87
|
+
|
|
88
|
+
# create masks
|
|
89
|
+
mask = []
|
|
90
|
+
for p, m in zip(params, mask):
|
|
91
|
+
prob = self.settings[p]['p']
|
|
92
|
+
use_dropout = self.settings[p]['use_dropout']
|
|
93
|
+
if use_dropout: mask.append(_bernoulli_like(p, prob))
|
|
94
|
+
else: mask.append(torch.ones_like(p))
|
|
95
|
+
|
|
96
|
+
@torch.no_grad
|
|
97
|
+
def dropout_closure(backward=True):
|
|
98
|
+
orig_params = params.clone()
|
|
99
|
+
params.mul_(mask)
|
|
100
|
+
if backward:
|
|
101
|
+
with torch.enable_grad(): loss = closure()
|
|
102
|
+
else:
|
|
103
|
+
loss = closure(False)
|
|
104
|
+
params.copy_(orig_params)
|
|
105
|
+
return loss
|
|
106
|
+
|
|
107
|
+
var.closure = dropout_closure
|
|
108
|
+
return var
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
class PerturbWeights(Module):
|
|
112
|
+
"""
|
|
113
|
+
Changes the closure so that it evaluates loss and gradients at weights perturbed by a random perturbation.
|
|
114
|
+
|
|
115
|
+
Can be disabled for a parameter by setting :code:`perturb=False` in corresponding parameter group.
|
|
116
|
+
|
|
117
|
+
Args:
|
|
118
|
+
alpha (float, optional): multiplier for perturbation magnitude. Defaults to 0.1.
|
|
119
|
+
relative (bool, optional): whether to multiply perturbation by mean absolute value of the parameter. Defaults to True.
|
|
120
|
+
distribution (bool, optional):
|
|
121
|
+
distribution of the random perturbation. Defaults to False.
|
|
122
|
+
"""
|
|
123
|
+
def __init__(self, alpha: float = 0.1, relative:bool=True, distribution:Distributions = 'normal'):
|
|
124
|
+
defaults = dict(alpha=alpha, relative=relative, distribution=distribution, perturb=True)
|
|
125
|
+
super().__init__(defaults)
|
|
126
|
+
|
|
127
|
+
@torch.no_grad
|
|
128
|
+
def step(self, var):
|
|
129
|
+
closure = var.closure
|
|
130
|
+
if closure is None: raise RuntimeError('WeightDropout requires closure')
|
|
131
|
+
params = TensorList(var.params)
|
|
132
|
+
|
|
133
|
+
# create perturbations
|
|
134
|
+
perts = []
|
|
135
|
+
for p in params:
|
|
136
|
+
settings = self.settings[p]
|
|
137
|
+
if not settings['perturb']:
|
|
138
|
+
perts.append(torch.zeros_like(p))
|
|
139
|
+
continue
|
|
140
|
+
|
|
141
|
+
alpha = settings['alpha']
|
|
142
|
+
if settings['relative']:
|
|
143
|
+
alpha *= p.abs().mean()
|
|
144
|
+
|
|
145
|
+
distribution = self.settings[p]['distribution'].lower()
|
|
146
|
+
if distribution in ('normal', 'gaussian'):
|
|
147
|
+
perts.append(torch.randn_like(p).mul_(alpha))
|
|
148
|
+
elif distribution == 'uniform':
|
|
149
|
+
perts.append(torch.empty_like(p).uniform_(-alpha,alpha))
|
|
150
|
+
elif distribution == 'sphere':
|
|
151
|
+
r = torch.randn_like(p)
|
|
152
|
+
perts.append((r * alpha) / torch.linalg.vector_norm(r)) # pylint:disable=not-callable
|
|
153
|
+
else:
|
|
154
|
+
raise ValueError(distribution)
|
|
155
|
+
|
|
156
|
+
@torch.no_grad
|
|
157
|
+
def perturbed_closure(backward=True):
|
|
158
|
+
params.add_(perts)
|
|
159
|
+
if backward:
|
|
160
|
+
with torch.enable_grad(): loss = closure()
|
|
161
|
+
else:
|
|
162
|
+
loss = closure(False)
|
|
163
|
+
params.sub_(perts)
|
|
164
|
+
return loss
|
|
165
|
+
|
|
166
|
+
var.closure = perturbed_closure
|
|
167
|
+
return var
|
|
@@ -0,0 +1,123 @@
|
|
|
1
|
+
import warnings
|
|
2
|
+
from collections.abc import Callable, Sequence, Iterable
|
|
3
|
+
from typing import cast
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
from ...core import Chainable, Module, Var
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def _split(
|
|
11
|
+
module: Module,
|
|
12
|
+
idxs,
|
|
13
|
+
params,
|
|
14
|
+
var: Var,
|
|
15
|
+
):
|
|
16
|
+
split_params = [p for i,p in enumerate(params) if i in idxs]
|
|
17
|
+
|
|
18
|
+
split_grad = None
|
|
19
|
+
if var.grad is not None:
|
|
20
|
+
split_grad = [g for i,g in enumerate(var.grad) if i in idxs]
|
|
21
|
+
|
|
22
|
+
split_update = None
|
|
23
|
+
if var.update is not None:
|
|
24
|
+
split_update = [u for i,u in enumerate(var.update) if i in idxs]
|
|
25
|
+
|
|
26
|
+
split_var = var.clone(clone_update=False, parent=var)
|
|
27
|
+
split_var.params = split_params
|
|
28
|
+
split_var.grad = split_grad
|
|
29
|
+
split_var.update = split_update
|
|
30
|
+
|
|
31
|
+
split_var = module.step(split_var)
|
|
32
|
+
|
|
33
|
+
# those should be set due to var being parent
|
|
34
|
+
if split_var.grad is not None:
|
|
35
|
+
assert var.grad is not None
|
|
36
|
+
|
|
37
|
+
if split_var.loss is not None:
|
|
38
|
+
assert var.loss is not None
|
|
39
|
+
|
|
40
|
+
if split_var.update is not None:
|
|
41
|
+
|
|
42
|
+
# make sure update is set, it will be filled with ``true`` and ``false`` tensors
|
|
43
|
+
if var.update is None:
|
|
44
|
+
if var.grad is None: var.update = [cast(torch.Tensor, None) for _ in var.params]
|
|
45
|
+
else: var.update = [g.clone() for g in var.grad]
|
|
46
|
+
|
|
47
|
+
# set all tensors from this split
|
|
48
|
+
for idx, u in zip(idxs, split_var.update):
|
|
49
|
+
var.update[idx] = u
|
|
50
|
+
|
|
51
|
+
return var
|
|
52
|
+
|
|
53
|
+
_SingleFilter = Callable[[torch.Tensor], bool] | torch.Tensor | Iterable[torch.Tensor] | torch.nn.Module | Iterable[torch.nn.Module]
|
|
54
|
+
Filter = _SingleFilter | Iterable[_SingleFilter]
|
|
55
|
+
|
|
56
|
+
def _make_filter(filter: Filter):
|
|
57
|
+
if callable(filter): return filter
|
|
58
|
+
if isinstance(filter, torch.Tensor):
|
|
59
|
+
return lambda x: x is filter
|
|
60
|
+
if isinstance(filter, torch.nn.Module):
|
|
61
|
+
return _make_filter(filter.parameters())
|
|
62
|
+
|
|
63
|
+
# iterable
|
|
64
|
+
filters = [_make_filter(f) for f in filter]
|
|
65
|
+
return lambda x: any(f(x) for f in filters)
|
|
66
|
+
|
|
67
|
+
class Split(Module):
|
|
68
|
+
"""Apply ``true`` modules to all parameters filtered by ``filter``, apply ``false`` modules to all other parameters.
|
|
69
|
+
|
|
70
|
+
Args:
|
|
71
|
+
filter (Filter, bool]):
|
|
72
|
+
a filter that selects tensors to be optimized by ``true``.
|
|
73
|
+
- tensor or iterable of tensors (e.g. ``encoder.parameters()``).
|
|
74
|
+
- function that takes in tensor and outputs a bool (e.g. ``lambda x: x.ndim >= 2``).
|
|
75
|
+
- a sequence of above (acts as "or", so returns true if any of them is true).
|
|
76
|
+
|
|
77
|
+
true (Chainable | None): modules that are applied to tensors where ``filter`` is ``True``.
|
|
78
|
+
false (Chainable | None): modules that are applied to tensors where ``filter`` is ``False``.
|
|
79
|
+
|
|
80
|
+
### Examples:
|
|
81
|
+
|
|
82
|
+
Muon with Adam fallback using same hyperparams as https://github.com/KellerJordan/Muon
|
|
83
|
+
|
|
84
|
+
```python
|
|
85
|
+
opt = tz.Modular(
|
|
86
|
+
model.parameters(),
|
|
87
|
+
tz.m.NAG(0.95),
|
|
88
|
+
tz.m.Split(
|
|
89
|
+
lambda p: p.ndim >= 2,
|
|
90
|
+
true = tz.m.Orthogonalize(),
|
|
91
|
+
false = [tz.m.Adam(0.9, 0.95), tz.m.Mul(1/66)],
|
|
92
|
+
),
|
|
93
|
+
tz.m.LR(1e-2),
|
|
94
|
+
)
|
|
95
|
+
```
|
|
96
|
+
"""
|
|
97
|
+
def __init__(self, filter: Filter, true: Chainable | None, false: Chainable | None):
|
|
98
|
+
defaults = dict(filter=filter)
|
|
99
|
+
super().__init__(defaults)
|
|
100
|
+
|
|
101
|
+
if true is not None: self.set_child('true', true)
|
|
102
|
+
if false is not None: self.set_child('false', false)
|
|
103
|
+
|
|
104
|
+
def step(self, var):
|
|
105
|
+
|
|
106
|
+
params = var.params
|
|
107
|
+
filter = _make_filter(self.settings[params[0]]['filter'])
|
|
108
|
+
|
|
109
|
+
true_idxs = []
|
|
110
|
+
false_idxs = []
|
|
111
|
+
for i,p in enumerate(params):
|
|
112
|
+
if filter(p): true_idxs.append(i)
|
|
113
|
+
else: false_idxs.append(i)
|
|
114
|
+
|
|
115
|
+
if 'true' in self.children and len(true_idxs) > 0:
|
|
116
|
+
true = self.children['true']
|
|
117
|
+
var = _split(true, idxs=true_idxs, params=params, var=var)
|
|
118
|
+
|
|
119
|
+
if 'false' in self.children and len(false_idxs) > 0:
|
|
120
|
+
false = self.children['false']
|
|
121
|
+
var = _split(false, idxs=false_idxs, params=params, var=var)
|
|
122
|
+
|
|
123
|
+
return var
|
|
@@ -7,7 +7,28 @@ from ...core import Chainable, Module
|
|
|
7
7
|
|
|
8
8
|
|
|
9
9
|
class Alternate(Module):
|
|
10
|
-
"""
|
|
10
|
+
"""Alternates between stepping with :code:`modules`.
|
|
11
|
+
|
|
12
|
+
That is, first step is performed with 1st module, second step with second module, etc.
|
|
13
|
+
|
|
14
|
+
Args:
|
|
15
|
+
steps (int | Iterable[int], optional): number of steps to perform with each module. Defaults to 1.
|
|
16
|
+
|
|
17
|
+
Examples:
|
|
18
|
+
Alternate between Adam, SignSGD and RMSprop
|
|
19
|
+
|
|
20
|
+
.. code-block:: python
|
|
21
|
+
|
|
22
|
+
opt = tz.Modular(
|
|
23
|
+
model.parameters(),
|
|
24
|
+
tz.m.Alternate(
|
|
25
|
+
tz.m.Adam(),
|
|
26
|
+
[tz.m.SignSGD(), tz.m.Mul(0.5)],
|
|
27
|
+
tz.m.RMSprop(),
|
|
28
|
+
),
|
|
29
|
+
tz.m.LR(1e-3),
|
|
30
|
+
)
|
|
31
|
+
"""
|
|
11
32
|
LOOP = True
|
|
12
33
|
def __init__(self, *modules: Chainable, steps: int | Iterable[int] = 1):
|
|
13
34
|
if isinstance(steps, Iterable):
|
|
@@ -32,7 +53,7 @@ class Alternate(Module):
|
|
|
32
53
|
var = module.step(var.clone(clone_update=False))
|
|
33
54
|
|
|
34
55
|
# number of steps until next module
|
|
35
|
-
steps = self.
|
|
56
|
+
steps = self.defaults['steps']
|
|
36
57
|
if isinstance(steps, int): steps = [steps]*len(self.children)
|
|
37
58
|
|
|
38
59
|
if 'steps_to_next' not in self.global_state:
|
|
@@ -54,14 +75,34 @@ class Alternate(Module):
|
|
|
54
75
|
return var
|
|
55
76
|
|
|
56
77
|
class Switch(Alternate):
|
|
57
|
-
"""
|
|
78
|
+
"""After :code:`steps` steps switches to the next module.
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
steps (int | Iterable[int]): Number of steps to perform with each module.
|
|
82
|
+
|
|
83
|
+
Examples:
|
|
84
|
+
Start with Adam, switch to L-BFGS after 1000th step and Truncated Newton on 2000th step.
|
|
85
|
+
|
|
86
|
+
.. code-block:: python
|
|
87
|
+
|
|
88
|
+
opt = tz.Modular(
|
|
89
|
+
model.parameters(),
|
|
90
|
+
tz.m.Switch(
|
|
91
|
+
[tz.m.Adam(), tz.m.LR(1e-3)],
|
|
92
|
+
[tz.m.LBFGS(), tz.m.Backtracking()],
|
|
93
|
+
[tz.m.NewtonCG(maxiter=20), tz.m.Backtracking()],
|
|
94
|
+
steps = (1000, 2000)
|
|
95
|
+
)
|
|
96
|
+
)
|
|
97
|
+
"""
|
|
98
|
+
|
|
58
99
|
LOOP = False
|
|
59
100
|
def __init__(self, *modules: Chainable, steps: int | Iterable[int]):
|
|
60
101
|
|
|
61
102
|
if isinstance(steps, Iterable):
|
|
62
103
|
steps = list(steps)
|
|
63
104
|
if len(steps) != len(modules) - 1:
|
|
64
|
-
raise ValueError(f"steps must be the same length as modules, got {len(modules) = }, {len(steps) = }")
|
|
105
|
+
raise ValueError(f"steps must be the same length as modules minus 1, got {len(modules) = }, {len(steps) = }")
|
|
65
106
|
|
|
66
107
|
steps.append(1)
|
|
67
108
|
|
|
@@ -6,9 +6,5 @@ from .cautious import (
|
|
|
6
6
|
ScaleModulesByCosineSimilarity,
|
|
7
7
|
UpdateGradientSignConsistency,
|
|
8
8
|
)
|
|
9
|
-
from .ema import EMA, Debias, Debias2, EMASquared, SqrtEMASquared, CenteredEMASquared, CenteredSqrtEMASquared
|
|
10
|
-
from .experimental import CoordinateMomentum
|
|
11
|
-
# from .matrix_momentum import MatrixMomentum
|
|
12
9
|
|
|
13
|
-
from .momentum import NAG, HeavyBall
|
|
14
|
-
from .matrix_momentum import MatrixMomentum, AdaptiveMatrixMomentum
|
|
10
|
+
from .momentum import NAG, HeavyBall, EMA
|
|
@@ -10,7 +10,7 @@ from ...utils import tolist
|
|
|
10
10
|
|
|
11
11
|
|
|
12
12
|
class Averaging(TensorwiseTransform):
|
|
13
|
-
"""Average of past
|
|
13
|
+
"""Average of past ``history_size`` updates.
|
|
14
14
|
|
|
15
15
|
Args:
|
|
16
16
|
history_size (int): Number of past updates to average
|
|
@@ -21,8 +21,8 @@ class Averaging(TensorwiseTransform):
|
|
|
21
21
|
super().__init__(uses_grad=False, defaults=defaults, target=target)
|
|
22
22
|
|
|
23
23
|
@torch.no_grad
|
|
24
|
-
def apply_tensor(self, tensor, param, grad, loss, state,
|
|
25
|
-
history_size =
|
|
24
|
+
def apply_tensor(self, tensor, param, grad, loss, state, setting):
|
|
25
|
+
history_size = setting['history_size']
|
|
26
26
|
if 'history' not in state:
|
|
27
27
|
state['history'] = deque(maxlen=history_size)
|
|
28
28
|
state['average'] = torch.zeros_like(tensor)
|
|
@@ -35,7 +35,7 @@ class Averaging(TensorwiseTransform):
|
|
|
35
35
|
return average / len(history)
|
|
36
36
|
|
|
37
37
|
class WeightedAveraging(TensorwiseTransform):
|
|
38
|
-
"""Weighted average of past
|
|
38
|
+
"""Weighted average of past ``len(weights)`` updates.
|
|
39
39
|
|
|
40
40
|
Args:
|
|
41
41
|
weights (Sequence[float]): a sequence of weights from oldest to newest.
|
|
@@ -46,8 +46,8 @@ class WeightedAveraging(TensorwiseTransform):
|
|
|
46
46
|
super().__init__(uses_grad=False, defaults=defaults, target=target)
|
|
47
47
|
|
|
48
48
|
@torch.no_grad
|
|
49
|
-
def apply_tensor(self, tensor, param, grad, loss, state,
|
|
50
|
-
weights =
|
|
49
|
+
def apply_tensor(self, tensor, param, grad, loss, state, setting):
|
|
50
|
+
weights = setting['weights']
|
|
51
51
|
|
|
52
52
|
if 'history' not in state:
|
|
53
53
|
state['history'] = deque(maxlen=len(weights))
|
|
@@ -69,7 +69,7 @@ class WeightedAveraging(TensorwiseTransform):
|
|
|
69
69
|
|
|
70
70
|
|
|
71
71
|
class MedianAveraging(TensorwiseTransform):
|
|
72
|
-
"""Median of past
|
|
72
|
+
"""Median of past ``history_size`` updates.
|
|
73
73
|
|
|
74
74
|
Args:
|
|
75
75
|
history_size (int): Number of past updates to average
|
|
@@ -80,8 +80,8 @@ class MedianAveraging(TensorwiseTransform):
|
|
|
80
80
|
super().__init__(uses_grad=False, defaults=defaults, target=target)
|
|
81
81
|
|
|
82
82
|
@torch.no_grad
|
|
83
|
-
def apply_tensor(self, tensor, param, grad, loss, state,
|
|
84
|
-
history_size =
|
|
83
|
+
def apply_tensor(self, tensor, param, grad, loss, state, setting):
|
|
84
|
+
history_size = setting['history_size']
|
|
85
85
|
|
|
86
86
|
if 'history' not in state:
|
|
87
87
|
state['history'] = deque(maxlen=history_size)
|