torchzero 0.3.10__py3-none-any.whl → 0.3.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_identical.py +2 -3
- tests/test_opts.py +140 -100
- tests/test_tensorlist.py +8 -7
- tests/test_vars.py +1 -0
- torchzero/__init__.py +1 -1
- torchzero/core/__init__.py +2 -2
- torchzero/core/module.py +335 -50
- torchzero/core/reformulation.py +65 -0
- torchzero/core/transform.py +197 -70
- torchzero/modules/__init__.py +13 -4
- torchzero/modules/adaptive/__init__.py +30 -0
- torchzero/modules/adaptive/adagrad.py +356 -0
- torchzero/modules/adaptive/adahessian.py +224 -0
- torchzero/modules/{optimizers → adaptive}/adam.py +6 -8
- torchzero/modules/adaptive/adan.py +96 -0
- torchzero/modules/adaptive/adaptive_heavyball.py +54 -0
- torchzero/modules/adaptive/aegd.py +54 -0
- torchzero/modules/adaptive/esgd.py +171 -0
- torchzero/modules/{optimizers → adaptive}/lion.py +1 -1
- torchzero/modules/{experimental/spectral.py → adaptive/lmadagrad.py} +94 -71
- torchzero/modules/adaptive/mars.py +79 -0
- torchzero/modules/adaptive/matrix_momentum.py +146 -0
- torchzero/modules/adaptive/msam.py +188 -0
- torchzero/modules/{optimizers → adaptive}/muon.py +29 -5
- torchzero/modules/adaptive/natural_gradient.py +175 -0
- torchzero/modules/{optimizers → adaptive}/orthograd.py +1 -1
- torchzero/modules/{optimizers → adaptive}/rmsprop.py +7 -4
- torchzero/modules/{optimizers → adaptive}/rprop.py +42 -10
- torchzero/modules/adaptive/sam.py +163 -0
- torchzero/modules/{optimizers → adaptive}/shampoo.py +47 -9
- torchzero/modules/{optimizers → adaptive}/soap.py +52 -65
- torchzero/modules/adaptive/sophia_h.py +185 -0
- torchzero/modules/clipping/clipping.py +115 -25
- torchzero/modules/clipping/ema_clipping.py +31 -17
- torchzero/modules/clipping/growth_clipping.py +8 -7
- torchzero/modules/conjugate_gradient/__init__.py +11 -0
- torchzero/modules/conjugate_gradient/cg.py +355 -0
- torchzero/modules/experimental/__init__.py +13 -19
- torchzero/modules/{projections → experimental}/dct.py +11 -11
- torchzero/modules/{projections → experimental}/fft.py +10 -10
- torchzero/modules/experimental/gradmin.py +4 -3
- torchzero/modules/experimental/l_infinity.py +111 -0
- torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +5 -42
- torchzero/modules/experimental/newton_solver.py +79 -17
- torchzero/modules/experimental/newtonnewton.py +32 -15
- torchzero/modules/experimental/reduce_outward_lr.py +4 -4
- torchzero/modules/experimental/scipy_newton_cg.py +105 -0
- torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +13 -55
- torchzero/modules/functional.py +52 -6
- torchzero/modules/grad_approximation/fdm.py +30 -4
- torchzero/modules/grad_approximation/forward_gradient.py +16 -4
- torchzero/modules/grad_approximation/grad_approximator.py +51 -10
- torchzero/modules/grad_approximation/rfdm.py +321 -52
- torchzero/modules/higher_order/__init__.py +1 -1
- torchzero/modules/higher_order/higher_order_newton.py +164 -93
- torchzero/modules/least_squares/__init__.py +1 -0
- torchzero/modules/least_squares/gn.py +161 -0
- torchzero/modules/line_search/__init__.py +4 -4
- torchzero/modules/line_search/_polyinterp.py +289 -0
- torchzero/modules/line_search/adaptive.py +124 -0
- torchzero/modules/line_search/backtracking.py +95 -57
- torchzero/modules/line_search/line_search.py +171 -22
- torchzero/modules/line_search/scipy.py +3 -3
- torchzero/modules/line_search/strong_wolfe.py +327 -199
- torchzero/modules/misc/__init__.py +35 -0
- torchzero/modules/misc/debug.py +48 -0
- torchzero/modules/misc/escape.py +62 -0
- torchzero/modules/misc/gradient_accumulation.py +136 -0
- torchzero/modules/misc/homotopy.py +59 -0
- torchzero/modules/misc/misc.py +383 -0
- torchzero/modules/misc/multistep.py +194 -0
- torchzero/modules/misc/regularization.py +167 -0
- torchzero/modules/misc/split.py +123 -0
- torchzero/modules/{ops → misc}/switch.py +45 -4
- torchzero/modules/momentum/__init__.py +1 -5
- torchzero/modules/momentum/averaging.py +9 -9
- torchzero/modules/momentum/cautious.py +51 -19
- torchzero/modules/momentum/momentum.py +37 -2
- torchzero/modules/ops/__init__.py +11 -31
- torchzero/modules/ops/accumulate.py +6 -10
- torchzero/modules/ops/binary.py +81 -34
- torchzero/modules/{momentum/ema.py → ops/higher_level.py} +16 -39
- torchzero/modules/ops/multi.py +82 -21
- torchzero/modules/ops/reduce.py +16 -8
- torchzero/modules/ops/unary.py +29 -13
- torchzero/modules/ops/utility.py +30 -18
- torchzero/modules/projections/__init__.py +2 -4
- torchzero/modules/projections/cast.py +51 -0
- torchzero/modules/projections/galore.py +3 -1
- torchzero/modules/projections/projection.py +190 -96
- torchzero/modules/quasi_newton/__init__.py +9 -14
- torchzero/modules/quasi_newton/damping.py +105 -0
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -0
- torchzero/modules/quasi_newton/lbfgs.py +286 -173
- torchzero/modules/quasi_newton/lsr1.py +185 -106
- torchzero/modules/quasi_newton/quasi_newton.py +816 -268
- torchzero/modules/restarts/__init__.py +7 -0
- torchzero/modules/restarts/restars.py +252 -0
- torchzero/modules/second_order/__init__.py +3 -2
- torchzero/modules/second_order/multipoint.py +238 -0
- torchzero/modules/second_order/newton.py +292 -68
- torchzero/modules/second_order/newton_cg.py +365 -15
- torchzero/modules/second_order/nystrom.py +104 -1
- torchzero/modules/smoothing/__init__.py +1 -1
- torchzero/modules/smoothing/laplacian.py +14 -4
- torchzero/modules/smoothing/sampling.py +300 -0
- torchzero/modules/step_size/__init__.py +2 -0
- torchzero/modules/step_size/adaptive.py +387 -0
- torchzero/modules/step_size/lr.py +154 -0
- torchzero/modules/termination/__init__.py +14 -0
- torchzero/modules/termination/termination.py +207 -0
- torchzero/modules/trust_region/__init__.py +5 -0
- torchzero/modules/trust_region/cubic_regularization.py +170 -0
- torchzero/modules/trust_region/dogleg.py +92 -0
- torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
- torchzero/modules/trust_region/trust_cg.py +97 -0
- torchzero/modules/trust_region/trust_region.py +350 -0
- torchzero/modules/variance_reduction/__init__.py +1 -0
- torchzero/modules/variance_reduction/svrg.py +208 -0
- torchzero/modules/weight_decay/__init__.py +1 -1
- torchzero/modules/weight_decay/weight_decay.py +94 -11
- torchzero/modules/wrappers/optim_wrapper.py +29 -1
- torchzero/modules/zeroth_order/__init__.py +1 -0
- torchzero/modules/zeroth_order/cd.py +359 -0
- torchzero/optim/root.py +65 -0
- torchzero/optim/utility/split.py +8 -8
- torchzero/optim/wrappers/directsearch.py +39 -3
- torchzero/optim/wrappers/fcmaes.py +24 -15
- torchzero/optim/wrappers/mads.py +5 -6
- torchzero/optim/wrappers/nevergrad.py +16 -1
- torchzero/optim/wrappers/nlopt.py +0 -2
- torchzero/optim/wrappers/optuna.py +3 -3
- torchzero/optim/wrappers/scipy.py +86 -25
- torchzero/utils/__init__.py +40 -4
- torchzero/utils/compile.py +1 -1
- torchzero/utils/derivatives.py +126 -114
- torchzero/utils/linalg/__init__.py +9 -2
- torchzero/utils/linalg/linear_operator.py +329 -0
- torchzero/utils/linalg/matrix_funcs.py +2 -2
- torchzero/utils/linalg/orthogonalize.py +2 -1
- torchzero/utils/linalg/qr.py +2 -2
- torchzero/utils/linalg/solve.py +369 -58
- torchzero/utils/metrics.py +83 -0
- torchzero/utils/numberlist.py +2 -0
- torchzero/utils/python_tools.py +16 -0
- torchzero/utils/tensorlist.py +134 -51
- torchzero/utils/torch_tools.py +9 -4
- torchzero-0.3.13.dist-info/METADATA +14 -0
- torchzero-0.3.13.dist-info/RECORD +166 -0
- {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/top_level.txt +0 -1
- docs/source/conf.py +0 -57
- torchzero/modules/experimental/absoap.py +0 -250
- torchzero/modules/experimental/adadam.py +0 -112
- torchzero/modules/experimental/adamY.py +0 -125
- torchzero/modules/experimental/adasoap.py +0 -172
- torchzero/modules/experimental/diagonal_higher_order_newton.py +0 -225
- torchzero/modules/experimental/eigendescent.py +0 -117
- torchzero/modules/experimental/etf.py +0 -172
- torchzero/modules/experimental/soapy.py +0 -163
- torchzero/modules/experimental/structured_newton.py +0 -111
- torchzero/modules/experimental/subspace_preconditioners.py +0 -138
- torchzero/modules/experimental/tada.py +0 -38
- torchzero/modules/line_search/trust_region.py +0 -73
- torchzero/modules/lr/__init__.py +0 -2
- torchzero/modules/lr/adaptive.py +0 -93
- torchzero/modules/lr/lr.py +0 -63
- torchzero/modules/momentum/matrix_momentum.py +0 -166
- torchzero/modules/ops/debug.py +0 -25
- torchzero/modules/ops/misc.py +0 -418
- torchzero/modules/ops/split.py +0 -75
- torchzero/modules/optimizers/__init__.py +0 -18
- torchzero/modules/optimizers/adagrad.py +0 -155
- torchzero/modules/optimizers/sophia_h.py +0 -129
- torchzero/modules/quasi_newton/cg.py +0 -268
- torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
- torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +0 -266
- torchzero/modules/quasi_newton/olbfgs.py +0 -196
- torchzero/modules/smoothing/gaussian.py +0 -164
- torchzero-0.3.10.dist-info/METADATA +0 -379
- torchzero-0.3.10.dist-info/RECORD +0 -139
- torchzero-0.3.10.dist-info/licenses/LICENSE +0 -21
- {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,154 @@
|
|
|
1
|
+
"""Learning rate"""
|
|
2
|
+
import torch
|
|
3
|
+
import random
|
|
4
|
+
|
|
5
|
+
from ...core import Transform
|
|
6
|
+
from ...utils import NumberList, TensorList, generic_ne, unpack_dicts
|
|
7
|
+
|
|
8
|
+
def lazy_lr(tensors: TensorList, lr: float | list, inplace:bool):
|
|
9
|
+
"""multiplies by lr if lr is not 1"""
|
|
10
|
+
if generic_ne(lr, 1):
|
|
11
|
+
if inplace: return tensors.mul_(lr)
|
|
12
|
+
return tensors * lr
|
|
13
|
+
return tensors
|
|
14
|
+
|
|
15
|
+
class LR(Transform):
|
|
16
|
+
"""Learning rate. Adding this module also adds support for LR schedulers."""
|
|
17
|
+
def __init__(self, lr: float):
|
|
18
|
+
defaults=dict(lr=lr)
|
|
19
|
+
super().__init__(defaults, uses_grad=False)
|
|
20
|
+
|
|
21
|
+
@torch.no_grad
|
|
22
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
23
|
+
return lazy_lr(TensorList(tensors), lr=[s['lr'] for s in settings], inplace=True)
|
|
24
|
+
|
|
25
|
+
class StepSize(Transform):
|
|
26
|
+
"""this is exactly the same as LR, except the `lr` parameter can be renamed to any other name to avoid clashes"""
|
|
27
|
+
def __init__(self, step_size: float, key = 'step_size'):
|
|
28
|
+
defaults={"key": key, key: step_size}
|
|
29
|
+
super().__init__(defaults, uses_grad=False)
|
|
30
|
+
|
|
31
|
+
@torch.no_grad
|
|
32
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
33
|
+
return lazy_lr(TensorList(tensors), lr=[s[s['key']] for s in settings], inplace=True)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def _warmup_lr(step: int, start_lr: float | NumberList, end_lr: float | NumberList, steps: float):
|
|
37
|
+
"""returns warm up lr scalar"""
|
|
38
|
+
if step > steps: return end_lr
|
|
39
|
+
return start_lr + (end_lr - start_lr) * (step / steps)
|
|
40
|
+
|
|
41
|
+
class Warmup(Transform):
|
|
42
|
+
"""Learning rate warmup, linearly increases learning rate multiplier from :code:`start_lr` to :code:`end_lr` over :code:`steps` steps.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
steps (int, optional): number of steps to perform warmup for. Defaults to 100.
|
|
46
|
+
start_lr (_type_, optional): initial learning rate multiplier on first step. Defaults to 1e-5.
|
|
47
|
+
end_lr (float, optional): learning rate multiplier at the end and after warmup. Defaults to 1.
|
|
48
|
+
|
|
49
|
+
Example:
|
|
50
|
+
Adam with 1000 steps warmup
|
|
51
|
+
|
|
52
|
+
.. code-block:: python
|
|
53
|
+
|
|
54
|
+
opt = tz.Modular(
|
|
55
|
+
model.parameters(),
|
|
56
|
+
tz.m.Adam(),
|
|
57
|
+
tz.m.LR(1e-2),
|
|
58
|
+
tz.m.Warmup(steps=1000)
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
"""
|
|
62
|
+
def __init__(self, steps = 100, start_lr = 1e-5, end_lr:float = 1):
|
|
63
|
+
defaults = dict(start_lr=start_lr,end_lr=end_lr, steps=steps)
|
|
64
|
+
super().__init__(defaults, uses_grad=False)
|
|
65
|
+
|
|
66
|
+
@torch.no_grad
|
|
67
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
68
|
+
start_lr, end_lr = unpack_dicts(settings, 'start_lr', 'end_lr', cls = NumberList)
|
|
69
|
+
num_steps = settings[0]['steps']
|
|
70
|
+
step = self.global_state.get('step', 0)
|
|
71
|
+
|
|
72
|
+
tensors = lazy_lr(
|
|
73
|
+
TensorList(tensors),
|
|
74
|
+
lr=_warmup_lr(step=step, start_lr=start_lr, end_lr=end_lr, steps=num_steps),
|
|
75
|
+
inplace=True
|
|
76
|
+
)
|
|
77
|
+
self.global_state['step'] = step + 1
|
|
78
|
+
return tensors
|
|
79
|
+
|
|
80
|
+
class WarmupNormClip(Transform):
|
|
81
|
+
"""Warmup via clipping of the update norm.
|
|
82
|
+
|
|
83
|
+
Args:
|
|
84
|
+
start_norm (_type_, optional): maximal norm on the first step. Defaults to 1e-5.
|
|
85
|
+
end_norm (float, optional): maximal norm on the last step. After that, norm clipping is disabled. Defaults to 1.
|
|
86
|
+
steps (int, optional): number of steps to perform warmup for. Defaults to 100.
|
|
87
|
+
|
|
88
|
+
Example:
|
|
89
|
+
Adam with 1000 steps norm clip warmup
|
|
90
|
+
|
|
91
|
+
.. code-block:: python
|
|
92
|
+
|
|
93
|
+
opt = tz.Modular(
|
|
94
|
+
model.parameters(),
|
|
95
|
+
tz.m.Adam(),
|
|
96
|
+
tz.m.WarmupNormClip(steps=1000)
|
|
97
|
+
tz.m.LR(1e-2),
|
|
98
|
+
)
|
|
99
|
+
"""
|
|
100
|
+
def __init__(self, steps = 100, start_norm = 1e-5, end_norm:float = 1):
|
|
101
|
+
defaults = dict(start_norm=start_norm,end_norm=end_norm, steps=steps)
|
|
102
|
+
super().__init__(defaults, uses_grad=False)
|
|
103
|
+
|
|
104
|
+
@torch.no_grad
|
|
105
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
106
|
+
start_norm, end_norm = unpack_dicts(settings, 'start_norm', 'end_norm', cls = NumberList)
|
|
107
|
+
num_steps = settings[0]['steps']
|
|
108
|
+
step = self.global_state.get('step', 0)
|
|
109
|
+
if step > num_steps: return tensors
|
|
110
|
+
|
|
111
|
+
tensors = TensorList(tensors)
|
|
112
|
+
norm = tensors.global_vector_norm()
|
|
113
|
+
current_max_norm = _warmup_lr(step, start_norm[0], end_norm[0], num_steps)
|
|
114
|
+
if norm > current_max_norm:
|
|
115
|
+
tensors.mul_(current_max_norm / norm)
|
|
116
|
+
|
|
117
|
+
self.global_state['step'] = step + 1
|
|
118
|
+
return tensors
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
class RandomStepSize(Transform):
|
|
122
|
+
"""Uses random global or layer-wise step size from `low` to `high`.
|
|
123
|
+
|
|
124
|
+
Args:
|
|
125
|
+
low (float, optional): minimum learning rate. Defaults to 0.
|
|
126
|
+
high (float, optional): maximum learning rate. Defaults to 1.
|
|
127
|
+
parameterwise (bool, optional):
|
|
128
|
+
if True, generate random step size for each parameter separately,
|
|
129
|
+
if False generate one global random step size. Defaults to False.
|
|
130
|
+
"""
|
|
131
|
+
def __init__(self, low: float = 0, high: float = 1, parameterwise=False, seed:int|None=None):
|
|
132
|
+
defaults = dict(low=low, high=high, parameterwise=parameterwise,seed=seed)
|
|
133
|
+
super().__init__(defaults, uses_grad=False)
|
|
134
|
+
|
|
135
|
+
@torch.no_grad
|
|
136
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
137
|
+
s = settings[0]
|
|
138
|
+
parameterwise = s['parameterwise']
|
|
139
|
+
|
|
140
|
+
seed = s['seed']
|
|
141
|
+
if 'generator' not in self.global_state:
|
|
142
|
+
self.global_state['generator'] = random.Random(seed)
|
|
143
|
+
generator: random.Random = self.global_state['generator']
|
|
144
|
+
|
|
145
|
+
if parameterwise:
|
|
146
|
+
low, high = unpack_dicts(settings, 'low', 'high')
|
|
147
|
+
lr = [generator.uniform(l, h) for l, h in zip(low, high)]
|
|
148
|
+
else:
|
|
149
|
+
low = s['low']
|
|
150
|
+
high = s['high']
|
|
151
|
+
lr = generator.uniform(low, high)
|
|
152
|
+
|
|
153
|
+
torch._foreach_mul_(tensors, lr)
|
|
154
|
+
return tensors
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
from .termination import (
|
|
2
|
+
TerminateAfterNEvaluations,
|
|
3
|
+
TerminateAfterNSeconds,
|
|
4
|
+
TerminateAfterNSteps,
|
|
5
|
+
TerminateAll,
|
|
6
|
+
TerminateAny,
|
|
7
|
+
TerminateByGradientNorm,
|
|
8
|
+
TerminateByUpdateNorm,
|
|
9
|
+
TerminateOnLossReached,
|
|
10
|
+
TerminateOnNoImprovement,
|
|
11
|
+
TerminationCriteriaBase,
|
|
12
|
+
TerminateNever,
|
|
13
|
+
make_termination_criteria
|
|
14
|
+
)
|
|
@@ -0,0 +1,207 @@
|
|
|
1
|
+
import time
|
|
2
|
+
from abc import ABC, abstractmethod
|
|
3
|
+
from collections.abc import Sequence
|
|
4
|
+
from typing import cast
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
|
|
8
|
+
from ...core import Module, Var
|
|
9
|
+
from ...utils import Metrics, TensorList, safe_dict_update_, tofloat
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class TerminationCriteriaBase(Module):
|
|
13
|
+
def __init__(self, defaults:dict | None = None, n: int = 1):
|
|
14
|
+
if defaults is None: defaults = {}
|
|
15
|
+
safe_dict_update_(defaults, {"_n": n})
|
|
16
|
+
super().__init__(defaults)
|
|
17
|
+
|
|
18
|
+
@abstractmethod
|
|
19
|
+
def termination_criteria(self, var: Var) -> bool:
|
|
20
|
+
...
|
|
21
|
+
|
|
22
|
+
def should_terminate(self, var: Var) -> bool:
|
|
23
|
+
n_bad = self.global_state.get('_n_bad', 0)
|
|
24
|
+
n = self.defaults['_n']
|
|
25
|
+
|
|
26
|
+
if self.termination_criteria(var):
|
|
27
|
+
n_bad += 1
|
|
28
|
+
if n_bad >= n:
|
|
29
|
+
self.global_state['_n_bad'] = 0
|
|
30
|
+
return True
|
|
31
|
+
|
|
32
|
+
else:
|
|
33
|
+
n_bad = 0
|
|
34
|
+
|
|
35
|
+
self.global_state['_n_bad'] = n_bad
|
|
36
|
+
return False
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def update(self, var):
|
|
40
|
+
var.should_terminate = self.should_terminate(var)
|
|
41
|
+
if var.should_terminate: self.global_state['_n_bad'] = 0
|
|
42
|
+
|
|
43
|
+
def apply(self, var):
|
|
44
|
+
return var
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class TerminateAfterNSteps(TerminationCriteriaBase):
|
|
48
|
+
def __init__(self, steps:int):
|
|
49
|
+
defaults = dict(steps=steps)
|
|
50
|
+
super().__init__(defaults)
|
|
51
|
+
|
|
52
|
+
def termination_criteria(self, var):
|
|
53
|
+
step = self.global_state.get('step', 0)
|
|
54
|
+
self.global_state['step'] = step + 1
|
|
55
|
+
|
|
56
|
+
max_steps = self.defaults['steps']
|
|
57
|
+
return step >= max_steps
|
|
58
|
+
|
|
59
|
+
class TerminateAfterNEvaluations(TerminationCriteriaBase):
|
|
60
|
+
def __init__(self, maxevals:int):
|
|
61
|
+
defaults = dict(maxevals=maxevals)
|
|
62
|
+
super().__init__(defaults)
|
|
63
|
+
|
|
64
|
+
def termination_criteria(self, var):
|
|
65
|
+
maxevals = self.defaults['maxevals']
|
|
66
|
+
return var.modular.num_evaluations >= maxevals
|
|
67
|
+
|
|
68
|
+
class TerminateAfterNSeconds(TerminationCriteriaBase):
|
|
69
|
+
def __init__(self, seconds:float, sec_fn = time.time):
|
|
70
|
+
defaults = dict(seconds=seconds, sec_fn=sec_fn)
|
|
71
|
+
super().__init__(defaults)
|
|
72
|
+
|
|
73
|
+
def termination_criteria(self, var):
|
|
74
|
+
max_seconds = self.defaults['seconds']
|
|
75
|
+
sec_fn = self.defaults['sec_fn']
|
|
76
|
+
|
|
77
|
+
if 'start' not in self.global_state:
|
|
78
|
+
self.global_state['start'] = sec_fn()
|
|
79
|
+
return False
|
|
80
|
+
|
|
81
|
+
seconds_passed = sec_fn() - self.global_state['start']
|
|
82
|
+
return seconds_passed >= max_seconds
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
class TerminateByGradientNorm(TerminationCriteriaBase):
|
|
87
|
+
def __init__(self, tol:float = 1e-8, n: int = 3, ord: Metrics = 2):
|
|
88
|
+
defaults = dict(tol=tol, ord=ord)
|
|
89
|
+
super().__init__(defaults, n=n)
|
|
90
|
+
|
|
91
|
+
def termination_criteria(self, var):
|
|
92
|
+
tol = self.defaults['tol']
|
|
93
|
+
ord = self.defaults['ord']
|
|
94
|
+
return TensorList(var.get_grad()).global_metric(ord) <= tol
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
class TerminateByUpdateNorm(TerminationCriteriaBase):
|
|
98
|
+
"""update is calculated as parameter difference"""
|
|
99
|
+
def __init__(self, tol:float = 1e-8, n: int = 3, ord: Metrics = 2):
|
|
100
|
+
defaults = dict(tol=tol, ord=ord)
|
|
101
|
+
super().__init__(defaults, n=n)
|
|
102
|
+
|
|
103
|
+
def termination_criteria(self, var):
|
|
104
|
+
step = self.global_state.get('step', 0)
|
|
105
|
+
self.global_state['step'] = step + 1
|
|
106
|
+
|
|
107
|
+
tol = self.defaults['tol']
|
|
108
|
+
ord = self.defaults['ord']
|
|
109
|
+
|
|
110
|
+
p_prev = self.get_state(var.params, 'p_prev', cls=TensorList)
|
|
111
|
+
if step == 0:
|
|
112
|
+
p_prev.copy_(var.params)
|
|
113
|
+
return False
|
|
114
|
+
|
|
115
|
+
should_terminate = (p_prev - var.params).global_metric(ord) <= tol
|
|
116
|
+
p_prev.copy_(var.params)
|
|
117
|
+
return should_terminate
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
class TerminateOnNoImprovement(TerminationCriteriaBase):
|
|
121
|
+
def __init__(self, tol:float = 1e-8, n: int = 10):
|
|
122
|
+
defaults = dict(tol=tol)
|
|
123
|
+
super().__init__(defaults, n=n)
|
|
124
|
+
|
|
125
|
+
def termination_criteria(self, var):
|
|
126
|
+
tol = self.defaults['tol']
|
|
127
|
+
|
|
128
|
+
f = tofloat(var.get_loss(False))
|
|
129
|
+
if 'f_min' not in self.global_state:
|
|
130
|
+
self.global_state['f_min'] = f
|
|
131
|
+
return False
|
|
132
|
+
|
|
133
|
+
f_min = self.global_state['f_min']
|
|
134
|
+
d = f_min - f
|
|
135
|
+
should_terminate = d <= tol
|
|
136
|
+
self.global_state['f_min'] = min(f, f_min)
|
|
137
|
+
return should_terminate
|
|
138
|
+
|
|
139
|
+
class TerminateOnLossReached(TerminationCriteriaBase):
|
|
140
|
+
def __init__(self, value: float):
|
|
141
|
+
defaults = dict(value=value)
|
|
142
|
+
super().__init__(defaults)
|
|
143
|
+
|
|
144
|
+
def termination_criteria(self, var):
|
|
145
|
+
value = self.defaults['value']
|
|
146
|
+
return var.get_loss(False) <= value
|
|
147
|
+
|
|
148
|
+
class TerminateAny(TerminationCriteriaBase):
|
|
149
|
+
def __init__(self, *criteria: TerminationCriteriaBase):
|
|
150
|
+
super().__init__()
|
|
151
|
+
|
|
152
|
+
self.set_children_sequence(criteria)
|
|
153
|
+
|
|
154
|
+
def termination_criteria(self, var: Var) -> bool:
|
|
155
|
+
for c in self.get_children_sequence():
|
|
156
|
+
if cast(TerminationCriteriaBase, c).termination_criteria(var): return True
|
|
157
|
+
|
|
158
|
+
return False
|
|
159
|
+
|
|
160
|
+
class TerminateAll(TerminationCriteriaBase):
|
|
161
|
+
def __init__(self, *criteria: TerminationCriteriaBase):
|
|
162
|
+
super().__init__()
|
|
163
|
+
|
|
164
|
+
self.set_children_sequence(criteria)
|
|
165
|
+
|
|
166
|
+
def termination_criteria(self, var: Var) -> bool:
|
|
167
|
+
for c in self.get_children_sequence():
|
|
168
|
+
if not cast(TerminationCriteriaBase, c).termination_criteria(var): return False
|
|
169
|
+
|
|
170
|
+
return True
|
|
171
|
+
|
|
172
|
+
class TerminateNever(TerminationCriteriaBase):
|
|
173
|
+
def __init__(self):
|
|
174
|
+
super().__init__()
|
|
175
|
+
|
|
176
|
+
def termination_criteria(self, var): return False
|
|
177
|
+
|
|
178
|
+
def make_termination_criteria(
|
|
179
|
+
ftol: float | None = None,
|
|
180
|
+
gtol: float | None = None,
|
|
181
|
+
stol: float | None = None,
|
|
182
|
+
maxiter: int | None = None,
|
|
183
|
+
maxeval: int | None = None,
|
|
184
|
+
maxsec: float | None = None,
|
|
185
|
+
target_loss: float | None = None,
|
|
186
|
+
extra: TerminationCriteriaBase | Sequence[TerminationCriteriaBase] | None = None,
|
|
187
|
+
n: int = 3,
|
|
188
|
+
):
|
|
189
|
+
criteria: list[TerminationCriteriaBase] = []
|
|
190
|
+
|
|
191
|
+
if ftol is not None: criteria.append(TerminateOnNoImprovement(ftol, n=n))
|
|
192
|
+
if gtol is not None: criteria.append(TerminateByGradientNorm(gtol, n=n))
|
|
193
|
+
if stol is not None: criteria.append(TerminateByUpdateNorm(stol, n=n))
|
|
194
|
+
|
|
195
|
+
if maxiter is not None: criteria.append(TerminateAfterNSteps(maxiter))
|
|
196
|
+
if maxeval is not None: criteria.append(TerminateAfterNEvaluations(maxeval))
|
|
197
|
+
if maxsec is not None: criteria.append(TerminateAfterNSeconds(maxsec))
|
|
198
|
+
|
|
199
|
+
if target_loss is not None: criteria.append(TerminateOnLossReached(target_loss))
|
|
200
|
+
|
|
201
|
+
if extra is not None:
|
|
202
|
+
if isinstance(extra, TerminationCriteriaBase): criteria.append(extra)
|
|
203
|
+
else: criteria.extend(extra)
|
|
204
|
+
|
|
205
|
+
if len(criteria) == 0: return TerminateNever()
|
|
206
|
+
if len(criteria) == 1: return criteria[0]
|
|
207
|
+
return TerminateAny(*criteria)
|
|
@@ -0,0 +1,170 @@
|
|
|
1
|
+
# pylint:disable=not-callable
|
|
2
|
+
from collections.abc import Callable
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
from ...core import Chainable, Module
|
|
7
|
+
from ...utils import TensorList, vec_to_tensors
|
|
8
|
+
from ...utils.linalg.linear_operator import LinearOperator
|
|
9
|
+
from .trust_region import _RADIUS_KEYS, TrustRegionBase, _RadiusStrategy
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
# code from https://github.com/konstmish/opt_methods/blob/master/optmethods/second_order/cubic.py
|
|
13
|
+
# ported to pytorch and linear operator
|
|
14
|
+
def ls_cubic_solver(f, g:torch.Tensor, H:LinearOperator, M: float, loss_at_params_plus_x_fn: Callable | None, it_max=100, epsilon=1e-8, ):
|
|
15
|
+
"""
|
|
16
|
+
Solve min_z <g, z-x> + 1/2<z-x, H(z-x)> + M/3 ||z-x||^3
|
|
17
|
+
|
|
18
|
+
For explanation of Cauchy point, see "Gradient Descent
|
|
19
|
+
Efficiently Finds the Cubic-Regularized Non-Convex Newton Step"
|
|
20
|
+
https://arxiv.org/pdf/1612.00547.pdf
|
|
21
|
+
Other potential implementations can be found in paper
|
|
22
|
+
"Adaptive cubic regularisation methods"
|
|
23
|
+
https://people.maths.ox.ac.uk/cartis/papers/ARCpI.pdf
|
|
24
|
+
"""
|
|
25
|
+
solver_it = 1
|
|
26
|
+
newton_step = H.solve(g).neg_()
|
|
27
|
+
if M == 0:
|
|
28
|
+
return newton_step, solver_it
|
|
29
|
+
|
|
30
|
+
def cauchy_point(g, H:LinearOperator, M):
|
|
31
|
+
if torch.linalg.vector_norm(g) == 0 or M == 0:
|
|
32
|
+
return 0 * g
|
|
33
|
+
g_dir = g / torch.linalg.vector_norm(g)
|
|
34
|
+
H_g_g = H.matvec(g_dir) @ g_dir
|
|
35
|
+
R = -H_g_g / (2*M) + torch.sqrt((H_g_g/M)**2/4 + torch.linalg.vector_norm(g)/M)
|
|
36
|
+
return -R * g_dir
|
|
37
|
+
|
|
38
|
+
def conv_criterion(s, r):
|
|
39
|
+
"""
|
|
40
|
+
The convergence criterion is an increasing and concave function in r
|
|
41
|
+
and it is equal to 0 only if r is the solution to the cubic problem
|
|
42
|
+
"""
|
|
43
|
+
s_norm = torch.linalg.vector_norm(s)
|
|
44
|
+
return 1/s_norm - 1/r
|
|
45
|
+
|
|
46
|
+
# Solution s satisfies ||s|| >= Cauchy_radius
|
|
47
|
+
r_min = torch.linalg.vector_norm(cauchy_point(g, H, M))
|
|
48
|
+
|
|
49
|
+
if (loss_at_params_plus_x_fn is not None) and (f > loss_at_params_plus_x_fn(newton_step)):
|
|
50
|
+
return newton_step, solver_it
|
|
51
|
+
|
|
52
|
+
r_max = torch.linalg.vector_norm(newton_step)
|
|
53
|
+
if r_max - r_min < epsilon:
|
|
54
|
+
return newton_step, solver_it
|
|
55
|
+
|
|
56
|
+
# id_matrix = torch.eye(g.size(0), device=g.device, dtype=g.dtype)
|
|
57
|
+
s_lam = None
|
|
58
|
+
for _ in range(it_max):
|
|
59
|
+
r_try = (r_min + r_max) / 2
|
|
60
|
+
lam = r_try * M
|
|
61
|
+
s_lam = H.add_diagonal(lam).solve(g).neg()
|
|
62
|
+
# s_lam = -torch.linalg.solve(B + lam*id_matrix, g)
|
|
63
|
+
solver_it += 1
|
|
64
|
+
crit = conv_criterion(s_lam, r_try)
|
|
65
|
+
if torch.abs(crit) < epsilon:
|
|
66
|
+
return s_lam, solver_it
|
|
67
|
+
if crit < 0:
|
|
68
|
+
r_min = r_try
|
|
69
|
+
else:
|
|
70
|
+
r_max = r_try
|
|
71
|
+
if r_max - r_min < epsilon:
|
|
72
|
+
break
|
|
73
|
+
assert s_lam is not None
|
|
74
|
+
return s_lam, solver_it
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
class CubicRegularization(TrustRegionBase):
|
|
78
|
+
"""Cubic regularization.
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
hess_module (Module | None, optional):
|
|
82
|
+
A module that maintains a hessian approximation (not hessian inverse!).
|
|
83
|
+
This includes all full-matrix quasi-newton methods, ``tz.m.Newton`` and ``tz.m.GaussNewton``.
|
|
84
|
+
When using quasi-newton methods, set `inverse=False` when constructing them.
|
|
85
|
+
eta (float, optional):
|
|
86
|
+
if ratio of actual to predicted rediction is larger than this, step is accepted.
|
|
87
|
+
When :code:`hess_module` is GaussNewton, this can be set to 0. Defaults to 0.15.
|
|
88
|
+
nplus (float, optional): increase factor on successful steps. Defaults to 1.5.
|
|
89
|
+
nminus (float, optional): decrease factor on unsuccessful steps. Defaults to 0.75.
|
|
90
|
+
rho_good (float, optional):
|
|
91
|
+
if ratio of actual to predicted rediction is larger than this, trust region size is multiplied by `nplus`.
|
|
92
|
+
rho_bad (float, optional):
|
|
93
|
+
if ratio of actual to predicted rediction is less than this, trust region size is multiplied by `nminus`.
|
|
94
|
+
init (float, optional): Initial trust region value. Defaults to 1.
|
|
95
|
+
maxiter (float, optional): maximum iterations when solving cubic subproblem, defaults to 1e-7.
|
|
96
|
+
eps (float, optional): epsilon for the solver, defaults to 1e-8.
|
|
97
|
+
update_freq (int, optional): frequency of updating the hessian. Defaults to 1.
|
|
98
|
+
max_attempts (max_attempts, optional):
|
|
99
|
+
maximum number of trust region size size reductions per step. A zero update vector is returned when
|
|
100
|
+
this limit is exceeded. Defaults to 10.
|
|
101
|
+
fallback (bool, optional):
|
|
102
|
+
if ``True``, when ``hess_module`` maintains hessian inverse which can't be inverted efficiently, it will
|
|
103
|
+
be inverted anyway. When ``False`` (default), a ``RuntimeError`` will be raised instead.
|
|
104
|
+
inner (Chainable | None, optional): preconditioning is applied to output of thise module. Defaults to None.
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
Examples:
|
|
108
|
+
Cubic regularized newton
|
|
109
|
+
|
|
110
|
+
.. code-block:: python
|
|
111
|
+
|
|
112
|
+
opt = tz.Modular(
|
|
113
|
+
model.parameters(),
|
|
114
|
+
tz.m.CubicRegularization(tz.m.Newton()),
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
"""
|
|
118
|
+
def __init__(
|
|
119
|
+
self,
|
|
120
|
+
hess_module: Chainable,
|
|
121
|
+
eta: float= 0.0,
|
|
122
|
+
nplus: float = 3.5,
|
|
123
|
+
nminus: float = 0.25,
|
|
124
|
+
rho_good: float = 0.99,
|
|
125
|
+
rho_bad: float = 1e-4,
|
|
126
|
+
init: float = 1,
|
|
127
|
+
max_attempts: int = 10,
|
|
128
|
+
radius_strategy: _RadiusStrategy | _RADIUS_KEYS = 'default',
|
|
129
|
+
maxiter: int = 100,
|
|
130
|
+
eps: float = 1e-8,
|
|
131
|
+
check_decrease:bool=False,
|
|
132
|
+
update_freq: int = 1,
|
|
133
|
+
inner: Chainable | None = None,
|
|
134
|
+
):
|
|
135
|
+
defaults = dict(maxiter=maxiter, eps=eps, check_decrease=check_decrease)
|
|
136
|
+
super().__init__(
|
|
137
|
+
defaults=defaults,
|
|
138
|
+
hess_module=hess_module,
|
|
139
|
+
eta=eta,
|
|
140
|
+
nplus=nplus,
|
|
141
|
+
nminus=nminus,
|
|
142
|
+
rho_good=rho_good,
|
|
143
|
+
rho_bad=rho_bad,
|
|
144
|
+
init=init,
|
|
145
|
+
max_attempts=max_attempts,
|
|
146
|
+
radius_strategy=radius_strategy,
|
|
147
|
+
update_freq=update_freq,
|
|
148
|
+
inner=inner,
|
|
149
|
+
|
|
150
|
+
boundary_tol=None,
|
|
151
|
+
radius_fn=None,
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
def trust_solve(self, f, g, H, radius, params, closure, settings):
|
|
155
|
+
params = TensorList(params)
|
|
156
|
+
|
|
157
|
+
loss_at_params_plus_x_fn = None
|
|
158
|
+
if settings['check_decrease']:
|
|
159
|
+
def closure_plus_x(x):
|
|
160
|
+
x_unflat = vec_to_tensors(x, params)
|
|
161
|
+
params.add_(x_unflat)
|
|
162
|
+
loss_x = closure(False)
|
|
163
|
+
params.sub_(x_unflat)
|
|
164
|
+
return loss_x
|
|
165
|
+
loss_at_params_plus_x_fn = closure_plus_x
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
d, _ = ls_cubic_solver(f=f, g=g, H=H, M=1/radius, loss_at_params_plus_x_fn=loss_at_params_plus_x_fn,
|
|
169
|
+
it_max=settings['maxiter'], epsilon=settings['eps'])
|
|
170
|
+
return d.neg_()
|
|
@@ -0,0 +1,92 @@
|
|
|
1
|
+
# pylint:disable=not-callable
|
|
2
|
+
import torch
|
|
3
|
+
|
|
4
|
+
from ...core import Chainable, Module
|
|
5
|
+
from ...utils import TensorList, vec_to_tensors
|
|
6
|
+
from .trust_region import _RADIUS_KEYS, TrustRegionBase, _RadiusStrategy
|
|
7
|
+
|
|
8
|
+
class Dogleg(TrustRegionBase):
|
|
9
|
+
"""Dogleg trust region algorithm.
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
Args:
|
|
13
|
+
hess_module (Module | None, optional):
|
|
14
|
+
A module that maintains a hessian approximation (not hessian inverse!).
|
|
15
|
+
This includes all full-matrix quasi-newton methods, ``tz.m.Newton`` and ``tz.m.GaussNewton``.
|
|
16
|
+
When using quasi-newton methods, set `inverse=False` when constructing them.
|
|
17
|
+
eta (float, optional):
|
|
18
|
+
if ratio of actual to predicted rediction is larger than this, step is accepted.
|
|
19
|
+
When :code:`hess_module` is GaussNewton, this can be set to 0. Defaults to 0.15.
|
|
20
|
+
nplus (float, optional): increase factor on successful steps. Defaults to 1.5.
|
|
21
|
+
nminus (float, optional): decrease factor on unsuccessful steps. Defaults to 0.75.
|
|
22
|
+
rho_good (float, optional):
|
|
23
|
+
if ratio of actual to predicted rediction is larger than this, trust region size is multiplied by `nplus`.
|
|
24
|
+
rho_bad (float, optional):
|
|
25
|
+
if ratio of actual to predicted rediction is less than this, trust region size is multiplied by `nminus`.
|
|
26
|
+
init (float, optional): Initial trust region value. Defaults to 1.
|
|
27
|
+
update_freq (int, optional): frequency of updating the hessian. Defaults to 1.
|
|
28
|
+
max_attempts (max_attempts, optional):
|
|
29
|
+
maximum number of trust region size size reductions per step. A zero update vector is returned when
|
|
30
|
+
this limit is exceeded. Defaults to 10.
|
|
31
|
+
inner (Chainable | None, optional): preconditioning is applied to output of thise module. Defaults to None.
|
|
32
|
+
|
|
33
|
+
"""
|
|
34
|
+
def __init__(
|
|
35
|
+
self,
|
|
36
|
+
hess_module: Chainable,
|
|
37
|
+
eta: float= 0.0,
|
|
38
|
+
nplus: float = 2,
|
|
39
|
+
nminus: float = 0.25,
|
|
40
|
+
rho_good: float = 0.75,
|
|
41
|
+
rho_bad: float = 0.25,
|
|
42
|
+
boundary_tol: float | None = None,
|
|
43
|
+
init: float = 1,
|
|
44
|
+
max_attempts: int = 10,
|
|
45
|
+
radius_strategy: _RadiusStrategy | _RADIUS_KEYS = 'default',
|
|
46
|
+
update_freq: int = 1,
|
|
47
|
+
inner: Chainable | None = None,
|
|
48
|
+
):
|
|
49
|
+
defaults = dict()
|
|
50
|
+
super().__init__(
|
|
51
|
+
defaults=defaults,
|
|
52
|
+
hess_module=hess_module,
|
|
53
|
+
eta=eta,
|
|
54
|
+
nplus=nplus,
|
|
55
|
+
nminus=nminus,
|
|
56
|
+
rho_good=rho_good,
|
|
57
|
+
rho_bad=rho_bad,
|
|
58
|
+
boundary_tol=boundary_tol,
|
|
59
|
+
init=init,
|
|
60
|
+
max_attempts=max_attempts,
|
|
61
|
+
radius_strategy=radius_strategy,
|
|
62
|
+
update_freq=update_freq,
|
|
63
|
+
inner=inner,
|
|
64
|
+
|
|
65
|
+
radius_fn=torch.linalg.vector_norm,
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
def trust_solve(self, f, g, H, radius, params, closure, settings):
|
|
69
|
+
if radius > 2: radius = self.global_state['radius'] = 2
|
|
70
|
+
eps = torch.finfo(g.dtype).tiny * 2
|
|
71
|
+
|
|
72
|
+
gHg = g.dot(H.matvec(g))
|
|
73
|
+
if gHg <= eps:
|
|
74
|
+
return (radius / torch.linalg.vector_norm(g)) * g # pylint:disable=not-callable
|
|
75
|
+
|
|
76
|
+
p_cauchy = (g.dot(g) / gHg) * g
|
|
77
|
+
p_newton = H.solve(g)
|
|
78
|
+
|
|
79
|
+
a = p_newton - p_cauchy
|
|
80
|
+
b = p_cauchy
|
|
81
|
+
|
|
82
|
+
aa = a.dot(a)
|
|
83
|
+
if aa < eps:
|
|
84
|
+
return (radius / torch.linalg.vector_norm(g)) * g # pylint:disable=not-callable
|
|
85
|
+
|
|
86
|
+
ab = a.dot(b)
|
|
87
|
+
bb = b.dot(b)
|
|
88
|
+
c = bb - radius**2
|
|
89
|
+
discriminant = (2*ab)**2 - 4*aa*c
|
|
90
|
+
beta = (-2*ab + torch.sqrt(discriminant.clip(min=0))) / (2 * aa)
|
|
91
|
+
return p_cauchy + beta * (p_newton - p_cauchy)
|
|
92
|
+
|