torchzero 0.3.10__py3-none-any.whl → 0.3.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docs/source/conf.py +6 -4
- docs/source/docstring template.py +46 -0
- tests/test_identical.py +2 -3
- tests/test_opts.py +64 -50
- tests/test_vars.py +1 -0
- torchzero/core/module.py +138 -6
- torchzero/core/transform.py +158 -51
- torchzero/modules/__init__.py +3 -2
- torchzero/modules/clipping/clipping.py +114 -17
- torchzero/modules/clipping/ema_clipping.py +27 -13
- torchzero/modules/clipping/growth_clipping.py +8 -7
- torchzero/modules/experimental/__init__.py +22 -5
- torchzero/modules/experimental/absoap.py +5 -2
- torchzero/modules/experimental/adadam.py +8 -2
- torchzero/modules/experimental/adamY.py +8 -2
- torchzero/modules/experimental/adam_lambertw.py +149 -0
- torchzero/modules/{line_search/trust_region.py → experimental/adaptive_step_size.py} +21 -4
- torchzero/modules/experimental/adasoap.py +7 -2
- torchzero/modules/experimental/cosine.py +214 -0
- torchzero/modules/experimental/cubic_adam.py +97 -0
- torchzero/modules/{projections → experimental}/dct.py +11 -11
- torchzero/modules/experimental/eigendescent.py +4 -1
- torchzero/modules/experimental/etf.py +32 -9
- torchzero/modules/experimental/exp_adam.py +113 -0
- torchzero/modules/experimental/expanded_lbfgs.py +141 -0
- torchzero/modules/{projections → experimental}/fft.py +10 -10
- torchzero/modules/experimental/hnewton.py +85 -0
- torchzero/modules/{quasi_newton/experimental → experimental}/modular_lbfgs.py +27 -28
- torchzero/modules/experimental/newtonnewton.py +7 -3
- torchzero/modules/experimental/parabolic_search.py +220 -0
- torchzero/modules/experimental/reduce_outward_lr.py +4 -4
- torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +12 -54
- torchzero/modules/experimental/subspace_preconditioners.py +11 -4
- torchzero/modules/experimental/{tada.py → tensor_adagrad.py} +10 -6
- torchzero/modules/functional.py +12 -2
- torchzero/modules/grad_approximation/fdm.py +30 -3
- torchzero/modules/grad_approximation/forward_gradient.py +13 -3
- torchzero/modules/grad_approximation/grad_approximator.py +51 -6
- torchzero/modules/grad_approximation/rfdm.py +285 -38
- torchzero/modules/higher_order/higher_order_newton.py +152 -89
- torchzero/modules/line_search/__init__.py +4 -4
- torchzero/modules/line_search/adaptive.py +99 -0
- torchzero/modules/line_search/backtracking.py +34 -9
- torchzero/modules/line_search/line_search.py +70 -12
- torchzero/modules/line_search/polynomial.py +233 -0
- torchzero/modules/line_search/scipy.py +2 -2
- torchzero/modules/line_search/strong_wolfe.py +34 -7
- torchzero/modules/misc/__init__.py +27 -0
- torchzero/modules/{ops → misc}/debug.py +24 -1
- torchzero/modules/misc/escape.py +60 -0
- torchzero/modules/misc/gradient_accumulation.py +70 -0
- torchzero/modules/misc/misc.py +316 -0
- torchzero/modules/misc/multistep.py +158 -0
- torchzero/modules/misc/regularization.py +171 -0
- torchzero/modules/{ops → misc}/split.py +29 -1
- torchzero/modules/{ops → misc}/switch.py +44 -3
- torchzero/modules/momentum/__init__.py +1 -1
- torchzero/modules/momentum/averaging.py +6 -6
- torchzero/modules/momentum/cautious.py +45 -8
- torchzero/modules/momentum/ema.py +7 -7
- torchzero/modules/momentum/experimental.py +2 -2
- torchzero/modules/momentum/matrix_momentum.py +90 -63
- torchzero/modules/momentum/momentum.py +2 -1
- torchzero/modules/ops/__init__.py +3 -31
- torchzero/modules/ops/accumulate.py +6 -10
- torchzero/modules/ops/binary.py +72 -26
- torchzero/modules/ops/multi.py +77 -16
- torchzero/modules/ops/reduce.py +15 -7
- torchzero/modules/ops/unary.py +29 -13
- torchzero/modules/ops/utility.py +20 -12
- torchzero/modules/optimizers/__init__.py +12 -3
- torchzero/modules/optimizers/adagrad.py +23 -13
- torchzero/modules/optimizers/adahessian.py +223 -0
- torchzero/modules/optimizers/adam.py +7 -6
- torchzero/modules/optimizers/adan.py +110 -0
- torchzero/modules/optimizers/adaptive_heavyball.py +57 -0
- torchzero/modules/optimizers/esgd.py +171 -0
- torchzero/modules/{experimental/spectral.py → optimizers/ladagrad.py} +91 -71
- torchzero/modules/optimizers/lion.py +1 -1
- torchzero/modules/optimizers/mars.py +91 -0
- torchzero/modules/optimizers/msam.py +186 -0
- torchzero/modules/optimizers/muon.py +30 -5
- torchzero/modules/optimizers/orthograd.py +1 -1
- torchzero/modules/optimizers/rmsprop.py +7 -4
- torchzero/modules/optimizers/rprop.py +42 -8
- torchzero/modules/optimizers/sam.py +163 -0
- torchzero/modules/optimizers/shampoo.py +39 -5
- torchzero/modules/optimizers/soap.py +29 -19
- torchzero/modules/optimizers/sophia_h.py +71 -14
- torchzero/modules/projections/__init__.py +2 -4
- torchzero/modules/projections/cast.py +51 -0
- torchzero/modules/projections/galore.py +3 -1
- torchzero/modules/projections/projection.py +188 -94
- torchzero/modules/quasi_newton/__init__.py +12 -2
- torchzero/modules/quasi_newton/cg.py +160 -59
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +163 -0
- torchzero/modules/quasi_newton/lbfgs.py +154 -97
- torchzero/modules/quasi_newton/lsr1.py +101 -57
- torchzero/modules/quasi_newton/quasi_newton.py +863 -215
- torchzero/modules/quasi_newton/trust_region.py +397 -0
- torchzero/modules/second_order/__init__.py +2 -2
- torchzero/modules/second_order/newton.py +220 -41
- torchzero/modules/second_order/newton_cg.py +300 -11
- torchzero/modules/second_order/nystrom.py +104 -1
- torchzero/modules/smoothing/gaussian.py +34 -0
- torchzero/modules/smoothing/laplacian.py +14 -4
- torchzero/modules/step_size/__init__.py +2 -0
- torchzero/modules/step_size/adaptive.py +122 -0
- torchzero/modules/step_size/lr.py +154 -0
- torchzero/modules/weight_decay/__init__.py +1 -1
- torchzero/modules/weight_decay/weight_decay.py +89 -7
- torchzero/modules/wrappers/optim_wrapper.py +29 -1
- torchzero/optim/wrappers/directsearch.py +39 -2
- torchzero/optim/wrappers/fcmaes.py +21 -13
- torchzero/optim/wrappers/mads.py +5 -6
- torchzero/optim/wrappers/nevergrad.py +16 -1
- torchzero/optim/wrappers/optuna.py +1 -1
- torchzero/optim/wrappers/scipy.py +5 -3
- torchzero/utils/__init__.py +2 -2
- torchzero/utils/derivatives.py +3 -3
- torchzero/utils/linalg/__init__.py +1 -1
- torchzero/utils/linalg/solve.py +251 -12
- torchzero/utils/numberlist.py +2 -0
- torchzero/utils/python_tools.py +10 -0
- torchzero/utils/tensorlist.py +40 -28
- {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/METADATA +65 -40
- torchzero-0.3.11.dist-info/RECORD +159 -0
- torchzero/modules/experimental/diagonal_higher_order_newton.py +0 -225
- torchzero/modules/experimental/soapy.py +0 -163
- torchzero/modules/experimental/structured_newton.py +0 -111
- torchzero/modules/lr/__init__.py +0 -2
- torchzero/modules/lr/adaptive.py +0 -93
- torchzero/modules/lr/lr.py +0 -63
- torchzero/modules/ops/misc.py +0 -418
- torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
- torchzero/modules/quasi_newton/olbfgs.py +0 -196
- torchzero-0.3.10.dist-info/RECORD +0 -139
- {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/WHEEL +0 -0
- {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/licenses/LICENSE +0 -0
- {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,141 @@
|
|
|
1
|
+
from collections import deque
|
|
2
|
+
from operator import itemgetter
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from ...core import Transform, Chainable, Module, Var, apply_transform
|
|
6
|
+
from ...utils import TensorList, as_tensorlist, NumberList
|
|
7
|
+
from ...modules.quasi_newton.lbfgs import _adaptive_damping, lbfgs, _lerp_params_update_
|
|
8
|
+
|
|
9
|
+
class ExpandedLBFGS(Module):
|
|
10
|
+
"""L-BFGS but uses differences between more pairs than just consequtive. Window size controls how far away the pairs are allowed to be.
|
|
11
|
+
"""
|
|
12
|
+
def __init__(
|
|
13
|
+
self,
|
|
14
|
+
history_size=10,
|
|
15
|
+
window_size:int=3,
|
|
16
|
+
tol: float | None = 1e-10,
|
|
17
|
+
damping: bool = False,
|
|
18
|
+
init_damping=0.9,
|
|
19
|
+
eigval_bounds=(0.5, 50),
|
|
20
|
+
params_beta: float | None = None,
|
|
21
|
+
grads_beta: float | None = None,
|
|
22
|
+
update_freq = 1,
|
|
23
|
+
z_beta: float | None = None,
|
|
24
|
+
tol_reset: bool = False,
|
|
25
|
+
inner: Chainable | None = None,
|
|
26
|
+
):
|
|
27
|
+
defaults = dict(history_size=history_size, window_size=window_size, tol=tol, damping=damping, init_damping=init_damping, eigval_bounds=eigval_bounds, params_beta=params_beta, grads_beta=grads_beta, update_freq=update_freq, z_beta=z_beta, tol_reset=tol_reset)
|
|
28
|
+
super().__init__(defaults)
|
|
29
|
+
|
|
30
|
+
self.global_state['s_history'] = deque(maxlen=history_size)
|
|
31
|
+
self.global_state['y_history'] = deque(maxlen=history_size)
|
|
32
|
+
self.global_state['sy_history'] = deque(maxlen=history_size)
|
|
33
|
+
self.global_state['p_history'] = deque(maxlen=window_size)
|
|
34
|
+
self.global_state['g_history'] = deque(maxlen=window_size)
|
|
35
|
+
|
|
36
|
+
if inner is not None:
|
|
37
|
+
self.set_child('inner', inner)
|
|
38
|
+
|
|
39
|
+
def reset(self):
|
|
40
|
+
self.state.clear()
|
|
41
|
+
self.global_state['step'] = 0
|
|
42
|
+
self.global_state['s_history'].clear()
|
|
43
|
+
self.global_state['y_history'].clear()
|
|
44
|
+
self.global_state['sy_history'].clear()
|
|
45
|
+
self.global_state['p_history'].clear()
|
|
46
|
+
self.global_state['g_history'].clear()
|
|
47
|
+
|
|
48
|
+
@torch.no_grad
|
|
49
|
+
def step(self, var):
|
|
50
|
+
params = as_tensorlist(var.params)
|
|
51
|
+
update = as_tensorlist(var.get_update())
|
|
52
|
+
step = self.global_state.get('step', 0)
|
|
53
|
+
self.global_state['step'] = step + 1
|
|
54
|
+
|
|
55
|
+
# history of s and k
|
|
56
|
+
s_history: deque[TensorList] = self.global_state['s_history']
|
|
57
|
+
y_history: deque[TensorList] = self.global_state['y_history']
|
|
58
|
+
sy_history: deque[torch.Tensor] = self.global_state['sy_history']
|
|
59
|
+
p_history: deque[TensorList] = self.global_state['p_history']
|
|
60
|
+
g_history: deque[TensorList] = self.global_state['g_history']
|
|
61
|
+
|
|
62
|
+
tol, damping, init_damping, eigval_bounds, update_freq, z_beta, tol_reset = itemgetter(
|
|
63
|
+
'tol', 'damping', 'init_damping', 'eigval_bounds', 'update_freq', 'z_beta', 'tol_reset')(self.settings[params[0]])
|
|
64
|
+
params_beta, grads_beta = self.get_settings(params, 'params_beta', 'grads_beta')
|
|
65
|
+
|
|
66
|
+
l_params, l_update = _lerp_params_update_(self, params, update, params_beta, grads_beta)
|
|
67
|
+
prev_l_params, prev_l_grad = self.get_state(params, 'prev_l_params', 'prev_l_grad', cls=TensorList)
|
|
68
|
+
|
|
69
|
+
# 1st step - there are no previous params and grads, lbfgs will do normalized GD step
|
|
70
|
+
if step == 0:
|
|
71
|
+
s = None; y = None; ys = None
|
|
72
|
+
else:
|
|
73
|
+
s = l_params - prev_l_params
|
|
74
|
+
y = l_update - prev_l_grad
|
|
75
|
+
ys = s.dot(y)
|
|
76
|
+
|
|
77
|
+
if damping:
|
|
78
|
+
s, y, ys = _adaptive_damping(s, y, ys, init_damping=init_damping, eigval_bounds=eigval_bounds)
|
|
79
|
+
|
|
80
|
+
prev_l_params.copy_(l_params)
|
|
81
|
+
prev_l_grad.copy_(l_update)
|
|
82
|
+
|
|
83
|
+
# update effective preconditioning state
|
|
84
|
+
if step % update_freq == 0:
|
|
85
|
+
if ys is not None and ys > 1e-10:
|
|
86
|
+
assert s is not None and y is not None
|
|
87
|
+
s_history.append(s)
|
|
88
|
+
y_history.append(y)
|
|
89
|
+
sy_history.append(ys)
|
|
90
|
+
|
|
91
|
+
if len(p_history) > 1:
|
|
92
|
+
for p_i, g_i in zip(list(p_history)[:-1], list(g_history)[:-1]):
|
|
93
|
+
s_i = l_params - p_i
|
|
94
|
+
y_i = l_update - g_i
|
|
95
|
+
ys_i = s_i.dot(y_i)
|
|
96
|
+
|
|
97
|
+
if ys_i > 1e-10:
|
|
98
|
+
if damping:
|
|
99
|
+
s_i, y_i, ys_i = _adaptive_damping(s_i, y_i, ys_i, init_damping=init_damping, eigval_bounds=eigval_bounds)
|
|
100
|
+
|
|
101
|
+
s_history.append(s_i)
|
|
102
|
+
y_history.append(y_i)
|
|
103
|
+
sy_history.append(ys_i)
|
|
104
|
+
|
|
105
|
+
p_history.append(l_params.clone())
|
|
106
|
+
g_history.append(l_update.clone())
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
# step with inner module before applying preconditioner
|
|
110
|
+
if self.children:
|
|
111
|
+
update = TensorList(apply_transform(self.children['inner'], tensors=update, params=params, grads=var.grad, var=var))
|
|
112
|
+
|
|
113
|
+
# tolerance on gradient difference to avoid exploding after converging
|
|
114
|
+
if tol is not None:
|
|
115
|
+
if y is not None and y.abs().global_max() <= tol:
|
|
116
|
+
var.update = update # may have been updated by inner module, probably makes sense to use it here?
|
|
117
|
+
if tol_reset: self.reset()
|
|
118
|
+
return var
|
|
119
|
+
|
|
120
|
+
# lerp initial H^-1 @ q guess
|
|
121
|
+
z_ema = None
|
|
122
|
+
if z_beta is not None:
|
|
123
|
+
z_ema = self.get_state(var.params, 'z_ema', cls=TensorList)
|
|
124
|
+
|
|
125
|
+
# precondition
|
|
126
|
+
dir = lbfgs(
|
|
127
|
+
tensors_=as_tensorlist(update),
|
|
128
|
+
s_history=s_history,
|
|
129
|
+
y_history=y_history,
|
|
130
|
+
sy_history=sy_history,
|
|
131
|
+
y=y,
|
|
132
|
+
sy=ys,
|
|
133
|
+
z_beta = z_beta,
|
|
134
|
+
z_ema = z_ema,
|
|
135
|
+
step=step
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
var.update = dir
|
|
139
|
+
|
|
140
|
+
return var
|
|
141
|
+
|
|
@@ -2,12 +2,12 @@ import torch
|
|
|
2
2
|
|
|
3
3
|
from ...core import Chainable
|
|
4
4
|
from ...utils import vec_to_tensors
|
|
5
|
-
from
|
|
5
|
+
from ..projections import ProjectionBase
|
|
6
6
|
|
|
7
7
|
|
|
8
|
-
class FFTProjection(
|
|
8
|
+
class FFTProjection(ProjectionBase):
|
|
9
9
|
# norm description copied from pytorch docstring
|
|
10
|
-
"""Project update into
|
|
10
|
+
"""Project update into Fourier space of real-valued inputs.
|
|
11
11
|
|
|
12
12
|
Args:
|
|
13
13
|
modules (Chainable): modules that will optimize the projected update.
|
|
@@ -45,8 +45,8 @@ class FFTProjection(Projection):
|
|
|
45
45
|
super().__init__(modules, project_update=project_update, project_params=project_params, project_grad=project_grad, defaults=defaults)
|
|
46
46
|
|
|
47
47
|
@torch.no_grad
|
|
48
|
-
def project(self, tensors,
|
|
49
|
-
settings =
|
|
48
|
+
def project(self, tensors, params, grads, loss, states, settings, current):
|
|
49
|
+
settings = settings[0]
|
|
50
50
|
one_d = settings['one_d']
|
|
51
51
|
norm = settings['norm']
|
|
52
52
|
|
|
@@ -60,14 +60,14 @@ class FFTProjection(Projection):
|
|
|
60
60
|
return [torch.view_as_real(torch.fft.rfftn(t, norm=norm)) if t.numel() > 1 else t for t in tensors] # pylint:disable=not-callable
|
|
61
61
|
|
|
62
62
|
@torch.no_grad
|
|
63
|
-
def unproject(self,
|
|
64
|
-
settings =
|
|
63
|
+
def unproject(self, projected_tensors, params, grads, loss, projected_states, projected_settings, current):
|
|
64
|
+
settings = projected_settings[0]
|
|
65
65
|
one_d = settings['one_d']
|
|
66
66
|
norm = settings['norm']
|
|
67
67
|
|
|
68
68
|
if one_d:
|
|
69
|
-
vec = torch.view_as_complex(
|
|
69
|
+
vec = torch.view_as_complex(projected_tensors[0])
|
|
70
70
|
unprojected_vec = torch.fft.irfft(vec, n=self.global_state['length'], norm=norm) # pylint:disable=not-callable
|
|
71
|
-
return vec_to_tensors(unprojected_vec, reference=
|
|
71
|
+
return vec_to_tensors(unprojected_vec, reference=params)
|
|
72
72
|
|
|
73
|
-
return [torch.fft.irfftn(torch.view_as_complex(t.contiguous()), s=p.shape, norm=norm) if t.numel() > 1 else t for t, p in zip(
|
|
73
|
+
return [torch.fft.irfftn(torch.view_as_complex(t.contiguous()), s=p.shape, norm=norm) if t.numel() > 1 else t for t, p in zip(projected_tensors, params)] # pylint:disable=not-callable
|
|
@@ -0,0 +1,85 @@
|
|
|
1
|
+
from collections import deque
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from ...core import TensorwiseTransform
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def eigh_solve(H: torch.Tensor, g: torch.Tensor):
|
|
9
|
+
try:
|
|
10
|
+
L, Q = torch.linalg.eigh(H) # pylint:disable=not-callable
|
|
11
|
+
return Q @ ((Q.mH @ g) / L)
|
|
12
|
+
except torch.linalg.LinAlgError:
|
|
13
|
+
return None
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class HNewton(TensorwiseTransform):
|
|
17
|
+
"""This treats gradient differences as Hvps with vectors being parameter differences, using past gradients that are close to each other. Basically this is another limited memory quasi newton method to test.
|
|
18
|
+
|
|
19
|
+
.. warning::
|
|
20
|
+
Experimental.
|
|
21
|
+
|
|
22
|
+
"""
|
|
23
|
+
def __init__(self, history_size: int, window_size: int, reg: float=0, tol: float = 1e-8, concat_params:bool=True, inner=None):
|
|
24
|
+
defaults = dict(history_size=history_size, window_size=window_size, reg=reg, tol=tol)
|
|
25
|
+
super().__init__(defaults, uses_grad=False, concat_params=concat_params, inner=inner)
|
|
26
|
+
|
|
27
|
+
def update_tensor(self, tensor, param, grad, loss, state, setting):
|
|
28
|
+
|
|
29
|
+
history_size = setting['history_size']
|
|
30
|
+
|
|
31
|
+
if 'param_history' not in state:
|
|
32
|
+
state['param_history'] = deque(maxlen=history_size)
|
|
33
|
+
state['grad_history'] = deque(maxlen=history_size)
|
|
34
|
+
|
|
35
|
+
param_history: deque = state['param_history']
|
|
36
|
+
grad_history: deque = state['grad_history']
|
|
37
|
+
param_history.append(param.ravel())
|
|
38
|
+
grad_history.append(tensor.ravel())
|
|
39
|
+
|
|
40
|
+
def apply_tensor(self, tensor, param, grad, loss, state, setting):
|
|
41
|
+
window_size = setting['window_size']
|
|
42
|
+
reg = setting['reg']
|
|
43
|
+
tol = setting['tol']
|
|
44
|
+
|
|
45
|
+
param_history: deque = state['param_history']
|
|
46
|
+
grad_history: deque = state['grad_history']
|
|
47
|
+
g = tensor.ravel()
|
|
48
|
+
|
|
49
|
+
n = len(param_history)
|
|
50
|
+
s_list = []
|
|
51
|
+
y_list = []
|
|
52
|
+
|
|
53
|
+
for i in range(n):
|
|
54
|
+
for j in range(i):
|
|
55
|
+
if i - j <= window_size:
|
|
56
|
+
p_i, g_i = param_history[i], grad_history[i]
|
|
57
|
+
p_j, g_j = param_history[j], grad_history[j]
|
|
58
|
+
s = p_i - p_j # vec in hvp
|
|
59
|
+
y = g_i - g_j # hvp
|
|
60
|
+
if s.dot(y) > tol:
|
|
61
|
+
s_list.append(s)
|
|
62
|
+
y_list.append(y)
|
|
63
|
+
|
|
64
|
+
if len(s_list) < 1:
|
|
65
|
+
scale = (1 / tensor.abs().sum()).clip(min=torch.finfo(tensor.dtype).eps, max=1)
|
|
66
|
+
tensor.mul_(scale)
|
|
67
|
+
return tensor
|
|
68
|
+
|
|
69
|
+
S = torch.stack(s_list, 1)
|
|
70
|
+
Y = torch.stack(y_list, 1)
|
|
71
|
+
|
|
72
|
+
B = S.T @ Y
|
|
73
|
+
if reg != 0: B.add_(torch.eye(B.size(0), device=B.device, dtype=B.dtype).mul_(reg))
|
|
74
|
+
g_proj = g @ S
|
|
75
|
+
|
|
76
|
+
newton_proj, info = torch.linalg.solve_ex(B, g_proj) # pylint:disable=not-callable
|
|
77
|
+
if info != 0:
|
|
78
|
+
newton_proj = -torch.linalg.lstsq(B, g_proj).solution # pylint:disable=not-callable
|
|
79
|
+
newton = S @ newton_proj
|
|
80
|
+
return newton.view_as(tensor)
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
# scale = (1 / tensor.abs().sum()).clip(min=torch.finfo(tensor.dtype).eps, max=1)
|
|
84
|
+
# tensor.mul_(scale)
|
|
85
|
+
# return tensor
|
|
@@ -4,8 +4,8 @@ from typing import Any
|
|
|
4
4
|
|
|
5
5
|
import torch
|
|
6
6
|
|
|
7
|
-
from
|
|
8
|
-
from
|
|
7
|
+
from ...core import Chainable, Module, Transform, Var, apply_transform, maybe_chain
|
|
8
|
+
from ...utils import NumberList, TensorList, as_tensorlist
|
|
9
9
|
|
|
10
10
|
|
|
11
11
|
def _adaptive_damping(
|
|
@@ -43,32 +43,31 @@ def lbfgs(
|
|
|
43
43
|
if scale < 1e-5: scale = 1 / tensors_.abs().mean()
|
|
44
44
|
return tensors_.mul_(min(1.0, scale)) # pyright: ignore[reportArgumentType]
|
|
45
45
|
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
return z
|
|
46
|
+
# 1st loop
|
|
47
|
+
alpha_list = []
|
|
48
|
+
q = tensors_.clone()
|
|
49
|
+
for s_i, y_i, ys_i in zip(reversed(s_history), reversed(y_history), reversed(sy_history)):
|
|
50
|
+
p_i = 1 / ys_i # this is also denoted as ρ (rho)
|
|
51
|
+
alpha = p_i * s_i.dot(q)
|
|
52
|
+
alpha_list.append(alpha)
|
|
53
|
+
q.sub_(y_i, alpha=alpha) # pyright: ignore[reportArgumentType]
|
|
54
|
+
|
|
55
|
+
# calculate z
|
|
56
|
+
# s.y/y.y is also this weird y-looking symbol I couldn't find
|
|
57
|
+
# z is it times q
|
|
58
|
+
# actually H0 = (s.y/y.y) * I, and z = H0 @ q
|
|
59
|
+
z = q * (ys_k / (y_k.dot(y_k)))
|
|
60
|
+
|
|
61
|
+
if z_tfm is not None:
|
|
62
|
+
z = TensorList(apply_transform(z_tfm, tensors=z, params=var.params, grads=var.grad, var=var))
|
|
63
|
+
|
|
64
|
+
# 2nd loop
|
|
65
|
+
for s_i, y_i, ys_i, alpha_i in zip(s_history, y_history, sy_history, reversed(alpha_list)):
|
|
66
|
+
p_i = 1 / ys_i
|
|
67
|
+
beta_i = p_i * y_i.dot(z)
|
|
68
|
+
z.add_(s_i, alpha = alpha_i - beta_i)
|
|
69
|
+
|
|
70
|
+
return z
|
|
72
71
|
|
|
73
72
|
def _apply_tfms_into_history(
|
|
74
73
|
self: Module,
|
|
@@ -22,8 +22,9 @@ from ..second_order.newton import (
|
|
|
22
22
|
|
|
23
23
|
|
|
24
24
|
class NewtonNewton(Module):
|
|
25
|
-
"""
|
|
26
|
-
|
|
25
|
+
"""Applies Newton-like preconditioning to Newton step.
|
|
26
|
+
|
|
27
|
+
This is a method that I thought of and then it worked. Here is how it works:
|
|
27
28
|
|
|
28
29
|
1. Calculate newton step by solving Hx=g
|
|
29
30
|
|
|
@@ -34,6 +35,9 @@ class NewtonNewton(Module):
|
|
|
34
35
|
4. Optionally, repeat (if order is higher than 3.)
|
|
35
36
|
|
|
36
37
|
Memory is n^order. It tends to converge faster on convex functions, but can be unstable on non-convex. Orders higher than 3 are usually too unsable and have little benefit.
|
|
38
|
+
|
|
39
|
+
3rd order variant can minimize some convex functions with up to 100 variables in less time than Newton's method,
|
|
40
|
+
this is if pytorch can vectorize hessian computation efficiently.
|
|
37
41
|
"""
|
|
38
42
|
def __init__(
|
|
39
43
|
self,
|
|
@@ -83,6 +87,6 @@ class NewtonNewton(Module):
|
|
|
83
87
|
if x is None: x = least_squares_solve(H, xp)
|
|
84
88
|
xp = x.squeeze()
|
|
85
89
|
|
|
86
|
-
var.update = vec_to_tensors(xp, params)
|
|
90
|
+
var.update = vec_to_tensors(xp.nan_to_num_(0,0,0), params)
|
|
87
91
|
return var
|
|
88
92
|
|
|
@@ -0,0 +1,220 @@
|
|
|
1
|
+
import math
|
|
2
|
+
from collections.abc import Mapping
|
|
3
|
+
from operator import itemgetter
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
from ...core import Module
|
|
8
|
+
from ...utils import TensorList
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def adaptive_tracking(
|
|
13
|
+
f,
|
|
14
|
+
f_0,
|
|
15
|
+
f_1,
|
|
16
|
+
t_0,
|
|
17
|
+
maxiter: int
|
|
18
|
+
):
|
|
19
|
+
|
|
20
|
+
t = t_0
|
|
21
|
+
f_t = f(t)
|
|
22
|
+
|
|
23
|
+
# backtrack
|
|
24
|
+
if f_t > f_0:
|
|
25
|
+
if f_1 > f_0: t = min(0.5, t_0/2)
|
|
26
|
+
while f_t > f_0:
|
|
27
|
+
maxiter -= 1
|
|
28
|
+
if maxiter < 0: return 0, f_0
|
|
29
|
+
t = t/2
|
|
30
|
+
f_t = f(t) if t!=1 else f_1
|
|
31
|
+
return t, f_t
|
|
32
|
+
|
|
33
|
+
# forwardtrack
|
|
34
|
+
f_prev = f_t
|
|
35
|
+
t *= 2
|
|
36
|
+
f_t = f(t)
|
|
37
|
+
if f_prev < f_t: return t/2, f_prev
|
|
38
|
+
while f_prev >= f_t:
|
|
39
|
+
maxiter -= 1
|
|
40
|
+
if maxiter < 0: return t, f_t
|
|
41
|
+
f_prev = f_t
|
|
42
|
+
t *= 2
|
|
43
|
+
f_t = f(t)
|
|
44
|
+
return t/2, f_prev
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class ParabolaSearch(Module):
|
|
49
|
+
""""""
|
|
50
|
+
def __init__(
|
|
51
|
+
self,
|
|
52
|
+
step_size: float = 1e-2,
|
|
53
|
+
adaptive: bool=True,
|
|
54
|
+
normalize: bool=False,
|
|
55
|
+
# method: str | None = None,
|
|
56
|
+
maxiter: int | None = 10,
|
|
57
|
+
# bracket=None,
|
|
58
|
+
# bounds=None,
|
|
59
|
+
# tol: float | None = None,
|
|
60
|
+
# options=None,
|
|
61
|
+
):
|
|
62
|
+
if normalize and adaptive: raise ValueError("pick either normalize or adaptive")
|
|
63
|
+
defaults = dict(step_size=step_size, adaptive=adaptive, normalize=normalize, maxiter=maxiter)
|
|
64
|
+
super().__init__(defaults)
|
|
65
|
+
|
|
66
|
+
import scipy.optimize
|
|
67
|
+
self.scopt = scipy.optimize
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
@torch.no_grad
|
|
71
|
+
def step(self, var):
|
|
72
|
+
x_0 = TensorList(var.params)
|
|
73
|
+
closure = var.closure
|
|
74
|
+
assert closure is not None
|
|
75
|
+
settings = self.settings[x_0[0]]
|
|
76
|
+
step_size = settings['step_size']
|
|
77
|
+
adaptive = settings['adaptive']
|
|
78
|
+
normalize = settings['normalize']
|
|
79
|
+
maxiter = settings['maxiter']
|
|
80
|
+
if normalize and adaptive: raise ValueError("pick either normalize or adaptive")
|
|
81
|
+
|
|
82
|
+
grad = TensorList(var.get_grad())
|
|
83
|
+
f_0 = var.get_loss(False)
|
|
84
|
+
|
|
85
|
+
scale = 1
|
|
86
|
+
if normalize: grad = grad/grad.abs().mean().clip(min=1e-8)
|
|
87
|
+
if adaptive: scale = grad.abs().mean().clip(min=1e-8)
|
|
88
|
+
|
|
89
|
+
# make step
|
|
90
|
+
v_0 = grad * (step_size/scale)
|
|
91
|
+
x_0 -= v_0
|
|
92
|
+
with torch.enable_grad():
|
|
93
|
+
f_1 = closure()
|
|
94
|
+
grad = x_0.grad
|
|
95
|
+
|
|
96
|
+
x_0 += v_0
|
|
97
|
+
if normalize: grad = grad/grad.abs().mean().clip(min=1e-8)
|
|
98
|
+
v_1 = grad * (step_size/scale)
|
|
99
|
+
a = v_1 - v_0
|
|
100
|
+
|
|
101
|
+
def parabolic_objective(t: float):
|
|
102
|
+
nonlocal x_0
|
|
103
|
+
|
|
104
|
+
step = v_0*t + 0.5*a*t**2
|
|
105
|
+
x_0 -= step
|
|
106
|
+
value = closure(False)
|
|
107
|
+
x_0 += step
|
|
108
|
+
return value.detach().cpu()
|
|
109
|
+
|
|
110
|
+
prev_t = self.global_state.get('prev_t', 2)
|
|
111
|
+
t, f = adaptive_tracking(parabolic_objective, f_0=f_0, f_1=f_1, t_0=prev_t, maxiter=maxiter)
|
|
112
|
+
self.global_state['prev_t'] = t
|
|
113
|
+
|
|
114
|
+
# method, bracket, bounds, tol, options, maxiter = itemgetter(
|
|
115
|
+
# 'method', 'bracket', 'bounds', 'tol', 'options', 'maxiter')(self.settings[var.params[0]])
|
|
116
|
+
|
|
117
|
+
# if maxiter is not None:
|
|
118
|
+
# options = dict(options) if isinstance(options, Mapping) else {}
|
|
119
|
+
# options['maxiter'] = maxiter
|
|
120
|
+
|
|
121
|
+
# res = self.scopt.minimize_scalar(parabolic_objective, method=method, bracket=bracket, bounds=bounds, tol=tol, options=options)
|
|
122
|
+
# t = res.x
|
|
123
|
+
|
|
124
|
+
var.update = v_0*t + 0.5*a*t**2
|
|
125
|
+
return var
|
|
126
|
+
|
|
127
|
+
class CubicParabolaSearch(Module):
|
|
128
|
+
""""""
|
|
129
|
+
def __init__(
|
|
130
|
+
self,
|
|
131
|
+
step_size: float = 1e-2,
|
|
132
|
+
adaptive: bool=True,
|
|
133
|
+
normalize: bool=False,
|
|
134
|
+
# method: str | None = None,
|
|
135
|
+
maxiter: int | None = 10,
|
|
136
|
+
# bracket=None,
|
|
137
|
+
# bounds=None,
|
|
138
|
+
# tol: float | None = None,
|
|
139
|
+
# options=None,
|
|
140
|
+
):
|
|
141
|
+
if normalize and adaptive: raise ValueError("pick either normalize or adaptive")
|
|
142
|
+
defaults = dict(step_size=step_size, adaptive=adaptive, normalize=normalize, maxiter=maxiter)
|
|
143
|
+
super().__init__(defaults)
|
|
144
|
+
|
|
145
|
+
import scipy.optimize
|
|
146
|
+
self.scopt = scipy.optimize
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
@torch.no_grad
|
|
150
|
+
def step(self, var):
|
|
151
|
+
x_0 = TensorList(var.params)
|
|
152
|
+
closure = var.closure
|
|
153
|
+
assert closure is not None
|
|
154
|
+
settings = self.settings[x_0[0]]
|
|
155
|
+
step_size = settings['step_size']
|
|
156
|
+
adaptive = settings['adaptive']
|
|
157
|
+
maxiter = settings['maxiter']
|
|
158
|
+
normalize = settings['normalize']
|
|
159
|
+
if normalize and adaptive: raise ValueError("pick either normalize or adaptive")
|
|
160
|
+
|
|
161
|
+
grad = TensorList(var.get_grad())
|
|
162
|
+
f_0 = var.get_loss(False)
|
|
163
|
+
|
|
164
|
+
scale = 1
|
|
165
|
+
if normalize: grad = grad/grad.abs().mean().clip(min=1e-8)
|
|
166
|
+
if adaptive: scale = grad.abs().mean().clip(min=1e-8)
|
|
167
|
+
|
|
168
|
+
# make step
|
|
169
|
+
v_0 = grad * (step_size/scale)
|
|
170
|
+
x_0 -= v_0
|
|
171
|
+
with torch.enable_grad():
|
|
172
|
+
f_1 = closure()
|
|
173
|
+
grad = x_0.grad
|
|
174
|
+
|
|
175
|
+
if normalize: grad = grad/grad.abs().mean().clip(min=1e-8)
|
|
176
|
+
v_1 = grad * (step_size/scale)
|
|
177
|
+
a_0 = v_1 - v_0
|
|
178
|
+
|
|
179
|
+
# make another step
|
|
180
|
+
x_0 -= v_1
|
|
181
|
+
with torch.enable_grad():
|
|
182
|
+
f_2 = closure()
|
|
183
|
+
grad = x_0.grad
|
|
184
|
+
|
|
185
|
+
if normalize: grad = grad/grad.abs().mean().clip(min=1e-8)
|
|
186
|
+
v_2 = grad * (step_size/scale)
|
|
187
|
+
a_1 = v_2 - v_1
|
|
188
|
+
|
|
189
|
+
j = a_1 - a_0
|
|
190
|
+
|
|
191
|
+
x_0 += v_0
|
|
192
|
+
x_0 += v_1
|
|
193
|
+
|
|
194
|
+
def parabolic_objective(t: float):
|
|
195
|
+
nonlocal x_0
|
|
196
|
+
|
|
197
|
+
step = v_0*t + (1/2)*a_0*t**2 + (1/6)*j*t**3
|
|
198
|
+
x_0 -= step
|
|
199
|
+
value = closure(False)
|
|
200
|
+
x_0 += step
|
|
201
|
+
return value
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
prev_t = self.global_state.get('prev_t', 2)
|
|
205
|
+
t, f = adaptive_tracking(parabolic_objective, f_0=f_0, f_1=f_1, t_0=prev_t, maxiter=maxiter)
|
|
206
|
+
self.global_state['prev_t'] = t
|
|
207
|
+
|
|
208
|
+
# method, bracket, bounds, tol, options, maxiter = itemgetter(
|
|
209
|
+
# 'method', 'bracket', 'bounds', 'tol', 'options', 'maxiter')(self.settings[var.params[0]])
|
|
210
|
+
|
|
211
|
+
# if maxiter is not None:
|
|
212
|
+
# options = dict(options) if isinstance(options, Mapping) else {}
|
|
213
|
+
# options['maxiter'] = maxiter
|
|
214
|
+
|
|
215
|
+
# res = self.scopt.minimize_scalar(parabolic_objective, method=method, bracket=bracket, bounds=bounds, tol=tol, options=options)
|
|
216
|
+
# t = res.x
|
|
217
|
+
|
|
218
|
+
var.update = v_0*t + (1/2)*a_0*t**2 + (1/6)*j*t**3
|
|
219
|
+
return var
|
|
220
|
+
|
|
@@ -4,19 +4,19 @@ from ...core import Target, Transform
|
|
|
4
4
|
from ...utils import TensorList, unpack_states, unpack_dicts
|
|
5
5
|
|
|
6
6
|
class ReduceOutwardLR(Transform):
|
|
7
|
-
"""
|
|
8
|
-
When update sign matches weight sign, the learning rate for that weight is multiplied by `mul`.
|
|
7
|
+
"""When update sign matches weight sign, the learning rate for that weight is multiplied by `mul`.
|
|
9
8
|
|
|
10
9
|
This means updates that move weights towards zero have higher learning rates.
|
|
11
10
|
|
|
12
|
-
|
|
11
|
+
.. warning::
|
|
12
|
+
This sounded good but after testing turns out it sucks.
|
|
13
13
|
"""
|
|
14
14
|
def __init__(self, mul = 0.5, use_grad=False, invert=False, target: Target = 'update'):
|
|
15
15
|
defaults = dict(mul=mul, use_grad=use_grad, invert=invert)
|
|
16
16
|
super().__init__(defaults, uses_grad=use_grad, target=target)
|
|
17
17
|
|
|
18
18
|
@torch.no_grad
|
|
19
|
-
def
|
|
19
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
20
20
|
params = TensorList(params)
|
|
21
21
|
tensors = TensorList(tensors)
|
|
22
22
|
|