torchzero 0.3.10__py3-none-any.whl → 0.3.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_identical.py +2 -3
- tests/test_opts.py +140 -100
- tests/test_tensorlist.py +8 -7
- tests/test_vars.py +1 -0
- torchzero/__init__.py +1 -1
- torchzero/core/__init__.py +2 -2
- torchzero/core/module.py +335 -50
- torchzero/core/reformulation.py +65 -0
- torchzero/core/transform.py +197 -70
- torchzero/modules/__init__.py +13 -4
- torchzero/modules/adaptive/__init__.py +30 -0
- torchzero/modules/adaptive/adagrad.py +356 -0
- torchzero/modules/adaptive/adahessian.py +224 -0
- torchzero/modules/{optimizers → adaptive}/adam.py +6 -8
- torchzero/modules/adaptive/adan.py +96 -0
- torchzero/modules/adaptive/adaptive_heavyball.py +54 -0
- torchzero/modules/adaptive/aegd.py +54 -0
- torchzero/modules/adaptive/esgd.py +171 -0
- torchzero/modules/{optimizers → adaptive}/lion.py +1 -1
- torchzero/modules/{experimental/spectral.py → adaptive/lmadagrad.py} +94 -71
- torchzero/modules/adaptive/mars.py +79 -0
- torchzero/modules/adaptive/matrix_momentum.py +146 -0
- torchzero/modules/adaptive/msam.py +188 -0
- torchzero/modules/{optimizers → adaptive}/muon.py +29 -5
- torchzero/modules/adaptive/natural_gradient.py +175 -0
- torchzero/modules/{optimizers → adaptive}/orthograd.py +1 -1
- torchzero/modules/{optimizers → adaptive}/rmsprop.py +7 -4
- torchzero/modules/{optimizers → adaptive}/rprop.py +42 -10
- torchzero/modules/adaptive/sam.py +163 -0
- torchzero/modules/{optimizers → adaptive}/shampoo.py +47 -9
- torchzero/modules/{optimizers → adaptive}/soap.py +52 -65
- torchzero/modules/adaptive/sophia_h.py +185 -0
- torchzero/modules/clipping/clipping.py +115 -25
- torchzero/modules/clipping/ema_clipping.py +31 -17
- torchzero/modules/clipping/growth_clipping.py +8 -7
- torchzero/modules/conjugate_gradient/__init__.py +11 -0
- torchzero/modules/conjugate_gradient/cg.py +355 -0
- torchzero/modules/experimental/__init__.py +13 -19
- torchzero/modules/{projections → experimental}/dct.py +11 -11
- torchzero/modules/{projections → experimental}/fft.py +10 -10
- torchzero/modules/experimental/gradmin.py +4 -3
- torchzero/modules/experimental/l_infinity.py +111 -0
- torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +5 -42
- torchzero/modules/experimental/newton_solver.py +79 -17
- torchzero/modules/experimental/newtonnewton.py +32 -15
- torchzero/modules/experimental/reduce_outward_lr.py +4 -4
- torchzero/modules/experimental/scipy_newton_cg.py +105 -0
- torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +13 -55
- torchzero/modules/functional.py +52 -6
- torchzero/modules/grad_approximation/fdm.py +30 -4
- torchzero/modules/grad_approximation/forward_gradient.py +16 -4
- torchzero/modules/grad_approximation/grad_approximator.py +51 -10
- torchzero/modules/grad_approximation/rfdm.py +321 -52
- torchzero/modules/higher_order/__init__.py +1 -1
- torchzero/modules/higher_order/higher_order_newton.py +164 -93
- torchzero/modules/least_squares/__init__.py +1 -0
- torchzero/modules/least_squares/gn.py +161 -0
- torchzero/modules/line_search/__init__.py +4 -4
- torchzero/modules/line_search/_polyinterp.py +289 -0
- torchzero/modules/line_search/adaptive.py +124 -0
- torchzero/modules/line_search/backtracking.py +95 -57
- torchzero/modules/line_search/line_search.py +171 -22
- torchzero/modules/line_search/scipy.py +3 -3
- torchzero/modules/line_search/strong_wolfe.py +327 -199
- torchzero/modules/misc/__init__.py +35 -0
- torchzero/modules/misc/debug.py +48 -0
- torchzero/modules/misc/escape.py +62 -0
- torchzero/modules/misc/gradient_accumulation.py +136 -0
- torchzero/modules/misc/homotopy.py +59 -0
- torchzero/modules/misc/misc.py +383 -0
- torchzero/modules/misc/multistep.py +194 -0
- torchzero/modules/misc/regularization.py +167 -0
- torchzero/modules/misc/split.py +123 -0
- torchzero/modules/{ops → misc}/switch.py +45 -4
- torchzero/modules/momentum/__init__.py +1 -5
- torchzero/modules/momentum/averaging.py +9 -9
- torchzero/modules/momentum/cautious.py +51 -19
- torchzero/modules/momentum/momentum.py +37 -2
- torchzero/modules/ops/__init__.py +11 -31
- torchzero/modules/ops/accumulate.py +6 -10
- torchzero/modules/ops/binary.py +81 -34
- torchzero/modules/{momentum/ema.py → ops/higher_level.py} +16 -39
- torchzero/modules/ops/multi.py +82 -21
- torchzero/modules/ops/reduce.py +16 -8
- torchzero/modules/ops/unary.py +29 -13
- torchzero/modules/ops/utility.py +30 -18
- torchzero/modules/projections/__init__.py +2 -4
- torchzero/modules/projections/cast.py +51 -0
- torchzero/modules/projections/galore.py +3 -1
- torchzero/modules/projections/projection.py +190 -96
- torchzero/modules/quasi_newton/__init__.py +9 -14
- torchzero/modules/quasi_newton/damping.py +105 -0
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -0
- torchzero/modules/quasi_newton/lbfgs.py +286 -173
- torchzero/modules/quasi_newton/lsr1.py +185 -106
- torchzero/modules/quasi_newton/quasi_newton.py +816 -268
- torchzero/modules/restarts/__init__.py +7 -0
- torchzero/modules/restarts/restars.py +252 -0
- torchzero/modules/second_order/__init__.py +3 -2
- torchzero/modules/second_order/multipoint.py +238 -0
- torchzero/modules/second_order/newton.py +292 -68
- torchzero/modules/second_order/newton_cg.py +365 -15
- torchzero/modules/second_order/nystrom.py +104 -1
- torchzero/modules/smoothing/__init__.py +1 -1
- torchzero/modules/smoothing/laplacian.py +14 -4
- torchzero/modules/smoothing/sampling.py +300 -0
- torchzero/modules/step_size/__init__.py +2 -0
- torchzero/modules/step_size/adaptive.py +387 -0
- torchzero/modules/step_size/lr.py +154 -0
- torchzero/modules/termination/__init__.py +14 -0
- torchzero/modules/termination/termination.py +207 -0
- torchzero/modules/trust_region/__init__.py +5 -0
- torchzero/modules/trust_region/cubic_regularization.py +170 -0
- torchzero/modules/trust_region/dogleg.py +92 -0
- torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
- torchzero/modules/trust_region/trust_cg.py +97 -0
- torchzero/modules/trust_region/trust_region.py +350 -0
- torchzero/modules/variance_reduction/__init__.py +1 -0
- torchzero/modules/variance_reduction/svrg.py +208 -0
- torchzero/modules/weight_decay/__init__.py +1 -1
- torchzero/modules/weight_decay/weight_decay.py +94 -11
- torchzero/modules/wrappers/optim_wrapper.py +29 -1
- torchzero/modules/zeroth_order/__init__.py +1 -0
- torchzero/modules/zeroth_order/cd.py +359 -0
- torchzero/optim/root.py +65 -0
- torchzero/optim/utility/split.py +8 -8
- torchzero/optim/wrappers/directsearch.py +39 -3
- torchzero/optim/wrappers/fcmaes.py +24 -15
- torchzero/optim/wrappers/mads.py +5 -6
- torchzero/optim/wrappers/nevergrad.py +16 -1
- torchzero/optim/wrappers/nlopt.py +0 -2
- torchzero/optim/wrappers/optuna.py +3 -3
- torchzero/optim/wrappers/scipy.py +86 -25
- torchzero/utils/__init__.py +40 -4
- torchzero/utils/compile.py +1 -1
- torchzero/utils/derivatives.py +126 -114
- torchzero/utils/linalg/__init__.py +9 -2
- torchzero/utils/linalg/linear_operator.py +329 -0
- torchzero/utils/linalg/matrix_funcs.py +2 -2
- torchzero/utils/linalg/orthogonalize.py +2 -1
- torchzero/utils/linalg/qr.py +2 -2
- torchzero/utils/linalg/solve.py +369 -58
- torchzero/utils/metrics.py +83 -0
- torchzero/utils/numberlist.py +2 -0
- torchzero/utils/python_tools.py +16 -0
- torchzero/utils/tensorlist.py +134 -51
- torchzero/utils/torch_tools.py +9 -4
- torchzero-0.3.13.dist-info/METADATA +14 -0
- torchzero-0.3.13.dist-info/RECORD +166 -0
- {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/top_level.txt +0 -1
- docs/source/conf.py +0 -57
- torchzero/modules/experimental/absoap.py +0 -250
- torchzero/modules/experimental/adadam.py +0 -112
- torchzero/modules/experimental/adamY.py +0 -125
- torchzero/modules/experimental/adasoap.py +0 -172
- torchzero/modules/experimental/diagonal_higher_order_newton.py +0 -225
- torchzero/modules/experimental/eigendescent.py +0 -117
- torchzero/modules/experimental/etf.py +0 -172
- torchzero/modules/experimental/soapy.py +0 -163
- torchzero/modules/experimental/structured_newton.py +0 -111
- torchzero/modules/experimental/subspace_preconditioners.py +0 -138
- torchzero/modules/experimental/tada.py +0 -38
- torchzero/modules/line_search/trust_region.py +0 -73
- torchzero/modules/lr/__init__.py +0 -2
- torchzero/modules/lr/adaptive.py +0 -93
- torchzero/modules/lr/lr.py +0 -63
- torchzero/modules/momentum/matrix_momentum.py +0 -166
- torchzero/modules/ops/debug.py +0 -25
- torchzero/modules/ops/misc.py +0 -418
- torchzero/modules/ops/split.py +0 -75
- torchzero/modules/optimizers/__init__.py +0 -18
- torchzero/modules/optimizers/adagrad.py +0 -155
- torchzero/modules/optimizers/sophia_h.py +0 -129
- torchzero/modules/quasi_newton/cg.py +0 -268
- torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
- torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +0 -266
- torchzero/modules/quasi_newton/olbfgs.py +0 -196
- torchzero/modules/smoothing/gaussian.py +0 -164
- torchzero-0.3.10.dist-info/METADATA +0 -379
- torchzero-0.3.10.dist-info/RECORD +0 -139
- torchzero-0.3.10.dist-info/licenses/LICENSE +0 -21
- {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,79 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
from ...core import Transform
|
|
4
|
+
from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def mars_correction_(
|
|
8
|
+
tensors_: TensorList,
|
|
9
|
+
prev_: TensorList,
|
|
10
|
+
beta: float | NumberList,
|
|
11
|
+
scaling: float | NumberList,
|
|
12
|
+
max_norm: float | NumberList | None,
|
|
13
|
+
):
|
|
14
|
+
dg = (tensors_ - prev_).mul_(scaling * beta / (1-beta))
|
|
15
|
+
prev_.copy_(tensors_)
|
|
16
|
+
|
|
17
|
+
c = tensors_.add_(dg)
|
|
18
|
+
if max_norm is not None:
|
|
19
|
+
c.clip_norm_(max=max_norm, tensorwise=False)
|
|
20
|
+
|
|
21
|
+
return c
|
|
22
|
+
|
|
23
|
+
class MARSCorrection(Transform):
|
|
24
|
+
"""MARS variance reduction correction.
|
|
25
|
+
|
|
26
|
+
Place any other momentum-based optimizer after this,
|
|
27
|
+
make sure ``beta`` parameter matches with momentum in the optimizer.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
beta (float, optional): use the same beta as you use in the momentum module. Defaults to 0.9.
|
|
31
|
+
scaling (float, optional): controls the scale of gradient correction in variance reduction. Defaults to 0.025.
|
|
32
|
+
max_norm (float, optional): clips norm of corrected gradients, None to disable. Defaults to 1.
|
|
33
|
+
|
|
34
|
+
## Examples:
|
|
35
|
+
|
|
36
|
+
Mars-AdamW
|
|
37
|
+
```python
|
|
38
|
+
optimizer = tz.Modular(
|
|
39
|
+
model.parameters(),
|
|
40
|
+
tz.m.MARSCorrection(beta=0.95),
|
|
41
|
+
tz.m.Adam(beta1=0.95, beta2=0.99),
|
|
42
|
+
tz.m.WeightDecay(1e-3),
|
|
43
|
+
tz.m.LR(0.1)
|
|
44
|
+
)
|
|
45
|
+
```
|
|
46
|
+
|
|
47
|
+
Mars-Lion
|
|
48
|
+
```python
|
|
49
|
+
optimizer = tz.Modular(
|
|
50
|
+
model.parameters(),
|
|
51
|
+
tz.m.MARSCorrection(beta=0.9),
|
|
52
|
+
tz.m.Lion(beta1=0.9),
|
|
53
|
+
tz.m.LR(0.1)
|
|
54
|
+
)
|
|
55
|
+
```
|
|
56
|
+
|
|
57
|
+
"""
|
|
58
|
+
def __init__(
|
|
59
|
+
self,
|
|
60
|
+
beta: float = 0.9,
|
|
61
|
+
scaling: float = 0.025,
|
|
62
|
+
max_norm: float | None = 1,
|
|
63
|
+
):
|
|
64
|
+
defaults=dict(beta=beta, scaling=scaling, max_norm=max_norm)
|
|
65
|
+
super().__init__(defaults, uses_grad=False)
|
|
66
|
+
|
|
67
|
+
@torch.no_grad
|
|
68
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
69
|
+
prev = unpack_states(states, tensors, 'prev', init=tensors, cls=TensorList)
|
|
70
|
+
beta, scaling = unpack_dicts(settings, 'beta', 'scaling', cls=NumberList)
|
|
71
|
+
max_norm = settings[0]['max_norm']
|
|
72
|
+
|
|
73
|
+
return mars_correction_(
|
|
74
|
+
tensors_=TensorList(tensors),
|
|
75
|
+
prev_=prev,
|
|
76
|
+
beta=beta,
|
|
77
|
+
scaling=scaling,
|
|
78
|
+
max_norm=max_norm,
|
|
79
|
+
)
|
|
@@ -0,0 +1,146 @@
|
|
|
1
|
+
from typing import Literal
|
|
2
|
+
from collections.abc import Callable
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from ...core import Module, apply_transform, Chainable
|
|
6
|
+
from ...utils import NumberList, TensorList, as_tensorlist
|
|
7
|
+
from ...utils.derivatives import hvp, hvp_fd_central, hvp_fd_forward
|
|
8
|
+
from ..functional import initial_step_size
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class MatrixMomentum(Module):
|
|
12
|
+
"""Second order momentum method.
|
|
13
|
+
|
|
14
|
+
Matrix momentum is useful for convex objectives, also for some reason it has very really good generalization on elastic net logistic regression.
|
|
15
|
+
|
|
16
|
+
Notes:
|
|
17
|
+
- ``mu`` needs to be tuned very carefully. It is supposed to be smaller than (1/largest eigenvalue), otherwise this will be very unstable. I have devised an adaptive version of this - ``tz.m.AdaptiveMatrixMomentum``, and it works well without having to tune ``mu``, however the adaptive version doesn't work on stochastic objectives.
|
|
18
|
+
|
|
19
|
+
- In most cases ``MatrixMomentum`` should be the first module in the chain because it relies on autograd.
|
|
20
|
+
|
|
21
|
+
- This module requires the a closure passed to the optimizer step, as it needs to re-evaluate the loss and gradients for calculating HVPs. The closure must accept a ``backward`` argument.
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
mu (float, optional): this has a similar role to (1 - beta) in normal momentum. Defaults to 0.1.
|
|
25
|
+
hvp_method (str, optional):
|
|
26
|
+
Determines how Hessian-vector products are evaluated.
|
|
27
|
+
|
|
28
|
+
- ``"autograd"``: Use PyTorch's autograd to calculate exact HVPs.
|
|
29
|
+
This requires creating a graph for the gradient.
|
|
30
|
+
- ``"forward"``: Use a forward finite difference formula to
|
|
31
|
+
approximate the HVP. This requires one extra gradient evaluation.
|
|
32
|
+
- ``"central"``: Use a central finite difference formula for a
|
|
33
|
+
more accurate HVP approximation. This requires two extra
|
|
34
|
+
gradient evaluations.
|
|
35
|
+
Defaults to "autograd".
|
|
36
|
+
h (float, optional): finite difference step size if hvp_method is set to finite difference. Defaults to 1e-3.
|
|
37
|
+
hvp_tfm (Chainable | None, optional): optional module applied to hessian-vector products. Defaults to None.
|
|
38
|
+
|
|
39
|
+
Reference:
|
|
40
|
+
Orr, Genevieve, and Todd Leen. "Using curvature information for fast stochastic search." Advances in neural information processing systems 9 (1996).
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
def __init__(
|
|
44
|
+
self,
|
|
45
|
+
lr:float,
|
|
46
|
+
mu=0.1,
|
|
47
|
+
hvp_method: Literal["autograd", "forward", "central"] = "autograd",
|
|
48
|
+
h: float = 1e-3,
|
|
49
|
+
adaptive:bool = False,
|
|
50
|
+
adapt_freq: int | None = None,
|
|
51
|
+
hvp_tfm: Chainable | None = None,
|
|
52
|
+
):
|
|
53
|
+
defaults = dict(lr=lr, mu=mu, hvp_method=hvp_method, h=h, adaptive=adaptive, adapt_freq=adapt_freq)
|
|
54
|
+
super().__init__(defaults)
|
|
55
|
+
|
|
56
|
+
if hvp_tfm is not None:
|
|
57
|
+
self.set_child('hvp_tfm', hvp_tfm)
|
|
58
|
+
|
|
59
|
+
def reset_for_online(self):
|
|
60
|
+
super().reset_for_online()
|
|
61
|
+
self.clear_state_keys('p_prev')
|
|
62
|
+
|
|
63
|
+
@torch.no_grad
|
|
64
|
+
def update(self, var):
|
|
65
|
+
assert var.closure is not None
|
|
66
|
+
p = TensorList(var.params)
|
|
67
|
+
p_prev = self.get_state(p, 'p_prev', init=var.params)
|
|
68
|
+
|
|
69
|
+
hvp_method = self.defaults['hvp_method']
|
|
70
|
+
h = self.defaults['h']
|
|
71
|
+
step = self.global_state.get("step", 0)
|
|
72
|
+
self.global_state["step"] = step + 1
|
|
73
|
+
|
|
74
|
+
if step > 0:
|
|
75
|
+
s = p - p_prev
|
|
76
|
+
|
|
77
|
+
Hs, _ = self.Hvp(s, at_x0=True, var=var, rgrad=None, hvp_method=hvp_method, h=h, normalize=True, retain_grad=False)
|
|
78
|
+
Hs = [t.detach() for t in Hs]
|
|
79
|
+
|
|
80
|
+
if 'hvp_tfm' in self.children:
|
|
81
|
+
Hs = TensorList(apply_transform(self.children['hvp_tfm'], Hs, params=p, grads=var.grad, var=var))
|
|
82
|
+
|
|
83
|
+
self.store(p, ("Hs", "s"), (Hs, s))
|
|
84
|
+
|
|
85
|
+
# -------------------------------- adaptive mu ------------------------------- #
|
|
86
|
+
if self.defaults["adaptive"]:
|
|
87
|
+
g = TensorList(var.get_grad())
|
|
88
|
+
|
|
89
|
+
if self.defaults["adapt_freq"] is None:
|
|
90
|
+
# ---------------------------- deterministic case ---------------------------- #
|
|
91
|
+
g_prev = self.get_state(var.params, "g_prev", cls=TensorList)
|
|
92
|
+
y = g - g_prev
|
|
93
|
+
g_prev.copy_(g)
|
|
94
|
+
denom = y.global_vector_norm()
|
|
95
|
+
denom = denom.clip(min=torch.finfo(denom.dtype).tiny * 2)
|
|
96
|
+
self.global_state["mu_mul"] = s.global_vector_norm() / denom
|
|
97
|
+
|
|
98
|
+
else:
|
|
99
|
+
# -------------------------------- stochastic -------------------------------- #
|
|
100
|
+
adapt_freq = self.defaults["adapt_freq"]
|
|
101
|
+
|
|
102
|
+
# we start on 1nd step, and want to adapt when we start, so use (step - 1)
|
|
103
|
+
if (step - 1) % adapt_freq == 0:
|
|
104
|
+
assert var.closure is not None
|
|
105
|
+
params = TensorList(var.params)
|
|
106
|
+
p_cur = params.clone()
|
|
107
|
+
|
|
108
|
+
# move to previous params and evaluate p_prev with current mini-batch
|
|
109
|
+
params.copy_(self.get_state(var.params, 'p_prev'))
|
|
110
|
+
with torch.enable_grad():
|
|
111
|
+
var.closure()
|
|
112
|
+
g_prev = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
|
|
113
|
+
y = g - g_prev
|
|
114
|
+
|
|
115
|
+
# move back to current params
|
|
116
|
+
params.copy_(p_cur)
|
|
117
|
+
|
|
118
|
+
denom = y.global_vector_norm()
|
|
119
|
+
denom = denom.clip(min=torch.finfo(denom.dtype).tiny * 2)
|
|
120
|
+
self.global_state["mu_mul"] = s.global_vector_norm() / denom
|
|
121
|
+
|
|
122
|
+
torch._foreach_copy_(p_prev, var.params)
|
|
123
|
+
|
|
124
|
+
@torch.no_grad
|
|
125
|
+
def apply(self, var):
|
|
126
|
+
update = TensorList(var.get_update())
|
|
127
|
+
lr,mu = self.get_settings(var.params, "lr", 'mu', cls=NumberList)
|
|
128
|
+
|
|
129
|
+
if "mu_mul" in self.global_state:
|
|
130
|
+
mu = mu * self.global_state["mu_mul"]
|
|
131
|
+
|
|
132
|
+
# --------------------------------- 1st step --------------------------------- #
|
|
133
|
+
# p_prev is not available so make a small step
|
|
134
|
+
step = self.global_state["step"]
|
|
135
|
+
if step == 1:
|
|
136
|
+
if self.defaults["adaptive"]: self.get_state(var.params, "g_prev", init=var.get_grad())
|
|
137
|
+
update.mul_(lr) # separate so that initial_step_size can clip correctly
|
|
138
|
+
update.mul_(initial_step_size(update, 1e-7))
|
|
139
|
+
return var
|
|
140
|
+
|
|
141
|
+
# -------------------------- matrix momentum update -------------------------- #
|
|
142
|
+
s, Hs = self.get_state(var.params, 's', 'Hs', cls=TensorList)
|
|
143
|
+
|
|
144
|
+
update.mul_(lr).sub_(s).add_(Hs*mu)
|
|
145
|
+
var.update = update
|
|
146
|
+
return var
|
|
@@ -0,0 +1,188 @@
|
|
|
1
|
+
from typing import Literal
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from ...core import Chainable, Module, Target, Transform, apply_transform
|
|
6
|
+
from ...utils import NumberList, TensorList, unpack_dicts, unpack_states, generic_ne
|
|
7
|
+
from ..functional import ema_
|
|
8
|
+
from ..momentum.momentum import nag_
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def msam_(
|
|
12
|
+
tensors: TensorList,
|
|
13
|
+
params: TensorList,
|
|
14
|
+
velocity_: TensorList,
|
|
15
|
+
momentum: float | NumberList,
|
|
16
|
+
lr: NumberList | None,
|
|
17
|
+
rho: float | NumberList,
|
|
18
|
+
weight_decay: float | NumberList,
|
|
19
|
+
nesterov: bool = False,
|
|
20
|
+
lerp: bool = False,
|
|
21
|
+
|
|
22
|
+
# inner args
|
|
23
|
+
inner: Module | None = None,
|
|
24
|
+
grads: list[torch.Tensor] | None = None,
|
|
25
|
+
):
|
|
26
|
+
# weights w and wh, momentum μ, perturbation strength ρ
|
|
27
|
+
# w = wh + rho * v / ||v||
|
|
28
|
+
# v1 = μv + g
|
|
29
|
+
# w1 = w - lr*v1
|
|
30
|
+
# wh1 = w1 - rho * v1 / ||v1||
|
|
31
|
+
|
|
32
|
+
# w1 = wh + rho * v / ||v|| - lr*v1
|
|
33
|
+
# vn = rho * v / ||v||
|
|
34
|
+
# v1n = rho * v1 / ||v1||
|
|
35
|
+
# wh1 = wh + vn - lr*v1 - v1n
|
|
36
|
+
|
|
37
|
+
# the update is
|
|
38
|
+
# vn - lr*v1 - v1n
|
|
39
|
+
|
|
40
|
+
# we track ascent direction so it becomes lr*v1 + v1n - vn
|
|
41
|
+
|
|
42
|
+
# can't really decouple it from lr
|
|
43
|
+
# but at least it is now expressed as function of g
|
|
44
|
+
|
|
45
|
+
denom = velocity_.global_vector_norm() / rho
|
|
46
|
+
denom = denom.clip(min=torch.finfo(tensors[0].dtype).tiny * 2)
|
|
47
|
+
vn = velocity_ / denom
|
|
48
|
+
|
|
49
|
+
mom_ = nag_ if nesterov else ema_
|
|
50
|
+
velocity_ = mom_(tensors, velocity_, momentum, dampening=0, lerp=lerp)
|
|
51
|
+
|
|
52
|
+
denom = velocity_.global_vector_norm() / rho
|
|
53
|
+
denom = denom.clip(min=torch.finfo(tensors[0].dtype).tiny * 2)
|
|
54
|
+
v1n = velocity_ / denom
|
|
55
|
+
|
|
56
|
+
if inner is not None:
|
|
57
|
+
assert params is not None
|
|
58
|
+
inner_update = TensorList(apply_transform(inner, tensors, params=params, grads=grads))
|
|
59
|
+
|
|
60
|
+
else:
|
|
61
|
+
assert lr is not None
|
|
62
|
+
inner_update = velocity_ * lr
|
|
63
|
+
|
|
64
|
+
update = inner_update.add_(v1n).sub_(vn)
|
|
65
|
+
|
|
66
|
+
if generic_ne(weight_decay, 0):
|
|
67
|
+
wd = (params + vn).mul_(weight_decay)
|
|
68
|
+
update.add_(wd)
|
|
69
|
+
|
|
70
|
+
return update
|
|
71
|
+
|
|
72
|
+
class MSAM(Transform):
|
|
73
|
+
"""Momentum-SAM from https://arxiv.org/pdf/2401.12033.
|
|
74
|
+
|
|
75
|
+
This implementation expresses the update rule as function of gradient. This way it can be used as a drop-in
|
|
76
|
+
replacement for momentum strategies in other optimizers.
|
|
77
|
+
|
|
78
|
+
To combine MSAM with other optimizers in the way done in the official implementation,
|
|
79
|
+
e.g. to make Adam_MSAM, use ``tz.m.MSAMObjective`` module.
|
|
80
|
+
|
|
81
|
+
Note
|
|
82
|
+
MSAM has a learning rate hyperparameter that can't really be removed from the update rule.
|
|
83
|
+
To avoid compounding learning rate mofications, remove the ``tz.m.LR`` module if you had it.
|
|
84
|
+
|
|
85
|
+
Args:
|
|
86
|
+
lr (float): learning rate. Adding this module adds support for learning rate schedulers.
|
|
87
|
+
momentum (float, optional): momentum (beta). Defaults to 0.9.
|
|
88
|
+
rho (float, optional): perturbation strength. Defaults to 0.3.
|
|
89
|
+
weight_decay (float, optional):
|
|
90
|
+
weight decay. It is applied to perturbed parameters, so it is differnet
|
|
91
|
+
from applying :code:`tz.m.WeightDecay` after MSAM. Defaults to 0.
|
|
92
|
+
nesterov (bool, optional): whether to use nesterov momentum formula. Defaults to False.
|
|
93
|
+
lerp (bool, optional):
|
|
94
|
+
whether to use linear interpolation, if True, this becomes similar to exponential moving average. Defaults to False.
|
|
95
|
+
|
|
96
|
+
Examples:
|
|
97
|
+
MSAM
|
|
98
|
+
|
|
99
|
+
.. code-block:: python
|
|
100
|
+
|
|
101
|
+
opt = tz.Modular(
|
|
102
|
+
model.parameters(),
|
|
103
|
+
tz.m.MSAM(1e-3)
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
Adam with MSAM instead of exponential average. Note that this is different from Adam_MSAM.
|
|
107
|
+
To make Adam_MSAM and such, use the :code:`tz.m.MSAMObjective` module.
|
|
108
|
+
|
|
109
|
+
.. code-block:: python
|
|
110
|
+
|
|
111
|
+
opt = tz.Modular(
|
|
112
|
+
model.parameters(),
|
|
113
|
+
tz.m.RMSprop(0.999, inner=tz.m.MSAM(1e-3)),
|
|
114
|
+
tz.m.Debias(0.9, 0.999),
|
|
115
|
+
)
|
|
116
|
+
"""
|
|
117
|
+
_USES_LR = True
|
|
118
|
+
def __init__(self, lr: float, momentum:float=0.9, rho:float=0.3, weight_decay:float=0, nesterov=False, lerp=False,):
|
|
119
|
+
defaults = dict(momentum=momentum,rho=rho, nesterov=nesterov, lerp=lerp, weight_decay=weight_decay)
|
|
120
|
+
if self._USES_LR: defaults['lr'] = lr
|
|
121
|
+
super().__init__(defaults, uses_grad=False)
|
|
122
|
+
|
|
123
|
+
@torch.no_grad
|
|
124
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
125
|
+
velocity = unpack_states(states, tensors, 'velocity', cls=TensorList)
|
|
126
|
+
s = self.settings[params[0]]
|
|
127
|
+
lerp = s['lerp']
|
|
128
|
+
nesterov = s['nesterov']
|
|
129
|
+
|
|
130
|
+
if self._USES_LR:
|
|
131
|
+
lr, momentum, rho, weight_decay = unpack_dicts(settings, 'lr','momentum','rho','weight_decay', cls=NumberList)
|
|
132
|
+
|
|
133
|
+
else:
|
|
134
|
+
lr=None
|
|
135
|
+
momentum,rho,weight_decay = unpack_dicts(settings, 'momentum','rho','weight_decay', cls=NumberList)
|
|
136
|
+
|
|
137
|
+
return msam_(
|
|
138
|
+
TensorList(tensors),
|
|
139
|
+
params=TensorList(params),
|
|
140
|
+
velocity_=velocity,
|
|
141
|
+
momentum=momentum,
|
|
142
|
+
lr=lr,
|
|
143
|
+
rho=rho,
|
|
144
|
+
weight_decay=weight_decay,
|
|
145
|
+
nesterov=nesterov,
|
|
146
|
+
lerp=lerp,
|
|
147
|
+
|
|
148
|
+
# inner args
|
|
149
|
+
inner=self.children.get("modules", None),
|
|
150
|
+
grads=grads,
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
class MSAMObjective(MSAM):
|
|
155
|
+
"""Momentum-SAM from https://arxiv.org/pdf/2401.12033.
|
|
156
|
+
|
|
157
|
+
Note:
|
|
158
|
+
Please make sure to place ``tz.m.LR`` inside the ``modules`` argument. For example,
|
|
159
|
+
``tz.m.MSAMObjective([tz.m.Adam(), tz.m.LR(1e-3)])``. Putting LR after MSAM will lead
|
|
160
|
+
to an incorrect update rule.
|
|
161
|
+
|
|
162
|
+
Args:
|
|
163
|
+
modules (Chainable): modules that will optimizer the MSAM objective. Make sure :code:`tz.m.LR` is one of them.
|
|
164
|
+
momentum (float, optional): momentum (beta). Defaults to 0.9.
|
|
165
|
+
rho (float, optional): perturbation strength. Defaults to 0.3.
|
|
166
|
+
nesterov (bool, optional): whether to use nesterov momentum formula. Defaults to False.
|
|
167
|
+
lerp (bool, optional):
|
|
168
|
+
whether to use linear interpolation, if True, MSAM momentum becomes similar to exponential moving average.
|
|
169
|
+
Defaults to False.
|
|
170
|
+
|
|
171
|
+
Examples:
|
|
172
|
+
AdamW-MSAM
|
|
173
|
+
|
|
174
|
+
.. code-block:: python
|
|
175
|
+
|
|
176
|
+
opt = tz.Modular(
|
|
177
|
+
bench.parameters(),
|
|
178
|
+
tz.m.MSAMObjective(
|
|
179
|
+
[tz.m.Adam(), tz.m.WeightDecay(1e-3), tz.m.LR(1e-3)],
|
|
180
|
+
rho=1.
|
|
181
|
+
)
|
|
182
|
+
)
|
|
183
|
+
"""
|
|
184
|
+
_USES_LR = False
|
|
185
|
+
def __init__(self, modules: Chainable, momentum:float=0.9, rho:float=0.3, weight_decay:float=0, nesterov=False, lerp=False):
|
|
186
|
+
super().__init__(lr=0, momentum=momentum, rho=rho, weight_decay=weight_decay, nesterov=nesterov, lerp=lerp)
|
|
187
|
+
self.set_child('modules', modules)
|
|
188
|
+
|
|
@@ -19,6 +19,7 @@ def _is_at_least_2d(p: torch.Tensor):
|
|
|
19
19
|
|
|
20
20
|
# stolen from:
|
|
21
21
|
# https://github.com/KellerJordan/Muon/blob/master/muon.py
|
|
22
|
+
# actually at this stage its a frankenstein
|
|
22
23
|
@enable_compilation
|
|
23
24
|
def zeropower_via_newtonschulz5(G: torch.Tensor, steps: int) -> torch.Tensor:
|
|
24
25
|
"""
|
|
@@ -152,7 +153,7 @@ class Orthogonalize(TensorwiseTransform):
|
|
|
152
153
|
The Muon page says that embeddings and classifier heads should not be orthogonalized.
|
|
153
154
|
Usually only matrix parameters that are directly used in matmuls should be orthogonalized.
|
|
154
155
|
|
|
155
|
-
To make Muon, use Split with Adam on 1d params
|
|
156
|
+
To make Muon, use Split with Adam on 1d params
|
|
156
157
|
|
|
157
158
|
Args:
|
|
158
159
|
ns_steps (int, optional):
|
|
@@ -165,6 +166,29 @@ class Orthogonalize(TensorwiseTransform):
|
|
|
165
166
|
Newton-Schulz is very fast, SVD is extremely slow but can be slighly more precise.
|
|
166
167
|
target (str, optional):
|
|
167
168
|
what to set on var.
|
|
169
|
+
|
|
170
|
+
## Examples:
|
|
171
|
+
|
|
172
|
+
standard Muon with Adam fallback
|
|
173
|
+
```py
|
|
174
|
+
opt = tz.Modular(
|
|
175
|
+
model.head.parameters(),
|
|
176
|
+
tz.m.Split(
|
|
177
|
+
# apply muon only to 2D+ parameters
|
|
178
|
+
filter = lambda t: t.ndim >= 2,
|
|
179
|
+
true = [
|
|
180
|
+
tz.m.HeavyBall(),
|
|
181
|
+
tz.m.Orthogonalize(),
|
|
182
|
+
tz.m.LR(1e-2),
|
|
183
|
+
],
|
|
184
|
+
false = tz.m.Adam()
|
|
185
|
+
),
|
|
186
|
+
tz.m.LR(1e-2)
|
|
187
|
+
)
|
|
188
|
+
```
|
|
189
|
+
|
|
190
|
+
Reference:
|
|
191
|
+
Keller Jordan, Yuchen Jin, Vlado Boza, You Jiacheng, Franz Cesista, Laker Newhouse, Jeremy Bernstein - Muon: An optimizer for hidden layers in neural networks (2024) https://github.com/KellerJordan/Muon
|
|
168
192
|
"""
|
|
169
193
|
def __init__(self, ns_steps=5, adjust_lr=False, dual_norm_correction=False,
|
|
170
194
|
method: Literal['newton-schulz', 'svd'] = 'newton-schulz', target:Target='update'):
|
|
@@ -172,9 +196,9 @@ class Orthogonalize(TensorwiseTransform):
|
|
|
172
196
|
super().__init__(uses_grad=False, defaults=defaults, target=target)
|
|
173
197
|
|
|
174
198
|
@torch.no_grad
|
|
175
|
-
def apply_tensor(self, tensor, param, grad, loss, state,
|
|
199
|
+
def apply_tensor(self, tensor, param, grad, loss, state, setting):
|
|
176
200
|
orthogonalize, ns_steps, dual_norm_correction, adjust_lr, method = itemgetter(
|
|
177
|
-
'orthogonalize', 'ns_steps', 'dual_norm_correction', 'adjust_lr', 'method')(
|
|
201
|
+
'orthogonalize', 'ns_steps', 'dual_norm_correction', 'adjust_lr', 'method')(setting)
|
|
178
202
|
|
|
179
203
|
if not orthogonalize: return tensor
|
|
180
204
|
|
|
@@ -199,7 +223,7 @@ class DualNormCorrection(TensorwiseTransform):
|
|
|
199
223
|
def __init__(self, target: Target='update'):
|
|
200
224
|
super().__init__({}, uses_grad=True, target=target)
|
|
201
225
|
|
|
202
|
-
def apply_tensor(self, tensor, param, grad, loss, state,
|
|
226
|
+
def apply_tensor(self, tensor, param, grad, loss, state, setting):
|
|
203
227
|
assert grad is not None
|
|
204
228
|
if (tensor.ndim >= 2) and (tensor.size(0) > 1) and (tensor.size(1) > 1):
|
|
205
229
|
return _dual_norm_correction(tensor, grad, batch_first=False)
|
|
@@ -213,7 +237,7 @@ class MuonAdjustLR(Transform):
|
|
|
213
237
|
defaults = dict(alpha=alpha)
|
|
214
238
|
super().__init__(defaults=defaults, uses_grad=False, target=target)
|
|
215
239
|
|
|
216
|
-
def
|
|
240
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
217
241
|
alphas = [s['alpha'] for s in settings]
|
|
218
242
|
tensors_alphas = [(t, adjust_lr_for_muon(a, t.shape)) for t, a in zip(tensors, alphas) if _is_at_least_2d(t)]
|
|
219
243
|
tensors = [i[0] for i in tensors_alphas]
|
|
@@ -0,0 +1,175 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from ...core import Module, Chainable, apply_transform
|
|
3
|
+
|
|
4
|
+
from ...utils.derivatives import jacobian_wrt, flatten_jacobian
|
|
5
|
+
from ...utils import vec_to_tensors, TensorList
|
|
6
|
+
from ...utils.linalg import linear_operator
|
|
7
|
+
from .lmadagrad import lm_adagrad_apply, lm_adagrad_update
|
|
8
|
+
|
|
9
|
+
class NaturalGradient(Module):
|
|
10
|
+
"""Natural gradient approximated via empirical fisher information matrix.
|
|
11
|
+
|
|
12
|
+
To use this, either pass vector of per-sample losses to the step method, or make sure
|
|
13
|
+
the closure returns it. Gradients will be calculated via batched autograd within this module,
|
|
14
|
+
you don't need to implement the backward pass. When using closure, please add the ``backward`` argument,
|
|
15
|
+
it will always be False but it is required. See below for an example.
|
|
16
|
+
|
|
17
|
+
Note:
|
|
18
|
+
Empirical fisher information matrix may give a really bad approximation in some cases.
|
|
19
|
+
If that is the case, set ``sqrt`` to True to perform whitening instead, which is way more robust.
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
reg (float, optional): regularization parameter. Defaults to 1e-8.
|
|
23
|
+
sqrt (bool, optional):
|
|
24
|
+
if True, uses square root of empirical fisher information matrix. Both EFIM and it's square
|
|
25
|
+
root can be calculated and stored efficiently without ndim^2 memory. Square root
|
|
26
|
+
whitens the gradient and often performs much better, especially when you try to use NGD
|
|
27
|
+
with a vector that isn't strictly per-sample gradients, but rather for example different losses.
|
|
28
|
+
gn_grad (bool, optional):
|
|
29
|
+
if True, uses Gauss-Newton G^T @ f as the gradient, which is effectively sum weighted by value
|
|
30
|
+
and is equivalent to squaring the values. This way you can solve least-squares
|
|
31
|
+
objectives with a NGD-like algorithm. If False, uses sum of per-sample gradients.
|
|
32
|
+
This has an effect when ``sqrt=True``, and affects the ``grad`` attribute.
|
|
33
|
+
Defaults to False.
|
|
34
|
+
batched (bool, optional): whether to use vmapping. Defaults to True.
|
|
35
|
+
|
|
36
|
+
Examples:
|
|
37
|
+
|
|
38
|
+
training a neural network:
|
|
39
|
+
```python
|
|
40
|
+
X = torch.randn(64, 20)
|
|
41
|
+
y = torch.randn(64, 10)
|
|
42
|
+
|
|
43
|
+
model = nn.Sequential(nn.Linear(20, 64), nn.ELU(), nn.Linear(64, 10))
|
|
44
|
+
opt = tz.Modular(
|
|
45
|
+
model.parameters(),
|
|
46
|
+
tz.m.NaturalGradient(),
|
|
47
|
+
tz.m.LR(3e-2)
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
for i in range(100):
|
|
51
|
+
y_hat = model(X) # (64, 10)
|
|
52
|
+
losses = (y_hat - y).pow(2).mean(0) # (10, )
|
|
53
|
+
opt.step(loss=losses)
|
|
54
|
+
if i % 10 == 0:
|
|
55
|
+
print(f'{losses.mean() = }')
|
|
56
|
+
```
|
|
57
|
+
|
|
58
|
+
training a neural network - closure version
|
|
59
|
+
```python
|
|
60
|
+
X = torch.randn(64, 20)
|
|
61
|
+
y = torch.randn(64, 10)
|
|
62
|
+
|
|
63
|
+
model = nn.Sequential(nn.Linear(20, 64), nn.ELU(), nn.Linear(64, 10))
|
|
64
|
+
opt = tz.Modular(
|
|
65
|
+
model.parameters(),
|
|
66
|
+
tz.m.NaturalGradient(),
|
|
67
|
+
tz.m.LR(3e-2)
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
def closure(backward=True):
|
|
71
|
+
y_hat = model(X) # (64, 10)
|
|
72
|
+
return (y_hat - y).pow(2).mean(0) # (10, )
|
|
73
|
+
|
|
74
|
+
for i in range(100):
|
|
75
|
+
losses = opt.step(closure)
|
|
76
|
+
if i % 10 == 0:
|
|
77
|
+
print(f'{losses.mean() = }')
|
|
78
|
+
```
|
|
79
|
+
|
|
80
|
+
minimizing the rosenbrock function with a mix of natural gradient, whitening and gauss-newton:
|
|
81
|
+
```python
|
|
82
|
+
def rosenbrock(X):
|
|
83
|
+
x1, x2 = X
|
|
84
|
+
return torch.stack([(1 - x1).abs(), (10 * (x2 - x1**2).abs())])
|
|
85
|
+
|
|
86
|
+
X = torch.tensor([-1.1, 2.5], requires_grad=True)
|
|
87
|
+
opt = tz.Modular([X], tz.m.NaturalGradient(sqrt=True, gn_grad=True), tz.m.LR(0.05))
|
|
88
|
+
|
|
89
|
+
for iter in range(200):
|
|
90
|
+
losses = rosenbrock(X)
|
|
91
|
+
opt.step(loss=losses)
|
|
92
|
+
if iter % 20 == 0:
|
|
93
|
+
print(f'{losses.mean() = }')
|
|
94
|
+
```
|
|
95
|
+
"""
|
|
96
|
+
def __init__(self, reg:float = 1e-8, sqrt:bool=False, gn_grad:bool=False, batched:bool=True, ):
|
|
97
|
+
super().__init__(defaults=dict(batched=batched, reg=reg, sqrt=sqrt, gn_grad=gn_grad))
|
|
98
|
+
|
|
99
|
+
@torch.no_grad
|
|
100
|
+
def update(self, var):
|
|
101
|
+
params = var.params
|
|
102
|
+
batched = self.defaults['batched']
|
|
103
|
+
gn_grad = self.defaults['gn_grad']
|
|
104
|
+
|
|
105
|
+
closure = var.closure
|
|
106
|
+
assert closure is not None
|
|
107
|
+
|
|
108
|
+
with torch.enable_grad():
|
|
109
|
+
f = var.get_loss(backward=False) # n_out
|
|
110
|
+
assert isinstance(f, torch.Tensor)
|
|
111
|
+
G_list = jacobian_wrt([f.ravel()], params, batched=batched)
|
|
112
|
+
|
|
113
|
+
var.loss = f.sum()
|
|
114
|
+
G = self.global_state["G"] = flatten_jacobian(G_list) # (n_samples, ndim)
|
|
115
|
+
|
|
116
|
+
if gn_grad:
|
|
117
|
+
g = self.global_state["g"] = G.H @ f.detach()
|
|
118
|
+
|
|
119
|
+
else:
|
|
120
|
+
g = self.global_state["g"] = G.sum(0)
|
|
121
|
+
|
|
122
|
+
var.grad = vec_to_tensors(g, params)
|
|
123
|
+
|
|
124
|
+
# set closure to calculate scalar value for line searches etc
|
|
125
|
+
if var.closure is not None:
|
|
126
|
+
def ngd_closure(backward=True):
|
|
127
|
+
if backward:
|
|
128
|
+
var.zero_grad()
|
|
129
|
+
with torch.enable_grad():
|
|
130
|
+
loss = closure(False)
|
|
131
|
+
if gn_grad: loss = loss.pow(2)
|
|
132
|
+
loss = loss.sum()
|
|
133
|
+
loss.backward()
|
|
134
|
+
return loss
|
|
135
|
+
|
|
136
|
+
loss = closure(False)
|
|
137
|
+
if gn_grad: loss = loss.pow(2)
|
|
138
|
+
return loss.sum()
|
|
139
|
+
|
|
140
|
+
var.closure = ngd_closure
|
|
141
|
+
|
|
142
|
+
@torch.no_grad
|
|
143
|
+
def apply(self, var):
|
|
144
|
+
params = var.params
|
|
145
|
+
reg = self.defaults['reg']
|
|
146
|
+
sqrt = self.defaults['sqrt']
|
|
147
|
+
|
|
148
|
+
G: torch.Tensor = self.global_state['G'] # (n_samples, n_dim)
|
|
149
|
+
|
|
150
|
+
if sqrt:
|
|
151
|
+
# this computes U, S <- SVD(M), then calculate update as U S^-1 Uᵀg,
|
|
152
|
+
# but it computes it through eigendecompotision
|
|
153
|
+
U, L = lm_adagrad_update(G.H, reg, 0)
|
|
154
|
+
if U is None or L is None: return var
|
|
155
|
+
|
|
156
|
+
v = lm_adagrad_apply(self.global_state["g"], U, L)
|
|
157
|
+
var.update = vec_to_tensors(v, params)
|
|
158
|
+
return var
|
|
159
|
+
|
|
160
|
+
GGT = G @ G.H # (n_samples, n_samples)
|
|
161
|
+
|
|
162
|
+
if reg != 0:
|
|
163
|
+
GGT.add_(torch.eye(GGT.size(0), device=GGT.device, dtype=GGT.dtype).mul_(reg))
|
|
164
|
+
|
|
165
|
+
z, _ = torch.linalg.solve_ex(GGT, torch.ones_like(GGT[0])) # pylint:disable=not-callable
|
|
166
|
+
v = G.H @ z
|
|
167
|
+
|
|
168
|
+
var.update = vec_to_tensors(v, params)
|
|
169
|
+
return var
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
def get_H(self, var):
|
|
173
|
+
if "G" not in self.global_state: return linear_operator.ScaledIdentity()
|
|
174
|
+
G = self.global_state['G']
|
|
175
|
+
return linear_operator.AtA(G)
|
|
@@ -36,7 +36,7 @@ class OrthoGrad(Transform):
|
|
|
36
36
|
defaults = dict(eps=eps, renormalize=renormalize)
|
|
37
37
|
super().__init__(defaults, uses_grad=False, target=target)
|
|
38
38
|
|
|
39
|
-
def
|
|
39
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
40
40
|
eps = settings[0]['eps']
|
|
41
41
|
renormalize = settings[0]['renormalize']
|
|
42
42
|
|