torchzero 0.3.10__py3-none-any.whl → 0.3.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docs/source/conf.py +6 -4
- docs/source/docstring template.py +46 -0
- tests/test_identical.py +2 -3
- tests/test_opts.py +64 -50
- tests/test_vars.py +1 -0
- torchzero/core/module.py +138 -6
- torchzero/core/transform.py +158 -51
- torchzero/modules/__init__.py +3 -2
- torchzero/modules/clipping/clipping.py +114 -17
- torchzero/modules/clipping/ema_clipping.py +27 -13
- torchzero/modules/clipping/growth_clipping.py +8 -7
- torchzero/modules/experimental/__init__.py +22 -5
- torchzero/modules/experimental/absoap.py +5 -2
- torchzero/modules/experimental/adadam.py +8 -2
- torchzero/modules/experimental/adamY.py +8 -2
- torchzero/modules/experimental/adam_lambertw.py +149 -0
- torchzero/modules/{line_search/trust_region.py → experimental/adaptive_step_size.py} +21 -4
- torchzero/modules/experimental/adasoap.py +7 -2
- torchzero/modules/experimental/cosine.py +214 -0
- torchzero/modules/experimental/cubic_adam.py +97 -0
- torchzero/modules/{projections → experimental}/dct.py +11 -11
- torchzero/modules/experimental/eigendescent.py +4 -1
- torchzero/modules/experimental/etf.py +32 -9
- torchzero/modules/experimental/exp_adam.py +113 -0
- torchzero/modules/experimental/expanded_lbfgs.py +141 -0
- torchzero/modules/{projections → experimental}/fft.py +10 -10
- torchzero/modules/experimental/hnewton.py +85 -0
- torchzero/modules/{quasi_newton/experimental → experimental}/modular_lbfgs.py +27 -28
- torchzero/modules/experimental/newtonnewton.py +7 -3
- torchzero/modules/experimental/parabolic_search.py +220 -0
- torchzero/modules/experimental/reduce_outward_lr.py +4 -4
- torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +12 -54
- torchzero/modules/experimental/subspace_preconditioners.py +11 -4
- torchzero/modules/experimental/{tada.py → tensor_adagrad.py} +10 -6
- torchzero/modules/functional.py +12 -2
- torchzero/modules/grad_approximation/fdm.py +30 -3
- torchzero/modules/grad_approximation/forward_gradient.py +13 -3
- torchzero/modules/grad_approximation/grad_approximator.py +51 -6
- torchzero/modules/grad_approximation/rfdm.py +285 -38
- torchzero/modules/higher_order/higher_order_newton.py +152 -89
- torchzero/modules/line_search/__init__.py +4 -4
- torchzero/modules/line_search/adaptive.py +99 -0
- torchzero/modules/line_search/backtracking.py +34 -9
- torchzero/modules/line_search/line_search.py +70 -12
- torchzero/modules/line_search/polynomial.py +233 -0
- torchzero/modules/line_search/scipy.py +2 -2
- torchzero/modules/line_search/strong_wolfe.py +34 -7
- torchzero/modules/misc/__init__.py +27 -0
- torchzero/modules/{ops → misc}/debug.py +24 -1
- torchzero/modules/misc/escape.py +60 -0
- torchzero/modules/misc/gradient_accumulation.py +70 -0
- torchzero/modules/misc/misc.py +316 -0
- torchzero/modules/misc/multistep.py +158 -0
- torchzero/modules/misc/regularization.py +171 -0
- torchzero/modules/{ops → misc}/split.py +29 -1
- torchzero/modules/{ops → misc}/switch.py +44 -3
- torchzero/modules/momentum/__init__.py +1 -1
- torchzero/modules/momentum/averaging.py +6 -6
- torchzero/modules/momentum/cautious.py +45 -8
- torchzero/modules/momentum/ema.py +7 -7
- torchzero/modules/momentum/experimental.py +2 -2
- torchzero/modules/momentum/matrix_momentum.py +90 -63
- torchzero/modules/momentum/momentum.py +2 -1
- torchzero/modules/ops/__init__.py +3 -31
- torchzero/modules/ops/accumulate.py +6 -10
- torchzero/modules/ops/binary.py +72 -26
- torchzero/modules/ops/multi.py +77 -16
- torchzero/modules/ops/reduce.py +15 -7
- torchzero/modules/ops/unary.py +29 -13
- torchzero/modules/ops/utility.py +20 -12
- torchzero/modules/optimizers/__init__.py +12 -3
- torchzero/modules/optimizers/adagrad.py +23 -13
- torchzero/modules/optimizers/adahessian.py +223 -0
- torchzero/modules/optimizers/adam.py +7 -6
- torchzero/modules/optimizers/adan.py +110 -0
- torchzero/modules/optimizers/adaptive_heavyball.py +57 -0
- torchzero/modules/optimizers/esgd.py +171 -0
- torchzero/modules/{experimental/spectral.py → optimizers/ladagrad.py} +91 -71
- torchzero/modules/optimizers/lion.py +1 -1
- torchzero/modules/optimizers/mars.py +91 -0
- torchzero/modules/optimizers/msam.py +186 -0
- torchzero/modules/optimizers/muon.py +30 -5
- torchzero/modules/optimizers/orthograd.py +1 -1
- torchzero/modules/optimizers/rmsprop.py +7 -4
- torchzero/modules/optimizers/rprop.py +42 -8
- torchzero/modules/optimizers/sam.py +163 -0
- torchzero/modules/optimizers/shampoo.py +39 -5
- torchzero/modules/optimizers/soap.py +29 -19
- torchzero/modules/optimizers/sophia_h.py +71 -14
- torchzero/modules/projections/__init__.py +2 -4
- torchzero/modules/projections/cast.py +51 -0
- torchzero/modules/projections/galore.py +3 -1
- torchzero/modules/projections/projection.py +188 -94
- torchzero/modules/quasi_newton/__init__.py +12 -2
- torchzero/modules/quasi_newton/cg.py +160 -59
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +163 -0
- torchzero/modules/quasi_newton/lbfgs.py +154 -97
- torchzero/modules/quasi_newton/lsr1.py +101 -57
- torchzero/modules/quasi_newton/quasi_newton.py +863 -215
- torchzero/modules/quasi_newton/trust_region.py +397 -0
- torchzero/modules/second_order/__init__.py +2 -2
- torchzero/modules/second_order/newton.py +220 -41
- torchzero/modules/second_order/newton_cg.py +300 -11
- torchzero/modules/second_order/nystrom.py +104 -1
- torchzero/modules/smoothing/gaussian.py +34 -0
- torchzero/modules/smoothing/laplacian.py +14 -4
- torchzero/modules/step_size/__init__.py +2 -0
- torchzero/modules/step_size/adaptive.py +122 -0
- torchzero/modules/step_size/lr.py +154 -0
- torchzero/modules/weight_decay/__init__.py +1 -1
- torchzero/modules/weight_decay/weight_decay.py +89 -7
- torchzero/modules/wrappers/optim_wrapper.py +29 -1
- torchzero/optim/wrappers/directsearch.py +39 -2
- torchzero/optim/wrappers/fcmaes.py +21 -13
- torchzero/optim/wrappers/mads.py +5 -6
- torchzero/optim/wrappers/nevergrad.py +16 -1
- torchzero/optim/wrappers/optuna.py +1 -1
- torchzero/optim/wrappers/scipy.py +5 -3
- torchzero/utils/__init__.py +2 -2
- torchzero/utils/derivatives.py +3 -3
- torchzero/utils/linalg/__init__.py +1 -1
- torchzero/utils/linalg/solve.py +251 -12
- torchzero/utils/numberlist.py +2 -0
- torchzero/utils/python_tools.py +10 -0
- torchzero/utils/tensorlist.py +40 -28
- {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/METADATA +65 -40
- torchzero-0.3.11.dist-info/RECORD +159 -0
- torchzero/modules/experimental/diagonal_higher_order_newton.py +0 -225
- torchzero/modules/experimental/soapy.py +0 -163
- torchzero/modules/experimental/structured_newton.py +0 -111
- torchzero/modules/lr/__init__.py +0 -2
- torchzero/modules/lr/adaptive.py +0 -93
- torchzero/modules/lr/lr.py +0 -63
- torchzero/modules/ops/misc.py +0 -418
- torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
- torchzero/modules/quasi_newton/olbfgs.py +0 -196
- torchzero-0.3.10.dist-info/RECORD +0 -139
- {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/WHEEL +0 -0
- {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/licenses/LICENSE +0 -0
- {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/top_level.txt +0 -0
|
@@ -5,10 +5,42 @@ import torch
|
|
|
5
5
|
|
|
6
6
|
from ...core import Chainable, TensorwiseTransform, Transform, apply_transform
|
|
7
7
|
from ...utils import TensorList, as_tensorlist, unpack_dicts, unpack_states
|
|
8
|
+
from .quasi_newton import _safe_clip, HessianUpdateStrategy
|
|
8
9
|
|
|
9
10
|
|
|
10
11
|
class ConguateGradientBase(Transform, ABC):
|
|
11
|
-
"""
|
|
12
|
+
"""Base class for conjugate gradient methods. The only difference between them is how beta is calculated.
|
|
13
|
+
|
|
14
|
+
This is an abstract class, to use it, subclass it and override `get_beta`.
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
Args:
|
|
18
|
+
defaults (dict | None, optional): dictionary of settings defaults. Defaults to None.
|
|
19
|
+
clip_beta (bool, optional): whether to clip beta to be no less than 0. Defaults to False.
|
|
20
|
+
reset_interval (int | None | Literal["auto"], optional):
|
|
21
|
+
interval between resetting the search direction.
|
|
22
|
+
"auto" means number of dimensions + 1, None means no reset. Defaults to None.
|
|
23
|
+
inner (Chainable | None, optional): previous direction is added to the output of this module. Defaults to None.
|
|
24
|
+
|
|
25
|
+
Example:
|
|
26
|
+
|
|
27
|
+
.. code-block:: python
|
|
28
|
+
|
|
29
|
+
class PolakRibiere(ConguateGradientBase):
|
|
30
|
+
def __init__(
|
|
31
|
+
self,
|
|
32
|
+
clip_beta=True,
|
|
33
|
+
reset_interval: int | None = None,
|
|
34
|
+
inner: Chainable | None = None
|
|
35
|
+
):
|
|
36
|
+
super().__init__(clip_beta=clip_beta, reset_interval=reset_interval, inner=inner)
|
|
37
|
+
|
|
38
|
+
def get_beta(self, p, g, prev_g, prev_d):
|
|
39
|
+
denom = prev_g.dot(prev_g)
|
|
40
|
+
if denom.abs() <= torch.finfo(g[0].dtype).eps: return 0
|
|
41
|
+
return g.dot(g - prev_g) / denom
|
|
42
|
+
|
|
43
|
+
"""
|
|
12
44
|
def __init__(self, defaults = None, clip_beta: bool = False, reset_interval: int | None | Literal['auto'] = None, inner: Chainable | None = None):
|
|
13
45
|
if defaults is None: defaults = {}
|
|
14
46
|
defaults['reset_interval'] = reset_interval
|
|
@@ -18,6 +50,15 @@ class ConguateGradientBase(Transform, ABC):
|
|
|
18
50
|
if inner is not None:
|
|
19
51
|
self.set_child('inner', inner)
|
|
20
52
|
|
|
53
|
+
def reset(self):
|
|
54
|
+
super().reset()
|
|
55
|
+
|
|
56
|
+
def reset_for_online(self):
|
|
57
|
+
super().reset_for_online()
|
|
58
|
+
self.clear_state_keys('prev_grad')
|
|
59
|
+
self.global_state.pop('stage', None)
|
|
60
|
+
self.global_state['step'] = self.global_state.get('step', 1) - 1
|
|
61
|
+
|
|
21
62
|
def initialize(self, p: TensorList, g: TensorList):
|
|
22
63
|
"""runs on first step when prev_grads and prev_dir are not available"""
|
|
23
64
|
|
|
@@ -26,39 +67,55 @@ class ConguateGradientBase(Transform, ABC):
|
|
|
26
67
|
"""returns beta"""
|
|
27
68
|
|
|
28
69
|
@torch.no_grad
|
|
29
|
-
def
|
|
70
|
+
def update_tensors(self, tensors, params, grads, loss, states, settings):
|
|
30
71
|
tensors = as_tensorlist(tensors)
|
|
31
72
|
params = as_tensorlist(params)
|
|
32
73
|
|
|
33
|
-
step = self.global_state.get('step', 0)
|
|
34
|
-
|
|
74
|
+
step = self.global_state.get('step', 0) + 1
|
|
75
|
+
self.global_state['step'] = step
|
|
35
76
|
|
|
36
77
|
# initialize on first step
|
|
37
|
-
if
|
|
78
|
+
if self.global_state.get('stage', 0) == 0:
|
|
79
|
+
g_prev, d_prev = unpack_states(states, tensors, 'g_prev', 'd_prev', cls=TensorList)
|
|
80
|
+
d_prev.copy_(tensors)
|
|
81
|
+
g_prev.copy_(tensors)
|
|
38
82
|
self.initialize(params, tensors)
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
83
|
+
self.global_state['stage'] = 1
|
|
84
|
+
|
|
85
|
+
else:
|
|
86
|
+
# if `update_tensors` was called multiple times before `apply_tensors`,
|
|
87
|
+
# stage becomes 2
|
|
88
|
+
self.global_state['stage'] = 2
|
|
89
|
+
|
|
90
|
+
@torch.no_grad
|
|
91
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
92
|
+
tensors = as_tensorlist(tensors)
|
|
93
|
+
step = self.global_state['step']
|
|
94
|
+
|
|
95
|
+
if 'inner' in self.children:
|
|
96
|
+
tensors = as_tensorlist(apply_transform(self.children['inner'], tensors, params, grads))
|
|
97
|
+
|
|
98
|
+
assert self.global_state['stage'] != 0
|
|
99
|
+
if self.global_state['stage'] == 1:
|
|
100
|
+
self.global_state['stage'] = 2
|
|
42
101
|
return tensors
|
|
43
102
|
|
|
103
|
+
params = as_tensorlist(params)
|
|
104
|
+
g_prev, d_prev = unpack_states(states, tensors, 'g_prev', 'd_prev', cls=TensorList)
|
|
105
|
+
|
|
44
106
|
# get beta
|
|
45
|
-
beta = self.get_beta(params, tensors,
|
|
107
|
+
beta = self.get_beta(params, tensors, g_prev, d_prev)
|
|
46
108
|
if settings[0]['clip_beta']: beta = max(0, beta) # pyright:ignore[reportArgumentType]
|
|
47
|
-
prev_grads.copy_(tensors)
|
|
48
109
|
|
|
49
110
|
# inner step
|
|
50
|
-
if 'inner' in self.children:
|
|
51
|
-
tensors = as_tensorlist(apply_transform(self.children['inner'], tensors, params, grads))
|
|
52
|
-
|
|
53
111
|
# calculate new direction with beta
|
|
54
|
-
dir = tensors.add_(
|
|
55
|
-
|
|
112
|
+
dir = tensors.add_(d_prev.mul_(beta))
|
|
113
|
+
d_prev.copy_(dir)
|
|
56
114
|
|
|
57
115
|
# resetting
|
|
58
|
-
self.global_state['step'] = step + 1
|
|
59
116
|
reset_interval = settings[0]['reset_interval']
|
|
60
117
|
if reset_interval == 'auto': reset_interval = tensors.global_numel() + 1
|
|
61
|
-
if reset_interval is not None and
|
|
118
|
+
if reset_interval is not None and step % reset_interval == 0:
|
|
62
119
|
self.reset()
|
|
63
120
|
|
|
64
121
|
return dir
|
|
@@ -70,7 +127,11 @@ def polak_ribiere_beta(g: TensorList, prev_g: TensorList):
|
|
|
70
127
|
return g.dot(g - prev_g) / denom
|
|
71
128
|
|
|
72
129
|
class PolakRibiere(ConguateGradientBase):
|
|
73
|
-
"""Polak-Ribière-Polyak nonlinear conjugate gradient method.
|
|
130
|
+
"""Polak-Ribière-Polyak nonlinear conjugate gradient method.
|
|
131
|
+
|
|
132
|
+
.. note::
|
|
133
|
+
- This requires step size to be determined via a line search, so put a line search like :code:`StrongWolfe(c2=0.1)` after this.
|
|
134
|
+
"""
|
|
74
135
|
def __init__(self, clip_beta=True, reset_interval: int | None = None, inner: Chainable | None = None):
|
|
75
136
|
super().__init__(clip_beta=clip_beta, reset_interval=reset_interval, inner=inner)
|
|
76
137
|
|
|
@@ -83,7 +144,11 @@ def fletcher_reeves_beta(gg: torch.Tensor, prev_gg: torch.Tensor):
|
|
|
83
144
|
return gg / prev_gg
|
|
84
145
|
|
|
85
146
|
class FletcherReeves(ConguateGradientBase):
|
|
86
|
-
"""Fletcher–Reeves nonlinear conjugate gradient method.
|
|
147
|
+
"""Fletcher–Reeves nonlinear conjugate gradient method.
|
|
148
|
+
|
|
149
|
+
.. note::
|
|
150
|
+
- This requires step size to be determined via a line search, so put a line search like :code:`StrongWolfe(c2=0.1)` after this.
|
|
151
|
+
"""
|
|
87
152
|
def __init__(self, reset_interval: int | None | Literal['auto'] = 'auto', clip_beta=False, inner: Chainable | None = None):
|
|
88
153
|
super().__init__(clip_beta=clip_beta, reset_interval=reset_interval, inner=inner)
|
|
89
154
|
|
|
@@ -105,7 +170,11 @@ def hestenes_stiefel_beta(g: TensorList, prev_d: TensorList,prev_g: TensorList):
|
|
|
105
170
|
|
|
106
171
|
|
|
107
172
|
class HestenesStiefel(ConguateGradientBase):
|
|
108
|
-
"""Hestenes–Stiefel nonlinear conjugate gradient method.
|
|
173
|
+
"""Hestenes–Stiefel nonlinear conjugate gradient method.
|
|
174
|
+
|
|
175
|
+
.. note::
|
|
176
|
+
- This requires step size to be determined via a line search, so put a line search like :code:`StrongWolfe(c2=0.1)` after this.
|
|
177
|
+
"""
|
|
109
178
|
def __init__(self, reset_interval: int | None | Literal['auto'] = None, clip_beta=False, inner: Chainable | None = None):
|
|
110
179
|
super().__init__(clip_beta=clip_beta, reset_interval=reset_interval, inner=inner)
|
|
111
180
|
|
|
@@ -120,7 +189,11 @@ def dai_yuan_beta(g: TensorList, prev_d: TensorList,prev_g: TensorList):
|
|
|
120
189
|
return (g.dot(g) / denom).neg()
|
|
121
190
|
|
|
122
191
|
class DaiYuan(ConguateGradientBase):
|
|
123
|
-
"""Dai–Yuan nonlinear conjugate gradient method.
|
|
192
|
+
"""Dai–Yuan nonlinear conjugate gradient method.
|
|
193
|
+
|
|
194
|
+
.. note::
|
|
195
|
+
- This requires step size to be determined via a line search, so put a line search like :code:`StrongWolfe(c2=0.1)` after this. Although Dai–Yuan formula provides an automatic step size scaling so it is technically possible to omit line search and instead use a small step size.
|
|
196
|
+
"""
|
|
124
197
|
def __init__(self, reset_interval: int | None | Literal['auto'] = None, clip_beta=False, inner: Chainable | None = None):
|
|
125
198
|
super().__init__(clip_beta=clip_beta, reset_interval=reset_interval, inner=inner)
|
|
126
199
|
|
|
@@ -135,7 +208,11 @@ def liu_storey_beta(g:TensorList, prev_d:TensorList, prev_g:TensorList, ):
|
|
|
135
208
|
return g.dot(g - prev_g) / denom
|
|
136
209
|
|
|
137
210
|
class LiuStorey(ConguateGradientBase):
|
|
138
|
-
"""Liu-Storey nonlinear conjugate gradient method.
|
|
211
|
+
"""Liu-Storey nonlinear conjugate gradient method.
|
|
212
|
+
|
|
213
|
+
.. note::
|
|
214
|
+
- This requires step size to be determined via a line search, so put a line search like :code:`StrongWolfe(c2=0.1)` after this.
|
|
215
|
+
"""
|
|
139
216
|
def __init__(self, reset_interval: int | None | Literal['auto'] = None, clip_beta=False, inner: Chainable | None = None):
|
|
140
217
|
super().__init__(clip_beta=clip_beta, reset_interval=reset_interval, inner=inner)
|
|
141
218
|
|
|
@@ -144,7 +221,11 @@ class LiuStorey(ConguateGradientBase):
|
|
|
144
221
|
|
|
145
222
|
# ----------------------------- Conjugate Descent ---------------------------- #
|
|
146
223
|
class ConjugateDescent(Transform):
|
|
147
|
-
"""Conjugate Descent (CD).
|
|
224
|
+
"""Conjugate Descent (CD).
|
|
225
|
+
|
|
226
|
+
.. note::
|
|
227
|
+
- This requires step size to be determined via a line search, so put a line search like :code:`StrongWolfe(c2=0.1)` after this.
|
|
228
|
+
"""
|
|
148
229
|
def __init__(self, inner: Chainable | None = None):
|
|
149
230
|
super().__init__(defaults={}, uses_grad=False)
|
|
150
231
|
|
|
@@ -153,7 +234,7 @@ class ConjugateDescent(Transform):
|
|
|
153
234
|
|
|
154
235
|
|
|
155
236
|
@torch.no_grad
|
|
156
|
-
def
|
|
237
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
157
238
|
g = as_tensorlist(tensors)
|
|
158
239
|
|
|
159
240
|
prev_d = unpack_states(states, tensors, 'prev_dir', cls=TensorList, init=torch.zeros_like)
|
|
@@ -188,7 +269,10 @@ def hager_zhang_beta(g:TensorList, prev_d:TensorList, prev_g:TensorList,):
|
|
|
188
269
|
|
|
189
270
|
class HagerZhang(ConguateGradientBase):
|
|
190
271
|
"""Hager-Zhang nonlinear conjugate gradient method,
|
|
191
|
-
|
|
272
|
+
|
|
273
|
+
.. note::
|
|
274
|
+
- This requires step size to be determined via a line search, so put a line search like :code:`StrongWolfe(c2=0.1)` after this.
|
|
275
|
+
"""
|
|
192
276
|
def __init__(self, reset_interval: int | None | Literal['auto'] = None, clip_beta=False, inner: Chainable | None = None):
|
|
193
277
|
super().__init__(clip_beta=clip_beta, reset_interval=reset_interval, inner=inner)
|
|
194
278
|
|
|
@@ -212,7 +296,10 @@ def hs_dy_beta(g: TensorList, prev_d: TensorList,prev_g: TensorList):
|
|
|
212
296
|
|
|
213
297
|
class HybridHS_DY(ConguateGradientBase):
|
|
214
298
|
"""HS-DY hybrid conjugate gradient method.
|
|
215
|
-
|
|
299
|
+
|
|
300
|
+
.. note::
|
|
301
|
+
- This requires step size to be determined via a line search, so put a line search like :code:`StrongWolfe(c2=0.1)` after this.
|
|
302
|
+
"""
|
|
216
303
|
def __init__(self, reset_interval: int | None | Literal['auto'] = None, clip_beta=False, inner: Chainable | None = None):
|
|
217
304
|
super().__init__(clip_beta=clip_beta, reset_interval=reset_interval, inner=inner)
|
|
218
305
|
|
|
@@ -220,49 +307,63 @@ class HybridHS_DY(ConguateGradientBase):
|
|
|
220
307
|
return hs_dy_beta(g, prev_d, prev_g)
|
|
221
308
|
|
|
222
309
|
|
|
223
|
-
def projected_gradient_(H:torch.Tensor, y:torch.Tensor
|
|
310
|
+
def projected_gradient_(H:torch.Tensor, y:torch.Tensor):
|
|
224
311
|
Hy = H @ y
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
H -= (H @ y.outer(y) @ H) / denom
|
|
312
|
+
yHy = _safe_clip(y.dot(Hy))
|
|
313
|
+
H -= (Hy.outer(y) @ H) / yHy
|
|
228
314
|
return H
|
|
229
315
|
|
|
230
|
-
class ProjectedGradientMethod(
|
|
231
|
-
"""
|
|
316
|
+
class ProjectedGradientMethod(HessianUpdateStrategy): # this doesn't maintain hessian
|
|
317
|
+
"""Projected gradient method.
|
|
318
|
+
|
|
319
|
+
.. note::
|
|
320
|
+
This method uses N^2 memory.
|
|
321
|
+
|
|
322
|
+
.. note::
|
|
323
|
+
This requires step size to be determined via a line search, so put a line search like :code:`StrongWolfe(c2=0.1)` after this.
|
|
324
|
+
|
|
325
|
+
.. note::
|
|
326
|
+
This is not the same as projected gradient descent.
|
|
327
|
+
|
|
328
|
+
Reference:
|
|
329
|
+
Pearson, J. D. (1969). Variable metric methods of minimisation. The Computer Journal, 12(2), 171–178. doi:10.1093/comjnl/12.2.171.
|
|
232
330
|
|
|
233
|
-
(This is not the same as projected gradient descent)
|
|
234
331
|
"""
|
|
235
332
|
|
|
236
333
|
def __init__(
|
|
237
334
|
self,
|
|
238
|
-
|
|
239
|
-
|
|
335
|
+
init_scale: float | Literal["auto"] = 1,
|
|
336
|
+
tol: float = 1e-8,
|
|
337
|
+
ptol: float | None = 1e-10,
|
|
338
|
+
ptol_reset: bool = False,
|
|
339
|
+
gtol: float | None = 1e-10,
|
|
340
|
+
reset_interval: int | None | Literal['auto'] = 'auto',
|
|
341
|
+
beta: float | None = None,
|
|
240
342
|
update_freq: int = 1,
|
|
241
343
|
scale_first: bool = False,
|
|
344
|
+
scale_second: bool = False,
|
|
242
345
|
concat_params: bool = True,
|
|
346
|
+
# inverse: bool = True,
|
|
243
347
|
inner: Chainable | None = None,
|
|
244
348
|
):
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
def apply_tensor(self, tensor, param, grad, loss, state, settings):
|
|
267
|
-
H = state['H']
|
|
268
|
-
return (H @ tensor.view(-1)).view_as(tensor)
|
|
349
|
+
super().__init__(
|
|
350
|
+
defaults=None,
|
|
351
|
+
init_scale=init_scale,
|
|
352
|
+
tol=tol,
|
|
353
|
+
ptol=ptol,
|
|
354
|
+
ptol_reset=ptol_reset,
|
|
355
|
+
gtol=gtol,
|
|
356
|
+
reset_interval=reset_interval,
|
|
357
|
+
beta=beta,
|
|
358
|
+
update_freq=update_freq,
|
|
359
|
+
scale_first=scale_first,
|
|
360
|
+
scale_second=scale_second,
|
|
361
|
+
concat_params=concat_params,
|
|
362
|
+
inverse=True,
|
|
363
|
+
inner=inner,
|
|
364
|
+
)
|
|
365
|
+
|
|
366
|
+
|
|
367
|
+
|
|
368
|
+
def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
|
|
369
|
+
return projected_gradient_(H=H, y=y)
|
|
@@ -0,0 +1,163 @@
|
|
|
1
|
+
from collections.abc import Callable
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from .quasi_newton import (
|
|
6
|
+
HessianUpdateStrategy,
|
|
7
|
+
_HessianUpdateStrategyDefaults,
|
|
8
|
+
_InverseHessianUpdateStrategyDefaults,
|
|
9
|
+
_safe_clip,
|
|
10
|
+
)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def _diag_Bv(self: HessianUpdateStrategy):
|
|
14
|
+
B, is_inverse = self.get_B()
|
|
15
|
+
|
|
16
|
+
if is_inverse:
|
|
17
|
+
H=B
|
|
18
|
+
def Hxv(v): return v/H
|
|
19
|
+
return Hxv
|
|
20
|
+
|
|
21
|
+
def Bv(v): return B*v
|
|
22
|
+
return Bv
|
|
23
|
+
|
|
24
|
+
def _diag_Hv(self: HessianUpdateStrategy):
|
|
25
|
+
H, is_inverse = self.get_H()
|
|
26
|
+
|
|
27
|
+
if is_inverse:
|
|
28
|
+
B=H
|
|
29
|
+
def Bxv(v): return v/B
|
|
30
|
+
return Bxv
|
|
31
|
+
|
|
32
|
+
def Hv(v): return H*v
|
|
33
|
+
return Hv
|
|
34
|
+
|
|
35
|
+
def diagonal_bfgs_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
|
|
36
|
+
sy = s.dot(y)
|
|
37
|
+
if sy < tol: return H
|
|
38
|
+
|
|
39
|
+
sy_sq = _safe_clip(sy**2)
|
|
40
|
+
|
|
41
|
+
num1 = (sy + (y * H * y)) * s*s
|
|
42
|
+
term1 = num1.div_(sy_sq)
|
|
43
|
+
num2 = (H * y * s).add_(s * y * H)
|
|
44
|
+
term2 = num2.div_(sy)
|
|
45
|
+
H += term1.sub_(term2)
|
|
46
|
+
return H
|
|
47
|
+
|
|
48
|
+
class DiagonalBFGS(_InverseHessianUpdateStrategyDefaults):
|
|
49
|
+
"""Diagonal BFGS. This is simply BFGS with only the diagonal being updated and used. It doesn't satisfy the secant equation but may still be useful."""
|
|
50
|
+
def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
|
|
51
|
+
return diagonal_bfgs_H_(H=H, s=s, y=y, tol=setting['tol'])
|
|
52
|
+
|
|
53
|
+
def _init_M(self, size:int, device, dtype, is_inverse:bool): return torch.ones(size, device=device, dtype=dtype)
|
|
54
|
+
def make_Bv(self): return _diag_Bv(self)
|
|
55
|
+
def make_Hv(self): return _diag_Hv(self)
|
|
56
|
+
|
|
57
|
+
def diagonal_sr1_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol:float):
|
|
58
|
+
z = s - H*y
|
|
59
|
+
denom = z.dot(y)
|
|
60
|
+
|
|
61
|
+
z_norm = torch.linalg.norm(z) # pylint:disable=not-callable
|
|
62
|
+
y_norm = torch.linalg.norm(y) # pylint:disable=not-callable
|
|
63
|
+
|
|
64
|
+
# if y_norm*z_norm < tol: return H
|
|
65
|
+
|
|
66
|
+
# check as in Nocedal, Wright. “Numerical optimization” 2nd p.146
|
|
67
|
+
if denom.abs() <= tol * y_norm * z_norm: return H # pylint:disable=not-callable
|
|
68
|
+
H += (z*z).div_(_safe_clip(denom))
|
|
69
|
+
return H
|
|
70
|
+
class DiagonalSR1(_InverseHessianUpdateStrategyDefaults):
|
|
71
|
+
"""Diagonal SR1. This is simply SR1 with only the diagonal being updated and used. It doesn't satisfy the secant equation but may still be useful."""
|
|
72
|
+
def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
|
|
73
|
+
return diagonal_sr1_(H=H, s=s, y=y, tol=setting['tol'])
|
|
74
|
+
def update_B(self, B, s, y, p, g, p_prev, g_prev, state, setting):
|
|
75
|
+
return diagonal_sr1_(H=B, s=y, y=s, tol=setting['tol'])
|
|
76
|
+
|
|
77
|
+
def _init_M(self, size:int, device, dtype, is_inverse:bool): return torch.ones(size, device=device, dtype=dtype)
|
|
78
|
+
def make_Bv(self): return _diag_Bv(self)
|
|
79
|
+
def make_Hv(self): return _diag_Hv(self)
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
# Zhu M., Nazareth J. L., Wolkowicz H. The quasi-Cauchy relation and diagonal updating //SIAM Journal on Optimization. – 1999. – Т. 9. – №. 4. – С. 1192-1204.
|
|
84
|
+
def diagonal_qc_B_(B:torch.Tensor, s: torch.Tensor, y:torch.Tensor):
|
|
85
|
+
denom = _safe_clip((s**4).sum())
|
|
86
|
+
num = s.dot(y) - (s*B).dot(s)
|
|
87
|
+
B += s**2 * (num/denom)
|
|
88
|
+
return B
|
|
89
|
+
|
|
90
|
+
class DiagonalQuasiCauchi(_HessianUpdateStrategyDefaults):
|
|
91
|
+
"""Diagonal quasi-cauchi method.
|
|
92
|
+
|
|
93
|
+
Reference:
|
|
94
|
+
Zhu M., Nazareth J. L., Wolkowicz H. The quasi-Cauchy relation and diagonal updating //SIAM Journal on Optimization. – 1999. – Т. 9. – №. 4. – С. 1192-1204.
|
|
95
|
+
"""
|
|
96
|
+
def update_B(self, B, s, y, p, g, p_prev, g_prev, state, setting):
|
|
97
|
+
return diagonal_qc_B_(B=B, s=s, y=y)
|
|
98
|
+
|
|
99
|
+
def _init_M(self, size:int, device, dtype, is_inverse:bool): return torch.ones(size, device=device, dtype=dtype)
|
|
100
|
+
def make_Bv(self): return _diag_Bv(self)
|
|
101
|
+
def make_Hv(self): return _diag_Hv(self)
|
|
102
|
+
|
|
103
|
+
# Leong, Wah June, Sharareh Enshaei, and Sie Long Kek. "Diagonal quasi-Newton methods via least change updating principle with weighted Frobenius norm." Numerical Algorithms 86 (2021): 1225-1241.
|
|
104
|
+
def diagonal_wqc_B_(B:torch.Tensor, s: torch.Tensor, y:torch.Tensor):
|
|
105
|
+
E_sq = s**2 * B**2
|
|
106
|
+
denom = _safe_clip((s*E_sq).dot(s))
|
|
107
|
+
num = s.dot(y) - (s*B).dot(s)
|
|
108
|
+
B += E_sq * (num/denom)
|
|
109
|
+
return B
|
|
110
|
+
|
|
111
|
+
class DiagonalWeightedQuasiCauchi(_HessianUpdateStrategyDefaults):
|
|
112
|
+
"""Diagonal quasi-cauchi method.
|
|
113
|
+
|
|
114
|
+
Reference:
|
|
115
|
+
Leong, Wah June, Sharareh Enshaei, and Sie Long Kek. "Diagonal quasi-Newton methods via least change updating principle with weighted Frobenius norm." Numerical Algorithms 86 (2021): 1225-1241.
|
|
116
|
+
"""
|
|
117
|
+
def update_B(self, B, s, y, p, g, p_prev, g_prev, state, setting):
|
|
118
|
+
return diagonal_wqc_B_(B=B, s=s, y=y)
|
|
119
|
+
|
|
120
|
+
def _init_M(self, size:int, device, dtype, is_inverse:bool): return torch.ones(size, device=device, dtype=dtype)
|
|
121
|
+
def make_Bv(self): return _diag_Bv(self)
|
|
122
|
+
def make_Hv(self): return _diag_Hv(self)
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
# Andrei, Neculai. "A diagonal quasi-Newton updating method for unconstrained optimization." Numerical Algorithms 81.2 (2019): 575-590.
|
|
126
|
+
def dnrtr_B_(B:torch.Tensor, s: torch.Tensor, y:torch.Tensor):
|
|
127
|
+
denom = _safe_clip((s**4).sum())
|
|
128
|
+
num = s.dot(y) + s.dot(s) - (s*B).dot(s)
|
|
129
|
+
B += s**2 * (num/denom) - 1
|
|
130
|
+
return B
|
|
131
|
+
|
|
132
|
+
class DNRTR(_HessianUpdateStrategyDefaults):
|
|
133
|
+
"""Diagonal quasi-newton method.
|
|
134
|
+
|
|
135
|
+
Reference:
|
|
136
|
+
Andrei, Neculai. "A diagonal quasi-Newton updating method for unconstrained optimization." Numerical Algorithms 81.2 (2019): 575-590.
|
|
137
|
+
"""
|
|
138
|
+
def update_B(self, B, s, y, p, g, p_prev, g_prev, state, setting):
|
|
139
|
+
return diagonal_wqc_B_(B=B, s=s, y=y)
|
|
140
|
+
|
|
141
|
+
def _init_M(self, size:int, device, dtype, is_inverse:bool): return torch.ones(size, device=device, dtype=dtype)
|
|
142
|
+
def make_Bv(self): return _diag_Bv(self)
|
|
143
|
+
def make_Hv(self): return _diag_Hv(self)
|
|
144
|
+
|
|
145
|
+
# Nosrati, Mahsa, and Keyvan Amini. "A new diagonal quasi-Newton algorithm for unconstrained optimization problems." Applications of Mathematics 69.4 (2024): 501-512.
|
|
146
|
+
def new_dqn_B_(B:torch.Tensor, s: torch.Tensor, y:torch.Tensor):
|
|
147
|
+
denom = _safe_clip((s**4).sum())
|
|
148
|
+
num = s.dot(y)
|
|
149
|
+
B += s**2 * (num/denom)
|
|
150
|
+
return B
|
|
151
|
+
|
|
152
|
+
class NewDQN(_HessianUpdateStrategyDefaults):
|
|
153
|
+
"""Diagonal quasi-newton method.
|
|
154
|
+
|
|
155
|
+
Reference:
|
|
156
|
+
Nosrati, Mahsa, and Keyvan Amini. "A new diagonal quasi-Newton algorithm for unconstrained optimization problems." Applications of Mathematics 69.4 (2024): 501-512.
|
|
157
|
+
"""
|
|
158
|
+
def update_B(self, B, s, y, p, g, p_prev, g_prev, state, setting):
|
|
159
|
+
return new_dqn_B_(B=B, s=s, y=y)
|
|
160
|
+
|
|
161
|
+
def _init_M(self, size:int, device, dtype, is_inverse:bool): return torch.ones(size, device=device, dtype=dtype)
|
|
162
|
+
def make_Bv(self): return _diag_Bv(self)
|
|
163
|
+
def make_Hv(self): return _diag_Hv(self)
|