torchzero 0.3.9__py3-none-any.whl → 0.3.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docs/source/conf.py +6 -4
- docs/source/docstring template.py +46 -0
- tests/test_identical.py +2 -3
- tests/test_opts.py +115 -68
- tests/test_tensorlist.py +2 -2
- tests/test_vars.py +62 -61
- torchzero/core/__init__.py +2 -3
- torchzero/core/module.py +185 -53
- torchzero/core/transform.py +327 -159
- torchzero/modules/__init__.py +3 -1
- torchzero/modules/clipping/clipping.py +120 -23
- torchzero/modules/clipping/ema_clipping.py +37 -22
- torchzero/modules/clipping/growth_clipping.py +20 -21
- torchzero/modules/experimental/__init__.py +30 -4
- torchzero/modules/experimental/absoap.py +53 -156
- torchzero/modules/experimental/adadam.py +22 -15
- torchzero/modules/experimental/adamY.py +21 -25
- torchzero/modules/experimental/adam_lambertw.py +149 -0
- torchzero/modules/{line_search/trust_region.py → experimental/adaptive_step_size.py} +37 -8
- torchzero/modules/experimental/adasoap.py +24 -129
- torchzero/modules/experimental/cosine.py +214 -0
- torchzero/modules/experimental/cubic_adam.py +97 -0
- torchzero/modules/experimental/curveball.py +12 -12
- torchzero/modules/{projections → experimental}/dct.py +11 -11
- torchzero/modules/experimental/eigendescent.py +120 -0
- torchzero/modules/experimental/etf.py +195 -0
- torchzero/modules/experimental/exp_adam.py +113 -0
- torchzero/modules/experimental/expanded_lbfgs.py +141 -0
- torchzero/modules/{projections → experimental}/fft.py +10 -10
- torchzero/modules/experimental/gradmin.py +2 -2
- torchzero/modules/experimental/hnewton.py +85 -0
- torchzero/modules/{quasi_newton/experimental → experimental}/modular_lbfgs.py +49 -50
- torchzero/modules/experimental/newton_solver.py +11 -11
- torchzero/modules/experimental/newtonnewton.py +92 -0
- torchzero/modules/experimental/parabolic_search.py +220 -0
- torchzero/modules/experimental/reduce_outward_lr.py +10 -7
- torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +12 -54
- torchzero/modules/experimental/subspace_preconditioners.py +20 -10
- torchzero/modules/experimental/tensor_adagrad.py +42 -0
- torchzero/modules/functional.py +12 -2
- torchzero/modules/grad_approximation/fdm.py +31 -4
- torchzero/modules/grad_approximation/forward_gradient.py +17 -7
- torchzero/modules/grad_approximation/grad_approximator.py +69 -24
- torchzero/modules/grad_approximation/rfdm.py +310 -50
- torchzero/modules/higher_order/__init__.py +1 -0
- torchzero/modules/higher_order/higher_order_newton.py +319 -0
- torchzero/modules/line_search/__init__.py +4 -4
- torchzero/modules/line_search/adaptive.py +99 -0
- torchzero/modules/line_search/backtracking.py +75 -31
- torchzero/modules/line_search/line_search.py +107 -49
- torchzero/modules/line_search/polynomial.py +233 -0
- torchzero/modules/line_search/scipy.py +20 -5
- torchzero/modules/line_search/strong_wolfe.py +52 -36
- torchzero/modules/misc/__init__.py +27 -0
- torchzero/modules/misc/debug.py +48 -0
- torchzero/modules/misc/escape.py +60 -0
- torchzero/modules/misc/gradient_accumulation.py +70 -0
- torchzero/modules/misc/misc.py +316 -0
- torchzero/modules/misc/multistep.py +158 -0
- torchzero/modules/misc/regularization.py +171 -0
- torchzero/modules/misc/split.py +103 -0
- torchzero/modules/{ops → misc}/switch.py +48 -7
- torchzero/modules/momentum/__init__.py +1 -1
- torchzero/modules/momentum/averaging.py +25 -10
- torchzero/modules/momentum/cautious.py +115 -40
- torchzero/modules/momentum/ema.py +92 -41
- torchzero/modules/momentum/experimental.py +21 -13
- torchzero/modules/momentum/matrix_momentum.py +145 -76
- torchzero/modules/momentum/momentum.py +25 -4
- torchzero/modules/ops/__init__.py +3 -31
- torchzero/modules/ops/accumulate.py +51 -25
- torchzero/modules/ops/binary.py +108 -62
- torchzero/modules/ops/multi.py +95 -34
- torchzero/modules/ops/reduce.py +31 -23
- torchzero/modules/ops/unary.py +37 -21
- torchzero/modules/ops/utility.py +53 -45
- torchzero/modules/optimizers/__init__.py +12 -3
- torchzero/modules/optimizers/adagrad.py +48 -29
- torchzero/modules/optimizers/adahessian.py +223 -0
- torchzero/modules/optimizers/adam.py +35 -37
- torchzero/modules/optimizers/adan.py +110 -0
- torchzero/modules/optimizers/adaptive_heavyball.py +57 -0
- torchzero/modules/optimizers/esgd.py +171 -0
- torchzero/modules/optimizers/ladagrad.py +183 -0
- torchzero/modules/optimizers/lion.py +4 -4
- torchzero/modules/optimizers/mars.py +91 -0
- torchzero/modules/optimizers/msam.py +186 -0
- torchzero/modules/optimizers/muon.py +32 -7
- torchzero/modules/optimizers/orthograd.py +4 -5
- torchzero/modules/optimizers/rmsprop.py +19 -19
- torchzero/modules/optimizers/rprop.py +89 -52
- torchzero/modules/optimizers/sam.py +163 -0
- torchzero/modules/optimizers/shampoo.py +55 -27
- torchzero/modules/optimizers/soap.py +40 -37
- torchzero/modules/optimizers/sophia_h.py +82 -25
- torchzero/modules/projections/__init__.py +2 -4
- torchzero/modules/projections/cast.py +51 -0
- torchzero/modules/projections/galore.py +4 -2
- torchzero/modules/projections/projection.py +212 -118
- torchzero/modules/quasi_newton/__init__.py +44 -5
- torchzero/modules/quasi_newton/cg.py +190 -39
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +163 -0
- torchzero/modules/quasi_newton/lbfgs.py +154 -97
- torchzero/modules/quasi_newton/lsr1.py +102 -58
- torchzero/modules/quasi_newton/quasi_newton.py +1032 -177
- torchzero/modules/quasi_newton/trust_region.py +397 -0
- torchzero/modules/second_order/__init__.py +2 -2
- torchzero/modules/second_order/newton.py +245 -54
- torchzero/modules/second_order/newton_cg.py +311 -21
- torchzero/modules/second_order/nystrom.py +124 -21
- torchzero/modules/smoothing/gaussian.py +55 -21
- torchzero/modules/smoothing/laplacian.py +20 -12
- torchzero/modules/step_size/__init__.py +2 -0
- torchzero/modules/step_size/adaptive.py +122 -0
- torchzero/modules/step_size/lr.py +154 -0
- torchzero/modules/weight_decay/__init__.py +1 -1
- torchzero/modules/weight_decay/weight_decay.py +126 -10
- torchzero/modules/wrappers/optim_wrapper.py +40 -12
- torchzero/optim/wrappers/directsearch.py +281 -0
- torchzero/optim/wrappers/fcmaes.py +105 -0
- torchzero/optim/wrappers/mads.py +89 -0
- torchzero/optim/wrappers/nevergrad.py +20 -5
- torchzero/optim/wrappers/nlopt.py +28 -14
- torchzero/optim/wrappers/optuna.py +70 -0
- torchzero/optim/wrappers/scipy.py +167 -16
- torchzero/utils/__init__.py +3 -7
- torchzero/utils/derivatives.py +5 -4
- torchzero/utils/linalg/__init__.py +1 -1
- torchzero/utils/linalg/solve.py +251 -12
- torchzero/utils/numberlist.py +2 -0
- torchzero/utils/optimizer.py +55 -74
- torchzero/utils/python_tools.py +27 -4
- torchzero/utils/tensorlist.py +40 -28
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/METADATA +76 -51
- torchzero-0.3.11.dist-info/RECORD +159 -0
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/WHEEL +1 -1
- torchzero/core/preconditioner.py +0 -138
- torchzero/modules/experimental/algebraic_newton.py +0 -145
- torchzero/modules/experimental/soapy.py +0 -290
- torchzero/modules/experimental/spectral.py +0 -288
- torchzero/modules/experimental/structured_newton.py +0 -111
- torchzero/modules/experimental/tropical_newton.py +0 -136
- torchzero/modules/lr/__init__.py +0 -2
- torchzero/modules/lr/lr.py +0 -59
- torchzero/modules/lr/step_size.py +0 -97
- torchzero/modules/ops/debug.py +0 -25
- torchzero/modules/ops/misc.py +0 -419
- torchzero/modules/ops/split.py +0 -75
- torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
- torchzero/modules/quasi_newton/olbfgs.py +0 -196
- torchzero-0.3.9.dist-info/RECORD +0 -131
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/licenses/LICENSE +0 -0
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/top_level.txt +0 -0
|
@@ -6,7 +6,7 @@ from typing import Literal
|
|
|
6
6
|
|
|
7
7
|
import torch
|
|
8
8
|
|
|
9
|
-
from ...core import Modular, Module,
|
|
9
|
+
from ...core import Modular, Module, Var
|
|
10
10
|
from ...utils import NumberList, TensorList
|
|
11
11
|
from ...utils.derivatives import jacobian_wrt
|
|
12
12
|
from ..grad_approximation import GradApproximator, GradTarget
|
|
@@ -17,24 +17,24 @@ class Reformulation(Module, ABC):
|
|
|
17
17
|
super().__init__(defaults)
|
|
18
18
|
|
|
19
19
|
@abstractmethod
|
|
20
|
-
def closure(self, backward: bool, closure: Callable, params:list[torch.Tensor],
|
|
20
|
+
def closure(self, backward: bool, closure: Callable, params:list[torch.Tensor], var: Var) -> tuple[float | torch.Tensor, Sequence[torch.Tensor] | None]:
|
|
21
21
|
"""returns loss and gradient, if backward is False then gradient can be None"""
|
|
22
22
|
|
|
23
|
-
def pre_step(self,
|
|
23
|
+
def pre_step(self, var: Var) -> Var | None:
|
|
24
24
|
"""This runs once before each step, whereas `closure` may run multiple times per step if further modules
|
|
25
25
|
evaluate gradients at multiple points. This is useful for example to pre-generate new random perturbations."""
|
|
26
|
-
return
|
|
26
|
+
return var
|
|
27
27
|
|
|
28
|
-
def step(self,
|
|
29
|
-
ret = self.pre_step(
|
|
30
|
-
if isinstance(ret,
|
|
28
|
+
def step(self, var):
|
|
29
|
+
ret = self.pre_step(var)
|
|
30
|
+
if isinstance(ret, Var): var = ret
|
|
31
31
|
|
|
32
|
-
if
|
|
33
|
-
params, closure =
|
|
32
|
+
if var.closure is None: raise RuntimeError("Reformulation requires closure")
|
|
33
|
+
params, closure = var.params, var.closure
|
|
34
34
|
|
|
35
35
|
|
|
36
36
|
def modified_closure(backward=True):
|
|
37
|
-
loss, grad = self.closure(backward, closure, params,
|
|
37
|
+
loss, grad = self.closure(backward, closure, params, var)
|
|
38
38
|
|
|
39
39
|
if grad is not None:
|
|
40
40
|
for p,g in zip(params, grad):
|
|
@@ -42,8 +42,8 @@ class Reformulation(Module, ABC):
|
|
|
42
42
|
|
|
43
43
|
return loss
|
|
44
44
|
|
|
45
|
-
|
|
46
|
-
return
|
|
45
|
+
var.closure = modified_closure
|
|
46
|
+
return var
|
|
47
47
|
|
|
48
48
|
|
|
49
49
|
def _decay_sigma_(self: Module, params):
|
|
@@ -58,12 +58,46 @@ def _generate_perturbations_to_state_(self: Module, params: TensorList, n_sample
|
|
|
58
58
|
for param, prt in zip(params, zip(*perturbations)):
|
|
59
59
|
self.state[param]['perturbations'] = prt
|
|
60
60
|
|
|
61
|
-
def _clear_state_hook(optimizer: Modular,
|
|
61
|
+
def _clear_state_hook(optimizer: Modular, var: Var, self: Module):
|
|
62
62
|
for m in optimizer.unrolled_modules:
|
|
63
63
|
if m is not self:
|
|
64
64
|
m.reset()
|
|
65
65
|
|
|
66
66
|
class GaussianHomotopy(Reformulation):
|
|
67
|
+
"""Approximately smoothes the function with a gaussian kernel by sampling it at random perturbed points around current point. Both function values and gradients are averaged over all samples. The perturbed points are generated before each
|
|
68
|
+
step and remain the same throughout the step.
|
|
69
|
+
|
|
70
|
+
.. note::
|
|
71
|
+
This module reformulates the objective, it modifies the closure to evaluate value and gradients of a smoothed function. All modules after this will operate on the modified objective.
|
|
72
|
+
|
|
73
|
+
.. note::
|
|
74
|
+
This module requires the a closure passed to the optimizer step,
|
|
75
|
+
as it needs to re-evaluate the loss and gradients at perturbed points.
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
n_samples (int): number of points to sample, larger values lead to a more accurate smoothing.
|
|
79
|
+
init_sigma (float): initial scale of perturbations.
|
|
80
|
+
tol (float | None, optional):
|
|
81
|
+
if maximal parameters change value is smaller than this, sigma is reduced by :code:`decay`. Defaults to 1e-4.
|
|
82
|
+
decay (float, optional): multiplier to sigma when converged on a smoothed function. Defaults to 0.5.
|
|
83
|
+
max_steps (int | None, optional): maximum number of steps before decaying sigma. Defaults to None.
|
|
84
|
+
clear_state (bool, optional):
|
|
85
|
+
whether to clear all other module states when sigma is decayed, because the objective function changes. Defaults to True.
|
|
86
|
+
seed (int | None, optional): seed for random perturbationss. Defaults to None.
|
|
87
|
+
|
|
88
|
+
Examples:
|
|
89
|
+
Gaussian-smoothed NewtonCG
|
|
90
|
+
|
|
91
|
+
.. code-block:: python
|
|
92
|
+
|
|
93
|
+
opt = tz.Modular(
|
|
94
|
+
model.parameters(),
|
|
95
|
+
tz.m.GaussianHomotopy(100),
|
|
96
|
+
tz.m.NewtonCG(maxiter=20),
|
|
97
|
+
tz.m.AdaptiveBacktracking(),
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
"""
|
|
67
101
|
def __init__(
|
|
68
102
|
self,
|
|
69
103
|
n_samples: int,
|
|
@@ -85,12 +119,12 @@ class GaussianHomotopy(Reformulation):
|
|
|
85
119
|
else: self.global_state['generator'] = None
|
|
86
120
|
return self.global_state['generator']
|
|
87
121
|
|
|
88
|
-
def pre_step(self,
|
|
89
|
-
params = TensorList(
|
|
122
|
+
def pre_step(self, var):
|
|
123
|
+
params = TensorList(var.params)
|
|
90
124
|
settings = self.settings[params[0]]
|
|
91
125
|
n_samples = settings['n_samples']
|
|
92
|
-
init_sigma = self.
|
|
93
|
-
sigmas = self.get_state('sigma',
|
|
126
|
+
init_sigma = [self.settings[p]['init_sigma'] for p in params]
|
|
127
|
+
sigmas = self.get_state(params, 'sigma', init=init_sigma)
|
|
94
128
|
|
|
95
129
|
if any('perturbations' not in self.state[p] for p in params):
|
|
96
130
|
generator = self._get_generator(settings['seed'], params)
|
|
@@ -109,9 +143,9 @@ class GaussianHomotopy(Reformulation):
|
|
|
109
143
|
tol = settings['tol']
|
|
110
144
|
if tol is not None and not decayed:
|
|
111
145
|
if not any('prev_params' in self.state[p] for p in params):
|
|
112
|
-
prev_params = self.get_state('prev_params',
|
|
146
|
+
prev_params = self.get_state(params, 'prev_params', cls=TensorList, init='param')
|
|
113
147
|
else:
|
|
114
|
-
prev_params = self.get_state('prev_params',
|
|
148
|
+
prev_params = self.get_state(params, 'prev_params', cls=TensorList, init='param')
|
|
115
149
|
s = params - prev_params
|
|
116
150
|
|
|
117
151
|
if s.abs().global_max() <= tol:
|
|
@@ -124,10 +158,10 @@ class GaussianHomotopy(Reformulation):
|
|
|
124
158
|
generator = self._get_generator(settings['seed'], params)
|
|
125
159
|
_generate_perturbations_to_state_(self, params=params, n_samples=n_samples, sigmas=sigmas, generator=generator)
|
|
126
160
|
if settings['clear_state']:
|
|
127
|
-
|
|
161
|
+
var.post_step_hooks.append(partial(_clear_state_hook, self=self))
|
|
128
162
|
|
|
129
163
|
@torch.no_grad
|
|
130
|
-
def closure(self, backward, closure, params,
|
|
164
|
+
def closure(self, backward, closure, params, var):
|
|
131
165
|
params = TensorList(params)
|
|
132
166
|
|
|
133
167
|
settings = self.settings[params[0]]
|
|
@@ -56,7 +56,7 @@ def _precompute_denominator(tensor: torch.Tensor, sigma) -> torch.Tensor:
|
|
|
56
56
|
return 1 - sigma * torch.fft.fft(v) # pylint: disable = not-callable
|
|
57
57
|
|
|
58
58
|
class LaplacianSmoothing(Transform):
|
|
59
|
-
"""Applies laplacian smoothing via a fast Fourier transform solver.
|
|
59
|
+
"""Applies laplacian smoothing via a fast Fourier transform solver which can improve generalization.
|
|
60
60
|
|
|
61
61
|
Args:
|
|
62
62
|
sigma (float, optional): controls the amount of smoothing. Defaults to 1.
|
|
@@ -67,11 +67,21 @@ class LaplacianSmoothing(Transform):
|
|
|
67
67
|
minimum number of elements in a parameter to apply laplacian smoothing to.
|
|
68
68
|
Only has effect if `layerwise` is True. Defaults to 4.
|
|
69
69
|
target (str, optional):
|
|
70
|
-
what to set on
|
|
70
|
+
what to set on var.
|
|
71
|
+
|
|
72
|
+
Examples:
|
|
73
|
+
Laplacian Smoothing Gradient Descent optimizer as in the paper
|
|
74
|
+
|
|
75
|
+
.. code-block:: python
|
|
76
|
+
|
|
77
|
+
opt = tz.Modular(
|
|
78
|
+
model.parameters(),
|
|
79
|
+
tz.m.LaplacianSmoothing(),
|
|
80
|
+
tz.m.LR(1e-2),
|
|
81
|
+
)
|
|
71
82
|
|
|
72
83
|
Reference:
|
|
73
|
-
|
|
74
|
-
Laplacian smoothing gradient descent. Research in the Mathematical Sciences, 9(3), 55.*
|
|
84
|
+
Osher, S., Wang, B., Yin, P., Luo, X., Barekat, F., Pham, M., & Lin, A. (2022). Laplacian smoothing gradient descent. Research in the Mathematical Sciences, 9(3), 55.
|
|
75
85
|
|
|
76
86
|
"""
|
|
77
87
|
def __init__(self, sigma:float = 1, layerwise=True, min_numel = 4, target: Target = 'update'):
|
|
@@ -82,19 +92,17 @@ class LaplacianSmoothing(Transform):
|
|
|
82
92
|
|
|
83
93
|
|
|
84
94
|
@torch.no_grad
|
|
85
|
-
def
|
|
86
|
-
layerwise =
|
|
95
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
96
|
+
layerwise = settings[0]['layerwise']
|
|
87
97
|
|
|
88
98
|
# layerwise laplacian smoothing
|
|
89
99
|
if layerwise:
|
|
90
100
|
|
|
91
101
|
# precompute the denominator for each layer and store it in each parameters state
|
|
92
102
|
smoothed_target = TensorList()
|
|
93
|
-
for p, t in zip(params, tensors):
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
state = self.state[p]
|
|
97
|
-
if 'denominator' not in state: state['denominator'] = _precompute_denominator(p, settings['sigma'])
|
|
103
|
+
for p, t, state, setting in zip(params, tensors, states, settings):
|
|
104
|
+
if p.numel() > setting['min_numel']:
|
|
105
|
+
if 'denominator' not in state: state['denominator'] = _precompute_denominator(p, setting['sigma'])
|
|
98
106
|
smoothed_target.append(torch.fft.ifft(torch.fft.fft(t.view(-1)) / state['denominator']).real.view_as(t)) #pylint:disable=not-callable
|
|
99
107
|
else:
|
|
100
108
|
smoothed_target.append(t)
|
|
@@ -106,7 +114,7 @@ class LaplacianSmoothing(Transform):
|
|
|
106
114
|
# precompute full denominator
|
|
107
115
|
tensors = TensorList(tensors)
|
|
108
116
|
if self.global_state.get('full_denominator', None) is None:
|
|
109
|
-
self.global_state['full_denominator'] = _precompute_denominator(tensors.to_vec(),
|
|
117
|
+
self.global_state['full_denominator'] = _precompute_denominator(tensors.to_vec(), settings[0]['sigma'])
|
|
110
118
|
|
|
111
119
|
# apply the smoothing
|
|
112
120
|
vec = tensors.to_vec()
|
|
@@ -0,0 +1,122 @@
|
|
|
1
|
+
"""Various step size strategies"""
|
|
2
|
+
from typing import Any, Literal
|
|
3
|
+
from operator import itemgetter
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
from ...core import Transform, Chainable
|
|
7
|
+
from ...utils import TensorList, unpack_dicts, unpack_states, NumberList
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class PolyakStepSize(Transform):
|
|
11
|
+
"""Polyak's subgradient method.
|
|
12
|
+
|
|
13
|
+
Args:
|
|
14
|
+
f_star (int, optional):
|
|
15
|
+
(estimated) minimal possible value of the objective function (lowest possible loss). Defaults to 0.
|
|
16
|
+
max (float | None, optional): maximum possible step size. Defaults to None.
|
|
17
|
+
use_grad (bool, optional):
|
|
18
|
+
if True, uses dot product of update and gradient to compute the step size.
|
|
19
|
+
Otherwise, dot product of update with itself is used, which has no geometric meaning so it probably won't work well.
|
|
20
|
+
Defaults to False.
|
|
21
|
+
alpha (float, optional): multiplier to Polyak step-size. Defaults to 1.
|
|
22
|
+
"""
|
|
23
|
+
def __init__(self, f_star: float = 0, max: float | None = None, use_grad=False, alpha: float = 1, inner: Chainable | None = None):
|
|
24
|
+
|
|
25
|
+
defaults = dict(alpha=alpha, max=max, f_star=f_star, use_grad=use_grad)
|
|
26
|
+
super().__init__(defaults, uses_grad=use_grad, uses_loss=True, inner=inner)
|
|
27
|
+
|
|
28
|
+
def update_tensors(self, tensors, params, grads, loss, states, settings):
|
|
29
|
+
assert grads is not None and loss is not None
|
|
30
|
+
tensors = TensorList(tensors)
|
|
31
|
+
grads = TensorList(grads)
|
|
32
|
+
|
|
33
|
+
use_grad, max, f_star = itemgetter('use_grad', 'max', 'f_star')(settings[0])
|
|
34
|
+
|
|
35
|
+
if use_grad: gg = tensors.dot(grads)
|
|
36
|
+
else: gg = tensors.dot(tensors)
|
|
37
|
+
|
|
38
|
+
if gg.abs() <= torch.finfo(gg.dtype).eps: step_size = 0 # converged
|
|
39
|
+
else: step_size = (loss - f_star) / gg
|
|
40
|
+
|
|
41
|
+
if max is not None:
|
|
42
|
+
if step_size > max: step_size = max
|
|
43
|
+
|
|
44
|
+
self.global_state['step_size'] = step_size
|
|
45
|
+
|
|
46
|
+
@torch.no_grad
|
|
47
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
48
|
+
step_size = self.global_state.get('step_size', 1)
|
|
49
|
+
torch._foreach_mul_(tensors, step_size * unpack_dicts(settings, 'alpha', cls=NumberList))
|
|
50
|
+
return tensors
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def _bb_short(s: TensorList, y: TensorList, sy, eps, fallback):
|
|
55
|
+
yy = y.dot(y)
|
|
56
|
+
if yy < eps:
|
|
57
|
+
if sy < eps: return fallback # try to fallback on long
|
|
58
|
+
ss = s.dot(s)
|
|
59
|
+
return ss/sy
|
|
60
|
+
return sy/yy
|
|
61
|
+
|
|
62
|
+
def _bb_long(s: TensorList, y: TensorList, sy, eps, fallback):
|
|
63
|
+
ss = s.dot(s)
|
|
64
|
+
if sy < eps:
|
|
65
|
+
yy = y.dot(y) # try to fallback on short
|
|
66
|
+
if yy < eps: return fallback
|
|
67
|
+
return sy/yy
|
|
68
|
+
return ss/sy
|
|
69
|
+
|
|
70
|
+
def _bb_geom(s: TensorList, y: TensorList, sy, eps, fallback):
|
|
71
|
+
short = _bb_short(s, y, sy, eps, fallback)
|
|
72
|
+
long = _bb_long(s, y, sy, eps, fallback)
|
|
73
|
+
return (short * long) ** 0.5
|
|
74
|
+
|
|
75
|
+
class BarzilaiBorwein(Transform):
|
|
76
|
+
"""Barzilai-Borwein method.
|
|
77
|
+
|
|
78
|
+
Args:
|
|
79
|
+
type (str, optional):
|
|
80
|
+
one of "short" with formula sᵀy/yᵀy, "long" with formula sᵀs/sᵀy, or "geom" to use geometric mean of short and long.
|
|
81
|
+
Defaults to 'geom'.
|
|
82
|
+
scale_first (bool, optional):
|
|
83
|
+
whether to make first step very small when previous gradient is not available. Defaults to True.
|
|
84
|
+
fallback (float, optional): step size when denominator is less than 0 (will happen on negative curvature). Defaults to 1e-3.
|
|
85
|
+
inner (Chainable | None, optional):
|
|
86
|
+
step size will be applied to outputs of this module. Defaults to None.
|
|
87
|
+
|
|
88
|
+
"""
|
|
89
|
+
def __init__(self, type: Literal['long', 'short', 'geom'] = 'geom', scale_first:bool=True, fallback:float=1e-3, inner:Chainable|None = None):
|
|
90
|
+
defaults = dict(type=type, fallback=fallback)
|
|
91
|
+
super().__init__(defaults, uses_grad=False, scale_first=scale_first, inner=inner)
|
|
92
|
+
|
|
93
|
+
def reset_for_online(self):
|
|
94
|
+
super().reset_for_online()
|
|
95
|
+
self.clear_state_keys('prev_p', 'prev_g')
|
|
96
|
+
|
|
97
|
+
@torch.no_grad
|
|
98
|
+
def update_tensors(self, tensors, params, grads, loss, states, settings):
|
|
99
|
+
prev_p, prev_g = unpack_states(states, tensors, 'prev_p', 'prev_g', cls=TensorList)
|
|
100
|
+
fallback = unpack_dicts(settings, 'fallback', cls=NumberList)
|
|
101
|
+
type = settings[0]['type']
|
|
102
|
+
|
|
103
|
+
s = params-prev_p
|
|
104
|
+
y = tensors-prev_g
|
|
105
|
+
sy = s.dot(y)
|
|
106
|
+
eps = torch.finfo(sy.dtype).eps
|
|
107
|
+
|
|
108
|
+
if type == 'short': step_size = _bb_short(s, y, sy, eps, fallback)
|
|
109
|
+
elif type == 'long': step_size = _bb_long(s, y, sy, eps, fallback)
|
|
110
|
+
elif type == 'geom': step_size = _bb_geom(s, y, sy, eps, fallback)
|
|
111
|
+
else: raise ValueError(type)
|
|
112
|
+
|
|
113
|
+
self.global_state['step_size'] = step_size
|
|
114
|
+
|
|
115
|
+
prev_p.copy_(params)
|
|
116
|
+
prev_g.copy_(tensors)
|
|
117
|
+
|
|
118
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
119
|
+
step_size = self.global_state.get('step_size', 1)
|
|
120
|
+
torch._foreach_mul_(tensors, step_size)
|
|
121
|
+
return tensors
|
|
122
|
+
|
|
@@ -0,0 +1,154 @@
|
|
|
1
|
+
"""Learning rate"""
|
|
2
|
+
import torch
|
|
3
|
+
import random
|
|
4
|
+
|
|
5
|
+
from ...core import Transform
|
|
6
|
+
from ...utils import NumberList, TensorList, generic_ne, unpack_dicts
|
|
7
|
+
|
|
8
|
+
def lazy_lr(tensors: TensorList, lr: float | list, inplace:bool):
|
|
9
|
+
"""multiplies by lr if lr is not 1"""
|
|
10
|
+
if generic_ne(lr, 1):
|
|
11
|
+
if inplace: return tensors.mul_(lr)
|
|
12
|
+
return tensors * lr
|
|
13
|
+
return tensors
|
|
14
|
+
|
|
15
|
+
class LR(Transform):
|
|
16
|
+
"""Learning rate. Adding this module also adds support for LR schedulers."""
|
|
17
|
+
def __init__(self, lr: float):
|
|
18
|
+
defaults=dict(lr=lr)
|
|
19
|
+
super().__init__(defaults, uses_grad=False)
|
|
20
|
+
|
|
21
|
+
@torch.no_grad
|
|
22
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
23
|
+
return lazy_lr(TensorList(tensors), lr=[s['lr'] for s in settings], inplace=True)
|
|
24
|
+
|
|
25
|
+
class StepSize(Transform):
|
|
26
|
+
"""this is exactly the same as LR, except the `lr` parameter can be renamed to any other name to avoid clashes"""
|
|
27
|
+
def __init__(self, step_size: float, key = 'step_size'):
|
|
28
|
+
defaults={"key": key, key: step_size}
|
|
29
|
+
super().__init__(defaults, uses_grad=False)
|
|
30
|
+
|
|
31
|
+
@torch.no_grad
|
|
32
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
33
|
+
return lazy_lr(TensorList(tensors), lr=[s[s['key']] for s in settings], inplace=True)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def _warmup_lr(step: int, start_lr: float | NumberList, end_lr: float | NumberList, steps: float):
|
|
37
|
+
"""returns warm up lr scalar"""
|
|
38
|
+
if step > steps: return end_lr
|
|
39
|
+
return start_lr + (end_lr - start_lr) * (step / steps)
|
|
40
|
+
|
|
41
|
+
class Warmup(Transform):
|
|
42
|
+
"""Learning rate warmup, linearly increases learning rate multiplier from :code:`start_lr` to :code:`end_lr` over :code:`steps` steps.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
steps (int, optional): number of steps to perform warmup for. Defaults to 100.
|
|
46
|
+
start_lr (_type_, optional): initial learning rate multiplier on first step. Defaults to 1e-5.
|
|
47
|
+
end_lr (float, optional): learning rate multiplier at the end and after warmup. Defaults to 1.
|
|
48
|
+
|
|
49
|
+
Example:
|
|
50
|
+
Adam with 1000 steps warmup
|
|
51
|
+
|
|
52
|
+
.. code-block:: python
|
|
53
|
+
|
|
54
|
+
opt = tz.Modular(
|
|
55
|
+
model.parameters(),
|
|
56
|
+
tz.m.Adam(),
|
|
57
|
+
tz.m.LR(1e-2),
|
|
58
|
+
tz.m.Warmup(steps=1000)
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
"""
|
|
62
|
+
def __init__(self, steps = 100, start_lr = 1e-5, end_lr:float = 1):
|
|
63
|
+
defaults = dict(start_lr=start_lr,end_lr=end_lr, steps=steps)
|
|
64
|
+
super().__init__(defaults, uses_grad=False)
|
|
65
|
+
|
|
66
|
+
@torch.no_grad
|
|
67
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
68
|
+
start_lr, end_lr = unpack_dicts(settings, 'start_lr', 'end_lr', cls = NumberList)
|
|
69
|
+
num_steps = settings[0]['steps']
|
|
70
|
+
step = self.global_state.get('step', 0)
|
|
71
|
+
|
|
72
|
+
tensors = lazy_lr(
|
|
73
|
+
TensorList(tensors),
|
|
74
|
+
lr=_warmup_lr(step=step, start_lr=start_lr, end_lr=end_lr, steps=num_steps),
|
|
75
|
+
inplace=True
|
|
76
|
+
)
|
|
77
|
+
self.global_state['step'] = step + 1
|
|
78
|
+
return tensors
|
|
79
|
+
|
|
80
|
+
class WarmupNormClip(Transform):
|
|
81
|
+
"""Warmup via clipping of the update norm.
|
|
82
|
+
|
|
83
|
+
Args:
|
|
84
|
+
start_norm (_type_, optional): maximal norm on the first step. Defaults to 1e-5.
|
|
85
|
+
end_norm (float, optional): maximal norm on the last step. After that, norm clipping is disabled. Defaults to 1.
|
|
86
|
+
steps (int, optional): number of steps to perform warmup for. Defaults to 100.
|
|
87
|
+
|
|
88
|
+
Example:
|
|
89
|
+
Adam with 1000 steps norm clip warmup
|
|
90
|
+
|
|
91
|
+
.. code-block:: python
|
|
92
|
+
|
|
93
|
+
opt = tz.Modular(
|
|
94
|
+
model.parameters(),
|
|
95
|
+
tz.m.Adam(),
|
|
96
|
+
tz.m.WarmupNormClip(steps=1000)
|
|
97
|
+
tz.m.LR(1e-2),
|
|
98
|
+
)
|
|
99
|
+
"""
|
|
100
|
+
def __init__(self, steps = 100, start_norm = 1e-5, end_norm:float = 1):
|
|
101
|
+
defaults = dict(start_norm=start_norm,end_norm=end_norm, steps=steps)
|
|
102
|
+
super().__init__(defaults, uses_grad=False)
|
|
103
|
+
|
|
104
|
+
@torch.no_grad
|
|
105
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
106
|
+
start_norm, end_norm = unpack_dicts(settings, 'start_norm', 'end_norm', cls = NumberList)
|
|
107
|
+
num_steps = settings[0]['steps']
|
|
108
|
+
step = self.global_state.get('step', 0)
|
|
109
|
+
if step > num_steps: return tensors
|
|
110
|
+
|
|
111
|
+
tensors = TensorList(tensors)
|
|
112
|
+
norm = tensors.global_vector_norm()
|
|
113
|
+
current_max_norm = _warmup_lr(step, start_norm[0], end_norm[0], num_steps)
|
|
114
|
+
if norm > current_max_norm:
|
|
115
|
+
tensors.mul_(current_max_norm / norm)
|
|
116
|
+
|
|
117
|
+
self.global_state['step'] = step + 1
|
|
118
|
+
return tensors
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
class RandomStepSize(Transform):
|
|
122
|
+
"""Uses random global or layer-wise step size from `low` to `high`.
|
|
123
|
+
|
|
124
|
+
Args:
|
|
125
|
+
low (float, optional): minimum learning rate. Defaults to 0.
|
|
126
|
+
high (float, optional): maximum learning rate. Defaults to 1.
|
|
127
|
+
parameterwise (bool, optional):
|
|
128
|
+
if True, generate random step size for each parameter separately,
|
|
129
|
+
if False generate one global random step size. Defaults to False.
|
|
130
|
+
"""
|
|
131
|
+
def __init__(self, low: float = 0, high: float = 1, parameterwise=False, seed:int|None=None):
|
|
132
|
+
defaults = dict(low=low, high=high, parameterwise=parameterwise,seed=seed)
|
|
133
|
+
super().__init__(defaults, uses_grad=False)
|
|
134
|
+
|
|
135
|
+
@torch.no_grad
|
|
136
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
137
|
+
s = settings[0]
|
|
138
|
+
parameterwise = s['parameterwise']
|
|
139
|
+
|
|
140
|
+
seed = s['seed']
|
|
141
|
+
if 'generator' not in self.global_state:
|
|
142
|
+
self.global_state['generator'] = random.Random(seed)
|
|
143
|
+
generator: random.Random = self.global_state['generator']
|
|
144
|
+
|
|
145
|
+
if parameterwise:
|
|
146
|
+
low, high = unpack_dicts(settings, 'low', 'high')
|
|
147
|
+
lr = [generator.uniform(l, h) for l, h in zip(low, high)]
|
|
148
|
+
else:
|
|
149
|
+
low = s['low']
|
|
150
|
+
high = s['high']
|
|
151
|
+
lr = generator.uniform(low, high)
|
|
152
|
+
|
|
153
|
+
torch._foreach_mul_(tensors, lr)
|
|
154
|
+
return tensors
|
|
@@ -1 +1 @@
|
|
|
1
|
-
from .weight_decay import WeightDecay, DirectWeightDecay, decay_weights_
|
|
1
|
+
from .weight_decay import WeightDecay, DirectWeightDecay, decay_weights_, RelativeWeightDecay
|
|
@@ -1,9 +1,11 @@
|
|
|
1
1
|
from collections.abc import Iterable, Sequence
|
|
2
|
+
from typing import Literal
|
|
2
3
|
|
|
3
4
|
import torch
|
|
4
5
|
|
|
5
6
|
from ...core import Module, Target, Transform
|
|
6
|
-
from ...utils import NumberList, TensorList, as_tensorlist
|
|
7
|
+
from ...utils import NumberList, TensorList, as_tensorlist, unpack_dicts, unpack_states
|
|
8
|
+
|
|
7
9
|
|
|
8
10
|
@torch.no_grad
|
|
9
11
|
def weight_decay_(
|
|
@@ -20,17 +22,126 @@ def weight_decay_(
|
|
|
20
22
|
|
|
21
23
|
|
|
22
24
|
class WeightDecay(Transform):
|
|
25
|
+
"""Weight decay.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
weight_decay (float): weight decay scale.
|
|
29
|
+
ord (int, optional): order of the penalty, e.g. 1 for L1 and 2 for L2. Defaults to 2.
|
|
30
|
+
target (Target, optional): what to set on var. Defaults to 'update'.
|
|
31
|
+
|
|
32
|
+
Examples:
|
|
33
|
+
Adam with non-decoupled weight decay
|
|
34
|
+
|
|
35
|
+
.. code-block:: python
|
|
36
|
+
|
|
37
|
+
opt = tz.Modular(
|
|
38
|
+
model.parameters(),
|
|
39
|
+
tz.m.WeightDecay(1e-3),
|
|
40
|
+
tz.m.Adam(),
|
|
41
|
+
tz.m.LR(1e-3)
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
Adam with decoupled weight decay that still scales with learning rate
|
|
45
|
+
|
|
46
|
+
.. code-block:: python
|
|
47
|
+
|
|
48
|
+
opt = tz.Modular(
|
|
49
|
+
model.parameters(),
|
|
50
|
+
tz.m.Adam(),
|
|
51
|
+
tz.m.WeightDecay(1e-3),
|
|
52
|
+
tz.m.LR(1e-3)
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
Adam with fully decoupled weight decay that doesn't scale with learning rate
|
|
56
|
+
|
|
57
|
+
.. code-block:: python
|
|
58
|
+
|
|
59
|
+
opt = tz.Modular(
|
|
60
|
+
model.parameters(),
|
|
61
|
+
tz.m.Adam(),
|
|
62
|
+
tz.m.LR(1e-3),
|
|
63
|
+
tz.m.WeightDecay(1e-6)
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
"""
|
|
23
67
|
def __init__(self, weight_decay: float, ord: int = 2, target: Target = 'update'):
|
|
68
|
+
|
|
24
69
|
defaults = dict(weight_decay=weight_decay, ord=ord)
|
|
25
70
|
super().__init__(defaults, uses_grad=False, target=target)
|
|
26
71
|
|
|
27
72
|
@torch.no_grad
|
|
28
|
-
def
|
|
29
|
-
weight_decay =
|
|
30
|
-
ord =
|
|
73
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
74
|
+
weight_decay = NumberList(s['weight_decay'] for s in settings)
|
|
75
|
+
ord = settings[0]['ord']
|
|
31
76
|
|
|
32
77
|
return weight_decay_(as_tensorlist(tensors), as_tensorlist(params), weight_decay, ord)
|
|
33
78
|
|
|
79
|
+
class RelativeWeightDecay(Transform):
|
|
80
|
+
"""Weight decay relative to the mean absolute value of update, gradient or parameters depending on value of :code:`norm_input` argument.
|
|
81
|
+
|
|
82
|
+
Args:
|
|
83
|
+
weight_decay (float): relative weight decay scale.
|
|
84
|
+
ord (int, optional): order of the penalty, e.g. 1 for L1 and 2 for L2. Defaults to 2.
|
|
85
|
+
norm_input (str, optional):
|
|
86
|
+
determines what should weight decay be relative to. "update", "grad" or "params".
|
|
87
|
+
Defaults to "update".
|
|
88
|
+
target (Target, optional): what to set on var. Defaults to 'update'.
|
|
89
|
+
|
|
90
|
+
Examples:
|
|
91
|
+
Adam with non-decoupled relative weight decay
|
|
92
|
+
|
|
93
|
+
.. code-block:: python
|
|
94
|
+
|
|
95
|
+
opt = tz.Modular(
|
|
96
|
+
model.parameters(),
|
|
97
|
+
tz.m.RelativeWeightDecay(1e-3),
|
|
98
|
+
tz.m.Adam(),
|
|
99
|
+
tz.m.LR(1e-3)
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
Adam with decoupled relative weight decay
|
|
103
|
+
|
|
104
|
+
.. code-block:: python
|
|
105
|
+
|
|
106
|
+
opt = tz.Modular(
|
|
107
|
+
model.parameters(),
|
|
108
|
+
tz.m.Adam(),
|
|
109
|
+
tz.m.RelativeWeightDecay(1e-3),
|
|
110
|
+
tz.m.LR(1e-3)
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
"""
|
|
114
|
+
def __init__(
|
|
115
|
+
self,
|
|
116
|
+
weight_decay: float = 0.1,
|
|
117
|
+
ord: int = 2,
|
|
118
|
+
norm_input: Literal["update", "grad", "params"] = "update",
|
|
119
|
+
target: Target = "update",
|
|
120
|
+
):
|
|
121
|
+
defaults = dict(weight_decay=weight_decay, ord=ord, norm_input=norm_input)
|
|
122
|
+
super().__init__(defaults, uses_grad=norm_input == 'grad', target=target)
|
|
123
|
+
|
|
124
|
+
@torch.no_grad
|
|
125
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
126
|
+
weight_decay = NumberList(s['weight_decay'] for s in settings)
|
|
127
|
+
|
|
128
|
+
ord = settings[0]['ord']
|
|
129
|
+
norm_input = settings[0]['norm_input']
|
|
130
|
+
|
|
131
|
+
if norm_input == 'update': src = TensorList(tensors)
|
|
132
|
+
elif norm_input == 'grad':
|
|
133
|
+
assert grads is not None
|
|
134
|
+
src = TensorList(grads)
|
|
135
|
+
elif norm_input == 'params':
|
|
136
|
+
src = TensorList(params)
|
|
137
|
+
else:
|
|
138
|
+
raise ValueError(norm_input)
|
|
139
|
+
|
|
140
|
+
mean_abs = src.abs().global_mean()
|
|
141
|
+
|
|
142
|
+
return weight_decay_(as_tensorlist(tensors), as_tensorlist(params), weight_decay * mean_abs, ord)
|
|
143
|
+
|
|
144
|
+
|
|
34
145
|
@torch.no_grad
|
|
35
146
|
def decay_weights_(params: Iterable[torch.Tensor], weight_decay: float | NumberList, ord:int=2):
|
|
36
147
|
"""directly decays weights in-place"""
|
|
@@ -38,15 +149,20 @@ def decay_weights_(params: Iterable[torch.Tensor], weight_decay: float | NumberL
|
|
|
38
149
|
weight_decay_(params, params, -weight_decay, ord)
|
|
39
150
|
|
|
40
151
|
class DirectWeightDecay(Module):
|
|
41
|
-
"""
|
|
152
|
+
"""Directly applies weight decay to parameters.
|
|
153
|
+
|
|
154
|
+
Args:
|
|
155
|
+
weight_decay (float): weight decay scale.
|
|
156
|
+
ord (int, optional): order of the penalty, e.g. 1 for L1 and 2 for L2. Defaults to 2.
|
|
157
|
+
"""
|
|
42
158
|
def __init__(self, weight_decay: float, ord: int = 2,):
|
|
43
159
|
defaults = dict(weight_decay=weight_decay, ord=ord)
|
|
44
160
|
super().__init__(defaults)
|
|
45
161
|
|
|
46
162
|
@torch.no_grad
|
|
47
|
-
def step(self,
|
|
48
|
-
weight_decay = self.get_settings('weight_decay',
|
|
49
|
-
ord = self.settings[
|
|
163
|
+
def step(self, var):
|
|
164
|
+
weight_decay = self.get_settings(var.params, 'weight_decay', cls=NumberList)
|
|
165
|
+
ord = self.settings[var.params[0]]['ord']
|
|
50
166
|
|
|
51
|
-
decay_weights_(
|
|
52
|
-
return
|
|
167
|
+
decay_weights_(var.params, weight_decay, ord)
|
|
168
|
+
return var
|