torchzero 0.3.11__py3-none-any.whl → 0.3.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_opts.py +95 -69
- tests/test_tensorlist.py +8 -7
- torchzero/__init__.py +1 -1
- torchzero/core/__init__.py +2 -2
- torchzero/core/module.py +225 -72
- torchzero/core/reformulation.py +65 -0
- torchzero/core/transform.py +44 -24
- torchzero/modules/__init__.py +13 -5
- torchzero/modules/{optimizers → adaptive}/__init__.py +5 -2
- torchzero/modules/adaptive/adagrad.py +356 -0
- torchzero/modules/{optimizers → adaptive}/adahessian.py +53 -52
- torchzero/modules/{optimizers → adaptive}/adam.py +0 -3
- torchzero/modules/{optimizers → adaptive}/adan.py +26 -40
- torchzero/modules/{optimizers → adaptive}/adaptive_heavyball.py +3 -6
- torchzero/modules/adaptive/aegd.py +54 -0
- torchzero/modules/{optimizers → adaptive}/esgd.py +1 -1
- torchzero/modules/{optimizers/ladagrad.py → adaptive/lmadagrad.py} +42 -39
- torchzero/modules/{optimizers → adaptive}/mars.py +24 -36
- torchzero/modules/adaptive/matrix_momentum.py +146 -0
- torchzero/modules/{optimizers → adaptive}/msam.py +14 -12
- torchzero/modules/{optimizers → adaptive}/muon.py +19 -20
- torchzero/modules/adaptive/natural_gradient.py +175 -0
- torchzero/modules/{optimizers → adaptive}/rprop.py +0 -2
- torchzero/modules/{optimizers → adaptive}/sam.py +1 -1
- torchzero/modules/{optimizers → adaptive}/shampoo.py +8 -4
- torchzero/modules/{optimizers → adaptive}/soap.py +27 -50
- torchzero/modules/{optimizers → adaptive}/sophia_h.py +2 -3
- torchzero/modules/clipping/clipping.py +85 -92
- torchzero/modules/clipping/ema_clipping.py +5 -5
- torchzero/modules/conjugate_gradient/__init__.py +11 -0
- torchzero/modules/{quasi_newton → conjugate_gradient}/cg.py +355 -369
- torchzero/modules/experimental/__init__.py +9 -32
- torchzero/modules/experimental/dct.py +2 -2
- torchzero/modules/experimental/fft.py +2 -2
- torchzero/modules/experimental/gradmin.py +4 -3
- torchzero/modules/experimental/l_infinity.py +111 -0
- torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +3 -40
- torchzero/modules/experimental/newton_solver.py +79 -17
- torchzero/modules/experimental/newtonnewton.py +27 -14
- torchzero/modules/experimental/scipy_newton_cg.py +105 -0
- torchzero/modules/experimental/structural_projections.py +1 -1
- torchzero/modules/functional.py +50 -14
- torchzero/modules/grad_approximation/fdm.py +19 -20
- torchzero/modules/grad_approximation/forward_gradient.py +4 -2
- torchzero/modules/grad_approximation/grad_approximator.py +43 -47
- torchzero/modules/grad_approximation/rfdm.py +144 -122
- torchzero/modules/higher_order/__init__.py +1 -1
- torchzero/modules/higher_order/higher_order_newton.py +31 -23
- torchzero/modules/least_squares/__init__.py +1 -0
- torchzero/modules/least_squares/gn.py +161 -0
- torchzero/modules/line_search/__init__.py +2 -2
- torchzero/modules/line_search/_polyinterp.py +289 -0
- torchzero/modules/line_search/adaptive.py +69 -44
- torchzero/modules/line_search/backtracking.py +83 -70
- torchzero/modules/line_search/line_search.py +159 -68
- torchzero/modules/line_search/scipy.py +1 -1
- torchzero/modules/line_search/strong_wolfe.py +319 -218
- torchzero/modules/misc/__init__.py +8 -0
- torchzero/modules/misc/debug.py +4 -4
- torchzero/modules/misc/escape.py +9 -7
- torchzero/modules/misc/gradient_accumulation.py +88 -22
- torchzero/modules/misc/homotopy.py +59 -0
- torchzero/modules/misc/misc.py +82 -15
- torchzero/modules/misc/multistep.py +47 -11
- torchzero/modules/misc/regularization.py +5 -9
- torchzero/modules/misc/split.py +55 -35
- torchzero/modules/misc/switch.py +1 -1
- torchzero/modules/momentum/__init__.py +1 -5
- torchzero/modules/momentum/averaging.py +3 -3
- torchzero/modules/momentum/cautious.py +42 -47
- torchzero/modules/momentum/momentum.py +35 -1
- torchzero/modules/ops/__init__.py +9 -1
- torchzero/modules/ops/binary.py +9 -8
- torchzero/modules/{momentum/ema.py → ops/higher_level.py} +10 -33
- torchzero/modules/ops/multi.py +15 -15
- torchzero/modules/ops/reduce.py +1 -1
- torchzero/modules/ops/utility.py +12 -8
- torchzero/modules/projections/projection.py +4 -4
- torchzero/modules/quasi_newton/__init__.py +1 -16
- torchzero/modules/quasi_newton/damping.py +105 -0
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -163
- torchzero/modules/quasi_newton/lbfgs.py +256 -200
- torchzero/modules/quasi_newton/lsr1.py +167 -132
- torchzero/modules/quasi_newton/quasi_newton.py +346 -446
- torchzero/modules/restarts/__init__.py +7 -0
- torchzero/modules/restarts/restars.py +252 -0
- torchzero/modules/second_order/__init__.py +2 -1
- torchzero/modules/second_order/multipoint.py +238 -0
- torchzero/modules/second_order/newton.py +133 -88
- torchzero/modules/second_order/newton_cg.py +141 -80
- torchzero/modules/smoothing/__init__.py +1 -1
- torchzero/modules/smoothing/sampling.py +300 -0
- torchzero/modules/step_size/__init__.py +1 -1
- torchzero/modules/step_size/adaptive.py +312 -47
- torchzero/modules/termination/__init__.py +14 -0
- torchzero/modules/termination/termination.py +207 -0
- torchzero/modules/trust_region/__init__.py +5 -0
- torchzero/modules/trust_region/cubic_regularization.py +170 -0
- torchzero/modules/trust_region/dogleg.py +92 -0
- torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
- torchzero/modules/trust_region/trust_cg.py +97 -0
- torchzero/modules/trust_region/trust_region.py +350 -0
- torchzero/modules/variance_reduction/__init__.py +1 -0
- torchzero/modules/variance_reduction/svrg.py +208 -0
- torchzero/modules/weight_decay/weight_decay.py +65 -64
- torchzero/modules/zeroth_order/__init__.py +1 -0
- torchzero/modules/zeroth_order/cd.py +359 -0
- torchzero/optim/root.py +65 -0
- torchzero/optim/utility/split.py +8 -8
- torchzero/optim/wrappers/directsearch.py +0 -1
- torchzero/optim/wrappers/fcmaes.py +3 -2
- torchzero/optim/wrappers/nlopt.py +0 -2
- torchzero/optim/wrappers/optuna.py +2 -2
- torchzero/optim/wrappers/scipy.py +81 -22
- torchzero/utils/__init__.py +40 -4
- torchzero/utils/compile.py +1 -1
- torchzero/utils/derivatives.py +123 -111
- torchzero/utils/linalg/__init__.py +9 -2
- torchzero/utils/linalg/linear_operator.py +329 -0
- torchzero/utils/linalg/matrix_funcs.py +2 -2
- torchzero/utils/linalg/orthogonalize.py +2 -1
- torchzero/utils/linalg/qr.py +2 -2
- torchzero/utils/linalg/solve.py +226 -154
- torchzero/utils/metrics.py +83 -0
- torchzero/utils/python_tools.py +6 -0
- torchzero/utils/tensorlist.py +105 -34
- torchzero/utils/torch_tools.py +9 -4
- torchzero-0.3.13.dist-info/METADATA +14 -0
- torchzero-0.3.13.dist-info/RECORD +166 -0
- {torchzero-0.3.11.dist-info → torchzero-0.3.13.dist-info}/top_level.txt +0 -1
- docs/source/conf.py +0 -59
- docs/source/docstring template.py +0 -46
- torchzero/modules/experimental/absoap.py +0 -253
- torchzero/modules/experimental/adadam.py +0 -118
- torchzero/modules/experimental/adamY.py +0 -131
- torchzero/modules/experimental/adam_lambertw.py +0 -149
- torchzero/modules/experimental/adaptive_step_size.py +0 -90
- torchzero/modules/experimental/adasoap.py +0 -177
- torchzero/modules/experimental/cosine.py +0 -214
- torchzero/modules/experimental/cubic_adam.py +0 -97
- torchzero/modules/experimental/eigendescent.py +0 -120
- torchzero/modules/experimental/etf.py +0 -195
- torchzero/modules/experimental/exp_adam.py +0 -113
- torchzero/modules/experimental/expanded_lbfgs.py +0 -141
- torchzero/modules/experimental/hnewton.py +0 -85
- torchzero/modules/experimental/modular_lbfgs.py +0 -265
- torchzero/modules/experimental/parabolic_search.py +0 -220
- torchzero/modules/experimental/subspace_preconditioners.py +0 -145
- torchzero/modules/experimental/tensor_adagrad.py +0 -42
- torchzero/modules/line_search/polynomial.py +0 -233
- torchzero/modules/momentum/matrix_momentum.py +0 -193
- torchzero/modules/optimizers/adagrad.py +0 -165
- torchzero/modules/quasi_newton/trust_region.py +0 -397
- torchzero/modules/smoothing/gaussian.py +0 -198
- torchzero-0.3.11.dist-info/METADATA +0 -404
- torchzero-0.3.11.dist-info/RECORD +0 -159
- torchzero-0.3.11.dist-info/licenses/LICENSE +0 -21
- /torchzero/modules/{optimizers → adaptive}/lion.py +0 -0
- /torchzero/modules/{optimizers → adaptive}/orthograd.py +0 -0
- /torchzero/modules/{optimizers → adaptive}/rmsprop.py +0 -0
- {torchzero-0.3.11.dist-info → torchzero-0.3.13.dist-info}/WHEEL +0 -0
|
@@ -1,11 +1,13 @@
|
|
|
1
|
+
import math
|
|
2
|
+
from collections.abc import Iterable, Sequence
|
|
1
3
|
from operator import itemgetter
|
|
2
4
|
from typing import Literal
|
|
3
|
-
|
|
4
|
-
import math
|
|
5
|
+
|
|
5
6
|
import torch
|
|
6
7
|
|
|
7
8
|
from ...core import Module, Target, Transform
|
|
8
|
-
from ...utils import NumberList, TensorList
|
|
9
|
+
from ...utils import Metrics, NumberList, TensorList
|
|
10
|
+
from ...utils.metrics import _METRICS
|
|
9
11
|
|
|
10
12
|
|
|
11
13
|
def clip_grad_value_(params: Iterable[torch.Tensor], value: float):
|
|
@@ -24,7 +26,7 @@ def _clip_norm_(
|
|
|
24
26
|
min: float | NumberList | None,
|
|
25
27
|
max: float | NumberList | None,
|
|
26
28
|
norm_value: float | NumberList | None,
|
|
27
|
-
ord:
|
|
29
|
+
ord: Metrics,
|
|
28
30
|
dim: int | Sequence[int] | Literal["global"] | None,
|
|
29
31
|
inverse_dims: bool,
|
|
30
32
|
min_size: int,
|
|
@@ -35,7 +37,7 @@ def _clip_norm_(
|
|
|
35
37
|
raise ValueError(f'if norm_value is given then min and max must be None got {min = }; {max = }')
|
|
36
38
|
|
|
37
39
|
# if dim is None: return tensors_.mul_(norm_value / tensors_.norm(ord=ord))
|
|
38
|
-
if dim == 'global': return tensors_.mul_(norm_value / tensors_.
|
|
40
|
+
if dim == 'global': return tensors_.mul_(norm_value / tensors_.global_metric(ord))
|
|
39
41
|
|
|
40
42
|
# if dim is None: return tensors_.clip_norm_(min,max,tensorwise=True,ord=ord)
|
|
41
43
|
if dim == 'global': return tensors_.clip_norm_(min,max,tensorwise=False,ord=ord)
|
|
@@ -54,8 +56,8 @@ def _clip_norm_(
|
|
|
54
56
|
size = math.prod(tensor.size(d) for d in real_dim)
|
|
55
57
|
if size < min_size: continue
|
|
56
58
|
|
|
57
|
-
if ord
|
|
58
|
-
norm =
|
|
59
|
+
if isinstance(ord, str):
|
|
60
|
+
norm = _METRICS[ord].evaluate_tensor(tensor, dim=real_dim, keepdim=True)
|
|
59
61
|
else:
|
|
60
62
|
norm: torch.Tensor = torch.linalg.vector_norm(tensor, ord=ord, dim=real_dim, keepdim=True) # pylint:disable=not-callable
|
|
61
63
|
|
|
@@ -94,7 +96,7 @@ def _clip_norm_(
|
|
|
94
96
|
def clip_grad_norm_(
|
|
95
97
|
params: Iterable[torch.Tensor],
|
|
96
98
|
max_norm: float | None,
|
|
97
|
-
ord:
|
|
99
|
+
ord: Metrics = 2,
|
|
98
100
|
dim: int | Sequence[int] | Literal["global"] | None = None,
|
|
99
101
|
inverse_dims: bool = False,
|
|
100
102
|
min_size: int = 2,
|
|
@@ -105,7 +107,7 @@ def clip_grad_norm_(
|
|
|
105
107
|
|
|
106
108
|
Args:
|
|
107
109
|
params (Iterable[torch.Tensor]): parameters with gradients to clip.
|
|
108
|
-
|
|
110
|
+
max_norm (float): value to clip norm to.
|
|
109
111
|
ord (float, optional): norm order. Defaults to 2.
|
|
110
112
|
dim (int | Sequence[int] | str | None, optional):
|
|
111
113
|
calculates norm along those dimensions.
|
|
@@ -122,7 +124,7 @@ def clip_grad_norm_(
|
|
|
122
124
|
def normalize_grads_(
|
|
123
125
|
params: Iterable[torch.Tensor],
|
|
124
126
|
norm_value: float,
|
|
125
|
-
ord:
|
|
127
|
+
ord: Metrics = 2,
|
|
126
128
|
dim: int | Sequence[int] | Literal["global"] | None = None,
|
|
127
129
|
inverse_dims: bool = False,
|
|
128
130
|
min_size: int = 1,
|
|
@@ -149,35 +151,33 @@ def normalize_grads_(
|
|
|
149
151
|
|
|
150
152
|
|
|
151
153
|
class ClipValue(Transform):
|
|
152
|
-
"""Clips update magnitude to be within
|
|
154
|
+
"""Clips update magnitude to be within ``(-value, value)`` range.
|
|
153
155
|
|
|
154
156
|
Args:
|
|
155
157
|
value (float): value to clip to.
|
|
156
|
-
target (str): refer to
|
|
158
|
+
target (str): refer to ``target argument`` in documentation.
|
|
157
159
|
|
|
158
160
|
Examples:
|
|
159
161
|
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
tz.m.LR(1e-2),
|
|
180
|
-
)
|
|
162
|
+
Gradient clipping:
|
|
163
|
+
```python
|
|
164
|
+
opt = tz.Modular(
|
|
165
|
+
model.parameters(),
|
|
166
|
+
tz.m.ClipValue(1),
|
|
167
|
+
tz.m.Adam(),
|
|
168
|
+
tz.m.LR(1e-2),
|
|
169
|
+
)
|
|
170
|
+
```
|
|
171
|
+
|
|
172
|
+
Update clipping:
|
|
173
|
+
```python
|
|
174
|
+
opt = tz.Modular(
|
|
175
|
+
model.parameters(),
|
|
176
|
+
tz.m.Adam(),
|
|
177
|
+
tz.m.ClipValue(1),
|
|
178
|
+
tz.m.LR(1e-2),
|
|
179
|
+
)
|
|
180
|
+
```
|
|
181
181
|
|
|
182
182
|
"""
|
|
183
183
|
def __init__(self, value: float, target: Target = 'update'):
|
|
@@ -193,7 +193,7 @@ class ClipNorm(Transform):
|
|
|
193
193
|
"""Clips update norm to be no larger than `value`.
|
|
194
194
|
|
|
195
195
|
Args:
|
|
196
|
-
|
|
196
|
+
max_norm (float): value to clip norm to.
|
|
197
197
|
ord (float, optional): norm order. Defaults to 2.
|
|
198
198
|
dim (int | Sequence[int] | str | None, optional):
|
|
199
199
|
calculates norm along those dimensions.
|
|
@@ -209,32 +209,30 @@ class ClipNorm(Transform):
|
|
|
209
209
|
|
|
210
210
|
Examples:
|
|
211
211
|
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
tz.m.LR(1e-2),
|
|
232
|
-
)
|
|
212
|
+
Gradient norm clipping:
|
|
213
|
+
```python
|
|
214
|
+
opt = tz.Modular(
|
|
215
|
+
model.parameters(),
|
|
216
|
+
tz.m.ClipNorm(1),
|
|
217
|
+
tz.m.Adam(),
|
|
218
|
+
tz.m.LR(1e-2),
|
|
219
|
+
)
|
|
220
|
+
```
|
|
221
|
+
|
|
222
|
+
Update norm clipping:
|
|
223
|
+
```python
|
|
224
|
+
opt = tz.Modular(
|
|
225
|
+
model.parameters(),
|
|
226
|
+
tz.m.Adam(),
|
|
227
|
+
tz.m.ClipNorm(1),
|
|
228
|
+
tz.m.LR(1e-2),
|
|
229
|
+
)
|
|
230
|
+
```
|
|
233
231
|
"""
|
|
234
232
|
def __init__(
|
|
235
233
|
self,
|
|
236
234
|
max_norm: float,
|
|
237
|
-
ord:
|
|
235
|
+
ord: Metrics = 2,
|
|
238
236
|
dim: int | Sequence[int] | Literal["global"] | None = None,
|
|
239
237
|
inverse_dims: bool = False,
|
|
240
238
|
min_size: int = 1,
|
|
@@ -263,7 +261,7 @@ class Normalize(Transform):
|
|
|
263
261
|
"""Normalizes the update.
|
|
264
262
|
|
|
265
263
|
Args:
|
|
266
|
-
|
|
264
|
+
norm_value (float): desired norm value.
|
|
267
265
|
ord (float, optional): norm order. Defaults to 2.
|
|
268
266
|
dim (int | Sequence[int] | str | None, optional):
|
|
269
267
|
calculates norm along those dimensions.
|
|
@@ -278,33 +276,31 @@ class Normalize(Transform):
|
|
|
278
276
|
what this affects.
|
|
279
277
|
|
|
280
278
|
Examples:
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
tz.m.LR(1e-2),
|
|
302
|
-
)
|
|
279
|
+
Gradient normalization:
|
|
280
|
+
```python
|
|
281
|
+
opt = tz.Modular(
|
|
282
|
+
model.parameters(),
|
|
283
|
+
tz.m.Normalize(1),
|
|
284
|
+
tz.m.Adam(),
|
|
285
|
+
tz.m.LR(1e-2),
|
|
286
|
+
)
|
|
287
|
+
```
|
|
288
|
+
|
|
289
|
+
Update normalization:
|
|
290
|
+
|
|
291
|
+
```python
|
|
292
|
+
opt = tz.Modular(
|
|
293
|
+
model.parameters(),
|
|
294
|
+
tz.m.Adam(),
|
|
295
|
+
tz.m.Normalize(1),
|
|
296
|
+
tz.m.LR(1e-2),
|
|
297
|
+
)
|
|
298
|
+
```
|
|
303
299
|
"""
|
|
304
300
|
def __init__(
|
|
305
301
|
self,
|
|
306
302
|
norm_value: float = 1,
|
|
307
|
-
ord:
|
|
303
|
+
ord: Metrics = 2,
|
|
308
304
|
dim: int | Sequence[int] | Literal["global"] | None = None,
|
|
309
305
|
inverse_dims: bool = False,
|
|
310
306
|
min_size: int = 1,
|
|
@@ -370,8 +366,6 @@ class Centralize(Transform):
|
|
|
370
366
|
"""Centralizes the update.
|
|
371
367
|
|
|
372
368
|
Args:
|
|
373
|
-
value (float): desired norm value.
|
|
374
|
-
ord (float, optional): norm order. Defaults to 2.
|
|
375
369
|
dim (int | Sequence[int] | str | None, optional):
|
|
376
370
|
calculates norm along those dimensions.
|
|
377
371
|
If list/tuple, tensors are centralized along all dimensios in `dim` that they have.
|
|
@@ -384,18 +378,17 @@ class Centralize(Transform):
|
|
|
384
378
|
|
|
385
379
|
Examples:
|
|
386
380
|
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
)
|
|
381
|
+
Standard gradient centralization:
|
|
382
|
+
```python
|
|
383
|
+
opt = tz.Modular(
|
|
384
|
+
model.parameters(),
|
|
385
|
+
tz.m.Centralize(dim=0),
|
|
386
|
+
tz.m.LR(1e-2),
|
|
387
|
+
)
|
|
388
|
+
```
|
|
396
389
|
|
|
397
390
|
References:
|
|
398
|
-
|
|
391
|
+
- Yong, H., Huang, J., Hua, X., & Zhang, L. (2020). Gradient centralization: A new optimization technique for deep neural networks. In Computer Vision–ECCV 2020: 16th European Conference, Glasgow, UK, August 23–28, 2020, Proceedings, Part I 16 (pp. 635-652). Springer International Publishing. https://arxiv.org/abs/2004.01461
|
|
399
392
|
"""
|
|
400
393
|
def __init__(
|
|
401
394
|
self,
|
|
@@ -5,7 +5,7 @@ from collections.abc import Iterable, Sequence
|
|
|
5
5
|
import torch
|
|
6
6
|
|
|
7
7
|
from ...core import Module, Target, Transform, apply_transform, Chainable
|
|
8
|
-
from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
|
|
8
|
+
from ...utils import NumberList, TensorList, unpack_dicts, unpack_states, Metrics
|
|
9
9
|
|
|
10
10
|
class ClipNormByEMA(Transform):
|
|
11
11
|
"""Clips norm to be no larger than the norm of an exponential moving average of past updates.
|
|
@@ -25,7 +25,7 @@ class ClipNormByEMA(Transform):
|
|
|
25
25
|
def __init__(
|
|
26
26
|
self,
|
|
27
27
|
beta=0.99,
|
|
28
|
-
ord:
|
|
28
|
+
ord: Metrics = 2,
|
|
29
29
|
eps=1e-6,
|
|
30
30
|
tensorwise:bool=True,
|
|
31
31
|
max_ema_growth: float | None = 1.5,
|
|
@@ -47,7 +47,7 @@ class ClipNormByEMA(Transform):
|
|
|
47
47
|
ema.lerp_(tensors, 1-beta)
|
|
48
48
|
|
|
49
49
|
if tensorwise:
|
|
50
|
-
ema_norm = ema.
|
|
50
|
+
ema_norm = ema.metric(ord)
|
|
51
51
|
|
|
52
52
|
# clip ema norm growth
|
|
53
53
|
if max_ema_growth is not None:
|
|
@@ -64,7 +64,7 @@ class ClipNormByEMA(Transform):
|
|
|
64
64
|
else: denom.clip_(min=1)
|
|
65
65
|
|
|
66
66
|
else:
|
|
67
|
-
ema_norm = ema.
|
|
67
|
+
ema_norm = ema.global_metric(ord)
|
|
68
68
|
|
|
69
69
|
# clip ema norm growth
|
|
70
70
|
if max_ema_growth is not None:
|
|
@@ -75,7 +75,7 @@ class ClipNormByEMA(Transform):
|
|
|
75
75
|
ema_norm = allowed_norm
|
|
76
76
|
prev_ema_norm.set_(ema_norm)
|
|
77
77
|
|
|
78
|
-
tensors_norm = tensors.
|
|
78
|
+
tensors_norm = tensors.global_metric(ord)
|
|
79
79
|
denom = tensors_norm / ema_norm.clip(min=eps[0])
|
|
80
80
|
if self.NORMALIZE: denom.clip_(min=eps[0])
|
|
81
81
|
else: denom.clip_(min=1)
|