torchzero 0.3.15__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. tests/test_identical.py +2 -2
  2. tests/test_module_autograd.py +586 -0
  3. tests/test_objective.py +188 -0
  4. tests/test_opts.py +43 -33
  5. tests/test_tensorlist.py +0 -8
  6. tests/test_utils_optimizer.py +0 -1
  7. torchzero/__init__.py +1 -1
  8. torchzero/core/__init__.py +7 -4
  9. torchzero/core/chain.py +20 -23
  10. torchzero/core/functional.py +90 -24
  11. torchzero/core/modular.py +48 -52
  12. torchzero/core/module.py +130 -50
  13. torchzero/core/objective.py +948 -0
  14. torchzero/core/reformulation.py +55 -24
  15. torchzero/core/transform.py +261 -367
  16. torchzero/linalg/__init__.py +10 -0
  17. torchzero/linalg/eigh.py +34 -0
  18. torchzero/linalg/linalg_utils.py +14 -0
  19. torchzero/{utils/linalg → linalg}/linear_operator.py +99 -49
  20. torchzero/linalg/matrix_power.py +28 -0
  21. torchzero/linalg/orthogonalize.py +95 -0
  22. torchzero/{utils/linalg → linalg}/qr.py +4 -2
  23. torchzero/{utils/linalg → linalg}/solve.py +76 -88
  24. torchzero/linalg/svd.py +20 -0
  25. torchzero/linalg/torch_linalg.py +168 -0
  26. torchzero/modules/adaptive/__init__.py +1 -1
  27. torchzero/modules/adaptive/adagrad.py +163 -213
  28. torchzero/modules/adaptive/adahessian.py +74 -103
  29. torchzero/modules/adaptive/adam.py +53 -76
  30. torchzero/modules/adaptive/adan.py +49 -30
  31. torchzero/modules/adaptive/adaptive_heavyball.py +11 -6
  32. torchzero/modules/adaptive/aegd.py +12 -12
  33. torchzero/modules/adaptive/esgd.py +98 -119
  34. torchzero/modules/adaptive/lion.py +5 -10
  35. torchzero/modules/adaptive/lmadagrad.py +87 -32
  36. torchzero/modules/adaptive/mars.py +5 -5
  37. torchzero/modules/adaptive/matrix_momentum.py +47 -51
  38. torchzero/modules/adaptive/msam.py +70 -52
  39. torchzero/modules/adaptive/muon.py +59 -124
  40. torchzero/modules/adaptive/natural_gradient.py +33 -28
  41. torchzero/modules/adaptive/orthograd.py +11 -15
  42. torchzero/modules/adaptive/rmsprop.py +83 -75
  43. torchzero/modules/adaptive/rprop.py +48 -47
  44. torchzero/modules/adaptive/sam.py +55 -45
  45. torchzero/modules/adaptive/shampoo.py +123 -129
  46. torchzero/modules/adaptive/soap.py +207 -143
  47. torchzero/modules/adaptive/sophia_h.py +106 -130
  48. torchzero/modules/clipping/clipping.py +15 -18
  49. torchzero/modules/clipping/ema_clipping.py +31 -25
  50. torchzero/modules/clipping/growth_clipping.py +14 -17
  51. torchzero/modules/conjugate_gradient/cg.py +26 -37
  52. torchzero/modules/experimental/__init__.py +2 -6
  53. torchzero/modules/experimental/coordinate_momentum.py +36 -0
  54. torchzero/modules/experimental/curveball.py +25 -41
  55. torchzero/modules/experimental/gradmin.py +2 -2
  56. torchzero/modules/experimental/higher_order_newton.py +14 -40
  57. torchzero/modules/experimental/newton_solver.py +22 -53
  58. torchzero/modules/experimental/newtonnewton.py +15 -12
  59. torchzero/modules/experimental/reduce_outward_lr.py +7 -7
  60. torchzero/modules/experimental/scipy_newton_cg.py +21 -24
  61. torchzero/modules/experimental/spsa1.py +3 -3
  62. torchzero/modules/experimental/structural_projections.py +1 -4
  63. torchzero/modules/functional.py +1 -1
  64. torchzero/modules/grad_approximation/forward_gradient.py +7 -7
  65. torchzero/modules/grad_approximation/grad_approximator.py +23 -16
  66. torchzero/modules/grad_approximation/rfdm.py +20 -17
  67. torchzero/modules/least_squares/gn.py +90 -42
  68. torchzero/modules/line_search/backtracking.py +2 -2
  69. torchzero/modules/line_search/line_search.py +32 -32
  70. torchzero/modules/line_search/strong_wolfe.py +2 -2
  71. torchzero/modules/misc/debug.py +12 -12
  72. torchzero/modules/misc/escape.py +10 -10
  73. torchzero/modules/misc/gradient_accumulation.py +10 -78
  74. torchzero/modules/misc/homotopy.py +16 -8
  75. torchzero/modules/misc/misc.py +120 -122
  76. torchzero/modules/misc/multistep.py +50 -48
  77. torchzero/modules/misc/regularization.py +49 -44
  78. torchzero/modules/misc/split.py +30 -28
  79. torchzero/modules/misc/switch.py +37 -32
  80. torchzero/modules/momentum/averaging.py +14 -14
  81. torchzero/modules/momentum/cautious.py +34 -28
  82. torchzero/modules/momentum/momentum.py +11 -11
  83. torchzero/modules/ops/__init__.py +4 -4
  84. torchzero/modules/ops/accumulate.py +21 -21
  85. torchzero/modules/ops/binary.py +67 -66
  86. torchzero/modules/ops/higher_level.py +19 -19
  87. torchzero/modules/ops/multi.py +44 -41
  88. torchzero/modules/ops/reduce.py +26 -23
  89. torchzero/modules/ops/unary.py +53 -53
  90. torchzero/modules/ops/utility.py +47 -46
  91. torchzero/modules/projections/galore.py +1 -1
  92. torchzero/modules/projections/projection.py +43 -43
  93. torchzero/modules/quasi_newton/damping.py +1 -1
  94. torchzero/modules/quasi_newton/lbfgs.py +7 -7
  95. torchzero/modules/quasi_newton/lsr1.py +7 -7
  96. torchzero/modules/quasi_newton/quasi_newton.py +10 -10
  97. torchzero/modules/quasi_newton/sg2.py +19 -19
  98. torchzero/modules/restarts/restars.py +26 -24
  99. torchzero/modules/second_order/__init__.py +2 -2
  100. torchzero/modules/second_order/ifn.py +31 -62
  101. torchzero/modules/second_order/inm.py +49 -53
  102. torchzero/modules/second_order/multipoint.py +40 -80
  103. torchzero/modules/second_order/newton.py +57 -90
  104. torchzero/modules/second_order/newton_cg.py +102 -154
  105. torchzero/modules/second_order/nystrom.py +157 -177
  106. torchzero/modules/second_order/rsn.py +106 -96
  107. torchzero/modules/smoothing/laplacian.py +13 -12
  108. torchzero/modules/smoothing/sampling.py +11 -10
  109. torchzero/modules/step_size/adaptive.py +23 -23
  110. torchzero/modules/step_size/lr.py +15 -15
  111. torchzero/modules/termination/termination.py +32 -30
  112. torchzero/modules/trust_region/cubic_regularization.py +2 -2
  113. torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
  114. torchzero/modules/trust_region/trust_cg.py +1 -1
  115. torchzero/modules/trust_region/trust_region.py +27 -22
  116. torchzero/modules/variance_reduction/svrg.py +21 -18
  117. torchzero/modules/weight_decay/__init__.py +2 -1
  118. torchzero/modules/weight_decay/reinit.py +83 -0
  119. torchzero/modules/weight_decay/weight_decay.py +12 -13
  120. torchzero/modules/wrappers/optim_wrapper.py +10 -10
  121. torchzero/modules/zeroth_order/cd.py +9 -6
  122. torchzero/optim/root.py +3 -3
  123. torchzero/optim/utility/split.py +2 -1
  124. torchzero/optim/wrappers/directsearch.py +27 -63
  125. torchzero/optim/wrappers/fcmaes.py +14 -35
  126. torchzero/optim/wrappers/mads.py +11 -31
  127. torchzero/optim/wrappers/moors.py +66 -0
  128. torchzero/optim/wrappers/nevergrad.py +4 -4
  129. torchzero/optim/wrappers/nlopt.py +31 -25
  130. torchzero/optim/wrappers/optuna.py +6 -13
  131. torchzero/optim/wrappers/pybobyqa.py +124 -0
  132. torchzero/optim/wrappers/scipy/__init__.py +7 -0
  133. torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
  134. torchzero/optim/wrappers/scipy/brute.py +48 -0
  135. torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
  136. torchzero/optim/wrappers/scipy/direct.py +69 -0
  137. torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
  138. torchzero/optim/wrappers/scipy/experimental.py +141 -0
  139. torchzero/optim/wrappers/scipy/minimize.py +151 -0
  140. torchzero/optim/wrappers/scipy/sgho.py +111 -0
  141. torchzero/optim/wrappers/wrapper.py +121 -0
  142. torchzero/utils/__init__.py +7 -25
  143. torchzero/utils/compile.py +2 -2
  144. torchzero/utils/derivatives.py +93 -69
  145. torchzero/utils/optimizer.py +4 -77
  146. torchzero/utils/python_tools.py +31 -0
  147. torchzero/utils/tensorlist.py +11 -5
  148. torchzero/utils/thoad_tools.py +68 -0
  149. {torchzero-0.3.15.dist-info → torchzero-0.4.0.dist-info}/METADATA +1 -1
  150. torchzero-0.4.0.dist-info/RECORD +191 -0
  151. tests/test_vars.py +0 -185
  152. torchzero/core/var.py +0 -376
  153. torchzero/modules/experimental/momentum.py +0 -160
  154. torchzero/optim/wrappers/scipy.py +0 -572
  155. torchzero/utils/linalg/__init__.py +0 -12
  156. torchzero/utils/linalg/matrix_funcs.py +0 -87
  157. torchzero/utils/linalg/orthogonalize.py +0 -12
  158. torchzero/utils/linalg/svd.py +0 -20
  159. torchzero/utils/ops.py +0 -10
  160. torchzero-0.3.15.dist-info/RECORD +0 -175
  161. /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
  162. {torchzero-0.3.15.dist-info → torchzero-0.4.0.dist-info}/WHEEL +0 -0
  163. {torchzero-0.3.15.dist-info → torchzero-0.4.0.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  import torch
2
2
 
3
- from ...core import Module, Chainable, apply_transform
3
+ from ...core import Module, Chainable, step
4
4
  from ...utils import TensorList, vec_to_tensors
5
5
  from ..second_order.newton import _newton_step, _get_H
6
6
 
@@ -58,12 +58,12 @@ class SG2(Module):
58
58
  if inner is not None: self.set_child('inner', inner)
59
59
 
60
60
  @torch.no_grad
61
- def update(self, var):
61
+ def update(self, objective):
62
62
  k = self.global_state.get('step', 0) + 1
63
63
  self.global_state["step"] = k
64
64
 
65
- params = TensorList(var.params)
66
- closure = var.closure
65
+ params = TensorList(objective.params)
66
+ closure = objective.closure
67
67
  if closure is None:
68
68
  raise RuntimeError("closure is required for SG2")
69
69
  generator = self.get_generator(params[0].device, self.defaults["seed"])
@@ -79,7 +79,7 @@ class SG2(Module):
79
79
 
80
80
  # one sided
81
81
  if self.defaults["one_sided"]:
82
- g_0 = TensorList(var.get_grad())
82
+ g_0 = TensorList(objective.get_grads())
83
83
  params.add_(cd)
84
84
  closure()
85
85
 
@@ -126,9 +126,9 @@ class SG2(Module):
126
126
 
127
127
 
128
128
  @torch.no_grad
129
- def apply(self, var):
129
+ def apply(self, objective):
130
130
  dir = _newton_step(
131
- var=var,
131
+ objective=objective,
132
132
  H = self.global_state["H"],
133
133
  damping = self.defaults["damping"],
134
134
  inner = self.children.get("inner", None),
@@ -138,10 +138,10 @@ class SG2(Module):
138
138
  g_proj=None,
139
139
  )
140
140
 
141
- var.update = vec_to_tensors(dir, var.params)
142
- return var
141
+ objective.updates = vec_to_tensors(dir, objective.params)
142
+ return objective
143
143
 
144
- def get_H(self,var=...):
144
+ def get_H(self,objective=...):
145
145
  return _get_H(self.global_state["H"], self.defaults["eigval_fn"])
146
146
 
147
147
 
@@ -198,12 +198,12 @@ class SPSA2(Module):
198
198
  if inner is not None: self.set_child('inner', inner)
199
199
 
200
200
  @torch.no_grad
201
- def update(self, var):
201
+ def update(self, objective):
202
202
  k = self.global_state.get('step', 0) + 1
203
203
  self.global_state["step"] = k
204
204
 
205
- params = TensorList(var.params)
206
- closure = var.closure
205
+ params = TensorList(objective.params)
206
+ closure = objective.closure
207
207
  if closure is None:
208
208
  raise RuntimeError("closure is required for SPSA2")
209
209
 
@@ -260,7 +260,7 @@ class SPSA2(Module):
260
260
  H_hat /= n_samples
261
261
 
262
262
  # set grad to approximated grad
263
- var.grad = g_0
263
+ objective.grads = g_0
264
264
 
265
265
  # update H
266
266
  H = self.global_state.get("H", None)
@@ -273,9 +273,9 @@ class SPSA2(Module):
273
273
  self.global_state["H"] = H
274
274
 
275
275
  @torch.no_grad
276
- def apply(self, var):
276
+ def apply(self, objective):
277
277
  dir = _newton_step(
278
- var=var,
278
+ objective=objective,
279
279
  H = self.global_state["H"],
280
280
  damping = self.defaults["damping"],
281
281
  inner = self.children.get("inner", None),
@@ -285,8 +285,8 @@ class SPSA2(Module):
285
285
  g_proj=None,
286
286
  )
287
287
 
288
- var.update = vec_to_tensors(dir, var.params)
289
- return var
288
+ objective.updates = vec_to_tensors(dir, objective.params)
289
+ return objective
290
290
 
291
- def get_H(self,var=...):
291
+ def get_H(self,objective=...):
292
292
  return _get_H(self.global_state["H"], self.defaults["eigval_fn"])
@@ -4,12 +4,14 @@ from typing import final, Literal, cast
4
4
 
5
5
  import torch
6
6
 
7
- from ...core import Chainable, Module, Var
7
+ from ...core import Chainable, Module, Objective
8
8
  from ...utils import TensorList
9
9
  from ..termination import TerminationCriteriaBase
10
10
 
11
11
  def _reset_except_self(optimizer, var, self: Module):
12
- for m in optimizer.unrolled_modules: m.reset()
12
+ for m in optimizer.unrolled_modules:
13
+ if m is not self:
14
+ m.reset()
13
15
 
14
16
  class RestartStrategyBase(Module, ABC):
15
17
  """Base class for restart strategies.
@@ -24,7 +26,7 @@ class RestartStrategyBase(Module, ABC):
24
26
  self.set_child('modules', modules)
25
27
 
26
28
  @abstractmethod
27
- def should_reset(self, var: Var) -> bool:
29
+ def should_reset(self, var: Objective) -> bool:
28
30
  """returns whether reset should occur"""
29
31
 
30
32
  def _reset_on_condition(self, var):
@@ -39,23 +41,23 @@ class RestartStrategyBase(Module, ABC):
39
41
  return modules
40
42
 
41
43
  @final
42
- def update(self, var):
43
- modules = self._reset_on_condition(var)
44
+ def update(self, objective):
45
+ modules = self._reset_on_condition(objective)
44
46
  if modules is not None:
45
- modules.update(var)
47
+ modules.update(objective)
46
48
 
47
49
  @final
48
- def apply(self, var):
50
+ def apply(self, objective):
49
51
  # don't check here because it was check in `update`
50
52
  modules = self.children.get('modules', None)
51
- if modules is None: return var
52
- return modules.apply(var.clone(clone_update=False))
53
+ if modules is None: return objective
54
+ return modules.apply(objective.clone(clone_updates=False))
53
55
 
54
56
  @final
55
- def step(self, var):
56
- modules = self._reset_on_condition(var)
57
- if modules is None: return var
58
- return modules.step(var.clone(clone_update=False))
57
+ def step(self, objective):
58
+ modules = self._reset_on_condition(objective)
59
+ if modules is None: return objective
60
+ return modules.step(objective.clone(clone_updates=False))
59
61
 
60
62
 
61
63
 
@@ -170,7 +172,7 @@ class PowellRestart(RestartStrategyBase):
170
172
  super().__init__(defaults, modules)
171
173
 
172
174
  def should_reset(self, var):
173
- g = TensorList(var.get_grad())
175
+ g = TensorList(var.get_grads())
174
176
  cond1 = self.defaults['cond1']; cond2 = self.defaults['cond2']
175
177
 
176
178
  # -------------------------------- initialize -------------------------------- #
@@ -192,7 +194,7 @@ class PowellRestart(RestartStrategyBase):
192
194
 
193
195
  # ------------------------------- 2nd condition ------------------------------ #
194
196
  if (cond2 is not None) and (not reset):
195
- d_g = TensorList(var.get_update()).dot(g)
197
+ d_g = TensorList(var.get_updates()).dot(g)
196
198
  if (-1-cond2) * g_g < d_g < (-1 + cond2) * g_g:
197
199
  reset = True
198
200
 
@@ -229,17 +231,17 @@ class BirginMartinezRestart(Module):
229
231
 
230
232
  self.set_child("module", module)
231
233
 
232
- def update(self, var):
234
+ def update(self, objective):
233
235
  module = self.children['module']
234
- module.update(var)
236
+ module.update(objective)
235
237
 
236
- def apply(self, var):
238
+ def apply(self, objective):
237
239
  module = self.children['module']
238
- var = module.apply(var.clone(clone_update=False))
240
+ objective = module.apply(objective.clone(clone_updates=False))
239
241
 
240
242
  cond = self.defaults['cond']
241
- g = TensorList(var.get_grad())
242
- d = TensorList(var.get_update())
243
+ g = TensorList(objective.get_grads())
244
+ d = TensorList(objective.get_updates())
243
245
  d_g = d.dot(g)
244
246
  d_norm = d.global_vector_norm()
245
247
  g_norm = g.global_vector_norm()
@@ -247,7 +249,7 @@ class BirginMartinezRestart(Module):
247
249
  # d in our case is same direction as g so it has a minus sign
248
250
  if -d_g > -cond * d_norm * g_norm:
249
251
  module.reset()
250
- var.update = g.clone()
251
- return var
252
+ objective.updates = g.clone()
253
+ return objective
252
254
 
253
- return var
255
+ return objective
@@ -1,7 +1,7 @@
1
1
  from .ifn import InverseFreeNewton
2
- from .inm import INM
2
+ from .inm import ImprovedNewton
3
3
  from .multipoint import SixthOrder3P, SixthOrder3PM2, SixthOrder5P, TwoPointNewton
4
4
  from .newton import Newton
5
5
  from .newton_cg import NewtonCG, NewtonCGSteihaug
6
6
  from .nystrom import NystromPCG, NystromSketchAndSolve
7
- from .rsn import RSN
7
+ from .rsn import SubspaceNewton
@@ -1,89 +1,58 @@
1
- import warnings
2
- from collections.abc import Callable
3
- from functools import partial
4
- from typing import Literal
5
-
6
1
  import torch
7
2
 
8
- from ...core import Chainable, Module, apply_transform, Var
3
+ from ...core import Chainable, Transform, HessianMethod
9
4
  from ...utils import TensorList, vec_to_tensors
10
- from ...utils.linalg.linear_operator import DenseWithInverse, Dense
11
- from .newton import _get_H, _get_loss_grad_and_hessian, _newton_step
5
+ from ...linalg.linear_operator import DenseWithInverse
12
6
 
13
7
 
14
- class InverseFreeNewton(Module):
8
+ class InverseFreeNewton(Transform):
15
9
  """Inverse-free newton's method
16
10
 
17
- .. note::
18
- In most cases Newton should be the first module in the chain because it relies on autograd. Use the :code:`inner` argument if you wish to apply Newton preconditioning to another module's output.
19
-
20
- .. note::
21
- This module requires the a closure passed to the optimizer step,
22
- as it needs to re-evaluate the loss and gradients for calculating the hessian.
23
- The closure must accept a ``backward`` argument (refer to documentation).
24
-
25
- .. warning::
26
- this uses roughly O(N^2) memory.
27
-
28
11
  Reference
29
12
  [Massalski, Marcin, and Magdalena Nockowska-Rosiak. "INVERSE-FREE NEWTON'S METHOD." Journal of Applied Analysis & Computation 15.4 (2025): 2238-2257.](https://www.jaac-online.com/article/doi/10.11948/20240428)
30
13
  """
31
14
  def __init__(
32
15
  self,
33
16
  update_freq: int = 1,
34
- hessian_method: Literal["autograd", "func", "autograd.functional"] = "autograd",
35
- vectorize: bool = True,
17
+ hessian_method: HessianMethod = "batched_autograd",
18
+ h: float = 1e-3,
36
19
  inner: Chainable | None = None,
37
20
  ):
38
- defaults = dict(hessian_method=hessian_method, vectorize=vectorize, update_freq=update_freq)
39
- super().__init__(defaults)
40
-
41
- if inner is not None:
42
- self.set_child('inner', inner)
21
+ defaults = dict(hessian_method=hessian_method, h=h)
22
+ super().__init__(defaults, update_freq=update_freq, inner=inner)
43
23
 
44
24
  @torch.no_grad
45
- def update(self, var):
46
- update_freq = self.defaults['update_freq']
25
+ def update_states(self, objective, states, settings):
26
+ fs = settings[0]
47
27
 
48
- step = self.global_state.get('step', 0)
49
- self.global_state['step'] = step + 1
28
+ _, _, H = objective.hessian(
29
+ hessian_method=fs['hessian_method'],
30
+ h=fs['h'],
31
+ at_x0=True
32
+ )
50
33
 
51
- if step % update_freq == 0:
52
- loss, g_list, H = _get_loss_grad_and_hessian(
53
- var, self.defaults['hessian_method'], self.defaults['vectorize']
54
- )
55
- self.global_state["H"] = H
34
+ self.global_state["H"] = H
56
35
 
57
- # inverse free part
58
- if 'Y' not in self.global_state:
59
- num = H.T
60
- denom = (torch.linalg.norm(H, 1) * torch.linalg.norm(H, float('inf'))) # pylint:disable=not-callable
36
+ # inverse free part
37
+ if 'Y' not in self.global_state:
38
+ num = H.T
39
+ denom = (torch.linalg.norm(H, 1) * torch.linalg.norm(H, float('inf'))) # pylint:disable=not-callable
61
40
 
62
- finfo = torch.finfo(H.dtype)
63
- self.global_state['Y'] = num.div_(denom.clip(min=finfo.tiny * 2, max=finfo.max / 2))
41
+ finfo = torch.finfo(H.dtype)
42
+ self.global_state['Y'] = num.div_(denom.clip(min=finfo.tiny * 2, max=finfo.max / 2))
64
43
 
65
- else:
66
- Y = self.global_state['Y']
67
- I2 = torch.eye(Y.size(0), device=Y.device, dtype=Y.dtype).mul_(2)
68
- I2 -= H @ Y
69
- self.global_state['Y'] = Y @ I2
44
+ else:
45
+ Y = self.global_state['Y']
46
+ I2 = torch.eye(Y.size(0), device=Y.device, dtype=Y.dtype).mul_(2)
47
+ I2 -= H @ Y
48
+ self.global_state['Y'] = Y @ I2
70
49
 
71
50
 
72
- def apply(self, var):
51
+ def apply_states(self, objective, states, settings):
73
52
  Y = self.global_state["Y"]
74
- params = var.params
75
-
76
- # -------------------------------- inner step -------------------------------- #
77
- update = var.get_update()
78
- if 'inner' in self.children:
79
- update = apply_transform(self.children['inner'], update, params=params, grads=var.grad, var=var)
80
-
81
- g = torch.cat([t.ravel() for t in update])
82
-
83
- # ----------------------------------- solve ---------------------------------- #
84
- var.update = vec_to_tensors(Y@g, params)
85
-
86
- return var
53
+ g = torch.cat([t.ravel() for t in objective.get_updates()])
54
+ objective.updates = vec_to_tensors(Y@g, objective.params)
55
+ return objective
87
56
 
88
- def get_H(self,var):
57
+ def get_H(self,objective=...):
89
58
  return DenseWithInverse(A = self.global_state["H"], A_inv=self.global_state["Y"])
@@ -1,12 +1,11 @@
1
1
  from collections.abc import Callable
2
- from typing import Literal
3
2
 
4
3
  import torch
5
4
 
6
- from ...core import Chainable, Module
7
- from ...utils import TensorList, vec_to_tensors
5
+ from ...core import Chainable, Transform, HessianMethod
6
+ from ...utils import TensorList, vec_to_tensors, unpack_states
8
7
  from ..functional import safe_clip
9
- from .newton import _get_H, _get_loss_grad_and_hessian, _newton_step
8
+ from .newton import _get_H, _newton_step
10
9
 
11
10
  @torch.no_grad
12
11
  def inm(f:torch.Tensor, J:torch.Tensor, s:torch.Tensor, y:torch.Tensor):
@@ -25,7 +24,7 @@ def _eigval_fn(J: torch.Tensor, fn) -> torch.Tensor:
25
24
  L, Q = torch.linalg.eigh(J) # pylint:disable=not-callable
26
25
  return (Q * L.unsqueeze(-2)) @ Q.mH
27
26
 
28
- class INM(Module):
27
+ class ImprovedNewton(Transform):
29
28
  """Improved Newton's Method (INM).
30
29
 
31
30
  Reference:
@@ -37,69 +36,66 @@ class INM(Module):
37
36
  damping: float = 0,
38
37
  use_lstsq: bool = False,
39
38
  update_freq: int = 1,
40
- hessian_method: Literal["autograd", "func", "autograd.functional"] = "autograd",
41
- vectorize: bool = True,
42
- inner: Chainable | None = None,
43
39
  H_tfm: Callable[[torch.Tensor, torch.Tensor], tuple[torch.Tensor, bool]] | Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
44
40
  eigval_fn: Callable[[torch.Tensor], torch.Tensor] | None = None,
41
+ hessian_method: HessianMethod = "batched_autograd",
42
+ h: float = 1e-3,
43
+ inner: Chainable | None = None,
45
44
  ):
46
- defaults = dict(damping=damping, hessian_method=hessian_method, use_lstsq=use_lstsq, vectorize=vectorize, H_tfm=H_tfm, eigval_fn=eigval_fn, update_freq=update_freq)
47
- super().__init__(defaults)
48
-
49
- if inner is not None:
50
- self.set_child("inner", inner)
45
+ defaults = locals().copy()
46
+ del defaults['self'], defaults['inner'], defaults["update_freq"]
47
+ super().__init__(defaults, update_freq=update_freq, inner=inner, )
51
48
 
52
49
  @torch.no_grad
53
- def update(self, var):
54
- update_freq = self.defaults['update_freq']
55
-
56
- step = self.global_state.get('step', 0)
57
- self.global_state['step'] = step + 1
58
-
59
- if step % update_freq == 0:
60
- _, f_list, J = _get_loss_grad_and_hessian(
61
- var, self.defaults['hessian_method'], self.defaults['vectorize']
62
- )
63
-
64
- f = torch.cat([t.ravel() for t in f_list])
65
- J = _eigval_fn(J, self.defaults["eigval_fn"])
66
-
67
- x_list = TensorList(var.params)
68
- f_list = TensorList(var.get_grad())
69
- x_prev, f_prev = self.get_state(var.params, "x_prev", "f_prev", cls=TensorList)
70
-
71
- # initialize on 1st step, do Newton step
72
- if step == 0:
73
- x_prev.copy_(x_list)
74
- f_prev.copy_(f_list)
75
- self.global_state["P"] = J
76
- return
77
-
78
- # INM update
79
- s_list = x_list - x_prev
80
- y_list = f_list - f_prev
50
+ def update_states(self, objective, states, settings):
51
+ fs = settings[0]
52
+
53
+ _, f_list, J = objective.hessian(
54
+ hessian_method=fs['hessian_method'],
55
+ h=fs['h'],
56
+ at_x0=True
57
+ )
58
+ if f_list is None: f_list = objective.get_grads()
59
+
60
+ f = torch.cat([t.ravel() for t in f_list])
61
+ J = _eigval_fn(J, fs["eigval_fn"])
62
+
63
+ x_list = TensorList(objective.params)
64
+ f_list = TensorList(objective.get_grads())
65
+ x_prev, f_prev = unpack_states(states, objective.params, "x_prev", "f_prev", cls=TensorList)
66
+
67
+ # initialize on 1st step, do Newton step
68
+ if "P" not in self.global_state:
81
69
  x_prev.copy_(x_list)
82
70
  f_prev.copy_(f_list)
71
+ self.global_state["P"] = J
72
+ return
83
73
 
84
- self.global_state["P"] = inm(f, J, s=s_list.to_vec(), y=y_list.to_vec())
74
+ # INM update
75
+ s_list = x_list - x_prev
76
+ y_list = f_list - f_prev
77
+ x_prev.copy_(x_list)
78
+ f_prev.copy_(f_list)
79
+
80
+ self.global_state["P"] = inm(f, J, s=s_list.to_vec(), y=y_list.to_vec())
85
81
 
86
82
 
87
83
  @torch.no_grad
88
- def apply(self, var):
89
- params = var.params
84
+ def apply_states(self, objective, states, settings):
85
+ fs = settings[0]
86
+
90
87
  update = _newton_step(
91
- var=var,
88
+ objective = objective,
92
89
  H = self.global_state["P"],
93
- damping=self.defaults["damping"],
94
- inner=self.children.get("inner", None),
95
- H_tfm=self.defaults["H_tfm"],
96
- eigval_fn=None, # it is applied in `update`
97
- use_lstsq=self.defaults["use_lstsq"],
90
+ damping = fs["damping"],
91
+ H_tfm = fs["H_tfm"],
92
+ eigval_fn = None, # it is applied in `update`
93
+ use_lstsq = fs["use_lstsq"],
98
94
  )
99
95
 
100
- var.update = vec_to_tensors(update, params)
96
+ objective.updates = vec_to_tensors(update, objective.params)
101
97
 
102
- return var
98
+ return objective
103
99
 
104
- def get_H(self,var=...):
100
+ def get_H(self,objective=...):
105
101
  return _get_H(self.global_state["P"], eigval_fn=None)