torchzero 0.3.15__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_identical.py +22 -22
- tests/test_module_autograd.py +586 -0
- tests/test_objective.py +188 -0
- tests/test_opts.py +225 -214
- tests/test_tensorlist.py +0 -8
- tests/test_utils_optimizer.py +0 -1
- torchzero/__init__.py +2 -2
- torchzero/core/__init__.py +7 -4
- torchzero/core/chain.py +20 -23
- torchzero/core/functional.py +90 -24
- torchzero/core/modular.py +53 -57
- torchzero/core/module.py +132 -52
- torchzero/core/objective.py +948 -0
- torchzero/core/reformulation.py +55 -24
- torchzero/core/transform.py +261 -367
- torchzero/linalg/__init__.py +11 -0
- torchzero/linalg/eigh.py +253 -0
- torchzero/linalg/linalg_utils.py +14 -0
- torchzero/{utils/linalg → linalg}/linear_operator.py +99 -49
- torchzero/linalg/matrix_power.py +28 -0
- torchzero/linalg/orthogonalize.py +93 -0
- torchzero/{utils/linalg → linalg}/qr.py +16 -2
- torchzero/{utils/linalg → linalg}/solve.py +74 -88
- torchzero/linalg/svd.py +47 -0
- torchzero/linalg/torch_linalg.py +168 -0
- torchzero/modules/__init__.py +4 -3
- torchzero/modules/adaptive/__init__.py +11 -3
- torchzero/modules/adaptive/adagrad.py +167 -217
- torchzero/modules/adaptive/adahessian.py +76 -105
- torchzero/modules/adaptive/adam.py +53 -76
- torchzero/modules/adaptive/adan.py +50 -31
- torchzero/modules/adaptive/adaptive_heavyball.py +12 -7
- torchzero/modules/adaptive/aegd.py +12 -12
- torchzero/modules/adaptive/esgd.py +98 -119
- torchzero/modules/adaptive/ggt.py +186 -0
- torchzero/modules/adaptive/lion.py +7 -11
- torchzero/modules/adaptive/lre_optimizers.py +299 -0
- torchzero/modules/adaptive/mars.py +7 -7
- torchzero/modules/adaptive/matrix_momentum.py +48 -52
- torchzero/modules/adaptive/msam.py +71 -53
- torchzero/modules/adaptive/muon.py +67 -129
- torchzero/modules/adaptive/natural_gradient.py +63 -41
- torchzero/modules/adaptive/orthograd.py +11 -15
- torchzero/modules/adaptive/psgd/__init__.py +5 -0
- torchzero/modules/adaptive/psgd/_psgd_utils.py +37 -0
- torchzero/modules/adaptive/psgd/psgd.py +1390 -0
- torchzero/modules/adaptive/psgd/psgd_dense_newton.py +174 -0
- torchzero/modules/adaptive/psgd/psgd_kron_newton.py +203 -0
- torchzero/modules/adaptive/psgd/psgd_kron_whiten.py +185 -0
- torchzero/modules/adaptive/psgd/psgd_lra_newton.py +118 -0
- torchzero/modules/adaptive/psgd/psgd_lra_whiten.py +116 -0
- torchzero/modules/adaptive/rmsprop.py +83 -75
- torchzero/modules/adaptive/rprop.py +48 -47
- torchzero/modules/adaptive/sam.py +55 -45
- torchzero/modules/adaptive/shampoo.py +149 -130
- torchzero/modules/adaptive/soap.py +207 -143
- torchzero/modules/adaptive/sophia_h.py +106 -130
- torchzero/modules/clipping/clipping.py +22 -25
- torchzero/modules/clipping/ema_clipping.py +31 -25
- torchzero/modules/clipping/growth_clipping.py +14 -17
- torchzero/modules/conjugate_gradient/cg.py +27 -38
- torchzero/modules/experimental/__init__.py +7 -6
- torchzero/modules/experimental/adanystrom.py +258 -0
- torchzero/modules/experimental/common_directions_whiten.py +142 -0
- torchzero/modules/experimental/coordinate_momentum.py +36 -0
- torchzero/modules/experimental/cubic_adam.py +160 -0
- torchzero/modules/experimental/curveball.py +25 -41
- torchzero/modules/experimental/eigen_sr1.py +182 -0
- torchzero/modules/experimental/eigengrad.py +207 -0
- torchzero/modules/experimental/gradmin.py +2 -2
- torchzero/modules/experimental/higher_order_newton.py +14 -40
- torchzero/modules/experimental/l_infinity.py +1 -1
- torchzero/modules/experimental/matrix_nag.py +122 -0
- torchzero/modules/experimental/newton_solver.py +23 -54
- torchzero/modules/experimental/newtonnewton.py +45 -48
- torchzero/modules/experimental/reduce_outward_lr.py +7 -7
- torchzero/modules/experimental/scipy_newton_cg.py +21 -24
- torchzero/modules/experimental/spsa1.py +3 -3
- torchzero/modules/experimental/structural_projections.py +1 -4
- torchzero/modules/grad_approximation/fdm.py +2 -2
- torchzero/modules/grad_approximation/forward_gradient.py +7 -7
- torchzero/modules/grad_approximation/grad_approximator.py +23 -16
- torchzero/modules/grad_approximation/rfdm.py +24 -21
- torchzero/modules/least_squares/gn.py +121 -50
- torchzero/modules/line_search/backtracking.py +4 -4
- torchzero/modules/line_search/line_search.py +33 -33
- torchzero/modules/line_search/strong_wolfe.py +4 -4
- torchzero/modules/misc/debug.py +12 -12
- torchzero/modules/misc/escape.py +10 -10
- torchzero/modules/misc/gradient_accumulation.py +11 -79
- torchzero/modules/misc/homotopy.py +16 -8
- torchzero/modules/misc/misc.py +121 -123
- torchzero/modules/misc/multistep.py +52 -53
- torchzero/modules/misc/regularization.py +49 -44
- torchzero/modules/misc/split.py +31 -29
- torchzero/modules/misc/switch.py +37 -32
- torchzero/modules/momentum/averaging.py +14 -14
- torchzero/modules/momentum/cautious.py +37 -31
- torchzero/modules/momentum/momentum.py +12 -12
- torchzero/modules/ops/__init__.py +4 -4
- torchzero/modules/ops/accumulate.py +21 -21
- torchzero/modules/ops/binary.py +67 -66
- torchzero/modules/ops/higher_level.py +20 -20
- torchzero/modules/ops/multi.py +44 -41
- torchzero/modules/ops/reduce.py +26 -23
- torchzero/modules/ops/unary.py +53 -53
- torchzero/modules/ops/utility.py +47 -46
- torchzero/modules/{functional.py → opt_utils.py} +1 -1
- torchzero/modules/projections/galore.py +1 -1
- torchzero/modules/projections/projection.py +46 -43
- torchzero/modules/quasi_newton/__init__.py +1 -1
- torchzero/modules/quasi_newton/damping.py +2 -2
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +1 -1
- torchzero/modules/quasi_newton/lbfgs.py +10 -10
- torchzero/modules/quasi_newton/lsr1.py +10 -10
- torchzero/modules/quasi_newton/quasi_newton.py +54 -39
- torchzero/modules/quasi_newton/sg2.py +69 -205
- torchzero/modules/restarts/restars.py +39 -37
- torchzero/modules/second_order/__init__.py +2 -2
- torchzero/modules/second_order/ifn.py +31 -62
- torchzero/modules/second_order/inm.py +57 -53
- torchzero/modules/second_order/multipoint.py +40 -80
- torchzero/modules/second_order/newton.py +165 -196
- torchzero/modules/second_order/newton_cg.py +105 -157
- torchzero/modules/second_order/nystrom.py +216 -185
- torchzero/modules/second_order/rsn.py +132 -125
- torchzero/modules/smoothing/laplacian.py +13 -12
- torchzero/modules/smoothing/sampling.py +10 -10
- torchzero/modules/step_size/adaptive.py +24 -24
- torchzero/modules/step_size/lr.py +17 -17
- torchzero/modules/termination/termination.py +32 -30
- torchzero/modules/trust_region/cubic_regularization.py +3 -3
- torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
- torchzero/modules/trust_region/trust_cg.py +2 -2
- torchzero/modules/trust_region/trust_region.py +27 -22
- torchzero/modules/variance_reduction/svrg.py +23 -21
- torchzero/modules/weight_decay/__init__.py +2 -1
- torchzero/modules/weight_decay/reinit.py +83 -0
- torchzero/modules/weight_decay/weight_decay.py +17 -18
- torchzero/modules/wrappers/optim_wrapper.py +14 -14
- torchzero/modules/zeroth_order/cd.py +10 -7
- torchzero/optim/mbs.py +291 -0
- torchzero/optim/root.py +3 -3
- torchzero/optim/utility/split.py +2 -1
- torchzero/optim/wrappers/directsearch.py +27 -63
- torchzero/optim/wrappers/fcmaes.py +14 -35
- torchzero/optim/wrappers/mads.py +11 -31
- torchzero/optim/wrappers/moors.py +66 -0
- torchzero/optim/wrappers/nevergrad.py +4 -13
- torchzero/optim/wrappers/nlopt.py +31 -25
- torchzero/optim/wrappers/optuna.py +8 -13
- torchzero/optim/wrappers/pybobyqa.py +124 -0
- torchzero/optim/wrappers/scipy/__init__.py +7 -0
- torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
- torchzero/optim/wrappers/scipy/brute.py +48 -0
- torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
- torchzero/optim/wrappers/scipy/direct.py +69 -0
- torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
- torchzero/optim/wrappers/scipy/experimental.py +141 -0
- torchzero/optim/wrappers/scipy/minimize.py +151 -0
- torchzero/optim/wrappers/scipy/sgho.py +111 -0
- torchzero/optim/wrappers/wrapper.py +121 -0
- torchzero/utils/__init__.py +7 -25
- torchzero/utils/benchmarks/__init__.py +0 -0
- torchzero/utils/benchmarks/logistic.py +122 -0
- torchzero/utils/compile.py +2 -2
- torchzero/utils/derivatives.py +97 -73
- torchzero/utils/optimizer.py +4 -77
- torchzero/utils/python_tools.py +31 -0
- torchzero/utils/tensorlist.py +11 -5
- torchzero/utils/thoad_tools.py +68 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/METADATA +1 -1
- torchzero-0.4.1.dist-info/RECORD +209 -0
- tests/test_vars.py +0 -185
- torchzero/core/var.py +0 -376
- torchzero/modules/adaptive/lmadagrad.py +0 -186
- torchzero/modules/experimental/momentum.py +0 -160
- torchzero/optim/wrappers/scipy.py +0 -572
- torchzero/utils/linalg/__init__.py +0 -12
- torchzero/utils/linalg/matrix_funcs.py +0 -87
- torchzero/utils/linalg/orthogonalize.py +0 -12
- torchzero/utils/linalg/svd.py +0 -20
- torchzero/utils/ops.py +0 -10
- torchzero-0.3.15.dist-info/RECORD +0 -175
- /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/WHEEL +0 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,160 @@
|
|
|
1
|
+
from typing import Any, Literal
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from ...core import TensorTransform
|
|
6
|
+
from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
|
|
7
|
+
from ..adaptive.lre_optimizers import LREOptimizerBase, _squared_reproject
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def signed_cbrt(x: TensorList | Any) -> Any:
|
|
11
|
+
return x.sign() * x.abs().pow(1/3)
|
|
12
|
+
|
|
13
|
+
def _clip_min_magnitude(x: torch.Tensor, eps: float):
|
|
14
|
+
return x.sign() * x.abs().clamp(min=eps)
|
|
15
|
+
|
|
16
|
+
_cubic_adam_mode = Literal["signed_cbrt", "unsigned_cbrt", "halve"]
|
|
17
|
+
|
|
18
|
+
def _cubic_minimize(A: torch.Tensor | Any, B: torch.Tensor | Any, C: torch.Tensor | Any, eps):
|
|
19
|
+
"""minimizes (A/3)x^3 + (A/2)x^2 + Cx"""
|
|
20
|
+
discriminant = B**2 - 4 * A * C
|
|
21
|
+
|
|
22
|
+
denom = _clip_min_magnitude(2 * A, eps)
|
|
23
|
+
root = discriminant.clamp(min=0).sqrt_()
|
|
24
|
+
|
|
25
|
+
x0 = (-B + root) / denom
|
|
26
|
+
x1 = (-B - root) / denom
|
|
27
|
+
|
|
28
|
+
f0 = (A/3)*x0**3 + (B/2)*x0**2 + C*x0
|
|
29
|
+
f1 = (A/3)*x1**3 + (B/2)*x1**2 + C*x1
|
|
30
|
+
|
|
31
|
+
x_star = x0.where(f0 < f1, x1)
|
|
32
|
+
|
|
33
|
+
adam = -C / (B + eps)
|
|
34
|
+
return adam.where(discriminant < 0, x_star)
|
|
35
|
+
|
|
36
|
+
def cubic_adam_(
|
|
37
|
+
tensors: TensorList,
|
|
38
|
+
exp_avg_: TensorList,
|
|
39
|
+
exp_avg_sq_: TensorList,
|
|
40
|
+
exp_avg_cu_: TensorList,
|
|
41
|
+
alpha: float | NumberList,
|
|
42
|
+
beta1: float | NumberList,
|
|
43
|
+
beta2: float | NumberList,
|
|
44
|
+
beta3: float | NumberList,
|
|
45
|
+
eps: float | NumberList,
|
|
46
|
+
debiased: bool,
|
|
47
|
+
step: int,
|
|
48
|
+
|
|
49
|
+
mode: _cubic_adam_mode = 'signed_cbrt'
|
|
50
|
+
):
|
|
51
|
+
exp_avg_.lerp_(tensors, 1-beta1)
|
|
52
|
+
exp_avg_sq_.lerp_(tensors**2, 1-beta2)
|
|
53
|
+
exp_avg_cu_.lerp_(tensors**3, 1-beta3)
|
|
54
|
+
|
|
55
|
+
if debiased:
|
|
56
|
+
m1 = exp_avg_ / (1 - beta1 ** step)
|
|
57
|
+
m2 = exp_avg_sq_ / (1 - beta2 ** step)
|
|
58
|
+
m3 = exp_avg_cu_ / (1 - beta3 ** step)
|
|
59
|
+
else:
|
|
60
|
+
m1, m2, m3 = exp_avg_, exp_avg_sq_, exp_avg_cu_
|
|
61
|
+
|
|
62
|
+
# adam minimizes ax^2 + bx
|
|
63
|
+
# we are going to minimize ax^3 + bx^2 + cx
|
|
64
|
+
|
|
65
|
+
if mode == "signed_cbrt": A = signed_cbrt(m3)
|
|
66
|
+
elif mode == "unsigned_cbrt": A = m3.abs().pow(1/3)
|
|
67
|
+
elif mode == 'halve': A = 0.5 * m3
|
|
68
|
+
else: raise ValueError(mode)
|
|
69
|
+
|
|
70
|
+
B = m2.sqrt()
|
|
71
|
+
C = m1
|
|
72
|
+
x_star = _cubic_minimize(A, B, C, eps)
|
|
73
|
+
return x_star.mul_(-alpha)
|
|
74
|
+
|
|
75
|
+
class CubicAdam(TensorTransform):
|
|
76
|
+
"""Adam which has 3rd momentum and minimizes a cubic polynomial."""
|
|
77
|
+
def __init__(
|
|
78
|
+
self,
|
|
79
|
+
beta1: float = 0.9,
|
|
80
|
+
beta2: float = 0.99,
|
|
81
|
+
beta3: float = 0.99,
|
|
82
|
+
eps: float = 1e-8,
|
|
83
|
+
debiased:bool=True,
|
|
84
|
+
alpha: float = 1.,
|
|
85
|
+
|
|
86
|
+
mode: _cubic_adam_mode = 'signed_cbrt'
|
|
87
|
+
):
|
|
88
|
+
defaults=dict(beta1=beta1,beta2=beta2,beta3=beta3,eps=eps,debiased=debiased,alpha=alpha,mode=mode)
|
|
89
|
+
super().__init__(defaults)
|
|
90
|
+
|
|
91
|
+
@torch.no_grad
|
|
92
|
+
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
93
|
+
step = self.global_state['step'] = self.global_state.get('step', 0) + 1
|
|
94
|
+
|
|
95
|
+
beta1,beta2,beta3,eps,alpha=unpack_dicts(settings, 'beta1','beta2','beta3','eps','alpha', cls=NumberList)
|
|
96
|
+
exp_avg, exp_avg_sq, exp_avg_cu = unpack_states(states, tensors, 'exp_avg', 'exp_avg_sq', 'exp_avg_cu', cls=TensorList)
|
|
97
|
+
|
|
98
|
+
return cubic_adam_(
|
|
99
|
+
tensors=TensorList(tensors),
|
|
100
|
+
exp_avg_=exp_avg,
|
|
101
|
+
exp_avg_sq_=exp_avg_sq,
|
|
102
|
+
exp_avg_cu_=exp_avg_cu,
|
|
103
|
+
alpha=alpha,
|
|
104
|
+
beta1=beta1,
|
|
105
|
+
beta2=beta2,
|
|
106
|
+
beta3=beta3,
|
|
107
|
+
eps=eps,
|
|
108
|
+
debiased=settings[0]['debiased'],
|
|
109
|
+
step=step,
|
|
110
|
+
|
|
111
|
+
mode=settings[0]["mode"]
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
class SubspaceCubicAdam(LREOptimizerBase):
|
|
115
|
+
"""Runs cubic Adam in low rank eigenbasis."""
|
|
116
|
+
def __init__(self, beta1=0.9, beta2=0.95, beta3=0.95, eps=1e-8, mode: _cubic_adam_mode = 'signed_cbrt', cautious:bool=False, exact_reproject:bool=True):
|
|
117
|
+
self.beta1 = beta1
|
|
118
|
+
self.beta2 = beta2
|
|
119
|
+
self.beta3 = beta3
|
|
120
|
+
self.eps = eps
|
|
121
|
+
self.cautious = cautious
|
|
122
|
+
self.mode: _cubic_adam_mode = mode
|
|
123
|
+
self.exact_reproject = exact_reproject
|
|
124
|
+
|
|
125
|
+
def step(self, g, L, Q, state):
|
|
126
|
+
g = Q.T @ g
|
|
127
|
+
|
|
128
|
+
if "exp_avg" not in state:
|
|
129
|
+
state["exp_avg"] = torch.zeros_like(g)
|
|
130
|
+
state["exp_avg_sq"] = torch.zeros_like(g)
|
|
131
|
+
state["exp_avg_cu"] = torch.zeros_like(g)
|
|
132
|
+
state["current_step"] = 1
|
|
133
|
+
|
|
134
|
+
dir = cubic_adam_(
|
|
135
|
+
tensors = TensorList([g]),
|
|
136
|
+
exp_avg_ = TensorList([state["exp_avg"]]),
|
|
137
|
+
exp_avg_sq_ = TensorList([state["exp_avg_sq"]]),
|
|
138
|
+
exp_avg_cu_ = TensorList([state["exp_avg_cu"]]),
|
|
139
|
+
alpha = 1,
|
|
140
|
+
beta1 = self.beta1,
|
|
141
|
+
beta2 = self.beta2,
|
|
142
|
+
beta3 = self.beta3,
|
|
143
|
+
eps = self.eps,
|
|
144
|
+
debiased = True,
|
|
145
|
+
step = state["current_step"],
|
|
146
|
+
|
|
147
|
+
mode=self.mode,
|
|
148
|
+
)[0]
|
|
149
|
+
|
|
150
|
+
state["current_step"] += 1
|
|
151
|
+
return Q @ dir
|
|
152
|
+
|
|
153
|
+
def reproject(self, L_old, Q_old, L_new, Q_new, state):
|
|
154
|
+
if "exp_avg" not in state: return
|
|
155
|
+
|
|
156
|
+
C = Q_new.T @ Q_old
|
|
157
|
+
|
|
158
|
+
state["exp_avg"] = C @ state["exp_avg"]
|
|
159
|
+
state["exp_avg_sq"] = _squared_reproject(C, state["exp_avg_sq"], exact=self.exact_reproject)
|
|
160
|
+
state["exp_avg_cu"] = C.pow(3) @ state["exp_avg_cu"] # exact reproject with 1_000_000 is feasible
|
|
@@ -1,25 +1,25 @@
|
|
|
1
1
|
from typing import Literal
|
|
2
|
-
|
|
2
|
+
|
|
3
3
|
import torch
|
|
4
4
|
|
|
5
|
-
from ...core import
|
|
6
|
-
from ...utils import NumberList, TensorList,
|
|
7
|
-
|
|
5
|
+
from ...core import Chainable, Transform, step, HVPMethod
|
|
6
|
+
from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
|
|
7
|
+
|
|
8
8
|
|
|
9
9
|
def curveball(
|
|
10
10
|
tensors: TensorList,
|
|
11
11
|
z_: TensorList,
|
|
12
|
-
|
|
12
|
+
Hzz: TensorList,
|
|
13
13
|
momentum: float | NumberList,
|
|
14
14
|
precond_lr: float | NumberList,
|
|
15
15
|
):
|
|
16
16
|
"""returns z_, clone it!!! (no just negate it)"""
|
|
17
|
-
delta =
|
|
17
|
+
delta = Hzz + tensors
|
|
18
18
|
z_.mul_(momentum).sub_(delta.mul_(precond_lr)) # z ← ρz − βΔ
|
|
19
19
|
return z_
|
|
20
20
|
|
|
21
21
|
|
|
22
|
-
class CurveBall(
|
|
22
|
+
class CurveBall(Transform):
|
|
23
23
|
"""CurveBall method from https://arxiv.org/pdf/1805.08095#page=4.09.
|
|
24
24
|
|
|
25
25
|
For now this implementation does not include automatic ρ, α and β hyper-parameters in closed form, therefore it is expected to underperform compared to official implementation (https://github.com/jotaf98/pytorch-curveball/tree/master) so I moved this to experimental.
|
|
@@ -36,7 +36,7 @@ class CurveBall(Module):
|
|
|
36
36
|
self,
|
|
37
37
|
precond_lr: float=1e-3,
|
|
38
38
|
momentum: float=0.9,
|
|
39
|
-
hvp_method:
|
|
39
|
+
hvp_method: HVPMethod = "autograd",
|
|
40
40
|
h: float = 1e-3,
|
|
41
41
|
reg: float = 1,
|
|
42
42
|
inner: Chainable | None = None,
|
|
@@ -44,46 +44,30 @@ class CurveBall(Module):
|
|
|
44
44
|
defaults = dict(precond_lr=precond_lr, momentum=momentum, hvp_method=hvp_method, h=h, reg=reg)
|
|
45
45
|
super().__init__(defaults)
|
|
46
46
|
|
|
47
|
-
|
|
47
|
+
self.set_child('inner', inner)
|
|
48
48
|
|
|
49
49
|
@torch.no_grad
|
|
50
|
-
def
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
h = settings['h']
|
|
50
|
+
def apply_states(self, objective, states, settings):
|
|
51
|
+
params = objective.params
|
|
52
|
+
fs = settings[0]
|
|
53
|
+
hvp_method = fs['hvp_method']
|
|
54
|
+
h = fs['h']
|
|
56
55
|
|
|
57
|
-
precond_lr, momentum, reg =
|
|
56
|
+
precond_lr, momentum, reg = unpack_dicts(settings, 'precond_lr', 'momentum', 'reg', cls=NumberList)
|
|
58
57
|
|
|
59
|
-
|
|
60
|
-
closure = var.closure
|
|
58
|
+
closure = objective.closure
|
|
61
59
|
assert closure is not None
|
|
62
60
|
|
|
63
|
-
z, Hz =
|
|
64
|
-
|
|
65
|
-
if hvp_method == 'autograd':
|
|
66
|
-
grad = var.get_grad(create_graph=True)
|
|
67
|
-
Hvp = hvp(params, grad, z)
|
|
68
|
-
|
|
69
|
-
elif hvp_method == 'forward':
|
|
70
|
-
loss, Hvp = hvp_fd_forward(closure, params, z, h=h, g_0=var.get_grad(), normalize=True)
|
|
71
|
-
|
|
72
|
-
elif hvp_method == 'central':
|
|
73
|
-
loss, Hvp = hvp_fd_central(closure, params, z, h=h, normalize=True)
|
|
74
|
-
|
|
75
|
-
else:
|
|
76
|
-
raise ValueError(hvp_method)
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
Hz.set_(Hvp + z*reg)
|
|
61
|
+
z, Hz = unpack_states(states, params, 'z', 'Hz', cls=TensorList)
|
|
62
|
+
Hz, _ = objective.hessian_vector_product(z, rgrad=None, at_x0=True, hvp_method=hvp_method, h=h)
|
|
80
63
|
|
|
64
|
+
Hz = TensorList(Hz)
|
|
65
|
+
Hzz = Hz.add_(z * reg)
|
|
81
66
|
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
update = apply_transform(self.children['inner'], update, params, grads=var.grad, var=var)
|
|
67
|
+
objective = self.inner_step("inner", objective, must_exist=False)
|
|
68
|
+
updates = objective.get_updates()
|
|
85
69
|
|
|
86
|
-
z = curveball(TensorList(
|
|
87
|
-
|
|
70
|
+
z = curveball(TensorList(updates), z, Hzz, momentum=momentum, precond_lr=precond_lr)
|
|
71
|
+
objective.updates = z.neg()
|
|
88
72
|
|
|
89
|
-
return
|
|
73
|
+
return objective
|
|
@@ -0,0 +1,182 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
from ...core import Transform
|
|
4
|
+
from ...linalg.orthogonalize import orthogonalize, OrthogonalizeMethod
|
|
5
|
+
from ...linalg.eigh import eigh_plus_uuT, regularize_eigh
|
|
6
|
+
from ...utils import TensorList, unpack_states, vec_to_tensors_
|
|
7
|
+
from ..opt_utils import safe_clip
|
|
8
|
+
from .eigengrad import _eigengrad_update_state_, eigengrad_apply
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def sr1_u(L: torch.Tensor, Q: torch.Tensor, s:torch.Tensor, y: torch.Tensor, tol:float):
|
|
12
|
+
"""u from u u^T correction and its sign"""
|
|
13
|
+
r = y - torch.linalg.multi_dot([Q, L.diag_embed(), Q.T, s]) # pylint:disable=not-callable
|
|
14
|
+
rs = r.dot(s)
|
|
15
|
+
|
|
16
|
+
if rs.abs() < tol * torch.linalg.vector_norm(r) * torch.linalg.vector_norm(s): # pylint:disable=not-callable
|
|
17
|
+
return None, None
|
|
18
|
+
|
|
19
|
+
u = r / rs.abs().sqrt()
|
|
20
|
+
return u, torch.sign(rs)
|
|
21
|
+
|
|
22
|
+
class EigenSR1(Transform):
|
|
23
|
+
def __init__(
|
|
24
|
+
self,
|
|
25
|
+
rank: int = 100,
|
|
26
|
+
tol: float = 1e-32,
|
|
27
|
+
eig_tol: float | None = None,
|
|
28
|
+
damping: float = 0,
|
|
29
|
+
rdamping: float = 0,
|
|
30
|
+
abs: bool = False,
|
|
31
|
+
mm_tol: float = 1e-7,
|
|
32
|
+
mm_truncate: int | None = None,
|
|
33
|
+
mm_damping: float = 1e-4,
|
|
34
|
+
mm_rdamping: float = 0,
|
|
35
|
+
mm_abs: bool = True,
|
|
36
|
+
id_reg: float | None = None,
|
|
37
|
+
column_space_tol=1e-9,
|
|
38
|
+
beta: float = 0.95,
|
|
39
|
+
balance_tol: float = 10,
|
|
40
|
+
balance_strength: float = 1e-1,
|
|
41
|
+
|
|
42
|
+
eigenbasis_optimizer = None,
|
|
43
|
+
update_freq: int = 1,
|
|
44
|
+
init_steps: int = 10,
|
|
45
|
+
orthogonalize_interval: int | None = 1,
|
|
46
|
+
orthogonalize_method: OrthogonalizeMethod = 'qr',
|
|
47
|
+
|
|
48
|
+
hvp_method = "autograd",
|
|
49
|
+
h = 1e-3,
|
|
50
|
+
inner = None,
|
|
51
|
+
|
|
52
|
+
):
|
|
53
|
+
defaults = locals().copy()
|
|
54
|
+
for k in ["self", "inner"]:
|
|
55
|
+
del defaults[k]
|
|
56
|
+
|
|
57
|
+
super().__init__(defaults)
|
|
58
|
+
|
|
59
|
+
def update_states(self, objective, states, settings):
|
|
60
|
+
fs = settings[0]
|
|
61
|
+
step = self.increment_counter("step", 0)
|
|
62
|
+
|
|
63
|
+
if step % fs["update_freq"] == 0:
|
|
64
|
+
|
|
65
|
+
params = TensorList(objective.params)
|
|
66
|
+
|
|
67
|
+
# compute y as hessian-vector product with s (random vecs during init steps)
|
|
68
|
+
if ("p_prev" not in self.global_state) or (step < fs["init_steps"]):
|
|
69
|
+
s_list = params.sample_like('rademacher')
|
|
70
|
+
|
|
71
|
+
else:
|
|
72
|
+
p_prev = self.global_state["p_prev"]
|
|
73
|
+
s_list = params - p_prev
|
|
74
|
+
|
|
75
|
+
if s_list.dot(s_list) < torch.finfo(s_list[0].dtype).tiny * 2:
|
|
76
|
+
s_list = params.sample_like('rademacher')
|
|
77
|
+
|
|
78
|
+
self.global_state["p_prev"] = params
|
|
79
|
+
|
|
80
|
+
# compute y as hessian-vector product with s
|
|
81
|
+
Hz_list, _ = objective.hessian_vector_product(s_list, rgrad=None, at_x0=True, hvp_method=fs["hvp_method"], h=fs["h"])
|
|
82
|
+
|
|
83
|
+
s = torch.cat([t.ravel() for t in s_list])
|
|
84
|
+
y = torch.cat([t.ravel() for t in Hz_list])
|
|
85
|
+
|
|
86
|
+
# keep track of exponential moving average of hessian diagonal and balance eigenvalues
|
|
87
|
+
if (fs["balance_strength"] != 0) and (step > fs["init_steps"]) and ("L" in self.global_state):
|
|
88
|
+
|
|
89
|
+
D = s * y # hutchinson estimator
|
|
90
|
+
exp_avg = self.global_state.get("exp_avg", None)
|
|
91
|
+
|
|
92
|
+
if exp_avg is None:
|
|
93
|
+
exp_avg = self.global_state["exp_avg"] = D
|
|
94
|
+
|
|
95
|
+
else:
|
|
96
|
+
exp_avg.lerp_(D, weight=1-fs["beta"])
|
|
97
|
+
|
|
98
|
+
L = self.global_state["L"]
|
|
99
|
+
L_abs = L.abs()
|
|
100
|
+
tau = L_abs.amax() / exp_avg.abs().amax()
|
|
101
|
+
|
|
102
|
+
if tau > fs["balance_tol"]:
|
|
103
|
+
L_balanced = L_abs.pow((1 / tau) ** (1 / fs["balance_strength"])).copysign(L)
|
|
104
|
+
self.global_state["L"] = torch.where(L_abs > 1, L_balanced, L)
|
|
105
|
+
|
|
106
|
+
# initialize L and Q on 1st step
|
|
107
|
+
if "L" not in self.global_state:
|
|
108
|
+
|
|
109
|
+
L = torch.zeros(1, dtype=s.dtype, device=s.device) # rank, rank
|
|
110
|
+
Q = torch.zeros([s.numel(), 1], dtype=s.dtype, device=s.device) # ndim, rank
|
|
111
|
+
|
|
112
|
+
u, sign = sr1_u(L=L, Q=Q, s=s, y=y, tol=0)
|
|
113
|
+
assert u is not None and sign is not None
|
|
114
|
+
|
|
115
|
+
# for uu^T u is eigenvector and u^T u is eigenvalue
|
|
116
|
+
norm = torch.linalg.vector_norm(u).clip(min=torch.finfo(u.dtype).tiny * 2) # pylint:disable=not-callable
|
|
117
|
+
|
|
118
|
+
self.global_state["L"] = self.global_state["L_reg"] = (u.dot(u).unsqueeze(0) / norm) * sign # (rank,)
|
|
119
|
+
self.global_state["Q"] = self.global_state["Q_reg"] = u.unsqueeze(-1) / norm # (m, rank)
|
|
120
|
+
|
|
121
|
+
# update hessian
|
|
122
|
+
else:
|
|
123
|
+
try:
|
|
124
|
+
L = self.global_state["L"]
|
|
125
|
+
Q = self.global_state["Q"]
|
|
126
|
+
|
|
127
|
+
H_step = self.increment_counter("H_step", start=0)
|
|
128
|
+
if H_step % fs["orthogonalize_interval"] == 0:
|
|
129
|
+
Q = orthogonalize(Q, method=fs["orthogonalize_method"])
|
|
130
|
+
|
|
131
|
+
u, sign = sr1_u(L=L, Q=Q, s=s, y=y, tol=fs["tol"])
|
|
132
|
+
|
|
133
|
+
if (u is not None) and (sign is not None):
|
|
134
|
+
|
|
135
|
+
# compute new factors
|
|
136
|
+
L_new, Q_new = eigh_plus_uuT(L, Q, u, tol=fs["column_space_tol"], alpha=sign.item(), retry_float64=True)
|
|
137
|
+
|
|
138
|
+
# truncate/regularize new factors (those go into the accumulator)
|
|
139
|
+
L_new, Q_new = regularize_eigh(L=L_new, Q=Q_new, truncate=min(fs["rank"], s.numel()),
|
|
140
|
+
tol=fs["eig_tol"], damping=fs["damping"], rdamping=fs["rdamping"])
|
|
141
|
+
|
|
142
|
+
_eigengrad_update_state_(state=self.global_state, setting=fs, L_new=L_new, Q_new=Q_new)
|
|
143
|
+
|
|
144
|
+
except torch.linalg.LinAlgError:
|
|
145
|
+
pass
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
def apply_states(self, objective, states, settings):
|
|
150
|
+
fs = settings[0]
|
|
151
|
+
updates = objective.get_updates()
|
|
152
|
+
|
|
153
|
+
if "eigenbasis_state" not in self.global_state:
|
|
154
|
+
self.global_state["eigenbasis_state"] = {}
|
|
155
|
+
|
|
156
|
+
step = self.global_state["step"] # starts at 0
|
|
157
|
+
if step < fs["init_steps"]:
|
|
158
|
+
|
|
159
|
+
# skip update first init_steps to let hessian kick-start
|
|
160
|
+
objective.stop = True
|
|
161
|
+
objective.skip_update = True
|
|
162
|
+
return objective
|
|
163
|
+
|
|
164
|
+
if "L_reg" not in self.global_state:
|
|
165
|
+
TensorList(updates).clip_(-0.1, 0.1)
|
|
166
|
+
return objective
|
|
167
|
+
|
|
168
|
+
dir = eigengrad_apply(
|
|
169
|
+
tensor = torch.cat([t.ravel() for t in updates]),
|
|
170
|
+
L_reg = self.global_state["L_reg"],
|
|
171
|
+
Q_reg = self.global_state["Q_reg"],
|
|
172
|
+
beta = None,
|
|
173
|
+
step = None,
|
|
174
|
+
debias = False,
|
|
175
|
+
id_reg = fs["id_reg"],
|
|
176
|
+
eigenbasis_optimizer = fs["eigenbasis_optimizer"],
|
|
177
|
+
eigenbasis_state = self.global_state["eigenbasis_state"],
|
|
178
|
+
whiten_fn = lambda x: x
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
vec_to_tensors_(dir, updates)
|
|
182
|
+
return objective
|
|
@@ -0,0 +1,207 @@
|
|
|
1
|
+
# pylint: disable = non-ascii-name
|
|
2
|
+
from collections.abc import Mapping
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
from ...core import Chainable, TensorTransform
|
|
7
|
+
from ...linalg.eigh import eigh_plus_uuT, regularize_eigh
|
|
8
|
+
from ...linalg.orthogonalize import OrthogonalizeMethod, orthogonalize
|
|
9
|
+
from ...linalg.linear_operator import Eigendecomposition
|
|
10
|
+
from ..adaptive.lre_optimizers import LREOptimizerBase
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def _eigengrad_update_state_(state:dict, setting: Mapping, L_new: torch.Tensor | None, Q_new:torch.Tensor | None):
|
|
14
|
+
"""stores L, Q, L_reg, Q_reg and reprojects eigenbasis opt (this is also used on other eigen based modules)"""
|
|
15
|
+
if (L_new is not None) and (Q_new is not None):
|
|
16
|
+
|
|
17
|
+
# re-orthogonalize
|
|
18
|
+
orthogonalize_interval = setting["orthogonalize_interval"]
|
|
19
|
+
if orthogonalize_interval is not None:
|
|
20
|
+
Q_step = state.get("Q_step", 0)
|
|
21
|
+
state["Q_step"] = Q_step + 1
|
|
22
|
+
if Q_step % orthogonalize_interval == 0:
|
|
23
|
+
Q_new = orthogonalize(Q_new, method=setting["orthogonalize_method"])
|
|
24
|
+
|
|
25
|
+
# take absolute value (for hessian)
|
|
26
|
+
if setting.get("abs", False):
|
|
27
|
+
L_new = L_new.abs()
|
|
28
|
+
|
|
29
|
+
# store
|
|
30
|
+
state["L"] = L_new
|
|
31
|
+
state["Q"] = Q_new
|
|
32
|
+
|
|
33
|
+
# absolute value for matmul
|
|
34
|
+
if setting.get("mm_abs", False):
|
|
35
|
+
L_new = L_new.abs()
|
|
36
|
+
|
|
37
|
+
# regularize for matmul
|
|
38
|
+
# this second round of regularization is only used for preconditioning
|
|
39
|
+
# and doesn't affect the accumulator
|
|
40
|
+
L_reg_new, Q_reg_new = regularize_eigh(L=L_new, Q=Q_new,
|
|
41
|
+
truncate=setting["mm_truncate"],
|
|
42
|
+
tol=setting["mm_tol"],
|
|
43
|
+
damping=setting["mm_damping"],
|
|
44
|
+
rdamping=setting["mm_rdamping"],
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
# print(f'{state["L_reg"] = }, {L_reg_new = }')
|
|
48
|
+
|
|
49
|
+
# reproject eigenbasis optimizer
|
|
50
|
+
if (L_reg_new is not None) and (Q_reg_new is not None):
|
|
51
|
+
eigenbasis_optimizer: LREOptimizerBase | None = setting["eigenbasis_optimizer"]
|
|
52
|
+
if eigenbasis_optimizer is not None:
|
|
53
|
+
eigenbasis_optimizer.reproject(L_old=state["L_reg"], Q_old=state["Q_reg"], L_new=L_reg_new,
|
|
54
|
+
Q_new=Q_reg_new, state=state["eigenbasis_state"])
|
|
55
|
+
|
|
56
|
+
state["L_reg"] = L_reg_new
|
|
57
|
+
state["Q_reg"] = Q_reg_new
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def eigengrad_apply(
|
|
61
|
+
tensor: torch.Tensor,
|
|
62
|
+
L_reg: torch.Tensor,
|
|
63
|
+
Q_reg: torch.Tensor,
|
|
64
|
+
beta: float | None,
|
|
65
|
+
step: int | None,
|
|
66
|
+
debias: bool,
|
|
67
|
+
id_reg: float | None,
|
|
68
|
+
eigenbasis_optimizer: LREOptimizerBase | None,
|
|
69
|
+
eigenbasis_state: dict,
|
|
70
|
+
|
|
71
|
+
whiten_fn = torch.sqrt
|
|
72
|
+
):
|
|
73
|
+
# debias
|
|
74
|
+
if debias:
|
|
75
|
+
assert beta is not None and step is not None
|
|
76
|
+
L_reg = L_reg / (1 - beta **step)
|
|
77
|
+
|
|
78
|
+
# step with eigenbasis optimizer
|
|
79
|
+
if eigenbasis_optimizer is not None:
|
|
80
|
+
if (id_reg is not None) and (id_reg != 0):
|
|
81
|
+
raise RuntimeError("id_reg is not compatible with eigenbasis_optimizer")
|
|
82
|
+
|
|
83
|
+
update = eigenbasis_optimizer.step(tensor.ravel(), L=L_reg, Q=Q_reg, state=eigenbasis_state)
|
|
84
|
+
return update.view_as(tensor)
|
|
85
|
+
|
|
86
|
+
# or just whiten
|
|
87
|
+
# L_reg = L_reg.clip(min=torch.finfo(L_reg.dtype).tiny * 2)
|
|
88
|
+
|
|
89
|
+
if id_reg is None or id_reg == 0:
|
|
90
|
+
G = Eigendecomposition(whiten_fn(L_reg), Q_reg, use_nystrom=False)
|
|
91
|
+
dir = G.solve(tensor.ravel())
|
|
92
|
+
|
|
93
|
+
else:
|
|
94
|
+
G = Eigendecomposition(whiten_fn(L_reg), Q_reg, use_nystrom=True)
|
|
95
|
+
dir = G.solve_plus_diag(tensor.ravel(), diag=id_reg)
|
|
96
|
+
|
|
97
|
+
return dir.view_as(tensor)
|
|
98
|
+
|
|
99
|
+
class Eigengrad(TensorTransform):
|
|
100
|
+
"""we can easily compute rank 1 symmetric update to a low rank eigendecomposition.
|
|
101
|
+
So this stores covariance matrix as it.
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
Args:
|
|
105
|
+
rank (int): maximum allowed rank
|
|
106
|
+
beta (float, optional): beta for covariance matrix exponential moving average. Defaults to 0.95.
|
|
107
|
+
eig_tol (float, optional):
|
|
108
|
+
removes eigenvalues this much smaller than largest eigenvalue when updating the preconditioner. Defaults to 1e-7.
|
|
109
|
+
damping (float, optional):
|
|
110
|
+
added to eigenvalues when updating the preconditioner. Defaults to 1e-8.
|
|
111
|
+
rdamping (float, optional):
|
|
112
|
+
added to eigenvalues when updating the preconditioner, relative to largest eigenvalue. Defaults to 0.
|
|
113
|
+
mm_tol (float, optional):
|
|
114
|
+
removes eigenvalues this much smaller than largest eigenvalue when computing the update. Defaults to 1e-7.
|
|
115
|
+
mm_truncate (int | None, optional):
|
|
116
|
+
uses top k eigenvalues to compute the update. Defaults to None.
|
|
117
|
+
mm_damping (float, optional):
|
|
118
|
+
added to eigenvalues when computing the update. Defaults to 1e-4.
|
|
119
|
+
mm_rdamping (float, optional):
|
|
120
|
+
added to eigenvalues when computing the update, relative to largest eigenvalue. Defaults to 0.
|
|
121
|
+
id_reg (float, optional):
|
|
122
|
+
multiplier to identity matrix added to preconditioner before computing update
|
|
123
|
+
If this value is given, solution from Nyström sketch-and-solve will be used to compute the update.
|
|
124
|
+
This value can't be too small (i.e. less than 1e-5) or the solver will be very unstable. Defaults to None.
|
|
125
|
+
column_space_tol (float, optional):
|
|
126
|
+
tolerance for deciding if new eigenvector is within column space of the covariance matrix. Defaults to 1e-9.
|
|
127
|
+
concat_params (bool, optional):
|
|
128
|
+
whether to precondition all parameters at once if True, or each separately if False. Defaults to True.
|
|
129
|
+
update_freq (int, optional): update frequency. Defaults to 1.
|
|
130
|
+
inner (Chainable | None, optional): inner modules. Defaults to None.
|
|
131
|
+
|
|
132
|
+
"""
|
|
133
|
+
|
|
134
|
+
def __init__(
|
|
135
|
+
self,
|
|
136
|
+
rank: int = 100,
|
|
137
|
+
beta=0.95,
|
|
138
|
+
eig_tol: float | None = 1e-5,
|
|
139
|
+
damping: float = 0,
|
|
140
|
+
rdamping: float = 0,
|
|
141
|
+
mm_tol: float = 0,
|
|
142
|
+
mm_truncate: int | None = None,
|
|
143
|
+
mm_damping: float = 1e-4,
|
|
144
|
+
mm_rdamping: float = 0,
|
|
145
|
+
id_reg: float | None = None,
|
|
146
|
+
column_space_tol = 1e-9,
|
|
147
|
+
|
|
148
|
+
orthogonalize_interval: int | None = None,
|
|
149
|
+
orthogonalize_method: OrthogonalizeMethod = 'qr',
|
|
150
|
+
|
|
151
|
+
eigenbasis_optimizer: LREOptimizerBase | None = None,
|
|
152
|
+
concat_params: bool = True,
|
|
153
|
+
update_freq: int = 1,
|
|
154
|
+
inner: Chainable | None = None,
|
|
155
|
+
):
|
|
156
|
+
defaults = locals().copy()
|
|
157
|
+
for k in ["self", "concat_params", "inner", "update_freq"]:
|
|
158
|
+
del defaults[k]
|
|
159
|
+
|
|
160
|
+
super().__init__(defaults, concat_params=concat_params, inner=inner, update_freq=update_freq)
|
|
161
|
+
|
|
162
|
+
def single_tensor_update(self, tensor, param, grad, loss, state, setting):
|
|
163
|
+
state["step"] = state.get("step", 0) + 1
|
|
164
|
+
beta = setting["beta"]
|
|
165
|
+
|
|
166
|
+
if "L" not in state:
|
|
167
|
+
# for uu^T u is eigenvector and u^T u is eigenvalue
|
|
168
|
+
norm = torch.linalg.vector_norm(tensor).clip(min=torch.finfo(tensor.dtype).tiny * 2) # pylint:disable=not-callable
|
|
169
|
+
|
|
170
|
+
state["L"] = state["L_reg"] = (tensor.dot(tensor).unsqueeze(0) / norm) # (rank,)
|
|
171
|
+
state["Q"] = state["Q_reg"] = tensor.unsqueeze(-1) / norm # (m, rank)
|
|
172
|
+
|
|
173
|
+
else:
|
|
174
|
+
try:
|
|
175
|
+
L = state["L"]
|
|
176
|
+
Q = state["Q"]
|
|
177
|
+
|
|
178
|
+
# compute new factors
|
|
179
|
+
L_new, Q_new = eigh_plus_uuT(L*beta, Q, tensor, alpha=(1-beta), tol=setting["column_space_tol"], retry_float64=True)
|
|
180
|
+
|
|
181
|
+
# truncate/regularize new factors (those go into the accumulator)
|
|
182
|
+
L_new, Q_new = regularize_eigh(L=L_new, Q=Q_new, truncate=setting["rank"], tol=setting["eig_tol"],
|
|
183
|
+
damping=setting["damping"], rdamping=setting["rdamping"])
|
|
184
|
+
|
|
185
|
+
_eigengrad_update_state_(state=state, setting=setting, L_new=L_new, Q_new=Q_new)
|
|
186
|
+
|
|
187
|
+
except torch.linalg.LinAlgError:
|
|
188
|
+
pass
|
|
189
|
+
|
|
190
|
+
def single_tensor_apply(self, tensor, param, grad, loss, state, setting):
|
|
191
|
+
if "L_reg" not in state:
|
|
192
|
+
return tensor.clip(-0.1, 0.1)
|
|
193
|
+
|
|
194
|
+
if "eigenbasis_state" not in state:
|
|
195
|
+
state["eigenbasis_state"] = {}
|
|
196
|
+
|
|
197
|
+
return eigengrad_apply(
|
|
198
|
+
tensor = tensor,
|
|
199
|
+
L_reg = state["L_reg"],
|
|
200
|
+
Q_reg = state["Q_reg"],
|
|
201
|
+
beta = setting["beta"],
|
|
202
|
+
step = state["step"],
|
|
203
|
+
debias = True,
|
|
204
|
+
id_reg = setting["id_reg"],
|
|
205
|
+
eigenbasis_optimizer = setting["eigenbasis_optimizer"],
|
|
206
|
+
eigenbasis_state = state["eigenbasis_state"]
|
|
207
|
+
)
|
|
@@ -5,7 +5,7 @@ from typing import Literal
|
|
|
5
5
|
|
|
6
6
|
import torch
|
|
7
7
|
|
|
8
|
-
from ...core import Module,
|
|
8
|
+
from ...core import Module, Objective, Chainable
|
|
9
9
|
from ...utils import NumberList, TensorList
|
|
10
10
|
from ...utils.derivatives import jacobian_wrt
|
|
11
11
|
from ..grad_approximation import GradApproximator, GradTarget
|
|
@@ -43,7 +43,7 @@ class GradMin(Reformulation):
|
|
|
43
43
|
super().__init__(defaults, modules=modules)
|
|
44
44
|
|
|
45
45
|
@torch.no_grad
|
|
46
|
-
def closure(self, backward, closure, params,
|
|
46
|
+
def closure(self, backward, closure, params, objective):
|
|
47
47
|
settings = self.settings[params[0]]
|
|
48
48
|
loss_term = settings['loss_term']
|
|
49
49
|
relative = settings['relative']
|