torchzero 0.3.14__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_identical.py +2 -2
- tests/test_module_autograd.py +586 -0
- tests/test_objective.py +188 -0
- tests/test_opts.py +47 -36
- tests/test_tensorlist.py +0 -8
- tests/test_utils_optimizer.py +0 -1
- torchzero/__init__.py +1 -1
- torchzero/core/__init__.py +8 -2
- torchzero/core/chain.py +47 -0
- torchzero/core/functional.py +103 -0
- torchzero/core/modular.py +233 -0
- torchzero/core/module.py +132 -643
- torchzero/core/objective.py +948 -0
- torchzero/core/reformulation.py +56 -23
- torchzero/core/transform.py +261 -365
- torchzero/linalg/__init__.py +10 -0
- torchzero/linalg/eigh.py +34 -0
- torchzero/linalg/linalg_utils.py +14 -0
- torchzero/{utils/linalg → linalg}/linear_operator.py +132 -34
- torchzero/linalg/matrix_power.py +28 -0
- torchzero/linalg/orthogonalize.py +95 -0
- torchzero/{utils/linalg → linalg}/qr.py +4 -2
- torchzero/{utils/linalg → linalg}/solve.py +76 -88
- torchzero/linalg/svd.py +20 -0
- torchzero/linalg/torch_linalg.py +168 -0
- torchzero/modules/__init__.py +0 -1
- torchzero/modules/adaptive/__init__.py +1 -1
- torchzero/modules/adaptive/adagrad.py +163 -213
- torchzero/modules/adaptive/adahessian.py +74 -103
- torchzero/modules/adaptive/adam.py +53 -76
- torchzero/modules/adaptive/adan.py +49 -30
- torchzero/modules/adaptive/adaptive_heavyball.py +11 -6
- torchzero/modules/adaptive/aegd.py +12 -12
- torchzero/modules/adaptive/esgd.py +98 -119
- torchzero/modules/adaptive/lion.py +5 -10
- torchzero/modules/adaptive/lmadagrad.py +87 -32
- torchzero/modules/adaptive/mars.py +5 -5
- torchzero/modules/adaptive/matrix_momentum.py +47 -51
- torchzero/modules/adaptive/msam.py +70 -52
- torchzero/modules/adaptive/muon.py +59 -124
- torchzero/modules/adaptive/natural_gradient.py +33 -28
- torchzero/modules/adaptive/orthograd.py +11 -15
- torchzero/modules/adaptive/rmsprop.py +83 -75
- torchzero/modules/adaptive/rprop.py +48 -47
- torchzero/modules/adaptive/sam.py +55 -45
- torchzero/modules/adaptive/shampoo.py +123 -129
- torchzero/modules/adaptive/soap.py +207 -143
- torchzero/modules/adaptive/sophia_h.py +106 -130
- torchzero/modules/clipping/clipping.py +15 -18
- torchzero/modules/clipping/ema_clipping.py +31 -25
- torchzero/modules/clipping/growth_clipping.py +14 -17
- torchzero/modules/conjugate_gradient/cg.py +26 -37
- torchzero/modules/experimental/__init__.py +3 -6
- torchzero/modules/experimental/coordinate_momentum.py +36 -0
- torchzero/modules/experimental/curveball.py +25 -41
- torchzero/modules/experimental/gradmin.py +2 -2
- torchzero/modules/{higher_order → experimental}/higher_order_newton.py +14 -40
- torchzero/modules/experimental/newton_solver.py +22 -53
- torchzero/modules/experimental/newtonnewton.py +20 -17
- torchzero/modules/experimental/reduce_outward_lr.py +7 -7
- torchzero/modules/experimental/scipy_newton_cg.py +21 -24
- torchzero/modules/experimental/spsa1.py +5 -5
- torchzero/modules/experimental/structural_projections.py +1 -4
- torchzero/modules/functional.py +8 -1
- torchzero/modules/grad_approximation/forward_gradient.py +7 -7
- torchzero/modules/grad_approximation/grad_approximator.py +23 -16
- torchzero/modules/grad_approximation/rfdm.py +20 -17
- torchzero/modules/least_squares/gn.py +90 -42
- torchzero/modules/line_search/__init__.py +1 -1
- torchzero/modules/line_search/_polyinterp.py +3 -1
- torchzero/modules/line_search/adaptive.py +3 -3
- torchzero/modules/line_search/backtracking.py +3 -3
- torchzero/modules/line_search/interpolation.py +160 -0
- torchzero/modules/line_search/line_search.py +42 -51
- torchzero/modules/line_search/strong_wolfe.py +5 -5
- torchzero/modules/misc/debug.py +12 -12
- torchzero/modules/misc/escape.py +10 -10
- torchzero/modules/misc/gradient_accumulation.py +10 -78
- torchzero/modules/misc/homotopy.py +16 -8
- torchzero/modules/misc/misc.py +120 -122
- torchzero/modules/misc/multistep.py +63 -61
- torchzero/modules/misc/regularization.py +49 -44
- torchzero/modules/misc/split.py +30 -28
- torchzero/modules/misc/switch.py +37 -32
- torchzero/modules/momentum/averaging.py +14 -14
- torchzero/modules/momentum/cautious.py +34 -28
- torchzero/modules/momentum/momentum.py +11 -11
- torchzero/modules/ops/__init__.py +4 -4
- torchzero/modules/ops/accumulate.py +21 -21
- torchzero/modules/ops/binary.py +67 -66
- torchzero/modules/ops/higher_level.py +19 -19
- torchzero/modules/ops/multi.py +44 -41
- torchzero/modules/ops/reduce.py +26 -23
- torchzero/modules/ops/unary.py +53 -53
- torchzero/modules/ops/utility.py +47 -46
- torchzero/modules/projections/galore.py +1 -1
- torchzero/modules/projections/projection.py +43 -43
- torchzero/modules/quasi_newton/__init__.py +2 -0
- torchzero/modules/quasi_newton/damping.py +1 -1
- torchzero/modules/quasi_newton/lbfgs.py +7 -7
- torchzero/modules/quasi_newton/lsr1.py +7 -7
- torchzero/modules/quasi_newton/quasi_newton.py +25 -16
- torchzero/modules/quasi_newton/sg2.py +292 -0
- torchzero/modules/restarts/restars.py +26 -24
- torchzero/modules/second_order/__init__.py +6 -3
- torchzero/modules/second_order/ifn.py +58 -0
- torchzero/modules/second_order/inm.py +101 -0
- torchzero/modules/second_order/multipoint.py +40 -80
- torchzero/modules/second_order/newton.py +105 -228
- torchzero/modules/second_order/newton_cg.py +102 -154
- torchzero/modules/second_order/nystrom.py +158 -178
- torchzero/modules/second_order/rsn.py +237 -0
- torchzero/modules/smoothing/laplacian.py +13 -12
- torchzero/modules/smoothing/sampling.py +11 -10
- torchzero/modules/step_size/adaptive.py +23 -23
- torchzero/modules/step_size/lr.py +15 -15
- torchzero/modules/termination/termination.py +32 -30
- torchzero/modules/trust_region/cubic_regularization.py +2 -2
- torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
- torchzero/modules/trust_region/trust_cg.py +1 -1
- torchzero/modules/trust_region/trust_region.py +27 -22
- torchzero/modules/variance_reduction/svrg.py +21 -18
- torchzero/modules/weight_decay/__init__.py +2 -1
- torchzero/modules/weight_decay/reinit.py +83 -0
- torchzero/modules/weight_decay/weight_decay.py +12 -13
- torchzero/modules/wrappers/optim_wrapper.py +57 -50
- torchzero/modules/zeroth_order/cd.py +9 -6
- torchzero/optim/root.py +3 -3
- torchzero/optim/utility/split.py +2 -1
- torchzero/optim/wrappers/directsearch.py +27 -63
- torchzero/optim/wrappers/fcmaes.py +14 -35
- torchzero/optim/wrappers/mads.py +11 -31
- torchzero/optim/wrappers/moors.py +66 -0
- torchzero/optim/wrappers/nevergrad.py +4 -4
- torchzero/optim/wrappers/nlopt.py +31 -25
- torchzero/optim/wrappers/optuna.py +6 -13
- torchzero/optim/wrappers/pybobyqa.py +124 -0
- torchzero/optim/wrappers/scipy/__init__.py +7 -0
- torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
- torchzero/optim/wrappers/scipy/brute.py +48 -0
- torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
- torchzero/optim/wrappers/scipy/direct.py +69 -0
- torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
- torchzero/optim/wrappers/scipy/experimental.py +141 -0
- torchzero/optim/wrappers/scipy/minimize.py +151 -0
- torchzero/optim/wrappers/scipy/sgho.py +111 -0
- torchzero/optim/wrappers/wrapper.py +121 -0
- torchzero/utils/__init__.py +7 -25
- torchzero/utils/compile.py +2 -2
- torchzero/utils/derivatives.py +112 -88
- torchzero/utils/optimizer.py +4 -77
- torchzero/utils/python_tools.py +31 -0
- torchzero/utils/tensorlist.py +11 -5
- torchzero/utils/thoad_tools.py +68 -0
- {torchzero-0.3.14.dist-info → torchzero-0.4.0.dist-info}/METADATA +1 -1
- torchzero-0.4.0.dist-info/RECORD +191 -0
- tests/test_vars.py +0 -185
- torchzero/modules/experimental/momentum.py +0 -160
- torchzero/modules/higher_order/__init__.py +0 -1
- torchzero/optim/wrappers/scipy.py +0 -572
- torchzero/utils/linalg/__init__.py +0 -12
- torchzero/utils/linalg/matrix_funcs.py +0 -87
- torchzero/utils/linalg/orthogonalize.py +0 -12
- torchzero/utils/linalg/svd.py +0 -20
- torchzero/utils/ops.py +0 -10
- torchzero-0.3.14.dist-info/RECORD +0 -167
- /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
- {torchzero-0.3.14.dist-info → torchzero-0.4.0.dist-info}/WHEEL +0 -0
- {torchzero-0.3.14.dist-info → torchzero-0.4.0.dist-info}/top_level.txt +0 -0
|
@@ -1,27 +1,27 @@
|
|
|
1
1
|
"""Modules that perform averaging over a history of past updates."""
|
|
2
2
|
from collections import deque
|
|
3
3
|
from collections.abc import Sequence
|
|
4
|
-
from typing import Any
|
|
4
|
+
from typing import Any
|
|
5
5
|
|
|
6
6
|
import torch
|
|
7
7
|
|
|
8
|
-
from ...core import
|
|
8
|
+
from ...core import TensorTransform
|
|
9
9
|
from ...utils import tolist
|
|
10
10
|
|
|
11
11
|
|
|
12
|
-
class Averaging(
|
|
12
|
+
class Averaging(TensorTransform):
|
|
13
13
|
"""Average of past ``history_size`` updates.
|
|
14
14
|
|
|
15
15
|
Args:
|
|
16
16
|
history_size (int): Number of past updates to average
|
|
17
17
|
target (Target, optional): target. Defaults to 'update'.
|
|
18
18
|
"""
|
|
19
|
-
def __init__(self, history_size: int
|
|
19
|
+
def __init__(self, history_size: int):
|
|
20
20
|
defaults = dict(history_size=history_size)
|
|
21
|
-
super().__init__(
|
|
21
|
+
super().__init__(defaults=defaults)
|
|
22
22
|
|
|
23
23
|
@torch.no_grad
|
|
24
|
-
def
|
|
24
|
+
def single_tensor_apply(self, tensor, param, grad, loss, state, setting):
|
|
25
25
|
history_size = setting['history_size']
|
|
26
26
|
if 'history' not in state:
|
|
27
27
|
state['history'] = deque(maxlen=history_size)
|
|
@@ -34,19 +34,19 @@ class Averaging(TensorwiseTransform):
|
|
|
34
34
|
|
|
35
35
|
return average / len(history)
|
|
36
36
|
|
|
37
|
-
class WeightedAveraging(
|
|
37
|
+
class WeightedAveraging(TensorTransform):
|
|
38
38
|
"""Weighted average of past ``len(weights)`` updates.
|
|
39
39
|
|
|
40
40
|
Args:
|
|
41
41
|
weights (Sequence[float]): a sequence of weights from oldest to newest.
|
|
42
42
|
target (Target, optional): target. Defaults to 'update'.
|
|
43
43
|
"""
|
|
44
|
-
def __init__(self, weights: Sequence[float] | torch.Tensor | Any
|
|
44
|
+
def __init__(self, weights: Sequence[float] | torch.Tensor | Any):
|
|
45
45
|
defaults = dict(weights = tolist(weights))
|
|
46
|
-
super().__init__(
|
|
46
|
+
super().__init__(defaults=defaults)
|
|
47
47
|
|
|
48
48
|
@torch.no_grad
|
|
49
|
-
def
|
|
49
|
+
def single_tensor_apply(self, tensor, param, grad, loss, state, setting):
|
|
50
50
|
weights = setting['weights']
|
|
51
51
|
|
|
52
52
|
if 'history' not in state:
|
|
@@ -68,19 +68,19 @@ class WeightedAveraging(TensorwiseTransform):
|
|
|
68
68
|
return average
|
|
69
69
|
|
|
70
70
|
|
|
71
|
-
class MedianAveraging(
|
|
71
|
+
class MedianAveraging(TensorTransform):
|
|
72
72
|
"""Median of past ``history_size`` updates.
|
|
73
73
|
|
|
74
74
|
Args:
|
|
75
75
|
history_size (int): Number of past updates to average
|
|
76
76
|
target (Target, optional): target. Defaults to 'update'.
|
|
77
77
|
"""
|
|
78
|
-
def __init__(self, history_size: int,
|
|
78
|
+
def __init__(self, history_size: int,):
|
|
79
79
|
defaults = dict(history_size = history_size)
|
|
80
|
-
super().__init__(
|
|
80
|
+
super().__init__(defaults=defaults)
|
|
81
81
|
|
|
82
82
|
@torch.no_grad
|
|
83
|
-
def
|
|
83
|
+
def single_tensor_apply(self, tensor, param, grad, loss, state, setting):
|
|
84
84
|
history_size = setting['history_size']
|
|
85
85
|
|
|
86
86
|
if 'history' not in state:
|
|
@@ -5,7 +5,7 @@ from typing import Literal
|
|
|
5
5
|
|
|
6
6
|
import torch
|
|
7
7
|
|
|
8
|
-
from ...core import
|
|
8
|
+
from ...core import TensorTransform, Module, Chainable
|
|
9
9
|
from ...utils import NumberList, TensorList, unpack_dicts
|
|
10
10
|
|
|
11
11
|
|
|
@@ -36,7 +36,7 @@ def cautious_(
|
|
|
36
36
|
tensors_ -= tensors_.mul(2).mul_(mask.logical_not_())
|
|
37
37
|
return tensors_
|
|
38
38
|
|
|
39
|
-
class Cautious(
|
|
39
|
+
class Cautious(TensorTransform):
|
|
40
40
|
"""Negates update for parameters where update and gradient sign is inconsistent.
|
|
41
41
|
Optionally normalizes the update by the number of parameters that are not masked.
|
|
42
42
|
This is meant to be used after any momentum-based modules.
|
|
@@ -79,12 +79,12 @@ class Cautious(Transform):
|
|
|
79
79
|
super().__init__(defaults, uses_grad=True)
|
|
80
80
|
|
|
81
81
|
@torch.no_grad
|
|
82
|
-
def
|
|
82
|
+
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
83
83
|
assert grads is not None
|
|
84
84
|
mode, normalize, eps = itemgetter('mode', 'normalize', 'eps')(settings[0])
|
|
85
85
|
return cautious_(TensorList(tensors), TensorList(grads), normalize=normalize, eps=eps, mode=mode)
|
|
86
86
|
|
|
87
|
-
class UpdateGradientSignConsistency(
|
|
87
|
+
class UpdateGradientSignConsistency(TensorTransform):
|
|
88
88
|
"""Compares update and gradient signs. Output will have 1s where signs match, and 0s where they don't.
|
|
89
89
|
|
|
90
90
|
Args:
|
|
@@ -98,7 +98,7 @@ class UpdateGradientSignConsistency(Transform):
|
|
|
98
98
|
super().__init__(defaults, uses_grad=True)
|
|
99
99
|
|
|
100
100
|
@torch.no_grad
|
|
101
|
-
def
|
|
101
|
+
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
102
102
|
assert grads is not None
|
|
103
103
|
normalize, eps = itemgetter('normalize', 'eps')(settings[0])
|
|
104
104
|
|
|
@@ -108,7 +108,7 @@ class UpdateGradientSignConsistency(Transform):
|
|
|
108
108
|
return mask
|
|
109
109
|
|
|
110
110
|
class IntermoduleCautious(Module):
|
|
111
|
-
"""Negaties update on :code:`main` module where it's sign doesn't match with output of
|
|
111
|
+
"""Negaties update on :code:`main` module where it's sign doesn't match with output of ``compare`` module.
|
|
112
112
|
|
|
113
113
|
Args:
|
|
114
114
|
main (Chainable): main module or sequence of modules whose update will be cautioned.
|
|
@@ -137,29 +137,32 @@ class IntermoduleCautious(Module):
|
|
|
137
137
|
self.set_child('main', main)
|
|
138
138
|
self.set_child('compare', compare)
|
|
139
139
|
|
|
140
|
+
def update(self, objective): raise RuntimeError
|
|
141
|
+
def apply(self, objective): raise RuntimeError
|
|
142
|
+
|
|
140
143
|
@torch.no_grad
|
|
141
|
-
def step(self,
|
|
144
|
+
def step(self, objective):
|
|
142
145
|
main = self.children['main']
|
|
143
146
|
compare = self.children['compare']
|
|
144
147
|
|
|
145
|
-
main_var = main.step(
|
|
146
|
-
|
|
148
|
+
main_var = main.step(objective.clone(clone_updates=True))
|
|
149
|
+
objective.update_attrs_from_clone_(main_var)
|
|
147
150
|
|
|
148
|
-
compare_var = compare.step(
|
|
149
|
-
|
|
151
|
+
compare_var = compare.step(objective.clone(clone_updates=True))
|
|
152
|
+
objective.update_attrs_from_clone_(compare_var)
|
|
150
153
|
|
|
151
154
|
mode, normalize, eps = itemgetter('mode', 'normalize', 'eps')(self.defaults)
|
|
152
|
-
|
|
153
|
-
TensorList(main_var.
|
|
154
|
-
TensorList(compare_var.
|
|
155
|
+
objective.updates = cautious_(
|
|
156
|
+
TensorList(main_var.get_updates()),
|
|
157
|
+
TensorList(compare_var.get_updates()),
|
|
155
158
|
normalize=normalize,
|
|
156
159
|
mode=mode,
|
|
157
160
|
eps=eps,
|
|
158
161
|
)
|
|
159
162
|
|
|
160
|
-
return
|
|
163
|
+
return objective
|
|
161
164
|
|
|
162
|
-
class ScaleByGradCosineSimilarity(
|
|
165
|
+
class ScaleByGradCosineSimilarity(TensorTransform):
|
|
163
166
|
"""Multiplies the update by cosine similarity with gradient.
|
|
164
167
|
If cosine similarity is negative, naturally the update will be negated as well.
|
|
165
168
|
|
|
@@ -186,7 +189,7 @@ class ScaleByGradCosineSimilarity(Transform):
|
|
|
186
189
|
super().__init__(defaults, uses_grad=True)
|
|
187
190
|
|
|
188
191
|
@torch.no_grad
|
|
189
|
-
def
|
|
192
|
+
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
190
193
|
assert grads is not None
|
|
191
194
|
eps = settings[0]['eps']
|
|
192
195
|
tensors = TensorList(tensors)
|
|
@@ -196,8 +199,8 @@ class ScaleByGradCosineSimilarity(Transform):
|
|
|
196
199
|
return tensors.mul_(cos_sim)
|
|
197
200
|
|
|
198
201
|
class ScaleModulesByCosineSimilarity(Module):
|
|
199
|
-
"""Scales the output of
|
|
200
|
-
of
|
|
202
|
+
"""Scales the output of ``main`` module by it's cosine similarity to the output
|
|
203
|
+
of ``compare`` module.
|
|
201
204
|
|
|
202
205
|
Args:
|
|
203
206
|
main (Chainable): main module or sequence of modules whose update will be scaled.
|
|
@@ -230,22 +233,25 @@ class ScaleModulesByCosineSimilarity(Module):
|
|
|
230
233
|
self.set_child('main', main)
|
|
231
234
|
self.set_child('compare', compare)
|
|
232
235
|
|
|
236
|
+
def update(self, objective): raise RuntimeError
|
|
237
|
+
def apply(self, objective): raise RuntimeError
|
|
238
|
+
|
|
233
239
|
@torch.no_grad
|
|
234
|
-
def step(self,
|
|
240
|
+
def step(self, objective):
|
|
235
241
|
main = self.children['main']
|
|
236
242
|
compare = self.children['compare']
|
|
237
243
|
|
|
238
|
-
main_var = main.step(
|
|
239
|
-
|
|
244
|
+
main_var = main.step(objective.clone(clone_updates=True))
|
|
245
|
+
objective.update_attrs_from_clone_(main_var)
|
|
240
246
|
|
|
241
|
-
compare_var = compare.step(
|
|
242
|
-
|
|
247
|
+
compare_var = compare.step(objective.clone(clone_updates=True))
|
|
248
|
+
objective.update_attrs_from_clone_(compare_var)
|
|
243
249
|
|
|
244
|
-
m = TensorList(main_var.
|
|
245
|
-
c = TensorList(compare_var.
|
|
250
|
+
m = TensorList(main_var.get_updates())
|
|
251
|
+
c = TensorList(compare_var.get_updates())
|
|
246
252
|
eps = self.defaults['eps']
|
|
247
253
|
|
|
248
254
|
cos_sim = m.dot(c) / (m.global_vector_norm() * c.global_vector_norm()).clip(min=eps)
|
|
249
255
|
|
|
250
|
-
|
|
251
|
-
return
|
|
256
|
+
objective.updates = m.mul_(cos_sim)
|
|
257
|
+
return objective
|
|
@@ -4,12 +4,12 @@ from typing import Literal
|
|
|
4
4
|
|
|
5
5
|
import torch
|
|
6
6
|
|
|
7
|
-
from ...core import
|
|
7
|
+
from ...core import TensorTransform
|
|
8
8
|
from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
|
|
9
9
|
from ..functional import debias, ema_
|
|
10
10
|
|
|
11
11
|
|
|
12
|
-
class EMA(
|
|
12
|
+
class EMA(TensorTransform):
|
|
13
13
|
"""Maintains an exponential moving average of update.
|
|
14
14
|
|
|
15
15
|
Args:
|
|
@@ -20,12 +20,12 @@ class EMA(Transform):
|
|
|
20
20
|
ema_init (str, optional): initial values for the EMA, "zeros" or "update".
|
|
21
21
|
target (Target, optional): target to apply EMA to. Defaults to 'update'.
|
|
22
22
|
"""
|
|
23
|
-
def __init__(self, momentum:float=0.9, dampening:float=0, debiased: bool = False, lerp=True, ema_init: Literal['zeros', 'update'] = 'zeros'
|
|
23
|
+
def __init__(self, momentum:float=0.9, dampening:float=0, debiased: bool = False, lerp=True, ema_init: Literal['zeros', 'update'] = 'zeros'):
|
|
24
24
|
defaults = dict(momentum=momentum,dampening=dampening,debiased=debiased,lerp=lerp,ema_init=ema_init)
|
|
25
|
-
super().__init__(defaults, uses_grad=False
|
|
25
|
+
super().__init__(defaults, uses_grad=False)
|
|
26
26
|
|
|
27
27
|
@torch.no_grad
|
|
28
|
-
def
|
|
28
|
+
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
29
29
|
step = self.global_state['step'] = self.global_state.get('step', 0) + 1
|
|
30
30
|
|
|
31
31
|
debiased, lerp, ema_init = itemgetter('debiased','lerp','ema_init')(settings[0])
|
|
@@ -53,8 +53,8 @@ class HeavyBall(EMA):
|
|
|
53
53
|
ema_init (str, optional): initial values for the EMA, "zeros" or "update".
|
|
54
54
|
target (Target, optional): target to apply EMA to. Defaults to 'update'.
|
|
55
55
|
"""
|
|
56
|
-
def __init__(self, momentum:float=0.9, dampening:float=0, debiased: bool = False, lerp=False, ema_init: Literal['zeros', 'update'] = 'update'
|
|
57
|
-
super().__init__(momentum=momentum, dampening=dampening, debiased=debiased, lerp=lerp, ema_init=ema_init
|
|
56
|
+
def __init__(self, momentum:float=0.9, dampening:float=0, debiased: bool = False, lerp=False, ema_init: Literal['zeros', 'update'] = 'update'):
|
|
57
|
+
super().__init__(momentum=momentum, dampening=dampening, debiased=debiased, lerp=lerp, ema_init=ema_init)
|
|
58
58
|
|
|
59
59
|
def nag_(
|
|
60
60
|
tensors_: TensorList,
|
|
@@ -74,7 +74,7 @@ def nag_(
|
|
|
74
74
|
return tensors_
|
|
75
75
|
|
|
76
76
|
|
|
77
|
-
class NAG(
|
|
77
|
+
class NAG(TensorTransform):
|
|
78
78
|
"""Nesterov accelerated gradient method (nesterov momentum).
|
|
79
79
|
|
|
80
80
|
Args:
|
|
@@ -84,12 +84,12 @@ class NAG(Transform):
|
|
|
84
84
|
whether to use linear interpolation, if True, this becomes similar to exponential moving average. Defaults to False.
|
|
85
85
|
target (Target, optional): target to apply EMA to. Defaults to 'update'.
|
|
86
86
|
"""
|
|
87
|
-
def __init__(self, momentum:float=0.9, dampening:float=0, lerp=False
|
|
87
|
+
def __init__(self, momentum:float=0.9, dampening:float=0, lerp=False):
|
|
88
88
|
defaults = dict(momentum=momentum,dampening=dampening, lerp=lerp)
|
|
89
|
-
super().__init__(defaults, uses_grad=False
|
|
89
|
+
super().__init__(defaults, uses_grad=False)
|
|
90
90
|
|
|
91
91
|
@torch.no_grad
|
|
92
|
-
def
|
|
92
|
+
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
93
93
|
velocity = unpack_states(states, tensors, 'velocity', cls=TensorList)
|
|
94
94
|
lerp = self.settings[params[0]]['lerp']
|
|
95
95
|
|
|
@@ -12,8 +12,8 @@ from .binary import (
|
|
|
12
12
|
CopyMagnitude,
|
|
13
13
|
CopySign,
|
|
14
14
|
Div,
|
|
15
|
-
|
|
16
|
-
|
|
15
|
+
GraftInputToOutput,
|
|
16
|
+
GraftInputToOutput,
|
|
17
17
|
GramSchimdt,
|
|
18
18
|
Maximum,
|
|
19
19
|
Minimum,
|
|
@@ -21,7 +21,7 @@ from .binary import (
|
|
|
21
21
|
Pow,
|
|
22
22
|
RCopySign,
|
|
23
23
|
RDiv,
|
|
24
|
-
|
|
24
|
+
GraftOutputToInput,
|
|
25
25
|
RPow,
|
|
26
26
|
RSub,
|
|
27
27
|
Sub,
|
|
@@ -38,7 +38,7 @@ from .higher_level import (
|
|
|
38
38
|
from .multi import (
|
|
39
39
|
ClipModules,
|
|
40
40
|
DivModules,
|
|
41
|
-
|
|
41
|
+
Graft,
|
|
42
42
|
LerpModules,
|
|
43
43
|
MultiOperationBase,
|
|
44
44
|
PowModules,
|
|
@@ -1,90 +1,90 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
|
|
3
|
-
from ...core import
|
|
3
|
+
from ...core import TensorTransform
|
|
4
4
|
from ...utils import TensorList, unpack_states
|
|
5
5
|
|
|
6
|
-
class AccumulateSum(
|
|
6
|
+
class AccumulateSum(TensorTransform):
|
|
7
7
|
"""Accumulates sum of all past updates.
|
|
8
8
|
|
|
9
9
|
Args:
|
|
10
10
|
decay (float, optional): decays the accumulator. Defaults to 0.
|
|
11
11
|
target (Target, optional): target. Defaults to 'update'.
|
|
12
12
|
"""
|
|
13
|
-
def __init__(self, decay: float = 0
|
|
13
|
+
def __init__(self, decay: float = 0):
|
|
14
14
|
defaults = dict(decay=decay)
|
|
15
|
-
super().__init__(defaults
|
|
15
|
+
super().__init__(defaults)
|
|
16
16
|
|
|
17
17
|
@torch.no_grad
|
|
18
|
-
def
|
|
18
|
+
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
19
19
|
sum = unpack_states(states, tensors, 'sum', cls=TensorList)
|
|
20
20
|
decay = [1-s['decay'] for s in settings]
|
|
21
21
|
return sum.add_(tensors).lazy_mul(decay, clone=True)
|
|
22
22
|
|
|
23
|
-
class AccumulateMean(
|
|
23
|
+
class AccumulateMean(TensorTransform):
|
|
24
24
|
"""Accumulates mean of all past updates.
|
|
25
25
|
|
|
26
26
|
Args:
|
|
27
27
|
decay (float, optional): decays the accumulator. Defaults to 0.
|
|
28
28
|
target (Target, optional): target. Defaults to 'update'.
|
|
29
29
|
"""
|
|
30
|
-
def __init__(self, decay: float = 0
|
|
30
|
+
def __init__(self, decay: float = 0):
|
|
31
31
|
defaults = dict(decay=decay)
|
|
32
|
-
super().__init__(defaults
|
|
32
|
+
super().__init__(defaults)
|
|
33
33
|
|
|
34
34
|
@torch.no_grad
|
|
35
|
-
def
|
|
35
|
+
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
36
36
|
step = self.global_state['step'] = self.global_state.get('step', 0) + 1
|
|
37
37
|
mean = unpack_states(states, tensors, 'mean', cls=TensorList)
|
|
38
38
|
decay = [1-s['decay'] for s in settings]
|
|
39
39
|
return mean.add_(tensors).lazy_mul(decay, clone=True).div_(step)
|
|
40
40
|
|
|
41
|
-
class AccumulateProduct(
|
|
41
|
+
class AccumulateProduct(TensorTransform):
|
|
42
42
|
"""Accumulates product of all past updates.
|
|
43
43
|
|
|
44
44
|
Args:
|
|
45
45
|
decay (float, optional): decays the accumulator. Defaults to 0.
|
|
46
46
|
target (Target, optional): target. Defaults to 'update'.
|
|
47
47
|
"""
|
|
48
|
-
def __init__(self, decay: float = 0, target
|
|
48
|
+
def __init__(self, decay: float = 0, target = 'update',):
|
|
49
49
|
defaults = dict(decay=decay)
|
|
50
|
-
super().__init__(defaults
|
|
50
|
+
super().__init__(defaults)
|
|
51
51
|
|
|
52
52
|
@torch.no_grad
|
|
53
|
-
def
|
|
53
|
+
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
54
54
|
prod = unpack_states(states, tensors, 'prod', cls=TensorList)
|
|
55
55
|
decay = [1-s['decay'] for s in settings]
|
|
56
56
|
return prod.mul_(tensors).lazy_mul(decay, clone=True)
|
|
57
57
|
|
|
58
|
-
class AccumulateMaximum(
|
|
58
|
+
class AccumulateMaximum(TensorTransform):
|
|
59
59
|
"""Accumulates maximum of all past updates.
|
|
60
60
|
|
|
61
61
|
Args:
|
|
62
62
|
decay (float, optional): decays the accumulator. Defaults to 0.
|
|
63
63
|
target (Target, optional): target. Defaults to 'update'.
|
|
64
64
|
"""
|
|
65
|
-
def __init__(self, decay: float = 0
|
|
65
|
+
def __init__(self, decay: float = 0):
|
|
66
66
|
defaults = dict(decay=decay)
|
|
67
|
-
super().__init__(defaults
|
|
67
|
+
super().__init__(defaults)
|
|
68
68
|
|
|
69
69
|
@torch.no_grad
|
|
70
|
-
def
|
|
70
|
+
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
71
71
|
maximum = unpack_states(states, tensors, 'maximum', cls=TensorList)
|
|
72
72
|
decay = [1-s['decay'] for s in settings]
|
|
73
73
|
return maximum.maximum_(tensors).lazy_mul(decay, clone=True)
|
|
74
74
|
|
|
75
|
-
class AccumulateMinimum(
|
|
75
|
+
class AccumulateMinimum(TensorTransform):
|
|
76
76
|
"""Accumulates minimum of all past updates.
|
|
77
77
|
|
|
78
78
|
Args:
|
|
79
79
|
decay (float, optional): decays the accumulator. Defaults to 0.
|
|
80
80
|
target (Target, optional): target. Defaults to 'update'.
|
|
81
81
|
"""
|
|
82
|
-
def __init__(self, decay: float = 0
|
|
82
|
+
def __init__(self, decay: float = 0):
|
|
83
83
|
defaults = dict(decay=decay)
|
|
84
|
-
super().__init__(defaults
|
|
84
|
+
super().__init__(defaults)
|
|
85
85
|
|
|
86
86
|
@torch.no_grad
|
|
87
|
-
def
|
|
87
|
+
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
88
88
|
minimum = unpack_states(states, tensors, 'minimum', cls=TensorList)
|
|
89
89
|
decay = [1-s['decay'] for s in settings]
|
|
90
90
|
return minimum.minimum_(tensors).lazy_mul(decay, clone=True)
|