torchzero 0.3.11__py3-none-any.whl → 0.3.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_opts.py +95 -69
- tests/test_tensorlist.py +8 -7
- torchzero/__init__.py +1 -1
- torchzero/core/__init__.py +2 -2
- torchzero/core/module.py +225 -72
- torchzero/core/reformulation.py +65 -0
- torchzero/core/transform.py +44 -24
- torchzero/modules/__init__.py +13 -5
- torchzero/modules/{optimizers → adaptive}/__init__.py +5 -2
- torchzero/modules/adaptive/adagrad.py +356 -0
- torchzero/modules/{optimizers → adaptive}/adahessian.py +53 -52
- torchzero/modules/{optimizers → adaptive}/adam.py +0 -3
- torchzero/modules/{optimizers → adaptive}/adan.py +26 -40
- torchzero/modules/{optimizers → adaptive}/adaptive_heavyball.py +3 -6
- torchzero/modules/adaptive/aegd.py +54 -0
- torchzero/modules/{optimizers → adaptive}/esgd.py +1 -1
- torchzero/modules/{optimizers/ladagrad.py → adaptive/lmadagrad.py} +42 -39
- torchzero/modules/{optimizers → adaptive}/mars.py +24 -36
- torchzero/modules/adaptive/matrix_momentum.py +146 -0
- torchzero/modules/{optimizers → adaptive}/msam.py +14 -12
- torchzero/modules/{optimizers → adaptive}/muon.py +19 -20
- torchzero/modules/adaptive/natural_gradient.py +175 -0
- torchzero/modules/{optimizers → adaptive}/rprop.py +0 -2
- torchzero/modules/{optimizers → adaptive}/sam.py +1 -1
- torchzero/modules/{optimizers → adaptive}/shampoo.py +8 -4
- torchzero/modules/{optimizers → adaptive}/soap.py +27 -50
- torchzero/modules/{optimizers → adaptive}/sophia_h.py +2 -3
- torchzero/modules/clipping/clipping.py +85 -92
- torchzero/modules/clipping/ema_clipping.py +5 -5
- torchzero/modules/conjugate_gradient/__init__.py +11 -0
- torchzero/modules/{quasi_newton → conjugate_gradient}/cg.py +355 -369
- torchzero/modules/experimental/__init__.py +9 -32
- torchzero/modules/experimental/dct.py +2 -2
- torchzero/modules/experimental/fft.py +2 -2
- torchzero/modules/experimental/gradmin.py +4 -3
- torchzero/modules/experimental/l_infinity.py +111 -0
- torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +3 -40
- torchzero/modules/experimental/newton_solver.py +79 -17
- torchzero/modules/experimental/newtonnewton.py +27 -14
- torchzero/modules/experimental/scipy_newton_cg.py +105 -0
- torchzero/modules/experimental/structural_projections.py +1 -1
- torchzero/modules/functional.py +50 -14
- torchzero/modules/grad_approximation/fdm.py +19 -20
- torchzero/modules/grad_approximation/forward_gradient.py +4 -2
- torchzero/modules/grad_approximation/grad_approximator.py +43 -47
- torchzero/modules/grad_approximation/rfdm.py +144 -122
- torchzero/modules/higher_order/__init__.py +1 -1
- torchzero/modules/higher_order/higher_order_newton.py +31 -23
- torchzero/modules/least_squares/__init__.py +1 -0
- torchzero/modules/least_squares/gn.py +161 -0
- torchzero/modules/line_search/__init__.py +2 -2
- torchzero/modules/line_search/_polyinterp.py +289 -0
- torchzero/modules/line_search/adaptive.py +69 -44
- torchzero/modules/line_search/backtracking.py +83 -70
- torchzero/modules/line_search/line_search.py +159 -68
- torchzero/modules/line_search/scipy.py +1 -1
- torchzero/modules/line_search/strong_wolfe.py +319 -218
- torchzero/modules/misc/__init__.py +8 -0
- torchzero/modules/misc/debug.py +4 -4
- torchzero/modules/misc/escape.py +9 -7
- torchzero/modules/misc/gradient_accumulation.py +88 -22
- torchzero/modules/misc/homotopy.py +59 -0
- torchzero/modules/misc/misc.py +82 -15
- torchzero/modules/misc/multistep.py +47 -11
- torchzero/modules/misc/regularization.py +5 -9
- torchzero/modules/misc/split.py +55 -35
- torchzero/modules/misc/switch.py +1 -1
- torchzero/modules/momentum/__init__.py +1 -5
- torchzero/modules/momentum/averaging.py +3 -3
- torchzero/modules/momentum/cautious.py +42 -47
- torchzero/modules/momentum/momentum.py +35 -1
- torchzero/modules/ops/__init__.py +9 -1
- torchzero/modules/ops/binary.py +9 -8
- torchzero/modules/{momentum/ema.py → ops/higher_level.py} +10 -33
- torchzero/modules/ops/multi.py +15 -15
- torchzero/modules/ops/reduce.py +1 -1
- torchzero/modules/ops/utility.py +12 -8
- torchzero/modules/projections/projection.py +4 -4
- torchzero/modules/quasi_newton/__init__.py +1 -16
- torchzero/modules/quasi_newton/damping.py +105 -0
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -163
- torchzero/modules/quasi_newton/lbfgs.py +256 -200
- torchzero/modules/quasi_newton/lsr1.py +167 -132
- torchzero/modules/quasi_newton/quasi_newton.py +346 -446
- torchzero/modules/restarts/__init__.py +7 -0
- torchzero/modules/restarts/restars.py +252 -0
- torchzero/modules/second_order/__init__.py +2 -1
- torchzero/modules/second_order/multipoint.py +238 -0
- torchzero/modules/second_order/newton.py +133 -88
- torchzero/modules/second_order/newton_cg.py +141 -80
- torchzero/modules/smoothing/__init__.py +1 -1
- torchzero/modules/smoothing/sampling.py +300 -0
- torchzero/modules/step_size/__init__.py +1 -1
- torchzero/modules/step_size/adaptive.py +312 -47
- torchzero/modules/termination/__init__.py +14 -0
- torchzero/modules/termination/termination.py +207 -0
- torchzero/modules/trust_region/__init__.py +5 -0
- torchzero/modules/trust_region/cubic_regularization.py +170 -0
- torchzero/modules/trust_region/dogleg.py +92 -0
- torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
- torchzero/modules/trust_region/trust_cg.py +97 -0
- torchzero/modules/trust_region/trust_region.py +350 -0
- torchzero/modules/variance_reduction/__init__.py +1 -0
- torchzero/modules/variance_reduction/svrg.py +208 -0
- torchzero/modules/weight_decay/weight_decay.py +65 -64
- torchzero/modules/zeroth_order/__init__.py +1 -0
- torchzero/modules/zeroth_order/cd.py +359 -0
- torchzero/optim/root.py +65 -0
- torchzero/optim/utility/split.py +8 -8
- torchzero/optim/wrappers/directsearch.py +0 -1
- torchzero/optim/wrappers/fcmaes.py +3 -2
- torchzero/optim/wrappers/nlopt.py +0 -2
- torchzero/optim/wrappers/optuna.py +2 -2
- torchzero/optim/wrappers/scipy.py +81 -22
- torchzero/utils/__init__.py +40 -4
- torchzero/utils/compile.py +1 -1
- torchzero/utils/derivatives.py +123 -111
- torchzero/utils/linalg/__init__.py +9 -2
- torchzero/utils/linalg/linear_operator.py +329 -0
- torchzero/utils/linalg/matrix_funcs.py +2 -2
- torchzero/utils/linalg/orthogonalize.py +2 -1
- torchzero/utils/linalg/qr.py +2 -2
- torchzero/utils/linalg/solve.py +226 -154
- torchzero/utils/metrics.py +83 -0
- torchzero/utils/python_tools.py +6 -0
- torchzero/utils/tensorlist.py +105 -34
- torchzero/utils/torch_tools.py +9 -4
- torchzero-0.3.13.dist-info/METADATA +14 -0
- torchzero-0.3.13.dist-info/RECORD +166 -0
- {torchzero-0.3.11.dist-info → torchzero-0.3.13.dist-info}/top_level.txt +0 -1
- docs/source/conf.py +0 -59
- docs/source/docstring template.py +0 -46
- torchzero/modules/experimental/absoap.py +0 -253
- torchzero/modules/experimental/adadam.py +0 -118
- torchzero/modules/experimental/adamY.py +0 -131
- torchzero/modules/experimental/adam_lambertw.py +0 -149
- torchzero/modules/experimental/adaptive_step_size.py +0 -90
- torchzero/modules/experimental/adasoap.py +0 -177
- torchzero/modules/experimental/cosine.py +0 -214
- torchzero/modules/experimental/cubic_adam.py +0 -97
- torchzero/modules/experimental/eigendescent.py +0 -120
- torchzero/modules/experimental/etf.py +0 -195
- torchzero/modules/experimental/exp_adam.py +0 -113
- torchzero/modules/experimental/expanded_lbfgs.py +0 -141
- torchzero/modules/experimental/hnewton.py +0 -85
- torchzero/modules/experimental/modular_lbfgs.py +0 -265
- torchzero/modules/experimental/parabolic_search.py +0 -220
- torchzero/modules/experimental/subspace_preconditioners.py +0 -145
- torchzero/modules/experimental/tensor_adagrad.py +0 -42
- torchzero/modules/line_search/polynomial.py +0 -233
- torchzero/modules/momentum/matrix_momentum.py +0 -193
- torchzero/modules/optimizers/adagrad.py +0 -165
- torchzero/modules/quasi_newton/trust_region.py +0 -397
- torchzero/modules/smoothing/gaussian.py +0 -198
- torchzero-0.3.11.dist-info/METADATA +0 -404
- torchzero-0.3.11.dist-info/RECORD +0 -159
- torchzero-0.3.11.dist-info/licenses/LICENSE +0 -21
- /torchzero/modules/{optimizers → adaptive}/lion.py +0 -0
- /torchzero/modules/{optimizers → adaptive}/orthograd.py +0 -0
- /torchzero/modules/{optimizers → adaptive}/rmsprop.py +0 -0
- {torchzero-0.3.11.dist-info → torchzero-0.3.13.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,146 @@
|
|
|
1
|
+
from typing import Literal
|
|
2
|
+
from collections.abc import Callable
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from ...core import Module, apply_transform, Chainable
|
|
6
|
+
from ...utils import NumberList, TensorList, as_tensorlist
|
|
7
|
+
from ...utils.derivatives import hvp, hvp_fd_central, hvp_fd_forward
|
|
8
|
+
from ..functional import initial_step_size
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class MatrixMomentum(Module):
|
|
12
|
+
"""Second order momentum method.
|
|
13
|
+
|
|
14
|
+
Matrix momentum is useful for convex objectives, also for some reason it has very really good generalization on elastic net logistic regression.
|
|
15
|
+
|
|
16
|
+
Notes:
|
|
17
|
+
- ``mu`` needs to be tuned very carefully. It is supposed to be smaller than (1/largest eigenvalue), otherwise this will be very unstable. I have devised an adaptive version of this - ``tz.m.AdaptiveMatrixMomentum``, and it works well without having to tune ``mu``, however the adaptive version doesn't work on stochastic objectives.
|
|
18
|
+
|
|
19
|
+
- In most cases ``MatrixMomentum`` should be the first module in the chain because it relies on autograd.
|
|
20
|
+
|
|
21
|
+
- This module requires the a closure passed to the optimizer step, as it needs to re-evaluate the loss and gradients for calculating HVPs. The closure must accept a ``backward`` argument.
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
mu (float, optional): this has a similar role to (1 - beta) in normal momentum. Defaults to 0.1.
|
|
25
|
+
hvp_method (str, optional):
|
|
26
|
+
Determines how Hessian-vector products are evaluated.
|
|
27
|
+
|
|
28
|
+
- ``"autograd"``: Use PyTorch's autograd to calculate exact HVPs.
|
|
29
|
+
This requires creating a graph for the gradient.
|
|
30
|
+
- ``"forward"``: Use a forward finite difference formula to
|
|
31
|
+
approximate the HVP. This requires one extra gradient evaluation.
|
|
32
|
+
- ``"central"``: Use a central finite difference formula for a
|
|
33
|
+
more accurate HVP approximation. This requires two extra
|
|
34
|
+
gradient evaluations.
|
|
35
|
+
Defaults to "autograd".
|
|
36
|
+
h (float, optional): finite difference step size if hvp_method is set to finite difference. Defaults to 1e-3.
|
|
37
|
+
hvp_tfm (Chainable | None, optional): optional module applied to hessian-vector products. Defaults to None.
|
|
38
|
+
|
|
39
|
+
Reference:
|
|
40
|
+
Orr, Genevieve, and Todd Leen. "Using curvature information for fast stochastic search." Advances in neural information processing systems 9 (1996).
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
def __init__(
|
|
44
|
+
self,
|
|
45
|
+
lr:float,
|
|
46
|
+
mu=0.1,
|
|
47
|
+
hvp_method: Literal["autograd", "forward", "central"] = "autograd",
|
|
48
|
+
h: float = 1e-3,
|
|
49
|
+
adaptive:bool = False,
|
|
50
|
+
adapt_freq: int | None = None,
|
|
51
|
+
hvp_tfm: Chainable | None = None,
|
|
52
|
+
):
|
|
53
|
+
defaults = dict(lr=lr, mu=mu, hvp_method=hvp_method, h=h, adaptive=adaptive, adapt_freq=adapt_freq)
|
|
54
|
+
super().__init__(defaults)
|
|
55
|
+
|
|
56
|
+
if hvp_tfm is not None:
|
|
57
|
+
self.set_child('hvp_tfm', hvp_tfm)
|
|
58
|
+
|
|
59
|
+
def reset_for_online(self):
|
|
60
|
+
super().reset_for_online()
|
|
61
|
+
self.clear_state_keys('p_prev')
|
|
62
|
+
|
|
63
|
+
@torch.no_grad
|
|
64
|
+
def update(self, var):
|
|
65
|
+
assert var.closure is not None
|
|
66
|
+
p = TensorList(var.params)
|
|
67
|
+
p_prev = self.get_state(p, 'p_prev', init=var.params)
|
|
68
|
+
|
|
69
|
+
hvp_method = self.defaults['hvp_method']
|
|
70
|
+
h = self.defaults['h']
|
|
71
|
+
step = self.global_state.get("step", 0)
|
|
72
|
+
self.global_state["step"] = step + 1
|
|
73
|
+
|
|
74
|
+
if step > 0:
|
|
75
|
+
s = p - p_prev
|
|
76
|
+
|
|
77
|
+
Hs, _ = self.Hvp(s, at_x0=True, var=var, rgrad=None, hvp_method=hvp_method, h=h, normalize=True, retain_grad=False)
|
|
78
|
+
Hs = [t.detach() for t in Hs]
|
|
79
|
+
|
|
80
|
+
if 'hvp_tfm' in self.children:
|
|
81
|
+
Hs = TensorList(apply_transform(self.children['hvp_tfm'], Hs, params=p, grads=var.grad, var=var))
|
|
82
|
+
|
|
83
|
+
self.store(p, ("Hs", "s"), (Hs, s))
|
|
84
|
+
|
|
85
|
+
# -------------------------------- adaptive mu ------------------------------- #
|
|
86
|
+
if self.defaults["adaptive"]:
|
|
87
|
+
g = TensorList(var.get_grad())
|
|
88
|
+
|
|
89
|
+
if self.defaults["adapt_freq"] is None:
|
|
90
|
+
# ---------------------------- deterministic case ---------------------------- #
|
|
91
|
+
g_prev = self.get_state(var.params, "g_prev", cls=TensorList)
|
|
92
|
+
y = g - g_prev
|
|
93
|
+
g_prev.copy_(g)
|
|
94
|
+
denom = y.global_vector_norm()
|
|
95
|
+
denom = denom.clip(min=torch.finfo(denom.dtype).tiny * 2)
|
|
96
|
+
self.global_state["mu_mul"] = s.global_vector_norm() / denom
|
|
97
|
+
|
|
98
|
+
else:
|
|
99
|
+
# -------------------------------- stochastic -------------------------------- #
|
|
100
|
+
adapt_freq = self.defaults["adapt_freq"]
|
|
101
|
+
|
|
102
|
+
# we start on 1nd step, and want to adapt when we start, so use (step - 1)
|
|
103
|
+
if (step - 1) % adapt_freq == 0:
|
|
104
|
+
assert var.closure is not None
|
|
105
|
+
params = TensorList(var.params)
|
|
106
|
+
p_cur = params.clone()
|
|
107
|
+
|
|
108
|
+
# move to previous params and evaluate p_prev with current mini-batch
|
|
109
|
+
params.copy_(self.get_state(var.params, 'p_prev'))
|
|
110
|
+
with torch.enable_grad():
|
|
111
|
+
var.closure()
|
|
112
|
+
g_prev = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
|
|
113
|
+
y = g - g_prev
|
|
114
|
+
|
|
115
|
+
# move back to current params
|
|
116
|
+
params.copy_(p_cur)
|
|
117
|
+
|
|
118
|
+
denom = y.global_vector_norm()
|
|
119
|
+
denom = denom.clip(min=torch.finfo(denom.dtype).tiny * 2)
|
|
120
|
+
self.global_state["mu_mul"] = s.global_vector_norm() / denom
|
|
121
|
+
|
|
122
|
+
torch._foreach_copy_(p_prev, var.params)
|
|
123
|
+
|
|
124
|
+
@torch.no_grad
|
|
125
|
+
def apply(self, var):
|
|
126
|
+
update = TensorList(var.get_update())
|
|
127
|
+
lr,mu = self.get_settings(var.params, "lr", 'mu', cls=NumberList)
|
|
128
|
+
|
|
129
|
+
if "mu_mul" in self.global_state:
|
|
130
|
+
mu = mu * self.global_state["mu_mul"]
|
|
131
|
+
|
|
132
|
+
# --------------------------------- 1st step --------------------------------- #
|
|
133
|
+
# p_prev is not available so make a small step
|
|
134
|
+
step = self.global_state["step"]
|
|
135
|
+
if step == 1:
|
|
136
|
+
if self.defaults["adaptive"]: self.get_state(var.params, "g_prev", init=var.get_grad())
|
|
137
|
+
update.mul_(lr) # separate so that initial_step_size can clip correctly
|
|
138
|
+
update.mul_(initial_step_size(update, 1e-7))
|
|
139
|
+
return var
|
|
140
|
+
|
|
141
|
+
# -------------------------- matrix momentum update -------------------------- #
|
|
142
|
+
s, Hs = self.get_state(var.params, 's', 'Hs', cls=TensorList)
|
|
143
|
+
|
|
144
|
+
update.mul_(lr).sub_(s).add_(Hs*mu)
|
|
145
|
+
var.update = update
|
|
146
|
+
return var
|
|
@@ -42,13 +42,15 @@ def msam_(
|
|
|
42
42
|
# can't really decouple it from lr
|
|
43
43
|
# but at least it is now expressed as function of g
|
|
44
44
|
|
|
45
|
-
denom =
|
|
45
|
+
denom = velocity_.global_vector_norm() / rho
|
|
46
|
+
denom = denom.clip(min=torch.finfo(tensors[0].dtype).tiny * 2)
|
|
46
47
|
vn = velocity_ / denom
|
|
47
48
|
|
|
48
49
|
mom_ = nag_ if nesterov else ema_
|
|
49
50
|
velocity_ = mom_(tensors, velocity_, momentum, dampening=0, lerp=lerp)
|
|
50
51
|
|
|
51
|
-
denom =
|
|
52
|
+
denom = velocity_.global_vector_norm() / rho
|
|
53
|
+
denom = denom.clip(min=torch.finfo(tensors[0].dtype).tiny * 2)
|
|
52
54
|
v1n = velocity_ / denom
|
|
53
55
|
|
|
54
56
|
if inner is not None:
|
|
@@ -74,11 +76,11 @@ class MSAM(Transform):
|
|
|
74
76
|
replacement for momentum strategies in other optimizers.
|
|
75
77
|
|
|
76
78
|
To combine MSAM with other optimizers in the way done in the official implementation,
|
|
77
|
-
e.g. to make Adam_MSAM, use
|
|
79
|
+
e.g. to make Adam_MSAM, use ``tz.m.MSAMObjective`` module.
|
|
78
80
|
|
|
79
|
-
|
|
81
|
+
Note
|
|
80
82
|
MSAM has a learning rate hyperparameter that can't really be removed from the update rule.
|
|
81
|
-
To avoid compounding learning rate mofications, remove the
|
|
83
|
+
To avoid compounding learning rate mofications, remove the ``tz.m.LR`` module if you had it.
|
|
82
84
|
|
|
83
85
|
Args:
|
|
84
86
|
lr (float): learning rate. Adding this module adds support for learning rate schedulers.
|
|
@@ -112,10 +114,10 @@ class MSAM(Transform):
|
|
|
112
114
|
tz.m.Debias(0.9, 0.999),
|
|
113
115
|
)
|
|
114
116
|
"""
|
|
115
|
-
|
|
117
|
+
_USES_LR = True
|
|
116
118
|
def __init__(self, lr: float, momentum:float=0.9, rho:float=0.3, weight_decay:float=0, nesterov=False, lerp=False,):
|
|
117
119
|
defaults = dict(momentum=momentum,rho=rho, nesterov=nesterov, lerp=lerp, weight_decay=weight_decay)
|
|
118
|
-
if self.
|
|
120
|
+
if self._USES_LR: defaults['lr'] = lr
|
|
119
121
|
super().__init__(defaults, uses_grad=False)
|
|
120
122
|
|
|
121
123
|
@torch.no_grad
|
|
@@ -125,7 +127,7 @@ class MSAM(Transform):
|
|
|
125
127
|
lerp = s['lerp']
|
|
126
128
|
nesterov = s['nesterov']
|
|
127
129
|
|
|
128
|
-
if self.
|
|
130
|
+
if self._USES_LR:
|
|
129
131
|
lr, momentum, rho, weight_decay = unpack_dicts(settings, 'lr','momentum','rho','weight_decay', cls=NumberList)
|
|
130
132
|
|
|
131
133
|
else:
|
|
@@ -152,9 +154,9 @@ class MSAM(Transform):
|
|
|
152
154
|
class MSAMObjective(MSAM):
|
|
153
155
|
"""Momentum-SAM from https://arxiv.org/pdf/2401.12033.
|
|
154
156
|
|
|
155
|
-
|
|
156
|
-
Please make sure to place
|
|
157
|
-
|
|
157
|
+
Note:
|
|
158
|
+
Please make sure to place ``tz.m.LR`` inside the ``modules`` argument. For example,
|
|
159
|
+
``tz.m.MSAMObjective([tz.m.Adam(), tz.m.LR(1e-3)])``. Putting LR after MSAM will lead
|
|
158
160
|
to an incorrect update rule.
|
|
159
161
|
|
|
160
162
|
Args:
|
|
@@ -179,7 +181,7 @@ class MSAMObjective(MSAM):
|
|
|
179
181
|
)
|
|
180
182
|
)
|
|
181
183
|
"""
|
|
182
|
-
|
|
184
|
+
_USES_LR = False
|
|
183
185
|
def __init__(self, modules: Chainable, momentum:float=0.9, rho:float=0.3, weight_decay:float=0, nesterov=False, lerp=False):
|
|
184
186
|
super().__init__(lr=0, momentum=momentum, rho=rho, weight_decay=weight_decay, nesterov=nesterov, lerp=lerp)
|
|
185
187
|
self.set_child('modules', modules)
|
|
@@ -167,26 +167,25 @@ class Orthogonalize(TensorwiseTransform):
|
|
|
167
167
|
target (str, optional):
|
|
168
168
|
what to set on var.
|
|
169
169
|
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
)
|
|
170
|
+
## Examples:
|
|
171
|
+
|
|
172
|
+
standard Muon with Adam fallback
|
|
173
|
+
```py
|
|
174
|
+
opt = tz.Modular(
|
|
175
|
+
model.head.parameters(),
|
|
176
|
+
tz.m.Split(
|
|
177
|
+
# apply muon only to 2D+ parameters
|
|
178
|
+
filter = lambda t: t.ndim >= 2,
|
|
179
|
+
true = [
|
|
180
|
+
tz.m.HeavyBall(),
|
|
181
|
+
tz.m.Orthogonalize(),
|
|
182
|
+
tz.m.LR(1e-2),
|
|
183
|
+
],
|
|
184
|
+
false = tz.m.Adam()
|
|
185
|
+
),
|
|
186
|
+
tz.m.LR(1e-2)
|
|
187
|
+
)
|
|
188
|
+
```
|
|
190
189
|
|
|
191
190
|
Reference:
|
|
192
191
|
Keller Jordan, Yuchen Jin, Vlado Boza, You Jiacheng, Franz Cesista, Laker Newhouse, Jeremy Bernstein - Muon: An optimizer for hidden layers in neural networks (2024) https://github.com/KellerJordan/Muon
|
|
@@ -0,0 +1,175 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from ...core import Module, Chainable, apply_transform
|
|
3
|
+
|
|
4
|
+
from ...utils.derivatives import jacobian_wrt, flatten_jacobian
|
|
5
|
+
from ...utils import vec_to_tensors, TensorList
|
|
6
|
+
from ...utils.linalg import linear_operator
|
|
7
|
+
from .lmadagrad import lm_adagrad_apply, lm_adagrad_update
|
|
8
|
+
|
|
9
|
+
class NaturalGradient(Module):
|
|
10
|
+
"""Natural gradient approximated via empirical fisher information matrix.
|
|
11
|
+
|
|
12
|
+
To use this, either pass vector of per-sample losses to the step method, or make sure
|
|
13
|
+
the closure returns it. Gradients will be calculated via batched autograd within this module,
|
|
14
|
+
you don't need to implement the backward pass. When using closure, please add the ``backward`` argument,
|
|
15
|
+
it will always be False but it is required. See below for an example.
|
|
16
|
+
|
|
17
|
+
Note:
|
|
18
|
+
Empirical fisher information matrix may give a really bad approximation in some cases.
|
|
19
|
+
If that is the case, set ``sqrt`` to True to perform whitening instead, which is way more robust.
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
reg (float, optional): regularization parameter. Defaults to 1e-8.
|
|
23
|
+
sqrt (bool, optional):
|
|
24
|
+
if True, uses square root of empirical fisher information matrix. Both EFIM and it's square
|
|
25
|
+
root can be calculated and stored efficiently without ndim^2 memory. Square root
|
|
26
|
+
whitens the gradient and often performs much better, especially when you try to use NGD
|
|
27
|
+
with a vector that isn't strictly per-sample gradients, but rather for example different losses.
|
|
28
|
+
gn_grad (bool, optional):
|
|
29
|
+
if True, uses Gauss-Newton G^T @ f as the gradient, which is effectively sum weighted by value
|
|
30
|
+
and is equivalent to squaring the values. This way you can solve least-squares
|
|
31
|
+
objectives with a NGD-like algorithm. If False, uses sum of per-sample gradients.
|
|
32
|
+
This has an effect when ``sqrt=True``, and affects the ``grad`` attribute.
|
|
33
|
+
Defaults to False.
|
|
34
|
+
batched (bool, optional): whether to use vmapping. Defaults to True.
|
|
35
|
+
|
|
36
|
+
Examples:
|
|
37
|
+
|
|
38
|
+
training a neural network:
|
|
39
|
+
```python
|
|
40
|
+
X = torch.randn(64, 20)
|
|
41
|
+
y = torch.randn(64, 10)
|
|
42
|
+
|
|
43
|
+
model = nn.Sequential(nn.Linear(20, 64), nn.ELU(), nn.Linear(64, 10))
|
|
44
|
+
opt = tz.Modular(
|
|
45
|
+
model.parameters(),
|
|
46
|
+
tz.m.NaturalGradient(),
|
|
47
|
+
tz.m.LR(3e-2)
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
for i in range(100):
|
|
51
|
+
y_hat = model(X) # (64, 10)
|
|
52
|
+
losses = (y_hat - y).pow(2).mean(0) # (10, )
|
|
53
|
+
opt.step(loss=losses)
|
|
54
|
+
if i % 10 == 0:
|
|
55
|
+
print(f'{losses.mean() = }')
|
|
56
|
+
```
|
|
57
|
+
|
|
58
|
+
training a neural network - closure version
|
|
59
|
+
```python
|
|
60
|
+
X = torch.randn(64, 20)
|
|
61
|
+
y = torch.randn(64, 10)
|
|
62
|
+
|
|
63
|
+
model = nn.Sequential(nn.Linear(20, 64), nn.ELU(), nn.Linear(64, 10))
|
|
64
|
+
opt = tz.Modular(
|
|
65
|
+
model.parameters(),
|
|
66
|
+
tz.m.NaturalGradient(),
|
|
67
|
+
tz.m.LR(3e-2)
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
def closure(backward=True):
|
|
71
|
+
y_hat = model(X) # (64, 10)
|
|
72
|
+
return (y_hat - y).pow(2).mean(0) # (10, )
|
|
73
|
+
|
|
74
|
+
for i in range(100):
|
|
75
|
+
losses = opt.step(closure)
|
|
76
|
+
if i % 10 == 0:
|
|
77
|
+
print(f'{losses.mean() = }')
|
|
78
|
+
```
|
|
79
|
+
|
|
80
|
+
minimizing the rosenbrock function with a mix of natural gradient, whitening and gauss-newton:
|
|
81
|
+
```python
|
|
82
|
+
def rosenbrock(X):
|
|
83
|
+
x1, x2 = X
|
|
84
|
+
return torch.stack([(1 - x1).abs(), (10 * (x2 - x1**2).abs())])
|
|
85
|
+
|
|
86
|
+
X = torch.tensor([-1.1, 2.5], requires_grad=True)
|
|
87
|
+
opt = tz.Modular([X], tz.m.NaturalGradient(sqrt=True, gn_grad=True), tz.m.LR(0.05))
|
|
88
|
+
|
|
89
|
+
for iter in range(200):
|
|
90
|
+
losses = rosenbrock(X)
|
|
91
|
+
opt.step(loss=losses)
|
|
92
|
+
if iter % 20 == 0:
|
|
93
|
+
print(f'{losses.mean() = }')
|
|
94
|
+
```
|
|
95
|
+
"""
|
|
96
|
+
def __init__(self, reg:float = 1e-8, sqrt:bool=False, gn_grad:bool=False, batched:bool=True, ):
|
|
97
|
+
super().__init__(defaults=dict(batched=batched, reg=reg, sqrt=sqrt, gn_grad=gn_grad))
|
|
98
|
+
|
|
99
|
+
@torch.no_grad
|
|
100
|
+
def update(self, var):
|
|
101
|
+
params = var.params
|
|
102
|
+
batched = self.defaults['batched']
|
|
103
|
+
gn_grad = self.defaults['gn_grad']
|
|
104
|
+
|
|
105
|
+
closure = var.closure
|
|
106
|
+
assert closure is not None
|
|
107
|
+
|
|
108
|
+
with torch.enable_grad():
|
|
109
|
+
f = var.get_loss(backward=False) # n_out
|
|
110
|
+
assert isinstance(f, torch.Tensor)
|
|
111
|
+
G_list = jacobian_wrt([f.ravel()], params, batched=batched)
|
|
112
|
+
|
|
113
|
+
var.loss = f.sum()
|
|
114
|
+
G = self.global_state["G"] = flatten_jacobian(G_list) # (n_samples, ndim)
|
|
115
|
+
|
|
116
|
+
if gn_grad:
|
|
117
|
+
g = self.global_state["g"] = G.H @ f.detach()
|
|
118
|
+
|
|
119
|
+
else:
|
|
120
|
+
g = self.global_state["g"] = G.sum(0)
|
|
121
|
+
|
|
122
|
+
var.grad = vec_to_tensors(g, params)
|
|
123
|
+
|
|
124
|
+
# set closure to calculate scalar value for line searches etc
|
|
125
|
+
if var.closure is not None:
|
|
126
|
+
def ngd_closure(backward=True):
|
|
127
|
+
if backward:
|
|
128
|
+
var.zero_grad()
|
|
129
|
+
with torch.enable_grad():
|
|
130
|
+
loss = closure(False)
|
|
131
|
+
if gn_grad: loss = loss.pow(2)
|
|
132
|
+
loss = loss.sum()
|
|
133
|
+
loss.backward()
|
|
134
|
+
return loss
|
|
135
|
+
|
|
136
|
+
loss = closure(False)
|
|
137
|
+
if gn_grad: loss = loss.pow(2)
|
|
138
|
+
return loss.sum()
|
|
139
|
+
|
|
140
|
+
var.closure = ngd_closure
|
|
141
|
+
|
|
142
|
+
@torch.no_grad
|
|
143
|
+
def apply(self, var):
|
|
144
|
+
params = var.params
|
|
145
|
+
reg = self.defaults['reg']
|
|
146
|
+
sqrt = self.defaults['sqrt']
|
|
147
|
+
|
|
148
|
+
G: torch.Tensor = self.global_state['G'] # (n_samples, n_dim)
|
|
149
|
+
|
|
150
|
+
if sqrt:
|
|
151
|
+
# this computes U, S <- SVD(M), then calculate update as U S^-1 Uᵀg,
|
|
152
|
+
# but it computes it through eigendecompotision
|
|
153
|
+
U, L = lm_adagrad_update(G.H, reg, 0)
|
|
154
|
+
if U is None or L is None: return var
|
|
155
|
+
|
|
156
|
+
v = lm_adagrad_apply(self.global_state["g"], U, L)
|
|
157
|
+
var.update = vec_to_tensors(v, params)
|
|
158
|
+
return var
|
|
159
|
+
|
|
160
|
+
GGT = G @ G.H # (n_samples, n_samples)
|
|
161
|
+
|
|
162
|
+
if reg != 0:
|
|
163
|
+
GGT.add_(torch.eye(GGT.size(0), device=GGT.device, dtype=GGT.dtype).mul_(reg))
|
|
164
|
+
|
|
165
|
+
z, _ = torch.linalg.solve_ex(GGT, torch.ones_like(GGT[0])) # pylint:disable=not-callable
|
|
166
|
+
v = G.H @ z
|
|
167
|
+
|
|
168
|
+
var.update = vec_to_tensors(v, params)
|
|
169
|
+
return var
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
def get_H(self, var):
|
|
173
|
+
if "G" not in self.global_state: return linear_operator.ScaledIdentity()
|
|
174
|
+
G = self.global_state['G']
|
|
175
|
+
return linear_operator.AtA(G)
|
|
@@ -258,8 +258,6 @@ class BacktrackOnSignChange(Transform):
|
|
|
258
258
|
This is part of RProp update rule.
|
|
259
259
|
|
|
260
260
|
Args:
|
|
261
|
-
normalize (bool, optional): renormalize update after masking. Defaults to False.
|
|
262
|
-
eps (_type_, optional): epsilon for normalization. Defaults to 1e-6.
|
|
263
261
|
use_grad (bool, optional):
|
|
264
262
|
if True, tracks sign change of the gradient,
|
|
265
263
|
otherwise track sign change of the update. Defaults to True.
|
|
@@ -63,7 +63,7 @@ class SAM(Module):
|
|
|
63
63
|
zero_grad = var.zero_grad
|
|
64
64
|
if closure is None: raise RuntimeError("SAM requires a closure passed to the optimizer step")
|
|
65
65
|
p, rho = self.get_settings(var.params, 'p', 'rho', cls=NumberList)
|
|
66
|
-
s = self.
|
|
66
|
+
s = self.defaults
|
|
67
67
|
eps = s['eps']
|
|
68
68
|
asam = s['asam']
|
|
69
69
|
|
|
@@ -17,6 +17,7 @@ def update_shampoo_preconditioner_(
|
|
|
17
17
|
update_freq: int,
|
|
18
18
|
exp_override: int | None,
|
|
19
19
|
beta: float | None,
|
|
20
|
+
reg: float
|
|
20
21
|
):
|
|
21
22
|
for i, (accumulator, preconditioner) in enumerate(zip(accumulators_, preconditioners_)):
|
|
22
23
|
if accumulator is None: continue
|
|
@@ -28,6 +29,8 @@ def update_shampoo_preconditioner_(
|
|
|
28
29
|
|
|
29
30
|
if step % update_freq == 0:
|
|
30
31
|
matrix_exp = -1/(grad.ndim*2) if exp_override is None else -1/exp_override
|
|
32
|
+
if reg != 0:
|
|
33
|
+
accumulator = accumulator + torch.eye(accumulator.size(0), device=accumulator.device, dtype=accumulator.dtype).mul_(reg)
|
|
31
34
|
set_storage_(preconditioner, matrix_power_eigh(accumulator, matrix_exp))
|
|
32
35
|
|
|
33
36
|
|
|
@@ -99,7 +102,6 @@ class Shampoo(Transform):
|
|
|
99
102
|
decay (float | None, optional): slowly decays preconditioners. Defaults to None.
|
|
100
103
|
beta (float | None, optional):
|
|
101
104
|
if None calculates sum as in standard shampoo, otherwise uses EMA of preconditioners. Defaults to None.
|
|
102
|
-
matrix_eps (float, optional): epsilon for matrix operations. Defaults to 1e-10.
|
|
103
105
|
update_freq (int, optional): preconditioner update frequency. Defaults to 10.
|
|
104
106
|
exp_override (int | None, optional): matrix exponent override, if not set, uses 2*ndim. Defaults to 2.
|
|
105
107
|
merge_small (bool, optional): whether to merge small dims on tensors. Defaults to True.
|
|
@@ -140,6 +142,7 @@ class Shampoo(Transform):
|
|
|
140
142
|
self,
|
|
141
143
|
decay: float | None = None,
|
|
142
144
|
beta: float | None = None,
|
|
145
|
+
reg: float = 1e-12,
|
|
143
146
|
update_freq: int = 10,
|
|
144
147
|
exp_override: int | None = 2,
|
|
145
148
|
merge_small: bool = True,
|
|
@@ -148,7 +151,7 @@ class Shampoo(Transform):
|
|
|
148
151
|
adagrad_eps: float = 1e-8,
|
|
149
152
|
inner: Chainable | None = None,
|
|
150
153
|
):
|
|
151
|
-
defaults = dict(decay=decay, beta=beta, update_freq=update_freq, exp_override=exp_override, merge_small=merge_small, max_dim=max_dim, precondition_1d=precondition_1d,adagrad_eps=adagrad_eps)
|
|
154
|
+
defaults = dict(decay=decay, beta=beta, update_freq=update_freq, exp_override=exp_override, merge_small=merge_small, max_dim=max_dim, precondition_1d=precondition_1d,adagrad_eps=adagrad_eps, reg=reg)
|
|
152
155
|
super().__init__(defaults, uses_grad=False)
|
|
153
156
|
|
|
154
157
|
if inner is not None:
|
|
@@ -159,8 +162,8 @@ class Shampoo(Transform):
|
|
|
159
162
|
|
|
160
163
|
# update preconditioners
|
|
161
164
|
for i,(t,state, setting) in enumerate(zip(tensors, states, settings)):
|
|
162
|
-
beta, update_freq, exp_override, merge_small, max_dim, precondition_1d = itemgetter(
|
|
163
|
-
'beta', 'update_freq', 'exp_override', 'merge_small', 'max_dim', 'precondition_1d')(setting)
|
|
165
|
+
beta, update_freq, exp_override, merge_small, max_dim, precondition_1d, reg = itemgetter(
|
|
166
|
+
'beta', 'update_freq', 'exp_override', 'merge_small', 'max_dim', 'precondition_1d', "reg")(setting)
|
|
164
167
|
|
|
165
168
|
if merge_small:
|
|
166
169
|
t, state['flat_sizes'], state['sort_idxs'] = _merge_small_dims(t, max_dim)
|
|
@@ -195,6 +198,7 @@ class Shampoo(Transform):
|
|
|
195
198
|
update_freq=update_freq,
|
|
196
199
|
exp_override=exp_override,
|
|
197
200
|
beta=beta,
|
|
201
|
+
reg=reg,
|
|
198
202
|
)
|
|
199
203
|
|
|
200
204
|
# inner step
|
|
@@ -1,9 +1,10 @@
|
|
|
1
1
|
from operator import itemgetter
|
|
2
|
+
import warnings
|
|
2
3
|
|
|
3
4
|
import torch
|
|
4
5
|
|
|
5
6
|
from ...core import Chainable, Transform, apply_transform
|
|
6
|
-
from ...modules.
|
|
7
|
+
from ...modules.adaptive.shampoo import _merge_small_dims, _unmerge_small_dims
|
|
7
8
|
|
|
8
9
|
@torch.no_grad
|
|
9
10
|
def update_soap_covariances_(
|
|
@@ -52,36 +53,23 @@ def get_orthogonal_matrix(mat: list[torch.Tensor | None]):
|
|
|
52
53
|
"""
|
|
53
54
|
Computes the eigenbases of the preconditioner using torch.linalg.eigh decomposition.
|
|
54
55
|
"""
|
|
55
|
-
matrix = []
|
|
56
|
-
float_data = False
|
|
57
|
-
original_type = original_device = None
|
|
58
|
-
for m in mat:
|
|
59
|
-
if m is None or len(m) == 0:
|
|
60
|
-
matrix.append([])
|
|
61
|
-
continue
|
|
62
|
-
if m.dtype != torch.float:
|
|
63
|
-
original_type = m.dtype
|
|
64
|
-
original_device = m.device
|
|
65
|
-
matrix.append(m.float())
|
|
66
|
-
else:
|
|
67
|
-
float_data = True
|
|
68
|
-
matrix.append(m)
|
|
69
56
|
|
|
70
57
|
final = []
|
|
71
|
-
for m in
|
|
72
|
-
|
|
58
|
+
for m in mat:
|
|
59
|
+
|
|
60
|
+
if m is None or len(m) == 0:
|
|
73
61
|
final.append([])
|
|
74
62
|
continue
|
|
63
|
+
|
|
75
64
|
try:
|
|
76
65
|
_, Q = torch.linalg.eigh(m+1e-30*torch.eye(m.shape[0], device=m.device)) # pylint:disable=not-callable
|
|
77
|
-
except
|
|
66
|
+
except torch.linalg.LinAlgError:
|
|
78
67
|
_, Q = torch.linalg.eigh(m.to(torch.float64)+1e-30*torch.eye(m.shape[0], device=m.device)) # pylint:disable=not-callable
|
|
79
68
|
Q = Q.to(m.dtype)
|
|
80
|
-
Q = torch.flip(Q, [1])
|
|
81
69
|
|
|
82
|
-
|
|
83
|
-
Q = Q.to(original_device).type(original_type)
|
|
70
|
+
Q = torch.flip(Q, [1])
|
|
84
71
|
final.append(Q)
|
|
72
|
+
|
|
85
73
|
return final
|
|
86
74
|
|
|
87
75
|
# function from https://github.com/nikhilvyas/SOAP/blob/main/soap.py#L240
|
|
@@ -91,40 +79,24 @@ def get_orthogonal_matrix_QR(exp_avg_sq: torch.Tensor, GG: list[torch.Tensor | N
|
|
|
91
79
|
Computes the eigenbases of the preconditioner using one round of power iteration
|
|
92
80
|
followed by torch.linalg.qr decomposition.
|
|
93
81
|
"""
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
82
|
+
final = []
|
|
83
|
+
|
|
84
|
+
for ind, (m,o) in enumerate(zip(GG, Q_list)):
|
|
85
|
+
|
|
86
|
+
# skip 1d or large dims
|
|
99
87
|
if m is None or len(m) == 0:
|
|
100
|
-
|
|
101
|
-
orth_matrix.append([])
|
|
88
|
+
final.append([])
|
|
102
89
|
continue
|
|
103
90
|
assert o is not None
|
|
104
|
-
if m.data.dtype != torch.float:
|
|
105
|
-
original_type = m.data.dtype
|
|
106
|
-
original_device = m.data.device
|
|
107
|
-
matrix.append(m.data.float())
|
|
108
|
-
orth_matrix.append(o.data.float())
|
|
109
|
-
else:
|
|
110
|
-
float_data = True
|
|
111
|
-
matrix.append(m.data.float())
|
|
112
|
-
orth_matrix.append(o.data.float())
|
|
113
91
|
|
|
114
|
-
final = []
|
|
115
|
-
for ind, (m,o) in enumerate(zip(matrix, orth_matrix)):
|
|
116
|
-
if len(m)==0:
|
|
117
|
-
final.append([])
|
|
118
|
-
continue
|
|
119
92
|
est_eig = torch.diag(o.T @ m @ o)
|
|
120
93
|
sort_idx = torch.argsort(est_eig, descending=True)
|
|
121
94
|
exp_avg_sq = exp_avg_sq.index_select(ind, sort_idx)
|
|
122
|
-
o = o[:,sort_idx]
|
|
123
|
-
power_iter = m @ o
|
|
124
|
-
Q, _ = torch.linalg.qr(power_iter) # pylint:disable=not-callable
|
|
125
95
|
|
|
126
|
-
|
|
127
|
-
|
|
96
|
+
power_iter = m @ o[:, sort_idx]
|
|
97
|
+
Q, _ = torch.linalg.qr(power_iter.to(torch.float32)) # pylint:disable=not-callable
|
|
98
|
+
Q = Q.to(power_iter.dtype)
|
|
99
|
+
|
|
128
100
|
final.append(Q)
|
|
129
101
|
|
|
130
102
|
return final, exp_avg_sq
|
|
@@ -226,7 +198,10 @@ class SOAP(Transform):
|
|
|
226
198
|
|
|
227
199
|
if state['GG'] is not None:
|
|
228
200
|
update_soap_covariances_(t, GGs_=state['GG'], beta=shampoo_beta)
|
|
229
|
-
state['Q'] = get_orthogonal_matrix(state['GG'])
|
|
201
|
+
try: state['Q'] = get_orthogonal_matrix(state['GG'])
|
|
202
|
+
except torch.linalg.LinAlgError as e:
|
|
203
|
+
warnings.warn(f"torch.linalg.eigh raised an error when initializing SOAP Q matrices on 1st step, diagonal preconditioning will be used for this parameter. The error was:\n{e}")
|
|
204
|
+
state["GG"] = None
|
|
230
205
|
|
|
231
206
|
state['step'] = 0
|
|
232
207
|
updates.append(tensors[i].clip(-0.1, 0.1))
|
|
@@ -283,6 +258,8 @@ class SOAP(Transform):
|
|
|
283
258
|
if state['GG'] is not None:
|
|
284
259
|
update_soap_covariances_(t, state['GG'], shampoo_beta)
|
|
285
260
|
if state['step'] % setting['precond_freq'] == 0:
|
|
286
|
-
|
|
287
|
-
|
|
261
|
+
try:
|
|
262
|
+
state['Q'], state['exp_avg_sq_projected'] = get_orthogonal_matrix_QR(exp_avg_sq_projected, state['GG'], state['Q'])
|
|
263
|
+
except torch.linalg.LinAlgError:
|
|
264
|
+
pass
|
|
288
265
|
return updates
|
|
@@ -4,8 +4,6 @@ import torch
|
|
|
4
4
|
|
|
5
5
|
from ...core import Module, Target, Transform, Chainable, apply_transform
|
|
6
6
|
from ...utils import NumberList, TensorList, as_tensorlist
|
|
7
|
-
from ...utils.derivatives import hvp, hvp_fd_forward, hvp_fd_central
|
|
8
|
-
|
|
9
7
|
def sophia_H(
|
|
10
8
|
tensors: TensorList,
|
|
11
9
|
h: TensorList | None,
|
|
@@ -72,7 +70,7 @@ class SophiaH(Module):
|
|
|
72
70
|
more accurate HVP approximation. This requires two extra
|
|
73
71
|
gradient evaluations.
|
|
74
72
|
Defaults to "autograd".
|
|
75
|
-
|
|
73
|
+
fd_h (float, optional): finite difference step size if :code:`hvp_method` is "forward" or "central". Defaults to 1e-3.
|
|
76
74
|
n_samples (int, optional):
|
|
77
75
|
number of hessian-vector products with random vectors to evaluate each time when updating
|
|
78
76
|
the preconditioner. Larger values may lead to better hessian diagonal estimate. Defaults to 1.
|
|
@@ -159,6 +157,7 @@ class SophiaH(Module):
|
|
|
159
157
|
|
|
160
158
|
Hvp, rgrad = self.Hvp(u, at_x0=True, var=var, rgrad=rgrad, hvp_method=hvp_method,
|
|
161
159
|
h=fd_h, normalize=True, retain_grad=i < n_samples-1)
|
|
160
|
+
Hvp = tuple(Hvp)
|
|
162
161
|
|
|
163
162
|
if h is None: h = Hvp
|
|
164
163
|
else: torch._foreach_add_(h, Hvp)
|