torchzero 0.3.9__py3-none-any.whl → 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (153) hide show
  1. docs/source/conf.py +6 -4
  2. docs/source/docstring template.py +46 -0
  3. tests/test_identical.py +2 -3
  4. tests/test_opts.py +115 -68
  5. tests/test_tensorlist.py +2 -2
  6. tests/test_vars.py +62 -61
  7. torchzero/core/__init__.py +2 -3
  8. torchzero/core/module.py +185 -53
  9. torchzero/core/transform.py +327 -159
  10. torchzero/modules/__init__.py +3 -1
  11. torchzero/modules/clipping/clipping.py +120 -23
  12. torchzero/modules/clipping/ema_clipping.py +37 -22
  13. torchzero/modules/clipping/growth_clipping.py +20 -21
  14. torchzero/modules/experimental/__init__.py +30 -4
  15. torchzero/modules/experimental/absoap.py +53 -156
  16. torchzero/modules/experimental/adadam.py +22 -15
  17. torchzero/modules/experimental/adamY.py +21 -25
  18. torchzero/modules/experimental/adam_lambertw.py +149 -0
  19. torchzero/modules/{line_search/trust_region.py → experimental/adaptive_step_size.py} +37 -8
  20. torchzero/modules/experimental/adasoap.py +24 -129
  21. torchzero/modules/experimental/cosine.py +214 -0
  22. torchzero/modules/experimental/cubic_adam.py +97 -0
  23. torchzero/modules/experimental/curveball.py +12 -12
  24. torchzero/modules/{projections → experimental}/dct.py +11 -11
  25. torchzero/modules/experimental/eigendescent.py +120 -0
  26. torchzero/modules/experimental/etf.py +195 -0
  27. torchzero/modules/experimental/exp_adam.py +113 -0
  28. torchzero/modules/experimental/expanded_lbfgs.py +141 -0
  29. torchzero/modules/{projections → experimental}/fft.py +10 -10
  30. torchzero/modules/experimental/gradmin.py +2 -2
  31. torchzero/modules/experimental/hnewton.py +85 -0
  32. torchzero/modules/{quasi_newton/experimental → experimental}/modular_lbfgs.py +49 -50
  33. torchzero/modules/experimental/newton_solver.py +11 -11
  34. torchzero/modules/experimental/newtonnewton.py +92 -0
  35. torchzero/modules/experimental/parabolic_search.py +220 -0
  36. torchzero/modules/experimental/reduce_outward_lr.py +10 -7
  37. torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +12 -54
  38. torchzero/modules/experimental/subspace_preconditioners.py +20 -10
  39. torchzero/modules/experimental/tensor_adagrad.py +42 -0
  40. torchzero/modules/functional.py +12 -2
  41. torchzero/modules/grad_approximation/fdm.py +31 -4
  42. torchzero/modules/grad_approximation/forward_gradient.py +17 -7
  43. torchzero/modules/grad_approximation/grad_approximator.py +69 -24
  44. torchzero/modules/grad_approximation/rfdm.py +310 -50
  45. torchzero/modules/higher_order/__init__.py +1 -0
  46. torchzero/modules/higher_order/higher_order_newton.py +319 -0
  47. torchzero/modules/line_search/__init__.py +4 -4
  48. torchzero/modules/line_search/adaptive.py +99 -0
  49. torchzero/modules/line_search/backtracking.py +75 -31
  50. torchzero/modules/line_search/line_search.py +107 -49
  51. torchzero/modules/line_search/polynomial.py +233 -0
  52. torchzero/modules/line_search/scipy.py +20 -5
  53. torchzero/modules/line_search/strong_wolfe.py +52 -36
  54. torchzero/modules/misc/__init__.py +27 -0
  55. torchzero/modules/misc/debug.py +48 -0
  56. torchzero/modules/misc/escape.py +60 -0
  57. torchzero/modules/misc/gradient_accumulation.py +70 -0
  58. torchzero/modules/misc/misc.py +316 -0
  59. torchzero/modules/misc/multistep.py +158 -0
  60. torchzero/modules/misc/regularization.py +171 -0
  61. torchzero/modules/misc/split.py +103 -0
  62. torchzero/modules/{ops → misc}/switch.py +48 -7
  63. torchzero/modules/momentum/__init__.py +1 -1
  64. torchzero/modules/momentum/averaging.py +25 -10
  65. torchzero/modules/momentum/cautious.py +115 -40
  66. torchzero/modules/momentum/ema.py +92 -41
  67. torchzero/modules/momentum/experimental.py +21 -13
  68. torchzero/modules/momentum/matrix_momentum.py +145 -76
  69. torchzero/modules/momentum/momentum.py +25 -4
  70. torchzero/modules/ops/__init__.py +3 -31
  71. torchzero/modules/ops/accumulate.py +51 -25
  72. torchzero/modules/ops/binary.py +108 -62
  73. torchzero/modules/ops/multi.py +95 -34
  74. torchzero/modules/ops/reduce.py +31 -23
  75. torchzero/modules/ops/unary.py +37 -21
  76. torchzero/modules/ops/utility.py +53 -45
  77. torchzero/modules/optimizers/__init__.py +12 -3
  78. torchzero/modules/optimizers/adagrad.py +48 -29
  79. torchzero/modules/optimizers/adahessian.py +223 -0
  80. torchzero/modules/optimizers/adam.py +35 -37
  81. torchzero/modules/optimizers/adan.py +110 -0
  82. torchzero/modules/optimizers/adaptive_heavyball.py +57 -0
  83. torchzero/modules/optimizers/esgd.py +171 -0
  84. torchzero/modules/optimizers/ladagrad.py +183 -0
  85. torchzero/modules/optimizers/lion.py +4 -4
  86. torchzero/modules/optimizers/mars.py +91 -0
  87. torchzero/modules/optimizers/msam.py +186 -0
  88. torchzero/modules/optimizers/muon.py +32 -7
  89. torchzero/modules/optimizers/orthograd.py +4 -5
  90. torchzero/modules/optimizers/rmsprop.py +19 -19
  91. torchzero/modules/optimizers/rprop.py +89 -52
  92. torchzero/modules/optimizers/sam.py +163 -0
  93. torchzero/modules/optimizers/shampoo.py +55 -27
  94. torchzero/modules/optimizers/soap.py +40 -37
  95. torchzero/modules/optimizers/sophia_h.py +82 -25
  96. torchzero/modules/projections/__init__.py +2 -4
  97. torchzero/modules/projections/cast.py +51 -0
  98. torchzero/modules/projections/galore.py +4 -2
  99. torchzero/modules/projections/projection.py +212 -118
  100. torchzero/modules/quasi_newton/__init__.py +44 -5
  101. torchzero/modules/quasi_newton/cg.py +190 -39
  102. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +163 -0
  103. torchzero/modules/quasi_newton/lbfgs.py +154 -97
  104. torchzero/modules/quasi_newton/lsr1.py +102 -58
  105. torchzero/modules/quasi_newton/quasi_newton.py +1032 -177
  106. torchzero/modules/quasi_newton/trust_region.py +397 -0
  107. torchzero/modules/second_order/__init__.py +2 -2
  108. torchzero/modules/second_order/newton.py +245 -54
  109. torchzero/modules/second_order/newton_cg.py +311 -21
  110. torchzero/modules/second_order/nystrom.py +124 -21
  111. torchzero/modules/smoothing/gaussian.py +55 -21
  112. torchzero/modules/smoothing/laplacian.py +20 -12
  113. torchzero/modules/step_size/__init__.py +2 -0
  114. torchzero/modules/step_size/adaptive.py +122 -0
  115. torchzero/modules/step_size/lr.py +154 -0
  116. torchzero/modules/weight_decay/__init__.py +1 -1
  117. torchzero/modules/weight_decay/weight_decay.py +126 -10
  118. torchzero/modules/wrappers/optim_wrapper.py +40 -12
  119. torchzero/optim/wrappers/directsearch.py +281 -0
  120. torchzero/optim/wrappers/fcmaes.py +105 -0
  121. torchzero/optim/wrappers/mads.py +89 -0
  122. torchzero/optim/wrappers/nevergrad.py +20 -5
  123. torchzero/optim/wrappers/nlopt.py +28 -14
  124. torchzero/optim/wrappers/optuna.py +70 -0
  125. torchzero/optim/wrappers/scipy.py +167 -16
  126. torchzero/utils/__init__.py +3 -7
  127. torchzero/utils/derivatives.py +5 -4
  128. torchzero/utils/linalg/__init__.py +1 -1
  129. torchzero/utils/linalg/solve.py +251 -12
  130. torchzero/utils/numberlist.py +2 -0
  131. torchzero/utils/optimizer.py +55 -74
  132. torchzero/utils/python_tools.py +27 -4
  133. torchzero/utils/tensorlist.py +40 -28
  134. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/METADATA +76 -51
  135. torchzero-0.3.11.dist-info/RECORD +159 -0
  136. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/WHEEL +1 -1
  137. torchzero/core/preconditioner.py +0 -138
  138. torchzero/modules/experimental/algebraic_newton.py +0 -145
  139. torchzero/modules/experimental/soapy.py +0 -290
  140. torchzero/modules/experimental/spectral.py +0 -288
  141. torchzero/modules/experimental/structured_newton.py +0 -111
  142. torchzero/modules/experimental/tropical_newton.py +0 -136
  143. torchzero/modules/lr/__init__.py +0 -2
  144. torchzero/modules/lr/lr.py +0 -59
  145. torchzero/modules/lr/step_size.py +0 -97
  146. torchzero/modules/ops/debug.py +0 -25
  147. torchzero/modules/ops/misc.py +0 -419
  148. torchzero/modules/ops/split.py +0 -75
  149. torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
  150. torchzero/modules/quasi_newton/olbfgs.py +0 -196
  151. torchzero-0.3.9.dist-info/RECORD +0 -131
  152. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/licenses/LICENSE +0 -0
  153. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/top_level.txt +0 -0
@@ -7,7 +7,35 @@ from ...utils import Params, _copy_param_groups, _make_param_groups
7
7
 
8
8
 
9
9
  class Wrap(Module):
10
- """Custom param groups are supported only by `set_param_groups`. Settings passed to Modular will be ignored."""
10
+ """
11
+ Wraps a pytorch optimizer to use it as a module.
12
+
13
+ .. note::
14
+ Custom param groups are supported only by `set_param_groups`, settings passed to Modular will be ignored.
15
+
16
+ Args:
17
+ opt_fn (Callable[..., torch.optim.Optimizer] | torch.optim.Optimizer):
18
+ function that takes in parameters and returns the optimizer, for example :code:`torch.optim.Adam`
19
+ or :code:`lambda parameters: torch.optim.Adam(parameters, lr=1e-3)`
20
+ *args:
21
+ **kwargs:
22
+ Extra args to be passed to opt_fn. The function is called as :code:`opt_fn(parameters, *args, **kwargs)`.
23
+
24
+ Example:
25
+ wrapping pytorch_optimizer.StableAdamW
26
+
27
+ .. code-block:: py
28
+
29
+ from pytorch_optimizer import StableAdamW
30
+ opt = tz.Modular(
31
+ model.parameters(),
32
+ tz.m.Wrap(StableAdamW, lr=1),
33
+ tz.m.Cautious(),
34
+ tz.m.LR(1e-2)
35
+ )
36
+
37
+
38
+ """
11
39
  def __init__(self, opt_fn: Callable[..., torch.optim.Optimizer] | torch.optim.Optimizer, *args, **kwargs):
12
40
  super().__init__()
13
41
  self._opt_fn = opt_fn
@@ -24,8 +52,8 @@ class Wrap(Module):
24
52
  return super().set_param_groups(param_groups)
25
53
 
26
54
  @torch.no_grad
27
- def step(self, vars):
28
- params = vars.params
55
+ def step(self, var):
56
+ params = var.params
29
57
 
30
58
  # initialize opt on 1st step
31
59
  if self.optimizer is None:
@@ -35,18 +63,18 @@ class Wrap(Module):
35
63
 
36
64
  # set grad to update
37
65
  orig_grad = [p.grad for p in params]
38
- for p, u in zip(params, vars.get_update()):
66
+ for p, u in zip(params, var.get_update()):
39
67
  p.grad = u
40
68
 
41
69
  # if this module is last, can step with _opt directly
42
70
  # direct step can't be applied if next module is LR but _opt doesn't support lr,
43
71
  # and if there are multiple different per-parameter lrs (would be annoying to support)
44
- if vars.is_last and (
45
- (vars.last_module_lrs is None)
72
+ if var.is_last and (
73
+ (var.last_module_lrs is None)
46
74
  or
47
- (('lr' in self.optimizer.defaults) and (len(set(vars.last_module_lrs)) == 1))
75
+ (('lr' in self.optimizer.defaults) and (len(set(var.last_module_lrs)) == 1))
48
76
  ):
49
- lr = 1 if vars.last_module_lrs is None else vars.last_module_lrs[0]
77
+ lr = 1 if var.last_module_lrs is None else var.last_module_lrs[0]
50
78
 
51
79
  # update optimizer lr with desired lr
52
80
  if lr != 1:
@@ -68,19 +96,19 @@ class Wrap(Module):
68
96
  for p, g in zip(params, orig_grad):
69
97
  p.grad = g
70
98
 
71
- vars.stop = True; vars.skip_update = True
72
- return vars
99
+ var.stop = True; var.skip_update = True
100
+ return var
73
101
 
74
102
  # this is not the last module, meaning update is difference in parameters
75
103
  params_before_step = [p.clone() for p in params]
76
104
  self.optimizer.step() # step and update params
77
105
  for p, g in zip(params, orig_grad):
78
106
  p.grad = g
79
- vars.update = list(torch._foreach_sub(params_before_step, params)) # set update to difference between params
107
+ var.update = list(torch._foreach_sub(params_before_step, params)) # set update to difference between params
80
108
  for p, o in zip(params, params_before_step):
81
109
  p.set_(o) # pyright: ignore[reportArgumentType]
82
110
 
83
- return vars
111
+ return var
84
112
 
85
113
  def reset(self):
86
114
  super().reset()
@@ -0,0 +1,281 @@
1
+ from collections.abc import Callable
2
+ from functools import partial
3
+ from typing import Any, Literal
4
+
5
+ import directsearch
6
+ import numpy as np
7
+ import torch
8
+ from directsearch.ds import DEFAULT_PARAMS
9
+
10
+ from ...modules.second_order.newton import tikhonov_
11
+ from ...utils import Optimizer, TensorList
12
+
13
+
14
+ def _ensure_float(x):
15
+ if isinstance(x, torch.Tensor): return x.detach().cpu().item()
16
+ if isinstance(x, np.ndarray): return x.item()
17
+ return float(x)
18
+
19
+ def _ensure_numpy(x):
20
+ if isinstance(x, torch.Tensor): return x.detach().cpu()
21
+ if isinstance(x, np.ndarray): return x
22
+ return np.array(x)
23
+
24
+
25
+ Closure = Callable[[bool], Any]
26
+
27
+
28
+ class DirectSearch(Optimizer):
29
+ """Use directsearch as pytorch optimizer.
30
+
31
+ Note that this performs full minimization on each step,
32
+ so usually you would want to perform a single step, although performing multiple steps will refine the
33
+ solution.
34
+
35
+ Args:
36
+ params: iterable of parameters to optimize or dicts defining parameter groups.
37
+
38
+ rho: Choice of the forcing function.
39
+
40
+ sketch_dim: Reduced dimension to generate polling directions in.
41
+
42
+ sketch_type: Sketching technique to be used.
43
+
44
+ maxevals: Maximum number of calls to f performed by the algorithm.
45
+
46
+ poll_type: Type of polling directions generated in the reduced spaces.
47
+
48
+ alpha0: Initial value for the stepsize parameter.
49
+
50
+ alpha_max: Maximum value for the stepsize parameter.
51
+
52
+ alpha_min: Minimum value for the stepsize parameter.
53
+
54
+ gamma_inc: Increase factor for the stepsize update.
55
+
56
+ gamma_dec: Decrease factor for the stepsize update.
57
+
58
+ verbose:
59
+ Boolean indicating whether information should be displayed during an algorithmic run.
60
+
61
+ print_freq:
62
+ Value indicating how frequently information should be displayed.
63
+
64
+ use_stochastic_three_points:
65
+ Boolean indicating whether the specific stochastic three points method should be used.
66
+
67
+ poll_scale_prob: Probability of scaling the polling directions.
68
+
69
+ poll_scale_factor: Factor used to scale the polling directions.
70
+
71
+ rho_uses_normd:
72
+ Boolean indicating whether the forcing function should account for the norm of the direction.
73
+
74
+
75
+ """
76
+ def __init__(
77
+ self,
78
+ params,
79
+ maxevals = DEFAULT_PARAMS['maxevals'], # Maximum number of function evaluations
80
+ rho = DEFAULT_PARAMS['rho'], # Forcing function
81
+ sketch_dim = DEFAULT_PARAMS['sketch_dim'], # Target dimension for sketching
82
+ sketch_type = DEFAULT_PARAMS['sketch_type'], # Sketching technique
83
+ poll_type = DEFAULT_PARAMS['poll_type'], # Polling direction type
84
+ alpha0 = DEFAULT_PARAMS['alpha0'], # Original stepsize value
85
+ alpha_max = DEFAULT_PARAMS['alpha_max'], # Maximum value for the stepsize
86
+ alpha_min = DEFAULT_PARAMS['alpha_min'], # Minimum value for the stepsize
87
+ gamma_inc = DEFAULT_PARAMS['gamma_inc'], # Increasing factor for the stepsize
88
+ gamma_dec = DEFAULT_PARAMS['gamma_dec'], # Decreasing factor for the stepsize
89
+ verbose = DEFAULT_PARAMS['verbose'], # Display information about the method
90
+ print_freq = DEFAULT_PARAMS['print_freq'], # How frequently to display information
91
+ use_stochastic_three_points = DEFAULT_PARAMS['use_stochastic_three_points'], # Boolean for a specific method
92
+ rho_uses_normd = DEFAULT_PARAMS['rho_uses_normd'], # Forcing function based on direction norm
93
+ ):
94
+ super().__init__(params, {})
95
+
96
+ kwargs = locals().copy()
97
+ del kwargs['self'], kwargs['params'], kwargs['__class__']
98
+ self._kwargs = kwargs
99
+
100
+ def _objective(self, x: np.ndarray, params: TensorList, closure):
101
+ params.from_vec_(torch.from_numpy(x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
102
+ return _ensure_float(closure(False))
103
+
104
+ @torch.no_grad
105
+ def step(self, closure: Closure):
106
+ params = self.get_params()
107
+
108
+ x0 = params.to_vec().detach().cpu().numpy()
109
+
110
+ res = directsearch.solve(
111
+ partial(self._objective, params = params, closure = closure),
112
+ x0 = x0,
113
+ **self._kwargs
114
+ )
115
+
116
+ params.from_vec_(torch.from_numpy(res.x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
117
+ return res.f
118
+
119
+
120
+
121
+ class DirectSearchDS(Optimizer):
122
+ def __init__(
123
+ self,
124
+ params,
125
+ maxevals = DEFAULT_PARAMS['maxevals'], # Maximum number of function evaluations
126
+ rho = DEFAULT_PARAMS['rho'], # Forcing function
127
+ poll_type = DEFAULT_PARAMS['poll_type'], # Polling direction type
128
+ alpha0 = DEFAULT_PARAMS['alpha0'], # Original stepsize value
129
+ alpha_max = DEFAULT_PARAMS['alpha_max'], # Maximum value for the stepsize
130
+ alpha_min = DEFAULT_PARAMS['alpha_min'], # Minimum value for the stepsize
131
+ gamma_inc = DEFAULT_PARAMS['gamma_inc'], # Increasing factor for the stepsize
132
+ gamma_dec = DEFAULT_PARAMS['gamma_dec'], # Decreasing factor for the stepsize
133
+ verbose = DEFAULT_PARAMS['verbose'], # Display information about the method
134
+ print_freq = DEFAULT_PARAMS['print_freq'], # How frequently to display information
135
+ rho_uses_normd = DEFAULT_PARAMS['rho_uses_normd'], # Forcing function based on direction norm
136
+ ):
137
+ super().__init__(params, {})
138
+
139
+ kwargs = locals().copy()
140
+ del kwargs['self'], kwargs['params'], kwargs['__class__']
141
+ self._kwargs = kwargs
142
+
143
+ def _objective(self, x: np.ndarray, params: TensorList, closure):
144
+ params.from_vec_(torch.from_numpy(x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
145
+ return _ensure_float(closure(False))
146
+
147
+ @torch.no_grad
148
+ def step(self, closure: Closure):
149
+ params = self.get_params()
150
+
151
+ x0 = params.to_vec().detach().cpu().numpy()
152
+
153
+ res = directsearch.solve_directsearch(
154
+ partial(self._objective, params = params, closure = closure),
155
+ x0 = x0,
156
+ **self._kwargs
157
+ )
158
+
159
+ params.from_vec_(torch.from_numpy(res.x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
160
+ return res.f
161
+
162
+ class DirectSearchProbabilistic(Optimizer):
163
+ def __init__(
164
+ self,
165
+ params,
166
+ maxevals = DEFAULT_PARAMS['maxevals'], # Maximum number of function evaluations
167
+ rho = DEFAULT_PARAMS['rho'], # Forcing function
168
+ alpha0 = DEFAULT_PARAMS['alpha0'], # Original stepsize value
169
+ alpha_max = DEFAULT_PARAMS['alpha_max'], # Maximum value for the stepsize
170
+ alpha_min = DEFAULT_PARAMS['alpha_min'], # Minimum value for the stepsize
171
+ gamma_inc = DEFAULT_PARAMS['gamma_inc'], # Increasing factor for the stepsize
172
+ gamma_dec = DEFAULT_PARAMS['gamma_dec'], # Decreasing factor for the stepsize
173
+ verbose = DEFAULT_PARAMS['verbose'], # Display information about the method
174
+ print_freq = DEFAULT_PARAMS['print_freq'], # How frequently to display information
175
+ rho_uses_normd = DEFAULT_PARAMS['rho_uses_normd'], # Forcing function based on direction norm
176
+ ):
177
+ super().__init__(params, {})
178
+
179
+ kwargs = locals().copy()
180
+ del kwargs['self'], kwargs['params'], kwargs['__class__']
181
+ self._kwargs = kwargs
182
+
183
+ def _objective(self, x: np.ndarray, params: TensorList, closure):
184
+ params.from_vec_(torch.from_numpy(x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
185
+ return _ensure_float(closure(False))
186
+
187
+ @torch.no_grad
188
+ def step(self, closure: Closure):
189
+ params = self.get_params()
190
+
191
+ x0 = params.to_vec().detach().cpu().numpy()
192
+
193
+ res = directsearch.solve_probabilistic_directsearch(
194
+ partial(self._objective, params = params, closure = closure),
195
+ x0 = x0,
196
+ **self._kwargs
197
+ )
198
+
199
+ params.from_vec_(torch.from_numpy(res.x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
200
+ return res.f
201
+
202
+
203
+ class DirectSearchSubspace(Optimizer):
204
+ def __init__(
205
+ self,
206
+ params,
207
+ maxevals = DEFAULT_PARAMS['maxevals'], # Maximum number of function evaluations
208
+ rho = DEFAULT_PARAMS['rho'], # Forcing function
209
+ sketch_dim = DEFAULT_PARAMS['sketch_dim'], # Target dimension for sketching
210
+ sketch_type = DEFAULT_PARAMS['sketch_type'], # Sketching technique
211
+ poll_type = DEFAULT_PARAMS['poll_type'], # Polling direction type
212
+ alpha0 = DEFAULT_PARAMS['alpha0'], # Original stepsize value
213
+ alpha_max = DEFAULT_PARAMS['alpha_max'], # Maximum value for the stepsize
214
+ alpha_min = DEFAULT_PARAMS['alpha_min'], # Minimum value for the stepsize
215
+ gamma_inc = DEFAULT_PARAMS['gamma_inc'], # Increasing factor for the stepsize
216
+ gamma_dec = DEFAULT_PARAMS['gamma_dec'], # Decreasing factor for the stepsize
217
+ verbose = DEFAULT_PARAMS['verbose'], # Display information about the method
218
+ print_freq = DEFAULT_PARAMS['print_freq'], # How frequently to display information
219
+ rho_uses_normd = DEFAULT_PARAMS['rho_uses_normd'], # Forcing function based on direction norm
220
+ ):
221
+ super().__init__(params, {})
222
+
223
+ kwargs = locals().copy()
224
+ del kwargs['self'], kwargs['params'], kwargs['__class__']
225
+ self._kwargs = kwargs
226
+
227
+ def _objective(self, x: np.ndarray, params: TensorList, closure):
228
+ params.from_vec_(torch.from_numpy(x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
229
+ return _ensure_float(closure(False))
230
+
231
+ @torch.no_grad
232
+ def step(self, closure: Closure):
233
+ params = self.get_params()
234
+
235
+ x0 = params.to_vec().detach().cpu().numpy()
236
+
237
+ res = directsearch.solve_subspace_directsearch(
238
+ partial(self._objective, params = params, closure = closure),
239
+ x0 = x0,
240
+ **self._kwargs
241
+ )
242
+
243
+ params.from_vec_(torch.from_numpy(res.x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
244
+ return res.f
245
+
246
+
247
+
248
+ class DirectSearchSTP(Optimizer):
249
+ def __init__(
250
+ self,
251
+ params,
252
+ maxevals = DEFAULT_PARAMS['maxevals'], # Maximum number of function evaluations
253
+ alpha0 = DEFAULT_PARAMS['alpha0'], # Original stepsize value
254
+ alpha_min = DEFAULT_PARAMS['alpha_min'], # Minimum value for the stepsize
255
+ verbose = DEFAULT_PARAMS['verbose'], # Display information about the method
256
+ print_freq = DEFAULT_PARAMS['print_freq'], # How frequently to display information
257
+ ):
258
+ super().__init__(params, {})
259
+
260
+ kwargs = locals().copy()
261
+ del kwargs['self'], kwargs['params'], kwargs['__class__']
262
+ self._kwargs = kwargs
263
+
264
+ def _objective(self, x: np.ndarray, params: TensorList, closure):
265
+ params.from_vec_(torch.from_numpy(x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
266
+ return _ensure_float(closure(False))
267
+
268
+ @torch.no_grad
269
+ def step(self, closure: Closure):
270
+ params = self.get_params()
271
+
272
+ x0 = params.to_vec().detach().cpu().numpy()
273
+
274
+ res = directsearch.solve_stp(
275
+ partial(self._objective, params = params, closure = closure),
276
+ x0 = x0,
277
+ **self._kwargs
278
+ )
279
+
280
+ params.from_vec_(torch.from_numpy(res.x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
281
+ return res.f
@@ -0,0 +1,105 @@
1
+ from collections.abc import Callable
2
+ from functools import partial
3
+ from typing import Any, Literal
4
+
5
+ import fcmaes
6
+ import fcmaes.optimizer
7
+ import fcmaes.retry
8
+ import numpy as np
9
+ import torch
10
+
11
+ from ...utils import Optimizer, TensorList
12
+
13
+ Closure = Callable[[bool], Any]
14
+
15
+
16
+ def _ensure_float(x) -> float:
17
+ if isinstance(x, torch.Tensor): return x.detach().cpu().item()
18
+ if isinstance(x, np.ndarray): return float(x.item())
19
+ return float(x)
20
+
21
+ def silence_fcmaes():
22
+ fcmaes.retry.logger.disable('fcmaes')
23
+
24
+ class FcmaesWrapper(Optimizer):
25
+ """Use fcmaes as pytorch optimizer. Particularly fcmaes has BITEOPT which appears to win in many benchmarks.
26
+
27
+ Note that this performs full minimization on each step, so only perform one step with this.
28
+
29
+ Args:
30
+ params: iterable of parameters to optimize or dicts defining parameter groups.
31
+ lb (float): lower bounds, this can also be specified in param_groups.
32
+ ub (float): upper bounds, this can also be specified in param_groups.
33
+ optimizer (fcmaes.optimizer.Optimizer | None, optional):
34
+ optimizer to use. Default is a sequence of differential evolution and CMA-ES.
35
+ max_evaluations (int | None, optional):
36
+ Forced termination of all optimization runs after `max_evaluations` function evaluations.
37
+ Only used if optimizer is undefined, otherwise this setting is defined in the optimizer. Defaults to 50000.
38
+ value_limit (float | None, optional): Upper limit for optimized function values to be stored. Defaults to np.inf.
39
+ num_retries (int | None, optional): Number of optimization retries. Defaults to 1.
40
+ popsize (int | None, optional):
41
+ CMA-ES population size used for all CMA-ES runs.
42
+ Not used for differential evolution.
43
+ Ignored if parameter optimizer is defined. Defaults to 31.
44
+ capacity (int | None, optional): capacity of the evaluation store.. Defaults to 500.
45
+ stop_fitness (float | None, optional):
46
+ Limit for fitness value. optimization runs terminate if this value is reached. Defaults to -np.inf.
47
+ statistic_num (int | None, optional):
48
+ if > 0 stores the progress of the optimization. Defines the size of this store. Defaults to 0.
49
+ """
50
+ def __init__(
51
+ self,
52
+ params,
53
+ lb: float,
54
+ ub: float,
55
+ optimizer: fcmaes.optimizer.Optimizer | None = None,
56
+ max_evaluations: int | None = 50000,
57
+ value_limit: float | None = np.inf,
58
+ num_retries: int | None = 1,
59
+ # workers: int = 1,
60
+ popsize: int | None = 31,
61
+ capacity: int | None = 500,
62
+ stop_fitness: float | None = -np.inf,
63
+ statistic_num: int | None = 0
64
+ ):
65
+ super().__init__(params, lb=lb, ub=ub)
66
+ silence_fcmaes()
67
+ kwargs = locals().copy()
68
+ del kwargs['self'], kwargs['params'], kwargs['lb'], kwargs['ub'], kwargs['__class__']
69
+ self._kwargs = kwargs
70
+ self._kwargs['workers'] = 1
71
+
72
+ def _objective(self, x: np.ndarray, params: TensorList, closure) -> float:
73
+ if self.raised: return np.inf
74
+ try:
75
+ params.from_vec_(torch.from_numpy(x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
76
+ return _ensure_float(closure(False))
77
+ except Exception as e:
78
+ # ha ha, I found a way to make exceptions work in fcmaes and scipy direct
79
+ self.e = e
80
+ self.raised = True
81
+ return np.inf
82
+
83
+ @torch.no_grad
84
+ def step(self, closure: Closure):
85
+ self.raised = False
86
+ self.e = None
87
+
88
+ params = self.get_params()
89
+
90
+ lb, ub = self.group_vals('lb', 'ub', cls=list)
91
+ bounds = []
92
+ for p, l, u in zip(params, lb, ub):
93
+ bounds.extend([[l, u]] * p.numel())
94
+
95
+ res = fcmaes.retry.minimize(
96
+ partial(self._objective, params=params, closure=closure), # pyright:ignore[reportArgumentType]
97
+ bounds=bounds, # pyright:ignore[reportArgumentType]
98
+ **self._kwargs
99
+ )
100
+
101
+ params.from_vec_(torch.from_numpy(res.x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
102
+
103
+ if self.e is not None: raise self.e from None
104
+ return res.fun
105
+
@@ -0,0 +1,89 @@
1
+ from collections.abc import Callable
2
+ from functools import partial
3
+ from typing import Any, Literal
4
+
5
+ import numpy as np
6
+ import torch
7
+ from mads.mads import orthomads
8
+
9
+ from ...utils import Optimizer, TensorList
10
+
11
+
12
+ def _ensure_float(x):
13
+ if isinstance(x, torch.Tensor): return x.detach().cpu().item()
14
+ if isinstance(x, np.ndarray): return x.item()
15
+ return float(x)
16
+
17
+ def _ensure_numpy(x):
18
+ if isinstance(x, torch.Tensor): return x.detach().cpu()
19
+ if isinstance(x, np.ndarray): return x
20
+ return np.array(x)
21
+
22
+
23
+ Closure = Callable[[bool], Any]
24
+
25
+
26
+ class MADS(Optimizer):
27
+ """Use mads.orthomads as pytorch optimizer.
28
+
29
+ Note that this performs full minimization on each step,
30
+ so usually you would want to perform a single step, although performing multiple steps will refine the
31
+ solution.
32
+
33
+ Args:
34
+ params: iterable of parameters to optimize or dicts defining parameter groups.
35
+ lb (float): lower bounds, this can also be specified in param_groups.
36
+ ub (float): upper bounds, this can also be specified in param_groups.
37
+ dp (float, optional): Initial poll size as percent of bounds. Defaults to 0.1.
38
+ dm (float, optional): Initial mesh size as percent of bounds. Defaults to 0.01.
39
+ dp_tol (float, optional): Minimum poll size stopping criteria. Defaults to -float('inf').
40
+ nitermax (float, optional): Maximum objective function evaluations. Defaults to float('inf').
41
+ displog (bool, optional): whether to show log. Defaults to False.
42
+ savelog (bool, optional): whether to save log. Defaults to False.
43
+ """
44
+ def __init__(
45
+ self,
46
+ params,
47
+ lb: float,
48
+ ub: float,
49
+ dp = 0.1,
50
+ dm = 0.01,
51
+ dp_tol = -float('inf'),
52
+ nitermax = float('inf'),
53
+ displog = False,
54
+ savelog = False,
55
+ ):
56
+ super().__init__(params, lb=lb, ub=ub)
57
+
58
+ kwargs = locals().copy()
59
+ del kwargs['self'], kwargs['params'], kwargs['lb'], kwargs['ub'], kwargs['__class__']
60
+ self._kwargs = kwargs
61
+
62
+ def _objective(self, x: np.ndarray, params: TensorList, closure):
63
+ params.from_vec_(torch.from_numpy(x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
64
+ return _ensure_float(closure(False))
65
+
66
+ @torch.no_grad
67
+ def step(self, closure: Closure):
68
+ params = self.get_params()
69
+
70
+ x0 = params.to_vec().detach().cpu().numpy()
71
+
72
+ lb, ub = self.group_vals('lb', 'ub', cls=list)
73
+ bounds_lower = []
74
+ bounds_upper = []
75
+ for p, l, u in zip(params, lb, ub):
76
+ bounds_lower.extend([l] * p.numel())
77
+ bounds_upper.extend([u] * p.numel())
78
+
79
+ f, x = orthomads(
80
+ design_variables=x0,
81
+ bounds_upper=np.asarray(bounds_upper),
82
+ bounds_lower=np.asarray(bounds_lower),
83
+ objective_function=partial(self._objective, params = params, closure = closure),
84
+ **self._kwargs
85
+ )
86
+
87
+ params.from_vec_(torch.from_numpy(x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
88
+ return f
89
+
@@ -9,12 +9,12 @@ import nevergrad as ng
9
9
  from ...utils import Optimizer
10
10
 
11
11
 
12
- def _ensure_float(x):
12
+ def _ensure_float(x) -> float:
13
13
  if isinstance(x, torch.Tensor): return x.detach().cpu().item()
14
- if isinstance(x, np.ndarray): return x.item()
14
+ if isinstance(x, np.ndarray): return float(x.item())
15
15
  return float(x)
16
16
 
17
- class NevergradOptimizer(Optimizer):
17
+ class NevergradWrapper(Optimizer):
18
18
  """Use nevergrad optimizer as pytorch optimizer.
19
19
  Note that it is recommended to specify `budget` to the number of iterations you expect to run,
20
20
  as some nevergrad optimizers will error without it.
@@ -29,6 +29,12 @@ class NevergradOptimizer(Optimizer):
29
29
  use certain rule for first 50% of the steps, and then switch to another rule.
30
30
  This parameter doesn't actually limit the maximum number of steps!
31
31
  But it doesn't have to be exact. Defaults to None.
32
+ lb (float | None, optional):
33
+ lower bounds, this can also be specified in param_groups. Bounds are optional, however
34
+ some nevergrad algorithms will raise an exception of bounds are not specified.
35
+ ub (float, optional):
36
+ upper bounds, this can also be specified in param_groups. Bounds are optional, however
37
+ some nevergrad algorithms will raise an exception of bounds are not specified.
32
38
  mutable_sigma (bool, optional):
33
39
  nevergrad parameter, sets whether the mutation standard deviation must mutate as well
34
40
  (for mutation based algorithms). Defaults to False.
@@ -44,11 +50,20 @@ class NevergradOptimizer(Optimizer):
44
50
  params,
45
51
  opt_cls:"type[ng.optimizers.base.Optimizer] | abc.Callable[..., ng.optimizers.base.Optimizer]",
46
52
  budget: int | None = None,
47
- mutable_sigma = False,
48
53
  lb: float | None = None,
49
54
  ub: float | None = None,
55
+ mutable_sigma = False,
50
56
  use_init = True,
51
57
  ):
58
+ """_summary_
59
+
60
+ Args:
61
+ params (_type_): _description_
62
+ opt_cls (type[ng.optimizers.base.Optimizer] | abc.Callable[..., ng.optimizers.base.Optimizer]): _description_
63
+ budget (int | None, optional): _description_. Defaults to None.
64
+ mutable_sigma (bool, optional): _description_. Defaults to False.
65
+ use_init (bool, optional): _description_. Defaults to True.
66
+ """
52
67
  defaults = dict(lb=lb, ub=ub, use_init=use_init, mutable_sigma=mutable_sigma)
53
68
  super().__init__(params, defaults)
54
69
  self.opt_cls = opt_cls
@@ -56,7 +71,7 @@ class NevergradOptimizer(Optimizer):
56
71
  self.budget = budget
57
72
 
58
73
  @torch.no_grad
59
- def step(self, closure): # type:ignore # pylint:disable=signature-differs
74
+ def step(self, closure): # pylint:disable=signature-differs # pyright:ignore[reportIncompatibleMethodOverride]
60
75
  params = self.get_params()
61
76
  if self.opt is None:
62
77
  ng_params = []