torchzero 0.3.15__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (187) hide show
  1. tests/test_identical.py +22 -22
  2. tests/test_module_autograd.py +586 -0
  3. tests/test_objective.py +188 -0
  4. tests/test_opts.py +225 -214
  5. tests/test_tensorlist.py +0 -8
  6. tests/test_utils_optimizer.py +0 -1
  7. torchzero/__init__.py +2 -2
  8. torchzero/core/__init__.py +7 -4
  9. torchzero/core/chain.py +20 -23
  10. torchzero/core/functional.py +90 -24
  11. torchzero/core/modular.py +53 -57
  12. torchzero/core/module.py +132 -52
  13. torchzero/core/objective.py +948 -0
  14. torchzero/core/reformulation.py +55 -24
  15. torchzero/core/transform.py +261 -367
  16. torchzero/linalg/__init__.py +11 -0
  17. torchzero/linalg/eigh.py +253 -0
  18. torchzero/linalg/linalg_utils.py +14 -0
  19. torchzero/{utils/linalg → linalg}/linear_operator.py +99 -49
  20. torchzero/linalg/matrix_power.py +28 -0
  21. torchzero/linalg/orthogonalize.py +93 -0
  22. torchzero/{utils/linalg → linalg}/qr.py +16 -2
  23. torchzero/{utils/linalg → linalg}/solve.py +74 -88
  24. torchzero/linalg/svd.py +47 -0
  25. torchzero/linalg/torch_linalg.py +168 -0
  26. torchzero/modules/__init__.py +4 -3
  27. torchzero/modules/adaptive/__init__.py +11 -3
  28. torchzero/modules/adaptive/adagrad.py +167 -217
  29. torchzero/modules/adaptive/adahessian.py +76 -105
  30. torchzero/modules/adaptive/adam.py +53 -76
  31. torchzero/modules/adaptive/adan.py +50 -31
  32. torchzero/modules/adaptive/adaptive_heavyball.py +12 -7
  33. torchzero/modules/adaptive/aegd.py +12 -12
  34. torchzero/modules/adaptive/esgd.py +98 -119
  35. torchzero/modules/adaptive/ggt.py +186 -0
  36. torchzero/modules/adaptive/lion.py +7 -11
  37. torchzero/modules/adaptive/lre_optimizers.py +299 -0
  38. torchzero/modules/adaptive/mars.py +7 -7
  39. torchzero/modules/adaptive/matrix_momentum.py +48 -52
  40. torchzero/modules/adaptive/msam.py +71 -53
  41. torchzero/modules/adaptive/muon.py +67 -129
  42. torchzero/modules/adaptive/natural_gradient.py +63 -41
  43. torchzero/modules/adaptive/orthograd.py +11 -15
  44. torchzero/modules/adaptive/psgd/__init__.py +5 -0
  45. torchzero/modules/adaptive/psgd/_psgd_utils.py +37 -0
  46. torchzero/modules/adaptive/psgd/psgd.py +1390 -0
  47. torchzero/modules/adaptive/psgd/psgd_dense_newton.py +174 -0
  48. torchzero/modules/adaptive/psgd/psgd_kron_newton.py +203 -0
  49. torchzero/modules/adaptive/psgd/psgd_kron_whiten.py +185 -0
  50. torchzero/modules/adaptive/psgd/psgd_lra_newton.py +118 -0
  51. torchzero/modules/adaptive/psgd/psgd_lra_whiten.py +116 -0
  52. torchzero/modules/adaptive/rmsprop.py +83 -75
  53. torchzero/modules/adaptive/rprop.py +48 -47
  54. torchzero/modules/adaptive/sam.py +55 -45
  55. torchzero/modules/adaptive/shampoo.py +149 -130
  56. torchzero/modules/adaptive/soap.py +207 -143
  57. torchzero/modules/adaptive/sophia_h.py +106 -130
  58. torchzero/modules/clipping/clipping.py +22 -25
  59. torchzero/modules/clipping/ema_clipping.py +31 -25
  60. torchzero/modules/clipping/growth_clipping.py +14 -17
  61. torchzero/modules/conjugate_gradient/cg.py +27 -38
  62. torchzero/modules/experimental/__init__.py +7 -6
  63. torchzero/modules/experimental/adanystrom.py +258 -0
  64. torchzero/modules/experimental/common_directions_whiten.py +142 -0
  65. torchzero/modules/experimental/coordinate_momentum.py +36 -0
  66. torchzero/modules/experimental/cubic_adam.py +160 -0
  67. torchzero/modules/experimental/curveball.py +25 -41
  68. torchzero/modules/experimental/eigen_sr1.py +182 -0
  69. torchzero/modules/experimental/eigengrad.py +207 -0
  70. torchzero/modules/experimental/gradmin.py +2 -2
  71. torchzero/modules/experimental/higher_order_newton.py +14 -40
  72. torchzero/modules/experimental/l_infinity.py +1 -1
  73. torchzero/modules/experimental/matrix_nag.py +122 -0
  74. torchzero/modules/experimental/newton_solver.py +23 -54
  75. torchzero/modules/experimental/newtonnewton.py +45 -48
  76. torchzero/modules/experimental/reduce_outward_lr.py +7 -7
  77. torchzero/modules/experimental/scipy_newton_cg.py +21 -24
  78. torchzero/modules/experimental/spsa1.py +3 -3
  79. torchzero/modules/experimental/structural_projections.py +1 -4
  80. torchzero/modules/grad_approximation/fdm.py +2 -2
  81. torchzero/modules/grad_approximation/forward_gradient.py +7 -7
  82. torchzero/modules/grad_approximation/grad_approximator.py +23 -16
  83. torchzero/modules/grad_approximation/rfdm.py +24 -21
  84. torchzero/modules/least_squares/gn.py +121 -50
  85. torchzero/modules/line_search/backtracking.py +4 -4
  86. torchzero/modules/line_search/line_search.py +33 -33
  87. torchzero/modules/line_search/strong_wolfe.py +4 -4
  88. torchzero/modules/misc/debug.py +12 -12
  89. torchzero/modules/misc/escape.py +10 -10
  90. torchzero/modules/misc/gradient_accumulation.py +11 -79
  91. torchzero/modules/misc/homotopy.py +16 -8
  92. torchzero/modules/misc/misc.py +121 -123
  93. torchzero/modules/misc/multistep.py +52 -53
  94. torchzero/modules/misc/regularization.py +49 -44
  95. torchzero/modules/misc/split.py +31 -29
  96. torchzero/modules/misc/switch.py +37 -32
  97. torchzero/modules/momentum/averaging.py +14 -14
  98. torchzero/modules/momentum/cautious.py +37 -31
  99. torchzero/modules/momentum/momentum.py +12 -12
  100. torchzero/modules/ops/__init__.py +4 -4
  101. torchzero/modules/ops/accumulate.py +21 -21
  102. torchzero/modules/ops/binary.py +67 -66
  103. torchzero/modules/ops/higher_level.py +20 -20
  104. torchzero/modules/ops/multi.py +44 -41
  105. torchzero/modules/ops/reduce.py +26 -23
  106. torchzero/modules/ops/unary.py +53 -53
  107. torchzero/modules/ops/utility.py +47 -46
  108. torchzero/modules/{functional.py → opt_utils.py} +1 -1
  109. torchzero/modules/projections/galore.py +1 -1
  110. torchzero/modules/projections/projection.py +46 -43
  111. torchzero/modules/quasi_newton/__init__.py +1 -1
  112. torchzero/modules/quasi_newton/damping.py +2 -2
  113. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +1 -1
  114. torchzero/modules/quasi_newton/lbfgs.py +10 -10
  115. torchzero/modules/quasi_newton/lsr1.py +10 -10
  116. torchzero/modules/quasi_newton/quasi_newton.py +54 -39
  117. torchzero/modules/quasi_newton/sg2.py +69 -205
  118. torchzero/modules/restarts/restars.py +39 -37
  119. torchzero/modules/second_order/__init__.py +2 -2
  120. torchzero/modules/second_order/ifn.py +31 -62
  121. torchzero/modules/second_order/inm.py +57 -53
  122. torchzero/modules/second_order/multipoint.py +40 -80
  123. torchzero/modules/second_order/newton.py +165 -196
  124. torchzero/modules/second_order/newton_cg.py +105 -157
  125. torchzero/modules/second_order/nystrom.py +216 -185
  126. torchzero/modules/second_order/rsn.py +132 -125
  127. torchzero/modules/smoothing/laplacian.py +13 -12
  128. torchzero/modules/smoothing/sampling.py +10 -10
  129. torchzero/modules/step_size/adaptive.py +24 -24
  130. torchzero/modules/step_size/lr.py +17 -17
  131. torchzero/modules/termination/termination.py +32 -30
  132. torchzero/modules/trust_region/cubic_regularization.py +3 -3
  133. torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
  134. torchzero/modules/trust_region/trust_cg.py +2 -2
  135. torchzero/modules/trust_region/trust_region.py +27 -22
  136. torchzero/modules/variance_reduction/svrg.py +23 -21
  137. torchzero/modules/weight_decay/__init__.py +2 -1
  138. torchzero/modules/weight_decay/reinit.py +83 -0
  139. torchzero/modules/weight_decay/weight_decay.py +17 -18
  140. torchzero/modules/wrappers/optim_wrapper.py +14 -14
  141. torchzero/modules/zeroth_order/cd.py +10 -7
  142. torchzero/optim/mbs.py +291 -0
  143. torchzero/optim/root.py +3 -3
  144. torchzero/optim/utility/split.py +2 -1
  145. torchzero/optim/wrappers/directsearch.py +27 -63
  146. torchzero/optim/wrappers/fcmaes.py +14 -35
  147. torchzero/optim/wrappers/mads.py +11 -31
  148. torchzero/optim/wrappers/moors.py +66 -0
  149. torchzero/optim/wrappers/nevergrad.py +4 -13
  150. torchzero/optim/wrappers/nlopt.py +31 -25
  151. torchzero/optim/wrappers/optuna.py +8 -13
  152. torchzero/optim/wrappers/pybobyqa.py +124 -0
  153. torchzero/optim/wrappers/scipy/__init__.py +7 -0
  154. torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
  155. torchzero/optim/wrappers/scipy/brute.py +48 -0
  156. torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
  157. torchzero/optim/wrappers/scipy/direct.py +69 -0
  158. torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
  159. torchzero/optim/wrappers/scipy/experimental.py +141 -0
  160. torchzero/optim/wrappers/scipy/minimize.py +151 -0
  161. torchzero/optim/wrappers/scipy/sgho.py +111 -0
  162. torchzero/optim/wrappers/wrapper.py +121 -0
  163. torchzero/utils/__init__.py +7 -25
  164. torchzero/utils/benchmarks/__init__.py +0 -0
  165. torchzero/utils/benchmarks/logistic.py +122 -0
  166. torchzero/utils/compile.py +2 -2
  167. torchzero/utils/derivatives.py +97 -73
  168. torchzero/utils/optimizer.py +4 -77
  169. torchzero/utils/python_tools.py +31 -0
  170. torchzero/utils/tensorlist.py +11 -5
  171. torchzero/utils/thoad_tools.py +68 -0
  172. {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/METADATA +1 -1
  173. torchzero-0.4.1.dist-info/RECORD +209 -0
  174. tests/test_vars.py +0 -185
  175. torchzero/core/var.py +0 -376
  176. torchzero/modules/adaptive/lmadagrad.py +0 -186
  177. torchzero/modules/experimental/momentum.py +0 -160
  178. torchzero/optim/wrappers/scipy.py +0 -572
  179. torchzero/utils/linalg/__init__.py +0 -12
  180. torchzero/utils/linalg/matrix_funcs.py +0 -87
  181. torchzero/utils/linalg/orthogonalize.py +0 -12
  182. torchzero/utils/linalg/svd.py +0 -20
  183. torchzero/utils/ops.py +0 -10
  184. torchzero-0.3.15.dist-info/RECORD +0 -175
  185. /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
  186. {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/WHEEL +0 -0
  187. {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,83 @@
1
+ from functools import partial
2
+
3
+ import torch
4
+
5
+ from ...core import Module
6
+ from ...utils import NumberList, TensorList
7
+
8
+
9
+ def _reset_except_self(objective, modules, self: Module):
10
+ for m in modules:
11
+ if m is not self:
12
+ m.reset()
13
+
14
+ class RandomReinitialize(Module):
15
+ """On each step with probability ``p_reinit`` trigger reinitialization,
16
+ whereby ``p_weights`` weights are reset to their initial values.
17
+
18
+ This modifies the parameters directly. Place it as the first module.
19
+
20
+ Args:
21
+ p_reinit (float, optional): probability to trigger reinitialization on each step. Defaults to 0.01.
22
+ p_weights (float, optional): probability for each weight to be set to initial value when reinitialization is triggered. Defaults to 0.1.
23
+ store_every (int | None, optional): if set, stores new initial values every this many steps. Defaults to None.
24
+ beta (float, optional):
25
+ whenever ``store_every`` is triggered, uses linear interpolation with this beta.
26
+ If ``store_every=1``, this can be set to some value close to 1 such as 0.999
27
+ to reinitialize to slow parameter EMA. Defaults to 0.
28
+ reset (bool, optional): whether to reset states of other modules on reinitialization. Defaults to False.
29
+ seed (int | None, optional): random seed.
30
+ """
31
+
32
+ def __init__(
33
+ self,
34
+ p_reinit: float = 0.01,
35
+ p_weights: float = 0.1,
36
+ store_every: int | None = None,
37
+ beta: float = 0,
38
+ reset: bool = False,
39
+ seed: int | None = None,
40
+ ):
41
+ defaults = dict(p_weights=p_weights, p_reinit=p_reinit, store_every=store_every, beta=beta, reset=reset, seed=seed)
42
+ super().__init__(defaults)
43
+
44
+ def update(self, objective):
45
+ # this stores initial values to per-parameter states
46
+ p_init = self.get_state(objective.params, "p_init", init="params", cls=TensorList)
47
+
48
+ # store new params every store_every steps
49
+ step = self.global_state.get("step", 0)
50
+ self.global_state["step"] = step + 1
51
+
52
+ store_every = self.defaults["store_every"]
53
+ if (store_every is not None and step % store_every == 0):
54
+ beta = self.get_settings(objective.params, "beta", cls=NumberList)
55
+ p_init.lerp_(objective.params, weight=(1 - beta))
56
+
57
+ @torch.no_grad
58
+ def apply(self, objective):
59
+ p_reinit = self.defaults["p_reinit"]
60
+ device = objective.params[0].device
61
+ generator = self.get_generator(device, self.defaults["seed"])
62
+
63
+ # determine whether to trigger reinitialization
64
+ reinitialize = torch.rand(1, generator=generator, device=device) < p_reinit
65
+
66
+ # reinitialize
67
+ if reinitialize:
68
+ params = TensorList(objective.params)
69
+ p_init = self.get_state(params, "p_init", init=params)
70
+
71
+
72
+ # mask with p_weights entries being True
73
+ p_weights = self.get_settings(params, "p_weights")
74
+ mask = params.bernoulli_like(p_weights, generator=generator).as_bool()
75
+
76
+ # set weights at mask to their initialization
77
+ params.masked_set_(mask, p_init)
78
+
79
+ # reset
80
+ if self.defaults["reset"]:
81
+ objective.post_step_hooks.append(partial(_reset_except_self, self=self))
82
+
83
+ return objective
@@ -3,7 +3,7 @@ from typing import Literal
3
3
 
4
4
  import torch
5
5
 
6
- from ...core import Module, Target, Transform
6
+ from ...core import Module, TensorTransform
7
7
  from ...utils import NumberList, TensorList, as_tensorlist, unpack_dicts, unpack_states, Metrics
8
8
 
9
9
 
@@ -21,7 +21,7 @@ def weight_decay_(
21
21
  return grad_.add_(params.pow(ord-1).copysign_(params).mul_(weight_decay))
22
22
 
23
23
 
24
- class WeightDecay(Transform):
24
+ class WeightDecay(TensorTransform):
25
25
  """Weight decay.
26
26
 
27
27
  Args:
@@ -33,7 +33,7 @@ class WeightDecay(Transform):
33
33
 
34
34
  Adam with non-decoupled weight decay
35
35
  ```python
36
- opt = tz.Modular(
36
+ opt = tz.Optimizer(
37
37
  model.parameters(),
38
38
  tz.m.WeightDecay(1e-3),
39
39
  tz.m.Adam(),
@@ -44,7 +44,7 @@ class WeightDecay(Transform):
44
44
  Adam with decoupled weight decay that still scales with learning rate
45
45
  ```python
46
46
 
47
- opt = tz.Modular(
47
+ opt = tz.Optimizer(
48
48
  model.parameters(),
49
49
  tz.m.Adam(),
50
50
  tz.m.WeightDecay(1e-3),
@@ -54,7 +54,7 @@ class WeightDecay(Transform):
54
54
 
55
55
  Adam with fully decoupled weight decay that doesn't scale with learning rate
56
56
  ```python
57
- opt = tz.Modular(
57
+ opt = tz.Optimizer(
58
58
  model.parameters(),
59
59
  tz.m.Adam(),
60
60
  tz.m.LR(1e-3),
@@ -63,19 +63,19 @@ class WeightDecay(Transform):
63
63
  ```
64
64
 
65
65
  """
66
- def __init__(self, weight_decay: float, ord: int = 2, target: Target = 'update'):
66
+ def __init__(self, weight_decay: float, ord: int = 2):
67
67
 
68
68
  defaults = dict(weight_decay=weight_decay, ord=ord)
69
- super().__init__(defaults, uses_grad=False, target=target)
69
+ super().__init__(defaults)
70
70
 
71
71
  @torch.no_grad
72
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
72
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
73
73
  weight_decay = NumberList(s['weight_decay'] for s in settings)
74
74
  ord = settings[0]['ord']
75
75
 
76
76
  return weight_decay_(as_tensorlist(tensors), as_tensorlist(params), weight_decay, ord)
77
77
 
78
- class RelativeWeightDecay(Transform):
78
+ class RelativeWeightDecay(TensorTransform):
79
79
  """Weight decay relative to the mean absolute value of update, gradient or parameters depending on value of ``norm_input`` argument.
80
80
 
81
81
  Args:
@@ -93,7 +93,7 @@ class RelativeWeightDecay(Transform):
93
93
 
94
94
  Adam with non-decoupled relative weight decay
95
95
  ```python
96
- opt = tz.Modular(
96
+ opt = tz.Optimizer(
97
97
  model.parameters(),
98
98
  tz.m.RelativeWeightDecay(1e-1),
99
99
  tz.m.Adam(),
@@ -103,7 +103,7 @@ class RelativeWeightDecay(Transform):
103
103
 
104
104
  Adam with decoupled relative weight decay
105
105
  ```python
106
- opt = tz.Modular(
106
+ opt = tz.Optimizer(
107
107
  model.parameters(),
108
108
  tz.m.Adam(),
109
109
  tz.m.RelativeWeightDecay(1e-1),
@@ -117,13 +117,12 @@ class RelativeWeightDecay(Transform):
117
117
  ord: int = 2,
118
118
  norm_input: Literal["update", "grad", "params"] = "update",
119
119
  metric: Metrics = 'mad',
120
- target: Target = "update",
121
120
  ):
122
121
  defaults = dict(weight_decay=weight_decay, ord=ord, norm_input=norm_input, metric=metric)
123
- super().__init__(defaults, uses_grad=norm_input == 'grad', target=target)
122
+ super().__init__(defaults, uses_grad=norm_input == 'grad')
124
123
 
125
124
  @torch.no_grad
126
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
125
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
127
126
  weight_decay = NumberList(s['weight_decay'] for s in settings)
128
127
 
129
128
  ord = settings[0]['ord']
@@ -161,9 +160,9 @@ class DirectWeightDecay(Module):
161
160
  super().__init__(defaults)
162
161
 
163
162
  @torch.no_grad
164
- def step(self, var):
165
- weight_decay = self.get_settings(var.params, 'weight_decay', cls=NumberList)
163
+ def apply(self, objective):
164
+ weight_decay = self.get_settings(objective.params, 'weight_decay', cls=NumberList)
166
165
  ord = self.defaults['ord']
167
166
 
168
- decay_weights_(var.params, weight_decay, ord)
169
- return var
167
+ decay_weights_(objective.params, weight_decay, ord)
168
+ return objective
@@ -3,7 +3,7 @@ from typing import Any
3
3
  import torch
4
4
 
5
5
  from ...core.module import Module
6
- from ...utils import Params, _copy_param_groups, _make_param_groups
6
+ from ...utils.params import Params, _copy_param_groups, _make_param_groups
7
7
 
8
8
 
9
9
  class Wrap(Module):
@@ -11,7 +11,7 @@ class Wrap(Module):
11
11
  Wraps a pytorch optimizer to use it as a module.
12
12
 
13
13
  Note:
14
- Custom param groups are supported only by ``set_param_groups``, settings passed to Modular will be applied to all parameters.
14
+ Custom param groups are supported only by ``set_param_groups``, settings passed to Optimizer will be applied to all parameters.
15
15
 
16
16
  Args:
17
17
  opt_fn (Callable[..., torch.optim.Optimizer] | torch.optim.Optimizer):
@@ -21,7 +21,7 @@ class Wrap(Module):
21
21
  **kwargs:
22
22
  Extra args to be passed to opt_fn. The function is called as ``opt_fn(parameters, *args, **kwargs)``.
23
23
  use_param_groups:
24
- Whether to pass settings passed to Modular to the wrapped optimizer.
24
+ Whether to pass settings passed to Optimizer to the wrapped optimizer.
25
25
 
26
26
  Note that settings to the first parameter are used for all parameters,
27
27
  so if you specified per-parameter settings, they will be ignored.
@@ -32,7 +32,7 @@ class Wrap(Module):
32
32
  ```python
33
33
 
34
34
  from pytorch_optimizer import StableAdamW
35
- opt = tz.Modular(
35
+ opt = tz.Optimizer(
36
36
  model.parameters(),
37
37
  tz.m.Wrap(StableAdamW, lr=1),
38
38
  tz.m.Cautious(),
@@ -66,8 +66,8 @@ class Wrap(Module):
66
66
  return super().set_param_groups(param_groups)
67
67
 
68
68
  @torch.no_grad
69
- def step(self, var):
70
- params = var.params
69
+ def apply(self, objective):
70
+ params = objective.params
71
71
 
72
72
  # initialize opt on 1st step
73
73
  if self.optimizer is None:
@@ -76,14 +76,14 @@ class Wrap(Module):
76
76
  self.optimizer = self._opt_fn(param_groups, *self._opt_args, **self._opt_kwargs)
77
77
 
78
78
  # set optimizer per-parameter settings
79
- if self.defaults["use_param_groups"] and var.modular is not None:
79
+ if self.defaults["use_param_groups"] and objective.modular is not None:
80
80
  for group in self.optimizer.param_groups:
81
81
  first_param = group['params'][0]
82
82
  setting = self.settings[first_param]
83
83
 
84
84
  # settings passed in `set_param_groups` are the highest priority
85
85
  # schedulers will override defaults but not settings passed in `set_param_groups`
86
- # this is consistent with how Modular does it.
86
+ # this is consistent with how Optimizer does it.
87
87
  if self._custom_param_groups is not None:
88
88
  setting = {k:v for k,v in setting if k not in self._custom_param_groups[0]}
89
89
 
@@ -91,19 +91,19 @@ class Wrap(Module):
91
91
 
92
92
  # set grad to update
93
93
  orig_grad = [p.grad for p in params]
94
- for p, u in zip(params, var.get_update()):
94
+ for p, u in zip(params, objective.get_updates()):
95
95
  p.grad = u
96
96
 
97
97
  # if this is last module, simply use optimizer to update parameters
98
- if var.modular is not None and self is var.modular.modules[-1]:
98
+ if objective.modular is not None and self is objective.modular.modules[-1]:
99
99
  self.optimizer.step()
100
100
 
101
101
  # restore grad
102
102
  for p, g in zip(params, orig_grad):
103
103
  p.grad = g
104
104
 
105
- var.stop = True; var.skip_update = True
106
- return var
105
+ objective.stop = True; objective.skip_update = True
106
+ return objective
107
107
 
108
108
  # this is not the last module, meaning update is difference in parameters
109
109
  # and passed to next module
@@ -111,11 +111,11 @@ class Wrap(Module):
111
111
  self.optimizer.step() # step and update params
112
112
  for p, g in zip(params, orig_grad):
113
113
  p.grad = g
114
- var.update = list(torch._foreach_sub(params_before_step, params)) # set update to difference between params
114
+ objective.updates = list(torch._foreach_sub(params_before_step, params)) # set update to difference between params
115
115
  for p, o in zip(params, params_before_step):
116
116
  p.set_(o) # pyright: ignore[reportArgumentType]
117
117
 
118
- return var
118
+ return objective
119
119
 
120
120
  def reset(self):
121
121
  super().reset()
@@ -29,17 +29,20 @@ class CD(Module):
29
29
  whether to use three points (three function evaluatins) to determine descent direction.
30
30
  if False, uses two points, but then ``adaptive`` can't be used. Defaults to True.
31
31
  """
32
- def __init__(self, h:float=1e-3, grad:bool=True, adaptive:bool=True, index:Literal['cyclic', 'cyclic2', 'random']="cyclic2", threepoint:bool=True,):
32
+ def __init__(self, h:float=1e-3, grad:bool=False, adaptive:bool=True, index:Literal['cyclic', 'cyclic2', 'random']="cyclic2", threepoint:bool=True,):
33
33
  defaults = dict(h=h, grad=grad, adaptive=adaptive, index=index, threepoint=threepoint)
34
34
  super().__init__(defaults)
35
35
 
36
+ def update(self, objective): raise RuntimeError
37
+ def apply(self, objective): raise RuntimeError
38
+
36
39
  @torch.no_grad
37
- def step(self, var):
38
- closure = var.closure
40
+ def step(self, objective):
41
+ closure = objective.closure
39
42
  if closure is None:
40
43
  raise RuntimeError("CD requires closure")
41
44
 
42
- params = TensorList(var.params)
45
+ params = TensorList(objective.params)
43
46
  ndim = params.global_numel()
44
47
 
45
48
  grad_step_size = self.defaults['grad']
@@ -79,7 +82,7 @@ class CD(Module):
79
82
  else:
80
83
  warnings.warn("CD adaptive=True only works with threepoint=True")
81
84
 
82
- f_0 = var.get_loss(False)
85
+ f_0 = objective.get_loss(False)
83
86
  params.flat_set_lambda_(idx, lambda x: x + h)
84
87
  f_p = closure(False)
85
88
 
@@ -117,6 +120,6 @@ class CD(Module):
117
120
  # ----------------------------- create the update ---------------------------- #
118
121
  update = params.zeros_like()
119
122
  update.flat_set_(idx, alpha)
120
- var.update = update
121
- return var
123
+ objective.updates = update
124
+ return objective
122
125
 
torchzero/optim/mbs.py ADDED
@@ -0,0 +1,291 @@
1
+ from typing import NamedTuple
2
+ import math
3
+ from collections.abc import Iterable
4
+ from decimal import ROUND_HALF_UP, Decimal
5
+
6
+ import numpy as np
7
+
8
+
9
+ def format_number(number, n):
10
+ """Rounds to n significant digits after the decimal point."""
11
+ if number == 0: return 0
12
+ if math.isnan(number) or math.isinf(number) or (not math.isfinite(number)): return number
13
+ if n <= 0: raise ValueError("n must be positive")
14
+
15
+ dec = Decimal(str(number))
16
+ if dec.is_zero(): return 0
17
+ if number > 10**n or dec % 1 == 0: return int(dec)
18
+
19
+ if abs(dec) >= 1:
20
+ places = n
21
+ else:
22
+ frac_str = format(abs(dec), 'f').split('.')[1]
23
+ leading_zeros = len(frac_str) - len(frac_str.lstrip('0'))
24
+ places = leading_zeros + n
25
+
26
+ quantizer = Decimal('1e-' + str(places))
27
+ rounded_dec = dec.quantize(quantizer, rounding=ROUND_HALF_UP)
28
+
29
+ if rounded_dec % 1 == 0: return int(rounded_dec)
30
+ return float(rounded_dec)
31
+
32
+ def _nonfinite_to_inf(x):
33
+ if not math.isfinite(x): return math.inf
34
+ return x
35
+
36
+ def _tofloatlist(x) -> list[float]:
37
+ if isinstance(x, (int,float)): return [x]
38
+ if isinstance(x, np.ndarray) and x.size == 1: return [float(x.item())]
39
+ return [float(i) for i in x]
40
+
41
+ class Trial(NamedTuple):
42
+ x: float
43
+ f: tuple[float, ...]
44
+
45
+ class Solution(NamedTuple):
46
+ x: float
47
+ f: tuple[float, ...]
48
+ trials: list[Trial]
49
+
50
+ class MBS:
51
+ """Univariate minimization via grid search followed by refining, supports multi-objective functions.
52
+
53
+ This tends to outperform bayesian optimization for learning rate tuning, it is also good for plotting.
54
+
55
+ First it evaluates all points defined in ``grid``. The grid doesn't have to be dense and the solution doesn't
56
+ have to be between the endpoints.
57
+
58
+ Then it picks ``num_candidates`` best points per each objective. If any of those points are endpoints,
59
+ it expands the search space by ``step`` in that direction and evaluates the new endpoint.
60
+
61
+ Otherwise it keeps picking points between best points and evaluating them, until ``num_binary`` evaluations
62
+ have been performed.
63
+
64
+ Args:
65
+ grid (Iterable[float], optional): values for initial grid search. If ``log_scale=True``, should be in log10 scale.
66
+ step (float, optional): expansion step size. Defaults to 1.
67
+ num_candidates (int, optional): number of best points to sample new points around on each iteration. Defaults to 2.
68
+ num_binary (int, optional): maximum number of new points sampled via binary search. Defaults to 7.
69
+ num_expansions (int, optional): maximum number of expansions (not counted towards binary search points). Defaults to 7.
70
+ rounding (int, optional): rounding is to significant digits, avoids evaluating points that are too close.
71
+ lb (float | None, optional): lower bound. If ``log_scale=True``, should be in log10 scale.
72
+ ub (float | None, optional): upper bound. If ``log_scale=True``, should be in log10 scale.
73
+ log_scale (bool, optional):
74
+ whether to minimize in log10 scale. If true, it is assumed that
75
+ ``grid``, ``lb`` and ``ub`` are given in log10 scale.
76
+
77
+ Example:
78
+
79
+ ```python
80
+ def objective(x: float):
81
+ x = x * 4
82
+ return -(np.sin(x) * (x / 3) + np.cos(x*2.5) * 2 - 0.05 * (x-5)**2)
83
+
84
+ mbs = MBS(grid=[-1, 0, 1, 2, 3, 4], step=1, num_binary=10, num_expansions=10)
85
+
86
+ x, f, trials = mbs.run(objective)
87
+ # x - solution
88
+ # f - value at solution x
89
+ # trials - list of trials, each trial is a named tuple: Trial(x, f)
90
+ """
91
+
92
+ def __init__(
93
+ self,
94
+ grid: Iterable[float],
95
+ step: float,
96
+ num_candidates: int = 3,
97
+ num_binary: int = 20,
98
+ num_expansions: int = 20,
99
+ rounding: int| None = 2,
100
+ lb = None,
101
+ ub = None,
102
+ log_scale: bool = False,
103
+ ):
104
+ self.objectives: dict[int, dict[float,float]] = {}
105
+ """dictionary of objectives, each maps point (x) to value (v)"""
106
+
107
+ self.evaluated: set[float] = set()
108
+ """set of evaluated points (x)"""
109
+
110
+ grid = tuple(grid)
111
+ if len(grid) == 0: raise ValueError("At least one grid search point must be specified")
112
+ self.grid = sorted(grid)
113
+
114
+ self.step = step
115
+ self.num_candidates = num_candidates
116
+ self.num_binary = num_binary
117
+ self.num_expansions = num_expansions
118
+ self.rounding = rounding
119
+ self.log_scale = log_scale
120
+ self.lb = lb
121
+ self.ub = ub
122
+
123
+ def _get_best_x(self, n: int, objective: int):
124
+ """n best points"""
125
+ obj = self.objectives[objective]
126
+ v_to_x = [(v,x) for x,v in obj.items()]
127
+ v_to_x.sort(key = lambda vx: vx[0])
128
+ xs = [x for v,x in v_to_x]
129
+ return xs[:n]
130
+
131
+ def _suggest_points_around(self, x: float, objective: int):
132
+ """suggests points around x"""
133
+ points = list(self.objectives[objective].keys())
134
+ points.sort()
135
+ if x not in points: raise RuntimeError(f"{x} not in {points}")
136
+
137
+ expansions = []
138
+ if x == points[0]:
139
+ expansions.append((x-self.step, 'expansion'))
140
+
141
+ if x == points[-1]:
142
+ expansions.append((x+self.step, 'expansion'))
143
+
144
+ if len(expansions) != 0: return expansions
145
+
146
+ idx = points.index(x)
147
+ xm = points[idx-1]
148
+ xp = points[idx+1]
149
+
150
+ x1 = (x - (x - xm)/2)
151
+ x2 = (x + (xp - x)/2)
152
+
153
+ return [(x1, 'binary'), (x2, 'binary')]
154
+
155
+ def _out_of_bounds(self, x):
156
+ if self.lb is not None and x < self.lb: return True
157
+ if self.ub is not None and x > self.ub: return True
158
+ return False
159
+
160
+ def _evaluate(self, fn, x):
161
+ """Evaluate a point, returns False if point is already in history"""
162
+ if self.rounding is not None: x = format_number(x, self.rounding)
163
+ if x in self.evaluated: return False
164
+ if self._out_of_bounds(x): return False
165
+
166
+ self.evaluated.add(x)
167
+
168
+ if self.log_scale: vals = _tofloatlist(fn(10 ** x))
169
+ else: vals = _tofloatlist(fn(x))
170
+ vals = [_nonfinite_to_inf(v) for v in vals]
171
+
172
+ for idx, v in enumerate(vals):
173
+ if idx not in self.objectives: self.objectives[idx] = {}
174
+ self.objectives[idx][x] = v
175
+
176
+ return True
177
+
178
+ def run(self, fn) -> Solution:
179
+ # step 1 - gr id search
180
+ for x in self.grid:
181
+ self._evaluate(fn, x)
182
+
183
+ # step 2 - binary search
184
+ while True:
185
+ if (self.num_candidates <= 0) or (self.num_expansions <= 0 and self.num_binary <= 0): break
186
+
187
+ # suggest candidates
188
+ candidates: list[tuple[float, str]] = []
189
+
190
+ # sample around best points
191
+ for objective in self.objectives:
192
+ best_points = self._get_best_x(self.num_candidates, objective)
193
+ for p in best_points:
194
+ candidates.extend(self._suggest_points_around(p, objective=objective))
195
+
196
+ # filter
197
+ if self.num_expansions <= 0:
198
+ candidates = [(x,t) for x,t in candidates if t != 'expansion']
199
+
200
+ if self.num_candidates <= 0:
201
+ candidates = [(x,t) for x,t in candidates if t != 'binary']
202
+
203
+ # if expansion was suggested, discard anything else
204
+ types = [t for x, t in candidates]
205
+ if any(t == 'expansion' for t in types):
206
+ candidates = [(x,t) for x,t in candidates if t == 'expansion']
207
+
208
+ # evaluate candidates
209
+ terminate = False
210
+ at_least_one_evaluated = False
211
+ for x, t in candidates:
212
+ evaluated = self._evaluate(fn, x)
213
+ if not evaluated: continue
214
+ at_least_one_evaluated = True
215
+
216
+ if t == 'expansion': self.num_expansions -= 1
217
+ elif t == 'binary': self.num_binary -= 1
218
+
219
+ if self.num_binary < 0:
220
+ terminate = True
221
+ break
222
+
223
+ if terminate: break
224
+ if not at_least_one_evaluated:
225
+ if self.rounding is None: break
226
+ self.rounding += 1
227
+ if self.rounding == 100: break
228
+
229
+ # create dict[float, tuple[float,...]]
230
+ ret = {}
231
+ for i, objective in enumerate(self.objectives.values()):
232
+ for x, v in objective.items():
233
+ if self.log_scale: x = 10 ** x
234
+ if x not in ret: ret[x] = [None for _ in self.objectives]
235
+ ret[x][i] = v
236
+
237
+ for v in ret.values():
238
+ assert len(v) == len(self.objectives), v
239
+ assert all(i is not None for i in v), v
240
+
241
+ # ret maps x to list of per-objective values, e.g. {1: [0.1, 0.3], ...}
242
+ # now make a list of trials as they are easier to work with
243
+ trials: list[Trial] = []
244
+ for x, values in ret.items():
245
+ trials.append(Trial(x=x, f=values))
246
+
247
+ # sort trials by sum of values
248
+ trials.sort(key = lambda trial: sum(trial.f))
249
+ return Solution(x=trials[0].x, f=trials[0].f, trials=trials)
250
+
251
+ def mbs_minimize(
252
+ fn,
253
+ grid: Iterable[float],
254
+ step: float,
255
+ num_candidates: int = 3,
256
+ num_binary: int = 20,
257
+ num_expansions: int = 20,
258
+ rounding=2,
259
+ lb:float | None = None,
260
+ ub:float | None = None,
261
+ log_scale=False,
262
+ ) -> Solution:
263
+ """minimize univariate function via MBS.
264
+
265
+ Args:
266
+ fn (function): objective function that accepts a float and returns a float or a sequence of floats to minimize.
267
+ step (float, optional): expansion step size. Defaults to 1.
268
+ num_candidates (int, optional): number of best points to sample new points around on each iteration. Defaults to 2.
269
+ num_binary (int, optional): maximum number of new points sampled via binary search. Defaults to 7.
270
+ num_expansions (int, optional): maximum number of expansions (not counted towards binary search points). Defaults to 7.
271
+ rounding (int, optional): rounding is to significant digits, avoids evaluating points that are too close.
272
+ lb (float | None, optional): lower bound. If ``log_scale=True``, should be in log10 scale.
273
+ ub (float | None, optional): upper bound. If ``log_scale=True``, should be in log10 scale.
274
+ log_scale (bool, optional):
275
+ whether to minimize in log10 scale. If true, it is assumed that
276
+ ``grid``, ``lb`` and ``ub`` are given in log10 scale.
277
+
278
+ Example:
279
+
280
+ ```python
281
+ def objective(x: float):
282
+ x = x * 4
283
+ return -(np.sin(x) * (x / 3) + np.cos(x*2.5) * 2 - 0.05 * (x-5)**2)
284
+
285
+ x, f, trials = mbs_minimize(objective, grid=[-1, 0, 1, 2, 3, 4], step=1, num_binary=10, num_expansions=10)
286
+ # x - solution
287
+ # f - value at solution x
288
+ # trials - list of trials, each trial is a named tuple: Trial(x, f)
289
+ """
290
+ mbs = MBS(grid, step=step, num_candidates=num_candidates, num_binary=num_binary, num_expansions=num_expansions, rounding=rounding, lb=lb, ub=ub, log_scale=log_scale)
291
+ return mbs.run(fn)
torchzero/optim/root.py CHANGED
@@ -3,7 +3,7 @@ from collections.abc import Callable
3
3
 
4
4
  from abc import abstractmethod
5
5
  import torch
6
- from ..modules.higher_order.multipoint import sixth_order_im1, sixth_order_p6, _solve
6
+ from ..modules.second_order.multipoint import sixth_order_3p, sixth_order_5p, two_point_newton, sixth_order_3pm2, _solve
7
7
 
8
8
  def make_evaluate(f: Callable[[torch.Tensor], torch.Tensor]):
9
9
  def evaluate(x, order) -> tuple[torch.Tensor, ...]:
@@ -53,7 +53,7 @@ class Newton(RootBase):
53
53
  def one_iteration(self, x, evaluate): return newton(x, evaluate, self.lstsq)
54
54
 
55
55
 
56
- class SixthOrderP6(RootBase):
56
+ class SixthOrder3P(RootBase):
57
57
  """sixth-order iterative method
58
58
 
59
59
  Abro, Hameer Akhtar, and Muhammad Mujtaba Shaikh. "A new time-efficient and convergent nonlinear solver." Applied Mathematics and Computation 355 (2019): 516-536.
@@ -62,4 +62,4 @@ class SixthOrderP6(RootBase):
62
62
  def one_iteration(self, x, evaluate):
63
63
  def f(x): return evaluate(x, 0)[0]
64
64
  def f_j(x): return evaluate(x, 1)
65
- return sixth_order_p6(x, f, f_j, self.lstsq)
65
+ return sixth_order_3p(x, f, f_j, self.lstsq)
@@ -3,7 +3,8 @@ from collections.abc import Callable, Iterable
3
3
 
4
4
  import torch
5
5
 
6
- from ...utils import flatten, get_params
6
+ from ...utils import flatten
7
+ from ...utils.optimizer import get_params
7
8
 
8
9
  class Split(torch.optim.Optimizer):
9
10
  """Steps will all `optimizers`, also has a check that they have no duplicate parameters.