torchzero 0.3.15__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (187) hide show
  1. tests/test_identical.py +22 -22
  2. tests/test_module_autograd.py +586 -0
  3. tests/test_objective.py +188 -0
  4. tests/test_opts.py +225 -214
  5. tests/test_tensorlist.py +0 -8
  6. tests/test_utils_optimizer.py +0 -1
  7. torchzero/__init__.py +2 -2
  8. torchzero/core/__init__.py +7 -4
  9. torchzero/core/chain.py +20 -23
  10. torchzero/core/functional.py +90 -24
  11. torchzero/core/modular.py +53 -57
  12. torchzero/core/module.py +132 -52
  13. torchzero/core/objective.py +948 -0
  14. torchzero/core/reformulation.py +55 -24
  15. torchzero/core/transform.py +261 -367
  16. torchzero/linalg/__init__.py +11 -0
  17. torchzero/linalg/eigh.py +253 -0
  18. torchzero/linalg/linalg_utils.py +14 -0
  19. torchzero/{utils/linalg → linalg}/linear_operator.py +99 -49
  20. torchzero/linalg/matrix_power.py +28 -0
  21. torchzero/linalg/orthogonalize.py +93 -0
  22. torchzero/{utils/linalg → linalg}/qr.py +16 -2
  23. torchzero/{utils/linalg → linalg}/solve.py +74 -88
  24. torchzero/linalg/svd.py +47 -0
  25. torchzero/linalg/torch_linalg.py +168 -0
  26. torchzero/modules/__init__.py +4 -3
  27. torchzero/modules/adaptive/__init__.py +11 -3
  28. torchzero/modules/adaptive/adagrad.py +167 -217
  29. torchzero/modules/adaptive/adahessian.py +76 -105
  30. torchzero/modules/adaptive/adam.py +53 -76
  31. torchzero/modules/adaptive/adan.py +50 -31
  32. torchzero/modules/adaptive/adaptive_heavyball.py +12 -7
  33. torchzero/modules/adaptive/aegd.py +12 -12
  34. torchzero/modules/adaptive/esgd.py +98 -119
  35. torchzero/modules/adaptive/ggt.py +186 -0
  36. torchzero/modules/adaptive/lion.py +7 -11
  37. torchzero/modules/adaptive/lre_optimizers.py +299 -0
  38. torchzero/modules/adaptive/mars.py +7 -7
  39. torchzero/modules/adaptive/matrix_momentum.py +48 -52
  40. torchzero/modules/adaptive/msam.py +71 -53
  41. torchzero/modules/adaptive/muon.py +67 -129
  42. torchzero/modules/adaptive/natural_gradient.py +63 -41
  43. torchzero/modules/adaptive/orthograd.py +11 -15
  44. torchzero/modules/adaptive/psgd/__init__.py +5 -0
  45. torchzero/modules/adaptive/psgd/_psgd_utils.py +37 -0
  46. torchzero/modules/adaptive/psgd/psgd.py +1390 -0
  47. torchzero/modules/adaptive/psgd/psgd_dense_newton.py +174 -0
  48. torchzero/modules/adaptive/psgd/psgd_kron_newton.py +203 -0
  49. torchzero/modules/adaptive/psgd/psgd_kron_whiten.py +185 -0
  50. torchzero/modules/adaptive/psgd/psgd_lra_newton.py +118 -0
  51. torchzero/modules/adaptive/psgd/psgd_lra_whiten.py +116 -0
  52. torchzero/modules/adaptive/rmsprop.py +83 -75
  53. torchzero/modules/adaptive/rprop.py +48 -47
  54. torchzero/modules/adaptive/sam.py +55 -45
  55. torchzero/modules/adaptive/shampoo.py +149 -130
  56. torchzero/modules/adaptive/soap.py +207 -143
  57. torchzero/modules/adaptive/sophia_h.py +106 -130
  58. torchzero/modules/clipping/clipping.py +22 -25
  59. torchzero/modules/clipping/ema_clipping.py +31 -25
  60. torchzero/modules/clipping/growth_clipping.py +14 -17
  61. torchzero/modules/conjugate_gradient/cg.py +27 -38
  62. torchzero/modules/experimental/__init__.py +7 -6
  63. torchzero/modules/experimental/adanystrom.py +258 -0
  64. torchzero/modules/experimental/common_directions_whiten.py +142 -0
  65. torchzero/modules/experimental/coordinate_momentum.py +36 -0
  66. torchzero/modules/experimental/cubic_adam.py +160 -0
  67. torchzero/modules/experimental/curveball.py +25 -41
  68. torchzero/modules/experimental/eigen_sr1.py +182 -0
  69. torchzero/modules/experimental/eigengrad.py +207 -0
  70. torchzero/modules/experimental/gradmin.py +2 -2
  71. torchzero/modules/experimental/higher_order_newton.py +14 -40
  72. torchzero/modules/experimental/l_infinity.py +1 -1
  73. torchzero/modules/experimental/matrix_nag.py +122 -0
  74. torchzero/modules/experimental/newton_solver.py +23 -54
  75. torchzero/modules/experimental/newtonnewton.py +45 -48
  76. torchzero/modules/experimental/reduce_outward_lr.py +7 -7
  77. torchzero/modules/experimental/scipy_newton_cg.py +21 -24
  78. torchzero/modules/experimental/spsa1.py +3 -3
  79. torchzero/modules/experimental/structural_projections.py +1 -4
  80. torchzero/modules/grad_approximation/fdm.py +2 -2
  81. torchzero/modules/grad_approximation/forward_gradient.py +7 -7
  82. torchzero/modules/grad_approximation/grad_approximator.py +23 -16
  83. torchzero/modules/grad_approximation/rfdm.py +24 -21
  84. torchzero/modules/least_squares/gn.py +121 -50
  85. torchzero/modules/line_search/backtracking.py +4 -4
  86. torchzero/modules/line_search/line_search.py +33 -33
  87. torchzero/modules/line_search/strong_wolfe.py +4 -4
  88. torchzero/modules/misc/debug.py +12 -12
  89. torchzero/modules/misc/escape.py +10 -10
  90. torchzero/modules/misc/gradient_accumulation.py +11 -79
  91. torchzero/modules/misc/homotopy.py +16 -8
  92. torchzero/modules/misc/misc.py +121 -123
  93. torchzero/modules/misc/multistep.py +52 -53
  94. torchzero/modules/misc/regularization.py +49 -44
  95. torchzero/modules/misc/split.py +31 -29
  96. torchzero/modules/misc/switch.py +37 -32
  97. torchzero/modules/momentum/averaging.py +14 -14
  98. torchzero/modules/momentum/cautious.py +37 -31
  99. torchzero/modules/momentum/momentum.py +12 -12
  100. torchzero/modules/ops/__init__.py +4 -4
  101. torchzero/modules/ops/accumulate.py +21 -21
  102. torchzero/modules/ops/binary.py +67 -66
  103. torchzero/modules/ops/higher_level.py +20 -20
  104. torchzero/modules/ops/multi.py +44 -41
  105. torchzero/modules/ops/reduce.py +26 -23
  106. torchzero/modules/ops/unary.py +53 -53
  107. torchzero/modules/ops/utility.py +47 -46
  108. torchzero/modules/{functional.py → opt_utils.py} +1 -1
  109. torchzero/modules/projections/galore.py +1 -1
  110. torchzero/modules/projections/projection.py +46 -43
  111. torchzero/modules/quasi_newton/__init__.py +1 -1
  112. torchzero/modules/quasi_newton/damping.py +2 -2
  113. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +1 -1
  114. torchzero/modules/quasi_newton/lbfgs.py +10 -10
  115. torchzero/modules/quasi_newton/lsr1.py +10 -10
  116. torchzero/modules/quasi_newton/quasi_newton.py +54 -39
  117. torchzero/modules/quasi_newton/sg2.py +69 -205
  118. torchzero/modules/restarts/restars.py +39 -37
  119. torchzero/modules/second_order/__init__.py +2 -2
  120. torchzero/modules/second_order/ifn.py +31 -62
  121. torchzero/modules/second_order/inm.py +57 -53
  122. torchzero/modules/second_order/multipoint.py +40 -80
  123. torchzero/modules/second_order/newton.py +165 -196
  124. torchzero/modules/second_order/newton_cg.py +105 -157
  125. torchzero/modules/second_order/nystrom.py +216 -185
  126. torchzero/modules/second_order/rsn.py +132 -125
  127. torchzero/modules/smoothing/laplacian.py +13 -12
  128. torchzero/modules/smoothing/sampling.py +10 -10
  129. torchzero/modules/step_size/adaptive.py +24 -24
  130. torchzero/modules/step_size/lr.py +17 -17
  131. torchzero/modules/termination/termination.py +32 -30
  132. torchzero/modules/trust_region/cubic_regularization.py +3 -3
  133. torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
  134. torchzero/modules/trust_region/trust_cg.py +2 -2
  135. torchzero/modules/trust_region/trust_region.py +27 -22
  136. torchzero/modules/variance_reduction/svrg.py +23 -21
  137. torchzero/modules/weight_decay/__init__.py +2 -1
  138. torchzero/modules/weight_decay/reinit.py +83 -0
  139. torchzero/modules/weight_decay/weight_decay.py +17 -18
  140. torchzero/modules/wrappers/optim_wrapper.py +14 -14
  141. torchzero/modules/zeroth_order/cd.py +10 -7
  142. torchzero/optim/mbs.py +291 -0
  143. torchzero/optim/root.py +3 -3
  144. torchzero/optim/utility/split.py +2 -1
  145. torchzero/optim/wrappers/directsearch.py +27 -63
  146. torchzero/optim/wrappers/fcmaes.py +14 -35
  147. torchzero/optim/wrappers/mads.py +11 -31
  148. torchzero/optim/wrappers/moors.py +66 -0
  149. torchzero/optim/wrappers/nevergrad.py +4 -13
  150. torchzero/optim/wrappers/nlopt.py +31 -25
  151. torchzero/optim/wrappers/optuna.py +8 -13
  152. torchzero/optim/wrappers/pybobyqa.py +124 -0
  153. torchzero/optim/wrappers/scipy/__init__.py +7 -0
  154. torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
  155. torchzero/optim/wrappers/scipy/brute.py +48 -0
  156. torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
  157. torchzero/optim/wrappers/scipy/direct.py +69 -0
  158. torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
  159. torchzero/optim/wrappers/scipy/experimental.py +141 -0
  160. torchzero/optim/wrappers/scipy/minimize.py +151 -0
  161. torchzero/optim/wrappers/scipy/sgho.py +111 -0
  162. torchzero/optim/wrappers/wrapper.py +121 -0
  163. torchzero/utils/__init__.py +7 -25
  164. torchzero/utils/benchmarks/__init__.py +0 -0
  165. torchzero/utils/benchmarks/logistic.py +122 -0
  166. torchzero/utils/compile.py +2 -2
  167. torchzero/utils/derivatives.py +97 -73
  168. torchzero/utils/optimizer.py +4 -77
  169. torchzero/utils/python_tools.py +31 -0
  170. torchzero/utils/tensorlist.py +11 -5
  171. torchzero/utils/thoad_tools.py +68 -0
  172. {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/METADATA +1 -1
  173. torchzero-0.4.1.dist-info/RECORD +209 -0
  174. tests/test_vars.py +0 -185
  175. torchzero/core/var.py +0 -376
  176. torchzero/modules/adaptive/lmadagrad.py +0 -186
  177. torchzero/modules/experimental/momentum.py +0 -160
  178. torchzero/optim/wrappers/scipy.py +0 -572
  179. torchzero/utils/linalg/__init__.py +0 -12
  180. torchzero/utils/linalg/matrix_funcs.py +0 -87
  181. torchzero/utils/linalg/orthogonalize.py +0 -12
  182. torchzero/utils/linalg/svd.py +0 -20
  183. torchzero/utils/ops.py +0 -10
  184. torchzero-0.3.15.dist-info/RECORD +0 -175
  185. /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
  186. {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/WHEEL +0 -0
  187. {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/top_level.txt +0 -0
@@ -38,15 +38,15 @@ class SPSA1(GradApproximator):
38
38
  super().__init__(defaults, target=target)
39
39
 
40
40
 
41
- def pre_step(self, var):
41
+ def pre_step(self, objective):
42
42
 
43
43
  if self.defaults['pre_generate']:
44
44
 
45
- params = TensorList(var.params)
45
+ params = TensorList(objective.params)
46
46
  generator = self.get_generator(params[0].device, self.defaults['seed'])
47
47
 
48
48
  n_samples = self.defaults['n_samples']
49
- h = self.get_settings(var.params, 'h')
49
+ h = self.get_settings(objective.params, 'h')
50
50
 
51
51
  perturbations = [params.rademacher_like(generator=generator) for _ in range(n_samples)]
52
52
  torch._foreach_mul_([p for l in perturbations for p in l], [v for vv in h for v in [vv]*n_samples])
@@ -1,11 +1,8 @@
1
1
  import math
2
-
3
- import numpy as np
4
2
  import torch
5
3
 
6
4
  from ...core import Chainable
7
- from ...utils import vec_to_tensors, TensorList
8
- from ..adaptive.shampoo import _merge_small_dims
5
+ from ...utils import vec_to_tensors
9
6
  from ..projections import ProjectionBase
10
7
 
11
8
 
@@ -106,12 +106,12 @@ class FDM(GradApproximator):
106
106
  plain FDM:
107
107
 
108
108
  ```python
109
- fdm = tz.Modular(model.parameters(), tz.m.FDM(), tz.m.LR(1e-2))
109
+ fdm = tz.Optimizer(model.parameters(), tz.m.FDM(), tz.m.LR(1e-2))
110
110
  ```
111
111
 
112
112
  Any gradient-based method can use FDM-estimated gradients.
113
113
  ```python
114
- fdm_ncg = tz.Modular(
114
+ fdm_ncg = tz.Optimizer(
115
115
  model.parameters(),
116
116
  tz.m.FDM(),
117
117
  # set hvp_method to "forward" so that it
@@ -52,11 +52,11 @@ class ForwardGradient(RandomizedFDM):
52
52
  params = TensorList(params)
53
53
  loss_approx = None
54
54
 
55
- settings = self.settings[params[0]]
56
- n_samples = settings['n_samples']
57
- jvp_method = settings['jvp_method']
58
- h = settings['h']
59
- distribution = settings['distribution']
55
+ fs = self.settings[params[0]]
56
+ n_samples = fs['n_samples']
57
+ jvp_method = fs['jvp_method']
58
+ h = fs['h']
59
+ distribution = fs['distribution']
60
60
  default = [None]*n_samples
61
61
  perturbations = list(zip(*(self.state[p].get('perturbations', default) for p in params)))
62
62
  generator = self.get_generator(params[0].device, self.defaults['seed'])
@@ -74,10 +74,10 @@ class ForwardGradient(RandomizedFDM):
74
74
  loss, d = jvp(partial(closure, False), params=params, tangent=prt)
75
75
 
76
76
  elif jvp_method == 'forward':
77
- loss, d = jvp_fd_forward(partial(closure, False), params=params, tangent=prt, v_0=loss, normalize=True, h=h)
77
+ loss, d = jvp_fd_forward(partial(closure, False), params=params, tangent=prt, v_0=loss, h=h)
78
78
 
79
79
  elif jvp_method == 'central':
80
- loss_approx, d = jvp_fd_central(partial(closure, False), params=params, tangent=prt, normalize=True, h=h)
80
+ loss_approx, d = jvp_fd_central(partial(closure, False), params=params, tangent=prt, h=h)
81
81
 
82
82
  else: raise ValueError(jvp_method)
83
83
 
@@ -5,7 +5,7 @@ from typing import Any, Literal
5
5
 
6
6
  import torch
7
7
 
8
- from ...core import Module, Var
8
+ from ...core import Module, Objective
9
9
 
10
10
  GradTarget = Literal['update', 'grad', 'closure']
11
11
  _Scalar = torch.Tensor | float
@@ -62,24 +62,25 @@ class GradApproximator(Module, ABC):
62
62
  return spsa_grads, None, loss_plus
63
63
  ```
64
64
  """
65
- def __init__(self, defaults: dict[str, Any] | None = None, target: GradTarget = 'closure'):
65
+ def __init__(self, defaults: dict[str, Any] | None = None, return_approx_loss:bool=False, target: GradTarget = 'closure'):
66
66
  super().__init__(defaults)
67
67
  self._target: GradTarget = target
68
+ self._return_approx_loss = return_approx_loss
68
69
 
69
70
  @abstractmethod
70
71
  def approximate(self, closure: Callable, params: list[torch.Tensor], loss: torch.Tensor | None) -> tuple[Iterable[torch.Tensor], torch.Tensor | None, torch.Tensor | None]:
71
72
  """Returns a tuple: ``(grad, loss, loss_approx)``, make sure this resets parameters to their original values!"""
72
73
 
73
- def pre_step(self, var: Var) -> None:
74
+ def pre_step(self, objective: Objective) -> None:
74
75
  """This runs once before each step, whereas `approximate` may run multiple times per step if further modules
75
76
  evaluate gradients at multiple points. This is useful for example to pre-generate new random perturbations."""
76
77
 
77
78
  @torch.no_grad
78
- def step(self, var):
79
- self.pre_step(var)
79
+ def update(self, objective):
80
+ self.pre_step(objective)
80
81
 
81
- if var.closure is None: raise RuntimeError("Gradient approximation requires closure")
82
- params, closure, loss = var.params, var.closure, var.loss
82
+ if objective.closure is None: raise RuntimeError("Gradient approximation requires closure")
83
+ params, closure, loss = objective.params, objective.closure, objective.loss
83
84
 
84
85
  if self._target == 'closure':
85
86
 
@@ -88,20 +89,26 @@ class GradApproximator(Module, ABC):
88
89
  # set loss to None because closure might be evaluated at different points
89
90
  grad, l, l_approx = self.approximate(closure=closure, params=params, loss=None)
90
91
  for p, g in zip(params, grad): p.grad = g
91
- return l if l is not None else closure(False)
92
+ if l is not None: return l
93
+ if self._return_approx_loss and l_approx is not None: return l_approx
94
+ return closure(False)
95
+
92
96
  return closure(False)
93
97
 
94
- var.closure = approx_closure
95
- return var
98
+ objective.closure = approx_closure
99
+ return
96
100
 
97
101
  # if var.grad is not None:
98
102
  # warnings.warn('Using grad approximator when `var.grad` is already set.')
99
- grad,loss,loss_approx = self.approximate(closure=closure, params=params, loss=loss)
100
- if loss_approx is not None: var.loss_approx = loss_approx
101
- if loss is not None: var.loss = var.loss_approx = loss
102
- if self._target == 'grad': var.grad = list(grad)
103
- elif self._target == 'update': var.update = list(grad)
103
+ grad, loss, loss_approx = self.approximate(closure=closure, params=params, loss=loss)
104
+ if loss_approx is not None: objective.loss_approx = loss_approx
105
+ if loss is not None: objective.loss = objective.loss_approx = loss
106
+ if self._target == 'grad': objective.grads = list(grad)
107
+ elif self._target == 'update': objective.updates = list(grad)
104
108
  else: raise ValueError(self._target)
105
- return var
109
+ return
110
+
111
+ def apply(self, objective):
112
+ return objective
106
113
 
107
114
  _FD_Formula = Literal['forward', 'forward2', 'backward', 'backward2', 'central', 'central2', 'central3', 'forward3', 'backward3', 'central4', 'forward4', 'forward5', 'bspsa4']
@@ -174,9 +174,9 @@ class RandomizedFDM(GradApproximator):
174
174
 
175
175
  SPSA is randomized FDM with rademacher distribution and central formula.
176
176
  ```py
177
- spsa = tz.Modular(
177
+ spsa = tz.Optimizer(
178
178
  model.parameters(),
179
- tz.m.RandomizedFDM(formula="central", distribution="rademacher"),
179
+ tz.m.RandomizedFDM(formula="fd_central", distribution="rademacher"),
180
180
  tz.m.LR(1e-2)
181
181
  )
182
182
  ```
@@ -185,9 +185,9 @@ class RandomizedFDM(GradApproximator):
185
185
 
186
186
  RDSA is randomized FDM with usually gaussian distribution and central formula.
187
187
  ```
188
- rdsa = tz.Modular(
188
+ rdsa = tz.Optimizer(
189
189
  model.parameters(),
190
- tz.m.RandomizedFDM(formula="central", distribution="gaussian"),
190
+ tz.m.RandomizedFDM(formula="fd_central", distribution="gaussian"),
191
191
  tz.m.LR(1e-2)
192
192
  )
193
193
  ```
@@ -196,7 +196,7 @@ class RandomizedFDM(GradApproximator):
196
196
 
197
197
  GS uses many gaussian samples with possibly a larger finite difference step size.
198
198
  ```
199
- gs = tz.Modular(
199
+ gs = tz.Optimizer(
200
200
  model.parameters(),
201
201
  tz.m.RandomizedFDM(n_samples=100, distribution="gaussian", formula="forward2", h=1e-1),
202
202
  tz.m.NewtonCG(hvp_method="forward"),
@@ -208,7 +208,7 @@ class RandomizedFDM(GradApproximator):
208
208
 
209
209
  Momentum might help by reducing the variance of the estimated gradients.
210
210
  ```
211
- momentum_spsa = tz.Modular(
211
+ momentum_spsa = tz.Optimizer(
212
212
  model.parameters(),
213
213
  tz.m.RandomizedFDM(),
214
214
  tz.m.HeavyBall(0.9),
@@ -223,23 +223,24 @@ class RandomizedFDM(GradApproximator):
223
223
  n_samples: int = 1,
224
224
  formula: _FD_Formula = "central",
225
225
  distribution: Distributions = "rademacher",
226
- pre_generate = True,
226
+ pre_generate: bool = True,
227
+ return_approx_loss: bool = False,
227
228
  seed: int | None | torch.Generator = None,
228
229
  target: GradTarget = "closure",
229
230
  ):
230
231
  defaults = dict(h=h, formula=formula, n_samples=n_samples, distribution=distribution, pre_generate=pre_generate, seed=seed)
231
- super().__init__(defaults, target=target)
232
+ super().__init__(defaults, return_approx_loss=return_approx_loss, target=target)
232
233
 
233
234
 
234
- def pre_step(self, var):
235
- h = self.get_settings(var.params, 'h')
235
+ def pre_step(self, objective):
236
+ h = self.get_settings(objective.params, 'h')
236
237
  pre_generate = self.defaults['pre_generate']
237
238
 
238
239
  if pre_generate:
239
240
  n_samples = self.defaults['n_samples']
240
241
  distribution = self.defaults['distribution']
241
242
 
242
- params = TensorList(var.params)
243
+ params = TensorList(objective.params)
243
244
  generator = self.get_generator(params[0].device, self.defaults['seed'])
244
245
  perturbations = [params.sample_like(distribution=distribution, variance=1, generator=generator) for _ in range(n_samples)]
245
246
 
@@ -346,11 +347,12 @@ class RDSA(RandomizedFDM):
346
347
  n_samples: int = 1,
347
348
  formula: _FD_Formula = "central2",
348
349
  distribution: Distributions = "gaussian",
349
- pre_generate = True,
350
+ pre_generate: bool = True,
351
+ return_approx_loss: bool = False,
350
352
  target: GradTarget = "closure",
351
353
  seed: int | None | torch.Generator = None,
352
354
  ):
353
- super().__init__(h=h, n_samples=n_samples,formula=formula,distribution=distribution,pre_generate=pre_generate,target=target,seed=seed)
355
+ super().__init__(h=h, n_samples=n_samples,formula=formula,distribution=distribution,pre_generate=pre_generate,target=target,seed=seed, return_approx_loss=return_approx_loss)
354
356
 
355
357
  class GaussianSmoothing(RandomizedFDM):
356
358
  """
@@ -380,11 +382,12 @@ class GaussianSmoothing(RandomizedFDM):
380
382
  n_samples: int = 100,
381
383
  formula: _FD_Formula = "forward2",
382
384
  distribution: Distributions = "gaussian",
383
- pre_generate = True,
385
+ pre_generate: bool = True,
386
+ return_approx_loss: bool = False,
384
387
  target: GradTarget = "closure",
385
388
  seed: int | None | torch.Generator = None,
386
389
  ):
387
- super().__init__(h=h, n_samples=n_samples,formula=formula,distribution=distribution,pre_generate=pre_generate,target=target,seed=seed)
390
+ super().__init__(h=h, n_samples=n_samples,formula=formula,distribution=distribution,pre_generate=pre_generate,target=target,seed=seed, return_approx_loss=return_approx_loss)
388
391
 
389
392
  class MeZO(GradApproximator):
390
393
  """Gradient approximation via memory-efficient zeroth order optimizer (MeZO) - https://arxiv.org/abs/2305.17333.
@@ -406,10 +409,10 @@ class MeZO(GradApproximator):
406
409
  """
407
410
 
408
411
  def __init__(self, h: float=1e-3, n_samples: int = 1, formula: _FD_Formula = 'central2',
409
- distribution: Distributions = 'rademacher', target: GradTarget = 'closure'):
412
+ distribution: Distributions = 'rademacher', return_approx_loss: bool = False, target: GradTarget = 'closure'):
410
413
 
411
414
  defaults = dict(h=h, formula=formula, n_samples=n_samples, distribution=distribution)
412
- super().__init__(defaults, target=target)
415
+ super().__init__(defaults, return_approx_loss=return_approx_loss, target=target)
413
416
 
414
417
  def _seeded_perturbation(self, params: list[torch.Tensor], distribution, seed, h):
415
418
  prt = TensorList(params).sample_like(
@@ -419,19 +422,19 @@ class MeZO(GradApproximator):
419
422
  )
420
423
  return prt
421
424
 
422
- def pre_step(self, var):
423
- h = NumberList(self.settings[p]['h'] for p in var.params)
425
+ def pre_step(self, objective):
426
+ h = NumberList(self.settings[p]['h'] for p in objective.params)
424
427
 
425
428
  n_samples = self.defaults['n_samples']
426
429
  distribution = self.defaults['distribution']
427
430
 
428
- step = var.current_step
431
+ step = objective.current_step
429
432
 
430
433
  # create functions that generate a deterministic perturbation from seed based on current step
431
434
  prt_fns = []
432
435
  for i in range(n_samples):
433
436
 
434
- prt_fn = partial(self._seeded_perturbation, params=var.params, distribution=distribution, seed=1_000_000*step + i, h=h)
437
+ prt_fn = partial(self._seeded_perturbation, params=objective.params, distribution=distribution, seed=1_000_000*step + i, h=h)
435
438
  prt_fns.append(prt_fn)
436
439
 
437
440
  self.global_state['prt_fns'] = prt_fns
@@ -1,28 +1,31 @@
1
1
  import torch
2
- from ...core import Module
3
2
 
4
- from ...utils.derivatives import jacobian_wrt, flatten_jacobian
3
+ from ...core import Chainable, Transform
4
+ from ...linalg import linear_operator
5
5
  from ...utils import vec_to_tensors
6
- from ...utils.linalg import linear_operator
7
- class SumOfSquares(Module):
6
+ from ...utils.derivatives import flatten_jacobian, jacobian_wrt
7
+
8
+
9
+ class SumOfSquares(Transform):
8
10
  """Sets loss to be the sum of squares of values returned by the closure.
9
11
 
10
12
  This is meant to be used to test least squares methods against ordinary minimization methods.
11
13
 
12
14
  To use this, the closure should return a vector of values to minimize sum of squares of.
13
- Please add the `backward` argument, it will always be False but it is required.
15
+ Please add the ``backward`` argument, it will always be False but it is required.
14
16
  """
15
17
  def __init__(self):
16
18
  super().__init__()
17
19
 
18
20
  @torch.no_grad
19
- def step(self, var):
20
- closure = var.closure
21
+ def update_states(self, objective, states, settings):
22
+ closure = objective.closure
21
23
 
22
24
  if closure is not None:
25
+
23
26
  def sos_closure(backward=True):
24
27
  if backward:
25
- var.zero_grad()
28
+ objective.zero_grad()
26
29
  with torch.enable_grad():
27
30
  loss = closure(False)
28
31
  loss = loss.pow(2).sum()
@@ -32,18 +35,19 @@ class SumOfSquares(Module):
32
35
  loss = closure(False)
33
36
  return loss.pow(2).sum()
34
37
 
35
- var.closure = sos_closure
36
-
37
- if var.loss is not None:
38
- var.loss = var.loss.pow(2).sum()
38
+ objective.closure = sos_closure
39
39
 
40
- if var.loss_approx is not None:
41
- var.loss_approx = var.loss_approx.pow(2).sum()
40
+ if objective.loss is not None:
41
+ objective.loss = objective.loss.pow(2).sum()
42
42
 
43
- return var
43
+ if objective.loss_approx is not None:
44
+ objective.loss_approx = objective.loss_approx.pow(2).sum()
44
45
 
46
+ @torch.no_grad
47
+ def apply_states(self, objective, states, settings):
48
+ return objective
45
49
 
46
- class GaussNewton(Module):
50
+ class GaussNewton(Transform):
47
51
  """Gauss-newton method.
48
52
 
49
53
  To use this, the closure should return a vector of values to minimize sum of squares of.
@@ -57,6 +61,9 @@ class GaussNewton(Module):
57
61
 
58
62
  Args:
59
63
  reg (float, optional): regularization parameter. Defaults to 1e-8.
64
+ update_freq (int, optional):
65
+ frequency of computing the jacobian. When jacobian is not computed, only residuals are computed and updated.
66
+ Defaults to 1.
60
67
  batched (bool, optional): whether to use vmapping. Defaults to True.
61
68
 
62
69
  Examples:
@@ -68,7 +75,7 @@ class GaussNewton(Module):
68
75
  return torch.stack([(1 - x1), 100 * (x2 - x1**2)])
69
76
 
70
77
  X = torch.tensor([-1.1, 2.5], requires_grad=True)
71
- opt = tz.Modular([X], tz.m.GaussNewton(), tz.m.Backtracking())
78
+ opt = tz.Optimizer([X], tz.m.GaussNewton(), tz.m.Backtracking())
72
79
 
73
80
  # define the closure for line search
74
81
  def closure(backward=True):
@@ -86,7 +93,7 @@ class GaussNewton(Module):
86
93
  y = torch.randn(64, 10)
87
94
 
88
95
  model = nn.Sequential(nn.Linear(20, 64), nn.ELU(), nn.Linear(64, 10))
89
- opt = tz.Modular(
96
+ opt = tz.Optimizer(
90
97
  model.parameters(),
91
98
  tz.m.TrustCG(tz.m.GaussNewton()),
92
99
  )
@@ -101,35 +108,62 @@ class GaussNewton(Module):
101
108
  print(f'{losses.mean() = }')
102
109
  ```
103
110
  """
104
- def __init__(self, reg:float = 1e-8, batched:bool=True, ):
105
- super().__init__(defaults=dict(batched=batched, reg=reg))
111
+ def __init__(self, reg:float = 1e-8, update_freq: int= 1, batched:bool=True, inner: Chainable | None = None):
112
+ defaults=dict(update_freq=update_freq,batched=batched, reg=reg)
113
+ super().__init__(defaults=defaults)
114
+ if inner is not None: self.set_child('inner', inner)
106
115
 
107
116
  @torch.no_grad
108
- def update(self, var):
109
- params = var.params
110
- batched = self.defaults['batched']
117
+ def update_states(self, objective, states, settings):
118
+ fs = settings[0]
119
+ params = objective.params
120
+ closure = objective.closure
121
+ batched = fs['batched']
122
+ update_freq = fs['update_freq']
123
+
124
+ # compute residuals
125
+ r = objective.loss
126
+ if r is None:
127
+ assert closure is not None
128
+ with torch.enable_grad():
129
+ r = objective.get_loss(backward=False) # n_residuals
130
+ assert isinstance(r, torch.Tensor)
131
+
132
+ # set sum of squares scalar loss and it's gradient to objective
133
+ objective.loss = r.pow(2).sum()
134
+
135
+ step = self.increment_counter("step", start=0)
136
+
137
+ if step % update_freq == 0:
138
+
139
+ # compute jacobian
140
+ with torch.enable_grad():
141
+ J_list = jacobian_wrt([r.ravel()], params, batched=batched)
142
+
143
+ J = self.global_state["J"] = flatten_jacobian(J_list) # (n_residuals, ndim)
111
144
 
112
- closure = var.closure
113
- assert closure is not None
145
+ else:
146
+ J = self.global_state["J"]
114
147
 
115
- # gauss newton direction
116
- with torch.enable_grad():
117
- f = var.get_loss(backward=False) # n_out
118
- assert isinstance(f, torch.Tensor)
119
- G_list = jacobian_wrt([f.ravel()], params, batched=batched)
148
+ Jr = J.T @ r.detach() # (ndim)
120
149
 
121
- var.loss = f.pow(2).sum()
150
+ # if there are more residuals, solve (J^T J)x = J^T r, so we need Jr
151
+ # otherwise solve (J J^T)z = r and set x = J^T z, so we need r
152
+ n_residuals, ndim = J.shape
153
+ if n_residuals >= ndim or "inner" in self.children:
154
+ self.global_state["Jr"] = Jr
122
155
 
123
- G = self.global_state["G"] = flatten_jacobian(G_list) # (n_out, ndim)
124
- Gtf = G.T @ f.detach() # (ndim)
125
- self.global_state["Gtf"] = Gtf
126
- var.grad = vec_to_tensors(Gtf, var.params)
156
+ else:
157
+ self.global_state["r"] = r
158
+
159
+ objective.grads = vec_to_tensors(Jr, objective.params)
127
160
 
128
161
  # set closure to calculate sum of squares for line searches etc
129
- if var.closure is not None:
162
+ if closure is not None:
130
163
  def sos_closure(backward=True):
164
+
131
165
  if backward:
132
- var.zero_grad()
166
+ objective.zero_grad()
133
167
  with torch.enable_grad():
134
168
  loss = closure(False).pow(2).sum()
135
169
  loss.backward()
@@ -138,24 +172,61 @@ class GaussNewton(Module):
138
172
  loss = closure(False).pow(2).sum()
139
173
  return loss
140
174
 
141
- var.closure = sos_closure
175
+ objective.closure = sos_closure
142
176
 
143
177
  @torch.no_grad
144
- def apply(self, var):
145
- reg = self.defaults['reg']
178
+ def apply_states(self, objective, states, settings):
179
+ fs = settings[0]
180
+ reg = fs['reg']
181
+
182
+ J: torch.Tensor = self.global_state['J']
183
+ nresiduals, ndim = J.shape
184
+ if nresiduals >= ndim or "inner" in self.children:
185
+
186
+ # (J^T J)v = J^T r
187
+ Jr: torch.Tensor = self.global_state['Jr']
188
+
189
+ # inner step
190
+ if "inner" in self.children:
191
+
192
+ # var.grad is set to unflattened Jr
193
+ assert objective.grads is not None
194
+ objective = self.inner_step("inner", objective, must_exist=True)
195
+ Jr_list = objective.get_updates()
196
+ Jr = torch.cat([t.ravel() for t in Jr_list])
197
+
198
+ JtJ = J.T @ J # (ndim, ndim)
199
+ if reg != 0:
200
+ JtJ.add_(torch.eye(JtJ.size(0), device=JtJ.device, dtype=JtJ.dtype).mul_(reg))
201
+
202
+ if nresiduals >= ndim:
203
+ v, info = torch.linalg.solve_ex(JtJ, Jr) # pylint:disable=not-callable
204
+ else:
205
+ v = torch.linalg.lstsq(JtJ, Jr).solution # pylint:disable=not-callable
206
+
207
+ objective.updates = vec_to_tensors(v, objective.params)
208
+ return objective
209
+
210
+ # else:
211
+ # solve (J J^T)z = r and set v = J^T z
212
+ # we need (J^T J)v = J^T r
213
+ # if z is solution to (G G^T)z = r, and v = J^T z
214
+ # then (J^T J)v = (J^T J) (J^T z) = J^T (J J^T) z = J^T r
215
+ # therefore (J^T J)v = J^T r
216
+ # also this gives a minimum norm solution
146
217
 
147
- G = self.global_state['G']
148
- Gtf = self.global_state['Gtf']
218
+ r = self.global_state['r']
149
219
 
150
- GtG = G.T @ G # (ndim, ndim)
220
+ JJT = J @ J.T # (nresiduals, nresiduals)
151
221
  if reg != 0:
152
- GtG.add_(torch.eye(GtG.size(0), device=GtG.device, dtype=GtG.dtype).mul_(reg))
222
+ JJT.add_(torch.eye(JJT.size(0), device=JJT.device, dtype=JJT.dtype).mul_(reg))
153
223
 
154
- v = torch.linalg.lstsq(GtG, Gtf).solution # pylint:disable=not-callable
224
+ z, info = torch.linalg.solve_ex(JJT, r) # pylint:disable=not-callable
225
+ v = J.T @ z
155
226
 
156
- var.update = vec_to_tensors(v, var.params)
157
- return var
227
+ objective.updates = vec_to_tensors(v, objective.params)
228
+ return objective
158
229
 
159
- def get_H(self, var):
160
- G = self.global_state['G']
161
- return linear_operator.AtA(G)
230
+ def get_H(self, objective=...):
231
+ J = self.global_state['J']
232
+ return linear_operator.AtA(J)
@@ -77,7 +77,7 @@ class Backtracking(LineSearchBase):
77
77
  Gradient descent with backtracking line search:
78
78
 
79
79
  ```python
80
- opt = tz.Modular(
80
+ opt = tz.Optimizer(
81
81
  model.parameters(),
82
82
  tz.m.Backtracking()
83
83
  )
@@ -85,7 +85,7 @@ class Backtracking(LineSearchBase):
85
85
 
86
86
  L-BFGS with backtracking line search:
87
87
  ```python
88
- opt = tz.Modular(
88
+ opt = tz.Optimizer(
89
89
  model.parameters(),
90
90
  tz.m.LBFGS(),
91
91
  tz.m.Backtracking()
@@ -117,7 +117,7 @@ class Backtracking(LineSearchBase):
117
117
 
118
118
  # # directional derivative
119
119
  if c == 0: d = 0
120
- else: d = -sum(t.sum() for t in torch._foreach_mul(var.get_grad(), var.get_update()))
120
+ else: d = -sum(t.sum() for t in torch._foreach_mul(var.get_grads(), var.get_updates()))
121
121
 
122
122
  # scale init
123
123
  init_scale = self.global_state.get('init_scale', 1)
@@ -199,7 +199,7 @@ class AdaptiveBacktracking(LineSearchBase):
199
199
 
200
200
  # directional derivative (0 if c = 0 because it is not needed)
201
201
  if c == 0: d = 0
202
- else: d = -sum(t.sum() for t in torch._foreach_mul(var.get_grad(), update))
202
+ else: d = -sum(t.sum() for t in torch._foreach_mul(var.get_grads(), update))
203
203
 
204
204
  # scale beta
205
205
  beta = beta * self.global_state['beta_scale']