torchzero 0.3.15__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_identical.py +2 -2
- tests/test_module_autograd.py +586 -0
- tests/test_objective.py +188 -0
- tests/test_opts.py +43 -33
- tests/test_tensorlist.py +0 -8
- tests/test_utils_optimizer.py +0 -1
- torchzero/__init__.py +1 -1
- torchzero/core/__init__.py +7 -4
- torchzero/core/chain.py +20 -23
- torchzero/core/functional.py +90 -24
- torchzero/core/modular.py +48 -52
- torchzero/core/module.py +130 -50
- torchzero/core/objective.py +948 -0
- torchzero/core/reformulation.py +55 -24
- torchzero/core/transform.py +261 -367
- torchzero/linalg/__init__.py +10 -0
- torchzero/linalg/eigh.py +34 -0
- torchzero/linalg/linalg_utils.py +14 -0
- torchzero/{utils/linalg → linalg}/linear_operator.py +99 -49
- torchzero/linalg/matrix_power.py +28 -0
- torchzero/linalg/orthogonalize.py +95 -0
- torchzero/{utils/linalg → linalg}/qr.py +4 -2
- torchzero/{utils/linalg → linalg}/solve.py +76 -88
- torchzero/linalg/svd.py +20 -0
- torchzero/linalg/torch_linalg.py +168 -0
- torchzero/modules/adaptive/__init__.py +1 -1
- torchzero/modules/adaptive/adagrad.py +163 -213
- torchzero/modules/adaptive/adahessian.py +74 -103
- torchzero/modules/adaptive/adam.py +53 -76
- torchzero/modules/adaptive/adan.py +49 -30
- torchzero/modules/adaptive/adaptive_heavyball.py +11 -6
- torchzero/modules/adaptive/aegd.py +12 -12
- torchzero/modules/adaptive/esgd.py +98 -119
- torchzero/modules/adaptive/lion.py +5 -10
- torchzero/modules/adaptive/lmadagrad.py +87 -32
- torchzero/modules/adaptive/mars.py +5 -5
- torchzero/modules/adaptive/matrix_momentum.py +47 -51
- torchzero/modules/adaptive/msam.py +70 -52
- torchzero/modules/adaptive/muon.py +59 -124
- torchzero/modules/adaptive/natural_gradient.py +33 -28
- torchzero/modules/adaptive/orthograd.py +11 -15
- torchzero/modules/adaptive/rmsprop.py +83 -75
- torchzero/modules/adaptive/rprop.py +48 -47
- torchzero/modules/adaptive/sam.py +55 -45
- torchzero/modules/adaptive/shampoo.py +123 -129
- torchzero/modules/adaptive/soap.py +207 -143
- torchzero/modules/adaptive/sophia_h.py +106 -130
- torchzero/modules/clipping/clipping.py +15 -18
- torchzero/modules/clipping/ema_clipping.py +31 -25
- torchzero/modules/clipping/growth_clipping.py +14 -17
- torchzero/modules/conjugate_gradient/cg.py +26 -37
- torchzero/modules/experimental/__init__.py +2 -6
- torchzero/modules/experimental/coordinate_momentum.py +36 -0
- torchzero/modules/experimental/curveball.py +25 -41
- torchzero/modules/experimental/gradmin.py +2 -2
- torchzero/modules/experimental/higher_order_newton.py +14 -40
- torchzero/modules/experimental/newton_solver.py +22 -53
- torchzero/modules/experimental/newtonnewton.py +15 -12
- torchzero/modules/experimental/reduce_outward_lr.py +7 -7
- torchzero/modules/experimental/scipy_newton_cg.py +21 -24
- torchzero/modules/experimental/spsa1.py +3 -3
- torchzero/modules/experimental/structural_projections.py +1 -4
- torchzero/modules/functional.py +1 -1
- torchzero/modules/grad_approximation/forward_gradient.py +7 -7
- torchzero/modules/grad_approximation/grad_approximator.py +23 -16
- torchzero/modules/grad_approximation/rfdm.py +20 -17
- torchzero/modules/least_squares/gn.py +90 -42
- torchzero/modules/line_search/backtracking.py +2 -2
- torchzero/modules/line_search/line_search.py +32 -32
- torchzero/modules/line_search/strong_wolfe.py +2 -2
- torchzero/modules/misc/debug.py +12 -12
- torchzero/modules/misc/escape.py +10 -10
- torchzero/modules/misc/gradient_accumulation.py +10 -78
- torchzero/modules/misc/homotopy.py +16 -8
- torchzero/modules/misc/misc.py +120 -122
- torchzero/modules/misc/multistep.py +50 -48
- torchzero/modules/misc/regularization.py +49 -44
- torchzero/modules/misc/split.py +30 -28
- torchzero/modules/misc/switch.py +37 -32
- torchzero/modules/momentum/averaging.py +14 -14
- torchzero/modules/momentum/cautious.py +34 -28
- torchzero/modules/momentum/momentum.py +11 -11
- torchzero/modules/ops/__init__.py +4 -4
- torchzero/modules/ops/accumulate.py +21 -21
- torchzero/modules/ops/binary.py +67 -66
- torchzero/modules/ops/higher_level.py +19 -19
- torchzero/modules/ops/multi.py +44 -41
- torchzero/modules/ops/reduce.py +26 -23
- torchzero/modules/ops/unary.py +53 -53
- torchzero/modules/ops/utility.py +47 -46
- torchzero/modules/projections/galore.py +1 -1
- torchzero/modules/projections/projection.py +43 -43
- torchzero/modules/quasi_newton/damping.py +1 -1
- torchzero/modules/quasi_newton/lbfgs.py +7 -7
- torchzero/modules/quasi_newton/lsr1.py +7 -7
- torchzero/modules/quasi_newton/quasi_newton.py +10 -10
- torchzero/modules/quasi_newton/sg2.py +19 -19
- torchzero/modules/restarts/restars.py +26 -24
- torchzero/modules/second_order/__init__.py +2 -2
- torchzero/modules/second_order/ifn.py +31 -62
- torchzero/modules/second_order/inm.py +49 -53
- torchzero/modules/second_order/multipoint.py +40 -80
- torchzero/modules/second_order/newton.py +57 -90
- torchzero/modules/second_order/newton_cg.py +102 -154
- torchzero/modules/second_order/nystrom.py +157 -177
- torchzero/modules/second_order/rsn.py +106 -96
- torchzero/modules/smoothing/laplacian.py +13 -12
- torchzero/modules/smoothing/sampling.py +11 -10
- torchzero/modules/step_size/adaptive.py +23 -23
- torchzero/modules/step_size/lr.py +15 -15
- torchzero/modules/termination/termination.py +32 -30
- torchzero/modules/trust_region/cubic_regularization.py +2 -2
- torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
- torchzero/modules/trust_region/trust_cg.py +1 -1
- torchzero/modules/trust_region/trust_region.py +27 -22
- torchzero/modules/variance_reduction/svrg.py +21 -18
- torchzero/modules/weight_decay/__init__.py +2 -1
- torchzero/modules/weight_decay/reinit.py +83 -0
- torchzero/modules/weight_decay/weight_decay.py +12 -13
- torchzero/modules/wrappers/optim_wrapper.py +10 -10
- torchzero/modules/zeroth_order/cd.py +9 -6
- torchzero/optim/root.py +3 -3
- torchzero/optim/utility/split.py +2 -1
- torchzero/optim/wrappers/directsearch.py +27 -63
- torchzero/optim/wrappers/fcmaes.py +14 -35
- torchzero/optim/wrappers/mads.py +11 -31
- torchzero/optim/wrappers/moors.py +66 -0
- torchzero/optim/wrappers/nevergrad.py +4 -4
- torchzero/optim/wrappers/nlopt.py +31 -25
- torchzero/optim/wrappers/optuna.py +6 -13
- torchzero/optim/wrappers/pybobyqa.py +124 -0
- torchzero/optim/wrappers/scipy/__init__.py +7 -0
- torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
- torchzero/optim/wrappers/scipy/brute.py +48 -0
- torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
- torchzero/optim/wrappers/scipy/direct.py +69 -0
- torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
- torchzero/optim/wrappers/scipy/experimental.py +141 -0
- torchzero/optim/wrappers/scipy/minimize.py +151 -0
- torchzero/optim/wrappers/scipy/sgho.py +111 -0
- torchzero/optim/wrappers/wrapper.py +121 -0
- torchzero/utils/__init__.py +7 -25
- torchzero/utils/compile.py +2 -2
- torchzero/utils/derivatives.py +93 -69
- torchzero/utils/optimizer.py +4 -77
- torchzero/utils/python_tools.py +31 -0
- torchzero/utils/tensorlist.py +11 -5
- torchzero/utils/thoad_tools.py +68 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.0.dist-info}/METADATA +1 -1
- torchzero-0.4.0.dist-info/RECORD +191 -0
- tests/test_vars.py +0 -185
- torchzero/core/var.py +0 -376
- torchzero/modules/experimental/momentum.py +0 -160
- torchzero/optim/wrappers/scipy.py +0 -572
- torchzero/utils/linalg/__init__.py +0 -12
- torchzero/utils/linalg/matrix_funcs.py +0 -87
- torchzero/utils/linalg/orthogonalize.py +0 -12
- torchzero/utils/linalg/svd.py +0 -20
- torchzero/utils/ops.py +0 -10
- torchzero-0.3.15.dist-info/RECORD +0 -175
- /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.0.dist-info}/WHEEL +0 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.0.dist-info}/top_level.txt +0 -0
|
@@ -8,7 +8,7 @@ from typing import Any, Literal
|
|
|
8
8
|
import numpy as np
|
|
9
9
|
import torch
|
|
10
10
|
|
|
11
|
-
from ...core import Module,
|
|
11
|
+
from ...core import Module, Objective
|
|
12
12
|
from ...utils import tofloat, set_storage_
|
|
13
13
|
from ..functional import clip_by_finfo
|
|
14
14
|
|
|
@@ -139,7 +139,7 @@ class LineSearchBase(Module, ABC):
|
|
|
139
139
|
for c, n in zip(params, new_params):
|
|
140
140
|
set_storage_(c, n)
|
|
141
141
|
|
|
142
|
-
def _loss(self, step_size: float, var:
|
|
142
|
+
def _loss(self, step_size: float, var: Objective, closure, params: list[torch.Tensor],
|
|
143
143
|
update: list[torch.Tensor], backward:bool=False) -> float:
|
|
144
144
|
|
|
145
145
|
# if step_size is 0, we might already know the loss
|
|
@@ -165,16 +165,16 @@ class LineSearchBase(Module, ABC):
|
|
|
165
165
|
# if evaluated loss at step size 0, set it to var.loss
|
|
166
166
|
if step_size == 0:
|
|
167
167
|
var.loss = loss
|
|
168
|
-
if backward: var.
|
|
168
|
+
if backward: var.grads = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
|
|
169
169
|
|
|
170
170
|
return tofloat(loss)
|
|
171
171
|
|
|
172
|
-
def _loss_derivative_gradient(self, step_size: float, var:
|
|
172
|
+
def _loss_derivative_gradient(self, step_size: float, var: Objective, closure,
|
|
173
173
|
params: list[torch.Tensor], update: list[torch.Tensor]):
|
|
174
174
|
# if step_size is 0, we might already know the derivative
|
|
175
|
-
if (var.
|
|
175
|
+
if (var.grads is not None) and (step_size == 0):
|
|
176
176
|
loss = self._loss(step_size=step_size,var=var,closure=closure,params=params,update=update,backward=False)
|
|
177
|
-
derivative = - sum(t.sum() for t in torch._foreach_mul(var.
|
|
177
|
+
derivative = - sum(t.sum() for t in torch._foreach_mul(var.grads, update))
|
|
178
178
|
|
|
179
179
|
else:
|
|
180
180
|
# loss with a backward pass sets params.grad
|
|
@@ -184,79 +184,79 @@ class LineSearchBase(Module, ABC):
|
|
|
184
184
|
derivative = - sum(t.sum() for t in torch._foreach_mul([p.grad if p.grad is not None
|
|
185
185
|
else torch.zeros_like(p) for p in params], update))
|
|
186
186
|
|
|
187
|
-
assert var.
|
|
188
|
-
return loss, tofloat(derivative), var.
|
|
187
|
+
assert var.grads is not None
|
|
188
|
+
return loss, tofloat(derivative), var.grads
|
|
189
189
|
|
|
190
|
-
def _loss_derivative(self, step_size: float, var:
|
|
190
|
+
def _loss_derivative(self, step_size: float, var: Objective, closure,
|
|
191
191
|
params: list[torch.Tensor], update: list[torch.Tensor]):
|
|
192
192
|
return self._loss_derivative_gradient(step_size=step_size, var=var,closure=closure,params=params,update=update)[:2]
|
|
193
193
|
|
|
194
|
-
def evaluate_f(self, step_size: float, var:
|
|
194
|
+
def evaluate_f(self, step_size: float, var: Objective, backward:bool=False):
|
|
195
195
|
"""evaluate function value at alpha `step_size`."""
|
|
196
196
|
closure = var.closure
|
|
197
197
|
if closure is None: raise RuntimeError('line search requires closure')
|
|
198
|
-
return self._loss(step_size=step_size, var=var, closure=closure, params=var.params,update=var.
|
|
198
|
+
return self._loss(step_size=step_size, var=var, closure=closure, params=var.params,update=var.get_updates(),backward=backward)
|
|
199
199
|
|
|
200
|
-
def evaluate_f_d(self, step_size: float, var:
|
|
200
|
+
def evaluate_f_d(self, step_size: float, var: Objective):
|
|
201
201
|
"""evaluate function value and directional derivative in the direction of the update at step size `step_size`."""
|
|
202
202
|
closure = var.closure
|
|
203
203
|
if closure is None: raise RuntimeError('line search requires closure')
|
|
204
|
-
return self._loss_derivative(step_size=step_size, var=var, closure=closure, params=var.params,update=var.
|
|
204
|
+
return self._loss_derivative(step_size=step_size, var=var, closure=closure, params=var.params,update=var.get_updates())
|
|
205
205
|
|
|
206
|
-
def evaluate_f_d_g(self, step_size: float, var:
|
|
206
|
+
def evaluate_f_d_g(self, step_size: float, var: Objective):
|
|
207
207
|
"""evaluate function value, directional derivative, and gradient list at step size `step_size`."""
|
|
208
208
|
closure = var.closure
|
|
209
209
|
if closure is None: raise RuntimeError('line search requires closure')
|
|
210
|
-
return self._loss_derivative_gradient(step_size=step_size, var=var, closure=closure, params=var.params,update=var.
|
|
210
|
+
return self._loss_derivative_gradient(step_size=step_size, var=var, closure=closure, params=var.params,update=var.get_updates())
|
|
211
211
|
|
|
212
|
-
def make_objective(self, var:
|
|
212
|
+
def make_objective(self, var: Objective, backward:bool=False):
|
|
213
213
|
closure = var.closure
|
|
214
214
|
if closure is None: raise RuntimeError('line search requires closure')
|
|
215
|
-
return partial(self._loss, var=var, closure=closure, params=var.params, update=var.
|
|
215
|
+
return partial(self._loss, var=var, closure=closure, params=var.params, update=var.get_updates(), backward=backward)
|
|
216
216
|
|
|
217
|
-
def make_objective_with_derivative(self, var:
|
|
217
|
+
def make_objective_with_derivative(self, var: Objective):
|
|
218
218
|
closure = var.closure
|
|
219
219
|
if closure is None: raise RuntimeError('line search requires closure')
|
|
220
|
-
return partial(self._loss_derivative, var=var, closure=closure, params=var.params, update=var.
|
|
220
|
+
return partial(self._loss_derivative, var=var, closure=closure, params=var.params, update=var.get_updates())
|
|
221
221
|
|
|
222
|
-
def make_objective_with_derivative_and_gradient(self, var:
|
|
222
|
+
def make_objective_with_derivative_and_gradient(self, var: Objective):
|
|
223
223
|
closure = var.closure
|
|
224
224
|
if closure is None: raise RuntimeError('line search requires closure')
|
|
225
|
-
return partial(self._loss_derivative_gradient, var=var, closure=closure, params=var.params, update=var.
|
|
225
|
+
return partial(self._loss_derivative_gradient, var=var, closure=closure, params=var.params, update=var.get_updates())
|
|
226
226
|
|
|
227
227
|
@abstractmethod
|
|
228
|
-
def search(self, update: list[torch.Tensor], var:
|
|
228
|
+
def search(self, update: list[torch.Tensor], var: Objective) -> float:
|
|
229
229
|
"""Finds the step size to use"""
|
|
230
230
|
|
|
231
231
|
@torch.no_grad
|
|
232
|
-
def
|
|
232
|
+
def apply(self, objective: Objective) -> Objective:
|
|
233
233
|
self._reset()
|
|
234
234
|
|
|
235
|
-
params =
|
|
235
|
+
params = objective.params
|
|
236
236
|
self._initial_params = [p.clone() for p in params]
|
|
237
|
-
update =
|
|
237
|
+
update = objective.get_updates()
|
|
238
238
|
|
|
239
239
|
try:
|
|
240
|
-
step_size = self.search(update=update, var=
|
|
240
|
+
step_size = self.search(update=update, var=objective)
|
|
241
241
|
except MaxLineSearchItersReached:
|
|
242
242
|
step_size = self._best_step_size
|
|
243
243
|
|
|
244
244
|
step_size = clip_by_finfo(step_size, torch.finfo(update[0].dtype))
|
|
245
245
|
|
|
246
246
|
# set loss_approx
|
|
247
|
-
if
|
|
247
|
+
if objective.loss_approx is None: objective.loss_approx = self._lowest_loss
|
|
248
248
|
|
|
249
249
|
# if this is last module, directly update parameters to avoid redundant operations
|
|
250
|
-
if
|
|
250
|
+
if objective.modular is not None and self is objective.modular.modules[-1]:
|
|
251
251
|
self.set_step_size_(step_size, params=params, update=update)
|
|
252
252
|
|
|
253
|
-
|
|
254
|
-
return
|
|
253
|
+
objective.stop = True; objective.skip_update = True
|
|
254
|
+
return objective
|
|
255
255
|
|
|
256
256
|
# revert parameters and multiply update by step size
|
|
257
257
|
self.set_step_size_(0, params=params, update=update)
|
|
258
|
-
torch._foreach_mul_(
|
|
259
|
-
return
|
|
258
|
+
torch._foreach_mul_(objective.updates, step_size)
|
|
259
|
+
return objective
|
|
260
260
|
|
|
261
261
|
|
|
262
262
|
|
|
@@ -284,8 +284,8 @@ class StrongWolfe(LineSearchBase):
|
|
|
284
284
|
'init_value', 'init', 'c1', 'c2', 'a_max', 'maxiter', 'maxzoom',
|
|
285
285
|
'maxeval', 'interpolation', 'adaptive', 'plus_minus', 'fallback', 'tol_change')(self.defaults)
|
|
286
286
|
|
|
287
|
-
dir = as_tensorlist(var.
|
|
288
|
-
grad_list = var.
|
|
287
|
+
dir = as_tensorlist(var.get_updates())
|
|
288
|
+
grad_list = var.get_grads()
|
|
289
289
|
|
|
290
290
|
g_0 = -sum(t.sum() for t in torch._foreach_mul(grad_list, dir))
|
|
291
291
|
f_0 = var.get_loss(False)
|
torchzero/modules/misc/debug.py
CHANGED
|
@@ -11,9 +11,9 @@ class PrintUpdate(Module):
|
|
|
11
11
|
defaults = dict(text=text, print_fn=print_fn)
|
|
12
12
|
super().__init__(defaults)
|
|
13
13
|
|
|
14
|
-
def
|
|
15
|
-
self.defaults["print_fn"](f'{self.defaults["text"]}{
|
|
16
|
-
return
|
|
14
|
+
def apply(self, objective):
|
|
15
|
+
self.defaults["print_fn"](f'{self.defaults["text"]}{objective.updates}')
|
|
16
|
+
return objective
|
|
17
17
|
|
|
18
18
|
class PrintShape(Module):
|
|
19
19
|
"""Prints shapes of the update."""
|
|
@@ -21,10 +21,10 @@ class PrintShape(Module):
|
|
|
21
21
|
defaults = dict(text=text, print_fn=print_fn)
|
|
22
22
|
super().__init__(defaults)
|
|
23
23
|
|
|
24
|
-
def
|
|
25
|
-
shapes = [u.shape for u in
|
|
24
|
+
def apply(self, objective):
|
|
25
|
+
shapes = [u.shape for u in objective.updates] if objective.updates is not None else None
|
|
26
26
|
self.defaults["print_fn"](f'{self.defaults["text"]}{shapes}')
|
|
27
|
-
return
|
|
27
|
+
return objective
|
|
28
28
|
|
|
29
29
|
class PrintParams(Module):
|
|
30
30
|
"""Prints current update."""
|
|
@@ -32,9 +32,9 @@ class PrintParams(Module):
|
|
|
32
32
|
defaults = dict(text=text, print_fn=print_fn)
|
|
33
33
|
super().__init__(defaults)
|
|
34
34
|
|
|
35
|
-
def
|
|
36
|
-
self.defaults["print_fn"](f'{self.defaults["text"]}{
|
|
37
|
-
return
|
|
35
|
+
def apply(self, objective):
|
|
36
|
+
self.defaults["print_fn"](f'{self.defaults["text"]}{objective.params}')
|
|
37
|
+
return objective
|
|
38
38
|
|
|
39
39
|
|
|
40
40
|
class PrintLoss(Module):
|
|
@@ -43,6 +43,6 @@ class PrintLoss(Module):
|
|
|
43
43
|
defaults = dict(text=text, print_fn=print_fn)
|
|
44
44
|
super().__init__(defaults)
|
|
45
45
|
|
|
46
|
-
def
|
|
47
|
-
self.defaults["print_fn"](f'{self.defaults["text"]}{
|
|
48
|
-
return
|
|
46
|
+
def apply(self, objective):
|
|
47
|
+
self.defaults["print_fn"](f'{self.defaults["text"]}{objective.get_loss(False)}')
|
|
48
|
+
return objective
|
torchzero/modules/misc/escape.py
CHANGED
|
@@ -3,7 +3,7 @@ import math
|
|
|
3
3
|
from typing import Literal
|
|
4
4
|
import torch
|
|
5
5
|
|
|
6
|
-
from ...core import Modular, Module,
|
|
6
|
+
from ...core import Modular, Module, Objective, Chainable
|
|
7
7
|
from ...utils import NumberList, TensorList
|
|
8
8
|
|
|
9
9
|
|
|
@@ -15,11 +15,11 @@ class EscapeAnnealing(Module):
|
|
|
15
15
|
|
|
16
16
|
|
|
17
17
|
@torch.no_grad
|
|
18
|
-
def
|
|
19
|
-
closure =
|
|
18
|
+
def apply(self, objective):
|
|
19
|
+
closure = objective.closure
|
|
20
20
|
if closure is None: raise RuntimeError("Escape requries closure")
|
|
21
21
|
|
|
22
|
-
params = TensorList(
|
|
22
|
+
params = TensorList(objective.params)
|
|
23
23
|
settings = self.settings[params[0]]
|
|
24
24
|
max_region = self.get_settings(params, 'max_region', cls=NumberList)
|
|
25
25
|
max_iter = settings['max_iter']
|
|
@@ -41,7 +41,7 @@ class EscapeAnnealing(Module):
|
|
|
41
41
|
self.global_state['n_bad'] = n_bad
|
|
42
42
|
|
|
43
43
|
# no progress
|
|
44
|
-
f_0 =
|
|
44
|
+
f_0 = objective.get_loss(False)
|
|
45
45
|
if n_bad >= n_tol:
|
|
46
46
|
for i in range(1, max_iter+1):
|
|
47
47
|
alpha = max_region * (i / max_iter)
|
|
@@ -51,12 +51,12 @@ class EscapeAnnealing(Module):
|
|
|
51
51
|
f_star = closure(False)
|
|
52
52
|
|
|
53
53
|
if math.isfinite(f_star) and f_star < f_0-1e-12:
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
return
|
|
54
|
+
objective.updates = None
|
|
55
|
+
objective.stop = True
|
|
56
|
+
objective.skip_update = True
|
|
57
|
+
return objective
|
|
58
58
|
|
|
59
59
|
params.sub_(pert)
|
|
60
60
|
|
|
61
61
|
self.global_state['n_bad'] = 0
|
|
62
|
-
return
|
|
62
|
+
return objective
|
|
@@ -3,74 +3,6 @@ import torch
|
|
|
3
3
|
from ...core import Chainable, Module
|
|
4
4
|
|
|
5
5
|
|
|
6
|
-
# class GradientAccumulation(Module):
|
|
7
|
-
# """Uses :code:`n` steps to accumulate gradients, after :code:`n` gradients have been accumulated, they are passed to :code:`modules` and parameters are updates.
|
|
8
|
-
|
|
9
|
-
# Accumulating gradients for :code:`n` steps is equivalent to increasing batch size by :code:`n`. Increasing the batch size
|
|
10
|
-
# is more computationally efficient, but sometimes it is not feasible due to memory constraints.
|
|
11
|
-
|
|
12
|
-
# .. note::
|
|
13
|
-
# Technically this can accumulate any inputs, including updates generated by previous modules. As long as this module is first, it will accumulate the gradients.
|
|
14
|
-
|
|
15
|
-
# Args:
|
|
16
|
-
# modules (Chainable): modules that perform a step every :code:`n` steps using the accumulated gradients.
|
|
17
|
-
# n (int): number of gradients to accumulate.
|
|
18
|
-
# mean (bool, optional): if True, uses mean of accumulated gradients, otherwise uses sum. Defaults to True.
|
|
19
|
-
# stop (bool, optional):
|
|
20
|
-
# this module prevents next modules from stepping unless :code:`n` gradients have been accumulate. Setting this argument to False disables that. Defaults to True.
|
|
21
|
-
|
|
22
|
-
# Examples:
|
|
23
|
-
# Adam with gradients accumulated for 16 batches.
|
|
24
|
-
|
|
25
|
-
# .. code-block:: python
|
|
26
|
-
|
|
27
|
-
# opt = tz.Modular(
|
|
28
|
-
# model.parameters(),
|
|
29
|
-
# tz.m.GradientAccumulation(
|
|
30
|
-
# [tz.m.Adam(), tz.m.LR(1e-2)],
|
|
31
|
-
# n=16
|
|
32
|
-
# )
|
|
33
|
-
# )
|
|
34
|
-
|
|
35
|
-
# """
|
|
36
|
-
# def __init__(self, modules: Chainable, n: int, mean=True, stop=True):
|
|
37
|
-
# defaults = dict(n=n, mean=mean, stop=stop)
|
|
38
|
-
# super().__init__(defaults)
|
|
39
|
-
# self.set_child('modules', modules)
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
# @torch.no_grad
|
|
43
|
-
# def step(self, var):
|
|
44
|
-
# accumulator = self.get_state(var.params, 'accumulator')
|
|
45
|
-
# settings = self.defaults
|
|
46
|
-
# n = settings['n']; mean = settings['mean']; stop = settings['stop']
|
|
47
|
-
# step = self.global_state['step'] = self.global_state.get('step', 0) + 1
|
|
48
|
-
|
|
49
|
-
# # add update to accumulator
|
|
50
|
-
# torch._foreach_add_(accumulator, var.get_update())
|
|
51
|
-
|
|
52
|
-
# # step with accumulated updates
|
|
53
|
-
# if step % n == 0:
|
|
54
|
-
# if mean:
|
|
55
|
-
# torch._foreach_div_(accumulator, n)
|
|
56
|
-
|
|
57
|
-
# var.update = [a.clone() for a in accumulator]
|
|
58
|
-
# var = self.children['modules'].step(var)
|
|
59
|
-
|
|
60
|
-
# # zero accumulator
|
|
61
|
-
# torch._foreach_zero_(accumulator)
|
|
62
|
-
|
|
63
|
-
# else:
|
|
64
|
-
# # prevent update
|
|
65
|
-
# if stop:
|
|
66
|
-
# var.update = None
|
|
67
|
-
# var.stop=True
|
|
68
|
-
# var.skip_update=True
|
|
69
|
-
|
|
70
|
-
# return var
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
6
|
|
|
75
7
|
class GradientAccumulation(Module):
|
|
76
8
|
"""Uses ``n`` steps to accumulate gradients, after ``n`` gradients have been accumulated, they are passed to :code:`modules` and parameters are updates.
|
|
@@ -106,21 +38,21 @@ class GradientAccumulation(Module):
|
|
|
106
38
|
|
|
107
39
|
|
|
108
40
|
@torch.no_grad
|
|
109
|
-
def
|
|
110
|
-
accumulator = self.get_state(
|
|
41
|
+
def apply(self, objective):
|
|
42
|
+
accumulator = self.get_state(objective.params, 'accumulator')
|
|
111
43
|
settings = self.defaults
|
|
112
44
|
n = settings['n']; mean = settings['mean']; stop = settings['stop']
|
|
113
|
-
step = self.
|
|
45
|
+
step = self.increment_counter("step", 0)
|
|
114
46
|
|
|
115
47
|
# add update to accumulator
|
|
116
|
-
torch._foreach_add_(accumulator,
|
|
48
|
+
torch._foreach_add_(accumulator, objective.get_updates())
|
|
117
49
|
|
|
118
50
|
# step with accumulated updates
|
|
119
|
-
if step % n == 0:
|
|
51
|
+
if (step + 1) % n == 0:
|
|
120
52
|
if mean:
|
|
121
53
|
torch._foreach_div_(accumulator, n)
|
|
122
54
|
|
|
123
|
-
|
|
55
|
+
objective.updates = accumulator
|
|
124
56
|
|
|
125
57
|
# zero accumulator
|
|
126
58
|
self.clear_state_keys('accumulator')
|
|
@@ -128,9 +60,9 @@ class GradientAccumulation(Module):
|
|
|
128
60
|
else:
|
|
129
61
|
# prevent update
|
|
130
62
|
if stop:
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
63
|
+
objective.updates = None
|
|
64
|
+
objective.stop=True
|
|
65
|
+
objective.skip_update=True
|
|
134
66
|
|
|
135
|
-
return
|
|
67
|
+
return objective
|
|
136
68
|
|
|
@@ -13,27 +13,27 @@ class HomotopyBase(Module):
|
|
|
13
13
|
"""transform the loss"""
|
|
14
14
|
|
|
15
15
|
@torch.no_grad
|
|
16
|
-
def
|
|
17
|
-
if
|
|
18
|
-
|
|
16
|
+
def apply(self, objective):
|
|
17
|
+
if objective.loss is not None:
|
|
18
|
+
objective.loss = self.loss_transform(objective.loss)
|
|
19
19
|
|
|
20
|
-
closure =
|
|
20
|
+
closure = objective.closure
|
|
21
21
|
if closure is None: raise RuntimeError("SquareHomotopy requires closure")
|
|
22
22
|
|
|
23
23
|
def homotopy_closure(backward=True):
|
|
24
24
|
if backward:
|
|
25
25
|
with torch.enable_grad():
|
|
26
26
|
loss = self.loss_transform(closure(False))
|
|
27
|
-
grad = torch.autograd.grad(loss,
|
|
28
|
-
for p,g in zip(
|
|
27
|
+
grad = torch.autograd.grad(loss, objective.params, allow_unused=True)
|
|
28
|
+
for p,g in zip(objective.params, grad):
|
|
29
29
|
p.grad = g
|
|
30
30
|
else:
|
|
31
31
|
loss = self.loss_transform(closure(False))
|
|
32
32
|
|
|
33
33
|
return loss
|
|
34
34
|
|
|
35
|
-
|
|
36
|
-
return
|
|
35
|
+
objective.closure = homotopy_closure
|
|
36
|
+
return objective
|
|
37
37
|
|
|
38
38
|
class SquareHomotopy(HomotopyBase):
|
|
39
39
|
def __init__(self): super().__init__()
|
|
@@ -57,3 +57,11 @@ class LambdaHomotopy(HomotopyBase):
|
|
|
57
57
|
super().__init__(defaults)
|
|
58
58
|
|
|
59
59
|
def loss_transform(self, loss): return self.defaults['fn'](loss)
|
|
60
|
+
|
|
61
|
+
class FixedLossHomotopy(HomotopyBase):
|
|
62
|
+
def __init__(self, value: float = 1):
|
|
63
|
+
defaults = dict(value=value)
|
|
64
|
+
super().__init__(defaults)
|
|
65
|
+
|
|
66
|
+
def loss_transform(self, loss): return loss / loss.detach().clip(min=torch.finfo(loss.dtype).tiny * 2)
|
|
67
|
+
|