torchzero 0.3.11__py3-none-any.whl → 0.3.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_opts.py +95 -76
- tests/test_tensorlist.py +8 -7
- torchzero/__init__.py +1 -1
- torchzero/core/__init__.py +2 -2
- torchzero/core/module.py +229 -72
- torchzero/core/reformulation.py +65 -0
- torchzero/core/transform.py +44 -24
- torchzero/modules/__init__.py +13 -5
- torchzero/modules/{optimizers → adaptive}/__init__.py +5 -2
- torchzero/modules/adaptive/adagrad.py +356 -0
- torchzero/modules/{optimizers → adaptive}/adahessian.py +53 -52
- torchzero/modules/{optimizers → adaptive}/adam.py +0 -3
- torchzero/modules/{optimizers → adaptive}/adan.py +26 -40
- torchzero/modules/{optimizers → adaptive}/adaptive_heavyball.py +3 -6
- torchzero/modules/adaptive/aegd.py +54 -0
- torchzero/modules/{optimizers → adaptive}/esgd.py +1 -1
- torchzero/modules/{optimizers/ladagrad.py → adaptive/lmadagrad.py} +42 -39
- torchzero/modules/{optimizers → adaptive}/mars.py +24 -36
- torchzero/modules/adaptive/matrix_momentum.py +146 -0
- torchzero/modules/{optimizers → adaptive}/msam.py +14 -12
- torchzero/modules/{optimizers → adaptive}/muon.py +19 -20
- torchzero/modules/adaptive/natural_gradient.py +175 -0
- torchzero/modules/{optimizers → adaptive}/rprop.py +0 -2
- torchzero/modules/{optimizers → adaptive}/sam.py +1 -1
- torchzero/modules/{optimizers → adaptive}/shampoo.py +8 -4
- torchzero/modules/{optimizers → adaptive}/soap.py +27 -50
- torchzero/modules/{optimizers → adaptive}/sophia_h.py +2 -3
- torchzero/modules/clipping/clipping.py +85 -92
- torchzero/modules/clipping/ema_clipping.py +5 -5
- torchzero/modules/conjugate_gradient/__init__.py +11 -0
- torchzero/modules/{quasi_newton → conjugate_gradient}/cg.py +355 -369
- torchzero/modules/experimental/__init__.py +9 -32
- torchzero/modules/experimental/dct.py +2 -2
- torchzero/modules/experimental/fft.py +2 -2
- torchzero/modules/experimental/gradmin.py +4 -3
- torchzero/modules/experimental/l_infinity.py +111 -0
- torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +3 -40
- torchzero/modules/experimental/newton_solver.py +79 -17
- torchzero/modules/experimental/newtonnewton.py +27 -14
- torchzero/modules/experimental/scipy_newton_cg.py +105 -0
- torchzero/modules/experimental/spsa1.py +93 -0
- torchzero/modules/experimental/structural_projections.py +1 -1
- torchzero/modules/functional.py +50 -14
- torchzero/modules/grad_approximation/__init__.py +1 -1
- torchzero/modules/grad_approximation/fdm.py +19 -20
- torchzero/modules/grad_approximation/forward_gradient.py +6 -7
- torchzero/modules/grad_approximation/grad_approximator.py +43 -47
- torchzero/modules/grad_approximation/rfdm.py +114 -175
- torchzero/modules/higher_order/__init__.py +1 -1
- torchzero/modules/higher_order/higher_order_newton.py +31 -23
- torchzero/modules/least_squares/__init__.py +1 -0
- torchzero/modules/least_squares/gn.py +161 -0
- torchzero/modules/line_search/__init__.py +2 -2
- torchzero/modules/line_search/_polyinterp.py +289 -0
- torchzero/modules/line_search/adaptive.py +69 -44
- torchzero/modules/line_search/backtracking.py +83 -70
- torchzero/modules/line_search/line_search.py +159 -68
- torchzero/modules/line_search/scipy.py +16 -4
- torchzero/modules/line_search/strong_wolfe.py +319 -220
- torchzero/modules/misc/__init__.py +8 -0
- torchzero/modules/misc/debug.py +4 -4
- torchzero/modules/misc/escape.py +9 -7
- torchzero/modules/misc/gradient_accumulation.py +88 -22
- torchzero/modules/misc/homotopy.py +59 -0
- torchzero/modules/misc/misc.py +82 -15
- torchzero/modules/misc/multistep.py +47 -11
- torchzero/modules/misc/regularization.py +5 -9
- torchzero/modules/misc/split.py +55 -35
- torchzero/modules/misc/switch.py +1 -1
- torchzero/modules/momentum/__init__.py +1 -5
- torchzero/modules/momentum/averaging.py +3 -3
- torchzero/modules/momentum/cautious.py +42 -47
- torchzero/modules/momentum/momentum.py +35 -1
- torchzero/modules/ops/__init__.py +9 -1
- torchzero/modules/ops/binary.py +9 -8
- torchzero/modules/{momentum/ema.py → ops/higher_level.py} +10 -33
- torchzero/modules/ops/multi.py +15 -15
- torchzero/modules/ops/reduce.py +1 -1
- torchzero/modules/ops/utility.py +12 -8
- torchzero/modules/projections/projection.py +4 -4
- torchzero/modules/quasi_newton/__init__.py +1 -16
- torchzero/modules/quasi_newton/damping.py +105 -0
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -163
- torchzero/modules/quasi_newton/lbfgs.py +256 -200
- torchzero/modules/quasi_newton/lsr1.py +167 -132
- torchzero/modules/quasi_newton/quasi_newton.py +346 -446
- torchzero/modules/restarts/__init__.py +7 -0
- torchzero/modules/restarts/restars.py +253 -0
- torchzero/modules/second_order/__init__.py +2 -1
- torchzero/modules/second_order/multipoint.py +238 -0
- torchzero/modules/second_order/newton.py +133 -88
- torchzero/modules/second_order/newton_cg.py +207 -170
- torchzero/modules/smoothing/__init__.py +1 -1
- torchzero/modules/smoothing/sampling.py +300 -0
- torchzero/modules/step_size/__init__.py +1 -1
- torchzero/modules/step_size/adaptive.py +312 -47
- torchzero/modules/termination/__init__.py +14 -0
- torchzero/modules/termination/termination.py +207 -0
- torchzero/modules/trust_region/__init__.py +5 -0
- torchzero/modules/trust_region/cubic_regularization.py +170 -0
- torchzero/modules/trust_region/dogleg.py +92 -0
- torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
- torchzero/modules/trust_region/trust_cg.py +99 -0
- torchzero/modules/trust_region/trust_region.py +350 -0
- torchzero/modules/variance_reduction/__init__.py +1 -0
- torchzero/modules/variance_reduction/svrg.py +208 -0
- torchzero/modules/weight_decay/weight_decay.py +65 -64
- torchzero/modules/zeroth_order/__init__.py +1 -0
- torchzero/modules/zeroth_order/cd.py +122 -0
- torchzero/optim/root.py +65 -0
- torchzero/optim/utility/split.py +8 -8
- torchzero/optim/wrappers/directsearch.py +0 -1
- torchzero/optim/wrappers/fcmaes.py +3 -2
- torchzero/optim/wrappers/nlopt.py +0 -2
- torchzero/optim/wrappers/optuna.py +2 -2
- torchzero/optim/wrappers/scipy.py +81 -22
- torchzero/utils/__init__.py +40 -4
- torchzero/utils/compile.py +1 -1
- torchzero/utils/derivatives.py +123 -111
- torchzero/utils/linalg/__init__.py +9 -2
- torchzero/utils/linalg/linear_operator.py +329 -0
- torchzero/utils/linalg/matrix_funcs.py +2 -2
- torchzero/utils/linalg/orthogonalize.py +2 -1
- torchzero/utils/linalg/qr.py +2 -2
- torchzero/utils/linalg/solve.py +226 -154
- torchzero/utils/metrics.py +83 -0
- torchzero/utils/optimizer.py +2 -2
- torchzero/utils/python_tools.py +7 -0
- torchzero/utils/tensorlist.py +105 -34
- torchzero/utils/torch_tools.py +9 -4
- torchzero-0.3.14.dist-info/METADATA +14 -0
- torchzero-0.3.14.dist-info/RECORD +167 -0
- {torchzero-0.3.11.dist-info → torchzero-0.3.14.dist-info}/top_level.txt +0 -1
- docs/source/conf.py +0 -59
- docs/source/docstring template.py +0 -46
- torchzero/modules/experimental/absoap.py +0 -253
- torchzero/modules/experimental/adadam.py +0 -118
- torchzero/modules/experimental/adamY.py +0 -131
- torchzero/modules/experimental/adam_lambertw.py +0 -149
- torchzero/modules/experimental/adaptive_step_size.py +0 -90
- torchzero/modules/experimental/adasoap.py +0 -177
- torchzero/modules/experimental/cosine.py +0 -214
- torchzero/modules/experimental/cubic_adam.py +0 -97
- torchzero/modules/experimental/eigendescent.py +0 -120
- torchzero/modules/experimental/etf.py +0 -195
- torchzero/modules/experimental/exp_adam.py +0 -113
- torchzero/modules/experimental/expanded_lbfgs.py +0 -141
- torchzero/modules/experimental/hnewton.py +0 -85
- torchzero/modules/experimental/modular_lbfgs.py +0 -265
- torchzero/modules/experimental/parabolic_search.py +0 -220
- torchzero/modules/experimental/subspace_preconditioners.py +0 -145
- torchzero/modules/experimental/tensor_adagrad.py +0 -42
- torchzero/modules/line_search/polynomial.py +0 -233
- torchzero/modules/momentum/matrix_momentum.py +0 -193
- torchzero/modules/optimizers/adagrad.py +0 -165
- torchzero/modules/quasi_newton/trust_region.py +0 -397
- torchzero/modules/smoothing/gaussian.py +0 -198
- torchzero-0.3.11.dist-info/METADATA +0 -404
- torchzero-0.3.11.dist-info/RECORD +0 -159
- torchzero-0.3.11.dist-info/licenses/LICENSE +0 -21
- /torchzero/modules/{optimizers → adaptive}/lion.py +0 -0
- /torchzero/modules/{optimizers → adaptive}/orthograd.py +0 -0
- /torchzero/modules/{optimizers → adaptive}/rmsprop.py +0 -0
- {torchzero-0.3.11.dist-info → torchzero-0.3.14.dist-info}/WHEEL +0 -0
torchzero/core/transform.py
CHANGED
|
@@ -1,18 +1,36 @@
|
|
|
1
1
|
from abc import ABC, abstractmethod
|
|
2
|
-
from collections.abc import Iterable,
|
|
2
|
+
from collections.abc import Iterable, Mapping, Sequence
|
|
3
3
|
from typing import Any, Literal, final
|
|
4
4
|
|
|
5
5
|
import torch
|
|
6
6
|
|
|
7
|
-
from ..utils import
|
|
8
|
-
from .module import
|
|
7
|
+
from ..utils import TensorList, set_storage_, vec_to_tensors
|
|
8
|
+
from .module import Chain, Chainable, Module, Var
|
|
9
9
|
|
|
10
10
|
Target = Literal['grad', 'update', 'closure', 'params_direct', 'params_difference', 'update_difference']
|
|
11
11
|
|
|
12
|
+
|
|
12
13
|
class Transform(Module, ABC):
|
|
13
|
-
"""Base class for a transform.
|
|
14
|
+
"""Base class for a transform.
|
|
15
|
+
This is an abstract class, to use it, subclass it and override ``update_tensors`` and ``apply_tensors`` methods.
|
|
14
16
|
|
|
15
17
|
A transform is a module that can also be applied manually to an arbitrary sequence of tensors.
|
|
18
|
+
It has two methods:
|
|
19
|
+
|
|
20
|
+
- ``update_tensors`` updates the internal state of this transform, it doesn't modify tensors. \
|
|
21
|
+
It may be called multiple times before ``apply_tensors``.
|
|
22
|
+
- ``apply_tensors`` applies this transform to tensors, without modifying the internal state if possible.
|
|
23
|
+
|
|
24
|
+
Alternatively, if update-apply structure doesn't make sense for a transform, all logic can be defined within ``apply_tensors``.
|
|
25
|
+
|
|
26
|
+
Transform can be applied to tensors corresponding to custom parameters
|
|
27
|
+
by calling ``keyed_transform_update`` and ``keyed_transform_apply``,
|
|
28
|
+
parameters will be keys to store per-parameter states, so they should remain the same python objects.
|
|
29
|
+
|
|
30
|
+
Alternatively you can manually create a list of state dictionaries per each tensor and pass it to
|
|
31
|
+
``transform_update`` and ``transform_apply``.
|
|
32
|
+
|
|
33
|
+
A transform can modify the closure instead of directly modifying update by passing ``target="closure"``.
|
|
16
34
|
|
|
17
35
|
Args:
|
|
18
36
|
defaults (dict[str,Any] | None): dict with default values.
|
|
@@ -21,6 +39,7 @@ class Transform(Module, ABC):
|
|
|
21
39
|
`grad` is always computed and can't be None. Otherwise set to False.
|
|
22
40
|
target (Target, optional):
|
|
23
41
|
what to set on var. Defaults to 'update'.
|
|
42
|
+
|
|
24
43
|
"""
|
|
25
44
|
def __init__(
|
|
26
45
|
self,
|
|
@@ -29,7 +48,6 @@ class Transform(Module, ABC):
|
|
|
29
48
|
uses_loss: bool = False,
|
|
30
49
|
concat_params: bool = False,
|
|
31
50
|
update_freq: int = 1,
|
|
32
|
-
scale_first: bool = False,
|
|
33
51
|
inner: Chainable | None = None,
|
|
34
52
|
target: Target = 'update',
|
|
35
53
|
):
|
|
@@ -39,8 +57,8 @@ class Transform(Module, ABC):
|
|
|
39
57
|
self._uses_loss = uses_loss
|
|
40
58
|
self._concat_params = concat_params
|
|
41
59
|
self._update_freq = update_freq
|
|
42
|
-
self._scale_first = scale_first
|
|
43
60
|
self._inner = inner
|
|
61
|
+
self._var = None
|
|
44
62
|
|
|
45
63
|
def update_tensors(
|
|
46
64
|
self,
|
|
@@ -93,14 +111,6 @@ class Transform(Module, ABC):
|
|
|
93
111
|
states = states[:num]
|
|
94
112
|
settings = settings[:num]
|
|
95
113
|
|
|
96
|
-
scale_factor = 1
|
|
97
|
-
|
|
98
|
-
# scaling factor for 1st step
|
|
99
|
-
if self._scale_first and step == 0:
|
|
100
|
-
# initial step size guess from pytorch LBFGS
|
|
101
|
-
scale_factor = 1 / TensorList(tensors).abs().global_sum().clip(min=1)
|
|
102
|
-
scale_factor = scale_factor.clip(min=torch.finfo(tensors[0].dtype).eps)
|
|
103
|
-
|
|
104
114
|
# update transform
|
|
105
115
|
if step % self._update_freq == 0:
|
|
106
116
|
self.update_tensors(tensors=tensors, params=params, grads=grads, loss=loss, states=states, settings=settings)
|
|
@@ -109,7 +119,6 @@ class Transform(Module, ABC):
|
|
|
109
119
|
self.global_state["__tensors"] = tensors
|
|
110
120
|
self.global_state["__params"] = params
|
|
111
121
|
self.global_state["__grads"] = grads
|
|
112
|
-
self.global_state["__scale_factor"] = scale_factor
|
|
113
122
|
|
|
114
123
|
|
|
115
124
|
@final
|
|
@@ -140,23 +149,19 @@ class Transform(Module, ABC):
|
|
|
140
149
|
tensors = self.global_state.pop("__tensors")
|
|
141
150
|
params = self.global_state.pop("__params")
|
|
142
151
|
grads = self.global_state.pop("__grads")
|
|
143
|
-
scale_factor = self.global_state.pop("__scale_factor")
|
|
144
152
|
|
|
145
153
|
# step with inner
|
|
146
154
|
if self._inner is not None:
|
|
147
|
-
tensors = apply_transform(self._inner, tensors=un_tensors, params=un_params, grads=un_grads)
|
|
155
|
+
tensors = apply_transform(self._inner, tensors=un_tensors, params=un_params, grads=un_grads, var=self._var)
|
|
148
156
|
if self._concat_params:
|
|
149
157
|
tensors = [torch.cat([t.ravel() for t in tensors])]
|
|
150
158
|
|
|
151
159
|
# apply transform
|
|
152
160
|
tensors = list(self.apply_tensors(tensors=tensors, params=params, grads=grads, loss=loss, states=states, settings=settings))
|
|
153
161
|
|
|
154
|
-
# scale initial step, when preconditioner might not have been applied
|
|
155
|
-
if self._scale_first and self.global_state['__step'] == 1:
|
|
156
|
-
torch._foreach_mul_(tensors, scale_factor)
|
|
157
|
-
|
|
158
162
|
if self._concat_params:
|
|
159
163
|
tensors = vec_to_tensors(vec=tensors[0], reference=un_tensors)
|
|
164
|
+
|
|
160
165
|
return tensors
|
|
161
166
|
|
|
162
167
|
def _get_keyed_states_settings(self, params: list[torch.Tensor]):
|
|
@@ -220,7 +225,9 @@ class Transform(Module, ABC):
|
|
|
220
225
|
self.pre_step(var)
|
|
221
226
|
|
|
222
227
|
# update
|
|
228
|
+
self._var = var
|
|
223
229
|
self.keyed_transform_update(update, params, var.grad, var.loss)
|
|
230
|
+
self._var = None
|
|
224
231
|
|
|
225
232
|
def apply(self, var: Var):
|
|
226
233
|
if self._target != 'update':
|
|
@@ -234,7 +241,10 @@ class Transform(Module, ABC):
|
|
|
234
241
|
params=var.params
|
|
235
242
|
|
|
236
243
|
# apply
|
|
244
|
+
self._var = var
|
|
237
245
|
var.update = self.keyed_transform_apply(update, params, var.grad, var.loss)
|
|
246
|
+
self._var = None
|
|
247
|
+
|
|
238
248
|
self.post_step(var)
|
|
239
249
|
return var
|
|
240
250
|
|
|
@@ -246,12 +256,14 @@ class Transform(Module, ABC):
|
|
|
246
256
|
if self._uses_loss: var.get_loss(False)
|
|
247
257
|
params=var.params
|
|
248
258
|
self.pre_step(var)
|
|
259
|
+
self._var = var
|
|
249
260
|
|
|
250
261
|
# ---------------------------------- update ---------------------------------- #
|
|
251
262
|
if self._target == 'update':
|
|
252
263
|
update = var.get_update()
|
|
253
264
|
self.keyed_transform_update(update, params, var.grad, var.loss)
|
|
254
265
|
var.update = list(self.keyed_transform_apply(update, params, var.grad, var.loss))
|
|
266
|
+
self._var = None
|
|
255
267
|
return var
|
|
256
268
|
|
|
257
269
|
# ----------------------------------- grad ----------------------------------- #
|
|
@@ -259,6 +271,7 @@ class Transform(Module, ABC):
|
|
|
259
271
|
grad = var.get_grad()
|
|
260
272
|
self.keyed_transform_update(grad, params, grad, var.loss)
|
|
261
273
|
var.grad = list(self.keyed_transform_apply(grad, params, grad, var.loss))
|
|
274
|
+
self._var = None
|
|
262
275
|
return var
|
|
263
276
|
|
|
264
277
|
# ------------------------------- params_direct ------------------------------ #
|
|
@@ -266,6 +279,7 @@ class Transform(Module, ABC):
|
|
|
266
279
|
self.keyed_transform_update(var.params, params, var.grad, var.loss)
|
|
267
280
|
new_params = self.keyed_transform_apply(var.params, params, var.grad, var.loss)
|
|
268
281
|
for p, new_p in zip(var.params, new_params): set_storage_(p, new_p)
|
|
282
|
+
self._var = None
|
|
269
283
|
return var
|
|
270
284
|
|
|
271
285
|
# ----------------------------- params_differnce ----------------------------- #
|
|
@@ -274,6 +288,7 @@ class Transform(Module, ABC):
|
|
|
274
288
|
self.keyed_transform_update(p_clone, params, var.grad, var.loss)
|
|
275
289
|
new_params = tuple(self.keyed_transform_apply(p_clone, params, var.grad, var.loss))
|
|
276
290
|
var.update = list(torch._foreach_sub(var.params, new_params))
|
|
291
|
+
self._var = None
|
|
277
292
|
return var
|
|
278
293
|
|
|
279
294
|
# ----------------------------- update_difference ---------------------------- #
|
|
@@ -283,6 +298,7 @@ class Transform(Module, ABC):
|
|
|
283
298
|
self.keyed_transform_update(u_clone, params, var.grad, var.loss)
|
|
284
299
|
new_update = tuple(self.keyed_transform_apply(u_clone, params, var.grad, var.loss))
|
|
285
300
|
var.update = list(torch._foreach_sub(update, new_update))
|
|
301
|
+
self._var = None
|
|
286
302
|
return var
|
|
287
303
|
|
|
288
304
|
# ---------------------------------- closure --------------------------------- #
|
|
@@ -291,12 +307,17 @@ class Transform(Module, ABC):
|
|
|
291
307
|
if original_closure is None: raise ValueError('Target = "closure", but closure is None')
|
|
292
308
|
|
|
293
309
|
params = var.params
|
|
310
|
+
parent_var = self._var
|
|
294
311
|
def transformed_closure(backward=True):
|
|
295
312
|
if backward:
|
|
296
313
|
loss = original_closure()
|
|
297
314
|
current_grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
|
|
315
|
+
|
|
316
|
+
self._var = parent_var
|
|
298
317
|
self.keyed_transform_update(current_grad, params, var.grad, var.loss)
|
|
299
318
|
transformed_grad = list(self.keyed_transform_apply(current_grad, params, var.grad, var.loss))
|
|
319
|
+
self._var = None
|
|
320
|
+
|
|
300
321
|
for p, g in zip(params, transformed_grad):
|
|
301
322
|
p.grad = g
|
|
302
323
|
|
|
@@ -307,6 +328,7 @@ class Transform(Module, ABC):
|
|
|
307
328
|
|
|
308
329
|
var.closure = transformed_closure
|
|
309
330
|
self.post_step(var)
|
|
331
|
+
self._var = None
|
|
310
332
|
return var
|
|
311
333
|
|
|
312
334
|
# ---------------------------------- invalid --------------------------------- #
|
|
@@ -316,7 +338,7 @@ class Transform(Module, ABC):
|
|
|
316
338
|
class TensorwiseTransform(Transform, ABC):
|
|
317
339
|
"""Base class for a parameter-wise transform.
|
|
318
340
|
|
|
319
|
-
This is an abstract class, to use it, subclass it and override `
|
|
341
|
+
This is an abstract class, to use it, subclass it and override `update_tensor` and `apply_tensor`.
|
|
320
342
|
|
|
321
343
|
Args:
|
|
322
344
|
defaults (dict[str,Any] | None): dict with default values.
|
|
@@ -333,7 +355,6 @@ class TensorwiseTransform(Transform, ABC):
|
|
|
333
355
|
uses_loss: bool = False,
|
|
334
356
|
concat_params: bool = False,
|
|
335
357
|
update_freq: int = 1,
|
|
336
|
-
scale_first: bool = False,
|
|
337
358
|
inner: Chainable | None = None,
|
|
338
359
|
target: Target = 'update',
|
|
339
360
|
):
|
|
@@ -342,7 +363,6 @@ class TensorwiseTransform(Transform, ABC):
|
|
|
342
363
|
uses_grad=uses_grad,
|
|
343
364
|
concat_params=concat_params,
|
|
344
365
|
update_freq=update_freq,
|
|
345
|
-
scale_first=scale_first,
|
|
346
366
|
uses_loss=uses_loss,
|
|
347
367
|
inner=inner,
|
|
348
368
|
target=target,
|
torchzero/modules/__init__.py
CHANGED
|
@@ -1,15 +1,23 @@
|
|
|
1
|
+
from . import experimental
|
|
1
2
|
from .clipping import *
|
|
3
|
+
from .conjugate_gradient import *
|
|
2
4
|
from .grad_approximation import *
|
|
5
|
+
from .higher_order import *
|
|
6
|
+
from .least_squares import *
|
|
3
7
|
from .line_search import *
|
|
4
|
-
from .
|
|
8
|
+
from .misc import *
|
|
5
9
|
from .momentum import *
|
|
6
10
|
from .ops import *
|
|
7
|
-
from .
|
|
11
|
+
from .adaptive import *
|
|
8
12
|
from .projections import *
|
|
9
13
|
from .quasi_newton import *
|
|
14
|
+
from .second_order import *
|
|
10
15
|
from .smoothing import *
|
|
16
|
+
from .step_size import *
|
|
17
|
+
from .termination import *
|
|
18
|
+
from .trust_region import *
|
|
19
|
+
from .variance_reduction import *
|
|
11
20
|
from .weight_decay import *
|
|
12
21
|
from .wrappers import *
|
|
13
|
-
from .
|
|
14
|
-
from .
|
|
15
|
-
from .misc import *
|
|
22
|
+
from .restarts import *
|
|
23
|
+
from .zeroth_order import *
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from .adagrad import Adagrad, FullMatrixAdagrad
|
|
1
|
+
from .adagrad import Adagrad, FullMatrixAdagrad, AdagradNorm
|
|
2
2
|
|
|
3
3
|
# from .curveball import CurveBall
|
|
4
4
|
# from .spectral import SpectralPreconditioner
|
|
@@ -6,12 +6,15 @@ from .adahessian import AdaHessian
|
|
|
6
6
|
from .adam import Adam
|
|
7
7
|
from .adan import Adan
|
|
8
8
|
from .adaptive_heavyball import AdaptiveHeavyBall
|
|
9
|
+
from .aegd import AEGD
|
|
9
10
|
from .esgd import ESGD
|
|
10
|
-
from .
|
|
11
|
+
from .lmadagrad import LMAdagrad
|
|
11
12
|
from .lion import Lion
|
|
12
13
|
from .mars import MARSCorrection
|
|
14
|
+
from .matrix_momentum import MatrixMomentum
|
|
13
15
|
from .msam import MSAM, MSAMObjective
|
|
14
16
|
from .muon import DualNormCorrection, MuonAdjustLR, Orthogonalize, orthogonalize_grads_
|
|
17
|
+
from .natural_gradient import NaturalGradient
|
|
15
18
|
from .orthograd import OrthoGrad, orthograd_
|
|
16
19
|
from .rmsprop import RMSprop
|
|
17
20
|
from .rprop import (
|
|
@@ -0,0 +1,356 @@
|
|
|
1
|
+
from operator import itemgetter
|
|
2
|
+
from typing import Literal
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
from ...core import (
|
|
6
|
+
Chainable,
|
|
7
|
+
Module,
|
|
8
|
+
Target,
|
|
9
|
+
TensorwiseTransform,
|
|
10
|
+
Transform,
|
|
11
|
+
Var,
|
|
12
|
+
apply_transform,
|
|
13
|
+
)
|
|
14
|
+
from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
|
|
15
|
+
from ...utils.linalg import matrix_power_eigh
|
|
16
|
+
from ..functional import add_power_, lerp_power_, root, epsilon_step_size
|
|
17
|
+
from ...utils.linalg.linear_operator import Dense
|
|
18
|
+
|
|
19
|
+
def adagrad_(
|
|
20
|
+
tensors_: TensorList,
|
|
21
|
+
sq_sum_: TensorList,
|
|
22
|
+
alpha: float | NumberList,
|
|
23
|
+
lr_decay: float | NumberList,
|
|
24
|
+
eps: float | NumberList,
|
|
25
|
+
step: int,
|
|
26
|
+
pow: float = 2,
|
|
27
|
+
use_sqrt: bool = True,
|
|
28
|
+
divide: bool = False,
|
|
29
|
+
|
|
30
|
+
decay: float | None = None,
|
|
31
|
+
beta: float | None = None,
|
|
32
|
+
|
|
33
|
+
# inner args
|
|
34
|
+
inner: Module | None = None,
|
|
35
|
+
params: list[torch.Tensor] | None = None,
|
|
36
|
+
grads: list[torch.Tensor] | None = None,
|
|
37
|
+
):
|
|
38
|
+
"""returns `tensors_`"""
|
|
39
|
+
clr = alpha / (1 + step * lr_decay)
|
|
40
|
+
|
|
41
|
+
if beta is None or step == 1: sq_sum_ = add_power_(tensors_, sum_=sq_sum_, pow=pow)
|
|
42
|
+
else: sq_sum_ = lerp_power_(tensors_, exp_avg_pow_=sq_sum_, beta=beta, pow=pow)
|
|
43
|
+
if decay is not None:
|
|
44
|
+
sq_sum_.mul_(1-decay)
|
|
45
|
+
|
|
46
|
+
if inner is not None:
|
|
47
|
+
assert params is not None
|
|
48
|
+
tensors_ = TensorList(apply_transform(inner, tensors_, params=params, grads=grads))
|
|
49
|
+
|
|
50
|
+
if divide: sq_sum_ = sq_sum_ / max(step, 1)
|
|
51
|
+
|
|
52
|
+
if use_sqrt: tensors_.div_(root(sq_sum_, p=pow, inplace=False).add_(eps)).mul_(clr)
|
|
53
|
+
else: tensors_.div_(sq_sum_.add(eps)).mul_(clr)
|
|
54
|
+
|
|
55
|
+
return tensors_
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class Adagrad(Transform):
|
|
60
|
+
"""Adagrad, divides by sum of past squares of gradients.
|
|
61
|
+
|
|
62
|
+
This implementation is identical to ``torch.optim.Adagrad``.
|
|
63
|
+
|
|
64
|
+
Args:
|
|
65
|
+
lr_decay (float, optional): learning rate decay. Defaults to 0.
|
|
66
|
+
initial_accumulator_value (float, optional): initial value of the sum of squares of gradients. Defaults to 0.
|
|
67
|
+
eps (float, optional): division epsilon. Defaults to 1e-10.
|
|
68
|
+
alpha (float, optional): step size. Defaults to 1.
|
|
69
|
+
pow (float, optional): power for gradients and accumulator root. Defaults to 2.
|
|
70
|
+
use_sqrt (bool, optional): whether to take the root of the accumulator. Defaults to True.
|
|
71
|
+
inner (Chainable | None, optional): Inner modules that are applied after updating accumulator and before preconditioning. Defaults to None.
|
|
72
|
+
"""
|
|
73
|
+
def __init__(
|
|
74
|
+
self,
|
|
75
|
+
lr_decay: float = 0,
|
|
76
|
+
initial_accumulator_value: float = 0,
|
|
77
|
+
eps: float = 1e-10,
|
|
78
|
+
alpha: float = 1,
|
|
79
|
+
pow: float = 2,
|
|
80
|
+
use_sqrt: bool = True,
|
|
81
|
+
divide: bool=False,
|
|
82
|
+
beta:float | None = None,
|
|
83
|
+
decay: float | None = None,
|
|
84
|
+
inner: Chainable | None = None,
|
|
85
|
+
):
|
|
86
|
+
defaults = dict(alpha = alpha, lr_decay = lr_decay, initial_accumulator_value=initial_accumulator_value,
|
|
87
|
+
eps = eps, pow=pow, use_sqrt = use_sqrt, divide=divide, beta=beta, decay=decay)
|
|
88
|
+
super().__init__(defaults=defaults, uses_grad=False)
|
|
89
|
+
|
|
90
|
+
if inner is not None:
|
|
91
|
+
self.set_child('inner', inner)
|
|
92
|
+
|
|
93
|
+
@torch.no_grad
|
|
94
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
95
|
+
tensors = TensorList(tensors)
|
|
96
|
+
step = self.global_state['step'] = self.global_state.get('step', 0) + 1
|
|
97
|
+
|
|
98
|
+
lr_decay,alpha,eps = unpack_dicts(settings, 'lr_decay', 'alpha', 'eps', cls=NumberList)
|
|
99
|
+
|
|
100
|
+
pow, use_sqrt, divide = itemgetter('pow', 'use_sqrt', 'divide')(settings[0])
|
|
101
|
+
|
|
102
|
+
sq_sum = unpack_states(states, tensors, 'sq_sum', cls=TensorList)
|
|
103
|
+
|
|
104
|
+
# initialize accumulator on 1st step
|
|
105
|
+
if step == 1:
|
|
106
|
+
sq_sum.set_(tensors.full_like([s['initial_accumulator_value'] for s in settings]))
|
|
107
|
+
|
|
108
|
+
return adagrad_(
|
|
109
|
+
tensors,
|
|
110
|
+
sq_sum_=sq_sum,
|
|
111
|
+
alpha=alpha,
|
|
112
|
+
lr_decay=lr_decay,
|
|
113
|
+
eps=eps,
|
|
114
|
+
step=step,
|
|
115
|
+
pow=pow,
|
|
116
|
+
use_sqrt=use_sqrt,
|
|
117
|
+
divide=divide,
|
|
118
|
+
|
|
119
|
+
beta = self.defaults["beta"],
|
|
120
|
+
decay = self.defaults["decay"],
|
|
121
|
+
# inner args
|
|
122
|
+
inner=self.children.get("inner", None),
|
|
123
|
+
params=params,
|
|
124
|
+
grads=grads,
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
def lerp(start, end, weight):
|
|
129
|
+
return start + weight * (end - start)
|
|
130
|
+
|
|
131
|
+
def adagrad_norm_(
|
|
132
|
+
tensors_: TensorList,
|
|
133
|
+
accumulator: float | torch.Tensor,
|
|
134
|
+
alpha: float | NumberList,
|
|
135
|
+
lr_decay: float | NumberList,
|
|
136
|
+
eps: float | NumberList,
|
|
137
|
+
step: int,
|
|
138
|
+
use_sqrt: bool = True,
|
|
139
|
+
divide: bool = False,
|
|
140
|
+
|
|
141
|
+
decay: float | None = None,
|
|
142
|
+
beta: float | None = None,
|
|
143
|
+
|
|
144
|
+
# inner args
|
|
145
|
+
inner: Module | None = None,
|
|
146
|
+
params: list[torch.Tensor] | None = None,
|
|
147
|
+
grads: list[torch.Tensor] | None = None,
|
|
148
|
+
):
|
|
149
|
+
"""returns `tensors_`"""
|
|
150
|
+
clr = alpha / (1 + step * lr_decay)
|
|
151
|
+
|
|
152
|
+
gg = tensors_.dot(tensors_)
|
|
153
|
+
|
|
154
|
+
if beta is None or step == 1: accumulator += gg
|
|
155
|
+
else: accumulator = lerp(accumulator, gg, 1-beta)
|
|
156
|
+
|
|
157
|
+
if decay is not None:
|
|
158
|
+
accumulator *= 1-decay
|
|
159
|
+
|
|
160
|
+
if inner is not None:
|
|
161
|
+
assert params is not None
|
|
162
|
+
tensors_ = TensorList(apply_transform(inner, tensors_, params=params, grads=grads))
|
|
163
|
+
|
|
164
|
+
if divide: accumulator = accumulator / max(step, 1)
|
|
165
|
+
|
|
166
|
+
if use_sqrt: tensors_.div_(eps + accumulator.sqrt()).mul_(clr)
|
|
167
|
+
else: tensors_.div_(eps + accumulator).mul_(clr)
|
|
168
|
+
|
|
169
|
+
return tensors_, accumulator
|
|
170
|
+
|
|
171
|
+
class AdagradNorm(Transform):
|
|
172
|
+
"""Adagrad-Norm, divides by sum of past means of squares of gradients.
|
|
173
|
+
|
|
174
|
+
Args:
|
|
175
|
+
lr_decay (float, optional): learning rate decay. Defaults to 0.
|
|
176
|
+
initial_accumulator_value (float, optional): initial value of the sum of squares of gradients. Defaults to 0.
|
|
177
|
+
eps (float, optional): division epsilon. Defaults to 1e-10.
|
|
178
|
+
alpha (float, optional): step size. Defaults to 1.
|
|
179
|
+
pow (float, optional): power for gradients and accumulator root. Defaults to 2.
|
|
180
|
+
use_sqrt (bool, optional): whether to take the root of the accumulator. Defaults to True.
|
|
181
|
+
inner (Chainable | None, optional): Inner modules that are applied after updating accumulator and before preconditioning. Defaults to None.
|
|
182
|
+
"""
|
|
183
|
+
def __init__(
|
|
184
|
+
self,
|
|
185
|
+
lr_decay: float = 0,
|
|
186
|
+
initial_accumulator_value: float = 0,
|
|
187
|
+
eps: float = 1e-10,
|
|
188
|
+
alpha: float = 1,
|
|
189
|
+
pow: float = 2,
|
|
190
|
+
use_sqrt: bool = True,
|
|
191
|
+
divide: bool=False,
|
|
192
|
+
beta:float | None = None,
|
|
193
|
+
decay: float | None = None,
|
|
194
|
+
inner: Chainable | None = None,
|
|
195
|
+
):
|
|
196
|
+
defaults = dict(alpha = alpha, lr_decay = lr_decay, initial_accumulator_value=initial_accumulator_value,
|
|
197
|
+
eps = eps, pow=pow, use_sqrt = use_sqrt, divide=divide, beta=beta, decay=decay)
|
|
198
|
+
super().__init__(defaults=defaults, uses_grad=False)
|
|
199
|
+
|
|
200
|
+
if inner is not None:
|
|
201
|
+
self.set_child('inner', inner)
|
|
202
|
+
|
|
203
|
+
@torch.no_grad
|
|
204
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
205
|
+
tensors = TensorList(tensors)
|
|
206
|
+
step = self.global_state['step'] = self.global_state.get('step', 0) + 1
|
|
207
|
+
lr_decay,alpha,eps = unpack_dicts(settings, 'lr_decay', 'alpha', 'eps', cls=NumberList)
|
|
208
|
+
|
|
209
|
+
use_sqrt, divide, initial_accumulator_value = itemgetter('use_sqrt', 'divide', "initial_accumulator_value")(settings[0])
|
|
210
|
+
|
|
211
|
+
accumulator = self.global_state.get("accumulator", initial_accumulator_value)
|
|
212
|
+
|
|
213
|
+
d, self.global_state["accumulator"] = adagrad_norm_(
|
|
214
|
+
tensors,
|
|
215
|
+
accumulator=accumulator,
|
|
216
|
+
alpha=alpha,
|
|
217
|
+
lr_decay=lr_decay,
|
|
218
|
+
eps=eps,
|
|
219
|
+
step=step,
|
|
220
|
+
use_sqrt=use_sqrt,
|
|
221
|
+
divide=divide,
|
|
222
|
+
|
|
223
|
+
beta = self.defaults["beta"],
|
|
224
|
+
decay = self.defaults["decay"],
|
|
225
|
+
# inner args
|
|
226
|
+
inner=self.children.get("inner", None),
|
|
227
|
+
params=params,
|
|
228
|
+
grads=grads,
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
return d
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
class FullMatrixAdagrad(TensorwiseTransform):
|
|
235
|
+
"""Full-matrix version of Adagrad, can be customized to make RMSprop or Adam (see examples).
|
|
236
|
+
|
|
237
|
+
Note:
|
|
238
|
+
A more memory-efficient version equivalent to full matrix Adagrad on last n gradients is implemented in ``tz.m.LMAdagrad``.
|
|
239
|
+
|
|
240
|
+
Args:
|
|
241
|
+
beta (float | None, optional): momentum for gradient outer product accumulators. if None, uses sum. Defaults to None.
|
|
242
|
+
decay (float | None, optional): decay for gradient outer product accumulators. Defaults to None.
|
|
243
|
+
sqrt (bool, optional): whether to take the square root of the accumulator. Defaults to True.
|
|
244
|
+
concat_params (bool, optional): if False, each parameter will have it's own accumulator. Defaults to True.
|
|
245
|
+
precond_freq (int, optional): frequency of updating the inverse square root of the accumulator. Defaults to 1.
|
|
246
|
+
init (Literal[str], optional):
|
|
247
|
+
how to initialize the accumulator.
|
|
248
|
+
- "identity" - with identity matrix (default).
|
|
249
|
+
- "zeros" - with zero matrix.
|
|
250
|
+
- "ones" - with matrix of ones.
|
|
251
|
+
-"GGT" - with the first outer product
|
|
252
|
+
divide (bool, optional): whether to divide the accumulator by number of gradients in it. Defaults to False.
|
|
253
|
+
inner (Chainable | None, optional): inner modules to apply preconditioning to. Defaults to None.
|
|
254
|
+
|
|
255
|
+
## Examples:
|
|
256
|
+
|
|
257
|
+
Plain full-matrix adagrad
|
|
258
|
+
```python
|
|
259
|
+
opt = tz.Modular(
|
|
260
|
+
model.parameters(),
|
|
261
|
+
tz.m.FullMatrixAdagrd(),
|
|
262
|
+
tz.m.LR(1e-2),
|
|
263
|
+
)
|
|
264
|
+
```
|
|
265
|
+
|
|
266
|
+
Full-matrix RMSprop
|
|
267
|
+
```python
|
|
268
|
+
opt = tz.Modular(
|
|
269
|
+
model.parameters(),
|
|
270
|
+
tz.m.FullMatrixAdagrad(beta=0.99),
|
|
271
|
+
tz.m.LR(1e-2),
|
|
272
|
+
)
|
|
273
|
+
```
|
|
274
|
+
|
|
275
|
+
Full-matrix Adam
|
|
276
|
+
```python
|
|
277
|
+
opt = tz.Modular(
|
|
278
|
+
model.parameters(),
|
|
279
|
+
tz.m.FullMatrixAdagrad(beta=0.999, inner=tz.m.EMA(0.9)),
|
|
280
|
+
tz.m.Debias(0.9, 0.999),
|
|
281
|
+
tz.m.LR(1e-2),
|
|
282
|
+
)
|
|
283
|
+
```
|
|
284
|
+
"""
|
|
285
|
+
def __init__(
|
|
286
|
+
self,
|
|
287
|
+
beta: float | None = None,
|
|
288
|
+
decay: float | None = None,
|
|
289
|
+
sqrt: bool = True,
|
|
290
|
+
concat_params=True,
|
|
291
|
+
precond_freq: int = 1,
|
|
292
|
+
init: Literal["identity", "zeros", "ones", "GGT"] = "identity",
|
|
293
|
+
reg: float = 1e-12,
|
|
294
|
+
divide: bool = False,
|
|
295
|
+
inner: Chainable | None = None,
|
|
296
|
+
):
|
|
297
|
+
defaults = dict(beta=beta, decay=decay, sqrt=sqrt, precond_freq=precond_freq, init=init, divide=divide, reg=reg)
|
|
298
|
+
super().__init__(defaults, uses_grad=False, concat_params=concat_params, inner=inner,)
|
|
299
|
+
|
|
300
|
+
@torch.no_grad
|
|
301
|
+
def update_tensor(self, tensor, param, grad, loss, state, setting):
|
|
302
|
+
G = tensor.ravel()
|
|
303
|
+
GG = torch.outer(G, G)
|
|
304
|
+
decay = setting['decay']
|
|
305
|
+
beta = setting['beta']
|
|
306
|
+
init = setting['init']
|
|
307
|
+
|
|
308
|
+
if 'GG' not in state:
|
|
309
|
+
if init == 'identity': state['GG'] = torch.eye(GG.size(0), device=GG.device, dtype=GG.dtype)
|
|
310
|
+
elif init == 'zeros': state['GG'] = torch.zeros_like(GG)
|
|
311
|
+
elif init == 'ones': state['GG'] = torch.ones_like(GG)
|
|
312
|
+
elif init == 'GGT': state['GG'] = GG.clone()
|
|
313
|
+
else: raise ValueError(init)
|
|
314
|
+
if decay is not None: state['GG'].mul_(decay)
|
|
315
|
+
|
|
316
|
+
if beta is not None: state['GG'].lerp_(GG, 1-beta)
|
|
317
|
+
else: state['GG'].add_(GG)
|
|
318
|
+
state['i'] = state.get('i', 0) + 1 # number of GGTs in sum
|
|
319
|
+
|
|
320
|
+
@torch.no_grad
|
|
321
|
+
def apply_tensor(self, tensor, param, grad, loss, state, setting):
|
|
322
|
+
step = state.get('step', 0)
|
|
323
|
+
state['step'] = step + 1
|
|
324
|
+
|
|
325
|
+
GG: torch.Tensor = state['GG']
|
|
326
|
+
sqrt = setting['sqrt']
|
|
327
|
+
divide = setting['divide']
|
|
328
|
+
precond_freq = setting['precond_freq']
|
|
329
|
+
reg = setting['reg']
|
|
330
|
+
|
|
331
|
+
if divide: GG = GG/state.get('i', 1)
|
|
332
|
+
|
|
333
|
+
if reg != 0:
|
|
334
|
+
GG = GG + torch.eye(GG.size(0), device=GG.device, dtype=GG.dtype).mul_(reg)
|
|
335
|
+
|
|
336
|
+
if tensor.numel() == 1:
|
|
337
|
+
GG = GG.squeeze()
|
|
338
|
+
if sqrt: return tensor / GG.sqrt()
|
|
339
|
+
return tensor / GG
|
|
340
|
+
|
|
341
|
+
try:
|
|
342
|
+
if sqrt:
|
|
343
|
+
if "B" not in state or step % precond_freq == 0:
|
|
344
|
+
B = state["B"] = matrix_power_eigh(GG, -1/2)
|
|
345
|
+
else:
|
|
346
|
+
B = state["B"]
|
|
347
|
+
|
|
348
|
+
else: return torch.linalg.solve(GG, tensor.ravel()).view_as(tensor) # pylint:disable = not-callable
|
|
349
|
+
|
|
350
|
+
except torch.linalg.LinAlgError:
|
|
351
|
+
# fallback to diagonal AdaGrad
|
|
352
|
+
denom = GG.diagonal()
|
|
353
|
+
if sqrt: denom = denom.sqrt()
|
|
354
|
+
return tensor.div_(denom + max(reg, 1e-12))
|
|
355
|
+
|
|
356
|
+
return (B @ tensor.ravel()).view_as(tensor)
|