torchzero 0.3.11__py3-none-any.whl → 0.3.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_opts.py +95 -69
- tests/test_tensorlist.py +8 -7
- torchzero/__init__.py +1 -1
- torchzero/core/__init__.py +2 -2
- torchzero/core/module.py +225 -72
- torchzero/core/reformulation.py +65 -0
- torchzero/core/transform.py +44 -24
- torchzero/modules/__init__.py +13 -5
- torchzero/modules/{optimizers → adaptive}/__init__.py +5 -2
- torchzero/modules/adaptive/adagrad.py +356 -0
- torchzero/modules/{optimizers → adaptive}/adahessian.py +53 -52
- torchzero/modules/{optimizers → adaptive}/adam.py +0 -3
- torchzero/modules/{optimizers → adaptive}/adan.py +26 -40
- torchzero/modules/{optimizers → adaptive}/adaptive_heavyball.py +3 -6
- torchzero/modules/adaptive/aegd.py +54 -0
- torchzero/modules/{optimizers → adaptive}/esgd.py +1 -1
- torchzero/modules/{optimizers/ladagrad.py → adaptive/lmadagrad.py} +42 -39
- torchzero/modules/{optimizers → adaptive}/mars.py +24 -36
- torchzero/modules/adaptive/matrix_momentum.py +146 -0
- torchzero/modules/{optimizers → adaptive}/msam.py +14 -12
- torchzero/modules/{optimizers → adaptive}/muon.py +19 -20
- torchzero/modules/adaptive/natural_gradient.py +175 -0
- torchzero/modules/{optimizers → adaptive}/rprop.py +0 -2
- torchzero/modules/{optimizers → adaptive}/sam.py +1 -1
- torchzero/modules/{optimizers → adaptive}/shampoo.py +8 -4
- torchzero/modules/{optimizers → adaptive}/soap.py +27 -50
- torchzero/modules/{optimizers → adaptive}/sophia_h.py +2 -3
- torchzero/modules/clipping/clipping.py +85 -92
- torchzero/modules/clipping/ema_clipping.py +5 -5
- torchzero/modules/conjugate_gradient/__init__.py +11 -0
- torchzero/modules/{quasi_newton → conjugate_gradient}/cg.py +355 -369
- torchzero/modules/experimental/__init__.py +9 -32
- torchzero/modules/experimental/dct.py +2 -2
- torchzero/modules/experimental/fft.py +2 -2
- torchzero/modules/experimental/gradmin.py +4 -3
- torchzero/modules/experimental/l_infinity.py +111 -0
- torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +3 -40
- torchzero/modules/experimental/newton_solver.py +79 -17
- torchzero/modules/experimental/newtonnewton.py +27 -14
- torchzero/modules/experimental/scipy_newton_cg.py +105 -0
- torchzero/modules/experimental/structural_projections.py +1 -1
- torchzero/modules/functional.py +50 -14
- torchzero/modules/grad_approximation/fdm.py +19 -20
- torchzero/modules/grad_approximation/forward_gradient.py +4 -2
- torchzero/modules/grad_approximation/grad_approximator.py +43 -47
- torchzero/modules/grad_approximation/rfdm.py +144 -122
- torchzero/modules/higher_order/__init__.py +1 -1
- torchzero/modules/higher_order/higher_order_newton.py +31 -23
- torchzero/modules/least_squares/__init__.py +1 -0
- torchzero/modules/least_squares/gn.py +161 -0
- torchzero/modules/line_search/__init__.py +2 -2
- torchzero/modules/line_search/_polyinterp.py +289 -0
- torchzero/modules/line_search/adaptive.py +69 -44
- torchzero/modules/line_search/backtracking.py +83 -70
- torchzero/modules/line_search/line_search.py +159 -68
- torchzero/modules/line_search/scipy.py +1 -1
- torchzero/modules/line_search/strong_wolfe.py +319 -218
- torchzero/modules/misc/__init__.py +8 -0
- torchzero/modules/misc/debug.py +4 -4
- torchzero/modules/misc/escape.py +9 -7
- torchzero/modules/misc/gradient_accumulation.py +88 -22
- torchzero/modules/misc/homotopy.py +59 -0
- torchzero/modules/misc/misc.py +82 -15
- torchzero/modules/misc/multistep.py +47 -11
- torchzero/modules/misc/regularization.py +5 -9
- torchzero/modules/misc/split.py +55 -35
- torchzero/modules/misc/switch.py +1 -1
- torchzero/modules/momentum/__init__.py +1 -5
- torchzero/modules/momentum/averaging.py +3 -3
- torchzero/modules/momentum/cautious.py +42 -47
- torchzero/modules/momentum/momentum.py +35 -1
- torchzero/modules/ops/__init__.py +9 -1
- torchzero/modules/ops/binary.py +9 -8
- torchzero/modules/{momentum/ema.py → ops/higher_level.py} +10 -33
- torchzero/modules/ops/multi.py +15 -15
- torchzero/modules/ops/reduce.py +1 -1
- torchzero/modules/ops/utility.py +12 -8
- torchzero/modules/projections/projection.py +4 -4
- torchzero/modules/quasi_newton/__init__.py +1 -16
- torchzero/modules/quasi_newton/damping.py +105 -0
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -163
- torchzero/modules/quasi_newton/lbfgs.py +256 -200
- torchzero/modules/quasi_newton/lsr1.py +167 -132
- torchzero/modules/quasi_newton/quasi_newton.py +346 -446
- torchzero/modules/restarts/__init__.py +7 -0
- torchzero/modules/restarts/restars.py +252 -0
- torchzero/modules/second_order/__init__.py +2 -1
- torchzero/modules/second_order/multipoint.py +238 -0
- torchzero/modules/second_order/newton.py +133 -88
- torchzero/modules/second_order/newton_cg.py +141 -80
- torchzero/modules/smoothing/__init__.py +1 -1
- torchzero/modules/smoothing/sampling.py +300 -0
- torchzero/modules/step_size/__init__.py +1 -1
- torchzero/modules/step_size/adaptive.py +312 -47
- torchzero/modules/termination/__init__.py +14 -0
- torchzero/modules/termination/termination.py +207 -0
- torchzero/modules/trust_region/__init__.py +5 -0
- torchzero/modules/trust_region/cubic_regularization.py +170 -0
- torchzero/modules/trust_region/dogleg.py +92 -0
- torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
- torchzero/modules/trust_region/trust_cg.py +97 -0
- torchzero/modules/trust_region/trust_region.py +350 -0
- torchzero/modules/variance_reduction/__init__.py +1 -0
- torchzero/modules/variance_reduction/svrg.py +208 -0
- torchzero/modules/weight_decay/weight_decay.py +65 -64
- torchzero/modules/zeroth_order/__init__.py +1 -0
- torchzero/modules/zeroth_order/cd.py +359 -0
- torchzero/optim/root.py +65 -0
- torchzero/optim/utility/split.py +8 -8
- torchzero/optim/wrappers/directsearch.py +0 -1
- torchzero/optim/wrappers/fcmaes.py +3 -2
- torchzero/optim/wrappers/nlopt.py +0 -2
- torchzero/optim/wrappers/optuna.py +2 -2
- torchzero/optim/wrappers/scipy.py +81 -22
- torchzero/utils/__init__.py +40 -4
- torchzero/utils/compile.py +1 -1
- torchzero/utils/derivatives.py +123 -111
- torchzero/utils/linalg/__init__.py +9 -2
- torchzero/utils/linalg/linear_operator.py +329 -0
- torchzero/utils/linalg/matrix_funcs.py +2 -2
- torchzero/utils/linalg/orthogonalize.py +2 -1
- torchzero/utils/linalg/qr.py +2 -2
- torchzero/utils/linalg/solve.py +226 -154
- torchzero/utils/metrics.py +83 -0
- torchzero/utils/python_tools.py +6 -0
- torchzero/utils/tensorlist.py +105 -34
- torchzero/utils/torch_tools.py +9 -4
- torchzero-0.3.13.dist-info/METADATA +14 -0
- torchzero-0.3.13.dist-info/RECORD +166 -0
- {torchzero-0.3.11.dist-info → torchzero-0.3.13.dist-info}/top_level.txt +0 -1
- docs/source/conf.py +0 -59
- docs/source/docstring template.py +0 -46
- torchzero/modules/experimental/absoap.py +0 -253
- torchzero/modules/experimental/adadam.py +0 -118
- torchzero/modules/experimental/adamY.py +0 -131
- torchzero/modules/experimental/adam_lambertw.py +0 -149
- torchzero/modules/experimental/adaptive_step_size.py +0 -90
- torchzero/modules/experimental/adasoap.py +0 -177
- torchzero/modules/experimental/cosine.py +0 -214
- torchzero/modules/experimental/cubic_adam.py +0 -97
- torchzero/modules/experimental/eigendescent.py +0 -120
- torchzero/modules/experimental/etf.py +0 -195
- torchzero/modules/experimental/exp_adam.py +0 -113
- torchzero/modules/experimental/expanded_lbfgs.py +0 -141
- torchzero/modules/experimental/hnewton.py +0 -85
- torchzero/modules/experimental/modular_lbfgs.py +0 -265
- torchzero/modules/experimental/parabolic_search.py +0 -220
- torchzero/modules/experimental/subspace_preconditioners.py +0 -145
- torchzero/modules/experimental/tensor_adagrad.py +0 -42
- torchzero/modules/line_search/polynomial.py +0 -233
- torchzero/modules/momentum/matrix_momentum.py +0 -193
- torchzero/modules/optimizers/adagrad.py +0 -165
- torchzero/modules/quasi_newton/trust_region.py +0 -397
- torchzero/modules/smoothing/gaussian.py +0 -198
- torchzero-0.3.11.dist-info/METADATA +0 -404
- torchzero-0.3.11.dist-info/RECORD +0 -159
- torchzero-0.3.11.dist-info/licenses/LICENSE +0 -21
- /torchzero/modules/{optimizers → adaptive}/lion.py +0 -0
- /torchzero/modules/{optimizers → adaptive}/orthograd.py +0 -0
- /torchzero/modules/{optimizers → adaptive}/rmsprop.py +0 -0
- {torchzero-0.3.11.dist-info → torchzero-0.3.13.dist-info}/WHEEL +0 -0
|
@@ -1,218 +1,253 @@
|
|
|
1
1
|
from collections import deque
|
|
2
|
+
from collections.abc import Sequence
|
|
2
3
|
from operator import itemgetter
|
|
3
4
|
|
|
4
5
|
import torch
|
|
5
6
|
|
|
6
7
|
from ...core import Chainable, Module, Transform, Var, apply_transform
|
|
7
|
-
from ...utils import NumberList, TensorList, as_tensorlist,
|
|
8
|
-
from
|
|
9
|
-
from
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
def lsr1_(
|
|
13
|
-
tensors_: TensorList,
|
|
14
|
-
s_history: deque[TensorList],
|
|
15
|
-
y_history: deque[TensorList],
|
|
16
|
-
step: int,
|
|
17
|
-
scale_second: bool,
|
|
18
|
-
):
|
|
19
|
-
if len(s_history) == 0:
|
|
20
|
-
# initial step size guess from pytorch
|
|
21
|
-
return safe_scaling_(TensorList(tensors_))
|
|
8
|
+
from ...utils import NumberList, TensorList, as_tensorlist, generic_finfo_tiny, unpack_states, vec_to_tensors_
|
|
9
|
+
from ...utils.linalg.linear_operator import LinearOperator
|
|
10
|
+
from ..functional import initial_step_size
|
|
11
|
+
from .damping import DampingStrategyType, apply_damping
|
|
22
12
|
|
|
13
|
+
|
|
14
|
+
def lsr1_Hx(x, s_history: Sequence, y_history: Sequence,):
|
|
23
15
|
m = len(s_history)
|
|
16
|
+
if m == 0: return x.clone()
|
|
17
|
+
eps = generic_finfo_tiny(x) * 2
|
|
24
18
|
|
|
25
|
-
w_list
|
|
26
|
-
ww_list: list = [None for _ in range(m)]
|
|
19
|
+
w_list = []
|
|
27
20
|
wy_list: list = [None for _ in range(m)]
|
|
28
21
|
|
|
29
|
-
# 1st loop - all w_k = s_k - H_k_prev y_k
|
|
22
|
+
# # 1st loop - all w_k = s_k - H_k_prev y_k
|
|
30
23
|
for k in range(m):
|
|
31
24
|
s_k = s_history[k]
|
|
32
25
|
y_k = y_history[k]
|
|
33
26
|
|
|
34
|
-
|
|
27
|
+
Hx = y_k.clone()
|
|
35
28
|
for j in range(k):
|
|
36
29
|
w_j = w_list[j]
|
|
37
30
|
y_j = y_history[j]
|
|
38
31
|
|
|
39
32
|
wy = wy_list[j]
|
|
40
33
|
if wy is None: wy = wy_list[j] = w_j.dot(y_j)
|
|
34
|
+
if wy.abs() < eps: continue
|
|
41
35
|
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
if wy == 0: continue
|
|
36
|
+
alpha = w_j.dot(y_k) / wy
|
|
37
|
+
Hx.add_(w_j, alpha=alpha)
|
|
46
38
|
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
w_k = s_k - H_k
|
|
39
|
+
w_k = s_k - Hx
|
|
50
40
|
w_list.append(w_k)
|
|
51
41
|
|
|
52
|
-
Hx =
|
|
42
|
+
Hx = x.clone()
|
|
43
|
+
|
|
44
|
+
# second loop
|
|
53
45
|
for k in range(m):
|
|
54
46
|
w_k = w_list[k]
|
|
55
47
|
y_k = y_history[k]
|
|
56
48
|
wy = wy_list[k]
|
|
57
|
-
ww = ww_list[k]
|
|
58
49
|
|
|
59
50
|
if wy is None: wy = w_k.dot(y_k) # this happens when m = 1 so inner loop doesn't run
|
|
60
|
-
if
|
|
51
|
+
if wy.abs() < eps: continue
|
|
61
52
|
|
|
62
|
-
|
|
53
|
+
alpha = w_k.dot(x) / wy
|
|
54
|
+
Hx.add_(w_k, alpha=alpha)
|
|
63
55
|
|
|
64
|
-
|
|
56
|
+
return Hx
|
|
65
57
|
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
scale_factor = scale_factor.clip(min=torch.finfo(tensors_[0].dtype).eps)
|
|
69
|
-
Hx.mul_(scale_factor)
|
|
58
|
+
def lsr1_Bx(x, s_history: Sequence, y_history: Sequence,):
|
|
59
|
+
return lsr1_Hx(x, s_history=y_history, y_history=s_history)
|
|
70
60
|
|
|
71
|
-
|
|
61
|
+
class LSR1LinearOperator(LinearOperator):
|
|
62
|
+
def __init__(self, s_history: Sequence[torch.Tensor], y_history: Sequence[torch.Tensor]):
|
|
63
|
+
super().__init__()
|
|
64
|
+
self.s_history = s_history
|
|
65
|
+
self.y_history = y_history
|
|
72
66
|
|
|
67
|
+
def solve(self, b):
|
|
68
|
+
return lsr1_Hx(x=b, s_history=self.s_history, y_history=self.y_history)
|
|
73
69
|
|
|
74
|
-
|
|
75
|
-
|
|
70
|
+
def matvec(self, x):
|
|
71
|
+
return lsr1_Bx(x=x, s_history=self.s_history, y_history=self.y_history)
|
|
76
72
|
|
|
77
|
-
|
|
78
|
-
|
|
73
|
+
def size(self):
|
|
74
|
+
if len(self.s_history) == 0: raise RuntimeError()
|
|
75
|
+
n = len(self.s_history[0])
|
|
76
|
+
return (n, n)
|
|
79
77
|
|
|
80
|
-
.. note::
|
|
81
|
-
L-SR1 update rule uses a nested loop, computationally with history size `n` it is similar to L-BFGS with history size `(n^2)/2`. On small problems (ndim <= 2000) BFGS and SR1 may be faster than limited-memory versions.
|
|
82
78
|
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
for example using :code:`tz.m.StrongWolfe(plus_minus=True)` line search, or modifying the direction with :code:`tz.m.Cautious` or :code:`tz.m.ScaleByGradCosineSimilarity`.
|
|
79
|
+
class LSR1(Transform):
|
|
80
|
+
"""Limited-memory SR1 algorithm. A line search or trust region is recommended.
|
|
86
81
|
|
|
87
82
|
Args:
|
|
88
83
|
history_size (int, optional):
|
|
89
84
|
number of past parameter differences and gradient differences to store. Defaults to 10.
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
85
|
+
ptol (float | None, optional):
|
|
86
|
+
skips updating the history if maximum absolute value of
|
|
87
|
+
parameter difference is less than this value. Defaults to None.
|
|
88
|
+
ptol_restart (bool, optional):
|
|
89
|
+
If true, whenever parameter difference is less then ``ptol``,
|
|
90
|
+
L-SR1 state will be reset. Defaults to None.
|
|
94
91
|
gtol (float | None, optional):
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
92
|
+
skips updating the history if if maximum absolute value of
|
|
93
|
+
gradient difference is less than this value. Defaults to None.
|
|
94
|
+
ptol_restart (bool, optional):
|
|
95
|
+
If true, whenever gradient difference is less then ``gtol``,
|
|
96
|
+
L-SR1 state will be reset. Defaults to None.
|
|
97
|
+
scale_first (bool, optional):
|
|
98
|
+
makes first step, when hessian approximation is not available,
|
|
99
|
+
small to reduce number of line search iterations. Defaults to False.
|
|
100
|
+
update_freq (int, optional):
|
|
101
|
+
how often to update L-SR1 history. Larger values may be better for stochastic optimization. Defaults to 1.
|
|
102
|
+
damping (DampingStrategyType, optional):
|
|
103
|
+
damping to use, can be "powell" or "double". Defaults to None.
|
|
104
|
+
compact (bool, optional):
|
|
105
|
+
if True, uses a compact representation verstion of L-SR1. It is much faster computationally, but less stable.
|
|
104
106
|
inner (Chainable | None, optional):
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
107
|
+
optional inner modules applied after updating L-SR1 history and before preconditioning. Defaults to None.
|
|
108
|
+
|
|
109
|
+
## Examples:
|
|
110
|
+
|
|
111
|
+
L-SR1 with line search
|
|
112
|
+
```python
|
|
113
|
+
opt = tz.Modular(
|
|
114
|
+
model.parameters(),
|
|
115
|
+
tz.m.SR1(),
|
|
116
|
+
tz.m.StrongWolfe(c2=0.1, fallback=True)
|
|
117
|
+
)
|
|
118
|
+
```
|
|
119
|
+
|
|
120
|
+
L-SR1 with trust region
|
|
121
|
+
```python
|
|
122
|
+
opt = tz.Modular(
|
|
123
|
+
model.parameters(),
|
|
124
|
+
tz.m.TrustCG(tz.m.LSR1())
|
|
125
|
+
)
|
|
126
|
+
```
|
|
118
127
|
"""
|
|
119
128
|
def __init__(
|
|
120
129
|
self,
|
|
121
|
-
history_size
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
gtol: float | None =
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
update_freq
|
|
128
|
-
|
|
130
|
+
history_size=10,
|
|
131
|
+
ptol: float | None = None,
|
|
132
|
+
ptol_restart: bool = False,
|
|
133
|
+
gtol: float | None = None,
|
|
134
|
+
gtol_restart: bool = False,
|
|
135
|
+
scale_first:bool=False,
|
|
136
|
+
update_freq = 1,
|
|
137
|
+
damping: DampingStrategyType = None,
|
|
129
138
|
inner: Chainable | None = None,
|
|
130
139
|
):
|
|
131
140
|
defaults = dict(
|
|
132
|
-
history_size=history_size,
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
141
|
+
history_size=history_size,
|
|
142
|
+
scale_first=scale_first,
|
|
143
|
+
ptol=ptol,
|
|
144
|
+
gtol=gtol,
|
|
145
|
+
ptol_restart=ptol_restart,
|
|
146
|
+
gtol_restart=gtol_restart,
|
|
147
|
+
damping = damping,
|
|
136
148
|
)
|
|
137
|
-
super().__init__(defaults, uses_grad=False, inner=inner)
|
|
149
|
+
super().__init__(defaults, uses_grad=False, inner=inner, update_freq=update_freq)
|
|
138
150
|
|
|
139
151
|
self.global_state['s_history'] = deque(maxlen=history_size)
|
|
140
152
|
self.global_state['y_history'] = deque(maxlen=history_size)
|
|
141
153
|
|
|
142
|
-
def
|
|
154
|
+
def _reset_self(self):
|
|
143
155
|
self.state.clear()
|
|
144
156
|
self.global_state['step'] = 0
|
|
145
157
|
self.global_state['s_history'].clear()
|
|
146
158
|
self.global_state['y_history'].clear()
|
|
147
159
|
|
|
160
|
+
def reset(self):
|
|
161
|
+
self._reset_self()
|
|
162
|
+
for c in self.children.values(): c.reset()
|
|
163
|
+
|
|
148
164
|
def reset_for_online(self):
|
|
149
165
|
super().reset_for_online()
|
|
150
|
-
self.clear_state_keys('
|
|
166
|
+
self.clear_state_keys('p_prev', 'g_prev')
|
|
151
167
|
self.global_state.pop('step', None)
|
|
152
168
|
|
|
153
169
|
@torch.no_grad
|
|
154
170
|
def update_tensors(self, tensors, params, grads, loss, states, settings):
|
|
155
|
-
|
|
156
|
-
|
|
171
|
+
p = as_tensorlist(params)
|
|
172
|
+
g = as_tensorlist(tensors)
|
|
157
173
|
step = self.global_state.get('step', 0)
|
|
158
174
|
self.global_state['step'] = step + 1
|
|
159
175
|
|
|
160
|
-
|
|
161
|
-
|
|
176
|
+
# history of s and k
|
|
177
|
+
s_history: deque = self.global_state['s_history']
|
|
178
|
+
y_history: deque = self.global_state['y_history']
|
|
162
179
|
|
|
163
|
-
|
|
164
|
-
|
|
180
|
+
ptol = self.defaults['ptol']
|
|
181
|
+
gtol = self.defaults['gtol']
|
|
182
|
+
ptol_restart = self.defaults['ptol_restart']
|
|
183
|
+
gtol_restart = self.defaults['gtol_restart']
|
|
184
|
+
damping = self.defaults['damping']
|
|
165
185
|
|
|
166
|
-
|
|
167
|
-
l_params, l_update = _lerp_params_update_(self, params, update, params_beta, grads_beta)
|
|
168
|
-
prev_l_params, prev_l_grad = unpack_states(states, tensors, 'prev_l_params', 'prev_l_grad', cls=TensorList)
|
|
186
|
+
p_prev, g_prev = unpack_states(states, tensors, 'p_prev', 'g_prev', cls=TensorList)
|
|
169
187
|
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
188
|
+
# 1st step - there are no previous params and grads, lsr1 will do normalized SGD step
|
|
189
|
+
if step == 0:
|
|
190
|
+
s = None; y = None; sy = None
|
|
191
|
+
else:
|
|
192
|
+
s = p - p_prev
|
|
193
|
+
y = g - g_prev
|
|
176
194
|
|
|
177
|
-
|
|
178
|
-
|
|
195
|
+
if damping is not None:
|
|
196
|
+
s, y = apply_damping(damping, s=s, y=y, g=g, H=self.get_H())
|
|
179
197
|
|
|
180
|
-
|
|
181
|
-
|
|
198
|
+
sy = s.dot(y)
|
|
199
|
+
# damping to be added here
|
|
182
200
|
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
201
|
+
below_tol = False
|
|
202
|
+
# tolerance on parameter difference to avoid exploding after converging
|
|
203
|
+
if ptol is not None:
|
|
204
|
+
if s is not None and s.abs().global_max() <= ptol:
|
|
205
|
+
if ptol_restart: self._reset_self()
|
|
206
|
+
sy = None
|
|
207
|
+
below_tol = True
|
|
208
|
+
|
|
209
|
+
# tolerance on gradient difference to avoid exploding when there is no curvature
|
|
210
|
+
if gtol is not None:
|
|
211
|
+
if y is not None and y.abs().global_max() <= gtol:
|
|
212
|
+
if gtol_restart: self._reset_self()
|
|
213
|
+
sy = None
|
|
214
|
+
below_tol = True
|
|
215
|
+
|
|
216
|
+
# store previous params and grads
|
|
217
|
+
if not below_tol:
|
|
218
|
+
p_prev.copy_(p)
|
|
219
|
+
g_prev.copy_(g)
|
|
220
|
+
|
|
221
|
+
# update effective preconditioning state
|
|
222
|
+
if sy is not None:
|
|
223
|
+
assert s is not None and y is not None and sy is not None
|
|
224
|
+
|
|
225
|
+
s_history.append(s)
|
|
226
|
+
y_history.append(y)
|
|
227
|
+
|
|
228
|
+
def get_H(self, var=...):
|
|
229
|
+
s_history = [tl.to_vec() for tl in self.global_state['s_history']]
|
|
230
|
+
y_history = [tl.to_vec() for tl in self.global_state['y_history']]
|
|
231
|
+
return LSR1LinearOperator(s_history, y_history)
|
|
186
232
|
|
|
187
233
|
@torch.no_grad
|
|
188
234
|
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
189
|
-
|
|
190
|
-
s = self.global_state.pop('s')
|
|
191
|
-
y = self.global_state.pop('y')
|
|
235
|
+
scale_first = self.defaults['scale_first']
|
|
192
236
|
|
|
193
|
-
|
|
194
|
-
tol = setting['tol']
|
|
195
|
-
gtol = setting['gtol']
|
|
196
|
-
tol_reset = setting['tol_reset']
|
|
237
|
+
tensors = as_tensorlist(tensors)
|
|
197
238
|
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
if s is not None and s.abs().global_max() <= tol:
|
|
201
|
-
if tol_reset: self.reset()
|
|
202
|
-
return safe_scaling_(TensorList(tensors))
|
|
203
|
-
|
|
204
|
-
# tolerance on gradient difference to avoid exploding when there is no curvature
|
|
205
|
-
if tol is not None:
|
|
206
|
-
if y is not None and y.abs().global_max() <= gtol:
|
|
207
|
-
return safe_scaling_(TensorList(tensors))
|
|
239
|
+
s_history = self.global_state['s_history']
|
|
240
|
+
y_history = self.global_state['y_history']
|
|
208
241
|
|
|
209
242
|
# precondition
|
|
210
|
-
dir =
|
|
211
|
-
|
|
212
|
-
s_history=
|
|
213
|
-
y_history=
|
|
214
|
-
step=self.global_state.get('step', 1),
|
|
215
|
-
scale_second=setting['scale_second'],
|
|
243
|
+
dir = lsr1_Hx(
|
|
244
|
+
x=tensors,
|
|
245
|
+
s_history=s_history,
|
|
246
|
+
y_history=y_history,
|
|
216
247
|
)
|
|
217
248
|
|
|
218
|
-
|
|
249
|
+
# scale 1st step
|
|
250
|
+
if scale_first and self.global_state.get('step', 1) == 1:
|
|
251
|
+
dir *= initial_step_size(dir, eps=1e-7)
|
|
252
|
+
|
|
253
|
+
return dir
|