torchzero 0.1.8__py3-none-any.whl → 0.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (200) hide show
  1. docs/source/conf.py +57 -0
  2. tests/test_identical.py +230 -0
  3. tests/test_module.py +50 -0
  4. tests/test_opts.py +884 -0
  5. tests/test_tensorlist.py +1787 -0
  6. tests/test_utils_optimizer.py +170 -0
  7. tests/test_vars.py +184 -0
  8. torchzero/__init__.py +4 -4
  9. torchzero/core/__init__.py +3 -13
  10. torchzero/core/module.py +629 -510
  11. torchzero/core/preconditioner.py +137 -0
  12. torchzero/core/transform.py +252 -0
  13. torchzero/modules/__init__.py +13 -21
  14. torchzero/modules/clipping/__init__.py +3 -0
  15. torchzero/modules/clipping/clipping.py +320 -0
  16. torchzero/modules/clipping/ema_clipping.py +135 -0
  17. torchzero/modules/clipping/growth_clipping.py +187 -0
  18. torchzero/modules/experimental/__init__.py +13 -18
  19. torchzero/modules/experimental/absoap.py +350 -0
  20. torchzero/modules/experimental/adadam.py +111 -0
  21. torchzero/modules/experimental/adamY.py +135 -0
  22. torchzero/modules/experimental/adasoap.py +282 -0
  23. torchzero/modules/experimental/algebraic_newton.py +145 -0
  24. torchzero/modules/experimental/curveball.py +89 -0
  25. torchzero/modules/experimental/dsoap.py +290 -0
  26. torchzero/modules/experimental/gradmin.py +85 -0
  27. torchzero/modules/experimental/reduce_outward_lr.py +35 -0
  28. torchzero/modules/experimental/spectral.py +286 -0
  29. torchzero/modules/experimental/subspace_preconditioners.py +128 -0
  30. torchzero/modules/experimental/tropical_newton.py +136 -0
  31. torchzero/modules/functional.py +209 -0
  32. torchzero/modules/grad_approximation/__init__.py +4 -0
  33. torchzero/modules/grad_approximation/fdm.py +120 -0
  34. torchzero/modules/grad_approximation/forward_gradient.py +81 -0
  35. torchzero/modules/grad_approximation/grad_approximator.py +66 -0
  36. torchzero/modules/grad_approximation/rfdm.py +259 -0
  37. torchzero/modules/line_search/__init__.py +5 -30
  38. torchzero/modules/line_search/backtracking.py +186 -0
  39. torchzero/modules/line_search/line_search.py +181 -0
  40. torchzero/modules/line_search/scipy.py +37 -0
  41. torchzero/modules/line_search/strong_wolfe.py +260 -0
  42. torchzero/modules/line_search/trust_region.py +61 -0
  43. torchzero/modules/lr/__init__.py +2 -0
  44. torchzero/modules/lr/lr.py +59 -0
  45. torchzero/modules/lr/step_size.py +97 -0
  46. torchzero/modules/momentum/__init__.py +14 -4
  47. torchzero/modules/momentum/averaging.py +78 -0
  48. torchzero/modules/momentum/cautious.py +181 -0
  49. torchzero/modules/momentum/ema.py +173 -0
  50. torchzero/modules/momentum/experimental.py +189 -0
  51. torchzero/modules/momentum/matrix_momentum.py +124 -0
  52. torchzero/modules/momentum/momentum.py +43 -106
  53. torchzero/modules/ops/__init__.py +103 -0
  54. torchzero/modules/ops/accumulate.py +65 -0
  55. torchzero/modules/ops/binary.py +240 -0
  56. torchzero/modules/ops/debug.py +25 -0
  57. torchzero/modules/ops/misc.py +419 -0
  58. torchzero/modules/ops/multi.py +137 -0
  59. torchzero/modules/ops/reduce.py +149 -0
  60. torchzero/modules/ops/split.py +75 -0
  61. torchzero/modules/ops/switch.py +68 -0
  62. torchzero/modules/ops/unary.py +115 -0
  63. torchzero/modules/ops/utility.py +112 -0
  64. torchzero/modules/optimizers/__init__.py +18 -10
  65. torchzero/modules/optimizers/adagrad.py +146 -49
  66. torchzero/modules/optimizers/adam.py +112 -118
  67. torchzero/modules/optimizers/lion.py +18 -11
  68. torchzero/modules/optimizers/muon.py +222 -0
  69. torchzero/modules/optimizers/orthograd.py +55 -0
  70. torchzero/modules/optimizers/rmsprop.py +103 -51
  71. torchzero/modules/optimizers/rprop.py +342 -99
  72. torchzero/modules/optimizers/shampoo.py +197 -0
  73. torchzero/modules/optimizers/soap.py +286 -0
  74. torchzero/modules/optimizers/sophia_h.py +129 -0
  75. torchzero/modules/projections/__init__.py +5 -0
  76. torchzero/modules/projections/dct.py +73 -0
  77. torchzero/modules/projections/fft.py +73 -0
  78. torchzero/modules/projections/galore.py +10 -0
  79. torchzero/modules/projections/projection.py +218 -0
  80. torchzero/modules/projections/structural.py +151 -0
  81. torchzero/modules/quasi_newton/__init__.py +7 -4
  82. torchzero/modules/quasi_newton/cg.py +218 -0
  83. torchzero/modules/quasi_newton/experimental/__init__.py +1 -0
  84. torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +265 -0
  85. torchzero/modules/quasi_newton/lbfgs.py +228 -0
  86. torchzero/modules/quasi_newton/lsr1.py +170 -0
  87. torchzero/modules/quasi_newton/olbfgs.py +196 -0
  88. torchzero/modules/quasi_newton/quasi_newton.py +475 -0
  89. torchzero/modules/second_order/__init__.py +3 -4
  90. torchzero/modules/second_order/newton.py +142 -165
  91. torchzero/modules/second_order/newton_cg.py +84 -0
  92. torchzero/modules/second_order/nystrom.py +168 -0
  93. torchzero/modules/smoothing/__init__.py +2 -5
  94. torchzero/modules/smoothing/gaussian.py +164 -0
  95. torchzero/modules/smoothing/{laplacian_smoothing.py → laplacian.py} +115 -128
  96. torchzero/modules/weight_decay/__init__.py +1 -0
  97. torchzero/modules/weight_decay/weight_decay.py +52 -0
  98. torchzero/modules/wrappers/__init__.py +1 -0
  99. torchzero/modules/wrappers/optim_wrapper.py +91 -0
  100. torchzero/optim/__init__.py +2 -10
  101. torchzero/optim/utility/__init__.py +1 -0
  102. torchzero/optim/utility/split.py +45 -0
  103. torchzero/optim/wrappers/nevergrad.py +2 -28
  104. torchzero/optim/wrappers/nlopt.py +31 -16
  105. torchzero/optim/wrappers/scipy.py +79 -156
  106. torchzero/utils/__init__.py +27 -0
  107. torchzero/utils/compile.py +175 -37
  108. torchzero/utils/derivatives.py +513 -99
  109. torchzero/utils/linalg/__init__.py +5 -0
  110. torchzero/utils/linalg/matrix_funcs.py +87 -0
  111. torchzero/utils/linalg/orthogonalize.py +11 -0
  112. torchzero/utils/linalg/qr.py +71 -0
  113. torchzero/utils/linalg/solve.py +168 -0
  114. torchzero/utils/linalg/svd.py +20 -0
  115. torchzero/utils/numberlist.py +132 -0
  116. torchzero/utils/ops.py +10 -0
  117. torchzero/utils/optimizer.py +284 -0
  118. torchzero/utils/optuna_tools.py +40 -0
  119. torchzero/utils/params.py +149 -0
  120. torchzero/utils/python_tools.py +40 -25
  121. torchzero/utils/tensorlist.py +1081 -0
  122. torchzero/utils/torch_tools.py +48 -12
  123. torchzero-0.3.2.dist-info/METADATA +379 -0
  124. torchzero-0.3.2.dist-info/RECORD +128 -0
  125. {torchzero-0.1.8.dist-info → torchzero-0.3.2.dist-info}/WHEEL +1 -1
  126. {torchzero-0.1.8.dist-info → torchzero-0.3.2.dist-info/licenses}/LICENSE +0 -0
  127. torchzero-0.3.2.dist-info/top_level.txt +3 -0
  128. torchzero/core/tensorlist_optimizer.py +0 -219
  129. torchzero/modules/adaptive/__init__.py +0 -4
  130. torchzero/modules/adaptive/adaptive.py +0 -192
  131. torchzero/modules/experimental/experimental.py +0 -294
  132. torchzero/modules/experimental/quad_interp.py +0 -104
  133. torchzero/modules/experimental/subspace.py +0 -259
  134. torchzero/modules/gradient_approximation/__init__.py +0 -7
  135. torchzero/modules/gradient_approximation/_fd_formulas.py +0 -3
  136. torchzero/modules/gradient_approximation/base_approximator.py +0 -105
  137. torchzero/modules/gradient_approximation/fdm.py +0 -125
  138. torchzero/modules/gradient_approximation/forward_gradient.py +0 -163
  139. torchzero/modules/gradient_approximation/newton_fdm.py +0 -198
  140. torchzero/modules/gradient_approximation/rfdm.py +0 -125
  141. torchzero/modules/line_search/armijo.py +0 -56
  142. torchzero/modules/line_search/base_ls.py +0 -139
  143. torchzero/modules/line_search/directional_newton.py +0 -217
  144. torchzero/modules/line_search/grid_ls.py +0 -158
  145. torchzero/modules/line_search/scipy_minimize_scalar.py +0 -62
  146. torchzero/modules/meta/__init__.py +0 -12
  147. torchzero/modules/meta/alternate.py +0 -65
  148. torchzero/modules/meta/grafting.py +0 -195
  149. torchzero/modules/meta/optimizer_wrapper.py +0 -173
  150. torchzero/modules/meta/return_overrides.py +0 -46
  151. torchzero/modules/misc/__init__.py +0 -10
  152. torchzero/modules/misc/accumulate.py +0 -43
  153. torchzero/modules/misc/basic.py +0 -115
  154. torchzero/modules/misc/lr.py +0 -96
  155. torchzero/modules/misc/multistep.py +0 -51
  156. torchzero/modules/misc/on_increase.py +0 -53
  157. torchzero/modules/operations/__init__.py +0 -29
  158. torchzero/modules/operations/multi.py +0 -298
  159. torchzero/modules/operations/reduction.py +0 -134
  160. torchzero/modules/operations/singular.py +0 -113
  161. torchzero/modules/optimizers/sgd.py +0 -54
  162. torchzero/modules/orthogonalization/__init__.py +0 -2
  163. torchzero/modules/orthogonalization/newtonschulz.py +0 -159
  164. torchzero/modules/orthogonalization/svd.py +0 -86
  165. torchzero/modules/regularization/__init__.py +0 -22
  166. torchzero/modules/regularization/dropout.py +0 -34
  167. torchzero/modules/regularization/noise.py +0 -77
  168. torchzero/modules/regularization/normalization.py +0 -328
  169. torchzero/modules/regularization/ortho_grad.py +0 -78
  170. torchzero/modules/regularization/weight_decay.py +0 -92
  171. torchzero/modules/scheduling/__init__.py +0 -2
  172. torchzero/modules/scheduling/lr_schedulers.py +0 -131
  173. torchzero/modules/scheduling/step_size.py +0 -80
  174. torchzero/modules/smoothing/gaussian_smoothing.py +0 -90
  175. torchzero/modules/weight_averaging/__init__.py +0 -2
  176. torchzero/modules/weight_averaging/ema.py +0 -72
  177. torchzero/modules/weight_averaging/swa.py +0 -171
  178. torchzero/optim/experimental/__init__.py +0 -20
  179. torchzero/optim/experimental/experimental.py +0 -343
  180. torchzero/optim/experimental/ray_search.py +0 -83
  181. torchzero/optim/first_order/__init__.py +0 -18
  182. torchzero/optim/first_order/cautious.py +0 -158
  183. torchzero/optim/first_order/forward_gradient.py +0 -70
  184. torchzero/optim/first_order/optimizers.py +0 -570
  185. torchzero/optim/modular.py +0 -148
  186. torchzero/optim/quasi_newton/__init__.py +0 -1
  187. torchzero/optim/quasi_newton/directional_newton.py +0 -58
  188. torchzero/optim/second_order/__init__.py +0 -1
  189. torchzero/optim/second_order/newton.py +0 -94
  190. torchzero/optim/zeroth_order/__init__.py +0 -4
  191. torchzero/optim/zeroth_order/fdm.py +0 -87
  192. torchzero/optim/zeroth_order/newton_fdm.py +0 -146
  193. torchzero/optim/zeroth_order/rfdm.py +0 -217
  194. torchzero/optim/zeroth_order/rs.py +0 -85
  195. torchzero/random/__init__.py +0 -1
  196. torchzero/random/random.py +0 -46
  197. torchzero/tensorlist.py +0 -826
  198. torchzero-0.1.8.dist-info/METADATA +0 -130
  199. torchzero-0.1.8.dist-info/RECORD +0 -104
  200. torchzero-0.1.8.dist-info/top_level.txt +0 -1
@@ -1,294 +0,0 @@
1
- from collections import abc
2
-
3
- import torch
4
-
5
- from ...tensorlist import TensorList, where, Distributions
6
- from ...core import OptimizerModule
7
- from ...utils.derivatives import jacobian
8
-
9
- def _bool_ones_like(x):
10
- return torch.ones_like(x, dtype=torch.bool)
11
-
12
-
13
- class MinibatchRprop(OptimizerModule):
14
- """
15
- for experiments, unlikely to work well on most problems.
16
-
17
- explanation: does 2 steps per batch, applies rprop rule on the second step.
18
- """
19
- def __init__(
20
- self,
21
- nplus: float = 1.2,
22
- nminus: float = 0.5,
23
- lb: float | None = 1e-6,
24
- ub: float | None = 50,
25
- backtrack=True,
26
- next_mode = 'continue',
27
- increase_mul = 0.5,
28
- alpha: float = 1,
29
- ):
30
- defaults = dict(nplus = nplus, nminus = nminus, alpha = alpha, lb = lb, ub = ub, increase_mul=increase_mul)
31
- super().__init__(defaults)
32
- self.current_step = 0
33
- self.backtrack = backtrack
34
-
35
- self.next_mode = next_mode
36
-
37
- @torch.no_grad
38
- def step(self, vars):
39
- if vars.closure is None: raise ValueError("Minibatch Rprop requires closure")
40
- if vars.ascent is not None: raise ValueError("Minibatch Rprop must be the first module.")
41
- params = self.get_params()
42
-
43
- nplus, nminus, lb, ub = self.get_group_keys('nplus', 'nminus', 'lb', 'ub')
44
- allowed, magnitudes = self.get_state_keys(
45
- 'allowed', 'magnitudes',
46
- inits = [_bool_ones_like, torch.zeros_like],
47
- params=params
48
- )
49
-
50
- g1_sign = vars.maybe_compute_grad_(params).sign() # no inplace to not modify grads
51
- # initialize on 1st iteration
52
- if self.current_step == 0:
53
- magnitudes.fill_(self.get_group_key('alpha')).clamp_(lb, ub)
54
- # ascent = magnitudes * g1_sign
55
- # self.current_step += 1
56
- # return ascent
57
-
58
- # first step
59
- ascent = g1_sign.mul_(magnitudes).mul_(allowed)
60
- params -= ascent
61
- with torch.enable_grad(): vars.fx0_approx = vars.closure()
62
- f0 = vars.fx0; f1 = vars.fx0_approx
63
- assert f0 is not None and f1 is not None
64
-
65
- # if loss increased, reduce all lrs and undo the update
66
- if f1 > f0:
67
- increase_mul = self.get_group_key('increase_mul')
68
- magnitudes.mul_(increase_mul).clamp_(lb, ub)
69
- params += ascent
70
- self.current_step += 1
71
- return f0
72
-
73
- # on `continue` we move to params after 1st update
74
- # therefore state must be updated to have all attributes after 1st update
75
- if self.next_mode == 'continue':
76
- vars.fx0 = vars.fx0_approx
77
- vars.grad = params.ensure_grad_().grad
78
- sign = vars.grad.sign()
79
-
80
- else:
81
- sign = params.ensure_grad_().grad.sign_() # can use in-place as this is not fx0 grad
82
-
83
- # compare 1st and 2nd gradients via rprop rule
84
- prev = ascent
85
- mul = sign * prev # prev is already multiuplied by `allowed`
86
-
87
- sign_changed = mul < 0
88
- sign_same = mul > 0
89
- zeroes = mul == 0
90
-
91
- mul.fill_(1)
92
- mul.masked_fill_(sign_changed, nminus)
93
- mul.masked_fill_(sign_same, nplus)
94
-
95
- # multiply magnitudes based on sign change and clamp to bounds
96
- magnitudes.mul_(mul).clamp_(lb, ub)
97
-
98
- # revert update if sign changed
99
- if self.backtrack:
100
- ascent2 = sign.mul_(magnitudes)
101
- ascent2.masked_set_(sign_changed, prev.neg_())
102
- else:
103
- ascent2 = sign.mul_(magnitudes * ~sign_changed)
104
-
105
- # update allowed to only have weights where last update wasn't reverted
106
- allowed.set_(sign_same | zeroes)
107
-
108
- self.current_step += 1
109
-
110
- # update params or step
111
- if self.next_mode == 'continue' or (self.next_mode == 'add' and self.next_module is None):
112
- vars.ascent = ascent2
113
- return self._update_params_or_step_with_next(vars, params)
114
-
115
- if self.next_mode == 'add':
116
- # undo 1st step
117
- params += ascent
118
- vars.ascent = ascent + ascent2
119
- return self._update_params_or_step_with_next(vars, params)
120
-
121
- if self.next_mode == 'undo':
122
- params += ascent
123
- vars.ascent = ascent2
124
- return self._update_params_or_step_with_next(vars, params)
125
-
126
- raise ValueError(f'invalid next_mode: {self.next_mode}')
127
-
128
-
129
-
130
- class GradMin(OptimizerModule):
131
- """
132
- for experiments, unlikely to work well on most problems.
133
-
134
- explanation: calculate grads wrt sum of grads + loss.
135
- """
136
- def __init__(self, loss_term: float = 1, square=False, maximize_grad = False, create_graph = False):
137
- super().__init__(dict(loss_term=loss_term))
138
- self.square = square
139
- self.maximize_grad = maximize_grad
140
- self.create_graph = create_graph
141
-
142
- @torch.no_grad
143
- def step(self, vars):
144
- if vars.closure is None: raise ValueError()
145
- if vars.ascent is not None:
146
- raise ValueError("GradMin doesn't accept ascent_direction")
147
-
148
- params = self.get_params()
149
- loss_term = self.get_group_key('loss_term')
150
-
151
- self.zero_grad()
152
- with torch.enable_grad():
153
- vars.fx0 = vars.closure(False)
154
- grads = jacobian([vars.fx0], params, create_graph=True, batched=False) # type:ignore
155
- grads = TensorList(grads).squeeze_(0)
156
- if self.square:
157
- grads = grads ** 2
158
- else:
159
- grads = grads.abs()
160
-
161
- if self.maximize_grad: grads: TensorList = grads - (vars.fx0 * loss_term) # type:ignore
162
- else: grads = grads + (vars.fx0 * loss_term)
163
- grad_mean = torch.sum(torch.stack(grads.sum())) / grads.total_numel()
164
-
165
- if self.create_graph: grad_mean.backward(create_graph=True)
166
- else: grad_mean.backward(retain_graph=False)
167
-
168
- if self.maximize_grad: vars.grad = params.ensure_grad_().grad.neg_()
169
- else: vars.grad = params.ensure_grad_().grad
170
-
171
- vars.maybe_use_grad_(params)
172
- return self._update_params_or_step_with_next(vars)
173
-
174
-
175
- class HVPDiagNewton(OptimizerModule):
176
- """
177
- for experiments, unlikely to work well on most problems.
178
-
179
- explanation: may or may not approximate newton step if hessian is diagonal with 2 backward passes. Probably not.
180
- """
181
- def __init__(self, eps=1e-3):
182
- super().__init__(dict(eps=eps))
183
-
184
- @torch.no_grad
185
- def step(self, vars):
186
- if vars.closure is None: raise ValueError()
187
- if vars.ascent is not None:
188
- raise ValueError("HVPDiagNewton doesn't accept ascent_direction")
189
-
190
- params = self.get_params()
191
- eps = self.get_group_key('eps')
192
- grad_fx0 = vars.maybe_compute_grad_(params).clone()
193
- vars.grad = grad_fx0 # set state grad to the cloned version, since it will be overwritten
194
-
195
- params += grad_fx0 * eps
196
- with torch.enable_grad(): _ = vars.closure()
197
-
198
- params -= grad_fx0 * eps
199
-
200
- newton = grad_fx0 * ((grad_fx0 * eps) / (params.grad - grad_fx0))
201
- newton.nan_to_num_(0,0,0)
202
-
203
- vars.ascent = newton
204
- return self._update_params_or_step_with_next(vars)
205
-
206
-
207
-
208
- class ReduceOutwardLR(OptimizerModule):
209
- """
210
- When update sign matches weight sign, the learning rate for that weight is multiplied by `mul`.
211
-
212
- This means updates that move weights towards zero have higher learning rates.
213
- """
214
- def __init__(self, mul = 0.5, use_grad=False, invert=False):
215
- defaults = dict(mul = mul)
216
- super().__init__(defaults)
217
-
218
- self.use_grad = use_grad
219
- self.invert = invert
220
-
221
- @torch.no_grad
222
- def _update(self, vars, ascent):
223
- params = self.get_params()
224
- mul = self.get_group_key('mul')
225
-
226
- if self.use_grad: cur = vars.maybe_compute_grad_(params)
227
- else: cur = ascent
228
-
229
- # mask of weights where sign matches with update sign (minus ascent sign), multiplied by `mul`.
230
- if self.invert: mask = (params * cur) > 0
231
- else: mask = (params * cur) < 0
232
- ascent.masked_set_(mask, ascent*mul)
233
-
234
- return ascent
235
-
236
- class NoiseSign(OptimizerModule):
237
- """uses random vector with ascent sign"""
238
- def __init__(self, distribution:Distributions = 'normal', alpha = 1):
239
- super().__init__({})
240
- self.alpha = alpha
241
- self.distribution:Distributions = distribution
242
-
243
-
244
- def _update(self, vars, ascent):
245
- return ascent.sample_like(self.alpha, self.distribution).copysign_(ascent)
246
-
247
- class ParamSign(OptimizerModule):
248
- """uses params with ascent sign"""
249
- def __init__(self):
250
- super().__init__({})
251
-
252
-
253
- def _update(self, vars, ascent):
254
- params = self.get_params()
255
-
256
- return params.copysign(ascent)
257
-
258
- class NegParamSign(OptimizerModule):
259
- """uses max(params_abs) - params_abs with ascent sign"""
260
- def __init__(self):
261
- super().__init__({})
262
-
263
-
264
- def _update(self, vars, ascent):
265
- neg_params = self.get_params().abs()
266
- max = neg_params.total_max()
267
- neg_params = neg_params.neg_().add(max)
268
- return neg_params.copysign_(ascent)
269
-
270
- class InvParamSign(OptimizerModule):
271
- """uses 1/(params_abs+eps) with ascent sign"""
272
- def __init__(self, eps=1e-2):
273
- super().__init__({})
274
- self.eps = eps
275
-
276
-
277
- def _update(self, vars, ascent):
278
- inv_params = self.get_params().abs().add_(self.eps).reciprocal_()
279
- return inv_params.copysign(ascent)
280
-
281
-
282
- class ParamWhereConsistentSign(OptimizerModule):
283
- """where ascent and param signs are the same, it sets ascent to param value"""
284
- def __init__(self, eps=1e-2):
285
- super().__init__({})
286
- self.eps = eps
287
-
288
-
289
- def _update(self, vars, ascent):
290
- params = self.get_params()
291
- same_sign = params.sign() == ascent.sign()
292
- ascent.masked_set_(same_sign, params)
293
-
294
- return ascent
@@ -1,104 +0,0 @@
1
- import bisect
2
-
3
- import numpy as np
4
- import torch
5
-
6
- from ...tensorlist import TensorList
7
- from ...core import OptimizationVars
8
- from ..line_search.base_ls import LineSearchBase
9
-
10
- _FloatOrTensor = float | torch.Tensor
11
-
12
- def _ensure_float(x):
13
- if isinstance(x, torch.Tensor): return x.detach().cpu().item()
14
- elif isinstance(x, np.ndarray): return x.item()
15
- return float(x)
16
-
17
- class Point:
18
- def __init__(self, x, fx, dfx = None):
19
- self.x = x
20
- self.fx = fx
21
- self.dfx = dfx
22
-
23
- def __repr__(self):
24
- return f'Point(x={self.x:.2f}, fx={self.fx:.2f})'
25
-
26
- def _step_2poins(x1, f1, df1, x2, f2):
27
- # we have two points and one derivative
28
- # minimize the quadratic to obtain 3rd point and perform bracketing
29
- a = (df1 * x2 - f2 - df1*x1 + f1) / (x1**2 - x2**2 - 2*x1**2 + 2*x1*x2)
30
- b = df1 - 2*a*x1
31
- # c = -(a*x1**2 + b*x1 - y1)
32
- return -b / (2 * a), a
33
-
34
- class QuadraticInterpolation2Point(LineSearchBase):
35
- """This is WIP, please don't use yet!
36
- Use `torchzero.modules.MinimizeQuadraticLS` and `torchzero.modules.MinimizeQuadratic3PointsLS` instead.
37
-
38
- Args:
39
- lr (_type_, optional): _description_. Defaults to 1e-2.
40
- log_lrs (bool, optional): _description_. Defaults to False.
41
- max_evals (int, optional): _description_. Defaults to 2.
42
- min_dist (_type_, optional): _description_. Defaults to 1e-2.
43
- """
44
- def __init__(self, lr=1e-2, log_lrs = False, max_evals = 2, min_dist = 1e-2,):
45
- super().__init__({"lr": lr}, maxiter=None, log_lrs=log_lrs)
46
- self.max_evals = max_evals
47
- self.min_dist = min_dist
48
-
49
- @torch.no_grad
50
- def _find_best_lr(self, vars: OptimizationVars, params: TensorList) -> float:
51
- if vars.closure is None: raise ValueError('QuardaticLS requires closure')
52
- closure = vars.closure
53
- if vars.fx0 is None: vars.fx0 = vars.closure(False)
54
- grad = vars.grad
55
- if grad is None: grad = vars.ascent # in case we used FDM
56
- if grad is None: raise ValueError('QuardaticLS requires gradients.')
57
-
58
- params = self.get_params()
59
- lr: float = self.get_first_group_key('lr') # this doesn't support variable lrs but we still want to support schedulers
60
-
61
- # directional f'(x0)
62
- # for each lr we step by this much
63
- dfx0 = magn = grad.total_vector_norm(2)
64
-
65
- # f(x1)
66
- fx1 = self._evaluate_lr_(lr, closure, grad, params)
67
-
68
- # make min_dist relative
69
- min_dist = abs(lr) * self.min_dist
70
- points = sorted([Point(0, _ensure_float(vars.fx0), dfx0), Point(lr, _ensure_float(fx1))], key = lambda x: x.fx)
71
-
72
- for i in range(self.max_evals):
73
- # find new point
74
- p1, p2 = points
75
- if p1.dfx is None: p1, p2 = p2, p1
76
- xmin, curvature = _step_2poins(p1.x * magn, p1.fx, -p1.dfx, p2.x * magn, p2.fx) # type:ignore
77
- xmin = _ensure_float(xmin/magn)
78
- print(f'{xmin = }', f'{curvature = }, n_evals = {i+1}')
79
-
80
- # if max_evals = 1, we just minimize a quadratic once
81
- if i == self.max_evals - 1:
82
- if curvature > 0: return xmin
83
- return lr
84
-
85
- # TODO: handle negative curvature
86
- # if curvature < 0:
87
- # if points[0].x == 0: return lr
88
- # return points[0].x
89
-
90
- # evaluate value and gradients at new point
91
- fxmin = self._evaluate_lr_(xmin, closure, grad, params, backward=True)
92
- dfxmin = -(params.grad * grad).total_sum()
93
-
94
- # insort new point
95
- bisect.insort(points, Point(xmin, _ensure_float(fxmin), dfxmin), key = lambda x: x.fx)
96
-
97
- # pick 2 best points to find the new bracketing interval
98
- points = sorted(points, key = lambda x: x.fx)[:2]
99
- # TODO: new point might be worse than 2 existing ones which would lead to stagnation
100
-
101
- # if points are too close, end the loop
102
- if abs(points[0].x - points[1].x) < min_dist: break
103
-
104
- return points[0].x
@@ -1,259 +0,0 @@
1
- import typing as T
2
- from abc import ABC, abstractmethod
3
- from collections import abc
4
-
5
- import torch
6
-
7
- from ... import tensorlist as tl
8
- from ...core import OptimizationVars, OptimizerModule, _Chain, _maybe_pass_backward
9
- # this whole thing can also be implemented via parameter vectors.
10
- # Need to test which one is more efficient...
11
-
12
- class Projection(ABC):
13
- n = 1
14
- @abstractmethod
15
- def sample(self, params: tl.TensorList, vars: OptimizationVars) -> list[tl.TensorList]:
16
- """Generate a projection.
17
-
18
- Args:
19
- params (tl.TensorList): tensor list of parameters.
20
- state (OptimizationState): optimization state object.
21
-
22
- Returns:
23
- projection.
24
- """
25
-
26
- class ProjRandom(Projection):
27
- def __init__(self, n = 1, distribution: tl.Distributions = 'normal', ):
28
- self.distribution: tl.Distributions = distribution
29
- self.n = n
30
-
31
- def sample(self, params: tl.TensorList, vars: OptimizationVars):
32
- return [params.sample_like(distribution=self.distribution) for _ in range(self.n)]
33
-
34
-
35
- class Proj2Masks(Projection):
36
- def __init__(self, n_pairs = 1):
37
- """Similar to ProjRandom, but generates pairs of two random masks of 0s and 1s,
38
- where second mask is an inverse of the first mask."""
39
- self.n_pairs = n_pairs
40
-
41
- @property
42
- def n(self):
43
- return self.n_pairs * 2
44
-
45
- def sample(self, params: tl.TensorList, vars: OptimizationVars):
46
- projections = []
47
- for i in range(self.n_pairs):
48
- mask = params.bernoulli_like(0.5)
49
- mask2 = 1 - mask
50
- projections.append(mask)
51
- projections.append(mask2)
52
-
53
- return projections
54
-
55
-
56
- class ProjAscent(Projection):
57
- """Use ascent direction as the projection."""
58
- def sample(self, params: tl.TensorList, vars: OptimizationVars):
59
- if vars.ascent is None: raise ValueError
60
- return [vars.ascent]
61
-
62
- class ProjAscentRay(Projection):
63
- def __init__(self, eps = 0.1, n = 1, distribution: tl.Distributions = 'normal', ):
64
- self.eps = eps
65
- self.distribution: tl.Distributions = distribution
66
- self.n = n
67
-
68
- def sample(self, params: tl.TensorList, vars: OptimizationVars):
69
- if vars.ascent is None: raise ValueError
70
- mean = params.total_mean().detach().cpu().item()
71
- return [vars.ascent + vars.ascent.sample_like(mean * self.eps, distribution=self.distribution) for _ in range(self.n)]
72
-
73
- class ProjGrad(Projection):
74
- def sample(self, params: tl.TensorList, vars: OptimizationVars):
75
- grad = vars.maybe_compute_grad_(params)
76
- return [grad]
77
-
78
- class ProjGradRay(Projection):
79
- def __init__(self, eps = 0.1, n = 1, distribution: tl.Distributions = 'normal', ):
80
- self.eps = eps
81
- self.distribution: tl.Distributions = distribution
82
- self.n = n
83
-
84
- def sample(self, params: tl.TensorList, vars: OptimizationVars):
85
- grad = vars.maybe_compute_grad_(params)
86
- mean = params.total_mean().detach().cpu().item()
87
- return [grad + grad.sample_like(mean * self.eps, distribution=self.distribution) for _ in range(self.n)]
88
-
89
- class ProjGradAscentDifference(Projection):
90
- def __init__(self, normalize=False):
91
- """Use difference between gradient and ascent direction as projection.
92
-
93
- Args:
94
- normalize (bool, optional): normalizes grads and ascent projection to have norm = 1. Defaults to False.
95
- """
96
- self.normalize = normalize
97
-
98
- def sample(self, params: tl.TensorList, vars: OptimizationVars):
99
- grad = vars.maybe_compute_grad_(params)
100
- if self.normalize:
101
- return [vars.ascent / vars.ascent.total_vector_norm(2) - grad / grad.total_vector_norm(2)] # type:ignore
102
-
103
- return [vars.ascent - grad] # type:ignore
104
-
105
- class ProjLastGradDifference(Projection):
106
- def __init__(self):
107
- """Use difference between last two gradients as the projection."""
108
- self.last_grad = None
109
- def sample(self, params: tl.TensorList, vars: OptimizationVars):
110
- if self.last_grad is None:
111
- self.last_grad = vars.maybe_compute_grad_(params)
112
- return [self.last_grad]
113
-
114
- grad = vars.maybe_compute_grad_(params)
115
- diff = grad - self.last_grad
116
- self.last_grad = grad
117
- return [diff]
118
-
119
- class ProjLastAscentDifference(Projection):
120
- def __init__(self):
121
- """Use difference between last two ascent directions as the projection."""
122
- self.last_direction = T.cast(tl.TensorList, None)
123
-
124
- def sample(self, params: tl.TensorList, vars: OptimizationVars):
125
- if self.last_direction is None:
126
- self.last_direction: tl.TensorList = vars.ascent # type:ignore
127
- return [self.last_direction]
128
-
129
- diff = vars.ascent - self.last_direction # type:ignore
130
- self.last_direction = vars.ascent # type:ignore
131
- return [diff]
132
-
133
- class ProjNormalize(Projection):
134
- def __init__(self, *projections: Projection):
135
- """Normalizes all projections to have norm = 1."""
136
- self.projections = projections
137
-
138
- @property
139
- def n(self):
140
- return sum(proj.n for proj in self.projections)
141
-
142
- def sample(self, params: tl.TensorList, vars: OptimizationVars): # type:ignore
143
- vecs = [proj for obj in self.projections for proj in obj.sample(params, vars)]
144
- norms = [v.total_vector_norm(2) for v in vecs]
145
- return [v/norm if norm!=0 else v.randn_like() for v,norm in zip(vecs,norms)] # type:ignore
146
-
147
- class Subspace(OptimizerModule):
148
- """This is pretty inefficient, I thought of a much better way to do this via jvp and I will rewrite this soon.
149
-
150
- Optimizes parameters projected into a lower (or higher) dimensional subspace.
151
-
152
- The subspace is a bunch of projections that go through the current point. Projections can be random,
153
- or face in the direction of the gradient, or difference between last two gradients, etc. The projections
154
- are updated every `update_every` steps.
155
-
156
- Notes:
157
- This doesn't work with anything that directly calculates the hessian or other quantities via `torch.autograd.grad`,
158
- like `ExactNewton`. I will have to manually implement a subspace version for it.
159
-
160
- This also zeroes parameters after each step, meaning it won't work with some integrations like nevergrad
161
- (as they store their own parameters which don't get zeroed). It does however work with integrations like
162
- `scipy.optimize` because they performs a full minimization on each step.
163
- Another version of this which doesn't zero the params is under way.
164
-
165
- Args:
166
- projections (Projection | Iterable[Projection]):
167
- list of projections - `Projection` objects that define the directions of the projections.
168
- Each Projection object may generate one or multiple directions.
169
- update_every (int, optional): generates new projections every n steps. Defaults to 1.
170
- """
171
- def __init__(
172
- self,
173
- modules: OptimizerModule | abc.Iterable[OptimizerModule],
174
- projections: Projection | abc.Iterable[Projection],
175
- update_every: int | None = 1,
176
- ):
177
- super().__init__({})
178
- if isinstance(projections, Projection): projections = [projections]
179
- self.projections = list(projections)
180
- self._set_child_('subspace', modules)
181
- self.update_every = update_every
182
- self.current_step = 0
183
-
184
- # cast them because they are guaranteed to be assigned on 1st step.
185
- self.projection_vectors = T.cast(list[tl.TensorList], None)
186
- self.projected_params = T.cast(torch.Tensor, None)
187
-
188
-
189
- def _update_child_params_(self, child: "OptimizerModule"):
190
- dtype = self._params[0].dtype
191
- device = self._params[0].device
192
- params = [torch.zeros(sum(proj.n for proj in self.projections), dtype = dtype, device = device, requires_grad=True)]
193
- if child._has_custom_params: raise RuntimeError(f"Subspace child {child.__class__.__name__} can't have custom params.")
194
- if not child._initialized:
195
- child._initialize_(params, set_passed_params=False)
196
- else:
197
- child.param_groups = []
198
- child.add_param_group({"params": params})
199
-
200
- @torch.no_grad
201
- def step(self, vars):
202
- #if self.next_module is None: raise ValueError('RandomProjection needs a child')
203
- if vars.closure is None: raise ValueError('RandomProjection needs a closure')
204
- closure = vars.closure
205
- params = self.get_params()
206
-
207
- # every `regenerate_every` steps we generate new random projections.
208
- if self.current_step == 0 or (self.update_every is not None and self.current_step % self.update_every == 0):
209
-
210
- # generate n projection vetors
211
- self.projection_vectors = [sample for proj in self.projections for sample in proj.sample(params, vars)]
212
-
213
- # child params is n scalars corresponding to each projection vector
214
- self.projected_params = self.children['subspace']._params[0] # type:ignore
215
-
216
- # closure that takes the projected params from the child, puts them into full space params, and evaluates the loss
217
- def projected_closure(backward = True):
218
- residual = sum(vec * p for vec, p in zip(self.projection_vectors, self.projected_params))
219
-
220
- # this in-place operation prevents autodiff from working
221
- # we manually calculate the gradients as they are just a product
222
- # therefore we need torch.no_grad here because optimizers call closure under torch.enabled_grad
223
- with torch.no_grad(): params.add_(residual)
224
-
225
- loss = _maybe_pass_backward(closure, backward)
226
-
227
- if backward:
228
- self.projected_params.grad = torch.cat([(params.grad * vec).total_sum().unsqueeze(0) for vec in self.projection_vectors])
229
- with torch.no_grad(): params.sub_(residual)
230
- return loss
231
-
232
- # # if ascent direction is provided,
233
- # # project the ascent direction into the projection space (need to test if this works)
234
- # if ascent_direction is not None:
235
- # ascent_direction = tl.sum([ascent_direction*v for v in self.projection_vectors])
236
-
237
- # perform a step with the child
238
- subspace_state = vars.copy(False)
239
- subspace_state.closure = projected_closure
240
- subspace_state.ascent = None
241
- if subspace_state.grad is not None:
242
- subspace_state.grad = tl.TensorList([torch.cat([(params.grad * vec).total_sum().unsqueeze(0) for vec in self.projection_vectors])])
243
- self.children['subspace'].step(subspace_state) # type:ignore
244
-
245
- # that is going to update child's paramers, which we now project back to the full parameter space
246
- residual = tl.sum([vec * p for vec, p in zip(self.projection_vectors, self.projected_params)])
247
- vars.ascent = residual.neg_()
248
-
249
- # move fx0 and fx0 approx to state
250
- if subspace_state.fx0 is not None: vars.fx0 = subspace_state.fx0
251
- if subspace_state.fx0_approx is not None: vars.fx0 = subspace_state.fx0_approx
252
- # projected_params are residuals that have been applied to actual params on previous step in some way
253
- # therefore they need to now become zero (otherwise they work like momentum with no decay).
254
- # note: THIS WON'T WORK WITH INTEGRATIONS, UNLESS THEY PERFORM FULL MINIMIZATION EACH STEP
255
- # because their params won't be zeroed.
256
- self.projected_params.zero_()
257
-
258
- self.current_step += 1
259
- return self._update_params_or_step_with_next(vars)
@@ -1,7 +0,0 @@
1
- r"""
2
- Gradient approximation methods.
3
- """
4
- from .fdm import FDM
5
- from .rfdm import RandomizedFDM
6
- from .newton_fdm import NewtonFDM
7
- from .forward_gradient import ForwardGradient
@@ -1,3 +0,0 @@
1
- import typing as T
2
-
3
- _FD_Formulas = T.Literal['central', 'forward', 'backward']