torchzero 0.3.9__py3-none-any.whl → 0.3.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docs/source/conf.py +6 -4
- docs/source/docstring template.py +46 -0
- tests/test_identical.py +2 -3
- tests/test_opts.py +115 -68
- tests/test_tensorlist.py +2 -2
- tests/test_vars.py +62 -61
- torchzero/core/__init__.py +2 -3
- torchzero/core/module.py +185 -53
- torchzero/core/transform.py +327 -159
- torchzero/modules/__init__.py +3 -1
- torchzero/modules/clipping/clipping.py +120 -23
- torchzero/modules/clipping/ema_clipping.py +37 -22
- torchzero/modules/clipping/growth_clipping.py +20 -21
- torchzero/modules/experimental/__init__.py +30 -4
- torchzero/modules/experimental/absoap.py +53 -156
- torchzero/modules/experimental/adadam.py +22 -15
- torchzero/modules/experimental/adamY.py +21 -25
- torchzero/modules/experimental/adam_lambertw.py +149 -0
- torchzero/modules/{line_search/trust_region.py → experimental/adaptive_step_size.py} +37 -8
- torchzero/modules/experimental/adasoap.py +24 -129
- torchzero/modules/experimental/cosine.py +214 -0
- torchzero/modules/experimental/cubic_adam.py +97 -0
- torchzero/modules/experimental/curveball.py +12 -12
- torchzero/modules/{projections → experimental}/dct.py +11 -11
- torchzero/modules/experimental/eigendescent.py +120 -0
- torchzero/modules/experimental/etf.py +195 -0
- torchzero/modules/experimental/exp_adam.py +113 -0
- torchzero/modules/experimental/expanded_lbfgs.py +141 -0
- torchzero/modules/{projections → experimental}/fft.py +10 -10
- torchzero/modules/experimental/gradmin.py +2 -2
- torchzero/modules/experimental/hnewton.py +85 -0
- torchzero/modules/{quasi_newton/experimental → experimental}/modular_lbfgs.py +49 -50
- torchzero/modules/experimental/newton_solver.py +11 -11
- torchzero/modules/experimental/newtonnewton.py +92 -0
- torchzero/modules/experimental/parabolic_search.py +220 -0
- torchzero/modules/experimental/reduce_outward_lr.py +10 -7
- torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +12 -54
- torchzero/modules/experimental/subspace_preconditioners.py +20 -10
- torchzero/modules/experimental/tensor_adagrad.py +42 -0
- torchzero/modules/functional.py +12 -2
- torchzero/modules/grad_approximation/fdm.py +31 -4
- torchzero/modules/grad_approximation/forward_gradient.py +17 -7
- torchzero/modules/grad_approximation/grad_approximator.py +69 -24
- torchzero/modules/grad_approximation/rfdm.py +310 -50
- torchzero/modules/higher_order/__init__.py +1 -0
- torchzero/modules/higher_order/higher_order_newton.py +319 -0
- torchzero/modules/line_search/__init__.py +4 -4
- torchzero/modules/line_search/adaptive.py +99 -0
- torchzero/modules/line_search/backtracking.py +75 -31
- torchzero/modules/line_search/line_search.py +107 -49
- torchzero/modules/line_search/polynomial.py +233 -0
- torchzero/modules/line_search/scipy.py +20 -5
- torchzero/modules/line_search/strong_wolfe.py +52 -36
- torchzero/modules/misc/__init__.py +27 -0
- torchzero/modules/misc/debug.py +48 -0
- torchzero/modules/misc/escape.py +60 -0
- torchzero/modules/misc/gradient_accumulation.py +70 -0
- torchzero/modules/misc/misc.py +316 -0
- torchzero/modules/misc/multistep.py +158 -0
- torchzero/modules/misc/regularization.py +171 -0
- torchzero/modules/misc/split.py +103 -0
- torchzero/modules/{ops → misc}/switch.py +48 -7
- torchzero/modules/momentum/__init__.py +1 -1
- torchzero/modules/momentum/averaging.py +25 -10
- torchzero/modules/momentum/cautious.py +115 -40
- torchzero/modules/momentum/ema.py +92 -41
- torchzero/modules/momentum/experimental.py +21 -13
- torchzero/modules/momentum/matrix_momentum.py +145 -76
- torchzero/modules/momentum/momentum.py +25 -4
- torchzero/modules/ops/__init__.py +3 -31
- torchzero/modules/ops/accumulate.py +51 -25
- torchzero/modules/ops/binary.py +108 -62
- torchzero/modules/ops/multi.py +95 -34
- torchzero/modules/ops/reduce.py +31 -23
- torchzero/modules/ops/unary.py +37 -21
- torchzero/modules/ops/utility.py +53 -45
- torchzero/modules/optimizers/__init__.py +12 -3
- torchzero/modules/optimizers/adagrad.py +48 -29
- torchzero/modules/optimizers/adahessian.py +223 -0
- torchzero/modules/optimizers/adam.py +35 -37
- torchzero/modules/optimizers/adan.py +110 -0
- torchzero/modules/optimizers/adaptive_heavyball.py +57 -0
- torchzero/modules/optimizers/esgd.py +171 -0
- torchzero/modules/optimizers/ladagrad.py +183 -0
- torchzero/modules/optimizers/lion.py +4 -4
- torchzero/modules/optimizers/mars.py +91 -0
- torchzero/modules/optimizers/msam.py +186 -0
- torchzero/modules/optimizers/muon.py +32 -7
- torchzero/modules/optimizers/orthograd.py +4 -5
- torchzero/modules/optimizers/rmsprop.py +19 -19
- torchzero/modules/optimizers/rprop.py +89 -52
- torchzero/modules/optimizers/sam.py +163 -0
- torchzero/modules/optimizers/shampoo.py +55 -27
- torchzero/modules/optimizers/soap.py +40 -37
- torchzero/modules/optimizers/sophia_h.py +82 -25
- torchzero/modules/projections/__init__.py +2 -4
- torchzero/modules/projections/cast.py +51 -0
- torchzero/modules/projections/galore.py +4 -2
- torchzero/modules/projections/projection.py +212 -118
- torchzero/modules/quasi_newton/__init__.py +44 -5
- torchzero/modules/quasi_newton/cg.py +190 -39
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +163 -0
- torchzero/modules/quasi_newton/lbfgs.py +154 -97
- torchzero/modules/quasi_newton/lsr1.py +102 -58
- torchzero/modules/quasi_newton/quasi_newton.py +1032 -177
- torchzero/modules/quasi_newton/trust_region.py +397 -0
- torchzero/modules/second_order/__init__.py +2 -2
- torchzero/modules/second_order/newton.py +245 -54
- torchzero/modules/second_order/newton_cg.py +311 -21
- torchzero/modules/second_order/nystrom.py +124 -21
- torchzero/modules/smoothing/gaussian.py +55 -21
- torchzero/modules/smoothing/laplacian.py +20 -12
- torchzero/modules/step_size/__init__.py +2 -0
- torchzero/modules/step_size/adaptive.py +122 -0
- torchzero/modules/step_size/lr.py +154 -0
- torchzero/modules/weight_decay/__init__.py +1 -1
- torchzero/modules/weight_decay/weight_decay.py +126 -10
- torchzero/modules/wrappers/optim_wrapper.py +40 -12
- torchzero/optim/wrappers/directsearch.py +281 -0
- torchzero/optim/wrappers/fcmaes.py +105 -0
- torchzero/optim/wrappers/mads.py +89 -0
- torchzero/optim/wrappers/nevergrad.py +20 -5
- torchzero/optim/wrappers/nlopt.py +28 -14
- torchzero/optim/wrappers/optuna.py +70 -0
- torchzero/optim/wrappers/scipy.py +167 -16
- torchzero/utils/__init__.py +3 -7
- torchzero/utils/derivatives.py +5 -4
- torchzero/utils/linalg/__init__.py +1 -1
- torchzero/utils/linalg/solve.py +251 -12
- torchzero/utils/numberlist.py +2 -0
- torchzero/utils/optimizer.py +55 -74
- torchzero/utils/python_tools.py +27 -4
- torchzero/utils/tensorlist.py +40 -28
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/METADATA +76 -51
- torchzero-0.3.11.dist-info/RECORD +159 -0
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/WHEEL +1 -1
- torchzero/core/preconditioner.py +0 -138
- torchzero/modules/experimental/algebraic_newton.py +0 -145
- torchzero/modules/experimental/soapy.py +0 -290
- torchzero/modules/experimental/spectral.py +0 -288
- torchzero/modules/experimental/structured_newton.py +0 -111
- torchzero/modules/experimental/tropical_newton.py +0 -136
- torchzero/modules/lr/__init__.py +0 -2
- torchzero/modules/lr/lr.py +0 -59
- torchzero/modules/lr/step_size.py +0 -97
- torchzero/modules/ops/debug.py +0 -25
- torchzero/modules/ops/misc.py +0 -419
- torchzero/modules/ops/split.py +0 -75
- torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
- torchzero/modules/quasi_newton/olbfgs.py +0 -196
- torchzero-0.3.9.dist-info/RECORD +0 -131
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/licenses/LICENSE +0 -0
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/top_level.txt +0 -0
|
@@ -5,7 +5,7 @@ import math
|
|
|
5
5
|
import torch
|
|
6
6
|
|
|
7
7
|
from ...core import Module, Target, Transform
|
|
8
|
-
from ...utils import NumberList, TensorList
|
|
8
|
+
from ...utils import NumberList, TensorList
|
|
9
9
|
|
|
10
10
|
|
|
11
11
|
def clip_grad_value_(params: Iterable[torch.Tensor], value: float):
|
|
@@ -24,7 +24,7 @@ def _clip_norm_(
|
|
|
24
24
|
min: float | NumberList | None,
|
|
25
25
|
max: float | NumberList | None,
|
|
26
26
|
norm_value: float | NumberList | None,
|
|
27
|
-
ord: float,
|
|
27
|
+
ord: float | Literal['mean_abs'],
|
|
28
28
|
dim: int | Sequence[int] | Literal["global"] | None,
|
|
29
29
|
inverse_dims: bool,
|
|
30
30
|
min_size: int,
|
|
@@ -54,9 +54,13 @@ def _clip_norm_(
|
|
|
54
54
|
size = math.prod(tensor.size(d) for d in real_dim)
|
|
55
55
|
if size < min_size: continue
|
|
56
56
|
|
|
57
|
-
|
|
57
|
+
if ord == 'mean_abs':
|
|
58
|
+
norm = tensor.abs().mean(dim=real_dim, keepdim=True)
|
|
59
|
+
else:
|
|
60
|
+
norm: torch.Tensor = torch.linalg.vector_norm(tensor, ord=ord, dim=real_dim, keepdim=True) # pylint:disable=not-callable
|
|
61
|
+
|
|
58
62
|
if norm.numel() == 1 and norm == 0: continue
|
|
59
|
-
norm = torch.where(norm
|
|
63
|
+
norm = torch.where(norm <= 1e-12, 1, norm)
|
|
60
64
|
|
|
61
65
|
# normalize = True, perform normalization
|
|
62
66
|
norm_v = norm_value[i] if isinstance(norm_value, (list,tuple)) else norm_value
|
|
@@ -90,7 +94,7 @@ def _clip_norm_(
|
|
|
90
94
|
def clip_grad_norm_(
|
|
91
95
|
params: Iterable[torch.Tensor],
|
|
92
96
|
max_norm: float | None,
|
|
93
|
-
ord: float = 2,
|
|
97
|
+
ord: float | Literal['mean_abs'] = 2,
|
|
94
98
|
dim: int | Sequence[int] | Literal["global"] | None = None,
|
|
95
99
|
inverse_dims: bool = False,
|
|
96
100
|
min_size: int = 2,
|
|
@@ -118,7 +122,7 @@ def clip_grad_norm_(
|
|
|
118
122
|
def normalize_grads_(
|
|
119
123
|
params: Iterable[torch.Tensor],
|
|
120
124
|
norm_value: float,
|
|
121
|
-
ord: float = 2,
|
|
125
|
+
ord: float | Literal['mean_abs'] = 2,
|
|
122
126
|
dim: int | Sequence[int] | Literal["global"] | None = None,
|
|
123
127
|
inverse_dims: bool = False,
|
|
124
128
|
min_size: int = 1,
|
|
@@ -145,14 +149,44 @@ def normalize_grads_(
|
|
|
145
149
|
|
|
146
150
|
|
|
147
151
|
class ClipValue(Transform):
|
|
148
|
-
"""Clips update magnitude to be within `(-value, value)` range.
|
|
152
|
+
"""Clips update magnitude to be within `(-value, value)` range.
|
|
153
|
+
|
|
154
|
+
Args:
|
|
155
|
+
value (float): value to clip to.
|
|
156
|
+
target (str): refer to :ref:`target argument` in documentation.
|
|
157
|
+
|
|
158
|
+
Examples:
|
|
159
|
+
|
|
160
|
+
Gradient clipping:
|
|
161
|
+
|
|
162
|
+
.. code-block:: python
|
|
163
|
+
|
|
164
|
+
opt = tz.Modular(
|
|
165
|
+
model.parameters(),
|
|
166
|
+
tz.m.ClipValue(1),
|
|
167
|
+
tz.m.Adam(),
|
|
168
|
+
tz.m.LR(1e-2),
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
Update clipping:
|
|
172
|
+
|
|
173
|
+
.. code-block:: python
|
|
174
|
+
|
|
175
|
+
opt = tz.Modular(
|
|
176
|
+
model.parameters(),
|
|
177
|
+
tz.m.Adam(),
|
|
178
|
+
tz.m.ClipValue(1),
|
|
179
|
+
tz.m.LR(1e-2),
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
"""
|
|
149
183
|
def __init__(self, value: float, target: Target = 'update'):
|
|
150
184
|
defaults = dict(value=value)
|
|
151
|
-
super().__init__(defaults,
|
|
185
|
+
super().__init__(defaults, target=target)
|
|
152
186
|
|
|
153
187
|
@torch.no_grad
|
|
154
|
-
def
|
|
155
|
-
value =
|
|
188
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
189
|
+
value = [s['value'] for s in settings]
|
|
156
190
|
return TensorList(tensors).clip_([-v for v in value], value)
|
|
157
191
|
|
|
158
192
|
class ClipNorm(Transform):
|
|
@@ -172,23 +206,47 @@ class ClipNorm(Transform):
|
|
|
172
206
|
minimal numer of elements in a parameter or slice to clip norm. Defaults to 1.
|
|
173
207
|
target (str, optional):
|
|
174
208
|
what this affects.
|
|
209
|
+
|
|
210
|
+
Examples:
|
|
211
|
+
|
|
212
|
+
Gradient norm clipping:
|
|
213
|
+
|
|
214
|
+
.. code-block:: python
|
|
215
|
+
|
|
216
|
+
opt = tz.Modular(
|
|
217
|
+
model.parameters(),
|
|
218
|
+
tz.m.ClipNorm(1),
|
|
219
|
+
tz.m.Adam(),
|
|
220
|
+
tz.m.LR(1e-2),
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
Update norm clipping:
|
|
224
|
+
|
|
225
|
+
.. code-block:: python
|
|
226
|
+
|
|
227
|
+
opt = tz.Modular(
|
|
228
|
+
model.parameters(),
|
|
229
|
+
tz.m.Adam(),
|
|
230
|
+
tz.m.ClipNorm(1),
|
|
231
|
+
tz.m.LR(1e-2),
|
|
232
|
+
)
|
|
175
233
|
"""
|
|
176
234
|
def __init__(
|
|
177
235
|
self,
|
|
178
236
|
max_norm: float,
|
|
179
|
-
ord: float = 2,
|
|
237
|
+
ord: float | Literal['mean_abs'] = 2,
|
|
180
238
|
dim: int | Sequence[int] | Literal["global"] | None = None,
|
|
181
239
|
inverse_dims: bool = False,
|
|
182
240
|
min_size: int = 1,
|
|
183
241
|
target: Target = "update",
|
|
184
242
|
):
|
|
185
243
|
defaults = dict(max_norm=max_norm,ord=ord,dim=dim,min_size=min_size,inverse_dims=inverse_dims)
|
|
186
|
-
super().__init__(defaults,
|
|
244
|
+
super().__init__(defaults, target=target)
|
|
187
245
|
|
|
188
246
|
@torch.no_grad
|
|
189
|
-
def
|
|
190
|
-
max_norm =
|
|
191
|
-
ord, dim, min_size, inverse_dims = itemgetter('ord', 'dim', 'min_size', 'inverse_dims')(
|
|
247
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
248
|
+
max_norm = NumberList(s['max_norm'] for s in settings)
|
|
249
|
+
ord, dim, min_size, inverse_dims = itemgetter('ord', 'dim', 'min_size', 'inverse_dims')(settings[0])
|
|
192
250
|
_clip_norm_(
|
|
193
251
|
tensors_ = TensorList(tensors),
|
|
194
252
|
min = 0,
|
|
@@ -218,23 +276,47 @@ class Normalize(Transform):
|
|
|
218
276
|
minimal size of a dimension to normalize along it. Defaults to 1.
|
|
219
277
|
target (str, optional):
|
|
220
278
|
what this affects.
|
|
279
|
+
|
|
280
|
+
Examples:
|
|
281
|
+
|
|
282
|
+
Gradient normalization:
|
|
283
|
+
|
|
284
|
+
.. code-block:: python
|
|
285
|
+
|
|
286
|
+
opt = tz.Modular(
|
|
287
|
+
model.parameters(),
|
|
288
|
+
tz.m.Normalize(1),
|
|
289
|
+
tz.m.Adam(),
|
|
290
|
+
tz.m.LR(1e-2),
|
|
291
|
+
)
|
|
292
|
+
|
|
293
|
+
Update normalization:
|
|
294
|
+
|
|
295
|
+
.. code-block:: python
|
|
296
|
+
|
|
297
|
+
opt = tz.Modular(
|
|
298
|
+
model.parameters(),
|
|
299
|
+
tz.m.Adam(),
|
|
300
|
+
tz.m.Normalize(1),
|
|
301
|
+
tz.m.LR(1e-2),
|
|
302
|
+
)
|
|
221
303
|
"""
|
|
222
304
|
def __init__(
|
|
223
305
|
self,
|
|
224
306
|
norm_value: float = 1,
|
|
225
|
-
ord: float = 2,
|
|
307
|
+
ord: float | Literal['mean_abs'] = 2,
|
|
226
308
|
dim: int | Sequence[int] | Literal["global"] | None = None,
|
|
227
309
|
inverse_dims: bool = False,
|
|
228
310
|
min_size: int = 1,
|
|
229
311
|
target: Target = "update",
|
|
230
312
|
):
|
|
231
313
|
defaults = dict(norm_value=norm_value,ord=ord,dim=dim,min_size=min_size, inverse_dims=inverse_dims)
|
|
232
|
-
super().__init__(defaults,
|
|
314
|
+
super().__init__(defaults, target=target)
|
|
233
315
|
|
|
234
316
|
@torch.no_grad
|
|
235
|
-
def
|
|
236
|
-
norm_value =
|
|
237
|
-
ord, dim, min_size, inverse_dims = itemgetter('ord', 'dim', 'min_size', 'inverse_dims')(
|
|
317
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
318
|
+
norm_value = NumberList(s['norm_value'] for s in settings)
|
|
319
|
+
ord, dim, min_size, inverse_dims = itemgetter('ord', 'dim', 'min_size', 'inverse_dims')(settings[0])
|
|
238
320
|
|
|
239
321
|
_clip_norm_(
|
|
240
322
|
tensors_ = TensorList(tensors),
|
|
@@ -299,6 +381,21 @@ class Centralize(Transform):
|
|
|
299
381
|
if True, the `dims` argument is inverted, and all other dimensions are centralized.
|
|
300
382
|
min_size (int, optional):
|
|
301
383
|
minimal size of a dimension to normalize along it. Defaults to 1.
|
|
384
|
+
|
|
385
|
+
Examples:
|
|
386
|
+
|
|
387
|
+
Standard gradient centralization:
|
|
388
|
+
|
|
389
|
+
.. code-block:: python
|
|
390
|
+
|
|
391
|
+
opt = tz.Modular(
|
|
392
|
+
model.parameters(),
|
|
393
|
+
tz.m.Centralize(dim=0),
|
|
394
|
+
tz.m.LR(1e-2),
|
|
395
|
+
)
|
|
396
|
+
|
|
397
|
+
References:
|
|
398
|
+
- Yong, H., Huang, J., Hua, X., & Zhang, L. (2020). Gradient centralization: A new optimization technique for deep neural networks. In Computer Vision–ECCV 2020: 16th European Conference, Glasgow, UK, August 23–28, 2020, Proceedings, Part I 16 (pp. 635-652). Springer International Publishing. https://arxiv.org/abs/2004.01461
|
|
302
399
|
"""
|
|
303
400
|
def __init__(
|
|
304
401
|
self,
|
|
@@ -308,11 +405,11 @@ class Centralize(Transform):
|
|
|
308
405
|
target: Target = "update",
|
|
309
406
|
):
|
|
310
407
|
defaults = dict(dim=dim,min_size=min_size,inverse_dims=inverse_dims)
|
|
311
|
-
super().__init__(defaults,
|
|
408
|
+
super().__init__(defaults, target=target)
|
|
312
409
|
|
|
313
410
|
@torch.no_grad
|
|
314
|
-
def
|
|
315
|
-
dim, min_size, inverse_dims = itemgetter('dim', 'min_size', 'inverse_dims')(
|
|
411
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
412
|
+
dim, min_size, inverse_dims = itemgetter('dim', 'min_size', 'inverse_dims')(settings[0])
|
|
316
413
|
|
|
317
414
|
_centralize_(tensors_ = TensorList(tensors), dim=dim, inverse_dims=inverse_dims, min_size=min_size)
|
|
318
415
|
|
|
@@ -4,8 +4,8 @@ from collections.abc import Iterable, Sequence
|
|
|
4
4
|
|
|
5
5
|
import torch
|
|
6
6
|
|
|
7
|
-
from ...core import Module, Target, Transform,
|
|
8
|
-
from ...utils import NumberList, TensorList,
|
|
7
|
+
from ...core import Module, Target, Transform, apply_transform, Chainable
|
|
8
|
+
from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
|
|
9
9
|
|
|
10
10
|
class ClipNormByEMA(Transform):
|
|
11
11
|
"""Clips norm to be no larger than the norm of an exponential moving average of past updates.
|
|
@@ -14,9 +14,10 @@ class ClipNormByEMA(Transform):
|
|
|
14
14
|
beta (float, optional): beta for the exponential moving average. Defaults to 0.99.
|
|
15
15
|
ord (float, optional): order of the norm. Defaults to 2.
|
|
16
16
|
eps (float, optional): epsilon for division. Defaults to 1e-6.
|
|
17
|
-
tensorwise (bool, optional):
|
|
17
|
+
tensorwise (bool, optional):
|
|
18
|
+
if True, norms are calculated parameter-wise, otherwise treats all parameters as single vector. Defaults to True.
|
|
18
19
|
max_ema_growth (float | None, optional):
|
|
19
|
-
if specified, exponential moving average norm can grow
|
|
20
|
+
if specified, restricts how quickly exponential moving average norm can grow. The norm is allowed to grow by at most this value per step. Defaults to 1.5.
|
|
20
21
|
ema_init (str, optional):
|
|
21
22
|
How to initialize exponential moving average on first step, "update" to use the first update or "zeros". Defaults to 'zeros'.
|
|
22
23
|
"""
|
|
@@ -29,18 +30,20 @@ class ClipNormByEMA(Transform):
|
|
|
29
30
|
tensorwise:bool=True,
|
|
30
31
|
max_ema_growth: float | None = 1.5,
|
|
31
32
|
ema_init: Literal['zeros', 'update'] = 'zeros',
|
|
33
|
+
inner: Chainable | None = None,
|
|
32
34
|
):
|
|
33
35
|
defaults = dict(beta=beta, ord=ord, tensorwise=tensorwise, ema_init=ema_init, eps=eps, max_ema_growth=max_ema_growth)
|
|
34
|
-
super().__init__(defaults,
|
|
36
|
+
super().__init__(defaults, inner=inner)
|
|
35
37
|
|
|
36
38
|
@torch.no_grad
|
|
37
|
-
def
|
|
38
|
-
ord, tensorwise, ema_init, max_ema_growth = itemgetter('ord', 'tensorwise', 'ema_init', 'max_ema_growth')(self.settings[params[0]])
|
|
39
|
-
|
|
40
|
-
beta, eps = self.get_settings('beta', 'eps', params=params, cls=NumberList)
|
|
39
|
+
def update_tensors(self, tensors, params, grads, loss, states, settings):
|
|
41
40
|
tensors = TensorList(tensors)
|
|
41
|
+
ord, tensorwise, ema_init, max_ema_growth = itemgetter('ord', 'tensorwise', 'ema_init', 'max_ema_growth')(settings[0])
|
|
42
|
+
|
|
43
|
+
beta, eps = unpack_dicts(settings, 'beta', 'eps', cls=NumberList)
|
|
44
|
+
|
|
45
|
+
ema = unpack_states(states, tensors, 'ema', init = (torch.zeros_like if ema_init=='zeros' else tensors), cls=TensorList)
|
|
42
46
|
|
|
43
|
-
ema = self.get_state('ema', params=params, init = (torch.zeros_like if ema_init=='zeros' else tensors), cls=TensorList)
|
|
44
47
|
ema.lerp_(tensors, 1-beta)
|
|
45
48
|
|
|
46
49
|
if tensorwise:
|
|
@@ -48,7 +51,7 @@ class ClipNormByEMA(Transform):
|
|
|
48
51
|
|
|
49
52
|
# clip ema norm growth
|
|
50
53
|
if max_ema_growth is not None:
|
|
51
|
-
prev_ema_norm =
|
|
54
|
+
prev_ema_norm = unpack_states(states, tensors, 'prev_ema_norm', init=ema_norm, cls=TensorList)
|
|
52
55
|
allowed_norm = (prev_ema_norm * max_ema_growth).clip(min=1e-6)
|
|
53
56
|
ema_denom = (ema_norm / allowed_norm).clip(min=1)
|
|
54
57
|
ema.div_(ema_denom)
|
|
@@ -77,7 +80,12 @@ class ClipNormByEMA(Transform):
|
|
|
77
80
|
if self.NORMALIZE: denom.clip_(min=eps[0])
|
|
78
81
|
else: denom.clip_(min=1)
|
|
79
82
|
|
|
80
|
-
|
|
83
|
+
self.global_state['denom'] = denom
|
|
84
|
+
|
|
85
|
+
@torch.no_grad
|
|
86
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
87
|
+
denom = self.global_state.pop('denom')
|
|
88
|
+
torch._foreach_div_(tensors, denom)
|
|
81
89
|
return tensors
|
|
82
90
|
|
|
83
91
|
class NormalizeByEMA(ClipNormByEMA):
|
|
@@ -87,9 +95,10 @@ class NormalizeByEMA(ClipNormByEMA):
|
|
|
87
95
|
beta (float, optional): beta for the exponential moving average. Defaults to 0.99.
|
|
88
96
|
ord (float, optional): order of the norm. Defaults to 2.
|
|
89
97
|
eps (float, optional): epsilon for division. Defaults to 1e-6.
|
|
90
|
-
tensorwise (bool, optional):
|
|
98
|
+
tensorwise (bool, optional):
|
|
99
|
+
if True, norms are calculated parameter-wise, otherwise treats all parameters as single vector. Defaults to True.
|
|
91
100
|
max_ema_growth (float | None, optional):
|
|
92
|
-
if specified, exponential moving average norm can grow
|
|
101
|
+
if specified, restricts how quickly exponential moving average norm can grow. The norm is allowed to grow by at most this value per step. Defaults to 1.5.
|
|
93
102
|
ema_init (str, optional):
|
|
94
103
|
How to initialize exponential moving average on first step, "update" to use the first update or "zeros". Defaults to 'zeros'.
|
|
95
104
|
"""
|
|
@@ -98,38 +107,44 @@ class NormalizeByEMA(ClipNormByEMA):
|
|
|
98
107
|
# TODO Centralize by EMA?
|
|
99
108
|
|
|
100
109
|
class ClipValueByEMA(Transform):
|
|
101
|
-
"""Clips magnitude of update to be no larger than magnitude of
|
|
110
|
+
"""Clips magnitude of update to be no larger than magnitude of exponential moving average of past (unclipped) updates.
|
|
102
111
|
|
|
103
112
|
Args:
|
|
104
113
|
beta (float, optional): beta for the exponential moving average. Defaults to 0.99.
|
|
105
114
|
ema_init (str, optional):
|
|
106
115
|
How to initialize exponential moving average on first step, "update" to use the first update or "zeros". Defaults to 'zeros'.
|
|
107
|
-
ema_tfm (Chainable | None, optional):
|
|
116
|
+
ema_tfm (Chainable | None, optional):
|
|
117
|
+
optional modules applied to exponential moving average before clipping by it. Defaults to None.
|
|
108
118
|
"""
|
|
109
119
|
def __init__(
|
|
110
120
|
self,
|
|
111
121
|
beta=0.99,
|
|
112
122
|
ema_init: Literal['zeros', 'update'] = 'zeros',
|
|
113
123
|
ema_tfm:Chainable | None=None,
|
|
124
|
+
inner: Chainable | None = None,
|
|
114
125
|
):
|
|
115
126
|
defaults = dict(beta=beta, ema_init=ema_init)
|
|
116
|
-
super().__init__(defaults,
|
|
127
|
+
super().__init__(defaults, inner=inner)
|
|
117
128
|
|
|
118
129
|
if ema_tfm is not None:
|
|
119
130
|
self.set_child('ema_tfm', ema_tfm)
|
|
120
131
|
|
|
121
132
|
@torch.no_grad
|
|
122
|
-
def
|
|
123
|
-
ema_init = itemgetter('ema_init')(
|
|
133
|
+
def update_tensors(self, tensors, params, grads, loss, states, settings):
|
|
134
|
+
ema_init = itemgetter('ema_init')(settings[0])
|
|
124
135
|
|
|
125
|
-
beta =
|
|
136
|
+
beta = unpack_dicts(settings, 'beta', cls=NumberList)
|
|
126
137
|
tensors = TensorList(tensors)
|
|
127
138
|
|
|
128
|
-
ema =
|
|
139
|
+
ema = unpack_states(states, tensors, 'ema', init = (torch.zeros_like if ema_init=='zeros' else lambda t: t.abs()), cls=TensorList)
|
|
129
140
|
ema.lerp_(tensors.abs(), 1-beta)
|
|
130
141
|
|
|
142
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
143
|
+
tensors = TensorList(tensors)
|
|
144
|
+
ema = unpack_states(states, tensors, 'ema', cls=TensorList)
|
|
145
|
+
|
|
131
146
|
if 'ema_tfm' in self.children:
|
|
132
|
-
ema = TensorList(
|
|
147
|
+
ema = TensorList(apply_transform(self.children['ema_tfm'], ema.clone(), params, grads, loss))
|
|
133
148
|
|
|
134
149
|
tensors.clip_(-ema, ema)
|
|
135
150
|
return tensors
|
|
@@ -19,7 +19,7 @@ class ClipValueGrowth(TensorwiseTransform):
|
|
|
19
19
|
bounds the tracked multiplicative clipping decay to prevent collapse to 0.
|
|
20
20
|
Next update is at most :code:`max(previous update * mul, max_decay)`.
|
|
21
21
|
Defaults to 2.
|
|
22
|
-
target (Target, optional): what to set on
|
|
22
|
+
target (Target, optional): what to set on var. Defaults to "update".
|
|
23
23
|
"""
|
|
24
24
|
def __init__(
|
|
25
25
|
self,
|
|
@@ -30,15 +30,13 @@ class ClipValueGrowth(TensorwiseTransform):
|
|
|
30
30
|
target: Target = "update",
|
|
31
31
|
):
|
|
32
32
|
defaults = dict(add=add, mul=mul, min_value=min_value, max_decay=max_decay)
|
|
33
|
-
super().__init__(defaults,
|
|
33
|
+
super().__init__(defaults, target=target)
|
|
34
34
|
|
|
35
35
|
|
|
36
|
-
def
|
|
37
|
-
add, mul, min_value, max_decay = itemgetter('add','mul','min_value','max_decay')(
|
|
36
|
+
def apply_tensor(self, tensor, param, grad, loss, state, setting):
|
|
37
|
+
add, mul, min_value, max_decay = itemgetter('add','mul','min_value','max_decay')(setting)
|
|
38
38
|
add: float | None
|
|
39
39
|
|
|
40
|
-
state = self.state[param]
|
|
41
|
-
|
|
42
40
|
if add is None and mul is None:
|
|
43
41
|
return tensor
|
|
44
42
|
|
|
@@ -122,7 +120,8 @@ class ClipNormGrowth(Transform):
|
|
|
122
120
|
|
|
123
121
|
Args:
|
|
124
122
|
add (float | None, optional): additive clipping, next update norm is at most `previous norm + add`. Defaults to None.
|
|
125
|
-
mul (float | None, optional):
|
|
123
|
+
mul (float | None, optional):
|
|
124
|
+
multiplicative clipping, next update norm is at most `previous norm * mul`. Defaults to 1.5.
|
|
126
125
|
min_value (float | None, optional):
|
|
127
126
|
minimum value for multiplicative clipping to prevent collapse to 0.
|
|
128
127
|
Next norm is at most :code:`max(prev_norm, min_value) * mul`. Defaults to 1e-4.
|
|
@@ -133,7 +132,7 @@ class ClipNormGrowth(Transform):
|
|
|
133
132
|
ord (float, optional): norm order. Defaults to 2.
|
|
134
133
|
parameterwise (bool, optional):
|
|
135
134
|
if True, norms are calculated parameter-wise, otherwise treats all parameters as single vector. Defaults to True.
|
|
136
|
-
target (Target, optional): what to set on
|
|
135
|
+
target (Target, optional): what to set on var. Defaults to "update".
|
|
137
136
|
"""
|
|
138
137
|
def __init__(
|
|
139
138
|
self,
|
|
@@ -146,39 +145,39 @@ class ClipNormGrowth(Transform):
|
|
|
146
145
|
target: Target = "update",
|
|
147
146
|
):
|
|
148
147
|
defaults = dict(add=add, mul=mul, min_value=min_value, max_decay=max_decay, ord=ord, parameterwise=parameterwise)
|
|
149
|
-
super().__init__(defaults,
|
|
148
|
+
super().__init__(defaults, target=target)
|
|
150
149
|
|
|
151
150
|
|
|
152
151
|
|
|
153
|
-
def
|
|
154
|
-
parameterwise =
|
|
152
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
153
|
+
parameterwise = settings[0]['parameterwise']
|
|
155
154
|
tensors = TensorList(tensors)
|
|
156
155
|
|
|
157
156
|
if parameterwise:
|
|
158
157
|
ts = tensors
|
|
159
|
-
stts =
|
|
160
|
-
stns =
|
|
158
|
+
stts = states
|
|
159
|
+
stns = settings
|
|
161
160
|
|
|
162
161
|
else:
|
|
163
162
|
ts = [tensors.to_vec()]
|
|
164
163
|
stts = [self.global_state]
|
|
165
|
-
stns = [
|
|
164
|
+
stns = [settings[0]]
|
|
166
165
|
|
|
167
166
|
|
|
168
|
-
for t,state,
|
|
167
|
+
for t, state, setting in zip(ts, stts, stns):
|
|
169
168
|
if 'prev_norm' not in state:
|
|
170
|
-
state['prev_norm'] = torch.linalg.vector_norm(t, ord=
|
|
169
|
+
state['prev_norm'] = torch.linalg.vector_norm(t, ord=setting['ord']) # pylint:disable=not-callable
|
|
171
170
|
state['prev_denom'] = 1
|
|
172
171
|
continue
|
|
173
172
|
|
|
174
173
|
_, state['prev_norm'], state['prev_denom'] = norm_growth_clip_(
|
|
175
174
|
tensor_ = t,
|
|
176
175
|
prev_norm = state['prev_norm'],
|
|
177
|
-
add =
|
|
178
|
-
mul =
|
|
179
|
-
min_value =
|
|
180
|
-
max_decay =
|
|
181
|
-
ord =
|
|
176
|
+
add = setting['add'],
|
|
177
|
+
mul = setting['mul'],
|
|
178
|
+
min_value = setting['min_value'],
|
|
179
|
+
max_decay = setting['max_decay'],
|
|
180
|
+
ord = setting['ord'],
|
|
182
181
|
)
|
|
183
182
|
|
|
184
183
|
if not parameterwise:
|
|
@@ -1,15 +1,41 @@
|
|
|
1
|
+
"""This submodule contains various untested experimental modules, some of them are to be moved out of experimental when properly tested, some are to remain here forever or to be deleted depending on the degree of their usefulness."""
|
|
1
2
|
from .absoap import ABSOAP
|
|
2
3
|
from .adadam import Adadam
|
|
4
|
+
from .adam_lambertw import AdamLambertW
|
|
3
5
|
from .adamY import AdamY
|
|
6
|
+
from .adaptive_step_size import AdaptiveStepSize
|
|
4
7
|
from .adasoap import AdaSOAP
|
|
8
|
+
from .cosine import (
|
|
9
|
+
AdaptiveDifference,
|
|
10
|
+
AdaptiveDifferenceEMA,
|
|
11
|
+
CosineDebounce,
|
|
12
|
+
CosineMomentum,
|
|
13
|
+
CosineStepSize,
|
|
14
|
+
ScaledAdaptiveDifference,
|
|
15
|
+
)
|
|
16
|
+
from .cubic_adam import CubicAdam
|
|
5
17
|
from .curveball import CurveBall
|
|
6
|
-
|
|
18
|
+
|
|
19
|
+
# from dct import DCTProjection
|
|
20
|
+
from .eigendescent import EigenDescent
|
|
21
|
+
from .etf import (
|
|
22
|
+
ExponentialTrajectoryFit,
|
|
23
|
+
ExponentialTrajectoryFitV2,
|
|
24
|
+
PointwiseExponential,
|
|
25
|
+
)
|
|
26
|
+
from .exp_adam import ExpAdam
|
|
27
|
+
from .expanded_lbfgs import ExpandedLBFGS
|
|
28
|
+
from .fft import FFTProjection
|
|
7
29
|
from .gradmin import GradMin
|
|
30
|
+
from .hnewton import HNewton
|
|
31
|
+
from .modular_lbfgs import ModularLBFGS
|
|
32
|
+
from .newton_solver import NewtonSolver
|
|
33
|
+
from .newtonnewton import NewtonNewton
|
|
34
|
+
from .parabolic_search import CubicParabolaSearch, ParabolaSearch
|
|
8
35
|
from .reduce_outward_lr import ReduceOutwardLR
|
|
9
|
-
from .
|
|
36
|
+
from .structural_projections import BlockPartition, TensorizeProjection
|
|
10
37
|
from .subspace_preconditioners import (
|
|
11
38
|
HistorySubspacePreconditioning,
|
|
12
39
|
RandomSubspacePreconditioning,
|
|
13
40
|
)
|
|
14
|
-
from .
|
|
15
|
-
from .newton_solver import NewtonSolver
|
|
41
|
+
from .tensor_adagrad import TensorAdagrad
|