torchzero 0.1.8__py3-none-any.whl → 0.3.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docs/source/conf.py +57 -0
- tests/test_identical.py +230 -0
- tests/test_module.py +50 -0
- tests/test_opts.py +884 -0
- tests/test_tensorlist.py +1787 -0
- tests/test_utils_optimizer.py +170 -0
- tests/test_vars.py +184 -0
- torchzero/__init__.py +4 -4
- torchzero/core/__init__.py +3 -13
- torchzero/core/module.py +629 -510
- torchzero/core/preconditioner.py +137 -0
- torchzero/core/transform.py +252 -0
- torchzero/modules/__init__.py +13 -21
- torchzero/modules/clipping/__init__.py +3 -0
- torchzero/modules/clipping/clipping.py +320 -0
- torchzero/modules/clipping/ema_clipping.py +135 -0
- torchzero/modules/clipping/growth_clipping.py +187 -0
- torchzero/modules/experimental/__init__.py +13 -18
- torchzero/modules/experimental/absoap.py +350 -0
- torchzero/modules/experimental/adadam.py +111 -0
- torchzero/modules/experimental/adamY.py +135 -0
- torchzero/modules/experimental/adasoap.py +282 -0
- torchzero/modules/experimental/algebraic_newton.py +145 -0
- torchzero/modules/experimental/curveball.py +89 -0
- torchzero/modules/experimental/dsoap.py +290 -0
- torchzero/modules/experimental/gradmin.py +85 -0
- torchzero/modules/experimental/reduce_outward_lr.py +35 -0
- torchzero/modules/experimental/spectral.py +286 -0
- torchzero/modules/experimental/subspace_preconditioners.py +128 -0
- torchzero/modules/experimental/tropical_newton.py +136 -0
- torchzero/modules/functional.py +209 -0
- torchzero/modules/grad_approximation/__init__.py +4 -0
- torchzero/modules/grad_approximation/fdm.py +120 -0
- torchzero/modules/grad_approximation/forward_gradient.py +81 -0
- torchzero/modules/grad_approximation/grad_approximator.py +66 -0
- torchzero/modules/grad_approximation/rfdm.py +259 -0
- torchzero/modules/line_search/__init__.py +5 -30
- torchzero/modules/line_search/backtracking.py +186 -0
- torchzero/modules/line_search/line_search.py +181 -0
- torchzero/modules/line_search/scipy.py +37 -0
- torchzero/modules/line_search/strong_wolfe.py +260 -0
- torchzero/modules/line_search/trust_region.py +61 -0
- torchzero/modules/lr/__init__.py +2 -0
- torchzero/modules/lr/lr.py +59 -0
- torchzero/modules/lr/step_size.py +97 -0
- torchzero/modules/momentum/__init__.py +14 -4
- torchzero/modules/momentum/averaging.py +78 -0
- torchzero/modules/momentum/cautious.py +181 -0
- torchzero/modules/momentum/ema.py +173 -0
- torchzero/modules/momentum/experimental.py +189 -0
- torchzero/modules/momentum/matrix_momentum.py +124 -0
- torchzero/modules/momentum/momentum.py +43 -106
- torchzero/modules/ops/__init__.py +103 -0
- torchzero/modules/ops/accumulate.py +65 -0
- torchzero/modules/ops/binary.py +240 -0
- torchzero/modules/ops/debug.py +25 -0
- torchzero/modules/ops/misc.py +419 -0
- torchzero/modules/ops/multi.py +137 -0
- torchzero/modules/ops/reduce.py +149 -0
- torchzero/modules/ops/split.py +75 -0
- torchzero/modules/ops/switch.py +68 -0
- torchzero/modules/ops/unary.py +115 -0
- torchzero/modules/ops/utility.py +112 -0
- torchzero/modules/optimizers/__init__.py +18 -10
- torchzero/modules/optimizers/adagrad.py +146 -49
- torchzero/modules/optimizers/adam.py +112 -118
- torchzero/modules/optimizers/lion.py +18 -11
- torchzero/modules/optimizers/muon.py +222 -0
- torchzero/modules/optimizers/orthograd.py +55 -0
- torchzero/modules/optimizers/rmsprop.py +103 -51
- torchzero/modules/optimizers/rprop.py +342 -99
- torchzero/modules/optimizers/shampoo.py +197 -0
- torchzero/modules/optimizers/soap.py +286 -0
- torchzero/modules/optimizers/sophia_h.py +129 -0
- torchzero/modules/projections/__init__.py +5 -0
- torchzero/modules/projections/dct.py +73 -0
- torchzero/modules/projections/fft.py +73 -0
- torchzero/modules/projections/galore.py +10 -0
- torchzero/modules/projections/projection.py +218 -0
- torchzero/modules/projections/structural.py +151 -0
- torchzero/modules/quasi_newton/__init__.py +7 -4
- torchzero/modules/quasi_newton/cg.py +218 -0
- torchzero/modules/quasi_newton/experimental/__init__.py +1 -0
- torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +265 -0
- torchzero/modules/quasi_newton/lbfgs.py +228 -0
- torchzero/modules/quasi_newton/lsr1.py +170 -0
- torchzero/modules/quasi_newton/olbfgs.py +196 -0
- torchzero/modules/quasi_newton/quasi_newton.py +475 -0
- torchzero/modules/second_order/__init__.py +3 -4
- torchzero/modules/second_order/newton.py +142 -165
- torchzero/modules/second_order/newton_cg.py +84 -0
- torchzero/modules/second_order/nystrom.py +168 -0
- torchzero/modules/smoothing/__init__.py +2 -5
- torchzero/modules/smoothing/gaussian.py +164 -0
- torchzero/modules/smoothing/{laplacian_smoothing.py → laplacian.py} +115 -128
- torchzero/modules/weight_decay/__init__.py +1 -0
- torchzero/modules/weight_decay/weight_decay.py +52 -0
- torchzero/modules/wrappers/__init__.py +1 -0
- torchzero/modules/wrappers/optim_wrapper.py +91 -0
- torchzero/optim/__init__.py +2 -10
- torchzero/optim/utility/__init__.py +1 -0
- torchzero/optim/utility/split.py +45 -0
- torchzero/optim/wrappers/nevergrad.py +2 -28
- torchzero/optim/wrappers/nlopt.py +31 -16
- torchzero/optim/wrappers/scipy.py +79 -156
- torchzero/utils/__init__.py +27 -0
- torchzero/utils/compile.py +175 -37
- torchzero/utils/derivatives.py +513 -99
- torchzero/utils/linalg/__init__.py +5 -0
- torchzero/utils/linalg/matrix_funcs.py +87 -0
- torchzero/utils/linalg/orthogonalize.py +11 -0
- torchzero/utils/linalg/qr.py +71 -0
- torchzero/utils/linalg/solve.py +168 -0
- torchzero/utils/linalg/svd.py +20 -0
- torchzero/utils/numberlist.py +132 -0
- torchzero/utils/ops.py +10 -0
- torchzero/utils/optimizer.py +284 -0
- torchzero/utils/optuna_tools.py +40 -0
- torchzero/utils/params.py +149 -0
- torchzero/utils/python_tools.py +40 -25
- torchzero/utils/tensorlist.py +1081 -0
- torchzero/utils/torch_tools.py +48 -12
- torchzero-0.3.2.dist-info/METADATA +379 -0
- torchzero-0.3.2.dist-info/RECORD +128 -0
- {torchzero-0.1.8.dist-info → torchzero-0.3.2.dist-info}/WHEEL +1 -1
- {torchzero-0.1.8.dist-info → torchzero-0.3.2.dist-info/licenses}/LICENSE +0 -0
- torchzero-0.3.2.dist-info/top_level.txt +3 -0
- torchzero/core/tensorlist_optimizer.py +0 -219
- torchzero/modules/adaptive/__init__.py +0 -4
- torchzero/modules/adaptive/adaptive.py +0 -192
- torchzero/modules/experimental/experimental.py +0 -294
- torchzero/modules/experimental/quad_interp.py +0 -104
- torchzero/modules/experimental/subspace.py +0 -259
- torchzero/modules/gradient_approximation/__init__.py +0 -7
- torchzero/modules/gradient_approximation/_fd_formulas.py +0 -3
- torchzero/modules/gradient_approximation/base_approximator.py +0 -105
- torchzero/modules/gradient_approximation/fdm.py +0 -125
- torchzero/modules/gradient_approximation/forward_gradient.py +0 -163
- torchzero/modules/gradient_approximation/newton_fdm.py +0 -198
- torchzero/modules/gradient_approximation/rfdm.py +0 -125
- torchzero/modules/line_search/armijo.py +0 -56
- torchzero/modules/line_search/base_ls.py +0 -139
- torchzero/modules/line_search/directional_newton.py +0 -217
- torchzero/modules/line_search/grid_ls.py +0 -158
- torchzero/modules/line_search/scipy_minimize_scalar.py +0 -62
- torchzero/modules/meta/__init__.py +0 -12
- torchzero/modules/meta/alternate.py +0 -65
- torchzero/modules/meta/grafting.py +0 -195
- torchzero/modules/meta/optimizer_wrapper.py +0 -173
- torchzero/modules/meta/return_overrides.py +0 -46
- torchzero/modules/misc/__init__.py +0 -10
- torchzero/modules/misc/accumulate.py +0 -43
- torchzero/modules/misc/basic.py +0 -115
- torchzero/modules/misc/lr.py +0 -96
- torchzero/modules/misc/multistep.py +0 -51
- torchzero/modules/misc/on_increase.py +0 -53
- torchzero/modules/operations/__init__.py +0 -29
- torchzero/modules/operations/multi.py +0 -298
- torchzero/modules/operations/reduction.py +0 -134
- torchzero/modules/operations/singular.py +0 -113
- torchzero/modules/optimizers/sgd.py +0 -54
- torchzero/modules/orthogonalization/__init__.py +0 -2
- torchzero/modules/orthogonalization/newtonschulz.py +0 -159
- torchzero/modules/orthogonalization/svd.py +0 -86
- torchzero/modules/regularization/__init__.py +0 -22
- torchzero/modules/regularization/dropout.py +0 -34
- torchzero/modules/regularization/noise.py +0 -77
- torchzero/modules/regularization/normalization.py +0 -328
- torchzero/modules/regularization/ortho_grad.py +0 -78
- torchzero/modules/regularization/weight_decay.py +0 -92
- torchzero/modules/scheduling/__init__.py +0 -2
- torchzero/modules/scheduling/lr_schedulers.py +0 -131
- torchzero/modules/scheduling/step_size.py +0 -80
- torchzero/modules/smoothing/gaussian_smoothing.py +0 -90
- torchzero/modules/weight_averaging/__init__.py +0 -2
- torchzero/modules/weight_averaging/ema.py +0 -72
- torchzero/modules/weight_averaging/swa.py +0 -171
- torchzero/optim/experimental/__init__.py +0 -20
- torchzero/optim/experimental/experimental.py +0 -343
- torchzero/optim/experimental/ray_search.py +0 -83
- torchzero/optim/first_order/__init__.py +0 -18
- torchzero/optim/first_order/cautious.py +0 -158
- torchzero/optim/first_order/forward_gradient.py +0 -70
- torchzero/optim/first_order/optimizers.py +0 -570
- torchzero/optim/modular.py +0 -148
- torchzero/optim/quasi_newton/__init__.py +0 -1
- torchzero/optim/quasi_newton/directional_newton.py +0 -58
- torchzero/optim/second_order/__init__.py +0 -1
- torchzero/optim/second_order/newton.py +0 -94
- torchzero/optim/zeroth_order/__init__.py +0 -4
- torchzero/optim/zeroth_order/fdm.py +0 -87
- torchzero/optim/zeroth_order/newton_fdm.py +0 -146
- torchzero/optim/zeroth_order/rfdm.py +0 -217
- torchzero/optim/zeroth_order/rs.py +0 -85
- torchzero/random/__init__.py +0 -1
- torchzero/random/random.py +0 -46
- torchzero/tensorlist.py +0 -826
- torchzero-0.1.8.dist-info/METADATA +0 -130
- torchzero-0.1.8.dist-info/RECORD +0 -104
- torchzero-0.1.8.dist-info/top_level.txt +0 -1
|
@@ -1,134 +0,0 @@
|
|
|
1
|
-
from collections.abc import Callable, Iterable
|
|
2
|
-
import numpy as np
|
|
3
|
-
import torch
|
|
4
|
-
|
|
5
|
-
from ...core import OptimizerModule
|
|
6
|
-
|
|
7
|
-
_Value = int | float | OptimizerModule | Iterable[OptimizerModule]
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
class Sum(OptimizerModule):
|
|
11
|
-
"""calculates sum of multiple updates.
|
|
12
|
-
|
|
13
|
-
Args:
|
|
14
|
-
*modules:
|
|
15
|
-
either OptimizerModules or iterables of OptimizerModules to chain. Scalars are also allowed."""
|
|
16
|
-
def __init__(
|
|
17
|
-
self,
|
|
18
|
-
*modules: _Value,
|
|
19
|
-
):
|
|
20
|
-
super().__init__({})
|
|
21
|
-
|
|
22
|
-
scalars = [i for i in modules if isinstance(i, (int,float))]
|
|
23
|
-
self.scalar = sum(scalars) if len(scalars) > 0 else None
|
|
24
|
-
|
|
25
|
-
for i,module in enumerate(i for i in modules if not isinstance(i, (int, float))):
|
|
26
|
-
self._set_child_(i, module)
|
|
27
|
-
|
|
28
|
-
@torch.no_grad
|
|
29
|
-
def step(self, vars):
|
|
30
|
-
if len(self.children) == 1:
|
|
31
|
-
vars.ascent = self.children[0].return_ascent(vars)
|
|
32
|
-
if self.scalar is not None: vars.ascent += self.scalar
|
|
33
|
-
return self._update_params_or_step_with_next(vars)
|
|
34
|
-
|
|
35
|
-
sum = None
|
|
36
|
-
for i, c in sorted(self.children.items(), key=lambda x: x[0]):
|
|
37
|
-
if i == len(self.children) - 1: cur_state = vars
|
|
38
|
-
else: cur_state = vars.copy(clone_ascent = True)
|
|
39
|
-
|
|
40
|
-
if sum is None: sum = c.return_ascent(cur_state)
|
|
41
|
-
else: sum += c.return_ascent(cur_state)
|
|
42
|
-
|
|
43
|
-
if i != len(self.children) - 1: vars.update_attrs_(cur_state)
|
|
44
|
-
|
|
45
|
-
assert sum is not None
|
|
46
|
-
if self.scalar is not None: sum += self.scalar
|
|
47
|
-
vars.ascent = sum
|
|
48
|
-
return self._update_params_or_step_with_next(vars)
|
|
49
|
-
|
|
50
|
-
class Mean(OptimizerModule):
|
|
51
|
-
"""calculates mean of multiple updates.
|
|
52
|
-
|
|
53
|
-
Args:
|
|
54
|
-
*modules:
|
|
55
|
-
either OptimizerModules or iterables of OptimizerModules to chain. Scalars are also allowed."""
|
|
56
|
-
|
|
57
|
-
def __init__(
|
|
58
|
-
self,
|
|
59
|
-
*modules: _Value,
|
|
60
|
-
):
|
|
61
|
-
super().__init__({})
|
|
62
|
-
|
|
63
|
-
scalars = [i for i in modules if isinstance(i, (int,float))]
|
|
64
|
-
self.scalar = sum(scalars) if len(scalars) > 0 else None
|
|
65
|
-
|
|
66
|
-
self.n_values = len(modules)
|
|
67
|
-
|
|
68
|
-
for i,module in enumerate(i for i in modules if not isinstance(i, (int, float))):
|
|
69
|
-
self._set_child_(i, module)
|
|
70
|
-
|
|
71
|
-
@torch.no_grad
|
|
72
|
-
def step(self, vars):
|
|
73
|
-
if len(self.children) == 1:
|
|
74
|
-
vars.ascent = self.children[0].return_ascent(vars)
|
|
75
|
-
if self.scalar is not None: vars.ascent += self.scalar
|
|
76
|
-
if self.n_values > 1: vars.ascent /= self.n_values
|
|
77
|
-
return self._update_params_or_step_with_next(vars)
|
|
78
|
-
|
|
79
|
-
sum = None
|
|
80
|
-
for i, c in sorted(self.children.items(), key=lambda x: x[0]):
|
|
81
|
-
if i == len(self.children) - 1: cur_state = vars
|
|
82
|
-
else: cur_state = vars.copy(clone_ascent = True)
|
|
83
|
-
|
|
84
|
-
if sum is None: sum = c.return_ascent(cur_state)
|
|
85
|
-
else: sum += c.return_ascent(cur_state)
|
|
86
|
-
|
|
87
|
-
if i != len(self.children) - 1: vars.update_attrs_(cur_state)
|
|
88
|
-
|
|
89
|
-
assert sum is not None
|
|
90
|
-
if self.scalar is not None: sum += self.scalar
|
|
91
|
-
if self.n_values > 1: sum /= self.n_values
|
|
92
|
-
vars.ascent = sum
|
|
93
|
-
return self._update_params_or_step_with_next(vars)
|
|
94
|
-
|
|
95
|
-
class Product(OptimizerModule):
|
|
96
|
-
"""calculates product of multiple updates.
|
|
97
|
-
|
|
98
|
-
Args:
|
|
99
|
-
*modules:
|
|
100
|
-
either OptimizerModules or iterables of OptimizerModules to chain. Scalars are also allowed."""
|
|
101
|
-
|
|
102
|
-
def __init__(
|
|
103
|
-
self,
|
|
104
|
-
*modules: _Value,
|
|
105
|
-
):
|
|
106
|
-
super().__init__({})
|
|
107
|
-
|
|
108
|
-
scalars = [i for i in modules if isinstance(i, (int,float))]
|
|
109
|
-
self.scalar = np.prod(scalars).item() if len(scalars) > 0 else None
|
|
110
|
-
|
|
111
|
-
for i,module in enumerate(i for i in modules if not isinstance(i, (int, float))):
|
|
112
|
-
self._set_child_(i, module)
|
|
113
|
-
|
|
114
|
-
@torch.no_grad
|
|
115
|
-
def step(self, vars):
|
|
116
|
-
if len(self.children) == 1:
|
|
117
|
-
vars.ascent = self.children[0].return_ascent(vars)
|
|
118
|
-
if self.scalar is not None: vars.ascent *= self.scalar
|
|
119
|
-
return self._update_params_or_step_with_next(vars)
|
|
120
|
-
|
|
121
|
-
prod = None
|
|
122
|
-
for i, c in sorted(self.children.items(), key=lambda x: x[0]):
|
|
123
|
-
if i == len(self.children) - 1: cur_state = vars
|
|
124
|
-
else: cur_state = vars.copy(clone_ascent = True)
|
|
125
|
-
|
|
126
|
-
if prod is None: prod = c.return_ascent(cur_state)
|
|
127
|
-
else: prod *= c.return_ascent(cur_state)
|
|
128
|
-
|
|
129
|
-
if i != len(self.children) - 1: vars.update_attrs_(cur_state)
|
|
130
|
-
|
|
131
|
-
assert prod is not None
|
|
132
|
-
if self.scalar is not None: prod *= self.scalar
|
|
133
|
-
vars.ascent = prod
|
|
134
|
-
return self._update_params_or_step_with_next(vars)
|
|
@@ -1,113 +0,0 @@
|
|
|
1
|
-
from collections.abc import Iterable
|
|
2
|
-
from operator import methodcaller
|
|
3
|
-
|
|
4
|
-
import torch
|
|
5
|
-
|
|
6
|
-
from ...core import OptimizerModule
|
|
7
|
-
from ...tensorlist import TensorList
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
class Operation(OptimizerModule):
|
|
11
|
-
"""Applies an operation to the ascent, supported operations:
|
|
12
|
-
|
|
13
|
-
`abs`, `sign`, `sin`, `cos`, `tan`, `asin`, `acos`, `atan`, `sinh`, `cosh`,
|
|
14
|
-
`tanh`, `log`, `log1p`, `log2`, `log10`, `erf`, `erfc`, `exp`, `neg`, `reciprocal`,
|
|
15
|
-
`copy`, `zero`, `sqrt`, `floor`, `ceil`, `round`."""
|
|
16
|
-
def __init__(self, operation: str):
|
|
17
|
-
super().__init__({})
|
|
18
|
-
self.operation = methodcaller(f'{operation}_')
|
|
19
|
-
|
|
20
|
-
@torch.no_grad
|
|
21
|
-
def _update(self, vars, ascent): return self.operation(ascent)
|
|
22
|
-
|
|
23
|
-
class Reciprocal(OptimizerModule):
|
|
24
|
-
"""*1 / update*"""
|
|
25
|
-
def __init__(self,):
|
|
26
|
-
super().__init__({})
|
|
27
|
-
|
|
28
|
-
@torch.no_grad()
|
|
29
|
-
def _update(self, vars, ascent): return ascent.reciprocal_()
|
|
30
|
-
|
|
31
|
-
class Negate(OptimizerModule):
|
|
32
|
-
"""minus update"""
|
|
33
|
-
def __init__(self,):
|
|
34
|
-
super().__init__({})
|
|
35
|
-
|
|
36
|
-
@torch.no_grad()
|
|
37
|
-
def _update(self, vars, ascent): return ascent.neg_()
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
def sign_grad_(params: Iterable[torch.Tensor]):
|
|
41
|
-
"""Apply sign function to gradients of an iterable of parameters.
|
|
42
|
-
|
|
43
|
-
Args:
|
|
44
|
-
params (abc.Iterable[torch.Tensor]): an iterable of Tensors or a single Tensor.
|
|
45
|
-
"""
|
|
46
|
-
TensorList(params).get_existing_grads().sign_()
|
|
47
|
-
|
|
48
|
-
class Sign(OptimizerModule):
|
|
49
|
-
"""applies sign function to the update"""
|
|
50
|
-
def __init__(self):
|
|
51
|
-
super().__init__({})
|
|
52
|
-
|
|
53
|
-
@torch.no_grad
|
|
54
|
-
def _update(self, vars, ascent): return ascent.sign_()
|
|
55
|
-
|
|
56
|
-
class Abs(OptimizerModule):
|
|
57
|
-
"""takes absolute values of the update."""
|
|
58
|
-
def __init__(self):
|
|
59
|
-
super().__init__({})
|
|
60
|
-
|
|
61
|
-
@torch.no_grad
|
|
62
|
-
def _update(self, vars, ascent): return ascent.abs_()
|
|
63
|
-
|
|
64
|
-
class Sin(OptimizerModule):
|
|
65
|
-
"""applies sin function to the ascent"""
|
|
66
|
-
def __init__(self):
|
|
67
|
-
super().__init__({})
|
|
68
|
-
|
|
69
|
-
@torch.no_grad
|
|
70
|
-
def _update(self, vars, ascent): return ascent.sin_()
|
|
71
|
-
|
|
72
|
-
class Cos(OptimizerModule):
|
|
73
|
-
"""applies cos function to the ascent"""
|
|
74
|
-
def __init__(self):
|
|
75
|
-
super().__init__({})
|
|
76
|
-
|
|
77
|
-
@torch.no_grad
|
|
78
|
-
def _update(self, vars, ascent): return ascent.cos_()
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
class NanToNum(OptimizerModule):
|
|
82
|
-
"""Convert `nan`, `inf` and `-inf` to numbers.
|
|
83
|
-
|
|
84
|
-
Args:
|
|
85
|
-
nan (optional): the value to replace NaNs with. Default is zero.
|
|
86
|
-
posinf (optional): if a Number, the value to replace positive infinity values with.
|
|
87
|
-
If None, positive infinity values are replaced with the greatest finite value
|
|
88
|
-
representable by input's dtype. Default is None.
|
|
89
|
-
neginf (optional): if a Number, the value to replace negative infinity values with.
|
|
90
|
-
If None, negative infinity values are replaced with the lowest finite value
|
|
91
|
-
representable by input's dtype. Default is None.
|
|
92
|
-
"""
|
|
93
|
-
def __init__(self, nan=None, posinf=None, neginf=None):
|
|
94
|
-
super().__init__({})
|
|
95
|
-
self.nan = nan
|
|
96
|
-
self.posinf = posinf
|
|
97
|
-
self.neginf = neginf
|
|
98
|
-
|
|
99
|
-
@torch.no_grad()
|
|
100
|
-
def _update(self, vars, ascent): return ascent.nan_to_num_(self.nan, self.posinf, self.neginf)
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
class MagnitudePower(OptimizerModule):
|
|
104
|
-
"""Raises update to the `value` power, but preserves the sign when the power is odd."""
|
|
105
|
-
def __init__(self, value: int | float):
|
|
106
|
-
super().__init__({})
|
|
107
|
-
self.value = value
|
|
108
|
-
|
|
109
|
-
@torch.no_grad()
|
|
110
|
-
def _update(self, vars, ascent):
|
|
111
|
-
if self.value % 2 == 1: return ascent.pow_(self.value)
|
|
112
|
-
return ascent.abs().pow_(self.value) * ascent.sign()
|
|
113
|
-
|
|
@@ -1,54 +0,0 @@
|
|
|
1
|
-
import typing as T
|
|
2
|
-
|
|
3
|
-
import torch
|
|
4
|
-
|
|
5
|
-
from ...core import OptimizerModule
|
|
6
|
-
from ..momentum.momentum import _heavyball_step, _nesterov_step_
|
|
7
|
-
|
|
8
|
-
class SGD(OptimizerModule):
|
|
9
|
-
"""Same as `torch.optim.SGD` but as an optimizer module. Exactly matches `torch.optim.SGD`, except
|
|
10
|
-
nesterov momentum additionally supports dampening, and negative momentum is allowed.
|
|
11
|
-
|
|
12
|
-
Args:
|
|
13
|
-
momentum (float, optional): momentum. Defaults to 0.
|
|
14
|
-
dampening (float, optional): momentum dampening. Defaults to 0.
|
|
15
|
-
weight_decay (float, optional): weight decay (L2 regularization). Defaults to 0.
|
|
16
|
-
nesterov (bool, optional):
|
|
17
|
-
enables nesterov momentum, otherwise uses heavyball momentum. Defaults to False.
|
|
18
|
-
alpha (float, optional): learning rate. Defaults to 1.
|
|
19
|
-
"""
|
|
20
|
-
def __init__(
|
|
21
|
-
self,
|
|
22
|
-
momentum: float = 0,
|
|
23
|
-
dampening: float = 0,
|
|
24
|
-
weight_decay: float = 0,
|
|
25
|
-
nesterov: bool = False,
|
|
26
|
-
alpha: float = 1,
|
|
27
|
-
):
|
|
28
|
-
|
|
29
|
-
defaults = dict(alpha=alpha, momentum=momentum, dampening=dampening, weight_decay=weight_decay,)
|
|
30
|
-
super().__init__(defaults)
|
|
31
|
-
self.nesterov = nesterov
|
|
32
|
-
self.current_step = 0
|
|
33
|
-
|
|
34
|
-
@torch.no_grad
|
|
35
|
-
def _update(self, vars, ascent):
|
|
36
|
-
params = self.get_params()
|
|
37
|
-
settings = self.get_all_group_keys()
|
|
38
|
-
|
|
39
|
-
if any(i != 0 for i in settings['weight_decay']):
|
|
40
|
-
ascent += params * settings['weight_decay']
|
|
41
|
-
|
|
42
|
-
if any(i != 1 for i in settings['alpha']):
|
|
43
|
-
ascent *= settings['alpha']
|
|
44
|
-
|
|
45
|
-
if any(i != 0 for i in settings['momentum']):
|
|
46
|
-
velocity = self.get_state_key('velocity', init = torch.zeros_like if self.nesterov else ascent)
|
|
47
|
-
# consistency with pytorch which on first step only initializes momentum
|
|
48
|
-
if self.current_step > 0 or self.nesterov:
|
|
49
|
-
# nesterov step can be done in-place, polyak returns new direction
|
|
50
|
-
if self.nesterov: _nesterov_step_(ascent, velocity, settings['momentum'], settings['dampening'])
|
|
51
|
-
else: ascent = _heavyball_step(ascent, velocity, settings['momentum'], settings['dampening'])
|
|
52
|
-
|
|
53
|
-
self.current_step += 1
|
|
54
|
-
return ascent
|
|
@@ -1,159 +0,0 @@
|
|
|
1
|
-
"""
|
|
2
|
-
Newton-Schulz iteration code is taken from https://github.com/KellerJordan/Muon
|
|
3
|
-
|
|
4
|
-
Keller Jordan and Yuchen Jin and Vlado Boza and You Jiacheng and Franz Cecista and Laker Newhouse and Jeremy Bernstein.
|
|
5
|
-
Muon: An optimizer for hidden layers in neural networks (2024). URL: https://kellerjordan.github.io/posts/muon
|
|
6
|
-
"""
|
|
7
|
-
from collections.abc import Iterable
|
|
8
|
-
|
|
9
|
-
import torch
|
|
10
|
-
|
|
11
|
-
from ...core import OptimizerModule, _Targets
|
|
12
|
-
# from ...utils.compile import maybe_compile
|
|
13
|
-
|
|
14
|
-
def _zeropower_via_newtonschulz5(G, steps):
|
|
15
|
-
"""
|
|
16
|
-
code from https://github.com/KellerJordan/Muon
|
|
17
|
-
|
|
18
|
-
Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
|
|
19
|
-
quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
|
|
20
|
-
of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
|
|
21
|
-
zero even beyond the point where the iteration no longer converges all the way to one everywhere
|
|
22
|
-
on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
|
|
23
|
-
where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
|
|
24
|
-
performance at all relative to UV^T, where USV^T = G is the SVD.
|
|
25
|
-
"""
|
|
26
|
-
assert len(G.shape) == 2
|
|
27
|
-
a, b, c = (3.4445, -4.7750, 2.0315)
|
|
28
|
-
X = G.bfloat16()
|
|
29
|
-
if G.size(0) > G.size(1):
|
|
30
|
-
X = X.T
|
|
31
|
-
|
|
32
|
-
# Ensure spectral norm is at most 1
|
|
33
|
-
X = X / (X.norm() + 1e-7)
|
|
34
|
-
# Perform the NS iterations
|
|
35
|
-
for _ in range(steps):
|
|
36
|
-
A = X @ X.T
|
|
37
|
-
B = b * A + c * A @ A # adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
|
|
38
|
-
X = a * X + B @ X
|
|
39
|
-
|
|
40
|
-
if G.size(0) > G.size(1):
|
|
41
|
-
X = X.T
|
|
42
|
-
|
|
43
|
-
return X
|
|
44
|
-
|
|
45
|
-
_compiled_zeropower_via_newtonschulz5 = torch.compile(_zeropower_via_newtonschulz5)
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
def zeropower_via_newtonschulz_(params: Iterable[torch.Tensor], steps: int = 6, adaptive = False, compiled = True):
|
|
49
|
-
"""Uses newton-Schulz iteration to compute the zeroth power / orthogonalization of gradients of an iterable of parameters.
|
|
50
|
-
|
|
51
|
-
This sets gradients in-place.
|
|
52
|
-
|
|
53
|
-
Note that the Muon page says that embeddings and classifier heads should not be orthogonalized.
|
|
54
|
-
|
|
55
|
-
The orthogonalization code is taken from https://github.com/KellerJordan/Muon
|
|
56
|
-
Args:
|
|
57
|
-
params (abc.Iterable[torch.Tensor]): parameters that hold gradients to orthogonalize.
|
|
58
|
-
steps (int): The number of Newton-Schulz iterations to run. (6 is probably always enough).
|
|
59
|
-
The number of Newton-Schulz iterations to run. (6 is probably always enough). Defaults to 6.
|
|
60
|
-
adaptive (bool, optional):
|
|
61
|
-
Enables adaptation to scale of gradients (from https://github.com/leloykun/adaptive-muon). Defaults to False.
|
|
62
|
-
compiled (bool, optional):
|
|
63
|
-
Uses compiled newton-Schulz iteration function. Faster but won't work on windows. Defaults to True.
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
"""
|
|
67
|
-
if compiled: fn = _compiled_zeropower_via_newtonschulz5
|
|
68
|
-
else: fn = _zeropower_via_newtonschulz5
|
|
69
|
-
for p in params:
|
|
70
|
-
if p.grad is not None and p.grad.ndim >= 2 and min(p.grad.shape) >= 2:
|
|
71
|
-
G = p.grad.view(p.grad.shape[0], -1)
|
|
72
|
-
X = fn(G, steps)
|
|
73
|
-
|
|
74
|
-
if adaptive:
|
|
75
|
-
# this is from https://github.com/leloykun/adaptive-muon
|
|
76
|
-
X = torch.einsum('ij,ij,ab->ab', G.type_as(X), X, X) # Adaptive scaling,`(G * X).sum() * X` == (G.T @ X).trace() * X
|
|
77
|
-
|
|
78
|
-
p.grad = X.reshape_as(p.grad).to(p.grad, copy=False)
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
class ZeropowerViaNewtonSchulz(OptimizerModule):
|
|
82
|
-
"""Uses Newton-Schulz iteration to compute the zeroth power / orthogonalization of gradients of an iterable of parameters.
|
|
83
|
-
|
|
84
|
-
To disable orthogonalization for a parameter, put it into a parameter group with "newtonshultz" = False.
|
|
85
|
-
The Muon page says that embeddings and classifier heads should not be orthogonalized.
|
|
86
|
-
|
|
87
|
-
The orthogonalization code is taken from https://github.com/KellerJordan/Muon.
|
|
88
|
-
|
|
89
|
-
Note that unlike this module, Muon also uses Adam for gradients that are not orthogonalized,
|
|
90
|
-
so I'd still recommend using it. Maybe use `Wrap` to wrap it into a module (I will make muon
|
|
91
|
-
with selectable modules to optimize non-muon params soon)
|
|
92
|
-
|
|
93
|
-
However not using Adam, or putting Adam module after this to apply it to ALL updates, both seem
|
|
94
|
-
to work quite well too.
|
|
95
|
-
|
|
96
|
-
Args:
|
|
97
|
-
ns_steps (int, optional):
|
|
98
|
-
The number of Newton-Schulz iterations to run. (6 is probably always enough). Defaults to 6.
|
|
99
|
-
adaptive (bool, optional):
|
|
100
|
-
Enables adaptation to scale of gradients (from https://github.com/leloykun/adaptive-muon). Defaults to True.
|
|
101
|
-
compiled (bool, optional):
|
|
102
|
-
Uses compiled newton-Schulz iteration function. Faster but won't work on windows. Defaults to True.
|
|
103
|
-
target (str, optional):
|
|
104
|
-
determines what this module updates.
|
|
105
|
-
|
|
106
|
-
"ascent" - it updates the ascent
|
|
107
|
-
|
|
108
|
-
"grad" - it updates the gradient (and sets `.grad` attributes to updated gradient).
|
|
109
|
-
|
|
110
|
-
"closure" - it makes a new closure that sets the updated ascent to the .`grad` attributes.
|
|
111
|
-
"""
|
|
112
|
-
def __init__(self, ns_steps = 6, adaptive = False, compiled=True, target:_Targets='ascent'):
|
|
113
|
-
defaults = dict(newtonshultz = True, ns_steps=ns_steps, adaptive=adaptive)
|
|
114
|
-
super().__init__(defaults, target=target)
|
|
115
|
-
|
|
116
|
-
if compiled: self._zeropower_via_newtonschulz5 = _compiled_zeropower_via_newtonschulz5
|
|
117
|
-
else: self._zeropower_via_newtonschulz5 = _zeropower_via_newtonschulz5
|
|
118
|
-
|
|
119
|
-
def _update(self, vars, ascent):
|
|
120
|
-
toggle, ns_steps, adaptive = self.get_group_keys('newtonshultz', 'ns_steps', 'adaptive', cls=list)
|
|
121
|
-
|
|
122
|
-
for asc, enable, steps, ada in zip(ascent, toggle, ns_steps, adaptive):
|
|
123
|
-
if enable and len([i for i in asc.shape if i > 1]) != 0:
|
|
124
|
-
G = asc.view(asc.shape[0], -1)
|
|
125
|
-
X = self._zeropower_via_newtonschulz5(G, steps)
|
|
126
|
-
|
|
127
|
-
if ada:
|
|
128
|
-
# this is from https://github.com/leloykun/adaptive-muon
|
|
129
|
-
X = torch.einsum('ij,ij,ab->ab', G.type_as(X), X, X) # Adaptive scaling,`(G * X).sum() * X` == (G.T @ X).trace() * X
|
|
130
|
-
|
|
131
|
-
asc.set_(X.reshape_as(asc).to(asc, copy=False)) # type:ignore
|
|
132
|
-
|
|
133
|
-
return ascent
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
class DualNormCorrection(OptimizerModule):
|
|
138
|
-
"""Dual norm correction from https://github.com/leloykun/adaptive-muon.
|
|
139
|
-
|
|
140
|
-
Description from the page:
|
|
141
|
-
|
|
142
|
-
Single-line modification to any (dualizer-based) optimizer that allows the optimizer to adapt to the scale of the gradients as they change during training.
|
|
143
|
-
This is done by scaling the dualized gradient by the clipped dual norm of the original gradient.
|
|
144
|
-
"""
|
|
145
|
-
def __init__(self, adaptive_scale_min: int | None = -1, adaptive_scale_max: int | None = 1):
|
|
146
|
-
defaults = dict(adaptive_scale_min = adaptive_scale_min, adaptive_scale_max = adaptive_scale_max)
|
|
147
|
-
super().__init__(defaults)
|
|
148
|
-
|
|
149
|
-
def _update(self, vars, ascent):
|
|
150
|
-
params = self.get_params()
|
|
151
|
-
adaptive_scale_min, adaptive_scale_max = self.get_group_keys('adaptive_scale_min', 'adaptive_scale_max')
|
|
152
|
-
|
|
153
|
-
for asc, grad, min, max in zip(ascent, vars.maybe_compute_grad_(params), adaptive_scale_min, adaptive_scale_max):
|
|
154
|
-
if len([i for i in asc.shape if i > 1]) != 0:
|
|
155
|
-
scale = torch.einsum('ij,ij->', grad.view(grad.shape[0], -1), asc.view(asc.shape[0], -1))
|
|
156
|
-
if min is not None or max is not None: scale = scale.clip(min, max)
|
|
157
|
-
asc *= scale
|
|
158
|
-
|
|
159
|
-
return ascent
|
|
@@ -1,86 +0,0 @@
|
|
|
1
|
-
"""Orthogonalization code adapted from https://github.com/MarkTuddenham/Orthogonal-Optimisers
|
|
2
|
-
|
|
3
|
-
Tuddenham, M., Prügel-Bennett, A., & Hare, J. (2022).
|
|
4
|
-
Orthogonalising gradients to speed up neural network optimisation. arXiv preprint arXiv:2202.07052.
|
|
5
|
-
"""
|
|
6
|
-
import logging
|
|
7
|
-
from collections.abc import Iterable, Sequence
|
|
8
|
-
|
|
9
|
-
import torch
|
|
10
|
-
|
|
11
|
-
from ...core import OptimizerModule, _Targets
|
|
12
|
-
|
|
13
|
-
@torch.no_grad()
|
|
14
|
-
def _orthogonalize_update_(updates: Sequence[torch.Tensor], toggle = None, warn_fail=True) -> None:
|
|
15
|
-
"""adapted from https://github.com/MarkTuddenham/Orthogonal-Optimisers"""
|
|
16
|
-
if toggle is None: toggle = [True] * len(updates)
|
|
17
|
-
|
|
18
|
-
# Orthogonalise the gradients using SVD
|
|
19
|
-
for grad, orth in zip(updates, toggle):
|
|
20
|
-
if orth and grad.ndim > 1:
|
|
21
|
-
G: torch.Tensor = grad.view(grad.shape[0], -1)
|
|
22
|
-
orth_G: torch.Tensor | None = None
|
|
23
|
-
try:
|
|
24
|
-
u, s, vt = torch.linalg.svd(G, full_matrices=False) # pylint:disable=not-callable
|
|
25
|
-
orth_G = u @ vt
|
|
26
|
-
except RuntimeError:
|
|
27
|
-
# if warn: logging.warning('Failed to perform SVD, adding some noise.')
|
|
28
|
-
try:
|
|
29
|
-
u, s, v = torch.svd_lowrank(
|
|
30
|
-
G,
|
|
31
|
-
q=1, # assume rank is at least 1
|
|
32
|
-
M=1e-4 * G.mean() * torch.randn_like(G))
|
|
33
|
-
orth_G = u @ v.T
|
|
34
|
-
except RuntimeError:
|
|
35
|
-
if warn_fail: logging.error(('Failed to perform SVD with noise,'
|
|
36
|
-
' skipping gradient orthogonalisation'))
|
|
37
|
-
if orth_G is not None:
|
|
38
|
-
grad.set_(orth_G.reshape_as(grad)) # type:ignore
|
|
39
|
-
|
|
40
|
-
return updates
|
|
41
|
-
|
|
42
|
-
def orthogonalize_grad_(params: Iterable[torch.Tensor], warn_fail=False):
|
|
43
|
-
"""orthogonalizes gradients of an iterable of parameters.
|
|
44
|
-
|
|
45
|
-
This updates gradients in-place.
|
|
46
|
-
|
|
47
|
-
The orthogonalization code is adapted from https://github.com/MarkTuddenham/Orthogonal-Optimisers
|
|
48
|
-
Args:
|
|
49
|
-
params (abc.Iterable[torch.Tensor]): parameters that hold gradients to orthogonalize.
|
|
50
|
-
warn_fail (bool, optional):
|
|
51
|
-
whether to print a warning when orthogonalization fails, and gradients are not
|
|
52
|
-
orthogonalized. Defaults to True.
|
|
53
|
-
"""
|
|
54
|
-
grads = [p.grad for p in params if p.grad is not None]
|
|
55
|
-
_orthogonalize_update_(grads, warn_fail=warn_fail)
|
|
56
|
-
|
|
57
|
-
class Orthogonalize(OptimizerModule):
|
|
58
|
-
"""Orthogonalizes the update using SVD.
|
|
59
|
-
|
|
60
|
-
To disable orthogonalization for a parameter, put it into a parameter group with "orth" = False.
|
|
61
|
-
|
|
62
|
-
The orthogonalization code is adapted from https://github.com/MarkTuddenham/Orthogonal-Optimisers
|
|
63
|
-
|
|
64
|
-
Tip: :py:class:`tz.m.ZeropowerViaNewtonSchulz` is a significantly faster version of this.
|
|
65
|
-
Args:
|
|
66
|
-
warn_fail (bool, optional):
|
|
67
|
-
whether to print a warning when orthogonalization fails, and gradients are not
|
|
68
|
-
orthogonalized. Defaults to True.
|
|
69
|
-
target (str, optional):
|
|
70
|
-
determines what this module updates.
|
|
71
|
-
|
|
72
|
-
"ascent" - it updates the ascent
|
|
73
|
-
|
|
74
|
-
"grad" - it updates the gradient (and sets `.grad` attributes to updated gradient).
|
|
75
|
-
|
|
76
|
-
"closure" - it makes a new closure that sets the updated ascent to the .`grad` attributes.
|
|
77
|
-
"""
|
|
78
|
-
def __init__(self, warn_fail=True, target: _Targets = 'ascent'):
|
|
79
|
-
defaults = dict(orth = True)
|
|
80
|
-
super().__init__(defaults, target = target)
|
|
81
|
-
self.warn_fail = warn_fail
|
|
82
|
-
|
|
83
|
-
def _update(self, vars, ascent):
|
|
84
|
-
toggle = self.get_group_key('orth', cls=list)
|
|
85
|
-
_orthogonalize_update_(ascent, toggle, self.warn_fail)
|
|
86
|
-
return ascent
|
|
@@ -1,22 +0,0 @@
|
|
|
1
|
-
r"""
|
|
2
|
-
This includes regularization modules like weight decay.
|
|
3
|
-
"""
|
|
4
|
-
from .dropout import Dropout
|
|
5
|
-
from .noise import AddNoise, Random, add_noise_
|
|
6
|
-
from .normalization import (
|
|
7
|
-
Centralize,
|
|
8
|
-
ClipNorm,
|
|
9
|
-
ClipValue,
|
|
10
|
-
Normalize,
|
|
11
|
-
centralize_grad_,
|
|
12
|
-
clip_grad_norm_,
|
|
13
|
-
clip_grad_value_,
|
|
14
|
-
normalize_grad_,
|
|
15
|
-
)
|
|
16
|
-
from .weight_decay import (
|
|
17
|
-
WeightDecay,
|
|
18
|
-
l1_regularize_,
|
|
19
|
-
l2_regularize_,
|
|
20
|
-
weight_decay_penalty,
|
|
21
|
-
)
|
|
22
|
-
from .ortho_grad import OrthoGrad, orthograd_
|
|
@@ -1,34 +0,0 @@
|
|
|
1
|
-
import typing as T
|
|
2
|
-
from collections import abc
|
|
3
|
-
|
|
4
|
-
import torch
|
|
5
|
-
|
|
6
|
-
from ...tensorlist import Distributions, TensorList
|
|
7
|
-
from ...core import OptimizerModule
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
class Dropout(OptimizerModule):
|
|
11
|
-
"""
|
|
12
|
-
Applies dropout to the update - sets random elements to 0.
|
|
13
|
-
|
|
14
|
-
This can be used to apply learning rate dropout, if put after other modules, or gradient dropout,
|
|
15
|
-
if put first.
|
|
16
|
-
|
|
17
|
-
Args:
|
|
18
|
-
p (float, optional): probability to replace update value with zero. Defaults to 0.5.
|
|
19
|
-
|
|
20
|
-
reference
|
|
21
|
-
*Lin, H., Zeng, W., Zhuang, Y., Ding, X., Huang, Y., & Paisley, J. (2022).
|
|
22
|
-
Learning rate dropout. IEEE Transactions on Neural Networks and Learning Systems,
|
|
23
|
-
34(11), 9029-9039.*
|
|
24
|
-
"""
|
|
25
|
-
def __init__(self, p: float = 0.5):
|
|
26
|
-
defaults = dict(p = p)
|
|
27
|
-
super().__init__(defaults)
|
|
28
|
-
|
|
29
|
-
@torch.no_grad
|
|
30
|
-
def _update(self, vars, ascent):
|
|
31
|
-
p = self.get_group_key('p')
|
|
32
|
-
|
|
33
|
-
ascent *= ascent.bernoulli_like(p)
|
|
34
|
-
return ascent
|
|
@@ -1,77 +0,0 @@
|
|
|
1
|
-
from collections import abc
|
|
2
|
-
from typing import Literal
|
|
3
|
-
|
|
4
|
-
import torch
|
|
5
|
-
|
|
6
|
-
from ...core import OptimizerModule
|
|
7
|
-
from ...tensorlist import Distributions, TensorList, _Scalar, _ScalarSequence
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
def add_noise_(
|
|
11
|
-
grads: abc.Iterable[torch.Tensor],
|
|
12
|
-
alpha: "_Scalar | _ScalarSequence" = 1e-2,
|
|
13
|
-
distribution: Distributions = "normal",
|
|
14
|
-
mode: Literal["absolute", "global", "param", "channel"] = "param",
|
|
15
|
-
):
|
|
16
|
-
if not isinstance(grads, TensorList): grads = TensorList(grads)
|
|
17
|
-
if mode == 'absolute':
|
|
18
|
-
grads += grads.sample_like(alpha, distribution)
|
|
19
|
-
|
|
20
|
-
elif mode == 'global':
|
|
21
|
-
grads += grads.sample_like((grads.total_vector_norm(1)/grads.total_numel() * alpha).detach().cpu().item(), distribution) # type:ignore
|
|
22
|
-
|
|
23
|
-
elif mode == 'param':
|
|
24
|
-
grads += grads.sample_like(grads.abs().mean()*alpha, distribution)
|
|
25
|
-
|
|
26
|
-
elif mode == 'channel':
|
|
27
|
-
grads = grads.unbind_channels()
|
|
28
|
-
grads += grads.sample_like(grads.abs().mean()*alpha, distribution)
|
|
29
|
-
|
|
30
|
-
class AddNoise(OptimizerModule):
|
|
31
|
-
"""Add noise to update. By default noise magnitude is relative to the mean of each parameter.
|
|
32
|
-
|
|
33
|
-
Args:
|
|
34
|
-
alpha (float, optional): magnitude of noise. Defaults to 1e-2.
|
|
35
|
-
distribution (Distributions, optional): distribution of noise. Defaults to 'normal'.
|
|
36
|
-
mode (str, optional):
|
|
37
|
-
how to calculate noise magnitude.
|
|
38
|
-
|
|
39
|
-
- "absolute": ignores gradient magnitude and always uses `alpha` as magnitude.
|
|
40
|
-
|
|
41
|
-
- "global": multiplies `alpha` by mean of the entire gradient, as if it was a single vector.
|
|
42
|
-
|
|
43
|
-
- "param": multiplies `alpha` by mean of each individual parameter (default).
|
|
44
|
-
|
|
45
|
-
- "channel": multiplies `alpha` by mean of each channel of each parameter.
|
|
46
|
-
"""
|
|
47
|
-
|
|
48
|
-
def __init__(
|
|
49
|
-
self,
|
|
50
|
-
alpha: float = 1.,
|
|
51
|
-
distribution: Distributions = "normal",
|
|
52
|
-
mode: Literal["absolute", "global", "param", "channel"] = "param",
|
|
53
|
-
):
|
|
54
|
-
defaults = dict(alpha = alpha)
|
|
55
|
-
super().__init__(defaults)
|
|
56
|
-
self.distribution: Distributions = distribution
|
|
57
|
-
self.mode: Literal["absolute", "global", "param", "channel"] = mode
|
|
58
|
-
|
|
59
|
-
@torch.no_grad
|
|
60
|
-
def _update(self, vars, ascent):
|
|
61
|
-
alpha = self.get_group_key('alpha')
|
|
62
|
-
|
|
63
|
-
add_noise_(ascent, alpha, self.distribution, self.mode)
|
|
64
|
-
return ascent
|
|
65
|
-
|
|
66
|
-
class Random(OptimizerModule):
|
|
67
|
-
"""uses a random vector as the update. The vector is completely random and isn't checked to be descent direction.
|
|
68
|
-
This is therefore mainly useful in combination with other modules like Sum, Multiply, etc."""
|
|
69
|
-
def __init__(self, alpha: float = 1, distribution: Distributions = "normal"):
|
|
70
|
-
defaults = dict(alpha = alpha)
|
|
71
|
-
super().__init__(defaults)
|
|
72
|
-
self.distribution: Distributions = distribution
|
|
73
|
-
|
|
74
|
-
@torch.no_grad
|
|
75
|
-
def _update(self, vars, ascent):
|
|
76
|
-
alpha = self.get_group_key('alpha')
|
|
77
|
-
return ascent.sample_like(alpha, self.distribution)
|