torchzero 0.3.11__py3-none-any.whl → 0.3.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_opts.py +95 -76
- tests/test_tensorlist.py +8 -7
- torchzero/__init__.py +1 -1
- torchzero/core/__init__.py +2 -2
- torchzero/core/module.py +229 -72
- torchzero/core/reformulation.py +65 -0
- torchzero/core/transform.py +44 -24
- torchzero/modules/__init__.py +13 -5
- torchzero/modules/{optimizers → adaptive}/__init__.py +5 -2
- torchzero/modules/adaptive/adagrad.py +356 -0
- torchzero/modules/{optimizers → adaptive}/adahessian.py +53 -52
- torchzero/modules/{optimizers → adaptive}/adam.py +0 -3
- torchzero/modules/{optimizers → adaptive}/adan.py +26 -40
- torchzero/modules/{optimizers → adaptive}/adaptive_heavyball.py +3 -6
- torchzero/modules/adaptive/aegd.py +54 -0
- torchzero/modules/{optimizers → adaptive}/esgd.py +1 -1
- torchzero/modules/{optimizers/ladagrad.py → adaptive/lmadagrad.py} +42 -39
- torchzero/modules/{optimizers → adaptive}/mars.py +24 -36
- torchzero/modules/adaptive/matrix_momentum.py +146 -0
- torchzero/modules/{optimizers → adaptive}/msam.py +14 -12
- torchzero/modules/{optimizers → adaptive}/muon.py +19 -20
- torchzero/modules/adaptive/natural_gradient.py +175 -0
- torchzero/modules/{optimizers → adaptive}/rprop.py +0 -2
- torchzero/modules/{optimizers → adaptive}/sam.py +1 -1
- torchzero/modules/{optimizers → adaptive}/shampoo.py +8 -4
- torchzero/modules/{optimizers → adaptive}/soap.py +27 -50
- torchzero/modules/{optimizers → adaptive}/sophia_h.py +2 -3
- torchzero/modules/clipping/clipping.py +85 -92
- torchzero/modules/clipping/ema_clipping.py +5 -5
- torchzero/modules/conjugate_gradient/__init__.py +11 -0
- torchzero/modules/{quasi_newton → conjugate_gradient}/cg.py +355 -369
- torchzero/modules/experimental/__init__.py +9 -32
- torchzero/modules/experimental/dct.py +2 -2
- torchzero/modules/experimental/fft.py +2 -2
- torchzero/modules/experimental/gradmin.py +4 -3
- torchzero/modules/experimental/l_infinity.py +111 -0
- torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +3 -40
- torchzero/modules/experimental/newton_solver.py +79 -17
- torchzero/modules/experimental/newtonnewton.py +27 -14
- torchzero/modules/experimental/scipy_newton_cg.py +105 -0
- torchzero/modules/experimental/spsa1.py +93 -0
- torchzero/modules/experimental/structural_projections.py +1 -1
- torchzero/modules/functional.py +50 -14
- torchzero/modules/grad_approximation/__init__.py +1 -1
- torchzero/modules/grad_approximation/fdm.py +19 -20
- torchzero/modules/grad_approximation/forward_gradient.py +6 -7
- torchzero/modules/grad_approximation/grad_approximator.py +43 -47
- torchzero/modules/grad_approximation/rfdm.py +114 -175
- torchzero/modules/higher_order/__init__.py +1 -1
- torchzero/modules/higher_order/higher_order_newton.py +31 -23
- torchzero/modules/least_squares/__init__.py +1 -0
- torchzero/modules/least_squares/gn.py +161 -0
- torchzero/modules/line_search/__init__.py +2 -2
- torchzero/modules/line_search/_polyinterp.py +289 -0
- torchzero/modules/line_search/adaptive.py +69 -44
- torchzero/modules/line_search/backtracking.py +83 -70
- torchzero/modules/line_search/line_search.py +159 -68
- torchzero/modules/line_search/scipy.py +16 -4
- torchzero/modules/line_search/strong_wolfe.py +319 -220
- torchzero/modules/misc/__init__.py +8 -0
- torchzero/modules/misc/debug.py +4 -4
- torchzero/modules/misc/escape.py +9 -7
- torchzero/modules/misc/gradient_accumulation.py +88 -22
- torchzero/modules/misc/homotopy.py +59 -0
- torchzero/modules/misc/misc.py +82 -15
- torchzero/modules/misc/multistep.py +47 -11
- torchzero/modules/misc/regularization.py +5 -9
- torchzero/modules/misc/split.py +55 -35
- torchzero/modules/misc/switch.py +1 -1
- torchzero/modules/momentum/__init__.py +1 -5
- torchzero/modules/momentum/averaging.py +3 -3
- torchzero/modules/momentum/cautious.py +42 -47
- torchzero/modules/momentum/momentum.py +35 -1
- torchzero/modules/ops/__init__.py +9 -1
- torchzero/modules/ops/binary.py +9 -8
- torchzero/modules/{momentum/ema.py → ops/higher_level.py} +10 -33
- torchzero/modules/ops/multi.py +15 -15
- torchzero/modules/ops/reduce.py +1 -1
- torchzero/modules/ops/utility.py +12 -8
- torchzero/modules/projections/projection.py +4 -4
- torchzero/modules/quasi_newton/__init__.py +1 -16
- torchzero/modules/quasi_newton/damping.py +105 -0
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -163
- torchzero/modules/quasi_newton/lbfgs.py +256 -200
- torchzero/modules/quasi_newton/lsr1.py +167 -132
- torchzero/modules/quasi_newton/quasi_newton.py +346 -446
- torchzero/modules/restarts/__init__.py +7 -0
- torchzero/modules/restarts/restars.py +253 -0
- torchzero/modules/second_order/__init__.py +2 -1
- torchzero/modules/second_order/multipoint.py +238 -0
- torchzero/modules/second_order/newton.py +133 -88
- torchzero/modules/second_order/newton_cg.py +207 -170
- torchzero/modules/smoothing/__init__.py +1 -1
- torchzero/modules/smoothing/sampling.py +300 -0
- torchzero/modules/step_size/__init__.py +1 -1
- torchzero/modules/step_size/adaptive.py +312 -47
- torchzero/modules/termination/__init__.py +14 -0
- torchzero/modules/termination/termination.py +207 -0
- torchzero/modules/trust_region/__init__.py +5 -0
- torchzero/modules/trust_region/cubic_regularization.py +170 -0
- torchzero/modules/trust_region/dogleg.py +92 -0
- torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
- torchzero/modules/trust_region/trust_cg.py +99 -0
- torchzero/modules/trust_region/trust_region.py +350 -0
- torchzero/modules/variance_reduction/__init__.py +1 -0
- torchzero/modules/variance_reduction/svrg.py +208 -0
- torchzero/modules/weight_decay/weight_decay.py +65 -64
- torchzero/modules/zeroth_order/__init__.py +1 -0
- torchzero/modules/zeroth_order/cd.py +122 -0
- torchzero/optim/root.py +65 -0
- torchzero/optim/utility/split.py +8 -8
- torchzero/optim/wrappers/directsearch.py +0 -1
- torchzero/optim/wrappers/fcmaes.py +3 -2
- torchzero/optim/wrappers/nlopt.py +0 -2
- torchzero/optim/wrappers/optuna.py +2 -2
- torchzero/optim/wrappers/scipy.py +81 -22
- torchzero/utils/__init__.py +40 -4
- torchzero/utils/compile.py +1 -1
- torchzero/utils/derivatives.py +123 -111
- torchzero/utils/linalg/__init__.py +9 -2
- torchzero/utils/linalg/linear_operator.py +329 -0
- torchzero/utils/linalg/matrix_funcs.py +2 -2
- torchzero/utils/linalg/orthogonalize.py +2 -1
- torchzero/utils/linalg/qr.py +2 -2
- torchzero/utils/linalg/solve.py +226 -154
- torchzero/utils/metrics.py +83 -0
- torchzero/utils/optimizer.py +2 -2
- torchzero/utils/python_tools.py +7 -0
- torchzero/utils/tensorlist.py +105 -34
- torchzero/utils/torch_tools.py +9 -4
- torchzero-0.3.14.dist-info/METADATA +14 -0
- torchzero-0.3.14.dist-info/RECORD +167 -0
- {torchzero-0.3.11.dist-info → torchzero-0.3.14.dist-info}/top_level.txt +0 -1
- docs/source/conf.py +0 -59
- docs/source/docstring template.py +0 -46
- torchzero/modules/experimental/absoap.py +0 -253
- torchzero/modules/experimental/adadam.py +0 -118
- torchzero/modules/experimental/adamY.py +0 -131
- torchzero/modules/experimental/adam_lambertw.py +0 -149
- torchzero/modules/experimental/adaptive_step_size.py +0 -90
- torchzero/modules/experimental/adasoap.py +0 -177
- torchzero/modules/experimental/cosine.py +0 -214
- torchzero/modules/experimental/cubic_adam.py +0 -97
- torchzero/modules/experimental/eigendescent.py +0 -120
- torchzero/modules/experimental/etf.py +0 -195
- torchzero/modules/experimental/exp_adam.py +0 -113
- torchzero/modules/experimental/expanded_lbfgs.py +0 -141
- torchzero/modules/experimental/hnewton.py +0 -85
- torchzero/modules/experimental/modular_lbfgs.py +0 -265
- torchzero/modules/experimental/parabolic_search.py +0 -220
- torchzero/modules/experimental/subspace_preconditioners.py +0 -145
- torchzero/modules/experimental/tensor_adagrad.py +0 -42
- torchzero/modules/line_search/polynomial.py +0 -233
- torchzero/modules/momentum/matrix_momentum.py +0 -193
- torchzero/modules/optimizers/adagrad.py +0 -165
- torchzero/modules/quasi_newton/trust_region.py +0 -397
- torchzero/modules/smoothing/gaussian.py +0 -198
- torchzero-0.3.11.dist-info/METADATA +0 -404
- torchzero-0.3.11.dist-info/RECORD +0 -159
- torchzero-0.3.11.dist-info/licenses/LICENSE +0 -21
- /torchzero/modules/{optimizers → adaptive}/lion.py +0 -0
- /torchzero/modules/{optimizers → adaptive}/orthograd.py +0 -0
- /torchzero/modules/{optimizers → adaptive}/rmsprop.py +0 -0
- {torchzero-0.3.11.dist-info → torchzero-0.3.14.dist-info}/WHEEL +0 -0
torchzero/modules/misc/switch.py
CHANGED
|
@@ -53,7 +53,7 @@ class Alternate(Module):
|
|
|
53
53
|
var = module.step(var.clone(clone_update=False))
|
|
54
54
|
|
|
55
55
|
# number of steps until next module
|
|
56
|
-
steps = self.
|
|
56
|
+
steps = self.defaults['steps']
|
|
57
57
|
if isinstance(steps, int): steps = [steps]*len(self.children)
|
|
58
58
|
|
|
59
59
|
if 'steps_to_next' not in self.global_state:
|
|
@@ -6,9 +6,5 @@ from .cautious import (
|
|
|
6
6
|
ScaleModulesByCosineSimilarity,
|
|
7
7
|
UpdateGradientSignConsistency,
|
|
8
8
|
)
|
|
9
|
-
from .ema import EMA, Debias, Debias2, EMASquared, SqrtEMASquared, CenteredEMASquared, CenteredSqrtEMASquared
|
|
10
|
-
from .experimental import CoordinateMomentum
|
|
11
|
-
# from .matrix_momentum import MatrixMomentum
|
|
12
9
|
|
|
13
|
-
from .momentum import NAG, HeavyBall
|
|
14
|
-
from .matrix_momentum import MatrixMomentum, AdaptiveMatrixMomentum
|
|
10
|
+
from .momentum import NAG, HeavyBall, EMA
|
|
@@ -10,7 +10,7 @@ from ...utils import tolist
|
|
|
10
10
|
|
|
11
11
|
|
|
12
12
|
class Averaging(TensorwiseTransform):
|
|
13
|
-
"""Average of past
|
|
13
|
+
"""Average of past ``history_size`` updates.
|
|
14
14
|
|
|
15
15
|
Args:
|
|
16
16
|
history_size (int): Number of past updates to average
|
|
@@ -35,7 +35,7 @@ class Averaging(TensorwiseTransform):
|
|
|
35
35
|
return average / len(history)
|
|
36
36
|
|
|
37
37
|
class WeightedAveraging(TensorwiseTransform):
|
|
38
|
-
"""Weighted average of past
|
|
38
|
+
"""Weighted average of past ``len(weights)`` updates.
|
|
39
39
|
|
|
40
40
|
Args:
|
|
41
41
|
weights (Sequence[float]): a sequence of weights from oldest to newest.
|
|
@@ -69,7 +69,7 @@ class WeightedAveraging(TensorwiseTransform):
|
|
|
69
69
|
|
|
70
70
|
|
|
71
71
|
class MedianAveraging(TensorwiseTransform):
|
|
72
|
-
"""Median of past
|
|
72
|
+
"""Median of past ``history_size`` updates.
|
|
73
73
|
|
|
74
74
|
Args:
|
|
75
75
|
history_size (int): Number of past updates to average
|
|
@@ -48,24 +48,22 @@ class Cautious(Transform):
|
|
|
48
48
|
eps (float, optional): epsilon for normalization. Defaults to 1e-6.
|
|
49
49
|
mode (str, optional):
|
|
50
50
|
what to do with updates with inconsistent signs.
|
|
51
|
+
- "zero" - set them to zero (as in paper)
|
|
52
|
+
- "grad" - set them to the gradient (same as using update magnitude and gradient sign)
|
|
53
|
+
- "backtrack" - negate them
|
|
51
54
|
|
|
52
|
-
|
|
55
|
+
## Examples:
|
|
53
56
|
|
|
54
|
-
|
|
57
|
+
Cautious Adam
|
|
55
58
|
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
bench.parameters(),
|
|
65
|
-
tz.m.Adam(),
|
|
66
|
-
tz.m.Cautious(),
|
|
67
|
-
tz.m.LR(1e-2)
|
|
68
|
-
)
|
|
59
|
+
```python
|
|
60
|
+
opt = tz.Modular(
|
|
61
|
+
bench.parameters(),
|
|
62
|
+
tz.m.Adam(),
|
|
63
|
+
tz.m.Cautious(),
|
|
64
|
+
tz.m.LR(1e-2)
|
|
65
|
+
)
|
|
66
|
+
```
|
|
69
67
|
|
|
70
68
|
References:
|
|
71
69
|
Cautious Optimizers: Improving Training with One Line of Code. Kaizhao Liang, Lizhang Chen, Bo Liu, Qiang Liu
|
|
@@ -120,12 +118,9 @@ class IntermoduleCautious(Module):
|
|
|
120
118
|
eps (float, optional): epsilon for normalization. Defaults to 1e-6.
|
|
121
119
|
mode (str, optional):
|
|
122
120
|
what to do with updates with inconsistent signs.
|
|
123
|
-
|
|
124
|
-
"
|
|
125
|
-
|
|
126
|
-
"grad" - set them to the gradient
|
|
127
|
-
|
|
128
|
-
"backtrack" - negate them (same as using update magnitude and gradient sign)
|
|
121
|
+
- "zero" - set them to zero (as in paper)
|
|
122
|
+
- "grad" - set them to the gradient (same as using update magnitude and gradient sign)
|
|
123
|
+
- "backtrack" - negate them
|
|
129
124
|
"""
|
|
130
125
|
def __init__(
|
|
131
126
|
self,
|
|
@@ -153,7 +148,7 @@ class IntermoduleCautious(Module):
|
|
|
153
148
|
compare_var = compare.step(var.clone(clone_update=True))
|
|
154
149
|
var.update_attrs_from_clone_(compare_var)
|
|
155
150
|
|
|
156
|
-
mode, normalize, eps = itemgetter('mode', 'normalize', 'eps')(self.
|
|
151
|
+
mode, normalize, eps = itemgetter('mode', 'normalize', 'eps')(self.defaults)
|
|
157
152
|
var.update = cautious_(
|
|
158
153
|
TensorList(main_var.get_update()),
|
|
159
154
|
TensorList(compare_var.get_update()),
|
|
@@ -171,17 +166,17 @@ class ScaleByGradCosineSimilarity(Transform):
|
|
|
171
166
|
Args:
|
|
172
167
|
eps (float, optional): epsilon for division. Defaults to 1e-6.
|
|
173
168
|
|
|
174
|
-
Examples:
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
169
|
+
## Examples:
|
|
170
|
+
|
|
171
|
+
Scaled Adam
|
|
172
|
+
```python
|
|
173
|
+
opt = tz.Modular(
|
|
174
|
+
bench.parameters(),
|
|
175
|
+
tz.m.Adam(),
|
|
176
|
+
tz.m.ScaleByGradCosineSimilarity(),
|
|
177
|
+
tz.m.LR(1e-2)
|
|
178
|
+
)
|
|
179
|
+
```
|
|
185
180
|
"""
|
|
186
181
|
def __init__(
|
|
187
182
|
self,
|
|
@@ -209,19 +204,19 @@ class ScaleModulesByCosineSimilarity(Module):
|
|
|
209
204
|
compare (Chainable): module or sequence of modules to compare to
|
|
210
205
|
eps (float, optional): epsilon for division. Defaults to 1e-6.
|
|
211
206
|
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
207
|
+
## Examples:
|
|
208
|
+
|
|
209
|
+
Adam scaled by similarity to RMSprop
|
|
210
|
+
```python
|
|
211
|
+
opt = tz.Modular(
|
|
212
|
+
bench.parameters(),
|
|
213
|
+
tz.m.ScaleModulesByCosineSimilarity(
|
|
214
|
+
main = tz.m.Adam(),
|
|
215
|
+
compare = tz.m.RMSprop(0.999, debiased=True),
|
|
216
|
+
),
|
|
217
|
+
tz.m.LR(1e-2)
|
|
218
|
+
)
|
|
219
|
+
```
|
|
225
220
|
"""
|
|
226
221
|
def __init__(
|
|
227
222
|
self,
|
|
@@ -248,7 +243,7 @@ class ScaleModulesByCosineSimilarity(Module):
|
|
|
248
243
|
|
|
249
244
|
m = TensorList(main_var.get_update())
|
|
250
245
|
c = TensorList(compare_var.get_update())
|
|
251
|
-
eps = self.
|
|
246
|
+
eps = self.defaults['eps']
|
|
252
247
|
|
|
253
248
|
cos_sim = m.dot(c) / (m.global_vector_norm() * c.global_vector_norm()).clip(min=eps)
|
|
254
249
|
|
|
@@ -1,10 +1,44 @@
|
|
|
1
|
+
from collections import deque
|
|
2
|
+
from operator import itemgetter
|
|
1
3
|
from typing import Literal
|
|
2
4
|
|
|
3
5
|
import torch
|
|
4
6
|
|
|
5
7
|
from ...core import Target, Transform
|
|
6
8
|
from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
|
|
7
|
-
from
|
|
9
|
+
from ..functional import debias, ema_
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class EMA(Transform):
|
|
13
|
+
"""Maintains an exponential moving average of update.
|
|
14
|
+
|
|
15
|
+
Args:
|
|
16
|
+
momentum (float, optional): momentum (beta). Defaults to 0.9.
|
|
17
|
+
dampening (float, optional): momentum dampening. Defaults to 0.
|
|
18
|
+
debiased (bool, optional): whether to debias the EMA like in Adam. Defaults to False.
|
|
19
|
+
lerp (bool, optional): whether to use linear interpolation. Defaults to True.
|
|
20
|
+
ema_init (str, optional): initial values for the EMA, "zeros" or "update".
|
|
21
|
+
target (Target, optional): target to apply EMA to. Defaults to 'update'.
|
|
22
|
+
"""
|
|
23
|
+
def __init__(self, momentum:float=0.9, dampening:float=0, debiased: bool = False, lerp=True, ema_init: Literal['zeros', 'update'] = 'zeros', target: Target = 'update'):
|
|
24
|
+
defaults = dict(momentum=momentum,dampening=dampening,debiased=debiased,lerp=lerp,ema_init=ema_init)
|
|
25
|
+
super().__init__(defaults, uses_grad=False, target=target)
|
|
26
|
+
|
|
27
|
+
@torch.no_grad
|
|
28
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
29
|
+
step = self.global_state['step'] = self.global_state.get('step', 0) + 1
|
|
30
|
+
|
|
31
|
+
debiased, lerp, ema_init = itemgetter('debiased','lerp','ema_init')(settings[0])
|
|
32
|
+
|
|
33
|
+
exp_avg = unpack_states(states, tensors, 'exp_avg',
|
|
34
|
+
init=torch.zeros_like if ema_init=='zeros' else tensors, cls=TensorList)
|
|
35
|
+
momentum, dampening = unpack_dicts(settings, 'momentum','dampening', cls=NumberList)
|
|
36
|
+
|
|
37
|
+
exp_avg = ema_(TensorList(tensors), exp_avg_=exp_avg,beta=momentum,dampening=dampening,lerp=lerp)
|
|
38
|
+
|
|
39
|
+
if debiased: return debias(exp_avg, step=step, beta1=momentum, alpha=1, inplace=False)
|
|
40
|
+
else: return exp_avg.clone() # this has exp_avg storage so needs to be cloned
|
|
41
|
+
|
|
8
42
|
|
|
9
43
|
|
|
10
44
|
class HeavyBall(EMA):
|
|
@@ -27,6 +27,14 @@ from .binary import (
|
|
|
27
27
|
Sub,
|
|
28
28
|
Threshold,
|
|
29
29
|
)
|
|
30
|
+
from .higher_level import (
|
|
31
|
+
CenteredEMASquared,
|
|
32
|
+
CenteredSqrtEMASquared,
|
|
33
|
+
Debias,
|
|
34
|
+
Debias2,
|
|
35
|
+
EMASquared,
|
|
36
|
+
SqrtEMASquared,
|
|
37
|
+
)
|
|
30
38
|
from .multi import (
|
|
31
39
|
ClipModules,
|
|
32
40
|
DivModules,
|
|
@@ -64,7 +72,7 @@ from .utility import (
|
|
|
64
72
|
Grad,
|
|
65
73
|
GradToNone,
|
|
66
74
|
Identity,
|
|
67
|
-
|
|
75
|
+
Noop,
|
|
68
76
|
Ones,
|
|
69
77
|
Params,
|
|
70
78
|
Randn,
|
torchzero/modules/ops/binary.py
CHANGED
|
@@ -57,8 +57,8 @@ class Add(BinaryOperationBase):
|
|
|
57
57
|
|
|
58
58
|
@torch.no_grad
|
|
59
59
|
def transform(self, var, update: list[torch.Tensor], other: float | list[torch.Tensor]):
|
|
60
|
-
if isinstance(other, (int,float)): torch._foreach_add_(update, other * self.
|
|
61
|
-
else: torch._foreach_add_(update, other, alpha=self.
|
|
60
|
+
if isinstance(other, (int,float)): torch._foreach_add_(update, other * self.defaults['alpha'])
|
|
61
|
+
else: torch._foreach_add_(update, other, alpha=self.defaults['alpha'])
|
|
62
62
|
return update
|
|
63
63
|
|
|
64
64
|
class Sub(BinaryOperationBase):
|
|
@@ -72,8 +72,8 @@ class Sub(BinaryOperationBase):
|
|
|
72
72
|
|
|
73
73
|
@torch.no_grad
|
|
74
74
|
def transform(self, var, update: list[torch.Tensor], other: float | list[torch.Tensor]):
|
|
75
|
-
if isinstance(other, (int,float)): torch._foreach_sub_(update, other * self.
|
|
76
|
-
else: torch._foreach_sub_(update, other, alpha=self.
|
|
75
|
+
if isinstance(other, (int,float)): torch._foreach_sub_(update, other * self.defaults['alpha'])
|
|
76
|
+
else: torch._foreach_sub_(update, other, alpha=self.defaults['alpha'])
|
|
77
77
|
return update
|
|
78
78
|
|
|
79
79
|
class RSub(BinaryOperationBase):
|
|
@@ -219,7 +219,7 @@ class Graft(BinaryOperationBase):
|
|
|
219
219
|
|
|
220
220
|
@torch.no_grad
|
|
221
221
|
def transform(self, var, update: list[torch.Tensor], magnitude: list[torch.Tensor]):
|
|
222
|
-
tensorwise, ord, eps = itemgetter('tensorwise','ord','eps')(self.
|
|
222
|
+
tensorwise, ord, eps = itemgetter('tensorwise','ord','eps')(self.defaults)
|
|
223
223
|
return TensorList(update).graft_(magnitude, tensorwise=tensorwise, ord=ord, eps=eps)
|
|
224
224
|
|
|
225
225
|
class RGraft(BinaryOperationBase):
|
|
@@ -231,7 +231,7 @@ class RGraft(BinaryOperationBase):
|
|
|
231
231
|
|
|
232
232
|
@torch.no_grad
|
|
233
233
|
def transform(self, var, update: list[torch.Tensor], direction: list[torch.Tensor]):
|
|
234
|
-
tensorwise, ord, eps = itemgetter('tensorwise','ord','eps')(self.
|
|
234
|
+
tensorwise, ord, eps = itemgetter('tensorwise','ord','eps')(self.defaults)
|
|
235
235
|
return TensorList(direction).graft_(update, tensorwise=tensorwise, ord=ord, eps=eps)
|
|
236
236
|
|
|
237
237
|
GraftToUpdate = RGraft
|
|
@@ -265,7 +265,8 @@ class GramSchimdt(BinaryOperationBase):
|
|
|
265
265
|
@torch.no_grad
|
|
266
266
|
def transform(self, var, update: list[torch.Tensor], other: list[torch.Tensor]):
|
|
267
267
|
update = TensorList(update); other = TensorList(other)
|
|
268
|
-
|
|
268
|
+
min = torch.finfo(update[0].dtype).tiny * 2
|
|
269
|
+
return update - (other*update) / (other*other).clip(min=min)
|
|
269
270
|
|
|
270
271
|
|
|
271
272
|
class Threshold(BinaryOperationBase):
|
|
@@ -276,7 +277,7 @@ class Threshold(BinaryOperationBase):
|
|
|
276
277
|
|
|
277
278
|
@torch.no_grad
|
|
278
279
|
def transform(self, var, update: list[torch.Tensor], threshold: list[torch.Tensor] | float, value: list[torch.Tensor] | float):
|
|
279
|
-
update_above = self.
|
|
280
|
+
update_above = self.defaults['update_above']
|
|
280
281
|
update = TensorList(update)
|
|
281
282
|
if update_above:
|
|
282
283
|
if isinstance(value, list): return update.where_(update>threshold, value)
|
|
@@ -5,39 +5,16 @@ from typing import Literal
|
|
|
5
5
|
import torch
|
|
6
6
|
|
|
7
7
|
from ...core import Target, Transform
|
|
8
|
-
from ...utils import
|
|
9
|
-
from ..functional import
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
debiased (bool, optional): whether to debias the EMA like in Adam. Defaults to False.
|
|
19
|
-
lerp (bool, optional): whether to use linear interpolation. Defaults to True.
|
|
20
|
-
ema_init (str, optional): initial values for the EMA, "zeros" or "update".
|
|
21
|
-
target (Target, optional): target to apply EMA to. Defaults to 'update'.
|
|
22
|
-
"""
|
|
23
|
-
def __init__(self, momentum:float=0.9, dampening:float=0, debiased: bool = False, lerp=True, ema_init: Literal['zeros', 'update'] = 'zeros', target: Target = 'update'):
|
|
24
|
-
defaults = dict(momentum=momentum,dampening=dampening,debiased=debiased,lerp=lerp,ema_init=ema_init)
|
|
25
|
-
super().__init__(defaults, uses_grad=False, target=target)
|
|
26
|
-
|
|
27
|
-
@torch.no_grad
|
|
28
|
-
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
29
|
-
step = self.global_state['step'] = self.global_state.get('step', 0) + 1
|
|
30
|
-
|
|
31
|
-
debiased, lerp, ema_init = itemgetter('debiased','lerp','ema_init')(settings[0])
|
|
32
|
-
|
|
33
|
-
exp_avg = unpack_states(states, tensors, 'exp_avg',
|
|
34
|
-
init=torch.zeros_like if ema_init=='zeros' else tensors, cls=TensorList)
|
|
35
|
-
momentum, dampening = unpack_dicts(settings, 'momentum','dampening', cls=NumberList)
|
|
36
|
-
|
|
37
|
-
exp_avg = ema_(TensorList(tensors), exp_avg_=exp_avg,beta=momentum,dampening=dampening,lerp=lerp)
|
|
38
|
-
|
|
39
|
-
if debiased: return debias(exp_avg, step=step, beta1=momentum, alpha=1, inplace=False)
|
|
40
|
-
else: return exp_avg.clone() # this has exp_avg storage so needs to be cloned
|
|
8
|
+
from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
|
|
9
|
+
from ..functional import (
|
|
10
|
+
centered_ema_sq_,
|
|
11
|
+
debias,
|
|
12
|
+
debias_second_momentum,
|
|
13
|
+
ema_,
|
|
14
|
+
ema_sq_,
|
|
15
|
+
sqrt_centered_ema_sq_,
|
|
16
|
+
sqrt_ema_sq_,
|
|
17
|
+
)
|
|
41
18
|
|
|
42
19
|
|
|
43
20
|
class EMASquared(Transform):
|
torchzero/modules/ops/multi.py
CHANGED
|
@@ -8,7 +8,7 @@ from typing import Any, Literal
|
|
|
8
8
|
import torch
|
|
9
9
|
|
|
10
10
|
from ...core import Chainable, Module, Target, Var, maybe_chain
|
|
11
|
-
from ...utils import TensorList, tensorlist
|
|
11
|
+
from ...utils import TensorList, tensorlist, Metrics
|
|
12
12
|
|
|
13
13
|
|
|
14
14
|
class MultiOperationBase(Module, ABC):
|
|
@@ -59,7 +59,7 @@ class SubModules(MultiOperationBase):
|
|
|
59
59
|
|
|
60
60
|
@torch.no_grad
|
|
61
61
|
def transform(self, var: Var, input: float | list[torch.Tensor], other: float | list[torch.Tensor]) -> list[torch.Tensor]:
|
|
62
|
-
alpha = self.
|
|
62
|
+
alpha = self.defaults['alpha']
|
|
63
63
|
|
|
64
64
|
if isinstance(input, (int,float)):
|
|
65
65
|
assert isinstance(other, list)
|
|
@@ -112,7 +112,7 @@ class LerpModules(MultiOperationBase):
|
|
|
112
112
|
|
|
113
113
|
@torch.no_grad
|
|
114
114
|
def transform(self, var: Var, input: list[torch.Tensor], end: list[torch.Tensor]) -> list[torch.Tensor]:
|
|
115
|
-
torch._foreach_lerp_(input, end, weight=self.
|
|
115
|
+
torch._foreach_lerp_(input, end, weight=self.defaults['weight'])
|
|
116
116
|
return input
|
|
117
117
|
|
|
118
118
|
class ClipModules(MultiOperationBase):
|
|
@@ -154,45 +154,45 @@ class GraftModules(MultiOperationBase):
|
|
|
154
154
|
Reference:
|
|
155
155
|
Agarwal, N., Anil, R., Hazan, E., Koren, T., & Zhang, C. (2020). Disentangling adaptive gradient methods from learning rates. arXiv preprint arXiv:2002.11803. https://arxiv.org/pdf/2002.11803
|
|
156
156
|
"""
|
|
157
|
-
def __init__(self, direction: Chainable, magnitude: Chainable, tensorwise:bool=True, ord:
|
|
157
|
+
def __init__(self, direction: Chainable, magnitude: Chainable, tensorwise:bool=True, ord:Metrics=2, eps:float = 1e-6, strength:float=1):
|
|
158
158
|
defaults = dict(tensorwise=tensorwise, ord=ord, eps=eps, strength=strength)
|
|
159
159
|
super().__init__(defaults, direction=direction, magnitude=magnitude)
|
|
160
160
|
|
|
161
161
|
@torch.no_grad
|
|
162
162
|
def transform(self, var, magnitude: list[torch.Tensor], direction:list[torch.Tensor]):
|
|
163
|
-
tensorwise, ord, eps, strength = itemgetter('tensorwise','ord','eps', 'strength')(self.
|
|
163
|
+
tensorwise, ord, eps, strength = itemgetter('tensorwise','ord','eps', 'strength')(self.defaults)
|
|
164
164
|
return TensorList(direction).graft_(magnitude, tensorwise=tensorwise, ord=ord, eps=eps, strength=strength)
|
|
165
165
|
|
|
166
166
|
class MultiplyByModuleNorm(MultiOperationBase):
|
|
167
167
|
"""Outputs :code:`input` multiplied by norm of the :code:`norm` output."""
|
|
168
|
-
def __init__(self, input: Chainable, norm: Chainable, tensorwise:bool=True, ord:
|
|
168
|
+
def __init__(self, input: Chainable, norm: Chainable, tensorwise:bool=True, ord:Metrics=2):
|
|
169
169
|
defaults = dict(tensorwise=tensorwise, ord=ord)
|
|
170
170
|
super().__init__(defaults, input=input, norm=norm)
|
|
171
171
|
|
|
172
172
|
@torch.no_grad
|
|
173
173
|
def transform(self, var, input: list[torch.Tensor], norm:list[torch.Tensor]):
|
|
174
|
-
tensorwise, ord = itemgetter('tensorwise','ord')(self.
|
|
174
|
+
tensorwise, ord = itemgetter('tensorwise','ord')(self.defaults)
|
|
175
175
|
if tensorwise:
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
176
|
+
n = TensorList(norm).metric(ord)
|
|
177
|
+
else:
|
|
178
|
+
n = TensorList(norm).global_metric(ord)
|
|
179
179
|
|
|
180
180
|
torch._foreach_mul_(input, n)
|
|
181
181
|
return input
|
|
182
182
|
|
|
183
183
|
class DivideByModuleNorm(MultiOperationBase):
|
|
184
184
|
"""Outputs :code:`input` divided by norm of the :code:`norm` output."""
|
|
185
|
-
def __init__(self, input: Chainable, norm: Chainable, tensorwise:bool=True, ord:
|
|
185
|
+
def __init__(self, input: Chainable, norm: Chainable, tensorwise:bool=True, ord:Metrics=2):
|
|
186
186
|
defaults = dict(tensorwise=tensorwise, ord=ord)
|
|
187
187
|
super().__init__(defaults, input=input, norm=norm)
|
|
188
188
|
|
|
189
189
|
@torch.no_grad
|
|
190
190
|
def transform(self, var, input: list[torch.Tensor], norm:list[torch.Tensor]):
|
|
191
|
-
tensorwise, ord = itemgetter('tensorwise','ord')(self.
|
|
191
|
+
tensorwise, ord = itemgetter('tensorwise','ord')(self.defaults)
|
|
192
192
|
if tensorwise:
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
193
|
+
n = TensorList(norm).metric(ord)
|
|
194
|
+
else:
|
|
195
|
+
n = TensorList(norm).global_metric(ord)
|
|
196
196
|
|
|
197
197
|
torch._foreach_div_(input, n)
|
|
198
198
|
return input
|
torchzero/modules/ops/reduce.py
CHANGED
|
@@ -81,7 +81,7 @@ class WeightedSum(ReduceOperationBase):
|
|
|
81
81
|
@torch.no_grad
|
|
82
82
|
def transform(self, var: Var, *inputs: float | list[torch.Tensor]) -> list[torch.Tensor]:
|
|
83
83
|
sorted_inputs = sorted(inputs, key=lambda x: isinstance(x, float))
|
|
84
|
-
weights = self.
|
|
84
|
+
weights = self.defaults['weights']
|
|
85
85
|
sum = cast(list, sorted_inputs[0])
|
|
86
86
|
torch._foreach_mul_(sum, weights[0])
|
|
87
87
|
if len(sorted_inputs) > 1:
|
torchzero/modules/ops/utility.py
CHANGED
|
@@ -4,7 +4,7 @@ import torch
|
|
|
4
4
|
|
|
5
5
|
from ...core import Module, Target, Transform
|
|
6
6
|
from ...utils.tensorlist import Distributions, TensorList
|
|
7
|
-
|
|
7
|
+
from ...utils.linalg.linear_operator import ScaledIdentity
|
|
8
8
|
|
|
9
9
|
class Clone(Module):
|
|
10
10
|
"""Clones input. May be useful to store some intermediate result and make sure it doesn't get affected by in-place operations"""
|
|
@@ -64,15 +64,15 @@ class Fill(Module):
|
|
|
64
64
|
|
|
65
65
|
class RandomSample(Module):
|
|
66
66
|
"""Outputs tensors filled with random numbers from distribution depending on value of :code:`distribution`."""
|
|
67
|
-
def __init__(self,
|
|
68
|
-
defaults = dict(
|
|
67
|
+
def __init__(self, distribution: Distributions = 'normal', variance:float | None = None):
|
|
68
|
+
defaults = dict(distribution=distribution, variance=variance)
|
|
69
69
|
super().__init__(defaults)
|
|
70
70
|
|
|
71
71
|
@torch.no_grad
|
|
72
72
|
def step(self, var):
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
)
|
|
73
|
+
distribution = self.defaults['distribution']
|
|
74
|
+
variance = self.get_settings(var.params, 'variance')
|
|
75
|
+
var.update = TensorList(var.params).sample_like(distribution=distribution, variance=variance)
|
|
76
76
|
return var
|
|
77
77
|
|
|
78
78
|
class Randn(Module):
|
|
@@ -112,9 +112,13 @@ class UpdateToNone(Module):
|
|
|
112
112
|
return var
|
|
113
113
|
|
|
114
114
|
class Identity(Module):
|
|
115
|
-
"""
|
|
115
|
+
"""Identity operator that is argument-insensitive. This also can be used as identity hessian for trust region methods."""
|
|
116
116
|
def __init__(self, *args, **kwargs): super().__init__()
|
|
117
117
|
def step(self, var): return var
|
|
118
|
+
def get_H(self, var):
|
|
119
|
+
n = sum(p.numel() for p in var.params)
|
|
120
|
+
p = var.params[0]
|
|
121
|
+
return ScaledIdentity(shape=(n,n), device=p.device, dtype=p.dtype)
|
|
118
122
|
|
|
119
|
-
|
|
123
|
+
Noop = Identity
|
|
120
124
|
"""A placeholder identity operator that is argument-insensitive."""
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import math
|
|
2
2
|
import warnings
|
|
3
3
|
from abc import ABC, abstractmethod
|
|
4
|
-
from collections import
|
|
4
|
+
from collections import ChainMap, defaultdict
|
|
5
5
|
from collections.abc import Iterable, Mapping, Sequence
|
|
6
6
|
from functools import partial
|
|
7
7
|
from typing import Any, Literal
|
|
@@ -9,7 +9,7 @@ from typing import Any, Literal
|
|
|
9
9
|
import torch
|
|
10
10
|
|
|
11
11
|
from ...core import Chainable, Module, Var
|
|
12
|
-
from ...utils import
|
|
12
|
+
from ...utils import set_storage_, vec_to_tensors
|
|
13
13
|
|
|
14
14
|
|
|
15
15
|
def _make_projected_closure(closure, project_fn, unproject_fn,
|
|
@@ -166,7 +166,7 @@ class ProjectionBase(Module, ABC):
|
|
|
166
166
|
current=current,
|
|
167
167
|
))
|
|
168
168
|
|
|
169
|
-
projected_var = var.clone(clone_update=False)
|
|
169
|
+
projected_var = var.clone(clone_update=False, parent=var)
|
|
170
170
|
|
|
171
171
|
closure = var.closure
|
|
172
172
|
|
|
@@ -278,7 +278,7 @@ class ProjectionBase(Module, ABC):
|
|
|
278
278
|
unprojected_var = projected_var.clone(clone_update=False)
|
|
279
279
|
unprojected_var.closure = var.closure
|
|
280
280
|
unprojected_var.params = var.params
|
|
281
|
-
unprojected_var.grad = var.grad
|
|
281
|
+
unprojected_var.grad = var.grad # this may also be set by projected_var since it has var as parent
|
|
282
282
|
|
|
283
283
|
if self._project_update:
|
|
284
284
|
assert projected_var.update is not None
|
|
@@ -1,14 +1,3 @@
|
|
|
1
|
-
from .cg import (
|
|
2
|
-
ConjugateDescent,
|
|
3
|
-
DaiYuan,
|
|
4
|
-
FletcherReeves,
|
|
5
|
-
HagerZhang,
|
|
6
|
-
HestenesStiefel,
|
|
7
|
-
HybridHS_DY,
|
|
8
|
-
LiuStorey,
|
|
9
|
-
PolakRibiere,
|
|
10
|
-
ProjectedGradientMethod,
|
|
11
|
-
)
|
|
12
1
|
from .diagonal_quasi_newton import (
|
|
13
2
|
DNRTR,
|
|
14
3
|
DiagonalBFGS,
|
|
@@ -19,9 +8,6 @@ from .diagonal_quasi_newton import (
|
|
|
19
8
|
)
|
|
20
9
|
from .lbfgs import LBFGS
|
|
21
10
|
from .lsr1 import LSR1
|
|
22
|
-
# from .olbfgs import OnlineLBFGS
|
|
23
|
-
|
|
24
|
-
# from .experimental import ModularLBFGS
|
|
25
11
|
from .quasi_newton import (
|
|
26
12
|
BFGS,
|
|
27
13
|
DFP,
|
|
@@ -40,7 +26,6 @@ from .quasi_newton import (
|
|
|
40
26
|
NewSSM,
|
|
41
27
|
Pearson,
|
|
42
28
|
ProjectedNewtonRaphson,
|
|
43
|
-
ThomasOptimalMethod,
|
|
44
29
|
ShorR,
|
|
30
|
+
ThomasOptimalMethod,
|
|
45
31
|
)
|
|
46
|
-
from .trust_region import CubicRegularization, TrustCG, TrustRegionBase
|