torchzero 0.3.10__py3-none-any.whl → 0.3.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_identical.py +2 -3
- tests/test_opts.py +140 -100
- tests/test_tensorlist.py +8 -7
- tests/test_vars.py +1 -0
- torchzero/__init__.py +1 -1
- torchzero/core/__init__.py +2 -2
- torchzero/core/module.py +335 -50
- torchzero/core/reformulation.py +65 -0
- torchzero/core/transform.py +197 -70
- torchzero/modules/__init__.py +13 -4
- torchzero/modules/adaptive/__init__.py +30 -0
- torchzero/modules/adaptive/adagrad.py +356 -0
- torchzero/modules/adaptive/adahessian.py +224 -0
- torchzero/modules/{optimizers → adaptive}/adam.py +6 -8
- torchzero/modules/adaptive/adan.py +96 -0
- torchzero/modules/adaptive/adaptive_heavyball.py +54 -0
- torchzero/modules/adaptive/aegd.py +54 -0
- torchzero/modules/adaptive/esgd.py +171 -0
- torchzero/modules/{optimizers → adaptive}/lion.py +1 -1
- torchzero/modules/{experimental/spectral.py → adaptive/lmadagrad.py} +94 -71
- torchzero/modules/adaptive/mars.py +79 -0
- torchzero/modules/adaptive/matrix_momentum.py +146 -0
- torchzero/modules/adaptive/msam.py +188 -0
- torchzero/modules/{optimizers → adaptive}/muon.py +29 -5
- torchzero/modules/adaptive/natural_gradient.py +175 -0
- torchzero/modules/{optimizers → adaptive}/orthograd.py +1 -1
- torchzero/modules/{optimizers → adaptive}/rmsprop.py +7 -4
- torchzero/modules/{optimizers → adaptive}/rprop.py +42 -10
- torchzero/modules/adaptive/sam.py +163 -0
- torchzero/modules/{optimizers → adaptive}/shampoo.py +47 -9
- torchzero/modules/{optimizers → adaptive}/soap.py +52 -65
- torchzero/modules/adaptive/sophia_h.py +185 -0
- torchzero/modules/clipping/clipping.py +115 -25
- torchzero/modules/clipping/ema_clipping.py +31 -17
- torchzero/modules/clipping/growth_clipping.py +8 -7
- torchzero/modules/conjugate_gradient/__init__.py +11 -0
- torchzero/modules/conjugate_gradient/cg.py +355 -0
- torchzero/modules/experimental/__init__.py +13 -19
- torchzero/modules/{projections → experimental}/dct.py +11 -11
- torchzero/modules/{projections → experimental}/fft.py +10 -10
- torchzero/modules/experimental/gradmin.py +4 -3
- torchzero/modules/experimental/l_infinity.py +111 -0
- torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +5 -42
- torchzero/modules/experimental/newton_solver.py +79 -17
- torchzero/modules/experimental/newtonnewton.py +32 -15
- torchzero/modules/experimental/reduce_outward_lr.py +4 -4
- torchzero/modules/experimental/scipy_newton_cg.py +105 -0
- torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +13 -55
- torchzero/modules/functional.py +52 -6
- torchzero/modules/grad_approximation/fdm.py +30 -4
- torchzero/modules/grad_approximation/forward_gradient.py +16 -4
- torchzero/modules/grad_approximation/grad_approximator.py +51 -10
- torchzero/modules/grad_approximation/rfdm.py +321 -52
- torchzero/modules/higher_order/__init__.py +1 -1
- torchzero/modules/higher_order/higher_order_newton.py +164 -93
- torchzero/modules/least_squares/__init__.py +1 -0
- torchzero/modules/least_squares/gn.py +161 -0
- torchzero/modules/line_search/__init__.py +4 -4
- torchzero/modules/line_search/_polyinterp.py +289 -0
- torchzero/modules/line_search/adaptive.py +124 -0
- torchzero/modules/line_search/backtracking.py +95 -57
- torchzero/modules/line_search/line_search.py +171 -22
- torchzero/modules/line_search/scipy.py +3 -3
- torchzero/modules/line_search/strong_wolfe.py +327 -199
- torchzero/modules/misc/__init__.py +35 -0
- torchzero/modules/misc/debug.py +48 -0
- torchzero/modules/misc/escape.py +62 -0
- torchzero/modules/misc/gradient_accumulation.py +136 -0
- torchzero/modules/misc/homotopy.py +59 -0
- torchzero/modules/misc/misc.py +383 -0
- torchzero/modules/misc/multistep.py +194 -0
- torchzero/modules/misc/regularization.py +167 -0
- torchzero/modules/misc/split.py +123 -0
- torchzero/modules/{ops → misc}/switch.py +45 -4
- torchzero/modules/momentum/__init__.py +1 -5
- torchzero/modules/momentum/averaging.py +9 -9
- torchzero/modules/momentum/cautious.py +51 -19
- torchzero/modules/momentum/momentum.py +37 -2
- torchzero/modules/ops/__init__.py +11 -31
- torchzero/modules/ops/accumulate.py +6 -10
- torchzero/modules/ops/binary.py +81 -34
- torchzero/modules/{momentum/ema.py → ops/higher_level.py} +16 -39
- torchzero/modules/ops/multi.py +82 -21
- torchzero/modules/ops/reduce.py +16 -8
- torchzero/modules/ops/unary.py +29 -13
- torchzero/modules/ops/utility.py +30 -18
- torchzero/modules/projections/__init__.py +2 -4
- torchzero/modules/projections/cast.py +51 -0
- torchzero/modules/projections/galore.py +3 -1
- torchzero/modules/projections/projection.py +190 -96
- torchzero/modules/quasi_newton/__init__.py +9 -14
- torchzero/modules/quasi_newton/damping.py +105 -0
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -0
- torchzero/modules/quasi_newton/lbfgs.py +286 -173
- torchzero/modules/quasi_newton/lsr1.py +185 -106
- torchzero/modules/quasi_newton/quasi_newton.py +816 -268
- torchzero/modules/restarts/__init__.py +7 -0
- torchzero/modules/restarts/restars.py +252 -0
- torchzero/modules/second_order/__init__.py +3 -2
- torchzero/modules/second_order/multipoint.py +238 -0
- torchzero/modules/second_order/newton.py +292 -68
- torchzero/modules/second_order/newton_cg.py +365 -15
- torchzero/modules/second_order/nystrom.py +104 -1
- torchzero/modules/smoothing/__init__.py +1 -1
- torchzero/modules/smoothing/laplacian.py +14 -4
- torchzero/modules/smoothing/sampling.py +300 -0
- torchzero/modules/step_size/__init__.py +2 -0
- torchzero/modules/step_size/adaptive.py +387 -0
- torchzero/modules/step_size/lr.py +154 -0
- torchzero/modules/termination/__init__.py +14 -0
- torchzero/modules/termination/termination.py +207 -0
- torchzero/modules/trust_region/__init__.py +5 -0
- torchzero/modules/trust_region/cubic_regularization.py +170 -0
- torchzero/modules/trust_region/dogleg.py +92 -0
- torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
- torchzero/modules/trust_region/trust_cg.py +97 -0
- torchzero/modules/trust_region/trust_region.py +350 -0
- torchzero/modules/variance_reduction/__init__.py +1 -0
- torchzero/modules/variance_reduction/svrg.py +208 -0
- torchzero/modules/weight_decay/__init__.py +1 -1
- torchzero/modules/weight_decay/weight_decay.py +94 -11
- torchzero/modules/wrappers/optim_wrapper.py +29 -1
- torchzero/modules/zeroth_order/__init__.py +1 -0
- torchzero/modules/zeroth_order/cd.py +359 -0
- torchzero/optim/root.py +65 -0
- torchzero/optim/utility/split.py +8 -8
- torchzero/optim/wrappers/directsearch.py +39 -3
- torchzero/optim/wrappers/fcmaes.py +24 -15
- torchzero/optim/wrappers/mads.py +5 -6
- torchzero/optim/wrappers/nevergrad.py +16 -1
- torchzero/optim/wrappers/nlopt.py +0 -2
- torchzero/optim/wrappers/optuna.py +3 -3
- torchzero/optim/wrappers/scipy.py +86 -25
- torchzero/utils/__init__.py +40 -4
- torchzero/utils/compile.py +1 -1
- torchzero/utils/derivatives.py +126 -114
- torchzero/utils/linalg/__init__.py +9 -2
- torchzero/utils/linalg/linear_operator.py +329 -0
- torchzero/utils/linalg/matrix_funcs.py +2 -2
- torchzero/utils/linalg/orthogonalize.py +2 -1
- torchzero/utils/linalg/qr.py +2 -2
- torchzero/utils/linalg/solve.py +369 -58
- torchzero/utils/metrics.py +83 -0
- torchzero/utils/numberlist.py +2 -0
- torchzero/utils/python_tools.py +16 -0
- torchzero/utils/tensorlist.py +134 -51
- torchzero/utils/torch_tools.py +9 -4
- torchzero-0.3.13.dist-info/METADATA +14 -0
- torchzero-0.3.13.dist-info/RECORD +166 -0
- {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/top_level.txt +0 -1
- docs/source/conf.py +0 -57
- torchzero/modules/experimental/absoap.py +0 -250
- torchzero/modules/experimental/adadam.py +0 -112
- torchzero/modules/experimental/adamY.py +0 -125
- torchzero/modules/experimental/adasoap.py +0 -172
- torchzero/modules/experimental/diagonal_higher_order_newton.py +0 -225
- torchzero/modules/experimental/eigendescent.py +0 -117
- torchzero/modules/experimental/etf.py +0 -172
- torchzero/modules/experimental/soapy.py +0 -163
- torchzero/modules/experimental/structured_newton.py +0 -111
- torchzero/modules/experimental/subspace_preconditioners.py +0 -138
- torchzero/modules/experimental/tada.py +0 -38
- torchzero/modules/line_search/trust_region.py +0 -73
- torchzero/modules/lr/__init__.py +0 -2
- torchzero/modules/lr/adaptive.py +0 -93
- torchzero/modules/lr/lr.py +0 -63
- torchzero/modules/momentum/matrix_momentum.py +0 -166
- torchzero/modules/ops/debug.py +0 -25
- torchzero/modules/ops/misc.py +0 -418
- torchzero/modules/ops/split.py +0 -75
- torchzero/modules/optimizers/__init__.py +0 -18
- torchzero/modules/optimizers/adagrad.py +0 -155
- torchzero/modules/optimizers/sophia_h.py +0 -129
- torchzero/modules/quasi_newton/cg.py +0 -268
- torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
- torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +0 -266
- torchzero/modules/quasi_newton/olbfgs.py +0 -196
- torchzero/modules/smoothing/gaussian.py +0 -164
- torchzero-0.3.10.dist-info/METADATA +0 -379
- torchzero-0.3.10.dist-info/RECORD +0 -139
- torchzero-0.3.10.dist-info/licenses/LICENSE +0 -21
- {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,185 @@
|
|
|
1
|
+
from typing import Literal
|
|
2
|
+
from collections.abc import Callable
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from ...core import Module, Target, Transform, Chainable, apply_transform
|
|
6
|
+
from ...utils import NumberList, TensorList, as_tensorlist
|
|
7
|
+
def sophia_H(
|
|
8
|
+
tensors: TensorList,
|
|
9
|
+
h: TensorList | None,
|
|
10
|
+
exp_avg_: TensorList,
|
|
11
|
+
h_exp_avg_: TensorList,
|
|
12
|
+
beta1: float | NumberList,
|
|
13
|
+
beta2: float | NumberList,
|
|
14
|
+
update_freq: int,
|
|
15
|
+
precond_scale: float | NumberList,
|
|
16
|
+
clip: float | NumberList,
|
|
17
|
+
eps: float | NumberList,
|
|
18
|
+
step: int
|
|
19
|
+
):
|
|
20
|
+
# momentum
|
|
21
|
+
exp_avg_.lerp_(tensors, 1-beta1)
|
|
22
|
+
|
|
23
|
+
# update preconditioner
|
|
24
|
+
if step % update_freq == 0:
|
|
25
|
+
assert h is not None
|
|
26
|
+
h_exp_avg_.lerp_(h, 1-beta2)
|
|
27
|
+
|
|
28
|
+
else:
|
|
29
|
+
assert h is None
|
|
30
|
+
|
|
31
|
+
denom = (h_exp_avg_ * precond_scale).clip_(min=eps)
|
|
32
|
+
return (exp_avg_ / denom).clip_(-clip, clip)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class SophiaH(Module):
|
|
36
|
+
"""SophiaH optimizer from https://arxiv.org/abs/2305.14342
|
|
37
|
+
|
|
38
|
+
This is similar to Adam, but the second momentum is replaced by an exponential moving average of randomized hessian diagonal estimates, and the update is agressively clipped.
|
|
39
|
+
|
|
40
|
+
.. note::
|
|
41
|
+
In most cases SophiaH should be the first module in the chain because it relies on autograd. Use the :code:`inner` argument if you wish to apply SophiaH preconditioning to another module's output.
|
|
42
|
+
|
|
43
|
+
.. note::
|
|
44
|
+
If you are using gradient estimators or reformulations, set :code:`hvp_method` to "forward" or "central".
|
|
45
|
+
|
|
46
|
+
.. note::
|
|
47
|
+
This module requires the a closure passed to the optimizer step,
|
|
48
|
+
as it needs to re-evaluate the loss and gradients for calculating HVPs.
|
|
49
|
+
The closure must accept a ``backward`` argument (refer to documentation).
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
beta1 (float, optional): first momentum. Defaults to 0.96.
|
|
53
|
+
beta2 (float, optional): momentum for hessian diagonal estimate. Defaults to 0.99.
|
|
54
|
+
update_freq (int, optional):
|
|
55
|
+
frequency of updating hessian diagonal estimate via a hessian-vector product. Defaults to 10.
|
|
56
|
+
precond_scale (float, optional):
|
|
57
|
+
scale of the preconditioner. Defaults to 1.
|
|
58
|
+
clip (float, optional):
|
|
59
|
+
clips update to (-clip, clip). Defaults to 1.
|
|
60
|
+
eps (float, optional):
|
|
61
|
+
clips hessian diagonal esimate to be no less than this value. Defaults to 1e-12.
|
|
62
|
+
hvp_method (str, optional):
|
|
63
|
+
Determines how Hessian-vector products are evaluated.
|
|
64
|
+
|
|
65
|
+
- ``"autograd"``: Use PyTorch's autograd to calculate exact HVPs.
|
|
66
|
+
This requires creating a graph for the gradient.
|
|
67
|
+
- ``"forward"``: Use a forward finite difference formula to
|
|
68
|
+
approximate the HVP. This requires one extra gradient evaluation.
|
|
69
|
+
- ``"central"``: Use a central finite difference formula for a
|
|
70
|
+
more accurate HVP approximation. This requires two extra
|
|
71
|
+
gradient evaluations.
|
|
72
|
+
Defaults to "autograd".
|
|
73
|
+
fd_h (float, optional): finite difference step size if :code:`hvp_method` is "forward" or "central". Defaults to 1e-3.
|
|
74
|
+
n_samples (int, optional):
|
|
75
|
+
number of hessian-vector products with random vectors to evaluate each time when updating
|
|
76
|
+
the preconditioner. Larger values may lead to better hessian diagonal estimate. Defaults to 1.
|
|
77
|
+
seed (int | None, optional): seed for random vectors. Defaults to None.
|
|
78
|
+
inner (Chainable | None, optional): preconditioning is applied to the output of this module. Defaults to None.
|
|
79
|
+
|
|
80
|
+
Examples:
|
|
81
|
+
Using SophiaH:
|
|
82
|
+
|
|
83
|
+
.. code-block:: python
|
|
84
|
+
|
|
85
|
+
opt = tz.Modular(
|
|
86
|
+
model.parameters(),
|
|
87
|
+
tz.m.SophiaH(),
|
|
88
|
+
tz.m.LR(0.1)
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
SophiaH preconditioner can be applied to any other module by passing it to the :code:`inner` argument.
|
|
92
|
+
Turn off SophiaH's first momentum to get just the preconditioning. Here is an example of applying
|
|
93
|
+
SophiaH preconditioning to nesterov momentum (:code:`tz.m.NAG`):
|
|
94
|
+
|
|
95
|
+
.. code-block:: python
|
|
96
|
+
|
|
97
|
+
opt = tz.Modular(
|
|
98
|
+
model.parameters(),
|
|
99
|
+
tz.m.SophiaH(beta1=0, inner=tz.m.NAG(0.96)),
|
|
100
|
+
tz.m.LR(0.1)
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
"""
|
|
104
|
+
def __init__(
|
|
105
|
+
self,
|
|
106
|
+
beta1: float = 0.96,
|
|
107
|
+
beta2: float = 0.99,
|
|
108
|
+
update_freq: int = 10,
|
|
109
|
+
precond_scale: float = 1,
|
|
110
|
+
clip: float = 1,
|
|
111
|
+
eps: float = 1e-12,
|
|
112
|
+
hvp_method: Literal['autograd', 'forward', 'central'] = 'autograd',
|
|
113
|
+
fd_h: float = 1e-3,
|
|
114
|
+
n_samples = 1,
|
|
115
|
+
seed: int | None = None,
|
|
116
|
+
inner: Chainable | None = None
|
|
117
|
+
):
|
|
118
|
+
defaults = dict(beta1=beta1, beta2=beta2, update_freq=update_freq, precond_scale=precond_scale, clip=clip, eps=eps, hvp_method=hvp_method, n_samples=n_samples, fd_h=fd_h, seed=seed)
|
|
119
|
+
super().__init__(defaults)
|
|
120
|
+
|
|
121
|
+
if inner is not None:
|
|
122
|
+
self.set_child('inner', inner)
|
|
123
|
+
|
|
124
|
+
@torch.no_grad
|
|
125
|
+
def step(self, var):
|
|
126
|
+
params = var.params
|
|
127
|
+
settings = self.settings[params[0]]
|
|
128
|
+
hvp_method = settings['hvp_method']
|
|
129
|
+
fd_h = settings['fd_h']
|
|
130
|
+
update_freq = settings['update_freq']
|
|
131
|
+
n_samples = settings['n_samples']
|
|
132
|
+
|
|
133
|
+
seed = settings['seed']
|
|
134
|
+
generator = None
|
|
135
|
+
if seed is not None:
|
|
136
|
+
if 'generator' not in self.global_state:
|
|
137
|
+
self.global_state['generator'] = torch.Generator(params[0].device).manual_seed(seed)
|
|
138
|
+
generator = self.global_state['generator']
|
|
139
|
+
|
|
140
|
+
beta1, beta2, precond_scale, clip, eps = self.get_settings(params,
|
|
141
|
+
'beta1', 'beta2', 'precond_scale', 'clip', 'eps', cls=NumberList)
|
|
142
|
+
|
|
143
|
+
exp_avg, h_exp_avg = self.get_state(params, 'exp_avg', 'h_exp_avg', cls=TensorList)
|
|
144
|
+
|
|
145
|
+
step = self.global_state.get('step', 0)
|
|
146
|
+
self.global_state['step'] = step + 1
|
|
147
|
+
|
|
148
|
+
closure = var.closure
|
|
149
|
+
assert closure is not None
|
|
150
|
+
|
|
151
|
+
h = None
|
|
152
|
+
if step % update_freq == 0:
|
|
153
|
+
|
|
154
|
+
rgrad=None
|
|
155
|
+
for i in range(n_samples):
|
|
156
|
+
u = [torch.randn(p.shape, device=p.device, dtype=p.dtype, generator=generator) for p in params]
|
|
157
|
+
|
|
158
|
+
Hvp, rgrad = self.Hvp(u, at_x0=True, var=var, rgrad=rgrad, hvp_method=hvp_method,
|
|
159
|
+
h=fd_h, normalize=True, retain_grad=i < n_samples-1)
|
|
160
|
+
Hvp = tuple(Hvp)
|
|
161
|
+
|
|
162
|
+
if h is None: h = Hvp
|
|
163
|
+
else: torch._foreach_add_(h, Hvp)
|
|
164
|
+
|
|
165
|
+
assert h is not None
|
|
166
|
+
if n_samples > 1: torch._foreach_div_(h, n_samples)
|
|
167
|
+
|
|
168
|
+
update = var.get_update()
|
|
169
|
+
if 'inner' in self.children:
|
|
170
|
+
update = apply_transform(self.children['inner'], tensors=update, params=params, grads=var.grad, var=var)
|
|
171
|
+
|
|
172
|
+
var.update = sophia_H(
|
|
173
|
+
tensors=TensorList(update),
|
|
174
|
+
h=TensorList(h) if h is not None else None,
|
|
175
|
+
exp_avg_=exp_avg,
|
|
176
|
+
h_exp_avg_=h_exp_avg,
|
|
177
|
+
beta1=beta1,
|
|
178
|
+
beta2=beta2,
|
|
179
|
+
update_freq=update_freq,
|
|
180
|
+
precond_scale=precond_scale,
|
|
181
|
+
clip=clip,
|
|
182
|
+
eps=eps,
|
|
183
|
+
step=step,
|
|
184
|
+
)
|
|
185
|
+
return var
|
|
@@ -1,11 +1,13 @@
|
|
|
1
|
+
import math
|
|
2
|
+
from collections.abc import Iterable, Sequence
|
|
1
3
|
from operator import itemgetter
|
|
2
4
|
from typing import Literal
|
|
3
|
-
|
|
4
|
-
import math
|
|
5
|
+
|
|
5
6
|
import torch
|
|
6
7
|
|
|
7
8
|
from ...core import Module, Target, Transform
|
|
8
|
-
from ...utils import NumberList, TensorList
|
|
9
|
+
from ...utils import Metrics, NumberList, TensorList
|
|
10
|
+
from ...utils.metrics import _METRICS
|
|
9
11
|
|
|
10
12
|
|
|
11
13
|
def clip_grad_value_(params: Iterable[torch.Tensor], value: float):
|
|
@@ -24,7 +26,7 @@ def _clip_norm_(
|
|
|
24
26
|
min: float | NumberList | None,
|
|
25
27
|
max: float | NumberList | None,
|
|
26
28
|
norm_value: float | NumberList | None,
|
|
27
|
-
ord:
|
|
29
|
+
ord: Metrics,
|
|
28
30
|
dim: int | Sequence[int] | Literal["global"] | None,
|
|
29
31
|
inverse_dims: bool,
|
|
30
32
|
min_size: int,
|
|
@@ -35,7 +37,7 @@ def _clip_norm_(
|
|
|
35
37
|
raise ValueError(f'if norm_value is given then min and max must be None got {min = }; {max = }')
|
|
36
38
|
|
|
37
39
|
# if dim is None: return tensors_.mul_(norm_value / tensors_.norm(ord=ord))
|
|
38
|
-
if dim == 'global': return tensors_.mul_(norm_value / tensors_.
|
|
40
|
+
if dim == 'global': return tensors_.mul_(norm_value / tensors_.global_metric(ord))
|
|
39
41
|
|
|
40
42
|
# if dim is None: return tensors_.clip_norm_(min,max,tensorwise=True,ord=ord)
|
|
41
43
|
if dim == 'global': return tensors_.clip_norm_(min,max,tensorwise=False,ord=ord)
|
|
@@ -54,9 +56,13 @@ def _clip_norm_(
|
|
|
54
56
|
size = math.prod(tensor.size(d) for d in real_dim)
|
|
55
57
|
if size < min_size: continue
|
|
56
58
|
|
|
57
|
-
|
|
59
|
+
if isinstance(ord, str):
|
|
60
|
+
norm = _METRICS[ord].evaluate_tensor(tensor, dim=real_dim, keepdim=True)
|
|
61
|
+
else:
|
|
62
|
+
norm: torch.Tensor = torch.linalg.vector_norm(tensor, ord=ord, dim=real_dim, keepdim=True) # pylint:disable=not-callable
|
|
63
|
+
|
|
58
64
|
if norm.numel() == 1 and norm == 0: continue
|
|
59
|
-
norm = torch.where(norm
|
|
65
|
+
norm = torch.where(norm <= 1e-12, 1, norm)
|
|
60
66
|
|
|
61
67
|
# normalize = True, perform normalization
|
|
62
68
|
norm_v = norm_value[i] if isinstance(norm_value, (list,tuple)) else norm_value
|
|
@@ -90,7 +96,7 @@ def _clip_norm_(
|
|
|
90
96
|
def clip_grad_norm_(
|
|
91
97
|
params: Iterable[torch.Tensor],
|
|
92
98
|
max_norm: float | None,
|
|
93
|
-
ord:
|
|
99
|
+
ord: Metrics = 2,
|
|
94
100
|
dim: int | Sequence[int] | Literal["global"] | None = None,
|
|
95
101
|
inverse_dims: bool = False,
|
|
96
102
|
min_size: int = 2,
|
|
@@ -101,7 +107,7 @@ def clip_grad_norm_(
|
|
|
101
107
|
|
|
102
108
|
Args:
|
|
103
109
|
params (Iterable[torch.Tensor]): parameters with gradients to clip.
|
|
104
|
-
|
|
110
|
+
max_norm (float): value to clip norm to.
|
|
105
111
|
ord (float, optional): norm order. Defaults to 2.
|
|
106
112
|
dim (int | Sequence[int] | str | None, optional):
|
|
107
113
|
calculates norm along those dimensions.
|
|
@@ -118,7 +124,7 @@ def clip_grad_norm_(
|
|
|
118
124
|
def normalize_grads_(
|
|
119
125
|
params: Iterable[torch.Tensor],
|
|
120
126
|
norm_value: float,
|
|
121
|
-
ord:
|
|
127
|
+
ord: Metrics = 2,
|
|
122
128
|
dim: int | Sequence[int] | Literal["global"] | None = None,
|
|
123
129
|
inverse_dims: bool = False,
|
|
124
130
|
min_size: int = 1,
|
|
@@ -145,13 +151,41 @@ def normalize_grads_(
|
|
|
145
151
|
|
|
146
152
|
|
|
147
153
|
class ClipValue(Transform):
|
|
148
|
-
"""Clips update magnitude to be within
|
|
154
|
+
"""Clips update magnitude to be within ``(-value, value)`` range.
|
|
155
|
+
|
|
156
|
+
Args:
|
|
157
|
+
value (float): value to clip to.
|
|
158
|
+
target (str): refer to ``target argument`` in documentation.
|
|
159
|
+
|
|
160
|
+
Examples:
|
|
161
|
+
|
|
162
|
+
Gradient clipping:
|
|
163
|
+
```python
|
|
164
|
+
opt = tz.Modular(
|
|
165
|
+
model.parameters(),
|
|
166
|
+
tz.m.ClipValue(1),
|
|
167
|
+
tz.m.Adam(),
|
|
168
|
+
tz.m.LR(1e-2),
|
|
169
|
+
)
|
|
170
|
+
```
|
|
171
|
+
|
|
172
|
+
Update clipping:
|
|
173
|
+
```python
|
|
174
|
+
opt = tz.Modular(
|
|
175
|
+
model.parameters(),
|
|
176
|
+
tz.m.Adam(),
|
|
177
|
+
tz.m.ClipValue(1),
|
|
178
|
+
tz.m.LR(1e-2),
|
|
179
|
+
)
|
|
180
|
+
```
|
|
181
|
+
|
|
182
|
+
"""
|
|
149
183
|
def __init__(self, value: float, target: Target = 'update'):
|
|
150
184
|
defaults = dict(value=value)
|
|
151
|
-
super().__init__(defaults,
|
|
185
|
+
super().__init__(defaults, target=target)
|
|
152
186
|
|
|
153
187
|
@torch.no_grad
|
|
154
|
-
def
|
|
188
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
155
189
|
value = [s['value'] for s in settings]
|
|
156
190
|
return TensorList(tensors).clip_([-v for v in value], value)
|
|
157
191
|
|
|
@@ -159,7 +193,7 @@ class ClipNorm(Transform):
|
|
|
159
193
|
"""Clips update norm to be no larger than `value`.
|
|
160
194
|
|
|
161
195
|
Args:
|
|
162
|
-
|
|
196
|
+
max_norm (float): value to clip norm to.
|
|
163
197
|
ord (float, optional): norm order. Defaults to 2.
|
|
164
198
|
dim (int | Sequence[int] | str | None, optional):
|
|
165
199
|
calculates norm along those dimensions.
|
|
@@ -172,21 +206,43 @@ class ClipNorm(Transform):
|
|
|
172
206
|
minimal numer of elements in a parameter or slice to clip norm. Defaults to 1.
|
|
173
207
|
target (str, optional):
|
|
174
208
|
what this affects.
|
|
209
|
+
|
|
210
|
+
Examples:
|
|
211
|
+
|
|
212
|
+
Gradient norm clipping:
|
|
213
|
+
```python
|
|
214
|
+
opt = tz.Modular(
|
|
215
|
+
model.parameters(),
|
|
216
|
+
tz.m.ClipNorm(1),
|
|
217
|
+
tz.m.Adam(),
|
|
218
|
+
tz.m.LR(1e-2),
|
|
219
|
+
)
|
|
220
|
+
```
|
|
221
|
+
|
|
222
|
+
Update norm clipping:
|
|
223
|
+
```python
|
|
224
|
+
opt = tz.Modular(
|
|
225
|
+
model.parameters(),
|
|
226
|
+
tz.m.Adam(),
|
|
227
|
+
tz.m.ClipNorm(1),
|
|
228
|
+
tz.m.LR(1e-2),
|
|
229
|
+
)
|
|
230
|
+
```
|
|
175
231
|
"""
|
|
176
232
|
def __init__(
|
|
177
233
|
self,
|
|
178
234
|
max_norm: float,
|
|
179
|
-
ord:
|
|
235
|
+
ord: Metrics = 2,
|
|
180
236
|
dim: int | Sequence[int] | Literal["global"] | None = None,
|
|
181
237
|
inverse_dims: bool = False,
|
|
182
238
|
min_size: int = 1,
|
|
183
239
|
target: Target = "update",
|
|
184
240
|
):
|
|
185
241
|
defaults = dict(max_norm=max_norm,ord=ord,dim=dim,min_size=min_size,inverse_dims=inverse_dims)
|
|
186
|
-
super().__init__(defaults,
|
|
242
|
+
super().__init__(defaults, target=target)
|
|
187
243
|
|
|
188
244
|
@torch.no_grad
|
|
189
|
-
def
|
|
245
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
190
246
|
max_norm = NumberList(s['max_norm'] for s in settings)
|
|
191
247
|
ord, dim, min_size, inverse_dims = itemgetter('ord', 'dim', 'min_size', 'inverse_dims')(settings[0])
|
|
192
248
|
_clip_norm_(
|
|
@@ -205,7 +261,7 @@ class Normalize(Transform):
|
|
|
205
261
|
"""Normalizes the update.
|
|
206
262
|
|
|
207
263
|
Args:
|
|
208
|
-
|
|
264
|
+
norm_value (float): desired norm value.
|
|
209
265
|
ord (float, optional): norm order. Defaults to 2.
|
|
210
266
|
dim (int | Sequence[int] | str | None, optional):
|
|
211
267
|
calculates norm along those dimensions.
|
|
@@ -218,21 +274,43 @@ class Normalize(Transform):
|
|
|
218
274
|
minimal size of a dimension to normalize along it. Defaults to 1.
|
|
219
275
|
target (str, optional):
|
|
220
276
|
what this affects.
|
|
277
|
+
|
|
278
|
+
Examples:
|
|
279
|
+
Gradient normalization:
|
|
280
|
+
```python
|
|
281
|
+
opt = tz.Modular(
|
|
282
|
+
model.parameters(),
|
|
283
|
+
tz.m.Normalize(1),
|
|
284
|
+
tz.m.Adam(),
|
|
285
|
+
tz.m.LR(1e-2),
|
|
286
|
+
)
|
|
287
|
+
```
|
|
288
|
+
|
|
289
|
+
Update normalization:
|
|
290
|
+
|
|
291
|
+
```python
|
|
292
|
+
opt = tz.Modular(
|
|
293
|
+
model.parameters(),
|
|
294
|
+
tz.m.Adam(),
|
|
295
|
+
tz.m.Normalize(1),
|
|
296
|
+
tz.m.LR(1e-2),
|
|
297
|
+
)
|
|
298
|
+
```
|
|
221
299
|
"""
|
|
222
300
|
def __init__(
|
|
223
301
|
self,
|
|
224
302
|
norm_value: float = 1,
|
|
225
|
-
ord:
|
|
303
|
+
ord: Metrics = 2,
|
|
226
304
|
dim: int | Sequence[int] | Literal["global"] | None = None,
|
|
227
305
|
inverse_dims: bool = False,
|
|
228
306
|
min_size: int = 1,
|
|
229
307
|
target: Target = "update",
|
|
230
308
|
):
|
|
231
309
|
defaults = dict(norm_value=norm_value,ord=ord,dim=dim,min_size=min_size, inverse_dims=inverse_dims)
|
|
232
|
-
super().__init__(defaults,
|
|
310
|
+
super().__init__(defaults, target=target)
|
|
233
311
|
|
|
234
312
|
@torch.no_grad
|
|
235
|
-
def
|
|
313
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
236
314
|
norm_value = NumberList(s['norm_value'] for s in settings)
|
|
237
315
|
ord, dim, min_size, inverse_dims = itemgetter('ord', 'dim', 'min_size', 'inverse_dims')(settings[0])
|
|
238
316
|
|
|
@@ -288,8 +366,6 @@ class Centralize(Transform):
|
|
|
288
366
|
"""Centralizes the update.
|
|
289
367
|
|
|
290
368
|
Args:
|
|
291
|
-
value (float): desired norm value.
|
|
292
|
-
ord (float, optional): norm order. Defaults to 2.
|
|
293
369
|
dim (int | Sequence[int] | str | None, optional):
|
|
294
370
|
calculates norm along those dimensions.
|
|
295
371
|
If list/tuple, tensors are centralized along all dimensios in `dim` that they have.
|
|
@@ -299,6 +375,20 @@ class Centralize(Transform):
|
|
|
299
375
|
if True, the `dims` argument is inverted, and all other dimensions are centralized.
|
|
300
376
|
min_size (int, optional):
|
|
301
377
|
minimal size of a dimension to normalize along it. Defaults to 1.
|
|
378
|
+
|
|
379
|
+
Examples:
|
|
380
|
+
|
|
381
|
+
Standard gradient centralization:
|
|
382
|
+
```python
|
|
383
|
+
opt = tz.Modular(
|
|
384
|
+
model.parameters(),
|
|
385
|
+
tz.m.Centralize(dim=0),
|
|
386
|
+
tz.m.LR(1e-2),
|
|
387
|
+
)
|
|
388
|
+
```
|
|
389
|
+
|
|
390
|
+
References:
|
|
391
|
+
- Yong, H., Huang, J., Hua, X., & Zhang, L. (2020). Gradient centralization: A new optimization technique for deep neural networks. In Computer Vision–ECCV 2020: 16th European Conference, Glasgow, UK, August 23–28, 2020, Proceedings, Part I 16 (pp. 635-652). Springer International Publishing. https://arxiv.org/abs/2004.01461
|
|
302
392
|
"""
|
|
303
393
|
def __init__(
|
|
304
394
|
self,
|
|
@@ -308,10 +398,10 @@ class Centralize(Transform):
|
|
|
308
398
|
target: Target = "update",
|
|
309
399
|
):
|
|
310
400
|
defaults = dict(dim=dim,min_size=min_size,inverse_dims=inverse_dims)
|
|
311
|
-
super().__init__(defaults,
|
|
401
|
+
super().__init__(defaults, target=target)
|
|
312
402
|
|
|
313
403
|
@torch.no_grad
|
|
314
|
-
def
|
|
404
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
315
405
|
dim, min_size, inverse_dims = itemgetter('dim', 'min_size', 'inverse_dims')(settings[0])
|
|
316
406
|
|
|
317
407
|
_centralize_(tensors_ = TensorList(tensors), dim=dim, inverse_dims=inverse_dims, min_size=min_size)
|
|
@@ -5,7 +5,7 @@ from collections.abc import Iterable, Sequence
|
|
|
5
5
|
import torch
|
|
6
6
|
|
|
7
7
|
from ...core import Module, Target, Transform, apply_transform, Chainable
|
|
8
|
-
from ...utils import NumberList, TensorList,
|
|
8
|
+
from ...utils import NumberList, TensorList, unpack_dicts, unpack_states, Metrics
|
|
9
9
|
|
|
10
10
|
class ClipNormByEMA(Transform):
|
|
11
11
|
"""Clips norm to be no larger than the norm of an exponential moving average of past updates.
|
|
@@ -14,9 +14,10 @@ class ClipNormByEMA(Transform):
|
|
|
14
14
|
beta (float, optional): beta for the exponential moving average. Defaults to 0.99.
|
|
15
15
|
ord (float, optional): order of the norm. Defaults to 2.
|
|
16
16
|
eps (float, optional): epsilon for division. Defaults to 1e-6.
|
|
17
|
-
tensorwise (bool, optional):
|
|
17
|
+
tensorwise (bool, optional):
|
|
18
|
+
if True, norms are calculated parameter-wise, otherwise treats all parameters as single vector. Defaults to True.
|
|
18
19
|
max_ema_growth (float | None, optional):
|
|
19
|
-
if specified, exponential moving average norm can grow
|
|
20
|
+
if specified, restricts how quickly exponential moving average norm can grow. The norm is allowed to grow by at most this value per step. Defaults to 1.5.
|
|
20
21
|
ema_init (str, optional):
|
|
21
22
|
How to initialize exponential moving average on first step, "update" to use the first update or "zeros". Defaults to 'zeros'.
|
|
22
23
|
"""
|
|
@@ -24,17 +25,18 @@ class ClipNormByEMA(Transform):
|
|
|
24
25
|
def __init__(
|
|
25
26
|
self,
|
|
26
27
|
beta=0.99,
|
|
27
|
-
ord:
|
|
28
|
+
ord: Metrics = 2,
|
|
28
29
|
eps=1e-6,
|
|
29
30
|
tensorwise:bool=True,
|
|
30
31
|
max_ema_growth: float | None = 1.5,
|
|
31
32
|
ema_init: Literal['zeros', 'update'] = 'zeros',
|
|
33
|
+
inner: Chainable | None = None,
|
|
32
34
|
):
|
|
33
35
|
defaults = dict(beta=beta, ord=ord, tensorwise=tensorwise, ema_init=ema_init, eps=eps, max_ema_growth=max_ema_growth)
|
|
34
|
-
super().__init__(defaults,
|
|
36
|
+
super().__init__(defaults, inner=inner)
|
|
35
37
|
|
|
36
38
|
@torch.no_grad
|
|
37
|
-
def
|
|
39
|
+
def update_tensors(self, tensors, params, grads, loss, states, settings):
|
|
38
40
|
tensors = TensorList(tensors)
|
|
39
41
|
ord, tensorwise, ema_init, max_ema_growth = itemgetter('ord', 'tensorwise', 'ema_init', 'max_ema_growth')(settings[0])
|
|
40
42
|
|
|
@@ -45,7 +47,7 @@ class ClipNormByEMA(Transform):
|
|
|
45
47
|
ema.lerp_(tensors, 1-beta)
|
|
46
48
|
|
|
47
49
|
if tensorwise:
|
|
48
|
-
ema_norm = ema.
|
|
50
|
+
ema_norm = ema.metric(ord)
|
|
49
51
|
|
|
50
52
|
# clip ema norm growth
|
|
51
53
|
if max_ema_growth is not None:
|
|
@@ -62,7 +64,7 @@ class ClipNormByEMA(Transform):
|
|
|
62
64
|
else: denom.clip_(min=1)
|
|
63
65
|
|
|
64
66
|
else:
|
|
65
|
-
ema_norm = ema.
|
|
67
|
+
ema_norm = ema.global_metric(ord)
|
|
66
68
|
|
|
67
69
|
# clip ema norm growth
|
|
68
70
|
if max_ema_growth is not None:
|
|
@@ -73,12 +75,17 @@ class ClipNormByEMA(Transform):
|
|
|
73
75
|
ema_norm = allowed_norm
|
|
74
76
|
prev_ema_norm.set_(ema_norm)
|
|
75
77
|
|
|
76
|
-
tensors_norm = tensors.
|
|
78
|
+
tensors_norm = tensors.global_metric(ord)
|
|
77
79
|
denom = tensors_norm / ema_norm.clip(min=eps[0])
|
|
78
80
|
if self.NORMALIZE: denom.clip_(min=eps[0])
|
|
79
81
|
else: denom.clip_(min=1)
|
|
80
82
|
|
|
81
|
-
|
|
83
|
+
self.global_state['denom'] = denom
|
|
84
|
+
|
|
85
|
+
@torch.no_grad
|
|
86
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
87
|
+
denom = self.global_state.pop('denom')
|
|
88
|
+
torch._foreach_div_(tensors, denom)
|
|
82
89
|
return tensors
|
|
83
90
|
|
|
84
91
|
class NormalizeByEMA(ClipNormByEMA):
|
|
@@ -88,9 +95,10 @@ class NormalizeByEMA(ClipNormByEMA):
|
|
|
88
95
|
beta (float, optional): beta for the exponential moving average. Defaults to 0.99.
|
|
89
96
|
ord (float, optional): order of the norm. Defaults to 2.
|
|
90
97
|
eps (float, optional): epsilon for division. Defaults to 1e-6.
|
|
91
|
-
tensorwise (bool, optional):
|
|
98
|
+
tensorwise (bool, optional):
|
|
99
|
+
if True, norms are calculated parameter-wise, otherwise treats all parameters as single vector. Defaults to True.
|
|
92
100
|
max_ema_growth (float | None, optional):
|
|
93
|
-
if specified, exponential moving average norm can grow
|
|
101
|
+
if specified, restricts how quickly exponential moving average norm can grow. The norm is allowed to grow by at most this value per step. Defaults to 1.5.
|
|
94
102
|
ema_init (str, optional):
|
|
95
103
|
How to initialize exponential moving average on first step, "update" to use the first update or "zeros". Defaults to 'zeros'.
|
|
96
104
|
"""
|
|
@@ -99,28 +107,30 @@ class NormalizeByEMA(ClipNormByEMA):
|
|
|
99
107
|
# TODO Centralize by EMA?
|
|
100
108
|
|
|
101
109
|
class ClipValueByEMA(Transform):
|
|
102
|
-
"""Clips magnitude of update to be no larger than magnitude of
|
|
110
|
+
"""Clips magnitude of update to be no larger than magnitude of exponential moving average of past (unclipped) updates.
|
|
103
111
|
|
|
104
112
|
Args:
|
|
105
113
|
beta (float, optional): beta for the exponential moving average. Defaults to 0.99.
|
|
106
114
|
ema_init (str, optional):
|
|
107
115
|
How to initialize exponential moving average on first step, "update" to use the first update or "zeros". Defaults to 'zeros'.
|
|
108
|
-
ema_tfm (Chainable | None, optional):
|
|
116
|
+
ema_tfm (Chainable | None, optional):
|
|
117
|
+
optional modules applied to exponential moving average before clipping by it. Defaults to None.
|
|
109
118
|
"""
|
|
110
119
|
def __init__(
|
|
111
120
|
self,
|
|
112
121
|
beta=0.99,
|
|
113
122
|
ema_init: Literal['zeros', 'update'] = 'zeros',
|
|
114
123
|
ema_tfm:Chainable | None=None,
|
|
124
|
+
inner: Chainable | None = None,
|
|
115
125
|
):
|
|
116
126
|
defaults = dict(beta=beta, ema_init=ema_init)
|
|
117
|
-
super().__init__(defaults,
|
|
127
|
+
super().__init__(defaults, inner=inner)
|
|
118
128
|
|
|
119
129
|
if ema_tfm is not None:
|
|
120
130
|
self.set_child('ema_tfm', ema_tfm)
|
|
121
131
|
|
|
122
132
|
@torch.no_grad
|
|
123
|
-
def
|
|
133
|
+
def update_tensors(self, tensors, params, grads, loss, states, settings):
|
|
124
134
|
ema_init = itemgetter('ema_init')(settings[0])
|
|
125
135
|
|
|
126
136
|
beta = unpack_dicts(settings, 'beta', cls=NumberList)
|
|
@@ -129,8 +139,12 @@ class ClipValueByEMA(Transform):
|
|
|
129
139
|
ema = unpack_states(states, tensors, 'ema', init = (torch.zeros_like if ema_init=='zeros' else lambda t: t.abs()), cls=TensorList)
|
|
130
140
|
ema.lerp_(tensors.abs(), 1-beta)
|
|
131
141
|
|
|
142
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
143
|
+
tensors = TensorList(tensors)
|
|
144
|
+
ema = unpack_states(states, tensors, 'ema', cls=TensorList)
|
|
145
|
+
|
|
132
146
|
if 'ema_tfm' in self.children:
|
|
133
|
-
ema = TensorList(apply_transform(self.children['ema_tfm'], ema, params, grads, loss))
|
|
147
|
+
ema = TensorList(apply_transform(self.children['ema_tfm'], ema.clone(), params, grads, loss))
|
|
134
148
|
|
|
135
149
|
tensors.clip_(-ema, ema)
|
|
136
150
|
return tensors
|
|
@@ -19,7 +19,7 @@ class ClipValueGrowth(TensorwiseTransform):
|
|
|
19
19
|
bounds the tracked multiplicative clipping decay to prevent collapse to 0.
|
|
20
20
|
Next update is at most :code:`max(previous update * mul, max_decay)`.
|
|
21
21
|
Defaults to 2.
|
|
22
|
-
target (Target, optional): what to set on var
|
|
22
|
+
target (Target, optional): what to set on var. Defaults to "update".
|
|
23
23
|
"""
|
|
24
24
|
def __init__(
|
|
25
25
|
self,
|
|
@@ -30,11 +30,11 @@ class ClipValueGrowth(TensorwiseTransform):
|
|
|
30
30
|
target: Target = "update",
|
|
31
31
|
):
|
|
32
32
|
defaults = dict(add=add, mul=mul, min_value=min_value, max_decay=max_decay)
|
|
33
|
-
super().__init__(defaults,
|
|
33
|
+
super().__init__(defaults, target=target)
|
|
34
34
|
|
|
35
35
|
|
|
36
|
-
def apply_tensor(self, tensor, param, grad, loss, state,
|
|
37
|
-
add, mul, min_value, max_decay = itemgetter('add','mul','min_value','max_decay')(
|
|
36
|
+
def apply_tensor(self, tensor, param, grad, loss, state, setting):
|
|
37
|
+
add, mul, min_value, max_decay = itemgetter('add','mul','min_value','max_decay')(setting)
|
|
38
38
|
add: float | None
|
|
39
39
|
|
|
40
40
|
if add is None and mul is None:
|
|
@@ -120,7 +120,8 @@ class ClipNormGrowth(Transform):
|
|
|
120
120
|
|
|
121
121
|
Args:
|
|
122
122
|
add (float | None, optional): additive clipping, next update norm is at most `previous norm + add`. Defaults to None.
|
|
123
|
-
mul (float | None, optional):
|
|
123
|
+
mul (float | None, optional):
|
|
124
|
+
multiplicative clipping, next update norm is at most `previous norm * mul`. Defaults to 1.5.
|
|
124
125
|
min_value (float | None, optional):
|
|
125
126
|
minimum value for multiplicative clipping to prevent collapse to 0.
|
|
126
127
|
Next norm is at most :code:`max(prev_norm, min_value) * mul`. Defaults to 1e-4.
|
|
@@ -144,11 +145,11 @@ class ClipNormGrowth(Transform):
|
|
|
144
145
|
target: Target = "update",
|
|
145
146
|
):
|
|
146
147
|
defaults = dict(add=add, mul=mul, min_value=min_value, max_decay=max_decay, ord=ord, parameterwise=parameterwise)
|
|
147
|
-
super().__init__(defaults,
|
|
148
|
+
super().__init__(defaults, target=target)
|
|
148
149
|
|
|
149
150
|
|
|
150
151
|
|
|
151
|
-
def
|
|
152
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
152
153
|
parameterwise = settings[0]['parameterwise']
|
|
153
154
|
tensors = TensorList(tensors)
|
|
154
155
|
|