torchzero 0.1.8__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docs/source/conf.py +57 -0
- tests/test_identical.py +230 -0
- tests/test_module.py +50 -0
- tests/test_opts.py +884 -0
- tests/test_tensorlist.py +1787 -0
- tests/test_utils_optimizer.py +170 -0
- tests/test_vars.py +184 -0
- torchzero/__init__.py +4 -4
- torchzero/core/__init__.py +3 -13
- torchzero/core/module.py +629 -510
- torchzero/core/preconditioner.py +137 -0
- torchzero/core/transform.py +252 -0
- torchzero/modules/__init__.py +13 -21
- torchzero/modules/clipping/__init__.py +3 -0
- torchzero/modules/clipping/clipping.py +320 -0
- torchzero/modules/clipping/ema_clipping.py +135 -0
- torchzero/modules/clipping/growth_clipping.py +187 -0
- torchzero/modules/experimental/__init__.py +13 -18
- torchzero/modules/experimental/absoap.py +350 -0
- torchzero/modules/experimental/adadam.py +111 -0
- torchzero/modules/experimental/adamY.py +135 -0
- torchzero/modules/experimental/adasoap.py +282 -0
- torchzero/modules/experimental/algebraic_newton.py +145 -0
- torchzero/modules/experimental/curveball.py +89 -0
- torchzero/modules/experimental/dsoap.py +290 -0
- torchzero/modules/experimental/gradmin.py +85 -0
- torchzero/modules/experimental/reduce_outward_lr.py +35 -0
- torchzero/modules/experimental/spectral.py +286 -0
- torchzero/modules/experimental/subspace_preconditioners.py +128 -0
- torchzero/modules/experimental/tropical_newton.py +136 -0
- torchzero/modules/functional.py +209 -0
- torchzero/modules/grad_approximation/__init__.py +4 -0
- torchzero/modules/grad_approximation/fdm.py +120 -0
- torchzero/modules/grad_approximation/forward_gradient.py +81 -0
- torchzero/modules/grad_approximation/grad_approximator.py +66 -0
- torchzero/modules/grad_approximation/rfdm.py +259 -0
- torchzero/modules/line_search/__init__.py +5 -30
- torchzero/modules/line_search/backtracking.py +186 -0
- torchzero/modules/line_search/line_search.py +181 -0
- torchzero/modules/line_search/scipy.py +37 -0
- torchzero/modules/line_search/strong_wolfe.py +260 -0
- torchzero/modules/line_search/trust_region.py +61 -0
- torchzero/modules/lr/__init__.py +2 -0
- torchzero/modules/lr/lr.py +59 -0
- torchzero/modules/lr/step_size.py +97 -0
- torchzero/modules/momentum/__init__.py +14 -4
- torchzero/modules/momentum/averaging.py +78 -0
- torchzero/modules/momentum/cautious.py +181 -0
- torchzero/modules/momentum/ema.py +173 -0
- torchzero/modules/momentum/experimental.py +189 -0
- torchzero/modules/momentum/matrix_momentum.py +124 -0
- torchzero/modules/momentum/momentum.py +43 -106
- torchzero/modules/ops/__init__.py +103 -0
- torchzero/modules/ops/accumulate.py +65 -0
- torchzero/modules/ops/binary.py +240 -0
- torchzero/modules/ops/debug.py +25 -0
- torchzero/modules/ops/misc.py +419 -0
- torchzero/modules/ops/multi.py +137 -0
- torchzero/modules/ops/reduce.py +149 -0
- torchzero/modules/ops/split.py +75 -0
- torchzero/modules/ops/switch.py +68 -0
- torchzero/modules/ops/unary.py +115 -0
- torchzero/modules/ops/utility.py +112 -0
- torchzero/modules/optimizers/__init__.py +18 -10
- torchzero/modules/optimizers/adagrad.py +146 -49
- torchzero/modules/optimizers/adam.py +112 -118
- torchzero/modules/optimizers/lion.py +18 -11
- torchzero/modules/optimizers/muon.py +222 -0
- torchzero/modules/optimizers/orthograd.py +55 -0
- torchzero/modules/optimizers/rmsprop.py +103 -51
- torchzero/modules/optimizers/rprop.py +342 -99
- torchzero/modules/optimizers/shampoo.py +197 -0
- torchzero/modules/optimizers/soap.py +286 -0
- torchzero/modules/optimizers/sophia_h.py +129 -0
- torchzero/modules/projections/__init__.py +5 -0
- torchzero/modules/projections/dct.py +73 -0
- torchzero/modules/projections/fft.py +73 -0
- torchzero/modules/projections/galore.py +10 -0
- torchzero/modules/projections/projection.py +218 -0
- torchzero/modules/projections/structural.py +151 -0
- torchzero/modules/quasi_newton/__init__.py +7 -4
- torchzero/modules/quasi_newton/cg.py +218 -0
- torchzero/modules/quasi_newton/experimental/__init__.py +1 -0
- torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +265 -0
- torchzero/modules/quasi_newton/lbfgs.py +228 -0
- torchzero/modules/quasi_newton/lsr1.py +170 -0
- torchzero/modules/quasi_newton/olbfgs.py +196 -0
- torchzero/modules/quasi_newton/quasi_newton.py +475 -0
- torchzero/modules/second_order/__init__.py +3 -4
- torchzero/modules/second_order/newton.py +142 -165
- torchzero/modules/second_order/newton_cg.py +84 -0
- torchzero/modules/second_order/nystrom.py +168 -0
- torchzero/modules/smoothing/__init__.py +2 -5
- torchzero/modules/smoothing/gaussian.py +164 -0
- torchzero/modules/smoothing/{laplacian_smoothing.py → laplacian.py} +115 -128
- torchzero/modules/weight_decay/__init__.py +1 -0
- torchzero/modules/weight_decay/weight_decay.py +52 -0
- torchzero/modules/wrappers/__init__.py +1 -0
- torchzero/modules/wrappers/optim_wrapper.py +91 -0
- torchzero/optim/__init__.py +2 -10
- torchzero/optim/utility/__init__.py +1 -0
- torchzero/optim/utility/split.py +45 -0
- torchzero/optim/wrappers/nevergrad.py +2 -28
- torchzero/optim/wrappers/nlopt.py +31 -16
- torchzero/optim/wrappers/scipy.py +79 -156
- torchzero/utils/__init__.py +27 -0
- torchzero/utils/compile.py +175 -37
- torchzero/utils/derivatives.py +513 -99
- torchzero/utils/linalg/__init__.py +5 -0
- torchzero/utils/linalg/matrix_funcs.py +87 -0
- torchzero/utils/linalg/orthogonalize.py +11 -0
- torchzero/utils/linalg/qr.py +71 -0
- torchzero/utils/linalg/solve.py +168 -0
- torchzero/utils/linalg/svd.py +20 -0
- torchzero/utils/numberlist.py +132 -0
- torchzero/utils/ops.py +10 -0
- torchzero/utils/optimizer.py +284 -0
- torchzero/utils/optuna_tools.py +40 -0
- torchzero/utils/params.py +149 -0
- torchzero/utils/python_tools.py +40 -25
- torchzero/utils/tensorlist.py +1081 -0
- torchzero/utils/torch_tools.py +48 -12
- torchzero-0.3.1.dist-info/METADATA +379 -0
- torchzero-0.3.1.dist-info/RECORD +128 -0
- {torchzero-0.1.8.dist-info → torchzero-0.3.1.dist-info}/WHEEL +1 -1
- {torchzero-0.1.8.dist-info → torchzero-0.3.1.dist-info/licenses}/LICENSE +0 -0
- torchzero-0.3.1.dist-info/top_level.txt +3 -0
- torchzero/core/tensorlist_optimizer.py +0 -219
- torchzero/modules/adaptive/__init__.py +0 -4
- torchzero/modules/adaptive/adaptive.py +0 -192
- torchzero/modules/experimental/experimental.py +0 -294
- torchzero/modules/experimental/quad_interp.py +0 -104
- torchzero/modules/experimental/subspace.py +0 -259
- torchzero/modules/gradient_approximation/__init__.py +0 -7
- torchzero/modules/gradient_approximation/_fd_formulas.py +0 -3
- torchzero/modules/gradient_approximation/base_approximator.py +0 -105
- torchzero/modules/gradient_approximation/fdm.py +0 -125
- torchzero/modules/gradient_approximation/forward_gradient.py +0 -163
- torchzero/modules/gradient_approximation/newton_fdm.py +0 -198
- torchzero/modules/gradient_approximation/rfdm.py +0 -125
- torchzero/modules/line_search/armijo.py +0 -56
- torchzero/modules/line_search/base_ls.py +0 -139
- torchzero/modules/line_search/directional_newton.py +0 -217
- torchzero/modules/line_search/grid_ls.py +0 -158
- torchzero/modules/line_search/scipy_minimize_scalar.py +0 -62
- torchzero/modules/meta/__init__.py +0 -12
- torchzero/modules/meta/alternate.py +0 -65
- torchzero/modules/meta/grafting.py +0 -195
- torchzero/modules/meta/optimizer_wrapper.py +0 -173
- torchzero/modules/meta/return_overrides.py +0 -46
- torchzero/modules/misc/__init__.py +0 -10
- torchzero/modules/misc/accumulate.py +0 -43
- torchzero/modules/misc/basic.py +0 -115
- torchzero/modules/misc/lr.py +0 -96
- torchzero/modules/misc/multistep.py +0 -51
- torchzero/modules/misc/on_increase.py +0 -53
- torchzero/modules/operations/__init__.py +0 -29
- torchzero/modules/operations/multi.py +0 -298
- torchzero/modules/operations/reduction.py +0 -134
- torchzero/modules/operations/singular.py +0 -113
- torchzero/modules/optimizers/sgd.py +0 -54
- torchzero/modules/orthogonalization/__init__.py +0 -2
- torchzero/modules/orthogonalization/newtonschulz.py +0 -159
- torchzero/modules/orthogonalization/svd.py +0 -86
- torchzero/modules/regularization/__init__.py +0 -22
- torchzero/modules/regularization/dropout.py +0 -34
- torchzero/modules/regularization/noise.py +0 -77
- torchzero/modules/regularization/normalization.py +0 -328
- torchzero/modules/regularization/ortho_grad.py +0 -78
- torchzero/modules/regularization/weight_decay.py +0 -92
- torchzero/modules/scheduling/__init__.py +0 -2
- torchzero/modules/scheduling/lr_schedulers.py +0 -131
- torchzero/modules/scheduling/step_size.py +0 -80
- torchzero/modules/smoothing/gaussian_smoothing.py +0 -90
- torchzero/modules/weight_averaging/__init__.py +0 -2
- torchzero/modules/weight_averaging/ema.py +0 -72
- torchzero/modules/weight_averaging/swa.py +0 -171
- torchzero/optim/experimental/__init__.py +0 -20
- torchzero/optim/experimental/experimental.py +0 -343
- torchzero/optim/experimental/ray_search.py +0 -83
- torchzero/optim/first_order/__init__.py +0 -18
- torchzero/optim/first_order/cautious.py +0 -158
- torchzero/optim/first_order/forward_gradient.py +0 -70
- torchzero/optim/first_order/optimizers.py +0 -570
- torchzero/optim/modular.py +0 -148
- torchzero/optim/quasi_newton/__init__.py +0 -1
- torchzero/optim/quasi_newton/directional_newton.py +0 -58
- torchzero/optim/second_order/__init__.py +0 -1
- torchzero/optim/second_order/newton.py +0 -94
- torchzero/optim/zeroth_order/__init__.py +0 -4
- torchzero/optim/zeroth_order/fdm.py +0 -87
- torchzero/optim/zeroth_order/newton_fdm.py +0 -146
- torchzero/optim/zeroth_order/rfdm.py +0 -217
- torchzero/optim/zeroth_order/rs.py +0 -85
- torchzero/random/__init__.py +0 -1
- torchzero/random/random.py +0 -46
- torchzero/tensorlist.py +0 -826
- torchzero-0.1.8.dist-info/METADATA +0 -130
- torchzero-0.1.8.dist-info/RECORD +0 -104
- torchzero-0.1.8.dist-info/top_level.txt +0 -1
|
@@ -0,0 +1,164 @@
|
|
|
1
|
+
import warnings
|
|
2
|
+
from abc import ABC, abstractmethod
|
|
3
|
+
from collections.abc import Callable, Sequence
|
|
4
|
+
from functools import partial
|
|
5
|
+
from typing import Literal
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
|
|
9
|
+
from ...core import Modular, Module, Vars
|
|
10
|
+
from ...utils import NumberList, TensorList
|
|
11
|
+
from ...utils.derivatives import jacobian_wrt
|
|
12
|
+
from ..grad_approximation import GradApproximator, GradTarget
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class Reformulation(Module, ABC):
|
|
16
|
+
def __init__(self, defaults):
|
|
17
|
+
super().__init__(defaults)
|
|
18
|
+
|
|
19
|
+
@abstractmethod
|
|
20
|
+
def closure(self, backward: bool, closure: Callable, params:list[torch.Tensor], vars: Vars) -> tuple[float | torch.Tensor, Sequence[torch.Tensor] | None]:
|
|
21
|
+
"""returns loss and gradient, if backward is False then gradient can be None"""
|
|
22
|
+
|
|
23
|
+
def pre_step(self, vars: Vars) -> Vars | None:
|
|
24
|
+
"""This runs once before each step, whereas `closure` may run multiple times per step if further modules
|
|
25
|
+
evaluate gradients at multiple points. This is useful for example to pre-generate new random perturbations."""
|
|
26
|
+
return vars
|
|
27
|
+
|
|
28
|
+
def step(self, vars):
|
|
29
|
+
ret = self.pre_step(vars)
|
|
30
|
+
if isinstance(ret, Vars): vars = ret
|
|
31
|
+
|
|
32
|
+
if vars.closure is None: raise RuntimeError("Reformulation requires closure")
|
|
33
|
+
params, closure = vars.params, vars.closure
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def modified_closure(backward=True):
|
|
37
|
+
loss, grad = self.closure(backward, closure, params, vars)
|
|
38
|
+
|
|
39
|
+
if grad is not None:
|
|
40
|
+
for p,g in zip(params, grad):
|
|
41
|
+
p.grad = g
|
|
42
|
+
|
|
43
|
+
return loss
|
|
44
|
+
|
|
45
|
+
vars.closure = modified_closure
|
|
46
|
+
return vars
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def _decay_sigma_(self: Module, params):
|
|
50
|
+
for p in params:
|
|
51
|
+
state = self.state[p]
|
|
52
|
+
settings = self.settings[p]
|
|
53
|
+
state['sigma'] *= settings['decay']
|
|
54
|
+
|
|
55
|
+
def _generate_perturbations_to_state_(self: Module, params: TensorList, n_samples, sigmas, generator):
|
|
56
|
+
perturbations = [params.sample_like(generator=generator) for _ in range(n_samples)]
|
|
57
|
+
torch._foreach_mul_([p for l in perturbations for p in l], [v for vv in sigmas for v in [vv]*n_samples])
|
|
58
|
+
for param, prt in zip(params, zip(*perturbations)):
|
|
59
|
+
self.state[param]['perturbations'] = prt
|
|
60
|
+
|
|
61
|
+
def _clear_state_hook(optimizer: Modular, vars: Vars, self: Module):
|
|
62
|
+
for m in optimizer.unrolled_modules:
|
|
63
|
+
if m is not self:
|
|
64
|
+
m.reset()
|
|
65
|
+
|
|
66
|
+
class GaussianHomotopy(Reformulation):
|
|
67
|
+
def __init__(
|
|
68
|
+
self,
|
|
69
|
+
n_samples: int,
|
|
70
|
+
init_sigma: float,
|
|
71
|
+
tol: float | None = 1e-4,
|
|
72
|
+
decay=0.5,
|
|
73
|
+
max_steps: int | None = None,
|
|
74
|
+
clear_state=True,
|
|
75
|
+
seed: int | None = None,
|
|
76
|
+
):
|
|
77
|
+
defaults = dict(n_samples=n_samples, init_sigma=init_sigma, tol=tol, decay=decay, max_steps=max_steps, clear_state=clear_state, seed=seed)
|
|
78
|
+
super().__init__(defaults)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def _get_generator(self, seed: int | None | torch.Generator, params: list[torch.Tensor]):
|
|
82
|
+
if 'generator' not in self.global_state:
|
|
83
|
+
if isinstance(seed, torch.Generator): self.global_state['generator'] = seed
|
|
84
|
+
elif seed is not None: self.global_state['generator'] = torch.Generator(params[0].device).manual_seed(seed)
|
|
85
|
+
else: self.global_state['generator'] = None
|
|
86
|
+
return self.global_state['generator']
|
|
87
|
+
|
|
88
|
+
def pre_step(self, vars):
|
|
89
|
+
params = TensorList(vars.params)
|
|
90
|
+
settings = self.settings[params[0]]
|
|
91
|
+
n_samples = settings['n_samples']
|
|
92
|
+
init_sigma = self.get_settings('init_sigma', params=params)
|
|
93
|
+
sigmas = self.get_state('sigma', params = params, init=init_sigma)
|
|
94
|
+
|
|
95
|
+
if any('perturbations' not in self.state[p] for p in params):
|
|
96
|
+
generator = self._get_generator(settings['seed'], params)
|
|
97
|
+
_generate_perturbations_to_state_(self, params=params, n_samples=n_samples, sigmas=sigmas, generator=generator)
|
|
98
|
+
|
|
99
|
+
# sigma decay rules
|
|
100
|
+
max_steps = settings['max_steps']
|
|
101
|
+
decayed = False
|
|
102
|
+
if max_steps is not None and max_steps > 0:
|
|
103
|
+
level_steps = self.global_state['level_steps'] = self.global_state.get('level_steps', 0) + 1
|
|
104
|
+
if level_steps > max_steps:
|
|
105
|
+
self.global_state['level_steps'] = 0
|
|
106
|
+
_decay_sigma_(self, params)
|
|
107
|
+
decayed = True
|
|
108
|
+
|
|
109
|
+
tol = settings['tol']
|
|
110
|
+
if tol is not None and not decayed:
|
|
111
|
+
if not any('prev_params' in self.state[p] for p in params):
|
|
112
|
+
prev_params = self.get_state('prev_params', params=params, cls=TensorList, init='param')
|
|
113
|
+
else:
|
|
114
|
+
prev_params = self.get_state('prev_params', params=params, cls=TensorList, init='param')
|
|
115
|
+
s = params - prev_params
|
|
116
|
+
|
|
117
|
+
if s.abs().global_max() <= tol:
|
|
118
|
+
_decay_sigma_(self, params)
|
|
119
|
+
decayed = True
|
|
120
|
+
|
|
121
|
+
prev_params.copy_(params)
|
|
122
|
+
|
|
123
|
+
if decayed:
|
|
124
|
+
generator = self._get_generator(settings['seed'], params)
|
|
125
|
+
_generate_perturbations_to_state_(self, params=params, n_samples=n_samples, sigmas=sigmas, generator=generator)
|
|
126
|
+
if settings['clear_state']:
|
|
127
|
+
vars.post_step_hooks.append(partial(_clear_state_hook, self=self))
|
|
128
|
+
|
|
129
|
+
@torch.no_grad
|
|
130
|
+
def closure(self, backward, closure, params, vars):
|
|
131
|
+
params = TensorList(params)
|
|
132
|
+
|
|
133
|
+
settings = self.settings[params[0]]
|
|
134
|
+
n_samples = settings['n_samples']
|
|
135
|
+
|
|
136
|
+
perturbations = list(zip(*(self.state[p]['perturbations'] for p in params)))
|
|
137
|
+
|
|
138
|
+
loss = None
|
|
139
|
+
grad = None
|
|
140
|
+
for i in range(n_samples):
|
|
141
|
+
prt = perturbations[i]
|
|
142
|
+
|
|
143
|
+
params.add_(prt)
|
|
144
|
+
if backward:
|
|
145
|
+
with torch.enable_grad(): l = closure()
|
|
146
|
+
if grad is None: grad = params.grad
|
|
147
|
+
else: grad += params.grad
|
|
148
|
+
|
|
149
|
+
else:
|
|
150
|
+
l = closure(False)
|
|
151
|
+
|
|
152
|
+
if loss is None: loss = l
|
|
153
|
+
else: loss = loss+l
|
|
154
|
+
|
|
155
|
+
params.sub_(prt)
|
|
156
|
+
|
|
157
|
+
assert loss is not None
|
|
158
|
+
if n_samples > 1:
|
|
159
|
+
loss = loss / n_samples
|
|
160
|
+
if backward:
|
|
161
|
+
assert grad is not None
|
|
162
|
+
grad.div_(n_samples)
|
|
163
|
+
|
|
164
|
+
return loss, grad
|
|
@@ -1,128 +1,115 @@
|
|
|
1
|
-
from typing import Literal
|
|
2
|
-
from collections.abc import Iterable
|
|
3
|
-
|
|
4
|
-
import torch
|
|
5
|
-
|
|
6
|
-
from ...tensorlist import TensorList
|
|
7
|
-
from ...core import
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
def vector_laplacian_smoothing(input: torch.Tensor, sigma: float = 1) -> torch.Tensor:
|
|
11
|
-
"""Returns a new vector with laplacian smoothing applied to it. This flattens the input!"""
|
|
12
|
-
vec = input.view(-1)
|
|
13
|
-
v = torch.zeros_like(vec)
|
|
14
|
-
v[0] = -2
|
|
15
|
-
v[1] = 1
|
|
16
|
-
v[-1] = 1
|
|
17
|
-
numerator = torch.fft.fft(vec) # pylint: disable = not-callable
|
|
18
|
-
denominator = 1 - sigma * torch.fft.fft(v) # pylint: disable = not-callable
|
|
19
|
-
return torch.fft.ifft(numerator / denominator).real # pylint: disable = not-callable
|
|
20
|
-
|
|
21
|
-
def gradient_laplacian_smoothing_(params: Iterable[torch.Tensor], sigma: float = 1, layerwise=True, min_numel = 4):
|
|
22
|
-
"""Applies laplacian smoothing to gradients of an iterable of parameters.
|
|
23
|
-
|
|
24
|
-
This updates gradients in-place.
|
|
25
|
-
|
|
26
|
-
Args:
|
|
27
|
-
params (abc.Iterable[torch.Tensor]): an iterable of Tensors that will have gradients smoothed.
|
|
28
|
-
sigma (float, optional): controls the amount of smoothing. Defaults to 1.
|
|
29
|
-
layerwise (bool, optional):
|
|
30
|
-
If True, applies smoothing to each parameter's gradient separately,
|
|
31
|
-
Otherwise applies it to all gradients, concatenated into a single vector. Defaults to True.
|
|
32
|
-
min_numel (int, optional):
|
|
33
|
-
minimum number of elements in a parameter to apply laplacian smoothing to.
|
|
34
|
-
Only has effect if `layerwise` is True. Defaults to 4.
|
|
35
|
-
|
|
36
|
-
Reference:
|
|
37
|
-
*Osher, S., Wang, B., Yin, P., Luo, X., Barekat, F., Pham, M., & Lin, A. (2022).
|
|
38
|
-
Laplacian smoothing gradient descent. Research in the Mathematical Sciences, 9(3), 55.*
|
|
39
|
-
"""
|
|
40
|
-
grads = TensorList(params).
|
|
41
|
-
if layerwise:
|
|
42
|
-
for g in grads:
|
|
43
|
-
if g.numel() >= min_numel:
|
|
44
|
-
g.set_(vector_laplacian_smoothing(g, sigma).
|
|
45
|
-
else:
|
|
46
|
-
vec = grads.to_vec()
|
|
47
|
-
grads.from_vec_(vector_laplacian_smoothing(vec, sigma))
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
def _precompute_denominator(tensor: torch.Tensor, sigma) -> torch.Tensor:
|
|
51
|
-
"""Denominator will always be the same and depends on the size of the vector and the sigma."""
|
|
52
|
-
v = torch.zeros_like(tensor.view(-1))
|
|
53
|
-
v[0] = -2
|
|
54
|
-
v[1] = 1
|
|
55
|
-
v[-1] = 1
|
|
56
|
-
return 1 - sigma * torch.fft.fft(v) # pylint: disable = not-callable
|
|
57
|
-
|
|
58
|
-
class LaplacianSmoothing(
|
|
59
|
-
"""Applies laplacian smoothing via a fast Fourier transform solver.
|
|
60
|
-
|
|
61
|
-
Args:
|
|
62
|
-
sigma (float, optional): controls the amount of smoothing. Defaults to 1.
|
|
63
|
-
layerwise (bool, optional):
|
|
64
|
-
If True, applies smoothing to each parameter's gradient separately,
|
|
65
|
-
Otherwise applies it to all gradients, concatenated into a single vector. Defaults to True.
|
|
66
|
-
min_numel (int, optional):
|
|
67
|
-
minimum number of elements in a parameter to apply laplacian smoothing to.
|
|
68
|
-
Only has effect if `layerwise` is True. Defaults to 4.
|
|
69
|
-
target (str, optional):
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
return smoothed_direction
|
|
117
|
-
|
|
118
|
-
# else
|
|
119
|
-
# full laplacian smoothing
|
|
120
|
-
# precompute full denominator
|
|
121
|
-
if self.full_denominator is None:
|
|
122
|
-
self.full_denominator = _precompute_denominator(ascent.to_vec(), self.sigma)
|
|
123
|
-
|
|
124
|
-
# apply the smoothing
|
|
125
|
-
vec = ascent.to_vec()
|
|
126
|
-
return ascent.from_vec(torch.fft.ifft(torch.fft.fft(vec) / self.full_denominator).real) # pylint: disable = not-callable
|
|
127
|
-
|
|
128
|
-
|
|
1
|
+
from typing import Literal
|
|
2
|
+
from collections.abc import Iterable
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
from ...utils.tensorlist import TensorList
|
|
7
|
+
from ...core import Transform, Target
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def vector_laplacian_smoothing(input: torch.Tensor, sigma: float = 1) -> torch.Tensor:
|
|
11
|
+
"""Returns a new vector with laplacian smoothing applied to it. This flattens the input!"""
|
|
12
|
+
vec = input.view(-1)
|
|
13
|
+
v = torch.zeros_like(vec)
|
|
14
|
+
v[0] = -2
|
|
15
|
+
v[1] = 1
|
|
16
|
+
v[-1] = 1
|
|
17
|
+
numerator = torch.fft.fft(vec) # pylint: disable = not-callable
|
|
18
|
+
denominator = 1 - sigma * torch.fft.fft(v) # pylint: disable = not-callable
|
|
19
|
+
return torch.fft.ifft(numerator / denominator).real # pylint: disable = not-callable
|
|
20
|
+
|
|
21
|
+
def gradient_laplacian_smoothing_(params: Iterable[torch.Tensor], sigma: float = 1, layerwise=True, min_numel = 4):
|
|
22
|
+
"""Applies laplacian smoothing to gradients of an iterable of parameters.
|
|
23
|
+
|
|
24
|
+
This updates gradients in-place.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
params (abc.Iterable[torch.Tensor]): an iterable of Tensors that will have gradients smoothed.
|
|
28
|
+
sigma (float, optional): controls the amount of smoothing. Defaults to 1.
|
|
29
|
+
layerwise (bool, optional):
|
|
30
|
+
If True, applies smoothing to each parameter's gradient separately,
|
|
31
|
+
Otherwise applies it to all gradients, concatenated into a single vector. Defaults to True.
|
|
32
|
+
min_numel (int, optional):
|
|
33
|
+
minimum number of elements in a parameter to apply laplacian smoothing to.
|
|
34
|
+
Only has effect if `layerwise` is True. Defaults to 4.
|
|
35
|
+
|
|
36
|
+
Reference:
|
|
37
|
+
*Osher, S., Wang, B., Yin, P., Luo, X., Barekat, F., Pham, M., & Lin, A. (2022).
|
|
38
|
+
Laplacian smoothing gradient descent. Research in the Mathematical Sciences, 9(3), 55.*
|
|
39
|
+
"""
|
|
40
|
+
grads = TensorList(params).get_grad()
|
|
41
|
+
if layerwise:
|
|
42
|
+
for g in grads:
|
|
43
|
+
if g.numel() >= min_numel:
|
|
44
|
+
g.set_(vector_laplacian_smoothing(g, sigma).view_as(g)) # pyright:ignore[reportArgumentType]
|
|
45
|
+
else:
|
|
46
|
+
vec = grads.to_vec()
|
|
47
|
+
grads.from_vec_(vector_laplacian_smoothing(vec, sigma))
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def _precompute_denominator(tensor: torch.Tensor, sigma) -> torch.Tensor:
|
|
51
|
+
"""Denominator will always be the same and depends on the size of the vector and the sigma."""
|
|
52
|
+
v = torch.zeros_like(tensor.view(-1))
|
|
53
|
+
v[0] = -2
|
|
54
|
+
v[1] = 1
|
|
55
|
+
v[-1] = 1
|
|
56
|
+
return 1 - sigma * torch.fft.fft(v) # pylint: disable = not-callable
|
|
57
|
+
|
|
58
|
+
class LaplacianSmoothing(Transform):
|
|
59
|
+
"""Applies laplacian smoothing via a fast Fourier transform solver.
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
sigma (float, optional): controls the amount of smoothing. Defaults to 1.
|
|
63
|
+
layerwise (bool, optional):
|
|
64
|
+
If True, applies smoothing to each parameter's gradient separately,
|
|
65
|
+
Otherwise applies it to all gradients, concatenated into a single vector. Defaults to True.
|
|
66
|
+
min_numel (int, optional):
|
|
67
|
+
minimum number of elements in a parameter to apply laplacian smoothing to.
|
|
68
|
+
Only has effect if `layerwise` is True. Defaults to 4.
|
|
69
|
+
target (str, optional):
|
|
70
|
+
what to set on vars.
|
|
71
|
+
|
|
72
|
+
Reference:
|
|
73
|
+
*Osher, S., Wang, B., Yin, P., Luo, X., Barekat, F., Pham, M., & Lin, A. (2022).
|
|
74
|
+
Laplacian smoothing gradient descent. Research in the Mathematical Sciences, 9(3), 55.*
|
|
75
|
+
|
|
76
|
+
"""
|
|
77
|
+
def __init__(self, sigma:float = 1, layerwise=True, min_numel = 4, target: Target = 'update'):
|
|
78
|
+
defaults = dict(sigma = sigma, layerwise=layerwise, min_numel=min_numel)
|
|
79
|
+
super().__init__(defaults, uses_grad=False, target=target)
|
|
80
|
+
# precomputed denominator for when layerwise=False
|
|
81
|
+
self.global_state['full_denominator'] = None
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
@torch.no_grad
|
|
85
|
+
def transform(self, tensors, params, grads, vars):
|
|
86
|
+
layerwise = self.settings[params[0]]['layerwise']
|
|
87
|
+
|
|
88
|
+
# layerwise laplacian smoothing
|
|
89
|
+
if layerwise:
|
|
90
|
+
|
|
91
|
+
# precompute the denominator for each layer and store it in each parameters state
|
|
92
|
+
smoothed_target = TensorList()
|
|
93
|
+
for p, t in zip(params, tensors):
|
|
94
|
+
settings = self.settings[p]
|
|
95
|
+
if p.numel() > settings['min_numel']:
|
|
96
|
+
state = self.state[p]
|
|
97
|
+
if 'denominator' not in state: state['denominator'] = _precompute_denominator(p, settings['sigma'])
|
|
98
|
+
smoothed_target.append(torch.fft.ifft(torch.fft.fft(t.view(-1)) / state['denominator']).real.view_as(t)) #pylint:disable=not-callable
|
|
99
|
+
else:
|
|
100
|
+
smoothed_target.append(t)
|
|
101
|
+
|
|
102
|
+
return smoothed_target
|
|
103
|
+
|
|
104
|
+
# else
|
|
105
|
+
# full laplacian smoothing
|
|
106
|
+
# precompute full denominator
|
|
107
|
+
tensors = TensorList(tensors)
|
|
108
|
+
if self.global_state.get('full_denominator', None) is None:
|
|
109
|
+
self.global_state['full_denominator'] = _precompute_denominator(tensors.to_vec(), self.settings[params[0]]['sigma'])
|
|
110
|
+
|
|
111
|
+
# apply the smoothing
|
|
112
|
+
vec = tensors.to_vec()
|
|
113
|
+
return tensors.from_vec(torch.fft.ifft(torch.fft.fft(vec) / self.global_state['full_denominator']).real)#pylint:disable=not-callable
|
|
114
|
+
|
|
115
|
+
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .weight_decay import WeightDecay, DirectWeightDecay, decay_weights_
|
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
from collections.abc import Iterable, Sequence
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from ...core import Module, Target, Transform
|
|
6
|
+
from ...utils import NumberList, TensorList, as_tensorlist
|
|
7
|
+
|
|
8
|
+
@torch.no_grad
|
|
9
|
+
def weight_decay_(
|
|
10
|
+
grad_: TensorList,
|
|
11
|
+
params: TensorList,
|
|
12
|
+
weight_decay: float | NumberList,
|
|
13
|
+
ord: int = 2
|
|
14
|
+
):
|
|
15
|
+
"""returns `grad_`."""
|
|
16
|
+
if ord == 1: return grad_.add_(params.sign().mul_(weight_decay))
|
|
17
|
+
if ord == 2: return grad_.add_(params.mul(weight_decay))
|
|
18
|
+
if ord - 1 % 2 != 0: return grad_.add_(params.pow(ord-1).mul_(weight_decay))
|
|
19
|
+
return grad_.add_(params.pow(ord-1).copysign_(params).mul_(weight_decay))
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class WeightDecay(Transform):
|
|
23
|
+
def __init__(self, weight_decay: float, ord: int = 2, target: Target = 'update'):
|
|
24
|
+
defaults = dict(weight_decay=weight_decay, ord=ord)
|
|
25
|
+
super().__init__(defaults, uses_grad=False, target=target)
|
|
26
|
+
|
|
27
|
+
@torch.no_grad
|
|
28
|
+
def transform(self, tensors, params, grads, vars):
|
|
29
|
+
weight_decay = self.get_settings('weight_decay', params=params, cls=NumberList)
|
|
30
|
+
ord = self.settings[params[0]]['ord']
|
|
31
|
+
|
|
32
|
+
return weight_decay_(as_tensorlist(tensors), as_tensorlist(params), weight_decay, ord)
|
|
33
|
+
|
|
34
|
+
@torch.no_grad
|
|
35
|
+
def decay_weights_(params: Iterable[torch.Tensor], weight_decay: float | NumberList, ord:int=2):
|
|
36
|
+
"""directly decays weights in-place"""
|
|
37
|
+
params = TensorList(params)
|
|
38
|
+
weight_decay_(params, params, -weight_decay, ord)
|
|
39
|
+
|
|
40
|
+
class DirectWeightDecay(Module):
|
|
41
|
+
"""directly decays weights in-place"""
|
|
42
|
+
def __init__(self, weight_decay: float, ord: int = 2,):
|
|
43
|
+
defaults = dict(weight_decay=weight_decay, ord=ord)
|
|
44
|
+
super().__init__(defaults)
|
|
45
|
+
|
|
46
|
+
@torch.no_grad
|
|
47
|
+
def step(self, vars):
|
|
48
|
+
weight_decay = self.get_settings('weight_decay', params=vars.params, cls=NumberList)
|
|
49
|
+
ord = self.settings[vars.params[0]]['ord']
|
|
50
|
+
|
|
51
|
+
decay_weights_(vars.params, weight_decay, ord)
|
|
52
|
+
return vars
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .optim_wrapper import Wrap
|
|
@@ -0,0 +1,91 @@
|
|
|
1
|
+
from collections.abc import Iterable, Mapping, Sequence, Callable
|
|
2
|
+
from typing import Any
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from ...core.module import Module
|
|
6
|
+
from ...utils import Params, _copy_param_groups, _make_param_groups
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class Wrap(Module):
|
|
10
|
+
"""Custom param groups are supported only by `set_param_groups`. Settings passed to Modular will be ignored."""
|
|
11
|
+
def __init__(self, opt_fn: Callable[..., torch.optim.Optimizer] | torch.optim.Optimizer, *args, **kwargs):
|
|
12
|
+
super().__init__()
|
|
13
|
+
self._opt_fn = opt_fn
|
|
14
|
+
self._opt_args = args
|
|
15
|
+
self._opt_kwargs = kwargs
|
|
16
|
+
self._custom_param_groups = None
|
|
17
|
+
|
|
18
|
+
self.optimizer: torch.optim.Optimizer | None = None
|
|
19
|
+
if isinstance(self._opt_fn, torch.optim.Optimizer) or not callable(self._opt_fn):
|
|
20
|
+
self.optimizer = self._opt_fn
|
|
21
|
+
|
|
22
|
+
def set_param_groups(self, param_groups):
|
|
23
|
+
self._custom_param_groups = param_groups
|
|
24
|
+
return super().set_param_groups(param_groups)
|
|
25
|
+
|
|
26
|
+
@torch.no_grad
|
|
27
|
+
def step(self, vars):
|
|
28
|
+
params = vars.params
|
|
29
|
+
|
|
30
|
+
# initialize opt on 1st step
|
|
31
|
+
if self.optimizer is None:
|
|
32
|
+
assert callable(self._opt_fn)
|
|
33
|
+
param_groups = params if self._custom_param_groups is None else self._custom_param_groups
|
|
34
|
+
self.optimizer = self._opt_fn(param_groups, *self._opt_args, **self._opt_kwargs)
|
|
35
|
+
|
|
36
|
+
# set grad to update
|
|
37
|
+
orig_grad = [p.grad for p in params]
|
|
38
|
+
for p, u in zip(params, vars.get_update()):
|
|
39
|
+
p.grad = u
|
|
40
|
+
|
|
41
|
+
# if this module is last, can step with _opt directly
|
|
42
|
+
# direct step can't be applied if next module is LR but _opt doesn't support lr,
|
|
43
|
+
# and if there are multiple different per-parameter lrs (would be annoying to support)
|
|
44
|
+
if vars.is_last and (
|
|
45
|
+
(vars.last_module_lrs is None)
|
|
46
|
+
or
|
|
47
|
+
(('lr' in self.optimizer.defaults) and (len(set(vars.last_module_lrs)) == 1))
|
|
48
|
+
):
|
|
49
|
+
lr = 1 if vars.last_module_lrs is None else vars.last_module_lrs[0]
|
|
50
|
+
|
|
51
|
+
# update optimizer lr with desired lr
|
|
52
|
+
if lr != 1:
|
|
53
|
+
self.optimizer.defaults['__original_lr__'] = self.optimizer.defaults['lr']
|
|
54
|
+
for g in self.optimizer.param_groups:
|
|
55
|
+
g['__original_lr__'] = g['lr']
|
|
56
|
+
g['lr'] = g['lr'] * lr
|
|
57
|
+
|
|
58
|
+
# step
|
|
59
|
+
self.optimizer.step()
|
|
60
|
+
|
|
61
|
+
# restore original lr
|
|
62
|
+
if lr != 1:
|
|
63
|
+
self.optimizer.defaults['lr'] = self.optimizer.defaults.pop('__original_lr__')
|
|
64
|
+
for g in self.optimizer.param_groups:
|
|
65
|
+
g['lr'] = g.pop('__original_lr__')
|
|
66
|
+
|
|
67
|
+
# restore grad
|
|
68
|
+
for p, g in zip(params, orig_grad):
|
|
69
|
+
p.grad = g
|
|
70
|
+
|
|
71
|
+
vars.stop = True; vars.skip_update = True
|
|
72
|
+
return vars
|
|
73
|
+
|
|
74
|
+
# this is not the last module, meaning update is difference in parameters
|
|
75
|
+
params_before_step = [p.clone() for p in params]
|
|
76
|
+
self.optimizer.step() # step and update params
|
|
77
|
+
for p, g in zip(params, orig_grad):
|
|
78
|
+
p.grad = g
|
|
79
|
+
vars.update = list(torch._foreach_sub(params_before_step, params)) # set update to difference between params
|
|
80
|
+
for p, o in zip(params, params_before_step):
|
|
81
|
+
p.set_(o) # pyright: ignore[reportArgumentType]
|
|
82
|
+
|
|
83
|
+
return vars
|
|
84
|
+
|
|
85
|
+
def reset(self):
|
|
86
|
+
super().reset()
|
|
87
|
+
assert self.optimizer is not None
|
|
88
|
+
for g in self.optimizer.param_groups:
|
|
89
|
+
for p in g['params']:
|
|
90
|
+
state = self.optimizer.state[p]
|
|
91
|
+
state.clear()
|
torchzero/optim/__init__.py
CHANGED
|
@@ -1,10 +1,2 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
3
|
-
"""
|
|
4
|
-
from .modular import Modular
|
|
5
|
-
from .quasi_newton import *
|
|
6
|
-
from .zeroth_order import *
|
|
7
|
-
from .second_order import *
|
|
8
|
-
from .first_order import *
|
|
9
|
-
# from .wrappers.scipy import ScipyMinimize
|
|
10
|
-
from . import experimental
|
|
1
|
+
from .utility import *
|
|
2
|
+
from .wrappers import *
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .split import Split
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
import warnings
|
|
2
|
+
from collections.abc import Callable, Iterable
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
from ...utils import flatten, get_params
|
|
7
|
+
|
|
8
|
+
class Split(torch.optim.Optimizer):
|
|
9
|
+
"""Steps will all `optimizers`, also has a check that they have no duplicate parameters.
|
|
10
|
+
Doesn't support closure based optimizers.
|
|
11
|
+
|
|
12
|
+
Example:
|
|
13
|
+
|
|
14
|
+
.. code:: py
|
|
15
|
+
|
|
16
|
+
opt = Split(
|
|
17
|
+
torch.optim.Adam(model.encoder.parameters(), lr=0.001),
|
|
18
|
+
torch.optim.SGD(model.decoder.parameters(), lr=0.1)
|
|
19
|
+
)
|
|
20
|
+
"""
|
|
21
|
+
def __init__(self, *optimizers: torch.optim.Optimizer | Iterable[torch.optim.Optimizer]):
|
|
22
|
+
all_params = []
|
|
23
|
+
self.optimizers: list[torch.optim.Optimizer] = flatten(optimizers)
|
|
24
|
+
|
|
25
|
+
# gather all params in case user tries to access them from this object
|
|
26
|
+
for i,opt in enumerate(self.optimizers):
|
|
27
|
+
for p in get_params(opt.param_groups, 'all', list):
|
|
28
|
+
if p not in all_params: all_params.append(p)
|
|
29
|
+
else: warnings.warn(
|
|
30
|
+
f'optimizers[{i}] {opt.__class__.__name__} has some duplicate parameters '
|
|
31
|
+
'that are also in previous optimizers. They will be updated multiple times.')
|
|
32
|
+
|
|
33
|
+
super().__init__(all_params, {})
|
|
34
|
+
|
|
35
|
+
def step(self, closure: Callable | None = None):
|
|
36
|
+
loss = None
|
|
37
|
+
|
|
38
|
+
# if closure provided, populate grad, otherwise each optimizer will call closure separately
|
|
39
|
+
if closure is not None:
|
|
40
|
+
with torch.enable_grad(): loss = closure()
|
|
41
|
+
|
|
42
|
+
for opt in self.optimizers:
|
|
43
|
+
opt.step() # closure not passed as grad is already evaluated
|
|
44
|
+
|
|
45
|
+
return loss
|
|
@@ -6,7 +6,7 @@ import torch
|
|
|
6
6
|
|
|
7
7
|
import nevergrad as ng
|
|
8
8
|
|
|
9
|
-
from ...
|
|
9
|
+
from ...utils import Optimizer
|
|
10
10
|
|
|
11
11
|
|
|
12
12
|
def _ensure_float(x):
|
|
@@ -14,7 +14,7 @@ def _ensure_float(x):
|
|
|
14
14
|
if isinstance(x, np.ndarray): return x.item()
|
|
15
15
|
return float(x)
|
|
16
16
|
|
|
17
|
-
class NevergradOptimizer(
|
|
17
|
+
class NevergradOptimizer(Optimizer):
|
|
18
18
|
"""Use nevergrad optimizer as pytorch optimizer.
|
|
19
19
|
Note that it is recommended to specify `budget` to the number of iterations you expect to run,
|
|
20
20
|
as some nevergrad optimizers will error without it.
|
|
@@ -85,29 +85,3 @@ class NevergradOptimizer(TensorListOptimizer):
|
|
|
85
85
|
loss = closure(False)
|
|
86
86
|
self.opt.tell(x, _ensure_float(loss))
|
|
87
87
|
return loss
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
# class NevergradSubspace(ModularOptimizer):
|
|
92
|
-
# def __init__(
|
|
93
|
-
# self,
|
|
94
|
-
# params,
|
|
95
|
-
# opt_cls:"type[ng.optimizers.base.Optimizer] | abc.Callable[..., ng.optimizers.base.Optimizer]",
|
|
96
|
-
# budget=None,
|
|
97
|
-
# mutable_sigma = False,
|
|
98
|
-
# use_init = True,
|
|
99
|
-
# projections = Proj2Masks(5),
|
|
100
|
-
# ):
|
|
101
|
-
|
|
102
|
-
# modules = [
|
|
103
|
-
# Subspace(projections, update_every=100),
|
|
104
|
-
# UninitializedClosureOptimizerWrapper(
|
|
105
|
-
# NevergradOptimizer,
|
|
106
|
-
# opt_cls = opt_cls,
|
|
107
|
-
# budget = budget,
|
|
108
|
-
# mutable_sigma = mutable_sigma,
|
|
109
|
-
# use_init = use_init,
|
|
110
|
-
# ),
|
|
111
|
-
# ]
|
|
112
|
-
|
|
113
|
-
# super().__init__(params, modules)
|