torchzero 0.4.0__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_identical.py +22 -22
- tests/test_opts.py +199 -198
- torchzero/__init__.py +1 -1
- torchzero/core/__init__.py +1 -1
- torchzero/core/functional.py +1 -1
- torchzero/core/modular.py +5 -5
- torchzero/core/module.py +2 -2
- torchzero/core/objective.py +10 -10
- torchzero/core/transform.py +1 -1
- torchzero/linalg/__init__.py +3 -2
- torchzero/linalg/eigh.py +223 -4
- torchzero/linalg/orthogonalize.py +2 -4
- torchzero/linalg/qr.py +12 -0
- torchzero/linalg/solve.py +1 -3
- torchzero/linalg/svd.py +47 -20
- torchzero/modules/__init__.py +4 -3
- torchzero/modules/adaptive/__init__.py +11 -3
- torchzero/modules/adaptive/adagrad.py +10 -10
- torchzero/modules/adaptive/adahessian.py +2 -2
- torchzero/modules/adaptive/adam.py +1 -1
- torchzero/modules/adaptive/adan.py +1 -1
- torchzero/modules/adaptive/adaptive_heavyball.py +1 -1
- torchzero/modules/adaptive/esgd.py +2 -2
- torchzero/modules/adaptive/ggt.py +186 -0
- torchzero/modules/adaptive/lion.py +2 -1
- torchzero/modules/adaptive/lre_optimizers.py +299 -0
- torchzero/modules/adaptive/mars.py +2 -2
- torchzero/modules/adaptive/matrix_momentum.py +1 -1
- torchzero/modules/adaptive/msam.py +4 -4
- torchzero/modules/adaptive/muon.py +9 -6
- torchzero/modules/adaptive/natural_gradient.py +32 -15
- torchzero/modules/adaptive/psgd/__init__.py +5 -0
- torchzero/modules/adaptive/psgd/_psgd_utils.py +37 -0
- torchzero/modules/adaptive/psgd/psgd.py +1390 -0
- torchzero/modules/adaptive/psgd/psgd_dense_newton.py +174 -0
- torchzero/modules/adaptive/psgd/psgd_kron_newton.py +203 -0
- torchzero/modules/adaptive/psgd/psgd_kron_whiten.py +185 -0
- torchzero/modules/adaptive/psgd/psgd_lra_newton.py +118 -0
- torchzero/modules/adaptive/psgd/psgd_lra_whiten.py +116 -0
- torchzero/modules/adaptive/rprop.py +2 -2
- torchzero/modules/adaptive/sam.py +4 -4
- torchzero/modules/adaptive/shampoo.py +28 -3
- torchzero/modules/adaptive/soap.py +3 -3
- torchzero/modules/adaptive/sophia_h.py +2 -2
- torchzero/modules/clipping/clipping.py +7 -7
- torchzero/modules/conjugate_gradient/cg.py +2 -2
- torchzero/modules/experimental/__init__.py +5 -0
- torchzero/modules/experimental/adanystrom.py +258 -0
- torchzero/modules/experimental/common_directions_whiten.py +142 -0
- torchzero/modules/experimental/cubic_adam.py +160 -0
- torchzero/modules/experimental/eigen_sr1.py +182 -0
- torchzero/modules/experimental/eigengrad.py +207 -0
- torchzero/modules/experimental/l_infinity.py +1 -1
- torchzero/modules/experimental/matrix_nag.py +122 -0
- torchzero/modules/experimental/newton_solver.py +2 -2
- torchzero/modules/experimental/newtonnewton.py +34 -40
- torchzero/modules/grad_approximation/fdm.py +2 -2
- torchzero/modules/grad_approximation/rfdm.py +4 -4
- torchzero/modules/least_squares/gn.py +68 -45
- torchzero/modules/line_search/backtracking.py +2 -2
- torchzero/modules/line_search/line_search.py +1 -1
- torchzero/modules/line_search/strong_wolfe.py +2 -2
- torchzero/modules/misc/escape.py +1 -1
- torchzero/modules/misc/gradient_accumulation.py +1 -1
- torchzero/modules/misc/misc.py +1 -1
- torchzero/modules/misc/multistep.py +4 -7
- torchzero/modules/misc/regularization.py +2 -2
- torchzero/modules/misc/split.py +1 -1
- torchzero/modules/misc/switch.py +2 -2
- torchzero/modules/momentum/cautious.py +3 -3
- torchzero/modules/momentum/momentum.py +1 -1
- torchzero/modules/ops/higher_level.py +1 -1
- torchzero/modules/ops/multi.py +1 -1
- torchzero/modules/projections/projection.py +5 -2
- torchzero/modules/quasi_newton/__init__.py +1 -1
- torchzero/modules/quasi_newton/damping.py +1 -1
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +1 -1
- torchzero/modules/quasi_newton/lbfgs.py +3 -3
- torchzero/modules/quasi_newton/lsr1.py +3 -3
- torchzero/modules/quasi_newton/quasi_newton.py +44 -29
- torchzero/modules/quasi_newton/sg2.py +69 -205
- torchzero/modules/restarts/restars.py +17 -17
- torchzero/modules/second_order/inm.py +33 -25
- torchzero/modules/second_order/newton.py +132 -130
- torchzero/modules/second_order/newton_cg.py +3 -3
- torchzero/modules/second_order/nystrom.py +83 -32
- torchzero/modules/second_order/rsn.py +41 -44
- torchzero/modules/smoothing/laplacian.py +1 -1
- torchzero/modules/smoothing/sampling.py +2 -3
- torchzero/modules/step_size/adaptive.py +6 -6
- torchzero/modules/step_size/lr.py +2 -2
- torchzero/modules/trust_region/cubic_regularization.py +1 -1
- torchzero/modules/trust_region/levenberg_marquardt.py +2 -2
- torchzero/modules/trust_region/trust_cg.py +1 -1
- torchzero/modules/variance_reduction/svrg.py +4 -5
- torchzero/modules/weight_decay/reinit.py +2 -2
- torchzero/modules/weight_decay/weight_decay.py +5 -5
- torchzero/modules/wrappers/optim_wrapper.py +4 -4
- torchzero/modules/zeroth_order/cd.py +1 -1
- torchzero/optim/mbs.py +291 -0
- torchzero/optim/wrappers/nevergrad.py +0 -9
- torchzero/optim/wrappers/optuna.py +2 -0
- torchzero/utils/benchmarks/__init__.py +0 -0
- torchzero/utils/benchmarks/logistic.py +122 -0
- torchzero/utils/derivatives.py +4 -4
- {torchzero-0.4.0.dist-info → torchzero-0.4.1.dist-info}/METADATA +1 -1
- torchzero-0.4.1.dist-info/RECORD +209 -0
- torchzero/modules/adaptive/lmadagrad.py +0 -241
- torchzero-0.4.0.dist-info/RECORD +0 -191
- /torchzero/modules/{functional.py → opt_utils.py} +0 -0
- {torchzero-0.4.0.dist-info → torchzero-0.4.1.dist-info}/WHEEL +0 -0
- {torchzero-0.4.0.dist-info → torchzero-0.4.1.dist-info}/top_level.txt +0 -0
torchzero/modules/misc/split.py
CHANGED
torchzero/modules/misc/switch.py
CHANGED
|
@@ -19,7 +19,7 @@ class Alternate(Module):
|
|
|
19
19
|
|
|
20
20
|
```python
|
|
21
21
|
|
|
22
|
-
opt = tz.
|
|
22
|
+
opt = tz.Optimizer(
|
|
23
23
|
model.parameters(),
|
|
24
24
|
tz.m.Alternate(
|
|
25
25
|
tz.m.Adam(),
|
|
@@ -89,7 +89,7 @@ class Switch(Alternate):
|
|
|
89
89
|
Start with Adam, switch to L-BFGS after 1000th step and Truncated Newton on 2000th step.
|
|
90
90
|
|
|
91
91
|
```python
|
|
92
|
-
opt = tz.
|
|
92
|
+
opt = tz.Optimizer(
|
|
93
93
|
model.parameters(),
|
|
94
94
|
tz.m.Switch(
|
|
95
95
|
[tz.m.Adam(), tz.m.LR(1e-3)],
|
|
@@ -57,7 +57,7 @@ class Cautious(TensorTransform):
|
|
|
57
57
|
Cautious Adam
|
|
58
58
|
|
|
59
59
|
```python
|
|
60
|
-
opt = tz.
|
|
60
|
+
opt = tz.Optimizer(
|
|
61
61
|
bench.parameters(),
|
|
62
62
|
tz.m.Adam(),
|
|
63
63
|
tz.m.Cautious(),
|
|
@@ -173,7 +173,7 @@ class ScaleByGradCosineSimilarity(TensorTransform):
|
|
|
173
173
|
|
|
174
174
|
Scaled Adam
|
|
175
175
|
```python
|
|
176
|
-
opt = tz.
|
|
176
|
+
opt = tz.Optimizer(
|
|
177
177
|
bench.parameters(),
|
|
178
178
|
tz.m.Adam(),
|
|
179
179
|
tz.m.ScaleByGradCosineSimilarity(),
|
|
@@ -211,7 +211,7 @@ class ScaleModulesByCosineSimilarity(Module):
|
|
|
211
211
|
|
|
212
212
|
Adam scaled by similarity to RMSprop
|
|
213
213
|
```python
|
|
214
|
-
opt = tz.
|
|
214
|
+
opt = tz.Optimizer(
|
|
215
215
|
bench.parameters(),
|
|
216
216
|
tz.m.ScaleModulesByCosineSimilarity(
|
|
217
217
|
main = tz.m.Adam(),
|
torchzero/modules/ops/multi.py
CHANGED
|
@@ -149,8 +149,11 @@ class ProjectionBase(Module, ABC):
|
|
|
149
149
|
Iterable[torch.Tensor]: unprojected tensors of the same shape as params
|
|
150
150
|
"""
|
|
151
151
|
|
|
152
|
+
def update(self, objective: Objective): raise RuntimeError("projections don't support update/apply")
|
|
153
|
+
def apply(self, objective: Objective): raise RuntimeError("projections don't support update/apply")
|
|
154
|
+
|
|
152
155
|
@torch.no_grad
|
|
153
|
-
def
|
|
156
|
+
def step(self, objective: Objective):
|
|
154
157
|
params = objective.params
|
|
155
158
|
settings = [self.settings[p] for p in params]
|
|
156
159
|
|
|
@@ -266,7 +269,7 @@ class ProjectionBase(Module, ABC):
|
|
|
266
269
|
|
|
267
270
|
# ----------------------------------- step ----------------------------------- #
|
|
268
271
|
projected_obj.params = projected_params
|
|
269
|
-
projected_obj = self.children['modules'].
|
|
272
|
+
projected_obj = self.children['modules'].step(projected_obj)
|
|
270
273
|
|
|
271
274
|
# empty fake params storage
|
|
272
275
|
# this doesn't affect update/grad because it is a different python object, set_ changes storage on an object
|
|
@@ -7,7 +7,7 @@ import torch
|
|
|
7
7
|
from ...core import Chainable, TensorTransform
|
|
8
8
|
from ...utils import TensorList, as_tensorlist, unpack_states
|
|
9
9
|
from ...linalg.linear_operator import LinearOperator
|
|
10
|
-
from ..
|
|
10
|
+
from ..opt_utils import initial_step_size
|
|
11
11
|
from .damping import DampingStrategyType, apply_damping
|
|
12
12
|
|
|
13
13
|
|
|
@@ -188,7 +188,7 @@ class LBFGS(TensorTransform):
|
|
|
188
188
|
|
|
189
189
|
L-BFGS with line search
|
|
190
190
|
```python
|
|
191
|
-
opt = tz.
|
|
191
|
+
opt = tz.Optimizer(
|
|
192
192
|
model.parameters(),
|
|
193
193
|
tz.m.LBFGS(100),
|
|
194
194
|
tz.m.Backtracking()
|
|
@@ -197,7 +197,7 @@ class LBFGS(TensorTransform):
|
|
|
197
197
|
|
|
198
198
|
L-BFGS with trust region
|
|
199
199
|
```python
|
|
200
|
-
opt = tz.
|
|
200
|
+
opt = tz.Optimizer(
|
|
201
201
|
model.parameters(),
|
|
202
202
|
tz.m.TrustCG(tz.m.LBFGS())
|
|
203
203
|
)
|
|
@@ -7,7 +7,7 @@ import torch
|
|
|
7
7
|
from ...core import Chainable, Module, TensorTransform, Objective, step
|
|
8
8
|
from ...utils import NumberList, TensorList, as_tensorlist, generic_finfo_tiny, unpack_states, vec_to_tensors_
|
|
9
9
|
from ...linalg.linear_operator import LinearOperator
|
|
10
|
-
from ..
|
|
10
|
+
from ..opt_utils import initial_step_size
|
|
11
11
|
from .damping import DampingStrategyType, apply_damping
|
|
12
12
|
|
|
13
13
|
|
|
@@ -110,7 +110,7 @@ class LSR1(TensorTransform):
|
|
|
110
110
|
|
|
111
111
|
L-SR1 with line search
|
|
112
112
|
```python
|
|
113
|
-
opt = tz.
|
|
113
|
+
opt = tz.Optimizer(
|
|
114
114
|
model.parameters(),
|
|
115
115
|
tz.m.SR1(),
|
|
116
116
|
tz.m.StrongWolfe(c2=0.1, fallback=True)
|
|
@@ -119,7 +119,7 @@ class LSR1(TensorTransform):
|
|
|
119
119
|
|
|
120
120
|
L-SR1 with trust region
|
|
121
121
|
```python
|
|
122
|
-
opt = tz.
|
|
122
|
+
opt = tz.Optimizer(
|
|
123
123
|
model.parameters(),
|
|
124
124
|
tz.m.TrustCG(tz.m.LSR1())
|
|
125
125
|
)
|
|
@@ -8,7 +8,7 @@ import torch
|
|
|
8
8
|
from ...core import Chainable, Module, TensorTransform, Transform
|
|
9
9
|
from ...utils import TensorList, set_storage_, unpack_states, safe_dict_update_
|
|
10
10
|
from ...linalg import linear_operator
|
|
11
|
-
from ..
|
|
11
|
+
from ..opt_utils import initial_step_size, safe_clip
|
|
12
12
|
|
|
13
13
|
|
|
14
14
|
|
|
@@ -106,11 +106,12 @@ class HessianUpdateStrategy(TensorTransform, ABC):
|
|
|
106
106
|
scale_first: bool = False,
|
|
107
107
|
concat_params: bool = True,
|
|
108
108
|
inverse: bool = True,
|
|
109
|
+
uses_loss: bool = False,
|
|
109
110
|
inner: Chainable | None = None,
|
|
110
111
|
):
|
|
111
112
|
if defaults is None: defaults = {}
|
|
112
113
|
safe_dict_update_(defaults, dict(init_scale=init_scale, tol=tol, ptol=ptol, ptol_restart=ptol_restart, gtol=gtol, inverse=inverse, beta=beta, restart_interval=restart_interval, scale_first=scale_first))
|
|
113
|
-
super().__init__(defaults,
|
|
114
|
+
super().__init__(defaults, uses_loss=uses_loss, concat_params=concat_params, update_freq=update_freq, inner=inner)
|
|
114
115
|
|
|
115
116
|
def reset_for_online(self):
|
|
116
117
|
super().reset_for_online()
|
|
@@ -141,18 +142,22 @@ class HessianUpdateStrategy(TensorTransform, ABC):
|
|
|
141
142
|
return H
|
|
142
143
|
|
|
143
144
|
# ------------------------------ common methods ------------------------------ #
|
|
144
|
-
def auto_initial_scale(self, s:torch.Tensor,y:torch.Tensor) -> torch.Tensor | float:
|
|
145
|
+
def auto_initial_scale(self, s:torch.Tensor,y:torch.Tensor) -> torch.Tensor | float | None:
|
|
145
146
|
"""returns multiplier to B on 2nd step if ``init_scale='auto'``. H should be divided by this!"""
|
|
146
147
|
ys = y.dot(s)
|
|
147
148
|
yy = y.dot(y)
|
|
148
|
-
|
|
149
|
-
return
|
|
149
|
+
tiny = torch.finfo(ys.dtype).tiny * 2
|
|
150
|
+
if ys > tiny and yy > tiny: return yy/ys
|
|
151
|
+
return None
|
|
150
152
|
|
|
151
|
-
def reset_P(self, P: torch.Tensor, s:torch.Tensor,y:torch.Tensor, inverse:bool, init_scale: Any, state:dict[str,Any]) -> None:
|
|
153
|
+
def reset_P(self, P: torch.Tensor, s:torch.Tensor, y:torch.Tensor, inverse:bool, init_scale: Any, state:dict[str,Any]) -> None:
|
|
152
154
|
"""resets ``P`` which is either B or H"""
|
|
153
155
|
set_storage_(P, self.initialize_P(s.numel(), device=P.device, dtype=P.dtype, is_inverse=inverse))
|
|
154
|
-
if init_scale == 'auto':
|
|
155
|
-
|
|
156
|
+
if init_scale == 'auto':
|
|
157
|
+
init_scale = self.auto_initial_scale(s,y)
|
|
158
|
+
state["scaled"] = init_scale is not None
|
|
159
|
+
|
|
160
|
+
if init_scale is not None and init_scale != 1:
|
|
156
161
|
if inverse: P /= init_scale
|
|
157
162
|
else: P *= init_scale
|
|
158
163
|
|
|
@@ -182,6 +187,7 @@ class HessianUpdateStrategy(TensorTransform, ABC):
|
|
|
182
187
|
state['f_prev'] = loss
|
|
183
188
|
state['p_prev'] = p.clone()
|
|
184
189
|
state['g_prev'] = g.clone()
|
|
190
|
+
state["scaled"] = False
|
|
185
191
|
return
|
|
186
192
|
|
|
187
193
|
state['f'] = loss
|
|
@@ -205,9 +211,13 @@ class HessianUpdateStrategy(TensorTransform, ABC):
|
|
|
205
211
|
if gtol is not None and y.abs().max() <= gtol:
|
|
206
212
|
return
|
|
207
213
|
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
214
|
+
# apply automatic initial scale if it hasn't been applied
|
|
215
|
+
if (not state["scaled"]) and (init_scale == 'auto'):
|
|
216
|
+
scale = self.auto_initial_scale(s,y)
|
|
217
|
+
if scale is not None:
|
|
218
|
+
state["scaled"] = True
|
|
219
|
+
if inverse: M /= self.auto_initial_scale(s,y)
|
|
220
|
+
else: M *= self.auto_initial_scale(s,y)
|
|
211
221
|
|
|
212
222
|
beta = setting['beta']
|
|
213
223
|
if beta is not None and beta != 0: M = M.clone() # because all of them update it in-place
|
|
@@ -367,22 +377,21 @@ def bfgs_B_(B:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
|
|
|
367
377
|
B += term1.sub_(term2)
|
|
368
378
|
return B
|
|
369
379
|
|
|
370
|
-
|
|
380
|
+
|
|
381
|
+
def bfgs_H_(H: torch.Tensor, s: torch.Tensor, y: torch.Tensor, tol: float):
|
|
371
382
|
sy = s.dot(y)
|
|
372
383
|
if sy <= tol: return H
|
|
373
384
|
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
Hy = H@y
|
|
377
|
-
scale1 = (sy + y.dot(Hy)) / sy_sq
|
|
378
|
-
term1 = s.outer(s).mul_(scale1)
|
|
385
|
+
rho = 1.0 / sy
|
|
386
|
+
Hy = H @ y
|
|
379
387
|
|
|
380
|
-
|
|
381
|
-
term2 =
|
|
388
|
+
term1 = (s.outer(s)).mul_(rho * (1 + rho * y.dot(Hy)))
|
|
389
|
+
term2 = (Hy.outer(s) + s.outer(Hy)).mul_(rho)
|
|
382
390
|
|
|
383
|
-
H
|
|
391
|
+
H.add_(term1).sub_(term2)
|
|
384
392
|
return H
|
|
385
393
|
|
|
394
|
+
|
|
386
395
|
class BFGS(_InverseHessianUpdateStrategyDefaults):
|
|
387
396
|
"""Broyden–Fletcher–Goldfarb–Shanno Quasi-Newton method. This is usually the most stable quasi-newton method.
|
|
388
397
|
|
|
@@ -428,7 +437,7 @@ class BFGS(_InverseHessianUpdateStrategyDefaults):
|
|
|
428
437
|
BFGS with backtracking line search:
|
|
429
438
|
|
|
430
439
|
```python
|
|
431
|
-
opt = tz.
|
|
440
|
+
opt = tz.Optimizer(
|
|
432
441
|
model.parameters(),
|
|
433
442
|
tz.m.BFGS(),
|
|
434
443
|
tz.m.Backtracking()
|
|
@@ -437,7 +446,7 @@ class BFGS(_InverseHessianUpdateStrategyDefaults):
|
|
|
437
446
|
|
|
438
447
|
BFGS with trust region
|
|
439
448
|
```python
|
|
440
|
-
opt = tz.
|
|
449
|
+
opt = tz.Optimizer(
|
|
441
450
|
model.parameters(),
|
|
442
451
|
tz.m.LevenbergMarquardt(tz.m.BFGS(inverse=False)),
|
|
443
452
|
)
|
|
@@ -505,7 +514,7 @@ class SR1(_InverseHessianUpdateStrategyDefaults):
|
|
|
505
514
|
|
|
506
515
|
SR1 with trust region
|
|
507
516
|
```python
|
|
508
|
-
opt = tz.
|
|
517
|
+
opt = tz.Optimizer(
|
|
509
518
|
model.parameters(),
|
|
510
519
|
tz.m.LevenbergMarquardt(tz.m.SR1(inverse=False)),
|
|
511
520
|
)
|
|
@@ -1015,7 +1024,7 @@ class GradientCorrection(TensorTransform):
|
|
|
1015
1024
|
L-BFGS with gradient correction
|
|
1016
1025
|
|
|
1017
1026
|
```python
|
|
1018
|
-
opt = tz.
|
|
1027
|
+
opt = tz.Optimizer(
|
|
1019
1028
|
model.parameters(),
|
|
1020
1029
|
tz.m.LBFGS(inner=tz.m.GradientCorrection()),
|
|
1021
1030
|
tz.m.Backtracking()
|
|
@@ -1154,6 +1163,7 @@ class NewSSM(HessianUpdateStrategy):
|
|
|
1154
1163
|
scale_first=scale_first,
|
|
1155
1164
|
concat_params=concat_params,
|
|
1156
1165
|
inverse=True,
|
|
1166
|
+
uses_loss=True,
|
|
1157
1167
|
inner=inner,
|
|
1158
1168
|
)
|
|
1159
1169
|
def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
|
|
@@ -1171,13 +1181,18 @@ class NewSSM(HessianUpdateStrategy):
|
|
|
1171
1181
|
|
|
1172
1182
|
# this is supposed to be equivalent (and it is)
|
|
1173
1183
|
def shor_r_(H:torch.Tensor, y:torch.Tensor, alpha:float):
|
|
1174
|
-
|
|
1175
|
-
|
|
1176
|
-
|
|
1177
|
-
term =
|
|
1178
|
-
H.sub_(term, alpha=1-alpha**2)
|
|
1184
|
+
Hy = H @ y
|
|
1185
|
+
yHy = safe_clip(y.dot(Hy))
|
|
1186
|
+
term = Hy.outer(Hy).div_(yHy)
|
|
1187
|
+
H.sub_(term, alpha=(1-alpha**2))
|
|
1179
1188
|
return H
|
|
1180
1189
|
|
|
1190
|
+
# def projected_gradient_(H:torch.Tensor, y:torch.Tensor):
|
|
1191
|
+
# Hy = H @ y
|
|
1192
|
+
# yHy = safe_clip(y.dot(Hy))
|
|
1193
|
+
# H -= (Hy.outer(y) @ H).div_(yHy)
|
|
1194
|
+
# return H
|
|
1195
|
+
|
|
1181
1196
|
class ShorR(HessianUpdateStrategy):
|
|
1182
1197
|
"""Shor’s r-algorithm.
|
|
1183
1198
|
|