torchzero 0.3.10__py3-none-any.whl → 0.3.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_identical.py +2 -3
- tests/test_opts.py +140 -100
- tests/test_tensorlist.py +8 -7
- tests/test_vars.py +1 -0
- torchzero/__init__.py +1 -1
- torchzero/core/__init__.py +2 -2
- torchzero/core/module.py +335 -50
- torchzero/core/reformulation.py +65 -0
- torchzero/core/transform.py +197 -70
- torchzero/modules/__init__.py +13 -4
- torchzero/modules/adaptive/__init__.py +30 -0
- torchzero/modules/adaptive/adagrad.py +356 -0
- torchzero/modules/adaptive/adahessian.py +224 -0
- torchzero/modules/{optimizers → adaptive}/adam.py +6 -8
- torchzero/modules/adaptive/adan.py +96 -0
- torchzero/modules/adaptive/adaptive_heavyball.py +54 -0
- torchzero/modules/adaptive/aegd.py +54 -0
- torchzero/modules/adaptive/esgd.py +171 -0
- torchzero/modules/{optimizers → adaptive}/lion.py +1 -1
- torchzero/modules/{experimental/spectral.py → adaptive/lmadagrad.py} +94 -71
- torchzero/modules/adaptive/mars.py +79 -0
- torchzero/modules/adaptive/matrix_momentum.py +146 -0
- torchzero/modules/adaptive/msam.py +188 -0
- torchzero/modules/{optimizers → adaptive}/muon.py +29 -5
- torchzero/modules/adaptive/natural_gradient.py +175 -0
- torchzero/modules/{optimizers → adaptive}/orthograd.py +1 -1
- torchzero/modules/{optimizers → adaptive}/rmsprop.py +7 -4
- torchzero/modules/{optimizers → adaptive}/rprop.py +42 -10
- torchzero/modules/adaptive/sam.py +163 -0
- torchzero/modules/{optimizers → adaptive}/shampoo.py +47 -9
- torchzero/modules/{optimizers → adaptive}/soap.py +52 -65
- torchzero/modules/adaptive/sophia_h.py +185 -0
- torchzero/modules/clipping/clipping.py +115 -25
- torchzero/modules/clipping/ema_clipping.py +31 -17
- torchzero/modules/clipping/growth_clipping.py +8 -7
- torchzero/modules/conjugate_gradient/__init__.py +11 -0
- torchzero/modules/conjugate_gradient/cg.py +355 -0
- torchzero/modules/experimental/__init__.py +13 -19
- torchzero/modules/{projections → experimental}/dct.py +11 -11
- torchzero/modules/{projections → experimental}/fft.py +10 -10
- torchzero/modules/experimental/gradmin.py +4 -3
- torchzero/modules/experimental/l_infinity.py +111 -0
- torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +5 -42
- torchzero/modules/experimental/newton_solver.py +79 -17
- torchzero/modules/experimental/newtonnewton.py +32 -15
- torchzero/modules/experimental/reduce_outward_lr.py +4 -4
- torchzero/modules/experimental/scipy_newton_cg.py +105 -0
- torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +13 -55
- torchzero/modules/functional.py +52 -6
- torchzero/modules/grad_approximation/fdm.py +30 -4
- torchzero/modules/grad_approximation/forward_gradient.py +16 -4
- torchzero/modules/grad_approximation/grad_approximator.py +51 -10
- torchzero/modules/grad_approximation/rfdm.py +321 -52
- torchzero/modules/higher_order/__init__.py +1 -1
- torchzero/modules/higher_order/higher_order_newton.py +164 -93
- torchzero/modules/least_squares/__init__.py +1 -0
- torchzero/modules/least_squares/gn.py +161 -0
- torchzero/modules/line_search/__init__.py +4 -4
- torchzero/modules/line_search/_polyinterp.py +289 -0
- torchzero/modules/line_search/adaptive.py +124 -0
- torchzero/modules/line_search/backtracking.py +95 -57
- torchzero/modules/line_search/line_search.py +171 -22
- torchzero/modules/line_search/scipy.py +3 -3
- torchzero/modules/line_search/strong_wolfe.py +327 -199
- torchzero/modules/misc/__init__.py +35 -0
- torchzero/modules/misc/debug.py +48 -0
- torchzero/modules/misc/escape.py +62 -0
- torchzero/modules/misc/gradient_accumulation.py +136 -0
- torchzero/modules/misc/homotopy.py +59 -0
- torchzero/modules/misc/misc.py +383 -0
- torchzero/modules/misc/multistep.py +194 -0
- torchzero/modules/misc/regularization.py +167 -0
- torchzero/modules/misc/split.py +123 -0
- torchzero/modules/{ops → misc}/switch.py +45 -4
- torchzero/modules/momentum/__init__.py +1 -5
- torchzero/modules/momentum/averaging.py +9 -9
- torchzero/modules/momentum/cautious.py +51 -19
- torchzero/modules/momentum/momentum.py +37 -2
- torchzero/modules/ops/__init__.py +11 -31
- torchzero/modules/ops/accumulate.py +6 -10
- torchzero/modules/ops/binary.py +81 -34
- torchzero/modules/{momentum/ema.py → ops/higher_level.py} +16 -39
- torchzero/modules/ops/multi.py +82 -21
- torchzero/modules/ops/reduce.py +16 -8
- torchzero/modules/ops/unary.py +29 -13
- torchzero/modules/ops/utility.py +30 -18
- torchzero/modules/projections/__init__.py +2 -4
- torchzero/modules/projections/cast.py +51 -0
- torchzero/modules/projections/galore.py +3 -1
- torchzero/modules/projections/projection.py +190 -96
- torchzero/modules/quasi_newton/__init__.py +9 -14
- torchzero/modules/quasi_newton/damping.py +105 -0
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -0
- torchzero/modules/quasi_newton/lbfgs.py +286 -173
- torchzero/modules/quasi_newton/lsr1.py +185 -106
- torchzero/modules/quasi_newton/quasi_newton.py +816 -268
- torchzero/modules/restarts/__init__.py +7 -0
- torchzero/modules/restarts/restars.py +252 -0
- torchzero/modules/second_order/__init__.py +3 -2
- torchzero/modules/second_order/multipoint.py +238 -0
- torchzero/modules/second_order/newton.py +292 -68
- torchzero/modules/second_order/newton_cg.py +365 -15
- torchzero/modules/second_order/nystrom.py +104 -1
- torchzero/modules/smoothing/__init__.py +1 -1
- torchzero/modules/smoothing/laplacian.py +14 -4
- torchzero/modules/smoothing/sampling.py +300 -0
- torchzero/modules/step_size/__init__.py +2 -0
- torchzero/modules/step_size/adaptive.py +387 -0
- torchzero/modules/step_size/lr.py +154 -0
- torchzero/modules/termination/__init__.py +14 -0
- torchzero/modules/termination/termination.py +207 -0
- torchzero/modules/trust_region/__init__.py +5 -0
- torchzero/modules/trust_region/cubic_regularization.py +170 -0
- torchzero/modules/trust_region/dogleg.py +92 -0
- torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
- torchzero/modules/trust_region/trust_cg.py +97 -0
- torchzero/modules/trust_region/trust_region.py +350 -0
- torchzero/modules/variance_reduction/__init__.py +1 -0
- torchzero/modules/variance_reduction/svrg.py +208 -0
- torchzero/modules/weight_decay/__init__.py +1 -1
- torchzero/modules/weight_decay/weight_decay.py +94 -11
- torchzero/modules/wrappers/optim_wrapper.py +29 -1
- torchzero/modules/zeroth_order/__init__.py +1 -0
- torchzero/modules/zeroth_order/cd.py +359 -0
- torchzero/optim/root.py +65 -0
- torchzero/optim/utility/split.py +8 -8
- torchzero/optim/wrappers/directsearch.py +39 -3
- torchzero/optim/wrappers/fcmaes.py +24 -15
- torchzero/optim/wrappers/mads.py +5 -6
- torchzero/optim/wrappers/nevergrad.py +16 -1
- torchzero/optim/wrappers/nlopt.py +0 -2
- torchzero/optim/wrappers/optuna.py +3 -3
- torchzero/optim/wrappers/scipy.py +86 -25
- torchzero/utils/__init__.py +40 -4
- torchzero/utils/compile.py +1 -1
- torchzero/utils/derivatives.py +126 -114
- torchzero/utils/linalg/__init__.py +9 -2
- torchzero/utils/linalg/linear_operator.py +329 -0
- torchzero/utils/linalg/matrix_funcs.py +2 -2
- torchzero/utils/linalg/orthogonalize.py +2 -1
- torchzero/utils/linalg/qr.py +2 -2
- torchzero/utils/linalg/solve.py +369 -58
- torchzero/utils/metrics.py +83 -0
- torchzero/utils/numberlist.py +2 -0
- torchzero/utils/python_tools.py +16 -0
- torchzero/utils/tensorlist.py +134 -51
- torchzero/utils/torch_tools.py +9 -4
- torchzero-0.3.13.dist-info/METADATA +14 -0
- torchzero-0.3.13.dist-info/RECORD +166 -0
- {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/top_level.txt +0 -1
- docs/source/conf.py +0 -57
- torchzero/modules/experimental/absoap.py +0 -250
- torchzero/modules/experimental/adadam.py +0 -112
- torchzero/modules/experimental/adamY.py +0 -125
- torchzero/modules/experimental/adasoap.py +0 -172
- torchzero/modules/experimental/diagonal_higher_order_newton.py +0 -225
- torchzero/modules/experimental/eigendescent.py +0 -117
- torchzero/modules/experimental/etf.py +0 -172
- torchzero/modules/experimental/soapy.py +0 -163
- torchzero/modules/experimental/structured_newton.py +0 -111
- torchzero/modules/experimental/subspace_preconditioners.py +0 -138
- torchzero/modules/experimental/tada.py +0 -38
- torchzero/modules/line_search/trust_region.py +0 -73
- torchzero/modules/lr/__init__.py +0 -2
- torchzero/modules/lr/adaptive.py +0 -93
- torchzero/modules/lr/lr.py +0 -63
- torchzero/modules/momentum/matrix_momentum.py +0 -166
- torchzero/modules/ops/debug.py +0 -25
- torchzero/modules/ops/misc.py +0 -418
- torchzero/modules/ops/split.py +0 -75
- torchzero/modules/optimizers/__init__.py +0 -18
- torchzero/modules/optimizers/adagrad.py +0 -155
- torchzero/modules/optimizers/sophia_h.py +0 -129
- torchzero/modules/quasi_newton/cg.py +0 -268
- torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
- torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +0 -266
- torchzero/modules/quasi_newton/olbfgs.py +0 -196
- torchzero/modules/smoothing/gaussian.py +0 -164
- torchzero-0.3.10.dist-info/METADATA +0 -379
- torchzero-0.3.10.dist-info/RECORD +0 -139
- torchzero-0.3.10.dist-info/licenses/LICENSE +0 -21
- {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/WHEEL +0 -0
|
@@ -1,196 +0,0 @@
|
|
|
1
|
-
from collections import deque
|
|
2
|
-
from functools import partial
|
|
3
|
-
from operator import itemgetter
|
|
4
|
-
from typing import Literal
|
|
5
|
-
|
|
6
|
-
import torch
|
|
7
|
-
|
|
8
|
-
from ...core import Chainable, Module, Transform, Var, apply_transform
|
|
9
|
-
from ...utils import NumberList, TensorList, as_tensorlist
|
|
10
|
-
from .lbfgs import _adaptive_damping, lbfgs
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
@torch.no_grad
|
|
14
|
-
def _store_sk_yk_after_step_hook(optimizer, var: Var, prev_params: TensorList, prev_grad: TensorList, damping, init_damping, eigval_bounds, s_history: deque[TensorList], y_history: deque[TensorList], sy_history: deque[torch.Tensor]):
|
|
15
|
-
assert var.closure is not None
|
|
16
|
-
with torch.enable_grad(): var.closure()
|
|
17
|
-
grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in var.params]
|
|
18
|
-
s_k = var.params - prev_params
|
|
19
|
-
y_k = grad - prev_grad
|
|
20
|
-
ys_k = s_k.dot(y_k)
|
|
21
|
-
|
|
22
|
-
if damping:
|
|
23
|
-
s_k, y_k, ys_k = _adaptive_damping(s_k, y_k, ys_k, init_damping=init_damping, eigval_bounds=eigval_bounds)
|
|
24
|
-
|
|
25
|
-
if ys_k > 1e-10:
|
|
26
|
-
s_history.append(s_k)
|
|
27
|
-
y_history.append(y_k)
|
|
28
|
-
sy_history.append(ys_k)
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
class OnlineLBFGS(Module):
|
|
33
|
-
"""Online L-BFGS.
|
|
34
|
-
Parameter and gradient differences are sampled from the same mini-batch by performing an extra forward and backward pass.
|
|
35
|
-
However I did a bunch of experiments and the online part doesn't seem to help. Normal L-BFGS is usually still
|
|
36
|
-
better because it performs twice as many steps, and it is reasonably stable with normalization or grafting.
|
|
37
|
-
|
|
38
|
-
Args:
|
|
39
|
-
history_size (int, optional): number of past parameter differences and gradient differences to store. Defaults to 10.
|
|
40
|
-
sample_grads (str, optional):
|
|
41
|
-
- "before" - samples current mini-batch gradient at previous and current parameters, calculates y_k
|
|
42
|
-
and adds it to history before stepping.
|
|
43
|
-
- "after" - samples current mini-batch gradient at parameters before stepping and after updating parameters.
|
|
44
|
-
s_k and y_k are added after parameter update, therefore they are delayed by 1 step.
|
|
45
|
-
|
|
46
|
-
In practice both modes behave very similarly. Defaults to 'before'.
|
|
47
|
-
tol (float | None, optional):
|
|
48
|
-
tolerance for minimal gradient difference to avoid instability after converging to minima. Defaults to 1e-10.
|
|
49
|
-
damping (bool, optional):
|
|
50
|
-
whether to use adaptive damping. Learning rate might need to be lowered with this enabled. Defaults to False.
|
|
51
|
-
init_damping (float, optional):
|
|
52
|
-
initial damping for adaptive dampening. Defaults to 0.9.
|
|
53
|
-
eigval_bounds (tuple, optional):
|
|
54
|
-
eigenvalue bounds for adaptive dampening. Defaults to (0.5, 50).
|
|
55
|
-
params_beta (float | None, optional):
|
|
56
|
-
if not None, EMA of parameters is used for preconditioner update. Defaults to None.
|
|
57
|
-
grads_beta (float | None, optional):
|
|
58
|
-
if not None, EMA of gradients is used for preconditioner update. Defaults to None.
|
|
59
|
-
update_freq (int, optional):
|
|
60
|
-
how often to update L-BFGS history. Defaults to 1.
|
|
61
|
-
z_beta (float | None, optional):
|
|
62
|
-
optional EMA for initial H^-1 @ q. Acts as a kind of momentum but is prone to get stuck. Defaults to None.
|
|
63
|
-
inner (Chainable | None, optional):
|
|
64
|
-
optional inner modules applied after updating L-BFGS history and before preconditioning. Defaults to None.
|
|
65
|
-
"""
|
|
66
|
-
def __init__(
|
|
67
|
-
self,
|
|
68
|
-
history_size=10,
|
|
69
|
-
sample_grads: Literal['before', 'after'] = 'before',
|
|
70
|
-
tol: float | None = 1e-10,
|
|
71
|
-
damping: bool = False,
|
|
72
|
-
init_damping=0.9,
|
|
73
|
-
eigval_bounds=(0.5, 50),
|
|
74
|
-
z_beta: float | None = None,
|
|
75
|
-
inner: Chainable | None = None,
|
|
76
|
-
):
|
|
77
|
-
defaults = dict(history_size=history_size, tol=tol, damping=damping, init_damping=init_damping, eigval_bounds=eigval_bounds, sample_grads=sample_grads, z_beta=z_beta)
|
|
78
|
-
super().__init__(defaults)
|
|
79
|
-
|
|
80
|
-
self.global_state['s_history'] = deque(maxlen=history_size)
|
|
81
|
-
self.global_state['y_history'] = deque(maxlen=history_size)
|
|
82
|
-
self.global_state['sy_history'] = deque(maxlen=history_size)
|
|
83
|
-
|
|
84
|
-
if inner is not None:
|
|
85
|
-
self.set_child('inner', inner)
|
|
86
|
-
|
|
87
|
-
def reset(self):
|
|
88
|
-
"""Resets the internal state of the L-SR1 module."""
|
|
89
|
-
# super().reset() # Clears self.state (per-parameter) if any, and "step"
|
|
90
|
-
# Re-initialize L-SR1 specific global state
|
|
91
|
-
self.state.clear()
|
|
92
|
-
self.global_state['step'] = 0
|
|
93
|
-
self.global_state['s_history'].clear()
|
|
94
|
-
self.global_state['y_history'].clear()
|
|
95
|
-
self.global_state['sy_history'].clear()
|
|
96
|
-
|
|
97
|
-
@torch.no_grad
|
|
98
|
-
def step(self, var):
|
|
99
|
-
assert var.closure is not None
|
|
100
|
-
|
|
101
|
-
params = as_tensorlist(var.params)
|
|
102
|
-
update = as_tensorlist(var.get_update())
|
|
103
|
-
step = self.global_state.get('step', 0)
|
|
104
|
-
self.global_state['step'] = step + 1
|
|
105
|
-
|
|
106
|
-
# history of s and k
|
|
107
|
-
s_history: deque[TensorList] = self.global_state['s_history']
|
|
108
|
-
y_history: deque[TensorList] = self.global_state['y_history']
|
|
109
|
-
sy_history: deque[torch.Tensor] = self.global_state['sy_history']
|
|
110
|
-
|
|
111
|
-
tol, damping, init_damping, eigval_bounds, sample_grads, z_beta = itemgetter(
|
|
112
|
-
'tol', 'damping', 'init_damping', 'eigval_bounds', 'sample_grads', 'z_beta')(self.settings[params[0]])
|
|
113
|
-
|
|
114
|
-
# sample gradient at previous params with current mini-batch
|
|
115
|
-
if sample_grads == 'before':
|
|
116
|
-
prev_params = self.get_state(params, 'prev_params', cls=TensorList)
|
|
117
|
-
if step == 0:
|
|
118
|
-
s_k = None; y_k = None; ys_k = None
|
|
119
|
-
else:
|
|
120
|
-
s_k = params - prev_params
|
|
121
|
-
|
|
122
|
-
current_params = params.clone()
|
|
123
|
-
params.set_(prev_params)
|
|
124
|
-
with torch.enable_grad(): var.closure()
|
|
125
|
-
y_k = update - params.grad
|
|
126
|
-
ys_k = s_k.dot(y_k)
|
|
127
|
-
params.set_(current_params)
|
|
128
|
-
|
|
129
|
-
if damping:
|
|
130
|
-
s_k, y_k, ys_k = _adaptive_damping(s_k, y_k, ys_k, init_damping=init_damping, eigval_bounds=eigval_bounds)
|
|
131
|
-
|
|
132
|
-
if ys_k > 1e-10:
|
|
133
|
-
s_history.append(s_k)
|
|
134
|
-
y_history.append(y_k)
|
|
135
|
-
sy_history.append(ys_k)
|
|
136
|
-
|
|
137
|
-
prev_params.copy_(params)
|
|
138
|
-
|
|
139
|
-
# use previous s_k, y_k pair, samples gradient at current batch before and after updating parameters
|
|
140
|
-
elif sample_grads == 'after':
|
|
141
|
-
if len(s_history) == 0:
|
|
142
|
-
s_k = None; y_k = None; ys_k = None
|
|
143
|
-
else:
|
|
144
|
-
s_k = s_history[-1]
|
|
145
|
-
y_k = y_history[-1]
|
|
146
|
-
ys_k = s_k.dot(y_k)
|
|
147
|
-
|
|
148
|
-
# this will run after params are updated by Modular after running all future modules
|
|
149
|
-
var.post_step_hooks.append(
|
|
150
|
-
partial(
|
|
151
|
-
_store_sk_yk_after_step_hook,
|
|
152
|
-
prev_params=params.clone(),
|
|
153
|
-
prev_grad=update.clone(),
|
|
154
|
-
damping=damping,
|
|
155
|
-
init_damping=init_damping,
|
|
156
|
-
eigval_bounds=eigval_bounds,
|
|
157
|
-
s_history=s_history,
|
|
158
|
-
y_history=y_history,
|
|
159
|
-
sy_history=sy_history,
|
|
160
|
-
))
|
|
161
|
-
|
|
162
|
-
else:
|
|
163
|
-
raise ValueError(sample_grads)
|
|
164
|
-
|
|
165
|
-
# step with inner module before applying preconditioner
|
|
166
|
-
if self.children:
|
|
167
|
-
update = TensorList(apply_transform(self.children['inner'], tensors=update, params=params, grads=var.grad, var=var))
|
|
168
|
-
|
|
169
|
-
# tolerance on gradient difference to avoid exploding after converging
|
|
170
|
-
if tol is not None:
|
|
171
|
-
if y_k is not None and y_k.abs().global_max() <= tol:
|
|
172
|
-
var.update = update # may have been updated by inner module, probably makes sense to use it here?
|
|
173
|
-
return var
|
|
174
|
-
|
|
175
|
-
# lerp initial H^-1 @ q guess
|
|
176
|
-
z_ema = None
|
|
177
|
-
if z_beta is not None:
|
|
178
|
-
z_ema = self.get_state(params, 'z_ema', cls=TensorList)
|
|
179
|
-
|
|
180
|
-
# precondition
|
|
181
|
-
dir = lbfgs(
|
|
182
|
-
tensors_=as_tensorlist(update),
|
|
183
|
-
s_history=s_history,
|
|
184
|
-
y_history=y_history,
|
|
185
|
-
sy_history=sy_history,
|
|
186
|
-
y_k=y_k,
|
|
187
|
-
ys_k=ys_k,
|
|
188
|
-
z_beta = z_beta,
|
|
189
|
-
z_ema = z_ema,
|
|
190
|
-
step=step
|
|
191
|
-
)
|
|
192
|
-
|
|
193
|
-
var.update = dir
|
|
194
|
-
|
|
195
|
-
return var
|
|
196
|
-
|
|
@@ -1,164 +0,0 @@
|
|
|
1
|
-
import warnings
|
|
2
|
-
from abc import ABC, abstractmethod
|
|
3
|
-
from collections.abc import Callable, Sequence
|
|
4
|
-
from functools import partial
|
|
5
|
-
from typing import Literal
|
|
6
|
-
|
|
7
|
-
import torch
|
|
8
|
-
|
|
9
|
-
from ...core import Modular, Module, Var
|
|
10
|
-
from ...utils import NumberList, TensorList
|
|
11
|
-
from ...utils.derivatives import jacobian_wrt
|
|
12
|
-
from ..grad_approximation import GradApproximator, GradTarget
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
class Reformulation(Module, ABC):
|
|
16
|
-
def __init__(self, defaults):
|
|
17
|
-
super().__init__(defaults)
|
|
18
|
-
|
|
19
|
-
@abstractmethod
|
|
20
|
-
def closure(self, backward: bool, closure: Callable, params:list[torch.Tensor], var: Var) -> tuple[float | torch.Tensor, Sequence[torch.Tensor] | None]:
|
|
21
|
-
"""returns loss and gradient, if backward is False then gradient can be None"""
|
|
22
|
-
|
|
23
|
-
def pre_step(self, var: Var) -> Var | None:
|
|
24
|
-
"""This runs once before each step, whereas `closure` may run multiple times per step if further modules
|
|
25
|
-
evaluate gradients at multiple points. This is useful for example to pre-generate new random perturbations."""
|
|
26
|
-
return var
|
|
27
|
-
|
|
28
|
-
def step(self, var):
|
|
29
|
-
ret = self.pre_step(var)
|
|
30
|
-
if isinstance(ret, Var): var = ret
|
|
31
|
-
|
|
32
|
-
if var.closure is None: raise RuntimeError("Reformulation requires closure")
|
|
33
|
-
params, closure = var.params, var.closure
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
def modified_closure(backward=True):
|
|
37
|
-
loss, grad = self.closure(backward, closure, params, var)
|
|
38
|
-
|
|
39
|
-
if grad is not None:
|
|
40
|
-
for p,g in zip(params, grad):
|
|
41
|
-
p.grad = g
|
|
42
|
-
|
|
43
|
-
return loss
|
|
44
|
-
|
|
45
|
-
var.closure = modified_closure
|
|
46
|
-
return var
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
def _decay_sigma_(self: Module, params):
|
|
50
|
-
for p in params:
|
|
51
|
-
state = self.state[p]
|
|
52
|
-
settings = self.settings[p]
|
|
53
|
-
state['sigma'] *= settings['decay']
|
|
54
|
-
|
|
55
|
-
def _generate_perturbations_to_state_(self: Module, params: TensorList, n_samples, sigmas, generator):
|
|
56
|
-
perturbations = [params.sample_like(generator=generator) for _ in range(n_samples)]
|
|
57
|
-
torch._foreach_mul_([p for l in perturbations for p in l], [v for vv in sigmas for v in [vv]*n_samples])
|
|
58
|
-
for param, prt in zip(params, zip(*perturbations)):
|
|
59
|
-
self.state[param]['perturbations'] = prt
|
|
60
|
-
|
|
61
|
-
def _clear_state_hook(optimizer: Modular, var: Var, self: Module):
|
|
62
|
-
for m in optimizer.unrolled_modules:
|
|
63
|
-
if m is not self:
|
|
64
|
-
m.reset()
|
|
65
|
-
|
|
66
|
-
class GaussianHomotopy(Reformulation):
|
|
67
|
-
def __init__(
|
|
68
|
-
self,
|
|
69
|
-
n_samples: int,
|
|
70
|
-
init_sigma: float,
|
|
71
|
-
tol: float | None = 1e-4,
|
|
72
|
-
decay=0.5,
|
|
73
|
-
max_steps: int | None = None,
|
|
74
|
-
clear_state=True,
|
|
75
|
-
seed: int | None = None,
|
|
76
|
-
):
|
|
77
|
-
defaults = dict(n_samples=n_samples, init_sigma=init_sigma, tol=tol, decay=decay, max_steps=max_steps, clear_state=clear_state, seed=seed)
|
|
78
|
-
super().__init__(defaults)
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
def _get_generator(self, seed: int | None | torch.Generator, params: list[torch.Tensor]):
|
|
82
|
-
if 'generator' not in self.global_state:
|
|
83
|
-
if isinstance(seed, torch.Generator): self.global_state['generator'] = seed
|
|
84
|
-
elif seed is not None: self.global_state['generator'] = torch.Generator(params[0].device).manual_seed(seed)
|
|
85
|
-
else: self.global_state['generator'] = None
|
|
86
|
-
return self.global_state['generator']
|
|
87
|
-
|
|
88
|
-
def pre_step(self, var):
|
|
89
|
-
params = TensorList(var.params)
|
|
90
|
-
settings = self.settings[params[0]]
|
|
91
|
-
n_samples = settings['n_samples']
|
|
92
|
-
init_sigma = [self.settings[p]['init_sigma'] for p in params]
|
|
93
|
-
sigmas = self.get_state(params, 'sigma', init=init_sigma)
|
|
94
|
-
|
|
95
|
-
if any('perturbations' not in self.state[p] for p in params):
|
|
96
|
-
generator = self._get_generator(settings['seed'], params)
|
|
97
|
-
_generate_perturbations_to_state_(self, params=params, n_samples=n_samples, sigmas=sigmas, generator=generator)
|
|
98
|
-
|
|
99
|
-
# sigma decay rules
|
|
100
|
-
max_steps = settings['max_steps']
|
|
101
|
-
decayed = False
|
|
102
|
-
if max_steps is not None and max_steps > 0:
|
|
103
|
-
level_steps = self.global_state['level_steps'] = self.global_state.get('level_steps', 0) + 1
|
|
104
|
-
if level_steps > max_steps:
|
|
105
|
-
self.global_state['level_steps'] = 0
|
|
106
|
-
_decay_sigma_(self, params)
|
|
107
|
-
decayed = True
|
|
108
|
-
|
|
109
|
-
tol = settings['tol']
|
|
110
|
-
if tol is not None and not decayed:
|
|
111
|
-
if not any('prev_params' in self.state[p] for p in params):
|
|
112
|
-
prev_params = self.get_state(params, 'prev_params', cls=TensorList, init='param')
|
|
113
|
-
else:
|
|
114
|
-
prev_params = self.get_state(params, 'prev_params', cls=TensorList, init='param')
|
|
115
|
-
s = params - prev_params
|
|
116
|
-
|
|
117
|
-
if s.abs().global_max() <= tol:
|
|
118
|
-
_decay_sigma_(self, params)
|
|
119
|
-
decayed = True
|
|
120
|
-
|
|
121
|
-
prev_params.copy_(params)
|
|
122
|
-
|
|
123
|
-
if decayed:
|
|
124
|
-
generator = self._get_generator(settings['seed'], params)
|
|
125
|
-
_generate_perturbations_to_state_(self, params=params, n_samples=n_samples, sigmas=sigmas, generator=generator)
|
|
126
|
-
if settings['clear_state']:
|
|
127
|
-
var.post_step_hooks.append(partial(_clear_state_hook, self=self))
|
|
128
|
-
|
|
129
|
-
@torch.no_grad
|
|
130
|
-
def closure(self, backward, closure, params, var):
|
|
131
|
-
params = TensorList(params)
|
|
132
|
-
|
|
133
|
-
settings = self.settings[params[0]]
|
|
134
|
-
n_samples = settings['n_samples']
|
|
135
|
-
|
|
136
|
-
perturbations = list(zip(*(self.state[p]['perturbations'] for p in params)))
|
|
137
|
-
|
|
138
|
-
loss = None
|
|
139
|
-
grad = None
|
|
140
|
-
for i in range(n_samples):
|
|
141
|
-
prt = perturbations[i]
|
|
142
|
-
|
|
143
|
-
params.add_(prt)
|
|
144
|
-
if backward:
|
|
145
|
-
with torch.enable_grad(): l = closure()
|
|
146
|
-
if grad is None: grad = params.grad
|
|
147
|
-
else: grad += params.grad
|
|
148
|
-
|
|
149
|
-
else:
|
|
150
|
-
l = closure(False)
|
|
151
|
-
|
|
152
|
-
if loss is None: loss = l
|
|
153
|
-
else: loss = loss+l
|
|
154
|
-
|
|
155
|
-
params.sub_(prt)
|
|
156
|
-
|
|
157
|
-
assert loss is not None
|
|
158
|
-
if n_samples > 1:
|
|
159
|
-
loss = loss / n_samples
|
|
160
|
-
if backward:
|
|
161
|
-
assert grad is not None
|
|
162
|
-
grad.div_(n_samples)
|
|
163
|
-
|
|
164
|
-
return loss, grad
|