torchzero 0.3.15__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_identical.py +22 -22
- tests/test_module_autograd.py +586 -0
- tests/test_objective.py +188 -0
- tests/test_opts.py +225 -214
- tests/test_tensorlist.py +0 -8
- tests/test_utils_optimizer.py +0 -1
- torchzero/__init__.py +2 -2
- torchzero/core/__init__.py +7 -4
- torchzero/core/chain.py +20 -23
- torchzero/core/functional.py +90 -24
- torchzero/core/modular.py +53 -57
- torchzero/core/module.py +132 -52
- torchzero/core/objective.py +948 -0
- torchzero/core/reformulation.py +55 -24
- torchzero/core/transform.py +261 -367
- torchzero/linalg/__init__.py +11 -0
- torchzero/linalg/eigh.py +253 -0
- torchzero/linalg/linalg_utils.py +14 -0
- torchzero/{utils/linalg → linalg}/linear_operator.py +99 -49
- torchzero/linalg/matrix_power.py +28 -0
- torchzero/linalg/orthogonalize.py +93 -0
- torchzero/{utils/linalg → linalg}/qr.py +16 -2
- torchzero/{utils/linalg → linalg}/solve.py +74 -88
- torchzero/linalg/svd.py +47 -0
- torchzero/linalg/torch_linalg.py +168 -0
- torchzero/modules/__init__.py +4 -3
- torchzero/modules/adaptive/__init__.py +11 -3
- torchzero/modules/adaptive/adagrad.py +167 -217
- torchzero/modules/adaptive/adahessian.py +76 -105
- torchzero/modules/adaptive/adam.py +53 -76
- torchzero/modules/adaptive/adan.py +50 -31
- torchzero/modules/adaptive/adaptive_heavyball.py +12 -7
- torchzero/modules/adaptive/aegd.py +12 -12
- torchzero/modules/adaptive/esgd.py +98 -119
- torchzero/modules/adaptive/ggt.py +186 -0
- torchzero/modules/adaptive/lion.py +7 -11
- torchzero/modules/adaptive/lre_optimizers.py +299 -0
- torchzero/modules/adaptive/mars.py +7 -7
- torchzero/modules/adaptive/matrix_momentum.py +48 -52
- torchzero/modules/adaptive/msam.py +71 -53
- torchzero/modules/adaptive/muon.py +67 -129
- torchzero/modules/adaptive/natural_gradient.py +63 -41
- torchzero/modules/adaptive/orthograd.py +11 -15
- torchzero/modules/adaptive/psgd/__init__.py +5 -0
- torchzero/modules/adaptive/psgd/_psgd_utils.py +37 -0
- torchzero/modules/adaptive/psgd/psgd.py +1390 -0
- torchzero/modules/adaptive/psgd/psgd_dense_newton.py +174 -0
- torchzero/modules/adaptive/psgd/psgd_kron_newton.py +203 -0
- torchzero/modules/adaptive/psgd/psgd_kron_whiten.py +185 -0
- torchzero/modules/adaptive/psgd/psgd_lra_newton.py +118 -0
- torchzero/modules/adaptive/psgd/psgd_lra_whiten.py +116 -0
- torchzero/modules/adaptive/rmsprop.py +83 -75
- torchzero/modules/adaptive/rprop.py +48 -47
- torchzero/modules/adaptive/sam.py +55 -45
- torchzero/modules/adaptive/shampoo.py +149 -130
- torchzero/modules/adaptive/soap.py +207 -143
- torchzero/modules/adaptive/sophia_h.py +106 -130
- torchzero/modules/clipping/clipping.py +22 -25
- torchzero/modules/clipping/ema_clipping.py +31 -25
- torchzero/modules/clipping/growth_clipping.py +14 -17
- torchzero/modules/conjugate_gradient/cg.py +27 -38
- torchzero/modules/experimental/__init__.py +7 -6
- torchzero/modules/experimental/adanystrom.py +258 -0
- torchzero/modules/experimental/common_directions_whiten.py +142 -0
- torchzero/modules/experimental/coordinate_momentum.py +36 -0
- torchzero/modules/experimental/cubic_adam.py +160 -0
- torchzero/modules/experimental/curveball.py +25 -41
- torchzero/modules/experimental/eigen_sr1.py +182 -0
- torchzero/modules/experimental/eigengrad.py +207 -0
- torchzero/modules/experimental/gradmin.py +2 -2
- torchzero/modules/experimental/higher_order_newton.py +14 -40
- torchzero/modules/experimental/l_infinity.py +1 -1
- torchzero/modules/experimental/matrix_nag.py +122 -0
- torchzero/modules/experimental/newton_solver.py +23 -54
- torchzero/modules/experimental/newtonnewton.py +45 -48
- torchzero/modules/experimental/reduce_outward_lr.py +7 -7
- torchzero/modules/experimental/scipy_newton_cg.py +21 -24
- torchzero/modules/experimental/spsa1.py +3 -3
- torchzero/modules/experimental/structural_projections.py +1 -4
- torchzero/modules/grad_approximation/fdm.py +2 -2
- torchzero/modules/grad_approximation/forward_gradient.py +7 -7
- torchzero/modules/grad_approximation/grad_approximator.py +23 -16
- torchzero/modules/grad_approximation/rfdm.py +24 -21
- torchzero/modules/least_squares/gn.py +121 -50
- torchzero/modules/line_search/backtracking.py +4 -4
- torchzero/modules/line_search/line_search.py +33 -33
- torchzero/modules/line_search/strong_wolfe.py +4 -4
- torchzero/modules/misc/debug.py +12 -12
- torchzero/modules/misc/escape.py +10 -10
- torchzero/modules/misc/gradient_accumulation.py +11 -79
- torchzero/modules/misc/homotopy.py +16 -8
- torchzero/modules/misc/misc.py +121 -123
- torchzero/modules/misc/multistep.py +52 -53
- torchzero/modules/misc/regularization.py +49 -44
- torchzero/modules/misc/split.py +31 -29
- torchzero/modules/misc/switch.py +37 -32
- torchzero/modules/momentum/averaging.py +14 -14
- torchzero/modules/momentum/cautious.py +37 -31
- torchzero/modules/momentum/momentum.py +12 -12
- torchzero/modules/ops/__init__.py +4 -4
- torchzero/modules/ops/accumulate.py +21 -21
- torchzero/modules/ops/binary.py +67 -66
- torchzero/modules/ops/higher_level.py +20 -20
- torchzero/modules/ops/multi.py +44 -41
- torchzero/modules/ops/reduce.py +26 -23
- torchzero/modules/ops/unary.py +53 -53
- torchzero/modules/ops/utility.py +47 -46
- torchzero/modules/{functional.py → opt_utils.py} +1 -1
- torchzero/modules/projections/galore.py +1 -1
- torchzero/modules/projections/projection.py +46 -43
- torchzero/modules/quasi_newton/__init__.py +1 -1
- torchzero/modules/quasi_newton/damping.py +2 -2
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +1 -1
- torchzero/modules/quasi_newton/lbfgs.py +10 -10
- torchzero/modules/quasi_newton/lsr1.py +10 -10
- torchzero/modules/quasi_newton/quasi_newton.py +54 -39
- torchzero/modules/quasi_newton/sg2.py +69 -205
- torchzero/modules/restarts/restars.py +39 -37
- torchzero/modules/second_order/__init__.py +2 -2
- torchzero/modules/second_order/ifn.py +31 -62
- torchzero/modules/second_order/inm.py +57 -53
- torchzero/modules/second_order/multipoint.py +40 -80
- torchzero/modules/second_order/newton.py +165 -196
- torchzero/modules/second_order/newton_cg.py +105 -157
- torchzero/modules/second_order/nystrom.py +216 -185
- torchzero/modules/second_order/rsn.py +132 -125
- torchzero/modules/smoothing/laplacian.py +13 -12
- torchzero/modules/smoothing/sampling.py +10 -10
- torchzero/modules/step_size/adaptive.py +24 -24
- torchzero/modules/step_size/lr.py +17 -17
- torchzero/modules/termination/termination.py +32 -30
- torchzero/modules/trust_region/cubic_regularization.py +3 -3
- torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
- torchzero/modules/trust_region/trust_cg.py +2 -2
- torchzero/modules/trust_region/trust_region.py +27 -22
- torchzero/modules/variance_reduction/svrg.py +23 -21
- torchzero/modules/weight_decay/__init__.py +2 -1
- torchzero/modules/weight_decay/reinit.py +83 -0
- torchzero/modules/weight_decay/weight_decay.py +17 -18
- torchzero/modules/wrappers/optim_wrapper.py +14 -14
- torchzero/modules/zeroth_order/cd.py +10 -7
- torchzero/optim/mbs.py +291 -0
- torchzero/optim/root.py +3 -3
- torchzero/optim/utility/split.py +2 -1
- torchzero/optim/wrappers/directsearch.py +27 -63
- torchzero/optim/wrappers/fcmaes.py +14 -35
- torchzero/optim/wrappers/mads.py +11 -31
- torchzero/optim/wrappers/moors.py +66 -0
- torchzero/optim/wrappers/nevergrad.py +4 -13
- torchzero/optim/wrappers/nlopt.py +31 -25
- torchzero/optim/wrappers/optuna.py +8 -13
- torchzero/optim/wrappers/pybobyqa.py +124 -0
- torchzero/optim/wrappers/scipy/__init__.py +7 -0
- torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
- torchzero/optim/wrappers/scipy/brute.py +48 -0
- torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
- torchzero/optim/wrappers/scipy/direct.py +69 -0
- torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
- torchzero/optim/wrappers/scipy/experimental.py +141 -0
- torchzero/optim/wrappers/scipy/minimize.py +151 -0
- torchzero/optim/wrappers/scipy/sgho.py +111 -0
- torchzero/optim/wrappers/wrapper.py +121 -0
- torchzero/utils/__init__.py +7 -25
- torchzero/utils/benchmarks/__init__.py +0 -0
- torchzero/utils/benchmarks/logistic.py +122 -0
- torchzero/utils/compile.py +2 -2
- torchzero/utils/derivatives.py +97 -73
- torchzero/utils/optimizer.py +4 -77
- torchzero/utils/python_tools.py +31 -0
- torchzero/utils/tensorlist.py +11 -5
- torchzero/utils/thoad_tools.py +68 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/METADATA +1 -1
- torchzero-0.4.1.dist-info/RECORD +209 -0
- tests/test_vars.py +0 -185
- torchzero/core/var.py +0 -376
- torchzero/modules/adaptive/lmadagrad.py +0 -186
- torchzero/modules/experimental/momentum.py +0 -160
- torchzero/optim/wrappers/scipy.py +0 -572
- torchzero/utils/linalg/__init__.py +0 -12
- torchzero/utils/linalg/matrix_funcs.py +0 -87
- torchzero/utils/linalg/orthogonalize.py +0 -12
- torchzero/utils/linalg/svd.py +0 -20
- torchzero/utils/ops.py +0 -10
- torchzero-0.3.15.dist-info/RECORD +0 -175
- /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/WHEEL +0 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,174 @@
|
|
|
1
|
+
# pylint:disable=not-callable
|
|
2
|
+
"""all functions are from https://github.com/lixilinx/psgd_torch/blob/master/psgd.py"""
|
|
3
|
+
import math
|
|
4
|
+
import warnings
|
|
5
|
+
from typing import Literal
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
|
|
9
|
+
from ....core import Chainable, HVPMethod, Transform
|
|
10
|
+
from ....utils import Distributions, TensorList, vec_to_tensors_
|
|
11
|
+
from ._psgd_utils import _initialize_lra_state_
|
|
12
|
+
from .psgd import (
|
|
13
|
+
lift2single,
|
|
14
|
+
update_precond_dense_eq,
|
|
15
|
+
update_precond_dense_q0p5eq1p5,
|
|
16
|
+
update_precond_dense_qep,
|
|
17
|
+
update_precond_dense_qeq,
|
|
18
|
+
update_precond_dense_quad,
|
|
19
|
+
update_precond_dense_quad4p,
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
# matches
|
|
23
|
+
class PSGDDenseNewton(Transform):
|
|
24
|
+
"""Dense hessian preconditioner from Preconditioned Stochastic Gradient Descent (see https://github.com/lixilinx/psgd_torch)
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
init_scale (float | None, optional):
|
|
28
|
+
initial scale of the preconditioner. If None, determined based on a heuristic. Defaults to None.
|
|
29
|
+
lr_preconditioner (float, optional): learning rate of the preconditioner. Defaults to 0.1.
|
|
30
|
+
betaL (float, optional): EMA factor for the L-smoothness constant wrt Q. Defaults to 0.9.
|
|
31
|
+
damping (float, optional):
|
|
32
|
+
adds small noise to hessian-vector product when updating the preconditioner. Defaults to 1e-9.
|
|
33
|
+
grad_clip_max_norm (float, optional): clips norm of the update. Defaults to float("inf").
|
|
34
|
+
update_probability (float, optional): probability of updating preconditioner on each step. Defaults to 1.0.
|
|
35
|
+
dQ (str, optional): geometry for preconditioner update. Defaults to "Q0.5EQ1.5".
|
|
36
|
+
hvp_method (HVPMethod, optional): how to compute hessian-vector products. Defaults to 'autograd'.
|
|
37
|
+
h (float, optional):
|
|
38
|
+
if ``hvp_method`` is ``"fd_central"`` or ``"fd_forward"``, controls finite difference step size.
|
|
39
|
+
Defaults to 1e-3.
|
|
40
|
+
distribution (Distributions, optional):
|
|
41
|
+
distribution for random vectors for hessian-vector products. Defaults to 'normal'.
|
|
42
|
+
|
|
43
|
+
inner (Chainable | None, optional): preconditioning will be applied to output of this module. Defaults to None.
|
|
44
|
+
|
|
45
|
+
###Examples:
|
|
46
|
+
|
|
47
|
+
Pure Dense Newton PSGD:
|
|
48
|
+
```py
|
|
49
|
+
optimizer = tz.Optimizer(
|
|
50
|
+
model.parameters(),
|
|
51
|
+
tz.m.DenseNewton(),
|
|
52
|
+
tz.m.LR(1e-3),
|
|
53
|
+
)
|
|
54
|
+
```
|
|
55
|
+
|
|
56
|
+
Applying preconditioner to momentum:
|
|
57
|
+
```py
|
|
58
|
+
optimizer = tz.Optimizer(
|
|
59
|
+
model.parameters(),
|
|
60
|
+
tz.m.DenseNewton(inner=tz.m.EMA(0.9)),
|
|
61
|
+
tz.m.LR(1e-3),
|
|
62
|
+
)
|
|
63
|
+
```
|
|
64
|
+
"""
|
|
65
|
+
def __init__(
|
|
66
|
+
self,
|
|
67
|
+
init_scale: float | None = None,
|
|
68
|
+
lr_preconditioner=0.1,
|
|
69
|
+
betaL=0.9,
|
|
70
|
+
damping=1e-9,
|
|
71
|
+
grad_clip_max_norm=float("inf"),
|
|
72
|
+
update_probability=1.0,
|
|
73
|
+
dQ: Literal["QUAD4P", "QUAD", "QEP", "EQ", "QEQ", "Q0p5EQ1p5", "Q0.5EQ1.5"] = "Q0.5EQ1.5",
|
|
74
|
+
|
|
75
|
+
hvp_method: HVPMethod = 'autograd',
|
|
76
|
+
h: float = 1e-3,
|
|
77
|
+
distribution: Distributions = 'normal',
|
|
78
|
+
|
|
79
|
+
inner: Chainable | None = None,
|
|
80
|
+
):
|
|
81
|
+
defaults = locals().copy()
|
|
82
|
+
del defaults["inner"], defaults["self"]
|
|
83
|
+
super().__init__(defaults, inner=inner)
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
@torch.no_grad
|
|
87
|
+
def update_states(self, objective, states, settings):
|
|
88
|
+
fs = settings[0]
|
|
89
|
+
|
|
90
|
+
# -------------------------------- initialize -------------------------------- #
|
|
91
|
+
if "Q" not in self.global_state:
|
|
92
|
+
|
|
93
|
+
p = objective.params[0]
|
|
94
|
+
dQ = fs["dQ"]
|
|
95
|
+
init_scale = fs["init_scale"]
|
|
96
|
+
|
|
97
|
+
if init_scale is None:
|
|
98
|
+
self.global_state["Q"] = None
|
|
99
|
+
|
|
100
|
+
else:
|
|
101
|
+
n = sum(p.numel() for p in objective.params)
|
|
102
|
+
if dQ == "QUAD4P":
|
|
103
|
+
init_scale *= init_scale
|
|
104
|
+
self.global_state["Q"] = torch.eye(n, dtype=p.dtype, device=p.device) * init_scale
|
|
105
|
+
|
|
106
|
+
self.global_state["L"] = lift2single(torch.zeros([], dtype=p.dtype, device=p.device)) # Lipschitz smoothness constant estimation for the psgd criterion
|
|
107
|
+
|
|
108
|
+
if dQ == "QUAD4P":
|
|
109
|
+
self.global_state["update_precond"] = update_precond_dense_quad4p
|
|
110
|
+
self.global_state["precond_grad"] = lambda Q, g: Q @ g
|
|
111
|
+
assert torch.finfo(p.dtype).eps < 1e-6, "Directly fitting P needs at least single precision"
|
|
112
|
+
|
|
113
|
+
elif dQ == "QUAD":
|
|
114
|
+
self.global_state["update_precond"] = update_precond_dense_quad
|
|
115
|
+
self.global_state["precond_grad"] = lambda Q, g: Q @ (Q @ g) # Q is symmetric; just save one transpose
|
|
116
|
+
|
|
117
|
+
else:
|
|
118
|
+
self.global_state["precond_grad"] = lambda Q, g: Q.T @ (Q @ g)
|
|
119
|
+
if dQ == "QEP":
|
|
120
|
+
self.global_state["update_precond"] = update_precond_dense_qep
|
|
121
|
+
elif dQ == "EQ":
|
|
122
|
+
self.global_state["update_precond"] = update_precond_dense_eq
|
|
123
|
+
elif dQ == "QEQ":
|
|
124
|
+
self.global_state["update_precond"] = update_precond_dense_qeq
|
|
125
|
+
else:
|
|
126
|
+
assert (dQ == "Q0p5EQ1p5") or (dQ == "Q0.5EQ1.5"), f"Invalid choice for dQ: '{dQ}'"
|
|
127
|
+
self.global_state["update_precond"] = update_precond_dense_q0p5eq1p5
|
|
128
|
+
|
|
129
|
+
# ---------------------------------- update ---------------------------------- #
|
|
130
|
+
Q = self.global_state["Q"]
|
|
131
|
+
if (torch.rand([]) < fs["update_probability"]) or Q is None:
|
|
132
|
+
|
|
133
|
+
# hessian-vector product
|
|
134
|
+
vs = TensorList(objective.params).sample_like(distribution=fs["distribution"])
|
|
135
|
+
Hvs, _ = objective.hessian_vector_product(z=vs, rgrad=None, at_x0=True, hvp_method=fs["hvp_method"], h=fs["h"])
|
|
136
|
+
|
|
137
|
+
v = torch.cat([t.ravel() for t in vs]).unsqueeze(1)
|
|
138
|
+
h = torch.cat([t.ravel() for t in Hvs]).unsqueeze(1)
|
|
139
|
+
|
|
140
|
+
# initialize on the fly
|
|
141
|
+
if Q is None:
|
|
142
|
+
scale = (torch.mean(v*v))**(1/4) * (torch.mean(h**4) + fs["damping"]**4)**(-1/8)
|
|
143
|
+
if fs["dQ"] == "QUAD4P": # Q actually is P in this case
|
|
144
|
+
scale *= scale
|
|
145
|
+
Q = self.global_state["Q"] = torch.eye(len(v), dtype=v.dtype, device=v.device) * scale
|
|
146
|
+
|
|
147
|
+
# update preconditioner
|
|
148
|
+
self.global_state["update_precond"](
|
|
149
|
+
Q=Q,
|
|
150
|
+
L=self.global_state["L"],
|
|
151
|
+
v=v,
|
|
152
|
+
h=h,
|
|
153
|
+
lr=fs["lr_preconditioner"],
|
|
154
|
+
betaL=fs["betaL"],
|
|
155
|
+
damping=fs["damping"],
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
@torch.no_grad
|
|
159
|
+
def apply_states(self, objective, states, settings):
|
|
160
|
+
updates = objective.get_updates()
|
|
161
|
+
|
|
162
|
+
# cat grads
|
|
163
|
+
g = torch.cat([t.ravel() for t in updates]).unsqueeze(1) # column vec
|
|
164
|
+
pre_grad = self.global_state["precond_grad"](self.global_state["Q"], g)
|
|
165
|
+
|
|
166
|
+
# norm clipping
|
|
167
|
+
grad_clip_max_norm = settings[0]["grad_clip_max_norm"]
|
|
168
|
+
if grad_clip_max_norm < float("inf"): # clip preconditioned gradient
|
|
169
|
+
grad_norm = torch.linalg.vector_norm(pre_grad)
|
|
170
|
+
if grad_norm > grad_clip_max_norm:
|
|
171
|
+
pre_grad *= grad_clip_max_norm / grad_norm
|
|
172
|
+
|
|
173
|
+
vec_to_tensors_(pre_grad, updates)
|
|
174
|
+
return objective
|
|
@@ -0,0 +1,203 @@
|
|
|
1
|
+
# pylint:disable=not-callable
|
|
2
|
+
"""all functions are from https://github.com/lixilinx/psgd_torch/blob/master/psgd.py"""
|
|
3
|
+
import math
|
|
4
|
+
import warnings
|
|
5
|
+
from typing import Literal
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
|
|
9
|
+
from ....core import Chainable, HVPMethod, Transform
|
|
10
|
+
from ....utils import NumberList, TensorList, Distributions
|
|
11
|
+
from .psgd import (
|
|
12
|
+
init_kron,
|
|
13
|
+
precond_grad_kron,
|
|
14
|
+
update_precond_kron_newton_eq,
|
|
15
|
+
update_precond_kron_newton_q0p5eq1p5,
|
|
16
|
+
update_precond_kron_newton_qep,
|
|
17
|
+
update_precond_kron_newton_qeq,
|
|
18
|
+
update_precond_kron_newton_quad,
|
|
19
|
+
update_precond_kron_newton_quad4p,
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
# matches
|
|
23
|
+
class PSGDKronNewton(Transform):
|
|
24
|
+
"""Kron hessian preconditioner from Preconditioned Stochastic Gradient Descent (see https://github.com/lixilinx/psgd_torch)
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
max_dim (int, optional): dimensions with size larger than this use diagonal preconditioner. Defaults to 10_000.
|
|
28
|
+
max_skew (float, optional):
|
|
29
|
+
if memory used by full preconditioner (dim^2) is larger than total number of elements in a parameter times ``max_skew``, it uses a diagonal preconditioner. Defaults to 1.0.
|
|
30
|
+
init_scale (float | None, optional):
|
|
31
|
+
initial scale of the preconditioner. If None, determined based on a heuristic. Defaults to None.
|
|
32
|
+
lr_preconditioner (float, optional): learning rate of the preconditioner. Defaults to 0.1.
|
|
33
|
+
betaL (float, optional): EMA factor for the L-smoothness constant wrt Q. Defaults to 0.9.
|
|
34
|
+
damping (float, optional): adds small noise to gradient when updating the preconditioner. Defaults to 1e-9.
|
|
35
|
+
grad_clip_max_amp (float, optional): clips amplitude of the update. Defaults to float("inf").
|
|
36
|
+
update_probability (float, optional): probability of updating preconditioner on each step. Defaults to 1.0.
|
|
37
|
+
dQ (str, optional): geometry for preconditioner update. Defaults to "Q0.5EQ1.5".
|
|
38
|
+
balance_probability (float, optional):
|
|
39
|
+
probablility of balancing the dynamic ranges of the factors of Q to avoid over/under-flow on each step. Defaults to 0.01.
|
|
40
|
+
hvp_method (HVPMethod, optional): how to compute hessian-vector products. Defaults to 'autograd'.
|
|
41
|
+
h (float, optional):
|
|
42
|
+
if ``hvp_method`` is ``"fd_central"`` or ``"fd_forward"``, controls finite difference step size.
|
|
43
|
+
Defaults to 1e-3.
|
|
44
|
+
distribution (Distributions, optional):
|
|
45
|
+
distribution for random vectors for hessian-vector products. Defaults to 'normal'.
|
|
46
|
+
inner (Chainable | None, optional): preconditioning will be applied to output of this module. Defaults to None.
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
###Examples:
|
|
50
|
+
|
|
51
|
+
Pure PSGD Kron Newton:
|
|
52
|
+
```py
|
|
53
|
+
optimizer = tz.Optimizer(
|
|
54
|
+
model.parameters(),
|
|
55
|
+
tz.m.KronNewton(),
|
|
56
|
+
tz.m.LR(1e-3),
|
|
57
|
+
)
|
|
58
|
+
```
|
|
59
|
+
|
|
60
|
+
Applying preconditioner to momentum:
|
|
61
|
+
```py
|
|
62
|
+
optimizer = tz.Optimizer(
|
|
63
|
+
model.parameters(),
|
|
64
|
+
tz.m.KronNewton(inner=tz.m.EMA(0.9)),
|
|
65
|
+
tz.m.LR(1e-3),
|
|
66
|
+
)
|
|
67
|
+
```
|
|
68
|
+
"""
|
|
69
|
+
def __init__(
|
|
70
|
+
self,
|
|
71
|
+
max_dim: int = 10_000,
|
|
72
|
+
max_skew: float = 1.0,
|
|
73
|
+
init_scale: float | None = None,
|
|
74
|
+
lr_preconditioner: float = 0.1,
|
|
75
|
+
betaL: float = 0.9,
|
|
76
|
+
damping: float = 1e-9,
|
|
77
|
+
grad_clip_max_amp: float = float("inf"),
|
|
78
|
+
update_probability: float= 1.0,
|
|
79
|
+
dQ: Literal["QEP", "EQ", "QEQ", "QUAD", "Q0.5EQ1.5", "Q0p5EQ1p5", "QUAD4P"] = "Q0.5EQ1.5",
|
|
80
|
+
balance_probability: float = 0.01,
|
|
81
|
+
|
|
82
|
+
hvp_method: HVPMethod = 'autograd',
|
|
83
|
+
h: float = 1e-3,
|
|
84
|
+
distribution: Distributions = 'normal',
|
|
85
|
+
|
|
86
|
+
inner: Chainable | None = None,
|
|
87
|
+
):
|
|
88
|
+
defaults = locals().copy()
|
|
89
|
+
del defaults["inner"], defaults["self"]
|
|
90
|
+
super().__init__(defaults, inner=inner)
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def _initialize_state(self, param, state, setting):
|
|
94
|
+
assert "initialized" not in state
|
|
95
|
+
state["initialized"] = True
|
|
96
|
+
|
|
97
|
+
# initialize preconditioners
|
|
98
|
+
if setting["init_scale"] is None:
|
|
99
|
+
warnings.warn("FYI: Will set the preconditioner initial scale on the fly. Recommend to set it manually.")
|
|
100
|
+
state["QLs_exprs"] = None
|
|
101
|
+
else:
|
|
102
|
+
state["QLs_exprs"] = init_kron(
|
|
103
|
+
param.squeeze(),
|
|
104
|
+
Scale=setting["init_scale"],
|
|
105
|
+
max_size=setting["max_dim"],
|
|
106
|
+
max_skew=setting["max_skew"],
|
|
107
|
+
dQ=setting["dQ"],
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
dQ = setting["dQ"]
|
|
111
|
+
if dQ == "QUAD4P":
|
|
112
|
+
assert torch.finfo(param.dtype).eps < 1e-6, "Directly fitting P needs at least single precision"
|
|
113
|
+
state["update_precond"] = update_precond_kron_newton_quad4p
|
|
114
|
+
state["precond_grad"] = lambda QL, exprs, G: exprs[0](*QL[0], G) # it's exprA(*Q, G)
|
|
115
|
+
|
|
116
|
+
else:
|
|
117
|
+
state["precond_grad"] = precond_grad_kron
|
|
118
|
+
if dQ == "QEP":
|
|
119
|
+
state["update_precond"] = update_precond_kron_newton_quad
|
|
120
|
+
elif dQ == "EQ":
|
|
121
|
+
state["update_precond"] = update_precond_kron_newton_qep
|
|
122
|
+
elif dQ == "QEQ":
|
|
123
|
+
state["update_precond"] = update_precond_kron_newton_eq
|
|
124
|
+
elif dQ == "QUAD":
|
|
125
|
+
state["update_precond"] = update_precond_kron_newton_qeq
|
|
126
|
+
else:
|
|
127
|
+
assert (dQ == "Q0.5EQ1.5") or (dQ == "Q0p5EQ1p5"), f"Invalid choice for dQ: '{dQ}'"
|
|
128
|
+
state["update_precond"] = update_precond_kron_newton_q0p5eq1p5
|
|
129
|
+
|
|
130
|
+
@torch.no_grad
|
|
131
|
+
def update_states(self, objective, states, settings):
|
|
132
|
+
|
|
133
|
+
# initialize states
|
|
134
|
+
for param, state, setting in zip(objective.params, states, settings):
|
|
135
|
+
if "initialized" not in state:
|
|
136
|
+
self._initialize_state(param, state, setting)
|
|
137
|
+
|
|
138
|
+
fs = settings[0]
|
|
139
|
+
|
|
140
|
+
uninitialized = any(state["QLs_exprs"] is None for state in states)
|
|
141
|
+
if (torch.rand([]) < fs["update_probability"]) or uninitialized:
|
|
142
|
+
|
|
143
|
+
# hessian-vector product
|
|
144
|
+
vs = TensorList(objective.params).sample_like(distribution=fs["distribution"])
|
|
145
|
+
Hvs, _ = objective.hessian_vector_product(z=vs, rgrad=None, at_x0=True, hvp_method=fs["hvp_method"], h=fs["h"])
|
|
146
|
+
|
|
147
|
+
# initialize on the fly (why does it use vs?)
|
|
148
|
+
if uninitialized:
|
|
149
|
+
|
|
150
|
+
scale = (sum([torch.sum(torch.abs(v)**2) for v in vs])/sum([v.numel() for v in vs])) ** (1/4) # (mean(|v|^2))^(1/4)
|
|
151
|
+
|
|
152
|
+
scale = scale * (max([torch.mean((torch.abs(h))**4) for h in Hvs]) + fs["damping"]**4) ** (-1/8) # (mean(|v|^2))^(1/4) * (mean(|h|^4))^(-1/8)
|
|
153
|
+
|
|
154
|
+
for h, state, setting in zip(Hvs, states, settings):
|
|
155
|
+
if state["QLs_exprs"] is None:
|
|
156
|
+
state["QLs_exprs"] = init_kron(
|
|
157
|
+
h.squeeze(),
|
|
158
|
+
Scale=scale,
|
|
159
|
+
max_size=setting["max_dim"],
|
|
160
|
+
max_skew=setting["max_skew"],
|
|
161
|
+
dQ=setting["dQ"],
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
# update preconditioner
|
|
165
|
+
for v, h, state, setting in zip(vs, Hvs, states, settings):
|
|
166
|
+
state["update_precond"](
|
|
167
|
+
*state["QLs_exprs"],
|
|
168
|
+
v.squeeze(),
|
|
169
|
+
h.squeeze(),
|
|
170
|
+
lr=setting["lr_preconditioner"],
|
|
171
|
+
betaL=setting["betaL"],
|
|
172
|
+
damping=setting["damping"],
|
|
173
|
+
balance_prob=setting["balance_probability"]
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
@torch.no_grad
|
|
177
|
+
def apply_states(self, objective, states, settings):
|
|
178
|
+
|
|
179
|
+
params = objective.params
|
|
180
|
+
tensors = objective.get_updates()
|
|
181
|
+
pre_tensors = []
|
|
182
|
+
|
|
183
|
+
# precondition
|
|
184
|
+
for param, tensor, state in zip(params, tensors, states):
|
|
185
|
+
t = state["precond_grad"](
|
|
186
|
+
*state["QLs_exprs"],
|
|
187
|
+
tensor.squeeze(),
|
|
188
|
+
)
|
|
189
|
+
pre_tensors.append(t.view_as(param))
|
|
190
|
+
|
|
191
|
+
# norm clipping
|
|
192
|
+
grad_clip_max_amp = settings[0]["grad_clip_max_amp"]
|
|
193
|
+
if grad_clip_max_amp < math.inf:
|
|
194
|
+
pre_tensors = TensorList(pre_tensors)
|
|
195
|
+
num_params = sum(t.numel() for t in pre_tensors)
|
|
196
|
+
|
|
197
|
+
avg_amp = pre_tensors.dot(pre_tensors.conj()).div(num_params).sqrt()
|
|
198
|
+
|
|
199
|
+
if avg_amp > grad_clip_max_amp:
|
|
200
|
+
torch._foreach_mul_(pre_tensors, grad_clip_max_amp / avg_amp)
|
|
201
|
+
|
|
202
|
+
objective.updates = pre_tensors
|
|
203
|
+
return objective
|
|
@@ -0,0 +1,185 @@
|
|
|
1
|
+
# pylint:disable=not-callable
|
|
2
|
+
"""all functions are from https://github.com/lixilinx/psgd_torch/blob/master/psgd.py"""
|
|
3
|
+
import math
|
|
4
|
+
import warnings
|
|
5
|
+
from typing import Literal
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
|
|
9
|
+
from ....core import Chainable, TensorTransform
|
|
10
|
+
from ....utils import NumberList, TensorList
|
|
11
|
+
from .psgd import (
|
|
12
|
+
init_kron,
|
|
13
|
+
precond_grad_kron,
|
|
14
|
+
update_precond_kron_whiten_eq,
|
|
15
|
+
update_precond_kron_whiten_q0p5eq1p5,
|
|
16
|
+
update_precond_kron_whiten_qep,
|
|
17
|
+
update_precond_kron_whiten_qeq,
|
|
18
|
+
update_precond_kron_whiten_quad,
|
|
19
|
+
update_precond_kron_whiten_quad4p,
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
# matches
|
|
23
|
+
class PSGDKronWhiten(TensorTransform):
|
|
24
|
+
"""Kron whitening preconditioner from Preconditioned Stochastic Gradient Descent (see https://github.com/lixilinx/psgd_torch)
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
max_dim (int, optional): dimensions with size larger than this use diagonal preconditioner. Defaults to 10_000.
|
|
28
|
+
max_skew (float, optional):
|
|
29
|
+
if memory used by full preconditioner (dim^2) is larger than total number of elements in a parameter times ``max_skew``, it uses a diagonal preconditioner. Defaults to 1.0.
|
|
30
|
+
init_scale (float | None, optional):
|
|
31
|
+
initial scale of the preconditioner. If None, determined from magnitude of the first gradient. Defaults to None.
|
|
32
|
+
lr_preconditioner (float, optional): learning rate of the preconditioner. Defaults to 0.1.
|
|
33
|
+
betaL (float, optional): EMA factor for the L-smoothness constant wrt Q. Defaults to 0.9.
|
|
34
|
+
damping (float, optional): adds small noise to gradient when updating the preconditioner. Defaults to 1e-9.
|
|
35
|
+
grad_clip_max_amp (float, optional): clips amplitude of the update. Defaults to float("inf").
|
|
36
|
+
update_probability (float, optional): probability of updating preconditioner on each step. Defaults to 1.0.
|
|
37
|
+
dQ (str, optional): geometry for preconditioner update. Defaults to "Q0.5EQ1.5".
|
|
38
|
+
balance_probability (float, optional):
|
|
39
|
+
probablility of balancing the dynamic ranges of the factors of Q to avoid over/under-flow on each step. Defaults to 0.01.
|
|
40
|
+
|
|
41
|
+
inner (Chainable | None, optional): preconditioning will be applied to output of this module. Defaults to None.
|
|
42
|
+
|
|
43
|
+
###Examples:
|
|
44
|
+
|
|
45
|
+
Pure PSGD Kron:
|
|
46
|
+
```py
|
|
47
|
+
optimizer = tz.Optimizer(
|
|
48
|
+
model.parameters(),
|
|
49
|
+
tz.m.KronWhiten(),
|
|
50
|
+
tz.m.LR(1e-3),
|
|
51
|
+
)
|
|
52
|
+
```
|
|
53
|
+
|
|
54
|
+
Momentum into preconditioner (whitens momentum):
|
|
55
|
+
```py
|
|
56
|
+
optimizer = tz.Optimizer(
|
|
57
|
+
model.parameters(),
|
|
58
|
+
tz.m.EMA(0.9),
|
|
59
|
+
tz.m.KronWhiten(),
|
|
60
|
+
tz.m.LR(1e-3),
|
|
61
|
+
)
|
|
62
|
+
```
|
|
63
|
+
|
|
64
|
+
Updating the preconditioner from gradients and applying it to momentum:
|
|
65
|
+
```py
|
|
66
|
+
optimizer = tz.Optimizer(
|
|
67
|
+
model.parameters(),
|
|
68
|
+
tz.m.KronWhiten(inner=tz.m.EMA(0.9)),
|
|
69
|
+
tz.m.LR(1e-3),
|
|
70
|
+
)
|
|
71
|
+
```
|
|
72
|
+
|
|
73
|
+
"""
|
|
74
|
+
def __init__(
|
|
75
|
+
self,
|
|
76
|
+
max_dim: int = 10_000,
|
|
77
|
+
max_skew: float = 1.0,
|
|
78
|
+
init_scale: float | None = None,
|
|
79
|
+
lr_preconditioner: float = 0.1,
|
|
80
|
+
betaL: float = 0.9,
|
|
81
|
+
damping: float = 1e-9,
|
|
82
|
+
grad_clip_max_amp: float = float("inf"),
|
|
83
|
+
update_probability: float= 1.0,
|
|
84
|
+
dQ: Literal["QEP", "EQ", "QEQ", "QUAD", "Q0.5EQ1.5", "Q0p5EQ1p5", "QUAD4P"] = "Q0.5EQ1.5",
|
|
85
|
+
balance_probability: float = 0.01,
|
|
86
|
+
|
|
87
|
+
inner: Chainable | None = None,
|
|
88
|
+
):
|
|
89
|
+
defaults = locals().copy()
|
|
90
|
+
del defaults["inner"], defaults["self"]
|
|
91
|
+
super().__init__(defaults, inner=inner)
|
|
92
|
+
|
|
93
|
+
@torch.no_grad
|
|
94
|
+
def single_tensor_initialize(self, tensor, param, grad, loss, state, setting):
|
|
95
|
+
# initialize preconditioners
|
|
96
|
+
if setting["init_scale"] is None:
|
|
97
|
+
# warnings.warn("FYI: Will set the preconditioner initial scale on the fly. Recommend to set it manually.")
|
|
98
|
+
state["QLs_exprs"] = None
|
|
99
|
+
else:
|
|
100
|
+
state["QLs_exprs"] = init_kron(
|
|
101
|
+
param.squeeze(),
|
|
102
|
+
Scale=setting["init_scale"],
|
|
103
|
+
max_size=setting["max_dim"],
|
|
104
|
+
max_skew=setting["max_skew"],
|
|
105
|
+
dQ=setting["dQ"],
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
dQ = setting["dQ"]
|
|
109
|
+
if dQ == "QUAD4P":
|
|
110
|
+
assert torch.finfo(param.dtype).eps < 1e-6, "Directly fitting P needs at least single precision"
|
|
111
|
+
state["update_precond"] = update_precond_kron_whiten_quad4p
|
|
112
|
+
state["precond_grad"] = lambda QL, exprs, G: exprs[0](*QL[0], G) # it's exprA(*Q, G)
|
|
113
|
+
|
|
114
|
+
else:
|
|
115
|
+
state["precond_grad"] = precond_grad_kron
|
|
116
|
+
if dQ == "QEP":
|
|
117
|
+
state["update_precond"] = update_precond_kron_whiten_qep
|
|
118
|
+
elif dQ == "EQ":
|
|
119
|
+
state["update_precond"] = update_precond_kron_whiten_eq
|
|
120
|
+
elif dQ == "QEQ":
|
|
121
|
+
state["update_precond"] = update_precond_kron_whiten_qeq
|
|
122
|
+
elif dQ == "QUAD":
|
|
123
|
+
state["update_precond"] = update_precond_kron_whiten_quad
|
|
124
|
+
else:
|
|
125
|
+
assert (dQ == "Q0.5EQ1.5") or (dQ == "Q0p5EQ1p5"), f"Invalid choice for dQ: '{dQ}'"
|
|
126
|
+
state["update_precond"] = update_precond_kron_whiten_q0p5eq1p5
|
|
127
|
+
|
|
128
|
+
@torch.no_grad
|
|
129
|
+
def multi_tensor_update(self, tensors, params, grads, loss, states, settings):
|
|
130
|
+
|
|
131
|
+
# initialize on the fly if not initialized
|
|
132
|
+
if any(state["QLs_exprs"] is None for state in states):
|
|
133
|
+
|
|
134
|
+
scale = max([torch.mean((torch.abs(g))**4) for g in tensors])
|
|
135
|
+
scale = (scale + settings[0]["damping"]**4)**(-1/8)
|
|
136
|
+
|
|
137
|
+
for param, state, setting in zip(params, states, settings):
|
|
138
|
+
if state["QLs_exprs"] is None:
|
|
139
|
+
state["QLs_exprs"] = init_kron(
|
|
140
|
+
param.squeeze(),
|
|
141
|
+
Scale=scale,
|
|
142
|
+
max_size=setting["max_dim"],
|
|
143
|
+
max_skew=setting["max_skew"],
|
|
144
|
+
dQ=setting["dQ"],
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
# update preconditioners
|
|
149
|
+
# (could also try per-parameter probability)
|
|
150
|
+
if torch.rand([]) < settings[0]["update_probability"]: # update Q
|
|
151
|
+
for tensor, state, setting in zip(tensors, states, settings):
|
|
152
|
+
state["update_precond"](
|
|
153
|
+
*state["QLs_exprs"],
|
|
154
|
+
tensor.squeeze(),
|
|
155
|
+
lr=setting["lr_preconditioner"],
|
|
156
|
+
betaL=setting["betaL"],
|
|
157
|
+
damping=setting["damping"],
|
|
158
|
+
balance_prob=setting["balance_probability"]
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
@torch.no_grad
|
|
162
|
+
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
163
|
+
|
|
164
|
+
pre_tensors = []
|
|
165
|
+
|
|
166
|
+
# precondition
|
|
167
|
+
for param, tensor, state in zip(params, tensors, states):
|
|
168
|
+
t = state["precond_grad"](
|
|
169
|
+
*state["QLs_exprs"],
|
|
170
|
+
tensor.squeeze(),
|
|
171
|
+
)
|
|
172
|
+
pre_tensors.append(t.view_as(param))
|
|
173
|
+
|
|
174
|
+
# norm clipping
|
|
175
|
+
grad_clip_max_amp = settings[0]["grad_clip_max_amp"]
|
|
176
|
+
if grad_clip_max_amp < math.inf:
|
|
177
|
+
pre_tensors = TensorList(pre_tensors)
|
|
178
|
+
num_params = sum(t.numel() for t in pre_tensors)
|
|
179
|
+
|
|
180
|
+
avg_amp = pre_tensors.dot(pre_tensors.conj()).div(num_params).sqrt()
|
|
181
|
+
|
|
182
|
+
if avg_amp > grad_clip_max_amp:
|
|
183
|
+
torch._foreach_mul_(pre_tensors, grad_clip_max_amp / avg_amp)
|
|
184
|
+
|
|
185
|
+
return pre_tensors
|
|
@@ -0,0 +1,118 @@
|
|
|
1
|
+
# pylint:disable=not-callable
|
|
2
|
+
"""all functions are from https://github.com/lixilinx/psgd_torch/blob/master/psgd.py"""
|
|
3
|
+
import math
|
|
4
|
+
import warnings
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
|
|
8
|
+
from ....core import Chainable, HVPMethod, Transform
|
|
9
|
+
from ....utils import Distributions, TensorList, vec_to_tensors_
|
|
10
|
+
from .psgd import lift2single, precond_grad_lra, update_precond_lra_newton
|
|
11
|
+
from ._psgd_utils import _initialize_lra_state_
|
|
12
|
+
|
|
13
|
+
# matches
|
|
14
|
+
class PSGDLRANewton(Transform):
|
|
15
|
+
"""Low rank hessian preconditioner from Preconditioned Stochastic Gradient Descent (see https://github.com/lixilinx/psgd_torch)
|
|
16
|
+
|
|
17
|
+
Args:
|
|
18
|
+
rank (int, optional):
|
|
19
|
+
Preconditioner has a diagonal part and a low rank part, whose rank is decided by this setting. Defaults to 10.
|
|
20
|
+
init_scale (float | None, optional):
|
|
21
|
+
initial scale of the preconditioner. If None, determined based on a heuristic. Defaults to None.
|
|
22
|
+
lr_preconditioner (float, optional): learning rate of the preconditioner. Defaults to 0.1.
|
|
23
|
+
betaL (float, optional): EMA factor for the L-smoothness constant wrt Q. Defaults to 0.9.
|
|
24
|
+
damping (float, optional):
|
|
25
|
+
adds small noise to hessian-vector product when updating the preconditioner. Defaults to 1e-9.
|
|
26
|
+
grad_clip_max_norm (float, optional): clips norm of the update. Defaults to float("inf").
|
|
27
|
+
update_probability (float, optional): probability of updating preconditioner on each step. Defaults to 1.0.
|
|
28
|
+
hvp_method (HVPMethod, optional): how to compute hessian-vector products. Defaults to 'autograd'.
|
|
29
|
+
h (float, optional):
|
|
30
|
+
if ``hvp_method`` is ``"fd_central"`` or ``"fd_forward"``, controls finite difference step size.
|
|
31
|
+
Defaults to 1e-3.
|
|
32
|
+
distribution (Distributions, optional):
|
|
33
|
+
distribution for random vectors for hessian-vector products. Defaults to 'normal'.
|
|
34
|
+
|
|
35
|
+
inner (Chainable | None, optional): preconditioning will be applied to output of this module. Defaults to None.
|
|
36
|
+
|
|
37
|
+
###Examples:
|
|
38
|
+
|
|
39
|
+
Pure LRA Newton PSGD:
|
|
40
|
+
```py
|
|
41
|
+
optimizer = tz.Optimizer(
|
|
42
|
+
model.parameters(),
|
|
43
|
+
tz.m.LRANewton(),
|
|
44
|
+
tz.m.LR(1e-3),
|
|
45
|
+
)
|
|
46
|
+
```
|
|
47
|
+
|
|
48
|
+
Applying preconditioner to momentum:
|
|
49
|
+
```py
|
|
50
|
+
optimizer = tz.Optimizer(
|
|
51
|
+
model.parameters(),
|
|
52
|
+
tz.m.LRANewton(inner=tz.m.EMA(0.9)),
|
|
53
|
+
tz.m.LR(1e-3),
|
|
54
|
+
)
|
|
55
|
+
```
|
|
56
|
+
"""
|
|
57
|
+
def __init__(
|
|
58
|
+
self,
|
|
59
|
+
rank: int = 10,
|
|
60
|
+
init_scale: float | None = None,
|
|
61
|
+
lr_preconditioner=0.1,
|
|
62
|
+
betaL=0.9,
|
|
63
|
+
damping=1e-9,
|
|
64
|
+
grad_clip_max_norm=float("inf"),
|
|
65
|
+
update_probability=1.0,
|
|
66
|
+
|
|
67
|
+
hvp_method: HVPMethod = 'autograd',
|
|
68
|
+
h: float = 1e-3,
|
|
69
|
+
distribution: Distributions = 'normal',
|
|
70
|
+
|
|
71
|
+
inner: Chainable | None = None,
|
|
72
|
+
):
|
|
73
|
+
defaults = locals().copy()
|
|
74
|
+
del defaults["inner"], defaults["self"]
|
|
75
|
+
super().__init__(defaults, inner=inner)
|
|
76
|
+
|
|
77
|
+
@torch.no_grad
|
|
78
|
+
def update_states(self, objective, states, settings):
|
|
79
|
+
fs = settings[0]
|
|
80
|
+
|
|
81
|
+
# initialize
|
|
82
|
+
if "UVd" not in self.global_state:
|
|
83
|
+
p = torch.cat([t.ravel() for t in objective.params])
|
|
84
|
+
_initialize_lra_state_(p, self.global_state, fs)
|
|
85
|
+
|
|
86
|
+
UVd = self.global_state["UVd"]
|
|
87
|
+
if (torch.rand([]) < fs["update_probability"]) or (UVd[2] is None):
|
|
88
|
+
|
|
89
|
+
# hessian-vector product
|
|
90
|
+
vs = TensorList(objective.params).sample_like(distribution=fs["distribution"])
|
|
91
|
+
Hvs, _ = objective.hessian_vector_product(z=vs, rgrad=None, at_x0=True, hvp_method=fs["hvp_method"], h=fs["h"])
|
|
92
|
+
|
|
93
|
+
v = torch.cat([t.ravel() for t in vs]).unsqueeze(1)
|
|
94
|
+
h = torch.cat([t.ravel() for t in Hvs]).unsqueeze(1)
|
|
95
|
+
|
|
96
|
+
if UVd[2] is None:
|
|
97
|
+
UVd[2] = (torch.mean(v*v))**(1/4) * (torch.mean(h**4) + fs["damping"]**4)**(-1/8) * torch.ones_like(v)
|
|
98
|
+
|
|
99
|
+
# update preconditioner
|
|
100
|
+
update_precond_lra_newton(UVd=UVd, Luvd=self.global_state["Luvd"], v=v, h=h, lr=fs["lr_preconditioner"], betaL=fs["betaL"], damping=fs["damping"])
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
@torch.no_grad
|
|
104
|
+
def apply_states(self, objective, states, settings):
|
|
105
|
+
updates = objective.get_updates()
|
|
106
|
+
|
|
107
|
+
g = torch.cat([t.ravel() for t in updates]).unsqueeze(1) # column vec
|
|
108
|
+
pre_grad = precond_grad_lra(UVd=self.global_state["UVd"], g=g)
|
|
109
|
+
|
|
110
|
+
# norm clipping
|
|
111
|
+
grad_clip_max_norm = settings[0]["grad_clip_max_norm"]
|
|
112
|
+
if grad_clip_max_norm < float("inf"): # clip preconditioned gradient
|
|
113
|
+
grad_norm = torch.linalg.vector_norm(pre_grad)
|
|
114
|
+
if grad_norm > grad_clip_max_norm:
|
|
115
|
+
pre_grad *= grad_clip_max_norm / grad_norm
|
|
116
|
+
|
|
117
|
+
vec_to_tensors_(pre_grad, updates)
|
|
118
|
+
return objective
|