torchzero 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (104) hide show
  1. torchzero/__init__.py +4 -0
  2. torchzero/core/__init__.py +13 -0
  3. torchzero/core/module.py +471 -0
  4. torchzero/core/tensorlist_optimizer.py +219 -0
  5. torchzero/modules/__init__.py +21 -0
  6. torchzero/modules/adaptive/__init__.py +4 -0
  7. torchzero/modules/adaptive/adaptive.py +192 -0
  8. torchzero/modules/experimental/__init__.py +19 -0
  9. torchzero/modules/experimental/experimental.py +294 -0
  10. torchzero/modules/experimental/quad_interp.py +104 -0
  11. torchzero/modules/experimental/subspace.py +259 -0
  12. torchzero/modules/gradient_approximation/__init__.py +7 -0
  13. torchzero/modules/gradient_approximation/_fd_formulas.py +3 -0
  14. torchzero/modules/gradient_approximation/base_approximator.py +110 -0
  15. torchzero/modules/gradient_approximation/fdm.py +125 -0
  16. torchzero/modules/gradient_approximation/forward_gradient.py +163 -0
  17. torchzero/modules/gradient_approximation/newton_fdm.py +198 -0
  18. torchzero/modules/gradient_approximation/rfdm.py +125 -0
  19. torchzero/modules/line_search/__init__.py +30 -0
  20. torchzero/modules/line_search/armijo.py +56 -0
  21. torchzero/modules/line_search/base_ls.py +139 -0
  22. torchzero/modules/line_search/directional_newton.py +217 -0
  23. torchzero/modules/line_search/grid_ls.py +158 -0
  24. torchzero/modules/line_search/scipy_minimize_scalar.py +62 -0
  25. torchzero/modules/meta/__init__.py +12 -0
  26. torchzero/modules/meta/alternate.py +65 -0
  27. torchzero/modules/meta/grafting.py +195 -0
  28. torchzero/modules/meta/optimizer_wrapper.py +173 -0
  29. torchzero/modules/meta/return_overrides.py +46 -0
  30. torchzero/modules/misc/__init__.py +10 -0
  31. torchzero/modules/misc/accumulate.py +43 -0
  32. torchzero/modules/misc/basic.py +115 -0
  33. torchzero/modules/misc/lr.py +96 -0
  34. torchzero/modules/misc/multistep.py +51 -0
  35. torchzero/modules/misc/on_increase.py +53 -0
  36. torchzero/modules/momentum/__init__.py +4 -0
  37. torchzero/modules/momentum/momentum.py +106 -0
  38. torchzero/modules/operations/__init__.py +29 -0
  39. torchzero/modules/operations/multi.py +298 -0
  40. torchzero/modules/operations/reduction.py +134 -0
  41. torchzero/modules/operations/singular.py +113 -0
  42. torchzero/modules/optimizers/__init__.py +10 -0
  43. torchzero/modules/optimizers/adagrad.py +49 -0
  44. torchzero/modules/optimizers/adam.py +118 -0
  45. torchzero/modules/optimizers/lion.py +28 -0
  46. torchzero/modules/optimizers/rmsprop.py +51 -0
  47. torchzero/modules/optimizers/rprop.py +99 -0
  48. torchzero/modules/optimizers/sgd.py +54 -0
  49. torchzero/modules/orthogonalization/__init__.py +2 -0
  50. torchzero/modules/orthogonalization/newtonschulz.py +159 -0
  51. torchzero/modules/orthogonalization/svd.py +86 -0
  52. torchzero/modules/quasi_newton/__init__.py +4 -0
  53. torchzero/modules/regularization/__init__.py +22 -0
  54. torchzero/modules/regularization/dropout.py +34 -0
  55. torchzero/modules/regularization/noise.py +77 -0
  56. torchzero/modules/regularization/normalization.py +328 -0
  57. torchzero/modules/regularization/ortho_grad.py +78 -0
  58. torchzero/modules/regularization/weight_decay.py +92 -0
  59. torchzero/modules/scheduling/__init__.py +2 -0
  60. torchzero/modules/scheduling/lr_schedulers.py +131 -0
  61. torchzero/modules/scheduling/step_size.py +80 -0
  62. torchzero/modules/second_order/__init__.py +4 -0
  63. torchzero/modules/second_order/newton.py +165 -0
  64. torchzero/modules/smoothing/__init__.py +5 -0
  65. torchzero/modules/smoothing/gaussian_smoothing.py +90 -0
  66. torchzero/modules/smoothing/laplacian_smoothing.py +128 -0
  67. torchzero/modules/weight_averaging/__init__.py +2 -0
  68. torchzero/modules/weight_averaging/ema.py +72 -0
  69. torchzero/modules/weight_averaging/swa.py +171 -0
  70. torchzero/optim/__init__.py +10 -0
  71. torchzero/optim/experimental/__init__.py +20 -0
  72. torchzero/optim/experimental/experimental.py +343 -0
  73. torchzero/optim/experimental/ray_search.py +83 -0
  74. torchzero/optim/first_order/__init__.py +18 -0
  75. torchzero/optim/first_order/cautious.py +158 -0
  76. torchzero/optim/first_order/forward_gradient.py +70 -0
  77. torchzero/optim/first_order/optimizers.py +570 -0
  78. torchzero/optim/modular.py +132 -0
  79. torchzero/optim/quasi_newton/__init__.py +1 -0
  80. torchzero/optim/quasi_newton/directional_newton.py +58 -0
  81. torchzero/optim/second_order/__init__.py +1 -0
  82. torchzero/optim/second_order/newton.py +94 -0
  83. torchzero/optim/wrappers/__init__.py +0 -0
  84. torchzero/optim/wrappers/nevergrad.py +113 -0
  85. torchzero/optim/wrappers/nlopt.py +165 -0
  86. torchzero/optim/wrappers/scipy.py +439 -0
  87. torchzero/optim/zeroth_order/__init__.py +4 -0
  88. torchzero/optim/zeroth_order/fdm.py +87 -0
  89. torchzero/optim/zeroth_order/newton_fdm.py +146 -0
  90. torchzero/optim/zeroth_order/rfdm.py +217 -0
  91. torchzero/optim/zeroth_order/rs.py +85 -0
  92. torchzero/random/__init__.py +1 -0
  93. torchzero/random/random.py +46 -0
  94. torchzero/tensorlist.py +819 -0
  95. torchzero/utils/__init__.py +0 -0
  96. torchzero/utils/compile.py +39 -0
  97. torchzero/utils/derivatives.py +99 -0
  98. torchzero/utils/python_tools.py +25 -0
  99. torchzero/utils/torch_tools.py +92 -0
  100. torchzero-0.0.1.dist-info/LICENSE +21 -0
  101. torchzero-0.0.1.dist-info/METADATA +118 -0
  102. torchzero-0.0.1.dist-info/RECORD +104 -0
  103. torchzero-0.0.1.dist-info/WHEEL +5 -0
  104. torchzero-0.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,132 @@
1
+ from collections import abc
2
+ import warnings
3
+ from inspect import cleandoc
4
+ import torch
5
+
6
+ from ..core import OptimizerModule, TensorListOptimizer, OptimizationState, _Chain, _Chainable
7
+ from ..utils.python_tools import flatten
8
+
9
+ def _unroll_modules(flat_modules: list[OptimizerModule], nested) -> list[OptimizerModule]:
10
+ """returns a list of all modules, including all nested ones"""
11
+ unrolled = []
12
+ for m in flat_modules:
13
+ unrolled.append(m)
14
+ if len(m.children) > 0:
15
+ unrolled.extend(_unroll_modules(list(m.children.values()), nested=True))
16
+ if nested:
17
+ if m.next_module is not None:
18
+ unrolled.extend(_unroll_modules([m.next_module], nested=True))
19
+ return unrolled
20
+
21
+
22
+ class Modular(TensorListOptimizer):
23
+ """Creates a modular optimizer by chaining together a sequence of optimizer modules.
24
+
25
+ Args:
26
+ params: iterable of parameters to optimize or dicts defining parameter groups.
27
+ *modules (Iterable[OptimizerModule] | OptimizerModule):
28
+ A sequence of optimizer modules to chain together. This argument will be flattened."""
29
+ def __init__(self, params, *modules: _Chainable):
30
+ flat_modules = flatten(modules)
31
+ self.modules: list[OptimizerModule] = flat_modules
32
+ self.chain = _Chain(flat_modules)
33
+
34
+ # save unrolled modules and make sure there is only 1 LR module.
35
+ self.unrolled_modules = _unroll_modules(flat_modules, nested=False)
36
+ num_lr_modules = len([m for m in self.unrolled_modules if m.IS_LR_MODULE])
37
+ if num_lr_modules > 1:
38
+ warnings.warn(cleandoc(
39
+ f"""More then 1 lr modules have been added.
40
+ This may lead to incorrect behaviour with learning rate scheduling and per-parameter learning rates.
41
+ Make sure there is a single `LR` module, use `Alpha` module instead of it where needed.
42
+ \nList of modules: {self.unrolled_modules}; \nlist of lr modules: {[m for m in self.unrolled_modules if m.IS_LR_MODULE]}"""
43
+ ))
44
+
45
+ if isinstance(params, torch.nn.Module):
46
+ self.model = params
47
+ params = list(params.parameters())
48
+ else:
49
+ self.model = None
50
+ params = list(params)
51
+
52
+ # if there is an `lr` setting, make sure there is an LR module that can use it
53
+ for p in params:
54
+ if isinstance(p, dict):
55
+ if 'lr' in p:
56
+ if num_lr_modules == 0:
57
+ warnings.warn(cleandoc(
58
+ """Passed "lr" setting in a parameter group, but there is no LR module that can use that setting.
59
+ Add an `LR` module to make per-layer "lr" setting work."""
60
+ ))
61
+
62
+ super().__init__(params, {})
63
+ self.chain._initialize_(params, set_passed_params=True)
64
+
65
+ # run post-init hooks
66
+ for module in self.unrolled_modules:
67
+ for hook in module.post_init_hooks:
68
+ hook(self, module)
69
+
70
+ def get_lr_module(self, last=True) -> OptimizerModule:
71
+ """
72
+ Retrieves the module in the chain that controls the learning rate.
73
+
74
+ This method is useful for setting up a learning rate scheduler. By default, it retrieves the last module in the chain
75
+ that has an `lr` group parameter.
76
+
77
+ Args:
78
+ last (bool, optional):
79
+ If multiple modules have an `lr` parameter, this argument controls which one is returned.
80
+ - If `True` (default), the last module is returned.
81
+ - If `False`, the first module is returned.
82
+
83
+ Returns:
84
+ OptimizerModule: The module that controls the learning rate.
85
+
86
+ Raises:
87
+ ValueError: If no modules in the chain have an `lr` parameter. To fix this, add an `LR` module.
88
+
89
+ Example:
90
+
91
+ .. code:: py
92
+ from torch.optim.lr_scheduler import OneCycleLR
93
+ import torchzero as tz
94
+
95
+ opt = tz.Modular(model.parameters(), [tz.m.RMSProp(), tz.m.LR(1e-2), tz.m.DirectionalNewton()])
96
+ lr_scheduler = OneCycleLR(opt.get_lr_module(), max_lr = 1e-1, total_steps = 1000, cycle_momentum=False)
97
+
98
+ """
99
+ modules = list(reversed(self.unrolled_modules)) if last else self.unrolled_modules
100
+ for m in modules:
101
+ if 'lr' in m.param_groups[0]: return m
102
+
103
+ raise ValueError(f'No modules out of {", ".join(m.__class__.__name__ for m in modules)} support and `lr` parameter. The easiest way to fix is is to add an `LR(1)` module at the end.')
104
+
105
+ def get_module_by_name(self, name: str | type, last=True) -> OptimizerModule:
106
+ """Returns the first or last module in the chain that matches the provided name or type.
107
+
108
+ Args:
109
+ name (str | type): the name (as a string) or the type of the module to search for.
110
+ last (bool, optional):
111
+ If multiple modules match, this argument controls which one is returned.
112
+ - If `True` (default), the last matching module is returned.
113
+ - If `False`, the first matching module is returned.
114
+
115
+ Returns:
116
+ OptimizerModule: The matching optimizer module.
117
+
118
+ Raises:
119
+ ValueError: If no modules in the chain match the provided name or type.
120
+ """
121
+ modules = list(reversed(self.unrolled_modules)) if last else self.unrolled_modules
122
+ for m in modules:
123
+ if isinstance(name, str) and m.__class__.__name__ == name: return m
124
+ if isinstance(name, type) and isinstance(m, name): return m
125
+
126
+ raise ValueError(f'No modules out of {", ".join(m.__class__.__name__ for m in modules)} match "{name}".')
127
+
128
+ def step(self, closure=None): # type:ignore
129
+ state = OptimizationState(closure, self.model)
130
+ res = self.chain.step(state)
131
+ for hook in state.post_step_hooks: hook(self, state)
132
+ return res
@@ -0,0 +1 @@
1
+ from .directional_newton import DirectionalNewton
@@ -0,0 +1,58 @@
1
+ from ...modules import (
2
+ SGD,
3
+ )
4
+ from ...modules import DirectionalNewton as _DirectionalNewton, LR
5
+ from ..modular import Modular
6
+
7
+
8
+ class DirectionalNewton(Modular):
9
+ """Minimizes a parabola in the direction of the gradient (or update if momentum or weight decay is enabled)
10
+ via one additional forward pass, and uses another forward pass to make sure it didn't overstep.
11
+ So in total this performs three forward passes and one backward.
12
+
13
+ First forward and backward pass is used to calculate the value and gradient at initial parameters.
14
+ Then a gradient descent step is performed with `lr` learning rate, and loss is recalculated
15
+ with new parameters. A quadratic is fitted to two points and gradient,
16
+ if it has positive curvature, this makes a step towards the minimum, and checks if lr decreased
17
+ with an additional forward pass.
18
+
19
+ Args:
20
+ params: iterable of parameters to optimize or dicts defining parameter groups.
21
+ lr (float, optional):
22
+ learning rate. Since you shouldn't put this module after LR(), you have to specify
23
+ the learning rate in this argument. Defaults to 1e-2.
24
+ max_dist (float | None, optional):
25
+ maximum distance to step when minimizing quadratic.
26
+ If minimum is further than this distance, minimization is not performed. Defaults to 1e4.
27
+ validate_step (bool, optional):
28
+ uses an additional forward pass to check
29
+ if step towards the minimum actually decreased the loss. Defaults to True.
30
+ momentum (float, optional): momentum. Defaults to 0.
31
+ dampening (float, optional): momentum dampening. Defaults to 0.
32
+ weight_decay (float, optional): weight decay (L2 regularization). Defaults to 0.
33
+ nesterov (bool, optional):
34
+ enables nesterov momentum, otherwise uses heavyball momentum. Defaults to False.
35
+
36
+ Note:
37
+ While lr scheduling is supported, this uses lr of the first parameter for all parameters.
38
+ """
39
+ def __init__(
40
+ self,
41
+ params,
42
+ lr: float = 1e-4,
43
+ max_dist: float | None = 1e5,
44
+ validate_step: bool = True,
45
+ momentum: float = 0,
46
+ dampening: float = 0,
47
+ weight_decay: float = 0,
48
+ nesterov: bool = False,
49
+
50
+ ):
51
+
52
+ modules = [
53
+ SGD(momentum=momentum,dampening=dampening,weight_decay=weight_decay,nesterov=nesterov),
54
+ LR(lr),
55
+ _DirectionalNewton(max_dist, validate_step)
56
+ ]
57
+ super().__init__(params, modules)
58
+
@@ -0,0 +1 @@
1
+ from .newton import ExactNewton
@@ -0,0 +1,94 @@
1
+ from typing import Any, Literal
2
+
3
+ import torch
4
+
5
+ from ...modules import (
6
+ LR,
7
+ ClipNorm,
8
+ FallbackLinearSystemSolvers,
9
+ LinearSystemSolvers,
10
+ LineSearches,
11
+ get_line_search,
12
+ )
13
+ from ...modules import ExactNewton as _ExactNewton
14
+ from ..modular import Modular
15
+
16
+
17
+ class ExactNewton(Modular):
18
+ """Peforms an exact Newton step using batched autograd. Note that torch.func would be way more efficient
19
+ but much more restrictive to what operations are allowed (I will add it at some point).
20
+
21
+ Args:
22
+ params: iterable of parameters to optimize or dicts defining parameter groups.
23
+ lr (float, optional): learning rate. Defaults to 1.
24
+ tikhonov (float, optional):
25
+ tikhonov regularization (constant value added to the diagonal of the hessian). Defaults to 0.
26
+ solver (LinearSystemSolvers, optional):
27
+ solver for Hx = g. Defaults to "cholesky_lu" (cholesky or LU if it fails).
28
+ fallback (FallbackLinearSystemSolvers, optional):
29
+ what to do if solver fails. Defaults to "safe_diag"
30
+ (takes nonzero diagonal elements, or fallbacks to gradient descent if all elements are 0).
31
+ max_norm (float, optional):
32
+ clips the newton step to L2 norm to avoid instability by giant steps.
33
+ A mauch better way is to use trust region methods. I haven't implemented any
34
+ but you can use `tz.optim.wrappers.scipy.ScipyMinimize` with one of the trust region methods.
35
+ Defaults to None.
36
+ validate (bool, optional):
37
+ validate if the step didn't increase the loss by `loss * tol` with an additional forward pass.
38
+ If not, undo the step and perform a gradient descent step.
39
+ tol (float, optional):
40
+ only has effect if `validate` is enabled.
41
+ If loss increased by `loss * tol`, perform gradient descent step.
42
+ Set this to 0 to guarantee that loss always decreases. Defaults to 1.
43
+ gd_lr (float, optional):
44
+ only has effect if `validate` is enabled.
45
+ Gradient descent step learning rate. Defaults to 1e-2.
46
+ line_search (OptimizerModule | None, optional): line search module, can be None. Defaults to None.
47
+ batched_hessian (bool, optional):
48
+ whether to use experimental pytorch vmap-vectorized hessian calculation. As per pytorch docs,
49
+ should be faster, but this feature being experimental, there may be performance cliffs.
50
+ Defaults to True.
51
+ diag (False, optional):
52
+ only use the diagonal of the hessian. This will still calculate the full hessian!
53
+ This is mainly useful for benchmarking.
54
+ """
55
+ def __init__(
56
+ self,
57
+ params,
58
+ lr: float = 1,
59
+ tikhonov: float | Literal['eig'] = 0.0,
60
+ solver: LinearSystemSolvers = "cholesky_lu",
61
+ fallback: FallbackLinearSystemSolvers = "safe_diag",
62
+ max_norm: float | None = None,
63
+ validate=False,
64
+ tol: float = 1,
65
+ gd_lr = 1e-2,
66
+ line_search: LineSearches | None = None,
67
+ batched_hessian = True,
68
+
69
+ diag: bool = False,
70
+ ):
71
+ modules: list[Any] = [
72
+ _ExactNewton(
73
+ tikhonov=tikhonov,
74
+ batched_hessian=batched_hessian,
75
+ solver=solver,
76
+ fallback=fallback,
77
+ validate=validate,
78
+ tol = tol,
79
+ gd_lr=gd_lr,
80
+ diag = diag,
81
+ ),
82
+ ]
83
+
84
+ if max_norm is not None:
85
+ modules.append(ClipNorm(max_norm))
86
+
87
+ modules.append(LR(lr))
88
+
89
+ if line_search is not None:
90
+ modules.append(get_line_search(line_search))
91
+
92
+ super().__init__(params, modules)
93
+
94
+
File without changes
@@ -0,0 +1,113 @@
1
+ import typing
2
+ from collections import abc
3
+
4
+ import numpy as np
5
+ import torch
6
+
7
+ import nevergrad as ng
8
+
9
+ from ...core import TensorListOptimizer
10
+
11
+
12
+ def _ensure_float(x):
13
+ if isinstance(x, torch.Tensor): return x.detach().cpu().item()
14
+ if isinstance(x, np.ndarray): return x.item()
15
+ return float(x)
16
+
17
+ class NevergradOptimizer(TensorListOptimizer):
18
+ """Use nevergrad optimizer as pytorch optimizer.
19
+ Note that it is recommended to specify `budget` to the number of iterations you expect to run,
20
+ as some nevergrad optimizers will error without it.
21
+
22
+ Args:
23
+ params: iterable of parameters to optimize or dicts defining parameter groups.
24
+ opt_cls (type[ng.optimizers.base.Optimizer]):
25
+ nevergrad optimizer class. For example, `ng.optimizers.NGOpt`.
26
+ budget (int | None, optional):
27
+ nevergrad parameter which sets allowed number of function evaluations (forward passes).
28
+ This only affects the behaviour of many nevergrad optimizers, for example some
29
+ use certain rule for first 50% of the steps, and then switch to another rule.
30
+ This parameter doesn't actually limit the maximum number of steps!
31
+ But it doesn't have to be exact. Defaults to None.
32
+ mutable_sigma (bool, optional):
33
+ nevergrad parameter, sets whether the mutation standard deviation must mutate as well
34
+ (for mutation based algorithms). Defaults to False.
35
+ use_init (bool, optional):
36
+ whether to use initial model parameters as initial parameters for the nevergrad parametrization.
37
+ The reason you might want to set this to False is because True seems to break some optimizers
38
+ (mainly portfolio ones by initalizing them all to same parameters so they all perform exactly the same steps).
39
+ However if you are fine-tuning something, you have to set this to True, otherwise it will start from
40
+ new random parameters. Defaults to True.
41
+ """
42
+ def __init__(
43
+ self,
44
+ params,
45
+ opt_cls:"type[ng.optimizers.base.Optimizer] | abc.Callable[..., ng.optimizers.base.Optimizer]",
46
+ budget: int | None = None,
47
+ mutable_sigma = False,
48
+ lb: float | None = None,
49
+ ub: float | None = None,
50
+ use_init = True,
51
+ ):
52
+ defaults = dict(lb=lb, ub=ub, use_init=use_init, mutable_sigma=mutable_sigma)
53
+ super().__init__(params, defaults)
54
+ self.opt_cls = opt_cls
55
+ self.opt = None
56
+ self.budget = budget
57
+
58
+ @torch.no_grad
59
+ def step(self, closure): # type:ignore # pylint:disable=signature-differs
60
+ params = self.get_params()
61
+ if self.opt is None:
62
+ ng_params = []
63
+ for group in self.param_groups:
64
+ params = group['params']
65
+ mutable_sigma = group['mutable_sigma']
66
+ use_init = group['use_init']
67
+ lb = group['lb']
68
+ ub = group['ub']
69
+ for p in params:
70
+ if p.requires_grad:
71
+ if use_init:
72
+ ng_params.append(
73
+ ng.p.Array(init = p.detach().cpu().numpy(), lower=lb, upper=ub, mutable_sigma=mutable_sigma))
74
+ else:
75
+ ng_params.append(
76
+ ng.p.Array(shape = p.shape, lower=lb, upper=ub, mutable_sigma=mutable_sigma))
77
+
78
+ parametrization = ng.p.Tuple(*ng_params)
79
+ self.opt = self.opt_cls(parametrization, budget=self.budget)
80
+
81
+ x: ng.p.Tuple = self.opt.ask() # type:ignore
82
+ for cur, new in zip(params, x):
83
+ cur.set_(torch.from_numpy(new.value).to(dtype=cur.dtype, device=cur.device, copy=False).reshape_as(cur)) # type:ignore
84
+
85
+ loss = closure(False)
86
+ self.opt.tell(x, _ensure_float(loss))
87
+ return loss
88
+
89
+
90
+
91
+ # class NevergradSubspace(ModularOptimizer):
92
+ # def __init__(
93
+ # self,
94
+ # params,
95
+ # opt_cls:"type[ng.optimizers.base.Optimizer] | abc.Callable[..., ng.optimizers.base.Optimizer]",
96
+ # budget=None,
97
+ # mutable_sigma = False,
98
+ # use_init = True,
99
+ # projections = Proj2Masks(5),
100
+ # ):
101
+
102
+ # modules = [
103
+ # Subspace(projections, update_every=100),
104
+ # UninitializedClosureOptimizerWrapper(
105
+ # NevergradOptimizer,
106
+ # opt_cls = opt_cls,
107
+ # budget = budget,
108
+ # mutable_sigma = mutable_sigma,
109
+ # use_init = use_init,
110
+ # ),
111
+ # ]
112
+
113
+ # super().__init__(params, modules)
@@ -0,0 +1,165 @@
1
+ from typing import Literal
2
+ from collections.abc import Mapping, Callable
3
+ from functools import partial
4
+ import numpy as np
5
+ import torch
6
+
7
+ import nlopt
8
+ from ...core import TensorListOptimizer, _ClosureType
9
+ from ...tensorlist import TensorList
10
+
11
+ _ALGOS_LITERAL = Literal[
12
+ "GN_DIRECT", # = _nlopt.GN_DIRECT
13
+ "GN_DIRECT_L", # = _nlopt.GN_DIRECT_L
14
+ "GN_DIRECT_L_RAND", # = _nlopt.GN_DIRECT_L_RAND
15
+ "GN_DIRECT_NOSCAL", # = _nlopt.GN_DIRECT_NOSCAL
16
+ "GN_DIRECT_L_NOSCAL", # = _nlopt.GN_DIRECT_L_NOSCAL
17
+ "GN_DIRECT_L_RAND_NOSCAL", # = _nlopt.GN_DIRECT_L_RAND_NOSCAL
18
+ "GN_ORIG_DIRECT", # = _nlopt.GN_ORIG_DIRECT
19
+ "GN_ORIG_DIRECT_L", # = _nlopt.GN_ORIG_DIRECT_L
20
+ "GD_STOGO", # = _nlopt.GD_STOGO
21
+ "GD_STOGO_RAND", # = _nlopt.GD_STOGO_RAND
22
+ "LD_LBFGS_NOCEDAL", # = _nlopt.LD_LBFGS_NOCEDAL
23
+ "LD_LBFGS", # = _nlopt.LD_LBFGS
24
+ "LN_PRAXIS", # = _nlopt.LN_PRAXIS
25
+ "LD_VAR1", # = _nlopt.LD_VAR1
26
+ "LD_VAR2", # = _nlopt.LD_VAR2
27
+ "LD_TNEWTON", # = _nlopt.LD_TNEWTON
28
+ "LD_TNEWTON_RESTART", # = _nlopt.LD_TNEWTON_RESTART
29
+ "LD_TNEWTON_PRECOND", # = _nlopt.LD_TNEWTON_PRECOND
30
+ "LD_TNEWTON_PRECOND_RESTART", # = _nlopt.LD_TNEWTON_PRECOND_RESTART
31
+ "GN_CRS2_LM", # = _nlopt.GN_CRS2_LM
32
+ "GN_MLSL", # = _nlopt.GN_MLSL
33
+ "GD_MLSL", # = _nlopt.GD_MLSL
34
+ "GN_MLSL_LDS", # = _nlopt.GN_MLSL_LDS
35
+ "GD_MLSL_LDS", # = _nlopt.GD_MLSL_LDS
36
+ "LD_MMA", # = _nlopt.LD_MMA
37
+ "LN_COBYLA", # = _nlopt.LN_COBYLA
38
+ "LN_NEWUOA", # = _nlopt.LN_NEWUOA
39
+ "LN_NEWUOA_BOUND", # = _nlopt.LN_NEWUOA_BOUND
40
+ "LN_NELDERMEAD", # = _nlopt.LN_NELDERMEAD
41
+ "LN_SBPLX", # = _nlopt.LN_SBPLX
42
+ "LN_AUGLAG", # = _nlopt.LN_AUGLAG
43
+ "LD_AUGLAG", # = _nlopt.LD_AUGLAG
44
+ "LN_AUGLAG_EQ", # = _nlopt.LN_AUGLAG_EQ
45
+ "LD_AUGLAG_EQ", # = _nlopt.LD_AUGLAG_EQ
46
+ "LN_BOBYQA", # = _nlopt.LN_BOBYQA
47
+ "GN_ISRES", # = _nlopt.GN_ISRES
48
+ "AUGLAG", # = _nlopt.AUGLAG
49
+ "AUGLAG_EQ", # = _nlopt.AUGLAG_EQ
50
+ "G_MLSL", # = _nlopt.G_MLSL
51
+ "G_MLSL_LDS", # = _nlopt.G_MLSL_LDS
52
+ "LD_SLSQP", # = _nlopt.LD_SLSQP
53
+ "LD_CCSAQ", # = _nlopt.LD_CCSAQ
54
+ "GN_ESCH", # = _nlopt.GN_ESCH
55
+ "GN_AGS", # = _nlopt.GN_AGS
56
+ ]
57
+
58
+ def _ensure_float(x):
59
+ if isinstance(x, torch.Tensor): return x.detach().cpu().item()
60
+ if isinstance(x, np.ndarray): return x.item()
61
+ return float(x)
62
+
63
+ def _ensure_tensor(x):
64
+ if isinstance(x, np.ndarray):
65
+ x.setflags(write=True)
66
+ return torch.from_numpy(x)
67
+ return torch.tensor(x, dtype=torch.float32)
68
+
69
+ inf = float('inf')
70
+ class NLOptOptimizer(TensorListOptimizer):
71
+ """Use nlopt as pytorch optimizer, with gradient supplied by pytorch autograd.
72
+ Note that this performs full minimization on each step,
73
+ so usually you would want to perform a single step, although performing multiple steps will refine the
74
+ solution.
75
+
76
+ Some algorithms are buggy with numpy>=2.
77
+
78
+ Args:
79
+ params: iterable of parameters to optimize or dicts defining parameter groups.
80
+ algorithm (int | _ALGOS_LITERAL): optimization algorithm from https://nlopt.readthedocs.io/en/latest/NLopt_Algorithms/
81
+ maxeval (int | None):
82
+ maximum allowed function evaluations, set to None to disable. But some stopping criterion
83
+ must be set otherwise nlopt will run forever.
84
+ lb (float | None, optional): optional lower bounds, some algorithms require this. Defaults to None.
85
+ ub (float | None, optional): optional upper bounds, some algorithms require this. Defaults to None.
86
+ stopval (float | None, optional): stop minimizing when an objective value ≤ stopval is found. Defaults to None.
87
+ ftol_rel (float | None, optional): set relative tolerance on function value. Defaults to None.
88
+ ftol_abs (float | None, optional): set absolute tolerance on function value. Defaults to None.
89
+ xtol_rel (float | None, optional): set relative tolerance on optimization parameters. Defaults to None.
90
+ xtol_abs (float | None, optional): set absolute tolerances on optimization parameters. Defaults to None.
91
+ maxtime (float | None, optional): stop when the optimization time (in seconds) exceeds maxtime. Defaults to None.
92
+ """
93
+ def __init__(
94
+ self,
95
+ params,
96
+ algorithm: int | _ALGOS_LITERAL,
97
+ maxeval: int | None,
98
+ lb: float | None = None,
99
+ ub: float | None = None,
100
+ stopval: float | None = None,
101
+ ftol_rel: float | None = None,
102
+ ftol_abs: float | None = None,
103
+ xtol_rel: float | None = None,
104
+ xtol_abs: float | None = None,
105
+ maxtime: float | None = None,
106
+ ):
107
+ defaults = dict(lb=lb, ub=ub)
108
+ super().__init__(params, defaults)
109
+
110
+ self.opt: nlopt.opt | None = None
111
+ if isinstance(algorithm, str): algorithm = getattr(nlopt, algorithm.upper())
112
+ self.algorithm: int = algorithm # type:ignore
113
+ self.algorithm_name: str | None = None
114
+
115
+ self.maxeval = maxeval; self.stopval = stopval
116
+ self.ftol_rel = ftol_rel; self.ftol_abs = ftol_abs
117
+ self.xtol_rel = xtol_rel; self.xtol_abs = xtol_abs
118
+ self.maxtime = maxtime
119
+
120
+ self._last_loss = None
121
+
122
+ def _f(self, x: np.ndarray, grad: np.ndarray, closure: _ClosureType, params: TensorList):
123
+ params.from_vec_(_ensure_tensor(x).to(params[0], copy=False))
124
+ if grad.size > 0:
125
+ with torch.enable_grad(): loss = closure()
126
+ self._last_loss = _ensure_float(loss)
127
+ grad[:] = params.ensure_grad_().grad.to_vec().reshape(grad.shape).detach().cpu().numpy()
128
+ return self._last_loss
129
+
130
+ self._last_loss = _ensure_float(closure(False))
131
+ return self._last_loss
132
+
133
+ @torch.no_grad
134
+ def step(self, closure: _ClosureType): # pylint: disable = signature-differs
135
+
136
+ params = self.get_params()
137
+
138
+ # make bounds
139
+ lb, ub = self.get_group_keys('lb', 'ub', cls=list)
140
+ lower = []
141
+ upper = []
142
+ for p, l, u in zip(params, lb, ub):
143
+ if l is None: l = -inf
144
+ if u is None: u = inf
145
+ lower.extend([l] * p.numel())
146
+ upper.extend([u] * p.numel())
147
+
148
+ x0 = params.to_vec().detach().cpu().numpy()
149
+
150
+ self.opt = nlopt.opt(self.algorithm, x0.size)
151
+ self.opt.set_min_objective(partial(self._f, closure = closure, params = params))
152
+ self.opt.set_lower_bounds(lower)
153
+ self.opt.set_upper_bounds(upper)
154
+
155
+ if self.maxeval is not None: self.opt.set_maxeval(self.maxeval)
156
+ if self.stopval is not None: self.opt.set_stopval(self.stopval)
157
+ if self.ftol_rel is not None: self.opt.set_ftol_rel(self.ftol_rel)
158
+ if self.ftol_abs is not None: self.opt.set_ftol_abs(self.ftol_abs)
159
+ if self.xtol_rel is not None: self.opt.set_xtol_rel(self.xtol_rel)
160
+ if self.xtol_abs is not None: self.opt.set_xtol_abs(self.xtol_abs)
161
+ if self.maxtime is not None: self.opt.set_maxtime(self.maxtime)
162
+
163
+ x = self.opt.optimize(x0)
164
+ params.from_vec_(torch.from_numpy(x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
165
+ return self._last_loss