torchzero 0.4.0__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (112) hide show
  1. tests/test_identical.py +22 -22
  2. tests/test_opts.py +199 -198
  3. torchzero/__init__.py +1 -1
  4. torchzero/core/__init__.py +1 -1
  5. torchzero/core/functional.py +1 -1
  6. torchzero/core/modular.py +5 -5
  7. torchzero/core/module.py +2 -2
  8. torchzero/core/objective.py +10 -10
  9. torchzero/core/transform.py +1 -1
  10. torchzero/linalg/__init__.py +3 -2
  11. torchzero/linalg/eigh.py +223 -4
  12. torchzero/linalg/orthogonalize.py +2 -4
  13. torchzero/linalg/qr.py +12 -0
  14. torchzero/linalg/solve.py +1 -3
  15. torchzero/linalg/svd.py +47 -20
  16. torchzero/modules/__init__.py +4 -3
  17. torchzero/modules/adaptive/__init__.py +11 -3
  18. torchzero/modules/adaptive/adagrad.py +10 -10
  19. torchzero/modules/adaptive/adahessian.py +2 -2
  20. torchzero/modules/adaptive/adam.py +1 -1
  21. torchzero/modules/adaptive/adan.py +1 -1
  22. torchzero/modules/adaptive/adaptive_heavyball.py +1 -1
  23. torchzero/modules/adaptive/esgd.py +2 -2
  24. torchzero/modules/adaptive/ggt.py +186 -0
  25. torchzero/modules/adaptive/lion.py +2 -1
  26. torchzero/modules/adaptive/lre_optimizers.py +299 -0
  27. torchzero/modules/adaptive/mars.py +2 -2
  28. torchzero/modules/adaptive/matrix_momentum.py +1 -1
  29. torchzero/modules/adaptive/msam.py +4 -4
  30. torchzero/modules/adaptive/muon.py +9 -6
  31. torchzero/modules/adaptive/natural_gradient.py +32 -15
  32. torchzero/modules/adaptive/psgd/__init__.py +5 -0
  33. torchzero/modules/adaptive/psgd/_psgd_utils.py +37 -0
  34. torchzero/modules/adaptive/psgd/psgd.py +1390 -0
  35. torchzero/modules/adaptive/psgd/psgd_dense_newton.py +174 -0
  36. torchzero/modules/adaptive/psgd/psgd_kron_newton.py +203 -0
  37. torchzero/modules/adaptive/psgd/psgd_kron_whiten.py +185 -0
  38. torchzero/modules/adaptive/psgd/psgd_lra_newton.py +118 -0
  39. torchzero/modules/adaptive/psgd/psgd_lra_whiten.py +116 -0
  40. torchzero/modules/adaptive/rprop.py +2 -2
  41. torchzero/modules/adaptive/sam.py +4 -4
  42. torchzero/modules/adaptive/shampoo.py +28 -3
  43. torchzero/modules/adaptive/soap.py +3 -3
  44. torchzero/modules/adaptive/sophia_h.py +2 -2
  45. torchzero/modules/clipping/clipping.py +7 -7
  46. torchzero/modules/conjugate_gradient/cg.py +2 -2
  47. torchzero/modules/experimental/__init__.py +5 -0
  48. torchzero/modules/experimental/adanystrom.py +258 -0
  49. torchzero/modules/experimental/common_directions_whiten.py +142 -0
  50. torchzero/modules/experimental/cubic_adam.py +160 -0
  51. torchzero/modules/experimental/eigen_sr1.py +182 -0
  52. torchzero/modules/experimental/eigengrad.py +207 -0
  53. torchzero/modules/experimental/l_infinity.py +1 -1
  54. torchzero/modules/experimental/matrix_nag.py +122 -0
  55. torchzero/modules/experimental/newton_solver.py +2 -2
  56. torchzero/modules/experimental/newtonnewton.py +34 -40
  57. torchzero/modules/grad_approximation/fdm.py +2 -2
  58. torchzero/modules/grad_approximation/rfdm.py +4 -4
  59. torchzero/modules/least_squares/gn.py +68 -45
  60. torchzero/modules/line_search/backtracking.py +2 -2
  61. torchzero/modules/line_search/line_search.py +1 -1
  62. torchzero/modules/line_search/strong_wolfe.py +2 -2
  63. torchzero/modules/misc/escape.py +1 -1
  64. torchzero/modules/misc/gradient_accumulation.py +1 -1
  65. torchzero/modules/misc/misc.py +1 -1
  66. torchzero/modules/misc/multistep.py +4 -7
  67. torchzero/modules/misc/regularization.py +2 -2
  68. torchzero/modules/misc/split.py +1 -1
  69. torchzero/modules/misc/switch.py +2 -2
  70. torchzero/modules/momentum/cautious.py +3 -3
  71. torchzero/modules/momentum/momentum.py +1 -1
  72. torchzero/modules/ops/higher_level.py +1 -1
  73. torchzero/modules/ops/multi.py +1 -1
  74. torchzero/modules/projections/projection.py +5 -2
  75. torchzero/modules/quasi_newton/__init__.py +1 -1
  76. torchzero/modules/quasi_newton/damping.py +1 -1
  77. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +1 -1
  78. torchzero/modules/quasi_newton/lbfgs.py +3 -3
  79. torchzero/modules/quasi_newton/lsr1.py +3 -3
  80. torchzero/modules/quasi_newton/quasi_newton.py +44 -29
  81. torchzero/modules/quasi_newton/sg2.py +69 -205
  82. torchzero/modules/restarts/restars.py +17 -17
  83. torchzero/modules/second_order/inm.py +33 -25
  84. torchzero/modules/second_order/newton.py +132 -130
  85. torchzero/modules/second_order/newton_cg.py +3 -3
  86. torchzero/modules/second_order/nystrom.py +83 -32
  87. torchzero/modules/second_order/rsn.py +41 -44
  88. torchzero/modules/smoothing/laplacian.py +1 -1
  89. torchzero/modules/smoothing/sampling.py +2 -3
  90. torchzero/modules/step_size/adaptive.py +6 -6
  91. torchzero/modules/step_size/lr.py +2 -2
  92. torchzero/modules/trust_region/cubic_regularization.py +1 -1
  93. torchzero/modules/trust_region/levenberg_marquardt.py +2 -2
  94. torchzero/modules/trust_region/trust_cg.py +1 -1
  95. torchzero/modules/variance_reduction/svrg.py +4 -5
  96. torchzero/modules/weight_decay/reinit.py +2 -2
  97. torchzero/modules/weight_decay/weight_decay.py +5 -5
  98. torchzero/modules/wrappers/optim_wrapper.py +4 -4
  99. torchzero/modules/zeroth_order/cd.py +1 -1
  100. torchzero/optim/mbs.py +291 -0
  101. torchzero/optim/wrappers/nevergrad.py +0 -9
  102. torchzero/optim/wrappers/optuna.py +2 -0
  103. torchzero/utils/benchmarks/__init__.py +0 -0
  104. torchzero/utils/benchmarks/logistic.py +122 -0
  105. torchzero/utils/derivatives.py +4 -4
  106. {torchzero-0.4.0.dist-info → torchzero-0.4.1.dist-info}/METADATA +1 -1
  107. torchzero-0.4.1.dist-info/RECORD +209 -0
  108. torchzero/modules/adaptive/lmadagrad.py +0 -241
  109. torchzero-0.4.0.dist-info/RECORD +0 -191
  110. /torchzero/modules/{functional.py → opt_utils.py} +0 -0
  111. {torchzero-0.4.0.dist-info → torchzero-0.4.1.dist-info}/WHEEL +0 -0
  112. {torchzero-0.4.0.dist-info → torchzero-0.4.1.dist-info}/top_level.txt +0 -0
@@ -43,7 +43,7 @@ class InfinityNormTrustRegion(TrustRegionBase):
43
43
 
44
44
  .. code-block:: python
45
45
 
46
- opt = tz.Modular(
46
+ opt = tz.Optimizer(
47
47
  model.parameters(),
48
48
  tz.m.InfinityNormTrustRegion(hess_module=tz.m.BFGS(inverse=False)),
49
49
  )
@@ -0,0 +1,122 @@
1
+ from collections.abc import Callable
2
+ from typing import Literal
3
+
4
+ import torch
5
+ from torchzero.core import Chainable, Transform, HVPMethod
6
+ from torchzero.utils import NumberList, TensorList
7
+
8
+
9
+ def matrix_nag_(
10
+ tensors_: TensorList,
11
+ s: TensorList,
12
+ Hvp_fn: Callable,
13
+ mu: float | NumberList,
14
+ ):
15
+ s += tensors_
16
+ Hv = TensorList(Hvp_fn(s))
17
+ s -= Hv.mul_(mu)
18
+ return tensors_.add_(s)
19
+
20
+
21
+ class MatrixNAG(Transform):
22
+ """nesterov momentum version of matrix momentum. It seemed to work really well but adapting doesn't work,
23
+ I need to test more"""
24
+ def __init__(
25
+ self,
26
+ mu=0.1,
27
+ hvp_method: HVPMethod = "autograd",
28
+ h: float = 1e-3,
29
+ adaptive:bool = False,
30
+ adapt_freq: int | None = None,
31
+ hvp_tfm: Chainable | None = None,
32
+ ):
33
+ defaults = dict(mu=mu, hvp_method=hvp_method, h=h, adaptive=adaptive, adapt_freq=adapt_freq)
34
+ super().__init__(defaults)
35
+
36
+ if hvp_tfm is not None:
37
+ self.set_child('hvp_tfm', hvp_tfm)
38
+
39
+ def reset_for_online(self):
40
+ super().reset_for_online()
41
+ self.clear_state_keys('p_prev')
42
+
43
+ @torch.no_grad
44
+ def apply_states(self, objective, states, settings):
45
+ assert objective.closure is not None
46
+ step = self.global_state.get("step", 0)
47
+ self.global_state["step"] = step + 1
48
+
49
+ p = TensorList(objective.params)
50
+ g = TensorList(objective.get_grads(create_graph=self.defaults["hvp_method"] == "autograd"))
51
+ p_prev = self.get_state(p, "p_prev", init=p, cls=TensorList)
52
+ s = p - p_prev
53
+ p_prev.copy_(p)
54
+
55
+ # -------------------------------- adaptive mu ------------------------------- #
56
+ if self.defaults["adaptive"]:
57
+
58
+ if step == 1:
59
+ self.global_state["mu_mul"] = 0
60
+
61
+ else:
62
+ # ---------------------------- deterministic case ---------------------------- #
63
+ if self.defaults["adapt_freq"] is None:
64
+ g_prev = self.get_state(objective.params, "g_prev", cls=TensorList)
65
+ y = g - g_prev
66
+ g_prev.copy_(g)
67
+
68
+ denom = y.global_vector_norm()
69
+ denom = denom.clip(min = torch.finfo(denom.dtype).tiny * 2)
70
+ self.global_state["mu_mul"] = s.global_vector_norm() / denom
71
+
72
+ # -------------------------------- stochastic -------------------------------- #
73
+ else:
74
+ adapt_freq = self.defaults["adapt_freq"]
75
+
76
+ # we start on 1nd step, and want to adapt when we start, so use (step - 1)
77
+ if (step - 1) % adapt_freq == 0:
78
+ assert objective.closure is not None
79
+ p_cur = p.clone()
80
+
81
+ # move to previous params and evaluate p_prev with current mini-batch
82
+ p.copy_(self.get_state(objective.params, 'p_prev'))
83
+ with torch.enable_grad():
84
+ objective.closure()
85
+ g_prev = [t.grad if t.grad is not None else torch.zeros_like(t) for t in p]
86
+ y = g - g_prev
87
+
88
+ # move back to current params
89
+ p.copy_(p_cur)
90
+
91
+ denom = y.global_vector_norm()
92
+ denom = denom.clip(min = torch.finfo(denom.dtype).tiny * 2)
93
+ self.global_state["mu_mul"] = s.global_vector_norm() / denom
94
+
95
+ # -------------------------- matrix momentum update -------------------------- #
96
+ mu = self.get_settings(p, "mu", cls=NumberList)
97
+ if "mu_mul" in self.global_state:
98
+ mu = mu * self.global_state["mu_mul"]
99
+
100
+ # def Hvp_fn(v):
101
+ # Hv, _ = self.Hvp(
102
+ # v=v,
103
+ # at_x0=True,
104
+ # var=objective,
105
+ # rgrad=g,
106
+ # hvp_method=self.defaults["hvp_method"],
107
+ # h=self.defaults["h"],
108
+ # normalize=True,
109
+ # retain_grad=False,
110
+ # )
111
+ # return Hv
112
+
113
+ _, Hvp_fn = objective.list_Hvp_function(hvp_method=self.defaults["hvp_method"], h=self.defaults["h"], at_x0=True)
114
+
115
+ objective.updates = matrix_nag_(
116
+ tensors_=TensorList(objective.get_updates()),
117
+ s=s,
118
+ Hvp_fn=Hvp_fn,
119
+ mu=mu,
120
+ )
121
+
122
+ return objective
@@ -3,7 +3,7 @@ from typing import Any
3
3
 
4
4
  import torch
5
5
 
6
- from ...core import Chainable, Modular, Module, step, HVPMethod
6
+ from ...core import Chainable, Optimizer, Module, step, HVPMethod
7
7
  from ...utils import TensorList
8
8
  from ..quasi_newton import LBFGS
9
9
 
@@ -12,7 +12,7 @@ class NewtonSolver(Module):
12
12
  """Matrix free newton via with any custom solver (this is for testing, use NewtonCG or NystromPCG)."""
13
13
  def __init__(
14
14
  self,
15
- solver: Callable[[list[torch.Tensor]], Any] = lambda p: Modular(p, LBFGS()),
15
+ solver: Callable[[list[torch.Tensor]], Any] = lambda p: Optimizer(p, LBFGS()),
16
16
  maxiter=None,
17
17
  maxiter1=None,
18
18
  tol:float | None=1e-3,
@@ -7,22 +7,21 @@ from typing import Literal
7
7
 
8
8
  import torch
9
9
 
10
- from ...core import Chainable, Module, step
10
+ from ...core import Chainable, Transform, step
11
11
  from ...linalg.linear_operator import Dense
12
- from ...utils import TensorList, vec_to_tensors
12
+ from ...utils import TensorList, vec_to_tensors_
13
13
  from ...utils.derivatives import (
14
14
  flatten_jacobian,
15
15
  jacobian_wrt,
16
16
  )
17
17
  from ..second_order.newton import (
18
- _cholesky_solve,
19
- _eigh_solve,
18
+ _try_cholesky_solve,
20
19
  _least_squares_solve,
21
- _lu_solve,
20
+ _try_lu_solve,
22
21
  )
23
22
 
24
23
 
25
- class NewtonNewton(Module):
24
+ class NewtonNewton(Transform):
26
25
  """Applies Newton-like preconditioning to Newton step.
27
26
 
28
27
  This is a method that I thought of and then it worked. Here is how it works:
@@ -34,39 +33,32 @@ class NewtonNewton(Module):
34
33
  3. Solve H2 x2 = x for x2.
35
34
 
36
35
  4. Optionally, repeat (if order is higher than 3.)
37
-
38
- Memory is n^order. It tends to converge faster on convex functions, but can be unstable on non-convex. Orders higher than 3 are usually too unsable and have little benefit.
39
-
40
- 3rd order variant can minimize some convex functions with up to 100 variables in less time than Newton's method,
41
- this is if pytorch can vectorize hessian computation efficiently.
42
36
  """
43
37
  def __init__(
44
38
  self,
45
39
  reg: float = 1e-6,
46
40
  order: int = 3,
47
- search_negative: bool = False,
48
41
  vectorize: bool = True,
49
- eigval_fn: Callable[[torch.Tensor], torch.Tensor] | None = None,
42
+ update_freq: int = 1,
43
+ inner: Chainable | None = None,
50
44
  ):
51
- defaults = dict(order=order, reg=reg, vectorize=vectorize, eigval_fn=eigval_fn, search_negative=search_negative)
52
- super().__init__(defaults)
45
+ defaults = dict(order=order, reg=reg, vectorize=vectorize)
46
+ super().__init__(defaults, update_freq=update_freq, inner=inner)
53
47
 
54
48
  @torch.no_grad
55
- def update(self, objective):
49
+ def update_states(self, objective, states, settings):
50
+ fs = settings[0]
56
51
 
57
52
  params = TensorList(objective.params)
58
53
  closure = objective.closure
59
54
  if closure is None: raise RuntimeError('NewtonNewton requires closure')
60
55
 
61
- settings = self.settings[params[0]]
62
- reg = settings['reg']
63
- vectorize = settings['vectorize']
64
- order = settings['order']
65
- search_negative = settings['search_negative']
66
- eigval_fn = settings['eigval_fn']
56
+ reg = fs['reg']
57
+ vectorize = fs['vectorize']
58
+ order = fs['order']
67
59
 
68
60
  # ------------------------ calculate grad and hessian ------------------------ #
69
- Hs = []
61
+ P = None
70
62
  with torch.enable_grad():
71
63
  loss = objective.loss = objective.loss_approx = closure(False)
72
64
  g_list = torch.autograd.grad(loss, params, create_graph=True)
@@ -81,28 +73,30 @@ class NewtonNewton(Module):
81
73
  with torch.no_grad() if is_last else nullcontext():
82
74
  H = flatten_jacobian(H_list)
83
75
  if reg != 0: H = H + I * reg
84
- Hs.append(H)
76
+ if P is None: P = H
77
+ else: P = P @ H
85
78
 
86
- x = None
87
- if search_negative or (is_last and eigval_fn is not None):
88
- x = _eigh_solve(H, xp, eigval_fn, search_negative=search_negative)
89
- if x is None: x = _cholesky_solve(H, xp)
90
- if x is None: x = _lu_solve(H, xp)
91
- if x is None: x = _least_squares_solve(H, xp)
92
- xp = x.squeeze()
79
+ if not is_last:
80
+ x = _try_cholesky_solve(H, xp)
81
+ if x is None: x = _try_lu_solve(H, xp)
82
+ if x is None: x = _least_squares_solve(H, xp)
83
+ xp = x.squeeze()
93
84
 
94
- self.global_state["Hs"] = Hs
95
- self.global_state['xp'] = xp.nan_to_num_(0,0,0)
85
+ self.global_state["P"] = P
96
86
 
97
87
  @torch.no_grad
98
- def apply(self, objective):
99
- params = objective.params
100
- xp = self.global_state['xp']
101
- objective.updates = vec_to_tensors(xp, params)
88
+ def apply_states(self, objective, states, settings):
89
+ updates = objective.get_updates()
90
+ P = self.global_state['P']
91
+ b = torch.cat([t.ravel() for t in updates])
92
+
93
+ sol = _try_cholesky_solve(P, b)
94
+ if sol is None: sol = _try_lu_solve(P, b)
95
+ if sol is None: sol = _least_squares_solve(P, b)
96
+
97
+ vec_to_tensors_(sol, updates)
102
98
  return objective
103
99
 
104
100
  @torch.no_grad
105
101
  def get_H(self, objective=...):
106
- Hs = self.global_state["Hs"]
107
- if len(Hs) == 1: return Dense(Hs[0])
108
- return Dense(torch.linalg.multi_dot(self.global_state["Hs"])) # pylint:disable=not-callable
102
+ return Dense(self.global_state["P"])
@@ -106,12 +106,12 @@ class FDM(GradApproximator):
106
106
  plain FDM:
107
107
 
108
108
  ```python
109
- fdm = tz.Modular(model.parameters(), tz.m.FDM(), tz.m.LR(1e-2))
109
+ fdm = tz.Optimizer(model.parameters(), tz.m.FDM(), tz.m.LR(1e-2))
110
110
  ```
111
111
 
112
112
  Any gradient-based method can use FDM-estimated gradients.
113
113
  ```python
114
- fdm_ncg = tz.Modular(
114
+ fdm_ncg = tz.Optimizer(
115
115
  model.parameters(),
116
116
  tz.m.FDM(),
117
117
  # set hvp_method to "forward" so that it
@@ -174,7 +174,7 @@ class RandomizedFDM(GradApproximator):
174
174
 
175
175
  SPSA is randomized FDM with rademacher distribution and central formula.
176
176
  ```py
177
- spsa = tz.Modular(
177
+ spsa = tz.Optimizer(
178
178
  model.parameters(),
179
179
  tz.m.RandomizedFDM(formula="fd_central", distribution="rademacher"),
180
180
  tz.m.LR(1e-2)
@@ -185,7 +185,7 @@ class RandomizedFDM(GradApproximator):
185
185
 
186
186
  RDSA is randomized FDM with usually gaussian distribution and central formula.
187
187
  ```
188
- rdsa = tz.Modular(
188
+ rdsa = tz.Optimizer(
189
189
  model.parameters(),
190
190
  tz.m.RandomizedFDM(formula="fd_central", distribution="gaussian"),
191
191
  tz.m.LR(1e-2)
@@ -196,7 +196,7 @@ class RandomizedFDM(GradApproximator):
196
196
 
197
197
  GS uses many gaussian samples with possibly a larger finite difference step size.
198
198
  ```
199
- gs = tz.Modular(
199
+ gs = tz.Optimizer(
200
200
  model.parameters(),
201
201
  tz.m.RandomizedFDM(n_samples=100, distribution="gaussian", formula="forward2", h=1e-1),
202
202
  tz.m.NewtonCG(hvp_method="forward"),
@@ -208,7 +208,7 @@ class RandomizedFDM(GradApproximator):
208
208
 
209
209
  Momentum might help by reducing the variance of the estimated gradients.
210
210
  ```
211
- momentum_spsa = tz.Modular(
211
+ momentum_spsa = tz.Optimizer(
212
212
  model.parameters(),
213
213
  tz.m.RandomizedFDM(),
214
214
  tz.m.HeavyBall(0.9),
@@ -1,12 +1,12 @@
1
1
  import torch
2
2
 
3
- from ...core import Chainable, Module, step
3
+ from ...core import Chainable, Transform
4
4
  from ...linalg import linear_operator
5
5
  from ...utils import vec_to_tensors
6
6
  from ...utils.derivatives import flatten_jacobian, jacobian_wrt
7
7
 
8
8
 
9
- class SumOfSquares(Module):
9
+ class SumOfSquares(Transform):
10
10
  """Sets loss to be the sum of squares of values returned by the closure.
11
11
 
12
12
  This is meant to be used to test least squares methods against ordinary minimization methods.
@@ -18,7 +18,7 @@ class SumOfSquares(Module):
18
18
  super().__init__()
19
19
 
20
20
  @torch.no_grad
21
- def update(self, objective):
21
+ def update_states(self, objective, states, settings):
22
22
  closure = objective.closure
23
23
 
24
24
  if closure is not None:
@@ -43,7 +43,11 @@ class SumOfSquares(Module):
43
43
  if objective.loss_approx is not None:
44
44
  objective.loss_approx = objective.loss_approx.pow(2).sum()
45
45
 
46
- class GaussNewton(Module):
46
+ @torch.no_grad
47
+ def apply_states(self, objective, states, settings):
48
+ return objective
49
+
50
+ class GaussNewton(Transform):
47
51
  """Gauss-newton method.
48
52
 
49
53
  To use this, the closure should return a vector of values to minimize sum of squares of.
@@ -57,6 +61,9 @@ class GaussNewton(Module):
57
61
 
58
62
  Args:
59
63
  reg (float, optional): regularization parameter. Defaults to 1e-8.
64
+ update_freq (int, optional):
65
+ frequency of computing the jacobian. When jacobian is not computed, only residuals are computed and updated.
66
+ Defaults to 1.
60
67
  batched (bool, optional): whether to use vmapping. Defaults to True.
61
68
 
62
69
  Examples:
@@ -68,7 +75,7 @@ class GaussNewton(Module):
68
75
  return torch.stack([(1 - x1), 100 * (x2 - x1**2)])
69
76
 
70
77
  X = torch.tensor([-1.1, 2.5], requires_grad=True)
71
- opt = tz.Modular([X], tz.m.GaussNewton(), tz.m.Backtracking())
78
+ opt = tz.Optimizer([X], tz.m.GaussNewton(), tz.m.Backtracking())
72
79
 
73
80
  # define the closure for line search
74
81
  def closure(backward=True):
@@ -86,7 +93,7 @@ class GaussNewton(Module):
86
93
  y = torch.randn(64, 10)
87
94
 
88
95
  model = nn.Sequential(nn.Linear(20, 64), nn.ELU(), nn.Linear(64, 10))
89
- opt = tz.Modular(
96
+ opt = tz.Optimizer(
90
97
  model.parameters(),
91
98
  tz.m.TrustCG(tz.m.GaussNewton()),
92
99
  )
@@ -101,33 +108,49 @@ class GaussNewton(Module):
101
108
  print(f'{losses.mean() = }')
102
109
  ```
103
110
  """
104
- def __init__(self, reg:float = 1e-8, batched:bool=True, inner: Chainable | None = None):
105
- super().__init__(defaults=dict(batched=batched, reg=reg))
111
+ def __init__(self, reg:float = 1e-8, update_freq: int= 1, batched:bool=True, inner: Chainable | None = None):
112
+ defaults=dict(update_freq=update_freq,batched=batched, reg=reg)
113
+ super().__init__(defaults=defaults)
106
114
  if inner is not None: self.set_child('inner', inner)
107
115
 
108
116
  @torch.no_grad
109
- def update(self, objective):
117
+ def update_states(self, objective, states, settings):
118
+ fs = settings[0]
110
119
  params = objective.params
111
- batched = self.defaults['batched']
112
-
113
120
  closure = objective.closure
114
- assert closure is not None
121
+ batched = fs['batched']
122
+ update_freq = fs['update_freq']
123
+
124
+ # compute residuals
125
+ r = objective.loss
126
+ if r is None:
127
+ assert closure is not None
128
+ with torch.enable_grad():
129
+ r = objective.get_loss(backward=False) # n_residuals
130
+ assert isinstance(r, torch.Tensor)
131
+
132
+ # set sum of squares scalar loss and it's gradient to objective
133
+ objective.loss = r.pow(2).sum()
115
134
 
116
- # gauss newton direction
117
- with torch.enable_grad():
118
- r = objective.get_loss(backward=False) # nresiduals
119
- assert isinstance(r, torch.Tensor)
120
- J_list = jacobian_wrt([r.ravel()], params, batched=batched)
135
+ step = self.increment_counter("step", start=0)
121
136
 
122
- objective.loss = r.pow(2).sum()
137
+ if step % update_freq == 0:
138
+
139
+ # compute jacobian
140
+ with torch.enable_grad():
141
+ J_list = jacobian_wrt([r.ravel()], params, batched=batched)
142
+
143
+ J = self.global_state["J"] = flatten_jacobian(J_list) # (n_residuals, ndim)
144
+
145
+ else:
146
+ J = self.global_state["J"]
123
147
 
124
- J = self.global_state["J"] = flatten_jacobian(J_list) # (nresiduals, ndim)
125
148
  Jr = J.T @ r.detach() # (ndim)
126
149
 
127
150
  # if there are more residuals, solve (J^T J)x = J^T r, so we need Jr
128
151
  # otherwise solve (J J^T)z = r and set x = J^T z, so we need r
129
- nresiduals, ndim = J.shape
130
- if nresiduals >= ndim or "inner" in self.children:
152
+ n_residuals, ndim = J.shape
153
+ if n_residuals >= ndim or "inner" in self.children:
131
154
  self.global_state["Jr"] = Jr
132
155
 
133
156
  else:
@@ -136,8 +159,9 @@ class GaussNewton(Module):
136
159
  objective.grads = vec_to_tensors(Jr, objective.params)
137
160
 
138
161
  # set closure to calculate sum of squares for line searches etc
139
- if objective.closure is not None:
162
+ if closure is not None:
140
163
  def sos_closure(backward=True):
164
+
141
165
  if backward:
142
166
  objective.zero_grad()
143
167
  with torch.enable_grad():
@@ -151,8 +175,9 @@ class GaussNewton(Module):
151
175
  objective.closure = sos_closure
152
176
 
153
177
  @torch.no_grad
154
- def apply(self, objective):
155
- reg = self.defaults['reg']
178
+ def apply_states(self, objective, states, settings):
179
+ fs = settings[0]
180
+ reg = fs['reg']
156
181
 
157
182
  J: torch.Tensor = self.global_state['J']
158
183
  nresiduals, ndim = J.shape
@@ -170,39 +195,37 @@ class GaussNewton(Module):
170
195
  Jr_list = objective.get_updates()
171
196
  Jr = torch.cat([t.ravel() for t in Jr_list])
172
197
 
173
- JJ = J.T @ J # (ndim, ndim)
198
+ JtJ = J.T @ J # (ndim, ndim)
174
199
  if reg != 0:
175
- JJ.add_(torch.eye(JJ.size(0), device=JJ.device, dtype=JJ.dtype).mul_(reg))
200
+ JtJ.add_(torch.eye(JtJ.size(0), device=JtJ.device, dtype=JtJ.dtype).mul_(reg))
176
201
 
177
202
  if nresiduals >= ndim:
178
- v, info = torch.linalg.solve_ex(JJ, Jr) # pylint:disable=not-callable
203
+ v, info = torch.linalg.solve_ex(JtJ, Jr) # pylint:disable=not-callable
179
204
  else:
180
- v = torch.linalg.lstsq(JJ, Jr).solution # pylint:disable=not-callable
205
+ v = torch.linalg.lstsq(JtJ, Jr).solution # pylint:disable=not-callable
181
206
 
182
207
  objective.updates = vec_to_tensors(v, objective.params)
183
208
  return objective
184
209
 
185
- else:
186
- # solve (J J^T)z = r and set v = J^T z
187
- # derivation
188
- # we need (J^T J)v = J^T r
189
- # suppose z is solution to (G G^T)z = r, and v = J^T z
190
- # if v = J^T z, then (J^T J)v = (J^T J) (J^T z) = J^T (J J^T) z = J^T r
191
- # therefore with our presuppositions (J^T J)v = J^T r
210
+ # else:
211
+ # solve (J J^T)z = r and set v = J^T z
212
+ # we need (J^T J)v = J^T r
213
+ # if z is solution to (G G^T)z = r, and v = J^T z
214
+ # then (J^T J)v = (J^T J) (J^T z) = J^T (J J^T) z = J^T r
215
+ # therefore (J^T J)v = J^T r
216
+ # also this gives a minimum norm solution
192
217
 
193
- # also this gives a minimum norm solution
218
+ r = self.global_state['r']
194
219
 
195
- r = self.global_state['r']
220
+ JJT = J @ J.T # (nresiduals, nresiduals)
221
+ if reg != 0:
222
+ JJT.add_(torch.eye(JJT.size(0), device=JJT.device, dtype=JJT.dtype).mul_(reg))
196
223
 
197
- JJT = J @ J.T # (nresiduals, nresiduals)
198
- if reg != 0:
199
- JJT.add_(torch.eye(JJT.size(0), device=JJT.device, dtype=JJT.dtype).mul_(reg))
200
-
201
- z, info = torch.linalg.solve_ex(JJT, r) # pylint:disable=not-callable
202
- v = J.T @ z
224
+ z, info = torch.linalg.solve_ex(JJT, r) # pylint:disable=not-callable
225
+ v = J.T @ z
203
226
 
204
- objective.updates = vec_to_tensors(v, objective.params)
205
- return objective
227
+ objective.updates = vec_to_tensors(v, objective.params)
228
+ return objective
206
229
 
207
230
  def get_H(self, objective=...):
208
231
  J = self.global_state['J']
@@ -77,7 +77,7 @@ class Backtracking(LineSearchBase):
77
77
  Gradient descent with backtracking line search:
78
78
 
79
79
  ```python
80
- opt = tz.Modular(
80
+ opt = tz.Optimizer(
81
81
  model.parameters(),
82
82
  tz.m.Backtracking()
83
83
  )
@@ -85,7 +85,7 @@ class Backtracking(LineSearchBase):
85
85
 
86
86
  L-BFGS with backtracking line search:
87
87
  ```python
88
- opt = tz.Modular(
88
+ opt = tz.Optimizer(
89
89
  model.parameters(),
90
90
  tz.m.LBFGS(),
91
91
  tz.m.Backtracking()
@@ -10,7 +10,7 @@ import torch
10
10
 
11
11
  from ...core import Module, Objective
12
12
  from ...utils import tofloat, set_storage_
13
- from ..functional import clip_by_finfo
13
+ from ..opt_utils import clip_by_finfo
14
14
 
15
15
 
16
16
  class MaxLineSearchItersReached(Exception): pass
@@ -236,7 +236,7 @@ class StrongWolfe(LineSearchBase):
236
236
  Conjugate gradient method with strong wolfe line search. Nocedal, Wright recommend setting c2 to 0.1 for CG. Since CG doesn't produce well scaled directions, initial alpha can be determined from function values by ``a_init="first-order"``.
237
237
 
238
238
  ```python
239
- opt = tz.Modular(
239
+ opt = tz.Optimizer(
240
240
  model.parameters(),
241
241
  tz.m.PolakRibiere(),
242
242
  tz.m.StrongWolfe(c2=0.1, a_init="first-order")
@@ -245,7 +245,7 @@ class StrongWolfe(LineSearchBase):
245
245
 
246
246
  LBFGS strong wolfe line search:
247
247
  ```python
248
- opt = tz.Modular(
248
+ opt = tz.Optimizer(
249
249
  model.parameters(),
250
250
  tz.m.LBFGS(),
251
251
  tz.m.StrongWolfe()
@@ -3,7 +3,7 @@ import math
3
3
  from typing import Literal
4
4
  import torch
5
5
 
6
- from ...core import Modular, Module, Objective, Chainable
6
+ from ...core import Optimizer, Module, Objective, Chainable
7
7
  from ...utils import NumberList, TensorList
8
8
 
9
9
 
@@ -24,7 +24,7 @@ class GradientAccumulation(Module):
24
24
  Adam with gradients accumulated for 16 batches.
25
25
 
26
26
  ```python
27
- opt = tz.Modular(
27
+ opt = tz.Optimizer(
28
28
  model.parameters(),
29
29
  tz.m.GradientAccumulation(),
30
30
  tz.m.Adam(),
@@ -342,7 +342,7 @@ class SaveBest(Module):
342
342
  return (1 - x)**2 + (100 * (y - x**2))**2
343
343
 
344
344
  xy = torch.tensor((-1.1, 2.5), requires_grad=True)
345
- opt = tz.Modular(
345
+ opt = tz.Optimizer(
346
346
  [xy],
347
347
  tz.m.NAG(0.999),
348
348
  tz.m.LR(1e-6),
@@ -129,7 +129,7 @@ class Online(Module):
129
129
 
130
130
  Online L-BFGS with Backtracking line search
131
131
  ```python
132
- opt = tz.Modular(
132
+ opt = tz.Optimizer(
133
133
  model.parameters(),
134
134
  tz.m.Online(tz.m.LBFGS()),
135
135
  tz.m.Backtracking()
@@ -138,19 +138,16 @@ class Online(Module):
138
138
 
139
139
  Online L-BFGS trust region
140
140
  ```python
141
- opt = tz.Modular(
141
+ opt = tz.Optimizer(
142
142
  model.parameters(),
143
143
  tz.m.TrustCG(tz.m.Online(tz.m.LBFGS()))
144
144
  )
145
145
  ```
146
146
 
147
147
  """
148
- def __init__(self, *modules: Module,):
148
+ def __init__(self, module: Module,):
149
149
  super().__init__()
150
- if len(modules) == 0:
151
- raise RuntimeError("Online got empty list of modules. To make a module online, wrap it in tz.m.Online, e.g. `tz.m.Online(tz.m.LBFGS())`")
152
-
153
- self.set_child('module', modules)
150
+ self.set_child('module', module)
154
151
 
155
152
  @torch.no_grad
156
153
  def update(self, objective):
@@ -23,7 +23,7 @@ class Dropout(Transform):
23
23
  Gradient dropout.
24
24
 
25
25
  ```python
26
- opt = tz.Modular(
26
+ opt = tz.Optimizer(
27
27
  model.parameters(),
28
28
  tz.m.Dropout(0.5),
29
29
  tz.m.Adam(),
@@ -34,7 +34,7 @@ class Dropout(Transform):
34
34
  Update dropout.
35
35
 
36
36
  ``python
37
- opt = tz.Modular(
37
+ opt = tz.Optimizer(
38
38
  model.parameters(),
39
39
  tz.m.Adam(),
40
40
  tz.m.Dropout(0.5),