torchzero 0.3.15__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. tests/test_identical.py +2 -2
  2. tests/test_module_autograd.py +586 -0
  3. tests/test_objective.py +188 -0
  4. tests/test_opts.py +43 -33
  5. tests/test_tensorlist.py +0 -8
  6. tests/test_utils_optimizer.py +0 -1
  7. torchzero/__init__.py +1 -1
  8. torchzero/core/__init__.py +7 -4
  9. torchzero/core/chain.py +20 -23
  10. torchzero/core/functional.py +90 -24
  11. torchzero/core/modular.py +48 -52
  12. torchzero/core/module.py +130 -50
  13. torchzero/core/objective.py +948 -0
  14. torchzero/core/reformulation.py +55 -24
  15. torchzero/core/transform.py +261 -367
  16. torchzero/linalg/__init__.py +10 -0
  17. torchzero/linalg/eigh.py +34 -0
  18. torchzero/linalg/linalg_utils.py +14 -0
  19. torchzero/{utils/linalg → linalg}/linear_operator.py +99 -49
  20. torchzero/linalg/matrix_power.py +28 -0
  21. torchzero/linalg/orthogonalize.py +95 -0
  22. torchzero/{utils/linalg → linalg}/qr.py +4 -2
  23. torchzero/{utils/linalg → linalg}/solve.py +76 -88
  24. torchzero/linalg/svd.py +20 -0
  25. torchzero/linalg/torch_linalg.py +168 -0
  26. torchzero/modules/adaptive/__init__.py +1 -1
  27. torchzero/modules/adaptive/adagrad.py +163 -213
  28. torchzero/modules/adaptive/adahessian.py +74 -103
  29. torchzero/modules/adaptive/adam.py +53 -76
  30. torchzero/modules/adaptive/adan.py +49 -30
  31. torchzero/modules/adaptive/adaptive_heavyball.py +11 -6
  32. torchzero/modules/adaptive/aegd.py +12 -12
  33. torchzero/modules/adaptive/esgd.py +98 -119
  34. torchzero/modules/adaptive/lion.py +5 -10
  35. torchzero/modules/adaptive/lmadagrad.py +87 -32
  36. torchzero/modules/adaptive/mars.py +5 -5
  37. torchzero/modules/adaptive/matrix_momentum.py +47 -51
  38. torchzero/modules/adaptive/msam.py +70 -52
  39. torchzero/modules/adaptive/muon.py +59 -124
  40. torchzero/modules/adaptive/natural_gradient.py +33 -28
  41. torchzero/modules/adaptive/orthograd.py +11 -15
  42. torchzero/modules/adaptive/rmsprop.py +83 -75
  43. torchzero/modules/adaptive/rprop.py +48 -47
  44. torchzero/modules/adaptive/sam.py +55 -45
  45. torchzero/modules/adaptive/shampoo.py +123 -129
  46. torchzero/modules/adaptive/soap.py +207 -143
  47. torchzero/modules/adaptive/sophia_h.py +106 -130
  48. torchzero/modules/clipping/clipping.py +15 -18
  49. torchzero/modules/clipping/ema_clipping.py +31 -25
  50. torchzero/modules/clipping/growth_clipping.py +14 -17
  51. torchzero/modules/conjugate_gradient/cg.py +26 -37
  52. torchzero/modules/experimental/__init__.py +2 -6
  53. torchzero/modules/experimental/coordinate_momentum.py +36 -0
  54. torchzero/modules/experimental/curveball.py +25 -41
  55. torchzero/modules/experimental/gradmin.py +2 -2
  56. torchzero/modules/experimental/higher_order_newton.py +14 -40
  57. torchzero/modules/experimental/newton_solver.py +22 -53
  58. torchzero/modules/experimental/newtonnewton.py +15 -12
  59. torchzero/modules/experimental/reduce_outward_lr.py +7 -7
  60. torchzero/modules/experimental/scipy_newton_cg.py +21 -24
  61. torchzero/modules/experimental/spsa1.py +3 -3
  62. torchzero/modules/experimental/structural_projections.py +1 -4
  63. torchzero/modules/functional.py +1 -1
  64. torchzero/modules/grad_approximation/forward_gradient.py +7 -7
  65. torchzero/modules/grad_approximation/grad_approximator.py +23 -16
  66. torchzero/modules/grad_approximation/rfdm.py +20 -17
  67. torchzero/modules/least_squares/gn.py +90 -42
  68. torchzero/modules/line_search/backtracking.py +2 -2
  69. torchzero/modules/line_search/line_search.py +32 -32
  70. torchzero/modules/line_search/strong_wolfe.py +2 -2
  71. torchzero/modules/misc/debug.py +12 -12
  72. torchzero/modules/misc/escape.py +10 -10
  73. torchzero/modules/misc/gradient_accumulation.py +10 -78
  74. torchzero/modules/misc/homotopy.py +16 -8
  75. torchzero/modules/misc/misc.py +120 -122
  76. torchzero/modules/misc/multistep.py +50 -48
  77. torchzero/modules/misc/regularization.py +49 -44
  78. torchzero/modules/misc/split.py +30 -28
  79. torchzero/modules/misc/switch.py +37 -32
  80. torchzero/modules/momentum/averaging.py +14 -14
  81. torchzero/modules/momentum/cautious.py +34 -28
  82. torchzero/modules/momentum/momentum.py +11 -11
  83. torchzero/modules/ops/__init__.py +4 -4
  84. torchzero/modules/ops/accumulate.py +21 -21
  85. torchzero/modules/ops/binary.py +67 -66
  86. torchzero/modules/ops/higher_level.py +19 -19
  87. torchzero/modules/ops/multi.py +44 -41
  88. torchzero/modules/ops/reduce.py +26 -23
  89. torchzero/modules/ops/unary.py +53 -53
  90. torchzero/modules/ops/utility.py +47 -46
  91. torchzero/modules/projections/galore.py +1 -1
  92. torchzero/modules/projections/projection.py +43 -43
  93. torchzero/modules/quasi_newton/damping.py +1 -1
  94. torchzero/modules/quasi_newton/lbfgs.py +7 -7
  95. torchzero/modules/quasi_newton/lsr1.py +7 -7
  96. torchzero/modules/quasi_newton/quasi_newton.py +10 -10
  97. torchzero/modules/quasi_newton/sg2.py +19 -19
  98. torchzero/modules/restarts/restars.py +26 -24
  99. torchzero/modules/second_order/__init__.py +2 -2
  100. torchzero/modules/second_order/ifn.py +31 -62
  101. torchzero/modules/second_order/inm.py +49 -53
  102. torchzero/modules/second_order/multipoint.py +40 -80
  103. torchzero/modules/second_order/newton.py +57 -90
  104. torchzero/modules/second_order/newton_cg.py +102 -154
  105. torchzero/modules/second_order/nystrom.py +157 -177
  106. torchzero/modules/second_order/rsn.py +106 -96
  107. torchzero/modules/smoothing/laplacian.py +13 -12
  108. torchzero/modules/smoothing/sampling.py +11 -10
  109. torchzero/modules/step_size/adaptive.py +23 -23
  110. torchzero/modules/step_size/lr.py +15 -15
  111. torchzero/modules/termination/termination.py +32 -30
  112. torchzero/modules/trust_region/cubic_regularization.py +2 -2
  113. torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
  114. torchzero/modules/trust_region/trust_cg.py +1 -1
  115. torchzero/modules/trust_region/trust_region.py +27 -22
  116. torchzero/modules/variance_reduction/svrg.py +21 -18
  117. torchzero/modules/weight_decay/__init__.py +2 -1
  118. torchzero/modules/weight_decay/reinit.py +83 -0
  119. torchzero/modules/weight_decay/weight_decay.py +12 -13
  120. torchzero/modules/wrappers/optim_wrapper.py +10 -10
  121. torchzero/modules/zeroth_order/cd.py +9 -6
  122. torchzero/optim/root.py +3 -3
  123. torchzero/optim/utility/split.py +2 -1
  124. torchzero/optim/wrappers/directsearch.py +27 -63
  125. torchzero/optim/wrappers/fcmaes.py +14 -35
  126. torchzero/optim/wrappers/mads.py +11 -31
  127. torchzero/optim/wrappers/moors.py +66 -0
  128. torchzero/optim/wrappers/nevergrad.py +4 -4
  129. torchzero/optim/wrappers/nlopt.py +31 -25
  130. torchzero/optim/wrappers/optuna.py +6 -13
  131. torchzero/optim/wrappers/pybobyqa.py +124 -0
  132. torchzero/optim/wrappers/scipy/__init__.py +7 -0
  133. torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
  134. torchzero/optim/wrappers/scipy/brute.py +48 -0
  135. torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
  136. torchzero/optim/wrappers/scipy/direct.py +69 -0
  137. torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
  138. torchzero/optim/wrappers/scipy/experimental.py +141 -0
  139. torchzero/optim/wrappers/scipy/minimize.py +151 -0
  140. torchzero/optim/wrappers/scipy/sgho.py +111 -0
  141. torchzero/optim/wrappers/wrapper.py +121 -0
  142. torchzero/utils/__init__.py +7 -25
  143. torchzero/utils/compile.py +2 -2
  144. torchzero/utils/derivatives.py +93 -69
  145. torchzero/utils/optimizer.py +4 -77
  146. torchzero/utils/python_tools.py +31 -0
  147. torchzero/utils/tensorlist.py +11 -5
  148. torchzero/utils/thoad_tools.py +68 -0
  149. {torchzero-0.3.15.dist-info → torchzero-0.4.0.dist-info}/METADATA +1 -1
  150. torchzero-0.4.0.dist-info/RECORD +191 -0
  151. tests/test_vars.py +0 -185
  152. torchzero/core/var.py +0 -376
  153. torchzero/modules/experimental/momentum.py +0 -160
  154. torchzero/optim/wrappers/scipy.py +0 -572
  155. torchzero/utils/linalg/__init__.py +0 -12
  156. torchzero/utils/linalg/matrix_funcs.py +0 -87
  157. torchzero/utils/linalg/orthogonalize.py +0 -12
  158. torchzero/utils/linalg/svd.py +0 -20
  159. torchzero/utils/ops.py +0 -10
  160. torchzero-0.3.15.dist-info/RECORD +0 -175
  161. /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
  162. {torchzero-0.3.15.dist-info → torchzero-0.4.0.dist-info}/WHEEL +0 -0
  163. {torchzero-0.3.15.dist-info → torchzero-0.4.0.dist-info}/top_level.txt +0 -0
@@ -3,7 +3,7 @@ from typing import Any
3
3
  import torch
4
4
 
5
5
  from ...core.module import Module
6
- from ...utils import Params, _copy_param_groups, _make_param_groups
6
+ from ...utils.params import Params, _copy_param_groups, _make_param_groups
7
7
 
8
8
 
9
9
  class Wrap(Module):
@@ -66,8 +66,8 @@ class Wrap(Module):
66
66
  return super().set_param_groups(param_groups)
67
67
 
68
68
  @torch.no_grad
69
- def step(self, var):
70
- params = var.params
69
+ def apply(self, objective):
70
+ params = objective.params
71
71
 
72
72
  # initialize opt on 1st step
73
73
  if self.optimizer is None:
@@ -76,7 +76,7 @@ class Wrap(Module):
76
76
  self.optimizer = self._opt_fn(param_groups, *self._opt_args, **self._opt_kwargs)
77
77
 
78
78
  # set optimizer per-parameter settings
79
- if self.defaults["use_param_groups"] and var.modular is not None:
79
+ if self.defaults["use_param_groups"] and objective.modular is not None:
80
80
  for group in self.optimizer.param_groups:
81
81
  first_param = group['params'][0]
82
82
  setting = self.settings[first_param]
@@ -91,19 +91,19 @@ class Wrap(Module):
91
91
 
92
92
  # set grad to update
93
93
  orig_grad = [p.grad for p in params]
94
- for p, u in zip(params, var.get_update()):
94
+ for p, u in zip(params, objective.get_updates()):
95
95
  p.grad = u
96
96
 
97
97
  # if this is last module, simply use optimizer to update parameters
98
- if var.modular is not None and self is var.modular.modules[-1]:
98
+ if objective.modular is not None and self is objective.modular.modules[-1]:
99
99
  self.optimizer.step()
100
100
 
101
101
  # restore grad
102
102
  for p, g in zip(params, orig_grad):
103
103
  p.grad = g
104
104
 
105
- var.stop = True; var.skip_update = True
106
- return var
105
+ objective.stop = True; objective.skip_update = True
106
+ return objective
107
107
 
108
108
  # this is not the last module, meaning update is difference in parameters
109
109
  # and passed to next module
@@ -111,11 +111,11 @@ class Wrap(Module):
111
111
  self.optimizer.step() # step and update params
112
112
  for p, g in zip(params, orig_grad):
113
113
  p.grad = g
114
- var.update = list(torch._foreach_sub(params_before_step, params)) # set update to difference between params
114
+ objective.updates = list(torch._foreach_sub(params_before_step, params)) # set update to difference between params
115
115
  for p, o in zip(params, params_before_step):
116
116
  p.set_(o) # pyright: ignore[reportArgumentType]
117
117
 
118
- return var
118
+ return objective
119
119
 
120
120
  def reset(self):
121
121
  super().reset()
@@ -33,13 +33,16 @@ class CD(Module):
33
33
  defaults = dict(h=h, grad=grad, adaptive=adaptive, index=index, threepoint=threepoint)
34
34
  super().__init__(defaults)
35
35
 
36
+ def update(self, objective): raise RuntimeError
37
+ def apply(self, objective): raise RuntimeError
38
+
36
39
  @torch.no_grad
37
- def step(self, var):
38
- closure = var.closure
40
+ def step(self, objective):
41
+ closure = objective.closure
39
42
  if closure is None:
40
43
  raise RuntimeError("CD requires closure")
41
44
 
42
- params = TensorList(var.params)
45
+ params = TensorList(objective.params)
43
46
  ndim = params.global_numel()
44
47
 
45
48
  grad_step_size = self.defaults['grad']
@@ -79,7 +82,7 @@ class CD(Module):
79
82
  else:
80
83
  warnings.warn("CD adaptive=True only works with threepoint=True")
81
84
 
82
- f_0 = var.get_loss(False)
85
+ f_0 = objective.get_loss(False)
83
86
  params.flat_set_lambda_(idx, lambda x: x + h)
84
87
  f_p = closure(False)
85
88
 
@@ -117,6 +120,6 @@ class CD(Module):
117
120
  # ----------------------------- create the update ---------------------------- #
118
121
  update = params.zeros_like()
119
122
  update.flat_set_(idx, alpha)
120
- var.update = update
121
- return var
123
+ objective.updates = update
124
+ return objective
122
125
 
torchzero/optim/root.py CHANGED
@@ -3,7 +3,7 @@ from collections.abc import Callable
3
3
 
4
4
  from abc import abstractmethod
5
5
  import torch
6
- from ..modules.higher_order.multipoint import sixth_order_im1, sixth_order_p6, _solve
6
+ from ..modules.second_order.multipoint import sixth_order_3p, sixth_order_5p, two_point_newton, sixth_order_3pm2, _solve
7
7
 
8
8
  def make_evaluate(f: Callable[[torch.Tensor], torch.Tensor]):
9
9
  def evaluate(x, order) -> tuple[torch.Tensor, ...]:
@@ -53,7 +53,7 @@ class Newton(RootBase):
53
53
  def one_iteration(self, x, evaluate): return newton(x, evaluate, self.lstsq)
54
54
 
55
55
 
56
- class SixthOrderP6(RootBase):
56
+ class SixthOrder3P(RootBase):
57
57
  """sixth-order iterative method
58
58
 
59
59
  Abro, Hameer Akhtar, and Muhammad Mujtaba Shaikh. "A new time-efficient and convergent nonlinear solver." Applied Mathematics and Computation 355 (2019): 516-536.
@@ -62,4 +62,4 @@ class SixthOrderP6(RootBase):
62
62
  def one_iteration(self, x, evaluate):
63
63
  def f(x): return evaluate(x, 0)[0]
64
64
  def f_j(x): return evaluate(x, 1)
65
- return sixth_order_p6(x, f, f_j, self.lstsq)
65
+ return sixth_order_3p(x, f, f_j, self.lstsq)
@@ -3,7 +3,8 @@ from collections.abc import Callable, Iterable
3
3
 
4
4
  import torch
5
5
 
6
- from ...utils import flatten, get_params
6
+ from ...utils import flatten
7
+ from ...utils.optimizer import get_params
7
8
 
8
9
  class Split(torch.optim.Optimizer):
9
10
  """Steps will all `optimizers`, also has a check that they have no duplicate parameters.
@@ -7,24 +7,13 @@ import numpy as np
7
7
  import torch
8
8
  from directsearch.ds import DEFAULT_PARAMS
9
9
 
10
- from ...utils import Optimizer, TensorList
11
-
12
-
13
- def _ensure_float(x):
14
- if isinstance(x, torch.Tensor): return x.detach().cpu().item()
15
- if isinstance(x, np.ndarray): return x.item()
16
- return float(x)
17
-
18
- def _ensure_numpy(x):
19
- if isinstance(x, torch.Tensor): return x.detach().cpu()
20
- if isinstance(x, np.ndarray): return x
21
- return np.array(x)
22
-
10
+ from ...utils import TensorList
11
+ from .wrapper import WrapperBase
23
12
 
24
13
  Closure = Callable[[bool], Any]
25
14
 
26
15
 
27
- class DirectSearch(Optimizer):
16
+ class DirectSearch(WrapperBase):
28
17
  """Use directsearch as pytorch optimizer.
29
18
 
30
19
  Note that this performs full minimization on each step,
@@ -96,28 +85,23 @@ class DirectSearch(Optimizer):
96
85
  del kwargs['self'], kwargs['params'], kwargs['__class__']
97
86
  self._kwargs = kwargs
98
87
 
99
- def _objective(self, x: np.ndarray, params: TensorList, closure):
100
- params.from_vec_(torch.from_numpy(x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
101
- return _ensure_float(closure(False))
102
-
103
88
  @torch.no_grad
104
89
  def step(self, closure: Closure):
105
- params = self.get_params()
106
-
107
- x0 = params.to_vec().detach().cpu().numpy()
90
+ params = TensorList(self._get_params())
91
+ x0 = params.to_vec().numpy(force=True)
108
92
 
109
93
  res = directsearch.solve(
110
- partial(self._objective, params = params, closure = closure),
94
+ partial(self._f, params=params, closure=closure),
111
95
  x0 = x0,
112
96
  **self._kwargs
113
97
  )
114
98
 
115
- params.from_vec_(torch.from_numpy(res.x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
99
+ params.from_vec_(torch.as_tensor(res.x, device = params[0].device, dtype=params[0].dtype))
116
100
  return res.f
117
101
 
118
102
 
119
103
 
120
- class DirectSearchDS(Optimizer):
104
+ class DirectSearchDS(WrapperBase):
121
105
  def __init__(
122
106
  self,
123
107
  params,
@@ -139,26 +123,21 @@ class DirectSearchDS(Optimizer):
139
123
  del kwargs['self'], kwargs['params'], kwargs['__class__']
140
124
  self._kwargs = kwargs
141
125
 
142
- def _objective(self, x: np.ndarray, params: TensorList, closure):
143
- params.from_vec_(torch.from_numpy(x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
144
- return _ensure_float(closure(False))
145
-
146
126
  @torch.no_grad
147
127
  def step(self, closure: Closure):
148
- params = self.get_params()
149
-
150
- x0 = params.to_vec().detach().cpu().numpy()
128
+ params = TensorList(self._get_params())
129
+ x0 = params.to_vec().numpy(force=True)
151
130
 
152
131
  res = directsearch.solve_directsearch(
153
- partial(self._objective, params = params, closure = closure),
132
+ partial(self._f, params = params, closure = closure),
154
133
  x0 = x0,
155
134
  **self._kwargs
156
135
  )
157
136
 
158
- params.from_vec_(torch.from_numpy(res.x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
137
+ params.from_vec_(torch.as_tensor(res.x, device = params[0].device, dtype=params[0].dtype))
159
138
  return res.f
160
139
 
161
- class DirectSearchProbabilistic(Optimizer):
140
+ class DirectSearchProbabilistic(WrapperBase):
162
141
  def __init__(
163
142
  self,
164
143
  params,
@@ -179,27 +158,22 @@ class DirectSearchProbabilistic(Optimizer):
179
158
  del kwargs['self'], kwargs['params'], kwargs['__class__']
180
159
  self._kwargs = kwargs
181
160
 
182
- def _objective(self, x: np.ndarray, params: TensorList, closure):
183
- params.from_vec_(torch.from_numpy(x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
184
- return _ensure_float(closure(False))
185
-
186
161
  @torch.no_grad
187
162
  def step(self, closure: Closure):
188
- params = self.get_params()
189
-
190
- x0 = params.to_vec().detach().cpu().numpy()
163
+ params = TensorList(self._get_params())
164
+ x0 = params.to_vec().numpy(force=True)
191
165
 
192
166
  res = directsearch.solve_probabilistic_directsearch(
193
- partial(self._objective, params = params, closure = closure),
167
+ partial(self._f, params = params, closure = closure),
194
168
  x0 = x0,
195
169
  **self._kwargs
196
170
  )
197
171
 
198
- params.from_vec_(torch.from_numpy(res.x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
172
+ params.from_vec_(torch.as_tensor(res.x, device = params[0].device, dtype=params[0].dtype))
199
173
  return res.f
200
174
 
201
175
 
202
- class DirectSearchSubspace(Optimizer):
176
+ class DirectSearchSubspace(WrapperBase):
203
177
  def __init__(
204
178
  self,
205
179
  params,
@@ -223,28 +197,23 @@ class DirectSearchSubspace(Optimizer):
223
197
  del kwargs['self'], kwargs['params'], kwargs['__class__']
224
198
  self._kwargs = kwargs
225
199
 
226
- def _objective(self, x: np.ndarray, params: TensorList, closure):
227
- params.from_vec_(torch.from_numpy(x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
228
- return _ensure_float(closure(False))
229
-
230
200
  @torch.no_grad
231
201
  def step(self, closure: Closure):
232
- params = self.get_params()
233
-
234
- x0 = params.to_vec().detach().cpu().numpy()
202
+ params = TensorList(self._get_params())
203
+ x0 = params.to_vec().numpy(force=True)
235
204
 
236
205
  res = directsearch.solve_subspace_directsearch(
237
- partial(self._objective, params = params, closure = closure),
206
+ partial(self._f, params = params, closure = closure),
238
207
  x0 = x0,
239
208
  **self._kwargs
240
209
  )
241
210
 
242
- params.from_vec_(torch.from_numpy(res.x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
211
+ params.from_vec_(torch.as_tensor(res.x, device = params[0].device, dtype=params[0].dtype))
243
212
  return res.f
244
213
 
245
214
 
246
215
 
247
- class DirectSearchSTP(Optimizer):
216
+ class DirectSearchSTP(WrapperBase):
248
217
  def __init__(
249
218
  self,
250
219
  params,
@@ -260,21 +229,16 @@ class DirectSearchSTP(Optimizer):
260
229
  del kwargs['self'], kwargs['params'], kwargs['__class__']
261
230
  self._kwargs = kwargs
262
231
 
263
- def _objective(self, x: np.ndarray, params: TensorList, closure):
264
- params.from_vec_(torch.from_numpy(x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
265
- return _ensure_float(closure(False))
266
-
267
232
  @torch.no_grad
268
233
  def step(self, closure: Closure):
269
- params = self.get_params()
270
-
271
- x0 = params.to_vec().detach().cpu().numpy()
234
+ params = TensorList(self._get_params())
235
+ x0 = params.to_vec().numpy(force=True)
272
236
 
273
237
  res = directsearch.solve_stp(
274
- partial(self._objective, params = params, closure = closure),
238
+ partial(self._f, params = params, closure = closure),
275
239
  x0 = x0,
276
240
  **self._kwargs
277
241
  )
278
242
 
279
- params.from_vec_(torch.from_numpy(res.x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
243
+ params.from_vec_(torch.as_tensor(res.x, device = params[0].device, dtype=params[0].dtype))
280
244
  return res.f
@@ -9,20 +9,15 @@ import fcmaes
9
9
  import fcmaes.optimizer
10
10
  import fcmaes.retry
11
11
 
12
- from ...utils import Optimizer, TensorList
12
+ from ...utils import TensorList
13
+ from .wrapper import WrapperBase
13
14
 
14
15
  Closure = Callable[[bool], Any]
15
16
 
16
-
17
- def _ensure_float(x) -> float:
18
- if isinstance(x, torch.Tensor): return x.detach().cpu().item()
19
- if isinstance(x, np.ndarray): return float(x.item())
20
- return float(x)
21
-
22
17
  def silence_fcmaes():
23
18
  fcmaes.retry.logger.disable('fcmaes')
24
19
 
25
- class FcmaesWrapper(Optimizer):
20
+ class FcmaesWrapper(WrapperBase):
26
21
  """Use fcmaes as pytorch optimizer. Particularly fcmaes has BITEOPT which appears to win in many benchmarks.
27
22
 
28
23
  Note that this performs full minimization on each step, so only perform one step with this.
@@ -42,7 +37,7 @@ class FcmaesWrapper(Optimizer):
42
37
  CMA-ES population size used for all CMA-ES runs.
43
38
  Not used for differential evolution.
44
39
  Ignored if parameter optimizer is defined. Defaults to 31.
45
- capacity (int | None, optional): capacity of the evaluation store.. Defaults to 500.
40
+ capacity (int | None, optional): capacity of the evaluation store. Defaults to 500.
46
41
  stop_fitness (float | None, optional):
47
42
  Limit for fitness value. optimization runs terminate if this value is reached. Defaults to -np.inf.
48
43
  statistic_num (int | None, optional):
@@ -61,46 +56,30 @@ class FcmaesWrapper(Optimizer):
61
56
  popsize: int | None = 31,
62
57
  capacity: int | None = 500,
63
58
  stop_fitness: float | None = -np.inf,
64
- statistic_num: int | None = 0
59
+ statistic_num: int | None = 0,
60
+ silence: bool = True,
65
61
  ):
66
- super().__init__(params, lb=lb, ub=ub)
67
- silence_fcmaes()
62
+ super().__init__(params, dict(lb=lb,ub=ub))
63
+ if silence:
64
+ silence_fcmaes()
68
65
  kwargs = locals().copy()
69
- del kwargs['self'], kwargs['params'], kwargs['lb'], kwargs['ub'], kwargs['__class__']
66
+ del kwargs['self'], kwargs['params'], kwargs['lb'], kwargs['ub'], kwargs['__class__'], kwargs["silence"]
70
67
  self._kwargs = kwargs
71
68
  self._kwargs['workers'] = 1
72
69
 
73
- def _objective(self, x: np.ndarray, params: TensorList, closure) -> float:
74
- if self.raised: return np.inf
75
- try:
76
- params.from_vec_(torch.from_numpy(x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
77
- return _ensure_float(closure(False))
78
- except Exception as e:
79
- # ha ha, I found a way to make exceptions work in fcmaes and scipy direct
80
- self.e = e
81
- self.raised = True
82
- return np.inf
83
70
 
84
71
  @torch.no_grad
85
72
  def step(self, closure: Closure):
86
- self.raised = False
87
- self.e = None
88
73
 
89
- params = self.get_params()
90
-
91
- lb, ub = self.group_vals('lb', 'ub', cls=list)
92
- bounds = []
93
- for p, l, u in zip(params, lb, ub):
94
- bounds.extend([[l, u]] * p.numel())
74
+ params = TensorList(self._get_params())
75
+ bounds = self._get_bounds()
95
76
 
96
77
  res = fcmaes.retry.minimize(
97
- partial(self._objective, params=params, closure=closure), # pyright:ignore[reportArgumentType]
78
+ partial(self._f, params=params, closure=closure), # pyright:ignore[reportArgumentType]
98
79
  bounds=bounds, # pyright:ignore[reportArgumentType]
99
80
  **self._kwargs
100
81
  )
101
82
 
102
- params.from_vec_(torch.from_numpy(res.x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
103
-
104
- if self.e is not None: raise self.e from None
83
+ params.from_vec_(torch.as_tensor(res.x, device = params[0].device, dtype=params[0].dtype))
105
84
  return res.fun
106
85
 
@@ -6,24 +6,13 @@ import numpy as np
6
6
  import torch
7
7
  from mads.mads import orthomads
8
8
 
9
- from ...utils import Optimizer, TensorList
10
-
11
-
12
- def _ensure_float(x):
13
- if isinstance(x, torch.Tensor): return x.detach().cpu().item()
14
- if isinstance(x, np.ndarray): return x.item()
15
- return float(x)
16
-
17
- def _ensure_numpy(x):
18
- if isinstance(x, torch.Tensor): return x.detach().cpu()
19
- if isinstance(x, np.ndarray): return x
20
- return np.array(x)
21
-
9
+ from ...utils import TensorList
10
+ from .wrapper import WrapperBase
22
11
 
23
12
  Closure = Callable[[bool], Any]
24
13
 
25
14
 
26
- class MADS(Optimizer):
15
+ class MADS(WrapperBase):
27
16
  """Use mads.orthomads as pytorch optimizer.
28
17
 
29
18
  Note that this performs full minimization on each step,
@@ -53,37 +42,28 @@ class MADS(Optimizer):
53
42
  displog = False,
54
43
  savelog = False,
55
44
  ):
56
- super().__init__(params, lb=lb, ub=ub)
45
+ super().__init__(params, dict(lb=lb, ub=ub))
57
46
 
58
47
  kwargs = locals().copy()
59
48
  del kwargs['self'], kwargs['params'], kwargs['lb'], kwargs['ub'], kwargs['__class__']
60
49
  self._kwargs = kwargs
61
50
 
62
- def _objective(self, x: np.ndarray, params: TensorList, closure):
63
- params.from_vec_(torch.from_numpy(x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
64
- return _ensure_float(closure(False))
65
51
 
66
52
  @torch.no_grad
67
53
  def step(self, closure: Closure):
68
- params = self.get_params()
69
-
70
- x0 = params.to_vec().detach().cpu().numpy()
54
+ params = TensorList(self._get_params())
55
+ x0 = params.to_vec().numpy(force=True)
56
+ lb, ub = self._get_lb_ub()
71
57
 
72
- lb, ub = self.group_vals('lb', 'ub', cls=list)
73
- bounds_lower = []
74
- bounds_upper = []
75
- for p, l, u in zip(params, lb, ub):
76
- bounds_lower.extend([l] * p.numel())
77
- bounds_upper.extend([u] * p.numel())
78
58
 
79
59
  f, x = orthomads(
80
60
  design_variables=x0,
81
- bounds_upper=np.asarray(bounds_upper),
82
- bounds_lower=np.asarray(bounds_lower),
83
- objective_function=partial(self._objective, params = params, closure = closure),
61
+ bounds_upper=np.asarray(ub),
62
+ bounds_lower=np.asarray(lb),
63
+ objective_function=partial(self._f, params=params, closure=closure),
84
64
  **self._kwargs
85
65
  )
86
66
 
87
- params.from_vec_(torch.from_numpy(x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
67
+ params.from_vec_(torch.as_tensor(x, device = params[0].device, dtype=params[0].dtype,))
88
68
  return f
89
69
 
@@ -0,0 +1,66 @@
1
+ from collections.abc import Callable
2
+ from functools import partial
3
+ from typing import Any, Literal
4
+
5
+ import numpy as np
6
+ import torch
7
+
8
+ from ...utils import TensorList
9
+ from .wrapper import WrapperBase
10
+
11
+ Closure = Callable[[bool], Any]
12
+
13
+ class MoorsWrapper(WrapperBase):
14
+ """Use moo-rs (pymoors) is PyTorch optimizer.
15
+
16
+ Note that this performs full minimization on each step,
17
+ so usually you would want to perform a single step.
18
+
19
+ To use this, define a function that accepts fitness function and number of variables and returns a pymoors algorithm:
20
+
21
+ ```python
22
+ alg_fn = lambda fitness_fn, num_vars: pymoors.Nsga2(
23
+ fitness_fn=fitness_fn,
24
+ num_vars=num_vars,
25
+ num_iterations=100,
26
+ sampler = pymoors.RandomSamplingFloat(min=-3, max=3),
27
+ crossover = pymoors.SinglePointBinaryCrossover(),
28
+ mutation = pymoors.GaussianMutation(gene_mutation_rate=1e-2, sigma=0.1),
29
+ population_size = 32,
30
+ num_offsprings = 32,
31
+ )
32
+
33
+ optimizer = MoorsWrapper(model.parameters(), alg_fn)
34
+ ```
35
+
36
+ All algorithms in pymoors have slightly different APIs, refer to their docs.
37
+
38
+ """
39
+ def __init__(
40
+ self,
41
+ params,
42
+ algorithm_fn: Callable[[Callable[[np.ndarray], np.ndarray], int], Any]
43
+ ):
44
+ super().__init__(params, {})
45
+ self._algorithm_fn = algorithm_fn
46
+
47
+ def _objective(self, x: np.ndarray, params, closure):
48
+ fs = []
49
+ for x_i in x:
50
+ f_i = self._fs(x_i, params=params, closure=closure)
51
+ fs.append(f_i)
52
+ return np.stack(fs, dtype=np.float64) # pymoors needs float64
53
+
54
+ @torch.no_grad
55
+ def step(self, closure: Closure):
56
+ params = TensorList(self._get_params())
57
+ objective = partial(self._objective, params=params, closure=closure)
58
+
59
+ algorithm = self._algorithm_fn(objective, params.global_numel())
60
+
61
+ algorithm.run()
62
+ pop = algorithm.population
63
+
64
+ params.from_vec_(torch.as_tensor(pop.best[0].genes, device = params[0].device, dtype=params[0].dtype,))
65
+ return pop.best[0].fitness
66
+
@@ -6,7 +6,7 @@ import torch
6
6
 
7
7
  import nevergrad as ng
8
8
 
9
- from ...utils import Optimizer
9
+ from .wrapper import WrapperBase
10
10
 
11
11
 
12
12
  def _ensure_float(x) -> float:
@@ -14,7 +14,7 @@ def _ensure_float(x) -> float:
14
14
  if isinstance(x, np.ndarray): return float(x.item())
15
15
  return float(x)
16
16
 
17
- class NevergradWrapper(Optimizer):
17
+ class NevergradWrapper(WrapperBase):
18
18
  """Use nevergrad optimizer as pytorch optimizer.
19
19
  Note that it is recommended to specify `budget` to the number of iterations you expect to run,
20
20
  as some nevergrad optimizers will error without it.
@@ -72,7 +72,7 @@ class NevergradWrapper(Optimizer):
72
72
 
73
73
  @torch.no_grad
74
74
  def step(self, closure): # pylint:disable=signature-differs # pyright:ignore[reportIncompatibleMethodOverride]
75
- params = self.get_params()
75
+ params = self._get_params()
76
76
  if self.opt is None:
77
77
  ng_params = []
78
78
  for group in self.param_groups:
@@ -95,7 +95,7 @@ class NevergradWrapper(Optimizer):
95
95
 
96
96
  x: ng.p.Tuple = self.opt.ask() # type:ignore
97
97
  for cur, new in zip(params, x):
98
- cur.set_(torch.from_numpy(new.value).to(dtype=cur.dtype, device=cur.device, copy=False).reshape_as(cur)) # type:ignore
98
+ cur.set_(torch.as_tensor(new.value, dtype=cur.dtype, device=cur.device).reshape_as(cur)) # type:ignore
99
99
 
100
100
  loss = closure(False)
101
101
  self.opt.tell(x, _ensure_float(loss))