torchzero 0.4.0__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_identical.py +22 -22
- tests/test_opts.py +199 -198
- torchzero/__init__.py +1 -1
- torchzero/core/__init__.py +1 -1
- torchzero/core/functional.py +1 -1
- torchzero/core/modular.py +5 -5
- torchzero/core/module.py +2 -2
- torchzero/core/objective.py +10 -10
- torchzero/core/transform.py +1 -1
- torchzero/linalg/__init__.py +3 -2
- torchzero/linalg/eigh.py +223 -4
- torchzero/linalg/orthogonalize.py +2 -4
- torchzero/linalg/qr.py +12 -0
- torchzero/linalg/solve.py +1 -3
- torchzero/linalg/svd.py +47 -20
- torchzero/modules/__init__.py +4 -3
- torchzero/modules/adaptive/__init__.py +11 -3
- torchzero/modules/adaptive/adagrad.py +10 -10
- torchzero/modules/adaptive/adahessian.py +2 -2
- torchzero/modules/adaptive/adam.py +1 -1
- torchzero/modules/adaptive/adan.py +1 -1
- torchzero/modules/adaptive/adaptive_heavyball.py +1 -1
- torchzero/modules/adaptive/esgd.py +2 -2
- torchzero/modules/adaptive/ggt.py +186 -0
- torchzero/modules/adaptive/lion.py +2 -1
- torchzero/modules/adaptive/lre_optimizers.py +299 -0
- torchzero/modules/adaptive/mars.py +2 -2
- torchzero/modules/adaptive/matrix_momentum.py +1 -1
- torchzero/modules/adaptive/msam.py +4 -4
- torchzero/modules/adaptive/muon.py +9 -6
- torchzero/modules/adaptive/natural_gradient.py +32 -15
- torchzero/modules/adaptive/psgd/__init__.py +5 -0
- torchzero/modules/adaptive/psgd/_psgd_utils.py +37 -0
- torchzero/modules/adaptive/psgd/psgd.py +1390 -0
- torchzero/modules/adaptive/psgd/psgd_dense_newton.py +174 -0
- torchzero/modules/adaptive/psgd/psgd_kron_newton.py +203 -0
- torchzero/modules/adaptive/psgd/psgd_kron_whiten.py +185 -0
- torchzero/modules/adaptive/psgd/psgd_lra_newton.py +118 -0
- torchzero/modules/adaptive/psgd/psgd_lra_whiten.py +116 -0
- torchzero/modules/adaptive/rprop.py +2 -2
- torchzero/modules/adaptive/sam.py +4 -4
- torchzero/modules/adaptive/shampoo.py +28 -3
- torchzero/modules/adaptive/soap.py +3 -3
- torchzero/modules/adaptive/sophia_h.py +2 -2
- torchzero/modules/clipping/clipping.py +7 -7
- torchzero/modules/conjugate_gradient/cg.py +2 -2
- torchzero/modules/experimental/__init__.py +5 -0
- torchzero/modules/experimental/adanystrom.py +258 -0
- torchzero/modules/experimental/common_directions_whiten.py +142 -0
- torchzero/modules/experimental/cubic_adam.py +160 -0
- torchzero/modules/experimental/eigen_sr1.py +182 -0
- torchzero/modules/experimental/eigengrad.py +207 -0
- torchzero/modules/experimental/l_infinity.py +1 -1
- torchzero/modules/experimental/matrix_nag.py +122 -0
- torchzero/modules/experimental/newton_solver.py +2 -2
- torchzero/modules/experimental/newtonnewton.py +34 -40
- torchzero/modules/grad_approximation/fdm.py +2 -2
- torchzero/modules/grad_approximation/rfdm.py +4 -4
- torchzero/modules/least_squares/gn.py +68 -45
- torchzero/modules/line_search/backtracking.py +2 -2
- torchzero/modules/line_search/line_search.py +1 -1
- torchzero/modules/line_search/strong_wolfe.py +2 -2
- torchzero/modules/misc/escape.py +1 -1
- torchzero/modules/misc/gradient_accumulation.py +1 -1
- torchzero/modules/misc/misc.py +1 -1
- torchzero/modules/misc/multistep.py +4 -7
- torchzero/modules/misc/regularization.py +2 -2
- torchzero/modules/misc/split.py +1 -1
- torchzero/modules/misc/switch.py +2 -2
- torchzero/modules/momentum/cautious.py +3 -3
- torchzero/modules/momentum/momentum.py +1 -1
- torchzero/modules/ops/higher_level.py +1 -1
- torchzero/modules/ops/multi.py +1 -1
- torchzero/modules/projections/projection.py +5 -2
- torchzero/modules/quasi_newton/__init__.py +1 -1
- torchzero/modules/quasi_newton/damping.py +1 -1
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +1 -1
- torchzero/modules/quasi_newton/lbfgs.py +3 -3
- torchzero/modules/quasi_newton/lsr1.py +3 -3
- torchzero/modules/quasi_newton/quasi_newton.py +44 -29
- torchzero/modules/quasi_newton/sg2.py +69 -205
- torchzero/modules/restarts/restars.py +17 -17
- torchzero/modules/second_order/inm.py +33 -25
- torchzero/modules/second_order/newton.py +132 -130
- torchzero/modules/second_order/newton_cg.py +3 -3
- torchzero/modules/second_order/nystrom.py +83 -32
- torchzero/modules/second_order/rsn.py +41 -44
- torchzero/modules/smoothing/laplacian.py +1 -1
- torchzero/modules/smoothing/sampling.py +2 -3
- torchzero/modules/step_size/adaptive.py +6 -6
- torchzero/modules/step_size/lr.py +2 -2
- torchzero/modules/trust_region/cubic_regularization.py +1 -1
- torchzero/modules/trust_region/levenberg_marquardt.py +2 -2
- torchzero/modules/trust_region/trust_cg.py +1 -1
- torchzero/modules/variance_reduction/svrg.py +4 -5
- torchzero/modules/weight_decay/reinit.py +2 -2
- torchzero/modules/weight_decay/weight_decay.py +5 -5
- torchzero/modules/wrappers/optim_wrapper.py +4 -4
- torchzero/modules/zeroth_order/cd.py +1 -1
- torchzero/optim/mbs.py +291 -0
- torchzero/optim/wrappers/nevergrad.py +0 -9
- torchzero/optim/wrappers/optuna.py +2 -0
- torchzero/utils/benchmarks/__init__.py +0 -0
- torchzero/utils/benchmarks/logistic.py +122 -0
- torchzero/utils/derivatives.py +4 -4
- {torchzero-0.4.0.dist-info → torchzero-0.4.1.dist-info}/METADATA +1 -1
- torchzero-0.4.1.dist-info/RECORD +209 -0
- torchzero/modules/adaptive/lmadagrad.py +0 -241
- torchzero-0.4.0.dist-info/RECORD +0 -191
- /torchzero/modules/{functional.py → opt_utils.py} +0 -0
- {torchzero-0.4.0.dist-info → torchzero-0.4.1.dist-info}/WHEEL +0 -0
- {torchzero-0.4.0.dist-info → torchzero-0.4.1.dist-info}/top_level.txt +0 -0
|
@@ -6,10 +6,10 @@ from typing import Literal
|
|
|
6
6
|
import torch
|
|
7
7
|
|
|
8
8
|
from ...core import Chainable, Transform, HVPMethod
|
|
9
|
-
from ...utils import
|
|
9
|
+
from ...utils import vec_to_tensors_
|
|
10
10
|
from ...linalg.linear_operator import Sketched
|
|
11
11
|
|
|
12
|
-
from .newton import
|
|
12
|
+
from .newton import _newton_update_state_, _newton_solve
|
|
13
13
|
|
|
14
14
|
def _qr_orthonormalize(A:torch.Tensor):
|
|
15
15
|
m,n = A.shape
|
|
@@ -20,12 +20,10 @@ def _qr_orthonormalize(A:torch.Tensor):
|
|
|
20
20
|
q, _ = torch.linalg.qr(A) # pylint:disable=not-callable
|
|
21
21
|
return q
|
|
22
22
|
|
|
23
|
+
|
|
23
24
|
def _orthonormal_sketch(m, n, dtype, device, generator):
|
|
24
25
|
return _qr_orthonormalize(torch.randn(m, n, dtype=dtype, device=device, generator=generator))
|
|
25
26
|
|
|
26
|
-
def _gaussian_sketch(m, n, dtype, device, generator):
|
|
27
|
-
return torch.randn(m, n, dtype=dtype, device=device, generator=generator) / math.sqrt(m)
|
|
28
|
-
|
|
29
27
|
def _rademacher_sketch(m, n, dtype, device, generator):
|
|
30
28
|
rademacher = torch.bernoulli(torch.full((m,n), 0.5), generator = generator).mul_(2).sub_(1)
|
|
31
29
|
return rademacher.mul_(1 / math.sqrt(m))
|
|
@@ -37,11 +35,10 @@ class SubspaceNewton(Transform):
|
|
|
37
35
|
sketch_size (int):
|
|
38
36
|
size of the random sketch. This many hessian-vector products will need to be evaluated each step.
|
|
39
37
|
sketch_type (str, optional):
|
|
38
|
+
- "common_directions" - uses history steepest descent directions as the basis[2]. It is orthonormalized on-line using Gram-Schmidt (default).
|
|
40
39
|
- "orthonormal" - random orthonormal basis. Orthonormality is necessary to use linear operator based modules such as trust region, but it can be slower to compute.
|
|
41
|
-
- "rademacher" - approximately orthonormal scaled random rademacher basis.
|
|
42
|
-
- "
|
|
43
|
-
- "common_directions" - uses history steepest descent directions as the basis[2]. It is orthonormalized on-line using Gram-Schmidt.
|
|
44
|
-
- "mixed" - random orthonormal basis but with four directions set to gradient, slow and fast gradient EMAs, and previous update direction (default).
|
|
40
|
+
- "rademacher" - approximately orthonormal (if dimension is large) scaled random rademacher basis. It is recommended to use at least "orthonormal" - it requires QR but it is still very cheap.
|
|
41
|
+
- "mixed" - random orthonormal basis but with four directions set to gradient, slow and fast gradient EMAs, and previous update direction.
|
|
45
42
|
damping (float, optional): hessian damping (scale of identity matrix added to hessian). Defaults to 0.
|
|
46
43
|
hvp_method (str, optional):
|
|
47
44
|
How to compute hessian-matrix product:
|
|
@@ -73,7 +70,7 @@ class SubspaceNewton(Transform):
|
|
|
73
70
|
|
|
74
71
|
RSN with line search
|
|
75
72
|
```python
|
|
76
|
-
opt = tz.
|
|
73
|
+
opt = tz.Optimizer(
|
|
77
74
|
model.parameters(),
|
|
78
75
|
tz.m.RSN(),
|
|
79
76
|
tz.m.Backtracking()
|
|
@@ -82,7 +79,7 @@ class SubspaceNewton(Transform):
|
|
|
82
79
|
|
|
83
80
|
RSN with trust region
|
|
84
81
|
```python
|
|
85
|
-
opt = tz.
|
|
82
|
+
opt = tz.Optimizer(
|
|
86
83
|
model.parameters(),
|
|
87
84
|
tz.m.LevenbergMarquardt(tz.m.RSN()),
|
|
88
85
|
)
|
|
@@ -97,14 +94,14 @@ class SubspaceNewton(Transform):
|
|
|
97
94
|
def __init__(
|
|
98
95
|
self,
|
|
99
96
|
sketch_size: int,
|
|
100
|
-
sketch_type: Literal["orthonormal", "
|
|
97
|
+
sketch_type: Literal["orthonormal", "common_directions", "mixed", "rademacher"] = "common_directions",
|
|
101
98
|
damping:float=0,
|
|
99
|
+
eigval_fn: Callable[[torch.Tensor], torch.Tensor] | None = None,
|
|
100
|
+
update_freq: int = 1,
|
|
101
|
+
precompute_inverse: bool = False,
|
|
102
|
+
use_lstsq: bool = True,
|
|
102
103
|
hvp_method: HVPMethod = "batched_autograd",
|
|
103
104
|
h: float = 1e-2,
|
|
104
|
-
use_lstsq: bool = True,
|
|
105
|
-
update_freq: int = 1,
|
|
106
|
-
H_tfm: Callable[[torch.Tensor, torch.Tensor], tuple[torch.Tensor, bool]] | Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
|
|
107
|
-
eigval_fn: Callable[[torch.Tensor], torch.Tensor] | None = None,
|
|
108
105
|
seed: int | None = None,
|
|
109
106
|
inner: Chainable | None = None,
|
|
110
107
|
):
|
|
@@ -128,10 +125,7 @@ class SubspaceNewton(Transform):
|
|
|
128
125
|
sketch_type = fs["sketch_type"]
|
|
129
126
|
hvp_method = fs["hvp_method"]
|
|
130
127
|
|
|
131
|
-
if sketch_type
|
|
132
|
-
S = _gaussian_sketch(ndim, sketch_size, device=device, dtype=dtype, generator=generator)
|
|
133
|
-
|
|
134
|
-
elif sketch_type == "rademacher":
|
|
128
|
+
if sketch_type == "rademacher":
|
|
135
129
|
S = _rademacher_sketch(ndim, sketch_size, device=device, dtype=dtype, generator=generator)
|
|
136
130
|
|
|
137
131
|
elif sketch_type == 'orthonormal':
|
|
@@ -187,7 +181,7 @@ class SubspaceNewton(Transform):
|
|
|
187
181
|
# form and orthogonalize sketching matrix
|
|
188
182
|
S = torch.stack([g, slow_ema, fast_ema, prev_dir], dim=1)
|
|
189
183
|
if sketch_size > 4:
|
|
190
|
-
S_random =
|
|
184
|
+
S_random = torch.randn(ndim, sketch_size - 3, device=device, dtype=dtype, generator=generator) / math.sqrt(ndim)
|
|
191
185
|
S = torch.cat([S, S_random], dim=1)
|
|
192
186
|
|
|
193
187
|
S = _qr_orthonormalize(S)
|
|
@@ -200,38 +194,41 @@ class SubspaceNewton(Transform):
|
|
|
200
194
|
hvp_method=fs["hvp_method"], h=fs["h"])
|
|
201
195
|
H_sketched = S.T @ HS
|
|
202
196
|
|
|
203
|
-
|
|
197
|
+
# update state
|
|
198
|
+
_newton_update_state_(
|
|
199
|
+
state = self.global_state,
|
|
200
|
+
H = H_sketched,
|
|
201
|
+
damping = fs["damping"],
|
|
202
|
+
eigval_fn = fs["eigval_fn"],
|
|
203
|
+
precompute_inverse = fs["precompute_inverse"],
|
|
204
|
+
use_lstsq = fs["use_lstsq"]
|
|
205
|
+
|
|
206
|
+
)
|
|
207
|
+
|
|
204
208
|
self.global_state["S"] = S
|
|
205
209
|
|
|
206
210
|
def apply_states(self, objective, states, settings):
|
|
207
|
-
|
|
211
|
+
updates = objective.get_updates()
|
|
212
|
+
fs = settings[0]
|
|
208
213
|
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
eigval_fn=self.defaults["eigval_fn"],
|
|
215
|
-
use_lstsq=self.defaults["use_lstsq"],
|
|
216
|
-
g_proj = lambda g: S.T @ g
|
|
217
|
-
)
|
|
214
|
+
S = self.global_state["S"]
|
|
215
|
+
b = torch.cat([t.ravel() for t in updates])
|
|
216
|
+
b_proj = S.T @ b
|
|
217
|
+
|
|
218
|
+
d_proj = _newton_solve(b=b_proj, state=self.global_state, use_lstsq=fs["use_lstsq"])
|
|
218
219
|
|
|
219
220
|
d = S @ d_proj
|
|
220
|
-
|
|
221
|
+
vec_to_tensors_(d, updates)
|
|
221
222
|
return objective
|
|
222
223
|
|
|
223
224
|
def get_H(self, objective=...):
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
S: torch.Tensor = self.global_state["S"]
|
|
227
|
-
|
|
228
|
-
if eigval_fn is not None:
|
|
229
|
-
try:
|
|
230
|
-
L, Q = torch.linalg.eigh(H_sketched) # pylint:disable=not-callable
|
|
231
|
-
L: torch.Tensor = eigval_fn(L)
|
|
232
|
-
H_sketched = Q @ L.diag_embed() @ Q.mH
|
|
225
|
+
if "H" in self.global_state:
|
|
226
|
+
H_sketched = self.global_state["H"]
|
|
233
227
|
|
|
234
|
-
|
|
235
|
-
|
|
228
|
+
else:
|
|
229
|
+
L = self.global_state["L"]
|
|
230
|
+
Q = self.global_state["Q"]
|
|
231
|
+
H_sketched = Q @ L.diag_embed() @ Q.mH
|
|
236
232
|
|
|
233
|
+
S: torch.Tensor = self.global_state["S"]
|
|
237
234
|
return Sketched(S, H_sketched)
|
|
@@ -7,15 +7,14 @@ from typing import Literal, cast
|
|
|
7
7
|
|
|
8
8
|
import torch
|
|
9
9
|
|
|
10
|
-
from ...core import Chainable,
|
|
10
|
+
from ...core import Chainable, Optimizer, Module, Objective
|
|
11
11
|
from ...core.reformulation import Reformulation
|
|
12
12
|
from ...utils import Distributions, NumberList, TensorList
|
|
13
13
|
from ..termination import TerminationCriteriaBase, make_termination_criteria
|
|
14
14
|
|
|
15
15
|
|
|
16
16
|
def _reset_except_self(objective: Objective, modules, self: Module):
|
|
17
|
-
|
|
18
|
-
for m in objective.modular.flat_modules:
|
|
17
|
+
for m in modules:
|
|
19
18
|
if m is not self:
|
|
20
19
|
m.reset()
|
|
21
20
|
|
|
@@ -8,7 +8,7 @@ import torch
|
|
|
8
8
|
from ...core import Chainable, TensorTransform
|
|
9
9
|
from ...utils import NumberList, TensorList, tofloat, unpack_dicts, unpack_states
|
|
10
10
|
from ...linalg.linear_operator import ScaledIdentity
|
|
11
|
-
from ..
|
|
11
|
+
from ..opt_utils import epsilon_step_size
|
|
12
12
|
|
|
13
13
|
def _acceptable_alpha(alpha, param:torch.Tensor):
|
|
14
14
|
finfo = torch.finfo(param.dtype)
|
|
@@ -16,7 +16,7 @@ def _acceptable_alpha(alpha, param:torch.Tensor):
|
|
|
16
16
|
return False
|
|
17
17
|
return True
|
|
18
18
|
|
|
19
|
-
def
|
|
19
|
+
def _get_scaled_identity_H(self: TensorTransform, var):
|
|
20
20
|
n = sum(p.numel() for p in var.params)
|
|
21
21
|
p = var.params[0]
|
|
22
22
|
alpha = self.global_state.get('alpha', 1)
|
|
@@ -87,7 +87,7 @@ class PolyakStepSize(TensorTransform):
|
|
|
87
87
|
return tensors
|
|
88
88
|
|
|
89
89
|
def get_H(self, objective):
|
|
90
|
-
return
|
|
90
|
+
return _get_scaled_identity_H(self, objective)
|
|
91
91
|
|
|
92
92
|
|
|
93
93
|
def _bb_short(s: TensorList, y: TensorList, sy, eps):
|
|
@@ -176,7 +176,7 @@ class BarzilaiBorwein(TensorTransform):
|
|
|
176
176
|
prev_g.copy_(g)
|
|
177
177
|
|
|
178
178
|
def get_H(self, objective):
|
|
179
|
-
return
|
|
179
|
+
return _get_scaled_identity_H(self, objective)
|
|
180
180
|
|
|
181
181
|
@torch.no_grad
|
|
182
182
|
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
@@ -288,7 +288,7 @@ class BBStab(TensorTransform):
|
|
|
288
288
|
prev_g.copy_(g)
|
|
289
289
|
|
|
290
290
|
def get_H(self, objective):
|
|
291
|
-
return
|
|
291
|
+
return _get_scaled_identity_H(self, objective)
|
|
292
292
|
|
|
293
293
|
@torch.no_grad
|
|
294
294
|
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
@@ -384,4 +384,4 @@ class AdGD(TensorTransform):
|
|
|
384
384
|
return tensors
|
|
385
385
|
|
|
386
386
|
def get_H(self, objective):
|
|
387
|
-
return
|
|
387
|
+
return _get_scaled_identity_H(self, objective)
|
|
@@ -51,7 +51,7 @@ class Warmup(TensorTransform):
|
|
|
51
51
|
|
|
52
52
|
.. code-block:: python
|
|
53
53
|
|
|
54
|
-
opt = tz.
|
|
54
|
+
opt = tz.Optimizer(
|
|
55
55
|
model.parameters(),
|
|
56
56
|
tz.m.Adam(),
|
|
57
57
|
tz.m.LR(1e-2),
|
|
@@ -90,7 +90,7 @@ class WarmupNormClip(TensorTransform):
|
|
|
90
90
|
|
|
91
91
|
.. code-block:: python
|
|
92
92
|
|
|
93
|
-
opt = tz.
|
|
93
|
+
opt = tz.Optimizer(
|
|
94
94
|
model.parameters(),
|
|
95
95
|
tz.m.Adam(),
|
|
96
96
|
tz.m.WarmupNormClip(steps=1000)
|
|
@@ -44,7 +44,7 @@ class LevenbergMarquardt(TrustRegionBase):
|
|
|
44
44
|
Gauss-Newton with Levenberg-Marquardt trust-region
|
|
45
45
|
|
|
46
46
|
```python
|
|
47
|
-
opt = tz.
|
|
47
|
+
opt = tz.Optimizer(
|
|
48
48
|
model.parameters(),
|
|
49
49
|
tz.m.LevenbergMarquardt(tz.m.GaussNewton()),
|
|
50
50
|
)
|
|
@@ -52,7 +52,7 @@ class LevenbergMarquardt(TrustRegionBase):
|
|
|
52
52
|
|
|
53
53
|
LM-SR1
|
|
54
54
|
```python
|
|
55
|
-
opt = tz.
|
|
55
|
+
opt = tz.Optimizer(
|
|
56
56
|
model.parameters(),
|
|
57
57
|
tz.m.LevenbergMarquardt(tz.m.SR1(inverse=False)),
|
|
58
58
|
)
|
|
@@ -8,8 +8,7 @@ from ...utils import tofloat
|
|
|
8
8
|
|
|
9
9
|
|
|
10
10
|
def _reset_except_self(objective: Objective, modules, self: Module):
|
|
11
|
-
|
|
12
|
-
for m in objective.modular.flat_modules:
|
|
11
|
+
for m in modules:
|
|
13
12
|
if m is not self:
|
|
14
13
|
m.reset()
|
|
15
14
|
|
|
@@ -45,7 +44,7 @@ class SVRG(Module):
|
|
|
45
44
|
## Examples:
|
|
46
45
|
SVRG-LBFGS
|
|
47
46
|
```python
|
|
48
|
-
opt = tz.
|
|
47
|
+
opt = tz.Optimizer(
|
|
49
48
|
model.parameters(),
|
|
50
49
|
tz.m.SVRG(len(dataloader)),
|
|
51
50
|
tz.m.LBFGS(),
|
|
@@ -55,7 +54,7 @@ class SVRG(Module):
|
|
|
55
54
|
|
|
56
55
|
For extra variance reduction one can use Online versions of algorithms, although it won't always help.
|
|
57
56
|
```python
|
|
58
|
-
opt = tz.
|
|
57
|
+
opt = tz.Optimizer(
|
|
59
58
|
model.parameters(),
|
|
60
59
|
tz.m.SVRG(len(dataloader)),
|
|
61
60
|
tz.m.Online(tz.m.LBFGS()),
|
|
@@ -64,7 +63,7 @@ class SVRG(Module):
|
|
|
64
63
|
|
|
65
64
|
Variance reduction can also be applied to gradient estimators.
|
|
66
65
|
```python
|
|
67
|
-
opt = tz.
|
|
66
|
+
opt = tz.Optimizer(
|
|
68
67
|
model.parameters(),
|
|
69
68
|
tz.m.SPSA(),
|
|
70
69
|
tz.m.SVRG(100),
|
|
@@ -6,8 +6,8 @@ from ...core import Module
|
|
|
6
6
|
from ...utils import NumberList, TensorList
|
|
7
7
|
|
|
8
8
|
|
|
9
|
-
def _reset_except_self(
|
|
10
|
-
for m in
|
|
9
|
+
def _reset_except_self(objective, modules, self: Module):
|
|
10
|
+
for m in modules:
|
|
11
11
|
if m is not self:
|
|
12
12
|
m.reset()
|
|
13
13
|
|
|
@@ -33,7 +33,7 @@ class WeightDecay(TensorTransform):
|
|
|
33
33
|
|
|
34
34
|
Adam with non-decoupled weight decay
|
|
35
35
|
```python
|
|
36
|
-
opt = tz.
|
|
36
|
+
opt = tz.Optimizer(
|
|
37
37
|
model.parameters(),
|
|
38
38
|
tz.m.WeightDecay(1e-3),
|
|
39
39
|
tz.m.Adam(),
|
|
@@ -44,7 +44,7 @@ class WeightDecay(TensorTransform):
|
|
|
44
44
|
Adam with decoupled weight decay that still scales with learning rate
|
|
45
45
|
```python
|
|
46
46
|
|
|
47
|
-
opt = tz.
|
|
47
|
+
opt = tz.Optimizer(
|
|
48
48
|
model.parameters(),
|
|
49
49
|
tz.m.Adam(),
|
|
50
50
|
tz.m.WeightDecay(1e-3),
|
|
@@ -54,7 +54,7 @@ class WeightDecay(TensorTransform):
|
|
|
54
54
|
|
|
55
55
|
Adam with fully decoupled weight decay that doesn't scale with learning rate
|
|
56
56
|
```python
|
|
57
|
-
opt = tz.
|
|
57
|
+
opt = tz.Optimizer(
|
|
58
58
|
model.parameters(),
|
|
59
59
|
tz.m.Adam(),
|
|
60
60
|
tz.m.LR(1e-3),
|
|
@@ -93,7 +93,7 @@ class RelativeWeightDecay(TensorTransform):
|
|
|
93
93
|
|
|
94
94
|
Adam with non-decoupled relative weight decay
|
|
95
95
|
```python
|
|
96
|
-
opt = tz.
|
|
96
|
+
opt = tz.Optimizer(
|
|
97
97
|
model.parameters(),
|
|
98
98
|
tz.m.RelativeWeightDecay(1e-1),
|
|
99
99
|
tz.m.Adam(),
|
|
@@ -103,7 +103,7 @@ class RelativeWeightDecay(TensorTransform):
|
|
|
103
103
|
|
|
104
104
|
Adam with decoupled relative weight decay
|
|
105
105
|
```python
|
|
106
|
-
opt = tz.
|
|
106
|
+
opt = tz.Optimizer(
|
|
107
107
|
model.parameters(),
|
|
108
108
|
tz.m.Adam(),
|
|
109
109
|
tz.m.RelativeWeightDecay(1e-1),
|
|
@@ -11,7 +11,7 @@ class Wrap(Module):
|
|
|
11
11
|
Wraps a pytorch optimizer to use it as a module.
|
|
12
12
|
|
|
13
13
|
Note:
|
|
14
|
-
Custom param groups are supported only by ``set_param_groups``, settings passed to
|
|
14
|
+
Custom param groups are supported only by ``set_param_groups``, settings passed to Optimizer will be applied to all parameters.
|
|
15
15
|
|
|
16
16
|
Args:
|
|
17
17
|
opt_fn (Callable[..., torch.optim.Optimizer] | torch.optim.Optimizer):
|
|
@@ -21,7 +21,7 @@ class Wrap(Module):
|
|
|
21
21
|
**kwargs:
|
|
22
22
|
Extra args to be passed to opt_fn. The function is called as ``opt_fn(parameters, *args, **kwargs)``.
|
|
23
23
|
use_param_groups:
|
|
24
|
-
Whether to pass settings passed to
|
|
24
|
+
Whether to pass settings passed to Optimizer to the wrapped optimizer.
|
|
25
25
|
|
|
26
26
|
Note that settings to the first parameter are used for all parameters,
|
|
27
27
|
so if you specified per-parameter settings, they will be ignored.
|
|
@@ -32,7 +32,7 @@ class Wrap(Module):
|
|
|
32
32
|
```python
|
|
33
33
|
|
|
34
34
|
from pytorch_optimizer import StableAdamW
|
|
35
|
-
opt = tz.
|
|
35
|
+
opt = tz.Optimizer(
|
|
36
36
|
model.parameters(),
|
|
37
37
|
tz.m.Wrap(StableAdamW, lr=1),
|
|
38
38
|
tz.m.Cautious(),
|
|
@@ -83,7 +83,7 @@ class Wrap(Module):
|
|
|
83
83
|
|
|
84
84
|
# settings passed in `set_param_groups` are the highest priority
|
|
85
85
|
# schedulers will override defaults but not settings passed in `set_param_groups`
|
|
86
|
-
# this is consistent with how
|
|
86
|
+
# this is consistent with how Optimizer does it.
|
|
87
87
|
if self._custom_param_groups is not None:
|
|
88
88
|
setting = {k:v for k,v in setting if k not in self._custom_param_groups[0]}
|
|
89
89
|
|
|
@@ -29,7 +29,7 @@ class CD(Module):
|
|
|
29
29
|
whether to use three points (three function evaluatins) to determine descent direction.
|
|
30
30
|
if False, uses two points, but then ``adaptive`` can't be used. Defaults to True.
|
|
31
31
|
"""
|
|
32
|
-
def __init__(self, h:float=1e-3, grad:bool=
|
|
32
|
+
def __init__(self, h:float=1e-3, grad:bool=False, adaptive:bool=True, index:Literal['cyclic', 'cyclic2', 'random']="cyclic2", threepoint:bool=True,):
|
|
33
33
|
defaults = dict(h=h, grad=grad, adaptive=adaptive, index=index, threepoint=threepoint)
|
|
34
34
|
super().__init__(defaults)
|
|
35
35
|
|