torchzero 0.3.10__py3-none-any.whl → 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (140) hide show
  1. docs/source/conf.py +6 -4
  2. docs/source/docstring template.py +46 -0
  3. tests/test_identical.py +2 -3
  4. tests/test_opts.py +64 -50
  5. tests/test_vars.py +1 -0
  6. torchzero/core/module.py +138 -6
  7. torchzero/core/transform.py +158 -51
  8. torchzero/modules/__init__.py +3 -2
  9. torchzero/modules/clipping/clipping.py +114 -17
  10. torchzero/modules/clipping/ema_clipping.py +27 -13
  11. torchzero/modules/clipping/growth_clipping.py +8 -7
  12. torchzero/modules/experimental/__init__.py +22 -5
  13. torchzero/modules/experimental/absoap.py +5 -2
  14. torchzero/modules/experimental/adadam.py +8 -2
  15. torchzero/modules/experimental/adamY.py +8 -2
  16. torchzero/modules/experimental/adam_lambertw.py +149 -0
  17. torchzero/modules/{line_search/trust_region.py → experimental/adaptive_step_size.py} +21 -4
  18. torchzero/modules/experimental/adasoap.py +7 -2
  19. torchzero/modules/experimental/cosine.py +214 -0
  20. torchzero/modules/experimental/cubic_adam.py +97 -0
  21. torchzero/modules/{projections → experimental}/dct.py +11 -11
  22. torchzero/modules/experimental/eigendescent.py +4 -1
  23. torchzero/modules/experimental/etf.py +32 -9
  24. torchzero/modules/experimental/exp_adam.py +113 -0
  25. torchzero/modules/experimental/expanded_lbfgs.py +141 -0
  26. torchzero/modules/{projections → experimental}/fft.py +10 -10
  27. torchzero/modules/experimental/hnewton.py +85 -0
  28. torchzero/modules/{quasi_newton/experimental → experimental}/modular_lbfgs.py +27 -28
  29. torchzero/modules/experimental/newtonnewton.py +7 -3
  30. torchzero/modules/experimental/parabolic_search.py +220 -0
  31. torchzero/modules/experimental/reduce_outward_lr.py +4 -4
  32. torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +12 -54
  33. torchzero/modules/experimental/subspace_preconditioners.py +11 -4
  34. torchzero/modules/experimental/{tada.py → tensor_adagrad.py} +10 -6
  35. torchzero/modules/functional.py +12 -2
  36. torchzero/modules/grad_approximation/fdm.py +30 -3
  37. torchzero/modules/grad_approximation/forward_gradient.py +13 -3
  38. torchzero/modules/grad_approximation/grad_approximator.py +51 -6
  39. torchzero/modules/grad_approximation/rfdm.py +285 -38
  40. torchzero/modules/higher_order/higher_order_newton.py +152 -89
  41. torchzero/modules/line_search/__init__.py +4 -4
  42. torchzero/modules/line_search/adaptive.py +99 -0
  43. torchzero/modules/line_search/backtracking.py +34 -9
  44. torchzero/modules/line_search/line_search.py +70 -12
  45. torchzero/modules/line_search/polynomial.py +233 -0
  46. torchzero/modules/line_search/scipy.py +2 -2
  47. torchzero/modules/line_search/strong_wolfe.py +34 -7
  48. torchzero/modules/misc/__init__.py +27 -0
  49. torchzero/modules/{ops → misc}/debug.py +24 -1
  50. torchzero/modules/misc/escape.py +60 -0
  51. torchzero/modules/misc/gradient_accumulation.py +70 -0
  52. torchzero/modules/misc/misc.py +316 -0
  53. torchzero/modules/misc/multistep.py +158 -0
  54. torchzero/modules/misc/regularization.py +171 -0
  55. torchzero/modules/{ops → misc}/split.py +29 -1
  56. torchzero/modules/{ops → misc}/switch.py +44 -3
  57. torchzero/modules/momentum/__init__.py +1 -1
  58. torchzero/modules/momentum/averaging.py +6 -6
  59. torchzero/modules/momentum/cautious.py +45 -8
  60. torchzero/modules/momentum/ema.py +7 -7
  61. torchzero/modules/momentum/experimental.py +2 -2
  62. torchzero/modules/momentum/matrix_momentum.py +90 -63
  63. torchzero/modules/momentum/momentum.py +2 -1
  64. torchzero/modules/ops/__init__.py +3 -31
  65. torchzero/modules/ops/accumulate.py +6 -10
  66. torchzero/modules/ops/binary.py +72 -26
  67. torchzero/modules/ops/multi.py +77 -16
  68. torchzero/modules/ops/reduce.py +15 -7
  69. torchzero/modules/ops/unary.py +29 -13
  70. torchzero/modules/ops/utility.py +20 -12
  71. torchzero/modules/optimizers/__init__.py +12 -3
  72. torchzero/modules/optimizers/adagrad.py +23 -13
  73. torchzero/modules/optimizers/adahessian.py +223 -0
  74. torchzero/modules/optimizers/adam.py +7 -6
  75. torchzero/modules/optimizers/adan.py +110 -0
  76. torchzero/modules/optimizers/adaptive_heavyball.py +57 -0
  77. torchzero/modules/optimizers/esgd.py +171 -0
  78. torchzero/modules/{experimental/spectral.py → optimizers/ladagrad.py} +91 -71
  79. torchzero/modules/optimizers/lion.py +1 -1
  80. torchzero/modules/optimizers/mars.py +91 -0
  81. torchzero/modules/optimizers/msam.py +186 -0
  82. torchzero/modules/optimizers/muon.py +30 -5
  83. torchzero/modules/optimizers/orthograd.py +1 -1
  84. torchzero/modules/optimizers/rmsprop.py +7 -4
  85. torchzero/modules/optimizers/rprop.py +42 -8
  86. torchzero/modules/optimizers/sam.py +163 -0
  87. torchzero/modules/optimizers/shampoo.py +39 -5
  88. torchzero/modules/optimizers/soap.py +29 -19
  89. torchzero/modules/optimizers/sophia_h.py +71 -14
  90. torchzero/modules/projections/__init__.py +2 -4
  91. torchzero/modules/projections/cast.py +51 -0
  92. torchzero/modules/projections/galore.py +3 -1
  93. torchzero/modules/projections/projection.py +188 -94
  94. torchzero/modules/quasi_newton/__init__.py +12 -2
  95. torchzero/modules/quasi_newton/cg.py +160 -59
  96. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +163 -0
  97. torchzero/modules/quasi_newton/lbfgs.py +154 -97
  98. torchzero/modules/quasi_newton/lsr1.py +101 -57
  99. torchzero/modules/quasi_newton/quasi_newton.py +863 -215
  100. torchzero/modules/quasi_newton/trust_region.py +397 -0
  101. torchzero/modules/second_order/__init__.py +2 -2
  102. torchzero/modules/second_order/newton.py +220 -41
  103. torchzero/modules/second_order/newton_cg.py +300 -11
  104. torchzero/modules/second_order/nystrom.py +104 -1
  105. torchzero/modules/smoothing/gaussian.py +34 -0
  106. torchzero/modules/smoothing/laplacian.py +14 -4
  107. torchzero/modules/step_size/__init__.py +2 -0
  108. torchzero/modules/step_size/adaptive.py +122 -0
  109. torchzero/modules/step_size/lr.py +154 -0
  110. torchzero/modules/weight_decay/__init__.py +1 -1
  111. torchzero/modules/weight_decay/weight_decay.py +89 -7
  112. torchzero/modules/wrappers/optim_wrapper.py +29 -1
  113. torchzero/optim/wrappers/directsearch.py +39 -2
  114. torchzero/optim/wrappers/fcmaes.py +21 -13
  115. torchzero/optim/wrappers/mads.py +5 -6
  116. torchzero/optim/wrappers/nevergrad.py +16 -1
  117. torchzero/optim/wrappers/optuna.py +1 -1
  118. torchzero/optim/wrappers/scipy.py +5 -3
  119. torchzero/utils/__init__.py +2 -2
  120. torchzero/utils/derivatives.py +3 -3
  121. torchzero/utils/linalg/__init__.py +1 -1
  122. torchzero/utils/linalg/solve.py +251 -12
  123. torchzero/utils/numberlist.py +2 -0
  124. torchzero/utils/python_tools.py +10 -0
  125. torchzero/utils/tensorlist.py +40 -28
  126. {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/METADATA +65 -40
  127. torchzero-0.3.11.dist-info/RECORD +159 -0
  128. torchzero/modules/experimental/diagonal_higher_order_newton.py +0 -225
  129. torchzero/modules/experimental/soapy.py +0 -163
  130. torchzero/modules/experimental/structured_newton.py +0 -111
  131. torchzero/modules/lr/__init__.py +0 -2
  132. torchzero/modules/lr/adaptive.py +0 -93
  133. torchzero/modules/lr/lr.py +0 -63
  134. torchzero/modules/ops/misc.py +0 -418
  135. torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
  136. torchzero/modules/quasi_newton/olbfgs.py +0 -196
  137. torchzero-0.3.10.dist-info/RECORD +0 -139
  138. {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/WHEEL +0 -0
  139. {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/licenses/LICENSE +0 -0
  140. {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,397 @@
1
+ """Trust region API is currently experimental, it will probably change completely"""
2
+ # pylint:disable=not-callable
3
+ from abc import ABC, abstractmethod
4
+ from typing import Any, Literal, cast, final
5
+ from collections.abc import Sequence, Mapping
6
+
7
+ import numpy as np
8
+ import torch
9
+ from scipy.optimize import lsq_linear
10
+
11
+ from ...core import Chainable, Module, apply_transform, Var
12
+ from ...utils import TensorList, vec_to_tensors
13
+ from ...utils.derivatives import (
14
+ hessian_list_to_mat,
15
+ jacobian_and_hessian_wrt,
16
+ )
17
+ from .quasi_newton import HessianUpdateStrategy
18
+ from ...utils.linalg import steihaug_toint_cg
19
+
20
+
21
+ def trust_lstsq(H: torch.Tensor, g: torch.Tensor, trust_region: float):
22
+ res = lsq_linear(H.numpy(force=True).astype(np.float64), g.numpy(force=True).astype(np.float64), bounds=(-trust_region, trust_region))
23
+ x = torch.from_numpy(res.x).to(H)
24
+ return x, res.cost
25
+
26
+ def _flatten_tensors(tensors: list[torch.Tensor]):
27
+ return torch.cat([t.ravel() for t in tensors])
28
+
29
+
30
+ class TrustRegionBase(Module, ABC):
31
+ def __init__(
32
+ self,
33
+ defaults: dict | None = None,
34
+ hess_module: HessianUpdateStrategy | None = None,
35
+ update_freq: int = 1,
36
+ inner: Chainable | None = None,
37
+ ):
38
+ self._update_freq = update_freq
39
+ super().__init__(defaults)
40
+
41
+ if hess_module is not None:
42
+ self.set_child('hess_module', hess_module)
43
+
44
+ if inner is not None:
45
+ self.set_child('inner', inner)
46
+
47
+ @abstractmethod
48
+ def trust_region_step(self, var: Var, tensors:list[torch.Tensor], P: torch.Tensor, is_inverse:bool) -> Var:
49
+ """trust region logic"""
50
+
51
+
52
+ @final
53
+ @torch.no_grad
54
+ def update(self, var):
55
+ # ---------------------------------- update ---------------------------------- #
56
+ closure = var.closure
57
+ if closure is None: raise RuntimeError("Trust region requires closure")
58
+ params = var.params
59
+
60
+ step = self.global_state.get('step', 0)
61
+ self.global_state['step'] = step + 1
62
+
63
+ P = None
64
+ is_inverse=None
65
+ g_list = var.grad
66
+ loss = var.loss
67
+ if step % self._update_freq == 0:
68
+
69
+ if 'hess_module' not in self.children:
70
+ params=var.params
71
+ closure=var.closure
72
+ if closure is None: raise ValueError('Closure is required for trust region')
73
+ with torch.enable_grad():
74
+ loss = var.loss = var.loss_approx = closure(False)
75
+ g_list, H_list = jacobian_and_hessian_wrt([loss], params, batched=True)
76
+ g_list = [t[0] for t in g_list] # remove leading dim from loss
77
+ var.grad = g_list
78
+ P = hessian_list_to_mat(H_list)
79
+ is_inverse=False
80
+
81
+
82
+ else:
83
+ hessian_module = cast(HessianUpdateStrategy, self.children['hess_module'])
84
+ hessian_module.update(var)
85
+ P, is_inverse = hessian_module.get_B()
86
+
87
+ if self._update_freq != 0:
88
+ self.global_state['B'] = P
89
+ self.global_state['is_inverse'] = is_inverse
90
+
91
+
92
+ @final
93
+ @torch.no_grad
94
+ def apply(self, var):
95
+ P = self.global_state['B']
96
+ is_inverse = self.global_state['is_inverse']
97
+
98
+ # -------------------------------- inner step -------------------------------- #
99
+ update = var.get_update()
100
+ if 'inner' in self.children:
101
+ update = apply_transform(self.children['inner'], update, params=var.params, grads=var.grad, var=var)
102
+
103
+ # ----------------------------------- apply ---------------------------------- #
104
+ return self.trust_region_step(var=var, tensors=update, P=P, is_inverse=is_inverse)
105
+
106
+ def _update_tr_radius(update_vec:torch.Tensor, params: Sequence[torch.Tensor], closure,
107
+ loss, g:torch.Tensor, H:torch.Tensor, trust_region:float, settings: Mapping):
108
+ """returns (update, new_trust_region)
109
+
110
+ Args:
111
+ update_vec (torch.Tensor): update vector which is SUBTRACTED from parameters
112
+ params (_type_): params tensor list
113
+ closure (_type_): closure
114
+ loss (_type_): loss at x0
115
+ g (torch.Tensor): gradient vector
116
+ H (torch.Tensor): hessian
117
+ trust_region (float): current trust region value
118
+ """
119
+ # evaluate actual loss reduction
120
+ update_unflattned = vec_to_tensors(update_vec, params)
121
+ params = TensorList(params)
122
+ params -= update_unflattned
123
+ loss_star = closure(False)
124
+ params += update_unflattned
125
+ reduction = loss - loss_star
126
+
127
+ # expected reduction is g.T @ p + 0.5 * p.T @ B @ p
128
+ if H.ndim == 1: Hu = H * update_vec
129
+ else: Hu = H @ update_vec
130
+ pred_reduction = - (g.dot(update_vec) + 0.5 * update_vec.dot(Hu))
131
+ rho = reduction / (pred_reduction.clip(min=1e-8))
132
+
133
+ # failed step
134
+ if rho < 0.25:
135
+ trust_region *= settings["nminus"]
136
+
137
+ # very good step
138
+ elif rho > 0.75:
139
+ diff = trust_region - update_vec.abs()
140
+ if (diff.amin() / trust_region) > 1e-4: # hits boundary
141
+ trust_region *= settings["nplus"]
142
+
143
+ # # if the ratio is high enough then accept the proposed step
144
+ # if rho > settings["eta"]:
145
+ # update = vec_to_tensors(update_vec, params)
146
+
147
+ # else:
148
+ # update = params.zeros_like()
149
+
150
+ return trust_region, rho > settings["eta"]
151
+
152
+ class TrustCG(TrustRegionBase):
153
+ """Trust region via Steihaug-Toint Conjugate Gradient method. This is mainly useful for quasi-newton methods.
154
+ If you don't use :code:`hess_module`, use the matrix-free :code:`tz.m.NewtonCGSteihaug` which only uses hessian-vector products.
155
+
156
+ Args:
157
+ hess_module (HessianUpdateStrategy | None, optional):
158
+ Hessian update strategy, must be one of the :code:`HessianUpdateStrategy` modules. Make sure to set :code:`inverse=False`. If None, uses autograd to calculate the hessian. Defaults to None.
159
+ eta (float, optional):
160
+ if ratio of actual to predicted rediction is larger than this, step is accepted.
161
+ When :code:`hess_module` is None, this can be set to 0. Defaults to 0.15.
162
+ nplus (float, optional): increase factor on successful steps. Defaults to 1.5.
163
+ nminus (float, optional): decrease factor on unsuccessful steps. Defaults to 0.75.
164
+ init (float, optional): Initial trust region value. Defaults to 1.
165
+ update_freq (int, optional): frequency of updating the hessian. Defaults to 1.
166
+ reg (int, optional): hessian regularization. Defaults to 0.
167
+ inner (Chainable | None, optional): preconditioning is applied to output of thise module. Defaults to None.
168
+
169
+ Examples:
170
+ Trust-SR1
171
+
172
+ .. code-block:: python
173
+
174
+ opt = tz.Modular(
175
+ model.parameters(),
176
+ tz.m.TrustCG(hess_module=tz.m.SR1(inverse=False)),
177
+ )
178
+ """
179
+ def __init__(
180
+ self,
181
+ hess_module: HessianUpdateStrategy | None,
182
+ eta: float= 0.15,
183
+ nplus: float = 2,
184
+ nminus: float = 0.25,
185
+ init: float = 1,
186
+ update_freq: int = 1,
187
+ reg: float = 0,
188
+ max_attempts: int = 10,
189
+ inner: Chainable | None = None,
190
+ ):
191
+ defaults = dict(init=init, nplus=nplus, nminus=nminus, eta=eta, reg=reg, max_attempts=max_attempts)
192
+ super().__init__(defaults, hess_module=hess_module, update_freq=update_freq, inner=inner)
193
+
194
+ @torch.no_grad
195
+ def trust_region_step(self, var, tensors, P, is_inverse):
196
+ params = TensorList(var.params)
197
+ settings = self.settings[params[0]]
198
+ g = _flatten_tensors(tensors)
199
+
200
+ reg = settings['reg']
201
+ max_attempts = settings['max_attempts']
202
+
203
+ loss = var.loss
204
+ closure = var.closure
205
+ if closure is None: raise RuntimeError("Trust region requires closure")
206
+ if loss is None: loss = closure(False)
207
+
208
+ if is_inverse:
209
+ if P.ndim == 1: P = P.reciprocal()
210
+ else: raise NotImplementedError()
211
+
212
+ success = False
213
+ update_vec = None
214
+ while not success:
215
+ max_attempts -= 1
216
+ if max_attempts < 0: break
217
+
218
+ trust_region = self.global_state.get('trust_region', settings['init'])
219
+
220
+ if trust_region < 1e-8 or trust_region > 1e8:
221
+ trust_region = self.global_state['trust_region'] = settings['init']
222
+
223
+ update_vec = steihaug_toint_cg(P, g, trust_region, reg=reg)
224
+
225
+ self.global_state['trust_region'], success = _update_tr_radius(
226
+ update_vec=update_vec, params=params, closure=closure,
227
+ loss=loss, g=g, H=P, trust_region=trust_region, settings = settings,
228
+ )
229
+
230
+ assert update_vec is not None
231
+ if success: var.update = vec_to_tensors(update_vec, params)
232
+ else: var.update = params.zeros_like()
233
+
234
+ return var
235
+
236
+
237
+ # code from https://github.com/konstmish/opt_methods/blob/master/optmethods/second_order/cubic.py
238
+ # ported to torch
239
+ def ls_cubic_solver(f, g:torch.Tensor, H:torch.Tensor, M: float, is_inverse: bool, loss_plus, it_max=100, epsilon=1e-8, ):
240
+ """
241
+ Solve min_z <g, z-x> + 1/2<z-x, H(z-x)> + M/3 ||z-x||^3
242
+
243
+ For explanation of Cauchy point, see "Gradient Descent
244
+ Efficiently Finds the Cubic-Regularized Non-Convex Newton Step"
245
+ https://arxiv.org/pdf/1612.00547.pdf
246
+ Other potential implementations can be found in paper
247
+ "Adaptive cubic regularisation methods"
248
+ https://people.maths.ox.ac.uk/cartis/papers/ARCpI.pdf
249
+ """
250
+ solver_it = 1
251
+ if is_inverse:
252
+ newton_step = - H @ g
253
+ H = torch.linalg.inv(H)
254
+ else:
255
+ newton_step, info = torch.linalg.solve_ex(H, g)
256
+ if info != 0:
257
+ newton_step = torch.linalg.lstsq(H, g).solution
258
+ newton_step.neg_()
259
+ if M == 0:
260
+ return newton_step, solver_it
261
+ def cauchy_point(g, H, M):
262
+ if torch.linalg.vector_norm(g) == 0 or M == 0:
263
+ return 0 * g
264
+ g_dir = g / torch.linalg.vector_norm(g)
265
+ H_g_g = H @ g_dir @ g_dir
266
+ R = -H_g_g / (2*M) + torch.sqrt((H_g_g/M)**2/4 + torch.linalg.vector_norm(g)/M)
267
+ return -R * g_dir
268
+
269
+ def conv_criterion(s, r):
270
+ """
271
+ The convergence criterion is an increasing and concave function in r
272
+ and it is equal to 0 only if r is the solution to the cubic problem
273
+ """
274
+ s_norm = torch.linalg.vector_norm(s)
275
+ return 1/s_norm - 1/r
276
+
277
+ # Solution s satisfies ||s|| >= Cauchy_radius
278
+ r_min = torch.linalg.vector_norm(cauchy_point(g, H, M))
279
+
280
+ if f > loss_plus(newton_step):
281
+ return newton_step, solver_it
282
+
283
+ r_max = torch.linalg.vector_norm(newton_step)
284
+ if r_max - r_min < epsilon:
285
+ return newton_step, solver_it
286
+ id_matrix = torch.eye(g.size(0), device=g.device, dtype=g.dtype)
287
+ s_lam = None
288
+ for _ in range(it_max):
289
+ r_try = (r_min + r_max) / 2
290
+ lam = r_try * M
291
+ s_lam = -torch.linalg.solve(H + lam*id_matrix, g)
292
+ solver_it += 1
293
+ crit = conv_criterion(s_lam, r_try)
294
+ if np.abs(crit) < epsilon:
295
+ return s_lam, solver_it
296
+ if crit < 0:
297
+ r_min = r_try
298
+ else:
299
+ r_max = r_try
300
+ if r_max - r_min < epsilon:
301
+ break
302
+ assert s_lam is not None
303
+ return s_lam, solver_it
304
+
305
+ class CubicRegularization(TrustRegionBase):
306
+ """Cubic regularization.
307
+
308
+ .. note::
309
+ by default this functions like a trust region, set nplus and nminus = 1 to make regularization parameter fixed.
310
+ :code:`init` sets 1/regularization.
311
+
312
+ Args:
313
+ hess_module (HessianUpdateStrategy | None, optional):
314
+ Hessian update strategy, must be one of the :code:`HessianUpdateStrategy` modules. This works better with true hessian though. Make sure to set :code:`inverse=False`. If None, uses autograd to calculate the hessian. Defaults to None.
315
+ eta (float, optional):
316
+ if ratio of actual to predicted rediction is larger than this, step is accepted.
317
+ When :code:`hess_module` is None, this can be set to 0. Defaults to 0.0.
318
+ nplus (float, optional): increase factor on successful steps. Defaults to 1.5.
319
+ nminus (float, optional): decrease factor on unsuccessful steps. Defaults to 0.75.
320
+ init (float, optional): Initial trust region value. Defaults to 1.
321
+ maxiter (float, optional): maximum iterations when solving cubic subproblem, defaults to 1e-7.
322
+ eps (float, optional): epsilon for the solver, defaults to 1e-8.
323
+ update_freq (int, optional): frequency of updating the hessian. Defaults to 1.
324
+ inner (Chainable | None, optional): preconditioning is applied to output of thise module. Defaults to None.
325
+
326
+ Examples:
327
+ Cubic regularized newton
328
+
329
+ .. code-block:: python
330
+
331
+ opt = tz.Modular(
332
+ model.parameters(),
333
+ tz.m.CubicRegularization(),
334
+ )
335
+
336
+ """
337
+ def __init__(
338
+ self,
339
+ hess_module: HessianUpdateStrategy | None = None,
340
+ eta: float= 0.0,
341
+ nplus: float = 2,
342
+ nminus: float = 0.25,
343
+ init: float = 1,
344
+ maxiter: int = 100,
345
+ eps: float = 1e-8,
346
+ update_freq: int = 1,
347
+ max_attempts: int = 10,
348
+ inner: Chainable | None = None,
349
+ ):
350
+ defaults = dict(init=init, nplus=nplus, nminus=nminus, eta=eta, maxiter=maxiter, eps=eps, max_attempts=max_attempts)
351
+ super().__init__(defaults, hess_module=hess_module, update_freq=update_freq, inner=inner)
352
+
353
+ @torch.no_grad
354
+ def trust_region_step(self, var, tensors, P, is_inverse):
355
+ params = TensorList(var.params)
356
+ settings = self.settings[params[0]]
357
+ g = _flatten_tensors(tensors)
358
+
359
+ maxiter = settings['maxiter']
360
+ max_attempts = settings['max_attempts']
361
+ eps = settings['eps']
362
+
363
+ loss = var.loss
364
+ closure = var.closure
365
+ if closure is None: raise RuntimeError("Trust region requires closure")
366
+ if loss is None: loss = closure(False)
367
+
368
+ def loss_plus(x):
369
+ x_unflat = vec_to_tensors(x, params)
370
+ params.add_(x_unflat)
371
+ loss_x = closure(False)
372
+ params.sub_(x_unflat)
373
+ return loss_x
374
+
375
+ success = False
376
+ update_vec = None
377
+ while not success:
378
+ max_attempts -= 1
379
+ if max_attempts < 0: break
380
+
381
+ trust_region = self.global_state.get('trust_region', settings['init'])
382
+ if trust_region < 1e-8 or trust_region > 1e16: trust_region = self.global_state['trust_region'] = settings['init']
383
+
384
+ update_vec, _ = ls_cubic_solver(f=loss, g=g, H=P, M=1/trust_region, is_inverse=is_inverse,
385
+ loss_plus=loss_plus, it_max=maxiter, epsilon=eps)
386
+ update_vec.neg_()
387
+
388
+ self.global_state['trust_region'], success = _update_tr_radius(
389
+ update_vec=update_vec, params=params, closure=closure,
390
+ loss=loss, g=g, H=P, trust_region=trust_region, settings = settings,
391
+ )
392
+
393
+ assert update_vec is not None
394
+ if success: var.update = vec_to_tensors(update_vec, params)
395
+ else: var.update = params.zeros_like()
396
+
397
+ return var
@@ -1,3 +1,3 @@
1
- from .newton import Newton
2
- from .newton_cg import NewtonCG
1
+ from .newton import Newton, InverseFreeNewton
2
+ from .newton_cg import NewtonCG, TruncatedNewtonCG
3
3
  from .nystrom import NystromSketchAndSolve, NystromPCG