torchzero 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torchzero/__init__.py +4 -0
- torchzero/core/__init__.py +13 -0
- torchzero/core/module.py +471 -0
- torchzero/core/tensorlist_optimizer.py +219 -0
- torchzero/modules/__init__.py +21 -0
- torchzero/modules/adaptive/__init__.py +4 -0
- torchzero/modules/adaptive/adaptive.py +192 -0
- torchzero/modules/experimental/__init__.py +19 -0
- torchzero/modules/experimental/experimental.py +294 -0
- torchzero/modules/experimental/quad_interp.py +104 -0
- torchzero/modules/experimental/subspace.py +259 -0
- torchzero/modules/gradient_approximation/__init__.py +7 -0
- torchzero/modules/gradient_approximation/_fd_formulas.py +3 -0
- torchzero/modules/gradient_approximation/base_approximator.py +110 -0
- torchzero/modules/gradient_approximation/fdm.py +125 -0
- torchzero/modules/gradient_approximation/forward_gradient.py +163 -0
- torchzero/modules/gradient_approximation/newton_fdm.py +198 -0
- torchzero/modules/gradient_approximation/rfdm.py +125 -0
- torchzero/modules/line_search/__init__.py +30 -0
- torchzero/modules/line_search/armijo.py +56 -0
- torchzero/modules/line_search/base_ls.py +139 -0
- torchzero/modules/line_search/directional_newton.py +217 -0
- torchzero/modules/line_search/grid_ls.py +158 -0
- torchzero/modules/line_search/scipy_minimize_scalar.py +62 -0
- torchzero/modules/meta/__init__.py +12 -0
- torchzero/modules/meta/alternate.py +65 -0
- torchzero/modules/meta/grafting.py +195 -0
- torchzero/modules/meta/optimizer_wrapper.py +173 -0
- torchzero/modules/meta/return_overrides.py +46 -0
- torchzero/modules/misc/__init__.py +10 -0
- torchzero/modules/misc/accumulate.py +43 -0
- torchzero/modules/misc/basic.py +115 -0
- torchzero/modules/misc/lr.py +96 -0
- torchzero/modules/misc/multistep.py +51 -0
- torchzero/modules/misc/on_increase.py +53 -0
- torchzero/modules/momentum/__init__.py +4 -0
- torchzero/modules/momentum/momentum.py +106 -0
- torchzero/modules/operations/__init__.py +29 -0
- torchzero/modules/operations/multi.py +298 -0
- torchzero/modules/operations/reduction.py +134 -0
- torchzero/modules/operations/singular.py +113 -0
- torchzero/modules/optimizers/__init__.py +10 -0
- torchzero/modules/optimizers/adagrad.py +49 -0
- torchzero/modules/optimizers/adam.py +118 -0
- torchzero/modules/optimizers/lion.py +28 -0
- torchzero/modules/optimizers/rmsprop.py +51 -0
- torchzero/modules/optimizers/rprop.py +99 -0
- torchzero/modules/optimizers/sgd.py +54 -0
- torchzero/modules/orthogonalization/__init__.py +2 -0
- torchzero/modules/orthogonalization/newtonschulz.py +159 -0
- torchzero/modules/orthogonalization/svd.py +86 -0
- torchzero/modules/quasi_newton/__init__.py +4 -0
- torchzero/modules/regularization/__init__.py +22 -0
- torchzero/modules/regularization/dropout.py +34 -0
- torchzero/modules/regularization/noise.py +77 -0
- torchzero/modules/regularization/normalization.py +328 -0
- torchzero/modules/regularization/ortho_grad.py +78 -0
- torchzero/modules/regularization/weight_decay.py +92 -0
- torchzero/modules/scheduling/__init__.py +2 -0
- torchzero/modules/scheduling/lr_schedulers.py +131 -0
- torchzero/modules/scheduling/step_size.py +80 -0
- torchzero/modules/second_order/__init__.py +4 -0
- torchzero/modules/second_order/newton.py +165 -0
- torchzero/modules/smoothing/__init__.py +5 -0
- torchzero/modules/smoothing/gaussian_smoothing.py +90 -0
- torchzero/modules/smoothing/laplacian_smoothing.py +128 -0
- torchzero/modules/weight_averaging/__init__.py +2 -0
- torchzero/modules/weight_averaging/ema.py +72 -0
- torchzero/modules/weight_averaging/swa.py +171 -0
- torchzero/optim/__init__.py +10 -0
- torchzero/optim/experimental/__init__.py +20 -0
- torchzero/optim/experimental/experimental.py +343 -0
- torchzero/optim/experimental/ray_search.py +83 -0
- torchzero/optim/first_order/__init__.py +18 -0
- torchzero/optim/first_order/cautious.py +158 -0
- torchzero/optim/first_order/forward_gradient.py +70 -0
- torchzero/optim/first_order/optimizers.py +570 -0
- torchzero/optim/modular.py +132 -0
- torchzero/optim/quasi_newton/__init__.py +1 -0
- torchzero/optim/quasi_newton/directional_newton.py +58 -0
- torchzero/optim/second_order/__init__.py +1 -0
- torchzero/optim/second_order/newton.py +94 -0
- torchzero/optim/wrappers/__init__.py +0 -0
- torchzero/optim/wrappers/nevergrad.py +113 -0
- torchzero/optim/wrappers/nlopt.py +165 -0
- torchzero/optim/wrappers/scipy.py +439 -0
- torchzero/optim/zeroth_order/__init__.py +4 -0
- torchzero/optim/zeroth_order/fdm.py +87 -0
- torchzero/optim/zeroth_order/newton_fdm.py +146 -0
- torchzero/optim/zeroth_order/rfdm.py +217 -0
- torchzero/optim/zeroth_order/rs.py +85 -0
- torchzero/random/__init__.py +1 -0
- torchzero/random/random.py +46 -0
- torchzero/tensorlist.py +819 -0
- torchzero/utils/__init__.py +0 -0
- torchzero/utils/compile.py +39 -0
- torchzero/utils/derivatives.py +99 -0
- torchzero/utils/python_tools.py +25 -0
- torchzero/utils/torch_tools.py +92 -0
- torchzero-0.0.1.dist-info/LICENSE +21 -0
- torchzero-0.0.1.dist-info/METADATA +118 -0
- torchzero-0.0.1.dist-info/RECORD +104 -0
- torchzero-0.0.1.dist-info/WHEEL +5 -0
- torchzero-0.0.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,132 @@
|
|
|
1
|
+
from collections import abc
|
|
2
|
+
import warnings
|
|
3
|
+
from inspect import cleandoc
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
from ..core import OptimizerModule, TensorListOptimizer, OptimizationState, _Chain, _Chainable
|
|
7
|
+
from ..utils.python_tools import flatten
|
|
8
|
+
|
|
9
|
+
def _unroll_modules(flat_modules: list[OptimizerModule], nested) -> list[OptimizerModule]:
|
|
10
|
+
"""returns a list of all modules, including all nested ones"""
|
|
11
|
+
unrolled = []
|
|
12
|
+
for m in flat_modules:
|
|
13
|
+
unrolled.append(m)
|
|
14
|
+
if len(m.children) > 0:
|
|
15
|
+
unrolled.extend(_unroll_modules(list(m.children.values()), nested=True))
|
|
16
|
+
if nested:
|
|
17
|
+
if m.next_module is not None:
|
|
18
|
+
unrolled.extend(_unroll_modules([m.next_module], nested=True))
|
|
19
|
+
return unrolled
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class Modular(TensorListOptimizer):
|
|
23
|
+
"""Creates a modular optimizer by chaining together a sequence of optimizer modules.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
params: iterable of parameters to optimize or dicts defining parameter groups.
|
|
27
|
+
*modules (Iterable[OptimizerModule] | OptimizerModule):
|
|
28
|
+
A sequence of optimizer modules to chain together. This argument will be flattened."""
|
|
29
|
+
def __init__(self, params, *modules: _Chainable):
|
|
30
|
+
flat_modules = flatten(modules)
|
|
31
|
+
self.modules: list[OptimizerModule] = flat_modules
|
|
32
|
+
self.chain = _Chain(flat_modules)
|
|
33
|
+
|
|
34
|
+
# save unrolled modules and make sure there is only 1 LR module.
|
|
35
|
+
self.unrolled_modules = _unroll_modules(flat_modules, nested=False)
|
|
36
|
+
num_lr_modules = len([m for m in self.unrolled_modules if m.IS_LR_MODULE])
|
|
37
|
+
if num_lr_modules > 1:
|
|
38
|
+
warnings.warn(cleandoc(
|
|
39
|
+
f"""More then 1 lr modules have been added.
|
|
40
|
+
This may lead to incorrect behaviour with learning rate scheduling and per-parameter learning rates.
|
|
41
|
+
Make sure there is a single `LR` module, use `Alpha` module instead of it where needed.
|
|
42
|
+
\nList of modules: {self.unrolled_modules}; \nlist of lr modules: {[m for m in self.unrolled_modules if m.IS_LR_MODULE]}"""
|
|
43
|
+
))
|
|
44
|
+
|
|
45
|
+
if isinstance(params, torch.nn.Module):
|
|
46
|
+
self.model = params
|
|
47
|
+
params = list(params.parameters())
|
|
48
|
+
else:
|
|
49
|
+
self.model = None
|
|
50
|
+
params = list(params)
|
|
51
|
+
|
|
52
|
+
# if there is an `lr` setting, make sure there is an LR module that can use it
|
|
53
|
+
for p in params:
|
|
54
|
+
if isinstance(p, dict):
|
|
55
|
+
if 'lr' in p:
|
|
56
|
+
if num_lr_modules == 0:
|
|
57
|
+
warnings.warn(cleandoc(
|
|
58
|
+
"""Passed "lr" setting in a parameter group, but there is no LR module that can use that setting.
|
|
59
|
+
Add an `LR` module to make per-layer "lr" setting work."""
|
|
60
|
+
))
|
|
61
|
+
|
|
62
|
+
super().__init__(params, {})
|
|
63
|
+
self.chain._initialize_(params, set_passed_params=True)
|
|
64
|
+
|
|
65
|
+
# run post-init hooks
|
|
66
|
+
for module in self.unrolled_modules:
|
|
67
|
+
for hook in module.post_init_hooks:
|
|
68
|
+
hook(self, module)
|
|
69
|
+
|
|
70
|
+
def get_lr_module(self, last=True) -> OptimizerModule:
|
|
71
|
+
"""
|
|
72
|
+
Retrieves the module in the chain that controls the learning rate.
|
|
73
|
+
|
|
74
|
+
This method is useful for setting up a learning rate scheduler. By default, it retrieves the last module in the chain
|
|
75
|
+
that has an `lr` group parameter.
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
last (bool, optional):
|
|
79
|
+
If multiple modules have an `lr` parameter, this argument controls which one is returned.
|
|
80
|
+
- If `True` (default), the last module is returned.
|
|
81
|
+
- If `False`, the first module is returned.
|
|
82
|
+
|
|
83
|
+
Returns:
|
|
84
|
+
OptimizerModule: The module that controls the learning rate.
|
|
85
|
+
|
|
86
|
+
Raises:
|
|
87
|
+
ValueError: If no modules in the chain have an `lr` parameter. To fix this, add an `LR` module.
|
|
88
|
+
|
|
89
|
+
Example:
|
|
90
|
+
|
|
91
|
+
.. code:: py
|
|
92
|
+
from torch.optim.lr_scheduler import OneCycleLR
|
|
93
|
+
import torchzero as tz
|
|
94
|
+
|
|
95
|
+
opt = tz.Modular(model.parameters(), [tz.m.RMSProp(), tz.m.LR(1e-2), tz.m.DirectionalNewton()])
|
|
96
|
+
lr_scheduler = OneCycleLR(opt.get_lr_module(), max_lr = 1e-1, total_steps = 1000, cycle_momentum=False)
|
|
97
|
+
|
|
98
|
+
"""
|
|
99
|
+
modules = list(reversed(self.unrolled_modules)) if last else self.unrolled_modules
|
|
100
|
+
for m in modules:
|
|
101
|
+
if 'lr' in m.param_groups[0]: return m
|
|
102
|
+
|
|
103
|
+
raise ValueError(f'No modules out of {", ".join(m.__class__.__name__ for m in modules)} support and `lr` parameter. The easiest way to fix is is to add an `LR(1)` module at the end.')
|
|
104
|
+
|
|
105
|
+
def get_module_by_name(self, name: str | type, last=True) -> OptimizerModule:
|
|
106
|
+
"""Returns the first or last module in the chain that matches the provided name or type.
|
|
107
|
+
|
|
108
|
+
Args:
|
|
109
|
+
name (str | type): the name (as a string) or the type of the module to search for.
|
|
110
|
+
last (bool, optional):
|
|
111
|
+
If multiple modules match, this argument controls which one is returned.
|
|
112
|
+
- If `True` (default), the last matching module is returned.
|
|
113
|
+
- If `False`, the first matching module is returned.
|
|
114
|
+
|
|
115
|
+
Returns:
|
|
116
|
+
OptimizerModule: The matching optimizer module.
|
|
117
|
+
|
|
118
|
+
Raises:
|
|
119
|
+
ValueError: If no modules in the chain match the provided name or type.
|
|
120
|
+
"""
|
|
121
|
+
modules = list(reversed(self.unrolled_modules)) if last else self.unrolled_modules
|
|
122
|
+
for m in modules:
|
|
123
|
+
if isinstance(name, str) and m.__class__.__name__ == name: return m
|
|
124
|
+
if isinstance(name, type) and isinstance(m, name): return m
|
|
125
|
+
|
|
126
|
+
raise ValueError(f'No modules out of {", ".join(m.__class__.__name__ for m in modules)} match "{name}".')
|
|
127
|
+
|
|
128
|
+
def step(self, closure=None): # type:ignore
|
|
129
|
+
state = OptimizationState(closure, self.model)
|
|
130
|
+
res = self.chain.step(state)
|
|
131
|
+
for hook in state.post_step_hooks: hook(self, state)
|
|
132
|
+
return res
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .directional_newton import DirectionalNewton
|
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
from ...modules import (
|
|
2
|
+
SGD,
|
|
3
|
+
)
|
|
4
|
+
from ...modules import DirectionalNewton as _DirectionalNewton, LR
|
|
5
|
+
from ..modular import Modular
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class DirectionalNewton(Modular):
|
|
9
|
+
"""Minimizes a parabola in the direction of the gradient (or update if momentum or weight decay is enabled)
|
|
10
|
+
via one additional forward pass, and uses another forward pass to make sure it didn't overstep.
|
|
11
|
+
So in total this performs three forward passes and one backward.
|
|
12
|
+
|
|
13
|
+
First forward and backward pass is used to calculate the value and gradient at initial parameters.
|
|
14
|
+
Then a gradient descent step is performed with `lr` learning rate, and loss is recalculated
|
|
15
|
+
with new parameters. A quadratic is fitted to two points and gradient,
|
|
16
|
+
if it has positive curvature, this makes a step towards the minimum, and checks if lr decreased
|
|
17
|
+
with an additional forward pass.
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
params: iterable of parameters to optimize or dicts defining parameter groups.
|
|
21
|
+
lr (float, optional):
|
|
22
|
+
learning rate. Since you shouldn't put this module after LR(), you have to specify
|
|
23
|
+
the learning rate in this argument. Defaults to 1e-2.
|
|
24
|
+
max_dist (float | None, optional):
|
|
25
|
+
maximum distance to step when minimizing quadratic.
|
|
26
|
+
If minimum is further than this distance, minimization is not performed. Defaults to 1e4.
|
|
27
|
+
validate_step (bool, optional):
|
|
28
|
+
uses an additional forward pass to check
|
|
29
|
+
if step towards the minimum actually decreased the loss. Defaults to True.
|
|
30
|
+
momentum (float, optional): momentum. Defaults to 0.
|
|
31
|
+
dampening (float, optional): momentum dampening. Defaults to 0.
|
|
32
|
+
weight_decay (float, optional): weight decay (L2 regularization). Defaults to 0.
|
|
33
|
+
nesterov (bool, optional):
|
|
34
|
+
enables nesterov momentum, otherwise uses heavyball momentum. Defaults to False.
|
|
35
|
+
|
|
36
|
+
Note:
|
|
37
|
+
While lr scheduling is supported, this uses lr of the first parameter for all parameters.
|
|
38
|
+
"""
|
|
39
|
+
def __init__(
|
|
40
|
+
self,
|
|
41
|
+
params,
|
|
42
|
+
lr: float = 1e-4,
|
|
43
|
+
max_dist: float | None = 1e5,
|
|
44
|
+
validate_step: bool = True,
|
|
45
|
+
momentum: float = 0,
|
|
46
|
+
dampening: float = 0,
|
|
47
|
+
weight_decay: float = 0,
|
|
48
|
+
nesterov: bool = False,
|
|
49
|
+
|
|
50
|
+
):
|
|
51
|
+
|
|
52
|
+
modules = [
|
|
53
|
+
SGD(momentum=momentum,dampening=dampening,weight_decay=weight_decay,nesterov=nesterov),
|
|
54
|
+
LR(lr),
|
|
55
|
+
_DirectionalNewton(max_dist, validate_step)
|
|
56
|
+
]
|
|
57
|
+
super().__init__(params, modules)
|
|
58
|
+
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .newton import ExactNewton
|
|
@@ -0,0 +1,94 @@
|
|
|
1
|
+
from typing import Any, Literal
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from ...modules import (
|
|
6
|
+
LR,
|
|
7
|
+
ClipNorm,
|
|
8
|
+
FallbackLinearSystemSolvers,
|
|
9
|
+
LinearSystemSolvers,
|
|
10
|
+
LineSearches,
|
|
11
|
+
get_line_search,
|
|
12
|
+
)
|
|
13
|
+
from ...modules import ExactNewton as _ExactNewton
|
|
14
|
+
from ..modular import Modular
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class ExactNewton(Modular):
|
|
18
|
+
"""Peforms an exact Newton step using batched autograd. Note that torch.func would be way more efficient
|
|
19
|
+
but much more restrictive to what operations are allowed (I will add it at some point).
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
params: iterable of parameters to optimize or dicts defining parameter groups.
|
|
23
|
+
lr (float, optional): learning rate. Defaults to 1.
|
|
24
|
+
tikhonov (float, optional):
|
|
25
|
+
tikhonov regularization (constant value added to the diagonal of the hessian). Defaults to 0.
|
|
26
|
+
solver (LinearSystemSolvers, optional):
|
|
27
|
+
solver for Hx = g. Defaults to "cholesky_lu" (cholesky or LU if it fails).
|
|
28
|
+
fallback (FallbackLinearSystemSolvers, optional):
|
|
29
|
+
what to do if solver fails. Defaults to "safe_diag"
|
|
30
|
+
(takes nonzero diagonal elements, or fallbacks to gradient descent if all elements are 0).
|
|
31
|
+
max_norm (float, optional):
|
|
32
|
+
clips the newton step to L2 norm to avoid instability by giant steps.
|
|
33
|
+
A mauch better way is to use trust region methods. I haven't implemented any
|
|
34
|
+
but you can use `tz.optim.wrappers.scipy.ScipyMinimize` with one of the trust region methods.
|
|
35
|
+
Defaults to None.
|
|
36
|
+
validate (bool, optional):
|
|
37
|
+
validate if the step didn't increase the loss by `loss * tol` with an additional forward pass.
|
|
38
|
+
If not, undo the step and perform a gradient descent step.
|
|
39
|
+
tol (float, optional):
|
|
40
|
+
only has effect if `validate` is enabled.
|
|
41
|
+
If loss increased by `loss * tol`, perform gradient descent step.
|
|
42
|
+
Set this to 0 to guarantee that loss always decreases. Defaults to 1.
|
|
43
|
+
gd_lr (float, optional):
|
|
44
|
+
only has effect if `validate` is enabled.
|
|
45
|
+
Gradient descent step learning rate. Defaults to 1e-2.
|
|
46
|
+
line_search (OptimizerModule | None, optional): line search module, can be None. Defaults to None.
|
|
47
|
+
batched_hessian (bool, optional):
|
|
48
|
+
whether to use experimental pytorch vmap-vectorized hessian calculation. As per pytorch docs,
|
|
49
|
+
should be faster, but this feature being experimental, there may be performance cliffs.
|
|
50
|
+
Defaults to True.
|
|
51
|
+
diag (False, optional):
|
|
52
|
+
only use the diagonal of the hessian. This will still calculate the full hessian!
|
|
53
|
+
This is mainly useful for benchmarking.
|
|
54
|
+
"""
|
|
55
|
+
def __init__(
|
|
56
|
+
self,
|
|
57
|
+
params,
|
|
58
|
+
lr: float = 1,
|
|
59
|
+
tikhonov: float | Literal['eig'] = 0.0,
|
|
60
|
+
solver: LinearSystemSolvers = "cholesky_lu",
|
|
61
|
+
fallback: FallbackLinearSystemSolvers = "safe_diag",
|
|
62
|
+
max_norm: float | None = None,
|
|
63
|
+
validate=False,
|
|
64
|
+
tol: float = 1,
|
|
65
|
+
gd_lr = 1e-2,
|
|
66
|
+
line_search: LineSearches | None = None,
|
|
67
|
+
batched_hessian = True,
|
|
68
|
+
|
|
69
|
+
diag: bool = False,
|
|
70
|
+
):
|
|
71
|
+
modules: list[Any] = [
|
|
72
|
+
_ExactNewton(
|
|
73
|
+
tikhonov=tikhonov,
|
|
74
|
+
batched_hessian=batched_hessian,
|
|
75
|
+
solver=solver,
|
|
76
|
+
fallback=fallback,
|
|
77
|
+
validate=validate,
|
|
78
|
+
tol = tol,
|
|
79
|
+
gd_lr=gd_lr,
|
|
80
|
+
diag = diag,
|
|
81
|
+
),
|
|
82
|
+
]
|
|
83
|
+
|
|
84
|
+
if max_norm is not None:
|
|
85
|
+
modules.append(ClipNorm(max_norm))
|
|
86
|
+
|
|
87
|
+
modules.append(LR(lr))
|
|
88
|
+
|
|
89
|
+
if line_search is not None:
|
|
90
|
+
modules.append(get_line_search(line_search))
|
|
91
|
+
|
|
92
|
+
super().__init__(params, modules)
|
|
93
|
+
|
|
94
|
+
|
|
File without changes
|
|
@@ -0,0 +1,113 @@
|
|
|
1
|
+
import typing
|
|
2
|
+
from collections import abc
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
import nevergrad as ng
|
|
8
|
+
|
|
9
|
+
from ...core import TensorListOptimizer
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def _ensure_float(x):
|
|
13
|
+
if isinstance(x, torch.Tensor): return x.detach().cpu().item()
|
|
14
|
+
if isinstance(x, np.ndarray): return x.item()
|
|
15
|
+
return float(x)
|
|
16
|
+
|
|
17
|
+
class NevergradOptimizer(TensorListOptimizer):
|
|
18
|
+
"""Use nevergrad optimizer as pytorch optimizer.
|
|
19
|
+
Note that it is recommended to specify `budget` to the number of iterations you expect to run,
|
|
20
|
+
as some nevergrad optimizers will error without it.
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
params: iterable of parameters to optimize or dicts defining parameter groups.
|
|
24
|
+
opt_cls (type[ng.optimizers.base.Optimizer]):
|
|
25
|
+
nevergrad optimizer class. For example, `ng.optimizers.NGOpt`.
|
|
26
|
+
budget (int | None, optional):
|
|
27
|
+
nevergrad parameter which sets allowed number of function evaluations (forward passes).
|
|
28
|
+
This only affects the behaviour of many nevergrad optimizers, for example some
|
|
29
|
+
use certain rule for first 50% of the steps, and then switch to another rule.
|
|
30
|
+
This parameter doesn't actually limit the maximum number of steps!
|
|
31
|
+
But it doesn't have to be exact. Defaults to None.
|
|
32
|
+
mutable_sigma (bool, optional):
|
|
33
|
+
nevergrad parameter, sets whether the mutation standard deviation must mutate as well
|
|
34
|
+
(for mutation based algorithms). Defaults to False.
|
|
35
|
+
use_init (bool, optional):
|
|
36
|
+
whether to use initial model parameters as initial parameters for the nevergrad parametrization.
|
|
37
|
+
The reason you might want to set this to False is because True seems to break some optimizers
|
|
38
|
+
(mainly portfolio ones by initalizing them all to same parameters so they all perform exactly the same steps).
|
|
39
|
+
However if you are fine-tuning something, you have to set this to True, otherwise it will start from
|
|
40
|
+
new random parameters. Defaults to True.
|
|
41
|
+
"""
|
|
42
|
+
def __init__(
|
|
43
|
+
self,
|
|
44
|
+
params,
|
|
45
|
+
opt_cls:"type[ng.optimizers.base.Optimizer] | abc.Callable[..., ng.optimizers.base.Optimizer]",
|
|
46
|
+
budget: int | None = None,
|
|
47
|
+
mutable_sigma = False,
|
|
48
|
+
lb: float | None = None,
|
|
49
|
+
ub: float | None = None,
|
|
50
|
+
use_init = True,
|
|
51
|
+
):
|
|
52
|
+
defaults = dict(lb=lb, ub=ub, use_init=use_init, mutable_sigma=mutable_sigma)
|
|
53
|
+
super().__init__(params, defaults)
|
|
54
|
+
self.opt_cls = opt_cls
|
|
55
|
+
self.opt = None
|
|
56
|
+
self.budget = budget
|
|
57
|
+
|
|
58
|
+
@torch.no_grad
|
|
59
|
+
def step(self, closure): # type:ignore # pylint:disable=signature-differs
|
|
60
|
+
params = self.get_params()
|
|
61
|
+
if self.opt is None:
|
|
62
|
+
ng_params = []
|
|
63
|
+
for group in self.param_groups:
|
|
64
|
+
params = group['params']
|
|
65
|
+
mutable_sigma = group['mutable_sigma']
|
|
66
|
+
use_init = group['use_init']
|
|
67
|
+
lb = group['lb']
|
|
68
|
+
ub = group['ub']
|
|
69
|
+
for p in params:
|
|
70
|
+
if p.requires_grad:
|
|
71
|
+
if use_init:
|
|
72
|
+
ng_params.append(
|
|
73
|
+
ng.p.Array(init = p.detach().cpu().numpy(), lower=lb, upper=ub, mutable_sigma=mutable_sigma))
|
|
74
|
+
else:
|
|
75
|
+
ng_params.append(
|
|
76
|
+
ng.p.Array(shape = p.shape, lower=lb, upper=ub, mutable_sigma=mutable_sigma))
|
|
77
|
+
|
|
78
|
+
parametrization = ng.p.Tuple(*ng_params)
|
|
79
|
+
self.opt = self.opt_cls(parametrization, budget=self.budget)
|
|
80
|
+
|
|
81
|
+
x: ng.p.Tuple = self.opt.ask() # type:ignore
|
|
82
|
+
for cur, new in zip(params, x):
|
|
83
|
+
cur.set_(torch.from_numpy(new.value).to(dtype=cur.dtype, device=cur.device, copy=False).reshape_as(cur)) # type:ignore
|
|
84
|
+
|
|
85
|
+
loss = closure(False)
|
|
86
|
+
self.opt.tell(x, _ensure_float(loss))
|
|
87
|
+
return loss
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
# class NevergradSubspace(ModularOptimizer):
|
|
92
|
+
# def __init__(
|
|
93
|
+
# self,
|
|
94
|
+
# params,
|
|
95
|
+
# opt_cls:"type[ng.optimizers.base.Optimizer] | abc.Callable[..., ng.optimizers.base.Optimizer]",
|
|
96
|
+
# budget=None,
|
|
97
|
+
# mutable_sigma = False,
|
|
98
|
+
# use_init = True,
|
|
99
|
+
# projections = Proj2Masks(5),
|
|
100
|
+
# ):
|
|
101
|
+
|
|
102
|
+
# modules = [
|
|
103
|
+
# Subspace(projections, update_every=100),
|
|
104
|
+
# UninitializedClosureOptimizerWrapper(
|
|
105
|
+
# NevergradOptimizer,
|
|
106
|
+
# opt_cls = opt_cls,
|
|
107
|
+
# budget = budget,
|
|
108
|
+
# mutable_sigma = mutable_sigma,
|
|
109
|
+
# use_init = use_init,
|
|
110
|
+
# ),
|
|
111
|
+
# ]
|
|
112
|
+
|
|
113
|
+
# super().__init__(params, modules)
|
|
@@ -0,0 +1,165 @@
|
|
|
1
|
+
from typing import Literal
|
|
2
|
+
from collections.abc import Mapping, Callable
|
|
3
|
+
from functools import partial
|
|
4
|
+
import numpy as np
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
import nlopt
|
|
8
|
+
from ...core import TensorListOptimizer, _ClosureType
|
|
9
|
+
from ...tensorlist import TensorList
|
|
10
|
+
|
|
11
|
+
_ALGOS_LITERAL = Literal[
|
|
12
|
+
"GN_DIRECT", # = _nlopt.GN_DIRECT
|
|
13
|
+
"GN_DIRECT_L", # = _nlopt.GN_DIRECT_L
|
|
14
|
+
"GN_DIRECT_L_RAND", # = _nlopt.GN_DIRECT_L_RAND
|
|
15
|
+
"GN_DIRECT_NOSCAL", # = _nlopt.GN_DIRECT_NOSCAL
|
|
16
|
+
"GN_DIRECT_L_NOSCAL", # = _nlopt.GN_DIRECT_L_NOSCAL
|
|
17
|
+
"GN_DIRECT_L_RAND_NOSCAL", # = _nlopt.GN_DIRECT_L_RAND_NOSCAL
|
|
18
|
+
"GN_ORIG_DIRECT", # = _nlopt.GN_ORIG_DIRECT
|
|
19
|
+
"GN_ORIG_DIRECT_L", # = _nlopt.GN_ORIG_DIRECT_L
|
|
20
|
+
"GD_STOGO", # = _nlopt.GD_STOGO
|
|
21
|
+
"GD_STOGO_RAND", # = _nlopt.GD_STOGO_RAND
|
|
22
|
+
"LD_LBFGS_NOCEDAL", # = _nlopt.LD_LBFGS_NOCEDAL
|
|
23
|
+
"LD_LBFGS", # = _nlopt.LD_LBFGS
|
|
24
|
+
"LN_PRAXIS", # = _nlopt.LN_PRAXIS
|
|
25
|
+
"LD_VAR1", # = _nlopt.LD_VAR1
|
|
26
|
+
"LD_VAR2", # = _nlopt.LD_VAR2
|
|
27
|
+
"LD_TNEWTON", # = _nlopt.LD_TNEWTON
|
|
28
|
+
"LD_TNEWTON_RESTART", # = _nlopt.LD_TNEWTON_RESTART
|
|
29
|
+
"LD_TNEWTON_PRECOND", # = _nlopt.LD_TNEWTON_PRECOND
|
|
30
|
+
"LD_TNEWTON_PRECOND_RESTART", # = _nlopt.LD_TNEWTON_PRECOND_RESTART
|
|
31
|
+
"GN_CRS2_LM", # = _nlopt.GN_CRS2_LM
|
|
32
|
+
"GN_MLSL", # = _nlopt.GN_MLSL
|
|
33
|
+
"GD_MLSL", # = _nlopt.GD_MLSL
|
|
34
|
+
"GN_MLSL_LDS", # = _nlopt.GN_MLSL_LDS
|
|
35
|
+
"GD_MLSL_LDS", # = _nlopt.GD_MLSL_LDS
|
|
36
|
+
"LD_MMA", # = _nlopt.LD_MMA
|
|
37
|
+
"LN_COBYLA", # = _nlopt.LN_COBYLA
|
|
38
|
+
"LN_NEWUOA", # = _nlopt.LN_NEWUOA
|
|
39
|
+
"LN_NEWUOA_BOUND", # = _nlopt.LN_NEWUOA_BOUND
|
|
40
|
+
"LN_NELDERMEAD", # = _nlopt.LN_NELDERMEAD
|
|
41
|
+
"LN_SBPLX", # = _nlopt.LN_SBPLX
|
|
42
|
+
"LN_AUGLAG", # = _nlopt.LN_AUGLAG
|
|
43
|
+
"LD_AUGLAG", # = _nlopt.LD_AUGLAG
|
|
44
|
+
"LN_AUGLAG_EQ", # = _nlopt.LN_AUGLAG_EQ
|
|
45
|
+
"LD_AUGLAG_EQ", # = _nlopt.LD_AUGLAG_EQ
|
|
46
|
+
"LN_BOBYQA", # = _nlopt.LN_BOBYQA
|
|
47
|
+
"GN_ISRES", # = _nlopt.GN_ISRES
|
|
48
|
+
"AUGLAG", # = _nlopt.AUGLAG
|
|
49
|
+
"AUGLAG_EQ", # = _nlopt.AUGLAG_EQ
|
|
50
|
+
"G_MLSL", # = _nlopt.G_MLSL
|
|
51
|
+
"G_MLSL_LDS", # = _nlopt.G_MLSL_LDS
|
|
52
|
+
"LD_SLSQP", # = _nlopt.LD_SLSQP
|
|
53
|
+
"LD_CCSAQ", # = _nlopt.LD_CCSAQ
|
|
54
|
+
"GN_ESCH", # = _nlopt.GN_ESCH
|
|
55
|
+
"GN_AGS", # = _nlopt.GN_AGS
|
|
56
|
+
]
|
|
57
|
+
|
|
58
|
+
def _ensure_float(x):
|
|
59
|
+
if isinstance(x, torch.Tensor): return x.detach().cpu().item()
|
|
60
|
+
if isinstance(x, np.ndarray): return x.item()
|
|
61
|
+
return float(x)
|
|
62
|
+
|
|
63
|
+
def _ensure_tensor(x):
|
|
64
|
+
if isinstance(x, np.ndarray):
|
|
65
|
+
x.setflags(write=True)
|
|
66
|
+
return torch.from_numpy(x)
|
|
67
|
+
return torch.tensor(x, dtype=torch.float32)
|
|
68
|
+
|
|
69
|
+
inf = float('inf')
|
|
70
|
+
class NLOptOptimizer(TensorListOptimizer):
|
|
71
|
+
"""Use nlopt as pytorch optimizer, with gradient supplied by pytorch autograd.
|
|
72
|
+
Note that this performs full minimization on each step,
|
|
73
|
+
so usually you would want to perform a single step, although performing multiple steps will refine the
|
|
74
|
+
solution.
|
|
75
|
+
|
|
76
|
+
Some algorithms are buggy with numpy>=2.
|
|
77
|
+
|
|
78
|
+
Args:
|
|
79
|
+
params: iterable of parameters to optimize or dicts defining parameter groups.
|
|
80
|
+
algorithm (int | _ALGOS_LITERAL): optimization algorithm from https://nlopt.readthedocs.io/en/latest/NLopt_Algorithms/
|
|
81
|
+
maxeval (int | None):
|
|
82
|
+
maximum allowed function evaluations, set to None to disable. But some stopping criterion
|
|
83
|
+
must be set otherwise nlopt will run forever.
|
|
84
|
+
lb (float | None, optional): optional lower bounds, some algorithms require this. Defaults to None.
|
|
85
|
+
ub (float | None, optional): optional upper bounds, some algorithms require this. Defaults to None.
|
|
86
|
+
stopval (float | None, optional): stop minimizing when an objective value ≤ stopval is found. Defaults to None.
|
|
87
|
+
ftol_rel (float | None, optional): set relative tolerance on function value. Defaults to None.
|
|
88
|
+
ftol_abs (float | None, optional): set absolute tolerance on function value. Defaults to None.
|
|
89
|
+
xtol_rel (float | None, optional): set relative tolerance on optimization parameters. Defaults to None.
|
|
90
|
+
xtol_abs (float | None, optional): set absolute tolerances on optimization parameters. Defaults to None.
|
|
91
|
+
maxtime (float | None, optional): stop when the optimization time (in seconds) exceeds maxtime. Defaults to None.
|
|
92
|
+
"""
|
|
93
|
+
def __init__(
|
|
94
|
+
self,
|
|
95
|
+
params,
|
|
96
|
+
algorithm: int | _ALGOS_LITERAL,
|
|
97
|
+
maxeval: int | None,
|
|
98
|
+
lb: float | None = None,
|
|
99
|
+
ub: float | None = None,
|
|
100
|
+
stopval: float | None = None,
|
|
101
|
+
ftol_rel: float | None = None,
|
|
102
|
+
ftol_abs: float | None = None,
|
|
103
|
+
xtol_rel: float | None = None,
|
|
104
|
+
xtol_abs: float | None = None,
|
|
105
|
+
maxtime: float | None = None,
|
|
106
|
+
):
|
|
107
|
+
defaults = dict(lb=lb, ub=ub)
|
|
108
|
+
super().__init__(params, defaults)
|
|
109
|
+
|
|
110
|
+
self.opt: nlopt.opt | None = None
|
|
111
|
+
if isinstance(algorithm, str): algorithm = getattr(nlopt, algorithm.upper())
|
|
112
|
+
self.algorithm: int = algorithm # type:ignore
|
|
113
|
+
self.algorithm_name: str | None = None
|
|
114
|
+
|
|
115
|
+
self.maxeval = maxeval; self.stopval = stopval
|
|
116
|
+
self.ftol_rel = ftol_rel; self.ftol_abs = ftol_abs
|
|
117
|
+
self.xtol_rel = xtol_rel; self.xtol_abs = xtol_abs
|
|
118
|
+
self.maxtime = maxtime
|
|
119
|
+
|
|
120
|
+
self._last_loss = None
|
|
121
|
+
|
|
122
|
+
def _f(self, x: np.ndarray, grad: np.ndarray, closure: _ClosureType, params: TensorList):
|
|
123
|
+
params.from_vec_(_ensure_tensor(x).to(params[0], copy=False))
|
|
124
|
+
if grad.size > 0:
|
|
125
|
+
with torch.enable_grad(): loss = closure()
|
|
126
|
+
self._last_loss = _ensure_float(loss)
|
|
127
|
+
grad[:] = params.ensure_grad_().grad.to_vec().reshape(grad.shape).detach().cpu().numpy()
|
|
128
|
+
return self._last_loss
|
|
129
|
+
|
|
130
|
+
self._last_loss = _ensure_float(closure(False))
|
|
131
|
+
return self._last_loss
|
|
132
|
+
|
|
133
|
+
@torch.no_grad
|
|
134
|
+
def step(self, closure: _ClosureType): # pylint: disable = signature-differs
|
|
135
|
+
|
|
136
|
+
params = self.get_params()
|
|
137
|
+
|
|
138
|
+
# make bounds
|
|
139
|
+
lb, ub = self.get_group_keys('lb', 'ub', cls=list)
|
|
140
|
+
lower = []
|
|
141
|
+
upper = []
|
|
142
|
+
for p, l, u in zip(params, lb, ub):
|
|
143
|
+
if l is None: l = -inf
|
|
144
|
+
if u is None: u = inf
|
|
145
|
+
lower.extend([l] * p.numel())
|
|
146
|
+
upper.extend([u] * p.numel())
|
|
147
|
+
|
|
148
|
+
x0 = params.to_vec().detach().cpu().numpy()
|
|
149
|
+
|
|
150
|
+
self.opt = nlopt.opt(self.algorithm, x0.size)
|
|
151
|
+
self.opt.set_min_objective(partial(self._f, closure = closure, params = params))
|
|
152
|
+
self.opt.set_lower_bounds(lower)
|
|
153
|
+
self.opt.set_upper_bounds(upper)
|
|
154
|
+
|
|
155
|
+
if self.maxeval is not None: self.opt.set_maxeval(self.maxeval)
|
|
156
|
+
if self.stopval is not None: self.opt.set_stopval(self.stopval)
|
|
157
|
+
if self.ftol_rel is not None: self.opt.set_ftol_rel(self.ftol_rel)
|
|
158
|
+
if self.ftol_abs is not None: self.opt.set_ftol_abs(self.ftol_abs)
|
|
159
|
+
if self.xtol_rel is not None: self.opt.set_xtol_rel(self.xtol_rel)
|
|
160
|
+
if self.xtol_abs is not None: self.opt.set_xtol_abs(self.xtol_abs)
|
|
161
|
+
if self.maxtime is not None: self.opt.set_maxtime(self.maxtime)
|
|
162
|
+
|
|
163
|
+
x = self.opt.optimize(x0)
|
|
164
|
+
params.from_vec_(torch.from_numpy(x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
|
|
165
|
+
return self._last_loss
|