torchzero 0.3.8__py3-none-any.whl → 0.3.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_opts.py +55 -22
- tests/test_tensorlist.py +3 -3
- tests/test_vars.py +61 -61
- torchzero/core/__init__.py +2 -3
- torchzero/core/module.py +49 -49
- torchzero/core/transform.py +219 -158
- torchzero/modules/__init__.py +1 -0
- torchzero/modules/clipping/clipping.py +10 -10
- torchzero/modules/clipping/ema_clipping.py +14 -13
- torchzero/modules/clipping/growth_clipping.py +16 -18
- torchzero/modules/experimental/__init__.py +12 -3
- torchzero/modules/experimental/absoap.py +50 -156
- torchzero/modules/experimental/adadam.py +15 -14
- torchzero/modules/experimental/adamY.py +17 -27
- torchzero/modules/experimental/adasoap.py +20 -130
- torchzero/modules/experimental/curveball.py +12 -12
- torchzero/modules/experimental/diagonal_higher_order_newton.py +225 -0
- torchzero/modules/experimental/eigendescent.py +117 -0
- torchzero/modules/experimental/etf.py +172 -0
- torchzero/modules/experimental/gradmin.py +2 -2
- torchzero/modules/experimental/newton_solver.py +11 -11
- torchzero/modules/experimental/newtonnewton.py +88 -0
- torchzero/modules/experimental/reduce_outward_lr.py +8 -5
- torchzero/modules/experimental/soapy.py +19 -146
- torchzero/modules/experimental/spectral.py +79 -204
- torchzero/modules/experimental/structured_newton.py +111 -0
- torchzero/modules/experimental/subspace_preconditioners.py +13 -10
- torchzero/modules/experimental/tada.py +38 -0
- torchzero/modules/grad_approximation/fdm.py +2 -2
- torchzero/modules/grad_approximation/forward_gradient.py +5 -5
- torchzero/modules/grad_approximation/grad_approximator.py +21 -21
- torchzero/modules/grad_approximation/rfdm.py +28 -15
- torchzero/modules/higher_order/__init__.py +1 -0
- torchzero/modules/higher_order/higher_order_newton.py +256 -0
- torchzero/modules/line_search/backtracking.py +42 -23
- torchzero/modules/line_search/line_search.py +40 -40
- torchzero/modules/line_search/scipy.py +18 -3
- torchzero/modules/line_search/strong_wolfe.py +21 -32
- torchzero/modules/line_search/trust_region.py +18 -6
- torchzero/modules/lr/__init__.py +1 -1
- torchzero/modules/lr/{step_size.py → adaptive.py} +22 -26
- torchzero/modules/lr/lr.py +20 -16
- torchzero/modules/momentum/averaging.py +25 -10
- torchzero/modules/momentum/cautious.py +73 -35
- torchzero/modules/momentum/ema.py +92 -41
- torchzero/modules/momentum/experimental.py +21 -13
- torchzero/modules/momentum/matrix_momentum.py +96 -54
- torchzero/modules/momentum/momentum.py +24 -4
- torchzero/modules/ops/accumulate.py +51 -21
- torchzero/modules/ops/binary.py +36 -36
- torchzero/modules/ops/debug.py +7 -7
- torchzero/modules/ops/misc.py +128 -129
- torchzero/modules/ops/multi.py +19 -19
- torchzero/modules/ops/reduce.py +16 -16
- torchzero/modules/ops/split.py +26 -26
- torchzero/modules/ops/switch.py +4 -4
- torchzero/modules/ops/unary.py +20 -20
- torchzero/modules/ops/utility.py +37 -37
- torchzero/modules/optimizers/adagrad.py +33 -24
- torchzero/modules/optimizers/adam.py +31 -34
- torchzero/modules/optimizers/lion.py +4 -4
- torchzero/modules/optimizers/muon.py +6 -6
- torchzero/modules/optimizers/orthograd.py +4 -5
- torchzero/modules/optimizers/rmsprop.py +13 -16
- torchzero/modules/optimizers/rprop.py +52 -49
- torchzero/modules/optimizers/shampoo.py +17 -23
- torchzero/modules/optimizers/soap.py +12 -19
- torchzero/modules/optimizers/sophia_h.py +13 -13
- torchzero/modules/projections/dct.py +4 -4
- torchzero/modules/projections/fft.py +6 -6
- torchzero/modules/projections/galore.py +1 -1
- torchzero/modules/projections/projection.py +57 -57
- torchzero/modules/projections/structural.py +17 -17
- torchzero/modules/quasi_newton/__init__.py +33 -4
- torchzero/modules/quasi_newton/cg.py +76 -26
- torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +24 -24
- torchzero/modules/quasi_newton/lbfgs.py +15 -15
- torchzero/modules/quasi_newton/lsr1.py +18 -17
- torchzero/modules/quasi_newton/olbfgs.py +19 -19
- torchzero/modules/quasi_newton/quasi_newton.py +257 -48
- torchzero/modules/second_order/newton.py +38 -21
- torchzero/modules/second_order/newton_cg.py +13 -12
- torchzero/modules/second_order/nystrom.py +19 -19
- torchzero/modules/smoothing/gaussian.py +21 -21
- torchzero/modules/smoothing/laplacian.py +7 -9
- torchzero/modules/weight_decay/__init__.py +1 -1
- torchzero/modules/weight_decay/weight_decay.py +43 -9
- torchzero/modules/wrappers/optim_wrapper.py +11 -11
- torchzero/optim/wrappers/directsearch.py +244 -0
- torchzero/optim/wrappers/fcmaes.py +97 -0
- torchzero/optim/wrappers/mads.py +90 -0
- torchzero/optim/wrappers/nevergrad.py +4 -4
- torchzero/optim/wrappers/nlopt.py +28 -14
- torchzero/optim/wrappers/optuna.py +70 -0
- torchzero/optim/wrappers/scipy.py +162 -13
- torchzero/utils/__init__.py +2 -6
- torchzero/utils/derivatives.py +2 -1
- torchzero/utils/optimizer.py +55 -74
- torchzero/utils/python_tools.py +17 -4
- {torchzero-0.3.8.dist-info → torchzero-0.3.10.dist-info}/METADATA +14 -14
- torchzero-0.3.10.dist-info/RECORD +139 -0
- {torchzero-0.3.8.dist-info → torchzero-0.3.10.dist-info}/WHEEL +1 -1
- torchzero/core/preconditioner.py +0 -138
- torchzero/modules/experimental/algebraic_newton.py +0 -145
- torchzero/modules/experimental/tropical_newton.py +0 -136
- torchzero-0.3.8.dist-info/RECORD +0 -130
- {torchzero-0.3.8.dist-info → torchzero-0.3.10.dist-info}/licenses/LICENSE +0 -0
- {torchzero-0.3.8.dist-info → torchzero-0.3.10.dist-info}/top_level.txt +0 -0
|
@@ -1,14 +1,15 @@
|
|
|
1
1
|
from abc import ABC, abstractmethod
|
|
2
|
+
from typing import Literal
|
|
2
3
|
|
|
3
4
|
import torch
|
|
4
5
|
|
|
5
|
-
from ...core import Chainable, Transform,
|
|
6
|
-
from ...utils import TensorList, as_tensorlist
|
|
6
|
+
from ...core import Chainable, TensorwiseTransform, Transform, apply_transform
|
|
7
|
+
from ...utils import TensorList, as_tensorlist, unpack_dicts, unpack_states
|
|
7
8
|
|
|
8
9
|
|
|
9
10
|
class ConguateGradientBase(Transform, ABC):
|
|
10
11
|
"""all CGs are the same except beta calculation"""
|
|
11
|
-
def __init__(self, defaults = None, clip_beta: bool = False, reset_interval: int | None = None, inner: Chainable | None = None):
|
|
12
|
+
def __init__(self, defaults = None, clip_beta: bool = False, reset_interval: int | None | Literal['auto'] = None, inner: Chainable | None = None):
|
|
12
13
|
if defaults is None: defaults = {}
|
|
13
14
|
defaults['reset_interval'] = reset_interval
|
|
14
15
|
defaults['clip_beta'] = clip_beta
|
|
@@ -25,12 +26,12 @@ class ConguateGradientBase(Transform, ABC):
|
|
|
25
26
|
"""returns beta"""
|
|
26
27
|
|
|
27
28
|
@torch.no_grad
|
|
28
|
-
def
|
|
29
|
+
def apply(self, tensors, params, grads, loss, states, settings):
|
|
29
30
|
tensors = as_tensorlist(tensors)
|
|
30
31
|
params = as_tensorlist(params)
|
|
31
32
|
|
|
32
33
|
step = self.global_state.get('step', 0)
|
|
33
|
-
prev_dir, prev_grads =
|
|
34
|
+
prev_dir, prev_grads = unpack_states(states, tensors, 'prev_dir', 'prev_grad', cls=TensorList)
|
|
34
35
|
|
|
35
36
|
# initialize on first step
|
|
36
37
|
if step == 0:
|
|
@@ -42,12 +43,12 @@ class ConguateGradientBase(Transform, ABC):
|
|
|
42
43
|
|
|
43
44
|
# get beta
|
|
44
45
|
beta = self.get_beta(params, tensors, prev_grads, prev_dir)
|
|
45
|
-
if
|
|
46
|
+
if settings[0]['clip_beta']: beta = max(0, beta) # pyright:ignore[reportArgumentType]
|
|
46
47
|
prev_grads.copy_(tensors)
|
|
47
48
|
|
|
48
49
|
# inner step
|
|
49
50
|
if 'inner' in self.children:
|
|
50
|
-
tensors = as_tensorlist(
|
|
51
|
+
tensors = as_tensorlist(apply_transform(self.children['inner'], tensors, params, grads))
|
|
51
52
|
|
|
52
53
|
# calculate new direction with beta
|
|
53
54
|
dir = tensors.add_(prev_dir.mul_(beta))
|
|
@@ -55,7 +56,8 @@ class ConguateGradientBase(Transform, ABC):
|
|
|
55
56
|
|
|
56
57
|
# resetting
|
|
57
58
|
self.global_state['step'] = step + 1
|
|
58
|
-
reset_interval =
|
|
59
|
+
reset_interval = settings[0]['reset_interval']
|
|
60
|
+
if reset_interval == 'auto': reset_interval = tensors.global_numel() + 1
|
|
59
61
|
if reset_interval is not None and (step+1) % reset_interval == 0:
|
|
60
62
|
self.reset()
|
|
61
63
|
|
|
@@ -64,7 +66,7 @@ class ConguateGradientBase(Transform, ABC):
|
|
|
64
66
|
# ------------------------------- Polak-Ribière ------------------------------ #
|
|
65
67
|
def polak_ribiere_beta(g: TensorList, prev_g: TensorList):
|
|
66
68
|
denom = prev_g.dot(prev_g)
|
|
67
|
-
if denom
|
|
69
|
+
if denom.abs() <= torch.finfo(g[0].dtype).eps: return 0
|
|
68
70
|
return g.dot(g - prev_g) / denom
|
|
69
71
|
|
|
70
72
|
class PolakRibiere(ConguateGradientBase):
|
|
@@ -76,13 +78,13 @@ class PolakRibiere(ConguateGradientBase):
|
|
|
76
78
|
return polak_ribiere_beta(g, prev_g)
|
|
77
79
|
|
|
78
80
|
# ------------------------------ Fletcher–Reeves ----------------------------- #
|
|
79
|
-
def fletcher_reeves_beta(gg, prev_gg):
|
|
80
|
-
if prev_gg
|
|
81
|
+
def fletcher_reeves_beta(gg: torch.Tensor, prev_gg: torch.Tensor):
|
|
82
|
+
if prev_gg.abs() <= torch.finfo(gg.dtype).eps: return 0
|
|
81
83
|
return gg / prev_gg
|
|
82
84
|
|
|
83
85
|
class FletcherReeves(ConguateGradientBase):
|
|
84
86
|
"""Fletcher–Reeves nonlinear conjugate gradient method. This requires step size to be determined via a line search, so put a line search like :code:`StrongWolfe` after this."""
|
|
85
|
-
def __init__(self, reset_interval: int | None =
|
|
87
|
+
def __init__(self, reset_interval: int | None | Literal['auto'] = 'auto', clip_beta=False, inner: Chainable | None = None):
|
|
86
88
|
super().__init__(clip_beta=clip_beta, reset_interval=reset_interval, inner=inner)
|
|
87
89
|
|
|
88
90
|
def initialize(self, p, g):
|
|
@@ -98,13 +100,13 @@ class FletcherReeves(ConguateGradientBase):
|
|
|
98
100
|
def hestenes_stiefel_beta(g: TensorList, prev_d: TensorList,prev_g: TensorList):
|
|
99
101
|
grad_diff = g - prev_g
|
|
100
102
|
denom = prev_d.dot(grad_diff)
|
|
101
|
-
if denom
|
|
103
|
+
if denom.abs() < torch.finfo(g[0].dtype).eps: return 0
|
|
102
104
|
return (g.dot(grad_diff) / denom).neg()
|
|
103
105
|
|
|
104
106
|
|
|
105
107
|
class HestenesStiefel(ConguateGradientBase):
|
|
106
108
|
"""Hestenes–Stiefel nonlinear conjugate gradient method. This requires step size to be determined via a line search, so put a line search like :code:`StrongWolfe` after this."""
|
|
107
|
-
def __init__(self, reset_interval: int | None = None, clip_beta=False, inner: Chainable | None = None):
|
|
109
|
+
def __init__(self, reset_interval: int | None | Literal['auto'] = None, clip_beta=False, inner: Chainable | None = None):
|
|
108
110
|
super().__init__(clip_beta=clip_beta, reset_interval=reset_interval, inner=inner)
|
|
109
111
|
|
|
110
112
|
def get_beta(self, p, g, prev_g, prev_d):
|
|
@@ -114,12 +116,12 @@ class HestenesStiefel(ConguateGradientBase):
|
|
|
114
116
|
# --------------------------------- Dai–Yuan --------------------------------- #
|
|
115
117
|
def dai_yuan_beta(g: TensorList, prev_d: TensorList,prev_g: TensorList):
|
|
116
118
|
denom = prev_d.dot(g - prev_g)
|
|
117
|
-
if denom
|
|
119
|
+
if denom.abs() <= torch.finfo(g[0].dtype).eps: return 0
|
|
118
120
|
return (g.dot(g) / denom).neg()
|
|
119
121
|
|
|
120
122
|
class DaiYuan(ConguateGradientBase):
|
|
121
123
|
"""Dai–Yuan nonlinear conjugate gradient method. This requires step size to be determined via a line search, so put a line search like :code:`StrongWolfe` after this."""
|
|
122
|
-
def __init__(self, reset_interval: int | None = None, clip_beta=False, inner: Chainable | None = None):
|
|
124
|
+
def __init__(self, reset_interval: int | None | Literal['auto'] = None, clip_beta=False, inner: Chainable | None = None):
|
|
123
125
|
super().__init__(clip_beta=clip_beta, reset_interval=reset_interval, inner=inner)
|
|
124
126
|
|
|
125
127
|
def get_beta(self, p, g, prev_g, prev_d):
|
|
@@ -129,12 +131,12 @@ class DaiYuan(ConguateGradientBase):
|
|
|
129
131
|
# -------------------------------- Liu-Storey -------------------------------- #
|
|
130
132
|
def liu_storey_beta(g:TensorList, prev_d:TensorList, prev_g:TensorList, ):
|
|
131
133
|
denom = prev_g.dot(prev_d)
|
|
132
|
-
if denom
|
|
134
|
+
if denom.abs() <= torch.finfo(g[0].dtype).eps: return 0
|
|
133
135
|
return g.dot(g - prev_g) / denom
|
|
134
136
|
|
|
135
137
|
class LiuStorey(ConguateGradientBase):
|
|
136
138
|
"""Liu-Storey nonlinear conjugate gradient method. This requires step size to be determined via a line search, so put a line search like :code:`StrongWolfe` after this."""
|
|
137
|
-
def __init__(self, reset_interval: int | None = None, clip_beta=False, inner: Chainable | None = None):
|
|
139
|
+
def __init__(self, reset_interval: int | None | Literal['auto'] = None, clip_beta=False, inner: Chainable | None = None):
|
|
138
140
|
super().__init__(clip_beta=clip_beta, reset_interval=reset_interval, inner=inner)
|
|
139
141
|
|
|
140
142
|
def get_beta(self, p, g, prev_g, prev_d):
|
|
@@ -151,20 +153,20 @@ class ConjugateDescent(Transform):
|
|
|
151
153
|
|
|
152
154
|
|
|
153
155
|
@torch.no_grad
|
|
154
|
-
def
|
|
156
|
+
def apply(self, tensors, params, grads, loss, states, settings):
|
|
155
157
|
g = as_tensorlist(tensors)
|
|
156
158
|
|
|
157
|
-
prev_d =
|
|
159
|
+
prev_d = unpack_states(states, tensors, 'prev_dir', cls=TensorList, init=torch.zeros_like)
|
|
158
160
|
if 'denom' not in self.global_state:
|
|
159
161
|
self.global_state['denom'] = torch.tensor(0.).to(g[0])
|
|
160
162
|
|
|
161
163
|
prev_gd = self.global_state.get('prev_gd', 0)
|
|
162
|
-
if prev_gd
|
|
164
|
+
if abs(prev_gd) <= torch.finfo(g[0].dtype).eps: beta = 0
|
|
163
165
|
else: beta = g.dot(g) / prev_gd
|
|
164
166
|
|
|
165
167
|
# inner step
|
|
166
168
|
if 'inner' in self.children:
|
|
167
|
-
g = as_tensorlist(
|
|
169
|
+
g = as_tensorlist(apply_transform(self.children['inner'], g, params, grads))
|
|
168
170
|
|
|
169
171
|
dir = g.add_(prev_d.mul_(beta))
|
|
170
172
|
prev_d.copy_(dir)
|
|
@@ -176,7 +178,7 @@ class ConjugateDescent(Transform):
|
|
|
176
178
|
def hager_zhang_beta(g:TensorList, prev_d:TensorList, prev_g:TensorList,):
|
|
177
179
|
g_diff = g - prev_g
|
|
178
180
|
denom = prev_d.dot(g_diff)
|
|
179
|
-
if denom
|
|
181
|
+
if denom.abs() <= torch.finfo(g[0].dtype).eps: return 0
|
|
180
182
|
|
|
181
183
|
term1 = 1/denom
|
|
182
184
|
# term2
|
|
@@ -187,7 +189,7 @@ def hager_zhang_beta(g:TensorList, prev_d:TensorList, prev_g:TensorList,):
|
|
|
187
189
|
class HagerZhang(ConguateGradientBase):
|
|
188
190
|
"""Hager-Zhang nonlinear conjugate gradient method,
|
|
189
191
|
This requires step size to be determined via a line search, so put a line search like :code:`StrongWolfe` after this."""
|
|
190
|
-
def __init__(self, reset_interval: int | None = None, clip_beta=False, inner: Chainable | None = None):
|
|
192
|
+
def __init__(self, reset_interval: int | None | Literal['auto'] = None, clip_beta=False, inner: Chainable | None = None):
|
|
191
193
|
super().__init__(clip_beta=clip_beta, reset_interval=reset_interval, inner=inner)
|
|
192
194
|
|
|
193
195
|
def get_beta(self, p, g, prev_g, prev_d):
|
|
@@ -198,7 +200,7 @@ class HagerZhang(ConguateGradientBase):
|
|
|
198
200
|
def hs_dy_beta(g: TensorList, prev_d: TensorList,prev_g: TensorList):
|
|
199
201
|
grad_diff = g - prev_g
|
|
200
202
|
denom = prev_d.dot(grad_diff)
|
|
201
|
-
if denom
|
|
203
|
+
if denom.abs() <= torch.finfo(g[0].dtype).eps: return 0
|
|
202
204
|
|
|
203
205
|
# Dai-Yuan
|
|
204
206
|
dy_beta = (g.dot(g) / denom).neg().clamp(min=0)
|
|
@@ -211,8 +213,56 @@ def hs_dy_beta(g: TensorList, prev_d: TensorList,prev_g: TensorList):
|
|
|
211
213
|
class HybridHS_DY(ConguateGradientBase):
|
|
212
214
|
"""HS-DY hybrid conjugate gradient method.
|
|
213
215
|
This requires step size to be determined via a line search, so put a line search like :code:`StrongWolfe` after this."""
|
|
214
|
-
def __init__(self, reset_interval: int | None = None, clip_beta=False, inner: Chainable | None = None):
|
|
216
|
+
def __init__(self, reset_interval: int | None | Literal['auto'] = None, clip_beta=False, inner: Chainable | None = None):
|
|
215
217
|
super().__init__(clip_beta=clip_beta, reset_interval=reset_interval, inner=inner)
|
|
216
218
|
|
|
217
219
|
def get_beta(self, p, g, prev_g, prev_d):
|
|
218
220
|
return hs_dy_beta(g, prev_d, prev_g)
|
|
221
|
+
|
|
222
|
+
|
|
223
|
+
def projected_gradient_(H:torch.Tensor, y:torch.Tensor, tol: float):
|
|
224
|
+
Hy = H @ y
|
|
225
|
+
denom = y.dot(Hy)
|
|
226
|
+
if denom.abs() < tol: return H
|
|
227
|
+
H -= (H @ y.outer(y) @ H) / denom
|
|
228
|
+
return H
|
|
229
|
+
|
|
230
|
+
class ProjectedGradientMethod(TensorwiseTransform):
|
|
231
|
+
"""Pearson, J. D. (1969). Variable metric methods of minimisation. The Computer Journal, 12(2), 171–178. doi:10.1093/comjnl/12.2.171.
|
|
232
|
+
|
|
233
|
+
(This is not the same as projected gradient descent)
|
|
234
|
+
"""
|
|
235
|
+
|
|
236
|
+
def __init__(
|
|
237
|
+
self,
|
|
238
|
+
tol: float = 1e-10,
|
|
239
|
+
reset_interval: int | None = None,
|
|
240
|
+
update_freq: int = 1,
|
|
241
|
+
scale_first: bool = False,
|
|
242
|
+
concat_params: bool = True,
|
|
243
|
+
inner: Chainable | None = None,
|
|
244
|
+
):
|
|
245
|
+
defaults = dict(reset_interval=reset_interval, tol=tol)
|
|
246
|
+
super().__init__(defaults, uses_grad=False, scale_first=scale_first, concat_params=concat_params, update_freq=update_freq, inner=inner)
|
|
247
|
+
|
|
248
|
+
def update_tensor(self, tensor, param, grad, loss, state, settings):
|
|
249
|
+
step = state.get('step', 0)
|
|
250
|
+
state['step'] = step + 1
|
|
251
|
+
reset_interval = settings['reset_interval']
|
|
252
|
+
if reset_interval is None: reset_interval = tensor.numel() + 1 # as recommended
|
|
253
|
+
|
|
254
|
+
if ("H" not in state) or (step % reset_interval == 0):
|
|
255
|
+
state["H"] = torch.eye(tensor.numel(), device=tensor.device, dtype=tensor.dtype)
|
|
256
|
+
state['g_prev'] = tensor.clone()
|
|
257
|
+
return
|
|
258
|
+
|
|
259
|
+
H = state['H']
|
|
260
|
+
g_prev = state['g_prev']
|
|
261
|
+
state['g_prev'] = tensor.clone()
|
|
262
|
+
y = (tensor - g_prev).ravel()
|
|
263
|
+
|
|
264
|
+
projected_gradient_(H, y, settings['tol'])
|
|
265
|
+
|
|
266
|
+
def apply_tensor(self, tensor, param, grad, loss, state, settings):
|
|
267
|
+
H = state['H']
|
|
268
|
+
return (H @ tensor.view(-1)).view_as(tensor)
|
|
@@ -4,7 +4,7 @@ from typing import Any
|
|
|
4
4
|
|
|
5
5
|
import torch
|
|
6
6
|
|
|
7
|
-
from ....core import Chainable, Module, Transform,
|
|
7
|
+
from ....core import Chainable, Module, Transform, Var, apply_transform, maybe_chain
|
|
8
8
|
from ....utils import NumberList, TensorList, as_tensorlist
|
|
9
9
|
|
|
10
10
|
|
|
@@ -28,7 +28,7 @@ def _adaptive_damping(
|
|
|
28
28
|
|
|
29
29
|
def lbfgs(
|
|
30
30
|
tensors_: TensorList,
|
|
31
|
-
|
|
31
|
+
var: Var,
|
|
32
32
|
s_history: deque[TensorList],
|
|
33
33
|
y_history: deque[TensorList],
|
|
34
34
|
sy_history: deque[torch.Tensor],
|
|
@@ -60,7 +60,7 @@ def lbfgs(
|
|
|
60
60
|
z = q * (ys_k / (y_k.dot(y_k)))
|
|
61
61
|
|
|
62
62
|
if z_tfm is not None:
|
|
63
|
-
z = TensorList(
|
|
63
|
+
z = TensorList(apply_transform(z_tfm, tensors=z, params=var.params, grads=var.grad, var=var))
|
|
64
64
|
|
|
65
65
|
# 2nd loop
|
|
66
66
|
for s_i, y_i, ys_i, alpha_i in zip(s_history, y_history, sy_history, reversed(alpha_list)):
|
|
@@ -73,28 +73,28 @@ def lbfgs(
|
|
|
73
73
|
def _apply_tfms_into_history(
|
|
74
74
|
self: Module,
|
|
75
75
|
params: list[torch.Tensor],
|
|
76
|
-
|
|
76
|
+
var: Var,
|
|
77
77
|
update: list[torch.Tensor],
|
|
78
78
|
):
|
|
79
79
|
if 'params_history_tfm' in self.children:
|
|
80
|
-
params =
|
|
80
|
+
params = apply_transform(self.children['params_history_tfm'], tensors=as_tensorlist(params).clone(), params=params, grads=var.grad, var=var)
|
|
81
81
|
|
|
82
82
|
if 'grad_history_tfm' in self.children:
|
|
83
|
-
update =
|
|
83
|
+
update = apply_transform(self.children['grad_history_tfm'], tensors=as_tensorlist(update).clone(), params=params, grads=var.grad, var=var)
|
|
84
84
|
|
|
85
85
|
return params, update
|
|
86
86
|
|
|
87
87
|
def _apply_tfms_into_precond(
|
|
88
88
|
self: Module,
|
|
89
89
|
params: list[torch.Tensor],
|
|
90
|
-
|
|
90
|
+
var: Var,
|
|
91
91
|
update: list[torch.Tensor],
|
|
92
92
|
):
|
|
93
93
|
if 'params_precond_tfm' in self.children:
|
|
94
|
-
params =
|
|
94
|
+
params = apply_transform(self.children['params_precond_tfm'], tensors=as_tensorlist(params).clone(), params=params, grads=var.grad, var=var)
|
|
95
95
|
|
|
96
96
|
if 'grad_precond_tfm' in self.children:
|
|
97
|
-
update =
|
|
97
|
+
update = apply_transform(self.children['grad_precond_tfm'], tensors=update, params=params, grads=var.grad, var=var)
|
|
98
98
|
|
|
99
99
|
return params, update
|
|
100
100
|
|
|
@@ -165,9 +165,9 @@ class ModularLBFGS(Module):
|
|
|
165
165
|
self.global_state['sy_history'].clear()
|
|
166
166
|
|
|
167
167
|
@torch.no_grad
|
|
168
|
-
def step(self,
|
|
169
|
-
params = as_tensorlist(
|
|
170
|
-
update = as_tensorlist(
|
|
168
|
+
def step(self, var):
|
|
169
|
+
params = as_tensorlist(var.params)
|
|
170
|
+
update = as_tensorlist(var.get_update())
|
|
171
171
|
step = self.global_state.get('step', 0)
|
|
172
172
|
self.global_state['step'] = step + 1
|
|
173
173
|
|
|
@@ -186,11 +186,11 @@ class ModularLBFGS(Module):
|
|
|
186
186
|
params_h, update_h = _apply_tfms_into_history(
|
|
187
187
|
self,
|
|
188
188
|
params=params,
|
|
189
|
-
|
|
189
|
+
var=var,
|
|
190
190
|
update=update,
|
|
191
191
|
)
|
|
192
192
|
|
|
193
|
-
prev_params_h, prev_grad_h = self.get_state('prev_params_h', 'prev_grad_h',
|
|
193
|
+
prev_params_h, prev_grad_h = self.get_state(params, 'prev_params_h', 'prev_grad_h', cls=TensorList)
|
|
194
194
|
|
|
195
195
|
# 1st step - there are no previous params and grads, `lbfgs` will do normalized SGD step
|
|
196
196
|
if step == 0:
|
|
@@ -217,16 +217,16 @@ class ModularLBFGS(Module):
|
|
|
217
217
|
# step with inner module before applying preconditioner
|
|
218
218
|
if 'update_precond_tfm' in self.children:
|
|
219
219
|
update_precond_tfm = self.children['update_precond_tfm']
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
tensors =
|
|
220
|
+
inner_var = update_precond_tfm.step(var.clone(clone_update=True))
|
|
221
|
+
var.update_attrs_from_clone_(inner_var)
|
|
222
|
+
tensors = inner_var.update
|
|
223
223
|
assert tensors is not None
|
|
224
224
|
else:
|
|
225
225
|
tensors = update.clone()
|
|
226
226
|
|
|
227
227
|
# transforms into preconditioner
|
|
228
|
-
params_p, update_p = _apply_tfms_into_precond(self, params=params,
|
|
229
|
-
prev_params_p, prev_grad_p = self.get_state('prev_params_p', 'prev_grad_p',
|
|
228
|
+
params_p, update_p = _apply_tfms_into_precond(self, params=params, var=var, update=update)
|
|
229
|
+
prev_params_p, prev_grad_p = self.get_state(params, 'prev_params_p', 'prev_grad_p', cls=TensorList)
|
|
230
230
|
|
|
231
231
|
if step == 0:
|
|
232
232
|
s_k_p = None; y_k_p = None; ys_k_p = None
|
|
@@ -245,13 +245,13 @@ class ModularLBFGS(Module):
|
|
|
245
245
|
# tolerance on gradient difference to avoid exploding after converging
|
|
246
246
|
if tol is not None:
|
|
247
247
|
if y_k_p is not None and y_k_p.abs().global_max() <= tol:
|
|
248
|
-
|
|
249
|
-
return
|
|
248
|
+
var.update = update # may have been updated by inner module, probably makes sense to use it here?
|
|
249
|
+
return var
|
|
250
250
|
|
|
251
251
|
# precondition
|
|
252
252
|
dir = lbfgs(
|
|
253
253
|
tensors_=as_tensorlist(tensors),
|
|
254
|
-
|
|
254
|
+
var=var,
|
|
255
255
|
s_history=s_history,
|
|
256
256
|
y_history=y_history,
|
|
257
257
|
sy_history=sy_history,
|
|
@@ -260,7 +260,7 @@ class ModularLBFGS(Module):
|
|
|
260
260
|
z_tfm=self.children.get('z_tfm', None),
|
|
261
261
|
)
|
|
262
262
|
|
|
263
|
-
|
|
263
|
+
var.update = dir
|
|
264
264
|
|
|
265
|
-
return
|
|
265
|
+
return var
|
|
266
266
|
|
|
@@ -2,7 +2,7 @@ from collections import deque
|
|
|
2
2
|
from operator import itemgetter
|
|
3
3
|
import torch
|
|
4
4
|
|
|
5
|
-
from ...core import Transform, Chainable, Module,
|
|
5
|
+
from ...core import Transform, Chainable, Module, Var, apply_transform
|
|
6
6
|
from ...utils import TensorList, as_tensorlist, NumberList
|
|
7
7
|
|
|
8
8
|
|
|
@@ -38,9 +38,9 @@ def lbfgs(
|
|
|
38
38
|
if len(s_history) == 0 or y_k is None or ys_k is None:
|
|
39
39
|
|
|
40
40
|
# initial step size guess modified from pytorch L-BFGS
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
return tensors_.mul_(
|
|
41
|
+
scale_factor = 1 / TensorList(tensors_).abs().global_sum().clip(min=1)
|
|
42
|
+
scale_factor = scale_factor.clip(min=torch.finfo(tensors_[0].dtype).eps)
|
|
43
|
+
return tensors_.mul_(scale_factor)
|
|
44
44
|
|
|
45
45
|
else:
|
|
46
46
|
# 1st loop
|
|
@@ -154,9 +154,9 @@ class LBFGS(Module):
|
|
|
154
154
|
self.global_state['sy_history'].clear()
|
|
155
155
|
|
|
156
156
|
@torch.no_grad
|
|
157
|
-
def step(self,
|
|
158
|
-
params = as_tensorlist(
|
|
159
|
-
update = as_tensorlist(
|
|
157
|
+
def step(self, var):
|
|
158
|
+
params = as_tensorlist(var.params)
|
|
159
|
+
update = as_tensorlist(var.get_update())
|
|
160
160
|
step = self.global_state.get('step', 0)
|
|
161
161
|
self.global_state['step'] = step + 1
|
|
162
162
|
|
|
@@ -167,10 +167,10 @@ class LBFGS(Module):
|
|
|
167
167
|
|
|
168
168
|
tol, damping, init_damping, eigval_bounds, update_freq, z_beta, tol_reset = itemgetter(
|
|
169
169
|
'tol', 'damping', 'init_damping', 'eigval_bounds', 'update_freq', 'z_beta', 'tol_reset')(self.settings[params[0]])
|
|
170
|
-
params_beta, grads_beta = self.get_settings('params_beta', 'grads_beta'
|
|
170
|
+
params_beta, grads_beta = self.get_settings(params, 'params_beta', 'grads_beta')
|
|
171
171
|
|
|
172
172
|
l_params, l_update = _lerp_params_update_(self, params, update, params_beta, grads_beta)
|
|
173
|
-
prev_l_params, prev_l_grad = self.get_state('prev_l_params', 'prev_l_grad',
|
|
173
|
+
prev_l_params, prev_l_grad = self.get_state(params, 'prev_l_params', 'prev_l_grad', cls=TensorList)
|
|
174
174
|
|
|
175
175
|
# 1st step - there are no previous params and grads, `lbfgs` will do normalized SGD step
|
|
176
176
|
if step == 0:
|
|
@@ -196,19 +196,19 @@ class LBFGS(Module):
|
|
|
196
196
|
|
|
197
197
|
# step with inner module before applying preconditioner
|
|
198
198
|
if self.children:
|
|
199
|
-
update = TensorList(
|
|
199
|
+
update = TensorList(apply_transform(self.children['inner'], tensors=update, params=params, grads=var.grad, var=var))
|
|
200
200
|
|
|
201
201
|
# tolerance on gradient difference to avoid exploding after converging
|
|
202
202
|
if tol is not None:
|
|
203
203
|
if y_k is not None and y_k.abs().global_max() <= tol:
|
|
204
|
-
|
|
204
|
+
var.update = update # may have been updated by inner module, probably makes sense to use it here?
|
|
205
205
|
if tol_reset: self.reset()
|
|
206
|
-
return
|
|
206
|
+
return var
|
|
207
207
|
|
|
208
208
|
# lerp initial H^-1 @ q guess
|
|
209
209
|
z_ema = None
|
|
210
210
|
if z_beta is not None:
|
|
211
|
-
z_ema = self.get_state('z_ema',
|
|
211
|
+
z_ema = self.get_state(var.params, 'z_ema', cls=TensorList)
|
|
212
212
|
|
|
213
213
|
# precondition
|
|
214
214
|
dir = lbfgs(
|
|
@@ -223,7 +223,7 @@ class LBFGS(Module):
|
|
|
223
223
|
step=step
|
|
224
224
|
)
|
|
225
225
|
|
|
226
|
-
|
|
226
|
+
var.update = dir
|
|
227
227
|
|
|
228
|
-
return
|
|
228
|
+
return var
|
|
229
229
|
|
|
@@ -3,7 +3,7 @@ from operator import itemgetter
|
|
|
3
3
|
|
|
4
4
|
import torch
|
|
5
5
|
|
|
6
|
-
from ...core import Chainable, Module, Transform,
|
|
6
|
+
from ...core import Chainable, Module, Transform, Var, apply_transform
|
|
7
7
|
from ...utils import NumberList, TensorList, as_tensorlist
|
|
8
8
|
|
|
9
9
|
from .lbfgs import _lerp_params_update_
|
|
@@ -17,9 +17,9 @@ def lsr1_(
|
|
|
17
17
|
):
|
|
18
18
|
if step == 0 or not s_history:
|
|
19
19
|
# initial step size guess from pytorch
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
return tensors_.mul_(
|
|
20
|
+
scale_factor = 1 / TensorList(tensors_).abs().global_sum().clip(min=1)
|
|
21
|
+
scale_factor = scale_factor.clip(min=torch.finfo(tensors_[0].dtype).eps)
|
|
22
|
+
return tensors_.mul_(scale_factor)
|
|
23
23
|
|
|
24
24
|
m = len(s_history)
|
|
25
25
|
|
|
@@ -65,9 +65,10 @@ def lsr1_(
|
|
|
65
65
|
Hx.add_(w_k, alpha=w_k.dot(tensors_) / wy) # pyright:ignore[reportArgumentType]
|
|
66
66
|
|
|
67
67
|
if scale_second and step == 1:
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
Hx.mul_(
|
|
68
|
+
scale_factor = 1 / TensorList(tensors_).abs().global_sum().clip(min=1)
|
|
69
|
+
scale_factor = scale_factor.clip(min=torch.finfo(tensors_[0].dtype).eps)
|
|
70
|
+
Hx.mul_(scale_factor)
|
|
71
|
+
|
|
71
72
|
return Hx
|
|
72
73
|
|
|
73
74
|
|
|
@@ -122,9 +123,9 @@ class LSR1(Module):
|
|
|
122
123
|
|
|
123
124
|
|
|
124
125
|
@torch.no_grad
|
|
125
|
-
def step(self,
|
|
126
|
-
params = as_tensorlist(
|
|
127
|
-
update = as_tensorlist(
|
|
126
|
+
def step(self, var: Var):
|
|
127
|
+
params = as_tensorlist(var.params)
|
|
128
|
+
update = as_tensorlist(var.get_update())
|
|
128
129
|
step = self.global_state.get('step', 0)
|
|
129
130
|
self.global_state['step'] = step + 1
|
|
130
131
|
|
|
@@ -134,10 +135,10 @@ class LSR1(Module):
|
|
|
134
135
|
settings = self.settings[params[0]]
|
|
135
136
|
tol, update_freq, scale_second = itemgetter('tol', 'update_freq', 'scale_second')(settings)
|
|
136
137
|
|
|
137
|
-
params_beta, grads_beta_ = self.get_settings('params_beta', 'grads_beta'
|
|
138
|
+
params_beta, grads_beta_ = self.get_settings(params, 'params_beta', 'grads_beta') # type: ignore
|
|
138
139
|
l_params, l_update = _lerp_params_update_(self, params, update, params_beta, grads_beta_)
|
|
139
140
|
|
|
140
|
-
prev_l_params, prev_l_grad = self.get_state('prev_l_params', 'prev_l_grad',
|
|
141
|
+
prev_l_params, prev_l_grad = self.get_state(params, 'prev_l_params', 'prev_l_grad', cls=TensorList)
|
|
141
142
|
|
|
142
143
|
y_k = None
|
|
143
144
|
if step != 0:
|
|
@@ -152,13 +153,13 @@ class LSR1(Module):
|
|
|
152
153
|
prev_l_grad.copy_(l_update)
|
|
153
154
|
|
|
154
155
|
if 'inner' in self.children:
|
|
155
|
-
update = TensorList(
|
|
156
|
+
update = TensorList(apply_transform(self.children['inner'], tensors=update, params=params, grads=var.grad, var=var))
|
|
156
157
|
|
|
157
158
|
# tolerance on gradient difference to avoid exploding after converging
|
|
158
159
|
if tol is not None:
|
|
159
160
|
if y_k is not None and y_k.abs().global_max() <= tol:
|
|
160
|
-
|
|
161
|
-
return
|
|
161
|
+
var.update = update
|
|
162
|
+
return var
|
|
162
163
|
|
|
163
164
|
dir = lsr1_(
|
|
164
165
|
tensors_=update,
|
|
@@ -168,6 +169,6 @@ class LSR1(Module):
|
|
|
168
169
|
scale_second=scale_second,
|
|
169
170
|
)
|
|
170
171
|
|
|
171
|
-
|
|
172
|
+
var.update = dir
|
|
172
173
|
|
|
173
|
-
return
|
|
174
|
+
return var
|
|
@@ -5,17 +5,17 @@ from typing import Literal
|
|
|
5
5
|
|
|
6
6
|
import torch
|
|
7
7
|
|
|
8
|
-
from ...core import Chainable, Module, Transform,
|
|
8
|
+
from ...core import Chainable, Module, Transform, Var, apply_transform
|
|
9
9
|
from ...utils import NumberList, TensorList, as_tensorlist
|
|
10
10
|
from .lbfgs import _adaptive_damping, lbfgs
|
|
11
11
|
|
|
12
12
|
|
|
13
13
|
@torch.no_grad
|
|
14
|
-
def _store_sk_yk_after_step_hook(optimizer,
|
|
15
|
-
assert
|
|
16
|
-
with torch.enable_grad():
|
|
17
|
-
grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in
|
|
18
|
-
s_k =
|
|
14
|
+
def _store_sk_yk_after_step_hook(optimizer, var: Var, prev_params: TensorList, prev_grad: TensorList, damping, init_damping, eigval_bounds, s_history: deque[TensorList], y_history: deque[TensorList], sy_history: deque[torch.Tensor]):
|
|
15
|
+
assert var.closure is not None
|
|
16
|
+
with torch.enable_grad(): var.closure()
|
|
17
|
+
grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in var.params]
|
|
18
|
+
s_k = var.params - prev_params
|
|
19
19
|
y_k = grad - prev_grad
|
|
20
20
|
ys_k = s_k.dot(y_k)
|
|
21
21
|
|
|
@@ -95,11 +95,11 @@ class OnlineLBFGS(Module):
|
|
|
95
95
|
self.global_state['sy_history'].clear()
|
|
96
96
|
|
|
97
97
|
@torch.no_grad
|
|
98
|
-
def step(self,
|
|
99
|
-
assert
|
|
98
|
+
def step(self, var):
|
|
99
|
+
assert var.closure is not None
|
|
100
100
|
|
|
101
|
-
params = as_tensorlist(
|
|
102
|
-
update = as_tensorlist(
|
|
101
|
+
params = as_tensorlist(var.params)
|
|
102
|
+
update = as_tensorlist(var.get_update())
|
|
103
103
|
step = self.global_state.get('step', 0)
|
|
104
104
|
self.global_state['step'] = step + 1
|
|
105
105
|
|
|
@@ -113,7 +113,7 @@ class OnlineLBFGS(Module):
|
|
|
113
113
|
|
|
114
114
|
# sample gradient at previous params with current mini-batch
|
|
115
115
|
if sample_grads == 'before':
|
|
116
|
-
prev_params = self.get_state('prev_params',
|
|
116
|
+
prev_params = self.get_state(params, 'prev_params', cls=TensorList)
|
|
117
117
|
if step == 0:
|
|
118
118
|
s_k = None; y_k = None; ys_k = None
|
|
119
119
|
else:
|
|
@@ -121,7 +121,7 @@ class OnlineLBFGS(Module):
|
|
|
121
121
|
|
|
122
122
|
current_params = params.clone()
|
|
123
123
|
params.set_(prev_params)
|
|
124
|
-
with torch.enable_grad():
|
|
124
|
+
with torch.enable_grad(): var.closure()
|
|
125
125
|
y_k = update - params.grad
|
|
126
126
|
ys_k = s_k.dot(y_k)
|
|
127
127
|
params.set_(current_params)
|
|
@@ -146,7 +146,7 @@ class OnlineLBFGS(Module):
|
|
|
146
146
|
ys_k = s_k.dot(y_k)
|
|
147
147
|
|
|
148
148
|
# this will run after params are updated by Modular after running all future modules
|
|
149
|
-
|
|
149
|
+
var.post_step_hooks.append(
|
|
150
150
|
partial(
|
|
151
151
|
_store_sk_yk_after_step_hook,
|
|
152
152
|
prev_params=params.clone(),
|
|
@@ -164,18 +164,18 @@ class OnlineLBFGS(Module):
|
|
|
164
164
|
|
|
165
165
|
# step with inner module before applying preconditioner
|
|
166
166
|
if self.children:
|
|
167
|
-
update = TensorList(
|
|
167
|
+
update = TensorList(apply_transform(self.children['inner'], tensors=update, params=params, grads=var.grad, var=var))
|
|
168
168
|
|
|
169
169
|
# tolerance on gradient difference to avoid exploding after converging
|
|
170
170
|
if tol is not None:
|
|
171
171
|
if y_k is not None and y_k.abs().global_max() <= tol:
|
|
172
|
-
|
|
173
|
-
return
|
|
172
|
+
var.update = update # may have been updated by inner module, probably makes sense to use it here?
|
|
173
|
+
return var
|
|
174
174
|
|
|
175
175
|
# lerp initial H^-1 @ q guess
|
|
176
176
|
z_ema = None
|
|
177
177
|
if z_beta is not None:
|
|
178
|
-
z_ema = self.get_state('z_ema',
|
|
178
|
+
z_ema = self.get_state(params, 'z_ema', cls=TensorList)
|
|
179
179
|
|
|
180
180
|
# precondition
|
|
181
181
|
dir = lbfgs(
|
|
@@ -190,7 +190,7 @@ class OnlineLBFGS(Module):
|
|
|
190
190
|
step=step
|
|
191
191
|
)
|
|
192
192
|
|
|
193
|
-
|
|
193
|
+
var.update = dir
|
|
194
194
|
|
|
195
|
-
return
|
|
195
|
+
return var
|
|
196
196
|
|