torchzero 0.3.14__py3-none-any.whl → 0.3.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. tests/test_opts.py +4 -3
  2. torchzero/core/__init__.py +4 -1
  3. torchzero/core/chain.py +50 -0
  4. torchzero/core/functional.py +37 -0
  5. torchzero/core/modular.py +237 -0
  6. torchzero/core/module.py +8 -599
  7. torchzero/core/reformulation.py +3 -1
  8. torchzero/core/transform.py +7 -5
  9. torchzero/core/var.py +376 -0
  10. torchzero/modules/__init__.py +0 -1
  11. torchzero/modules/adaptive/adahessian.py +2 -2
  12. torchzero/modules/adaptive/esgd.py +2 -2
  13. torchzero/modules/adaptive/matrix_momentum.py +1 -1
  14. torchzero/modules/adaptive/sophia_h.py +2 -2
  15. torchzero/modules/experimental/__init__.py +1 -0
  16. torchzero/modules/experimental/newtonnewton.py +5 -5
  17. torchzero/modules/experimental/spsa1.py +2 -2
  18. torchzero/modules/functional.py +7 -0
  19. torchzero/modules/line_search/__init__.py +1 -1
  20. torchzero/modules/line_search/_polyinterp.py +3 -1
  21. torchzero/modules/line_search/adaptive.py +3 -3
  22. torchzero/modules/line_search/backtracking.py +1 -1
  23. torchzero/modules/line_search/interpolation.py +160 -0
  24. torchzero/modules/line_search/line_search.py +11 -20
  25. torchzero/modules/line_search/strong_wolfe.py +3 -3
  26. torchzero/modules/misc/misc.py +2 -2
  27. torchzero/modules/misc/multistep.py +13 -13
  28. torchzero/modules/quasi_newton/__init__.py +2 -0
  29. torchzero/modules/quasi_newton/quasi_newton.py +15 -6
  30. torchzero/modules/quasi_newton/sg2.py +292 -0
  31. torchzero/modules/second_order/__init__.py +6 -3
  32. torchzero/modules/second_order/ifn.py +89 -0
  33. torchzero/modules/second_order/inm.py +105 -0
  34. torchzero/modules/second_order/newton.py +103 -193
  35. torchzero/modules/second_order/nystrom.py +1 -1
  36. torchzero/modules/second_order/rsn.py +227 -0
  37. torchzero/modules/wrappers/optim_wrapper.py +49 -42
  38. torchzero/utils/derivatives.py +19 -19
  39. torchzero/utils/linalg/linear_operator.py +50 -2
  40. {torchzero-0.3.14.dist-info → torchzero-0.3.15.dist-info}/METADATA +1 -1
  41. {torchzero-0.3.14.dist-info → torchzero-0.3.15.dist-info}/RECORD +44 -36
  42. torchzero/modules/higher_order/__init__.py +0 -1
  43. /torchzero/modules/{higher_order → experimental}/higher_order_newton.py +0 -0
  44. {torchzero-0.3.14.dist-info → torchzero-0.3.15.dist-info}/WHEEL +0 -0
  45. {torchzero-0.3.14.dist-info → torchzero-0.3.15.dist-info}/top_level.txt +0 -0
@@ -5,7 +5,9 @@ from typing import Any, Literal, final
5
5
  import torch
6
6
 
7
7
  from ..utils import TensorList, set_storage_, vec_to_tensors
8
- from .module import Chain, Chainable, Module, Var
8
+ from .chain import Chain
9
+ from .module import Chainable, Module
10
+ from .var import Var
9
11
 
10
12
  Target = Literal['grad', 'update', 'closure', 'params_direct', 'params_difference', 'update_difference']
11
13
 
@@ -86,7 +88,7 @@ class Transform(Module, ABC):
86
88
 
87
89
  @final
88
90
  @torch.no_grad
89
- def transform_update(
91
+ def update_transform(
90
92
  self,
91
93
  tensors: list[torch.Tensor],
92
94
  params: list[torch.Tensor],
@@ -123,7 +125,7 @@ class Transform(Module, ABC):
123
125
 
124
126
  @final
125
127
  @torch.no_grad
126
- def transform_apply(
128
+ def apply_transform(
127
129
  self,
128
130
  tensors: list[torch.Tensor],
129
131
  params: list[torch.Tensor],
@@ -190,7 +192,7 @@ class Transform(Module, ABC):
190
192
  ):
191
193
  """`params` will be used as keys and need to always point to same tensor objects.`"""
192
194
  states, settings = self._get_keyed_states_settings(params)
193
- self.transform_update(tensors=tensors, params=params, grads=grads, loss=loss, states=states, settings=settings)
195
+ self.update_transform(tensors=tensors, params=params, grads=grads, loss=loss, states=states, settings=settings)
194
196
 
195
197
 
196
198
  @final
@@ -204,7 +206,7 @@ class Transform(Module, ABC):
204
206
  ):
205
207
  """`params` will be used as keys and need to always point to same tensor objects.`"""
206
208
  states, settings = self._get_keyed_states_settings(params)
207
- return self.transform_apply(tensors=tensors, params=params, grads=grads, loss=loss, states=states, settings=settings)
209
+ return self.apply_transform(tensors=tensors, params=params, grads=grads, loss=loss, states=states, settings=settings)
208
210
 
209
211
 
210
212
  def pre_step(self, var: Var) -> None:
torchzero/core/var.py ADDED
@@ -0,0 +1,376 @@
1
+
2
+ import warnings
3
+ from abc import ABC, abstractmethod
4
+ from collections import ChainMap, defaultdict
5
+ from collections.abc import Callable, Iterable, MutableMapping, Sequence
6
+ from operator import itemgetter
7
+ from typing import Any, final, overload, Literal, cast, TYPE_CHECKING
8
+
9
+ import torch
10
+
11
+ from ..utils import (
12
+ Init,
13
+ ListLike,
14
+ Params,
15
+ _make_param_groups,
16
+ get_state_vals,
17
+ vec_to_tensors
18
+ )
19
+ from ..utils.derivatives import hvp, hvp_fd_central, hvp_fd_forward, flatten_jacobian
20
+ from ..utils.python_tools import flatten
21
+ from ..utils.linalg.linear_operator import LinearOperator
22
+
23
+ if TYPE_CHECKING:
24
+ from .modular import Modular
25
+
26
+ def _closure_backward(closure, params, retain_graph, create_graph):
27
+ with torch.enable_grad():
28
+ if not (retain_graph or create_graph):
29
+ return closure()
30
+
31
+ for p in params: p.grad = None
32
+ loss = closure(False)
33
+ grad = torch.autograd.grad(loss, params, retain_graph=retain_graph, create_graph=create_graph)
34
+ for p,g in zip(params,grad): p.grad = g
35
+ return loss
36
+
37
+ # region Vars
38
+ # ----------------------------------- var ----------------------------------- #
39
+ class Var:
40
+ """
41
+ Holds parameters, gradient, update, objective function (closure) if supplied, loss, and some other info.
42
+ Modules take in a ``Var`` object, modify and it is passed to the next module.
43
+
44
+ """
45
+ def __init__(
46
+ self,
47
+ params: list[torch.Tensor],
48
+ closure: Callable | None,
49
+ model: torch.nn.Module | None,
50
+ current_step: int,
51
+ parent: "Var | None" = None,
52
+ modular: "Modular | None" = None,
53
+ loss: torch.Tensor | None = None,
54
+ storage: dict | None = None,
55
+ ):
56
+ self.params: list[torch.Tensor] = params
57
+ """List of all parameters with requires_grad = True."""
58
+
59
+ self.closure = closure
60
+ """A closure that reevaluates the model and returns the loss, None if it wasn't specified"""
61
+
62
+ self.model = model
63
+ """torch.nn.Module object of the model, None if it wasn't specified."""
64
+
65
+ self.current_step: int = current_step
66
+ """global current step, starts at 0. This may not correspond to module current step,
67
+ for example a module may step every 10 global steps."""
68
+
69
+ self.parent: "Var | None" = parent
70
+ """parent ``Var`` object. When ``self.get_grad()`` is called, it will also set ``parent.grad``.
71
+ Same with ``self.get_loss()``. This is useful when ``self.params`` are different from ``parent.params``,
72
+ e.g. when projecting."""
73
+
74
+ self.modular: "Modular | None" = modular
75
+ """Modular optimizer object that created this ``Var``."""
76
+
77
+ self.update: list[torch.Tensor] | None = None
78
+ """
79
+ current update. Update is assumed to be a transformed gradient, therefore it is subtracted.
80
+
81
+ If closure is None, this is initially set to cloned gradient. Otherwise this is set to None.
82
+
83
+ At the end ``var.get_update()`` is subtracted from parameters. Therefore if ``var.update`` is ``None``,
84
+ gradient will be used and calculated if needed.
85
+ """
86
+
87
+ self.grad: list[torch.Tensor] | None = None
88
+ """gradient with current parameters. If closure is not ``None``, this is set to ``None`` and can be calculated if needed."""
89
+
90
+ self.loss: torch.Tensor | Any | None = loss
91
+ """loss with current parameters."""
92
+
93
+ self.loss_approx: torch.Tensor | Any | None = None
94
+ """loss at a point near current point. This can be useful as some modules only calculate loss at perturbed points,
95
+ whereas some other modules require loss strictly at current point."""
96
+
97
+ self.post_step_hooks: list[Callable[[Modular, Var]]] = []
98
+ """list of functions to be called after optimizer step.
99
+
100
+ This attribute should always be modified in-place (using ``append`` or ``extend``).
101
+
102
+ The signature is:
103
+
104
+ ```python
105
+ def hook(optimizer: Modular, var: Vars): ...
106
+ ```
107
+ """
108
+
109
+ self.stop: bool = False
110
+ """if True, all following modules will be skipped.
111
+ If this module is a child, it only affects modules at the same level (in the same Chain)."""
112
+
113
+ self.skip_update: bool = False
114
+ """if True, the parameters will not be updated."""
115
+
116
+ # self.storage: dict = {}
117
+ # """Storage for any other data, such as hessian estimates, etc."""
118
+
119
+ self.attrs: dict = {}
120
+ """attributes, Modular.attrs is updated with this after each step. This attribute should always be modified in-place"""
121
+
122
+ if storage is None: storage = {}
123
+ self.storage: dict = storage
124
+ """additional kwargs passed to closure will end up in this dict. This attribute should always be modified in-place"""
125
+
126
+ self.should_terminate: bool | None = None
127
+ """termination criteria, Modular.should_terminate is set to this after each step if not None"""
128
+
129
+ def get_loss(self, backward: bool, retain_graph = None, create_graph: bool = False) -> torch.Tensor | float:
130
+ """Returns the loss at current parameters, computing it if it hasn't been computed already and assigning ``var.loss``.
131
+ Do not call this at perturbed parameters. Backward always sets grads to None before recomputing."""
132
+ if self.loss is None:
133
+
134
+ if self.closure is None: raise RuntimeError("closure is None")
135
+ if backward:
136
+ with torch.enable_grad():
137
+ self.loss = self.loss_approx = _closure_backward(
138
+ closure=self.closure, params=self.params, retain_graph=retain_graph, create_graph=create_graph
139
+ )
140
+
141
+ # initializing to zeros_like is equivalent to using zero_grad with set_to_none = False.
142
+ # it is technically a more correct approach for when some parameters conditionally receive gradients
143
+ # and in this case it shouldn't be slower.
144
+
145
+ # next time closure() is called, it will set grad to None.
146
+ # zero_grad(set_to_none=False) shouldn't be used (I should add a warning)
147
+ self.grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in self.params]
148
+ else:
149
+ self.loss = self.loss_approx = self.closure(False)
150
+
151
+ # if self.loss was not None, above branch wasn't executed because loss has already been evaluated, but without backward since self.grad is None.
152
+ # and now it is requested to be evaluated with backward.
153
+ if backward and self.grad is None:
154
+ warnings.warn('get_loss was called with backward=False, and then with backward=True so it had to be re-evaluated, so the closure was evaluated twice where it could have been evaluated once.')
155
+ if self.closure is None: raise RuntimeError("closure is None")
156
+
157
+ with torch.enable_grad():
158
+ self.loss = self.loss_approx = _closure_backward(
159
+ closure=self.closure, params=self.params, retain_graph=retain_graph, create_graph=create_graph
160
+ )
161
+ self.grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in self.params]
162
+
163
+ # set parent grad
164
+ if self.parent is not None:
165
+ # the way projections/split work, they make a new closure which evaluates original
166
+ # closure and projects the gradient, and set it as their var.closure.
167
+ # then on `get_loss(backward=True)` it is called, so it also sets original parameters gradient.
168
+ # and we set it to parent var here.
169
+ if self.parent.loss is None: self.parent.loss = self.loss
170
+ if self.parent.grad is None and backward:
171
+ if all(p.grad is None for p in self.parent.params):
172
+ warnings.warn("Parent grad is None after backward.")
173
+ self.parent.grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in self.parent.params]
174
+
175
+ return self.loss # type:ignore
176
+
177
+ def get_grad(self, retain_graph: bool | None = None, create_graph: bool = False) -> list[torch.Tensor]:
178
+ """Returns the gradient at initial parameters, computing it if it hasn't been computed already and assigning
179
+ ``var.grad`` and potentially ``var.loss``. Do not call this at perturbed parameters."""
180
+ if self.grad is None:
181
+ if self.closure is None: raise RuntimeError("closure is None")
182
+ self.get_loss(backward=True, retain_graph=retain_graph, create_graph=create_graph) # evaluate and set self.loss and self.grad
183
+
184
+ assert self.grad is not None
185
+ return self.grad
186
+
187
+ def get_update(self) -> list[torch.Tensor]:
188
+ """Returns the update. If update is None, it is initialized by cloning the gradients and assigning to ``var.update``.
189
+ Computing the gradients may assign ``var.grad`` and ``var.loss`` if they haven't been computed.
190
+ Do not call this at perturbed parameters."""
191
+ if self.update is None: self.update = [g.clone() for g in self.get_grad()]
192
+ return self.update
193
+
194
+ def clone(self, clone_update: bool, parent: "Var | None" = None):
195
+ """Creates a shallow copy of the Vars object, update can optionally be deep-copied (via ``torch.clone``).
196
+
197
+ Setting ``parent`` is only if clone's parameters are something different,
198
+ while clone's closure referes to the same objective but with a "view" on parameters.
199
+ """
200
+ copy = Var(params = self.params, closure=self.closure, model=self.model, current_step=self.current_step, parent=parent)
201
+
202
+ if clone_update and self.update is not None:
203
+ copy.update = [u.clone() for u in self.update]
204
+ else:
205
+ copy.update = self.update
206
+
207
+ copy.grad = self.grad
208
+ copy.loss = self.loss
209
+ copy.loss_approx = self.loss_approx
210
+ copy.closure = self.closure
211
+ copy.post_step_hooks = self.post_step_hooks
212
+ copy.stop = self.stop
213
+ copy.skip_update = self.skip_update
214
+
215
+ copy.modular = self.modular
216
+ copy.attrs = self.attrs
217
+ copy.storage = self.storage
218
+ copy.should_terminate = self.should_terminate
219
+
220
+ return copy
221
+
222
+ def update_attrs_from_clone_(self, var: "Var"):
223
+ """Updates attributes of this `Vars` instance from a cloned instance.
224
+ Typically called after a child module has processed a cloned `Vars`
225
+ object. This propagates any newly computed loss or gradient values
226
+ from the child's context back to the parent `Vars` if the parent
227
+ didn't have them computed already.
228
+
229
+ Also, as long as ``post_step_hooks`` and ``attrs`` are modified in-place,
230
+ if the child updates them, the update will affect the parent too.
231
+ """
232
+ if self.loss is None: self.loss = var.loss
233
+ if self.loss_approx is None: self.loss_approx = var.loss_approx
234
+ if self.grad is None: self.grad = var.grad
235
+
236
+ if var.should_terminate is not None: self.should_terminate = var.should_terminate
237
+
238
+ def zero_grad(self, set_to_none=True):
239
+ if set_to_none:
240
+ for p in self.params: p.grad = None
241
+ else:
242
+ grads = [p.grad for p in self.params if p.grad is not None]
243
+ if len(grads) != 0: torch._foreach_zero_(grads)
244
+
245
+
246
+ # ------------------------------ HELPER METHODS ------------------------------ #
247
+ @torch.no_grad
248
+ def hessian_vector_product(
249
+ self,
250
+ v: Sequence[torch.Tensor],
251
+ at_x0: bool,
252
+ rgrad: Sequence[torch.Tensor] | None,
253
+ hvp_method: Literal['autograd', 'forward', 'central'],
254
+ h: float,
255
+ normalize: bool,
256
+ retain_graph: bool,
257
+ ) -> tuple[list[torch.Tensor], Sequence[torch.Tensor] | None]:
258
+ """
259
+ Returns ``(Hvp, rgrad)``, where ``rgrad`` is gradient at current parameters,
260
+ possibly with ``create_graph=True``, or it may be None with ``hvp_method="central"``.
261
+ Gradient is set to vars automatically if ``at_x0``, you can always access it with ``vars.get_grad()``
262
+
263
+ Single sample example:
264
+
265
+ ```python
266
+ Hvp, _ = self.hessian_vector_product(v, at_x0=True, rgrad=None, ..., retain_graph=False)
267
+ ```
268
+
269
+ Multiple samples example:
270
+
271
+ ```python
272
+ D = None
273
+ rgrad = None
274
+ for i in range(n_samples):
275
+ v = [torch.randn_like(p) for p in params]
276
+ Hvp, rgrad = self.hessian_vector_product(v, at_x0=True, rgrad=rgrad, ..., retain_graph=i < n_samples-1)
277
+
278
+ if D is None: D = Hvp
279
+ else: torch._foreach_add_(D, Hvp)
280
+
281
+ if n_samples > 1: torch._foreach_div_(D, n_samples)
282
+ ```
283
+
284
+ Args:
285
+ v (Sequence[torch.Tensor]): vector in hessian-vector product
286
+ at_x0 (bool): whether this is being called at original or perturbed parameters.
287
+ var (Var): Var
288
+ rgrad (Sequence[torch.Tensor] | None): pass None initially, then pass what this returns.
289
+ hvp_method (str): hvp method.
290
+ h (float): finite difference step size
291
+ normalize (bool): whether to normalize v for finite difference
292
+ retain_grad (bool): retain grad
293
+ """
294
+ # get grad
295
+ if rgrad is None and hvp_method in ('autograd', 'forward'):
296
+ if at_x0: rgrad = self.get_grad(create_graph = hvp_method=='autograd')
297
+ else:
298
+ if self.closure is None: raise RuntimeError("Closure is required to calculate HVp")
299
+ with torch.enable_grad():
300
+ loss = self.closure()
301
+ rgrad = torch.autograd.grad(loss, self.params, create_graph = hvp_method=='autograd')
302
+
303
+ if hvp_method == 'autograd':
304
+ assert rgrad is not None
305
+ Hvp = hvp(self.params, rgrad, v, retain_graph=retain_graph)
306
+
307
+ elif hvp_method == 'forward':
308
+ assert rgrad is not None
309
+ loss, Hvp = hvp_fd_forward(self.closure, self.params, v, h=h, g_0=rgrad, normalize=normalize)
310
+
311
+ elif hvp_method == 'central':
312
+ loss, Hvp = hvp_fd_central(self.closure, self.params, v, h=h, normalize=normalize)
313
+
314
+ else:
315
+ raise ValueError(hvp_method)
316
+
317
+ return list(Hvp), rgrad
318
+
319
+ @torch.no_grad
320
+ def hessian_matrix_product(
321
+ self,
322
+ M: torch.Tensor,
323
+ at_x0: bool,
324
+ rgrad: Sequence[torch.Tensor] | None,
325
+ hvp_method: Literal["batched", 'autograd', 'forward', 'central'],
326
+ h: float,
327
+ normalize: bool,
328
+ retain_graph: bool,
329
+ ) -> tuple[torch.Tensor, Sequence[torch.Tensor] | None]:
330
+ """M is (n_dim, n_hvps), computes H @ M - (n_dim, n_hvps)."""
331
+
332
+ # get grad
333
+ if rgrad is None and hvp_method in ('autograd', 'forward', "batched"):
334
+ if at_x0: rgrad = self.get_grad(create_graph = hvp_method in ('autograd', "batched"))
335
+ else:
336
+ if self.closure is None: raise RuntimeError("Closure is required to calculate HVp")
337
+ with torch.enable_grad():
338
+ loss = self.closure()
339
+ create_graph = hvp_method in ('autograd', "batched")
340
+ rgrad = torch.autograd.grad(loss, self.params, create_graph=create_graph)
341
+
342
+ if hvp_method == "batched":
343
+ assert rgrad is not None
344
+ with torch.enable_grad():
345
+ flat_inputs = torch.cat([g.ravel() for g in rgrad])
346
+ HM_list = torch.autograd.grad(flat_inputs, self.params, grad_outputs=M.T, is_grads_batched=True, retain_graph=retain_graph)
347
+ HM = flatten_jacobian(HM_list).T
348
+
349
+ elif hvp_method == 'autograd':
350
+ assert rgrad is not None
351
+ with torch.enable_grad():
352
+ flat_inputs = torch.cat([g.ravel() for g in rgrad])
353
+ HV_tensors = [torch.autograd.grad(
354
+ flat_inputs, self.params, grad_outputs=col,
355
+ retain_graph = retain_graph or (i < M.size(1) - 1)
356
+ ) for i,col in enumerate(M.unbind(1))]
357
+ HM_list = [torch.cat([t.ravel() for t in tensors]) for tensors in HV_tensors]
358
+ HM = torch.stack(HM_list, 1)
359
+
360
+ elif hvp_method == 'forward':
361
+ assert rgrad is not None
362
+ HV_tensors = [hvp_fd_forward(self.closure, self.params, vec_to_tensors(col, self.params), h=h, g_0=rgrad, normalize=normalize)[1] for col in M.unbind(1)]
363
+ HM_list = [torch.cat([t.ravel() for t in tensors]) for tensors in HV_tensors]
364
+ HM = flatten_jacobian(HM_list)
365
+
366
+ elif hvp_method == 'central':
367
+ HV_tensors = [hvp_fd_central(self.closure, self.params, vec_to_tensors(col, self.params), h=h, normalize=normalize)[1] for col in M.unbind(1)]
368
+ HM_list = [torch.cat([t.ravel() for t in tensors]) for tensors in HV_tensors]
369
+ HM = flatten_jacobian(HM_list)
370
+
371
+ else:
372
+ raise ValueError(hvp_method)
373
+
374
+ return HM, rgrad
375
+
376
+ # endregion
@@ -2,7 +2,6 @@ from . import experimental
2
2
  from .clipping import *
3
3
  from .conjugate_gradient import *
4
4
  from .grad_approximation import *
5
- from .higher_order import *
6
5
  from .least_squares import *
7
6
  from .line_search import *
8
7
  from .misc import *
@@ -193,8 +193,8 @@ class AdaHessian(Module):
193
193
  for i in range(n_samples):
194
194
  u = [_rademacher_like(p, generator=generator) for p in params]
195
195
 
196
- Hvp, rgrad = self.Hvp(u, at_x0=True, var=var, rgrad=rgrad, hvp_method=hvp_method,
197
- h=fd_h, normalize=True, retain_grad=i < n_samples-1)
196
+ Hvp, rgrad = var.hessian_vector_product(u, at_x0=True, rgrad=rgrad, hvp_method=hvp_method,
197
+ h=fd_h, normalize=True, retain_graph=i < n_samples-1)
198
198
  Hvp = tuple(Hvp)
199
199
 
200
200
  if D is None: D = Hvp
@@ -144,8 +144,8 @@ class ESGD(Module):
144
144
  for j in range(n_samples):
145
145
  u = [torch.randn(p.size(), generator=generator, device=p.device, dtype=p.dtype) for p in params]
146
146
 
147
- Hvp, rgrad = self.Hvp(u, at_x0=True, var=var, rgrad=rgrad, hvp_method=hvp_method,
148
- h=fd_h, normalize=True, retain_grad=j < n_samples-1)
147
+ Hvp, rgrad = var.hessian_vector_product(u, at_x0=True, rgrad=rgrad, hvp_method=hvp_method,
148
+ h=fd_h, normalize=True, retain_graph=j < n_samples-1)
149
149
 
150
150
  if D is None: D = Hvp
151
151
  else: torch._foreach_add_(D, Hvp)
@@ -74,7 +74,7 @@ class MatrixMomentum(Module):
74
74
  if step > 0:
75
75
  s = p - p_prev
76
76
 
77
- Hs, _ = self.Hvp(s, at_x0=True, var=var, rgrad=None, hvp_method=hvp_method, h=h, normalize=True, retain_grad=False)
77
+ Hs, _ = var.hessian_vector_product(s, at_x0=True, rgrad=None, hvp_method=hvp_method, h=h, normalize=True, retain_graph=False)
78
78
  Hs = [t.detach() for t in Hs]
79
79
 
80
80
  if 'hvp_tfm' in self.children:
@@ -155,8 +155,8 @@ class SophiaH(Module):
155
155
  for i in range(n_samples):
156
156
  u = [torch.randn(p.shape, device=p.device, dtype=p.dtype, generator=generator) for p in params]
157
157
 
158
- Hvp, rgrad = self.Hvp(u, at_x0=True, var=var, rgrad=rgrad, hvp_method=hvp_method,
159
- h=fd_h, normalize=True, retain_grad=i < n_samples-1)
158
+ Hvp, rgrad = var.hessian_vector_product(u, at_x0=True, rgrad=rgrad, hvp_method=hvp_method,
159
+ h=fd_h, normalize=True, retain_graph=i < n_samples-1)
160
160
  Hvp = tuple(Hvp)
161
161
 
162
162
  if h is None: h = Hvp
@@ -4,6 +4,7 @@ from .curveball import CurveBall
4
4
  # from dct import DCTProjection
5
5
  from .fft import FFTProjection
6
6
  from .gradmin import GradMin
7
+ from .higher_order_newton import HigherOrderNewton
7
8
  from .l_infinity import InfinityNormTrustRegion
8
9
  from .momentum import (
9
10
  CoordinateMomentum,
@@ -45,9 +45,9 @@ class NewtonNewton(Module):
45
45
  order: int = 3,
46
46
  search_negative: bool = False,
47
47
  vectorize: bool = True,
48
- eigval_tfm: Callable[[torch.Tensor], torch.Tensor] | None = None,
48
+ eigval_fn: Callable[[torch.Tensor], torch.Tensor] | None = None,
49
49
  ):
50
- defaults = dict(order=order, reg=reg, vectorize=vectorize, eigval_tfm=eigval_tfm, search_negative=search_negative)
50
+ defaults = dict(order=order, reg=reg, vectorize=vectorize, eigval_fn=eigval_fn, search_negative=search_negative)
51
51
  super().__init__(defaults)
52
52
 
53
53
  @torch.no_grad
@@ -61,7 +61,7 @@ class NewtonNewton(Module):
61
61
  vectorize = settings['vectorize']
62
62
  order = settings['order']
63
63
  search_negative = settings['search_negative']
64
- eigval_tfm = settings['eigval_tfm']
64
+ eigval_fn = settings['eigval_fn']
65
65
 
66
66
  # ------------------------ calculate grad and hessian ------------------------ #
67
67
  Hs = []
@@ -82,8 +82,8 @@ class NewtonNewton(Module):
82
82
  Hs.append(H)
83
83
 
84
84
  x = None
85
- if search_negative or (is_last and eigval_tfm is not None):
86
- x = _eigh_solve(H, xp, eigval_tfm, search_negative=search_negative)
85
+ if search_negative or (is_last and eigval_fn is not None):
86
+ x = _eigh_solve(H, xp, eigval_fn, search_negative=search_negative)
87
87
  if x is None: x = _cholesky_solve(H, xp)
88
88
  if x is None: x = _lu_solve(H, xp)
89
89
  if x is None: x = _least_squares_solve(H, xp)
@@ -48,7 +48,7 @@ class SPSA1(GradApproximator):
48
48
  n_samples = self.defaults['n_samples']
49
49
  h = self.get_settings(var.params, 'h')
50
50
 
51
- perturbations = [params.sample_like(distribution='rademacher', generator=generator) for _ in range(n_samples)]
51
+ perturbations = [params.rademacher_like(generator=generator) for _ in range(n_samples)]
52
52
  torch._foreach_mul_([p for l in perturbations for p in l], [v for vv in h for v in [vv]*n_samples])
53
53
 
54
54
  for param, prt in zip(params, zip(*perturbations)):
@@ -74,7 +74,7 @@ class SPSA1(GradApproximator):
74
74
  prt = perturbations[i]
75
75
 
76
76
  if prt[0] is None:
77
- prt = params.sample_like('rademacher', generator=generator).mul_(h)
77
+ prt = params.rademacher_like(generator=generator).mul_(h)
78
78
 
79
79
  else: prt = TensorList(prt)
80
80
 
@@ -253,3 +253,10 @@ def safe_clip(x: torch.Tensor, min=None):
253
253
 
254
254
  if x.abs() < min: return x.new_full(x.size(), min).copysign(x)
255
255
  return x
256
+
257
+
258
+ def clip_by_finfo(x, finfo: torch.finfo):
259
+ """clips by (dtype.max / 2, dtype.min / 2)"""
260
+ if x > finfo.max / 2: return finfo.max / 2
261
+ if x < finfo.min / 2: return finfo.min / 2
262
+ return x
@@ -1,4 +1,4 @@
1
- from .adaptive import AdaptiveTracking
1
+ from .adaptive import AdaptiveBisection
2
2
  from .backtracking import AdaptiveBacktracking, Backtracking
3
3
  from .line_search import LineSearchBase
4
4
  from .scipy import ScipyMinimizeScalar
@@ -2,7 +2,7 @@ import numpy as np
2
2
  import torch
3
3
 
4
4
  from .line_search import LineSearchBase
5
-
5
+ from ...utils import tofloat
6
6
 
7
7
  # polynomial interpolation
8
8
  # this code is from https://github.com/hjmshi/PyTorch-LBFGS/blob/master/functions/LBFGS.py
@@ -284,6 +284,8 @@ def polyinterp2(points, lb, ub, unbounded: bool = False):
284
284
  x_sol = _cubic_interp(p, lb, ub)
285
285
  if x_sol is not None and _within_bounds(x_sol, lb, ub): return x_sol
286
286
 
287
+ if lb is not None: lb = tofloat(lb)
288
+ if ub is not None: ub = tofloat(ub)
287
289
  x_sol = _poly_interp(points, lb, ub)
288
290
  if x_sol is not None and _within_bounds(x_sol, lb, ub): return x_sol
289
291
  return polyinterp2(points[1:], lb, ub)
@@ -10,7 +10,7 @@ import torch
10
10
  from .line_search import LineSearchBase, TerminationCondition, termination_condition
11
11
 
12
12
 
13
- def adaptive_tracking(
13
+ def adaptive_bisection(
14
14
  f,
15
15
  a_init,
16
16
  maxiter: int,
@@ -56,7 +56,7 @@ def adaptive_tracking(
56
56
  return 0, f_0, niter
57
57
 
58
58
 
59
- class AdaptiveTracking(LineSearchBase):
59
+ class AdaptiveBisection(LineSearchBase):
60
60
  """A line search that evaluates previous step size, if value increased, backtracks until the value stops decreasing,
61
61
  otherwise forward-tracks until value stops decreasing.
62
62
 
@@ -98,7 +98,7 @@ class AdaptiveTracking(LineSearchBase):
98
98
  if a_init < torch.finfo(var.params[0].dtype).tiny * 2:
99
99
  a_init = torch.finfo(var.params[0].dtype).max / 2
100
100
 
101
- step_size, f, niter = adaptive_tracking(
101
+ step_size, f, niter = adaptive_bisection(
102
102
  objective,
103
103
  a_init=a_init,
104
104
  maxiter=maxiter,
@@ -136,7 +136,7 @@ class Backtracking(LineSearchBase):
136
136
  if adaptive:
137
137
  finfo = torch.finfo(var.params[0].dtype)
138
138
  if init_scale <= finfo.tiny * 2:
139
- self.global_state["init_scale"] = finfo.max / 2
139
+ self.global_state["init_scale"] = init * 2
140
140
  else:
141
141
  self.global_state['init_scale'] = init_scale * beta**maxiter
142
142
  return 0