torchzero 0.1.8__py3-none-any.whl → 0.3.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docs/source/conf.py +57 -0
- tests/test_identical.py +230 -0
- tests/test_module.py +50 -0
- tests/test_opts.py +884 -0
- tests/test_tensorlist.py +1787 -0
- tests/test_utils_optimizer.py +170 -0
- tests/test_vars.py +184 -0
- torchzero/__init__.py +4 -4
- torchzero/core/__init__.py +3 -13
- torchzero/core/module.py +629 -510
- torchzero/core/preconditioner.py +137 -0
- torchzero/core/transform.py +252 -0
- torchzero/modules/__init__.py +13 -21
- torchzero/modules/clipping/__init__.py +3 -0
- torchzero/modules/clipping/clipping.py +320 -0
- torchzero/modules/clipping/ema_clipping.py +135 -0
- torchzero/modules/clipping/growth_clipping.py +187 -0
- torchzero/modules/experimental/__init__.py +13 -18
- torchzero/modules/experimental/absoap.py +350 -0
- torchzero/modules/experimental/adadam.py +111 -0
- torchzero/modules/experimental/adamY.py +135 -0
- torchzero/modules/experimental/adasoap.py +282 -0
- torchzero/modules/experimental/algebraic_newton.py +145 -0
- torchzero/modules/experimental/curveball.py +89 -0
- torchzero/modules/experimental/dsoap.py +290 -0
- torchzero/modules/experimental/gradmin.py +85 -0
- torchzero/modules/experimental/reduce_outward_lr.py +35 -0
- torchzero/modules/experimental/spectral.py +286 -0
- torchzero/modules/experimental/subspace_preconditioners.py +128 -0
- torchzero/modules/experimental/tropical_newton.py +136 -0
- torchzero/modules/functional.py +209 -0
- torchzero/modules/grad_approximation/__init__.py +4 -0
- torchzero/modules/grad_approximation/fdm.py +120 -0
- torchzero/modules/grad_approximation/forward_gradient.py +81 -0
- torchzero/modules/grad_approximation/grad_approximator.py +66 -0
- torchzero/modules/grad_approximation/rfdm.py +259 -0
- torchzero/modules/line_search/__init__.py +5 -30
- torchzero/modules/line_search/backtracking.py +186 -0
- torchzero/modules/line_search/line_search.py +181 -0
- torchzero/modules/line_search/scipy.py +37 -0
- torchzero/modules/line_search/strong_wolfe.py +260 -0
- torchzero/modules/line_search/trust_region.py +61 -0
- torchzero/modules/lr/__init__.py +2 -0
- torchzero/modules/lr/lr.py +59 -0
- torchzero/modules/lr/step_size.py +97 -0
- torchzero/modules/momentum/__init__.py +14 -4
- torchzero/modules/momentum/averaging.py +78 -0
- torchzero/modules/momentum/cautious.py +181 -0
- torchzero/modules/momentum/ema.py +173 -0
- torchzero/modules/momentum/experimental.py +189 -0
- torchzero/modules/momentum/matrix_momentum.py +124 -0
- torchzero/modules/momentum/momentum.py +43 -106
- torchzero/modules/ops/__init__.py +103 -0
- torchzero/modules/ops/accumulate.py +65 -0
- torchzero/modules/ops/binary.py +240 -0
- torchzero/modules/ops/debug.py +25 -0
- torchzero/modules/ops/misc.py +419 -0
- torchzero/modules/ops/multi.py +137 -0
- torchzero/modules/ops/reduce.py +149 -0
- torchzero/modules/ops/split.py +75 -0
- torchzero/modules/ops/switch.py +68 -0
- torchzero/modules/ops/unary.py +115 -0
- torchzero/modules/ops/utility.py +112 -0
- torchzero/modules/optimizers/__init__.py +18 -10
- torchzero/modules/optimizers/adagrad.py +146 -49
- torchzero/modules/optimizers/adam.py +112 -118
- torchzero/modules/optimizers/lion.py +18 -11
- torchzero/modules/optimizers/muon.py +222 -0
- torchzero/modules/optimizers/orthograd.py +55 -0
- torchzero/modules/optimizers/rmsprop.py +103 -51
- torchzero/modules/optimizers/rprop.py +342 -99
- torchzero/modules/optimizers/shampoo.py +197 -0
- torchzero/modules/optimizers/soap.py +286 -0
- torchzero/modules/optimizers/sophia_h.py +129 -0
- torchzero/modules/projections/__init__.py +5 -0
- torchzero/modules/projections/dct.py +73 -0
- torchzero/modules/projections/fft.py +73 -0
- torchzero/modules/projections/galore.py +10 -0
- torchzero/modules/projections/projection.py +218 -0
- torchzero/modules/projections/structural.py +151 -0
- torchzero/modules/quasi_newton/__init__.py +7 -4
- torchzero/modules/quasi_newton/cg.py +218 -0
- torchzero/modules/quasi_newton/experimental/__init__.py +1 -0
- torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +265 -0
- torchzero/modules/quasi_newton/lbfgs.py +228 -0
- torchzero/modules/quasi_newton/lsr1.py +170 -0
- torchzero/modules/quasi_newton/olbfgs.py +196 -0
- torchzero/modules/quasi_newton/quasi_newton.py +475 -0
- torchzero/modules/second_order/__init__.py +3 -4
- torchzero/modules/second_order/newton.py +142 -165
- torchzero/modules/second_order/newton_cg.py +84 -0
- torchzero/modules/second_order/nystrom.py +168 -0
- torchzero/modules/smoothing/__init__.py +2 -5
- torchzero/modules/smoothing/gaussian.py +164 -0
- torchzero/modules/smoothing/{laplacian_smoothing.py → laplacian.py} +115 -128
- torchzero/modules/weight_decay/__init__.py +1 -0
- torchzero/modules/weight_decay/weight_decay.py +52 -0
- torchzero/modules/wrappers/__init__.py +1 -0
- torchzero/modules/wrappers/optim_wrapper.py +91 -0
- torchzero/optim/__init__.py +2 -10
- torchzero/optim/utility/__init__.py +1 -0
- torchzero/optim/utility/split.py +45 -0
- torchzero/optim/wrappers/nevergrad.py +2 -28
- torchzero/optim/wrappers/nlopt.py +31 -16
- torchzero/optim/wrappers/scipy.py +79 -156
- torchzero/utils/__init__.py +27 -0
- torchzero/utils/compile.py +175 -37
- torchzero/utils/derivatives.py +513 -99
- torchzero/utils/linalg/__init__.py +5 -0
- torchzero/utils/linalg/matrix_funcs.py +87 -0
- torchzero/utils/linalg/orthogonalize.py +11 -0
- torchzero/utils/linalg/qr.py +71 -0
- torchzero/utils/linalg/solve.py +168 -0
- torchzero/utils/linalg/svd.py +20 -0
- torchzero/utils/numberlist.py +132 -0
- torchzero/utils/ops.py +10 -0
- torchzero/utils/optimizer.py +284 -0
- torchzero/utils/optuna_tools.py +40 -0
- torchzero/utils/params.py +149 -0
- torchzero/utils/python_tools.py +40 -25
- torchzero/utils/tensorlist.py +1081 -0
- torchzero/utils/torch_tools.py +48 -12
- torchzero-0.3.2.dist-info/METADATA +379 -0
- torchzero-0.3.2.dist-info/RECORD +128 -0
- {torchzero-0.1.8.dist-info → torchzero-0.3.2.dist-info}/WHEEL +1 -1
- {torchzero-0.1.8.dist-info → torchzero-0.3.2.dist-info/licenses}/LICENSE +0 -0
- torchzero-0.3.2.dist-info/top_level.txt +3 -0
- torchzero/core/tensorlist_optimizer.py +0 -219
- torchzero/modules/adaptive/__init__.py +0 -4
- torchzero/modules/adaptive/adaptive.py +0 -192
- torchzero/modules/experimental/experimental.py +0 -294
- torchzero/modules/experimental/quad_interp.py +0 -104
- torchzero/modules/experimental/subspace.py +0 -259
- torchzero/modules/gradient_approximation/__init__.py +0 -7
- torchzero/modules/gradient_approximation/_fd_formulas.py +0 -3
- torchzero/modules/gradient_approximation/base_approximator.py +0 -105
- torchzero/modules/gradient_approximation/fdm.py +0 -125
- torchzero/modules/gradient_approximation/forward_gradient.py +0 -163
- torchzero/modules/gradient_approximation/newton_fdm.py +0 -198
- torchzero/modules/gradient_approximation/rfdm.py +0 -125
- torchzero/modules/line_search/armijo.py +0 -56
- torchzero/modules/line_search/base_ls.py +0 -139
- torchzero/modules/line_search/directional_newton.py +0 -217
- torchzero/modules/line_search/grid_ls.py +0 -158
- torchzero/modules/line_search/scipy_minimize_scalar.py +0 -62
- torchzero/modules/meta/__init__.py +0 -12
- torchzero/modules/meta/alternate.py +0 -65
- torchzero/modules/meta/grafting.py +0 -195
- torchzero/modules/meta/optimizer_wrapper.py +0 -173
- torchzero/modules/meta/return_overrides.py +0 -46
- torchzero/modules/misc/__init__.py +0 -10
- torchzero/modules/misc/accumulate.py +0 -43
- torchzero/modules/misc/basic.py +0 -115
- torchzero/modules/misc/lr.py +0 -96
- torchzero/modules/misc/multistep.py +0 -51
- torchzero/modules/misc/on_increase.py +0 -53
- torchzero/modules/operations/__init__.py +0 -29
- torchzero/modules/operations/multi.py +0 -298
- torchzero/modules/operations/reduction.py +0 -134
- torchzero/modules/operations/singular.py +0 -113
- torchzero/modules/optimizers/sgd.py +0 -54
- torchzero/modules/orthogonalization/__init__.py +0 -2
- torchzero/modules/orthogonalization/newtonschulz.py +0 -159
- torchzero/modules/orthogonalization/svd.py +0 -86
- torchzero/modules/regularization/__init__.py +0 -22
- torchzero/modules/regularization/dropout.py +0 -34
- torchzero/modules/regularization/noise.py +0 -77
- torchzero/modules/regularization/normalization.py +0 -328
- torchzero/modules/regularization/ortho_grad.py +0 -78
- torchzero/modules/regularization/weight_decay.py +0 -92
- torchzero/modules/scheduling/__init__.py +0 -2
- torchzero/modules/scheduling/lr_schedulers.py +0 -131
- torchzero/modules/scheduling/step_size.py +0 -80
- torchzero/modules/smoothing/gaussian_smoothing.py +0 -90
- torchzero/modules/weight_averaging/__init__.py +0 -2
- torchzero/modules/weight_averaging/ema.py +0 -72
- torchzero/modules/weight_averaging/swa.py +0 -171
- torchzero/optim/experimental/__init__.py +0 -20
- torchzero/optim/experimental/experimental.py +0 -343
- torchzero/optim/experimental/ray_search.py +0 -83
- torchzero/optim/first_order/__init__.py +0 -18
- torchzero/optim/first_order/cautious.py +0 -158
- torchzero/optim/first_order/forward_gradient.py +0 -70
- torchzero/optim/first_order/optimizers.py +0 -570
- torchzero/optim/modular.py +0 -148
- torchzero/optim/quasi_newton/__init__.py +0 -1
- torchzero/optim/quasi_newton/directional_newton.py +0 -58
- torchzero/optim/second_order/__init__.py +0 -1
- torchzero/optim/second_order/newton.py +0 -94
- torchzero/optim/zeroth_order/__init__.py +0 -4
- torchzero/optim/zeroth_order/fdm.py +0 -87
- torchzero/optim/zeroth_order/newton_fdm.py +0 -146
- torchzero/optim/zeroth_order/rfdm.py +0 -217
- torchzero/optim/zeroth_order/rs.py +0 -85
- torchzero/random/__init__.py +0 -1
- torchzero/random/random.py +0 -46
- torchzero/tensorlist.py +0 -826
- torchzero-0.1.8.dist-info/METADATA +0 -130
- torchzero-0.1.8.dist-info/RECORD +0 -104
- torchzero-0.1.8.dist-info/top_level.txt +0 -1
|
@@ -0,0 +1,320 @@
|
|
|
1
|
+
from operator import itemgetter
|
|
2
|
+
from typing import Literal
|
|
3
|
+
from collections.abc import Iterable, Sequence
|
|
4
|
+
import math
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
from ...core import Module, Target, Transform
|
|
8
|
+
from ...utils import NumberList, TensorList, generic_eq
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def clip_grad_value_(params: Iterable[torch.Tensor], value: float):
|
|
12
|
+
"""Clips gradient of an iterable of parameters at specified value.
|
|
13
|
+
Gradients are modified in-place.
|
|
14
|
+
Args:
|
|
15
|
+
params (Iterable[Tensor]): iterable of tensors with gradients to clip.
|
|
16
|
+
value (float or int): maximum allowed value of gradient
|
|
17
|
+
"""
|
|
18
|
+
grads = [p.grad for p in params if p.grad is not None]
|
|
19
|
+
torch._foreach_clamp_min_(grads, -value)
|
|
20
|
+
torch._foreach_clamp_max_(grads, value)
|
|
21
|
+
|
|
22
|
+
def _clip_norm_(
|
|
23
|
+
tensors_: TensorList,
|
|
24
|
+
min: float | NumberList | None,
|
|
25
|
+
max: float | NumberList | None,
|
|
26
|
+
norm_value: float | NumberList | None,
|
|
27
|
+
ord: float,
|
|
28
|
+
dim: int | Sequence[int] | Literal["global"] | None,
|
|
29
|
+
inverse_dims: bool,
|
|
30
|
+
min_size: int,
|
|
31
|
+
) -> TensorList:
|
|
32
|
+
"""generic function that can clip norm or normalize"""
|
|
33
|
+
if norm_value is not None:
|
|
34
|
+
if min is not None or max is not None:
|
|
35
|
+
raise ValueError(f'if norm_value is given then min and max must be None got {min = }; {max = }')
|
|
36
|
+
|
|
37
|
+
# if dim is None: return tensors_.mul_(norm_value / tensors_.norm(ord=ord))
|
|
38
|
+
if dim == 'global': return tensors_.mul_(norm_value / tensors_.global_vector_norm(ord=ord))
|
|
39
|
+
|
|
40
|
+
# if dim is None: return tensors_.clip_norm_(min,max,tensorwise=True,ord=ord)
|
|
41
|
+
if dim == 'global': return tensors_.clip_norm_(min,max,tensorwise=False,ord=ord)
|
|
42
|
+
|
|
43
|
+
muls = []
|
|
44
|
+
tensors_to_mul = []
|
|
45
|
+
if isinstance(dim, int): dim = (dim, )
|
|
46
|
+
|
|
47
|
+
for i, tensor in enumerate(tensors_):
|
|
48
|
+
# remove dimensions that overflow tensor.ndim or are too small
|
|
49
|
+
if tensor.ndim == 0: tensor = tensor.unsqueeze(0)
|
|
50
|
+
if dim is None: dim = list(range(tensor.ndim))
|
|
51
|
+
real_dim = [d for d in dim if d < tensor.ndim]
|
|
52
|
+
if inverse_dims: real_dim = [d for d in range(tensor.ndim) if d not in real_dim]
|
|
53
|
+
if len(real_dim) == 0: continue
|
|
54
|
+
size = math.prod(tensor.size(d) for d in real_dim)
|
|
55
|
+
if size < min_size: continue
|
|
56
|
+
|
|
57
|
+
norm: torch.Tensor = torch.linalg.vector_norm(tensor, ord=ord, dim=real_dim, keepdim=True) # pylint:disable=not-callable
|
|
58
|
+
if norm.numel() == 1 and norm == 0: continue
|
|
59
|
+
norm = torch.where(norm == 0, 1, norm)
|
|
60
|
+
|
|
61
|
+
# normalize = True, perform normalization
|
|
62
|
+
norm_v = norm_value[i] if isinstance(norm_value, (list,tuple)) else norm_value
|
|
63
|
+
if norm_v is not None:
|
|
64
|
+
mul = norm_v / norm
|
|
65
|
+
|
|
66
|
+
# else clip to min and max norms
|
|
67
|
+
else:
|
|
68
|
+
minv = min[i] if isinstance(min, (list,tuple)) else min
|
|
69
|
+
maxv = max[i] if isinstance(max, (list,tuple)) else max
|
|
70
|
+
|
|
71
|
+
mul = 1
|
|
72
|
+
if minv is not None:
|
|
73
|
+
mul_to_min = (minv / norm).clamp(min=1)
|
|
74
|
+
mul *= mul_to_min
|
|
75
|
+
|
|
76
|
+
if maxv is not None:
|
|
77
|
+
mul_to_max = (maxv / norm).clamp(max=1)
|
|
78
|
+
mul *= mul_to_max
|
|
79
|
+
|
|
80
|
+
muls.append(mul)
|
|
81
|
+
tensors_to_mul.append(tensor)
|
|
82
|
+
|
|
83
|
+
if len(muls) > 0:
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
torch._foreach_mul_(tensors_to_mul, muls)
|
|
87
|
+
return tensors_
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def clip_grad_norm_(
|
|
91
|
+
params: Iterable[torch.Tensor],
|
|
92
|
+
max_norm: float | None,
|
|
93
|
+
ord: float = 2,
|
|
94
|
+
dim: int | Sequence[int] | Literal["global"] | None = None,
|
|
95
|
+
inverse_dims: bool = False,
|
|
96
|
+
min_size: int = 2,
|
|
97
|
+
min_norm: float | None = None,
|
|
98
|
+
):
|
|
99
|
+
"""Clips gradient of an iterable of parameters to specified norm value.
|
|
100
|
+
Gradients are modified in-place.
|
|
101
|
+
|
|
102
|
+
Args:
|
|
103
|
+
params (Iterable[torch.Tensor]): parameters with gradients to clip.
|
|
104
|
+
value (float): value to clip norm to.
|
|
105
|
+
ord (float, optional): norm order. Defaults to 2.
|
|
106
|
+
dim (int | Sequence[int] | str | None, optional):
|
|
107
|
+
calculates norm along those dimensions.
|
|
108
|
+
If list/tuple, tensors are normalized along all dimensios in `dim` that they have.
|
|
109
|
+
Can be set to "global" to normalize by global norm of all gradients concatenated to a vector.
|
|
110
|
+
Defaults to None.
|
|
111
|
+
min_size (int, optional):
|
|
112
|
+
minimal size of a dimension to normalize along it. Defaults to 1.
|
|
113
|
+
"""
|
|
114
|
+
grads = TensorList(p.grad for p in params if p.grad is not None)
|
|
115
|
+
_clip_norm_(grads, min=min_norm, max=max_norm, norm_value=None, ord=ord, dim=dim, inverse_dims=inverse_dims, min_size=min_size)
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
def normalize_grads_(
|
|
119
|
+
params: Iterable[torch.Tensor],
|
|
120
|
+
norm_value: float,
|
|
121
|
+
ord: float = 2,
|
|
122
|
+
dim: int | Sequence[int] | Literal["global"] | None = None,
|
|
123
|
+
inverse_dims: bool = False,
|
|
124
|
+
min_size: int = 1,
|
|
125
|
+
):
|
|
126
|
+
"""Normalizes gradient of an iterable of parameters to specified norm value.
|
|
127
|
+
Gradients are modified in-place.
|
|
128
|
+
|
|
129
|
+
Args:
|
|
130
|
+
params (Iterable[torch.Tensor]): parameters with gradients to clip.
|
|
131
|
+
norm_value (float): value to clip norm to.
|
|
132
|
+
ord (float, optional): norm order. Defaults to 2.
|
|
133
|
+
dim (int | Sequence[int] | str | None, optional):
|
|
134
|
+
calculates norm along those dimensions.
|
|
135
|
+
If list/tuple, tensors are normalized along all dimensios in `dim` that they have.
|
|
136
|
+
Can be set to "global" to normalize by global norm of all gradients concatenated to a vector.
|
|
137
|
+
Defaults to None.
|
|
138
|
+
inverse_dims (bool, optional):
|
|
139
|
+
if True, the `dims` argument is inverted, and all other dimensions are normalized.
|
|
140
|
+
min_size (int, optional):
|
|
141
|
+
minimal size of a dimension to normalize along it. Defaults to 1.
|
|
142
|
+
"""
|
|
143
|
+
grads = TensorList(p.grad for p in params if p.grad is not None)
|
|
144
|
+
_clip_norm_(grads, min=None, max=None, norm_value=norm_value, ord=ord, dim=dim, inverse_dims=inverse_dims, min_size=min_size)
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
class ClipValue(Transform):
|
|
148
|
+
"""Clips update magnitude to be within `(-value, value)` range."""
|
|
149
|
+
def __init__(self, value: float, target: Target = 'update'):
|
|
150
|
+
defaults = dict(value=value)
|
|
151
|
+
super().__init__(defaults, uses_grad=False, target=target)
|
|
152
|
+
|
|
153
|
+
@torch.no_grad
|
|
154
|
+
def transform(self, tensors, params, grads, vars):
|
|
155
|
+
value = self.get_settings('value', params=params)
|
|
156
|
+
return TensorList(tensors).clip_([-v for v in value], value)
|
|
157
|
+
|
|
158
|
+
class ClipNorm(Transform):
|
|
159
|
+
"""Clips update norm to be no larger than `value`.
|
|
160
|
+
|
|
161
|
+
Args:
|
|
162
|
+
value (float): value to clip norm to.
|
|
163
|
+
ord (float, optional): norm order. Defaults to 2.
|
|
164
|
+
dim (int | Sequence[int] | str | None, optional):
|
|
165
|
+
calculates norm along those dimensions.
|
|
166
|
+
If list/tuple, tensors are normalized along all dimensios in `dim` that they have.
|
|
167
|
+
Can be set to "global" to normalize by global norm of all gradients concatenated to a vector.
|
|
168
|
+
Defaults to None.
|
|
169
|
+
inverse_dims (bool, optional):
|
|
170
|
+
if True, the `dims` argument is inverted, and all other dimensions are normalized.
|
|
171
|
+
min_size (int, optional):
|
|
172
|
+
minimal numer of elements in a parameter or slice to clip norm. Defaults to 1.
|
|
173
|
+
target (str, optional):
|
|
174
|
+
what this affects.
|
|
175
|
+
"""
|
|
176
|
+
def __init__(
|
|
177
|
+
self,
|
|
178
|
+
max_norm: float,
|
|
179
|
+
ord: float = 2,
|
|
180
|
+
dim: int | Sequence[int] | Literal["global"] | None = None,
|
|
181
|
+
inverse_dims: bool = False,
|
|
182
|
+
min_size: int = 1,
|
|
183
|
+
target: Target = "update",
|
|
184
|
+
):
|
|
185
|
+
defaults = dict(max_norm=max_norm,ord=ord,dim=dim,min_size=min_size,inverse_dims=inverse_dims)
|
|
186
|
+
super().__init__(defaults, uses_grad=False, target=target)
|
|
187
|
+
|
|
188
|
+
@torch.no_grad
|
|
189
|
+
def transform(self, tensors, params, grads, vars):
|
|
190
|
+
max_norm = self.get_settings('max_norm', params=params, cls=NumberList)
|
|
191
|
+
ord, dim, min_size, inverse_dims = itemgetter('ord', 'dim', 'min_size', 'inverse_dims')(self.settings[params[0]])
|
|
192
|
+
_clip_norm_(
|
|
193
|
+
tensors_ = TensorList(tensors),
|
|
194
|
+
min = 0,
|
|
195
|
+
max = max_norm,
|
|
196
|
+
norm_value = None,
|
|
197
|
+
ord = ord,
|
|
198
|
+
dim = dim,
|
|
199
|
+
inverse_dims=inverse_dims,
|
|
200
|
+
min_size = min_size,
|
|
201
|
+
)
|
|
202
|
+
return tensors
|
|
203
|
+
|
|
204
|
+
class Normalize(Transform):
|
|
205
|
+
"""Normalizes the update.
|
|
206
|
+
|
|
207
|
+
Args:
|
|
208
|
+
value (float): desired norm value.
|
|
209
|
+
ord (float, optional): norm order. Defaults to 2.
|
|
210
|
+
dim (int | Sequence[int] | str | None, optional):
|
|
211
|
+
calculates norm along those dimensions.
|
|
212
|
+
If list/tuple, tensors are normalized along all dimensios in `dim` that they have.
|
|
213
|
+
Can be set to "global" to normalize by global norm of all gradients concatenated to a vector.
|
|
214
|
+
Defaults to None.
|
|
215
|
+
inverse_dims (bool, optional):
|
|
216
|
+
if True, the `dims` argument is inverted, and all other dimensions are normalized.
|
|
217
|
+
min_size (int, optional):
|
|
218
|
+
minimal size of a dimension to normalize along it. Defaults to 1.
|
|
219
|
+
target (str, optional):
|
|
220
|
+
what this affects.
|
|
221
|
+
"""
|
|
222
|
+
def __init__(
|
|
223
|
+
self,
|
|
224
|
+
norm_value: float = 1,
|
|
225
|
+
ord: float = 2,
|
|
226
|
+
dim: int | Sequence[int] | Literal["global"] | None = None,
|
|
227
|
+
inverse_dims: bool = False,
|
|
228
|
+
min_size: int = 1,
|
|
229
|
+
target: Target = "update",
|
|
230
|
+
):
|
|
231
|
+
defaults = dict(norm_value=norm_value,ord=ord,dim=dim,min_size=min_size, inverse_dims=inverse_dims)
|
|
232
|
+
super().__init__(defaults, uses_grad=False, target=target)
|
|
233
|
+
|
|
234
|
+
@torch.no_grad
|
|
235
|
+
def transform(self, tensors, params, grads, vars):
|
|
236
|
+
norm_value = self.get_settings('norm_value', params=params, cls=NumberList)
|
|
237
|
+
ord, dim, min_size, inverse_dims = itemgetter('ord', 'dim', 'min_size', 'inverse_dims')(self.settings[params[0]])
|
|
238
|
+
|
|
239
|
+
_clip_norm_(
|
|
240
|
+
tensors_ = TensorList(tensors),
|
|
241
|
+
min = None,
|
|
242
|
+
max = None,
|
|
243
|
+
norm_value = norm_value,
|
|
244
|
+
ord = ord,
|
|
245
|
+
dim = dim,
|
|
246
|
+
inverse_dims=inverse_dims,
|
|
247
|
+
min_size = min_size,
|
|
248
|
+
)
|
|
249
|
+
|
|
250
|
+
return tensors
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
def _centralize_(
|
|
254
|
+
tensors_: TensorList,
|
|
255
|
+
dim: int | Sequence[int] | Literal["global"] | None,
|
|
256
|
+
min_size: int,
|
|
257
|
+
inverse_dims: bool,
|
|
258
|
+
) -> TensorList:
|
|
259
|
+
"""generic function that can clip norm or normalize"""
|
|
260
|
+
if dim == 'global': return tensors_.sub_(tensors_.global_mean().item())
|
|
261
|
+
|
|
262
|
+
subs = []
|
|
263
|
+
tensors_to_sub = []
|
|
264
|
+
if isinstance(dim, int): dim = (dim, )
|
|
265
|
+
|
|
266
|
+
for tensor in tensors_:
|
|
267
|
+
# remove dimensions that overflow tensor.ndim or are too small
|
|
268
|
+
if dim is None: dim = list(range(tensor.ndim))
|
|
269
|
+
real_dim = [d for d in dim if d < tensor.ndim]
|
|
270
|
+
if inverse_dims: real_dim = [d for d in range(tensor.ndim) if d not in real_dim]
|
|
271
|
+
if len(real_dim) == 0: continue
|
|
272
|
+
size = math.prod(tensor.size(d) for d in real_dim)
|
|
273
|
+
if size < min_size: continue
|
|
274
|
+
|
|
275
|
+
mean: torch.Tensor = torch.mean(tensor, dim=real_dim, keepdim=True)
|
|
276
|
+
if mean.numel() == 1 and mean == 0: continue
|
|
277
|
+
|
|
278
|
+
subs.append(mean)
|
|
279
|
+
tensors_to_sub.append(tensor)
|
|
280
|
+
|
|
281
|
+
if len(subs) > 0:
|
|
282
|
+
torch._foreach_sub_(tensors_to_sub, subs)
|
|
283
|
+
|
|
284
|
+
return tensors_
|
|
285
|
+
|
|
286
|
+
|
|
287
|
+
class Centralize(Transform):
|
|
288
|
+
"""Centralizes the update.
|
|
289
|
+
|
|
290
|
+
Args:
|
|
291
|
+
value (float): desired norm value.
|
|
292
|
+
ord (float, optional): norm order. Defaults to 2.
|
|
293
|
+
dim (int | Sequence[int] | str | None, optional):
|
|
294
|
+
calculates norm along those dimensions.
|
|
295
|
+
If list/tuple, tensors are centralized along all dimensios in `dim` that they have.
|
|
296
|
+
Can be set to "global" to centralize by global mean of all gradients concatenated to a vector.
|
|
297
|
+
Defaults to None.
|
|
298
|
+
inverse_dims (bool, optional):
|
|
299
|
+
if True, the `dims` argument is inverted, and all other dimensions are centralized.
|
|
300
|
+
min_size (int, optional):
|
|
301
|
+
minimal size of a dimension to normalize along it. Defaults to 1.
|
|
302
|
+
"""
|
|
303
|
+
def __init__(
|
|
304
|
+
self,
|
|
305
|
+
dim: int | Sequence[int] | Literal["global"] | None = None,
|
|
306
|
+
inverse_dims: bool = False,
|
|
307
|
+
min_size: int = 2,
|
|
308
|
+
target: Target = "update",
|
|
309
|
+
):
|
|
310
|
+
defaults = dict(dim=dim,min_size=min_size,inverse_dims=inverse_dims)
|
|
311
|
+
super().__init__(defaults, uses_grad=False, target=target)
|
|
312
|
+
|
|
313
|
+
@torch.no_grad
|
|
314
|
+
def transform(self, tensors, params, grads, vars):
|
|
315
|
+
dim, min_size, inverse_dims = itemgetter('dim', 'min_size', 'inverse_dims')(self.settings[params[0]])
|
|
316
|
+
|
|
317
|
+
_centralize_(tensors_ = TensorList(tensors), dim=dim, inverse_dims=inverse_dims, min_size=min_size)
|
|
318
|
+
|
|
319
|
+
return tensors
|
|
320
|
+
|
|
@@ -0,0 +1,135 @@
|
|
|
1
|
+
from operator import itemgetter
|
|
2
|
+
from typing import Literal
|
|
3
|
+
from collections.abc import Iterable, Sequence
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
from ...core import Module, Target, Transform, apply, Chainable
|
|
8
|
+
from ...utils import NumberList, TensorList, generic_eq
|
|
9
|
+
|
|
10
|
+
class ClipNormByEMA(Transform):
|
|
11
|
+
"""Clips norm to be no larger than the norm of an exponential moving average of past updates.
|
|
12
|
+
|
|
13
|
+
Args:
|
|
14
|
+
beta (float, optional): beta for the exponential moving average. Defaults to 0.99.
|
|
15
|
+
ord (float, optional): order of the norm. Defaults to 2.
|
|
16
|
+
eps (float, optional): epsilon for division. Defaults to 1e-6.
|
|
17
|
+
tensorwise (bool, optional): whether to calculate norm separately for each layer, or global norm for all layers. Defaults to True.
|
|
18
|
+
max_ema_growth (float | None, optional):
|
|
19
|
+
if specified, exponential moving average norm can grow but at most this value per step. Defaults to 1.5.
|
|
20
|
+
ema_init (str, optional):
|
|
21
|
+
How to initialize exponential moving average on first step, "update" to use the first update or "zeros". Defaults to 'zeros'.
|
|
22
|
+
"""
|
|
23
|
+
NORMALIZE = False
|
|
24
|
+
def __init__(
|
|
25
|
+
self,
|
|
26
|
+
beta=0.99,
|
|
27
|
+
ord: float = 2,
|
|
28
|
+
eps=1e-6,
|
|
29
|
+
tensorwise:bool=True,
|
|
30
|
+
max_ema_growth: float | None = 1.5,
|
|
31
|
+
ema_init: Literal['zeros', 'update'] = 'zeros',
|
|
32
|
+
):
|
|
33
|
+
defaults = dict(beta=beta, ord=ord, tensorwise=tensorwise, ema_init=ema_init, eps=eps, max_ema_growth=max_ema_growth)
|
|
34
|
+
super().__init__(defaults, uses_grad=False)
|
|
35
|
+
|
|
36
|
+
@torch.no_grad
|
|
37
|
+
def transform(self, tensors, params, grads, vars):
|
|
38
|
+
ord, tensorwise, ema_init, max_ema_growth = itemgetter('ord', 'tensorwise', 'ema_init', 'max_ema_growth')(self.settings[params[0]])
|
|
39
|
+
|
|
40
|
+
beta, eps = self.get_settings('beta', 'eps', params=params, cls=NumberList)
|
|
41
|
+
tensors = TensorList(tensors)
|
|
42
|
+
|
|
43
|
+
ema = self.get_state('ema', params=params, init = (torch.zeros_like if ema_init=='zeros' else tensors), cls=TensorList)
|
|
44
|
+
ema.lerp_(tensors, 1-beta)
|
|
45
|
+
|
|
46
|
+
if tensorwise:
|
|
47
|
+
ema_norm = ema.norm(ord)
|
|
48
|
+
|
|
49
|
+
# clip ema norm growth
|
|
50
|
+
if max_ema_growth is not None:
|
|
51
|
+
prev_ema_norm = self.get_state('prev_ema_norm', params=params, init=ema_norm, cls=TensorList)
|
|
52
|
+
allowed_norm = (prev_ema_norm * max_ema_growth).clip(min=1e-6)
|
|
53
|
+
ema_denom = (ema_norm / allowed_norm).clip(min=1)
|
|
54
|
+
ema.div_(ema_denom)
|
|
55
|
+
ema_norm.div_(ema_denom)
|
|
56
|
+
prev_ema_norm.set_(ema_norm)
|
|
57
|
+
|
|
58
|
+
tensors_norm = tensors.norm(ord)
|
|
59
|
+
denom = tensors_norm / ema_norm.clip(min=eps)
|
|
60
|
+
if self.NORMALIZE: denom.clip_(min=eps)
|
|
61
|
+
else: denom.clip_(min=1)
|
|
62
|
+
|
|
63
|
+
else:
|
|
64
|
+
ema_norm = ema.global_vector_norm(ord)
|
|
65
|
+
|
|
66
|
+
# clip ema norm growth
|
|
67
|
+
if max_ema_growth is not None:
|
|
68
|
+
prev_ema_norm = self.global_state.setdefault('prev_ema_norm', ema_norm)
|
|
69
|
+
allowed_norm = prev_ema_norm * max_ema_growth
|
|
70
|
+
if ema_norm > allowed_norm:
|
|
71
|
+
ema.div_(ema_norm / allowed_norm)
|
|
72
|
+
ema_norm = allowed_norm
|
|
73
|
+
prev_ema_norm.set_(ema_norm)
|
|
74
|
+
|
|
75
|
+
tensors_norm = tensors.global_vector_norm(ord)
|
|
76
|
+
denom = tensors_norm / ema_norm.clip(min=eps[0])
|
|
77
|
+
if self.NORMALIZE: denom.clip_(min=eps[0])
|
|
78
|
+
else: denom.clip_(min=1)
|
|
79
|
+
|
|
80
|
+
tensors.div_(denom)
|
|
81
|
+
return tensors
|
|
82
|
+
|
|
83
|
+
class NormalizeByEMA(ClipNormByEMA):
|
|
84
|
+
"""Sets norm of the update to be the same as the norm of an exponential moving average of past updates.
|
|
85
|
+
|
|
86
|
+
Args:
|
|
87
|
+
beta (float, optional): beta for the exponential moving average. Defaults to 0.99.
|
|
88
|
+
ord (float, optional): order of the norm. Defaults to 2.
|
|
89
|
+
eps (float, optional): epsilon for division. Defaults to 1e-6.
|
|
90
|
+
tensorwise (bool, optional): whether to calculate norm separately for each layer, or global norm for all layers. Defaults to True.
|
|
91
|
+
max_ema_growth (float | None, optional):
|
|
92
|
+
if specified, exponential moving average norm can grow but at most this value per step. Defaults to 1.5.
|
|
93
|
+
ema_init (str, optional):
|
|
94
|
+
How to initialize exponential moving average on first step, "update" to use the first update or "zeros". Defaults to 'zeros'.
|
|
95
|
+
"""
|
|
96
|
+
NORMALIZE = True
|
|
97
|
+
|
|
98
|
+
# TODO Centralize by EMA?
|
|
99
|
+
|
|
100
|
+
class ClipValueByEMA(Transform):
|
|
101
|
+
"""Clips magnitude of update to be no larger than magnitude of an exponential moving average of past (unclipped) updates.
|
|
102
|
+
|
|
103
|
+
Args:
|
|
104
|
+
beta (float, optional): beta for the exponential moving average. Defaults to 0.99.
|
|
105
|
+
ema_init (str, optional):
|
|
106
|
+
How to initialize exponential moving average on first step, "update" to use the first update or "zeros". Defaults to 'zeros'.
|
|
107
|
+
ema_tfm (Chainable | None, optional): optional modules applied to exponential moving average before clipping by it. Defaults to None.
|
|
108
|
+
"""
|
|
109
|
+
def __init__(
|
|
110
|
+
self,
|
|
111
|
+
beta=0.99,
|
|
112
|
+
ema_init: Literal['zeros', 'update'] = 'zeros',
|
|
113
|
+
ema_tfm:Chainable | None=None,
|
|
114
|
+
):
|
|
115
|
+
defaults = dict(beta=beta, ema_init=ema_init)
|
|
116
|
+
super().__init__(defaults, uses_grad=False)
|
|
117
|
+
|
|
118
|
+
if ema_tfm is not None:
|
|
119
|
+
self.set_child('ema_tfm', ema_tfm)
|
|
120
|
+
|
|
121
|
+
@torch.no_grad
|
|
122
|
+
def transform(self, tensors, params, grads, vars):
|
|
123
|
+
ema_init = itemgetter('ema_init')(self.settings[params[0]])
|
|
124
|
+
|
|
125
|
+
beta = self.get_settings('beta', params=params, cls=NumberList)
|
|
126
|
+
tensors = TensorList(tensors)
|
|
127
|
+
|
|
128
|
+
ema = self.get_state('ema', params=params, init = (torch.zeros_like if ema_init=='zeros' else lambda t: t.abs()), cls=TensorList)
|
|
129
|
+
ema.lerp_(tensors.abs(), 1-beta)
|
|
130
|
+
|
|
131
|
+
if 'ema_tfm' in self.children:
|
|
132
|
+
ema = TensorList(apply(self.children['ema_tfm'], ema, params, vars.grad, vars))
|
|
133
|
+
|
|
134
|
+
tensors.clip_(-ema, ema)
|
|
135
|
+
return tensors
|
|
@@ -0,0 +1,187 @@
|
|
|
1
|
+
from operator import itemgetter
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from ...core import TensorwiseTransform, Target, Transform
|
|
6
|
+
from ...utils import TensorList, as_tensorlist
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class ClipValueGrowth(TensorwiseTransform):
|
|
10
|
+
"""Clips update value magnitude growth.
|
|
11
|
+
|
|
12
|
+
Args:
|
|
13
|
+
add (float | None, optional): additive clipping, next update is at most `previous update + add`. Defaults to None.
|
|
14
|
+
mul (float | None, optional): multiplicative clipping, next update is at most `previous update * mul`. Defaults to 1.5.
|
|
15
|
+
min_value (float | None, optional):
|
|
16
|
+
minimum value for multiplicative clipping to prevent collapse to 0.
|
|
17
|
+
Next update is at most :code:`max(prev_update, min_value) * mul`. Defaults to 1e-4.
|
|
18
|
+
max_decay (float | None, optional):
|
|
19
|
+
bounds the tracked multiplicative clipping decay to prevent collapse to 0.
|
|
20
|
+
Next update is at most :code:`max(previous update * mul, max_decay)`.
|
|
21
|
+
Defaults to 2.
|
|
22
|
+
target (Target, optional): what to set on vars.. Defaults to "update".
|
|
23
|
+
"""
|
|
24
|
+
def __init__(
|
|
25
|
+
self,
|
|
26
|
+
add: float | None = None,
|
|
27
|
+
mul: float | None = 1.5,
|
|
28
|
+
min_value: float | None = 1e-4,
|
|
29
|
+
max_decay: float | None = 2,
|
|
30
|
+
target: Target = "update",
|
|
31
|
+
):
|
|
32
|
+
defaults = dict(add=add, mul=mul, min_value=min_value, max_decay=max_decay)
|
|
33
|
+
super().__init__(defaults, uses_grad=False, target=target)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def transform(self, tensor, param, grad, vars):
|
|
37
|
+
add, mul, min_value, max_decay = itemgetter('add','mul','min_value','max_decay')(self.settings[param])
|
|
38
|
+
add: float | None
|
|
39
|
+
|
|
40
|
+
state = self.state[param]
|
|
41
|
+
|
|
42
|
+
if add is None and mul is None:
|
|
43
|
+
return tensor
|
|
44
|
+
|
|
45
|
+
if 'prev' not in state:
|
|
46
|
+
state['prev'] = tensor.clone()
|
|
47
|
+
return tensor
|
|
48
|
+
|
|
49
|
+
prev: torch.Tensor = state['prev']
|
|
50
|
+
|
|
51
|
+
# additive bound
|
|
52
|
+
if add is not None:
|
|
53
|
+
growth = (tensor.abs() - prev.abs()).clip(min=0)
|
|
54
|
+
tensor.sub_(torch.where(growth > add, (growth-add).copysign_(tensor), 0))
|
|
55
|
+
|
|
56
|
+
# multiplicative bound
|
|
57
|
+
growth = None
|
|
58
|
+
if mul is not None:
|
|
59
|
+
prev_magn = prev.abs()
|
|
60
|
+
if min_value is not None: prev_magn.clip_(min=min_value)
|
|
61
|
+
growth = (tensor.abs() / prev_magn).clamp_(min=1e-8)
|
|
62
|
+
|
|
63
|
+
denom = torch.where(growth > mul, growth/mul, 1)
|
|
64
|
+
|
|
65
|
+
tensor.div_(denom)
|
|
66
|
+
|
|
67
|
+
# limit max growth decay
|
|
68
|
+
if max_decay is not None:
|
|
69
|
+
if growth is None:
|
|
70
|
+
prev_magn = prev.abs()
|
|
71
|
+
if min_value is not None: prev_magn.clip_(min=min_value)
|
|
72
|
+
growth = (tensor.abs() / prev_magn).clamp_(min=1e-8)
|
|
73
|
+
|
|
74
|
+
new_prev = torch.where(growth < (1/max_decay), prev/max_decay, tensor)
|
|
75
|
+
else:
|
|
76
|
+
new_prev = tensor.clone()
|
|
77
|
+
|
|
78
|
+
state['prev'] = new_prev
|
|
79
|
+
return tensor
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def norm_growth_clip_(
|
|
83
|
+
tensor_: torch.Tensor,
|
|
84
|
+
prev_norm: torch.Tensor,
|
|
85
|
+
add: float | None,
|
|
86
|
+
mul: float | None,
|
|
87
|
+
min_value: float | None,
|
|
88
|
+
max_decay: float | None,
|
|
89
|
+
ord: float,
|
|
90
|
+
):
|
|
91
|
+
if add is None and mul is None: return tensor_
|
|
92
|
+
norm = torch.linalg.vector_norm(tensor_, ord=ord) # pylint:disable=not-callable
|
|
93
|
+
|
|
94
|
+
denom = 1
|
|
95
|
+
# additive bound
|
|
96
|
+
if add is not None:
|
|
97
|
+
allowed_norm = prev_norm + add
|
|
98
|
+
if norm > allowed_norm: denom = norm / allowed_norm
|
|
99
|
+
|
|
100
|
+
# multiplicative bound
|
|
101
|
+
if mul is not None:
|
|
102
|
+
allowed_norm = prev_norm * mul
|
|
103
|
+
if norm > allowed_norm: denom = max(denom, norm / allowed_norm)
|
|
104
|
+
|
|
105
|
+
# minimal norm
|
|
106
|
+
if min_value is not None:
|
|
107
|
+
denom = max(denom, min_value)
|
|
108
|
+
|
|
109
|
+
# limit max growth decay
|
|
110
|
+
new_prev_norm = norm/denom
|
|
111
|
+
if max_decay is not None:
|
|
112
|
+
decay = norm / prev_norm
|
|
113
|
+
if decay < (1/max_decay):
|
|
114
|
+
new_prev_norm = prev_norm / max_decay
|
|
115
|
+
|
|
116
|
+
if min_value is not None: new_prev_norm = max(new_prev_norm, min_value) # pyright:ignore[reportArgumentType]
|
|
117
|
+
return tensor_.div_(denom), new_prev_norm, denom
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
class ClipNormGrowth(Transform):
|
|
121
|
+
"""Clips update norm growth.
|
|
122
|
+
|
|
123
|
+
Args:
|
|
124
|
+
add (float | None, optional): additive clipping, next update norm is at most `previous norm + add`. Defaults to None.
|
|
125
|
+
mul (float | None, optional): multiplicative clipping, next update norm is at most `previous norm * mul`. Defaults to 1.5.
|
|
126
|
+
min_value (float | None, optional):
|
|
127
|
+
minimum value for multiplicative clipping to prevent collapse to 0.
|
|
128
|
+
Next norm is at most :code:`max(prev_norm, min_value) * mul`. Defaults to 1e-4.
|
|
129
|
+
max_decay (float | None, optional):
|
|
130
|
+
bounds the tracked multiplicative clipping decay to prevent collapse to 0.
|
|
131
|
+
Next norm is at most :code:`max(previous norm * mul, max_decay)`.
|
|
132
|
+
Defaults to 2.
|
|
133
|
+
ord (float, optional): norm order. Defaults to 2.
|
|
134
|
+
parameterwise (bool, optional):
|
|
135
|
+
if True, norms are calculated parameter-wise, otherwise treats all parameters as single vector. Defaults to True.
|
|
136
|
+
target (Target, optional): what to set on vars. Defaults to "update".
|
|
137
|
+
"""
|
|
138
|
+
def __init__(
|
|
139
|
+
self,
|
|
140
|
+
add: float | None = None,
|
|
141
|
+
mul: float | None = 1.5,
|
|
142
|
+
min_value: float | None = 1e-4,
|
|
143
|
+
max_decay: float | None = 2,
|
|
144
|
+
ord: float = 2,
|
|
145
|
+
parameterwise=True,
|
|
146
|
+
target: Target = "update",
|
|
147
|
+
):
|
|
148
|
+
defaults = dict(add=add, mul=mul, min_value=min_value, max_decay=max_decay, ord=ord, parameterwise=parameterwise)
|
|
149
|
+
super().__init__(defaults, uses_grad=False, target=target)
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
def transform(self, tensors, params, grads, vars):
|
|
154
|
+
parameterwise = self.settings[params[0]]['parameterwise']
|
|
155
|
+
tensors = TensorList(tensors)
|
|
156
|
+
|
|
157
|
+
if parameterwise:
|
|
158
|
+
ts = tensors
|
|
159
|
+
stts = [self.state[p] for p in params]
|
|
160
|
+
stns = [self.settings[p] for p in params]
|
|
161
|
+
|
|
162
|
+
else:
|
|
163
|
+
ts = [tensors.to_vec()]
|
|
164
|
+
stts = [self.global_state]
|
|
165
|
+
stns = [self.settings[params[0]]]
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
for t,state, settings in zip(ts, stts, stns):
|
|
169
|
+
if 'prev_norm' not in state:
|
|
170
|
+
state['prev_norm'] = torch.linalg.vector_norm(t, ord=settings['ord']) # pylint:disable=not-callable
|
|
171
|
+
state['prev_denom'] = 1
|
|
172
|
+
continue
|
|
173
|
+
|
|
174
|
+
_, state['prev_norm'], state['prev_denom'] = norm_growth_clip_(
|
|
175
|
+
tensor_ = t,
|
|
176
|
+
prev_norm = state['prev_norm'],
|
|
177
|
+
add = settings['add'],
|
|
178
|
+
mul = settings['mul'],
|
|
179
|
+
min_value = settings['min_value'],
|
|
180
|
+
max_decay = settings['max_decay'],
|
|
181
|
+
ord = settings['ord'],
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
if not parameterwise:
|
|
185
|
+
tensors.from_vec_(ts[0])
|
|
186
|
+
|
|
187
|
+
return tensors
|
|
@@ -1,19 +1,14 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
from .
|
|
5
|
-
from .
|
|
6
|
-
from .
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
ProjGradRay,
|
|
14
|
-
ProjLastAscentDifference,
|
|
15
|
-
ProjLastGradDifference,
|
|
16
|
-
ProjNormalize,
|
|
17
|
-
ProjRandom,
|
|
18
|
-
Subspace,
|
|
1
|
+
from .absoap import ABSOAP
|
|
2
|
+
from .adadam import Adadam
|
|
3
|
+
from .adamY import AdamY
|
|
4
|
+
from .adasoap import AdaSOAP
|
|
5
|
+
from .curveball import CurveBall
|
|
6
|
+
from .dsoap import DSOAP
|
|
7
|
+
from .gradmin import GradMin
|
|
8
|
+
from .reduce_outward_lr import ReduceOutwardLR
|
|
9
|
+
from .spectral import SpectralPreconditioner
|
|
10
|
+
from .subspace_preconditioners import (
|
|
11
|
+
HistorySubspacePreconditioning,
|
|
12
|
+
RandomSubspacePreconditioning,
|
|
19
13
|
)
|
|
14
|
+
from .tropical_newton import TropicalNewton
|