torchzero 0.3.15__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_identical.py +22 -22
- tests/test_module_autograd.py +586 -0
- tests/test_objective.py +188 -0
- tests/test_opts.py +225 -214
- tests/test_tensorlist.py +0 -8
- tests/test_utils_optimizer.py +0 -1
- torchzero/__init__.py +2 -2
- torchzero/core/__init__.py +7 -4
- torchzero/core/chain.py +20 -23
- torchzero/core/functional.py +90 -24
- torchzero/core/modular.py +53 -57
- torchzero/core/module.py +132 -52
- torchzero/core/objective.py +948 -0
- torchzero/core/reformulation.py +55 -24
- torchzero/core/transform.py +261 -367
- torchzero/linalg/__init__.py +11 -0
- torchzero/linalg/eigh.py +253 -0
- torchzero/linalg/linalg_utils.py +14 -0
- torchzero/{utils/linalg → linalg}/linear_operator.py +99 -49
- torchzero/linalg/matrix_power.py +28 -0
- torchzero/linalg/orthogonalize.py +93 -0
- torchzero/{utils/linalg → linalg}/qr.py +16 -2
- torchzero/{utils/linalg → linalg}/solve.py +74 -88
- torchzero/linalg/svd.py +47 -0
- torchzero/linalg/torch_linalg.py +168 -0
- torchzero/modules/__init__.py +4 -3
- torchzero/modules/adaptive/__init__.py +11 -3
- torchzero/modules/adaptive/adagrad.py +167 -217
- torchzero/modules/adaptive/adahessian.py +76 -105
- torchzero/modules/adaptive/adam.py +53 -76
- torchzero/modules/adaptive/adan.py +50 -31
- torchzero/modules/adaptive/adaptive_heavyball.py +12 -7
- torchzero/modules/adaptive/aegd.py +12 -12
- torchzero/modules/adaptive/esgd.py +98 -119
- torchzero/modules/adaptive/ggt.py +186 -0
- torchzero/modules/adaptive/lion.py +7 -11
- torchzero/modules/adaptive/lre_optimizers.py +299 -0
- torchzero/modules/adaptive/mars.py +7 -7
- torchzero/modules/adaptive/matrix_momentum.py +48 -52
- torchzero/modules/adaptive/msam.py +71 -53
- torchzero/modules/adaptive/muon.py +67 -129
- torchzero/modules/adaptive/natural_gradient.py +63 -41
- torchzero/modules/adaptive/orthograd.py +11 -15
- torchzero/modules/adaptive/psgd/__init__.py +5 -0
- torchzero/modules/adaptive/psgd/_psgd_utils.py +37 -0
- torchzero/modules/adaptive/psgd/psgd.py +1390 -0
- torchzero/modules/adaptive/psgd/psgd_dense_newton.py +174 -0
- torchzero/modules/adaptive/psgd/psgd_kron_newton.py +203 -0
- torchzero/modules/adaptive/psgd/psgd_kron_whiten.py +185 -0
- torchzero/modules/adaptive/psgd/psgd_lra_newton.py +118 -0
- torchzero/modules/adaptive/psgd/psgd_lra_whiten.py +116 -0
- torchzero/modules/adaptive/rmsprop.py +83 -75
- torchzero/modules/adaptive/rprop.py +48 -47
- torchzero/modules/adaptive/sam.py +55 -45
- torchzero/modules/adaptive/shampoo.py +149 -130
- torchzero/modules/adaptive/soap.py +207 -143
- torchzero/modules/adaptive/sophia_h.py +106 -130
- torchzero/modules/clipping/clipping.py +22 -25
- torchzero/modules/clipping/ema_clipping.py +31 -25
- torchzero/modules/clipping/growth_clipping.py +14 -17
- torchzero/modules/conjugate_gradient/cg.py +27 -38
- torchzero/modules/experimental/__init__.py +7 -6
- torchzero/modules/experimental/adanystrom.py +258 -0
- torchzero/modules/experimental/common_directions_whiten.py +142 -0
- torchzero/modules/experimental/coordinate_momentum.py +36 -0
- torchzero/modules/experimental/cubic_adam.py +160 -0
- torchzero/modules/experimental/curveball.py +25 -41
- torchzero/modules/experimental/eigen_sr1.py +182 -0
- torchzero/modules/experimental/eigengrad.py +207 -0
- torchzero/modules/experimental/gradmin.py +2 -2
- torchzero/modules/experimental/higher_order_newton.py +14 -40
- torchzero/modules/experimental/l_infinity.py +1 -1
- torchzero/modules/experimental/matrix_nag.py +122 -0
- torchzero/modules/experimental/newton_solver.py +23 -54
- torchzero/modules/experimental/newtonnewton.py +45 -48
- torchzero/modules/experimental/reduce_outward_lr.py +7 -7
- torchzero/modules/experimental/scipy_newton_cg.py +21 -24
- torchzero/modules/experimental/spsa1.py +3 -3
- torchzero/modules/experimental/structural_projections.py +1 -4
- torchzero/modules/grad_approximation/fdm.py +2 -2
- torchzero/modules/grad_approximation/forward_gradient.py +7 -7
- torchzero/modules/grad_approximation/grad_approximator.py +23 -16
- torchzero/modules/grad_approximation/rfdm.py +24 -21
- torchzero/modules/least_squares/gn.py +121 -50
- torchzero/modules/line_search/backtracking.py +4 -4
- torchzero/modules/line_search/line_search.py +33 -33
- torchzero/modules/line_search/strong_wolfe.py +4 -4
- torchzero/modules/misc/debug.py +12 -12
- torchzero/modules/misc/escape.py +10 -10
- torchzero/modules/misc/gradient_accumulation.py +11 -79
- torchzero/modules/misc/homotopy.py +16 -8
- torchzero/modules/misc/misc.py +121 -123
- torchzero/modules/misc/multistep.py +52 -53
- torchzero/modules/misc/regularization.py +49 -44
- torchzero/modules/misc/split.py +31 -29
- torchzero/modules/misc/switch.py +37 -32
- torchzero/modules/momentum/averaging.py +14 -14
- torchzero/modules/momentum/cautious.py +37 -31
- torchzero/modules/momentum/momentum.py +12 -12
- torchzero/modules/ops/__init__.py +4 -4
- torchzero/modules/ops/accumulate.py +21 -21
- torchzero/modules/ops/binary.py +67 -66
- torchzero/modules/ops/higher_level.py +20 -20
- torchzero/modules/ops/multi.py +44 -41
- torchzero/modules/ops/reduce.py +26 -23
- torchzero/modules/ops/unary.py +53 -53
- torchzero/modules/ops/utility.py +47 -46
- torchzero/modules/{functional.py → opt_utils.py} +1 -1
- torchzero/modules/projections/galore.py +1 -1
- torchzero/modules/projections/projection.py +46 -43
- torchzero/modules/quasi_newton/__init__.py +1 -1
- torchzero/modules/quasi_newton/damping.py +2 -2
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +1 -1
- torchzero/modules/quasi_newton/lbfgs.py +10 -10
- torchzero/modules/quasi_newton/lsr1.py +10 -10
- torchzero/modules/quasi_newton/quasi_newton.py +54 -39
- torchzero/modules/quasi_newton/sg2.py +69 -205
- torchzero/modules/restarts/restars.py +39 -37
- torchzero/modules/second_order/__init__.py +2 -2
- torchzero/modules/second_order/ifn.py +31 -62
- torchzero/modules/second_order/inm.py +57 -53
- torchzero/modules/second_order/multipoint.py +40 -80
- torchzero/modules/second_order/newton.py +165 -196
- torchzero/modules/second_order/newton_cg.py +105 -157
- torchzero/modules/second_order/nystrom.py +216 -185
- torchzero/modules/second_order/rsn.py +132 -125
- torchzero/modules/smoothing/laplacian.py +13 -12
- torchzero/modules/smoothing/sampling.py +10 -10
- torchzero/modules/step_size/adaptive.py +24 -24
- torchzero/modules/step_size/lr.py +17 -17
- torchzero/modules/termination/termination.py +32 -30
- torchzero/modules/trust_region/cubic_regularization.py +3 -3
- torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
- torchzero/modules/trust_region/trust_cg.py +2 -2
- torchzero/modules/trust_region/trust_region.py +27 -22
- torchzero/modules/variance_reduction/svrg.py +23 -21
- torchzero/modules/weight_decay/__init__.py +2 -1
- torchzero/modules/weight_decay/reinit.py +83 -0
- torchzero/modules/weight_decay/weight_decay.py +17 -18
- torchzero/modules/wrappers/optim_wrapper.py +14 -14
- torchzero/modules/zeroth_order/cd.py +10 -7
- torchzero/optim/mbs.py +291 -0
- torchzero/optim/root.py +3 -3
- torchzero/optim/utility/split.py +2 -1
- torchzero/optim/wrappers/directsearch.py +27 -63
- torchzero/optim/wrappers/fcmaes.py +14 -35
- torchzero/optim/wrappers/mads.py +11 -31
- torchzero/optim/wrappers/moors.py +66 -0
- torchzero/optim/wrappers/nevergrad.py +4 -13
- torchzero/optim/wrappers/nlopt.py +31 -25
- torchzero/optim/wrappers/optuna.py +8 -13
- torchzero/optim/wrappers/pybobyqa.py +124 -0
- torchzero/optim/wrappers/scipy/__init__.py +7 -0
- torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
- torchzero/optim/wrappers/scipy/brute.py +48 -0
- torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
- torchzero/optim/wrappers/scipy/direct.py +69 -0
- torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
- torchzero/optim/wrappers/scipy/experimental.py +141 -0
- torchzero/optim/wrappers/scipy/minimize.py +151 -0
- torchzero/optim/wrappers/scipy/sgho.py +111 -0
- torchzero/optim/wrappers/wrapper.py +121 -0
- torchzero/utils/__init__.py +7 -25
- torchzero/utils/benchmarks/__init__.py +0 -0
- torchzero/utils/benchmarks/logistic.py +122 -0
- torchzero/utils/compile.py +2 -2
- torchzero/utils/derivatives.py +97 -73
- torchzero/utils/optimizer.py +4 -77
- torchzero/utils/python_tools.py +31 -0
- torchzero/utils/tensorlist.py +11 -5
- torchzero/utils/thoad_tools.py +68 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/METADATA +1 -1
- torchzero-0.4.1.dist-info/RECORD +209 -0
- tests/test_vars.py +0 -185
- torchzero/core/var.py +0 -376
- torchzero/modules/adaptive/lmadagrad.py +0 -186
- torchzero/modules/experimental/momentum.py +0 -160
- torchzero/optim/wrappers/scipy.py +0 -572
- torchzero/utils/linalg/__init__.py +0 -12
- torchzero/utils/linalg/matrix_funcs.py +0 -87
- torchzero/utils/linalg/orthogonalize.py +0 -12
- torchzero/utils/linalg/svd.py +0 -20
- torchzero/utils/ops.py +0 -10
- torchzero-0.3.15.dist-info/RECORD +0 -175
- /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/WHEEL +0 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/top_level.txt +0 -0
|
@@ -1,12 +1,12 @@
|
|
|
1
1
|
import torch
|
|
2
|
-
from ...core import
|
|
2
|
+
from ...core import Transform
|
|
3
3
|
|
|
4
4
|
from ...utils.derivatives import jacobian_wrt, flatten_jacobian
|
|
5
|
-
from ...utils import vec_to_tensors
|
|
6
|
-
from ...
|
|
7
|
-
from .
|
|
5
|
+
from ...utils import vec_to_tensors
|
|
6
|
+
from ...linalg import linear_operator
|
|
7
|
+
from .ggt import ggt_update
|
|
8
8
|
|
|
9
|
-
class NaturalGradient(
|
|
9
|
+
class NaturalGradient(Transform):
|
|
10
10
|
"""Natural gradient approximated via empirical fisher information matrix.
|
|
11
11
|
|
|
12
12
|
To use this, either pass vector of per-sample losses to the step method, or make sure
|
|
@@ -27,9 +27,9 @@ class NaturalGradient(Module):
|
|
|
27
27
|
with a vector that isn't strictly per-sample gradients, but rather for example different losses.
|
|
28
28
|
gn_grad (bool, optional):
|
|
29
29
|
if True, uses Gauss-Newton G^T @ f as the gradient, which is effectively sum weighted by value
|
|
30
|
-
and is equivalent to squaring the values.
|
|
31
|
-
|
|
32
|
-
This has an effect when ``sqrt=
|
|
30
|
+
and is equivalent to squaring the values. That makes the kernel trick solver incorrect, but for
|
|
31
|
+
some reason it still works. If False, uses sum of per-sample gradients.
|
|
32
|
+
This has an effect when ``sqrt=False``, and affects the ``grad`` attribute.
|
|
33
33
|
Defaults to False.
|
|
34
34
|
batched (bool, optional): whether to use vmapping. Defaults to True.
|
|
35
35
|
|
|
@@ -41,7 +41,7 @@ class NaturalGradient(Module):
|
|
|
41
41
|
y = torch.randn(64, 10)
|
|
42
42
|
|
|
43
43
|
model = nn.Sequential(nn.Linear(20, 64), nn.ELU(), nn.Linear(64, 10))
|
|
44
|
-
opt = tz.
|
|
44
|
+
opt = tz.Optimizer(
|
|
45
45
|
model.parameters(),
|
|
46
46
|
tz.m.NaturalGradient(),
|
|
47
47
|
tz.m.LR(3e-2)
|
|
@@ -61,7 +61,7 @@ class NaturalGradient(Module):
|
|
|
61
61
|
y = torch.randn(64, 10)
|
|
62
62
|
|
|
63
63
|
model = nn.Sequential(nn.Linear(20, 64), nn.ELU(), nn.Linear(64, 10))
|
|
64
|
-
opt = tz.
|
|
64
|
+
opt = tz.Optimizer(
|
|
65
65
|
model.parameters(),
|
|
66
66
|
tz.m.NaturalGradient(),
|
|
67
67
|
tz.m.LR(3e-2)
|
|
@@ -84,7 +84,7 @@ class NaturalGradient(Module):
|
|
|
84
84
|
return torch.stack([(1 - x1).abs(), (10 * (x2 - x1**2).abs())])
|
|
85
85
|
|
|
86
86
|
X = torch.tensor([-1.1, 2.5], requires_grad=True)
|
|
87
|
-
opt = tz.
|
|
87
|
+
opt = tz.Optimizer([X], tz.m.NaturalGradient(sqrt=True, gn_grad=True), tz.m.LR(0.05))
|
|
88
88
|
|
|
89
89
|
for iter in range(200):
|
|
90
90
|
losses = rosenbrock(X)
|
|
@@ -97,20 +97,27 @@ class NaturalGradient(Module):
|
|
|
97
97
|
super().__init__(defaults=dict(batched=batched, reg=reg, sqrt=sqrt, gn_grad=gn_grad))
|
|
98
98
|
|
|
99
99
|
@torch.no_grad
|
|
100
|
-
def
|
|
101
|
-
params =
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
100
|
+
def update_states(self, objective, states, settings):
|
|
101
|
+
params = objective.params
|
|
102
|
+
closure = objective.closure
|
|
103
|
+
fs = settings[0]
|
|
104
|
+
batched = fs['batched']
|
|
105
|
+
gn_grad = fs['gn_grad']
|
|
106
|
+
|
|
107
|
+
# compute per-sample losses
|
|
108
|
+
f = objective.loss
|
|
109
|
+
if f is None:
|
|
110
|
+
assert closure is not None
|
|
111
|
+
with torch.enable_grad():
|
|
112
|
+
f = objective.get_loss(backward=False) # n_out
|
|
113
|
+
assert isinstance(f, torch.Tensor)
|
|
114
|
+
|
|
115
|
+
# compute per-sample gradients
|
|
108
116
|
with torch.enable_grad():
|
|
109
|
-
f = var.get_loss(backward=False) # n_out
|
|
110
|
-
assert isinstance(f, torch.Tensor)
|
|
111
117
|
G_list = jacobian_wrt([f.ravel()], params, batched=batched)
|
|
112
118
|
|
|
113
|
-
|
|
119
|
+
# set scalar loss and it's grad to objective
|
|
120
|
+
objective.loss = f.sum()
|
|
114
121
|
G = self.global_state["G"] = flatten_jacobian(G_list) # (n_samples, ndim)
|
|
115
122
|
|
|
116
123
|
if gn_grad:
|
|
@@ -119,13 +126,15 @@ class NaturalGradient(Module):
|
|
|
119
126
|
else:
|
|
120
127
|
g = self.global_state["g"] = G.sum(0)
|
|
121
128
|
|
|
122
|
-
|
|
129
|
+
objective.grads = vec_to_tensors(g, params)
|
|
123
130
|
|
|
124
131
|
# set closure to calculate scalar value for line searches etc
|
|
125
|
-
if
|
|
132
|
+
if closure is not None:
|
|
133
|
+
|
|
126
134
|
def ngd_closure(backward=True):
|
|
135
|
+
|
|
127
136
|
if backward:
|
|
128
|
-
|
|
137
|
+
objective.zero_grad()
|
|
129
138
|
with torch.enable_grad():
|
|
130
139
|
loss = closure(False)
|
|
131
140
|
if gn_grad: loss = loss.pow(2)
|
|
@@ -137,39 +146,52 @@ class NaturalGradient(Module):
|
|
|
137
146
|
if gn_grad: loss = loss.pow(2)
|
|
138
147
|
return loss.sum()
|
|
139
148
|
|
|
140
|
-
|
|
149
|
+
objective.closure = ngd_closure
|
|
141
150
|
|
|
142
151
|
@torch.no_grad
|
|
143
|
-
def
|
|
144
|
-
params =
|
|
145
|
-
|
|
146
|
-
|
|
152
|
+
def apply_states(self, objective, states, settings):
|
|
153
|
+
params = objective.params
|
|
154
|
+
fs = settings[0]
|
|
155
|
+
reg = fs['reg']
|
|
156
|
+
sqrt = fs['sqrt']
|
|
147
157
|
|
|
148
158
|
G: torch.Tensor = self.global_state['G'] # (n_samples, n_dim)
|
|
149
159
|
|
|
150
160
|
if sqrt:
|
|
151
161
|
# this computes U, S <- SVD(M), then calculate update as U S^-1 Uᵀg,
|
|
152
162
|
# but it computes it through eigendecompotision
|
|
153
|
-
|
|
154
|
-
|
|
163
|
+
L, U = ggt_update(G.H, damping=reg, rdamping=1e-16, truncate=0, eig_tol=1e-12)
|
|
164
|
+
|
|
165
|
+
if U is None or L is None:
|
|
166
|
+
|
|
167
|
+
# fallback to element-wise
|
|
168
|
+
g = self.global_state["g"]
|
|
169
|
+
g /= G.square().mean(0).sqrt().add(reg)
|
|
170
|
+
objective.updates = vec_to_tensors(g, params)
|
|
171
|
+
return objective
|
|
155
172
|
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
173
|
+
# whiten
|
|
174
|
+
z = U.T @ self.global_state["g"]
|
|
175
|
+
v = (U * L.rsqrt()) @ z
|
|
176
|
+
objective.updates = vec_to_tensors(v, params)
|
|
177
|
+
return objective
|
|
159
178
|
|
|
160
|
-
|
|
179
|
+
# we need (G^T G)v = g
|
|
180
|
+
# where g = G^T
|
|
181
|
+
# so we need to solve (G^T G)v = G^T
|
|
182
|
+
GGt = G @ G.H # (n_samples, n_samples)
|
|
161
183
|
|
|
162
184
|
if reg != 0:
|
|
163
|
-
|
|
185
|
+
GGt.add_(torch.eye(GGt.size(0), device=GGt.device, dtype=GGt.dtype).mul_(reg))
|
|
164
186
|
|
|
165
|
-
z, _ = torch.linalg.solve_ex(
|
|
187
|
+
z, _ = torch.linalg.solve_ex(GGt, torch.ones_like(GGt[0])) # pylint:disable=not-callable
|
|
166
188
|
v = G.H @ z
|
|
167
189
|
|
|
168
|
-
|
|
169
|
-
return
|
|
190
|
+
objective.updates = vec_to_tensors(v, params)
|
|
191
|
+
return objective
|
|
170
192
|
|
|
171
193
|
|
|
172
|
-
def get_H(self,
|
|
194
|
+
def get_H(self, objective=...):
|
|
173
195
|
if "G" not in self.global_state: return linear_operator.ScaledIdentity()
|
|
174
196
|
G = self.global_state['G']
|
|
175
197
|
return linear_operator.AtA(G)
|
|
@@ -1,13 +1,9 @@
|
|
|
1
|
-
from
|
|
2
|
-
import math
|
|
3
|
-
import warnings
|
|
4
|
-
from collections.abc import Iterable, Sequence
|
|
5
|
-
from typing import Literal
|
|
1
|
+
from collections.abc import Iterable
|
|
6
2
|
|
|
7
3
|
import torch
|
|
8
4
|
|
|
9
|
-
from ...core import
|
|
10
|
-
from ...utils import
|
|
5
|
+
from ...core import TensorTransform
|
|
6
|
+
from ...utils import TensorList
|
|
11
7
|
|
|
12
8
|
def orthograd_(params: Iterable[torch.Tensor], eps: float = 1e-30):
|
|
13
9
|
"""Applies ⟂Grad - projects gradient of an iterable of parameters to be orthogonal to the weights.
|
|
@@ -19,29 +15,29 @@ def orthograd_(params: Iterable[torch.Tensor], eps: float = 1e-30):
|
|
|
19
15
|
reference
|
|
20
16
|
https://arxiv.org/abs/2501.04697
|
|
21
17
|
"""
|
|
22
|
-
params =
|
|
18
|
+
params = TensorList(params).with_grad()
|
|
23
19
|
grad = params.grad
|
|
24
20
|
grad -= (params.dot(grad)/(params.dot(params) + eps)) * params
|
|
25
21
|
|
|
26
22
|
|
|
27
|
-
class OrthoGrad(
|
|
23
|
+
class OrthoGrad(TensorTransform):
|
|
28
24
|
"""Applies ⟂Grad - projects gradient of an iterable of parameters to be orthogonal to the weights.
|
|
29
25
|
|
|
30
26
|
Args:
|
|
31
27
|
eps (float, optional): epsilon added to the denominator for numerical stability (default: 1e-30)
|
|
32
28
|
renormalize (bool, optional): whether to graft projected gradient to original gradient norm. Defaults to True.
|
|
33
|
-
target (Target, optional): what to set on var. Defaults to 'update'.
|
|
34
29
|
"""
|
|
35
|
-
def __init__(self, eps: float = 1e-8, renormalize=True
|
|
30
|
+
def __init__(self, eps: float = 1e-8, renormalize=True):
|
|
36
31
|
defaults = dict(eps=eps, renormalize=renormalize)
|
|
37
|
-
super().__init__(defaults
|
|
32
|
+
super().__init__(defaults)
|
|
38
33
|
|
|
39
|
-
|
|
34
|
+
@torch.no_grad
|
|
35
|
+
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
40
36
|
eps = settings[0]['eps']
|
|
41
37
|
renormalize = settings[0]['renormalize']
|
|
42
38
|
|
|
43
|
-
params =
|
|
44
|
-
target =
|
|
39
|
+
params = TensorList(params)
|
|
40
|
+
target = TensorList(tensors)
|
|
45
41
|
|
|
46
42
|
scale = params.dot(target)/(params.dot(params) + eps)
|
|
47
43
|
if renormalize:
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
# pylint:disable=not-callable
|
|
2
|
+
import warnings
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
from .psgd import lift2single
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def _initialize_lra_state_(tensor: torch.Tensor, state, setting):
|
|
10
|
+
n = tensor.numel()
|
|
11
|
+
rank = max(min(setting["rank"], n-1), 1)
|
|
12
|
+
dtype=tensor.dtype
|
|
13
|
+
device=tensor.device
|
|
14
|
+
|
|
15
|
+
U = torch.randn((n, rank), dtype=dtype, device=device)
|
|
16
|
+
U *= 0.1**0.5 / torch.linalg.vector_norm(U)
|
|
17
|
+
|
|
18
|
+
V = torch.randn((n, rank), dtype=dtype, device=device)
|
|
19
|
+
V *= 0.1**0.5 / torch.linalg.vector_norm(V)
|
|
20
|
+
|
|
21
|
+
if setting["init_scale"] is None:
|
|
22
|
+
# warnings.warn("FYI: Will set the preconditioner initial scale on the fly. Recommend to set it manually.")
|
|
23
|
+
d = None
|
|
24
|
+
else:
|
|
25
|
+
d = torch.ones(n, 1, dtype=dtype, device=device) * setting["init_scale"]
|
|
26
|
+
|
|
27
|
+
state["UVd"] = [U, V, d]
|
|
28
|
+
state["Luvd"] = [lift2single(torch.zeros([], dtype=dtype, device=device)) for _ in range(3)]
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def _wrap_with_no_backward(opt):
|
|
33
|
+
"""to use original psgd opts with visualbench"""
|
|
34
|
+
class _Wrapped:
|
|
35
|
+
def step(self, closure):
|
|
36
|
+
return opt.step(lambda: closure(False))
|
|
37
|
+
return _Wrapped()
|