torchzero 0.3.9__py3-none-any.whl → 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (153) hide show
  1. docs/source/conf.py +6 -4
  2. docs/source/docstring template.py +46 -0
  3. tests/test_identical.py +2 -3
  4. tests/test_opts.py +115 -68
  5. tests/test_tensorlist.py +2 -2
  6. tests/test_vars.py +62 -61
  7. torchzero/core/__init__.py +2 -3
  8. torchzero/core/module.py +185 -53
  9. torchzero/core/transform.py +327 -159
  10. torchzero/modules/__init__.py +3 -1
  11. torchzero/modules/clipping/clipping.py +120 -23
  12. torchzero/modules/clipping/ema_clipping.py +37 -22
  13. torchzero/modules/clipping/growth_clipping.py +20 -21
  14. torchzero/modules/experimental/__init__.py +30 -4
  15. torchzero/modules/experimental/absoap.py +53 -156
  16. torchzero/modules/experimental/adadam.py +22 -15
  17. torchzero/modules/experimental/adamY.py +21 -25
  18. torchzero/modules/experimental/adam_lambertw.py +149 -0
  19. torchzero/modules/{line_search/trust_region.py → experimental/adaptive_step_size.py} +37 -8
  20. torchzero/modules/experimental/adasoap.py +24 -129
  21. torchzero/modules/experimental/cosine.py +214 -0
  22. torchzero/modules/experimental/cubic_adam.py +97 -0
  23. torchzero/modules/experimental/curveball.py +12 -12
  24. torchzero/modules/{projections → experimental}/dct.py +11 -11
  25. torchzero/modules/experimental/eigendescent.py +120 -0
  26. torchzero/modules/experimental/etf.py +195 -0
  27. torchzero/modules/experimental/exp_adam.py +113 -0
  28. torchzero/modules/experimental/expanded_lbfgs.py +141 -0
  29. torchzero/modules/{projections → experimental}/fft.py +10 -10
  30. torchzero/modules/experimental/gradmin.py +2 -2
  31. torchzero/modules/experimental/hnewton.py +85 -0
  32. torchzero/modules/{quasi_newton/experimental → experimental}/modular_lbfgs.py +49 -50
  33. torchzero/modules/experimental/newton_solver.py +11 -11
  34. torchzero/modules/experimental/newtonnewton.py +92 -0
  35. torchzero/modules/experimental/parabolic_search.py +220 -0
  36. torchzero/modules/experimental/reduce_outward_lr.py +10 -7
  37. torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +12 -54
  38. torchzero/modules/experimental/subspace_preconditioners.py +20 -10
  39. torchzero/modules/experimental/tensor_adagrad.py +42 -0
  40. torchzero/modules/functional.py +12 -2
  41. torchzero/modules/grad_approximation/fdm.py +31 -4
  42. torchzero/modules/grad_approximation/forward_gradient.py +17 -7
  43. torchzero/modules/grad_approximation/grad_approximator.py +69 -24
  44. torchzero/modules/grad_approximation/rfdm.py +310 -50
  45. torchzero/modules/higher_order/__init__.py +1 -0
  46. torchzero/modules/higher_order/higher_order_newton.py +319 -0
  47. torchzero/modules/line_search/__init__.py +4 -4
  48. torchzero/modules/line_search/adaptive.py +99 -0
  49. torchzero/modules/line_search/backtracking.py +75 -31
  50. torchzero/modules/line_search/line_search.py +107 -49
  51. torchzero/modules/line_search/polynomial.py +233 -0
  52. torchzero/modules/line_search/scipy.py +20 -5
  53. torchzero/modules/line_search/strong_wolfe.py +52 -36
  54. torchzero/modules/misc/__init__.py +27 -0
  55. torchzero/modules/misc/debug.py +48 -0
  56. torchzero/modules/misc/escape.py +60 -0
  57. torchzero/modules/misc/gradient_accumulation.py +70 -0
  58. torchzero/modules/misc/misc.py +316 -0
  59. torchzero/modules/misc/multistep.py +158 -0
  60. torchzero/modules/misc/regularization.py +171 -0
  61. torchzero/modules/misc/split.py +103 -0
  62. torchzero/modules/{ops → misc}/switch.py +48 -7
  63. torchzero/modules/momentum/__init__.py +1 -1
  64. torchzero/modules/momentum/averaging.py +25 -10
  65. torchzero/modules/momentum/cautious.py +115 -40
  66. torchzero/modules/momentum/ema.py +92 -41
  67. torchzero/modules/momentum/experimental.py +21 -13
  68. torchzero/modules/momentum/matrix_momentum.py +145 -76
  69. torchzero/modules/momentum/momentum.py +25 -4
  70. torchzero/modules/ops/__init__.py +3 -31
  71. torchzero/modules/ops/accumulate.py +51 -25
  72. torchzero/modules/ops/binary.py +108 -62
  73. torchzero/modules/ops/multi.py +95 -34
  74. torchzero/modules/ops/reduce.py +31 -23
  75. torchzero/modules/ops/unary.py +37 -21
  76. torchzero/modules/ops/utility.py +53 -45
  77. torchzero/modules/optimizers/__init__.py +12 -3
  78. torchzero/modules/optimizers/adagrad.py +48 -29
  79. torchzero/modules/optimizers/adahessian.py +223 -0
  80. torchzero/modules/optimizers/adam.py +35 -37
  81. torchzero/modules/optimizers/adan.py +110 -0
  82. torchzero/modules/optimizers/adaptive_heavyball.py +57 -0
  83. torchzero/modules/optimizers/esgd.py +171 -0
  84. torchzero/modules/optimizers/ladagrad.py +183 -0
  85. torchzero/modules/optimizers/lion.py +4 -4
  86. torchzero/modules/optimizers/mars.py +91 -0
  87. torchzero/modules/optimizers/msam.py +186 -0
  88. torchzero/modules/optimizers/muon.py +32 -7
  89. torchzero/modules/optimizers/orthograd.py +4 -5
  90. torchzero/modules/optimizers/rmsprop.py +19 -19
  91. torchzero/modules/optimizers/rprop.py +89 -52
  92. torchzero/modules/optimizers/sam.py +163 -0
  93. torchzero/modules/optimizers/shampoo.py +55 -27
  94. torchzero/modules/optimizers/soap.py +40 -37
  95. torchzero/modules/optimizers/sophia_h.py +82 -25
  96. torchzero/modules/projections/__init__.py +2 -4
  97. torchzero/modules/projections/cast.py +51 -0
  98. torchzero/modules/projections/galore.py +4 -2
  99. torchzero/modules/projections/projection.py +212 -118
  100. torchzero/modules/quasi_newton/__init__.py +44 -5
  101. torchzero/modules/quasi_newton/cg.py +190 -39
  102. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +163 -0
  103. torchzero/modules/quasi_newton/lbfgs.py +154 -97
  104. torchzero/modules/quasi_newton/lsr1.py +102 -58
  105. torchzero/modules/quasi_newton/quasi_newton.py +1032 -177
  106. torchzero/modules/quasi_newton/trust_region.py +397 -0
  107. torchzero/modules/second_order/__init__.py +2 -2
  108. torchzero/modules/second_order/newton.py +245 -54
  109. torchzero/modules/second_order/newton_cg.py +311 -21
  110. torchzero/modules/second_order/nystrom.py +124 -21
  111. torchzero/modules/smoothing/gaussian.py +55 -21
  112. torchzero/modules/smoothing/laplacian.py +20 -12
  113. torchzero/modules/step_size/__init__.py +2 -0
  114. torchzero/modules/step_size/adaptive.py +122 -0
  115. torchzero/modules/step_size/lr.py +154 -0
  116. torchzero/modules/weight_decay/__init__.py +1 -1
  117. torchzero/modules/weight_decay/weight_decay.py +126 -10
  118. torchzero/modules/wrappers/optim_wrapper.py +40 -12
  119. torchzero/optim/wrappers/directsearch.py +281 -0
  120. torchzero/optim/wrappers/fcmaes.py +105 -0
  121. torchzero/optim/wrappers/mads.py +89 -0
  122. torchzero/optim/wrappers/nevergrad.py +20 -5
  123. torchzero/optim/wrappers/nlopt.py +28 -14
  124. torchzero/optim/wrappers/optuna.py +70 -0
  125. torchzero/optim/wrappers/scipy.py +167 -16
  126. torchzero/utils/__init__.py +3 -7
  127. torchzero/utils/derivatives.py +5 -4
  128. torchzero/utils/linalg/__init__.py +1 -1
  129. torchzero/utils/linalg/solve.py +251 -12
  130. torchzero/utils/numberlist.py +2 -0
  131. torchzero/utils/optimizer.py +55 -74
  132. torchzero/utils/python_tools.py +27 -4
  133. torchzero/utils/tensorlist.py +40 -28
  134. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/METADATA +76 -51
  135. torchzero-0.3.11.dist-info/RECORD +159 -0
  136. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/WHEEL +1 -1
  137. torchzero/core/preconditioner.py +0 -138
  138. torchzero/modules/experimental/algebraic_newton.py +0 -145
  139. torchzero/modules/experimental/soapy.py +0 -290
  140. torchzero/modules/experimental/spectral.py +0 -288
  141. torchzero/modules/experimental/structured_newton.py +0 -111
  142. torchzero/modules/experimental/tropical_newton.py +0 -136
  143. torchzero/modules/lr/__init__.py +0 -2
  144. torchzero/modules/lr/lr.py +0 -59
  145. torchzero/modules/lr/step_size.py +0 -97
  146. torchzero/modules/ops/debug.py +0 -25
  147. torchzero/modules/ops/misc.py +0 -419
  148. torchzero/modules/ops/split.py +0 -75
  149. torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
  150. torchzero/modules/quasi_newton/olbfgs.py +0 -196
  151. torchzero-0.3.9.dist-info/RECORD +0 -131
  152. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/licenses/LICENSE +0 -0
  153. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/top_level.txt +0 -0
@@ -5,7 +5,7 @@ from typing import Literal
5
5
 
6
6
  import torch
7
7
 
8
- from ...core import Chainable, Module, apply
8
+ from ...core import Chainable, Module, apply_transform
9
9
  from ...utils import TensorList, vec_to_tensors
10
10
  from ...utils.derivatives import (
11
11
  hessian_list_to_mat,
@@ -18,9 +18,12 @@ from ...utils.derivatives import (
18
18
 
19
19
 
20
20
  def lu_solve(H: torch.Tensor, g: torch.Tensor):
21
- x, info = torch.linalg.solve_ex(H, g) # pylint:disable=not-callable
22
- if info == 0: return x
23
- return None
21
+ try:
22
+ x, info = torch.linalg.solve_ex(H, g) # pylint:disable=not-callable
23
+ if info == 0: return x
24
+ return None
25
+ except RuntimeError:
26
+ return None
24
27
 
25
28
  def cholesky_solve(H: torch.Tensor, g: torch.Tensor):
26
29
  x, info = torch.linalg.cholesky_ex(H) # pylint:disable=not-callable
@@ -32,12 +35,17 @@ def cholesky_solve(H: torch.Tensor, g: torch.Tensor):
32
35
  def least_squares_solve(H: torch.Tensor, g: torch.Tensor):
33
36
  return torch.linalg.lstsq(H, g)[0] # pylint:disable=not-callable
34
37
 
35
- def eigh_solve(H: torch.Tensor, g: torch.Tensor, tfm: Callable | None):
38
+ def eigh_solve(H: torch.Tensor, g: torch.Tensor, tfm: Callable | None, search_negative: bool):
36
39
  try:
37
40
  L, Q = torch.linalg.eigh(H) # pylint:disable=not-callable
38
41
  if tfm is not None: L = tfm(L)
39
- L.reciprocal_()
40
- return torch.linalg.multi_dot([Q * L.unsqueeze(-2), Q.mH, g]) # pylint:disable=not-callable
42
+ if search_negative and L[0] < 0:
43
+ d = Q[0]
44
+ # use eigvec or -eigvec depending on if it points in same direction as gradient
45
+ return g.dot(d).sign() * d
46
+
47
+ return Q @ ((Q.mH @ g) / L)
48
+
41
49
  except torch.linalg.LinAlgError:
42
50
  return None
43
51
 
@@ -45,103 +53,286 @@ def tikhonov_(H: torch.Tensor, reg: float):
45
53
  if reg!=0: H.add_(torch.eye(H.size(-1), dtype=H.dtype, device=H.device).mul_(reg))
46
54
  return H
47
55
 
48
- def eig_tikhonov_(H: torch.Tensor, reg: float):
49
- v = torch.linalg.eigvalsh(H).min().clamp_(max=0).neg_() + reg # pylint:disable=not-callable
50
- return tikhonov_(H, v)
51
-
52
56
 
53
57
  class Newton(Module):
54
- """Exact newton via autograd.
58
+ """Exact newton's method via autograd.
59
+
60
+ .. note::
61
+ In most cases Newton should be the first module in the chain because it relies on autograd. Use the :code:`inner` argument if you wish to apply Newton preconditioning to another module's output.
62
+
63
+ .. note::
64
+ This module requires the a closure passed to the optimizer step,
65
+ as it needs to re-evaluate the loss and gradients for calculating the hessian.
66
+ The closure must accept a ``backward`` argument (refer to documentation).
67
+
68
+ .. warning::
69
+ this uses roughly O(N^2) memory.
70
+
55
71
 
56
72
  Args:
57
73
  reg (float, optional): tikhonov regularizer value. Defaults to 1e-6.
58
- eig_reg (bool, optional): whether to use largest negative eigenvalue as regularizer. Defaults to False.
74
+ search_negative (bool, Optional):
75
+ if True, whenever a negative eigenvalue is detected,
76
+ search direction is proposed along an eigenvector corresponding to a negative eigenvalue.
59
77
  hessian_method (str):
60
78
  how to calculate hessian. Defaults to "autograd".
61
79
  vectorize (bool, optional):
62
80
  whether to enable vectorized hessian. Defaults to True.
63
- inner (Chainable | None, optional): inner modules. Defaults to None.
81
+ inner (Chainable | None, optional): modules to apply hessian preconditioner to. Defaults to None.
64
82
  H_tfm (Callable | None, optional):
65
83
  optional hessian transforms, takes in two arguments - `(hessian, gradient)`.
66
84
 
67
- must return a tuple: `(hessian, is_inverted)` with transformed hessian and a boolean value
68
- which must be True if transform inverted the hessian and False otherwise. Defaults to None.
85
+ must return either a tuple: `(hessian, is_inverted)` with transformed hessian and a boolean value
86
+ which must be True if transform inverted the hessian and False otherwise.
87
+
88
+ Or it returns a single tensor which is used as the update.
89
+
90
+ Defaults to None.
69
91
  eigval_tfm (Callable | None, optional):
70
92
  optional eigenvalues transform, for example :code:`torch.abs` or :code:`lambda L: torch.clip(L, min=1e-8)`.
71
- If this is specified, eigendecomposition will be used to solve Hx = g.
93
+ If this is specified, eigendecomposition will be used to invert the hessian.
94
+
95
+ Examples:
96
+ Newton's method with backtracking line search
97
+
98
+ .. code-block:: python
99
+
100
+ opt = tz.Modular(
101
+ model.parameters(),
102
+ tz.m.Newton(),
103
+ tz.m.Backtracking()
104
+ )
105
+
106
+ Newton's method modified for non-convex functions by taking matrix absolute value of the hessian
107
+
108
+ .. code-block:: python
109
+
110
+ opt = tz.Modular(
111
+ model.parameters(),
112
+ tz.m.Newton(eigval_tfm=lambda x: torch.abs(x).clip(min=0.1)),
113
+ tz.m.Backtracking()
114
+ )
115
+
116
+ Newton's method modified for non-convex functions by searching along negative curvature directions
117
+
118
+ .. code-block:: python
119
+
120
+ opt = tz.Modular(
121
+ model.parameters(),
122
+ tz.m.Newton(search_negative=True),
123
+ tz.m.Backtracking()
124
+ )
125
+
126
+ Newton preconditioning applied to momentum
127
+
128
+ .. code-block:: python
129
+
130
+ opt = tz.Modular(
131
+ model.parameters(),
132
+ tz.m.Newton(inner=tz.m.EMA(0.9)),
133
+ tz.m.LR(0.1)
134
+ )
135
+
136
+ Diagonal newton example. This will still evaluate the entire hessian so it isn't efficient, but if you wanted to see how diagonal newton behaves or compares to full newton, you can use this.
137
+
138
+ .. code-block:: python
139
+
140
+ opt = tz.Modular(
141
+ model.parameters(),
142
+ tz.m.Newton(H_tfm = lambda H, g: g/H.diag()),
143
+ tz.m.Backtracking()
144
+ )
72
145
 
73
146
  """
74
147
  def __init__(
75
148
  self,
76
149
  reg: float = 1e-6,
77
- eig_reg: bool = False,
150
+ search_negative: bool = False,
151
+ update_freq: int = 1,
78
152
  hessian_method: Literal["autograd", "func", "autograd.functional"] = "autograd",
79
153
  vectorize: bool = True,
80
154
  inner: Chainable | None = None,
81
- H_tfm: Callable[[torch.Tensor, torch.Tensor], tuple[torch.Tensor, bool]] | None = None,
155
+ H_tfm: Callable[[torch.Tensor, torch.Tensor], tuple[torch.Tensor, bool]] | Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
82
156
  eigval_tfm: Callable[[torch.Tensor], torch.Tensor] | None = None,
83
157
  ):
84
- defaults = dict(reg=reg, eig_reg=eig_reg, abs=abs,hessian_method=hessian_method, vectorize=vectorize, H_tfm=H_tfm, eigval_tfm=eigval_tfm)
158
+ defaults = dict(reg=reg, hessian_method=hessian_method, vectorize=vectorize, H_tfm=H_tfm, eigval_tfm=eigval_tfm, search_negative=search_negative, update_freq=update_freq)
85
159
  super().__init__(defaults)
86
160
 
87
161
  if inner is not None:
88
162
  self.set_child('inner', inner)
89
163
 
90
164
  @torch.no_grad
91
- def step(self, vars):
92
- params = TensorList(vars.params)
93
- closure = vars.closure
165
+ def step(self, var):
166
+ params = TensorList(var.params)
167
+ closure = var.closure
94
168
  if closure is None: raise RuntimeError('NewtonCG requires closure')
95
169
 
96
170
  settings = self.settings[params[0]]
97
171
  reg = settings['reg']
98
- eig_reg = settings['eig_reg']
172
+ search_negative = settings['search_negative']
99
173
  hessian_method = settings['hessian_method']
100
174
  vectorize = settings['vectorize']
101
175
  H_tfm = settings['H_tfm']
102
176
  eigval_tfm = settings['eigval_tfm']
177
+ update_freq = settings['update_freq']
178
+
179
+ step = self.global_state.get('step', 0)
180
+ self.global_state['step'] = step + 1
181
+
182
+ g_list = var.grad
183
+ H = None
184
+ if step % update_freq == 0:
185
+ # ------------------------ calculate grad and hessian ------------------------ #
186
+ if hessian_method == 'autograd':
187
+ with torch.enable_grad():
188
+ loss = var.loss = var.loss_approx = closure(False)
189
+ g_list, H_list = jacobian_and_hessian_wrt([loss], params, batched=vectorize)
190
+ g_list = [t[0] for t in g_list] # remove leading dim from loss
191
+ var.grad = g_list
192
+ H = hessian_list_to_mat(H_list)
103
193
 
104
- # ------------------------ calculate grad and hessian ------------------------ #
105
- if hessian_method == 'autograd':
106
- with torch.enable_grad():
107
- loss = vars.loss = vars.loss_approx = closure(False)
108
- g_list, H_list = jacobian_and_hessian_wrt([loss], params, batched=vectorize)
109
- g_list = [t[0] for t in g_list] # remove leading dim from loss
110
- vars.grad = g_list
111
- H = hessian_list_to_mat(H_list)
112
-
113
- elif hessian_method in ('func', 'autograd.functional'):
114
- strat = 'forward-mode' if vectorize else 'reverse-mode'
115
- with torch.enable_grad():
116
- g_list = vars.get_grad(retain_graph=True)
117
- H: torch.Tensor = hessian_mat(partial(closure, backward=False), params,
118
- method=hessian_method, vectorize=vectorize, outer_jacobian_strategy=strat) # pyright:ignore[reportAssignmentType]
119
-
120
- else:
121
- raise ValueError(hessian_method)
194
+ elif hessian_method in ('func', 'autograd.functional'):
195
+ strat = 'forward-mode' if vectorize else 'reverse-mode'
196
+ with torch.enable_grad():
197
+ g_list = var.get_grad(retain_graph=True)
198
+ H = hessian_mat(partial(closure, backward=False), params,
199
+ method=hessian_method, vectorize=vectorize, outer_jacobian_strategy=strat) # pyright:ignore[reportAssignmentType]
200
+
201
+ else:
202
+ raise ValueError(hessian_method)
203
+
204
+ H = tikhonov_(H, reg)
205
+ if update_freq != 1:
206
+ self.global_state['H'] = H
207
+
208
+ if H is None:
209
+ H = self.global_state["H"]
210
+
211
+ # var.storage['hessian'] = H
122
212
 
123
213
  # -------------------------------- inner step -------------------------------- #
124
- update = vars.get_update()
214
+ update = var.get_update()
125
215
  if 'inner' in self.children:
126
- update = apply(self.children['inner'], update, params=params, grads=list(g_list), vars=vars)
127
- g = torch.cat([t.view(-1) for t in update])
216
+ update = apply_transform(self.children['inner'], update, params=params, grads=g_list, var=var)
217
+
218
+ g = torch.cat([t.ravel() for t in update])
128
219
 
129
- # ------------------------------- regulazition ------------------------------- #
130
- if eig_reg: H = eig_tikhonov_(H, reg)
131
- else: H = tikhonov_(H, reg)
132
220
 
133
221
  # ----------------------------------- solve ---------------------------------- #
134
222
  update = None
135
223
  if H_tfm is not None:
136
- H, is_inv = H_tfm(H, g)
137
- if is_inv: update = H
224
+ ret = H_tfm(H, g)
138
225
 
139
- if eigval_tfm is not None:
140
- update = eigh_solve(H, g, eigval_tfm)
226
+ if isinstance(ret, torch.Tensor):
227
+ update = ret
228
+
229
+ else: # returns (H, is_inv)
230
+ H, is_inv = ret
231
+ if is_inv: update = H @ g
232
+
233
+ if search_negative or (eigval_tfm is not None):
234
+ update = eigh_solve(H, g, eigval_tfm, search_negative=search_negative)
141
235
 
142
236
  if update is None: update = cholesky_solve(H, g)
143
237
  if update is None: update = lu_solve(H, g)
144
238
  if update is None: update = least_squares_solve(H, g)
145
239
 
146
- vars.update = vec_to_tensors(update, params)
147
- return vars
240
+ var.update = vec_to_tensors(update, params)
241
+
242
+ return var
243
+
244
+ class InverseFreeNewton(Module):
245
+ """Inverse-free newton's method
246
+
247
+ .. note::
248
+ In most cases Newton should be the first module in the chain because it relies on autograd. Use the :code:`inner` argument if you wish to apply Newton preconditioning to another module's output.
249
+
250
+ .. note::
251
+ This module requires the a closure passed to the optimizer step,
252
+ as it needs to re-evaluate the loss and gradients for calculating the hessian.
253
+ The closure must accept a ``backward`` argument (refer to documentation).
254
+
255
+ .. warning::
256
+ this uses roughly O(N^2) memory.
257
+
258
+ Reference
259
+ Massalski, Marcin, and Magdalena Nockowska-Rosiak. "INVERSE-FREE NEWTON'S METHOD." Journal of Applied Analysis & Computation 15.4 (2025): 2238-2257.
260
+ """
261
+ def __init__(
262
+ self,
263
+ update_freq: int = 1,
264
+ hessian_method: Literal["autograd", "func", "autograd.functional"] = "autograd",
265
+ vectorize: bool = True,
266
+ inner: Chainable | None = None,
267
+ ):
268
+ defaults = dict(hessian_method=hessian_method, vectorize=vectorize, update_freq=update_freq)
269
+ super().__init__(defaults)
270
+
271
+ if inner is not None:
272
+ self.set_child('inner', inner)
273
+
274
+ @torch.no_grad
275
+ def step(self, var):
276
+ params = TensorList(var.params)
277
+ closure = var.closure
278
+ if closure is None: raise RuntimeError('NewtonCG requires closure')
279
+
280
+ settings = self.settings[params[0]]
281
+ hessian_method = settings['hessian_method']
282
+ vectorize = settings['vectorize']
283
+ update_freq = settings['update_freq']
284
+
285
+ step = self.global_state.get('step', 0)
286
+ self.global_state['step'] = step + 1
287
+
288
+ g_list = var.grad
289
+ Y = None
290
+ if step % update_freq == 0:
291
+ # ------------------------ calculate grad and hessian ------------------------ #
292
+ if hessian_method == 'autograd':
293
+ with torch.enable_grad():
294
+ loss = var.loss = var.loss_approx = closure(False)
295
+ g_list, H_list = jacobian_and_hessian_wrt([loss], params, batched=vectorize)
296
+ g_list = [t[0] for t in g_list] # remove leading dim from loss
297
+ var.grad = g_list
298
+ H = hessian_list_to_mat(H_list)
299
+
300
+ elif hessian_method in ('func', 'autograd.functional'):
301
+ strat = 'forward-mode' if vectorize else 'reverse-mode'
302
+ with torch.enable_grad():
303
+ g_list = var.get_grad(retain_graph=True)
304
+ H = hessian_mat(partial(closure, backward=False), params,
305
+ method=hessian_method, vectorize=vectorize, outer_jacobian_strategy=strat) # pyright:ignore[reportAssignmentType]
306
+
307
+ else:
308
+ raise ValueError(hessian_method)
309
+
310
+ # inverse free part
311
+ if 'Y' not in self.global_state:
312
+ num = H.T
313
+ denom = (torch.linalg.norm(H, 1) * torch.linalg.norm(H, float('inf'))) # pylint:disable=not-callable
314
+ eps = torch.finfo(H.dtype).eps
315
+ Y = self.global_state['Y'] = num.div_(denom.clip(min=eps, max=1/eps))
316
+
317
+ else:
318
+ Y = self.global_state['Y']
319
+ I = torch.eye(Y.size(0), device=Y.device, dtype=Y.dtype).mul_(2)
320
+ I -= H @ Y
321
+ Y = self.global_state['Y'] = Y @ I
322
+
323
+ if Y is None:
324
+ Y = self.global_state["Y"]
325
+
326
+
327
+ # -------------------------------- inner step -------------------------------- #
328
+ update = var.get_update()
329
+ if 'inner' in self.children:
330
+ update = apply_transform(self.children['inner'], update, params=params, grads=g_list, var=var)
331
+
332
+ g = torch.cat([t.ravel() for t in update])
333
+
334
+
335
+ # ----------------------------------- solve ---------------------------------- #
336
+ var.update = vec_to_tensors(Y@g, params)
337
+
338
+ return var