torchzero 0.3.10__py3-none-any.whl → 0.3.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_identical.py +2 -3
- tests/test_opts.py +140 -100
- tests/test_tensorlist.py +8 -7
- tests/test_vars.py +1 -0
- torchzero/__init__.py +1 -1
- torchzero/core/__init__.py +2 -2
- torchzero/core/module.py +335 -50
- torchzero/core/reformulation.py +65 -0
- torchzero/core/transform.py +197 -70
- torchzero/modules/__init__.py +13 -4
- torchzero/modules/adaptive/__init__.py +30 -0
- torchzero/modules/adaptive/adagrad.py +356 -0
- torchzero/modules/adaptive/adahessian.py +224 -0
- torchzero/modules/{optimizers → adaptive}/adam.py +6 -8
- torchzero/modules/adaptive/adan.py +96 -0
- torchzero/modules/adaptive/adaptive_heavyball.py +54 -0
- torchzero/modules/adaptive/aegd.py +54 -0
- torchzero/modules/adaptive/esgd.py +171 -0
- torchzero/modules/{optimizers → adaptive}/lion.py +1 -1
- torchzero/modules/{experimental/spectral.py → adaptive/lmadagrad.py} +94 -71
- torchzero/modules/adaptive/mars.py +79 -0
- torchzero/modules/adaptive/matrix_momentum.py +146 -0
- torchzero/modules/adaptive/msam.py +188 -0
- torchzero/modules/{optimizers → adaptive}/muon.py +29 -5
- torchzero/modules/adaptive/natural_gradient.py +175 -0
- torchzero/modules/{optimizers → adaptive}/orthograd.py +1 -1
- torchzero/modules/{optimizers → adaptive}/rmsprop.py +7 -4
- torchzero/modules/{optimizers → adaptive}/rprop.py +42 -10
- torchzero/modules/adaptive/sam.py +163 -0
- torchzero/modules/{optimizers → adaptive}/shampoo.py +47 -9
- torchzero/modules/{optimizers → adaptive}/soap.py +52 -65
- torchzero/modules/adaptive/sophia_h.py +185 -0
- torchzero/modules/clipping/clipping.py +115 -25
- torchzero/modules/clipping/ema_clipping.py +31 -17
- torchzero/modules/clipping/growth_clipping.py +8 -7
- torchzero/modules/conjugate_gradient/__init__.py +11 -0
- torchzero/modules/conjugate_gradient/cg.py +355 -0
- torchzero/modules/experimental/__init__.py +13 -19
- torchzero/modules/{projections → experimental}/dct.py +11 -11
- torchzero/modules/{projections → experimental}/fft.py +10 -10
- torchzero/modules/experimental/gradmin.py +4 -3
- torchzero/modules/experimental/l_infinity.py +111 -0
- torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +5 -42
- torchzero/modules/experimental/newton_solver.py +79 -17
- torchzero/modules/experimental/newtonnewton.py +32 -15
- torchzero/modules/experimental/reduce_outward_lr.py +4 -4
- torchzero/modules/experimental/scipy_newton_cg.py +105 -0
- torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +13 -55
- torchzero/modules/functional.py +52 -6
- torchzero/modules/grad_approximation/fdm.py +30 -4
- torchzero/modules/grad_approximation/forward_gradient.py +16 -4
- torchzero/modules/grad_approximation/grad_approximator.py +51 -10
- torchzero/modules/grad_approximation/rfdm.py +321 -52
- torchzero/modules/higher_order/__init__.py +1 -1
- torchzero/modules/higher_order/higher_order_newton.py +164 -93
- torchzero/modules/least_squares/__init__.py +1 -0
- torchzero/modules/least_squares/gn.py +161 -0
- torchzero/modules/line_search/__init__.py +4 -4
- torchzero/modules/line_search/_polyinterp.py +289 -0
- torchzero/modules/line_search/adaptive.py +124 -0
- torchzero/modules/line_search/backtracking.py +95 -57
- torchzero/modules/line_search/line_search.py +171 -22
- torchzero/modules/line_search/scipy.py +3 -3
- torchzero/modules/line_search/strong_wolfe.py +327 -199
- torchzero/modules/misc/__init__.py +35 -0
- torchzero/modules/misc/debug.py +48 -0
- torchzero/modules/misc/escape.py +62 -0
- torchzero/modules/misc/gradient_accumulation.py +136 -0
- torchzero/modules/misc/homotopy.py +59 -0
- torchzero/modules/misc/misc.py +383 -0
- torchzero/modules/misc/multistep.py +194 -0
- torchzero/modules/misc/regularization.py +167 -0
- torchzero/modules/misc/split.py +123 -0
- torchzero/modules/{ops → misc}/switch.py +45 -4
- torchzero/modules/momentum/__init__.py +1 -5
- torchzero/modules/momentum/averaging.py +9 -9
- torchzero/modules/momentum/cautious.py +51 -19
- torchzero/modules/momentum/momentum.py +37 -2
- torchzero/modules/ops/__init__.py +11 -31
- torchzero/modules/ops/accumulate.py +6 -10
- torchzero/modules/ops/binary.py +81 -34
- torchzero/modules/{momentum/ema.py → ops/higher_level.py} +16 -39
- torchzero/modules/ops/multi.py +82 -21
- torchzero/modules/ops/reduce.py +16 -8
- torchzero/modules/ops/unary.py +29 -13
- torchzero/modules/ops/utility.py +30 -18
- torchzero/modules/projections/__init__.py +2 -4
- torchzero/modules/projections/cast.py +51 -0
- torchzero/modules/projections/galore.py +3 -1
- torchzero/modules/projections/projection.py +190 -96
- torchzero/modules/quasi_newton/__init__.py +9 -14
- torchzero/modules/quasi_newton/damping.py +105 -0
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -0
- torchzero/modules/quasi_newton/lbfgs.py +286 -173
- torchzero/modules/quasi_newton/lsr1.py +185 -106
- torchzero/modules/quasi_newton/quasi_newton.py +816 -268
- torchzero/modules/restarts/__init__.py +7 -0
- torchzero/modules/restarts/restars.py +252 -0
- torchzero/modules/second_order/__init__.py +3 -2
- torchzero/modules/second_order/multipoint.py +238 -0
- torchzero/modules/second_order/newton.py +292 -68
- torchzero/modules/second_order/newton_cg.py +365 -15
- torchzero/modules/second_order/nystrom.py +104 -1
- torchzero/modules/smoothing/__init__.py +1 -1
- torchzero/modules/smoothing/laplacian.py +14 -4
- torchzero/modules/smoothing/sampling.py +300 -0
- torchzero/modules/step_size/__init__.py +2 -0
- torchzero/modules/step_size/adaptive.py +387 -0
- torchzero/modules/step_size/lr.py +154 -0
- torchzero/modules/termination/__init__.py +14 -0
- torchzero/modules/termination/termination.py +207 -0
- torchzero/modules/trust_region/__init__.py +5 -0
- torchzero/modules/trust_region/cubic_regularization.py +170 -0
- torchzero/modules/trust_region/dogleg.py +92 -0
- torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
- torchzero/modules/trust_region/trust_cg.py +97 -0
- torchzero/modules/trust_region/trust_region.py +350 -0
- torchzero/modules/variance_reduction/__init__.py +1 -0
- torchzero/modules/variance_reduction/svrg.py +208 -0
- torchzero/modules/weight_decay/__init__.py +1 -1
- torchzero/modules/weight_decay/weight_decay.py +94 -11
- torchzero/modules/wrappers/optim_wrapper.py +29 -1
- torchzero/modules/zeroth_order/__init__.py +1 -0
- torchzero/modules/zeroth_order/cd.py +359 -0
- torchzero/optim/root.py +65 -0
- torchzero/optim/utility/split.py +8 -8
- torchzero/optim/wrappers/directsearch.py +39 -3
- torchzero/optim/wrappers/fcmaes.py +24 -15
- torchzero/optim/wrappers/mads.py +5 -6
- torchzero/optim/wrappers/nevergrad.py +16 -1
- torchzero/optim/wrappers/nlopt.py +0 -2
- torchzero/optim/wrappers/optuna.py +3 -3
- torchzero/optim/wrappers/scipy.py +86 -25
- torchzero/utils/__init__.py +40 -4
- torchzero/utils/compile.py +1 -1
- torchzero/utils/derivatives.py +126 -114
- torchzero/utils/linalg/__init__.py +9 -2
- torchzero/utils/linalg/linear_operator.py +329 -0
- torchzero/utils/linalg/matrix_funcs.py +2 -2
- torchzero/utils/linalg/orthogonalize.py +2 -1
- torchzero/utils/linalg/qr.py +2 -2
- torchzero/utils/linalg/solve.py +369 -58
- torchzero/utils/metrics.py +83 -0
- torchzero/utils/numberlist.py +2 -0
- torchzero/utils/python_tools.py +16 -0
- torchzero/utils/tensorlist.py +134 -51
- torchzero/utils/torch_tools.py +9 -4
- torchzero-0.3.13.dist-info/METADATA +14 -0
- torchzero-0.3.13.dist-info/RECORD +166 -0
- {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/top_level.txt +0 -1
- docs/source/conf.py +0 -57
- torchzero/modules/experimental/absoap.py +0 -250
- torchzero/modules/experimental/adadam.py +0 -112
- torchzero/modules/experimental/adamY.py +0 -125
- torchzero/modules/experimental/adasoap.py +0 -172
- torchzero/modules/experimental/diagonal_higher_order_newton.py +0 -225
- torchzero/modules/experimental/eigendescent.py +0 -117
- torchzero/modules/experimental/etf.py +0 -172
- torchzero/modules/experimental/soapy.py +0 -163
- torchzero/modules/experimental/structured_newton.py +0 -111
- torchzero/modules/experimental/subspace_preconditioners.py +0 -138
- torchzero/modules/experimental/tada.py +0 -38
- torchzero/modules/line_search/trust_region.py +0 -73
- torchzero/modules/lr/__init__.py +0 -2
- torchzero/modules/lr/adaptive.py +0 -93
- torchzero/modules/lr/lr.py +0 -63
- torchzero/modules/momentum/matrix_momentum.py +0 -166
- torchzero/modules/ops/debug.py +0 -25
- torchzero/modules/ops/misc.py +0 -418
- torchzero/modules/ops/split.py +0 -75
- torchzero/modules/optimizers/__init__.py +0 -18
- torchzero/modules/optimizers/adagrad.py +0 -155
- torchzero/modules/optimizers/sophia_h.py +0 -129
- torchzero/modules/quasi_newton/cg.py +0 -268
- torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
- torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +0 -266
- torchzero/modules/quasi_newton/olbfgs.py +0 -196
- torchzero/modules/smoothing/gaussian.py +0 -164
- torchzero-0.3.10.dist-info/METADATA +0 -379
- torchzero-0.3.10.dist-info/RECORD +0 -139
- torchzero-0.3.10.dist-info/licenses/LICENSE +0 -21
- {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/WHEEL +0 -0
|
@@ -48,16 +48,25 @@ class Cautious(Transform):
|
|
|
48
48
|
eps (float, optional): epsilon for normalization. Defaults to 1e-6.
|
|
49
49
|
mode (str, optional):
|
|
50
50
|
what to do with updates with inconsistent signs.
|
|
51
|
+
- "zero" - set them to zero (as in paper)
|
|
52
|
+
- "grad" - set them to the gradient (same as using update magnitude and gradient sign)
|
|
53
|
+
- "backtrack" - negate them
|
|
51
54
|
|
|
52
|
-
|
|
55
|
+
## Examples:
|
|
53
56
|
|
|
54
|
-
|
|
57
|
+
Cautious Adam
|
|
55
58
|
|
|
56
|
-
|
|
59
|
+
```python
|
|
60
|
+
opt = tz.Modular(
|
|
61
|
+
bench.parameters(),
|
|
62
|
+
tz.m.Adam(),
|
|
63
|
+
tz.m.Cautious(),
|
|
64
|
+
tz.m.LR(1e-2)
|
|
65
|
+
)
|
|
66
|
+
```
|
|
57
67
|
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
Kaizhao Liang, Lizhang Chen, Bo Liu, Qiang Liu*
|
|
68
|
+
References:
|
|
69
|
+
Cautious Optimizers: Improving Training with One Line of Code. Kaizhao Liang, Lizhang Chen, Bo Liu, Qiang Liu
|
|
61
70
|
"""
|
|
62
71
|
|
|
63
72
|
def __init__(
|
|
@@ -70,7 +79,7 @@ class Cautious(Transform):
|
|
|
70
79
|
super().__init__(defaults, uses_grad=True)
|
|
71
80
|
|
|
72
81
|
@torch.no_grad
|
|
73
|
-
def
|
|
82
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
74
83
|
assert grads is not None
|
|
75
84
|
mode, normalize, eps = itemgetter('mode', 'normalize', 'eps')(settings[0])
|
|
76
85
|
return cautious_(TensorList(tensors), TensorList(grads), normalize=normalize, eps=eps, mode=mode)
|
|
@@ -89,7 +98,7 @@ class UpdateGradientSignConsistency(Transform):
|
|
|
89
98
|
super().__init__(defaults, uses_grad=True)
|
|
90
99
|
|
|
91
100
|
@torch.no_grad
|
|
92
|
-
def
|
|
101
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
93
102
|
assert grads is not None
|
|
94
103
|
normalize, eps = itemgetter('normalize', 'eps')(settings[0])
|
|
95
104
|
|
|
@@ -109,12 +118,9 @@ class IntermoduleCautious(Module):
|
|
|
109
118
|
eps (float, optional): epsilon for normalization. Defaults to 1e-6.
|
|
110
119
|
mode (str, optional):
|
|
111
120
|
what to do with updates with inconsistent signs.
|
|
112
|
-
|
|
113
|
-
"
|
|
114
|
-
|
|
115
|
-
"grad" - set them to the gradient
|
|
116
|
-
|
|
117
|
-
"backtrack" - negate them (same as using update magnitude and gradient sign)
|
|
121
|
+
- "zero" - set them to zero (as in paper)
|
|
122
|
+
- "grad" - set them to the gradient (same as using update magnitude and gradient sign)
|
|
123
|
+
- "backtrack" - negate them
|
|
118
124
|
"""
|
|
119
125
|
def __init__(
|
|
120
126
|
self,
|
|
@@ -142,7 +148,7 @@ class IntermoduleCautious(Module):
|
|
|
142
148
|
compare_var = compare.step(var.clone(clone_update=True))
|
|
143
149
|
var.update_attrs_from_clone_(compare_var)
|
|
144
150
|
|
|
145
|
-
mode, normalize, eps = itemgetter('mode', 'normalize', 'eps')(self.
|
|
151
|
+
mode, normalize, eps = itemgetter('mode', 'normalize', 'eps')(self.defaults)
|
|
146
152
|
var.update = cautious_(
|
|
147
153
|
TensorList(main_var.get_update()),
|
|
148
154
|
TensorList(compare_var.get_update()),
|
|
@@ -159,6 +165,18 @@ class ScaleByGradCosineSimilarity(Transform):
|
|
|
159
165
|
|
|
160
166
|
Args:
|
|
161
167
|
eps (float, optional): epsilon for division. Defaults to 1e-6.
|
|
168
|
+
|
|
169
|
+
## Examples:
|
|
170
|
+
|
|
171
|
+
Scaled Adam
|
|
172
|
+
```python
|
|
173
|
+
opt = tz.Modular(
|
|
174
|
+
bench.parameters(),
|
|
175
|
+
tz.m.Adam(),
|
|
176
|
+
tz.m.ScaleByGradCosineSimilarity(),
|
|
177
|
+
tz.m.LR(1e-2)
|
|
178
|
+
)
|
|
179
|
+
```
|
|
162
180
|
"""
|
|
163
181
|
def __init__(
|
|
164
182
|
self,
|
|
@@ -168,12 +186,12 @@ class ScaleByGradCosineSimilarity(Transform):
|
|
|
168
186
|
super().__init__(defaults, uses_grad=True)
|
|
169
187
|
|
|
170
188
|
@torch.no_grad
|
|
171
|
-
def
|
|
189
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
172
190
|
assert grads is not None
|
|
173
191
|
eps = settings[0]['eps']
|
|
174
192
|
tensors = TensorList(tensors)
|
|
175
193
|
grads = TensorList(grads)
|
|
176
|
-
cos_sim =
|
|
194
|
+
cos_sim = tensors.dot(grads) / (tensors.global_vector_norm() * grads.global_vector_norm()).clip(min=eps)
|
|
177
195
|
|
|
178
196
|
return tensors.mul_(cos_sim)
|
|
179
197
|
|
|
@@ -185,6 +203,20 @@ class ScaleModulesByCosineSimilarity(Module):
|
|
|
185
203
|
main (Chainable): main module or sequence of modules whose update will be scaled.
|
|
186
204
|
compare (Chainable): module or sequence of modules to compare to
|
|
187
205
|
eps (float, optional): epsilon for division. Defaults to 1e-6.
|
|
206
|
+
|
|
207
|
+
## Examples:
|
|
208
|
+
|
|
209
|
+
Adam scaled by similarity to RMSprop
|
|
210
|
+
```python
|
|
211
|
+
opt = tz.Modular(
|
|
212
|
+
bench.parameters(),
|
|
213
|
+
tz.m.ScaleModulesByCosineSimilarity(
|
|
214
|
+
main = tz.m.Adam(),
|
|
215
|
+
compare = tz.m.RMSprop(0.999, debiased=True),
|
|
216
|
+
),
|
|
217
|
+
tz.m.LR(1e-2)
|
|
218
|
+
)
|
|
219
|
+
```
|
|
188
220
|
"""
|
|
189
221
|
def __init__(
|
|
190
222
|
self,
|
|
@@ -211,9 +243,9 @@ class ScaleModulesByCosineSimilarity(Module):
|
|
|
211
243
|
|
|
212
244
|
m = TensorList(main_var.get_update())
|
|
213
245
|
c = TensorList(compare_var.get_update())
|
|
214
|
-
eps = self.
|
|
246
|
+
eps = self.defaults['eps']
|
|
215
247
|
|
|
216
|
-
cos_sim =
|
|
248
|
+
cos_sim = m.dot(c) / (m.global_vector_norm() * c.global_vector_norm()).clip(min=eps)
|
|
217
249
|
|
|
218
250
|
var.update = m.mul_(cos_sim)
|
|
219
251
|
return var
|
|
@@ -1,10 +1,44 @@
|
|
|
1
|
+
from collections import deque
|
|
2
|
+
from operator import itemgetter
|
|
1
3
|
from typing import Literal
|
|
2
4
|
|
|
3
5
|
import torch
|
|
4
6
|
|
|
5
7
|
from ...core import Target, Transform
|
|
6
8
|
from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
|
|
7
|
-
from
|
|
9
|
+
from ..functional import debias, ema_
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class EMA(Transform):
|
|
13
|
+
"""Maintains an exponential moving average of update.
|
|
14
|
+
|
|
15
|
+
Args:
|
|
16
|
+
momentum (float, optional): momentum (beta). Defaults to 0.9.
|
|
17
|
+
dampening (float, optional): momentum dampening. Defaults to 0.
|
|
18
|
+
debiased (bool, optional): whether to debias the EMA like in Adam. Defaults to False.
|
|
19
|
+
lerp (bool, optional): whether to use linear interpolation. Defaults to True.
|
|
20
|
+
ema_init (str, optional): initial values for the EMA, "zeros" or "update".
|
|
21
|
+
target (Target, optional): target to apply EMA to. Defaults to 'update'.
|
|
22
|
+
"""
|
|
23
|
+
def __init__(self, momentum:float=0.9, dampening:float=0, debiased: bool = False, lerp=True, ema_init: Literal['zeros', 'update'] = 'zeros', target: Target = 'update'):
|
|
24
|
+
defaults = dict(momentum=momentum,dampening=dampening,debiased=debiased,lerp=lerp,ema_init=ema_init)
|
|
25
|
+
super().__init__(defaults, uses_grad=False, target=target)
|
|
26
|
+
|
|
27
|
+
@torch.no_grad
|
|
28
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
29
|
+
step = self.global_state['step'] = self.global_state.get('step', 0) + 1
|
|
30
|
+
|
|
31
|
+
debiased, lerp, ema_init = itemgetter('debiased','lerp','ema_init')(settings[0])
|
|
32
|
+
|
|
33
|
+
exp_avg = unpack_states(states, tensors, 'exp_avg',
|
|
34
|
+
init=torch.zeros_like if ema_init=='zeros' else tensors, cls=TensorList)
|
|
35
|
+
momentum, dampening = unpack_dicts(settings, 'momentum','dampening', cls=NumberList)
|
|
36
|
+
|
|
37
|
+
exp_avg = ema_(TensorList(tensors), exp_avg_=exp_avg,beta=momentum,dampening=dampening,lerp=lerp)
|
|
38
|
+
|
|
39
|
+
if debiased: return debias(exp_avg, step=step, beta1=momentum, alpha=1, inplace=False)
|
|
40
|
+
else: return exp_avg.clone() # this has exp_avg storage so needs to be cloned
|
|
41
|
+
|
|
8
42
|
|
|
9
43
|
|
|
10
44
|
class HeavyBall(EMA):
|
|
@@ -55,9 +89,10 @@ class NAG(Transform):
|
|
|
55
89
|
super().__init__(defaults, uses_grad=False, target=target)
|
|
56
90
|
|
|
57
91
|
@torch.no_grad
|
|
58
|
-
def
|
|
92
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
59
93
|
velocity = unpack_states(states, tensors, 'velocity', cls=TensorList)
|
|
60
94
|
lerp = self.settings[params[0]]['lerp']
|
|
61
95
|
|
|
62
96
|
momentum,dampening = unpack_dicts(settings, 'momentum','dampening', cls=NumberList)
|
|
63
97
|
return nag_(TensorList(tensors), velocity_=velocity,momentum=momentum,dampening=dampening,lerp=lerp)
|
|
98
|
+
|
|
@@ -7,7 +7,7 @@ from .accumulate import (
|
|
|
7
7
|
)
|
|
8
8
|
from .binary import (
|
|
9
9
|
Add,
|
|
10
|
-
|
|
10
|
+
BinaryOperationBase,
|
|
11
11
|
Clip,
|
|
12
12
|
CopyMagnitude,
|
|
13
13
|
CopySign,
|
|
@@ -27,37 +27,20 @@ from .binary import (
|
|
|
27
27
|
Sub,
|
|
28
28
|
Threshold,
|
|
29
29
|
)
|
|
30
|
-
from .
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
GraftGradToUpdate,
|
|
38
|
-
GraftToGrad,
|
|
39
|
-
GraftToParams,
|
|
40
|
-
LastAbsoluteRatio,
|
|
41
|
-
LastDifference,
|
|
42
|
-
LastGradDifference,
|
|
43
|
-
LastProduct,
|
|
44
|
-
LastRatio,
|
|
45
|
-
MulByLoss,
|
|
46
|
-
Multistep,
|
|
47
|
-
NegateOnLossIncrease,
|
|
48
|
-
NoiseSign,
|
|
49
|
-
Previous,
|
|
50
|
-
Relative,
|
|
51
|
-
Sequential,
|
|
52
|
-
UpdateSign,
|
|
53
|
-
WeightDropout,
|
|
30
|
+
from .higher_level import (
|
|
31
|
+
CenteredEMASquared,
|
|
32
|
+
CenteredSqrtEMASquared,
|
|
33
|
+
Debias,
|
|
34
|
+
Debias2,
|
|
35
|
+
EMASquared,
|
|
36
|
+
SqrtEMASquared,
|
|
54
37
|
)
|
|
55
38
|
from .multi import (
|
|
56
39
|
ClipModules,
|
|
57
40
|
DivModules,
|
|
58
41
|
GraftModules,
|
|
59
42
|
LerpModules,
|
|
60
|
-
|
|
43
|
+
MultiOperationBase,
|
|
61
44
|
PowModules,
|
|
62
45
|
SubModules,
|
|
63
46
|
)
|
|
@@ -66,13 +49,11 @@ from .reduce import (
|
|
|
66
49
|
Mean,
|
|
67
50
|
MinimumModules,
|
|
68
51
|
Prod,
|
|
69
|
-
|
|
52
|
+
ReduceOperationBase,
|
|
70
53
|
Sum,
|
|
71
54
|
WeightedMean,
|
|
72
55
|
WeightedSum,
|
|
73
56
|
)
|
|
74
|
-
from .split import Split
|
|
75
|
-
from .switch import Alternate, Switch
|
|
76
57
|
from .unary import (
|
|
77
58
|
Abs,
|
|
78
59
|
CustomUnaryOperation,
|
|
@@ -91,13 +72,12 @@ from .utility import (
|
|
|
91
72
|
Grad,
|
|
92
73
|
GradToNone,
|
|
93
74
|
Identity,
|
|
94
|
-
|
|
75
|
+
Noop,
|
|
95
76
|
Ones,
|
|
96
77
|
Params,
|
|
97
78
|
Randn,
|
|
98
79
|
RandomSample,
|
|
99
80
|
Uniform,
|
|
100
|
-
Update,
|
|
101
81
|
UpdateToNone,
|
|
102
82
|
Zeros,
|
|
103
83
|
)
|
|
@@ -1,11 +1,7 @@
|
|
|
1
|
-
from collections import deque
|
|
2
|
-
from operator import itemgetter
|
|
3
|
-
from typing import Literal
|
|
4
|
-
|
|
5
1
|
import torch
|
|
6
2
|
|
|
7
3
|
from ...core import Target, Transform
|
|
8
|
-
from ...utils import TensorList,
|
|
4
|
+
from ...utils import TensorList, unpack_states
|
|
9
5
|
|
|
10
6
|
class AccumulateSum(Transform):
|
|
11
7
|
"""Accumulates sum of all past updates.
|
|
@@ -19,7 +15,7 @@ class AccumulateSum(Transform):
|
|
|
19
15
|
super().__init__(defaults, uses_grad=False, target=target)
|
|
20
16
|
|
|
21
17
|
@torch.no_grad
|
|
22
|
-
def
|
|
18
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
23
19
|
sum = unpack_states(states, tensors, 'sum', cls=TensorList)
|
|
24
20
|
decay = [1-s['decay'] for s in settings]
|
|
25
21
|
return sum.add_(tensors).lazy_mul(decay, clone=True)
|
|
@@ -36,7 +32,7 @@ class AccumulateMean(Transform):
|
|
|
36
32
|
super().__init__(defaults, uses_grad=False, target=target)
|
|
37
33
|
|
|
38
34
|
@torch.no_grad
|
|
39
|
-
def
|
|
35
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
40
36
|
step = self.global_state['step'] = self.global_state.get('step', 0) + 1
|
|
41
37
|
mean = unpack_states(states, tensors, 'mean', cls=TensorList)
|
|
42
38
|
decay = [1-s['decay'] for s in settings]
|
|
@@ -54,7 +50,7 @@ class AccumulateProduct(Transform):
|
|
|
54
50
|
super().__init__(defaults, uses_grad=False, target=target)
|
|
55
51
|
|
|
56
52
|
@torch.no_grad
|
|
57
|
-
def
|
|
53
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
58
54
|
prod = unpack_states(states, tensors, 'prod', cls=TensorList)
|
|
59
55
|
decay = [1-s['decay'] for s in settings]
|
|
60
56
|
return prod.mul_(tensors).lazy_mul(decay, clone=True)
|
|
@@ -71,7 +67,7 @@ class AccumulateMaximum(Transform):
|
|
|
71
67
|
super().__init__(defaults, uses_grad=False, target=target)
|
|
72
68
|
|
|
73
69
|
@torch.no_grad
|
|
74
|
-
def
|
|
70
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
75
71
|
maximum = unpack_states(states, tensors, 'maximum', cls=TensorList)
|
|
76
72
|
decay = [1-s['decay'] for s in settings]
|
|
77
73
|
return maximum.maximum_(tensors).lazy_mul(decay, clone=True)
|
|
@@ -88,7 +84,7 @@ class AccumulateMinimum(Transform):
|
|
|
88
84
|
super().__init__(defaults, uses_grad=False, target=target)
|
|
89
85
|
|
|
90
86
|
@torch.no_grad
|
|
91
|
-
def
|
|
87
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
92
88
|
minimum = unpack_states(states, tensors, 'minimum', cls=TensorList)
|
|
93
89
|
decay = [1-s['decay'] for s in settings]
|
|
94
90
|
return minimum.minimum_(tensors).lazy_mul(decay, clone=True)
|
torchzero/modules/ops/binary.py
CHANGED
|
@@ -1,5 +1,4 @@
|
|
|
1
1
|
#pyright: reportIncompatibleMethodOverride=false
|
|
2
|
-
""""""
|
|
3
2
|
from abc import ABC, abstractmethod
|
|
4
3
|
from collections.abc import Iterable, Sequence
|
|
5
4
|
from operator import itemgetter
|
|
@@ -11,7 +10,7 @@ from ...core import Chainable, Module, Target, Var, maybe_chain
|
|
|
11
10
|
from ...utils import TensorList, tensorlist
|
|
12
11
|
|
|
13
12
|
|
|
14
|
-
class
|
|
13
|
+
class BinaryOperationBase(Module, ABC):
|
|
15
14
|
"""Base class for operations that use update as the first operand. This is an abstract class, subclass it and override `transform` method to use it."""
|
|
16
15
|
def __init__(self, defaults: dict[str, Any] | None, **operands: Chainable | Any):
|
|
17
16
|
super().__init__(defaults=defaults)
|
|
@@ -47,29 +46,41 @@ class BinaryOperation(Module, ABC):
|
|
|
47
46
|
return var
|
|
48
47
|
|
|
49
48
|
|
|
50
|
-
class Add(
|
|
49
|
+
class Add(BinaryOperationBase):
|
|
50
|
+
"""Add :code:`other` to tensors. :code:`other` can be a number or a module.
|
|
51
|
+
|
|
52
|
+
If :code:`other` is a module, this calculates :code:`tensors + other(tensors)`
|
|
53
|
+
"""
|
|
51
54
|
def __init__(self, other: Chainable | float, alpha: float = 1):
|
|
52
55
|
defaults = dict(alpha=alpha)
|
|
53
56
|
super().__init__(defaults, other=other)
|
|
54
57
|
|
|
55
58
|
@torch.no_grad
|
|
56
59
|
def transform(self, var, update: list[torch.Tensor], other: float | list[torch.Tensor]):
|
|
57
|
-
if isinstance(other, (int,float)): torch._foreach_add_(update, other * self.
|
|
58
|
-
else: torch._foreach_add_(update, other, alpha=self.
|
|
60
|
+
if isinstance(other, (int,float)): torch._foreach_add_(update, other * self.defaults['alpha'])
|
|
61
|
+
else: torch._foreach_add_(update, other, alpha=self.defaults['alpha'])
|
|
59
62
|
return update
|
|
60
63
|
|
|
61
|
-
class Sub(
|
|
64
|
+
class Sub(BinaryOperationBase):
|
|
65
|
+
"""Subtract :code:`other` from tensors. :code:`other` can be a number or a module.
|
|
66
|
+
|
|
67
|
+
If :code:`other` is a module, this calculates :code:`tensors - other(tensors)`
|
|
68
|
+
"""
|
|
62
69
|
def __init__(self, other: Chainable | float, alpha: float = 1):
|
|
63
70
|
defaults = dict(alpha=alpha)
|
|
64
71
|
super().__init__(defaults, other=other)
|
|
65
72
|
|
|
66
73
|
@torch.no_grad
|
|
67
74
|
def transform(self, var, update: list[torch.Tensor], other: float | list[torch.Tensor]):
|
|
68
|
-
if isinstance(other, (int,float)): torch._foreach_sub_(update, other * self.
|
|
69
|
-
else: torch._foreach_sub_(update, other, alpha=self.
|
|
75
|
+
if isinstance(other, (int,float)): torch._foreach_sub_(update, other * self.defaults['alpha'])
|
|
76
|
+
else: torch._foreach_sub_(update, other, alpha=self.defaults['alpha'])
|
|
70
77
|
return update
|
|
71
78
|
|
|
72
|
-
class RSub(
|
|
79
|
+
class RSub(BinaryOperationBase):
|
|
80
|
+
"""Subtract tensors from :code:`other`. :code:`other` can be a number or a module.
|
|
81
|
+
|
|
82
|
+
If :code:`other` is a module, this calculates :code:`other(tensors) - tensors`
|
|
83
|
+
"""
|
|
73
84
|
def __init__(self, other: Chainable | float):
|
|
74
85
|
super().__init__({}, other=other)
|
|
75
86
|
|
|
@@ -77,7 +88,11 @@ class RSub(BinaryOperation):
|
|
|
77
88
|
def transform(self, var, update: list[torch.Tensor], other: float | list[torch.Tensor]):
|
|
78
89
|
return other - TensorList(update)
|
|
79
90
|
|
|
80
|
-
class Mul(
|
|
91
|
+
class Mul(BinaryOperationBase):
|
|
92
|
+
"""Multiply tensors by :code:`other`. :code:`other` can be a number or a module.
|
|
93
|
+
|
|
94
|
+
If :code:`other` is a module, this calculates :code:`tensors * other(tensors)`
|
|
95
|
+
"""
|
|
81
96
|
def __init__(self, other: Chainable | float):
|
|
82
97
|
super().__init__({}, other=other)
|
|
83
98
|
|
|
@@ -86,7 +101,11 @@ class Mul(BinaryOperation):
|
|
|
86
101
|
torch._foreach_mul_(update, other)
|
|
87
102
|
return update
|
|
88
103
|
|
|
89
|
-
class Div(
|
|
104
|
+
class Div(BinaryOperationBase):
|
|
105
|
+
"""Divide tensors by :code:`other`. :code:`other` can be a number or a module.
|
|
106
|
+
|
|
107
|
+
If :code:`other` is a module, this calculates :code:`tensors / other(tensors)`
|
|
108
|
+
"""
|
|
90
109
|
def __init__(self, other: Chainable | float):
|
|
91
110
|
super().__init__({}, other=other)
|
|
92
111
|
|
|
@@ -95,7 +114,11 @@ class Div(BinaryOperation):
|
|
|
95
114
|
torch._foreach_div_(update, other)
|
|
96
115
|
return update
|
|
97
116
|
|
|
98
|
-
class RDiv(
|
|
117
|
+
class RDiv(BinaryOperationBase):
|
|
118
|
+
"""Divide :code:`other` by tensors. :code:`other` can be a number or a module.
|
|
119
|
+
|
|
120
|
+
If :code:`other` is a module, this calculates :code:`other(tensors) / tensors`
|
|
121
|
+
"""
|
|
99
122
|
def __init__(self, other: Chainable | float):
|
|
100
123
|
super().__init__({}, other=other)
|
|
101
124
|
|
|
@@ -103,7 +126,11 @@ class RDiv(BinaryOperation):
|
|
|
103
126
|
def transform(self, var, update: list[torch.Tensor], other: float | list[torch.Tensor]):
|
|
104
127
|
return other / TensorList(update)
|
|
105
128
|
|
|
106
|
-
class Pow(
|
|
129
|
+
class Pow(BinaryOperationBase):
|
|
130
|
+
"""Take tensors to the power of :code:`exponent`. :code:`exponent` can be a number or a module.
|
|
131
|
+
|
|
132
|
+
If :code:`exponent` is a module, this calculates :code:`tensors ^ exponent(tensors)`
|
|
133
|
+
"""
|
|
107
134
|
def __init__(self, exponent: Chainable | float):
|
|
108
135
|
super().__init__({}, exponent=exponent)
|
|
109
136
|
|
|
@@ -112,7 +139,11 @@ class Pow(BinaryOperation):
|
|
|
112
139
|
torch._foreach_pow_(update, exponent)
|
|
113
140
|
return update
|
|
114
141
|
|
|
115
|
-
class RPow(
|
|
142
|
+
class RPow(BinaryOperationBase):
|
|
143
|
+
"""Take :code:`other` to the power of tensors. :code:`other` can be a number or a module.
|
|
144
|
+
|
|
145
|
+
If :code:`other` is a module, this calculates :code:`other(tensors) ^ tensors`
|
|
146
|
+
"""
|
|
116
147
|
def __init__(self, other: Chainable | float):
|
|
117
148
|
super().__init__({}, other=other)
|
|
118
149
|
|
|
@@ -122,7 +153,11 @@ class RPow(BinaryOperation):
|
|
|
122
153
|
torch._foreach_pow_(other, update)
|
|
123
154
|
return other
|
|
124
155
|
|
|
125
|
-
class Lerp(
|
|
156
|
+
class Lerp(BinaryOperationBase):
|
|
157
|
+
"""Does a linear interpolation of tensors and :code:`end` module based on a scalar :code:`weight`.
|
|
158
|
+
|
|
159
|
+
The output is given by :code:`output = tensors + weight * (end(tensors) - tensors)`
|
|
160
|
+
"""
|
|
126
161
|
def __init__(self, end: Chainable, weight: float):
|
|
127
162
|
defaults = dict(weight=weight)
|
|
128
163
|
super().__init__(defaults, end=end)
|
|
@@ -132,7 +167,8 @@ class Lerp(BinaryOperation):
|
|
|
132
167
|
torch._foreach_lerp_(update, end, weight=self.get_settings(var.params, 'weight'))
|
|
133
168
|
return update
|
|
134
169
|
|
|
135
|
-
class CopySign(
|
|
170
|
+
class CopySign(BinaryOperationBase):
|
|
171
|
+
"""Returns tensors with sign copied from :code:`other(tensors)`."""
|
|
136
172
|
def __init__(self, other: Chainable):
|
|
137
173
|
super().__init__({}, other=other)
|
|
138
174
|
|
|
@@ -140,7 +176,8 @@ class CopySign(BinaryOperation):
|
|
|
140
176
|
def transform(self, var, update: list[torch.Tensor], other: list[torch.Tensor]):
|
|
141
177
|
return [u.copysign_(o) for u, o in zip(update, other)]
|
|
142
178
|
|
|
143
|
-
class RCopySign(
|
|
179
|
+
class RCopySign(BinaryOperationBase):
|
|
180
|
+
"""Returns :code:`other(tensors)` with sign copied from tensors."""
|
|
144
181
|
def __init__(self, other: Chainable):
|
|
145
182
|
super().__init__({}, other=other)
|
|
146
183
|
|
|
@@ -149,7 +186,11 @@ class RCopySign(BinaryOperation):
|
|
|
149
186
|
return [o.copysign_(u) for u, o in zip(update, other)]
|
|
150
187
|
CopyMagnitude = RCopySign
|
|
151
188
|
|
|
152
|
-
class Clip(
|
|
189
|
+
class Clip(BinaryOperationBase):
|
|
190
|
+
"""clip tensors to be in :code:`(min, max)` range. :code:`min` and :code:`max: can be None, numbers or modules.
|
|
191
|
+
|
|
192
|
+
If code:`min` and :code:`max`: are modules, this calculates :code:`tensors.clip(min(tensors), max(tensors))`.
|
|
193
|
+
"""
|
|
153
194
|
def __init__(self, min: float | Chainable | None = None, max: float | Chainable | None = None):
|
|
154
195
|
super().__init__({}, min=min, max=max)
|
|
155
196
|
|
|
@@ -157,8 +198,11 @@ class Clip(BinaryOperation):
|
|
|
157
198
|
def transform(self, var, update: list[torch.Tensor], min: float | list[torch.Tensor] | None, max: float | list[torch.Tensor] | None):
|
|
158
199
|
return TensorList(update).clamp_(min=min, max=max)
|
|
159
200
|
|
|
160
|
-
class MirroredClip(
|
|
161
|
-
"""clip
|
|
201
|
+
class MirroredClip(BinaryOperationBase):
|
|
202
|
+
"""clip tensors to be in :code:`(-value, value)` range. :code:`value` can be a number or a module.
|
|
203
|
+
|
|
204
|
+
If :code:`value` is a module, this calculates :code:`tensors.clip(-value(tensors), value(tensors))`
|
|
205
|
+
"""
|
|
162
206
|
def __init__(self, value: float | Chainable):
|
|
163
207
|
super().__init__({}, value=value)
|
|
164
208
|
|
|
@@ -167,19 +211,19 @@ class MirroredClip(BinaryOperation):
|
|
|
167
211
|
min = -value if isinstance(value, (int,float)) else [-v for v in value]
|
|
168
212
|
return TensorList(update).clamp_(min=min, max=value)
|
|
169
213
|
|
|
170
|
-
class Graft(
|
|
171
|
-
"""
|
|
214
|
+
class Graft(BinaryOperationBase):
|
|
215
|
+
"""Outputs tensors rescaled to have the same norm as :code:`magnitude(tensors)`."""
|
|
172
216
|
def __init__(self, magnitude: Chainable, tensorwise:bool=True, ord:float=2, eps:float = 1e-6):
|
|
173
217
|
defaults = dict(tensorwise=tensorwise, ord=ord, eps=eps)
|
|
174
218
|
super().__init__(defaults, magnitude=magnitude)
|
|
175
219
|
|
|
176
220
|
@torch.no_grad
|
|
177
221
|
def transform(self, var, update: list[torch.Tensor], magnitude: list[torch.Tensor]):
|
|
178
|
-
tensorwise, ord, eps = itemgetter('tensorwise','ord','eps')(self.
|
|
222
|
+
tensorwise, ord, eps = itemgetter('tensorwise','ord','eps')(self.defaults)
|
|
179
223
|
return TensorList(update).graft_(magnitude, tensorwise=tensorwise, ord=ord, eps=eps)
|
|
180
224
|
|
|
181
|
-
class RGraft(
|
|
182
|
-
"""
|
|
225
|
+
class RGraft(BinaryOperationBase):
|
|
226
|
+
"""Outputs :code:`magnitude(tensors)` rescaled to have the same norm as tensors"""
|
|
183
227
|
|
|
184
228
|
def __init__(self, direction: Chainable, tensorwise:bool=True, ord:float=2, eps:float = 1e-6):
|
|
185
229
|
defaults = dict(tensorwise=tensorwise, ord=ord, eps=eps)
|
|
@@ -187,12 +231,13 @@ class RGraft(BinaryOperation):
|
|
|
187
231
|
|
|
188
232
|
@torch.no_grad
|
|
189
233
|
def transform(self, var, update: list[torch.Tensor], direction: list[torch.Tensor]):
|
|
190
|
-
tensorwise, ord, eps = itemgetter('tensorwise','ord','eps')(self.
|
|
234
|
+
tensorwise, ord, eps = itemgetter('tensorwise','ord','eps')(self.defaults)
|
|
191
235
|
return TensorList(direction).graft_(update, tensorwise=tensorwise, ord=ord, eps=eps)
|
|
192
236
|
|
|
193
237
|
GraftToUpdate = RGraft
|
|
194
238
|
|
|
195
|
-
class Maximum(
|
|
239
|
+
class Maximum(BinaryOperationBase):
|
|
240
|
+
"""Outputs :code:`maximum(tensors, other(tensors))`"""
|
|
196
241
|
def __init__(self, other: Chainable):
|
|
197
242
|
super().__init__({}, other=other)
|
|
198
243
|
|
|
@@ -201,7 +246,8 @@ class Maximum(BinaryOperation):
|
|
|
201
246
|
torch._foreach_maximum_(update, other)
|
|
202
247
|
return update
|
|
203
248
|
|
|
204
|
-
class Minimum(
|
|
249
|
+
class Minimum(BinaryOperationBase):
|
|
250
|
+
"""Outputs :code:`minimum(tensors, other(tensors))`"""
|
|
205
251
|
def __init__(self, other: Chainable):
|
|
206
252
|
super().__init__({}, other=other)
|
|
207
253
|
|
|
@@ -211,26 +257,27 @@ class Minimum(BinaryOperation):
|
|
|
211
257
|
return update
|
|
212
258
|
|
|
213
259
|
|
|
214
|
-
class GramSchimdt(
|
|
215
|
-
"""
|
|
260
|
+
class GramSchimdt(BinaryOperationBase):
|
|
261
|
+
"""outputs tensors made orthogonal to `other(tensors)` via Gram-Schmidt."""
|
|
216
262
|
def __init__(self, other: Chainable):
|
|
217
263
|
super().__init__({}, other=other)
|
|
218
264
|
|
|
219
265
|
@torch.no_grad
|
|
220
266
|
def transform(self, var, update: list[torch.Tensor], other: list[torch.Tensor]):
|
|
221
267
|
update = TensorList(update); other = TensorList(other)
|
|
222
|
-
|
|
268
|
+
min = torch.finfo(update[0].dtype).tiny * 2
|
|
269
|
+
return update - (other*update) / (other*other).clip(min=min)
|
|
223
270
|
|
|
224
271
|
|
|
225
|
-
class Threshold(
|
|
226
|
-
"""
|
|
272
|
+
class Threshold(BinaryOperationBase):
|
|
273
|
+
"""Outputs tensors thresholded such that values above :code:`threshold` are set to :code:`value`."""
|
|
227
274
|
def __init__(self, threshold: Chainable | float, value: Chainable | float, update_above: bool):
|
|
228
275
|
defaults = dict(update_above=update_above)
|
|
229
276
|
super().__init__(defaults, threshold=threshold, value=value)
|
|
230
277
|
|
|
231
278
|
@torch.no_grad
|
|
232
279
|
def transform(self, var, update: list[torch.Tensor], threshold: list[torch.Tensor] | float, value: list[torch.Tensor] | float):
|
|
233
|
-
update_above = self.
|
|
280
|
+
update_above = self.defaults['update_above']
|
|
234
281
|
update = TensorList(update)
|
|
235
282
|
if update_above:
|
|
236
283
|
if isinstance(value, list): return update.where_(update>threshold, value)
|