torchzero 0.3.11__py3-none-any.whl → 0.3.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_opts.py +95 -76
- tests/test_tensorlist.py +8 -7
- torchzero/__init__.py +1 -1
- torchzero/core/__init__.py +2 -2
- torchzero/core/module.py +229 -72
- torchzero/core/reformulation.py +65 -0
- torchzero/core/transform.py +44 -24
- torchzero/modules/__init__.py +13 -5
- torchzero/modules/{optimizers → adaptive}/__init__.py +5 -2
- torchzero/modules/adaptive/adagrad.py +356 -0
- torchzero/modules/{optimizers → adaptive}/adahessian.py +53 -52
- torchzero/modules/{optimizers → adaptive}/adam.py +0 -3
- torchzero/modules/{optimizers → adaptive}/adan.py +26 -40
- torchzero/modules/{optimizers → adaptive}/adaptive_heavyball.py +3 -6
- torchzero/modules/adaptive/aegd.py +54 -0
- torchzero/modules/{optimizers → adaptive}/esgd.py +1 -1
- torchzero/modules/{optimizers/ladagrad.py → adaptive/lmadagrad.py} +42 -39
- torchzero/modules/{optimizers → adaptive}/mars.py +24 -36
- torchzero/modules/adaptive/matrix_momentum.py +146 -0
- torchzero/modules/{optimizers → adaptive}/msam.py +14 -12
- torchzero/modules/{optimizers → adaptive}/muon.py +19 -20
- torchzero/modules/adaptive/natural_gradient.py +175 -0
- torchzero/modules/{optimizers → adaptive}/rprop.py +0 -2
- torchzero/modules/{optimizers → adaptive}/sam.py +1 -1
- torchzero/modules/{optimizers → adaptive}/shampoo.py +8 -4
- torchzero/modules/{optimizers → adaptive}/soap.py +27 -50
- torchzero/modules/{optimizers → adaptive}/sophia_h.py +2 -3
- torchzero/modules/clipping/clipping.py +85 -92
- torchzero/modules/clipping/ema_clipping.py +5 -5
- torchzero/modules/conjugate_gradient/__init__.py +11 -0
- torchzero/modules/{quasi_newton → conjugate_gradient}/cg.py +355 -369
- torchzero/modules/experimental/__init__.py +9 -32
- torchzero/modules/experimental/dct.py +2 -2
- torchzero/modules/experimental/fft.py +2 -2
- torchzero/modules/experimental/gradmin.py +4 -3
- torchzero/modules/experimental/l_infinity.py +111 -0
- torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +3 -40
- torchzero/modules/experimental/newton_solver.py +79 -17
- torchzero/modules/experimental/newtonnewton.py +27 -14
- torchzero/modules/experimental/scipy_newton_cg.py +105 -0
- torchzero/modules/experimental/spsa1.py +93 -0
- torchzero/modules/experimental/structural_projections.py +1 -1
- torchzero/modules/functional.py +50 -14
- torchzero/modules/grad_approximation/__init__.py +1 -1
- torchzero/modules/grad_approximation/fdm.py +19 -20
- torchzero/modules/grad_approximation/forward_gradient.py +6 -7
- torchzero/modules/grad_approximation/grad_approximator.py +43 -47
- torchzero/modules/grad_approximation/rfdm.py +114 -175
- torchzero/modules/higher_order/__init__.py +1 -1
- torchzero/modules/higher_order/higher_order_newton.py +31 -23
- torchzero/modules/least_squares/__init__.py +1 -0
- torchzero/modules/least_squares/gn.py +161 -0
- torchzero/modules/line_search/__init__.py +2 -2
- torchzero/modules/line_search/_polyinterp.py +289 -0
- torchzero/modules/line_search/adaptive.py +69 -44
- torchzero/modules/line_search/backtracking.py +83 -70
- torchzero/modules/line_search/line_search.py +159 -68
- torchzero/modules/line_search/scipy.py +16 -4
- torchzero/modules/line_search/strong_wolfe.py +319 -220
- torchzero/modules/misc/__init__.py +8 -0
- torchzero/modules/misc/debug.py +4 -4
- torchzero/modules/misc/escape.py +9 -7
- torchzero/modules/misc/gradient_accumulation.py +88 -22
- torchzero/modules/misc/homotopy.py +59 -0
- torchzero/modules/misc/misc.py +82 -15
- torchzero/modules/misc/multistep.py +47 -11
- torchzero/modules/misc/regularization.py +5 -9
- torchzero/modules/misc/split.py +55 -35
- torchzero/modules/misc/switch.py +1 -1
- torchzero/modules/momentum/__init__.py +1 -5
- torchzero/modules/momentum/averaging.py +3 -3
- torchzero/modules/momentum/cautious.py +42 -47
- torchzero/modules/momentum/momentum.py +35 -1
- torchzero/modules/ops/__init__.py +9 -1
- torchzero/modules/ops/binary.py +9 -8
- torchzero/modules/{momentum/ema.py → ops/higher_level.py} +10 -33
- torchzero/modules/ops/multi.py +15 -15
- torchzero/modules/ops/reduce.py +1 -1
- torchzero/modules/ops/utility.py +12 -8
- torchzero/modules/projections/projection.py +4 -4
- torchzero/modules/quasi_newton/__init__.py +1 -16
- torchzero/modules/quasi_newton/damping.py +105 -0
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -163
- torchzero/modules/quasi_newton/lbfgs.py +256 -200
- torchzero/modules/quasi_newton/lsr1.py +167 -132
- torchzero/modules/quasi_newton/quasi_newton.py +346 -446
- torchzero/modules/restarts/__init__.py +7 -0
- torchzero/modules/restarts/restars.py +253 -0
- torchzero/modules/second_order/__init__.py +2 -1
- torchzero/modules/second_order/multipoint.py +238 -0
- torchzero/modules/second_order/newton.py +133 -88
- torchzero/modules/second_order/newton_cg.py +207 -170
- torchzero/modules/smoothing/__init__.py +1 -1
- torchzero/modules/smoothing/sampling.py +300 -0
- torchzero/modules/step_size/__init__.py +1 -1
- torchzero/modules/step_size/adaptive.py +312 -47
- torchzero/modules/termination/__init__.py +14 -0
- torchzero/modules/termination/termination.py +207 -0
- torchzero/modules/trust_region/__init__.py +5 -0
- torchzero/modules/trust_region/cubic_regularization.py +170 -0
- torchzero/modules/trust_region/dogleg.py +92 -0
- torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
- torchzero/modules/trust_region/trust_cg.py +99 -0
- torchzero/modules/trust_region/trust_region.py +350 -0
- torchzero/modules/variance_reduction/__init__.py +1 -0
- torchzero/modules/variance_reduction/svrg.py +208 -0
- torchzero/modules/weight_decay/weight_decay.py +65 -64
- torchzero/modules/zeroth_order/__init__.py +1 -0
- torchzero/modules/zeroth_order/cd.py +122 -0
- torchzero/optim/root.py +65 -0
- torchzero/optim/utility/split.py +8 -8
- torchzero/optim/wrappers/directsearch.py +0 -1
- torchzero/optim/wrappers/fcmaes.py +3 -2
- torchzero/optim/wrappers/nlopt.py +0 -2
- torchzero/optim/wrappers/optuna.py +2 -2
- torchzero/optim/wrappers/scipy.py +81 -22
- torchzero/utils/__init__.py +40 -4
- torchzero/utils/compile.py +1 -1
- torchzero/utils/derivatives.py +123 -111
- torchzero/utils/linalg/__init__.py +9 -2
- torchzero/utils/linalg/linear_operator.py +329 -0
- torchzero/utils/linalg/matrix_funcs.py +2 -2
- torchzero/utils/linalg/orthogonalize.py +2 -1
- torchzero/utils/linalg/qr.py +2 -2
- torchzero/utils/linalg/solve.py +226 -154
- torchzero/utils/metrics.py +83 -0
- torchzero/utils/optimizer.py +2 -2
- torchzero/utils/python_tools.py +7 -0
- torchzero/utils/tensorlist.py +105 -34
- torchzero/utils/torch_tools.py +9 -4
- torchzero-0.3.14.dist-info/METADATA +14 -0
- torchzero-0.3.14.dist-info/RECORD +167 -0
- {torchzero-0.3.11.dist-info → torchzero-0.3.14.dist-info}/top_level.txt +0 -1
- docs/source/conf.py +0 -59
- docs/source/docstring template.py +0 -46
- torchzero/modules/experimental/absoap.py +0 -253
- torchzero/modules/experimental/adadam.py +0 -118
- torchzero/modules/experimental/adamY.py +0 -131
- torchzero/modules/experimental/adam_lambertw.py +0 -149
- torchzero/modules/experimental/adaptive_step_size.py +0 -90
- torchzero/modules/experimental/adasoap.py +0 -177
- torchzero/modules/experimental/cosine.py +0 -214
- torchzero/modules/experimental/cubic_adam.py +0 -97
- torchzero/modules/experimental/eigendescent.py +0 -120
- torchzero/modules/experimental/etf.py +0 -195
- torchzero/modules/experimental/exp_adam.py +0 -113
- torchzero/modules/experimental/expanded_lbfgs.py +0 -141
- torchzero/modules/experimental/hnewton.py +0 -85
- torchzero/modules/experimental/modular_lbfgs.py +0 -265
- torchzero/modules/experimental/parabolic_search.py +0 -220
- torchzero/modules/experimental/subspace_preconditioners.py +0 -145
- torchzero/modules/experimental/tensor_adagrad.py +0 -42
- torchzero/modules/line_search/polynomial.py +0 -233
- torchzero/modules/momentum/matrix_momentum.py +0 -193
- torchzero/modules/optimizers/adagrad.py +0 -165
- torchzero/modules/quasi_newton/trust_region.py +0 -397
- torchzero/modules/smoothing/gaussian.py +0 -198
- torchzero-0.3.11.dist-info/METADATA +0 -404
- torchzero-0.3.11.dist-info/RECORD +0 -159
- torchzero-0.3.11.dist-info/licenses/LICENSE +0 -21
- /torchzero/modules/{optimizers → adaptive}/lion.py +0 -0
- /torchzero/modules/{optimizers → adaptive}/orthograd.py +0 -0
- /torchzero/modules/{optimizers → adaptive}/rmsprop.py +0 -0
- {torchzero-0.3.11.dist-info → torchzero-0.3.14.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,300 @@
|
|
|
1
|
+
import math
|
|
2
|
+
from abc import ABC, abstractmethod
|
|
3
|
+
from collections.abc import Callable, Sequence
|
|
4
|
+
from contextlib import nullcontext
|
|
5
|
+
from functools import partial
|
|
6
|
+
from typing import Literal, cast
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
|
|
10
|
+
from ...core import Chainable, Modular, Module, Var
|
|
11
|
+
from ...core.reformulation import Reformulation
|
|
12
|
+
from ...utils import Distributions, NumberList, TensorList
|
|
13
|
+
from ..termination import TerminationCriteriaBase, make_termination_criteria
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def _reset_except_self(optimizer: Modular, var: Var, self: Module):
|
|
17
|
+
for m in optimizer.unrolled_modules:
|
|
18
|
+
if m is not self:
|
|
19
|
+
m.reset()
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class GradientSampling(Reformulation):
|
|
23
|
+
"""Samples and aggregates gradients and values at perturbed points.
|
|
24
|
+
|
|
25
|
+
This module can be used for gaussian homotopy and gradient sampling methods.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
modules (Chainable | None, optional):
|
|
29
|
+
modules that will be optimizing the modified objective.
|
|
30
|
+
if None, returns gradient of the modified objective as the update. Defaults to None.
|
|
31
|
+
sigma (float, optional): initial magnitude of the perturbations. Defaults to 1.
|
|
32
|
+
n (int, optional): number of perturbations per step. Defaults to 100.
|
|
33
|
+
aggregate (str, optional):
|
|
34
|
+
how to aggregate values and gradients
|
|
35
|
+
- "mean" - uses mean of the gradients, as in gaussian homotopy.
|
|
36
|
+
- "max" - uses element-wise maximum of the gradients.
|
|
37
|
+
- "min" - uses element-wise minimum of the gradients.
|
|
38
|
+
- "min-norm" - picks gradient with the lowest norm.
|
|
39
|
+
|
|
40
|
+
Defaults to 'mean'.
|
|
41
|
+
distribution (Distributions, optional): distribution for random perturbations. Defaults to 'gaussian'.
|
|
42
|
+
include_x0 (bool, optional): whether to include gradient at un-perturbed point. Defaults to True.
|
|
43
|
+
fixed (bool, optional):
|
|
44
|
+
if True, perturbations do not get replaced by new random perturbations until termination criteria is satisfied. Defaults to True.
|
|
45
|
+
pre_generate (bool, optional):
|
|
46
|
+
if True, perturbations are pre-generated before each step.
|
|
47
|
+
This requires more memory to store all of them,
|
|
48
|
+
but ensures they do not change when closure is evaluated multiple times.
|
|
49
|
+
Defaults to True.
|
|
50
|
+
termination (TerminationCriteriaBase | Sequence[TerminationCriteriaBase] | None, optional):
|
|
51
|
+
a termination criteria module, sigma will be multiplied by ``decay`` when termination criteria is satisfied,
|
|
52
|
+
and new perturbations will be generated if ``fixed``. Defaults to None.
|
|
53
|
+
decay (float, optional): sigma multiplier on termination criteria. Defaults to 2/3.
|
|
54
|
+
reset_on_termination (bool, optional): whether to reset states of all other modules on termination. Defaults to True.
|
|
55
|
+
sigma_strategy (str | None, optional):
|
|
56
|
+
strategy for adapting sigma. If condition is satisfied, sigma is multiplied by ``sigma_nplus``,
|
|
57
|
+
otherwise it is multiplied by ``sigma_nminus``.
|
|
58
|
+
- "grad-norm" - at least ``sigma_target`` gradients should have lower norm than at un-perturbed point.
|
|
59
|
+
- "value" - at least ``sigma_target`` values (losses) should be lower than at un-perturbed point.
|
|
60
|
+
- None - doesn't use adaptive sigma.
|
|
61
|
+
|
|
62
|
+
This introduces a side-effect to the closure, so it should be left at None of you use
|
|
63
|
+
trust region or line search to optimize the modified objective.
|
|
64
|
+
Defaults to None.
|
|
65
|
+
sigma_target (int, optional):
|
|
66
|
+
number of elements to satisfy the condition in ``sigma_strategy``. Defaults to 1.
|
|
67
|
+
sigma_nplus (float, optional): sigma multiplier when ``sigma_strategy`` condition is satisfied. Defaults to 4/3.
|
|
68
|
+
sigma_nminus (float, optional): sigma multiplier when ``sigma_strategy`` condition is not satisfied. Defaults to 2/3.
|
|
69
|
+
seed (int | None, optional): seed. Defaults to None.
|
|
70
|
+
"""
|
|
71
|
+
def __init__(
|
|
72
|
+
self,
|
|
73
|
+
modules: Chainable | None = None,
|
|
74
|
+
sigma: float = 1.,
|
|
75
|
+
n:int = 100,
|
|
76
|
+
aggregate: Literal['mean', 'max', 'min', 'min-norm', 'min-value'] = 'mean',
|
|
77
|
+
distribution: Distributions = 'gaussian',
|
|
78
|
+
include_x0: bool = True,
|
|
79
|
+
|
|
80
|
+
fixed: bool=True,
|
|
81
|
+
pre_generate: bool = True,
|
|
82
|
+
termination: TerminationCriteriaBase | Sequence[TerminationCriteriaBase] | None = None,
|
|
83
|
+
decay: float = 2/3,
|
|
84
|
+
reset_on_termination: bool = True,
|
|
85
|
+
|
|
86
|
+
sigma_strategy: Literal['grad-norm', 'value'] | None = None,
|
|
87
|
+
sigma_target: int | float = 0.2,
|
|
88
|
+
sigma_nplus: float = 4/3,
|
|
89
|
+
sigma_nminus: float = 2/3,
|
|
90
|
+
|
|
91
|
+
seed: int | None = None,
|
|
92
|
+
):
|
|
93
|
+
|
|
94
|
+
defaults = dict(sigma=sigma, n=n, aggregate=aggregate, distribution=distribution, seed=seed, include_x0=include_x0, fixed=fixed, decay=decay, reset_on_termination=reset_on_termination, sigma_strategy=sigma_strategy, sigma_target=sigma_target, sigma_nplus=sigma_nplus, sigma_nminus=sigma_nminus, pre_generate=pre_generate)
|
|
95
|
+
super().__init__(defaults, modules)
|
|
96
|
+
|
|
97
|
+
if termination is not None:
|
|
98
|
+
self.set_child('termination', make_termination_criteria(extra=termination))
|
|
99
|
+
|
|
100
|
+
@torch.no_grad
|
|
101
|
+
def pre_step(self, var):
|
|
102
|
+
params = TensorList(var.params)
|
|
103
|
+
|
|
104
|
+
fixed = self.defaults['fixed']
|
|
105
|
+
|
|
106
|
+
# check termination criteria
|
|
107
|
+
if 'termination' in self.children:
|
|
108
|
+
termination = cast(TerminationCriteriaBase, self.children['termination'])
|
|
109
|
+
if termination.should_terminate(var):
|
|
110
|
+
|
|
111
|
+
# decay sigmas
|
|
112
|
+
states = [self.state[p] for p in params]
|
|
113
|
+
settings = [self.settings[p] for p in params]
|
|
114
|
+
|
|
115
|
+
for state, setting in zip(states, settings):
|
|
116
|
+
if 'sigma' not in state: state['sigma'] = setting['sigma']
|
|
117
|
+
state['sigma'] *= setting['decay']
|
|
118
|
+
|
|
119
|
+
# reset on sigmas decay
|
|
120
|
+
if self.defaults['reset_on_termination']:
|
|
121
|
+
var.post_step_hooks.append(partial(_reset_except_self, self=self))
|
|
122
|
+
|
|
123
|
+
# clear perturbations
|
|
124
|
+
self.global_state.pop('perts', None)
|
|
125
|
+
|
|
126
|
+
# pre-generate perturbations if not already pre-generated or not fixed
|
|
127
|
+
if self.defaults['pre_generate'] and (('perts' not in self.global_state) or (not fixed)):
|
|
128
|
+
states = [self.state[p] for p in params]
|
|
129
|
+
settings = [self.settings[p] for p in params]
|
|
130
|
+
|
|
131
|
+
n = self.defaults['n'] - self.defaults['include_x0']
|
|
132
|
+
generator = self.get_generator(params[0].device, self.defaults['seed'])
|
|
133
|
+
|
|
134
|
+
perts = [params.sample_like(self.defaults['distribution'], generator=generator) for _ in range(n)]
|
|
135
|
+
|
|
136
|
+
self.global_state['perts'] = perts
|
|
137
|
+
|
|
138
|
+
@torch.no_grad
|
|
139
|
+
def closure(self, backward, closure, params, var):
|
|
140
|
+
params = TensorList(params)
|
|
141
|
+
loss_agg = None
|
|
142
|
+
grad_agg = None
|
|
143
|
+
|
|
144
|
+
states = [self.state[p] for p in params]
|
|
145
|
+
settings = [self.settings[p] for p in params]
|
|
146
|
+
sigma_inits = [s['sigma'] for s in settings]
|
|
147
|
+
sigmas = [s.setdefault('sigma', si) for s, si in zip(states, sigma_inits)]
|
|
148
|
+
|
|
149
|
+
include_x0 = self.defaults['include_x0']
|
|
150
|
+
pre_generate = self.defaults['pre_generate']
|
|
151
|
+
aggregate: Literal['mean', 'max', 'min', 'min-norm', 'min-value'] = self.defaults['aggregate']
|
|
152
|
+
sigma_strategy: Literal['grad-norm', 'value'] | None = self.defaults['sigma_strategy']
|
|
153
|
+
distribution = self.defaults['distribution']
|
|
154
|
+
generator = self.get_generator(params[0].device, self.defaults['seed'])
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
n_finite = 0
|
|
158
|
+
n_good = 0
|
|
159
|
+
f_0 = None; g_0 = None
|
|
160
|
+
|
|
161
|
+
# evaluate at x_0
|
|
162
|
+
if include_x0:
|
|
163
|
+
f_0 = cast(torch.Tensor, var.get_loss(backward=backward))
|
|
164
|
+
|
|
165
|
+
isfinite = math.isfinite(f_0)
|
|
166
|
+
if isfinite:
|
|
167
|
+
n_finite += 1
|
|
168
|
+
loss_agg = f_0
|
|
169
|
+
|
|
170
|
+
if backward:
|
|
171
|
+
g_0 = var.get_grad()
|
|
172
|
+
if isfinite: grad_agg = g_0
|
|
173
|
+
|
|
174
|
+
# evaluate at x_0 + p for each perturbation
|
|
175
|
+
if pre_generate:
|
|
176
|
+
perts = self.global_state['perts']
|
|
177
|
+
else:
|
|
178
|
+
perts = [None] * (self.defaults['n'] - include_x0)
|
|
179
|
+
|
|
180
|
+
x_0 = [p.clone() for p in params]
|
|
181
|
+
|
|
182
|
+
for pert in perts:
|
|
183
|
+
loss = None; grad = None
|
|
184
|
+
|
|
185
|
+
# generate if not pre-generated
|
|
186
|
+
if pert is None:
|
|
187
|
+
pert = params.sample_like(distribution, generator=generator)
|
|
188
|
+
|
|
189
|
+
# add perturbation and evaluate
|
|
190
|
+
pert = pert * sigmas
|
|
191
|
+
torch._foreach_add_(params, pert)
|
|
192
|
+
|
|
193
|
+
with torch.enable_grad() if backward else nullcontext():
|
|
194
|
+
loss = closure(backward)
|
|
195
|
+
|
|
196
|
+
if math.isfinite(loss):
|
|
197
|
+
n_finite += 1
|
|
198
|
+
|
|
199
|
+
# add loss
|
|
200
|
+
if loss_agg is None:
|
|
201
|
+
loss_agg = loss
|
|
202
|
+
else:
|
|
203
|
+
if aggregate == 'mean':
|
|
204
|
+
loss_agg += loss
|
|
205
|
+
|
|
206
|
+
elif (aggregate=='min') or (aggregate=='min-value') or (aggregate=='min-norm' and not backward):
|
|
207
|
+
loss_agg = loss_agg.clamp(max=loss)
|
|
208
|
+
|
|
209
|
+
elif aggregate == 'max':
|
|
210
|
+
loss_agg = loss_agg.clamp(min=loss)
|
|
211
|
+
|
|
212
|
+
# add grad
|
|
213
|
+
if backward:
|
|
214
|
+
grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
|
|
215
|
+
if grad_agg is None:
|
|
216
|
+
grad_agg = grad
|
|
217
|
+
else:
|
|
218
|
+
if aggregate == 'mean':
|
|
219
|
+
torch._foreach_add_(grad_agg, grad)
|
|
220
|
+
|
|
221
|
+
elif aggregate == 'min':
|
|
222
|
+
grad_agg_abs = torch._foreach_abs(grad_agg)
|
|
223
|
+
torch._foreach_minimum_(grad_agg_abs, torch._foreach_abs(grad))
|
|
224
|
+
grad_agg = [g_abs.copysign(g) for g_abs, g in zip(grad_agg_abs, grad_agg)]
|
|
225
|
+
|
|
226
|
+
elif aggregate == 'max':
|
|
227
|
+
grad_agg_abs = torch._foreach_abs(grad_agg)
|
|
228
|
+
torch._foreach_maximum_(grad_agg_abs, torch._foreach_abs(grad))
|
|
229
|
+
grad_agg = [g_abs.copysign(g) for g_abs, g in zip(grad_agg_abs, grad_agg)]
|
|
230
|
+
|
|
231
|
+
elif aggregate == 'min-norm':
|
|
232
|
+
if TensorList(grad).global_vector_norm() < TensorList(grad_agg).global_vector_norm():
|
|
233
|
+
grad_agg = grad
|
|
234
|
+
loss_agg = loss
|
|
235
|
+
|
|
236
|
+
elif aggregate == 'min-value':
|
|
237
|
+
if loss < loss_agg:
|
|
238
|
+
grad_agg = grad
|
|
239
|
+
loss_agg = loss
|
|
240
|
+
|
|
241
|
+
# undo perturbation
|
|
242
|
+
torch._foreach_copy_(params, x_0)
|
|
243
|
+
|
|
244
|
+
# adaptive sigma
|
|
245
|
+
# by value
|
|
246
|
+
if sigma_strategy == 'value':
|
|
247
|
+
if f_0 is None:
|
|
248
|
+
with torch.enable_grad() if backward else nullcontext():
|
|
249
|
+
f_0 = closure(False)
|
|
250
|
+
|
|
251
|
+
if loss < f_0:
|
|
252
|
+
n_good += 1
|
|
253
|
+
|
|
254
|
+
# by gradient norm
|
|
255
|
+
elif sigma_strategy == 'grad-norm' and backward and math.isfinite(loss):
|
|
256
|
+
assert grad is not None
|
|
257
|
+
if g_0 is None:
|
|
258
|
+
with torch.enable_grad() if backward else nullcontext():
|
|
259
|
+
closure()
|
|
260
|
+
g_0 = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
|
|
261
|
+
|
|
262
|
+
if TensorList(grad).global_vector_norm() < TensorList(g_0).global_vector_norm():
|
|
263
|
+
n_good += 1
|
|
264
|
+
|
|
265
|
+
# update sigma if strategy is enabled
|
|
266
|
+
if sigma_strategy is not None:
|
|
267
|
+
|
|
268
|
+
sigma_target = self.defaults['sigma_target']
|
|
269
|
+
if isinstance(sigma_target, float):
|
|
270
|
+
sigma_target = int(max(1, n_finite * sigma_target))
|
|
271
|
+
|
|
272
|
+
if n_good >= sigma_target:
|
|
273
|
+
key = 'sigma_nplus'
|
|
274
|
+
else:
|
|
275
|
+
key = 'sigma_nminus'
|
|
276
|
+
|
|
277
|
+
for p in params:
|
|
278
|
+
self.state[p]['sigma'] *= self.settings[p][key]
|
|
279
|
+
|
|
280
|
+
# if no finite losses, just return inf
|
|
281
|
+
if n_finite == 0:
|
|
282
|
+
assert loss_agg is None and grad_agg is None
|
|
283
|
+
loss = torch.tensor(torch.inf, dtype=params[0].dtype, device=params[0].device)
|
|
284
|
+
grad = [torch.full_like(p, torch.inf) for p in params]
|
|
285
|
+
return loss, grad
|
|
286
|
+
|
|
287
|
+
assert loss_agg is not None
|
|
288
|
+
|
|
289
|
+
# no post processing needed when aggregate is 'max', 'min', 'min-norm', 'min-value'
|
|
290
|
+
if aggregate != 'mean':
|
|
291
|
+
return loss_agg, grad_agg
|
|
292
|
+
|
|
293
|
+
# on mean divide by number of evals
|
|
294
|
+
loss_agg /= n_finite
|
|
295
|
+
|
|
296
|
+
if backward:
|
|
297
|
+
assert grad_agg is not None
|
|
298
|
+
torch._foreach_div_(grad_agg, n_finite)
|
|
299
|
+
|
|
300
|
+
return loss_agg, grad_agg
|
|
@@ -1,2 +1,2 @@
|
|
|
1
1
|
from .lr import LR, StepSize, Warmup, WarmupNormClip, RandomStepSize
|
|
2
|
-
from .adaptive import PolyakStepSize, BarzilaiBorwein
|
|
2
|
+
from .adaptive import PolyakStepSize, BarzilaiBorwein, BBStab, AdGD
|