torchzero 0.1.7__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docs/source/conf.py +57 -0
- tests/test_identical.py +230 -0
- tests/test_module.py +50 -0
- tests/test_opts.py +884 -0
- tests/test_tensorlist.py +1787 -0
- tests/test_utils_optimizer.py +170 -0
- tests/test_vars.py +184 -0
- torchzero/__init__.py +4 -4
- torchzero/core/__init__.py +3 -13
- torchzero/core/module.py +629 -494
- torchzero/core/preconditioner.py +137 -0
- torchzero/core/transform.py +252 -0
- torchzero/modules/__init__.py +13 -21
- torchzero/modules/clipping/__init__.py +3 -0
- torchzero/modules/clipping/clipping.py +320 -0
- torchzero/modules/clipping/ema_clipping.py +135 -0
- torchzero/modules/clipping/growth_clipping.py +187 -0
- torchzero/modules/experimental/__init__.py +13 -18
- torchzero/modules/experimental/absoap.py +350 -0
- torchzero/modules/experimental/adadam.py +111 -0
- torchzero/modules/experimental/adamY.py +135 -0
- torchzero/modules/experimental/adasoap.py +282 -0
- torchzero/modules/experimental/algebraic_newton.py +145 -0
- torchzero/modules/experimental/curveball.py +89 -0
- torchzero/modules/experimental/dsoap.py +290 -0
- torchzero/modules/experimental/gradmin.py +85 -0
- torchzero/modules/experimental/reduce_outward_lr.py +35 -0
- torchzero/modules/experimental/spectral.py +286 -0
- torchzero/modules/experimental/subspace_preconditioners.py +128 -0
- torchzero/modules/experimental/tropical_newton.py +136 -0
- torchzero/modules/functional.py +209 -0
- torchzero/modules/grad_approximation/__init__.py +4 -0
- torchzero/modules/grad_approximation/fdm.py +120 -0
- torchzero/modules/grad_approximation/forward_gradient.py +81 -0
- torchzero/modules/grad_approximation/grad_approximator.py +66 -0
- torchzero/modules/grad_approximation/rfdm.py +259 -0
- torchzero/modules/line_search/__init__.py +5 -30
- torchzero/modules/line_search/backtracking.py +186 -0
- torchzero/modules/line_search/line_search.py +181 -0
- torchzero/modules/line_search/scipy.py +37 -0
- torchzero/modules/line_search/strong_wolfe.py +260 -0
- torchzero/modules/line_search/trust_region.py +61 -0
- torchzero/modules/lr/__init__.py +2 -0
- torchzero/modules/lr/lr.py +59 -0
- torchzero/modules/lr/step_size.py +97 -0
- torchzero/modules/momentum/__init__.py +14 -4
- torchzero/modules/momentum/averaging.py +78 -0
- torchzero/modules/momentum/cautious.py +181 -0
- torchzero/modules/momentum/ema.py +173 -0
- torchzero/modules/momentum/experimental.py +189 -0
- torchzero/modules/momentum/matrix_momentum.py +124 -0
- torchzero/modules/momentum/momentum.py +43 -106
- torchzero/modules/ops/__init__.py +103 -0
- torchzero/modules/ops/accumulate.py +65 -0
- torchzero/modules/ops/binary.py +240 -0
- torchzero/modules/ops/debug.py +25 -0
- torchzero/modules/ops/misc.py +419 -0
- torchzero/modules/ops/multi.py +137 -0
- torchzero/modules/ops/reduce.py +149 -0
- torchzero/modules/ops/split.py +75 -0
- torchzero/modules/ops/switch.py +68 -0
- torchzero/modules/ops/unary.py +115 -0
- torchzero/modules/ops/utility.py +112 -0
- torchzero/modules/optimizers/__init__.py +18 -10
- torchzero/modules/optimizers/adagrad.py +146 -49
- torchzero/modules/optimizers/adam.py +112 -118
- torchzero/modules/optimizers/lion.py +18 -11
- torchzero/modules/optimizers/muon.py +222 -0
- torchzero/modules/optimizers/orthograd.py +55 -0
- torchzero/modules/optimizers/rmsprop.py +103 -51
- torchzero/modules/optimizers/rprop.py +342 -99
- torchzero/modules/optimizers/shampoo.py +197 -0
- torchzero/modules/optimizers/soap.py +286 -0
- torchzero/modules/optimizers/sophia_h.py +129 -0
- torchzero/modules/projections/__init__.py +5 -0
- torchzero/modules/projections/dct.py +73 -0
- torchzero/modules/projections/fft.py +73 -0
- torchzero/modules/projections/galore.py +10 -0
- torchzero/modules/projections/projection.py +218 -0
- torchzero/modules/projections/structural.py +151 -0
- torchzero/modules/quasi_newton/__init__.py +7 -4
- torchzero/modules/quasi_newton/cg.py +218 -0
- torchzero/modules/quasi_newton/experimental/__init__.py +1 -0
- torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +265 -0
- torchzero/modules/quasi_newton/lbfgs.py +228 -0
- torchzero/modules/quasi_newton/lsr1.py +170 -0
- torchzero/modules/quasi_newton/olbfgs.py +196 -0
- torchzero/modules/quasi_newton/quasi_newton.py +475 -0
- torchzero/modules/second_order/__init__.py +3 -4
- torchzero/modules/second_order/newton.py +142 -165
- torchzero/modules/second_order/newton_cg.py +84 -0
- torchzero/modules/second_order/nystrom.py +168 -0
- torchzero/modules/smoothing/__init__.py +2 -5
- torchzero/modules/smoothing/gaussian.py +164 -0
- torchzero/modules/smoothing/{laplacian_smoothing.py → laplacian.py} +115 -128
- torchzero/modules/weight_decay/__init__.py +1 -0
- torchzero/modules/weight_decay/weight_decay.py +52 -0
- torchzero/modules/wrappers/__init__.py +1 -0
- torchzero/modules/wrappers/optim_wrapper.py +91 -0
- torchzero/optim/__init__.py +2 -10
- torchzero/optim/utility/__init__.py +1 -0
- torchzero/optim/utility/split.py +45 -0
- torchzero/optim/wrappers/nevergrad.py +2 -28
- torchzero/optim/wrappers/nlopt.py +31 -16
- torchzero/optim/wrappers/scipy.py +79 -156
- torchzero/utils/__init__.py +27 -0
- torchzero/utils/compile.py +175 -37
- torchzero/utils/derivatives.py +513 -99
- torchzero/utils/linalg/__init__.py +5 -0
- torchzero/utils/linalg/matrix_funcs.py +87 -0
- torchzero/utils/linalg/orthogonalize.py +11 -0
- torchzero/utils/linalg/qr.py +71 -0
- torchzero/utils/linalg/solve.py +168 -0
- torchzero/utils/linalg/svd.py +20 -0
- torchzero/utils/numberlist.py +132 -0
- torchzero/utils/ops.py +10 -0
- torchzero/utils/optimizer.py +284 -0
- torchzero/utils/optuna_tools.py +40 -0
- torchzero/utils/params.py +149 -0
- torchzero/utils/python_tools.py +40 -25
- torchzero/utils/tensorlist.py +1081 -0
- torchzero/utils/torch_tools.py +48 -12
- torchzero-0.3.1.dist-info/METADATA +379 -0
- torchzero-0.3.1.dist-info/RECORD +128 -0
- {torchzero-0.1.7.dist-info → torchzero-0.3.1.dist-info}/WHEEL +1 -1
- {torchzero-0.1.7.dist-info → torchzero-0.3.1.dist-info/licenses}/LICENSE +0 -0
- torchzero-0.3.1.dist-info/top_level.txt +3 -0
- torchzero/core/tensorlist_optimizer.py +0 -219
- torchzero/modules/adaptive/__init__.py +0 -4
- torchzero/modules/adaptive/adaptive.py +0 -192
- torchzero/modules/experimental/experimental.py +0 -294
- torchzero/modules/experimental/quad_interp.py +0 -104
- torchzero/modules/experimental/subspace.py +0 -259
- torchzero/modules/gradient_approximation/__init__.py +0 -7
- torchzero/modules/gradient_approximation/_fd_formulas.py +0 -3
- torchzero/modules/gradient_approximation/base_approximator.py +0 -105
- torchzero/modules/gradient_approximation/fdm.py +0 -125
- torchzero/modules/gradient_approximation/forward_gradient.py +0 -163
- torchzero/modules/gradient_approximation/newton_fdm.py +0 -198
- torchzero/modules/gradient_approximation/rfdm.py +0 -125
- torchzero/modules/line_search/armijo.py +0 -56
- torchzero/modules/line_search/base_ls.py +0 -139
- torchzero/modules/line_search/directional_newton.py +0 -217
- torchzero/modules/line_search/grid_ls.py +0 -158
- torchzero/modules/line_search/scipy_minimize_scalar.py +0 -62
- torchzero/modules/meta/__init__.py +0 -12
- torchzero/modules/meta/alternate.py +0 -65
- torchzero/modules/meta/grafting.py +0 -195
- torchzero/modules/meta/optimizer_wrapper.py +0 -173
- torchzero/modules/meta/return_overrides.py +0 -46
- torchzero/modules/misc/__init__.py +0 -10
- torchzero/modules/misc/accumulate.py +0 -43
- torchzero/modules/misc/basic.py +0 -115
- torchzero/modules/misc/lr.py +0 -96
- torchzero/modules/misc/multistep.py +0 -51
- torchzero/modules/misc/on_increase.py +0 -53
- torchzero/modules/operations/__init__.py +0 -29
- torchzero/modules/operations/multi.py +0 -298
- torchzero/modules/operations/reduction.py +0 -134
- torchzero/modules/operations/singular.py +0 -113
- torchzero/modules/optimizers/sgd.py +0 -54
- torchzero/modules/orthogonalization/__init__.py +0 -2
- torchzero/modules/orthogonalization/newtonschulz.py +0 -159
- torchzero/modules/orthogonalization/svd.py +0 -86
- torchzero/modules/regularization/__init__.py +0 -22
- torchzero/modules/regularization/dropout.py +0 -34
- torchzero/modules/regularization/noise.py +0 -77
- torchzero/modules/regularization/normalization.py +0 -328
- torchzero/modules/regularization/ortho_grad.py +0 -78
- torchzero/modules/regularization/weight_decay.py +0 -92
- torchzero/modules/scheduling/__init__.py +0 -2
- torchzero/modules/scheduling/lr_schedulers.py +0 -131
- torchzero/modules/scheduling/step_size.py +0 -80
- torchzero/modules/smoothing/gaussian_smoothing.py +0 -90
- torchzero/modules/weight_averaging/__init__.py +0 -2
- torchzero/modules/weight_averaging/ema.py +0 -72
- torchzero/modules/weight_averaging/swa.py +0 -171
- torchzero/optim/experimental/__init__.py +0 -20
- torchzero/optim/experimental/experimental.py +0 -343
- torchzero/optim/experimental/ray_search.py +0 -83
- torchzero/optim/first_order/__init__.py +0 -18
- torchzero/optim/first_order/cautious.py +0 -158
- torchzero/optim/first_order/forward_gradient.py +0 -70
- torchzero/optim/first_order/optimizers.py +0 -570
- torchzero/optim/modular.py +0 -132
- torchzero/optim/quasi_newton/__init__.py +0 -1
- torchzero/optim/quasi_newton/directional_newton.py +0 -58
- torchzero/optim/second_order/__init__.py +0 -1
- torchzero/optim/second_order/newton.py +0 -94
- torchzero/optim/zeroth_order/__init__.py +0 -4
- torchzero/optim/zeroth_order/fdm.py +0 -87
- torchzero/optim/zeroth_order/newton_fdm.py +0 -146
- torchzero/optim/zeroth_order/rfdm.py +0 -217
- torchzero/optim/zeroth_order/rs.py +0 -85
- torchzero/random/__init__.py +0 -1
- torchzero/random/random.py +0 -46
- torchzero/tensorlist.py +0 -826
- torchzero-0.1.7.dist-info/METADATA +0 -120
- torchzero-0.1.7.dist-info/RECORD +0 -104
- torchzero-0.1.7.dist-info/top_level.txt +0 -1
|
@@ -1,219 +0,0 @@
|
|
|
1
|
-
from typing import Literal, Any, overload, TypeVar
|
|
2
|
-
from abc import ABC
|
|
3
|
-
from collections.abc import Callable, Sequence, Iterable, Mapping, MutableSequence
|
|
4
|
-
import numpy as np
|
|
5
|
-
import torch
|
|
6
|
-
import torch.optim.optimizer
|
|
7
|
-
from torch.optim.optimizer import ParamsT
|
|
8
|
-
|
|
9
|
-
from ..tensorlist import TensorList, NumberList
|
|
10
|
-
from ..utils.torch_tools import totensor, tofloat
|
|
11
|
-
from ..utils.python_tools import _ScalarLoss
|
|
12
|
-
|
|
13
|
-
_StateInit = Literal['params', 'grad'] | Callable | TensorList
|
|
14
|
-
|
|
15
|
-
_ClosureType = Callable[..., _ScalarLoss]
|
|
16
|
-
"""
|
|
17
|
-
|
|
18
|
-
Closure example:
|
|
19
|
-
|
|
20
|
-
.. code-block:: python
|
|
21
|
-
|
|
22
|
-
def closure(backward = True):
|
|
23
|
-
loss = model(inputs)
|
|
24
|
-
if backward:
|
|
25
|
-
optimizer.zero_grad()
|
|
26
|
-
loss.backward()
|
|
27
|
-
return loss
|
|
28
|
-
|
|
29
|
-
This closure will also work with all built in pytorch optimizers including LBFGS, as well as and most custom ones.
|
|
30
|
-
"""
|
|
31
|
-
|
|
32
|
-
def _maybe_pass_backward(closure: _ClosureType, backward: bool) -> _ScalarLoss:
|
|
33
|
-
"""not passing backward when it is true makes this work with closures with no `backward` argument"""
|
|
34
|
-
if backward:
|
|
35
|
-
with torch.enable_grad(): return closure()
|
|
36
|
-
return closure(False)
|
|
37
|
-
|
|
38
|
-
CLS = TypeVar('CLS')
|
|
39
|
-
class TensorListOptimizer(torch.optim.Optimizer, ABC):
|
|
40
|
-
"""torch.optim.Optimizer with some additional methods related to TensorList.
|
|
41
|
-
|
|
42
|
-
Args:
|
|
43
|
-
params (ParamsT): iterable of parameters.
|
|
44
|
-
defaults (_type_): dictionary with default parameters for the optimizer.
|
|
45
|
-
"""
|
|
46
|
-
def __init__(self, params: ParamsT, defaults):
|
|
47
|
-
super().__init__(params, defaults)
|
|
48
|
-
self._params: list[torch.Tensor] = [param for group in self.param_groups for param in group['params']]
|
|
49
|
-
self.has_complex = any(torch.is_complex(x) for x in self._params)
|
|
50
|
-
"""True if any of the params are complex"""
|
|
51
|
-
|
|
52
|
-
def add_param_group(self, param_group: dict[str, Any]) -> None:
|
|
53
|
-
super().add_param_group(param_group)
|
|
54
|
-
self._params: list[torch.Tensor] = [param for group in self.param_groups for param in group['params']]
|
|
55
|
-
self.has_complex = any(torch.is_complex(x) for x in self._params)
|
|
56
|
-
|
|
57
|
-
# def get_params[CLS: Any](self, cls: type[CLS] = TensorList) -> CLS:
|
|
58
|
-
def get_params(self, cls: type[CLS] = TensorList) -> CLS:
|
|
59
|
-
"""returns all params with `requires_grad = True` as a TensorList."""
|
|
60
|
-
return cls(p for p in self._params if p.requires_grad) # type:ignore
|
|
61
|
-
|
|
62
|
-
def ensure_grad_(self):
|
|
63
|
-
"""Replaces None grad attribute with zeroes for all parameters that require grad."""
|
|
64
|
-
for p in self.get_params():
|
|
65
|
-
if p.requires_grad and p.grad is None: p.grad = torch.zeros_like(p)
|
|
66
|
-
|
|
67
|
-
# def get_state_key[CLS: MutableSequence](self, key: str, init: _StateInit = torch.zeros_like, params=None, cls: type[CLS] = TensorList) -> CLS:
|
|
68
|
-
def get_state_key(self, key: str, init: _StateInit = torch.zeros_like, params=None, cls: type[CLS] = TensorList) -> CLS:
|
|
69
|
-
"""Returns a tensorlist of all `key` states of all params with `requires_grad = True`.
|
|
70
|
-
|
|
71
|
-
Args:
|
|
72
|
-
key (str): key to create/access.
|
|
73
|
-
init: Initial value if key doesn't exist. Can be `params`, `grad`, or callable such as `torch.zeros_like`.
|
|
74
|
-
Defaults to torch.zeros_like.
|
|
75
|
-
params (optional): optionally pass params if you already created them. Defaults to None.
|
|
76
|
-
cls (optional): optionally specify any other MutableSequence subclass to use instead of TensorList.
|
|
77
|
-
|
|
78
|
-
Returns:
|
|
79
|
-
TensorList: TensorList with the `key` state. Those tensors are stored in the optimizer, so modify them in-place.
|
|
80
|
-
"""
|
|
81
|
-
value = cls()
|
|
82
|
-
if params is None: params = self.get_params()
|
|
83
|
-
for pi, p in enumerate(params):
|
|
84
|
-
state = self.state[p]
|
|
85
|
-
if key not in state:
|
|
86
|
-
if callable(init): state[key] = init(p)
|
|
87
|
-
elif isinstance(init, TensorList): state[key] = init[pi].clone()
|
|
88
|
-
elif init == 'params': state[key] = p.clone().detach()
|
|
89
|
-
elif init == 'grad': state[key] = p.grad.clone().detach() if p.grad is not None else torch.zeros_like(p)
|
|
90
|
-
else: raise ValueError(f'unknown init - {init}')
|
|
91
|
-
value.append(state[key]) # type:ignore
|
|
92
|
-
return value
|
|
93
|
-
|
|
94
|
-
# def get_state_keys[CLS: MutableSequence](
|
|
95
|
-
def get_state_keys(
|
|
96
|
-
self,
|
|
97
|
-
*keys: str,
|
|
98
|
-
inits: _StateInit | Sequence[_StateInit] = torch.zeros_like,
|
|
99
|
-
params=None,
|
|
100
|
-
cls: type[CLS] = TensorList,
|
|
101
|
-
) -> list[CLS]:
|
|
102
|
-
"""Returns a TensorList with the `key` states of all `params`. Creates the states if they don't exist."""
|
|
103
|
-
|
|
104
|
-
values = [cls() for _ in range(len(keys))]
|
|
105
|
-
if params is None: params = self.get_params()
|
|
106
|
-
if callable(inits) or isinstance(inits, str): inits = [inits] * len(keys) # type:ignore
|
|
107
|
-
|
|
108
|
-
for pi, p in enumerate(params):
|
|
109
|
-
state = self.state[p]
|
|
110
|
-
for i, (key, init) in enumerate(zip(keys, inits)): # type:ignore
|
|
111
|
-
if key not in state:
|
|
112
|
-
if callable(init): state[key] = init(p)
|
|
113
|
-
elif isinstance(init, TensorList): state[key] = init[pi].clone()
|
|
114
|
-
elif init == 'params': state[key] = p.clone().detach()
|
|
115
|
-
elif init == 'grad': state[key] = p.grad.clone().detach() if p.grad is not None else torch.zeros_like(p)
|
|
116
|
-
else: raise ValueError(f'unknown init - {init}')
|
|
117
|
-
values[i].append(state[key]) # type:ignore
|
|
118
|
-
return values
|
|
119
|
-
|
|
120
|
-
def _yield_groups_key(self, key: str):
|
|
121
|
-
for group in self.param_groups:
|
|
122
|
-
value = group[key]
|
|
123
|
-
for p in group['params']:
|
|
124
|
-
if p.requires_grad: yield value
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
# def get_group_key[CLS: Any](self, key: str, cls: type[CLS] = NumberList) -> CLS:
|
|
128
|
-
def get_group_key(self, key: str, cls: type[CLS] = NumberList) -> CLS:
|
|
129
|
-
"""Returns a TensorList with the param_groups `key` setting of each param."""
|
|
130
|
-
return cls(self._yield_groups_key(key)) # type:ignore
|
|
131
|
-
|
|
132
|
-
def get_first_group_key(self, key:str) -> Any:
|
|
133
|
-
"""Returns the param_groups `key` setting of the first param."""
|
|
134
|
-
return next(iter(self._yield_groups_key(key)))
|
|
135
|
-
|
|
136
|
-
# def get_all_group_keys[CLS: Any](self, cls: type[CLS] = NumberList) -> dict[str, CLS]:
|
|
137
|
-
def get_all_group_keys(self, cls: type[CLS] = NumberList) -> dict[str, CLS]:
|
|
138
|
-
all_values: dict[str, CLS] = {}
|
|
139
|
-
for group in self.param_groups:
|
|
140
|
-
|
|
141
|
-
n_params = len([p for p in group['params'] if p.requires_grad])
|
|
142
|
-
|
|
143
|
-
for key, value in group.items():
|
|
144
|
-
if key != 'params':
|
|
145
|
-
if key not in all_values: all_values[key] = cls(value for _ in range(n_params)) # type:ignore
|
|
146
|
-
else: all_values[key].extend([value for _ in range(n_params)]) # type:ignore
|
|
147
|
-
|
|
148
|
-
return all_values
|
|
149
|
-
|
|
150
|
-
# def get_group_keys[CLS: MutableSequence](self, *keys: str, cls: type[CLS] = NumberList) -> list[CLS]:
|
|
151
|
-
def get_group_keys(self, *keys: str, cls: type[CLS] = NumberList) -> list[CLS]:
|
|
152
|
-
"""Returns a list with the param_groups `key` setting of each param."""
|
|
153
|
-
|
|
154
|
-
all_values: list[CLS] = [cls() for _ in keys]
|
|
155
|
-
for group in self.param_groups:
|
|
156
|
-
|
|
157
|
-
n_params = len([p for p in group['params'] if p.requires_grad])
|
|
158
|
-
|
|
159
|
-
for i, key in enumerate(keys):
|
|
160
|
-
value = group[key]
|
|
161
|
-
all_values[i].extend([value for _ in range(n_params)]) # type:ignore
|
|
162
|
-
|
|
163
|
-
return all_values
|
|
164
|
-
|
|
165
|
-
@torch.no_grad
|
|
166
|
-
def evaluate_loss_at_vec(self, vec, closure=None, params = None, backward=False, ensure_float=False):
|
|
167
|
-
"""_summary_
|
|
168
|
-
|
|
169
|
-
Args:
|
|
170
|
-
vec (_type_): _description_
|
|
171
|
-
closure (_type_, optional): _description_. Defaults to None.
|
|
172
|
-
params (_type_, optional): _description_. Defaults to None.
|
|
173
|
-
backward (bool, optional): _description_. Defaults to False.
|
|
174
|
-
ensure_float (bool, optional): _description_. Defaults to False.
|
|
175
|
-
|
|
176
|
-
Returns:
|
|
177
|
-
_type_: _description_
|
|
178
|
-
"""
|
|
179
|
-
vec = totensor(vec)
|
|
180
|
-
if closure is None: closure = self._closure # type:ignore # pylint:disable=no-member
|
|
181
|
-
if params is None: params = self.get_params()
|
|
182
|
-
|
|
183
|
-
params.from_vec_(vec.to(params[0]))
|
|
184
|
-
loss = _maybe_pass_backward(closure, backward)
|
|
185
|
-
|
|
186
|
-
if ensure_float: return tofloat(loss)
|
|
187
|
-
return _maybe_pass_backward(closure, backward)
|
|
188
|
-
|
|
189
|
-
@overload
|
|
190
|
-
def evaluate_loss_grad_at_vec(self, vec, closure=None, params = None, to_numpy: Literal[True] = False) -> tuple[float, np.ndarray]: ... # type:ignore
|
|
191
|
-
@overload
|
|
192
|
-
def evaluate_loss_grad_at_vec(self, vec, closure=None, params = None, to_numpy: Literal[False] = False) -> tuple[_ScalarLoss, torch.Tensor]: ...
|
|
193
|
-
@torch.no_grad
|
|
194
|
-
def evaluate_loss_grad_at_vec(self, vec, closure=None, params = None, to_numpy: Literal[True] | Literal[False]=False):
|
|
195
|
-
"""_summary_
|
|
196
|
-
|
|
197
|
-
Args:
|
|
198
|
-
vec (_type_): _description_
|
|
199
|
-
closure (_type_, optional): _description_. Defaults to None.
|
|
200
|
-
params (_type_, optional): _description_. Defaults to None.
|
|
201
|
-
to_numpy (Literal[True] | Literal[False], optional): _description_. Defaults to False.
|
|
202
|
-
|
|
203
|
-
Returns:
|
|
204
|
-
_type_: _description_
|
|
205
|
-
"""
|
|
206
|
-
if params is None: params = self.get_params()
|
|
207
|
-
loss = self.evaluate_loss_at_vec(vec, closure, params, backward = True, ensure_float = to_numpy)
|
|
208
|
-
grad = params.grad.to_vec()
|
|
209
|
-
|
|
210
|
-
if to_numpy: return tofloat(loss), grad.detach().cpu().numpy()
|
|
211
|
-
return loss, grad
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
@torch.no_grad
|
|
215
|
-
def _maybe_evaluate_closure(self, closure, backward=True):
|
|
216
|
-
loss = None
|
|
217
|
-
if closure is not None:
|
|
218
|
-
loss = _maybe_pass_backward(closure, backward)
|
|
219
|
-
return loss
|
|
@@ -1,192 +0,0 @@
|
|
|
1
|
-
import typing
|
|
2
|
-
import torch
|
|
3
|
-
|
|
4
|
-
from ...core import OptimizerModule
|
|
5
|
-
|
|
6
|
-
class Cautious(OptimizerModule):
|
|
7
|
-
"""Negates update for parameters where update and gradient sign is inconsistent.
|
|
8
|
-
Optionally normalizes the update by the number of parameters that are not masked.
|
|
9
|
-
This is meant to be used after any momentum-based modules.
|
|
10
|
-
|
|
11
|
-
Args:
|
|
12
|
-
normalize (bool, optional):
|
|
13
|
-
renormalize update after masking.
|
|
14
|
-
only has effect when mode is 'zero'. Defaults to False.
|
|
15
|
-
eps (float, optional): epsilon for normalization. Defaults to 1e-6.
|
|
16
|
-
mode (str, optional):
|
|
17
|
-
what to do with updates with inconsistent signs.
|
|
18
|
-
|
|
19
|
-
"zero" - set them to zero (as in paper)
|
|
20
|
-
|
|
21
|
-
"grad" - set them to the gradient
|
|
22
|
-
|
|
23
|
-
"negate" - negate them (same as using update magnitude and gradient sign)
|
|
24
|
-
|
|
25
|
-
reference
|
|
26
|
-
*Cautious Optimizers: Improving Training with One Line of Code.
|
|
27
|
-
Kaizhao Liang, Lizhang Chen, Bo Liu, Qiang Liu*
|
|
28
|
-
"""
|
|
29
|
-
def __init__(self, normalize = False, eps=1e-6, mode: typing.Literal['zero', 'grad', 'backtrack'] = 'zero'):
|
|
30
|
-
super().__init__({})
|
|
31
|
-
self.eps = eps
|
|
32
|
-
self.normalize = normalize
|
|
33
|
-
self.mode: typing.Literal['zero', 'grad', 'backtrack'] = mode
|
|
34
|
-
|
|
35
|
-
@torch.no_grad
|
|
36
|
-
def _update(self, vars, ascent):
|
|
37
|
-
params = self.get_params()
|
|
38
|
-
grad = vars.maybe_compute_grad_(params)
|
|
39
|
-
|
|
40
|
-
# mask will be > 0 for parameters where both signs are the same
|
|
41
|
-
mask = (ascent * grad) > 0
|
|
42
|
-
if self.mode in ('zero', 'grad'):
|
|
43
|
-
if self.normalize and self.mode == 'zero':
|
|
44
|
-
fmask = mask.to(ascent[0].dtype)
|
|
45
|
-
fmask /= fmask.total_mean() + self.eps # type:ignore
|
|
46
|
-
else:
|
|
47
|
-
fmask = mask
|
|
48
|
-
|
|
49
|
-
ascent *= fmask
|
|
50
|
-
|
|
51
|
-
if self.mode == 'grad':
|
|
52
|
-
ascent += grad * mask.logical_not_()
|
|
53
|
-
|
|
54
|
-
return ascent
|
|
55
|
-
|
|
56
|
-
# mode = 'backtrack'
|
|
57
|
-
ascent -= ascent.mul(2).mul_(mask.logical_not_())
|
|
58
|
-
return ascent
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
class UseGradSign(OptimizerModule):
|
|
62
|
-
"""
|
|
63
|
-
Uses update magnitude but gradient sign.
|
|
64
|
-
"""
|
|
65
|
-
def __init__(self):
|
|
66
|
-
super().__init__({})
|
|
67
|
-
|
|
68
|
-
@torch.no_grad
|
|
69
|
-
def _update(self, vars, ascent):
|
|
70
|
-
params = self.get_params()
|
|
71
|
-
grad = vars.maybe_compute_grad_(params)
|
|
72
|
-
|
|
73
|
-
return ascent.abs_().mul_(grad.sign())
|
|
74
|
-
|
|
75
|
-
class UseGradMagnitude(OptimizerModule):
|
|
76
|
-
"""
|
|
77
|
-
Uses update sign but gradient magnitude.
|
|
78
|
-
"""
|
|
79
|
-
def __init__(self):
|
|
80
|
-
super().__init__({})
|
|
81
|
-
|
|
82
|
-
@torch.no_grad
|
|
83
|
-
def _update(self, vars, ascent):
|
|
84
|
-
params = self.get_params()
|
|
85
|
-
grad = vars.maybe_compute_grad_(params)
|
|
86
|
-
|
|
87
|
-
return ascent.sign_().mul_(grad.abs())
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
class ScaleLRBySignChange(OptimizerModule):
|
|
91
|
-
"""
|
|
92
|
-
learning rate gets multiplied by `nplus` if ascent/gradient didn't change the sign,
|
|
93
|
-
or `nminus` if it did.
|
|
94
|
-
|
|
95
|
-
This is part of RProp update rule.
|
|
96
|
-
|
|
97
|
-
Args:
|
|
98
|
-
nplus (float): learning rate gets multiplied by `nplus` if ascent/gradient didn't change the sign
|
|
99
|
-
nminus (float): learning rate gets multiplied by `nminus` if ascent/gradient changed the sign
|
|
100
|
-
lb (float): lower bound for lr.
|
|
101
|
-
ub (float): upper bound for lr.
|
|
102
|
-
alpha (float): initial learning rate.
|
|
103
|
-
|
|
104
|
-
"""
|
|
105
|
-
def __init__(self, nplus: float = 1.2, nminus: float = 0.5, lb = 1e-6, ub = 50, alpha=1, use_grad=False):
|
|
106
|
-
defaults = dict(nplus = nplus, nminus = nminus, alpha = alpha, lb = lb, ub = ub)
|
|
107
|
-
super().__init__(defaults)
|
|
108
|
-
self.current_step = 0
|
|
109
|
-
self.use_grad = use_grad
|
|
110
|
-
|
|
111
|
-
@torch.no_grad
|
|
112
|
-
def _update(self, vars, ascent):
|
|
113
|
-
params = self.get_params()
|
|
114
|
-
|
|
115
|
-
if self.use_grad: cur = vars.maybe_compute_grad_(params)
|
|
116
|
-
else: cur = ascent
|
|
117
|
-
|
|
118
|
-
nplus, nminus, lb, ub = self.get_group_keys('nplus', 'nminus', 'lb', 'ub')
|
|
119
|
-
prev, lrs = self.get_state_keys('prev_ascent', 'lrs', params=params)
|
|
120
|
-
|
|
121
|
-
# initialize on 1st step
|
|
122
|
-
if self.current_step == 0:
|
|
123
|
-
lrs.fill_(self.get_group_key('alpha'))
|
|
124
|
-
ascent.mul_(lrs)
|
|
125
|
-
prev.copy_(ascent)
|
|
126
|
-
self.current_step += 1
|
|
127
|
-
return ascent
|
|
128
|
-
|
|
129
|
-
mask = cur * prev
|
|
130
|
-
sign_changed = mask < 0
|
|
131
|
-
sign_same = mask > 0
|
|
132
|
-
|
|
133
|
-
# multiply magnitudes where sign didn't change
|
|
134
|
-
lrs.masked_set_(sign_same, lrs * nplus)
|
|
135
|
-
# multiply magnitudes where sign changed
|
|
136
|
-
lrs.masked_set_(sign_changed, lrs * nminus)
|
|
137
|
-
# bounds
|
|
138
|
-
lrs.clamp_(lb, ub)
|
|
139
|
-
|
|
140
|
-
ascent.mul_(lrs)
|
|
141
|
-
prev.copy_(cur)
|
|
142
|
-
self.current_step += 1
|
|
143
|
-
return ascent
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
class NegateOnSignChange(OptimizerModule):
|
|
148
|
-
"""Negates or undoes update for parameters where where gradient or update sign changes.
|
|
149
|
-
|
|
150
|
-
This is part of RProp update rule.
|
|
151
|
-
|
|
152
|
-
Args:
|
|
153
|
-
normalize (bool, optional): renormalize update after masking. Defaults to False.
|
|
154
|
-
eps (_type_, optional): epsilon for normalization. Defaults to 1e-6.
|
|
155
|
-
use_grad (bool, optional): if True, tracks sign change of the gradient,
|
|
156
|
-
otherwise track sign change of the update. Defaults to True.
|
|
157
|
-
backtrack (bool, optional): if True, undoes the update when sign changes, otherwise negates it.
|
|
158
|
-
Defaults to True.
|
|
159
|
-
|
|
160
|
-
"""
|
|
161
|
-
# todo: add momentum to negation (to cautious as well and rprop negation as well)
|
|
162
|
-
def __init__(self, normalize = False, eps=1e-6, use_grad = False, backtrack = True):
|
|
163
|
-
super().__init__({})
|
|
164
|
-
self.eps = eps
|
|
165
|
-
self.normalize = normalize
|
|
166
|
-
self.use_grad = use_grad
|
|
167
|
-
self.backtrack = backtrack
|
|
168
|
-
self.current_step = 0
|
|
169
|
-
|
|
170
|
-
@torch.no_grad
|
|
171
|
-
def _update(self, vars, ascent):
|
|
172
|
-
params = self.get_params()
|
|
173
|
-
|
|
174
|
-
if self.use_grad: cur = vars.maybe_compute_grad_(params)
|
|
175
|
-
else: cur = ascent
|
|
176
|
-
|
|
177
|
-
prev = self.get_state_key('prev')
|
|
178
|
-
|
|
179
|
-
# initialize on first step
|
|
180
|
-
if self.current_step == 0:
|
|
181
|
-
prev.set_(cur)
|
|
182
|
-
self.current_step += 1
|
|
183
|
-
return ascent
|
|
184
|
-
|
|
185
|
-
# mask will be > 0 for parameters where both signs are the same
|
|
186
|
-
mask = (cur * prev) < 0
|
|
187
|
-
if self.backtrack: ascent.masked_set_(mask, prev)
|
|
188
|
-
else: ascent.select_set_(mask, 0)
|
|
189
|
-
|
|
190
|
-
prev.set_(cur)
|
|
191
|
-
self.current_step += 1
|
|
192
|
-
return ascent
|