torchzero 0.3.10__py3-none-any.whl → 0.3.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_identical.py +2 -3
- tests/test_opts.py +140 -100
- tests/test_tensorlist.py +8 -7
- tests/test_vars.py +1 -0
- torchzero/__init__.py +1 -1
- torchzero/core/__init__.py +2 -2
- torchzero/core/module.py +335 -50
- torchzero/core/reformulation.py +65 -0
- torchzero/core/transform.py +197 -70
- torchzero/modules/__init__.py +13 -4
- torchzero/modules/adaptive/__init__.py +30 -0
- torchzero/modules/adaptive/adagrad.py +356 -0
- torchzero/modules/adaptive/adahessian.py +224 -0
- torchzero/modules/{optimizers → adaptive}/adam.py +6 -8
- torchzero/modules/adaptive/adan.py +96 -0
- torchzero/modules/adaptive/adaptive_heavyball.py +54 -0
- torchzero/modules/adaptive/aegd.py +54 -0
- torchzero/modules/adaptive/esgd.py +171 -0
- torchzero/modules/{optimizers → adaptive}/lion.py +1 -1
- torchzero/modules/{experimental/spectral.py → adaptive/lmadagrad.py} +94 -71
- torchzero/modules/adaptive/mars.py +79 -0
- torchzero/modules/adaptive/matrix_momentum.py +146 -0
- torchzero/modules/adaptive/msam.py +188 -0
- torchzero/modules/{optimizers → adaptive}/muon.py +29 -5
- torchzero/modules/adaptive/natural_gradient.py +175 -0
- torchzero/modules/{optimizers → adaptive}/orthograd.py +1 -1
- torchzero/modules/{optimizers → adaptive}/rmsprop.py +7 -4
- torchzero/modules/{optimizers → adaptive}/rprop.py +42 -10
- torchzero/modules/adaptive/sam.py +163 -0
- torchzero/modules/{optimizers → adaptive}/shampoo.py +47 -9
- torchzero/modules/{optimizers → adaptive}/soap.py +52 -65
- torchzero/modules/adaptive/sophia_h.py +185 -0
- torchzero/modules/clipping/clipping.py +115 -25
- torchzero/modules/clipping/ema_clipping.py +31 -17
- torchzero/modules/clipping/growth_clipping.py +8 -7
- torchzero/modules/conjugate_gradient/__init__.py +11 -0
- torchzero/modules/conjugate_gradient/cg.py +355 -0
- torchzero/modules/experimental/__init__.py +13 -19
- torchzero/modules/{projections → experimental}/dct.py +11 -11
- torchzero/modules/{projections → experimental}/fft.py +10 -10
- torchzero/modules/experimental/gradmin.py +4 -3
- torchzero/modules/experimental/l_infinity.py +111 -0
- torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +5 -42
- torchzero/modules/experimental/newton_solver.py +79 -17
- torchzero/modules/experimental/newtonnewton.py +32 -15
- torchzero/modules/experimental/reduce_outward_lr.py +4 -4
- torchzero/modules/experimental/scipy_newton_cg.py +105 -0
- torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +13 -55
- torchzero/modules/functional.py +52 -6
- torchzero/modules/grad_approximation/fdm.py +30 -4
- torchzero/modules/grad_approximation/forward_gradient.py +16 -4
- torchzero/modules/grad_approximation/grad_approximator.py +51 -10
- torchzero/modules/grad_approximation/rfdm.py +321 -52
- torchzero/modules/higher_order/__init__.py +1 -1
- torchzero/modules/higher_order/higher_order_newton.py +164 -93
- torchzero/modules/least_squares/__init__.py +1 -0
- torchzero/modules/least_squares/gn.py +161 -0
- torchzero/modules/line_search/__init__.py +4 -4
- torchzero/modules/line_search/_polyinterp.py +289 -0
- torchzero/modules/line_search/adaptive.py +124 -0
- torchzero/modules/line_search/backtracking.py +95 -57
- torchzero/modules/line_search/line_search.py +171 -22
- torchzero/modules/line_search/scipy.py +3 -3
- torchzero/modules/line_search/strong_wolfe.py +327 -199
- torchzero/modules/misc/__init__.py +35 -0
- torchzero/modules/misc/debug.py +48 -0
- torchzero/modules/misc/escape.py +62 -0
- torchzero/modules/misc/gradient_accumulation.py +136 -0
- torchzero/modules/misc/homotopy.py +59 -0
- torchzero/modules/misc/misc.py +383 -0
- torchzero/modules/misc/multistep.py +194 -0
- torchzero/modules/misc/regularization.py +167 -0
- torchzero/modules/misc/split.py +123 -0
- torchzero/modules/{ops → misc}/switch.py +45 -4
- torchzero/modules/momentum/__init__.py +1 -5
- torchzero/modules/momentum/averaging.py +9 -9
- torchzero/modules/momentum/cautious.py +51 -19
- torchzero/modules/momentum/momentum.py +37 -2
- torchzero/modules/ops/__init__.py +11 -31
- torchzero/modules/ops/accumulate.py +6 -10
- torchzero/modules/ops/binary.py +81 -34
- torchzero/modules/{momentum/ema.py → ops/higher_level.py} +16 -39
- torchzero/modules/ops/multi.py +82 -21
- torchzero/modules/ops/reduce.py +16 -8
- torchzero/modules/ops/unary.py +29 -13
- torchzero/modules/ops/utility.py +30 -18
- torchzero/modules/projections/__init__.py +2 -4
- torchzero/modules/projections/cast.py +51 -0
- torchzero/modules/projections/galore.py +3 -1
- torchzero/modules/projections/projection.py +190 -96
- torchzero/modules/quasi_newton/__init__.py +9 -14
- torchzero/modules/quasi_newton/damping.py +105 -0
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -0
- torchzero/modules/quasi_newton/lbfgs.py +286 -173
- torchzero/modules/quasi_newton/lsr1.py +185 -106
- torchzero/modules/quasi_newton/quasi_newton.py +816 -268
- torchzero/modules/restarts/__init__.py +7 -0
- torchzero/modules/restarts/restars.py +252 -0
- torchzero/modules/second_order/__init__.py +3 -2
- torchzero/modules/second_order/multipoint.py +238 -0
- torchzero/modules/second_order/newton.py +292 -68
- torchzero/modules/second_order/newton_cg.py +365 -15
- torchzero/modules/second_order/nystrom.py +104 -1
- torchzero/modules/smoothing/__init__.py +1 -1
- torchzero/modules/smoothing/laplacian.py +14 -4
- torchzero/modules/smoothing/sampling.py +300 -0
- torchzero/modules/step_size/__init__.py +2 -0
- torchzero/modules/step_size/adaptive.py +387 -0
- torchzero/modules/step_size/lr.py +154 -0
- torchzero/modules/termination/__init__.py +14 -0
- torchzero/modules/termination/termination.py +207 -0
- torchzero/modules/trust_region/__init__.py +5 -0
- torchzero/modules/trust_region/cubic_regularization.py +170 -0
- torchzero/modules/trust_region/dogleg.py +92 -0
- torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
- torchzero/modules/trust_region/trust_cg.py +97 -0
- torchzero/modules/trust_region/trust_region.py +350 -0
- torchzero/modules/variance_reduction/__init__.py +1 -0
- torchzero/modules/variance_reduction/svrg.py +208 -0
- torchzero/modules/weight_decay/__init__.py +1 -1
- torchzero/modules/weight_decay/weight_decay.py +94 -11
- torchzero/modules/wrappers/optim_wrapper.py +29 -1
- torchzero/modules/zeroth_order/__init__.py +1 -0
- torchzero/modules/zeroth_order/cd.py +359 -0
- torchzero/optim/root.py +65 -0
- torchzero/optim/utility/split.py +8 -8
- torchzero/optim/wrappers/directsearch.py +39 -3
- torchzero/optim/wrappers/fcmaes.py +24 -15
- torchzero/optim/wrappers/mads.py +5 -6
- torchzero/optim/wrappers/nevergrad.py +16 -1
- torchzero/optim/wrappers/nlopt.py +0 -2
- torchzero/optim/wrappers/optuna.py +3 -3
- torchzero/optim/wrappers/scipy.py +86 -25
- torchzero/utils/__init__.py +40 -4
- torchzero/utils/compile.py +1 -1
- torchzero/utils/derivatives.py +126 -114
- torchzero/utils/linalg/__init__.py +9 -2
- torchzero/utils/linalg/linear_operator.py +329 -0
- torchzero/utils/linalg/matrix_funcs.py +2 -2
- torchzero/utils/linalg/orthogonalize.py +2 -1
- torchzero/utils/linalg/qr.py +2 -2
- torchzero/utils/linalg/solve.py +369 -58
- torchzero/utils/metrics.py +83 -0
- torchzero/utils/numberlist.py +2 -0
- torchzero/utils/python_tools.py +16 -0
- torchzero/utils/tensorlist.py +134 -51
- torchzero/utils/torch_tools.py +9 -4
- torchzero-0.3.13.dist-info/METADATA +14 -0
- torchzero-0.3.13.dist-info/RECORD +166 -0
- {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/top_level.txt +0 -1
- docs/source/conf.py +0 -57
- torchzero/modules/experimental/absoap.py +0 -250
- torchzero/modules/experimental/adadam.py +0 -112
- torchzero/modules/experimental/adamY.py +0 -125
- torchzero/modules/experimental/adasoap.py +0 -172
- torchzero/modules/experimental/diagonal_higher_order_newton.py +0 -225
- torchzero/modules/experimental/eigendescent.py +0 -117
- torchzero/modules/experimental/etf.py +0 -172
- torchzero/modules/experimental/soapy.py +0 -163
- torchzero/modules/experimental/structured_newton.py +0 -111
- torchzero/modules/experimental/subspace_preconditioners.py +0 -138
- torchzero/modules/experimental/tada.py +0 -38
- torchzero/modules/line_search/trust_region.py +0 -73
- torchzero/modules/lr/__init__.py +0 -2
- torchzero/modules/lr/adaptive.py +0 -93
- torchzero/modules/lr/lr.py +0 -63
- torchzero/modules/momentum/matrix_momentum.py +0 -166
- torchzero/modules/ops/debug.py +0 -25
- torchzero/modules/ops/misc.py +0 -418
- torchzero/modules/ops/split.py +0 -75
- torchzero/modules/optimizers/__init__.py +0 -18
- torchzero/modules/optimizers/adagrad.py +0 -155
- torchzero/modules/optimizers/sophia_h.py +0 -129
- torchzero/modules/quasi_newton/cg.py +0 -268
- torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
- torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +0 -266
- torchzero/modules/quasi_newton/olbfgs.py +0 -196
- torchzero/modules/smoothing/gaussian.py +0 -164
- torchzero-0.3.10.dist-info/METADATA +0 -379
- torchzero-0.3.10.dist-info/RECORD +0 -139
- torchzero-0.3.10.dist-info/licenses/LICENSE +0 -21
- {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/WHEEL +0 -0
|
@@ -1,111 +0,0 @@
|
|
|
1
|
-
# idea https://arxiv.org/pdf/2212.09841
|
|
2
|
-
import warnings
|
|
3
|
-
from collections.abc import Callable
|
|
4
|
-
from functools import partial
|
|
5
|
-
from typing import Literal
|
|
6
|
-
|
|
7
|
-
import torch
|
|
8
|
-
|
|
9
|
-
from ...core import Chainable, Module, apply_transform
|
|
10
|
-
from ...utils import TensorList, vec_to_tensors
|
|
11
|
-
from ...utils.derivatives import (
|
|
12
|
-
hessian_list_to_mat,
|
|
13
|
-
hessian_mat,
|
|
14
|
-
hvp,
|
|
15
|
-
hvp_fd_central,
|
|
16
|
-
hvp_fd_forward,
|
|
17
|
-
jacobian_and_hessian_wrt,
|
|
18
|
-
)
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
class StructuredNewton(Module):
|
|
22
|
-
"""TODO. Please note that this is experimental and isn't guaranteed to work.
|
|
23
|
-
Args:
|
|
24
|
-
structure (str, optional): structure.
|
|
25
|
-
reg (float, optional): tikhonov regularizer value. Defaults to 1e-6.
|
|
26
|
-
hvp_method (str):
|
|
27
|
-
how to calculate hvp_method. Defaults to "autograd".
|
|
28
|
-
inner (Chainable | None, optional): inner modules. Defaults to None.
|
|
29
|
-
|
|
30
|
-
"""
|
|
31
|
-
def __init__(
|
|
32
|
-
self,
|
|
33
|
-
structure: Literal[
|
|
34
|
-
"diagonal",
|
|
35
|
-
"diagonal1",
|
|
36
|
-
"diagonal_abs",
|
|
37
|
-
"tridiagonal",
|
|
38
|
-
"circulant",
|
|
39
|
-
"toeplitz",
|
|
40
|
-
"toeplitz_like",
|
|
41
|
-
"hankel",
|
|
42
|
-
"rank1",
|
|
43
|
-
"rank2", # any rank
|
|
44
|
-
]
|
|
45
|
-
| str = "diagonal",
|
|
46
|
-
reg: float = 1e-6,
|
|
47
|
-
hvp_method: Literal["autograd", "forward", "central"] = "autograd",
|
|
48
|
-
h: float = 1e-3,
|
|
49
|
-
inner: Chainable | None = None,
|
|
50
|
-
):
|
|
51
|
-
defaults = dict(reg=reg, hvp_method=hvp_method, structure=structure, h=h)
|
|
52
|
-
super().__init__(defaults)
|
|
53
|
-
|
|
54
|
-
if inner is not None:
|
|
55
|
-
self.set_child('inner', inner)
|
|
56
|
-
|
|
57
|
-
@torch.no_grad
|
|
58
|
-
def step(self, var):
|
|
59
|
-
params = TensorList(var.params)
|
|
60
|
-
closure = var.closure
|
|
61
|
-
if closure is None: raise RuntimeError('NewtonCG requires closure')
|
|
62
|
-
|
|
63
|
-
settings = self.settings[params[0]]
|
|
64
|
-
reg = settings['reg']
|
|
65
|
-
hvp_method = settings['hvp_method']
|
|
66
|
-
structure = settings['structure']
|
|
67
|
-
h = settings['h']
|
|
68
|
-
|
|
69
|
-
# ------------------------ calculate grad and hessian ------------------------ #
|
|
70
|
-
if hvp_method == 'autograd':
|
|
71
|
-
grad = var.get_grad(create_graph=True)
|
|
72
|
-
def Hvp_fn1(x):
|
|
73
|
-
return hvp(params, grad, x, retain_graph=True)
|
|
74
|
-
Hvp_fn = Hvp_fn1
|
|
75
|
-
|
|
76
|
-
elif hvp_method == 'forward':
|
|
77
|
-
grad = var.get_grad()
|
|
78
|
-
def Hvp_fn2(x):
|
|
79
|
-
return hvp_fd_forward(closure, params, x, h=h, g_0=grad, normalize=True)[1]
|
|
80
|
-
Hvp_fn = Hvp_fn2
|
|
81
|
-
|
|
82
|
-
elif hvp_method == 'central':
|
|
83
|
-
grad = var.get_grad()
|
|
84
|
-
def Hvp_fn3(x):
|
|
85
|
-
return hvp_fd_central(closure, params, x, h=h, normalize=True)[1]
|
|
86
|
-
Hvp_fn = Hvp_fn3
|
|
87
|
-
|
|
88
|
-
else: raise ValueError(hvp_method)
|
|
89
|
-
|
|
90
|
-
# -------------------------------- inner step -------------------------------- #
|
|
91
|
-
update = var.get_update()
|
|
92
|
-
if 'inner' in self.children:
|
|
93
|
-
update = apply_transform(self.children['inner'], update, params=params, grads=grad, var=var)
|
|
94
|
-
|
|
95
|
-
# hessian
|
|
96
|
-
if structure.startswith('diagonal'):
|
|
97
|
-
H = Hvp_fn([torch.ones_like(p) for p in params])
|
|
98
|
-
if structure == 'diagonal1': torch._foreach_clamp_min_(H, 1)
|
|
99
|
-
if structure == 'diagonal_abs': torch._foreach_abs_(H)
|
|
100
|
-
torch._foreach_add_(H, reg)
|
|
101
|
-
torch._foreach_div_(update, H)
|
|
102
|
-
var.update = update
|
|
103
|
-
return var
|
|
104
|
-
|
|
105
|
-
# hessian
|
|
106
|
-
raise NotImplementedError(structure)
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
@@ -1,138 +0,0 @@
|
|
|
1
|
-
from collections import deque
|
|
2
|
-
|
|
3
|
-
import torch
|
|
4
|
-
# import visualbench as vb
|
|
5
|
-
|
|
6
|
-
# import torchzero as tz
|
|
7
|
-
|
|
8
|
-
from ...core import Transform, Chainable, apply_transform
|
|
9
|
-
from ...utils.linalg import inv_sqrt_2x2, matrix_power_eigh, gram_schmidt
|
|
10
|
-
from ...utils import TensorList, vec_to_tensors_
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
def inverse_sqrt(M):
|
|
14
|
-
if M.shape[-1] == 2: return inv_sqrt_2x2(M, force_pd=True) # general formula for 2x2 matrices
|
|
15
|
-
return matrix_power_eigh(M, -1/2)
|
|
16
|
-
|
|
17
|
-
def update_subspace_preconditioner_(
|
|
18
|
-
grad: torch.Tensor, # store grads and basis as vectors for matmul
|
|
19
|
-
basis: torch.Tensor, # ndim, k
|
|
20
|
-
accumulator_: torch.Tensor, # k, k
|
|
21
|
-
beta: float | None,
|
|
22
|
-
):
|
|
23
|
-
projected = basis.T @ grad # k
|
|
24
|
-
outer = torch.outer(projected, projected)
|
|
25
|
-
|
|
26
|
-
if beta is None: accumulator_.add_(outer)
|
|
27
|
-
else: accumulator_.lerp_(outer, 1-beta)
|
|
28
|
-
|
|
29
|
-
def apply_subspace_preconditioner(
|
|
30
|
-
tensor: torch.Tensor,
|
|
31
|
-
basis: torch.Tensor, # ndim, k
|
|
32
|
-
accumulator: torch.Tensor,
|
|
33
|
-
):
|
|
34
|
-
preconditioner = inverse_sqrt(accumulator) # k,k
|
|
35
|
-
|
|
36
|
-
tensor_projected = basis.T @ tensor # k
|
|
37
|
-
update_projected = preconditioner @ tensor_projected # k
|
|
38
|
-
return basis @ update_projected # d
|
|
39
|
-
|
|
40
|
-
class RandomSubspacePreconditioning(Transform):
|
|
41
|
-
"""Whitens in random slowly changing subspace. Please note that this is experimental and isn't guaranteed to work."""
|
|
42
|
-
def __init__(self, k: int, beta: float | None = 0.99, basis_beta: float | None = 0.99, inner: Chainable | None = None):
|
|
43
|
-
defaults = dict(k=k, beta=beta, basis_beta=basis_beta)
|
|
44
|
-
super().__init__(defaults, uses_grad=False)
|
|
45
|
-
|
|
46
|
-
if inner is not None: self.set_child('inner', inner)
|
|
47
|
-
|
|
48
|
-
def apply(self, tensors, params, grads, loss, states, settings):
|
|
49
|
-
settings = settings[0]
|
|
50
|
-
g = torch.cat([t.view(-1) for t in tensors])
|
|
51
|
-
k = settings['k']
|
|
52
|
-
beta = settings['beta']
|
|
53
|
-
basis_beta = settings['basis_beta']
|
|
54
|
-
|
|
55
|
-
if 'basis' not in self.global_state:
|
|
56
|
-
self.global_state['basis'] = torch.randn(g.numel(), k, device=g.device, dtype=g.dtype)
|
|
57
|
-
self.global_state['accumulator'] = torch.eye(k, device=g.device, dtype=g.dtype)
|
|
58
|
-
|
|
59
|
-
basis = self.global_state['basis']
|
|
60
|
-
accumulator = self.global_state['accumulator']
|
|
61
|
-
|
|
62
|
-
if basis_beta is not None:
|
|
63
|
-
basis.lerp_(torch.randn_like(basis), 1-basis_beta)
|
|
64
|
-
|
|
65
|
-
update_subspace_preconditioner_(g, basis, accumulator, beta)
|
|
66
|
-
|
|
67
|
-
if 'inner' in self.children:
|
|
68
|
-
tensors = apply_transform(self.children['inner'], tensors, params, grads)
|
|
69
|
-
g = torch.cat([t.view(-1) for t in tensors])
|
|
70
|
-
|
|
71
|
-
try:
|
|
72
|
-
preconditioned = apply_subspace_preconditioner(g, basis, accumulator)
|
|
73
|
-
except torch.linalg.LinAlgError:
|
|
74
|
-
preconditioned = g.clip(-0.1, 0.1)
|
|
75
|
-
vec_to_tensors_(preconditioned, tensors)
|
|
76
|
-
|
|
77
|
-
return tensors
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
class HistorySubspacePreconditioning(Transform):
|
|
81
|
-
"""Whitens in subspace spanned by history of gradient differences.
|
|
82
|
-
Please note that this is experimental and isn't guaranteed to work.
|
|
83
|
-
|
|
84
|
-
Args:
|
|
85
|
-
beta - for preconditioner itself in the basis.
|
|
86
|
-
basis_beta - how much basis is allowed to change.
|
|
87
|
-
"""
|
|
88
|
-
def __init__(self, k: int, beta: float | None = 0.99, basis_beta=0.99, inner: Chainable | None = None):
|
|
89
|
-
defaults = dict(k=k, beta=beta, basis_beta=basis_beta)
|
|
90
|
-
super().__init__(defaults, uses_grad=False)
|
|
91
|
-
|
|
92
|
-
if inner is not None: self.set_child('inner', inner)
|
|
93
|
-
|
|
94
|
-
def apply(self, tensors, params, grads, loss, states, settings):
|
|
95
|
-
settings = settings[0]
|
|
96
|
-
|
|
97
|
-
g = torch.cat([t.view(-1) for t in tensors])
|
|
98
|
-
k = settings['k']
|
|
99
|
-
beta = settings['beta']
|
|
100
|
-
basis_beta = settings['basis_beta']
|
|
101
|
-
|
|
102
|
-
if 'history' not in self.global_state:
|
|
103
|
-
self.global_state['history'] = deque(maxlen=k)
|
|
104
|
-
self.global_state['accumulator'] = torch.eye(k, device=g.device, dtype=g.dtype)
|
|
105
|
-
self.global_state['basis'] = torch.ones(g.numel(), k, device=g.device, dtype=g.dtype)
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
history: deque = self.global_state['history']
|
|
109
|
-
accumulator = self.global_state['accumulator']
|
|
110
|
-
basis = self.global_state['basis']
|
|
111
|
-
|
|
112
|
-
history.append(g)
|
|
113
|
-
if len(history) < k:
|
|
114
|
-
basis_t = torch.randn(g.numel(), k, device=g.device, dtype=g.dtype)
|
|
115
|
-
history_basis = torch.stack(tuple(history), -1)
|
|
116
|
-
basis_t[:, -len(history):] = history_basis
|
|
117
|
-
|
|
118
|
-
else:
|
|
119
|
-
basis_t = torch.stack(tuple(history), -1)
|
|
120
|
-
|
|
121
|
-
basis_t[:,:-1] = basis_t[:, :-1] - basis_t[:, 1:]
|
|
122
|
-
basis_t = (basis_t - basis_t.mean()) / basis_t.std()
|
|
123
|
-
|
|
124
|
-
basis.lerp_(basis_t, 1-basis_beta)
|
|
125
|
-
update_subspace_preconditioner_(g, basis, accumulator, beta)
|
|
126
|
-
|
|
127
|
-
if 'inner' in self.children:
|
|
128
|
-
tensors = apply_transform(self.children['inner'], tensors, params, grads)
|
|
129
|
-
g = torch.cat([t.view(-1) for t in tensors])
|
|
130
|
-
|
|
131
|
-
try:
|
|
132
|
-
preconditioned = apply_subspace_preconditioner(g, basis, accumulator)
|
|
133
|
-
except torch.linalg.LinAlgError:
|
|
134
|
-
preconditioned = g.clip(-0.1,0.1)
|
|
135
|
-
vec_to_tensors_(preconditioned, tensors)
|
|
136
|
-
|
|
137
|
-
return tensors
|
|
138
|
-
|
|
@@ -1,38 +0,0 @@
|
|
|
1
|
-
from collections import deque
|
|
2
|
-
|
|
3
|
-
import torch
|
|
4
|
-
|
|
5
|
-
from ...core import Chainable, TensorwiseTransform
|
|
6
|
-
from ...utils.linalg import matrix_power_eigh
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
class TAda(TensorwiseTransform):
|
|
10
|
-
"""3rd order whitening (maybe normalizes skewness). Please note that this is experimental and isn't guaranteed to work."""
|
|
11
|
-
def __init__(self, history_size: int = 100, reg: float = 1e-8, update_freq: int = 1, concat_params: bool = True, inner: Chainable | None = None):
|
|
12
|
-
defaults = dict(history_size=history_size, reg=reg)
|
|
13
|
-
super().__init__(defaults, uses_grad=False, update_freq=update_freq, inner=inner, concat_params=concat_params)
|
|
14
|
-
|
|
15
|
-
@torch.no_grad
|
|
16
|
-
def update_tensor(self, tensor, param, grad, loss, state, settings):
|
|
17
|
-
reg = settings['reg']
|
|
18
|
-
if 'history' not in state:
|
|
19
|
-
state['history'] = deque(maxlen=settings['history_size'])
|
|
20
|
-
|
|
21
|
-
g = tensor.view(-1)
|
|
22
|
-
history = state['history']
|
|
23
|
-
history.append(g.clone())
|
|
24
|
-
|
|
25
|
-
I = torch.eye(tensor.numel(), device=tensor.device, dtype=tensor.dtype).mul_(reg)
|
|
26
|
-
g_k = history[0]
|
|
27
|
-
outer = torch.outer(g_k, g_k).mul_(torch.dot(g_k, g).clip(min=reg))
|
|
28
|
-
if len(history) > 1:
|
|
29
|
-
for g_k in list(history)[1:]:
|
|
30
|
-
outer += torch.outer(g_k, g_k).mul_(torch.dot(g_k, g).clip(min=reg))
|
|
31
|
-
|
|
32
|
-
state['outer'] = outer.add_(I)
|
|
33
|
-
|
|
34
|
-
@torch.no_grad
|
|
35
|
-
def apply_tensor(self, tensor, param, grad, loss, state, settings):
|
|
36
|
-
outer = state['outer']
|
|
37
|
-
P = matrix_power_eigh(outer, -1/2)
|
|
38
|
-
return (P @ tensor.ravel()).view_as(tensor)
|
|
@@ -1,73 +0,0 @@
|
|
|
1
|
-
from operator import itemgetter
|
|
2
|
-
|
|
3
|
-
import torch
|
|
4
|
-
|
|
5
|
-
from .line_search import LineSearch
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
class TrustRegion(LineSearch):
|
|
9
|
-
"""Basic first order trust region method. Re-evaluates the function after stepping, if value decreased sufficiently,
|
|
10
|
-
step size is increased. If value increased, step size is decreased. This is prone to collapsing.
|
|
11
|
-
|
|
12
|
-
Args:
|
|
13
|
-
nplus (float, optional): multiplier to step size on successful steps. Defaults to 1.5.
|
|
14
|
-
nminus (float, optional): multiplier to step size on unsuccessful steps. Defaults to 0.75.
|
|
15
|
-
c (float, optional): descent condition. Defaults to 1e-4.
|
|
16
|
-
init (float, optional): initial step size. Defaults to 1.
|
|
17
|
-
backtrack (bool, optional): whether to undo the step if value increased. Defaults to True.
|
|
18
|
-
adaptive (bool, optional):
|
|
19
|
-
If enabled, when multiple consecutive steps have been successful or unsuccessful,
|
|
20
|
-
the corresponding multipliers are increased, otherwise they are reset. Defaults to True.
|
|
21
|
-
"""
|
|
22
|
-
def __init__(self, nplus: float=1.5, nminus: float=0.75, c: float=1e-4, init: float = 1, backtrack: bool = True, adaptive: bool = True):
|
|
23
|
-
defaults = dict(nplus=nplus, nminus=nminus, c=c, init=init, backtrack=backtrack, adaptive=adaptive)
|
|
24
|
-
super().__init__(defaults)
|
|
25
|
-
|
|
26
|
-
@torch.no_grad
|
|
27
|
-
def search(self, update, var):
|
|
28
|
-
|
|
29
|
-
nplus, nminus, c, init, backtrack, adaptive = itemgetter('nplus','nminus','c','init','backtrack', 'adaptive')(self.settings[var.params[0]])
|
|
30
|
-
step_size = self.global_state.setdefault('step_size', init)
|
|
31
|
-
previous_success = self.global_state.setdefault('previous_success', False)
|
|
32
|
-
nplus_mul = self.global_state.setdefault('nplus_mul', 1)
|
|
33
|
-
nminus_mul = self.global_state.setdefault('nminus_mul', 1)
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
f_0 = self.evaluate_step_size(0, var, backward=False)
|
|
37
|
-
|
|
38
|
-
# directional derivative (0 if c = 0 because it is not needed)
|
|
39
|
-
if c == 0: d = 0
|
|
40
|
-
else: d = -sum(t.sum() for t in torch._foreach_mul(var.get_grad(), update))
|
|
41
|
-
|
|
42
|
-
# test step size
|
|
43
|
-
sufficient_f = f_0 + c * step_size * min(d, 0) # pyright:ignore[reportArgumentType]
|
|
44
|
-
|
|
45
|
-
f_1 = self.evaluate_step_size(step_size, var, backward=False)
|
|
46
|
-
|
|
47
|
-
proposed = step_size
|
|
48
|
-
|
|
49
|
-
# very good step
|
|
50
|
-
if f_1 < sufficient_f:
|
|
51
|
-
self.global_state['step_size'] *= nplus * nplus_mul
|
|
52
|
-
|
|
53
|
-
# two very good steps in a row - increase nplus_mul
|
|
54
|
-
if adaptive:
|
|
55
|
-
if previous_success: self.global_state['nplus_mul'] *= nplus
|
|
56
|
-
else: self.global_state['nplus_mul'] = 1
|
|
57
|
-
|
|
58
|
-
# acceptable step step
|
|
59
|
-
#elif f_1 <= f_0: pass
|
|
60
|
-
|
|
61
|
-
# bad step
|
|
62
|
-
if f_1 >= f_0:
|
|
63
|
-
self.global_state['step_size'] *= nminus * nminus_mul
|
|
64
|
-
|
|
65
|
-
# two bad steps in a row - decrease nminus_mul
|
|
66
|
-
if adaptive:
|
|
67
|
-
if previous_success: self.global_state['nminus_mul'] *= nminus
|
|
68
|
-
else: self.global_state['nminus_mul'] = 1
|
|
69
|
-
|
|
70
|
-
if backtrack: proposed = 0
|
|
71
|
-
else: proposed *= nminus * nminus_mul
|
|
72
|
-
|
|
73
|
-
return proposed
|
torchzero/modules/lr/__init__.py
DELETED
torchzero/modules/lr/adaptive.py
DELETED
|
@@ -1,93 +0,0 @@
|
|
|
1
|
-
"""Various step size strategies"""
|
|
2
|
-
import random
|
|
3
|
-
from typing import Any
|
|
4
|
-
from operator import itemgetter
|
|
5
|
-
import torch
|
|
6
|
-
|
|
7
|
-
from ...core import Transform
|
|
8
|
-
from ...utils import TensorList, NumberList, unpack_dicts
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
class PolyakStepSize(Transform):
|
|
12
|
-
"""Polyak's step-size method.
|
|
13
|
-
|
|
14
|
-
Args:
|
|
15
|
-
max (float | None, optional): maximum possible step size. Defaults to None.
|
|
16
|
-
min_obj_value (int, optional):
|
|
17
|
-
(estimated) minimal possible value of the objective function (lowest possible loss). Defaults to 0.
|
|
18
|
-
use_grad (bool, optional):
|
|
19
|
-
if True, uses dot product of update and gradient to compute the step size.
|
|
20
|
-
Otherwise, dot product of update with itself is used, which has no geometric meaning so it probably won't work well.
|
|
21
|
-
Defaults to True.
|
|
22
|
-
parameterwise (bool, optional):
|
|
23
|
-
if True, calculate Polyak step-size for each parameter separately,
|
|
24
|
-
if False calculate one global step size for all parameters. Defaults to False.
|
|
25
|
-
alpha (float, optional): multiplier to Polyak step-size. Defaults to 1.
|
|
26
|
-
"""
|
|
27
|
-
def __init__(self, max: float | None = None, min_obj_value: float = 0, use_grad=True, parameterwise=False, alpha: float = 1):
|
|
28
|
-
|
|
29
|
-
defaults = dict(alpha=alpha, max=max, min_obj_value=min_obj_value, use_grad=use_grad, parameterwise=parameterwise)
|
|
30
|
-
super().__init__(defaults, uses_grad=use_grad)
|
|
31
|
-
|
|
32
|
-
@torch.no_grad
|
|
33
|
-
def apply(self, tensors, params, grads, loss, states, settings):
|
|
34
|
-
assert grads is not None
|
|
35
|
-
tensors = TensorList(tensors)
|
|
36
|
-
grads = TensorList(grads)
|
|
37
|
-
alpha = NumberList(s['alpha'] for s in settings)
|
|
38
|
-
|
|
39
|
-
parameterwise, use_grad, max, min_obj_value = itemgetter('parameterwise', 'use_grad', 'max', 'min_obj_value')(settings[0])
|
|
40
|
-
|
|
41
|
-
if use_grad: denom = tensors.dot(grads)
|
|
42
|
-
else: denom = tensors.dot(tensors)
|
|
43
|
-
|
|
44
|
-
if parameterwise:
|
|
45
|
-
polyak_step_size: TensorList | Any = (loss - min_obj_value) / denom.where(denom!=0, 1)
|
|
46
|
-
polyak_step_size = polyak_step_size.where(denom != 0, 0)
|
|
47
|
-
if max is not None: polyak_step_size = polyak_step_size.clamp_max(max)
|
|
48
|
-
|
|
49
|
-
else:
|
|
50
|
-
if denom.abs() <= torch.finfo(denom.dtype).eps: polyak_step_size = 0 # converged
|
|
51
|
-
else: polyak_step_size = (loss - min_obj_value) / denom
|
|
52
|
-
|
|
53
|
-
if max is not None:
|
|
54
|
-
if polyak_step_size > max: polyak_step_size = max
|
|
55
|
-
|
|
56
|
-
tensors.mul_(alpha * polyak_step_size)
|
|
57
|
-
return tensors
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
class RandomStepSize(Transform):
|
|
61
|
-
"""Uses random global or layer-wise step size from `low` to `high`.
|
|
62
|
-
|
|
63
|
-
Args:
|
|
64
|
-
low (float, optional): minimum learning rate. Defaults to 0.
|
|
65
|
-
high (float, optional): maximum learning rate. Defaults to 1.
|
|
66
|
-
parameterwise (bool, optional):
|
|
67
|
-
if True, generate random step size for each parameter separately,
|
|
68
|
-
if False generate one global random step size. Defaults to False.
|
|
69
|
-
"""
|
|
70
|
-
def __init__(self, low: float = 0, high: float = 1, parameterwise=False, seed:int|None=None):
|
|
71
|
-
defaults = dict(low=low, high=high, parameterwise=parameterwise,seed=seed)
|
|
72
|
-
super().__init__(defaults, uses_grad=False)
|
|
73
|
-
|
|
74
|
-
@torch.no_grad
|
|
75
|
-
def apply(self, tensors, params, grads, loss, states, settings):
|
|
76
|
-
s = settings[0]
|
|
77
|
-
parameterwise = s['parameterwise']
|
|
78
|
-
|
|
79
|
-
seed = s['seed']
|
|
80
|
-
if 'generator' not in self.global_state:
|
|
81
|
-
self.global_state['generator'] = random.Random(seed)
|
|
82
|
-
generator: random.Random = self.global_state['generator']
|
|
83
|
-
|
|
84
|
-
if parameterwise:
|
|
85
|
-
low, high = unpack_dicts(settings, 'low', 'high')
|
|
86
|
-
lr = [generator.uniform(l, h) for l, h in zip(low, high)]
|
|
87
|
-
else:
|
|
88
|
-
low = s['low']
|
|
89
|
-
high = s['high']
|
|
90
|
-
lr = generator.uniform(low, high)
|
|
91
|
-
|
|
92
|
-
torch._foreach_mul_(tensors, lr)
|
|
93
|
-
return tensors
|
torchzero/modules/lr/lr.py
DELETED
|
@@ -1,63 +0,0 @@
|
|
|
1
|
-
"""Learning rate"""
|
|
2
|
-
import torch
|
|
3
|
-
|
|
4
|
-
from ...core import Transform
|
|
5
|
-
from ...utils import NumberList, TensorList, generic_eq, unpack_dicts
|
|
6
|
-
|
|
7
|
-
def lazy_lr(tensors: TensorList, lr: float | list, inplace:bool):
|
|
8
|
-
"""multiplies by lr if lr is not 1"""
|
|
9
|
-
if generic_eq(lr, 1): return tensors
|
|
10
|
-
if inplace: return tensors.mul_(lr)
|
|
11
|
-
return tensors * lr
|
|
12
|
-
|
|
13
|
-
class LR(Transform):
|
|
14
|
-
"""Learning rate. Adding this module also adds support for LR schedulers."""
|
|
15
|
-
def __init__(self, lr: float):
|
|
16
|
-
defaults=dict(lr=lr)
|
|
17
|
-
super().__init__(defaults, uses_grad=False)
|
|
18
|
-
|
|
19
|
-
@torch.no_grad
|
|
20
|
-
def apply(self, tensors, params, grads, loss, states, settings):
|
|
21
|
-
return lazy_lr(TensorList(tensors), lr=[s['lr'] for s in settings], inplace=True)
|
|
22
|
-
|
|
23
|
-
class StepSize(Transform):
|
|
24
|
-
"""this is exactly the same as LR, except the `lr` parameter can be renamed to any other name to avoid clashes"""
|
|
25
|
-
def __init__(self, step_size: float, key = 'step_size'):
|
|
26
|
-
defaults={"key": key, key: step_size}
|
|
27
|
-
super().__init__(defaults, uses_grad=False)
|
|
28
|
-
|
|
29
|
-
@torch.no_grad
|
|
30
|
-
def apply(self, tensors, params, grads, loss, states, settings):
|
|
31
|
-
return lazy_lr(TensorList(tensors), lr=[s[s['key']] for s in settings], inplace=True)
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
def _warmup_lr(step: int, start_lr: float | NumberList, end_lr: float | NumberList, steps: float):
|
|
35
|
-
"""returns warm up lr scalar"""
|
|
36
|
-
if step > steps: return end_lr
|
|
37
|
-
return start_lr + (end_lr - start_lr) * (step / steps)
|
|
38
|
-
|
|
39
|
-
class Warmup(Transform):
|
|
40
|
-
"""Learning rate warmup, linearly increases learning rate multiplier from :code:`start_lr` to :code:`end_lr` over :code:`steps` steps.
|
|
41
|
-
|
|
42
|
-
Args:
|
|
43
|
-
start_lr (_type_, optional): initial learning rate multiplier on first step. Defaults to 1e-5.
|
|
44
|
-
end_lr (float, optional): learning rate multiplier at the end and after warmup. Defaults to 1.
|
|
45
|
-
steps (int, optional): number of steps to perform warmup for. Defaults to 100.
|
|
46
|
-
"""
|
|
47
|
-
def __init__(self, start_lr = 1e-5, end_lr:float = 1, steps = 100):
|
|
48
|
-
defaults = dict(start_lr=start_lr,end_lr=end_lr, steps=steps)
|
|
49
|
-
super().__init__(defaults, uses_grad=False)
|
|
50
|
-
|
|
51
|
-
@torch.no_grad
|
|
52
|
-
def apply(self, tensors, params, grads, loss, states, settings):
|
|
53
|
-
start_lr, end_lr = unpack_dicts(settings, 'start_lr', 'end_lr', cls = NumberList)
|
|
54
|
-
num_steps = settings[0]['steps']
|
|
55
|
-
step = self.global_state.get('step', 0)
|
|
56
|
-
|
|
57
|
-
target = lazy_lr(
|
|
58
|
-
TensorList(tensors),
|
|
59
|
-
lr=_warmup_lr(step=step, start_lr=start_lr, end_lr=end_lr, steps=num_steps),
|
|
60
|
-
inplace=True
|
|
61
|
-
)
|
|
62
|
-
self.global_state['step'] = step + 1
|
|
63
|
-
return target
|
|
@@ -1,166 +0,0 @@
|
|
|
1
|
-
from typing import Literal
|
|
2
|
-
|
|
3
|
-
import torch
|
|
4
|
-
|
|
5
|
-
from ...core import Module, apply_transform, Chainable
|
|
6
|
-
from ...utils import NumberList, TensorList, as_tensorlist
|
|
7
|
-
from ...utils.derivatives import hvp, hvp_fd_central, hvp_fd_forward
|
|
8
|
-
|
|
9
|
-
class MatrixMomentum(Module):
|
|
10
|
-
"""
|
|
11
|
-
May be useful for ill conditioned stochastic quadratic objectives but I need to test this.
|
|
12
|
-
Evaluates hessian vector product on each step (via finite difference or autograd).
|
|
13
|
-
|
|
14
|
-
`mu` is supposed to be smaller than (1/largest eigenvalue), otherwise this will be very unstable.
|
|
15
|
-
|
|
16
|
-
Args:
|
|
17
|
-
mu (float, optional): this has a similar role to (1 - beta) in normal momentum. Defaults to 0.1.
|
|
18
|
-
beta (float, optional): decay for the buffer, this is not part of the original update rule. Defaults to 1.
|
|
19
|
-
hvp_method (str, optional):
|
|
20
|
-
How to calculate hessian-vector products.
|
|
21
|
-
Exact - "autograd", or finite difference - "forward", "central". Defaults to 'forward'.
|
|
22
|
-
h (float, optional): finite difference step size if hvp_method is set to finite difference. Defaults to 1e-3.
|
|
23
|
-
hvp_tfm (Chainable | None, optional): optional module applied to hessian-vector products. Defaults to None.
|
|
24
|
-
|
|
25
|
-
Reference:
|
|
26
|
-
Orr, Genevieve, and Todd Leen. "Using curvature information for fast stochastic search." Advances in neural information processing systems 9 (1996).
|
|
27
|
-
"""
|
|
28
|
-
|
|
29
|
-
def __init__(
|
|
30
|
-
self,
|
|
31
|
-
mu=0.1,
|
|
32
|
-
beta: float = 1,
|
|
33
|
-
hvp_method: Literal["autograd", "forward", "central"] = "forward",
|
|
34
|
-
h: float = 1e-3,
|
|
35
|
-
hvp_tfm: Chainable | None = None,
|
|
36
|
-
):
|
|
37
|
-
defaults = dict(mu=mu, beta=beta, hvp_method=hvp_method, h=h)
|
|
38
|
-
super().__init__(defaults)
|
|
39
|
-
|
|
40
|
-
if hvp_tfm is not None:
|
|
41
|
-
self.set_child('hvp_tfm', hvp_tfm)
|
|
42
|
-
|
|
43
|
-
@torch.no_grad
|
|
44
|
-
def step(self, var):
|
|
45
|
-
assert var.closure is not None
|
|
46
|
-
prev_update = self.get_state(var.params, 'prev_update', cls=TensorList)
|
|
47
|
-
hvp_method = self.settings[var.params[0]]['hvp_method']
|
|
48
|
-
h = self.settings[var.params[0]]['h']
|
|
49
|
-
|
|
50
|
-
mu,beta = self.get_settings(var.params, 'mu','beta', cls=NumberList)
|
|
51
|
-
|
|
52
|
-
if hvp_method == 'autograd':
|
|
53
|
-
with torch.enable_grad():
|
|
54
|
-
grad = var.get_grad(create_graph=True)
|
|
55
|
-
hvp_ = TensorList(hvp(var.params, grads=grad, vec=prev_update, allow_unused=True, retain_graph=False)).detach_()
|
|
56
|
-
|
|
57
|
-
elif hvp_method == 'forward':
|
|
58
|
-
var.get_grad()
|
|
59
|
-
l, hvp_ = hvp_fd_forward(var.closure, var.params, vec=prev_update, g_0=var.grad, h=h, normalize=True)
|
|
60
|
-
if var.loss_approx is None: var.loss_approx = l
|
|
61
|
-
|
|
62
|
-
elif hvp_method == 'central':
|
|
63
|
-
l, hvp_ = hvp_fd_central(var.closure, var.params, vec=prev_update, h=h, normalize=True)
|
|
64
|
-
if var.loss_approx is None: var.loss_approx = l
|
|
65
|
-
|
|
66
|
-
else:
|
|
67
|
-
raise ValueError(hvp_method)
|
|
68
|
-
|
|
69
|
-
if 'hvp_tfm' in self.children:
|
|
70
|
-
hvp_ = TensorList(apply_transform(self.children['hvp_tfm'], hvp_, params=var.params, grads=var.grad, var=var))
|
|
71
|
-
|
|
72
|
-
update = TensorList(var.get_update())
|
|
73
|
-
|
|
74
|
-
hvp_ = as_tensorlist(hvp_)
|
|
75
|
-
update.add_(prev_update - hvp_*mu)
|
|
76
|
-
prev_update.set_(update * beta)
|
|
77
|
-
var.update = update
|
|
78
|
-
return var
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
class AdaptiveMatrixMomentum(Module):
|
|
82
|
-
"""
|
|
83
|
-
May be useful for ill conditioned stochastic quadratic objectives but I need to test this.
|
|
84
|
-
Evaluates hessian vector product on each step (via finite difference or autograd).
|
|
85
|
-
|
|
86
|
-
This version estimates mu via a simple heuristic: ||s||/||y||, where s is parameter difference, y is gradient difference.
|
|
87
|
-
|
|
88
|
-
Args:
|
|
89
|
-
mu_mul (float, optional): multiplier to the estimated mu. Defaults to 1.
|
|
90
|
-
beta (float, optional): decay for the buffer, this is not part of the original update rule. Defaults to 1.
|
|
91
|
-
hvp_method (str, optional):
|
|
92
|
-
How to calculate hessian-vector products.
|
|
93
|
-
Exact - "autograd", or finite difference - "forward", "central". Defaults to 'forward'.
|
|
94
|
-
h (float, optional): finite difference step size if hvp_method is set to finite difference. Defaults to 1e-3.
|
|
95
|
-
hvp_tfm (Chainable | None, optional): optional module applied to hessian-vector products. Defaults to None.
|
|
96
|
-
|
|
97
|
-
Reference:
|
|
98
|
-
Orr, Genevieve, and Todd Leen. "Using curvature information for fast stochastic search." Advances in neural information processing systems 9 (1996).
|
|
99
|
-
"""
|
|
100
|
-
|
|
101
|
-
def __init__(
|
|
102
|
-
self,
|
|
103
|
-
mu_mul: float = 1,
|
|
104
|
-
beta: float = 1,
|
|
105
|
-
eps=1e-4,
|
|
106
|
-
hvp_method: Literal["autograd", "forward", "central"] = "forward",
|
|
107
|
-
h: float = 1e-3,
|
|
108
|
-
hvp_tfm: Chainable | None = None,
|
|
109
|
-
):
|
|
110
|
-
defaults = dict(mu_mul=mu_mul, beta=beta, hvp_method=hvp_method, h=h, eps=eps)
|
|
111
|
-
super().__init__(defaults)
|
|
112
|
-
|
|
113
|
-
if hvp_tfm is not None:
|
|
114
|
-
self.set_child('hvp_tfm', hvp_tfm)
|
|
115
|
-
|
|
116
|
-
@torch.no_grad
|
|
117
|
-
def step(self, var):
|
|
118
|
-
assert var.closure is not None
|
|
119
|
-
prev_update, prev_params, prev_grad = self.get_state(var.params, 'prev_update', 'prev_params', 'prev_grad', cls=TensorList)
|
|
120
|
-
|
|
121
|
-
settings = self.settings[var.params[0]]
|
|
122
|
-
hvp_method = settings['hvp_method']
|
|
123
|
-
h = settings['h']
|
|
124
|
-
eps = settings['eps']
|
|
125
|
-
|
|
126
|
-
mu_mul, beta = self.get_settings(var.params, 'mu_mul','beta', cls=NumberList)
|
|
127
|
-
|
|
128
|
-
if hvp_method == 'autograd':
|
|
129
|
-
with torch.enable_grad():
|
|
130
|
-
grad = var.get_grad(create_graph=True)
|
|
131
|
-
hvp_ = TensorList(hvp(var.params, grads=grad, vec=prev_update, allow_unused=True, retain_graph=False)).detach_()
|
|
132
|
-
|
|
133
|
-
elif hvp_method == 'forward':
|
|
134
|
-
var.get_grad()
|
|
135
|
-
l, hvp_ = hvp_fd_forward(var.closure, var.params, vec=prev_update, g_0=var.grad, h=h, normalize=True)
|
|
136
|
-
if var.loss_approx is None: var.loss_approx = l
|
|
137
|
-
|
|
138
|
-
elif hvp_method == 'central':
|
|
139
|
-
l, hvp_ = hvp_fd_central(var.closure, var.params, vec=prev_update, h=h, normalize=True)
|
|
140
|
-
if var.loss_approx is None: var.loss_approx = l
|
|
141
|
-
|
|
142
|
-
else:
|
|
143
|
-
raise ValueError(hvp_method)
|
|
144
|
-
|
|
145
|
-
if 'hvp_tfm' in self.children:
|
|
146
|
-
hvp_ = TensorList(apply_transform(self.children['hvp_tfm'], hvp_, params=var.params, grads=var.grad, var=var))
|
|
147
|
-
|
|
148
|
-
# adaptive part
|
|
149
|
-
update = TensorList(var.get_update())
|
|
150
|
-
|
|
151
|
-
s_k = var.params - prev_params
|
|
152
|
-
prev_params.copy_(var.params)
|
|
153
|
-
|
|
154
|
-
assert var.grad is not None
|
|
155
|
-
y_k = var.grad - prev_grad
|
|
156
|
-
prev_grad.copy_(var.grad)
|
|
157
|
-
|
|
158
|
-
ada_mu = (s_k.global_vector_norm() / (y_k.global_vector_norm() + eps)) * mu_mul
|
|
159
|
-
|
|
160
|
-
# matrix momentum uppdate
|
|
161
|
-
hvp_ = as_tensorlist(hvp_)
|
|
162
|
-
update.add_(prev_update - hvp_*ada_mu)
|
|
163
|
-
prev_update.set_(update * beta)
|
|
164
|
-
var.update = update
|
|
165
|
-
return var
|
|
166
|
-
|