torchzero 0.3.14__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (169) hide show
  1. tests/test_identical.py +2 -2
  2. tests/test_module_autograd.py +586 -0
  3. tests/test_objective.py +188 -0
  4. tests/test_opts.py +47 -36
  5. tests/test_tensorlist.py +0 -8
  6. tests/test_utils_optimizer.py +0 -1
  7. torchzero/__init__.py +1 -1
  8. torchzero/core/__init__.py +8 -2
  9. torchzero/core/chain.py +47 -0
  10. torchzero/core/functional.py +103 -0
  11. torchzero/core/modular.py +233 -0
  12. torchzero/core/module.py +132 -643
  13. torchzero/core/objective.py +948 -0
  14. torchzero/core/reformulation.py +56 -23
  15. torchzero/core/transform.py +261 -365
  16. torchzero/linalg/__init__.py +10 -0
  17. torchzero/linalg/eigh.py +34 -0
  18. torchzero/linalg/linalg_utils.py +14 -0
  19. torchzero/{utils/linalg → linalg}/linear_operator.py +132 -34
  20. torchzero/linalg/matrix_power.py +28 -0
  21. torchzero/linalg/orthogonalize.py +95 -0
  22. torchzero/{utils/linalg → linalg}/qr.py +4 -2
  23. torchzero/{utils/linalg → linalg}/solve.py +76 -88
  24. torchzero/linalg/svd.py +20 -0
  25. torchzero/linalg/torch_linalg.py +168 -0
  26. torchzero/modules/__init__.py +0 -1
  27. torchzero/modules/adaptive/__init__.py +1 -1
  28. torchzero/modules/adaptive/adagrad.py +163 -213
  29. torchzero/modules/adaptive/adahessian.py +74 -103
  30. torchzero/modules/adaptive/adam.py +53 -76
  31. torchzero/modules/adaptive/adan.py +49 -30
  32. torchzero/modules/adaptive/adaptive_heavyball.py +11 -6
  33. torchzero/modules/adaptive/aegd.py +12 -12
  34. torchzero/modules/adaptive/esgd.py +98 -119
  35. torchzero/modules/adaptive/lion.py +5 -10
  36. torchzero/modules/adaptive/lmadagrad.py +87 -32
  37. torchzero/modules/adaptive/mars.py +5 -5
  38. torchzero/modules/adaptive/matrix_momentum.py +47 -51
  39. torchzero/modules/adaptive/msam.py +70 -52
  40. torchzero/modules/adaptive/muon.py +59 -124
  41. torchzero/modules/adaptive/natural_gradient.py +33 -28
  42. torchzero/modules/adaptive/orthograd.py +11 -15
  43. torchzero/modules/adaptive/rmsprop.py +83 -75
  44. torchzero/modules/adaptive/rprop.py +48 -47
  45. torchzero/modules/adaptive/sam.py +55 -45
  46. torchzero/modules/adaptive/shampoo.py +123 -129
  47. torchzero/modules/adaptive/soap.py +207 -143
  48. torchzero/modules/adaptive/sophia_h.py +106 -130
  49. torchzero/modules/clipping/clipping.py +15 -18
  50. torchzero/modules/clipping/ema_clipping.py +31 -25
  51. torchzero/modules/clipping/growth_clipping.py +14 -17
  52. torchzero/modules/conjugate_gradient/cg.py +26 -37
  53. torchzero/modules/experimental/__init__.py +3 -6
  54. torchzero/modules/experimental/coordinate_momentum.py +36 -0
  55. torchzero/modules/experimental/curveball.py +25 -41
  56. torchzero/modules/experimental/gradmin.py +2 -2
  57. torchzero/modules/{higher_order → experimental}/higher_order_newton.py +14 -40
  58. torchzero/modules/experimental/newton_solver.py +22 -53
  59. torchzero/modules/experimental/newtonnewton.py +20 -17
  60. torchzero/modules/experimental/reduce_outward_lr.py +7 -7
  61. torchzero/modules/experimental/scipy_newton_cg.py +21 -24
  62. torchzero/modules/experimental/spsa1.py +5 -5
  63. torchzero/modules/experimental/structural_projections.py +1 -4
  64. torchzero/modules/functional.py +8 -1
  65. torchzero/modules/grad_approximation/forward_gradient.py +7 -7
  66. torchzero/modules/grad_approximation/grad_approximator.py +23 -16
  67. torchzero/modules/grad_approximation/rfdm.py +20 -17
  68. torchzero/modules/least_squares/gn.py +90 -42
  69. torchzero/modules/line_search/__init__.py +1 -1
  70. torchzero/modules/line_search/_polyinterp.py +3 -1
  71. torchzero/modules/line_search/adaptive.py +3 -3
  72. torchzero/modules/line_search/backtracking.py +3 -3
  73. torchzero/modules/line_search/interpolation.py +160 -0
  74. torchzero/modules/line_search/line_search.py +42 -51
  75. torchzero/modules/line_search/strong_wolfe.py +5 -5
  76. torchzero/modules/misc/debug.py +12 -12
  77. torchzero/modules/misc/escape.py +10 -10
  78. torchzero/modules/misc/gradient_accumulation.py +10 -78
  79. torchzero/modules/misc/homotopy.py +16 -8
  80. torchzero/modules/misc/misc.py +120 -122
  81. torchzero/modules/misc/multistep.py +63 -61
  82. torchzero/modules/misc/regularization.py +49 -44
  83. torchzero/modules/misc/split.py +30 -28
  84. torchzero/modules/misc/switch.py +37 -32
  85. torchzero/modules/momentum/averaging.py +14 -14
  86. torchzero/modules/momentum/cautious.py +34 -28
  87. torchzero/modules/momentum/momentum.py +11 -11
  88. torchzero/modules/ops/__init__.py +4 -4
  89. torchzero/modules/ops/accumulate.py +21 -21
  90. torchzero/modules/ops/binary.py +67 -66
  91. torchzero/modules/ops/higher_level.py +19 -19
  92. torchzero/modules/ops/multi.py +44 -41
  93. torchzero/modules/ops/reduce.py +26 -23
  94. torchzero/modules/ops/unary.py +53 -53
  95. torchzero/modules/ops/utility.py +47 -46
  96. torchzero/modules/projections/galore.py +1 -1
  97. torchzero/modules/projections/projection.py +43 -43
  98. torchzero/modules/quasi_newton/__init__.py +2 -0
  99. torchzero/modules/quasi_newton/damping.py +1 -1
  100. torchzero/modules/quasi_newton/lbfgs.py +7 -7
  101. torchzero/modules/quasi_newton/lsr1.py +7 -7
  102. torchzero/modules/quasi_newton/quasi_newton.py +25 -16
  103. torchzero/modules/quasi_newton/sg2.py +292 -0
  104. torchzero/modules/restarts/restars.py +26 -24
  105. torchzero/modules/second_order/__init__.py +6 -3
  106. torchzero/modules/second_order/ifn.py +58 -0
  107. torchzero/modules/second_order/inm.py +101 -0
  108. torchzero/modules/second_order/multipoint.py +40 -80
  109. torchzero/modules/second_order/newton.py +105 -228
  110. torchzero/modules/second_order/newton_cg.py +102 -154
  111. torchzero/modules/second_order/nystrom.py +158 -178
  112. torchzero/modules/second_order/rsn.py +237 -0
  113. torchzero/modules/smoothing/laplacian.py +13 -12
  114. torchzero/modules/smoothing/sampling.py +11 -10
  115. torchzero/modules/step_size/adaptive.py +23 -23
  116. torchzero/modules/step_size/lr.py +15 -15
  117. torchzero/modules/termination/termination.py +32 -30
  118. torchzero/modules/trust_region/cubic_regularization.py +2 -2
  119. torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
  120. torchzero/modules/trust_region/trust_cg.py +1 -1
  121. torchzero/modules/trust_region/trust_region.py +27 -22
  122. torchzero/modules/variance_reduction/svrg.py +21 -18
  123. torchzero/modules/weight_decay/__init__.py +2 -1
  124. torchzero/modules/weight_decay/reinit.py +83 -0
  125. torchzero/modules/weight_decay/weight_decay.py +12 -13
  126. torchzero/modules/wrappers/optim_wrapper.py +57 -50
  127. torchzero/modules/zeroth_order/cd.py +9 -6
  128. torchzero/optim/root.py +3 -3
  129. torchzero/optim/utility/split.py +2 -1
  130. torchzero/optim/wrappers/directsearch.py +27 -63
  131. torchzero/optim/wrappers/fcmaes.py +14 -35
  132. torchzero/optim/wrappers/mads.py +11 -31
  133. torchzero/optim/wrappers/moors.py +66 -0
  134. torchzero/optim/wrappers/nevergrad.py +4 -4
  135. torchzero/optim/wrappers/nlopt.py +31 -25
  136. torchzero/optim/wrappers/optuna.py +6 -13
  137. torchzero/optim/wrappers/pybobyqa.py +124 -0
  138. torchzero/optim/wrappers/scipy/__init__.py +7 -0
  139. torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
  140. torchzero/optim/wrappers/scipy/brute.py +48 -0
  141. torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
  142. torchzero/optim/wrappers/scipy/direct.py +69 -0
  143. torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
  144. torchzero/optim/wrappers/scipy/experimental.py +141 -0
  145. torchzero/optim/wrappers/scipy/minimize.py +151 -0
  146. torchzero/optim/wrappers/scipy/sgho.py +111 -0
  147. torchzero/optim/wrappers/wrapper.py +121 -0
  148. torchzero/utils/__init__.py +7 -25
  149. torchzero/utils/compile.py +2 -2
  150. torchzero/utils/derivatives.py +112 -88
  151. torchzero/utils/optimizer.py +4 -77
  152. torchzero/utils/python_tools.py +31 -0
  153. torchzero/utils/tensorlist.py +11 -5
  154. torchzero/utils/thoad_tools.py +68 -0
  155. {torchzero-0.3.14.dist-info → torchzero-0.4.0.dist-info}/METADATA +1 -1
  156. torchzero-0.4.0.dist-info/RECORD +191 -0
  157. tests/test_vars.py +0 -185
  158. torchzero/modules/experimental/momentum.py +0 -160
  159. torchzero/modules/higher_order/__init__.py +0 -1
  160. torchzero/optim/wrappers/scipy.py +0 -572
  161. torchzero/utils/linalg/__init__.py +0 -12
  162. torchzero/utils/linalg/matrix_funcs.py +0 -87
  163. torchzero/utils/linalg/orthogonalize.py +0 -12
  164. torchzero/utils/linalg/svd.py +0 -20
  165. torchzero/utils/ops.py +0 -10
  166. torchzero-0.3.14.dist-info/RECORD +0 -167
  167. /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
  168. {torchzero-0.3.14.dist-info → torchzero-0.4.0.dist-info}/WHEEL +0 -0
  169. {torchzero-0.3.14.dist-info → torchzero-0.4.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,292 @@
1
+ import torch
2
+
3
+ from ...core import Module, Chainable, step
4
+ from ...utils import TensorList, vec_to_tensors
5
+ from ..second_order.newton import _newton_step, _get_H
6
+
7
+ def sg2_(
8
+ delta_g: torch.Tensor,
9
+ cd: torch.Tensor,
10
+ ) -> torch.Tensor:
11
+ """cd is c * perturbation, and must be multiplied by two if hessian estimate is two-sided
12
+ (or divide delta_g by two)."""
13
+
14
+ M = torch.outer(1.0 / cd, delta_g)
15
+ H_hat = 0.5 * (M + M.T)
16
+
17
+ return H_hat
18
+
19
+
20
+
21
+ class SG2(Module):
22
+ """second-order stochastic gradient
23
+
24
+ SG2 with line search
25
+ ```python
26
+ opt = tz.Modular(
27
+ model.parameters(),
28
+ tz.m.SG2(),
29
+ tz.m.Backtracking()
30
+ )
31
+ ```
32
+
33
+ SG2 with trust region
34
+ ```python
35
+ opt = tz.Modular(
36
+ model.parameters(),
37
+ tz.m.LevenbergMarquardt(tz.m.SG2()),
38
+ )
39
+ ```
40
+
41
+ """
42
+
43
+ def __init__(
44
+ self,
45
+ n_samples: int = 1,
46
+ h: float = 1e-2,
47
+ beta: float | None = None,
48
+ damping: float = 0,
49
+ eigval_fn=None,
50
+ one_sided: bool = False, # one-sided hessian
51
+ use_lstsq: bool = True,
52
+ seed=None,
53
+ inner: Chainable | None = None,
54
+ ):
55
+ defaults = dict(n_samples=n_samples, h=h, beta=beta, damping=damping, eigval_fn=eigval_fn, one_sided=one_sided, seed=seed, use_lstsq=use_lstsq)
56
+ super().__init__(defaults)
57
+
58
+ if inner is not None: self.set_child('inner', inner)
59
+
60
+ @torch.no_grad
61
+ def update(self, objective):
62
+ k = self.global_state.get('step', 0) + 1
63
+ self.global_state["step"] = k
64
+
65
+ params = TensorList(objective.params)
66
+ closure = objective.closure
67
+ if closure is None:
68
+ raise RuntimeError("closure is required for SG2")
69
+ generator = self.get_generator(params[0].device, self.defaults["seed"])
70
+
71
+ h = self.get_settings(params, "h")
72
+ x_0 = params.clone()
73
+ n_samples = self.defaults["n_samples"]
74
+ H_hat = None
75
+
76
+ for i in range(n_samples):
77
+ # generate perturbation
78
+ cd = params.rademacher_like(generator=generator).mul_(h)
79
+
80
+ # one sided
81
+ if self.defaults["one_sided"]:
82
+ g_0 = TensorList(objective.get_grads())
83
+ params.add_(cd)
84
+ closure()
85
+
86
+ g_p = params.grad.fill_none_(params)
87
+ delta_g = (g_p - g_0) * 2
88
+
89
+ # two sided
90
+ else:
91
+ params.add_(cd)
92
+ closure()
93
+ g_p = params.grad.fill_none_(params)
94
+
95
+ params.copy_(x_0)
96
+ params.sub_(cd)
97
+ closure()
98
+ g_n = params.grad.fill_none_(params)
99
+
100
+ delta_g = g_p - g_n
101
+
102
+ # restore params
103
+ params.set_(x_0)
104
+
105
+ # compute H hat
106
+ H_i = sg2_(
107
+ delta_g = delta_g.to_vec(),
108
+ cd = cd.to_vec(),
109
+ )
110
+
111
+ if H_hat is None: H_hat = H_i
112
+ else: H_hat += H_i
113
+
114
+ assert H_hat is not None
115
+ if n_samples > 1: H_hat /= n_samples
116
+
117
+ # update H
118
+ H = self.global_state.get("H", None)
119
+ if H is None: H = H_hat
120
+ else:
121
+ beta = self.defaults["beta"]
122
+ if beta is None: beta = k / (k+1)
123
+ H.lerp_(H_hat, 1-beta)
124
+
125
+ self.global_state["H"] = H
126
+
127
+
128
+ @torch.no_grad
129
+ def apply(self, objective):
130
+ dir = _newton_step(
131
+ objective=objective,
132
+ H = self.global_state["H"],
133
+ damping = self.defaults["damping"],
134
+ inner = self.children.get("inner", None),
135
+ H_tfm=None,
136
+ eigval_fn=self.defaults["eigval_fn"],
137
+ use_lstsq=self.defaults["use_lstsq"],
138
+ g_proj=None,
139
+ )
140
+
141
+ objective.updates = vec_to_tensors(dir, objective.params)
142
+ return objective
143
+
144
+ def get_H(self,objective=...):
145
+ return _get_H(self.global_state["H"], self.defaults["eigval_fn"])
146
+
147
+
148
+
149
+
150
+ # two sided
151
+ # we have g via x + d, x - d
152
+ # H via g(x + d), g(x - d)
153
+ # 1 is x, x+2d
154
+ # 2 is x, x-2d
155
+ # 5 evals in total
156
+
157
+ # one sided
158
+ # g via x, x + d
159
+ # 1 is x, x + d
160
+ # 2 is x, x - d
161
+ # 3 evals and can use two sided for g_0
162
+
163
+ class SPSA2(Module):
164
+ """second-order SPSA
165
+
166
+ SPSA2 with line search
167
+ ```python
168
+ opt = tz.Modular(
169
+ model.parameters(),
170
+ tz.m.SPSA2(),
171
+ tz.m.Backtracking()
172
+ )
173
+ ```
174
+
175
+ SPSA2 with trust region
176
+ ```python
177
+ opt = tz.Modular(
178
+ model.parameters(),
179
+ tz.m.LevenbergMarquardt(tz.m.SPSA2()),
180
+ )
181
+ ```
182
+ """
183
+
184
+ def __init__(
185
+ self,
186
+ n_samples: int = 1,
187
+ h: float = 1e-2,
188
+ beta: float | None = None,
189
+ damping: float = 0,
190
+ eigval_fn=None,
191
+ use_lstsq: bool = True,
192
+ seed=None,
193
+ inner: Chainable | None = None,
194
+ ):
195
+ defaults = dict(n_samples=n_samples, h=h, beta=beta, damping=damping, eigval_fn=eigval_fn, seed=seed, use_lstsq=use_lstsq)
196
+ super().__init__(defaults)
197
+
198
+ if inner is not None: self.set_child('inner', inner)
199
+
200
+ @torch.no_grad
201
+ def update(self, objective):
202
+ k = self.global_state.get('step', 0) + 1
203
+ self.global_state["step"] = k
204
+
205
+ params = TensorList(objective.params)
206
+ closure = objective.closure
207
+ if closure is None:
208
+ raise RuntimeError("closure is required for SPSA2")
209
+
210
+ generator = self.get_generator(params[0].device, self.defaults["seed"])
211
+
212
+ h = self.get_settings(params, "h")
213
+ x_0 = params.clone()
214
+ n_samples = self.defaults["n_samples"]
215
+ H_hat = None
216
+ g_0 = None
217
+
218
+ for i in range(n_samples):
219
+ # perturbations for g and H
220
+ cd_g = params.rademacher_like(generator=generator).mul_(h)
221
+ cd_H = params.rademacher_like(generator=generator).mul_(h)
222
+
223
+ # evaluate 4 points
224
+ x_p = x_0 + cd_g
225
+ x_n = x_0 - cd_g
226
+
227
+ params.set_(x_p)
228
+ f_p = closure(False)
229
+ params.add_(cd_H)
230
+ f_pp = closure(False)
231
+
232
+ params.set_(x_n)
233
+ f_n = closure(False)
234
+ params.add_(cd_H)
235
+ f_np = closure(False)
236
+
237
+ g_p_vec = (f_pp - f_p) / cd_H
238
+ g_n_vec = (f_np - f_n) / cd_H
239
+ delta_g = g_p_vec - g_n_vec
240
+
241
+ # restore params
242
+ params.set_(x_0)
243
+
244
+ # compute grad
245
+ g_i = (f_p - f_n) / (2 * cd_g)
246
+ if g_0 is None: g_0 = g_i
247
+ else: g_0 += g_i
248
+
249
+ # compute H hat
250
+ H_i = sg2_(
251
+ delta_g = delta_g.to_vec().div_(2.0),
252
+ cd = cd_g.to_vec(), # The interval is measured by the original 'cd'
253
+ )
254
+ if H_hat is None: H_hat = H_i
255
+ else: H_hat += H_i
256
+
257
+ assert g_0 is not None and H_hat is not None
258
+ if n_samples > 1:
259
+ g_0 /= n_samples
260
+ H_hat /= n_samples
261
+
262
+ # set grad to approximated grad
263
+ objective.grads = g_0
264
+
265
+ # update H
266
+ H = self.global_state.get("H", None)
267
+ if H is None: H = H_hat
268
+ else:
269
+ beta = self.defaults["beta"]
270
+ if beta is None: beta = k / (k+1)
271
+ H.lerp_(H_hat, 1-beta)
272
+
273
+ self.global_state["H"] = H
274
+
275
+ @torch.no_grad
276
+ def apply(self, objective):
277
+ dir = _newton_step(
278
+ objective=objective,
279
+ H = self.global_state["H"],
280
+ damping = self.defaults["damping"],
281
+ inner = self.children.get("inner", None),
282
+ H_tfm=None,
283
+ eigval_fn=self.defaults["eigval_fn"],
284
+ use_lstsq=self.defaults["use_lstsq"],
285
+ g_proj=None,
286
+ )
287
+
288
+ objective.updates = vec_to_tensors(dir, objective.params)
289
+ return objective
290
+
291
+ def get_H(self,objective=...):
292
+ return _get_H(self.global_state["H"], self.defaults["eigval_fn"])
@@ -4,12 +4,14 @@ from typing import final, Literal, cast
4
4
 
5
5
  import torch
6
6
 
7
- from ...core import Chainable, Module, Var
7
+ from ...core import Chainable, Module, Objective
8
8
  from ...utils import TensorList
9
9
  from ..termination import TerminationCriteriaBase
10
10
 
11
11
  def _reset_except_self(optimizer, var, self: Module):
12
- for m in optimizer.unrolled_modules: m.reset()
12
+ for m in optimizer.unrolled_modules:
13
+ if m is not self:
14
+ m.reset()
13
15
 
14
16
  class RestartStrategyBase(Module, ABC):
15
17
  """Base class for restart strategies.
@@ -24,7 +26,7 @@ class RestartStrategyBase(Module, ABC):
24
26
  self.set_child('modules', modules)
25
27
 
26
28
  @abstractmethod
27
- def should_reset(self, var: Var) -> bool:
29
+ def should_reset(self, var: Objective) -> bool:
28
30
  """returns whether reset should occur"""
29
31
 
30
32
  def _reset_on_condition(self, var):
@@ -39,23 +41,23 @@ class RestartStrategyBase(Module, ABC):
39
41
  return modules
40
42
 
41
43
  @final
42
- def update(self, var):
43
- modules = self._reset_on_condition(var)
44
+ def update(self, objective):
45
+ modules = self._reset_on_condition(objective)
44
46
  if modules is not None:
45
- modules.update(var)
47
+ modules.update(objective)
46
48
 
47
49
  @final
48
- def apply(self, var):
50
+ def apply(self, objective):
49
51
  # don't check here because it was check in `update`
50
52
  modules = self.children.get('modules', None)
51
- if modules is None: return var
52
- return modules.apply(var.clone(clone_update=False))
53
+ if modules is None: return objective
54
+ return modules.apply(objective.clone(clone_updates=False))
53
55
 
54
56
  @final
55
- def step(self, var):
56
- modules = self._reset_on_condition(var)
57
- if modules is None: return var
58
- return modules.step(var.clone(clone_update=False))
57
+ def step(self, objective):
58
+ modules = self._reset_on_condition(objective)
59
+ if modules is None: return objective
60
+ return modules.step(objective.clone(clone_updates=False))
59
61
 
60
62
 
61
63
 
@@ -170,7 +172,7 @@ class PowellRestart(RestartStrategyBase):
170
172
  super().__init__(defaults, modules)
171
173
 
172
174
  def should_reset(self, var):
173
- g = TensorList(var.get_grad())
175
+ g = TensorList(var.get_grads())
174
176
  cond1 = self.defaults['cond1']; cond2 = self.defaults['cond2']
175
177
 
176
178
  # -------------------------------- initialize -------------------------------- #
@@ -192,7 +194,7 @@ class PowellRestart(RestartStrategyBase):
192
194
 
193
195
  # ------------------------------- 2nd condition ------------------------------ #
194
196
  if (cond2 is not None) and (not reset):
195
- d_g = TensorList(var.get_update()).dot(g)
197
+ d_g = TensorList(var.get_updates()).dot(g)
196
198
  if (-1-cond2) * g_g < d_g < (-1 + cond2) * g_g:
197
199
  reset = True
198
200
 
@@ -229,17 +231,17 @@ class BirginMartinezRestart(Module):
229
231
 
230
232
  self.set_child("module", module)
231
233
 
232
- def update(self, var):
234
+ def update(self, objective):
233
235
  module = self.children['module']
234
- module.update(var)
236
+ module.update(objective)
235
237
 
236
- def apply(self, var):
238
+ def apply(self, objective):
237
239
  module = self.children['module']
238
- var = module.apply(var.clone(clone_update=False))
240
+ objective = module.apply(objective.clone(clone_updates=False))
239
241
 
240
242
  cond = self.defaults['cond']
241
- g = TensorList(var.get_grad())
242
- d = TensorList(var.get_update())
243
+ g = TensorList(objective.get_grads())
244
+ d = TensorList(objective.get_updates())
243
245
  d_g = d.dot(g)
244
246
  d_norm = d.global_vector_norm()
245
247
  g_norm = g.global_vector_norm()
@@ -247,7 +249,7 @@ class BirginMartinezRestart(Module):
247
249
  # d in our case is same direction as g so it has a minus sign
248
250
  if -d_g > -cond * d_norm * g_norm:
249
251
  module.reset()
250
- var.update = g.clone()
251
- return var
252
+ objective.updates = g.clone()
253
+ return objective
252
254
 
253
- return var
255
+ return objective
@@ -1,4 +1,7 @@
1
- from .newton import Newton, InverseFreeNewton
1
+ from .ifn import InverseFreeNewton
2
+ from .inm import ImprovedNewton
3
+ from .multipoint import SixthOrder3P, SixthOrder3PM2, SixthOrder5P, TwoPointNewton
4
+ from .newton import Newton
2
5
  from .newton_cg import NewtonCG, NewtonCGSteihaug
3
- from .nystrom import NystromSketchAndSolve, NystromPCG
4
- from .multipoint import SixthOrder3P, SixthOrder5P, TwoPointNewton, SixthOrder3PM2
6
+ from .nystrom import NystromPCG, NystromSketchAndSolve
7
+ from .rsn import SubspaceNewton
@@ -0,0 +1,58 @@
1
+ import torch
2
+
3
+ from ...core import Chainable, Transform, HessianMethod
4
+ from ...utils import TensorList, vec_to_tensors
5
+ from ...linalg.linear_operator import DenseWithInverse
6
+
7
+
8
+ class InverseFreeNewton(Transform):
9
+ """Inverse-free newton's method
10
+
11
+ Reference
12
+ [Massalski, Marcin, and Magdalena Nockowska-Rosiak. "INVERSE-FREE NEWTON'S METHOD." Journal of Applied Analysis & Computation 15.4 (2025): 2238-2257.](https://www.jaac-online.com/article/doi/10.11948/20240428)
13
+ """
14
+ def __init__(
15
+ self,
16
+ update_freq: int = 1,
17
+ hessian_method: HessianMethod = "batched_autograd",
18
+ h: float = 1e-3,
19
+ inner: Chainable | None = None,
20
+ ):
21
+ defaults = dict(hessian_method=hessian_method, h=h)
22
+ super().__init__(defaults, update_freq=update_freq, inner=inner)
23
+
24
+ @torch.no_grad
25
+ def update_states(self, objective, states, settings):
26
+ fs = settings[0]
27
+
28
+ _, _, H = objective.hessian(
29
+ hessian_method=fs['hessian_method'],
30
+ h=fs['h'],
31
+ at_x0=True
32
+ )
33
+
34
+ self.global_state["H"] = H
35
+
36
+ # inverse free part
37
+ if 'Y' not in self.global_state:
38
+ num = H.T
39
+ denom = (torch.linalg.norm(H, 1) * torch.linalg.norm(H, float('inf'))) # pylint:disable=not-callable
40
+
41
+ finfo = torch.finfo(H.dtype)
42
+ self.global_state['Y'] = num.div_(denom.clip(min=finfo.tiny * 2, max=finfo.max / 2))
43
+
44
+ else:
45
+ Y = self.global_state['Y']
46
+ I2 = torch.eye(Y.size(0), device=Y.device, dtype=Y.dtype).mul_(2)
47
+ I2 -= H @ Y
48
+ self.global_state['Y'] = Y @ I2
49
+
50
+
51
+ def apply_states(self, objective, states, settings):
52
+ Y = self.global_state["Y"]
53
+ g = torch.cat([t.ravel() for t in objective.get_updates()])
54
+ objective.updates = vec_to_tensors(Y@g, objective.params)
55
+ return objective
56
+
57
+ def get_H(self,objective=...):
58
+ return DenseWithInverse(A = self.global_state["H"], A_inv=self.global_state["Y"])
@@ -0,0 +1,101 @@
1
+ from collections.abc import Callable
2
+
3
+ import torch
4
+
5
+ from ...core import Chainable, Transform, HessianMethod
6
+ from ...utils import TensorList, vec_to_tensors, unpack_states
7
+ from ..functional import safe_clip
8
+ from .newton import _get_H, _newton_step
9
+
10
+ @torch.no_grad
11
+ def inm(f:torch.Tensor, J:torch.Tensor, s:torch.Tensor, y:torch.Tensor):
12
+
13
+ yy = safe_clip(y.dot(y))
14
+ ss = safe_clip(s.dot(s))
15
+
16
+ term1 = y.dot(y - J@s) / yy
17
+ FbT = f.outer(s).mul_(term1 / ss)
18
+
19
+ P = FbT.add_(J)
20
+ return P
21
+
22
+ def _eigval_fn(J: torch.Tensor, fn) -> torch.Tensor:
23
+ if fn is None: return J
24
+ L, Q = torch.linalg.eigh(J) # pylint:disable=not-callable
25
+ return (Q * L.unsqueeze(-2)) @ Q.mH
26
+
27
+ class ImprovedNewton(Transform):
28
+ """Improved Newton's Method (INM).
29
+
30
+ Reference:
31
+ [Saheya, B., et al. "A new Newton-like method for solving nonlinear equations." SpringerPlus 5.1 (2016): 1269.](https://d-nb.info/1112813721/34)
32
+ """
33
+
34
+ def __init__(
35
+ self,
36
+ damping: float = 0,
37
+ use_lstsq: bool = False,
38
+ update_freq: int = 1,
39
+ H_tfm: Callable[[torch.Tensor, torch.Tensor], tuple[torch.Tensor, bool]] | Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
40
+ eigval_fn: Callable[[torch.Tensor], torch.Tensor] | None = None,
41
+ hessian_method: HessianMethod = "batched_autograd",
42
+ h: float = 1e-3,
43
+ inner: Chainable | None = None,
44
+ ):
45
+ defaults = locals().copy()
46
+ del defaults['self'], defaults['inner'], defaults["update_freq"]
47
+ super().__init__(defaults, update_freq=update_freq, inner=inner, )
48
+
49
+ @torch.no_grad
50
+ def update_states(self, objective, states, settings):
51
+ fs = settings[0]
52
+
53
+ _, f_list, J = objective.hessian(
54
+ hessian_method=fs['hessian_method'],
55
+ h=fs['h'],
56
+ at_x0=True
57
+ )
58
+ if f_list is None: f_list = objective.get_grads()
59
+
60
+ f = torch.cat([t.ravel() for t in f_list])
61
+ J = _eigval_fn(J, fs["eigval_fn"])
62
+
63
+ x_list = TensorList(objective.params)
64
+ f_list = TensorList(objective.get_grads())
65
+ x_prev, f_prev = unpack_states(states, objective.params, "x_prev", "f_prev", cls=TensorList)
66
+
67
+ # initialize on 1st step, do Newton step
68
+ if "P" not in self.global_state:
69
+ x_prev.copy_(x_list)
70
+ f_prev.copy_(f_list)
71
+ self.global_state["P"] = J
72
+ return
73
+
74
+ # INM update
75
+ s_list = x_list - x_prev
76
+ y_list = f_list - f_prev
77
+ x_prev.copy_(x_list)
78
+ f_prev.copy_(f_list)
79
+
80
+ self.global_state["P"] = inm(f, J, s=s_list.to_vec(), y=y_list.to_vec())
81
+
82
+
83
+ @torch.no_grad
84
+ def apply_states(self, objective, states, settings):
85
+ fs = settings[0]
86
+
87
+ update = _newton_step(
88
+ objective = objective,
89
+ H = self.global_state["P"],
90
+ damping = fs["damping"],
91
+ H_tfm = fs["H_tfm"],
92
+ eigval_fn = None, # it is applied in `update`
93
+ use_lstsq = fs["use_lstsq"],
94
+ )
95
+
96
+ objective.updates = vec_to_tensors(update, objective.params)
97
+
98
+ return objective
99
+
100
+ def get_H(self,objective=...):
101
+ return _get_H(self.global_state["P"], eigval_fn=None)