torchzero 0.4.1__py3-none-any.whl → 0.4.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. tests/test_identical.py +1 -1
  2. torchzero/__init__.py +3 -1
  3. torchzero/_minimize/__init__.py +0 -0
  4. torchzero/_minimize/methods.py +95 -0
  5. torchzero/_minimize/minimize.py +518 -0
  6. torchzero/core/__init__.py +5 -5
  7. torchzero/core/chain.py +2 -1
  8. torchzero/core/functional.py +2 -1
  9. torchzero/core/module.py +75 -4
  10. torchzero/core/transform.py +6 -5
  11. torchzero/linalg/eigh.py +116 -68
  12. torchzero/linalg/linear_operator.py +1 -0
  13. torchzero/linalg/orthogonalize.py +60 -5
  14. torchzero/linalg/sketch.py +39 -0
  15. torchzero/modules/__init__.py +1 -0
  16. torchzero/modules/adaptive/adagrad.py +2 -0
  17. torchzero/modules/adaptive/adam.py +5 -1
  18. torchzero/modules/adaptive/adan.py +3 -0
  19. torchzero/modules/adaptive/ggt.py +20 -18
  20. torchzero/modules/adaptive/lion.py +3 -1
  21. torchzero/modules/adaptive/mars.py +6 -5
  22. torchzero/modules/adaptive/msam.py +3 -0
  23. torchzero/modules/adaptive/rmsprop.py +2 -0
  24. torchzero/modules/adaptive/rprop.py +9 -7
  25. torchzero/modules/adaptive/shampoo.py +9 -1
  26. torchzero/modules/adaptive/soap.py +32 -29
  27. torchzero/modules/basis/__init__.py +2 -0
  28. torchzero/modules/basis/ggt_basis.py +199 -0
  29. torchzero/modules/basis/soap_basis.py +254 -0
  30. torchzero/modules/clipping/ema_clipping.py +32 -27
  31. torchzero/modules/clipping/growth_clipping.py +1 -0
  32. torchzero/modules/experimental/__init__.py +1 -6
  33. torchzero/modules/experimental/coordinate_momentum.py +2 -0
  34. torchzero/modules/experimental/cubic_adam.py +4 -0
  35. torchzero/modules/grad_approximation/__init__.py +3 -2
  36. torchzero/modules/least_squares/gn.py +6 -0
  37. torchzero/modules/misc/gradient_accumulation.py +1 -0
  38. torchzero/modules/misc/misc.py +6 -0
  39. torchzero/modules/momentum/averaging.py +6 -0
  40. torchzero/modules/momentum/momentum.py +13 -9
  41. torchzero/modules/ops/__init__.py +0 -1
  42. torchzero/modules/ops/accumulate.py +4 -0
  43. torchzero/modules/ops/higher_level.py +6 -1
  44. torchzero/modules/second_order/inm.py +4 -0
  45. torchzero/modules/second_order/newton.py +11 -3
  46. torchzero/modules/second_order/newton_cg.py +7 -3
  47. torchzero/modules/second_order/nystrom.py +14 -19
  48. torchzero/modules/second_order/rsn.py +37 -6
  49. torchzero/modules/trust_region/trust_region.py +2 -1
  50. torchzero/utils/benchmarks/logistic.py +33 -18
  51. torchzero/utils/optuna_tools.py +1 -1
  52. torchzero/utils/params.py +13 -1
  53. torchzero/utils/tensorlist.py +2 -2
  54. {torchzero-0.4.1.dist-info → torchzero-0.4.3.dist-info}/METADATA +1 -1
  55. {torchzero-0.4.1.dist-info → torchzero-0.4.3.dist-info}/RECORD +58 -55
  56. torchzero/modules/experimental/adanystrom.py +0 -258
  57. torchzero/modules/experimental/common_directions_whiten.py +0 -142
  58. torchzero/modules/experimental/eigen_sr1.py +0 -182
  59. torchzero/modules/experimental/eigengrad.py +0 -207
  60. /torchzero/modules/{experimental → grad_approximation}/spsa1.py +0 -0
  61. {torchzero-0.4.1.dist-info → torchzero-0.4.3.dist-info}/WHEEL +0 -0
  62. {torchzero-0.4.1.dist-info → torchzero-0.4.3.dist-info}/top_level.txt +0 -0
tests/test_identical.py CHANGED
@@ -105,7 +105,7 @@ def test_adam(amsgrad):
105
105
  tz_fn_ops = lambda p: tz.Optimizer(
106
106
  p,
107
107
  tz.m.DivModules(
108
- tz.m.EMA(0.9, debiased=True),
108
+ tz.m.EMA(0.9, debias=True),
109
109
  [tz.m.SqrtEMASquared(0.999, debiased=True, amsgrad=amsgrad), tz.m.Add(1e-8)]
110
110
  ))
111
111
  tz_fn_ops2 = lambda p: tz.Optimizer(
torchzero/__init__.py CHANGED
@@ -1,4 +1,6 @@
1
1
  from . import core, optim, utils
2
2
  from .core import Optimizer
3
3
  from .utils.compile import enable_compilation
4
- from . import modules as m
4
+ from . import modules as m
5
+
6
+ from ._minimize.minimize import minimize
File without changes
@@ -0,0 +1,95 @@
1
+ """WIP API"""
2
+ import itertools
3
+ import time
4
+ from collections import deque
5
+ from collections.abc import Callable, Sequence, Mapping, Iterable
6
+ from typing import Any, NamedTuple, cast, overload
7
+
8
+ import numpy as np
9
+ import torch
10
+
11
+ from .. import m
12
+ from ..core import Module, Optimizer
13
+ from ..utils import tofloat
14
+
15
+
16
+ def _get_method_from_str(method: str) -> list[Module]:
17
+ method = ''.join(c for c in method.lower().strip() if c.isalnum())
18
+
19
+ if method == "bfgs":
20
+ return [m.RestartOnStuck(m.BFGS()), m.Backtracking()]
21
+
22
+ if method == "lbfgs":
23
+ return [m.LBFGS(100), m.Backtracking()]
24
+
25
+ if method == "newton":
26
+ return [m.Newton(), m.Backtracking()]
27
+
28
+ if method == "sfn":
29
+ return [m.Newton(eigval_fn=lambda x: x.abs().clip(min=1e-10)), m.Backtracking()]
30
+
31
+ if method == "inm":
32
+ return [m.ImprovedNewton(), m.Backtracking()]
33
+
34
+ if method == 'crn':
35
+ return [m.CubicRegularization(m.Newton())]
36
+
37
+ if method == "commondirections":
38
+ return [m.SubspaceNewton(sketch_type='common_directions'), m.Backtracking()]
39
+
40
+ if method == "trust":
41
+ return [m.LevenbergMarquardt(m.Newton())]
42
+
43
+ if method == "trustexact":
44
+ return [m.TrustCG(m.Newton())]
45
+
46
+ if method == "dogleg":
47
+ return [m.Dogleg(m.Newton())]
48
+
49
+ if method == "trustbfgs":
50
+ return [m.LevenbergMarquardt(m.BFGS())]
51
+
52
+ if method == "trustsr1":
53
+ return [m.LevenbergMarquardt(m.SR1())]
54
+
55
+ if method == "newtoncg":
56
+ return [m.NewtonCG(), m.Backtracking()]
57
+
58
+ if method == "tn":
59
+ return [m.NewtonCG(maxiter=10), m.Backtracking()]
60
+
61
+ if method == "trustncg":
62
+ return [m.NewtonCGSteihaug()]
63
+
64
+ if method == "gd":
65
+ return [m.Backtracking()]
66
+
67
+ if method == "cg":
68
+ return [m.FletcherReeves(), m.StrongWolfe(c2=0.1, fallback=True)]
69
+
70
+ if method == "bb":
71
+ return [m.RestartOnStuck(m.BarzilaiBorwein())]
72
+
73
+ if method == "bbstab":
74
+ return [m.BBStab()]
75
+
76
+ if method == "adgd":
77
+ return [m.AdGD()]
78
+
79
+ if method in ("gn", "gaussnewton"):
80
+ return [m.GaussNewton(), m.Backtracking()]
81
+
82
+ if method == "rprop":
83
+ return [m.Rprop(alpha=1e-3)]
84
+
85
+ if method == "lm":
86
+ return [m.LevenbergMarquardt(m.GaussNewton())]
87
+
88
+ if method == "mlm":
89
+ return [m.LevenbergMarquardt(m.GaussNewton(), y=1)]
90
+
91
+ if method == "cd":
92
+ return [m.CD(), m.ScipyMinimizeScalar(maxiter=8)]
93
+
94
+
95
+ raise NotImplementedError(method)
@@ -0,0 +1,518 @@
1
+ """WIP API"""
2
+ import itertools
3
+ import time
4
+ from collections import deque
5
+ from collections.abc import Callable, Iterable, Mapping, Sequence
6
+ from typing import Any, NamedTuple, cast, overload
7
+
8
+ import numpy as np
9
+ import torch
10
+
11
+ from ..core import Module, Optimizer
12
+ from ..utils import tofloat
13
+ from .methods import _get_method_from_str
14
+
15
+ _fn_autograd = Callable[[torch.Tensor], torch.Tensor | Any]
16
+ _fn_custom_grad = Callable[[torch.Tensor, torch.Tensor], torch.Tensor]
17
+ _scalar = float | np.ndarray | torch.Tensor
18
+ _method = str | Module | Sequence[Module] | Callable[..., torch.optim.Optimizer]
19
+
20
+ def _tensorlist_norm(tensors: Iterable[torch.Tensor], ord) -> torch.Tensor:
21
+ """returns a scalar - global norm of tensors"""
22
+ if ord == torch.inf:
23
+ return max(torch._foreach_max(torch._foreach_abs(tuple(tensors))))
24
+
25
+ if ord == 1:
26
+ return cast(torch.Tensor, sum(t.abs().sum() for t in tensors))
27
+
28
+ if ord % 2 != 0:
29
+ tensors = torch._foreach_abs(tuple(tensors))
30
+
31
+ tensors = torch._foreach_pow(tuple(tensors), ord)
32
+ return sum(t.sum() for t in tensors) ** (1 / ord)
33
+
34
+
35
+
36
+ class Params:
37
+ __slots__ = ("args", "kwargs")
38
+ def __init__(self, args: Sequence[torch.Tensor], kwargs: Mapping[str, torch.Tensor]):
39
+ self.args = tuple(args)
40
+ self.kwargs = dict(kwargs)
41
+
42
+ @property
43
+ def x(self):
44
+ assert len(self.args) == 1
45
+ assert len(self.kwargs) == 0
46
+ return self.args[0]
47
+
48
+ def parameters(self):
49
+ yield from self.args
50
+ yield from self.kwargs.values()
51
+
52
+ def clone(self):
53
+ return Params(
54
+ args = [a.clone() for a in self.args],
55
+ kwargs={k:v.clone() for k,v in self.kwargs.items()}
56
+ )
57
+
58
+ def __repr__(self):
59
+ if len(self.args) == 1 and len(self.kwargs) == 0:
60
+ return f"Params({repr(self.x)})"
61
+
62
+ s = "Params("
63
+ if len(self.args) > 0:
64
+ s = f"{s}\n\targs = (\n\t\t"
65
+ s += ",\n\t\t".join(str(a) for a in self.args)
66
+ s = s + "\n\t)"
67
+
68
+ if len(self.kwargs) > 0:
69
+ s = f'{s}\n\tkwargs = (\n\t\t'
70
+ for k,v in self.kwargs.items():
71
+ s = f"{s}{k}={v},\n\t\t"
72
+ s = s[:-2] + "\t)"
73
+
74
+ return f"{s}\n)"
75
+
76
+ def _call(self, f):
77
+ return f(*self.args, **self.kwargs)
78
+
79
+ def _detach_clone(self):
80
+ return Params(
81
+ args = [a.detach().clone() for a in self.args],
82
+ kwargs={k:v.detach().clone() for k,v in self.kwargs.items()}
83
+ )
84
+
85
+ def _detach_cpu_clone(self):
86
+ return Params(
87
+ args = [a.detach().cpu().clone() for a in self.args],
88
+ kwargs={k:v.detach().cpu().clone() for k,v in self.kwargs.items()}
89
+ )
90
+
91
+ def _requires_grad_(self, mode=True):
92
+ return Params(
93
+ args = [a.requires_grad_(mode) for a in self.args],
94
+ kwargs={k:v.requires_grad_(mode) for k,v in self.kwargs.items()}
95
+ )
96
+
97
+
98
+ def _grads(self):
99
+ params = tuple(self.parameters())
100
+ if all(p.grad is None for p in params): return None
101
+ return [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
102
+
103
+
104
+ _x0 = (
105
+ torch.Tensor |
106
+ Sequence[torch.Tensor] |
107
+ Mapping[str, torch.Tensor] |
108
+ Mapping[str, Sequence[torch.Tensor] | Mapping[str, torch.Tensor]] |
109
+ tuple[Sequence[torch.Tensor], Mapping[str, torch.Tensor]] |
110
+ Sequence[Sequence[torch.Tensor] | Mapping[str, torch.Tensor]] |
111
+ Params
112
+ )
113
+
114
+
115
+
116
+ def _get_opt_fn(method: _method):
117
+ if isinstance(method, str):
118
+ return lambda p: Optimizer(p, *_get_method_from_str(method))
119
+
120
+ if isinstance(method, Module):
121
+ return lambda p: Optimizer(p, method)
122
+
123
+ if isinstance(method, Sequence):
124
+ return lambda p: Optimizer(p, *method)
125
+
126
+ if callable(method):
127
+ return method
128
+
129
+ raise ValueError(method)
130
+
131
+ def _is_scalar(x):
132
+ if isinstance(x, torch.Tensor): return x.numel() == 1
133
+ if isinstance(x, np.ndarray): return x.size == 1
134
+ return True
135
+
136
+ def _maybe_detach_cpu(x):
137
+ if isinstance(x, torch.Tensor): return x.detach().cpu()
138
+ return x
139
+
140
+ class _MaxEvaluationsReached(Exception): pass
141
+ class _MaxSecondsReached(Exception): pass
142
+ class Terminate(Exception): pass
143
+
144
+ class _WrappedFunc:
145
+ def __init__(self, f: _fn_autograd | _fn_custom_grad, x0: Params, reduce_fn: Callable, max_history,
146
+ maxeval:int | None, maxsec: float | None, custom_grad:bool):
147
+ self.f = f
148
+ self.maxeval = maxeval
149
+ self.reduce_fn = reduce_fn
150
+ self.custom_grad = custom_grad
151
+ self.maxsec = maxsec
152
+
153
+ self.x_best = x0.clone()
154
+ self.fmin = float("inf")
155
+ self.evals = 0
156
+ self.start = time.time()
157
+
158
+ if max_history == -1: max_history = None # unlimited history
159
+ if max_history == 0: self.history = None
160
+ else: self.history = deque(maxlen=max_history)
161
+
162
+ def __call__(self, x: Params, g: torch.Tensor | None = None) -> tuple[torch.Tensor, torch.Tensor | None]:
163
+ if self.maxeval is not None and self.evals >= self.maxeval:
164
+ raise _MaxEvaluationsReached
165
+
166
+ if self.maxsec is not None and time.time() - self.start >= self.maxsec:
167
+ raise _MaxSecondsReached
168
+
169
+ self.evals += 1
170
+
171
+ if self.custom_grad:
172
+ assert g is not None
173
+ assert len(x.args) == 1 and len(x.kwargs) == 0
174
+ v = v_scalar = cast(_fn_custom_grad, self.f)(x.x, g)
175
+ else:
176
+ v = v_scalar = x._call(self.f)
177
+
178
+ with torch.no_grad():
179
+
180
+ # multi-value v, reduce using reduce func
181
+ if isinstance(v, torch.Tensor) and v.numel() > 1:
182
+ v_scalar = self.reduce_fn(v)
183
+
184
+ if v_scalar < self.fmin:
185
+ self.fmin = tofloat(v_scalar)
186
+ self.x_best = x._detach_clone()
187
+
188
+ if self.history is not None:
189
+ self.history.append((x._detach_cpu_clone(), _maybe_detach_cpu(v)))
190
+
191
+ return v, g
192
+
193
+
194
+
195
+ class MinimizeResult(NamedTuple):
196
+ params: Params
197
+ x: torch.Tensor | None
198
+ success: bool
199
+ message: str
200
+ fun: float
201
+ n_iters: int
202
+ n_evals: int
203
+ g_norm: torch.Tensor | None
204
+ dir_norm: torch.Tensor | None
205
+ losses: list[float]
206
+ history: deque[tuple[torch.Tensor, torch.Tensor]]
207
+
208
+ def __repr__(self):
209
+ newline = "\n"
210
+ ident = " " * 10
211
+ return (
212
+ f"message: {self.message}\n"
213
+ f"success: {self.success}\n"
214
+ f"fun: {self.fun}\n"
215
+ f"params: {repr(self.params).replace(newline, newline+ident)}\n"
216
+ f"x: {self.x}\n"
217
+ f"n_iters: {self.n_iters}\n"
218
+ f"n_evals: {self.n_evals}\n"
219
+ f"g_norm: {self.g_norm}\n"
220
+ f"dir_norm: {self.dir_norm}\n"
221
+ )
222
+
223
+
224
+
225
+ def _make_params(x0: _x0):
226
+ x = cast(Any, x0)
227
+
228
+ # kwargs
229
+ if isinstance(x, Params): return x
230
+
231
+ # single tensor
232
+ if isinstance(x, torch.Tensor): return Params(args = (x, ), kwargs = {})
233
+
234
+ if isinstance(x, Sequence):
235
+ # args
236
+ if isinstance(x[0], torch.Tensor): return Params(args=x, kwargs = {})
237
+
238
+ # tuple of (args, kwrgs)
239
+ assert len(x) == 2 and isinstance(x[0], Sequence) and isinstance(x[1], Mapping)
240
+ return Params(args=x[0], kwargs=x[1])
241
+
242
+ if isinstance(x, Mapping):
243
+ # dict with args and kwargs
244
+ if "args" in x or "kwargs" in x: return Params(args=x.get("args", ()), kwargs=x.get("kwargs", {}))
245
+
246
+ # kwargs
247
+ return Params(args=(), kwargs=x)
248
+
249
+ raise TypeError(type(x))
250
+
251
+
252
+ def minimize(
253
+ f: _fn_autograd | _fn_custom_grad,
254
+ x0: _x0,
255
+
256
+ method: _method | None = None,
257
+
258
+ maxeval: int | None = None,
259
+ maxiter: int | None = None,
260
+ maxsec: float | None = None,
261
+ ftol: _scalar | None = None,
262
+ gtol: _scalar | None = 1e-5,
263
+ xtol: _scalar | None = None,
264
+ max_no_improvement_iters: int | None = 100,
265
+
266
+ reduce_fn: Callable[[torch.Tensor], torch.Tensor] = torch.sum,
267
+ max_history: int = 0,
268
+
269
+ custom_grad: bool = False,
270
+ use_termination_exceptions: bool = True,
271
+ norm = torch.inf,
272
+
273
+ ) -> MinimizeResult:
274
+ """Minimize a scalar or multiobjective function of one or more variables.
275
+
276
+ Args:
277
+ f (_fn_autograd | _fn_custom_grad):
278
+ The objective function to be minimized.
279
+ x0 (_x0):
280
+ Initial guess. Can be torch.Tensor, tuple of torch.Tensors to pass as args,
281
+ or dictionary of torch.Tensors to pass as kwargs.
282
+ method (_method | None, optional):
283
+ Type of solver. Can be a string, a ``Module`` (like ``tz.m.BFGS()``), or a list of ``Module``.
284
+ By default chooses BFGS or L-BFGS depending on number of variables. Defaults to None.
285
+ maxeval (int | None, optional):
286
+ terminate when exceeded this number of function evaluations. Defaults to None.
287
+ maxiter (int | None, optional):
288
+ terminate when exceeded this number of solver iterations,
289
+ each iteration may perform multiple function evaluations. Defaults to None.
290
+ maxsec (float | None, optional):
291
+ terminate after optimizing for this many seconds. Defaults to None.
292
+ ftol (_scalar | None, optional):
293
+ terminate when reached a solution with objective value less or equal to this value. Defaults to None.
294
+ gtol (_scalar | None, optional):
295
+ terminate when gradient norm is less or equal to this value.
296
+ The type of norm is controlled by ``norm`` argument and is infinity norm by default. Defaults to 1e-5.
297
+ xtol (_scalar | None, optional):
298
+ terminate when norm of difference between successive parameters is less or equal to this value. Defaults to None.
299
+ max_no_improvement_iters (int | None, optional):
300
+ terminate when objective value hasn't improved once for this many consecutive iterations. Defaults to 100.
301
+ reduce_fn (Callable[[torch.Tensor], torch.Tensor], optional):
302
+ only has effect when ``f`` is multi-objective / least-squares. Determines how to convert
303
+ vector returned by ``f`` to a single scalar value for ``ftol`` and ``max_no_improvement_iters``.
304
+ Defaults to torch.sum.
305
+ max_history (int, optional):
306
+ stores this many last evaluated parameters and their values.
307
+ Set to -1 to store all parameters. Set to 0 to store nothing (default).
308
+ custom_grad (bool, optional):
309
+ Allows specifying a custom gradient function instead of using autograd.
310
+ if True, objective function ``f`` must of the following form:
311
+ ```python
312
+ def f(x, grad):
313
+ value = objective(x)
314
+ if grad.numel() > 0:
315
+ grad[:] = objective_gradient(x)
316
+ return value
317
+ ```
318
+
319
+ Defaults to False.
320
+ use_termination_exceptions (bool, optional):
321
+ if True, ``maxeval`` and ``maxsec`` use exceptions to terminate, therefore they are able to trigger
322
+ mid-iteration. If False, they can only trigger after iteration, so it might perform slightly more
323
+ evals and for slightly more seconds than requested. Defaults to True.
324
+ norm (float, optional):
325
+ type of norm to use for gradient and update tolerances. Defaults to torch.inf.
326
+
327
+ Raises:
328
+ RuntimeError: _description_
329
+
330
+ Returns:
331
+ MinimizeResult: _description_
332
+ """
333
+
334
+ x0 = _make_params(x0)
335
+ x = x0._requires_grad_(True)
336
+
337
+ # checks
338
+ if custom_grad:
339
+ if not (len(x.args) == 1 and len(x.kwargs) == 0):
340
+ raise RuntimeError("custom_grad only works when `x` is a single tensor.")
341
+
342
+ # determine method if None
343
+ if method is None:
344
+ max_dim = 5_000 if next(iter(x.parameters())).is_cuda else 1_000
345
+ if sum(p.numel() for p in x.parameters()) > max_dim: method = 'lbfgs'
346
+ else: method = 'bfgs'
347
+
348
+ opt_fn = _get_opt_fn(method)
349
+ optimizer = opt_fn(list(x.parameters()))
350
+
351
+ f_wrapped = _WrappedFunc(
352
+ f,
353
+ x0=x0,
354
+ reduce_fn=reduce_fn,
355
+ max_history=max_history,
356
+ maxeval=maxeval,
357
+ custom_grad=custom_grad,
358
+ maxsec=maxsec,
359
+ )
360
+
361
+ def closure(backward=True):
362
+
363
+ g = None
364
+ v = None
365
+ if custom_grad:
366
+ v = x.x
367
+ if backward: g = torch.empty_like(v)
368
+ else: g = torch.empty(0, device=v.device, dtype=v.dtype)
369
+
370
+ loss, g = f_wrapped(x, g=g)
371
+
372
+ if backward:
373
+
374
+ # custom gradients provided by user
375
+ if g is not None:
376
+ assert v is not None
377
+ v.grad = g
378
+
379
+ # autograd
380
+ else:
381
+ optimizer.zero_grad()
382
+ loss.backward()
383
+
384
+ return loss
385
+
386
+ losses = []
387
+
388
+ tiny = torch.finfo(list(x0.parameters())[0].dtype).tiny ** 2
389
+ if gtol == 0: gtol = tiny
390
+ if xtol == 0: xtol = tiny
391
+
392
+ p_prev = None if xtol is None else [p.detach().clone() for p in x.parameters()]
393
+ fmin = float("inf")
394
+ niter = 0
395
+ n_no_improvement = 0
396
+ g_norm = None
397
+ dir_norm = None
398
+
399
+ terminate_msg = "max iterations reached"
400
+ success = False
401
+
402
+ exceptions: list | tuple = [Terminate]
403
+ if use_termination_exceptions:
404
+ if maxeval is not None: exceptions.append(_MaxEvaluationsReached)
405
+ if maxsec is not None: exceptions.append(_MaxSecondsReached)
406
+ exceptions = tuple(exceptions)
407
+
408
+ for i in (range(maxiter) if maxiter is not None else itertools.count()):
409
+ niter += 1
410
+
411
+ # ----------------------------------- step ----------------------------------- #
412
+ try:
413
+ v = v_scalar = optimizer.step(closure) # pyright:ignore[reportCallIssue,reportArgumentType]
414
+ except exceptions:
415
+ break
416
+
417
+ with torch.no_grad():
418
+ assert v is not None and v_scalar is not None
419
+
420
+ if isinstance(v, torch.Tensor) and v.numel() > 1:
421
+ v_scalar = reduce_fn(v)
422
+
423
+ losses.append(tofloat(v_scalar))
424
+
425
+ # --------------------------- termination criteria --------------------------- #
426
+
427
+ # termination criteria on optimizer
428
+ if isinstance(optimizer, Optimizer) and optimizer.should_terminate:
429
+ terminate_msg = 'optimizer-specific termination criteria triggered'
430
+ success = True
431
+ break
432
+
433
+ # max seconds (when use_termination_exceptions=False)
434
+ if maxsec is not None and time.time() - f_wrapped.start >= maxsec:
435
+ terminate_msg = 'max seconds reached'
436
+ success = False
437
+ break
438
+
439
+ # max evals (when use_termination_exceptions=False)
440
+ if maxeval is not None and f_wrapped.evals >= maxeval:
441
+ terminate_msg = 'max evaluations reached'
442
+ success = False
443
+ break
444
+
445
+ # min function value
446
+ if ftol is not None and v_scalar <= ftol:
447
+ terminate_msg = 'target function value reached'
448
+ success = True
449
+ break
450
+
451
+ # gradient infinity norm
452
+ if gtol is not None:
453
+ grads = x._grads()
454
+ if grads is not None:
455
+ g_norm = _tensorlist_norm(grads, norm)
456
+ if g_norm <= gtol:
457
+ terminate_msg = 'gradient norm is below tolerance'
458
+ success = True
459
+ break
460
+
461
+ # due to the way torchzero works we sometimes don't populate .grad,
462
+ # e.g. with Newton, therefore fallback on xtol
463
+ else:
464
+ if xtol is None: xtol = tiny
465
+
466
+ # difference in parameters
467
+ if xtol is not None:
468
+ p_new = [p.detach().clone() for p in x.parameters()]
469
+
470
+ if p_prev is None: # happens when xtol is set in gtol logic
471
+ p_prev = p_new
472
+
473
+ else:
474
+ dir_norm = _tensorlist_norm(torch._foreach_sub(p_new, p_prev), norm)
475
+ if dir_norm <= xtol:
476
+ terminate_msg = 'update norm is below tolerance'
477
+ success = True
478
+ break
479
+
480
+ p_prev = p_new
481
+
482
+ # no improvement steps
483
+ if max_no_improvement_iters is not None:
484
+ if f_wrapped.fmin >= fmin:
485
+ n_no_improvement += 1
486
+ else:
487
+ fmin = f_wrapped.fmin
488
+ n_no_improvement = 0
489
+
490
+ if n_no_improvement >= max_no_improvement_iters:
491
+ terminate_msg = 'reached maximum steps without improvement'
492
+ success = False
493
+ break
494
+
495
+ history=f_wrapped.history
496
+ if history is None: history = deque()
497
+
498
+ x_vec = None
499
+ if len(x0.args) == 1 and len(x0.kwargs) == 0:
500
+ x_vec = f_wrapped.x_best.x
501
+
502
+ result = MinimizeResult(
503
+ params = f_wrapped.x_best,
504
+ x = x_vec,
505
+ success = success,
506
+ message = terminate_msg,
507
+ fun = f_wrapped.fmin,
508
+ n_iters = niter,
509
+ n_evals = f_wrapped.evals,
510
+ g_norm = g_norm,
511
+ dir_norm = dir_norm,
512
+ losses = losses,
513
+ history = history,
514
+ )
515
+
516
+ return result
517
+
518
+
@@ -1,8 +1,8 @@
1
- from .transform import TensorTransform, Transform
2
- from .module import Chainable, Module
3
- from .objective import DerivativesMethod, HessianMethod, HVPMethod, Objective
1
+ from .chain import Chain, maybe_chain
2
+ from .functional import apply, step, step_tensors, update
4
3
 
5
4
  # order is important to avoid circular imports
6
5
  from .modular import Optimizer
7
- from .functional import apply, step, step_tensors, update
8
- from .chain import Chain, maybe_chain
6
+ from .module import Module, Chainable, ProjectedBuffer
7
+ from .objective import Objective, DerivativesMethod, HessianMethod, HVPMethod
8
+ from .transform import TensorTransform, Transform
torchzero/core/chain.py CHANGED
@@ -1,8 +1,9 @@
1
1
  from collections.abc import Iterable
2
2
 
3
3
  from ..utils.python_tools import flatten
4
- from .module import Module, Chainable
5
4
  from .functional import _chain_step
5
+ from .module import Chainable, Module
6
+
6
7
 
7
8
  class Chain(Module):
8
9
  """Chain modules, mostly used internally"""
@@ -83,6 +83,7 @@ def step_tensors(
83
83
  modules = (modules, )
84
84
 
85
85
  # make fake params if they are only used for shapes
86
+ # note that if modules use states, tensors must always be the same python object
86
87
  if params is None:
87
88
  params = [t.view_as(t).requires_grad_() for t in tensors]
88
89
 
@@ -96,7 +97,7 @@ def step_tensors(
96
97
  objective.updates = list(tensors)
97
98
 
98
99
  # step with modules
99
- # this won't update parameters in-place because objective.Optimizer is None
100
+ # this won't update parameters in-place (on modules with fused update) because objective.Optimizer is None
100
101
  objective = _chain_step(objective, modules)
101
102
 
102
103
  # return updates