torchzero 0.3.10__py3-none-any.whl → 0.3.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_identical.py +2 -3
- tests/test_opts.py +140 -100
- tests/test_tensorlist.py +8 -7
- tests/test_vars.py +1 -0
- torchzero/__init__.py +1 -1
- torchzero/core/__init__.py +2 -2
- torchzero/core/module.py +335 -50
- torchzero/core/reformulation.py +65 -0
- torchzero/core/transform.py +197 -70
- torchzero/modules/__init__.py +13 -4
- torchzero/modules/adaptive/__init__.py +30 -0
- torchzero/modules/adaptive/adagrad.py +356 -0
- torchzero/modules/adaptive/adahessian.py +224 -0
- torchzero/modules/{optimizers → adaptive}/adam.py +6 -8
- torchzero/modules/adaptive/adan.py +96 -0
- torchzero/modules/adaptive/adaptive_heavyball.py +54 -0
- torchzero/modules/adaptive/aegd.py +54 -0
- torchzero/modules/adaptive/esgd.py +171 -0
- torchzero/modules/{optimizers → adaptive}/lion.py +1 -1
- torchzero/modules/{experimental/spectral.py → adaptive/lmadagrad.py} +94 -71
- torchzero/modules/adaptive/mars.py +79 -0
- torchzero/modules/adaptive/matrix_momentum.py +146 -0
- torchzero/modules/adaptive/msam.py +188 -0
- torchzero/modules/{optimizers → adaptive}/muon.py +29 -5
- torchzero/modules/adaptive/natural_gradient.py +175 -0
- torchzero/modules/{optimizers → adaptive}/orthograd.py +1 -1
- torchzero/modules/{optimizers → adaptive}/rmsprop.py +7 -4
- torchzero/modules/{optimizers → adaptive}/rprop.py +42 -10
- torchzero/modules/adaptive/sam.py +163 -0
- torchzero/modules/{optimizers → adaptive}/shampoo.py +47 -9
- torchzero/modules/{optimizers → adaptive}/soap.py +52 -65
- torchzero/modules/adaptive/sophia_h.py +185 -0
- torchzero/modules/clipping/clipping.py +115 -25
- torchzero/modules/clipping/ema_clipping.py +31 -17
- torchzero/modules/clipping/growth_clipping.py +8 -7
- torchzero/modules/conjugate_gradient/__init__.py +11 -0
- torchzero/modules/conjugate_gradient/cg.py +355 -0
- torchzero/modules/experimental/__init__.py +13 -19
- torchzero/modules/{projections → experimental}/dct.py +11 -11
- torchzero/modules/{projections → experimental}/fft.py +10 -10
- torchzero/modules/experimental/gradmin.py +4 -3
- torchzero/modules/experimental/l_infinity.py +111 -0
- torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +5 -42
- torchzero/modules/experimental/newton_solver.py +79 -17
- torchzero/modules/experimental/newtonnewton.py +32 -15
- torchzero/modules/experimental/reduce_outward_lr.py +4 -4
- torchzero/modules/experimental/scipy_newton_cg.py +105 -0
- torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +13 -55
- torchzero/modules/functional.py +52 -6
- torchzero/modules/grad_approximation/fdm.py +30 -4
- torchzero/modules/grad_approximation/forward_gradient.py +16 -4
- torchzero/modules/grad_approximation/grad_approximator.py +51 -10
- torchzero/modules/grad_approximation/rfdm.py +321 -52
- torchzero/modules/higher_order/__init__.py +1 -1
- torchzero/modules/higher_order/higher_order_newton.py +164 -93
- torchzero/modules/least_squares/__init__.py +1 -0
- torchzero/modules/least_squares/gn.py +161 -0
- torchzero/modules/line_search/__init__.py +4 -4
- torchzero/modules/line_search/_polyinterp.py +289 -0
- torchzero/modules/line_search/adaptive.py +124 -0
- torchzero/modules/line_search/backtracking.py +95 -57
- torchzero/modules/line_search/line_search.py +171 -22
- torchzero/modules/line_search/scipy.py +3 -3
- torchzero/modules/line_search/strong_wolfe.py +327 -199
- torchzero/modules/misc/__init__.py +35 -0
- torchzero/modules/misc/debug.py +48 -0
- torchzero/modules/misc/escape.py +62 -0
- torchzero/modules/misc/gradient_accumulation.py +136 -0
- torchzero/modules/misc/homotopy.py +59 -0
- torchzero/modules/misc/misc.py +383 -0
- torchzero/modules/misc/multistep.py +194 -0
- torchzero/modules/misc/regularization.py +167 -0
- torchzero/modules/misc/split.py +123 -0
- torchzero/modules/{ops → misc}/switch.py +45 -4
- torchzero/modules/momentum/__init__.py +1 -5
- torchzero/modules/momentum/averaging.py +9 -9
- torchzero/modules/momentum/cautious.py +51 -19
- torchzero/modules/momentum/momentum.py +37 -2
- torchzero/modules/ops/__init__.py +11 -31
- torchzero/modules/ops/accumulate.py +6 -10
- torchzero/modules/ops/binary.py +81 -34
- torchzero/modules/{momentum/ema.py → ops/higher_level.py} +16 -39
- torchzero/modules/ops/multi.py +82 -21
- torchzero/modules/ops/reduce.py +16 -8
- torchzero/modules/ops/unary.py +29 -13
- torchzero/modules/ops/utility.py +30 -18
- torchzero/modules/projections/__init__.py +2 -4
- torchzero/modules/projections/cast.py +51 -0
- torchzero/modules/projections/galore.py +3 -1
- torchzero/modules/projections/projection.py +190 -96
- torchzero/modules/quasi_newton/__init__.py +9 -14
- torchzero/modules/quasi_newton/damping.py +105 -0
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -0
- torchzero/modules/quasi_newton/lbfgs.py +286 -173
- torchzero/modules/quasi_newton/lsr1.py +185 -106
- torchzero/modules/quasi_newton/quasi_newton.py +816 -268
- torchzero/modules/restarts/__init__.py +7 -0
- torchzero/modules/restarts/restars.py +252 -0
- torchzero/modules/second_order/__init__.py +3 -2
- torchzero/modules/second_order/multipoint.py +238 -0
- torchzero/modules/second_order/newton.py +292 -68
- torchzero/modules/second_order/newton_cg.py +365 -15
- torchzero/modules/second_order/nystrom.py +104 -1
- torchzero/modules/smoothing/__init__.py +1 -1
- torchzero/modules/smoothing/laplacian.py +14 -4
- torchzero/modules/smoothing/sampling.py +300 -0
- torchzero/modules/step_size/__init__.py +2 -0
- torchzero/modules/step_size/adaptive.py +387 -0
- torchzero/modules/step_size/lr.py +154 -0
- torchzero/modules/termination/__init__.py +14 -0
- torchzero/modules/termination/termination.py +207 -0
- torchzero/modules/trust_region/__init__.py +5 -0
- torchzero/modules/trust_region/cubic_regularization.py +170 -0
- torchzero/modules/trust_region/dogleg.py +92 -0
- torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
- torchzero/modules/trust_region/trust_cg.py +97 -0
- torchzero/modules/trust_region/trust_region.py +350 -0
- torchzero/modules/variance_reduction/__init__.py +1 -0
- torchzero/modules/variance_reduction/svrg.py +208 -0
- torchzero/modules/weight_decay/__init__.py +1 -1
- torchzero/modules/weight_decay/weight_decay.py +94 -11
- torchzero/modules/wrappers/optim_wrapper.py +29 -1
- torchzero/modules/zeroth_order/__init__.py +1 -0
- torchzero/modules/zeroth_order/cd.py +359 -0
- torchzero/optim/root.py +65 -0
- torchzero/optim/utility/split.py +8 -8
- torchzero/optim/wrappers/directsearch.py +39 -3
- torchzero/optim/wrappers/fcmaes.py +24 -15
- torchzero/optim/wrappers/mads.py +5 -6
- torchzero/optim/wrappers/nevergrad.py +16 -1
- torchzero/optim/wrappers/nlopt.py +0 -2
- torchzero/optim/wrappers/optuna.py +3 -3
- torchzero/optim/wrappers/scipy.py +86 -25
- torchzero/utils/__init__.py +40 -4
- torchzero/utils/compile.py +1 -1
- torchzero/utils/derivatives.py +126 -114
- torchzero/utils/linalg/__init__.py +9 -2
- torchzero/utils/linalg/linear_operator.py +329 -0
- torchzero/utils/linalg/matrix_funcs.py +2 -2
- torchzero/utils/linalg/orthogonalize.py +2 -1
- torchzero/utils/linalg/qr.py +2 -2
- torchzero/utils/linalg/solve.py +369 -58
- torchzero/utils/metrics.py +83 -0
- torchzero/utils/numberlist.py +2 -0
- torchzero/utils/python_tools.py +16 -0
- torchzero/utils/tensorlist.py +134 -51
- torchzero/utils/torch_tools.py +9 -4
- torchzero-0.3.13.dist-info/METADATA +14 -0
- torchzero-0.3.13.dist-info/RECORD +166 -0
- {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/top_level.txt +0 -1
- docs/source/conf.py +0 -57
- torchzero/modules/experimental/absoap.py +0 -250
- torchzero/modules/experimental/adadam.py +0 -112
- torchzero/modules/experimental/adamY.py +0 -125
- torchzero/modules/experimental/adasoap.py +0 -172
- torchzero/modules/experimental/diagonal_higher_order_newton.py +0 -225
- torchzero/modules/experimental/eigendescent.py +0 -117
- torchzero/modules/experimental/etf.py +0 -172
- torchzero/modules/experimental/soapy.py +0 -163
- torchzero/modules/experimental/structured_newton.py +0 -111
- torchzero/modules/experimental/subspace_preconditioners.py +0 -138
- torchzero/modules/experimental/tada.py +0 -38
- torchzero/modules/line_search/trust_region.py +0 -73
- torchzero/modules/lr/__init__.py +0 -2
- torchzero/modules/lr/adaptive.py +0 -93
- torchzero/modules/lr/lr.py +0 -63
- torchzero/modules/momentum/matrix_momentum.py +0 -166
- torchzero/modules/ops/debug.py +0 -25
- torchzero/modules/ops/misc.py +0 -418
- torchzero/modules/ops/split.py +0 -75
- torchzero/modules/optimizers/__init__.py +0 -18
- torchzero/modules/optimizers/adagrad.py +0 -155
- torchzero/modules/optimizers/sophia_h.py +0 -129
- torchzero/modules/quasi_newton/cg.py +0 -268
- torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
- torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +0 -266
- torchzero/modules/quasi_newton/olbfgs.py +0 -196
- torchzero/modules/smoothing/gaussian.py +0 -164
- torchzero-0.3.10.dist-info/METADATA +0 -379
- torchzero-0.3.10.dist-info/RECORD +0 -139
- torchzero-0.3.10.dist-info/licenses/LICENSE +0 -21
- {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/WHEEL +0 -0
|
@@ -1,129 +0,0 @@
|
|
|
1
|
-
from typing import Literal
|
|
2
|
-
from collections.abc import Callable
|
|
3
|
-
import torch
|
|
4
|
-
|
|
5
|
-
from ...core import Module, Target, Transform, Chainable, apply_transform
|
|
6
|
-
from ...utils import NumberList, TensorList, as_tensorlist
|
|
7
|
-
from ...utils.derivatives import hvp, hvp_fd_forward, hvp_fd_central
|
|
8
|
-
|
|
9
|
-
def sophia_H(
|
|
10
|
-
tensors: TensorList,
|
|
11
|
-
h: TensorList | None,
|
|
12
|
-
exp_avg_: TensorList,
|
|
13
|
-
h_exp_avg_: TensorList,
|
|
14
|
-
beta1: float | NumberList,
|
|
15
|
-
beta2: float | NumberList,
|
|
16
|
-
update_freq: int,
|
|
17
|
-
precond_scale: float | NumberList,
|
|
18
|
-
clip: float | NumberList,
|
|
19
|
-
eps: float | NumberList,
|
|
20
|
-
step: int
|
|
21
|
-
):
|
|
22
|
-
# momentum
|
|
23
|
-
exp_avg_.lerp_(tensors, 1-beta1)
|
|
24
|
-
|
|
25
|
-
# update preconditioner
|
|
26
|
-
if step % update_freq == 0:
|
|
27
|
-
assert h is not None
|
|
28
|
-
h_exp_avg_.lerp_(h, 1-beta2)
|
|
29
|
-
|
|
30
|
-
else:
|
|
31
|
-
assert h is None
|
|
32
|
-
|
|
33
|
-
denom = (h_exp_avg_ * precond_scale).clip_(min=eps)
|
|
34
|
-
return (exp_avg_ / denom).clip_(-clip, clip)
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
class SophiaH(Module):
|
|
38
|
-
def __init__(
|
|
39
|
-
self,
|
|
40
|
-
beta1: float = 0.96,
|
|
41
|
-
beta2: float = 0.99,
|
|
42
|
-
update_freq: int = 10,
|
|
43
|
-
precond_scale: float = 1,
|
|
44
|
-
clip: float = 1,
|
|
45
|
-
eps: float = 1e-12,
|
|
46
|
-
hvp_method: Literal['autograd', 'forward', 'central'] = 'autograd',
|
|
47
|
-
fd_h: float = 1e-3,
|
|
48
|
-
n_samples = 1,
|
|
49
|
-
seed: int | None = None,
|
|
50
|
-
inner: Chainable | None = None
|
|
51
|
-
):
|
|
52
|
-
defaults = dict(beta1=beta1, beta2=beta2, update_freq=update_freq, precond_scale=precond_scale, clip=clip, eps=eps, hvp_method=hvp_method, n_samples=n_samples, fd_h=fd_h, seed=seed)
|
|
53
|
-
super().__init__(defaults)
|
|
54
|
-
|
|
55
|
-
if inner is not None:
|
|
56
|
-
self.set_child('inner', inner)
|
|
57
|
-
|
|
58
|
-
@torch.no_grad
|
|
59
|
-
def step(self, var):
|
|
60
|
-
params = var.params
|
|
61
|
-
settings = self.settings[params[0]]
|
|
62
|
-
hvp_method = settings['hvp_method']
|
|
63
|
-
fd_h = settings['fd_h']
|
|
64
|
-
update_freq = settings['update_freq']
|
|
65
|
-
n_samples = settings['n_samples']
|
|
66
|
-
|
|
67
|
-
seed = settings['seed']
|
|
68
|
-
generator = None
|
|
69
|
-
if seed is not None:
|
|
70
|
-
if 'generator' not in self.global_state:
|
|
71
|
-
self.global_state['generator'] = torch.Generator(params[0].device).manual_seed(seed)
|
|
72
|
-
generator = self.global_state['generator']
|
|
73
|
-
|
|
74
|
-
beta1, beta2, precond_scale, clip, eps = self.get_settings(params,
|
|
75
|
-
'beta1', 'beta2', 'precond_scale', 'clip', 'eps', cls=NumberList)
|
|
76
|
-
|
|
77
|
-
exp_avg, h_exp_avg = self.get_state(params, 'exp_avg', 'h_exp_avg', cls=TensorList)
|
|
78
|
-
|
|
79
|
-
step = self.global_state.get('step', 0)
|
|
80
|
-
self.global_state['step'] = step + 1
|
|
81
|
-
|
|
82
|
-
closure = var.closure
|
|
83
|
-
assert closure is not None
|
|
84
|
-
|
|
85
|
-
h = None
|
|
86
|
-
if step % update_freq == 0:
|
|
87
|
-
|
|
88
|
-
grad=None
|
|
89
|
-
for i in range(n_samples):
|
|
90
|
-
u = [torch.randn(p.shape, device=p.device, dtype=p.dtype, generator=generator) for p in params]
|
|
91
|
-
|
|
92
|
-
if hvp_method == 'autograd':
|
|
93
|
-
if grad is None: grad = var.get_grad(create_graph=True)
|
|
94
|
-
assert grad is not None
|
|
95
|
-
Hvp = hvp(params, grad, u, retain_graph=i < n_samples-1)
|
|
96
|
-
|
|
97
|
-
elif hvp_method == 'forward':
|
|
98
|
-
loss, Hvp = hvp_fd_forward(closure, params, u, h=fd_h, g_0=var.get_grad(), normalize=True)
|
|
99
|
-
|
|
100
|
-
elif hvp_method == 'central':
|
|
101
|
-
loss, Hvp = hvp_fd_central(closure, params, u, h=fd_h, normalize=True)
|
|
102
|
-
|
|
103
|
-
else:
|
|
104
|
-
raise ValueError(hvp_method)
|
|
105
|
-
|
|
106
|
-
if h is None: h = Hvp
|
|
107
|
-
else: torch._foreach_add_(h, Hvp)
|
|
108
|
-
|
|
109
|
-
assert h is not None
|
|
110
|
-
if n_samples > 1: torch._foreach_div_(h, n_samples)
|
|
111
|
-
|
|
112
|
-
update = var.get_update()
|
|
113
|
-
if 'inner' in self.children:
|
|
114
|
-
update = apply_transform(self.children['inner'], tensors=update, params=params, grads=var.grad, var=var)
|
|
115
|
-
|
|
116
|
-
var.update = sophia_H(
|
|
117
|
-
tensors=TensorList(update),
|
|
118
|
-
h=TensorList(h) if h is not None else None,
|
|
119
|
-
exp_avg_=exp_avg,
|
|
120
|
-
h_exp_avg_=h_exp_avg,
|
|
121
|
-
beta1=beta1,
|
|
122
|
-
beta2=beta2,
|
|
123
|
-
update_freq=update_freq,
|
|
124
|
-
precond_scale=precond_scale,
|
|
125
|
-
clip=clip,
|
|
126
|
-
eps=eps,
|
|
127
|
-
step=step,
|
|
128
|
-
)
|
|
129
|
-
return var
|
|
@@ -1,268 +0,0 @@
|
|
|
1
|
-
from abc import ABC, abstractmethod
|
|
2
|
-
from typing import Literal
|
|
3
|
-
|
|
4
|
-
import torch
|
|
5
|
-
|
|
6
|
-
from ...core import Chainable, TensorwiseTransform, Transform, apply_transform
|
|
7
|
-
from ...utils import TensorList, as_tensorlist, unpack_dicts, unpack_states
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
class ConguateGradientBase(Transform, ABC):
|
|
11
|
-
"""all CGs are the same except beta calculation"""
|
|
12
|
-
def __init__(self, defaults = None, clip_beta: bool = False, reset_interval: int | None | Literal['auto'] = None, inner: Chainable | None = None):
|
|
13
|
-
if defaults is None: defaults = {}
|
|
14
|
-
defaults['reset_interval'] = reset_interval
|
|
15
|
-
defaults['clip_beta'] = clip_beta
|
|
16
|
-
super().__init__(defaults, uses_grad=False)
|
|
17
|
-
|
|
18
|
-
if inner is not None:
|
|
19
|
-
self.set_child('inner', inner)
|
|
20
|
-
|
|
21
|
-
def initialize(self, p: TensorList, g: TensorList):
|
|
22
|
-
"""runs on first step when prev_grads and prev_dir are not available"""
|
|
23
|
-
|
|
24
|
-
@abstractmethod
|
|
25
|
-
def get_beta(self, p: TensorList, g: TensorList, prev_g: TensorList, prev_d: TensorList) -> float | torch.Tensor:
|
|
26
|
-
"""returns beta"""
|
|
27
|
-
|
|
28
|
-
@torch.no_grad
|
|
29
|
-
def apply(self, tensors, params, grads, loss, states, settings):
|
|
30
|
-
tensors = as_tensorlist(tensors)
|
|
31
|
-
params = as_tensorlist(params)
|
|
32
|
-
|
|
33
|
-
step = self.global_state.get('step', 0)
|
|
34
|
-
prev_dir, prev_grads = unpack_states(states, tensors, 'prev_dir', 'prev_grad', cls=TensorList)
|
|
35
|
-
|
|
36
|
-
# initialize on first step
|
|
37
|
-
if step == 0:
|
|
38
|
-
self.initialize(params, tensors)
|
|
39
|
-
prev_dir.copy_(tensors)
|
|
40
|
-
prev_grads.copy_(tensors)
|
|
41
|
-
self.global_state['step'] = step + 1
|
|
42
|
-
return tensors
|
|
43
|
-
|
|
44
|
-
# get beta
|
|
45
|
-
beta = self.get_beta(params, tensors, prev_grads, prev_dir)
|
|
46
|
-
if settings[0]['clip_beta']: beta = max(0, beta) # pyright:ignore[reportArgumentType]
|
|
47
|
-
prev_grads.copy_(tensors)
|
|
48
|
-
|
|
49
|
-
# inner step
|
|
50
|
-
if 'inner' in self.children:
|
|
51
|
-
tensors = as_tensorlist(apply_transform(self.children['inner'], tensors, params, grads))
|
|
52
|
-
|
|
53
|
-
# calculate new direction with beta
|
|
54
|
-
dir = tensors.add_(prev_dir.mul_(beta))
|
|
55
|
-
prev_dir.copy_(dir)
|
|
56
|
-
|
|
57
|
-
# resetting
|
|
58
|
-
self.global_state['step'] = step + 1
|
|
59
|
-
reset_interval = settings[0]['reset_interval']
|
|
60
|
-
if reset_interval == 'auto': reset_interval = tensors.global_numel() + 1
|
|
61
|
-
if reset_interval is not None and (step+1) % reset_interval == 0:
|
|
62
|
-
self.reset()
|
|
63
|
-
|
|
64
|
-
return dir
|
|
65
|
-
|
|
66
|
-
# ------------------------------- Polak-Ribière ------------------------------ #
|
|
67
|
-
def polak_ribiere_beta(g: TensorList, prev_g: TensorList):
|
|
68
|
-
denom = prev_g.dot(prev_g)
|
|
69
|
-
if denom.abs() <= torch.finfo(g[0].dtype).eps: return 0
|
|
70
|
-
return g.dot(g - prev_g) / denom
|
|
71
|
-
|
|
72
|
-
class PolakRibiere(ConguateGradientBase):
|
|
73
|
-
"""Polak-Ribière-Polyak nonlinear conjugate gradient method. This requires step size to be determined via a line search, so put a line search like :code:`StrongWolfe(c2=0.1)` after this."""
|
|
74
|
-
def __init__(self, clip_beta=True, reset_interval: int | None = None, inner: Chainable | None = None):
|
|
75
|
-
super().__init__(clip_beta=clip_beta, reset_interval=reset_interval, inner=inner)
|
|
76
|
-
|
|
77
|
-
def get_beta(self, p, g, prev_g, prev_d):
|
|
78
|
-
return polak_ribiere_beta(g, prev_g)
|
|
79
|
-
|
|
80
|
-
# ------------------------------ Fletcher–Reeves ----------------------------- #
|
|
81
|
-
def fletcher_reeves_beta(gg: torch.Tensor, prev_gg: torch.Tensor):
|
|
82
|
-
if prev_gg.abs() <= torch.finfo(gg.dtype).eps: return 0
|
|
83
|
-
return gg / prev_gg
|
|
84
|
-
|
|
85
|
-
class FletcherReeves(ConguateGradientBase):
|
|
86
|
-
"""Fletcher–Reeves nonlinear conjugate gradient method. This requires step size to be determined via a line search, so put a line search like :code:`StrongWolfe` after this."""
|
|
87
|
-
def __init__(self, reset_interval: int | None | Literal['auto'] = 'auto', clip_beta=False, inner: Chainable | None = None):
|
|
88
|
-
super().__init__(clip_beta=clip_beta, reset_interval=reset_interval, inner=inner)
|
|
89
|
-
|
|
90
|
-
def initialize(self, p, g):
|
|
91
|
-
self.global_state['prev_gg'] = g.dot(g)
|
|
92
|
-
|
|
93
|
-
def get_beta(self, p, g, prev_g, prev_d):
|
|
94
|
-
gg = g.dot(g)
|
|
95
|
-
beta = fletcher_reeves_beta(gg, self.global_state['prev_gg'])
|
|
96
|
-
self.global_state['prev_gg'] = gg
|
|
97
|
-
return beta
|
|
98
|
-
|
|
99
|
-
# ----------------------------- Hestenes–Stiefel ----------------------------- #
|
|
100
|
-
def hestenes_stiefel_beta(g: TensorList, prev_d: TensorList,prev_g: TensorList):
|
|
101
|
-
grad_diff = g - prev_g
|
|
102
|
-
denom = prev_d.dot(grad_diff)
|
|
103
|
-
if denom.abs() < torch.finfo(g[0].dtype).eps: return 0
|
|
104
|
-
return (g.dot(grad_diff) / denom).neg()
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
class HestenesStiefel(ConguateGradientBase):
|
|
108
|
-
"""Hestenes–Stiefel nonlinear conjugate gradient method. This requires step size to be determined via a line search, so put a line search like :code:`StrongWolfe` after this."""
|
|
109
|
-
def __init__(self, reset_interval: int | None | Literal['auto'] = None, clip_beta=False, inner: Chainable | None = None):
|
|
110
|
-
super().__init__(clip_beta=clip_beta, reset_interval=reset_interval, inner=inner)
|
|
111
|
-
|
|
112
|
-
def get_beta(self, p, g, prev_g, prev_d):
|
|
113
|
-
return hestenes_stiefel_beta(g, prev_d, prev_g)
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
# --------------------------------- Dai–Yuan --------------------------------- #
|
|
117
|
-
def dai_yuan_beta(g: TensorList, prev_d: TensorList,prev_g: TensorList):
|
|
118
|
-
denom = prev_d.dot(g - prev_g)
|
|
119
|
-
if denom.abs() <= torch.finfo(g[0].dtype).eps: return 0
|
|
120
|
-
return (g.dot(g) / denom).neg()
|
|
121
|
-
|
|
122
|
-
class DaiYuan(ConguateGradientBase):
|
|
123
|
-
"""Dai–Yuan nonlinear conjugate gradient method. This requires step size to be determined via a line search, so put a line search like :code:`StrongWolfe` after this."""
|
|
124
|
-
def __init__(self, reset_interval: int | None | Literal['auto'] = None, clip_beta=False, inner: Chainable | None = None):
|
|
125
|
-
super().__init__(clip_beta=clip_beta, reset_interval=reset_interval, inner=inner)
|
|
126
|
-
|
|
127
|
-
def get_beta(self, p, g, prev_g, prev_d):
|
|
128
|
-
return dai_yuan_beta(g, prev_d, prev_g)
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
# -------------------------------- Liu-Storey -------------------------------- #
|
|
132
|
-
def liu_storey_beta(g:TensorList, prev_d:TensorList, prev_g:TensorList, ):
|
|
133
|
-
denom = prev_g.dot(prev_d)
|
|
134
|
-
if denom.abs() <= torch.finfo(g[0].dtype).eps: return 0
|
|
135
|
-
return g.dot(g - prev_g) / denom
|
|
136
|
-
|
|
137
|
-
class LiuStorey(ConguateGradientBase):
|
|
138
|
-
"""Liu-Storey nonlinear conjugate gradient method. This requires step size to be determined via a line search, so put a line search like :code:`StrongWolfe` after this."""
|
|
139
|
-
def __init__(self, reset_interval: int | None | Literal['auto'] = None, clip_beta=False, inner: Chainable | None = None):
|
|
140
|
-
super().__init__(clip_beta=clip_beta, reset_interval=reset_interval, inner=inner)
|
|
141
|
-
|
|
142
|
-
def get_beta(self, p, g, prev_g, prev_d):
|
|
143
|
-
return liu_storey_beta(g, prev_d, prev_g)
|
|
144
|
-
|
|
145
|
-
# ----------------------------- Conjugate Descent ---------------------------- #
|
|
146
|
-
class ConjugateDescent(Transform):
|
|
147
|
-
"""Conjugate Descent (CD). This requires step size to be determined via a line search, so put a line search like :code:`StrongWolfe` after this."""
|
|
148
|
-
def __init__(self, inner: Chainable | None = None):
|
|
149
|
-
super().__init__(defaults={}, uses_grad=False)
|
|
150
|
-
|
|
151
|
-
if inner is not None:
|
|
152
|
-
self.set_child('inner', inner)
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
@torch.no_grad
|
|
156
|
-
def apply(self, tensors, params, grads, loss, states, settings):
|
|
157
|
-
g = as_tensorlist(tensors)
|
|
158
|
-
|
|
159
|
-
prev_d = unpack_states(states, tensors, 'prev_dir', cls=TensorList, init=torch.zeros_like)
|
|
160
|
-
if 'denom' not in self.global_state:
|
|
161
|
-
self.global_state['denom'] = torch.tensor(0.).to(g[0])
|
|
162
|
-
|
|
163
|
-
prev_gd = self.global_state.get('prev_gd', 0)
|
|
164
|
-
if abs(prev_gd) <= torch.finfo(g[0].dtype).eps: beta = 0
|
|
165
|
-
else: beta = g.dot(g) / prev_gd
|
|
166
|
-
|
|
167
|
-
# inner step
|
|
168
|
-
if 'inner' in self.children:
|
|
169
|
-
g = as_tensorlist(apply_transform(self.children['inner'], g, params, grads))
|
|
170
|
-
|
|
171
|
-
dir = g.add_(prev_d.mul_(beta))
|
|
172
|
-
prev_d.copy_(dir)
|
|
173
|
-
self.global_state['prev_gd'] = g.dot(dir)
|
|
174
|
-
return dir
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
# -------------------------------- Hager-Zhang ------------------------------- #
|
|
178
|
-
def hager_zhang_beta(g:TensorList, prev_d:TensorList, prev_g:TensorList,):
|
|
179
|
-
g_diff = g - prev_g
|
|
180
|
-
denom = prev_d.dot(g_diff)
|
|
181
|
-
if denom.abs() <= torch.finfo(g[0].dtype).eps: return 0
|
|
182
|
-
|
|
183
|
-
term1 = 1/denom
|
|
184
|
-
# term2
|
|
185
|
-
term2 = (g_diff - (2 * prev_d * (g_diff.pow(2).global_sum()/denom))).dot(g)
|
|
186
|
-
return (term1 * term2).neg()
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
class HagerZhang(ConguateGradientBase):
|
|
190
|
-
"""Hager-Zhang nonlinear conjugate gradient method,
|
|
191
|
-
This requires step size to be determined via a line search, so put a line search like :code:`StrongWolfe` after this."""
|
|
192
|
-
def __init__(self, reset_interval: int | None | Literal['auto'] = None, clip_beta=False, inner: Chainable | None = None):
|
|
193
|
-
super().__init__(clip_beta=clip_beta, reset_interval=reset_interval, inner=inner)
|
|
194
|
-
|
|
195
|
-
def get_beta(self, p, g, prev_g, prev_d):
|
|
196
|
-
return hager_zhang_beta(g, prev_d, prev_g)
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
# ----------------------------------- HS-DY ---------------------------------- #
|
|
200
|
-
def hs_dy_beta(g: TensorList, prev_d: TensorList,prev_g: TensorList):
|
|
201
|
-
grad_diff = g - prev_g
|
|
202
|
-
denom = prev_d.dot(grad_diff)
|
|
203
|
-
if denom.abs() <= torch.finfo(g[0].dtype).eps: return 0
|
|
204
|
-
|
|
205
|
-
# Dai-Yuan
|
|
206
|
-
dy_beta = (g.dot(g) / denom).neg().clamp(min=0)
|
|
207
|
-
|
|
208
|
-
# Hestenes–Stiefel
|
|
209
|
-
hs_beta = (g.dot(grad_diff) / denom).neg().clamp(min=0)
|
|
210
|
-
|
|
211
|
-
return max(0, min(dy_beta, hs_beta)) # type:ignore
|
|
212
|
-
|
|
213
|
-
class HybridHS_DY(ConguateGradientBase):
|
|
214
|
-
"""HS-DY hybrid conjugate gradient method.
|
|
215
|
-
This requires step size to be determined via a line search, so put a line search like :code:`StrongWolfe` after this."""
|
|
216
|
-
def __init__(self, reset_interval: int | None | Literal['auto'] = None, clip_beta=False, inner: Chainable | None = None):
|
|
217
|
-
super().__init__(clip_beta=clip_beta, reset_interval=reset_interval, inner=inner)
|
|
218
|
-
|
|
219
|
-
def get_beta(self, p, g, prev_g, prev_d):
|
|
220
|
-
return hs_dy_beta(g, prev_d, prev_g)
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
def projected_gradient_(H:torch.Tensor, y:torch.Tensor, tol: float):
|
|
224
|
-
Hy = H @ y
|
|
225
|
-
denom = y.dot(Hy)
|
|
226
|
-
if denom.abs() < tol: return H
|
|
227
|
-
H -= (H @ y.outer(y) @ H) / denom
|
|
228
|
-
return H
|
|
229
|
-
|
|
230
|
-
class ProjectedGradientMethod(TensorwiseTransform):
|
|
231
|
-
"""Pearson, J. D. (1969). Variable metric methods of minimisation. The Computer Journal, 12(2), 171–178. doi:10.1093/comjnl/12.2.171.
|
|
232
|
-
|
|
233
|
-
(This is not the same as projected gradient descent)
|
|
234
|
-
"""
|
|
235
|
-
|
|
236
|
-
def __init__(
|
|
237
|
-
self,
|
|
238
|
-
tol: float = 1e-10,
|
|
239
|
-
reset_interval: int | None = None,
|
|
240
|
-
update_freq: int = 1,
|
|
241
|
-
scale_first: bool = False,
|
|
242
|
-
concat_params: bool = True,
|
|
243
|
-
inner: Chainable | None = None,
|
|
244
|
-
):
|
|
245
|
-
defaults = dict(reset_interval=reset_interval, tol=tol)
|
|
246
|
-
super().__init__(defaults, uses_grad=False, scale_first=scale_first, concat_params=concat_params, update_freq=update_freq, inner=inner)
|
|
247
|
-
|
|
248
|
-
def update_tensor(self, tensor, param, grad, loss, state, settings):
|
|
249
|
-
step = state.get('step', 0)
|
|
250
|
-
state['step'] = step + 1
|
|
251
|
-
reset_interval = settings['reset_interval']
|
|
252
|
-
if reset_interval is None: reset_interval = tensor.numel() + 1 # as recommended
|
|
253
|
-
|
|
254
|
-
if ("H" not in state) or (step % reset_interval == 0):
|
|
255
|
-
state["H"] = torch.eye(tensor.numel(), device=tensor.device, dtype=tensor.dtype)
|
|
256
|
-
state['g_prev'] = tensor.clone()
|
|
257
|
-
return
|
|
258
|
-
|
|
259
|
-
H = state['H']
|
|
260
|
-
g_prev = state['g_prev']
|
|
261
|
-
state['g_prev'] = tensor.clone()
|
|
262
|
-
y = (tensor - g_prev).ravel()
|
|
263
|
-
|
|
264
|
-
projected_gradient_(H, y, settings['tol'])
|
|
265
|
-
|
|
266
|
-
def apply_tensor(self, tensor, param, grad, loss, state, settings):
|
|
267
|
-
H = state['H']
|
|
268
|
-
return (H @ tensor.view(-1)).view_as(tensor)
|
|
@@ -1 +0,0 @@
|
|
|
1
|
-
from .modular_lbfgs import ModularLBFGS
|
|
@@ -1,266 +0,0 @@
|
|
|
1
|
-
from collections import deque
|
|
2
|
-
from operator import itemgetter
|
|
3
|
-
from typing import Any
|
|
4
|
-
|
|
5
|
-
import torch
|
|
6
|
-
|
|
7
|
-
from ....core import Chainable, Module, Transform, Var, apply_transform, maybe_chain
|
|
8
|
-
from ....utils import NumberList, TensorList, as_tensorlist
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
def _adaptive_damping(
|
|
12
|
-
s_k: TensorList,
|
|
13
|
-
y_k: TensorList,
|
|
14
|
-
ys_k: torch.Tensor,
|
|
15
|
-
init_damping = 0.99,
|
|
16
|
-
eigval_bounds = (0.01, 1.5)
|
|
17
|
-
):
|
|
18
|
-
# adaptive damping Al-Baali, M.: Quasi-Wolfe conditions for quasi-Newton methods for large-scale optimization. In: 40th Workshop on Large Scale Nonlinear Optimization, Erice, Italy, June 22–July 1 (2004)
|
|
19
|
-
sigma_l, sigma_h = eigval_bounds
|
|
20
|
-
u = ys_k / s_k.dot(s_k)
|
|
21
|
-
if u <= sigma_l < 1: tau = min((1-sigma_l)/(1-u), init_damping)
|
|
22
|
-
elif u >= sigma_h > 1: tau = min((sigma_h-1)/(u-1), init_damping)
|
|
23
|
-
else: tau = init_damping
|
|
24
|
-
y_k = tau * y_k + (1-tau) * s_k
|
|
25
|
-
ys_k = s_k.dot(y_k)
|
|
26
|
-
|
|
27
|
-
return s_k, y_k, ys_k
|
|
28
|
-
|
|
29
|
-
def lbfgs(
|
|
30
|
-
tensors_: TensorList,
|
|
31
|
-
var: Var,
|
|
32
|
-
s_history: deque[TensorList],
|
|
33
|
-
y_history: deque[TensorList],
|
|
34
|
-
sy_history: deque[torch.Tensor],
|
|
35
|
-
y_k: TensorList | None,
|
|
36
|
-
ys_k: torch.Tensor | None,
|
|
37
|
-
z_tfm: Any,
|
|
38
|
-
):
|
|
39
|
-
if len(s_history) == 0 or y_k is None or ys_k is None:
|
|
40
|
-
|
|
41
|
-
# initial step size guess modified from pytorch L-BFGS
|
|
42
|
-
scale = 1 / tensors_.abs().global_sum()
|
|
43
|
-
if scale < 1e-5: scale = 1 / tensors_.abs().mean()
|
|
44
|
-
return tensors_.mul_(min(1.0, scale)) # pyright: ignore[reportArgumentType]
|
|
45
|
-
|
|
46
|
-
else:
|
|
47
|
-
# 1st loop
|
|
48
|
-
alpha_list = []
|
|
49
|
-
q = tensors_.clone()
|
|
50
|
-
for s_i, y_i, ys_i in zip(reversed(s_history), reversed(y_history), reversed(sy_history)):
|
|
51
|
-
p_i = 1 / ys_i # this is also denoted as ρ (rho)
|
|
52
|
-
alpha = p_i * s_i.dot(q)
|
|
53
|
-
alpha_list.append(alpha)
|
|
54
|
-
q.sub_(y_i, alpha=alpha) # pyright: ignore[reportArgumentType]
|
|
55
|
-
|
|
56
|
-
# calculate z
|
|
57
|
-
# s.y/y.y is also this weird y-looking symbol I couldn't find
|
|
58
|
-
# z is it times q
|
|
59
|
-
# actually H0 = (s.y/y.y) * I, and z = H0 @ q
|
|
60
|
-
z = q * (ys_k / (y_k.dot(y_k)))
|
|
61
|
-
|
|
62
|
-
if z_tfm is not None:
|
|
63
|
-
z = TensorList(apply_transform(z_tfm, tensors=z, params=var.params, grads=var.grad, var=var))
|
|
64
|
-
|
|
65
|
-
# 2nd loop
|
|
66
|
-
for s_i, y_i, ys_i, alpha_i in zip(s_history, y_history, sy_history, reversed(alpha_list)):
|
|
67
|
-
p_i = 1 / ys_i
|
|
68
|
-
beta_i = p_i * y_i.dot(z)
|
|
69
|
-
z.add_(s_i, alpha = alpha_i - beta_i)
|
|
70
|
-
|
|
71
|
-
return z
|
|
72
|
-
|
|
73
|
-
def _apply_tfms_into_history(
|
|
74
|
-
self: Module,
|
|
75
|
-
params: list[torch.Tensor],
|
|
76
|
-
var: Var,
|
|
77
|
-
update: list[torch.Tensor],
|
|
78
|
-
):
|
|
79
|
-
if 'params_history_tfm' in self.children:
|
|
80
|
-
params = apply_transform(self.children['params_history_tfm'], tensors=as_tensorlist(params).clone(), params=params, grads=var.grad, var=var)
|
|
81
|
-
|
|
82
|
-
if 'grad_history_tfm' in self.children:
|
|
83
|
-
update = apply_transform(self.children['grad_history_tfm'], tensors=as_tensorlist(update).clone(), params=params, grads=var.grad, var=var)
|
|
84
|
-
|
|
85
|
-
return params, update
|
|
86
|
-
|
|
87
|
-
def _apply_tfms_into_precond(
|
|
88
|
-
self: Module,
|
|
89
|
-
params: list[torch.Tensor],
|
|
90
|
-
var: Var,
|
|
91
|
-
update: list[torch.Tensor],
|
|
92
|
-
):
|
|
93
|
-
if 'params_precond_tfm' in self.children:
|
|
94
|
-
params = apply_transform(self.children['params_precond_tfm'], tensors=as_tensorlist(params).clone(), params=params, grads=var.grad, var=var)
|
|
95
|
-
|
|
96
|
-
if 'grad_precond_tfm' in self.children:
|
|
97
|
-
update = apply_transform(self.children['grad_precond_tfm'], tensors=update, params=params, grads=var.grad, var=var)
|
|
98
|
-
|
|
99
|
-
return params, update
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
class ModularLBFGS(Module):
|
|
103
|
-
"""L-BFGS with ability to apply transforms to many inner variables.
|
|
104
|
-
|
|
105
|
-
Args:
|
|
106
|
-
history_size (int, optional): number of past parameter differences and gradient differences to store. Defaults to 10.
|
|
107
|
-
tol (float | None, optional):
|
|
108
|
-
tolerance for minimal gradient difference to avoid instability after converging to minima. Defaults to 1e-10.
|
|
109
|
-
damping (bool, optional):
|
|
110
|
-
whether to use adaptive damping. Learning rate might need to be lowered with this enabled. Defaults to False.
|
|
111
|
-
init_damping (float, optional):
|
|
112
|
-
initial damping for adaptive dampening. Defaults to 0.9.
|
|
113
|
-
eigval_bounds (tuple, optional):
|
|
114
|
-
eigenvalue bounds for adaptive dampening. Defaults to (0.5, 50).
|
|
115
|
-
update_freq (int, optional):
|
|
116
|
-
how often to update L-BFGS history. Defaults to 1.
|
|
117
|
-
z_tfm (float | None, optional):
|
|
118
|
-
transform module applied to initial H^-1 @ q guess. Defaults to None.
|
|
119
|
-
params_history_tfm (AnyTransform | None, optional):
|
|
120
|
-
transform module applied to params before adding s_k to history. Defaults to None.
|
|
121
|
-
grad_history_tfm (AnyTransform | None, optional):
|
|
122
|
-
transform module applied to grads before adding y_k to history. Defaults to None.
|
|
123
|
-
params_precond_tfm (AnyTransform | None, optional):
|
|
124
|
-
transform module applied to params to calculate s_k before preconditioning. Defaults to None.
|
|
125
|
-
grad_precond_tfm (AnyTransform | None, optional):
|
|
126
|
-
transform module applied to grads to calculate y_k before preconditioning. Defaults to None.
|
|
127
|
-
update_precond_tfm (Chainable | None, optional):
|
|
128
|
-
transform module applied to grads that are being preconditioned. Defaults to None.
|
|
129
|
-
"""
|
|
130
|
-
def __init__(
|
|
131
|
-
self,
|
|
132
|
-
history_size=10,
|
|
133
|
-
tol: float | None = 1e-10,
|
|
134
|
-
damping: bool = False,
|
|
135
|
-
init_damping=0.9,
|
|
136
|
-
eigval_bounds=(0.5, 50),
|
|
137
|
-
update_freq = 1,
|
|
138
|
-
params_history_tfm: Chainable | None = None,
|
|
139
|
-
grad_history_tfm: Chainable | None = None,
|
|
140
|
-
params_precond_tfm: Chainable | None = None,
|
|
141
|
-
grad_precond_tfm: Chainable | None = None,
|
|
142
|
-
update_precond_tfm: Chainable | None = None,
|
|
143
|
-
z_tfm: Chainable | None = None,
|
|
144
|
-
):
|
|
145
|
-
defaults = dict(history_size=history_size, tol=tol, damping=damping, init_damping=init_damping, eigval_bounds=eigval_bounds, update_freq=update_freq)
|
|
146
|
-
super().__init__(defaults)
|
|
147
|
-
|
|
148
|
-
self.global_state['s_history'] = deque(maxlen=history_size)
|
|
149
|
-
self.global_state['y_history'] = deque(maxlen=history_size)
|
|
150
|
-
self.global_state['sy_history'] = deque(maxlen=history_size)
|
|
151
|
-
|
|
152
|
-
loc = locals().copy()
|
|
153
|
-
for k in ('update_precond_tfm', 'params_history_tfm', 'grad_history_tfm', 'params_precond_tfm', 'grad_precond_tfm','z_tfm'):
|
|
154
|
-
v = loc[k]
|
|
155
|
-
if v is not None:
|
|
156
|
-
self.set_child(k,v)
|
|
157
|
-
|
|
158
|
-
def reset(self):
|
|
159
|
-
"""Resets the internal state of the L-SR1 module."""
|
|
160
|
-
# super().reset() # Clears self.state (per-parameter) if any, and "step"
|
|
161
|
-
self.state.clear()
|
|
162
|
-
self.global_state['step'] = 0
|
|
163
|
-
self.global_state['s_history'].clear()
|
|
164
|
-
self.global_state['y_history'].clear()
|
|
165
|
-
self.global_state['sy_history'].clear()
|
|
166
|
-
|
|
167
|
-
@torch.no_grad
|
|
168
|
-
def step(self, var):
|
|
169
|
-
params = as_tensorlist(var.params)
|
|
170
|
-
update = as_tensorlist(var.get_update())
|
|
171
|
-
step = self.global_state.get('step', 0)
|
|
172
|
-
self.global_state['step'] = step + 1
|
|
173
|
-
|
|
174
|
-
# history of s and k
|
|
175
|
-
s_history: deque[TensorList] = self.global_state['s_history']
|
|
176
|
-
y_history: deque[TensorList] = self.global_state['y_history']
|
|
177
|
-
sy_history: deque[torch.Tensor] = self.global_state['sy_history']
|
|
178
|
-
|
|
179
|
-
tol, damping, init_damping, eigval_bounds, update_freq = itemgetter(
|
|
180
|
-
'tol', 'damping', 'init_damping', 'eigval_bounds', 'update_freq')(self.settings[params[0]])
|
|
181
|
-
|
|
182
|
-
# params_beta, grads_beta = self.get_settings('params_beta', 'grads_beta', params=params, cls=NumberList)
|
|
183
|
-
# l_params, l_update = _lerp_params_update_(self, params, update, params_beta, grads_beta)
|
|
184
|
-
|
|
185
|
-
# params and update that go into history
|
|
186
|
-
params_h, update_h = _apply_tfms_into_history(
|
|
187
|
-
self,
|
|
188
|
-
params=params,
|
|
189
|
-
var=var,
|
|
190
|
-
update=update,
|
|
191
|
-
)
|
|
192
|
-
|
|
193
|
-
prev_params_h, prev_grad_h = self.get_state(params, 'prev_params_h', 'prev_grad_h', cls=TensorList)
|
|
194
|
-
|
|
195
|
-
# 1st step - there are no previous params and grads, `lbfgs` will do normalized SGD step
|
|
196
|
-
if step == 0:
|
|
197
|
-
s_k_h = None; y_k_h = None; ys_k_h = None
|
|
198
|
-
else:
|
|
199
|
-
s_k_h = params_h - prev_params_h
|
|
200
|
-
y_k_h = update_h - prev_grad_h
|
|
201
|
-
ys_k_h = s_k_h.dot(y_k_h)
|
|
202
|
-
|
|
203
|
-
if damping:
|
|
204
|
-
s_k_h, y_k_h, ys_k_h = _adaptive_damping(s_k_h, y_k_h, ys_k_h, init_damping=init_damping, eigval_bounds=eigval_bounds)
|
|
205
|
-
|
|
206
|
-
prev_params_h.copy_(params_h)
|
|
207
|
-
prev_grad_h.copy_(update_h)
|
|
208
|
-
|
|
209
|
-
# update effective preconditioning state
|
|
210
|
-
if step % update_freq == 0:
|
|
211
|
-
if ys_k_h is not None and ys_k_h > 1e-10:
|
|
212
|
-
assert s_k_h is not None and y_k_h is not None
|
|
213
|
-
s_history.append(s_k_h)
|
|
214
|
-
y_history.append(y_k_h)
|
|
215
|
-
sy_history.append(ys_k_h)
|
|
216
|
-
|
|
217
|
-
# step with inner module before applying preconditioner
|
|
218
|
-
if 'update_precond_tfm' in self.children:
|
|
219
|
-
update_precond_tfm = self.children['update_precond_tfm']
|
|
220
|
-
inner_var = update_precond_tfm.step(var.clone(clone_update=True))
|
|
221
|
-
var.update_attrs_from_clone_(inner_var)
|
|
222
|
-
tensors = inner_var.update
|
|
223
|
-
assert tensors is not None
|
|
224
|
-
else:
|
|
225
|
-
tensors = update.clone()
|
|
226
|
-
|
|
227
|
-
# transforms into preconditioner
|
|
228
|
-
params_p, update_p = _apply_tfms_into_precond(self, params=params, var=var, update=update)
|
|
229
|
-
prev_params_p, prev_grad_p = self.get_state(params, 'prev_params_p', 'prev_grad_p', cls=TensorList)
|
|
230
|
-
|
|
231
|
-
if step == 0:
|
|
232
|
-
s_k_p = None; y_k_p = None; ys_k_p = None
|
|
233
|
-
|
|
234
|
-
else:
|
|
235
|
-
s_k_p = params_p - prev_params_p
|
|
236
|
-
y_k_p = update_p - prev_grad_p
|
|
237
|
-
ys_k_p = s_k_p.dot(y_k_p)
|
|
238
|
-
|
|
239
|
-
if damping:
|
|
240
|
-
s_k_p, y_k_p, ys_k_p = _adaptive_damping(s_k_p, y_k_p, ys_k_p, init_damping=init_damping, eigval_bounds=eigval_bounds)
|
|
241
|
-
|
|
242
|
-
prev_params_p.copy_(params_p)
|
|
243
|
-
prev_grad_p.copy_(update_p)
|
|
244
|
-
|
|
245
|
-
# tolerance on gradient difference to avoid exploding after converging
|
|
246
|
-
if tol is not None:
|
|
247
|
-
if y_k_p is not None and y_k_p.abs().global_max() <= tol:
|
|
248
|
-
var.update = update # may have been updated by inner module, probably makes sense to use it here?
|
|
249
|
-
return var
|
|
250
|
-
|
|
251
|
-
# precondition
|
|
252
|
-
dir = lbfgs(
|
|
253
|
-
tensors_=as_tensorlist(tensors),
|
|
254
|
-
var=var,
|
|
255
|
-
s_history=s_history,
|
|
256
|
-
y_history=y_history,
|
|
257
|
-
sy_history=sy_history,
|
|
258
|
-
y_k=y_k_p,
|
|
259
|
-
ys_k=ys_k_p,
|
|
260
|
-
z_tfm=self.children.get('z_tfm', None),
|
|
261
|
-
)
|
|
262
|
-
|
|
263
|
-
var.update = dir
|
|
264
|
-
|
|
265
|
-
return var
|
|
266
|
-
|