torchzero 0.4.0__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (112) hide show
  1. tests/test_identical.py +22 -22
  2. tests/test_opts.py +199 -198
  3. torchzero/__init__.py +1 -1
  4. torchzero/core/__init__.py +1 -1
  5. torchzero/core/functional.py +1 -1
  6. torchzero/core/modular.py +5 -5
  7. torchzero/core/module.py +2 -2
  8. torchzero/core/objective.py +10 -10
  9. torchzero/core/transform.py +1 -1
  10. torchzero/linalg/__init__.py +3 -2
  11. torchzero/linalg/eigh.py +223 -4
  12. torchzero/linalg/orthogonalize.py +2 -4
  13. torchzero/linalg/qr.py +12 -0
  14. torchzero/linalg/solve.py +1 -3
  15. torchzero/linalg/svd.py +47 -20
  16. torchzero/modules/__init__.py +4 -3
  17. torchzero/modules/adaptive/__init__.py +11 -3
  18. torchzero/modules/adaptive/adagrad.py +10 -10
  19. torchzero/modules/adaptive/adahessian.py +2 -2
  20. torchzero/modules/adaptive/adam.py +1 -1
  21. torchzero/modules/adaptive/adan.py +1 -1
  22. torchzero/modules/adaptive/adaptive_heavyball.py +1 -1
  23. torchzero/modules/adaptive/esgd.py +2 -2
  24. torchzero/modules/adaptive/ggt.py +186 -0
  25. torchzero/modules/adaptive/lion.py +2 -1
  26. torchzero/modules/adaptive/lre_optimizers.py +299 -0
  27. torchzero/modules/adaptive/mars.py +2 -2
  28. torchzero/modules/adaptive/matrix_momentum.py +1 -1
  29. torchzero/modules/adaptive/msam.py +4 -4
  30. torchzero/modules/adaptive/muon.py +9 -6
  31. torchzero/modules/adaptive/natural_gradient.py +32 -15
  32. torchzero/modules/adaptive/psgd/__init__.py +5 -0
  33. torchzero/modules/adaptive/psgd/_psgd_utils.py +37 -0
  34. torchzero/modules/adaptive/psgd/psgd.py +1390 -0
  35. torchzero/modules/adaptive/psgd/psgd_dense_newton.py +174 -0
  36. torchzero/modules/adaptive/psgd/psgd_kron_newton.py +203 -0
  37. torchzero/modules/adaptive/psgd/psgd_kron_whiten.py +185 -0
  38. torchzero/modules/adaptive/psgd/psgd_lra_newton.py +118 -0
  39. torchzero/modules/adaptive/psgd/psgd_lra_whiten.py +116 -0
  40. torchzero/modules/adaptive/rprop.py +2 -2
  41. torchzero/modules/adaptive/sam.py +4 -4
  42. torchzero/modules/adaptive/shampoo.py +28 -3
  43. torchzero/modules/adaptive/soap.py +3 -3
  44. torchzero/modules/adaptive/sophia_h.py +2 -2
  45. torchzero/modules/clipping/clipping.py +7 -7
  46. torchzero/modules/conjugate_gradient/cg.py +2 -2
  47. torchzero/modules/experimental/__init__.py +5 -0
  48. torchzero/modules/experimental/adanystrom.py +258 -0
  49. torchzero/modules/experimental/common_directions_whiten.py +142 -0
  50. torchzero/modules/experimental/cubic_adam.py +160 -0
  51. torchzero/modules/experimental/eigen_sr1.py +182 -0
  52. torchzero/modules/experimental/eigengrad.py +207 -0
  53. torchzero/modules/experimental/l_infinity.py +1 -1
  54. torchzero/modules/experimental/matrix_nag.py +122 -0
  55. torchzero/modules/experimental/newton_solver.py +2 -2
  56. torchzero/modules/experimental/newtonnewton.py +34 -40
  57. torchzero/modules/grad_approximation/fdm.py +2 -2
  58. torchzero/modules/grad_approximation/rfdm.py +4 -4
  59. torchzero/modules/least_squares/gn.py +68 -45
  60. torchzero/modules/line_search/backtracking.py +2 -2
  61. torchzero/modules/line_search/line_search.py +1 -1
  62. torchzero/modules/line_search/strong_wolfe.py +2 -2
  63. torchzero/modules/misc/escape.py +1 -1
  64. torchzero/modules/misc/gradient_accumulation.py +1 -1
  65. torchzero/modules/misc/misc.py +1 -1
  66. torchzero/modules/misc/multistep.py +4 -7
  67. torchzero/modules/misc/regularization.py +2 -2
  68. torchzero/modules/misc/split.py +1 -1
  69. torchzero/modules/misc/switch.py +2 -2
  70. torchzero/modules/momentum/cautious.py +3 -3
  71. torchzero/modules/momentum/momentum.py +1 -1
  72. torchzero/modules/ops/higher_level.py +1 -1
  73. torchzero/modules/ops/multi.py +1 -1
  74. torchzero/modules/projections/projection.py +5 -2
  75. torchzero/modules/quasi_newton/__init__.py +1 -1
  76. torchzero/modules/quasi_newton/damping.py +1 -1
  77. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +1 -1
  78. torchzero/modules/quasi_newton/lbfgs.py +3 -3
  79. torchzero/modules/quasi_newton/lsr1.py +3 -3
  80. torchzero/modules/quasi_newton/quasi_newton.py +44 -29
  81. torchzero/modules/quasi_newton/sg2.py +69 -205
  82. torchzero/modules/restarts/restars.py +17 -17
  83. torchzero/modules/second_order/inm.py +33 -25
  84. torchzero/modules/second_order/newton.py +132 -130
  85. torchzero/modules/second_order/newton_cg.py +3 -3
  86. torchzero/modules/second_order/nystrom.py +83 -32
  87. torchzero/modules/second_order/rsn.py +41 -44
  88. torchzero/modules/smoothing/laplacian.py +1 -1
  89. torchzero/modules/smoothing/sampling.py +2 -3
  90. torchzero/modules/step_size/adaptive.py +6 -6
  91. torchzero/modules/step_size/lr.py +2 -2
  92. torchzero/modules/trust_region/cubic_regularization.py +1 -1
  93. torchzero/modules/trust_region/levenberg_marquardt.py +2 -2
  94. torchzero/modules/trust_region/trust_cg.py +1 -1
  95. torchzero/modules/variance_reduction/svrg.py +4 -5
  96. torchzero/modules/weight_decay/reinit.py +2 -2
  97. torchzero/modules/weight_decay/weight_decay.py +5 -5
  98. torchzero/modules/wrappers/optim_wrapper.py +4 -4
  99. torchzero/modules/zeroth_order/cd.py +1 -1
  100. torchzero/optim/mbs.py +291 -0
  101. torchzero/optim/wrappers/nevergrad.py +0 -9
  102. torchzero/optim/wrappers/optuna.py +2 -0
  103. torchzero/utils/benchmarks/__init__.py +0 -0
  104. torchzero/utils/benchmarks/logistic.py +122 -0
  105. torchzero/utils/derivatives.py +4 -4
  106. {torchzero-0.4.0.dist-info → torchzero-0.4.1.dist-info}/METADATA +1 -1
  107. torchzero-0.4.1.dist-info/RECORD +209 -0
  108. torchzero/modules/adaptive/lmadagrad.py +0 -241
  109. torchzero-0.4.0.dist-info/RECORD +0 -191
  110. /torchzero/modules/{functional.py → opt_utils.py} +0 -0
  111. {torchzero-0.4.0.dist-info → torchzero-0.4.1.dist-info}/WHEEL +0 -0
  112. {torchzero-0.4.0.dist-info → torchzero-0.4.1.dist-info}/top_level.txt +0 -0
@@ -1,29 +1,39 @@
1
1
  import torch
2
2
 
3
- from ...core import Module, Chainable, step
4
- from ...utils import TensorList, vec_to_tensors
5
- from ..second_order.newton import _newton_step, _get_H
3
+ from ...core import Chainable, Transform
4
+ from ...utils import TensorList, unpack_dicts, unpack_states, vec_to_tensors_
5
+ from ...linalg.linear_operator import Dense
6
+
6
7
 
7
8
  def sg2_(
8
9
  delta_g: torch.Tensor,
9
10
  cd: torch.Tensor,
10
11
  ) -> torch.Tensor:
11
- """cd is c * perturbation, and must be multiplied by two if hessian estimate is two-sided
12
- (or divide delta_g by two)."""
12
+ """cd is c * perturbation."""
13
13
 
14
- M = torch.outer(1.0 / cd, delta_g)
14
+ M = torch.outer(0.5 / cd, delta_g)
15
15
  H_hat = 0.5 * (M + M.T)
16
16
 
17
17
  return H_hat
18
18
 
19
19
 
20
20
 
21
- class SG2(Module):
21
+ class SG2(Transform):
22
22
  """second-order stochastic gradient
23
23
 
24
+ 2SPSA (second-order SPSA)
25
+ ```python
26
+ opt = tz.Optimizer(
27
+ model.parameters(),
28
+ tz.m.SPSA(),
29
+ tz.m.SG2(),
30
+ tz.m.LR(1e-2),
31
+ )
32
+ ```
33
+
24
34
  SG2 with line search
25
35
  ```python
26
- opt = tz.Modular(
36
+ opt = tz.Optimizer(
27
37
  model.parameters(),
28
38
  tz.m.SG2(),
29
39
  tz.m.Backtracking()
@@ -32,9 +42,9 @@ class SG2(Module):
32
42
 
33
43
  SG2 with trust region
34
44
  ```python
35
- opt = tz.Modular(
45
+ opt = tz.Optimizer(
36
46
  model.parameters(),
37
- tz.m.LevenbergMarquardt(tz.m.SG2()),
47
+ tz.m.LevenbergMarquardt(tz.m.SG2(beta=0.75. n_samples=4)),
38
48
  )
39
49
  ```
40
50
 
@@ -43,24 +53,22 @@ class SG2(Module):
43
53
  def __init__(
44
54
  self,
45
55
  n_samples: int = 1,
46
- h: float = 1e-2,
56
+ n_first_step_samples: int = 10,
57
+ start_step: int = 10,
47
58
  beta: float | None = None,
48
- damping: float = 0,
49
- eigval_fn=None,
50
- one_sided: bool = False, # one-sided hessian
51
- use_lstsq: bool = True,
59
+ damping: float = 1e-4,
60
+ h: float = 1e-2,
52
61
  seed=None,
62
+ update_freq: int = 1,
53
63
  inner: Chainable | None = None,
54
64
  ):
55
- defaults = dict(n_samples=n_samples, h=h, beta=beta, damping=damping, eigval_fn=eigval_fn, one_sided=one_sided, seed=seed, use_lstsq=use_lstsq)
56
- super().__init__(defaults)
57
-
58
- if inner is not None: self.set_child('inner', inner)
65
+ defaults = dict(n_samples=n_samples, h=h, beta=beta, damping=damping, seed=seed, start_step=start_step, n_first_step_samples=n_first_step_samples)
66
+ super().__init__(defaults, update_freq=update_freq, inner=inner)
59
67
 
60
68
  @torch.no_grad
61
- def update(self, objective):
62
- k = self.global_state.get('step', 0) + 1
63
- self.global_state["step"] = k
69
+ def update_states(self, objective, states, settings):
70
+ fs = settings[0]
71
+ k = self.increment_counter("step", 0)
64
72
 
65
73
  params = TensorList(objective.params)
66
74
  closure = objective.closure
@@ -68,36 +76,28 @@ class SG2(Module):
68
76
  raise RuntimeError("closure is required for SG2")
69
77
  generator = self.get_generator(params[0].device, self.defaults["seed"])
70
78
 
71
- h = self.get_settings(params, "h")
79
+ h = unpack_dicts(settings, "h")
72
80
  x_0 = params.clone()
73
- n_samples = self.defaults["n_samples"]
81
+ n_samples = fs["n_samples"]
82
+ if k == 0: n_samples = fs["n_first_step_samples"]
74
83
  H_hat = None
75
84
 
85
+ # compute new approximation
76
86
  for i in range(n_samples):
77
87
  # generate perturbation
78
88
  cd = params.rademacher_like(generator=generator).mul_(h)
79
89
 
80
- # one sided
81
- if self.defaults["one_sided"]:
82
- g_0 = TensorList(objective.get_grads())
83
- params.add_(cd)
84
- closure()
90
+ # two sided hessian approximation
91
+ params.add_(cd)
92
+ closure()
93
+ g_p = params.grad.fill_none_(params)
85
94
 
86
- g_p = params.grad.fill_none_(params)
87
- delta_g = (g_p - g_0) * 2
95
+ params.copy_(x_0)
96
+ params.sub_(cd)
97
+ closure()
98
+ g_n = params.grad.fill_none_(params)
88
99
 
89
- # two sided
90
- else:
91
- params.add_(cd)
92
- closure()
93
- g_p = params.grad.fill_none_(params)
94
-
95
- params.copy_(x_0)
96
- params.sub_(cd)
97
- closure()
98
- g_n = params.grad.fill_none_(params)
99
-
100
- delta_g = g_p - g_n
100
+ delta_g = g_p - g_n
101
101
 
102
102
  # restore params
103
103
  params.set_(x_0)
@@ -114,179 +114,43 @@ class SG2(Module):
114
114
  assert H_hat is not None
115
115
  if n_samples > 1: H_hat /= n_samples
116
116
 
117
+ # add damping
118
+ if fs["damping"] != 0:
119
+ reg = torch.eye(H_hat.size(0), device=H_hat.device, dtype=H_hat.dtype).mul_(fs["damping"])
120
+ H_hat += reg
121
+
117
122
  # update H
118
123
  H = self.global_state.get("H", None)
119
124
  if H is None: H = H_hat
120
125
  else:
121
- beta = self.defaults["beta"]
122
- if beta is None: beta = k / (k+1)
126
+ beta = fs["beta"]
127
+ if beta is None: beta = (k+1) / (k+2)
123
128
  H.lerp_(H_hat, 1-beta)
124
129
 
125
130
  self.global_state["H"] = H
126
131
 
127
132
 
128
133
  @torch.no_grad
129
- def apply(self, objective):
130
- dir = _newton_step(
131
- objective=objective,
132
- H = self.global_state["H"],
133
- damping = self.defaults["damping"],
134
- inner = self.children.get("inner", None),
135
- H_tfm=None,
136
- eigval_fn=self.defaults["eigval_fn"],
137
- use_lstsq=self.defaults["use_lstsq"],
138
- g_proj=None,
139
- )
140
-
141
- objective.updates = vec_to_tensors(dir, objective.params)
134
+ def apply_states(self, objective, states, settings):
135
+ fs = settings[0]
136
+ updates = objective.get_updates()
137
+
138
+ H: torch.Tensor = self.global_state["H"]
139
+ k = self.global_state["step"]
140
+ if k < fs["start_step"]:
141
+ # don't precondition yet
142
+ # I guess we can try using trace to scale the update
143
+ # because it will have horrible scaling otherwise
144
+ torch._foreach_div_(updates, H.trace())
145
+ return objective
146
+
147
+ b = torch.cat([t.ravel() for t in updates])
148
+ sol = torch.linalg.lstsq(H, b).solution # pylint:disable=not-callable
149
+
150
+ vec_to_tensors_(sol, updates)
142
151
  return objective
143
152
 
144
- def get_H(self,objective=...):
145
- return _get_H(self.global_state["H"], self.defaults["eigval_fn"])
146
-
147
-
148
-
149
-
150
- # two sided
151
- # we have g via x + d, x - d
152
- # H via g(x + d), g(x - d)
153
- # 1 is x, x+2d
154
- # 2 is x, x-2d
155
- # 5 evals in total
156
-
157
- # one sided
158
- # g via x, x + d
159
- # 1 is x, x + d
160
- # 2 is x, x - d
161
- # 3 evals and can use two sided for g_0
162
-
163
- class SPSA2(Module):
164
- """second-order SPSA
165
-
166
- SPSA2 with line search
167
- ```python
168
- opt = tz.Modular(
169
- model.parameters(),
170
- tz.m.SPSA2(),
171
- tz.m.Backtracking()
172
- )
173
- ```
174
-
175
- SPSA2 with trust region
176
- ```python
177
- opt = tz.Modular(
178
- model.parameters(),
179
- tz.m.LevenbergMarquardt(tz.m.SPSA2()),
180
- )
181
- ```
182
- """
183
-
184
- def __init__(
185
- self,
186
- n_samples: int = 1,
187
- h: float = 1e-2,
188
- beta: float | None = None,
189
- damping: float = 0,
190
- eigval_fn=None,
191
- use_lstsq: bool = True,
192
- seed=None,
193
- inner: Chainable | None = None,
194
- ):
195
- defaults = dict(n_samples=n_samples, h=h, beta=beta, damping=damping, eigval_fn=eigval_fn, seed=seed, use_lstsq=use_lstsq)
196
- super().__init__(defaults)
197
-
198
- if inner is not None: self.set_child('inner', inner)
199
-
200
- @torch.no_grad
201
- def update(self, objective):
202
- k = self.global_state.get('step', 0) + 1
203
- self.global_state["step"] = k
204
-
205
- params = TensorList(objective.params)
206
- closure = objective.closure
207
- if closure is None:
208
- raise RuntimeError("closure is required for SPSA2")
209
-
210
- generator = self.get_generator(params[0].device, self.defaults["seed"])
211
-
212
- h = self.get_settings(params, "h")
213
- x_0 = params.clone()
214
- n_samples = self.defaults["n_samples"]
215
- H_hat = None
216
- g_0 = None
217
-
218
- for i in range(n_samples):
219
- # perturbations for g and H
220
- cd_g = params.rademacher_like(generator=generator).mul_(h)
221
- cd_H = params.rademacher_like(generator=generator).mul_(h)
222
-
223
- # evaluate 4 points
224
- x_p = x_0 + cd_g
225
- x_n = x_0 - cd_g
226
-
227
- params.set_(x_p)
228
- f_p = closure(False)
229
- params.add_(cd_H)
230
- f_pp = closure(False)
231
-
232
- params.set_(x_n)
233
- f_n = closure(False)
234
- params.add_(cd_H)
235
- f_np = closure(False)
236
-
237
- g_p_vec = (f_pp - f_p) / cd_H
238
- g_n_vec = (f_np - f_n) / cd_H
239
- delta_g = g_p_vec - g_n_vec
240
-
241
- # restore params
242
- params.set_(x_0)
243
-
244
- # compute grad
245
- g_i = (f_p - f_n) / (2 * cd_g)
246
- if g_0 is None: g_0 = g_i
247
- else: g_0 += g_i
248
-
249
- # compute H hat
250
- H_i = sg2_(
251
- delta_g = delta_g.to_vec().div_(2.0),
252
- cd = cd_g.to_vec(), # The interval is measured by the original 'cd'
253
- )
254
- if H_hat is None: H_hat = H_i
255
- else: H_hat += H_i
256
-
257
- assert g_0 is not None and H_hat is not None
258
- if n_samples > 1:
259
- g_0 /= n_samples
260
- H_hat /= n_samples
261
-
262
- # set grad to approximated grad
263
- objective.grads = g_0
153
+ def get_H(self, objective=...):
154
+ return Dense(self.global_state["H"])
264
155
 
265
- # update H
266
- H = self.global_state.get("H", None)
267
- if H is None: H = H_hat
268
- else:
269
- beta = self.defaults["beta"]
270
- if beta is None: beta = k / (k+1)
271
- H.lerp_(H_hat, 1-beta)
272
-
273
- self.global_state["H"] = H
274
-
275
- @torch.no_grad
276
- def apply(self, objective):
277
- dir = _newton_step(
278
- objective=objective,
279
- H = self.global_state["H"],
280
- damping = self.defaults["damping"],
281
- inner = self.children.get("inner", None),
282
- H_tfm=None,
283
- eigval_fn=self.defaults["eigval_fn"],
284
- use_lstsq=self.defaults["use_lstsq"],
285
- g_proj=None,
286
- )
287
-
288
- objective.updates = vec_to_tensors(dir, objective.params)
289
- return objective
290
156
 
291
- def get_H(self,objective=...):
292
- return _get_H(self.global_state["H"], self.defaults["eigval_fn"])
@@ -8,8 +8,8 @@ from ...core import Chainable, Module, Objective
8
8
  from ...utils import TensorList
9
9
  from ..termination import TerminationCriteriaBase
10
10
 
11
- def _reset_except_self(optimizer, var, self: Module):
12
- for m in optimizer.unrolled_modules:
11
+ def _reset_except_self(objective, modules, self: Module):
12
+ for m in modules:
13
13
  if m is not self:
14
14
  m.reset()
15
15
 
@@ -26,15 +26,15 @@ class RestartStrategyBase(Module, ABC):
26
26
  self.set_child('modules', modules)
27
27
 
28
28
  @abstractmethod
29
- def should_reset(self, var: Objective) -> bool:
29
+ def should_reset(self, objective: Objective) -> bool:
30
30
  """returns whether reset should occur"""
31
31
 
32
- def _reset_on_condition(self, var):
32
+ def _reset_on_condition(self, objective: Objective):
33
33
  modules = self.children.get('modules', None)
34
34
 
35
- if self.should_reset(var):
35
+ if self.should_reset(objective):
36
36
  if modules is None:
37
- var.post_step_hooks.append(partial(_reset_except_self, self=self))
37
+ objective.post_step_hooks.append(partial(_reset_except_self, self=self))
38
38
  else:
39
39
  modules.reset()
40
40
 
@@ -78,11 +78,11 @@ class RestartOnStuck(RestartStrategyBase):
78
78
  super().__init__(defaults, modules)
79
79
 
80
80
  @torch.no_grad
81
- def should_reset(self, var):
81
+ def should_reset(self, objective):
82
82
  step = self.global_state.get('step', 0)
83
83
  self.global_state['step'] = step + 1
84
84
 
85
- params = TensorList(var.params)
85
+ params = TensorList(objective.params)
86
86
  tol = self.defaults['tol']
87
87
  if tol is None: tol = torch.finfo(params[0].dtype).tiny * 2
88
88
  n_tol = self.defaults['n_tol']
@@ -124,12 +124,12 @@ class RestartEvery(RestartStrategyBase):
124
124
  defaults = dict(steps=steps)
125
125
  super().__init__(defaults, modules)
126
126
 
127
- def should_reset(self, var):
127
+ def should_reset(self, objective):
128
128
  step = self.global_state.get('step', 0) + 1
129
129
  self.global_state['step'] = step
130
130
 
131
131
  n = self.defaults['steps']
132
- if isinstance(n, str): n = sum(p.numel() for p in var.params if p.requires_grad)
132
+ if isinstance(n, str): n = sum(p.numel() for p in objective.params if p.requires_grad)
133
133
 
134
134
  # reset every n steps
135
135
  if step % n == 0:
@@ -143,9 +143,9 @@ class RestartOnTerminationCriteria(RestartStrategyBase):
143
143
  super().__init__(None, modules)
144
144
  self.set_child('criteria', criteria)
145
145
 
146
- def should_reset(self, var):
146
+ def should_reset(self, objective):
147
147
  criteria = cast(TerminationCriteriaBase, self.children['criteria'])
148
- return criteria.should_terminate(var)
148
+ return criteria.should_terminate(objective)
149
149
 
150
150
  class PowellRestart(RestartStrategyBase):
151
151
  """Powell's two restarting criterions for conjugate gradient methods.
@@ -171,14 +171,14 @@ class PowellRestart(RestartStrategyBase):
171
171
  defaults=dict(cond1=cond1, cond2=cond2)
172
172
  super().__init__(defaults, modules)
173
173
 
174
- def should_reset(self, var):
175
- g = TensorList(var.get_grads())
174
+ def should_reset(self, objective):
175
+ g = TensorList(objective.get_grads())
176
176
  cond1 = self.defaults['cond1']; cond2 = self.defaults['cond2']
177
177
 
178
178
  # -------------------------------- initialize -------------------------------- #
179
179
  if 'initialized' not in self.global_state:
180
180
  self.global_state['initialized'] = 0
181
- g_prev = self.get_state(var.params, 'g_prev', init=g)
181
+ g_prev = self.get_state(objective.params, 'g_prev', init=g)
182
182
  return False
183
183
 
184
184
  g_g = g.dot(g)
@@ -186,7 +186,7 @@ class PowellRestart(RestartStrategyBase):
186
186
  reset = False
187
187
  # ------------------------------- 1st condition ------------------------------ #
188
188
  if cond1 is not None:
189
- g_prev = self.get_state(var.params, 'g_prev', must_exist=True, cls=TensorList)
189
+ g_prev = self.get_state(objective.params, 'g_prev', must_exist=True, cls=TensorList)
190
190
  g_g_prev = g_prev.dot(g)
191
191
 
192
192
  if g_g_prev.abs() >= cond1 * g_g:
@@ -194,7 +194,7 @@ class PowellRestart(RestartStrategyBase):
194
194
 
195
195
  # ------------------------------- 2nd condition ------------------------------ #
196
196
  if (cond2 is not None) and (not reset):
197
- d_g = TensorList(var.get_updates()).dot(g)
197
+ d_g = TensorList(objective.get_updates()).dot(g)
198
198
  if (-1-cond2) * g_g < d_g < (-1 + cond2) * g_g:
199
199
  reset = True
200
200
 
@@ -3,9 +3,9 @@ from collections.abc import Callable
3
3
  import torch
4
4
 
5
5
  from ...core import Chainable, Transform, HessianMethod
6
- from ...utils import TensorList, vec_to_tensors, unpack_states
7
- from ..functional import safe_clip
8
- from .newton import _get_H, _newton_step
6
+ from ...utils import TensorList, vec_to_tensors_, unpack_states
7
+ from ..opt_utils import safe_clip
8
+ from .newton import _newton_update_state_, _newton_solve, _newton_get_H
9
9
 
10
10
  @torch.no_grad
11
11
  def inm(f:torch.Tensor, J:torch.Tensor, s:torch.Tensor, y:torch.Tensor):
@@ -34,10 +34,10 @@ class ImprovedNewton(Transform):
34
34
  def __init__(
35
35
  self,
36
36
  damping: float = 0,
37
- use_lstsq: bool = False,
38
- update_freq: int = 1,
39
- H_tfm: Callable[[torch.Tensor, torch.Tensor], tuple[torch.Tensor, bool]] | Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
40
37
  eigval_fn: Callable[[torch.Tensor], torch.Tensor] | None = None,
38
+ update_freq: int = 1,
39
+ precompute_inverse: bool | None = None,
40
+ use_lstsq: bool = False,
41
41
  hessian_method: HessianMethod = "batched_autograd",
42
42
  h: float = 1e-3,
43
43
  inner: Chainable | None = None,
@@ -65,37 +65,45 @@ class ImprovedNewton(Transform):
65
65
  x_prev, f_prev = unpack_states(states, objective.params, "x_prev", "f_prev", cls=TensorList)
66
66
 
67
67
  # initialize on 1st step, do Newton step
68
- if "P" not in self.global_state:
68
+ if "H" not in self.global_state:
69
69
  x_prev.copy_(x_list)
70
70
  f_prev.copy_(f_list)
71
- self.global_state["P"] = J
72
- return
71
+ P = J
73
72
 
74
73
  # INM update
75
- s_list = x_list - x_prev
76
- y_list = f_list - f_prev
77
- x_prev.copy_(x_list)
78
- f_prev.copy_(f_list)
74
+ else:
75
+ s_list = x_list - x_prev
76
+ y_list = f_list - f_prev
77
+ x_prev.copy_(x_list)
78
+ f_prev.copy_(f_list)
79
79
 
80
- self.global_state["P"] = inm(f, J, s=s_list.to_vec(), y=y_list.to_vec())
80
+ P = inm(f, J, s=s_list.to_vec(), y=y_list.to_vec())
81
81
 
82
+ # update state
83
+ precompute_inverse = fs["precompute_inverse"]
84
+ if precompute_inverse is None:
85
+ precompute_inverse = fs["__update_freq"] >= 10
86
+
87
+ _newton_update_state_(
88
+ H=P,
89
+ state = self.global_state,
90
+ damping = fs["damping"],
91
+ eigval_fn = fs["eigval_fn"],
92
+ precompute_inverse = precompute_inverse,
93
+ use_lstsq = fs["use_lstsq"]
94
+ )
82
95
 
83
96
  @torch.no_grad
84
97
  def apply_states(self, objective, states, settings):
98
+ updates = objective.get_updates()
85
99
  fs = settings[0]
86
100
 
87
- update = _newton_step(
88
- objective = objective,
89
- H = self.global_state["P"],
90
- damping = fs["damping"],
91
- H_tfm = fs["H_tfm"],
92
- eigval_fn = None, # it is applied in `update`
93
- use_lstsq = fs["use_lstsq"],
94
- )
95
-
96
- objective.updates = vec_to_tensors(update, objective.params)
101
+ b = torch.cat([t.ravel() for t in updates])
102
+ sol = _newton_solve(b=b, state=self.global_state, use_lstsq=fs["use_lstsq"])
97
103
 
104
+ vec_to_tensors_(sol, updates)
98
105
  return objective
99
106
 
107
+
100
108
  def get_H(self,objective=...):
101
- return _get_H(self.global_state["P"], eigval_fn=None)
109
+ return _newton_get_H(self.global_state)