torchzero 0.3.9__py3-none-any.whl → 0.3.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docs/source/conf.py +6 -4
- docs/source/docstring template.py +46 -0
- tests/test_identical.py +2 -3
- tests/test_opts.py +115 -68
- tests/test_tensorlist.py +2 -2
- tests/test_vars.py +62 -61
- torchzero/core/__init__.py +2 -3
- torchzero/core/module.py +185 -53
- torchzero/core/transform.py +327 -159
- torchzero/modules/__init__.py +3 -1
- torchzero/modules/clipping/clipping.py +120 -23
- torchzero/modules/clipping/ema_clipping.py +37 -22
- torchzero/modules/clipping/growth_clipping.py +20 -21
- torchzero/modules/experimental/__init__.py +30 -4
- torchzero/modules/experimental/absoap.py +53 -156
- torchzero/modules/experimental/adadam.py +22 -15
- torchzero/modules/experimental/adamY.py +21 -25
- torchzero/modules/experimental/adam_lambertw.py +149 -0
- torchzero/modules/{line_search/trust_region.py → experimental/adaptive_step_size.py} +37 -8
- torchzero/modules/experimental/adasoap.py +24 -129
- torchzero/modules/experimental/cosine.py +214 -0
- torchzero/modules/experimental/cubic_adam.py +97 -0
- torchzero/modules/experimental/curveball.py +12 -12
- torchzero/modules/{projections → experimental}/dct.py +11 -11
- torchzero/modules/experimental/eigendescent.py +120 -0
- torchzero/modules/experimental/etf.py +195 -0
- torchzero/modules/experimental/exp_adam.py +113 -0
- torchzero/modules/experimental/expanded_lbfgs.py +141 -0
- torchzero/modules/{projections → experimental}/fft.py +10 -10
- torchzero/modules/experimental/gradmin.py +2 -2
- torchzero/modules/experimental/hnewton.py +85 -0
- torchzero/modules/{quasi_newton/experimental → experimental}/modular_lbfgs.py +49 -50
- torchzero/modules/experimental/newton_solver.py +11 -11
- torchzero/modules/experimental/newtonnewton.py +92 -0
- torchzero/modules/experimental/parabolic_search.py +220 -0
- torchzero/modules/experimental/reduce_outward_lr.py +10 -7
- torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +12 -54
- torchzero/modules/experimental/subspace_preconditioners.py +20 -10
- torchzero/modules/experimental/tensor_adagrad.py +42 -0
- torchzero/modules/functional.py +12 -2
- torchzero/modules/grad_approximation/fdm.py +31 -4
- torchzero/modules/grad_approximation/forward_gradient.py +17 -7
- torchzero/modules/grad_approximation/grad_approximator.py +69 -24
- torchzero/modules/grad_approximation/rfdm.py +310 -50
- torchzero/modules/higher_order/__init__.py +1 -0
- torchzero/modules/higher_order/higher_order_newton.py +319 -0
- torchzero/modules/line_search/__init__.py +4 -4
- torchzero/modules/line_search/adaptive.py +99 -0
- torchzero/modules/line_search/backtracking.py +75 -31
- torchzero/modules/line_search/line_search.py +107 -49
- torchzero/modules/line_search/polynomial.py +233 -0
- torchzero/modules/line_search/scipy.py +20 -5
- torchzero/modules/line_search/strong_wolfe.py +52 -36
- torchzero/modules/misc/__init__.py +27 -0
- torchzero/modules/misc/debug.py +48 -0
- torchzero/modules/misc/escape.py +60 -0
- torchzero/modules/misc/gradient_accumulation.py +70 -0
- torchzero/modules/misc/misc.py +316 -0
- torchzero/modules/misc/multistep.py +158 -0
- torchzero/modules/misc/regularization.py +171 -0
- torchzero/modules/misc/split.py +103 -0
- torchzero/modules/{ops → misc}/switch.py +48 -7
- torchzero/modules/momentum/__init__.py +1 -1
- torchzero/modules/momentum/averaging.py +25 -10
- torchzero/modules/momentum/cautious.py +115 -40
- torchzero/modules/momentum/ema.py +92 -41
- torchzero/modules/momentum/experimental.py +21 -13
- torchzero/modules/momentum/matrix_momentum.py +145 -76
- torchzero/modules/momentum/momentum.py +25 -4
- torchzero/modules/ops/__init__.py +3 -31
- torchzero/modules/ops/accumulate.py +51 -25
- torchzero/modules/ops/binary.py +108 -62
- torchzero/modules/ops/multi.py +95 -34
- torchzero/modules/ops/reduce.py +31 -23
- torchzero/modules/ops/unary.py +37 -21
- torchzero/modules/ops/utility.py +53 -45
- torchzero/modules/optimizers/__init__.py +12 -3
- torchzero/modules/optimizers/adagrad.py +48 -29
- torchzero/modules/optimizers/adahessian.py +223 -0
- torchzero/modules/optimizers/adam.py +35 -37
- torchzero/modules/optimizers/adan.py +110 -0
- torchzero/modules/optimizers/adaptive_heavyball.py +57 -0
- torchzero/modules/optimizers/esgd.py +171 -0
- torchzero/modules/optimizers/ladagrad.py +183 -0
- torchzero/modules/optimizers/lion.py +4 -4
- torchzero/modules/optimizers/mars.py +91 -0
- torchzero/modules/optimizers/msam.py +186 -0
- torchzero/modules/optimizers/muon.py +32 -7
- torchzero/modules/optimizers/orthograd.py +4 -5
- torchzero/modules/optimizers/rmsprop.py +19 -19
- torchzero/modules/optimizers/rprop.py +89 -52
- torchzero/modules/optimizers/sam.py +163 -0
- torchzero/modules/optimizers/shampoo.py +55 -27
- torchzero/modules/optimizers/soap.py +40 -37
- torchzero/modules/optimizers/sophia_h.py +82 -25
- torchzero/modules/projections/__init__.py +2 -4
- torchzero/modules/projections/cast.py +51 -0
- torchzero/modules/projections/galore.py +4 -2
- torchzero/modules/projections/projection.py +212 -118
- torchzero/modules/quasi_newton/__init__.py +44 -5
- torchzero/modules/quasi_newton/cg.py +190 -39
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +163 -0
- torchzero/modules/quasi_newton/lbfgs.py +154 -97
- torchzero/modules/quasi_newton/lsr1.py +102 -58
- torchzero/modules/quasi_newton/quasi_newton.py +1032 -177
- torchzero/modules/quasi_newton/trust_region.py +397 -0
- torchzero/modules/second_order/__init__.py +2 -2
- torchzero/modules/second_order/newton.py +245 -54
- torchzero/modules/second_order/newton_cg.py +311 -21
- torchzero/modules/second_order/nystrom.py +124 -21
- torchzero/modules/smoothing/gaussian.py +55 -21
- torchzero/modules/smoothing/laplacian.py +20 -12
- torchzero/modules/step_size/__init__.py +2 -0
- torchzero/modules/step_size/adaptive.py +122 -0
- torchzero/modules/step_size/lr.py +154 -0
- torchzero/modules/weight_decay/__init__.py +1 -1
- torchzero/modules/weight_decay/weight_decay.py +126 -10
- torchzero/modules/wrappers/optim_wrapper.py +40 -12
- torchzero/optim/wrappers/directsearch.py +281 -0
- torchzero/optim/wrappers/fcmaes.py +105 -0
- torchzero/optim/wrappers/mads.py +89 -0
- torchzero/optim/wrappers/nevergrad.py +20 -5
- torchzero/optim/wrappers/nlopt.py +28 -14
- torchzero/optim/wrappers/optuna.py +70 -0
- torchzero/optim/wrappers/scipy.py +167 -16
- torchzero/utils/__init__.py +3 -7
- torchzero/utils/derivatives.py +5 -4
- torchzero/utils/linalg/__init__.py +1 -1
- torchzero/utils/linalg/solve.py +251 -12
- torchzero/utils/numberlist.py +2 -0
- torchzero/utils/optimizer.py +55 -74
- torchzero/utils/python_tools.py +27 -4
- torchzero/utils/tensorlist.py +40 -28
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/METADATA +76 -51
- torchzero-0.3.11.dist-info/RECORD +159 -0
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/WHEEL +1 -1
- torchzero/core/preconditioner.py +0 -138
- torchzero/modules/experimental/algebraic_newton.py +0 -145
- torchzero/modules/experimental/soapy.py +0 -290
- torchzero/modules/experimental/spectral.py +0 -288
- torchzero/modules/experimental/structured_newton.py +0 -111
- torchzero/modules/experimental/tropical_newton.py +0 -136
- torchzero/modules/lr/__init__.py +0 -2
- torchzero/modules/lr/lr.py +0 -59
- torchzero/modules/lr/step_size.py +0 -97
- torchzero/modules/ops/debug.py +0 -25
- torchzero/modules/ops/misc.py +0 -419
- torchzero/modules/ops/split.py +0 -75
- torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
- torchzero/modules/quasi_newton/olbfgs.py +0 -196
- torchzero-0.3.9.dist-info/RECORD +0 -131
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/licenses/LICENSE +0 -0
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
from collections import deque
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from ...core import Module
|
|
6
|
+
from ...utils.tensorlist import Distributions
|
|
7
|
+
|
|
8
|
+
class PrintUpdate(Module):
|
|
9
|
+
"""Prints current update."""
|
|
10
|
+
def __init__(self, text = 'update = ', print_fn = print):
|
|
11
|
+
defaults = dict(text=text, print_fn=print_fn)
|
|
12
|
+
super().__init__(defaults)
|
|
13
|
+
|
|
14
|
+
def step(self, var):
|
|
15
|
+
self.settings[var.params[0]]["print_fn"](f'{self.settings[var.params[0]]["text"]}{var.update}')
|
|
16
|
+
return var
|
|
17
|
+
|
|
18
|
+
class PrintShape(Module):
|
|
19
|
+
"""Prints shapes of the update."""
|
|
20
|
+
def __init__(self, text = 'shapes = ', print_fn = print):
|
|
21
|
+
defaults = dict(text=text, print_fn=print_fn)
|
|
22
|
+
super().__init__(defaults)
|
|
23
|
+
|
|
24
|
+
def step(self, var):
|
|
25
|
+
shapes = [u.shape for u in var.update] if var.update is not None else None
|
|
26
|
+
self.settings[var.params[0]]["print_fn"](f'{self.settings[var.params[0]]["text"]}{shapes}')
|
|
27
|
+
return var
|
|
28
|
+
|
|
29
|
+
class PrintParams(Module):
|
|
30
|
+
"""Prints current update."""
|
|
31
|
+
def __init__(self, text = 'params = ', print_fn = print):
|
|
32
|
+
defaults = dict(text=text, print_fn=print_fn)
|
|
33
|
+
super().__init__(defaults)
|
|
34
|
+
|
|
35
|
+
def step(self, var):
|
|
36
|
+
self.settings[var.params[0]]["print_fn"](f'{self.settings[var.params[0]]["text"]}{var.params}')
|
|
37
|
+
return var
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class PrintLoss(Module):
|
|
41
|
+
"""Prints var.get_loss()."""
|
|
42
|
+
def __init__(self, text = 'loss = ', print_fn = print):
|
|
43
|
+
defaults = dict(text=text, print_fn=print_fn)
|
|
44
|
+
super().__init__(defaults)
|
|
45
|
+
|
|
46
|
+
def step(self, var):
|
|
47
|
+
self.settings[var.params[0]]["print_fn"](f'{self.settings[var.params[0]]["text"]}{var.get_loss(False)}')
|
|
48
|
+
return var
|
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
from ...core import Module
|
|
4
|
+
from ...utils import TensorList, NumberList
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class EscapeAnnealing(Module):
|
|
8
|
+
"""If parameters stop changing, this runs a backward annealing random search"""
|
|
9
|
+
def __init__(self, max_region:float = 1, max_iter:int = 1000, tol=1e-6, n_tol: int = 10):
|
|
10
|
+
defaults = dict(max_region=max_region, max_iter=max_iter, tol=tol, n_tol=n_tol)
|
|
11
|
+
super().__init__(defaults)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@torch.no_grad
|
|
15
|
+
def step(self, var):
|
|
16
|
+
closure = var.closure
|
|
17
|
+
if closure is None: raise RuntimeError("Escape requries closure")
|
|
18
|
+
|
|
19
|
+
params = TensorList(var.params)
|
|
20
|
+
settings = self.settings[params[0]]
|
|
21
|
+
max_region = self.get_settings(params, 'max_region', cls=NumberList)
|
|
22
|
+
max_iter = settings['max_iter']
|
|
23
|
+
tol = settings['tol']
|
|
24
|
+
n_tol = settings['n_tol']
|
|
25
|
+
|
|
26
|
+
n_bad = self.global_state.get('n_bad', 0)
|
|
27
|
+
|
|
28
|
+
prev_params = self.get_state(params, 'prev_params', cls=TensorList)
|
|
29
|
+
diff = params-prev_params
|
|
30
|
+
prev_params.copy_(params)
|
|
31
|
+
|
|
32
|
+
if diff.abs().global_max() <= tol:
|
|
33
|
+
n_bad += 1
|
|
34
|
+
|
|
35
|
+
else:
|
|
36
|
+
n_bad = 0
|
|
37
|
+
|
|
38
|
+
self.global_state['n_bad'] = n_bad
|
|
39
|
+
|
|
40
|
+
# no progress
|
|
41
|
+
f_0 = var.get_loss(False)
|
|
42
|
+
if n_bad >= n_tol:
|
|
43
|
+
for i in range(1, max_iter+1):
|
|
44
|
+
alpha = max_region * (i / max_iter)
|
|
45
|
+
pert = params.sample_like(distribution='sphere').mul_(alpha)
|
|
46
|
+
|
|
47
|
+
params.add_(pert)
|
|
48
|
+
f_star = closure(False)
|
|
49
|
+
|
|
50
|
+
if f_star < f_0-1e-10:
|
|
51
|
+
var.update = None
|
|
52
|
+
var.stop = True
|
|
53
|
+
var.skip_update = True
|
|
54
|
+
return var
|
|
55
|
+
|
|
56
|
+
else:
|
|
57
|
+
params.sub_(pert)
|
|
58
|
+
|
|
59
|
+
self.global_state['n_bad'] = 0
|
|
60
|
+
return var
|
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
from ...core import Chainable, Module
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class GradientAccumulation(Module):
|
|
7
|
+
"""Uses :code:`n` steps to accumulate gradients, after :code:`n` gradients have been accumulated, they are passed to :code:`modules` and parameters are updates.
|
|
8
|
+
|
|
9
|
+
Accumulating gradients for :code:`n` steps is equivalent to increasing batch size by :code:`n`. Increasing the batch size
|
|
10
|
+
is more computationally efficient, but sometimes it is not feasible due to memory constraints.
|
|
11
|
+
|
|
12
|
+
.. note::
|
|
13
|
+
Technically this can accumulate any inputs, including updates generated by previous modules. As long as this module is first, it will accumulate the gradients.
|
|
14
|
+
|
|
15
|
+
Args:
|
|
16
|
+
modules (Chainable): modules that perform a step every :code:`n` steps using the accumulated gradients.
|
|
17
|
+
n (int): number of gradients to accumulate.
|
|
18
|
+
mean (bool, optional): if True, uses mean of accumulated gradients, otherwise uses sum. Defaults to True.
|
|
19
|
+
stop (bool, optional):
|
|
20
|
+
this module prevents next modules from stepping unless :code:`n` gradients have been accumulate. Setting this argument to False disables that. Defaults to True.
|
|
21
|
+
|
|
22
|
+
Examples:
|
|
23
|
+
Adam with gradients accumulated for 16 batches.
|
|
24
|
+
|
|
25
|
+
.. code-block:: python
|
|
26
|
+
|
|
27
|
+
opt = tz.Modular(
|
|
28
|
+
model.parameters(),
|
|
29
|
+
tz.m.GradientAccumulation(
|
|
30
|
+
modules=[tz.m.Adam(), tz.m.LR(1e-2)],
|
|
31
|
+
n=16
|
|
32
|
+
)
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
"""
|
|
36
|
+
def __init__(self, modules: Chainable, n: int, mean=True, stop=True):
|
|
37
|
+
defaults = dict(n=n, mean=mean, stop=stop)
|
|
38
|
+
super().__init__(defaults)
|
|
39
|
+
self.set_child('modules', modules)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
@torch.no_grad
|
|
43
|
+
def step(self, var):
|
|
44
|
+
accumulator = self.get_state(var.params, 'accumulator')
|
|
45
|
+
settings = self.settings[var.params[0]]
|
|
46
|
+
n = settings['n']; mean = settings['mean']; stop = settings['stop']
|
|
47
|
+
step = self.global_state['step'] = self.global_state.get('step', 0) + 1
|
|
48
|
+
|
|
49
|
+
# add update to accumulator
|
|
50
|
+
torch._foreach_add_(accumulator, var.get_update())
|
|
51
|
+
|
|
52
|
+
# step with accumulated updates
|
|
53
|
+
if step % n == 0:
|
|
54
|
+
if mean:
|
|
55
|
+
torch._foreach_div_(accumulator, n)
|
|
56
|
+
|
|
57
|
+
var.update = [a.clone() for a in accumulator]
|
|
58
|
+
var = self.children['modules'].step(var)
|
|
59
|
+
|
|
60
|
+
# zero accumulator
|
|
61
|
+
torch._foreach_zero_(accumulator)
|
|
62
|
+
|
|
63
|
+
else:
|
|
64
|
+
# prevent update
|
|
65
|
+
if stop:
|
|
66
|
+
var.stop=True
|
|
67
|
+
var.skip_update=True
|
|
68
|
+
|
|
69
|
+
return var
|
|
70
|
+
|
|
@@ -0,0 +1,316 @@
|
|
|
1
|
+
from collections import deque
|
|
2
|
+
from collections.abc import Iterable
|
|
3
|
+
from operator import itemgetter
|
|
4
|
+
from typing import Literal
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
|
|
8
|
+
from ...core import Chainable, Module, Target, TensorwiseTransform, Transform, Var
|
|
9
|
+
from ...utils import Distributions, NumberList, TensorList, unpack_dicts, unpack_states
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class Previous(TensorwiseTransform):
|
|
13
|
+
"""Maintains an update from n steps back, for example if n=1, returns previous update"""
|
|
14
|
+
def __init__(self, n=1, target: Target = 'update'):
|
|
15
|
+
defaults = dict(n=n)
|
|
16
|
+
super().__init__(uses_grad=False, defaults=defaults, target=target)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@torch.no_grad
|
|
20
|
+
def apply_tensor(self, tensor, param, grad, loss, state, setting):
|
|
21
|
+
n = setting['n']
|
|
22
|
+
|
|
23
|
+
if 'history' not in state:
|
|
24
|
+
state['history'] = deque(maxlen=n+1)
|
|
25
|
+
|
|
26
|
+
state['history'].append(tensor)
|
|
27
|
+
|
|
28
|
+
return state['history'][0]
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class LastDifference(Transform):
|
|
32
|
+
"""Outputs difference between past two updates."""
|
|
33
|
+
def __init__(self,target: Target = 'update'):
|
|
34
|
+
super().__init__({}, target=target)
|
|
35
|
+
|
|
36
|
+
@torch.no_grad
|
|
37
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
38
|
+
prev_tensors = unpack_states(states, tensors, 'prev_tensors') # initialized to 0
|
|
39
|
+
difference = torch._foreach_sub(tensors, prev_tensors)
|
|
40
|
+
for p, c in zip(prev_tensors, tensors): p.set_(c)
|
|
41
|
+
return difference
|
|
42
|
+
|
|
43
|
+
class LastGradDifference(Module):
|
|
44
|
+
"""Outputs difference between past two gradients."""
|
|
45
|
+
def __init__(self):
|
|
46
|
+
super().__init__({})
|
|
47
|
+
|
|
48
|
+
@torch.no_grad
|
|
49
|
+
def step(self, var):
|
|
50
|
+
grad = var.get_grad()
|
|
51
|
+
prev_grad = self.get_state(var.params, 'prev_grad') # initialized to 0
|
|
52
|
+
difference = torch._foreach_sub(grad, prev_grad)
|
|
53
|
+
for p, c in zip(prev_grad, grad): p.copy_(c)
|
|
54
|
+
var.update = list(difference)
|
|
55
|
+
return var
|
|
56
|
+
|
|
57
|
+
class LastParamDifference(Module):
|
|
58
|
+
"""Outputs difference between past two parameters, which is the effective previous update."""
|
|
59
|
+
def __init__(self):
|
|
60
|
+
super().__init__({})
|
|
61
|
+
|
|
62
|
+
@torch.no_grad
|
|
63
|
+
def step(self, var):
|
|
64
|
+
params = var.params
|
|
65
|
+
prev_params = self.get_state(var.params, 'prev_params') # initialized to 0
|
|
66
|
+
difference = torch._foreach_sub(params, prev_params)
|
|
67
|
+
for p, c in zip(prev_params, params): p.copy_(c)
|
|
68
|
+
var.update = list(difference)
|
|
69
|
+
return var
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
class LastProduct(Transform):
|
|
74
|
+
"""Outputs difference between past two updates."""
|
|
75
|
+
def __init__(self,target: Target = 'update'):
|
|
76
|
+
super().__init__({}, uses_grad=False, target=target)
|
|
77
|
+
|
|
78
|
+
@torch.no_grad
|
|
79
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
80
|
+
prev = unpack_states(states, tensors, 'prev', init=torch.ones_like) # initialized to 1 for prod
|
|
81
|
+
prod = torch._foreach_mul(tensors, prev)
|
|
82
|
+
for p, c in zip(prev, tensors): p.set_(c)
|
|
83
|
+
return prod
|
|
84
|
+
|
|
85
|
+
class LastRatio(Transform):
|
|
86
|
+
"""Outputs ratio between past two updates, the numerator is determined by :code:`numerator` argument."""
|
|
87
|
+
def __init__(self, numerator: Literal['cur', 'prev'] = 'cur', target: Target = 'update'):
|
|
88
|
+
defaults = dict(numerator=numerator)
|
|
89
|
+
super().__init__(defaults, uses_grad=False, target=target)
|
|
90
|
+
|
|
91
|
+
@torch.no_grad
|
|
92
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
93
|
+
prev = unpack_states(states, tensors, 'prev', init = torch.ones_like) # initialized to ones
|
|
94
|
+
numerator = settings[0]['numerator']
|
|
95
|
+
if numerator == 'cur': ratio = torch._foreach_div(tensors, prev)
|
|
96
|
+
else: ratio = torch._foreach_div(prev, tensors)
|
|
97
|
+
for p, c in zip(prev, tensors): p.set_(c)
|
|
98
|
+
return ratio
|
|
99
|
+
|
|
100
|
+
class LastAbsoluteRatio(Transform):
|
|
101
|
+
"""Outputs ratio between absolute values of past two updates the numerator is determined by :code:`numerator` argument."""
|
|
102
|
+
def __init__(self, numerator: Literal['cur', 'prev'] = 'cur', eps:float=1e-8, target: Target = 'update'):
|
|
103
|
+
defaults = dict(numerator=numerator, eps=eps)
|
|
104
|
+
super().__init__(defaults, uses_grad=False, target=target)
|
|
105
|
+
|
|
106
|
+
@torch.no_grad
|
|
107
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
108
|
+
prev = unpack_states(states, tensors, 'prev', init = torch.ones_like) # initialized to ones
|
|
109
|
+
numerator = settings[0]['numerator']
|
|
110
|
+
eps = NumberList(s['eps'] for s in settings)
|
|
111
|
+
|
|
112
|
+
torch._foreach_abs_(tensors)
|
|
113
|
+
torch._foreach_clamp_min_(prev, eps)
|
|
114
|
+
|
|
115
|
+
if numerator == 'cur': ratio = torch._foreach_div(tensors, prev)
|
|
116
|
+
else: ratio = torch._foreach_div(prev, tensors)
|
|
117
|
+
for p, c in zip(prev, tensors): p.set_(c)
|
|
118
|
+
return ratio
|
|
119
|
+
|
|
120
|
+
class GradSign(Transform):
|
|
121
|
+
"""Copies gradient sign to update."""
|
|
122
|
+
def __init__(self, target: Target = 'update'):
|
|
123
|
+
super().__init__({}, uses_grad=True, target=target)
|
|
124
|
+
|
|
125
|
+
@torch.no_grad
|
|
126
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
127
|
+
assert grads is not None
|
|
128
|
+
return [t.copysign_(g) for t,g in zip(tensors, grads)]
|
|
129
|
+
|
|
130
|
+
class UpdateSign(Transform):
|
|
131
|
+
"""Outputs gradient with sign copied from the update."""
|
|
132
|
+
def __init__(self, target: Target = 'update'):
|
|
133
|
+
super().__init__({}, uses_grad=True, target=target)
|
|
134
|
+
|
|
135
|
+
@torch.no_grad
|
|
136
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
137
|
+
assert grads is not None
|
|
138
|
+
return [g.copysign(t) for t,g in zip(tensors, grads)] # no in-place
|
|
139
|
+
|
|
140
|
+
class GraftToGrad(Transform):
|
|
141
|
+
"""Grafts update to the gradient, that is update is rescaled to have the same norm as the gradient."""
|
|
142
|
+
def __init__(self, tensorwise:bool=False, ord:float=2, eps:float = 1e-6, target: Target = 'update'):
|
|
143
|
+
defaults = dict(tensorwise=tensorwise, ord=ord, eps=eps)
|
|
144
|
+
super().__init__(defaults, uses_grad=True, target=target)
|
|
145
|
+
|
|
146
|
+
@torch.no_grad
|
|
147
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
148
|
+
assert grads is not None
|
|
149
|
+
tensorwise, ord, eps = itemgetter('tensorwise','ord','eps')(settings[0])
|
|
150
|
+
return TensorList(tensors).graft_(grads, tensorwise=tensorwise, ord=ord, eps=eps)
|
|
151
|
+
|
|
152
|
+
class GraftGradToUpdate(Transform):
|
|
153
|
+
"""Outputs gradient grafted to update, that is gradient rescaled to have the same norm as the update."""
|
|
154
|
+
def __init__(self, tensorwise:bool=False, ord:float=2, eps:float = 1e-6, target: Target = 'update'):
|
|
155
|
+
defaults = dict(tensorwise=tensorwise, ord=ord, eps=eps)
|
|
156
|
+
super().__init__(defaults, uses_grad=True, target=target)
|
|
157
|
+
|
|
158
|
+
@torch.no_grad
|
|
159
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
160
|
+
assert grads is not None
|
|
161
|
+
tensorwise, ord, eps = itemgetter('tensorwise','ord','eps')(settings[0])
|
|
162
|
+
return TensorList(grads).graft(tensors, tensorwise=tensorwise, ord=ord, eps=eps)
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
class GraftToParams(Transform):
|
|
166
|
+
"""Grafts update to the parameters, that is update is rescaled to have the same norm as the parameters, but no smaller than :code:`eps`."""
|
|
167
|
+
def __init__(self, tensorwise:bool=False, ord:float=2, eps:float = 1e-4, target: Target = 'update'):
|
|
168
|
+
defaults = dict(tensorwise=tensorwise, ord=ord, eps=eps)
|
|
169
|
+
super().__init__(defaults, uses_grad=False, target=target)
|
|
170
|
+
|
|
171
|
+
@torch.no_grad
|
|
172
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
173
|
+
tensorwise, ord, eps = itemgetter('tensorwise','ord','eps')(settings[0])
|
|
174
|
+
return TensorList(tensors).graft_(params, tensorwise=tensorwise, ord=ord, eps=eps)
|
|
175
|
+
|
|
176
|
+
class Relative(Transform):
|
|
177
|
+
"""Multiplies update by absolute parameter values to make it relative to their magnitude, :code:`min_value` is minimum allowed value to avoid getting stuck at 0."""
|
|
178
|
+
def __init__(self, min_value:float = 1e-4, target: Target = 'update'):
|
|
179
|
+
defaults = dict(min_value=min_value)
|
|
180
|
+
super().__init__(defaults, uses_grad=False, target=target)
|
|
181
|
+
|
|
182
|
+
@torch.no_grad
|
|
183
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
184
|
+
mul = TensorList(params).abs().clamp_([s['min_value'] for s in settings])
|
|
185
|
+
torch._foreach_mul_(tensors, mul)
|
|
186
|
+
return tensors
|
|
187
|
+
|
|
188
|
+
class FillLoss(Module):
|
|
189
|
+
"""Outputs tensors filled with loss value times :code:`alpha`"""
|
|
190
|
+
def __init__(self, alpha: float = 1, backward: bool = True):
|
|
191
|
+
defaults = dict(alpha=alpha, backward=backward)
|
|
192
|
+
super().__init__(defaults)
|
|
193
|
+
|
|
194
|
+
@torch.no_grad
|
|
195
|
+
def step(self, var):
|
|
196
|
+
alpha = self.get_settings(var.params, 'alpha')
|
|
197
|
+
loss = var.get_loss(backward=self.settings[var.params[0]]['backward'])
|
|
198
|
+
var.update = [torch.full_like(p, loss*a) for p,a in zip(var.params, alpha)]
|
|
199
|
+
return var
|
|
200
|
+
|
|
201
|
+
class MulByLoss(Module):
|
|
202
|
+
"""Multiplies update by loss times :code:`alpha`"""
|
|
203
|
+
def __init__(self, alpha: float = 1, min_value:float = 1e-8, backward: bool = True):
|
|
204
|
+
defaults = dict(alpha=alpha, min_value=min_value, backward=backward)
|
|
205
|
+
super().__init__(defaults)
|
|
206
|
+
|
|
207
|
+
@torch.no_grad
|
|
208
|
+
def step(self, var):
|
|
209
|
+
alpha, min_value = self.get_settings(var.params, 'alpha', 'min_value')
|
|
210
|
+
loss = var.get_loss(backward=self.settings[var.params[0]]['backward'])
|
|
211
|
+
mul = [max(loss*a, mv) for a,mv in zip(alpha, min_value)]
|
|
212
|
+
torch._foreach_mul_(var.update, mul)
|
|
213
|
+
return var
|
|
214
|
+
|
|
215
|
+
class DivByLoss(Module):
|
|
216
|
+
"""Divides update by loss times :code:`alpha`"""
|
|
217
|
+
def __init__(self, alpha: float = 1, min_value:float = 1e-8, backward: bool = True):
|
|
218
|
+
defaults = dict(alpha=alpha, min_value=min_value, backward=backward)
|
|
219
|
+
super().__init__(defaults)
|
|
220
|
+
|
|
221
|
+
@torch.no_grad
|
|
222
|
+
def step(self, var):
|
|
223
|
+
alpha, min_value = self.get_settings(var.params, 'alpha', 'min_value')
|
|
224
|
+
loss = var.get_loss(backward=self.settings[var.params[0]]['backward'])
|
|
225
|
+
mul = [max(loss*a, mv) for a,mv in zip(alpha, min_value)]
|
|
226
|
+
torch._foreach_div_(var.update, mul)
|
|
227
|
+
return var
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
class NoiseSign(Transform):
|
|
231
|
+
"""Outputs random tensors with sign copied from the update."""
|
|
232
|
+
def __init__(self, distribution:Distributions = 'normal', alpha = 1):
|
|
233
|
+
defaults = dict(distribution=distribution, alpha=alpha)
|
|
234
|
+
super().__init__(defaults, uses_grad=False)
|
|
235
|
+
|
|
236
|
+
@torch.no_grad
|
|
237
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
238
|
+
alpha = [s['alpha'] for s in settings]
|
|
239
|
+
distribution = self.settings[params[0]]['distribution']
|
|
240
|
+
return TensorList(tensors).sample_like(alpha, distribution).copysign_(tensors)
|
|
241
|
+
|
|
242
|
+
class HpuEstimate(Transform):
|
|
243
|
+
"""returns ``y/||s||``, where ``y`` is difference between current and previous update (gradient), ``s`` is difference between current and previous parameters. The returned tensors are a finite difference approximation to hessian times previous update."""
|
|
244
|
+
def __init__(self):
|
|
245
|
+
defaults = dict()
|
|
246
|
+
super().__init__(defaults, uses_grad=False)
|
|
247
|
+
|
|
248
|
+
def reset_for_online(self):
|
|
249
|
+
super().reset_for_online()
|
|
250
|
+
self.clear_state_keys('prev_params', 'prev_update')
|
|
251
|
+
|
|
252
|
+
@torch.no_grad
|
|
253
|
+
def update_tensors(self, tensors, params, grads, loss, states, settings):
|
|
254
|
+
prev_params, prev_update = self.get_state(params, 'prev_params', 'prev_update') # initialized to 0
|
|
255
|
+
s = torch._foreach_sub(params, prev_params)
|
|
256
|
+
y = torch._foreach_sub(tensors, prev_update)
|
|
257
|
+
for p, c in zip(prev_params, params): p.copy_(c)
|
|
258
|
+
for p, c in zip(prev_update, tensors): p.copy_(c)
|
|
259
|
+
torch._foreach_div_(y, torch.linalg.norm(torch.cat([t.ravel() for t in s])).clip(min=1e-8)) # pylint:disable=not-callable
|
|
260
|
+
self.store(params, ['s', 'y'], [s, y])
|
|
261
|
+
|
|
262
|
+
@torch.no_grad
|
|
263
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
264
|
+
return [self.state[p]['y'] for p in params]
|
|
265
|
+
|
|
266
|
+
class RandomHvp(Module):
|
|
267
|
+
"""Returns a hessian-vector product with a random vector"""
|
|
268
|
+
|
|
269
|
+
def __init__(
|
|
270
|
+
self,
|
|
271
|
+
n_samples: int = 1,
|
|
272
|
+
distribution: Distributions = "normal",
|
|
273
|
+
update_freq: int = 1,
|
|
274
|
+
hvp_method: Literal["autograd", "forward", "central"] = "autograd",
|
|
275
|
+
h=1e-3,
|
|
276
|
+
):
|
|
277
|
+
defaults = dict(n_samples=n_samples, distribution=distribution, hvp_method=hvp_method, h=h, update_freq=update_freq)
|
|
278
|
+
super().__init__(defaults)
|
|
279
|
+
|
|
280
|
+
@torch.no_grad
|
|
281
|
+
def step(self, var):
|
|
282
|
+
params = TensorList(var.params)
|
|
283
|
+
settings = self.settings[params[0]]
|
|
284
|
+
n_samples = settings['n_samples']
|
|
285
|
+
distribution = settings['distribution']
|
|
286
|
+
hvp_method = settings['hvp_method']
|
|
287
|
+
h = settings['h']
|
|
288
|
+
update_freq = settings['update_freq']
|
|
289
|
+
|
|
290
|
+
step = self.global_state.get('step', 0)
|
|
291
|
+
self.global_state['step'] = step + 1
|
|
292
|
+
|
|
293
|
+
D = None
|
|
294
|
+
if step % update_freq == 0:
|
|
295
|
+
|
|
296
|
+
rgrad = None
|
|
297
|
+
for i in range(n_samples):
|
|
298
|
+
u = params.sample_like(distribution=distribution)
|
|
299
|
+
|
|
300
|
+
Hvp, rgrad = self.Hvp(u, at_x0=True, var=var, rgrad=rgrad, hvp_method=hvp_method,
|
|
301
|
+
h=h, normalize=True, retain_grad=i < n_samples-1)
|
|
302
|
+
|
|
303
|
+
if D is None: D = Hvp
|
|
304
|
+
else: torch._foreach_add_(D, Hvp)
|
|
305
|
+
|
|
306
|
+
if n_samples > 1: torch._foreach_div_(D, n_samples)
|
|
307
|
+
if update_freq != 1:
|
|
308
|
+
assert D is not None
|
|
309
|
+
D_buf = self.get_state(params, "D", cls=TensorList)
|
|
310
|
+
D_buf.set_(D)
|
|
311
|
+
|
|
312
|
+
if D is None:
|
|
313
|
+
D = self.get_state(params, "D", cls=TensorList)
|
|
314
|
+
|
|
315
|
+
var.update = list(D)
|
|
316
|
+
return var
|
|
@@ -0,0 +1,158 @@
|
|
|
1
|
+
from collections.abc import Iterable
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from ...core import Chainable, Module, Var
|
|
6
|
+
from ...utils import TensorList
|
|
7
|
+
|
|
8
|
+
def _sequential_step(self: Module, var: Var, sequential: bool):
|
|
9
|
+
params = var.params
|
|
10
|
+
steps = self.settings[params[0]]['steps']
|
|
11
|
+
|
|
12
|
+
if sequential: modules = self.get_children_sequence() * steps
|
|
13
|
+
else: modules = [self.children['module']] * steps
|
|
14
|
+
|
|
15
|
+
if var.closure is None and len(modules) > 1: raise ValueError('Multistep and Sequential require closure')
|
|
16
|
+
|
|
17
|
+
# store original params unless this is last module and can update params directly
|
|
18
|
+
params_before_steps = None if (var.is_last and var.last_module_lrs is None) else [p.clone() for p in params]
|
|
19
|
+
|
|
20
|
+
# first step - pass var as usual
|
|
21
|
+
var = modules[0].step(var)
|
|
22
|
+
new_var = var
|
|
23
|
+
|
|
24
|
+
# subsequent steps - update parameters and create new var
|
|
25
|
+
if len(modules) > 1:
|
|
26
|
+
for m in modules[1:]:
|
|
27
|
+
|
|
28
|
+
# update params
|
|
29
|
+
if (not new_var.skip_update):
|
|
30
|
+
if new_var.last_module_lrs is not None:
|
|
31
|
+
torch._foreach_mul_(new_var.get_update(), new_var.last_module_lrs)
|
|
32
|
+
|
|
33
|
+
torch._foreach_sub_(params, new_var.get_update())
|
|
34
|
+
|
|
35
|
+
# create new var since we are at a new point, that means grad, update and loss will be None
|
|
36
|
+
new_var = Var(params=new_var.params, closure=new_var.closure,
|
|
37
|
+
model=new_var.model, current_step=new_var.current_step + 1)
|
|
38
|
+
|
|
39
|
+
# step
|
|
40
|
+
new_var = m.step(new_var)
|
|
41
|
+
|
|
42
|
+
# final parameter update
|
|
43
|
+
if (not new_var.skip_update):
|
|
44
|
+
if new_var.last_module_lrs is not None:
|
|
45
|
+
torch._foreach_mul_(new_var.get_update(), new_var.last_module_lrs)
|
|
46
|
+
|
|
47
|
+
torch._foreach_sub_(params, new_var.get_update())
|
|
48
|
+
|
|
49
|
+
# if last module, update is applied so return new var
|
|
50
|
+
if params_before_steps is None:
|
|
51
|
+
new_var.stop = True
|
|
52
|
+
new_var.skip_update = True
|
|
53
|
+
return new_var
|
|
54
|
+
|
|
55
|
+
# otherwise use parameter difference as update
|
|
56
|
+
var.update = list(torch._foreach_sub(params_before_steps, params))
|
|
57
|
+
for p, bef in zip(params, params_before_steps):
|
|
58
|
+
p.set_(bef) # pyright:ignore[reportArgumentType]
|
|
59
|
+
return var
|
|
60
|
+
|
|
61
|
+
class Multistep(Module):
|
|
62
|
+
"""Performs :code:`steps` inner steps with :code:`module` per each step.
|
|
63
|
+
|
|
64
|
+
The update is taken to be the parameter difference between parameters before and after the inner loop."""
|
|
65
|
+
def __init__(self, module: Chainable, steps: int):
|
|
66
|
+
defaults = dict(steps=steps)
|
|
67
|
+
super().__init__(defaults)
|
|
68
|
+
self.set_child('module', module)
|
|
69
|
+
|
|
70
|
+
@torch.no_grad
|
|
71
|
+
def step(self, var):
|
|
72
|
+
return _sequential_step(self, var, sequential=False)
|
|
73
|
+
|
|
74
|
+
class Sequential(Module):
|
|
75
|
+
"""On each step, this sequentially steps with :code:`modules` :code:`steps` times.
|
|
76
|
+
|
|
77
|
+
The update is taken to be the parameter difference between parameters before and after the inner loop."""
|
|
78
|
+
def __init__(self, modules: Iterable[Chainable], steps: int=1):
|
|
79
|
+
defaults = dict(steps=steps)
|
|
80
|
+
super().__init__(defaults)
|
|
81
|
+
self.set_children_sequence(modules)
|
|
82
|
+
|
|
83
|
+
@torch.no_grad
|
|
84
|
+
def step(self, var):
|
|
85
|
+
return _sequential_step(self, var, sequential=True)
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
class NegateOnLossIncrease(Module):
|
|
89
|
+
"""Uses an extra forward pass to evaluate loss at :code:`parameters+update`,
|
|
90
|
+
if loss is larger than at :code:`parameters`,
|
|
91
|
+
the update is set to 0 if :code:`backtrack=False` and to :code:`-update` otherwise"""
|
|
92
|
+
def __init__(self, backtrack=False):
|
|
93
|
+
defaults = dict(backtrack=backtrack)
|
|
94
|
+
super().__init__(defaults=defaults)
|
|
95
|
+
|
|
96
|
+
@torch.no_grad
|
|
97
|
+
def step(self, var):
|
|
98
|
+
closure = var.closure
|
|
99
|
+
if closure is None: raise RuntimeError('NegateOnLossIncrease requires closure')
|
|
100
|
+
backtrack = self.settings[var.params[0]]['backtrack']
|
|
101
|
+
|
|
102
|
+
update = var.get_update()
|
|
103
|
+
f_0 = var.get_loss(backward=False)
|
|
104
|
+
|
|
105
|
+
torch._foreach_sub_(var.params, update)
|
|
106
|
+
f_1 = closure(False)
|
|
107
|
+
|
|
108
|
+
if f_1 <= f_0:
|
|
109
|
+
if var.is_last and var.last_module_lrs is None:
|
|
110
|
+
var.stop = True
|
|
111
|
+
var.skip_update = True
|
|
112
|
+
return var
|
|
113
|
+
|
|
114
|
+
torch._foreach_add_(var.params, update)
|
|
115
|
+
return var
|
|
116
|
+
|
|
117
|
+
torch._foreach_add_(var.params, update)
|
|
118
|
+
if backtrack:
|
|
119
|
+
torch._foreach_neg_(var.update)
|
|
120
|
+
else:
|
|
121
|
+
torch._foreach_zero_(var.update)
|
|
122
|
+
return var
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
class Online(Module):
|
|
126
|
+
"""Allows certain modules to be used for mini-batch optimization."""
|
|
127
|
+
def __init__(self, module: Chainable,):
|
|
128
|
+
super().__init__()
|
|
129
|
+
|
|
130
|
+
self.set_child('module', module)
|
|
131
|
+
|
|
132
|
+
@torch.no_grad
|
|
133
|
+
def step(self, var):
|
|
134
|
+
closure = var.closure
|
|
135
|
+
if closure is None: raise ValueError("Closure must be passed for Online")
|
|
136
|
+
step = self.global_state.get('step', 0) + 1
|
|
137
|
+
self.global_state['step'] = step
|
|
138
|
+
params = TensorList(var.params)
|
|
139
|
+
p_cur = params.clone()
|
|
140
|
+
p_prev = self.get_state(params, 'p_prev', cls=TensorList)
|
|
141
|
+
module = self.children['module']
|
|
142
|
+
|
|
143
|
+
if step == 1:
|
|
144
|
+
var = module.step(var.clone(clone_update=False))
|
|
145
|
+
|
|
146
|
+
p_prev.copy_(params)
|
|
147
|
+
return var
|
|
148
|
+
|
|
149
|
+
# restore previous params
|
|
150
|
+
var_prev = Var(params=params, closure=closure, model=var.model, current_step=var.current_step)
|
|
151
|
+
params.set_(p_prev)
|
|
152
|
+
module.reset_for_online()
|
|
153
|
+
module.update(var_prev)
|
|
154
|
+
|
|
155
|
+
# restore current params
|
|
156
|
+
params.set_(p_cur)
|
|
157
|
+
p_prev.copy_(params)
|
|
158
|
+
return module.step(var.clone(clone_update=False))
|