torchzero 0.3.14__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_identical.py +2 -2
- tests/test_module_autograd.py +586 -0
- tests/test_objective.py +188 -0
- tests/test_opts.py +47 -36
- tests/test_tensorlist.py +0 -8
- tests/test_utils_optimizer.py +0 -1
- torchzero/__init__.py +1 -1
- torchzero/core/__init__.py +8 -2
- torchzero/core/chain.py +47 -0
- torchzero/core/functional.py +103 -0
- torchzero/core/modular.py +233 -0
- torchzero/core/module.py +132 -643
- torchzero/core/objective.py +948 -0
- torchzero/core/reformulation.py +56 -23
- torchzero/core/transform.py +261 -365
- torchzero/linalg/__init__.py +10 -0
- torchzero/linalg/eigh.py +34 -0
- torchzero/linalg/linalg_utils.py +14 -0
- torchzero/{utils/linalg → linalg}/linear_operator.py +132 -34
- torchzero/linalg/matrix_power.py +28 -0
- torchzero/linalg/orthogonalize.py +95 -0
- torchzero/{utils/linalg → linalg}/qr.py +4 -2
- torchzero/{utils/linalg → linalg}/solve.py +76 -88
- torchzero/linalg/svd.py +20 -0
- torchzero/linalg/torch_linalg.py +168 -0
- torchzero/modules/__init__.py +0 -1
- torchzero/modules/adaptive/__init__.py +1 -1
- torchzero/modules/adaptive/adagrad.py +163 -213
- torchzero/modules/adaptive/adahessian.py +74 -103
- torchzero/modules/adaptive/adam.py +53 -76
- torchzero/modules/adaptive/adan.py +49 -30
- torchzero/modules/adaptive/adaptive_heavyball.py +11 -6
- torchzero/modules/adaptive/aegd.py +12 -12
- torchzero/modules/adaptive/esgd.py +98 -119
- torchzero/modules/adaptive/lion.py +5 -10
- torchzero/modules/adaptive/lmadagrad.py +87 -32
- torchzero/modules/adaptive/mars.py +5 -5
- torchzero/modules/adaptive/matrix_momentum.py +47 -51
- torchzero/modules/adaptive/msam.py +70 -52
- torchzero/modules/adaptive/muon.py +59 -124
- torchzero/modules/adaptive/natural_gradient.py +33 -28
- torchzero/modules/adaptive/orthograd.py +11 -15
- torchzero/modules/adaptive/rmsprop.py +83 -75
- torchzero/modules/adaptive/rprop.py +48 -47
- torchzero/modules/adaptive/sam.py +55 -45
- torchzero/modules/adaptive/shampoo.py +123 -129
- torchzero/modules/adaptive/soap.py +207 -143
- torchzero/modules/adaptive/sophia_h.py +106 -130
- torchzero/modules/clipping/clipping.py +15 -18
- torchzero/modules/clipping/ema_clipping.py +31 -25
- torchzero/modules/clipping/growth_clipping.py +14 -17
- torchzero/modules/conjugate_gradient/cg.py +26 -37
- torchzero/modules/experimental/__init__.py +3 -6
- torchzero/modules/experimental/coordinate_momentum.py +36 -0
- torchzero/modules/experimental/curveball.py +25 -41
- torchzero/modules/experimental/gradmin.py +2 -2
- torchzero/modules/{higher_order → experimental}/higher_order_newton.py +14 -40
- torchzero/modules/experimental/newton_solver.py +22 -53
- torchzero/modules/experimental/newtonnewton.py +20 -17
- torchzero/modules/experimental/reduce_outward_lr.py +7 -7
- torchzero/modules/experimental/scipy_newton_cg.py +21 -24
- torchzero/modules/experimental/spsa1.py +5 -5
- torchzero/modules/experimental/structural_projections.py +1 -4
- torchzero/modules/functional.py +8 -1
- torchzero/modules/grad_approximation/forward_gradient.py +7 -7
- torchzero/modules/grad_approximation/grad_approximator.py +23 -16
- torchzero/modules/grad_approximation/rfdm.py +20 -17
- torchzero/modules/least_squares/gn.py +90 -42
- torchzero/modules/line_search/__init__.py +1 -1
- torchzero/modules/line_search/_polyinterp.py +3 -1
- torchzero/modules/line_search/adaptive.py +3 -3
- torchzero/modules/line_search/backtracking.py +3 -3
- torchzero/modules/line_search/interpolation.py +160 -0
- torchzero/modules/line_search/line_search.py +42 -51
- torchzero/modules/line_search/strong_wolfe.py +5 -5
- torchzero/modules/misc/debug.py +12 -12
- torchzero/modules/misc/escape.py +10 -10
- torchzero/modules/misc/gradient_accumulation.py +10 -78
- torchzero/modules/misc/homotopy.py +16 -8
- torchzero/modules/misc/misc.py +120 -122
- torchzero/modules/misc/multistep.py +63 -61
- torchzero/modules/misc/regularization.py +49 -44
- torchzero/modules/misc/split.py +30 -28
- torchzero/modules/misc/switch.py +37 -32
- torchzero/modules/momentum/averaging.py +14 -14
- torchzero/modules/momentum/cautious.py +34 -28
- torchzero/modules/momentum/momentum.py +11 -11
- torchzero/modules/ops/__init__.py +4 -4
- torchzero/modules/ops/accumulate.py +21 -21
- torchzero/modules/ops/binary.py +67 -66
- torchzero/modules/ops/higher_level.py +19 -19
- torchzero/modules/ops/multi.py +44 -41
- torchzero/modules/ops/reduce.py +26 -23
- torchzero/modules/ops/unary.py +53 -53
- torchzero/modules/ops/utility.py +47 -46
- torchzero/modules/projections/galore.py +1 -1
- torchzero/modules/projections/projection.py +43 -43
- torchzero/modules/quasi_newton/__init__.py +2 -0
- torchzero/modules/quasi_newton/damping.py +1 -1
- torchzero/modules/quasi_newton/lbfgs.py +7 -7
- torchzero/modules/quasi_newton/lsr1.py +7 -7
- torchzero/modules/quasi_newton/quasi_newton.py +25 -16
- torchzero/modules/quasi_newton/sg2.py +292 -0
- torchzero/modules/restarts/restars.py +26 -24
- torchzero/modules/second_order/__init__.py +6 -3
- torchzero/modules/second_order/ifn.py +58 -0
- torchzero/modules/second_order/inm.py +101 -0
- torchzero/modules/second_order/multipoint.py +40 -80
- torchzero/modules/second_order/newton.py +105 -228
- torchzero/modules/second_order/newton_cg.py +102 -154
- torchzero/modules/second_order/nystrom.py +158 -178
- torchzero/modules/second_order/rsn.py +237 -0
- torchzero/modules/smoothing/laplacian.py +13 -12
- torchzero/modules/smoothing/sampling.py +11 -10
- torchzero/modules/step_size/adaptive.py +23 -23
- torchzero/modules/step_size/lr.py +15 -15
- torchzero/modules/termination/termination.py +32 -30
- torchzero/modules/trust_region/cubic_regularization.py +2 -2
- torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
- torchzero/modules/trust_region/trust_cg.py +1 -1
- torchzero/modules/trust_region/trust_region.py +27 -22
- torchzero/modules/variance_reduction/svrg.py +21 -18
- torchzero/modules/weight_decay/__init__.py +2 -1
- torchzero/modules/weight_decay/reinit.py +83 -0
- torchzero/modules/weight_decay/weight_decay.py +12 -13
- torchzero/modules/wrappers/optim_wrapper.py +57 -50
- torchzero/modules/zeroth_order/cd.py +9 -6
- torchzero/optim/root.py +3 -3
- torchzero/optim/utility/split.py +2 -1
- torchzero/optim/wrappers/directsearch.py +27 -63
- torchzero/optim/wrappers/fcmaes.py +14 -35
- torchzero/optim/wrappers/mads.py +11 -31
- torchzero/optim/wrappers/moors.py +66 -0
- torchzero/optim/wrappers/nevergrad.py +4 -4
- torchzero/optim/wrappers/nlopt.py +31 -25
- torchzero/optim/wrappers/optuna.py +6 -13
- torchzero/optim/wrappers/pybobyqa.py +124 -0
- torchzero/optim/wrappers/scipy/__init__.py +7 -0
- torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
- torchzero/optim/wrappers/scipy/brute.py +48 -0
- torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
- torchzero/optim/wrappers/scipy/direct.py +69 -0
- torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
- torchzero/optim/wrappers/scipy/experimental.py +141 -0
- torchzero/optim/wrappers/scipy/minimize.py +151 -0
- torchzero/optim/wrappers/scipy/sgho.py +111 -0
- torchzero/optim/wrappers/wrapper.py +121 -0
- torchzero/utils/__init__.py +7 -25
- torchzero/utils/compile.py +2 -2
- torchzero/utils/derivatives.py +112 -88
- torchzero/utils/optimizer.py +4 -77
- torchzero/utils/python_tools.py +31 -0
- torchzero/utils/tensorlist.py +11 -5
- torchzero/utils/thoad_tools.py +68 -0
- {torchzero-0.3.14.dist-info → torchzero-0.4.0.dist-info}/METADATA +1 -1
- torchzero-0.4.0.dist-info/RECORD +191 -0
- tests/test_vars.py +0 -185
- torchzero/modules/experimental/momentum.py +0 -160
- torchzero/modules/higher_order/__init__.py +0 -1
- torchzero/optim/wrappers/scipy.py +0 -572
- torchzero/utils/linalg/__init__.py +0 -12
- torchzero/utils/linalg/matrix_funcs.py +0 -87
- torchzero/utils/linalg/orthogonalize.py +0 -12
- torchzero/utils/linalg/svd.py +0 -20
- torchzero/utils/ops.py +0 -10
- torchzero-0.3.14.dist-info/RECORD +0 -167
- /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
- {torchzero-0.3.14.dist-info → torchzero-0.4.0.dist-info}/WHEEL +0 -0
- {torchzero-0.3.14.dist-info → torchzero-0.4.0.dist-info}/top_level.txt +0 -0
|
@@ -1,12 +1,12 @@
|
|
|
1
1
|
import torch
|
|
2
|
-
from ...core import
|
|
2
|
+
from ...core import Transform
|
|
3
3
|
|
|
4
4
|
from ...utils.derivatives import jacobian_wrt, flatten_jacobian
|
|
5
|
-
from ...utils import vec_to_tensors
|
|
6
|
-
from ...
|
|
5
|
+
from ...utils import vec_to_tensors
|
|
6
|
+
from ...linalg import linear_operator
|
|
7
7
|
from .lmadagrad import lm_adagrad_apply, lm_adagrad_update
|
|
8
8
|
|
|
9
|
-
class NaturalGradient(
|
|
9
|
+
class NaturalGradient(Transform):
|
|
10
10
|
"""Natural gradient approximated via empirical fisher information matrix.
|
|
11
11
|
|
|
12
12
|
To use this, either pass vector of per-sample losses to the step method, or make sure
|
|
@@ -27,9 +27,9 @@ class NaturalGradient(Module):
|
|
|
27
27
|
with a vector that isn't strictly per-sample gradients, but rather for example different losses.
|
|
28
28
|
gn_grad (bool, optional):
|
|
29
29
|
if True, uses Gauss-Newton G^T @ f as the gradient, which is effectively sum weighted by value
|
|
30
|
-
and is equivalent to squaring the values.
|
|
31
|
-
|
|
32
|
-
This has an effect when ``sqrt=
|
|
30
|
+
and is equivalent to squaring the values. That makes the kernel trick solver incorrect, but for
|
|
31
|
+
some reason it still works. If False, uses sum of per-sample gradients.
|
|
32
|
+
This has an effect when ``sqrt=False``, and affects the ``grad`` attribute.
|
|
33
33
|
Defaults to False.
|
|
34
34
|
batched (bool, optional): whether to use vmapping. Defaults to True.
|
|
35
35
|
|
|
@@ -97,20 +97,21 @@ class NaturalGradient(Module):
|
|
|
97
97
|
super().__init__(defaults=dict(batched=batched, reg=reg, sqrt=sqrt, gn_grad=gn_grad))
|
|
98
98
|
|
|
99
99
|
@torch.no_grad
|
|
100
|
-
def
|
|
101
|
-
params =
|
|
102
|
-
|
|
103
|
-
|
|
100
|
+
def update_states(self, objective, states, settings):
|
|
101
|
+
params = objective.params
|
|
102
|
+
fs = settings[0]
|
|
103
|
+
batched = fs['batched']
|
|
104
|
+
gn_grad = fs['gn_grad']
|
|
104
105
|
|
|
105
|
-
closure =
|
|
106
|
+
closure = objective.closure
|
|
106
107
|
assert closure is not None
|
|
107
108
|
|
|
108
109
|
with torch.enable_grad():
|
|
109
|
-
f =
|
|
110
|
+
f = objective.get_loss(backward=False) # n_out
|
|
110
111
|
assert isinstance(f, torch.Tensor)
|
|
111
112
|
G_list = jacobian_wrt([f.ravel()], params, batched=batched)
|
|
112
113
|
|
|
113
|
-
|
|
114
|
+
objective.loss = f.sum()
|
|
114
115
|
G = self.global_state["G"] = flatten_jacobian(G_list) # (n_samples, ndim)
|
|
115
116
|
|
|
116
117
|
if gn_grad:
|
|
@@ -119,13 +120,13 @@ class NaturalGradient(Module):
|
|
|
119
120
|
else:
|
|
120
121
|
g = self.global_state["g"] = G.sum(0)
|
|
121
122
|
|
|
122
|
-
|
|
123
|
+
objective.grads = vec_to_tensors(g, params)
|
|
123
124
|
|
|
124
125
|
# set closure to calculate scalar value for line searches etc
|
|
125
|
-
if
|
|
126
|
+
if objective.closure is not None:
|
|
126
127
|
def ngd_closure(backward=True):
|
|
127
128
|
if backward:
|
|
128
|
-
|
|
129
|
+
objective.zero_grad()
|
|
129
130
|
with torch.enable_grad():
|
|
130
131
|
loss = closure(False)
|
|
131
132
|
if gn_grad: loss = loss.pow(2)
|
|
@@ -137,13 +138,14 @@ class NaturalGradient(Module):
|
|
|
137
138
|
if gn_grad: loss = loss.pow(2)
|
|
138
139
|
return loss.sum()
|
|
139
140
|
|
|
140
|
-
|
|
141
|
+
objective.closure = ngd_closure
|
|
141
142
|
|
|
142
143
|
@torch.no_grad
|
|
143
|
-
def
|
|
144
|
-
params =
|
|
145
|
-
|
|
146
|
-
|
|
144
|
+
def apply_states(self, objective, states, settings):
|
|
145
|
+
params = objective.params
|
|
146
|
+
fs = settings[0]
|
|
147
|
+
reg = fs['reg']
|
|
148
|
+
sqrt = fs['sqrt']
|
|
147
149
|
|
|
148
150
|
G: torch.Tensor = self.global_state['G'] # (n_samples, n_dim)
|
|
149
151
|
|
|
@@ -151,12 +153,15 @@ class NaturalGradient(Module):
|
|
|
151
153
|
# this computes U, S <- SVD(M), then calculate update as U S^-1 Uᵀg,
|
|
152
154
|
# but it computes it through eigendecompotision
|
|
153
155
|
U, L = lm_adagrad_update(G.H, reg, 0)
|
|
154
|
-
if U is None or L is None: return
|
|
156
|
+
if U is None or L is None: return objective
|
|
155
157
|
|
|
156
158
|
v = lm_adagrad_apply(self.global_state["g"], U, L)
|
|
157
|
-
|
|
158
|
-
return
|
|
159
|
+
objective.updates = vec_to_tensors(v, params)
|
|
160
|
+
return objective
|
|
159
161
|
|
|
162
|
+
# we need (G^T G)v = g
|
|
163
|
+
# where g = G^T
|
|
164
|
+
# so we need to solve (G^T G)v = G^T
|
|
160
165
|
GGT = G @ G.H # (n_samples, n_samples)
|
|
161
166
|
|
|
162
167
|
if reg != 0:
|
|
@@ -165,11 +170,11 @@ class NaturalGradient(Module):
|
|
|
165
170
|
z, _ = torch.linalg.solve_ex(GGT, torch.ones_like(GGT[0])) # pylint:disable=not-callable
|
|
166
171
|
v = G.H @ z
|
|
167
172
|
|
|
168
|
-
|
|
169
|
-
return
|
|
173
|
+
objective.updates = vec_to_tensors(v, params)
|
|
174
|
+
return objective
|
|
170
175
|
|
|
171
176
|
|
|
172
|
-
def get_H(self,
|
|
177
|
+
def get_H(self, objective=...):
|
|
173
178
|
if "G" not in self.global_state: return linear_operator.ScaledIdentity()
|
|
174
179
|
G = self.global_state['G']
|
|
175
180
|
return linear_operator.AtA(G)
|
|
@@ -1,13 +1,9 @@
|
|
|
1
|
-
from
|
|
2
|
-
import math
|
|
3
|
-
import warnings
|
|
4
|
-
from collections.abc import Iterable, Sequence
|
|
5
|
-
from typing import Literal
|
|
1
|
+
from collections.abc import Iterable
|
|
6
2
|
|
|
7
3
|
import torch
|
|
8
4
|
|
|
9
|
-
from ...core import
|
|
10
|
-
from ...utils import
|
|
5
|
+
from ...core import TensorTransform
|
|
6
|
+
from ...utils import TensorList
|
|
11
7
|
|
|
12
8
|
def orthograd_(params: Iterable[torch.Tensor], eps: float = 1e-30):
|
|
13
9
|
"""Applies ⟂Grad - projects gradient of an iterable of parameters to be orthogonal to the weights.
|
|
@@ -19,29 +15,29 @@ def orthograd_(params: Iterable[torch.Tensor], eps: float = 1e-30):
|
|
|
19
15
|
reference
|
|
20
16
|
https://arxiv.org/abs/2501.04697
|
|
21
17
|
"""
|
|
22
|
-
params =
|
|
18
|
+
params = TensorList(params).with_grad()
|
|
23
19
|
grad = params.grad
|
|
24
20
|
grad -= (params.dot(grad)/(params.dot(params) + eps)) * params
|
|
25
21
|
|
|
26
22
|
|
|
27
|
-
class OrthoGrad(
|
|
23
|
+
class OrthoGrad(TensorTransform):
|
|
28
24
|
"""Applies ⟂Grad - projects gradient of an iterable of parameters to be orthogonal to the weights.
|
|
29
25
|
|
|
30
26
|
Args:
|
|
31
27
|
eps (float, optional): epsilon added to the denominator for numerical stability (default: 1e-30)
|
|
32
28
|
renormalize (bool, optional): whether to graft projected gradient to original gradient norm. Defaults to True.
|
|
33
|
-
target (Target, optional): what to set on var. Defaults to 'update'.
|
|
34
29
|
"""
|
|
35
|
-
def __init__(self, eps: float = 1e-8, renormalize=True
|
|
30
|
+
def __init__(self, eps: float = 1e-8, renormalize=True):
|
|
36
31
|
defaults = dict(eps=eps, renormalize=renormalize)
|
|
37
|
-
super().__init__(defaults
|
|
32
|
+
super().__init__(defaults)
|
|
38
33
|
|
|
39
|
-
|
|
34
|
+
@torch.no_grad
|
|
35
|
+
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
40
36
|
eps = settings[0]['eps']
|
|
41
37
|
renormalize = settings[0]['renormalize']
|
|
42
38
|
|
|
43
|
-
params =
|
|
44
|
-
target =
|
|
39
|
+
params = TensorList(params)
|
|
40
|
+
target = TensorList(tensors)
|
|
45
41
|
|
|
46
42
|
scale = params.dot(target)/(params.dot(params) + eps)
|
|
47
43
|
if renormalize:
|
|
@@ -1,45 +1,11 @@
|
|
|
1
|
-
from operator import itemgetter
|
|
2
1
|
from typing import Literal
|
|
3
2
|
|
|
4
3
|
import torch
|
|
5
4
|
|
|
6
|
-
from ...core import
|
|
5
|
+
from ...core import TensorTransform, Chainable
|
|
7
6
|
from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
def rmsprop_(
|
|
12
|
-
tensors_: TensorList,
|
|
13
|
-
exp_avg_sq_: TensorList,
|
|
14
|
-
smoothing: float | NumberList,
|
|
15
|
-
eps: float | NumberList,
|
|
16
|
-
debiased: bool,
|
|
17
|
-
step: int,
|
|
18
|
-
exp_avg_: TensorList | None = None,
|
|
19
|
-
max_exp_avg_sq_: TensorList | None = None,
|
|
20
|
-
pow: float = 2,
|
|
21
|
-
|
|
22
|
-
# inner args
|
|
23
|
-
inner: Module | None = None,
|
|
24
|
-
params: list[torch.Tensor] | None = None,
|
|
25
|
-
grads: list[torch.Tensor] | None = None,
|
|
26
|
-
):
|
|
27
|
-
"""returns `tensors_`"""
|
|
28
|
-
if exp_avg_ is not None:
|
|
29
|
-
sqrt_exp_avg_sq = sqrt_centered_ema_sq_(tensors=tensors_, exp_avg_=exp_avg_,
|
|
30
|
-
exp_avg_sq_=exp_avg_sq_,max_exp_avg_sq_=max_exp_avg_sq_,
|
|
31
|
-
beta=smoothing,debiased=debiased,step=step,pow=pow)
|
|
32
|
-
else:
|
|
33
|
-
sqrt_exp_avg_sq = sqrt_ema_sq_(tensors=tensors_,exp_avg_sq_=exp_avg_sq_,max_exp_avg_sq_=max_exp_avg_sq_,
|
|
34
|
-
beta=smoothing,debiased=debiased,step=step,pow=pow)
|
|
35
|
-
|
|
36
|
-
if inner is not None:
|
|
37
|
-
assert params is not None
|
|
38
|
-
tensors_ = TensorList(apply_transform(inner, tensors_, params=params, grads=grads))
|
|
39
|
-
|
|
40
|
-
return tensors_.div_(sqrt_exp_avg_sq.add_(eps))
|
|
41
|
-
|
|
42
|
-
class RMSprop(Transform):
|
|
7
|
+
|
|
8
|
+
class RMSprop(TensorTransform):
|
|
43
9
|
"""Divides graient by EMA of gradient squares.
|
|
44
10
|
|
|
45
11
|
This implementation is identical to :code:`torch.optim.RMSprop`.
|
|
@@ -48,7 +14,7 @@ class RMSprop(Transform):
|
|
|
48
14
|
smoothing (float, optional): beta for exponential moving average of gradient squares. Defaults to 0.99.
|
|
49
15
|
eps (float, optional): epsilon for division. Defaults to 1e-8.
|
|
50
16
|
centered (bool, optional): whether to center EMA of gradient squares using an additional EMA. Defaults to False.
|
|
51
|
-
|
|
17
|
+
debias (bool, optional): applies Adam debiasing. Defaults to False.
|
|
52
18
|
amsgrad (bool, optional): Whether to divide by maximum of EMA of gradient squares instead. Defaults to False.
|
|
53
19
|
pow (float, optional): power used in second momentum power and root. Defaults to 2.
|
|
54
20
|
init (str, optional): how to initialize EMA, either "update" to use first update or "zeros". Defaults to "update".
|
|
@@ -60,44 +26,86 @@ class RMSprop(Transform):
|
|
|
60
26
|
smoothing: float = 0.99,
|
|
61
27
|
eps: float = 1e-8,
|
|
62
28
|
centered: bool = False,
|
|
63
|
-
|
|
29
|
+
debias: bool = False,
|
|
64
30
|
amsgrad: bool = False,
|
|
65
|
-
pow: float = 2,
|
|
66
31
|
init: Literal["zeros", "update"] = "zeros",
|
|
32
|
+
|
|
67
33
|
inner: Chainable | None = None,
|
|
34
|
+
exp_avg_sq_tfm: Chainable | None = None,
|
|
68
35
|
):
|
|
69
|
-
defaults =
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
36
|
+
defaults = locals().copy()
|
|
37
|
+
del defaults['self'], defaults["inner"], defaults["exp_avg_sq_tfm"]
|
|
38
|
+
super().__init__(defaults, inner=inner)
|
|
39
|
+
|
|
40
|
+
self.set_child('exp_avg_sq', exp_avg_sq_tfm)
|
|
41
|
+
|
|
42
|
+
@torch.no_grad
|
|
43
|
+
def single_tensor_initialize(self, tensor, param, grad, loss, state, setting):
|
|
44
|
+
if setting["init"] == "zeros":
|
|
45
|
+
state["exp_avg_sq"] = torch.zeros_like(tensor)
|
|
46
|
+
if setting["centered"]: state["exp_avg"] = torch.zeros_like(tensor)
|
|
47
|
+
if setting["amsgrad"]: state["amsgrad"] = torch.zeros_like(tensor)
|
|
48
|
+
|
|
49
|
+
else:
|
|
50
|
+
state["exp_avg_sq"] = tensor ** 2
|
|
51
|
+
if setting["centered"]: state["exp_avg"] = tensor.clone()
|
|
52
|
+
if setting["amsgrad"]: state["amsgrad"] = tensor ** 2
|
|
53
|
+
|
|
54
|
+
@torch.no_grad
|
|
55
|
+
def multi_tensor_update(self, tensors, params, grads, loss, states, settings):
|
|
56
|
+
self.increment_counter("step", start = 0)
|
|
57
|
+
fs = settings[0]
|
|
58
|
+
|
|
59
|
+
exp_avg_sq = unpack_states(states, tensors, "exp_avg_sq", cls=TensorList)
|
|
60
|
+
|
|
61
|
+
# update exponential average
|
|
62
|
+
smoothing = NumberList(s["smoothing"] for s in settings)
|
|
63
|
+
exp_avg_sq.mul_(smoothing).addcmul_(tensors, tensors, value=1-smoothing)
|
|
64
|
+
|
|
65
|
+
# update mean estimate if centered
|
|
66
|
+
if fs["centered"]:
|
|
67
|
+
exp_avg = unpack_states(states, tensors, "exp_avg", cls=TensorList)
|
|
68
|
+
exp_avg.lerp_(tensors, 1-smoothing)
|
|
69
|
+
|
|
70
|
+
# amsgrad
|
|
71
|
+
if fs["amsgrad"]:
|
|
72
|
+
exp_avg_sq_max = unpack_states(states, tensors, "exp_avg_sq_max", cls=TensorList)
|
|
73
|
+
exp_avg_sq_max.maximum_(exp_avg_sq)
|
|
74
|
+
|
|
75
|
+
@torch.no_grad
|
|
76
|
+
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
77
|
+
tensors = TensorList(tensors)
|
|
78
|
+
step = self.global_state["step"] # 0 on 1st step
|
|
79
|
+
eps = NumberList(s["eps"] for s in settings)
|
|
80
|
+
fs = settings[0]
|
|
81
|
+
|
|
82
|
+
if fs["amsgrad"]: key = "max_exp_avg_sq"
|
|
83
|
+
else: key = "exp_avg_sq"
|
|
84
|
+
exp_avg_sq = TensorList(s[key] for s in states)
|
|
85
|
+
|
|
86
|
+
# load mean estimate if centered
|
|
87
|
+
exp_avg = None
|
|
88
|
+
if fs['centered']:
|
|
89
|
+
exp_avg = TensorList(s["exp_avg"] for s in states)
|
|
90
|
+
|
|
91
|
+
# debias exp_avg_sq and exp_avg
|
|
92
|
+
if fs["debias"]:
|
|
93
|
+
smoothing = NumberList(s["smoothing"] for s in settings)
|
|
94
|
+
bias_correction = 1 - (smoothing ** (step + 1))
|
|
95
|
+
exp_avg_sq = exp_avg_sq / bias_correction
|
|
96
|
+
|
|
97
|
+
if fs['centered']:
|
|
98
|
+
assert exp_avg is not None
|
|
99
|
+
exp_avg = exp_avg / bias_correction
|
|
100
|
+
|
|
101
|
+
# apply transform to potentially debiased exp_avg_sq
|
|
102
|
+
exp_avg_sq = TensorList(self.inner_step_tensors(
|
|
103
|
+
"exp_avg_sq", exp_avg_sq, params=params, grads=grads, loss=loss, clone=True, must_exist=False
|
|
104
|
+
))
|
|
105
|
+
|
|
106
|
+
# center
|
|
107
|
+
if fs["centered"]:
|
|
108
|
+
assert exp_avg is not None
|
|
109
|
+
exp_avg_sq = exp_avg_sq.addcmul(exp_avg, exp_avg, value=-1)
|
|
110
|
+
|
|
111
|
+
return tensors.div_(exp_avg_sq.sqrt().add_(eps))
|
|
@@ -1,8 +1,8 @@
|
|
|
1
1
|
|
|
2
2
|
import torch
|
|
3
3
|
|
|
4
|
-
from ...core import
|
|
5
|
-
from ...utils import NumberList, TensorList,
|
|
4
|
+
from ...core import TensorTransform
|
|
5
|
+
from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
|
|
6
6
|
|
|
7
7
|
|
|
8
8
|
def _bool_ones_like(x):
|
|
@@ -126,7 +126,7 @@ def rprop_(
|
|
|
126
126
|
|
|
127
127
|
|
|
128
128
|
|
|
129
|
-
class Rprop(
|
|
129
|
+
class Rprop(TensorTransform):
|
|
130
130
|
"""
|
|
131
131
|
Resilient propagation. The update magnitude gets multiplied by `nplus` if gradient didn't change the sign,
|
|
132
132
|
or `nminus` if it did. Then the update is applied with the sign of the current gradient.
|
|
@@ -165,7 +165,7 @@ class Rprop(Transform):
|
|
|
165
165
|
super().__init__(defaults, uses_grad=False)
|
|
166
166
|
|
|
167
167
|
@torch.no_grad
|
|
168
|
-
def
|
|
168
|
+
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
169
169
|
step = self.global_state.get('step', 0)
|
|
170
170
|
self.global_state['step'] = step + 1
|
|
171
171
|
|
|
@@ -178,7 +178,7 @@ class Rprop(Transform):
|
|
|
178
178
|
)
|
|
179
179
|
|
|
180
180
|
tensors = rprop_(
|
|
181
|
-
tensors_ =
|
|
181
|
+
tensors_ = TensorList(tensors),
|
|
182
182
|
prev_ = prev,
|
|
183
183
|
allowed_ = allowed,
|
|
184
184
|
magnitudes_ = magnitudes,
|
|
@@ -194,7 +194,7 @@ class Rprop(Transform):
|
|
|
194
194
|
return tensors
|
|
195
195
|
|
|
196
196
|
|
|
197
|
-
class ScaleLRBySignChange(
|
|
197
|
+
class ScaleLRBySignChange(TensorTransform):
|
|
198
198
|
"""
|
|
199
199
|
learning rate gets multiplied by `nplus` if ascent/gradient didn't change the sign,
|
|
200
200
|
or `nminus` if it did.
|
|
@@ -218,19 +218,19 @@ class ScaleLRBySignChange(Transform):
|
|
|
218
218
|
ub=50.0,
|
|
219
219
|
alpha=1.0,
|
|
220
220
|
use_grad=False,
|
|
221
|
-
target: Target = "update",
|
|
222
221
|
):
|
|
223
222
|
defaults = dict(nplus=nplus, nminus=nminus, alpha=alpha, lb=lb, ub=ub, use_grad=use_grad)
|
|
224
|
-
super().__init__(defaults, uses_grad=use_grad
|
|
223
|
+
super().__init__(defaults, uses_grad=use_grad)
|
|
225
224
|
|
|
226
225
|
@torch.no_grad
|
|
227
|
-
def
|
|
226
|
+
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
228
227
|
step = self.global_state.get('step', 0)
|
|
229
228
|
self.global_state['step'] = step + 1
|
|
230
229
|
|
|
231
|
-
tensors =
|
|
232
|
-
|
|
233
|
-
|
|
230
|
+
tensors = TensorList(tensors)
|
|
231
|
+
if self._uses_grad:
|
|
232
|
+
assert grads is not None
|
|
233
|
+
cur = TensorList(grads)
|
|
234
234
|
else: cur = tensors
|
|
235
235
|
|
|
236
236
|
nplus, nminus, lb, ub = unpack_dicts(settings, 'nplus', 'nminus', 'lb', 'ub', cls=NumberList)
|
|
@@ -252,7 +252,7 @@ class ScaleLRBySignChange(Transform):
|
|
|
252
252
|
)
|
|
253
253
|
return tensors
|
|
254
254
|
|
|
255
|
-
class BacktrackOnSignChange(
|
|
255
|
+
class BacktrackOnSignChange(TensorTransform):
|
|
256
256
|
"""Negates or undoes update for parameters where where gradient or update sign changes.
|
|
257
257
|
|
|
258
258
|
This is part of RProp update rule.
|
|
@@ -266,20 +266,21 @@ class BacktrackOnSignChange(Transform):
|
|
|
266
266
|
Defaults to True.
|
|
267
267
|
|
|
268
268
|
"""
|
|
269
|
-
def __init__(self, use_grad = False, backtrack = True
|
|
270
|
-
defaults = dict(use_grad=use_grad, backtrack=backtrack
|
|
269
|
+
def __init__(self, use_grad = False, backtrack = True):
|
|
270
|
+
defaults = dict(use_grad=use_grad, backtrack=backtrack)
|
|
271
271
|
super().__init__(defaults, uses_grad=use_grad)
|
|
272
272
|
|
|
273
273
|
@torch.no_grad
|
|
274
|
-
def
|
|
274
|
+
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
275
275
|
step = self.global_state.get('step', 0)
|
|
276
276
|
self.global_state['step'] = step + 1
|
|
277
277
|
|
|
278
|
-
tensors =
|
|
279
|
-
use_grad = settings[0]['use_grad']
|
|
278
|
+
tensors = TensorList(tensors)
|
|
280
279
|
backtrack = settings[0]['backtrack']
|
|
281
280
|
|
|
282
|
-
if
|
|
281
|
+
if self._uses_grad:
|
|
282
|
+
assert grads is not None
|
|
283
|
+
cur = TensorList(grads)
|
|
283
284
|
else: cur = tensors
|
|
284
285
|
|
|
285
286
|
tensors = backtrack_on_sign_change_(
|
|
@@ -292,54 +293,55 @@ class BacktrackOnSignChange(Transform):
|
|
|
292
293
|
|
|
293
294
|
return tensors
|
|
294
295
|
|
|
295
|
-
class SignConsistencyMask(
|
|
296
|
+
class SignConsistencyMask(TensorTransform):
|
|
296
297
|
"""
|
|
297
298
|
Outputs a mask of sign consistency of current and previous inputs.
|
|
298
299
|
|
|
299
300
|
The output is 0 for weights where input sign changed compared to previous input, 1 otherwise.
|
|
300
301
|
|
|
301
|
-
Examples:
|
|
302
|
-
|
|
303
|
-
GD that skips update for weights where gradient sign changed compared to previous gradient.
|
|
302
|
+
### Examples:
|
|
304
303
|
|
|
305
|
-
|
|
304
|
+
GD that skips update for weights where gradient sign changed compared to previous gradient.
|
|
306
305
|
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
306
|
+
```python
|
|
307
|
+
opt = tz.Modular(
|
|
308
|
+
model.parameters(),
|
|
309
|
+
tz.m.Mul(tz.m.SignConsistencyMask()),
|
|
310
|
+
tz.m.LR(1e-2)
|
|
311
|
+
)
|
|
312
|
+
```
|
|
312
313
|
|
|
313
314
|
"""
|
|
314
|
-
def __init__(self
|
|
315
|
-
super().__init__(
|
|
315
|
+
def __init__(self):
|
|
316
|
+
super().__init__()
|
|
316
317
|
|
|
317
318
|
@torch.no_grad
|
|
318
|
-
def
|
|
319
|
+
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
319
320
|
prev = unpack_states(states, tensors, 'prev', cls=TensorList)
|
|
320
321
|
mask = prev.mul_(tensors).gt_(0)
|
|
321
322
|
prev.copy_(tensors)
|
|
322
323
|
return mask
|
|
323
324
|
|
|
324
325
|
|
|
325
|
-
class SignConsistencyLRs(
|
|
326
|
+
class SignConsistencyLRs(TensorTransform):
|
|
326
327
|
"""Outputs per-weight learning rates based on consecutive sign consistency.
|
|
327
328
|
|
|
328
|
-
The learning rate for a weight is multiplied by
|
|
329
|
+
The learning rate for a weight is multiplied by ``nplus`` when two consecutive update signs are the same, otherwise it is multiplied by ``nplus``. The learning rates are bounded to be in ``(lb, ub)`` range.
|
|
329
330
|
|
|
330
|
-
Examples:
|
|
331
|
+
### Examples:
|
|
331
332
|
|
|
332
|
-
|
|
333
|
+
GD scaled by consecutive gradient sign consistency
|
|
333
334
|
|
|
334
|
-
|
|
335
|
+
```python
|
|
335
336
|
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
337
|
+
opt = tz.Modular(
|
|
338
|
+
model.parameters(),
|
|
339
|
+
tz.m.Mul(tz.m.SignConsistencyLRs()),
|
|
340
|
+
tz.m.LR(1e-2)
|
|
341
|
+
)
|
|
342
|
+
```
|
|
341
343
|
|
|
342
|
-
|
|
344
|
+
"""
|
|
343
345
|
def __init__(
|
|
344
346
|
self,
|
|
345
347
|
nplus: float = 1.2,
|
|
@@ -347,17 +349,16 @@ class SignConsistencyLRs(Transform):
|
|
|
347
349
|
lb: float | None = 1e-6,
|
|
348
350
|
ub: float | None = 50,
|
|
349
351
|
alpha: float = 1,
|
|
350
|
-
target: Target = 'update'
|
|
351
352
|
):
|
|
352
353
|
defaults = dict(nplus = nplus, nminus = nminus, alpha = alpha, lb = lb, ub = ub)
|
|
353
|
-
super().__init__(defaults, uses_grad=False
|
|
354
|
+
super().__init__(defaults, uses_grad=False)
|
|
354
355
|
|
|
355
356
|
@torch.no_grad
|
|
356
|
-
def
|
|
357
|
+
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
357
358
|
step = self.global_state.get('step', 0)
|
|
358
359
|
self.global_state['step'] = step + 1
|
|
359
360
|
|
|
360
|
-
target =
|
|
361
|
+
target = TensorList(tensors)
|
|
361
362
|
nplus, nminus, lb, ub = unpack_dicts(settings, 'nplus', 'nminus', 'lb', 'ub', cls=NumberList)
|
|
362
363
|
prev, lrs = unpack_states(states, tensors, 'prev', 'lrs', cls=TensorList)
|
|
363
364
|
|