torchzero 0.3.10__py3-none-any.whl → 0.3.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_identical.py +2 -3
- tests/test_opts.py +140 -100
- tests/test_tensorlist.py +8 -7
- tests/test_vars.py +1 -0
- torchzero/__init__.py +1 -1
- torchzero/core/__init__.py +2 -2
- torchzero/core/module.py +335 -50
- torchzero/core/reformulation.py +65 -0
- torchzero/core/transform.py +197 -70
- torchzero/modules/__init__.py +13 -4
- torchzero/modules/adaptive/__init__.py +30 -0
- torchzero/modules/adaptive/adagrad.py +356 -0
- torchzero/modules/adaptive/adahessian.py +224 -0
- torchzero/modules/{optimizers → adaptive}/adam.py +6 -8
- torchzero/modules/adaptive/adan.py +96 -0
- torchzero/modules/adaptive/adaptive_heavyball.py +54 -0
- torchzero/modules/adaptive/aegd.py +54 -0
- torchzero/modules/adaptive/esgd.py +171 -0
- torchzero/modules/{optimizers → adaptive}/lion.py +1 -1
- torchzero/modules/{experimental/spectral.py → adaptive/lmadagrad.py} +94 -71
- torchzero/modules/adaptive/mars.py +79 -0
- torchzero/modules/adaptive/matrix_momentum.py +146 -0
- torchzero/modules/adaptive/msam.py +188 -0
- torchzero/modules/{optimizers → adaptive}/muon.py +29 -5
- torchzero/modules/adaptive/natural_gradient.py +175 -0
- torchzero/modules/{optimizers → adaptive}/orthograd.py +1 -1
- torchzero/modules/{optimizers → adaptive}/rmsprop.py +7 -4
- torchzero/modules/{optimizers → adaptive}/rprop.py +42 -10
- torchzero/modules/adaptive/sam.py +163 -0
- torchzero/modules/{optimizers → adaptive}/shampoo.py +47 -9
- torchzero/modules/{optimizers → adaptive}/soap.py +52 -65
- torchzero/modules/adaptive/sophia_h.py +185 -0
- torchzero/modules/clipping/clipping.py +115 -25
- torchzero/modules/clipping/ema_clipping.py +31 -17
- torchzero/modules/clipping/growth_clipping.py +8 -7
- torchzero/modules/conjugate_gradient/__init__.py +11 -0
- torchzero/modules/conjugate_gradient/cg.py +355 -0
- torchzero/modules/experimental/__init__.py +13 -19
- torchzero/modules/{projections → experimental}/dct.py +11 -11
- torchzero/modules/{projections → experimental}/fft.py +10 -10
- torchzero/modules/experimental/gradmin.py +4 -3
- torchzero/modules/experimental/l_infinity.py +111 -0
- torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +5 -42
- torchzero/modules/experimental/newton_solver.py +79 -17
- torchzero/modules/experimental/newtonnewton.py +32 -15
- torchzero/modules/experimental/reduce_outward_lr.py +4 -4
- torchzero/modules/experimental/scipy_newton_cg.py +105 -0
- torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +13 -55
- torchzero/modules/functional.py +52 -6
- torchzero/modules/grad_approximation/fdm.py +30 -4
- torchzero/modules/grad_approximation/forward_gradient.py +16 -4
- torchzero/modules/grad_approximation/grad_approximator.py +51 -10
- torchzero/modules/grad_approximation/rfdm.py +321 -52
- torchzero/modules/higher_order/__init__.py +1 -1
- torchzero/modules/higher_order/higher_order_newton.py +164 -93
- torchzero/modules/least_squares/__init__.py +1 -0
- torchzero/modules/least_squares/gn.py +161 -0
- torchzero/modules/line_search/__init__.py +4 -4
- torchzero/modules/line_search/_polyinterp.py +289 -0
- torchzero/modules/line_search/adaptive.py +124 -0
- torchzero/modules/line_search/backtracking.py +95 -57
- torchzero/modules/line_search/line_search.py +171 -22
- torchzero/modules/line_search/scipy.py +3 -3
- torchzero/modules/line_search/strong_wolfe.py +327 -199
- torchzero/modules/misc/__init__.py +35 -0
- torchzero/modules/misc/debug.py +48 -0
- torchzero/modules/misc/escape.py +62 -0
- torchzero/modules/misc/gradient_accumulation.py +136 -0
- torchzero/modules/misc/homotopy.py +59 -0
- torchzero/modules/misc/misc.py +383 -0
- torchzero/modules/misc/multistep.py +194 -0
- torchzero/modules/misc/regularization.py +167 -0
- torchzero/modules/misc/split.py +123 -0
- torchzero/modules/{ops → misc}/switch.py +45 -4
- torchzero/modules/momentum/__init__.py +1 -5
- torchzero/modules/momentum/averaging.py +9 -9
- torchzero/modules/momentum/cautious.py +51 -19
- torchzero/modules/momentum/momentum.py +37 -2
- torchzero/modules/ops/__init__.py +11 -31
- torchzero/modules/ops/accumulate.py +6 -10
- torchzero/modules/ops/binary.py +81 -34
- torchzero/modules/{momentum/ema.py → ops/higher_level.py} +16 -39
- torchzero/modules/ops/multi.py +82 -21
- torchzero/modules/ops/reduce.py +16 -8
- torchzero/modules/ops/unary.py +29 -13
- torchzero/modules/ops/utility.py +30 -18
- torchzero/modules/projections/__init__.py +2 -4
- torchzero/modules/projections/cast.py +51 -0
- torchzero/modules/projections/galore.py +3 -1
- torchzero/modules/projections/projection.py +190 -96
- torchzero/modules/quasi_newton/__init__.py +9 -14
- torchzero/modules/quasi_newton/damping.py +105 -0
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -0
- torchzero/modules/quasi_newton/lbfgs.py +286 -173
- torchzero/modules/quasi_newton/lsr1.py +185 -106
- torchzero/modules/quasi_newton/quasi_newton.py +816 -268
- torchzero/modules/restarts/__init__.py +7 -0
- torchzero/modules/restarts/restars.py +252 -0
- torchzero/modules/second_order/__init__.py +3 -2
- torchzero/modules/second_order/multipoint.py +238 -0
- torchzero/modules/second_order/newton.py +292 -68
- torchzero/modules/second_order/newton_cg.py +365 -15
- torchzero/modules/second_order/nystrom.py +104 -1
- torchzero/modules/smoothing/__init__.py +1 -1
- torchzero/modules/smoothing/laplacian.py +14 -4
- torchzero/modules/smoothing/sampling.py +300 -0
- torchzero/modules/step_size/__init__.py +2 -0
- torchzero/modules/step_size/adaptive.py +387 -0
- torchzero/modules/step_size/lr.py +154 -0
- torchzero/modules/termination/__init__.py +14 -0
- torchzero/modules/termination/termination.py +207 -0
- torchzero/modules/trust_region/__init__.py +5 -0
- torchzero/modules/trust_region/cubic_regularization.py +170 -0
- torchzero/modules/trust_region/dogleg.py +92 -0
- torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
- torchzero/modules/trust_region/trust_cg.py +97 -0
- torchzero/modules/trust_region/trust_region.py +350 -0
- torchzero/modules/variance_reduction/__init__.py +1 -0
- torchzero/modules/variance_reduction/svrg.py +208 -0
- torchzero/modules/weight_decay/__init__.py +1 -1
- torchzero/modules/weight_decay/weight_decay.py +94 -11
- torchzero/modules/wrappers/optim_wrapper.py +29 -1
- torchzero/modules/zeroth_order/__init__.py +1 -0
- torchzero/modules/zeroth_order/cd.py +359 -0
- torchzero/optim/root.py +65 -0
- torchzero/optim/utility/split.py +8 -8
- torchzero/optim/wrappers/directsearch.py +39 -3
- torchzero/optim/wrappers/fcmaes.py +24 -15
- torchzero/optim/wrappers/mads.py +5 -6
- torchzero/optim/wrappers/nevergrad.py +16 -1
- torchzero/optim/wrappers/nlopt.py +0 -2
- torchzero/optim/wrappers/optuna.py +3 -3
- torchzero/optim/wrappers/scipy.py +86 -25
- torchzero/utils/__init__.py +40 -4
- torchzero/utils/compile.py +1 -1
- torchzero/utils/derivatives.py +126 -114
- torchzero/utils/linalg/__init__.py +9 -2
- torchzero/utils/linalg/linear_operator.py +329 -0
- torchzero/utils/linalg/matrix_funcs.py +2 -2
- torchzero/utils/linalg/orthogonalize.py +2 -1
- torchzero/utils/linalg/qr.py +2 -2
- torchzero/utils/linalg/solve.py +369 -58
- torchzero/utils/metrics.py +83 -0
- torchzero/utils/numberlist.py +2 -0
- torchzero/utils/python_tools.py +16 -0
- torchzero/utils/tensorlist.py +134 -51
- torchzero/utils/torch_tools.py +9 -4
- torchzero-0.3.13.dist-info/METADATA +14 -0
- torchzero-0.3.13.dist-info/RECORD +166 -0
- {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/top_level.txt +0 -1
- docs/source/conf.py +0 -57
- torchzero/modules/experimental/absoap.py +0 -250
- torchzero/modules/experimental/adadam.py +0 -112
- torchzero/modules/experimental/adamY.py +0 -125
- torchzero/modules/experimental/adasoap.py +0 -172
- torchzero/modules/experimental/diagonal_higher_order_newton.py +0 -225
- torchzero/modules/experimental/eigendescent.py +0 -117
- torchzero/modules/experimental/etf.py +0 -172
- torchzero/modules/experimental/soapy.py +0 -163
- torchzero/modules/experimental/structured_newton.py +0 -111
- torchzero/modules/experimental/subspace_preconditioners.py +0 -138
- torchzero/modules/experimental/tada.py +0 -38
- torchzero/modules/line_search/trust_region.py +0 -73
- torchzero/modules/lr/__init__.py +0 -2
- torchzero/modules/lr/adaptive.py +0 -93
- torchzero/modules/lr/lr.py +0 -63
- torchzero/modules/momentum/matrix_momentum.py +0 -166
- torchzero/modules/ops/debug.py +0 -25
- torchzero/modules/ops/misc.py +0 -418
- torchzero/modules/ops/split.py +0 -75
- torchzero/modules/optimizers/__init__.py +0 -18
- torchzero/modules/optimizers/adagrad.py +0 -155
- torchzero/modules/optimizers/sophia_h.py +0 -129
- torchzero/modules/quasi_newton/cg.py +0 -268
- torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
- torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +0 -266
- torchzero/modules/quasi_newton/olbfgs.py +0 -196
- torchzero/modules/smoothing/gaussian.py +0 -164
- torchzero-0.3.10.dist-info/METADATA +0 -379
- torchzero-0.3.10.dist-info/RECORD +0 -139
- torchzero-0.3.10.dist-info/licenses/LICENSE +0 -21
- {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/WHEEL +0 -0
|
@@ -1,174 +1,253 @@
|
|
|
1
1
|
from collections import deque
|
|
2
|
+
from collections.abc import Sequence
|
|
2
3
|
from operator import itemgetter
|
|
3
4
|
|
|
4
5
|
import torch
|
|
5
6
|
|
|
6
7
|
from ...core import Chainable, Module, Transform, Var, apply_transform
|
|
7
|
-
from ...utils import NumberList, TensorList, as_tensorlist
|
|
8
|
-
|
|
9
|
-
from
|
|
10
|
-
|
|
11
|
-
def lsr1_(
|
|
12
|
-
tensors_: TensorList,
|
|
13
|
-
s_history: deque[TensorList],
|
|
14
|
-
y_history: deque[TensorList],
|
|
15
|
-
step: int,
|
|
16
|
-
scale_second: bool,
|
|
17
|
-
):
|
|
18
|
-
if step == 0 or not s_history:
|
|
19
|
-
# initial step size guess from pytorch
|
|
20
|
-
scale_factor = 1 / TensorList(tensors_).abs().global_sum().clip(min=1)
|
|
21
|
-
scale_factor = scale_factor.clip(min=torch.finfo(tensors_[0].dtype).eps)
|
|
22
|
-
return tensors_.mul_(scale_factor)
|
|
8
|
+
from ...utils import NumberList, TensorList, as_tensorlist, generic_finfo_tiny, unpack_states, vec_to_tensors_
|
|
9
|
+
from ...utils.linalg.linear_operator import LinearOperator
|
|
10
|
+
from ..functional import initial_step_size
|
|
11
|
+
from .damping import DampingStrategyType, apply_damping
|
|
23
12
|
|
|
13
|
+
|
|
14
|
+
def lsr1_Hx(x, s_history: Sequence, y_history: Sequence,):
|
|
24
15
|
m = len(s_history)
|
|
16
|
+
if m == 0: return x.clone()
|
|
17
|
+
eps = generic_finfo_tiny(x) * 2
|
|
25
18
|
|
|
26
|
-
w_list
|
|
27
|
-
ww_list: list = [None for _ in range(m)]
|
|
19
|
+
w_list = []
|
|
28
20
|
wy_list: list = [None for _ in range(m)]
|
|
29
21
|
|
|
30
|
-
# 1st loop - all w_k = s_k - H_k_prev y_k
|
|
22
|
+
# # 1st loop - all w_k = s_k - H_k_prev y_k
|
|
31
23
|
for k in range(m):
|
|
32
24
|
s_k = s_history[k]
|
|
33
25
|
y_k = y_history[k]
|
|
34
26
|
|
|
35
|
-
|
|
27
|
+
Hx = y_k.clone()
|
|
36
28
|
for j in range(k):
|
|
37
29
|
w_j = w_list[j]
|
|
38
30
|
y_j = y_history[j]
|
|
39
31
|
|
|
40
32
|
wy = wy_list[j]
|
|
41
33
|
if wy is None: wy = wy_list[j] = w_j.dot(y_j)
|
|
34
|
+
if wy.abs() < eps: continue
|
|
42
35
|
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
if wy == 0: continue
|
|
47
|
-
|
|
48
|
-
H_k.add_(w_j, alpha=w_j.dot(y_k) / wy) # pyright:ignore[reportArgumentType]
|
|
36
|
+
alpha = w_j.dot(y_k) / wy
|
|
37
|
+
Hx.add_(w_j, alpha=alpha)
|
|
49
38
|
|
|
50
|
-
w_k = s_k -
|
|
39
|
+
w_k = s_k - Hx
|
|
51
40
|
w_list.append(w_k)
|
|
52
41
|
|
|
53
|
-
Hx =
|
|
42
|
+
Hx = x.clone()
|
|
43
|
+
|
|
44
|
+
# second loop
|
|
54
45
|
for k in range(m):
|
|
55
46
|
w_k = w_list[k]
|
|
56
47
|
y_k = y_history[k]
|
|
57
48
|
wy = wy_list[k]
|
|
58
|
-
ww = ww_list[k]
|
|
59
49
|
|
|
60
50
|
if wy is None: wy = w_k.dot(y_k) # this happens when m = 1 so inner loop doesn't run
|
|
61
|
-
if
|
|
51
|
+
if wy.abs() < eps: continue
|
|
52
|
+
|
|
53
|
+
alpha = w_k.dot(x) / wy
|
|
54
|
+
Hx.add_(w_k, alpha=alpha)
|
|
62
55
|
|
|
63
|
-
|
|
56
|
+
return Hx
|
|
64
57
|
|
|
65
|
-
|
|
58
|
+
def lsr1_Bx(x, s_history: Sequence, y_history: Sequence,):
|
|
59
|
+
return lsr1_Hx(x, s_history=y_history, y_history=s_history)
|
|
66
60
|
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
61
|
+
class LSR1LinearOperator(LinearOperator):
|
|
62
|
+
def __init__(self, s_history: Sequence[torch.Tensor], y_history: Sequence[torch.Tensor]):
|
|
63
|
+
super().__init__()
|
|
64
|
+
self.s_history = s_history
|
|
65
|
+
self.y_history = y_history
|
|
66
|
+
|
|
67
|
+
def solve(self, b):
|
|
68
|
+
return lsr1_Hx(x=b, s_history=self.s_history, y_history=self.y_history)
|
|
69
|
+
|
|
70
|
+
def matvec(self, x):
|
|
71
|
+
return lsr1_Bx(x=x, s_history=self.s_history, y_history=self.y_history)
|
|
72
|
+
|
|
73
|
+
def size(self):
|
|
74
|
+
if len(self.s_history) == 0: raise RuntimeError()
|
|
75
|
+
n = len(self.s_history[0])
|
|
76
|
+
return (n, n)
|
|
71
77
|
|
|
72
|
-
return Hx
|
|
73
78
|
|
|
79
|
+
class LSR1(Transform):
|
|
80
|
+
"""Limited-memory SR1 algorithm. A line search or trust region is recommended.
|
|
74
81
|
|
|
75
|
-
class LSR1(Module):
|
|
76
|
-
"""Limited Memory SR1 (L-SR1)
|
|
77
82
|
Args:
|
|
78
|
-
history_size (int, optional):
|
|
79
|
-
and gradient differences
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
Defaults to
|
|
92
|
-
|
|
93
|
-
|
|
83
|
+
history_size (int, optional):
|
|
84
|
+
number of past parameter differences and gradient differences to store. Defaults to 10.
|
|
85
|
+
ptol (float | None, optional):
|
|
86
|
+
skips updating the history if maximum absolute value of
|
|
87
|
+
parameter difference is less than this value. Defaults to None.
|
|
88
|
+
ptol_restart (bool, optional):
|
|
89
|
+
If true, whenever parameter difference is less then ``ptol``,
|
|
90
|
+
L-SR1 state will be reset. Defaults to None.
|
|
91
|
+
gtol (float | None, optional):
|
|
92
|
+
skips updating the history if if maximum absolute value of
|
|
93
|
+
gradient difference is less than this value. Defaults to None.
|
|
94
|
+
ptol_restart (bool, optional):
|
|
95
|
+
If true, whenever gradient difference is less then ``gtol``,
|
|
96
|
+
L-SR1 state will be reset. Defaults to None.
|
|
97
|
+
scale_first (bool, optional):
|
|
98
|
+
makes first step, when hessian approximation is not available,
|
|
99
|
+
small to reduce number of line search iterations. Defaults to False.
|
|
100
|
+
update_freq (int, optional):
|
|
101
|
+
how often to update L-SR1 history. Larger values may be better for stochastic optimization. Defaults to 1.
|
|
102
|
+
damping (DampingStrategyType, optional):
|
|
103
|
+
damping to use, can be "powell" or "double". Defaults to None.
|
|
104
|
+
compact (bool, optional):
|
|
105
|
+
if True, uses a compact representation verstion of L-SR1. It is much faster computationally, but less stable.
|
|
106
|
+
inner (Chainable | None, optional):
|
|
107
|
+
optional inner modules applied after updating L-SR1 history and before preconditioning. Defaults to None.
|
|
108
|
+
|
|
109
|
+
## Examples:
|
|
110
|
+
|
|
111
|
+
L-SR1 with line search
|
|
112
|
+
```python
|
|
113
|
+
opt = tz.Modular(
|
|
114
|
+
model.parameters(),
|
|
115
|
+
tz.m.SR1(),
|
|
116
|
+
tz.m.StrongWolfe(c2=0.1, fallback=True)
|
|
117
|
+
)
|
|
118
|
+
```
|
|
119
|
+
|
|
120
|
+
L-SR1 with trust region
|
|
121
|
+
```python
|
|
122
|
+
opt = tz.Modular(
|
|
123
|
+
model.parameters(),
|
|
124
|
+
tz.m.TrustCG(tz.m.LSR1())
|
|
125
|
+
)
|
|
126
|
+
```
|
|
94
127
|
"""
|
|
95
128
|
def __init__(
|
|
96
129
|
self,
|
|
97
|
-
history_size
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
130
|
+
history_size=10,
|
|
131
|
+
ptol: float | None = None,
|
|
132
|
+
ptol_restart: bool = False,
|
|
133
|
+
gtol: float | None = None,
|
|
134
|
+
gtol_restart: bool = False,
|
|
135
|
+
scale_first:bool=False,
|
|
136
|
+
update_freq = 1,
|
|
137
|
+
damping: DampingStrategyType = None,
|
|
103
138
|
inner: Chainable | None = None,
|
|
104
139
|
):
|
|
105
140
|
defaults = dict(
|
|
106
|
-
history_size=history_size,
|
|
107
|
-
|
|
108
|
-
|
|
141
|
+
history_size=history_size,
|
|
142
|
+
scale_first=scale_first,
|
|
143
|
+
ptol=ptol,
|
|
144
|
+
gtol=gtol,
|
|
145
|
+
ptol_restart=ptol_restart,
|
|
146
|
+
gtol_restart=gtol_restart,
|
|
147
|
+
damping = damping,
|
|
109
148
|
)
|
|
110
|
-
super().__init__(defaults)
|
|
149
|
+
super().__init__(defaults, uses_grad=False, inner=inner, update_freq=update_freq)
|
|
111
150
|
|
|
112
151
|
self.global_state['s_history'] = deque(maxlen=history_size)
|
|
113
152
|
self.global_state['y_history'] = deque(maxlen=history_size)
|
|
114
153
|
|
|
115
|
-
|
|
116
|
-
self.set_child('inner', inner)
|
|
117
|
-
|
|
118
|
-
def reset(self):
|
|
154
|
+
def _reset_self(self):
|
|
119
155
|
self.state.clear()
|
|
120
156
|
self.global_state['step'] = 0
|
|
121
157
|
self.global_state['s_history'].clear()
|
|
122
158
|
self.global_state['y_history'].clear()
|
|
123
159
|
|
|
160
|
+
def reset(self):
|
|
161
|
+
self._reset_self()
|
|
162
|
+
for c in self.children.values(): c.reset()
|
|
163
|
+
|
|
164
|
+
def reset_for_online(self):
|
|
165
|
+
super().reset_for_online()
|
|
166
|
+
self.clear_state_keys('p_prev', 'g_prev')
|
|
167
|
+
self.global_state.pop('step', None)
|
|
124
168
|
|
|
125
169
|
@torch.no_grad
|
|
126
|
-
def
|
|
127
|
-
|
|
128
|
-
|
|
170
|
+
def update_tensors(self, tensors, params, grads, loss, states, settings):
|
|
171
|
+
p = as_tensorlist(params)
|
|
172
|
+
g = as_tensorlist(tensors)
|
|
129
173
|
step = self.global_state.get('step', 0)
|
|
130
174
|
self.global_state['step'] = step + 1
|
|
131
175
|
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
176
|
+
# history of s and k
|
|
177
|
+
s_history: deque = self.global_state['s_history']
|
|
178
|
+
y_history: deque = self.global_state['y_history']
|
|
179
|
+
|
|
180
|
+
ptol = self.defaults['ptol']
|
|
181
|
+
gtol = self.defaults['gtol']
|
|
182
|
+
ptol_restart = self.defaults['ptol_restart']
|
|
183
|
+
gtol_restart = self.defaults['gtol_restart']
|
|
184
|
+
damping = self.defaults['damping']
|
|
185
|
+
|
|
186
|
+
p_prev, g_prev = unpack_states(states, tensors, 'p_prev', 'g_prev', cls=TensorList)
|
|
187
|
+
|
|
188
|
+
# 1st step - there are no previous params and grads, lsr1 will do normalized SGD step
|
|
189
|
+
if step == 0:
|
|
190
|
+
s = None; y = None; sy = None
|
|
191
|
+
else:
|
|
192
|
+
s = p - p_prev
|
|
193
|
+
y = g - g_prev
|
|
194
|
+
|
|
195
|
+
if damping is not None:
|
|
196
|
+
s, y = apply_damping(damping, s=s, y=y, g=g, H=self.get_H())
|
|
197
|
+
|
|
198
|
+
sy = s.dot(y)
|
|
199
|
+
# damping to be added here
|
|
200
|
+
|
|
201
|
+
below_tol = False
|
|
202
|
+
# tolerance on parameter difference to avoid exploding after converging
|
|
203
|
+
if ptol is not None:
|
|
204
|
+
if s is not None and s.abs().global_max() <= ptol:
|
|
205
|
+
if ptol_restart: self._reset_self()
|
|
206
|
+
sy = None
|
|
207
|
+
below_tol = True
|
|
208
|
+
|
|
209
|
+
# tolerance on gradient difference to avoid exploding when there is no curvature
|
|
210
|
+
if gtol is not None:
|
|
211
|
+
if y is not None and y.abs().global_max() <= gtol:
|
|
212
|
+
if gtol_restart: self._reset_self()
|
|
213
|
+
sy = None
|
|
214
|
+
below_tol = True
|
|
215
|
+
|
|
216
|
+
# store previous params and grads
|
|
217
|
+
if not below_tol:
|
|
218
|
+
p_prev.copy_(p)
|
|
219
|
+
g_prev.copy_(g)
|
|
220
|
+
|
|
221
|
+
# update effective preconditioning state
|
|
222
|
+
if sy is not None:
|
|
223
|
+
assert s is not None and y is not None and sy is not None
|
|
224
|
+
|
|
225
|
+
s_history.append(s)
|
|
226
|
+
y_history.append(y)
|
|
227
|
+
|
|
228
|
+
def get_H(self, var=...):
|
|
229
|
+
s_history = [tl.to_vec() for tl in self.global_state['s_history']]
|
|
230
|
+
y_history = [tl.to_vec() for tl in self.global_state['y_history']]
|
|
231
|
+
return LSR1LinearOperator(s_history, y_history)
|
|
137
232
|
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
prev_l_params, prev_l_grad = self.get_state(params, 'prev_l_params', 'prev_l_grad', cls=TensorList)
|
|
142
|
-
|
|
143
|
-
y_k = None
|
|
144
|
-
if step != 0:
|
|
145
|
-
if step % update_freq == 0:
|
|
146
|
-
s_k = l_params - prev_l_params
|
|
147
|
-
y_k = l_update - prev_l_grad
|
|
148
|
-
|
|
149
|
-
s_history.append(s_k)
|
|
150
|
-
y_history.append(y_k)
|
|
151
|
-
|
|
152
|
-
prev_l_params.copy_(l_params)
|
|
153
|
-
prev_l_grad.copy_(l_update)
|
|
233
|
+
@torch.no_grad
|
|
234
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
235
|
+
scale_first = self.defaults['scale_first']
|
|
154
236
|
|
|
155
|
-
|
|
156
|
-
update = TensorList(apply_transform(self.children['inner'], tensors=update, params=params, grads=var.grad, var=var))
|
|
237
|
+
tensors = as_tensorlist(tensors)
|
|
157
238
|
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
if y_k is not None and y_k.abs().global_max() <= tol:
|
|
161
|
-
var.update = update
|
|
162
|
-
return var
|
|
239
|
+
s_history = self.global_state['s_history']
|
|
240
|
+
y_history = self.global_state['y_history']
|
|
163
241
|
|
|
164
|
-
|
|
165
|
-
|
|
242
|
+
# precondition
|
|
243
|
+
dir = lsr1_Hx(
|
|
244
|
+
x=tensors,
|
|
166
245
|
s_history=s_history,
|
|
167
246
|
y_history=y_history,
|
|
168
|
-
step=step,
|
|
169
|
-
scale_second=scale_second,
|
|
170
247
|
)
|
|
171
248
|
|
|
172
|
-
|
|
249
|
+
# scale 1st step
|
|
250
|
+
if scale_first and self.global_state.get('step', 1) == 1:
|
|
251
|
+
dir *= initial_step_size(dir, eps=1e-7)
|
|
173
252
|
|
|
174
|
-
return
|
|
253
|
+
return dir
|