torchzero 0.3.10__py3-none-any.whl → 0.3.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_identical.py +2 -3
- tests/test_opts.py +140 -100
- tests/test_tensorlist.py +8 -7
- tests/test_vars.py +1 -0
- torchzero/__init__.py +1 -1
- torchzero/core/__init__.py +2 -2
- torchzero/core/module.py +335 -50
- torchzero/core/reformulation.py +65 -0
- torchzero/core/transform.py +197 -70
- torchzero/modules/__init__.py +13 -4
- torchzero/modules/adaptive/__init__.py +30 -0
- torchzero/modules/adaptive/adagrad.py +356 -0
- torchzero/modules/adaptive/adahessian.py +224 -0
- torchzero/modules/{optimizers → adaptive}/adam.py +6 -8
- torchzero/modules/adaptive/adan.py +96 -0
- torchzero/modules/adaptive/adaptive_heavyball.py +54 -0
- torchzero/modules/adaptive/aegd.py +54 -0
- torchzero/modules/adaptive/esgd.py +171 -0
- torchzero/modules/{optimizers → adaptive}/lion.py +1 -1
- torchzero/modules/{experimental/spectral.py → adaptive/lmadagrad.py} +94 -71
- torchzero/modules/adaptive/mars.py +79 -0
- torchzero/modules/adaptive/matrix_momentum.py +146 -0
- torchzero/modules/adaptive/msam.py +188 -0
- torchzero/modules/{optimizers → adaptive}/muon.py +29 -5
- torchzero/modules/adaptive/natural_gradient.py +175 -0
- torchzero/modules/{optimizers → adaptive}/orthograd.py +1 -1
- torchzero/modules/{optimizers → adaptive}/rmsprop.py +7 -4
- torchzero/modules/{optimizers → adaptive}/rprop.py +42 -10
- torchzero/modules/adaptive/sam.py +163 -0
- torchzero/modules/{optimizers → adaptive}/shampoo.py +47 -9
- torchzero/modules/{optimizers → adaptive}/soap.py +52 -65
- torchzero/modules/adaptive/sophia_h.py +185 -0
- torchzero/modules/clipping/clipping.py +115 -25
- torchzero/modules/clipping/ema_clipping.py +31 -17
- torchzero/modules/clipping/growth_clipping.py +8 -7
- torchzero/modules/conjugate_gradient/__init__.py +11 -0
- torchzero/modules/conjugate_gradient/cg.py +355 -0
- torchzero/modules/experimental/__init__.py +13 -19
- torchzero/modules/{projections → experimental}/dct.py +11 -11
- torchzero/modules/{projections → experimental}/fft.py +10 -10
- torchzero/modules/experimental/gradmin.py +4 -3
- torchzero/modules/experimental/l_infinity.py +111 -0
- torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +5 -42
- torchzero/modules/experimental/newton_solver.py +79 -17
- torchzero/modules/experimental/newtonnewton.py +32 -15
- torchzero/modules/experimental/reduce_outward_lr.py +4 -4
- torchzero/modules/experimental/scipy_newton_cg.py +105 -0
- torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +13 -55
- torchzero/modules/functional.py +52 -6
- torchzero/modules/grad_approximation/fdm.py +30 -4
- torchzero/modules/grad_approximation/forward_gradient.py +16 -4
- torchzero/modules/grad_approximation/grad_approximator.py +51 -10
- torchzero/modules/grad_approximation/rfdm.py +321 -52
- torchzero/modules/higher_order/__init__.py +1 -1
- torchzero/modules/higher_order/higher_order_newton.py +164 -93
- torchzero/modules/least_squares/__init__.py +1 -0
- torchzero/modules/least_squares/gn.py +161 -0
- torchzero/modules/line_search/__init__.py +4 -4
- torchzero/modules/line_search/_polyinterp.py +289 -0
- torchzero/modules/line_search/adaptive.py +124 -0
- torchzero/modules/line_search/backtracking.py +95 -57
- torchzero/modules/line_search/line_search.py +171 -22
- torchzero/modules/line_search/scipy.py +3 -3
- torchzero/modules/line_search/strong_wolfe.py +327 -199
- torchzero/modules/misc/__init__.py +35 -0
- torchzero/modules/misc/debug.py +48 -0
- torchzero/modules/misc/escape.py +62 -0
- torchzero/modules/misc/gradient_accumulation.py +136 -0
- torchzero/modules/misc/homotopy.py +59 -0
- torchzero/modules/misc/misc.py +383 -0
- torchzero/modules/misc/multistep.py +194 -0
- torchzero/modules/misc/regularization.py +167 -0
- torchzero/modules/misc/split.py +123 -0
- torchzero/modules/{ops → misc}/switch.py +45 -4
- torchzero/modules/momentum/__init__.py +1 -5
- torchzero/modules/momentum/averaging.py +9 -9
- torchzero/modules/momentum/cautious.py +51 -19
- torchzero/modules/momentum/momentum.py +37 -2
- torchzero/modules/ops/__init__.py +11 -31
- torchzero/modules/ops/accumulate.py +6 -10
- torchzero/modules/ops/binary.py +81 -34
- torchzero/modules/{momentum/ema.py → ops/higher_level.py} +16 -39
- torchzero/modules/ops/multi.py +82 -21
- torchzero/modules/ops/reduce.py +16 -8
- torchzero/modules/ops/unary.py +29 -13
- torchzero/modules/ops/utility.py +30 -18
- torchzero/modules/projections/__init__.py +2 -4
- torchzero/modules/projections/cast.py +51 -0
- torchzero/modules/projections/galore.py +3 -1
- torchzero/modules/projections/projection.py +190 -96
- torchzero/modules/quasi_newton/__init__.py +9 -14
- torchzero/modules/quasi_newton/damping.py +105 -0
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -0
- torchzero/modules/quasi_newton/lbfgs.py +286 -173
- torchzero/modules/quasi_newton/lsr1.py +185 -106
- torchzero/modules/quasi_newton/quasi_newton.py +816 -268
- torchzero/modules/restarts/__init__.py +7 -0
- torchzero/modules/restarts/restars.py +252 -0
- torchzero/modules/second_order/__init__.py +3 -2
- torchzero/modules/second_order/multipoint.py +238 -0
- torchzero/modules/second_order/newton.py +292 -68
- torchzero/modules/second_order/newton_cg.py +365 -15
- torchzero/modules/second_order/nystrom.py +104 -1
- torchzero/modules/smoothing/__init__.py +1 -1
- torchzero/modules/smoothing/laplacian.py +14 -4
- torchzero/modules/smoothing/sampling.py +300 -0
- torchzero/modules/step_size/__init__.py +2 -0
- torchzero/modules/step_size/adaptive.py +387 -0
- torchzero/modules/step_size/lr.py +154 -0
- torchzero/modules/termination/__init__.py +14 -0
- torchzero/modules/termination/termination.py +207 -0
- torchzero/modules/trust_region/__init__.py +5 -0
- torchzero/modules/trust_region/cubic_regularization.py +170 -0
- torchzero/modules/trust_region/dogleg.py +92 -0
- torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
- torchzero/modules/trust_region/trust_cg.py +97 -0
- torchzero/modules/trust_region/trust_region.py +350 -0
- torchzero/modules/variance_reduction/__init__.py +1 -0
- torchzero/modules/variance_reduction/svrg.py +208 -0
- torchzero/modules/weight_decay/__init__.py +1 -1
- torchzero/modules/weight_decay/weight_decay.py +94 -11
- torchzero/modules/wrappers/optim_wrapper.py +29 -1
- torchzero/modules/zeroth_order/__init__.py +1 -0
- torchzero/modules/zeroth_order/cd.py +359 -0
- torchzero/optim/root.py +65 -0
- torchzero/optim/utility/split.py +8 -8
- torchzero/optim/wrappers/directsearch.py +39 -3
- torchzero/optim/wrappers/fcmaes.py +24 -15
- torchzero/optim/wrappers/mads.py +5 -6
- torchzero/optim/wrappers/nevergrad.py +16 -1
- torchzero/optim/wrappers/nlopt.py +0 -2
- torchzero/optim/wrappers/optuna.py +3 -3
- torchzero/optim/wrappers/scipy.py +86 -25
- torchzero/utils/__init__.py +40 -4
- torchzero/utils/compile.py +1 -1
- torchzero/utils/derivatives.py +126 -114
- torchzero/utils/linalg/__init__.py +9 -2
- torchzero/utils/linalg/linear_operator.py +329 -0
- torchzero/utils/linalg/matrix_funcs.py +2 -2
- torchzero/utils/linalg/orthogonalize.py +2 -1
- torchzero/utils/linalg/qr.py +2 -2
- torchzero/utils/linalg/solve.py +369 -58
- torchzero/utils/metrics.py +83 -0
- torchzero/utils/numberlist.py +2 -0
- torchzero/utils/python_tools.py +16 -0
- torchzero/utils/tensorlist.py +134 -51
- torchzero/utils/torch_tools.py +9 -4
- torchzero-0.3.13.dist-info/METADATA +14 -0
- torchzero-0.3.13.dist-info/RECORD +166 -0
- {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/top_level.txt +0 -1
- docs/source/conf.py +0 -57
- torchzero/modules/experimental/absoap.py +0 -250
- torchzero/modules/experimental/adadam.py +0 -112
- torchzero/modules/experimental/adamY.py +0 -125
- torchzero/modules/experimental/adasoap.py +0 -172
- torchzero/modules/experimental/diagonal_higher_order_newton.py +0 -225
- torchzero/modules/experimental/eigendescent.py +0 -117
- torchzero/modules/experimental/etf.py +0 -172
- torchzero/modules/experimental/soapy.py +0 -163
- torchzero/modules/experimental/structured_newton.py +0 -111
- torchzero/modules/experimental/subspace_preconditioners.py +0 -138
- torchzero/modules/experimental/tada.py +0 -38
- torchzero/modules/line_search/trust_region.py +0 -73
- torchzero/modules/lr/__init__.py +0 -2
- torchzero/modules/lr/adaptive.py +0 -93
- torchzero/modules/lr/lr.py +0 -63
- torchzero/modules/momentum/matrix_momentum.py +0 -166
- torchzero/modules/ops/debug.py +0 -25
- torchzero/modules/ops/misc.py +0 -418
- torchzero/modules/ops/split.py +0 -75
- torchzero/modules/optimizers/__init__.py +0 -18
- torchzero/modules/optimizers/adagrad.py +0 -155
- torchzero/modules/optimizers/sophia_h.py +0 -129
- torchzero/modules/quasi_newton/cg.py +0 -268
- torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
- torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +0 -266
- torchzero/modules/quasi_newton/olbfgs.py +0 -196
- torchzero/modules/smoothing/gaussian.py +0 -164
- torchzero-0.3.10.dist-info/METADATA +0 -379
- torchzero-0.3.10.dist-info/RECORD +0 -139
- torchzero-0.3.10.dist-info/licenses/LICENSE +0 -21
- {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,300 @@
|
|
|
1
|
+
import math
|
|
2
|
+
from abc import ABC, abstractmethod
|
|
3
|
+
from collections.abc import Callable, Sequence
|
|
4
|
+
from contextlib import nullcontext
|
|
5
|
+
from functools import partial
|
|
6
|
+
from typing import Literal, cast
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
|
|
10
|
+
from ...core import Chainable, Modular, Module, Var
|
|
11
|
+
from ...core.reformulation import Reformulation
|
|
12
|
+
from ...utils import Distributions, NumberList, TensorList
|
|
13
|
+
from ..termination import TerminationCriteriaBase, make_termination_criteria
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def _reset_except_self(optimizer: Modular, var: Var, self: Module):
|
|
17
|
+
for m in optimizer.unrolled_modules:
|
|
18
|
+
if m is not self:
|
|
19
|
+
m.reset()
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class GradientSampling(Reformulation):
|
|
23
|
+
"""Samples and aggregates gradients and values at perturbed points.
|
|
24
|
+
|
|
25
|
+
This module can be used for gaussian homotopy and gradient sampling methods.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
modules (Chainable | None, optional):
|
|
29
|
+
modules that will be optimizing the modified objective.
|
|
30
|
+
if None, returns gradient of the modified objective as the update. Defaults to None.
|
|
31
|
+
sigma (float, optional): initial magnitude of the perturbations. Defaults to 1.
|
|
32
|
+
n (int, optional): number of perturbations per step. Defaults to 100.
|
|
33
|
+
aggregate (str, optional):
|
|
34
|
+
how to aggregate values and gradients
|
|
35
|
+
- "mean" - uses mean of the gradients, as in gaussian homotopy.
|
|
36
|
+
- "max" - uses element-wise maximum of the gradients.
|
|
37
|
+
- "min" - uses element-wise minimum of the gradients.
|
|
38
|
+
- "min-norm" - picks gradient with the lowest norm.
|
|
39
|
+
|
|
40
|
+
Defaults to 'mean'.
|
|
41
|
+
distribution (Distributions, optional): distribution for random perturbations. Defaults to 'gaussian'.
|
|
42
|
+
include_x0 (bool, optional): whether to include gradient at un-perturbed point. Defaults to True.
|
|
43
|
+
fixed (bool, optional):
|
|
44
|
+
if True, perturbations do not get replaced by new random perturbations until termination criteria is satisfied. Defaults to True.
|
|
45
|
+
pre_generate (bool, optional):
|
|
46
|
+
if True, perturbations are pre-generated before each step.
|
|
47
|
+
This requires more memory to store all of them,
|
|
48
|
+
but ensures they do not change when closure is evaluated multiple times.
|
|
49
|
+
Defaults to True.
|
|
50
|
+
termination (TerminationCriteriaBase | Sequence[TerminationCriteriaBase] | None, optional):
|
|
51
|
+
a termination criteria module, sigma will be multiplied by ``decay`` when termination criteria is satisfied,
|
|
52
|
+
and new perturbations will be generated if ``fixed``. Defaults to None.
|
|
53
|
+
decay (float, optional): sigma multiplier on termination criteria. Defaults to 2/3.
|
|
54
|
+
reset_on_termination (bool, optional): whether to reset states of all other modules on termination. Defaults to True.
|
|
55
|
+
sigma_strategy (str | None, optional):
|
|
56
|
+
strategy for adapting sigma. If condition is satisfied, sigma is multiplied by ``sigma_nplus``,
|
|
57
|
+
otherwise it is multiplied by ``sigma_nminus``.
|
|
58
|
+
- "grad-norm" - at least ``sigma_target`` gradients should have lower norm than at un-perturbed point.
|
|
59
|
+
- "value" - at least ``sigma_target`` values (losses) should be lower than at un-perturbed point.
|
|
60
|
+
- None - doesn't use adaptive sigma.
|
|
61
|
+
|
|
62
|
+
This introduces a side-effect to the closure, so it should be left at None of you use
|
|
63
|
+
trust region or line search to optimize the modified objective.
|
|
64
|
+
Defaults to None.
|
|
65
|
+
sigma_target (int, optional):
|
|
66
|
+
number of elements to satisfy the condition in ``sigma_strategy``. Defaults to 1.
|
|
67
|
+
sigma_nplus (float, optional): sigma multiplier when ``sigma_strategy`` condition is satisfied. Defaults to 4/3.
|
|
68
|
+
sigma_nminus (float, optional): sigma multiplier when ``sigma_strategy`` condition is not satisfied. Defaults to 2/3.
|
|
69
|
+
seed (int | None, optional): seed. Defaults to None.
|
|
70
|
+
"""
|
|
71
|
+
def __init__(
|
|
72
|
+
self,
|
|
73
|
+
modules: Chainable | None = None,
|
|
74
|
+
sigma: float = 1.,
|
|
75
|
+
n:int = 100,
|
|
76
|
+
aggregate: Literal['mean', 'max', 'min', 'min-norm', 'min-value'] = 'mean',
|
|
77
|
+
distribution: Distributions = 'gaussian',
|
|
78
|
+
include_x0: bool = True,
|
|
79
|
+
|
|
80
|
+
fixed: bool=True,
|
|
81
|
+
pre_generate: bool = True,
|
|
82
|
+
termination: TerminationCriteriaBase | Sequence[TerminationCriteriaBase] | None = None,
|
|
83
|
+
decay: float = 2/3,
|
|
84
|
+
reset_on_termination: bool = True,
|
|
85
|
+
|
|
86
|
+
sigma_strategy: Literal['grad-norm', 'value'] | None = None,
|
|
87
|
+
sigma_target: int | float = 0.2,
|
|
88
|
+
sigma_nplus: float = 4/3,
|
|
89
|
+
sigma_nminus: float = 2/3,
|
|
90
|
+
|
|
91
|
+
seed: int | None = None,
|
|
92
|
+
):
|
|
93
|
+
|
|
94
|
+
defaults = dict(sigma=sigma, n=n, aggregate=aggregate, distribution=distribution, seed=seed, include_x0=include_x0, fixed=fixed, decay=decay, reset_on_termination=reset_on_termination, sigma_strategy=sigma_strategy, sigma_target=sigma_target, sigma_nplus=sigma_nplus, sigma_nminus=sigma_nminus, pre_generate=pre_generate)
|
|
95
|
+
super().__init__(defaults, modules)
|
|
96
|
+
|
|
97
|
+
if termination is not None:
|
|
98
|
+
self.set_child('termination', make_termination_criteria(extra=termination))
|
|
99
|
+
|
|
100
|
+
@torch.no_grad
|
|
101
|
+
def pre_step(self, var):
|
|
102
|
+
params = TensorList(var.params)
|
|
103
|
+
|
|
104
|
+
fixed = self.defaults['fixed']
|
|
105
|
+
|
|
106
|
+
# check termination criteria
|
|
107
|
+
if 'termination' in self.children:
|
|
108
|
+
termination = cast(TerminationCriteriaBase, self.children['termination'])
|
|
109
|
+
if termination.should_terminate(var):
|
|
110
|
+
|
|
111
|
+
# decay sigmas
|
|
112
|
+
states = [self.state[p] for p in params]
|
|
113
|
+
settings = [self.settings[p] for p in params]
|
|
114
|
+
|
|
115
|
+
for state, setting in zip(states, settings):
|
|
116
|
+
if 'sigma' not in state: state['sigma'] = setting['sigma']
|
|
117
|
+
state['sigma'] *= setting['decay']
|
|
118
|
+
|
|
119
|
+
# reset on sigmas decay
|
|
120
|
+
if self.defaults['reset_on_termination']:
|
|
121
|
+
var.post_step_hooks.append(partial(_reset_except_self, self=self))
|
|
122
|
+
|
|
123
|
+
# clear perturbations
|
|
124
|
+
self.global_state.pop('perts', None)
|
|
125
|
+
|
|
126
|
+
# pre-generate perturbations if not already pre-generated or not fixed
|
|
127
|
+
if self.defaults['pre_generate'] and (('perts' not in self.global_state) or (not fixed)):
|
|
128
|
+
states = [self.state[p] for p in params]
|
|
129
|
+
settings = [self.settings[p] for p in params]
|
|
130
|
+
|
|
131
|
+
n = self.defaults['n'] - self.defaults['include_x0']
|
|
132
|
+
generator = self.get_generator(params[0].device, self.defaults['seed'])
|
|
133
|
+
|
|
134
|
+
perts = [params.sample_like(self.defaults['distribution'], generator=generator) for _ in range(n)]
|
|
135
|
+
|
|
136
|
+
self.global_state['perts'] = perts
|
|
137
|
+
|
|
138
|
+
@torch.no_grad
|
|
139
|
+
def closure(self, backward, closure, params, var):
|
|
140
|
+
params = TensorList(params)
|
|
141
|
+
loss_agg = None
|
|
142
|
+
grad_agg = None
|
|
143
|
+
|
|
144
|
+
states = [self.state[p] for p in params]
|
|
145
|
+
settings = [self.settings[p] for p in params]
|
|
146
|
+
sigma_inits = [s['sigma'] for s in settings]
|
|
147
|
+
sigmas = [s.setdefault('sigma', si) for s, si in zip(states, sigma_inits)]
|
|
148
|
+
|
|
149
|
+
include_x0 = self.defaults['include_x0']
|
|
150
|
+
pre_generate = self.defaults['pre_generate']
|
|
151
|
+
aggregate: Literal['mean', 'max', 'min', 'min-norm', 'min-value'] = self.defaults['aggregate']
|
|
152
|
+
sigma_strategy: Literal['grad-norm', 'value'] | None = self.defaults['sigma_strategy']
|
|
153
|
+
distribution = self.defaults['distribution']
|
|
154
|
+
generator = self.get_generator(params[0].device, self.defaults['seed'])
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
n_finite = 0
|
|
158
|
+
n_good = 0
|
|
159
|
+
f_0 = None; g_0 = None
|
|
160
|
+
|
|
161
|
+
# evaluate at x_0
|
|
162
|
+
if include_x0:
|
|
163
|
+
f_0 = cast(torch.Tensor, var.get_loss(backward=backward))
|
|
164
|
+
|
|
165
|
+
isfinite = math.isfinite(f_0)
|
|
166
|
+
if isfinite:
|
|
167
|
+
n_finite += 1
|
|
168
|
+
loss_agg = f_0
|
|
169
|
+
|
|
170
|
+
if backward:
|
|
171
|
+
g_0 = var.get_grad()
|
|
172
|
+
if isfinite: grad_agg = g_0
|
|
173
|
+
|
|
174
|
+
# evaluate at x_0 + p for each perturbation
|
|
175
|
+
if pre_generate:
|
|
176
|
+
perts = self.global_state['perts']
|
|
177
|
+
else:
|
|
178
|
+
perts = [None] * (self.defaults['n'] - include_x0)
|
|
179
|
+
|
|
180
|
+
x_0 = [p.clone() for p in params]
|
|
181
|
+
|
|
182
|
+
for pert in perts:
|
|
183
|
+
loss = None; grad = None
|
|
184
|
+
|
|
185
|
+
# generate if not pre-generated
|
|
186
|
+
if pert is None:
|
|
187
|
+
pert = params.sample_like(distribution, generator=generator)
|
|
188
|
+
|
|
189
|
+
# add perturbation and evaluate
|
|
190
|
+
pert = pert * sigmas
|
|
191
|
+
torch._foreach_add_(params, pert)
|
|
192
|
+
|
|
193
|
+
with torch.enable_grad() if backward else nullcontext():
|
|
194
|
+
loss = closure(backward)
|
|
195
|
+
|
|
196
|
+
if math.isfinite(loss):
|
|
197
|
+
n_finite += 1
|
|
198
|
+
|
|
199
|
+
# add loss
|
|
200
|
+
if loss_agg is None:
|
|
201
|
+
loss_agg = loss
|
|
202
|
+
else:
|
|
203
|
+
if aggregate == 'mean':
|
|
204
|
+
loss_agg += loss
|
|
205
|
+
|
|
206
|
+
elif (aggregate=='min') or (aggregate=='min-value') or (aggregate=='min-norm' and not backward):
|
|
207
|
+
loss_agg = loss_agg.clamp(max=loss)
|
|
208
|
+
|
|
209
|
+
elif aggregate == 'max':
|
|
210
|
+
loss_agg = loss_agg.clamp(min=loss)
|
|
211
|
+
|
|
212
|
+
# add grad
|
|
213
|
+
if backward:
|
|
214
|
+
grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
|
|
215
|
+
if grad_agg is None:
|
|
216
|
+
grad_agg = grad
|
|
217
|
+
else:
|
|
218
|
+
if aggregate == 'mean':
|
|
219
|
+
torch._foreach_add_(grad_agg, grad)
|
|
220
|
+
|
|
221
|
+
elif aggregate == 'min':
|
|
222
|
+
grad_agg_abs = torch._foreach_abs(grad_agg)
|
|
223
|
+
torch._foreach_minimum_(grad_agg_abs, torch._foreach_abs(grad))
|
|
224
|
+
grad_agg = [g_abs.copysign(g) for g_abs, g in zip(grad_agg_abs, grad_agg)]
|
|
225
|
+
|
|
226
|
+
elif aggregate == 'max':
|
|
227
|
+
grad_agg_abs = torch._foreach_abs(grad_agg)
|
|
228
|
+
torch._foreach_maximum_(grad_agg_abs, torch._foreach_abs(grad))
|
|
229
|
+
grad_agg = [g_abs.copysign(g) for g_abs, g in zip(grad_agg_abs, grad_agg)]
|
|
230
|
+
|
|
231
|
+
elif aggregate == 'min-norm':
|
|
232
|
+
if TensorList(grad).global_vector_norm() < TensorList(grad_agg).global_vector_norm():
|
|
233
|
+
grad_agg = grad
|
|
234
|
+
loss_agg = loss
|
|
235
|
+
|
|
236
|
+
elif aggregate == 'min-value':
|
|
237
|
+
if loss < loss_agg:
|
|
238
|
+
grad_agg = grad
|
|
239
|
+
loss_agg = loss
|
|
240
|
+
|
|
241
|
+
# undo perturbation
|
|
242
|
+
torch._foreach_copy_(params, x_0)
|
|
243
|
+
|
|
244
|
+
# adaptive sigma
|
|
245
|
+
# by value
|
|
246
|
+
if sigma_strategy == 'value':
|
|
247
|
+
if f_0 is None:
|
|
248
|
+
with torch.enable_grad() if backward else nullcontext():
|
|
249
|
+
f_0 = closure(False)
|
|
250
|
+
|
|
251
|
+
if loss < f_0:
|
|
252
|
+
n_good += 1
|
|
253
|
+
|
|
254
|
+
# by gradient norm
|
|
255
|
+
elif sigma_strategy == 'grad-norm' and backward and math.isfinite(loss):
|
|
256
|
+
assert grad is not None
|
|
257
|
+
if g_0 is None:
|
|
258
|
+
with torch.enable_grad() if backward else nullcontext():
|
|
259
|
+
closure()
|
|
260
|
+
g_0 = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
|
|
261
|
+
|
|
262
|
+
if TensorList(grad).global_vector_norm() < TensorList(g_0).global_vector_norm():
|
|
263
|
+
n_good += 1
|
|
264
|
+
|
|
265
|
+
# update sigma if strategy is enabled
|
|
266
|
+
if sigma_strategy is not None:
|
|
267
|
+
|
|
268
|
+
sigma_target = self.defaults['sigma_target']
|
|
269
|
+
if isinstance(sigma_target, float):
|
|
270
|
+
sigma_target = int(max(1, n_finite * sigma_target))
|
|
271
|
+
|
|
272
|
+
if n_good >= sigma_target:
|
|
273
|
+
key = 'sigma_nplus'
|
|
274
|
+
else:
|
|
275
|
+
key = 'sigma_nminus'
|
|
276
|
+
|
|
277
|
+
for p in params:
|
|
278
|
+
self.state[p]['sigma'] *= self.settings[p][key]
|
|
279
|
+
|
|
280
|
+
# if no finite losses, just return inf
|
|
281
|
+
if n_finite == 0:
|
|
282
|
+
assert loss_agg is None and grad_agg is None
|
|
283
|
+
loss = torch.tensor(torch.inf, dtype=params[0].dtype, device=params[0].device)
|
|
284
|
+
grad = [torch.full_like(p, torch.inf) for p in params]
|
|
285
|
+
return loss, grad
|
|
286
|
+
|
|
287
|
+
assert loss_agg is not None
|
|
288
|
+
|
|
289
|
+
# no post processing needed when aggregate is 'max', 'min', 'min-norm', 'min-value'
|
|
290
|
+
if aggregate != 'mean':
|
|
291
|
+
return loss_agg, grad_agg
|
|
292
|
+
|
|
293
|
+
# on mean divide by number of evals
|
|
294
|
+
loss_agg /= n_finite
|
|
295
|
+
|
|
296
|
+
if backward:
|
|
297
|
+
assert grad_agg is not None
|
|
298
|
+
torch._foreach_div_(grad_agg, n_finite)
|
|
299
|
+
|
|
300
|
+
return loss_agg, grad_agg
|
|
@@ -0,0 +1,387 @@
|
|
|
1
|
+
"""Various step size strategies"""
|
|
2
|
+
import math
|
|
3
|
+
from operator import itemgetter
|
|
4
|
+
from typing import Any, Literal
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
|
|
8
|
+
from ...core import Chainable, Transform
|
|
9
|
+
from ...utils import NumberList, TensorList, tofloat, unpack_dicts, unpack_states
|
|
10
|
+
from ...utils.linalg.linear_operator import ScaledIdentity
|
|
11
|
+
from ..functional import epsilon_step_size
|
|
12
|
+
|
|
13
|
+
def _acceptable_alpha(alpha, param:torch.Tensor):
|
|
14
|
+
finfo = torch.finfo(param.dtype)
|
|
15
|
+
if (alpha is None) or (alpha < finfo.tiny*2) or (not math.isfinite(alpha)) or (alpha > finfo.max/2):
|
|
16
|
+
return False
|
|
17
|
+
return True
|
|
18
|
+
|
|
19
|
+
def _get_H(self: Transform, var):
|
|
20
|
+
n = sum(p.numel() for p in var.params)
|
|
21
|
+
p = var.params[0]
|
|
22
|
+
alpha = self.global_state.get('alpha', 1)
|
|
23
|
+
if not _acceptable_alpha(alpha, p): alpha = 1
|
|
24
|
+
|
|
25
|
+
return ScaledIdentity(1 / alpha, shape=(n,n), device=p.device, dtype=p.dtype)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class PolyakStepSize(Transform):
|
|
29
|
+
"""Polyak's subgradient method with known or unknown f*.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
f_star (float | Mone, optional):
|
|
33
|
+
minimal possible value of the objective function. If not known, set to ``None``. Defaults to 0.
|
|
34
|
+
y (float, optional):
|
|
35
|
+
when ``f_star`` is set to None, it is calculated as ``f_best - y``.
|
|
36
|
+
y_decay (float, optional):
|
|
37
|
+
``y`` is multiplied by ``(1 - y_decay)`` after each step. Defaults to 1e-3.
|
|
38
|
+
max (float | None, optional): maximum possible step size. Defaults to None.
|
|
39
|
+
use_grad (bool, optional):
|
|
40
|
+
if True, uses dot product of update and gradient to compute the step size.
|
|
41
|
+
Otherwise, dot product of update with itself is used.
|
|
42
|
+
alpha (float, optional): multiplier to Polyak step-size. Defaults to 1.
|
|
43
|
+
"""
|
|
44
|
+
def __init__(self, f_star: float | None = 0, y: float = 1, y_decay: float = 1e-3, max: float | None = None, use_grad=True, alpha: float = 1, inner: Chainable | None = None):
|
|
45
|
+
|
|
46
|
+
defaults = dict(alpha=alpha, max=max, f_star=f_star, y=y, y_decay=y_decay)
|
|
47
|
+
super().__init__(defaults, uses_grad=use_grad, uses_loss=True, inner=inner)
|
|
48
|
+
|
|
49
|
+
@torch.no_grad
|
|
50
|
+
def update_tensors(self, tensors, params, grads, loss, states, settings):
|
|
51
|
+
assert grads is not None and loss is not None
|
|
52
|
+
tensors = TensorList(tensors)
|
|
53
|
+
grads = TensorList(grads)
|
|
54
|
+
|
|
55
|
+
# load variables
|
|
56
|
+
max, f_star, y, y_decay = itemgetter('max', 'f_star', 'y', 'y_decay')(settings[0])
|
|
57
|
+
y_val = self.global_state.get('y_val', y)
|
|
58
|
+
f_best = self.global_state.get('f_best', None)
|
|
59
|
+
|
|
60
|
+
# gg
|
|
61
|
+
if self._uses_grad: gg = tensors.dot(grads)
|
|
62
|
+
else: gg = tensors.dot(tensors)
|
|
63
|
+
|
|
64
|
+
# store loss
|
|
65
|
+
if f_best is None or loss < f_best: f_best = tofloat(loss)
|
|
66
|
+
if f_star is None: f_star = f_best - y_val
|
|
67
|
+
|
|
68
|
+
# calculate the step size
|
|
69
|
+
if gg <= torch.finfo(gg.dtype).tiny * 2: alpha = 0 # converged
|
|
70
|
+
else: alpha = (loss - f_star) / gg
|
|
71
|
+
|
|
72
|
+
# clip
|
|
73
|
+
if max is not None:
|
|
74
|
+
if alpha > max: alpha = max
|
|
75
|
+
|
|
76
|
+
# store state
|
|
77
|
+
self.global_state['f_best'] = f_best
|
|
78
|
+
self.global_state['y_val'] = y_val * (1 - y_decay)
|
|
79
|
+
self.global_state['alpha'] = alpha
|
|
80
|
+
|
|
81
|
+
@torch.no_grad
|
|
82
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
83
|
+
alpha = self.global_state.get('alpha', 1)
|
|
84
|
+
if not _acceptable_alpha(alpha, tensors[0]): alpha = epsilon_step_size(TensorList(tensors))
|
|
85
|
+
|
|
86
|
+
torch._foreach_mul_(tensors, alpha * unpack_dicts(settings, 'alpha', cls=NumberList))
|
|
87
|
+
return tensors
|
|
88
|
+
|
|
89
|
+
def get_H(self, var):
|
|
90
|
+
return _get_H(self, var)
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def _bb_short(s: TensorList, y: TensorList, sy, eps):
|
|
94
|
+
yy = y.dot(y)
|
|
95
|
+
if yy < eps:
|
|
96
|
+
if sy < eps: return None # try to fallback on long
|
|
97
|
+
ss = s.dot(s)
|
|
98
|
+
return ss/sy
|
|
99
|
+
return sy/yy
|
|
100
|
+
|
|
101
|
+
def _bb_long(s: TensorList, y: TensorList, sy, eps):
|
|
102
|
+
ss = s.dot(s)
|
|
103
|
+
if sy < eps:
|
|
104
|
+
yy = y.dot(y) # try to fallback on short
|
|
105
|
+
if yy < eps: return None
|
|
106
|
+
return sy/yy
|
|
107
|
+
return ss/sy
|
|
108
|
+
|
|
109
|
+
def _bb_geom(s: TensorList, y: TensorList, sy, eps, fallback:bool):
|
|
110
|
+
short = _bb_short(s, y, sy, eps)
|
|
111
|
+
long = _bb_long(s, y, sy, eps)
|
|
112
|
+
if long is None or short is None:
|
|
113
|
+
if fallback:
|
|
114
|
+
if short is not None: return short
|
|
115
|
+
if long is not None: return long
|
|
116
|
+
return None
|
|
117
|
+
return (short * long) ** 0.5
|
|
118
|
+
|
|
119
|
+
class BarzilaiBorwein(Transform):
|
|
120
|
+
"""Barzilai-Borwein step size method.
|
|
121
|
+
|
|
122
|
+
Args:
|
|
123
|
+
type (str, optional):
|
|
124
|
+
one of "short" with formula sᵀy/yᵀy, "long" with formula sᵀs/sᵀy, or "geom" to use geometric mean of short and long.
|
|
125
|
+
Defaults to "geom".
|
|
126
|
+
fallback (float, optional): step size when denominator is less than 0 (will happen on negative curvature). Defaults to 1e-3.
|
|
127
|
+
inner (Chainable | None, optional):
|
|
128
|
+
step size will be applied to outputs of this module. Defaults to None.
|
|
129
|
+
"""
|
|
130
|
+
|
|
131
|
+
def __init__(
|
|
132
|
+
self,
|
|
133
|
+
type: Literal["long", "short", "geom", "geom-fallback"] = "geom",
|
|
134
|
+
alpha_0: float = 1e-7,
|
|
135
|
+
use_grad=True,
|
|
136
|
+
inner: Chainable | None = None,
|
|
137
|
+
):
|
|
138
|
+
defaults = dict(type=type, alpha_0=alpha_0)
|
|
139
|
+
super().__init__(defaults, uses_grad=use_grad, inner=inner)
|
|
140
|
+
|
|
141
|
+
def reset_for_online(self):
|
|
142
|
+
super().reset_for_online()
|
|
143
|
+
self.clear_state_keys('prev_g')
|
|
144
|
+
self.global_state['reset'] = True
|
|
145
|
+
|
|
146
|
+
@torch.no_grad
|
|
147
|
+
def update_tensors(self, tensors, params, grads, loss, states, settings):
|
|
148
|
+
step = self.global_state.get('step', 0)
|
|
149
|
+
self.global_state['step'] = step + 1
|
|
150
|
+
|
|
151
|
+
prev_p, prev_g = unpack_states(states, tensors, 'prev_p', 'prev_g', cls=TensorList)
|
|
152
|
+
type = self.defaults['type']
|
|
153
|
+
|
|
154
|
+
g = grads if self._uses_grad else tensors
|
|
155
|
+
assert g is not None
|
|
156
|
+
|
|
157
|
+
reset = self.global_state.get('reset', False)
|
|
158
|
+
self.global_state.pop('reset', None)
|
|
159
|
+
|
|
160
|
+
if step != 0 and not reset:
|
|
161
|
+
s = params-prev_p
|
|
162
|
+
y = g-prev_g
|
|
163
|
+
sy = s.dot(y)
|
|
164
|
+
eps = torch.finfo(sy.dtype).tiny * 2
|
|
165
|
+
|
|
166
|
+
if type == 'short': alpha = _bb_short(s, y, sy, eps)
|
|
167
|
+
elif type == 'long': alpha = _bb_long(s, y, sy, eps)
|
|
168
|
+
elif type == 'geom': alpha = _bb_geom(s, y, sy, eps, fallback=False)
|
|
169
|
+
elif type == 'geom-fallback': alpha = _bb_geom(s, y, sy, eps, fallback=True)
|
|
170
|
+
else: raise ValueError(type)
|
|
171
|
+
|
|
172
|
+
# if alpha is not None:
|
|
173
|
+
self.global_state['alpha'] = alpha
|
|
174
|
+
|
|
175
|
+
prev_p.copy_(params)
|
|
176
|
+
prev_g.copy_(g)
|
|
177
|
+
|
|
178
|
+
def get_H(self, var):
|
|
179
|
+
return _get_H(self, var)
|
|
180
|
+
|
|
181
|
+
@torch.no_grad
|
|
182
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
183
|
+
alpha = self.global_state.get('alpha', None)
|
|
184
|
+
|
|
185
|
+
if not _acceptable_alpha(alpha, tensors[0]):
|
|
186
|
+
alpha = epsilon_step_size(TensorList(tensors), settings[0]['alpha_0'])
|
|
187
|
+
|
|
188
|
+
torch._foreach_mul_(tensors, alpha)
|
|
189
|
+
return tensors
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
class BBStab(Transform):
|
|
193
|
+
"""Stabilized Barzilai-Borwein method (https://arxiv.org/abs/1907.06409).
|
|
194
|
+
|
|
195
|
+
This clips the norm of the Barzilai-Borwein update by ``delta``, where ``delta`` can be adaptive if ``c`` is specified.
|
|
196
|
+
|
|
197
|
+
Args:
|
|
198
|
+
c (float, optional):
|
|
199
|
+
adaptive delta parameter. If ``delta`` is set to None, first ``inf_iters`` updates are performed
|
|
200
|
+
with non-stabilized Barzilai-Borwein step size. Then delta is set to norm of
|
|
201
|
+
the update that had the smallest norm, and multiplied by ``c``. Defaults to 0.2.
|
|
202
|
+
delta (float | None, optional):
|
|
203
|
+
Barzilai-Borwein update is clipped to this value. Set to ``None`` to use an adaptive choice. Defaults to None.
|
|
204
|
+
type (str, optional):
|
|
205
|
+
one of "short" with formula sᵀy/yᵀy, "long" with formula sᵀs/sᵀy, or "geom" to use geometric mean of short and long.
|
|
206
|
+
Defaults to "geom". Note that "long" corresponds to BB1stab and "short" to BB2stab,
|
|
207
|
+
however I found that "geom" works really well.
|
|
208
|
+
inner (Chainable | None, optional):
|
|
209
|
+
step size will be applied to outputs of this module. Defaults to None.
|
|
210
|
+
|
|
211
|
+
"""
|
|
212
|
+
def __init__(
|
|
213
|
+
self,
|
|
214
|
+
c=0.2,
|
|
215
|
+
delta:float | None = None,
|
|
216
|
+
type: Literal["long", "short", "geom", "geom-fallback"] = "geom",
|
|
217
|
+
alpha_0: float = 1e-7,
|
|
218
|
+
use_grad=True,
|
|
219
|
+
inf_iters: int = 3,
|
|
220
|
+
inner: Chainable | None = None,
|
|
221
|
+
):
|
|
222
|
+
defaults = dict(type=type,alpha_0=alpha_0, c=c, delta=delta, inf_iters=inf_iters)
|
|
223
|
+
super().__init__(defaults, uses_grad=use_grad, inner=inner)
|
|
224
|
+
|
|
225
|
+
def reset_for_online(self):
|
|
226
|
+
super().reset_for_online()
|
|
227
|
+
self.clear_state_keys('prev_g')
|
|
228
|
+
self.global_state['reset'] = True
|
|
229
|
+
|
|
230
|
+
@torch.no_grad
|
|
231
|
+
def update_tensors(self, tensors, params, grads, loss, states, settings):
|
|
232
|
+
step = self.global_state.get('step', 0)
|
|
233
|
+
self.global_state['step'] = step + 1
|
|
234
|
+
|
|
235
|
+
prev_p, prev_g = unpack_states(states, tensors, 'prev_p', 'prev_g', cls=TensorList)
|
|
236
|
+
type = self.defaults['type']
|
|
237
|
+
c = self.defaults['c']
|
|
238
|
+
delta = self.defaults['delta']
|
|
239
|
+
inf_iters = self.defaults['inf_iters']
|
|
240
|
+
|
|
241
|
+
g = grads if self._uses_grad else tensors
|
|
242
|
+
assert g is not None
|
|
243
|
+
g = TensorList(g)
|
|
244
|
+
|
|
245
|
+
reset = self.global_state.get('reset', False)
|
|
246
|
+
self.global_state.pop('reset', None)
|
|
247
|
+
|
|
248
|
+
if step != 0 and not reset:
|
|
249
|
+
s = params-prev_p
|
|
250
|
+
y = g-prev_g
|
|
251
|
+
sy = s.dot(y)
|
|
252
|
+
eps = torch.finfo(sy.dtype).tiny
|
|
253
|
+
|
|
254
|
+
if type == 'short': alpha = _bb_short(s, y, sy, eps)
|
|
255
|
+
elif type == 'long': alpha = _bb_long(s, y, sy, eps)
|
|
256
|
+
elif type == 'geom': alpha = _bb_geom(s, y, sy, eps, fallback=False)
|
|
257
|
+
elif type == 'geom-fallback': alpha = _bb_geom(s, y, sy, eps, fallback=True)
|
|
258
|
+
else: raise ValueError(type)
|
|
259
|
+
|
|
260
|
+
if alpha is not None:
|
|
261
|
+
|
|
262
|
+
# adaptive delta
|
|
263
|
+
if delta is None:
|
|
264
|
+
niters = self.global_state.get('niters', 0) # this accounts for skipped negative curvature steps
|
|
265
|
+
self.global_state['niters'] = niters + 1
|
|
266
|
+
|
|
267
|
+
|
|
268
|
+
if niters == 0: pass # 1st iteration is scaled GD step, shouldn't be used to find s_norm_min
|
|
269
|
+
elif niters <= inf_iters:
|
|
270
|
+
s_norm_min = self.global_state.get('s_norm_min', None)
|
|
271
|
+
if s_norm_min is None: s_norm_min = s.global_vector_norm()
|
|
272
|
+
else: s_norm_min = min(s_norm_min, s.global_vector_norm())
|
|
273
|
+
self.global_state['s_norm_min'] = s_norm_min
|
|
274
|
+
# first few steps use delta=inf, so delta remains None
|
|
275
|
+
|
|
276
|
+
else:
|
|
277
|
+
delta = c * self.global_state['s_norm_min']
|
|
278
|
+
|
|
279
|
+
if delta is None: # delta is inf for first few steps
|
|
280
|
+
self.global_state['alpha'] = alpha
|
|
281
|
+
|
|
282
|
+
# BBStab step size
|
|
283
|
+
else:
|
|
284
|
+
a_stab = delta / g.global_vector_norm()
|
|
285
|
+
self.global_state['alpha'] = min(alpha, a_stab)
|
|
286
|
+
|
|
287
|
+
prev_p.copy_(params)
|
|
288
|
+
prev_g.copy_(g)
|
|
289
|
+
|
|
290
|
+
def get_H(self, var):
|
|
291
|
+
return _get_H(self, var)
|
|
292
|
+
|
|
293
|
+
@torch.no_grad
|
|
294
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
295
|
+
alpha = self.global_state.get('alpha', None)
|
|
296
|
+
|
|
297
|
+
if not _acceptable_alpha(alpha, tensors[0]):
|
|
298
|
+
alpha = epsilon_step_size(TensorList(tensors), settings[0]['alpha_0'])
|
|
299
|
+
|
|
300
|
+
torch._foreach_mul_(tensors, alpha)
|
|
301
|
+
return tensors
|
|
302
|
+
|
|
303
|
+
|
|
304
|
+
class AdGD(Transform):
|
|
305
|
+
"""AdGD and AdGD-2 (https://arxiv.org/abs/2308.02261)"""
|
|
306
|
+
def __init__(self, variant:Literal[1,2]=2, alpha_0:float = 1e-7, sqrt:bool=True, use_grad=True, inner: Chainable | None = None,):
|
|
307
|
+
defaults = dict(variant=variant, alpha_0=alpha_0, sqrt=sqrt)
|
|
308
|
+
super().__init__(defaults, uses_grad=use_grad, inner=inner,)
|
|
309
|
+
|
|
310
|
+
def reset_for_online(self):
|
|
311
|
+
super().reset_for_online()
|
|
312
|
+
self.clear_state_keys('prev_g')
|
|
313
|
+
self.global_state['reset'] = True
|
|
314
|
+
|
|
315
|
+
@torch.no_grad
|
|
316
|
+
def update_tensors(self, tensors, params, grads, loss, states, settings):
|
|
317
|
+
variant = settings[0]['variant']
|
|
318
|
+
theta_0 = 0 if variant == 1 else 1/3
|
|
319
|
+
theta = self.global_state.get('theta', theta_0)
|
|
320
|
+
|
|
321
|
+
step = self.global_state.get('step', 0)
|
|
322
|
+
self.global_state['step'] = step + 1
|
|
323
|
+
|
|
324
|
+
p = TensorList(params)
|
|
325
|
+
g = grads if self._uses_grad else tensors
|
|
326
|
+
assert g is not None
|
|
327
|
+
g = TensorList(g)
|
|
328
|
+
|
|
329
|
+
prev_p, prev_g = unpack_states(states, tensors, 'prev_p', 'prev_g', cls=TensorList)
|
|
330
|
+
|
|
331
|
+
# online
|
|
332
|
+
if self.global_state.get('reset', False):
|
|
333
|
+
del self.global_state['reset']
|
|
334
|
+
prev_p.copy_(p)
|
|
335
|
+
prev_g.copy_(g)
|
|
336
|
+
return
|
|
337
|
+
|
|
338
|
+
if step == 0:
|
|
339
|
+
alpha_0 = settings[0]['alpha_0']
|
|
340
|
+
if alpha_0 is None: alpha_0 = epsilon_step_size(g)
|
|
341
|
+
self.global_state['alpha'] = alpha_0
|
|
342
|
+
prev_p.copy_(p)
|
|
343
|
+
prev_g.copy_(g)
|
|
344
|
+
return
|
|
345
|
+
|
|
346
|
+
sqrt = settings[0]['sqrt']
|
|
347
|
+
alpha = self.global_state.get('alpha', math.inf)
|
|
348
|
+
L = (g - prev_g).global_vector_norm() / (p - prev_p).global_vector_norm()
|
|
349
|
+
eps = torch.finfo(L.dtype).tiny * 2
|
|
350
|
+
|
|
351
|
+
if variant == 1:
|
|
352
|
+
a1 = math.sqrt(1 + theta)*alpha
|
|
353
|
+
val = math.sqrt(2) if sqrt else 2
|
|
354
|
+
if L > eps: a2 = 1 / (val*L)
|
|
355
|
+
else: a2 = math.inf
|
|
356
|
+
|
|
357
|
+
elif variant == 2:
|
|
358
|
+
a1 = math.sqrt(2/3 + theta)*alpha
|
|
359
|
+
a2 = alpha / math.sqrt(max(eps, 2 * alpha**2 * L**2 - 1))
|
|
360
|
+
|
|
361
|
+
else:
|
|
362
|
+
raise ValueError(variant)
|
|
363
|
+
|
|
364
|
+
alpha_new = min(a1, a2)
|
|
365
|
+
if alpha_new < 0: alpha_new = max(a1, a2)
|
|
366
|
+
if alpha_new > eps:
|
|
367
|
+
self.global_state['theta'] = alpha_new/alpha
|
|
368
|
+
self.global_state['alpha'] = alpha_new
|
|
369
|
+
|
|
370
|
+
prev_p.copy_(p)
|
|
371
|
+
prev_g.copy_(g)
|
|
372
|
+
|
|
373
|
+
@torch.no_grad
|
|
374
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
375
|
+
alpha = self.global_state.get('alpha', None)
|
|
376
|
+
|
|
377
|
+
if not _acceptable_alpha(alpha, tensors[0]):
|
|
378
|
+
# alpha isn't None on 1st step
|
|
379
|
+
self.state.clear()
|
|
380
|
+
self.global_state.clear()
|
|
381
|
+
alpha = epsilon_step_size(TensorList(tensors), settings[0]['alpha_0'])
|
|
382
|
+
|
|
383
|
+
torch._foreach_mul_(tensors, alpha)
|
|
384
|
+
return tensors
|
|
385
|
+
|
|
386
|
+
def get_H(self, var):
|
|
387
|
+
return _get_H(self, var)
|