torchzero 0.3.15__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. tests/test_identical.py +2 -2
  2. tests/test_module_autograd.py +586 -0
  3. tests/test_objective.py +188 -0
  4. tests/test_opts.py +43 -33
  5. tests/test_tensorlist.py +0 -8
  6. tests/test_utils_optimizer.py +0 -1
  7. torchzero/__init__.py +1 -1
  8. torchzero/core/__init__.py +7 -4
  9. torchzero/core/chain.py +20 -23
  10. torchzero/core/functional.py +90 -24
  11. torchzero/core/modular.py +48 -52
  12. torchzero/core/module.py +130 -50
  13. torchzero/core/objective.py +948 -0
  14. torchzero/core/reformulation.py +55 -24
  15. torchzero/core/transform.py +261 -367
  16. torchzero/linalg/__init__.py +10 -0
  17. torchzero/linalg/eigh.py +34 -0
  18. torchzero/linalg/linalg_utils.py +14 -0
  19. torchzero/{utils/linalg → linalg}/linear_operator.py +99 -49
  20. torchzero/linalg/matrix_power.py +28 -0
  21. torchzero/linalg/orthogonalize.py +95 -0
  22. torchzero/{utils/linalg → linalg}/qr.py +4 -2
  23. torchzero/{utils/linalg → linalg}/solve.py +76 -88
  24. torchzero/linalg/svd.py +20 -0
  25. torchzero/linalg/torch_linalg.py +168 -0
  26. torchzero/modules/adaptive/__init__.py +1 -1
  27. torchzero/modules/adaptive/adagrad.py +163 -213
  28. torchzero/modules/adaptive/adahessian.py +74 -103
  29. torchzero/modules/adaptive/adam.py +53 -76
  30. torchzero/modules/adaptive/adan.py +49 -30
  31. torchzero/modules/adaptive/adaptive_heavyball.py +11 -6
  32. torchzero/modules/adaptive/aegd.py +12 -12
  33. torchzero/modules/adaptive/esgd.py +98 -119
  34. torchzero/modules/adaptive/lion.py +5 -10
  35. torchzero/modules/adaptive/lmadagrad.py +87 -32
  36. torchzero/modules/adaptive/mars.py +5 -5
  37. torchzero/modules/adaptive/matrix_momentum.py +47 -51
  38. torchzero/modules/adaptive/msam.py +70 -52
  39. torchzero/modules/adaptive/muon.py +59 -124
  40. torchzero/modules/adaptive/natural_gradient.py +33 -28
  41. torchzero/modules/adaptive/orthograd.py +11 -15
  42. torchzero/modules/adaptive/rmsprop.py +83 -75
  43. torchzero/modules/adaptive/rprop.py +48 -47
  44. torchzero/modules/adaptive/sam.py +55 -45
  45. torchzero/modules/adaptive/shampoo.py +123 -129
  46. torchzero/modules/adaptive/soap.py +207 -143
  47. torchzero/modules/adaptive/sophia_h.py +106 -130
  48. torchzero/modules/clipping/clipping.py +15 -18
  49. torchzero/modules/clipping/ema_clipping.py +31 -25
  50. torchzero/modules/clipping/growth_clipping.py +14 -17
  51. torchzero/modules/conjugate_gradient/cg.py +26 -37
  52. torchzero/modules/experimental/__init__.py +2 -6
  53. torchzero/modules/experimental/coordinate_momentum.py +36 -0
  54. torchzero/modules/experimental/curveball.py +25 -41
  55. torchzero/modules/experimental/gradmin.py +2 -2
  56. torchzero/modules/experimental/higher_order_newton.py +14 -40
  57. torchzero/modules/experimental/newton_solver.py +22 -53
  58. torchzero/modules/experimental/newtonnewton.py +15 -12
  59. torchzero/modules/experimental/reduce_outward_lr.py +7 -7
  60. torchzero/modules/experimental/scipy_newton_cg.py +21 -24
  61. torchzero/modules/experimental/spsa1.py +3 -3
  62. torchzero/modules/experimental/structural_projections.py +1 -4
  63. torchzero/modules/functional.py +1 -1
  64. torchzero/modules/grad_approximation/forward_gradient.py +7 -7
  65. torchzero/modules/grad_approximation/grad_approximator.py +23 -16
  66. torchzero/modules/grad_approximation/rfdm.py +20 -17
  67. torchzero/modules/least_squares/gn.py +90 -42
  68. torchzero/modules/line_search/backtracking.py +2 -2
  69. torchzero/modules/line_search/line_search.py +32 -32
  70. torchzero/modules/line_search/strong_wolfe.py +2 -2
  71. torchzero/modules/misc/debug.py +12 -12
  72. torchzero/modules/misc/escape.py +10 -10
  73. torchzero/modules/misc/gradient_accumulation.py +10 -78
  74. torchzero/modules/misc/homotopy.py +16 -8
  75. torchzero/modules/misc/misc.py +120 -122
  76. torchzero/modules/misc/multistep.py +50 -48
  77. torchzero/modules/misc/regularization.py +49 -44
  78. torchzero/modules/misc/split.py +30 -28
  79. torchzero/modules/misc/switch.py +37 -32
  80. torchzero/modules/momentum/averaging.py +14 -14
  81. torchzero/modules/momentum/cautious.py +34 -28
  82. torchzero/modules/momentum/momentum.py +11 -11
  83. torchzero/modules/ops/__init__.py +4 -4
  84. torchzero/modules/ops/accumulate.py +21 -21
  85. torchzero/modules/ops/binary.py +67 -66
  86. torchzero/modules/ops/higher_level.py +19 -19
  87. torchzero/modules/ops/multi.py +44 -41
  88. torchzero/modules/ops/reduce.py +26 -23
  89. torchzero/modules/ops/unary.py +53 -53
  90. torchzero/modules/ops/utility.py +47 -46
  91. torchzero/modules/projections/galore.py +1 -1
  92. torchzero/modules/projections/projection.py +43 -43
  93. torchzero/modules/quasi_newton/damping.py +1 -1
  94. torchzero/modules/quasi_newton/lbfgs.py +7 -7
  95. torchzero/modules/quasi_newton/lsr1.py +7 -7
  96. torchzero/modules/quasi_newton/quasi_newton.py +10 -10
  97. torchzero/modules/quasi_newton/sg2.py +19 -19
  98. torchzero/modules/restarts/restars.py +26 -24
  99. torchzero/modules/second_order/__init__.py +2 -2
  100. torchzero/modules/second_order/ifn.py +31 -62
  101. torchzero/modules/second_order/inm.py +49 -53
  102. torchzero/modules/second_order/multipoint.py +40 -80
  103. torchzero/modules/second_order/newton.py +57 -90
  104. torchzero/modules/second_order/newton_cg.py +102 -154
  105. torchzero/modules/second_order/nystrom.py +157 -177
  106. torchzero/modules/second_order/rsn.py +106 -96
  107. torchzero/modules/smoothing/laplacian.py +13 -12
  108. torchzero/modules/smoothing/sampling.py +11 -10
  109. torchzero/modules/step_size/adaptive.py +23 -23
  110. torchzero/modules/step_size/lr.py +15 -15
  111. torchzero/modules/termination/termination.py +32 -30
  112. torchzero/modules/trust_region/cubic_regularization.py +2 -2
  113. torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
  114. torchzero/modules/trust_region/trust_cg.py +1 -1
  115. torchzero/modules/trust_region/trust_region.py +27 -22
  116. torchzero/modules/variance_reduction/svrg.py +21 -18
  117. torchzero/modules/weight_decay/__init__.py +2 -1
  118. torchzero/modules/weight_decay/reinit.py +83 -0
  119. torchzero/modules/weight_decay/weight_decay.py +12 -13
  120. torchzero/modules/wrappers/optim_wrapper.py +10 -10
  121. torchzero/modules/zeroth_order/cd.py +9 -6
  122. torchzero/optim/root.py +3 -3
  123. torchzero/optim/utility/split.py +2 -1
  124. torchzero/optim/wrappers/directsearch.py +27 -63
  125. torchzero/optim/wrappers/fcmaes.py +14 -35
  126. torchzero/optim/wrappers/mads.py +11 -31
  127. torchzero/optim/wrappers/moors.py +66 -0
  128. torchzero/optim/wrappers/nevergrad.py +4 -4
  129. torchzero/optim/wrappers/nlopt.py +31 -25
  130. torchzero/optim/wrappers/optuna.py +6 -13
  131. torchzero/optim/wrappers/pybobyqa.py +124 -0
  132. torchzero/optim/wrappers/scipy/__init__.py +7 -0
  133. torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
  134. torchzero/optim/wrappers/scipy/brute.py +48 -0
  135. torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
  136. torchzero/optim/wrappers/scipy/direct.py +69 -0
  137. torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
  138. torchzero/optim/wrappers/scipy/experimental.py +141 -0
  139. torchzero/optim/wrappers/scipy/minimize.py +151 -0
  140. torchzero/optim/wrappers/scipy/sgho.py +111 -0
  141. torchzero/optim/wrappers/wrapper.py +121 -0
  142. torchzero/utils/__init__.py +7 -25
  143. torchzero/utils/compile.py +2 -2
  144. torchzero/utils/derivatives.py +93 -69
  145. torchzero/utils/optimizer.py +4 -77
  146. torchzero/utils/python_tools.py +31 -0
  147. torchzero/utils/tensorlist.py +11 -5
  148. torchzero/utils/thoad_tools.py +68 -0
  149. {torchzero-0.3.15.dist-info → torchzero-0.4.0.dist-info}/METADATA +1 -1
  150. torchzero-0.4.0.dist-info/RECORD +191 -0
  151. tests/test_vars.py +0 -185
  152. torchzero/core/var.py +0 -376
  153. torchzero/modules/experimental/momentum.py +0 -160
  154. torchzero/optim/wrappers/scipy.py +0 -572
  155. torchzero/utils/linalg/__init__.py +0 -12
  156. torchzero/utils/linalg/matrix_funcs.py +0 -87
  157. torchzero/utils/linalg/orthogonalize.py +0 -12
  158. torchzero/utils/linalg/svd.py +0 -20
  159. torchzero/utils/ops.py +0 -10
  160. torchzero-0.3.15.dist-info/RECORD +0 -175
  161. /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
  162. {torchzero-0.3.15.dist-info → torchzero-0.4.0.dist-info}/WHEEL +0 -0
  163. {torchzero-0.3.15.dist-info → torchzero-0.4.0.dist-info}/top_level.txt +0 -0
@@ -1,16 +1,16 @@
1
- import warnings
2
- import math
3
- from typing import Literal, cast
1
+
4
2
  from operator import itemgetter
3
+ from typing import Literal, cast
4
+
5
5
  import torch
6
6
 
7
- from ...core import Chainable, Module, apply_transform
8
- from ...utils import TensorList, as_tensorlist, tofloat
9
- from ...utils.derivatives import hvp, hvp_fd_central, hvp_fd_forward
10
- from ...utils.linalg.solve import cg, minres, find_within_trust_radius
7
+ from ...core import Chainable, Transform, HVPMethod
8
+ from ...utils import TensorList, tofloat, unpack_dicts, unpack_states
9
+ from ...linalg.solve import cg, find_within_trust_radius, minres
11
10
  from ..trust_region.trust_region import default_radius
12
11
 
13
- class NewtonCG(Module):
12
+
13
+ class NewtonCG(Transform):
14
14
  """Newton's method with a matrix-free conjugate gradient or minimial-residual solver.
15
15
 
16
16
  Notes:
@@ -37,17 +37,14 @@ class NewtonCG(Module):
37
37
  hvp_method (str, optional):
38
38
  Determines how Hessian-vector products are evaluated.
39
39
 
40
- - ``"autograd"``: Use PyTorch's autograd to calculate exact HVPs.
41
- This requires creating a graph for the gradient.
42
- - ``"forward"``: Use a forward finite difference formula to
43
- approximate the HVP. This requires one extra gradient evaluation.
44
- - ``"central"``: Use a central finite difference formula for a
45
- more accurate HVP approximation. This requires two extra
46
- gradient evaluations.
47
- Defaults to "autograd".
40
+ - ``"autograd"`` - uses autograd hessian-vector products. If multiple hessian-vector products are evaluated, uses a for-loop.
41
+ - ``"fd_forward"`` - uses gradient finite difference approximation with a less accurate forward formula which requires one extra gradient evaluation per hessian-vector product.
42
+ - ``"fd_central"`` - uses gradient finite difference approximation with a more accurate central formula which requires two gradient evaluations per hessian-vector product.
43
+
44
+ For NewtonCG ``"batched_autograd"`` is equivalent to ``"autograd"``. Defaults to ``"autograd"``.
48
45
  h (float, optional):
49
- The step size for finite differences if :code:`hvp_method` is
50
- ``"forward"`` or ``"central"``. Defaults to 1e-3.
46
+ The step size for finite difference if ``hvp_method`` is
47
+ ``"fd_forward"`` or ``"fd_central"``. Defaults to 1e-3.
51
48
  warm_start (bool, optional):
52
49
  If ``True``, the conjugate gradient solver is initialized with the
53
50
  solution from the previous optimization step. This can accelerate
@@ -82,100 +79,72 @@ class NewtonCG(Module):
82
79
  maxiter: int | None = None,
83
80
  tol: float = 1e-8,
84
81
  reg: float = 1e-8,
85
- hvp_method: Literal["forward", "central", "autograd"] = "autograd",
86
- solver: Literal['cg', 'minres', 'minres_npc'] = 'cg',
82
+ hvp_method: HVPMethod = "autograd",
83
+ solver: Literal['cg', 'minres'] = 'cg',
84
+ npc_terminate: bool = False,
87
85
  h: float = 1e-3, # tuned 1e-4 or 1e-3
88
86
  miniter:int = 1,
89
87
  warm_start=False,
88
+ warm_beta:float=0,
90
89
  inner: Chainable | None = None,
91
90
  ):
92
91
  defaults = locals().copy()
93
92
  del defaults['self'], defaults['inner']
94
- super().__init__(defaults,)
95
-
96
- if inner is not None:
97
- self.set_child('inner', inner)
93
+ super().__init__(defaults, inner=inner)
98
94
 
99
95
  self._num_hvps = 0
100
96
  self._num_hvps_last_step = 0
101
97
 
102
98
  @torch.no_grad
103
- def step(self, var):
104
- params = TensorList(var.params)
105
- closure = var.closure
106
- if closure is None: raise RuntimeError('NewtonCG requires closure')
107
-
108
- settings = self.settings[params[0]]
109
- tol = settings['tol']
110
- reg = settings['reg']
111
- maxiter = settings['maxiter']
112
- hvp_method = settings['hvp_method']
113
- solver = settings['solver'].lower().strip()
114
- h = settings['h']
115
- warm_start = settings['warm_start']
99
+ def update_states(self, objective, states, settings):
100
+ fs = settings[0]
101
+ hvp_method = fs['hvp_method']
102
+ h = fs['h']
116
103
 
117
- self._num_hvps_last_step = 0
118
104
  # ---------------------- Hessian vector product function --------------------- #
119
- if hvp_method == 'autograd':
120
- grad = var.get_grad(create_graph=True)
121
-
122
- def H_mm(x):
123
- self._num_hvps_last_step += 1
124
- with torch.enable_grad():
125
- return TensorList(hvp(params, grad, x, retain_graph=True))
126
-
127
- else:
128
-
129
- with torch.enable_grad():
130
- grad = var.get_grad()
131
-
132
- if hvp_method == 'forward':
133
- def H_mm(x):
134
- self._num_hvps_last_step += 1
135
- return TensorList(hvp_fd_forward(closure, params, x, h=h, g_0=grad, normalize=True)[1])
136
-
137
- elif hvp_method == 'central':
138
- def H_mm(x):
139
- self._num_hvps_last_step += 1
140
- return TensorList(hvp_fd_central(closure, params, x, h=h, normalize=True)[1])
141
-
142
- else:
143
- raise ValueError(hvp_method)
105
+ _, H_mv = objective.list_Hvp_function(hvp_method=hvp_method, h=h, at_x0=True)
106
+ objective.temp = H_mv
144
107
 
108
+ @torch.no_grad
109
+ def apply_states(self, objective, states, settings):
110
+ self._num_hvps_last_step = 0
111
+ H_mv = objective.poptemp()
145
112
 
146
- # -------------------------------- inner step -------------------------------- #
147
- b = var.get_update()
148
- if 'inner' in self.children:
149
- b = apply_transform(self.children['inner'], b, params=params, grads=grad, var=var)
150
- b = as_tensorlist(b)
113
+ fs = settings[0]
114
+ tol = fs['tol']
115
+ reg = fs['reg']
116
+ maxiter = fs['maxiter']
117
+ solver = fs['solver'].lower().strip()
118
+ warm_start = fs['warm_start']
119
+ npc_terminate = fs["npc_terminate"]
151
120
 
152
121
  # ---------------------------------- run cg ---------------------------------- #
153
122
  x0 = None
154
- if warm_start: x0 = self.get_state(params, 'prev_x', cls=TensorList) # initialized to 0 which is default anyway
123
+ if warm_start:
124
+ x0 = unpack_states(states, objective.params, 'prev_x', cls=TensorList)
125
+
126
+ b = TensorList(objective.get_updates())
155
127
 
156
128
  if solver == 'cg':
157
- d, _ = cg(A_mm=H_mm, b=b, x0=x0, tol=tol, maxiter=maxiter, miniter=self.defaults["miniter"],reg=reg)
129
+ d, _ = cg(A_mv=H_mv, b=b, x0=x0, tol=tol, maxiter=maxiter,
130
+ miniter=fs["miniter"], reg=reg, npc_terminate=npc_terminate)
158
131
 
159
132
  elif solver == 'minres':
160
- d = minres(A_mm=H_mm, b=b, x0=x0, tol=tol, maxiter=maxiter, reg=reg, npc_terminate=False)
161
-
162
- elif solver == 'minres_npc':
163
- d = minres(A_mm=H_mm, b=b, x0=x0, tol=tol, maxiter=maxiter, reg=reg, npc_terminate=True)
133
+ d = minres(A_mv=H_mv, b=b, x0=x0, tol=tol, maxiter=maxiter, reg=reg, npc_terminate=npc_terminate)
164
134
 
165
135
  else:
166
136
  raise ValueError(f"Unknown solver {solver}")
167
137
 
168
138
  if warm_start:
169
139
  assert x0 is not None
170
- x0.copy_(d)
171
-
172
- var.update = d
140
+ x0.lerp_(d, weight = 1-fs["warm_beta"])
173
141
 
142
+ objective.updates = d
174
143
  self._num_hvps += self._num_hvps_last_step
175
- return var
144
+ return objective
176
145
 
177
146
 
178
- class NewtonCGSteihaug(Module):
147
+ class NewtonCGSteihaug(Transform):
179
148
  """Newton's method with trust region and a matrix-free Steihaug-Toint conjugate gradient solver.
180
149
 
181
150
  Notes:
@@ -219,7 +188,7 @@ class NewtonCGSteihaug(Module):
219
188
  whether to terminate CG/MINRES whenever negative curvature is detected. Defaults to False.
220
189
 
221
190
  hvp_method (str, optional):
222
- either "forward" to use forward formula which requires one backward pass per Hvp, or "central" to use a more accurate central formula which requires two backward passes. "forward" is usually accurate enough. Defaults to "forward".
191
+ either ``"fd_forward"`` to use forward formula which requires one backward pass per hessian-vector product, or ``"fd_central"`` to use a more accurate central formula which requires two backward passes. ``"fd_forward"`` is usually accurate enough. Defaults to ``"fd_forward"``.
223
192
  h (float, optional): finite difference step size. Defaults to 1e-3.
224
193
 
225
194
  inner (Chainable | None, optional):
@@ -261,7 +230,7 @@ class NewtonCGSteihaug(Module):
261
230
  npc_terminate: bool = False,
262
231
 
263
232
  # hvp settings
264
- hvp_method: Literal["forward", "central"] = "central",
233
+ hvp_method: Literal["fd_forward", "fd_central"] = "fd_central",
265
234
  h: float = 1e-3, # tuned 1e-4 or 1e-3
266
235
 
267
236
  # inner
@@ -269,72 +238,51 @@ class NewtonCGSteihaug(Module):
269
238
  ):
270
239
  defaults = locals().copy()
271
240
  del defaults['self'], defaults['inner']
272
- super().__init__(defaults,)
273
-
274
- if inner is not None:
275
- self.set_child('inner', inner)
241
+ super().__init__(defaults, inner=inner)
276
242
 
277
243
  self._num_hvps = 0
278
244
  self._num_hvps_last_step = 0
279
245
 
280
- @torch.no_grad
281
- def step(self, var):
282
- params = TensorList(var.params)
283
- closure = var.closure
284
- if closure is None: raise RuntimeError('NewtonCG requires closure')
285
-
286
- tol = self.defaults['tol'] * self.global_state.get('tol_mul', 1)
287
- solver = self.defaults['solver'].lower().strip()
288
-
289
- (reg, maxiter, hvp_method, h, max_attempts, boundary_tol,
290
- eta, nplus, nminus, rho_good, rho_bad, init, npc_terminate,
291
- miniter, max_history, adapt_tol) = itemgetter(
292
- "reg", "maxiter", "hvp_method", "h", "max_attempts", "boundary_tol",
293
- "eta", "nplus", "nminus", "rho_good", "rho_bad", "init", "npc_terminate",
294
- "miniter", "max_history", "adapt_tol",
295
- )(self.defaults)
296
246
 
297
- self._num_hvps_last_step = 0
247
+ @torch.no_grad
248
+ def update_states(self, objective, states, settings):
249
+ fs = settings[0]
250
+ hvp_method = fs['hvp_method']
251
+ h = fs['h']
298
252
 
299
253
  # ---------------------- Hessian vector product function --------------------- #
300
- if hvp_method == 'autograd':
301
- grad = var.get_grad(create_graph=True)
302
-
303
- def H_mm(x):
304
- self._num_hvps_last_step += 1
305
- with torch.enable_grad():
306
- return TensorList(hvp(params, grad, x, retain_graph=True))
307
-
308
- else:
309
-
310
- with torch.enable_grad():
311
- grad = var.get_grad()
254
+ _, H_mv = objective.list_Hvp_function(hvp_method=hvp_method, h=h, at_x0=True)
255
+ objective.temp = H_mv
312
256
 
313
- if hvp_method == 'forward':
314
- def H_mm(x):
315
- self._num_hvps_last_step += 1
316
- return TensorList(hvp_fd_forward(closure, params, x, h=h, g_0=grad, normalize=True)[1])
257
+ @torch.no_grad
258
+ def apply_states(self, objective, states, settings):
259
+ self._num_hvps_last_step = 0
317
260
 
318
- elif hvp_method == 'central':
319
- def H_mm(x):
320
- self._num_hvps_last_step += 1
321
- return TensorList(hvp_fd_central(closure, params, x, h=h, normalize=True)[1])
261
+ H_mv = objective.poptemp()
262
+ params = TensorList(objective.params)
263
+ fs = settings[0]
322
264
 
323
- else:
324
- raise ValueError(hvp_method)
265
+ tol = fs['tol'] * self.global_state.get('tol_mul', 1)
266
+ solver = fs['solver'].lower().strip()
325
267
 
268
+ reg=fs["reg"]
269
+ maxiter=fs["maxiter"]
270
+ max_attempts=fs["max_attempts"]
271
+ init=fs["init"]
272
+ npc_terminate=fs["npc_terminate"]
273
+ miniter=fs["miniter"]
274
+ max_history=fs["max_history"]
275
+ adapt_tol=fs["adapt_tol"]
326
276
 
327
- # -------------------------------- inner step -------------------------------- #
328
- b = var.get_update()
329
- if 'inner' in self.children:
330
- b = apply_transform(self.children['inner'], b, params=params, grads=grad, var=var)
331
- b = as_tensorlist(b)
332
277
 
333
278
  # ------------------------------- trust region ------------------------------- #
334
279
  success = False
335
280
  d = None
336
- x0 = [p.clone() for p in params]
281
+ orig_params = [p.clone() for p in params]
282
+ b = TensorList(objective.get_updates())
337
283
  solution = None
284
+ closure = objective.closure
285
+ assert closure is not None
338
286
 
339
287
  while not success:
340
288
  max_attempts -= 1
@@ -343,7 +291,7 @@ class NewtonCGSteihaug(Module):
343
291
  trust_radius = self.global_state.get('trust_radius', init)
344
292
 
345
293
  # -------------- make sure trust radius isn't too small or large ------------- #
346
- finfo = torch.finfo(x0[0].dtype)
294
+ finfo = torch.finfo(orig_params[0].dtype)
347
295
  if trust_radius < finfo.tiny * 2:
348
296
  trust_radius = self.global_state['trust_radius'] = init
349
297
  if adapt_tol:
@@ -360,7 +308,7 @@ class NewtonCGSteihaug(Module):
360
308
  if d is None:
361
309
  if solver == 'cg':
362
310
  d, solution = cg(
363
- A_mm=H_mm,
311
+ A_mv=H_mv,
364
312
  b=b,
365
313
  tol=tol,
366
314
  maxiter=maxiter,
@@ -372,40 +320,40 @@ class NewtonCGSteihaug(Module):
372
320
  )
373
321
 
374
322
  elif solver == 'minres':
375
- d = minres(A_mm=H_mm, b=b, trust_radius=trust_radius, tol=tol, maxiter=maxiter, reg=reg, npc_terminate=npc_terminate)
323
+ d = minres(A_mv=H_mv, b=b, trust_radius=trust_radius, tol=tol, maxiter=maxiter, reg=reg, npc_terminate=npc_terminate)
376
324
 
377
325
  else:
378
326
  raise ValueError(f"unknown solver {solver}")
379
327
 
380
328
  # ---------------------------- update trust radius --------------------------- #
381
329
  self.global_state["trust_radius"], success = default_radius(
382
- params=params,
383
- closure=closure,
384
- f=tofloat(var.get_loss(False)),
385
- g=b,
386
- H=H_mm,
387
- d=d,
388
- trust_radius=trust_radius,
389
- eta=eta,
390
- nplus=nplus,
391
- nminus=nminus,
392
- rho_good=rho_good,
393
- rho_bad=rho_bad,
394
- boundary_tol=boundary_tol,
395
-
396
- init=init, # init isn't used because check_overflow=False
397
- state=self.global_state, # not used
398
- settings=self.defaults, # not used
399
- check_overflow=False, # this is checked manually to adapt tolerance
330
+ params = params,
331
+ closure = closure,
332
+ f = tofloat(objective.get_loss(False)),
333
+ g = b,
334
+ H = H_mv,
335
+ d = d,
336
+ trust_radius = trust_radius,
337
+ eta = fs["eta"],
338
+ nplus = fs["nplus"],
339
+ nminus = fs["nminus"],
340
+ rho_good = fs["rho_good"],
341
+ rho_bad = fs["rho_bad"],
342
+ boundary_tol = fs["boundary_tol"],
343
+
344
+ init = cast(int, None), # init isn't used because check_overflow=False
345
+ state = cast(dict, None), # not used
346
+ settings = cast(dict, None), # not used
347
+ check_overflow = False, # this is checked manually to adapt tolerance
400
348
  )
401
349
 
402
350
  # --------------------------- assign new direction --------------------------- #
403
351
  assert d is not None
404
352
  if success:
405
- var.update = d
353
+ objective.updates = d
406
354
 
407
355
  else:
408
- var.update = params.zeros_like()
356
+ objective.updates = params.zeros_like()
409
357
 
410
358
  self._num_hvps += self._num_hvps_last_step
411
- return var
359
+ return objective