torchzero 0.3.6__py3-none-any.whl → 0.3.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_opts.py +1 -1
- tests/test_tensorlist.py +1 -1
- torchzero/core/preconditioner.py +12 -11
- torchzero/modules/experimental/__init__.py +3 -2
- torchzero/modules/experimental/absoap.py +8 -2
- torchzero/modules/experimental/adadam.py +1 -1
- torchzero/modules/experimental/adamY.py +1 -1
- torchzero/modules/experimental/adasoap.py +2 -2
- torchzero/modules/experimental/algebraic_newton.py +1 -1
- torchzero/modules/experimental/curveball.py +1 -1
- torchzero/modules/experimental/gradmin.py +1 -1
- torchzero/modules/experimental/newton_solver.py +88 -0
- torchzero/modules/experimental/{dsoap.py → soapy.py} +4 -4
- torchzero/modules/experimental/spectral.py +5 -3
- torchzero/modules/experimental/structured_newton.py +111 -0
- torchzero/modules/experimental/subspace_preconditioners.py +16 -9
- torchzero/modules/optimizers/soap.py +1 -2
- torchzero/modules/projections/projection.py +27 -1
- torchzero/modules/quasi_newton/cg.py +9 -9
- torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +4 -3
- torchzero/modules/quasi_newton/lbfgs.py +4 -3
- torchzero/modules/quasi_newton/lsr1.py +7 -3
- torchzero/modules/quasi_newton/quasi_newton.py +18 -17
- torchzero/modules/second_order/__init__.py +1 -1
- torchzero/modules/second_order/newton.py +11 -6
- torchzero/modules/second_order/newton_cg.py +3 -3
- torchzero/modules/second_order/nystrom.py +6 -6
- torchzero/utils/linalg/benchmark.py +20 -0
- torchzero/utils/linalg/solve.py +15 -14
- {torchzero-0.3.6.dist-info → torchzero-0.3.9.dist-info}/METADATA +2 -2
- {torchzero-0.3.6.dist-info → torchzero-0.3.9.dist-info}/RECORD +34 -31
- {torchzero-0.3.6.dist-info → torchzero-0.3.9.dist-info}/WHEEL +0 -0
- {torchzero-0.3.6.dist-info → torchzero-0.3.9.dist-info}/licenses/LICENSE +0 -0
- {torchzero-0.3.6.dist-info → torchzero-0.3.9.dist-info}/top_level.txt +0 -0
tests/test_opts.py
CHANGED
|
@@ -745,7 +745,7 @@ SSVM = Run(
|
|
|
745
745
|
func_opt=lambda p: tz.Modular(p, tz.m.SSVM(1), tz.m.StrongWolfe()),
|
|
746
746
|
sphere_opt=lambda p: tz.Modular(p, tz.m.SSVM(1), tz.m.StrongWolfe()),
|
|
747
747
|
needs_closure=True,
|
|
748
|
-
func='rosen', steps=50, loss=1e-
|
|
748
|
+
func='rosen', steps=50, loss=1e-10, merge_invariant=True,
|
|
749
749
|
sphere_steps=10, sphere_loss=0,
|
|
750
750
|
)
|
|
751
751
|
|
tests/test_tensorlist.py
CHANGED
|
@@ -835,7 +835,7 @@ def test_global_reductions(simple_tl: TensorList, global_method, vec_equiv_metho
|
|
|
835
835
|
expected = vec_equiv_func()
|
|
836
836
|
|
|
837
837
|
if isinstance(result, bool): assert result == expected
|
|
838
|
-
else: assert torch.allclose(result, expected), f"Tensors not close: {result = }, {expected = }"
|
|
838
|
+
else: assert torch.allclose(result, expected, atol=1e-4), f"Tensors not close: {result = }, {expected = }"
|
|
839
839
|
|
|
840
840
|
|
|
841
841
|
def test_global_vector_norm(simple_tl: TensorList):
|
torchzero/core/preconditioner.py
CHANGED
|
@@ -38,17 +38,18 @@ class Preconditioner(Transform):
|
|
|
38
38
|
|
|
39
39
|
|
|
40
40
|
def _tensor_wise_transform(self, tensors:list[torch.Tensor], params:list[torch.Tensor], grads:list[torch.Tensor] | None, vars:Vars) -> list[torch.Tensor]:
|
|
41
|
-
step = self.global_state.get('
|
|
41
|
+
step = self.global_state.get('__step', 0)
|
|
42
42
|
states = [self.state[p] for p in params]
|
|
43
43
|
settings = [self.settings[p] for p in params]
|
|
44
44
|
global_settings = settings[0]
|
|
45
45
|
update_freq = global_settings['__update_freq']
|
|
46
46
|
|
|
47
47
|
scale_first = global_settings['__scale_first']
|
|
48
|
-
scale_factor =
|
|
48
|
+
scale_factor = 1
|
|
49
49
|
if scale_first and step == 0:
|
|
50
50
|
# initial step size guess from pytorch LBFGS
|
|
51
|
-
scale_factor = TensorList(tensors).abs().
|
|
51
|
+
scale_factor = 1 / TensorList(tensors).abs().global_sum().clip(min=1)
|
|
52
|
+
scale_factor = scale_factor.clip(min=torch.finfo(tensors[0].dtype).eps)
|
|
52
53
|
|
|
53
54
|
# update preconditioner
|
|
54
55
|
if step % update_freq == 0:
|
|
@@ -63,13 +64,13 @@ class Preconditioner(Transform):
|
|
|
63
64
|
|
|
64
65
|
# scale initial step, when preconditioner might not have been applied
|
|
65
66
|
if scale_first and step == 0:
|
|
66
|
-
torch.
|
|
67
|
+
torch._foreach_mul_(tensors, scale_factor)
|
|
67
68
|
|
|
68
|
-
self.global_state['
|
|
69
|
+
self.global_state['__step'] = step + 1
|
|
69
70
|
return tensors
|
|
70
71
|
|
|
71
72
|
def _concat_transform(self, tensors:list[torch.Tensor], params:list[torch.Tensor], grads:list[torch.Tensor] | None, vars:Vars) -> list[torch.Tensor]:
|
|
72
|
-
step = self.global_state.get('
|
|
73
|
+
step = self.global_state.get('__step', 0)
|
|
73
74
|
tensors_vec = torch.cat([t.ravel() for t in tensors])
|
|
74
75
|
params_vec = torch.cat([p.ravel() for p in params])
|
|
75
76
|
grads_vec = [torch.cat([g.ravel() for g in grads])] if grads is not None else None
|
|
@@ -80,10 +81,11 @@ class Preconditioner(Transform):
|
|
|
80
81
|
update_freq = global_settings['__update_freq']
|
|
81
82
|
|
|
82
83
|
scale_first = global_settings['__scale_first']
|
|
83
|
-
scale_factor =
|
|
84
|
+
scale_factor = 1
|
|
84
85
|
if scale_first and step == 0:
|
|
85
86
|
# initial step size guess from pytorch LBFGS
|
|
86
|
-
scale_factor = tensors_vec.abs().sum()
|
|
87
|
+
scale_factor = 1 / tensors_vec.abs().sum().clip(min=1)
|
|
88
|
+
scale_factor = scale_factor.clip(min=torch.finfo(tensors_vec.dtype).eps)
|
|
87
89
|
|
|
88
90
|
# update preconditioner
|
|
89
91
|
if step % update_freq == 0:
|
|
@@ -99,11 +101,10 @@ class Preconditioner(Transform):
|
|
|
99
101
|
|
|
100
102
|
# scale initial step, when preconditioner might not have been applied
|
|
101
103
|
if scale_first and step == 0:
|
|
102
|
-
|
|
103
|
-
tensors_vec /= scale_factor
|
|
104
|
+
tensors_vec *= scale_factor
|
|
104
105
|
|
|
105
106
|
tensors = vec_to_tensors(vec=tensors_vec, reference=tensors)
|
|
106
|
-
self.global_state['
|
|
107
|
+
self.global_state['__step'] = step + 1
|
|
107
108
|
return tensors
|
|
108
109
|
|
|
109
110
|
@torch.no_grad
|
|
@@ -3,7 +3,7 @@ from .adadam import Adadam
|
|
|
3
3
|
from .adamY import AdamY
|
|
4
4
|
from .adasoap import AdaSOAP
|
|
5
5
|
from .curveball import CurveBall
|
|
6
|
-
from .
|
|
6
|
+
from .soapy import SOAPY
|
|
7
7
|
from .gradmin import GradMin
|
|
8
8
|
from .reduce_outward_lr import ReduceOutwardLR
|
|
9
9
|
from .spectral import SpectralPreconditioner
|
|
@@ -11,4 +11,5 @@ from .subspace_preconditioners import (
|
|
|
11
11
|
HistorySubspacePreconditioning,
|
|
12
12
|
RandomSubspacePreconditioning,
|
|
13
13
|
)
|
|
14
|
-
from .tropical_newton import TropicalNewton
|
|
14
|
+
from .tropical_newton import TropicalNewton
|
|
15
|
+
from .newton_solver import NewtonSolver
|
|
@@ -140,11 +140,17 @@ Source=Literal['p','g','s','y', 'gy', 'sy', 'sn', 'yn', 'gys', 'sys','sn', 'yn']
|
|
|
140
140
|
class ABSOAP(Transform):
|
|
141
141
|
"""SOAP but with two extra letters included in its name in order to improve converence
|
|
142
142
|
|
|
143
|
+
so what you can do is choose what goes into what ,and that is supposed to be good.
|
|
144
|
+
|
|
143
145
|
new args
|
|
144
146
|
|
|
145
147
|
scale by s whether to scale gradient differences by parameter differences
|
|
146
148
|
|
|
147
149
|
y_to_ema2 whether to use gradient differences for exponential moving average too
|
|
150
|
+
|
|
151
|
+
okay I changed these args into another ones
|
|
152
|
+
|
|
153
|
+
BASICALLY THIS IS FOR MY EXPERIMENTS
|
|
148
154
|
"""
|
|
149
155
|
def __init__(
|
|
150
156
|
self,
|
|
@@ -213,7 +219,7 @@ class ABSOAP(Transform):
|
|
|
213
219
|
if 'g_prev' not in state:
|
|
214
220
|
state['p_prev'] = p.clone()
|
|
215
221
|
state['g_prev'] = t.clone()
|
|
216
|
-
updates.append(tensors[i].
|
|
222
|
+
updates.append(tensors[i].clip(-0.1,0.1))
|
|
217
223
|
continue
|
|
218
224
|
|
|
219
225
|
p_prev = state['p_prev']
|
|
@@ -285,7 +291,7 @@ class ABSOAP(Transform):
|
|
|
285
291
|
state['Q'] = get_orthogonal_matrix(state['GG'])
|
|
286
292
|
|
|
287
293
|
state['step'] = 0
|
|
288
|
-
updates.append(tensors[i].
|
|
294
|
+
updates.append(tensors[i].clip(-0.1,0.1))
|
|
289
295
|
continue # skip 1st step as in https://github.com/nikhilvyas/SOAP/blob/main/soap.py ?
|
|
290
296
|
# I use sign instead as to not mess up with next modules. 1st Adam step is always sign anyway.
|
|
291
297
|
|
|
@@ -218,9 +218,9 @@ class AdaSOAP(Transform):
|
|
|
218
218
|
state['Q'] = get_orthogonal_matrix(GG_precond)
|
|
219
219
|
|
|
220
220
|
state['step'] = 0
|
|
221
|
-
updates.append(tensors[i].
|
|
221
|
+
updates.append(tensors[i].clip(-0.1,0.1))
|
|
222
222
|
continue # skip 1st step as in https://github.com/nikhilvyas/SOAP/blob/main/soap.py ?
|
|
223
|
-
#
|
|
223
|
+
# that can mess with other modules scaling
|
|
224
224
|
|
|
225
225
|
# Projecting gradients to the eigenbases of Shampoo's preconditioner
|
|
226
226
|
# i.e. projecting to the eigenbases of matrices in state['GG']
|
|
@@ -71,7 +71,7 @@ def tikhonov(H: torch.Tensor, reg: float, algebra: ta.Algebra = ta.TropicalSemir
|
|
|
71
71
|
|
|
72
72
|
|
|
73
73
|
class AlgebraicNewton(Module):
|
|
74
|
-
"""newton in other algebras, not
|
|
74
|
+
"""newton in other algebras, not that it works."""
|
|
75
75
|
def __init__(
|
|
76
76
|
self,
|
|
77
77
|
reg: float | None = None,
|
|
@@ -13,7 +13,7 @@ def curveball(
|
|
|
13
13
|
momentum: float | NumberList,
|
|
14
14
|
precond_lr: float | NumberList,
|
|
15
15
|
):
|
|
16
|
-
"""returns z_, clone it!!!"""
|
|
16
|
+
"""returns z_, clone it!!! (no just negate it)"""
|
|
17
17
|
delta = Hz + tensors
|
|
18
18
|
z_.mul_(momentum).sub_(delta.mul_(precond_lr)) # z ← ρz − βΔ
|
|
19
19
|
return z_
|
|
@@ -14,7 +14,7 @@ from ..smoothing.gaussian import Reformulation
|
|
|
14
14
|
|
|
15
15
|
|
|
16
16
|
class GradMin(Reformulation):
|
|
17
|
-
"""Reformulates the objective to minimize sum of gradient magnitudes via autograd.
|
|
17
|
+
"""Reformulates the objective to minimize sum of gradient magnitudes via autograd. This is not expected to be practical.
|
|
18
18
|
|
|
19
19
|
Args:
|
|
20
20
|
loss_term (float, optional): adds loss value times this to sum of gradient magnitudes. Defaults to 1.
|
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
from collections.abc import Callable, Iterable
|
|
2
|
+
from typing import Any, Literal, overload
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
from ...core import Chainable, Module, apply, Modular
|
|
7
|
+
from ...utils import TensorList, as_tensorlist
|
|
8
|
+
from ...utils.derivatives import hvp
|
|
9
|
+
from ..quasi_newton import LBFGS
|
|
10
|
+
|
|
11
|
+
class NewtonSolver(Module):
|
|
12
|
+
"""Matrix free newton via with any custom solver (usually it is better to just use NewtonCG or NystromPCG is even better)"""
|
|
13
|
+
def __init__(
|
|
14
|
+
self,
|
|
15
|
+
solver: Callable[[list[torch.Tensor]], Any] = lambda p: Modular(p, LBFGS()),
|
|
16
|
+
maxiter=None,
|
|
17
|
+
tol=1e-3,
|
|
18
|
+
reg: float = 0,
|
|
19
|
+
warm_start=True,
|
|
20
|
+
inner: Chainable | None = None,
|
|
21
|
+
):
|
|
22
|
+
defaults = dict(tol=tol, maxiter=maxiter, reg=reg, warm_start=warm_start, solver=solver)
|
|
23
|
+
super().__init__(defaults,)
|
|
24
|
+
|
|
25
|
+
if inner is not None:
|
|
26
|
+
self.set_child('inner', inner)
|
|
27
|
+
|
|
28
|
+
@torch.no_grad
|
|
29
|
+
def step(self, vars):
|
|
30
|
+
params = TensorList(vars.params)
|
|
31
|
+
closure = vars.closure
|
|
32
|
+
if closure is None: raise RuntimeError('NewtonCG requires closure')
|
|
33
|
+
|
|
34
|
+
settings = self.settings[params[0]]
|
|
35
|
+
solver_cls = settings['solver']
|
|
36
|
+
maxiter = settings['maxiter']
|
|
37
|
+
tol = settings['tol']
|
|
38
|
+
reg = settings['reg']
|
|
39
|
+
warm_start = settings['warm_start']
|
|
40
|
+
|
|
41
|
+
# ---------------------- Hessian vector product function --------------------- #
|
|
42
|
+
grad = vars.get_grad(create_graph=True)
|
|
43
|
+
|
|
44
|
+
def H_mm(x):
|
|
45
|
+
with torch.enable_grad():
|
|
46
|
+
Hvp = TensorList(hvp(params, grad, x, create_graph=True))
|
|
47
|
+
if reg != 0: Hvp = Hvp + (x*reg)
|
|
48
|
+
return Hvp
|
|
49
|
+
|
|
50
|
+
# -------------------------------- inner step -------------------------------- #
|
|
51
|
+
b = as_tensorlist(grad)
|
|
52
|
+
if 'inner' in self.children:
|
|
53
|
+
b = as_tensorlist(apply(self.children['inner'], [g.clone() for g in grad], params=params, grads=grad, vars=vars))
|
|
54
|
+
|
|
55
|
+
# ---------------------------------- run cg ---------------------------------- #
|
|
56
|
+
x0 = None
|
|
57
|
+
if warm_start: x0 = self.get_state('prev_x', params=params, cls=TensorList) # initialized to 0 which is default anyway
|
|
58
|
+
if x0 is None: x = b.zeros_like().requires_grad_(True)
|
|
59
|
+
else: x = x0.clone().requires_grad_(True)
|
|
60
|
+
|
|
61
|
+
solver = solver_cls(x)
|
|
62
|
+
def lstsq_closure(backward=True):
|
|
63
|
+
Hx = H_mm(x)
|
|
64
|
+
loss = (Hx-b).pow(2).global_mean()
|
|
65
|
+
if backward:
|
|
66
|
+
solver.zero_grad()
|
|
67
|
+
loss.backward(inputs=x)
|
|
68
|
+
return loss
|
|
69
|
+
|
|
70
|
+
if maxiter is None: maxiter = b.global_numel()
|
|
71
|
+
loss = None
|
|
72
|
+
initial_loss = lstsq_closure(False)
|
|
73
|
+
if initial_loss > tol:
|
|
74
|
+
for i in range(maxiter):
|
|
75
|
+
loss = solver.step(lstsq_closure)
|
|
76
|
+
assert loss is not None
|
|
77
|
+
if min(loss, loss/initial_loss) < tol: break
|
|
78
|
+
|
|
79
|
+
print(f'{loss = }')
|
|
80
|
+
|
|
81
|
+
if warm_start:
|
|
82
|
+
assert x0 is not None
|
|
83
|
+
x0.copy_(x)
|
|
84
|
+
|
|
85
|
+
vars.update = x.detach()
|
|
86
|
+
return vars
|
|
87
|
+
|
|
88
|
+
|
|
@@ -3,7 +3,7 @@ from operator import itemgetter
|
|
|
3
3
|
import torch
|
|
4
4
|
|
|
5
5
|
from ...core import Chainable, Transform, apply
|
|
6
|
-
from
|
|
6
|
+
from ..optimizers.shampoo import _merge_small_dims, _unmerge_small_dims
|
|
7
7
|
|
|
8
8
|
@torch.no_grad
|
|
9
9
|
def update_soap_covariances_(
|
|
@@ -135,7 +135,7 @@ def get_orthogonal_matrix_QR(exp_avg_sq: torch.Tensor, GG: list[torch.Tensor | N
|
|
|
135
135
|
|
|
136
136
|
return final, exp_avg_sq
|
|
137
137
|
|
|
138
|
-
class
|
|
138
|
+
class SOAPY(Transform):
|
|
139
139
|
"""SOAP but uses scaled gradient differences
|
|
140
140
|
|
|
141
141
|
new args
|
|
@@ -195,7 +195,7 @@ class DSOAP(Transform):
|
|
|
195
195
|
if 'g_prev' not in state:
|
|
196
196
|
state['p_prev'] = p.clone()
|
|
197
197
|
state['g_prev'] = t.clone()
|
|
198
|
-
updates.append(tensors[i].
|
|
198
|
+
updates.append(tensors[i].clip(-0.1,0.1))
|
|
199
199
|
continue
|
|
200
200
|
|
|
201
201
|
p_prev = state['p_prev']
|
|
@@ -228,7 +228,7 @@ class DSOAP(Transform):
|
|
|
228
228
|
state['Q'] = get_orthogonal_matrix(state['GG'])
|
|
229
229
|
|
|
230
230
|
state['step'] = 0
|
|
231
|
-
updates.append(tensors[i].
|
|
231
|
+
updates.append(tensors[i].clip(-0.1,0.1))
|
|
232
232
|
continue # skip 1st step as in https://github.com/nikhilvyas/SOAP/blob/main/soap.py ?
|
|
233
233
|
# I use sign instead as to not mess up with next modules. 1st Adam step is always sign anyway.
|
|
234
234
|
|
|
@@ -194,8 +194,10 @@ class SpectralPreconditioner(TensorwisePreconditioner):
|
|
|
194
194
|
order (int, optional):
|
|
195
195
|
whitening order, 1 approximates FIM (maybe), 2 - hessian (maybe), 3+ - god knows what.
|
|
196
196
|
solver (str, optional): what to use for whitening. Defaults to 'svd'.
|
|
197
|
-
|
|
198
|
-
|
|
197
|
+
A_beta (float | None, optional):
|
|
198
|
+
beta for U (in SVD and other letters in other solvers) (probably a bad idea). Defaults to None.
|
|
199
|
+
B_beta (float | None, optional):
|
|
200
|
+
beta for S (in SVD and other letters in other solvers) (probably a bad idea). Defaults to None.
|
|
199
201
|
interval (int, optional): How often to update history. Defaults to 1 (every step).
|
|
200
202
|
concat_params (bool, optional):
|
|
201
203
|
whether to apply preconditioning to each tensor (False, default) or to all tensors concatenated into a vector (True). Latter will be slower but captures interactions between layers. Defaults to True.
|
|
@@ -275,7 +277,7 @@ class SpectralPreconditioner(TensorwisePreconditioner):
|
|
|
275
277
|
A = state.get('A', None)
|
|
276
278
|
if A is None:
|
|
277
279
|
# make a conservative step to avoid issues due to different GD scaling
|
|
278
|
-
return tensor.
|
|
280
|
+
return tensor.clip_(-0.1, 0.1) # pyright:ignore[reportArgumentType]
|
|
279
281
|
|
|
280
282
|
B = state['B']
|
|
281
283
|
update = solver.apply(tensor.view(-1), A, B).view_as(tensor)
|
|
@@ -0,0 +1,111 @@
|
|
|
1
|
+
# idea https://arxiv.org/pdf/2212.09841
|
|
2
|
+
import warnings
|
|
3
|
+
from collections.abc import Callable
|
|
4
|
+
from functools import partial
|
|
5
|
+
from typing import Literal
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
|
|
9
|
+
from ...core import Chainable, Module, apply
|
|
10
|
+
from ...utils import TensorList, vec_to_tensors
|
|
11
|
+
from ...utils.derivatives import (
|
|
12
|
+
hessian_list_to_mat,
|
|
13
|
+
hessian_mat,
|
|
14
|
+
hvp,
|
|
15
|
+
hvp_fd_central,
|
|
16
|
+
hvp_fd_forward,
|
|
17
|
+
jacobian_and_hessian_wrt,
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class StructuredNewton(Module):
|
|
22
|
+
"""TODO
|
|
23
|
+
Args:
|
|
24
|
+
structure (str, optional): structure.
|
|
25
|
+
reg (float, optional): tikhonov regularizer value. Defaults to 1e-6.
|
|
26
|
+
hvp_method (str):
|
|
27
|
+
how to calculate hvp_method. Defaults to "autograd".
|
|
28
|
+
inner (Chainable | None, optional): inner modules. Defaults to None.
|
|
29
|
+
|
|
30
|
+
"""
|
|
31
|
+
def __init__(
|
|
32
|
+
self,
|
|
33
|
+
structure: Literal[
|
|
34
|
+
"diagonal",
|
|
35
|
+
"diagonal1",
|
|
36
|
+
"diagonal_abs",
|
|
37
|
+
"tridiagonal",
|
|
38
|
+
"circulant",
|
|
39
|
+
"toeplitz",
|
|
40
|
+
"toeplitz_like",
|
|
41
|
+
"hankel",
|
|
42
|
+
"rank1",
|
|
43
|
+
"rank2", # any rank
|
|
44
|
+
]
|
|
45
|
+
| str = "diagonal",
|
|
46
|
+
reg: float = 1e-6,
|
|
47
|
+
hvp_method: Literal["autograd", "forward", "central"] = "autograd",
|
|
48
|
+
h: float = 1e-3,
|
|
49
|
+
inner: Chainable | None = None,
|
|
50
|
+
):
|
|
51
|
+
defaults = dict(reg=reg, hvp_method=hvp_method, structure=structure, h=h)
|
|
52
|
+
super().__init__(defaults)
|
|
53
|
+
|
|
54
|
+
if inner is not None:
|
|
55
|
+
self.set_child('inner', inner)
|
|
56
|
+
|
|
57
|
+
@torch.no_grad
|
|
58
|
+
def step(self, vars):
|
|
59
|
+
params = TensorList(vars.params)
|
|
60
|
+
closure = vars.closure
|
|
61
|
+
if closure is None: raise RuntimeError('NewtonCG requires closure')
|
|
62
|
+
|
|
63
|
+
settings = self.settings[params[0]]
|
|
64
|
+
reg = settings['reg']
|
|
65
|
+
hvp_method = settings['hvp_method']
|
|
66
|
+
structure = settings['structure']
|
|
67
|
+
h = settings['h']
|
|
68
|
+
|
|
69
|
+
# ------------------------ calculate grad and hessian ------------------------ #
|
|
70
|
+
if hvp_method == 'autograd':
|
|
71
|
+
grad = vars.get_grad(create_graph=True)
|
|
72
|
+
def Hvp_fn1(x):
|
|
73
|
+
return hvp(params, grad, x, retain_graph=True)
|
|
74
|
+
Hvp_fn = Hvp_fn1
|
|
75
|
+
|
|
76
|
+
elif hvp_method == 'forward':
|
|
77
|
+
grad = vars.get_grad()
|
|
78
|
+
def Hvp_fn2(x):
|
|
79
|
+
return hvp_fd_forward(closure, params, x, h=h, g_0=grad, normalize=True)[1]
|
|
80
|
+
Hvp_fn = Hvp_fn2
|
|
81
|
+
|
|
82
|
+
elif hvp_method == 'central':
|
|
83
|
+
grad = vars.get_grad()
|
|
84
|
+
def Hvp_fn3(x):
|
|
85
|
+
return hvp_fd_central(closure, params, x, h=h, normalize=True)[1]
|
|
86
|
+
Hvp_fn = Hvp_fn3
|
|
87
|
+
|
|
88
|
+
else: raise ValueError(hvp_method)
|
|
89
|
+
|
|
90
|
+
# -------------------------------- inner step -------------------------------- #
|
|
91
|
+
update = vars.get_update()
|
|
92
|
+
if 'inner' in self.children:
|
|
93
|
+
update = apply(self.children['inner'], update, params=params, grads=grad, vars=vars)
|
|
94
|
+
|
|
95
|
+
# hessian
|
|
96
|
+
if structure.startswith('diagonal'):
|
|
97
|
+
H = Hvp_fn([torch.ones_like(p) for p in params])
|
|
98
|
+
if structure == 'diagonal1': torch._foreach_clamp_min_(H, 1)
|
|
99
|
+
if structure == 'diagonal_abs': torch._foreach_abs_(H)
|
|
100
|
+
torch._foreach_add_(H, reg)
|
|
101
|
+
torch._foreach_div_(update, H)
|
|
102
|
+
vars.update = update
|
|
103
|
+
return vars
|
|
104
|
+
|
|
105
|
+
# hessian
|
|
106
|
+
raise NotImplementedError(structure)
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
|
|
@@ -38,16 +38,19 @@ def apply_subspace_preconditioner(
|
|
|
38
38
|
return basis @ update_projected # d
|
|
39
39
|
|
|
40
40
|
class RandomSubspacePreconditioning(Transform):
|
|
41
|
-
"""full matrix rmsprop in random subspace"""
|
|
42
|
-
def __init__(self, k: int, beta: float | None = 0.99):
|
|
43
|
-
defaults = dict(k=k, beta=beta)
|
|
41
|
+
"""full matrix rmsprop in random slowly changing subspace"""
|
|
42
|
+
def __init__(self, k: int, beta: float | None = 0.99, basis_beta: float | None = 0.99, inner: Chainable | None = None):
|
|
43
|
+
defaults = dict(k=k, beta=beta, basis_beta=basis_beta)
|
|
44
44
|
super().__init__(defaults, uses_grad=False)
|
|
45
45
|
|
|
46
|
+
if inner is not None: self.set_child('inner', inner)
|
|
47
|
+
|
|
46
48
|
def transform(self, tensors, params, grads, vars):
|
|
47
49
|
settings = self.settings[params[0]]
|
|
48
50
|
g = torch.cat([t.view(-1) for t in tensors])
|
|
49
51
|
k = settings['k']
|
|
50
52
|
beta = settings['beta']
|
|
53
|
+
basis_beta = settings['basis_beta']
|
|
51
54
|
|
|
52
55
|
if 'basis' not in self.global_state:
|
|
53
56
|
self.global_state['basis'] = torch.randn(g.numel(), k, device=g.device, dtype=g.dtype)
|
|
@@ -56,13 +59,19 @@ class RandomSubspacePreconditioning(Transform):
|
|
|
56
59
|
basis = self.global_state['basis']
|
|
57
60
|
accumulator = self.global_state['accumulator']
|
|
58
61
|
|
|
62
|
+
if basis_beta is not None:
|
|
63
|
+
basis.lerp_(torch.randn_like(basis), 1-basis_beta)
|
|
64
|
+
|
|
59
65
|
update_subspace_preconditioner_(g, basis, accumulator, beta)
|
|
66
|
+
|
|
67
|
+
if 'inner' in self.children:
|
|
68
|
+
tensors = apply(self.children['inner'], tensors, params, grads, vars)
|
|
69
|
+
g = torch.cat([t.view(-1) for t in tensors])
|
|
70
|
+
|
|
60
71
|
try:
|
|
61
72
|
preconditioned = apply_subspace_preconditioner(g, basis, accumulator)
|
|
62
73
|
except torch.linalg.LinAlgError:
|
|
63
|
-
|
|
64
|
-
if denom <= 1e-10: denom = torch.ones_like(denom)
|
|
65
|
-
preconditioned = g / g.abs().sum()
|
|
74
|
+
preconditioned = g.clip(-0.1, 0.1)
|
|
66
75
|
vec_to_tensors_(preconditioned, tensors)
|
|
67
76
|
|
|
68
77
|
return tensors
|
|
@@ -119,9 +128,7 @@ class HistorySubspacePreconditioning(Transform):
|
|
|
119
128
|
try:
|
|
120
129
|
preconditioned = apply_subspace_preconditioner(g, basis, accumulator)
|
|
121
130
|
except torch.linalg.LinAlgError:
|
|
122
|
-
|
|
123
|
-
if denom <= 1e-10: denom = torch.ones_like(denom)
|
|
124
|
-
preconditioned = g / g.abs().sum()
|
|
131
|
+
preconditioned = g.clip(-0.1,0.1)
|
|
125
132
|
vec_to_tensors_(preconditioned, tensors)
|
|
126
133
|
|
|
127
134
|
return tensors
|
|
@@ -222,8 +222,7 @@ class SOAP(Transform):
|
|
|
222
222
|
state['Q'] = get_orthogonal_matrix(state['GG'])
|
|
223
223
|
|
|
224
224
|
state['step'] = 0
|
|
225
|
-
updates.append(tensors[i].
|
|
226
|
-
# updates.append(tensors[i] / tensors[i].abs().sum())
|
|
225
|
+
updates.append(tensors[i].clip(-0.1, 0.1))
|
|
227
226
|
continue # skip 1st step as in https://github.com/nikhilvyas/SOAP/blob/main/soap.py ?
|
|
228
227
|
# I use scaled update instead as to not mess up with next modules.
|
|
229
228
|
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import math
|
|
2
|
+
from functools import partial
|
|
2
3
|
from abc import ABC, abstractmethod
|
|
3
4
|
from collections.abc import Iterable
|
|
4
5
|
from typing import Any, Literal
|
|
@@ -33,6 +34,25 @@ def _make_projected_closure(closure, vars: Vars, projection: "Projection",
|
|
|
33
34
|
|
|
34
35
|
return projected_closure
|
|
35
36
|
|
|
37
|
+
def _projected_get_grad_override(
|
|
38
|
+
retain_graph: bool | None = None,
|
|
39
|
+
create_graph: bool = False,
|
|
40
|
+
projection: Any = ...,
|
|
41
|
+
unprojected_vars: Any = ...,
|
|
42
|
+
self: Any = ...,
|
|
43
|
+
):
|
|
44
|
+
assert isinstance(projection, Projection)
|
|
45
|
+
assert isinstance(unprojected_vars, Vars)
|
|
46
|
+
assert isinstance(self, Vars)
|
|
47
|
+
|
|
48
|
+
if self.grad is not None: return self.grad
|
|
49
|
+
grads = unprojected_vars.get_grad(retain_graph, create_graph)
|
|
50
|
+
projected_grads = list(projection.project(grads, self, current='grads'))
|
|
51
|
+
self.grad = projected_grads
|
|
52
|
+
for p, g in zip(self.params, projected_grads):
|
|
53
|
+
p.grad = g
|
|
54
|
+
return self.grad
|
|
55
|
+
|
|
36
56
|
|
|
37
57
|
class Projection(Module, ABC):
|
|
38
58
|
"""
|
|
@@ -137,6 +157,12 @@ class Projection(Module, ABC):
|
|
|
137
157
|
|
|
138
158
|
# step
|
|
139
159
|
projected_vars.params = self._projected_params
|
|
160
|
+
projected_vars.get_grad = partial(
|
|
161
|
+
_projected_get_grad_override,
|
|
162
|
+
projection=self,
|
|
163
|
+
unprojected_vars=vars,
|
|
164
|
+
self=projected_vars,
|
|
165
|
+
)
|
|
140
166
|
projected_vars = self.children['modules'].step(projected_vars)
|
|
141
167
|
|
|
142
168
|
# empty fake params storage
|
|
@@ -149,7 +175,7 @@ class Projection(Module, ABC):
|
|
|
149
175
|
unprojected_vars = projected_vars.clone(clone_update=False)
|
|
150
176
|
unprojected_vars.closure = vars.closure
|
|
151
177
|
unprojected_vars.params = vars.params
|
|
152
|
-
|
|
178
|
+
unprojected_vars.grad = vars.grad
|
|
153
179
|
|
|
154
180
|
if self._project_update:
|
|
155
181
|
assert projected_vars.update is not None
|
|
@@ -64,7 +64,7 @@ class ConguateGradientBase(Transform, ABC):
|
|
|
64
64
|
# ------------------------------- Polak-Ribière ------------------------------ #
|
|
65
65
|
def polak_ribiere_beta(g: TensorList, prev_g: TensorList):
|
|
66
66
|
denom = prev_g.dot(prev_g)
|
|
67
|
-
if denom
|
|
67
|
+
if denom.abs() <= torch.finfo(g[0].dtype).eps: return 0
|
|
68
68
|
return g.dot(g - prev_g) / denom
|
|
69
69
|
|
|
70
70
|
class PolakRibiere(ConguateGradientBase):
|
|
@@ -76,8 +76,8 @@ class PolakRibiere(ConguateGradientBase):
|
|
|
76
76
|
return polak_ribiere_beta(g, prev_g)
|
|
77
77
|
|
|
78
78
|
# ------------------------------ Fletcher–Reeves ----------------------------- #
|
|
79
|
-
def fletcher_reeves_beta(gg, prev_gg):
|
|
80
|
-
if prev_gg
|
|
79
|
+
def fletcher_reeves_beta(gg: torch.Tensor, prev_gg: torch.Tensor):
|
|
80
|
+
if prev_gg.abs() <= torch.finfo(gg.dtype).eps: return 0
|
|
81
81
|
return gg / prev_gg
|
|
82
82
|
|
|
83
83
|
class FletcherReeves(ConguateGradientBase):
|
|
@@ -98,7 +98,7 @@ class FletcherReeves(ConguateGradientBase):
|
|
|
98
98
|
def hestenes_stiefel_beta(g: TensorList, prev_d: TensorList,prev_g: TensorList):
|
|
99
99
|
grad_diff = g - prev_g
|
|
100
100
|
denom = prev_d.dot(grad_diff)
|
|
101
|
-
if denom
|
|
101
|
+
if denom.abs() < torch.finfo(g[0].dtype).eps: return 0
|
|
102
102
|
return (g.dot(grad_diff) / denom).neg()
|
|
103
103
|
|
|
104
104
|
|
|
@@ -114,7 +114,7 @@ class HestenesStiefel(ConguateGradientBase):
|
|
|
114
114
|
# --------------------------------- Dai–Yuan --------------------------------- #
|
|
115
115
|
def dai_yuan_beta(g: TensorList, prev_d: TensorList,prev_g: TensorList):
|
|
116
116
|
denom = prev_d.dot(g - prev_g)
|
|
117
|
-
if denom
|
|
117
|
+
if denom.abs() <= torch.finfo(g[0].dtype).eps: return 0
|
|
118
118
|
return (g.dot(g) / denom).neg()
|
|
119
119
|
|
|
120
120
|
class DaiYuan(ConguateGradientBase):
|
|
@@ -129,7 +129,7 @@ class DaiYuan(ConguateGradientBase):
|
|
|
129
129
|
# -------------------------------- Liu-Storey -------------------------------- #
|
|
130
130
|
def liu_storey_beta(g:TensorList, prev_d:TensorList, prev_g:TensorList, ):
|
|
131
131
|
denom = prev_g.dot(prev_d)
|
|
132
|
-
if denom
|
|
132
|
+
if denom.abs() <= torch.finfo(g[0].dtype).eps: return 0
|
|
133
133
|
return g.dot(g - prev_g) / denom
|
|
134
134
|
|
|
135
135
|
class LiuStorey(ConguateGradientBase):
|
|
@@ -159,7 +159,7 @@ class ConjugateDescent(Transform):
|
|
|
159
159
|
self.global_state['denom'] = torch.tensor(0.).to(g[0])
|
|
160
160
|
|
|
161
161
|
prev_gd = self.global_state.get('prev_gd', 0)
|
|
162
|
-
if prev_gd
|
|
162
|
+
if abs(prev_gd) <= torch.finfo(g[0].dtype).eps: beta = 0
|
|
163
163
|
else: beta = g.dot(g) / prev_gd
|
|
164
164
|
|
|
165
165
|
# inner step
|
|
@@ -176,7 +176,7 @@ class ConjugateDescent(Transform):
|
|
|
176
176
|
def hager_zhang_beta(g:TensorList, prev_d:TensorList, prev_g:TensorList,):
|
|
177
177
|
g_diff = g - prev_g
|
|
178
178
|
denom = prev_d.dot(g_diff)
|
|
179
|
-
if denom
|
|
179
|
+
if denom.abs() <= torch.finfo(g[0].dtype).eps: return 0
|
|
180
180
|
|
|
181
181
|
term1 = 1/denom
|
|
182
182
|
# term2
|
|
@@ -198,7 +198,7 @@ class HagerZhang(ConguateGradientBase):
|
|
|
198
198
|
def hs_dy_beta(g: TensorList, prev_d: TensorList,prev_g: TensorList):
|
|
199
199
|
grad_diff = g - prev_g
|
|
200
200
|
denom = prev_d.dot(grad_diff)
|
|
201
|
-
if denom
|
|
201
|
+
if denom.abs() <= torch.finfo(g[0].dtype).eps: return 0
|
|
202
202
|
|
|
203
203
|
# Dai-Yuan
|
|
204
204
|
dy_beta = (g.dot(g) / denom).neg().clamp(min=0)
|
|
@@ -37,10 +37,11 @@ def lbfgs(
|
|
|
37
37
|
z_tfm: Any,
|
|
38
38
|
):
|
|
39
39
|
if len(s_history) == 0 or y_k is None or ys_k is None:
|
|
40
|
-
# dir = params.grad.sign() # may work fine
|
|
41
40
|
|
|
42
|
-
# initial step size guess
|
|
43
|
-
|
|
41
|
+
# initial step size guess modified from pytorch L-BFGS
|
|
42
|
+
scale = 1 / tensors_.abs().global_sum()
|
|
43
|
+
if scale < 1e-5: scale = 1 / tensors_.abs().mean()
|
|
44
|
+
return tensors_.mul_(min(1.0, scale)) # pyright: ignore[reportArgumentType]
|
|
44
45
|
|
|
45
46
|
else:
|
|
46
47
|
# 1st loop
|
|
@@ -36,10 +36,11 @@ def lbfgs(
|
|
|
36
36
|
step: int,
|
|
37
37
|
):
|
|
38
38
|
if len(s_history) == 0 or y_k is None or ys_k is None:
|
|
39
|
-
# dir = params.grad.sign() # may work fine
|
|
40
39
|
|
|
41
|
-
# initial step size guess
|
|
42
|
-
|
|
40
|
+
# initial step size guess modified from pytorch L-BFGS
|
|
41
|
+
scale_factor = 1 / TensorList(tensors_).abs().global_sum().clip(min=1)
|
|
42
|
+
scale_factor = scale_factor.clip(min=torch.finfo(tensors_[0].dtype).eps)
|
|
43
|
+
return tensors_.mul_(scale_factor)
|
|
43
44
|
|
|
44
45
|
else:
|
|
45
46
|
# 1st loop
|
|
@@ -17,8 +17,9 @@ def lsr1_(
|
|
|
17
17
|
):
|
|
18
18
|
if step == 0 or not s_history:
|
|
19
19
|
# initial step size guess from pytorch
|
|
20
|
-
|
|
21
|
-
|
|
20
|
+
scale_factor = 1 / TensorList(tensors_).abs().global_sum().clip(min=1)
|
|
21
|
+
scale_factor = scale_factor.clip(min=torch.finfo(tensors_[0].dtype).eps)
|
|
22
|
+
return tensors_.mul_(scale_factor)
|
|
22
23
|
|
|
23
24
|
m = len(s_history)
|
|
24
25
|
|
|
@@ -64,7 +65,10 @@ def lsr1_(
|
|
|
64
65
|
Hx.add_(w_k, alpha=w_k.dot(tensors_) / wy) # pyright:ignore[reportArgumentType]
|
|
65
66
|
|
|
66
67
|
if scale_second and step == 1:
|
|
67
|
-
|
|
68
|
+
scale_factor = 1 / TensorList(tensors_).abs().global_sum().clip(min=1)
|
|
69
|
+
scale_factor = scale_factor.clip(min=torch.finfo(tensors_[0].dtype).eps)
|
|
70
|
+
Hx.mul_(scale_factor)
|
|
71
|
+
|
|
68
72
|
return Hx
|
|
69
73
|
|
|
70
74
|
|
|
@@ -68,6 +68,7 @@ class HessianUpdateStrategy(TensorwisePreconditioner, ABC):
|
|
|
68
68
|
M_key = 'H' if inverse else 'B'
|
|
69
69
|
M = state.get(M_key, None)
|
|
70
70
|
step = state.get('step', 0)
|
|
71
|
+
state['step'] = step + 1
|
|
71
72
|
init_scale = settings['init_scale']
|
|
72
73
|
tol = settings['tol']
|
|
73
74
|
tol_reset = settings['tol_reset']
|
|
@@ -91,13 +92,12 @@ class HessianUpdateStrategy(TensorwisePreconditioner, ABC):
|
|
|
91
92
|
state['p_prev'].copy_(p)
|
|
92
93
|
state['g_prev'].copy_(g)
|
|
93
94
|
|
|
94
|
-
|
|
95
|
-
if reset_interval is not None and step % reset_interval == 0:
|
|
95
|
+
if reset_interval is not None and step != 0 and step % reset_interval == 0:
|
|
96
96
|
self._reset_M_(M, s, y, inverse, init_scale)
|
|
97
97
|
return
|
|
98
98
|
|
|
99
99
|
# tolerance on gradient difference to avoid exploding after converging
|
|
100
|
-
|
|
100
|
+
elif y.abs().max() <= tol:
|
|
101
101
|
# reset history
|
|
102
102
|
if tol_reset: self._reset_M_(M, s, y, inverse, init_scale)
|
|
103
103
|
return
|
|
@@ -119,11 +119,12 @@ class HessianUpdateStrategy(TensorwisePreconditioner, ABC):
|
|
|
119
119
|
|
|
120
120
|
@torch.no_grad
|
|
121
121
|
def apply_tensor(self, tensor, param, grad, state, settings):
|
|
122
|
-
step = state
|
|
122
|
+
step = state.get('step', 0)
|
|
123
123
|
|
|
124
124
|
if settings['scale_second'] and step == 2:
|
|
125
|
-
|
|
126
|
-
|
|
125
|
+
scale_factor = 1 / tensor.abs().sum().clip(min=1)
|
|
126
|
+
scale_factor = scale_factor.clip(min=torch.finfo(tensor.dtype).eps)
|
|
127
|
+
tensor = tensor * scale_factor
|
|
127
128
|
|
|
128
129
|
inverse = settings['inverse']
|
|
129
130
|
if inverse:
|
|
@@ -135,7 +136,7 @@ class HessianUpdateStrategy(TensorwisePreconditioner, ABC):
|
|
|
135
136
|
return torch.linalg.solve_ex(B, tensor.view(-1))[0].view_as(tensor) # pylint:disable=not-callable
|
|
136
137
|
|
|
137
138
|
# to avoid typing all arguments for each method
|
|
138
|
-
class
|
|
139
|
+
class HUpdateStrategy(HessianUpdateStrategy):
|
|
139
140
|
def __init__(
|
|
140
141
|
self,
|
|
141
142
|
init_scale: float | Literal["auto"] = "auto",
|
|
@@ -174,7 +175,7 @@ def bfgs_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
|
|
|
174
175
|
H += term1.sub_(term2)
|
|
175
176
|
return H
|
|
176
177
|
|
|
177
|
-
class BFGS(
|
|
178
|
+
class BFGS(HUpdateStrategy):
|
|
178
179
|
def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
|
|
179
180
|
return bfgs_H_(H=H, s=s, y=y, tol=settings['tol'])
|
|
180
181
|
|
|
@@ -193,7 +194,7 @@ def sr1_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol:float):
|
|
|
193
194
|
H += torch.outer(z, z).div_(denom)
|
|
194
195
|
return H
|
|
195
196
|
|
|
196
|
-
class SR1(
|
|
197
|
+
class SR1(HUpdateStrategy):
|
|
197
198
|
def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
|
|
198
199
|
return sr1_H_(H=H, s=s, y=y, tol=settings['tol'])
|
|
199
200
|
|
|
@@ -213,7 +214,7 @@ def dfp_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
|
|
|
213
214
|
H += term1.sub_(term2)
|
|
214
215
|
return H
|
|
215
216
|
|
|
216
|
-
class DFP(
|
|
217
|
+
class DFP(HUpdateStrategy):
|
|
217
218
|
def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
|
|
218
219
|
return dfp_H_(H=H, s=s, y=y, tol=settings['tol'])
|
|
219
220
|
|
|
@@ -254,19 +255,19 @@ def greenstadt2_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
|
|
|
254
255
|
H -= num/denom
|
|
255
256
|
return H
|
|
256
257
|
|
|
257
|
-
class BroydenGood(
|
|
258
|
+
class BroydenGood(HUpdateStrategy):
|
|
258
259
|
def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
|
|
259
260
|
return broyden_good_H_(H=H, s=s, y=y, tol=settings['tol'])
|
|
260
261
|
|
|
261
|
-
class BroydenBad(
|
|
262
|
+
class BroydenBad(HUpdateStrategy):
|
|
262
263
|
def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
|
|
263
264
|
return broyden_bad_H_(H=H, s=s, y=y, tol=settings['tol'])
|
|
264
265
|
|
|
265
|
-
class Greenstadt1(
|
|
266
|
+
class Greenstadt1(HUpdateStrategy):
|
|
266
267
|
def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
|
|
267
268
|
return greenstadt1_H_(H=H, s=s, y=y, g_prev=g_prev, tol=settings['tol'])
|
|
268
269
|
|
|
269
|
-
class Greenstadt2(
|
|
270
|
+
class Greenstadt2(HUpdateStrategy):
|
|
270
271
|
def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
|
|
271
272
|
return greenstadt2_H_(H=H, s=s, y=y, tol=settings['tol'])
|
|
272
273
|
|
|
@@ -287,7 +288,7 @@ def column_updating_H_(H:torch.Tensor, s:torch.Tensor, y:torch.Tensor, tol:float
|
|
|
287
288
|
H[:, j] += num.squeeze() / denom
|
|
288
289
|
return H
|
|
289
290
|
|
|
290
|
-
class ColumnUpdatingMethod(
|
|
291
|
+
class ColumnUpdatingMethod(HUpdateStrategy):
|
|
291
292
|
"""Lopes, V. L., & Martínez, J. M. (1995). Convergence properties of the inverse column-updating method. Optimization Methods & Software, 6(2), 127–144. from https://www.ime.unicamp.br/sites/default/files/pesquisa/relatorios/rp-1993-76.pdf"""
|
|
292
293
|
def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
|
|
293
294
|
return column_updating_H_(H=H, s=s, y=y, tol=settings['tol'])
|
|
@@ -307,7 +308,7 @@ def thomas_H_(H: torch.Tensor, R:torch.Tensor, s: torch.Tensor, y: torch.Tensor,
|
|
|
307
308
|
H -= num/denom
|
|
308
309
|
return H, R
|
|
309
310
|
|
|
310
|
-
class ThomasOptimalMethod(
|
|
311
|
+
class ThomasOptimalMethod(HUpdateStrategy):
|
|
311
312
|
"""Thomas, Stephen Walter. Sequential estimation techniques for quasi-Newton algorithms. Cornell University, 1975."""
|
|
312
313
|
def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
|
|
313
314
|
if 'R' not in state: state['R'] = torch.eye(H.size(-1), device=H.device, dtype=H.dtype)
|
|
@@ -364,7 +365,7 @@ def pearson2_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
|
|
|
364
365
|
H += num.div_(sy)
|
|
365
366
|
return H
|
|
366
367
|
|
|
367
|
-
class Pearson2(
|
|
368
|
+
class Pearson2(HUpdateStrategy):
|
|
368
369
|
"""finally found a reference in https://www.recotechnologies.com/~beigi/ps/asme-jdsmc-93-2.pdf"""
|
|
369
370
|
def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
|
|
370
371
|
return pearson2_H_(H=H, s=s, y=y, tol=settings['tol'])
|
|
@@ -1,14 +1,18 @@
|
|
|
1
1
|
import warnings
|
|
2
|
+
from collections.abc import Callable
|
|
2
3
|
from functools import partial
|
|
3
4
|
from typing import Literal
|
|
4
|
-
|
|
5
|
+
|
|
5
6
|
import torch
|
|
6
7
|
|
|
7
|
-
from ...core import Chainable,
|
|
8
|
-
from ...utils import
|
|
8
|
+
from ...core import Chainable, Module, apply
|
|
9
|
+
from ...utils import TensorList, vec_to_tensors
|
|
9
10
|
from ...utils.derivatives import (
|
|
10
11
|
hessian_list_to_mat,
|
|
11
12
|
hessian_mat,
|
|
13
|
+
hvp,
|
|
14
|
+
hvp_fd_central,
|
|
15
|
+
hvp_fd_forward,
|
|
12
16
|
jacobian_and_hessian_wrt,
|
|
13
17
|
)
|
|
14
18
|
|
|
@@ -117,9 +121,10 @@ class Newton(Module):
|
|
|
117
121
|
raise ValueError(hessian_method)
|
|
118
122
|
|
|
119
123
|
# -------------------------------- inner step -------------------------------- #
|
|
124
|
+
update = vars.get_update()
|
|
120
125
|
if 'inner' in self.children:
|
|
121
|
-
|
|
122
|
-
g = torch.cat([t.view(-1) for t in
|
|
126
|
+
update = apply(self.children['inner'], update, params=params, grads=list(g_list), vars=vars)
|
|
127
|
+
g = torch.cat([t.view(-1) for t in update])
|
|
123
128
|
|
|
124
129
|
# ------------------------------- regulazition ------------------------------- #
|
|
125
130
|
if eig_reg: H = eig_tikhonov_(H, reg)
|
|
@@ -139,4 +144,4 @@ class Newton(Module):
|
|
|
139
144
|
if update is None: update = least_squares_solve(H, g)
|
|
140
145
|
|
|
141
146
|
vars.update = vec_to_tensors(update, params)
|
|
142
|
-
return vars
|
|
147
|
+
return vars
|
|
@@ -66,9 +66,9 @@ class NewtonCG(Module):
|
|
|
66
66
|
|
|
67
67
|
|
|
68
68
|
# -------------------------------- inner step -------------------------------- #
|
|
69
|
-
b =
|
|
69
|
+
b = vars.get_update()
|
|
70
70
|
if 'inner' in self.children:
|
|
71
|
-
b = as_tensorlist(apply(self.children['inner'],
|
|
71
|
+
b = as_tensorlist(apply(self.children['inner'], b, params=params, grads=grad, vars=vars))
|
|
72
72
|
|
|
73
73
|
# ---------------------------------- run cg ---------------------------------- #
|
|
74
74
|
x0 = None
|
|
@@ -76,7 +76,7 @@ class NewtonCG(Module):
|
|
|
76
76
|
x = cg(A_mm=H_mm, b=as_tensorlist(b), x0_=x0, tol=tol, maxiter=maxiter, reg=reg)
|
|
77
77
|
if warm_start:
|
|
78
78
|
assert x0 is not None
|
|
79
|
-
x0.
|
|
79
|
+
x0.copy_(x)
|
|
80
80
|
|
|
81
81
|
vars.update = x
|
|
82
82
|
return vars
|
|
@@ -15,7 +15,7 @@ class NystromSketchAndSolve(Module):
|
|
|
15
15
|
rank: int,
|
|
16
16
|
reg: float = 1e-3,
|
|
17
17
|
hvp_method: Literal["forward", "central", "autograd"] = "autograd",
|
|
18
|
-
h=1e-
|
|
18
|
+
h=1e-2,
|
|
19
19
|
inner: Chainable | None = None,
|
|
20
20
|
seed: int | None = None,
|
|
21
21
|
):
|
|
@@ -74,9 +74,9 @@ class NystromSketchAndSolve(Module):
|
|
|
74
74
|
|
|
75
75
|
|
|
76
76
|
# -------------------------------- inner step -------------------------------- #
|
|
77
|
-
b =
|
|
77
|
+
b = vars.get_update()
|
|
78
78
|
if 'inner' in self.children:
|
|
79
|
-
b = apply(self.children['inner'],
|
|
79
|
+
b = apply(self.children['inner'], b, params=params, grads=grad, vars=vars)
|
|
80
80
|
|
|
81
81
|
# ------------------------------ sketch&n&solve ------------------------------ #
|
|
82
82
|
x = nystrom_sketch_and_solve(A_mm=H_mm, b=torch.cat([t.ravel() for t in b]), rank=rank, reg=reg, generator=generator)
|
|
@@ -93,7 +93,7 @@ class NystromPCG(Module):
|
|
|
93
93
|
tol=1e-3,
|
|
94
94
|
reg: float = 1e-6,
|
|
95
95
|
hvp_method: Literal["forward", "central", "autograd"] = "autograd",
|
|
96
|
-
h=1e-
|
|
96
|
+
h=1e-2,
|
|
97
97
|
inner: Chainable | None = None,
|
|
98
98
|
seed: int | None = None,
|
|
99
99
|
):
|
|
@@ -156,9 +156,9 @@ class NystromPCG(Module):
|
|
|
156
156
|
|
|
157
157
|
|
|
158
158
|
# -------------------------------- inner step -------------------------------- #
|
|
159
|
-
b =
|
|
159
|
+
b = vars.get_update()
|
|
160
160
|
if 'inner' in self.children:
|
|
161
|
-
b = apply(self.children['inner'],
|
|
161
|
+
b = apply(self.children['inner'], b, params=params, grads=grad, vars=vars)
|
|
162
162
|
|
|
163
163
|
# ------------------------------ sketch&n&solve ------------------------------ #
|
|
164
164
|
x = nystrom_pcg(A_mm=H_mm, b=torch.cat([t.ravel() for t in b]), sketch_size=sketch_size, reg=reg, tol=tol, maxiter=maxiter, x0_=None, generator=generator)
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
from collections.abc import Callable
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def benchmark_solver(
|
|
7
|
+
A: torch.Tensor | Callable[[torch.Tensor], torch.Tensor],
|
|
8
|
+
b: torch.Tensor,
|
|
9
|
+
solver: Callable[[Callable[[torch.Tensor], torch.Tensor], torch.Tensor]]
|
|
10
|
+
):
|
|
11
|
+
residuals = []
|
|
12
|
+
def A_mm(x):
|
|
13
|
+
if callable(A): Ax = A(x)
|
|
14
|
+
else: Ax = A@x
|
|
15
|
+
residuals.append(torch.linalg.vector_norm(Ax-b)) # pylint:disable=not-callable
|
|
16
|
+
return Ax
|
|
17
|
+
|
|
18
|
+
solver(A_mm, b)
|
|
19
|
+
return residuals
|
|
20
|
+
|
torchzero/utils/linalg/solve.py
CHANGED
|
@@ -8,27 +8,27 @@ from .. import TensorList, generic_zeros_like, generic_vector_norm, generic_nume
|
|
|
8
8
|
def cg(
|
|
9
9
|
A_mm: Callable[[torch.Tensor], torch.Tensor],
|
|
10
10
|
b: torch.Tensor,
|
|
11
|
-
x0_: torch.Tensor | None,
|
|
12
|
-
tol: float | None,
|
|
13
|
-
maxiter: int | None,
|
|
11
|
+
x0_: torch.Tensor | None = None,
|
|
12
|
+
tol: float | None = 1e-4,
|
|
13
|
+
maxiter: int | None = None,
|
|
14
14
|
reg: float = 0,
|
|
15
15
|
) -> torch.Tensor: ...
|
|
16
16
|
@overload
|
|
17
17
|
def cg(
|
|
18
18
|
A_mm: Callable[[TensorList], TensorList],
|
|
19
19
|
b: TensorList,
|
|
20
|
-
x0_: TensorList | None,
|
|
21
|
-
tol: float | None,
|
|
22
|
-
maxiter: int | None,
|
|
20
|
+
x0_: TensorList | None = None,
|
|
21
|
+
tol: float | None = 1e-4,
|
|
22
|
+
maxiter: int | None = None,
|
|
23
23
|
reg: float | list[float] | tuple[float] = 0,
|
|
24
24
|
) -> TensorList: ...
|
|
25
25
|
|
|
26
26
|
def cg(
|
|
27
27
|
A_mm: Callable,
|
|
28
28
|
b: torch.Tensor | TensorList,
|
|
29
|
-
x0_: torch.Tensor | TensorList | None,
|
|
30
|
-
tol: float | None,
|
|
31
|
-
maxiter: int | None,
|
|
29
|
+
x0_: torch.Tensor | TensorList | None = None,
|
|
30
|
+
tol: float | None = 1e-4,
|
|
31
|
+
maxiter: int | None = None,
|
|
32
32
|
reg: float | list[float] | tuple[float] = 0,
|
|
33
33
|
):
|
|
34
34
|
def A_mm_reg(x): # A_mm with regularization
|
|
@@ -90,7 +90,7 @@ def nystrom_sketch_and_solve(
|
|
|
90
90
|
A_mm: Callable[[torch.Tensor], torch.Tensor],
|
|
91
91
|
b: torch.Tensor,
|
|
92
92
|
rank: int,
|
|
93
|
-
reg: float,
|
|
93
|
+
reg: float = 1e-3,
|
|
94
94
|
generator=None,
|
|
95
95
|
) -> torch.Tensor:
|
|
96
96
|
U, lambd = nystrom_approximation(
|
|
@@ -116,10 +116,10 @@ def nystrom_pcg(
|
|
|
116
116
|
A_mm: Callable[[torch.Tensor], torch.Tensor],
|
|
117
117
|
b: torch.Tensor,
|
|
118
118
|
sketch_size: int,
|
|
119
|
-
reg: float,
|
|
120
|
-
x0_: torch.Tensor | None,
|
|
121
|
-
tol: float | None,
|
|
122
|
-
maxiter: int | None,
|
|
119
|
+
reg: float = 1e-6,
|
|
120
|
+
x0_: torch.Tensor | None = None,
|
|
121
|
+
tol: float | None = 1e-4,
|
|
122
|
+
maxiter: int | None = None,
|
|
123
123
|
generator=None,
|
|
124
124
|
) -> torch.Tensor:
|
|
125
125
|
U, lambd = nystrom_approximation(
|
|
@@ -166,3 +166,4 @@ def nystrom_pcg(
|
|
|
166
166
|
z = P_inv @ residual
|
|
167
167
|
beta = residual.dot(z) / rz
|
|
168
168
|
p = z + p*beta
|
|
169
|
+
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: torchzero
|
|
3
|
-
Version: 0.3.
|
|
3
|
+
Version: 0.3.9
|
|
4
4
|
Summary: Modular optimization library for PyTorch.
|
|
5
5
|
Author-email: Ivan Nikishev <nkshv2@gmail.com>
|
|
6
6
|
License: MIT License
|
|
@@ -156,7 +156,7 @@ for epoch in range(100):
|
|
|
156
156
|
* `Newton`: Classic Newton's method.
|
|
157
157
|
* `NewtonCG`: Matrix-free newton's method with conjugate gradient solver.
|
|
158
158
|
* `NystromSketchAndSolve`: Nyström sketch-and-solve method.
|
|
159
|
-
* `NystromPCG`: NewtonCG with Nyström preconditioning (
|
|
159
|
+
* `NystromPCG`: NewtonCG with Nyström preconditioning (usually beats NewtonCG).
|
|
160
160
|
|
|
161
161
|
* **Quasi-Newton**: Approximate second-order optimization methods.
|
|
162
162
|
* `LBFGS`: Limited-memory BFGS.
|
|
@@ -1,14 +1,14 @@
|
|
|
1
1
|
docs/source/conf.py,sha256=jd80ZT2IdCx7nlQrpOTJL8UhGBNm6KYyXlpp0jmRiAw,1849
|
|
2
2
|
tests/test_identical.py,sha256=NZ7A8Rm1U9Q16d-cG2G_wccpPtNALyoKYJt9qMownMc,11568
|
|
3
3
|
tests/test_module.py,sha256=qX3rjdSJsbA8JO17bPTUIDspe7bg2dogqxMw__KV7SU,2039
|
|
4
|
-
tests/test_opts.py,sha256=
|
|
5
|
-
tests/test_tensorlist.py,sha256=
|
|
4
|
+
tests/test_opts.py,sha256=TZVaCv2ZLdHSkL6snTEkqhTMHqlcO55L-c56k6Hh4xc,40850
|
|
5
|
+
tests/test_tensorlist.py,sha256=Djpr5C0T5d_gz-j-P-bpo_X51DC4twbtT9c-xDSFbP0,72438
|
|
6
6
|
tests/test_utils_optimizer.py,sha256=bvC0Ehvs2L8fohpyIF5Vfr9OKTycpnODWLPflXilU1c,8414
|
|
7
7
|
tests/test_vars.py,sha256=3p9dsHk7SJpMd-WRD0ziBNq5FEHRBJGSxbMLD8ES4J0,6815
|
|
8
8
|
torchzero/__init__.py,sha256=L7IJ1qZ3o8E9oRwlJZBK2_2yII_eeGEk57Of6EfVbrk,112
|
|
9
9
|
torchzero/core/__init__.py,sha256=2JRyeGZprTexAeEPQOIl9fLFGBwzvya-AwKyt7XAmGQ,210
|
|
10
10
|
torchzero/core/module.py,sha256=Razw3c71Kfegznm0vQxsii1KuTUCPBC9UGyq2v-KX4M,27568
|
|
11
|
-
torchzero/core/preconditioner.py,sha256=
|
|
11
|
+
torchzero/core/preconditioner.py,sha256=n9oh7kZdt1kU3Wh472lnvLrsXwhR5Wqe6lIp7JuAJ_I,6336
|
|
12
12
|
torchzero/core/transform.py,sha256=ajNJcX45ds-_lc5CqxgLfEFGil6_BYLerB0WvoTi8rM,10303
|
|
13
13
|
torchzero/modules/__init__.py,sha256=BDeyuSd2s1WFUUXIo3tGTNp4aYp4A2B94cydpPW24nY,332
|
|
14
14
|
torchzero/modules/functional.py,sha256=HXNzmPe7LsPadryEm7zrcEKqGej16QDwSgBkbEvggFM,6492
|
|
@@ -16,18 +16,20 @@ torchzero/modules/clipping/__init__.py,sha256=ZaffMF7mIRK6hZSfuZadgjNTX6hF5ANiLB
|
|
|
16
16
|
torchzero/modules/clipping/clipping.py,sha256=I-5utyrqdKtF5yaH-9m2F3UqdfpPmA2bSSFUAZ_d60Q,12544
|
|
17
17
|
torchzero/modules/clipping/ema_clipping.py,sha256=pLeNuEBLpJ74io2sHn_ZVYaQ6ydEfhpVfVEX2bFttd0,5947
|
|
18
18
|
torchzero/modules/clipping/growth_clipping.py,sha256=OD-kdia2Rn-DvYlYV6EZlGPDVTh9tj-W9mpiZPc3hOQ,6772
|
|
19
|
-
torchzero/modules/experimental/__init__.py,sha256=
|
|
20
|
-
torchzero/modules/experimental/absoap.py,sha256=
|
|
21
|
-
torchzero/modules/experimental/adadam.py,sha256=
|
|
22
|
-
torchzero/modules/experimental/adamY.py,sha256=
|
|
23
|
-
torchzero/modules/experimental/adasoap.py,sha256=
|
|
24
|
-
torchzero/modules/experimental/algebraic_newton.py,sha256=
|
|
25
|
-
torchzero/modules/experimental/curveball.py,sha256=
|
|
26
|
-
torchzero/modules/experimental/
|
|
27
|
-
torchzero/modules/experimental/
|
|
19
|
+
torchzero/modules/experimental/__init__.py,sha256=fEPDYDl7qhaFoferDRmG3ehwuqSvx4Vt2uOz0Y7h4to,483
|
|
20
|
+
torchzero/modules/experimental/absoap.py,sha256=Z4MS4pDPSQ9IaTk8g57OfrsWcYVOT72x533KKtn2Zxk,13512
|
|
21
|
+
torchzero/modules/experimental/adadam.py,sha256=OAPF1-NUbg79V3QOTYzsQlRC97C7XHj5boOLDqLz3PE,4029
|
|
22
|
+
torchzero/modules/experimental/adamY.py,sha256=g1pAHwgdyDdKvObZ67lCSc36L99tl5jlQgOr4lMJCDo,4595
|
|
23
|
+
torchzero/modules/experimental/adasoap.py,sha256=JdV6rB9xfqL3vbHpZCLmkJZKRObZ1nVoEmabtIeVT3E,11195
|
|
24
|
+
torchzero/modules/experimental/algebraic_newton.py,sha256=sq5ZD_j_EtlxIjNnS0rKKwTSG_JuwsZOg9ZMMQTuQm0,5154
|
|
25
|
+
torchzero/modules/experimental/curveball.py,sha256=Uk30uLEztTHD5IUJLJm9Nn3x31DF9kQHmeLFhc065us,3262
|
|
26
|
+
torchzero/modules/experimental/gradmin.py,sha256=iJmEvDEdVdck0C-94pY3iGxnIoNv6Fu6vj3f7lS6aQM,3686
|
|
27
|
+
torchzero/modules/experimental/newton_solver.py,sha256=iGI2LHLaZd2ovpbq1Vogs76os0zWG7VwM7nUz8RzxVg,3071
|
|
28
28
|
torchzero/modules/experimental/reduce_outward_lr.py,sha256=kjtRwepBGBca77ToM-lw3b8ywptMtmSdC_jQfjJAwlY,1184
|
|
29
|
-
torchzero/modules/experimental/
|
|
30
|
-
torchzero/modules/experimental/
|
|
29
|
+
torchzero/modules/experimental/soapy.py,sha256=Ishd2Jj6BbhjrLyC48zf-cjMmA1kJb_uKXESQBIML_s,10990
|
|
30
|
+
torchzero/modules/experimental/spectral.py,sha256=8_n208V2yPY3z5pCym-FvwO7DGFhozNgWlpIBtQSdrI,12139
|
|
31
|
+
torchzero/modules/experimental/structured_newton.py,sha256=uWczR-uAXHaFwf0mlOThv2sLG0irH6Gz1hKlGHtPAj4,3386
|
|
32
|
+
torchzero/modules/experimental/subspace_preconditioners.py,sha256=WnHpga7Kx4-N2xU5vP3uUHRER70ymyNJCWbSx2zXWOk,4976
|
|
31
33
|
torchzero/modules/experimental/tropical_newton.py,sha256=uq66ouhgrgc8iYGozDQ3_rtbubj8rKRwb1jfcdnlpHg,4903
|
|
32
34
|
torchzero/modules/grad_approximation/__init__.py,sha256=DVFjf0cXuF70NA0nJ2WklpP01PQgrRZxUjUQjjQeSos,195
|
|
33
35
|
torchzero/modules/grad_approximation/fdm.py,sha256=2PNNBIMup1xlOwLFAwAS3xAVd-7GGVyerMeKH1ug9LQ,3591
|
|
@@ -70,26 +72,26 @@ torchzero/modules/optimizers/orthograd.py,sha256=5BLnNJTYuGUClHmlxaXZ1jNvBR4zSFD
|
|
|
70
72
|
torchzero/modules/optimizers/rmsprop.py,sha256=d10Y9Ck-391tVysO3xMHg3g2Pe0UEZplgebEyDYi3Z4,4333
|
|
71
73
|
torchzero/modules/optimizers/rprop.py,sha256=n4k5-9F3ppH0Xl-4l4vNXfqVf2r67vMPCkstUaQKPLw,10974
|
|
72
74
|
torchzero/modules/optimizers/shampoo.py,sha256=AHHV6d71DqKDPCg52ShWIPIRSGtWkMc1v1XwXgDG3qY,8606
|
|
73
|
-
torchzero/modules/optimizers/soap.py,sha256=
|
|
75
|
+
torchzero/modules/optimizers/soap.py,sha256=Kf2BAtIf2QY1V2ZJcUjRLcp2WfIVLd3mNclnaT3Nmds,11520
|
|
74
76
|
torchzero/modules/optimizers/sophia_h.py,sha256=8pSlYVm66xWplzdP8MX3MCTzzIYHsxGzDEXJKA03Zgg,4279
|
|
75
77
|
torchzero/modules/projections/__init__.py,sha256=OCxlh_-Tx-xpl31X03CeFJH9XveH563oEsWc8rUvX0A,196
|
|
76
78
|
torchzero/modules/projections/dct.py,sha256=wxaEV6dTNiOqW_n2UHX0De6mMXTKDXK6UNcMNI4Rogk,2373
|
|
77
79
|
torchzero/modules/projections/fft.py,sha256=OpCcEM1-A2dgk1umwRsBsvK7ObiHtsBKlkkcw0IX83Q,2961
|
|
78
80
|
torchzero/modules/projections/galore.py,sha256=c9CZ0kHxpKEoyfc_lnmeHOkNp55jCppb7onN5YmWnN8,242
|
|
79
|
-
torchzero/modules/projections/projection.py,sha256=
|
|
81
|
+
torchzero/modules/projections/projection.py,sha256=aYufSD3ftRUqVScPmqxwEFgP1P8ioxM8z9eyzaL7d10,10147
|
|
80
82
|
torchzero/modules/projections/structural.py,sha256=QaCGHmzHCXj46sM-XZ5XlYU9BnuRKI2ReR3LE8y2R4g,5740
|
|
81
83
|
torchzero/modules/quasi_newton/__init__.py,sha256=0iOlX73PHj9lQS3_2cJ5lyCdas904MnFfIvR8Popvzw,402
|
|
82
|
-
torchzero/modules/quasi_newton/cg.py,sha256=
|
|
83
|
-
torchzero/modules/quasi_newton/lbfgs.py,sha256=
|
|
84
|
-
torchzero/modules/quasi_newton/lsr1.py,sha256=
|
|
84
|
+
torchzero/modules/quasi_newton/cg.py,sha256=lIJvfWAZ08r0o4uqaJnRG6pvcE2kBkJUkZ1MK37KMTk,9602
|
|
85
|
+
torchzero/modules/quasi_newton/lbfgs.py,sha256=SMgesPMZ4ubVeG7R395SnAb5ffkyPHbzSQMqPlLGI7U,9211
|
|
86
|
+
torchzero/modules/quasi_newton/lsr1.py,sha256=XmYyYANzQgQuFtOMW59znQrS-mprGRXazicfB9JAup8,6059
|
|
85
87
|
torchzero/modules/quasi_newton/olbfgs.py,sha256=2YAOXlMnPGw22sNcIMH1hmggzAXQRbN59RSPUZNKUZY,8352
|
|
86
|
-
torchzero/modules/quasi_newton/quasi_newton.py,sha256=
|
|
88
|
+
torchzero/modules/quasi_newton/quasi_newton.py,sha256=rUp4s3MbACcOjwpz00TAjl-olif50voTmC16vv5XrSE,17496
|
|
87
89
|
torchzero/modules/quasi_newton/experimental/__init__.py,sha256=3qpZGgdsx6wpoafWaNWx-eamRl1FuxVCWQZq8Y7Cl98,39
|
|
88
|
-
torchzero/modules/quasi_newton/experimental/modular_lbfgs.py,sha256=
|
|
89
|
-
torchzero/modules/second_order/__init__.py,sha256=
|
|
90
|
-
torchzero/modules/second_order/newton.py,sha256=
|
|
91
|
-
torchzero/modules/second_order/newton_cg.py,sha256=
|
|
92
|
-
torchzero/modules/second_order/nystrom.py,sha256=
|
|
90
|
+
torchzero/modules/quasi_newton/experimental/modular_lbfgs.py,sha256=ec6JKYX89xA_UlY9VrMB3hBjDyNKwkalS_4JQGA1qOY,10762
|
|
91
|
+
torchzero/modules/second_order/__init__.py,sha256=jolCGaIVkID9hpxgx0Tc22wgjVlwuWekWjKTMe5jKXw,114
|
|
92
|
+
torchzero/modules/second_order/newton.py,sha256=xxkrhFK4i5I9oOX3AGGh_6bXNDUSFq4D0pw3c7qgEd8,5925
|
|
93
|
+
torchzero/modules/second_order/newton_cg.py,sha256=PILHRf2koop_cywE1RNGukT16alDO7prC4C3HlZcW30,2861
|
|
94
|
+
torchzero/modules/second_order/nystrom.py,sha256=zdLSTQ_S5VViUt2sAmFNoDCCHKmHP2A7112czkZNlUk,6051
|
|
93
95
|
torchzero/modules/smoothing/__init__.py,sha256=tUTGN0A-EQC7xuLV2AuHFWk-t7D6jIJlpV_3qyfRqLk,80
|
|
94
96
|
torchzero/modules/smoothing/gaussian.py,sha256=YlT_G4MqAVkiWG56RHAwgt5SSPISpvQZQbSLh8mhF3I,6153
|
|
95
97
|
torchzero/modules/smoothing/laplacian.py,sha256=Bfrs7D59SfdU7j-97UBKD1hs0obC-ZgjJvG7oKwaa0o,5065
|
|
@@ -116,13 +118,14 @@ torchzero/utils/python_tools.py,sha256=RFBqNj8w52dpJ983pUPPDbg2x1MX_-SsBnBMffWGG
|
|
|
116
118
|
torchzero/utils/tensorlist.py,sha256=qSbiliVo1euFAksdHHHRbPUdYYxfkw1dvhpXj71wGy0,53162
|
|
117
119
|
torchzero/utils/torch_tools.py,sha256=ohqnnZRlqdfp5PAfMSbQDIEKygW0_ARjxSEBp3Zo9nU,4756
|
|
118
120
|
torchzero/utils/linalg/__init__.py,sha256=Dzbho3_z7JDdKzYD-QdLArg0ZEoC2BVGdlE3JoAnXHQ,272
|
|
121
|
+
torchzero/utils/linalg/benchmark.py,sha256=wiIMn-GY2xxWbHVf8CPbJddUPeUPq9OUDkvbp1iILYI,479
|
|
119
122
|
torchzero/utils/linalg/matrix_funcs.py,sha256=-LecWrPWbJvfeCgIzUhfWARa2aSZvJ12lHX7Jno38O4,3099
|
|
120
123
|
torchzero/utils/linalg/orthogonalize.py,sha256=mDCkET7qgDZqf_y6oPYAK3d2L5HrB8gzOFPl0YoONaY,399
|
|
121
124
|
torchzero/utils/linalg/qr.py,sha256=L-RXuYV-SIHI-Llq4y1rQ_Tz-yamds0_QNZeHapbjNE,2507
|
|
122
|
-
torchzero/utils/linalg/solve.py,sha256=
|
|
125
|
+
torchzero/utils/linalg/solve.py,sha256=P0PMi0zro3G3Rd0X-JeoLk7tqYDB0js0aB4bpQ0OABU,5235
|
|
123
126
|
torchzero/utils/linalg/svd.py,sha256=wBxl-JSciINV-N6zvM4SGdveqMr6idq51h68LyQQRYg,660
|
|
124
|
-
torchzero-0.3.
|
|
125
|
-
torchzero-0.3.
|
|
126
|
-
torchzero-0.3.
|
|
127
|
-
torchzero-0.3.
|
|
128
|
-
torchzero-0.3.
|
|
127
|
+
torchzero-0.3.9.dist-info/licenses/LICENSE,sha256=r9ZciAoZoqKC_FNADE0ORukj1p1XhLXEbegdsAyqhJs,1087
|
|
128
|
+
torchzero-0.3.9.dist-info/METADATA,sha256=aENIaMgy94tD6nakRWfApleVSy6bxW8-q3-mQeVSeGA,13941
|
|
129
|
+
torchzero-0.3.9.dist-info/WHEEL,sha256=zaaOINJESkSfm_4HQVc5ssNzHCPXhJm0kEUakpsEHaU,91
|
|
130
|
+
torchzero-0.3.9.dist-info/top_level.txt,sha256=YDdpIOb7HyKV9THOtOYsFFMTbxvCO0kiol4-83tDj-A,21
|
|
131
|
+
torchzero-0.3.9.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|