torchzero 0.3.9__py3-none-any.whl → 0.3.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docs/source/conf.py +6 -4
- docs/source/docstring template.py +46 -0
- tests/test_identical.py +2 -3
- tests/test_opts.py +115 -68
- tests/test_tensorlist.py +2 -2
- tests/test_vars.py +62 -61
- torchzero/core/__init__.py +2 -3
- torchzero/core/module.py +185 -53
- torchzero/core/transform.py +327 -159
- torchzero/modules/__init__.py +3 -1
- torchzero/modules/clipping/clipping.py +120 -23
- torchzero/modules/clipping/ema_clipping.py +37 -22
- torchzero/modules/clipping/growth_clipping.py +20 -21
- torchzero/modules/experimental/__init__.py +30 -4
- torchzero/modules/experimental/absoap.py +53 -156
- torchzero/modules/experimental/adadam.py +22 -15
- torchzero/modules/experimental/adamY.py +21 -25
- torchzero/modules/experimental/adam_lambertw.py +149 -0
- torchzero/modules/{line_search/trust_region.py → experimental/adaptive_step_size.py} +37 -8
- torchzero/modules/experimental/adasoap.py +24 -129
- torchzero/modules/experimental/cosine.py +214 -0
- torchzero/modules/experimental/cubic_adam.py +97 -0
- torchzero/modules/experimental/curveball.py +12 -12
- torchzero/modules/{projections → experimental}/dct.py +11 -11
- torchzero/modules/experimental/eigendescent.py +120 -0
- torchzero/modules/experimental/etf.py +195 -0
- torchzero/modules/experimental/exp_adam.py +113 -0
- torchzero/modules/experimental/expanded_lbfgs.py +141 -0
- torchzero/modules/{projections → experimental}/fft.py +10 -10
- torchzero/modules/experimental/gradmin.py +2 -2
- torchzero/modules/experimental/hnewton.py +85 -0
- torchzero/modules/{quasi_newton/experimental → experimental}/modular_lbfgs.py +49 -50
- torchzero/modules/experimental/newton_solver.py +11 -11
- torchzero/modules/experimental/newtonnewton.py +92 -0
- torchzero/modules/experimental/parabolic_search.py +220 -0
- torchzero/modules/experimental/reduce_outward_lr.py +10 -7
- torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +12 -54
- torchzero/modules/experimental/subspace_preconditioners.py +20 -10
- torchzero/modules/experimental/tensor_adagrad.py +42 -0
- torchzero/modules/functional.py +12 -2
- torchzero/modules/grad_approximation/fdm.py +31 -4
- torchzero/modules/grad_approximation/forward_gradient.py +17 -7
- torchzero/modules/grad_approximation/grad_approximator.py +69 -24
- torchzero/modules/grad_approximation/rfdm.py +310 -50
- torchzero/modules/higher_order/__init__.py +1 -0
- torchzero/modules/higher_order/higher_order_newton.py +319 -0
- torchzero/modules/line_search/__init__.py +4 -4
- torchzero/modules/line_search/adaptive.py +99 -0
- torchzero/modules/line_search/backtracking.py +75 -31
- torchzero/modules/line_search/line_search.py +107 -49
- torchzero/modules/line_search/polynomial.py +233 -0
- torchzero/modules/line_search/scipy.py +20 -5
- torchzero/modules/line_search/strong_wolfe.py +52 -36
- torchzero/modules/misc/__init__.py +27 -0
- torchzero/modules/misc/debug.py +48 -0
- torchzero/modules/misc/escape.py +60 -0
- torchzero/modules/misc/gradient_accumulation.py +70 -0
- torchzero/modules/misc/misc.py +316 -0
- torchzero/modules/misc/multistep.py +158 -0
- torchzero/modules/misc/regularization.py +171 -0
- torchzero/modules/misc/split.py +103 -0
- torchzero/modules/{ops → misc}/switch.py +48 -7
- torchzero/modules/momentum/__init__.py +1 -1
- torchzero/modules/momentum/averaging.py +25 -10
- torchzero/modules/momentum/cautious.py +115 -40
- torchzero/modules/momentum/ema.py +92 -41
- torchzero/modules/momentum/experimental.py +21 -13
- torchzero/modules/momentum/matrix_momentum.py +145 -76
- torchzero/modules/momentum/momentum.py +25 -4
- torchzero/modules/ops/__init__.py +3 -31
- torchzero/modules/ops/accumulate.py +51 -25
- torchzero/modules/ops/binary.py +108 -62
- torchzero/modules/ops/multi.py +95 -34
- torchzero/modules/ops/reduce.py +31 -23
- torchzero/modules/ops/unary.py +37 -21
- torchzero/modules/ops/utility.py +53 -45
- torchzero/modules/optimizers/__init__.py +12 -3
- torchzero/modules/optimizers/adagrad.py +48 -29
- torchzero/modules/optimizers/adahessian.py +223 -0
- torchzero/modules/optimizers/adam.py +35 -37
- torchzero/modules/optimizers/adan.py +110 -0
- torchzero/modules/optimizers/adaptive_heavyball.py +57 -0
- torchzero/modules/optimizers/esgd.py +171 -0
- torchzero/modules/optimizers/ladagrad.py +183 -0
- torchzero/modules/optimizers/lion.py +4 -4
- torchzero/modules/optimizers/mars.py +91 -0
- torchzero/modules/optimizers/msam.py +186 -0
- torchzero/modules/optimizers/muon.py +32 -7
- torchzero/modules/optimizers/orthograd.py +4 -5
- torchzero/modules/optimizers/rmsprop.py +19 -19
- torchzero/modules/optimizers/rprop.py +89 -52
- torchzero/modules/optimizers/sam.py +163 -0
- torchzero/modules/optimizers/shampoo.py +55 -27
- torchzero/modules/optimizers/soap.py +40 -37
- torchzero/modules/optimizers/sophia_h.py +82 -25
- torchzero/modules/projections/__init__.py +2 -4
- torchzero/modules/projections/cast.py +51 -0
- torchzero/modules/projections/galore.py +4 -2
- torchzero/modules/projections/projection.py +212 -118
- torchzero/modules/quasi_newton/__init__.py +44 -5
- torchzero/modules/quasi_newton/cg.py +190 -39
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +163 -0
- torchzero/modules/quasi_newton/lbfgs.py +154 -97
- torchzero/modules/quasi_newton/lsr1.py +102 -58
- torchzero/modules/quasi_newton/quasi_newton.py +1032 -177
- torchzero/modules/quasi_newton/trust_region.py +397 -0
- torchzero/modules/second_order/__init__.py +2 -2
- torchzero/modules/second_order/newton.py +245 -54
- torchzero/modules/second_order/newton_cg.py +311 -21
- torchzero/modules/second_order/nystrom.py +124 -21
- torchzero/modules/smoothing/gaussian.py +55 -21
- torchzero/modules/smoothing/laplacian.py +20 -12
- torchzero/modules/step_size/__init__.py +2 -0
- torchzero/modules/step_size/adaptive.py +122 -0
- torchzero/modules/step_size/lr.py +154 -0
- torchzero/modules/weight_decay/__init__.py +1 -1
- torchzero/modules/weight_decay/weight_decay.py +126 -10
- torchzero/modules/wrappers/optim_wrapper.py +40 -12
- torchzero/optim/wrappers/directsearch.py +281 -0
- torchzero/optim/wrappers/fcmaes.py +105 -0
- torchzero/optim/wrappers/mads.py +89 -0
- torchzero/optim/wrappers/nevergrad.py +20 -5
- torchzero/optim/wrappers/nlopt.py +28 -14
- torchzero/optim/wrappers/optuna.py +70 -0
- torchzero/optim/wrappers/scipy.py +167 -16
- torchzero/utils/__init__.py +3 -7
- torchzero/utils/derivatives.py +5 -4
- torchzero/utils/linalg/__init__.py +1 -1
- torchzero/utils/linalg/solve.py +251 -12
- torchzero/utils/numberlist.py +2 -0
- torchzero/utils/optimizer.py +55 -74
- torchzero/utils/python_tools.py +27 -4
- torchzero/utils/tensorlist.py +40 -28
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/METADATA +76 -51
- torchzero-0.3.11.dist-info/RECORD +159 -0
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/WHEEL +1 -1
- torchzero/core/preconditioner.py +0 -138
- torchzero/modules/experimental/algebraic_newton.py +0 -145
- torchzero/modules/experimental/soapy.py +0 -290
- torchzero/modules/experimental/spectral.py +0 -288
- torchzero/modules/experimental/structured_newton.py +0 -111
- torchzero/modules/experimental/tropical_newton.py +0 -136
- torchzero/modules/lr/__init__.py +0 -2
- torchzero/modules/lr/lr.py +0 -59
- torchzero/modules/lr/step_size.py +0 -97
- torchzero/modules/ops/debug.py +0 -25
- torchzero/modules/ops/misc.py +0 -419
- torchzero/modules/ops/split.py +0 -75
- torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
- torchzero/modules/quasi_newton/olbfgs.py +0 -196
- torchzero-0.3.9.dist-info/RECORD +0 -131
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/licenses/LICENSE +0 -0
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/top_level.txt +0 -0
|
@@ -1,14 +1,47 @@
|
|
|
1
1
|
from abc import ABC, abstractmethod
|
|
2
|
+
from typing import Literal
|
|
2
3
|
|
|
3
4
|
import torch
|
|
4
5
|
|
|
5
|
-
from ...core import Chainable, Transform,
|
|
6
|
-
from ...utils import TensorList, as_tensorlist
|
|
6
|
+
from ...core import Chainable, TensorwiseTransform, Transform, apply_transform
|
|
7
|
+
from ...utils import TensorList, as_tensorlist, unpack_dicts, unpack_states
|
|
8
|
+
from .quasi_newton import _safe_clip, HessianUpdateStrategy
|
|
7
9
|
|
|
8
10
|
|
|
9
11
|
class ConguateGradientBase(Transform, ABC):
|
|
10
|
-
"""
|
|
11
|
-
|
|
12
|
+
"""Base class for conjugate gradient methods. The only difference between them is how beta is calculated.
|
|
13
|
+
|
|
14
|
+
This is an abstract class, to use it, subclass it and override `get_beta`.
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
Args:
|
|
18
|
+
defaults (dict | None, optional): dictionary of settings defaults. Defaults to None.
|
|
19
|
+
clip_beta (bool, optional): whether to clip beta to be no less than 0. Defaults to False.
|
|
20
|
+
reset_interval (int | None | Literal["auto"], optional):
|
|
21
|
+
interval between resetting the search direction.
|
|
22
|
+
"auto" means number of dimensions + 1, None means no reset. Defaults to None.
|
|
23
|
+
inner (Chainable | None, optional): previous direction is added to the output of this module. Defaults to None.
|
|
24
|
+
|
|
25
|
+
Example:
|
|
26
|
+
|
|
27
|
+
.. code-block:: python
|
|
28
|
+
|
|
29
|
+
class PolakRibiere(ConguateGradientBase):
|
|
30
|
+
def __init__(
|
|
31
|
+
self,
|
|
32
|
+
clip_beta=True,
|
|
33
|
+
reset_interval: int | None = None,
|
|
34
|
+
inner: Chainable | None = None
|
|
35
|
+
):
|
|
36
|
+
super().__init__(clip_beta=clip_beta, reset_interval=reset_interval, inner=inner)
|
|
37
|
+
|
|
38
|
+
def get_beta(self, p, g, prev_g, prev_d):
|
|
39
|
+
denom = prev_g.dot(prev_g)
|
|
40
|
+
if denom.abs() <= torch.finfo(g[0].dtype).eps: return 0
|
|
41
|
+
return g.dot(g - prev_g) / denom
|
|
42
|
+
|
|
43
|
+
"""
|
|
44
|
+
def __init__(self, defaults = None, clip_beta: bool = False, reset_interval: int | None | Literal['auto'] = None, inner: Chainable | None = None):
|
|
12
45
|
if defaults is None: defaults = {}
|
|
13
46
|
defaults['reset_interval'] = reset_interval
|
|
14
47
|
defaults['clip_beta'] = clip_beta
|
|
@@ -17,6 +50,15 @@ class ConguateGradientBase(Transform, ABC):
|
|
|
17
50
|
if inner is not None:
|
|
18
51
|
self.set_child('inner', inner)
|
|
19
52
|
|
|
53
|
+
def reset(self):
|
|
54
|
+
super().reset()
|
|
55
|
+
|
|
56
|
+
def reset_for_online(self):
|
|
57
|
+
super().reset_for_online()
|
|
58
|
+
self.clear_state_keys('prev_grad')
|
|
59
|
+
self.global_state.pop('stage', None)
|
|
60
|
+
self.global_state['step'] = self.global_state.get('step', 1) - 1
|
|
61
|
+
|
|
20
62
|
def initialize(self, p: TensorList, g: TensorList):
|
|
21
63
|
"""runs on first step when prev_grads and prev_dir are not available"""
|
|
22
64
|
|
|
@@ -25,38 +67,55 @@ class ConguateGradientBase(Transform, ABC):
|
|
|
25
67
|
"""returns beta"""
|
|
26
68
|
|
|
27
69
|
@torch.no_grad
|
|
28
|
-
def
|
|
70
|
+
def update_tensors(self, tensors, params, grads, loss, states, settings):
|
|
29
71
|
tensors = as_tensorlist(tensors)
|
|
30
72
|
params = as_tensorlist(params)
|
|
31
73
|
|
|
32
|
-
step = self.global_state.get('step', 0)
|
|
33
|
-
|
|
74
|
+
step = self.global_state.get('step', 0) + 1
|
|
75
|
+
self.global_state['step'] = step
|
|
34
76
|
|
|
35
77
|
# initialize on first step
|
|
36
|
-
if
|
|
78
|
+
if self.global_state.get('stage', 0) == 0:
|
|
79
|
+
g_prev, d_prev = unpack_states(states, tensors, 'g_prev', 'd_prev', cls=TensorList)
|
|
80
|
+
d_prev.copy_(tensors)
|
|
81
|
+
g_prev.copy_(tensors)
|
|
37
82
|
self.initialize(params, tensors)
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
83
|
+
self.global_state['stage'] = 1
|
|
84
|
+
|
|
85
|
+
else:
|
|
86
|
+
# if `update_tensors` was called multiple times before `apply_tensors`,
|
|
87
|
+
# stage becomes 2
|
|
88
|
+
self.global_state['stage'] = 2
|
|
89
|
+
|
|
90
|
+
@torch.no_grad
|
|
91
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
92
|
+
tensors = as_tensorlist(tensors)
|
|
93
|
+
step = self.global_state['step']
|
|
94
|
+
|
|
95
|
+
if 'inner' in self.children:
|
|
96
|
+
tensors = as_tensorlist(apply_transform(self.children['inner'], tensors, params, grads))
|
|
97
|
+
|
|
98
|
+
assert self.global_state['stage'] != 0
|
|
99
|
+
if self.global_state['stage'] == 1:
|
|
100
|
+
self.global_state['stage'] = 2
|
|
41
101
|
return tensors
|
|
42
102
|
|
|
103
|
+
params = as_tensorlist(params)
|
|
104
|
+
g_prev, d_prev = unpack_states(states, tensors, 'g_prev', 'd_prev', cls=TensorList)
|
|
105
|
+
|
|
43
106
|
# get beta
|
|
44
|
-
beta = self.get_beta(params, tensors,
|
|
45
|
-
if
|
|
46
|
-
prev_grads.copy_(tensors)
|
|
107
|
+
beta = self.get_beta(params, tensors, g_prev, d_prev)
|
|
108
|
+
if settings[0]['clip_beta']: beta = max(0, beta) # pyright:ignore[reportArgumentType]
|
|
47
109
|
|
|
48
110
|
# inner step
|
|
49
|
-
if 'inner' in self.children:
|
|
50
|
-
tensors = as_tensorlist(apply(self.children['inner'], tensors, params, grads, vars))
|
|
51
|
-
|
|
52
111
|
# calculate new direction with beta
|
|
53
|
-
dir = tensors.add_(
|
|
54
|
-
|
|
112
|
+
dir = tensors.add_(d_prev.mul_(beta))
|
|
113
|
+
d_prev.copy_(dir)
|
|
55
114
|
|
|
56
115
|
# resetting
|
|
57
|
-
|
|
58
|
-
reset_interval =
|
|
59
|
-
if reset_interval is not None and
|
|
116
|
+
reset_interval = settings[0]['reset_interval']
|
|
117
|
+
if reset_interval == 'auto': reset_interval = tensors.global_numel() + 1
|
|
118
|
+
if reset_interval is not None and step % reset_interval == 0:
|
|
60
119
|
self.reset()
|
|
61
120
|
|
|
62
121
|
return dir
|
|
@@ -68,7 +127,11 @@ def polak_ribiere_beta(g: TensorList, prev_g: TensorList):
|
|
|
68
127
|
return g.dot(g - prev_g) / denom
|
|
69
128
|
|
|
70
129
|
class PolakRibiere(ConguateGradientBase):
|
|
71
|
-
"""Polak-Ribière-Polyak nonlinear conjugate gradient method.
|
|
130
|
+
"""Polak-Ribière-Polyak nonlinear conjugate gradient method.
|
|
131
|
+
|
|
132
|
+
.. note::
|
|
133
|
+
- This requires step size to be determined via a line search, so put a line search like :code:`StrongWolfe(c2=0.1)` after this.
|
|
134
|
+
"""
|
|
72
135
|
def __init__(self, clip_beta=True, reset_interval: int | None = None, inner: Chainable | None = None):
|
|
73
136
|
super().__init__(clip_beta=clip_beta, reset_interval=reset_interval, inner=inner)
|
|
74
137
|
|
|
@@ -81,8 +144,12 @@ def fletcher_reeves_beta(gg: torch.Tensor, prev_gg: torch.Tensor):
|
|
|
81
144
|
return gg / prev_gg
|
|
82
145
|
|
|
83
146
|
class FletcherReeves(ConguateGradientBase):
|
|
84
|
-
"""Fletcher–Reeves nonlinear conjugate gradient method.
|
|
85
|
-
|
|
147
|
+
"""Fletcher–Reeves nonlinear conjugate gradient method.
|
|
148
|
+
|
|
149
|
+
.. note::
|
|
150
|
+
- This requires step size to be determined via a line search, so put a line search like :code:`StrongWolfe(c2=0.1)` after this.
|
|
151
|
+
"""
|
|
152
|
+
def __init__(self, reset_interval: int | None | Literal['auto'] = 'auto', clip_beta=False, inner: Chainable | None = None):
|
|
86
153
|
super().__init__(clip_beta=clip_beta, reset_interval=reset_interval, inner=inner)
|
|
87
154
|
|
|
88
155
|
def initialize(self, p, g):
|
|
@@ -103,8 +170,12 @@ def hestenes_stiefel_beta(g: TensorList, prev_d: TensorList,prev_g: TensorList):
|
|
|
103
170
|
|
|
104
171
|
|
|
105
172
|
class HestenesStiefel(ConguateGradientBase):
|
|
106
|
-
"""Hestenes–Stiefel nonlinear conjugate gradient method.
|
|
107
|
-
|
|
173
|
+
"""Hestenes–Stiefel nonlinear conjugate gradient method.
|
|
174
|
+
|
|
175
|
+
.. note::
|
|
176
|
+
- This requires step size to be determined via a line search, so put a line search like :code:`StrongWolfe(c2=0.1)` after this.
|
|
177
|
+
"""
|
|
178
|
+
def __init__(self, reset_interval: int | None | Literal['auto'] = None, clip_beta=False, inner: Chainable | None = None):
|
|
108
179
|
super().__init__(clip_beta=clip_beta, reset_interval=reset_interval, inner=inner)
|
|
109
180
|
|
|
110
181
|
def get_beta(self, p, g, prev_g, prev_d):
|
|
@@ -118,8 +189,12 @@ def dai_yuan_beta(g: TensorList, prev_d: TensorList,prev_g: TensorList):
|
|
|
118
189
|
return (g.dot(g) / denom).neg()
|
|
119
190
|
|
|
120
191
|
class DaiYuan(ConguateGradientBase):
|
|
121
|
-
"""Dai–Yuan nonlinear conjugate gradient method.
|
|
122
|
-
|
|
192
|
+
"""Dai–Yuan nonlinear conjugate gradient method.
|
|
193
|
+
|
|
194
|
+
.. note::
|
|
195
|
+
- This requires step size to be determined via a line search, so put a line search like :code:`StrongWolfe(c2=0.1)` after this. Although Dai–Yuan formula provides an automatic step size scaling so it is technically possible to omit line search and instead use a small step size.
|
|
196
|
+
"""
|
|
197
|
+
def __init__(self, reset_interval: int | None | Literal['auto'] = None, clip_beta=False, inner: Chainable | None = None):
|
|
123
198
|
super().__init__(clip_beta=clip_beta, reset_interval=reset_interval, inner=inner)
|
|
124
199
|
|
|
125
200
|
def get_beta(self, p, g, prev_g, prev_d):
|
|
@@ -133,8 +208,12 @@ def liu_storey_beta(g:TensorList, prev_d:TensorList, prev_g:TensorList, ):
|
|
|
133
208
|
return g.dot(g - prev_g) / denom
|
|
134
209
|
|
|
135
210
|
class LiuStorey(ConguateGradientBase):
|
|
136
|
-
"""Liu-Storey nonlinear conjugate gradient method.
|
|
137
|
-
|
|
211
|
+
"""Liu-Storey nonlinear conjugate gradient method.
|
|
212
|
+
|
|
213
|
+
.. note::
|
|
214
|
+
- This requires step size to be determined via a line search, so put a line search like :code:`StrongWolfe(c2=0.1)` after this.
|
|
215
|
+
"""
|
|
216
|
+
def __init__(self, reset_interval: int | None | Literal['auto'] = None, clip_beta=False, inner: Chainable | None = None):
|
|
138
217
|
super().__init__(clip_beta=clip_beta, reset_interval=reset_interval, inner=inner)
|
|
139
218
|
|
|
140
219
|
def get_beta(self, p, g, prev_g, prev_d):
|
|
@@ -142,7 +221,11 @@ class LiuStorey(ConguateGradientBase):
|
|
|
142
221
|
|
|
143
222
|
# ----------------------------- Conjugate Descent ---------------------------- #
|
|
144
223
|
class ConjugateDescent(Transform):
|
|
145
|
-
"""Conjugate Descent (CD).
|
|
224
|
+
"""Conjugate Descent (CD).
|
|
225
|
+
|
|
226
|
+
.. note::
|
|
227
|
+
- This requires step size to be determined via a line search, so put a line search like :code:`StrongWolfe(c2=0.1)` after this.
|
|
228
|
+
"""
|
|
146
229
|
def __init__(self, inner: Chainable | None = None):
|
|
147
230
|
super().__init__(defaults={}, uses_grad=False)
|
|
148
231
|
|
|
@@ -151,10 +234,10 @@ class ConjugateDescent(Transform):
|
|
|
151
234
|
|
|
152
235
|
|
|
153
236
|
@torch.no_grad
|
|
154
|
-
def
|
|
237
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
155
238
|
g = as_tensorlist(tensors)
|
|
156
239
|
|
|
157
|
-
prev_d =
|
|
240
|
+
prev_d = unpack_states(states, tensors, 'prev_dir', cls=TensorList, init=torch.zeros_like)
|
|
158
241
|
if 'denom' not in self.global_state:
|
|
159
242
|
self.global_state['denom'] = torch.tensor(0.).to(g[0])
|
|
160
243
|
|
|
@@ -164,7 +247,7 @@ class ConjugateDescent(Transform):
|
|
|
164
247
|
|
|
165
248
|
# inner step
|
|
166
249
|
if 'inner' in self.children:
|
|
167
|
-
g = as_tensorlist(
|
|
250
|
+
g = as_tensorlist(apply_transform(self.children['inner'], g, params, grads))
|
|
168
251
|
|
|
169
252
|
dir = g.add_(prev_d.mul_(beta))
|
|
170
253
|
prev_d.copy_(dir)
|
|
@@ -186,8 +269,11 @@ def hager_zhang_beta(g:TensorList, prev_d:TensorList, prev_g:TensorList,):
|
|
|
186
269
|
|
|
187
270
|
class HagerZhang(ConguateGradientBase):
|
|
188
271
|
"""Hager-Zhang nonlinear conjugate gradient method,
|
|
189
|
-
|
|
190
|
-
|
|
272
|
+
|
|
273
|
+
.. note::
|
|
274
|
+
- This requires step size to be determined via a line search, so put a line search like :code:`StrongWolfe(c2=0.1)` after this.
|
|
275
|
+
"""
|
|
276
|
+
def __init__(self, reset_interval: int | None | Literal['auto'] = None, clip_beta=False, inner: Chainable | None = None):
|
|
191
277
|
super().__init__(clip_beta=clip_beta, reset_interval=reset_interval, inner=inner)
|
|
192
278
|
|
|
193
279
|
def get_beta(self, p, g, prev_g, prev_d):
|
|
@@ -210,9 +296,74 @@ def hs_dy_beta(g: TensorList, prev_d: TensorList,prev_g: TensorList):
|
|
|
210
296
|
|
|
211
297
|
class HybridHS_DY(ConguateGradientBase):
|
|
212
298
|
"""HS-DY hybrid conjugate gradient method.
|
|
213
|
-
|
|
214
|
-
|
|
299
|
+
|
|
300
|
+
.. note::
|
|
301
|
+
- This requires step size to be determined via a line search, so put a line search like :code:`StrongWolfe(c2=0.1)` after this.
|
|
302
|
+
"""
|
|
303
|
+
def __init__(self, reset_interval: int | None | Literal['auto'] = None, clip_beta=False, inner: Chainable | None = None):
|
|
215
304
|
super().__init__(clip_beta=clip_beta, reset_interval=reset_interval, inner=inner)
|
|
216
305
|
|
|
217
306
|
def get_beta(self, p, g, prev_g, prev_d):
|
|
218
307
|
return hs_dy_beta(g, prev_d, prev_g)
|
|
308
|
+
|
|
309
|
+
|
|
310
|
+
def projected_gradient_(H:torch.Tensor, y:torch.Tensor):
|
|
311
|
+
Hy = H @ y
|
|
312
|
+
yHy = _safe_clip(y.dot(Hy))
|
|
313
|
+
H -= (Hy.outer(y) @ H) / yHy
|
|
314
|
+
return H
|
|
315
|
+
|
|
316
|
+
class ProjectedGradientMethod(HessianUpdateStrategy): # this doesn't maintain hessian
|
|
317
|
+
"""Projected gradient method.
|
|
318
|
+
|
|
319
|
+
.. note::
|
|
320
|
+
This method uses N^2 memory.
|
|
321
|
+
|
|
322
|
+
.. note::
|
|
323
|
+
This requires step size to be determined via a line search, so put a line search like :code:`StrongWolfe(c2=0.1)` after this.
|
|
324
|
+
|
|
325
|
+
.. note::
|
|
326
|
+
This is not the same as projected gradient descent.
|
|
327
|
+
|
|
328
|
+
Reference:
|
|
329
|
+
Pearson, J. D. (1969). Variable metric methods of minimisation. The Computer Journal, 12(2), 171–178. doi:10.1093/comjnl/12.2.171.
|
|
330
|
+
|
|
331
|
+
"""
|
|
332
|
+
|
|
333
|
+
def __init__(
|
|
334
|
+
self,
|
|
335
|
+
init_scale: float | Literal["auto"] = 1,
|
|
336
|
+
tol: float = 1e-8,
|
|
337
|
+
ptol: float | None = 1e-10,
|
|
338
|
+
ptol_reset: bool = False,
|
|
339
|
+
gtol: float | None = 1e-10,
|
|
340
|
+
reset_interval: int | None | Literal['auto'] = 'auto',
|
|
341
|
+
beta: float | None = None,
|
|
342
|
+
update_freq: int = 1,
|
|
343
|
+
scale_first: bool = False,
|
|
344
|
+
scale_second: bool = False,
|
|
345
|
+
concat_params: bool = True,
|
|
346
|
+
# inverse: bool = True,
|
|
347
|
+
inner: Chainable | None = None,
|
|
348
|
+
):
|
|
349
|
+
super().__init__(
|
|
350
|
+
defaults=None,
|
|
351
|
+
init_scale=init_scale,
|
|
352
|
+
tol=tol,
|
|
353
|
+
ptol=ptol,
|
|
354
|
+
ptol_reset=ptol_reset,
|
|
355
|
+
gtol=gtol,
|
|
356
|
+
reset_interval=reset_interval,
|
|
357
|
+
beta=beta,
|
|
358
|
+
update_freq=update_freq,
|
|
359
|
+
scale_first=scale_first,
|
|
360
|
+
scale_second=scale_second,
|
|
361
|
+
concat_params=concat_params,
|
|
362
|
+
inverse=True,
|
|
363
|
+
inner=inner,
|
|
364
|
+
)
|
|
365
|
+
|
|
366
|
+
|
|
367
|
+
|
|
368
|
+
def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
|
|
369
|
+
return projected_gradient_(H=H, y=y)
|
|
@@ -0,0 +1,163 @@
|
|
|
1
|
+
from collections.abc import Callable
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from .quasi_newton import (
|
|
6
|
+
HessianUpdateStrategy,
|
|
7
|
+
_HessianUpdateStrategyDefaults,
|
|
8
|
+
_InverseHessianUpdateStrategyDefaults,
|
|
9
|
+
_safe_clip,
|
|
10
|
+
)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def _diag_Bv(self: HessianUpdateStrategy):
|
|
14
|
+
B, is_inverse = self.get_B()
|
|
15
|
+
|
|
16
|
+
if is_inverse:
|
|
17
|
+
H=B
|
|
18
|
+
def Hxv(v): return v/H
|
|
19
|
+
return Hxv
|
|
20
|
+
|
|
21
|
+
def Bv(v): return B*v
|
|
22
|
+
return Bv
|
|
23
|
+
|
|
24
|
+
def _diag_Hv(self: HessianUpdateStrategy):
|
|
25
|
+
H, is_inverse = self.get_H()
|
|
26
|
+
|
|
27
|
+
if is_inverse:
|
|
28
|
+
B=H
|
|
29
|
+
def Bxv(v): return v/B
|
|
30
|
+
return Bxv
|
|
31
|
+
|
|
32
|
+
def Hv(v): return H*v
|
|
33
|
+
return Hv
|
|
34
|
+
|
|
35
|
+
def diagonal_bfgs_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
|
|
36
|
+
sy = s.dot(y)
|
|
37
|
+
if sy < tol: return H
|
|
38
|
+
|
|
39
|
+
sy_sq = _safe_clip(sy**2)
|
|
40
|
+
|
|
41
|
+
num1 = (sy + (y * H * y)) * s*s
|
|
42
|
+
term1 = num1.div_(sy_sq)
|
|
43
|
+
num2 = (H * y * s).add_(s * y * H)
|
|
44
|
+
term2 = num2.div_(sy)
|
|
45
|
+
H += term1.sub_(term2)
|
|
46
|
+
return H
|
|
47
|
+
|
|
48
|
+
class DiagonalBFGS(_InverseHessianUpdateStrategyDefaults):
|
|
49
|
+
"""Diagonal BFGS. This is simply BFGS with only the diagonal being updated and used. It doesn't satisfy the secant equation but may still be useful."""
|
|
50
|
+
def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
|
|
51
|
+
return diagonal_bfgs_H_(H=H, s=s, y=y, tol=setting['tol'])
|
|
52
|
+
|
|
53
|
+
def _init_M(self, size:int, device, dtype, is_inverse:bool): return torch.ones(size, device=device, dtype=dtype)
|
|
54
|
+
def make_Bv(self): return _diag_Bv(self)
|
|
55
|
+
def make_Hv(self): return _diag_Hv(self)
|
|
56
|
+
|
|
57
|
+
def diagonal_sr1_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol:float):
|
|
58
|
+
z = s - H*y
|
|
59
|
+
denom = z.dot(y)
|
|
60
|
+
|
|
61
|
+
z_norm = torch.linalg.norm(z) # pylint:disable=not-callable
|
|
62
|
+
y_norm = torch.linalg.norm(y) # pylint:disable=not-callable
|
|
63
|
+
|
|
64
|
+
# if y_norm*z_norm < tol: return H
|
|
65
|
+
|
|
66
|
+
# check as in Nocedal, Wright. “Numerical optimization” 2nd p.146
|
|
67
|
+
if denom.abs() <= tol * y_norm * z_norm: return H # pylint:disable=not-callable
|
|
68
|
+
H += (z*z).div_(_safe_clip(denom))
|
|
69
|
+
return H
|
|
70
|
+
class DiagonalSR1(_InverseHessianUpdateStrategyDefaults):
|
|
71
|
+
"""Diagonal SR1. This is simply SR1 with only the diagonal being updated and used. It doesn't satisfy the secant equation but may still be useful."""
|
|
72
|
+
def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
|
|
73
|
+
return diagonal_sr1_(H=H, s=s, y=y, tol=setting['tol'])
|
|
74
|
+
def update_B(self, B, s, y, p, g, p_prev, g_prev, state, setting):
|
|
75
|
+
return diagonal_sr1_(H=B, s=y, y=s, tol=setting['tol'])
|
|
76
|
+
|
|
77
|
+
def _init_M(self, size:int, device, dtype, is_inverse:bool): return torch.ones(size, device=device, dtype=dtype)
|
|
78
|
+
def make_Bv(self): return _diag_Bv(self)
|
|
79
|
+
def make_Hv(self): return _diag_Hv(self)
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
# Zhu M., Nazareth J. L., Wolkowicz H. The quasi-Cauchy relation and diagonal updating //SIAM Journal on Optimization. – 1999. – Т. 9. – №. 4. – С. 1192-1204.
|
|
84
|
+
def diagonal_qc_B_(B:torch.Tensor, s: torch.Tensor, y:torch.Tensor):
|
|
85
|
+
denom = _safe_clip((s**4).sum())
|
|
86
|
+
num = s.dot(y) - (s*B).dot(s)
|
|
87
|
+
B += s**2 * (num/denom)
|
|
88
|
+
return B
|
|
89
|
+
|
|
90
|
+
class DiagonalQuasiCauchi(_HessianUpdateStrategyDefaults):
|
|
91
|
+
"""Diagonal quasi-cauchi method.
|
|
92
|
+
|
|
93
|
+
Reference:
|
|
94
|
+
Zhu M., Nazareth J. L., Wolkowicz H. The quasi-Cauchy relation and diagonal updating //SIAM Journal on Optimization. – 1999. – Т. 9. – №. 4. – С. 1192-1204.
|
|
95
|
+
"""
|
|
96
|
+
def update_B(self, B, s, y, p, g, p_prev, g_prev, state, setting):
|
|
97
|
+
return diagonal_qc_B_(B=B, s=s, y=y)
|
|
98
|
+
|
|
99
|
+
def _init_M(self, size:int, device, dtype, is_inverse:bool): return torch.ones(size, device=device, dtype=dtype)
|
|
100
|
+
def make_Bv(self): return _diag_Bv(self)
|
|
101
|
+
def make_Hv(self): return _diag_Hv(self)
|
|
102
|
+
|
|
103
|
+
# Leong, Wah June, Sharareh Enshaei, and Sie Long Kek. "Diagonal quasi-Newton methods via least change updating principle with weighted Frobenius norm." Numerical Algorithms 86 (2021): 1225-1241.
|
|
104
|
+
def diagonal_wqc_B_(B:torch.Tensor, s: torch.Tensor, y:torch.Tensor):
|
|
105
|
+
E_sq = s**2 * B**2
|
|
106
|
+
denom = _safe_clip((s*E_sq).dot(s))
|
|
107
|
+
num = s.dot(y) - (s*B).dot(s)
|
|
108
|
+
B += E_sq * (num/denom)
|
|
109
|
+
return B
|
|
110
|
+
|
|
111
|
+
class DiagonalWeightedQuasiCauchi(_HessianUpdateStrategyDefaults):
|
|
112
|
+
"""Diagonal quasi-cauchi method.
|
|
113
|
+
|
|
114
|
+
Reference:
|
|
115
|
+
Leong, Wah June, Sharareh Enshaei, and Sie Long Kek. "Diagonal quasi-Newton methods via least change updating principle with weighted Frobenius norm." Numerical Algorithms 86 (2021): 1225-1241.
|
|
116
|
+
"""
|
|
117
|
+
def update_B(self, B, s, y, p, g, p_prev, g_prev, state, setting):
|
|
118
|
+
return diagonal_wqc_B_(B=B, s=s, y=y)
|
|
119
|
+
|
|
120
|
+
def _init_M(self, size:int, device, dtype, is_inverse:bool): return torch.ones(size, device=device, dtype=dtype)
|
|
121
|
+
def make_Bv(self): return _diag_Bv(self)
|
|
122
|
+
def make_Hv(self): return _diag_Hv(self)
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
# Andrei, Neculai. "A diagonal quasi-Newton updating method for unconstrained optimization." Numerical Algorithms 81.2 (2019): 575-590.
|
|
126
|
+
def dnrtr_B_(B:torch.Tensor, s: torch.Tensor, y:torch.Tensor):
|
|
127
|
+
denom = _safe_clip((s**4).sum())
|
|
128
|
+
num = s.dot(y) + s.dot(s) - (s*B).dot(s)
|
|
129
|
+
B += s**2 * (num/denom) - 1
|
|
130
|
+
return B
|
|
131
|
+
|
|
132
|
+
class DNRTR(_HessianUpdateStrategyDefaults):
|
|
133
|
+
"""Diagonal quasi-newton method.
|
|
134
|
+
|
|
135
|
+
Reference:
|
|
136
|
+
Andrei, Neculai. "A diagonal quasi-Newton updating method for unconstrained optimization." Numerical Algorithms 81.2 (2019): 575-590.
|
|
137
|
+
"""
|
|
138
|
+
def update_B(self, B, s, y, p, g, p_prev, g_prev, state, setting):
|
|
139
|
+
return diagonal_wqc_B_(B=B, s=s, y=y)
|
|
140
|
+
|
|
141
|
+
def _init_M(self, size:int, device, dtype, is_inverse:bool): return torch.ones(size, device=device, dtype=dtype)
|
|
142
|
+
def make_Bv(self): return _diag_Bv(self)
|
|
143
|
+
def make_Hv(self): return _diag_Hv(self)
|
|
144
|
+
|
|
145
|
+
# Nosrati, Mahsa, and Keyvan Amini. "A new diagonal quasi-Newton algorithm for unconstrained optimization problems." Applications of Mathematics 69.4 (2024): 501-512.
|
|
146
|
+
def new_dqn_B_(B:torch.Tensor, s: torch.Tensor, y:torch.Tensor):
|
|
147
|
+
denom = _safe_clip((s**4).sum())
|
|
148
|
+
num = s.dot(y)
|
|
149
|
+
B += s**2 * (num/denom)
|
|
150
|
+
return B
|
|
151
|
+
|
|
152
|
+
class NewDQN(_HessianUpdateStrategyDefaults):
|
|
153
|
+
"""Diagonal quasi-newton method.
|
|
154
|
+
|
|
155
|
+
Reference:
|
|
156
|
+
Nosrati, Mahsa, and Keyvan Amini. "A new diagonal quasi-Newton algorithm for unconstrained optimization problems." Applications of Mathematics 69.4 (2024): 501-512.
|
|
157
|
+
"""
|
|
158
|
+
def update_B(self, B, s, y, p, g, p_prev, g_prev, state, setting):
|
|
159
|
+
return new_dqn_B_(B=B, s=s, y=y)
|
|
160
|
+
|
|
161
|
+
def _init_M(self, size:int, device, dtype, is_inverse:bool): return torch.ones(size, device=device, dtype=dtype)
|
|
162
|
+
def make_Bv(self): return _diag_Bv(self)
|
|
163
|
+
def make_Hv(self): return _diag_Hv(self)
|