torchzero 0.3.9__py3-none-any.whl → 0.3.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docs/source/conf.py +6 -4
- docs/source/docstring template.py +46 -0
- tests/test_identical.py +2 -3
- tests/test_opts.py +115 -68
- tests/test_tensorlist.py +2 -2
- tests/test_vars.py +62 -61
- torchzero/core/__init__.py +2 -3
- torchzero/core/module.py +185 -53
- torchzero/core/transform.py +327 -159
- torchzero/modules/__init__.py +3 -1
- torchzero/modules/clipping/clipping.py +120 -23
- torchzero/modules/clipping/ema_clipping.py +37 -22
- torchzero/modules/clipping/growth_clipping.py +20 -21
- torchzero/modules/experimental/__init__.py +30 -4
- torchzero/modules/experimental/absoap.py +53 -156
- torchzero/modules/experimental/adadam.py +22 -15
- torchzero/modules/experimental/adamY.py +21 -25
- torchzero/modules/experimental/adam_lambertw.py +149 -0
- torchzero/modules/{line_search/trust_region.py → experimental/adaptive_step_size.py} +37 -8
- torchzero/modules/experimental/adasoap.py +24 -129
- torchzero/modules/experimental/cosine.py +214 -0
- torchzero/modules/experimental/cubic_adam.py +97 -0
- torchzero/modules/experimental/curveball.py +12 -12
- torchzero/modules/{projections → experimental}/dct.py +11 -11
- torchzero/modules/experimental/eigendescent.py +120 -0
- torchzero/modules/experimental/etf.py +195 -0
- torchzero/modules/experimental/exp_adam.py +113 -0
- torchzero/modules/experimental/expanded_lbfgs.py +141 -0
- torchzero/modules/{projections → experimental}/fft.py +10 -10
- torchzero/modules/experimental/gradmin.py +2 -2
- torchzero/modules/experimental/hnewton.py +85 -0
- torchzero/modules/{quasi_newton/experimental → experimental}/modular_lbfgs.py +49 -50
- torchzero/modules/experimental/newton_solver.py +11 -11
- torchzero/modules/experimental/newtonnewton.py +92 -0
- torchzero/modules/experimental/parabolic_search.py +220 -0
- torchzero/modules/experimental/reduce_outward_lr.py +10 -7
- torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +12 -54
- torchzero/modules/experimental/subspace_preconditioners.py +20 -10
- torchzero/modules/experimental/tensor_adagrad.py +42 -0
- torchzero/modules/functional.py +12 -2
- torchzero/modules/grad_approximation/fdm.py +31 -4
- torchzero/modules/grad_approximation/forward_gradient.py +17 -7
- torchzero/modules/grad_approximation/grad_approximator.py +69 -24
- torchzero/modules/grad_approximation/rfdm.py +310 -50
- torchzero/modules/higher_order/__init__.py +1 -0
- torchzero/modules/higher_order/higher_order_newton.py +319 -0
- torchzero/modules/line_search/__init__.py +4 -4
- torchzero/modules/line_search/adaptive.py +99 -0
- torchzero/modules/line_search/backtracking.py +75 -31
- torchzero/modules/line_search/line_search.py +107 -49
- torchzero/modules/line_search/polynomial.py +233 -0
- torchzero/modules/line_search/scipy.py +20 -5
- torchzero/modules/line_search/strong_wolfe.py +52 -36
- torchzero/modules/misc/__init__.py +27 -0
- torchzero/modules/misc/debug.py +48 -0
- torchzero/modules/misc/escape.py +60 -0
- torchzero/modules/misc/gradient_accumulation.py +70 -0
- torchzero/modules/misc/misc.py +316 -0
- torchzero/modules/misc/multistep.py +158 -0
- torchzero/modules/misc/regularization.py +171 -0
- torchzero/modules/misc/split.py +103 -0
- torchzero/modules/{ops → misc}/switch.py +48 -7
- torchzero/modules/momentum/__init__.py +1 -1
- torchzero/modules/momentum/averaging.py +25 -10
- torchzero/modules/momentum/cautious.py +115 -40
- torchzero/modules/momentum/ema.py +92 -41
- torchzero/modules/momentum/experimental.py +21 -13
- torchzero/modules/momentum/matrix_momentum.py +145 -76
- torchzero/modules/momentum/momentum.py +25 -4
- torchzero/modules/ops/__init__.py +3 -31
- torchzero/modules/ops/accumulate.py +51 -25
- torchzero/modules/ops/binary.py +108 -62
- torchzero/modules/ops/multi.py +95 -34
- torchzero/modules/ops/reduce.py +31 -23
- torchzero/modules/ops/unary.py +37 -21
- torchzero/modules/ops/utility.py +53 -45
- torchzero/modules/optimizers/__init__.py +12 -3
- torchzero/modules/optimizers/adagrad.py +48 -29
- torchzero/modules/optimizers/adahessian.py +223 -0
- torchzero/modules/optimizers/adam.py +35 -37
- torchzero/modules/optimizers/adan.py +110 -0
- torchzero/modules/optimizers/adaptive_heavyball.py +57 -0
- torchzero/modules/optimizers/esgd.py +171 -0
- torchzero/modules/optimizers/ladagrad.py +183 -0
- torchzero/modules/optimizers/lion.py +4 -4
- torchzero/modules/optimizers/mars.py +91 -0
- torchzero/modules/optimizers/msam.py +186 -0
- torchzero/modules/optimizers/muon.py +32 -7
- torchzero/modules/optimizers/orthograd.py +4 -5
- torchzero/modules/optimizers/rmsprop.py +19 -19
- torchzero/modules/optimizers/rprop.py +89 -52
- torchzero/modules/optimizers/sam.py +163 -0
- torchzero/modules/optimizers/shampoo.py +55 -27
- torchzero/modules/optimizers/soap.py +40 -37
- torchzero/modules/optimizers/sophia_h.py +82 -25
- torchzero/modules/projections/__init__.py +2 -4
- torchzero/modules/projections/cast.py +51 -0
- torchzero/modules/projections/galore.py +4 -2
- torchzero/modules/projections/projection.py +212 -118
- torchzero/modules/quasi_newton/__init__.py +44 -5
- torchzero/modules/quasi_newton/cg.py +190 -39
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +163 -0
- torchzero/modules/quasi_newton/lbfgs.py +154 -97
- torchzero/modules/quasi_newton/lsr1.py +102 -58
- torchzero/modules/quasi_newton/quasi_newton.py +1032 -177
- torchzero/modules/quasi_newton/trust_region.py +397 -0
- torchzero/modules/second_order/__init__.py +2 -2
- torchzero/modules/second_order/newton.py +245 -54
- torchzero/modules/second_order/newton_cg.py +311 -21
- torchzero/modules/second_order/nystrom.py +124 -21
- torchzero/modules/smoothing/gaussian.py +55 -21
- torchzero/modules/smoothing/laplacian.py +20 -12
- torchzero/modules/step_size/__init__.py +2 -0
- torchzero/modules/step_size/adaptive.py +122 -0
- torchzero/modules/step_size/lr.py +154 -0
- torchzero/modules/weight_decay/__init__.py +1 -1
- torchzero/modules/weight_decay/weight_decay.py +126 -10
- torchzero/modules/wrappers/optim_wrapper.py +40 -12
- torchzero/optim/wrappers/directsearch.py +281 -0
- torchzero/optim/wrappers/fcmaes.py +105 -0
- torchzero/optim/wrappers/mads.py +89 -0
- torchzero/optim/wrappers/nevergrad.py +20 -5
- torchzero/optim/wrappers/nlopt.py +28 -14
- torchzero/optim/wrappers/optuna.py +70 -0
- torchzero/optim/wrappers/scipy.py +167 -16
- torchzero/utils/__init__.py +3 -7
- torchzero/utils/derivatives.py +5 -4
- torchzero/utils/linalg/__init__.py +1 -1
- torchzero/utils/linalg/solve.py +251 -12
- torchzero/utils/numberlist.py +2 -0
- torchzero/utils/optimizer.py +55 -74
- torchzero/utils/python_tools.py +27 -4
- torchzero/utils/tensorlist.py +40 -28
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/METADATA +76 -51
- torchzero-0.3.11.dist-info/RECORD +159 -0
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/WHEEL +1 -1
- torchzero/core/preconditioner.py +0 -138
- torchzero/modules/experimental/algebraic_newton.py +0 -145
- torchzero/modules/experimental/soapy.py +0 -290
- torchzero/modules/experimental/spectral.py +0 -288
- torchzero/modules/experimental/structured_newton.py +0 -111
- torchzero/modules/experimental/tropical_newton.py +0 -136
- torchzero/modules/lr/__init__.py +0 -2
- torchzero/modules/lr/lr.py +0 -59
- torchzero/modules/lr/step_size.py +0 -97
- torchzero/modules/ops/debug.py +0 -25
- torchzero/modules/ops/misc.py +0 -419
- torchzero/modules/ops/split.py +0 -75
- torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
- torchzero/modules/quasi_newton/olbfgs.py +0 -196
- torchzero-0.3.9.dist-info/RECORD +0 -131
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/licenses/LICENSE +0 -0
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/top_level.txt +0 -0
torchzero/core/preconditioner.py
DELETED
|
@@ -1,138 +0,0 @@
|
|
|
1
|
-
from abc import ABC, abstractmethod
|
|
2
|
-
from collections import ChainMap, defaultdict
|
|
3
|
-
from collections.abc import Mapping, Sequence
|
|
4
|
-
from typing import Any, overload, final
|
|
5
|
-
|
|
6
|
-
import torch
|
|
7
|
-
|
|
8
|
-
from .module import Module, Chainable, Vars
|
|
9
|
-
from .transform import apply, Transform, Target
|
|
10
|
-
from ..utils import TensorList, vec_to_tensors
|
|
11
|
-
|
|
12
|
-
class Preconditioner(Transform):
|
|
13
|
-
"""Abstract class for a preconditioner."""
|
|
14
|
-
def __init__(
|
|
15
|
-
self,
|
|
16
|
-
defaults: dict | None,
|
|
17
|
-
uses_grad: bool,
|
|
18
|
-
concat_params: bool = False,
|
|
19
|
-
update_freq: int = 1,
|
|
20
|
-
scale_first: bool = False,
|
|
21
|
-
inner: Chainable | None = None,
|
|
22
|
-
target: Target = "update",
|
|
23
|
-
):
|
|
24
|
-
if defaults is None: defaults = {}
|
|
25
|
-
defaults.update(dict(__update_freq=update_freq, __concat_params=concat_params, __scale_first=scale_first))
|
|
26
|
-
super().__init__(defaults, uses_grad=uses_grad, target=target)
|
|
27
|
-
|
|
28
|
-
if inner is not None:
|
|
29
|
-
self.set_child('inner', inner)
|
|
30
|
-
|
|
31
|
-
@abstractmethod
|
|
32
|
-
def update(self, tensors: list[torch.Tensor], params:list[torch.Tensor], grads:list[torch.Tensor] | None, states: list[dict[str, Any]], settings: Sequence[Mapping[str, Any]]):
|
|
33
|
-
"""updates the preconditioner with `tensors`, any internal state should be stored using `keys`"""
|
|
34
|
-
|
|
35
|
-
@abstractmethod
|
|
36
|
-
def apply(self, tensors:list[torch.Tensor], params:list[torch.Tensor], grads:list[torch.Tensor] | None, states: list[dict[str, Any]], settings: Sequence[Mapping[str, Any]]) -> list[torch.Tensor]:
|
|
37
|
-
"""applies preconditioner to `tensors`, any internal state should be stored using `keys`"""
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
def _tensor_wise_transform(self, tensors:list[torch.Tensor], params:list[torch.Tensor], grads:list[torch.Tensor] | None, vars:Vars) -> list[torch.Tensor]:
|
|
41
|
-
step = self.global_state.get('__step', 0)
|
|
42
|
-
states = [self.state[p] for p in params]
|
|
43
|
-
settings = [self.settings[p] for p in params]
|
|
44
|
-
global_settings = settings[0]
|
|
45
|
-
update_freq = global_settings['__update_freq']
|
|
46
|
-
|
|
47
|
-
scale_first = global_settings['__scale_first']
|
|
48
|
-
scale_factor = 1
|
|
49
|
-
if scale_first and step == 0:
|
|
50
|
-
# initial step size guess from pytorch LBFGS
|
|
51
|
-
scale_factor = 1 / TensorList(tensors).abs().global_sum().clip(min=1)
|
|
52
|
-
scale_factor = scale_factor.clip(min=torch.finfo(tensors[0].dtype).eps)
|
|
53
|
-
|
|
54
|
-
# update preconditioner
|
|
55
|
-
if step % update_freq == 0:
|
|
56
|
-
self.update(tensors=tensors, params=params, grads=grads, states=states, settings=settings)
|
|
57
|
-
|
|
58
|
-
# step with inner
|
|
59
|
-
if 'inner' in self.children:
|
|
60
|
-
tensors = apply(self.children['inner'], tensors=tensors, params=params, grads=grads, vars=vars)
|
|
61
|
-
|
|
62
|
-
# apply preconditioner
|
|
63
|
-
tensors = self.apply(tensors=tensors, params=params, grads=grads, states=states, settings=settings)
|
|
64
|
-
|
|
65
|
-
# scale initial step, when preconditioner might not have been applied
|
|
66
|
-
if scale_first and step == 0:
|
|
67
|
-
torch._foreach_mul_(tensors, scale_factor)
|
|
68
|
-
|
|
69
|
-
self.global_state['__step'] = step + 1
|
|
70
|
-
return tensors
|
|
71
|
-
|
|
72
|
-
def _concat_transform(self, tensors:list[torch.Tensor], params:list[torch.Tensor], grads:list[torch.Tensor] | None, vars:Vars) -> list[torch.Tensor]:
|
|
73
|
-
step = self.global_state.get('__step', 0)
|
|
74
|
-
tensors_vec = torch.cat([t.ravel() for t in tensors])
|
|
75
|
-
params_vec = torch.cat([p.ravel() for p in params])
|
|
76
|
-
grads_vec = [torch.cat([g.ravel() for g in grads])] if grads is not None else None
|
|
77
|
-
|
|
78
|
-
states = [self.state[params[0]]]
|
|
79
|
-
settings = [self.settings[params[0]]]
|
|
80
|
-
global_settings = settings[0]
|
|
81
|
-
update_freq = global_settings['__update_freq']
|
|
82
|
-
|
|
83
|
-
scale_first = global_settings['__scale_first']
|
|
84
|
-
scale_factor = 1
|
|
85
|
-
if scale_first and step == 0:
|
|
86
|
-
# initial step size guess from pytorch LBFGS
|
|
87
|
-
scale_factor = 1 / tensors_vec.abs().sum().clip(min=1)
|
|
88
|
-
scale_factor = scale_factor.clip(min=torch.finfo(tensors_vec.dtype).eps)
|
|
89
|
-
|
|
90
|
-
# update preconditioner
|
|
91
|
-
if step % update_freq == 0:
|
|
92
|
-
self.update(tensors=[tensors_vec], params=[params_vec], grads=grads_vec, states=states, settings=settings)
|
|
93
|
-
|
|
94
|
-
# step with inner
|
|
95
|
-
if 'inner' in self.children:
|
|
96
|
-
tensors = apply(self.children['inner'], tensors=tensors, params=params, grads=grads, vars=vars)
|
|
97
|
-
tensors_vec = torch.cat([t.ravel() for t in tensors]) # have to recat
|
|
98
|
-
|
|
99
|
-
# apply preconditioner
|
|
100
|
-
tensors_vec = self.apply(tensors=[tensors_vec], params=[params_vec], grads=grads_vec, states=states, settings=settings)[0]
|
|
101
|
-
|
|
102
|
-
# scale initial step, when preconditioner might not have been applied
|
|
103
|
-
if scale_first and step == 0:
|
|
104
|
-
tensors_vec *= scale_factor
|
|
105
|
-
|
|
106
|
-
tensors = vec_to_tensors(vec=tensors_vec, reference=tensors)
|
|
107
|
-
self.global_state['__step'] = step + 1
|
|
108
|
-
return tensors
|
|
109
|
-
|
|
110
|
-
@torch.no_grad
|
|
111
|
-
def transform(self, tensors, params, grads, vars):
|
|
112
|
-
concat_params = self.settings[params[0]]['__concat_params']
|
|
113
|
-
if concat_params: return self._concat_transform(tensors, params, grads, vars)
|
|
114
|
-
return self._tensor_wise_transform(tensors, params, grads, vars)
|
|
115
|
-
|
|
116
|
-
class TensorwisePreconditioner(Preconditioner, ABC):
|
|
117
|
-
@abstractmethod
|
|
118
|
-
def update_tensor(self, tensor: torch.Tensor, param:torch.Tensor, grad: torch.Tensor | None, state: dict[str, Any], settings: Mapping[str, Any]):
|
|
119
|
-
"""update preconditioner with `tensor`"""
|
|
120
|
-
|
|
121
|
-
@abstractmethod
|
|
122
|
-
def apply_tensor(self, tensor: torch.Tensor, param:torch.Tensor, grad: torch.Tensor | None, state: dict[str, Any], settings: Mapping[str, Any]) -> torch.Tensor:
|
|
123
|
-
"""apply preconditioner to `tensor`"""
|
|
124
|
-
|
|
125
|
-
@final
|
|
126
|
-
def update(self, tensors, params, grads, states, settings):
|
|
127
|
-
if grads is None: grads = [None]*len(tensors)
|
|
128
|
-
for t,p,g,state,setting in zip(tensors, params, grads, states, settings):
|
|
129
|
-
self.update_tensor(t, p, g, state, setting)
|
|
130
|
-
|
|
131
|
-
@final
|
|
132
|
-
def apply(self, tensors, params, grads, states, settings):
|
|
133
|
-
preconditioned = []
|
|
134
|
-
if grads is None: grads = [None]*len(tensors)
|
|
135
|
-
for t,p,g,state,setting in zip(tensors, params, grads, states, settings):
|
|
136
|
-
preconditioned.append(self.apply_tensor(t, p, g, state, setting))
|
|
137
|
-
return preconditioned
|
|
138
|
-
|
|
@@ -1,145 +0,0 @@
|
|
|
1
|
-
import warnings
|
|
2
|
-
from functools import partial
|
|
3
|
-
from typing import Literal
|
|
4
|
-
from collections.abc import Callable
|
|
5
|
-
import torch
|
|
6
|
-
import torchalgebras as ta
|
|
7
|
-
|
|
8
|
-
from ...core import Chainable, apply, Module
|
|
9
|
-
from ...utils import vec_to_tensors, TensorList
|
|
10
|
-
from ...utils.derivatives import (
|
|
11
|
-
hessian_list_to_mat,
|
|
12
|
-
hessian_mat,
|
|
13
|
-
jacobian_and_hessian_wrt,
|
|
14
|
-
)
|
|
15
|
-
|
|
16
|
-
class MaxItersReached(Exception): pass
|
|
17
|
-
def tropical_lstsq(
|
|
18
|
-
H: torch.Tensor,
|
|
19
|
-
g: torch.Tensor,
|
|
20
|
-
solver,
|
|
21
|
-
maxiter,
|
|
22
|
-
tol,
|
|
23
|
-
algebra,
|
|
24
|
-
verbose,
|
|
25
|
-
):
|
|
26
|
-
"""it can run on any algebra with add despite it saying tropical"""
|
|
27
|
-
algebra = ta.get_algebra(algebra)
|
|
28
|
-
|
|
29
|
-
x = torch.zeros_like(g, requires_grad=True)
|
|
30
|
-
best_x = x.detach().clone()
|
|
31
|
-
best_loss = float('inf')
|
|
32
|
-
opt = solver([x])
|
|
33
|
-
|
|
34
|
-
niter = 0
|
|
35
|
-
def closure(backward=True):
|
|
36
|
-
nonlocal niter, best_x, best_loss
|
|
37
|
-
if niter == maxiter: raise MaxItersReached
|
|
38
|
-
niter += 1
|
|
39
|
-
|
|
40
|
-
g_hat = algebra.mm(H, x)
|
|
41
|
-
loss = torch.nn.functional.mse_loss(g_hat, g)
|
|
42
|
-
if loss < best_loss:
|
|
43
|
-
best_x = x.detach().clone()
|
|
44
|
-
best_loss = loss.detach()
|
|
45
|
-
|
|
46
|
-
if backward:
|
|
47
|
-
opt.zero_grad()
|
|
48
|
-
loss.backward()
|
|
49
|
-
return loss
|
|
50
|
-
|
|
51
|
-
loss = None
|
|
52
|
-
prev_loss = float('inf')
|
|
53
|
-
for i in range(maxiter):
|
|
54
|
-
try:
|
|
55
|
-
loss = opt.step(closure)
|
|
56
|
-
if loss == 0: break
|
|
57
|
-
if tol is not None and prev_loss - loss < tol: break
|
|
58
|
-
prev_loss = loss
|
|
59
|
-
except MaxItersReached:
|
|
60
|
-
break
|
|
61
|
-
|
|
62
|
-
if verbose: print(f'{best_loss = } after {niter} iters')
|
|
63
|
-
return best_x.detach()
|
|
64
|
-
|
|
65
|
-
def tikhonov(H: torch.Tensor, reg: float, algebra: ta.Algebra = ta.TropicalSemiring()):
|
|
66
|
-
if reg!=0:
|
|
67
|
-
I = ta.AlgebraicTensor(torch.eye(H.size(-1), dtype=H.dtype, device=H.device), algebra)
|
|
68
|
-
I = I * reg
|
|
69
|
-
H = algebra.add(H, I.data)
|
|
70
|
-
return H
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
class AlgebraicNewton(Module):
|
|
74
|
-
"""newton in other algebras, not that it works."""
|
|
75
|
-
def __init__(
|
|
76
|
-
self,
|
|
77
|
-
reg: float | None = None,
|
|
78
|
-
hessian_method: Literal["autograd", "func", "autograd.functional"] = "autograd",
|
|
79
|
-
vectorize: bool = True,
|
|
80
|
-
solver=lambda p: torch.optim.LBFGS(p, line_search_fn='strong_wolfe'),
|
|
81
|
-
maxiter=1000,
|
|
82
|
-
tol: float | None = 1e-10,
|
|
83
|
-
algebra: ta.Algebra | str = 'tropical max',
|
|
84
|
-
verbose: bool = False,
|
|
85
|
-
inner: Chainable | None = None,
|
|
86
|
-
):
|
|
87
|
-
defaults = dict(reg=reg, hessian_method=hessian_method, vectorize=vectorize)
|
|
88
|
-
super().__init__(defaults)
|
|
89
|
-
|
|
90
|
-
self.algebra = ta.get_algebra(algebra)
|
|
91
|
-
self.lstsq_args:dict = dict(solver=solver, maxiter=maxiter, tol=tol, algebra=algebra, verbose=verbose)
|
|
92
|
-
|
|
93
|
-
if inner is not None:
|
|
94
|
-
self.set_child('inner', inner)
|
|
95
|
-
|
|
96
|
-
@torch.no_grad
|
|
97
|
-
def step(self, vars):
|
|
98
|
-
params = TensorList(vars.params)
|
|
99
|
-
closure = vars.closure
|
|
100
|
-
if closure is None: raise RuntimeError('NewtonCG requires closure')
|
|
101
|
-
|
|
102
|
-
settings = self.settings[params[0]]
|
|
103
|
-
reg = settings['reg']
|
|
104
|
-
hessian_method = settings['hessian_method']
|
|
105
|
-
vectorize = settings['vectorize']
|
|
106
|
-
|
|
107
|
-
# ------------------------ calculate grad and hessian ------------------------ #
|
|
108
|
-
if hessian_method == 'autograd':
|
|
109
|
-
with torch.enable_grad():
|
|
110
|
-
loss = vars.loss = vars.loss_approx = closure(False)
|
|
111
|
-
g_list, H_list = jacobian_and_hessian_wrt([loss], params, batched=vectorize)
|
|
112
|
-
g_list = [t[0] for t in g_list] # remove leading dim from loss
|
|
113
|
-
vars.grad = g_list
|
|
114
|
-
H = hessian_list_to_mat(H_list)
|
|
115
|
-
|
|
116
|
-
elif hessian_method in ('func', 'autograd.functional'):
|
|
117
|
-
strat = 'forward-mode' if vectorize else 'reverse-mode'
|
|
118
|
-
with torch.enable_grad():
|
|
119
|
-
g_list = vars.get_grad(retain_graph=True)
|
|
120
|
-
H: torch.Tensor = hessian_mat(partial(closure, backward=False), params,
|
|
121
|
-
method=hessian_method, vectorize=vectorize, outer_jacobian_strategy=strat) # pyright:ignore[reportAssignmentType]
|
|
122
|
-
|
|
123
|
-
else:
|
|
124
|
-
raise ValueError(hessian_method)
|
|
125
|
-
|
|
126
|
-
# -------------------------------- inner step -------------------------------- #
|
|
127
|
-
if 'inner' in self.children:
|
|
128
|
-
g_list = apply(self.children['inner'], list(g_list), params=params, grads=list(g_list), vars=vars)
|
|
129
|
-
g = torch.cat([t.view(-1) for t in g_list])
|
|
130
|
-
|
|
131
|
-
# ------------------------------- regulazition ------------------------------- #
|
|
132
|
-
if reg is not None: H = tikhonov(H, reg)
|
|
133
|
-
|
|
134
|
-
# ----------------------------------- solve ---------------------------------- #
|
|
135
|
-
tropical_update = tropical_lstsq(H, g, **self.lstsq_args)
|
|
136
|
-
# what now? w - u is not defined, it is defined for max version if u < w
|
|
137
|
-
# w = params.to_vec()
|
|
138
|
-
# w_hat = self.algebra.sub(w, tropical_update)
|
|
139
|
-
# update = w_hat - w
|
|
140
|
-
# no
|
|
141
|
-
# it makes sense to solve tropical system and sub normally
|
|
142
|
-
# the only thing is that tropical system can have no solutions
|
|
143
|
-
|
|
144
|
-
vars.update = vec_to_tensors(tropical_update, params)
|
|
145
|
-
return vars
|
|
@@ -1,290 +0,0 @@
|
|
|
1
|
-
from operator import itemgetter
|
|
2
|
-
|
|
3
|
-
import torch
|
|
4
|
-
|
|
5
|
-
from ...core import Chainable, Transform, apply
|
|
6
|
-
from ..optimizers.shampoo import _merge_small_dims, _unmerge_small_dims
|
|
7
|
-
|
|
8
|
-
@torch.no_grad
|
|
9
|
-
def update_soap_covariances_(
|
|
10
|
-
grad: torch.Tensor,
|
|
11
|
-
GGs_: list[torch.Tensor | None],
|
|
12
|
-
beta: float | None,
|
|
13
|
-
):
|
|
14
|
-
for i, GG in enumerate(GGs_):
|
|
15
|
-
if GG is None: continue
|
|
16
|
-
|
|
17
|
-
axes = list(range(i)) + list(range(i + 1, grad.ndim)) # this works fine with 1d params
|
|
18
|
-
if beta is None: GG.add_(torch.tensordot(grad, grad, (axes, axes))) # pyright:ignore[reportArgumentType]
|
|
19
|
-
else: GG.lerp_(torch.tensordot(grad, grad, (axes, axes)), 1-beta) # pyright:ignore[reportArgumentType]
|
|
20
|
-
|
|
21
|
-
@torch.no_grad
|
|
22
|
-
def project(tensors: torch.Tensor, Q: list[torch.Tensor | None]):
|
|
23
|
-
"""
|
|
24
|
-
Projects the gradient to the eigenbases of the preconditioner.
|
|
25
|
-
"""
|
|
26
|
-
for mat in Q:
|
|
27
|
-
if mat is None: continue
|
|
28
|
-
if len(mat) > 0:
|
|
29
|
-
tensors = torch.tensordot(tensors, mat, dims=[[0], [0]]) # pyright:ignore[reportArgumentType]
|
|
30
|
-
else:
|
|
31
|
-
# I don't understand this part but it is in https://github.com/nikhilvyas/SOAP/blob/main/soap.py
|
|
32
|
-
permute_order = list(range(1, len(tensors.shape))) + [0]
|
|
33
|
-
tensors = tensors.permute(permute_order)
|
|
34
|
-
|
|
35
|
-
return tensors
|
|
36
|
-
|
|
37
|
-
@torch.no_grad
|
|
38
|
-
def project_back(tensors: torch.Tensor, Q: list[torch.Tensor| None]):
|
|
39
|
-
"""
|
|
40
|
-
Projects the gradient back to the original space.
|
|
41
|
-
"""
|
|
42
|
-
for mat in Q:
|
|
43
|
-
if mat is None: continue
|
|
44
|
-
if len(mat) > 0:
|
|
45
|
-
tensors = torch.tensordot(tensors, mat,dims=[[0], [1]]) # pyright:ignore[reportArgumentType]
|
|
46
|
-
else:
|
|
47
|
-
permute_order = list(range(1, len(tensors.shape))) + [0]
|
|
48
|
-
tensors = tensors.permute(permute_order)
|
|
49
|
-
|
|
50
|
-
return tensors
|
|
51
|
-
|
|
52
|
-
# function from https://github.com/nikhilvyas/SOAP/blob/main/soap.py
|
|
53
|
-
@torch.no_grad
|
|
54
|
-
def get_orthogonal_matrix(mat: list[torch.Tensor | None]):
|
|
55
|
-
"""
|
|
56
|
-
Computes the eigenbases of the preconditioner using torch.linalg.eigh decomposition.
|
|
57
|
-
"""
|
|
58
|
-
matrix = []
|
|
59
|
-
float_data = False
|
|
60
|
-
original_type = original_device = None
|
|
61
|
-
for m in mat:
|
|
62
|
-
if m is None: continue
|
|
63
|
-
if len(m) == 0:
|
|
64
|
-
matrix.append([])
|
|
65
|
-
continue
|
|
66
|
-
if m.dtype != torch.float:
|
|
67
|
-
original_type = m.dtype
|
|
68
|
-
original_device = m.device
|
|
69
|
-
matrix.append(m.float())
|
|
70
|
-
else:
|
|
71
|
-
float_data = True
|
|
72
|
-
matrix.append(m)
|
|
73
|
-
|
|
74
|
-
final = []
|
|
75
|
-
for m in matrix:
|
|
76
|
-
if len(m) == 0:
|
|
77
|
-
final.append([])
|
|
78
|
-
continue
|
|
79
|
-
try:
|
|
80
|
-
_, Q = torch.linalg.eigh(m+1e-30*torch.eye(m.shape[0], device=m.device)) # pylint:disable=not-callable
|
|
81
|
-
except Exception:
|
|
82
|
-
_, Q = torch.linalg.eigh(m.to(torch.float64)+1e-30*torch.eye(m.shape[0], device=m.device)) # pylint:disable=not-callable
|
|
83
|
-
Q = Q.to(m.dtype)
|
|
84
|
-
Q = torch.flip(Q, [1])
|
|
85
|
-
|
|
86
|
-
if not float_data:
|
|
87
|
-
Q = Q.to(original_device).type(original_type)
|
|
88
|
-
final.append(Q)
|
|
89
|
-
return final
|
|
90
|
-
|
|
91
|
-
# function from https://github.com/nikhilvyas/SOAP/blob/main/soap.py#L240
|
|
92
|
-
@torch.no_grad
|
|
93
|
-
def get_orthogonal_matrix_QR(exp_avg_sq: torch.Tensor, GG: list[torch.Tensor | None], Q_list: list[torch.Tensor | None]):
|
|
94
|
-
"""
|
|
95
|
-
Computes the eigenbases of the preconditioner using one round of power iteration
|
|
96
|
-
followed by torch.linalg.qr decomposition.
|
|
97
|
-
"""
|
|
98
|
-
matrix = []
|
|
99
|
-
orth_matrix = []
|
|
100
|
-
float_data = False
|
|
101
|
-
original_type = original_device = None
|
|
102
|
-
for m,o in zip(GG, Q_list):
|
|
103
|
-
if m is None: continue
|
|
104
|
-
assert o is not None
|
|
105
|
-
|
|
106
|
-
if len(m) == 0:
|
|
107
|
-
matrix.append([])
|
|
108
|
-
orth_matrix.append([])
|
|
109
|
-
continue
|
|
110
|
-
if m.data.dtype != torch.float:
|
|
111
|
-
original_type = m.data.dtype
|
|
112
|
-
original_device = m.data.device
|
|
113
|
-
matrix.append(m.data.float())
|
|
114
|
-
orth_matrix.append(o.data.float())
|
|
115
|
-
else:
|
|
116
|
-
float_data = True
|
|
117
|
-
matrix.append(m.data.float())
|
|
118
|
-
orth_matrix.append(o.data.float())
|
|
119
|
-
|
|
120
|
-
final = []
|
|
121
|
-
for ind, (m,o) in enumerate(zip(matrix, orth_matrix)):
|
|
122
|
-
if len(m)==0:
|
|
123
|
-
final.append([])
|
|
124
|
-
continue
|
|
125
|
-
est_eig = torch.diag(o.T @ m @ o)
|
|
126
|
-
sort_idx = torch.argsort(est_eig, descending=True)
|
|
127
|
-
exp_avg_sq = exp_avg_sq.index_select(ind, sort_idx)
|
|
128
|
-
o = o[:,sort_idx]
|
|
129
|
-
power_iter = m @ o
|
|
130
|
-
Q, _ = torch.linalg.qr(power_iter) # pylint:disable=not-callable
|
|
131
|
-
|
|
132
|
-
if not float_data:
|
|
133
|
-
Q = Q.to(original_device).type(original_type)
|
|
134
|
-
final.append(Q)
|
|
135
|
-
|
|
136
|
-
return final, exp_avg_sq
|
|
137
|
-
|
|
138
|
-
class SOAPY(Transform):
|
|
139
|
-
"""SOAP but uses scaled gradient differences
|
|
140
|
-
|
|
141
|
-
new args
|
|
142
|
-
|
|
143
|
-
scale by s whether to scale gradient differences by parameter differences
|
|
144
|
-
|
|
145
|
-
y_to_ema2 whether to use gradient differences for exponential moving average too
|
|
146
|
-
"""
|
|
147
|
-
def __init__(
|
|
148
|
-
self,
|
|
149
|
-
beta1: float = 0.95,
|
|
150
|
-
beta2: float = 0.95,
|
|
151
|
-
shampoo_beta: float | None = 0.95,
|
|
152
|
-
precond_freq: int = 10,
|
|
153
|
-
merge_small: bool = True,
|
|
154
|
-
max_dim: int = 2_000,
|
|
155
|
-
precondition_1d: bool = True,
|
|
156
|
-
eps: float = 1e-8,
|
|
157
|
-
decay: float | None = None,
|
|
158
|
-
alpha: float = 1,
|
|
159
|
-
bias_correction: bool = True,
|
|
160
|
-
scale_by_s: bool = True,
|
|
161
|
-
y_to_ema2: bool = False,
|
|
162
|
-
):
|
|
163
|
-
defaults = dict(
|
|
164
|
-
beta1=beta1,
|
|
165
|
-
beta2=beta2,
|
|
166
|
-
shampoo_beta=shampoo_beta,
|
|
167
|
-
precond_freq=precond_freq,
|
|
168
|
-
merge_small=merge_small,
|
|
169
|
-
max_dim=max_dim,
|
|
170
|
-
precondition_1d=precondition_1d,
|
|
171
|
-
eps=eps,
|
|
172
|
-
decay=decay,
|
|
173
|
-
bias_correction=bias_correction,
|
|
174
|
-
alpha=alpha,
|
|
175
|
-
scale_by_s=scale_by_s,
|
|
176
|
-
y_to_ema2=y_to_ema2,
|
|
177
|
-
)
|
|
178
|
-
super().__init__(defaults, uses_grad=False)
|
|
179
|
-
|
|
180
|
-
@torch.no_grad
|
|
181
|
-
def transform(self, tensors, params, grads, vars):
|
|
182
|
-
updates = []
|
|
183
|
-
# update preconditioners
|
|
184
|
-
for i,(p,t) in enumerate(zip(params, tensors)):
|
|
185
|
-
state = self.state[p]
|
|
186
|
-
settings = self.settings[p]
|
|
187
|
-
beta1, beta2, shampoo_beta, merge_small, max_dim, precondition_1d, eps, alpha = itemgetter(
|
|
188
|
-
'beta1', 'beta2', 'shampoo_beta', 'merge_small', 'max_dim', 'precondition_1d', 'eps', 'alpha')(settings)
|
|
189
|
-
scale_by_s = settings['scale_by_s']
|
|
190
|
-
y_to_ema2 = settings['y_to_ema2']
|
|
191
|
-
|
|
192
|
-
if merge_small:
|
|
193
|
-
t, state['flat_sizes'], state['sort_idxs'] = _merge_small_dims(t, max_dim)
|
|
194
|
-
|
|
195
|
-
if 'g_prev' not in state:
|
|
196
|
-
state['p_prev'] = p.clone()
|
|
197
|
-
state['g_prev'] = t.clone()
|
|
198
|
-
updates.append(tensors[i].clip(-0.1,0.1))
|
|
199
|
-
continue
|
|
200
|
-
|
|
201
|
-
p_prev = state['p_prev']
|
|
202
|
-
g_prev = state['g_prev']
|
|
203
|
-
s = p - p_prev
|
|
204
|
-
y = t - g_prev
|
|
205
|
-
if scale_by_s: y /= torch.linalg.norm(s).clip(min=1e-8) # pylint:disable=not-callable
|
|
206
|
-
|
|
207
|
-
state['p_prev'].copy_(p)
|
|
208
|
-
state['g_prev'].copy_(t)
|
|
209
|
-
|
|
210
|
-
# initialize state on 1st step
|
|
211
|
-
if 'GG' not in state:
|
|
212
|
-
state["exp_avg"] = torch.zeros_like(t)
|
|
213
|
-
if y_to_ema2: state["exp_avg_sq"] = torch.ones_like(t)
|
|
214
|
-
else: state["exp_avg_sq"] = torch.zeros_like(t)
|
|
215
|
-
|
|
216
|
-
if not precondition_1d and t.ndim <= 1:
|
|
217
|
-
state['GG'] = []
|
|
218
|
-
|
|
219
|
-
else:
|
|
220
|
-
state['GG'] = [torch.zeros(sh, sh, dtype=t.dtype, device=t.device) if 1<sh<max_dim else None for sh in t.shape]
|
|
221
|
-
|
|
222
|
-
# either scalar parameter, 1d with precondition_1d=False, or all dims are too big.
|
|
223
|
-
if len([i is not None for i in state['GG']]) == 0:
|
|
224
|
-
state['GG'] = None
|
|
225
|
-
|
|
226
|
-
if state['GG'] is not None:
|
|
227
|
-
update_soap_covariances_(y, GGs_=state['GG'], beta=shampoo_beta)
|
|
228
|
-
state['Q'] = get_orthogonal_matrix(state['GG'])
|
|
229
|
-
|
|
230
|
-
state['step'] = 0
|
|
231
|
-
updates.append(tensors[i].clip(-0.1,0.1))
|
|
232
|
-
continue # skip 1st step as in https://github.com/nikhilvyas/SOAP/blob/main/soap.py ?
|
|
233
|
-
# I use sign instead as to not mess up with next modules. 1st Adam step is always sign anyway.
|
|
234
|
-
|
|
235
|
-
# Projecting gradients to the eigenbases of Shampoo's preconditioner
|
|
236
|
-
# i.e. projecting to the eigenbases of matrices in state['GG']
|
|
237
|
-
z_projected = None
|
|
238
|
-
if state['GG'] is not None:
|
|
239
|
-
if y_to_ema2: z_projected = project(y, state['Q'])
|
|
240
|
-
else: z_projected = project(t, state['Q'])
|
|
241
|
-
|
|
242
|
-
# exponential moving averages
|
|
243
|
-
# this part could be foreached but I will do that at some point its not a big difference compared to preconditioning
|
|
244
|
-
exp_avg: torch.Tensor = state["exp_avg"]
|
|
245
|
-
exp_avg_sq: torch.Tensor = state["exp_avg_sq"]
|
|
246
|
-
|
|
247
|
-
exp_avg.lerp_(t, 1-beta1)
|
|
248
|
-
|
|
249
|
-
if z_projected is None:
|
|
250
|
-
if y_to_ema2: exp_avg_sq.mul_(beta2).addcmul_(y, y, value=1-beta2)
|
|
251
|
-
else: exp_avg_sq.mul_(beta2).addcmul_(t, t, value=1-beta2)
|
|
252
|
-
else:
|
|
253
|
-
exp_avg_sq.mul_(beta2).addcmul_(z_projected, z_projected, value=1-beta2)
|
|
254
|
-
|
|
255
|
-
# project exponential moving averages if they are accumulated unprojected
|
|
256
|
-
exp_avg_projected = exp_avg
|
|
257
|
-
if z_projected is not None:
|
|
258
|
-
exp_avg_projected = project(exp_avg, state['Q'])
|
|
259
|
-
|
|
260
|
-
exp_avg_sq_projected = exp_avg_sq
|
|
261
|
-
|
|
262
|
-
denom = exp_avg_sq_projected.sqrt().add_(eps)
|
|
263
|
-
# print(f'{t_projected = }, {exp_avg = }, {exp_avg_projected = }, {exp_avg_sq = }, {exp_avg_sq_projected = }, {denom = }')
|
|
264
|
-
|
|
265
|
-
# Projecting back the preconditioned (by Adam) exponential moving average of gradients
|
|
266
|
-
# to the original space
|
|
267
|
-
update = exp_avg_projected / denom
|
|
268
|
-
if z_projected is not None:
|
|
269
|
-
update = project_back(update, state["Q"])
|
|
270
|
-
|
|
271
|
-
if settings['bias_correction']:
|
|
272
|
-
bias_correction1 = 1.0 - beta1 ** (state["step"]+1)
|
|
273
|
-
bias_correction2 = 1.0 - beta2 ** (state["step"]+1)
|
|
274
|
-
update *= ((bias_correction2 ** .5) / bias_correction1) * alpha
|
|
275
|
-
elif alpha is not None:
|
|
276
|
-
update *= alpha
|
|
277
|
-
|
|
278
|
-
if merge_small:
|
|
279
|
-
update = _unmerge_small_dims(update, state['flat_sizes'], state['sort_idxs'])
|
|
280
|
-
|
|
281
|
-
updates.append(update)
|
|
282
|
-
state["step"] += 1
|
|
283
|
-
|
|
284
|
-
# Update is done after the gradient step to avoid using current gradients in the projection.
|
|
285
|
-
if state['GG'] is not None:
|
|
286
|
-
update_soap_covariances_(y, state['GG'], shampoo_beta)
|
|
287
|
-
if state['step'] % settings['precond_freq'] == 0:
|
|
288
|
-
state['Q'], state['exp_avg_sq'] = get_orthogonal_matrix_QR(exp_avg_sq, state['GG'], state['Q'])
|
|
289
|
-
|
|
290
|
-
return updates
|