torchzero 0.3.15__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_identical.py +22 -22
- tests/test_module_autograd.py +586 -0
- tests/test_objective.py +188 -0
- tests/test_opts.py +225 -214
- tests/test_tensorlist.py +0 -8
- tests/test_utils_optimizer.py +0 -1
- torchzero/__init__.py +2 -2
- torchzero/core/__init__.py +7 -4
- torchzero/core/chain.py +20 -23
- torchzero/core/functional.py +90 -24
- torchzero/core/modular.py +53 -57
- torchzero/core/module.py +132 -52
- torchzero/core/objective.py +948 -0
- torchzero/core/reformulation.py +55 -24
- torchzero/core/transform.py +261 -367
- torchzero/linalg/__init__.py +11 -0
- torchzero/linalg/eigh.py +253 -0
- torchzero/linalg/linalg_utils.py +14 -0
- torchzero/{utils/linalg → linalg}/linear_operator.py +99 -49
- torchzero/linalg/matrix_power.py +28 -0
- torchzero/linalg/orthogonalize.py +93 -0
- torchzero/{utils/linalg → linalg}/qr.py +16 -2
- torchzero/{utils/linalg → linalg}/solve.py +74 -88
- torchzero/linalg/svd.py +47 -0
- torchzero/linalg/torch_linalg.py +168 -0
- torchzero/modules/__init__.py +4 -3
- torchzero/modules/adaptive/__init__.py +11 -3
- torchzero/modules/adaptive/adagrad.py +167 -217
- torchzero/modules/adaptive/adahessian.py +76 -105
- torchzero/modules/adaptive/adam.py +53 -76
- torchzero/modules/adaptive/adan.py +50 -31
- torchzero/modules/adaptive/adaptive_heavyball.py +12 -7
- torchzero/modules/adaptive/aegd.py +12 -12
- torchzero/modules/adaptive/esgd.py +98 -119
- torchzero/modules/adaptive/ggt.py +186 -0
- torchzero/modules/adaptive/lion.py +7 -11
- torchzero/modules/adaptive/lre_optimizers.py +299 -0
- torchzero/modules/adaptive/mars.py +7 -7
- torchzero/modules/adaptive/matrix_momentum.py +48 -52
- torchzero/modules/adaptive/msam.py +71 -53
- torchzero/modules/adaptive/muon.py +67 -129
- torchzero/modules/adaptive/natural_gradient.py +63 -41
- torchzero/modules/adaptive/orthograd.py +11 -15
- torchzero/modules/adaptive/psgd/__init__.py +5 -0
- torchzero/modules/adaptive/psgd/_psgd_utils.py +37 -0
- torchzero/modules/adaptive/psgd/psgd.py +1390 -0
- torchzero/modules/adaptive/psgd/psgd_dense_newton.py +174 -0
- torchzero/modules/adaptive/psgd/psgd_kron_newton.py +203 -0
- torchzero/modules/adaptive/psgd/psgd_kron_whiten.py +185 -0
- torchzero/modules/adaptive/psgd/psgd_lra_newton.py +118 -0
- torchzero/modules/adaptive/psgd/psgd_lra_whiten.py +116 -0
- torchzero/modules/adaptive/rmsprop.py +83 -75
- torchzero/modules/adaptive/rprop.py +48 -47
- torchzero/modules/adaptive/sam.py +55 -45
- torchzero/modules/adaptive/shampoo.py +149 -130
- torchzero/modules/adaptive/soap.py +207 -143
- torchzero/modules/adaptive/sophia_h.py +106 -130
- torchzero/modules/clipping/clipping.py +22 -25
- torchzero/modules/clipping/ema_clipping.py +31 -25
- torchzero/modules/clipping/growth_clipping.py +14 -17
- torchzero/modules/conjugate_gradient/cg.py +27 -38
- torchzero/modules/experimental/__init__.py +7 -6
- torchzero/modules/experimental/adanystrom.py +258 -0
- torchzero/modules/experimental/common_directions_whiten.py +142 -0
- torchzero/modules/experimental/coordinate_momentum.py +36 -0
- torchzero/modules/experimental/cubic_adam.py +160 -0
- torchzero/modules/experimental/curveball.py +25 -41
- torchzero/modules/experimental/eigen_sr1.py +182 -0
- torchzero/modules/experimental/eigengrad.py +207 -0
- torchzero/modules/experimental/gradmin.py +2 -2
- torchzero/modules/experimental/higher_order_newton.py +14 -40
- torchzero/modules/experimental/l_infinity.py +1 -1
- torchzero/modules/experimental/matrix_nag.py +122 -0
- torchzero/modules/experimental/newton_solver.py +23 -54
- torchzero/modules/experimental/newtonnewton.py +45 -48
- torchzero/modules/experimental/reduce_outward_lr.py +7 -7
- torchzero/modules/experimental/scipy_newton_cg.py +21 -24
- torchzero/modules/experimental/spsa1.py +3 -3
- torchzero/modules/experimental/structural_projections.py +1 -4
- torchzero/modules/grad_approximation/fdm.py +2 -2
- torchzero/modules/grad_approximation/forward_gradient.py +7 -7
- torchzero/modules/grad_approximation/grad_approximator.py +23 -16
- torchzero/modules/grad_approximation/rfdm.py +24 -21
- torchzero/modules/least_squares/gn.py +121 -50
- torchzero/modules/line_search/backtracking.py +4 -4
- torchzero/modules/line_search/line_search.py +33 -33
- torchzero/modules/line_search/strong_wolfe.py +4 -4
- torchzero/modules/misc/debug.py +12 -12
- torchzero/modules/misc/escape.py +10 -10
- torchzero/modules/misc/gradient_accumulation.py +11 -79
- torchzero/modules/misc/homotopy.py +16 -8
- torchzero/modules/misc/misc.py +121 -123
- torchzero/modules/misc/multistep.py +52 -53
- torchzero/modules/misc/regularization.py +49 -44
- torchzero/modules/misc/split.py +31 -29
- torchzero/modules/misc/switch.py +37 -32
- torchzero/modules/momentum/averaging.py +14 -14
- torchzero/modules/momentum/cautious.py +37 -31
- torchzero/modules/momentum/momentum.py +12 -12
- torchzero/modules/ops/__init__.py +4 -4
- torchzero/modules/ops/accumulate.py +21 -21
- torchzero/modules/ops/binary.py +67 -66
- torchzero/modules/ops/higher_level.py +20 -20
- torchzero/modules/ops/multi.py +44 -41
- torchzero/modules/ops/reduce.py +26 -23
- torchzero/modules/ops/unary.py +53 -53
- torchzero/modules/ops/utility.py +47 -46
- torchzero/modules/{functional.py → opt_utils.py} +1 -1
- torchzero/modules/projections/galore.py +1 -1
- torchzero/modules/projections/projection.py +46 -43
- torchzero/modules/quasi_newton/__init__.py +1 -1
- torchzero/modules/quasi_newton/damping.py +2 -2
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +1 -1
- torchzero/modules/quasi_newton/lbfgs.py +10 -10
- torchzero/modules/quasi_newton/lsr1.py +10 -10
- torchzero/modules/quasi_newton/quasi_newton.py +54 -39
- torchzero/modules/quasi_newton/sg2.py +69 -205
- torchzero/modules/restarts/restars.py +39 -37
- torchzero/modules/second_order/__init__.py +2 -2
- torchzero/modules/second_order/ifn.py +31 -62
- torchzero/modules/second_order/inm.py +57 -53
- torchzero/modules/second_order/multipoint.py +40 -80
- torchzero/modules/second_order/newton.py +165 -196
- torchzero/modules/second_order/newton_cg.py +105 -157
- torchzero/modules/second_order/nystrom.py +216 -185
- torchzero/modules/second_order/rsn.py +132 -125
- torchzero/modules/smoothing/laplacian.py +13 -12
- torchzero/modules/smoothing/sampling.py +10 -10
- torchzero/modules/step_size/adaptive.py +24 -24
- torchzero/modules/step_size/lr.py +17 -17
- torchzero/modules/termination/termination.py +32 -30
- torchzero/modules/trust_region/cubic_regularization.py +3 -3
- torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
- torchzero/modules/trust_region/trust_cg.py +2 -2
- torchzero/modules/trust_region/trust_region.py +27 -22
- torchzero/modules/variance_reduction/svrg.py +23 -21
- torchzero/modules/weight_decay/__init__.py +2 -1
- torchzero/modules/weight_decay/reinit.py +83 -0
- torchzero/modules/weight_decay/weight_decay.py +17 -18
- torchzero/modules/wrappers/optim_wrapper.py +14 -14
- torchzero/modules/zeroth_order/cd.py +10 -7
- torchzero/optim/mbs.py +291 -0
- torchzero/optim/root.py +3 -3
- torchzero/optim/utility/split.py +2 -1
- torchzero/optim/wrappers/directsearch.py +27 -63
- torchzero/optim/wrappers/fcmaes.py +14 -35
- torchzero/optim/wrappers/mads.py +11 -31
- torchzero/optim/wrappers/moors.py +66 -0
- torchzero/optim/wrappers/nevergrad.py +4 -13
- torchzero/optim/wrappers/nlopt.py +31 -25
- torchzero/optim/wrappers/optuna.py +8 -13
- torchzero/optim/wrappers/pybobyqa.py +124 -0
- torchzero/optim/wrappers/scipy/__init__.py +7 -0
- torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
- torchzero/optim/wrappers/scipy/brute.py +48 -0
- torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
- torchzero/optim/wrappers/scipy/direct.py +69 -0
- torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
- torchzero/optim/wrappers/scipy/experimental.py +141 -0
- torchzero/optim/wrappers/scipy/minimize.py +151 -0
- torchzero/optim/wrappers/scipy/sgho.py +111 -0
- torchzero/optim/wrappers/wrapper.py +121 -0
- torchzero/utils/__init__.py +7 -25
- torchzero/utils/benchmarks/__init__.py +0 -0
- torchzero/utils/benchmarks/logistic.py +122 -0
- torchzero/utils/compile.py +2 -2
- torchzero/utils/derivatives.py +97 -73
- torchzero/utils/optimizer.py +4 -77
- torchzero/utils/python_tools.py +31 -0
- torchzero/utils/tensorlist.py +11 -5
- torchzero/utils/thoad_tools.py +68 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/METADATA +1 -1
- torchzero-0.4.1.dist-info/RECORD +209 -0
- tests/test_vars.py +0 -185
- torchzero/core/var.py +0 -376
- torchzero/modules/adaptive/lmadagrad.py +0 -186
- torchzero/modules/experimental/momentum.py +0 -160
- torchzero/optim/wrappers/scipy.py +0 -572
- torchzero/utils/linalg/__init__.py +0 -12
- torchzero/utils/linalg/matrix_funcs.py +0 -87
- torchzero/utils/linalg/orthogonalize.py +0 -12
- torchzero/utils/linalg/svd.py +0 -20
- torchzero/utils/ops.py +0 -10
- torchzero-0.3.15.dist-info/RECORD +0 -175
- /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/WHEEL +0 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,299 @@
|
|
|
1
|
+
"""subspace optimizers to be used in a low rank eigenbasis
|
|
2
|
+
|
|
3
|
+
three opts support this - GGT and experimental AdaNystrom and Eigengrad
|
|
4
|
+
|
|
5
|
+
I could define repoject on a module but because most opts use per-parameter state that is complicated"""
|
|
6
|
+
|
|
7
|
+
import math
|
|
8
|
+
from abc import ABC, abstractmethod
|
|
9
|
+
from typing import Any, cast
|
|
10
|
+
|
|
11
|
+
import torch
|
|
12
|
+
|
|
13
|
+
from ...linalg import matrix_power_eigh, torch_linalg
|
|
14
|
+
from .lion import lion_
|
|
15
|
+
|
|
16
|
+
class LREOptimizerBase(ABC):
|
|
17
|
+
"""Optimizer to run in a low rank eigenbasis.
|
|
18
|
+
|
|
19
|
+
notes:
|
|
20
|
+
|
|
21
|
+
1. it shouldn't store any states in self, everything should be in state.
|
|
22
|
+
This is because this may be called on multiple parameters in a sequence
|
|
23
|
+
|
|
24
|
+
2. apply is always called first, than reproject whenever eigenbasis gets updated
|
|
25
|
+
|
|
26
|
+
3. L is variance in the eigenbasis.
|
|
27
|
+
"""
|
|
28
|
+
@abstractmethod
|
|
29
|
+
def step(self, g: torch.Tensor, L: torch.Tensor, Q: torch.Tensor, state: dict) -> torch.Tensor:
|
|
30
|
+
...
|
|
31
|
+
|
|
32
|
+
@abstractmethod
|
|
33
|
+
def reproject(self, L_old: torch.Tensor, Q_old: torch.Tensor,
|
|
34
|
+
L_new: torch.Tensor, Q_new: torch.Tensor, state: dict) -> None:
|
|
35
|
+
...
|
|
36
|
+
|
|
37
|
+
class Whiten(LREOptimizerBase):
|
|
38
|
+
"""This simply applies whitening and is equivalent to not running an optimizer in the eigenbasis"""
|
|
39
|
+
def step(self, g, L, Q, state): return (Q * L.rsqrt()) @ (Q.T @ g)
|
|
40
|
+
def reproject(self, L_old, Q_old, L_new, Q_new, state): pass
|
|
41
|
+
|
|
42
|
+
class EMA(LREOptimizerBase):
|
|
43
|
+
"""Maintains exponential moving average of gradients in the low rank eigenbasis. Nesterov setting is experimental"""
|
|
44
|
+
def __init__(self, beta=0.9, nesterov:bool=False, cautious:bool=False, whiten:bool=True):
|
|
45
|
+
self.beta = beta
|
|
46
|
+
self.nesterov = nesterov
|
|
47
|
+
self.whiten = whiten
|
|
48
|
+
self.cautious = cautious
|
|
49
|
+
|
|
50
|
+
def step(self, g, L, Q, state):
|
|
51
|
+
g = Q.T @ g
|
|
52
|
+
|
|
53
|
+
if "exp_avg" not in state:
|
|
54
|
+
state["exp_avg"] = torch.zeros_like(g)
|
|
55
|
+
|
|
56
|
+
exp_avg = state["exp_avg"]
|
|
57
|
+
exp_avg.lerp_(g, 1-self.beta)
|
|
58
|
+
|
|
59
|
+
if self.nesterov:
|
|
60
|
+
dir = (g + exp_avg * self.beta) / (1 + self.beta)
|
|
61
|
+
else:
|
|
62
|
+
dir = exp_avg
|
|
63
|
+
|
|
64
|
+
if self.cautious:
|
|
65
|
+
mask = (g * dir) > 0
|
|
66
|
+
dir *= mask
|
|
67
|
+
|
|
68
|
+
if self.whiten: return (Q * L.rsqrt()) @ dir
|
|
69
|
+
return Q @ dir
|
|
70
|
+
|
|
71
|
+
def reproject(self, L_old, Q_old, L_new, Q_new, state):
|
|
72
|
+
if "exp_avg" not in state: return
|
|
73
|
+
C = Q_new.T @ Q_old
|
|
74
|
+
state["exp_avg"] = C @ state["exp_avg"]
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def adam(g:torch.Tensor, state:dict, beta1, beta2, eps):
|
|
78
|
+
|
|
79
|
+
if "exp_avg" not in state:
|
|
80
|
+
state["exp_avg"] = torch.zeros_like(g)
|
|
81
|
+
state["exp_avg_sq"] = torch.zeros_like(g)
|
|
82
|
+
state["current_step"] = 1
|
|
83
|
+
|
|
84
|
+
exp_avg = state["exp_avg"]
|
|
85
|
+
exp_avg_sq = state["exp_avg_sq"]
|
|
86
|
+
current_step = state["current_step"]
|
|
87
|
+
|
|
88
|
+
exp_avg.lerp_(g, 1-beta1)
|
|
89
|
+
exp_avg_sq.mul_(beta2).addcmul_(g, g, value=1-beta2)
|
|
90
|
+
denom = exp_avg_sq.sqrt().add_(eps)
|
|
91
|
+
|
|
92
|
+
bias_correction1 = 1.0 - (beta1 ** current_step)
|
|
93
|
+
bias_correction2 = 1.0 - (beta2 ** current_step)
|
|
94
|
+
alpha = math.sqrt(bias_correction2) / bias_correction1
|
|
95
|
+
state["current_step"] = current_step + 1
|
|
96
|
+
|
|
97
|
+
return (exp_avg * alpha) / denom
|
|
98
|
+
|
|
99
|
+
def _squared_reproject(C: torch.Tensor, sq: torch.Tensor, exact: bool):
|
|
100
|
+
if exact:
|
|
101
|
+
return (C @ sq.diag_embed() @ C.T).diagonal()
|
|
102
|
+
|
|
103
|
+
return C.square() @ sq
|
|
104
|
+
|
|
105
|
+
class Adam(LREOptimizerBase):
|
|
106
|
+
"""Runs Adam in low rank eigenbasis."""
|
|
107
|
+
def __init__(self, beta1=0.9, beta2=0.95, cautious:bool=False, eps=1e-8, exact_reproject:bool=True):
|
|
108
|
+
self.beta1 = beta1
|
|
109
|
+
self.beta2 = beta2
|
|
110
|
+
self.eps = eps
|
|
111
|
+
self.cautious = cautious
|
|
112
|
+
self.exact_reproject = exact_reproject
|
|
113
|
+
|
|
114
|
+
def step(self, g, L, Q, state):
|
|
115
|
+
g = Q.T @ g
|
|
116
|
+
|
|
117
|
+
dir = adam(g, state, self.beta1, self.beta2, self.eps)
|
|
118
|
+
|
|
119
|
+
if self.cautious:
|
|
120
|
+
mask = (g * dir) > 0
|
|
121
|
+
dir *= mask
|
|
122
|
+
|
|
123
|
+
return Q @ dir
|
|
124
|
+
|
|
125
|
+
def reproject(self, L_old, Q_old, L_new, Q_new, state):
|
|
126
|
+
if "exp_avg" not in state: return
|
|
127
|
+
C = Q_new.T @ Q_old
|
|
128
|
+
|
|
129
|
+
state["exp_avg"] = C @ state["exp_avg"]
|
|
130
|
+
state["exp_avg_sq"] = _squared_reproject(C, state["exp_avg_sq"], self.exact_reproject)
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
class FullMatrixAdam(LREOptimizerBase):
|
|
134
|
+
"""Runs full-matrix Adam in low rank eigenbasis.
|
|
135
|
+
The preconditioner is updated whenever basis is updated"""
|
|
136
|
+
def __init__(self, beta1=0.9, beta2=0.95, eps=1e-8, matrix_power=-1/2, abs=True, cautious:bool=False):
|
|
137
|
+
self.beta1 = beta1
|
|
138
|
+
self.beta2 = beta2
|
|
139
|
+
self.eps = eps
|
|
140
|
+
self.matrix_power = matrix_power
|
|
141
|
+
self.abs = abs
|
|
142
|
+
self.cautious = cautious
|
|
143
|
+
|
|
144
|
+
def step(self, g, L, Q, state):
|
|
145
|
+
g = Q.T @ g
|
|
146
|
+
|
|
147
|
+
# initialize
|
|
148
|
+
if "exp_avg" not in state:
|
|
149
|
+
state["exp_avg"] = torch.zeros_like(g)
|
|
150
|
+
state["covariance"] = torch.eye(g.numel(), device=g.device, dtype=g.dtype)
|
|
151
|
+
state["preconditioner"] = torch.eye(g.numel(), device=g.device, dtype=g.dtype)
|
|
152
|
+
state["reprojected"] = True
|
|
153
|
+
state["current_step"] = 1
|
|
154
|
+
|
|
155
|
+
exp_avg = state["exp_avg"]
|
|
156
|
+
covariance = state["covariance"]
|
|
157
|
+
current_step = state["current_step"]
|
|
158
|
+
|
|
159
|
+
# update buffers
|
|
160
|
+
exp_avg.lerp_(g, 1-self.beta1)
|
|
161
|
+
covariance.lerp_(g.outer(g), weight=1-self.beta2)
|
|
162
|
+
|
|
163
|
+
# correct bias
|
|
164
|
+
bias_correction1 = 1.0 - (self.beta1 ** current_step)
|
|
165
|
+
exp_avg = exp_avg / bias_correction1
|
|
166
|
+
|
|
167
|
+
# after reprojecting update the preconditioner
|
|
168
|
+
if state["reprojected"]:
|
|
169
|
+
state["reprojected"] = False
|
|
170
|
+
|
|
171
|
+
bias_correction2 = 1.0 - (self.beta2 ** current_step)
|
|
172
|
+
covariance = covariance / bias_correction2
|
|
173
|
+
|
|
174
|
+
reg = torch.eye(covariance.size(0), device=covariance.device, dtype=covariance.dtype).mul_(self.eps)
|
|
175
|
+
covariance = covariance + reg
|
|
176
|
+
|
|
177
|
+
# compute matrix power
|
|
178
|
+
try:
|
|
179
|
+
state["preconditioner"] = matrix_power_eigh(covariance, self.matrix_power, abs=self.abs)
|
|
180
|
+
|
|
181
|
+
except torch.linalg.LinAlgError:
|
|
182
|
+
|
|
183
|
+
# fallback to diagonal
|
|
184
|
+
state["preconditioner"] = covariance.diagonal().rsqrt().diag_embed()
|
|
185
|
+
|
|
186
|
+
# compute the update
|
|
187
|
+
state["current_step"] = current_step + 1
|
|
188
|
+
preconditioner = state["preconditioner"]
|
|
189
|
+
dir = preconditioner @ exp_avg
|
|
190
|
+
|
|
191
|
+
if self.cautious:
|
|
192
|
+
mask = (g * dir) > 0
|
|
193
|
+
dir *= mask
|
|
194
|
+
|
|
195
|
+
return Q @ dir
|
|
196
|
+
|
|
197
|
+
def reproject(self, L_old, Q_old, L_new, Q_new, state):
|
|
198
|
+
if "exp_avg" not in state: return
|
|
199
|
+
|
|
200
|
+
state["reprojected"] = True
|
|
201
|
+
|
|
202
|
+
C = Q_new.T @ Q_old
|
|
203
|
+
state["exp_avg"] = C @ state["exp_avg"]
|
|
204
|
+
state["covariance"] = C @ state["covariance"] @ C.T
|
|
205
|
+
|
|
206
|
+
class Lion(LREOptimizerBase):
|
|
207
|
+
"""Runs Lion in the low rank eigenbasis."""
|
|
208
|
+
def __init__(self, beta1=0.9, beta2=0.99, cautious:bool=False):
|
|
209
|
+
self.beta1 = beta1
|
|
210
|
+
self.beta2 = beta2
|
|
211
|
+
self.cautious = cautious
|
|
212
|
+
|
|
213
|
+
def step(self, g, L, Q, state):
|
|
214
|
+
g = Q.T @ g
|
|
215
|
+
|
|
216
|
+
if "exp_avg" not in state:
|
|
217
|
+
state["exp_avg"] = torch.zeros_like(g)
|
|
218
|
+
|
|
219
|
+
dir = cast(torch.Tensor, lion_(g, state["exp_avg"], beta1=self.beta1, beta2=self.beta2))
|
|
220
|
+
|
|
221
|
+
if self.cautious:
|
|
222
|
+
mask = (g * dir) > 0
|
|
223
|
+
dir *= mask
|
|
224
|
+
|
|
225
|
+
return Q @ dir
|
|
226
|
+
|
|
227
|
+
def reproject(self, L_old, Q_old, L_new, Q_new, state):
|
|
228
|
+
if "exp_avg" not in state: return
|
|
229
|
+
C = Q_new.T @ Q_old
|
|
230
|
+
state["exp_avg"] = C @ state["exp_avg"]
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
class Grams(LREOptimizerBase):
|
|
234
|
+
"""Runs Grams in low rank eigenbasis."""
|
|
235
|
+
def __init__(self, beta1=0.9, beta2=0.95, eps=1e-8, exact_reproject=True):
|
|
236
|
+
self.beta1 = beta1
|
|
237
|
+
self.beta2 = beta2
|
|
238
|
+
self.eps = eps
|
|
239
|
+
self.exact_reproject = exact_reproject
|
|
240
|
+
|
|
241
|
+
def step(self, g, L, Q, state):
|
|
242
|
+
g = Q.T @ g
|
|
243
|
+
dir = adam(g, state, self.beta1, self.beta2, self.eps)
|
|
244
|
+
return Q @ dir.copysign(g)
|
|
245
|
+
|
|
246
|
+
def reproject(self, L_old, Q_old, L_new, Q_new, state):
|
|
247
|
+
if "exp_avg" not in state: return
|
|
248
|
+
C = Q_new.T @ Q_old
|
|
249
|
+
|
|
250
|
+
state["exp_avg"] = C @ state["exp_avg"]
|
|
251
|
+
state["exp_avg_sq"] = _squared_reproject(C, state["exp_avg_sq"], self.exact_reproject)
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
class LaProp(LREOptimizerBase):
|
|
255
|
+
"""Runs LaProp in low rank eigenbasis."""
|
|
256
|
+
def __init__(self, beta1=0.9, beta2=0.95, eps=1e-8, cautious:bool=False, exact_reproject=True):
|
|
257
|
+
self.beta1 = beta1
|
|
258
|
+
self.beta2 = beta2
|
|
259
|
+
self.eps = eps
|
|
260
|
+
self.cautious = cautious
|
|
261
|
+
self.exact_reproject = exact_reproject
|
|
262
|
+
|
|
263
|
+
def step(self, g, L, Q, state):
|
|
264
|
+
g = Q.T @ g
|
|
265
|
+
|
|
266
|
+
if "exp_avg" not in state:
|
|
267
|
+
state["exp_avg"] = torch.zeros_like(g)
|
|
268
|
+
state["exp_avg_sq"] = torch.zeros_like(g)
|
|
269
|
+
state["current_step"] = 1
|
|
270
|
+
|
|
271
|
+
exp_avg = state["exp_avg"]
|
|
272
|
+
exp_avg_sq = state["exp_avg_sq"]
|
|
273
|
+
current_step = state["current_step"]
|
|
274
|
+
|
|
275
|
+
# update second moments
|
|
276
|
+
exp_avg_sq.mul_(self.beta2).addcmul_(g, g, value=1-self.beta2)
|
|
277
|
+
bias_correction2 = 1.0 - (self.beta2 ** current_step)
|
|
278
|
+
|
|
279
|
+
# divide by bias corrected second moments
|
|
280
|
+
dir = g / (exp_avg_sq / bias_correction2).sqrt().add_(self.eps)
|
|
281
|
+
|
|
282
|
+
# update first moments and bias correct
|
|
283
|
+
exp_avg.lerp_(dir, 1-self.beta1)
|
|
284
|
+
bias_correction1 = 1.0 - (self.beta1 ** current_step)
|
|
285
|
+
dir = exp_avg / bias_correction1
|
|
286
|
+
|
|
287
|
+
if self.cautious:
|
|
288
|
+
mask = (g * dir) > 0
|
|
289
|
+
dir *= mask
|
|
290
|
+
|
|
291
|
+
state["current_step"] = current_step + 1
|
|
292
|
+
return Q @ dir
|
|
293
|
+
|
|
294
|
+
def reproject(self, L_old, Q_old, L_new, Q_new, state):
|
|
295
|
+
if "exp_avg" not in state: return
|
|
296
|
+
C = Q_new.T @ Q_old
|
|
297
|
+
|
|
298
|
+
state["exp_avg"] = C @ state["exp_avg"]
|
|
299
|
+
state["exp_avg_sq"] = _squared_reproject(C, state["exp_avg_sq"], self.exact_reproject)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
|
|
3
|
-
from ...core import
|
|
3
|
+
from ...core import TensorTransform
|
|
4
4
|
from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
|
|
5
5
|
|
|
6
6
|
|
|
@@ -20,7 +20,7 @@ def mars_correction_(
|
|
|
20
20
|
|
|
21
21
|
return c
|
|
22
22
|
|
|
23
|
-
class MARSCorrection(
|
|
23
|
+
class MARSCorrection(TensorTransform):
|
|
24
24
|
"""MARS variance reduction correction.
|
|
25
25
|
|
|
26
26
|
Place any other momentum-based optimizer after this,
|
|
@@ -35,7 +35,7 @@ class MARSCorrection(Transform):
|
|
|
35
35
|
|
|
36
36
|
Mars-AdamW
|
|
37
37
|
```python
|
|
38
|
-
optimizer = tz.
|
|
38
|
+
optimizer = tz.Optimizer(
|
|
39
39
|
model.parameters(),
|
|
40
40
|
tz.m.MARSCorrection(beta=0.95),
|
|
41
41
|
tz.m.Adam(beta1=0.95, beta2=0.99),
|
|
@@ -46,7 +46,7 @@ class MARSCorrection(Transform):
|
|
|
46
46
|
|
|
47
47
|
Mars-Lion
|
|
48
48
|
```python
|
|
49
|
-
optimizer = tz.
|
|
49
|
+
optimizer = tz.Optimizer(
|
|
50
50
|
model.parameters(),
|
|
51
51
|
tz.m.MARSCorrection(beta=0.9),
|
|
52
52
|
tz.m.Lion(beta1=0.9),
|
|
@@ -61,11 +61,11 @@ class MARSCorrection(Transform):
|
|
|
61
61
|
scaling: float = 0.025,
|
|
62
62
|
max_norm: float | None = 1,
|
|
63
63
|
):
|
|
64
|
-
defaults=dict(beta=beta, scaling=scaling, max_norm=max_norm)
|
|
65
|
-
super().__init__(defaults
|
|
64
|
+
defaults = dict(beta=beta, scaling=scaling, max_norm=max_norm)
|
|
65
|
+
super().__init__(defaults)
|
|
66
66
|
|
|
67
67
|
@torch.no_grad
|
|
68
|
-
def
|
|
68
|
+
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
69
69
|
prev = unpack_states(states, tensors, 'prev', init=tensors, cls=TensorList)
|
|
70
70
|
beta, scaling = unpack_dicts(settings, 'beta', 'scaling', cls=NumberList)
|
|
71
71
|
max_norm = settings[0]['max_norm']
|
|
@@ -1,14 +1,13 @@
|
|
|
1
1
|
from typing import Literal
|
|
2
|
-
|
|
2
|
+
|
|
3
3
|
import torch
|
|
4
4
|
|
|
5
|
-
from ...core import
|
|
6
|
-
from ...utils import NumberList, TensorList,
|
|
7
|
-
from
|
|
8
|
-
from ..functional import initial_step_size
|
|
5
|
+
from ...core import Chainable, Transform, HVPMethod
|
|
6
|
+
from ...utils import NumberList, TensorList, unpack_states, unpack_dicts
|
|
7
|
+
from ..opt_utils import initial_step_size
|
|
9
8
|
|
|
10
9
|
|
|
11
|
-
class MatrixMomentum(
|
|
10
|
+
class MatrixMomentum(Transform):
|
|
12
11
|
"""Second order momentum method.
|
|
13
12
|
|
|
14
13
|
Matrix momentum is useful for convex objectives, also for some reason it has very really good generalization on elastic net logistic regression.
|
|
@@ -23,17 +22,17 @@ class MatrixMomentum(Module):
|
|
|
23
22
|
Args:
|
|
24
23
|
mu (float, optional): this has a similar role to (1 - beta) in normal momentum. Defaults to 0.1.
|
|
25
24
|
hvp_method (str, optional):
|
|
26
|
-
Determines how
|
|
27
|
-
|
|
28
|
-
- ``"
|
|
29
|
-
|
|
30
|
-
- ``"
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
25
|
+
Determines how hessian-vector products are computed.
|
|
26
|
+
|
|
27
|
+
- ``"batched_autograd"`` - uses autograd with batched hessian-vector products. If a single hessian-vector is evaluated, equivalent to ``"autograd"``. Faster than ``"autograd"`` but uses more memory.
|
|
28
|
+
- ``"autograd"`` - uses autograd hessian-vector products. If multiple hessian-vector products are evaluated, uses a for-loop. Slower than ``"batched_autograd"`` but uses less memory.
|
|
29
|
+
- ``"fd_forward"`` - uses gradient finite difference approximation with a less accurate forward formula which requires one extra gradient evaluation per hessian-vector product.
|
|
30
|
+
- ``"fd_central"`` - uses gradient finite difference approximation with a more accurate central formula which requires two gradient evaluations per hessian-vector product.
|
|
31
|
+
|
|
32
|
+
Defaults to ``"autograd"``.
|
|
33
|
+
h (float, optional):
|
|
34
|
+
The step size for finite difference if ``hvp_method`` is
|
|
35
|
+
``"fd_forward"`` or ``"fd_central"``. Defaults to 1e-3.
|
|
37
36
|
hvp_tfm (Chainable | None, optional): optional module applied to hessian-vector products. Defaults to None.
|
|
38
37
|
|
|
39
38
|
Reference:
|
|
@@ -44,51 +43,45 @@ class MatrixMomentum(Module):
|
|
|
44
43
|
self,
|
|
45
44
|
lr:float,
|
|
46
45
|
mu=0.1,
|
|
47
|
-
hvp_method:
|
|
46
|
+
hvp_method: HVPMethod = "autograd",
|
|
48
47
|
h: float = 1e-3,
|
|
49
48
|
adaptive:bool = False,
|
|
50
49
|
adapt_freq: int | None = None,
|
|
51
|
-
|
|
50
|
+
|
|
51
|
+
inner: Chainable | None = None,
|
|
52
52
|
):
|
|
53
53
|
defaults = dict(lr=lr, mu=mu, hvp_method=hvp_method, h=h, adaptive=adaptive, adapt_freq=adapt_freq)
|
|
54
|
-
super().__init__(defaults)
|
|
55
|
-
|
|
56
|
-
if hvp_tfm is not None:
|
|
57
|
-
self.set_child('hvp_tfm', hvp_tfm)
|
|
54
|
+
super().__init__(defaults, inner=inner)
|
|
58
55
|
|
|
59
56
|
def reset_for_online(self):
|
|
60
57
|
super().reset_for_online()
|
|
61
58
|
self.clear_state_keys('p_prev')
|
|
62
59
|
|
|
63
60
|
@torch.no_grad
|
|
64
|
-
def
|
|
65
|
-
|
|
66
|
-
p = TensorList(
|
|
67
|
-
p_prev =
|
|
61
|
+
def update_states(self, objective, states, settings):
|
|
62
|
+
step = self.increment_counter("step", 0)
|
|
63
|
+
p = TensorList(objective.params)
|
|
64
|
+
p_prev = unpack_states(states, p, 'p_prev', init=p)
|
|
68
65
|
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
self.global_state["step"] = step + 1
|
|
66
|
+
fs = settings[0]
|
|
67
|
+
hvp_method = fs['hvp_method']
|
|
68
|
+
h = fs['h']
|
|
73
69
|
|
|
74
70
|
if step > 0:
|
|
75
71
|
s = p - p_prev
|
|
76
72
|
|
|
77
|
-
Hs, _ =
|
|
73
|
+
Hs, _ = objective.hessian_vector_product(s, at_x0=True, rgrad=None, hvp_method=hvp_method, h=h, retain_graph=False)
|
|
78
74
|
Hs = [t.detach() for t in Hs]
|
|
79
75
|
|
|
80
|
-
if 'hvp_tfm' in self.children:
|
|
81
|
-
Hs = TensorList(apply_transform(self.children['hvp_tfm'], Hs, params=p, grads=var.grad, var=var))
|
|
82
|
-
|
|
83
76
|
self.store(p, ("Hs", "s"), (Hs, s))
|
|
84
77
|
|
|
85
78
|
# -------------------------------- adaptive mu ------------------------------- #
|
|
86
|
-
if
|
|
87
|
-
g = TensorList(
|
|
79
|
+
if fs["adaptive"]:
|
|
80
|
+
g = TensorList(objective.get_grads())
|
|
88
81
|
|
|
89
|
-
if
|
|
82
|
+
if fs["adapt_freq"] is None:
|
|
90
83
|
# ---------------------------- deterministic case ---------------------------- #
|
|
91
|
-
g_prev =
|
|
84
|
+
g_prev = unpack_states(states, p, "g_prev", cls=TensorList)
|
|
92
85
|
y = g - g_prev
|
|
93
86
|
g_prev.copy_(g)
|
|
94
87
|
denom = y.global_vector_norm()
|
|
@@ -101,14 +94,14 @@ class MatrixMomentum(Module):
|
|
|
101
94
|
|
|
102
95
|
# we start on 1nd step, and want to adapt when we start, so use (step - 1)
|
|
103
96
|
if (step - 1) % adapt_freq == 0:
|
|
104
|
-
assert
|
|
105
|
-
params = TensorList(
|
|
97
|
+
assert objective.closure is not None
|
|
98
|
+
params = TensorList(objective.params)
|
|
106
99
|
p_cur = params.clone()
|
|
107
100
|
|
|
108
101
|
# move to previous params and evaluate p_prev with current mini-batch
|
|
109
|
-
params.copy_(
|
|
102
|
+
params.copy_(unpack_states(states, p, 'p_prev'))
|
|
110
103
|
with torch.enable_grad():
|
|
111
|
-
|
|
104
|
+
objective.closure()
|
|
112
105
|
g_prev = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
|
|
113
106
|
y = g - g_prev
|
|
114
107
|
|
|
@@ -119,12 +112,12 @@ class MatrixMomentum(Module):
|
|
|
119
112
|
denom = denom.clip(min=torch.finfo(denom.dtype).tiny * 2)
|
|
120
113
|
self.global_state["mu_mul"] = s.global_vector_norm() / denom
|
|
121
114
|
|
|
122
|
-
torch._foreach_copy_(p_prev,
|
|
115
|
+
torch._foreach_copy_(p_prev, objective.params)
|
|
123
116
|
|
|
124
117
|
@torch.no_grad
|
|
125
|
-
def
|
|
126
|
-
update = TensorList(
|
|
127
|
-
lr,mu =
|
|
118
|
+
def apply_states(self, objective, states, settings):
|
|
119
|
+
update = TensorList(objective.get_updates())
|
|
120
|
+
lr, mu = unpack_dicts(settings, "lr", 'mu', cls=NumberList)
|
|
128
121
|
|
|
129
122
|
if "mu_mul" in self.global_state:
|
|
130
123
|
mu = mu * self.global_state["mu_mul"]
|
|
@@ -133,14 +126,17 @@ class MatrixMomentum(Module):
|
|
|
133
126
|
# p_prev is not available so make a small step
|
|
134
127
|
step = self.global_state["step"]
|
|
135
128
|
if step == 1:
|
|
136
|
-
if self.defaults["adaptive"]:
|
|
129
|
+
if self.defaults["adaptive"]:
|
|
130
|
+
# initialize
|
|
131
|
+
unpack_states(states, objective.params, "g_prev", init=objective.get_grads())
|
|
132
|
+
|
|
137
133
|
update.mul_(lr) # separate so that initial_step_size can clip correctly
|
|
138
134
|
update.mul_(initial_step_size(update, 1e-7))
|
|
139
|
-
return
|
|
135
|
+
return objective
|
|
140
136
|
|
|
141
137
|
# -------------------------- matrix momentum update -------------------------- #
|
|
142
|
-
s, Hs =
|
|
138
|
+
s, Hs = unpack_states(states, objective.params, 's', 'Hs', cls=TensorList)
|
|
143
139
|
|
|
144
140
|
update.mul_(lr).sub_(s).add_(Hs*mu)
|
|
145
|
-
|
|
146
|
-
return
|
|
141
|
+
objective.updates = update
|
|
142
|
+
return objective
|