torchzero 0.3.15__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (187) hide show
  1. tests/test_identical.py +22 -22
  2. tests/test_module_autograd.py +586 -0
  3. tests/test_objective.py +188 -0
  4. tests/test_opts.py +225 -214
  5. tests/test_tensorlist.py +0 -8
  6. tests/test_utils_optimizer.py +0 -1
  7. torchzero/__init__.py +2 -2
  8. torchzero/core/__init__.py +7 -4
  9. torchzero/core/chain.py +20 -23
  10. torchzero/core/functional.py +90 -24
  11. torchzero/core/modular.py +53 -57
  12. torchzero/core/module.py +132 -52
  13. torchzero/core/objective.py +948 -0
  14. torchzero/core/reformulation.py +55 -24
  15. torchzero/core/transform.py +261 -367
  16. torchzero/linalg/__init__.py +11 -0
  17. torchzero/linalg/eigh.py +253 -0
  18. torchzero/linalg/linalg_utils.py +14 -0
  19. torchzero/{utils/linalg → linalg}/linear_operator.py +99 -49
  20. torchzero/linalg/matrix_power.py +28 -0
  21. torchzero/linalg/orthogonalize.py +93 -0
  22. torchzero/{utils/linalg → linalg}/qr.py +16 -2
  23. torchzero/{utils/linalg → linalg}/solve.py +74 -88
  24. torchzero/linalg/svd.py +47 -0
  25. torchzero/linalg/torch_linalg.py +168 -0
  26. torchzero/modules/__init__.py +4 -3
  27. torchzero/modules/adaptive/__init__.py +11 -3
  28. torchzero/modules/adaptive/adagrad.py +167 -217
  29. torchzero/modules/adaptive/adahessian.py +76 -105
  30. torchzero/modules/adaptive/adam.py +53 -76
  31. torchzero/modules/adaptive/adan.py +50 -31
  32. torchzero/modules/adaptive/adaptive_heavyball.py +12 -7
  33. torchzero/modules/adaptive/aegd.py +12 -12
  34. torchzero/modules/adaptive/esgd.py +98 -119
  35. torchzero/modules/adaptive/ggt.py +186 -0
  36. torchzero/modules/adaptive/lion.py +7 -11
  37. torchzero/modules/adaptive/lre_optimizers.py +299 -0
  38. torchzero/modules/adaptive/mars.py +7 -7
  39. torchzero/modules/adaptive/matrix_momentum.py +48 -52
  40. torchzero/modules/adaptive/msam.py +71 -53
  41. torchzero/modules/adaptive/muon.py +67 -129
  42. torchzero/modules/adaptive/natural_gradient.py +63 -41
  43. torchzero/modules/adaptive/orthograd.py +11 -15
  44. torchzero/modules/adaptive/psgd/__init__.py +5 -0
  45. torchzero/modules/adaptive/psgd/_psgd_utils.py +37 -0
  46. torchzero/modules/adaptive/psgd/psgd.py +1390 -0
  47. torchzero/modules/adaptive/psgd/psgd_dense_newton.py +174 -0
  48. torchzero/modules/adaptive/psgd/psgd_kron_newton.py +203 -0
  49. torchzero/modules/adaptive/psgd/psgd_kron_whiten.py +185 -0
  50. torchzero/modules/adaptive/psgd/psgd_lra_newton.py +118 -0
  51. torchzero/modules/adaptive/psgd/psgd_lra_whiten.py +116 -0
  52. torchzero/modules/adaptive/rmsprop.py +83 -75
  53. torchzero/modules/adaptive/rprop.py +48 -47
  54. torchzero/modules/adaptive/sam.py +55 -45
  55. torchzero/modules/adaptive/shampoo.py +149 -130
  56. torchzero/modules/adaptive/soap.py +207 -143
  57. torchzero/modules/adaptive/sophia_h.py +106 -130
  58. torchzero/modules/clipping/clipping.py +22 -25
  59. torchzero/modules/clipping/ema_clipping.py +31 -25
  60. torchzero/modules/clipping/growth_clipping.py +14 -17
  61. torchzero/modules/conjugate_gradient/cg.py +27 -38
  62. torchzero/modules/experimental/__init__.py +7 -6
  63. torchzero/modules/experimental/adanystrom.py +258 -0
  64. torchzero/modules/experimental/common_directions_whiten.py +142 -0
  65. torchzero/modules/experimental/coordinate_momentum.py +36 -0
  66. torchzero/modules/experimental/cubic_adam.py +160 -0
  67. torchzero/modules/experimental/curveball.py +25 -41
  68. torchzero/modules/experimental/eigen_sr1.py +182 -0
  69. torchzero/modules/experimental/eigengrad.py +207 -0
  70. torchzero/modules/experimental/gradmin.py +2 -2
  71. torchzero/modules/experimental/higher_order_newton.py +14 -40
  72. torchzero/modules/experimental/l_infinity.py +1 -1
  73. torchzero/modules/experimental/matrix_nag.py +122 -0
  74. torchzero/modules/experimental/newton_solver.py +23 -54
  75. torchzero/modules/experimental/newtonnewton.py +45 -48
  76. torchzero/modules/experimental/reduce_outward_lr.py +7 -7
  77. torchzero/modules/experimental/scipy_newton_cg.py +21 -24
  78. torchzero/modules/experimental/spsa1.py +3 -3
  79. torchzero/modules/experimental/structural_projections.py +1 -4
  80. torchzero/modules/grad_approximation/fdm.py +2 -2
  81. torchzero/modules/grad_approximation/forward_gradient.py +7 -7
  82. torchzero/modules/grad_approximation/grad_approximator.py +23 -16
  83. torchzero/modules/grad_approximation/rfdm.py +24 -21
  84. torchzero/modules/least_squares/gn.py +121 -50
  85. torchzero/modules/line_search/backtracking.py +4 -4
  86. torchzero/modules/line_search/line_search.py +33 -33
  87. torchzero/modules/line_search/strong_wolfe.py +4 -4
  88. torchzero/modules/misc/debug.py +12 -12
  89. torchzero/modules/misc/escape.py +10 -10
  90. torchzero/modules/misc/gradient_accumulation.py +11 -79
  91. torchzero/modules/misc/homotopy.py +16 -8
  92. torchzero/modules/misc/misc.py +121 -123
  93. torchzero/modules/misc/multistep.py +52 -53
  94. torchzero/modules/misc/regularization.py +49 -44
  95. torchzero/modules/misc/split.py +31 -29
  96. torchzero/modules/misc/switch.py +37 -32
  97. torchzero/modules/momentum/averaging.py +14 -14
  98. torchzero/modules/momentum/cautious.py +37 -31
  99. torchzero/modules/momentum/momentum.py +12 -12
  100. torchzero/modules/ops/__init__.py +4 -4
  101. torchzero/modules/ops/accumulate.py +21 -21
  102. torchzero/modules/ops/binary.py +67 -66
  103. torchzero/modules/ops/higher_level.py +20 -20
  104. torchzero/modules/ops/multi.py +44 -41
  105. torchzero/modules/ops/reduce.py +26 -23
  106. torchzero/modules/ops/unary.py +53 -53
  107. torchzero/modules/ops/utility.py +47 -46
  108. torchzero/modules/{functional.py → opt_utils.py} +1 -1
  109. torchzero/modules/projections/galore.py +1 -1
  110. torchzero/modules/projections/projection.py +46 -43
  111. torchzero/modules/quasi_newton/__init__.py +1 -1
  112. torchzero/modules/quasi_newton/damping.py +2 -2
  113. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +1 -1
  114. torchzero/modules/quasi_newton/lbfgs.py +10 -10
  115. torchzero/modules/quasi_newton/lsr1.py +10 -10
  116. torchzero/modules/quasi_newton/quasi_newton.py +54 -39
  117. torchzero/modules/quasi_newton/sg2.py +69 -205
  118. torchzero/modules/restarts/restars.py +39 -37
  119. torchzero/modules/second_order/__init__.py +2 -2
  120. torchzero/modules/second_order/ifn.py +31 -62
  121. torchzero/modules/second_order/inm.py +57 -53
  122. torchzero/modules/second_order/multipoint.py +40 -80
  123. torchzero/modules/second_order/newton.py +165 -196
  124. torchzero/modules/second_order/newton_cg.py +105 -157
  125. torchzero/modules/second_order/nystrom.py +216 -185
  126. torchzero/modules/second_order/rsn.py +132 -125
  127. torchzero/modules/smoothing/laplacian.py +13 -12
  128. torchzero/modules/smoothing/sampling.py +10 -10
  129. torchzero/modules/step_size/adaptive.py +24 -24
  130. torchzero/modules/step_size/lr.py +17 -17
  131. torchzero/modules/termination/termination.py +32 -30
  132. torchzero/modules/trust_region/cubic_regularization.py +3 -3
  133. torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
  134. torchzero/modules/trust_region/trust_cg.py +2 -2
  135. torchzero/modules/trust_region/trust_region.py +27 -22
  136. torchzero/modules/variance_reduction/svrg.py +23 -21
  137. torchzero/modules/weight_decay/__init__.py +2 -1
  138. torchzero/modules/weight_decay/reinit.py +83 -0
  139. torchzero/modules/weight_decay/weight_decay.py +17 -18
  140. torchzero/modules/wrappers/optim_wrapper.py +14 -14
  141. torchzero/modules/zeroth_order/cd.py +10 -7
  142. torchzero/optim/mbs.py +291 -0
  143. torchzero/optim/root.py +3 -3
  144. torchzero/optim/utility/split.py +2 -1
  145. torchzero/optim/wrappers/directsearch.py +27 -63
  146. torchzero/optim/wrappers/fcmaes.py +14 -35
  147. torchzero/optim/wrappers/mads.py +11 -31
  148. torchzero/optim/wrappers/moors.py +66 -0
  149. torchzero/optim/wrappers/nevergrad.py +4 -13
  150. torchzero/optim/wrappers/nlopt.py +31 -25
  151. torchzero/optim/wrappers/optuna.py +8 -13
  152. torchzero/optim/wrappers/pybobyqa.py +124 -0
  153. torchzero/optim/wrappers/scipy/__init__.py +7 -0
  154. torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
  155. torchzero/optim/wrappers/scipy/brute.py +48 -0
  156. torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
  157. torchzero/optim/wrappers/scipy/direct.py +69 -0
  158. torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
  159. torchzero/optim/wrappers/scipy/experimental.py +141 -0
  160. torchzero/optim/wrappers/scipy/minimize.py +151 -0
  161. torchzero/optim/wrappers/scipy/sgho.py +111 -0
  162. torchzero/optim/wrappers/wrapper.py +121 -0
  163. torchzero/utils/__init__.py +7 -25
  164. torchzero/utils/benchmarks/__init__.py +0 -0
  165. torchzero/utils/benchmarks/logistic.py +122 -0
  166. torchzero/utils/compile.py +2 -2
  167. torchzero/utils/derivatives.py +97 -73
  168. torchzero/utils/optimizer.py +4 -77
  169. torchzero/utils/python_tools.py +31 -0
  170. torchzero/utils/tensorlist.py +11 -5
  171. torchzero/utils/thoad_tools.py +68 -0
  172. {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/METADATA +1 -1
  173. torchzero-0.4.1.dist-info/RECORD +209 -0
  174. tests/test_vars.py +0 -185
  175. torchzero/core/var.py +0 -376
  176. torchzero/modules/adaptive/lmadagrad.py +0 -186
  177. torchzero/modules/experimental/momentum.py +0 -160
  178. torchzero/optim/wrappers/scipy.py +0 -572
  179. torchzero/utils/linalg/__init__.py +0 -12
  180. torchzero/utils/linalg/matrix_funcs.py +0 -87
  181. torchzero/utils/linalg/orthogonalize.py +0 -12
  182. torchzero/utils/linalg/svd.py +0 -20
  183. torchzero/utils/ops.py +0 -10
  184. torchzero-0.3.15.dist-info/RECORD +0 -175
  185. /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
  186. {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/WHEEL +0 -0
  187. {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/top_level.txt +0 -0
@@ -4,12 +4,14 @@ from typing import final, Literal, cast
4
4
 
5
5
  import torch
6
6
 
7
- from ...core import Chainable, Module, Var
7
+ from ...core import Chainable, Module, Objective
8
8
  from ...utils import TensorList
9
9
  from ..termination import TerminationCriteriaBase
10
10
 
11
- def _reset_except_self(optimizer, var, self: Module):
12
- for m in optimizer.unrolled_modules: m.reset()
11
+ def _reset_except_self(objective, modules, self: Module):
12
+ for m in modules:
13
+ if m is not self:
14
+ m.reset()
13
15
 
14
16
  class RestartStrategyBase(Module, ABC):
15
17
  """Base class for restart strategies.
@@ -24,38 +26,38 @@ class RestartStrategyBase(Module, ABC):
24
26
  self.set_child('modules', modules)
25
27
 
26
28
  @abstractmethod
27
- def should_reset(self, var: Var) -> bool:
29
+ def should_reset(self, objective: Objective) -> bool:
28
30
  """returns whether reset should occur"""
29
31
 
30
- def _reset_on_condition(self, var):
32
+ def _reset_on_condition(self, objective: Objective):
31
33
  modules = self.children.get('modules', None)
32
34
 
33
- if self.should_reset(var):
35
+ if self.should_reset(objective):
34
36
  if modules is None:
35
- var.post_step_hooks.append(partial(_reset_except_self, self=self))
37
+ objective.post_step_hooks.append(partial(_reset_except_self, self=self))
36
38
  else:
37
39
  modules.reset()
38
40
 
39
41
  return modules
40
42
 
41
43
  @final
42
- def update(self, var):
43
- modules = self._reset_on_condition(var)
44
+ def update(self, objective):
45
+ modules = self._reset_on_condition(objective)
44
46
  if modules is not None:
45
- modules.update(var)
47
+ modules.update(objective)
46
48
 
47
49
  @final
48
- def apply(self, var):
50
+ def apply(self, objective):
49
51
  # don't check here because it was check in `update`
50
52
  modules = self.children.get('modules', None)
51
- if modules is None: return var
52
- return modules.apply(var.clone(clone_update=False))
53
+ if modules is None: return objective
54
+ return modules.apply(objective.clone(clone_updates=False))
53
55
 
54
56
  @final
55
- def step(self, var):
56
- modules = self._reset_on_condition(var)
57
- if modules is None: return var
58
- return modules.step(var.clone(clone_update=False))
57
+ def step(self, objective):
58
+ modules = self._reset_on_condition(objective)
59
+ if modules is None: return objective
60
+ return modules.step(objective.clone(clone_updates=False))
59
61
 
60
62
 
61
63
 
@@ -76,11 +78,11 @@ class RestartOnStuck(RestartStrategyBase):
76
78
  super().__init__(defaults, modules)
77
79
 
78
80
  @torch.no_grad
79
- def should_reset(self, var):
81
+ def should_reset(self, objective):
80
82
  step = self.global_state.get('step', 0)
81
83
  self.global_state['step'] = step + 1
82
84
 
83
- params = TensorList(var.params)
85
+ params = TensorList(objective.params)
84
86
  tol = self.defaults['tol']
85
87
  if tol is None: tol = torch.finfo(params[0].dtype).tiny * 2
86
88
  n_tol = self.defaults['n_tol']
@@ -122,12 +124,12 @@ class RestartEvery(RestartStrategyBase):
122
124
  defaults = dict(steps=steps)
123
125
  super().__init__(defaults, modules)
124
126
 
125
- def should_reset(self, var):
127
+ def should_reset(self, objective):
126
128
  step = self.global_state.get('step', 0) + 1
127
129
  self.global_state['step'] = step
128
130
 
129
131
  n = self.defaults['steps']
130
- if isinstance(n, str): n = sum(p.numel() for p in var.params if p.requires_grad)
132
+ if isinstance(n, str): n = sum(p.numel() for p in objective.params if p.requires_grad)
131
133
 
132
134
  # reset every n steps
133
135
  if step % n == 0:
@@ -141,9 +143,9 @@ class RestartOnTerminationCriteria(RestartStrategyBase):
141
143
  super().__init__(None, modules)
142
144
  self.set_child('criteria', criteria)
143
145
 
144
- def should_reset(self, var):
146
+ def should_reset(self, objective):
145
147
  criteria = cast(TerminationCriteriaBase, self.children['criteria'])
146
- return criteria.should_terminate(var)
148
+ return criteria.should_terminate(objective)
147
149
 
148
150
  class PowellRestart(RestartStrategyBase):
149
151
  """Powell's two restarting criterions for conjugate gradient methods.
@@ -169,14 +171,14 @@ class PowellRestart(RestartStrategyBase):
169
171
  defaults=dict(cond1=cond1, cond2=cond2)
170
172
  super().__init__(defaults, modules)
171
173
 
172
- def should_reset(self, var):
173
- g = TensorList(var.get_grad())
174
+ def should_reset(self, objective):
175
+ g = TensorList(objective.get_grads())
174
176
  cond1 = self.defaults['cond1']; cond2 = self.defaults['cond2']
175
177
 
176
178
  # -------------------------------- initialize -------------------------------- #
177
179
  if 'initialized' not in self.global_state:
178
180
  self.global_state['initialized'] = 0
179
- g_prev = self.get_state(var.params, 'g_prev', init=g)
181
+ g_prev = self.get_state(objective.params, 'g_prev', init=g)
180
182
  return False
181
183
 
182
184
  g_g = g.dot(g)
@@ -184,7 +186,7 @@ class PowellRestart(RestartStrategyBase):
184
186
  reset = False
185
187
  # ------------------------------- 1st condition ------------------------------ #
186
188
  if cond1 is not None:
187
- g_prev = self.get_state(var.params, 'g_prev', must_exist=True, cls=TensorList)
189
+ g_prev = self.get_state(objective.params, 'g_prev', must_exist=True, cls=TensorList)
188
190
  g_g_prev = g_prev.dot(g)
189
191
 
190
192
  if g_g_prev.abs() >= cond1 * g_g:
@@ -192,7 +194,7 @@ class PowellRestart(RestartStrategyBase):
192
194
 
193
195
  # ------------------------------- 2nd condition ------------------------------ #
194
196
  if (cond2 is not None) and (not reset):
195
- d_g = TensorList(var.get_update()).dot(g)
197
+ d_g = TensorList(objective.get_updates()).dot(g)
196
198
  if (-1-cond2) * g_g < d_g < (-1 + cond2) * g_g:
197
199
  reset = True
198
200
 
@@ -229,17 +231,17 @@ class BirginMartinezRestart(Module):
229
231
 
230
232
  self.set_child("module", module)
231
233
 
232
- def update(self, var):
234
+ def update(self, objective):
233
235
  module = self.children['module']
234
- module.update(var)
236
+ module.update(objective)
235
237
 
236
- def apply(self, var):
238
+ def apply(self, objective):
237
239
  module = self.children['module']
238
- var = module.apply(var.clone(clone_update=False))
240
+ objective = module.apply(objective.clone(clone_updates=False))
239
241
 
240
242
  cond = self.defaults['cond']
241
- g = TensorList(var.get_grad())
242
- d = TensorList(var.get_update())
243
+ g = TensorList(objective.get_grads())
244
+ d = TensorList(objective.get_updates())
243
245
  d_g = d.dot(g)
244
246
  d_norm = d.global_vector_norm()
245
247
  g_norm = g.global_vector_norm()
@@ -247,7 +249,7 @@ class BirginMartinezRestart(Module):
247
249
  # d in our case is same direction as g so it has a minus sign
248
250
  if -d_g > -cond * d_norm * g_norm:
249
251
  module.reset()
250
- var.update = g.clone()
251
- return var
252
+ objective.updates = g.clone()
253
+ return objective
252
254
 
253
- return var
255
+ return objective
@@ -1,7 +1,7 @@
1
1
  from .ifn import InverseFreeNewton
2
- from .inm import INM
2
+ from .inm import ImprovedNewton
3
3
  from .multipoint import SixthOrder3P, SixthOrder3PM2, SixthOrder5P, TwoPointNewton
4
4
  from .newton import Newton
5
5
  from .newton_cg import NewtonCG, NewtonCGSteihaug
6
6
  from .nystrom import NystromPCG, NystromSketchAndSolve
7
- from .rsn import RSN
7
+ from .rsn import SubspaceNewton
@@ -1,89 +1,58 @@
1
- import warnings
2
- from collections.abc import Callable
3
- from functools import partial
4
- from typing import Literal
5
-
6
1
  import torch
7
2
 
8
- from ...core import Chainable, Module, apply_transform, Var
3
+ from ...core import Chainable, Transform, HessianMethod
9
4
  from ...utils import TensorList, vec_to_tensors
10
- from ...utils.linalg.linear_operator import DenseWithInverse, Dense
11
- from .newton import _get_H, _get_loss_grad_and_hessian, _newton_step
5
+ from ...linalg.linear_operator import DenseWithInverse
12
6
 
13
7
 
14
- class InverseFreeNewton(Module):
8
+ class InverseFreeNewton(Transform):
15
9
  """Inverse-free newton's method
16
10
 
17
- .. note::
18
- In most cases Newton should be the first module in the chain because it relies on autograd. Use the :code:`inner` argument if you wish to apply Newton preconditioning to another module's output.
19
-
20
- .. note::
21
- This module requires the a closure passed to the optimizer step,
22
- as it needs to re-evaluate the loss and gradients for calculating the hessian.
23
- The closure must accept a ``backward`` argument (refer to documentation).
24
-
25
- .. warning::
26
- this uses roughly O(N^2) memory.
27
-
28
11
  Reference
29
12
  [Massalski, Marcin, and Magdalena Nockowska-Rosiak. "INVERSE-FREE NEWTON'S METHOD." Journal of Applied Analysis & Computation 15.4 (2025): 2238-2257.](https://www.jaac-online.com/article/doi/10.11948/20240428)
30
13
  """
31
14
  def __init__(
32
15
  self,
33
16
  update_freq: int = 1,
34
- hessian_method: Literal["autograd", "func", "autograd.functional"] = "autograd",
35
- vectorize: bool = True,
17
+ hessian_method: HessianMethod = "batched_autograd",
18
+ h: float = 1e-3,
36
19
  inner: Chainable | None = None,
37
20
  ):
38
- defaults = dict(hessian_method=hessian_method, vectorize=vectorize, update_freq=update_freq)
39
- super().__init__(defaults)
40
-
41
- if inner is not None:
42
- self.set_child('inner', inner)
21
+ defaults = dict(hessian_method=hessian_method, h=h)
22
+ super().__init__(defaults, update_freq=update_freq, inner=inner)
43
23
 
44
24
  @torch.no_grad
45
- def update(self, var):
46
- update_freq = self.defaults['update_freq']
25
+ def update_states(self, objective, states, settings):
26
+ fs = settings[0]
47
27
 
48
- step = self.global_state.get('step', 0)
49
- self.global_state['step'] = step + 1
28
+ _, _, H = objective.hessian(
29
+ hessian_method=fs['hessian_method'],
30
+ h=fs['h'],
31
+ at_x0=True
32
+ )
50
33
 
51
- if step % update_freq == 0:
52
- loss, g_list, H = _get_loss_grad_and_hessian(
53
- var, self.defaults['hessian_method'], self.defaults['vectorize']
54
- )
55
- self.global_state["H"] = H
34
+ self.global_state["H"] = H
56
35
 
57
- # inverse free part
58
- if 'Y' not in self.global_state:
59
- num = H.T
60
- denom = (torch.linalg.norm(H, 1) * torch.linalg.norm(H, float('inf'))) # pylint:disable=not-callable
36
+ # inverse free part
37
+ if 'Y' not in self.global_state:
38
+ num = H.T
39
+ denom = (torch.linalg.norm(H, 1) * torch.linalg.norm(H, float('inf'))) # pylint:disable=not-callable
61
40
 
62
- finfo = torch.finfo(H.dtype)
63
- self.global_state['Y'] = num.div_(denom.clip(min=finfo.tiny * 2, max=finfo.max / 2))
41
+ finfo = torch.finfo(H.dtype)
42
+ self.global_state['Y'] = num.div_(denom.clip(min=finfo.tiny * 2, max=finfo.max / 2))
64
43
 
65
- else:
66
- Y = self.global_state['Y']
67
- I2 = torch.eye(Y.size(0), device=Y.device, dtype=Y.dtype).mul_(2)
68
- I2 -= H @ Y
69
- self.global_state['Y'] = Y @ I2
44
+ else:
45
+ Y = self.global_state['Y']
46
+ I2 = torch.eye(Y.size(0), device=Y.device, dtype=Y.dtype).mul_(2)
47
+ I2 -= H @ Y
48
+ self.global_state['Y'] = Y @ I2
70
49
 
71
50
 
72
- def apply(self, var):
51
+ def apply_states(self, objective, states, settings):
73
52
  Y = self.global_state["Y"]
74
- params = var.params
75
-
76
- # -------------------------------- inner step -------------------------------- #
77
- update = var.get_update()
78
- if 'inner' in self.children:
79
- update = apply_transform(self.children['inner'], update, params=params, grads=var.grad, var=var)
80
-
81
- g = torch.cat([t.ravel() for t in update])
82
-
83
- # ----------------------------------- solve ---------------------------------- #
84
- var.update = vec_to_tensors(Y@g, params)
85
-
86
- return var
53
+ g = torch.cat([t.ravel() for t in objective.get_updates()])
54
+ objective.updates = vec_to_tensors(Y@g, objective.params)
55
+ return objective
87
56
 
88
- def get_H(self,var):
57
+ def get_H(self,objective=...):
89
58
  return DenseWithInverse(A = self.global_state["H"], A_inv=self.global_state["Y"])
@@ -1,12 +1,11 @@
1
1
  from collections.abc import Callable
2
- from typing import Literal
3
2
 
4
3
  import torch
5
4
 
6
- from ...core import Chainable, Module
7
- from ...utils import TensorList, vec_to_tensors
8
- from ..functional import safe_clip
9
- from .newton import _get_H, _get_loss_grad_and_hessian, _newton_step
5
+ from ...core import Chainable, Transform, HessianMethod
6
+ from ...utils import TensorList, vec_to_tensors_, unpack_states
7
+ from ..opt_utils import safe_clip
8
+ from .newton import _newton_update_state_, _newton_solve, _newton_get_H
10
9
 
11
10
  @torch.no_grad
12
11
  def inm(f:torch.Tensor, J:torch.Tensor, s:torch.Tensor, y:torch.Tensor):
@@ -25,7 +24,7 @@ def _eigval_fn(J: torch.Tensor, fn) -> torch.Tensor:
25
24
  L, Q = torch.linalg.eigh(J) # pylint:disable=not-callable
26
25
  return (Q * L.unsqueeze(-2)) @ Q.mH
27
26
 
28
- class INM(Module):
27
+ class ImprovedNewton(Transform):
29
28
  """Improved Newton's Method (INM).
30
29
 
31
30
  Reference:
@@ -35,71 +34,76 @@ class INM(Module):
35
34
  def __init__(
36
35
  self,
37
36
  damping: float = 0,
38
- use_lstsq: bool = False,
37
+ eigval_fn: Callable[[torch.Tensor], torch.Tensor] | None = None,
39
38
  update_freq: int = 1,
40
- hessian_method: Literal["autograd", "func", "autograd.functional"] = "autograd",
41
- vectorize: bool = True,
39
+ precompute_inverse: bool | None = None,
40
+ use_lstsq: bool = False,
41
+ hessian_method: HessianMethod = "batched_autograd",
42
+ h: float = 1e-3,
42
43
  inner: Chainable | None = None,
43
- H_tfm: Callable[[torch.Tensor, torch.Tensor], tuple[torch.Tensor, bool]] | Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
44
- eigval_fn: Callable[[torch.Tensor], torch.Tensor] | None = None,
45
44
  ):
46
- defaults = dict(damping=damping, hessian_method=hessian_method, use_lstsq=use_lstsq, vectorize=vectorize, H_tfm=H_tfm, eigval_fn=eigval_fn, update_freq=update_freq)
47
- super().__init__(defaults)
48
-
49
- if inner is not None:
50
- self.set_child("inner", inner)
45
+ defaults = locals().copy()
46
+ del defaults['self'], defaults['inner'], defaults["update_freq"]
47
+ super().__init__(defaults, update_freq=update_freq, inner=inner, )
51
48
 
52
49
  @torch.no_grad
53
- def update(self, var):
54
- update_freq = self.defaults['update_freq']
55
-
56
- step = self.global_state.get('step', 0)
57
- self.global_state['step'] = step + 1
50
+ def update_states(self, objective, states, settings):
51
+ fs = settings[0]
58
52
 
59
- if step % update_freq == 0:
60
- _, f_list, J = _get_loss_grad_and_hessian(
61
- var, self.defaults['hessian_method'], self.defaults['vectorize']
62
- )
53
+ _, f_list, J = objective.hessian(
54
+ hessian_method=fs['hessian_method'],
55
+ h=fs['h'],
56
+ at_x0=True
57
+ )
58
+ if f_list is None: f_list = objective.get_grads()
63
59
 
64
- f = torch.cat([t.ravel() for t in f_list])
65
- J = _eigval_fn(J, self.defaults["eigval_fn"])
60
+ f = torch.cat([t.ravel() for t in f_list])
61
+ J = _eigval_fn(J, fs["eigval_fn"])
66
62
 
67
- x_list = TensorList(var.params)
68
- f_list = TensorList(var.get_grad())
69
- x_prev, f_prev = self.get_state(var.params, "x_prev", "f_prev", cls=TensorList)
63
+ x_list = TensorList(objective.params)
64
+ f_list = TensorList(objective.get_grads())
65
+ x_prev, f_prev = unpack_states(states, objective.params, "x_prev", "f_prev", cls=TensorList)
70
66
 
71
- # initialize on 1st step, do Newton step
72
- if step == 0:
73
- x_prev.copy_(x_list)
74
- f_prev.copy_(f_list)
75
- self.global_state["P"] = J
76
- return
67
+ # initialize on 1st step, do Newton step
68
+ if "H" not in self.global_state:
69
+ x_prev.copy_(x_list)
70
+ f_prev.copy_(f_list)
71
+ P = J
77
72
 
78
- # INM update
73
+ # INM update
74
+ else:
79
75
  s_list = x_list - x_prev
80
76
  y_list = f_list - f_prev
81
77
  x_prev.copy_(x_list)
82
78
  f_prev.copy_(f_list)
83
79
 
84
- self.global_state["P"] = inm(f, J, s=s_list.to_vec(), y=y_list.to_vec())
80
+ P = inm(f, J, s=s_list.to_vec(), y=y_list.to_vec())
85
81
 
82
+ # update state
83
+ precompute_inverse = fs["precompute_inverse"]
84
+ if precompute_inverse is None:
85
+ precompute_inverse = fs["__update_freq"] >= 10
86
86
 
87
- @torch.no_grad
88
- def apply(self, var):
89
- params = var.params
90
- update = _newton_step(
91
- var=var,
92
- H = self.global_state["P"],
93
- damping=self.defaults["damping"],
94
- inner=self.children.get("inner", None),
95
- H_tfm=self.defaults["H_tfm"],
96
- eigval_fn=None, # it is applied in `update`
97
- use_lstsq=self.defaults["use_lstsq"],
87
+ _newton_update_state_(
88
+ H=P,
89
+ state = self.global_state,
90
+ damping = fs["damping"],
91
+ eigval_fn = fs["eigval_fn"],
92
+ precompute_inverse = precompute_inverse,
93
+ use_lstsq = fs["use_lstsq"]
98
94
  )
99
95
 
100
- var.update = vec_to_tensors(update, params)
96
+ @torch.no_grad
97
+ def apply_states(self, objective, states, settings):
98
+ updates = objective.get_updates()
99
+ fs = settings[0]
100
+
101
+ b = torch.cat([t.ravel() for t in updates])
102
+ sol = _newton_solve(b=b, state=self.global_state, use_lstsq=fs["use_lstsq"])
103
+
104
+ vec_to_tensors_(sol, updates)
105
+ return objective
101
106
 
102
- return var
103
107
 
104
- def get_H(self,var=...):
105
- return _get_H(self.global_state["P"], eigval_fn=None)
108
+ def get_H(self,objective=...):
109
+ return _newton_get_H(self.global_state)
@@ -1,19 +1,17 @@
1
- from collections.abc import Callable
2
- from contextlib import nullcontext
3
1
  from abc import ABC, abstractmethod
2
+ from collections.abc import Callable, Mapping
3
+ from typing import Any
4
+
4
5
  import numpy as np
5
6
  import torch
6
7
 
7
- from ...core import Chainable, Module, apply_transform, Var
8
- from ...utils import TensorList, vec_to_tensors, vec_to_tensors_
9
- from ...utils.derivatives import (
10
- flatten_jacobian,
11
- jacobian_wrt,
12
- )
8
+ from ...core import Chainable, DerivativesMethod, Objective, Transform
9
+ from ...utils import TensorList, vec_to_tensors
10
+
13
11
 
14
- class HigherOrderMethodBase(Module, ABC):
15
- def __init__(self, defaults: dict | None = None, vectorize: bool = True):
16
- self._vectorize = vectorize
12
+ class HigherOrderMethodBase(Transform, ABC):
13
+ def __init__(self, defaults: dict | None = None, derivatives_method: DerivativesMethod = 'batched_autograd'):
14
+ self._derivatives_method: DerivativesMethod = derivatives_method
17
15
  super().__init__(defaults)
18
16
 
19
17
  @abstractmethod
@@ -21,61 +19,27 @@ class HigherOrderMethodBase(Module, ABC):
21
19
  self,
22
20
  x: torch.Tensor,
23
21
  evaluate: Callable[[torch.Tensor, int], tuple[torch.Tensor, ...]],
24
- var: Var,
22
+ objective: Objective,
23
+ setting: Mapping[str, Any],
25
24
  ) -> torch.Tensor:
26
25
  """"""
27
26
 
28
27
  @torch.no_grad
29
- def step(self, var):
30
- params = TensorList(var.params)
31
- x0 = params.clone()
32
- closure = var.closure
28
+ def apply_states(self, objective, states, settings):
29
+ params = TensorList(objective.params)
30
+
31
+ closure = objective.closure
33
32
  if closure is None: raise RuntimeError('MultipointNewton requires closure')
34
- vectorize = self._vectorize
33
+ derivatives_method = self._derivatives_method
35
34
 
36
35
  def evaluate(x, order) -> tuple[torch.Tensor, ...]:
37
36
  """order=0 - returns (loss,), order=1 - returns (loss, grad), order=2 - returns (loss, grad, hessian), etc."""
38
- params.from_vec_(x)
39
-
40
- if order == 0:
41
- loss = closure(False)
42
- params.copy_(x0)
43
- return (loss, )
44
-
45
- if order == 1:
46
- with torch.enable_grad():
47
- loss = closure()
48
- grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
49
- params.copy_(x0)
50
- return loss, torch.cat([g.ravel() for g in grad])
51
-
52
- with torch.enable_grad():
53
- loss = var.loss = var.loss_approx = closure(False)
54
-
55
- g_list = torch.autograd.grad(loss, params, create_graph=True)
56
- var.grad = list(g_list)
57
-
58
- g = torch.cat([t.ravel() for t in g_list])
59
- n = g.numel()
60
- ret = [loss, g]
61
- T = g # current derivatives tensor
62
-
63
- # get all derivative up to order
64
- for o in range(2, order + 1):
65
- is_last = o == order
66
- T_list = jacobian_wrt([T], params, create_graph=not is_last, batched=vectorize)
67
- with torch.no_grad() if is_last else nullcontext():
68
- # the shape is (ndim, ) * order
69
- T = flatten_jacobian(T_list).view(n, n, *T.shape[1:])
70
- ret.append(T)
71
-
72
- params.copy_(x0)
73
- return tuple(ret)
37
+ return objective.derivatives_at(x, order, method=derivatives_method)
74
38
 
75
39
  x = torch.cat([p.ravel() for p in params])
76
- dir = self.one_iteration(x, evaluate, var)
77
- var.update = vec_to_tensors(dir, var.params)
78
- return var
40
+ dir = self.one_iteration(x, evaluate, objective, settings[0])
41
+ objective.updates = vec_to_tensors(dir, objective.params)
42
+ return objective
79
43
 
80
44
  def _inv(A: torch.Tensor, lstsq:bool) -> torch.Tensor:
81
45
  if lstsq: return torch.linalg.pinv(A) # pylint:disable=not-callable
@@ -106,16 +70,15 @@ class SixthOrder3P(HigherOrderMethodBase):
106
70
 
107
71
  Abro, Hameer Akhtar, and Muhammad Mujtaba Shaikh. "A new time-efficient and convergent nonlinear solver." Applied Mathematics and Computation 355 (2019): 516-536.
108
72
  """
109
- def __init__(self, lstsq: bool=False, vectorize: bool = True):
73
+ def __init__(self, lstsq: bool=False, derivatives_method: DerivativesMethod = 'batched_autograd'):
110
74
  defaults=dict(lstsq=lstsq)
111
- super().__init__(defaults=defaults, vectorize=vectorize)
75
+ super().__init__(defaults=defaults, derivatives_method=derivatives_method)
112
76
 
113
- def one_iteration(self, x, evaluate, var):
114
- settings = self.defaults
115
- lstsq = settings['lstsq']
77
+ @torch.no_grad
78
+ def one_iteration(self, x, evaluate, objective, setting):
116
79
  def f(x): return evaluate(x, 1)[1]
117
80
  def f_j(x): return evaluate(x, 2)[1:]
118
- x_star = sixth_order_3p(x, f, f_j, lstsq)
81
+ x_star = sixth_order_3p(x, f, f_j, setting['lstsq'])
119
82
  return x - x_star
120
83
 
121
84
  # I don't think it works (I tested root finding with this and it goes all over the place)
@@ -173,15 +136,14 @@ def sixth_order_5p(x:torch.Tensor, f_j, lstsq:bool=False):
173
136
 
174
137
  class SixthOrder5P(HigherOrderMethodBase):
175
138
  """Argyros, Ioannis K., et al. "Extended convergence for two sixth order methods under the same weak conditions." Foundations 3.1 (2023): 127-139."""
176
- def __init__(self, lstsq: bool=False, vectorize: bool = True):
139
+ def __init__(self, lstsq: bool=False, derivatives_method: DerivativesMethod = 'batched_autograd'):
177
140
  defaults=dict(lstsq=lstsq)
178
- super().__init__(defaults=defaults, vectorize=vectorize)
141
+ super().__init__(defaults=defaults, derivatives_method=derivatives_method)
179
142
 
180
- def one_iteration(self, x, evaluate, var):
181
- settings = self.defaults
182
- lstsq = settings['lstsq']
143
+ @torch.no_grad
144
+ def one_iteration(self, x, evaluate, objective, setting):
183
145
  def f_j(x): return evaluate(x, 2)[1:]
184
- x_star = sixth_order_5p(x, f_j, lstsq)
146
+ x_star = sixth_order_5p(x, f_j, setting['lstsq'])
185
147
  return x - x_star
186
148
 
187
149
  # 2f 1J 2 solves
@@ -196,16 +158,15 @@ class TwoPointNewton(HigherOrderMethodBase):
196
158
  """two-point Newton method with frozen derivative with third order convergence.
197
159
 
198
160
  Sharma, Janak Raj, and Deepak Kumar. "A fast and efficient composite Newton–Chebyshev method for systems of nonlinear equations." Journal of Complexity 49 (2018): 56-73."""
199
- def __init__(self, lstsq: bool=False, vectorize: bool = True):
161
+ def __init__(self, lstsq: bool=False, derivatives_method: DerivativesMethod = 'batched_autograd'):
200
162
  defaults=dict(lstsq=lstsq)
201
- super().__init__(defaults=defaults, vectorize=vectorize)
163
+ super().__init__(defaults=defaults, derivatives_method=derivatives_method)
202
164
 
203
- def one_iteration(self, x, evaluate, var):
204
- settings = self.defaults
205
- lstsq = settings['lstsq']
165
+ @torch.no_grad
166
+ def one_iteration(self, x, evaluate, objective, setting):
206
167
  def f(x): return evaluate(x, 1)[1]
207
168
  def f_j(x): return evaluate(x, 2)[1:]
208
- x_star = two_point_newton(x, f, f_j, lstsq)
169
+ x_star = two_point_newton(x, f, f_j, setting['lstsq'])
209
170
  return x - x_star
210
171
 
211
172
  #3f 2J 1inv
@@ -224,15 +185,14 @@ def sixth_order_3pm2(x:torch.Tensor, f, f_j, lstsq:bool=False):
224
185
 
225
186
  class SixthOrder3PM2(HigherOrderMethodBase):
226
187
  """Wang, Xiaofeng, and Yang Li. "An efficient sixth-order Newton-type method for solving nonlinear systems." Algorithms 10.2 (2017): 45."""
227
- def __init__(self, lstsq: bool=False, vectorize: bool = True):
188
+ def __init__(self, lstsq: bool=False, derivatives_method: DerivativesMethod = 'batched_autograd'):
228
189
  defaults=dict(lstsq=lstsq)
229
- super().__init__(defaults=defaults, vectorize=vectorize)
190
+ super().__init__(defaults=defaults, derivatives_method=derivatives_method)
230
191
 
231
- def one_iteration(self, x, evaluate, var):
232
- settings = self.defaults
233
- lstsq = settings['lstsq']
192
+ @torch.no_grad
193
+ def one_iteration(self, x, evaluate, objective, setting):
234
194
  def f_j(x): return evaluate(x, 2)[1:]
235
195
  def f(x): return evaluate(x, 1)[1]
236
- x_star = sixth_order_3pm2(x, f, f_j, lstsq)
196
+ x_star = sixth_order_3pm2(x, f, f_j, setting['lstsq'])
237
197
  return x - x_star
238
198