torchzero 0.3.11__py3-none-any.whl → 0.3.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (161) hide show
  1. tests/test_opts.py +95 -69
  2. tests/test_tensorlist.py +8 -7
  3. torchzero/__init__.py +1 -1
  4. torchzero/core/__init__.py +2 -2
  5. torchzero/core/module.py +225 -72
  6. torchzero/core/reformulation.py +65 -0
  7. torchzero/core/transform.py +44 -24
  8. torchzero/modules/__init__.py +13 -5
  9. torchzero/modules/{optimizers → adaptive}/__init__.py +5 -2
  10. torchzero/modules/adaptive/adagrad.py +356 -0
  11. torchzero/modules/{optimizers → adaptive}/adahessian.py +53 -52
  12. torchzero/modules/{optimizers → adaptive}/adam.py +0 -3
  13. torchzero/modules/{optimizers → adaptive}/adan.py +26 -40
  14. torchzero/modules/{optimizers → adaptive}/adaptive_heavyball.py +3 -6
  15. torchzero/modules/adaptive/aegd.py +54 -0
  16. torchzero/modules/{optimizers → adaptive}/esgd.py +1 -1
  17. torchzero/modules/{optimizers/ladagrad.py → adaptive/lmadagrad.py} +42 -39
  18. torchzero/modules/{optimizers → adaptive}/mars.py +24 -36
  19. torchzero/modules/adaptive/matrix_momentum.py +146 -0
  20. torchzero/modules/{optimizers → adaptive}/msam.py +14 -12
  21. torchzero/modules/{optimizers → adaptive}/muon.py +19 -20
  22. torchzero/modules/adaptive/natural_gradient.py +175 -0
  23. torchzero/modules/{optimizers → adaptive}/rprop.py +0 -2
  24. torchzero/modules/{optimizers → adaptive}/sam.py +1 -1
  25. torchzero/modules/{optimizers → adaptive}/shampoo.py +8 -4
  26. torchzero/modules/{optimizers → adaptive}/soap.py +27 -50
  27. torchzero/modules/{optimizers → adaptive}/sophia_h.py +2 -3
  28. torchzero/modules/clipping/clipping.py +85 -92
  29. torchzero/modules/clipping/ema_clipping.py +5 -5
  30. torchzero/modules/conjugate_gradient/__init__.py +11 -0
  31. torchzero/modules/{quasi_newton → conjugate_gradient}/cg.py +355 -369
  32. torchzero/modules/experimental/__init__.py +9 -32
  33. torchzero/modules/experimental/dct.py +2 -2
  34. torchzero/modules/experimental/fft.py +2 -2
  35. torchzero/modules/experimental/gradmin.py +4 -3
  36. torchzero/modules/experimental/l_infinity.py +111 -0
  37. torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +3 -40
  38. torchzero/modules/experimental/newton_solver.py +79 -17
  39. torchzero/modules/experimental/newtonnewton.py +27 -14
  40. torchzero/modules/experimental/scipy_newton_cg.py +105 -0
  41. torchzero/modules/experimental/structural_projections.py +1 -1
  42. torchzero/modules/functional.py +50 -14
  43. torchzero/modules/grad_approximation/fdm.py +19 -20
  44. torchzero/modules/grad_approximation/forward_gradient.py +4 -2
  45. torchzero/modules/grad_approximation/grad_approximator.py +43 -47
  46. torchzero/modules/grad_approximation/rfdm.py +144 -122
  47. torchzero/modules/higher_order/__init__.py +1 -1
  48. torchzero/modules/higher_order/higher_order_newton.py +31 -23
  49. torchzero/modules/least_squares/__init__.py +1 -0
  50. torchzero/modules/least_squares/gn.py +161 -0
  51. torchzero/modules/line_search/__init__.py +2 -2
  52. torchzero/modules/line_search/_polyinterp.py +289 -0
  53. torchzero/modules/line_search/adaptive.py +69 -44
  54. torchzero/modules/line_search/backtracking.py +83 -70
  55. torchzero/modules/line_search/line_search.py +159 -68
  56. torchzero/modules/line_search/scipy.py +1 -1
  57. torchzero/modules/line_search/strong_wolfe.py +319 -218
  58. torchzero/modules/misc/__init__.py +8 -0
  59. torchzero/modules/misc/debug.py +4 -4
  60. torchzero/modules/misc/escape.py +9 -7
  61. torchzero/modules/misc/gradient_accumulation.py +88 -22
  62. torchzero/modules/misc/homotopy.py +59 -0
  63. torchzero/modules/misc/misc.py +82 -15
  64. torchzero/modules/misc/multistep.py +47 -11
  65. torchzero/modules/misc/regularization.py +5 -9
  66. torchzero/modules/misc/split.py +55 -35
  67. torchzero/modules/misc/switch.py +1 -1
  68. torchzero/modules/momentum/__init__.py +1 -5
  69. torchzero/modules/momentum/averaging.py +3 -3
  70. torchzero/modules/momentum/cautious.py +42 -47
  71. torchzero/modules/momentum/momentum.py +35 -1
  72. torchzero/modules/ops/__init__.py +9 -1
  73. torchzero/modules/ops/binary.py +9 -8
  74. torchzero/modules/{momentum/ema.py → ops/higher_level.py} +10 -33
  75. torchzero/modules/ops/multi.py +15 -15
  76. torchzero/modules/ops/reduce.py +1 -1
  77. torchzero/modules/ops/utility.py +12 -8
  78. torchzero/modules/projections/projection.py +4 -4
  79. torchzero/modules/quasi_newton/__init__.py +1 -16
  80. torchzero/modules/quasi_newton/damping.py +105 -0
  81. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -163
  82. torchzero/modules/quasi_newton/lbfgs.py +256 -200
  83. torchzero/modules/quasi_newton/lsr1.py +167 -132
  84. torchzero/modules/quasi_newton/quasi_newton.py +346 -446
  85. torchzero/modules/restarts/__init__.py +7 -0
  86. torchzero/modules/restarts/restars.py +252 -0
  87. torchzero/modules/second_order/__init__.py +2 -1
  88. torchzero/modules/second_order/multipoint.py +238 -0
  89. torchzero/modules/second_order/newton.py +133 -88
  90. torchzero/modules/second_order/newton_cg.py +141 -80
  91. torchzero/modules/smoothing/__init__.py +1 -1
  92. torchzero/modules/smoothing/sampling.py +300 -0
  93. torchzero/modules/step_size/__init__.py +1 -1
  94. torchzero/modules/step_size/adaptive.py +312 -47
  95. torchzero/modules/termination/__init__.py +14 -0
  96. torchzero/modules/termination/termination.py +207 -0
  97. torchzero/modules/trust_region/__init__.py +5 -0
  98. torchzero/modules/trust_region/cubic_regularization.py +170 -0
  99. torchzero/modules/trust_region/dogleg.py +92 -0
  100. torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
  101. torchzero/modules/trust_region/trust_cg.py +97 -0
  102. torchzero/modules/trust_region/trust_region.py +350 -0
  103. torchzero/modules/variance_reduction/__init__.py +1 -0
  104. torchzero/modules/variance_reduction/svrg.py +208 -0
  105. torchzero/modules/weight_decay/weight_decay.py +65 -64
  106. torchzero/modules/zeroth_order/__init__.py +1 -0
  107. torchzero/modules/zeroth_order/cd.py +359 -0
  108. torchzero/optim/root.py +65 -0
  109. torchzero/optim/utility/split.py +8 -8
  110. torchzero/optim/wrappers/directsearch.py +0 -1
  111. torchzero/optim/wrappers/fcmaes.py +3 -2
  112. torchzero/optim/wrappers/nlopt.py +0 -2
  113. torchzero/optim/wrappers/optuna.py +2 -2
  114. torchzero/optim/wrappers/scipy.py +81 -22
  115. torchzero/utils/__init__.py +40 -4
  116. torchzero/utils/compile.py +1 -1
  117. torchzero/utils/derivatives.py +123 -111
  118. torchzero/utils/linalg/__init__.py +9 -2
  119. torchzero/utils/linalg/linear_operator.py +329 -0
  120. torchzero/utils/linalg/matrix_funcs.py +2 -2
  121. torchzero/utils/linalg/orthogonalize.py +2 -1
  122. torchzero/utils/linalg/qr.py +2 -2
  123. torchzero/utils/linalg/solve.py +226 -154
  124. torchzero/utils/metrics.py +83 -0
  125. torchzero/utils/python_tools.py +6 -0
  126. torchzero/utils/tensorlist.py +105 -34
  127. torchzero/utils/torch_tools.py +9 -4
  128. torchzero-0.3.13.dist-info/METADATA +14 -0
  129. torchzero-0.3.13.dist-info/RECORD +166 -0
  130. {torchzero-0.3.11.dist-info → torchzero-0.3.13.dist-info}/top_level.txt +0 -1
  131. docs/source/conf.py +0 -59
  132. docs/source/docstring template.py +0 -46
  133. torchzero/modules/experimental/absoap.py +0 -253
  134. torchzero/modules/experimental/adadam.py +0 -118
  135. torchzero/modules/experimental/adamY.py +0 -131
  136. torchzero/modules/experimental/adam_lambertw.py +0 -149
  137. torchzero/modules/experimental/adaptive_step_size.py +0 -90
  138. torchzero/modules/experimental/adasoap.py +0 -177
  139. torchzero/modules/experimental/cosine.py +0 -214
  140. torchzero/modules/experimental/cubic_adam.py +0 -97
  141. torchzero/modules/experimental/eigendescent.py +0 -120
  142. torchzero/modules/experimental/etf.py +0 -195
  143. torchzero/modules/experimental/exp_adam.py +0 -113
  144. torchzero/modules/experimental/expanded_lbfgs.py +0 -141
  145. torchzero/modules/experimental/hnewton.py +0 -85
  146. torchzero/modules/experimental/modular_lbfgs.py +0 -265
  147. torchzero/modules/experimental/parabolic_search.py +0 -220
  148. torchzero/modules/experimental/subspace_preconditioners.py +0 -145
  149. torchzero/modules/experimental/tensor_adagrad.py +0 -42
  150. torchzero/modules/line_search/polynomial.py +0 -233
  151. torchzero/modules/momentum/matrix_momentum.py +0 -193
  152. torchzero/modules/optimizers/adagrad.py +0 -165
  153. torchzero/modules/quasi_newton/trust_region.py +0 -397
  154. torchzero/modules/smoothing/gaussian.py +0 -198
  155. torchzero-0.3.11.dist-info/METADATA +0 -404
  156. torchzero-0.3.11.dist-info/RECORD +0 -159
  157. torchzero-0.3.11.dist-info/licenses/LICENSE +0 -21
  158. /torchzero/modules/{optimizers → adaptive}/lion.py +0 -0
  159. /torchzero/modules/{optimizers → adaptive}/orthograd.py +0 -0
  160. /torchzero/modules/{optimizers → adaptive}/rmsprop.py +0 -0
  161. {torchzero-0.3.11.dist-info → torchzero-0.3.13.dist-info}/WHEEL +0 -0
@@ -1,397 +0,0 @@
1
- """Trust region API is currently experimental, it will probably change completely"""
2
- # pylint:disable=not-callable
3
- from abc import ABC, abstractmethod
4
- from typing import Any, Literal, cast, final
5
- from collections.abc import Sequence, Mapping
6
-
7
- import numpy as np
8
- import torch
9
- from scipy.optimize import lsq_linear
10
-
11
- from ...core import Chainable, Module, apply_transform, Var
12
- from ...utils import TensorList, vec_to_tensors
13
- from ...utils.derivatives import (
14
- hessian_list_to_mat,
15
- jacobian_and_hessian_wrt,
16
- )
17
- from .quasi_newton import HessianUpdateStrategy
18
- from ...utils.linalg import steihaug_toint_cg
19
-
20
-
21
- def trust_lstsq(H: torch.Tensor, g: torch.Tensor, trust_region: float):
22
- res = lsq_linear(H.numpy(force=True).astype(np.float64), g.numpy(force=True).astype(np.float64), bounds=(-trust_region, trust_region))
23
- x = torch.from_numpy(res.x).to(H)
24
- return x, res.cost
25
-
26
- def _flatten_tensors(tensors: list[torch.Tensor]):
27
- return torch.cat([t.ravel() for t in tensors])
28
-
29
-
30
- class TrustRegionBase(Module, ABC):
31
- def __init__(
32
- self,
33
- defaults: dict | None = None,
34
- hess_module: HessianUpdateStrategy | None = None,
35
- update_freq: int = 1,
36
- inner: Chainable | None = None,
37
- ):
38
- self._update_freq = update_freq
39
- super().__init__(defaults)
40
-
41
- if hess_module is not None:
42
- self.set_child('hess_module', hess_module)
43
-
44
- if inner is not None:
45
- self.set_child('inner', inner)
46
-
47
- @abstractmethod
48
- def trust_region_step(self, var: Var, tensors:list[torch.Tensor], P: torch.Tensor, is_inverse:bool) -> Var:
49
- """trust region logic"""
50
-
51
-
52
- @final
53
- @torch.no_grad
54
- def update(self, var):
55
- # ---------------------------------- update ---------------------------------- #
56
- closure = var.closure
57
- if closure is None: raise RuntimeError("Trust region requires closure")
58
- params = var.params
59
-
60
- step = self.global_state.get('step', 0)
61
- self.global_state['step'] = step + 1
62
-
63
- P = None
64
- is_inverse=None
65
- g_list = var.grad
66
- loss = var.loss
67
- if step % self._update_freq == 0:
68
-
69
- if 'hess_module' not in self.children:
70
- params=var.params
71
- closure=var.closure
72
- if closure is None: raise ValueError('Closure is required for trust region')
73
- with torch.enable_grad():
74
- loss = var.loss = var.loss_approx = closure(False)
75
- g_list, H_list = jacobian_and_hessian_wrt([loss], params, batched=True)
76
- g_list = [t[0] for t in g_list] # remove leading dim from loss
77
- var.grad = g_list
78
- P = hessian_list_to_mat(H_list)
79
- is_inverse=False
80
-
81
-
82
- else:
83
- hessian_module = cast(HessianUpdateStrategy, self.children['hess_module'])
84
- hessian_module.update(var)
85
- P, is_inverse = hessian_module.get_B()
86
-
87
- if self._update_freq != 0:
88
- self.global_state['B'] = P
89
- self.global_state['is_inverse'] = is_inverse
90
-
91
-
92
- @final
93
- @torch.no_grad
94
- def apply(self, var):
95
- P = self.global_state['B']
96
- is_inverse = self.global_state['is_inverse']
97
-
98
- # -------------------------------- inner step -------------------------------- #
99
- update = var.get_update()
100
- if 'inner' in self.children:
101
- update = apply_transform(self.children['inner'], update, params=var.params, grads=var.grad, var=var)
102
-
103
- # ----------------------------------- apply ---------------------------------- #
104
- return self.trust_region_step(var=var, tensors=update, P=P, is_inverse=is_inverse)
105
-
106
- def _update_tr_radius(update_vec:torch.Tensor, params: Sequence[torch.Tensor], closure,
107
- loss, g:torch.Tensor, H:torch.Tensor, trust_region:float, settings: Mapping):
108
- """returns (update, new_trust_region)
109
-
110
- Args:
111
- update_vec (torch.Tensor): update vector which is SUBTRACTED from parameters
112
- params (_type_): params tensor list
113
- closure (_type_): closure
114
- loss (_type_): loss at x0
115
- g (torch.Tensor): gradient vector
116
- H (torch.Tensor): hessian
117
- trust_region (float): current trust region value
118
- """
119
- # evaluate actual loss reduction
120
- update_unflattned = vec_to_tensors(update_vec, params)
121
- params = TensorList(params)
122
- params -= update_unflattned
123
- loss_star = closure(False)
124
- params += update_unflattned
125
- reduction = loss - loss_star
126
-
127
- # expected reduction is g.T @ p + 0.5 * p.T @ B @ p
128
- if H.ndim == 1: Hu = H * update_vec
129
- else: Hu = H @ update_vec
130
- pred_reduction = - (g.dot(update_vec) + 0.5 * update_vec.dot(Hu))
131
- rho = reduction / (pred_reduction.clip(min=1e-8))
132
-
133
- # failed step
134
- if rho < 0.25:
135
- trust_region *= settings["nminus"]
136
-
137
- # very good step
138
- elif rho > 0.75:
139
- diff = trust_region - update_vec.abs()
140
- if (diff.amin() / trust_region) > 1e-4: # hits boundary
141
- trust_region *= settings["nplus"]
142
-
143
- # # if the ratio is high enough then accept the proposed step
144
- # if rho > settings["eta"]:
145
- # update = vec_to_tensors(update_vec, params)
146
-
147
- # else:
148
- # update = params.zeros_like()
149
-
150
- return trust_region, rho > settings["eta"]
151
-
152
- class TrustCG(TrustRegionBase):
153
- """Trust region via Steihaug-Toint Conjugate Gradient method. This is mainly useful for quasi-newton methods.
154
- If you don't use :code:`hess_module`, use the matrix-free :code:`tz.m.NewtonCGSteihaug` which only uses hessian-vector products.
155
-
156
- Args:
157
- hess_module (HessianUpdateStrategy | None, optional):
158
- Hessian update strategy, must be one of the :code:`HessianUpdateStrategy` modules. Make sure to set :code:`inverse=False`. If None, uses autograd to calculate the hessian. Defaults to None.
159
- eta (float, optional):
160
- if ratio of actual to predicted rediction is larger than this, step is accepted.
161
- When :code:`hess_module` is None, this can be set to 0. Defaults to 0.15.
162
- nplus (float, optional): increase factor on successful steps. Defaults to 1.5.
163
- nminus (float, optional): decrease factor on unsuccessful steps. Defaults to 0.75.
164
- init (float, optional): Initial trust region value. Defaults to 1.
165
- update_freq (int, optional): frequency of updating the hessian. Defaults to 1.
166
- reg (int, optional): hessian regularization. Defaults to 0.
167
- inner (Chainable | None, optional): preconditioning is applied to output of thise module. Defaults to None.
168
-
169
- Examples:
170
- Trust-SR1
171
-
172
- .. code-block:: python
173
-
174
- opt = tz.Modular(
175
- model.parameters(),
176
- tz.m.TrustCG(hess_module=tz.m.SR1(inverse=False)),
177
- )
178
- """
179
- def __init__(
180
- self,
181
- hess_module: HessianUpdateStrategy | None,
182
- eta: float= 0.15,
183
- nplus: float = 2,
184
- nminus: float = 0.25,
185
- init: float = 1,
186
- update_freq: int = 1,
187
- reg: float = 0,
188
- max_attempts: int = 10,
189
- inner: Chainable | None = None,
190
- ):
191
- defaults = dict(init=init, nplus=nplus, nminus=nminus, eta=eta, reg=reg, max_attempts=max_attempts)
192
- super().__init__(defaults, hess_module=hess_module, update_freq=update_freq, inner=inner)
193
-
194
- @torch.no_grad
195
- def trust_region_step(self, var, tensors, P, is_inverse):
196
- params = TensorList(var.params)
197
- settings = self.settings[params[0]]
198
- g = _flatten_tensors(tensors)
199
-
200
- reg = settings['reg']
201
- max_attempts = settings['max_attempts']
202
-
203
- loss = var.loss
204
- closure = var.closure
205
- if closure is None: raise RuntimeError("Trust region requires closure")
206
- if loss is None: loss = closure(False)
207
-
208
- if is_inverse:
209
- if P.ndim == 1: P = P.reciprocal()
210
- else: raise NotImplementedError()
211
-
212
- success = False
213
- update_vec = None
214
- while not success:
215
- max_attempts -= 1
216
- if max_attempts < 0: break
217
-
218
- trust_region = self.global_state.get('trust_region', settings['init'])
219
-
220
- if trust_region < 1e-8 or trust_region > 1e8:
221
- trust_region = self.global_state['trust_region'] = settings['init']
222
-
223
- update_vec = steihaug_toint_cg(P, g, trust_region, reg=reg)
224
-
225
- self.global_state['trust_region'], success = _update_tr_radius(
226
- update_vec=update_vec, params=params, closure=closure,
227
- loss=loss, g=g, H=P, trust_region=trust_region, settings = settings,
228
- )
229
-
230
- assert update_vec is not None
231
- if success: var.update = vec_to_tensors(update_vec, params)
232
- else: var.update = params.zeros_like()
233
-
234
- return var
235
-
236
-
237
- # code from https://github.com/konstmish/opt_methods/blob/master/optmethods/second_order/cubic.py
238
- # ported to torch
239
- def ls_cubic_solver(f, g:torch.Tensor, H:torch.Tensor, M: float, is_inverse: bool, loss_plus, it_max=100, epsilon=1e-8, ):
240
- """
241
- Solve min_z <g, z-x> + 1/2<z-x, H(z-x)> + M/3 ||z-x||^3
242
-
243
- For explanation of Cauchy point, see "Gradient Descent
244
- Efficiently Finds the Cubic-Regularized Non-Convex Newton Step"
245
- https://arxiv.org/pdf/1612.00547.pdf
246
- Other potential implementations can be found in paper
247
- "Adaptive cubic regularisation methods"
248
- https://people.maths.ox.ac.uk/cartis/papers/ARCpI.pdf
249
- """
250
- solver_it = 1
251
- if is_inverse:
252
- newton_step = - H @ g
253
- H = torch.linalg.inv(H)
254
- else:
255
- newton_step, info = torch.linalg.solve_ex(H, g)
256
- if info != 0:
257
- newton_step = torch.linalg.lstsq(H, g).solution
258
- newton_step.neg_()
259
- if M == 0:
260
- return newton_step, solver_it
261
- def cauchy_point(g, H, M):
262
- if torch.linalg.vector_norm(g) == 0 or M == 0:
263
- return 0 * g
264
- g_dir = g / torch.linalg.vector_norm(g)
265
- H_g_g = H @ g_dir @ g_dir
266
- R = -H_g_g / (2*M) + torch.sqrt((H_g_g/M)**2/4 + torch.linalg.vector_norm(g)/M)
267
- return -R * g_dir
268
-
269
- def conv_criterion(s, r):
270
- """
271
- The convergence criterion is an increasing and concave function in r
272
- and it is equal to 0 only if r is the solution to the cubic problem
273
- """
274
- s_norm = torch.linalg.vector_norm(s)
275
- return 1/s_norm - 1/r
276
-
277
- # Solution s satisfies ||s|| >= Cauchy_radius
278
- r_min = torch.linalg.vector_norm(cauchy_point(g, H, M))
279
-
280
- if f > loss_plus(newton_step):
281
- return newton_step, solver_it
282
-
283
- r_max = torch.linalg.vector_norm(newton_step)
284
- if r_max - r_min < epsilon:
285
- return newton_step, solver_it
286
- id_matrix = torch.eye(g.size(0), device=g.device, dtype=g.dtype)
287
- s_lam = None
288
- for _ in range(it_max):
289
- r_try = (r_min + r_max) / 2
290
- lam = r_try * M
291
- s_lam = -torch.linalg.solve(H + lam*id_matrix, g)
292
- solver_it += 1
293
- crit = conv_criterion(s_lam, r_try)
294
- if np.abs(crit) < epsilon:
295
- return s_lam, solver_it
296
- if crit < 0:
297
- r_min = r_try
298
- else:
299
- r_max = r_try
300
- if r_max - r_min < epsilon:
301
- break
302
- assert s_lam is not None
303
- return s_lam, solver_it
304
-
305
- class CubicRegularization(TrustRegionBase):
306
- """Cubic regularization.
307
-
308
- .. note::
309
- by default this functions like a trust region, set nplus and nminus = 1 to make regularization parameter fixed.
310
- :code:`init` sets 1/regularization.
311
-
312
- Args:
313
- hess_module (HessianUpdateStrategy | None, optional):
314
- Hessian update strategy, must be one of the :code:`HessianUpdateStrategy` modules. This works better with true hessian though. Make sure to set :code:`inverse=False`. If None, uses autograd to calculate the hessian. Defaults to None.
315
- eta (float, optional):
316
- if ratio of actual to predicted rediction is larger than this, step is accepted.
317
- When :code:`hess_module` is None, this can be set to 0. Defaults to 0.0.
318
- nplus (float, optional): increase factor on successful steps. Defaults to 1.5.
319
- nminus (float, optional): decrease factor on unsuccessful steps. Defaults to 0.75.
320
- init (float, optional): Initial trust region value. Defaults to 1.
321
- maxiter (float, optional): maximum iterations when solving cubic subproblem, defaults to 1e-7.
322
- eps (float, optional): epsilon for the solver, defaults to 1e-8.
323
- update_freq (int, optional): frequency of updating the hessian. Defaults to 1.
324
- inner (Chainable | None, optional): preconditioning is applied to output of thise module. Defaults to None.
325
-
326
- Examples:
327
- Cubic regularized newton
328
-
329
- .. code-block:: python
330
-
331
- opt = tz.Modular(
332
- model.parameters(),
333
- tz.m.CubicRegularization(),
334
- )
335
-
336
- """
337
- def __init__(
338
- self,
339
- hess_module: HessianUpdateStrategy | None = None,
340
- eta: float= 0.0,
341
- nplus: float = 2,
342
- nminus: float = 0.25,
343
- init: float = 1,
344
- maxiter: int = 100,
345
- eps: float = 1e-8,
346
- update_freq: int = 1,
347
- max_attempts: int = 10,
348
- inner: Chainable | None = None,
349
- ):
350
- defaults = dict(init=init, nplus=nplus, nminus=nminus, eta=eta, maxiter=maxiter, eps=eps, max_attempts=max_attempts)
351
- super().__init__(defaults, hess_module=hess_module, update_freq=update_freq, inner=inner)
352
-
353
- @torch.no_grad
354
- def trust_region_step(self, var, tensors, P, is_inverse):
355
- params = TensorList(var.params)
356
- settings = self.settings[params[0]]
357
- g = _flatten_tensors(tensors)
358
-
359
- maxiter = settings['maxiter']
360
- max_attempts = settings['max_attempts']
361
- eps = settings['eps']
362
-
363
- loss = var.loss
364
- closure = var.closure
365
- if closure is None: raise RuntimeError("Trust region requires closure")
366
- if loss is None: loss = closure(False)
367
-
368
- def loss_plus(x):
369
- x_unflat = vec_to_tensors(x, params)
370
- params.add_(x_unflat)
371
- loss_x = closure(False)
372
- params.sub_(x_unflat)
373
- return loss_x
374
-
375
- success = False
376
- update_vec = None
377
- while not success:
378
- max_attempts -= 1
379
- if max_attempts < 0: break
380
-
381
- trust_region = self.global_state.get('trust_region', settings['init'])
382
- if trust_region < 1e-8 or trust_region > 1e16: trust_region = self.global_state['trust_region'] = settings['init']
383
-
384
- update_vec, _ = ls_cubic_solver(f=loss, g=g, H=P, M=1/trust_region, is_inverse=is_inverse,
385
- loss_plus=loss_plus, it_max=maxiter, epsilon=eps)
386
- update_vec.neg_()
387
-
388
- self.global_state['trust_region'], success = _update_tr_radius(
389
- update_vec=update_vec, params=params, closure=closure,
390
- loss=loss, g=g, H=P, trust_region=trust_region, settings = settings,
391
- )
392
-
393
- assert update_vec is not None
394
- if success: var.update = vec_to_tensors(update_vec, params)
395
- else: var.update = params.zeros_like()
396
-
397
- return var
@@ -1,198 +0,0 @@
1
- import warnings
2
- from abc import ABC, abstractmethod
3
- from collections.abc import Callable, Sequence
4
- from functools import partial
5
- from typing import Literal
6
-
7
- import torch
8
-
9
- from ...core import Modular, Module, Var
10
- from ...utils import NumberList, TensorList
11
- from ...utils.derivatives import jacobian_wrt
12
- from ..grad_approximation import GradApproximator, GradTarget
13
-
14
-
15
- class Reformulation(Module, ABC):
16
- def __init__(self, defaults):
17
- super().__init__(defaults)
18
-
19
- @abstractmethod
20
- def closure(self, backward: bool, closure: Callable, params:list[torch.Tensor], var: Var) -> tuple[float | torch.Tensor, Sequence[torch.Tensor] | None]:
21
- """returns loss and gradient, if backward is False then gradient can be None"""
22
-
23
- def pre_step(self, var: Var) -> Var | None:
24
- """This runs once before each step, whereas `closure` may run multiple times per step if further modules
25
- evaluate gradients at multiple points. This is useful for example to pre-generate new random perturbations."""
26
- return var
27
-
28
- def step(self, var):
29
- ret = self.pre_step(var)
30
- if isinstance(ret, Var): var = ret
31
-
32
- if var.closure is None: raise RuntimeError("Reformulation requires closure")
33
- params, closure = var.params, var.closure
34
-
35
-
36
- def modified_closure(backward=True):
37
- loss, grad = self.closure(backward, closure, params, var)
38
-
39
- if grad is not None:
40
- for p,g in zip(params, grad):
41
- p.grad = g
42
-
43
- return loss
44
-
45
- var.closure = modified_closure
46
- return var
47
-
48
-
49
- def _decay_sigma_(self: Module, params):
50
- for p in params:
51
- state = self.state[p]
52
- settings = self.settings[p]
53
- state['sigma'] *= settings['decay']
54
-
55
- def _generate_perturbations_to_state_(self: Module, params: TensorList, n_samples, sigmas, generator):
56
- perturbations = [params.sample_like(generator=generator) for _ in range(n_samples)]
57
- torch._foreach_mul_([p for l in perturbations for p in l], [v for vv in sigmas for v in [vv]*n_samples])
58
- for param, prt in zip(params, zip(*perturbations)):
59
- self.state[param]['perturbations'] = prt
60
-
61
- def _clear_state_hook(optimizer: Modular, var: Var, self: Module):
62
- for m in optimizer.unrolled_modules:
63
- if m is not self:
64
- m.reset()
65
-
66
- class GaussianHomotopy(Reformulation):
67
- """Approximately smoothes the function with a gaussian kernel by sampling it at random perturbed points around current point. Both function values and gradients are averaged over all samples. The perturbed points are generated before each
68
- step and remain the same throughout the step.
69
-
70
- .. note::
71
- This module reformulates the objective, it modifies the closure to evaluate value and gradients of a smoothed function. All modules after this will operate on the modified objective.
72
-
73
- .. note::
74
- This module requires the a closure passed to the optimizer step,
75
- as it needs to re-evaluate the loss and gradients at perturbed points.
76
-
77
- Args:
78
- n_samples (int): number of points to sample, larger values lead to a more accurate smoothing.
79
- init_sigma (float): initial scale of perturbations.
80
- tol (float | None, optional):
81
- if maximal parameters change value is smaller than this, sigma is reduced by :code:`decay`. Defaults to 1e-4.
82
- decay (float, optional): multiplier to sigma when converged on a smoothed function. Defaults to 0.5.
83
- max_steps (int | None, optional): maximum number of steps before decaying sigma. Defaults to None.
84
- clear_state (bool, optional):
85
- whether to clear all other module states when sigma is decayed, because the objective function changes. Defaults to True.
86
- seed (int | None, optional): seed for random perturbationss. Defaults to None.
87
-
88
- Examples:
89
- Gaussian-smoothed NewtonCG
90
-
91
- .. code-block:: python
92
-
93
- opt = tz.Modular(
94
- model.parameters(),
95
- tz.m.GaussianHomotopy(100),
96
- tz.m.NewtonCG(maxiter=20),
97
- tz.m.AdaptiveBacktracking(),
98
- )
99
-
100
- """
101
- def __init__(
102
- self,
103
- n_samples: int,
104
- init_sigma: float,
105
- tol: float | None = 1e-4,
106
- decay=0.5,
107
- max_steps: int | None = None,
108
- clear_state=True,
109
- seed: int | None = None,
110
- ):
111
- defaults = dict(n_samples=n_samples, init_sigma=init_sigma, tol=tol, decay=decay, max_steps=max_steps, clear_state=clear_state, seed=seed)
112
- super().__init__(defaults)
113
-
114
-
115
- def _get_generator(self, seed: int | None | torch.Generator, params: list[torch.Tensor]):
116
- if 'generator' not in self.global_state:
117
- if isinstance(seed, torch.Generator): self.global_state['generator'] = seed
118
- elif seed is not None: self.global_state['generator'] = torch.Generator(params[0].device).manual_seed(seed)
119
- else: self.global_state['generator'] = None
120
- return self.global_state['generator']
121
-
122
- def pre_step(self, var):
123
- params = TensorList(var.params)
124
- settings = self.settings[params[0]]
125
- n_samples = settings['n_samples']
126
- init_sigma = [self.settings[p]['init_sigma'] for p in params]
127
- sigmas = self.get_state(params, 'sigma', init=init_sigma)
128
-
129
- if any('perturbations' not in self.state[p] for p in params):
130
- generator = self._get_generator(settings['seed'], params)
131
- _generate_perturbations_to_state_(self, params=params, n_samples=n_samples, sigmas=sigmas, generator=generator)
132
-
133
- # sigma decay rules
134
- max_steps = settings['max_steps']
135
- decayed = False
136
- if max_steps is not None and max_steps > 0:
137
- level_steps = self.global_state['level_steps'] = self.global_state.get('level_steps', 0) + 1
138
- if level_steps > max_steps:
139
- self.global_state['level_steps'] = 0
140
- _decay_sigma_(self, params)
141
- decayed = True
142
-
143
- tol = settings['tol']
144
- if tol is not None and not decayed:
145
- if not any('prev_params' in self.state[p] for p in params):
146
- prev_params = self.get_state(params, 'prev_params', cls=TensorList, init='param')
147
- else:
148
- prev_params = self.get_state(params, 'prev_params', cls=TensorList, init='param')
149
- s = params - prev_params
150
-
151
- if s.abs().global_max() <= tol:
152
- _decay_sigma_(self, params)
153
- decayed = True
154
-
155
- prev_params.copy_(params)
156
-
157
- if decayed:
158
- generator = self._get_generator(settings['seed'], params)
159
- _generate_perturbations_to_state_(self, params=params, n_samples=n_samples, sigmas=sigmas, generator=generator)
160
- if settings['clear_state']:
161
- var.post_step_hooks.append(partial(_clear_state_hook, self=self))
162
-
163
- @torch.no_grad
164
- def closure(self, backward, closure, params, var):
165
- params = TensorList(params)
166
-
167
- settings = self.settings[params[0]]
168
- n_samples = settings['n_samples']
169
-
170
- perturbations = list(zip(*(self.state[p]['perturbations'] for p in params)))
171
-
172
- loss = None
173
- grad = None
174
- for i in range(n_samples):
175
- prt = perturbations[i]
176
-
177
- params.add_(prt)
178
- if backward:
179
- with torch.enable_grad(): l = closure()
180
- if grad is None: grad = params.grad
181
- else: grad += params.grad
182
-
183
- else:
184
- l = closure(False)
185
-
186
- if loss is None: loss = l
187
- else: loss = loss+l
188
-
189
- params.sub_(prt)
190
-
191
- assert loss is not None
192
- if n_samples > 1:
193
- loss = loss / n_samples
194
- if backward:
195
- assert grad is not None
196
- grad.div_(n_samples)
197
-
198
- return loss, grad