torchzero 0.1.8__py3-none-any.whl → 0.3.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docs/source/conf.py +57 -0
- tests/test_identical.py +230 -0
- tests/test_module.py +50 -0
- tests/test_opts.py +884 -0
- tests/test_tensorlist.py +1787 -0
- tests/test_utils_optimizer.py +170 -0
- tests/test_vars.py +184 -0
- torchzero/__init__.py +4 -4
- torchzero/core/__init__.py +3 -13
- torchzero/core/module.py +629 -510
- torchzero/core/preconditioner.py +137 -0
- torchzero/core/transform.py +252 -0
- torchzero/modules/__init__.py +13 -21
- torchzero/modules/clipping/__init__.py +3 -0
- torchzero/modules/clipping/clipping.py +320 -0
- torchzero/modules/clipping/ema_clipping.py +135 -0
- torchzero/modules/clipping/growth_clipping.py +187 -0
- torchzero/modules/experimental/__init__.py +13 -18
- torchzero/modules/experimental/absoap.py +350 -0
- torchzero/modules/experimental/adadam.py +111 -0
- torchzero/modules/experimental/adamY.py +135 -0
- torchzero/modules/experimental/adasoap.py +282 -0
- torchzero/modules/experimental/algebraic_newton.py +145 -0
- torchzero/modules/experimental/curveball.py +89 -0
- torchzero/modules/experimental/dsoap.py +290 -0
- torchzero/modules/experimental/gradmin.py +85 -0
- torchzero/modules/experimental/reduce_outward_lr.py +35 -0
- torchzero/modules/experimental/spectral.py +286 -0
- torchzero/modules/experimental/subspace_preconditioners.py +128 -0
- torchzero/modules/experimental/tropical_newton.py +136 -0
- torchzero/modules/functional.py +209 -0
- torchzero/modules/grad_approximation/__init__.py +4 -0
- torchzero/modules/grad_approximation/fdm.py +120 -0
- torchzero/modules/grad_approximation/forward_gradient.py +81 -0
- torchzero/modules/grad_approximation/grad_approximator.py +66 -0
- torchzero/modules/grad_approximation/rfdm.py +259 -0
- torchzero/modules/line_search/__init__.py +5 -30
- torchzero/modules/line_search/backtracking.py +186 -0
- torchzero/modules/line_search/line_search.py +181 -0
- torchzero/modules/line_search/scipy.py +37 -0
- torchzero/modules/line_search/strong_wolfe.py +260 -0
- torchzero/modules/line_search/trust_region.py +61 -0
- torchzero/modules/lr/__init__.py +2 -0
- torchzero/modules/lr/lr.py +59 -0
- torchzero/modules/lr/step_size.py +97 -0
- torchzero/modules/momentum/__init__.py +14 -4
- torchzero/modules/momentum/averaging.py +78 -0
- torchzero/modules/momentum/cautious.py +181 -0
- torchzero/modules/momentum/ema.py +173 -0
- torchzero/modules/momentum/experimental.py +189 -0
- torchzero/modules/momentum/matrix_momentum.py +124 -0
- torchzero/modules/momentum/momentum.py +43 -106
- torchzero/modules/ops/__init__.py +103 -0
- torchzero/modules/ops/accumulate.py +65 -0
- torchzero/modules/ops/binary.py +240 -0
- torchzero/modules/ops/debug.py +25 -0
- torchzero/modules/ops/misc.py +419 -0
- torchzero/modules/ops/multi.py +137 -0
- torchzero/modules/ops/reduce.py +149 -0
- torchzero/modules/ops/split.py +75 -0
- torchzero/modules/ops/switch.py +68 -0
- torchzero/modules/ops/unary.py +115 -0
- torchzero/modules/ops/utility.py +112 -0
- torchzero/modules/optimizers/__init__.py +18 -10
- torchzero/modules/optimizers/adagrad.py +146 -49
- torchzero/modules/optimizers/adam.py +112 -118
- torchzero/modules/optimizers/lion.py +18 -11
- torchzero/modules/optimizers/muon.py +222 -0
- torchzero/modules/optimizers/orthograd.py +55 -0
- torchzero/modules/optimizers/rmsprop.py +103 -51
- torchzero/modules/optimizers/rprop.py +342 -99
- torchzero/modules/optimizers/shampoo.py +197 -0
- torchzero/modules/optimizers/soap.py +286 -0
- torchzero/modules/optimizers/sophia_h.py +129 -0
- torchzero/modules/projections/__init__.py +5 -0
- torchzero/modules/projections/dct.py +73 -0
- torchzero/modules/projections/fft.py +73 -0
- torchzero/modules/projections/galore.py +10 -0
- torchzero/modules/projections/projection.py +218 -0
- torchzero/modules/projections/structural.py +151 -0
- torchzero/modules/quasi_newton/__init__.py +7 -4
- torchzero/modules/quasi_newton/cg.py +218 -0
- torchzero/modules/quasi_newton/experimental/__init__.py +1 -0
- torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +265 -0
- torchzero/modules/quasi_newton/lbfgs.py +228 -0
- torchzero/modules/quasi_newton/lsr1.py +170 -0
- torchzero/modules/quasi_newton/olbfgs.py +196 -0
- torchzero/modules/quasi_newton/quasi_newton.py +475 -0
- torchzero/modules/second_order/__init__.py +3 -4
- torchzero/modules/second_order/newton.py +142 -165
- torchzero/modules/second_order/newton_cg.py +84 -0
- torchzero/modules/second_order/nystrom.py +168 -0
- torchzero/modules/smoothing/__init__.py +2 -5
- torchzero/modules/smoothing/gaussian.py +164 -0
- torchzero/modules/smoothing/{laplacian_smoothing.py → laplacian.py} +115 -128
- torchzero/modules/weight_decay/__init__.py +1 -0
- torchzero/modules/weight_decay/weight_decay.py +52 -0
- torchzero/modules/wrappers/__init__.py +1 -0
- torchzero/modules/wrappers/optim_wrapper.py +91 -0
- torchzero/optim/__init__.py +2 -10
- torchzero/optim/utility/__init__.py +1 -0
- torchzero/optim/utility/split.py +45 -0
- torchzero/optim/wrappers/nevergrad.py +2 -28
- torchzero/optim/wrappers/nlopt.py +31 -16
- torchzero/optim/wrappers/scipy.py +79 -156
- torchzero/utils/__init__.py +27 -0
- torchzero/utils/compile.py +175 -37
- torchzero/utils/derivatives.py +513 -99
- torchzero/utils/linalg/__init__.py +5 -0
- torchzero/utils/linalg/matrix_funcs.py +87 -0
- torchzero/utils/linalg/orthogonalize.py +11 -0
- torchzero/utils/linalg/qr.py +71 -0
- torchzero/utils/linalg/solve.py +168 -0
- torchzero/utils/linalg/svd.py +20 -0
- torchzero/utils/numberlist.py +132 -0
- torchzero/utils/ops.py +10 -0
- torchzero/utils/optimizer.py +284 -0
- torchzero/utils/optuna_tools.py +40 -0
- torchzero/utils/params.py +149 -0
- torchzero/utils/python_tools.py +40 -25
- torchzero/utils/tensorlist.py +1081 -0
- torchzero/utils/torch_tools.py +48 -12
- torchzero-0.3.2.dist-info/METADATA +379 -0
- torchzero-0.3.2.dist-info/RECORD +128 -0
- {torchzero-0.1.8.dist-info → torchzero-0.3.2.dist-info}/WHEEL +1 -1
- {torchzero-0.1.8.dist-info → torchzero-0.3.2.dist-info/licenses}/LICENSE +0 -0
- torchzero-0.3.2.dist-info/top_level.txt +3 -0
- torchzero/core/tensorlist_optimizer.py +0 -219
- torchzero/modules/adaptive/__init__.py +0 -4
- torchzero/modules/adaptive/adaptive.py +0 -192
- torchzero/modules/experimental/experimental.py +0 -294
- torchzero/modules/experimental/quad_interp.py +0 -104
- torchzero/modules/experimental/subspace.py +0 -259
- torchzero/modules/gradient_approximation/__init__.py +0 -7
- torchzero/modules/gradient_approximation/_fd_formulas.py +0 -3
- torchzero/modules/gradient_approximation/base_approximator.py +0 -105
- torchzero/modules/gradient_approximation/fdm.py +0 -125
- torchzero/modules/gradient_approximation/forward_gradient.py +0 -163
- torchzero/modules/gradient_approximation/newton_fdm.py +0 -198
- torchzero/modules/gradient_approximation/rfdm.py +0 -125
- torchzero/modules/line_search/armijo.py +0 -56
- torchzero/modules/line_search/base_ls.py +0 -139
- torchzero/modules/line_search/directional_newton.py +0 -217
- torchzero/modules/line_search/grid_ls.py +0 -158
- torchzero/modules/line_search/scipy_minimize_scalar.py +0 -62
- torchzero/modules/meta/__init__.py +0 -12
- torchzero/modules/meta/alternate.py +0 -65
- torchzero/modules/meta/grafting.py +0 -195
- torchzero/modules/meta/optimizer_wrapper.py +0 -173
- torchzero/modules/meta/return_overrides.py +0 -46
- torchzero/modules/misc/__init__.py +0 -10
- torchzero/modules/misc/accumulate.py +0 -43
- torchzero/modules/misc/basic.py +0 -115
- torchzero/modules/misc/lr.py +0 -96
- torchzero/modules/misc/multistep.py +0 -51
- torchzero/modules/misc/on_increase.py +0 -53
- torchzero/modules/operations/__init__.py +0 -29
- torchzero/modules/operations/multi.py +0 -298
- torchzero/modules/operations/reduction.py +0 -134
- torchzero/modules/operations/singular.py +0 -113
- torchzero/modules/optimizers/sgd.py +0 -54
- torchzero/modules/orthogonalization/__init__.py +0 -2
- torchzero/modules/orthogonalization/newtonschulz.py +0 -159
- torchzero/modules/orthogonalization/svd.py +0 -86
- torchzero/modules/regularization/__init__.py +0 -22
- torchzero/modules/regularization/dropout.py +0 -34
- torchzero/modules/regularization/noise.py +0 -77
- torchzero/modules/regularization/normalization.py +0 -328
- torchzero/modules/regularization/ortho_grad.py +0 -78
- torchzero/modules/regularization/weight_decay.py +0 -92
- torchzero/modules/scheduling/__init__.py +0 -2
- torchzero/modules/scheduling/lr_schedulers.py +0 -131
- torchzero/modules/scheduling/step_size.py +0 -80
- torchzero/modules/smoothing/gaussian_smoothing.py +0 -90
- torchzero/modules/weight_averaging/__init__.py +0 -2
- torchzero/modules/weight_averaging/ema.py +0 -72
- torchzero/modules/weight_averaging/swa.py +0 -171
- torchzero/optim/experimental/__init__.py +0 -20
- torchzero/optim/experimental/experimental.py +0 -343
- torchzero/optim/experimental/ray_search.py +0 -83
- torchzero/optim/first_order/__init__.py +0 -18
- torchzero/optim/first_order/cautious.py +0 -158
- torchzero/optim/first_order/forward_gradient.py +0 -70
- torchzero/optim/first_order/optimizers.py +0 -570
- torchzero/optim/modular.py +0 -148
- torchzero/optim/quasi_newton/__init__.py +0 -1
- torchzero/optim/quasi_newton/directional_newton.py +0 -58
- torchzero/optim/second_order/__init__.py +0 -1
- torchzero/optim/second_order/newton.py +0 -94
- torchzero/optim/zeroth_order/__init__.py +0 -4
- torchzero/optim/zeroth_order/fdm.py +0 -87
- torchzero/optim/zeroth_order/newton_fdm.py +0 -146
- torchzero/optim/zeroth_order/rfdm.py +0 -217
- torchzero/optim/zeroth_order/rs.py +0 -85
- torchzero/random/__init__.py +0 -1
- torchzero/random/random.py +0 -46
- torchzero/tensorlist.py +0 -826
- torchzero-0.1.8.dist-info/METADATA +0 -130
- torchzero-0.1.8.dist-info/RECORD +0 -104
- torchzero-0.1.8.dist-info/top_level.txt +0 -1
torchzero/modules/misc/basic.py
DELETED
|
@@ -1,115 +0,0 @@
|
|
|
1
|
-
from collections.abc import Callable, Iterable
|
|
2
|
-
|
|
3
|
-
import torch
|
|
4
|
-
|
|
5
|
-
from ...tensorlist import TensorList
|
|
6
|
-
|
|
7
|
-
from ...core import OptimizerModule, _Chainable
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
class Alpha(OptimizerModule):
|
|
11
|
-
"""Multiplies update by the learning rate, won't get picked up by learning rate schedulers."""
|
|
12
|
-
def __init__(self, alpha = 1e-3):
|
|
13
|
-
defaults = dict(alpha = alpha)
|
|
14
|
-
super().__init__(defaults)
|
|
15
|
-
|
|
16
|
-
@torch.no_grad
|
|
17
|
-
def _update(self, vars, ascent):
|
|
18
|
-
# multiply ascent direction by lr in-place
|
|
19
|
-
lr = self.get_group_key('alpha')
|
|
20
|
-
ascent *= lr
|
|
21
|
-
return ascent
|
|
22
|
-
|
|
23
|
-
class Clone(OptimizerModule):
|
|
24
|
-
"""Clones the update. Some modules update ascent in-place, so this may be
|
|
25
|
-
useful if you need to preserve it."""
|
|
26
|
-
def __init__(self):
|
|
27
|
-
super().__init__({})
|
|
28
|
-
|
|
29
|
-
@torch.no_grad
|
|
30
|
-
def _update(self, vars, ascent): return ascent.clone()
|
|
31
|
-
|
|
32
|
-
class Identity(OptimizerModule):
|
|
33
|
-
"""Does nothing."""
|
|
34
|
-
def __init__(self, *args, **kwargs):
|
|
35
|
-
super().__init__({})
|
|
36
|
-
|
|
37
|
-
@torch.no_grad
|
|
38
|
-
def _update(self, vars, ascent): return ascent
|
|
39
|
-
|
|
40
|
-
class Lambda(OptimizerModule):
|
|
41
|
-
"""Applies a function to the ascent direction.
|
|
42
|
-
The function must take a TensorList as the argument, and return the modified tensorlist.
|
|
43
|
-
|
|
44
|
-
Args:
|
|
45
|
-
f (Callable): function
|
|
46
|
-
"""
|
|
47
|
-
def __init__(self, f: Callable[[TensorList], TensorList]):
|
|
48
|
-
super().__init__({})
|
|
49
|
-
self.f = f
|
|
50
|
-
|
|
51
|
-
@torch.no_grad()
|
|
52
|
-
def _update(self, vars, ascent): return self.f(ascent)
|
|
53
|
-
|
|
54
|
-
class Grad(OptimizerModule):
|
|
55
|
-
"""Uses gradient as the update. This is useful for chains."""
|
|
56
|
-
def __init__(self):
|
|
57
|
-
super().__init__({})
|
|
58
|
-
|
|
59
|
-
@torch.no_grad
|
|
60
|
-
def _update(self, vars, ascent):
|
|
61
|
-
ascent = vars.ascent = vars.maybe_compute_grad_(self.get_params())
|
|
62
|
-
return ascent
|
|
63
|
-
|
|
64
|
-
class Zeros(OptimizerModule):
|
|
65
|
-
def __init__(self):
|
|
66
|
-
super().__init__({})
|
|
67
|
-
|
|
68
|
-
@torch.no_grad
|
|
69
|
-
def _update(self, vars, ascent):
|
|
70
|
-
return ascent.zeros_like()
|
|
71
|
-
|
|
72
|
-
class Fill(OptimizerModule):
|
|
73
|
-
def __init__(self, value):
|
|
74
|
-
super().__init__({"value": value})
|
|
75
|
-
|
|
76
|
-
@torch.no_grad
|
|
77
|
-
def _update(self, vars, ascent):
|
|
78
|
-
return ascent.fill(self.get_group_key('value'))
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
class GradToUpdate(OptimizerModule):
|
|
82
|
-
"""sets gradient and .grad attributes to current update"""
|
|
83
|
-
def __init__(self):
|
|
84
|
-
super().__init__({})
|
|
85
|
-
|
|
86
|
-
def _update(self, vars, ascent):
|
|
87
|
-
vars.set_grad_(ascent, self.get_params())
|
|
88
|
-
return ascent
|
|
89
|
-
|
|
90
|
-
class MakeClosure(OptimizerModule):
|
|
91
|
-
"""Makes a closure that sets `.grad` attribute to the update generated by `modules`"""
|
|
92
|
-
def __init__(self, modules: _Chainable):
|
|
93
|
-
super().__init__({})
|
|
94
|
-
self._set_child_('modules', modules)
|
|
95
|
-
|
|
96
|
-
def step(self, vars):
|
|
97
|
-
if vars.closure is None: raise ValueError("MakeClosure requires a closure")
|
|
98
|
-
|
|
99
|
-
params = self.get_params()
|
|
100
|
-
orig_closure = vars.closure
|
|
101
|
-
orig_state = vars.copy(True)
|
|
102
|
-
|
|
103
|
-
def new_closure(backward = True):
|
|
104
|
-
if backward:
|
|
105
|
-
cloned_state = orig_state.copy(True)
|
|
106
|
-
g = self.children['modules'].return_ascent(cloned_state)
|
|
107
|
-
params.set_grad_(g)
|
|
108
|
-
return cloned_state.get_loss()
|
|
109
|
-
|
|
110
|
-
else:
|
|
111
|
-
return orig_closure(False)
|
|
112
|
-
|
|
113
|
-
vars.closure = new_closure # type:ignore
|
|
114
|
-
return self._update_params_or_step_with_next(vars)
|
|
115
|
-
|
torchzero/modules/misc/lr.py
DELETED
|
@@ -1,96 +0,0 @@
|
|
|
1
|
-
import random
|
|
2
|
-
from collections.abc import Callable, Iterable
|
|
3
|
-
from functools import partial
|
|
4
|
-
from typing import TYPE_CHECKING, Any, overload
|
|
5
|
-
|
|
6
|
-
import torch
|
|
7
|
-
|
|
8
|
-
from ...tensorlist import TensorList
|
|
9
|
-
|
|
10
|
-
from ...core import OptimizerModule
|
|
11
|
-
|
|
12
|
-
if TYPE_CHECKING:
|
|
13
|
-
from ...optim import Modular
|
|
14
|
-
|
|
15
|
-
def _init_scheduler_hook(opt: "Modular", module: "LR", scheduler_cls, **kwargs):
|
|
16
|
-
"""post init hook that initializes the lr scheduler to the LR module and sets `_scheduler_step_fn`."""
|
|
17
|
-
scheduler = scheduler_cls(module, **kwargs)
|
|
18
|
-
module._scheduler_step_fn = scheduler.step
|
|
19
|
-
|
|
20
|
-
def _set_momentum_hook(optimizer, state, momentum):
|
|
21
|
-
for module in optimizer.unrolled_modules:
|
|
22
|
-
if 'momentum' in module.defaults:
|
|
23
|
-
for g in module.param_groups:
|
|
24
|
-
g['momentum'] = momentum
|
|
25
|
-
elif 'beta1' in module.defaults:
|
|
26
|
-
for g in module.param_groups:
|
|
27
|
-
g['beta1'] = momentum
|
|
28
|
-
|
|
29
|
-
class LR(OptimizerModule):
|
|
30
|
-
"""Multiplies update by the learning rate. Optionally uses an lr scheduler.
|
|
31
|
-
|
|
32
|
-
Args:
|
|
33
|
-
lr (float, optional): learning rate. Defaults to 1e-3.
|
|
34
|
-
scheduler (Callable[..., torch.optim.lr_scheduler.LRScheduler | Any] | None, optional):
|
|
35
|
-
A scheduler class, for example `torch.optim.lr_scheduler.OneCycleLR`. Defaults to None.
|
|
36
|
-
cycle_momentum (bool, optional):
|
|
37
|
-
enables schedulers that support it to affect momentum (like OneCycleLR).
|
|
38
|
-
The momentum will be cycled on ALL modules that have `momentum` or `beta1` setting.
|
|
39
|
-
This does not support external optimizers, wrapped with `Wrap`. Defaults to True.
|
|
40
|
-
sheduler_step_every (int, optional):
|
|
41
|
-
step with scheduler every n optimizer steps.
|
|
42
|
-
Useful when the scheduler steps once per epoch. Defaults to 1.
|
|
43
|
-
**kwargs:
|
|
44
|
-
kwargs to pass to `scheduler`.
|
|
45
|
-
"""
|
|
46
|
-
IS_LR_MODULE = True
|
|
47
|
-
def __init__(
|
|
48
|
-
self,
|
|
49
|
-
lr: float = 1e-3,
|
|
50
|
-
scheduler_cls: Callable[..., torch.optim.lr_scheduler.LRScheduler | Any] | None = None,
|
|
51
|
-
cycle_momentum: bool = True,
|
|
52
|
-
sheduler_step_every: int = 1,
|
|
53
|
-
# *args,
|
|
54
|
-
**kwargs,
|
|
55
|
-
):
|
|
56
|
-
|
|
57
|
-
defaults = dict(lr = lr)
|
|
58
|
-
|
|
59
|
-
if (scheduler_cls is not None) and cycle_momentum:
|
|
60
|
-
defaults['momentum'] = 0
|
|
61
|
-
super().__init__(defaults)
|
|
62
|
-
|
|
63
|
-
self._scheduler_step_fn = None
|
|
64
|
-
self.sheduler_step_every = sheduler_step_every
|
|
65
|
-
self.cycle_momentum = cycle_momentum
|
|
66
|
-
self.cur = 0
|
|
67
|
-
|
|
68
|
-
if scheduler_cls is not None:
|
|
69
|
-
self.post_init_hooks.append(lambda opt, module: _init_scheduler_hook(opt, module, scheduler_cls, **kwargs))
|
|
70
|
-
|
|
71
|
-
self._skip = False
|
|
72
|
-
|
|
73
|
-
@torch.no_grad
|
|
74
|
-
def _update(self, vars, ascent):
|
|
75
|
-
# step with scheduler
|
|
76
|
-
if self._scheduler_step_fn is not None:
|
|
77
|
-
if self.cur != 0 and self.cur % self.sheduler_step_every == 0:
|
|
78
|
-
self._scheduler_step_fn()
|
|
79
|
-
|
|
80
|
-
# add a hook to cycle momentum
|
|
81
|
-
if self.cycle_momentum:
|
|
82
|
-
vars.add_post_step_hook(_set_momentum_hook)
|
|
83
|
-
|
|
84
|
-
# remove init hook to delete reference to scheduler
|
|
85
|
-
if self.cur == 0 and len(self.post_init_hooks) == 1:
|
|
86
|
-
del self.post_init_hooks[0]
|
|
87
|
-
|
|
88
|
-
# skip if lr was applied by previous module (LR fusing)
|
|
89
|
-
if not self._skip:
|
|
90
|
-
# multiply ascent direction by lr in-place
|
|
91
|
-
lr = self.get_group_key('lr')
|
|
92
|
-
ascent *= lr
|
|
93
|
-
|
|
94
|
-
self.cur += 1
|
|
95
|
-
self._skip = False
|
|
96
|
-
return ascent
|
|
@@ -1,51 +0,0 @@
|
|
|
1
|
-
from collections.abc import Callable, Iterable
|
|
2
|
-
|
|
3
|
-
import torch
|
|
4
|
-
|
|
5
|
-
from ...tensorlist import TensorList
|
|
6
|
-
|
|
7
|
-
from ...core import OptimizerModule, _Chainable
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
class Multistep(OptimizerModule):
|
|
11
|
-
"""Performs multiple steps (per batch), passes total update to the next module.
|
|
12
|
-
|
|
13
|
-
Args:
|
|
14
|
-
modules (_Chainable): modules to perform multiple steps with.
|
|
15
|
-
num_steps (int, optional): number of steps to perform. Defaults to 2.
|
|
16
|
-
"""
|
|
17
|
-
def __init__(self, modules: _Chainable, num_steps: int = 2):
|
|
18
|
-
super().__init__({})
|
|
19
|
-
self.num_steps = num_steps
|
|
20
|
-
|
|
21
|
-
self._set_child_('modules', modules)
|
|
22
|
-
|
|
23
|
-
def step(self, vars):
|
|
24
|
-
# no next module, just perform multiple steps
|
|
25
|
-
if self.next_module is None:
|
|
26
|
-
ret = None
|
|
27
|
-
for step in range(self.num_steps):
|
|
28
|
-
state_copy = vars.copy(clone_ascent=True) if step != self.num_steps - 1 else vars
|
|
29
|
-
ret = self.children['modules'].step(state_copy)
|
|
30
|
-
|
|
31
|
-
# since parameters are updated after stepping, grad and fx0 must be erased as they are no longer correct
|
|
32
|
-
vars.grad = None; vars.fx0 = None
|
|
33
|
-
|
|
34
|
-
return ret
|
|
35
|
-
|
|
36
|
-
# accumulate steps and pass to next module
|
|
37
|
-
p0 = self.get_params().clone()
|
|
38
|
-
for step in range(self.num_steps):
|
|
39
|
-
state_copy = vars.copy(clone_ascent=True) if step != self.num_steps - 1 else vars
|
|
40
|
-
self.children['modules'].step(state_copy)
|
|
41
|
-
|
|
42
|
-
# since parameters are updated after stepping, grad and fx0 must be erased as they are no longer correct
|
|
43
|
-
vars.grad = None; vars.fx0 = None
|
|
44
|
-
|
|
45
|
-
p1 = self.get_params()
|
|
46
|
-
vars.ascent = p0 - p1
|
|
47
|
-
|
|
48
|
-
# undo ascent
|
|
49
|
-
p1.set_(p0)
|
|
50
|
-
|
|
51
|
-
return self._update_params_or_step_with_next(vars, p1)
|
|
@@ -1,53 +0,0 @@
|
|
|
1
|
-
import torch
|
|
2
|
-
|
|
3
|
-
from ...core import OptimizerModule
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
class NegateOnLossIncrease(OptimizerModule):
|
|
7
|
-
"""Performs an additional evaluation to check if update increases the loss. If it does,
|
|
8
|
-
negates or backtracks the update.
|
|
9
|
-
|
|
10
|
-
Args:
|
|
11
|
-
backtrack (bool, optional):
|
|
12
|
-
if True, sets update to minus update, otherwise sets it to zero. Defaults to True.
|
|
13
|
-
"""
|
|
14
|
-
def __init__(self, backtrack = True):
|
|
15
|
-
super().__init__({})
|
|
16
|
-
self.backtrack = backtrack
|
|
17
|
-
|
|
18
|
-
@torch.no_grad()
|
|
19
|
-
def step(self, vars):
|
|
20
|
-
if vars.closure is None: raise ValueError('NegateOnLossIncrease requires closure.')
|
|
21
|
-
if vars.fx0 is None: vars.fx0 = vars.closure(False)
|
|
22
|
-
|
|
23
|
-
# subtract ascent direction to params and see if loss decreases
|
|
24
|
-
params = self.get_params()
|
|
25
|
-
ascent_direction = vars.maybe_use_grad_(params)
|
|
26
|
-
params -= ascent_direction
|
|
27
|
-
vars.fx0_approx = vars.closure(False)
|
|
28
|
-
|
|
29
|
-
# if this has no children, update params and return loss
|
|
30
|
-
if self.next_module is None:
|
|
31
|
-
if params is None: params = self.get_params()
|
|
32
|
-
|
|
33
|
-
if vars.fx0_approx > vars.fx0:
|
|
34
|
-
# loss increased, so we negate thea scent direction
|
|
35
|
-
# we are currently at params - ascent direction
|
|
36
|
-
# so we add twice the ascent direction
|
|
37
|
-
params.add_(ascent_direction, alpha = 2 if self.backtrack else 1)
|
|
38
|
-
|
|
39
|
-
# else: we are already at a lower loss point
|
|
40
|
-
return vars.get_loss()
|
|
41
|
-
|
|
42
|
-
# otherwise undo the ascent direction because it is passed to the child
|
|
43
|
-
params += ascent_direction
|
|
44
|
-
|
|
45
|
-
# if loss increases, negate ascent direction
|
|
46
|
-
if vars.fx0_approx > vars.fx0:
|
|
47
|
-
if self.backtrack: ascent_direction.neg_()
|
|
48
|
-
else: ascent_direction.zero_()
|
|
49
|
-
|
|
50
|
-
# otherwise undo the ascent direction and pass the updated ascent direction to the child
|
|
51
|
-
return self.next_module.step(vars)
|
|
52
|
-
|
|
53
|
-
|
|
@@ -1,29 +0,0 @@
|
|
|
1
|
-
from .multi import (
|
|
2
|
-
Add,
|
|
3
|
-
AddMagnitude,
|
|
4
|
-
Div,
|
|
5
|
-
Divide,
|
|
6
|
-
Interpolate,
|
|
7
|
-
Lerp,
|
|
8
|
-
Mul,
|
|
9
|
-
Pow,
|
|
10
|
-
Power,
|
|
11
|
-
RDiv,
|
|
12
|
-
RPow,
|
|
13
|
-
RSub,
|
|
14
|
-
Sub,
|
|
15
|
-
Subtract,
|
|
16
|
-
)
|
|
17
|
-
from .reduction import Mean, Product, Sum
|
|
18
|
-
from .singular import (
|
|
19
|
-
Abs,
|
|
20
|
-
Cos,
|
|
21
|
-
MagnitudePower,
|
|
22
|
-
NanToNum,
|
|
23
|
-
Negate,
|
|
24
|
-
Operation,
|
|
25
|
-
Reciprocal,
|
|
26
|
-
Sign,
|
|
27
|
-
Sin,
|
|
28
|
-
sign_grad_,
|
|
29
|
-
)
|
|
@@ -1,298 +0,0 @@
|
|
|
1
|
-
from collections.abc import Iterable
|
|
2
|
-
import torch
|
|
3
|
-
|
|
4
|
-
from ...core import OptimizerModule
|
|
5
|
-
|
|
6
|
-
_Value = int | float | OptimizerModule | Iterable[OptimizerModule]
|
|
7
|
-
|
|
8
|
-
class Add(OptimizerModule):
|
|
9
|
-
"""add `value` to update. `value` can be a scalar, an OptimizerModule or sequence of OptimizerModules"""
|
|
10
|
-
def __init__(self, value: _Value):
|
|
11
|
-
super().__init__({})
|
|
12
|
-
|
|
13
|
-
if not isinstance(value, (int, float)):
|
|
14
|
-
self._set_child_('value', value)
|
|
15
|
-
|
|
16
|
-
self.value = value
|
|
17
|
-
|
|
18
|
-
@torch.no_grad()
|
|
19
|
-
def _update(self, vars, ascent):
|
|
20
|
-
if isinstance(self.value, (int, float)):
|
|
21
|
-
return ascent.add_(self.value)
|
|
22
|
-
|
|
23
|
-
state_copy = vars.copy(clone_ascent = True)
|
|
24
|
-
v = self.children['value'].return_ascent(state_copy)
|
|
25
|
-
return ascent.add_(v)
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
class Sub(OptimizerModule):
|
|
29
|
-
"""subtracts `value` from update. `value` can be a scalar, an OptimizerModule or sequence of OptimizerModules"""
|
|
30
|
-
def __init__(self, subtrahend: _Value):
|
|
31
|
-
super().__init__({})
|
|
32
|
-
|
|
33
|
-
if not isinstance(subtrahend, (int, float)):
|
|
34
|
-
self._set_child_('subtrahend', subtrahend)
|
|
35
|
-
|
|
36
|
-
self.subtrahend = subtrahend
|
|
37
|
-
|
|
38
|
-
@torch.no_grad()
|
|
39
|
-
def _update(self, vars, ascent):
|
|
40
|
-
if isinstance(self.subtrahend, (int, float)):
|
|
41
|
-
return ascent.sub_(self.subtrahend)
|
|
42
|
-
|
|
43
|
-
state_copy = vars.copy(clone_ascent = True)
|
|
44
|
-
subtrahend = self.children['subtrahend'].return_ascent(state_copy)
|
|
45
|
-
return ascent.sub_(subtrahend)
|
|
46
|
-
|
|
47
|
-
class RSub(OptimizerModule):
|
|
48
|
-
"""subtracts update from `value`. `value` can be a scalar, an OptimizerModule or sequence of OptimizerModules"""
|
|
49
|
-
def __init__(self, minuend: _Value):
|
|
50
|
-
super().__init__({})
|
|
51
|
-
|
|
52
|
-
if not isinstance(minuend, (int, float)):
|
|
53
|
-
self._set_child_('minuend', minuend)
|
|
54
|
-
|
|
55
|
-
self.minuend = minuend
|
|
56
|
-
|
|
57
|
-
@torch.no_grad()
|
|
58
|
-
def _update(self, vars, ascent):
|
|
59
|
-
if isinstance(self.minuend, (int, float)):
|
|
60
|
-
return ascent.sub_(self.minuend).neg_()
|
|
61
|
-
|
|
62
|
-
state_copy = vars.copy(clone_ascent = True)
|
|
63
|
-
minuend = self.children['minuend'].return_ascent(state_copy)
|
|
64
|
-
return ascent.sub_(minuend).neg_()
|
|
65
|
-
|
|
66
|
-
class Subtract(OptimizerModule):
|
|
67
|
-
"""Calculates `minuend - subtrahend`"""
|
|
68
|
-
def __init__(
|
|
69
|
-
self,
|
|
70
|
-
minuend: OptimizerModule | Iterable[OptimizerModule],
|
|
71
|
-
subtrahend: OptimizerModule | Iterable[OptimizerModule],
|
|
72
|
-
):
|
|
73
|
-
super().__init__({})
|
|
74
|
-
self._set_child_('minuend', minuend)
|
|
75
|
-
self._set_child_('subtrahend', subtrahend)
|
|
76
|
-
|
|
77
|
-
@torch.no_grad
|
|
78
|
-
def step(self, vars):
|
|
79
|
-
state_copy = vars.copy(clone_ascent = True)
|
|
80
|
-
minuend = self.children['minuend'].return_ascent(state_copy)
|
|
81
|
-
vars.update_attrs_(state_copy)
|
|
82
|
-
subtrahend = self.children['subtrahend'].return_ascent(vars)
|
|
83
|
-
|
|
84
|
-
vars.ascent = minuend.sub_(subtrahend)
|
|
85
|
-
return self._update_params_or_step_with_next(vars)
|
|
86
|
-
|
|
87
|
-
class Mul(OptimizerModule):
|
|
88
|
-
"""multiplies update by `value`. `value` can be a scalar, an OptimizerModule or sequence of OptimizerModules"""
|
|
89
|
-
def __init__(self, value: _Value):
|
|
90
|
-
super().__init__({})
|
|
91
|
-
|
|
92
|
-
if not isinstance(value, (int, float)):
|
|
93
|
-
self._set_child_('value', value)
|
|
94
|
-
|
|
95
|
-
self.value = value
|
|
96
|
-
|
|
97
|
-
@torch.no_grad()
|
|
98
|
-
def _update(self, vars, ascent):
|
|
99
|
-
if isinstance(self.value, (int, float)):
|
|
100
|
-
return ascent.mul_(self.value)
|
|
101
|
-
|
|
102
|
-
state_copy = vars.copy(clone_ascent = True)
|
|
103
|
-
v = self.children['value'].return_ascent(state_copy)
|
|
104
|
-
return ascent.mul_(v)
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
class Div(OptimizerModule):
|
|
108
|
-
"""divides update by `value`. `value` can be a scalar, an OptimizerModule or sequence of OptimizerModules"""
|
|
109
|
-
def __init__(self, denominator: _Value):
|
|
110
|
-
super().__init__({})
|
|
111
|
-
|
|
112
|
-
if not isinstance(denominator, (int, float)):
|
|
113
|
-
self._set_child_('denominator', denominator)
|
|
114
|
-
|
|
115
|
-
self.denominator = denominator
|
|
116
|
-
|
|
117
|
-
@torch.no_grad()
|
|
118
|
-
def _update(self, vars, ascent):
|
|
119
|
-
if isinstance(self.denominator, (int, float)):
|
|
120
|
-
return ascent.div_(self.denominator)
|
|
121
|
-
|
|
122
|
-
state_copy = vars.copy(clone_ascent = True)
|
|
123
|
-
denominator = self.children['denominator'].return_ascent(state_copy)
|
|
124
|
-
return ascent.div_(denominator)
|
|
125
|
-
|
|
126
|
-
class RDiv(OptimizerModule):
|
|
127
|
-
"""`value` by update. `value` can be a scalar, an OptimizerModule or sequence of OptimizerModules"""
|
|
128
|
-
def __init__(self, numerator: _Value):
|
|
129
|
-
super().__init__({})
|
|
130
|
-
|
|
131
|
-
if not isinstance(numerator, (int, float)):
|
|
132
|
-
self._set_child_('numerator', numerator)
|
|
133
|
-
|
|
134
|
-
self.numerator = numerator
|
|
135
|
-
|
|
136
|
-
@torch.no_grad()
|
|
137
|
-
def _update(self, vars, ascent):
|
|
138
|
-
if isinstance(self.numerator, (int, float)):
|
|
139
|
-
return ascent.reciprocal_().mul_(self.numerator)
|
|
140
|
-
|
|
141
|
-
state_copy = vars.copy(clone_ascent = True)
|
|
142
|
-
numerator = self.children['numerator'].return_ascent(state_copy)
|
|
143
|
-
return ascent.reciprocal_().mul_(numerator)
|
|
144
|
-
|
|
145
|
-
class Divide(OptimizerModule):
|
|
146
|
-
"""calculates *numerator / denominator*"""
|
|
147
|
-
def __init__(
|
|
148
|
-
self,
|
|
149
|
-
numerator: OptimizerModule | Iterable[OptimizerModule],
|
|
150
|
-
denominator: OptimizerModule | Iterable[OptimizerModule],
|
|
151
|
-
):
|
|
152
|
-
super().__init__({})
|
|
153
|
-
self._set_child_('numerator', numerator)
|
|
154
|
-
self._set_child_('denominator', denominator)
|
|
155
|
-
|
|
156
|
-
@torch.no_grad
|
|
157
|
-
def step(self, vars):
|
|
158
|
-
state_copy = vars.copy(clone_ascent = True)
|
|
159
|
-
numerator = self.children['numerator'].return_ascent(state_copy)
|
|
160
|
-
vars.update_attrs_(state_copy)
|
|
161
|
-
denominator = self.children['denominator'].return_ascent(vars)
|
|
162
|
-
|
|
163
|
-
vars.ascent = numerator.div_(denominator)
|
|
164
|
-
return self._update_params_or_step_with_next(vars)
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
class Pow(OptimizerModule):
|
|
168
|
-
"""takes ascent to the power of `value`. `value` can be a scalar, an OptimizerModule or sequence of OptimizerModules"""
|
|
169
|
-
def __init__(self, power: _Value):
|
|
170
|
-
super().__init__({})
|
|
171
|
-
|
|
172
|
-
if not isinstance(power, (int, float)):
|
|
173
|
-
self._set_child_('power', power)
|
|
174
|
-
|
|
175
|
-
self.power = power
|
|
176
|
-
|
|
177
|
-
@torch.no_grad()
|
|
178
|
-
def _update(self, vars, ascent):
|
|
179
|
-
if isinstance(self.power, (int, float)):
|
|
180
|
-
return ascent.pow_(self.power)
|
|
181
|
-
|
|
182
|
-
state_copy = vars.copy(clone_ascent = True)
|
|
183
|
-
power = self.children['power'].return_ascent(state_copy)
|
|
184
|
-
return ascent.pow_(power)
|
|
185
|
-
|
|
186
|
-
class RPow(OptimizerModule):
|
|
187
|
-
"""takes `value` to the power of ascent. `value` can be a scalar, an OptimizerModule or sequence of OptimizerModules"""
|
|
188
|
-
def __init__(self, base: _Value):
|
|
189
|
-
super().__init__({})
|
|
190
|
-
|
|
191
|
-
if not isinstance(base, (int, float)):
|
|
192
|
-
self._set_child_('base', base)
|
|
193
|
-
|
|
194
|
-
self.base = base
|
|
195
|
-
|
|
196
|
-
@torch.no_grad()
|
|
197
|
-
def _update(self, vars, ascent):
|
|
198
|
-
if isinstance(self.base, (int, float)):
|
|
199
|
-
return self.base ** ascent
|
|
200
|
-
|
|
201
|
-
state_copy = vars.copy(clone_ascent = True)
|
|
202
|
-
base = self.children['base'].return_ascent(state_copy)
|
|
203
|
-
return base.pow_(ascent)
|
|
204
|
-
|
|
205
|
-
class Power(OptimizerModule):
|
|
206
|
-
"""calculates *base ^ power*"""
|
|
207
|
-
def __init__(
|
|
208
|
-
self,
|
|
209
|
-
base: OptimizerModule | Iterable[OptimizerModule],
|
|
210
|
-
power: OptimizerModule | Iterable[OptimizerModule],
|
|
211
|
-
):
|
|
212
|
-
super().__init__({})
|
|
213
|
-
self._set_child_('base', base)
|
|
214
|
-
self._set_child_('power', power)
|
|
215
|
-
|
|
216
|
-
@torch.no_grad
|
|
217
|
-
def step(self, vars):
|
|
218
|
-
state_copy = vars.copy(clone_ascent = True)
|
|
219
|
-
base = self.children['base'].return_ascent(state_copy)
|
|
220
|
-
vars.update_attrs_(state_copy)
|
|
221
|
-
power = self.children['power'].return_ascent(vars)
|
|
222
|
-
|
|
223
|
-
vars.ascent = base.pow_(power)
|
|
224
|
-
return self._update_params_or_step_with_next(vars)
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
class Lerp(OptimizerModule):
|
|
228
|
-
"""Linear interpolation between update and `end` based on scalar `weight`.
|
|
229
|
-
|
|
230
|
-
`out = update + weight * (end - update)`"""
|
|
231
|
-
def __init__(self, end: OptimizerModule | Iterable[OptimizerModule], weight: float):
|
|
232
|
-
super().__init__({})
|
|
233
|
-
|
|
234
|
-
self._set_child_('end', end)
|
|
235
|
-
self.weight = weight
|
|
236
|
-
|
|
237
|
-
@torch.no_grad()
|
|
238
|
-
def _update(self, vars, ascent):
|
|
239
|
-
|
|
240
|
-
state_copy = vars.copy(clone_ascent = True)
|
|
241
|
-
end = self.children['end'].return_ascent(state_copy)
|
|
242
|
-
return ascent.lerp_(end, self.weight)
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
class Interpolate(OptimizerModule):
|
|
246
|
-
"""Does a linear interpolation of two module's updates - `start` (given by input), and `end`, based on a scalar
|
|
247
|
-
`weight`.
|
|
248
|
-
|
|
249
|
-
`out = input + weight * (end - input)`"""
|
|
250
|
-
def __init__(
|
|
251
|
-
self,
|
|
252
|
-
input: OptimizerModule | Iterable[OptimizerModule],
|
|
253
|
-
end: OptimizerModule | Iterable[OptimizerModule],
|
|
254
|
-
weight: float,
|
|
255
|
-
):
|
|
256
|
-
super().__init__({})
|
|
257
|
-
self._set_child_('input', input)
|
|
258
|
-
self._set_child_('end', end)
|
|
259
|
-
self.weight = weight
|
|
260
|
-
|
|
261
|
-
@torch.no_grad
|
|
262
|
-
def step(self, vars):
|
|
263
|
-
state_copy = vars.copy(clone_ascent = True)
|
|
264
|
-
input = self.children['input'].return_ascent(state_copy)
|
|
265
|
-
vars.update_attrs_(state_copy)
|
|
266
|
-
end = self.children['end'].return_ascent(vars)
|
|
267
|
-
|
|
268
|
-
vars.ascent = input.lerp_(end, weight = self.weight)
|
|
269
|
-
|
|
270
|
-
return self._update_params_or_step_with_next(vars)
|
|
271
|
-
|
|
272
|
-
class AddMagnitude(OptimizerModule):
|
|
273
|
-
"""Add `value` multiplied by sign of the ascent, i.e. this adds `value` to the magnitude of the update.
|
|
274
|
-
|
|
275
|
-
Args:
|
|
276
|
-
value (Value): value to add to magnitude, either a float or an OptimizerModule.
|
|
277
|
-
add_to_zero (bool, optional):
|
|
278
|
-
if True, adds `value` to 0s. Otherwise, zeros remain zero.
|
|
279
|
-
Only has effect if value is a float. Defaults to True.
|
|
280
|
-
"""
|
|
281
|
-
def __init__(self, value: _Value, add_to_zero=True):
|
|
282
|
-
super().__init__({})
|
|
283
|
-
|
|
284
|
-
if not isinstance(value, (int, float)):
|
|
285
|
-
self._set_child_('value', value)
|
|
286
|
-
|
|
287
|
-
self.value = value
|
|
288
|
-
self.add_to_zero = add_to_zero
|
|
289
|
-
|
|
290
|
-
@torch.no_grad()
|
|
291
|
-
def _update(self, vars, ascent):
|
|
292
|
-
if isinstance(self.value, (int, float)):
|
|
293
|
-
if self.add_to_zero: return ascent.add_(ascent.clamp_magnitude(min=1).sign_().mul_(self.value))
|
|
294
|
-
return ascent.add_(ascent.sign_().mul_(self.value))
|
|
295
|
-
|
|
296
|
-
state_copy = vars.copy(clone_ascent = True)
|
|
297
|
-
v = self.children['value'].return_ascent(state_copy)
|
|
298
|
-
return ascent.add_(v.abs_().mul_(ascent.sign()))
|