torchzero 0.1.8__py3-none-any.whl → 0.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (200) hide show
  1. docs/source/conf.py +57 -0
  2. tests/test_identical.py +230 -0
  3. tests/test_module.py +50 -0
  4. tests/test_opts.py +884 -0
  5. tests/test_tensorlist.py +1787 -0
  6. tests/test_utils_optimizer.py +170 -0
  7. tests/test_vars.py +184 -0
  8. torchzero/__init__.py +4 -4
  9. torchzero/core/__init__.py +3 -13
  10. torchzero/core/module.py +629 -510
  11. torchzero/core/preconditioner.py +137 -0
  12. torchzero/core/transform.py +252 -0
  13. torchzero/modules/__init__.py +13 -21
  14. torchzero/modules/clipping/__init__.py +3 -0
  15. torchzero/modules/clipping/clipping.py +320 -0
  16. torchzero/modules/clipping/ema_clipping.py +135 -0
  17. torchzero/modules/clipping/growth_clipping.py +187 -0
  18. torchzero/modules/experimental/__init__.py +13 -18
  19. torchzero/modules/experimental/absoap.py +350 -0
  20. torchzero/modules/experimental/adadam.py +111 -0
  21. torchzero/modules/experimental/adamY.py +135 -0
  22. torchzero/modules/experimental/adasoap.py +282 -0
  23. torchzero/modules/experimental/algebraic_newton.py +145 -0
  24. torchzero/modules/experimental/curveball.py +89 -0
  25. torchzero/modules/experimental/dsoap.py +290 -0
  26. torchzero/modules/experimental/gradmin.py +85 -0
  27. torchzero/modules/experimental/reduce_outward_lr.py +35 -0
  28. torchzero/modules/experimental/spectral.py +286 -0
  29. torchzero/modules/experimental/subspace_preconditioners.py +128 -0
  30. torchzero/modules/experimental/tropical_newton.py +136 -0
  31. torchzero/modules/functional.py +209 -0
  32. torchzero/modules/grad_approximation/__init__.py +4 -0
  33. torchzero/modules/grad_approximation/fdm.py +120 -0
  34. torchzero/modules/grad_approximation/forward_gradient.py +81 -0
  35. torchzero/modules/grad_approximation/grad_approximator.py +66 -0
  36. torchzero/modules/grad_approximation/rfdm.py +259 -0
  37. torchzero/modules/line_search/__init__.py +5 -30
  38. torchzero/modules/line_search/backtracking.py +186 -0
  39. torchzero/modules/line_search/line_search.py +181 -0
  40. torchzero/modules/line_search/scipy.py +37 -0
  41. torchzero/modules/line_search/strong_wolfe.py +260 -0
  42. torchzero/modules/line_search/trust_region.py +61 -0
  43. torchzero/modules/lr/__init__.py +2 -0
  44. torchzero/modules/lr/lr.py +59 -0
  45. torchzero/modules/lr/step_size.py +97 -0
  46. torchzero/modules/momentum/__init__.py +14 -4
  47. torchzero/modules/momentum/averaging.py +78 -0
  48. torchzero/modules/momentum/cautious.py +181 -0
  49. torchzero/modules/momentum/ema.py +173 -0
  50. torchzero/modules/momentum/experimental.py +189 -0
  51. torchzero/modules/momentum/matrix_momentum.py +124 -0
  52. torchzero/modules/momentum/momentum.py +43 -106
  53. torchzero/modules/ops/__init__.py +103 -0
  54. torchzero/modules/ops/accumulate.py +65 -0
  55. torchzero/modules/ops/binary.py +240 -0
  56. torchzero/modules/ops/debug.py +25 -0
  57. torchzero/modules/ops/misc.py +419 -0
  58. torchzero/modules/ops/multi.py +137 -0
  59. torchzero/modules/ops/reduce.py +149 -0
  60. torchzero/modules/ops/split.py +75 -0
  61. torchzero/modules/ops/switch.py +68 -0
  62. torchzero/modules/ops/unary.py +115 -0
  63. torchzero/modules/ops/utility.py +112 -0
  64. torchzero/modules/optimizers/__init__.py +18 -10
  65. torchzero/modules/optimizers/adagrad.py +146 -49
  66. torchzero/modules/optimizers/adam.py +112 -118
  67. torchzero/modules/optimizers/lion.py +18 -11
  68. torchzero/modules/optimizers/muon.py +222 -0
  69. torchzero/modules/optimizers/orthograd.py +55 -0
  70. torchzero/modules/optimizers/rmsprop.py +103 -51
  71. torchzero/modules/optimizers/rprop.py +342 -99
  72. torchzero/modules/optimizers/shampoo.py +197 -0
  73. torchzero/modules/optimizers/soap.py +286 -0
  74. torchzero/modules/optimizers/sophia_h.py +129 -0
  75. torchzero/modules/projections/__init__.py +5 -0
  76. torchzero/modules/projections/dct.py +73 -0
  77. torchzero/modules/projections/fft.py +73 -0
  78. torchzero/modules/projections/galore.py +10 -0
  79. torchzero/modules/projections/projection.py +218 -0
  80. torchzero/modules/projections/structural.py +151 -0
  81. torchzero/modules/quasi_newton/__init__.py +7 -4
  82. torchzero/modules/quasi_newton/cg.py +218 -0
  83. torchzero/modules/quasi_newton/experimental/__init__.py +1 -0
  84. torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +265 -0
  85. torchzero/modules/quasi_newton/lbfgs.py +228 -0
  86. torchzero/modules/quasi_newton/lsr1.py +170 -0
  87. torchzero/modules/quasi_newton/olbfgs.py +196 -0
  88. torchzero/modules/quasi_newton/quasi_newton.py +475 -0
  89. torchzero/modules/second_order/__init__.py +3 -4
  90. torchzero/modules/second_order/newton.py +142 -165
  91. torchzero/modules/second_order/newton_cg.py +84 -0
  92. torchzero/modules/second_order/nystrom.py +168 -0
  93. torchzero/modules/smoothing/__init__.py +2 -5
  94. torchzero/modules/smoothing/gaussian.py +164 -0
  95. torchzero/modules/smoothing/{laplacian_smoothing.py → laplacian.py} +115 -128
  96. torchzero/modules/weight_decay/__init__.py +1 -0
  97. torchzero/modules/weight_decay/weight_decay.py +52 -0
  98. torchzero/modules/wrappers/__init__.py +1 -0
  99. torchzero/modules/wrappers/optim_wrapper.py +91 -0
  100. torchzero/optim/__init__.py +2 -10
  101. torchzero/optim/utility/__init__.py +1 -0
  102. torchzero/optim/utility/split.py +45 -0
  103. torchzero/optim/wrappers/nevergrad.py +2 -28
  104. torchzero/optim/wrappers/nlopt.py +31 -16
  105. torchzero/optim/wrappers/scipy.py +79 -156
  106. torchzero/utils/__init__.py +27 -0
  107. torchzero/utils/compile.py +175 -37
  108. torchzero/utils/derivatives.py +513 -99
  109. torchzero/utils/linalg/__init__.py +5 -0
  110. torchzero/utils/linalg/matrix_funcs.py +87 -0
  111. torchzero/utils/linalg/orthogonalize.py +11 -0
  112. torchzero/utils/linalg/qr.py +71 -0
  113. torchzero/utils/linalg/solve.py +168 -0
  114. torchzero/utils/linalg/svd.py +20 -0
  115. torchzero/utils/numberlist.py +132 -0
  116. torchzero/utils/ops.py +10 -0
  117. torchzero/utils/optimizer.py +284 -0
  118. torchzero/utils/optuna_tools.py +40 -0
  119. torchzero/utils/params.py +149 -0
  120. torchzero/utils/python_tools.py +40 -25
  121. torchzero/utils/tensorlist.py +1081 -0
  122. torchzero/utils/torch_tools.py +48 -12
  123. torchzero-0.3.2.dist-info/METADATA +379 -0
  124. torchzero-0.3.2.dist-info/RECORD +128 -0
  125. {torchzero-0.1.8.dist-info → torchzero-0.3.2.dist-info}/WHEEL +1 -1
  126. {torchzero-0.1.8.dist-info → torchzero-0.3.2.dist-info/licenses}/LICENSE +0 -0
  127. torchzero-0.3.2.dist-info/top_level.txt +3 -0
  128. torchzero/core/tensorlist_optimizer.py +0 -219
  129. torchzero/modules/adaptive/__init__.py +0 -4
  130. torchzero/modules/adaptive/adaptive.py +0 -192
  131. torchzero/modules/experimental/experimental.py +0 -294
  132. torchzero/modules/experimental/quad_interp.py +0 -104
  133. torchzero/modules/experimental/subspace.py +0 -259
  134. torchzero/modules/gradient_approximation/__init__.py +0 -7
  135. torchzero/modules/gradient_approximation/_fd_formulas.py +0 -3
  136. torchzero/modules/gradient_approximation/base_approximator.py +0 -105
  137. torchzero/modules/gradient_approximation/fdm.py +0 -125
  138. torchzero/modules/gradient_approximation/forward_gradient.py +0 -163
  139. torchzero/modules/gradient_approximation/newton_fdm.py +0 -198
  140. torchzero/modules/gradient_approximation/rfdm.py +0 -125
  141. torchzero/modules/line_search/armijo.py +0 -56
  142. torchzero/modules/line_search/base_ls.py +0 -139
  143. torchzero/modules/line_search/directional_newton.py +0 -217
  144. torchzero/modules/line_search/grid_ls.py +0 -158
  145. torchzero/modules/line_search/scipy_minimize_scalar.py +0 -62
  146. torchzero/modules/meta/__init__.py +0 -12
  147. torchzero/modules/meta/alternate.py +0 -65
  148. torchzero/modules/meta/grafting.py +0 -195
  149. torchzero/modules/meta/optimizer_wrapper.py +0 -173
  150. torchzero/modules/meta/return_overrides.py +0 -46
  151. torchzero/modules/misc/__init__.py +0 -10
  152. torchzero/modules/misc/accumulate.py +0 -43
  153. torchzero/modules/misc/basic.py +0 -115
  154. torchzero/modules/misc/lr.py +0 -96
  155. torchzero/modules/misc/multistep.py +0 -51
  156. torchzero/modules/misc/on_increase.py +0 -53
  157. torchzero/modules/operations/__init__.py +0 -29
  158. torchzero/modules/operations/multi.py +0 -298
  159. torchzero/modules/operations/reduction.py +0 -134
  160. torchzero/modules/operations/singular.py +0 -113
  161. torchzero/modules/optimizers/sgd.py +0 -54
  162. torchzero/modules/orthogonalization/__init__.py +0 -2
  163. torchzero/modules/orthogonalization/newtonschulz.py +0 -159
  164. torchzero/modules/orthogonalization/svd.py +0 -86
  165. torchzero/modules/regularization/__init__.py +0 -22
  166. torchzero/modules/regularization/dropout.py +0 -34
  167. torchzero/modules/regularization/noise.py +0 -77
  168. torchzero/modules/regularization/normalization.py +0 -328
  169. torchzero/modules/regularization/ortho_grad.py +0 -78
  170. torchzero/modules/regularization/weight_decay.py +0 -92
  171. torchzero/modules/scheduling/__init__.py +0 -2
  172. torchzero/modules/scheduling/lr_schedulers.py +0 -131
  173. torchzero/modules/scheduling/step_size.py +0 -80
  174. torchzero/modules/smoothing/gaussian_smoothing.py +0 -90
  175. torchzero/modules/weight_averaging/__init__.py +0 -2
  176. torchzero/modules/weight_averaging/ema.py +0 -72
  177. torchzero/modules/weight_averaging/swa.py +0 -171
  178. torchzero/optim/experimental/__init__.py +0 -20
  179. torchzero/optim/experimental/experimental.py +0 -343
  180. torchzero/optim/experimental/ray_search.py +0 -83
  181. torchzero/optim/first_order/__init__.py +0 -18
  182. torchzero/optim/first_order/cautious.py +0 -158
  183. torchzero/optim/first_order/forward_gradient.py +0 -70
  184. torchzero/optim/first_order/optimizers.py +0 -570
  185. torchzero/optim/modular.py +0 -148
  186. torchzero/optim/quasi_newton/__init__.py +0 -1
  187. torchzero/optim/quasi_newton/directional_newton.py +0 -58
  188. torchzero/optim/second_order/__init__.py +0 -1
  189. torchzero/optim/second_order/newton.py +0 -94
  190. torchzero/optim/zeroth_order/__init__.py +0 -4
  191. torchzero/optim/zeroth_order/fdm.py +0 -87
  192. torchzero/optim/zeroth_order/newton_fdm.py +0 -146
  193. torchzero/optim/zeroth_order/rfdm.py +0 -217
  194. torchzero/optim/zeroth_order/rs.py +0 -85
  195. torchzero/random/__init__.py +0 -1
  196. torchzero/random/random.py +0 -46
  197. torchzero/tensorlist.py +0 -826
  198. torchzero-0.1.8.dist-info/METADATA +0 -130
  199. torchzero-0.1.8.dist-info/RECORD +0 -104
  200. torchzero-0.1.8.dist-info/top_level.txt +0 -1
@@ -1,217 +0,0 @@
1
- import torch
2
-
3
- from ...modules import SGD, WrapClosure, LR
4
- from ...modules import RandomizedFDM as _RandomizedFDM
5
- from ...modules import WeightDecay
6
- from ...modules.gradient_approximation._fd_formulas import _FD_Formulas
7
- from ...tensorlist import Distributions
8
- from ..modular import Modular
9
-
10
- class RandomizedFDM(Modular):
11
- """Randomized finite difference gradient approximation (e.g. SPSA, RDSA, Nesterov random search).
12
-
13
- With `forward` and `backward` formulas performs `1 + n_samples` evaluations per step;
14
- with `central` formula performs `2 * n_samples` evaluations per step.
15
-
16
- Args:
17
- params: iterable of parameters to optimize or dicts defining parameter groups.
18
- lr (float, optional): learning rate. Defaults to 1e-3.
19
- eps (float, optional): finite difference epsilon. Defaults to 1e-3.
20
- formula (_FD_Formulas, optional): finite difference formula. Defaults to "forward".
21
- n_samples (int, optional): number of random gradient approximations that will be averaged. Defaults to 1.
22
- distribution (Distributions, optional): distribution for random perturbations. Defaults to "normal".
23
- randomize_every (int, optional): number of steps between randomizing perturbations. Defaults to 1.
24
- momentum (float, optional): momentum. Defaults to 0.
25
- dampening (float, optional): momentum dampening. Defaults to 0.
26
- nesterov (bool, optional):
27
- enables nesterov momentum, otherwise uses heavyball momentum. Defaults to False.
28
- weight_decay (float, optional): weight decay (L2 regularization). Defaults to 0.
29
- decoupled (bool, optional):
30
- decouples weight decay from gradient. If True, weight decay doesn't depend on learning rate.
31
- """
32
- def __init__(
33
- self,
34
- params,
35
- lr: float = 1e-3,
36
- eps: float = 1e-3,
37
- formula: _FD_Formulas = "forward",
38
- n_samples: int = 1,
39
- distribution: Distributions = "normal",
40
- momentum: float = 0,
41
- dampening: float = 0,
42
- nesterov: bool = False,
43
- weight_decay: float = 0,
44
- decoupled=False,
45
- ):
46
- modules: list = [
47
- _RandomizedFDM(
48
- eps=eps,
49
- formula=formula,
50
- n_samples=n_samples,
51
- distribution=distribution,
52
- ),
53
- SGD(momentum = momentum, dampening = dampening, weight_decay = weight_decay if not decoupled else 0, nesterov = nesterov),
54
- LR(lr),
55
- ]
56
- if decoupled: modules.append(WeightDecay(weight_decay))
57
- super().__init__(params, modules)
58
-
59
-
60
- class SPSA(RandomizedFDM):
61
- """Simultaneous perturbation stochastic approximation method.
62
- This is the same as a randomized finite difference method with central formula
63
- and perturbations taken from rademacher distibution.
64
- Due to rademacher having values -1 or 1, the original formula divides by the perturbation,
65
- but that is equivalent to multiplying by it, which is the same as central difference formula.
66
-
67
- Args:
68
- params: iterable of parameters to optimize or dicts defining parameter groups.
69
- lr (float, optional): learning rate. Defaults to 1e-3.
70
- eps (float, optional): finite difference epsilon. Defaults to 1e-3.
71
- momentum (float, optional): momentum factor. Defaults to 0.
72
- weight_decay (float, optional): weight decay (L2 penalty). Defaults to 0.
73
- dampening (float, optional): dampening for momentum. Defaults to 0.
74
- nesterov (bool, optional): enables Nesterov momentum (supports dampening). Defaults to False.
75
- formula (_FD_Formulas, optional): finite difference formula. Defaults to "central".
76
- n_samples (int, optional): number of random gradient approximations that will be averaged. Defaults to 1.
77
- distribution (Distributions, optional): distribution for random perturbations. Defaults to "rademacher".
78
- randomize_every (int, optional): number of steps between randomizing perturbations. Defaults to 1.
79
- momentum (float, optional): momentum. Defaults to 0.
80
- dampening (float, optional): momentum dampening. Defaults to 0.
81
- nesterov (bool, optional):
82
- enables nesterov momentum, otherwise uses heavyball momentum. Defaults to False.
83
- weight_decay (float, optional): weight decay (L2 regularization). Defaults to 0.
84
- decoupled (bool, optional):
85
- decouples weight decay from gradient. If True, weight decay doesn't depend on learning rate.
86
-
87
- reference
88
- *Spall, J. C. (1992), “Multivariate Stochastic Approximation Using a Simultaneous Perturbation
89
- Gradient Approximation,” IEEE Transactions on Automatic Control, vol. 37(3), pp. 332–341.*
90
- """
91
- def __init__(
92
- self,
93
- params,
94
- lr: float = 1e-3,
95
- eps: float = 1e-3,
96
- formula: _FD_Formulas = "central",
97
- n_samples: int = 1,
98
- distribution: Distributions = 'rademacher',
99
- momentum: float = 0,
100
- dampening: float = 0,
101
- nesterov: bool = False,
102
- weight_decay: float = 0,
103
- decoupled=False, ):
104
- super().__init__(
105
- params = params,
106
- lr = lr,
107
- eps = eps,
108
- formula = formula,
109
- n_samples = n_samples,
110
- distribution = distribution,
111
- momentum = momentum,
112
- dampening = dampening,
113
- nesterov = nesterov,
114
- weight_decay = weight_decay,
115
- decoupled = decoupled,
116
- )
117
-
118
-
119
- class RandomGaussianSmoothing(RandomizedFDM):
120
- """Random search with gaussian smoothing.
121
- This is similar to forward randomized finite difference method, and it
122
- approximates and averages the gradient with multiple random perturbations taken from normal distribution,
123
- which is an approximation for the gradient of a gaussian smoothed version of the objective function.
124
-
125
- Args:
126
- params: iterable of parameters to optimize or dicts defining parameter groups.
127
- lr (float, optional): learning rate. Defaults to 1e-2.
128
- eps (float, optional): finite difference epsilon. Defaults to 1e-2.
129
- momentum (float, optional): momentum factor. Defaults to 0.
130
- weight_decay (float, optional): weight decay (L2 penalty). Defaults to 0.
131
- dampening (float, optional): dampening for momentum. Defaults to 0.
132
- nesterov (bool, optional): enables Nesterov momentum (supports dampening). Defaults to False.
133
- formula (_FD_Formulas, optional): finite difference formula. Defaults to "forward".
134
- n_samples (int, optional): number of random gradient approximations that will be averaged. Defaults to 1.
135
- distribution (Distributions, optional): distribution for random perturbations. Defaults to "normal".
136
- randomize_every (int, optional): number of steps between randomizing perturbations. Defaults to 1.
137
- momentum (float, optional): momentum. Defaults to 0.
138
- dampening (float, optional): momentum dampening. Defaults to 0.
139
- nesterov (bool, optional):
140
- enables nesterov momentum, otherwise uses heavyball momentum. Defaults to False.
141
- weight_decay (float, optional): weight decay (L2 regularization). Defaults to 0.
142
- decoupled (bool, optional):
143
- decouples weight decay from gradient. If True, weight decay doesn't depend on learning rate.
144
-
145
- reference
146
- *Nesterov, Y., & Spokoiny, V. (2017).
147
- Random gradient-free minimization of convex functions.
148
- Foundations of Computational Mathematics, 17(2), 527-566.*
149
-
150
- """
151
- def __init__(
152
- self,
153
- params,
154
- lr: float = 1e-2,
155
- eps: float = 1e-2,
156
- formula: _FD_Formulas = "forward",
157
- n_samples: int = 10,
158
- distribution: Distributions = 'normal',
159
- momentum: float = 0,
160
- dampening: float = 0,
161
- nesterov: bool = False,
162
- weight_decay: float = 0,
163
- decoupled=False
164
- ):
165
- super().__init__(
166
- params = params,
167
- lr = lr,
168
- eps = eps,
169
- formula = formula,
170
- n_samples = n_samples,
171
- distribution = distribution,
172
- momentum = momentum,
173
- dampening = dampening,
174
- nesterov = nesterov,
175
- weight_decay = weight_decay,
176
- decoupled = decoupled,
177
- )
178
-
179
- class RandomizedFDMWrapper(Modular):
180
- """Randomized finite difference gradient approximation (e.g. SPSA, RDSA, Nesterov random search).
181
-
182
- With `forward` and `backward` formulas performs `1 + n_samples` evaluations per step;
183
- with `central` formula performs `2 * n_samples` evaluations per step.
184
-
185
- Args:
186
- params: iterable of parameters to optimize or dicts defining parameter groups.
187
- optimizer (torch.optim.Optimizer): optimizer that will perform optimization using RFDM-approximated gradients.
188
- eps (float, optional): finite difference epsilon. Defaults to 1e-3.
189
- formula (_FD_Formulas, optional): finite difference formula. Defaults to "forward".
190
- n_samples (int, optional): number of random gradient approximations that will be averaged. Defaults to 1.
191
- distribution (Distributions, optional): distribution for random perturbations. Defaults to "normal".
192
- randomize_every (int, optional): number of steps between randomizing perturbations. Defaults to 1.
193
- randomize_closure (bool, optional): whether to generate a new random perturbation each time closure
194
- is evaluated with `backward=True` (this ignores `randomize_every`). Defaults to False. Defaults to False.
195
- """
196
- def __init__(
197
- self,
198
- optimizer: torch.optim.Optimizer,
199
- eps: float = 1e-3,
200
- formula: _FD_Formulas = "forward",
201
- n_samples: int = 1,
202
- distribution: Distributions = "normal",
203
- ):
204
- modules = [
205
- _RandomizedFDM(
206
- eps=eps,
207
- formula=formula,
208
- n_samples=n_samples,
209
- distribution=distribution,
210
- target = 'closure',
211
- ),
212
- WrapClosure(optimizer)
213
- ]
214
-
215
- # some optimizers have `eps` setting in param groups too.
216
- # it should not be passed to FDM
217
- super().__init__([p for g in optimizer.param_groups.copy() for p in g['params']], modules)
@@ -1,85 +0,0 @@
1
- import torch
2
-
3
- from ...core import TensorListOptimizer, _ClosureType
4
-
5
-
6
- class RandomSearch(TensorListOptimizer):
7
- """Pure random search.
8
-
9
- Args:
10
- params: iterable of parameters to optimize or dicts defining parameter groups.
11
- min (float, optional): lower bound of the search space. Defaults to -10.
12
- max (float, optional): upper bound of the search space. Defaults to 10.
13
- stochastic (bool, optional):
14
- evaluate function twice per step,
15
- and only accept new params if they decreased the loss.
16
- Defaults to False.
17
- """
18
- def __init__(self, params, min:float = -10, max:float = 10, stochastic = False):
19
- defaults = dict(min=min, max = max)
20
- super().__init__(params, defaults)
21
- self.lowest_loss = float('inf')
22
- self.stochastic = stochastic
23
-
24
- @torch.no_grad
25
- def step(self, closure: _ClosureType): # type:ignore # pylint:disable=W0222
26
- if self.stochastic: self.lowest_loss = closure()
27
-
28
- settings = self.get_all_group_keys()
29
- params = self.get_params()
30
-
31
- old_params = params.clone()
32
- new_params = params.uniform_like(settings['min'], settings['max'])
33
- params.set_(new_params)
34
- loss = closure(False)
35
-
36
- if loss < self.lowest_loss: self.lowest_loss = loss
37
- else: params.set_(old_params)
38
- return loss
39
-
40
- class CyclicRS(TensorListOptimizer):
41
- """Performs random search cycling through each coordinate.
42
- Works surprisingly well on up to ~100 dimensional problems.
43
-
44
- Args:
45
- params: iterable of parameters to optimize or dicts defining parameter groups.
46
- min (float, optional): lower bound of the search space. Defaults to -10.
47
- max (float, optional): upper bound of the search space. Defaults to 10.
48
- stochastic (bool, optional):
49
- evaluate function twice per step,
50
- and only accept new params if they decreased the loss.
51
- Defaults to False.
52
- """
53
- def __init__(self, params, min:float = -10, max:float = 10, stochastic = False):
54
- defaults = dict(min=min, max = max)
55
- super().__init__(params, defaults)
56
- self.lowest_loss = float('inf')
57
- self.cur_param = 0
58
- self.cur_value = 0
59
- self.stochastic = stochastic
60
-
61
- @torch.no_grad
62
- def step(self, closure: _ClosureType): # type:ignore # pylint:disable=W0222
63
- if self.stochastic: self.lowest_loss = closure()
64
- settings = self.get_all_group_keys()
65
- params = self.get_params()
66
-
67
- if self.cur_param >= len(params): self.cur_param = 0
68
- param = params[self.cur_param]
69
- if self.cur_value >= param.numel():
70
- self.cur_value = 0
71
- self.cur_param += 1
72
- if self.cur_param >= len(params): self.cur_param = 0
73
- param = params[self.cur_param]
74
-
75
- flat = param.view(-1)
76
- old_value = flat[self.cur_value].clone()
77
- flat[self.cur_value] = torch.rand(1).uniform_(settings['min'][self.cur_param], settings['max'][self.cur_param]) # type:ignore
78
-
79
- loss = closure(False)
80
- if loss < self.lowest_loss: self.lowest_loss = loss
81
- else:
82
- flat[self.cur_value] = old_value
83
-
84
- self.cur_value += 1
85
- return loss
@@ -1 +0,0 @@
1
- from .random import randmask, rademacher, uniform, sphere, sample, sample_like, Distributions
@@ -1,46 +0,0 @@
1
- from typing import Literal
2
-
3
- import torch
4
-
5
- Distributions = Literal['normal', 'uniform', 'sphere', 'rademacher']
6
-
7
- def rademacher(shape, p: float=0.5, device=None, requires_grad = False, dtype=None, generator=None):
8
- """Returns a tensor filled with random numbers from Rademacher distribution.
9
-
10
- *p* chance to draw a -1 and 1-*p* chance to draw a 1. Looks like this:
11
-
12
- ```
13
- [-1, 1, 1, -1, -1, 1, -1, 1, 1, -1, -1, -1, 1, -1, 1, -1, -1, 1, -1, 1]
14
- ```
15
- """
16
- if isinstance(shape, int): shape = (shape, )
17
- return torch.bernoulli(torch.full(shape, p, dtype=dtype, device=device, requires_grad=requires_grad), generator=generator) * 2 - 1
18
-
19
- def randmask(shape, p: float=0.5, device=None, requires_grad = False, generator=None):
20
- """Returns a tensor randomly filled with True and False.
21
-
22
- *p* chance to draw `True` and 1-*p* to draw `False`."""
23
- return torch.rand(shape, device=device, requires_grad=requires_grad, generator=generator) < p
24
-
25
- def uniform(shape, low: float, high: float, device=None, requires_grad=None, dtype=None):
26
- """Returns a tensor filled with random numbers from a uniform distribution between `low` and `high`."""
27
- return torch.empty(shape, device=device, dtype=dtype, requires_grad=requires_grad).uniform_(low, high)
28
-
29
- def sphere(shape, radius: float, device=None, requires_grad=None, dtype=None, generator = None):
30
- """Returns a tensor filled with random numbers sampled on a unit sphere with center at 0."""
31
- r = torch.randn(shape, device=device, dtype=dtype, requires_grad=requires_grad, generator=generator)
32
- return (r / torch.linalg.vector_norm(r)) * radius # pylint:disable=not-callable
33
-
34
- def sample(shape, eps: float = 1, distribution: Distributions = 'normal', generator=None, device=None, dtype=None, requires_grad=False):
35
- """generic random sampling function for different distributions."""
36
- if distribution == 'normal': return torch.randn(shape,dtype=dtype,device=device,requires_grad=requires_grad, generator=generator) * eps
37
- if distribution == 'uniform':
38
- return torch.empty(size=shape,dtype=dtype,device=device,requires_grad=requires_grad).uniform_(-eps/2, eps/2, generator=generator)
39
-
40
- if distribution == 'sphere': return sphere(shape, eps,dtype=dtype,device=device,requires_grad=requires_grad, generator=generator)
41
- if distribution == 'rademacher':
42
- return rademacher(shape, eps,dtype=dtype,device=device,requires_grad=requires_grad, generator=generator) * eps
43
- raise ValueError(f'Unknow distribution {distribution}')
44
-
45
- def sample_like(x: torch.Tensor, eps: float = 1, distribution: Distributions = 'normal', generator=None):
46
- return sample(x, eps=eps, distribution=distribution, generator=generator, device=x.device, dtype=x.dtype, requires_grad=x.requires_grad)